diff --git a/src/test/java/com/wolfssl/provider/jce/test/WolfCryptKeyAgreementTest.java b/src/test/java/com/wolfssl/provider/jce/test/WolfCryptKeyAgreementTest.java index 8aba56f..c6c366f 100644 --- a/src/test/java/com/wolfssl/provider/jce/test/WolfCryptKeyAgreementTest.java +++ b/src/test/java/com/wolfssl/provider/jce/test/WolfCryptKeyAgreementTest.java @@ -33,7 +33,6 @@ import java.util.concurrent.TimeUnit; import java.util.concurrent.Executors; import java.util.concurrent.ExecutorService; import java.util.concurrent.CountDownLatch; -import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.atomic.AtomicIntegerArray; import javax.crypto.KeyAgreement; diff --git a/src/test/java/com/wolfssl/wolfcrypt/test/DhTest.java b/src/test/java/com/wolfssl/wolfcrypt/test/DhTest.java index d738bde..13d516c 100644 --- a/src/test/java/com/wolfssl/wolfcrypt/test/DhTest.java +++ b/src/test/java/com/wolfssl/wolfcrypt/test/DhTest.java @@ -30,10 +30,11 @@ import org.junit.Test; import java.util.Arrays; import java.util.Random; import java.util.Iterator; +import java.util.concurrent.TimeUnit; import java.util.concurrent.Executors; import java.util.concurrent.ExecutorService; import java.util.concurrent.CountDownLatch; -import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.atomic.AtomicIntegerArray; import com.wolfssl.wolfcrypt.Dh; import com.wolfssl.wolfcrypt.Rng; @@ -105,7 +106,16 @@ public class DhTest { int numThreads = 10; ExecutorService service = Executors.newFixedThreadPool(numThreads); final CountDownLatch latch = new CountDownLatch(numThreads); - final LinkedBlockingQueue results = new LinkedBlockingQueue<>(); + + /* Used to detect timeout of CountDownLatch, don't run indefinitely + * if threads are stalled out or deadlocked */ + boolean returnWithoutTimeout = true; + + /* Keep track of failure and success count */ + final AtomicIntegerArray failures = new AtomicIntegerArray(1); + final AtomicIntegerArray success = new AtomicIntegerArray(1); + failures.set(0, 0); + success.set(0, 0); final byte[] p = Util.h2b( "E6969D3D495BE32C7CF180C3BDD4798E91B7818251BB055E" @@ -121,7 +131,6 @@ public class DhTest { service.submit(new Runnable() { @Override public void run() { - int failed = 0; Dh alice = null; Dh bob = null; @@ -134,65 +143,65 @@ public class DhTest { /* keys should be null before generation */ if (alice.getPublicKey() != null || bob.getPublicKey() != null) { - failed = 1; + throw new Exception( + "keys not null before generation"); } /* generate Dh keys */ - if (failed == 0) { - synchronized (rngLock) { - alice.makeKey(rng); - bob.makeKey(rng); - } + synchronized (rngLock) { + alice.makeKey(rng); + bob.makeKey(rng); } /* keys should not be null after generation */ - if (failed == 0) { - if (alice.getPublicKey() == null || - bob.getPublicKey() == null) { - failed = 1; - } + if (alice.getPublicKey() == null || + bob.getPublicKey() == null) { + throw new Exception( + "keys null after generation"); } - if (failed == 0) { - byte[] sharedSecretA = alice.makeSharedSecret(bob); - byte[] sharedSecretB = bob.makeSharedSecret(alice); + byte[] sharedSecretA = alice.makeSharedSecret(bob); + byte[] sharedSecretB = bob.makeSharedSecret(alice); - if (sharedSecretA == null || - sharedSecretB == null || - !Arrays.equals(sharedSecretA, sharedSecretB)) { - failed = 1; - } + if (sharedSecretA == null || + sharedSecretB == null || + !Arrays.equals(sharedSecretA, sharedSecretB)) { + throw new Exception( + "shared secrets null or not equal"); } + /* Log success */ + success.incrementAndGet(0); + } catch (Exception e) { e.printStackTrace(); - failed = 1; + + /* Log failure */ + failures.incrementAndGet(0); } finally { alice.releaseNativeStruct(); bob.releaseNativeStruct(); latch.countDown(); } - - if (failed == 1) { - results.add(1); - } - else { - results.add(0); - } } }); } /* wait for all threads to complete */ - latch.await(); + returnWithoutTimeout = latch.await(10, TimeUnit.SECONDS); + service.shutdown(); - /* Look for any failures that happened */ - Iterator listIterator = results.iterator(); - while (listIterator.hasNext()) { - Integer cur = listIterator.next(); - if (cur == 1) { - fail("Threading error in DH shared secret thread test"); + /* Check failure count and success count against thread count */ + if ((failures.get(0) != 0) || + (success.get(0) != numThreads)) { + if (returnWithoutTimeout == true) { + fail("DH shared secret test threading error: " + + failures.get(0) + " failures, " + + success.get(0) + " success, " + + numThreads + " num threads total"); + } else { + fail("DH shared secret test error, threads timed out"); } } }