This commit is contained in:
@@ -3,9 +3,8 @@ from __future__ import annotations
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
import sys
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, NoReturn, Union, cast, overload
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Literal, NoReturn, cast, overload
|
||||
|
||||
from .exceptions import InvalidKeyError
|
||||
from .types import HashlibHash, JWKDict
|
||||
@@ -21,14 +20,8 @@ from .utils import (
|
||||
to_base64url_uint,
|
||||
)
|
||||
|
||||
if sys.version_info >= (3, 8):
|
||||
from typing import Literal
|
||||
else:
|
||||
from typing_extensions import Literal
|
||||
|
||||
|
||||
try:
|
||||
from cryptography.exceptions import InvalidSignature
|
||||
from cryptography.exceptions import InvalidSignature, UnsupportedAlgorithm
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
from cryptography.hazmat.primitives import hashes
|
||||
from cryptography.hazmat.primitives.asymmetric import padding
|
||||
@@ -194,18 +187,16 @@ class Algorithm(ABC):
|
||||
@overload
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def to_jwk(key_obj, as_dict: Literal[True]) -> JWKDict:
|
||||
... # pragma: no cover
|
||||
def to_jwk(key_obj, as_dict: Literal[True]) -> JWKDict: ... # pragma: no cover
|
||||
|
||||
@overload
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def to_jwk(key_obj, as_dict: Literal[False] = False) -> str:
|
||||
... # pragma: no cover
|
||||
def to_jwk(key_obj, as_dict: Literal[False] = False) -> str: ... # pragma: no cover
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def to_jwk(key_obj, as_dict: bool = False) -> Union[JWKDict, str]:
|
||||
def to_jwk(key_obj, as_dict: bool = False) -> JWKDict | str:
|
||||
"""
|
||||
Serializes a given key into a JWK
|
||||
"""
|
||||
@@ -274,16 +265,18 @@ class HMACAlgorithm(Algorithm):
|
||||
|
||||
@overload
|
||||
@staticmethod
|
||||
def to_jwk(key_obj: str | bytes, as_dict: Literal[True]) -> JWKDict:
|
||||
... # pragma: no cover
|
||||
def to_jwk(
|
||||
key_obj: str | bytes, as_dict: Literal[True]
|
||||
) -> JWKDict: ... # pragma: no cover
|
||||
|
||||
@overload
|
||||
@staticmethod
|
||||
def to_jwk(key_obj: str | bytes, as_dict: Literal[False] = False) -> str:
|
||||
... # pragma: no cover
|
||||
def to_jwk(
|
||||
key_obj: str | bytes, as_dict: Literal[False] = False
|
||||
) -> str: ... # pragma: no cover
|
||||
|
||||
@staticmethod
|
||||
def to_jwk(key_obj: str | bytes, as_dict: bool = False) -> Union[JWKDict, str]:
|
||||
def to_jwk(key_obj: str | bytes, as_dict: bool = False) -> JWKDict | str:
|
||||
jwk = {
|
||||
"k": base64url_encode(force_bytes(key_obj)).decode(),
|
||||
"kty": "oct",
|
||||
@@ -304,7 +297,7 @@ class HMACAlgorithm(Algorithm):
|
||||
else:
|
||||
raise ValueError
|
||||
except ValueError:
|
||||
raise InvalidKeyError("Key is not valid JSON")
|
||||
raise InvalidKeyError("Key is not valid JSON") from None
|
||||
|
||||
if obj.get("kty") != "oct":
|
||||
raise InvalidKeyError("Not an HMAC key")
|
||||
@@ -350,22 +343,27 @@ if has_crypto:
|
||||
RSAPrivateKey, load_pem_private_key(key_bytes, password=None)
|
||||
)
|
||||
except ValueError:
|
||||
return cast(RSAPublicKey, load_pem_public_key(key_bytes))
|
||||
try:
|
||||
return cast(RSAPublicKey, load_pem_public_key(key_bytes))
|
||||
except (ValueError, UnsupportedAlgorithm):
|
||||
raise InvalidKeyError(
|
||||
"Could not parse the provided public key."
|
||||
) from None
|
||||
|
||||
@overload
|
||||
@staticmethod
|
||||
def to_jwk(key_obj: AllowedRSAKeys, as_dict: Literal[True]) -> JWKDict:
|
||||
... # pragma: no cover
|
||||
|
||||
@overload
|
||||
@staticmethod
|
||||
def to_jwk(key_obj: AllowedRSAKeys, as_dict: Literal[False] = False) -> str:
|
||||
... # pragma: no cover
|
||||
|
||||
@staticmethod
|
||||
def to_jwk(
|
||||
key_obj: AllowedRSAKeys, as_dict: bool = False
|
||||
) -> Union[JWKDict, str]:
|
||||
key_obj: AllowedRSAKeys, as_dict: Literal[True]
|
||||
) -> JWKDict: ... # pragma: no cover
|
||||
|
||||
@overload
|
||||
@staticmethod
|
||||
def to_jwk(
|
||||
key_obj: AllowedRSAKeys, as_dict: Literal[False] = False
|
||||
) -> str: ... # pragma: no cover
|
||||
|
||||
@staticmethod
|
||||
def to_jwk(key_obj: AllowedRSAKeys, as_dict: bool = False) -> JWKDict | str:
|
||||
obj: dict[str, Any] | None = None
|
||||
|
||||
if hasattr(key_obj, "private_numbers"):
|
||||
@@ -413,10 +411,10 @@ if has_crypto:
|
||||
else:
|
||||
raise ValueError
|
||||
except ValueError:
|
||||
raise InvalidKeyError("Key is not valid JSON")
|
||||
raise InvalidKeyError("Key is not valid JSON") from None
|
||||
|
||||
if obj.get("kty") != "RSA":
|
||||
raise InvalidKeyError("Not an RSA key")
|
||||
raise InvalidKeyError("Not an RSA key") from None
|
||||
|
||||
if "d" in obj and "e" in obj and "n" in obj:
|
||||
# Private key
|
||||
@@ -432,7 +430,7 @@ if has_crypto:
|
||||
if any_props_found and not all(props_found):
|
||||
raise InvalidKeyError(
|
||||
"RSA key must include all parameters if any are present besides d"
|
||||
)
|
||||
) from None
|
||||
|
||||
public_numbers = RSAPublicNumbers(
|
||||
from_base64url_uint(obj["e"]),
|
||||
@@ -524,7 +522,7 @@ if has_crypto:
|
||||
):
|
||||
raise InvalidKeyError(
|
||||
"Expecting a EllipticCurvePrivateKey/EllipticCurvePublicKey. Wrong key provided for ECDSA algorithms"
|
||||
)
|
||||
) from None
|
||||
|
||||
return crypto_key
|
||||
|
||||
@@ -533,7 +531,7 @@ if has_crypto:
|
||||
|
||||
return der_to_raw_signature(der_sig, key.curve)
|
||||
|
||||
def verify(self, msg: bytes, key: "AllowedECKeys", sig: bytes) -> bool:
|
||||
def verify(self, msg: bytes, key: AllowedECKeys, sig: bytes) -> bool:
|
||||
try:
|
||||
der_sig = raw_to_der_signature(sig, key.curve)
|
||||
except ValueError:
|
||||
@@ -552,18 +550,18 @@ if has_crypto:
|
||||
|
||||
@overload
|
||||
@staticmethod
|
||||
def to_jwk(key_obj: AllowedECKeys, as_dict: Literal[True]) -> JWKDict:
|
||||
... # pragma: no cover
|
||||
def to_jwk(
|
||||
key_obj: AllowedECKeys, as_dict: Literal[True]
|
||||
) -> JWKDict: ... # pragma: no cover
|
||||
|
||||
@overload
|
||||
@staticmethod
|
||||
def to_jwk(key_obj: AllowedECKeys, as_dict: Literal[False] = False) -> str:
|
||||
... # pragma: no cover
|
||||
def to_jwk(
|
||||
key_obj: AllowedECKeys, as_dict: Literal[False] = False
|
||||
) -> str: ... # pragma: no cover
|
||||
|
||||
@staticmethod
|
||||
def to_jwk(
|
||||
key_obj: AllowedECKeys, as_dict: bool = False
|
||||
) -> Union[JWKDict, str]:
|
||||
def to_jwk(key_obj: AllowedECKeys, as_dict: bool = False) -> JWKDict | str:
|
||||
if isinstance(key_obj, EllipticCurvePrivateKey):
|
||||
public_numbers = key_obj.public_key().public_numbers()
|
||||
elif isinstance(key_obj, EllipticCurvePublicKey):
|
||||
@@ -585,13 +583,20 @@ if has_crypto:
|
||||
obj: dict[str, Any] = {
|
||||
"kty": "EC",
|
||||
"crv": crv,
|
||||
"x": to_base64url_uint(public_numbers.x).decode(),
|
||||
"y": to_base64url_uint(public_numbers.y).decode(),
|
||||
"x": to_base64url_uint(
|
||||
public_numbers.x,
|
||||
bit_length=key_obj.curve.key_size,
|
||||
).decode(),
|
||||
"y": to_base64url_uint(
|
||||
public_numbers.y,
|
||||
bit_length=key_obj.curve.key_size,
|
||||
).decode(),
|
||||
}
|
||||
|
||||
if isinstance(key_obj, EllipticCurvePrivateKey):
|
||||
obj["d"] = to_base64url_uint(
|
||||
key_obj.private_numbers().private_value
|
||||
key_obj.private_numbers().private_value,
|
||||
bit_length=key_obj.curve.key_size,
|
||||
).decode()
|
||||
|
||||
if as_dict:
|
||||
@@ -609,13 +614,13 @@ if has_crypto:
|
||||
else:
|
||||
raise ValueError
|
||||
except ValueError:
|
||||
raise InvalidKeyError("Key is not valid JSON")
|
||||
raise InvalidKeyError("Key is not valid JSON") from None
|
||||
|
||||
if obj.get("kty") != "EC":
|
||||
raise InvalidKeyError("Not an Elliptic curve key")
|
||||
raise InvalidKeyError("Not an Elliptic curve key") from None
|
||||
|
||||
if "x" not in obj or "y" not in obj:
|
||||
raise InvalidKeyError("Not an Elliptic curve key")
|
||||
raise InvalidKeyError("Not an Elliptic curve key") from None
|
||||
|
||||
x = base64url_decode(obj.get("x"))
|
||||
y = base64url_decode(obj.get("y"))
|
||||
@@ -627,17 +632,23 @@ if has_crypto:
|
||||
if len(x) == len(y) == 32:
|
||||
curve_obj = SECP256R1()
|
||||
else:
|
||||
raise InvalidKeyError("Coords should be 32 bytes for curve P-256")
|
||||
raise InvalidKeyError(
|
||||
"Coords should be 32 bytes for curve P-256"
|
||||
) from None
|
||||
elif curve == "P-384":
|
||||
if len(x) == len(y) == 48:
|
||||
curve_obj = SECP384R1()
|
||||
else:
|
||||
raise InvalidKeyError("Coords should be 48 bytes for curve P-384")
|
||||
raise InvalidKeyError(
|
||||
"Coords should be 48 bytes for curve P-384"
|
||||
) from None
|
||||
elif curve == "P-521":
|
||||
if len(x) == len(y) == 66:
|
||||
curve_obj = SECP521R1()
|
||||
else:
|
||||
raise InvalidKeyError("Coords should be 66 bytes for curve P-521")
|
||||
raise InvalidKeyError(
|
||||
"Coords should be 66 bytes for curve P-521"
|
||||
) from None
|
||||
elif curve == "secp256k1":
|
||||
if len(x) == len(y) == 32:
|
||||
curve_obj = SECP256K1()
|
||||
@@ -771,16 +782,18 @@ if has_crypto:
|
||||
|
||||
@overload
|
||||
@staticmethod
|
||||
def to_jwk(key: AllowedOKPKeys, as_dict: Literal[True]) -> JWKDict:
|
||||
... # pragma: no cover
|
||||
def to_jwk(
|
||||
key: AllowedOKPKeys, as_dict: Literal[True]
|
||||
) -> JWKDict: ... # pragma: no cover
|
||||
|
||||
@overload
|
||||
@staticmethod
|
||||
def to_jwk(key: AllowedOKPKeys, as_dict: Literal[False] = False) -> str:
|
||||
... # pragma: no cover
|
||||
def to_jwk(
|
||||
key: AllowedOKPKeys, as_dict: Literal[False] = False
|
||||
) -> str: ... # pragma: no cover
|
||||
|
||||
@staticmethod
|
||||
def to_jwk(key: AllowedOKPKeys, as_dict: bool = False) -> Union[JWKDict, str]:
|
||||
def to_jwk(key: AllowedOKPKeys, as_dict: bool = False) -> JWKDict | str:
|
||||
if isinstance(key, (Ed25519PublicKey, Ed448PublicKey)):
|
||||
x = key.public_bytes(
|
||||
encoding=Encoding.Raw,
|
||||
@@ -836,7 +849,7 @@ if has_crypto:
|
||||
else:
|
||||
raise ValueError
|
||||
except ValueError:
|
||||
raise InvalidKeyError("Key is not valid JSON")
|
||||
raise InvalidKeyError("Key is not valid JSON") from None
|
||||
|
||||
if obj.get("kty") != "OKP":
|
||||
raise InvalidKeyError("Not an Octet Key Pair")
|
||||
|
||||
Reference in New Issue
Block a user