API refactor
All checks were successful
continuous-integration/drone/push Build is passing

This commit is contained in:
2025-10-07 16:25:52 +09:00
parent 76d0d86211
commit 91c7e04474
1171 changed files with 81940 additions and 44117 deletions

View File

@@ -249,9 +249,13 @@ class AMQP:
if max_priority is None:
max_priority = conf.task_queue_max_priority
if not queues and conf.task_default_queue:
queue_arguments = None
if conf.task_default_queue_type == 'quorum':
queue_arguments = {'x-queue-type': 'quorum'}
queues = (Queue(conf.task_default_queue,
exchange=self.default_exchange,
routing_key=default_routing_key),)
routing_key=default_routing_key,
queue_arguments=queue_arguments),)
autoexchange = (self.autoexchange if autoexchange is None
else autoexchange)
return self.queues_cls(
@@ -285,7 +289,7 @@ class AMQP:
create_sent_event=False, root_id=None, parent_id=None,
shadow=None, chain=None, now=None, timezone=None,
origin=None, ignore_result=False, argsrepr=None, kwargsrepr=None, stamped_headers=None,
**options):
replaced_task_nesting=0, **options):
args = args or ()
kwargs = kwargs or {}
@@ -339,6 +343,7 @@ class AMQP:
'kwargsrepr': kwargsrepr,
'origin': origin or anon_nodename(),
'ignore_result': ignore_result,
'replaced_task_nesting': replaced_task_nesting,
'stamped_headers': stamped_headers,
'stamps': stamps,
}
@@ -462,7 +467,8 @@ class AMQP:
retry=None, retry_policy=None,
serializer=None, delivery_mode=None,
compression=None, declare=None,
headers=None, exchange_type=None, **kwargs):
headers=None, exchange_type=None,
timeout=None, confirm_timeout=None, **kwargs):
retry = default_retry if retry is None else retry
headers2, properties, body, sent_event = message
if headers:
@@ -523,6 +529,7 @@ class AMQP:
retry=retry, retry_policy=_rp,
delivery_mode=delivery_mode, declare=declare,
headers=headers2,
timeout=timeout, confirm_timeout=confirm_timeout,
**properties
)
if after_receivers:

View File

@@ -34,6 +34,7 @@ BACKEND_ALIASES = {
'azureblockblob': 'celery.backends.azureblockblob:AzureBlockBlobBackend',
'arangodb': 'celery.backends.arangodb:ArangoDbBackend',
's3': 'celery.backends.s3:S3Backend',
'gs': 'celery.backends.gcs:GCSBackend',
}

View File

@@ -1,17 +1,23 @@
"""Actual App instance implementation."""
import functools
import importlib
import inspect
import os
import sys
import threading
import typing
import warnings
from collections import UserDict, defaultdict, deque
from datetime import datetime
from datetime import timezone as datetime_timezone
from operator import attrgetter
from click.exceptions import Exit
from kombu import pools
from dateutil.parser import isoparse
from kombu import Exchange, pools
from kombu.clocks import LamportClock
from kombu.common import oid_from
from kombu.transport.native_delayed_delivery import calculate_routing_key
from kombu.utils.compat import register_after_fork
from kombu.utils.objects import cached_property
from kombu.utils.uuid import uuid
@@ -32,6 +38,8 @@ from celery.utils.log import get_logger
from celery.utils.objects import FallbackContext, mro_lookup
from celery.utils.time import maybe_make_aware, timezone, to_utc
from ..utils.annotations import annotation_is_class, annotation_issubclass, get_optional_arg
from ..utils.quorum_queues import detect_quorum_queues
# Load all builtin tasks
from . import backends, builtins # noqa
from .annotations import prepare as prepare_annotations
@@ -41,6 +49,10 @@ from .registry import TaskRegistry
from .utils import (AppPickler, Settings, _new_key_to_old, _old_key_to_new, _unpickle_app, _unpickle_app_v2, appstr,
bugreport, detect_settings)
if typing.TYPE_CHECKING: # pragma: no cover # codecov does not capture this
# flake8 marks the BaseModel import as unused, because the actual typehint is quoted.
from pydantic import BaseModel # noqa: F401
__all__ = ('Celery',)
logger = get_logger(__name__)
@@ -90,6 +102,70 @@ def _after_fork_cleanup_app(app):
logger.info('after forker raised exception: %r', exc, exc_info=1)
def pydantic_wrapper(
app: "Celery",
task_fun: typing.Callable[..., typing.Any],
task_name: str,
strict: bool = True,
context: typing.Optional[typing.Dict[str, typing.Any]] = None,
dump_kwargs: typing.Optional[typing.Dict[str, typing.Any]] = None
):
"""Wrapper to validate arguments and serialize return values using Pydantic."""
try:
pydantic = importlib.import_module('pydantic')
except ModuleNotFoundError as ex:
raise ImproperlyConfigured('You need to install pydantic to use pydantic model serialization.') from ex
BaseModel: typing.Type['BaseModel'] = pydantic.BaseModel # noqa: F811 # only defined when type checking
if context is None:
context = {}
if dump_kwargs is None:
dump_kwargs = {}
dump_kwargs.setdefault('mode', 'json')
task_signature = inspect.signature(task_fun)
@functools.wraps(task_fun)
def wrapper(*task_args, **task_kwargs):
# Validate task parameters if type hinted as BaseModel
bound_args = task_signature.bind(*task_args, **task_kwargs)
for arg_name, arg_value in bound_args.arguments.items():
arg_annotation = task_signature.parameters[arg_name].annotation
optional_arg = get_optional_arg(arg_annotation)
if optional_arg is not None and arg_value is not None:
arg_annotation = optional_arg
if annotation_issubclass(arg_annotation, BaseModel):
bound_args.arguments[arg_name] = arg_annotation.model_validate(
arg_value,
strict=strict,
context={**context, 'celery_app': app, 'celery_task_name': task_name},
)
# Call the task with (potentially) converted arguments
returned_value = task_fun(*bound_args.args, **bound_args.kwargs)
# Dump Pydantic model if the returned value is an instance of pydantic.BaseModel *and* its
# class matches the typehint
return_annotation = task_signature.return_annotation
optional_return_annotation = get_optional_arg(return_annotation)
if optional_return_annotation is not None:
return_annotation = optional_return_annotation
if (
annotation_is_class(return_annotation)
and isinstance(returned_value, BaseModel)
and isinstance(returned_value, return_annotation)
):
return returned_value.model_dump(**dump_kwargs)
return returned_value
return wrapper
class PendingConfiguration(UserDict, AttributeDictMixin):
# `app.conf` will be of this type before being explicitly configured,
# meaning the app can keep any configuration set directly
@@ -238,6 +314,12 @@ class Celery:
self.loader_cls = loader or self._get_default_loader()
self.log_cls = log or self.log_cls
self.control_cls = control or self.control_cls
self._custom_task_cls_used = (
# Custom task class provided as argument
bool(task_cls)
# subclass of Celery with a task_cls attribute
or self.__class__ is not Celery and hasattr(self.__class__, 'task_cls')
)
self.task_cls = task_cls or self.task_cls
self.set_as_current = set_as_current
self.registry_cls = symbol_by_name(self.registry_cls)
@@ -433,6 +515,7 @@ class Celery:
if shared:
def cons(app):
return app._task_from_fun(fun, **opts)
cons.__name__ = fun.__name__
connect_on_app_finalize(cons)
if not lazy or self.finalized:
@@ -461,13 +544,27 @@ class Celery:
def type_checker(self, fun, bound=False):
return staticmethod(head_from_fun(fun, bound=bound))
def _task_from_fun(self, fun, name=None, base=None, bind=False, **options):
def _task_from_fun(
self,
fun,
name=None,
base=None,
bind=False,
pydantic: bool = False,
pydantic_strict: bool = False,
pydantic_context: typing.Optional[typing.Dict[str, typing.Any]] = None,
pydantic_dump_kwargs: typing.Optional[typing.Dict[str, typing.Any]] = None,
**options,
):
if not self.finalized and not self.autofinalize:
raise RuntimeError('Contract breach: app not finalized')
name = name or self.gen_task_name(fun.__name__, fun.__module__)
base = base or self.Task
if name not in self._tasks:
if pydantic is True:
fun = pydantic_wrapper(self, fun, name, pydantic_strict, pydantic_context, pydantic_dump_kwargs)
run = fun if bind else staticmethod(fun)
task = type(fun.__name__, (base,), dict({
'app': self,
@@ -711,7 +808,7 @@ class Celery:
retries=0, chord=None,
reply_to=None, time_limit=None, soft_time_limit=None,
root_id=None, parent_id=None, route_name=None,
shadow=None, chain=None, task_type=None, **options):
shadow=None, chain=None, task_type=None, replaced_task_nesting=0, **options):
"""Send task by name.
Supports the same arguments as :meth:`@-Task.apply_async`.
@@ -734,13 +831,48 @@ class Celery:
ignore_result = options.pop('ignore_result', False)
options = router.route(
options, route_name or name, args, kwargs, task_type)
driver_type = self.producer_pool.connections.connection.transport.driver_type
if (eta or countdown) and detect_quorum_queues(self, driver_type)[0]:
queue = options.get("queue")
exchange_type = queue.exchange.type if queue else options["exchange_type"]
routing_key = queue.routing_key if queue else options["routing_key"]
exchange_name = queue.exchange.name if queue else options["exchange"]
if exchange_type != 'direct':
if eta:
if isinstance(eta, str):
eta = isoparse(eta)
countdown = (maybe_make_aware(eta) - self.now()).total_seconds()
if countdown:
if countdown > 0:
routing_key = calculate_routing_key(int(countdown), routing_key)
exchange = Exchange(
'celery_delayed_27',
type='topic',
)
options.pop("queue", None)
options['routing_key'] = routing_key
options['exchange'] = exchange
else:
logger.warning(
'Direct exchanges are not supported with native delayed delivery.\n'
f'{exchange_name} is a direct exchange but should be a topic exchange or '
'a fanout exchange in order for native delayed delivery to work properly.\n'
'If quorum queues are used, this task may block the worker process until the ETA arrives.'
)
if expires is not None:
if isinstance(expires, datetime):
expires_s = (maybe_make_aware(
expires) - self.now()).total_seconds()
elif isinstance(expires, str):
expires_s = (maybe_make_aware(
datetime.fromisoformat(expires)) - self.now()).total_seconds()
isoparse(expires)) - self.now()).total_seconds()
else:
expires_s = expires
@@ -781,7 +913,7 @@ class Celery:
self.conf.task_send_sent_event,
root_id, parent_id, shadow, chain,
ignore_result=ignore_result,
**options
replaced_task_nesting=replaced_task_nesting, **options
)
stamped_headers = options.pop('stamped_headers', [])
@@ -894,6 +1026,7 @@ class Celery:
'broker_connection_timeout', connect_timeout
),
)
broker_connection = connection
def _acquire_connection(self, pool=True):
@@ -913,6 +1046,7 @@ class Celery:
will be acquired from the connection pool.
"""
return FallbackContext(connection, self._acquire_connection, pool=pool)
default_connection = connection_or_acquire # XXX compat
def producer_or_acquire(self, producer=None):
@@ -928,6 +1062,7 @@ class Celery:
return FallbackContext(
producer, self.producer_pool.acquire, block=True,
)
default_producer = producer_or_acquire # XXX compat
def prepare_config(self, c):
@@ -936,7 +1071,7 @@ class Celery:
def now(self):
"""Return the current time and date as a datetime."""
now_in_utc = to_utc(datetime.utcnow())
now_in_utc = to_utc(datetime.now(datetime_timezone.utc))
return now_in_utc.astimezone(self.timezone)
def select_queues(self, queues=None):
@@ -974,7 +1109,14 @@ class Celery:
This is used by PendingConfiguration:
as soon as you access a key the configuration is read.
"""
conf = self._conf = self._load_config()
try:
conf = self._conf = self._load_config()
except AttributeError as err:
# AttributeError is not propagated, it is "handled" by
# PendingConfiguration parent class. This causes
# confusing RecursionError.
raise ModuleNotFoundError(*err.args) from err
return conf
def _load_config(self):

View File

@@ -360,7 +360,7 @@ class Inspect:
* ``routing_key`` - Routing key used when task was published
* ``priority`` - Priority used when task was published
* ``redelivered`` - True if the task was redelivered
* ``worker_pid`` - PID of worker processin the task
* ``worker_pid`` - PID of worker processing the task
"""
# signature used be unary: query_task(ids=[id1, id2])
@@ -527,7 +527,8 @@ class Control:
if result:
for host in result:
for response in host.values():
task_ids.update(response['ok'])
if isinstance(response['ok'], set):
task_ids.update(response['ok'])
if task_ids:
return self.revoke(list(task_ids), destination=destination, terminate=terminate, signal=signal, **kwargs)

View File

@@ -95,6 +95,7 @@ NAMESPACES = Namespace(
heartbeat=Option(120, type='int'),
heartbeat_checkrate=Option(3.0, type='int'),
login_method=Option(None, type='string'),
native_delayed_delivery_queue_type=Option(default='quorum', type='string'),
pool_limit=Option(10, type='int'),
use_ssl=Option(False, type='bool'),
@@ -140,6 +141,12 @@ NAMESPACES = Namespace(
connection_timeout=Option(20, type='int'),
read_timeout=Option(120, type='int'),
),
gcs=Namespace(
bucket=Option(type='string'),
project=Option(type='string'),
base_path=Option('', type='string'),
ttl=Option(0, type='float'),
),
control=Namespace(
queue_ttl=Option(300.0, type='float'),
queue_expires=Option(10.0, type='float'),
@@ -243,6 +250,7 @@ NAMESPACES = Namespace(
),
table_schemas=Option(type='dict'),
table_names=Option(type='dict', old={'celery_result_db_tablenames'}),
create_tables_at_setup=Option(True, type='bool'),
),
task=Namespace(
__old__=OLD_NS,
@@ -255,6 +263,7 @@ NAMESPACES = Namespace(
inherit_parent_priority=Option(False, type='bool'),
default_delivery_mode=Option(2, type='string'),
default_queue=Option('celery'),
default_queue_type=Option('classic', type='string'),
default_exchange=Option(None, type='string'), # taken from queue
default_exchange_type=Option('direct'),
default_routing_key=Option(None, type='string'), # taken from queue
@@ -302,6 +311,8 @@ NAMESPACES = Namespace(
cancel_long_running_tasks_on_connection_loss=Option(
False, type='bool'
),
soft_shutdown_timeout=Option(0.0, type='float'),
enable_soft_shutdown_on_idle=Option(False, type='bool'),
concurrency=Option(None, type='int'),
consumer=Option('celery.worker.consumer:Consumer', type='string'),
direct=Option(False, type='bool', old={'celery_worker_direct'}),
@@ -325,6 +336,7 @@ NAMESPACES = Namespace(
pool_restarts=Option(False, type='bool'),
proc_alive_timeout=Option(4.0, type='float'),
prefetch_multiplier=Option(4, type='int'),
enable_prefetch_count_reduction=Option(True, type='bool'),
redirect_stdouts=Option(
True, type='bool', old={'celery_redirect_stdouts'},
),
@@ -338,6 +350,7 @@ NAMESPACES = Namespace(
task_log_format=Option(DEFAULT_TASK_LOG_FMT),
timer=Option(type='string'),
timer_precision=Option(1.0, type='float'),
detect_quorum_queues=Option(True, type='bool'),
),
)

View File

@@ -18,6 +18,7 @@ from celery import signals
from celery._state import get_current_task
from celery.exceptions import CDeprecationWarning, CPendingDeprecationWarning
from celery.local import class_property
from celery.platforms import isatty
from celery.utils.log import (ColorFormatter, LoggingProxy, get_logger, get_multiprocessing_logger, mlevel,
reset_multiprocessing_logger)
from celery.utils.nodenames import node_format
@@ -203,7 +204,7 @@ class Logging:
if colorize or colorize is None:
# Only use color if there's no active log file
# and stderr is an actual terminal.
return logfile is None and sys.stderr.isatty()
return logfile is None and isatty(sys.stderr)
return colorize
def colored(self, logfile=None, enabled=None):

View File

@@ -20,7 +20,7 @@ except AttributeError: # pragma: no cover
# for support Python 3.7
Pattern = re.Pattern
__all__ = ('MapRoute', 'Router', 'prepare')
__all__ = ('MapRoute', 'Router', 'expand_router_string', 'prepare')
class MapRoute:

View File

@@ -104,7 +104,7 @@ class Context:
def _get_custom_headers(self, *args, **kwargs):
headers = {}
headers.update(*args, **kwargs)
celery_keys = {*Context.__dict__.keys(), 'lang', 'task', 'argsrepr', 'kwargsrepr'}
celery_keys = {*Context.__dict__.keys(), 'lang', 'task', 'argsrepr', 'kwargsrepr', 'compression'}
for key in celery_keys:
headers.pop(key, None)
if not headers:
@@ -466,7 +466,7 @@ class Task:
shadow (str): Override task name used in logs/monitoring.
Default is retrieved from :meth:`shadow_name`.
connection (kombu.Connection): Re-use existing broker connection
connection (kombu.Connection): Reuse existing broker connection
instead of acquiring one from the connection pool.
retry (bool): If enabled sending of the task message will be
@@ -535,6 +535,8 @@ class Task:
publisher (kombu.Producer): Deprecated alias to ``producer``.
headers (Dict): Message headers to be included in the message.
The headers can be used as an overlay for custom labeling
using the :ref:`canvas-stamping` feature.
Returns:
celery.result.AsyncResult: Promise of future evaluation.
@@ -543,6 +545,8 @@ class Task:
TypeError: If not enough arguments are passed, or too many
arguments are passed. Note that signature checks may
be disabled by specifying ``@task(typing=False)``.
ValueError: If soft_time_limit and time_limit both are set
but soft_time_limit is greater than time_limit
kombu.exceptions.OperationalError: If a connection to the
transport cannot be made, or if the connection is lost.
@@ -550,6 +554,9 @@ class Task:
Also supports all keyword arguments supported by
:meth:`kombu.Producer.publish`.
"""
if self.soft_time_limit and self.time_limit and self.soft_time_limit > self.time_limit:
raise ValueError('soft_time_limit must be less than or equal to time_limit')
if self.typing:
try:
check_arguments = self.__header__
@@ -788,6 +795,7 @@ class Task:
request = {
'id': task_id,
'task': self.name,
'retries': retries,
'is_eager': True,
'logfile': logfile,
@@ -824,7 +832,7 @@ class Task:
if isinstance(retval, Retry) and retval.sig is not None:
return retval.sig.apply(retries=retries + 1)
state = states.SUCCESS if ret.info is None else ret.info.state
return EagerResult(task_id, retval, state, traceback=tb)
return EagerResult(task_id, retval, state, traceback=tb, name=self.name)
def AsyncResult(self, task_id, **kwargs):
"""Get AsyncResult instance for the specified task.
@@ -954,11 +962,20 @@ class Task:
root_id=self.request.root_id,
replaced_task_nesting=replaced_task_nesting
)
# If the replaced task is a chain, we want to set all of the chain tasks
# with the same replaced_task_nesting value to mark their replacement nesting level
if isinstance(sig, _chain):
for chain_task in maybe_list(sig.tasks) or []:
chain_task.set(replaced_task_nesting=replaced_task_nesting)
# If the task being replaced is part of a chain, we need to re-create
# it with the replacement signature - these subsequent tasks will
# retain their original task IDs as well
for t in reversed(self.request.chain or []):
sig |= signature(t, app=self.app)
chain_task = signature(t, app=self.app)
chain_task.set(replaced_task_nesting=replaced_task_nesting)
sig |= chain_task
return self.on_replace(sig)
def add_to_chord(self, sig, lazy=False):
@@ -1099,7 +1116,7 @@ class Task:
return result
def push_request(self, *args, **kwargs):
self.request_stack.push(Context(*args, **kwargs))
self.request_stack.push(Context(*args, **{**self.request.__dict__, **kwargs}))
def pop_request(self):
self.request_stack.pop()

View File

@@ -8,7 +8,6 @@ import os
import sys
import time
from collections import namedtuple
from typing import Any, Callable, Dict, FrozenSet, Optional, Sequence, Tuple, Type, Union
from warnings import warn
from billiard.einfo import ExceptionInfo, ExceptionWithTraceback
@@ -17,8 +16,6 @@ from kombu.serialization import loads as loads_message
from kombu.serialization import prepare_accept_content
from kombu.utils.encoding import safe_repr, safe_str
import celery
import celery.loaders.app
from celery import current_app, group, signals, states
from celery._state import _task_stack
from celery.app.task import Context
@@ -294,20 +291,10 @@ def traceback_clear(exc=None):
tb = tb.tb_next
def build_tracer(
name: str,
task: Union[celery.Task, celery.local.PromiseProxy],
loader: Optional[celery.loaders.app.AppLoader] = None,
hostname: Optional[str] = None,
store_errors: bool = True,
Info: Type[TraceInfo] = TraceInfo,
eager: bool = False,
propagate: bool = False,
app: Optional[celery.Celery] = None,
monotonic: Callable[[], int] = time.monotonic,
trace_ok_t: Type[trace_ok_t] = trace_ok_t,
IGNORE_STATES: FrozenSet[str] = IGNORE_STATES) -> \
Callable[[str, Tuple[Any, ...], Dict[str, Any], Any], trace_ok_t]:
def build_tracer(name, task, loader=None, hostname=None, store_errors=True,
Info=TraceInfo, eager=False, propagate=False, app=None,
monotonic=time.monotonic, trace_ok_t=trace_ok_t,
IGNORE_STATES=IGNORE_STATES):
"""Return a function that traces task execution.
Catches all exceptions and updates result backend with the
@@ -387,12 +374,7 @@ def build_tracer(
from celery import canvas
signature = canvas.maybe_signature # maybe_ does not clone if already
def on_error(
request: celery.app.task.Context,
exc: Union[Exception, Type[Exception]],
state: str = FAILURE,
call_errbacks: bool = True) -> Tuple[Info, Any, Any, Any]:
"""Handle any errors raised by a `Task`'s execution."""
def on_error(request, exc, state=FAILURE, call_errbacks=True):
if propagate:
raise
I = Info(state, exc)
@@ -401,13 +383,7 @@ def build_tracer(
)
return I, R, I.state, I.retval
def trace_task(
uuid: str,
args: Sequence[Any],
kwargs: Dict[str, Any],
request: Optional[Dict[str, Any]] = None) -> trace_ok_t:
"""Execute and trace a `Task`."""
def trace_task(uuid, args, kwargs, request=None):
# R - is the possibly prepared return value.
# I - is the Info object.
# T - runtime

View File

@@ -35,7 +35,7 @@ settings -> transport:{transport} results:{results}
"""
HIDDEN_SETTINGS = re.compile(
'API|TOKEN|KEY|SECRET|PASS|PROFANITIES_LIST|SIGNATURE|DATABASE',
'API|TOKEN|KEY|SECRET|PASS|PROFANITIES_LIST|SIGNATURE|DATABASE|BEAT_DBURI',
re.IGNORECASE,
)