diff --git a/src/internal.c b/src/internal.c index 2ce8f19e..f90755c5 100644 --- a/src/internal.c +++ b/src/internal.c @@ -139,6 +139,9 @@ const char* GetErrorString(int err) case WS_CRYPTO_FAILED: return "crypto action failed"; + case WS_INVALID_STATE_E: + return "invalid state"; + default: return "Unknown error code"; } @@ -146,6 +149,26 @@ const char* GetErrorString(int err) } +static int wsHighwater(byte dir, void* ctx) +{ + int ret = WS_SUCCESS; + + (void)dir; + + if (ctx) { + WOLFSSH* ssh = (WOLFSSH*)ctx; + + WLOG(WS_LOG_DEBUG, "HIGHWATER MARK: (%u) %s\n", + wolfSSH_GetHighwater(ssh), + (dir == WOLFSSH_HWSIDE_RECEIVE) ? "receive" : "transmit"); + + ret = wolfSSH_TriggerKeyExchange(ssh); + } + + return ret; +} + + WOLFSSH_CTX* CtxInit(WOLFSSH_CTX* ctx, void* heap) { WLOG(WS_LOG_DEBUG, "Entering CtxInit()"); @@ -162,7 +185,8 @@ WOLFSSH_CTX* CtxInit(WOLFSSH_CTX* ctx, void* heap) ctx->ioRecvCb = wsEmbedRecv; ctx->ioSendCb = wsEmbedSend; #endif /* WOLFSSH_USER_IO */ - ctx->countHighwater = DEFAULT_COUNT_HIGHWATER; + ctx->highwaterMark = DEFAULT_HIGHWATER_MARK; + ctx->highwaterCb = wsHighwater; return ctx; } @@ -213,9 +237,11 @@ WOLFSSH* SshInit(WOLFSSH* ssh, WOLFSSH_CTX* ctx) ssh->wfd = -1; /* set to invalid */ ssh->ioReadCtx = &ssh->rfd; /* prevent invalid access if not correctly */ ssh->ioWriteCtx = &ssh->wfd; /* set */ - ssh->countHighwater = ctx->countHighwater; + ssh->highwaterMark = ctx->highwaterMark; + ssh->highwaterCtx = (void*)ssh; ssh->acceptState = ACCEPT_BEGIN; ssh->clientState = CLIENT_BEGIN; + ssh->keyingState = KEYING_UNKEYED; ssh->nextChannel = DEFAULT_NEXT_CHANNEL; ssh->blockSz = MIN_BLOCK_SZ; ssh->encryptId = ID_NONE; @@ -265,6 +291,9 @@ void SshResourceFree(WOLFSSH* ssh, void* heap) if (ssh->userName) { WFREE(ssh->userName, heap, DYNTYPE_STRING); } + if (ssh->clientId) { + WFREE(ssh->clientId, heap, DYNTYPE_STRING); + } if (ssh->channelList) { WOLFSSH_CHANNEL* cur = ssh->channelList; WOLFSSH_CHANNEL* next; @@ -1011,6 +1040,9 @@ static int DoKexInit(WOLFSSH* ssh, uint8_t* buf, uint32_t len, uint32_t* idx) if (ssh == NULL || buf == NULL || len == 0 || idx == NULL) ret = WS_BAD_ARGUMENT; + + if (ssh->keyingState != KEYING_UNKEYED && ssh->keyingState != KEYING_KEYED) + ret = WS_INVALID_STATE_E; /* * I don't need to save what the client sends here. I should decode * each list into a local array of IDs, and pick the one the peer is @@ -2035,6 +2067,11 @@ static int DoUserAuthRequest(WOLFSSH* ssh, authData.sf.publicKey.dataToSign = buf + *idx; ret = DoUserAuthRequestPublicKey(ssh, &authData, buf, len, &begin); } +#ifdef WOLFSSH_ALLOW_USERAUTH_NONE + else if (authNameId == ID_NONE) { + ssh->clientState = CLIENT_USERAUTH_DONE; + } +#endif else { WLOG(WS_LOG_DEBUG, "invalid userauth type: %s", IdToName(authNameId)); @@ -2413,10 +2450,13 @@ static INLINE int Encrypt(WOLFSSH* ssh, uint8_t* cipher, const uint8_t* input, } ssh->txCount += sz; - if (ssh->countHighwater && ssh->txCount > ssh->countHighwater) { + if (ssh->highwaterMark && !ssh->highwaterFlag && + ssh->txCount > ssh->highwaterMark) { + WLOG(WS_LOG_DEBUG, "Transmit over high water mark"); if (ssh->ctx->highwaterCb) ssh->ctx->highwaterCb(WOLFSSH_HWSIDE_TRANSMIT, ssh->highwaterCtx); + ssh->highwaterFlag = 1; } return ret; @@ -2450,10 +2490,13 @@ static INLINE int Decrypt(WOLFSSH* ssh, uint8_t* plain, const uint8_t* input, } ssh->rxCount += sz; - if (ssh->countHighwater && ssh->rxCount > ssh->countHighwater) { + if (ssh->highwaterMark && !ssh->highwaterFlag && + ssh->rxCount > ssh->highwaterMark) { + WLOG(WS_LOG_DEBUG, "Receive over high water mark"); if (ssh->ctx->highwaterCb) ssh->ctx->highwaterCb(WOLFSSH_HWSIDE_RECEIVE, ssh->highwaterCtx); + ssh->highwaterFlag = 1; } return ret; @@ -2679,16 +2722,17 @@ static const char sshIdStr[] = "SSH-2.0-wolfSSHv" int ProcessClientVersion(WOLFSSH* ssh) { - int error; - uint32_t protoLen = 7; /* Length of the SSH-2.0 portion of the ID str */ - uint8_t scratch[LENGTH_SZ]; + int ret; + uint32_t clientIdSz; - if ( (error = GetInputText(ssh)) < 0) { + if ( (ret = GetInputText(ssh)) < 0) { WLOG(WS_LOG_DEBUG, "get input text failed"); - return error; + return ret; } - if (WSTRNCASECMP((char*)ssh->inputBuffer.buffer, sshIdStr, protoLen) == 0) { + if (WSTRNCASECMP((char*)ssh->inputBuffer.buffer, + sshIdStr, SSH_PROTO_SZ) == 0) { + ssh->clientState = CLIENT_VERSION_DONE; } else { @@ -2696,38 +2740,58 @@ int ProcessClientVersion(WOLFSSH* ssh) return WS_VERSION_E; } - c32toa(ssh->inputBuffer.length - 2, scratch); - wc_ShaUpdate(&ssh->handshake->hash, scratch, LENGTH_SZ); - wc_ShaUpdate(&ssh->handshake->hash, ssh->inputBuffer.buffer, - ssh->inputBuffer.length - 2); + clientIdSz = ssh->inputBuffer.length - SSH_PROTO_EOL_SZ; + + ssh->clientId = (uint8_t*)WMALLOC(clientIdSz, + ssh->ctx->heap, DYNTYPE_STRING); + if (ssh->clientId == NULL) + ret = WS_MEMORY_E; + else { + uint8_t flatSz[LENGTH_SZ]; + + /* Store the client version string. Will need it later during rekey */ + WMEMCPY(ssh->clientId, ssh->inputBuffer.buffer, clientIdSz); + ssh->clientIdSz = clientIdSz; + c32toa(clientIdSz, flatSz); + ret = wc_ShaUpdate(&ssh->handshake->hash, flatSz, LENGTH_SZ); + } + + if (ret == WS_SUCCESS) + ret = wc_ShaUpdate(&ssh->handshake->hash, ssh->inputBuffer.buffer, + clientIdSz); + ssh->inputBuffer.idx += ssh->inputBuffer.length; - return WS_SUCCESS; + return ret; } int SendServerVersion(WOLFSSH* ssh) { int ret = WS_SUCCESS; - uint32_t sshIdStrSz = (uint32_t)WSTRLEN(sshIdStr); - uint8_t sshIdStrSzFlat[LENGTH_SZ]; + uint32_t sshIdStrSz; if (ssh == NULL) ret = WS_BAD_ARGUMENT; if (ret == WS_SUCCESS) { WLOG(WS_LOG_DEBUG, "%s", sshIdStr); - ret = SendText(ssh, sshIdStr, (uint32_t)WSTRLEN(sshIdStr)); + sshIdStrSz = (uint32_t)WSTRLEN(sshIdStr); + ret = SendText(ssh, sshIdStr, sshIdStrSz); } if (ret == WS_SUCCESS) { - sshIdStrSz -= 2; /* Remove the CRLF */ + uint8_t sshIdStrSzFlat[LENGTH_SZ]; + + sshIdStrSz -= SSH_PROTO_EOL_SZ; c32toa(sshIdStrSz, sshIdStrSzFlat); - wc_ShaUpdate(&ssh->handshake->hash, sshIdStrSzFlat, LENGTH_SZ); - wc_ShaUpdate(&ssh->handshake->hash, - (const uint8_t*)sshIdStr, sshIdStrSz); + ret = wc_ShaUpdate(&ssh->handshake->hash, sshIdStrSzFlat, LENGTH_SZ); } + if (ret == WS_SUCCESS) + ret = wc_ShaUpdate(&ssh->handshake->hash, + (const uint8_t*)sshIdStr, sshIdStrSz); + return ret; } @@ -2748,7 +2812,7 @@ static int PreparePacket(WOLFSSH* ssh, uint32_t payloadSz) /* Minimum value for paddingSz is 4. */ paddingSz = ssh->blockSz - (LENGTH_SZ + PAD_LENGTH_SZ + payloadSz) % ssh->blockSz; - if (paddingSz < 4) + if (paddingSz < MIN_PAD_LENGTH) paddingSz += ssh->blockSz; ssh->paddingSz = paddingSz; packetSz = PAD_LENGTH_SZ + payloadSz + paddingSz; @@ -3186,6 +3250,7 @@ int SendNewKeys(WOLFSSH* ssh) } ssh->txCount = 0; + ssh->highwaterFlag = 0; } return ret; diff --git a/src/ssh.c b/src/ssh.c index b751073f..8e3a4ecf 100644 --- a/src/ssh.c +++ b/src/ssh.c @@ -165,7 +165,7 @@ int wolfSSH_SetHighwater(WOLFSSH* ssh, uint32_t highwater) WLOG(WS_LOG_DEBUG, "Entering wolfSSH_SetHighwater()"); if (ssh) { - ssh->countHighwater = highwater; + ssh->highwaterMark = highwater; return WS_SUCCESS; } @@ -179,7 +179,7 @@ uint32_t wolfSSH_GetHighwater(WOLFSSH* ssh) WLOG(WS_LOG_DEBUG, "Entering wolfSSH_GetHighwater()"); if (ssh) - return ssh->countHighwater; + return ssh->highwaterMark; return 0; } @@ -191,7 +191,7 @@ void wolfSSH_SetHighwaterCb(WOLFSSH_CTX* ctx, uint32_t highwater, WLOG(WS_LOG_DEBUG, "Entering wolfSSH_SetHighwaterCb()"); if (ctx) { - ctx->countHighwater = highwater; + ctx->highwaterMark = highwater; ctx->highwaterCb = cb; } } @@ -269,61 +269,21 @@ int wolfSSH_accept(WOLFSSH* ssh) WLOG(WS_LOG_DEBUG, acceptState, "SERVER_VERSION_SENT"); case ACCEPT_SERVER_VERSION_SENT: - while (ssh->clientState < CLIENT_KEXINIT_DONE) { + while (ssh->keyingState < KEYING_KEYED) { if ( (ssh->error = ProcessReply(ssh)) < WS_SUCCESS) { WLOG(WS_LOG_DEBUG, acceptError, "SERVER_VERSION_SENT", ssh->error); return WS_FATAL_ERROR; } } - ssh->acceptState = ACCEPT_CLIENT_KEXINIT_DONE; - WLOG(WS_LOG_DEBUG, acceptState, "CLIENT_KEXINIT_DONE"); + ssh->acceptState = ACCEPT_KEYED; + WLOG(WS_LOG_DEBUG, acceptState, "KEYED"); - case ACCEPT_CLIENT_KEXINIT_DONE: - if ( (ssh->error = SendKexInit(ssh)) < WS_SUCCESS) { - WLOG(WS_LOG_DEBUG, acceptError, - "CLIENT_KEXINIT_DONE", ssh->error); - return WS_FATAL_ERROR; - } - ssh->acceptState = ACCEPT_SERVER_KEXINIT_SENT; - WLOG(WS_LOG_DEBUG, acceptState, "SERVER_KEXINIT_SENT"); - - case ACCEPT_SERVER_KEXINIT_SENT: - while (ssh->clientState < CLIENT_KEXDH_INIT_DONE) { - if ( (ssh->error = ProcessReply(ssh)) < 0) { - WLOG(WS_LOG_DEBUG, acceptError, - "SERVER_KEXINIT_SENT", ssh->error); - return WS_FATAL_ERROR; - } - } - ssh->acceptState = ACCEPT_CLIENT_KEXDH_INIT_DONE; - WLOG(WS_LOG_DEBUG, acceptState, "CLIENT_KEXDH_INIT_DONE"); - - case ACCEPT_CLIENT_KEXDH_INIT_DONE: - if ( (ssh->error = SendKexDhReply(ssh)) < WS_SUCCESS) { - WLOG(WS_LOG_DEBUG, acceptError, - "CLIENT_KEXDH_INIT_DONE", ssh->error); - return WS_FATAL_ERROR; - } - ssh->acceptState = ACCEPT_SERVER_KEXDH_REPLY_SENT; - WLOG(WS_LOG_DEBUG, acceptState, "SERVER_KEXDH_REPLY_SENT"); - - case ACCEPT_SERVER_KEXDH_REPLY_SENT: - while (ssh->clientState < CLIENT_USING_KEYS) { - if ( (ssh->error = ProcessReply(ssh)) < 0) { - WLOG(WS_LOG_DEBUG, acceptError, - "SERVER_KEXDH_REPLY_SENT", ssh->error); - return WS_FATAL_ERROR; - } - } - ssh->acceptState = ACCEPT_USING_KEYS; - WLOG(WS_LOG_DEBUG, acceptState, "USING_KEYS"); - - case ACCEPT_USING_KEYS: + case ACCEPT_KEYED: while (ssh->clientState < CLIENT_USERAUTH_REQUEST_DONE) { if ( (ssh->error = ProcessReply(ssh)) < 0) { WLOG(WS_LOG_DEBUG, acceptError, - "USING_KEYS", ssh->error); + "KEYED", ssh->error); return WS_FATAL_ERROR; } } @@ -385,6 +345,13 @@ int wolfSSH_accept(WOLFSSH* ssh) } +int wolfSSH_TriggerKeyExchange(WOLFSSH* ssh) +{ + (void)ssh; + return WS_SUCCESS; +} + + int wolfSSH_stream_read(WOLFSSH* ssh, uint8_t* buf, uint32_t bufSz) { Buffer* inputBuffer; diff --git a/wolfssh/error.h b/wolfssh/error.h index 559290ac..a1105875 100644 --- a/wolfssh/error.h +++ b/wolfssh/error.h @@ -66,7 +66,8 @@ enum WS_ErrorCodes { WS_INVALID_CHANTYPE = -26, /* invalid channel type */ WS_INVALID_CHANID = -27, WS_INVALID_USERNAME = -28, - WS_CRYPTO_FAILED = -29 /* crypto action failed */ + WS_CRYPTO_FAILED = -29, /* crypto action failed */ + WS_INVALID_STATE_E = -30 }; diff --git a/wolfssh/internal.h b/wolfssh/internal.h index 00935d0f..f604980a 100644 --- a/wolfssh/internal.h +++ b/wolfssh/internal.h @@ -97,11 +97,14 @@ enum { #define COOKIE_SZ 16 #define LENGTH_SZ 4 #define PAD_LENGTH_SZ 1 +#define MIN_PAD_LENGTH 4 #define BOOLEAN_SZ 1 #define MSG_ID_SZ 1 #define SHA1_96_SZ 12 #define UINT32_SZ 4 -#define DEFAULT_COUNT_HIGHWATER ((1024 * 1024 * 1024) - (32 * 1024)) +#define SSH_PROTO_SZ 7 /* "SSH-2.0" */ +#define SSH_PROTO_EOL_SZ 2 /* Just the CRLF */ +#define DEFAULT_HIGHWATER_MARK ((1024 * 1024 * 1024) - (32 * 1024)) #ifndef DEFAULT_WINDOW_SZ #define DEFAULT_WINDOW_SZ (1024 * 1024) #endif @@ -145,7 +148,7 @@ struct WOLFSSH_CTX { uint8_t* privateKey; /* Owned by CTX */ uint32_t privateKeySz; - uint32_t countHighwater; + uint32_t highwaterMark; }; @@ -179,6 +182,8 @@ typedef struct HandshakeInfo { Sha hash; uint8_t e[257]; /* May have a leading zero, for unsigned. */ uint32_t eSz; + uint8_t* serverKexInit; /* Used for server initiated rekey. */ + uint32_t serverKeyInitSz; } HandshakeInfo; @@ -194,7 +199,8 @@ struct WOLFSSH { int wflags; /* optional write flags */ uint32_t txCount; uint32_t rxCount; - uint32_t countHighwater; + uint32_t highwaterMark; + uint8_t highwaterFlag; /* Set when highwater CB called */ void* highwaterCtx; uint32_t curSz; uint32_t seq; @@ -204,6 +210,7 @@ struct WOLFSSH { uint8_t acceptState; uint8_t clientState; uint8_t processReplyState; + uint8_t keyingState; uint8_t connReset; uint8_t isClosed; @@ -245,6 +252,8 @@ struct WOLFSSH { uint32_t userNameSz; uint8_t* pkBlob; uint32_t pkBlobSz; + uint8_t* clientId; /* Save for rekey */ + uint32_t clientIdSz; }; @@ -252,7 +261,6 @@ struct WOLFSSH_CHANNEL { uint8_t channelType; uint32_t channel; uint32_t windowSz; - uint32_t highwaterMark; uint32_t maxPacketSz; uint32_t peerChannel; uint32_t peerWindowSz; @@ -310,15 +318,27 @@ WOLFSSH_LOCAL int GenerateKey(uint8_t, uint8_t, uint8_t*, uint32_t, const uint8_t*, uint32_t); +enum KeyingStates { + KEYING_UNKEYED = 0, + + KEYING_KEXINIT_SENT, + KEYING_KEXINIT_RECV, + KEYING_KEXINIT_DONE, + + KEYING_KEXDH_INIT_RECV, + KEYING_KEXDH_DONE, + + KEYING_USING_KEYS_SENT, + KEYING_USING_KEYS_RECV, + KEYING_KEYED +}; + + enum AcceptStates { ACCEPT_BEGIN = 0, ACCEPT_CLIENT_VERSION_DONE, ACCEPT_SERVER_VERSION_SENT, - ACCEPT_CLIENT_KEXINIT_DONE, - ACCEPT_SERVER_KEXINIT_SENT, - ACCEPT_CLIENT_KEXDH_INIT_DONE, - ACCEPT_SERVER_KEXDH_REPLY_SENT, - ACCEPT_USING_KEYS, + ACCEPT_KEYED, ACCEPT_CLIENT_USERAUTH_REQUEST_DONE, ACCEPT_SERVER_USERAUTH_ACCEPT_SENT, ACCEPT_CLIENT_USERAUTH_DONE, diff --git a/wolfssh/ssh.h b/wolfssh/ssh.h index a6f8827b..b3368f43 100644 --- a/wolfssh/ssh.h +++ b/wolfssh/ssh.h @@ -133,6 +133,7 @@ WOLFSSH_API int wolfSSH_stream_send(WOLFSSH*, uint8_t*, uint32_t); WOLFSSH_API int wolfSSH_channel_read(WOLFSSH_CHANNEL*, uint8_t*, uint32_t); WOLFSSH_API int wolfSSH_channel_send(WOLFSSH_CHANNEL*, uint8_t*, uint32_t); WOLFSSH_API int wolfSSH_worker(WOLFSSH*); +WOLFSSH_API int wolfSSH_TriggerKeyExchange(WOLFSSH*); WOLFSSH_API int wolfSSH_KDF(uint8_t, uint8_t, uint8_t*, uint32_t, const uint8_t*, uint32_t, const uint8_t*, uint32_t,