dtls: Refactor handshake timeout logic and simplify newConn

pull/323/head
Juliusz Sosinowicz 2022-06-30 16:29:43 +02:00
parent 88cf4e3e5e
commit 656762d313
1 changed files with 17 additions and 58 deletions

View File

@ -248,33 +248,14 @@ static int newPendingSSL(void)
static void newConn(evutil_socket_t fd, short events, void* arg)
{
struct sockaddr_in cliaddr; /* the client's address */
socklen_t cliLen;
char b;
int ret;
int err;
int drop = 1;
/* Store pointer because pendingSSL can be modified in chGoodCb */
WOLFSSL* ssl = pendingSSL;
(void)events;
(void)arg;
/* Get the incoming address */
cliLen = sizeof(cliaddr);
ret = (int)recvfrom(listenfd, &b, sizeof(b), MSG_PEEK,
(struct sockaddr*)&cliaddr, &cliLen);
if (ret <= 0) {
perror("recvfrom()");
goto error;
}
if (wolfSSL_dtls_set_peer(ssl, &cliaddr, cliLen) != WOLFSSL_SUCCESS) {
fprintf(stderr, "wolfSSL_dtls_set_peer error.\n");
goto error;
}
drop = 0;
ret = wolfSSL_accept(ssl);
if (ret != WOLFSSL_SUCCESS) {
err = wolfSSL_get_error(ssl, 0);
@ -287,11 +268,19 @@ static void newConn(evutil_socket_t fd, short events, void* arg)
exit(1);
}
}
}
error:
/* Drop the datagram */
if (drop)
(void)recv(listenfd, &b, sizeof(b), 0);
static void setHsTimeout(WOLFSSL* ssl, struct timeval *tv)
{
int timeout = wolfSSL_dtls_get_current_timeout(ssl);
if (wolfSSL_dtls13_use_quick_timeout(ssl)) {
if (timeout >= QUICK_MULT)
tv->tv_sec = timeout / QUICK_MULT;
else
tv->tv_usec = timeout * 1000000 / QUICK_MULT;
}
else
tv->tv_sec = timeout;
}
/* Called when we have verified a connection */
@ -320,7 +309,7 @@ static int chGoodCb(WOLFSSL* ssl, void* arg)
goto error;
}
/* We need to change the sfd here so that the ssl object doesn't drop any
/* We need to change the SFD here so that the ssl object doesn't drop any
* new connections */
fd = newFD();
if (fd == INVALID_SOCKET)
@ -332,14 +321,7 @@ static int chGoodCb(WOLFSSL* ssl, void* arg)
goto error;
}
/* Have the ssl object use only the SFD without the peer address (since the
* SFD is now connected) */
if (wolfSSL_dtls_set_peer(ssl, NULL, 0) != WOLFSSL_SUCCESS) {
fprintf(stderr, "wolfSSL_dtls_set_peer error.\n");
goto error;
}
if (wolfSSL_set_fd(ssl, fd) != WOLFSSL_SUCCESS) {
if (wolfSSL_set_dtls_fd_connected(ssl, fd) != WOLFSSL_SUCCESS) {
fprintf(stderr, "wolfSSL_set_fd error.\n");
goto error;
}
@ -355,14 +337,7 @@ static int chGoodCb(WOLFSSL* ssl, void* arg)
goto error;
}
memset(&tv, 0, sizeof(tv));
if (wolfSSL_dtls13_use_quick_timeout(ssl)) {
if (timeout >= QUICK_MULT)
tv.tv_sec = timeout / QUICK_MULT;
else
tv.tv_usec = timeout * 1000000 / QUICK_MULT;
}
else
tv.tv_sec = timeout;
setHsTimeout(ssl, &tv);
/* We are using non-blocking sockets so we will definitely be waiting for
* the peer. Start the timer now. */
if (event_add(connCtx->readEv, &tv) != 0) {
@ -415,15 +390,7 @@ static void dataReady(evutil_socket_t fd, short events, void* arg)
fprintf(stderr, "wolfSSL_dtls_got_timeout failed\n");
goto error;
}
timeout = wolfSSL_dtls_get_current_timeout(connCtx->ssl);
if (wolfSSL_dtls13_use_quick_timeout(connCtx->ssl)) {
if (timeout >= QUICK_MULT)
tv.tv_sec = timeout / QUICK_MULT;
else
tv.tv_usec = timeout * 1000000 / QUICK_MULT;
}
else
tv.tv_sec = timeout;
setHsTimeout(connCtx->ssl, &tv);
if (event_add(connCtx->readEv, &tv) != 0) {
fprintf(stderr, "event_add failed\n");
goto error;
@ -460,15 +427,7 @@ static void dataReady(evutil_socket_t fd, short events, void* arg)
err = wolfSSL_get_error(connCtx->ssl, 0);
if (err == WOLFSSL_ERROR_WANT_READ ||
err == WOLFSSL_ERROR_WANT_WRITE) {
timeout = wolfSSL_dtls_get_current_timeout(connCtx->ssl);
if (wolfSSL_dtls13_use_quick_timeout(connCtx->ssl)) {
if (timeout >= QUICK_MULT)
tv.tv_sec = timeout / QUICK_MULT;
else
tv.tv_usec = timeout * 1000000 / QUICK_MULT;
}
else
tv.tv_sec = timeout;
setHsTimeout(connCtx->ssl, &tv);
if (event_add(err == WOLFSSL_ERROR_WANT_READ ?
connCtx->readEv : connCtx->writeEv, &tv) != 0) {
fprintf(stderr, "event_add failed\n");