JNI/JSSE: optimize out array creation in WolfSSLEngine RecvAppData(), pass ByteBuffer down to JNI directly

pull/244/head
Chris Conlon 2024-12-23 14:04:27 -07:00
parent 9db7ff1f49
commit 12eae28c14
4 changed files with 379 additions and 88 deletions

View File

@ -1002,16 +1002,104 @@ JNIEXPORT jint JNICALL Java_com_wolfssl_WolfSSLSession_write
}
}
JNIEXPORT jint JNICALL Java_com_wolfssl_WolfSSLSession_read
(JNIEnv* jenv, jobject jcl, jlong sslPtr, jbyteArray raw, jint offset,
jint length, jint timeout)
/**
* Read len bytes from wolfSSL_read() back into provided output buffer.
*
* Internal function called by WolfSSLSession.read() calls.
*
* If wolfSSL_get_fd(ssl) returns a socket descriptor, try to wait for
* data with select()/poll() up to provided timeout.
*
* Returns number of bytes read on success, or negative on error.
*/
static int SSLReadNonblockingWithSelectPoll(WOLFSSL* ssl, byte* out,
int length, int timeout)
{
byte* data = NULL;
int size = 0, ret, err, sockfd;
int size, ret, err, sockfd;
int pollRx = 0;
int pollTx = 0;
wolfSSL_Mutex* jniSessLock = NULL;
SSLAppData* appData = NULL;
if (ssl == NULL || out == NULL) {
return BAD_FUNC_ARG;
}
/* get session mutex from SSL app data */
appData = (SSLAppData*)wolfSSL_get_app_data(ssl);
if (appData == NULL) {
return WOLFSSL_FAILURE;
}
jniSessLock = appData->jniSessLock;
if (jniSessLock == NULL) {
return WOLFSSL_FAILURE;
}
do {
/* lock mutex around session I/O before read attempt */
if (wc_LockMutex(jniSessLock) != 0) {
size = WOLFSSL_FAILURE;
break;
}
size = wolfSSL_read(ssl, out, length);
err = wolfSSL_get_error(ssl, size);
/* unlock mutex around session I/O after read attempt */
if (wc_UnLockMutex(jniSessLock) != 0) {
size = WOLFSSL_FAILURE;
break;
}
if (size < 0 &&
((err == SSL_ERROR_WANT_READ) || (err == SSL_ERROR_WANT_WRITE))) {
sockfd = wolfSSL_get_fd(ssl);
if (sockfd == -1) {
/* For I/O that does not use sockets, sockfd may be -1,
* skip try to call select() */
break;
}
if (err == SSL_ERROR_WANT_READ) {
pollRx = 1;
}
else if (err == SSL_ERROR_WANT_WRITE) {
pollTx = 1;
}
#if defined(WOLFJNI_USE_IO_SELECT) || defined(USE_WINDOWS_API)
ret = socketSelect(sockfd, timeout, pollRx);
#else
ret = socketPoll(sockfd, timeout, pollRx, pollTx);
#endif
if ((ret == WOLFJNI_IO_EVENT_RECV_READY) ||
(ret == WOLFJNI_IO_EVENT_SEND_READY)) {
/* loop around and try wolfSSL_read() again */
continue;
} else {
/* Java will throw SocketTimeoutException or
* SocketException if ret equals
* WOLFJNI_IO_EVENT_TIMEOUT, WOLFJNI_IO_EVENT_FD_CLOSED
* WOLFJNI_IO_EVENT_ERROR, WOLFJNI_IO_EVENT_POLLHUP or
* WOLFJNI_IO_EVENT_FAIL */
size = ret;
break;
}
}
} while (err == SSL_ERROR_WANT_WRITE || err == SSL_ERROR_WANT_READ);
return size;
}
JNIEXPORT jint JNICALL Java_com_wolfssl_WolfSSLSession_read__J_3BIII
(JNIEnv* jenv, jobject jcl, jlong sslPtr, jbyteArray raw, jint offset,
jint length, jint timeout)
{
int size = 0;
byte* data = NULL;
WOLFSSL* ssl = (WOLFSSL*)(uintptr_t)sslPtr;
(void)jcl;
@ -1027,79 +1115,178 @@ JNIEXPORT jint JNICALL Java_com_wolfssl_WolfSSLSession_read
return SSL_FAILURE;
}
/* get session mutex from SSL app data */
appData = (SSLAppData*)wolfSSL_get_app_data(ssl);
if (appData == NULL) {
size = SSLReadNonblockingWithSelectPoll(ssl, data + offset,
(int)length, (int)timeout);
if (size < 0) {
(*jenv)->ReleaseByteArrayElements(jenv, raw, (jbyte*)data,
JNI_ABORT);
return WOLFSSL_FAILURE;
JNI_ABORT);
}
else {
/* JNI_COMMIT commits the data but does not free the local array
* 0 is used here to both commit and free */
(*jenv)->ReleaseByteArrayElements(jenv, raw, (jbyte*)data, 0);
}
}
return size;
}
JNIEXPORT jint JNICALL Java_com_wolfssl_WolfSSLSession_read__JLjava_nio_ByteBuffer_2II
(JNIEnv* jenv, jobject jcl, jlong sslPtr, jobject buf, jint length, jint timeout)
{
int size = 0;
int maxOutputSz;
int outSz = length;
byte* data = NULL;
WOLFSSL* ssl = (WOLFSSL*)(uintptr_t)sslPtr;
jclass excClass;
jclass buffClass;
jmethodID positionMeth;
jmethodID limitMeth;
jmethodID hasArrayMeth;
jmethodID arrayMeth;
jmethodID setPositionMeth;
jint position;
jint limit;
jboolean hasArray;
jbyteArray bufArr;
(void)jcl;
if (jenv == NULL || ssl == NULL || buf == NULL) {
return BAD_FUNC_ARG;
}
if (length > 0) {
/* Get WolfSSLException class */
excClass = (*jenv)->FindClass(jenv, "com/wolfssl/WolfSSLException");
if ((*jenv)->ExceptionOccurred(jenv)) {
(*jenv)->ExceptionDescribe(jenv);
(*jenv)->ExceptionClear(jenv);
return -1;
}
jniSessLock = appData->jniSessLock;
if (jniSessLock == NULL) {
(*jenv)->ReleaseByteArrayElements(jenv, raw, (jbyte*)data,
JNI_ABORT);
return WOLFSSL_FAILURE;
/* Get ByteBuffer class */
buffClass = (*jenv)->GetObjectClass(jenv, buf);
if (buffClass == NULL) {
(*jenv)->ThrowNew(jenv, excClass,
"Failed to find ByteBuffer class in native read()");
return -1;
}
do {
/* lock mutex around session I/O before read attempt */
if (wc_LockMutex(jniSessLock) != 0) {
size = WOLFSSL_FAILURE;
break;
/* Get ByteBuffer position */
positionMeth = (*jenv)->GetMethodID(jenv, buffClass, "position", "()I");
if (positionMeth == NULL) {
if ((*jenv)->ExceptionOccurred(jenv)) {
(*jenv)->ExceptionDescribe(jenv);
(*jenv)->ExceptionClear(jenv);
}
(*jenv)->ThrowNew(jenv, excClass,
"Failed to find ByteBuffer position() method in native read()");
return -1;
}
position = (*jenv)->CallIntMethod(jenv, buf, positionMeth);
size = wolfSSL_read(ssl, data + offset, length);
err = wolfSSL_get_error(ssl, size);
/* unlock mutex around session I/O after read attempt */
if (wc_UnLockMutex(jniSessLock) != 0) {
size = WOLFSSL_FAILURE;
break;
/* Get ByteBuffer limit */
limitMeth = (*jenv)->GetMethodID(jenv, buffClass, "limit", "()I");
if (limitMeth == NULL) {
if ((*jenv)->ExceptionOccurred(jenv)) {
(*jenv)->ExceptionDescribe(jenv);
(*jenv)->ExceptionClear(jenv);
}
(*jenv)->ThrowNew(jenv, excClass,
"Failed to find ByteBuffer limit() method in native read()");
return -1;
}
limit = (*jenv)->CallIntMethod(jenv, buf, limitMeth);
if (size < 0 && ((err == SSL_ERROR_WANT_READ) || \
(err == SSL_ERROR_WANT_WRITE))) {
/* Get and call ByteBuffer.hasArray() before calling array() */
hasArrayMeth = (*jenv)->GetMethodID(jenv, buffClass, "hasArray", "()Z");
if (hasArrayMeth == NULL) {
if ((*jenv)->ExceptionOccurred(jenv)) {
(*jenv)->ExceptionDescribe(jenv);
(*jenv)->ExceptionClear(jenv);
}
(*jenv)->ThrowNew(jenv, excClass,
"Failed to find ByteBuffer hasArray() method in native read()");
return -1;
}
sockfd = wolfSSL_get_fd(ssl);
if (sockfd == -1) {
/* For I/O that does not use sockets, sockfd may be -1,
* skip try to call select() */
break;
/* ByteBuffer.hasArray() does not throw any exceptions */
hasArray = (*jenv)->CallBooleanMethod(jenv, buf, hasArrayMeth);
if (!hasArray) {
(*jenv)->ThrowNew(jenv, excClass,
"ByteBuffer.hasArray() is false in native read()");
return BAD_FUNC_ARG;
}
/* Only read up to maximum space we have in this ByteBuffer */
maxOutputSz = (limit - position);
if (outSz > maxOutputSz) {
outSz = maxOutputSz;
}
/* Get reference to underlying byte[] from ByteBuffer */
arrayMeth = (*jenv)->GetMethodID(jenv, buffClass, "array", "()[B");
if (arrayMeth == NULL) {
if ((*jenv)->ExceptionOccurred(jenv)) {
(*jenv)->ExceptionDescribe(jenv);
(*jenv)->ExceptionClear(jenv);
}
(*jenv)->ThrowNew(jenv, excClass,
"Failed to find ByteBuffer array() method in native read()");
return -1;
}
bufArr = (jbyteArray)(*jenv)->CallObjectMethod(jenv, buf, arrayMeth);
/* Get array elements */
data = (byte*)(*jenv)->GetByteArrayElements(jenv, bufArr, NULL);
if ((*jenv)->ExceptionOccurred(jenv)) {
(*jenv)->ExceptionDescribe(jenv);
(*jenv)->ExceptionClear(jenv);
(*jenv)->ThrowNew(jenv, excClass,
"Exception when calling ByteBuffer.array() in native read()");
return -1;
}
if (data != NULL) {
size = SSLReadNonblockingWithSelectPoll(ssl, data + position,
maxOutputSz, (int)timeout);
/* Relase array elements */
if (size < 0) {
(*jenv)->ReleaseByteArrayElements(jenv, bufArr, (jbyte*)data,
JNI_ABORT);
}
else {
/* JNI_COMMIT commits the data but does not free the local array
* 0 is used here to both commit and free */
(*jenv)->ReleaseByteArrayElements(jenv, bufArr,
(jbyte*)data, 0);
/* Update ByteBuffer position() based on bytes written */
setPositionMeth = (*jenv)->GetMethodID(jenv, buffClass,
"position", "(I)Ljava/nio/Buffer;");
if (setPositionMeth == NULL) {
if ((*jenv)->ExceptionOccurred(jenv)) {
(*jenv)->ExceptionDescribe(jenv);
(*jenv)->ExceptionClear(jenv);
}
(*jenv)->ThrowNew(jenv, excClass,
"Failed to set ByteBuffer position() from "
"native read()");
size = -1;
}
if (err == SSL_ERROR_WANT_READ) {
pollRx = 1;
}
else if (err == SSL_ERROR_WANT_WRITE) {
pollTx = 1;
}
#if defined(WOLFJNI_USE_IO_SELECT) || defined(USE_WINDOWS_API)
ret = socketSelect(sockfd, (int)timeout, pollRx);
#else
ret = socketPoll(sockfd, (int)timeout, pollRx, pollTx);
#endif
if ((ret == WOLFJNI_IO_EVENT_RECV_READY) ||
(ret == WOLFJNI_IO_EVENT_SEND_READY)) {
/* loop around and try wolfSSL_read() again */
continue;
} else {
/* Java will throw SocketTimeoutException or
* SocketException if ret equals
* WOLFJNI_IO_EVENT_TIMEOUT, WOLFJNI_IO_EVENT_FD_CLOSED
* WOLFJNI_IO_EVENT_ERROR, WOLFJNI_IO_EVENT_POLLHUP or
* WOLFJNI_IO_EVENT_FAIL */
size = ret;
break;
else {
(*jenv)->CallVoidMethod(jenv, buf, setPositionMeth,
position + size);
}
}
} while (err == SSL_ERROR_WANT_WRITE || err == SSL_ERROR_WANT_READ);
/* JNI_COMMIT commits the data but does not free the local array
* 0 is used here to both commit and free */
(*jenv)->ReleaseByteArrayElements(jenv, raw, (jbyte*)data, 0);
}
}
return size;

View File

@ -100,9 +100,17 @@ JNIEXPORT jint JNICALL Java_com_wolfssl_WolfSSLSession_write
* Method: read
* Signature: (J[BIII)I
*/
JNIEXPORT jint JNICALL Java_com_wolfssl_WolfSSLSession_read
JNIEXPORT jint JNICALL Java_com_wolfssl_WolfSSLSession_read__J_3BIII
(JNIEnv *, jobject, jlong, jbyteArray, jint, jint, jint);
/*
* Class: com_wolfssl_WolfSSLSession
* Method: read
* Signature: (JLjava/nio/ByteBuffer;II)I
*/
JNIEXPORT jint JNICALL Java_com_wolfssl_WolfSSLSession_read__JLjava_nio_ByteBuffer_2II
(JNIEnv *, jobject, jlong, jobject, jint, jint);
/*
* Class: com_wolfssl_WolfSSLSession
* Method: accept

View File

@ -28,6 +28,7 @@ import java.net.DatagramSocket;
import java.net.SocketException;
import java.net.SocketTimeoutException;
import java.lang.StringBuilder;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import com.wolfssl.WolfSSLException;
@ -253,6 +254,8 @@ public class WolfSSLSession {
int timeout);
private native int read(long ssl, byte[] data, int offset, int sz,
int timeout);
private native int read(long ssl, ByteBuffer data, int sz, int timeout)
throws WolfSSLException;
private native int accept(long ssl, int timeout);
private native void freeSSL(long ssl);
private native int shutdownSSL(long ssl, int timeout);
@ -1112,6 +1115,86 @@ public class WolfSSLSession {
return ret;
}
/**
* Reads bytes from the SSL session and returns the read bytes into
* the provided ByteBuffer, using socket timeout value in milliseconds.
*
* The bytes read are removed from the internal receive buffer.
* <p>
* If necessary, <code>read()</code> will negotiate an SSL/TLS session
* if the handshake has not already been performed yet by <code>connect()
* </code> or <code>accept()</code>.
* <p>
* The SSL/TLS protocol uses SSL records which have a maximum size of
* 16kB. As such, wolfSSL needs to read an entire SSL record internally
* before it is able to process and decrypt the record. Because of this,
* a call to <code>read()</code> will only be able to return the
* maximum buffer size which has been decrypted at the time of calling.
* There may be additional not-yet-decrypted data waiting in the internal
* wolfSSL receive buffer which will be retrieved and decrypted with the
* next call to <code>read()</code>.
*
* @param data ByteBuffer where the data read from the SSL connection
* will be placed. position() will be updated after this
* method writes data to the ByteBuffer.
* @param sz number of bytes to read into <b><code>data</code></b>,
* may be adjusted to the maximum space in data if that is
* smaller than this size.
* @param timeout read timeout, milliseconds.
* @return the number of bytes read upon success. <code>SSL_FAILURE
* </code> will be returned upon failure which may be caused
* by either a clean (close notify alert) shutdown or just
* that the peer closed the connection. <code>
* SSL_FATAL_ERROR</code> upon failure when either an error
* occurred or, when using non-blocking sockets, the
* <b>SSL_ERROR_WANT_READ</b> or <b>SSL_ERROR_WANT_WRITE</b>
* error was received and the application needs to call
* <code>read()</code> again. Use <code>getError</code> to
* get a specific error code.
* <code>BAD_FUNC_ARC</code> when bad arguments are used.
* @throws IllegalStateException WolfSSLContext has been freed
* @throws SocketTimeoutException if socket timeout occurs
* @throws SocketException Native socket select/poll() failed
*/
public int read(ByteBuffer data, int sz, int timeout)
throws IllegalStateException, SocketTimeoutException, SocketException {
int ret;
long localPtr;
confirmObjectIsActive();
/* Fix for Infer scan, since not synchronizing on sslLock for
* access to this.sslPtr, see note below */
synchronized (sslLock) {
localPtr = this.sslPtr;
}
WolfSSLDebug.log(getClass(), WolfSSLDebug.Component.JNI,
WolfSSLDebug.INFO, localPtr, "entered read(ByteBuffer, " +
"sz: " + sz + ", timeout: " + timeout + ")");
/* not synchronizing on sslLock here since JNI read() locks
* session mutex around native wolfSSL_read() call. If sslLock
* is locked here, since we call select() inside native JNI we
* could timeout waiting for corresponding write() operation to
* occur if needed */
try {
ret = read(localPtr, data, sz, timeout);
} catch (WolfSSLException e) {
/* JNI code may throw WolfSSLException on JNI specific errors */
throw new SocketException(e.getMessage());
}
WolfSSLDebug.log(getClass(), WolfSSLDebug.Component.JNI,
WolfSSLDebug.INFO, localPtr, "read() ret: " + ret +
", err: " + getError(ret));
throwExceptionFromIOReturnValue(ret, "wolfSSL_read()");
return ret;
}
/**
* Waits for an SSL client to initiate the SSL/TLS handshake.
* This method is called on the server side. When it is called, the

View File

@ -818,15 +818,23 @@ public class WolfSSLEngine extends SSLEngine {
int ret = 0;
int idx = 0; /* index into out[] array */
int err = 0;
byte[] tmp;
byte[] tmp = null;
/* create read buffer of max output size */
/* Calculate maximum output size across ByteBuffer arrays */
maxOutSz = getTotalOutputSize(out, ofst, length);
tmp = new byte[maxOutSz];
synchronized (ioLock) {
try {
ret = this.ssl.read(tmp, maxOutSz);
/* If we only have one ByteBuffer, skip allocating
* separate intermediate byte[] and write directly to underlying
* ByteBuffer array */
if (out.length == 1) {
ret = this.ssl.read(out[0], maxOutSz, 0);
}
else {
tmp = new byte[maxOutSz];
ret = this.ssl.read(tmp, maxOutSz);
}
} catch (SocketTimeoutException | SocketException e) {
throw new SSLException(e);
}
@ -883,27 +891,32 @@ public class WolfSSLEngine extends SSLEngine {
}
}
else {
/* write processed data into output buffers */
for (i = 0; i < ret;) {
if (idx + ofst >= length) {
/* no more output buffers left */
break;
}
if (out.length == 1) {
totalRead = ret;
}
else {
/* write processed data into output buffers */
for (i = 0; i < ret;) {
if (idx + ofst >= length) {
/* no more output buffers left */
break;
}
bufSpace = out[idx + ofst].remaining();
if (bufSpace == 0) {
/* no more space in current out buffer, advance */
idx++;
continue;
}
bufSpace = out[idx + ofst].remaining();
if (bufSpace == 0) {
/* no more space in current out buffer, advance */
idx++;
continue;
}
sz = (bufSpace >= (ret - i)) ? (ret - i) : bufSpace;
out[idx + ofst].put(tmp, i, sz);
i += sz;
totalRead += sz;
sz = (bufSpace >= (ret - i)) ? (ret - i) : bufSpace;
out[idx + ofst].put(tmp, i, sz);
i += sz;
totalRead += sz;
if ((ret - i) > 0) {
idx++; /* go to next output buffer */
if ((ret - i) > 0) {
idx++; /* go to next output buffer */
}
}
}
}