diff --git a/cyassl/internal.h b/cyassl/internal.h index 1d81f043a..5c69e622a 100644 --- a/cyassl/internal.h +++ b/cyassl/internal.h @@ -1363,6 +1363,30 @@ enum ClientCertificateType { enum CipherType { stream, block, aead }; +#ifdef CYASSL_DTLS + + #ifdef WORD64_AVAILABLE + typedef word64 DtlsSeq; + #else + typedef word32 DtlsSeq; + #endif + #define DTLS_SEQ_BITS (sizeof(DtlsSeq) * CHAR_BIT) + + typedef struct DtlsState { + DtlsSeq window; /* Sliding window for current epoch */ + word16 nextEpoch; /* Expected epoch in next record */ + word32 nextSeq; /* Expected sequence in next record */ + + word16 curEpoch; /* Received epoch in current record */ + word32 curSeq; /* Received sequence in current record */ + + DtlsSeq prevWindow; /* Sliding window for old epoch */ + word32 prevSeq; /* Next sequence in allowed old epoch */ + } DtlsState; + +#endif /* CYASSL_DTLS */ + + /* keys and secrets */ typedef struct Keys { byte client_write_MAC_secret[MAX_DIGEST_SIZE]; /* max sizes */ @@ -1381,15 +1405,13 @@ typedef struct Keys { word32 sequence_number; #ifdef CYASSL_DTLS - word32 dtls_sequence_number; - word32 dtls_peer_sequence_number; - word32 dtls_expected_peer_sequence_number; - word16 dtls_handshake_number; + DtlsState dtls_state; /* Peer's state */ word16 dtls_peer_handshake_number; word16 dtls_expected_peer_handshake_number; - word16 dtls_epoch; - word16 dtls_peer_epoch; - word16 dtls_expected_peer_epoch; + + word16 dtls_epoch; /* Current tx epoch */ + word32 dtls_sequence_number; /* Current tx sequence */ + word16 dtls_handshake_number; /* Current tx handshake seq */ #endif word32 encryptSz; /* last size of encrypted data */ diff --git a/src/internal.c b/src/internal.c index 85e64f962..dfb4a79b8 100644 --- a/src/internal.c +++ b/src/internal.c @@ -87,6 +87,13 @@ CYASSL_CALLBACKS needs LARGE_STATIC_BUFFERS, please add LARGE_STATIC_BUFFERS #endif #endif + +#ifdef CYASSL_DTLS + static int DtlsCheckWindow(DtlsState* state); + static int DtlsUpdateWindow(DtlsState* state); +#endif + + typedef enum { doProcessInit = 0, #ifndef NO_CYASSL_SERVER @@ -1421,6 +1428,9 @@ int InitSSL(CYASSL* ssl, CYASSL_CTX* ctx) #ifdef CYASSL_DTLS ssl->IOCB_CookieCtx = NULL; /* we don't use for default cb */ ssl->dtls_expected_rx = MAX_MTU; + ssl->keys.dtls_state.window = 0; + ssl->keys.dtls_state.nextEpoch = 0; + ssl->keys.dtls_state.nextSeq = 0; #endif #ifndef NO_OLD_TLS @@ -1478,13 +1488,13 @@ int InitSSL(CYASSL* ssl, CYASSL_CTX* ctx) #ifdef CYASSL_DTLS ssl->keys.dtls_sequence_number = 0; - ssl->keys.dtls_peer_sequence_number = 0; - ssl->keys.dtls_expected_peer_sequence_number = 0; + ssl->keys.dtls_state.curSeq = 0; + ssl->keys.dtls_state.nextSeq = 0; ssl->keys.dtls_handshake_number = 0; ssl->keys.dtls_expected_peer_handshake_number = 0; ssl->keys.dtls_epoch = 0; - ssl->keys.dtls_peer_epoch = 0; - ssl->keys.dtls_expected_peer_epoch = 0; + ssl->keys.dtls_state.curEpoch = 0; + ssl->keys.dtls_state.nextEpoch = 0; ssl->dtls_timeout_init = DTLS_TIMEOUT_INIT; ssl->dtls_timeout_max = DTLS_TIMEOUT_MAX; ssl->dtls_timeout = ssl->dtls_timeout_init; @@ -2762,9 +2772,9 @@ static int GetRecordHeader(CYASSL* ssl, const byte* input, word32* inOutIdx, /* type and version in same sport */ XMEMCPY(rh, input + *inOutIdx, ENUM_LEN + VERSION_SZ); *inOutIdx += ENUM_LEN + VERSION_SZ; - ato16(input + *inOutIdx, &ssl->keys.dtls_peer_epoch); + ato16(input + *inOutIdx, &ssl->keys.dtls_state.curEpoch); *inOutIdx += 4; /* advance past epoch, skip first 2 seq bytes for now */ - ato32(input + *inOutIdx, &ssl->keys.dtls_peer_sequence_number); + ato32(input + *inOutIdx, &ssl->keys.dtls_state.curSeq); *inOutIdx += 4; /* advance past rest of seq */ ato16(input + *inOutIdx, size); *inOutIdx += LENGTH_SZ; @@ -2785,27 +2795,14 @@ static int GetRecordHeader(CYASSL* ssl, const byte* input, word32* inOutIdx, return VERSION_ERROR; /* only use requested version */ } } -#if 0 - /* Instead of this, check the datagram against the sliding window of - * received datagram goodness. */ + #ifdef CYASSL_DTLS - /* If DTLS, check the sequence number against expected. If out of - * order, drop the record. Allows newer records in and resets the - * expected to the next record. */ if (ssl->options.dtls) { - if ((ssl->keys.dtls_expected_peer_epoch == - ssl->keys.dtls_peer_epoch) && - (ssl->keys.dtls_peer_sequence_number >= - ssl->keys.dtls_expected_peer_sequence_number)) { - ssl->keys.dtls_expected_peer_sequence_number = - ssl->keys.dtls_peer_sequence_number + 1; - } - else { + if (DtlsCheckWindow(&ssl->keys.dtls_state) != 1) return SEQUENCE_ERROR; - } } #endif -#endif + /* record layer length check */ #ifdef HAVE_MAX_FRAGMENT if (*size > (ssl->max_fragment + MAX_COMP_EXTRA + MAX_MSG_EXTRA)) @@ -3868,6 +3865,68 @@ static int DoHandShakeMsg(CYASSL* ssl, byte* input, word32* inOutIdx, #ifdef CYASSL_DTLS + +static INLINE int DtlsCheckWindow(DtlsState* state) +{ + word32 cur; + word32 next; + DtlsSeq window; + + if (state->curEpoch == state->nextEpoch) { + next = state->nextSeq; + window = state->window; + } + else if (state->curEpoch < state->nextEpoch) { + next = state->prevSeq; + window = state->prevWindow; + } + else { + return 0; + } + + cur = state->curSeq; + + if ((next > DTLS_SEQ_BITS) && (cur < next - DTLS_SEQ_BITS)) { + return 0; + } + else if ((cur < next) && (window & (1 << (next - cur - 1)))) { + return 0; + } + + return 1; +} + + +static INLINE int DtlsUpdateWindow(DtlsState* state) +{ + word32 cur; + word32* next; + DtlsSeq* window; + + if (state->curEpoch == state->nextEpoch) { + next = &state->nextSeq; + window = &state->window; + } + else { + next = &state->prevSeq; + window = &state->prevWindow; + } + + cur = state->curSeq; + + if (cur < *next) { + *window |= (1 << (*next - cur - 1)); + } + else { + *window <<= (1 + cur - *next); + *window |= 1; + *next = cur + 1; + } + + return 1; +} + + static int DtlsMsgDrain(CYASSL* ssl) { DtlsMsg* item = ssl->dtls_msg_list; @@ -4888,8 +4947,6 @@ int ProcessReply(CYASSL* ssl) &ssl->curRL, &ssl->curSize); #ifdef CYASSL_DTLS if (ssl->options.dtls && ret == SEQUENCE_ERROR) { - /* This message is out of order. If we are handshaking, save - *it for later. Otherwise go ahead and process it. */ ssl->options.processReply = doProcessInit; ssl->buffers.inputBuffer.length = 0; ssl->buffers.inputBuffer.idx = 0; @@ -4925,7 +4982,14 @@ int ProcessReply(CYASSL* ssl) /* the record layer is here */ case runProcessingOneMessage: - if (ssl->keys.encryptionOn && ssl->keys.decryptedCur == 0) { + #ifdef CYASSL_DTLS + if (ssl->options.dtls && + ssl->keys.dtls_state.curEpoch < ssl->keys.dtls_state.nextEpoch) + ssl->keys.decryptedCur = 1; + #endif + + if (ssl->keys.encryptionOn && ssl->keys.decryptedCur == 0) + { ret = SanityCheckCipherText(ssl, ssl->curSize); if (ret < 0) return ret; @@ -4975,6 +5039,12 @@ int ProcessReply(CYASSL* ssl) ssl->keys.decryptedCur = 1; } + if (ssl->options.dtls) { + #ifdef CYASSL_DTLS + DtlsUpdateWindow(&ssl->keys.dtls_state); + #endif /* CYASSL_DTLS */ + } + CYASSL_MSG("received record layer msg"); switch (ssl->curRL.type) { @@ -5034,8 +5104,8 @@ int ProcessReply(CYASSL* ssl) #ifdef CYASSL_DTLS if (ssl->options.dtls) { DtlsPoolReset(ssl); - ssl->keys.dtls_expected_peer_epoch++; - ssl->keys.dtls_expected_peer_sequence_number = 0; + ssl->keys.dtls_state.nextEpoch++; + ssl->keys.dtls_state.nextSeq = 0; } #endif diff --git a/src/tls.c b/src/tls.c index 3d9a83a20..b8f6b1d8a 100644 --- a/src/tls.c +++ b/src/tls.c @@ -401,7 +401,7 @@ static INLINE word32 GetSEQIncrement(CYASSL* ssl, int verify) #ifdef CYASSL_DTLS if (ssl->options.dtls) { if (verify) - return ssl->keys.dtls_peer_sequence_number; /* explicit from peer */ + return ssl->keys.dtls_state.curSeq; /* explicit from peer */ else return ssl->keys.dtls_sequence_number - 1; /* already incremented */ } @@ -418,9 +418,9 @@ static INLINE word32 GetSEQIncrement(CYASSL* ssl, int verify) static INLINE word32 GetEpoch(CYASSL* ssl, int verify) { if (verify) - return ssl->keys.dtls_peer_epoch; + return ssl->keys.dtls_state.curEpoch; else - return ssl->keys.dtls_epoch; + return ssl->keys.dtls_epoch; } #endif /* CYASSL_DTLS */