This commit is contained in:
@@ -3,7 +3,6 @@ from __future__ import annotations
|
||||
import binascii
|
||||
import json
|
||||
import warnings
|
||||
from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from .algorithms import (
|
||||
@@ -12,7 +11,6 @@ from .algorithms import (
|
||||
has_crypto,
|
||||
requires_cryptography,
|
||||
)
|
||||
from .api_jwk import PyJWK
|
||||
from .exceptions import (
|
||||
DecodeError,
|
||||
InvalidAlgorithmError,
|
||||
@@ -31,7 +29,7 @@ class PyJWS:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
algorithms: Sequence[str] | None = None,
|
||||
algorithms: list[str] | None = None,
|
||||
options: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
self._algorithms = get_default_algorithms()
|
||||
@@ -105,8 +103,8 @@ class PyJWS:
|
||||
def encode(
|
||||
self,
|
||||
payload: bytes,
|
||||
key: AllowedPrivateKeys | PyJWK | str | bytes,
|
||||
algorithm: str | None = None,
|
||||
key: AllowedPrivateKeys | str | bytes,
|
||||
algorithm: str | None = "HS256",
|
||||
headers: dict[str, Any] | None = None,
|
||||
json_encoder: type[json.JSONEncoder] | None = None,
|
||||
is_payload_detached: bool = False,
|
||||
@@ -115,13 +113,7 @@ class PyJWS:
|
||||
segments = []
|
||||
|
||||
# declare a new var to narrow the type for type checkers
|
||||
if algorithm is None:
|
||||
if isinstance(key, PyJWK):
|
||||
algorithm_ = key.algorithm_name
|
||||
else:
|
||||
algorithm_ = "HS256"
|
||||
else:
|
||||
algorithm_ = algorithm
|
||||
algorithm_: str = algorithm if algorithm is not None else "none"
|
||||
|
||||
# Prefer headers values if present to function parameters.
|
||||
if headers:
|
||||
@@ -165,8 +157,6 @@ class PyJWS:
|
||||
signing_input = b".".join(segments)
|
||||
|
||||
alg_obj = self.get_algorithm_by_name(algorithm_)
|
||||
if isinstance(key, PyJWK):
|
||||
key = key.key
|
||||
key = alg_obj.prepare_key(key)
|
||||
signature = alg_obj.sign(signing_input, key)
|
||||
|
||||
@@ -182,8 +172,8 @@ class PyJWS:
|
||||
def decode_complete(
|
||||
self,
|
||||
jwt: str | bytes,
|
||||
key: AllowedPublicKeys | PyJWK | str | bytes = "",
|
||||
algorithms: Sequence[str] | None = None,
|
||||
key: AllowedPublicKeys | str | bytes = "",
|
||||
algorithms: list[str] | None = None,
|
||||
options: dict[str, Any] | None = None,
|
||||
detached_payload: bytes | None = None,
|
||||
**kwargs,
|
||||
@@ -194,14 +184,13 @@ class PyJWS:
|
||||
"and will be removed in pyjwt version 3. "
|
||||
f"Unsupported kwargs: {tuple(kwargs.keys())}",
|
||||
RemovedInPyjwt3Warning,
|
||||
stacklevel=2,
|
||||
)
|
||||
if options is None:
|
||||
options = {}
|
||||
merged_options = {**self.options, **options}
|
||||
verify_signature = merged_options["verify_signature"]
|
||||
|
||||
if verify_signature and not algorithms and not isinstance(key, PyJWK):
|
||||
if verify_signature and not algorithms:
|
||||
raise DecodeError(
|
||||
'It is required that you pass in a value for the "algorithms" argument when calling decode().'
|
||||
)
|
||||
@@ -228,8 +217,8 @@ class PyJWS:
|
||||
def decode(
|
||||
self,
|
||||
jwt: str | bytes,
|
||||
key: AllowedPublicKeys | PyJWK | str | bytes = "",
|
||||
algorithms: Sequence[str] | None = None,
|
||||
key: AllowedPublicKeys | str | bytes = "",
|
||||
algorithms: list[str] | None = None,
|
||||
options: dict[str, Any] | None = None,
|
||||
detached_payload: bytes | None = None,
|
||||
**kwargs,
|
||||
@@ -240,7 +229,6 @@ class PyJWS:
|
||||
"and will be removed in pyjwt version 3. "
|
||||
f"Unsupported kwargs: {tuple(kwargs.keys())}",
|
||||
RemovedInPyjwt3Warning,
|
||||
stacklevel=2,
|
||||
)
|
||||
decoded = self.decode_complete(
|
||||
jwt, key, algorithms, options, detached_payload=detached_payload
|
||||
@@ -301,28 +289,22 @@ class PyJWS:
|
||||
signing_input: bytes,
|
||||
header: dict[str, Any],
|
||||
signature: bytes,
|
||||
key: AllowedPublicKeys | PyJWK | str | bytes = "",
|
||||
algorithms: Sequence[str] | None = None,
|
||||
key: AllowedPublicKeys | str | bytes = "",
|
||||
algorithms: list[str] | None = None,
|
||||
) -> None:
|
||||
if algorithms is None and isinstance(key, PyJWK):
|
||||
algorithms = [key.algorithm_name]
|
||||
try:
|
||||
alg = header["alg"]
|
||||
except KeyError:
|
||||
raise InvalidAlgorithmError("Algorithm not specified") from None
|
||||
raise InvalidAlgorithmError("Algorithm not specified")
|
||||
|
||||
if not alg or (algorithms is not None and alg not in algorithms):
|
||||
raise InvalidAlgorithmError("The specified alg value is not allowed")
|
||||
|
||||
if isinstance(key, PyJWK):
|
||||
alg_obj = key.Algorithm
|
||||
prepared_key = key.key
|
||||
else:
|
||||
try:
|
||||
alg_obj = self.get_algorithm_by_name(alg)
|
||||
except NotImplementedError as e:
|
||||
raise InvalidAlgorithmError("Algorithm not supported") from e
|
||||
prepared_key = alg_obj.prepare_key(key)
|
||||
try:
|
||||
alg_obj = self.get_algorithm_by_name(alg)
|
||||
except NotImplementedError as e:
|
||||
raise InvalidAlgorithmError("Algorithm not supported") from e
|
||||
prepared_key = alg_obj.prepare_key(key)
|
||||
|
||||
if not alg_obj.verify(signing_input, prepared_key, signature):
|
||||
raise InvalidSignatureError("Signature verification failed")
|
||||
|
||||
Reference in New Issue
Block a user