diff --git a/wolfcrypt/src/rsa.c b/wolfcrypt/src/rsa.c index 05dcefbda..2d7703a41 100644 --- a/wolfcrypt/src/rsa.c +++ b/wolfcrypt/src/rsa.c @@ -765,7 +765,7 @@ static int RsaPad_PSS(const byte* input, word32 inputLen, byte* pkcsBlock, } else if (saltLen > hLen || saltLen < -1) return PSS_SALTLEN_E; - if ((int)pkcsBlockLen - hLen - 1 < saltLen + 2) + if ((int)pkcsBlockLen - hLen < saltLen + 2) return PSS_SALTLEN_E; s = m = pkcsBlock; @@ -1038,7 +1038,7 @@ static int RsaUnPad_PSS(byte *pkcsBlock, unsigned int pkcsBlockLen, } else if (saltLen > hLen || saltLen < -1) return PSS_SALTLEN_E; - if ((int)pkcsBlockLen - hLen - 1 < saltLen + 2) + if ((int)pkcsBlockLen - hLen < saltLen + 2) return PSS_SALTLEN_E; if (pkcsBlock[pkcsBlockLen - 1] != RSA_PSS_PAD_TERM) { @@ -1074,11 +1074,8 @@ static int RsaUnPad_PSS(byte *pkcsBlock, unsigned int pkcsBlockLen, XFREE(tmp, heap, DYNAMIC_TYPE_RSA_BUFFER); - i = pkcsBlockLen - (RSA_PSS_PAD_SZ + saltLen + 2 * hLen + 1); - XMEMSET(pkcsBlock + i, 0, RSA_PSS_PAD_SZ); - - *output = pkcsBlock + i; - return RSA_PSS_PAD_SZ + saltLen + 2 * hLen; + *output = pkcsBlock + pkcsBlockLen - (hLen + saltLen + 1); + return saltLen + hLen; } #endif @@ -2170,7 +2167,7 @@ int wc_RsaPSS_Verify_ex(byte* in, word32 inLen, byte* out, word32 outLen, int wc_RsaPSS_CheckPadding(const byte* in, word32 inSz, byte* sig, word32 sigSz, enum wc_HashType hashType) { - return wc_RsaPSS_CheckPadding_ex(in, inSz, sig, sigSz, hashType, inSz); + return wc_RsaPSS_CheckPadding_ex(in, inSz, sig, sigSz, hashType, inSz, 0); } /* Checks the PSS data to ensure that the signature matches. @@ -2188,33 +2185,46 @@ int wc_RsaPSS_CheckPadding(const byte* in, word32 inSz, byte* sig, */ int wc_RsaPSS_CheckPadding_ex(const byte* in, word32 inSz, byte* sig, word32 sigSz, enum wc_HashType hashType, - int saltLen) + int saltLen, int bits) { int ret = 0; + byte sigCheck[WC_MAX_DIGEST_SIZE*2 + RSA_PSS_PAD_SZ]; + + (void)bits; if (in == NULL || sig == NULL || inSz != (word32)wc_HashGetDigestSize(hashType)) ret = BAD_FUNC_ARG; if (ret == 0) { - if (saltLen == -1) + if (saltLen == -1) { saltLen = inSz; + #ifdef WOLFSSL_SHA512 + /* See FIPS 186-4 section 5.5 item (e). */ + if (bits == 1024 && inSz == WC_SHA512_DIGEST_SIZE) + saltLen = RSA_PSS_SALT_MAX_SZ; + #endif + } else if (saltLen < -1 || (word32)saltLen > inSz) ret = PSS_SALTLEN_E; } - /* Sig = 8 * 0x00 | Space for Message Hash | Salt | Exp Hash */ + + /* Sig = Salt | Exp Hash */ if (ret == 0) { - if (sigSz != RSA_PSS_PAD_SZ + inSz + (word32)saltLen + inSz) + if (sigSz != inSz + saltLen) ret = BAD_PADDING_E; } + /* Exp Hash = HASH(8 * 0x00 | Message Hash | Salt) */ if (ret == 0) { - XMEMCPY(sig + RSA_PSS_PAD_SZ, in, inSz); - ret = wc_Hash(hashType, sig, RSA_PSS_PAD_SZ + inSz + saltLen, sig, - inSz); + XMEMSET(sigCheck, 0, RSA_PSS_PAD_SZ); + XMEMCPY(sigCheck + RSA_PSS_PAD_SZ, in, inSz); + XMEMCPY(sigCheck + RSA_PSS_PAD_SZ + inSz, sig, saltLen); + ret = wc_Hash(hashType, sigCheck, RSA_PSS_PAD_SZ + inSz + saltLen, + sigCheck, inSz); } if (ret == 0) { - if (XMEMCMP(sig, sig + RSA_PSS_PAD_SZ + inSz + saltLen, inSz) != 0) { + if (XMEMCMP(sigCheck, sig + saltLen, inSz) != 0) { WOLFSSL_MSG("RsaPSS_CheckPadding: Padding Error"); ret = BAD_PADDING_E; } @@ -2242,7 +2252,7 @@ int wc_RsaPSS_VerifyCheckInline(byte* in, word32 inLen, byte** out, const byte* digest, word32 digestLen, enum wc_HashType hash, int mgf, RsaKey* key) { - int ret = 0, verify, saltLen, hLen; + int ret = 0, verify, saltLen, hLen, bits = 0; hLen = wc_HashGetDigestSize(hash); if (hLen < 0) @@ -2253,17 +2263,15 @@ int wc_RsaPSS_VerifyCheckInline(byte* in, word32 inLen, byte** out, saltLen = hLen; #ifdef WOLFSSL_SHA512 /* See FIPS 186-4 section 5.5 item (e). */ - if (mp_unsigned_bin_size(&key->n) == 1024 && - hLen == WC_SHA512_DIGEST_SIZE) { - + bits = mp_count_bits(&key->n); + if (bits == 1024 && hLen == WC_SHA512_DIGEST_SIZE) saltLen = RSA_PSS_SALT_MAX_SZ; - } #endif verify = wc_RsaPSS_VerifyInline_ex(in, inLen, out, hash, mgf, saltLen, key); if (verify > 0) ret = wc_RsaPSS_CheckPadding_ex(digest, digestLen, *out, verify, - hash, saltLen); + hash, saltLen, bits); if (ret == 0) ret = verify; @@ -2290,7 +2298,7 @@ int wc_RsaPSS_VerifyCheck(byte* in, word32 inLen, byte* out, word32 outLen, enum wc_HashType hash, int mgf, RsaKey* key) { - int ret = 0, verify, saltLen, hLen; + int ret = 0, verify, saltLen, hLen, bits = 0; hLen = wc_HashGetDigestSize(hash); if (hLen < 0) @@ -2301,18 +2309,16 @@ int wc_RsaPSS_VerifyCheck(byte* in, word32 inLen, byte* out, word32 outLen, saltLen = hLen; #ifdef WOLFSSL_SHA512 /* See FIPS 186-4 section 5.5 item (e). */ - if (mp_unsigned_bin_size(&key->n) == 1024 && - hLen == WC_SHA512_DIGEST_SIZE) { - + bits = mp_count_bits(&key->n); + if (bits == 1024 && hLen == WC_SHA512_DIGEST_SIZE) saltLen = RSA_PSS_SALT_MAX_SZ; - } #endif verify = wc_RsaPSS_Verify_ex(in, inLen, out, outLen, hash, mgf, saltLen, key); if (verify > 0) ret = wc_RsaPSS_CheckPadding_ex(digest, digestLen, out, verify, - hash, saltLen); + hash, saltLen, bits); if (ret == 0) ret = verify; diff --git a/wolfcrypt/test/test.c b/wolfcrypt/test/test.c index 4eb8f3205..90b1eab45 100644 --- a/wolfcrypt/test/test.c +++ b/wolfcrypt/test/test.c @@ -8873,8 +8873,8 @@ static int rsa_pss_test(WC_RNG* rng, RsaKey* key) ERROR_OUT(-5452, exit_rsa_pss); plainSz = ret; - ret = wc_RsaPSS_CheckPadding(digest, digestSz, plain, plainSz, - hash[j]); + ret = wc_RsaPSS_CheckPadding_ex(digest, digestSz, plain, plainSz, + hash[j], -1, wc_RsaEncryptSize(key)*8); if (ret != 0) ERROR_OUT(-5453, exit_rsa_pss); @@ -8942,7 +8942,7 @@ static int rsa_pss_test(WC_RNG* rng, RsaKey* key) #endif if (ret >= 0) { ret = wc_RsaPSS_CheckPadding_ex(digest, digestSz, sig, plainSz, - hash[0], 0); + hash[0], 0, 0); } } while (ret == WC_PENDING_E); if (ret != 0) @@ -8965,7 +8965,7 @@ static int rsa_pss_test(WC_RNG* rng, RsaKey* key) plainSz = ret; ret = wc_RsaPSS_CheckPadding_ex(digest, digestSz, plain, plainSz, hash[0], - 0); + 0, 0); if (ret != 0) ERROR_OUT(-5464, exit_rsa_pss); @@ -9025,11 +9025,11 @@ static int rsa_pss_test(WC_RNG* rng, RsaKey* key) ERROR_OUT(-5473, exit_rsa_pss); ret = wc_RsaPSS_CheckPadding_ex(digest, digestSz, plain, plainSz, hash[0], - -2); + -2, 0); if (ret != PSS_SALTLEN_E) ERROR_OUT(-5474, exit_rsa_pss); ret = wc_RsaPSS_CheckPadding_ex(digest, digestSz, plain, plainSz, hash[0], - digestSz + 1); + digestSz + 1, 0); if (ret != PSS_SALTLEN_E) ERROR_OUT(-5475, exit_rsa_pss); diff --git a/wolfssl/wolfcrypt/rsa.h b/wolfssl/wolfcrypt/rsa.h index 4121c7863..8d2c12798 100644 --- a/wolfssl/wolfcrypt/rsa.h +++ b/wolfssl/wolfcrypt/rsa.h @@ -190,7 +190,7 @@ WOLFSSL_API int wc_RsaPSS_CheckPadding(const byte* in, word32 inLen, byte* sig, WOLFSSL_API int wc_RsaPSS_CheckPadding_ex(const byte* in, word32 inLen, byte* sig, word32 sigSz, enum wc_HashType hashType, - int saltLen); + int saltLen, int bits); WOLFSSL_API int wc_RsaPSS_VerifyCheckInline(byte* in, word32 inLen, byte** out, const byte* digest, word32 digentLen, enum wc_HashType hash, int mgf,