Fix div implementation in SP int

pull/3087/head
Sean Parkinson 2020-06-30 22:22:11 +10:00
parent b6aaedd3b4
commit 1af2e5cf02
2 changed files with 176 additions and 84 deletions

View File

@ -653,10 +653,17 @@ void sp_rshb(sp_int* a, int n, sp_int* r)
{ {
int i; int i;
int j; int j;
int s = n % SP_WORD_SIZE;
for (i = n / SP_WORD_SIZE, j = 0; i < a->used-1; i++, j++) if (s == 0) {
r->dp[i] = (a->dp[j] >> n) | (a->dp[j+1] << (SP_WORD_SIZE - n)); for (i = n / SP_WORD_SIZE, j = 0; i < a->used-1; i++, j++)
r->dp[i] = a->dp[j] >> n; r->dp[j] = a->dp[i];
}
else {
for (i = n / SP_WORD_SIZE, j = 0; i < a->used-1; i++, j++)
r->dp[j] = (a->dp[i] >> s) | (a->dp[i+1] << (SP_WORD_SIZE - s));
}
r->dp[j] = a->dp[i] >> s;
r->used = j + 1; r->used = j + 1;
sp_clamp(r); sp_clamp(r);
} }
@ -688,6 +695,59 @@ static void _sp_mul_d(sp_int* a, sp_int_digit n, sp_int* r, int o)
sp_clamp(r); sp_clamp(r);
} }
/* Divide a two digit number by a digit number and return. (hi | lo) / d
*
* hi SP integer digit. High digit.
* lo SP integer digit. Lower digit.
* d SP integer digit. Number to divide by.
* returns the division result.
*/
static WC_INLINE sp_int_digit sp_div_word(sp_int_digit hi, sp_int_digit lo,
sp_int_digit d)
{
#ifdef WOLFSSL_SP_DIV_WORD_HALF
sp_int_digit div = d >> SP_HALF_SIZE;
sp_int_digit r;
sp_int_digit r2;
sp_int_word w = ((sp_int_word)hi << SP_WORD_SIZE) | lo;
sp_int_word trial;
r = hi / div;
if (r > SP_HALF_MAX) {
r = SP_HALF_MAX;
}
r <<= SP_HALF_SIZE;
trial = r * (sp_int_word)d;
while (trial > w) {
r -= (sp_int_digit)1 << SP_HALF_SIZE;
trial -= (sp_int_word)d << SP_HALF_SIZE;
}
w -= trial;
r2 = ((sp_int_digit)(w >> SP_HALF_SIZE)) / div;
trial = r2 * (sp_int_word)d;
while (trial > w) {
r2--;
trial -= d;
}
w -= trial;
r += r2;
r2 = ((sp_int_digit)w) / d;
r += r2;
return r;
#else
sp_int_word w;
sp_int_digit r;
w = ((sp_int_word)hi << SP_WORD_SIZE) | lo;
w /= d;
r = (sp_int_digit)w;
return r;
#endif
}
/* Divide a by d and return the quotient in r and the remainder in rem. /* Divide a by d and return the quotient in r and the remainder in rem.
* r = a / d; rem = a % d * r = a / d; rem = a % d
* *
@ -705,9 +765,6 @@ static int sp_div(sp_int* a, sp_int* d, sp_int* r, sp_int* rem)
int done = 0; int done = 0;
int i; int i;
int s; int s;
#ifndef WOLFSSL_SP_DIV_32
sp_int_word w = 0;
#endif
sp_int_digit dt; sp_int_digit dt;
sp_int_digit t; sp_int_digit t;
#ifdef WOLFSSL_SMALL_STACK #ifdef WOLFSSL_SMALL_STACK
@ -721,38 +778,53 @@ static int sp_div(sp_int* a, sp_int* d, sp_int* r, sp_int* rem)
sp_int tr[1]; sp_int tr[1];
sp_int trial[1]; sp_int trial[1];
#endif #endif
#ifdef WOLFSSL_SP_SMALL
int c;
#else
int j, o;
sp_int_word tw;
sp_int_sword sw;
#endif
if (sp_iszero(d)) if (sp_iszero(d))
err = MP_VAL; err = MP_VAL;
ret = sp_cmp(a, d); if (err == MP_OKAY) {
if (ret == MP_LT) { ret = sp_cmp(a, d);
if (rem != NULL) { if (ret == MP_LT) {
sp_copy(a, rem); if (rem != NULL) {
sp_copy(a, rem);
}
if (r != NULL) {
sp_set(r, 0);
}
done = 1;
} }
if (r != NULL) { else if (ret == MP_EQ) {
sp_set(r, 0); if (rem != NULL) {
sp_set(rem, 0);
}
if (r != NULL) {
sp_set(r, 1);
#ifdef WOLFSSL_SP_INT_NEGATIVE
r->sign = aSign;
#endif
}
done = 1;
} }
done = 1; else if (sp_count_bits(a) == sp_count_bits(d)) {
} /* a is greater than d but same bit length */
else if (ret == MP_EQ) { if (rem != NULL) {
if (rem != NULL) { sp_sub(a, d, rem);
sp_set(rem, 0); }
if (r != NULL) {
sp_set(r, 1);
#ifdef WOLFSSL_SP_INT_NEGATIVE
r->sign = aSign;
#endif
}
done = 1;
} }
if (r != NULL) {
sp_set(r, 1);
}
done = 1;
}
else if (sp_count_bits(a) == sp_count_bits(d)) {
/* a is greater than d but same bit length */
if (rem != NULL) {
sp_sub(a, d, rem);
}
if (r != NULL) {
sp_set(r, 1);
}
done = 1;
} }
#ifdef WOLFSSL_SMALL_STACK #ifdef WOLFSSL_SMALL_STACK
@ -790,79 +862,94 @@ static int sp_div(sp_int* a, sp_int* d, sp_int* r, sp_int* rem)
sp_clear(tr); sp_clear(tr);
tr->used = sa->used - d->used + 1; tr->used = sa->used - d->used + 1;
dt = d->dp[d->used-1]; dt = d->dp[d->used-1];
#ifndef WOLFSSL_SP_DIV_32
for (i = sa->used - 1; i >= d->used; ) { for (i = d->used - 1; i > 0; i--) {
if (sa->dp[i] > dt) { if (sa->dp[sa->used - d->used + i] != d->dp[i]) {
t = (sp_int_digit)-1; break;
}
}
if (sa->dp[sa->used - d->used + i] >= d->dp[i]) {
i = sa->used;
o = sa->used - d->used;
sp_lshb(d, o * SP_WORD_SIZE);
sp_sub(sa, d, sa);
sp_rshb(d, o * SP_WORD_SIZE, d);
sa->used = i;
if (r != NULL) {
tr->dp[o] = 1;
}
}
for (i = sa->used - 1; i >= d->used; i--) {
if (sa->dp[i] == dt) {
t = (sp_digit)-1;
} }
else { else {
w = ((sp_int_word)sa->dp[i] << SP_WORD_SIZE) | sa->dp[i-1]; t = sp_div_word(sa->dp[i], sa->dp[i-1], dt);
w /= dt;
if (w > (sp_int_digit)-1) {
t = (sp_int_digit)-1;
}
else {
t = (sp_int_digit)w;
}
} }
if (t > 0) { #ifdef WOLFSSL_SP_SMALL
do {
_sp_mul_d(d, t, trial, i - d->used); _sp_mul_d(d, t, trial, i - d->used);
while (sp_cmp(trial, sa) == MP_GT) { c = _sp_cmp_abs(trial, sa);
if (c == MP_GT) {
t--; t--;
_sp_mul_d(d, t, trial, i - d->used);
} }
sp_sub(sa, trial, sa);
tr->dp[i - d->used] += t;
if (tr->dp[i - d->used] < t)
tr->dp[i + 1 - d->used]++;
} }
i = sa->used - 1; while (c == MP_GT);
}
sp_sub(sa, trial, sa);
tr->dp[i - d->used] += t;
if (tr->dp[i - d->used] < t) {
tr->dp[i + 1 - d->used]++;
}
#else #else
{ o = i - d->used;
sp_int_digit div = (dt >> (SP_WORD_SIZE / 2)) + 1; do {
for (i = sa->used - 1; i >= d->used; ) { tw = 0;
t = sa->dp[i] / div; for (j = 0; j < d->used; j++) {
if ((t > 0) && (t << (SP_WORD_SIZE / 2) == 0)) tw += (sp_int_word)d->dp[j] * t;
t = (sp_int_digit)-1; trial->dp[j] = (sp_int_digit)tw;
t <<= SP_WORD_SIZE / 2; tw >>= SP_WORD_SIZE;
if (t == 0) { }
t = sa->dp[i] << (SP_WORD_SIZE / 2); trial->dp[j] = (sp_int_digit)tw;
t += sa->dp[i-1] >> (SP_WORD_SIZE / 2);
t /= div; for (j = d->used; j > 0; j--) {
} if (trial->dp[j] != sa->dp[j + o]) {
break;
if (t > 0) { }
_sp_mul_d(d, t, trial, i - d->used); }
while (sp_cmp(trial, sa) == MP_GT) { if (trial->dp[j] > sa->dp[j + o]) {
t--; t--;
_sp_mul_d(d, t, trial, i - d->used);
} }
sp_sub(sa, trial, sa);
tr->dp[i - d->used] += t;
if (tr->dp[i - d->used] < t)
tr->dp[i + 1 - d->used]++;
} }
i = sa->used - 1; while (trial->dp[j] > sa->dp[j + o]);
}
while (sp_cmp(sa, d) != MP_LT) { sw = 0;
sp_sub(sa, d, sa); for (j = 0; j <= d->used; j++) {
sp_add_d(tr, 1, tr); sw += sa->dp[j + o];
} sw -= trial->dp[j];
} sa->dp[j + o] = (sp_digit)sw;
sw >>= SP_WORD_SIZE;
}
tr->dp[o] += t;
if (tr->dp[o] < t) {
tr->dp[o + 1]++;
}
#endif #endif
}
sp_clamp(tr); sa->used = i + 1;
if (rem != NULL) { if (rem != NULL) {
if (s != SP_WORD_SIZE) if (s != SP_WORD_SIZE)
sp_rshb(sa, s, sa); sp_rshb(sa, s, sa);
sp_copy(sa, rem); sp_copy(sa, rem);
sp_clamp(rem);
} }
if (r != NULL) if (r != NULL) {
sp_copy(tr, r); sp_copy(tr, r);
sp_clamp(r);
}
} }
#ifdef WOLFSSL_SMALL_STACK #ifdef WOLFSSL_SMALL_STACK

View File

@ -60,6 +60,7 @@
typedef int32 sp_digit; typedef int32 sp_digit;
typedef uint32 sp_int_digit; typedef uint32 sp_int_digit;
typedef uint64 sp_int_word; typedef uint64 sp_int_word;
typedef int64 sp_int_sword;
#undef SP_WORD_SIZE #undef SP_WORD_SIZE
#define SP_WORD_SIZE 32 #define SP_WORD_SIZE 32
#elif !defined(WOLFSSL_SP_ASM) #elif !defined(WOLFSSL_SP_ASM)
@ -67,6 +68,7 @@
typedef int32_t sp_digit; typedef int32_t sp_digit;
typedef uint32_t sp_int_digit; typedef uint32_t sp_int_digit;
typedef uint64_t sp_int_word; typedef uint64_t sp_int_word;
typedef int64_t sp_int_sword;
#elif SP_WORD_SIZE == 64 #elif SP_WORD_SIZE == 64
typedef int64_t sp_digit; typedef int64_t sp_digit;
typedef uint64_t sp_int_digit; typedef uint64_t sp_int_digit;
@ -78,6 +80,7 @@
typedef long int128_t __attribute__ ((mode(TI))); typedef long int128_t __attribute__ ((mode(TI)));
#endif #endif
typedef uint128_t sp_int_word; typedef uint128_t sp_int_word;
typedef int128_t sp_int_sword;
#else #else
#error Word size not defined #error Word size not defined
#endif #endif
@ -86,6 +89,7 @@
typedef uint32_t sp_digit; typedef uint32_t sp_digit;
typedef uint32_t sp_int_digit; typedef uint32_t sp_int_digit;
typedef uint64_t sp_int_word; typedef uint64_t sp_int_word;
typedef int64_t sp_int_sword;
#elif SP_WORD_SIZE == 64 #elif SP_WORD_SIZE == 64
typedef uint64_t sp_digit; typedef uint64_t sp_digit;
typedef uint64_t sp_int_digit; typedef uint64_t sp_int_digit;
@ -97,6 +101,7 @@
typedef long int128_t __attribute__ ((mode(TI))); typedef long int128_t __attribute__ ((mode(TI)));
#endif #endif
typedef uint128_t sp_int_word; typedef uint128_t sp_int_word;
typedef int128_t sp_int_sword;
#else #else
#error Word size not defined #error Word size not defined
#endif #endif