SP int: fixes for 8-bit digits

Fix mask type in mp_cond_copy to be at least 16 bits to handle 'used'
being larger than 8-bit but mp_digit being 8-bit.
When large numbers are used with 8-bit words, mul/sqr partial sums will
overflow a word. Fix implementations to handle this.
pull/3734/head
Sean Parkinson 2021-02-08 12:24:28 +10:00
parent 2933db8915
commit 7986b37aa5
3 changed files with 70 additions and 1 deletions

View File

@ -4768,11 +4768,17 @@ int sp_mod(sp_int* a, sp_int* m, sp_int* r)
sp_int_word w;
sp_int_word l;
sp_int_word h;
#ifdef SP_WORD_OVERFLOW
sp_int_word o;
#endif
w = (sp_int_word)a->dp[0] * b->dp[0];
t->dp[0] = (sp_int_digit)w;
l = (sp_int_digit)(w >> SP_WORD_SIZE);
h = 0;
#ifdef SP_WORD_OVERFLOW
o = 0;
#endif
for (k = 1; k <= (a->used - 1) + (b->used - 1); k++) {
i = k - (b->used - 1);
i &= ~(i >> (sizeof(i) * 8 - 1));
@ -4781,11 +4787,21 @@ int sp_mod(sp_int* a, sp_int* m, sp_int* r)
w = (sp_int_word)a->dp[i] * b->dp[j];
l += (sp_int_digit)w;
h += (sp_int_digit)(w >> SP_WORD_SIZE);
#ifdef SP_WORD_OVERFLOW
h += (sp_int_digit)(l >> SP_WORD_SIZE);
l &= SP_MASK;
o += (sp_int_digit)(h >> SP_WORD_SIZE);
h &= SP_MASK;
#endif
}
t->dp[k] = (sp_int_digit)l;
l >>= SP_WORD_SIZE;
l += (sp_int_digit)h;
h >>= SP_WORD_SIZE;
#ifdef SP_WORD_OVERFLOW
h += o & SP_MASK;
o >>= SP_WORD_SIZE;
#endif
}
t->dp[k] = (sp_int_digit)l;
t->dp[k+1] = (sp_int_digit)h;
@ -9512,12 +9528,20 @@ int sp_mul_2d(sp_int* a, int e, sp_int* r)
#endif
if (err == MP_OKAY) {
sp_int_word w, l, h;
sp_int_word w;
sp_int_word l;
sp_int_word h;
#ifdef SP_WORD_OVERFLOW
sp_int_word o;
#endif
w = (sp_int_word)a->dp[0] * a->dp[0];
t->dp[0] = (sp_int_digit)w;
l = (sp_int_digit)(w >> SP_WORD_SIZE);
h = 0;
#ifdef SP_WORD_OVERFLOW
o = 0;
#endif
for (k = 1; k <= (a->used - 1) * 2; k++) {
i = k / 2;
j = k - i;
@ -9525,18 +9549,40 @@ int sp_mul_2d(sp_int* a, int e, sp_int* r)
w = (sp_int_word)a->dp[i] * a->dp[j];
l += (sp_int_digit)w;
h += (sp_int_digit)(w >> SP_WORD_SIZE);
#ifdef SP_WORD_OVERFLOW
h += (sp_int_digit)(l >> SP_WORD_SIZE);
l &= SP_MASK;
o += (sp_int_digit)(h >> SP_WORD_SIZE);
h &= SP_MASK;
#endif
}
for (++i, --j; (i < a->used) && (j >= 0); i++, j--) {
w = (sp_int_word)a->dp[i] * a->dp[j];
l += (sp_int_digit)w;
h += (sp_int_digit)(w >> SP_WORD_SIZE);
#ifdef SP_WORD_OVERFLOW
h += (sp_int_digit)(l >> SP_WORD_SIZE);
l &= SP_MASK;
o += (sp_int_digit)(h >> SP_WORD_SIZE);
h &= SP_MASK;
#endif
l += (sp_int_digit)w;
h += (sp_int_digit)(w >> SP_WORD_SIZE);
#ifdef SP_WORD_OVERFLOW
h += (sp_int_digit)(l >> SP_WORD_SIZE);
l &= SP_MASK;
o += (sp_int_digit)(h >> SP_WORD_SIZE);
h &= SP_MASK;
#endif
}
t->dp[k] = (sp_int_digit)l;
l >>= SP_WORD_SIZE;
l += (sp_int_digit)h;
h >>= SP_WORD_SIZE;
#ifdef SP_WORD_OVERFLOW
h += o & SP_MASK;
o >>= SP_WORD_SIZE;
#endif
}
t->dp[k] = (sp_int_digit)l;
t->dp[k+1] = (sp_int_digit)h;

View File

@ -99,7 +99,11 @@ int mp_cond_copy(mp_int* a, int copy, mp_int* b)
{
int err = MP_OKAY;
int i;
#if defined(SP_WORD_SIZE) && SP_WORD_SIZE == 8
unsigned int mask = (unsigned int)0 - copy;
#else
mp_digit mask = (mp_digit)0 - copy;
#endif
if (a == NULL || b == NULL)
err = BAD_FUNC_ARG;

View File

@ -376,6 +376,25 @@ typedef struct sp_ecc_ctx {
#define SP_INT_MAX_BITS (SP_INT_DIGITS * SP_WORD_SIZE)
#endif
#if SP_WORD_SIZE < 32
/* Maximum number of digits in a number to mul or sqr. */
#define SP_MUL_SQR_DIGITS (SP_INT_MAX_BITS / 2 / SP_WORD_SIZE)
/* Maximum value of partial in mul/sqr. */
#define SP_MUL_SQR_MAX_PARTIAL \
(SP_MUL_SQR_DIGITS * ((1 << SP_WORD_SIZE) - 1))
/* Maximim value in an sp_int_word. */
#define SP_INT_WORD_MAX ((1 << (SP_WORD_SIZE * 2)) - 1)
#if SP_MUL_SQR_MAX_PARTIAL > SP_INT_WORD_MAX
/* The sum of the partials in the multiplicaiton/square can exceed the
* size of a word. This will overflow the word and loose data.
* Use an implementation that handles carry after every add and uses an
* extra temporary word for overflowing high word.
*/
#define SP_WORD_OVERFLOW
#endif
#endif
/* For debugging only - format string for different digit sizes. */
#if SP_WORD_SIZE == 64