diff --git a/doc/dox_comments/header_files/ssl.h b/doc/dox_comments/header_files/ssl.h index ca208ce56..114f3c5bc 100644 --- a/doc/dox_comments/header_files/ssl.h +++ b/doc/dox_comments/header_files/ssl.h @@ -5162,6 +5162,61 @@ WOLFSSL_API void wolfSSL_CTX_set_psk_server_callback(WOLFSSL_CTX*, WOLFSSL_API void wolfSSL_set_psk_server_callback(WOLFSSL*, wc_psk_server_callback); + +/*! + \brief Sets a PSK user context in the WOLFSSL structure options member. + + \return WOLFSSL_SUCCESS or WOLFSSL_FAILURE + + \param ssl a pointer to a WOLFSSL structure, created using wolfSSL_new(). + \param psk_ctx void pointer to user PSK context + + \sa wolfSSL_get_psk_callback_ctx + \sa wolfSSL_CTX_set_psk_callback_ctx + \sa wolfSSL_CTX_get_psk_callback_ctx +*/ +WOLFSSL_API int wolfSSL_set_psk_callback_ctx(WOLFSSL* ssl, void* psk_ctx); + +/*! + \brief Sets a PSK user context in the WOLFSSL_CTX structure. + + \return WOLFSSL_SUCCESS or WOLFSSL_FAILURE + + \param ctx a pointer to a WOLFSSL_CTX structure, created using wolfSSL_CTX_new(). + \param psk_ctx void pointer to user PSK context + + \sa wolfSSL_set_psk_callback_ctx + \sa wolfSSL_get_psk_callback_ctx + \sa wolfSSL_CTX_get_psk_callback_ctx +*/ +WOLFSSL_API int wolfSSL_CTX_set_psk_callback_ctx(WOLFSSL_CTX* ctx, void* psk_ctx); + +/*! + \brief Get a PSK user context in the WOLFSSL structure options member. + + \return void pointer to user PSK context + + \param ssl a pointer to a WOLFSSL structure, created using wolfSSL_new(). + + \sa wolfSSL_set_psk_callback_ctx + \sa wolfSSL_CTX_set_psk_callback_ctx + \sa wolfSSL_CTX_get_psk_callback_ctx +*/ +WOLFSSL_API void* wolfSSL_get_psk_callback_ctx(WOLFSSL* ssl); + +/*! + \brief Get a PSK user context in the WOLFSSL_CTX structure. + + \return void pointer to user PSK context + + \param ctx a pointer to a WOLFSSL_CTX structure, created using wolfSSL_CTX_new(). + + \sa wolfSSL_CTX_set_psk_callback_ctx + \sa wolfSSL_set_psk_callback_ctx + \sa wolfSSL_get_psk_callback_ctx +*/ +WOLFSSL_API void* wolfSSL_CTX_get_psk_callback_ctx(WOLFSSL_CTX* ctx); + /*! \ingroup Setup diff --git a/src/internal.c b/src/internal.c index 6b390b64c..ab5ba26e7 100644 --- a/src/internal.c +++ b/src/internal.c @@ -5279,9 +5279,10 @@ int SetSSL_CTX(WOLFSSL* ssl, WOLFSSL_CTX* ctx, int writeDup) ssl->options.haveStaticECC = ctx->haveStaticECC; #ifndef NO_PSK - ssl->options.havePSK = ctx->havePSK; + ssl->options.havePSK = ctx->havePSK; ssl->options.client_psk_cb = ctx->client_psk_cb; ssl->options.server_psk_cb = ctx->server_psk_cb; + ssl->options.psk_ctx = ctx->psk_ctx; #ifdef WOLFSSL_TLS13 ssl->options.client_psk_tls13_cb = ctx->client_psk_tls13_cb; ssl->options.server_psk_tls13_cb = ctx->server_psk_tls13_cb; diff --git a/src/ssl.c b/src/ssl.c index 278d4fb16..173b3d67c 100644 --- a/src/ssl.c +++ b/src/ssl.c @@ -13458,7 +13458,6 @@ int wolfSSL_set_compression(WOLFSSL* ssl) ctx->client_psk_cb = cb; } - void wolfSSL_set_psk_client_callback(WOLFSSL* ssl,wc_psk_client_callback cb) { byte haveRSA = 1; @@ -13484,7 +13483,6 @@ int wolfSSL_set_compression(WOLFSSL* ssl) ssl->options.haveStaticECC, ssl->options.side); } - void wolfSSL_CTX_set_psk_server_callback(WOLFSSL_CTX* ctx, wc_psk_server_callback cb) { @@ -13495,7 +13493,6 @@ int wolfSSL_set_compression(WOLFSSL* ssl) ctx->server_psk_cb = cb; } - void wolfSSL_set_psk_server_callback(WOLFSSL* ssl,wc_psk_server_callback cb) { byte haveRSA = 1; @@ -13520,7 +13517,6 @@ int wolfSSL_set_compression(WOLFSSL* ssl) ssl->options.haveStaticECC, ssl->options.side); } - const char* wolfSSL_get_psk_identity_hint(const WOLFSSL* ssl) { WOLFSSL_ENTER("SSL_get_psk_identity_hint"); @@ -13542,7 +13538,6 @@ int wolfSSL_set_compression(WOLFSSL* ssl) return ssl->arrays->client_identity; } - int wolfSSL_CTX_use_psk_identity_hint(WOLFSSL_CTX* ctx, const char* hint) { WOLFSSL_ENTER("SSL_CTX_use_psk_identity_hint"); @@ -13559,7 +13554,6 @@ int wolfSSL_set_compression(WOLFSSL* ssl) return WOLFSSL_SUCCESS; } - int wolfSSL_use_psk_identity_hint(WOLFSSL* ssl, const char* hint) { WOLFSSL_ENTER("SSL_use_psk_identity_hint"); @@ -13577,6 +13571,28 @@ int wolfSSL_set_compression(WOLFSSL* ssl) return WOLFSSL_SUCCESS; } + void* wolfSSL_get_psk_callback_ctx(WOLFSSL* ssl) + { + return ssl ? ssl->options.psk_ctx : NULL; + } + void* wolfSSL_CTX_get_psk_callback_ctx(WOLFSSL_CTX* ctx) + { + return ctx ? ctx->psk_ctx : NULL; + } + int wolfSSL_set_psk_callback_ctx(WOLFSSL* ssl, void* psk_ctx) + { + if (ssl == NULL) + return WOLFSSL_FAILURE; + ssl->options.psk_ctx = psk_ctx; + return WOLFSSL_SUCCESS; + } + int wolfSSL_CTX_set_psk_callback_ctx(WOLFSSL_CTX* ctx, void* psk_ctx) + { + if (ctx == NULL) + return WOLFSSL_FAILURE; + ctx->psk_ctx = psk_ctx; + return WOLFSSL_SUCCESS; + } #endif /* NO_PSK */ diff --git a/wolfssl/internal.h b/wolfssl/internal.h index 18541a648..a4af16aea 100644 --- a/wolfssl/internal.h +++ b/wolfssl/internal.h @@ -2748,6 +2748,7 @@ struct WOLFSSL_CTX { wc_psk_client_tls13_callback client_psk_tls13_cb; /* client callback */ wc_psk_server_tls13_callback server_psk_tls13_cb; /* server callback */ #endif + void* psk_ctx; char server_hint[MAX_PSK_ID_LEN + NULL_TERM_LEN]; #endif /* HAVE_SESSION_TICKET || !NO_PSK */ #ifdef WOLFSSL_TLS13 @@ -3337,6 +3338,7 @@ typedef struct Options { wc_psk_client_tls13_callback client_psk_tls13_cb; /* client callback */ wc_psk_server_tls13_callback server_psk_tls13_cb; /* server callback */ #endif + void* psk_ctx; #endif /* NO_PSK */ #if defined(OPENSSL_EXTRA) || defined(HAVE_WEBSERVER) || defined(WOLFSSL_WPAS_SMALL) unsigned long mask; /* store SSL_OP_ flags */ diff --git a/wolfssl/ssl.h b/wolfssl/ssl.h index 96a3e1d0e..39d7a9911 100644 --- a/wolfssl/ssl.h +++ b/wolfssl/ssl.h @@ -1950,6 +1950,11 @@ enum { /* ssl Constants */ WOLFSSL_API void wolfSSL_set_psk_server_tls13_callback(WOLFSSL*, wc_psk_server_tls13_callback); #endif + WOLFSSL_API void* wolfSSL_get_psk_callback_ctx(WOLFSSL*); + WOLFSSL_API int wolfSSL_set_psk_callback_ctx(WOLFSSL*, void*); + + WOLFSSL_API void* wolfSSL_CTX_get_psk_callback_ctx(WOLFSSL_CTX*); + WOLFSSL_API int wolfSSL_CTX_set_psk_callback_ctx(WOLFSSL_CTX*, void*); #define PSK_TYPES_DEFINED #endif /* NO_PSK */