Merge pull request #8622 from SparkiDev/kyber_improv_3

ML-KEM/Kyber: minor improvements
pull/8618/head
JacobBarthelmeh 2025-04-02 09:56:32 -06:00 committed by GitHub
commit a3d0ffb1ed
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 1786 additions and 1453 deletions

File diff suppressed because it is too large Load Diff

View File

@ -348,12 +348,12 @@ int wc_MlKemKey_MakeKeyWithRandom(MlKemKey* key, const unsigned char* rand,
#else
#ifndef WOLFSSL_MLKEM_MAKEKEY_SMALL_MEM
#ifndef WOLFSSL_MLKEM_CACHE_A
sword16 e[(MLKEM_MAX_K + 1) * MLKEM_MAX_K * MLKEM_N];
sword16 e[(WC_ML_KEM_MAX_K + 1) * WC_ML_KEM_MAX_K * MLKEM_N];
#else
sword16 e[MLKEM_MAX_K * MLKEM_N];
sword16 e[WC_ML_KEM_MAX_K * MLKEM_N];
#endif
#else
sword16 e[MLKEM_MAX_K * MLKEM_N];
sword16 e[WC_ML_KEM_MAX_K * MLKEM_N];
#endif
#endif
#ifndef WOLFSSL_MLKEM_MAKEKEY_SMALL_MEM
@ -667,9 +667,9 @@ static int mlkemkey_encapsulate(MlKemKey* key, const byte* m, byte* r, byte* c)
sword16* y = NULL;
#else
#ifndef WOLFSSL_MLKEM_ENCAPSULATE_SMALL_MEM
sword16 y[((MLKEM_MAX_K + 3) * MLKEM_MAX_K + 3) * MLKEM_N];
sword16 y[((WC_ML_KEM_MAX_K + 3) * WC_ML_KEM_MAX_K + 3) * MLKEM_N];
#else
sword16 y[3 * MLKEM_MAX_K * MLKEM_N];
sword16 y[3 * WC_ML_KEM_MAX_K * MLKEM_N];
#endif
#endif
sword16* u;
@ -1266,39 +1266,6 @@ static MLKEM_NOINLINE int mlkemkey_decapsulate(MlKemKey* key, byte* m,
return ret;
}
#ifndef WOLFSSL_NO_ML_KEM
/* Derive the secret from z and cipher text.
*
* @param [in] z Implicit rejection value.
* @param [in] ct Cipher text.
* @param [in] ctSz Length of cipher text in bytes.
* @param [out] ss Shared secret.
* @return 0 on success.
* @return MEMORY_E when dynamic memory allocation failed.
* @return Other negative when a hash error occurred.
*/
static int mlkem_derive_secret(const byte* z, const byte* ct, word32 ctSz,
byte* ss)
{
int ret;
wc_Shake shake;
ret = wc_InitShake256(&shake, NULL, INVALID_DEVID);
if (ret == 0) {
ret = wc_Shake256_Update(&shake, z, WC_ML_KEM_SYM_SZ);
if (ret == 0) {
ret = wc_Shake256_Update(&shake, ct, ctSz);
}
if (ret == 0) {
ret = wc_Shake256_Final(&shake, ss, WC_ML_KEM_SS_SZ);
}
wc_Shake256_Free(&shake);
}
return ret;
}
#endif
/**
* Decapsulate the cipher text to calculate the shared secret.
*
@ -1461,7 +1428,7 @@ int wc_MlKemKey_Decapsulate(MlKemKey* key, unsigned char* ss,
#endif
#ifndef WOLFSSL_NO_ML_KEM
{
ret = mlkem_derive_secret(key->z, ct, ctSz, msg);
ret = mlkem_derive_secret(&key->prf, key->z, ct, ctSz, msg);
if (ret == 0) {
/* Set secret to kr or fake secret on comparison failure. */
for (i = 0; i < WC_ML_KEM_SYM_SZ; i++) {

File diff suppressed because it is too large Load Diff

View File

@ -182,7 +182,7 @@ const sword16 zetas[MLKEM_N / 2] = {
817, 1097, 603, 610, 1322, 2044, 1864, 384,
2114, 3193, 1218, 1994, 2455, 220, 2142, 1670,
2144, 1799, 2051, 794, 1819, 2475, 2459, 478,
3221, 3021, 996, 991, 958, 1869, 1522, 1628
3221, 3021, 996, 991, 958, 1869, 1522, 1628
};
@ -540,6 +540,308 @@ static void mlkem_ntt(sword16* r)
#endif
}
#if !defined(WOLFSSL_MLKEM_NO_MAKE_KEY) && \
!defined(WOLFSSL_MLKEM_SMALL) && !defined(WOLFSSL_MLKEM_NO_LARGE_CODE)
/* Number-Theoretic Transform.
*
* FIPS 203, Algorithm 9: NTT(f)
* Computes the NTT representation f_hat of the given polynomial f element of
* R_q.
* 1: f_hat <- f
* 2: i <- 1
* 3: for (len <- 128; len >= 2; len <- len/2)
* 4: for (start <- 0; start < 256; start <- start + 2.len)
* 5: zeta <- zetas^BitRev_7(i) mod q
* 6: i <- i + 1
* 7: for (j <- start; j < start + len; j++)
* 8: t <- zeta.f[j+len]
* 9: f_hat[j+len] <- f_hat[j] - t
* 10: f_hat[j] <- f_hat[j] - t
* 11: end for
* 12: end for
* 13: end for
* 14: return f_hat
*
* @param [in, out] r Polynomial to transform.
*/
static void mlkem_ntt_add_to(sword16* r, sword16* a)
{
#if defined(WOLFSSL_MLKEM_NTT_UNROLL)
/* Unroll len loop (Step 3). */
unsigned int k = 1;
unsigned int j;
unsigned int start;
sword16 zeta = zetas[k++];
/* len = 128 */
for (j = 0; j < MLKEM_N / 2; ++j) {
sword32 p = (sword32)zeta * r[j + MLKEM_N / 2];
sword16 t = MLKEM_MONT_RED(p);
sword16 rj = r[j];
r[j + MLKEM_N / 2] = rj - t;
r[j] = rj + t;
}
/* len = 64 */
for (start = 0; start < MLKEM_N; start += 2 * 64) {
zeta = zetas[k++];
for (j = 0; j < 64; ++j) {
sword32 p = (sword32)zeta * r[start + j + 64];
sword16 t = MLKEM_MONT_RED(p);
sword16 rj = r[start + j];
r[start + j + 64] = rj - t;
r[start + j] = rj + t;
}
}
/* len = 32 */
for (start = 0; start < MLKEM_N; start += 2 * 32) {
zeta = zetas[k++];
for (j = 0; j < 32; ++j) {
sword32 p = (sword32)zeta * r[start + j + 32];
sword16 t = MLKEM_MONT_RED(p);
sword16 rj = r[start + j];
r[start + j + 32] = rj - t;
r[start + j] = rj + t;
}
}
/* len = 16 */
for (start = 0; start < MLKEM_N; start += 2 * 16) {
zeta = zetas[k++];
for (j = 0; j < 16; ++j) {
sword32 p = (sword32)zeta * r[start + j + 16];
sword16 t = MLKEM_MONT_RED(p);
sword16 rj = r[start + j];
r[start + j + 16] = rj - t;
r[start + j] = rj + t;
}
}
/* len = 8 */
for (start = 0; start < MLKEM_N; start += 2 * 8) {
zeta = zetas[k++];
for (j = 0; j < 8; ++j) {
sword32 p = (sword32)zeta * r[start + j + 8];
sword16 t = MLKEM_MONT_RED(p);
sword16 rj = r[start + j];
r[start + j + 8] = rj - t;
r[start + j] = rj + t;
}
}
/* len = 4 */
for (start = 0; start < MLKEM_N; start += 2 * 4) {
zeta = zetas[k++];
for (j = 0; j < 4; ++j) {
sword32 p = (sword32)zeta * r[start + j + 4];
sword16 t = MLKEM_MONT_RED(p);
sword16 rj = r[start + j];
r[start + j + 4] = rj - t;
r[start + j] = rj + t;
}
}
/* len = 2 */
for (start = 0; start < MLKEM_N; start += 2 * 2) {
zeta = zetas[k++];
for (j = 0; j < 2; ++j) {
sword32 p = (sword32)zeta * r[start + j + 2];
sword16 t = MLKEM_MONT_RED(p);
sword16 rj = r[start + j];
r[start + j + 2] = rj - t;
r[start + j] = rj + t;
}
}
/* Reduce coefficients with quick algorithm. */
for (j = 0; j < MLKEM_N; ++j) {
sword16 t = a[j] + r[j];
a[j] = MLKEM_BARRETT_RED(t);
}
#else
/* Unroll len (2, 3, 2) and start loops. */
unsigned int j;
sword16 t0;
sword16 t1;
sword16 t2;
sword16 t3;
/* len = 128,64 */
sword16 zeta128 = zetas[1];
sword16 zeta64_0 = zetas[2];
sword16 zeta64_1 = zetas[3];
for (j = 0; j < MLKEM_N / 8; j++) {
sword16 r0 = r[j + 0];
sword16 r1 = r[j + 32];
sword16 r2 = r[j + 64];
sword16 r3 = r[j + 96];
sword16 r4 = r[j + 128];
sword16 r5 = r[j + 160];
sword16 r6 = r[j + 192];
sword16 r7 = r[j + 224];
t0 = MLKEM_MONT_RED((sword32)zeta128 * r4);
t1 = MLKEM_MONT_RED((sword32)zeta128 * r5);
t2 = MLKEM_MONT_RED((sword32)zeta128 * r6);
t3 = MLKEM_MONT_RED((sword32)zeta128 * r7);
r4 = r0 - t0;
r5 = r1 - t1;
r6 = r2 - t2;
r7 = r3 - t3;
r0 += t0;
r1 += t1;
r2 += t2;
r3 += t3;
t0 = MLKEM_MONT_RED((sword32)zeta64_0 * r2);
t1 = MLKEM_MONT_RED((sword32)zeta64_0 * r3);
t2 = MLKEM_MONT_RED((sword32)zeta64_1 * r6);
t3 = MLKEM_MONT_RED((sword32)zeta64_1 * r7);
r2 = r0 - t0;
r3 = r1 - t1;
r6 = r4 - t2;
r7 = r5 - t3;
r0 += t0;
r1 += t1;
r4 += t2;
r5 += t3;
r[j + 0] = r0;
r[j + 32] = r1;
r[j + 64] = r2;
r[j + 96] = r3;
r[j + 128] = r4;
r[j + 160] = r5;
r[j + 192] = r6;
r[j + 224] = r7;
}
/* len = 32,16,8 */
for (j = 0; j < MLKEM_N; j += 64) {
int i;
sword16 zeta32 = zetas[ 4 + j / 64 + 0];
sword16 zeta16_0 = zetas[ 8 + j / 32 + 0];
sword16 zeta16_1 = zetas[ 8 + j / 32 + 1];
sword16 zeta8_0 = zetas[16 + j / 16 + 0];
sword16 zeta8_1 = zetas[16 + j / 16 + 1];
sword16 zeta8_2 = zetas[16 + j / 16 + 2];
sword16 zeta8_3 = zetas[16 + j / 16 + 3];
for (i = 0; i < 8; i++) {
sword16 r0 = r[j + i + 0];
sword16 r1 = r[j + i + 8];
sword16 r2 = r[j + i + 16];
sword16 r3 = r[j + i + 24];
sword16 r4 = r[j + i + 32];
sword16 r5 = r[j + i + 40];
sword16 r6 = r[j + i + 48];
sword16 r7 = r[j + i + 56];
t0 = MLKEM_MONT_RED((sword32)zeta32 * r4);
t1 = MLKEM_MONT_RED((sword32)zeta32 * r5);
t2 = MLKEM_MONT_RED((sword32)zeta32 * r6);
t3 = MLKEM_MONT_RED((sword32)zeta32 * r7);
r4 = r0 - t0;
r5 = r1 - t1;
r6 = r2 - t2;
r7 = r3 - t3;
r0 += t0;
r1 += t1;
r2 += t2;
r3 += t3;
t0 = MLKEM_MONT_RED((sword32)zeta16_0 * r2);
t1 = MLKEM_MONT_RED((sword32)zeta16_0 * r3);
t2 = MLKEM_MONT_RED((sword32)zeta16_1 * r6);
t3 = MLKEM_MONT_RED((sword32)zeta16_1 * r7);
r2 = r0 - t0;
r3 = r1 - t1;
r6 = r4 - t2;
r7 = r5 - t3;
r0 += t0;
r1 += t1;
r4 += t2;
r5 += t3;
t0 = MLKEM_MONT_RED((sword32)zeta8_0 * r1);
t1 = MLKEM_MONT_RED((sword32)zeta8_1 * r3);
t2 = MLKEM_MONT_RED((sword32)zeta8_2 * r5);
t3 = MLKEM_MONT_RED((sword32)zeta8_3 * r7);
r1 = r0 - t0;
r3 = r2 - t1;
r5 = r4 - t2;
r7 = r6 - t3;
r0 += t0;
r2 += t1;
r4 += t2;
r6 += t3;
r[j + i + 0] = r0;
r[j + i + 8] = r1;
r[j + i + 16] = r2;
r[j + i + 24] = r3;
r[j + i + 32] = r4;
r[j + i + 40] = r5;
r[j + i + 48] = r6;
r[j + i + 56] = r7;
}
}
/* len = 4,2 and Final reduction */
for (j = 0; j < MLKEM_N; j += 8) {
sword16 zeta4 = zetas[32 + j / 8 + 0];
sword16 zeta2_0 = zetas[64 + j / 4 + 0];
sword16 zeta2_1 = zetas[64 + j / 4 + 1];
sword16 r0 = r[j + 0];
sword16 r1 = r[j + 1];
sword16 r2 = r[j + 2];
sword16 r3 = r[j + 3];
sword16 r4 = r[j + 4];
sword16 r5 = r[j + 5];
sword16 r6 = r[j + 6];
sword16 r7 = r[j + 7];
t0 = MLKEM_MONT_RED((sword32)zeta4 * r4);
t1 = MLKEM_MONT_RED((sword32)zeta4 * r5);
t2 = MLKEM_MONT_RED((sword32)zeta4 * r6);
t3 = MLKEM_MONT_RED((sword32)zeta4 * r7);
r4 = r0 - t0;
r5 = r1 - t1;
r6 = r2 - t2;
r7 = r3 - t3;
r0 += t0;
r1 += t1;
r2 += t2;
r3 += t3;
t0 = MLKEM_MONT_RED((sword32)zeta2_0 * r2);
t1 = MLKEM_MONT_RED((sword32)zeta2_0 * r3);
t2 = MLKEM_MONT_RED((sword32)zeta2_1 * r6);
t3 = MLKEM_MONT_RED((sword32)zeta2_1 * r7);
r2 = r0 - t0;
r3 = r1 - t1;
r6 = r4 - t2;
r7 = r5 - t3;
r0 += t0;
r1 += t1;
r4 += t2;
r5 += t3;
r0 += a[j + 0];
r1 += a[j + 1];
r2 += a[j + 2];
r3 += a[j + 3];
r4 += a[j + 4];
r5 += a[j + 5];
r6 += a[j + 6];
r7 += a[j + 7];
a[j + 0] = MLKEM_BARRETT_RED(r0);
a[j + 1] = MLKEM_BARRETT_RED(r1);
a[j + 2] = MLKEM_BARRETT_RED(r2);
a[j + 3] = MLKEM_BARRETT_RED(r3);
a[j + 4] = MLKEM_BARRETT_RED(r4);
a[j + 5] = MLKEM_BARRETT_RED(r5);
a[j + 6] = MLKEM_BARRETT_RED(r6);
a[j + 7] = MLKEM_BARRETT_RED(r7);
}
#endif
}
#endif
#if !defined(WOLFSSL_MLKEM_NO_ENCAPSULATE) || \
!defined(WOLFSSL_MLKEM_NO_DECAPSULATE)
/* Zetas for inverse NTT. */
@ -1524,6 +1826,7 @@ static void mlkem_keygen_c(sword16* s, sword16* t, sword16* e, const sword16* a,
}
/* Transform error values polynomial.
* Step 17: e_hat = NTT(e) */
#if defined(WOLFSSL_MLKEM_SMALL) || defined(WOLFSSL_MLKEM_NO_LARGE_CODE)
mlkem_ntt(e + i * MLKEM_N);
/* Add errors to public key and reduce.
* Step 18: t_hat = BarrettRed(MontRed(A_hat o s_hat) + e_hat) */
@ -1531,6 +1834,11 @@ static void mlkem_keygen_c(sword16* s, sword16* t, sword16* e, const sword16* a,
sword16 n = t[i * MLKEM_N + j] + e[i * MLKEM_N + j];
t[i * MLKEM_N + j] = MLKEM_BARRETT_RED(n);
}
#else
/* Add errors to public key and reduce.
* Step 18: t_hat = BarrettRed(MontRed(A_hat o s_hat) + e_hat) */
mlkem_ntt_add_to(e + i * MLKEM_N, t + i * MLKEM_N);
#endif
}
}
@ -1633,6 +1941,7 @@ int mlkem_keygen_seeds(sword16* s, sword16* t, MLKEM_PRF_T* prf,
}
/* Transform error values polynomial.
* Step 17: e_hat = NTT(e) */
#if defined(WOLFSSL_MLKEM_SMALL) || defined(WOLFSSL_MLKEM_NO_LARGE_CODE)
mlkem_ntt(e);
/* Add errors to public key and reduce.
* Step 18: t_hat = BarrettRed(MontRed(A_hat o s_hat) + e_hat) */
@ -1640,6 +1949,11 @@ int mlkem_keygen_seeds(sword16* s, sword16* t, MLKEM_PRF_T* prf,
sword16 n = t[i * MLKEM_N + j] + e[j];
t[i * MLKEM_N + j] = MLKEM_BARRETT_RED(n);
}
#else
/* Add errors to public key and reduce.
* Step 18: t_hat = BarrettRed(MontRed(A_hat o s_hat) + e_hat) */
mlkem_ntt_add_to(e, t + i * MLKEM_N);
#endif
}
return ret;
@ -1684,10 +1998,31 @@ static void mlkem_encapsulate_c(const sword16* pub, sword16* u, sword16* v,
/* Inverse transform u polynomial. */
mlkem_invntt(u + i * MLKEM_N);
/* Add errors to u and reduce. */
#if defined(WOLFSSL_MLKEM_SMALL) || defined(WOLFSSL_MLKEM_NO_LARGE_CODE)
for (j = 0; j < MLKEM_N; ++j) {
sword16 t = u[i * MLKEM_N + j] + e1[i * MLKEM_N + j];
u[i * MLKEM_N + j] = MLKEM_BARRETT_RED(t);
}
#else
for (j = 0; j < MLKEM_N; j += 8) {
sword16 t0 = u[i * MLKEM_N + j + 0] + e1[i * MLKEM_N + j + 0];
sword16 t1 = u[i * MLKEM_N + j + 1] + e1[i * MLKEM_N + j + 1];
sword16 t2 = u[i * MLKEM_N + j + 2] + e1[i * MLKEM_N + j + 2];
sword16 t3 = u[i * MLKEM_N + j + 3] + e1[i * MLKEM_N + j + 3];
sword16 t4 = u[i * MLKEM_N + j + 4] + e1[i * MLKEM_N + j + 4];
sword16 t5 = u[i * MLKEM_N + j + 5] + e1[i * MLKEM_N + j + 5];
sword16 t6 = u[i * MLKEM_N + j + 6] + e1[i * MLKEM_N + j + 6];
sword16 t7 = u[i * MLKEM_N + j + 7] + e1[i * MLKEM_N + j + 7];
u[i * MLKEM_N + j + 0] = MLKEM_BARRETT_RED(t0);
u[i * MLKEM_N + j + 1] = MLKEM_BARRETT_RED(t1);
u[i * MLKEM_N + j + 2] = MLKEM_BARRETT_RED(t2);
u[i * MLKEM_N + j + 3] = MLKEM_BARRETT_RED(t3);
u[i * MLKEM_N + j + 4] = MLKEM_BARRETT_RED(t4);
u[i * MLKEM_N + j + 5] = MLKEM_BARRETT_RED(t5);
u[i * MLKEM_N + j + 6] = MLKEM_BARRETT_RED(t6);
u[i * MLKEM_N + j + 7] = MLKEM_BARRETT_RED(t7);
}
#endif
}
/* Multiply public key by y into v polynomial. */
@ -1781,10 +2116,31 @@ int mlkem_encapsulate_seeds(const sword16* pub, MLKEM_PRF_T* prf, sword16* u,
break;
}
/* Add errors to u and reduce. */
#if defined(WOLFSSL_MLKEM_SMALL) || defined(WOLFSSL_MLKEM_NO_LARGE_CODE)
for (j = 0; j < MLKEM_N; ++j) {
sword16 t = u[i * MLKEM_N + j] + e1[j];
u[i * MLKEM_N + j] = MLKEM_BARRETT_RED(t);
}
#else
for (j = 0; j < MLKEM_N; j += 8) {
sword16 t0 = u[i * MLKEM_N + j + 0] + e1[j + 0];
sword16 t1 = u[i * MLKEM_N + j + 1] + e1[j + 1];
sword16 t2 = u[i * MLKEM_N + j + 2] + e1[j + 2];
sword16 t3 = u[i * MLKEM_N + j + 3] + e1[j + 3];
sword16 t4 = u[i * MLKEM_N + j + 4] + e1[j + 4];
sword16 t5 = u[i * MLKEM_N + j + 5] + e1[j + 5];
sword16 t6 = u[i * MLKEM_N + j + 6] + e1[j + 6];
sword16 t7 = u[i * MLKEM_N + j + 7] + e1[j + 7];
u[i * MLKEM_N + j + 0] = MLKEM_BARRETT_RED(t0);
u[i * MLKEM_N + j + 1] = MLKEM_BARRETT_RED(t1);
u[i * MLKEM_N + j + 2] = MLKEM_BARRETT_RED(t2);
u[i * MLKEM_N + j + 3] = MLKEM_BARRETT_RED(t3);
u[i * MLKEM_N + j + 4] = MLKEM_BARRETT_RED(t4);
u[i * MLKEM_N + j + 5] = MLKEM_BARRETT_RED(t5);
u[i * MLKEM_N + j + 6] = MLKEM_BARRETT_RED(t6);
u[i * MLKEM_N + j + 7] = MLKEM_BARRETT_RED(t7);
}
#endif
}
/* Multiply public key by y into v polynomial. */
@ -1799,10 +2155,31 @@ int mlkem_encapsulate_seeds(const sword16* pub, MLKEM_PRF_T* prf, sword16* u,
ret = mlkem_get_noise_eta2_c(prf, e2, coins);
if (ret == 0) {
/* Add errors and message to v and reduce. */
#if defined(WOLFSSL_MLKEM_SMALL) || defined(WOLFSSL_MLKEM_NO_LARGE_CODE)
for (i = 0; i < MLKEM_N; ++i) {
sword16 t = v[i] + e2[i] + m[i];
tp[i] = MLKEM_BARRETT_RED(t);
v[i] = MLKEM_BARRETT_RED(t);
}
#else
for (i = 0; i < MLKEM_N; i += 8) {
sword16 t0 = v[i + 0] + e2[i + 0] + m[i + 0];
sword16 t1 = v[i + 1] + e2[i + 1] + m[i + 1];
sword16 t2 = v[i + 2] + e2[i + 2] + m[i + 2];
sword16 t3 = v[i + 3] + e2[i + 3] + m[i + 3];
sword16 t4 = v[i + 4] + e2[i + 4] + m[i + 4];
sword16 t5 = v[i + 5] + e2[i + 5] + m[i + 5];
sword16 t6 = v[i + 6] + e2[i + 6] + m[i + 6];
sword16 t7 = v[i + 7] + e2[i + 7] + m[i + 7];
v[i + 0] = MLKEM_BARRETT_RED(t0);
v[i + 1] = MLKEM_BARRETT_RED(t1);
v[i + 2] = MLKEM_BARRETT_RED(t2);
v[i + 3] = MLKEM_BARRETT_RED(t3);
v[i + 4] = MLKEM_BARRETT_RED(t4);
v[i + 5] = MLKEM_BARRETT_RED(t5);
v[i + 6] = MLKEM_BARRETT_RED(t6);
v[i + 7] = MLKEM_BARRETT_RED(t7);
}
#endif
}
return ret;
@ -2713,6 +3090,47 @@ int mlkem_kdf(byte* seed, int seedLen, byte* out, int outLen)
#endif
#endif
#ifndef WOLFSSL_NO_ML_KEM
/* Derive the secret from z and cipher text.
*
* @param [in, out] shake256 SHAKE-256 object.
* @param [in] z Implicit rejection value.
* @param [in] ct Cipher text.
* @param [in] ctSz Length of cipher text in bytes.
* @param [out] ss Shared secret.
* @return 0 on success.
* @return MEMORY_E when dynamic memory allocation failed.
* @return Other negative when a hash error occurred.
*/
int mlkem_derive_secret(wc_Shake* shake256, const byte* z, const byte* ct,
word32 ctSz, byte* ss)
{
int ret;
#ifdef USE_INTEL_SPEEDUP
XMEMCPY(shake256->t, z, WC_ML_KEM_SYM_SZ);
XMEMCPY(shake256->t, ct, WC_SHA3_256_COUNT * 8 - WC_ML_KEM_SYM_SZ);
shake256->i = WC_ML_KEM_SYM_SZ;
ct += WC_SHA3_256_COUNT * 8 - WC_ML_KEM_SYM_SZ;
ctSz -= WC_SHA3_256_COUNT * 8 - WC_ML_KEM_SYM_SZ;
ret = wc_Shake256_Update(shake256, ct, ctSz);
if (ret == 0) {
ret = wc_Shake256_Final(shake256, ss, WC_ML_KEM_SS_SZ);
}
#else
ret = wc_Shake256_Update(shake256, z, WC_ML_KEM_SYM_SZ);
if (ret == 0) {
ret = wc_Shake256_Update(shake256, ct, ctSz);
}
if (ret == 0) {
ret = wc_Shake256_Final(shake256, ss, WC_ML_KEM_SS_SZ);
}
#endif
return ret;
}
#endif
#if !defined(WOLFSSL_ARMASM)
/* Rejection sampling on uniform random bytes to generate uniform random
* integers mod q.
@ -3625,16 +4043,34 @@ static void mlkem_get_noise_x4_eta2_avx2(byte* rand, byte* seed, byte o)
static int mlkem_get_noise_eta2_avx2(MLKEM_PRF_T* prf, sword16* p,
const byte* seed)
{
int ret;
byte rand[ETA2_RAND_SIZE];
word64 state[25];
/* Calculate random bytes from seed with PRF. */
ret = mlkem_prf(prf, rand, sizeof(rand), seed);
if (ret == 0) {
mlkem_cbd_eta2_avx2(p, rand);
(void)prf;
/* Put first WC_ML_KEM_SYM_SZ bytes og key into blank state. */
readUnalignedWords64(state, seed, WC_ML_KEM_SYM_SZ / sizeof(word64));
/* Last byte in with end of content marker. */
state[WC_ML_KEM_SYM_SZ / 8] = 0x1f00 | seed[WC_ML_KEM_SYM_SZ];
/* Set rest of state to 0. */
XMEMSET(state + WC_ML_KEM_SYM_SZ / 8 + 1, 0,
(25 - WC_ML_KEM_SYM_SZ / 8 - 1) * sizeof(word64));
/* ... except for rate marker. */
state[WC_SHA3_256_COUNT - 1] = W64LIT(0x8000000000000000);
/* Perform a block operation on the state for next block of output. */
if (IS_INTEL_BMI2(cpuid_flags)) {
sha3_block_bmi2(state);
}
else if (IS_INTEL_AVX2(cpuid_flags) && (SAVE_VECTOR_REGISTERS2() == 0)) {
sha3_block_avx2(state);
RESTORE_VECTOR_REGISTERS();
}
else {
BlockSha3(state);
}
mlkem_cbd_eta2_avx2(p, (byte*)state);
return ret;
return 0;
}
#endif

View File

@ -200,6 +200,10 @@ WOLFSSL_LOCAL
int mlkem_hash512(wc_Sha3* hash, const byte* data1, word32 data1Len,
const byte* data2, word32 data2Len, byte* out);
WOLFSSL_LOCAL
int mlkem_derive_secret(MLKEM_PRF_T* prf, const byte* z, const byte* ct,
word32 ctSz, byte* ss);
WOLFSSL_LOCAL
void mlkem_prf_init(MLKEM_PRF_T* prf);
WOLFSSL_LOCAL