Fix div implementation in SP int

pull/3087/head
Sean Parkinson 2020-06-30 22:22:11 +10:00
parent b6aaedd3b4
commit 1af2e5cf02
2 changed files with 176 additions and 84 deletions

View File

@ -653,10 +653,17 @@ void sp_rshb(sp_int* a, int n, sp_int* r)
{
int i;
int j;
int s = n % SP_WORD_SIZE;
for (i = n / SP_WORD_SIZE, j = 0; i < a->used-1; i++, j++)
r->dp[i] = (a->dp[j] >> n) | (a->dp[j+1] << (SP_WORD_SIZE - n));
r->dp[i] = a->dp[j] >> n;
if (s == 0) {
for (i = n / SP_WORD_SIZE, j = 0; i < a->used-1; i++, j++)
r->dp[j] = a->dp[i];
}
else {
for (i = n / SP_WORD_SIZE, j = 0; i < a->used-1; i++, j++)
r->dp[j] = (a->dp[i] >> s) | (a->dp[i+1] << (SP_WORD_SIZE - s));
}
r->dp[j] = a->dp[i] >> s;
r->used = j + 1;
sp_clamp(r);
}
@ -688,6 +695,59 @@ static void _sp_mul_d(sp_int* a, sp_int_digit n, sp_int* r, int o)
sp_clamp(r);
}
/* Divide a two digit number by a digit number and return. (hi | lo) / d
*
* hi SP integer digit. High digit.
* lo SP integer digit. Lower digit.
* d SP integer digit. Number to divide by.
* returns the division result.
*/
static WC_INLINE sp_int_digit sp_div_word(sp_int_digit hi, sp_int_digit lo,
sp_int_digit d)
{
#ifdef WOLFSSL_SP_DIV_WORD_HALF
sp_int_digit div = d >> SP_HALF_SIZE;
sp_int_digit r;
sp_int_digit r2;
sp_int_word w = ((sp_int_word)hi << SP_WORD_SIZE) | lo;
sp_int_word trial;
r = hi / div;
if (r > SP_HALF_MAX) {
r = SP_HALF_MAX;
}
r <<= SP_HALF_SIZE;
trial = r * (sp_int_word)d;
while (trial > w) {
r -= (sp_int_digit)1 << SP_HALF_SIZE;
trial -= (sp_int_word)d << SP_HALF_SIZE;
}
w -= trial;
r2 = ((sp_int_digit)(w >> SP_HALF_SIZE)) / div;
trial = r2 * (sp_int_word)d;
while (trial > w) {
r2--;
trial -= d;
}
w -= trial;
r += r2;
r2 = ((sp_int_digit)w) / d;
r += r2;
return r;
#else
sp_int_word w;
sp_int_digit r;
w = ((sp_int_word)hi << SP_WORD_SIZE) | lo;
w /= d;
r = (sp_int_digit)w;
return r;
#endif
}
/* Divide a by d and return the quotient in r and the remainder in rem.
* r = a / d; rem = a % d
*
@ -705,9 +765,6 @@ static int sp_div(sp_int* a, sp_int* d, sp_int* r, sp_int* rem)
int done = 0;
int i;
int s;
#ifndef WOLFSSL_SP_DIV_32
sp_int_word w = 0;
#endif
sp_int_digit dt;
sp_int_digit t;
#ifdef WOLFSSL_SMALL_STACK
@ -721,38 +778,53 @@ static int sp_div(sp_int* a, sp_int* d, sp_int* r, sp_int* rem)
sp_int tr[1];
sp_int trial[1];
#endif
#ifdef WOLFSSL_SP_SMALL
int c;
#else
int j, o;
sp_int_word tw;
sp_int_sword sw;
#endif
if (sp_iszero(d))
err = MP_VAL;
ret = sp_cmp(a, d);
if (ret == MP_LT) {
if (rem != NULL) {
sp_copy(a, rem);
if (err == MP_OKAY) {
ret = sp_cmp(a, d);
if (ret == MP_LT) {
if (rem != NULL) {
sp_copy(a, rem);
}
if (r != NULL) {
sp_set(r, 0);
}
done = 1;
}
if (r != NULL) {
sp_set(r, 0);
else if (ret == MP_EQ) {
if (rem != NULL) {
sp_set(rem, 0);
}
if (r != NULL) {
sp_set(r, 1);
#ifdef WOLFSSL_SP_INT_NEGATIVE
r->sign = aSign;
#endif
}
done = 1;
}
done = 1;
}
else if (ret == MP_EQ) {
if (rem != NULL) {
sp_set(rem, 0);
else if (sp_count_bits(a) == sp_count_bits(d)) {
/* a is greater than d but same bit length */
if (rem != NULL) {
sp_sub(a, d, rem);
}
if (r != NULL) {
sp_set(r, 1);
#ifdef WOLFSSL_SP_INT_NEGATIVE
r->sign = aSign;
#endif
}
done = 1;
}
if (r != NULL) {
sp_set(r, 1);
}
done = 1;
}
else if (sp_count_bits(a) == sp_count_bits(d)) {
/* a is greater than d but same bit length */
if (rem != NULL) {
sp_sub(a, d, rem);
}
if (r != NULL) {
sp_set(r, 1);
}
done = 1;
}
#ifdef WOLFSSL_SMALL_STACK
@ -790,79 +862,94 @@ static int sp_div(sp_int* a, sp_int* d, sp_int* r, sp_int* rem)
sp_clear(tr);
tr->used = sa->used - d->used + 1;
dt = d->dp[d->used-1];
#ifndef WOLFSSL_SP_DIV_32
for (i = sa->used - 1; i >= d->used; ) {
if (sa->dp[i] > dt) {
t = (sp_int_digit)-1;
for (i = d->used - 1; i > 0; i--) {
if (sa->dp[sa->used - d->used + i] != d->dp[i]) {
break;
}
}
if (sa->dp[sa->used - d->used + i] >= d->dp[i]) {
i = sa->used;
o = sa->used - d->used;
sp_lshb(d, o * SP_WORD_SIZE);
sp_sub(sa, d, sa);
sp_rshb(d, o * SP_WORD_SIZE, d);
sa->used = i;
if (r != NULL) {
tr->dp[o] = 1;
}
}
for (i = sa->used - 1; i >= d->used; i--) {
if (sa->dp[i] == dt) {
t = (sp_digit)-1;
}
else {
w = ((sp_int_word)sa->dp[i] << SP_WORD_SIZE) | sa->dp[i-1];
w /= dt;
if (w > (sp_int_digit)-1) {
t = (sp_int_digit)-1;
}
else {
t = (sp_int_digit)w;
}
t = sp_div_word(sa->dp[i], sa->dp[i-1], dt);
}
if (t > 0) {
#ifdef WOLFSSL_SP_SMALL
do {
_sp_mul_d(d, t, trial, i - d->used);
while (sp_cmp(trial, sa) == MP_GT) {
c = _sp_cmp_abs(trial, sa);
if (c == MP_GT) {
t--;
_sp_mul_d(d, t, trial, i - d->used);
}
sp_sub(sa, trial, sa);
tr->dp[i - d->used] += t;
if (tr->dp[i - d->used] < t)
tr->dp[i + 1 - d->used]++;
}
i = sa->used - 1;
}
while (c == MP_GT);
sp_sub(sa, trial, sa);
tr->dp[i - d->used] += t;
if (tr->dp[i - d->used] < t) {
tr->dp[i + 1 - d->used]++;
}
#else
{
sp_int_digit div = (dt >> (SP_WORD_SIZE / 2)) + 1;
for (i = sa->used - 1; i >= d->used; ) {
t = sa->dp[i] / div;
if ((t > 0) && (t << (SP_WORD_SIZE / 2) == 0))
t = (sp_int_digit)-1;
t <<= SP_WORD_SIZE / 2;
if (t == 0) {
t = sa->dp[i] << (SP_WORD_SIZE / 2);
t += sa->dp[i-1] >> (SP_WORD_SIZE / 2);
t /= div;
}
if (t > 0) {
_sp_mul_d(d, t, trial, i - d->used);
while (sp_cmp(trial, sa) == MP_GT) {
t--;
_sp_mul_d(d, t, trial, i - d->used);
o = i - d->used;
do {
tw = 0;
for (j = 0; j < d->used; j++) {
tw += (sp_int_word)d->dp[j] * t;
trial->dp[j] = (sp_int_digit)tw;
tw >>= SP_WORD_SIZE;
}
trial->dp[j] = (sp_int_digit)tw;
for (j = d->used; j > 0; j--) {
if (trial->dp[j] != sa->dp[j + o]) {
break;
}
}
if (trial->dp[j] > sa->dp[j + o]) {
t--;
}
sp_sub(sa, trial, sa);
tr->dp[i - d->used] += t;
if (tr->dp[i - d->used] < t)
tr->dp[i + 1 - d->used]++;
}
i = sa->used - 1;
}
while (trial->dp[j] > sa->dp[j + o]);
while (sp_cmp(sa, d) != MP_LT) {
sp_sub(sa, d, sa);
sp_add_d(tr, 1, tr);
}
}
sw = 0;
for (j = 0; j <= d->used; j++) {
sw += sa->dp[j + o];
sw -= trial->dp[j];
sa->dp[j + o] = (sp_digit)sw;
sw >>= SP_WORD_SIZE;
}
tr->dp[o] += t;
if (tr->dp[o] < t) {
tr->dp[o + 1]++;
}
#endif
sp_clamp(tr);
}
sa->used = i + 1;
if (rem != NULL) {
if (s != SP_WORD_SIZE)
sp_rshb(sa, s, sa);
sp_copy(sa, rem);
sp_clamp(rem);
}
if (r != NULL)
if (r != NULL) {
sp_copy(tr, r);
sp_clamp(r);
}
}
#ifdef WOLFSSL_SMALL_STACK

View File

@ -60,6 +60,7 @@
typedef int32 sp_digit;
typedef uint32 sp_int_digit;
typedef uint64 sp_int_word;
typedef int64 sp_int_sword;
#undef SP_WORD_SIZE
#define SP_WORD_SIZE 32
#elif !defined(WOLFSSL_SP_ASM)
@ -67,6 +68,7 @@
typedef int32_t sp_digit;
typedef uint32_t sp_int_digit;
typedef uint64_t sp_int_word;
typedef int64_t sp_int_sword;
#elif SP_WORD_SIZE == 64
typedef int64_t sp_digit;
typedef uint64_t sp_int_digit;
@ -78,6 +80,7 @@
typedef long int128_t __attribute__ ((mode(TI)));
#endif
typedef uint128_t sp_int_word;
typedef int128_t sp_int_sword;
#else
#error Word size not defined
#endif
@ -86,6 +89,7 @@
typedef uint32_t sp_digit;
typedef uint32_t sp_int_digit;
typedef uint64_t sp_int_word;
typedef int64_t sp_int_sword;
#elif SP_WORD_SIZE == 64
typedef uint64_t sp_digit;
typedef uint64_t sp_int_digit;
@ -97,6 +101,7 @@
typedef long int128_t __attribute__ ((mode(TI)));
#endif
typedef uint128_t sp_int_word;
typedef int128_t sp_int_sword;
#else
#error Word size not defined
#endif