From 43c564d48b6dc3ff719f4fd2e1f7f787b445b45c Mon Sep 17 00:00:00 2001 From: Juliusz Sosinowicz Date: Thu, 17 Apr 2025 16:35:32 +0200 Subject: [PATCH] dtls13: send acks with correct record number order --- src/dtls13.c | 24 +++++++++++--- tests/api.c | 1 + tests/api/test_dtls.c | 74 +++++++++++++++++++++++++++++++++++++++++++ tests/api/test_dtls.h | 1 + wolfssl/internal.h | 5 ++- 5 files changed, 100 insertions(+), 5 deletions(-) diff --git a/src/dtls13.c b/src/dtls13.c index 11d7a018f..5a9b6dca2 100644 --- a/src/dtls13.c +++ b/src/dtls13.c @@ -718,7 +718,7 @@ static Dtls13RecordNumber* Dtls13NewRecordNumber(w64wrapper epoch, return rn; } -static int Dtls13RtxAddAck(WOLFSSL* ssl, w64wrapper epoch, w64wrapper seq) +int Dtls13RtxAddAck(WOLFSSL* ssl, w64wrapper epoch, w64wrapper seq) { Dtls13RecordNumber* rn; @@ -728,12 +728,28 @@ static int Dtls13RtxAddAck(WOLFSSL* ssl, w64wrapper epoch, w64wrapper seq) if (wc_LockMutex(&ssl->dtls13Rtx.mutex) == 0) #endif { + /* Find location to insert new record */ + Dtls13RecordNumber** prevNext = &ssl->dtls13Rtx.seenRecords; + Dtls13RecordNumber* cur = ssl->dtls13Rtx.seenRecords; + + for (; cur != NULL; prevNext = &cur->next, cur = cur->next) { + if (w64Equal(cur->epoch, epoch) && w64Equal(cur->seq, seq)) { + /* already in list. no duplicates. */ + return 0; + } + else if (w64LT(epoch, cur->epoch) + || (w64Equal(epoch, cur->epoch) + && w64LT(seq, cur->seq))) { + break; + } + } + rn = Dtls13NewRecordNumber(epoch, seq, ssl->heap); if (rn == NULL) return MEMORY_E; - rn->next = ssl->dtls13Rtx.seenRecords; - ssl->dtls13Rtx.seenRecords = rn; + *prevNext = rn; + rn->next = cur; #ifdef WOLFSSL_RW_THREADED wc_UnLockMutex(&ssl->dtls13Rtx.mutex); #endif @@ -2522,7 +2538,7 @@ static int Dtls13GetAckListLength(Dtls13RecordNumber* list, word16* length) return 0; } -static int Dtls13WriteAckMessage(WOLFSSL* ssl, +int Dtls13WriteAckMessage(WOLFSSL* ssl, Dtls13RecordNumber* recordNumberList, word32* length) { word16 msgSz, headerLength; diff --git a/tests/api.c b/tests/api.c index 22d583af9..0f7d8df8c 100644 --- a/tests/api.c +++ b/tests/api.c @@ -67812,6 +67812,7 @@ TEST_CASE testCases[] = { TEST_DECL(test_wolfSSL_inject), TEST_DECL(test_wolfSSL_dtls_cid_parse), TEST_DECL(test_dtls13_epochs), + TEST_DECL(test_dtls13_ack_order), TEST_DECL(test_ocsp_status_callback), TEST_DECL(test_ocsp_basic_verify), TEST_DECL(test_ocsp_response_parsing), diff --git a/tests/api/test_dtls.c b/tests/api/test_dtls.c index 6617c6ccc..d236b316b 100644 --- a/tests/api/test_dtls.c +++ b/tests/api/test_dtls.c @@ -646,3 +646,77 @@ int test_dtls13_epochs(void) { return EXPECT_RESULT(); } +int test_dtls13_ack_order(void) +{ + EXPECT_DECLS; +#if defined(HAVE_MANUAL_MEMIO_TESTS_DEPENDENCIES) && defined(WOLFSSL_DTLS13) + WOLFSSL_CTX *ctx_c = NULL, *ctx_s = NULL; + WOLFSSL *ssl_c = NULL, *ssl_s = NULL; + struct test_memio_ctx test_ctx; + unsigned char readBuf[50]; + word32 length = 0; + /* struct { + * uint64 epoch; + * uint64 sequence_number; + * } RecordNumber; + * Big endian */ + unsigned char expected_output[] = { + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x03, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x03, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x03, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x03, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x03, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x06, + }; + + XMEMSET(&test_ctx, 0, sizeof(test_ctx)); + + /* Get a populated DTLS object */ + ExpectIntEQ(test_memio_setup(&test_ctx, &ctx_c, &ctx_s, &ssl_c, &ssl_s, + wolfDTLSv1_3_client_method, wolfDTLSv1_3_server_method), 0); + ExpectIntEQ(test_memio_do_handshake(ssl_c, ssl_s, 10, NULL), 0); + ExpectIntEQ(wolfSSL_read(ssl_c, readBuf, sizeof(readBuf)), -1); + /* Clear the buffer of any extra messages */ + ExpectIntEQ(wolfSSL_get_error(ssl_c, -1), WOLFSSL_ERROR_WANT_READ); + ExpectIntEQ(wolfSSL_read(ssl_s, readBuf, sizeof(readBuf)), -1); + ExpectIntEQ(wolfSSL_get_error(ssl_s, -1), WOLFSSL_ERROR_WANT_READ); + ExpectIntEQ(test_ctx.c_len, 0); + ExpectIntEQ(test_ctx.s_len, 0); + + /* Add seen records */ + ExpectIntEQ(Dtls13RtxAddAck(ssl_c, w64From32(0, 3), w64From32(0, 2)), 0); + ExpectIntEQ(Dtls13RtxAddAck(ssl_c, w64From32(0, 3), w64From32(0, 0)), 0); + ExpectIntEQ(Dtls13RtxAddAck(ssl_c, w64From32(0, 3), w64From32(0, 1)), 0); + ExpectIntEQ(Dtls13RtxAddAck(ssl_c, w64From32(0, 3), w64From32(0, 4)), 0); + ExpectIntEQ(Dtls13RtxAddAck(ssl_c, w64From32(0, 2), w64From32(0, 0)), 0); + ExpectIntEQ(Dtls13RtxAddAck(ssl_c, w64From32(0, 3), w64From32(0, 6)), 0); + ExpectIntEQ(Dtls13RtxAddAck(ssl_c, w64From32(0, 3), w64From32(0, 6)), 0); + ExpectIntEQ(Dtls13RtxAddAck(ssl_c, w64From32(0, 2), w64From32(0, 1)), 0); + ExpectIntEQ(Dtls13RtxAddAck(ssl_c, w64From32(0, 2), w64From32(0, 2)), 0); + ExpectIntEQ(Dtls13RtxAddAck(ssl_c, w64From32(0, 2), w64From32(0, 2)), 0); + ExpectIntEQ(Dtls13WriteAckMessage(ssl_c, ssl_c->dtls13Rtx.seenRecords, + &length), 0); + /* N * RecordNumber + 2 extra bytes for length */ + ExpectIntEQ(length, sizeof(expected_output) + 2); + ExpectNotNull(mymemmem(ssl_c->buffers.outputBuffer.buffer, + ssl_c->buffers.outputBuffer.bufferSize, expected_output, + sizeof(expected_output))); + + + wolfSSL_free(ssl_c); + wolfSSL_CTX_free(ctx_c); + wolfSSL_free(ssl_s); + wolfSSL_CTX_free(ctx_s); +#endif + return EXPECT_RESULT(); +} diff --git a/tests/api/test_dtls.h b/tests/api/test_dtls.h index a44b03676..bf82a2035 100644 --- a/tests/api/test_dtls.h +++ b/tests/api/test_dtls.h @@ -26,5 +26,6 @@ int test_dtls12_basic_connection_id(void); int test_dtls13_basic_connection_id(void); int test_wolfSSL_dtls_cid_parse(void); int test_dtls13_epochs(void); +int test_dtls13_ack_order(void); #endif /* TESTS_API_DTLS_H */ diff --git a/wolfssl/internal.h b/wolfssl/internal.h index fa211b732..29f4e0681 100644 --- a/wolfssl/internal.h +++ b/wolfssl/internal.h @@ -6578,7 +6578,7 @@ WOLFSSL_LOCAL int TLSv1_3_Capable(WOLFSSL* ssl); WOLFSSL_LOCAL void FreeHandshakeResources(WOLFSSL* ssl); WOLFSSL_LOCAL void ShrinkInputBuffer(WOLFSSL* ssl, int forcedFree); WOLFSSL_LOCAL void ShrinkOutputBuffer(WOLFSSL* ssl); -WOLFSSL_LOCAL byte* GetOutputBuffer(WOLFSSL* ssl); +WOLFSSL_TEST_VIS byte* GetOutputBuffer(WOLFSSL* ssl); WOLFSSL_LOCAL int CipherRequires(byte first, byte second, int requirement); WOLFSSL_LOCAL int VerifyClientSuite(word16 havePSK, byte cipherSuite0, @@ -7066,7 +7066,10 @@ WOLFSSL_LOCAL int Dtls13ReconstructEpochNumber(WOLFSSL* ssl, byte epochBits, w64wrapper* epoch); WOLFSSL_LOCAL int Dtls13ReconstructSeqNumber(WOLFSSL* ssl, Dtls13UnifiedHdrInfo* hdrInfo, w64wrapper* out); +WOLFSSL_TEST_VIS int Dtls13WriteAckMessage(WOLFSSL* ssl, + Dtls13RecordNumber* recordNumberList, word32* length); WOLFSSL_LOCAL int SendDtls13Ack(WOLFSSL* ssl); +WOLFSSL_TEST_VIS int Dtls13RtxAddAck(WOLFSSL* ssl, w64wrapper epoch, w64wrapper seq); WOLFSSL_LOCAL int Dtls13RtxProcessingCertificate(WOLFSSL* ssl, byte* input, word32 inputSize); WOLFSSL_LOCAL int Dtls13HashHandshake(WOLFSSL* ssl, const byte* input,