diff --git a/wolfcrypt/src/sp_int.c b/wolfcrypt/src/sp_int.c index 30d3d01a8..adc968c93 100644 --- a/wolfcrypt/src/sp_int.c +++ b/wolfcrypt/src/sp_int.c @@ -4768,11 +4768,17 @@ int sp_mod(sp_int* a, sp_int* m, sp_int* r) sp_int_word w; sp_int_word l; sp_int_word h; + #ifdef SP_WORD_OVERFLOW + sp_int_word o; + #endif w = (sp_int_word)a->dp[0] * b->dp[0]; t->dp[0] = (sp_int_digit)w; l = (sp_int_digit)(w >> SP_WORD_SIZE); h = 0; + #ifdef SP_WORD_OVERFLOW + o = 0; + #endif for (k = 1; k <= (a->used - 1) + (b->used - 1); k++) { i = k - (b->used - 1); i &= ~(i >> (sizeof(i) * 8 - 1)); @@ -4781,11 +4787,21 @@ int sp_mod(sp_int* a, sp_int* m, sp_int* r) w = (sp_int_word)a->dp[i] * b->dp[j]; l += (sp_int_digit)w; h += (sp_int_digit)(w >> SP_WORD_SIZE); + #ifdef SP_WORD_OVERFLOW + h += (sp_int_digit)(l >> SP_WORD_SIZE); + l &= SP_MASK; + o += (sp_int_digit)(h >> SP_WORD_SIZE); + h &= SP_MASK; + #endif } t->dp[k] = (sp_int_digit)l; l >>= SP_WORD_SIZE; l += (sp_int_digit)h; h >>= SP_WORD_SIZE; + #ifdef SP_WORD_OVERFLOW + h += o & SP_MASK; + o >>= SP_WORD_SIZE; + #endif } t->dp[k] = (sp_int_digit)l; t->dp[k+1] = (sp_int_digit)h; @@ -9512,12 +9528,20 @@ int sp_mul_2d(sp_int* a, int e, sp_int* r) #endif if (err == MP_OKAY) { - sp_int_word w, l, h; + sp_int_word w; + sp_int_word l; + sp_int_word h; + #ifdef SP_WORD_OVERFLOW + sp_int_word o; + #endif w = (sp_int_word)a->dp[0] * a->dp[0]; t->dp[0] = (sp_int_digit)w; l = (sp_int_digit)(w >> SP_WORD_SIZE); h = 0; + #ifdef SP_WORD_OVERFLOW + o = 0; + #endif for (k = 1; k <= (a->used - 1) * 2; k++) { i = k / 2; j = k - i; @@ -9525,18 +9549,40 @@ int sp_mul_2d(sp_int* a, int e, sp_int* r) w = (sp_int_word)a->dp[i] * a->dp[j]; l += (sp_int_digit)w; h += (sp_int_digit)(w >> SP_WORD_SIZE); + #ifdef SP_WORD_OVERFLOW + h += (sp_int_digit)(l >> SP_WORD_SIZE); + l &= SP_MASK; + o += (sp_int_digit)(h >> SP_WORD_SIZE); + h &= SP_MASK; + #endif } for (++i, --j; (i < a->used) && (j >= 0); i++, j--) { w = (sp_int_word)a->dp[i] * a->dp[j]; l += (sp_int_digit)w; h += (sp_int_digit)(w >> SP_WORD_SIZE); + #ifdef SP_WORD_OVERFLOW + h += (sp_int_digit)(l >> SP_WORD_SIZE); + l &= SP_MASK; + o += (sp_int_digit)(h >> SP_WORD_SIZE); + h &= SP_MASK; + #endif l += (sp_int_digit)w; h += (sp_int_digit)(w >> SP_WORD_SIZE); + #ifdef SP_WORD_OVERFLOW + h += (sp_int_digit)(l >> SP_WORD_SIZE); + l &= SP_MASK; + o += (sp_int_digit)(h >> SP_WORD_SIZE); + h &= SP_MASK; + #endif } t->dp[k] = (sp_int_digit)l; l >>= SP_WORD_SIZE; l += (sp_int_digit)h; h >>= SP_WORD_SIZE; + #ifdef SP_WORD_OVERFLOW + h += o & SP_MASK; + o >>= SP_WORD_SIZE; + #endif } t->dp[k] = (sp_int_digit)l; t->dp[k+1] = (sp_int_digit)h; diff --git a/wolfcrypt/src/wolfmath.c b/wolfcrypt/src/wolfmath.c index 40245ffd7..79a992d2b 100644 --- a/wolfcrypt/src/wolfmath.c +++ b/wolfcrypt/src/wolfmath.c @@ -99,7 +99,11 @@ int mp_cond_copy(mp_int* a, int copy, mp_int* b) { int err = MP_OKAY; int i; +#if defined(SP_WORD_SIZE) && SP_WORD_SIZE == 8 + unsigned int mask = (unsigned int)0 - copy; +#else mp_digit mask = (mp_digit)0 - copy; +#endif if (a == NULL || b == NULL) err = BAD_FUNC_ARG; diff --git a/wolfssl/wolfcrypt/sp_int.h b/wolfssl/wolfcrypt/sp_int.h index 96133133b..ccb52af78 100644 --- a/wolfssl/wolfcrypt/sp_int.h +++ b/wolfssl/wolfcrypt/sp_int.h @@ -376,6 +376,25 @@ typedef struct sp_ecc_ctx { #define SP_INT_MAX_BITS (SP_INT_DIGITS * SP_WORD_SIZE) #endif +#if SP_WORD_SIZE < 32 + /* Maximum number of digits in a number to mul or sqr. */ + #define SP_MUL_SQR_DIGITS (SP_INT_MAX_BITS / 2 / SP_WORD_SIZE) + /* Maximum value of partial in mul/sqr. */ + #define SP_MUL_SQR_MAX_PARTIAL \ + (SP_MUL_SQR_DIGITS * ((1 << SP_WORD_SIZE) - 1)) + /* Maximim value in an sp_int_word. */ + #define SP_INT_WORD_MAX ((1 << (SP_WORD_SIZE * 2)) - 1) + + #if SP_MUL_SQR_MAX_PARTIAL > SP_INT_WORD_MAX + /* The sum of the partials in the multiplicaiton/square can exceed the + * size of a word. This will overflow the word and loose data. + * Use an implementation that handles carry after every add and uses an + * extra temporary word for overflowing high word. + */ + #define SP_WORD_OVERFLOW + #endif +#endif + /* For debugging only - format string for different digit sizes. */ #if SP_WORD_SIZE == 64