This commit is contained in:
@@ -1,23 +1,17 @@
|
||||
"""Actual App instance implementation."""
|
||||
import functools
|
||||
import importlib
|
||||
import inspect
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
import typing
|
||||
import warnings
|
||||
from collections import UserDict, defaultdict, deque
|
||||
from datetime import datetime
|
||||
from datetime import timezone as datetime_timezone
|
||||
from operator import attrgetter
|
||||
|
||||
from click.exceptions import Exit
|
||||
from dateutil.parser import isoparse
|
||||
from kombu import Exchange, pools
|
||||
from kombu import pools
|
||||
from kombu.clocks import LamportClock
|
||||
from kombu.common import oid_from
|
||||
from kombu.transport.native_delayed_delivery import calculate_routing_key
|
||||
from kombu.utils.compat import register_after_fork
|
||||
from kombu.utils.objects import cached_property
|
||||
from kombu.utils.uuid import uuid
|
||||
@@ -38,8 +32,6 @@ from celery.utils.log import get_logger
|
||||
from celery.utils.objects import FallbackContext, mro_lookup
|
||||
from celery.utils.time import maybe_make_aware, timezone, to_utc
|
||||
|
||||
from ..utils.annotations import annotation_is_class, annotation_issubclass, get_optional_arg
|
||||
from ..utils.quorum_queues import detect_quorum_queues
|
||||
# Load all builtin tasks
|
||||
from . import backends, builtins # noqa
|
||||
from .annotations import prepare as prepare_annotations
|
||||
@@ -49,10 +41,6 @@ from .registry import TaskRegistry
|
||||
from .utils import (AppPickler, Settings, _new_key_to_old, _old_key_to_new, _unpickle_app, _unpickle_app_v2, appstr,
|
||||
bugreport, detect_settings)
|
||||
|
||||
if typing.TYPE_CHECKING: # pragma: no cover # codecov does not capture this
|
||||
# flake8 marks the BaseModel import as unused, because the actual typehint is quoted.
|
||||
from pydantic import BaseModel # noqa: F401
|
||||
|
||||
__all__ = ('Celery',)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -102,70 +90,6 @@ def _after_fork_cleanup_app(app):
|
||||
logger.info('after forker raised exception: %r', exc, exc_info=1)
|
||||
|
||||
|
||||
def pydantic_wrapper(
|
||||
app: "Celery",
|
||||
task_fun: typing.Callable[..., typing.Any],
|
||||
task_name: str,
|
||||
strict: bool = True,
|
||||
context: typing.Optional[typing.Dict[str, typing.Any]] = None,
|
||||
dump_kwargs: typing.Optional[typing.Dict[str, typing.Any]] = None
|
||||
):
|
||||
"""Wrapper to validate arguments and serialize return values using Pydantic."""
|
||||
try:
|
||||
pydantic = importlib.import_module('pydantic')
|
||||
except ModuleNotFoundError as ex:
|
||||
raise ImproperlyConfigured('You need to install pydantic to use pydantic model serialization.') from ex
|
||||
|
||||
BaseModel: typing.Type['BaseModel'] = pydantic.BaseModel # noqa: F811 # only defined when type checking
|
||||
|
||||
if context is None:
|
||||
context = {}
|
||||
if dump_kwargs is None:
|
||||
dump_kwargs = {}
|
||||
dump_kwargs.setdefault('mode', 'json')
|
||||
|
||||
task_signature = inspect.signature(task_fun)
|
||||
|
||||
@functools.wraps(task_fun)
|
||||
def wrapper(*task_args, **task_kwargs):
|
||||
# Validate task parameters if type hinted as BaseModel
|
||||
bound_args = task_signature.bind(*task_args, **task_kwargs)
|
||||
for arg_name, arg_value in bound_args.arguments.items():
|
||||
arg_annotation = task_signature.parameters[arg_name].annotation
|
||||
|
||||
optional_arg = get_optional_arg(arg_annotation)
|
||||
if optional_arg is not None and arg_value is not None:
|
||||
arg_annotation = optional_arg
|
||||
|
||||
if annotation_issubclass(arg_annotation, BaseModel):
|
||||
bound_args.arguments[arg_name] = arg_annotation.model_validate(
|
||||
arg_value,
|
||||
strict=strict,
|
||||
context={**context, 'celery_app': app, 'celery_task_name': task_name},
|
||||
)
|
||||
|
||||
# Call the task with (potentially) converted arguments
|
||||
returned_value = task_fun(*bound_args.args, **bound_args.kwargs)
|
||||
|
||||
# Dump Pydantic model if the returned value is an instance of pydantic.BaseModel *and* its
|
||||
# class matches the typehint
|
||||
return_annotation = task_signature.return_annotation
|
||||
optional_return_annotation = get_optional_arg(return_annotation)
|
||||
if optional_return_annotation is not None:
|
||||
return_annotation = optional_return_annotation
|
||||
|
||||
if (
|
||||
annotation_is_class(return_annotation)
|
||||
and isinstance(returned_value, BaseModel)
|
||||
and isinstance(returned_value, return_annotation)
|
||||
):
|
||||
return returned_value.model_dump(**dump_kwargs)
|
||||
|
||||
return returned_value
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class PendingConfiguration(UserDict, AttributeDictMixin):
|
||||
# `app.conf` will be of this type before being explicitly configured,
|
||||
# meaning the app can keep any configuration set directly
|
||||
@@ -314,12 +238,6 @@ class Celery:
|
||||
self.loader_cls = loader or self._get_default_loader()
|
||||
self.log_cls = log or self.log_cls
|
||||
self.control_cls = control or self.control_cls
|
||||
self._custom_task_cls_used = (
|
||||
# Custom task class provided as argument
|
||||
bool(task_cls)
|
||||
# subclass of Celery with a task_cls attribute
|
||||
or self.__class__ is not Celery and hasattr(self.__class__, 'task_cls')
|
||||
)
|
||||
self.task_cls = task_cls or self.task_cls
|
||||
self.set_as_current = set_as_current
|
||||
self.registry_cls = symbol_by_name(self.registry_cls)
|
||||
@@ -515,7 +433,6 @@ class Celery:
|
||||
if shared:
|
||||
def cons(app):
|
||||
return app._task_from_fun(fun, **opts)
|
||||
|
||||
cons.__name__ = fun.__name__
|
||||
connect_on_app_finalize(cons)
|
||||
if not lazy or self.finalized:
|
||||
@@ -544,27 +461,13 @@ class Celery:
|
||||
def type_checker(self, fun, bound=False):
|
||||
return staticmethod(head_from_fun(fun, bound=bound))
|
||||
|
||||
def _task_from_fun(
|
||||
self,
|
||||
fun,
|
||||
name=None,
|
||||
base=None,
|
||||
bind=False,
|
||||
pydantic: bool = False,
|
||||
pydantic_strict: bool = False,
|
||||
pydantic_context: typing.Optional[typing.Dict[str, typing.Any]] = None,
|
||||
pydantic_dump_kwargs: typing.Optional[typing.Dict[str, typing.Any]] = None,
|
||||
**options,
|
||||
):
|
||||
def _task_from_fun(self, fun, name=None, base=None, bind=False, **options):
|
||||
if not self.finalized and not self.autofinalize:
|
||||
raise RuntimeError('Contract breach: app not finalized')
|
||||
name = name or self.gen_task_name(fun.__name__, fun.__module__)
|
||||
base = base or self.Task
|
||||
|
||||
if name not in self._tasks:
|
||||
if pydantic is True:
|
||||
fun = pydantic_wrapper(self, fun, name, pydantic_strict, pydantic_context, pydantic_dump_kwargs)
|
||||
|
||||
run = fun if bind else staticmethod(fun)
|
||||
task = type(fun.__name__, (base,), dict({
|
||||
'app': self,
|
||||
@@ -808,7 +711,7 @@ class Celery:
|
||||
retries=0, chord=None,
|
||||
reply_to=None, time_limit=None, soft_time_limit=None,
|
||||
root_id=None, parent_id=None, route_name=None,
|
||||
shadow=None, chain=None, task_type=None, replaced_task_nesting=0, **options):
|
||||
shadow=None, chain=None, task_type=None, **options):
|
||||
"""Send task by name.
|
||||
|
||||
Supports the same arguments as :meth:`@-Task.apply_async`.
|
||||
@@ -831,48 +734,13 @@ class Celery:
|
||||
ignore_result = options.pop('ignore_result', False)
|
||||
options = router.route(
|
||||
options, route_name or name, args, kwargs, task_type)
|
||||
|
||||
driver_type = self.producer_pool.connections.connection.transport.driver_type
|
||||
|
||||
if (eta or countdown) and detect_quorum_queues(self, driver_type)[0]:
|
||||
|
||||
queue = options.get("queue")
|
||||
exchange_type = queue.exchange.type if queue else options["exchange_type"]
|
||||
routing_key = queue.routing_key if queue else options["routing_key"]
|
||||
exchange_name = queue.exchange.name if queue else options["exchange"]
|
||||
|
||||
if exchange_type != 'direct':
|
||||
if eta:
|
||||
if isinstance(eta, str):
|
||||
eta = isoparse(eta)
|
||||
countdown = (maybe_make_aware(eta) - self.now()).total_seconds()
|
||||
|
||||
if countdown:
|
||||
if countdown > 0:
|
||||
routing_key = calculate_routing_key(int(countdown), routing_key)
|
||||
exchange = Exchange(
|
||||
'celery_delayed_27',
|
||||
type='topic',
|
||||
)
|
||||
options.pop("queue", None)
|
||||
options['routing_key'] = routing_key
|
||||
options['exchange'] = exchange
|
||||
|
||||
else:
|
||||
logger.warning(
|
||||
'Direct exchanges are not supported with native delayed delivery.\n'
|
||||
f'{exchange_name} is a direct exchange but should be a topic exchange or '
|
||||
'a fanout exchange in order for native delayed delivery to work properly.\n'
|
||||
'If quorum queues are used, this task may block the worker process until the ETA arrives.'
|
||||
)
|
||||
|
||||
if expires is not None:
|
||||
if isinstance(expires, datetime):
|
||||
expires_s = (maybe_make_aware(
|
||||
expires) - self.now()).total_seconds()
|
||||
elif isinstance(expires, str):
|
||||
expires_s = (maybe_make_aware(
|
||||
isoparse(expires)) - self.now()).total_seconds()
|
||||
datetime.fromisoformat(expires)) - self.now()).total_seconds()
|
||||
else:
|
||||
expires_s = expires
|
||||
|
||||
@@ -913,7 +781,7 @@ class Celery:
|
||||
self.conf.task_send_sent_event,
|
||||
root_id, parent_id, shadow, chain,
|
||||
ignore_result=ignore_result,
|
||||
replaced_task_nesting=replaced_task_nesting, **options
|
||||
**options
|
||||
)
|
||||
|
||||
stamped_headers = options.pop('stamped_headers', [])
|
||||
@@ -1026,7 +894,6 @@ class Celery:
|
||||
'broker_connection_timeout', connect_timeout
|
||||
),
|
||||
)
|
||||
|
||||
broker_connection = connection
|
||||
|
||||
def _acquire_connection(self, pool=True):
|
||||
@@ -1046,7 +913,6 @@ class Celery:
|
||||
will be acquired from the connection pool.
|
||||
"""
|
||||
return FallbackContext(connection, self._acquire_connection, pool=pool)
|
||||
|
||||
default_connection = connection_or_acquire # XXX compat
|
||||
|
||||
def producer_or_acquire(self, producer=None):
|
||||
@@ -1062,7 +928,6 @@ class Celery:
|
||||
return FallbackContext(
|
||||
producer, self.producer_pool.acquire, block=True,
|
||||
)
|
||||
|
||||
default_producer = producer_or_acquire # XXX compat
|
||||
|
||||
def prepare_config(self, c):
|
||||
@@ -1071,7 +936,7 @@ class Celery:
|
||||
|
||||
def now(self):
|
||||
"""Return the current time and date as a datetime."""
|
||||
now_in_utc = to_utc(datetime.now(datetime_timezone.utc))
|
||||
now_in_utc = to_utc(datetime.utcnow())
|
||||
return now_in_utc.astimezone(self.timezone)
|
||||
|
||||
def select_queues(self, queues=None):
|
||||
@@ -1109,14 +974,7 @@ class Celery:
|
||||
This is used by PendingConfiguration:
|
||||
as soon as you access a key the configuration is read.
|
||||
"""
|
||||
try:
|
||||
conf = self._conf = self._load_config()
|
||||
except AttributeError as err:
|
||||
# AttributeError is not propagated, it is "handled" by
|
||||
# PendingConfiguration parent class. This causes
|
||||
# confusing RecursionError.
|
||||
raise ModuleNotFoundError(*err.args) from err
|
||||
|
||||
conf = self._conf = self._load_config()
|
||||
return conf
|
||||
|
||||
def _load_config(self):
|
||||
|
||||
Reference in New Issue
Block a user