Merge pull request #8436 from SparkiDev/mlkem_cache_a

ML-KEM/Kyber: cache A from key generation for decapsulation
pull/8447/head
David Garske 2025-02-12 17:29:38 -08:00 committed by GitHub
commit db0fa304a8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 96 additions and 26 deletions

View File

@ -1399,6 +1399,9 @@ do
small)
AM_CFLAGS="$AM_CFLAGS -DWOLFSSL_KYBER_SMALL"
;;
cache-a)
AM_CFLAGS="$AM_CFLAGS -DWOLFSSL_MLKEM_CACHE_A"
;;
512)
ENABLED_KYBER512=yes
;;

View File

@ -9630,17 +9630,37 @@ exit:
#endif
}
static void bench_kyber_encap(const char* name, int keySize, KyberKey* key)
static void bench_kyber_encap(int type, const char* name, int keySize,
KyberKey* key1, KyberKey* key2)
{
int ret = 0, times, count, pending = 0;
double start;
const char**desc = bench_desc_words[lng_index];
byte ct[KYBER_MAX_CIPHER_TEXT_SIZE];
byte ss[KYBER_SS_SZ];
byte pub[KYBER_MAX_PUBLIC_KEY_SIZE];
word32 pubLen;
word32 ctSz;
DECLARE_MULTI_VALUE_STATS_VARS()
ret = wc_KyberKey_CipherTextSize(key, &ctSz);
ret = wc_KyberKey_PublicKeySize(key1, &pubLen);
if (ret != 0) {
return;
}
ret = wc_KyberKey_EncodePublicKey(key1, pub, pubLen);
if (ret != 0) {
return;
}
ret = wc_KyberKey_Init(type, key2, HEAP_HINT, INVALID_DEVID);
if (ret != 0) {
return;
}
ret = wc_KyberKey_DecodePublicKey(key2, pub, pubLen);
if (ret != 0) {
return;
}
ret = wc_KyberKey_CipherTextSize(key2, &ctSz);
if (ret != 0) {
return;
}
@ -9651,10 +9671,10 @@ static void bench_kyber_encap(const char* name, int keySize, KyberKey* key)
/* while free pending slots in queue, submit ops */
for (times = 0; times < agreeTimes || pending > 0; times++) {
#ifdef KYBER_NONDETERMINISTIC
ret = wc_KyberKey_Encapsulate(key, ct, ss, &gRng);
ret = wc_KyberKey_Encapsulate(key2, ct, ss, &gRng);
#else
unsigned char rand[KYBER_ENC_RAND_SZ] = {0,};
ret = wc_KyberKey_EncapsulateWithRandom(key, ct, ss, rand,
ret = wc_KyberKey_EncapsulateWithRandom(key2, ct, ss, rand,
sizeof(rand));
#endif
if (ret != 0)
@ -9681,7 +9701,7 @@ exit_encap:
do {
/* while free pending slots in queue, submit ops */
for (times = 0; times < agreeTimes || pending > 0; times++) {
ret = wc_KyberKey_Decapsulate(key, ss, ct, ctSz);
ret = wc_KyberKey_Decapsulate(key1, ss, ct, ctSz);
if (ret != 0)
goto exit_decap;
RECORD_MULTI_VALUE_STATS();
@ -9702,7 +9722,8 @@ exit_decap:
void bench_kyber(int type)
{
KyberKey key;
KyberKey key1;
KyberKey key2;
const char* name = NULL;
int keySize = 0;
@ -9749,10 +9770,11 @@ void bench_kyber(int type)
#endif
}
bench_kyber_keygen(type, name, keySize, &key);
bench_kyber_encap(name, keySize, &key);
bench_kyber_keygen(type, name, keySize, &key1);
bench_kyber_encap(type, name, keySize, &key1, &key2);
wc_KyberKey_Free(&key);
wc_KyberKey_Free(&key2);
wc_KyberKey_Free(&key1);
}
#endif

View File

@ -63,6 +63,12 @@
#error "Can't use small memory with assembly optimized code"
#endif
#endif
#if defined(WOLFSSL_MLKEM_CACHE_A)
#if defined(WOLFSSL_MLKEM_MAKEKEY_SMALL_MEM) || \
defined(WOLFSSL_MLKEM_ENCAPSULATE_SMALL_MEM)
#error "Can't cache A with small memory code"
#endif
#endif
#ifdef WOLFSSL_WC_KYBER
@ -265,10 +271,14 @@ int wc_KyberKey_MakeKeyWithRandom(KyberKey* key, const unsigned char* rand,
sword16* e = NULL;
#else
#ifndef WOLFSSL_MLKEM_MAKEKEY_SMALL_MEM
#ifndef WOLFSSL_MLKEM_CACHE_A
sword16 e[(KYBER_MAX_K + 1) * KYBER_MAX_K * KYBER_N];
#else
sword16 e[KYBER_MAX_K * KYBER_N];
#endif
#else
sword16 e[KYBER_MAX_K * KYBER_N];
#endif
#endif
#ifndef WOLFSSL_MLKEM_MAKEKEY_SMALL_MEM
sword16* a = NULL;
@ -285,6 +295,8 @@ int wc_KyberKey_MakeKeyWithRandom(KyberKey* key, const unsigned char* rand,
}
if (ret == 0) {
key->flags = 0;
/* Establish parameters based on key type. */
switch (key->type) {
#ifndef WOLFSSL_NO_ML_KEM
@ -332,9 +344,17 @@ int wc_KyberKey_MakeKeyWithRandom(KyberKey* key, const unsigned char* rand,
if (ret == 0) {
/* Allocate dynamic memory for matrix and error vector. */
#ifndef WOLFSSL_MLKEM_MAKEKEY_SMALL_MEM
#ifndef WOLFSSL_MLKEM_CACHE_A
/* e (v) | a (m) */
e = (sword16*)XMALLOC((kp + 1) * kp * KYBER_N * sizeof(sword16),
key->heap, DYNAMIC_TYPE_TMP_BUFFER);
#else
/* e (v) */
e = (sword16*)XMALLOC(kp * KYBER_N * sizeof(sword16),
key->heap, DYNAMIC_TYPE_TMP_BUFFER);
#endif
#else
/* e (v) */
e = (sword16*)XMALLOC(kp * KYBER_N * sizeof(sword16),
key->heap, DYNAMIC_TYPE_TMP_BUFFER);
#endif
@ -346,8 +366,10 @@ int wc_KyberKey_MakeKeyWithRandom(KyberKey* key, const unsigned char* rand,
if (ret == 0) {
const byte* d = rand;
#ifndef WOLFSSL_MLKEM_MAKEKEY_SMALL_MEM
/* Error vector allocated at end of a. */
#ifdef WOLFSSL_MLKEM_CACHE_A
a = key->a;
#elif !defined(WOLFSSL_MLKEM_MAKEKEY_SMALL_MEM)
/* Matrix A allocated at end of error vector. */
a = e + (kp * KYBER_N);
#endif
@ -391,6 +413,9 @@ int wc_KyberKey_MakeKeyWithRandom(KyberKey* key, const unsigned char* rand,
ret = kyber_gen_matrix(&key->prf, a, kp, pubSeed, 0);
}
if (ret == 0) {
#ifdef WOLFSSL_MLKEM_CACHE_A
key->flags |= KYBER_FLAG_A_SET;
#endif
/* Generate key pair from random data. */
kyber_keygen(key->priv, key->pub, e, a, kp);
#else
@ -514,7 +539,7 @@ static int kyberkey_encapsulate(KyberKey* key, const byte* msg, byte* coins,
unsigned char* ct)
{
int ret = 0;
sword16* sp = NULL;
sword16* at = NULL;
#ifndef WOLFSSL_MLKEM_ENCAPSULATE_SMALL_MEM
sword16* k = NULL;
sword16* ep = NULL;
@ -523,12 +548,12 @@ static int kyberkey_encapsulate(KyberKey* key, const byte* msg, byte* coins,
unsigned int kp = 0;
unsigned int compVecSz = 0;
#ifndef WOLFSSL_NO_MALLOC
sword16* at = NULL;
sword16* sp = NULL;
#else
#ifndef WOLFSSL_MLKEM_ENCAPSULATE_SMALL_MEM
sword16 at[((KYBER_MAX_K + 3) * KYBER_MAX_K + 3) * KYBER_N];
sword16 sp[((KYBER_MAX_K + 3) * KYBER_MAX_K + 3) * KYBER_N];
#else
sword16 at[3 * KYBER_MAX_K * KYBER_N];
sword16 sp[3 * KYBER_MAX_K * KYBER_N];
#endif
#endif
#ifdef WOLFSSL_MLKEM_ENCAPSULATE_SMALL_MEM
@ -588,13 +613,13 @@ static int kyberkey_encapsulate(KyberKey* key, const byte* msg, byte* coins,
if (ret == 0) {
/* Allocate dynamic memory for all matrices, vectors and polynomials. */
#ifndef WOLFSSL_MLKEM_ENCAPSULATE_SMALL_MEM
at = (sword16*)XMALLOC(((kp + 3) * kp + 3) * KYBER_N * sizeof(sword16),
sp = (sword16*)XMALLOC(((kp + 3) * kp + 3) * KYBER_N * sizeof(sword16),
key->heap, DYNAMIC_TYPE_TMP_BUFFER);
#else
at = (sword16*)XMALLOC(3 * kp * KYBER_N * sizeof(sword16), key->heap,
sp = (sword16*)XMALLOC(3 * kp * KYBER_N * sizeof(sword16), key->heap,
DYNAMIC_TYPE_TMP_BUFFER);
#endif
if (at == NULL) {
if (sp == NULL) {
ret = MEMORY_E;
}
}
@ -603,15 +628,15 @@ static int kyberkey_encapsulate(KyberKey* key, const byte* msg, byte* coins,
if (ret == 0) {
#ifndef WOLFSSL_MLKEM_ENCAPSULATE_SMALL_MEM
/* Assign allocated dynamic memory to pointers.
* at (m) | k (p) | sp (v) | ep (p) | epp (v) | bp (v) | v (p) */
* sp (b) | at (m) | k (p) | ep (p) | epp (v) | bp (v) | v (p) */
at = sp + KYBER_N * kp;
k = at + KYBER_N * kp * kp;
sp = k + KYBER_N;
ep = sp + KYBER_N * kp;
ep = k + KYBER_N;
epp = ep + KYBER_N * kp;
#else
/* Assign allocated dynamic memory to pointers.
* at (v) | sp (v) | bp (v) */
sp = at + KYBER_N * kp;
* sp (v) | at (v) | bp (v) */
at = sp + KYBER_N * kp;
#endif
/* Initialize the PRF for use in the noise generation. */
@ -623,6 +648,21 @@ static int kyberkey_encapsulate(KyberKey* key, const byte* msg, byte* coins,
/* Generate noise using PRF. */
ret = kyber_get_noise(&key->prf, kp, sp, ep, epp, coins);
}
#ifdef WOLFSSL_MLKEM_CACHE_A
if ((ret == 0) && ((key->flags & KYBER_FLAG_A_SET) != 0)) {
unsigned int i;
/* Transpose matrix. */
for (i = 0; i < kp; i++) {
unsigned int j;
for (j = 0; j < kp; j++) {
XMEMCPY(&at[(i * kp + j) * KYBER_N],
&key->a[(j * kp + i) * KYBER_N],
KYBER_N * 2);
}
}
}
else
#endif
if (ret == 0) {
/* Generate the transposed matrix. */
ret = kyber_gen_matrix(&key->prf, at, kp, key->pubSeed, 1);
@ -632,7 +672,7 @@ static int kyberkey_encapsulate(KyberKey* key, const byte* msg, byte* coins,
sword16* v;
/* Assign remaining allocated dynamic memory to pointers.
* at (m) | k (p) | sp (v) | ep (p) | epp (v) | bp (v) | v (p)*/
* sp (v) | at (m) | k (p) | ep (p) | epp (v) | bp (v) | v (p)*/
bp = epp + KYBER_N;
v = bp + KYBER_N * kp;
@ -644,7 +684,7 @@ static int kyberkey_encapsulate(KyberKey* key, const byte* msg, byte* coins,
}
if (ret == 0) {
/* Assign remaining allocated dynamic memory to pointers.
* at (v) | sp (v) | bp (v) */
* sp (v) | at (v) | bp (v) */
bp = sp + KYBER_N * kp;
v = at;
@ -676,7 +716,7 @@ static int kyberkey_encapsulate(KyberKey* key, const byte* msg, byte* coins,
#ifndef WOLFSSL_NO_MALLOC
/* Dispose of dynamic memory allocated in function. */
XFREE(at, key->heap, DYNAMIC_TYPE_TMP_BUFFER);
XFREE(sp, key->heap, DYNAMIC_TYPE_TMP_BUFFER);
#endif
return ret;

View File

@ -62,6 +62,7 @@ enum {
KYBER_FLAG_PUB_SET = 0x0002,
KYBER_FLAG_BOTH_SET = 0x0003,
KYBER_FLAG_H_SET = 0x0004,
KYBER_FLAG_A_SET = 0x0008,
/* 2 bits of random used to create noise value. */
KYBER_CBD_ETA2 = 2,
@ -137,6 +138,10 @@ struct KyberKey {
byte h[KYBER_SYM_SZ];
/* Randomizer for decapsulation. */
byte z[KYBER_SYM_SZ];
#ifdef WOLFSSL_MLKEM_CACHE_A
/* A matrix from key generation. */
sword16 a[KYBER_MAX_K * KYBER_MAX_K * KYBER_N];
#endif
};
#ifdef __cplusplus