This commit is contained in:
@@ -6,7 +6,7 @@ from .api_jws import (
|
||||
register_algorithm,
|
||||
unregister_algorithm,
|
||||
)
|
||||
from .api_jwt import PyJWT, decode, decode_complete, encode
|
||||
from .api_jwt import PyJWT, decode, encode
|
||||
from .exceptions import (
|
||||
DecodeError,
|
||||
ExpiredSignatureError,
|
||||
@@ -27,7 +27,7 @@ from .exceptions import (
|
||||
)
|
||||
from .jwks_client import PyJWKClient
|
||||
|
||||
__version__ = "2.10.1"
|
||||
__version__ = "2.8.0"
|
||||
|
||||
__title__ = "PyJWT"
|
||||
__description__ = "JSON Web Token implementation in Python"
|
||||
@@ -49,7 +49,6 @@ __all__ = [
|
||||
"PyJWK",
|
||||
"PyJWKSet",
|
||||
"decode",
|
||||
"decode_complete",
|
||||
"encode",
|
||||
"get_unverified_header",
|
||||
"register_algorithm",
|
||||
|
||||
@@ -3,8 +3,9 @@ from __future__ import annotations
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
import sys
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Literal, NoReturn, cast, overload
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, NoReturn, Union, cast, overload
|
||||
|
||||
from .exceptions import InvalidKeyError
|
||||
from .types import HashlibHash, JWKDict
|
||||
@@ -20,8 +21,14 @@ from .utils import (
|
||||
to_base64url_uint,
|
||||
)
|
||||
|
||||
if sys.version_info >= (3, 8):
|
||||
from typing import Literal
|
||||
else:
|
||||
from typing_extensions import Literal
|
||||
|
||||
|
||||
try:
|
||||
from cryptography.exceptions import InvalidSignature, UnsupportedAlgorithm
|
||||
from cryptography.exceptions import InvalidSignature
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
from cryptography.hazmat.primitives import hashes
|
||||
from cryptography.hazmat.primitives.asymmetric import padding
|
||||
@@ -187,16 +194,18 @@ class Algorithm(ABC):
|
||||
@overload
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def to_jwk(key_obj, as_dict: Literal[True]) -> JWKDict: ... # pragma: no cover
|
||||
def to_jwk(key_obj, as_dict: Literal[True]) -> JWKDict:
|
||||
... # pragma: no cover
|
||||
|
||||
@overload
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def to_jwk(key_obj, as_dict: Literal[False] = False) -> str: ... # pragma: no cover
|
||||
def to_jwk(key_obj, as_dict: Literal[False] = False) -> str:
|
||||
... # pragma: no cover
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def to_jwk(key_obj, as_dict: bool = False) -> JWKDict | str:
|
||||
def to_jwk(key_obj, as_dict: bool = False) -> Union[JWKDict, str]:
|
||||
"""
|
||||
Serializes a given key into a JWK
|
||||
"""
|
||||
@@ -265,18 +274,16 @@ class HMACAlgorithm(Algorithm):
|
||||
|
||||
@overload
|
||||
@staticmethod
|
||||
def to_jwk(
|
||||
key_obj: str | bytes, as_dict: Literal[True]
|
||||
) -> JWKDict: ... # pragma: no cover
|
||||
def to_jwk(key_obj: str | bytes, as_dict: Literal[True]) -> JWKDict:
|
||||
... # pragma: no cover
|
||||
|
||||
@overload
|
||||
@staticmethod
|
||||
def to_jwk(
|
||||
key_obj: str | bytes, as_dict: Literal[False] = False
|
||||
) -> str: ... # pragma: no cover
|
||||
def to_jwk(key_obj: str | bytes, as_dict: Literal[False] = False) -> str:
|
||||
... # pragma: no cover
|
||||
|
||||
@staticmethod
|
||||
def to_jwk(key_obj: str | bytes, as_dict: bool = False) -> JWKDict | str:
|
||||
def to_jwk(key_obj: str | bytes, as_dict: bool = False) -> Union[JWKDict, str]:
|
||||
jwk = {
|
||||
"k": base64url_encode(force_bytes(key_obj)).decode(),
|
||||
"kty": "oct",
|
||||
@@ -297,7 +304,7 @@ class HMACAlgorithm(Algorithm):
|
||||
else:
|
||||
raise ValueError
|
||||
except ValueError:
|
||||
raise InvalidKeyError("Key is not valid JSON") from None
|
||||
raise InvalidKeyError("Key is not valid JSON")
|
||||
|
||||
if obj.get("kty") != "oct":
|
||||
raise InvalidKeyError("Not an HMAC key")
|
||||
@@ -343,27 +350,22 @@ if has_crypto:
|
||||
RSAPrivateKey, load_pem_private_key(key_bytes, password=None)
|
||||
)
|
||||
except ValueError:
|
||||
try:
|
||||
return cast(RSAPublicKey, load_pem_public_key(key_bytes))
|
||||
except (ValueError, UnsupportedAlgorithm):
|
||||
raise InvalidKeyError(
|
||||
"Could not parse the provided public key."
|
||||
) from None
|
||||
return cast(RSAPublicKey, load_pem_public_key(key_bytes))
|
||||
|
||||
@overload
|
||||
@staticmethod
|
||||
def to_jwk(
|
||||
key_obj: AllowedRSAKeys, as_dict: Literal[True]
|
||||
) -> JWKDict: ... # pragma: no cover
|
||||
def to_jwk(key_obj: AllowedRSAKeys, as_dict: Literal[True]) -> JWKDict:
|
||||
... # pragma: no cover
|
||||
|
||||
@overload
|
||||
@staticmethod
|
||||
def to_jwk(
|
||||
key_obj: AllowedRSAKeys, as_dict: Literal[False] = False
|
||||
) -> str: ... # pragma: no cover
|
||||
def to_jwk(key_obj: AllowedRSAKeys, as_dict: Literal[False] = False) -> str:
|
||||
... # pragma: no cover
|
||||
|
||||
@staticmethod
|
||||
def to_jwk(key_obj: AllowedRSAKeys, as_dict: bool = False) -> JWKDict | str:
|
||||
def to_jwk(
|
||||
key_obj: AllowedRSAKeys, as_dict: bool = False
|
||||
) -> Union[JWKDict, str]:
|
||||
obj: dict[str, Any] | None = None
|
||||
|
||||
if hasattr(key_obj, "private_numbers"):
|
||||
@@ -411,10 +413,10 @@ if has_crypto:
|
||||
else:
|
||||
raise ValueError
|
||||
except ValueError:
|
||||
raise InvalidKeyError("Key is not valid JSON") from None
|
||||
raise InvalidKeyError("Key is not valid JSON")
|
||||
|
||||
if obj.get("kty") != "RSA":
|
||||
raise InvalidKeyError("Not an RSA key") from None
|
||||
raise InvalidKeyError("Not an RSA key")
|
||||
|
||||
if "d" in obj and "e" in obj and "n" in obj:
|
||||
# Private key
|
||||
@@ -430,7 +432,7 @@ if has_crypto:
|
||||
if any_props_found and not all(props_found):
|
||||
raise InvalidKeyError(
|
||||
"RSA key must include all parameters if any are present besides d"
|
||||
) from None
|
||||
)
|
||||
|
||||
public_numbers = RSAPublicNumbers(
|
||||
from_base64url_uint(obj["e"]),
|
||||
@@ -522,7 +524,7 @@ if has_crypto:
|
||||
):
|
||||
raise InvalidKeyError(
|
||||
"Expecting a EllipticCurvePrivateKey/EllipticCurvePublicKey. Wrong key provided for ECDSA algorithms"
|
||||
) from None
|
||||
)
|
||||
|
||||
return crypto_key
|
||||
|
||||
@@ -531,7 +533,7 @@ if has_crypto:
|
||||
|
||||
return der_to_raw_signature(der_sig, key.curve)
|
||||
|
||||
def verify(self, msg: bytes, key: AllowedECKeys, sig: bytes) -> bool:
|
||||
def verify(self, msg: bytes, key: "AllowedECKeys", sig: bytes) -> bool:
|
||||
try:
|
||||
der_sig = raw_to_der_signature(sig, key.curve)
|
||||
except ValueError:
|
||||
@@ -550,18 +552,18 @@ if has_crypto:
|
||||
|
||||
@overload
|
||||
@staticmethod
|
||||
def to_jwk(
|
||||
key_obj: AllowedECKeys, as_dict: Literal[True]
|
||||
) -> JWKDict: ... # pragma: no cover
|
||||
def to_jwk(key_obj: AllowedECKeys, as_dict: Literal[True]) -> JWKDict:
|
||||
... # pragma: no cover
|
||||
|
||||
@overload
|
||||
@staticmethod
|
||||
def to_jwk(
|
||||
key_obj: AllowedECKeys, as_dict: Literal[False] = False
|
||||
) -> str: ... # pragma: no cover
|
||||
def to_jwk(key_obj: AllowedECKeys, as_dict: Literal[False] = False) -> str:
|
||||
... # pragma: no cover
|
||||
|
||||
@staticmethod
|
||||
def to_jwk(key_obj: AllowedECKeys, as_dict: bool = False) -> JWKDict | str:
|
||||
def to_jwk(
|
||||
key_obj: AllowedECKeys, as_dict: bool = False
|
||||
) -> Union[JWKDict, str]:
|
||||
if isinstance(key_obj, EllipticCurvePrivateKey):
|
||||
public_numbers = key_obj.public_key().public_numbers()
|
||||
elif isinstance(key_obj, EllipticCurvePublicKey):
|
||||
@@ -583,20 +585,13 @@ if has_crypto:
|
||||
obj: dict[str, Any] = {
|
||||
"kty": "EC",
|
||||
"crv": crv,
|
||||
"x": to_base64url_uint(
|
||||
public_numbers.x,
|
||||
bit_length=key_obj.curve.key_size,
|
||||
).decode(),
|
||||
"y": to_base64url_uint(
|
||||
public_numbers.y,
|
||||
bit_length=key_obj.curve.key_size,
|
||||
).decode(),
|
||||
"x": to_base64url_uint(public_numbers.x).decode(),
|
||||
"y": to_base64url_uint(public_numbers.y).decode(),
|
||||
}
|
||||
|
||||
if isinstance(key_obj, EllipticCurvePrivateKey):
|
||||
obj["d"] = to_base64url_uint(
|
||||
key_obj.private_numbers().private_value,
|
||||
bit_length=key_obj.curve.key_size,
|
||||
key_obj.private_numbers().private_value
|
||||
).decode()
|
||||
|
||||
if as_dict:
|
||||
@@ -614,13 +609,13 @@ if has_crypto:
|
||||
else:
|
||||
raise ValueError
|
||||
except ValueError:
|
||||
raise InvalidKeyError("Key is not valid JSON") from None
|
||||
raise InvalidKeyError("Key is not valid JSON")
|
||||
|
||||
if obj.get("kty") != "EC":
|
||||
raise InvalidKeyError("Not an Elliptic curve key") from None
|
||||
raise InvalidKeyError("Not an Elliptic curve key")
|
||||
|
||||
if "x" not in obj or "y" not in obj:
|
||||
raise InvalidKeyError("Not an Elliptic curve key") from None
|
||||
raise InvalidKeyError("Not an Elliptic curve key")
|
||||
|
||||
x = base64url_decode(obj.get("x"))
|
||||
y = base64url_decode(obj.get("y"))
|
||||
@@ -632,23 +627,17 @@ if has_crypto:
|
||||
if len(x) == len(y) == 32:
|
||||
curve_obj = SECP256R1()
|
||||
else:
|
||||
raise InvalidKeyError(
|
||||
"Coords should be 32 bytes for curve P-256"
|
||||
) from None
|
||||
raise InvalidKeyError("Coords should be 32 bytes for curve P-256")
|
||||
elif curve == "P-384":
|
||||
if len(x) == len(y) == 48:
|
||||
curve_obj = SECP384R1()
|
||||
else:
|
||||
raise InvalidKeyError(
|
||||
"Coords should be 48 bytes for curve P-384"
|
||||
) from None
|
||||
raise InvalidKeyError("Coords should be 48 bytes for curve P-384")
|
||||
elif curve == "P-521":
|
||||
if len(x) == len(y) == 66:
|
||||
curve_obj = SECP521R1()
|
||||
else:
|
||||
raise InvalidKeyError(
|
||||
"Coords should be 66 bytes for curve P-521"
|
||||
) from None
|
||||
raise InvalidKeyError("Coords should be 66 bytes for curve P-521")
|
||||
elif curve == "secp256k1":
|
||||
if len(x) == len(y) == 32:
|
||||
curve_obj = SECP256K1()
|
||||
@@ -782,18 +771,16 @@ if has_crypto:
|
||||
|
||||
@overload
|
||||
@staticmethod
|
||||
def to_jwk(
|
||||
key: AllowedOKPKeys, as_dict: Literal[True]
|
||||
) -> JWKDict: ... # pragma: no cover
|
||||
def to_jwk(key: AllowedOKPKeys, as_dict: Literal[True]) -> JWKDict:
|
||||
... # pragma: no cover
|
||||
|
||||
@overload
|
||||
@staticmethod
|
||||
def to_jwk(
|
||||
key: AllowedOKPKeys, as_dict: Literal[False] = False
|
||||
) -> str: ... # pragma: no cover
|
||||
def to_jwk(key: AllowedOKPKeys, as_dict: Literal[False] = False) -> str:
|
||||
... # pragma: no cover
|
||||
|
||||
@staticmethod
|
||||
def to_jwk(key: AllowedOKPKeys, as_dict: bool = False) -> JWKDict | str:
|
||||
def to_jwk(key: AllowedOKPKeys, as_dict: bool = False) -> Union[JWKDict, str]:
|
||||
if isinstance(key, (Ed25519PublicKey, Ed448PublicKey)):
|
||||
x = key.public_bytes(
|
||||
encoding=Encoding.Raw,
|
||||
@@ -849,7 +836,7 @@ if has_crypto:
|
||||
else:
|
||||
raise ValueError
|
||||
except ValueError:
|
||||
raise InvalidKeyError("Key is not valid JSON") from None
|
||||
raise InvalidKeyError("Key is not valid JSON")
|
||||
|
||||
if obj.get("kty") != "OKP":
|
||||
raise InvalidKeyError("Not an Octet Key Pair")
|
||||
|
||||
@@ -5,13 +5,7 @@ import time
|
||||
from typing import Any
|
||||
|
||||
from .algorithms import get_default_algorithms, has_crypto, requires_cryptography
|
||||
from .exceptions import (
|
||||
InvalidKeyError,
|
||||
MissingCryptographyError,
|
||||
PyJWKError,
|
||||
PyJWKSetError,
|
||||
PyJWTError,
|
||||
)
|
||||
from .exceptions import InvalidKeyError, PyJWKError, PyJWKSetError, PyJWTError
|
||||
from .types import JWKDict
|
||||
|
||||
|
||||
@@ -56,25 +50,21 @@ class PyJWK:
|
||||
raise InvalidKeyError(f"Unsupported kty: {kty}")
|
||||
|
||||
if not has_crypto and algorithm in requires_cryptography:
|
||||
raise MissingCryptographyError(
|
||||
f"{algorithm} requires 'cryptography' to be installed."
|
||||
)
|
||||
raise PyJWKError(f"{algorithm} requires 'cryptography' to be installed.")
|
||||
|
||||
self.algorithm_name = algorithm
|
||||
self.Algorithm = self._algorithms.get(algorithm)
|
||||
|
||||
if algorithm in self._algorithms:
|
||||
self.Algorithm = self._algorithms[algorithm]
|
||||
else:
|
||||
if not self.Algorithm:
|
||||
raise PyJWKError(f"Unable to find an algorithm for key: {self._jwk_data}")
|
||||
|
||||
self.key = self.Algorithm.from_jwk(self._jwk_data)
|
||||
|
||||
@staticmethod
|
||||
def from_dict(obj: JWKDict, algorithm: str | None = None) -> PyJWK:
|
||||
def from_dict(obj: JWKDict, algorithm: str | None = None) -> "PyJWK":
|
||||
return PyJWK(obj, algorithm)
|
||||
|
||||
@staticmethod
|
||||
def from_json(data: str, algorithm: None = None) -> PyJWK:
|
||||
def from_json(data: str, algorithm: None = None) -> "PyJWK":
|
||||
obj = json.loads(data)
|
||||
return PyJWK.from_dict(obj, algorithm)
|
||||
|
||||
@@ -104,9 +94,7 @@ class PyJWKSet:
|
||||
for key in keys:
|
||||
try:
|
||||
self.keys.append(PyJWK(key))
|
||||
except PyJWTError as error:
|
||||
if isinstance(error, MissingCryptographyError):
|
||||
raise error
|
||||
except PyJWTError:
|
||||
# skip unusable keys
|
||||
continue
|
||||
|
||||
@@ -116,16 +104,16 @@ class PyJWKSet:
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def from_dict(obj: dict[str, Any]) -> PyJWKSet:
|
||||
def from_dict(obj: dict[str, Any]) -> "PyJWKSet":
|
||||
keys = obj.get("keys", [])
|
||||
return PyJWKSet(keys)
|
||||
|
||||
@staticmethod
|
||||
def from_json(data: str) -> PyJWKSet:
|
||||
def from_json(data: str) -> "PyJWKSet":
|
||||
obj = json.loads(data)
|
||||
return PyJWKSet.from_dict(obj)
|
||||
|
||||
def __getitem__(self, kid: str) -> PyJWK:
|
||||
def __getitem__(self, kid: str) -> "PyJWK":
|
||||
for key in self.keys:
|
||||
if key.key_id == kid:
|
||||
return key
|
||||
|
||||
@@ -3,7 +3,6 @@ from __future__ import annotations
|
||||
import binascii
|
||||
import json
|
||||
import warnings
|
||||
from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from .algorithms import (
|
||||
@@ -12,7 +11,6 @@ from .algorithms import (
|
||||
has_crypto,
|
||||
requires_cryptography,
|
||||
)
|
||||
from .api_jwk import PyJWK
|
||||
from .exceptions import (
|
||||
DecodeError,
|
||||
InvalidAlgorithmError,
|
||||
@@ -31,7 +29,7 @@ class PyJWS:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
algorithms: Sequence[str] | None = None,
|
||||
algorithms: list[str] | None = None,
|
||||
options: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
self._algorithms = get_default_algorithms()
|
||||
@@ -105,8 +103,8 @@ class PyJWS:
|
||||
def encode(
|
||||
self,
|
||||
payload: bytes,
|
||||
key: AllowedPrivateKeys | PyJWK | str | bytes,
|
||||
algorithm: str | None = None,
|
||||
key: AllowedPrivateKeys | str | bytes,
|
||||
algorithm: str | None = "HS256",
|
||||
headers: dict[str, Any] | None = None,
|
||||
json_encoder: type[json.JSONEncoder] | None = None,
|
||||
is_payload_detached: bool = False,
|
||||
@@ -115,13 +113,7 @@ class PyJWS:
|
||||
segments = []
|
||||
|
||||
# declare a new var to narrow the type for type checkers
|
||||
if algorithm is None:
|
||||
if isinstance(key, PyJWK):
|
||||
algorithm_ = key.algorithm_name
|
||||
else:
|
||||
algorithm_ = "HS256"
|
||||
else:
|
||||
algorithm_ = algorithm
|
||||
algorithm_: str = algorithm if algorithm is not None else "none"
|
||||
|
||||
# Prefer headers values if present to function parameters.
|
||||
if headers:
|
||||
@@ -165,8 +157,6 @@ class PyJWS:
|
||||
signing_input = b".".join(segments)
|
||||
|
||||
alg_obj = self.get_algorithm_by_name(algorithm_)
|
||||
if isinstance(key, PyJWK):
|
||||
key = key.key
|
||||
key = alg_obj.prepare_key(key)
|
||||
signature = alg_obj.sign(signing_input, key)
|
||||
|
||||
@@ -182,8 +172,8 @@ class PyJWS:
|
||||
def decode_complete(
|
||||
self,
|
||||
jwt: str | bytes,
|
||||
key: AllowedPublicKeys | PyJWK | str | bytes = "",
|
||||
algorithms: Sequence[str] | None = None,
|
||||
key: AllowedPublicKeys | str | bytes = "",
|
||||
algorithms: list[str] | None = None,
|
||||
options: dict[str, Any] | None = None,
|
||||
detached_payload: bytes | None = None,
|
||||
**kwargs,
|
||||
@@ -194,14 +184,13 @@ class PyJWS:
|
||||
"and will be removed in pyjwt version 3. "
|
||||
f"Unsupported kwargs: {tuple(kwargs.keys())}",
|
||||
RemovedInPyjwt3Warning,
|
||||
stacklevel=2,
|
||||
)
|
||||
if options is None:
|
||||
options = {}
|
||||
merged_options = {**self.options, **options}
|
||||
verify_signature = merged_options["verify_signature"]
|
||||
|
||||
if verify_signature and not algorithms and not isinstance(key, PyJWK):
|
||||
if verify_signature and not algorithms:
|
||||
raise DecodeError(
|
||||
'It is required that you pass in a value for the "algorithms" argument when calling decode().'
|
||||
)
|
||||
@@ -228,8 +217,8 @@ class PyJWS:
|
||||
def decode(
|
||||
self,
|
||||
jwt: str | bytes,
|
||||
key: AllowedPublicKeys | PyJWK | str | bytes = "",
|
||||
algorithms: Sequence[str] | None = None,
|
||||
key: AllowedPublicKeys | str | bytes = "",
|
||||
algorithms: list[str] | None = None,
|
||||
options: dict[str, Any] | None = None,
|
||||
detached_payload: bytes | None = None,
|
||||
**kwargs,
|
||||
@@ -240,7 +229,6 @@ class PyJWS:
|
||||
"and will be removed in pyjwt version 3. "
|
||||
f"Unsupported kwargs: {tuple(kwargs.keys())}",
|
||||
RemovedInPyjwt3Warning,
|
||||
stacklevel=2,
|
||||
)
|
||||
decoded = self.decode_complete(
|
||||
jwt, key, algorithms, options, detached_payload=detached_payload
|
||||
@@ -301,28 +289,22 @@ class PyJWS:
|
||||
signing_input: bytes,
|
||||
header: dict[str, Any],
|
||||
signature: bytes,
|
||||
key: AllowedPublicKeys | PyJWK | str | bytes = "",
|
||||
algorithms: Sequence[str] | None = None,
|
||||
key: AllowedPublicKeys | str | bytes = "",
|
||||
algorithms: list[str] | None = None,
|
||||
) -> None:
|
||||
if algorithms is None and isinstance(key, PyJWK):
|
||||
algorithms = [key.algorithm_name]
|
||||
try:
|
||||
alg = header["alg"]
|
||||
except KeyError:
|
||||
raise InvalidAlgorithmError("Algorithm not specified") from None
|
||||
raise InvalidAlgorithmError("Algorithm not specified")
|
||||
|
||||
if not alg or (algorithms is not None and alg not in algorithms):
|
||||
raise InvalidAlgorithmError("The specified alg value is not allowed")
|
||||
|
||||
if isinstance(key, PyJWK):
|
||||
alg_obj = key.Algorithm
|
||||
prepared_key = key.key
|
||||
else:
|
||||
try:
|
||||
alg_obj = self.get_algorithm_by_name(alg)
|
||||
except NotImplementedError as e:
|
||||
raise InvalidAlgorithmError("Algorithm not supported") from e
|
||||
prepared_key = alg_obj.prepare_key(key)
|
||||
try:
|
||||
alg_obj = self.get_algorithm_by_name(alg)
|
||||
except NotImplementedError as e:
|
||||
raise InvalidAlgorithmError("Algorithm not supported") from e
|
||||
prepared_key = alg_obj.prepare_key(key)
|
||||
|
||||
if not alg_obj.verify(signing_input, prepared_key, signature):
|
||||
raise InvalidSignatureError("Signature verification failed")
|
||||
|
||||
@@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
import json
|
||||
import warnings
|
||||
from calendar import timegm
|
||||
from collections.abc import Iterable, Sequence
|
||||
from collections.abc import Iterable
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
@@ -15,15 +15,12 @@ from .exceptions import (
|
||||
InvalidAudienceError,
|
||||
InvalidIssuedAtError,
|
||||
InvalidIssuerError,
|
||||
InvalidJTIError,
|
||||
InvalidSubjectError,
|
||||
MissingRequiredClaimError,
|
||||
)
|
||||
from .warnings import RemovedInPyjwt3Warning
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .algorithms import AllowedPrivateKeys, AllowedPublicKeys
|
||||
from .api_jwk import PyJWK
|
||||
|
||||
|
||||
class PyJWT:
|
||||
@@ -41,16 +38,14 @@ class PyJWT:
|
||||
"verify_iat": True,
|
||||
"verify_aud": True,
|
||||
"verify_iss": True,
|
||||
"verify_sub": True,
|
||||
"verify_jti": True,
|
||||
"require": [],
|
||||
}
|
||||
|
||||
def encode(
|
||||
self,
|
||||
payload: dict[str, Any],
|
||||
key: AllowedPrivateKeys | PyJWK | str | bytes,
|
||||
algorithm: str | None = None,
|
||||
key: AllowedPrivateKeys | str | bytes,
|
||||
algorithm: str | None = "HS256",
|
||||
headers: dict[str, Any] | None = None,
|
||||
json_encoder: type[json.JSONEncoder] | None = None,
|
||||
sort_headers: bool = True,
|
||||
@@ -105,8 +100,8 @@ class PyJWT:
|
||||
def decode_complete(
|
||||
self,
|
||||
jwt: str | bytes,
|
||||
key: AllowedPublicKeys | PyJWK | str | bytes = "",
|
||||
algorithms: Sequence[str] | None = None,
|
||||
key: AllowedPublicKeys | str | bytes = "",
|
||||
algorithms: list[str] | None = None,
|
||||
options: dict[str, Any] | None = None,
|
||||
# deprecated arg, remove in pyjwt3
|
||||
verify: bool | None = None,
|
||||
@@ -115,8 +110,7 @@ class PyJWT:
|
||||
# passthrough arguments to _validate_claims
|
||||
# consider putting in options
|
||||
audience: str | Iterable[str] | None = None,
|
||||
issuer: str | Sequence[str] | None = None,
|
||||
subject: str | None = None,
|
||||
issuer: str | None = None,
|
||||
leeway: float | timedelta = 0,
|
||||
# kwargs
|
||||
**kwargs: Any,
|
||||
@@ -127,7 +121,6 @@ class PyJWT:
|
||||
"and will be removed in pyjwt version 3. "
|
||||
f"Unsupported kwargs: {tuple(kwargs.keys())}",
|
||||
RemovedInPyjwt3Warning,
|
||||
stacklevel=2,
|
||||
)
|
||||
options = dict(options or {}) # shallow-copy or initialize an empty dict
|
||||
options.setdefault("verify_signature", True)
|
||||
@@ -141,7 +134,6 @@ class PyJWT:
|
||||
"The equivalent is setting `verify_signature` to False in the `options` dictionary. "
|
||||
"This invocation has a mismatch between the kwarg and the option entry.",
|
||||
category=DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
if not options["verify_signature"]:
|
||||
@@ -150,8 +142,11 @@ class PyJWT:
|
||||
options.setdefault("verify_iat", False)
|
||||
options.setdefault("verify_aud", False)
|
||||
options.setdefault("verify_iss", False)
|
||||
options.setdefault("verify_sub", False)
|
||||
options.setdefault("verify_jti", False)
|
||||
|
||||
if options["verify_signature"] and not algorithms:
|
||||
raise DecodeError(
|
||||
'It is required that you pass in a value for the "algorithms" argument when calling decode().'
|
||||
)
|
||||
|
||||
decoded = api_jws.decode_complete(
|
||||
jwt,
|
||||
@@ -165,12 +160,7 @@ class PyJWT:
|
||||
|
||||
merged_options = {**self.options, **options}
|
||||
self._validate_claims(
|
||||
payload,
|
||||
merged_options,
|
||||
audience=audience,
|
||||
issuer=issuer,
|
||||
leeway=leeway,
|
||||
subject=subject,
|
||||
payload, merged_options, audience=audience, issuer=issuer, leeway=leeway
|
||||
)
|
||||
|
||||
decoded["payload"] = payload
|
||||
@@ -187,7 +177,7 @@ class PyJWT:
|
||||
try:
|
||||
payload = json.loads(decoded["payload"])
|
||||
except ValueError as e:
|
||||
raise DecodeError(f"Invalid payload string: {e}") from e
|
||||
raise DecodeError(f"Invalid payload string: {e}")
|
||||
if not isinstance(payload, dict):
|
||||
raise DecodeError("Invalid payload string: must be a json object")
|
||||
return payload
|
||||
@@ -195,8 +185,8 @@ class PyJWT:
|
||||
def decode(
|
||||
self,
|
||||
jwt: str | bytes,
|
||||
key: AllowedPublicKeys | PyJWK | str | bytes = "",
|
||||
algorithms: Sequence[str] | None = None,
|
||||
key: AllowedPublicKeys | str | bytes = "",
|
||||
algorithms: list[str] | None = None,
|
||||
options: dict[str, Any] | None = None,
|
||||
# deprecated arg, remove in pyjwt3
|
||||
verify: bool | None = None,
|
||||
@@ -205,8 +195,7 @@ class PyJWT:
|
||||
# passthrough arguments to _validate_claims
|
||||
# consider putting in options
|
||||
audience: str | Iterable[str] | None = None,
|
||||
subject: str | None = None,
|
||||
issuer: str | Sequence[str] | None = None,
|
||||
issuer: str | None = None,
|
||||
leeway: float | timedelta = 0,
|
||||
# kwargs
|
||||
**kwargs: Any,
|
||||
@@ -217,7 +206,6 @@ class PyJWT:
|
||||
"and will be removed in pyjwt version 3. "
|
||||
f"Unsupported kwargs: {tuple(kwargs.keys())}",
|
||||
RemovedInPyjwt3Warning,
|
||||
stacklevel=2,
|
||||
)
|
||||
decoded = self.decode_complete(
|
||||
jwt,
|
||||
@@ -227,7 +215,6 @@ class PyJWT:
|
||||
verify=verify,
|
||||
detached_payload=detached_payload,
|
||||
audience=audience,
|
||||
subject=subject,
|
||||
issuer=issuer,
|
||||
leeway=leeway,
|
||||
)
|
||||
@@ -239,7 +226,6 @@ class PyJWT:
|
||||
options: dict[str, Any],
|
||||
audience=None,
|
||||
issuer=None,
|
||||
subject: str | None = None,
|
||||
leeway: float | timedelta = 0,
|
||||
) -> None:
|
||||
if isinstance(leeway, timedelta):
|
||||
@@ -269,12 +255,6 @@ class PyJWT:
|
||||
payload, audience, strict=options.get("strict_aud", False)
|
||||
)
|
||||
|
||||
if options["verify_sub"]:
|
||||
self._validate_sub(payload, subject)
|
||||
|
||||
if options["verify_jti"]:
|
||||
self._validate_jti(payload)
|
||||
|
||||
def _validate_required_claims(
|
||||
self,
|
||||
payload: dict[str, Any],
|
||||
@@ -284,39 +264,6 @@ class PyJWT:
|
||||
if payload.get(claim) is None:
|
||||
raise MissingRequiredClaimError(claim)
|
||||
|
||||
def _validate_sub(self, payload: dict[str, Any], subject=None) -> None:
|
||||
"""
|
||||
Checks whether "sub" if in the payload is valid ot not.
|
||||
This is an Optional claim
|
||||
|
||||
:param payload(dict): The payload which needs to be validated
|
||||
:param subject(str): The subject of the token
|
||||
"""
|
||||
|
||||
if "sub" not in payload:
|
||||
return
|
||||
|
||||
if not isinstance(payload["sub"], str):
|
||||
raise InvalidSubjectError("Subject must be a string")
|
||||
|
||||
if subject is not None:
|
||||
if payload.get("sub") != subject:
|
||||
raise InvalidSubjectError("Invalid subject")
|
||||
|
||||
def _validate_jti(self, payload: dict[str, Any]) -> None:
|
||||
"""
|
||||
Checks whether "jti" if in the payload is valid ot not
|
||||
This is an Optional claim
|
||||
|
||||
:param payload(dict): The payload which needs to be validated
|
||||
"""
|
||||
|
||||
if "jti" not in payload:
|
||||
return
|
||||
|
||||
if not isinstance(payload.get("jti"), str):
|
||||
raise InvalidJTIError("JWT ID must be a string")
|
||||
|
||||
def _validate_iat(
|
||||
self,
|
||||
payload: dict[str, Any],
|
||||
@@ -326,9 +273,7 @@ class PyJWT:
|
||||
try:
|
||||
iat = int(payload["iat"])
|
||||
except ValueError:
|
||||
raise InvalidIssuedAtError(
|
||||
"Issued At claim (iat) must be an integer."
|
||||
) from None
|
||||
raise InvalidIssuedAtError("Issued At claim (iat) must be an integer.")
|
||||
if iat > (now + leeway):
|
||||
raise ImmatureSignatureError("The token is not yet valid (iat)")
|
||||
|
||||
@@ -341,7 +286,7 @@ class PyJWT:
|
||||
try:
|
||||
nbf = int(payload["nbf"])
|
||||
except ValueError:
|
||||
raise DecodeError("Not Before claim (nbf) must be an integer.") from None
|
||||
raise DecodeError("Not Before claim (nbf) must be an integer.")
|
||||
|
||||
if nbf > (now + leeway):
|
||||
raise ImmatureSignatureError("The token is not yet valid (nbf)")
|
||||
@@ -355,9 +300,7 @@ class PyJWT:
|
||||
try:
|
||||
exp = int(payload["exp"])
|
||||
except ValueError:
|
||||
raise DecodeError(
|
||||
"Expiration Time claim (exp) must be an integer."
|
||||
) from None
|
||||
raise DecodeError("Expiration Time claim (exp) must be an" " integer.")
|
||||
|
||||
if exp <= (now - leeway):
|
||||
raise ExpiredSignatureError("Signature has expired")
|
||||
@@ -419,12 +362,8 @@ class PyJWT:
|
||||
if "iss" not in payload:
|
||||
raise MissingRequiredClaimError("iss")
|
||||
|
||||
if isinstance(issuer, str):
|
||||
if payload["iss"] != issuer:
|
||||
raise InvalidIssuerError("Invalid issuer")
|
||||
else:
|
||||
if payload["iss"] not in issuer:
|
||||
raise InvalidIssuerError("Invalid issuer")
|
||||
if payload["iss"] != issuer:
|
||||
raise InvalidIssuerError("Invalid issuer")
|
||||
|
||||
|
||||
_jwt_global_obj = PyJWT()
|
||||
|
||||
@@ -58,10 +58,6 @@ class PyJWKError(PyJWTError):
|
||||
pass
|
||||
|
||||
|
||||
class MissingCryptographyError(PyJWKError):
|
||||
pass
|
||||
|
||||
|
||||
class PyJWKSetError(PyJWTError):
|
||||
pass
|
||||
|
||||
@@ -72,11 +68,3 @@ class PyJWKClientError(PyJWTError):
|
||||
|
||||
class PyJWKClientConnectionError(PyJWKClientError):
|
||||
pass
|
||||
|
||||
|
||||
class InvalidSubjectError(InvalidTokenError):
|
||||
pass
|
||||
|
||||
|
||||
class InvalidJTIError(InvalidTokenError):
|
||||
pass
|
||||
|
||||
@@ -39,10 +39,7 @@ def info() -> Dict[str, Dict[str, str]]:
|
||||
)
|
||||
if pypy_version_info.releaselevel != "final":
|
||||
implementation_version = "".join(
|
||||
[
|
||||
implementation_version,
|
||||
pypy_version_info.releaselevel,
|
||||
]
|
||||
[implementation_version, pypy_version_info.releaselevel]
|
||||
)
|
||||
else:
|
||||
implementation_version = "Unknown"
|
||||
|
||||
@@ -45,9 +45,7 @@ class PyJWKClient:
|
||||
if cache_keys:
|
||||
# Cache signing keys
|
||||
# Ignore mypy (https://github.com/python/mypy/issues/2427)
|
||||
self.get_signing_key = lru_cache(maxsize=max_cached_keys)(
|
||||
self.get_signing_key
|
||||
) # type: ignore
|
||||
self.get_signing_key = lru_cache(maxsize=max_cached_keys)(self.get_signing_key) # type: ignore
|
||||
|
||||
def fetch_data(self) -> Any:
|
||||
jwk_set: Any = None
|
||||
@@ -60,7 +58,7 @@ class PyJWKClient:
|
||||
except (URLError, TimeoutError) as e:
|
||||
raise PyJWKClientConnectionError(
|
||||
f'Fail to fetch data from the url, err: "{e}"'
|
||||
) from e
|
||||
)
|
||||
else:
|
||||
return jwk_set
|
||||
finally:
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import base64
|
||||
import binascii
|
||||
import re
|
||||
from typing import Optional, Union
|
||||
from typing import Union
|
||||
|
||||
try:
|
||||
from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurve
|
||||
@@ -37,11 +37,11 @@ def base64url_encode(input: bytes) -> bytes:
|
||||
return base64.urlsafe_b64encode(input).replace(b"=", b"")
|
||||
|
||||
|
||||
def to_base64url_uint(val: int, *, bit_length: Optional[int] = None) -> bytes:
|
||||
def to_base64url_uint(val: int) -> bytes:
|
||||
if val < 0:
|
||||
raise ValueError("Must be a positive integer")
|
||||
|
||||
int_bytes = bytes_from_int(val, bit_length=bit_length)
|
||||
int_bytes = bytes_from_int(val)
|
||||
|
||||
if len(int_bytes) == 0:
|
||||
int_bytes = b"\x00"
|
||||
@@ -63,10 +63,13 @@ def bytes_to_number(string: bytes) -> int:
|
||||
return int(binascii.b2a_hex(string), 16)
|
||||
|
||||
|
||||
def bytes_from_int(val: int, *, bit_length: Optional[int] = None) -> bytes:
|
||||
if bit_length is None:
|
||||
bit_length = val.bit_length()
|
||||
byte_length = (bit_length + 7) // 8
|
||||
def bytes_from_int(val: int) -> bytes:
|
||||
remaining = val
|
||||
byte_length = 0
|
||||
|
||||
while remaining != 0:
|
||||
remaining >>= 8
|
||||
byte_length += 1
|
||||
|
||||
return val.to_bytes(byte_length, "big", signed=False)
|
||||
|
||||
@@ -128,15 +131,26 @@ def is_pem_format(key: bytes) -> bool:
|
||||
|
||||
|
||||
# Based on https://github.com/pyca/cryptography/blob/bcb70852d577b3f490f015378c75cba74986297b/src/cryptography/hazmat/primitives/serialization/ssh.py#L40-L46
|
||||
_SSH_KEY_FORMATS = (
|
||||
_CERT_SUFFIX = b"-cert-v01@openssh.com"
|
||||
_SSH_PUBKEY_RC = re.compile(rb"\A(\S+)[ \t]+(\S+)")
|
||||
_SSH_KEY_FORMATS = [
|
||||
b"ssh-ed25519",
|
||||
b"ssh-rsa",
|
||||
b"ssh-dss",
|
||||
b"ecdsa-sha2-nistp256",
|
||||
b"ecdsa-sha2-nistp384",
|
||||
b"ecdsa-sha2-nistp521",
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def is_ssh_key(key: bytes) -> bool:
|
||||
return key.startswith(_SSH_KEY_FORMATS)
|
||||
if any(string_value in key for string_value in _SSH_KEY_FORMATS):
|
||||
return True
|
||||
|
||||
ssh_pubkey_match = _SSH_PUBKEY_RC.match(key)
|
||||
if ssh_pubkey_match:
|
||||
key_type = ssh_pubkey_match.group(1)
|
||||
if _CERT_SUFFIX == key_type[-len(_CERT_SUFFIX) :]:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
Reference in New Issue
Block a user