From 3b443c01fc8abb58c45af4f11e46135e8e7b716d Mon Sep 17 00:00:00 2001 From: John Safranek Date: Fri, 3 Nov 2023 10:19:07 -0700 Subject: [PATCH] PR Review 1. Add better error checking to the OpenSSH key code. 2. Add a couple heaps that were missing. --- src/internal.c | 125 +++++++++++++++++++++++++++++++++++-------------- 1 file changed, 90 insertions(+), 35 deletions(-) diff --git a/src/internal.c b/src/internal.c index 863847f1..b9e33888 100644 --- a/src/internal.c +++ b/src/internal.c @@ -897,7 +897,7 @@ static int wolfSSH_KEY_init(WS_KeySignature* key, byte keyId, void* heap) switch (keyId) { #ifndef WOLFSSH_NO_RSA case ID_SSH_RSA: - ret = wc_InitRsaKey(&key->ks.rsa.key, NULL); + ret = wc_InitRsaKey(&key->ks.rsa.key, heap); break; #endif #ifndef WOLFSSH_NO_ECDSA @@ -1045,39 +1045,84 @@ int IdentifyAsn1Key(const byte* in, word32 inSz, int isPrivate, void* heap) } -static int GetOpenSshKeyRsa(RsaKey* key, +/* + * Utility function to read an Mpint from the stream directly into a mp_int. + */ +static INLINE int GetMpintToMp(mp_int* mp, const byte* buf, word32 len, word32* idx) { const byte* val = NULL; word32 valSz = 0; - mp_int m; + int ret; - GetMpint(&valSz, &val, buf, len, idx); /* n */ - mp_read_unsigned_bin(&key->n, val, valSz); - GetMpint(&valSz, &val, buf, len, idx); /* e */ - mp_read_unsigned_bin(&key->e, val, valSz); - GetMpint(&valSz, &val, buf, len, idx); /* d */ - mp_read_unsigned_bin(&key->d, val, valSz); - GetMpint(&valSz, &val, buf, len, idx); /* iqmp */ - mp_read_unsigned_bin(&key->u, val, valSz); - GetMpint(&valSz, &val, buf, len, idx); /* p */ - mp_read_unsigned_bin(&key->p, val, valSz); - GetMpint(&valSz, &val, buf, len, idx); /* q */ - mp_read_unsigned_bin(&key->q, val, valSz); + ret = GetMpint(&valSz, &val, buf, len, idx); + if (ret == WS_SUCCESS) + ret = mp_read_unsigned_bin(mp, val, valSz); - /* Calculate dP and dQ for wolfCrypt. */ - mp_init(&m); - mp_sub_d(&key->p, 1, &m); - mp_mod(&key->d, &m, &key->dP); - mp_sub_d(&key->q, 1, &m); - mp_mod(&key->d, &m, &key->dQ); - mp_forcezero(&m); - mp_free(&m); - - return 0; + return ret; } +/* + * For the given RSA key, calculate p^-1 and q^-1. wolfCrypt's RSA + * code expects them, but the OpenSSH format key doesn't store them. + * TODO: Add a RSA read function to wolfCrypt to handle this condition. + */ +static INLINE int CalcRsaInverses(RsaKey* key) +{ + mp_int m; + int ret; + + ret = mp_init(&m); + if (ret == MP_OKAY) { + ret = mp_sub_d(&key->p, 1, &m); + if (ret == MP_OKAY) + ret = mp_mod(&key->d, &m, &key->dP); + if (ret == MP_OKAY) + ret = mp_sub_d(&key->q, 1, &m); + if (ret == MP_OKAY) + ret = mp_mod(&key->d, &m, &key->dQ); + mp_forcezero(&m); + } + + return ret; +} + + +/* + * Utility for GetOpenSshKey() to read in RSA keys. + */ +static int GetOpenSshKeyRsa(RsaKey* key, + const byte* buf, word32 len, word32* idx) +{ + int ret; + + ret = GetMpintToMp(&key->n, buf, len, idx); + if (ret == WS_SUCCESS) + ret = GetMpintToMp(&key->e, buf, len, idx); + if (ret == WS_SUCCESS) + ret = GetMpintToMp(&key->d, buf, len, idx); + if (ret == WS_SUCCESS) + ret = GetMpintToMp(&key->u, buf, len, idx); + if (ret == WS_SUCCESS) + ret = GetMpintToMp(&key->p, buf, len, idx); + if (ret == WS_SUCCESS) + ret = GetMpintToMp(&key->q, buf, len, idx); + + /* Calculate dP and dQ for wolfCrypt. */ + if (ret == WS_SUCCESS) + ret = CalcRsaInverses(key); + + if (ret != WS_SUCCESS) + ret = WS_RSA_E; + + return ret; +} + + +/* + * Utility for GetOpenSshKey() to read in ECDSA keys. + */ static int GetOpenSshKeyEcc(ecc_key* key, const byte* buf, word32 len, word32* idx) { @@ -1085,18 +1130,26 @@ static int GetOpenSshKeyEcc(ecc_key* key, word32 nameSz = 0, privSz = 0, pubSz = 0; int ret; - GetStringRef(&nameSz, &name, buf, len, idx); /* curve name */ - GetStringRef(&pubSz, &pub, buf, len, idx); /* Q */ - GetMpint(&privSz, &priv, buf, len, idx); /* d */ + ret = GetStringRef(&nameSz, &name, buf, len, idx); /* curve name */ + if (ret == WS_SUCCESS) + ret = GetStringRef(&pubSz, &pub, buf, len, idx); /* Q */ + if (ret == WS_SUCCESS) + ret = GetMpint(&privSz, &priv, buf, len, idx); /* d */ - ret = wc_ecc_import_private_key_ex(priv, privSz, pub, pubSz, - key, ECC_CURVE_DEF); + if (ret == WS_SUCCESS) + ret = wc_ecc_import_private_key_ex(priv, privSz, pub, pubSz, + key, ECC_CURVE_DEF); - return ret != 0; + if (ret != WS_SUCCESS) + ret = WS_ECC_E; + + return ret; } - +/* + * Decodes an OpenSSH format key. + */ static int GetOpenSshKey(WS_KeySignature *key, const byte* buf, word32 len, word32* idx) { @@ -1155,17 +1208,19 @@ static int GetOpenSshKey(WS_KeySignature *key, str, strSz, &subIdx); if (ret == WS_SUCCESS) { keyId = NameToId((const char*)subStr, subStrSz); - wolfSSH_KEY_init(key, keyId, NULL); + ret = wolfSSH_KEY_init(key, keyId, NULL); + } + if (ret == WS_SUCCESS) { switch (keyId) { #ifndef WOLFSSH_NO_RSA case ID_SSH_RSA: - GetOpenSshKeyRsa(&key->ks.rsa.key, + ret = GetOpenSshKeyRsa(&key->ks.rsa.key, str, strSz, &subIdx); break; #endif #ifndef WOLFSSH_NO_ECDSA case ID_ECDSA_SHA2_NISTP256: - GetOpenSshKeyEcc(&key->ks.ecc.key, + ret = GetOpenSshKeyEcc(&key->ks.ecc.key, str, strSz, &subIdx); break; #endif