Add synchronization to com.wolfssl.wolfcrypt.Rng, JUnit test cases

pull/44/head
Chris Conlon 2023-03-15 09:48:22 -06:00
parent 63b9f6bdb3
commit 46beebf1b0
2 changed files with 211 additions and 18 deletions

View File

@ -46,33 +46,39 @@ public class Rng extends NativeStruct {
int length);
private native void rngGenerateBlock(byte[] buffer, int offset, int length);
/* Lock to prevent concurrent access to native WC_RNG */
private final Object rngLock = new Object();
/** Default Rng constructor */
public Rng() { }
@Override
public void releaseNativeStruct() {
public synchronized void releaseNativeStruct() {
free();
super.releaseNativeStruct();
}
/**
* Initialize Rng object
*/
public void init() {
if (state == WolfCryptState.UNINITIALIZED) {
initRng();
state = WolfCryptState.INITIALIZED;
public synchronized void init() {
synchronized (rngLock) {
if (state == WolfCryptState.UNINITIALIZED) {
initRng();
state = WolfCryptState.INITIALIZED;
}
}
}
/**
* Free Rng object
*/
public void free() {
if (state == WolfCryptState.INITIALIZED) {
freeRng();
state = WolfCryptState.UNINITIALIZED;
public synchronized void free() {
synchronized (rngLock) {
if (state == WolfCryptState.INITIALIZED) {
freeRng();
state = WolfCryptState.UNINITIALIZED;
}
}
}
@ -81,14 +87,23 @@ public class Rng extends NativeStruct {
*
* Data size will be buffer.remaining() - buffer.position()
*
* @param buffer output buffer to place random data
* @param buffer output buffer to place random data, should be direct
* ByteBuffer (ie: ByteBuffer.allocateDirect())
*
* @throws WolfCryptException if native operation fails
* @throws WolfCryptException if native operation fails or input
* ByteBuffer is not direct.
*/
public void generateBlock(ByteBuffer buffer) {
public synchronized void generateBlock(ByteBuffer buffer) {
init();
rngGenerateBlock(buffer, buffer.position(), buffer.remaining());
if (buffer.isDirect() == false) {
throw new WolfCryptException("Input ByteBuffer is not direct");
}
synchronized (rngLock) {
rngGenerateBlock(buffer, buffer.position(), buffer.remaining());
}
buffer.position(buffer.position() + buffer.remaining());
}
@ -101,10 +116,12 @@ public class Rng extends NativeStruct {
*
* @throws WolfCryptException if native operation fails
*/
public void generateBlock(byte[] buffer, int offset, int length) {
public synchronized void generateBlock(byte[] buffer, int offset, int length) {
init();
rngGenerateBlock(buffer, offset, length);
synchronized (rngLock) {
rngGenerateBlock(buffer, offset, length);
}
}
/**
@ -116,7 +133,9 @@ public class Rng extends NativeStruct {
*
* @throws WolfCryptException if native operation fails
*/
public void generateBlock(byte[] buffer) {
public synchronized void generateBlock(byte[] buffer) {
/* rngLock acquired inside generateBlock() sub call */
generateBlock(buffer, 0, buffer.length);
}
@ -129,9 +148,10 @@ public class Rng extends NativeStruct {
*
* @throws WolfCryptException if native operation fails
*/
public byte[] generateBlock(int length) {
public synchronized byte[] generateBlock(int length) {
byte[] buffer = new byte[length];
/* rngLock acquired inside generateBlock() sub call */
generateBlock(buffer, 0, length);
return buffer;

View File

@ -24,9 +24,17 @@ package com.wolfssl.wolfcrypt.test;
import static org.junit.Assert.*;
import org.junit.Test;
import java.util.Arrays;
import java.util.Iterator;
import java.nio.ByteBuffer;
import java.util.concurrent.Executors;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.LinkedBlockingQueue;
import com.wolfssl.wolfcrypt.Rng;
import com.wolfssl.wolfcrypt.NativeStruct;
import com.wolfssl.wolfcrypt.WolfCryptException;
public class RngTest {
@ -35,4 +43,169 @@ public class RngTest {
assertNotEquals(NativeStruct.NULL, new Rng().getNativeStruct());
}
@Test
public void testInitFree() {
Rng wcRng = new Rng();
assertNotNull(wcRng);
wcRng.init();
wcRng.free();
/* double init should be ok */
wcRng.init();
wcRng.init();
/* double free should be ok */
wcRng.free();
wcRng.free();
}
@Test
public void testGenerateBlockByteBuffer() {
ByteBuffer tmpBlockA = ByteBuffer.allocateDirect(32);
ByteBuffer tmpBlockB = ByteBuffer.allocateDirect(32);
ByteBuffer nonDirect = ByteBuffer.allocate(32);
byte[] tmpA = new byte[32];
byte[] tmpB = new byte[32];
Rng wcRng = new Rng();
assertNotNull(tmpBlockA);
assertNotNull(tmpBlockB);
assertNotNull(wcRng);
wcRng.init();
wcRng.generateBlock(tmpBlockA);
wcRng.generateBlock(tmpBlockB);
/* Should get exception if input ByteBuffer is not direct */
try {
wcRng.generateBlock(nonDirect);
fail("Rng.generateBlock should fail if ByteBuffer is not direct");
} catch (WolfCryptException e) {
/* expected */
}
assertEquals(tmpBlockA.position(), 32);
assertEquals(tmpBlockB.position(), 32);
assertEquals(tmpBlockA.remaining(), 0);
assertEquals(tmpBlockA.remaining(), 0);
tmpBlockA.flip();
tmpBlockB.flip();
tmpBlockA.get(tmpA);
tmpBlockB.get(tmpB);
assertNotNull(tmpA);
assertNotNull(tmpB);
assertFalse(Arrays.equals(tmpA, tmpB));
wcRng.free();
}
@Test
public void testGenerateBlockByteArrayOffsetLength() {
byte[] tmpBlockA = new byte[32];
byte[] tmpBlockB = new byte[32];
Rng wcRng = new Rng();
wcRng.init();
/* generate two arrays of size 30 using offset and length */
wcRng.generateBlock(tmpBlockA, 0, 30);
wcRng.generateBlock(tmpBlockB, 0, 30);
/* make sure two arrays are not equal */
assertFalse(Arrays.equals(tmpBlockA, tmpBlockB));
wcRng.free();
}
@Test
public void testGenerateBlockByteArray() {
byte[] tmpBlockA = new byte[32];
byte[] tmpBlockB = new byte[32];
Rng wcRng = new Rng();
wcRng.init();
/* fill arrays with random data, up to buffer.length */
wcRng.generateBlock(tmpBlockA);
wcRng.generateBlock(tmpBlockB);
/* make sure two arrays are not equal */
assertFalse(Arrays.equals(tmpBlockA, tmpBlockB));
wcRng.free();
}
@Test
public void testGenerateBlockReturnArray() {
byte[] tmpBlockA = null;
byte[] tmpBlockB = null;
Rng wcRng = new Rng();
wcRng.init();
/* generate two arrays of data */
tmpBlockA = wcRng.generateBlock(32);
tmpBlockB = wcRng.generateBlock(32);
assertNotNull(tmpBlockA);
assertNotNull(tmpBlockB);
assertEquals(tmpBlockA.length, 32);
assertEquals(tmpBlockB.length, 32);
/* make sure two arrays are not equal */
assertFalse(Arrays.equals(tmpBlockA, tmpBlockB));
wcRng.free();
}
@Test
public void testThreadedUse() throws InterruptedException {
int numThreads = 15;
ExecutorService service = Executors.newFixedThreadPool(numThreads);
final CountDownLatch latch = new CountDownLatch(numThreads);
final LinkedBlockingQueue<byte[]> results = new LinkedBlockingQueue<>();
for (int i = 0; i < numThreads; i++) {
service.submit(new Runnable() {
@Override public void run() {
Rng wcRng = new Rng();
byte[] tmp = new byte[16];
wcRng.init();
/* generate 1000 random 16-byte arrays per thread */
for (int j = 0; j < 1000; j++) {
wcRng.generateBlock(tmp);
results.add(tmp.clone());
}
wcRng.free();
latch.countDown();
}
});
}
/* wait for all threads to complete */
latch.await();
Iterator<byte[]> listIterator = results.iterator();
byte[] current = listIterator.next();
while (listIterator.hasNext()) {
byte[] next = listIterator.next();
if (Arrays.equals(current, next)) {
fail("Found two identical random arrays in threading test:\n" +
Util.b2h(current) + "\n" + Util.b2h(next));
}
if (listIterator.hasNext()) {
current = listIterator.next();
}
}
}
}