Simplify OAEP and PSS
Makes things a little bit more like similar APIs. * Hash type is now set in constructor. * MGF is set automtically or manually with `set_mgf()` * Label defaults to emptypull/40/head
parent
e34a0ece53
commit
e3a5470fe8
|
@ -333,6 +333,14 @@ if _lib.RSA_ENABLED:
|
|||
def rsa_private(vectors):
|
||||
return RsaPrivate(vectors[RsaPrivate].key)
|
||||
|
||||
@pytest.fixture
|
||||
def rsa_private_oaep(vectors):
|
||||
return RsaPrivate(vectors[RsaPrivate].key, hash_type=HASH_TYPE_SHA)
|
||||
|
||||
@pytest.fixture
|
||||
def rsa_private_pss(vectors):
|
||||
return RsaPrivate(vectors[RsaPrivate].key, hash_type=HASH_TYPE_SHA256)
|
||||
|
||||
@pytest.fixture
|
||||
def rsa_private_pkcs8(vectors):
|
||||
return RsaPrivate(vectors[RsaPrivate].pkcs8_key)
|
||||
|
@ -341,6 +349,14 @@ if _lib.RSA_ENABLED:
|
|||
def rsa_public(vectors):
|
||||
return RsaPublic(vectors[RsaPublic].key)
|
||||
|
||||
@pytest.fixture
|
||||
def rsa_public_oaep(vectors):
|
||||
return RsaPublic(vectors[RsaPublic].key, hash_type=HASH_TYPE_SHA)
|
||||
|
||||
@pytest.fixture
|
||||
def rsa_public_pss(vectors):
|
||||
return RsaPublic(vectors[RsaPublic].key, hash_type=HASH_TYPE_SHA256)
|
||||
|
||||
@pytest.fixture
|
||||
def rsa_private_pem(vectors):
|
||||
with open(vectors[RsaPrivate].pem, "rb") as f:
|
||||
|
@ -382,21 +398,21 @@ if _lib.RSA_ENABLED:
|
|||
assert 1024 / 8 == len(ciphertext) == rsa_private.output_size
|
||||
assert plaintext == rsa_private.decrypt(ciphertext)
|
||||
|
||||
def test_rsa_encrypt_decrypt_pad_oaep(rsa_private, rsa_public):
|
||||
def test_rsa_encrypt_decrypt_pad_oaep(rsa_private_oaep, rsa_public_oaep):
|
||||
plaintext = t2b("Everyone gets Friday off.")
|
||||
|
||||
# normal usage, encrypt with public, decrypt with private
|
||||
ciphertext = rsa_public.encrypt_oaep(plaintext, HASH_TYPE_SHA, MGF1SHA1, "")
|
||||
ciphertext = rsa_public_oaep.encrypt_oaep(plaintext)
|
||||
|
||||
assert 1024 / 8 == len(ciphertext) == rsa_public.output_size
|
||||
assert plaintext == rsa_private.decrypt_oaep(ciphertext, HASH_TYPE_SHA, MGF1SHA1, "")
|
||||
assert 1024 / 8 == len(ciphertext) == rsa_public_oaep.output_size
|
||||
assert plaintext == rsa_private_oaep.decrypt_oaep(ciphertext)
|
||||
|
||||
# private object holds both private and public info, so it can also encrypt
|
||||
# using the known public key.
|
||||
ciphertext = rsa_private.encrypt_oaep(plaintext, HASH_TYPE_SHA, MGF1SHA1, "")
|
||||
ciphertext = rsa_private_oaep.encrypt_oaep(plaintext)
|
||||
|
||||
assert 1024 / 8 == len(ciphertext) == rsa_private.output_size
|
||||
assert plaintext == rsa_private.decrypt_oaep(ciphertext, HASH_TYPE_SHA, MGF1SHA1, "")
|
||||
assert 1024 / 8 == len(ciphertext) == rsa_private_oaep.output_size
|
||||
assert plaintext == rsa_private_oaep.decrypt_oaep(ciphertext)
|
||||
|
||||
|
||||
def test_rsa_pkcs8_encrypt_decrypt(rsa_private_pkcs8, rsa_public):
|
||||
|
@ -433,21 +449,21 @@ if _lib.RSA_ENABLED:
|
|||
assert plaintext == rsa_private.verify(signature)
|
||||
|
||||
if _lib.RSA_PSS_ENABLED:
|
||||
def test_rsa_pss_sign_verify(rsa_private, rsa_public):
|
||||
def test_rsa_pss_sign_verify(rsa_private_pss, rsa_public_pss):
|
||||
plaintext = t2b("Everyone gets Friday off yippee.")
|
||||
|
||||
# normal usage, sign with private, verify with public
|
||||
signature = rsa_private.sign_pss(plaintext, HASH_TYPE_SHA256, MGF1SHA256)
|
||||
signature = rsa_private_pss.sign_pss(plaintext)
|
||||
|
||||
assert 1024 / 8 == len(signature) == rsa_private.output_size
|
||||
assert 0 == rsa_public.verify_pss(plaintext, signature, HASH_TYPE_SHA256, MGF1SHA256)
|
||||
assert 1024 / 8 == len(signature) == rsa_private_pss.output_size
|
||||
assert 0 == rsa_public_pss.verify_pss(plaintext, signature)
|
||||
|
||||
# private object holds both private and public info, so it can also verify
|
||||
# using the known public key.
|
||||
signature = rsa_private.sign_pss(plaintext, HASH_TYPE_SHA256, MGF1SHA256)
|
||||
signature = rsa_private_pss.sign_pss(plaintext)
|
||||
|
||||
assert 1024 / 8 == len(signature) == rsa_private.output_size
|
||||
assert 0 == rsa_private.verify_pss(plaintext, signature, HASH_TYPE_SHA256, MGF1SHA256)
|
||||
assert 1024 / 8 == len(signature) == rsa_private_pss.output_size
|
||||
assert 0 == rsa_private_pss.verify_pss(plaintext, signature)
|
||||
|
||||
def test_rsa_sign_verify_pem(rsa_private_pem, rsa_public_pem):
|
||||
plaintext = t2b("Everyone gets Friday off.")
|
||||
|
|
|
@ -452,6 +452,8 @@ if _lib.DES3_ENABLED:
|
|||
if _lib.RSA_ENABLED:
|
||||
class _Rsa(object): # pylint: disable=too-few-public-methods
|
||||
RSA_MIN_PAD_SIZE = 11
|
||||
_mgf = None
|
||||
_hash_type = None
|
||||
|
||||
def __init__(self):
|
||||
self.native_object = _ffi.new("RsaKey *")
|
||||
|
@ -473,11 +475,30 @@ if _lib.RSA_ENABLED:
|
|||
if self.native_object:
|
||||
self._delete(self.native_object)
|
||||
|
||||
def set_mgf(self, mgf):
|
||||
self._mgf = mgf
|
||||
|
||||
def _get_mgf(self):
|
||||
if self._hash_type == _lib.WC_HASH_TYPE_SHA:
|
||||
self._mgf = _lib.WC_MGF1SHA1
|
||||
elif self._hash_type == _lib.WC_HASH_TYPE_SHA224:
|
||||
self._mgf = _lib.WC_MGF1SHA224
|
||||
elif self._hash_type == _lib.WC_HASH_TYPE_SHA256:
|
||||
self._mgf = _lib.WC_MGF1SHA256
|
||||
elif self._hash_type == _lib.WC_HASH_TYPE_SHA384:
|
||||
self._mgf = _lib.WC_MGF1SHA384
|
||||
elif self._hash_type == _lib.WC_HASH_TYPE_SHA512:
|
||||
self._mgf = _lib.WC_MGF1SHA512
|
||||
else:
|
||||
self._mgf = _lib.WC_MGF1NONE
|
||||
|
||||
|
||||
|
||||
class RsaPublic(_Rsa):
|
||||
def __init__(self, key=None):
|
||||
def __init__(self, key=None, hash_type=None):
|
||||
if key != None:
|
||||
key = t2b(key)
|
||||
self._hash_type = hash_type
|
||||
|
||||
_Rsa.__init__(self)
|
||||
|
||||
|
@ -524,17 +545,18 @@ if _lib.RSA_ENABLED:
|
|||
|
||||
return _ffi.buffer(ciphertext)[:]
|
||||
|
||||
def encrypt_oaep(self, plaintext, hash_type, mgf, label):
|
||||
def encrypt_oaep(self, plaintext, label=""):
|
||||
plaintext = t2b(plaintext)
|
||||
label = t2b(label)
|
||||
ciphertext = _ffi.new("byte[%d]" % self.output_size)
|
||||
|
||||
if self._mgf is None:
|
||||
self._get_mgf()
|
||||
ret = _lib.wc_RsaPublicEncrypt_ex(plaintext, len(plaintext),
|
||||
ciphertext, self.output_size,
|
||||
self.native_object,
|
||||
self._random.native_object,
|
||||
_lib.WC_RSA_OAEP_PAD, hash_type,
|
||||
mgf, label, len(label))
|
||||
_lib.WC_RSA_OAEP_PAD, self._hash_type,
|
||||
self._mgf, label, len(label))
|
||||
|
||||
if ret != self.output_size: # pragma: no cover
|
||||
raise WolfCryptError("Encryption error (%d)" % ret)
|
||||
|
@ -563,7 +585,7 @@ if _lib.RSA_ENABLED:
|
|||
return _ffi.buffer(plaintext, ret)[:]
|
||||
|
||||
if _lib.RSA_PSS_ENABLED:
|
||||
def verify_pss(self, plaintext, signature, hash_type, mgf):
|
||||
def verify_pss(self, plaintext, signature):
|
||||
"""
|
||||
Verifies **signature**, using the public key data in the
|
||||
object. The signature's length must be equal to:
|
||||
|
@ -574,17 +596,19 @@ if _lib.RSA_ENABLED:
|
|||
"""
|
||||
plaintext = t2b(plaintext)
|
||||
signature = t2b(signature)
|
||||
if self._mgf is None:
|
||||
self._get_mgf()
|
||||
verify = _ffi.new("byte[%d]" % self.output_size)
|
||||
|
||||
ret = _lib.wc_RsaPSS_Verify(signature, len(signature),
|
||||
verify, self.output_size,
|
||||
hash_type, mgf,
|
||||
self._hash_type, self._mgf,
|
||||
self.native_object)
|
||||
|
||||
if ret < 0: # pragma: no cover
|
||||
raise WolfCryptError("Verify error (%d)" % ret)
|
||||
ret = _lib.wc_RsaPSS_CheckPadding(plaintext, len(plaintext),
|
||||
verify, ret, hash_type)
|
||||
verify, ret, self._hash_type)
|
||||
|
||||
return ret
|
||||
|
||||
|
@ -613,10 +637,10 @@ if _lib.RSA_ENABLED:
|
|||
|
||||
return rsa
|
||||
|
||||
def __init__(self, key = None): # pylint: disable=super-init-not-called
|
||||
def __init__(self, key=None, hash_type=None): # pylint: disable=super-init-not-called
|
||||
|
||||
_Rsa.__init__(self) # pylint: disable=non-parent-init-called
|
||||
|
||||
self._hash_type = hash_type
|
||||
idx = _ffi.new("word32*")
|
||||
idx[0] = 0
|
||||
|
||||
|
@ -692,7 +716,7 @@ if _lib.RSA_ENABLED:
|
|||
|
||||
return _ffi.buffer(plaintext, ret)[:]
|
||||
|
||||
def decrypt_oaep(self, ciphertext, hash_type, mgf, label):
|
||||
def decrypt_oaep(self, ciphertext, label=""):
|
||||
"""
|
||||
Decrypts **ciphertext**, using the private key data in the
|
||||
object. The ciphertext's length must be equal to:
|
||||
|
@ -704,11 +728,13 @@ if _lib.RSA_ENABLED:
|
|||
ciphertext = t2b(ciphertext)
|
||||
label = t2b(label)
|
||||
plaintext = _ffi.new("byte[%d]" % self.output_size)
|
||||
if self._mgf is None:
|
||||
self._get_mgf()
|
||||
ret = _lib.wc_RsaPrivateDecrypt_ex(ciphertext, len(ciphertext),
|
||||
plaintext, self.output_size,
|
||||
self.native_object,
|
||||
_lib.WC_RSA_OAEP_PAD, hash_type,
|
||||
mgf, label, len(label))
|
||||
_lib.WC_RSA_OAEP_PAD, self._hash_type,
|
||||
self._mgf, label, len(label))
|
||||
|
||||
if ret < 0: # pragma: no cover
|
||||
raise WolfCryptError("Decryption error (%d)" % ret)
|
||||
|
@ -738,7 +764,7 @@ if _lib.RSA_ENABLED:
|
|||
return _ffi.buffer(signature, self.output_size)[:]
|
||||
|
||||
if _lib.RSA_PSS_ENABLED:
|
||||
def sign_pss(self, plaintext, hash_type, mgf):
|
||||
def sign_pss(self, plaintext):
|
||||
"""
|
||||
Signs **plaintext**, using the private key data in the object.
|
||||
The plaintext's length must not be greater than:
|
||||
|
@ -749,10 +775,11 @@ if _lib.RSA_ENABLED:
|
|||
"""
|
||||
plaintext = t2b(plaintext)
|
||||
signature = _ffi.new("byte[%d]" % self.output_size)
|
||||
|
||||
if self._mgf is None:
|
||||
self._get_mgf()
|
||||
ret = _lib.wc_RsaPSS_Sign(plaintext, len(plaintext),
|
||||
signature, self.output_size,
|
||||
hash_type, mgf,
|
||||
self._hash_type, self._mgf,
|
||||
self.native_object,
|
||||
self._random.native_object)
|
||||
|
||||
|
|
Loading…
Reference in New Issue