Merge pull request #8436 from SparkiDev/mlkem_cache_a

ML-KEM/Kyber: cache A from key generation for decapsulation
pull/8447/head
David Garske 2025-02-12 17:29:38 -08:00 committed by GitHub
commit db0fa304a8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 96 additions and 26 deletions

View File

@ -1399,6 +1399,9 @@ do
small) small)
AM_CFLAGS="$AM_CFLAGS -DWOLFSSL_KYBER_SMALL" AM_CFLAGS="$AM_CFLAGS -DWOLFSSL_KYBER_SMALL"
;; ;;
cache-a)
AM_CFLAGS="$AM_CFLAGS -DWOLFSSL_MLKEM_CACHE_A"
;;
512) 512)
ENABLED_KYBER512=yes ENABLED_KYBER512=yes
;; ;;

View File

@ -9630,17 +9630,37 @@ exit:
#endif #endif
} }
static void bench_kyber_encap(const char* name, int keySize, KyberKey* key) static void bench_kyber_encap(int type, const char* name, int keySize,
KyberKey* key1, KyberKey* key2)
{ {
int ret = 0, times, count, pending = 0; int ret = 0, times, count, pending = 0;
double start; double start;
const char**desc = bench_desc_words[lng_index]; const char**desc = bench_desc_words[lng_index];
byte ct[KYBER_MAX_CIPHER_TEXT_SIZE]; byte ct[KYBER_MAX_CIPHER_TEXT_SIZE];
byte ss[KYBER_SS_SZ]; byte ss[KYBER_SS_SZ];
byte pub[KYBER_MAX_PUBLIC_KEY_SIZE];
word32 pubLen;
word32 ctSz; word32 ctSz;
DECLARE_MULTI_VALUE_STATS_VARS() DECLARE_MULTI_VALUE_STATS_VARS()
ret = wc_KyberKey_CipherTextSize(key, &ctSz); ret = wc_KyberKey_PublicKeySize(key1, &pubLen);
if (ret != 0) {
return;
}
ret = wc_KyberKey_EncodePublicKey(key1, pub, pubLen);
if (ret != 0) {
return;
}
ret = wc_KyberKey_Init(type, key2, HEAP_HINT, INVALID_DEVID);
if (ret != 0) {
return;
}
ret = wc_KyberKey_DecodePublicKey(key2, pub, pubLen);
if (ret != 0) {
return;
}
ret = wc_KyberKey_CipherTextSize(key2, &ctSz);
if (ret != 0) { if (ret != 0) {
return; return;
} }
@ -9651,10 +9671,10 @@ static void bench_kyber_encap(const char* name, int keySize, KyberKey* key)
/* while free pending slots in queue, submit ops */ /* while free pending slots in queue, submit ops */
for (times = 0; times < agreeTimes || pending > 0; times++) { for (times = 0; times < agreeTimes || pending > 0; times++) {
#ifdef KYBER_NONDETERMINISTIC #ifdef KYBER_NONDETERMINISTIC
ret = wc_KyberKey_Encapsulate(key, ct, ss, &gRng); ret = wc_KyberKey_Encapsulate(key2, ct, ss, &gRng);
#else #else
unsigned char rand[KYBER_ENC_RAND_SZ] = {0,}; unsigned char rand[KYBER_ENC_RAND_SZ] = {0,};
ret = wc_KyberKey_EncapsulateWithRandom(key, ct, ss, rand, ret = wc_KyberKey_EncapsulateWithRandom(key2, ct, ss, rand,
sizeof(rand)); sizeof(rand));
#endif #endif
if (ret != 0) if (ret != 0)
@ -9681,7 +9701,7 @@ exit_encap:
do { do {
/* while free pending slots in queue, submit ops */ /* while free pending slots in queue, submit ops */
for (times = 0; times < agreeTimes || pending > 0; times++) { for (times = 0; times < agreeTimes || pending > 0; times++) {
ret = wc_KyberKey_Decapsulate(key, ss, ct, ctSz); ret = wc_KyberKey_Decapsulate(key1, ss, ct, ctSz);
if (ret != 0) if (ret != 0)
goto exit_decap; goto exit_decap;
RECORD_MULTI_VALUE_STATS(); RECORD_MULTI_VALUE_STATS();
@ -9702,7 +9722,8 @@ exit_decap:
void bench_kyber(int type) void bench_kyber(int type)
{ {
KyberKey key; KyberKey key1;
KyberKey key2;
const char* name = NULL; const char* name = NULL;
int keySize = 0; int keySize = 0;
@ -9749,10 +9770,11 @@ void bench_kyber(int type)
#endif #endif
} }
bench_kyber_keygen(type, name, keySize, &key); bench_kyber_keygen(type, name, keySize, &key1);
bench_kyber_encap(name, keySize, &key); bench_kyber_encap(type, name, keySize, &key1, &key2);
wc_KyberKey_Free(&key); wc_KyberKey_Free(&key2);
wc_KyberKey_Free(&key1);
} }
#endif #endif

View File

@ -63,6 +63,12 @@
#error "Can't use small memory with assembly optimized code" #error "Can't use small memory with assembly optimized code"
#endif #endif
#endif #endif
#if defined(WOLFSSL_MLKEM_CACHE_A)
#if defined(WOLFSSL_MLKEM_MAKEKEY_SMALL_MEM) || \
defined(WOLFSSL_MLKEM_ENCAPSULATE_SMALL_MEM)
#error "Can't cache A with small memory code"
#endif
#endif
#ifdef WOLFSSL_WC_KYBER #ifdef WOLFSSL_WC_KYBER
@ -265,10 +271,14 @@ int wc_KyberKey_MakeKeyWithRandom(KyberKey* key, const unsigned char* rand,
sword16* e = NULL; sword16* e = NULL;
#else #else
#ifndef WOLFSSL_MLKEM_MAKEKEY_SMALL_MEM #ifndef WOLFSSL_MLKEM_MAKEKEY_SMALL_MEM
#ifndef WOLFSSL_MLKEM_CACHE_A
sword16 e[(KYBER_MAX_K + 1) * KYBER_MAX_K * KYBER_N]; sword16 e[(KYBER_MAX_K + 1) * KYBER_MAX_K * KYBER_N];
#else #else
sword16 e[KYBER_MAX_K * KYBER_N]; sword16 e[KYBER_MAX_K * KYBER_N];
#endif #endif
#else
sword16 e[KYBER_MAX_K * KYBER_N];
#endif
#endif #endif
#ifndef WOLFSSL_MLKEM_MAKEKEY_SMALL_MEM #ifndef WOLFSSL_MLKEM_MAKEKEY_SMALL_MEM
sword16* a = NULL; sword16* a = NULL;
@ -285,6 +295,8 @@ int wc_KyberKey_MakeKeyWithRandom(KyberKey* key, const unsigned char* rand,
} }
if (ret == 0) { if (ret == 0) {
key->flags = 0;
/* Establish parameters based on key type. */ /* Establish parameters based on key type. */
switch (key->type) { switch (key->type) {
#ifndef WOLFSSL_NO_ML_KEM #ifndef WOLFSSL_NO_ML_KEM
@ -332,9 +344,17 @@ int wc_KyberKey_MakeKeyWithRandom(KyberKey* key, const unsigned char* rand,
if (ret == 0) { if (ret == 0) {
/* Allocate dynamic memory for matrix and error vector. */ /* Allocate dynamic memory for matrix and error vector. */
#ifndef WOLFSSL_MLKEM_MAKEKEY_SMALL_MEM #ifndef WOLFSSL_MLKEM_MAKEKEY_SMALL_MEM
#ifndef WOLFSSL_MLKEM_CACHE_A
/* e (v) | a (m) */
e = (sword16*)XMALLOC((kp + 1) * kp * KYBER_N * sizeof(sword16), e = (sword16*)XMALLOC((kp + 1) * kp * KYBER_N * sizeof(sword16),
key->heap, DYNAMIC_TYPE_TMP_BUFFER); key->heap, DYNAMIC_TYPE_TMP_BUFFER);
#else #else
/* e (v) */
e = (sword16*)XMALLOC(kp * KYBER_N * sizeof(sword16),
key->heap, DYNAMIC_TYPE_TMP_BUFFER);
#endif
#else
/* e (v) */
e = (sword16*)XMALLOC(kp * KYBER_N * sizeof(sword16), e = (sword16*)XMALLOC(kp * KYBER_N * sizeof(sword16),
key->heap, DYNAMIC_TYPE_TMP_BUFFER); key->heap, DYNAMIC_TYPE_TMP_BUFFER);
#endif #endif
@ -346,8 +366,10 @@ int wc_KyberKey_MakeKeyWithRandom(KyberKey* key, const unsigned char* rand,
if (ret == 0) { if (ret == 0) {
const byte* d = rand; const byte* d = rand;
#ifndef WOLFSSL_MLKEM_MAKEKEY_SMALL_MEM #ifdef WOLFSSL_MLKEM_CACHE_A
/* Error vector allocated at end of a. */ a = key->a;
#elif !defined(WOLFSSL_MLKEM_MAKEKEY_SMALL_MEM)
/* Matrix A allocated at end of error vector. */
a = e + (kp * KYBER_N); a = e + (kp * KYBER_N);
#endif #endif
@ -391,6 +413,9 @@ int wc_KyberKey_MakeKeyWithRandom(KyberKey* key, const unsigned char* rand,
ret = kyber_gen_matrix(&key->prf, a, kp, pubSeed, 0); ret = kyber_gen_matrix(&key->prf, a, kp, pubSeed, 0);
} }
if (ret == 0) { if (ret == 0) {
#ifdef WOLFSSL_MLKEM_CACHE_A
key->flags |= KYBER_FLAG_A_SET;
#endif
/* Generate key pair from random data. */ /* Generate key pair from random data. */
kyber_keygen(key->priv, key->pub, e, a, kp); kyber_keygen(key->priv, key->pub, e, a, kp);
#else #else
@ -514,7 +539,7 @@ static int kyberkey_encapsulate(KyberKey* key, const byte* msg, byte* coins,
unsigned char* ct) unsigned char* ct)
{ {
int ret = 0; int ret = 0;
sword16* sp = NULL; sword16* at = NULL;
#ifndef WOLFSSL_MLKEM_ENCAPSULATE_SMALL_MEM #ifndef WOLFSSL_MLKEM_ENCAPSULATE_SMALL_MEM
sword16* k = NULL; sword16* k = NULL;
sword16* ep = NULL; sword16* ep = NULL;
@ -523,12 +548,12 @@ static int kyberkey_encapsulate(KyberKey* key, const byte* msg, byte* coins,
unsigned int kp = 0; unsigned int kp = 0;
unsigned int compVecSz = 0; unsigned int compVecSz = 0;
#ifndef WOLFSSL_NO_MALLOC #ifndef WOLFSSL_NO_MALLOC
sword16* at = NULL; sword16* sp = NULL;
#else #else
#ifndef WOLFSSL_MLKEM_ENCAPSULATE_SMALL_MEM #ifndef WOLFSSL_MLKEM_ENCAPSULATE_SMALL_MEM
sword16 at[((KYBER_MAX_K + 3) * KYBER_MAX_K + 3) * KYBER_N]; sword16 sp[((KYBER_MAX_K + 3) * KYBER_MAX_K + 3) * KYBER_N];
#else #else
sword16 at[3 * KYBER_MAX_K * KYBER_N]; sword16 sp[3 * KYBER_MAX_K * KYBER_N];
#endif #endif
#endif #endif
#ifdef WOLFSSL_MLKEM_ENCAPSULATE_SMALL_MEM #ifdef WOLFSSL_MLKEM_ENCAPSULATE_SMALL_MEM
@ -588,13 +613,13 @@ static int kyberkey_encapsulate(KyberKey* key, const byte* msg, byte* coins,
if (ret == 0) { if (ret == 0) {
/* Allocate dynamic memory for all matrices, vectors and polynomials. */ /* Allocate dynamic memory for all matrices, vectors and polynomials. */
#ifndef WOLFSSL_MLKEM_ENCAPSULATE_SMALL_MEM #ifndef WOLFSSL_MLKEM_ENCAPSULATE_SMALL_MEM
at = (sword16*)XMALLOC(((kp + 3) * kp + 3) * KYBER_N * sizeof(sword16), sp = (sword16*)XMALLOC(((kp + 3) * kp + 3) * KYBER_N * sizeof(sword16),
key->heap, DYNAMIC_TYPE_TMP_BUFFER); key->heap, DYNAMIC_TYPE_TMP_BUFFER);
#else #else
at = (sword16*)XMALLOC(3 * kp * KYBER_N * sizeof(sword16), key->heap, sp = (sword16*)XMALLOC(3 * kp * KYBER_N * sizeof(sword16), key->heap,
DYNAMIC_TYPE_TMP_BUFFER); DYNAMIC_TYPE_TMP_BUFFER);
#endif #endif
if (at == NULL) { if (sp == NULL) {
ret = MEMORY_E; ret = MEMORY_E;
} }
} }
@ -603,15 +628,15 @@ static int kyberkey_encapsulate(KyberKey* key, const byte* msg, byte* coins,
if (ret == 0) { if (ret == 0) {
#ifndef WOLFSSL_MLKEM_ENCAPSULATE_SMALL_MEM #ifndef WOLFSSL_MLKEM_ENCAPSULATE_SMALL_MEM
/* Assign allocated dynamic memory to pointers. /* Assign allocated dynamic memory to pointers.
* at (m) | k (p) | sp (v) | ep (p) | epp (v) | bp (v) | v (p) */ * sp (b) | at (m) | k (p) | ep (p) | epp (v) | bp (v) | v (p) */
at = sp + KYBER_N * kp;
k = at + KYBER_N * kp * kp; k = at + KYBER_N * kp * kp;
sp = k + KYBER_N; ep = k + KYBER_N;
ep = sp + KYBER_N * kp;
epp = ep + KYBER_N * kp; epp = ep + KYBER_N * kp;
#else #else
/* Assign allocated dynamic memory to pointers. /* Assign allocated dynamic memory to pointers.
* at (v) | sp (v) | bp (v) */ * sp (v) | at (v) | bp (v) */
sp = at + KYBER_N * kp; at = sp + KYBER_N * kp;
#endif #endif
/* Initialize the PRF for use in the noise generation. */ /* Initialize the PRF for use in the noise generation. */
@ -623,6 +648,21 @@ static int kyberkey_encapsulate(KyberKey* key, const byte* msg, byte* coins,
/* Generate noise using PRF. */ /* Generate noise using PRF. */
ret = kyber_get_noise(&key->prf, kp, sp, ep, epp, coins); ret = kyber_get_noise(&key->prf, kp, sp, ep, epp, coins);
} }
#ifdef WOLFSSL_MLKEM_CACHE_A
if ((ret == 0) && ((key->flags & KYBER_FLAG_A_SET) != 0)) {
unsigned int i;
/* Transpose matrix. */
for (i = 0; i < kp; i++) {
unsigned int j;
for (j = 0; j < kp; j++) {
XMEMCPY(&at[(i * kp + j) * KYBER_N],
&key->a[(j * kp + i) * KYBER_N],
KYBER_N * 2);
}
}
}
else
#endif
if (ret == 0) { if (ret == 0) {
/* Generate the transposed matrix. */ /* Generate the transposed matrix. */
ret = kyber_gen_matrix(&key->prf, at, kp, key->pubSeed, 1); ret = kyber_gen_matrix(&key->prf, at, kp, key->pubSeed, 1);
@ -632,7 +672,7 @@ static int kyberkey_encapsulate(KyberKey* key, const byte* msg, byte* coins,
sword16* v; sword16* v;
/* Assign remaining allocated dynamic memory to pointers. /* Assign remaining allocated dynamic memory to pointers.
* at (m) | k (p) | sp (v) | ep (p) | epp (v) | bp (v) | v (p)*/ * sp (v) | at (m) | k (p) | ep (p) | epp (v) | bp (v) | v (p)*/
bp = epp + KYBER_N; bp = epp + KYBER_N;
v = bp + KYBER_N * kp; v = bp + KYBER_N * kp;
@ -644,7 +684,7 @@ static int kyberkey_encapsulate(KyberKey* key, const byte* msg, byte* coins,
} }
if (ret == 0) { if (ret == 0) {
/* Assign remaining allocated dynamic memory to pointers. /* Assign remaining allocated dynamic memory to pointers.
* at (v) | sp (v) | bp (v) */ * sp (v) | at (v) | bp (v) */
bp = sp + KYBER_N * kp; bp = sp + KYBER_N * kp;
v = at; v = at;
@ -676,7 +716,7 @@ static int kyberkey_encapsulate(KyberKey* key, const byte* msg, byte* coins,
#ifndef WOLFSSL_NO_MALLOC #ifndef WOLFSSL_NO_MALLOC
/* Dispose of dynamic memory allocated in function. */ /* Dispose of dynamic memory allocated in function. */
XFREE(at, key->heap, DYNAMIC_TYPE_TMP_BUFFER); XFREE(sp, key->heap, DYNAMIC_TYPE_TMP_BUFFER);
#endif #endif
return ret; return ret;

View File

@ -62,6 +62,7 @@ enum {
KYBER_FLAG_PUB_SET = 0x0002, KYBER_FLAG_PUB_SET = 0x0002,
KYBER_FLAG_BOTH_SET = 0x0003, KYBER_FLAG_BOTH_SET = 0x0003,
KYBER_FLAG_H_SET = 0x0004, KYBER_FLAG_H_SET = 0x0004,
KYBER_FLAG_A_SET = 0x0008,
/* 2 bits of random used to create noise value. */ /* 2 bits of random used to create noise value. */
KYBER_CBD_ETA2 = 2, KYBER_CBD_ETA2 = 2,
@ -137,6 +138,10 @@ struct KyberKey {
byte h[KYBER_SYM_SZ]; byte h[KYBER_SYM_SZ];
/* Randomizer for decapsulation. */ /* Randomizer for decapsulation. */
byte z[KYBER_SYM_SZ]; byte z[KYBER_SYM_SZ];
#ifdef WOLFSSL_MLKEM_CACHE_A
/* A matrix from key generation. */
sword16 a[KYBER_MAX_K * KYBER_MAX_K * KYBER_N];
#endif
}; };
#ifdef __cplusplus #ifdef __cplusplus