This commit is contained in:
@@ -15,9 +15,9 @@ from collections import namedtuple
|
||||
# Lazy loading
|
||||
from . import local
|
||||
|
||||
SERIES = 'emerald-rush'
|
||||
SERIES = 'immunity'
|
||||
|
||||
__version__ = '5.3.4'
|
||||
__version__ = '5.5.3'
|
||||
__author__ = 'Ask Solem'
|
||||
__contact__ = 'auvipy@gmail.com'
|
||||
__homepage__ = 'https://docs.celeryq.dev/'
|
||||
|
||||
@@ -249,9 +249,13 @@ class AMQP:
|
||||
if max_priority is None:
|
||||
max_priority = conf.task_queue_max_priority
|
||||
if not queues and conf.task_default_queue:
|
||||
queue_arguments = None
|
||||
if conf.task_default_queue_type == 'quorum':
|
||||
queue_arguments = {'x-queue-type': 'quorum'}
|
||||
queues = (Queue(conf.task_default_queue,
|
||||
exchange=self.default_exchange,
|
||||
routing_key=default_routing_key),)
|
||||
routing_key=default_routing_key,
|
||||
queue_arguments=queue_arguments),)
|
||||
autoexchange = (self.autoexchange if autoexchange is None
|
||||
else autoexchange)
|
||||
return self.queues_cls(
|
||||
@@ -285,7 +289,7 @@ class AMQP:
|
||||
create_sent_event=False, root_id=None, parent_id=None,
|
||||
shadow=None, chain=None, now=None, timezone=None,
|
||||
origin=None, ignore_result=False, argsrepr=None, kwargsrepr=None, stamped_headers=None,
|
||||
**options):
|
||||
replaced_task_nesting=0, **options):
|
||||
|
||||
args = args or ()
|
||||
kwargs = kwargs or {}
|
||||
@@ -339,6 +343,7 @@ class AMQP:
|
||||
'kwargsrepr': kwargsrepr,
|
||||
'origin': origin or anon_nodename(),
|
||||
'ignore_result': ignore_result,
|
||||
'replaced_task_nesting': replaced_task_nesting,
|
||||
'stamped_headers': stamped_headers,
|
||||
'stamps': stamps,
|
||||
}
|
||||
@@ -462,7 +467,8 @@ class AMQP:
|
||||
retry=None, retry_policy=None,
|
||||
serializer=None, delivery_mode=None,
|
||||
compression=None, declare=None,
|
||||
headers=None, exchange_type=None, **kwargs):
|
||||
headers=None, exchange_type=None,
|
||||
timeout=None, confirm_timeout=None, **kwargs):
|
||||
retry = default_retry if retry is None else retry
|
||||
headers2, properties, body, sent_event = message
|
||||
if headers:
|
||||
@@ -523,6 +529,7 @@ class AMQP:
|
||||
retry=retry, retry_policy=_rp,
|
||||
delivery_mode=delivery_mode, declare=declare,
|
||||
headers=headers2,
|
||||
timeout=timeout, confirm_timeout=confirm_timeout,
|
||||
**properties
|
||||
)
|
||||
if after_receivers:
|
||||
|
||||
@@ -34,6 +34,7 @@ BACKEND_ALIASES = {
|
||||
'azureblockblob': 'celery.backends.azureblockblob:AzureBlockBlobBackend',
|
||||
'arangodb': 'celery.backends.arangodb:ArangoDbBackend',
|
||||
's3': 'celery.backends.s3:S3Backend',
|
||||
'gs': 'celery.backends.gcs:GCSBackend',
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -1,17 +1,23 @@
|
||||
"""Actual App instance implementation."""
|
||||
import functools
|
||||
import importlib
|
||||
import inspect
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
import typing
|
||||
import warnings
|
||||
from collections import UserDict, defaultdict, deque
|
||||
from datetime import datetime
|
||||
from datetime import timezone as datetime_timezone
|
||||
from operator import attrgetter
|
||||
|
||||
from click.exceptions import Exit
|
||||
from kombu import pools
|
||||
from dateutil.parser import isoparse
|
||||
from kombu import Exchange, pools
|
||||
from kombu.clocks import LamportClock
|
||||
from kombu.common import oid_from
|
||||
from kombu.transport.native_delayed_delivery import calculate_routing_key
|
||||
from kombu.utils.compat import register_after_fork
|
||||
from kombu.utils.objects import cached_property
|
||||
from kombu.utils.uuid import uuid
|
||||
@@ -32,6 +38,8 @@ from celery.utils.log import get_logger
|
||||
from celery.utils.objects import FallbackContext, mro_lookup
|
||||
from celery.utils.time import maybe_make_aware, timezone, to_utc
|
||||
|
||||
from ..utils.annotations import annotation_is_class, annotation_issubclass, get_optional_arg
|
||||
from ..utils.quorum_queues import detect_quorum_queues
|
||||
# Load all builtin tasks
|
||||
from . import backends, builtins # noqa
|
||||
from .annotations import prepare as prepare_annotations
|
||||
@@ -41,6 +49,10 @@ from .registry import TaskRegistry
|
||||
from .utils import (AppPickler, Settings, _new_key_to_old, _old_key_to_new, _unpickle_app, _unpickle_app_v2, appstr,
|
||||
bugreport, detect_settings)
|
||||
|
||||
if typing.TYPE_CHECKING: # pragma: no cover # codecov does not capture this
|
||||
# flake8 marks the BaseModel import as unused, because the actual typehint is quoted.
|
||||
from pydantic import BaseModel # noqa: F401
|
||||
|
||||
__all__ = ('Celery',)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -90,6 +102,70 @@ def _after_fork_cleanup_app(app):
|
||||
logger.info('after forker raised exception: %r', exc, exc_info=1)
|
||||
|
||||
|
||||
def pydantic_wrapper(
|
||||
app: "Celery",
|
||||
task_fun: typing.Callable[..., typing.Any],
|
||||
task_name: str,
|
||||
strict: bool = True,
|
||||
context: typing.Optional[typing.Dict[str, typing.Any]] = None,
|
||||
dump_kwargs: typing.Optional[typing.Dict[str, typing.Any]] = None
|
||||
):
|
||||
"""Wrapper to validate arguments and serialize return values using Pydantic."""
|
||||
try:
|
||||
pydantic = importlib.import_module('pydantic')
|
||||
except ModuleNotFoundError as ex:
|
||||
raise ImproperlyConfigured('You need to install pydantic to use pydantic model serialization.') from ex
|
||||
|
||||
BaseModel: typing.Type['BaseModel'] = pydantic.BaseModel # noqa: F811 # only defined when type checking
|
||||
|
||||
if context is None:
|
||||
context = {}
|
||||
if dump_kwargs is None:
|
||||
dump_kwargs = {}
|
||||
dump_kwargs.setdefault('mode', 'json')
|
||||
|
||||
task_signature = inspect.signature(task_fun)
|
||||
|
||||
@functools.wraps(task_fun)
|
||||
def wrapper(*task_args, **task_kwargs):
|
||||
# Validate task parameters if type hinted as BaseModel
|
||||
bound_args = task_signature.bind(*task_args, **task_kwargs)
|
||||
for arg_name, arg_value in bound_args.arguments.items():
|
||||
arg_annotation = task_signature.parameters[arg_name].annotation
|
||||
|
||||
optional_arg = get_optional_arg(arg_annotation)
|
||||
if optional_arg is not None and arg_value is not None:
|
||||
arg_annotation = optional_arg
|
||||
|
||||
if annotation_issubclass(arg_annotation, BaseModel):
|
||||
bound_args.arguments[arg_name] = arg_annotation.model_validate(
|
||||
arg_value,
|
||||
strict=strict,
|
||||
context={**context, 'celery_app': app, 'celery_task_name': task_name},
|
||||
)
|
||||
|
||||
# Call the task with (potentially) converted arguments
|
||||
returned_value = task_fun(*bound_args.args, **bound_args.kwargs)
|
||||
|
||||
# Dump Pydantic model if the returned value is an instance of pydantic.BaseModel *and* its
|
||||
# class matches the typehint
|
||||
return_annotation = task_signature.return_annotation
|
||||
optional_return_annotation = get_optional_arg(return_annotation)
|
||||
if optional_return_annotation is not None:
|
||||
return_annotation = optional_return_annotation
|
||||
|
||||
if (
|
||||
annotation_is_class(return_annotation)
|
||||
and isinstance(returned_value, BaseModel)
|
||||
and isinstance(returned_value, return_annotation)
|
||||
):
|
||||
return returned_value.model_dump(**dump_kwargs)
|
||||
|
||||
return returned_value
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class PendingConfiguration(UserDict, AttributeDictMixin):
|
||||
# `app.conf` will be of this type before being explicitly configured,
|
||||
# meaning the app can keep any configuration set directly
|
||||
@@ -238,6 +314,12 @@ class Celery:
|
||||
self.loader_cls = loader or self._get_default_loader()
|
||||
self.log_cls = log or self.log_cls
|
||||
self.control_cls = control or self.control_cls
|
||||
self._custom_task_cls_used = (
|
||||
# Custom task class provided as argument
|
||||
bool(task_cls)
|
||||
# subclass of Celery with a task_cls attribute
|
||||
or self.__class__ is not Celery and hasattr(self.__class__, 'task_cls')
|
||||
)
|
||||
self.task_cls = task_cls or self.task_cls
|
||||
self.set_as_current = set_as_current
|
||||
self.registry_cls = symbol_by_name(self.registry_cls)
|
||||
@@ -433,6 +515,7 @@ class Celery:
|
||||
if shared:
|
||||
def cons(app):
|
||||
return app._task_from_fun(fun, **opts)
|
||||
|
||||
cons.__name__ = fun.__name__
|
||||
connect_on_app_finalize(cons)
|
||||
if not lazy or self.finalized:
|
||||
@@ -461,13 +544,27 @@ class Celery:
|
||||
def type_checker(self, fun, bound=False):
|
||||
return staticmethod(head_from_fun(fun, bound=bound))
|
||||
|
||||
def _task_from_fun(self, fun, name=None, base=None, bind=False, **options):
|
||||
def _task_from_fun(
|
||||
self,
|
||||
fun,
|
||||
name=None,
|
||||
base=None,
|
||||
bind=False,
|
||||
pydantic: bool = False,
|
||||
pydantic_strict: bool = False,
|
||||
pydantic_context: typing.Optional[typing.Dict[str, typing.Any]] = None,
|
||||
pydantic_dump_kwargs: typing.Optional[typing.Dict[str, typing.Any]] = None,
|
||||
**options,
|
||||
):
|
||||
if not self.finalized and not self.autofinalize:
|
||||
raise RuntimeError('Contract breach: app not finalized')
|
||||
name = name or self.gen_task_name(fun.__name__, fun.__module__)
|
||||
base = base or self.Task
|
||||
|
||||
if name not in self._tasks:
|
||||
if pydantic is True:
|
||||
fun = pydantic_wrapper(self, fun, name, pydantic_strict, pydantic_context, pydantic_dump_kwargs)
|
||||
|
||||
run = fun if bind else staticmethod(fun)
|
||||
task = type(fun.__name__, (base,), dict({
|
||||
'app': self,
|
||||
@@ -711,7 +808,7 @@ class Celery:
|
||||
retries=0, chord=None,
|
||||
reply_to=None, time_limit=None, soft_time_limit=None,
|
||||
root_id=None, parent_id=None, route_name=None,
|
||||
shadow=None, chain=None, task_type=None, **options):
|
||||
shadow=None, chain=None, task_type=None, replaced_task_nesting=0, **options):
|
||||
"""Send task by name.
|
||||
|
||||
Supports the same arguments as :meth:`@-Task.apply_async`.
|
||||
@@ -734,13 +831,48 @@ class Celery:
|
||||
ignore_result = options.pop('ignore_result', False)
|
||||
options = router.route(
|
||||
options, route_name or name, args, kwargs, task_type)
|
||||
|
||||
driver_type = self.producer_pool.connections.connection.transport.driver_type
|
||||
|
||||
if (eta or countdown) and detect_quorum_queues(self, driver_type)[0]:
|
||||
|
||||
queue = options.get("queue")
|
||||
exchange_type = queue.exchange.type if queue else options["exchange_type"]
|
||||
routing_key = queue.routing_key if queue else options["routing_key"]
|
||||
exchange_name = queue.exchange.name if queue else options["exchange"]
|
||||
|
||||
if exchange_type != 'direct':
|
||||
if eta:
|
||||
if isinstance(eta, str):
|
||||
eta = isoparse(eta)
|
||||
countdown = (maybe_make_aware(eta) - self.now()).total_seconds()
|
||||
|
||||
if countdown:
|
||||
if countdown > 0:
|
||||
routing_key = calculate_routing_key(int(countdown), routing_key)
|
||||
exchange = Exchange(
|
||||
'celery_delayed_27',
|
||||
type='topic',
|
||||
)
|
||||
options.pop("queue", None)
|
||||
options['routing_key'] = routing_key
|
||||
options['exchange'] = exchange
|
||||
|
||||
else:
|
||||
logger.warning(
|
||||
'Direct exchanges are not supported with native delayed delivery.\n'
|
||||
f'{exchange_name} is a direct exchange but should be a topic exchange or '
|
||||
'a fanout exchange in order for native delayed delivery to work properly.\n'
|
||||
'If quorum queues are used, this task may block the worker process until the ETA arrives.'
|
||||
)
|
||||
|
||||
if expires is not None:
|
||||
if isinstance(expires, datetime):
|
||||
expires_s = (maybe_make_aware(
|
||||
expires) - self.now()).total_seconds()
|
||||
elif isinstance(expires, str):
|
||||
expires_s = (maybe_make_aware(
|
||||
datetime.fromisoformat(expires)) - self.now()).total_seconds()
|
||||
isoparse(expires)) - self.now()).total_seconds()
|
||||
else:
|
||||
expires_s = expires
|
||||
|
||||
@@ -781,7 +913,7 @@ class Celery:
|
||||
self.conf.task_send_sent_event,
|
||||
root_id, parent_id, shadow, chain,
|
||||
ignore_result=ignore_result,
|
||||
**options
|
||||
replaced_task_nesting=replaced_task_nesting, **options
|
||||
)
|
||||
|
||||
stamped_headers = options.pop('stamped_headers', [])
|
||||
@@ -894,6 +1026,7 @@ class Celery:
|
||||
'broker_connection_timeout', connect_timeout
|
||||
),
|
||||
)
|
||||
|
||||
broker_connection = connection
|
||||
|
||||
def _acquire_connection(self, pool=True):
|
||||
@@ -913,6 +1046,7 @@ class Celery:
|
||||
will be acquired from the connection pool.
|
||||
"""
|
||||
return FallbackContext(connection, self._acquire_connection, pool=pool)
|
||||
|
||||
default_connection = connection_or_acquire # XXX compat
|
||||
|
||||
def producer_or_acquire(self, producer=None):
|
||||
@@ -928,6 +1062,7 @@ class Celery:
|
||||
return FallbackContext(
|
||||
producer, self.producer_pool.acquire, block=True,
|
||||
)
|
||||
|
||||
default_producer = producer_or_acquire # XXX compat
|
||||
|
||||
def prepare_config(self, c):
|
||||
@@ -936,7 +1071,7 @@ class Celery:
|
||||
|
||||
def now(self):
|
||||
"""Return the current time and date as a datetime."""
|
||||
now_in_utc = to_utc(datetime.utcnow())
|
||||
now_in_utc = to_utc(datetime.now(datetime_timezone.utc))
|
||||
return now_in_utc.astimezone(self.timezone)
|
||||
|
||||
def select_queues(self, queues=None):
|
||||
@@ -974,7 +1109,14 @@ class Celery:
|
||||
This is used by PendingConfiguration:
|
||||
as soon as you access a key the configuration is read.
|
||||
"""
|
||||
conf = self._conf = self._load_config()
|
||||
try:
|
||||
conf = self._conf = self._load_config()
|
||||
except AttributeError as err:
|
||||
# AttributeError is not propagated, it is "handled" by
|
||||
# PendingConfiguration parent class. This causes
|
||||
# confusing RecursionError.
|
||||
raise ModuleNotFoundError(*err.args) from err
|
||||
|
||||
return conf
|
||||
|
||||
def _load_config(self):
|
||||
|
||||
@@ -360,7 +360,7 @@ class Inspect:
|
||||
* ``routing_key`` - Routing key used when task was published
|
||||
* ``priority`` - Priority used when task was published
|
||||
* ``redelivered`` - True if the task was redelivered
|
||||
* ``worker_pid`` - PID of worker processin the task
|
||||
* ``worker_pid`` - PID of worker processing the task
|
||||
|
||||
"""
|
||||
# signature used be unary: query_task(ids=[id1, id2])
|
||||
@@ -527,7 +527,8 @@ class Control:
|
||||
if result:
|
||||
for host in result:
|
||||
for response in host.values():
|
||||
task_ids.update(response['ok'])
|
||||
if isinstance(response['ok'], set):
|
||||
task_ids.update(response['ok'])
|
||||
|
||||
if task_ids:
|
||||
return self.revoke(list(task_ids), destination=destination, terminate=terminate, signal=signal, **kwargs)
|
||||
|
||||
@@ -95,6 +95,7 @@ NAMESPACES = Namespace(
|
||||
heartbeat=Option(120, type='int'),
|
||||
heartbeat_checkrate=Option(3.0, type='int'),
|
||||
login_method=Option(None, type='string'),
|
||||
native_delayed_delivery_queue_type=Option(default='quorum', type='string'),
|
||||
pool_limit=Option(10, type='int'),
|
||||
use_ssl=Option(False, type='bool'),
|
||||
|
||||
@@ -140,6 +141,12 @@ NAMESPACES = Namespace(
|
||||
connection_timeout=Option(20, type='int'),
|
||||
read_timeout=Option(120, type='int'),
|
||||
),
|
||||
gcs=Namespace(
|
||||
bucket=Option(type='string'),
|
||||
project=Option(type='string'),
|
||||
base_path=Option('', type='string'),
|
||||
ttl=Option(0, type='float'),
|
||||
),
|
||||
control=Namespace(
|
||||
queue_ttl=Option(300.0, type='float'),
|
||||
queue_expires=Option(10.0, type='float'),
|
||||
@@ -243,6 +250,7 @@ NAMESPACES = Namespace(
|
||||
),
|
||||
table_schemas=Option(type='dict'),
|
||||
table_names=Option(type='dict', old={'celery_result_db_tablenames'}),
|
||||
create_tables_at_setup=Option(True, type='bool'),
|
||||
),
|
||||
task=Namespace(
|
||||
__old__=OLD_NS,
|
||||
@@ -255,6 +263,7 @@ NAMESPACES = Namespace(
|
||||
inherit_parent_priority=Option(False, type='bool'),
|
||||
default_delivery_mode=Option(2, type='string'),
|
||||
default_queue=Option('celery'),
|
||||
default_queue_type=Option('classic', type='string'),
|
||||
default_exchange=Option(None, type='string'), # taken from queue
|
||||
default_exchange_type=Option('direct'),
|
||||
default_routing_key=Option(None, type='string'), # taken from queue
|
||||
@@ -302,6 +311,8 @@ NAMESPACES = Namespace(
|
||||
cancel_long_running_tasks_on_connection_loss=Option(
|
||||
False, type='bool'
|
||||
),
|
||||
soft_shutdown_timeout=Option(0.0, type='float'),
|
||||
enable_soft_shutdown_on_idle=Option(False, type='bool'),
|
||||
concurrency=Option(None, type='int'),
|
||||
consumer=Option('celery.worker.consumer:Consumer', type='string'),
|
||||
direct=Option(False, type='bool', old={'celery_worker_direct'}),
|
||||
@@ -325,6 +336,7 @@ NAMESPACES = Namespace(
|
||||
pool_restarts=Option(False, type='bool'),
|
||||
proc_alive_timeout=Option(4.0, type='float'),
|
||||
prefetch_multiplier=Option(4, type='int'),
|
||||
enable_prefetch_count_reduction=Option(True, type='bool'),
|
||||
redirect_stdouts=Option(
|
||||
True, type='bool', old={'celery_redirect_stdouts'},
|
||||
),
|
||||
@@ -338,6 +350,7 @@ NAMESPACES = Namespace(
|
||||
task_log_format=Option(DEFAULT_TASK_LOG_FMT),
|
||||
timer=Option(type='string'),
|
||||
timer_precision=Option(1.0, type='float'),
|
||||
detect_quorum_queues=Option(True, type='bool'),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -18,6 +18,7 @@ from celery import signals
|
||||
from celery._state import get_current_task
|
||||
from celery.exceptions import CDeprecationWarning, CPendingDeprecationWarning
|
||||
from celery.local import class_property
|
||||
from celery.platforms import isatty
|
||||
from celery.utils.log import (ColorFormatter, LoggingProxy, get_logger, get_multiprocessing_logger, mlevel,
|
||||
reset_multiprocessing_logger)
|
||||
from celery.utils.nodenames import node_format
|
||||
@@ -203,7 +204,7 @@ class Logging:
|
||||
if colorize or colorize is None:
|
||||
# Only use color if there's no active log file
|
||||
# and stderr is an actual terminal.
|
||||
return logfile is None and sys.stderr.isatty()
|
||||
return logfile is None and isatty(sys.stderr)
|
||||
return colorize
|
||||
|
||||
def colored(self, logfile=None, enabled=None):
|
||||
|
||||
@@ -20,7 +20,7 @@ except AttributeError: # pragma: no cover
|
||||
# for support Python 3.7
|
||||
Pattern = re.Pattern
|
||||
|
||||
__all__ = ('MapRoute', 'Router', 'prepare')
|
||||
__all__ = ('MapRoute', 'Router', 'expand_router_string', 'prepare')
|
||||
|
||||
|
||||
class MapRoute:
|
||||
|
||||
@@ -104,7 +104,7 @@ class Context:
|
||||
def _get_custom_headers(self, *args, **kwargs):
|
||||
headers = {}
|
||||
headers.update(*args, **kwargs)
|
||||
celery_keys = {*Context.__dict__.keys(), 'lang', 'task', 'argsrepr', 'kwargsrepr'}
|
||||
celery_keys = {*Context.__dict__.keys(), 'lang', 'task', 'argsrepr', 'kwargsrepr', 'compression'}
|
||||
for key in celery_keys:
|
||||
headers.pop(key, None)
|
||||
if not headers:
|
||||
@@ -466,7 +466,7 @@ class Task:
|
||||
shadow (str): Override task name used in logs/monitoring.
|
||||
Default is retrieved from :meth:`shadow_name`.
|
||||
|
||||
connection (kombu.Connection): Re-use existing broker connection
|
||||
connection (kombu.Connection): Reuse existing broker connection
|
||||
instead of acquiring one from the connection pool.
|
||||
|
||||
retry (bool): If enabled sending of the task message will be
|
||||
@@ -535,6 +535,8 @@ class Task:
|
||||
publisher (kombu.Producer): Deprecated alias to ``producer``.
|
||||
|
||||
headers (Dict): Message headers to be included in the message.
|
||||
The headers can be used as an overlay for custom labeling
|
||||
using the :ref:`canvas-stamping` feature.
|
||||
|
||||
Returns:
|
||||
celery.result.AsyncResult: Promise of future evaluation.
|
||||
@@ -543,6 +545,8 @@ class Task:
|
||||
TypeError: If not enough arguments are passed, or too many
|
||||
arguments are passed. Note that signature checks may
|
||||
be disabled by specifying ``@task(typing=False)``.
|
||||
ValueError: If soft_time_limit and time_limit both are set
|
||||
but soft_time_limit is greater than time_limit
|
||||
kombu.exceptions.OperationalError: If a connection to the
|
||||
transport cannot be made, or if the connection is lost.
|
||||
|
||||
@@ -550,6 +554,9 @@ class Task:
|
||||
Also supports all keyword arguments supported by
|
||||
:meth:`kombu.Producer.publish`.
|
||||
"""
|
||||
if self.soft_time_limit and self.time_limit and self.soft_time_limit > self.time_limit:
|
||||
raise ValueError('soft_time_limit must be less than or equal to time_limit')
|
||||
|
||||
if self.typing:
|
||||
try:
|
||||
check_arguments = self.__header__
|
||||
@@ -788,6 +795,7 @@ class Task:
|
||||
|
||||
request = {
|
||||
'id': task_id,
|
||||
'task': self.name,
|
||||
'retries': retries,
|
||||
'is_eager': True,
|
||||
'logfile': logfile,
|
||||
@@ -824,7 +832,7 @@ class Task:
|
||||
if isinstance(retval, Retry) and retval.sig is not None:
|
||||
return retval.sig.apply(retries=retries + 1)
|
||||
state = states.SUCCESS if ret.info is None else ret.info.state
|
||||
return EagerResult(task_id, retval, state, traceback=tb)
|
||||
return EagerResult(task_id, retval, state, traceback=tb, name=self.name)
|
||||
|
||||
def AsyncResult(self, task_id, **kwargs):
|
||||
"""Get AsyncResult instance for the specified task.
|
||||
@@ -954,11 +962,20 @@ class Task:
|
||||
root_id=self.request.root_id,
|
||||
replaced_task_nesting=replaced_task_nesting
|
||||
)
|
||||
|
||||
# If the replaced task is a chain, we want to set all of the chain tasks
|
||||
# with the same replaced_task_nesting value to mark their replacement nesting level
|
||||
if isinstance(sig, _chain):
|
||||
for chain_task in maybe_list(sig.tasks) or []:
|
||||
chain_task.set(replaced_task_nesting=replaced_task_nesting)
|
||||
|
||||
# If the task being replaced is part of a chain, we need to re-create
|
||||
# it with the replacement signature - these subsequent tasks will
|
||||
# retain their original task IDs as well
|
||||
for t in reversed(self.request.chain or []):
|
||||
sig |= signature(t, app=self.app)
|
||||
chain_task = signature(t, app=self.app)
|
||||
chain_task.set(replaced_task_nesting=replaced_task_nesting)
|
||||
sig |= chain_task
|
||||
return self.on_replace(sig)
|
||||
|
||||
def add_to_chord(self, sig, lazy=False):
|
||||
@@ -1099,7 +1116,7 @@ class Task:
|
||||
return result
|
||||
|
||||
def push_request(self, *args, **kwargs):
|
||||
self.request_stack.push(Context(*args, **kwargs))
|
||||
self.request_stack.push(Context(*args, **{**self.request.__dict__, **kwargs}))
|
||||
|
||||
def pop_request(self):
|
||||
self.request_stack.pop()
|
||||
|
||||
@@ -8,7 +8,6 @@ import os
|
||||
import sys
|
||||
import time
|
||||
from collections import namedtuple
|
||||
from typing import Any, Callable, Dict, FrozenSet, Optional, Sequence, Tuple, Type, Union
|
||||
from warnings import warn
|
||||
|
||||
from billiard.einfo import ExceptionInfo, ExceptionWithTraceback
|
||||
@@ -17,8 +16,6 @@ from kombu.serialization import loads as loads_message
|
||||
from kombu.serialization import prepare_accept_content
|
||||
from kombu.utils.encoding import safe_repr, safe_str
|
||||
|
||||
import celery
|
||||
import celery.loaders.app
|
||||
from celery import current_app, group, signals, states
|
||||
from celery._state import _task_stack
|
||||
from celery.app.task import Context
|
||||
@@ -294,20 +291,10 @@ def traceback_clear(exc=None):
|
||||
tb = tb.tb_next
|
||||
|
||||
|
||||
def build_tracer(
|
||||
name: str,
|
||||
task: Union[celery.Task, celery.local.PromiseProxy],
|
||||
loader: Optional[celery.loaders.app.AppLoader] = None,
|
||||
hostname: Optional[str] = None,
|
||||
store_errors: bool = True,
|
||||
Info: Type[TraceInfo] = TraceInfo,
|
||||
eager: bool = False,
|
||||
propagate: bool = False,
|
||||
app: Optional[celery.Celery] = None,
|
||||
monotonic: Callable[[], int] = time.monotonic,
|
||||
trace_ok_t: Type[trace_ok_t] = trace_ok_t,
|
||||
IGNORE_STATES: FrozenSet[str] = IGNORE_STATES) -> \
|
||||
Callable[[str, Tuple[Any, ...], Dict[str, Any], Any], trace_ok_t]:
|
||||
def build_tracer(name, task, loader=None, hostname=None, store_errors=True,
|
||||
Info=TraceInfo, eager=False, propagate=False, app=None,
|
||||
monotonic=time.monotonic, trace_ok_t=trace_ok_t,
|
||||
IGNORE_STATES=IGNORE_STATES):
|
||||
"""Return a function that traces task execution.
|
||||
|
||||
Catches all exceptions and updates result backend with the
|
||||
@@ -387,12 +374,7 @@ def build_tracer(
|
||||
from celery import canvas
|
||||
signature = canvas.maybe_signature # maybe_ does not clone if already
|
||||
|
||||
def on_error(
|
||||
request: celery.app.task.Context,
|
||||
exc: Union[Exception, Type[Exception]],
|
||||
state: str = FAILURE,
|
||||
call_errbacks: bool = True) -> Tuple[Info, Any, Any, Any]:
|
||||
"""Handle any errors raised by a `Task`'s execution."""
|
||||
def on_error(request, exc, state=FAILURE, call_errbacks=True):
|
||||
if propagate:
|
||||
raise
|
||||
I = Info(state, exc)
|
||||
@@ -401,13 +383,7 @@ def build_tracer(
|
||||
)
|
||||
return I, R, I.state, I.retval
|
||||
|
||||
def trace_task(
|
||||
uuid: str,
|
||||
args: Sequence[Any],
|
||||
kwargs: Dict[str, Any],
|
||||
request: Optional[Dict[str, Any]] = None) -> trace_ok_t:
|
||||
"""Execute and trace a `Task`."""
|
||||
|
||||
def trace_task(uuid, args, kwargs, request=None):
|
||||
# R - is the possibly prepared return value.
|
||||
# I - is the Info object.
|
||||
# T - runtime
|
||||
|
||||
@@ -35,7 +35,7 @@ settings -> transport:{transport} results:{results}
|
||||
"""
|
||||
|
||||
HIDDEN_SETTINGS = re.compile(
|
||||
'API|TOKEN|KEY|SECRET|PASS|PROFANITIES_LIST|SIGNATURE|DATABASE',
|
||||
'API|TOKEN|KEY|SECRET|PASS|PROFANITIES_LIST|SIGNATURE|DATABASE|BEAT_DBURI',
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ from kombu.utils.encoding import safe_str
|
||||
from celery import VERSION_BANNER, platforms, signals
|
||||
from celery.app import trace
|
||||
from celery.loaders.app import AppLoader
|
||||
from celery.platforms import EX_FAILURE, EX_OK, check_privileges
|
||||
from celery.platforms import EX_FAILURE, EX_OK, check_privileges, isatty
|
||||
from celery.utils import static, term
|
||||
from celery.utils.debug import cry
|
||||
from celery.utils.imports import qualname
|
||||
@@ -77,8 +77,9 @@ def active_thread_count():
|
||||
if not t.name.startswith('Dummy-'))
|
||||
|
||||
|
||||
def safe_say(msg):
|
||||
print(f'\n{msg}', file=sys.__stderr__, flush=True)
|
||||
def safe_say(msg, f=sys.__stderr__):
|
||||
if hasattr(f, 'fileno') and f.fileno() is not None:
|
||||
os.write(f.fileno(), f'\n{msg}\n'.encode())
|
||||
|
||||
|
||||
class Worker(WorkController):
|
||||
@@ -106,7 +107,7 @@ class Worker(WorkController):
|
||||
super().setup_defaults(**kwargs)
|
||||
self.purge = purge
|
||||
self.no_color = no_color
|
||||
self._isatty = sys.stdout.isatty()
|
||||
self._isatty = isatty(sys.stdout)
|
||||
self.colored = self.app.log.colored(
|
||||
self.logfile,
|
||||
enabled=not no_color if no_color is not None else no_color
|
||||
@@ -278,15 +279,27 @@ class Worker(WorkController):
|
||||
)
|
||||
|
||||
|
||||
def _shutdown_handler(worker, sig='TERM', how='Warm',
|
||||
callback=None, exitcode=EX_OK):
|
||||
def _shutdown_handler(worker: Worker, sig='SIGTERM', how='Warm', callback=None, exitcode=EX_OK, verbose=True):
|
||||
"""Install signal handler for warm/cold shutdown.
|
||||
|
||||
The handler will run from the MainProcess.
|
||||
|
||||
Args:
|
||||
worker (Worker): The worker that received the signal.
|
||||
sig (str, optional): The signal that was received. Defaults to 'TERM'.
|
||||
how (str, optional): The type of shutdown to perform. Defaults to 'Warm'.
|
||||
callback (Callable, optional): Signal handler. Defaults to None.
|
||||
exitcode (int, optional): The exit code to use. Defaults to EX_OK.
|
||||
verbose (bool, optional): Whether to print the type of shutdown. Defaults to True.
|
||||
"""
|
||||
def _handle_request(*args):
|
||||
with in_sighandler():
|
||||
from celery.worker import state
|
||||
if current_process()._name == 'MainProcess':
|
||||
if callback:
|
||||
callback(worker)
|
||||
safe_say(f'worker: {how} shutdown (MainProcess)')
|
||||
if verbose:
|
||||
safe_say(f'worker: {how} shutdown (MainProcess)', sys.__stdout__)
|
||||
signals.worker_shutting_down.send(
|
||||
sender=worker.hostname, sig=sig, how=how,
|
||||
exitcode=exitcode,
|
||||
@@ -297,19 +310,126 @@ def _shutdown_handler(worker, sig='TERM', how='Warm',
|
||||
platforms.signals[sig] = _handle_request
|
||||
|
||||
|
||||
def on_hard_shutdown(worker: Worker):
|
||||
"""Signal handler for hard shutdown.
|
||||
|
||||
The handler will terminate the worker immediately by force using the exit code ``EX_FAILURE``.
|
||||
|
||||
In practice, you should never get here, as the standard shutdown process should be enough.
|
||||
This handler is only for the worst-case scenario, where the worker is stuck and cannot be
|
||||
terminated gracefully (e.g., spamming the Ctrl+C in the terminal to force the worker to terminate).
|
||||
|
||||
Args:
|
||||
worker (Worker): The worker that received the signal.
|
||||
|
||||
Raises:
|
||||
WorkerTerminate: This exception will be raised in the MainProcess to terminate the worker immediately.
|
||||
"""
|
||||
from celery.exceptions import WorkerTerminate
|
||||
raise WorkerTerminate(EX_FAILURE)
|
||||
|
||||
|
||||
def during_soft_shutdown(worker: Worker):
|
||||
"""This signal handler is called when the worker is in the middle of the soft shutdown process.
|
||||
|
||||
When the worker is in the soft shutdown process, it is waiting for tasks to finish. If the worker
|
||||
receives a SIGINT (Ctrl+C) or SIGQUIT signal (or possibly SIGTERM if REMAP_SIGTERM is set to "SIGQUIT"),
|
||||
the handler will cancels all unacked requests to allow the worker to terminate gracefully and replace the
|
||||
signal handler for SIGINT and SIGQUIT with the hard shutdown handler ``on_hard_shutdown`` to terminate
|
||||
the worker immediately by force next time the signal is received.
|
||||
|
||||
It will give the worker once last chance to gracefully terminate (the cold shutdown), after canceling all
|
||||
unacked requests, before using the hard shutdown handler to terminate the worker forcefully.
|
||||
|
||||
Args:
|
||||
worker (Worker): The worker that received the signal.
|
||||
"""
|
||||
# Replace the signal handler for SIGINT (Ctrl+C) and SIGQUIT (and possibly SIGTERM)
|
||||
# with the hard shutdown handler to terminate the worker immediately by force
|
||||
install_worker_term_hard_handler(worker, sig='SIGINT', callback=on_hard_shutdown, verbose=False)
|
||||
install_worker_term_hard_handler(worker, sig='SIGQUIT', callback=on_hard_shutdown)
|
||||
|
||||
# Cancel all unacked requests and allow the worker to terminate naturally
|
||||
worker.consumer.cancel_all_unacked_requests()
|
||||
|
||||
# We get here if the worker was in the middle of the soft (cold) shutdown process,
|
||||
# and the matching signal was received. This can typically happen when the worker is
|
||||
# waiting for tasks to finish, and the user decides to still cancel the running tasks.
|
||||
# We give the worker the last chance to gracefully terminate by letting the soft shutdown
|
||||
# waiting time to finish, which is running in the MainProcess from the previous signal handler call.
|
||||
safe_say('Waiting gracefully for cold shutdown to complete...', sys.__stdout__)
|
||||
|
||||
|
||||
def on_cold_shutdown(worker: Worker):
|
||||
"""Signal handler for cold shutdown.
|
||||
|
||||
Registered for SIGQUIT and SIGINT (Ctrl+C) signals. If REMAP_SIGTERM is set to "SIGQUIT", this handler will also
|
||||
be registered for SIGTERM.
|
||||
|
||||
This handler will initiate the cold (and soft if enabled) shutdown procesdure for the worker.
|
||||
|
||||
Worker running with N tasks:
|
||||
- SIGTERM:
|
||||
-The worker will initiate the warm shutdown process until all tasks are finished. Additional.
|
||||
SIGTERM signals will be ignored. SIGQUIT will transition to the cold shutdown process described below.
|
||||
- SIGQUIT:
|
||||
- The worker will initiate the cold shutdown process.
|
||||
- If the soft shutdown is enabled, the worker will wait for the tasks to finish up to the soft
|
||||
shutdown timeout (practically having a limited warm shutdown just before the cold shutdown).
|
||||
- Cancel all tasks (from the MainProcess) and allow the worker to complete the cold shutdown
|
||||
process gracefully.
|
||||
|
||||
Caveats:
|
||||
- SIGINT (Ctrl+C) signal is defined to replace itself with the cold shutdown (SIGQUIT) after first use,
|
||||
and to emit a message to the user to hit Ctrl+C again to initiate the cold shutdown process. But, most
|
||||
important, it will also be caught in WorkController.start() to initiate the warm shutdown process.
|
||||
- SIGTERM will also be handled in WorkController.start() to initiate the warm shutdown process (the same).
|
||||
- If REMAP_SIGTERM is set to "SIGQUIT", the SIGTERM signal will be remapped to SIGQUIT, and the cold
|
||||
shutdown process will be initiated instead of the warm shutdown process using SIGTERM.
|
||||
- If SIGQUIT is received (also via SIGINT) during the cold/soft shutdown process, the handler will cancel all
|
||||
unacked requests but still wait for the soft shutdown process to finish before terminating the worker
|
||||
gracefully. The next time the signal is received though, the worker will terminate immediately by force.
|
||||
|
||||
So, the purpose of this handler is to allow waiting for the soft shutdown timeout, then cancel all tasks from
|
||||
the MainProcess and let the WorkController.terminate() to terminate the worker naturally. If the soft shutdown
|
||||
is disabled, it will immediately cancel all tasks let the cold shutdown finish normally.
|
||||
|
||||
Args:
|
||||
worker (Worker): The worker that received the signal.
|
||||
"""
|
||||
safe_say('worker: Hitting Ctrl+C again will terminate all running tasks!', sys.__stdout__)
|
||||
|
||||
# Replace the signal handler for SIGINT (Ctrl+C) and SIGQUIT (and possibly SIGTERM)
|
||||
install_worker_term_hard_handler(worker, sig='SIGINT', callback=during_soft_shutdown)
|
||||
install_worker_term_hard_handler(worker, sig='SIGQUIT', callback=during_soft_shutdown)
|
||||
if REMAP_SIGTERM == "SIGQUIT":
|
||||
install_worker_term_hard_handler(worker, sig='SIGTERM', callback=during_soft_shutdown)
|
||||
# else, SIGTERM will print the _shutdown_handler's message and do nothing, every time it is received..
|
||||
|
||||
# Initiate soft shutdown process (if enabled and tasks are running)
|
||||
worker.wait_for_soft_shutdown()
|
||||
|
||||
# Cancel all unacked requests and allow the worker to terminate naturally
|
||||
worker.consumer.cancel_all_unacked_requests()
|
||||
|
||||
# Stop the pool to allow successful tasks call on_success()
|
||||
worker.consumer.pool.stop()
|
||||
|
||||
|
||||
# Allow SIGTERM to be remapped to SIGQUIT to initiate cold shutdown instead of warm shutdown using SIGTERM
|
||||
if REMAP_SIGTERM == "SIGQUIT":
|
||||
install_worker_term_handler = partial(
|
||||
_shutdown_handler, sig='SIGTERM', how='Cold', exitcode=EX_FAILURE,
|
||||
_shutdown_handler, sig='SIGTERM', how='Cold', callback=on_cold_shutdown, exitcode=EX_FAILURE,
|
||||
)
|
||||
else:
|
||||
install_worker_term_handler = partial(
|
||||
_shutdown_handler, sig='SIGTERM', how='Warm',
|
||||
)
|
||||
|
||||
|
||||
if not is_jython: # pragma: no cover
|
||||
install_worker_term_hard_handler = partial(
|
||||
_shutdown_handler, sig='SIGQUIT', how='Cold',
|
||||
exitcode=EX_FAILURE,
|
||||
_shutdown_handler, sig='SIGQUIT', how='Cold', callback=on_cold_shutdown, exitcode=EX_FAILURE,
|
||||
)
|
||||
else: # pragma: no cover
|
||||
install_worker_term_handler = \
|
||||
@@ -317,8 +437,9 @@ else: # pragma: no cover
|
||||
|
||||
|
||||
def on_SIGINT(worker):
|
||||
safe_say('worker: Hitting Ctrl+C again will terminate all running tasks!')
|
||||
install_worker_term_hard_handler(worker, sig='SIGINT')
|
||||
safe_say('worker: Hitting Ctrl+C again will initiate cold shutdown, terminating all running tasks!',
|
||||
sys.__stdout__)
|
||||
install_worker_term_hard_handler(worker, sig='SIGINT', verbose=False)
|
||||
|
||||
|
||||
if not is_jython: # pragma: no cover
|
||||
@@ -343,7 +464,8 @@ def install_worker_restart_handler(worker, sig='SIGHUP'):
|
||||
def restart_worker_sig_handler(*args):
|
||||
"""Signal handler restarting the current python program."""
|
||||
set_in_sighandler(True)
|
||||
safe_say(f"Restarting celery worker ({' '.join(sys.argv)})")
|
||||
safe_say(f"Restarting celery worker ({' '.join(sys.argv)})",
|
||||
sys.__stdout__)
|
||||
import atexit
|
||||
atexit.register(_reload_current_worker)
|
||||
from celery.worker import state
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""The Azure Storage Block Blob backend for Celery."""
|
||||
from kombu.transport.azurestoragequeues import Transport as AzureStorageQueuesTransport
|
||||
from kombu.utils import cached_property
|
||||
from kombu.utils.encoding import bytes_to_str
|
||||
|
||||
@@ -28,6 +29,13 @@ class AzureBlockBlobBackend(KeyValueStoreBackend):
|
||||
container_name=None,
|
||||
*args,
|
||||
**kwargs):
|
||||
"""
|
||||
Supported URL formats:
|
||||
|
||||
azureblockblob://CONNECTION_STRING
|
||||
azureblockblob://DefaultAzureCredential@STORAGE_ACCOUNT_URL
|
||||
azureblockblob://ManagedIdentityCredential@STORAGE_ACCOUNT_URL
|
||||
"""
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
if azurestorage is None or azurestorage.__version__ < '12':
|
||||
@@ -65,11 +73,26 @@ class AzureBlockBlobBackend(KeyValueStoreBackend):
|
||||
the container is created if it doesn't yet exist.
|
||||
|
||||
"""
|
||||
client = BlobServiceClient.from_connection_string(
|
||||
self._connection_string,
|
||||
connection_timeout=self._connection_timeout,
|
||||
read_timeout=self._read_timeout
|
||||
)
|
||||
if (
|
||||
"DefaultAzureCredential" in self._connection_string or
|
||||
"ManagedIdentityCredential" in self._connection_string
|
||||
):
|
||||
# Leveraging the work that Kombu already did for us
|
||||
credential_, url = AzureStorageQueuesTransport.parse_uri(
|
||||
self._connection_string
|
||||
)
|
||||
client = BlobServiceClient(
|
||||
account_url=url,
|
||||
credential=credential_,
|
||||
connection_timeout=self._connection_timeout,
|
||||
read_timeout=self._read_timeout,
|
||||
)
|
||||
else:
|
||||
client = BlobServiceClient.from_connection_string(
|
||||
self._connection_string,
|
||||
connection_timeout=self._connection_timeout,
|
||||
read_timeout=self._read_timeout,
|
||||
)
|
||||
|
||||
try:
|
||||
client.create_container(name=self._container_name)
|
||||
|
||||
@@ -9,7 +9,7 @@ import sys
|
||||
import time
|
||||
import warnings
|
||||
from collections import namedtuple
|
||||
from datetime import datetime, timedelta
|
||||
from datetime import timedelta
|
||||
from functools import partial
|
||||
from weakref import WeakValueDictionary
|
||||
|
||||
@@ -460,7 +460,7 @@ class Backend:
|
||||
state, traceback, request, format_date=True,
|
||||
encode=False):
|
||||
if state in self.READY_STATES:
|
||||
date_done = datetime.utcnow()
|
||||
date_done = self.app.now()
|
||||
if format_date:
|
||||
date_done = date_done.isoformat()
|
||||
else:
|
||||
@@ -833,9 +833,11 @@ class BaseKeyValueStoreBackend(Backend):
|
||||
"""
|
||||
global_keyprefix = self.app.conf.get('result_backend_transport_options', {}).get("global_keyprefix", None)
|
||||
if global_keyprefix:
|
||||
self.task_keyprefix = f"{global_keyprefix}_{self.task_keyprefix}"
|
||||
self.group_keyprefix = f"{global_keyprefix}_{self.group_keyprefix}"
|
||||
self.chord_keyprefix = f"{global_keyprefix}_{self.chord_keyprefix}"
|
||||
if global_keyprefix[-1] not in ':_-.':
|
||||
global_keyprefix += '_'
|
||||
self.task_keyprefix = f"{global_keyprefix}{self.task_keyprefix}"
|
||||
self.group_keyprefix = f"{global_keyprefix}{self.group_keyprefix}"
|
||||
self.chord_keyprefix = f"{global_keyprefix}{self.chord_keyprefix}"
|
||||
|
||||
def _encode_prefixes(self):
|
||||
self.task_keyprefix = self.key_t(self.task_keyprefix)
|
||||
@@ -1080,7 +1082,7 @@ class BaseKeyValueStoreBackend(Backend):
|
||||
)
|
||||
finally:
|
||||
deps.delete()
|
||||
self.client.delete(key)
|
||||
self.delete(key)
|
||||
else:
|
||||
self.expire(key, self.expires)
|
||||
|
||||
|
||||
@@ -86,7 +86,7 @@ class CassandraBackend(BaseBackend):
|
||||
supports_autoexpire = True # autoexpire supported via entry_ttl
|
||||
|
||||
def __init__(self, servers=None, keyspace=None, table=None, entry_ttl=None,
|
||||
port=9042, bundle_path=None, **kwargs):
|
||||
port=None, bundle_path=None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
if not cassandra:
|
||||
@@ -96,7 +96,7 @@ class CassandraBackend(BaseBackend):
|
||||
self.servers = servers or conf.get('cassandra_servers', None)
|
||||
self.bundle_path = bundle_path or conf.get(
|
||||
'cassandra_secure_bundle_path', None)
|
||||
self.port = port or conf.get('cassandra_port', None)
|
||||
self.port = port or conf.get('cassandra_port', None) or 9042
|
||||
self.keyspace = keyspace or conf.get('cassandra_keyspace', None)
|
||||
self.table = table or conf.get('cassandra_table', None)
|
||||
self.cassandra_options = conf.get('cassandra_options', {})
|
||||
|
||||
@@ -98,11 +98,23 @@ class DatabaseBackend(BaseBackend):
|
||||
'Missing connection string! Do you have the'
|
||||
' database_url setting set to a real value?')
|
||||
|
||||
self.session_manager = SessionManager()
|
||||
|
||||
create_tables_at_setup = conf.database_create_tables_at_setup
|
||||
if create_tables_at_setup is True:
|
||||
self._create_tables()
|
||||
|
||||
@property
|
||||
def extended_result(self):
|
||||
return self.app.conf.find_value_for_key('extended', 'result')
|
||||
|
||||
def ResultSession(self, session_manager=SessionManager()):
|
||||
def _create_tables(self):
|
||||
"""Create the task and taskset tables."""
|
||||
self.ResultSession()
|
||||
|
||||
def ResultSession(self, session_manager=None):
|
||||
if session_manager is None:
|
||||
session_manager = self.session_manager
|
||||
return session_manager.session_factory(
|
||||
dburi=self.url,
|
||||
short_lived_sessions=self.short_lived_sessions,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
"""Database models used by the SQLAlchemy result store backend."""
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.types import PickleType
|
||||
@@ -22,8 +22,8 @@ class Task(ResultModelBase):
|
||||
task_id = sa.Column(sa.String(155), unique=True)
|
||||
status = sa.Column(sa.String(50), default=states.PENDING)
|
||||
result = sa.Column(PickleType, nullable=True)
|
||||
date_done = sa.Column(sa.DateTime, default=datetime.utcnow,
|
||||
onupdate=datetime.utcnow, nullable=True)
|
||||
date_done = sa.Column(sa.DateTime, default=datetime.now(timezone.utc),
|
||||
onupdate=datetime.now(timezone.utc), nullable=True)
|
||||
traceback = sa.Column(sa.Text, nullable=True)
|
||||
|
||||
def __init__(self, task_id):
|
||||
@@ -84,7 +84,7 @@ class TaskSet(ResultModelBase):
|
||||
autoincrement=True, primary_key=True)
|
||||
taskset_id = sa.Column(sa.String(155), unique=True)
|
||||
result = sa.Column(PickleType, nullable=True)
|
||||
date_done = sa.Column(sa.DateTime, default=datetime.utcnow,
|
||||
date_done = sa.Column(sa.DateTime, default=datetime.now(timezone.utc),
|
||||
nullable=True)
|
||||
|
||||
def __init__(self, taskset_id, result):
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
"""AWS DynamoDB result store backend."""
|
||||
from collections import namedtuple
|
||||
from ipaddress import ip_address
|
||||
from time import sleep, time
|
||||
from typing import Any, Dict
|
||||
|
||||
from kombu.utils.url import _parse_url as parse_url
|
||||
|
||||
@@ -54,11 +56,15 @@ class DynamoDBBackend(KeyValueStoreBackend):
|
||||
supports_autoexpire = True
|
||||
|
||||
_key_field = DynamoDBAttribute(name='id', data_type='S')
|
||||
# Each record has either a value field or count field
|
||||
_value_field = DynamoDBAttribute(name='result', data_type='B')
|
||||
_count_filed = DynamoDBAttribute(name="chord_count", data_type='N')
|
||||
_timestamp_field = DynamoDBAttribute(name='timestamp', data_type='N')
|
||||
_ttl_field = DynamoDBAttribute(name='ttl', data_type='N')
|
||||
_available_fields = None
|
||||
|
||||
implements_incr = True
|
||||
|
||||
def __init__(self, url=None, table_name=None, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
@@ -91,9 +97,9 @@ class DynamoDBBackend(KeyValueStoreBackend):
|
||||
|
||||
aws_credentials_given = access_key_given
|
||||
|
||||
if region == 'localhost':
|
||||
if region == 'localhost' or DynamoDBBackend._is_valid_ip(region):
|
||||
# We are using the downloadable, local version of DynamoDB
|
||||
self.endpoint_url = f'http://localhost:{port}'
|
||||
self.endpoint_url = f'http://{region}:{port}'
|
||||
self.aws_region = 'us-east-1'
|
||||
logger.warning(
|
||||
'Using local-only DynamoDB endpoint URL: {}'.format(
|
||||
@@ -148,6 +154,14 @@ class DynamoDBBackend(KeyValueStoreBackend):
|
||||
secret_access_key=aws_secret_access_key
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _is_valid_ip(ip):
|
||||
try:
|
||||
ip_address(ip)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
def _get_client(self, access_key_id=None, secret_access_key=None):
|
||||
"""Get client connection."""
|
||||
if self._client is None:
|
||||
@@ -459,6 +473,40 @@ class DynamoDBBackend(KeyValueStoreBackend):
|
||||
})
|
||||
return put_request
|
||||
|
||||
def _prepare_init_count_request(self, key: str) -> Dict[str, Any]:
|
||||
"""Construct the counter initialization request parameters"""
|
||||
timestamp = time()
|
||||
return {
|
||||
'TableName': self.table_name,
|
||||
'Item': {
|
||||
self._key_field.name: {
|
||||
self._key_field.data_type: key
|
||||
},
|
||||
self._count_filed.name: {
|
||||
self._count_filed.data_type: "0"
|
||||
},
|
||||
self._timestamp_field.name: {
|
||||
self._timestamp_field.data_type: str(timestamp)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def _prepare_inc_count_request(self, key: str) -> Dict[str, Any]:
|
||||
"""Construct the counter increment request parameters"""
|
||||
return {
|
||||
'TableName': self.table_name,
|
||||
'Key': {
|
||||
self._key_field.name: {
|
||||
self._key_field.data_type: key
|
||||
}
|
||||
},
|
||||
'UpdateExpression': f"set {self._count_filed.name} = {self._count_filed.name} + :num",
|
||||
"ExpressionAttributeValues": {
|
||||
":num": {"N": "1"},
|
||||
},
|
||||
"ReturnValues": "UPDATED_NEW",
|
||||
}
|
||||
|
||||
def _item_to_dict(self, raw_response):
|
||||
"""Convert get_item() response to field-value pairs."""
|
||||
if 'Item' not in raw_response:
|
||||
@@ -491,3 +539,18 @@ class DynamoDBBackend(KeyValueStoreBackend):
|
||||
key = str(key)
|
||||
request_parameters = self._prepare_get_request(key)
|
||||
self.client.delete_item(**request_parameters)
|
||||
|
||||
def incr(self, key: bytes) -> int:
|
||||
"""Atomically increase the chord_count and return the new count"""
|
||||
key = str(key)
|
||||
request_parameters = self._prepare_inc_count_request(key)
|
||||
item_response = self.client.update_item(**request_parameters)
|
||||
new_count: str = item_response["Attributes"][self._count_filed.name][self._count_filed.data_type]
|
||||
return int(new_count)
|
||||
|
||||
def _apply_chord_incr(self, header_result_args, body, **kwargs):
|
||||
chord_key = self.get_key_for_chord(header_result_args[0])
|
||||
init_count_request = self._prepare_init_count_request(str(chord_key))
|
||||
self.client.put_item(**init_count_request)
|
||||
return super()._apply_chord_incr(
|
||||
header_result_args, body, **kwargs)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
"""Elasticsearch result store backend."""
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from kombu.utils.encoding import bytes_to_str
|
||||
from kombu.utils.url import _parse_url
|
||||
@@ -14,6 +14,11 @@ try:
|
||||
except ImportError:
|
||||
elasticsearch = None
|
||||
|
||||
try:
|
||||
import elastic_transport
|
||||
except ImportError:
|
||||
elastic_transport = None
|
||||
|
||||
__all__ = ('ElasticsearchBackend',)
|
||||
|
||||
E_LIB_MISSING = """\
|
||||
@@ -31,7 +36,7 @@ class ElasticsearchBackend(KeyValueStoreBackend):
|
||||
"""
|
||||
|
||||
index = 'celery'
|
||||
doc_type = 'backend'
|
||||
doc_type = None
|
||||
scheme = 'http'
|
||||
host = 'localhost'
|
||||
port = 9200
|
||||
@@ -83,17 +88,17 @@ class ElasticsearchBackend(KeyValueStoreBackend):
|
||||
self._server = None
|
||||
|
||||
def exception_safe_to_retry(self, exc):
|
||||
if isinstance(exc, (elasticsearch.exceptions.TransportError)):
|
||||
if isinstance(exc, elasticsearch.exceptions.ApiError):
|
||||
# 401: Unauthorized
|
||||
# 409: Conflict
|
||||
# 429: Too Many Requests
|
||||
# 500: Internal Server Error
|
||||
# 502: Bad Gateway
|
||||
# 503: Service Unavailable
|
||||
# 504: Gateway Timeout
|
||||
# N/A: Low level exception (i.e. socket exception)
|
||||
if exc.status_code in {401, 409, 429, 500, 502, 503, 504, 'N/A'}:
|
||||
if exc.status_code in {401, 409, 500, 502, 504, 'N/A'}:
|
||||
return True
|
||||
if isinstance(exc, elasticsearch.exceptions.TransportError):
|
||||
return True
|
||||
return False
|
||||
|
||||
def get(self, key):
|
||||
@@ -108,17 +113,23 @@ class ElasticsearchBackend(KeyValueStoreBackend):
|
||||
pass
|
||||
|
||||
def _get(self, key):
|
||||
return self.server.get(
|
||||
index=self.index,
|
||||
doc_type=self.doc_type,
|
||||
id=key,
|
||||
)
|
||||
if self.doc_type:
|
||||
return self.server.get(
|
||||
index=self.index,
|
||||
id=key,
|
||||
doc_type=self.doc_type,
|
||||
)
|
||||
else:
|
||||
return self.server.get(
|
||||
index=self.index,
|
||||
id=key,
|
||||
)
|
||||
|
||||
def _set_with_state(self, key, value, state):
|
||||
body = {
|
||||
'result': value,
|
||||
'@timestamp': '{}Z'.format(
|
||||
datetime.utcnow().isoformat()[:-3]
|
||||
datetime.now(timezone.utc).isoformat()[:-9]
|
||||
),
|
||||
}
|
||||
try:
|
||||
@@ -135,14 +146,23 @@ class ElasticsearchBackend(KeyValueStoreBackend):
|
||||
|
||||
def _index(self, id, body, **kwargs):
|
||||
body = {bytes_to_str(k): v for k, v in body.items()}
|
||||
return self.server.index(
|
||||
id=bytes_to_str(id),
|
||||
index=self.index,
|
||||
doc_type=self.doc_type,
|
||||
body=body,
|
||||
params={'op_type': 'create'},
|
||||
**kwargs
|
||||
)
|
||||
if self.doc_type:
|
||||
return self.server.index(
|
||||
id=bytes_to_str(id),
|
||||
index=self.index,
|
||||
doc_type=self.doc_type,
|
||||
body=body,
|
||||
params={'op_type': 'create'},
|
||||
**kwargs
|
||||
)
|
||||
else:
|
||||
return self.server.index(
|
||||
id=bytes_to_str(id),
|
||||
index=self.index,
|
||||
body=body,
|
||||
params={'op_type': 'create'},
|
||||
**kwargs
|
||||
)
|
||||
|
||||
def _update(self, id, body, state, **kwargs):
|
||||
"""Update state in a conflict free manner.
|
||||
@@ -182,19 +202,32 @@ class ElasticsearchBackend(KeyValueStoreBackend):
|
||||
prim_term = res_get.get('_primary_term', 1)
|
||||
|
||||
# try to update document with current seq_no and primary_term
|
||||
res = self.server.update(
|
||||
id=bytes_to_str(id),
|
||||
index=self.index,
|
||||
doc_type=self.doc_type,
|
||||
body={'doc': body},
|
||||
params={'if_primary_term': prim_term, 'if_seq_no': seq_no},
|
||||
**kwargs
|
||||
)
|
||||
if self.doc_type:
|
||||
res = self.server.update(
|
||||
id=bytes_to_str(id),
|
||||
index=self.index,
|
||||
doc_type=self.doc_type,
|
||||
body={'doc': body},
|
||||
params={'if_primary_term': prim_term, 'if_seq_no': seq_no},
|
||||
**kwargs
|
||||
)
|
||||
else:
|
||||
res = self.server.update(
|
||||
id=bytes_to_str(id),
|
||||
index=self.index,
|
||||
body={'doc': body},
|
||||
params={'if_primary_term': prim_term, 'if_seq_no': seq_no},
|
||||
**kwargs
|
||||
)
|
||||
# result is elastic search update query result
|
||||
# noop = query did not update any document
|
||||
# updated = at least one document got updated
|
||||
if res['result'] == 'noop':
|
||||
raise elasticsearch.exceptions.ConflictError(409, 'conflicting update occurred concurrently', {})
|
||||
raise elasticsearch.exceptions.ConflictError(
|
||||
"conflicting update occurred concurrently",
|
||||
elastic_transport.ApiResponseMeta(409, "HTTP/1.1",
|
||||
elastic_transport.HttpHeaders(), 0, elastic_transport.NodeConfig(
|
||||
self.scheme, self.host, self.port)), None)
|
||||
return res
|
||||
|
||||
def encode(self, data):
|
||||
@@ -225,7 +258,10 @@ class ElasticsearchBackend(KeyValueStoreBackend):
|
||||
return [self.get(key) for key in keys]
|
||||
|
||||
def delete(self, key):
|
||||
self.server.delete(index=self.index, doc_type=self.doc_type, id=key)
|
||||
if self.doc_type:
|
||||
self.server.delete(index=self.index, id=key, doc_type=self.doc_type)
|
||||
else:
|
||||
self.server.delete(index=self.index, id=key)
|
||||
|
||||
def _get_server(self):
|
||||
"""Connect to the Elasticsearch server."""
|
||||
@@ -233,11 +269,10 @@ class ElasticsearchBackend(KeyValueStoreBackend):
|
||||
if self.username and self.password:
|
||||
http_auth = (self.username, self.password)
|
||||
return elasticsearch.Elasticsearch(
|
||||
f'{self.host}:{self.port}',
|
||||
f'{self.scheme}://{self.host}:{self.port}',
|
||||
retry_on_timeout=self.es_retry_on_timeout,
|
||||
max_retries=self.es_max_retries,
|
||||
timeout=self.es_timeout,
|
||||
scheme=self.scheme,
|
||||
http_auth=http_auth,
|
||||
)
|
||||
|
||||
|
||||
@@ -50,7 +50,7 @@ class FilesystemBackend(KeyValueStoreBackend):
|
||||
self.open = open
|
||||
self.unlink = unlink
|
||||
|
||||
# Lets verify that we've everything setup right
|
||||
# Let's verify that we've everything setup right
|
||||
self._do_directory_test(b'.fs-backend-' + uuid().encode(encoding))
|
||||
|
||||
def __reduce__(self, args=(), kwargs=None):
|
||||
|
||||
352
venv/lib/python3.12/site-packages/celery/backends/gcs.py
Normal file
352
venv/lib/python3.12/site-packages/celery/backends/gcs.py
Normal file
@@ -0,0 +1,352 @@
|
||||
"""Google Cloud Storage result store backend for Celery."""
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import datetime, timedelta
|
||||
from os import getpid
|
||||
from threading import RLock
|
||||
|
||||
from kombu.utils.encoding import bytes_to_str
|
||||
from kombu.utils.functional import dictfilter
|
||||
from kombu.utils.url import url_to_parts
|
||||
|
||||
from celery.canvas import maybe_signature
|
||||
from celery.exceptions import ChordError, ImproperlyConfigured
|
||||
from celery.result import GroupResult, allow_join_result
|
||||
from celery.utils.log import get_logger
|
||||
|
||||
from .base import KeyValueStoreBackend
|
||||
|
||||
try:
|
||||
import requests
|
||||
from google.api_core import retry
|
||||
from google.api_core.exceptions import Conflict
|
||||
from google.api_core.retry import if_exception_type
|
||||
from google.cloud import storage
|
||||
from google.cloud.storage import Client
|
||||
from google.cloud.storage.retry import DEFAULT_RETRY
|
||||
except ImportError:
|
||||
storage = None
|
||||
|
||||
try:
|
||||
from google.cloud import firestore, firestore_admin_v1
|
||||
except ImportError:
|
||||
firestore = None
|
||||
firestore_admin_v1 = None
|
||||
|
||||
|
||||
__all__ = ('GCSBackend',)
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class GCSBackendBase(KeyValueStoreBackend):
|
||||
"""Google Cloud Storage task result backend."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
if not storage:
|
||||
raise ImproperlyConfigured(
|
||||
'You must install google-cloud-storage to use gcs backend'
|
||||
)
|
||||
super().__init__(**kwargs)
|
||||
self._client_lock = RLock()
|
||||
self._pid = getpid()
|
||||
self._retry_policy = DEFAULT_RETRY
|
||||
self._client = None
|
||||
|
||||
conf = self.app.conf
|
||||
if self.url:
|
||||
url_params = self._params_from_url()
|
||||
conf.update(**dictfilter(url_params))
|
||||
|
||||
self.bucket_name = conf.get('gcs_bucket')
|
||||
if not self.bucket_name:
|
||||
raise ImproperlyConfigured(
|
||||
'Missing bucket name: specify gcs_bucket to use gcs backend'
|
||||
)
|
||||
self.project = conf.get('gcs_project')
|
||||
if not self.project:
|
||||
raise ImproperlyConfigured(
|
||||
'Missing project:specify gcs_project to use gcs backend'
|
||||
)
|
||||
self.base_path = conf.get('gcs_base_path', '').strip('/')
|
||||
self._threadpool_maxsize = int(conf.get('gcs_threadpool_maxsize', 10))
|
||||
self.ttl = float(conf.get('gcs_ttl') or 0)
|
||||
if self.ttl < 0:
|
||||
raise ImproperlyConfigured(
|
||||
f'Invalid ttl: {self.ttl} must be greater than or equal to 0'
|
||||
)
|
||||
elif self.ttl:
|
||||
if not self._is_bucket_lifecycle_rule_exists():
|
||||
raise ImproperlyConfigured(
|
||||
f'Missing lifecycle rule to use gcs backend with ttl on '
|
||||
f'bucket: {self.bucket_name}'
|
||||
)
|
||||
|
||||
def get(self, key):
|
||||
key = bytes_to_str(key)
|
||||
blob = self._get_blob(key)
|
||||
try:
|
||||
return blob.download_as_bytes(retry=self._retry_policy)
|
||||
except storage.blob.NotFound:
|
||||
return None
|
||||
|
||||
def set(self, key, value):
|
||||
key = bytes_to_str(key)
|
||||
blob = self._get_blob(key)
|
||||
if self.ttl:
|
||||
blob.custom_time = datetime.utcnow() + timedelta(seconds=self.ttl)
|
||||
blob.upload_from_string(value, retry=self._retry_policy)
|
||||
|
||||
def delete(self, key):
|
||||
key = bytes_to_str(key)
|
||||
blob = self._get_blob(key)
|
||||
if blob.exists():
|
||||
blob.delete(retry=self._retry_policy)
|
||||
|
||||
def mget(self, keys):
|
||||
with ThreadPoolExecutor() as pool:
|
||||
return list(pool.map(self.get, keys))
|
||||
|
||||
@property
|
||||
def client(self):
|
||||
"""Returns a storage client."""
|
||||
|
||||
# make sure it's thread-safe, as creating a new client is expensive
|
||||
with self._client_lock:
|
||||
if self._client and self._pid == getpid():
|
||||
return self._client
|
||||
# make sure each process gets its own connection after a fork
|
||||
self._client = Client(project=self.project)
|
||||
self._pid = getpid()
|
||||
|
||||
# config the number of connections to the server
|
||||
adapter = requests.adapters.HTTPAdapter(
|
||||
pool_connections=self._threadpool_maxsize,
|
||||
pool_maxsize=self._threadpool_maxsize,
|
||||
max_retries=3,
|
||||
)
|
||||
client_http = self._client._http
|
||||
client_http.mount("https://", adapter)
|
||||
client_http._auth_request.session.mount("https://", adapter)
|
||||
|
||||
return self._client
|
||||
|
||||
@property
|
||||
def bucket(self):
|
||||
return self.client.bucket(self.bucket_name)
|
||||
|
||||
def _get_blob(self, key):
|
||||
key_bucket_path = f'{self.base_path}/{key}' if self.base_path else key
|
||||
return self.bucket.blob(key_bucket_path)
|
||||
|
||||
def _is_bucket_lifecycle_rule_exists(self):
|
||||
bucket = self.bucket
|
||||
bucket.reload()
|
||||
for rule in bucket.lifecycle_rules:
|
||||
if rule['action']['type'] == 'Delete':
|
||||
return True
|
||||
return False
|
||||
|
||||
def _params_from_url(self):
|
||||
url_parts = url_to_parts(self.url)
|
||||
|
||||
return {
|
||||
'gcs_bucket': url_parts.hostname,
|
||||
'gcs_base_path': url_parts.path,
|
||||
**url_parts.query,
|
||||
}
|
||||
|
||||
|
||||
class GCSBackend(GCSBackendBase):
|
||||
"""Google Cloud Storage task result backend.
|
||||
|
||||
Uses Firestore for chord ref count.
|
||||
"""
|
||||
|
||||
implements_incr = True
|
||||
supports_native_join = True
|
||||
|
||||
# Firestore parameters
|
||||
_collection_name = 'celery'
|
||||
_field_count = 'chord_count'
|
||||
_field_expires = 'expires_at'
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
if not (firestore and firestore_admin_v1):
|
||||
raise ImproperlyConfigured(
|
||||
'You must install google-cloud-firestore to use gcs backend'
|
||||
)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self._firestore_lock = RLock()
|
||||
self._firestore_client = None
|
||||
|
||||
self.firestore_project = self.app.conf.get(
|
||||
'firestore_project', self.project
|
||||
)
|
||||
if not self._is_firestore_ttl_policy_enabled():
|
||||
raise ImproperlyConfigured(
|
||||
f'Missing TTL policy to use gcs backend with ttl on '
|
||||
f'Firestore collection: {self._collection_name} '
|
||||
f'project: {self.firestore_project}'
|
||||
)
|
||||
|
||||
@property
|
||||
def firestore_client(self):
|
||||
"""Returns a firestore client."""
|
||||
|
||||
# make sure it's thread-safe, as creating a new client is expensive
|
||||
with self._firestore_lock:
|
||||
if self._firestore_client and self._pid == getpid():
|
||||
return self._firestore_client
|
||||
# make sure each process gets its own connection after a fork
|
||||
self._firestore_client = firestore.Client(
|
||||
project=self.firestore_project
|
||||
)
|
||||
self._pid = getpid()
|
||||
return self._firestore_client
|
||||
|
||||
def _is_firestore_ttl_policy_enabled(self):
|
||||
client = firestore_admin_v1.FirestoreAdminClient()
|
||||
|
||||
name = (
|
||||
f"projects/{self.firestore_project}"
|
||||
f"/databases/(default)/collectionGroups/{self._collection_name}"
|
||||
f"/fields/{self._field_expires}"
|
||||
)
|
||||
request = firestore_admin_v1.GetFieldRequest(name=name)
|
||||
field = client.get_field(request=request)
|
||||
|
||||
ttl_config = field.ttl_config
|
||||
return ttl_config and ttl_config.state in {
|
||||
firestore_admin_v1.Field.TtlConfig.State.ACTIVE,
|
||||
firestore_admin_v1.Field.TtlConfig.State.CREATING,
|
||||
}
|
||||
|
||||
def _apply_chord_incr(self, header_result_args, body, **kwargs):
|
||||
key = self.get_key_for_chord(header_result_args[0]).decode()
|
||||
self._expire_chord_key(key, 86400)
|
||||
return super()._apply_chord_incr(header_result_args, body, **kwargs)
|
||||
|
||||
def incr(self, key: bytes) -> int:
|
||||
doc = self._firestore_document(key)
|
||||
resp = doc.set(
|
||||
{self._field_count: firestore.Increment(1)},
|
||||
merge=True,
|
||||
retry=retry.Retry(
|
||||
predicate=if_exception_type(Conflict),
|
||||
initial=1.0,
|
||||
maximum=180.0,
|
||||
multiplier=2.0,
|
||||
timeout=180.0,
|
||||
),
|
||||
)
|
||||
return resp.transform_results[0].integer_value
|
||||
|
||||
def on_chord_part_return(self, request, state, result, **kwargs):
|
||||
"""Chord part return callback.
|
||||
|
||||
Called for each task in the chord.
|
||||
Increments the counter stored in Firestore.
|
||||
If the counter reaches the number of tasks in the chord, the callback
|
||||
is called.
|
||||
If the callback raises an exception, the chord is marked as errored.
|
||||
If the callback returns a value, the chord is marked as successful.
|
||||
"""
|
||||
app = self.app
|
||||
gid = request.group
|
||||
if not gid:
|
||||
return
|
||||
key = self.get_key_for_chord(gid)
|
||||
val = self.incr(key)
|
||||
size = request.chord.get("chord_size")
|
||||
if size is None:
|
||||
deps = self._restore_deps(gid, request)
|
||||
if deps is None:
|
||||
return
|
||||
size = len(deps)
|
||||
if val > size: # pragma: no cover
|
||||
logger.warning(
|
||||
'Chord counter incremented too many times for %r', gid
|
||||
)
|
||||
elif val == size:
|
||||
# Read the deps once, to reduce the number of reads from GCS ($$)
|
||||
deps = self._restore_deps(gid, request)
|
||||
if deps is None:
|
||||
return
|
||||
callback = maybe_signature(request.chord, app=app)
|
||||
j = deps.join_native
|
||||
try:
|
||||
with allow_join_result():
|
||||
ret = j(
|
||||
timeout=app.conf.result_chord_join_timeout,
|
||||
propagate=True,
|
||||
)
|
||||
except Exception as exc: # pylint: disable=broad-except
|
||||
try:
|
||||
culprit = next(deps._failed_join_report())
|
||||
reason = 'Dependency {0.id} raised {1!r}'.format(
|
||||
culprit,
|
||||
exc,
|
||||
)
|
||||
except StopIteration:
|
||||
reason = repr(exc)
|
||||
|
||||
logger.exception('Chord %r raised: %r', gid, reason)
|
||||
self.chord_error_from_stack(callback, ChordError(reason))
|
||||
else:
|
||||
try:
|
||||
callback.delay(ret)
|
||||
except Exception as exc: # pylint: disable=broad-except
|
||||
logger.exception('Chord %r raised: %r', gid, exc)
|
||||
self.chord_error_from_stack(
|
||||
callback,
|
||||
ChordError(f'Callback error: {exc!r}'),
|
||||
)
|
||||
finally:
|
||||
deps.delete()
|
||||
# Firestore doesn't have an exact ttl policy, so delete the key.
|
||||
self._delete_chord_key(key)
|
||||
|
||||
def _restore_deps(self, gid, request):
|
||||
app = self.app
|
||||
try:
|
||||
deps = GroupResult.restore(gid, backend=self)
|
||||
except Exception as exc: # pylint: disable=broad-except
|
||||
callback = maybe_signature(request.chord, app=app)
|
||||
logger.exception('Chord %r raised: %r', gid, exc)
|
||||
self.chord_error_from_stack(
|
||||
callback,
|
||||
ChordError(f'Cannot restore group: {exc!r}'),
|
||||
)
|
||||
return
|
||||
if deps is None:
|
||||
try:
|
||||
raise ValueError(gid)
|
||||
except ValueError as exc:
|
||||
callback = maybe_signature(request.chord, app=app)
|
||||
logger.exception('Chord callback %r raised: %r', gid, exc)
|
||||
self.chord_error_from_stack(
|
||||
callback,
|
||||
ChordError(f'GroupResult {gid} no longer exists'),
|
||||
)
|
||||
return deps
|
||||
|
||||
def _delete_chord_key(self, key):
|
||||
doc = self._firestore_document(key)
|
||||
doc.delete()
|
||||
|
||||
def _expire_chord_key(self, key, expires):
|
||||
"""Set TTL policy for a Firestore document.
|
||||
|
||||
Firestore ttl data is typically deleted within 24 hours after its
|
||||
expiration date.
|
||||
"""
|
||||
val_expires = datetime.utcnow() + timedelta(seconds=expires)
|
||||
doc = self._firestore_document(key)
|
||||
doc.set({self._field_expires: val_expires}, merge=True)
|
||||
|
||||
def _firestore_document(self, key):
|
||||
return self.firestore_client.collection(
|
||||
self._collection_name
|
||||
).document(bytes_to_str(key))
|
||||
@@ -1,5 +1,5 @@
|
||||
"""MongoDB result store backend."""
|
||||
from datetime import datetime, timedelta
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from kombu.exceptions import EncodeError
|
||||
from kombu.utils.objects import cached_property
|
||||
@@ -228,7 +228,7 @@ class MongoBackend(BaseBackend):
|
||||
meta = {
|
||||
'_id': group_id,
|
||||
'result': self.encode([i.id for i in result]),
|
||||
'date_done': datetime.utcnow(),
|
||||
'date_done': datetime.now(timezone.utc),
|
||||
}
|
||||
self.group_collection.replace_one({'_id': group_id}, meta, upsert=True)
|
||||
return result
|
||||
|
||||
@@ -359,6 +359,11 @@ class RedisBackend(BaseKeyValueStoreBackend, AsyncBackendMixin):
|
||||
connparams.update(query)
|
||||
return connparams
|
||||
|
||||
def exception_safe_to_retry(self, exc):
|
||||
if isinstance(exc, self.connection_errors):
|
||||
return True
|
||||
return False
|
||||
|
||||
@cached_property
|
||||
def retry_policy(self):
|
||||
retry_policy = super().retry_policy
|
||||
|
||||
@@ -222,7 +222,7 @@ class RPCBackend(base.Backend, AsyncBackendMixin):
|
||||
|
||||
def on_out_of_band_result(self, task_id, message):
|
||||
# Callback called when a reply for a task is received,
|
||||
# but we have no idea what do do with it.
|
||||
# but we have no idea what to do with it.
|
||||
# Since the result is not pending, we put it in a separate
|
||||
# buffer: probably it will become pending later.
|
||||
if self.result_consumer:
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""The periodic task scheduler."""
|
||||
|
||||
import copy
|
||||
import dbm
|
||||
import errno
|
||||
import heapq
|
||||
import os
|
||||
@@ -568,11 +569,11 @@ class PersistentScheduler(Scheduler):
|
||||
for _ in (1, 2):
|
||||
try:
|
||||
self._store['entries']
|
||||
except KeyError:
|
||||
except (KeyError, UnicodeDecodeError, TypeError):
|
||||
# new schedule db
|
||||
try:
|
||||
self._store['entries'] = {}
|
||||
except KeyError as exc:
|
||||
except (KeyError, UnicodeDecodeError, TypeError) + dbm.error as exc:
|
||||
self._store = self._destroy_open_corrupted_schedule(exc)
|
||||
continue
|
||||
else:
|
||||
|
||||
@@ -4,9 +4,10 @@ import numbers
|
||||
from collections import OrderedDict
|
||||
from functools import update_wrapper
|
||||
from pprint import pformat
|
||||
from typing import Any
|
||||
|
||||
import click
|
||||
from click import ParamType
|
||||
from click import Context, ParamType
|
||||
from kombu.utils.objects import cached_property
|
||||
|
||||
from celery._state import get_current_app
|
||||
@@ -170,19 +171,37 @@ class CeleryCommand(click.Command):
|
||||
formatter.write_dl(opts_group)
|
||||
|
||||
|
||||
class DaemonOption(CeleryOption):
|
||||
"""Common daemonization option"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(args,
|
||||
help_group=kwargs.pop("help_group", "Daemonization Options"),
|
||||
callback=kwargs.pop("callback", self.daemon_setting),
|
||||
**kwargs)
|
||||
|
||||
def daemon_setting(self, ctx: Context, opt: CeleryOption, value: Any) -> Any:
|
||||
"""
|
||||
Try to fetch daemonization option from applications settings.
|
||||
Use the daemon command name as prefix (eg. `worker` -> `worker_pidfile`)
|
||||
"""
|
||||
return value or getattr(ctx.obj.app.conf, f"{ctx.command.name}_{self.name}", None)
|
||||
|
||||
|
||||
class CeleryDaemonCommand(CeleryCommand):
|
||||
"""Daemon commands."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""Initialize a Celery command with common daemon options."""
|
||||
super().__init__(*args, **kwargs)
|
||||
self.params.append(CeleryOption(('-f', '--logfile'), help_group="Daemonization Options",
|
||||
help="Log destination; defaults to stderr"))
|
||||
self.params.append(CeleryOption(('--pidfile',), help_group="Daemonization Options"))
|
||||
self.params.append(CeleryOption(('--uid',), help_group="Daemonization Options"))
|
||||
self.params.append(CeleryOption(('--gid',), help_group="Daemonization Options"))
|
||||
self.params.append(CeleryOption(('--umask',), help_group="Daemonization Options"))
|
||||
self.params.append(CeleryOption(('--executable',), help_group="Daemonization Options"))
|
||||
self.params.extend((
|
||||
DaemonOption("--logfile", "-f", help="Log destination; defaults to stderr"),
|
||||
DaemonOption("--pidfile", help="PID file path; defaults to no PID file"),
|
||||
DaemonOption("--uid", help="Drops privileges to this user ID"),
|
||||
DaemonOption("--gid", help="Drops privileges to this group ID"),
|
||||
DaemonOption("--umask", help="Create files and directories with this umask"),
|
||||
DaemonOption("--executable", help="Override path to the Python executable"),
|
||||
))
|
||||
|
||||
|
||||
class CommaSeparatedList(ParamType):
|
||||
|
||||
@@ -11,7 +11,6 @@ except ImportError:
|
||||
|
||||
import click
|
||||
import click.exceptions
|
||||
from click.types import ParamType
|
||||
from click_didyoumean import DYMGroup
|
||||
from click_plugins import with_plugins
|
||||
|
||||
@@ -48,34 +47,6 @@ Unable to load celery application.
|
||||
{0}""")
|
||||
|
||||
|
||||
class App(ParamType):
|
||||
"""Application option."""
|
||||
|
||||
name = "application"
|
||||
|
||||
def convert(self, value, param, ctx):
|
||||
try:
|
||||
return find_app(value)
|
||||
except ModuleNotFoundError as e:
|
||||
if e.name != value:
|
||||
exc = traceback.format_exc()
|
||||
self.fail(
|
||||
UNABLE_TO_LOAD_APP_ERROR_OCCURRED.format(value, exc)
|
||||
)
|
||||
self.fail(UNABLE_TO_LOAD_APP_MODULE_NOT_FOUND.format(e.name))
|
||||
except AttributeError as e:
|
||||
attribute_name = e.args[0].capitalize()
|
||||
self.fail(UNABLE_TO_LOAD_APP_APP_MISSING.format(attribute_name))
|
||||
except Exception:
|
||||
exc = traceback.format_exc()
|
||||
self.fail(
|
||||
UNABLE_TO_LOAD_APP_ERROR_OCCURRED.format(value, exc)
|
||||
)
|
||||
|
||||
|
||||
APP = App()
|
||||
|
||||
|
||||
if sys.version_info >= (3, 10):
|
||||
_PLUGINS = entry_points(group='celery.commands')
|
||||
else:
|
||||
@@ -91,7 +62,11 @@ else:
|
||||
'--app',
|
||||
envvar='APP',
|
||||
cls=CeleryOption,
|
||||
type=APP,
|
||||
# May take either: a str when invoked from command line (Click),
|
||||
# or a Celery object when invoked from inside Celery; hence the
|
||||
# need to prevent Click from "processing" the Celery object and
|
||||
# converting it into its str representation.
|
||||
type=click.UNPROCESSED,
|
||||
help_group="Global Options")
|
||||
@click.option('-b',
|
||||
'--broker',
|
||||
@@ -160,6 +135,26 @@ def celery(ctx, app, broker, result_backend, loader, config, workdir,
|
||||
os.environ['CELERY_CONFIG_MODULE'] = config
|
||||
if skip_checks:
|
||||
os.environ['CELERY_SKIP_CHECKS'] = 'true'
|
||||
|
||||
if isinstance(app, str):
|
||||
try:
|
||||
app = find_app(app)
|
||||
except ModuleNotFoundError as e:
|
||||
if e.name != app:
|
||||
exc = traceback.format_exc()
|
||||
ctx.fail(
|
||||
UNABLE_TO_LOAD_APP_ERROR_OCCURRED.format(app, exc)
|
||||
)
|
||||
ctx.fail(UNABLE_TO_LOAD_APP_MODULE_NOT_FOUND.format(e.name))
|
||||
except AttributeError as e:
|
||||
attribute_name = e.args[0].capitalize()
|
||||
ctx.fail(UNABLE_TO_LOAD_APP_APP_MISSING.format(attribute_name))
|
||||
except Exception:
|
||||
exc = traceback.format_exc()
|
||||
ctx.fail(
|
||||
UNABLE_TO_LOAD_APP_ERROR_OCCURRED.format(app, exc)
|
||||
)
|
||||
|
||||
ctx.obj = CLIContext(app=app, no_color=no_color, workdir=workdir,
|
||||
quiet=quiet)
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""The ``celery control``, ``. inspect`` and ``. status`` programs."""
|
||||
from functools import partial
|
||||
from typing import Literal
|
||||
|
||||
import click
|
||||
from kombu.utils.json import dumps
|
||||
@@ -39,18 +40,69 @@ def _consume_arguments(meta, method, args):
|
||||
args[:] = args[i:]
|
||||
|
||||
|
||||
def _compile_arguments(action, args):
|
||||
meta = Panel.meta[action]
|
||||
def _compile_arguments(command, args):
|
||||
meta = Panel.meta[command]
|
||||
arguments = {}
|
||||
if meta.args:
|
||||
arguments.update({
|
||||
k: v for k, v in _consume_arguments(meta, action, args)
|
||||
k: v for k, v in _consume_arguments(meta, command, args)
|
||||
})
|
||||
if meta.variadic:
|
||||
arguments.update({meta.variadic: args})
|
||||
return arguments
|
||||
|
||||
|
||||
_RemoteControlType = Literal['inspect', 'control']
|
||||
|
||||
|
||||
def _verify_command_name(type_: _RemoteControlType, command: str) -> None:
|
||||
choices = _get_commands_of_type(type_)
|
||||
|
||||
if command not in choices:
|
||||
command_listing = ", ".join(choices)
|
||||
raise click.UsageError(
|
||||
message=f'Command {command} not recognized. Available {type_} commands: {command_listing}',
|
||||
)
|
||||
|
||||
|
||||
def _list_option(type_: _RemoteControlType):
|
||||
def callback(ctx: click.Context, param, value) -> None:
|
||||
if not value:
|
||||
return
|
||||
choices = _get_commands_of_type(type_)
|
||||
|
||||
formatter = click.HelpFormatter()
|
||||
|
||||
with formatter.section(f'{type_.capitalize()} Commands'):
|
||||
command_list = []
|
||||
for command_name, info in choices.items():
|
||||
if info.signature:
|
||||
command_preview = f'{command_name} {info.signature}'
|
||||
else:
|
||||
command_preview = command_name
|
||||
command_list.append((command_preview, info.help))
|
||||
formatter.write_dl(command_list)
|
||||
ctx.obj.echo(formatter.getvalue(), nl=False)
|
||||
ctx.exit()
|
||||
|
||||
return click.option(
|
||||
'--list',
|
||||
is_flag=True,
|
||||
help=f'List available {type_} commands and exit.',
|
||||
expose_value=False,
|
||||
is_eager=True,
|
||||
callback=callback,
|
||||
)
|
||||
|
||||
|
||||
def _get_commands_of_type(type_: _RemoteControlType) -> dict:
|
||||
command_name_info_pairs = [
|
||||
(name, info) for name, info in Panel.meta.items()
|
||||
if info.type == type_ and info.visible
|
||||
]
|
||||
return dict(sorted(command_name_info_pairs))
|
||||
|
||||
|
||||
@click.command(cls=CeleryCommand)
|
||||
@click.option('-t',
|
||||
'--timeout',
|
||||
@@ -96,10 +148,8 @@ def status(ctx, timeout, destination, json, **kwargs):
|
||||
|
||||
@click.command(cls=CeleryCommand,
|
||||
context_settings={'allow_extra_args': True})
|
||||
@click.argument("action", type=click.Choice([
|
||||
name for name, info in Panel.meta.items()
|
||||
if info.type == 'inspect' and info.visible
|
||||
]))
|
||||
@click.argument('command')
|
||||
@_list_option('inspect')
|
||||
@click.option('-t',
|
||||
'--timeout',
|
||||
cls=CeleryOption,
|
||||
@@ -121,19 +171,19 @@ def status(ctx, timeout, destination, json, **kwargs):
|
||||
help='Use json as output format.')
|
||||
@click.pass_context
|
||||
@handle_preload_options
|
||||
def inspect(ctx, action, timeout, destination, json, **kwargs):
|
||||
"""Inspect the worker at runtime.
|
||||
def inspect(ctx, command, timeout, destination, json, **kwargs):
|
||||
"""Inspect the workers by sending them the COMMAND inspect command.
|
||||
|
||||
Availability: RabbitMQ (AMQP) and Redis transports.
|
||||
"""
|
||||
_verify_command_name('inspect', command)
|
||||
callback = None if json else partial(_say_remote_command_reply, ctx,
|
||||
show_reply=True)
|
||||
arguments = _compile_arguments(action, ctx.args)
|
||||
arguments = _compile_arguments(command, ctx.args)
|
||||
inspect = ctx.obj.app.control.inspect(timeout=timeout,
|
||||
destination=destination,
|
||||
callback=callback)
|
||||
replies = inspect._request(action,
|
||||
**arguments)
|
||||
replies = inspect._request(command, **arguments)
|
||||
|
||||
if not replies:
|
||||
raise CeleryCommandException(
|
||||
@@ -153,10 +203,8 @@ def inspect(ctx, action, timeout, destination, json, **kwargs):
|
||||
|
||||
@click.command(cls=CeleryCommand,
|
||||
context_settings={'allow_extra_args': True})
|
||||
@click.argument("action", type=click.Choice([
|
||||
name for name, info in Panel.meta.items()
|
||||
if info.type == 'control' and info.visible
|
||||
]))
|
||||
@click.argument('command')
|
||||
@_list_option('control')
|
||||
@click.option('-t',
|
||||
'--timeout',
|
||||
cls=CeleryOption,
|
||||
@@ -178,16 +226,17 @@ def inspect(ctx, action, timeout, destination, json, **kwargs):
|
||||
help='Use json as output format.')
|
||||
@click.pass_context
|
||||
@handle_preload_options
|
||||
def control(ctx, action, timeout, destination, json):
|
||||
"""Workers remote control.
|
||||
def control(ctx, command, timeout, destination, json):
|
||||
"""Send the COMMAND control command to the workers.
|
||||
|
||||
Availability: RabbitMQ (AMQP), Redis, and MongoDB transports.
|
||||
"""
|
||||
_verify_command_name('control', command)
|
||||
callback = None if json else partial(_say_remote_command_reply, ctx,
|
||||
show_reply=True)
|
||||
args = ctx.args
|
||||
arguments = _compile_arguments(action, args)
|
||||
replies = ctx.obj.app.control.broadcast(action, timeout=timeout,
|
||||
arguments = _compile_arguments(command, args)
|
||||
replies = ctx.obj.app.control.broadcast(command, timeout=timeout,
|
||||
destination=destination,
|
||||
callback=callback,
|
||||
reply=True,
|
||||
|
||||
@@ -396,7 +396,7 @@ class Signature(dict):
|
||||
else:
|
||||
args, kwargs, options = self.args, self.kwargs, self.options
|
||||
# pylint: disable=too-many-function-args
|
||||
# Borks on this, as it's a property
|
||||
# Works on this, as it's a property
|
||||
return _apply(args, kwargs, **options)
|
||||
|
||||
def _merge(self, args=None, kwargs=None, options=None, force=False):
|
||||
@@ -515,7 +515,7 @@ class Signature(dict):
|
||||
if group_index is not None:
|
||||
opts['group_index'] = group_index
|
||||
# pylint: disable=too-many-function-args
|
||||
# Borks on this, as it's a property.
|
||||
# Works on this, as it's a property.
|
||||
return self.AsyncResult(tid)
|
||||
|
||||
_freeze = freeze
|
||||
@@ -958,6 +958,8 @@ class _chain(Signature):
|
||||
if isinstance(other, group):
|
||||
# unroll group with one member
|
||||
other = maybe_unroll_group(other)
|
||||
if not isinstance(other, group):
|
||||
return self.__or__(other)
|
||||
# chain | group() -> chain
|
||||
tasks = self.unchain_tasks()
|
||||
if not tasks:
|
||||
@@ -972,15 +974,20 @@ class _chain(Signature):
|
||||
tasks, other), app=self._app)
|
||||
elif isinstance(other, _chain):
|
||||
# chain | chain -> chain
|
||||
# use type(self) for _chain subclasses
|
||||
return type(self)(seq_concat_seq(
|
||||
self.unchain_tasks(), other.unchain_tasks()), app=self._app)
|
||||
return reduce(operator.or_, other.unchain_tasks(), self)
|
||||
elif isinstance(other, Signature):
|
||||
if self.tasks and isinstance(self.tasks[-1], group):
|
||||
# CHAIN [last item is group] | TASK -> chord
|
||||
sig = self.clone()
|
||||
sig.tasks[-1] = chord(
|
||||
sig.tasks[-1], other, app=self._app)
|
||||
# In the scenario where the second-to-last item in a chain is a chord,
|
||||
# it leads to a situation where two consecutive chords are formed.
|
||||
# In such cases, a further upgrade can be considered.
|
||||
# This would involve chaining the body of the second-to-last chord with the last chord."
|
||||
if len(sig.tasks) > 1 and isinstance(sig.tasks[-2], chord):
|
||||
sig.tasks[-2].body = sig.tasks[-2].body | sig.tasks[-1]
|
||||
sig.tasks = sig.tasks[:-1]
|
||||
return sig
|
||||
elif self.tasks and isinstance(self.tasks[-1], chord):
|
||||
# CHAIN [last item is chord] -> chain with chord body.
|
||||
@@ -1216,6 +1223,12 @@ class _chain(Signature):
|
||||
task, body=prev_task,
|
||||
root_id=root_id, app=app,
|
||||
)
|
||||
if tasks:
|
||||
prev_task = tasks[-1]
|
||||
prev_res = results[-1]
|
||||
else:
|
||||
prev_task = None
|
||||
prev_res = None
|
||||
|
||||
if is_last_task:
|
||||
# chain(task_id=id) means task id is set for the last task
|
||||
@@ -1261,6 +1274,7 @@ class _chain(Signature):
|
||||
while node.parent:
|
||||
node = node.parent
|
||||
prev_res = node
|
||||
self.id = last_task_id
|
||||
return tasks, results
|
||||
|
||||
def apply(self, args=None, kwargs=None, **options):
|
||||
@@ -1672,6 +1686,8 @@ class group(Signature):
|
||||
#
|
||||
# We return a concretised tuple of the signatures actually applied to
|
||||
# each child task signature, of which there might be none!
|
||||
sig = maybe_signature(sig)
|
||||
|
||||
return tuple(child_task.link_error(sig.clone(immutable=True)) for child_task in self.tasks)
|
||||
|
||||
def _prepared(self, tasks, partial_args, group_id, root_id, app,
|
||||
@@ -2271,6 +2287,8 @@ class _chord(Signature):
|
||||
``False`` (the current default), then the error callback will only be
|
||||
applied to the body.
|
||||
"""
|
||||
errback = maybe_signature(errback)
|
||||
|
||||
if self.app.conf.task_allow_error_cb_on_chord_header:
|
||||
for task in maybe_list(self.tasks) or []:
|
||||
task.link_error(errback.clone(immutable=True))
|
||||
@@ -2289,6 +2307,13 @@ class _chord(Signature):
|
||||
CPendingDeprecationWarning
|
||||
)
|
||||
|
||||
# Edge case for nested chords in the header
|
||||
for task in maybe_list(self.tasks) or []:
|
||||
if isinstance(task, chord):
|
||||
# Let the nested chord do the error linking itself on its
|
||||
# header and body where needed, based on the current configuration
|
||||
task.link_error(errback)
|
||||
|
||||
self.body.link_error(errback)
|
||||
return errback
|
||||
|
||||
|
||||
@@ -103,26 +103,35 @@ def _get_job_writer(job):
|
||||
return writer() # is a weakref
|
||||
|
||||
|
||||
def _ensure_integral_fd(fd):
|
||||
return fd if isinstance(fd, Integral) else fd.fileno()
|
||||
|
||||
|
||||
if hasattr(select, 'poll'):
|
||||
def _select_imp(readers=None, writers=None, err=None, timeout=0,
|
||||
poll=select.poll, POLLIN=select.POLLIN,
|
||||
POLLOUT=select.POLLOUT, POLLERR=select.POLLERR):
|
||||
poller = poll()
|
||||
register = poller.register
|
||||
fd_to_mask = {}
|
||||
|
||||
if readers:
|
||||
[register(fd, POLLIN) for fd in readers]
|
||||
for fd in map(_ensure_integral_fd, readers):
|
||||
fd_to_mask[fd] = fd_to_mask.get(fd, 0) | POLLIN
|
||||
if writers:
|
||||
[register(fd, POLLOUT) for fd in writers]
|
||||
for fd in map(_ensure_integral_fd, writers):
|
||||
fd_to_mask[fd] = fd_to_mask.get(fd, 0) | POLLOUT
|
||||
if err:
|
||||
[register(fd, POLLERR) for fd in err]
|
||||
for fd in map(_ensure_integral_fd, err):
|
||||
fd_to_mask[fd] = fd_to_mask.get(fd, 0) | POLLERR
|
||||
|
||||
for fd, event_mask in fd_to_mask.items():
|
||||
register(fd, event_mask)
|
||||
|
||||
R, W = set(), set()
|
||||
timeout = 0 if timeout and timeout < 0 else round(timeout * 1e3)
|
||||
events = poller.poll(timeout)
|
||||
for fd, event in events:
|
||||
if not isinstance(fd, Integral):
|
||||
fd = fd.fileno()
|
||||
if event & POLLIN:
|
||||
R.add(fd)
|
||||
if event & POLLOUT:
|
||||
@@ -194,7 +203,7 @@ def iterate_file_descriptors_safely(fds_iter, source_data,
|
||||
or possibly other reasons, so safely manage our lists of FDs.
|
||||
:param fds_iter: the file descriptors to iterate and apply hub_method
|
||||
:param source_data: data source to remove FD if it renders OSError
|
||||
:param hub_method: the method to call with with each fd and kwargs
|
||||
:param hub_method: the method to call with each fd and kwargs
|
||||
:*args to pass through to the hub_method;
|
||||
with a special syntax string '*fd*' represents a substitution
|
||||
for the current fd object in the iteration (for some callers).
|
||||
@@ -772,7 +781,7 @@ class AsynPool(_pool.Pool):
|
||||
None, WRITE | ERR, consolidate=True)
|
||||
else:
|
||||
iterate_file_descriptors_safely(
|
||||
inactive, all_inqueues, hub_remove)
|
||||
inactive, all_inqueues, hub.remove_writer)
|
||||
self.on_poll_start = on_poll_start
|
||||
|
||||
def on_inqueue_close(fd, proc):
|
||||
@@ -818,7 +827,7 @@ class AsynPool(_pool.Pool):
|
||||
# worker is already busy with another task
|
||||
continue
|
||||
if ready_fd not in all_inqueues:
|
||||
hub_remove(ready_fd)
|
||||
hub.remove_writer(ready_fd)
|
||||
continue
|
||||
try:
|
||||
job = pop_message()
|
||||
@@ -829,7 +838,7 @@ class AsynPool(_pool.Pool):
|
||||
# this may create a spinloop where the event loop
|
||||
# always wakes up.
|
||||
for inqfd in diff(active_writes):
|
||||
hub_remove(inqfd)
|
||||
hub.remove_writer(inqfd)
|
||||
break
|
||||
|
||||
else:
|
||||
@@ -927,7 +936,7 @@ class AsynPool(_pool.Pool):
|
||||
else:
|
||||
errors = 0
|
||||
finally:
|
||||
hub_remove(fd)
|
||||
hub.remove_writer(fd)
|
||||
write_stats[proc.index] += 1
|
||||
# message written, so this fd is now available
|
||||
active_writes.discard(fd)
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
"""Gevent execution pool."""
|
||||
import functools
|
||||
import types
|
||||
from time import monotonic
|
||||
|
||||
from kombu.asynchronous import timer as _timer
|
||||
@@ -16,15 +18,22 @@ __all__ = ('TaskPool',)
|
||||
# We cache globals and attribute lookups, so disable this warning.
|
||||
|
||||
|
||||
def apply_target(target, args=(), kwargs=None, callback=None,
|
||||
accept_callback=None, getpid=None, **_):
|
||||
kwargs = {} if not kwargs else kwargs
|
||||
return base.apply_target(target, args, kwargs, callback, accept_callback,
|
||||
pid=getpid(), **_)
|
||||
|
||||
|
||||
def apply_timeout(target, args=(), kwargs=None, callback=None,
|
||||
accept_callback=None, pid=None, timeout=None,
|
||||
accept_callback=None, getpid=None, timeout=None,
|
||||
timeout_callback=None, Timeout=Timeout,
|
||||
apply_target=base.apply_target, **rest):
|
||||
kwargs = {} if not kwargs else kwargs
|
||||
try:
|
||||
with Timeout(timeout):
|
||||
return apply_target(target, args, kwargs, callback,
|
||||
accept_callback, pid,
|
||||
accept_callback, getpid(),
|
||||
propagate=(Timeout,), **rest)
|
||||
except Timeout:
|
||||
return timeout_callback(False, timeout)
|
||||
@@ -82,18 +91,22 @@ class TaskPool(base.BasePool):
|
||||
is_green = True
|
||||
task_join_will_block = False
|
||||
_pool = None
|
||||
_pool_map = None
|
||||
_quick_put = None
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
from gevent import spawn_raw
|
||||
from gevent import getcurrent, spawn_raw
|
||||
from gevent.pool import Pool
|
||||
self.Pool = Pool
|
||||
self.getcurrent = getcurrent
|
||||
self.getpid = lambda: id(getcurrent())
|
||||
self.spawn_n = spawn_raw
|
||||
self.timeout = kwargs.get('timeout')
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def on_start(self):
|
||||
self._pool = self.Pool(self.limit)
|
||||
self._pool_map = {}
|
||||
self._quick_put = self._pool.spawn
|
||||
|
||||
def on_stop(self):
|
||||
@@ -102,12 +115,15 @@ class TaskPool(base.BasePool):
|
||||
|
||||
def on_apply(self, target, args=None, kwargs=None, callback=None,
|
||||
accept_callback=None, timeout=None,
|
||||
timeout_callback=None, apply_target=base.apply_target, **_):
|
||||
timeout_callback=None, apply_target=apply_target, **_):
|
||||
timeout = self.timeout if timeout is None else timeout
|
||||
return self._quick_put(apply_timeout if timeout else apply_target,
|
||||
target, args, kwargs, callback, accept_callback,
|
||||
timeout=timeout,
|
||||
timeout_callback=timeout_callback)
|
||||
target = self._make_killable_target(target)
|
||||
greenlet = self._quick_put(apply_timeout if timeout else apply_target,
|
||||
target, args, kwargs, callback, accept_callback,
|
||||
self.getpid, timeout=timeout, timeout_callback=timeout_callback)
|
||||
self._add_to_pool_map(id(greenlet), greenlet)
|
||||
greenlet.terminate = types.MethodType(_terminate, greenlet)
|
||||
return greenlet
|
||||
|
||||
def grow(self, n=1):
|
||||
self._pool._semaphore.counter += n
|
||||
@@ -117,6 +133,39 @@ class TaskPool(base.BasePool):
|
||||
self._pool._semaphore.counter -= n
|
||||
self._pool.size -= n
|
||||
|
||||
def terminate_job(self, pid, signal=None):
|
||||
import gevent
|
||||
|
||||
if pid in self._pool_map:
|
||||
greenlet = self._pool_map[pid]
|
||||
gevent.kill(greenlet)
|
||||
|
||||
@property
|
||||
def num_processes(self):
|
||||
return len(self._pool)
|
||||
|
||||
@staticmethod
|
||||
def _make_killable_target(target):
|
||||
def killable_target(*args, **kwargs):
|
||||
from greenlet import GreenletExit
|
||||
try:
|
||||
return target(*args, **kwargs)
|
||||
except GreenletExit:
|
||||
return (False, None, None)
|
||||
|
||||
return killable_target
|
||||
|
||||
def _add_to_pool_map(self, pid, greenlet):
|
||||
self._pool_map[pid] = greenlet
|
||||
greenlet.link(
|
||||
functools.partial(self._cleanup_after_job_finish, pid=pid, pool_map=self._pool_map),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _cleanup_after_job_finish(greenlet, pool_map, pid):
|
||||
del pool_map[pid]
|
||||
|
||||
|
||||
def _terminate(self, signal):
|
||||
# Done in `TaskPool.terminate_job`
|
||||
pass
|
||||
|
||||
@@ -0,0 +1,21 @@
|
||||
import functools
|
||||
|
||||
from django.db import transaction
|
||||
|
||||
from celery.app.task import Task
|
||||
|
||||
|
||||
class DjangoTask(Task):
|
||||
"""
|
||||
Extend the base :class:`~celery.app.task.Task` for Django.
|
||||
|
||||
Provide a nicer API to trigger tasks at the end of the DB transaction.
|
||||
"""
|
||||
|
||||
def delay_on_commit(self, *args, **kwargs) -> None:
|
||||
"""Call :meth:`~celery.app.task.Task.delay` with Django's ``on_commit()``."""
|
||||
transaction.on_commit(functools.partial(self.delay, *args, **kwargs))
|
||||
|
||||
def apply_async_on_commit(self, *args, **kwargs) -> None:
|
||||
"""Call :meth:`~celery.app.task.Task.apply_async` with Django's ``on_commit()``."""
|
||||
transaction.on_commit(functools.partial(self.apply_async, *args, **kwargs))
|
||||
@@ -3,10 +3,10 @@ import logging
|
||||
import os
|
||||
import threading
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Iterable, Union # noqa
|
||||
from typing import Any, Iterable, Optional, Union
|
||||
|
||||
import celery.worker.consumer # noqa
|
||||
from celery import Celery, worker # noqa
|
||||
from celery import Celery, worker
|
||||
from celery.result import _set_task_join_will_block, allow_join_result
|
||||
from celery.utils.dispatch import Signal
|
||||
from celery.utils.nodenames import anon_nodename
|
||||
@@ -30,6 +30,10 @@ test_worker_stopped = Signal(
|
||||
class TestWorkController(worker.WorkController):
|
||||
"""Worker that can synchronize on being fully started."""
|
||||
|
||||
# When this class is imported in pytest files, prevent pytest from thinking
|
||||
# this is a test class
|
||||
__test__ = False
|
||||
|
||||
logger_queue = None
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
@@ -131,16 +135,15 @@ def start_worker(
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _start_worker_thread(app,
|
||||
concurrency=1,
|
||||
pool='solo',
|
||||
loglevel=WORKER_LOGLEVEL,
|
||||
logfile=None,
|
||||
WorkController=TestWorkController,
|
||||
perform_ping_check=True,
|
||||
shutdown_timeout=10.0,
|
||||
**kwargs):
|
||||
# type: (Celery, int, str, Union[str, int], str, Any, **Any) -> Iterable
|
||||
def _start_worker_thread(app: Celery,
|
||||
concurrency: int = 1,
|
||||
pool: str = 'solo',
|
||||
loglevel: Union[str, int] = WORKER_LOGLEVEL,
|
||||
logfile: Optional[str] = None,
|
||||
WorkController: Any = TestWorkController,
|
||||
perform_ping_check: bool = True,
|
||||
shutdown_timeout: float = 10.0,
|
||||
**kwargs) -> Iterable[worker.WorkController]:
|
||||
"""Start Celery worker in a thread.
|
||||
|
||||
Yields:
|
||||
@@ -156,7 +159,7 @@ def _start_worker_thread(app,
|
||||
worker = WorkController(
|
||||
app=app,
|
||||
concurrency=concurrency,
|
||||
hostname=anon_nodename(),
|
||||
hostname=kwargs.pop("hostname", anon_nodename()),
|
||||
pool=pool,
|
||||
loglevel=loglevel,
|
||||
logfile=logfile,
|
||||
@@ -211,8 +214,7 @@ def _start_worker_process(app,
|
||||
cluster.stopwait()
|
||||
|
||||
|
||||
def setup_app_for_worker(app, loglevel, logfile) -> None:
|
||||
# type: (Celery, Union[str, int], str) -> None
|
||||
def setup_app_for_worker(app: Celery, loglevel: Union[str, int], logfile: str) -> None:
|
||||
"""Setup the app to be used for starting an embedded worker."""
|
||||
app.finalize()
|
||||
app.set_current()
|
||||
|
||||
@@ -55,7 +55,7 @@ def get_exchange(conn, name=EVENT_EXCHANGE_NAME):
|
||||
(from topic -> fanout).
|
||||
"""
|
||||
ex = copy(event_exchange)
|
||||
if conn.transport.driver_type == 'redis':
|
||||
if conn.transport.driver_type in {'redis', 'gcpubsub'}:
|
||||
# quick hack for Issue #436
|
||||
ex.type = 'fanout'
|
||||
if name != ex.name:
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
import os
|
||||
import sys
|
||||
import warnings
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timezone
|
||||
from importlib import import_module
|
||||
from typing import IO, TYPE_CHECKING, Any, List, Optional, cast
|
||||
|
||||
@@ -16,6 +16,7 @@ if TYPE_CHECKING:
|
||||
from types import ModuleType
|
||||
from typing import Protocol
|
||||
|
||||
from django.db.backends.base.base import BaseDatabaseWrapper
|
||||
from django.db.utils import ConnectionHandler
|
||||
|
||||
from celery.app.base import Celery
|
||||
@@ -78,6 +79,9 @@ class DjangoFixup:
|
||||
self._settings = symbol_by_name('django.conf:settings')
|
||||
self.app.loader.now = self.now
|
||||
|
||||
if not self.app._custom_task_cls_used:
|
||||
self.app.task_cls = 'celery.contrib.django.task:DjangoTask'
|
||||
|
||||
signals.import_modules.connect(self.on_import_modules)
|
||||
signals.worker_init.connect(self.on_worker_init)
|
||||
return self
|
||||
@@ -100,7 +104,7 @@ class DjangoFixup:
|
||||
self.worker_fixup.install()
|
||||
|
||||
def now(self, utc: bool = False) -> datetime:
|
||||
return datetime.utcnow() if utc else self._now()
|
||||
return datetime.now(timezone.utc) if utc else self._now()
|
||||
|
||||
def autodiscover_tasks(self) -> List[str]:
|
||||
from django.apps import apps
|
||||
@@ -161,15 +165,16 @@ class DjangoWorkerFixup:
|
||||
# network IO that close() might cause.
|
||||
for c in self._db.connections.all():
|
||||
if c and c.connection:
|
||||
self._maybe_close_db_fd(c.connection)
|
||||
self._maybe_close_db_fd(c)
|
||||
|
||||
# use the _ version to avoid DB_REUSE preventing the conn.close() call
|
||||
self._close_database(force=True)
|
||||
self.close_cache()
|
||||
|
||||
def _maybe_close_db_fd(self, fd: IO) -> None:
|
||||
def _maybe_close_db_fd(self, c: "BaseDatabaseWrapper") -> None:
|
||||
try:
|
||||
_maybe_close_fd(fd)
|
||||
with c.wrap_database_errors:
|
||||
_maybe_close_fd(c.connection)
|
||||
except self.interface_errors:
|
||||
pass
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ import importlib
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from kombu.utils import json
|
||||
from kombu.utils.objects import cached_property
|
||||
@@ -62,7 +62,7 @@ class BaseLoader:
|
||||
|
||||
def now(self, utc=True):
|
||||
if utc:
|
||||
return datetime.utcnow()
|
||||
return datetime.now(timezone.utc)
|
||||
return datetime.now()
|
||||
|
||||
def on_task_init(self, task_id, task):
|
||||
@@ -253,10 +253,12 @@ def find_related_module(package, related_name):
|
||||
# Django 1.7 allows for specifying a class name in INSTALLED_APPS.
|
||||
# (Issue #2248).
|
||||
try:
|
||||
# Return package itself when no related_name.
|
||||
module = importlib.import_module(package)
|
||||
if not related_name and module:
|
||||
return module
|
||||
except ImportError:
|
||||
except ModuleNotFoundError:
|
||||
# On import error, try to walk package up one level.
|
||||
package, _, _ = package.rpartition('.')
|
||||
if not package:
|
||||
raise
|
||||
@@ -264,9 +266,13 @@ def find_related_module(package, related_name):
|
||||
module_name = f'{package}.{related_name}'
|
||||
|
||||
try:
|
||||
# Try to find related_name under package.
|
||||
return importlib.import_module(module_name)
|
||||
except ImportError as e:
|
||||
import_exc_name = getattr(e, 'name', module_name)
|
||||
if import_exc_name is not None and import_exc_name != module_name:
|
||||
raise e
|
||||
return
|
||||
except ModuleNotFoundError as e:
|
||||
import_exc_name = getattr(e, 'name', None)
|
||||
# If candidate does not exist, then return None.
|
||||
if import_exc_name and module_name == import_exc_name:
|
||||
return
|
||||
|
||||
# Otherwise, raise because error probably originated from a nested import.
|
||||
raise e
|
||||
|
||||
@@ -397,7 +397,6 @@ COMPAT_MODULES = {
|
||||
},
|
||||
'log': {
|
||||
'get_default_logger': 'log.get_default_logger',
|
||||
'setup_logger': 'log.setup_logger',
|
||||
'setup_logging_subsystem': 'log.setup_logging_subsystem',
|
||||
'redirect_stdouts_to_logger': 'log.redirect_stdouts_to_logger',
|
||||
},
|
||||
|
||||
@@ -42,7 +42,7 @@ __all__ = (
|
||||
'DaemonContext', 'detached', 'parse_uid', 'parse_gid', 'setgroups',
|
||||
'initgroups', 'setgid', 'setuid', 'maybe_drop_privileges', 'signals',
|
||||
'signal_name', 'set_process_title', 'set_mp_process_title',
|
||||
'get_errno_name', 'ignore_errno', 'fd_by_path',
|
||||
'get_errno_name', 'ignore_errno', 'fd_by_path', 'isatty',
|
||||
)
|
||||
|
||||
# exitcodes
|
||||
@@ -95,6 +95,14 @@ SIGNAMES = {
|
||||
SIGMAP = {getattr(_signal, name): name for name in SIGNAMES}
|
||||
|
||||
|
||||
def isatty(fh):
|
||||
"""Return true if the process has a controlling terminal."""
|
||||
try:
|
||||
return fh.isatty()
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
|
||||
def pyimplementation():
|
||||
"""Return string identifying the current Python implementation."""
|
||||
if hasattr(_platform, 'python_implementation'):
|
||||
@@ -186,10 +194,14 @@ class Pidfile:
|
||||
if not pid:
|
||||
self.remove()
|
||||
return True
|
||||
if pid == os.getpid():
|
||||
# this can be common in k8s pod with PID of 1 - don't kill
|
||||
self.remove()
|
||||
return True
|
||||
|
||||
try:
|
||||
os.kill(pid, 0)
|
||||
except os.error as exc:
|
||||
except OSError as exc:
|
||||
if exc.errno == errno.ESRCH or exc.errno == errno.EPERM:
|
||||
print('Stale pidfile exists - Removing it.', file=sys.stderr)
|
||||
self.remove()
|
||||
|
||||
@@ -6,6 +6,7 @@ from collections import deque
|
||||
from contextlib import contextmanager
|
||||
from weakref import proxy
|
||||
|
||||
from dateutil.parser import isoparse
|
||||
from kombu.utils.objects import cached_property
|
||||
from vine import Thenable, barrier, promise
|
||||
|
||||
@@ -532,7 +533,7 @@ class AsyncResult(ResultBase):
|
||||
"""UTC date and time."""
|
||||
date_done = self._get_task_meta().get('date_done')
|
||||
if date_done and not isinstance(date_done, datetime.datetime):
|
||||
return datetime.datetime.fromisoformat(date_done)
|
||||
return isoparse(date_done)
|
||||
return date_done
|
||||
|
||||
@property
|
||||
@@ -983,13 +984,14 @@ class GroupResult(ResultSet):
|
||||
class EagerResult(AsyncResult):
|
||||
"""Result that we know has already been executed."""
|
||||
|
||||
def __init__(self, id, ret_value, state, traceback=None):
|
||||
def __init__(self, id, ret_value, state, traceback=None, name=None):
|
||||
# pylint: disable=super-init-not-called
|
||||
# XXX should really not be inheriting from AsyncResult
|
||||
self.id = id
|
||||
self._result = ret_value
|
||||
self._state = state
|
||||
self._traceback = traceback
|
||||
self._name = name
|
||||
self.on_ready = promise()
|
||||
self.on_ready(self)
|
||||
|
||||
@@ -1042,6 +1044,7 @@ class EagerResult(AsyncResult):
|
||||
'result': self._result,
|
||||
'status': self._state,
|
||||
'traceback': self._traceback,
|
||||
'name': self._name,
|
||||
}
|
||||
|
||||
@property
|
||||
|
||||
@@ -4,9 +4,8 @@ from __future__ import annotations
|
||||
import re
|
||||
from bisect import bisect, bisect_left
|
||||
from collections import namedtuple
|
||||
from collections.abc import Iterable
|
||||
from datetime import datetime, timedelta, tzinfo
|
||||
from typing import Any, Callable, Mapping, Sequence
|
||||
from typing import Any, Callable, Iterable, Mapping, Sequence, Union
|
||||
|
||||
from kombu.utils.objects import cached_property
|
||||
|
||||
@@ -15,7 +14,7 @@ from celery import Celery
|
||||
from . import current_app
|
||||
from .utils.collections import AttributeDict
|
||||
from .utils.time import (ffwd, humanize_seconds, localize, maybe_make_aware, maybe_timedelta, remaining, timezone,
|
||||
weekday)
|
||||
weekday, yearmonth)
|
||||
|
||||
__all__ = (
|
||||
'ParseException', 'schedule', 'crontab', 'crontab_parser',
|
||||
@@ -52,7 +51,10 @@ Argument event "{event}" is invalid, must be one of {all_events}.\
|
||||
"""
|
||||
|
||||
|
||||
def cronfield(s: str) -> str:
|
||||
Cronspec = Union[int, str, Iterable[int]]
|
||||
|
||||
|
||||
def cronfield(s: Cronspec | None) -> Cronspec:
|
||||
return '*' if s is None else s
|
||||
|
||||
|
||||
@@ -300,9 +302,12 @@ class crontab_parser:
|
||||
i = int(s)
|
||||
except ValueError:
|
||||
try:
|
||||
i = weekday(s)
|
||||
i = yearmonth(s)
|
||||
except KeyError:
|
||||
raise ValueError(f'Invalid weekday literal {s!r}.')
|
||||
try:
|
||||
i = weekday(s)
|
||||
except KeyError:
|
||||
raise ValueError(f'Invalid weekday literal {s!r}.')
|
||||
|
||||
max_val = self.min_ + self.max_ - 1
|
||||
if i > max_val:
|
||||
@@ -393,8 +398,8 @@ class crontab(BaseSchedule):
|
||||
present in ``month_of_year``.
|
||||
"""
|
||||
|
||||
def __init__(self, minute: str = '*', hour: str = '*', day_of_week: str = '*',
|
||||
day_of_month: str = '*', month_of_year: str = '*', **kwargs: Any) -> None:
|
||||
def __init__(self, minute: Cronspec = '*', hour: Cronspec = '*', day_of_week: Cronspec = '*',
|
||||
day_of_month: Cronspec = '*', month_of_year: Cronspec = '*', **kwargs: Any) -> None:
|
||||
self._orig_minute = cronfield(minute)
|
||||
self._orig_hour = cronfield(hour)
|
||||
self._orig_day_of_week = cronfield(day_of_week)
|
||||
@@ -408,9 +413,26 @@ class crontab(BaseSchedule):
|
||||
self.month_of_year = self._expand_cronspec(month_of_year, 12, 1)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@classmethod
|
||||
def from_string(cls, crontab: str) -> crontab:
|
||||
"""
|
||||
Create a Crontab from a cron expression string. For example ``crontab.from_string('* * * * *')``.
|
||||
|
||||
.. code-block:: text
|
||||
|
||||
┌───────────── minute (0–59)
|
||||
│ ┌───────────── hour (0–23)
|
||||
│ │ ┌───────────── day of the month (1–31)
|
||||
│ │ │ ┌───────────── month (1–12)
|
||||
│ │ │ │ ┌───────────── day of the week (0–6) (Sunday to Saturday)
|
||||
* * * * *
|
||||
"""
|
||||
minute, hour, day_of_month, month_of_year, day_of_week = crontab.split(" ")
|
||||
return cls(minute, hour, day_of_week, day_of_month, month_of_year)
|
||||
|
||||
@staticmethod
|
||||
def _expand_cronspec(
|
||||
cronspec: int | str | Iterable,
|
||||
cronspec: Cronspec,
|
||||
max_: int, min_: int = 0) -> set[Any]:
|
||||
"""Expand cron specification.
|
||||
|
||||
@@ -535,7 +557,7 @@ class crontab(BaseSchedule):
|
||||
def __repr__(self) -> str:
|
||||
return CRON_REPR.format(self)
|
||||
|
||||
def __reduce__(self) -> tuple[type, tuple[str, str, str, str, str], Any]:
|
||||
def __reduce__(self) -> tuple[type, tuple[Cronspec, Cronspec, Cronspec, Cronspec, Cronspec], Any]:
|
||||
return (self.__class__, (self._orig_minute,
|
||||
self._orig_hour,
|
||||
self._orig_day_of_week,
|
||||
|
||||
@@ -43,7 +43,7 @@ class Certificate:
|
||||
|
||||
def has_expired(self) -> bool:
|
||||
"""Check if the certificate has expired."""
|
||||
return datetime.datetime.utcnow() >= self._cert.not_valid_after
|
||||
return datetime.datetime.now(datetime.timezone.utc) >= self._cert.not_valid_after_utc
|
||||
|
||||
def get_pubkey(self) -> (
|
||||
DSAPublicKey | EllipticCurvePublicKey | Ed448PublicKey | Ed25519PublicKey | RSAPublicKey
|
||||
|
||||
@@ -11,6 +11,11 @@ from .utils import get_digest_algorithm, reraise_errors
|
||||
|
||||
__all__ = ('SecureSerializer', 'register_auth')
|
||||
|
||||
# Note: we guarantee that this value won't appear in the serialized data,
|
||||
# so we can use it as a separator.
|
||||
# If you change this value, make sure it's not present in the serialized data.
|
||||
DEFAULT_SEPARATOR = str_to_bytes("\x00\x01")
|
||||
|
||||
|
||||
class SecureSerializer:
|
||||
"""Signed serializer."""
|
||||
@@ -29,7 +34,8 @@ class SecureSerializer:
|
||||
assert self._cert is not None
|
||||
with reraise_errors('Unable to serialize: {0!r}', (Exception,)):
|
||||
content_type, content_encoding, body = dumps(
|
||||
bytes_to_str(data), serializer=self._serializer)
|
||||
data, serializer=self._serializer)
|
||||
|
||||
# What we sign is the serialized body, not the body itself.
|
||||
# this way the receiver doesn't have to decode the contents
|
||||
# to verify the signature (and thus avoiding potential flaws
|
||||
@@ -48,43 +54,26 @@ class SecureSerializer:
|
||||
payload['signer'],
|
||||
payload['body'])
|
||||
self._cert_store[signer].verify(body, signature, self._digest)
|
||||
return loads(bytes_to_str(body), payload['content_type'],
|
||||
return loads(body, payload['content_type'],
|
||||
payload['content_encoding'], force=True)
|
||||
|
||||
def _pack(self, body, content_type, content_encoding, signer, signature,
|
||||
sep=str_to_bytes('\x00\x01')):
|
||||
sep=DEFAULT_SEPARATOR):
|
||||
fields = sep.join(
|
||||
ensure_bytes(s) for s in [signer, signature, content_type,
|
||||
content_encoding, body]
|
||||
ensure_bytes(s) for s in [b64encode(signer), b64encode(signature),
|
||||
content_type, content_encoding, body]
|
||||
)
|
||||
return b64encode(fields)
|
||||
|
||||
def _unpack(self, payload, sep=str_to_bytes('\x00\x01')):
|
||||
def _unpack(self, payload, sep=DEFAULT_SEPARATOR):
|
||||
raw_payload = b64decode(ensure_bytes(payload))
|
||||
first_sep = raw_payload.find(sep)
|
||||
|
||||
signer = raw_payload[:first_sep]
|
||||
signer_cert = self._cert_store[signer]
|
||||
|
||||
# shift 3 bits right to get signature length
|
||||
# 2048bit rsa key has a signature length of 256
|
||||
# 4096bit rsa key has a signature length of 512
|
||||
sig_len = signer_cert.get_pubkey().key_size >> 3
|
||||
sep_len = len(sep)
|
||||
signature_start_position = first_sep + sep_len
|
||||
signature_end_position = signature_start_position + sig_len
|
||||
signature = raw_payload[
|
||||
signature_start_position:signature_end_position
|
||||
]
|
||||
|
||||
v = raw_payload[signature_end_position + sep_len:].split(sep)
|
||||
|
||||
v = raw_payload.split(sep, maxsplit=4)
|
||||
return {
|
||||
'signer': signer,
|
||||
'signature': signature,
|
||||
'content_type': bytes_to_str(v[0]),
|
||||
'content_encoding': bytes_to_str(v[1]),
|
||||
'body': bytes_to_str(v[2]),
|
||||
'signer': b64decode(v[0]),
|
||||
'signature': b64decode(v[1]),
|
||||
'content_type': bytes_to_str(v[2]),
|
||||
'content_encoding': bytes_to_str(v[3]),
|
||||
'body': v[4],
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,49 @@
|
||||
"""Code related to handling annotations."""
|
||||
|
||||
import sys
|
||||
import types
|
||||
import typing
|
||||
from inspect import isclass
|
||||
|
||||
|
||||
def is_none_type(value: typing.Any) -> bool:
|
||||
"""Check if the given value is a NoneType."""
|
||||
if sys.version_info < (3, 10):
|
||||
# raise Exception('below 3.10', value, type(None))
|
||||
return value is type(None)
|
||||
return value == types.NoneType # type: ignore[no-any-return]
|
||||
|
||||
|
||||
def get_optional_arg(annotation: typing.Any) -> typing.Any:
|
||||
"""Get the argument from an Optional[...] annotation, or None if it is no such annotation."""
|
||||
origin = typing.get_origin(annotation)
|
||||
if origin != typing.Union and (sys.version_info >= (3, 10) and origin != types.UnionType):
|
||||
return None
|
||||
|
||||
union_args = typing.get_args(annotation)
|
||||
if len(union_args) != 2: # Union does _not_ have two members, so it's not an Optional
|
||||
return None
|
||||
|
||||
has_none_arg = any(is_none_type(arg) for arg in union_args)
|
||||
# There will always be at least one type arg, as we have already established that this is a Union with exactly
|
||||
# two members, and both cannot be None (`Union[None, None]` does not work).
|
||||
type_arg = next(arg for arg in union_args if not is_none_type(arg)) # pragma: no branch
|
||||
|
||||
if has_none_arg:
|
||||
return type_arg
|
||||
return None
|
||||
|
||||
|
||||
def annotation_is_class(annotation: typing.Any) -> bool:
|
||||
"""Test if a given annotation is a class that can be used in isinstance()/issubclass()."""
|
||||
# isclass() returns True for generic type hints (e.g. `list[str]`) until Python 3.10.
|
||||
# NOTE: The guard for Python 3.9 is because types.GenericAlias is only added in Python 3.9. This is not a problem
|
||||
# as the syntax is added in the same version in the first place.
|
||||
if (3, 9) <= sys.version_info < (3, 11) and isinstance(annotation, types.GenericAlias):
|
||||
return False
|
||||
return isclass(annotation)
|
||||
|
||||
|
||||
def annotation_issubclass(annotation: typing.Any, cls: type) -> bool:
|
||||
"""Test if a given annotation is of the given subclass."""
|
||||
return annotation_is_class(annotation) and issubclass(annotation, cls)
|
||||
@@ -595,8 +595,7 @@ class LimitedSet:
|
||||
break # oldest item hasn't expired yet
|
||||
self.pop()
|
||||
|
||||
def pop(self, default=None) -> Any:
|
||||
# type: (Any) -> Any
|
||||
def pop(self, default: Any = None) -> Any:
|
||||
"""Remove and return the oldest item, or :const:`None` when empty."""
|
||||
while self._heap:
|
||||
_, item = heappop(self._heap)
|
||||
|
||||
@@ -54,6 +54,9 @@ def _boundmethod_safe_weakref(obj):
|
||||
def _make_lookup_key(receiver, sender, dispatch_uid):
|
||||
if dispatch_uid:
|
||||
return (dispatch_uid, _make_id(sender))
|
||||
# Issue #9119 - retry-wrapped functions use the underlying function for dispatch_uid
|
||||
elif hasattr(receiver, '_dispatch_uid'):
|
||||
return (receiver._dispatch_uid, _make_id(sender))
|
||||
else:
|
||||
return (_make_id(receiver), _make_id(sender))
|
||||
|
||||
@@ -170,6 +173,7 @@ class Signal: # pragma: no cover
|
||||
# it up later with the original func id
|
||||
options['dispatch_uid'] = _make_id(fun)
|
||||
fun = _retry_receiver(fun)
|
||||
fun._dispatch_uid = options['dispatch_uid']
|
||||
|
||||
self._connect_signal(fun, sender, options['weak'],
|
||||
options['dispatch_uid'])
|
||||
|
||||
@@ -51,8 +51,13 @@ def instantiate(name, *args, **kwargs):
|
||||
@contextmanager
|
||||
def cwd_in_path():
|
||||
"""Context adding the current working directory to sys.path."""
|
||||
cwd = os.getcwd()
|
||||
if cwd in sys.path:
|
||||
try:
|
||||
cwd = os.getcwd()
|
||||
except FileNotFoundError:
|
||||
cwd = None
|
||||
if not cwd:
|
||||
yield
|
||||
elif cwd in sys.path:
|
||||
yield
|
||||
else:
|
||||
sys.path.insert(0, cwd)
|
||||
|
||||
@@ -50,9 +50,9 @@ TIMEZONE_REGEX = re.compile(
|
||||
)
|
||||
|
||||
|
||||
def parse_iso8601(datestring):
|
||||
def parse_iso8601(datestring: str) -> datetime:
|
||||
"""Parse and convert ISO-8601 string to datetime."""
|
||||
warn("parse_iso8601", "v5.3", "v6", "datetime.datetime.fromisoformat")
|
||||
warn("parse_iso8601", "v5.3", "v6", "datetime.datetime.fromisoformat or dateutil.parser.isoparse")
|
||||
m = ISO8601_REGEX.match(datestring)
|
||||
if not m:
|
||||
raise ValueError('unable to parse date string %r' % datestring)
|
||||
|
||||
@@ -37,7 +37,7 @@ base_logger = logger = _get_logger('celery')
|
||||
|
||||
|
||||
def set_in_sighandler(value):
|
||||
"""Set flag signifiying that we're inside a signal handler."""
|
||||
"""Set flag signifying that we're inside a signal handler."""
|
||||
global _in_sighandler
|
||||
_in_sighandler = value
|
||||
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
"""Worker name utilities."""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import socket
|
||||
from functools import partial
|
||||
@@ -22,13 +24,18 @@ NODENAME_DEFAULT = 'celery'
|
||||
gethostname = memoize(1, Cache=dict)(socket.gethostname)
|
||||
|
||||
__all__ = (
|
||||
'worker_direct', 'gethostname', 'nodename',
|
||||
'anon_nodename', 'nodesplit', 'default_nodename',
|
||||
'node_format', 'host_format',
|
||||
'worker_direct',
|
||||
'gethostname',
|
||||
'nodename',
|
||||
'anon_nodename',
|
||||
'nodesplit',
|
||||
'default_nodename',
|
||||
'node_format',
|
||||
'host_format',
|
||||
)
|
||||
|
||||
|
||||
def worker_direct(hostname):
|
||||
def worker_direct(hostname: str | Queue) -> Queue:
|
||||
"""Return the :class:`kombu.Queue` being a direct route to a worker.
|
||||
|
||||
Arguments:
|
||||
@@ -46,21 +53,20 @@ def worker_direct(hostname):
|
||||
)
|
||||
|
||||
|
||||
def nodename(name, hostname):
|
||||
def nodename(name: str, hostname: str) -> str:
|
||||
"""Create node name from name/hostname pair."""
|
||||
return NODENAME_SEP.join((name, hostname))
|
||||
|
||||
|
||||
def anon_nodename(hostname=None, prefix='gen'):
|
||||
def anon_nodename(hostname: str | None = None, prefix: str = 'gen') -> str:
|
||||
"""Return the nodename for this process (not a worker).
|
||||
|
||||
This is used for e.g. the origin task message field.
|
||||
"""
|
||||
return nodename(''.join([prefix, str(os.getpid())]),
|
||||
hostname or gethostname())
|
||||
return nodename(''.join([prefix, str(os.getpid())]), hostname or gethostname())
|
||||
|
||||
|
||||
def nodesplit(name):
|
||||
def nodesplit(name: str) -> tuple[None, str] | list[str]:
|
||||
"""Split node name into tuple of name/hostname."""
|
||||
parts = name.split(NODENAME_SEP, 1)
|
||||
if len(parts) == 1:
|
||||
@@ -68,21 +74,21 @@ def nodesplit(name):
|
||||
return parts
|
||||
|
||||
|
||||
def default_nodename(hostname):
|
||||
def default_nodename(hostname: str) -> str:
|
||||
"""Return the default nodename for this process."""
|
||||
name, host = nodesplit(hostname or '')
|
||||
return nodename(name or NODENAME_DEFAULT, host or gethostname())
|
||||
|
||||
|
||||
def node_format(s, name, **extra):
|
||||
def node_format(s: str, name: str, **extra: dict) -> str:
|
||||
"""Format worker node name (name@host.com)."""
|
||||
shortname, host = nodesplit(name)
|
||||
return host_format(
|
||||
s, host, shortname or NODENAME_DEFAULT, p=name, **extra)
|
||||
return host_format(s, host, shortname or NODENAME_DEFAULT, p=name, **extra)
|
||||
|
||||
|
||||
def _fmt_process_index(prefix='', default='0'):
|
||||
def _fmt_process_index(prefix: str = '', default: str = '0') -> str:
|
||||
from .log import current_process_index
|
||||
|
||||
index = current_process_index()
|
||||
return f'{prefix}{index}' if index else default
|
||||
|
||||
@@ -90,13 +96,19 @@ def _fmt_process_index(prefix='', default='0'):
|
||||
_fmt_process_index_with_prefix = partial(_fmt_process_index, '-', '')
|
||||
|
||||
|
||||
def host_format(s, host=None, name=None, **extra):
|
||||
def host_format(s: str, host: str | None = None, name: str | None = None, **extra: dict) -> str:
|
||||
"""Format host %x abbreviations."""
|
||||
host = host or gethostname()
|
||||
hname, _, domain = host.partition('.')
|
||||
name = name or hname
|
||||
keys = dict({
|
||||
'h': host, 'n': name, 'd': domain,
|
||||
'i': _fmt_process_index, 'I': _fmt_process_index_with_prefix,
|
||||
}, **extra)
|
||||
keys = dict(
|
||||
{
|
||||
'h': host,
|
||||
'n': name,
|
||||
'd': domain,
|
||||
'i': _fmt_process_index,
|
||||
'I': _fmt_process_index_with_prefix,
|
||||
},
|
||||
**extra,
|
||||
)
|
||||
return simple_format(s, keys)
|
||||
|
||||
@@ -0,0 +1,20 @@
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
def detect_quorum_queues(app, driver_type: str) -> tuple[bool, str]:
|
||||
"""Detect if any of the queues are quorum queues.
|
||||
|
||||
Returns:
|
||||
tuple[bool, str]: A tuple containing a boolean indicating if any of the queues are quorum queues
|
||||
and the name of the first quorum queue found or an empty string if no quorum queues were found.
|
||||
"""
|
||||
is_rabbitmq_broker = driver_type == 'amqp'
|
||||
|
||||
if is_rabbitmq_broker:
|
||||
queues = app.amqp.queues
|
||||
for qname in queues:
|
||||
qarguments = queues[qname].queue_arguments or {}
|
||||
if qarguments.get("x-queue-type") == "quorum":
|
||||
return True, qname
|
||||
|
||||
return False, ""
|
||||
@@ -15,7 +15,7 @@ from decimal import Decimal
|
||||
from itertools import chain
|
||||
from numbers import Number
|
||||
from pprint import _recursion
|
||||
from typing import Any, AnyStr, Callable, Dict, Iterator, List, Sequence, Set, Tuple # noqa
|
||||
from typing import Any, AnyStr, Callable, Dict, Iterator, List, Optional, Sequence, Set, Tuple # noqa
|
||||
|
||||
from .text import truncate
|
||||
|
||||
@@ -41,7 +41,7 @@ _quoted = namedtuple('_quoted', ('value',))
|
||||
#: Recursion protection.
|
||||
_dirty = namedtuple('_dirty', ('objid',))
|
||||
|
||||
#: Types that are repsented as chars.
|
||||
#: Types that are represented as chars.
|
||||
chars_t = (bytes, str)
|
||||
|
||||
#: Types that are regarded as safe to call repr on.
|
||||
@@ -194,9 +194,12 @@ def _reprseq(val, lit_start, lit_end, builtin_type, chainer):
|
||||
)
|
||||
|
||||
|
||||
def reprstream(stack, seen=None, maxlevels=3, level=0, isinstance=isinstance):
|
||||
def reprstream(stack: deque,
|
||||
seen: Optional[Set] = None,
|
||||
maxlevels: int = 3,
|
||||
level: int = 0,
|
||||
isinstance: Callable = isinstance) -> Iterator[Any]:
|
||||
"""Streaming repr, yielding tokens."""
|
||||
# type: (deque, Set, int, int, Callable) -> Iterator[Any]
|
||||
seen = seen or set()
|
||||
append = stack.append
|
||||
popleft = stack.popleft
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
"""System information utilities."""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from math import ceil
|
||||
|
||||
@@ -9,16 +11,16 @@ __all__ = ('load_average', 'df')
|
||||
|
||||
if hasattr(os, 'getloadavg'):
|
||||
|
||||
def _load_average():
|
||||
def _load_average() -> tuple[float, ...]:
|
||||
return tuple(ceil(l * 1e2) / 1e2 for l in os.getloadavg())
|
||||
|
||||
else: # pragma: no cover
|
||||
# Windows doesn't have getloadavg
|
||||
def _load_average():
|
||||
return (0.0, 0.0, 0.0)
|
||||
def _load_average() -> tuple[float, ...]:
|
||||
return 0.0, 0.0, 0.0,
|
||||
|
||||
|
||||
def load_average():
|
||||
def load_average() -> tuple[float, ...]:
|
||||
"""Return system load average as a triple."""
|
||||
return _load_average()
|
||||
|
||||
@@ -26,23 +28,23 @@ def load_average():
|
||||
class df:
|
||||
"""Disk information."""
|
||||
|
||||
def __init__(self, path):
|
||||
def __init__(self, path: str | bytes | os.PathLike) -> None:
|
||||
self.path = path
|
||||
|
||||
@property
|
||||
def total_blocks(self):
|
||||
def total_blocks(self) -> float:
|
||||
return self.stat.f_blocks * self.stat.f_frsize / 1024
|
||||
|
||||
@property
|
||||
def available(self):
|
||||
def available(self) -> float:
|
||||
return self.stat.f_bavail * self.stat.f_frsize / 1024
|
||||
|
||||
@property
|
||||
def capacity(self):
|
||||
def capacity(self) -> int:
|
||||
avail = self.stat.f_bavail
|
||||
used = self.stat.f_blocks - self.stat.f_bfree
|
||||
return int(ceil(used * 100.0 / (used + avail) + 0.5))
|
||||
|
||||
@cached_property
|
||||
def stat(self):
|
||||
def stat(self) -> os.statvfs_result:
|
||||
return os.statvfs(os.path.abspath(self.path))
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Terminals and colors."""
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import codecs
|
||||
import os
|
||||
import platform
|
||||
import sys
|
||||
@@ -8,6 +9,8 @@ from functools import reduce
|
||||
|
||||
__all__ = ('colored',)
|
||||
|
||||
from typing import Any
|
||||
|
||||
BLACK, RED, GREEN, YELLOW, BLUE, MAGENTA, CYAN, WHITE = range(8)
|
||||
OP_SEQ = '\033[%dm'
|
||||
RESET_SEQ = '\033[0m'
|
||||
@@ -26,7 +29,7 @@ _IMG_PRE = '\033Ptmux;\033\033]' if TERM_IS_SCREEN else '\033]'
|
||||
_IMG_POST = '\a\033\\' if TERM_IS_SCREEN else '\a'
|
||||
|
||||
|
||||
def fg(s):
|
||||
def fg(s: int) -> str:
|
||||
return COLOR_SEQ % s
|
||||
|
||||
|
||||
@@ -41,11 +44,11 @@ class colored:
|
||||
... c.green('dog ')))
|
||||
"""
|
||||
|
||||
def __init__(self, *s, **kwargs):
|
||||
self.s = s
|
||||
self.enabled = not IS_WINDOWS and kwargs.get('enabled', True)
|
||||
self.op = kwargs.get('op', '')
|
||||
self.names = {
|
||||
def __init__(self, *s: object, **kwargs: Any) -> None:
|
||||
self.s: tuple[object, ...] = s
|
||||
self.enabled: bool = not IS_WINDOWS and kwargs.get('enabled', True)
|
||||
self.op: str = kwargs.get('op', '')
|
||||
self.names: dict[str, Any] = {
|
||||
'black': self.black,
|
||||
'red': self.red,
|
||||
'green': self.green,
|
||||
@@ -56,10 +59,10 @@ class colored:
|
||||
'white': self.white,
|
||||
}
|
||||
|
||||
def _add(self, a, b):
|
||||
return str(a) + str(b)
|
||||
def _add(self, a: object, b: object) -> str:
|
||||
return f"{a}{b}"
|
||||
|
||||
def _fold_no_color(self, a, b):
|
||||
def _fold_no_color(self, a: Any, b: Any) -> str:
|
||||
try:
|
||||
A = a.no_color()
|
||||
except AttributeError:
|
||||
@@ -69,109 +72,113 @@ class colored:
|
||||
except AttributeError:
|
||||
B = str(b)
|
||||
|
||||
return ''.join((str(A), str(B)))
|
||||
return f"{A}{B}"
|
||||
|
||||
def no_color(self):
|
||||
def no_color(self) -> str:
|
||||
if self.s:
|
||||
return str(reduce(self._fold_no_color, self.s))
|
||||
return ''
|
||||
|
||||
def embed(self):
|
||||
def embed(self) -> str:
|
||||
prefix = ''
|
||||
if self.enabled:
|
||||
prefix = self.op
|
||||
return ''.join((str(prefix), str(reduce(self._add, self.s))))
|
||||
return f"{prefix}{reduce(self._add, self.s)}"
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
suffix = ''
|
||||
if self.enabled:
|
||||
suffix = RESET_SEQ
|
||||
return str(''.join((self.embed(), str(suffix))))
|
||||
return f"{self.embed()}{suffix}"
|
||||
|
||||
def node(self, s, op):
|
||||
def node(self, s: tuple[object, ...], op: str) -> colored:
|
||||
return self.__class__(enabled=self.enabled, op=op, *s)
|
||||
|
||||
def black(self, *s):
|
||||
def black(self, *s: object) -> colored:
|
||||
return self.node(s, fg(30 + BLACK))
|
||||
|
||||
def red(self, *s):
|
||||
def red(self, *s: object) -> colored:
|
||||
return self.node(s, fg(30 + RED))
|
||||
|
||||
def green(self, *s):
|
||||
def green(self, *s: object) -> colored:
|
||||
return self.node(s, fg(30 + GREEN))
|
||||
|
||||
def yellow(self, *s):
|
||||
def yellow(self, *s: object) -> colored:
|
||||
return self.node(s, fg(30 + YELLOW))
|
||||
|
||||
def blue(self, *s):
|
||||
def blue(self, *s: object) -> colored:
|
||||
return self.node(s, fg(30 + BLUE))
|
||||
|
||||
def magenta(self, *s):
|
||||
def magenta(self, *s: object) -> colored:
|
||||
return self.node(s, fg(30 + MAGENTA))
|
||||
|
||||
def cyan(self, *s):
|
||||
def cyan(self, *s: object) -> colored:
|
||||
return self.node(s, fg(30 + CYAN))
|
||||
|
||||
def white(self, *s):
|
||||
def white(self, *s: object) -> colored:
|
||||
return self.node(s, fg(30 + WHITE))
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return repr(self.no_color())
|
||||
|
||||
def bold(self, *s):
|
||||
def bold(self, *s: object) -> colored:
|
||||
return self.node(s, OP_SEQ % 1)
|
||||
|
||||
def underline(self, *s):
|
||||
def underline(self, *s: object) -> colored:
|
||||
return self.node(s, OP_SEQ % 4)
|
||||
|
||||
def blink(self, *s):
|
||||
def blink(self, *s: object) -> colored:
|
||||
return self.node(s, OP_SEQ % 5)
|
||||
|
||||
def reverse(self, *s):
|
||||
def reverse(self, *s: object) -> colored:
|
||||
return self.node(s, OP_SEQ % 7)
|
||||
|
||||
def bright(self, *s):
|
||||
def bright(self, *s: object) -> colored:
|
||||
return self.node(s, OP_SEQ % 8)
|
||||
|
||||
def ired(self, *s):
|
||||
def ired(self, *s: object) -> colored:
|
||||
return self.node(s, fg(40 + RED))
|
||||
|
||||
def igreen(self, *s):
|
||||
def igreen(self, *s: object) -> colored:
|
||||
return self.node(s, fg(40 + GREEN))
|
||||
|
||||
def iyellow(self, *s):
|
||||
def iyellow(self, *s: object) -> colored:
|
||||
return self.node(s, fg(40 + YELLOW))
|
||||
|
||||
def iblue(self, *s):
|
||||
def iblue(self, *s: colored) -> colored:
|
||||
return self.node(s, fg(40 + BLUE))
|
||||
|
||||
def imagenta(self, *s):
|
||||
def imagenta(self, *s: object) -> colored:
|
||||
return self.node(s, fg(40 + MAGENTA))
|
||||
|
||||
def icyan(self, *s):
|
||||
def icyan(self, *s: object) -> colored:
|
||||
return self.node(s, fg(40 + CYAN))
|
||||
|
||||
def iwhite(self, *s):
|
||||
def iwhite(self, *s: object) -> colored:
|
||||
return self.node(s, fg(40 + WHITE))
|
||||
|
||||
def reset(self, *s):
|
||||
return self.node(s or [''], RESET_SEQ)
|
||||
def reset(self, *s: object) -> colored:
|
||||
return self.node(s or ('',), RESET_SEQ)
|
||||
|
||||
def __add__(self, other):
|
||||
return str(self) + str(other)
|
||||
def __add__(self, other: object) -> str:
|
||||
return f"{self}{other}"
|
||||
|
||||
|
||||
def supports_images():
|
||||
return sys.stdin.isatty() and ITERM_PROFILE
|
||||
def supports_images() -> bool:
|
||||
|
||||
try:
|
||||
return sys.stdin.isatty() and bool(os.environ.get('ITERM_PROFILE'))
|
||||
except AttributeError:
|
||||
return False
|
||||
|
||||
|
||||
def _read_as_base64(path):
|
||||
with codecs.open(path, mode='rb') as fh:
|
||||
def _read_as_base64(path: str) -> str:
|
||||
with open(path, mode='rb') as fh:
|
||||
encoded = base64.b64encode(fh.read())
|
||||
return encoded if isinstance(encoded, str) else encoded.decode('ascii')
|
||||
return encoded.decode('ascii')
|
||||
|
||||
|
||||
def imgcat(path, inline=1, preserve_aspect_ratio=0, **kwargs):
|
||||
def imgcat(path: str, inline: int = 1, preserve_aspect_ratio: int = 0, **kwargs: Any) -> str:
|
||||
return '\n%s1337;File=inline=%d;preserveAspectRatio=%d:%s%s' % (
|
||||
_IMG_PRE, inline, preserve_aspect_ratio,
|
||||
_read_as_base64(path), _IMG_POST)
|
||||
|
||||
@@ -14,6 +14,7 @@ from types import ModuleType
|
||||
from typing import Any, Callable
|
||||
|
||||
from dateutil import tz as dateutil_tz
|
||||
from dateutil.parser import isoparse
|
||||
from kombu.utils.functional import reprcall
|
||||
from kombu.utils.objects import cached_property
|
||||
|
||||
@@ -40,6 +41,9 @@ C_REMDEBUG = os.environ.get('C_REMDEBUG', False)
|
||||
DAYNAMES = 'sun', 'mon', 'tue', 'wed', 'thu', 'fri', 'sat'
|
||||
WEEKDAYS = dict(zip(DAYNAMES, range(7)))
|
||||
|
||||
MONTHNAMES = 'jan', 'feb', 'mar', 'apr', 'may', 'jun', 'jul', 'aug', 'sep', 'oct', 'nov', 'dec'
|
||||
YEARMONTHS = dict(zip(MONTHNAMES, range(1, 13)))
|
||||
|
||||
RATE_MODIFIER_MAP = {
|
||||
's': lambda n: n,
|
||||
'm': lambda n: n / 60.0,
|
||||
@@ -200,7 +204,7 @@ def delta_resolution(dt: datetime, delta: timedelta) -> datetime:
|
||||
def remaining(
|
||||
start: datetime, ends_in: timedelta, now: Callable | None = None,
|
||||
relative: bool = False) -> timedelta:
|
||||
"""Calculate the remaining time for a start date and a timedelta.
|
||||
"""Calculate the real remaining time for a start date and a timedelta.
|
||||
|
||||
For example, "how many seconds left for 30 seconds after start?"
|
||||
|
||||
@@ -211,24 +215,28 @@ def remaining(
|
||||
using :func:`delta_resolution` (i.e., rounded to the
|
||||
resolution of `ends_in`).
|
||||
now (Callable): Function returning the current time and date.
|
||||
Defaults to :func:`datetime.utcnow`.
|
||||
Defaults to :func:`datetime.now(timezone.utc)`.
|
||||
|
||||
Returns:
|
||||
~datetime.timedelta: Remaining time.
|
||||
"""
|
||||
now = now or datetime.utcnow()
|
||||
if str(
|
||||
start.tzinfo) == str(
|
||||
now.tzinfo) and now.utcoffset() != start.utcoffset():
|
||||
# DST started/ended
|
||||
start = start.replace(tzinfo=now.tzinfo)
|
||||
now = now or datetime.now(datetime_timezone.utc)
|
||||
end_date = start + ends_in
|
||||
if relative:
|
||||
end_date = delta_resolution(end_date, ends_in).replace(microsecond=0)
|
||||
ret = end_date - now
|
||||
|
||||
# Using UTC to calculate real time difference.
|
||||
# Python by default uses wall time in arithmetic between datetimes with
|
||||
# equal non-UTC timezones.
|
||||
now_utc = now.astimezone(timezone.utc)
|
||||
end_date_utc = end_date.astimezone(timezone.utc)
|
||||
ret = end_date_utc - now_utc
|
||||
if C_REMDEBUG: # pragma: no cover
|
||||
print('rem: NOW:{!r} START:{!r} ENDS_IN:{!r} END_DATE:{} REM:{}'.format(
|
||||
now, start, ends_in, end_date, ret))
|
||||
print(
|
||||
'rem: NOW:{!r} NOW_UTC:{!r} START:{!r} ENDS_IN:{!r} '
|
||||
'END_DATE:{} END_DATE_UTC:{!r} REM:{}'.format(
|
||||
now, now_utc, start, ends_in, end_date, end_date_utc, ret)
|
||||
)
|
||||
return ret
|
||||
|
||||
|
||||
@@ -257,6 +265,21 @@ def weekday(name: str) -> int:
|
||||
raise KeyError(name)
|
||||
|
||||
|
||||
def yearmonth(name: str) -> int:
|
||||
"""Return the position of a month: 1 - 12, where 1 is January.
|
||||
|
||||
Example:
|
||||
>>> yearmonth('january'), yearmonth('jan'), yearmonth('may')
|
||||
(1, 1, 5)
|
||||
"""
|
||||
abbreviation = name[0:3].lower()
|
||||
try:
|
||||
return YEARMONTHS[abbreviation]
|
||||
except KeyError:
|
||||
# Show original day name in exception, instead of abbr.
|
||||
raise KeyError(name)
|
||||
|
||||
|
||||
def humanize_seconds(
|
||||
secs: int, prefix: str = '', sep: str = '', now: str = 'now',
|
||||
microseconds: bool = False) -> str:
|
||||
@@ -288,7 +311,7 @@ def maybe_iso8601(dt: datetime | str | None) -> None | datetime:
|
||||
return
|
||||
if isinstance(dt, datetime):
|
||||
return dt
|
||||
return datetime.fromisoformat(dt)
|
||||
return isoparse(dt)
|
||||
|
||||
|
||||
def is_naive(dt: datetime) -> bool:
|
||||
@@ -302,7 +325,7 @@ def _can_detect_ambiguous(tz: tzinfo) -> bool:
|
||||
return isinstance(tz, ZoneInfo) or hasattr(tz, "is_ambiguous")
|
||||
|
||||
|
||||
def _is_ambigious(dt: datetime, tz: tzinfo) -> bool:
|
||||
def _is_ambiguous(dt: datetime, tz: tzinfo) -> bool:
|
||||
"""Helper function to determine if a timezone is ambiguous using python's dateutil module.
|
||||
|
||||
Returns False if the timezone cannot detect ambiguity, or if there is no ambiguity, otherwise True.
|
||||
@@ -319,7 +342,7 @@ def make_aware(dt: datetime, tz: tzinfo) -> datetime:
|
||||
"""Set timezone for a :class:`~datetime.datetime` object."""
|
||||
|
||||
dt = dt.replace(tzinfo=tz)
|
||||
if _is_ambigious(dt, tz):
|
||||
if _is_ambiguous(dt, tz):
|
||||
dt = min(dt.replace(fold=0), dt.replace(fold=1))
|
||||
return dt
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ import threading
|
||||
from itertools import count
|
||||
from threading import TIMEOUT_MAX as THREAD_TIMEOUT_MAX
|
||||
from time import sleep
|
||||
from typing import Any, Callable, Iterator, Optional, Tuple
|
||||
|
||||
from kombu.asynchronous.timer import Entry
|
||||
from kombu.asynchronous.timer import Timer as Schedule
|
||||
@@ -30,20 +31,23 @@ class Timer(threading.Thread):
|
||||
Entry = Entry
|
||||
Schedule = Schedule
|
||||
|
||||
running = False
|
||||
on_tick = None
|
||||
running: bool = False
|
||||
on_tick: Optional[Callable[[float], None]] = None
|
||||
|
||||
_timer_count = count(1)
|
||||
_timer_count: count = count(1)
|
||||
|
||||
if TIMER_DEBUG: # pragma: no cover
|
||||
def start(self, *args, **kwargs):
|
||||
def start(self, *args: Any, **kwargs: Any) -> None:
|
||||
import traceback
|
||||
print('- Timer starting')
|
||||
traceback.print_stack()
|
||||
super().start(*args, **kwargs)
|
||||
|
||||
def __init__(self, schedule=None, on_error=None, on_tick=None,
|
||||
on_start=None, max_interval=None, **kwargs):
|
||||
def __init__(self, schedule: Optional[Schedule] = None,
|
||||
on_error: Optional[Callable[[Exception], None]] = None,
|
||||
on_tick: Optional[Callable[[float], None]] = None,
|
||||
on_start: Optional[Callable[['Timer'], None]] = None,
|
||||
max_interval: Optional[float] = None, **kwargs: Any) -> None:
|
||||
self.schedule = schedule or self.Schedule(on_error=on_error,
|
||||
max_interval=max_interval)
|
||||
self.on_start = on_start
|
||||
@@ -60,8 +64,10 @@ class Timer(threading.Thread):
|
||||
self.daemon = True
|
||||
self.name = f'Timer-{next(self._timer_count)}'
|
||||
|
||||
def _next_entry(self):
|
||||
def _next_entry(self) -> Optional[float]:
|
||||
with self.not_empty:
|
||||
delay: Optional[float]
|
||||
entry: Optional[Entry]
|
||||
delay, entry = next(self.scheduler)
|
||||
if entry is None:
|
||||
if delay is None:
|
||||
@@ -70,10 +76,10 @@ class Timer(threading.Thread):
|
||||
return self.schedule.apply_entry(entry)
|
||||
__next__ = next = _next_entry # for 2to3
|
||||
|
||||
def run(self):
|
||||
def run(self) -> None:
|
||||
try:
|
||||
self.running = True
|
||||
self.scheduler = iter(self.schedule)
|
||||
self.scheduler: Iterator[Tuple[Optional[float], Optional[Entry]]] = iter(self.schedule)
|
||||
|
||||
while not self.__is_shutdown.is_set():
|
||||
delay = self._next_entry()
|
||||
@@ -94,61 +100,61 @@ class Timer(threading.Thread):
|
||||
sys.stderr.flush()
|
||||
os._exit(1)
|
||||
|
||||
def stop(self):
|
||||
def stop(self) -> None:
|
||||
self.__is_shutdown.set()
|
||||
if self.running:
|
||||
self.__is_stopped.wait()
|
||||
self.join(THREAD_TIMEOUT_MAX)
|
||||
self.running = False
|
||||
|
||||
def ensure_started(self):
|
||||
def ensure_started(self) -> None:
|
||||
if not self.running and not self.is_alive():
|
||||
if self.on_start:
|
||||
self.on_start(self)
|
||||
self.start()
|
||||
|
||||
def _do_enter(self, meth, *args, **kwargs):
|
||||
def _do_enter(self, meth: str, *args: Any, **kwargs: Any) -> Entry:
|
||||
self.ensure_started()
|
||||
with self.mutex:
|
||||
entry = getattr(self.schedule, meth)(*args, **kwargs)
|
||||
self.not_empty.notify()
|
||||
return entry
|
||||
|
||||
def enter(self, entry, eta, priority=None):
|
||||
def enter(self, entry: Entry, eta: float, priority: Optional[int] = None) -> Entry:
|
||||
return self._do_enter('enter_at', entry, eta, priority=priority)
|
||||
|
||||
def call_at(self, *args, **kwargs):
|
||||
def call_at(self, *args: Any, **kwargs: Any) -> Entry:
|
||||
return self._do_enter('call_at', *args, **kwargs)
|
||||
|
||||
def enter_after(self, *args, **kwargs):
|
||||
def enter_after(self, *args: Any, **kwargs: Any) -> Entry:
|
||||
return self._do_enter('enter_after', *args, **kwargs)
|
||||
|
||||
def call_after(self, *args, **kwargs):
|
||||
def call_after(self, *args: Any, **kwargs: Any) -> Entry:
|
||||
return self._do_enter('call_after', *args, **kwargs)
|
||||
|
||||
def call_repeatedly(self, *args, **kwargs):
|
||||
def call_repeatedly(self, *args: Any, **kwargs: Any) -> Entry:
|
||||
return self._do_enter('call_repeatedly', *args, **kwargs)
|
||||
|
||||
def exit_after(self, secs, priority=10):
|
||||
def exit_after(self, secs: float, priority: int = 10) -> None:
|
||||
self.call_after(secs, sys.exit, priority)
|
||||
|
||||
def cancel(self, tref):
|
||||
def cancel(self, tref: Entry) -> None:
|
||||
tref.cancel()
|
||||
|
||||
def clear(self):
|
||||
def clear(self) -> None:
|
||||
self.schedule.clear()
|
||||
|
||||
def empty(self):
|
||||
def empty(self) -> bool:
|
||||
return not len(self)
|
||||
|
||||
def __len__(self):
|
||||
def __len__(self) -> int:
|
||||
return len(self.schedule)
|
||||
|
||||
def __bool__(self):
|
||||
def __bool__(self) -> bool:
|
||||
"""``bool(timer)``."""
|
||||
return True
|
||||
__nonzero__ = __bool__
|
||||
|
||||
@property
|
||||
def queue(self):
|
||||
def queue(self) -> list:
|
||||
return self.schedule.queue
|
||||
|
||||
@@ -169,6 +169,7 @@ class Consumer:
|
||||
'celery.worker.consumer.heart:Heart',
|
||||
'celery.worker.consumer.control:Control',
|
||||
'celery.worker.consumer.tasks:Tasks',
|
||||
'celery.worker.consumer.delayed_delivery:DelayedDelivery',
|
||||
'celery.worker.consumer.consumer:Evloop',
|
||||
'celery.worker.consumer.agent:Agent',
|
||||
]
|
||||
@@ -390,20 +391,21 @@ class Consumer:
|
||||
else:
|
||||
warnings.warn(CANCEL_TASKS_BY_DEFAULT, CPendingDeprecationWarning)
|
||||
|
||||
self.initial_prefetch_count = max(
|
||||
self.prefetch_multiplier,
|
||||
self.max_prefetch_count - len(tuple(active_requests)) * self.prefetch_multiplier
|
||||
)
|
||||
|
||||
self._maximum_prefetch_restored = self.initial_prefetch_count == self.max_prefetch_count
|
||||
if not self._maximum_prefetch_restored:
|
||||
logger.info(
|
||||
f"Temporarily reducing the prefetch count to {self.initial_prefetch_count} to avoid over-fetching "
|
||||
f"since {len(tuple(active_requests))} tasks are currently being processed.\n"
|
||||
f"The prefetch count will be gradually restored to {self.max_prefetch_count} as the tasks "
|
||||
"complete processing."
|
||||
if self.app.conf.worker_enable_prefetch_count_reduction:
|
||||
self.initial_prefetch_count = max(
|
||||
self.prefetch_multiplier,
|
||||
self.max_prefetch_count - len(tuple(active_requests)) * self.prefetch_multiplier
|
||||
)
|
||||
|
||||
self._maximum_prefetch_restored = self.initial_prefetch_count == self.max_prefetch_count
|
||||
if not self._maximum_prefetch_restored:
|
||||
logger.info(
|
||||
f"Temporarily reducing the prefetch count to {self.initial_prefetch_count} to avoid "
|
||||
f"over-fetching since {len(tuple(active_requests))} tasks are currently being processed.\n"
|
||||
f"The prefetch count will be gradually restored to {self.max_prefetch_count} as the tasks "
|
||||
"complete processing."
|
||||
)
|
||||
|
||||
def register_with_event_loop(self, hub):
|
||||
self.blueprint.send_all(
|
||||
self, 'register_with_event_loop', args=(hub,),
|
||||
@@ -411,6 +413,7 @@ class Consumer:
|
||||
)
|
||||
|
||||
def shutdown(self):
|
||||
self.perform_pending_operations()
|
||||
self.blueprint.shutdown(self)
|
||||
|
||||
def stop(self):
|
||||
@@ -475,9 +478,9 @@ class Consumer:
|
||||
return self.ensure_connected(
|
||||
self.app.connection_for_read(heartbeat=heartbeat))
|
||||
|
||||
def connection_for_write(self, heartbeat=None):
|
||||
def connection_for_write(self, url=None, heartbeat=None):
|
||||
return self.ensure_connected(
|
||||
self.app.connection_for_write(heartbeat=heartbeat))
|
||||
self.app.connection_for_write(url=url, heartbeat=heartbeat))
|
||||
|
||||
def ensure_connected(self, conn):
|
||||
# Callback called for each retry while the connection
|
||||
@@ -504,13 +507,14 @@ class Consumer:
|
||||
# to determine whether connection retries are disabled.
|
||||
retry_disabled = not self.app.conf.broker_connection_retry
|
||||
|
||||
warnings.warn(
|
||||
CPendingDeprecationWarning(
|
||||
f"The broker_connection_retry configuration setting will no longer determine\n"
|
||||
f"whether broker connection retries are made during startup in Celery 6.0 and above.\n"
|
||||
f"If you wish to retain the existing behavior for retrying connections on startup,\n"
|
||||
f"you should set broker_connection_retry_on_startup to {self.app.conf.broker_connection_retry}.")
|
||||
)
|
||||
if retry_disabled:
|
||||
warnings.warn(
|
||||
CPendingDeprecationWarning(
|
||||
"The broker_connection_retry configuration setting will no longer determine\n"
|
||||
"whether broker connection retries are made during startup in Celery 6.0 and above.\n"
|
||||
"If you wish to refrain from retrying connections on startup,\n"
|
||||
"you should set broker_connection_retry_on_startup to False instead.")
|
||||
)
|
||||
else:
|
||||
if self.first_connection_attempt:
|
||||
retry_disabled = not self.app.conf.broker_connection_retry_on_startup
|
||||
@@ -696,7 +700,10 @@ class Consumer:
|
||||
|
||||
def _restore_prefetch_count_after_connection_restart(self, p, *args):
|
||||
with self.qos._mutex:
|
||||
if self._maximum_prefetch_restored:
|
||||
if any((
|
||||
not self.app.conf.worker_enable_prefetch_count_reduction,
|
||||
self._maximum_prefetch_restored,
|
||||
)):
|
||||
return
|
||||
|
||||
new_prefetch_count = min(self.max_prefetch_count, self._new_prefetch_count)
|
||||
@@ -726,6 +733,29 @@ class Consumer:
|
||||
self=self, state=self.blueprint.human_state(),
|
||||
)
|
||||
|
||||
def cancel_all_unacked_requests(self):
|
||||
"""Cancel all active requests that either do not require late acknowledgments or,
|
||||
if they do, have not been acknowledged yet.
|
||||
"""
|
||||
|
||||
def should_cancel(request):
|
||||
if not request.task.acks_late:
|
||||
# Task does not require late acknowledgment, cancel it.
|
||||
return True
|
||||
|
||||
if not request.acknowledged:
|
||||
# Task is late acknowledged, but it has not been acknowledged yet, cancel it.
|
||||
return True
|
||||
|
||||
# Task is late acknowledged, but it has already been acknowledged.
|
||||
return False # Do not cancel and allow it to gracefully finish as it has already been acknowledged.
|
||||
|
||||
requests_to_cancel = tuple(filter(should_cancel, active_requests))
|
||||
|
||||
if requests_to_cancel:
|
||||
for request in requests_to_cancel:
|
||||
request.cancel(self.pool)
|
||||
|
||||
|
||||
class Evloop(bootsteps.StartStopStep):
|
||||
"""Event loop service.
|
||||
|
||||
@@ -0,0 +1,247 @@
|
||||
"""Native delayed delivery functionality for Celery workers.
|
||||
|
||||
This module provides the DelayedDelivery bootstep which handles setup and configuration
|
||||
of native delayed delivery functionality when using quorum queues.
|
||||
"""
|
||||
from typing import Iterator, List, Optional, Set, Union, ValuesView
|
||||
|
||||
from kombu import Connection, Queue
|
||||
from kombu.transport.native_delayed_delivery import (bind_queue_to_native_delayed_delivery_exchange,
|
||||
declare_native_delayed_delivery_exchanges_and_queues)
|
||||
from kombu.utils.functional import retry_over_time
|
||||
|
||||
from celery import Celery, bootsteps
|
||||
from celery.utils.log import get_logger
|
||||
from celery.utils.quorum_queues import detect_quorum_queues
|
||||
from celery.worker.consumer import Consumer, Tasks
|
||||
|
||||
__all__ = ('DelayedDelivery',)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
# Default retry settings
|
||||
RETRY_INTERVAL = 1.0 # seconds between retries
|
||||
MAX_RETRIES = 3 # maximum number of retries
|
||||
|
||||
|
||||
# Valid queue types for delayed delivery
|
||||
VALID_QUEUE_TYPES = {'classic', 'quorum'}
|
||||
|
||||
|
||||
class DelayedDelivery(bootsteps.StartStopStep):
|
||||
"""Bootstep that sets up native delayed delivery functionality.
|
||||
|
||||
This component handles the setup and configuration of native delayed delivery
|
||||
for Celery workers. It is automatically included when quorum queues are
|
||||
detected in the application configuration.
|
||||
|
||||
Responsibilities:
|
||||
- Declaring native delayed delivery exchanges and queues
|
||||
- Binding all application queues to the delayed delivery exchanges
|
||||
- Handling connection failures gracefully with retries
|
||||
- Validating configuration settings
|
||||
"""
|
||||
|
||||
requires = (Tasks,)
|
||||
|
||||
def include_if(self, c: Consumer) -> bool:
|
||||
"""Determine if this bootstep should be included.
|
||||
|
||||
Args:
|
||||
c: The Celery consumer instance
|
||||
|
||||
Returns:
|
||||
bool: True if quorum queues are detected, False otherwise
|
||||
"""
|
||||
return detect_quorum_queues(c.app, c.app.connection_for_write().transport.driver_type)[0]
|
||||
|
||||
def start(self, c: Consumer) -> None:
|
||||
"""Initialize delayed delivery for all broker URLs.
|
||||
|
||||
Attempts to set up delayed delivery for each broker URL in the configuration.
|
||||
Failures are logged but don't prevent attempting remaining URLs.
|
||||
|
||||
Args:
|
||||
c: The Celery consumer instance
|
||||
|
||||
Raises:
|
||||
ValueError: If configuration validation fails
|
||||
"""
|
||||
app: Celery = c.app
|
||||
|
||||
try:
|
||||
self._validate_configuration(app)
|
||||
except ValueError as e:
|
||||
logger.critical("Configuration validation failed: %s", str(e))
|
||||
raise
|
||||
|
||||
broker_urls = self._validate_broker_urls(app.conf.broker_url)
|
||||
setup_errors = []
|
||||
|
||||
for broker_url in broker_urls:
|
||||
try:
|
||||
retry_over_time(
|
||||
self._setup_delayed_delivery,
|
||||
args=(c, broker_url),
|
||||
catch=(ConnectionRefusedError, OSError),
|
||||
errback=self._on_retry,
|
||||
interval_start=RETRY_INTERVAL,
|
||||
max_retries=MAX_RETRIES,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Failed to setup delayed delivery for %r: %s",
|
||||
broker_url, str(e)
|
||||
)
|
||||
setup_errors.append((broker_url, e))
|
||||
|
||||
if len(setup_errors) == len(broker_urls):
|
||||
logger.critical(
|
||||
"Failed to setup delayed delivery for all broker URLs. "
|
||||
"Native delayed delivery will not be available."
|
||||
)
|
||||
|
||||
def _setup_delayed_delivery(self, c: Consumer, broker_url: str) -> None:
|
||||
"""Set up delayed delivery for a specific broker URL.
|
||||
|
||||
Args:
|
||||
c: The Celery consumer instance
|
||||
broker_url: The broker URL to configure
|
||||
|
||||
Raises:
|
||||
ConnectionRefusedError: If connection to the broker fails
|
||||
OSError: If there are network-related issues
|
||||
Exception: For other unexpected errors during setup
|
||||
"""
|
||||
connection: Connection = c.app.connection_for_write(url=broker_url)
|
||||
queue_type = c.app.conf.broker_native_delayed_delivery_queue_type
|
||||
logger.debug(
|
||||
"Setting up delayed delivery for broker %r with queue type %r",
|
||||
broker_url, queue_type
|
||||
)
|
||||
|
||||
try:
|
||||
declare_native_delayed_delivery_exchanges_and_queues(
|
||||
connection,
|
||||
queue_type
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Failed to declare exchanges and queues for %r: %s",
|
||||
broker_url, str(e)
|
||||
)
|
||||
raise
|
||||
|
||||
try:
|
||||
self._bind_queues(c.app, connection)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Failed to bind queues for %r: %s",
|
||||
broker_url, str(e)
|
||||
)
|
||||
raise
|
||||
|
||||
def _bind_queues(self, app: Celery, connection: Connection) -> None:
|
||||
"""Bind all application queues to delayed delivery exchanges.
|
||||
|
||||
Args:
|
||||
app: The Celery application instance
|
||||
connection: The broker connection to use
|
||||
|
||||
Raises:
|
||||
Exception: If queue binding fails
|
||||
"""
|
||||
queues: ValuesView[Queue] = app.amqp.queues.values()
|
||||
if not queues:
|
||||
logger.warning("No queues found to bind for delayed delivery")
|
||||
return
|
||||
|
||||
for queue in queues:
|
||||
try:
|
||||
logger.debug("Binding queue %r to delayed delivery exchange", queue.name)
|
||||
bind_queue_to_native_delayed_delivery_exchange(connection, queue)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to bind queue %r: %s",
|
||||
queue.name, str(e)
|
||||
)
|
||||
raise
|
||||
|
||||
def _on_retry(self, exc: Exception, interval_range: Iterator[float], intervals_count: int) -> None:
|
||||
"""Callback for retry attempts.
|
||||
|
||||
Args:
|
||||
exc: The exception that triggered the retry
|
||||
interval_range: An iterator which returns the time in seconds to sleep next
|
||||
intervals_count: Number of retry attempts so far
|
||||
"""
|
||||
logger.warning(
|
||||
"Retrying delayed delivery setup (attempt %d/%d) after error: %s",
|
||||
intervals_count + 1, MAX_RETRIES, str(exc)
|
||||
)
|
||||
|
||||
def _validate_configuration(self, app: Celery) -> None:
|
||||
"""Validate all required configuration settings.
|
||||
|
||||
Args:
|
||||
app: The Celery application instance
|
||||
|
||||
Raises:
|
||||
ValueError: If any configuration is invalid
|
||||
"""
|
||||
# Validate broker URLs
|
||||
self._validate_broker_urls(app.conf.broker_url)
|
||||
|
||||
# Validate queue type
|
||||
self._validate_queue_type(app.conf.broker_native_delayed_delivery_queue_type)
|
||||
|
||||
def _validate_broker_urls(self, broker_urls: Union[str, List[str]]) -> Set[str]:
|
||||
"""Validate and split broker URLs.
|
||||
|
||||
Args:
|
||||
broker_urls: Broker URLs, either as a semicolon-separated string
|
||||
or as a list of strings
|
||||
|
||||
Returns:
|
||||
Set of valid broker URLs
|
||||
|
||||
Raises:
|
||||
ValueError: If no valid broker URLs are found or if invalid URLs are provided
|
||||
"""
|
||||
if not broker_urls:
|
||||
raise ValueError("broker_url configuration is empty")
|
||||
|
||||
if isinstance(broker_urls, str):
|
||||
brokers = broker_urls.split(";")
|
||||
elif isinstance(broker_urls, list):
|
||||
if not all(isinstance(url, str) for url in broker_urls):
|
||||
raise ValueError("All broker URLs must be strings")
|
||||
brokers = broker_urls
|
||||
else:
|
||||
raise ValueError(f"broker_url must be a string or list, got {broker_urls!r}")
|
||||
|
||||
valid_urls = {url for url in brokers}
|
||||
|
||||
if not valid_urls:
|
||||
raise ValueError("No valid broker URLs found in configuration")
|
||||
|
||||
return valid_urls
|
||||
|
||||
def _validate_queue_type(self, queue_type: Optional[str]) -> None:
|
||||
"""Validate the queue type configuration.
|
||||
|
||||
Args:
|
||||
queue_type: The configured queue type
|
||||
|
||||
Raises:
|
||||
ValueError: If queue type is invalid
|
||||
"""
|
||||
if not queue_type:
|
||||
raise ValueError("broker_native_delayed_delivery_queue_type is not configured")
|
||||
|
||||
if queue_type not in VALID_QUEUE_TYPES:
|
||||
sorted_types = sorted(VALID_QUEUE_TYPES)
|
||||
raise ValueError(
|
||||
f"Invalid queue type {queue_type!r}. Must be one of: {', '.join(sorted_types)}"
|
||||
)
|
||||
@@ -176,6 +176,7 @@ class Gossip(bootsteps.ConsumerStep):
|
||||
channel,
|
||||
queues=[ev.queue],
|
||||
on_message=partial(self.on_message, ev.event_from_message),
|
||||
accept=ev.accept,
|
||||
no_ack=True
|
||||
)]
|
||||
|
||||
|
||||
@@ -22,7 +22,7 @@ class Mingle(bootsteps.StartStopStep):
|
||||
|
||||
label = 'Mingle'
|
||||
requires = (Events,)
|
||||
compatible_transports = {'amqp', 'redis'}
|
||||
compatible_transports = {'amqp', 'redis', 'gcpubsub'}
|
||||
|
||||
def __init__(self, c, without_mingle=False, **kwargs):
|
||||
self.enabled = not without_mingle and self.compatible_transport(c.app)
|
||||
|
||||
@@ -1,13 +1,18 @@
|
||||
"""Worker Task Consumer Bootstep."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from kombu.common import QoS, ignore_errors
|
||||
|
||||
from celery import bootsteps
|
||||
from celery.utils.log import get_logger
|
||||
from celery.utils.quorum_queues import detect_quorum_queues
|
||||
|
||||
from .mingle import Mingle
|
||||
|
||||
__all__ = ('Tasks',)
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
debug = logger.debug
|
||||
|
||||
@@ -25,10 +30,7 @@ class Tasks(bootsteps.StartStopStep):
|
||||
"""Start task consumer."""
|
||||
c.update_strategies()
|
||||
|
||||
# - RabbitMQ 3.3 completely redefines how basic_qos works...
|
||||
# This will detect if the new qos semantics is in effect,
|
||||
# and if so make sure the 'apply_global' flag is set on qos updates.
|
||||
qos_global = not c.connection.qos_semantics_matches_spec
|
||||
qos_global = self.qos_global(c)
|
||||
|
||||
# set initial prefetch count
|
||||
c.connection.default_channel.basic_qos(
|
||||
@@ -63,3 +65,24 @@ class Tasks(bootsteps.StartStopStep):
|
||||
def info(self, c):
|
||||
"""Return task consumer info."""
|
||||
return {'prefetch_count': c.qos.value if c.qos else 'N/A'}
|
||||
|
||||
def qos_global(self, c) -> bool:
|
||||
"""Determine if global QoS should be applied.
|
||||
|
||||
Additional information:
|
||||
https://www.rabbitmq.com/docs/consumer-prefetch
|
||||
https://www.rabbitmq.com/docs/quorum-queues#global-qos
|
||||
"""
|
||||
# - RabbitMQ 3.3 completely redefines how basic_qos works...
|
||||
# This will detect if the new qos semantics is in effect,
|
||||
# and if so make sure the 'apply_global' flag is set on qos updates.
|
||||
qos_global = not c.connection.qos_semantics_matches_spec
|
||||
|
||||
if c.app.conf.worker_detect_quorum_queues:
|
||||
using_quorum_queues, qname = detect_quorum_queues(c.app, c.connection.transport.driver_type)
|
||||
|
||||
if using_quorum_queues:
|
||||
qos_global = False
|
||||
logger.info("Global QoS is disabled. Prefetch count in now static.")
|
||||
|
||||
return qos_global
|
||||
|
||||
@@ -7,6 +7,7 @@ from billiard.common import TERM_SIGNAME
|
||||
from kombu.utils.encoding import safe_repr
|
||||
|
||||
from celery.exceptions import WorkerShutdown
|
||||
from celery.platforms import EX_OK
|
||||
from celery.platforms import signals as _signals
|
||||
from celery.utils.functional import maybe_list
|
||||
from celery.utils.log import get_logger
|
||||
@@ -580,7 +581,7 @@ def autoscale(state, max=None, min=None):
|
||||
def shutdown(state, msg='Got shutdown from remote', **kwargs):
|
||||
"""Shutdown worker(s)."""
|
||||
logger.warning(msg)
|
||||
raise WorkerShutdown(msg)
|
||||
raise WorkerShutdown(EX_OK)
|
||||
|
||||
|
||||
# -- Queues
|
||||
|
||||
@@ -119,8 +119,10 @@ def synloop(obj, connection, consumer, blueprint, hub, qos,
|
||||
|
||||
obj.on_ready()
|
||||
|
||||
while blueprint.state == RUN and obj.connection:
|
||||
state.maybe_shutdown()
|
||||
def _loop_cycle():
|
||||
"""
|
||||
Perform one iteration of the blocking event loop.
|
||||
"""
|
||||
if heartbeat_error[0] is not None:
|
||||
raise heartbeat_error[0]
|
||||
if qos.prev != qos.value:
|
||||
@@ -133,3 +135,9 @@ def synloop(obj, connection, consumer, blueprint, hub, qos,
|
||||
except OSError:
|
||||
if blueprint.state == RUN:
|
||||
raise
|
||||
|
||||
while blueprint.state == RUN and obj.connection:
|
||||
try:
|
||||
state.maybe_shutdown()
|
||||
finally:
|
||||
_loop_cycle()
|
||||
|
||||
@@ -602,8 +602,8 @@ class Request:
|
||||
is_worker_lost = isinstance(exc, WorkerLostError)
|
||||
if self.task.acks_late:
|
||||
reject = (
|
||||
self.task.reject_on_worker_lost and
|
||||
is_worker_lost
|
||||
(self.task.reject_on_worker_lost and is_worker_lost)
|
||||
or (isinstance(exc, TimeLimitExceeded) and not self.task.acks_on_failure_or_timeout)
|
||||
)
|
||||
ack = self.task.acks_on_failure_or_timeout
|
||||
if reject:
|
||||
@@ -777,7 +777,7 @@ def create_request_cls(base, task, pool, hostname, eventer,
|
||||
if isinstance(exc, (SystemExit, KeyboardInterrupt)):
|
||||
raise exc
|
||||
return self.on_failure(retval, return_ok=True)
|
||||
task_ready(self)
|
||||
task_ready(self, successful=True)
|
||||
|
||||
if acks_late:
|
||||
self.acknowledge()
|
||||
|
||||
@@ -14,7 +14,8 @@ The worker consists of several components, all managed by bootsteps
|
||||
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timezone
|
||||
from time import sleep
|
||||
|
||||
from billiard import cpu_count
|
||||
from kombu.utils.compat import detect_environment
|
||||
@@ -89,7 +90,7 @@ class WorkController:
|
||||
def __init__(self, app=None, hostname=None, **kwargs):
|
||||
self.app = app or self.app
|
||||
self.hostname = default_nodename(hostname)
|
||||
self.startup_time = datetime.utcnow()
|
||||
self.startup_time = datetime.now(timezone.utc)
|
||||
self.app.loader.init_worker()
|
||||
self.on_before_init(**kwargs)
|
||||
self.setup_defaults(**kwargs)
|
||||
@@ -241,7 +242,7 @@ class WorkController:
|
||||
not self.app.IS_WINDOWS)
|
||||
|
||||
def stop(self, in_sighandler=False, exitcode=None):
|
||||
"""Graceful shutdown of the worker server."""
|
||||
"""Graceful shutdown of the worker server (Warm shutdown)."""
|
||||
if exitcode is not None:
|
||||
self.exitcode = exitcode
|
||||
if self.blueprint.state == RUN:
|
||||
@@ -251,7 +252,7 @@ class WorkController:
|
||||
self._send_worker_shutdown()
|
||||
|
||||
def terminate(self, in_sighandler=False):
|
||||
"""Not so graceful shutdown of the worker server."""
|
||||
"""Not so graceful shutdown of the worker server (Cold shutdown)."""
|
||||
if self.blueprint.state != TERMINATE:
|
||||
self.signal_consumer_close()
|
||||
if not in_sighandler or self.pool.signal_safe:
|
||||
@@ -293,7 +294,7 @@ class WorkController:
|
||||
return reload_from_cwd(sys.modules[module], reloader)
|
||||
|
||||
def info(self):
|
||||
uptime = datetime.utcnow() - self.startup_time
|
||||
uptime = datetime.now(timezone.utc) - self.startup_time
|
||||
return {'total': self.state.total_count,
|
||||
'pid': os.getpid(),
|
||||
'clock': str(self.app.clock),
|
||||
@@ -407,3 +408,28 @@ class WorkController:
|
||||
'worker_disable_rate_limits', disable_rate_limits,
|
||||
)
|
||||
self.worker_lost_wait = either('worker_lost_wait', worker_lost_wait)
|
||||
|
||||
def wait_for_soft_shutdown(self):
|
||||
"""Wait :setting:`worker_soft_shutdown_timeout` if soft shutdown is enabled.
|
||||
|
||||
To enable soft shutdown, set the :setting:`worker_soft_shutdown_timeout` in the
|
||||
configuration. Soft shutdown can be used to allow the worker to finish processing
|
||||
few more tasks before initiating a cold shutdown. This mechanism allows the worker
|
||||
to finish short tasks that are already in progress and requeue long-running tasks
|
||||
to be picked up by another worker.
|
||||
|
||||
.. warning::
|
||||
If there are no tasks in the worker, the worker will not wait for the
|
||||
soft shutdown timeout even if it is set as it makes no sense to wait for
|
||||
the timeout when there are no tasks to process.
|
||||
"""
|
||||
app = self.app
|
||||
requests = tuple(state.active_requests)
|
||||
|
||||
if app.conf.worker_enable_soft_shutdown_on_idle:
|
||||
requests = True
|
||||
|
||||
if app.conf.worker_soft_shutdown_timeout > 0 and requests:
|
||||
log = f"Initiating Soft Shutdown, terminating in {app.conf.worker_soft_shutdown_timeout} seconds"
|
||||
logger.warning(log)
|
||||
sleep(app.conf.worker_soft_shutdown_timeout)
|
||||
|
||||
Reference in New Issue
Block a user