Modified the input and output buffers, made them members of the

session object, rather than pointers.
pull/1/head
John Safranek 2014-08-12 17:21:13 -07:00
parent 8cb5fea384
commit f07f623ad6
3 changed files with 95 additions and 112 deletions

View File

@ -148,43 +148,27 @@ const char* IdToName(uint8_t id)
}
Buffer* BufferNew(uint32_t size, void* heap)
int BufferInit(Buffer* buffer, uint32_t size, void* heap)
{
Buffer* newBuffer = NULL;
if (buffer == NULL)
return WS_BAD_ARGUMENT;
if (size <= STATIC_BUFFER_LEN)
size = STATIC_BUFFER_LEN;
newBuffer = (Buffer*)WMALLOC(sizeof(Buffer), heap, WOLFSSH_TYPE_BUFFER);
if (newBuffer != NULL) {
WMEMSET(newBuffer, 0, sizeof(Buffer));
newBuffer->heap = heap;
newBuffer->bufferSz = size;
if (size > STATIC_BUFFER_LEN) {
newBuffer->buffer = (uint8_t*)WMALLOC(size,
heap, WOLFSSH_TYPE_BUFFER);
if (newBuffer->buffer == NULL) {
WFREE(newBuffer, heap, WOLFSSH_TYPE_BUFFER);
newBuffer = NULL;
}
else
newBuffer->dynamicFlag = 1;
}
else
newBuffer->buffer = newBuffer->staticBuffer;
WMEMSET(buffer, 0, sizeof(Buffer));
buffer->heap = heap;
buffer->bufferSz = size;
if (size > STATIC_BUFFER_LEN) {
buffer->buffer = (uint8_t*)WMALLOC(size, heap, WOLFSSH_TYPE_BUFFER);
if (buffer->buffer == NULL)
return WS_MEMORY_E;
buffer->dynamicFlag = 1;
}
else
buffer->buffer = buffer->staticBuffer;
return newBuffer;
}
void BufferFree(Buffer* buf)
{
if (buf != NULL) {
if (buf->dynamicFlag)
WFREE(buf->buffer, buf->heap, WOLFSSH_TYPE_BUFFER);
WFREE(buf, bug->heap, WOLFSSH_TYPE_BUFFER);
}
return WS_SUCCESS;
}
@ -228,24 +212,25 @@ int GrowBuffer(Buffer* buf, uint32_t sz, uint32_t usedSz)
}
void ShrinkBuffer(Buffer* buf)
void ShrinkBuffer(Buffer* buf, int forcedFree)
{
if (buf != NULL) {
uint32_t usedSz = buf->length - buf->idx;
if (usedSz > STATIC_BUFFER_LEN)
if (!forcedFree && usedSz > STATIC_BUFFER_LEN)
return;
WLOG(WS_LOG_DEBUG, "Shrinking buffer");
if (usedSz)
if (!forcedFree && usedSz)
WMEMCPY(buf->staticBuffer, buf->buffer + buf->idx, usedSz);
WFREE(buf->buffer, buf->heap, WOLFSSH_TYPE_BUFFER);
if (buf->dynamicFlag)
WFREE(buf->buffer, buf->heap, WOLFSSH_TYPE_BUFFER);
buf->dynamicFlag = 0;
buf->buffer = buf->staticBuffer;
buf->bufferSz = STATIC_BUFFER_LEN;
buf->length = usedSz;
buf->length = forcedFree ? 0 : usedSz;
buf->idx = 0;
}
}
@ -299,12 +284,12 @@ static int GetInputText(WOLFSSH* ssh)
int inSz = 255;
int in;
if (GrowBuffer(ssh->inputBuffer, inSz, 0) < 0)
if (GrowBuffer(&ssh->inputBuffer, inSz, 0) < 0)
return WS_MEMORY_E;
do {
in = Receive(ssh,
ssh->inputBuffer->buffer + ssh->inputBuffer->length, inSz);
ssh->inputBuffer.buffer + ssh->inputBuffer.length, inSz);
if (in == -1)
return WS_SOCKET_ERROR_E;
@ -315,12 +300,12 @@ static int GetInputText(WOLFSSH* ssh)
if (in > inSz)
return WS_RECV_OVERFLOW_E;
ssh->inputBuffer->length += in;
ssh->inputBuffer.length += in;
inSz -= in;
if (ssh->inputBuffer->length > 2) {
if (ssh->inputBuffer->buffer[ssh->inputBuffer->length - 2] == '\r' &&
ssh->inputBuffer->buffer[ssh->inputBuffer->length - 1] == '\n') {
if (ssh->inputBuffer.length > 2) {
if (ssh->inputBuffer.buffer[ssh->inputBuffer.length - 2] == '\r' &&
ssh->inputBuffer.buffer[ssh->inputBuffer.length - 1] == '\n') {
gotLine = 1;
}
@ -338,26 +323,26 @@ static int SendBuffer(WOLFSSH* ssh)
return -1;
}
while (ssh->outputBuffer->length > 0) {
while (ssh->outputBuffer.length > 0) {
int sent = ssh->ctx->ioSendCb(ssh,
ssh->outputBuffer->buffer + ssh->outputBuffer->idx,
ssh->outputBuffer->length, ssh->ioWriteCtx);
ssh->outputBuffer.buffer + ssh->outputBuffer.idx,
ssh->outputBuffer.length, ssh->ioWriteCtx);
if (sent < 0) {
return WS_SOCKET_ERROR_E;
}
if (sent > (int)ssh->outputBuffer->length) {
if (sent > (int)ssh->outputBuffer.length) {
WLOG(WS_LOG_DEBUG, "Out of bounds read");
return WS_SEND_OOB_READ_E;
}
ssh->outputBuffer->idx += sent;
ssh->outputBuffer->length -= sent;
ssh->outputBuffer.idx += sent;
ssh->outputBuffer.length -= sent;
}
ssh->outputBuffer->idx = 0;
ShrinkBuffer(ssh->outputBuffer);
ssh->outputBuffer.idx = 0;
ShrinkBuffer(&ssh->outputBuffer, 0);
return WS_SUCCESS;
}
@ -365,9 +350,9 @@ static int SendBuffer(WOLFSSH* ssh)
static int SendText(WOLFSSH* ssh, const char* text, uint32_t textLen)
{
GrowBuffer(ssh->outputBuffer, textLen, 0);
WMEMCPY(ssh->outputBuffer->buffer, text, textLen);
ssh->outputBuffer->length = textLen;
GrowBuffer(&ssh->outputBuffer, textLen, 0);
WMEMCPY(ssh->outputBuffer.buffer, text, textLen);
ssh->outputBuffer.length = textLen;
return SendBuffer(ssh);
}
@ -381,8 +366,8 @@ static int GetInputData(WOLFSSH* ssh, uint32_t size)
int usedLength;
/* check max input length */
usedLength = ssh->inputBuffer->length - ssh->inputBuffer->idx;
maxLength = ssh->inputBuffer->bufferSz - usedLength;
usedLength = ssh->inputBuffer.length - ssh->inputBuffer.idx;
maxLength = ssh->inputBuffer.bufferSz - usedLength;
inSz = (int)(size - usedLength); /* from last partial read */
WLOG(WS_LOG_DEBUG, "GID: size = %d", size);
@ -406,26 +391,26 @@ static int GetInputData(WOLFSSH* ssh, uint32_t size)
* buffer and resets idx to 0.
*/
if (inSz > maxLength) {
if (GrowBuffer(ssh->inputBuffer, size, usedLength) < 0)
if (GrowBuffer(&ssh->inputBuffer, size, usedLength) < 0)
return WS_MEMORY_E;
}
/* Put buffer data at start if not there */
/* Compress the buffer if needed, i.e. buffer idx is non-zero */
if (usedLength > 0 && ssh->inputBuffer->idx != 0) {
WMEMMOVE(ssh->inputBuffer->buffer,
ssh->inputBuffer->buffer + ssh->inputBuffer->idx,
if (usedLength > 0 && ssh->inputBuffer.idx != 0) {
WMEMMOVE(ssh->inputBuffer.buffer,
ssh->inputBuffer.buffer + ssh->inputBuffer.idx,
usedLength);
}
/* remove processed data */
ssh->inputBuffer->idx = 0;
ssh->inputBuffer->length = usedLength;
ssh->inputBuffer.idx = 0;
ssh->inputBuffer.length = usedLength;
/* read data from network */
do {
in = Receive(ssh,
ssh->inputBuffer->buffer + ssh->inputBuffer->length, inSz);
ssh->inputBuffer.buffer + ssh->inputBuffer.length, inSz);
if (in == -1)
return WS_SOCKET_ERROR_E;
@ -435,10 +420,10 @@ static int GetInputData(WOLFSSH* ssh, uint32_t size)
if (in > inSz)
return WS_RECV_OVERFLOW_E;
ssh->inputBuffer->length += in;
ssh->inputBuffer.length += in;
inSz -= in;
} while (ssh->inputBuffer->length < size);
} while (ssh->inputBuffer.length < size);
return 0;
}
@ -567,9 +552,9 @@ static int DoKexInit(WOLFSSH* ssh, uint8_t* buf, uint32_t len, uint32_t* idx)
static int DoPacket(WOLFSSH* ssh)
{
uint8_t* buf = (uint8_t*)ssh->inputBuffer->buffer;
uint32_t idx = ssh->inputBuffer->idx;
uint32_t len = ssh->inputBuffer->length;
uint8_t* buf = (uint8_t*)ssh->inputBuffer.buffer;
uint32_t idx = ssh->inputBuffer.idx;
uint32_t len = ssh->inputBuffer.length;
uint8_t msg;
uint8_t padSz;
@ -593,7 +578,7 @@ static int DoPacket(WOLFSSH* ssh)
}
idx += padSz;
ssh->inputBuffer->idx = idx;
ssh->inputBuffer.idx = idx;
return WS_SUCCESS;
}
@ -617,8 +602,8 @@ int ProcessReply(WOLFSSH* ssh)
/* Decrypt first block if encrypted */
case PROCESS_PACKET_LENGTH:
ato32(ssh->inputBuffer->buffer + ssh->inputBuffer->idx, &ssh->curSz);
ssh->inputBuffer->idx += LENGTH_SZ;
ato32(ssh->inputBuffer.buffer + ssh->inputBuffer.idx, &ssh->curSz);
ssh->inputBuffer.idx += LENGTH_SZ;
ssh->processReplyState = PROCESS_PACKET_FINISH;
case PROCESS_PACKET_FINISH:
@ -662,7 +647,7 @@ int ProcessClientVersion(WOLFSSH* ssh)
return error;
}
if (WSTRNCASECMP((char*)ssh->inputBuffer->buffer,
if (WSTRNCASECMP((char*)ssh->inputBuffer.buffer,
sshIdStr, protoLen) == 0) {
ssh->clientState = CLIENT_VERSION_DONE;
}
@ -671,14 +656,14 @@ int ProcessClientVersion(WOLFSSH* ssh)
return WS_VERSION_E;
}
ssh->peerId = (char*)WMALLOC(ssh->inputBuffer->length-1, ssh->ctx->heap, WOLFSSH_ID_TYPE);
ssh->peerId = (char*)WMALLOC(ssh->inputBuffer.length-1, ssh->ctx->heap, WOLFSSH_ID_TYPE);
if (ssh->peerId == NULL) {
return WS_MEMORY_E;
}
WMEMCPY(ssh->peerId, ssh->inputBuffer->buffer, ssh->inputBuffer->length-2);
ssh->peerId[ssh->inputBuffer->length - 1] = 0;
ssh->inputBuffer->idx += ssh->inputBuffer->length;
WMEMCPY(ssh->peerId, ssh->inputBuffer.buffer, ssh->inputBuffer.length-2);
ssh->peerId[ssh->inputBuffer.length - 1] = 0;
ssh->inputBuffer.idx += ssh->inputBuffer.length;
WLOG(WS_LOG_DEBUG, "%s", ssh->peerId);
return WS_SUCCESS;

View File

@ -121,14 +121,7 @@ static WOLFSSH* SshInit(WOLFSSH* ssh, WOLFSSH_CTX* ctx)
WMEMSET(ssh, 0, sizeof(WOLFSSH)); /* default init to zeros */
if (ctx)
ssh->ctx = ctx;
else {
WLOG(WS_LOG_ERROR, "Trying to init a wolfSSH w/o wolfSSH_CTX");
wolfSSH_free(ssh);
return NULL;
}
ssh->ctx = ctx;
ssh->rfd = -1; /* set to invalid */
ssh->wfd = -1; /* set to invalid */
ssh->ioReadCtx = &ssh->rfd; /* prevent invalid access if not correctly */
@ -142,8 +135,11 @@ static WOLFSSH* SshInit(WOLFSSH* ssh, WOLFSSH_CTX* ctx)
ssh->pendingPublicKeyId = ID_NONE;
ssh->pendingEncryptionId = ID_NONE;
ssh->pendingIntegrityId = ID_NONE;
ssh->inputBuffer = BufferNew(0, ctx->heap);
ssh->outputBuffer = BufferNew(0, ctx->heap);
if (BufferInit(&ssh->inputBuffer, 0, ctx->heap) != WS_SUCCESS ||
BufferInit(&ssh->outputBuffer, 0, ctx->heap) != WS_SUCCESS) {
wolfSSH_free(ssh);
ssh = NULL;
}
return ssh;
}
@ -156,6 +152,10 @@ WOLFSSH* wolfSSH_new(WOLFSSH_CTX* ctx)
if (ctx)
heap = ctx->heap;
else {
WLOG(WS_LOG_ERROR, "Trying to init a wolfSSH w/o wolfSSH_CTX");
return NULL;
}
WLOG(WS_LOG_DEBUG, "Enter wolfSSH_new()");
@ -175,8 +175,8 @@ static void SshResourceFree(WOLFSSH* ssh, void* heap)
WLOG(WS_LOG_DEBUG, "Enter sshResourceFree()");
WFREE(ssh->peerId, heap, WOLFSSH_ID_TYPE);
BufferFree(ssh->inputBuffer);
BufferFree(ssh->outputBuffer);
ShrinkBuffer(&ssh->inputBuffer, 1);
ShrinkBuffer(&ssh->outputBuffer, 1);
}

View File

@ -91,6 +91,28 @@ WOLFSSH_LOCAL uint8_t NameToId(const char*, uint32_t);
WOLFSSH_LOCAL const char* IdToName(uint8_t);
#define STATIC_BUFFER_LEN 16
/* This is one AES block size. We always grab one
* block size first to decrypt to find the size of
* the rest of the data. */
typedef struct Buffer {
void* heap; /* Heap for allocations */
uint32_t length; /* total buffer length used */
uint32_t idx; /* idx to part of length already consumed */
uint8_t* buffer; /* place holder for actual buffer */
uint32_t bufferSz; /* current buffer size */
ALIGN16 uint8_t staticBuffer[STATIC_BUFFER_LEN];
uint8_t dynamicFlag; /* dynamic memory currently in use */
uint32_t offset; /* Offset from start of buffer to data. */
} Buffer;
WOLFSSH_LOCAL int BufferInit(Buffer*, uint32_t, void*);
WOLFSSH_LOCAL int GrowBuffer(Buffer*, uint32_t, uint32_t);
WOLFSSH_LOCAL void ShrinkBuffer(Buffer* buf, int);
/* our wolfSSH Context */
struct WOLFSSH_CTX {
void* heap; /* heap hint */
@ -133,8 +155,8 @@ struct WOLFSSH {
uint8_t pendingEncryptionId;
uint8_t pendingIntegrityId;
struct Buffer* inputBuffer;
struct Buffer* outputBuffer;
Buffer inputBuffer;
Buffer outputBuffer;
Sha handshakeHash;
uint8_t session_id[SHA_DIGEST_SIZE];
@ -186,30 +208,6 @@ enum SshMessageIds {
};
#define STATIC_BUFFER_LEN 16
/* This is one AES block size. We always grab one
* block size first to decrypt to find the size of
* the rest of the data. */
typedef struct Buffer {
void* heap; /* Heap for allocations */
uint32_t length; /* total buffer length used */
uint32_t idx; /* idx to part of length already consumed */
uint8_t* buffer; /* place holder for actual buffer */
uint32_t bufferSz; /* current buffer size */
ALIGN16 uint8_t staticBuffer[STATIC_BUFFER_LEN];
uint8_t dynamicFlag; /* dynamic memory currently in use */
uint32_t offset; /* Offset from start of buffer to data. */
} Buffer;
WOLFSSH_LOCAL Buffer* BufferNew(uint32_t, void*);
WOLFSSH_LOCAL void BufferFree(Buffer*);
WOLFSSH_LOCAL int GrowBuffer(Buffer*, uint32_t, uint32_t);
WOLFSSH_LOCAL void ShrinkBuffer(Buffer* buf);
WOLFSSH_LOCAL int ProcessClientVersion(WOLFSSH*);
WOLFSSH_LOCAL int SendServerVersion(WOLFSSH*);