diff --git a/examples/provider/CryptoBenchmark.java b/examples/provider/CryptoBenchmark.java index 1253450..3a6d054 100644 --- a/examples/provider/CryptoBenchmark.java +++ b/examples/provider/CryptoBenchmark.java @@ -8,6 +8,8 @@ import java.security.SecureRandom; import java.security.Security; import java.security.spec.AlgorithmParameterSpec; import java.util.*; +import java.security.KeyPair; +import java.security.KeyPairGenerator; import com.wolfssl.provider.jce.WolfCryptProvider; import com.wolfssl.wolfcrypt.FeatureDetect; @@ -20,6 +22,9 @@ public class CryptoBenchmark { private static final int AES_BLOCK_SIZE = 16; private static final int DES3_BLOCK_SIZE = 8; private static final int GCM_TAG_LENGTH = 128; + private static final int[] RSA_KEY_SIZES = {2048, 3072, 4096}; + private static final int RSA_MIN_TIME_SECONDS = 1; /* minimum time to run each test */ + private static final int SMALL_MESSAGE_SIZE = 32; /* small message size for RSA ops */ /* Class to store benchmark results */ private static class BenchmarkResult { @@ -30,9 +35,9 @@ public class CryptoBenchmark { /* Constructor */ BenchmarkResult(String provider, String operation, double throughput) { - this.provider = provider; - this.operation = operation; - this.throughput = throughput; + this.provider = provider; + this.operation = operation; + this.throughput = throughput; } } @@ -70,56 +75,91 @@ public class CryptoBenchmark { } private static void printDeltaTable() { - /* Variables for table generation */ - Map> groupedResults; - String operation; - Map providerResults; - double wolfSpeed; - String provider; - double otherSpeed; - double deltaMiBs; - double deltaPercent; + /* Variables for table generation */ + Map> groupedResults; + Map providerResults; + double wolfSpeed; + String provider; + double otherSpeed; + double deltaValue; + double deltaPercent; - System.out.println("\nPerformance Delta (compared to wolfJCE)"); - System.out.println("-----------------------------------------------------------------------------"); - System.out.println("| Operation | Provider | Delta | Delta |"); - System.out.println("| | | (MiB/s) | (%) |"); - System.out.println("|------------------------------------------|----------|----------|----------|"); + System.out.println("\nPerformance Delta (compared to wolfJCE)"); + System.out.println("--------------------------------------------------------------------------------"); + System.out.println("| Operation | Provider | Delta | Delta |"); + System.out.println("| | | Value* | (%) |"); + System.out.println("|------------------------------------------|--------------|----------|----------|"); - /* Group results by operation */ - groupedResults = new HashMap<>(); - for (BenchmarkResult result : results) { - groupedResults - .computeIfAbsent(result.operation, k -> new HashMap<>()) - .put(result.provider, result.throughput); - } - - /* Calculate and print deltas */ - for (Map.Entry> entry : groupedResults.entrySet()) { - operation = entry.getKey(); - providerResults = entry.getValue(); - wolfSpeed = providerResults.getOrDefault("wolfJCE", 0.0); - - for (Map.Entry providerEntry : providerResults.entrySet()) { - provider = providerEntry.getKey(); - if (!provider.equals("wolfJCE")) { - otherSpeed = providerEntry.getValue(); - deltaMiBs = wolfSpeed - otherSpeed; - deltaPercent = ((wolfSpeed / otherSpeed) - 1.0) * 100; - - System.out.printf("| %-40s | %-8s | %+8.2f | %+8.1f |%n", - operation, - provider, - deltaMiBs, - deltaPercent); - } + /* Group results by operation */ + groupedResults = new HashMap<>(); + for (BenchmarkResult result : results) { + groupedResults + .computeIfAbsent(result.operation, k -> new HashMap<>()) + .put(result.provider, result.throughput); } - } - System.out.println("-----------------------------------------------------------------------------"); + + /* Sort operations to group RSA operations together */ + List sortedOperations = new ArrayList<>(groupedResults.keySet()); + Collections.sort(sortedOperations, (a, b) -> { + boolean aIsRSA = a.startsWith("RSA"); + boolean bIsRSA = b.startsWith("RSA"); + + if (aIsRSA && !bIsRSA) return -1; + if (!aIsRSA && bIsRSA) return 1; + return a.compareTo(b); + }); + + /* Calculate and print deltas */ + for (String operation : sortedOperations) { + providerResults = groupedResults.get(operation); + wolfSpeed = providerResults.getOrDefault("wolfJCE", 0.0); + boolean isRSAOperation = operation.startsWith("RSA"); + + for (Map.Entry providerEntry : providerResults.entrySet()) { + provider = providerEntry.getKey(); + if (!provider.equals("wolfJCE")) { + otherSpeed = providerEntry.getValue(); + + /* Adjust provider name for RSA operations */ + String displayProvider = provider; + if (isRSAOperation) { + if (operation.contains("key gen")) { + displayProvider = "SunRsaSign"; // Key generation uses SunRsaSign + } else { + displayProvider = "SunJCE"; // Public/private operations use SunJCE + } + } + + if (isRSAOperation) { + deltaValue = wolfSpeed - otherSpeed; + deltaPercent = ((wolfSpeed / otherSpeed) - 1.0) * 100; + } else { + deltaValue = wolfSpeed - otherSpeed; + deltaPercent = ((wolfSpeed / otherSpeed) - 1.0) * 100; + } + + /* Ensure unique operation-provider combination */ + String uniqueKey = operation + "|" + displayProvider; + if (!groupedResults.containsKey(uniqueKey)) { + System.out.printf("| %-40s | %-12s | %+8.2f | %+8.1f |%n", + operation.replace("RSA", "RSA/ECB/PKCS1Padding RSA"), + displayProvider, + deltaValue, + deltaPercent); + + /* Mark this combination as processed */ + groupedResults.put(uniqueKey, null); + } + } + } + } + System.out.println("--------------------------------------------------------------------------------"); + System.out.println("* Delta Value: MiB/s for symmetric ciphers, operations/second for RSA"); } - private static void runBenchmark(String algorithm, String mode, String padding, - String providerName) throws Exception { + /* Run symmetric encryption/decryption benchmarks */ + private static void runEncDecBenchmark(String algorithm, String mode, String padding, + String providerName) throws Exception { SecretKey key; byte[] ivBytes; AlgorithmParameterSpec params; @@ -128,7 +168,7 @@ public class CryptoBenchmark { double dataSizeMiB; Cipher cipher; String cipherName = algorithm + "/" + mode + "/" + padding; - + /* Timing variables */ long startTime; long endTime; @@ -170,7 +210,7 @@ public class CryptoBenchmark { /* Initialize cipher with specific provider */ cipher = Cipher.getInstance(cipherName, providerName); - + /* Warm up phase */ for (int i = 0; i < WARMUP_ITERATIONS; i++) { if (mode.equals("GCM")) { @@ -179,7 +219,7 @@ public class CryptoBenchmark { } cipher.init(Cipher.ENCRYPT_MODE, key, params); encryptedData = cipher.doFinal(testData); - + cipher.init(Cipher.DECRYPT_MODE, key, params); cipher.doFinal(encryptedData); } @@ -202,8 +242,8 @@ public class CryptoBenchmark { encryptThroughput = (DATA_SIZE / (encryptTime / 1000000000.0)) / (1024.0 * 1024.0); String testName = String.format("%s (%s)", cipherName, providerName); - System.out.printf("| %-40s | %8.3f | %8.3f | %8.3f |%n", - testName + " enc", dataSizeMiB, encryptTimeMS, encryptThroughput); + System.out.printf(" %-40s %8.3f MiB %8.3f ms %8.3f MiB/s%n", + testName + " enc", dataSizeMiB, encryptTimeMS, encryptThroughput); results.add(new BenchmarkResult(providerName, cipherName + " enc", encryptThroughput)); @@ -219,13 +259,117 @@ public class CryptoBenchmark { decryptTimeMS = decryptTime / 1000000.0; decryptThroughput = (DATA_SIZE / (decryptTime / 1000000000.0)) / (1024.0 * 1024.0); - System.out.printf("| %-40s | %8.3f | %8.3f | %8.3f |%n", - testName + " dec", dataSizeMiB, decryptTimeMS, decryptThroughput); + System.out.printf(" %-40s %8.3f MiB %8.3f ms %8.3f MiB/s%n", + testName + " dec", dataSizeMiB , decryptTimeMS, decryptThroughput); /* Store decryption result */ results.add(new BenchmarkResult(providerName, cipherName + " dec", decryptThroughput)); } + /* Print RSA results in simpler format */ + private static void printRSAResults(int operations, double totalTime, String operation, + String providerName, String mode) { + /* Variables for result calculations */ + double avgTimeMs; + double opsPerSec; + + /* Calculate metrics */ + avgTimeMs = (totalTime * 1000.0) / operations; + opsPerSec = operations / totalTime; + + /* Print formatted results */ + System.out.printf("%-12s %-8s %8d ops took %.3f sec, avg %.3f ms, %.3f ops/sec%n", + operation + " (" + mode + ")", + " ", + operations, + totalTime, + avgTimeMs, + opsPerSec); + + /* Store results for delta table */ + String fullOperation = operation; + results.add(new BenchmarkResult(providerName, fullOperation, opsPerSec)); + } + + /* Run RSA benchmarks for specified provider and key size */ + private static void runRSABenchmark(String providerName, int keySize) throws Exception { + /* Variables for benchmark operations */ + KeyPairGenerator keyGen; + Cipher cipher; + byte[] testData; + int keyGenOps; + long startTime; + double elapsedTime; + KeyPair keyPair; + int publicOps; + int privateOps; + byte[] encrypted; + String keyGenOp; + String cipherMode = "RSA/ECB/PKCS1Padding"; + + /* Initialize key generator and cipher */ + if (providerName.equals("SunJCE")) { + keyGen = KeyPairGenerator.getInstance("RSA", "SunRsaSign"); + cipher = Cipher.getInstance(cipherMode, "SunJCE"); + providerName = "SunRsaSign"; + } else { + keyGen = KeyPairGenerator.getInstance("RSA", providerName); + cipher = Cipher.getInstance(cipherMode, providerName); + } + testData = generateTestData(SMALL_MESSAGE_SIZE); + + /* Key Generation benchmark */ + keyGen.initialize(keySize); + keyGenOps = 0; + startTime = System.nanoTime(); + elapsedTime = 0; + + /* Run key generation benchmark */ + do { + keyGen.generateKeyPair(); + keyGenOps++; + elapsedTime = (System.nanoTime() - startTime) / 1_000_000_000.0; + } while (elapsedTime < RSA_MIN_TIME_SECONDS); + + keyGenOp = String.format("RSA %d key gen", keySize); + printRSAResults(keyGenOps, elapsedTime, keyGenOp, providerName, cipherMode); + + /* For 2048-bit keys, test public/private operations */ + if (keySize == 2048) { + /* Generate key pair for public/private operations */ + keyPair = keyGen.generateKeyPair(); + + /* Public key operations benchmark */ + publicOps = 0; + startTime = System.nanoTime(); + + do { + cipher.init(Cipher.ENCRYPT_MODE, keyPair.getPublic()); + cipher.doFinal(testData); + publicOps++; + elapsedTime = (System.nanoTime() - startTime) / 1_000_000_000.0; + } while (elapsedTime < RSA_MIN_TIME_SECONDS); + + printRSAResults(publicOps, elapsedTime, "RSA 2048 public", providerName, cipherMode); + + /* Private key operations benchmark */ + cipher.init(Cipher.ENCRYPT_MODE, keyPair.getPublic()); + encrypted = cipher.doFinal(testData); + + privateOps = 0; + startTime = System.nanoTime(); + + do { + cipher.init(Cipher.DECRYPT_MODE, keyPair.getPrivate()); + cipher.doFinal(encrypted); + privateOps++; + elapsedTime = (System.nanoTime() - startTime) / 1_000_000_000.0; + } while (elapsedTime < RSA_MIN_TIME_SECONDS); + + printRSAResults(privateOps, elapsedTime, "RSA 2048 private", providerName, cipherMode); + } + } + public static void main(String[] args) { try { /* Check if Bouncy Castle is available */ @@ -242,18 +386,18 @@ public class CryptoBenchmark { /* Create provider list based on availability */ java.util.List providerList = new java.util.ArrayList<>(); java.util.List providerNameList = new java.util.ArrayList<>(); - + providerList.add(new WolfCryptProvider()); providerNameList.add("wolfJCE"); - + providerList.add(new com.sun.crypto.provider.SunJCE()); providerNameList.add("SunJCE"); - + if (hasBouncyCastle && bcProvider != null) { providerList.add(bcProvider); providerNameList.add("BC"); } - + Provider[] providers = providerList.toArray(new Provider[0]); String[] providerNames = providerNameList.toArray(new String[0]); @@ -263,34 +407,41 @@ public class CryptoBenchmark { } System.out.println("-----------------------------------------------------------------------------"); - System.out.println(" JCE Crypto Provider Benchmark"); - System.out.println("-----------------------------------------------------------------------------"); - - System.out.println("| Operation | Size MiB | ms | MiB/s |"); - System.out.println("|------------------------------------------|----------|----------|----------|"); + System.out.println(" Symmetric Cipher Benchmark"); + System.out.println("-----------------------------------------------------------------------------\n"); - /* Test each provider */ + /* Run symmetric benchmarks */ for (int i = 0; i < providers.length; i++) { Security.insertProviderAt(providers[i], 1); - - /* Run benchmarks for different algorithms */ - runBenchmark("AES", "CBC", "NoPadding", providerNames[i]); - runBenchmark("AES", "CBC", "PKCS5Padding", providerNames[i]); - runBenchmark("AES", "GCM", "NoPadding", providerNames[i]); - /* Only run DES3 benchmark if it's enabled */ + + runEncDecBenchmark("AES", "CBC", "NoPadding", providerNames[i]); + runEncDecBenchmark("AES", "CBC", "PKCS5Padding", providerNames[i]); + runEncDecBenchmark("AES", "GCM", "NoPadding", providerNames[i]); + if (FeatureDetect.Des3Enabled()) { - runBenchmark("DESede", "CBC", "NoPadding", providerNames[i]); + runEncDecBenchmark("DESede", "CBC", "NoPadding", providerNames[i]); } - - if (i < providers.length - 1) { - System.out.println("|------------------------------------------|----------|----------|----------|"); - } - + Security.removeProvider(providers[i].getName()); } - + + + /* Run RSA benchmarks */ + System.out.println("\n-----------------------------------------------------------------------------"); + System.out.println("RSA Benchmark Results"); System.out.println("-----------------------------------------------------------------------------"); + for (Provider provider : providers) { + Security.insertProviderAt(provider, 1); + System.out.println("\n" + (provider.getName().equals("SunJCE") ? "SunJCE / SunRsaSign" : provider.getName()) + ":"); + for (int keySize : RSA_KEY_SIZES) { + runRSABenchmark(provider.getName(), keySize); + } + Security.removeProvider(provider.getName()); + } + System.out.println("-----------------------------------------------------------------------------"); + + /* Print delta table */ printDeltaTable(); } catch (Exception e) {