This commit is contained in:
@@ -249,9 +249,13 @@ class AMQP:
|
||||
if max_priority is None:
|
||||
max_priority = conf.task_queue_max_priority
|
||||
if not queues and conf.task_default_queue:
|
||||
queue_arguments = None
|
||||
if conf.task_default_queue_type == 'quorum':
|
||||
queue_arguments = {'x-queue-type': 'quorum'}
|
||||
queues = (Queue(conf.task_default_queue,
|
||||
exchange=self.default_exchange,
|
||||
routing_key=default_routing_key),)
|
||||
routing_key=default_routing_key,
|
||||
queue_arguments=queue_arguments),)
|
||||
autoexchange = (self.autoexchange if autoexchange is None
|
||||
else autoexchange)
|
||||
return self.queues_cls(
|
||||
@@ -285,7 +289,7 @@ class AMQP:
|
||||
create_sent_event=False, root_id=None, parent_id=None,
|
||||
shadow=None, chain=None, now=None, timezone=None,
|
||||
origin=None, ignore_result=False, argsrepr=None, kwargsrepr=None, stamped_headers=None,
|
||||
**options):
|
||||
replaced_task_nesting=0, **options):
|
||||
|
||||
args = args or ()
|
||||
kwargs = kwargs or {}
|
||||
@@ -339,6 +343,7 @@ class AMQP:
|
||||
'kwargsrepr': kwargsrepr,
|
||||
'origin': origin or anon_nodename(),
|
||||
'ignore_result': ignore_result,
|
||||
'replaced_task_nesting': replaced_task_nesting,
|
||||
'stamped_headers': stamped_headers,
|
||||
'stamps': stamps,
|
||||
}
|
||||
@@ -462,7 +467,8 @@ class AMQP:
|
||||
retry=None, retry_policy=None,
|
||||
serializer=None, delivery_mode=None,
|
||||
compression=None, declare=None,
|
||||
headers=None, exchange_type=None, **kwargs):
|
||||
headers=None, exchange_type=None,
|
||||
timeout=None, confirm_timeout=None, **kwargs):
|
||||
retry = default_retry if retry is None else retry
|
||||
headers2, properties, body, sent_event = message
|
||||
if headers:
|
||||
@@ -523,6 +529,7 @@ class AMQP:
|
||||
retry=retry, retry_policy=_rp,
|
||||
delivery_mode=delivery_mode, declare=declare,
|
||||
headers=headers2,
|
||||
timeout=timeout, confirm_timeout=confirm_timeout,
|
||||
**properties
|
||||
)
|
||||
if after_receivers:
|
||||
|
||||
@@ -34,6 +34,7 @@ BACKEND_ALIASES = {
|
||||
'azureblockblob': 'celery.backends.azureblockblob:AzureBlockBlobBackend',
|
||||
'arangodb': 'celery.backends.arangodb:ArangoDbBackend',
|
||||
's3': 'celery.backends.s3:S3Backend',
|
||||
'gs': 'celery.backends.gcs:GCSBackend',
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -1,17 +1,23 @@
|
||||
"""Actual App instance implementation."""
|
||||
import functools
|
||||
import importlib
|
||||
import inspect
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
import typing
|
||||
import warnings
|
||||
from collections import UserDict, defaultdict, deque
|
||||
from datetime import datetime
|
||||
from datetime import timezone as datetime_timezone
|
||||
from operator import attrgetter
|
||||
|
||||
from click.exceptions import Exit
|
||||
from kombu import pools
|
||||
from dateutil.parser import isoparse
|
||||
from kombu import Exchange, pools
|
||||
from kombu.clocks import LamportClock
|
||||
from kombu.common import oid_from
|
||||
from kombu.transport.native_delayed_delivery import calculate_routing_key
|
||||
from kombu.utils.compat import register_after_fork
|
||||
from kombu.utils.objects import cached_property
|
||||
from kombu.utils.uuid import uuid
|
||||
@@ -32,6 +38,8 @@ from celery.utils.log import get_logger
|
||||
from celery.utils.objects import FallbackContext, mro_lookup
|
||||
from celery.utils.time import maybe_make_aware, timezone, to_utc
|
||||
|
||||
from ..utils.annotations import annotation_is_class, annotation_issubclass, get_optional_arg
|
||||
from ..utils.quorum_queues import detect_quorum_queues
|
||||
# Load all builtin tasks
|
||||
from . import backends, builtins # noqa
|
||||
from .annotations import prepare as prepare_annotations
|
||||
@@ -41,6 +49,10 @@ from .registry import TaskRegistry
|
||||
from .utils import (AppPickler, Settings, _new_key_to_old, _old_key_to_new, _unpickle_app, _unpickle_app_v2, appstr,
|
||||
bugreport, detect_settings)
|
||||
|
||||
if typing.TYPE_CHECKING: # pragma: no cover # codecov does not capture this
|
||||
# flake8 marks the BaseModel import as unused, because the actual typehint is quoted.
|
||||
from pydantic import BaseModel # noqa: F401
|
||||
|
||||
__all__ = ('Celery',)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -90,6 +102,70 @@ def _after_fork_cleanup_app(app):
|
||||
logger.info('after forker raised exception: %r', exc, exc_info=1)
|
||||
|
||||
|
||||
def pydantic_wrapper(
|
||||
app: "Celery",
|
||||
task_fun: typing.Callable[..., typing.Any],
|
||||
task_name: str,
|
||||
strict: bool = True,
|
||||
context: typing.Optional[typing.Dict[str, typing.Any]] = None,
|
||||
dump_kwargs: typing.Optional[typing.Dict[str, typing.Any]] = None
|
||||
):
|
||||
"""Wrapper to validate arguments and serialize return values using Pydantic."""
|
||||
try:
|
||||
pydantic = importlib.import_module('pydantic')
|
||||
except ModuleNotFoundError as ex:
|
||||
raise ImproperlyConfigured('You need to install pydantic to use pydantic model serialization.') from ex
|
||||
|
||||
BaseModel: typing.Type['BaseModel'] = pydantic.BaseModel # noqa: F811 # only defined when type checking
|
||||
|
||||
if context is None:
|
||||
context = {}
|
||||
if dump_kwargs is None:
|
||||
dump_kwargs = {}
|
||||
dump_kwargs.setdefault('mode', 'json')
|
||||
|
||||
task_signature = inspect.signature(task_fun)
|
||||
|
||||
@functools.wraps(task_fun)
|
||||
def wrapper(*task_args, **task_kwargs):
|
||||
# Validate task parameters if type hinted as BaseModel
|
||||
bound_args = task_signature.bind(*task_args, **task_kwargs)
|
||||
for arg_name, arg_value in bound_args.arguments.items():
|
||||
arg_annotation = task_signature.parameters[arg_name].annotation
|
||||
|
||||
optional_arg = get_optional_arg(arg_annotation)
|
||||
if optional_arg is not None and arg_value is not None:
|
||||
arg_annotation = optional_arg
|
||||
|
||||
if annotation_issubclass(arg_annotation, BaseModel):
|
||||
bound_args.arguments[arg_name] = arg_annotation.model_validate(
|
||||
arg_value,
|
||||
strict=strict,
|
||||
context={**context, 'celery_app': app, 'celery_task_name': task_name},
|
||||
)
|
||||
|
||||
# Call the task with (potentially) converted arguments
|
||||
returned_value = task_fun(*bound_args.args, **bound_args.kwargs)
|
||||
|
||||
# Dump Pydantic model if the returned value is an instance of pydantic.BaseModel *and* its
|
||||
# class matches the typehint
|
||||
return_annotation = task_signature.return_annotation
|
||||
optional_return_annotation = get_optional_arg(return_annotation)
|
||||
if optional_return_annotation is not None:
|
||||
return_annotation = optional_return_annotation
|
||||
|
||||
if (
|
||||
annotation_is_class(return_annotation)
|
||||
and isinstance(returned_value, BaseModel)
|
||||
and isinstance(returned_value, return_annotation)
|
||||
):
|
||||
return returned_value.model_dump(**dump_kwargs)
|
||||
|
||||
return returned_value
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class PendingConfiguration(UserDict, AttributeDictMixin):
|
||||
# `app.conf` will be of this type before being explicitly configured,
|
||||
# meaning the app can keep any configuration set directly
|
||||
@@ -238,6 +314,12 @@ class Celery:
|
||||
self.loader_cls = loader or self._get_default_loader()
|
||||
self.log_cls = log or self.log_cls
|
||||
self.control_cls = control or self.control_cls
|
||||
self._custom_task_cls_used = (
|
||||
# Custom task class provided as argument
|
||||
bool(task_cls)
|
||||
# subclass of Celery with a task_cls attribute
|
||||
or self.__class__ is not Celery and hasattr(self.__class__, 'task_cls')
|
||||
)
|
||||
self.task_cls = task_cls or self.task_cls
|
||||
self.set_as_current = set_as_current
|
||||
self.registry_cls = symbol_by_name(self.registry_cls)
|
||||
@@ -433,6 +515,7 @@ class Celery:
|
||||
if shared:
|
||||
def cons(app):
|
||||
return app._task_from_fun(fun, **opts)
|
||||
|
||||
cons.__name__ = fun.__name__
|
||||
connect_on_app_finalize(cons)
|
||||
if not lazy or self.finalized:
|
||||
@@ -461,13 +544,27 @@ class Celery:
|
||||
def type_checker(self, fun, bound=False):
|
||||
return staticmethod(head_from_fun(fun, bound=bound))
|
||||
|
||||
def _task_from_fun(self, fun, name=None, base=None, bind=False, **options):
|
||||
def _task_from_fun(
|
||||
self,
|
||||
fun,
|
||||
name=None,
|
||||
base=None,
|
||||
bind=False,
|
||||
pydantic: bool = False,
|
||||
pydantic_strict: bool = False,
|
||||
pydantic_context: typing.Optional[typing.Dict[str, typing.Any]] = None,
|
||||
pydantic_dump_kwargs: typing.Optional[typing.Dict[str, typing.Any]] = None,
|
||||
**options,
|
||||
):
|
||||
if not self.finalized and not self.autofinalize:
|
||||
raise RuntimeError('Contract breach: app not finalized')
|
||||
name = name or self.gen_task_name(fun.__name__, fun.__module__)
|
||||
base = base or self.Task
|
||||
|
||||
if name not in self._tasks:
|
||||
if pydantic is True:
|
||||
fun = pydantic_wrapper(self, fun, name, pydantic_strict, pydantic_context, pydantic_dump_kwargs)
|
||||
|
||||
run = fun if bind else staticmethod(fun)
|
||||
task = type(fun.__name__, (base,), dict({
|
||||
'app': self,
|
||||
@@ -711,7 +808,7 @@ class Celery:
|
||||
retries=0, chord=None,
|
||||
reply_to=None, time_limit=None, soft_time_limit=None,
|
||||
root_id=None, parent_id=None, route_name=None,
|
||||
shadow=None, chain=None, task_type=None, **options):
|
||||
shadow=None, chain=None, task_type=None, replaced_task_nesting=0, **options):
|
||||
"""Send task by name.
|
||||
|
||||
Supports the same arguments as :meth:`@-Task.apply_async`.
|
||||
@@ -734,13 +831,48 @@ class Celery:
|
||||
ignore_result = options.pop('ignore_result', False)
|
||||
options = router.route(
|
||||
options, route_name or name, args, kwargs, task_type)
|
||||
|
||||
driver_type = self.producer_pool.connections.connection.transport.driver_type
|
||||
|
||||
if (eta or countdown) and detect_quorum_queues(self, driver_type)[0]:
|
||||
|
||||
queue = options.get("queue")
|
||||
exchange_type = queue.exchange.type if queue else options["exchange_type"]
|
||||
routing_key = queue.routing_key if queue else options["routing_key"]
|
||||
exchange_name = queue.exchange.name if queue else options["exchange"]
|
||||
|
||||
if exchange_type != 'direct':
|
||||
if eta:
|
||||
if isinstance(eta, str):
|
||||
eta = isoparse(eta)
|
||||
countdown = (maybe_make_aware(eta) - self.now()).total_seconds()
|
||||
|
||||
if countdown:
|
||||
if countdown > 0:
|
||||
routing_key = calculate_routing_key(int(countdown), routing_key)
|
||||
exchange = Exchange(
|
||||
'celery_delayed_27',
|
||||
type='topic',
|
||||
)
|
||||
options.pop("queue", None)
|
||||
options['routing_key'] = routing_key
|
||||
options['exchange'] = exchange
|
||||
|
||||
else:
|
||||
logger.warning(
|
||||
'Direct exchanges are not supported with native delayed delivery.\n'
|
||||
f'{exchange_name} is a direct exchange but should be a topic exchange or '
|
||||
'a fanout exchange in order for native delayed delivery to work properly.\n'
|
||||
'If quorum queues are used, this task may block the worker process until the ETA arrives.'
|
||||
)
|
||||
|
||||
if expires is not None:
|
||||
if isinstance(expires, datetime):
|
||||
expires_s = (maybe_make_aware(
|
||||
expires) - self.now()).total_seconds()
|
||||
elif isinstance(expires, str):
|
||||
expires_s = (maybe_make_aware(
|
||||
datetime.fromisoformat(expires)) - self.now()).total_seconds()
|
||||
isoparse(expires)) - self.now()).total_seconds()
|
||||
else:
|
||||
expires_s = expires
|
||||
|
||||
@@ -781,7 +913,7 @@ class Celery:
|
||||
self.conf.task_send_sent_event,
|
||||
root_id, parent_id, shadow, chain,
|
||||
ignore_result=ignore_result,
|
||||
**options
|
||||
replaced_task_nesting=replaced_task_nesting, **options
|
||||
)
|
||||
|
||||
stamped_headers = options.pop('stamped_headers', [])
|
||||
@@ -894,6 +1026,7 @@ class Celery:
|
||||
'broker_connection_timeout', connect_timeout
|
||||
),
|
||||
)
|
||||
|
||||
broker_connection = connection
|
||||
|
||||
def _acquire_connection(self, pool=True):
|
||||
@@ -913,6 +1046,7 @@ class Celery:
|
||||
will be acquired from the connection pool.
|
||||
"""
|
||||
return FallbackContext(connection, self._acquire_connection, pool=pool)
|
||||
|
||||
default_connection = connection_or_acquire # XXX compat
|
||||
|
||||
def producer_or_acquire(self, producer=None):
|
||||
@@ -928,6 +1062,7 @@ class Celery:
|
||||
return FallbackContext(
|
||||
producer, self.producer_pool.acquire, block=True,
|
||||
)
|
||||
|
||||
default_producer = producer_or_acquire # XXX compat
|
||||
|
||||
def prepare_config(self, c):
|
||||
@@ -936,7 +1071,7 @@ class Celery:
|
||||
|
||||
def now(self):
|
||||
"""Return the current time and date as a datetime."""
|
||||
now_in_utc = to_utc(datetime.utcnow())
|
||||
now_in_utc = to_utc(datetime.now(datetime_timezone.utc))
|
||||
return now_in_utc.astimezone(self.timezone)
|
||||
|
||||
def select_queues(self, queues=None):
|
||||
@@ -974,7 +1109,14 @@ class Celery:
|
||||
This is used by PendingConfiguration:
|
||||
as soon as you access a key the configuration is read.
|
||||
"""
|
||||
conf = self._conf = self._load_config()
|
||||
try:
|
||||
conf = self._conf = self._load_config()
|
||||
except AttributeError as err:
|
||||
# AttributeError is not propagated, it is "handled" by
|
||||
# PendingConfiguration parent class. This causes
|
||||
# confusing RecursionError.
|
||||
raise ModuleNotFoundError(*err.args) from err
|
||||
|
||||
return conf
|
||||
|
||||
def _load_config(self):
|
||||
|
||||
@@ -360,7 +360,7 @@ class Inspect:
|
||||
* ``routing_key`` - Routing key used when task was published
|
||||
* ``priority`` - Priority used when task was published
|
||||
* ``redelivered`` - True if the task was redelivered
|
||||
* ``worker_pid`` - PID of worker processin the task
|
||||
* ``worker_pid`` - PID of worker processing the task
|
||||
|
||||
"""
|
||||
# signature used be unary: query_task(ids=[id1, id2])
|
||||
@@ -527,7 +527,8 @@ class Control:
|
||||
if result:
|
||||
for host in result:
|
||||
for response in host.values():
|
||||
task_ids.update(response['ok'])
|
||||
if isinstance(response['ok'], set):
|
||||
task_ids.update(response['ok'])
|
||||
|
||||
if task_ids:
|
||||
return self.revoke(list(task_ids), destination=destination, terminate=terminate, signal=signal, **kwargs)
|
||||
|
||||
@@ -95,6 +95,7 @@ NAMESPACES = Namespace(
|
||||
heartbeat=Option(120, type='int'),
|
||||
heartbeat_checkrate=Option(3.0, type='int'),
|
||||
login_method=Option(None, type='string'),
|
||||
native_delayed_delivery_queue_type=Option(default='quorum', type='string'),
|
||||
pool_limit=Option(10, type='int'),
|
||||
use_ssl=Option(False, type='bool'),
|
||||
|
||||
@@ -140,6 +141,12 @@ NAMESPACES = Namespace(
|
||||
connection_timeout=Option(20, type='int'),
|
||||
read_timeout=Option(120, type='int'),
|
||||
),
|
||||
gcs=Namespace(
|
||||
bucket=Option(type='string'),
|
||||
project=Option(type='string'),
|
||||
base_path=Option('', type='string'),
|
||||
ttl=Option(0, type='float'),
|
||||
),
|
||||
control=Namespace(
|
||||
queue_ttl=Option(300.0, type='float'),
|
||||
queue_expires=Option(10.0, type='float'),
|
||||
@@ -243,6 +250,7 @@ NAMESPACES = Namespace(
|
||||
),
|
||||
table_schemas=Option(type='dict'),
|
||||
table_names=Option(type='dict', old={'celery_result_db_tablenames'}),
|
||||
create_tables_at_setup=Option(True, type='bool'),
|
||||
),
|
||||
task=Namespace(
|
||||
__old__=OLD_NS,
|
||||
@@ -255,6 +263,7 @@ NAMESPACES = Namespace(
|
||||
inherit_parent_priority=Option(False, type='bool'),
|
||||
default_delivery_mode=Option(2, type='string'),
|
||||
default_queue=Option('celery'),
|
||||
default_queue_type=Option('classic', type='string'),
|
||||
default_exchange=Option(None, type='string'), # taken from queue
|
||||
default_exchange_type=Option('direct'),
|
||||
default_routing_key=Option(None, type='string'), # taken from queue
|
||||
@@ -302,6 +311,8 @@ NAMESPACES = Namespace(
|
||||
cancel_long_running_tasks_on_connection_loss=Option(
|
||||
False, type='bool'
|
||||
),
|
||||
soft_shutdown_timeout=Option(0.0, type='float'),
|
||||
enable_soft_shutdown_on_idle=Option(False, type='bool'),
|
||||
concurrency=Option(None, type='int'),
|
||||
consumer=Option('celery.worker.consumer:Consumer', type='string'),
|
||||
direct=Option(False, type='bool', old={'celery_worker_direct'}),
|
||||
@@ -325,6 +336,7 @@ NAMESPACES = Namespace(
|
||||
pool_restarts=Option(False, type='bool'),
|
||||
proc_alive_timeout=Option(4.0, type='float'),
|
||||
prefetch_multiplier=Option(4, type='int'),
|
||||
enable_prefetch_count_reduction=Option(True, type='bool'),
|
||||
redirect_stdouts=Option(
|
||||
True, type='bool', old={'celery_redirect_stdouts'},
|
||||
),
|
||||
@@ -338,6 +350,7 @@ NAMESPACES = Namespace(
|
||||
task_log_format=Option(DEFAULT_TASK_LOG_FMT),
|
||||
timer=Option(type='string'),
|
||||
timer_precision=Option(1.0, type='float'),
|
||||
detect_quorum_queues=Option(True, type='bool'),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -18,6 +18,7 @@ from celery import signals
|
||||
from celery._state import get_current_task
|
||||
from celery.exceptions import CDeprecationWarning, CPendingDeprecationWarning
|
||||
from celery.local import class_property
|
||||
from celery.platforms import isatty
|
||||
from celery.utils.log import (ColorFormatter, LoggingProxy, get_logger, get_multiprocessing_logger, mlevel,
|
||||
reset_multiprocessing_logger)
|
||||
from celery.utils.nodenames import node_format
|
||||
@@ -203,7 +204,7 @@ class Logging:
|
||||
if colorize or colorize is None:
|
||||
# Only use color if there's no active log file
|
||||
# and stderr is an actual terminal.
|
||||
return logfile is None and sys.stderr.isatty()
|
||||
return logfile is None and isatty(sys.stderr)
|
||||
return colorize
|
||||
|
||||
def colored(self, logfile=None, enabled=None):
|
||||
|
||||
@@ -20,7 +20,7 @@ except AttributeError: # pragma: no cover
|
||||
# for support Python 3.7
|
||||
Pattern = re.Pattern
|
||||
|
||||
__all__ = ('MapRoute', 'Router', 'prepare')
|
||||
__all__ = ('MapRoute', 'Router', 'expand_router_string', 'prepare')
|
||||
|
||||
|
||||
class MapRoute:
|
||||
|
||||
@@ -104,7 +104,7 @@ class Context:
|
||||
def _get_custom_headers(self, *args, **kwargs):
|
||||
headers = {}
|
||||
headers.update(*args, **kwargs)
|
||||
celery_keys = {*Context.__dict__.keys(), 'lang', 'task', 'argsrepr', 'kwargsrepr'}
|
||||
celery_keys = {*Context.__dict__.keys(), 'lang', 'task', 'argsrepr', 'kwargsrepr', 'compression'}
|
||||
for key in celery_keys:
|
||||
headers.pop(key, None)
|
||||
if not headers:
|
||||
@@ -466,7 +466,7 @@ class Task:
|
||||
shadow (str): Override task name used in logs/monitoring.
|
||||
Default is retrieved from :meth:`shadow_name`.
|
||||
|
||||
connection (kombu.Connection): Re-use existing broker connection
|
||||
connection (kombu.Connection): Reuse existing broker connection
|
||||
instead of acquiring one from the connection pool.
|
||||
|
||||
retry (bool): If enabled sending of the task message will be
|
||||
@@ -535,6 +535,8 @@ class Task:
|
||||
publisher (kombu.Producer): Deprecated alias to ``producer``.
|
||||
|
||||
headers (Dict): Message headers to be included in the message.
|
||||
The headers can be used as an overlay for custom labeling
|
||||
using the :ref:`canvas-stamping` feature.
|
||||
|
||||
Returns:
|
||||
celery.result.AsyncResult: Promise of future evaluation.
|
||||
@@ -543,6 +545,8 @@ class Task:
|
||||
TypeError: If not enough arguments are passed, or too many
|
||||
arguments are passed. Note that signature checks may
|
||||
be disabled by specifying ``@task(typing=False)``.
|
||||
ValueError: If soft_time_limit and time_limit both are set
|
||||
but soft_time_limit is greater than time_limit
|
||||
kombu.exceptions.OperationalError: If a connection to the
|
||||
transport cannot be made, or if the connection is lost.
|
||||
|
||||
@@ -550,6 +554,9 @@ class Task:
|
||||
Also supports all keyword arguments supported by
|
||||
:meth:`kombu.Producer.publish`.
|
||||
"""
|
||||
if self.soft_time_limit and self.time_limit and self.soft_time_limit > self.time_limit:
|
||||
raise ValueError('soft_time_limit must be less than or equal to time_limit')
|
||||
|
||||
if self.typing:
|
||||
try:
|
||||
check_arguments = self.__header__
|
||||
@@ -788,6 +795,7 @@ class Task:
|
||||
|
||||
request = {
|
||||
'id': task_id,
|
||||
'task': self.name,
|
||||
'retries': retries,
|
||||
'is_eager': True,
|
||||
'logfile': logfile,
|
||||
@@ -824,7 +832,7 @@ class Task:
|
||||
if isinstance(retval, Retry) and retval.sig is not None:
|
||||
return retval.sig.apply(retries=retries + 1)
|
||||
state = states.SUCCESS if ret.info is None else ret.info.state
|
||||
return EagerResult(task_id, retval, state, traceback=tb)
|
||||
return EagerResult(task_id, retval, state, traceback=tb, name=self.name)
|
||||
|
||||
def AsyncResult(self, task_id, **kwargs):
|
||||
"""Get AsyncResult instance for the specified task.
|
||||
@@ -954,11 +962,20 @@ class Task:
|
||||
root_id=self.request.root_id,
|
||||
replaced_task_nesting=replaced_task_nesting
|
||||
)
|
||||
|
||||
# If the replaced task is a chain, we want to set all of the chain tasks
|
||||
# with the same replaced_task_nesting value to mark their replacement nesting level
|
||||
if isinstance(sig, _chain):
|
||||
for chain_task in maybe_list(sig.tasks) or []:
|
||||
chain_task.set(replaced_task_nesting=replaced_task_nesting)
|
||||
|
||||
# If the task being replaced is part of a chain, we need to re-create
|
||||
# it with the replacement signature - these subsequent tasks will
|
||||
# retain their original task IDs as well
|
||||
for t in reversed(self.request.chain or []):
|
||||
sig |= signature(t, app=self.app)
|
||||
chain_task = signature(t, app=self.app)
|
||||
chain_task.set(replaced_task_nesting=replaced_task_nesting)
|
||||
sig |= chain_task
|
||||
return self.on_replace(sig)
|
||||
|
||||
def add_to_chord(self, sig, lazy=False):
|
||||
@@ -1099,7 +1116,7 @@ class Task:
|
||||
return result
|
||||
|
||||
def push_request(self, *args, **kwargs):
|
||||
self.request_stack.push(Context(*args, **kwargs))
|
||||
self.request_stack.push(Context(*args, **{**self.request.__dict__, **kwargs}))
|
||||
|
||||
def pop_request(self):
|
||||
self.request_stack.pop()
|
||||
|
||||
@@ -8,7 +8,6 @@ import os
|
||||
import sys
|
||||
import time
|
||||
from collections import namedtuple
|
||||
from typing import Any, Callable, Dict, FrozenSet, Optional, Sequence, Tuple, Type, Union
|
||||
from warnings import warn
|
||||
|
||||
from billiard.einfo import ExceptionInfo, ExceptionWithTraceback
|
||||
@@ -17,8 +16,6 @@ from kombu.serialization import loads as loads_message
|
||||
from kombu.serialization import prepare_accept_content
|
||||
from kombu.utils.encoding import safe_repr, safe_str
|
||||
|
||||
import celery
|
||||
import celery.loaders.app
|
||||
from celery import current_app, group, signals, states
|
||||
from celery._state import _task_stack
|
||||
from celery.app.task import Context
|
||||
@@ -294,20 +291,10 @@ def traceback_clear(exc=None):
|
||||
tb = tb.tb_next
|
||||
|
||||
|
||||
def build_tracer(
|
||||
name: str,
|
||||
task: Union[celery.Task, celery.local.PromiseProxy],
|
||||
loader: Optional[celery.loaders.app.AppLoader] = None,
|
||||
hostname: Optional[str] = None,
|
||||
store_errors: bool = True,
|
||||
Info: Type[TraceInfo] = TraceInfo,
|
||||
eager: bool = False,
|
||||
propagate: bool = False,
|
||||
app: Optional[celery.Celery] = None,
|
||||
monotonic: Callable[[], int] = time.monotonic,
|
||||
trace_ok_t: Type[trace_ok_t] = trace_ok_t,
|
||||
IGNORE_STATES: FrozenSet[str] = IGNORE_STATES) -> \
|
||||
Callable[[str, Tuple[Any, ...], Dict[str, Any], Any], trace_ok_t]:
|
||||
def build_tracer(name, task, loader=None, hostname=None, store_errors=True,
|
||||
Info=TraceInfo, eager=False, propagate=False, app=None,
|
||||
monotonic=time.monotonic, trace_ok_t=trace_ok_t,
|
||||
IGNORE_STATES=IGNORE_STATES):
|
||||
"""Return a function that traces task execution.
|
||||
|
||||
Catches all exceptions and updates result backend with the
|
||||
@@ -387,12 +374,7 @@ def build_tracer(
|
||||
from celery import canvas
|
||||
signature = canvas.maybe_signature # maybe_ does not clone if already
|
||||
|
||||
def on_error(
|
||||
request: celery.app.task.Context,
|
||||
exc: Union[Exception, Type[Exception]],
|
||||
state: str = FAILURE,
|
||||
call_errbacks: bool = True) -> Tuple[Info, Any, Any, Any]:
|
||||
"""Handle any errors raised by a `Task`'s execution."""
|
||||
def on_error(request, exc, state=FAILURE, call_errbacks=True):
|
||||
if propagate:
|
||||
raise
|
||||
I = Info(state, exc)
|
||||
@@ -401,13 +383,7 @@ def build_tracer(
|
||||
)
|
||||
return I, R, I.state, I.retval
|
||||
|
||||
def trace_task(
|
||||
uuid: str,
|
||||
args: Sequence[Any],
|
||||
kwargs: Dict[str, Any],
|
||||
request: Optional[Dict[str, Any]] = None) -> trace_ok_t:
|
||||
"""Execute and trace a `Task`."""
|
||||
|
||||
def trace_task(uuid, args, kwargs, request=None):
|
||||
# R - is the possibly prepared return value.
|
||||
# I - is the Info object.
|
||||
# T - runtime
|
||||
|
||||
@@ -35,7 +35,7 @@ settings -> transport:{transport} results:{results}
|
||||
"""
|
||||
|
||||
HIDDEN_SETTINGS = re.compile(
|
||||
'API|TOKEN|KEY|SECRET|PASS|PROFANITIES_LIST|SIGNATURE|DATABASE',
|
||||
'API|TOKEN|KEY|SECRET|PASS|PROFANITIES_LIST|SIGNATURE|DATABASE|BEAT_DBURI',
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user