JNI: initial implementation of static direct ByteBuffer pool for WolfSSLSession.write(), avoids unaligned memory access at JNI layer
parent
6b1e7a6299
commit
1c8963e4fe
|
@ -1111,18 +1111,113 @@ JNIEXPORT jint JNICALL Java_com_wolfssl_WolfSSLSession_connect
|
|||
return ret;
|
||||
}
|
||||
|
||||
JNIEXPORT jint JNICALL Java_com_wolfssl_WolfSSLSession_write
|
||||
(JNIEnv* jenv, jobject jcl, jlong sslPtr, jbyteArray raw, jint offset,
|
||||
jint length, jint timeout)
|
||||
/**
|
||||
* Write len bytes with wolfSSL_write() from provided input data buffer.
|
||||
*
|
||||
* Internal function called by WolfSSLSession.write() calls.
|
||||
*
|
||||
* If wolfSSL_get_fd(ssl) returns a socket descriptor, try to wait writability
|
||||
* with select()/poll() up to provided timeout.
|
||||
*
|
||||
* Returns number of bytes written on success, or negative on error.
|
||||
*/
|
||||
static int SSLWriteNonblockingWithSelectPoll(WOLFSSL* ssl, byte* data,
|
||||
int length, int timeout)
|
||||
{
|
||||
byte* data = NULL;
|
||||
int ret = SSL_FAILURE, err, sockfd;
|
||||
int ret, err, sockfd;
|
||||
int pollRx = 0;
|
||||
#if !defined(WOLFJNI_USE_IO_SELECT) && !defined(USE_WINDOWS_API)
|
||||
int pollTx = 0;
|
||||
#endif
|
||||
wolfSSL_Mutex* jniSessLock = NULL;
|
||||
SSLAppData* appData = NULL;
|
||||
|
||||
if (ssl == NULL || data == NULL) {
|
||||
return BAD_FUNC_ARG;
|
||||
}
|
||||
|
||||
/* get session mutex from SSL app data */
|
||||
appData = (SSLAppData*)wolfSSL_get_app_data(ssl);
|
||||
if (appData == NULL) {
|
||||
return WOLFSSL_FAILURE;
|
||||
}
|
||||
|
||||
jniSessLock = appData->jniSessLock;
|
||||
if (jniSessLock == NULL) {
|
||||
return SSL_FAILURE;
|
||||
}
|
||||
|
||||
do {
|
||||
/* lock mutex around session I/O before write attempt */
|
||||
if (wc_LockMutex(jniSessLock) != 0) {
|
||||
ret = WOLFSSL_FAILURE;
|
||||
break;
|
||||
}
|
||||
|
||||
ret = wolfSSL_write(ssl, data, length);
|
||||
err = wolfSSL_get_error(ssl, ret);
|
||||
|
||||
/* unlock mutex around session I/O after write attempt */
|
||||
if (wc_UnLockMutex(jniSessLock) != 0) {
|
||||
ret = WOLFSSL_FAILURE;
|
||||
break;
|
||||
}
|
||||
|
||||
if ((ret < 0) &&
|
||||
((err == SSL_ERROR_WANT_READ) || (err == SSL_ERROR_WANT_WRITE))) {
|
||||
|
||||
sockfd = wolfSSL_get_fd(ssl);
|
||||
if (sockfd == -1) {
|
||||
/* For I/O that does not use sockets, sockfd may be -1,
|
||||
* skip try to call select() */
|
||||
break;
|
||||
}
|
||||
|
||||
if (err == SSL_ERROR_WANT_READ) {
|
||||
pollRx = 1;
|
||||
}
|
||||
#if !defined(WOLFJNI_USE_IO_SELECT) && !defined(USE_WINDOWS_API)
|
||||
else if (err == SSL_ERROR_WANT_WRITE) {
|
||||
pollTx = 1;
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(WOLFJNI_USE_IO_SELECT) || defined(USE_WINDOWS_API)
|
||||
ret = socketSelect(appData, sockfd, (int)timeout, pollRx, 0);
|
||||
#else
|
||||
ret = socketPoll(appData, sockfd, (int)timeout, pollRx,
|
||||
pollTx, 0);
|
||||
#endif
|
||||
if ((ret == WOLFJNI_IO_EVENT_RECV_READY) ||
|
||||
(ret == WOLFJNI_IO_EVENT_SEND_READY)) {
|
||||
/* loop around and try wolfSSL_write() again */
|
||||
continue;
|
||||
} else if (ret == WOLFJNI_IO_EVENT_TIMEOUT ||
|
||||
ret == WOLFJNI_IO_EVENT_FD_CLOSED ||
|
||||
ret == WOLFJNI_IO_EVENT_ERROR ||
|
||||
ret == WOLFJNI_IO_EVENT_POLLHUP ||
|
||||
ret == WOLFJNI_IO_EVENT_FAIL) {
|
||||
/* Java will throw SocketTimeoutException or
|
||||
* SocketException */
|
||||
break;
|
||||
} else {
|
||||
/* error */
|
||||
ret = WOLFSSL_FAILURE;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
} while (err == SSL_ERROR_WANT_WRITE || err == SSL_ERROR_WANT_READ);
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
JNIEXPORT jint JNICALL Java_com_wolfssl_WolfSSLSession_write__J_3BIII
|
||||
(JNIEnv* jenv, jobject jcl, jlong sslPtr, jbyteArray raw, jint offset,
|
||||
jint length, jint timeout)
|
||||
{
|
||||
byte* data = NULL;
|
||||
int ret;
|
||||
WOLFSSL* ssl = (WOLFSSL*)(uintptr_t)sslPtr;
|
||||
(void)jcl;
|
||||
|
||||
|
@ -1138,85 +1233,8 @@ JNIEXPORT jint JNICALL Java_com_wolfssl_WolfSSLSession_write
|
|||
return SSL_FAILURE;
|
||||
}
|
||||
|
||||
/* get session mutex from SSL app data */
|
||||
appData = (SSLAppData*)wolfSSL_get_app_data(ssl);
|
||||
if (appData == NULL) {
|
||||
(*jenv)->ReleaseByteArrayElements(jenv, raw, (jbyte*)data,
|
||||
JNI_ABORT);
|
||||
return WOLFSSL_FAILURE;
|
||||
}
|
||||
|
||||
jniSessLock = appData->jniSessLock;
|
||||
if (jniSessLock == NULL) {
|
||||
(*jenv)->ReleaseByteArrayElements(jenv, raw, (jbyte*)data,
|
||||
JNI_ABORT);
|
||||
return SSL_FAILURE;
|
||||
}
|
||||
|
||||
do {
|
||||
|
||||
/* lock mutex around session I/O before write attempt */
|
||||
if (wc_LockMutex(jniSessLock) != 0) {
|
||||
ret = WOLFSSL_FAILURE;
|
||||
break;
|
||||
}
|
||||
|
||||
ret = wolfSSL_write(ssl, data + offset, length);
|
||||
err = wolfSSL_get_error(ssl, ret);
|
||||
|
||||
/* unlock mutex around session I/O after write attempt */
|
||||
if (wc_UnLockMutex(jniSessLock) != 0) {
|
||||
ret = WOLFSSL_FAILURE;
|
||||
break;
|
||||
}
|
||||
|
||||
if (ret >= 0) /* return if it is success */
|
||||
break;
|
||||
|
||||
if (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE) {
|
||||
|
||||
sockfd = wolfSSL_get_fd(ssl);
|
||||
if (sockfd == -1) {
|
||||
/* For I/O that does not use sockets, sockfd may be -1,
|
||||
* skip try to call select() */
|
||||
break;
|
||||
}
|
||||
|
||||
if (err == SSL_ERROR_WANT_READ) {
|
||||
pollRx = 1;
|
||||
}
|
||||
#if !defined(WOLFJNI_USE_IO_SELECT) && !defined(USE_WINDOWS_API)
|
||||
else if (err == SSL_ERROR_WANT_WRITE) {
|
||||
pollTx = 1;
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(WOLFJNI_USE_IO_SELECT) || defined(USE_WINDOWS_API)
|
||||
ret = socketSelect(appData, sockfd, (int)timeout, pollRx, 0);
|
||||
#else
|
||||
ret = socketPoll(appData, sockfd, (int)timeout, pollRx,
|
||||
pollTx, 0);
|
||||
#endif
|
||||
if ((ret == WOLFJNI_IO_EVENT_RECV_READY) ||
|
||||
(ret == WOLFJNI_IO_EVENT_SEND_READY)) {
|
||||
/* loop around and try wolfSSL_write() again */
|
||||
continue;
|
||||
} else if (ret == WOLFJNI_IO_EVENT_TIMEOUT ||
|
||||
ret == WOLFJNI_IO_EVENT_FD_CLOSED ||
|
||||
ret == WOLFJNI_IO_EVENT_ERROR ||
|
||||
ret == WOLFJNI_IO_EVENT_POLLHUP ||
|
||||
ret == WOLFJNI_IO_EVENT_FAIL) {
|
||||
/* Java will throw SocketTimeoutException or
|
||||
* SocketException */
|
||||
break;
|
||||
} else {
|
||||
/* error */
|
||||
ret = WOLFSSL_FAILURE;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
} while (err == SSL_ERROR_WANT_WRITE || err == SSL_ERROR_WANT_READ);
|
||||
ret = SSLWriteNonblockingWithSelectPoll(ssl, data + offset,
|
||||
(int)length, (int)timeout);
|
||||
|
||||
(*jenv)->ReleaseByteArrayElements(jenv, raw, (jbyte*)data, JNI_ABORT);
|
||||
|
||||
|
@ -1227,6 +1245,86 @@ JNIEXPORT jint JNICALL Java_com_wolfssl_WolfSSLSession_write
|
|||
}
|
||||
}
|
||||
|
||||
JNIEXPORT jint JNICALL Java_com_wolfssl_WolfSSLSession_write__JLjava_nio_ByteBuffer_2IIZII
|
||||
(JNIEnv* jenv, jobject jcl, jlong sslPtr, jobject buf, jint position,
|
||||
jint limit, jboolean hasArray, jint length, jint timeout)
|
||||
{
|
||||
int ret;
|
||||
int maxInputSz;
|
||||
int inSz = length;
|
||||
byte* data = NULL;
|
||||
WOLFSSL* ssl = (WOLFSSL*)(uintptr_t)sslPtr;
|
||||
jbyteArray bufArr = NULL;
|
||||
|
||||
(void)jcl;
|
||||
|
||||
if (jenv == NULL || ssl == NULL || buf == NULL) {
|
||||
return BAD_FUNC_ARG;
|
||||
}
|
||||
|
||||
if (length > 0) {
|
||||
|
||||
/* Only write up to maximum space we have in this ByteBuffer */
|
||||
maxInputSz = (limit - position);
|
||||
if (inSz > maxInputSz) {
|
||||
inSz = maxInputSz;
|
||||
}
|
||||
|
||||
if (inSz <= 0) {
|
||||
return BAD_FUNC_ARG;
|
||||
}
|
||||
|
||||
if (hasArray) {
|
||||
/* Get reference to underlying byte[] from ByteBuffer */
|
||||
bufArr = (jbyteArray)(*jenv)->CallObjectMethod(jenv, buf,
|
||||
g_bufferArrayMethodId);
|
||||
if ((*jenv)->ExceptionCheck(jenv)) {
|
||||
return SSL_FAILURE;
|
||||
}
|
||||
|
||||
/* Get array elements */
|
||||
data = (byte *)(*jenv)->GetByteArrayElements(jenv, bufArr, NULL);
|
||||
if (data == NULL) {
|
||||
/* Handle any pending exception, we'll throw another below
|
||||
* anyways so just clear it */
|
||||
if ((*jenv)->ExceptionOccurred(jenv)) {
|
||||
(*jenv)->ExceptionDescribe(jenv);
|
||||
(*jenv)->ExceptionClear(jenv);
|
||||
}
|
||||
throwWolfSSLJNIException(jenv,
|
||||
"Failed to get byte[] from ByteBuffer in native write()");
|
||||
return BAD_FUNC_ARG;
|
||||
}
|
||||
}
|
||||
else {
|
||||
data = (byte *)(*jenv)->GetDirectBufferAddress(jenv, buf);
|
||||
if (data == NULL) {
|
||||
throwWolfSSLJNIException(jenv,
|
||||
"Failed to get DirectBuffer address in native write()");
|
||||
return BAD_FUNC_ARG;
|
||||
}
|
||||
}
|
||||
|
||||
ret = SSLWriteNonblockingWithSelectPoll(ssl, data + position,
|
||||
(int)inSz, (int)timeout);
|
||||
|
||||
/* release memory if using array mode */
|
||||
if (hasArray) {
|
||||
(*jenv)->ReleaseByteArrayElements(jenv, bufArr,
|
||||
(jbyte*)data, JNI_ABORT);
|
||||
}
|
||||
}
|
||||
|
||||
/* check for Java exceptions before returning */
|
||||
if ((*jenv)->ExceptionCheck(jenv)) {
|
||||
(*jenv)->ExceptionDescribe(jenv);
|
||||
(*jenv)->ExceptionClear(jenv);
|
||||
return SSL_FAILURE;
|
||||
}
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
/**
|
||||
* Read len bytes from wolfSSL_read() back into provided output buffer.
|
||||
*
|
||||
|
|
|
@ -96,9 +96,17 @@ JNIEXPORT jint JNICALL Java_com_wolfssl_WolfSSLSession_connect
|
|||
* Method: write
|
||||
* Signature: (J[BIII)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_com_wolfssl_WolfSSLSession_write
|
||||
JNIEXPORT jint JNICALL Java_com_wolfssl_WolfSSLSession_write__J_3BIII
|
||||
(JNIEnv *, jobject, jlong, jbyteArray, jint, jint, jint);
|
||||
|
||||
/*
|
||||
* Class: com_wolfssl_WolfSSLSession
|
||||
* Method: write
|
||||
* Signature: (JLjava/nio/ByteBuffer;IIZII)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_com_wolfssl_WolfSSLSession_write__JLjava_nio_ByteBuffer_2IIZII
|
||||
(JNIEnv *, jobject, jlong, jobject, jint, jint, jboolean, jint, jint);
|
||||
|
||||
/*
|
||||
* Class: com_wolfssl_WolfSSLSession
|
||||
* Method: read
|
||||
|
|
|
@ -110,12 +110,12 @@ public class WolfSSLSession {
|
|||
private final Object sslLock = new Object();
|
||||
|
||||
/* Maximum direct ByteBuffer pool size */
|
||||
private static final int MAX_POOL_SIZE = 32;
|
||||
private static int MAX_POOL_SIZE = 16;
|
||||
|
||||
/* Size of each direct ByteBuffer in the pool. This is set to 17KB, which
|
||||
* is slightly larger than the maximum SSL record size (16KB). This
|
||||
* allows for some overhead (SSL record header, etc) */
|
||||
private static final int BUFFER_SIZE = 17 * 1024;
|
||||
private static int BUFFER_SIZE = 17 * 1024;
|
||||
|
||||
/* Thread-local direct ByteBuffer pool for optimized JNI direct memory
|
||||
* access. Passing byte[] and offset down to JNI, on some systems this
|
||||
|
@ -446,6 +446,9 @@ public class WolfSSLSession {
|
|||
private native int connect(long ssl, int timeout);
|
||||
private native int write(long ssl, byte[] data, int offset, int length,
|
||||
int timeout);
|
||||
private native int write(long ssl, ByteBuffer data, final int position,
|
||||
final int limit, boolean hasArray, int sz, int timeout)
|
||||
throws WolfSSLException;
|
||||
private native int read(long ssl, byte[] data, int offset, int sz,
|
||||
int timeout);
|
||||
private native int read(long ssl, ByteBuffer data, final int position,
|
||||
|
@ -1000,37 +1003,7 @@ public class WolfSSLSession {
|
|||
public int write(byte[] data, int length)
|
||||
throws IllegalStateException, SocketTimeoutException, SocketException {
|
||||
|
||||
final int ret;
|
||||
final int err;
|
||||
long localPtr;
|
||||
|
||||
confirmObjectIsActive();
|
||||
|
||||
/* Fix for Infer scan, since not synchronizing on sslLock for
|
||||
* access to this.sslPtr, see note below */
|
||||
synchronized (sslLock) {
|
||||
localPtr = this.sslPtr;
|
||||
}
|
||||
|
||||
WolfSSLDebug.log(getClass(), WolfSSLDebug.Component.JNI,
|
||||
WolfSSLDebug.INFO, localPtr,
|
||||
() -> "entered write(length: " + length + ")");
|
||||
|
||||
/* not synchronizing on sslLock here since JNI write() locks
|
||||
* session mutex around native wolfSSL_write() call. If sslLock
|
||||
* is locked here, since we call select() inside native JNI we
|
||||
* could timeout waiting for corresponding read() operation to
|
||||
* occur if needed */
|
||||
ret = write(localPtr, data, 0, length, 0);
|
||||
err = getError(ret);
|
||||
|
||||
WolfSSLDebug.log(getClass(), WolfSSLDebug.Component.JNI,
|
||||
WolfSSLDebug.INFO, localPtr,
|
||||
() -> "write() ret: " + ret + ", err: " + err);
|
||||
|
||||
throwExceptionFromIOReturnValue(ret, "wolfSSL_write()");
|
||||
|
||||
return ret;
|
||||
return write(data, 0, length, 0);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -1113,9 +1086,12 @@ public class WolfSSLSession {
|
|||
public int write(byte[] data, int offset, int length, int timeout)
|
||||
throws IllegalStateException, SocketTimeoutException, SocketException {
|
||||
|
||||
final int ret;
|
||||
final int err;
|
||||
int ret = 0;
|
||||
int err = 0;
|
||||
int totalWritten = 0;
|
||||
int remaining = length;
|
||||
long localPtr;
|
||||
ByteBuffer directBuffer = null;
|
||||
|
||||
confirmObjectIsActive();
|
||||
|
||||
|
@ -1130,21 +1106,75 @@ public class WolfSSLSession {
|
|||
() -> "entered write(offset: " + offset + ", length: " +
|
||||
length + ", timeout: " + timeout + ")");
|
||||
|
||||
/* not synchronizing on sslLock here since JNI write() locks
|
||||
* session mutex around native wolfSSL_write() call. If sslLock
|
||||
* is locked here, since we call select() inside native JNI we
|
||||
* could timeout waiting for corresponding read() operation to
|
||||
* occur if needed */
|
||||
ret = write(localPtr, data, offset, length, timeout);
|
||||
err = getError(ret);
|
||||
/* Use a direct ByteBuffer from the pool to avoid unaligned
|
||||
* memory access. Otherwise our native JNI code may need to do
|
||||
* "buffer + offset" and end up with unaligned memory which
|
||||
* can be slow on some targets (ex: ARM/Aarch64) */
|
||||
try {
|
||||
/* Get a buffer from the pool */
|
||||
directBuffer = acquireDirectBuffer();
|
||||
|
||||
WolfSSLDebug.log(getClass(),
|
||||
WolfSSLDebug.Component.JNI, WolfSSLDebug.INFO, localPtr,
|
||||
() -> "write() using thread-local ByteBuffer pool: pool size: " +
|
||||
directBufferPool.get().size());
|
||||
|
||||
/* Write in chunks until all data is written or an error occurs.
|
||||
* The DirectByteBuffer size might be smaller than the data length,
|
||||
* so we need to loop to handle all the data */
|
||||
while (remaining > 0) {
|
||||
/* Calculate size for current chunk */
|
||||
int writeSize = Math.min(remaining, directBuffer.capacity());
|
||||
|
||||
/* Copy data from user array to direct buffer */
|
||||
directBuffer.clear();
|
||||
directBuffer.put(data, offset + totalWritten, writeSize);
|
||||
directBuffer.flip();
|
||||
|
||||
/* Call native write with DirectByteBuffer */
|
||||
ret = write(localPtr, directBuffer, directBuffer.position(),
|
||||
directBuffer.limit(), false, writeSize, timeout);
|
||||
|
||||
if (ret <= 0) {
|
||||
/* Error occurred, break out of loop */
|
||||
break;
|
||||
}
|
||||
|
||||
/* Update tracking variables */
|
||||
totalWritten += ret;
|
||||
remaining -= ret;
|
||||
}
|
||||
|
||||
} catch (Exception e) {
|
||||
|
||||
WolfSSLDebug.log(getClass(), WolfSSLDebug.Component.JNI,
|
||||
WolfSSLDebug.ERROR, localPtr,
|
||||
() -> "write() falling back to use byte[]");
|
||||
|
||||
/* Fall back to original implementation on exception (write not
|
||||
* done yet at this point in JNI call above) */
|
||||
totalWritten = write(localPtr, data, offset, length, timeout);
|
||||
|
||||
} finally {
|
||||
/* Return buffer to pool */
|
||||
if (directBuffer != null) {
|
||||
releaseDirectBuffer(directBuffer);
|
||||
}
|
||||
}
|
||||
|
||||
/* Return total bytes written, or last error code */
|
||||
final int finalRet = (totalWritten > 0) ? totalWritten : ret;
|
||||
final int finalErr = getError(finalRet);
|
||||
final int finalTotal = totalWritten;
|
||||
|
||||
WolfSSLDebug.log(getClass(), WolfSSLDebug.Component.JNI,
|
||||
WolfSSLDebug.INFO, localPtr,
|
||||
() -> "write() ret: " + ret + ", err: " + err);
|
||||
() -> "write() ret: " + finalRet + ", err: " + finalErr +
|
||||
", totalWritten: " + finalTotal);
|
||||
|
||||
throwExceptionFromIOReturnValue(ret, "wolfSSL_write()");
|
||||
throwExceptionFromIOReturnValue(finalRet, "wolfSSL_write()");
|
||||
|
||||
return ret;
|
||||
return finalRet;
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -2895,8 +2895,9 @@ public class WolfSSLSocket extends SSLSocket {
|
|||
if (ret < 0) {
|
||||
/* print error description string */
|
||||
String errStr = WolfSSL.getErrorString(err);
|
||||
throw new IOException("Native wolfSSL_write() error: "
|
||||
+ errStr + " (error code: " + err + ")");
|
||||
throw new IOException("Native wolfSSL_write() error: " +
|
||||
errStr + " (ret: " + ret + ", error code: " +
|
||||
err + ")");
|
||||
}
|
||||
|
||||
} catch (IllegalStateException e) {
|
||||
|
|
Loading…
Reference in New Issue