This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
__version__ = "3.3.0"
|
||||
__version__ = "3.5.0"
|
||||
__author__ = "Michael Davis"
|
||||
__license__ = "MIT"
|
||||
__copyright__ = "Copyright 2016 Michael Davis"
|
||||
|
||||
@@ -1,10 +1,4 @@
|
||||
try:
|
||||
from jose.backends.cryptography_backend import get_random_bytes # noqa: F401
|
||||
except ImportError:
|
||||
try:
|
||||
from jose.backends.pycrypto_backend import get_random_bytes # noqa: F401
|
||||
except ImportError:
|
||||
from jose.backends.native import get_random_bytes # noqa: F401
|
||||
from jose.backends.native import get_random_bytes # noqa: F401
|
||||
|
||||
try:
|
||||
from jose.backends.cryptography_backend import CryptographyRSAKey as RSAKey # noqa: F401
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
Required by rsa_backend but not cryptography_backend.
|
||||
"""
|
||||
|
||||
from pyasn1.codec.der import decoder, encoder
|
||||
from pyasn1.type import namedtype, univ
|
||||
|
||||
|
||||
@@ -3,7 +3,6 @@ import warnings
|
||||
|
||||
from cryptography.exceptions import InvalidSignature, InvalidTag
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
from cryptography.hazmat.bindings.openssl.binding import Binding
|
||||
from cryptography.hazmat.primitives import hashes, hmac, serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import ec, padding, rsa
|
||||
from cryptography.hazmat.primitives.asymmetric.utils import decode_dss_signature, encode_dss_signature
|
||||
@@ -16,35 +15,21 @@ from cryptography.x509 import load_pem_x509_certificate
|
||||
|
||||
from ..constants import ALGORITHMS
|
||||
from ..exceptions import JWEError, JWKError
|
||||
from ..utils import base64_to_long, base64url_decode, base64url_encode, ensure_binary, long_to_base64
|
||||
from ..utils import (
|
||||
base64_to_long,
|
||||
base64url_decode,
|
||||
base64url_encode,
|
||||
ensure_binary,
|
||||
is_pem_format,
|
||||
is_ssh_key,
|
||||
long_to_base64,
|
||||
)
|
||||
from . import get_random_bytes
|
||||
from .base import Key
|
||||
|
||||
_binding = None
|
||||
|
||||
|
||||
def get_random_bytes(num_bytes):
|
||||
"""
|
||||
Get random bytes
|
||||
|
||||
Currently, Cryptography returns OS random bytes. If you want OpenSSL
|
||||
generated random bytes, you'll have to switch the RAND engine after
|
||||
initializing the OpenSSL backend
|
||||
Args:
|
||||
num_bytes (int): Number of random bytes to generate and return
|
||||
Returns:
|
||||
bytes: Random bytes
|
||||
"""
|
||||
global _binding
|
||||
|
||||
if _binding is None:
|
||||
_binding = Binding()
|
||||
|
||||
buf = _binding.ffi.new("char[]", num_bytes)
|
||||
_binding.lib.RAND_bytes(buf, num_bytes)
|
||||
rand_bytes = _binding.ffi.buffer(buf, num_bytes)[:]
|
||||
return rand_bytes
|
||||
|
||||
|
||||
class CryptographyECKey(Key):
|
||||
SHA256 = hashes.SHA256
|
||||
SHA384 = hashes.SHA384
|
||||
@@ -243,8 +228,8 @@ class CryptographyRSAKey(Key):
|
||||
|
||||
self.cryptography_backend = cryptography_backend
|
||||
|
||||
# if it conforms to RSAPublicKey interface
|
||||
if hasattr(key, "public_bytes") and hasattr(key, "public_numbers"):
|
||||
# if it conforms to RSAPublicKey or RSAPrivateKey interface
|
||||
if (hasattr(key, "public_bytes") and hasattr(key, "public_numbers")) or hasattr(key, "private_bytes"):
|
||||
self.prepared_key = key
|
||||
return
|
||||
|
||||
@@ -439,6 +424,8 @@ class CryptographyAESKey(Key):
|
||||
ALGORITHMS.A256KW: None,
|
||||
}
|
||||
|
||||
IV_BYTE_LENGTH_MODE_MAP = {"CBC": algorithms.AES.block_size // 8, "GCM": 96 // 8}
|
||||
|
||||
def __init__(self, key, algorithm):
|
||||
if algorithm not in ALGORITHMS.AES:
|
||||
raise JWKError("%s is not a valid AES algorithm" % algorithm)
|
||||
@@ -468,7 +455,8 @@ class CryptographyAESKey(Key):
|
||||
def encrypt(self, plain_text, aad=None):
|
||||
plain_text = ensure_binary(plain_text)
|
||||
try:
|
||||
iv = get_random_bytes(algorithms.AES.block_size // 8)
|
||||
iv_byte_length = self.IV_BYTE_LENGTH_MODE_MAP.get(self._mode.name, algorithms.AES.block_size)
|
||||
iv = get_random_bytes(iv_byte_length)
|
||||
mode = self._mode(iv)
|
||||
if mode.name == "GCM":
|
||||
cipher = aead.AESGCM(self._key)
|
||||
@@ -552,14 +540,7 @@ class CryptographyHMACKey(Key):
|
||||
if isinstance(key, str):
|
||||
key = key.encode("utf-8")
|
||||
|
||||
invalid_strings = [
|
||||
b"-----BEGIN PUBLIC KEY-----",
|
||||
b"-----BEGIN RSA PUBLIC KEY-----",
|
||||
b"-----BEGIN CERTIFICATE-----",
|
||||
b"ssh-rsa",
|
||||
]
|
||||
|
||||
if any(string_value in key for string_value in invalid_strings):
|
||||
if is_pem_format(key) or is_ssh_key(key):
|
||||
raise JWKError(
|
||||
"The specified key is an asymmetric key or x509 certificate and"
|
||||
" should not be used as an HMAC secret."
|
||||
|
||||
@@ -5,7 +5,7 @@ import os
|
||||
from jose.backends.base import Key
|
||||
from jose.constants import ALGORITHMS
|
||||
from jose.exceptions import JWKError
|
||||
from jose.utils import base64url_decode, base64url_encode
|
||||
from jose.utils import base64url_decode, base64url_encode, is_pem_format, is_ssh_key
|
||||
|
||||
|
||||
def get_random_bytes(num_bytes):
|
||||
@@ -36,14 +36,7 @@ class HMACKey(Key):
|
||||
if isinstance(key, str):
|
||||
key = key.encode("utf-8")
|
||||
|
||||
invalid_strings = [
|
||||
b"-----BEGIN PUBLIC KEY-----",
|
||||
b"-----BEGIN RSA PUBLIC KEY-----",
|
||||
b"-----BEGIN CERTIFICATE-----",
|
||||
b"ssh-rsa",
|
||||
]
|
||||
|
||||
if any(string_value in key for string_value in invalid_strings):
|
||||
if is_pem_format(key) or is_ssh_key(key):
|
||||
raise JWKError(
|
||||
"The specified key is an asymmetric key or x509 certificate and"
|
||||
" should not be used as an HMAC secret."
|
||||
|
||||
@@ -221,7 +221,6 @@ class RSAKey(Key):
|
||||
return self.__class__(pyrsa.PublicKey(n=self._prepared_key.n, e=self._prepared_key.e), self._algorithm)
|
||||
|
||||
def to_pem(self, pem_format="PKCS8"):
|
||||
|
||||
if isinstance(self._prepared_key, pyrsa.PrivateKey):
|
||||
der = self._prepared_key.save_pkcs1(format="DER")
|
||||
if pem_format == "PKCS8":
|
||||
|
||||
@@ -96,3 +96,5 @@ class Zips:
|
||||
|
||||
|
||||
ZIPS = Zips()
|
||||
|
||||
JWE_SIZE_LIMIT = 250 * 1024
|
||||
|
||||
@@ -6,13 +6,13 @@ from struct import pack
|
||||
|
||||
from . import jwk
|
||||
from .backends import get_random_bytes
|
||||
from .constants import ALGORITHMS, ZIPS
|
||||
from .constants import ALGORITHMS, JWE_SIZE_LIMIT, ZIPS
|
||||
from .exceptions import JWEError, JWEParseError
|
||||
from .utils import base64url_decode, base64url_encode, ensure_binary
|
||||
|
||||
|
||||
def encrypt(plaintext, key, encryption=ALGORITHMS.A256GCM, algorithm=ALGORITHMS.DIR, zip=None, cty=None, kid=None):
|
||||
"""Encrypts plaintext and returns a JWE cmpact serialization string.
|
||||
"""Encrypts plaintext and returns a JWE compact serialization string.
|
||||
|
||||
Args:
|
||||
plaintext (bytes): A bytes object to encrypt
|
||||
@@ -76,6 +76,13 @@ def decrypt(jwe_str, key):
|
||||
>>> jwe.decrypt(jwe_string, 'asecret128bitkey')
|
||||
'Hello, World!'
|
||||
"""
|
||||
|
||||
# Limit the token size - if the data is compressed then decompressing the
|
||||
# data could lead to large memory usage. This helps address This addresses
|
||||
# CVE-2024-33664. Also see _decompress()
|
||||
if len(jwe_str) > JWE_SIZE_LIMIT:
|
||||
raise JWEError(f"JWE string {len(jwe_str)} bytes exceeds {JWE_SIZE_LIMIT} bytes")
|
||||
|
||||
header, encoded_header, encrypted_key, iv, cipher_text, auth_tag = _jwe_compact_deserialize(jwe_str)
|
||||
|
||||
# Verify that the implementation understands and can process all
|
||||
@@ -424,13 +431,13 @@ def _compress(zip, plaintext):
|
||||
(bytes): Compressed plaintext
|
||||
"""
|
||||
if zip not in ZIPS.SUPPORTED:
|
||||
raise NotImplementedError("ZIP {} is not supported!")
|
||||
raise NotImplementedError(f"ZIP {zip} is not supported!")
|
||||
if zip is None:
|
||||
compressed = plaintext
|
||||
elif zip == ZIPS.DEF:
|
||||
compressed = zlib.compress(plaintext)
|
||||
else:
|
||||
raise NotImplementedError("ZIP {} is not implemented!")
|
||||
raise NotImplementedError(f"ZIP {zip} is not implemented!")
|
||||
return compressed
|
||||
|
||||
|
||||
@@ -446,13 +453,18 @@ def _decompress(zip, compressed):
|
||||
(bytes): Compressed plaintext
|
||||
"""
|
||||
if zip not in ZIPS.SUPPORTED:
|
||||
raise NotImplementedError("ZIP {} is not supported!")
|
||||
raise NotImplementedError(f"ZIP {zip} is not supported!")
|
||||
if zip is None:
|
||||
decompressed = compressed
|
||||
elif zip == ZIPS.DEF:
|
||||
decompressed = zlib.decompress(compressed)
|
||||
# If, during decompression, there is more data than expected, the
|
||||
# decompression halts and raise an error. This addresses CVE-2024-33664
|
||||
decompressor = zlib.decompressobj()
|
||||
decompressed = decompressor.decompress(compressed, max_length=JWE_SIZE_LIMIT)
|
||||
if decompressor.unconsumed_tail:
|
||||
raise JWEError(f"Decompressed JWE string exceeds {JWE_SIZE_LIMIT} bytes")
|
||||
else:
|
||||
raise NotImplementedError("ZIP {} is not implemented!")
|
||||
raise NotImplementedError(f"ZIP {zip} is not implemented!")
|
||||
return decompressed
|
||||
|
||||
|
||||
@@ -530,7 +542,7 @@ def _get_key_wrap_cek(enc, key):
|
||||
|
||||
def _get_random_cek_bytes_for_enc(enc):
|
||||
"""
|
||||
Get the random cek bytes based on the encryptionn algorithm
|
||||
Get the random cek bytes based on the encryption algorithm
|
||||
|
||||
Args:
|
||||
enc (str): Encryption algorithm
|
||||
|
||||
@@ -71,9 +71,9 @@ def construct(key_data, algorithm=None):
|
||||
algorithm = key_data.get("alg", None)
|
||||
|
||||
if not algorithm:
|
||||
raise JWKError("Unable to find an algorithm for key: %s" % key_data)
|
||||
raise JWKError("Unable to find an algorithm for key")
|
||||
|
||||
key_class = get_key(algorithm)
|
||||
if not key_class:
|
||||
raise JWKError("Unable to find an algorithm for key: %s" % key_data)
|
||||
raise JWKError("Unable to find an algorithm for key")
|
||||
return key_class(key_data, algorithm)
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
import binascii
|
||||
import json
|
||||
from collections.abc import Iterable, Mapping
|
||||
|
||||
try:
|
||||
from collections.abc import Iterable, Mapping
|
||||
except ImportError:
|
||||
from collections import Mapping, Iterable
|
||||
|
||||
from jose import jwk
|
||||
from jose.backends.base import Key
|
||||
@@ -215,7 +219,6 @@ def _sig_matches_keys(keys, signing_input, signature, alg):
|
||||
|
||||
|
||||
def _get_keys(key):
|
||||
|
||||
if isinstance(key, Key):
|
||||
return (key,)
|
||||
|
||||
@@ -248,7 +251,6 @@ def _get_keys(key):
|
||||
|
||||
|
||||
def _verify_signature(signing_input, header, signature, key="", algorithms=None):
|
||||
|
||||
alg = header.get("alg")
|
||||
if not alg:
|
||||
raise JWSError("No algorithm was specified in the JWS header.")
|
||||
|
||||
@@ -1,8 +1,19 @@
|
||||
import json
|
||||
from calendar import timegm
|
||||
from collections.abc import Mapping
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
try:
|
||||
from collections.abc import Mapping
|
||||
except ImportError:
|
||||
from collections import Mapping
|
||||
|
||||
try:
|
||||
from datetime import UTC # Preferred in Python 3.13+
|
||||
except ImportError:
|
||||
from datetime import timezone
|
||||
|
||||
UTC = timezone.utc # Preferred in Python 3.12 and below
|
||||
|
||||
from jose import jws
|
||||
|
||||
from .constants import ALGORITHMS
|
||||
@@ -42,7 +53,6 @@ def encode(claims, key, algorithm=ALGORITHMS.HS256, headers=None, access_token=N
|
||||
"""
|
||||
|
||||
for time_claim in ["exp", "iat", "nbf"]:
|
||||
|
||||
# Convert datetime to a intDate value in known time-format claims
|
||||
if isinstance(claims.get(time_claim), datetime):
|
||||
claims[time_claim] = timegm(claims[time_claim].utctimetuple())
|
||||
@@ -58,8 +68,15 @@ def decode(token, key, algorithms=None, options=None, audience=None, issuer=None
|
||||
|
||||
Args:
|
||||
token (str): A signed JWS to be verified.
|
||||
key (str or dict): A key to attempt to verify the payload with. Can be
|
||||
individual JWK or JWK set.
|
||||
key (str or iterable): A key to attempt to verify the payload with.
|
||||
This can be simple string with an individual key (e.g. "a1234"),
|
||||
a tuple or list of keys (e.g. ("a1234...", "b3579"),
|
||||
a JSON string, (e.g. '["a1234", "b3579"]'),
|
||||
a dict with the 'keys' key that gives a tuple or list of keys (e.g {'keys': [...]} ) or
|
||||
a dict or JSON string for a JWK set as defined by RFC 7517 (e.g.
|
||||
{'keys': [{'kty': 'oct', 'k': 'YTEyMzQ'}, {'kty': 'oct', 'k':'YjM1Nzk'}]} or
|
||||
'{"keys": [{"kty":"oct","k":"YTEyMzQ"},{"kty":"oct","k":"YjM1Nzk"}]}'
|
||||
) in which case the keys must be base64 url safe encoded (with optional padding).
|
||||
algorithms (str or list): Valid algorithms that should be used to verify the JWS.
|
||||
audience (str): The intended audience of the token. If the "aud" claim is
|
||||
included in the claim set, then the audience must be included and must equal
|
||||
@@ -278,7 +295,7 @@ def _validate_nbf(claims, leeway=0):
|
||||
except ValueError:
|
||||
raise JWTClaimsError("Not Before claim (nbf) must be an integer.")
|
||||
|
||||
now = timegm(datetime.utcnow().utctimetuple())
|
||||
now = timegm(datetime.now(UTC).utctimetuple())
|
||||
|
||||
if nbf > (now + leeway):
|
||||
raise JWTClaimsError("The token is not yet valid (nbf)")
|
||||
@@ -308,7 +325,7 @@ def _validate_exp(claims, leeway=0):
|
||||
except ValueError:
|
||||
raise JWTClaimsError("Expiration Time claim (exp) must be an integer.")
|
||||
|
||||
now = timegm(datetime.utcnow().utctimetuple())
|
||||
now = timegm(datetime.now(UTC).utctimetuple())
|
||||
|
||||
if exp < (now - leeway):
|
||||
raise ExpiredSignatureError("Signature has expired.")
|
||||
@@ -382,7 +399,7 @@ def _validate_sub(claims, subject=None):
|
||||
"sub" value is a case-sensitive string containing a StringOrURI
|
||||
value. Use of this claim is OPTIONAL.
|
||||
|
||||
Args:
|
||||
Arg
|
||||
claims (dict): The claims dictionary to validate.
|
||||
subject (str): The subject of the token.
|
||||
"""
|
||||
@@ -456,7 +473,6 @@ def _validate_at_hash(claims, access_token, algorithm):
|
||||
|
||||
|
||||
def _validate_claims(claims, audience=None, issuer=None, subject=None, algorithm=None, access_token=None, options=None):
|
||||
|
||||
leeway = options.get("leeway", 0)
|
||||
|
||||
if isinstance(leeway, timedelta):
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import base64
|
||||
import re
|
||||
import struct
|
||||
|
||||
# Piggyback of the backends implementation of the function that converts a long
|
||||
@@ -9,7 +10,6 @@ try:
|
||||
def long_to_bytes(n, blocksize=0):
|
||||
return _long_to_bytes(n, blocksize or None)
|
||||
|
||||
|
||||
except ImportError:
|
||||
from ecdsa.ecdsa import int_to_string as _long_to_bytes
|
||||
|
||||
@@ -67,7 +67,7 @@ def base64url_decode(input):
|
||||
"""Helper method to base64url_decode a string.
|
||||
|
||||
Args:
|
||||
input (str): A base64url_encoded string to decode.
|
||||
input (bytes): A base64url_encoded string (bytes) to decode.
|
||||
|
||||
"""
|
||||
rem = len(input) % 4
|
||||
@@ -82,7 +82,7 @@ def base64url_encode(input):
|
||||
"""Helper method to base64url_encode a string.
|
||||
|
||||
Args:
|
||||
input (str): A base64url_encoded string to encode.
|
||||
input (bytes): A base64url_encoded string (bytes) to encode.
|
||||
|
||||
"""
|
||||
return base64.urlsafe_b64encode(input).replace(b"=", b"")
|
||||
@@ -106,3 +106,60 @@ def ensure_binary(s):
|
||||
if isinstance(s, str):
|
||||
return s.encode("utf-8", "strict")
|
||||
raise TypeError(f"not expecting type '{type(s)}'")
|
||||
|
||||
|
||||
# The following was copied from PyJWT:
|
||||
# https://github.com/jpadilla/pyjwt/commit/9c528670c455b8d948aff95ed50e22940d1ad3fc
|
||||
# Based on:
|
||||
# https://github.com/hynek/pem/blob/7ad94db26b0bc21d10953f5dbad3acfdfacf57aa/src/pem/_core.py#L224-L252
|
||||
_PEMS = {
|
||||
b"CERTIFICATE",
|
||||
b"TRUSTED CERTIFICATE",
|
||||
b"PRIVATE KEY",
|
||||
b"PUBLIC KEY",
|
||||
b"ENCRYPTED PRIVATE KEY",
|
||||
b"OPENSSH PRIVATE KEY",
|
||||
b"DSA PRIVATE KEY",
|
||||
b"RSA PRIVATE KEY",
|
||||
b"RSA PUBLIC KEY",
|
||||
b"EC PRIVATE KEY",
|
||||
b"DH PARAMETERS",
|
||||
b"NEW CERTIFICATE REQUEST",
|
||||
b"CERTIFICATE REQUEST",
|
||||
b"SSH2 PUBLIC KEY",
|
||||
b"SSH2 ENCRYPTED PRIVATE KEY",
|
||||
b"X509 CRL",
|
||||
}
|
||||
_PEM_RE = re.compile(
|
||||
b"----[- ]BEGIN (" + b"|".join(re.escape(pem) for pem in _PEMS) + b")[- ]----",
|
||||
)
|
||||
|
||||
|
||||
def is_pem_format(key: bytes) -> bool:
|
||||
return bool(_PEM_RE.search(key))
|
||||
|
||||
|
||||
# Based on
|
||||
# https://github.com/pyca/cryptography/blob/bcb70852d577b3f490f015378c75cba74986297b
|
||||
# /src/cryptography/hazmat/primitives/serialization/ssh.py#L40-L46
|
||||
_CERT_SUFFIX = b"-cert-v01@openssh.com"
|
||||
_SSH_PUBKEY_RC = re.compile(rb"\A(\S+)[ \t]+(\S+)")
|
||||
_SSH_KEY_FORMATS = [
|
||||
b"ssh-ed25519",
|
||||
b"ssh-rsa",
|
||||
b"ssh-dss",
|
||||
b"ecdsa-sha2-nistp256",
|
||||
b"ecdsa-sha2-nistp384",
|
||||
b"ecdsa-sha2-nistp521",
|
||||
]
|
||||
|
||||
|
||||
def is_ssh_key(key: bytes) -> bool:
|
||||
if any(string_value in key for string_value in _SSH_KEY_FORMATS):
|
||||
return True
|
||||
ssh_pubkey_match = _SSH_PUBKEY_RC.match(key)
|
||||
if ssh_pubkey_match:
|
||||
key_type = ssh_pubkey_match.group(1)
|
||||
if _CERT_SUFFIX == key_type[-len(_CERT_SUFFIX) :]:
|
||||
return True
|
||||
return False
|
||||
|
||||
Reference in New Issue
Block a user