JNI: refactor DH threaded test to use AtomicIntegerArray

pull/73/head
Chris Conlon 2024-04-12 16:34:23 -06:00
parent 8242964c3f
commit 3198d3e8da
2 changed files with 46 additions and 38 deletions

View File

@ -33,7 +33,6 @@ import java.util.concurrent.TimeUnit;
import java.util.concurrent.Executors; import java.util.concurrent.Executors;
import java.util.concurrent.ExecutorService; import java.util.concurrent.ExecutorService;
import java.util.concurrent.CountDownLatch; import java.util.concurrent.CountDownLatch;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.atomic.AtomicIntegerArray; import java.util.concurrent.atomic.AtomicIntegerArray;
import javax.crypto.KeyAgreement; import javax.crypto.KeyAgreement;

View File

@ -30,10 +30,11 @@ import org.junit.Test;
import java.util.Arrays; import java.util.Arrays;
import java.util.Random; import java.util.Random;
import java.util.Iterator; import java.util.Iterator;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.Executors; import java.util.concurrent.Executors;
import java.util.concurrent.ExecutorService; import java.util.concurrent.ExecutorService;
import java.util.concurrent.CountDownLatch; import java.util.concurrent.CountDownLatch;
import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.atomic.AtomicIntegerArray;
import com.wolfssl.wolfcrypt.Dh; import com.wolfssl.wolfcrypt.Dh;
import com.wolfssl.wolfcrypt.Rng; import com.wolfssl.wolfcrypt.Rng;
@ -105,7 +106,16 @@ public class DhTest {
int numThreads = 10; int numThreads = 10;
ExecutorService service = Executors.newFixedThreadPool(numThreads); ExecutorService service = Executors.newFixedThreadPool(numThreads);
final CountDownLatch latch = new CountDownLatch(numThreads); final CountDownLatch latch = new CountDownLatch(numThreads);
final LinkedBlockingQueue<Integer> results = new LinkedBlockingQueue<>();
/* Used to detect timeout of CountDownLatch, don't run indefinitely
* if threads are stalled out or deadlocked */
boolean returnWithoutTimeout = true;
/* Keep track of failure and success count */
final AtomicIntegerArray failures = new AtomicIntegerArray(1);
final AtomicIntegerArray success = new AtomicIntegerArray(1);
failures.set(0, 0);
success.set(0, 0);
final byte[] p = Util.h2b( final byte[] p = Util.h2b(
"E6969D3D495BE32C7CF180C3BDD4798E91B7818251BB055E" "E6969D3D495BE32C7CF180C3BDD4798E91B7818251BB055E"
@ -121,7 +131,6 @@ public class DhTest {
service.submit(new Runnable() { service.submit(new Runnable() {
@Override public void run() { @Override public void run() {
int failed = 0;
Dh alice = null; Dh alice = null;
Dh bob = null; Dh bob = null;
@ -134,65 +143,65 @@ public class DhTest {
/* keys should be null before generation */ /* keys should be null before generation */
if (alice.getPublicKey() != null || if (alice.getPublicKey() != null ||
bob.getPublicKey() != null) { bob.getPublicKey() != null) {
failed = 1; throw new Exception(
"keys not null before generation");
} }
/* generate Dh keys */ /* generate Dh keys */
if (failed == 0) { synchronized (rngLock) {
synchronized (rngLock) { alice.makeKey(rng);
alice.makeKey(rng); bob.makeKey(rng);
bob.makeKey(rng);
}
} }
/* keys should not be null after generation */ /* keys should not be null after generation */
if (failed == 0) { if (alice.getPublicKey() == null ||
if (alice.getPublicKey() == null || bob.getPublicKey() == null) {
bob.getPublicKey() == null) { throw new Exception(
failed = 1; "keys null after generation");
}
} }
if (failed == 0) { byte[] sharedSecretA = alice.makeSharedSecret(bob);
byte[] sharedSecretA = alice.makeSharedSecret(bob); byte[] sharedSecretB = bob.makeSharedSecret(alice);
byte[] sharedSecretB = bob.makeSharedSecret(alice);
if (sharedSecretA == null || if (sharedSecretA == null ||
sharedSecretB == null || sharedSecretB == null ||
!Arrays.equals(sharedSecretA, sharedSecretB)) { !Arrays.equals(sharedSecretA, sharedSecretB)) {
failed = 1; throw new Exception(
} "shared secrets null or not equal");
} }
/* Log success */
success.incrementAndGet(0);
} catch (Exception e) { } catch (Exception e) {
e.printStackTrace(); e.printStackTrace();
failed = 1;
/* Log failure */
failures.incrementAndGet(0);
} finally { } finally {
alice.releaseNativeStruct(); alice.releaseNativeStruct();
bob.releaseNativeStruct(); bob.releaseNativeStruct();
latch.countDown(); latch.countDown();
} }
if (failed == 1) {
results.add(1);
}
else {
results.add(0);
}
} }
}); });
} }
/* wait for all threads to complete */ /* wait for all threads to complete */
latch.await(); returnWithoutTimeout = latch.await(10, TimeUnit.SECONDS);
service.shutdown();
/* Look for any failures that happened */ /* Check failure count and success count against thread count */
Iterator<Integer> listIterator = results.iterator(); if ((failures.get(0) != 0) ||
while (listIterator.hasNext()) { (success.get(0) != numThreads)) {
Integer cur = listIterator.next(); if (returnWithoutTimeout == true) {
if (cur == 1) { fail("DH shared secret test threading error: " +
fail("Threading error in DH shared secret thread test"); failures.get(0) + " failures, " +
success.get(0) + " success, " +
numThreads + " num threads total");
} else {
fail("DH shared secret test error, threads timed out");
} }
} }
} }