Merge pull request #8868 from SparkiDev/dilithium_win_fixes_1

Dilithium/ML-DSA: Fixes for casting down and uninit
pull/8913/head
Daniel Pouzzner 2025-06-23 09:02:35 -05:00 committed by GitHub
commit 47a8242093
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 275 additions and 259 deletions

View File

@ -502,11 +502,12 @@ static int dilithium_get_hash_oid(int hash, byte* oidBuffer, word32* oidLen)
#ifndef WOLFSSL_DILITHIUM_NO_ASN1
oid = OidFromId(wc_HashGetOID((enum wc_HashType)hash), oidHashType, oidLen);
oid = OidFromId((word32)wc_HashGetOID((enum wc_HashType)hash), oidHashType,
oidLen);
if ((oid != NULL) && (*oidLen <= DILITHIUM_HASH_OID_LEN - 2)) {
#ifndef WOLFSSL_DILITHIUM_REVERSE_HASH_OID
oidBuffer[0] = 0x06; /* ObjectID */
oidBuffer[1] = *oidLen; /* ObjectID */
oidBuffer[0] = 0x06; /* ObjectID */
oidBuffer[1] = (byte)*oidLen; /* ObjectID */
oidBuffer += 2;
XMEMCPY(oidBuffer, oid, *oidLen);
#else
@ -733,19 +734,19 @@ static void dilthium_vec_encode_eta_bits(const sword32* s, byte d, byte eta,
* 8 numbers become 3 bytes. (8 * 3 bits = 3 * 8 bits) */
for (j = 0; j < DILITHIUM_N; j += 8) {
/* Make value a positive integer. */
byte s0 = 2 - s[j + 0];
byte s1 = 2 - s[j + 1];
byte s2 = 2 - s[j + 2];
byte s3 = 2 - s[j + 3];
byte s4 = 2 - s[j + 4];
byte s5 = 2 - s[j + 5];
byte s6 = 2 - s[j + 6];
byte s7 = 2 - s[j + 7];
byte s0 = (byte)(2 - s[j + 0]);
byte s1 = (byte)(2 - s[j + 1]);
byte s2 = (byte)(2 - s[j + 2]);
byte s3 = (byte)(2 - s[j + 3]);
byte s4 = (byte)(2 - s[j + 4]);
byte s5 = (byte)(2 - s[j + 5]);
byte s6 = (byte)(2 - s[j + 6]);
byte s7 = (byte)(2 - s[j + 7]);
/* Pack 8 3-bit values into 3 bytes. */
p[0] = (s0 >> 0) | (s1 << 3) | (s2 << 6);
p[1] = (s2 >> 2) | (s3 << 1) | (s4 << 4) | (s5 << 7);
p[2] = (s5 >> 1) | (s6 << 2) | (s7 << 5);
p[0] = (byte)((s0 >> 0) | (s1 << 3) | (s2 << 6));
p[1] = (byte)((s2 >> 2) | (s3 << 1) | (s4 << 4) | (s5 << 7));
p[2] = (byte)((s5 >> 1) | (s6 << 2) | (s7 << 5));
/* Move to next place to encode into. */
p += DILITHIUM_ETA_2_BITS;
}
@ -774,14 +775,14 @@ static void dilthium_vec_encode_eta_bits(const sword32* s, byte d, byte eta,
* 8 numbers become 4 bytes. (8 * 4 bits = 4 * 8 bits) */
for (j = 0; j < DILITHIUM_N / 2; j += 4) {
/* Make values positive and pack 2 4-bit values into 1 byte. */
p[j + 0] = (((byte)(4 - s[j * 2 + 0])) << 0) |
(((byte)(4 - s[j * 2 + 1])) << 4);
p[j + 1] = (((byte)(4 - s[j * 2 + 2])) << 0) |
(((byte)(4 - s[j * 2 + 3])) << 4);
p[j + 2] = (((byte)(4 - s[j * 2 + 4])) << 0) |
(((byte)(4 - s[j * 2 + 5])) << 4);
p[j + 3] = (((byte)(4 - s[j * 2 + 6])) << 0) |
(((byte)(4 - s[j * 2 + 7])) << 4);
p[j + 0] = (byte)((((byte)(4 - s[j * 2 + 0])) << 0) |
(((byte)(4 - s[j * 2 + 1])) << 4));
p[j + 1] = (byte)((((byte)(4 - s[j * 2 + 2])) << 0) |
(((byte)(4 - s[j * 2 + 3])) << 4));
p[j + 2] = (byte)((((byte)(4 - s[j * 2 + 4])) << 0) |
(((byte)(4 - s[j * 2 + 5])) << 4));
p[j + 3] = (byte)((((byte)(4 - s[j * 2 + 6])) << 0) |
(((byte)(4 - s[j * 2 + 7])) << 4));
}
#endif
/* Move to next place to encode into. */
@ -993,31 +994,39 @@ static void dilithium_vec_encode_t0_t1(sword32* t, byte d, byte* t0, byte* t1)
* Do all polynomial values - 8 at a time. */
for (j = 0; j < DILITHIUM_N; j += 8) {
/* Take 8 values of t and take top bits and make positive. */
word16 n1_0 = (t[j + 0] + DILITHIUM_D_MAX_HALF - 1) >> DILITHIUM_D;
word16 n1_1 = (t[j + 1] + DILITHIUM_D_MAX_HALF - 1) >> DILITHIUM_D;
word16 n1_2 = (t[j + 2] + DILITHIUM_D_MAX_HALF - 1) >> DILITHIUM_D;
word16 n1_3 = (t[j + 3] + DILITHIUM_D_MAX_HALF - 1) >> DILITHIUM_D;
word16 n1_4 = (t[j + 4] + DILITHIUM_D_MAX_HALF - 1) >> DILITHIUM_D;
word16 n1_5 = (t[j + 5] + DILITHIUM_D_MAX_HALF - 1) >> DILITHIUM_D;
word16 n1_6 = (t[j + 6] + DILITHIUM_D_MAX_HALF - 1) >> DILITHIUM_D;
word16 n1_7 = (t[j + 7] + DILITHIUM_D_MAX_HALF - 1) >> DILITHIUM_D;
word16 n1_0 = (word16)((t[j + 0] + DILITHIUM_D_MAX_HALF - 1) >>
DILITHIUM_D);
word16 n1_1 = (word16)((t[j + 1] + DILITHIUM_D_MAX_HALF - 1) >>
DILITHIUM_D);
word16 n1_2 = (word16)((t[j + 2] + DILITHIUM_D_MAX_HALF - 1) >>
DILITHIUM_D);
word16 n1_3 = (word16)((t[j + 3] + DILITHIUM_D_MAX_HALF - 1) >>
DILITHIUM_D);
word16 n1_4 = (word16)((t[j + 4] + DILITHIUM_D_MAX_HALF - 1) >>
DILITHIUM_D);
word16 n1_5 = (word16)((t[j + 5] + DILITHIUM_D_MAX_HALF - 1) >>
DILITHIUM_D);
word16 n1_6 = (word16)((t[j + 6] + DILITHIUM_D_MAX_HALF - 1) >>
DILITHIUM_D);
word16 n1_7 = (word16)((t[j + 7] + DILITHIUM_D_MAX_HALF - 1) >>
DILITHIUM_D);
/* Take 8 values of t and take bottom bits and make positive. */
word16 n0_0 = DILITHIUM_D_MAX_HALF -
(t[j + 0] - (n1_0 << DILITHIUM_D));
word16 n0_1 = DILITHIUM_D_MAX_HALF -
(t[j + 1] - (n1_1 << DILITHIUM_D));
word16 n0_2 = DILITHIUM_D_MAX_HALF -
(t[j + 2] - (n1_2 << DILITHIUM_D));
word16 n0_3 = DILITHIUM_D_MAX_HALF -
(t[j + 3] - (n1_3 << DILITHIUM_D));
word16 n0_4 = DILITHIUM_D_MAX_HALF -
(t[j + 4] - (n1_4 << DILITHIUM_D));
word16 n0_5 = DILITHIUM_D_MAX_HALF -
(t[j + 5] - (n1_5 << DILITHIUM_D));
word16 n0_6 = DILITHIUM_D_MAX_HALF -
(t[j + 6] - (n1_6 << DILITHIUM_D));
word16 n0_7 = DILITHIUM_D_MAX_HALF -
(t[j + 7] - (n1_7 << DILITHIUM_D));
word16 n0_0 = (word16)(DILITHIUM_D_MAX_HALF -
(t[j + 0] - (n1_0 << DILITHIUM_D)));
word16 n0_1 = (word16)(DILITHIUM_D_MAX_HALF -
(t[j + 1] - (n1_1 << DILITHIUM_D)));
word16 n0_2 = (word16)(DILITHIUM_D_MAX_HALF -
(t[j + 2] - (n1_2 << DILITHIUM_D)));
word16 n0_3 = (word16)(DILITHIUM_D_MAX_HALF -
(t[j + 3] - (n1_3 << DILITHIUM_D)));
word16 n0_4 = (word16)(DILITHIUM_D_MAX_HALF -
(t[j + 4] - (n1_4 << DILITHIUM_D)));
word16 n0_5 = (word16)(DILITHIUM_D_MAX_HALF -
(t[j + 5] - (n1_5 << DILITHIUM_D)));
word16 n0_6 = (word16)(DILITHIUM_D_MAX_HALF -
(t[j + 6] - (n1_6 << DILITHIUM_D)));
word16 n0_7 = (word16)(DILITHIUM_D_MAX_HALF -
(t[j + 7] - (n1_7 << DILITHIUM_D)));
/* 13 bits per number.
* 8 numbers become 13 bytes. (8 * 13 bits = 13 * 8 bits) */
@ -1031,20 +1040,20 @@ static void dilithium_vec_encode_t0_t1(sword32* t, byte d, byte* t0, byte* t1)
tp[2] = (n0_4 >> 12) | ((word32)n0_5 << 1) |
((word32)n0_6 << 14) | ((word32)n0_7 << 27);
#else
t0[ 0] = (n0_0 << 0);
t0[ 1] = (n0_0 >> 8) | (n0_1 << 5);
t0[ 2] = (n0_1 >> 3) ;
t0[ 3] = (n0_1 >> 11) | (n0_2 << 2);
t0[ 4] = (n0_2 >> 6) | (n0_3 << 7);
t0[ 5] = (n0_3 >> 1) ;
t0[ 6] = (n0_3 >> 9) | (n0_4 << 4);
t0[ 7] = (n0_4 >> 4) ;
t0[ 8] = (n0_4 >> 12) | (n0_5 << 1);
t0[ 9] = (n0_5 >> 7) | (n0_6 << 6);
t0[10] = (n0_6 >> 2) ;
t0[11] = (n0_6 >> 10) | (n0_7 << 3);
t0[ 0] = (byte)( (n0_0 << 0));
t0[ 1] = (byte)((n0_0 >> 8) | (n0_1 << 5));
t0[ 2] = (byte)((n0_1 >> 3) );
t0[ 3] = (byte)((n0_1 >> 11) | (n0_2 << 2));
t0[ 4] = (byte)((n0_2 >> 6) | (n0_3 << 7));
t0[ 5] = (byte)((n0_3 >> 1) );
t0[ 6] = (byte)((n0_3 >> 9) | (n0_4 << 4));
t0[ 7] = (byte)((n0_4 >> 4) );
t0[ 8] = (byte)((n0_4 >> 12) | (n0_5 << 1));
t0[ 9] = (byte)((n0_5 >> 7) | (n0_6 << 6));
t0[10] = (byte)((n0_6 >> 2) );
t0[11] = (byte)((n0_6 >> 10) | (n0_7 << 3));
#endif
t0[12] = (n0_7 >> 5) ;
t0[12] = (byte)((n0_7 >> 5) );
/* 10 bits per number.
* 8 bytes become 10 bytes. (8 * 10 bits = 10 * 8 bits) */
@ -1055,17 +1064,17 @@ static void dilithium_vec_encode_t0_t1(sword32* t, byte d, byte* t0, byte* t1)
tp[1] = (n1_3 >> 2) | ((word32)n1_4 << 8) |
((word32)n1_5 << 18) | ((word32)n1_6 << 28);
#else
t1[0] = (n1_0 << 0);
t1[1] = (n1_0 >> 8) | (n1_1 << 2);
t1[2] = (n1_1 >> 6) | (n1_2 << 4);
t1[3] = (n1_2 >> 4) | (n1_3 << 6);
t1[4] = (n1_3 >> 2) ;
t1[5] = (n1_4 << 0);
t1[6] = (n1_4 >> 8) | (n1_5 << 2);
t1[7] = (n1_5 >> 6) | (n1_6 << 4);
t1[0] = (byte)( (n1_0 << 0));
t1[1] = (byte)((n1_0 >> 8) | (n1_1 << 2));
t1[2] = (byte)((n1_1 >> 6) | (n1_2 << 4));
t1[3] = (byte)((n1_2 >> 4) | (n1_3 << 6));
t1[4] = (byte)((n1_3 >> 2) );
t1[5] = (byte)( (n1_4 << 0));
t1[6] = (byte)((n1_4 >> 8) | (n1_5 << 2));
t1[7] = (byte)((n1_5 >> 6) | (n1_6 << 4));
#endif
t1[8] = (n1_6 >> 4) | (n1_7 << 6);
t1[9] = (n1_7 >> 2) ;
t1[8] = (byte)((n1_6 >> 4) | (n1_7 << 6));
t1[9] = (byte)((n1_7 >> 2) );
/* Move to next place to encode bottom bits to. */
t0 += DILITHIUM_D;
@ -1106,7 +1115,7 @@ static void dilithium_decode_t0(const byte* t0, sword32* t)
t[j + 1] = DILITHIUM_D_MAX_HALF - ((t64 >> 13) & 0x1fff);
t[j + 2] = DILITHIUM_D_MAX_HALF - ((t64 >> 26) & 0x1fff);
t[j + 3] = DILITHIUM_D_MAX_HALF - ((t64 >> 39) & 0x1fff);
t[j + 4] = DILITHIUM_D_MAX_HALF -
t[j + 4] = DILITHIUM_D_MAX_HALF - (sword32)
((t64 >> 52) | ((t32_2 & 0x0001) << 12));
#else
word32 t32_0 = ((const word32*)t0)[0];
@ -1115,18 +1124,18 @@ static void dilithium_decode_t0(const byte* t0, sword32* t)
( t32_0 & 0x1fff);
t[j + 1] = DILITHIUM_D_MAX_HALF -
((t32_0 >> 13) & 0x1fff);
t[j + 2] = DILITHIUM_D_MAX_HALF -
t[j + 2] = DILITHIUM_D_MAX_HALF - (sword32)
(( t32_0 >> 26 ) | ((t32_1 & 0x007f) << 6));
t[j + 3] = DILITHIUM_D_MAX_HALF -
((t32_1 >> 7) & 0x1fff);
t[j + 4] = DILITHIUM_D_MAX_HALF -
t[j + 4] = DILITHIUM_D_MAX_HALF - (sword32)
(( t32_1 >> 20 ) | ((t32_2 & 0x0001) << 12));
#endif
t[j + 5] = DILITHIUM_D_MAX_HALF -
((t32_2 >> 1) & 0x1fff);
t[j + 6] = DILITHIUM_D_MAX_HALF -
((t32_2 >> 14) & 0x1fff);
t[j + 7] = DILITHIUM_D_MAX_HALF -
t[j + 7] = DILITHIUM_D_MAX_HALF - (sword32)
(( t32_2 >> 27 ) | ((word32)t0[12] ) << 5 );
#else
t[j + 0] = DILITHIUM_D_MAX_HALF -
@ -1216,7 +1225,8 @@ static void dilithium_decode_t1(const byte* t1, sword32* t)
t[j+3] = (sword32)( ((t64 >> 30) & 0x03ff) << DILITHIUM_D);
t[j+4] = (sword32)( ((t64 >> 40) & 0x03ff) << DILITHIUM_D);
t[j+5] = (sword32)( ((t64 >> 50) & 0x03ff) << DILITHIUM_D);
t[j+6] = (sword32)((((t64 >> 60)| (t16 << 4)) & 0x03ff) << DILITHIUM_D);
t[j+6] = (sword32)((((t64 >> 60) |
(word64)(t16 << 4)) & 0x03ff) << DILITHIUM_D);
t[j+7] = (sword32)( ((t16 >> 6) & 0x03ff) << DILITHIUM_D);
#else
word32 t32 = *((const word32*)t1);
@ -1311,10 +1321,10 @@ static void dilithium_encode_gamma1_17_bits(const sword32* z, byte* s)
/* Step 3. Get 18 bits as a number. */
for (j = 0; j < DILITHIUM_N; j += 4) {
word32 z0 = DILITHIUM_GAMMA1_17 - z[j + 0];
word32 z1 = DILITHIUM_GAMMA1_17 - z[j + 1];
word32 z2 = DILITHIUM_GAMMA1_17 - z[j + 2];
word32 z3 = DILITHIUM_GAMMA1_17 - z[j + 3];
word32 z0 = (word32)(DILITHIUM_GAMMA1_17 - z[j + 0]);
word32 z1 = (word32)(DILITHIUM_GAMMA1_17 - z[j + 1]);
word32 z2 = (word32)(DILITHIUM_GAMMA1_17 - z[j + 2]);
word32 z3 = (word32)(DILITHIUM_GAMMA1_17 - z[j + 3]);
/* 18 bits per number.
* 8 numbers become 9 bytes. (8 * 9 bits = 9 * 8 bits) */
@ -1329,16 +1339,16 @@ static void dilithium_encode_gamma1_17_bits(const sword32* z, byte* s)
s32p[1] = (z1 >> 14) | (z2 << 4) | (z3 << 22);
#endif
#else
s[0] = z0 ;
s[1] = z0 >> 8 ;
s[2] = (z0 >> 16) | (z1 << 2);
s[3] = z1 >> 6 ;
s[4] = (z1 >> 14) | (z2 << 4);
s[5] = z2 >> 4 ;
s[6] = (z2 >> 12) | (z3 << 6);
s[7] = z3 >> 2 ;
s[0] = (byte)( z0 );
s[1] = (byte)( z0 >> 8 );
s[2] = (byte)((z0 >> 16) | (z1 << 2));
s[3] = (byte)( z1 >> 6 );
s[4] = (byte)((z1 >> 14) | (z2 << 4));
s[5] = (byte)( z2 >> 4 );
s[6] = (byte)((z2 >> 12) | (z3 << 6));
s[7] = (byte)( z3 >> 2 );
#endif
s[8] = z3 >> 10 ;
s[8] = (byte)( z3 >> 10 );
/* Move to next place to encode to. */
s += DILITHIUM_GAMMA1_17_ENC_BITS / 2;
}
@ -1372,14 +1382,14 @@ static void dilithium_encode_gamma1_19_bits(const sword32* z, byte* s)
word16* s16p = (word16*)s;
#ifdef WC_64BIT_CPU
word64* s64p = (word64*)s;
s64p[0] = z0 | ((word64)z1 << 20) |
s64p[0] = (word64)z0 | ((word64)z1 << 20) |
((word64)z2 << 40) | ((word64)z3 << 60);
#else
word32* s32p = (word32*)s;
s32p[0] = z0 | (z1 << 20) ;
s32p[1] = (z1 >> 12) | (z2 << 8) | (z3 << 28);
s32p[0] = (word16)( z0 | (z1 << 20) );
s32p[1] = (word16)((z1 >> 12) | (z2 << 8) | (z3 << 28));
#endif
s16p[4] = (z3 >> 4) ;
s16p[4] = (word16)((z3 >> 4) );
#else
s[0] = z0 ;
s[1] = (z0 >> 8) ;
@ -1525,69 +1535,69 @@ static void dilithium_decode_gamma1(const byte* s, int bits, sword32* z)
#ifdef WC_64BIT_CPU
word64 s64_0 = *(const word64*)(s+0);
word64 s64_1 = *(const word64*)(s+9);
z[i+0] = (word32)DILITHIUM_GAMMA1_17 -
( s64_0 & 0x3ffff );
z[i+1] = (word32)DILITHIUM_GAMMA1_17 -
((s64_0 >> 18) & 0x3ffff );
z[i+2] = (word32)DILITHIUM_GAMMA1_17 -
((s64_0 >> 36) & 0x3ffff );
z[i+3] = (word32)DILITHIUM_GAMMA1_17 -
((s64_0 >> 54) | (((word32)s[8]) << 10));
z[i+4] = (word32)DILITHIUM_GAMMA1_17 -
( s64_1 & 0x3ffff );
z[i+5] = (word32)DILITHIUM_GAMMA1_17 -
((s64_1 >> 18) & 0x3ffff );
z[i+6] = (word32)DILITHIUM_GAMMA1_17 -
((s64_1 >> 36) & 0x3ffff );
z[i+7] = (word32)DILITHIUM_GAMMA1_17 -
((s64_1 >> 54) | (((word32)s[17]) << 10));
z[i+0] = (sword32)((word32)DILITHIUM_GAMMA1_17 -
( s64_0 & 0x3ffff ));
z[i+1] = (sword32)((word32)DILITHIUM_GAMMA1_17 -
((s64_0 >> 18) & 0x3ffff ));
z[i+2] = (sword32)((word32)DILITHIUM_GAMMA1_17 -
((s64_0 >> 36) & 0x3ffff ));
z[i+3] = (sword32)((word32)DILITHIUM_GAMMA1_17 -
((s64_0 >> 54) | (((word32)s[8]) << 10)));
z[i+4] = (sword32)((word32)DILITHIUM_GAMMA1_17 -
( s64_1 & 0x3ffff ));
z[i+5] = (sword32)((word32)DILITHIUM_GAMMA1_17 -
((s64_1 >> 18) & 0x3ffff ));
z[i+6] = (sword32)((word32)DILITHIUM_GAMMA1_17 -
((s64_1 >> 36) & 0x3ffff ));
z[i+7] = (sword32)((word32)DILITHIUM_GAMMA1_17 -
((s64_1 >> 54) | (((word32)s[17]) << 10)));
#else
word32 s32_0 = ((const word32*)(s+0))[0];
word32 s32_1 = ((const word32*)(s+0))[1];
word32 s32_2 = ((const word32*)(s+9))[0];
word32 s32_3 = ((const word32*)(s+9))[1];
z[i+0] = (word32)DILITHIUM_GAMMA1_17 -
( s32_0 & 0x3ffff );
z[i+1] = (word32)DILITHIUM_GAMMA1_17 -
((s32_0 >> 18) | (((s32_1 & 0x0000f) << 14)));
z[i+2] = (word32)DILITHIUM_GAMMA1_17 -
((s32_1 >> 4) & 0x3ffff);
z[i+3] = (word32)DILITHIUM_GAMMA1_17 -
((s32_1 >> 22) | (((word32)s[8]) << 10 ));
z[i+4] = (word32)DILITHIUM_GAMMA1_17 -
( s32_2 & 0x3ffff );
z[i+5] = (word32)DILITHIUM_GAMMA1_17 -
((s32_2 >> 18) | (((s32_3 & 0x0000f) << 14)));
z[i+6] = (word32)DILITHIUM_GAMMA1_17 -
((s32_3 >> 4) & 0x3ffff);
z[i+7] = (word32)DILITHIUM_GAMMA1_17 -
((s32_3 >> 22) | (((word32)s[17]) << 10 ));
z[i+0] = (sword32)((word32)DILITHIUM_GAMMA1_17 -
( s32_0 & 0x3ffff ));
z[i+1] = (sword32)((word32)DILITHIUM_GAMMA1_17 -
((s32_0 >> 18) | (((s32_1 & 0x0000f) << 14))));
z[i+2] = (sword32)((word32)DILITHIUM_GAMMA1_17 -
((s32_1 >> 4) & 0x3ffff ));
z[i+3] = (sword32)((word32)DILITHIUM_GAMMA1_17 -
((s32_1 >> 22) | (((word32)s[8]) << 10 )));
z[i+4] = (sword32)((word32)DILITHIUM_GAMMA1_17 -
( s32_2 & 0x3ffff ));
z[i+5] = (sword32)((word32)DILITHIUM_GAMMA1_17 -
((s32_2 >> 18) | (((s32_3 & 0x0000f) << 14))));
z[i+6] = (sword32)((word32)DILITHIUM_GAMMA1_17 -
((s32_3 >> 4) & 0x3ffff ));
z[i+7] = (sword32)((word32)DILITHIUM_GAMMA1_17 -
((s32_3 >> 22) | (((word32)s[17]) << 10 )));
#endif
#else
z[i+0] = DILITHIUM_GAMMA1_17 -
( s[ 0] | ((sword32)(s[ 1] << 8) |
(sword32)(s[ 2] & 0x03) << 16));
z[i+1] = DILITHIUM_GAMMA1_17 -
((s[ 2] >> 2) | ((sword32)(s[ 3] << 6) |
(sword32)(s[ 4] & 0x0f) << 14));
z[i+2] = DILITHIUM_GAMMA1_17 -
((s[ 4] >> 4) | ((sword32)(s[ 5] << 4) |
(sword32)(s[ 6] & 0x3f) << 12));
z[i+3] = DILITHIUM_GAMMA1_17 -
((s[ 6] >> 6) | ((sword32)(s[ 7] << 2) |
(sword32)(s[ 8] ) << 10));
z[i+4] = DILITHIUM_GAMMA1_17 -
( s[ 9] | ((sword32)(s[10] << 8) |
(sword32)(s[11] & 0x03) << 16));
z[i+5] = DILITHIUM_GAMMA1_17 -
((s[11] >> 2) | ((sword32)(s[12] << 6) |
(sword32)(s[13] & 0x0f) << 14));
z[i+6] = DILITHIUM_GAMMA1_17 -
((s[13] >> 4) | ((sword32)(s[14] << 4) |
(sword32)(s[15] & 0x3f) << 12));
z[i+7] = DILITHIUM_GAMMA1_17 -
((s[15] >> 6) | ((sword32)(s[16] << 2) |
(sword32)(s[17] ) << 10));
z[i+0] = (sword32)((word32)DILITHIUM_GAMMA1_17 -
( s[ 0] | ((sword32)(s[ 1] << 8) |
(sword32)(s[ 2] & 0x03) << 16)));
z[i+1] = (sword32)((word32)DILITHIUM_GAMMA1_17 -
((s[ 2] >> 2) | ((sword32)(s[ 3] << 6) |
(sword32)(s[ 4] & 0x0f) << 14)));
z[i+2] = (sword32)((word32)DILITHIUM_GAMMA1_17 -
((s[ 4] >> 4) | ((sword32)(s[ 5] << 4) |
(sword32)(s[ 6] & 0x3f) << 12)));
z[i+3] = (sword32)((word32)DILITHIUM_GAMMA1_17 -
((s[ 6] >> 6) | ((sword32)(s[ 7] << 2) |
(sword32)(s[ 8] ) << 10)));
z[i+4] = (sword32)((word32)DILITHIUM_GAMMA1_17 -
( s[ 9] | ((sword32)(s[10] << 8) |
(sword32)(s[11] & 0x03) << 16)));
z[i+5] = (sword32)((word32)DILITHIUM_GAMMA1_17 -
((s[11] >> 2) | ((sword32)(s[12] << 6) |
(sword32)(s[13] & 0x0f) << 14)));
z[i+6] = (sword32)((word32)DILITHIUM_GAMMA1_17 -
((s[13] >> 4) | ((sword32)(s[14] << 4) |
(sword32)(s[15] & 0x3f) << 12)));
z[i+7] = (sword32)((word32)DILITHIUM_GAMMA1_17 -
((s[15] >> 6) | ((sword32)(s[16] << 2) |
(sword32)(s[17] ) << 10)));
#endif
/* Move to next place to decode from. */
s += DILITHIUM_GAMMA1_17_ENC_BITS;
@ -1646,16 +1656,24 @@ static void dilithium_decode_gamma1(const byte* s, int bits, sword32* z)
#ifdef WC_64BIT_CPU
word64 s64_0 = *(const word64*)(s+0);
word64 s64_1 = *(const word64*)(s+10);
z[i+0] = DILITHIUM_GAMMA1_19 - ( s64_0 & 0xfffff) ;
z[i+1] = DILITHIUM_GAMMA1_19 - ( (s64_0 >> 20) & 0xfffff) ;
z[i+2] = DILITHIUM_GAMMA1_19 - ( (s64_0 >> 40) & 0xfffff) ;
z[i+3] = DILITHIUM_GAMMA1_19 - (((s64_0 >> 60) & 0xfffff) |
((sword32)s16_0 << 4));
z[i+4] = DILITHIUM_GAMMA1_19 - ( s64_1 & 0xfffff) ;
z[i+5] = DILITHIUM_GAMMA1_19 - ( (s64_1 >> 20) & 0xfffff) ;
z[i+6] = DILITHIUM_GAMMA1_19 - ( (s64_1 >> 40) & 0xfffff) ;
z[i+7] = DILITHIUM_GAMMA1_19 - (((s64_1 >> 60) & 0xfffff) |
((sword32)s16_1 << 4));
z[i+0] = DILITHIUM_GAMMA1_19 -
((sword32)( s64_0 & 0xfffff)) ;
z[i+1] = DILITHIUM_GAMMA1_19 -
((sword32)( (s64_0 >> 20) & 0xfffff)) ;
z[i+2] = DILITHIUM_GAMMA1_19 -
((sword32)( (s64_0 >> 40) & 0xfffff)) ;
z[i+3] = DILITHIUM_GAMMA1_19 -
((sword32)(((s64_0 >> 60) & 0xfffff)) |
((sword32)s16_0 << 4));
z[i+4] = DILITHIUM_GAMMA1_19 -
((sword32)( s64_1 & 0xfffff)) ;
z[i+5] = DILITHIUM_GAMMA1_19 -
((sword32)( (s64_1 >> 20) & 0xfffff)) ;
z[i+6] = DILITHIUM_GAMMA1_19 -
((sword32)( (s64_1 >> 40) & 0xfffff)) ;
z[i+7] = DILITHIUM_GAMMA1_19 -
((sword32)(((s64_1 >> 60) & 0xfffff)) |
((sword32)s16_1 << 4));
#else
word32 s32_0 = ((const word32*)(s+ 0))[0];
word32 s32_1 = ((const word32*)(s+ 0))[1];
@ -1767,28 +1785,28 @@ static void dilithium_encode_w1_88(const sword32* w1, byte* w1e)
* 16 numbers in 12 bytes. (16 * 6 bits = 12 * 8 bits) */
#if defined(LITTLE_ENDIAN_ORDER) && (WOLFSSL_DILITHIUM_ALIGNMENT <= 4)
word32* w1e32 = (word32*)w1e;
w1e32[0] = w1[j+ 0] | (w1[j+ 1] << 6) |
(w1[j+ 2] << 12) | (w1[j+ 3] << 18) |
(w1[j+ 4] << 24) | (w1[j+ 5] << 30);
w1e32[1] = (w1[j+ 5] >> 2) | (w1[j+ 6] << 4) |
(w1[j+ 7] << 10) | (w1[j+ 8] << 16) |
(w1[j+ 9] << 22) | (w1[j+10] << 28);
w1e32[2] = (w1[j+10] >> 4) | (w1[j+11] << 2) |
(w1[j+12] << 8) | (w1[j+13] << 14) |
(w1[j+14] << 20) | (w1[j+15] << 26);
w1e32[0] = (word32)( w1[j+ 0] | (w1[j+ 1] << 6) |
(w1[j+ 2] << 12) | (w1[j+ 3] << 18) |
(w1[j+ 4] << 24) | (w1[j+ 5] << 30));
w1e32[1] = (word32)((w1[j+ 5] >> 2) | (w1[j+ 6] << 4) |
(w1[j+ 7] << 10) | (w1[j+ 8] << 16) |
(w1[j+ 9] << 22) | (w1[j+10] << 28));
w1e32[2] = (word32)((w1[j+10] >> 4) | (w1[j+11] << 2) |
(w1[j+12] << 8) | (w1[j+13] << 14) |
(w1[j+14] << 20) | (w1[j+15] << 26));
#else
w1e[ 0] = w1[j+ 0] | (w1[j+ 1] << 6);
w1e[ 1] = (w1[j+ 1] >> 2) | (w1[j+ 2] << 4);
w1e[ 2] = (w1[j+ 2] >> 4) | (w1[j+ 3] << 2);
w1e[ 3] = w1[j+ 4] | (w1[j+ 5] << 6);
w1e[ 4] = (w1[j+ 5] >> 2) | (w1[j+ 6] << 4);
w1e[ 5] = (w1[j+ 6] >> 4) | (w1[j+ 7] << 2);
w1e[ 6] = w1[j+ 8] | (w1[j+ 9] << 6);
w1e[ 7] = (w1[j+ 9] >> 2) | (w1[j+10] << 4);
w1e[ 8] = (w1[j+10] >> 4) | (w1[j+11] << 2);
w1e[ 9] = w1[j+12] | (w1[j+13] << 6);
w1e[10] = (w1[j+13] >> 2) | (w1[j+14] << 4);
w1e[11] = (w1[j+14] >> 4) | (w1[j+15] << 2);
w1e[ 0] = (byte)( w1[j+ 0] | (w1[j+ 1] << 6));
w1e[ 1] = (byte)((w1[j+ 1] >> 2) | (w1[j+ 2] << 4));
w1e[ 2] = (byte)((w1[j+ 2] >> 4) | (w1[j+ 3] << 2));
w1e[ 3] = (byte)( w1[j+ 4] | (w1[j+ 5] << 6));
w1e[ 4] = (byte)((w1[j+ 5] >> 2) | (w1[j+ 6] << 4));
w1e[ 5] = (byte)((w1[j+ 6] >> 4) | (w1[j+ 7] << 2));
w1e[ 6] = (byte)( w1[j+ 8] | (w1[j+ 9] << 6));
w1e[ 7] = (byte)((w1[j+ 9] >> 2) | (w1[j+10] << 4));
w1e[ 8] = (byte)((w1[j+10] >> 4) | (w1[j+11] << 2));
w1e[ 9] = (byte)( w1[j+12] | (w1[j+13] << 6));
w1e[10] = (byte)((w1[j+13] >> 2) | (w1[j+14] << 4));
w1e[11] = (byte)((w1[j+14] >> 4) | (w1[j+15] << 2));
#endif
/* Move to next place to encode to. */
w1e += DILITHIUM_Q_HI_88_ENC_BITS * 2;
@ -1819,23 +1837,23 @@ static void dilithium_encode_w1_32(const sword32* w1, byte* w1e)
* 16 numbers in 8 bytes. (16 * 4 bits = 8 * 8 bits) */
#if defined(LITTLE_ENDIAN_ORDER) && (WOLFSSL_DILITHIUM_ALIGNMENT <= 8)
word32* w1e32 = (word32*)w1e;
w1e32[0] = (w1[j + 0] << 0) | (w1[j + 1] << 4) |
(w1[j + 2] << 8) | (w1[j + 3] << 12) |
(w1[j + 4] << 16) | (w1[j + 5] << 20) |
(w1[j + 6] << 24) | (w1[j + 7] << 28);
w1e32[1] = (w1[j + 8] << 0) | (w1[j + 9] << 4) |
(w1[j + 10] << 8) | (w1[j + 11] << 12) |
(w1[j + 12] << 16) | (w1[j + 13] << 20) |
(w1[j + 14] << 24) | (w1[j + 15] << 28);
w1e32[0] = (word32)((w1[j + 0] << 0) | (w1[j + 1] << 4) |
(w1[j + 2] << 8) | (w1[j + 3] << 12) |
(w1[j + 4] << 16) | (w1[j + 5] << 20) |
(w1[j + 6] << 24) | (w1[j + 7] << 28));
w1e32[1] = (word32)((w1[j + 8] << 0) | (w1[j + 9] << 4) |
(w1[j + 10] << 8) | (w1[j + 11] << 12) |
(w1[j + 12] << 16) | (w1[j + 13] << 20) |
(w1[j + 14] << 24) | (w1[j + 15] << 28));
#else
w1e[0] = w1[j + 0] | (w1[j + 1] << 4);
w1e[1] = w1[j + 2] | (w1[j + 3] << 4);
w1e[2] = w1[j + 4] | (w1[j + 5] << 4);
w1e[3] = w1[j + 6] | (w1[j + 7] << 4);
w1e[4] = w1[j + 8] | (w1[j + 9] << 4);
w1e[5] = w1[j + 10] | (w1[j + 11] << 4);
w1e[6] = w1[j + 12] | (w1[j + 13] << 4);
w1e[7] = w1[j + 14] | (w1[j + 15] << 4);
w1e[0] = (byte)(w1[j + 0] | (w1[j + 1] << 4));
w1e[1] = (byte)(w1[j + 2] | (w1[j + 3] << 4));
w1e[2] = (byte)(w1[j + 4] | (w1[j + 5] << 4));
w1e[3] = (byte)(w1[j + 6] | (w1[j + 7] << 4));
w1e[4] = (byte)(w1[j + 8] | (w1[j + 9] << 4));
w1e[5] = (byte)(w1[j + 10] | (w1[j + 11] << 4));
w1e[6] = (byte)(w1[j + 12] | (w1[j + 13] << 4));
w1e[7] = (byte)(w1[j + 14] | (w1[j + 15] << 4));
#endif
/* Move to next place to encode to. */
w1e += DILITHIUM_Q_HI_32_ENC_BITS * 2;
@ -2802,10 +2820,12 @@ static int dilithium_sample_in_ball_ex(int level, wc_Shake* shake256,
const byte* seed, word32 seedLen, byte tau, sword32* c, byte* block)
{
int ret = 0;
unsigned int k;
unsigned int i;
unsigned int s;
byte signs[DILITHIUM_SIGN_BYTES];
unsigned int i;
/* Step 1: Initialize sign bit index. */
unsigned int s = 0;
/* Step 2: First 8 bytes are used for sign. */
unsigned int k = DILITHIUM_SIGN_BYTES;
if (ret == 0) {
/* Set polynomial to all zeros. */
@ -2828,10 +2848,6 @@ static int dilithium_sample_in_ball_ex(int level, wc_Shake* shake256,
if (ret == 0) {
/* Copy first 8 bytes of first hash block as random sign bits. */
XMEMCPY(signs, block, DILITHIUM_SIGN_BYTES);
/* Step 1: Initialize sign bit index. */
s = 0;
/* Step 2: First 8 bytes are used for sign. */
k = DILITHIUM_SIGN_BYTES;
}
/* Step 3: Put in TAU +/- 1s. */
@ -3354,7 +3370,7 @@ static int dilithium_make_hint_32(const sword32* s, const sword32* w1,
* return Falsam of -1 when too many hints.
*/
static int dilithium_make_hint(const sword32* s, const sword32* w1, byte k,
word32 gamma2, byte omega, byte* h)
sword32 gamma2, byte omega, byte* h)
{
unsigned int i;
byte idx = 0;
@ -3509,12 +3525,12 @@ static void dilithium_use_hint_88(sword32* w1, const byte* h, unsigned int i,
w1[j] = r1 + hint;
/* Fix up w1 to not be 44 but 0. */
w1[j] &= 0 - (((word32)(w1[j] - 44)) >> 31);
w1[j] &= (sword32)(0 - (((word32)(w1[j] - 44)) >> 31));
/* Hint may have reduced 0 to -1 which is actually 43. */
w1[j] += (0 - (((word32)w1[j]) >> 31)) & 44;
w1[j] += (sword32)((0 - (((word32)w1[j]) >> 31)) & 44);
#else
/* Convert value to positive only range. */
r = w1[j] + ((0 - (((word32)w1[j]) >> 31)) & DILITHIUM_Q);
r = w1[j] + (sword32)((0 - (((word32)w1[j]) >> 31)) & DILITHIUM_Q);
/* Decompose value into low and high parts. */
dilithium_decompose_q88(r, &r0, &r1);
/* Check for hint. */
@ -3570,11 +3586,11 @@ static void dilithium_use_hint_32(sword32* w1, const byte* h, byte omega,
/* Increment hint offset if this index has hint. */
o += hint;
/* Convert value to positive only range. */
r = w1[j] + ((0 - (((word32)w1[j]) >> 31)) & DILITHIUM_Q);
r = w1[j] + (sword32)((0 - (((word32)w1[j]) >> 31)) & DILITHIUM_Q);
/* Decompose value into low and high parts. */
dilithium_decompose_q32(r, &r0, &r1);
/* Make hint positive or negative based on sign of r0. */
hint = (1 - (2 * (((word32)r0) >> 31))) & (0 - hint);
hint = (sword32)((1 - (2 * (((word32)r0) >> 31))) & (0 - hint));
/* Make w1 only the top part plus the hint. */
w1[j] = r1 + hint;
@ -3582,13 +3598,13 @@ static void dilithium_use_hint_32(sword32* w1, const byte* h, byte omega,
w1[j] &= 0xf;
#else
/* Convert value to positive only range. */
r = w1[j] + ((0 - (((word32)w1[j]) >> 31)) & DILITHIUM_Q);
r = w1[j] + (sword32)((0 - (((word32)w1[j]) >> 31)) & DILITHIUM_Q);
/* Decompose value into low and high parts. */
dilithium_decompose_q32(r, &r0, &r1);
/* Check for hint. */
if ((o < h[omega + i]) && (h[o] == (byte)j)) {
/* Add or subtract hint based on sign of r0. */
r1 += 1 - (2 * (((word32)r0) >> 31));
r1 += (sword32)(1 - (2 * (((word32)r0) >> 31)));
/* Go to next hint offset. */
o++;
}
@ -3616,7 +3632,7 @@ static void dilithium_use_hint_32(sword32* w1, const byte* h, byte omega,
* @param [in] omega Max number of hints. Hint counts after this index.
* @param [in] h Hints to apply. In signature encoding.
*/
static void dilithium_vec_use_hint(sword32* w1, byte k, word32 gamma2,
static void dilithium_vec_use_hint(sword32* w1, byte k, sword32 gamma2,
byte omega, const byte* h)
{
unsigned int i;
@ -3883,7 +3899,7 @@ static void dilithium_ntt(sword32* r)
}
for (j = 0; j < DILITHIUM_N; j += 64) {
int i;
unsigned int i;
sword32 zeta32 = zetas[ 4 + j / 64 + 0];
sword32 zeta160 = zetas[ 8 + j / 32 + 0];
sword32 zeta161 = zetas[ 8 + j / 32 + 1];
@ -3915,7 +3931,7 @@ static void dilithium_ntt(sword32* r)
}
for (j = 0; j < DILITHIUM_N; j += 16) {
int i;
unsigned int i;
sword32 zeta8 = zetas[16 + j / 16];
sword32 zeta40 = zetas[32 + j / 8 + 0];
sword32 zeta41 = zetas[32 + j / 8 + 1];
@ -4031,7 +4047,7 @@ static void dilithium_ntt(sword32* r)
}
for (j = 0; j < DILITHIUM_N; j += 64) {
int i;
unsigned int i;
sword32 zeta32 = zetas[ 4 + j / 64 + 0];
sword32 zeta160 = zetas[ 8 + j / 32 + 0];
sword32 zeta161 = zetas[ 8 + j / 32 + 1];
@ -4254,7 +4270,7 @@ static void dilithium_ntt_small(sword32* r)
}
for (j = 0; j < DILITHIUM_N; j += 64) {
int i;
unsigned int i;
sword32 zeta32 = zetas[ 4 + j / 64 + 0];
sword32 zeta160 = zetas[ 8 + j / 32 + 0];
sword32 zeta161 = zetas[ 8 + j / 32 + 1];
@ -4286,7 +4302,7 @@ static void dilithium_ntt_small(sword32* r)
}
for (j = 0; j < DILITHIUM_N; j += 16) {
int i;
unsigned int i;
sword32 zeta8 = zetas[16 + j / 16];
sword32 zeta40 = zetas[32 + j / 8 + 0];
sword32 zeta41 = zetas[32 + j / 8 + 1];
@ -4398,7 +4414,7 @@ static void dilithium_ntt_small(sword32* r)
}
for (j = 0; j < DILITHIUM_N; j += 64) {
int i;
unsigned int i;
sword32 zeta32 = zetas[ 4 + j / 64 + 0];
sword32 zeta160 = zetas[ 8 + j / 32 + 0];
sword32 zeta161 = zetas[ 8 + j / 32 + 1];
@ -4686,7 +4702,7 @@ static void dilithium_invntt(sword32* r)
}
for (j = 0; j < DILITHIUM_N; j += 16) {
int i;
unsigned int i;
sword32 zeta40 = zetas_inv[192 + j / 8 + 0];
sword32 zeta41 = zetas_inv[192 + j / 8 + 1];
sword32 zeta8 = zetas_inv[224 + j / 16 + 0];
@ -4718,7 +4734,7 @@ static void dilithium_invntt(sword32* r)
}
for (j = 0; j < DILITHIUM_N; j += 64) {
int i;
unsigned int i;
sword32 zeta160 = zetas_inv[240 + j / 32 + 0];
sword32 zeta161 = zetas_inv[240 + j / 32 + 1];
sword32 zeta32 = zetas_inv[248 + j / 64 + 0];
@ -4858,7 +4874,7 @@ static void dilithium_invntt(sword32* r)
}
for (j = 0; j < DILITHIUM_N; j += 64) {
int i;
unsigned int i;
sword32 zeta80 = zetas_inv[224 + j / 16 + 0];
sword32 zeta81 = zetas_inv[224 + j / 16 + 1];
sword32 zeta82 = zetas_inv[224 + j / 16 + 2];
@ -7066,7 +7082,7 @@ static int dilithium_sign_ctx_hash_with_seed(dilithium_key* key,
byte seedMu[DILITHIUM_RND_SZ + DILITHIUM_MU_SZ];
byte* mu = seedMu + DILITHIUM_RND_SZ;
byte oidMsgHash[DILITHIUM_HASH_OID_LEN + WC_MAX_DIGEST_SIZE];
word32 oidMsgHashLen;
word32 oidMsgHashLen = 0;
if ((ret == 0) && (hashLen > WC_MAX_DIGEST_SIZE)) {
ret = BUFFER_E;
@ -7649,7 +7665,7 @@ static int dilithium_verify_ctx_msg(dilithium_key* key, const byte* ctx,
if (ret == 0) {
/* Step 6. Calculate mu. */
ret = dilithium_hash256_ctx_msg(&key->shake, tr, DILITHIUM_TR_SZ, 0,
ctx, ctxLen, msg, msgLen, mu, DILITHIUM_MU_SZ);
ctx, (byte)ctxLen, msg, msgLen, mu, DILITHIUM_MU_SZ);
}
if (ret == 0) {
ret = dilithium_verify_mu(key, mu, sig, sigLen, res);
@ -7727,7 +7743,7 @@ static int dilithium_verify_ctx_hash(dilithium_key* key, const byte* ctx,
byte tr[DILITHIUM_TR_SZ];
byte* mu = tr;
byte oidMsgHash[DILITHIUM_HASH_OID_LEN + WC_MAX_DIGEST_SIZE];
word32 oidMsgHashLen;
word32 oidMsgHashLen = 0;
if (key == NULL) {
ret = BAD_FUNC_ARG;
@ -7747,7 +7763,7 @@ static int dilithium_verify_ctx_hash(dilithium_key* key, const byte* ctx,
/* Step 6. Calculate mu. */
ret = dilithium_hash256_ctx_msg(&key->shake, tr, DILITHIUM_TR_SZ, 1,
ctx, ctxLen, oidMsgHash, oidMsgHashLen, mu, DILITHIUM_MU_SZ);
ctx, (byte)ctxLen, oidMsgHash, oidMsgHashLen, mu, DILITHIUM_MU_SZ);
}
if (ret == 0) {
ret = dilithium_verify_mu(key, mu, sig, sigLen, res);
@ -8934,7 +8950,7 @@ int wc_dilithium_check_key(dilithium_key* key)
{
int ret = 0;
#ifdef WOLFSSL_WC_DILITHIUM
const wc_dilithium_params* params;
const wc_dilithium_params* params = NULL;
sword32* a = NULL;
sword32* s1 = NULL;
sword32* s2 = NULL;
@ -9491,7 +9507,7 @@ int wc_dilithium_export_private(dilithium_key* key, byte* out,
word32* outLen)
{
int ret = 0;
word32 inLen;
word32 inLen = 0;
/* Validate parameters. */
if ((key == NULL) || (out == NULL) || (outLen == NULL)) {
@ -9584,7 +9600,7 @@ int wc_dilithium_export_key(dilithium_key* key, byte* priv, word32 *privSz,
#ifndef WOLFSSL_DILITHIUM_NO_ASN1
/* Maps ASN.1 OID to wolfCrypt security level macros */
static int mapOidToSecLevel(word32 oid)
static int mapOidToSecLevel(int oid)
{
switch (oid) {
case ML_DSA_LEVEL2k:
@ -9672,7 +9688,7 @@ int wc_Dilithium_PrivateKeyDecode(const byte* input, word32* inOutIdx,
const byte* pubKey = NULL;
word32 privKeyLen = 0;
word32 pubKeyLen = 0;
int keytype = 0;
int keyType = 0;
/* Validate parameters. */
if ((input == NULL) || (inOutIdx == NULL) || (key == NULL) || (inSz == 0)) {
@ -9684,30 +9700,30 @@ int wc_Dilithium_PrivateKeyDecode(const byte* input, word32* inOutIdx,
if (key->level == 0) { /* Check first, because key->params will be NULL
* when key->level = 0 */
/* Level not set by caller, decode from DER */
keytype = ANONk;
keyType = ANONk;
}
#if defined(WOLFSSL_DILITHIUM_FIPS204_DRAFT)
else if (key->params == NULL) {
ret = BAD_FUNC_ARG;
}
else if (key->params->level == WC_ML_DSA_44_DRAFT) {
keytype = DILITHIUM_LEVEL2k;
keyType = DILITHIUM_LEVEL2k;
}
else if (key->params->level == WC_ML_DSA_65_DRAFT) {
keytype = DILITHIUM_LEVEL3k;
keyType = DILITHIUM_LEVEL3k;
}
else if (key->params->level == WC_ML_DSA_87_DRAFT) {
keytype = DILITHIUM_LEVEL5k;
keyType = DILITHIUM_LEVEL5k;
}
#endif
else if (key->level == WC_ML_DSA_44) {
keytype = ML_DSA_LEVEL2k;
keyType = ML_DSA_LEVEL2k;
}
else if (key->level == WC_ML_DSA_65) {
keytype = ML_DSA_LEVEL3k;
keyType = ML_DSA_LEVEL3k;
}
else if (key->level == WC_ML_DSA_87) {
keytype = ML_DSA_LEVEL5k;
keyType = ML_DSA_LEVEL5k;
}
else {
ret = BAD_FUNC_ARG;
@ -9718,16 +9734,16 @@ int wc_Dilithium_PrivateKeyDecode(const byte* input, word32* inOutIdx,
/* Decode the asymmetric key and get out private and public key data. */
ret = DecodeAsymKey_Assign(input, inOutIdx, inSz,
&privKey, &privKeyLen,
&pubKey, &pubKeyLen, &keytype);
&pubKey, &pubKeyLen, &keyType);
if (ret == 0
#ifdef WOLFSSL_WC_DILITHIUM
&& key->params == NULL
#endif
) {
/* Set the security level based on the decoded key. */
ret = mapOidToSecLevel(keytype);
ret = mapOidToSecLevel(keyType);
if (ret > 0) {
ret = wc_dilithium_set_level(key, ret);
ret = wc_dilithium_set_level(key, (byte)ret);
}
}
}
@ -9941,7 +9957,7 @@ int wc_Dilithium_PublicKeyDecode(const byte* input, word32* inOutIdx,
dilithium_key* key, word32 inSz)
{
int ret = 0;
const byte* pubKey;
const byte* pubKey = NULL;
word32 pubKeyLen = 0;
/* Validate parameters. */
@ -9954,7 +9970,7 @@ int wc_Dilithium_PublicKeyDecode(const byte* input, word32* inOutIdx,
ret = wc_dilithium_import_public(input, inSz, key);
if (ret != 0) {
#if !defined(WOLFSSL_DILITHIUM_NO_ASN1)
int keytype = 0;
int keyType = 0;
#else
int length;
unsigned char* oid;
@ -9972,43 +9988,43 @@ int wc_Dilithium_PublicKeyDecode(const byte* input, word32* inOutIdx,
ret = BAD_FUNC_ARG;
}
else if (key->params->level == WC_ML_DSA_44_DRAFT) {
keytype = DILITHIUM_LEVEL2k;
keyType = DILITHIUM_LEVEL2k;
}
else if (key->params->level == WC_ML_DSA_65_DRAFT) {
keytype = DILITHIUM_LEVEL3k;
keyType = DILITHIUM_LEVEL3k;
}
else if (key->params->level == WC_ML_DSA_87_DRAFT) {
keytype = DILITHIUM_LEVEL5k;
keyType = DILITHIUM_LEVEL5k;
}
else
#endif
if (key->level == WC_ML_DSA_44) {
keytype = ML_DSA_LEVEL2k;
keyType = ML_DSA_LEVEL2k;
}
else if (key->level == WC_ML_DSA_65) {
keytype = ML_DSA_LEVEL3k;
keyType = ML_DSA_LEVEL3k;
}
else if (key->level == WC_ML_DSA_87) {
keytype = ML_DSA_LEVEL5k;
keyType = ML_DSA_LEVEL5k;
}
else {
/* Level not set by caller, decode from DER */
keytype = ANONk; /* 0, not a valid key type in this situation*/
keyType = ANONk; /* 0, not a valid key type in this situation*/
}
if (ret == 0) {
/* Decode the asymmetric key and get out public key data. */
ret = DecodeAsymKeyPublic_Assign(input, inOutIdx, inSz,
&pubKey, &pubKeyLen,
&keytype);
&keyType);
if (ret == 0
#ifdef WOLFSSL_WC_DILITHIUM
&& key->params == NULL
#endif
) {
/* Set the security level based on the decoded key. */
ret = mapOidToSecLevel(keytype);
ret = mapOidToSecLevel(keyType);
if (ret > 0) {
ret = wc_dilithium_set_level(key, ret);
ret = wc_dilithium_set_level(key, (byte)ret);
}
}
}
@ -10140,8 +10156,8 @@ int wc_Dilithium_PublicKeyToDer(dilithium_key* key, byte* output, word32 len,
int withAlg)
{
int ret = 0;
int keytype = 0;
int pubKeyLen = 0;
int keyType = 0;
word32 pubKeyLen = 0;
/* Validate parameters. */
if (key == NULL) {
@ -10159,29 +10175,29 @@ int wc_Dilithium_PublicKeyToDer(dilithium_key* key, byte* output, word32 len,
ret = BAD_FUNC_ARG;
}
else if (key->params->level == WC_ML_DSA_44_DRAFT) {
keytype = DILITHIUM_LEVEL2k;
keyType = DILITHIUM_LEVEL2k;
pubKeyLen = DILITHIUM_LEVEL2_PUB_KEY_SIZE;
}
else if (key->params->level == WC_ML_DSA_65_DRAFT) {
keytype = DILITHIUM_LEVEL3k;
keyType = DILITHIUM_LEVEL3k;
pubKeyLen = DILITHIUM_LEVEL3_PUB_KEY_SIZE;
}
else if (key->params->level == WC_ML_DSA_87_DRAFT) {
keytype = DILITHIUM_LEVEL5k;
keyType = DILITHIUM_LEVEL5k;
pubKeyLen = DILITHIUM_LEVEL5_PUB_KEY_SIZE;
}
else
#endif
if (key->level == WC_ML_DSA_44) {
keytype = ML_DSA_LEVEL2k;
keyType = ML_DSA_LEVEL2k;
pubKeyLen = ML_DSA_LEVEL2_PUB_KEY_SIZE;
}
else if (key->level == WC_ML_DSA_65) {
keytype = ML_DSA_LEVEL3k;
keyType = ML_DSA_LEVEL3k;
pubKeyLen = ML_DSA_LEVEL3_PUB_KEY_SIZE;
}
else if (key->level == WC_ML_DSA_87) {
keytype = ML_DSA_LEVEL5k;
keyType = ML_DSA_LEVEL5k;
pubKeyLen = ML_DSA_LEVEL5_PUB_KEY_SIZE;
}
else {
@ -10191,7 +10207,7 @@ int wc_Dilithium_PublicKeyToDer(dilithium_key* key, byte* output, word32 len,
}
if (ret == 0) {
ret = SetAsymKeyDerPublic(key->p, pubKeyLen, output, len, keytype,
ret = SetAsymKeyDerPublic(key->p, pubKeyLen, output, len, keyType,
withAlg);
}

View File

@ -651,7 +651,7 @@ typedef struct wc_dilithium_params {
byte omega;
word16 lambda;
byte gamma1_bits;
word32 gamma2;
sword32 gamma2;
word32 w1EncSz;
word16 aSz;
word16 s1Sz;