diff --git a/src/dtls13.c b/src/dtls13.c index 5a9b6dca2..375c00af0 100644 --- a/src/dtls13.c +++ b/src/dtls13.c @@ -2618,19 +2618,16 @@ static int Dtls13RtxIsTrackedByRn(const Dtls13RtxRecord* r, w64wrapper epoch, static int Dtls13KeyUpdateAckReceived(WOLFSSL* ssl) { int ret; - w64Increment(&ssl->dtls13Epoch); - - /* Epoch wrapped up */ - if (w64IsZero(ssl->dtls13Epoch)) - return BAD_STATE_E; ret = DeriveTls13Keys(ssl, update_traffic_key, ENCRYPT_SIDE_ONLY, 1); if (ret != 0) return ret; - ret = Dtls13NewEpoch(ssl, ssl->dtls13Epoch, ENCRYPT_SIDE_ONLY); - if (ret != 0) - return ret; + w64Increment(&ssl->dtls13Epoch); + + /* Epoch wrapped up */ + if (w64IsZero(ssl->dtls13Epoch)) + return BAD_STATE_E; return Dtls13SetEpochKeys(ssl, ssl->dtls13Epoch, ENCRYPT_SIDE_ONLY); } diff --git a/src/tls13.c b/src/tls13.c index 3a9dae263..80dcc7743 100644 --- a/src/tls13.c +++ b/src/tls13.c @@ -1471,6 +1471,9 @@ int DeriveTls13Keys(WOLFSSL* ssl, int secret, int side, int store) byte key_dig[MAX_PRF_DIG]; #endif int provision; +#ifdef WOLFSSL_DTLS13 + w64wrapper epochNumber; +#endif #if defined(WOLFSSL_RENESAS_TSIP_TLS) ret = tsip_Tls13DeriveKeys(ssl, secret, side); @@ -1626,6 +1629,34 @@ int DeriveTls13Keys(WOLFSSL* ssl, int secret, int side, int store) ret = Dtls13DeriveSnKeys(ssl, provision); if (ret != 0) return ret; + + switch (secret) { + case early_data_key: + epochNumber = w64From32(0, DTLS13_EPOCH_EARLYDATA); + break; + case handshake_key: + epochNumber = w64From32(0, DTLS13_EPOCH_HANDSHAKE); + break; + case traffic_key: + case no_key: + epochNumber = w64From32(0, DTLS13_EPOCH_TRAFFIC0); + break; + case update_traffic_key: + if (side == ENCRYPT_SIDE_ONLY) { + epochNumber = ssl->dtls13Epoch; + } + else if (side == DECRYPT_SIDE_ONLY) { + epochNumber = ssl->dtls13PeerEpoch; + } + else { + return BAD_STATE_E; + } + w64Increment(&epochNumber); + break; + } + ret = Dtls13NewEpoch(ssl, epochNumber, side); + if (ret != 0) + return ret; } #endif /* WOLFSSL_DTLS13 */ @@ -4083,15 +4114,6 @@ static int WritePSKBinders(WOLFSSL* ssl, byte* output, word32 idx) if ((ret = SetKeysSide(ssl, ENCRYPT_SIDE_ONLY)) != 0) return ret; -#ifdef WOLFSSL_DTLS13 - if (ssl->options.dtls) { - ret = Dtls13NewEpoch( - ssl, w64From32(0x0, DTLS13_EPOCH_EARLYDATA), ENCRYPT_SIDE_ONLY); - if (ret != 0) - return ret; - } -#endif /* WOLFSSL_DTLS13 */ - } #endif @@ -6296,17 +6318,6 @@ static int CheckPreSharedKeys(WOLFSSL* ssl, const byte* input, word32 helloSz, return ret; ssl->keys.encryptionOn = 1; - -#ifdef WOLFSSL_DTLS13 - if (ssl->options.dtls) { - ret = Dtls13NewEpoch(ssl, - w64From32(0x0, DTLS13_EPOCH_EARLYDATA), - DECRYPT_SIDE_ONLY); - if (ret != 0) - return ret; - } -#endif /* WOLFSSL_DTLS13 */ - ssl->earlyData = process_early_data; } else @@ -7604,11 +7615,6 @@ static int SendTls13EncryptedExtensions(WOLFSSL* ssl) w64wrapper epochHandshake = w64From32(0, DTLS13_EPOCH_HANDSHAKE); ssl->dtls13Epoch = epochHandshake; - ret = Dtls13NewEpoch( - ssl, epochHandshake, ENCRYPT_AND_DECRYPT_SIDE); - if (ret != 0) - return ret; - ret = Dtls13SetEpochKeys( ssl, epochHandshake, ENCRYPT_AND_DECRYPT_SIDE); if (ret != 0) @@ -11194,11 +11200,6 @@ static int SendTls13Finished(WOLFSSL* ssl) ssl->dtls13Epoch = epochTraffic0; ssl->dtls13PeerEpoch = epochTraffic0; - ret = Dtls13NewEpoch( - ssl, epochTraffic0, ENCRYPT_AND_DECRYPT_SIDE); - if (ret != 0) - return ret; - ret = Dtls13SetEpochKeys( ssl, epochTraffic0, ENCRYPT_AND_DECRYPT_SIDE); if (ret != 0) @@ -11236,11 +11237,6 @@ static int SendTls13Finished(WOLFSSL* ssl) ssl->dtls13Epoch = epochTraffic0; ssl->dtls13PeerEpoch = epochTraffic0; - ret = Dtls13NewEpoch( - ssl, epochTraffic0, ENCRYPT_AND_DECRYPT_SIDE); - if (ret != 0) - return ret; - ret = Dtls13SetEpochKeys( ssl, epochTraffic0, ENCRYPT_AND_DECRYPT_SIDE); if (ret != 0) @@ -11440,10 +11436,6 @@ static int DoTls13KeyUpdate(WOLFSSL* ssl, const byte* input, word32* inOutIdx, if (ssl->options.dtls) { w64Increment(&ssl->dtls13PeerEpoch); - ret = Dtls13NewEpoch(ssl, ssl->dtls13PeerEpoch, DECRYPT_SIDE_ONLY); - if (ret != 0) - return ret; - ret = Dtls13SetEpochKeys(ssl, ssl->dtls13PeerEpoch, DECRYPT_SIDE_ONLY); if (ret != 0) return ret; @@ -12859,11 +12851,6 @@ int DoTls13HandShakeMsgType(WOLFSSL* ssl, byte* input, word32* inOutIdx, ssl->dtls13Epoch = epochHandshake; ssl->dtls13PeerEpoch = epochHandshake; - ret = Dtls13NewEpoch( - ssl, epochHandshake, ENCRYPT_AND_DECRYPT_SIDE); - if (ret != 0) - return ret; - ret = Dtls13SetEpochKeys( ssl, epochHandshake, ENCRYPT_AND_DECRYPT_SIDE); if (ret != 0) diff --git a/tests/api.c b/tests/api.c index adc826ed7..f2c2d0d2d 100644 --- a/tests/api.c +++ b/tests/api.c @@ -68274,6 +68274,7 @@ TEST_CASE testCases[] = { TEST_DECL(test_wolfSSL_inject), TEST_DECL(test_wolfSSL_dtls_cid_parse), TEST_DECL(test_dtls13_epochs), + TEST_DECL(test_dtls_rtx_across_epoch_change), TEST_DECL(test_dtls13_ack_order), TEST_DECL(test_dtls_version_checking), TEST_DECL(test_ocsp_status_callback), diff --git a/tests/api/test_dtls.c b/tests/api/test_dtls.c index d5d589649..380b75327 100644 --- a/tests/api/test_dtls.c +++ b/tests/api/test_dtls.c @@ -1313,3 +1313,72 @@ int test_records_span_network_boundaries(void) } #endif /* defined(HAVE_MANUAL_MEMIO_TESTS_DEPENDENCIES) && \ !defined(WOLFSSL_NO_TLS12) */ + +int test_dtls_rtx_across_epoch_change(void) +{ + EXPECT_DECLS; +#if defined(HAVE_MANUAL_MEMIO_TESTS_DEPENDENCIES) && \ + defined(WOLFSSL_DTLS13) && defined(WOLFSSL_DTLS) + WOLFSSL_CTX *ctx_c = NULL, *ctx_s = NULL; + WOLFSSL *ssl_c = NULL, *ssl_s = NULL; + struct test_memio_ctx test_ctx; + + XMEMSET(&test_ctx, 0, sizeof(test_ctx)); + + /* Setup DTLS contexts */ + ExpectIntEQ(test_memio_setup(&test_ctx, &ctx_c, &ctx_s, &ssl_c, &ssl_s, + wolfDTLSv1_3_client_method, wolfDTLSv1_3_server_method), + 0); + + /* CH0 */ + wolfSSL_SetLoggingPrefix("client:"); + ExpectIntEQ(wolfSSL_connect(ssl_c), -1); + ExpectIntEQ(wolfSSL_get_error(ssl_c, -1), SSL_ERROR_WANT_READ); + + /* SH */ + wolfSSL_SetLoggingPrefix("server:"); + ExpectIntEQ(wolfSSL_accept(ssl_s), -1); + ExpectIntEQ(wolfSSL_get_error(ssl_s, -1), SSL_ERROR_WANT_READ); + + /* CH1 */ + wolfSSL_SetLoggingPrefix("client:"); + ExpectIntEQ(wolfSSL_connect(ssl_c), -1); + ExpectIntEQ(wolfSSL_get_error(ssl_c, -1), SSL_ERROR_WANT_READ); + + wolfSSL_SetLoggingPrefix("server:"); + ExpectIntEQ(wolfSSL_accept(ssl_s), -1); + ExpectIntEQ(wolfSSL_get_error(ssl_s, -1), SSL_ERROR_WANT_READ); + + /* we should have now SH ... FINISHED messages in the buffer*/ + ExpectIntGE(test_ctx.c_msg_count, 2); + + /* now let's drop everything but the SH */ + while (test_ctx.c_msg_count > 1 && EXPECT_SUCCESS()) { + ExpectIntEQ(test_memio_drop_message(&test_ctx, 1, test_ctx.c_msg_count - 1), 0); + } + + /* Read the SH */ + wolfSSL_SetLoggingPrefix("client:"); + ExpectIntEQ(wolfSSL_connect(ssl_c), -1); + ExpectIntEQ(wolfSSL_get_error(ssl_c, -1), SSL_ERROR_WANT_READ); + + /* trigger client timeout */ + ExpectIntEQ(wolfSSL_dtls_got_timeout(ssl_c), WOLFSSL_SUCCESS); + /* this should have triggered a rtx */ + ExpectIntGT(test_ctx.s_msg_count, 0); + + /* finish the handshake */ + ExpectIntEQ(test_memio_do_handshake(ssl_c, ssl_s, 10, NULL), 0); + + /* Test communication works correctly */ + ExpectIntEQ(test_dtls_communication(ssl_s, ssl_c), TEST_SUCCESS); + + /* Cleanup */ + wolfSSL_free(ssl_c); + wolfSSL_CTX_free(ctx_c); + wolfSSL_free(ssl_s); + wolfSSL_CTX_free(ctx_s); +#endif /* defined(HAVE_MANUAL_MEMIO_TESTS_DEPENDENCIES) && \ + defined(WOLFSSL_DTLS13) */ + return EXPECT_RESULT(); +} diff --git a/tests/api/test_dtls.h b/tests/api/test_dtls.h index 7896a9510..18ace27b6 100644 --- a/tests/api/test_dtls.h +++ b/tests/api/test_dtls.h @@ -36,4 +36,5 @@ int test_dtls13_longer_length(void); int test_dtls13_short_read(void); int test_records_span_network_boundaries(void); int test_dtls_record_cross_boundaries(void); +int test_dtls_rtx_across_epoch_change(void); #endif /* TESTS_API_DTLS_H */