From 6a243c414933487dd26f2b2bc2bc20157ea8b361 Mon Sep 17 00:00:00 2001 From: Chris Conlon Date: Mon, 21 Apr 2025 15:51:01 -0600 Subject: [PATCH] JNI test: add JUnit tests for WolfSSLSession I/O ByteBuffer callbacks --- src/java/com/wolfssl/WolfSSLSession.java | 6 - .../com/wolfssl/test/WolfSSLSessionTest.java | 418 +++++++++++++++++- 2 files changed, 411 insertions(+), 13 deletions(-) diff --git a/src/java/com/wolfssl/WolfSSLSession.java b/src/java/com/wolfssl/WolfSSLSession.java index 8cbde66..70e60f5 100644 --- a/src/java/com/wolfssl/WolfSSLSession.java +++ b/src/java/com/wolfssl/WolfSSLSession.java @@ -3184,9 +3184,6 @@ public class WolfSSLSession { confirmObjectIsActive(); synchronized (sslLock) { - WolfSSLDebug.log(getClass(), WolfSSLDebug.Component.JNI, - WolfSSLDebug.INFO, this.sslPtr, "entered getIOReadCtx()"); - return this.ioReadCtx; } } @@ -3235,9 +3232,6 @@ public class WolfSSLSession { confirmObjectIsActive(); synchronized (sslLock) { - WolfSSLDebug.log(getClass(), WolfSSLDebug.Component.JNI, - WolfSSLDebug.INFO, this.sslPtr, "entered getIOWriteCtx()"); - return this.ioWriteCtx; } } diff --git a/src/test/com/wolfssl/test/WolfSSLSessionTest.java b/src/test/com/wolfssl/test/WolfSSLSessionTest.java index dbb1914..7e62ed5 100644 --- a/src/test/com/wolfssl/test/WolfSSLSessionTest.java +++ b/src/test/com/wolfssl/test/WolfSSLSessionTest.java @@ -23,13 +23,10 @@ package com.wolfssl.test; import org.junit.Test; import org.junit.BeforeClass; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; import static org.junit.Assert.*; import java.io.ByteArrayOutputStream; import java.io.PrintStream; -import java.io.IOException; import java.net.Socket; import java.net.ServerSocket; import java.net.InetAddress; @@ -43,6 +40,8 @@ import java.util.concurrent.Future; import java.util.concurrent.Callable; import java.util.concurrent.TimeUnit; import java.util.concurrent.CountDownLatch; +import java.util.Arrays; +import java.nio.ByteBuffer; import com.wolfssl.WolfSSL; import com.wolfssl.WolfSSLDebug; @@ -53,12 +52,11 @@ import com.wolfssl.WolfSSLPskClientCallback; import com.wolfssl.WolfSSLPskServerCallback; import com.wolfssl.WolfSSLTls13SecretCallback; import com.wolfssl.WolfSSLSession; +import com.wolfssl.WolfSSLByteBufferIORecvCallback; +import com.wolfssl.WolfSSLByteBufferIOSendCallback; public class WolfSSLSessionTest { - private final static int TEST_FAIL = -1; - private final static int TEST_SUCCESS = 0; - private static String cliCert = "./examples/certs/client-cert.pem"; private static String cliKey = "./examples/certs/client-key.pem"; private static String srvCert = "./examples/certs/server-cert.pem"; @@ -69,6 +67,9 @@ public class WolfSSLSessionTest { private final static String exampleHost = "www.example.com"; private final static int examplePort = 443; + /* Maximum network buffer size, for test I/O callbacks */ + private final static int MAX_NET_BUF_SZ = 17 * 1024; + private static WolfSSLContext ctx = null; @BeforeClass @@ -889,7 +890,6 @@ public class WolfSSLSessionTest { public void test_WolfSSLSession_setTls13SecretCb() throws WolfSSLJNIException { - int ret; WolfSSL sslLib = null; WolfSSLContext sslCtx = null; WolfSSLSession ssl = null; @@ -1413,5 +1413,409 @@ public class WolfSSLSessionTest { System.out.println("\t... passed"); } + + /** + * wolfSSL I/O context, is passed to I/O callbacks when called + * by native wolfSSL. + */ + private class MyIOCtx { + private byte[] cliToSrv = new byte[MAX_NET_BUF_SZ]; + private byte[] srvToCli = new byte[MAX_NET_BUF_SZ]; + + private int cliToSrvUsed = 0; + private int srvToCliUsed = 0; + + private int CLIENT_END = 1; + private int SERVER_END = 2; + + private final Object cliLock = new Object(); + private final Object srvLock = new Object(); + + private int insertData(byte[] dest, int destUsed, + ByteBuffer src, int len) { + + int freeBufSpace = dest.length - destUsed; + + /* Check if buffer is full */ + if ((len > 0) && (freeBufSpace == 0)) { + return -1; + } + + int bytesToCopy = Math.min(len, freeBufSpace); + if (bytesToCopy > 0) { + src.get(dest, destUsed, bytesToCopy); + } + return bytesToCopy; + } + + private int getData(byte[] src, int srcUsed, + ByteBuffer dest, int len) { + + /* src buffer is empty */ + if ((len > 0) && (srcUsed == 0)) { + return -1; + } + + int bytesToCopy = Math.min(len, srcUsed); + if (bytesToCopy > 0) { + dest.put(src, 0, bytesToCopy); + srcUsed -= bytesToCopy; + /* Shift remaining data to front of buffer */ + if (srcUsed > 0) { + System.arraycopy(src, bytesToCopy, src, 0, srcUsed); + } + } + return bytesToCopy; + } + + public int insertCliToSrvData(ByteBuffer buf, int len) { + synchronized (cliLock) { + int ret = insertData(cliToSrv, cliToSrvUsed, buf, len); + if (ret > 0) { + cliToSrvUsed += ret; + } + return ret; + } + } + + public int insertSrvToCliData(ByteBuffer buf, int len) { + synchronized (srvLock) { + int ret = insertData(srvToCli, srvToCliUsed, buf, len); + if (ret > 0) { + srvToCliUsed += ret; + } + return ret; + } + } + + public int getCliToSrvData(ByteBuffer buf, int len) { + synchronized (cliLock) { + int ret = getData(cliToSrv, cliToSrvUsed, buf, len); + if (ret > 0) { + cliToSrvUsed -= ret; + } + return ret; + } + } + + public int getSrvToCliData(ByteBuffer buf, int len) { + synchronized (srvLock) { + int ret = getData(srvToCli, srvToCliUsed, buf, len); + if (ret > 0) { + srvToCliUsed -= ret; + } + return ret; + } + } + } + + /* Client I/O callback using ByteBuffers */ + private class ClientByteBufferIOCallback + implements WolfSSLByteBufferIORecvCallback, + WolfSSLByteBufferIOSendCallback { + /** + * Receive data is called when wolfSSL needs to read data from the + * transport layer. In this case, we read data from the beginning + * of the internal byte[] (buffer) and place it into the ByteBuffer buf. + * + * Return the number of bytes copied to the ByteBuffer buf, or negative + * on error. + */ + @Override + public synchronized int receiveCallback(WolfSSLSession ssl, + ByteBuffer buf, int len, Object ctx) { + + int ret; + MyIOCtx ioCtx = (MyIOCtx) ctx; + + ret = ioCtx.getSrvToCliData(buf, len); + if (ret == -1) { + /* No data available, return WOLFSSL_CBIO_ERR_WANT_READ */ + ret = WolfSSL.WOLFSSL_CBIO_ERR_WANT_READ; + } + + return ret; + } + + /** + * Send data is called when wolfSSL needs to write data to the + * transport layer. In this case, we read data from the ByteBuffer + * buf and place it into our internal byte[] (buffer). + * + * Return the number of bytes copied from the ByteBuffer buf, or + * negative on error. + */ + @Override + public synchronized int sendCallback( + WolfSSLSession ssl, ByteBuffer buf, int len, Object ctx) { + + int ret; + MyIOCtx ioCtx = (MyIOCtx) ctx; + + ret = ioCtx.insertCliToSrvData(buf, len); + if (ret == -1) { + /* No space available, return WOLFSSL_CBIO_ERR_WANT_WRITE */ + ret = WolfSSL.WOLFSSL_CBIO_ERR_WANT_WRITE; + } + + return ret; + } + } + + /* Server I/O callback using ByteBuffers */ + private class ServerByteBufferIOCallback + implements WolfSSLByteBufferIORecvCallback, + WolfSSLByteBufferIOSendCallback { + /** + * Receive data is called when wolfSSL needs to read data from the + * transport layer. In this case, we read data from the beginning + * of the internal byte[] (buffer) and place it into the ByteBuffer buf. + * + * Return the number of bytes copied to the ByteBuffer buf, or negative + * on error. + */ + @Override + public synchronized int receiveCallback(WolfSSLSession ssl, + ByteBuffer buf, int len, Object ctx) { + + int ret; + MyIOCtx ioCtx = (MyIOCtx) ctx; + + ret = ioCtx.getCliToSrvData(buf, len); + if (ret == -1) { + /* No data available, return WOLFSSL_CBIO_ERR_WANT_READ */ + ret = WolfSSL.WOLFSSL_CBIO_ERR_WANT_READ; + } + + return ret; + } + + /** + * Send data is called when wolfSSL needs to write data to the + * transport layer. In this case, we read data from the ByteBuffer + * buf and place it into our internal byte[] (buffer). + * + * Return the number of bytes copied from the ByteBuffer buf, or + * negative on error. + */ + @Override + public synchronized int sendCallback( + WolfSSLSession ssl, ByteBuffer buf, int len, Object ctx) { + + int ret; + MyIOCtx ioCtx = (MyIOCtx) ctx; + + ret = ioCtx.insertSrvToCliData(buf, len); + if (ret == -1) { + /* No space available, return WOLFSSL_CBIO_ERR_WANT_WRITE */ + ret = WolfSSL.WOLFSSL_CBIO_ERR_WANT_WRITE; + } + + return ret; + } + } + + @Test + public void test_WolfSSLSession_ioBuffers() throws Exception { + int ret = 0; + int err = 0; + Socket cliSock = null; + WolfSSLSession cliSes = null; + byte[] testData = "Hello from client".getBytes(); + byte[] servAppBuffer = new byte[MAX_NET_BUF_SZ]; + byte[] cliAppBuffer = new byte[MAX_NET_BUF_SZ]; + int bytesRead = 0; + + /* Create client/server WolfSSLContext objects */ + final WolfSSLContext srvCtx; + WolfSSLContext cliCtx; + + System.out.print("\tTesting I/O CB with ByteBuffers"); + + /* Initialize library */ + WolfSSL lib = new WolfSSL(); + + /* Create ServerSocket first to get ephemeral port */ + final ServerSocket srvSocket = new ServerSocket(0); + + srvCtx = createAndSetupWolfSSLContext(srvCert, srvKey, + WolfSSL.SSL_FILETYPE_PEM, cliCert, + WolfSSL.SSLv23_ServerMethod()); + cliCtx = createAndSetupWolfSSLContext(cliCert, cliKey, + WolfSSL.SSL_FILETYPE_PEM, caCert, + WolfSSL.SSLv23_ClientMethod()); + + MyIOCtx myIOCb = new MyIOCtx(); + ClientByteBufferIOCallback cliIOCb = new ClientByteBufferIOCallback(); + ServerByteBufferIOCallback srvIOCb = new ServerByteBufferIOCallback(); + + ExecutorService es = Executors.newSingleThreadExecutor(); + + /* Start server */ + try { + es.submit(new Callable() { + @Override + public Void call() throws Exception { + int ret; + int err; + Socket server = null; + WolfSSLSession srvSes = null; + int bytesRead = 0; + + try { + server = srvSocket.accept(); + srvSes = new WolfSSLSession(srvCtx); + + /* Set I/O callback and ctx */ + srvSes.setIOSendByteBuffer(srvIOCb); + srvSes.setIORecvByteBuffer(srvIOCb); + srvSes.setIOWriteCtx(myIOCb); + srvSes.setIOReadCtx(myIOCb); + + /* Do handshake */ + do { + ret = srvSes.accept(); + err = srvSes.getError(ret); + } while (ret != WolfSSL.SSL_SUCCESS && + (err == WolfSSL.SSL_ERROR_WANT_READ || + err == WolfSSL.SSL_ERROR_WANT_WRITE)); + + if (ret != WolfSSL.SSL_SUCCESS) { + throw new Exception( + "Server accept failed: " + err); + } + + /* Read data from client */ + bytesRead = srvSes.read(servAppBuffer, + servAppBuffer.length, 0); + if (bytesRead <= 0) { + throw new Exception( + "Server read failed: " + bytesRead); + } + + /* Send same data back to client */ + ret = srvSes.write(servAppBuffer, bytesRead, 0); + if (ret != bytesRead) { + throw new Exception("Server write failed: " + ret); + } + + srvSes.shutdownSSL(); + srvSes.freeSSL(); + srvSes = null; + server.close(); + server = null; + + } finally { + if (srvSes != null) { + srvSes.freeSSL(); + } + if (server != null) { + server.close(); + } + } + + return null; + } + }); + + } catch (Exception e) { + System.out.println("\t... failed"); + e.printStackTrace(); + fail(); + } + + try { + /* Client connection */ + cliSock = new Socket(InetAddress.getLocalHost(), + srvSocket.getLocalPort()); + + cliSes = new WolfSSLSession(cliCtx); + + /* Set I/O callback and ctx */ + cliSes.setIOSendByteBuffer(cliIOCb); + cliSes.setIORecvByteBuffer(cliIOCb); + cliSes.setIOWriteCtx(myIOCb); + cliSes.setIOReadCtx(myIOCb); + + /* Do handshake */ + do { + ret = cliSes.connect(); + err = cliSes.getError(ret); + } while (ret != WolfSSL.SSL_SUCCESS && + (err == WolfSSL.SSL_ERROR_WANT_READ || + err == WolfSSL.SSL_ERROR_WANT_WRITE)); + + if (ret != WolfSSL.SSL_SUCCESS) { + throw new Exception( + "Client connect failed: " + err); + } + + /* Send test data */ + ret = cliSes.write(testData, testData.length, 0); + if (ret != testData.length) { + throw new Exception( + "Client write failed: " + ret); + } + + /* Read response */ + do { + bytesRead = cliSes.read(cliAppBuffer, cliAppBuffer.length, 0); + err = cliSes.getError(bytesRead); + } while (ret != WolfSSL.SSL_SUCCESS && + err == WolfSSL.SSL_ERROR_WANT_READ || + err == WolfSSL.SSL_ERROR_WANT_WRITE); + + if (bytesRead != testData.length) { + throw new Exception( + "Client read failed: " + bytesRead); + } + + /* Verify received data matches sent data using Java 8 compatible + * array comparison */ + boolean arraysMatch = true; + if (testData.length != bytesRead) { + arraysMatch = false; + } else { + for (int i = 0; i < testData.length; i++) { + if (testData[i] != cliAppBuffer[i]) { + arraysMatch = false; + break; + } + } + } + if (!arraysMatch) { + throw new Exception("Received data does not match sent data"); + } + + cliSes.shutdownSSL(); + cliSes.freeSSL(); + cliSes = null; + cliSock.close(); + cliSock = null; + + } catch (Exception e) { + System.out.println("\t... failed"); + e.printStackTrace(); + fail(); + + } finally { + /* Free resources */ + if (cliSes != null) { + cliSes.freeSSL(); + } + if (cliSock != null) { + cliSock.close(); + } + if (srvSocket != null) { + srvSocket.close(); + } + if (srvCtx != null) { + srvCtx.free(); + } + es.shutdown(); + } + + System.out.println("\t... passed"); + } }