From 767a2891135f90e0bef14645b0015357e22878ba Mon Sep 17 00:00:00 2001 From: Chris Conlon Date: Mon, 14 Apr 2025 16:43:20 -0600 Subject: [PATCH] JSSE: cache jmethodIDs used in native I/O callbacks globally, improves performance --- native/com_wolfssl_WolfSSL.c | 112 ++++++++++++++++++- native/com_wolfssl_WolfSSLContext.c | 19 +--- native/com_wolfssl_WolfSSLSession.c | 168 ++++------------------------ native/com_wolfssl_globals.h | 15 ++- 4 files changed, 153 insertions(+), 161 deletions(-) diff --git a/native/com_wolfssl_WolfSSL.c b/native/com_wolfssl_WolfSSL.c index f8a279f..442b972 100644 --- a/native/com_wolfssl_WolfSSL.c +++ b/native/com_wolfssl_WolfSSL.c @@ -53,6 +53,16 @@ JavaVM* g_vm; /* global object refs for logging callbacks */ static jobject g_loggingCbIfaceObj; +/* global method IDs we can cache for performance */ +jmethodID g_sslIORecvMethodId = NULL; +jmethodID g_sslIOSendMethodId = NULL; +jmethodID g_bufferPositionMethodId = NULL; +jmethodID g_bufferLimitMethodId = NULL; +jmethodID g_bufferHasArrayMethodId = NULL; +jmethodID g_bufferArrayMethodId = NULL; +jmethodID g_bufferSetPositionMethodId = NULL; +jmethodID g_verifyCallbackMethodId = NULL; + #ifdef HAVE_FIPS /* global object ref for FIPS error callback */ static jobject g_fipsCbIfaceObj; @@ -61,16 +71,116 @@ static jobject g_fipsCbIfaceObj; /* custom native fn prototypes */ void NativeLoggingCallback(const int logLevel, const char *const logMessage); -/* called when native library is loaded */ +/* Called when native library is loaded. + * We also cache global jmethodIDs here for performance. */ JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM* vm, void* reserved) { + JNIEnv* env = NULL; + jclass sslClass = NULL; + jclass byteBufferClass = NULL; + jclass verifyClass = NULL; (void)reserved; /* store JavaVM */ g_vm = vm; + + /* get JNIEnv from JavaVM */ + if ((*vm)->GetEnv(vm, (void**)&env, JNI_VERSION_1_6) != JNI_OK) { + printf("Unable to get JNIEnv from JavaVM in JNI_OnLoad()\n"); + return JNI_ERR; + } + + /* Cache the method ID for IO send and recv callbacks */ + sslClass = (*env)->FindClass(env, "com/wolfssl/WolfSSLSession"); + if (sslClass == NULL) { + return JNI_ERR; + } + + g_sslIORecvMethodId = (*env)->GetMethodID(env, sslClass, + "internalIOSSLRecvCallback", + "(Lcom/wolfssl/WolfSSLSession;[BI)I"); + + g_sslIOSendMethodId = (*env)->GetMethodID(env, sslClass, + "internalIOSSLSendCallback", + "(Lcom/wolfssl/WolfSSLSession;[BI)I"); + + /* Cache ByteBuffer method IDs */ + byteBufferClass = (*env)->FindClass(env, "java/nio/ByteBuffer"); + if (byteBufferClass == NULL) { + return JNI_ERR; + } + + g_bufferPositionMethodId = (*env)->GetMethodID(env, byteBufferClass, + "position", "()I"); + if (g_bufferPositionMethodId == NULL) { + return JNI_ERR; + } + + g_bufferLimitMethodId = (*env)->GetMethodID(env, byteBufferClass, + "limit", "()I"); + if (g_bufferLimitMethodId == NULL) { + return JNI_ERR; + } + + g_bufferHasArrayMethodId = (*env)->GetMethodID(env, byteBufferClass, + "hasArray", "()Z"); + if (g_bufferHasArrayMethodId == NULL) { + return JNI_ERR; + } + + g_bufferArrayMethodId = (*env)->GetMethodID(env, byteBufferClass, + "array", "()[B"); + if (g_bufferArrayMethodId == NULL) { + return JNI_ERR; + } + + g_bufferSetPositionMethodId = (*env)->GetMethodID(env, byteBufferClass, + "position", "(I)Ljava/nio/Buffer;"); + if (g_bufferSetPositionMethodId == NULL) { + return JNI_ERR; + } + + /* Cache verify callback method ID */ + verifyClass = (*env)->FindClass(env, "com/wolfssl/WolfSSLVerifyCallback"); + if (verifyClass == NULL) { + return JNI_ERR; + } + + g_verifyCallbackMethodId = (*env)->GetMethodID(env, verifyClass, + "verifyCallback", "(IJ)I"); + if (g_verifyCallbackMethodId == NULL) { + return JNI_ERR; + } + + /* Clean up local reference to class, not needed */ + (*env)->DeleteLocalRef(env, sslClass); + (*env)->DeleteLocalRef(env, byteBufferClass); + (*env)->DeleteLocalRef(env, verifyClass); + return JNI_VERSION_1_6; } +/* Called when native library is unloaded. + * We clear cached method IDs here. */ +JNIEXPORT void JNICALL JNI_OnUnload(JavaVM* vm, void* reserved) +{ + JNIEnv* env; + + if ((*vm)->GetEnv(vm, (void**)&env, JNI_VERSION_1_6) != JNI_OK) { + return; + } + + /* Clear cached method ID */ + g_sslIORecvMethodId = NULL; + g_sslIOSendMethodId = NULL; + g_bufferPositionMethodId = NULL; + g_bufferLimitMethodId = NULL; + g_bufferHasArrayMethodId = NULL; + g_bufferArrayMethodId = NULL; + g_bufferSetPositionMethodId = NULL; + g_verifyCallbackMethodId = NULL; +} + JNIEXPORT jint JNICALL Java_com_wolfssl_WolfSSL_init (JNIEnv* jenv, jobject jcl) { diff --git a/native/com_wolfssl_WolfSSLContext.c b/native/com_wolfssl_WolfSSLContext.c index 1da4b23..cf4de17 100644 --- a/native/com_wolfssl_WolfSSLContext.c +++ b/native/com_wolfssl_WolfSSLContext.c @@ -1028,7 +1028,6 @@ int NativeIORecvCb(WOLFSSL *ssl, char *buf, int sz, void *ctx) jobject ctxRef; /* WolfSSLContext object */ jclass innerCtxClass; /* WolfSSLContext class */ - jmethodID recvCbMethodId; /* internalIORecvCallback ID */ jbyteArray inData; if (!g_vm || !ssl || !buf || !ctx) { @@ -1140,18 +1139,10 @@ int NativeIORecvCb(WOLFSSL *ssl, char *buf, int sz, void *ctx) return WOLFSSL_CBIO_ERR_GENERAL; } - /* call internal I/O recv callback */ - recvCbMethodId = (*jenv)->GetMethodID(jenv, innerCtxClass, - "internalIORecvCallback", - "(Lcom/wolfssl/WolfSSLSession;[BI)I"); - if (!recvCbMethodId) { - if ((*jenv)->ExceptionOccurred(jenv)) { - (*jenv)->ExceptionDescribe(jenv); - (*jenv)->ExceptionClear(jenv); - } + /* make sure cached recv callback method ID is not null */ + if (!g_sslIORecvMethodId) { (*jenv)->ThrowNew(jenv, excClass, - "Error getting internalIORecvCallback method from JNI"); - (*jenv)->DeleteLocalRef(jenv, ctxRef); + "Cached recv callback method ID is null in NativeIORecvCb"); if (needsDetach) (*g_vm)->DetachCurrentThread(g_vm); return WOLFSSL_CBIO_ERR_GENERAL; @@ -1161,7 +1152,7 @@ int NativeIORecvCb(WOLFSSL *ssl, char *buf, int sz, void *ctx) inData = (*jenv)->NewByteArray(jenv, sz); if (!inData) { (*jenv)->ThrowNew(jenv, excClass, - "Error getting internalIORecvCallback method from JNI"); + "Error creating jbyteArray in NativeIORecvCb"); (*jenv)->DeleteLocalRef(jenv, ctxRef); if (needsDetach) (*g_vm)->DetachCurrentThread(g_vm); @@ -1170,7 +1161,7 @@ int NativeIORecvCb(WOLFSSL *ssl, char *buf, int sz, void *ctx) /* call Java send callback, ignore native ctx since Java * handles it */ - retval = (*jenv)->CallIntMethod(jenv, ctxRef, recvCbMethodId, + retval = (*jenv)->CallIntMethod(jenv, ctxRef, g_sslIORecvMethodId, (jobject)(*g_cachedSSLObj), inData, (jint)sz); diff --git a/native/com_wolfssl_WolfSSLSession.c b/native/com_wolfssl_WolfSSLSession.c index a74cc1f..83db55d 100644 --- a/native/com_wolfssl_WolfSSLSession.c +++ b/native/com_wolfssl_WolfSSLSession.c @@ -103,8 +103,6 @@ int NativeSSLVerifyCallback(int preverify_ok, WOLFSSL_X509_STORE_CTX* store) jint vmret = 0; jint retval = -1; jclass excClass; - jclass verifyClass = NULL; - jmethodID verifyMethod; jobjectRefType refcheck; SSLAppData* appData; /* WOLFSSL app data, stored verify cb obj */ jobject* g_verifySSLCbIfaceObj; /* Global jobject, stored in app data */ @@ -156,34 +154,20 @@ int NativeSSLVerifyCallback(int preverify_ok, WOLFSSL_X509_STORE_CTX* store) refcheck = (*jenv)->GetObjectRefType(jenv, *g_verifySSLCbIfaceObj); if (refcheck == 2) { - /* lookup WolfSSLVerifyCallback class from global object ref */ - verifyClass = (*jenv)->GetObjectClass(jenv, *g_verifySSLCbIfaceObj); - if (!verifyClass) { + if (!g_verifyCallbackMethodId) { if ((*jenv)->ExceptionOccurred(jenv)) { (*jenv)->ExceptionDescribe(jenv); (*jenv)->ExceptionClear(jenv); } (*jenv)->ThrowNew(jenv, excClass, - "Can't get native WolfSSLVerifyCallback class reference"); + "verifyCallback method ID is null in NativeSSLVerifyCallback"); return -107; } - verifyMethod = (*jenv)->GetMethodID(jenv, verifyClass, - "verifyCallback", "(IJ)I"); - if (verifyMethod == 0) { - if ((*jenv)->ExceptionOccurred(jenv)) { - (*jenv)->ExceptionDescribe(jenv); - (*jenv)->ExceptionClear(jenv); - } - - (*jenv)->ThrowNew(jenv, excClass, - "Error getting verifyCallback method from JNI"); - return -108; - } - retval = (*jenv)->CallIntMethod(jenv, *g_verifySSLCbIfaceObj, - verifyMethod, preverify_ok, (jlong)(uintptr_t)store); + g_verifyCallbackMethodId, preverify_ok, + (jlong)(uintptr_t)store); if ((*jenv)->ExceptionOccurred(jenv)) { /* exception occurred on the Java side during method call */ @@ -1415,14 +1399,7 @@ JNIEXPORT jint JNICALL Java_com_wolfssl_WolfSSLSession_read__JLjava_nio_ByteBuff int outSz = length; byte* data = NULL; WOLFSSL* ssl = (WOLFSSL*)(uintptr_t)sslPtr; - jclass excClass; - jclass buffClass; - jmethodID positionMeth; - jmethodID limitMeth; - jmethodID hasArrayMeth; - jmethodID arrayMeth; - jmethodID setPositionMeth; jint position; jint limit; @@ -1444,51 +1421,15 @@ JNIEXPORT jint JNICALL Java_com_wolfssl_WolfSSLSession_read__JLjava_nio_ByteBuff return -1; } - /* Get ByteBuffer class */ - buffClass = (*jenv)->GetObjectClass(jenv, buf); - if (buffClass == NULL) { - (*jenv)->ThrowNew(jenv, excClass, - "Failed to find ByteBuffer class in native read()"); - return -1; - } - /* Get ByteBuffer position */ - positionMeth = (*jenv)->GetMethodID(jenv, buffClass, "position", "()I"); - if (positionMeth == NULL) { - if ((*jenv)->ExceptionOccurred(jenv)) { - (*jenv)->ExceptionDescribe(jenv); - (*jenv)->ExceptionClear(jenv); - } - (*jenv)->ThrowNew(jenv, excClass, - "Failed to find ByteBuffer position() method in native read()"); - return -1; - } - position = (*jenv)->CallIntMethod(jenv, buf, positionMeth); + position = (*jenv)->CallIntMethod(jenv, buf, g_bufferPositionMethodId); /* Get ByteBuffer limit */ - limitMeth = (*jenv)->GetMethodID(jenv, buffClass, "limit", "()I"); - if (limitMeth == NULL) { - if ((*jenv)->ExceptionOccurred(jenv)) { - (*jenv)->ExceptionDescribe(jenv); - (*jenv)->ExceptionClear(jenv); - } - (*jenv)->ThrowNew(jenv, excClass, - "Failed to find ByteBuffer limit() method in native read()"); - return -1; - } - limit = (*jenv)->CallIntMethod(jenv, buf, limitMeth); + limit = (*jenv)->CallIntMethod(jenv, buf, g_bufferLimitMethodId); /* Get and call ByteBuffer.hasArray() before calling array() */ - hasArrayMeth = (*jenv)->GetMethodID(jenv, buffClass, "hasArray", "()Z"); - if (hasArrayMeth == NULL) { - if ((*jenv)->ExceptionOccurred(jenv)) { - (*jenv)->ExceptionDescribe(jenv); - (*jenv)->ExceptionClear(jenv); - } - (*jenv)->ThrowNew(jenv, excClass, - "Failed to find ByteBuffer hasArray() method in native read()"); - return -1; - } + hasArray = (*jenv)->CallBooleanMethod(jenv, buf, + g_bufferHasArrayMethodId); /* Only read up to maximum space we have in this ByteBuffer */ maxOutputSz = (limit - position); @@ -1496,21 +1437,10 @@ JNIEXPORT jint JNICALL Java_com_wolfssl_WolfSSLSession_read__JLjava_nio_ByteBuff outSz = maxOutputSz; } - hasArray = (*jenv)->CallBooleanMethod(jenv, buf, hasArrayMeth); - if (hasArray) { /* Get reference to underlying byte[] from ByteBuffer */ - arrayMeth = (*jenv)->GetMethodID(jenv, buffClass, "array", "()[B"); - if (arrayMeth == NULL) { - if ((*jenv)->ExceptionOccurred(jenv)) { - (*jenv)->ExceptionDescribe(jenv); - (*jenv)->ExceptionClear(jenv); - } - (*jenv)->ThrowNew(jenv, excClass, - "Failed to find ByteBuffer array() method in native read()"); - return -1; - } - bufArr = (jbyteArray)(*jenv)->CallObjectMethod(jenv, buf, arrayMeth); + bufArr = (jbyteArray)(*jenv)->CallObjectMethod(jenv, buf, + g_bufferArrayMethodId); /* Get array elements */ data = (byte *)(*jenv)->GetByteArrayElements(jenv, bufArr, NULL); @@ -1518,7 +1448,7 @@ JNIEXPORT jint JNICALL Java_com_wolfssl_WolfSSLSession_read__JLjava_nio_ByteBuff (*jenv)->ExceptionDescribe(jenv); (*jenv)->ExceptionClear(jenv); (*jenv)->ThrowNew(jenv, excClass, - "Exception when calling ByteBuffer.array() in native read()"); + "Exception when calling ByteBuffer.array() in native read()"); return -1; } } @@ -1526,24 +1456,24 @@ JNIEXPORT jint JNICALL Java_com_wolfssl_WolfSSLSession_read__JLjava_nio_ByteBuff data = (byte *)(*jenv)->GetDirectBufferAddress(jenv, buf); if (data == NULL) { (*jenv)->ThrowNew(jenv, excClass, - "Failed to get DirectBuffer address in native read()"); + "Failed to get DirectBuffer address in native read()"); return BAD_FUNC_ARG; } } if (data != NULL) { size = SSLReadNonblockingWithSelectPoll(ssl, data + position, - maxOutputSz, (int)timeout); + maxOutputSz, (int)timeout); /* Relase array elements */ if (hasArray) { if (size < 0) { - (*jenv)->ReleaseByteArrayElements(jenv, bufArr, (jbyte *)data, - JNI_ABORT); + (*jenv)->ReleaseByteArrayElements(jenv, bufArr, + (jbyte *)data, JNI_ABORT); } else { (*jenv)->ReleaseByteArrayElements(jenv, bufArr, - (jbyte *)data, 0); + (jbyte *)data, 0); } } @@ -1551,22 +1481,8 @@ JNIEXPORT jint JNICALL Java_com_wolfssl_WolfSSLSession_read__JLjava_nio_ByteBuff if (size > 0) { /* Update ByteBuffer position() based on bytes written */ - setPositionMeth = (*jenv)->GetMethodID(jenv, buffClass, - "position", "(I)Ljava/nio/Buffer;"); - if (setPositionMeth == NULL) { - if ((*jenv)->ExceptionOccurred(jenv)) { - (*jenv)->ExceptionDescribe(jenv); - (*jenv)->ExceptionClear(jenv); - } - (*jenv)->ThrowNew(jenv, excClass, - "Failed to set ByteBuffer position() from " - "native read()"); - size = -1; - } - else { - (*jenv)->CallVoidMethod(jenv, buf, setPositionMeth, - position + size); - } + (*jenv)->CallVoidMethod(jenv, buf, g_bufferSetPositionMethodId, + position + size); } } } @@ -6143,8 +6059,6 @@ int NativeSSLIORecvCb(WOLFSSL *ssl, char *buf, int sz, void *ctx) int needsDetach = 0; /* Should we explicitly detach? */ jobject* g_cachedSSLObj; /* WolfSSLSession cached object */ - jclass sslClass; /* WolfSSLSession class */ - jmethodID recvCbMethodId; /* internalIORecvCallback ID */ jbyteArray inData; if (!g_vm || !ssl || !buf || !ctx) { @@ -6189,28 +6103,9 @@ int NativeSSLIORecvCb(WOLFSSL *ssl, char *buf, int sz, void *ctx) return 0; } - /* lookup WolfSSLSession class from object */ - sslClass = (*jenv)->GetObjectClass(jenv, (jobject)(*g_cachedSSLObj)); - if (!sslClass) { + if (!g_sslIORecvMethodId) { (*jenv)->ThrowNew(jenv, excClass, - "Can't get native WolfSSLSession class reference in " - "NativeSSLIORecvCb"); - if (needsDetach) - (*g_vm)->DetachCurrentThread(g_vm); - return WOLFSSL_CBIO_ERR_GENERAL; - } - - /* call internal I/O recv callback */ - recvCbMethodId = (*jenv)->GetMethodID(jenv, sslClass, - "internalIOSSLRecvCallback", - "(Lcom/wolfssl/WolfSSLSession;[BI)I"); - if (!recvCbMethodId) { - if ((*jenv)->ExceptionOccurred(jenv)) { - (*jenv)->ExceptionDescribe(jenv); - (*jenv)->ExceptionClear(jenv); - } - (*jenv)->ThrowNew(jenv, excClass, - "Error getting internalIORecvCallback method from JNI"); + "Cached recv callback method ID is null in internalIORecvCallback"); if (needsDetach) (*g_vm)->DetachCurrentThread(g_vm); return WOLFSSL_CBIO_ERR_GENERAL; @@ -6229,7 +6124,7 @@ int NativeSSLIORecvCb(WOLFSSL *ssl, char *buf, int sz, void *ctx) /* call Java send callback, ignore native ctx since Java * handles it */ retval = (*jenv)->CallIntMethod(jenv, (jobject)(*g_cachedSSLObj), - recvCbMethodId, (jobject)(*g_cachedSSLObj), inData, (jint)sz); + g_sslIORecvMethodId, (jobject)(*g_cachedSSLObj), inData, (jint)sz); if ((*jenv)->ExceptionOccurred(jenv)) { (*jenv)->ExceptionDescribe(jenv); @@ -6285,8 +6180,6 @@ int NativeSSLIOSendCb(WOLFSSL *ssl, char *buf, int sz, void *ctx) int needsDetach = 0; /* Should we explicitly detach? */ jobject* g_cachedSSLObj; /* WolfSSLSession cached object */ - jclass sslClass; /* WolfSSLSession class */ - jmethodID sendCbMethodId; /* internalIOSendCallback ID */ jbyteArray outData; /* jbyteArray for data to send */ if (!g_vm || !ssl || !buf || !ctx) { @@ -6331,27 +6224,14 @@ int NativeSSLIOSendCb(WOLFSSL *ssl, char *buf, int sz, void *ctx) return 0; } - /* lookup WolfSSLSession class from object */ - sslClass = (*jenv)->GetObjectClass(jenv, (jobject)(*g_cachedSSLObj)); - if (!sslClass) { - (*jenv)->ThrowNew(jenv, excClass, - "Can't get native WolfSSLSession class reference"); - if (needsDetach) - (*g_vm)->DetachCurrentThread(g_vm); - return WOLFSSL_CBIO_ERR_GENERAL; - } - /* call internal I/O send callback */ - sendCbMethodId = (*jenv)->GetMethodID(jenv, sslClass, - "internalIOSSLSendCallback", - "(Lcom/wolfssl/WolfSSLSession;[BI)I"); - if (!sendCbMethodId) { + if (!g_sslIOSendMethodId) { if ((*jenv)->ExceptionOccurred(jenv)) { (*jenv)->ExceptionDescribe(jenv); (*jenv)->ExceptionClear(jenv); } (*jenv)->ThrowNew(jenv, excClass, - "Error getting internalIOSendCallback method from JNI"); + "internalIOSendCallback method ID is null in internalIOSendCb"); if (needsDetach) (*g_vm)->DetachCurrentThread(g_vm); return WOLFSSL_CBIO_ERR_GENERAL; @@ -6382,7 +6262,7 @@ int NativeSSLIOSendCb(WOLFSSL *ssl, char *buf, int sz, void *ctx) /* call Java send callback, ignore native ctx since Java * handles it */ retval = (*jenv)->CallIntMethod(jenv, (jobject)(*g_cachedSSLObj), - sendCbMethodId, (jobject)(*g_cachedSSLObj), outData, (jint)sz); + g_sslIOSendMethodId, (jobject)(*g_cachedSSLObj), outData, (jint)sz); if ((*jenv)->ExceptionOccurred(jenv)) { (*jenv)->ExceptionDescribe(jenv); diff --git a/native/com_wolfssl_globals.h b/native/com_wolfssl_globals.h index af3b0fc..4580bdc 100644 --- a/native/com_wolfssl_globals.h +++ b/native/com_wolfssl_globals.h @@ -25,7 +25,19 @@ #define _Included_com_wolfssl_globals /* global JavaVM reference for JNIEnv lookup */ -extern JavaVM* g_vm; +extern JavaVM* g_vm; + +/* Cache static jmethodIDs for performance, since they are guaranteed to be the + * same across all threads once cached. Initialized in JNI_OnLoad() and freed in + * JNI_OnUnload(). */ +extern jmethodID g_sslIORecvMethodId; /* WolfSSLSession.internalIOSSLRecvCallback */ +extern jmethodID g_sslIOSendMethodId; /* WolfSSLSession.internalIOSSLSendCallback */ +extern jmethodID g_bufferPositionMethodId; /* ByteBuffer.position() */ +extern jmethodID g_bufferLimitMethodId; /* ByteBuffer.limit() */ +extern jmethodID g_bufferHasArrayMethodId; /* ByteBuffer.hasArray() */ +extern jmethodID g_bufferArrayMethodId; /* ByteBuffer.array() */ +extern jmethodID g_bufferSetPositionMethodId; /* ByteBuffer.position(int) */ +extern jmethodID g_verifyCallbackMethodId; /* WolfSSLVerifyCallback.verifyCallback */ /* struct to hold I/O class, object refs */ typedef struct { @@ -39,4 +51,3 @@ unsigned int NativePskServerCb(WOLFSSL* ssl, const char* identity, unsigned char* key, unsigned int max_key_len); #endif -