Maths: mp_mod_2d supports negative value now

SRP: don't clear an mp_int that hasn't been initialized
pull/4219/head
Sean Parkinson 2021-07-19 15:26:50 +10:00
parent bbe47a81b7
commit ed6e173fc3
5 changed files with 102 additions and 23 deletions

View File

@ -660,7 +660,7 @@ void mp_rshd (mp_int * a, int b)
/* calc a value mod 2**b */
int mp_mod_2d (mp_int * a, int b, mp_int * c)
{
int x, res;
int x, res, bmax;
/* if b is <= 0 then zero the int */
if (b <= 0) {
@ -669,7 +669,7 @@ int mp_mod_2d (mp_int * a, int b, mp_int * c)
}
/* if the modulus is larger than the value than return */
if (b >= (int) (a->used * DIGIT_BIT)) {
if (a->sign == MP_ZPOS && b >= (int) (a->used * DIGIT_BIT)) {
res = mp_copy (a, c);
return res;
}
@ -679,14 +679,35 @@ int mp_mod_2d (mp_int * a, int b, mp_int * c)
return res;
}
/* calculate number of digits in mod value */
bmax = (b / DIGIT_BIT) + ((b % DIGIT_BIT) == 0 ? 0 : 1);
/* zero digits above the last digit of the modulus */
for (x = (b / DIGIT_BIT) + ((b % DIGIT_BIT) == 0 ? 0 : 1); x < c->used; x++) {
for (x = bmax; x < c->used; x++) {
c->dp[x] = 0;
}
if (c->sign == MP_NEG) {
mp_digit carry = 0;
/* grow result to size of modulus */
mp_grow(c, bmax);
/* negate value */
for (x = 0; x < c->used; x++) {
mp_digit next = c->dp[x] > 0;
c->dp[x] = ((mp_digit)0 - c->dp[x] - carry) & MP_MASK;
carry |= next;
}
for (; x < bmax; x++) {
c->dp[x] = ((mp_digit)0 - carry) & MP_MASK;
}
c->used = bmax;
c->sign = MP_ZPOS;
}
/* clear the digit that is not completely outside/inside the modulus */
x = DIGIT_BIT - (b % DIGIT_BIT);
if (x != DIGIT_BIT) {
c->dp[b / DIGIT_BIT] &=
c->dp[bmax - 1] &=
((mp_digit)~((mp_digit)0)) >> (x + ((sizeof(mp_digit)*8) - DIGIT_BIT));
}
mp_clamp (c);

View File

@ -9281,22 +9281,39 @@ int sp_mod_2d(sp_int* a, int e, sp_int* r)
int digits = (e + SP_WORD_SIZE - 1) >> SP_WORD_SHIFT;
if (a != r) {
XMEMCPY(r->dp, a->dp, digits * sizeof(sp_int_digit));
r->used = a->used;
}
/* Set used and mask off top digit of result. */
r->used = digits;
e &= SP_WORD_MASK;
if (e > 0) {
r->dp[r->used - 1] &= ((sp_int_digit)1 << e) - 1;
}
sp_clamp(r);
#ifdef WOLFSSL_SP_INT_NEGATIVE
if (sp_iszero(r)) {
r->sign = MP_ZPOS;
}
else if (a != r) {
r->sign = a->sign;
}
#ifndef WOLFSSL_SP_INT_NEGATIVE
if (digits <= a->used)
#else
if ((a->sign != MP_ZPOS) || (digits <= a->used))
#endif
{
#ifdef WOLFSSL_SP_INT_NEGATIVE
if (a->sign == MP_NEG) {
int i;
sp_int_digit carry = 0;
/* Negate value. */
for (i = 0; i < r->used; i++) {
sp_int_digit next = r->dp[i] > 0;
r->dp[i] = (sp_int_digit)0 - r->dp[i] - carry;
carry |= next;
}
for (; i < digits; i++) {
r->dp[i] = (sp_int_digit)0 - carry;
}
r->sign = MP_ZPOS;
}
#endif
/* Set used and mask off top digit of result. */
r->used = digits;
e &= SP_WORD_MASK;
if (e > 0) {
r->dp[r->used - 1] &= ((sp_int_digit)1 << e) - 1;
}
sp_clamp(r);
}
}
return err;

View File

@ -654,8 +654,7 @@ int wc_SrpComputeKey(Srp* srp, byte* clientPubKey, word32 clientPubKeySz,
if (!srp || !clientPubKey || clientPubKeySz == 0
|| !serverPubKey || serverPubKeySz == 0) {
r = BAD_FUNC_ARG;
goto out;
return BAD_FUNC_ARG;
}
#ifdef WOLFSSL_SMALL_STACK

View File

@ -998,6 +998,7 @@ int fp_mod(fp_int *a, fp_int *b, fp_int *c)
void fp_mod_2d(fp_int *a, int b, fp_int *c)
{
int x;
int bmax;
/* zero if count less than or equal to zero */
if (b <= 0) {
@ -1009,18 +1010,35 @@ void fp_mod_2d(fp_int *a, int b, fp_int *c)
fp_copy(a, c);
/* if 2**d is larger than we just return */
if (b >= (DIGIT_BIT * a->used)) {
if (c->sign == FP_ZPOS && b >= (DIGIT_BIT * a->used)) {
return;
}
bmax = (b + DIGIT_BIT - 1) / DIGIT_BIT;
/* zero digits above the last digit of the modulus */
for (x = (b / DIGIT_BIT) + ((b % DIGIT_BIT) == 0 ? 0 : 1); x < c->used; x++) {
for (x = bmax; x < c->used; x++) {
c->dp[x] = 0;
}
if (c->sign == FP_NEG) {
fp_digit carry = 0;
/* negate value */
for (x = 0; x < c->used; x++) {
fp_digit next = c->dp[x] > 0;
c->dp[x] = (fp_digit)0 - c->dp[x] - carry;
carry |= next;
}
for (; x < bmax; x++) {
c->dp[x] = (fp_digit)0 - carry;
}
c->used = bmax;
c->sign = FP_ZPOS;
}
/* clear the digit that is not completely outside/inside the modulus */
x = DIGIT_BIT - (b % DIGIT_BIT);
if (x != DIGIT_BIT) {
c->dp[b / DIGIT_BIT] &= ~((fp_digit)0) >> x;
c->dp[bmax - 1] &= ~((fp_digit)0) >> x;
}
fp_clamp (c);

View File

@ -35558,6 +35558,30 @@ static int mp_test_mod_2d(mp_int* a, mp_int* r, mp_int* t, WC_RNG* rng)
}
}
#if !defined(WOLFSSL_SP_MATH) || defined(WOLFSSL_SP_INT_NEGATIVE)
/* Test negative value being moded. */
for (j = 0; j < 20; j++) {
ret = randNum(a, 2, rng, NULL);
if (ret != 0)
return -13122;
a->sign = MP_NEG;
for (i = 1; i < DIGIT_BIT * 3 + 1; i++) {
ret = mp_mod_2d(a, i, r);
if (ret != 0)
return -13124;
mp_zero(t);
ret = mp_set_bit(t, i);
if (ret != 0)
return -13125;
ret = mp_mod(a, t, t);
if (ret != 0)
return -13126;
if (mp_cmp(r, t) != MP_EQ)
return -13127;
}
}
#endif
return 0;
}
#endif