API refactor
All checks were successful
continuous-integration/drone/push Build is passing

This commit is contained in:
2025-10-07 16:25:52 +09:00
parent 76d0d86211
commit 91c7e04474
1171 changed files with 81940 additions and 44117 deletions

View File

@@ -3,7 +3,7 @@ from __future__ import annotations
import json
import warnings
from calendar import timegm
from collections.abc import Iterable
from collections.abc import Iterable, Sequence
from datetime import datetime, timedelta, timezone
from typing import TYPE_CHECKING, Any
@@ -15,12 +15,15 @@ from .exceptions import (
InvalidAudienceError,
InvalidIssuedAtError,
InvalidIssuerError,
InvalidJTIError,
InvalidSubjectError,
MissingRequiredClaimError,
)
from .warnings import RemovedInPyjwt3Warning
if TYPE_CHECKING:
from .algorithms import AllowedPrivateKeys, AllowedPublicKeys
from .api_jwk import PyJWK
class PyJWT:
@@ -38,14 +41,16 @@ class PyJWT:
"verify_iat": True,
"verify_aud": True,
"verify_iss": True,
"verify_sub": True,
"verify_jti": True,
"require": [],
}
def encode(
self,
payload: dict[str, Any],
key: AllowedPrivateKeys | str | bytes,
algorithm: str | None = "HS256",
key: AllowedPrivateKeys | PyJWK | str | bytes,
algorithm: str | None = None,
headers: dict[str, Any] | None = None,
json_encoder: type[json.JSONEncoder] | None = None,
sort_headers: bool = True,
@@ -100,8 +105,8 @@ class PyJWT:
def decode_complete(
self,
jwt: str | bytes,
key: AllowedPublicKeys | str | bytes = "",
algorithms: list[str] | None = None,
key: AllowedPublicKeys | PyJWK | str | bytes = "",
algorithms: Sequence[str] | None = None,
options: dict[str, Any] | None = None,
# deprecated arg, remove in pyjwt3
verify: bool | None = None,
@@ -110,7 +115,8 @@ class PyJWT:
# passthrough arguments to _validate_claims
# consider putting in options
audience: str | Iterable[str] | None = None,
issuer: str | None = None,
issuer: str | Sequence[str] | None = None,
subject: str | None = None,
leeway: float | timedelta = 0,
# kwargs
**kwargs: Any,
@@ -121,6 +127,7 @@ class PyJWT:
"and will be removed in pyjwt version 3. "
f"Unsupported kwargs: {tuple(kwargs.keys())}",
RemovedInPyjwt3Warning,
stacklevel=2,
)
options = dict(options or {}) # shallow-copy or initialize an empty dict
options.setdefault("verify_signature", True)
@@ -134,6 +141,7 @@ class PyJWT:
"The equivalent is setting `verify_signature` to False in the `options` dictionary. "
"This invocation has a mismatch between the kwarg and the option entry.",
category=DeprecationWarning,
stacklevel=2,
)
if not options["verify_signature"]:
@@ -142,11 +150,8 @@ class PyJWT:
options.setdefault("verify_iat", False)
options.setdefault("verify_aud", False)
options.setdefault("verify_iss", False)
if options["verify_signature"] and not algorithms:
raise DecodeError(
'It is required that you pass in a value for the "algorithms" argument when calling decode().'
)
options.setdefault("verify_sub", False)
options.setdefault("verify_jti", False)
decoded = api_jws.decode_complete(
jwt,
@@ -160,7 +165,12 @@ class PyJWT:
merged_options = {**self.options, **options}
self._validate_claims(
payload, merged_options, audience=audience, issuer=issuer, leeway=leeway
payload,
merged_options,
audience=audience,
issuer=issuer,
leeway=leeway,
subject=subject,
)
decoded["payload"] = payload
@@ -177,7 +187,7 @@ class PyJWT:
try:
payload = json.loads(decoded["payload"])
except ValueError as e:
raise DecodeError(f"Invalid payload string: {e}")
raise DecodeError(f"Invalid payload string: {e}") from e
if not isinstance(payload, dict):
raise DecodeError("Invalid payload string: must be a json object")
return payload
@@ -185,8 +195,8 @@ class PyJWT:
def decode(
self,
jwt: str | bytes,
key: AllowedPublicKeys | str | bytes = "",
algorithms: list[str] | None = None,
key: AllowedPublicKeys | PyJWK | str | bytes = "",
algorithms: Sequence[str] | None = None,
options: dict[str, Any] | None = None,
# deprecated arg, remove in pyjwt3
verify: bool | None = None,
@@ -195,7 +205,8 @@ class PyJWT:
# passthrough arguments to _validate_claims
# consider putting in options
audience: str | Iterable[str] | None = None,
issuer: str | None = None,
subject: str | None = None,
issuer: str | Sequence[str] | None = None,
leeway: float | timedelta = 0,
# kwargs
**kwargs: Any,
@@ -206,6 +217,7 @@ class PyJWT:
"and will be removed in pyjwt version 3. "
f"Unsupported kwargs: {tuple(kwargs.keys())}",
RemovedInPyjwt3Warning,
stacklevel=2,
)
decoded = self.decode_complete(
jwt,
@@ -215,6 +227,7 @@ class PyJWT:
verify=verify,
detached_payload=detached_payload,
audience=audience,
subject=subject,
issuer=issuer,
leeway=leeway,
)
@@ -226,6 +239,7 @@ class PyJWT:
options: dict[str, Any],
audience=None,
issuer=None,
subject: str | None = None,
leeway: float | timedelta = 0,
) -> None:
if isinstance(leeway, timedelta):
@@ -255,6 +269,12 @@ class PyJWT:
payload, audience, strict=options.get("strict_aud", False)
)
if options["verify_sub"]:
self._validate_sub(payload, subject)
if options["verify_jti"]:
self._validate_jti(payload)
def _validate_required_claims(
self,
payload: dict[str, Any],
@@ -264,6 +284,39 @@ class PyJWT:
if payload.get(claim) is None:
raise MissingRequiredClaimError(claim)
def _validate_sub(self, payload: dict[str, Any], subject=None) -> None:
"""
Checks whether "sub" if in the payload is valid ot not.
This is an Optional claim
:param payload(dict): The payload which needs to be validated
:param subject(str): The subject of the token
"""
if "sub" not in payload:
return
if not isinstance(payload["sub"], str):
raise InvalidSubjectError("Subject must be a string")
if subject is not None:
if payload.get("sub") != subject:
raise InvalidSubjectError("Invalid subject")
def _validate_jti(self, payload: dict[str, Any]) -> None:
"""
Checks whether "jti" if in the payload is valid ot not
This is an Optional claim
:param payload(dict): The payload which needs to be validated
"""
if "jti" not in payload:
return
if not isinstance(payload.get("jti"), str):
raise InvalidJTIError("JWT ID must be a string")
def _validate_iat(
self,
payload: dict[str, Any],
@@ -273,7 +326,9 @@ class PyJWT:
try:
iat = int(payload["iat"])
except ValueError:
raise InvalidIssuedAtError("Issued At claim (iat) must be an integer.")
raise InvalidIssuedAtError(
"Issued At claim (iat) must be an integer."
) from None
if iat > (now + leeway):
raise ImmatureSignatureError("The token is not yet valid (iat)")
@@ -286,7 +341,7 @@ class PyJWT:
try:
nbf = int(payload["nbf"])
except ValueError:
raise DecodeError("Not Before claim (nbf) must be an integer.")
raise DecodeError("Not Before claim (nbf) must be an integer.") from None
if nbf > (now + leeway):
raise ImmatureSignatureError("The token is not yet valid (nbf)")
@@ -300,7 +355,9 @@ class PyJWT:
try:
exp = int(payload["exp"])
except ValueError:
raise DecodeError("Expiration Time claim (exp) must be an" " integer.")
raise DecodeError(
"Expiration Time claim (exp) must be an integer."
) from None
if exp <= (now - leeway):
raise ExpiredSignatureError("Signature has expired")
@@ -362,8 +419,12 @@ class PyJWT:
if "iss" not in payload:
raise MissingRequiredClaimError("iss")
if payload["iss"] != issuer:
raise InvalidIssuerError("Invalid issuer")
if isinstance(issuer, str):
if payload["iss"] != issuer:
raise InvalidIssuerError("Invalid issuer")
else:
if payload["iss"] not in issuer:
raise InvalidIssuerError("Invalid issuer")
_jwt_global_obj = PyJWT()