JNI test: add JUnit tests for WolfSSLSession I/O ByteBuffer callbacks

pull/257/head
Chris Conlon 2025-04-21 15:51:01 -06:00
parent 6853e02af8
commit 6a243c4149
2 changed files with 411 additions and 13 deletions

View File

@ -3184,9 +3184,6 @@ public class WolfSSLSession {
confirmObjectIsActive();
synchronized (sslLock) {
WolfSSLDebug.log(getClass(), WolfSSLDebug.Component.JNI,
WolfSSLDebug.INFO, this.sslPtr, "entered getIOReadCtx()");
return this.ioReadCtx;
}
}
@ -3235,9 +3232,6 @@ public class WolfSSLSession {
confirmObjectIsActive();
synchronized (sslLock) {
WolfSSLDebug.log(getClass(), WolfSSLDebug.Component.JNI,
WolfSSLDebug.INFO, this.sslPtr, "entered getIOWriteCtx()");
return this.ioWriteCtx;
}
}

View File

@ -23,13 +23,10 @@ package com.wolfssl.test;
import org.junit.Test;
import org.junit.BeforeClass;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import static org.junit.Assert.*;
import java.io.ByteArrayOutputStream;
import java.io.PrintStream;
import java.io.IOException;
import java.net.Socket;
import java.net.ServerSocket;
import java.net.InetAddress;
@ -43,6 +40,8 @@ import java.util.concurrent.Future;
import java.util.concurrent.Callable;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.CountDownLatch;
import java.util.Arrays;
import java.nio.ByteBuffer;
import com.wolfssl.WolfSSL;
import com.wolfssl.WolfSSLDebug;
@ -53,12 +52,11 @@ import com.wolfssl.WolfSSLPskClientCallback;
import com.wolfssl.WolfSSLPskServerCallback;
import com.wolfssl.WolfSSLTls13SecretCallback;
import com.wolfssl.WolfSSLSession;
import com.wolfssl.WolfSSLByteBufferIORecvCallback;
import com.wolfssl.WolfSSLByteBufferIOSendCallback;
public class WolfSSLSessionTest {
private final static int TEST_FAIL = -1;
private final static int TEST_SUCCESS = 0;
private static String cliCert = "./examples/certs/client-cert.pem";
private static String cliKey = "./examples/certs/client-key.pem";
private static String srvCert = "./examples/certs/server-cert.pem";
@ -69,6 +67,9 @@ public class WolfSSLSessionTest {
private final static String exampleHost = "www.example.com";
private final static int examplePort = 443;
/* Maximum network buffer size, for test I/O callbacks */
private final static int MAX_NET_BUF_SZ = 17 * 1024;
private static WolfSSLContext ctx = null;
@BeforeClass
@ -889,7 +890,6 @@ public class WolfSSLSessionTest {
public void test_WolfSSLSession_setTls13SecretCb()
throws WolfSSLJNIException {
int ret;
WolfSSL sslLib = null;
WolfSSLContext sslCtx = null;
WolfSSLSession ssl = null;
@ -1413,5 +1413,409 @@ public class WolfSSLSessionTest {
System.out.println("\t... passed");
}
/**
* wolfSSL I/O context, is passed to I/O callbacks when called
* by native wolfSSL.
*/
private class MyIOCtx {
private byte[] cliToSrv = new byte[MAX_NET_BUF_SZ];
private byte[] srvToCli = new byte[MAX_NET_BUF_SZ];
private int cliToSrvUsed = 0;
private int srvToCliUsed = 0;
private int CLIENT_END = 1;
private int SERVER_END = 2;
private final Object cliLock = new Object();
private final Object srvLock = new Object();
private int insertData(byte[] dest, int destUsed,
ByteBuffer src, int len) {
int freeBufSpace = dest.length - destUsed;
/* Check if buffer is full */
if ((len > 0) && (freeBufSpace == 0)) {
return -1;
}
int bytesToCopy = Math.min(len, freeBufSpace);
if (bytesToCopy > 0) {
src.get(dest, destUsed, bytesToCopy);
}
return bytesToCopy;
}
private int getData(byte[] src, int srcUsed,
ByteBuffer dest, int len) {
/* src buffer is empty */
if ((len > 0) && (srcUsed == 0)) {
return -1;
}
int bytesToCopy = Math.min(len, srcUsed);
if (bytesToCopy > 0) {
dest.put(src, 0, bytesToCopy);
srcUsed -= bytesToCopy;
/* Shift remaining data to front of buffer */
if (srcUsed > 0) {
System.arraycopy(src, bytesToCopy, src, 0, srcUsed);
}
}
return bytesToCopy;
}
public int insertCliToSrvData(ByteBuffer buf, int len) {
synchronized (cliLock) {
int ret = insertData(cliToSrv, cliToSrvUsed, buf, len);
if (ret > 0) {
cliToSrvUsed += ret;
}
return ret;
}
}
public int insertSrvToCliData(ByteBuffer buf, int len) {
synchronized (srvLock) {
int ret = insertData(srvToCli, srvToCliUsed, buf, len);
if (ret > 0) {
srvToCliUsed += ret;
}
return ret;
}
}
public int getCliToSrvData(ByteBuffer buf, int len) {
synchronized (cliLock) {
int ret = getData(cliToSrv, cliToSrvUsed, buf, len);
if (ret > 0) {
cliToSrvUsed -= ret;
}
return ret;
}
}
public int getSrvToCliData(ByteBuffer buf, int len) {
synchronized (srvLock) {
int ret = getData(srvToCli, srvToCliUsed, buf, len);
if (ret > 0) {
srvToCliUsed -= ret;
}
return ret;
}
}
}
/* Client I/O callback using ByteBuffers */
private class ClientByteBufferIOCallback
implements WolfSSLByteBufferIORecvCallback,
WolfSSLByteBufferIOSendCallback {
/**
* Receive data is called when wolfSSL needs to read data from the
* transport layer. In this case, we read data from the beginning
* of the internal byte[] (buffer) and place it into the ByteBuffer buf.
*
* Return the number of bytes copied to the ByteBuffer buf, or negative
* on error.
*/
@Override
public synchronized int receiveCallback(WolfSSLSession ssl,
ByteBuffer buf, int len, Object ctx) {
int ret;
MyIOCtx ioCtx = (MyIOCtx) ctx;
ret = ioCtx.getSrvToCliData(buf, len);
if (ret == -1) {
/* No data available, return WOLFSSL_CBIO_ERR_WANT_READ */
ret = WolfSSL.WOLFSSL_CBIO_ERR_WANT_READ;
}
return ret;
}
/**
* Send data is called when wolfSSL needs to write data to the
* transport layer. In this case, we read data from the ByteBuffer
* buf and place it into our internal byte[] (buffer).
*
* Return the number of bytes copied from the ByteBuffer buf, or
* negative on error.
*/
@Override
public synchronized int sendCallback(
WolfSSLSession ssl, ByteBuffer buf, int len, Object ctx) {
int ret;
MyIOCtx ioCtx = (MyIOCtx) ctx;
ret = ioCtx.insertCliToSrvData(buf, len);
if (ret == -1) {
/* No space available, return WOLFSSL_CBIO_ERR_WANT_WRITE */
ret = WolfSSL.WOLFSSL_CBIO_ERR_WANT_WRITE;
}
return ret;
}
}
/* Server I/O callback using ByteBuffers */
private class ServerByteBufferIOCallback
implements WolfSSLByteBufferIORecvCallback,
WolfSSLByteBufferIOSendCallback {
/**
* Receive data is called when wolfSSL needs to read data from the
* transport layer. In this case, we read data from the beginning
* of the internal byte[] (buffer) and place it into the ByteBuffer buf.
*
* Return the number of bytes copied to the ByteBuffer buf, or negative
* on error.
*/
@Override
public synchronized int receiveCallback(WolfSSLSession ssl,
ByteBuffer buf, int len, Object ctx) {
int ret;
MyIOCtx ioCtx = (MyIOCtx) ctx;
ret = ioCtx.getCliToSrvData(buf, len);
if (ret == -1) {
/* No data available, return WOLFSSL_CBIO_ERR_WANT_READ */
ret = WolfSSL.WOLFSSL_CBIO_ERR_WANT_READ;
}
return ret;
}
/**
* Send data is called when wolfSSL needs to write data to the
* transport layer. In this case, we read data from the ByteBuffer
* buf and place it into our internal byte[] (buffer).
*
* Return the number of bytes copied from the ByteBuffer buf, or
* negative on error.
*/
@Override
public synchronized int sendCallback(
WolfSSLSession ssl, ByteBuffer buf, int len, Object ctx) {
int ret;
MyIOCtx ioCtx = (MyIOCtx) ctx;
ret = ioCtx.insertSrvToCliData(buf, len);
if (ret == -1) {
/* No space available, return WOLFSSL_CBIO_ERR_WANT_WRITE */
ret = WolfSSL.WOLFSSL_CBIO_ERR_WANT_WRITE;
}
return ret;
}
}
@Test
public void test_WolfSSLSession_ioBuffers() throws Exception {
int ret = 0;
int err = 0;
Socket cliSock = null;
WolfSSLSession cliSes = null;
byte[] testData = "Hello from client".getBytes();
byte[] servAppBuffer = new byte[MAX_NET_BUF_SZ];
byte[] cliAppBuffer = new byte[MAX_NET_BUF_SZ];
int bytesRead = 0;
/* Create client/server WolfSSLContext objects */
final WolfSSLContext srvCtx;
WolfSSLContext cliCtx;
System.out.print("\tTesting I/O CB with ByteBuffers");
/* Initialize library */
WolfSSL lib = new WolfSSL();
/* Create ServerSocket first to get ephemeral port */
final ServerSocket srvSocket = new ServerSocket(0);
srvCtx = createAndSetupWolfSSLContext(srvCert, srvKey,
WolfSSL.SSL_FILETYPE_PEM, cliCert,
WolfSSL.SSLv23_ServerMethod());
cliCtx = createAndSetupWolfSSLContext(cliCert, cliKey,
WolfSSL.SSL_FILETYPE_PEM, caCert,
WolfSSL.SSLv23_ClientMethod());
MyIOCtx myIOCb = new MyIOCtx();
ClientByteBufferIOCallback cliIOCb = new ClientByteBufferIOCallback();
ServerByteBufferIOCallback srvIOCb = new ServerByteBufferIOCallback();
ExecutorService es = Executors.newSingleThreadExecutor();
/* Start server */
try {
es.submit(new Callable<Void>() {
@Override
public Void call() throws Exception {
int ret;
int err;
Socket server = null;
WolfSSLSession srvSes = null;
int bytesRead = 0;
try {
server = srvSocket.accept();
srvSes = new WolfSSLSession(srvCtx);
/* Set I/O callback and ctx */
srvSes.setIOSendByteBuffer(srvIOCb);
srvSes.setIORecvByteBuffer(srvIOCb);
srvSes.setIOWriteCtx(myIOCb);
srvSes.setIOReadCtx(myIOCb);
/* Do handshake */
do {
ret = srvSes.accept();
err = srvSes.getError(ret);
} while (ret != WolfSSL.SSL_SUCCESS &&
(err == WolfSSL.SSL_ERROR_WANT_READ ||
err == WolfSSL.SSL_ERROR_WANT_WRITE));
if (ret != WolfSSL.SSL_SUCCESS) {
throw new Exception(
"Server accept failed: " + err);
}
/* Read data from client */
bytesRead = srvSes.read(servAppBuffer,
servAppBuffer.length, 0);
if (bytesRead <= 0) {
throw new Exception(
"Server read failed: " + bytesRead);
}
/* Send same data back to client */
ret = srvSes.write(servAppBuffer, bytesRead, 0);
if (ret != bytesRead) {
throw new Exception("Server write failed: " + ret);
}
srvSes.shutdownSSL();
srvSes.freeSSL();
srvSes = null;
server.close();
server = null;
} finally {
if (srvSes != null) {
srvSes.freeSSL();
}
if (server != null) {
server.close();
}
}
return null;
}
});
} catch (Exception e) {
System.out.println("\t... failed");
e.printStackTrace();
fail();
}
try {
/* Client connection */
cliSock = new Socket(InetAddress.getLocalHost(),
srvSocket.getLocalPort());
cliSes = new WolfSSLSession(cliCtx);
/* Set I/O callback and ctx */
cliSes.setIOSendByteBuffer(cliIOCb);
cliSes.setIORecvByteBuffer(cliIOCb);
cliSes.setIOWriteCtx(myIOCb);
cliSes.setIOReadCtx(myIOCb);
/* Do handshake */
do {
ret = cliSes.connect();
err = cliSes.getError(ret);
} while (ret != WolfSSL.SSL_SUCCESS &&
(err == WolfSSL.SSL_ERROR_WANT_READ ||
err == WolfSSL.SSL_ERROR_WANT_WRITE));
if (ret != WolfSSL.SSL_SUCCESS) {
throw new Exception(
"Client connect failed: " + err);
}
/* Send test data */
ret = cliSes.write(testData, testData.length, 0);
if (ret != testData.length) {
throw new Exception(
"Client write failed: " + ret);
}
/* Read response */
do {
bytesRead = cliSes.read(cliAppBuffer, cliAppBuffer.length, 0);
err = cliSes.getError(bytesRead);
} while (ret != WolfSSL.SSL_SUCCESS &&
err == WolfSSL.SSL_ERROR_WANT_READ ||
err == WolfSSL.SSL_ERROR_WANT_WRITE);
if (bytesRead != testData.length) {
throw new Exception(
"Client read failed: " + bytesRead);
}
/* Verify received data matches sent data using Java 8 compatible
* array comparison */
boolean arraysMatch = true;
if (testData.length != bytesRead) {
arraysMatch = false;
} else {
for (int i = 0; i < testData.length; i++) {
if (testData[i] != cliAppBuffer[i]) {
arraysMatch = false;
break;
}
}
}
if (!arraysMatch) {
throw new Exception("Received data does not match sent data");
}
cliSes.shutdownSSL();
cliSes.freeSSL();
cliSes = null;
cliSock.close();
cliSock = null;
} catch (Exception e) {
System.out.println("\t... failed");
e.printStackTrace();
fail();
} finally {
/* Free resources */
if (cliSes != null) {
cliSes.freeSSL();
}
if (cliSock != null) {
cliSock.close();
}
if (srvSocket != null) {
srvSocket.close();
}
if (srvCtx != null) {
srvCtx.free();
}
es.shutdown();
}
System.out.println("\t... passed");
}
}