From 46beebf1b09d8684a163fe09ae1b68b511900c89 Mon Sep 17 00:00:00 2001 From: Chris Conlon Date: Wed, 15 Mar 2023 09:48:22 -0600 Subject: [PATCH] Add synchronization to com.wolfssl.wolfcrypt.Rng, JUnit test cases --- src/main/java/com/wolfssl/wolfcrypt/Rng.java | 56 ++++-- .../com/wolfssl/wolfcrypt/test/RngTest.java | 173 ++++++++++++++++++ 2 files changed, 211 insertions(+), 18 deletions(-) diff --git a/src/main/java/com/wolfssl/wolfcrypt/Rng.java b/src/main/java/com/wolfssl/wolfcrypt/Rng.java index 66aa022..8f7d3ff 100644 --- a/src/main/java/com/wolfssl/wolfcrypt/Rng.java +++ b/src/main/java/com/wolfssl/wolfcrypt/Rng.java @@ -46,33 +46,39 @@ public class Rng extends NativeStruct { int length); private native void rngGenerateBlock(byte[] buffer, int offset, int length); + /* Lock to prevent concurrent access to native WC_RNG */ + private final Object rngLock = new Object(); + /** Default Rng constructor */ public Rng() { } @Override - public void releaseNativeStruct() { + public synchronized void releaseNativeStruct() { free(); - super.releaseNativeStruct(); } /** * Initialize Rng object */ - public void init() { - if (state == WolfCryptState.UNINITIALIZED) { - initRng(); - state = WolfCryptState.INITIALIZED; + public synchronized void init() { + synchronized (rngLock) { + if (state == WolfCryptState.UNINITIALIZED) { + initRng(); + state = WolfCryptState.INITIALIZED; + } } } /** * Free Rng object */ - public void free() { - if (state == WolfCryptState.INITIALIZED) { - freeRng(); - state = WolfCryptState.UNINITIALIZED; + public synchronized void free() { + synchronized (rngLock) { + if (state == WolfCryptState.INITIALIZED) { + freeRng(); + state = WolfCryptState.UNINITIALIZED; + } } } @@ -81,14 +87,23 @@ public class Rng extends NativeStruct { * * Data size will be buffer.remaining() - buffer.position() * - * @param buffer output buffer to place random data + * @param buffer output buffer to place random data, should be direct + * ByteBuffer (ie: ByteBuffer.allocateDirect()) * - * @throws WolfCryptException if native operation fails + * @throws WolfCryptException if native operation fails or input + * ByteBuffer is not direct. */ - public void generateBlock(ByteBuffer buffer) { + public synchronized void generateBlock(ByteBuffer buffer) { init(); - rngGenerateBlock(buffer, buffer.position(), buffer.remaining()); + if (buffer.isDirect() == false) { + throw new WolfCryptException("Input ByteBuffer is not direct"); + } + + synchronized (rngLock) { + rngGenerateBlock(buffer, buffer.position(), buffer.remaining()); + } + buffer.position(buffer.position() + buffer.remaining()); } @@ -101,10 +116,12 @@ public class Rng extends NativeStruct { * * @throws WolfCryptException if native operation fails */ - public void generateBlock(byte[] buffer, int offset, int length) { + public synchronized void generateBlock(byte[] buffer, int offset, int length) { init(); - rngGenerateBlock(buffer, offset, length); + synchronized (rngLock) { + rngGenerateBlock(buffer, offset, length); + } } /** @@ -116,7 +133,9 @@ public class Rng extends NativeStruct { * * @throws WolfCryptException if native operation fails */ - public void generateBlock(byte[] buffer) { + public synchronized void generateBlock(byte[] buffer) { + + /* rngLock acquired inside generateBlock() sub call */ generateBlock(buffer, 0, buffer.length); } @@ -129,9 +148,10 @@ public class Rng extends NativeStruct { * * @throws WolfCryptException if native operation fails */ - public byte[] generateBlock(int length) { + public synchronized byte[] generateBlock(int length) { byte[] buffer = new byte[length]; + /* rngLock acquired inside generateBlock() sub call */ generateBlock(buffer, 0, length); return buffer; diff --git a/src/test/java/com/wolfssl/wolfcrypt/test/RngTest.java b/src/test/java/com/wolfssl/wolfcrypt/test/RngTest.java index d5f519a..1e0eb17 100644 --- a/src/test/java/com/wolfssl/wolfcrypt/test/RngTest.java +++ b/src/test/java/com/wolfssl/wolfcrypt/test/RngTest.java @@ -24,9 +24,17 @@ package com.wolfssl.wolfcrypt.test; import static org.junit.Assert.*; import org.junit.Test; +import java.util.Arrays; +import java.util.Iterator; +import java.nio.ByteBuffer; +import java.util.concurrent.Executors; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.LinkedBlockingQueue; import com.wolfssl.wolfcrypt.Rng; import com.wolfssl.wolfcrypt.NativeStruct; +import com.wolfssl.wolfcrypt.WolfCryptException; public class RngTest { @@ -35,4 +43,169 @@ public class RngTest { assertNotEquals(NativeStruct.NULL, new Rng().getNativeStruct()); } + @Test + public void testInitFree() { + Rng wcRng = new Rng(); + assertNotNull(wcRng); + wcRng.init(); + wcRng.free(); + + /* double init should be ok */ + wcRng.init(); + wcRng.init(); + + /* double free should be ok */ + wcRng.free(); + wcRng.free(); + } + + @Test + public void testGenerateBlockByteBuffer() { + ByteBuffer tmpBlockA = ByteBuffer.allocateDirect(32); + ByteBuffer tmpBlockB = ByteBuffer.allocateDirect(32); + ByteBuffer nonDirect = ByteBuffer.allocate(32); + byte[] tmpA = new byte[32]; + byte[] tmpB = new byte[32]; + Rng wcRng = new Rng(); + + assertNotNull(tmpBlockA); + assertNotNull(tmpBlockB); + assertNotNull(wcRng); + + wcRng.init(); + + wcRng.generateBlock(tmpBlockA); + wcRng.generateBlock(tmpBlockB); + + /* Should get exception if input ByteBuffer is not direct */ + try { + wcRng.generateBlock(nonDirect); + fail("Rng.generateBlock should fail if ByteBuffer is not direct"); + } catch (WolfCryptException e) { + /* expected */ + } + + assertEquals(tmpBlockA.position(), 32); + assertEquals(tmpBlockB.position(), 32); + assertEquals(tmpBlockA.remaining(), 0); + assertEquals(tmpBlockA.remaining(), 0); + + tmpBlockA.flip(); + tmpBlockB.flip(); + + tmpBlockA.get(tmpA); + tmpBlockB.get(tmpB); + + assertNotNull(tmpA); + assertNotNull(tmpB); + + assertFalse(Arrays.equals(tmpA, tmpB)); + + wcRng.free(); + } + + @Test + public void testGenerateBlockByteArrayOffsetLength() { + byte[] tmpBlockA = new byte[32]; + byte[] tmpBlockB = new byte[32]; + + Rng wcRng = new Rng(); + + wcRng.init(); + + /* generate two arrays of size 30 using offset and length */ + wcRng.generateBlock(tmpBlockA, 0, 30); + wcRng.generateBlock(tmpBlockB, 0, 30); + + /* make sure two arrays are not equal */ + assertFalse(Arrays.equals(tmpBlockA, tmpBlockB)); + + wcRng.free(); + } + + @Test + public void testGenerateBlockByteArray() { + byte[] tmpBlockA = new byte[32]; + byte[] tmpBlockB = new byte[32]; + + Rng wcRng = new Rng(); + + wcRng.init(); + + /* fill arrays with random data, up to buffer.length */ + wcRng.generateBlock(tmpBlockA); + wcRng.generateBlock(tmpBlockB); + + /* make sure two arrays are not equal */ + assertFalse(Arrays.equals(tmpBlockA, tmpBlockB)); + + wcRng.free(); + } + + @Test + public void testGenerateBlockReturnArray() { + byte[] tmpBlockA = null; + byte[] tmpBlockB = null; + + Rng wcRng = new Rng(); + + wcRng.init(); + + /* generate two arrays of data */ + tmpBlockA = wcRng.generateBlock(32); + tmpBlockB = wcRng.generateBlock(32); + + assertNotNull(tmpBlockA); + assertNotNull(tmpBlockB); + + assertEquals(tmpBlockA.length, 32); + assertEquals(tmpBlockB.length, 32); + + /* make sure two arrays are not equal */ + assertFalse(Arrays.equals(tmpBlockA, tmpBlockB)); + + wcRng.free(); + } + + @Test + public void testThreadedUse() throws InterruptedException { + int numThreads = 15; + ExecutorService service = Executors.newFixedThreadPool(numThreads); + final CountDownLatch latch = new CountDownLatch(numThreads); + final LinkedBlockingQueue results = new LinkedBlockingQueue<>(); + + for (int i = 0; i < numThreads; i++) { + service.submit(new Runnable() { + @Override public void run() { + Rng wcRng = new Rng(); + byte[] tmp = new byte[16]; + wcRng.init(); + /* generate 1000 random 16-byte arrays per thread */ + for (int j = 0; j < 1000; j++) { + wcRng.generateBlock(tmp); + results.add(tmp.clone()); + } + wcRng.free(); + latch.countDown(); + } + }); + } + + /* wait for all threads to complete */ + latch.await(); + + Iterator listIterator = results.iterator(); + byte[] current = listIterator.next(); + while (listIterator.hasNext()) { + byte[] next = listIterator.next(); + if (Arrays.equals(current, next)) { + fail("Found two identical random arrays in threading test:\n" + + Util.b2h(current) + "\n" + Util.b2h(next)); + } + if (listIterator.hasNext()) { + current = listIterator.next(); + } + } + } } +