diff --git a/native/com_wolfssl_WolfSSLSession.c b/native/com_wolfssl_WolfSSLSession.c index a365f94..f18e8f5 100644 --- a/native/com_wolfssl_WolfSSLSession.c +++ b/native/com_wolfssl_WolfSSLSession.c @@ -1111,18 +1111,113 @@ JNIEXPORT jint JNICALL Java_com_wolfssl_WolfSSLSession_connect return ret; } -JNIEXPORT jint JNICALL Java_com_wolfssl_WolfSSLSession_write - (JNIEnv* jenv, jobject jcl, jlong sslPtr, jbyteArray raw, jint offset, - jint length, jint timeout) +/** + * Write len bytes with wolfSSL_write() from provided input data buffer. + * + * Internal function called by WolfSSLSession.write() calls. + * + * If wolfSSL_get_fd(ssl) returns a socket descriptor, try to wait writability + * with select()/poll() up to provided timeout. + * + * Returns number of bytes written on success, or negative on error. + */ +static int SSLWriteNonblockingWithSelectPoll(WOLFSSL* ssl, byte* data, + int length, int timeout) { - byte* data = NULL; - int ret = SSL_FAILURE, err, sockfd; + int ret, err, sockfd; int pollRx = 0; #if !defined(WOLFJNI_USE_IO_SELECT) && !defined(USE_WINDOWS_API) int pollTx = 0; #endif wolfSSL_Mutex* jniSessLock = NULL; SSLAppData* appData = NULL; + + if (ssl == NULL || data == NULL) { + return BAD_FUNC_ARG; + } + + /* get session mutex from SSL app data */ + appData = (SSLAppData*)wolfSSL_get_app_data(ssl); + if (appData == NULL) { + return WOLFSSL_FAILURE; + } + + jniSessLock = appData->jniSessLock; + if (jniSessLock == NULL) { + return SSL_FAILURE; + } + + do { + /* lock mutex around session I/O before write attempt */ + if (wc_LockMutex(jniSessLock) != 0) { + ret = WOLFSSL_FAILURE; + break; + } + + ret = wolfSSL_write(ssl, data, length); + err = wolfSSL_get_error(ssl, ret); + + /* unlock mutex around session I/O after write attempt */ + if (wc_UnLockMutex(jniSessLock) != 0) { + ret = WOLFSSL_FAILURE; + break; + } + + if ((ret < 0) && + ((err == SSL_ERROR_WANT_READ) || (err == SSL_ERROR_WANT_WRITE))) { + + sockfd = wolfSSL_get_fd(ssl); + if (sockfd == -1) { + /* For I/O that does not use sockets, sockfd may be -1, + * skip try to call select() */ + break; + } + + if (err == SSL_ERROR_WANT_READ) { + pollRx = 1; + } + #if !defined(WOLFJNI_USE_IO_SELECT) && !defined(USE_WINDOWS_API) + else if (err == SSL_ERROR_WANT_WRITE) { + pollTx = 1; + } + #endif + + #if defined(WOLFJNI_USE_IO_SELECT) || defined(USE_WINDOWS_API) + ret = socketSelect(appData, sockfd, (int)timeout, pollRx, 0); + #else + ret = socketPoll(appData, sockfd, (int)timeout, pollRx, + pollTx, 0); + #endif + if ((ret == WOLFJNI_IO_EVENT_RECV_READY) || + (ret == WOLFJNI_IO_EVENT_SEND_READY)) { + /* loop around and try wolfSSL_write() again */ + continue; + } else if (ret == WOLFJNI_IO_EVENT_TIMEOUT || + ret == WOLFJNI_IO_EVENT_FD_CLOSED || + ret == WOLFJNI_IO_EVENT_ERROR || + ret == WOLFJNI_IO_EVENT_POLLHUP || + ret == WOLFJNI_IO_EVENT_FAIL) { + /* Java will throw SocketTimeoutException or + * SocketException */ + break; + } else { + /* error */ + ret = WOLFSSL_FAILURE; + break; + } + } + + } while (err == SSL_ERROR_WANT_WRITE || err == SSL_ERROR_WANT_READ); + + return ret; +} + +JNIEXPORT jint JNICALL Java_com_wolfssl_WolfSSLSession_write__J_3BIII + (JNIEnv* jenv, jobject jcl, jlong sslPtr, jbyteArray raw, jint offset, + jint length, jint timeout) +{ + byte* data = NULL; + int ret; WOLFSSL* ssl = (WOLFSSL*)(uintptr_t)sslPtr; (void)jcl; @@ -1138,85 +1233,8 @@ JNIEXPORT jint JNICALL Java_com_wolfssl_WolfSSLSession_write return SSL_FAILURE; } - /* get session mutex from SSL app data */ - appData = (SSLAppData*)wolfSSL_get_app_data(ssl); - if (appData == NULL) { - (*jenv)->ReleaseByteArrayElements(jenv, raw, (jbyte*)data, - JNI_ABORT); - return WOLFSSL_FAILURE; - } - - jniSessLock = appData->jniSessLock; - if (jniSessLock == NULL) { - (*jenv)->ReleaseByteArrayElements(jenv, raw, (jbyte*)data, - JNI_ABORT); - return SSL_FAILURE; - } - - do { - - /* lock mutex around session I/O before write attempt */ - if (wc_LockMutex(jniSessLock) != 0) { - ret = WOLFSSL_FAILURE; - break; - } - - ret = wolfSSL_write(ssl, data + offset, length); - err = wolfSSL_get_error(ssl, ret); - - /* unlock mutex around session I/O after write attempt */ - if (wc_UnLockMutex(jniSessLock) != 0) { - ret = WOLFSSL_FAILURE; - break; - } - - if (ret >= 0) /* return if it is success */ - break; - - if (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE) { - - sockfd = wolfSSL_get_fd(ssl); - if (sockfd == -1) { - /* For I/O that does not use sockets, sockfd may be -1, - * skip try to call select() */ - break; - } - - if (err == SSL_ERROR_WANT_READ) { - pollRx = 1; - } - #if !defined(WOLFJNI_USE_IO_SELECT) && !defined(USE_WINDOWS_API) - else if (err == SSL_ERROR_WANT_WRITE) { - pollTx = 1; - } - #endif - - #if defined(WOLFJNI_USE_IO_SELECT) || defined(USE_WINDOWS_API) - ret = socketSelect(appData, sockfd, (int)timeout, pollRx, 0); - #else - ret = socketPoll(appData, sockfd, (int)timeout, pollRx, - pollTx, 0); - #endif - if ((ret == WOLFJNI_IO_EVENT_RECV_READY) || - (ret == WOLFJNI_IO_EVENT_SEND_READY)) { - /* loop around and try wolfSSL_write() again */ - continue; - } else if (ret == WOLFJNI_IO_EVENT_TIMEOUT || - ret == WOLFJNI_IO_EVENT_FD_CLOSED || - ret == WOLFJNI_IO_EVENT_ERROR || - ret == WOLFJNI_IO_EVENT_POLLHUP || - ret == WOLFJNI_IO_EVENT_FAIL) { - /* Java will throw SocketTimeoutException or - * SocketException */ - break; - } else { - /* error */ - ret = WOLFSSL_FAILURE; - break; - } - } - - } while (err == SSL_ERROR_WANT_WRITE || err == SSL_ERROR_WANT_READ); + ret = SSLWriteNonblockingWithSelectPoll(ssl, data + offset, + (int)length, (int)timeout); (*jenv)->ReleaseByteArrayElements(jenv, raw, (jbyte*)data, JNI_ABORT); @@ -1227,6 +1245,86 @@ JNIEXPORT jint JNICALL Java_com_wolfssl_WolfSSLSession_write } } +JNIEXPORT jint JNICALL Java_com_wolfssl_WolfSSLSession_write__JLjava_nio_ByteBuffer_2IIZII + (JNIEnv* jenv, jobject jcl, jlong sslPtr, jobject buf, jint position, + jint limit, jboolean hasArray, jint length, jint timeout) +{ + int ret; + int maxInputSz; + int inSz = length; + byte* data = NULL; + WOLFSSL* ssl = (WOLFSSL*)(uintptr_t)sslPtr; + jbyteArray bufArr = NULL; + + (void)jcl; + + if (jenv == NULL || ssl == NULL || buf == NULL) { + return BAD_FUNC_ARG; + } + + if (length > 0) { + + /* Only write up to maximum space we have in this ByteBuffer */ + maxInputSz = (limit - position); + if (inSz > maxInputSz) { + inSz = maxInputSz; + } + + if (inSz <= 0) { + return BAD_FUNC_ARG; + } + + if (hasArray) { + /* Get reference to underlying byte[] from ByteBuffer */ + bufArr = (jbyteArray)(*jenv)->CallObjectMethod(jenv, buf, + g_bufferArrayMethodId); + if ((*jenv)->ExceptionCheck(jenv)) { + return SSL_FAILURE; + } + + /* Get array elements */ + data = (byte *)(*jenv)->GetByteArrayElements(jenv, bufArr, NULL); + if (data == NULL) { + /* Handle any pending exception, we'll throw another below + * anyways so just clear it */ + if ((*jenv)->ExceptionOccurred(jenv)) { + (*jenv)->ExceptionDescribe(jenv); + (*jenv)->ExceptionClear(jenv); + } + throwWolfSSLJNIException(jenv, + "Failed to get byte[] from ByteBuffer in native write()"); + return BAD_FUNC_ARG; + } + } + else { + data = (byte *)(*jenv)->GetDirectBufferAddress(jenv, buf); + if (data == NULL) { + throwWolfSSLJNIException(jenv, + "Failed to get DirectBuffer address in native write()"); + return BAD_FUNC_ARG; + } + } + + ret = SSLWriteNonblockingWithSelectPoll(ssl, data + position, + (int)inSz, (int)timeout); + + /* release memory if using array mode */ + if (hasArray) { + (*jenv)->ReleaseByteArrayElements(jenv, bufArr, + (jbyte*)data, JNI_ABORT); + } + } + + /* check for Java exceptions before returning */ + if ((*jenv)->ExceptionCheck(jenv)) { + (*jenv)->ExceptionDescribe(jenv); + (*jenv)->ExceptionClear(jenv); + return SSL_FAILURE; + } + + return ret; +} + /** * Read len bytes from wolfSSL_read() back into provided output buffer. * diff --git a/native/com_wolfssl_WolfSSLSession.h b/native/com_wolfssl_WolfSSLSession.h index 95b6dda..0b6950c 100644 --- a/native/com_wolfssl_WolfSSLSession.h +++ b/native/com_wolfssl_WolfSSLSession.h @@ -96,9 +96,17 @@ JNIEXPORT jint JNICALL Java_com_wolfssl_WolfSSLSession_connect * Method: write * Signature: (J[BIII)I */ -JNIEXPORT jint JNICALL Java_com_wolfssl_WolfSSLSession_write +JNIEXPORT jint JNICALL Java_com_wolfssl_WolfSSLSession_write__J_3BIII (JNIEnv *, jobject, jlong, jbyteArray, jint, jint, jint); +/* + * Class: com_wolfssl_WolfSSLSession + * Method: write + * Signature: (JLjava/nio/ByteBuffer;IIZII)I + */ +JNIEXPORT jint JNICALL Java_com_wolfssl_WolfSSLSession_write__JLjava_nio_ByteBuffer_2IIZII + (JNIEnv *, jobject, jlong, jobject, jint, jint, jboolean, jint, jint); + /* * Class: com_wolfssl_WolfSSLSession * Method: read diff --git a/src/java/com/wolfssl/WolfSSLSession.java b/src/java/com/wolfssl/WolfSSLSession.java index 62a72c5..a52c0c9 100644 --- a/src/java/com/wolfssl/WolfSSLSession.java +++ b/src/java/com/wolfssl/WolfSSLSession.java @@ -110,12 +110,12 @@ public class WolfSSLSession { private final Object sslLock = new Object(); /* Maximum direct ByteBuffer pool size */ - private static final int MAX_POOL_SIZE = 32; + private static int MAX_POOL_SIZE = 16; /* Size of each direct ByteBuffer in the pool. This is set to 17KB, which * is slightly larger than the maximum SSL record size (16KB). This * allows for some overhead (SSL record header, etc) */ - private static final int BUFFER_SIZE = 17 * 1024; + private static int BUFFER_SIZE = 17 * 1024; /* Thread-local direct ByteBuffer pool for optimized JNI direct memory * access. Passing byte[] and offset down to JNI, on some systems this @@ -446,6 +446,9 @@ public class WolfSSLSession { private native int connect(long ssl, int timeout); private native int write(long ssl, byte[] data, int offset, int length, int timeout); + private native int write(long ssl, ByteBuffer data, final int position, + final int limit, boolean hasArray, int sz, int timeout) + throws WolfSSLException; private native int read(long ssl, byte[] data, int offset, int sz, int timeout); private native int read(long ssl, ByteBuffer data, final int position, @@ -1000,37 +1003,7 @@ public class WolfSSLSession { public int write(byte[] data, int length) throws IllegalStateException, SocketTimeoutException, SocketException { - final int ret; - final int err; - long localPtr; - - confirmObjectIsActive(); - - /* Fix for Infer scan, since not synchronizing on sslLock for - * access to this.sslPtr, see note below */ - synchronized (sslLock) { - localPtr = this.sslPtr; - } - - WolfSSLDebug.log(getClass(), WolfSSLDebug.Component.JNI, - WolfSSLDebug.INFO, localPtr, - () -> "entered write(length: " + length + ")"); - - /* not synchronizing on sslLock here since JNI write() locks - * session mutex around native wolfSSL_write() call. If sslLock - * is locked here, since we call select() inside native JNI we - * could timeout waiting for corresponding read() operation to - * occur if needed */ - ret = write(localPtr, data, 0, length, 0); - err = getError(ret); - - WolfSSLDebug.log(getClass(), WolfSSLDebug.Component.JNI, - WolfSSLDebug.INFO, localPtr, - () -> "write() ret: " + ret + ", err: " + err); - - throwExceptionFromIOReturnValue(ret, "wolfSSL_write()"); - - return ret; + return write(data, 0, length, 0); } /** @@ -1113,9 +1086,12 @@ public class WolfSSLSession { public int write(byte[] data, int offset, int length, int timeout) throws IllegalStateException, SocketTimeoutException, SocketException { - final int ret; - final int err; + int ret = 0; + int err = 0; + int totalWritten = 0; + int remaining = length; long localPtr; + ByteBuffer directBuffer = null; confirmObjectIsActive(); @@ -1130,21 +1106,75 @@ public class WolfSSLSession { () -> "entered write(offset: " + offset + ", length: " + length + ", timeout: " + timeout + ")"); - /* not synchronizing on sslLock here since JNI write() locks - * session mutex around native wolfSSL_write() call. If sslLock - * is locked here, since we call select() inside native JNI we - * could timeout waiting for corresponding read() operation to - * occur if needed */ - ret = write(localPtr, data, offset, length, timeout); - err = getError(ret); + /* Use a direct ByteBuffer from the pool to avoid unaligned + * memory access. Otherwise our native JNI code may need to do + * "buffer + offset" and end up with unaligned memory which + * can be slow on some targets (ex: ARM/Aarch64) */ + try { + /* Get a buffer from the pool */ + directBuffer = acquireDirectBuffer(); + + WolfSSLDebug.log(getClass(), + WolfSSLDebug.Component.JNI, WolfSSLDebug.INFO, localPtr, + () -> "write() using thread-local ByteBuffer pool: pool size: " + + directBufferPool.get().size()); + + /* Write in chunks until all data is written or an error occurs. + * The DirectByteBuffer size might be smaller than the data length, + * so we need to loop to handle all the data */ + while (remaining > 0) { + /* Calculate size for current chunk */ + int writeSize = Math.min(remaining, directBuffer.capacity()); + + /* Copy data from user array to direct buffer */ + directBuffer.clear(); + directBuffer.put(data, offset + totalWritten, writeSize); + directBuffer.flip(); + + /* Call native write with DirectByteBuffer */ + ret = write(localPtr, directBuffer, directBuffer.position(), + directBuffer.limit(), false, writeSize, timeout); + + if (ret <= 0) { + /* Error occurred, break out of loop */ + break; + } + + /* Update tracking variables */ + totalWritten += ret; + remaining -= ret; + } + + } catch (Exception e) { + + WolfSSLDebug.log(getClass(), WolfSSLDebug.Component.JNI, + WolfSSLDebug.ERROR, localPtr, + () -> "write() falling back to use byte[]"); + + /* Fall back to original implementation on exception (write not + * done yet at this point in JNI call above) */ + totalWritten = write(localPtr, data, offset, length, timeout); + + } finally { + /* Return buffer to pool */ + if (directBuffer != null) { + releaseDirectBuffer(directBuffer); + } + } + + /* Return total bytes written, or last error code */ + final int finalRet = (totalWritten > 0) ? totalWritten : ret; + final int finalErr = getError(finalRet); + final int finalTotal = totalWritten; WolfSSLDebug.log(getClass(), WolfSSLDebug.Component.JNI, WolfSSLDebug.INFO, localPtr, - () -> "write() ret: " + ret + ", err: " + err); + () -> "write() ret: " + finalRet + ", err: " + finalErr + + ", totalWritten: " + finalTotal); - throwExceptionFromIOReturnValue(ret, "wolfSSL_write()"); + throwExceptionFromIOReturnValue(finalRet, "wolfSSL_write()"); - return ret; + return finalRet; } /** diff --git a/src/java/com/wolfssl/provider/jsse/WolfSSLSocket.java b/src/java/com/wolfssl/provider/jsse/WolfSSLSocket.java index f59bb79..0bcb297 100644 --- a/src/java/com/wolfssl/provider/jsse/WolfSSLSocket.java +++ b/src/java/com/wolfssl/provider/jsse/WolfSSLSocket.java @@ -2895,8 +2895,9 @@ public class WolfSSLSocket extends SSLSocket { if (ret < 0) { /* print error description string */ String errStr = WolfSSL.getErrorString(err); - throw new IOException("Native wolfSSL_write() error: " - + errStr + " (error code: " + err + ")"); + throw new IOException("Native wolfSSL_write() error: " + + errStr + " (ret: " + ret + ", error code: " + + err + ")"); } } catch (IllegalStateException e) {