diff --git a/cyassl/internal.h b/cyassl/internal.h index 5522a1a0f..6cb8a9163 100644 --- a/cyassl/internal.h +++ b/cyassl/internal.h @@ -1516,7 +1516,7 @@ CYASSL_LOCAL int SendCertificate(CYASSL*); CYASSL_LOCAL int SendCertificateRequest(CYASSL*); CYASSL_LOCAL int SendServerKeyExchange(CYASSL*); CYASSL_LOCAL int SendBuffered(CYASSL*); -CYASSL_LOCAL int ReceiveData(CYASSL*, byte*, int); +CYASSL_LOCAL int ReceiveData(CYASSL*, byte*, int, int); CYASSL_LOCAL int SendFinished(CYASSL*); CYASSL_LOCAL int SendAlert(CYASSL*, int, int); CYASSL_LOCAL int ProcessReply(CYASSL*); diff --git a/cyassl/openssl/ssl.h b/cyassl/openssl/ssl.h index f268b6a37..432abd32d 100644 --- a/cyassl/openssl/ssl.h +++ b/cyassl/openssl/ssl.h @@ -116,6 +116,7 @@ typedef CYASSL_X509_STORE_CTX X509_STORE_CTX; #define SSL_write CyaSSL_write #define SSL_read CyaSSL_read +#define SSL_peek CyaSSL_peek #define SSL_accept CyaSSL_accept #define SSL_CTX_free CyaSSL_CTX_free #define SSL_free CyaSSL_free diff --git a/cyassl/ssl.h b/cyassl/ssl.h index 10b1be9bd..2bbbfa054 100644 --- a/cyassl/ssl.h +++ b/cyassl/ssl.h @@ -190,6 +190,7 @@ CYASSL_API int CyaSSL_connect(CYASSL*); /* please see note at top of README if you get an error from connect */ CYASSL_API int CyaSSL_write(CYASSL*, const void*, int); CYASSL_API int CyaSSL_read(CYASSL*, void*, int); +CYASSL_API int CyaSSL_peek(CYASSL*, void*, int); CYASSL_API int CyaSSL_accept(CYASSL*); CYASSL_API void CyaSSL_CTX_free(CYASSL_CTX*); CYASSL_API void CyaSSL_free(CYASSL*); diff --git a/src/internal.c b/src/internal.c index 6215133a2..8d80b7f17 100644 --- a/src/internal.c +++ b/src/internal.c @@ -4065,7 +4065,7 @@ int SendData(CYASSL* ssl, const void* data, int sz) } /* process input data */ -int ReceiveData(CYASSL* ssl, byte* output, int sz) +int ReceiveData(CYASSL* ssl, byte* output, int sz, int peek) { int size; @@ -4104,8 +4104,11 @@ int ReceiveData(CYASSL* ssl, byte* output, int sz) size = ssl->buffers.clearOutputBuffer.length; XMEMCPY(output, ssl->buffers.clearOutputBuffer.buffer, size); - ssl->buffers.clearOutputBuffer.length -= size; - ssl->buffers.clearOutputBuffer.buffer += size; + + if (peek == 0) { + ssl->buffers.clearOutputBuffer.length -= size; + ssl->buffers.clearOutputBuffer.buffer += size; + } if (ssl->buffers.clearOutputBuffer.length == 0 && ssl->buffers.inputBuffer.dynamicFlag) diff --git a/src/ssl.c b/src/ssl.c index b712cd030..1a52e42ad 100644 --- a/src/ssl.c +++ b/src/ssl.c @@ -380,19 +380,19 @@ int CyaSSL_write(CYASSL* ssl, const void* data, int sz) } -int CyaSSL_read(CYASSL* ssl, void* data, int sz) +static int CyaSSL_read_internal(CYASSL* ssl, void* data, int sz, int peek) { int ret; - CYASSL_ENTER("SSL_read()"); + CYASSL_ENTER("CyaSSL_read_internal()"); #ifdef HAVE_ERRNO_H errno = 0; #endif - ret = ReceiveData(ssl, (byte*)data, min(sz, OUTPUT_RECORD_SIZE)); + ret = ReceiveData(ssl, (byte*)data, min(sz, OUTPUT_RECORD_SIZE), peek); - CYASSL_LEAVE("SSL_read()", ret); + CYASSL_LEAVE("CyaSSL_read_internal()", ret); if (ret < 0) return SSL_FATAL_ERROR; @@ -401,6 +401,22 @@ int CyaSSL_read(CYASSL* ssl, void* data, int sz) } +int CyaSSL_peek(CYASSL* ssl, void* data, int sz) +{ + CYASSL_ENTER("CyaSSL_peek()"); + + return CyaSSL_read_internal(ssl, data, sz, TRUE); +} + + +int CyaSSL_read(CYASSL* ssl, void* data, int sz) +{ + CYASSL_ENTER("CyaSSL_read()"); + + return CyaSSL_read_internal(ssl, data, sz, FALSE); +} + + int CyaSSL_send(CYASSL* ssl, const void* data, int sz, int flags) { int ret;