diff --git a/src/internal.c b/src/internal.c index ee2f507..f1e2ba3 100644 --- a/src/internal.c +++ b/src/internal.c @@ -148,43 +148,27 @@ const char* IdToName(uint8_t id) } -Buffer* BufferNew(uint32_t size, void* heap) +int BufferInit(Buffer* buffer, uint32_t size, void* heap) { - Buffer* newBuffer = NULL; + if (buffer == NULL) + return WS_BAD_ARGUMENT; if (size <= STATIC_BUFFER_LEN) size = STATIC_BUFFER_LEN; - newBuffer = (Buffer*)WMALLOC(sizeof(Buffer), heap, WOLFSSH_TYPE_BUFFER); - if (newBuffer != NULL) { - WMEMSET(newBuffer, 0, sizeof(Buffer)); - newBuffer->heap = heap; - newBuffer->bufferSz = size; - if (size > STATIC_BUFFER_LEN) { - newBuffer->buffer = (uint8_t*)WMALLOC(size, - heap, WOLFSSH_TYPE_BUFFER); - if (newBuffer->buffer == NULL) { - WFREE(newBuffer, heap, WOLFSSH_TYPE_BUFFER); - newBuffer = NULL; - } - else - newBuffer->dynamicFlag = 1; - } - else - newBuffer->buffer = newBuffer->staticBuffer; + WMEMSET(buffer, 0, sizeof(Buffer)); + buffer->heap = heap; + buffer->bufferSz = size; + if (size > STATIC_BUFFER_LEN) { + buffer->buffer = (uint8_t*)WMALLOC(size, heap, WOLFSSH_TYPE_BUFFER); + if (buffer->buffer == NULL) + return WS_MEMORY_E; + buffer->dynamicFlag = 1; } + else + buffer->buffer = buffer->staticBuffer; - return newBuffer; -} - - -void BufferFree(Buffer* buf) -{ - if (buf != NULL) { - if (buf->dynamicFlag) - WFREE(buf->buffer, buf->heap, WOLFSSH_TYPE_BUFFER); - WFREE(buf, bug->heap, WOLFSSH_TYPE_BUFFER); - } + return WS_SUCCESS; } @@ -228,24 +212,25 @@ int GrowBuffer(Buffer* buf, uint32_t sz, uint32_t usedSz) } -void ShrinkBuffer(Buffer* buf) +void ShrinkBuffer(Buffer* buf, int forcedFree) { if (buf != NULL) { uint32_t usedSz = buf->length - buf->idx; - if (usedSz > STATIC_BUFFER_LEN) + if (!forcedFree && usedSz > STATIC_BUFFER_LEN) return; WLOG(WS_LOG_DEBUG, "Shrinking buffer"); - if (usedSz) + if (!forcedFree && usedSz) WMEMCPY(buf->staticBuffer, buf->buffer + buf->idx, usedSz); - WFREE(buf->buffer, buf->heap, WOLFSSH_TYPE_BUFFER); + if (buf->dynamicFlag) + WFREE(buf->buffer, buf->heap, WOLFSSH_TYPE_BUFFER); buf->dynamicFlag = 0; buf->buffer = buf->staticBuffer; buf->bufferSz = STATIC_BUFFER_LEN; - buf->length = usedSz; + buf->length = forcedFree ? 0 : usedSz; buf->idx = 0; } } @@ -299,12 +284,12 @@ static int GetInputText(WOLFSSH* ssh) int inSz = 255; int in; - if (GrowBuffer(ssh->inputBuffer, inSz, 0) < 0) + if (GrowBuffer(&ssh->inputBuffer, inSz, 0) < 0) return WS_MEMORY_E; do { in = Receive(ssh, - ssh->inputBuffer->buffer + ssh->inputBuffer->length, inSz); + ssh->inputBuffer.buffer + ssh->inputBuffer.length, inSz); if (in == -1) return WS_SOCKET_ERROR_E; @@ -315,12 +300,12 @@ static int GetInputText(WOLFSSH* ssh) if (in > inSz) return WS_RECV_OVERFLOW_E; - ssh->inputBuffer->length += in; + ssh->inputBuffer.length += in; inSz -= in; - if (ssh->inputBuffer->length > 2) { - if (ssh->inputBuffer->buffer[ssh->inputBuffer->length - 2] == '\r' && - ssh->inputBuffer->buffer[ssh->inputBuffer->length - 1] == '\n') { + if (ssh->inputBuffer.length > 2) { + if (ssh->inputBuffer.buffer[ssh->inputBuffer.length - 2] == '\r' && + ssh->inputBuffer.buffer[ssh->inputBuffer.length - 1] == '\n') { gotLine = 1; } @@ -338,26 +323,26 @@ static int SendBuffer(WOLFSSH* ssh) return -1; } - while (ssh->outputBuffer->length > 0) { + while (ssh->outputBuffer.length > 0) { int sent = ssh->ctx->ioSendCb(ssh, - ssh->outputBuffer->buffer + ssh->outputBuffer->idx, - ssh->outputBuffer->length, ssh->ioWriteCtx); + ssh->outputBuffer.buffer + ssh->outputBuffer.idx, + ssh->outputBuffer.length, ssh->ioWriteCtx); if (sent < 0) { return WS_SOCKET_ERROR_E; } - if (sent > (int)ssh->outputBuffer->length) { + if (sent > (int)ssh->outputBuffer.length) { WLOG(WS_LOG_DEBUG, "Out of bounds read"); return WS_SEND_OOB_READ_E; } - ssh->outputBuffer->idx += sent; - ssh->outputBuffer->length -= sent; + ssh->outputBuffer.idx += sent; + ssh->outputBuffer.length -= sent; } - ssh->outputBuffer->idx = 0; - ShrinkBuffer(ssh->outputBuffer); + ssh->outputBuffer.idx = 0; + ShrinkBuffer(&ssh->outputBuffer, 0); return WS_SUCCESS; } @@ -365,9 +350,9 @@ static int SendBuffer(WOLFSSH* ssh) static int SendText(WOLFSSH* ssh, const char* text, uint32_t textLen) { - GrowBuffer(ssh->outputBuffer, textLen, 0); - WMEMCPY(ssh->outputBuffer->buffer, text, textLen); - ssh->outputBuffer->length = textLen; + GrowBuffer(&ssh->outputBuffer, textLen, 0); + WMEMCPY(ssh->outputBuffer.buffer, text, textLen); + ssh->outputBuffer.length = textLen; return SendBuffer(ssh); } @@ -381,8 +366,8 @@ static int GetInputData(WOLFSSH* ssh, uint32_t size) int usedLength; /* check max input length */ - usedLength = ssh->inputBuffer->length - ssh->inputBuffer->idx; - maxLength = ssh->inputBuffer->bufferSz - usedLength; + usedLength = ssh->inputBuffer.length - ssh->inputBuffer.idx; + maxLength = ssh->inputBuffer.bufferSz - usedLength; inSz = (int)(size - usedLength); /* from last partial read */ WLOG(WS_LOG_DEBUG, "GID: size = %d", size); @@ -406,26 +391,26 @@ static int GetInputData(WOLFSSH* ssh, uint32_t size) * buffer and resets idx to 0. */ if (inSz > maxLength) { - if (GrowBuffer(ssh->inputBuffer, size, usedLength) < 0) + if (GrowBuffer(&ssh->inputBuffer, size, usedLength) < 0) return WS_MEMORY_E; } /* Put buffer data at start if not there */ /* Compress the buffer if needed, i.e. buffer idx is non-zero */ - if (usedLength > 0 && ssh->inputBuffer->idx != 0) { - WMEMMOVE(ssh->inputBuffer->buffer, - ssh->inputBuffer->buffer + ssh->inputBuffer->idx, + if (usedLength > 0 && ssh->inputBuffer.idx != 0) { + WMEMMOVE(ssh->inputBuffer.buffer, + ssh->inputBuffer.buffer + ssh->inputBuffer.idx, usedLength); } /* remove processed data */ - ssh->inputBuffer->idx = 0; - ssh->inputBuffer->length = usedLength; + ssh->inputBuffer.idx = 0; + ssh->inputBuffer.length = usedLength; /* read data from network */ do { in = Receive(ssh, - ssh->inputBuffer->buffer + ssh->inputBuffer->length, inSz); + ssh->inputBuffer.buffer + ssh->inputBuffer.length, inSz); if (in == -1) return WS_SOCKET_ERROR_E; @@ -435,10 +420,10 @@ static int GetInputData(WOLFSSH* ssh, uint32_t size) if (in > inSz) return WS_RECV_OVERFLOW_E; - ssh->inputBuffer->length += in; + ssh->inputBuffer.length += in; inSz -= in; - } while (ssh->inputBuffer->length < size); + } while (ssh->inputBuffer.length < size); return 0; } @@ -567,9 +552,9 @@ static int DoKexInit(WOLFSSH* ssh, uint8_t* buf, uint32_t len, uint32_t* idx) static int DoPacket(WOLFSSH* ssh) { - uint8_t* buf = (uint8_t*)ssh->inputBuffer->buffer; - uint32_t idx = ssh->inputBuffer->idx; - uint32_t len = ssh->inputBuffer->length; + uint8_t* buf = (uint8_t*)ssh->inputBuffer.buffer; + uint32_t idx = ssh->inputBuffer.idx; + uint32_t len = ssh->inputBuffer.length; uint8_t msg; uint8_t padSz; @@ -593,7 +578,7 @@ static int DoPacket(WOLFSSH* ssh) } idx += padSz; - ssh->inputBuffer->idx = idx; + ssh->inputBuffer.idx = idx; return WS_SUCCESS; } @@ -617,8 +602,8 @@ int ProcessReply(WOLFSSH* ssh) /* Decrypt first block if encrypted */ case PROCESS_PACKET_LENGTH: - ato32(ssh->inputBuffer->buffer + ssh->inputBuffer->idx, &ssh->curSz); - ssh->inputBuffer->idx += LENGTH_SZ; + ato32(ssh->inputBuffer.buffer + ssh->inputBuffer.idx, &ssh->curSz); + ssh->inputBuffer.idx += LENGTH_SZ; ssh->processReplyState = PROCESS_PACKET_FINISH; case PROCESS_PACKET_FINISH: @@ -662,7 +647,7 @@ int ProcessClientVersion(WOLFSSH* ssh) return error; } - if (WSTRNCASECMP((char*)ssh->inputBuffer->buffer, + if (WSTRNCASECMP((char*)ssh->inputBuffer.buffer, sshIdStr, protoLen) == 0) { ssh->clientState = CLIENT_VERSION_DONE; } @@ -671,14 +656,14 @@ int ProcessClientVersion(WOLFSSH* ssh) return WS_VERSION_E; } - ssh->peerId = (char*)WMALLOC(ssh->inputBuffer->length-1, ssh->ctx->heap, WOLFSSH_ID_TYPE); + ssh->peerId = (char*)WMALLOC(ssh->inputBuffer.length-1, ssh->ctx->heap, WOLFSSH_ID_TYPE); if (ssh->peerId == NULL) { return WS_MEMORY_E; } - WMEMCPY(ssh->peerId, ssh->inputBuffer->buffer, ssh->inputBuffer->length-2); - ssh->peerId[ssh->inputBuffer->length - 1] = 0; - ssh->inputBuffer->idx += ssh->inputBuffer->length; + WMEMCPY(ssh->peerId, ssh->inputBuffer.buffer, ssh->inputBuffer.length-2); + ssh->peerId[ssh->inputBuffer.length - 1] = 0; + ssh->inputBuffer.idx += ssh->inputBuffer.length; WLOG(WS_LOG_DEBUG, "%s", ssh->peerId); return WS_SUCCESS; diff --git a/src/ssh.c b/src/ssh.c index e6f7d2f..ccffded 100644 --- a/src/ssh.c +++ b/src/ssh.c @@ -121,14 +121,7 @@ static WOLFSSH* SshInit(WOLFSSH* ssh, WOLFSSH_CTX* ctx) WMEMSET(ssh, 0, sizeof(WOLFSSH)); /* default init to zeros */ - if (ctx) - ssh->ctx = ctx; - else { - WLOG(WS_LOG_ERROR, "Trying to init a wolfSSH w/o wolfSSH_CTX"); - wolfSSH_free(ssh); - return NULL; - } - + ssh->ctx = ctx; ssh->rfd = -1; /* set to invalid */ ssh->wfd = -1; /* set to invalid */ ssh->ioReadCtx = &ssh->rfd; /* prevent invalid access if not correctly */ @@ -142,8 +135,11 @@ static WOLFSSH* SshInit(WOLFSSH* ssh, WOLFSSH_CTX* ctx) ssh->pendingPublicKeyId = ID_NONE; ssh->pendingEncryptionId = ID_NONE; ssh->pendingIntegrityId = ID_NONE; - ssh->inputBuffer = BufferNew(0, ctx->heap); - ssh->outputBuffer = BufferNew(0, ctx->heap); + if (BufferInit(&ssh->inputBuffer, 0, ctx->heap) != WS_SUCCESS || + BufferInit(&ssh->outputBuffer, 0, ctx->heap) != WS_SUCCESS) { + wolfSSH_free(ssh); + ssh = NULL; + } return ssh; } @@ -156,6 +152,10 @@ WOLFSSH* wolfSSH_new(WOLFSSH_CTX* ctx) if (ctx) heap = ctx->heap; + else { + WLOG(WS_LOG_ERROR, "Trying to init a wolfSSH w/o wolfSSH_CTX"); + return NULL; + } WLOG(WS_LOG_DEBUG, "Enter wolfSSH_new()"); @@ -175,8 +175,8 @@ static void SshResourceFree(WOLFSSH* ssh, void* heap) WLOG(WS_LOG_DEBUG, "Enter sshResourceFree()"); WFREE(ssh->peerId, heap, WOLFSSH_ID_TYPE); - BufferFree(ssh->inputBuffer); - BufferFree(ssh->outputBuffer); + ShrinkBuffer(&ssh->inputBuffer, 1); + ShrinkBuffer(&ssh->outputBuffer, 1); } diff --git a/wolfssh/internal.h b/wolfssh/internal.h index 3793f6f..5e3c25d 100644 --- a/wolfssh/internal.h +++ b/wolfssh/internal.h @@ -91,6 +91,28 @@ WOLFSSH_LOCAL uint8_t NameToId(const char*, uint32_t); WOLFSSH_LOCAL const char* IdToName(uint8_t); +#define STATIC_BUFFER_LEN 16 +/* This is one AES block size. We always grab one + * block size first to decrypt to find the size of + * the rest of the data. */ + + +typedef struct Buffer { + void* heap; /* Heap for allocations */ + uint32_t length; /* total buffer length used */ + uint32_t idx; /* idx to part of length already consumed */ + uint8_t* buffer; /* place holder for actual buffer */ + uint32_t bufferSz; /* current buffer size */ + ALIGN16 uint8_t staticBuffer[STATIC_BUFFER_LEN]; + uint8_t dynamicFlag; /* dynamic memory currently in use */ + uint32_t offset; /* Offset from start of buffer to data. */ +} Buffer; + +WOLFSSH_LOCAL int BufferInit(Buffer*, uint32_t, void*); +WOLFSSH_LOCAL int GrowBuffer(Buffer*, uint32_t, uint32_t); +WOLFSSH_LOCAL void ShrinkBuffer(Buffer* buf, int); + + /* our wolfSSH Context */ struct WOLFSSH_CTX { void* heap; /* heap hint */ @@ -133,8 +155,8 @@ struct WOLFSSH { uint8_t pendingEncryptionId; uint8_t pendingIntegrityId; - struct Buffer* inputBuffer; - struct Buffer* outputBuffer; + Buffer inputBuffer; + Buffer outputBuffer; Sha handshakeHash; uint8_t session_id[SHA_DIGEST_SIZE]; @@ -186,30 +208,6 @@ enum SshMessageIds { }; -#define STATIC_BUFFER_LEN 16 -/* This is one AES block size. We always grab one - * block size first to decrypt to find the size of - * the rest of the data. */ - - -typedef struct Buffer { - void* heap; /* Heap for allocations */ - uint32_t length; /* total buffer length used */ - uint32_t idx; /* idx to part of length already consumed */ - uint8_t* buffer; /* place holder for actual buffer */ - uint32_t bufferSz; /* current buffer size */ - ALIGN16 uint8_t staticBuffer[STATIC_BUFFER_LEN]; - uint8_t dynamicFlag; /* dynamic memory currently in use */ - uint32_t offset; /* Offset from start of buffer to data. */ -} Buffer; - - -WOLFSSH_LOCAL Buffer* BufferNew(uint32_t, void*); -WOLFSSH_LOCAL void BufferFree(Buffer*); -WOLFSSH_LOCAL int GrowBuffer(Buffer*, uint32_t, uint32_t); -WOLFSSH_LOCAL void ShrinkBuffer(Buffer* buf); - - WOLFSSH_LOCAL int ProcessClientVersion(WOLFSSH*); WOLFSSH_LOCAL int SendServerVersion(WOLFSSH*);