Merge pull request #260 from cconlon/jniSessionTestCleanup

Add more checks to JNI WolfSSLSession.read(ByteBuffer)
pull/261/head
JacobBarthelmeh 2025-04-30 15:33:38 -06:00 committed by GitHub
commit 30bffcc6da
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 72 additions and 30 deletions

View File

@ -1384,13 +1384,15 @@ JNIEXPORT jint JNICALL Java_com_wolfssl_WolfSSLSession_read__JLjava_nio_ByteBuff
if (length > 0) { if (length > 0) {
/* Get ByteBuffer position */ /* Get ByteBuffer position */
position = (*jenv)->CallIntMethod(jenv, buf, g_bufferPositionMethodId); position = (*jenv)->CallIntMethod(jenv, buf, g_bufferPositionMethodId);
if ((*jenv)->ExceptionCheck(jenv)) {
return SSL_FAILURE;
}
/* Get ByteBuffer limit */ /* Get ByteBuffer limit */
limit = (*jenv)->CallIntMethod(jenv, buf, g_bufferLimitMethodId); limit = (*jenv)->CallIntMethod(jenv, buf, g_bufferLimitMethodId);
if ((*jenv)->ExceptionCheck(jenv)) {
/* Get and call ByteBuffer.hasArray() before calling array() */ return SSL_FAILURE;
hasArray = (*jenv)->CallBooleanMethod(jenv, buf, }
g_bufferHasArrayMethodId);
/* Only read up to maximum space we have in this ByteBuffer */ /* Only read up to maximum space we have in this ByteBuffer */
maxOutputSz = (limit - position); maxOutputSz = (limit - position);
@ -1398,10 +1400,24 @@ JNIEXPORT jint JNICALL Java_com_wolfssl_WolfSSLSession_read__JLjava_nio_ByteBuff
outSz = maxOutputSz; outSz = maxOutputSz;
} }
if (outSz <= 0) {
return BAD_FUNC_ARG;
}
/* Get and call ByteBuffer.hasArray() before calling array() */
hasArray = (*jenv)->CallBooleanMethod(jenv, buf,
g_bufferHasArrayMethodId);
if ((*jenv)->ExceptionCheck(jenv)) {
return SSL_FAILURE;
}
if (hasArray) { if (hasArray) {
/* Get reference to underlying byte[] from ByteBuffer */ /* Get reference to underlying byte[] from ByteBuffer */
bufArr = (jbyteArray)(*jenv)->CallObjectMethod(jenv, buf, bufArr = (jbyteArray)(*jenv)->CallObjectMethod(jenv, buf,
g_bufferArrayMethodId); g_bufferArrayMethodId);
if ((*jenv)->ExceptionCheck(jenv)) {
return SSL_FAILURE;
}
/* Get array elements */ /* Get array elements */
data = (byte *)(*jenv)->GetByteArrayElements(jenv, bufArr, NULL); data = (byte *)(*jenv)->GetByteArrayElements(jenv, bufArr, NULL);
@ -1412,6 +1428,12 @@ JNIEXPORT jint JNICALL Java_com_wolfssl_WolfSSLSession_read__JLjava_nio_ByteBuff
"Exception when calling ByteBuffer.array() in native read()"); "Exception when calling ByteBuffer.array() in native read()");
return -1; return -1;
} }
if (data == NULL) {
throwWolfSSLJNIException(jenv,
"Failed to get byte[] from ByteBuffer in native read()");
return BAD_FUNC_ARG;
}
} }
else { else {
data = (byte *)(*jenv)->GetDirectBufferAddress(jenv, buf); data = (byte *)(*jenv)->GetDirectBufferAddress(jenv, buf);
@ -1422,28 +1444,28 @@ JNIEXPORT jint JNICALL Java_com_wolfssl_WolfSSLSession_read__JLjava_nio_ByteBuff
} }
} }
if (data != NULL) { size = SSLReadNonblockingWithSelectPoll(ssl, data + position,
size = SSLReadNonblockingWithSelectPoll(ssl, data + position, maxOutputSz, (int)timeout);
maxOutputSz, (int)timeout);
/* Relase array elements */ /* Release array elements if using array-backed buffer.
if (hasArray) { * Note: DirectByteBuffer doesn't need releasing data */
if (size < 0) { if (hasArray) {
(*jenv)->ReleaseByteArrayElements(jenv, bufArr, if (size < 0) {
(jbyte *)data, JNI_ABORT); (*jenv)->ReleaseByteArrayElements(jenv, bufArr,
} (jbyte *)data, JNI_ABORT);
else {
(*jenv)->ReleaseByteArrayElements(jenv, bufArr,
(jbyte *)data, 0);
}
} }
else {
(*jenv)->ReleaseByteArrayElements(jenv, bufArr,
(jbyte *)data, 0);
}
}
/* Note: DirectByteBuffer doesn't need releasing data */ /* Update ByteBuffer position() based on bytes written, on success */
if (size > 0) {
if (size > 0) { (*jenv)->CallVoidMethod(jenv, buf, g_bufferSetPositionMethodId,
/* Update ByteBuffer position() based on bytes written */ position + size);
(*jenv)->CallVoidMethod(jenv, buf, g_bufferSetPositionMethodId, if ((*jenv)->ExceptionCheck(jenv)) {
position + size); return SSL_FAILURE;
} }
} }
} }

View File

@ -1149,6 +1149,8 @@ public class WolfSSLSessionTest {
else { else {
System.setProperty("wolfssljni.debug", originalProp); System.setProperty("wolfssljni.debug", originalProp);
} }
/* Refresh debug flags */
WolfSSLDebug.refreshDebugFlags(); WolfSSLDebug.refreshDebugFlags();
/* Restore System.out direction */ /* Restore System.out direction */
@ -1686,15 +1688,27 @@ public class WolfSSLSessionTest {
} }
/* Read data from client */ /* Read data from client */
bytesRead = srvSes.read(servAppBuffer, do {
servAppBuffer.length, 0); bytesRead = srvSes.read(servAppBuffer,
servAppBuffer.length, 0);
err = srvSes.getError(bytesRead);
} while ((bytesRead < 0) &&
(err == WolfSSL.SSL_ERROR_WANT_READ ||
err == WolfSSL.SSL_ERROR_WANT_WRITE));
if (bytesRead <= 0) { if (bytesRead <= 0) {
throw new Exception( throw new Exception(
"Server read failed: " + bytesRead); "Server read failed: " + bytesRead);
} }
/* Send same data back to client */ /* Send same data back to client */
ret = srvSes.write(servAppBuffer, bytesRead, 0); do {
ret = srvSes.write(servAppBuffer, bytesRead, 0);
err = srvSes.getError(ret);
} while ((ret < 0) &&
(err == WolfSSL.SSL_ERROR_WANT_READ ||
err == WolfSSL.SSL_ERROR_WANT_WRITE));
if (ret != bytesRead) { if (ret != bytesRead) {
throw new Exception("Server write failed: " + ret); throw new Exception("Server write failed: " + ret);
} }
@ -1751,7 +1765,13 @@ public class WolfSSLSessionTest {
} }
/* Send test data */ /* Send test data */
ret = cliSes.write(testData, testData.length, 0); do {
ret = cliSes.write(testData, testData.length, 0);
err = cliSes.getError(ret);
} while ((ret < 0) &&
(err == WolfSSL.SSL_ERROR_WANT_READ ||
err == WolfSSL.SSL_ERROR_WANT_WRITE));
if (ret != testData.length) { if (ret != testData.length) {
throw new Exception( throw new Exception(
"Client write failed: " + ret); "Client write failed: " + ret);
@ -1761,9 +1781,9 @@ public class WolfSSLSessionTest {
do { do {
bytesRead = cliSes.read(cliAppBuffer, cliAppBuffer.length, 0); bytesRead = cliSes.read(cliAppBuffer, cliAppBuffer.length, 0);
err = cliSes.getError(bytesRead); err = cliSes.getError(bytesRead);
} while (ret != WolfSSL.SSL_SUCCESS && } while ((bytesRead < 0) &&
err == WolfSSL.SSL_ERROR_WANT_READ || (err == WolfSSL.SSL_ERROR_WANT_READ ||
err == WolfSSL.SSL_ERROR_WANT_WRITE); err == WolfSSL.SSL_ERROR_WANT_WRITE));
if (bytesRead != testData.length) { if (bytesRead != testData.length) {
throw new Exception( throw new Exception(