JNI: initial implementation of static direct ByteBuffer pool for WolfSSLSession.read(), avoids unaligned memory access at JNI layer

pull/268/head
Chris Conlon 2025-05-14 14:26:15 -06:00
parent a306dfff0e
commit 6b1e7a6299
2 changed files with 132 additions and 44 deletions

View File

@ -7,6 +7,10 @@
#ifdef __cplusplus #ifdef __cplusplus
extern "C" { extern "C" {
#endif #endif
#undef com_wolfssl_WolfSSLSession_MAX_POOL_SIZE
#define com_wolfssl_WolfSSLSession_MAX_POOL_SIZE 32L
#undef com_wolfssl_WolfSSLSession_BUFFER_SIZE
#define com_wolfssl_WolfSSLSession_BUFFER_SIZE 17408L
/* /*
* Class: com_wolfssl_WolfSSLSession * Class: com_wolfssl_WolfSSLSession
* Method: newSSL * Method: newSSL

View File

@ -30,6 +30,7 @@ import java.net.SocketTimeoutException;
import java.lang.StringBuilder; import java.lang.StringBuilder;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.util.concurrent.ConcurrentLinkedQueue;
/** /**
* Wraps a native WolfSSL session object and contains methods directly related * Wraps a native WolfSSL session object and contains methods directly related
@ -108,6 +109,74 @@ public class WolfSSLSession {
/* lock around native WOLFSSL pointer use */ /* lock around native WOLFSSL pointer use */
private final Object sslLock = new Object(); private final Object sslLock = new Object();
/* Maximum direct ByteBuffer pool size */
private static final int MAX_POOL_SIZE = 32;
/* Size of each direct ByteBuffer in the pool. This is set to 17KB, which
* is slightly larger than the maximum SSL record size (16KB). This
* allows for some overhead (SSL record header, etc) */
private static final int BUFFER_SIZE = 17 * 1024;
/* Thread-local direct ByteBuffer pool for optimized JNI direct memory
* access. Passing byte[] and offset down to JNI, on some systems this
* will cause unaligned memory access, with pointer addition
* (buffer + offset). Unaligned memory access can be considerably slower
* (ex: Aarch64). To avoid this, we use a thread-local pool of ByteBuffers
* here so native JNI does not do unaligned memory access and to eliminate
* cross-thread contention. */
private static final ThreadLocal<ConcurrentLinkedQueue<ByteBuffer>> directBufferPool =
ThreadLocal.withInitial(() -> new ConcurrentLinkedQueue<>());
/**
* Get a DirectByteBuffer from the thread-local pool or allocate a new one
* if the pool is empty.
*
* @return a direct ByteBuffer ready to use
*/
private static ByteBuffer acquireDirectBuffer() {
ConcurrentLinkedQueue<ByteBuffer> threadPool = directBufferPool.get();
ByteBuffer buffer = threadPool.poll();
if (buffer == null) {
WolfSSLDebug.log(WolfSSLSession.class, WolfSSLDebug.Component.JNI,
WolfSSLDebug.INFO, 0,
() -> "Thread-local DirectByteBuffer pool empty, " +
"allocating new buffer");
buffer = ByteBuffer.allocateDirect(BUFFER_SIZE);
} else {
WolfSSLDebug.log(WolfSSLSession.class, WolfSSLDebug.Component.JNI,
WolfSSLDebug.INFO, 0,
() -> "Reusing DirectByteBuffer from thread-local pool, " +
"pool size: " + threadPool.size());
buffer.clear();
}
return buffer;
}
/**
* Return a DirectByteBuffer to the thread-local pool for reuse.
*
* If the pool is full, the ByteBuffer will be garbage collected.
*
* @param buffer the buffer to return to the pool
*/
private static void releaseDirectBuffer(ByteBuffer buffer) {
if (buffer != null && buffer.isDirect()) {
buffer.clear();
ConcurrentLinkedQueue<ByteBuffer> threadPool =
directBufferPool.get();
if (threadPool.size() < MAX_POOL_SIZE) {
WolfSSLDebug.log(WolfSSLSession.class,
WolfSSLDebug.Component.JNI, WolfSSLDebug.INFO, 0,
() -> "Returning DirectByteBuffer to thread-local pool, " +
"pool size: " + threadPool.size());
threadPool.offer(buffer);
}
}
}
/* SNI requested by this WolfSSLSession if client side and useSNI() /* SNI requested by this WolfSSLSession if client side and useSNI()
* was called successfully. */ * was called successfully. */
private byte[] clientSNIRequested = null; private byte[] clientSNIRequested = null;
@ -1118,38 +1187,7 @@ public class WolfSSLSession {
public int read(byte[] data, int sz) public int read(byte[] data, int sz)
throws IllegalStateException, SocketTimeoutException, SocketException { throws IllegalStateException, SocketTimeoutException, SocketException {
final int ret; return read(data, 0, sz, 0);
final int err;
final int readSz = sz;
long localPtr;
confirmObjectIsActive();
/* Fix for Infer scan, since not synchronizing on sslLock for
* access to this.sslPtr, see note below */
synchronized (sslLock) {
localPtr = this.sslPtr;
}
WolfSSLDebug.log(getClass(), WolfSSLDebug.Component.JNI,
WolfSSLDebug.INFO, localPtr, () -> "entered read(sz: " +
readSz + ")");
/* not synchronizing on sslLock here since JNI read() locks
* session mutex around native wolfSSL_read() call. If sslLock
* is locked here, since we call select() inside native JNI we
* could timeout waiting for corresponding write() operation to
* occur if needed */
ret = read(localPtr, data, 0, readSz, 0);
err = getError(ret);
WolfSSLDebug.log(getClass(), WolfSSLDebug.Component.JNI,
WolfSSLDebug.INFO, localPtr,
() -> "read() ret: " + ret + ", err: " + err);
throwExceptionFromIOReturnValue(ret, "wolfSSL_read()");
return ret;
} }
/** /**
@ -1236,12 +1274,14 @@ public class WolfSSLSession {
public int read(byte[] data, int offset, int sz, int timeout) public int read(byte[] data, int offset, int sz, int timeout)
throws IllegalStateException, SocketTimeoutException, SocketException { throws IllegalStateException, SocketTimeoutException, SocketException {
final int ret; int ret;
final int err; int err;
int readSz = sz;
final int readOff = offset; final int readOff = offset;
final int readSz = sz; final int tmpReadSz = sz;
final int readTimeout = timeout; final int readTimeout = timeout;
long localPtr; long localPtr;
ByteBuffer directBuffer = null;
confirmObjectIsActive(); confirmObjectIsActive();
@ -1253,20 +1293,64 @@ public class WolfSSLSession {
WolfSSLDebug.log(getClass(), WolfSSLDebug.Component.JNI, WolfSSLDebug.log(getClass(), WolfSSLDebug.Component.JNI,
WolfSSLDebug.INFO, localPtr, WolfSSLDebug.INFO, localPtr,
() -> "entered read(offset: " + readOff + ", sz: " + readSz + () -> "entered read(offset: " + readOff + ", sz: " + tmpReadSz +
", timeout: " + readTimeout + ")"); ", timeout: " + readTimeout + ")");
/* not synchronizing on sslLock here since JNI read() locks /* Use a DirectByteBuffer from the pool to avoid unaligned
* session mutex around native wolfSSL_read() call. If sslLock * memory access. Otherwise our native JNI code may need to
* is locked here, since we call select() inside native JNI we * do "buffer + offset" and end up with unaligned memory which
* could timeout waiting for corresponding write() operation to * can be slow on some targets (ex: ARM/Aarch64) */
* occur if needed */ try {
ret = read(localPtr, data, readOff, readSz, readTimeout); /* Get a buffer from the pool */
err = getError(ret); directBuffer = acquireDirectBuffer();
WolfSSLDebug.log(getClass(),
WolfSSLDebug.Component.JNI, WolfSSLDebug.INFO, localPtr,
() -> "read() using thread-local ByteBuffer pool: pool size: " +
directBufferPool.get().size());
/* Only read up to the size of the buffer or readSz,
* whichever is smaller. */
readSz = Math.min(readSz, directBuffer.capacity());
/* Use direct buffer for JNI call */
directBuffer.limit(readSz);
/* Call native read with DirectByteBuffer */
ret = read(localPtr, directBuffer, 0, readSz, false,
readSz, readTimeout);
if (ret > 0) {
/* Copy data from direct buffer to user array */
directBuffer.flip();
directBuffer.get(data, offset, ret);
}
err = getError(ret);
} catch (Exception e) {
WolfSSLDebug.log(getClass(), WolfSSLDebug.Component.JNI,
WolfSSLDebug.INFO, localPtr,
() -> "read() falling back to use byte[]");
/* Fall back to original implementation on errors */
ret = read(localPtr, data, readOff, readSz, readTimeout);
err = getError(ret);
} finally {
/* Return buffer to pool */
if (directBuffer != null) {
releaseDirectBuffer(directBuffer);
}
}
final int finalRet = ret;
final int finalErr = err;
WolfSSLDebug.log(getClass(), WolfSSLDebug.Component.JNI, WolfSSLDebug.log(getClass(), WolfSSLDebug.Component.JNI,
WolfSSLDebug.INFO, localPtr, WolfSSLDebug.INFO, localPtr,
() -> "read() ret: " + ret + ", err: " + err); () -> "read() ret: " + finalRet + ", err: " + finalErr);
throwExceptionFromIOReturnValue(ret, "wolfSSL_read()"); throwExceptionFromIOReturnValue(ret, "wolfSSL_read()");