JNI: initial implementation of static direct ByteBuffer pool for WolfSSLSession.write(), avoids unaligned memory access at JNI layer

pull/268/head
Chris Conlon 2025-05-16 11:23:17 -06:00
parent 6b1e7a6299
commit 1c8963e4fe
4 changed files with 269 additions and 132 deletions

View File

@ -1111,18 +1111,113 @@ JNIEXPORT jint JNICALL Java_com_wolfssl_WolfSSLSession_connect
return ret; return ret;
} }
JNIEXPORT jint JNICALL Java_com_wolfssl_WolfSSLSession_write /**
(JNIEnv* jenv, jobject jcl, jlong sslPtr, jbyteArray raw, jint offset, * Write len bytes with wolfSSL_write() from provided input data buffer.
jint length, jint timeout) *
* Internal function called by WolfSSLSession.write() calls.
*
* If wolfSSL_get_fd(ssl) returns a socket descriptor, try to wait writability
* with select()/poll() up to provided timeout.
*
* Returns number of bytes written on success, or negative on error.
*/
static int SSLWriteNonblockingWithSelectPoll(WOLFSSL* ssl, byte* data,
int length, int timeout)
{ {
byte* data = NULL; int ret, err, sockfd;
int ret = SSL_FAILURE, err, sockfd;
int pollRx = 0; int pollRx = 0;
#if !defined(WOLFJNI_USE_IO_SELECT) && !defined(USE_WINDOWS_API) #if !defined(WOLFJNI_USE_IO_SELECT) && !defined(USE_WINDOWS_API)
int pollTx = 0; int pollTx = 0;
#endif #endif
wolfSSL_Mutex* jniSessLock = NULL; wolfSSL_Mutex* jniSessLock = NULL;
SSLAppData* appData = NULL; SSLAppData* appData = NULL;
if (ssl == NULL || data == NULL) {
return BAD_FUNC_ARG;
}
/* get session mutex from SSL app data */
appData = (SSLAppData*)wolfSSL_get_app_data(ssl);
if (appData == NULL) {
return WOLFSSL_FAILURE;
}
jniSessLock = appData->jniSessLock;
if (jniSessLock == NULL) {
return SSL_FAILURE;
}
do {
/* lock mutex around session I/O before write attempt */
if (wc_LockMutex(jniSessLock) != 0) {
ret = WOLFSSL_FAILURE;
break;
}
ret = wolfSSL_write(ssl, data, length);
err = wolfSSL_get_error(ssl, ret);
/* unlock mutex around session I/O after write attempt */
if (wc_UnLockMutex(jniSessLock) != 0) {
ret = WOLFSSL_FAILURE;
break;
}
if ((ret < 0) &&
((err == SSL_ERROR_WANT_READ) || (err == SSL_ERROR_WANT_WRITE))) {
sockfd = wolfSSL_get_fd(ssl);
if (sockfd == -1) {
/* For I/O that does not use sockets, sockfd may be -1,
* skip try to call select() */
break;
}
if (err == SSL_ERROR_WANT_READ) {
pollRx = 1;
}
#if !defined(WOLFJNI_USE_IO_SELECT) && !defined(USE_WINDOWS_API)
else if (err == SSL_ERROR_WANT_WRITE) {
pollTx = 1;
}
#endif
#if defined(WOLFJNI_USE_IO_SELECT) || defined(USE_WINDOWS_API)
ret = socketSelect(appData, sockfd, (int)timeout, pollRx, 0);
#else
ret = socketPoll(appData, sockfd, (int)timeout, pollRx,
pollTx, 0);
#endif
if ((ret == WOLFJNI_IO_EVENT_RECV_READY) ||
(ret == WOLFJNI_IO_EVENT_SEND_READY)) {
/* loop around and try wolfSSL_write() again */
continue;
} else if (ret == WOLFJNI_IO_EVENT_TIMEOUT ||
ret == WOLFJNI_IO_EVENT_FD_CLOSED ||
ret == WOLFJNI_IO_EVENT_ERROR ||
ret == WOLFJNI_IO_EVENT_POLLHUP ||
ret == WOLFJNI_IO_EVENT_FAIL) {
/* Java will throw SocketTimeoutException or
* SocketException */
break;
} else {
/* error */
ret = WOLFSSL_FAILURE;
break;
}
}
} while (err == SSL_ERROR_WANT_WRITE || err == SSL_ERROR_WANT_READ);
return ret;
}
JNIEXPORT jint JNICALL Java_com_wolfssl_WolfSSLSession_write__J_3BIII
(JNIEnv* jenv, jobject jcl, jlong sslPtr, jbyteArray raw, jint offset,
jint length, jint timeout)
{
byte* data = NULL;
int ret;
WOLFSSL* ssl = (WOLFSSL*)(uintptr_t)sslPtr; WOLFSSL* ssl = (WOLFSSL*)(uintptr_t)sslPtr;
(void)jcl; (void)jcl;
@ -1138,85 +1233,8 @@ JNIEXPORT jint JNICALL Java_com_wolfssl_WolfSSLSession_write
return SSL_FAILURE; return SSL_FAILURE;
} }
/* get session mutex from SSL app data */ ret = SSLWriteNonblockingWithSelectPoll(ssl, data + offset,
appData = (SSLAppData*)wolfSSL_get_app_data(ssl); (int)length, (int)timeout);
if (appData == NULL) {
(*jenv)->ReleaseByteArrayElements(jenv, raw, (jbyte*)data,
JNI_ABORT);
return WOLFSSL_FAILURE;
}
jniSessLock = appData->jniSessLock;
if (jniSessLock == NULL) {
(*jenv)->ReleaseByteArrayElements(jenv, raw, (jbyte*)data,
JNI_ABORT);
return SSL_FAILURE;
}
do {
/* lock mutex around session I/O before write attempt */
if (wc_LockMutex(jniSessLock) != 0) {
ret = WOLFSSL_FAILURE;
break;
}
ret = wolfSSL_write(ssl, data + offset, length);
err = wolfSSL_get_error(ssl, ret);
/* unlock mutex around session I/O after write attempt */
if (wc_UnLockMutex(jniSessLock) != 0) {
ret = WOLFSSL_FAILURE;
break;
}
if (ret >= 0) /* return if it is success */
break;
if (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE) {
sockfd = wolfSSL_get_fd(ssl);
if (sockfd == -1) {
/* For I/O that does not use sockets, sockfd may be -1,
* skip try to call select() */
break;
}
if (err == SSL_ERROR_WANT_READ) {
pollRx = 1;
}
#if !defined(WOLFJNI_USE_IO_SELECT) && !defined(USE_WINDOWS_API)
else if (err == SSL_ERROR_WANT_WRITE) {
pollTx = 1;
}
#endif
#if defined(WOLFJNI_USE_IO_SELECT) || defined(USE_WINDOWS_API)
ret = socketSelect(appData, sockfd, (int)timeout, pollRx, 0);
#else
ret = socketPoll(appData, sockfd, (int)timeout, pollRx,
pollTx, 0);
#endif
if ((ret == WOLFJNI_IO_EVENT_RECV_READY) ||
(ret == WOLFJNI_IO_EVENT_SEND_READY)) {
/* loop around and try wolfSSL_write() again */
continue;
} else if (ret == WOLFJNI_IO_EVENT_TIMEOUT ||
ret == WOLFJNI_IO_EVENT_FD_CLOSED ||
ret == WOLFJNI_IO_EVENT_ERROR ||
ret == WOLFJNI_IO_EVENT_POLLHUP ||
ret == WOLFJNI_IO_EVENT_FAIL) {
/* Java will throw SocketTimeoutException or
* SocketException */
break;
} else {
/* error */
ret = WOLFSSL_FAILURE;
break;
}
}
} while (err == SSL_ERROR_WANT_WRITE || err == SSL_ERROR_WANT_READ);
(*jenv)->ReleaseByteArrayElements(jenv, raw, (jbyte*)data, JNI_ABORT); (*jenv)->ReleaseByteArrayElements(jenv, raw, (jbyte*)data, JNI_ABORT);
@ -1227,6 +1245,86 @@ JNIEXPORT jint JNICALL Java_com_wolfssl_WolfSSLSession_write
} }
} }
JNIEXPORT jint JNICALL Java_com_wolfssl_WolfSSLSession_write__JLjava_nio_ByteBuffer_2IIZII
(JNIEnv* jenv, jobject jcl, jlong sslPtr, jobject buf, jint position,
jint limit, jboolean hasArray, jint length, jint timeout)
{
int ret;
int maxInputSz;
int inSz = length;
byte* data = NULL;
WOLFSSL* ssl = (WOLFSSL*)(uintptr_t)sslPtr;
jbyteArray bufArr = NULL;
(void)jcl;
if (jenv == NULL || ssl == NULL || buf == NULL) {
return BAD_FUNC_ARG;
}
if (length > 0) {
/* Only write up to maximum space we have in this ByteBuffer */
maxInputSz = (limit - position);
if (inSz > maxInputSz) {
inSz = maxInputSz;
}
if (inSz <= 0) {
return BAD_FUNC_ARG;
}
if (hasArray) {
/* Get reference to underlying byte[] from ByteBuffer */
bufArr = (jbyteArray)(*jenv)->CallObjectMethod(jenv, buf,
g_bufferArrayMethodId);
if ((*jenv)->ExceptionCheck(jenv)) {
return SSL_FAILURE;
}
/* Get array elements */
data = (byte *)(*jenv)->GetByteArrayElements(jenv, bufArr, NULL);
if (data == NULL) {
/* Handle any pending exception, we'll throw another below
* anyways so just clear it */
if ((*jenv)->ExceptionOccurred(jenv)) {
(*jenv)->ExceptionDescribe(jenv);
(*jenv)->ExceptionClear(jenv);
}
throwWolfSSLJNIException(jenv,
"Failed to get byte[] from ByteBuffer in native write()");
return BAD_FUNC_ARG;
}
}
else {
data = (byte *)(*jenv)->GetDirectBufferAddress(jenv, buf);
if (data == NULL) {
throwWolfSSLJNIException(jenv,
"Failed to get DirectBuffer address in native write()");
return BAD_FUNC_ARG;
}
}
ret = SSLWriteNonblockingWithSelectPoll(ssl, data + position,
(int)inSz, (int)timeout);
/* release memory if using array mode */
if (hasArray) {
(*jenv)->ReleaseByteArrayElements(jenv, bufArr,
(jbyte*)data, JNI_ABORT);
}
}
/* check for Java exceptions before returning */
if ((*jenv)->ExceptionCheck(jenv)) {
(*jenv)->ExceptionDescribe(jenv);
(*jenv)->ExceptionClear(jenv);
return SSL_FAILURE;
}
return ret;
}
/** /**
* Read len bytes from wolfSSL_read() back into provided output buffer. * Read len bytes from wolfSSL_read() back into provided output buffer.
* *

View File

@ -96,9 +96,17 @@ JNIEXPORT jint JNICALL Java_com_wolfssl_WolfSSLSession_connect
* Method: write * Method: write
* Signature: (J[BIII)I * Signature: (J[BIII)I
*/ */
JNIEXPORT jint JNICALL Java_com_wolfssl_WolfSSLSession_write JNIEXPORT jint JNICALL Java_com_wolfssl_WolfSSLSession_write__J_3BIII
(JNIEnv *, jobject, jlong, jbyteArray, jint, jint, jint); (JNIEnv *, jobject, jlong, jbyteArray, jint, jint, jint);
/*
* Class: com_wolfssl_WolfSSLSession
* Method: write
* Signature: (JLjava/nio/ByteBuffer;IIZII)I
*/
JNIEXPORT jint JNICALL Java_com_wolfssl_WolfSSLSession_write__JLjava_nio_ByteBuffer_2IIZII
(JNIEnv *, jobject, jlong, jobject, jint, jint, jboolean, jint, jint);
/* /*
* Class: com_wolfssl_WolfSSLSession * Class: com_wolfssl_WolfSSLSession
* Method: read * Method: read

View File

@ -110,12 +110,12 @@ public class WolfSSLSession {
private final Object sslLock = new Object(); private final Object sslLock = new Object();
/* Maximum direct ByteBuffer pool size */ /* Maximum direct ByteBuffer pool size */
private static final int MAX_POOL_SIZE = 32; private static int MAX_POOL_SIZE = 16;
/* Size of each direct ByteBuffer in the pool. This is set to 17KB, which /* Size of each direct ByteBuffer in the pool. This is set to 17KB, which
* is slightly larger than the maximum SSL record size (16KB). This * is slightly larger than the maximum SSL record size (16KB). This
* allows for some overhead (SSL record header, etc) */ * allows for some overhead (SSL record header, etc) */
private static final int BUFFER_SIZE = 17 * 1024; private static int BUFFER_SIZE = 17 * 1024;
/* Thread-local direct ByteBuffer pool for optimized JNI direct memory /* Thread-local direct ByteBuffer pool for optimized JNI direct memory
* access. Passing byte[] and offset down to JNI, on some systems this * access. Passing byte[] and offset down to JNI, on some systems this
@ -446,6 +446,9 @@ public class WolfSSLSession {
private native int connect(long ssl, int timeout); private native int connect(long ssl, int timeout);
private native int write(long ssl, byte[] data, int offset, int length, private native int write(long ssl, byte[] data, int offset, int length,
int timeout); int timeout);
private native int write(long ssl, ByteBuffer data, final int position,
final int limit, boolean hasArray, int sz, int timeout)
throws WolfSSLException;
private native int read(long ssl, byte[] data, int offset, int sz, private native int read(long ssl, byte[] data, int offset, int sz,
int timeout); int timeout);
private native int read(long ssl, ByteBuffer data, final int position, private native int read(long ssl, ByteBuffer data, final int position,
@ -1000,37 +1003,7 @@ public class WolfSSLSession {
public int write(byte[] data, int length) public int write(byte[] data, int length)
throws IllegalStateException, SocketTimeoutException, SocketException { throws IllegalStateException, SocketTimeoutException, SocketException {
final int ret; return write(data, 0, length, 0);
final int err;
long localPtr;
confirmObjectIsActive();
/* Fix for Infer scan, since not synchronizing on sslLock for
* access to this.sslPtr, see note below */
synchronized (sslLock) {
localPtr = this.sslPtr;
}
WolfSSLDebug.log(getClass(), WolfSSLDebug.Component.JNI,
WolfSSLDebug.INFO, localPtr,
() -> "entered write(length: " + length + ")");
/* not synchronizing on sslLock here since JNI write() locks
* session mutex around native wolfSSL_write() call. If sslLock
* is locked here, since we call select() inside native JNI we
* could timeout waiting for corresponding read() operation to
* occur if needed */
ret = write(localPtr, data, 0, length, 0);
err = getError(ret);
WolfSSLDebug.log(getClass(), WolfSSLDebug.Component.JNI,
WolfSSLDebug.INFO, localPtr,
() -> "write() ret: " + ret + ", err: " + err);
throwExceptionFromIOReturnValue(ret, "wolfSSL_write()");
return ret;
} }
/** /**
@ -1113,9 +1086,12 @@ public class WolfSSLSession {
public int write(byte[] data, int offset, int length, int timeout) public int write(byte[] data, int offset, int length, int timeout)
throws IllegalStateException, SocketTimeoutException, SocketException { throws IllegalStateException, SocketTimeoutException, SocketException {
final int ret; int ret = 0;
final int err; int err = 0;
int totalWritten = 0;
int remaining = length;
long localPtr; long localPtr;
ByteBuffer directBuffer = null;
confirmObjectIsActive(); confirmObjectIsActive();
@ -1130,21 +1106,75 @@ public class WolfSSLSession {
() -> "entered write(offset: " + offset + ", length: " + () -> "entered write(offset: " + offset + ", length: " +
length + ", timeout: " + timeout + ")"); length + ", timeout: " + timeout + ")");
/* not synchronizing on sslLock here since JNI write() locks /* Use a direct ByteBuffer from the pool to avoid unaligned
* session mutex around native wolfSSL_write() call. If sslLock * memory access. Otherwise our native JNI code may need to do
* is locked here, since we call select() inside native JNI we * "buffer + offset" and end up with unaligned memory which
* could timeout waiting for corresponding read() operation to * can be slow on some targets (ex: ARM/Aarch64) */
* occur if needed */ try {
ret = write(localPtr, data, offset, length, timeout); /* Get a buffer from the pool */
err = getError(ret); directBuffer = acquireDirectBuffer();
WolfSSLDebug.log(getClass(),
WolfSSLDebug.Component.JNI, WolfSSLDebug.INFO, localPtr,
() -> "write() using thread-local ByteBuffer pool: pool size: " +
directBufferPool.get().size());
/* Write in chunks until all data is written or an error occurs.
* The DirectByteBuffer size might be smaller than the data length,
* so we need to loop to handle all the data */
while (remaining > 0) {
/* Calculate size for current chunk */
int writeSize = Math.min(remaining, directBuffer.capacity());
/* Copy data from user array to direct buffer */
directBuffer.clear();
directBuffer.put(data, offset + totalWritten, writeSize);
directBuffer.flip();
/* Call native write with DirectByteBuffer */
ret = write(localPtr, directBuffer, directBuffer.position(),
directBuffer.limit(), false, writeSize, timeout);
if (ret <= 0) {
/* Error occurred, break out of loop */
break;
}
/* Update tracking variables */
totalWritten += ret;
remaining -= ret;
}
} catch (Exception e) {
WolfSSLDebug.log(getClass(), WolfSSLDebug.Component.JNI,
WolfSSLDebug.ERROR, localPtr,
() -> "write() falling back to use byte[]");
/* Fall back to original implementation on exception (write not
* done yet at this point in JNI call above) */
totalWritten = write(localPtr, data, offset, length, timeout);
} finally {
/* Return buffer to pool */
if (directBuffer != null) {
releaseDirectBuffer(directBuffer);
}
}
/* Return total bytes written, or last error code */
final int finalRet = (totalWritten > 0) ? totalWritten : ret;
final int finalErr = getError(finalRet);
final int finalTotal = totalWritten;
WolfSSLDebug.log(getClass(), WolfSSLDebug.Component.JNI, WolfSSLDebug.log(getClass(), WolfSSLDebug.Component.JNI,
WolfSSLDebug.INFO, localPtr, WolfSSLDebug.INFO, localPtr,
() -> "write() ret: " + ret + ", err: " + err); () -> "write() ret: " + finalRet + ", err: " + finalErr +
", totalWritten: " + finalTotal);
throwExceptionFromIOReturnValue(ret, "wolfSSL_write()"); throwExceptionFromIOReturnValue(finalRet, "wolfSSL_write()");
return ret; return finalRet;
} }
/** /**

View File

@ -2895,8 +2895,9 @@ public class WolfSSLSocket extends SSLSocket {
if (ret < 0) { if (ret < 0) {
/* print error description string */ /* print error description string */
String errStr = WolfSSL.getErrorString(err); String errStr = WolfSSL.getErrorString(err);
throw new IOException("Native wolfSSL_write() error: " throw new IOException("Native wolfSSL_write() error: " +
+ errStr + " (error code: " + err + ")"); errStr + " (ret: " + ret + ", error code: " +
err + ")");
} }
} catch (IllegalStateException e) { } catch (IllegalStateException e) {