API refactor
All checks were successful
continuous-integration/drone/push Build is passing

This commit is contained in:
2025-10-07 16:25:52 +09:00
parent 76d0d86211
commit 91c7e04474
1171 changed files with 81940 additions and 44117 deletions

View File

@@ -15,9 +15,9 @@ from collections import namedtuple
# Lazy loading
from . import local
SERIES = 'emerald-rush'
SERIES = 'immunity'
__version__ = '5.3.4'
__version__ = '5.5.3'
__author__ = 'Ask Solem'
__contact__ = 'auvipy@gmail.com'
__homepage__ = 'https://docs.celeryq.dev/'

View File

@@ -249,9 +249,13 @@ class AMQP:
if max_priority is None:
max_priority = conf.task_queue_max_priority
if not queues and conf.task_default_queue:
queue_arguments = None
if conf.task_default_queue_type == 'quorum':
queue_arguments = {'x-queue-type': 'quorum'}
queues = (Queue(conf.task_default_queue,
exchange=self.default_exchange,
routing_key=default_routing_key),)
routing_key=default_routing_key,
queue_arguments=queue_arguments),)
autoexchange = (self.autoexchange if autoexchange is None
else autoexchange)
return self.queues_cls(
@@ -285,7 +289,7 @@ class AMQP:
create_sent_event=False, root_id=None, parent_id=None,
shadow=None, chain=None, now=None, timezone=None,
origin=None, ignore_result=False, argsrepr=None, kwargsrepr=None, stamped_headers=None,
**options):
replaced_task_nesting=0, **options):
args = args or ()
kwargs = kwargs or {}
@@ -339,6 +343,7 @@ class AMQP:
'kwargsrepr': kwargsrepr,
'origin': origin or anon_nodename(),
'ignore_result': ignore_result,
'replaced_task_nesting': replaced_task_nesting,
'stamped_headers': stamped_headers,
'stamps': stamps,
}
@@ -462,7 +467,8 @@ class AMQP:
retry=None, retry_policy=None,
serializer=None, delivery_mode=None,
compression=None, declare=None,
headers=None, exchange_type=None, **kwargs):
headers=None, exchange_type=None,
timeout=None, confirm_timeout=None, **kwargs):
retry = default_retry if retry is None else retry
headers2, properties, body, sent_event = message
if headers:
@@ -523,6 +529,7 @@ class AMQP:
retry=retry, retry_policy=_rp,
delivery_mode=delivery_mode, declare=declare,
headers=headers2,
timeout=timeout, confirm_timeout=confirm_timeout,
**properties
)
if after_receivers:

View File

@@ -34,6 +34,7 @@ BACKEND_ALIASES = {
'azureblockblob': 'celery.backends.azureblockblob:AzureBlockBlobBackend',
'arangodb': 'celery.backends.arangodb:ArangoDbBackend',
's3': 'celery.backends.s3:S3Backend',
'gs': 'celery.backends.gcs:GCSBackend',
}

View File

@@ -1,17 +1,23 @@
"""Actual App instance implementation."""
import functools
import importlib
import inspect
import os
import sys
import threading
import typing
import warnings
from collections import UserDict, defaultdict, deque
from datetime import datetime
from datetime import timezone as datetime_timezone
from operator import attrgetter
from click.exceptions import Exit
from kombu import pools
from dateutil.parser import isoparse
from kombu import Exchange, pools
from kombu.clocks import LamportClock
from kombu.common import oid_from
from kombu.transport.native_delayed_delivery import calculate_routing_key
from kombu.utils.compat import register_after_fork
from kombu.utils.objects import cached_property
from kombu.utils.uuid import uuid
@@ -32,6 +38,8 @@ from celery.utils.log import get_logger
from celery.utils.objects import FallbackContext, mro_lookup
from celery.utils.time import maybe_make_aware, timezone, to_utc
from ..utils.annotations import annotation_is_class, annotation_issubclass, get_optional_arg
from ..utils.quorum_queues import detect_quorum_queues
# Load all builtin tasks
from . import backends, builtins # noqa
from .annotations import prepare as prepare_annotations
@@ -41,6 +49,10 @@ from .registry import TaskRegistry
from .utils import (AppPickler, Settings, _new_key_to_old, _old_key_to_new, _unpickle_app, _unpickle_app_v2, appstr,
bugreport, detect_settings)
if typing.TYPE_CHECKING: # pragma: no cover # codecov does not capture this
# flake8 marks the BaseModel import as unused, because the actual typehint is quoted.
from pydantic import BaseModel # noqa: F401
__all__ = ('Celery',)
logger = get_logger(__name__)
@@ -90,6 +102,70 @@ def _after_fork_cleanup_app(app):
logger.info('after forker raised exception: %r', exc, exc_info=1)
def pydantic_wrapper(
app: "Celery",
task_fun: typing.Callable[..., typing.Any],
task_name: str,
strict: bool = True,
context: typing.Optional[typing.Dict[str, typing.Any]] = None,
dump_kwargs: typing.Optional[typing.Dict[str, typing.Any]] = None
):
"""Wrapper to validate arguments and serialize return values using Pydantic."""
try:
pydantic = importlib.import_module('pydantic')
except ModuleNotFoundError as ex:
raise ImproperlyConfigured('You need to install pydantic to use pydantic model serialization.') from ex
BaseModel: typing.Type['BaseModel'] = pydantic.BaseModel # noqa: F811 # only defined when type checking
if context is None:
context = {}
if dump_kwargs is None:
dump_kwargs = {}
dump_kwargs.setdefault('mode', 'json')
task_signature = inspect.signature(task_fun)
@functools.wraps(task_fun)
def wrapper(*task_args, **task_kwargs):
# Validate task parameters if type hinted as BaseModel
bound_args = task_signature.bind(*task_args, **task_kwargs)
for arg_name, arg_value in bound_args.arguments.items():
arg_annotation = task_signature.parameters[arg_name].annotation
optional_arg = get_optional_arg(arg_annotation)
if optional_arg is not None and arg_value is not None:
arg_annotation = optional_arg
if annotation_issubclass(arg_annotation, BaseModel):
bound_args.arguments[arg_name] = arg_annotation.model_validate(
arg_value,
strict=strict,
context={**context, 'celery_app': app, 'celery_task_name': task_name},
)
# Call the task with (potentially) converted arguments
returned_value = task_fun(*bound_args.args, **bound_args.kwargs)
# Dump Pydantic model if the returned value is an instance of pydantic.BaseModel *and* its
# class matches the typehint
return_annotation = task_signature.return_annotation
optional_return_annotation = get_optional_arg(return_annotation)
if optional_return_annotation is not None:
return_annotation = optional_return_annotation
if (
annotation_is_class(return_annotation)
and isinstance(returned_value, BaseModel)
and isinstance(returned_value, return_annotation)
):
return returned_value.model_dump(**dump_kwargs)
return returned_value
return wrapper
class PendingConfiguration(UserDict, AttributeDictMixin):
# `app.conf` will be of this type before being explicitly configured,
# meaning the app can keep any configuration set directly
@@ -238,6 +314,12 @@ class Celery:
self.loader_cls = loader or self._get_default_loader()
self.log_cls = log or self.log_cls
self.control_cls = control or self.control_cls
self._custom_task_cls_used = (
# Custom task class provided as argument
bool(task_cls)
# subclass of Celery with a task_cls attribute
or self.__class__ is not Celery and hasattr(self.__class__, 'task_cls')
)
self.task_cls = task_cls or self.task_cls
self.set_as_current = set_as_current
self.registry_cls = symbol_by_name(self.registry_cls)
@@ -433,6 +515,7 @@ class Celery:
if shared:
def cons(app):
return app._task_from_fun(fun, **opts)
cons.__name__ = fun.__name__
connect_on_app_finalize(cons)
if not lazy or self.finalized:
@@ -461,13 +544,27 @@ class Celery:
def type_checker(self, fun, bound=False):
return staticmethod(head_from_fun(fun, bound=bound))
def _task_from_fun(self, fun, name=None, base=None, bind=False, **options):
def _task_from_fun(
self,
fun,
name=None,
base=None,
bind=False,
pydantic: bool = False,
pydantic_strict: bool = False,
pydantic_context: typing.Optional[typing.Dict[str, typing.Any]] = None,
pydantic_dump_kwargs: typing.Optional[typing.Dict[str, typing.Any]] = None,
**options,
):
if not self.finalized and not self.autofinalize:
raise RuntimeError('Contract breach: app not finalized')
name = name or self.gen_task_name(fun.__name__, fun.__module__)
base = base or self.Task
if name not in self._tasks:
if pydantic is True:
fun = pydantic_wrapper(self, fun, name, pydantic_strict, pydantic_context, pydantic_dump_kwargs)
run = fun if bind else staticmethod(fun)
task = type(fun.__name__, (base,), dict({
'app': self,
@@ -711,7 +808,7 @@ class Celery:
retries=0, chord=None,
reply_to=None, time_limit=None, soft_time_limit=None,
root_id=None, parent_id=None, route_name=None,
shadow=None, chain=None, task_type=None, **options):
shadow=None, chain=None, task_type=None, replaced_task_nesting=0, **options):
"""Send task by name.
Supports the same arguments as :meth:`@-Task.apply_async`.
@@ -734,13 +831,48 @@ class Celery:
ignore_result = options.pop('ignore_result', False)
options = router.route(
options, route_name or name, args, kwargs, task_type)
driver_type = self.producer_pool.connections.connection.transport.driver_type
if (eta or countdown) and detect_quorum_queues(self, driver_type)[0]:
queue = options.get("queue")
exchange_type = queue.exchange.type if queue else options["exchange_type"]
routing_key = queue.routing_key if queue else options["routing_key"]
exchange_name = queue.exchange.name if queue else options["exchange"]
if exchange_type != 'direct':
if eta:
if isinstance(eta, str):
eta = isoparse(eta)
countdown = (maybe_make_aware(eta) - self.now()).total_seconds()
if countdown:
if countdown > 0:
routing_key = calculate_routing_key(int(countdown), routing_key)
exchange = Exchange(
'celery_delayed_27',
type='topic',
)
options.pop("queue", None)
options['routing_key'] = routing_key
options['exchange'] = exchange
else:
logger.warning(
'Direct exchanges are not supported with native delayed delivery.\n'
f'{exchange_name} is a direct exchange but should be a topic exchange or '
'a fanout exchange in order for native delayed delivery to work properly.\n'
'If quorum queues are used, this task may block the worker process until the ETA arrives.'
)
if expires is not None:
if isinstance(expires, datetime):
expires_s = (maybe_make_aware(
expires) - self.now()).total_seconds()
elif isinstance(expires, str):
expires_s = (maybe_make_aware(
datetime.fromisoformat(expires)) - self.now()).total_seconds()
isoparse(expires)) - self.now()).total_seconds()
else:
expires_s = expires
@@ -781,7 +913,7 @@ class Celery:
self.conf.task_send_sent_event,
root_id, parent_id, shadow, chain,
ignore_result=ignore_result,
**options
replaced_task_nesting=replaced_task_nesting, **options
)
stamped_headers = options.pop('stamped_headers', [])
@@ -894,6 +1026,7 @@ class Celery:
'broker_connection_timeout', connect_timeout
),
)
broker_connection = connection
def _acquire_connection(self, pool=True):
@@ -913,6 +1046,7 @@ class Celery:
will be acquired from the connection pool.
"""
return FallbackContext(connection, self._acquire_connection, pool=pool)
default_connection = connection_or_acquire # XXX compat
def producer_or_acquire(self, producer=None):
@@ -928,6 +1062,7 @@ class Celery:
return FallbackContext(
producer, self.producer_pool.acquire, block=True,
)
default_producer = producer_or_acquire # XXX compat
def prepare_config(self, c):
@@ -936,7 +1071,7 @@ class Celery:
def now(self):
"""Return the current time and date as a datetime."""
now_in_utc = to_utc(datetime.utcnow())
now_in_utc = to_utc(datetime.now(datetime_timezone.utc))
return now_in_utc.astimezone(self.timezone)
def select_queues(self, queues=None):
@@ -974,7 +1109,14 @@ class Celery:
This is used by PendingConfiguration:
as soon as you access a key the configuration is read.
"""
conf = self._conf = self._load_config()
try:
conf = self._conf = self._load_config()
except AttributeError as err:
# AttributeError is not propagated, it is "handled" by
# PendingConfiguration parent class. This causes
# confusing RecursionError.
raise ModuleNotFoundError(*err.args) from err
return conf
def _load_config(self):

View File

@@ -360,7 +360,7 @@ class Inspect:
* ``routing_key`` - Routing key used when task was published
* ``priority`` - Priority used when task was published
* ``redelivered`` - True if the task was redelivered
* ``worker_pid`` - PID of worker processin the task
* ``worker_pid`` - PID of worker processing the task
"""
# signature used be unary: query_task(ids=[id1, id2])
@@ -527,7 +527,8 @@ class Control:
if result:
for host in result:
for response in host.values():
task_ids.update(response['ok'])
if isinstance(response['ok'], set):
task_ids.update(response['ok'])
if task_ids:
return self.revoke(list(task_ids), destination=destination, terminate=terminate, signal=signal, **kwargs)

View File

@@ -95,6 +95,7 @@ NAMESPACES = Namespace(
heartbeat=Option(120, type='int'),
heartbeat_checkrate=Option(3.0, type='int'),
login_method=Option(None, type='string'),
native_delayed_delivery_queue_type=Option(default='quorum', type='string'),
pool_limit=Option(10, type='int'),
use_ssl=Option(False, type='bool'),
@@ -140,6 +141,12 @@ NAMESPACES = Namespace(
connection_timeout=Option(20, type='int'),
read_timeout=Option(120, type='int'),
),
gcs=Namespace(
bucket=Option(type='string'),
project=Option(type='string'),
base_path=Option('', type='string'),
ttl=Option(0, type='float'),
),
control=Namespace(
queue_ttl=Option(300.0, type='float'),
queue_expires=Option(10.0, type='float'),
@@ -243,6 +250,7 @@ NAMESPACES = Namespace(
),
table_schemas=Option(type='dict'),
table_names=Option(type='dict', old={'celery_result_db_tablenames'}),
create_tables_at_setup=Option(True, type='bool'),
),
task=Namespace(
__old__=OLD_NS,
@@ -255,6 +263,7 @@ NAMESPACES = Namespace(
inherit_parent_priority=Option(False, type='bool'),
default_delivery_mode=Option(2, type='string'),
default_queue=Option('celery'),
default_queue_type=Option('classic', type='string'),
default_exchange=Option(None, type='string'), # taken from queue
default_exchange_type=Option('direct'),
default_routing_key=Option(None, type='string'), # taken from queue
@@ -302,6 +311,8 @@ NAMESPACES = Namespace(
cancel_long_running_tasks_on_connection_loss=Option(
False, type='bool'
),
soft_shutdown_timeout=Option(0.0, type='float'),
enable_soft_shutdown_on_idle=Option(False, type='bool'),
concurrency=Option(None, type='int'),
consumer=Option('celery.worker.consumer:Consumer', type='string'),
direct=Option(False, type='bool', old={'celery_worker_direct'}),
@@ -325,6 +336,7 @@ NAMESPACES = Namespace(
pool_restarts=Option(False, type='bool'),
proc_alive_timeout=Option(4.0, type='float'),
prefetch_multiplier=Option(4, type='int'),
enable_prefetch_count_reduction=Option(True, type='bool'),
redirect_stdouts=Option(
True, type='bool', old={'celery_redirect_stdouts'},
),
@@ -338,6 +350,7 @@ NAMESPACES = Namespace(
task_log_format=Option(DEFAULT_TASK_LOG_FMT),
timer=Option(type='string'),
timer_precision=Option(1.0, type='float'),
detect_quorum_queues=Option(True, type='bool'),
),
)

View File

@@ -18,6 +18,7 @@ from celery import signals
from celery._state import get_current_task
from celery.exceptions import CDeprecationWarning, CPendingDeprecationWarning
from celery.local import class_property
from celery.platforms import isatty
from celery.utils.log import (ColorFormatter, LoggingProxy, get_logger, get_multiprocessing_logger, mlevel,
reset_multiprocessing_logger)
from celery.utils.nodenames import node_format
@@ -203,7 +204,7 @@ class Logging:
if colorize or colorize is None:
# Only use color if there's no active log file
# and stderr is an actual terminal.
return logfile is None and sys.stderr.isatty()
return logfile is None and isatty(sys.stderr)
return colorize
def colored(self, logfile=None, enabled=None):

View File

@@ -20,7 +20,7 @@ except AttributeError: # pragma: no cover
# for support Python 3.7
Pattern = re.Pattern
__all__ = ('MapRoute', 'Router', 'prepare')
__all__ = ('MapRoute', 'Router', 'expand_router_string', 'prepare')
class MapRoute:

View File

@@ -104,7 +104,7 @@ class Context:
def _get_custom_headers(self, *args, **kwargs):
headers = {}
headers.update(*args, **kwargs)
celery_keys = {*Context.__dict__.keys(), 'lang', 'task', 'argsrepr', 'kwargsrepr'}
celery_keys = {*Context.__dict__.keys(), 'lang', 'task', 'argsrepr', 'kwargsrepr', 'compression'}
for key in celery_keys:
headers.pop(key, None)
if not headers:
@@ -466,7 +466,7 @@ class Task:
shadow (str): Override task name used in logs/monitoring.
Default is retrieved from :meth:`shadow_name`.
connection (kombu.Connection): Re-use existing broker connection
connection (kombu.Connection): Reuse existing broker connection
instead of acquiring one from the connection pool.
retry (bool): If enabled sending of the task message will be
@@ -535,6 +535,8 @@ class Task:
publisher (kombu.Producer): Deprecated alias to ``producer``.
headers (Dict): Message headers to be included in the message.
The headers can be used as an overlay for custom labeling
using the :ref:`canvas-stamping` feature.
Returns:
celery.result.AsyncResult: Promise of future evaluation.
@@ -543,6 +545,8 @@ class Task:
TypeError: If not enough arguments are passed, or too many
arguments are passed. Note that signature checks may
be disabled by specifying ``@task(typing=False)``.
ValueError: If soft_time_limit and time_limit both are set
but soft_time_limit is greater than time_limit
kombu.exceptions.OperationalError: If a connection to the
transport cannot be made, or if the connection is lost.
@@ -550,6 +554,9 @@ class Task:
Also supports all keyword arguments supported by
:meth:`kombu.Producer.publish`.
"""
if self.soft_time_limit and self.time_limit and self.soft_time_limit > self.time_limit:
raise ValueError('soft_time_limit must be less than or equal to time_limit')
if self.typing:
try:
check_arguments = self.__header__
@@ -788,6 +795,7 @@ class Task:
request = {
'id': task_id,
'task': self.name,
'retries': retries,
'is_eager': True,
'logfile': logfile,
@@ -824,7 +832,7 @@ class Task:
if isinstance(retval, Retry) and retval.sig is not None:
return retval.sig.apply(retries=retries + 1)
state = states.SUCCESS if ret.info is None else ret.info.state
return EagerResult(task_id, retval, state, traceback=tb)
return EagerResult(task_id, retval, state, traceback=tb, name=self.name)
def AsyncResult(self, task_id, **kwargs):
"""Get AsyncResult instance for the specified task.
@@ -954,11 +962,20 @@ class Task:
root_id=self.request.root_id,
replaced_task_nesting=replaced_task_nesting
)
# If the replaced task is a chain, we want to set all of the chain tasks
# with the same replaced_task_nesting value to mark their replacement nesting level
if isinstance(sig, _chain):
for chain_task in maybe_list(sig.tasks) or []:
chain_task.set(replaced_task_nesting=replaced_task_nesting)
# If the task being replaced is part of a chain, we need to re-create
# it with the replacement signature - these subsequent tasks will
# retain their original task IDs as well
for t in reversed(self.request.chain or []):
sig |= signature(t, app=self.app)
chain_task = signature(t, app=self.app)
chain_task.set(replaced_task_nesting=replaced_task_nesting)
sig |= chain_task
return self.on_replace(sig)
def add_to_chord(self, sig, lazy=False):
@@ -1099,7 +1116,7 @@ class Task:
return result
def push_request(self, *args, **kwargs):
self.request_stack.push(Context(*args, **kwargs))
self.request_stack.push(Context(*args, **{**self.request.__dict__, **kwargs}))
def pop_request(self):
self.request_stack.pop()

View File

@@ -8,7 +8,6 @@ import os
import sys
import time
from collections import namedtuple
from typing import Any, Callable, Dict, FrozenSet, Optional, Sequence, Tuple, Type, Union
from warnings import warn
from billiard.einfo import ExceptionInfo, ExceptionWithTraceback
@@ -17,8 +16,6 @@ from kombu.serialization import loads as loads_message
from kombu.serialization import prepare_accept_content
from kombu.utils.encoding import safe_repr, safe_str
import celery
import celery.loaders.app
from celery import current_app, group, signals, states
from celery._state import _task_stack
from celery.app.task import Context
@@ -294,20 +291,10 @@ def traceback_clear(exc=None):
tb = tb.tb_next
def build_tracer(
name: str,
task: Union[celery.Task, celery.local.PromiseProxy],
loader: Optional[celery.loaders.app.AppLoader] = None,
hostname: Optional[str] = None,
store_errors: bool = True,
Info: Type[TraceInfo] = TraceInfo,
eager: bool = False,
propagate: bool = False,
app: Optional[celery.Celery] = None,
monotonic: Callable[[], int] = time.monotonic,
trace_ok_t: Type[trace_ok_t] = trace_ok_t,
IGNORE_STATES: FrozenSet[str] = IGNORE_STATES) -> \
Callable[[str, Tuple[Any, ...], Dict[str, Any], Any], trace_ok_t]:
def build_tracer(name, task, loader=None, hostname=None, store_errors=True,
Info=TraceInfo, eager=False, propagate=False, app=None,
monotonic=time.monotonic, trace_ok_t=trace_ok_t,
IGNORE_STATES=IGNORE_STATES):
"""Return a function that traces task execution.
Catches all exceptions and updates result backend with the
@@ -387,12 +374,7 @@ def build_tracer(
from celery import canvas
signature = canvas.maybe_signature # maybe_ does not clone if already
def on_error(
request: celery.app.task.Context,
exc: Union[Exception, Type[Exception]],
state: str = FAILURE,
call_errbacks: bool = True) -> Tuple[Info, Any, Any, Any]:
"""Handle any errors raised by a `Task`'s execution."""
def on_error(request, exc, state=FAILURE, call_errbacks=True):
if propagate:
raise
I = Info(state, exc)
@@ -401,13 +383,7 @@ def build_tracer(
)
return I, R, I.state, I.retval
def trace_task(
uuid: str,
args: Sequence[Any],
kwargs: Dict[str, Any],
request: Optional[Dict[str, Any]] = None) -> trace_ok_t:
"""Execute and trace a `Task`."""
def trace_task(uuid, args, kwargs, request=None):
# R - is the possibly prepared return value.
# I - is the Info object.
# T - runtime

View File

@@ -35,7 +35,7 @@ settings -> transport:{transport} results:{results}
"""
HIDDEN_SETTINGS = re.compile(
'API|TOKEN|KEY|SECRET|PASS|PROFANITIES_LIST|SIGNATURE|DATABASE',
'API|TOKEN|KEY|SECRET|PASS|PROFANITIES_LIST|SIGNATURE|DATABASE|BEAT_DBURI',
re.IGNORECASE,
)

View File

@@ -20,7 +20,7 @@ from kombu.utils.encoding import safe_str
from celery import VERSION_BANNER, platforms, signals
from celery.app import trace
from celery.loaders.app import AppLoader
from celery.platforms import EX_FAILURE, EX_OK, check_privileges
from celery.platforms import EX_FAILURE, EX_OK, check_privileges, isatty
from celery.utils import static, term
from celery.utils.debug import cry
from celery.utils.imports import qualname
@@ -77,8 +77,9 @@ def active_thread_count():
if not t.name.startswith('Dummy-'))
def safe_say(msg):
print(f'\n{msg}', file=sys.__stderr__, flush=True)
def safe_say(msg, f=sys.__stderr__):
if hasattr(f, 'fileno') and f.fileno() is not None:
os.write(f.fileno(), f'\n{msg}\n'.encode())
class Worker(WorkController):
@@ -106,7 +107,7 @@ class Worker(WorkController):
super().setup_defaults(**kwargs)
self.purge = purge
self.no_color = no_color
self._isatty = sys.stdout.isatty()
self._isatty = isatty(sys.stdout)
self.colored = self.app.log.colored(
self.logfile,
enabled=not no_color if no_color is not None else no_color
@@ -278,15 +279,27 @@ class Worker(WorkController):
)
def _shutdown_handler(worker, sig='TERM', how='Warm',
callback=None, exitcode=EX_OK):
def _shutdown_handler(worker: Worker, sig='SIGTERM', how='Warm', callback=None, exitcode=EX_OK, verbose=True):
"""Install signal handler for warm/cold shutdown.
The handler will run from the MainProcess.
Args:
worker (Worker): The worker that received the signal.
sig (str, optional): The signal that was received. Defaults to 'TERM'.
how (str, optional): The type of shutdown to perform. Defaults to 'Warm'.
callback (Callable, optional): Signal handler. Defaults to None.
exitcode (int, optional): The exit code to use. Defaults to EX_OK.
verbose (bool, optional): Whether to print the type of shutdown. Defaults to True.
"""
def _handle_request(*args):
with in_sighandler():
from celery.worker import state
if current_process()._name == 'MainProcess':
if callback:
callback(worker)
safe_say(f'worker: {how} shutdown (MainProcess)')
if verbose:
safe_say(f'worker: {how} shutdown (MainProcess)', sys.__stdout__)
signals.worker_shutting_down.send(
sender=worker.hostname, sig=sig, how=how,
exitcode=exitcode,
@@ -297,19 +310,126 @@ def _shutdown_handler(worker, sig='TERM', how='Warm',
platforms.signals[sig] = _handle_request
def on_hard_shutdown(worker: Worker):
"""Signal handler for hard shutdown.
The handler will terminate the worker immediately by force using the exit code ``EX_FAILURE``.
In practice, you should never get here, as the standard shutdown process should be enough.
This handler is only for the worst-case scenario, where the worker is stuck and cannot be
terminated gracefully (e.g., spamming the Ctrl+C in the terminal to force the worker to terminate).
Args:
worker (Worker): The worker that received the signal.
Raises:
WorkerTerminate: This exception will be raised in the MainProcess to terminate the worker immediately.
"""
from celery.exceptions import WorkerTerminate
raise WorkerTerminate(EX_FAILURE)
def during_soft_shutdown(worker: Worker):
"""This signal handler is called when the worker is in the middle of the soft shutdown process.
When the worker is in the soft shutdown process, it is waiting for tasks to finish. If the worker
receives a SIGINT (Ctrl+C) or SIGQUIT signal (or possibly SIGTERM if REMAP_SIGTERM is set to "SIGQUIT"),
the handler will cancels all unacked requests to allow the worker to terminate gracefully and replace the
signal handler for SIGINT and SIGQUIT with the hard shutdown handler ``on_hard_shutdown`` to terminate
the worker immediately by force next time the signal is received.
It will give the worker once last chance to gracefully terminate (the cold shutdown), after canceling all
unacked requests, before using the hard shutdown handler to terminate the worker forcefully.
Args:
worker (Worker): The worker that received the signal.
"""
# Replace the signal handler for SIGINT (Ctrl+C) and SIGQUIT (and possibly SIGTERM)
# with the hard shutdown handler to terminate the worker immediately by force
install_worker_term_hard_handler(worker, sig='SIGINT', callback=on_hard_shutdown, verbose=False)
install_worker_term_hard_handler(worker, sig='SIGQUIT', callback=on_hard_shutdown)
# Cancel all unacked requests and allow the worker to terminate naturally
worker.consumer.cancel_all_unacked_requests()
# We get here if the worker was in the middle of the soft (cold) shutdown process,
# and the matching signal was received. This can typically happen when the worker is
# waiting for tasks to finish, and the user decides to still cancel the running tasks.
# We give the worker the last chance to gracefully terminate by letting the soft shutdown
# waiting time to finish, which is running in the MainProcess from the previous signal handler call.
safe_say('Waiting gracefully for cold shutdown to complete...', sys.__stdout__)
def on_cold_shutdown(worker: Worker):
"""Signal handler for cold shutdown.
Registered for SIGQUIT and SIGINT (Ctrl+C) signals. If REMAP_SIGTERM is set to "SIGQUIT", this handler will also
be registered for SIGTERM.
This handler will initiate the cold (and soft if enabled) shutdown procesdure for the worker.
Worker running with N tasks:
- SIGTERM:
-The worker will initiate the warm shutdown process until all tasks are finished. Additional.
SIGTERM signals will be ignored. SIGQUIT will transition to the cold shutdown process described below.
- SIGQUIT:
- The worker will initiate the cold shutdown process.
- If the soft shutdown is enabled, the worker will wait for the tasks to finish up to the soft
shutdown timeout (practically having a limited warm shutdown just before the cold shutdown).
- Cancel all tasks (from the MainProcess) and allow the worker to complete the cold shutdown
process gracefully.
Caveats:
- SIGINT (Ctrl+C) signal is defined to replace itself with the cold shutdown (SIGQUIT) after first use,
and to emit a message to the user to hit Ctrl+C again to initiate the cold shutdown process. But, most
important, it will also be caught in WorkController.start() to initiate the warm shutdown process.
- SIGTERM will also be handled in WorkController.start() to initiate the warm shutdown process (the same).
- If REMAP_SIGTERM is set to "SIGQUIT", the SIGTERM signal will be remapped to SIGQUIT, and the cold
shutdown process will be initiated instead of the warm shutdown process using SIGTERM.
- If SIGQUIT is received (also via SIGINT) during the cold/soft shutdown process, the handler will cancel all
unacked requests but still wait for the soft shutdown process to finish before terminating the worker
gracefully. The next time the signal is received though, the worker will terminate immediately by force.
So, the purpose of this handler is to allow waiting for the soft shutdown timeout, then cancel all tasks from
the MainProcess and let the WorkController.terminate() to terminate the worker naturally. If the soft shutdown
is disabled, it will immediately cancel all tasks let the cold shutdown finish normally.
Args:
worker (Worker): The worker that received the signal.
"""
safe_say('worker: Hitting Ctrl+C again will terminate all running tasks!', sys.__stdout__)
# Replace the signal handler for SIGINT (Ctrl+C) and SIGQUIT (and possibly SIGTERM)
install_worker_term_hard_handler(worker, sig='SIGINT', callback=during_soft_shutdown)
install_worker_term_hard_handler(worker, sig='SIGQUIT', callback=during_soft_shutdown)
if REMAP_SIGTERM == "SIGQUIT":
install_worker_term_hard_handler(worker, sig='SIGTERM', callback=during_soft_shutdown)
# else, SIGTERM will print the _shutdown_handler's message and do nothing, every time it is received..
# Initiate soft shutdown process (if enabled and tasks are running)
worker.wait_for_soft_shutdown()
# Cancel all unacked requests and allow the worker to terminate naturally
worker.consumer.cancel_all_unacked_requests()
# Stop the pool to allow successful tasks call on_success()
worker.consumer.pool.stop()
# Allow SIGTERM to be remapped to SIGQUIT to initiate cold shutdown instead of warm shutdown using SIGTERM
if REMAP_SIGTERM == "SIGQUIT":
install_worker_term_handler = partial(
_shutdown_handler, sig='SIGTERM', how='Cold', exitcode=EX_FAILURE,
_shutdown_handler, sig='SIGTERM', how='Cold', callback=on_cold_shutdown, exitcode=EX_FAILURE,
)
else:
install_worker_term_handler = partial(
_shutdown_handler, sig='SIGTERM', how='Warm',
)
if not is_jython: # pragma: no cover
install_worker_term_hard_handler = partial(
_shutdown_handler, sig='SIGQUIT', how='Cold',
exitcode=EX_FAILURE,
_shutdown_handler, sig='SIGQUIT', how='Cold', callback=on_cold_shutdown, exitcode=EX_FAILURE,
)
else: # pragma: no cover
install_worker_term_handler = \
@@ -317,8 +437,9 @@ else: # pragma: no cover
def on_SIGINT(worker):
safe_say('worker: Hitting Ctrl+C again will terminate all running tasks!')
install_worker_term_hard_handler(worker, sig='SIGINT')
safe_say('worker: Hitting Ctrl+C again will initiate cold shutdown, terminating all running tasks!',
sys.__stdout__)
install_worker_term_hard_handler(worker, sig='SIGINT', verbose=False)
if not is_jython: # pragma: no cover
@@ -343,7 +464,8 @@ def install_worker_restart_handler(worker, sig='SIGHUP'):
def restart_worker_sig_handler(*args):
"""Signal handler restarting the current python program."""
set_in_sighandler(True)
safe_say(f"Restarting celery worker ({' '.join(sys.argv)})")
safe_say(f"Restarting celery worker ({' '.join(sys.argv)})",
sys.__stdout__)
import atexit
atexit.register(_reload_current_worker)
from celery.worker import state

View File

@@ -1,4 +1,5 @@
"""The Azure Storage Block Blob backend for Celery."""
from kombu.transport.azurestoragequeues import Transport as AzureStorageQueuesTransport
from kombu.utils import cached_property
from kombu.utils.encoding import bytes_to_str
@@ -28,6 +29,13 @@ class AzureBlockBlobBackend(KeyValueStoreBackend):
container_name=None,
*args,
**kwargs):
"""
Supported URL formats:
azureblockblob://CONNECTION_STRING
azureblockblob://DefaultAzureCredential@STORAGE_ACCOUNT_URL
azureblockblob://ManagedIdentityCredential@STORAGE_ACCOUNT_URL
"""
super().__init__(*args, **kwargs)
if azurestorage is None or azurestorage.__version__ < '12':
@@ -65,11 +73,26 @@ class AzureBlockBlobBackend(KeyValueStoreBackend):
the container is created if it doesn't yet exist.
"""
client = BlobServiceClient.from_connection_string(
self._connection_string,
connection_timeout=self._connection_timeout,
read_timeout=self._read_timeout
)
if (
"DefaultAzureCredential" in self._connection_string or
"ManagedIdentityCredential" in self._connection_string
):
# Leveraging the work that Kombu already did for us
credential_, url = AzureStorageQueuesTransport.parse_uri(
self._connection_string
)
client = BlobServiceClient(
account_url=url,
credential=credential_,
connection_timeout=self._connection_timeout,
read_timeout=self._read_timeout,
)
else:
client = BlobServiceClient.from_connection_string(
self._connection_string,
connection_timeout=self._connection_timeout,
read_timeout=self._read_timeout,
)
try:
client.create_container(name=self._container_name)

View File

@@ -9,7 +9,7 @@ import sys
import time
import warnings
from collections import namedtuple
from datetime import datetime, timedelta
from datetime import timedelta
from functools import partial
from weakref import WeakValueDictionary
@@ -460,7 +460,7 @@ class Backend:
state, traceback, request, format_date=True,
encode=False):
if state in self.READY_STATES:
date_done = datetime.utcnow()
date_done = self.app.now()
if format_date:
date_done = date_done.isoformat()
else:
@@ -833,9 +833,11 @@ class BaseKeyValueStoreBackend(Backend):
"""
global_keyprefix = self.app.conf.get('result_backend_transport_options', {}).get("global_keyprefix", None)
if global_keyprefix:
self.task_keyprefix = f"{global_keyprefix}_{self.task_keyprefix}"
self.group_keyprefix = f"{global_keyprefix}_{self.group_keyprefix}"
self.chord_keyprefix = f"{global_keyprefix}_{self.chord_keyprefix}"
if global_keyprefix[-1] not in ':_-.':
global_keyprefix += '_'
self.task_keyprefix = f"{global_keyprefix}{self.task_keyprefix}"
self.group_keyprefix = f"{global_keyprefix}{self.group_keyprefix}"
self.chord_keyprefix = f"{global_keyprefix}{self.chord_keyprefix}"
def _encode_prefixes(self):
self.task_keyprefix = self.key_t(self.task_keyprefix)
@@ -1080,7 +1082,7 @@ class BaseKeyValueStoreBackend(Backend):
)
finally:
deps.delete()
self.client.delete(key)
self.delete(key)
else:
self.expire(key, self.expires)

View File

@@ -86,7 +86,7 @@ class CassandraBackend(BaseBackend):
supports_autoexpire = True # autoexpire supported via entry_ttl
def __init__(self, servers=None, keyspace=None, table=None, entry_ttl=None,
port=9042, bundle_path=None, **kwargs):
port=None, bundle_path=None, **kwargs):
super().__init__(**kwargs)
if not cassandra:
@@ -96,7 +96,7 @@ class CassandraBackend(BaseBackend):
self.servers = servers or conf.get('cassandra_servers', None)
self.bundle_path = bundle_path or conf.get(
'cassandra_secure_bundle_path', None)
self.port = port or conf.get('cassandra_port', None)
self.port = port or conf.get('cassandra_port', None) or 9042
self.keyspace = keyspace or conf.get('cassandra_keyspace', None)
self.table = table or conf.get('cassandra_table', None)
self.cassandra_options = conf.get('cassandra_options', {})

View File

@@ -98,11 +98,23 @@ class DatabaseBackend(BaseBackend):
'Missing connection string! Do you have the'
' database_url setting set to a real value?')
self.session_manager = SessionManager()
create_tables_at_setup = conf.database_create_tables_at_setup
if create_tables_at_setup is True:
self._create_tables()
@property
def extended_result(self):
return self.app.conf.find_value_for_key('extended', 'result')
def ResultSession(self, session_manager=SessionManager()):
def _create_tables(self):
"""Create the task and taskset tables."""
self.ResultSession()
def ResultSession(self, session_manager=None):
if session_manager is None:
session_manager = self.session_manager
return session_manager.session_factory(
dburi=self.url,
short_lived_sessions=self.short_lived_sessions,

View File

@@ -1,5 +1,5 @@
"""Database models used by the SQLAlchemy result store backend."""
from datetime import datetime
from datetime import datetime, timezone
import sqlalchemy as sa
from sqlalchemy.types import PickleType
@@ -22,8 +22,8 @@ class Task(ResultModelBase):
task_id = sa.Column(sa.String(155), unique=True)
status = sa.Column(sa.String(50), default=states.PENDING)
result = sa.Column(PickleType, nullable=True)
date_done = sa.Column(sa.DateTime, default=datetime.utcnow,
onupdate=datetime.utcnow, nullable=True)
date_done = sa.Column(sa.DateTime, default=datetime.now(timezone.utc),
onupdate=datetime.now(timezone.utc), nullable=True)
traceback = sa.Column(sa.Text, nullable=True)
def __init__(self, task_id):
@@ -84,7 +84,7 @@ class TaskSet(ResultModelBase):
autoincrement=True, primary_key=True)
taskset_id = sa.Column(sa.String(155), unique=True)
result = sa.Column(PickleType, nullable=True)
date_done = sa.Column(sa.DateTime, default=datetime.utcnow,
date_done = sa.Column(sa.DateTime, default=datetime.now(timezone.utc),
nullable=True)
def __init__(self, taskset_id, result):

View File

@@ -1,6 +1,8 @@
"""AWS DynamoDB result store backend."""
from collections import namedtuple
from ipaddress import ip_address
from time import sleep, time
from typing import Any, Dict
from kombu.utils.url import _parse_url as parse_url
@@ -54,11 +56,15 @@ class DynamoDBBackend(KeyValueStoreBackend):
supports_autoexpire = True
_key_field = DynamoDBAttribute(name='id', data_type='S')
# Each record has either a value field or count field
_value_field = DynamoDBAttribute(name='result', data_type='B')
_count_filed = DynamoDBAttribute(name="chord_count", data_type='N')
_timestamp_field = DynamoDBAttribute(name='timestamp', data_type='N')
_ttl_field = DynamoDBAttribute(name='ttl', data_type='N')
_available_fields = None
implements_incr = True
def __init__(self, url=None, table_name=None, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -91,9 +97,9 @@ class DynamoDBBackend(KeyValueStoreBackend):
aws_credentials_given = access_key_given
if region == 'localhost':
if region == 'localhost' or DynamoDBBackend._is_valid_ip(region):
# We are using the downloadable, local version of DynamoDB
self.endpoint_url = f'http://localhost:{port}'
self.endpoint_url = f'http://{region}:{port}'
self.aws_region = 'us-east-1'
logger.warning(
'Using local-only DynamoDB endpoint URL: {}'.format(
@@ -148,6 +154,14 @@ class DynamoDBBackend(KeyValueStoreBackend):
secret_access_key=aws_secret_access_key
)
@staticmethod
def _is_valid_ip(ip):
try:
ip_address(ip)
return True
except ValueError:
return False
def _get_client(self, access_key_id=None, secret_access_key=None):
"""Get client connection."""
if self._client is None:
@@ -459,6 +473,40 @@ class DynamoDBBackend(KeyValueStoreBackend):
})
return put_request
def _prepare_init_count_request(self, key: str) -> Dict[str, Any]:
"""Construct the counter initialization request parameters"""
timestamp = time()
return {
'TableName': self.table_name,
'Item': {
self._key_field.name: {
self._key_field.data_type: key
},
self._count_filed.name: {
self._count_filed.data_type: "0"
},
self._timestamp_field.name: {
self._timestamp_field.data_type: str(timestamp)
}
}
}
def _prepare_inc_count_request(self, key: str) -> Dict[str, Any]:
"""Construct the counter increment request parameters"""
return {
'TableName': self.table_name,
'Key': {
self._key_field.name: {
self._key_field.data_type: key
}
},
'UpdateExpression': f"set {self._count_filed.name} = {self._count_filed.name} + :num",
"ExpressionAttributeValues": {
":num": {"N": "1"},
},
"ReturnValues": "UPDATED_NEW",
}
def _item_to_dict(self, raw_response):
"""Convert get_item() response to field-value pairs."""
if 'Item' not in raw_response:
@@ -491,3 +539,18 @@ class DynamoDBBackend(KeyValueStoreBackend):
key = str(key)
request_parameters = self._prepare_get_request(key)
self.client.delete_item(**request_parameters)
def incr(self, key: bytes) -> int:
"""Atomically increase the chord_count and return the new count"""
key = str(key)
request_parameters = self._prepare_inc_count_request(key)
item_response = self.client.update_item(**request_parameters)
new_count: str = item_response["Attributes"][self._count_filed.name][self._count_filed.data_type]
return int(new_count)
def _apply_chord_incr(self, header_result_args, body, **kwargs):
chord_key = self.get_key_for_chord(header_result_args[0])
init_count_request = self._prepare_init_count_request(str(chord_key))
self.client.put_item(**init_count_request)
return super()._apply_chord_incr(
header_result_args, body, **kwargs)

View File

@@ -1,5 +1,5 @@
"""Elasticsearch result store backend."""
from datetime import datetime
from datetime import datetime, timezone
from kombu.utils.encoding import bytes_to_str
from kombu.utils.url import _parse_url
@@ -14,6 +14,11 @@ try:
except ImportError:
elasticsearch = None
try:
import elastic_transport
except ImportError:
elastic_transport = None
__all__ = ('ElasticsearchBackend',)
E_LIB_MISSING = """\
@@ -31,7 +36,7 @@ class ElasticsearchBackend(KeyValueStoreBackend):
"""
index = 'celery'
doc_type = 'backend'
doc_type = None
scheme = 'http'
host = 'localhost'
port = 9200
@@ -83,17 +88,17 @@ class ElasticsearchBackend(KeyValueStoreBackend):
self._server = None
def exception_safe_to_retry(self, exc):
if isinstance(exc, (elasticsearch.exceptions.TransportError)):
if isinstance(exc, elasticsearch.exceptions.ApiError):
# 401: Unauthorized
# 409: Conflict
# 429: Too Many Requests
# 500: Internal Server Error
# 502: Bad Gateway
# 503: Service Unavailable
# 504: Gateway Timeout
# N/A: Low level exception (i.e. socket exception)
if exc.status_code in {401, 409, 429, 500, 502, 503, 504, 'N/A'}:
if exc.status_code in {401, 409, 500, 502, 504, 'N/A'}:
return True
if isinstance(exc, elasticsearch.exceptions.TransportError):
return True
return False
def get(self, key):
@@ -108,17 +113,23 @@ class ElasticsearchBackend(KeyValueStoreBackend):
pass
def _get(self, key):
return self.server.get(
index=self.index,
doc_type=self.doc_type,
id=key,
)
if self.doc_type:
return self.server.get(
index=self.index,
id=key,
doc_type=self.doc_type,
)
else:
return self.server.get(
index=self.index,
id=key,
)
def _set_with_state(self, key, value, state):
body = {
'result': value,
'@timestamp': '{}Z'.format(
datetime.utcnow().isoformat()[:-3]
datetime.now(timezone.utc).isoformat()[:-9]
),
}
try:
@@ -135,14 +146,23 @@ class ElasticsearchBackend(KeyValueStoreBackend):
def _index(self, id, body, **kwargs):
body = {bytes_to_str(k): v for k, v in body.items()}
return self.server.index(
id=bytes_to_str(id),
index=self.index,
doc_type=self.doc_type,
body=body,
params={'op_type': 'create'},
**kwargs
)
if self.doc_type:
return self.server.index(
id=bytes_to_str(id),
index=self.index,
doc_type=self.doc_type,
body=body,
params={'op_type': 'create'},
**kwargs
)
else:
return self.server.index(
id=bytes_to_str(id),
index=self.index,
body=body,
params={'op_type': 'create'},
**kwargs
)
def _update(self, id, body, state, **kwargs):
"""Update state in a conflict free manner.
@@ -182,19 +202,32 @@ class ElasticsearchBackend(KeyValueStoreBackend):
prim_term = res_get.get('_primary_term', 1)
# try to update document with current seq_no and primary_term
res = self.server.update(
id=bytes_to_str(id),
index=self.index,
doc_type=self.doc_type,
body={'doc': body},
params={'if_primary_term': prim_term, 'if_seq_no': seq_no},
**kwargs
)
if self.doc_type:
res = self.server.update(
id=bytes_to_str(id),
index=self.index,
doc_type=self.doc_type,
body={'doc': body},
params={'if_primary_term': prim_term, 'if_seq_no': seq_no},
**kwargs
)
else:
res = self.server.update(
id=bytes_to_str(id),
index=self.index,
body={'doc': body},
params={'if_primary_term': prim_term, 'if_seq_no': seq_no},
**kwargs
)
# result is elastic search update query result
# noop = query did not update any document
# updated = at least one document got updated
if res['result'] == 'noop':
raise elasticsearch.exceptions.ConflictError(409, 'conflicting update occurred concurrently', {})
raise elasticsearch.exceptions.ConflictError(
"conflicting update occurred concurrently",
elastic_transport.ApiResponseMeta(409, "HTTP/1.1",
elastic_transport.HttpHeaders(), 0, elastic_transport.NodeConfig(
self.scheme, self.host, self.port)), None)
return res
def encode(self, data):
@@ -225,7 +258,10 @@ class ElasticsearchBackend(KeyValueStoreBackend):
return [self.get(key) for key in keys]
def delete(self, key):
self.server.delete(index=self.index, doc_type=self.doc_type, id=key)
if self.doc_type:
self.server.delete(index=self.index, id=key, doc_type=self.doc_type)
else:
self.server.delete(index=self.index, id=key)
def _get_server(self):
"""Connect to the Elasticsearch server."""
@@ -233,11 +269,10 @@ class ElasticsearchBackend(KeyValueStoreBackend):
if self.username and self.password:
http_auth = (self.username, self.password)
return elasticsearch.Elasticsearch(
f'{self.host}:{self.port}',
f'{self.scheme}://{self.host}:{self.port}',
retry_on_timeout=self.es_retry_on_timeout,
max_retries=self.es_max_retries,
timeout=self.es_timeout,
scheme=self.scheme,
http_auth=http_auth,
)

View File

@@ -50,7 +50,7 @@ class FilesystemBackend(KeyValueStoreBackend):
self.open = open
self.unlink = unlink
# Lets verify that we've everything setup right
# Let's verify that we've everything setup right
self._do_directory_test(b'.fs-backend-' + uuid().encode(encoding))
def __reduce__(self, args=(), kwargs=None):

View File

@@ -0,0 +1,352 @@
"""Google Cloud Storage result store backend for Celery."""
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime, timedelta
from os import getpid
from threading import RLock
from kombu.utils.encoding import bytes_to_str
from kombu.utils.functional import dictfilter
from kombu.utils.url import url_to_parts
from celery.canvas import maybe_signature
from celery.exceptions import ChordError, ImproperlyConfigured
from celery.result import GroupResult, allow_join_result
from celery.utils.log import get_logger
from .base import KeyValueStoreBackend
try:
import requests
from google.api_core import retry
from google.api_core.exceptions import Conflict
from google.api_core.retry import if_exception_type
from google.cloud import storage
from google.cloud.storage import Client
from google.cloud.storage.retry import DEFAULT_RETRY
except ImportError:
storage = None
try:
from google.cloud import firestore, firestore_admin_v1
except ImportError:
firestore = None
firestore_admin_v1 = None
__all__ = ('GCSBackend',)
logger = get_logger(__name__)
class GCSBackendBase(KeyValueStoreBackend):
"""Google Cloud Storage task result backend."""
def __init__(self, **kwargs):
if not storage:
raise ImproperlyConfigured(
'You must install google-cloud-storage to use gcs backend'
)
super().__init__(**kwargs)
self._client_lock = RLock()
self._pid = getpid()
self._retry_policy = DEFAULT_RETRY
self._client = None
conf = self.app.conf
if self.url:
url_params = self._params_from_url()
conf.update(**dictfilter(url_params))
self.bucket_name = conf.get('gcs_bucket')
if not self.bucket_name:
raise ImproperlyConfigured(
'Missing bucket name: specify gcs_bucket to use gcs backend'
)
self.project = conf.get('gcs_project')
if not self.project:
raise ImproperlyConfigured(
'Missing project:specify gcs_project to use gcs backend'
)
self.base_path = conf.get('gcs_base_path', '').strip('/')
self._threadpool_maxsize = int(conf.get('gcs_threadpool_maxsize', 10))
self.ttl = float(conf.get('gcs_ttl') or 0)
if self.ttl < 0:
raise ImproperlyConfigured(
f'Invalid ttl: {self.ttl} must be greater than or equal to 0'
)
elif self.ttl:
if not self._is_bucket_lifecycle_rule_exists():
raise ImproperlyConfigured(
f'Missing lifecycle rule to use gcs backend with ttl on '
f'bucket: {self.bucket_name}'
)
def get(self, key):
key = bytes_to_str(key)
blob = self._get_blob(key)
try:
return blob.download_as_bytes(retry=self._retry_policy)
except storage.blob.NotFound:
return None
def set(self, key, value):
key = bytes_to_str(key)
blob = self._get_blob(key)
if self.ttl:
blob.custom_time = datetime.utcnow() + timedelta(seconds=self.ttl)
blob.upload_from_string(value, retry=self._retry_policy)
def delete(self, key):
key = bytes_to_str(key)
blob = self._get_blob(key)
if blob.exists():
blob.delete(retry=self._retry_policy)
def mget(self, keys):
with ThreadPoolExecutor() as pool:
return list(pool.map(self.get, keys))
@property
def client(self):
"""Returns a storage client."""
# make sure it's thread-safe, as creating a new client is expensive
with self._client_lock:
if self._client and self._pid == getpid():
return self._client
# make sure each process gets its own connection after a fork
self._client = Client(project=self.project)
self._pid = getpid()
# config the number of connections to the server
adapter = requests.adapters.HTTPAdapter(
pool_connections=self._threadpool_maxsize,
pool_maxsize=self._threadpool_maxsize,
max_retries=3,
)
client_http = self._client._http
client_http.mount("https://", adapter)
client_http._auth_request.session.mount("https://", adapter)
return self._client
@property
def bucket(self):
return self.client.bucket(self.bucket_name)
def _get_blob(self, key):
key_bucket_path = f'{self.base_path}/{key}' if self.base_path else key
return self.bucket.blob(key_bucket_path)
def _is_bucket_lifecycle_rule_exists(self):
bucket = self.bucket
bucket.reload()
for rule in bucket.lifecycle_rules:
if rule['action']['type'] == 'Delete':
return True
return False
def _params_from_url(self):
url_parts = url_to_parts(self.url)
return {
'gcs_bucket': url_parts.hostname,
'gcs_base_path': url_parts.path,
**url_parts.query,
}
class GCSBackend(GCSBackendBase):
"""Google Cloud Storage task result backend.
Uses Firestore for chord ref count.
"""
implements_incr = True
supports_native_join = True
# Firestore parameters
_collection_name = 'celery'
_field_count = 'chord_count'
_field_expires = 'expires_at'
def __init__(self, **kwargs):
if not (firestore and firestore_admin_v1):
raise ImproperlyConfigured(
'You must install google-cloud-firestore to use gcs backend'
)
super().__init__(**kwargs)
self._firestore_lock = RLock()
self._firestore_client = None
self.firestore_project = self.app.conf.get(
'firestore_project', self.project
)
if not self._is_firestore_ttl_policy_enabled():
raise ImproperlyConfigured(
f'Missing TTL policy to use gcs backend with ttl on '
f'Firestore collection: {self._collection_name} '
f'project: {self.firestore_project}'
)
@property
def firestore_client(self):
"""Returns a firestore client."""
# make sure it's thread-safe, as creating a new client is expensive
with self._firestore_lock:
if self._firestore_client and self._pid == getpid():
return self._firestore_client
# make sure each process gets its own connection after a fork
self._firestore_client = firestore.Client(
project=self.firestore_project
)
self._pid = getpid()
return self._firestore_client
def _is_firestore_ttl_policy_enabled(self):
client = firestore_admin_v1.FirestoreAdminClient()
name = (
f"projects/{self.firestore_project}"
f"/databases/(default)/collectionGroups/{self._collection_name}"
f"/fields/{self._field_expires}"
)
request = firestore_admin_v1.GetFieldRequest(name=name)
field = client.get_field(request=request)
ttl_config = field.ttl_config
return ttl_config and ttl_config.state in {
firestore_admin_v1.Field.TtlConfig.State.ACTIVE,
firestore_admin_v1.Field.TtlConfig.State.CREATING,
}
def _apply_chord_incr(self, header_result_args, body, **kwargs):
key = self.get_key_for_chord(header_result_args[0]).decode()
self._expire_chord_key(key, 86400)
return super()._apply_chord_incr(header_result_args, body, **kwargs)
def incr(self, key: bytes) -> int:
doc = self._firestore_document(key)
resp = doc.set(
{self._field_count: firestore.Increment(1)},
merge=True,
retry=retry.Retry(
predicate=if_exception_type(Conflict),
initial=1.0,
maximum=180.0,
multiplier=2.0,
timeout=180.0,
),
)
return resp.transform_results[0].integer_value
def on_chord_part_return(self, request, state, result, **kwargs):
"""Chord part return callback.
Called for each task in the chord.
Increments the counter stored in Firestore.
If the counter reaches the number of tasks in the chord, the callback
is called.
If the callback raises an exception, the chord is marked as errored.
If the callback returns a value, the chord is marked as successful.
"""
app = self.app
gid = request.group
if not gid:
return
key = self.get_key_for_chord(gid)
val = self.incr(key)
size = request.chord.get("chord_size")
if size is None:
deps = self._restore_deps(gid, request)
if deps is None:
return
size = len(deps)
if val > size: # pragma: no cover
logger.warning(
'Chord counter incremented too many times for %r', gid
)
elif val == size:
# Read the deps once, to reduce the number of reads from GCS ($$)
deps = self._restore_deps(gid, request)
if deps is None:
return
callback = maybe_signature(request.chord, app=app)
j = deps.join_native
try:
with allow_join_result():
ret = j(
timeout=app.conf.result_chord_join_timeout,
propagate=True,
)
except Exception as exc: # pylint: disable=broad-except
try:
culprit = next(deps._failed_join_report())
reason = 'Dependency {0.id} raised {1!r}'.format(
culprit,
exc,
)
except StopIteration:
reason = repr(exc)
logger.exception('Chord %r raised: %r', gid, reason)
self.chord_error_from_stack(callback, ChordError(reason))
else:
try:
callback.delay(ret)
except Exception as exc: # pylint: disable=broad-except
logger.exception('Chord %r raised: %r', gid, exc)
self.chord_error_from_stack(
callback,
ChordError(f'Callback error: {exc!r}'),
)
finally:
deps.delete()
# Firestore doesn't have an exact ttl policy, so delete the key.
self._delete_chord_key(key)
def _restore_deps(self, gid, request):
app = self.app
try:
deps = GroupResult.restore(gid, backend=self)
except Exception as exc: # pylint: disable=broad-except
callback = maybe_signature(request.chord, app=app)
logger.exception('Chord %r raised: %r', gid, exc)
self.chord_error_from_stack(
callback,
ChordError(f'Cannot restore group: {exc!r}'),
)
return
if deps is None:
try:
raise ValueError(gid)
except ValueError as exc:
callback = maybe_signature(request.chord, app=app)
logger.exception('Chord callback %r raised: %r', gid, exc)
self.chord_error_from_stack(
callback,
ChordError(f'GroupResult {gid} no longer exists'),
)
return deps
def _delete_chord_key(self, key):
doc = self._firestore_document(key)
doc.delete()
def _expire_chord_key(self, key, expires):
"""Set TTL policy for a Firestore document.
Firestore ttl data is typically deleted within 24 hours after its
expiration date.
"""
val_expires = datetime.utcnow() + timedelta(seconds=expires)
doc = self._firestore_document(key)
doc.set({self._field_expires: val_expires}, merge=True)
def _firestore_document(self, key):
return self.firestore_client.collection(
self._collection_name
).document(bytes_to_str(key))

View File

@@ -1,5 +1,5 @@
"""MongoDB result store backend."""
from datetime import datetime, timedelta
from datetime import datetime, timedelta, timezone
from kombu.exceptions import EncodeError
from kombu.utils.objects import cached_property
@@ -228,7 +228,7 @@ class MongoBackend(BaseBackend):
meta = {
'_id': group_id,
'result': self.encode([i.id for i in result]),
'date_done': datetime.utcnow(),
'date_done': datetime.now(timezone.utc),
}
self.group_collection.replace_one({'_id': group_id}, meta, upsert=True)
return result

View File

@@ -359,6 +359,11 @@ class RedisBackend(BaseKeyValueStoreBackend, AsyncBackendMixin):
connparams.update(query)
return connparams
def exception_safe_to_retry(self, exc):
if isinstance(exc, self.connection_errors):
return True
return False
@cached_property
def retry_policy(self):
retry_policy = super().retry_policy

View File

@@ -222,7 +222,7 @@ class RPCBackend(base.Backend, AsyncBackendMixin):
def on_out_of_band_result(self, task_id, message):
# Callback called when a reply for a task is received,
# but we have no idea what do do with it.
# but we have no idea what to do with it.
# Since the result is not pending, we put it in a separate
# buffer: probably it will become pending later.
if self.result_consumer:

View File

@@ -1,6 +1,7 @@
"""The periodic task scheduler."""
import copy
import dbm
import errno
import heapq
import os
@@ -568,11 +569,11 @@ class PersistentScheduler(Scheduler):
for _ in (1, 2):
try:
self._store['entries']
except KeyError:
except (KeyError, UnicodeDecodeError, TypeError):
# new schedule db
try:
self._store['entries'] = {}
except KeyError as exc:
except (KeyError, UnicodeDecodeError, TypeError) + dbm.error as exc:
self._store = self._destroy_open_corrupted_schedule(exc)
continue
else:

View File

@@ -4,9 +4,10 @@ import numbers
from collections import OrderedDict
from functools import update_wrapper
from pprint import pformat
from typing import Any
import click
from click import ParamType
from click import Context, ParamType
from kombu.utils.objects import cached_property
from celery._state import get_current_app
@@ -170,19 +171,37 @@ class CeleryCommand(click.Command):
formatter.write_dl(opts_group)
class DaemonOption(CeleryOption):
"""Common daemonization option"""
def __init__(self, *args, **kwargs):
super().__init__(args,
help_group=kwargs.pop("help_group", "Daemonization Options"),
callback=kwargs.pop("callback", self.daemon_setting),
**kwargs)
def daemon_setting(self, ctx: Context, opt: CeleryOption, value: Any) -> Any:
"""
Try to fetch daemonization option from applications settings.
Use the daemon command name as prefix (eg. `worker` -> `worker_pidfile`)
"""
return value or getattr(ctx.obj.app.conf, f"{ctx.command.name}_{self.name}", None)
class CeleryDaemonCommand(CeleryCommand):
"""Daemon commands."""
def __init__(self, *args, **kwargs):
"""Initialize a Celery command with common daemon options."""
super().__init__(*args, **kwargs)
self.params.append(CeleryOption(('-f', '--logfile'), help_group="Daemonization Options",
help="Log destination; defaults to stderr"))
self.params.append(CeleryOption(('--pidfile',), help_group="Daemonization Options"))
self.params.append(CeleryOption(('--uid',), help_group="Daemonization Options"))
self.params.append(CeleryOption(('--gid',), help_group="Daemonization Options"))
self.params.append(CeleryOption(('--umask',), help_group="Daemonization Options"))
self.params.append(CeleryOption(('--executable',), help_group="Daemonization Options"))
self.params.extend((
DaemonOption("--logfile", "-f", help="Log destination; defaults to stderr"),
DaemonOption("--pidfile", help="PID file path; defaults to no PID file"),
DaemonOption("--uid", help="Drops privileges to this user ID"),
DaemonOption("--gid", help="Drops privileges to this group ID"),
DaemonOption("--umask", help="Create files and directories with this umask"),
DaemonOption("--executable", help="Override path to the Python executable"),
))
class CommaSeparatedList(ParamType):

View File

@@ -11,7 +11,6 @@ except ImportError:
import click
import click.exceptions
from click.types import ParamType
from click_didyoumean import DYMGroup
from click_plugins import with_plugins
@@ -48,34 +47,6 @@ Unable to load celery application.
{0}""")
class App(ParamType):
"""Application option."""
name = "application"
def convert(self, value, param, ctx):
try:
return find_app(value)
except ModuleNotFoundError as e:
if e.name != value:
exc = traceback.format_exc()
self.fail(
UNABLE_TO_LOAD_APP_ERROR_OCCURRED.format(value, exc)
)
self.fail(UNABLE_TO_LOAD_APP_MODULE_NOT_FOUND.format(e.name))
except AttributeError as e:
attribute_name = e.args[0].capitalize()
self.fail(UNABLE_TO_LOAD_APP_APP_MISSING.format(attribute_name))
except Exception:
exc = traceback.format_exc()
self.fail(
UNABLE_TO_LOAD_APP_ERROR_OCCURRED.format(value, exc)
)
APP = App()
if sys.version_info >= (3, 10):
_PLUGINS = entry_points(group='celery.commands')
else:
@@ -91,7 +62,11 @@ else:
'--app',
envvar='APP',
cls=CeleryOption,
type=APP,
# May take either: a str when invoked from command line (Click),
# or a Celery object when invoked from inside Celery; hence the
# need to prevent Click from "processing" the Celery object and
# converting it into its str representation.
type=click.UNPROCESSED,
help_group="Global Options")
@click.option('-b',
'--broker',
@@ -160,6 +135,26 @@ def celery(ctx, app, broker, result_backend, loader, config, workdir,
os.environ['CELERY_CONFIG_MODULE'] = config
if skip_checks:
os.environ['CELERY_SKIP_CHECKS'] = 'true'
if isinstance(app, str):
try:
app = find_app(app)
except ModuleNotFoundError as e:
if e.name != app:
exc = traceback.format_exc()
ctx.fail(
UNABLE_TO_LOAD_APP_ERROR_OCCURRED.format(app, exc)
)
ctx.fail(UNABLE_TO_LOAD_APP_MODULE_NOT_FOUND.format(e.name))
except AttributeError as e:
attribute_name = e.args[0].capitalize()
ctx.fail(UNABLE_TO_LOAD_APP_APP_MISSING.format(attribute_name))
except Exception:
exc = traceback.format_exc()
ctx.fail(
UNABLE_TO_LOAD_APP_ERROR_OCCURRED.format(app, exc)
)
ctx.obj = CLIContext(app=app, no_color=no_color, workdir=workdir,
quiet=quiet)

View File

@@ -1,5 +1,6 @@
"""The ``celery control``, ``. inspect`` and ``. status`` programs."""
from functools import partial
from typing import Literal
import click
from kombu.utils.json import dumps
@@ -39,18 +40,69 @@ def _consume_arguments(meta, method, args):
args[:] = args[i:]
def _compile_arguments(action, args):
meta = Panel.meta[action]
def _compile_arguments(command, args):
meta = Panel.meta[command]
arguments = {}
if meta.args:
arguments.update({
k: v for k, v in _consume_arguments(meta, action, args)
k: v for k, v in _consume_arguments(meta, command, args)
})
if meta.variadic:
arguments.update({meta.variadic: args})
return arguments
_RemoteControlType = Literal['inspect', 'control']
def _verify_command_name(type_: _RemoteControlType, command: str) -> None:
choices = _get_commands_of_type(type_)
if command not in choices:
command_listing = ", ".join(choices)
raise click.UsageError(
message=f'Command {command} not recognized. Available {type_} commands: {command_listing}',
)
def _list_option(type_: _RemoteControlType):
def callback(ctx: click.Context, param, value) -> None:
if not value:
return
choices = _get_commands_of_type(type_)
formatter = click.HelpFormatter()
with formatter.section(f'{type_.capitalize()} Commands'):
command_list = []
for command_name, info in choices.items():
if info.signature:
command_preview = f'{command_name} {info.signature}'
else:
command_preview = command_name
command_list.append((command_preview, info.help))
formatter.write_dl(command_list)
ctx.obj.echo(formatter.getvalue(), nl=False)
ctx.exit()
return click.option(
'--list',
is_flag=True,
help=f'List available {type_} commands and exit.',
expose_value=False,
is_eager=True,
callback=callback,
)
def _get_commands_of_type(type_: _RemoteControlType) -> dict:
command_name_info_pairs = [
(name, info) for name, info in Panel.meta.items()
if info.type == type_ and info.visible
]
return dict(sorted(command_name_info_pairs))
@click.command(cls=CeleryCommand)
@click.option('-t',
'--timeout',
@@ -96,10 +148,8 @@ def status(ctx, timeout, destination, json, **kwargs):
@click.command(cls=CeleryCommand,
context_settings={'allow_extra_args': True})
@click.argument("action", type=click.Choice([
name for name, info in Panel.meta.items()
if info.type == 'inspect' and info.visible
]))
@click.argument('command')
@_list_option('inspect')
@click.option('-t',
'--timeout',
cls=CeleryOption,
@@ -121,19 +171,19 @@ def status(ctx, timeout, destination, json, **kwargs):
help='Use json as output format.')
@click.pass_context
@handle_preload_options
def inspect(ctx, action, timeout, destination, json, **kwargs):
"""Inspect the worker at runtime.
def inspect(ctx, command, timeout, destination, json, **kwargs):
"""Inspect the workers by sending them the COMMAND inspect command.
Availability: RabbitMQ (AMQP) and Redis transports.
"""
_verify_command_name('inspect', command)
callback = None if json else partial(_say_remote_command_reply, ctx,
show_reply=True)
arguments = _compile_arguments(action, ctx.args)
arguments = _compile_arguments(command, ctx.args)
inspect = ctx.obj.app.control.inspect(timeout=timeout,
destination=destination,
callback=callback)
replies = inspect._request(action,
**arguments)
replies = inspect._request(command, **arguments)
if not replies:
raise CeleryCommandException(
@@ -153,10 +203,8 @@ def inspect(ctx, action, timeout, destination, json, **kwargs):
@click.command(cls=CeleryCommand,
context_settings={'allow_extra_args': True})
@click.argument("action", type=click.Choice([
name for name, info in Panel.meta.items()
if info.type == 'control' and info.visible
]))
@click.argument('command')
@_list_option('control')
@click.option('-t',
'--timeout',
cls=CeleryOption,
@@ -178,16 +226,17 @@ def inspect(ctx, action, timeout, destination, json, **kwargs):
help='Use json as output format.')
@click.pass_context
@handle_preload_options
def control(ctx, action, timeout, destination, json):
"""Workers remote control.
def control(ctx, command, timeout, destination, json):
"""Send the COMMAND control command to the workers.
Availability: RabbitMQ (AMQP), Redis, and MongoDB transports.
"""
_verify_command_name('control', command)
callback = None if json else partial(_say_remote_command_reply, ctx,
show_reply=True)
args = ctx.args
arguments = _compile_arguments(action, args)
replies = ctx.obj.app.control.broadcast(action, timeout=timeout,
arguments = _compile_arguments(command, args)
replies = ctx.obj.app.control.broadcast(command, timeout=timeout,
destination=destination,
callback=callback,
reply=True,

View File

@@ -396,7 +396,7 @@ class Signature(dict):
else:
args, kwargs, options = self.args, self.kwargs, self.options
# pylint: disable=too-many-function-args
# Borks on this, as it's a property
# Works on this, as it's a property
return _apply(args, kwargs, **options)
def _merge(self, args=None, kwargs=None, options=None, force=False):
@@ -515,7 +515,7 @@ class Signature(dict):
if group_index is not None:
opts['group_index'] = group_index
# pylint: disable=too-many-function-args
# Borks on this, as it's a property.
# Works on this, as it's a property.
return self.AsyncResult(tid)
_freeze = freeze
@@ -958,6 +958,8 @@ class _chain(Signature):
if isinstance(other, group):
# unroll group with one member
other = maybe_unroll_group(other)
if not isinstance(other, group):
return self.__or__(other)
# chain | group() -> chain
tasks = self.unchain_tasks()
if not tasks:
@@ -972,15 +974,20 @@ class _chain(Signature):
tasks, other), app=self._app)
elif isinstance(other, _chain):
# chain | chain -> chain
# use type(self) for _chain subclasses
return type(self)(seq_concat_seq(
self.unchain_tasks(), other.unchain_tasks()), app=self._app)
return reduce(operator.or_, other.unchain_tasks(), self)
elif isinstance(other, Signature):
if self.tasks and isinstance(self.tasks[-1], group):
# CHAIN [last item is group] | TASK -> chord
sig = self.clone()
sig.tasks[-1] = chord(
sig.tasks[-1], other, app=self._app)
# In the scenario where the second-to-last item in a chain is a chord,
# it leads to a situation where two consecutive chords are formed.
# In such cases, a further upgrade can be considered.
# This would involve chaining the body of the second-to-last chord with the last chord."
if len(sig.tasks) > 1 and isinstance(sig.tasks[-2], chord):
sig.tasks[-2].body = sig.tasks[-2].body | sig.tasks[-1]
sig.tasks = sig.tasks[:-1]
return sig
elif self.tasks and isinstance(self.tasks[-1], chord):
# CHAIN [last item is chord] -> chain with chord body.
@@ -1216,6 +1223,12 @@ class _chain(Signature):
task, body=prev_task,
root_id=root_id, app=app,
)
if tasks:
prev_task = tasks[-1]
prev_res = results[-1]
else:
prev_task = None
prev_res = None
if is_last_task:
# chain(task_id=id) means task id is set for the last task
@@ -1261,6 +1274,7 @@ class _chain(Signature):
while node.parent:
node = node.parent
prev_res = node
self.id = last_task_id
return tasks, results
def apply(self, args=None, kwargs=None, **options):
@@ -1672,6 +1686,8 @@ class group(Signature):
#
# We return a concretised tuple of the signatures actually applied to
# each child task signature, of which there might be none!
sig = maybe_signature(sig)
return tuple(child_task.link_error(sig.clone(immutable=True)) for child_task in self.tasks)
def _prepared(self, tasks, partial_args, group_id, root_id, app,
@@ -2271,6 +2287,8 @@ class _chord(Signature):
``False`` (the current default), then the error callback will only be
applied to the body.
"""
errback = maybe_signature(errback)
if self.app.conf.task_allow_error_cb_on_chord_header:
for task in maybe_list(self.tasks) or []:
task.link_error(errback.clone(immutable=True))
@@ -2289,6 +2307,13 @@ class _chord(Signature):
CPendingDeprecationWarning
)
# Edge case for nested chords in the header
for task in maybe_list(self.tasks) or []:
if isinstance(task, chord):
# Let the nested chord do the error linking itself on its
# header and body where needed, based on the current configuration
task.link_error(errback)
self.body.link_error(errback)
return errback

View File

@@ -103,26 +103,35 @@ def _get_job_writer(job):
return writer() # is a weakref
def _ensure_integral_fd(fd):
return fd if isinstance(fd, Integral) else fd.fileno()
if hasattr(select, 'poll'):
def _select_imp(readers=None, writers=None, err=None, timeout=0,
poll=select.poll, POLLIN=select.POLLIN,
POLLOUT=select.POLLOUT, POLLERR=select.POLLERR):
poller = poll()
register = poller.register
fd_to_mask = {}
if readers:
[register(fd, POLLIN) for fd in readers]
for fd in map(_ensure_integral_fd, readers):
fd_to_mask[fd] = fd_to_mask.get(fd, 0) | POLLIN
if writers:
[register(fd, POLLOUT) for fd in writers]
for fd in map(_ensure_integral_fd, writers):
fd_to_mask[fd] = fd_to_mask.get(fd, 0) | POLLOUT
if err:
[register(fd, POLLERR) for fd in err]
for fd in map(_ensure_integral_fd, err):
fd_to_mask[fd] = fd_to_mask.get(fd, 0) | POLLERR
for fd, event_mask in fd_to_mask.items():
register(fd, event_mask)
R, W = set(), set()
timeout = 0 if timeout and timeout < 0 else round(timeout * 1e3)
events = poller.poll(timeout)
for fd, event in events:
if not isinstance(fd, Integral):
fd = fd.fileno()
if event & POLLIN:
R.add(fd)
if event & POLLOUT:
@@ -194,7 +203,7 @@ def iterate_file_descriptors_safely(fds_iter, source_data,
or possibly other reasons, so safely manage our lists of FDs.
:param fds_iter: the file descriptors to iterate and apply hub_method
:param source_data: data source to remove FD if it renders OSError
:param hub_method: the method to call with with each fd and kwargs
:param hub_method: the method to call with each fd and kwargs
:*args to pass through to the hub_method;
with a special syntax string '*fd*' represents a substitution
for the current fd object in the iteration (for some callers).
@@ -772,7 +781,7 @@ class AsynPool(_pool.Pool):
None, WRITE | ERR, consolidate=True)
else:
iterate_file_descriptors_safely(
inactive, all_inqueues, hub_remove)
inactive, all_inqueues, hub.remove_writer)
self.on_poll_start = on_poll_start
def on_inqueue_close(fd, proc):
@@ -818,7 +827,7 @@ class AsynPool(_pool.Pool):
# worker is already busy with another task
continue
if ready_fd not in all_inqueues:
hub_remove(ready_fd)
hub.remove_writer(ready_fd)
continue
try:
job = pop_message()
@@ -829,7 +838,7 @@ class AsynPool(_pool.Pool):
# this may create a spinloop where the event loop
# always wakes up.
for inqfd in diff(active_writes):
hub_remove(inqfd)
hub.remove_writer(inqfd)
break
else:
@@ -927,7 +936,7 @@ class AsynPool(_pool.Pool):
else:
errors = 0
finally:
hub_remove(fd)
hub.remove_writer(fd)
write_stats[proc.index] += 1
# message written, so this fd is now available
active_writes.discard(fd)

View File

@@ -1,4 +1,6 @@
"""Gevent execution pool."""
import functools
import types
from time import monotonic
from kombu.asynchronous import timer as _timer
@@ -16,15 +18,22 @@ __all__ = ('TaskPool',)
# We cache globals and attribute lookups, so disable this warning.
def apply_target(target, args=(), kwargs=None, callback=None,
accept_callback=None, getpid=None, **_):
kwargs = {} if not kwargs else kwargs
return base.apply_target(target, args, kwargs, callback, accept_callback,
pid=getpid(), **_)
def apply_timeout(target, args=(), kwargs=None, callback=None,
accept_callback=None, pid=None, timeout=None,
accept_callback=None, getpid=None, timeout=None,
timeout_callback=None, Timeout=Timeout,
apply_target=base.apply_target, **rest):
kwargs = {} if not kwargs else kwargs
try:
with Timeout(timeout):
return apply_target(target, args, kwargs, callback,
accept_callback, pid,
accept_callback, getpid(),
propagate=(Timeout,), **rest)
except Timeout:
return timeout_callback(False, timeout)
@@ -82,18 +91,22 @@ class TaskPool(base.BasePool):
is_green = True
task_join_will_block = False
_pool = None
_pool_map = None
_quick_put = None
def __init__(self, *args, **kwargs):
from gevent import spawn_raw
from gevent import getcurrent, spawn_raw
from gevent.pool import Pool
self.Pool = Pool
self.getcurrent = getcurrent
self.getpid = lambda: id(getcurrent())
self.spawn_n = spawn_raw
self.timeout = kwargs.get('timeout')
super().__init__(*args, **kwargs)
def on_start(self):
self._pool = self.Pool(self.limit)
self._pool_map = {}
self._quick_put = self._pool.spawn
def on_stop(self):
@@ -102,12 +115,15 @@ class TaskPool(base.BasePool):
def on_apply(self, target, args=None, kwargs=None, callback=None,
accept_callback=None, timeout=None,
timeout_callback=None, apply_target=base.apply_target, **_):
timeout_callback=None, apply_target=apply_target, **_):
timeout = self.timeout if timeout is None else timeout
return self._quick_put(apply_timeout if timeout else apply_target,
target, args, kwargs, callback, accept_callback,
timeout=timeout,
timeout_callback=timeout_callback)
target = self._make_killable_target(target)
greenlet = self._quick_put(apply_timeout if timeout else apply_target,
target, args, kwargs, callback, accept_callback,
self.getpid, timeout=timeout, timeout_callback=timeout_callback)
self._add_to_pool_map(id(greenlet), greenlet)
greenlet.terminate = types.MethodType(_terminate, greenlet)
return greenlet
def grow(self, n=1):
self._pool._semaphore.counter += n
@@ -117,6 +133,39 @@ class TaskPool(base.BasePool):
self._pool._semaphore.counter -= n
self._pool.size -= n
def terminate_job(self, pid, signal=None):
import gevent
if pid in self._pool_map:
greenlet = self._pool_map[pid]
gevent.kill(greenlet)
@property
def num_processes(self):
return len(self._pool)
@staticmethod
def _make_killable_target(target):
def killable_target(*args, **kwargs):
from greenlet import GreenletExit
try:
return target(*args, **kwargs)
except GreenletExit:
return (False, None, None)
return killable_target
def _add_to_pool_map(self, pid, greenlet):
self._pool_map[pid] = greenlet
greenlet.link(
functools.partial(self._cleanup_after_job_finish, pid=pid, pool_map=self._pool_map),
)
@staticmethod
def _cleanup_after_job_finish(greenlet, pool_map, pid):
del pool_map[pid]
def _terminate(self, signal):
# Done in `TaskPool.terminate_job`
pass

View File

@@ -0,0 +1,21 @@
import functools
from django.db import transaction
from celery.app.task import Task
class DjangoTask(Task):
"""
Extend the base :class:`~celery.app.task.Task` for Django.
Provide a nicer API to trigger tasks at the end of the DB transaction.
"""
def delay_on_commit(self, *args, **kwargs) -> None:
"""Call :meth:`~celery.app.task.Task.delay` with Django's ``on_commit()``."""
transaction.on_commit(functools.partial(self.delay, *args, **kwargs))
def apply_async_on_commit(self, *args, **kwargs) -> None:
"""Call :meth:`~celery.app.task.Task.apply_async` with Django's ``on_commit()``."""
transaction.on_commit(functools.partial(self.apply_async, *args, **kwargs))

View File

@@ -3,10 +3,10 @@ import logging
import os
import threading
from contextlib import contextmanager
from typing import Any, Iterable, Union # noqa
from typing import Any, Iterable, Optional, Union
import celery.worker.consumer # noqa
from celery import Celery, worker # noqa
from celery import Celery, worker
from celery.result import _set_task_join_will_block, allow_join_result
from celery.utils.dispatch import Signal
from celery.utils.nodenames import anon_nodename
@@ -30,6 +30,10 @@ test_worker_stopped = Signal(
class TestWorkController(worker.WorkController):
"""Worker that can synchronize on being fully started."""
# When this class is imported in pytest files, prevent pytest from thinking
# this is a test class
__test__ = False
logger_queue = None
def __init__(self, *args, **kwargs):
@@ -131,16 +135,15 @@ def start_worker(
@contextmanager
def _start_worker_thread(app,
concurrency=1,
pool='solo',
loglevel=WORKER_LOGLEVEL,
logfile=None,
WorkController=TestWorkController,
perform_ping_check=True,
shutdown_timeout=10.0,
**kwargs):
# type: (Celery, int, str, Union[str, int], str, Any, **Any) -> Iterable
def _start_worker_thread(app: Celery,
concurrency: int = 1,
pool: str = 'solo',
loglevel: Union[str, int] = WORKER_LOGLEVEL,
logfile: Optional[str] = None,
WorkController: Any = TestWorkController,
perform_ping_check: bool = True,
shutdown_timeout: float = 10.0,
**kwargs) -> Iterable[worker.WorkController]:
"""Start Celery worker in a thread.
Yields:
@@ -156,7 +159,7 @@ def _start_worker_thread(app,
worker = WorkController(
app=app,
concurrency=concurrency,
hostname=anon_nodename(),
hostname=kwargs.pop("hostname", anon_nodename()),
pool=pool,
loglevel=loglevel,
logfile=logfile,
@@ -211,8 +214,7 @@ def _start_worker_process(app,
cluster.stopwait()
def setup_app_for_worker(app, loglevel, logfile) -> None:
# type: (Celery, Union[str, int], str) -> None
def setup_app_for_worker(app: Celery, loglevel: Union[str, int], logfile: str) -> None:
"""Setup the app to be used for starting an embedded worker."""
app.finalize()
app.set_current()

View File

@@ -55,7 +55,7 @@ def get_exchange(conn, name=EVENT_EXCHANGE_NAME):
(from topic -> fanout).
"""
ex = copy(event_exchange)
if conn.transport.driver_type == 'redis':
if conn.transport.driver_type in {'redis', 'gcpubsub'}:
# quick hack for Issue #436
ex.type = 'fanout'
if name != ex.name:

View File

@@ -2,7 +2,7 @@
import os
import sys
import warnings
from datetime import datetime
from datetime import datetime, timezone
from importlib import import_module
from typing import IO, TYPE_CHECKING, Any, List, Optional, cast
@@ -16,6 +16,7 @@ if TYPE_CHECKING:
from types import ModuleType
from typing import Protocol
from django.db.backends.base.base import BaseDatabaseWrapper
from django.db.utils import ConnectionHandler
from celery.app.base import Celery
@@ -78,6 +79,9 @@ class DjangoFixup:
self._settings = symbol_by_name('django.conf:settings')
self.app.loader.now = self.now
if not self.app._custom_task_cls_used:
self.app.task_cls = 'celery.contrib.django.task:DjangoTask'
signals.import_modules.connect(self.on_import_modules)
signals.worker_init.connect(self.on_worker_init)
return self
@@ -100,7 +104,7 @@ class DjangoFixup:
self.worker_fixup.install()
def now(self, utc: bool = False) -> datetime:
return datetime.utcnow() if utc else self._now()
return datetime.now(timezone.utc) if utc else self._now()
def autodiscover_tasks(self) -> List[str]:
from django.apps import apps
@@ -161,15 +165,16 @@ class DjangoWorkerFixup:
# network IO that close() might cause.
for c in self._db.connections.all():
if c and c.connection:
self._maybe_close_db_fd(c.connection)
self._maybe_close_db_fd(c)
# use the _ version to avoid DB_REUSE preventing the conn.close() call
self._close_database(force=True)
self.close_cache()
def _maybe_close_db_fd(self, fd: IO) -> None:
def _maybe_close_db_fd(self, c: "BaseDatabaseWrapper") -> None:
try:
_maybe_close_fd(fd)
with c.wrap_database_errors:
_maybe_close_fd(c.connection)
except self.interface_errors:
pass

View File

@@ -3,7 +3,7 @@ import importlib
import os
import re
import sys
from datetime import datetime
from datetime import datetime, timezone
from kombu.utils import json
from kombu.utils.objects import cached_property
@@ -62,7 +62,7 @@ class BaseLoader:
def now(self, utc=True):
if utc:
return datetime.utcnow()
return datetime.now(timezone.utc)
return datetime.now()
def on_task_init(self, task_id, task):
@@ -253,10 +253,12 @@ def find_related_module(package, related_name):
# Django 1.7 allows for specifying a class name in INSTALLED_APPS.
# (Issue #2248).
try:
# Return package itself when no related_name.
module = importlib.import_module(package)
if not related_name and module:
return module
except ImportError:
except ModuleNotFoundError:
# On import error, try to walk package up one level.
package, _, _ = package.rpartition('.')
if not package:
raise
@@ -264,9 +266,13 @@ def find_related_module(package, related_name):
module_name = f'{package}.{related_name}'
try:
# Try to find related_name under package.
return importlib.import_module(module_name)
except ImportError as e:
import_exc_name = getattr(e, 'name', module_name)
if import_exc_name is not None and import_exc_name != module_name:
raise e
return
except ModuleNotFoundError as e:
import_exc_name = getattr(e, 'name', None)
# If candidate does not exist, then return None.
if import_exc_name and module_name == import_exc_name:
return
# Otherwise, raise because error probably originated from a nested import.
raise e

View File

@@ -397,7 +397,6 @@ COMPAT_MODULES = {
},
'log': {
'get_default_logger': 'log.get_default_logger',
'setup_logger': 'log.setup_logger',
'setup_logging_subsystem': 'log.setup_logging_subsystem',
'redirect_stdouts_to_logger': 'log.redirect_stdouts_to_logger',
},

View File

@@ -42,7 +42,7 @@ __all__ = (
'DaemonContext', 'detached', 'parse_uid', 'parse_gid', 'setgroups',
'initgroups', 'setgid', 'setuid', 'maybe_drop_privileges', 'signals',
'signal_name', 'set_process_title', 'set_mp_process_title',
'get_errno_name', 'ignore_errno', 'fd_by_path',
'get_errno_name', 'ignore_errno', 'fd_by_path', 'isatty',
)
# exitcodes
@@ -95,6 +95,14 @@ SIGNAMES = {
SIGMAP = {getattr(_signal, name): name for name in SIGNAMES}
def isatty(fh):
"""Return true if the process has a controlling terminal."""
try:
return fh.isatty()
except AttributeError:
pass
def pyimplementation():
"""Return string identifying the current Python implementation."""
if hasattr(_platform, 'python_implementation'):
@@ -186,10 +194,14 @@ class Pidfile:
if not pid:
self.remove()
return True
if pid == os.getpid():
# this can be common in k8s pod with PID of 1 - don't kill
self.remove()
return True
try:
os.kill(pid, 0)
except os.error as exc:
except OSError as exc:
if exc.errno == errno.ESRCH or exc.errno == errno.EPERM:
print('Stale pidfile exists - Removing it.', file=sys.stderr)
self.remove()

View File

@@ -6,6 +6,7 @@ from collections import deque
from contextlib import contextmanager
from weakref import proxy
from dateutil.parser import isoparse
from kombu.utils.objects import cached_property
from vine import Thenable, barrier, promise
@@ -532,7 +533,7 @@ class AsyncResult(ResultBase):
"""UTC date and time."""
date_done = self._get_task_meta().get('date_done')
if date_done and not isinstance(date_done, datetime.datetime):
return datetime.datetime.fromisoformat(date_done)
return isoparse(date_done)
return date_done
@property
@@ -983,13 +984,14 @@ class GroupResult(ResultSet):
class EagerResult(AsyncResult):
"""Result that we know has already been executed."""
def __init__(self, id, ret_value, state, traceback=None):
def __init__(self, id, ret_value, state, traceback=None, name=None):
# pylint: disable=super-init-not-called
# XXX should really not be inheriting from AsyncResult
self.id = id
self._result = ret_value
self._state = state
self._traceback = traceback
self._name = name
self.on_ready = promise()
self.on_ready(self)
@@ -1042,6 +1044,7 @@ class EagerResult(AsyncResult):
'result': self._result,
'status': self._state,
'traceback': self._traceback,
'name': self._name,
}
@property

View File

@@ -4,9 +4,8 @@ from __future__ import annotations
import re
from bisect import bisect, bisect_left
from collections import namedtuple
from collections.abc import Iterable
from datetime import datetime, timedelta, tzinfo
from typing import Any, Callable, Mapping, Sequence
from typing import Any, Callable, Iterable, Mapping, Sequence, Union
from kombu.utils.objects import cached_property
@@ -15,7 +14,7 @@ from celery import Celery
from . import current_app
from .utils.collections import AttributeDict
from .utils.time import (ffwd, humanize_seconds, localize, maybe_make_aware, maybe_timedelta, remaining, timezone,
weekday)
weekday, yearmonth)
__all__ = (
'ParseException', 'schedule', 'crontab', 'crontab_parser',
@@ -52,7 +51,10 @@ Argument event "{event}" is invalid, must be one of {all_events}.\
"""
def cronfield(s: str) -> str:
Cronspec = Union[int, str, Iterable[int]]
def cronfield(s: Cronspec | None) -> Cronspec:
return '*' if s is None else s
@@ -300,9 +302,12 @@ class crontab_parser:
i = int(s)
except ValueError:
try:
i = weekday(s)
i = yearmonth(s)
except KeyError:
raise ValueError(f'Invalid weekday literal {s!r}.')
try:
i = weekday(s)
except KeyError:
raise ValueError(f'Invalid weekday literal {s!r}.')
max_val = self.min_ + self.max_ - 1
if i > max_val:
@@ -393,8 +398,8 @@ class crontab(BaseSchedule):
present in ``month_of_year``.
"""
def __init__(self, minute: str = '*', hour: str = '*', day_of_week: str = '*',
day_of_month: str = '*', month_of_year: str = '*', **kwargs: Any) -> None:
def __init__(self, minute: Cronspec = '*', hour: Cronspec = '*', day_of_week: Cronspec = '*',
day_of_month: Cronspec = '*', month_of_year: Cronspec = '*', **kwargs: Any) -> None:
self._orig_minute = cronfield(minute)
self._orig_hour = cronfield(hour)
self._orig_day_of_week = cronfield(day_of_week)
@@ -408,9 +413,26 @@ class crontab(BaseSchedule):
self.month_of_year = self._expand_cronspec(month_of_year, 12, 1)
super().__init__(**kwargs)
@classmethod
def from_string(cls, crontab: str) -> crontab:
"""
Create a Crontab from a cron expression string. For example ``crontab.from_string('* * * * *')``.
.. code-block:: text
┌───────────── minute (059)
│ ┌───────────── hour (023)
│ │ ┌───────────── day of the month (131)
│ │ │ ┌───────────── month (112)
│ │ │ │ ┌───────────── day of the week (06) (Sunday to Saturday)
* * * * *
"""
minute, hour, day_of_month, month_of_year, day_of_week = crontab.split(" ")
return cls(minute, hour, day_of_week, day_of_month, month_of_year)
@staticmethod
def _expand_cronspec(
cronspec: int | str | Iterable,
cronspec: Cronspec,
max_: int, min_: int = 0) -> set[Any]:
"""Expand cron specification.
@@ -535,7 +557,7 @@ class crontab(BaseSchedule):
def __repr__(self) -> str:
return CRON_REPR.format(self)
def __reduce__(self) -> tuple[type, tuple[str, str, str, str, str], Any]:
def __reduce__(self) -> tuple[type, tuple[Cronspec, Cronspec, Cronspec, Cronspec, Cronspec], Any]:
return (self.__class__, (self._orig_minute,
self._orig_hour,
self._orig_day_of_week,

View File

@@ -43,7 +43,7 @@ class Certificate:
def has_expired(self) -> bool:
"""Check if the certificate has expired."""
return datetime.datetime.utcnow() >= self._cert.not_valid_after
return datetime.datetime.now(datetime.timezone.utc) >= self._cert.not_valid_after_utc
def get_pubkey(self) -> (
DSAPublicKey | EllipticCurvePublicKey | Ed448PublicKey | Ed25519PublicKey | RSAPublicKey

View File

@@ -11,6 +11,11 @@ from .utils import get_digest_algorithm, reraise_errors
__all__ = ('SecureSerializer', 'register_auth')
# Note: we guarantee that this value won't appear in the serialized data,
# so we can use it as a separator.
# If you change this value, make sure it's not present in the serialized data.
DEFAULT_SEPARATOR = str_to_bytes("\x00\x01")
class SecureSerializer:
"""Signed serializer."""
@@ -29,7 +34,8 @@ class SecureSerializer:
assert self._cert is not None
with reraise_errors('Unable to serialize: {0!r}', (Exception,)):
content_type, content_encoding, body = dumps(
bytes_to_str(data), serializer=self._serializer)
data, serializer=self._serializer)
# What we sign is the serialized body, not the body itself.
# this way the receiver doesn't have to decode the contents
# to verify the signature (and thus avoiding potential flaws
@@ -48,43 +54,26 @@ class SecureSerializer:
payload['signer'],
payload['body'])
self._cert_store[signer].verify(body, signature, self._digest)
return loads(bytes_to_str(body), payload['content_type'],
return loads(body, payload['content_type'],
payload['content_encoding'], force=True)
def _pack(self, body, content_type, content_encoding, signer, signature,
sep=str_to_bytes('\x00\x01')):
sep=DEFAULT_SEPARATOR):
fields = sep.join(
ensure_bytes(s) for s in [signer, signature, content_type,
content_encoding, body]
ensure_bytes(s) for s in [b64encode(signer), b64encode(signature),
content_type, content_encoding, body]
)
return b64encode(fields)
def _unpack(self, payload, sep=str_to_bytes('\x00\x01')):
def _unpack(self, payload, sep=DEFAULT_SEPARATOR):
raw_payload = b64decode(ensure_bytes(payload))
first_sep = raw_payload.find(sep)
signer = raw_payload[:first_sep]
signer_cert = self._cert_store[signer]
# shift 3 bits right to get signature length
# 2048bit rsa key has a signature length of 256
# 4096bit rsa key has a signature length of 512
sig_len = signer_cert.get_pubkey().key_size >> 3
sep_len = len(sep)
signature_start_position = first_sep + sep_len
signature_end_position = signature_start_position + sig_len
signature = raw_payload[
signature_start_position:signature_end_position
]
v = raw_payload[signature_end_position + sep_len:].split(sep)
v = raw_payload.split(sep, maxsplit=4)
return {
'signer': signer,
'signature': signature,
'content_type': bytes_to_str(v[0]),
'content_encoding': bytes_to_str(v[1]),
'body': bytes_to_str(v[2]),
'signer': b64decode(v[0]),
'signature': b64decode(v[1]),
'content_type': bytes_to_str(v[2]),
'content_encoding': bytes_to_str(v[3]),
'body': v[4],
}

View File

@@ -0,0 +1,49 @@
"""Code related to handling annotations."""
import sys
import types
import typing
from inspect import isclass
def is_none_type(value: typing.Any) -> bool:
"""Check if the given value is a NoneType."""
if sys.version_info < (3, 10):
# raise Exception('below 3.10', value, type(None))
return value is type(None)
return value == types.NoneType # type: ignore[no-any-return]
def get_optional_arg(annotation: typing.Any) -> typing.Any:
"""Get the argument from an Optional[...] annotation, or None if it is no such annotation."""
origin = typing.get_origin(annotation)
if origin != typing.Union and (sys.version_info >= (3, 10) and origin != types.UnionType):
return None
union_args = typing.get_args(annotation)
if len(union_args) != 2: # Union does _not_ have two members, so it's not an Optional
return None
has_none_arg = any(is_none_type(arg) for arg in union_args)
# There will always be at least one type arg, as we have already established that this is a Union with exactly
# two members, and both cannot be None (`Union[None, None]` does not work).
type_arg = next(arg for arg in union_args if not is_none_type(arg)) # pragma: no branch
if has_none_arg:
return type_arg
return None
def annotation_is_class(annotation: typing.Any) -> bool:
"""Test if a given annotation is a class that can be used in isinstance()/issubclass()."""
# isclass() returns True for generic type hints (e.g. `list[str]`) until Python 3.10.
# NOTE: The guard for Python 3.9 is because types.GenericAlias is only added in Python 3.9. This is not a problem
# as the syntax is added in the same version in the first place.
if (3, 9) <= sys.version_info < (3, 11) and isinstance(annotation, types.GenericAlias):
return False
return isclass(annotation)
def annotation_issubclass(annotation: typing.Any, cls: type) -> bool:
"""Test if a given annotation is of the given subclass."""
return annotation_is_class(annotation) and issubclass(annotation, cls)

View File

@@ -595,8 +595,7 @@ class LimitedSet:
break # oldest item hasn't expired yet
self.pop()
def pop(self, default=None) -> Any:
# type: (Any) -> Any
def pop(self, default: Any = None) -> Any:
"""Remove and return the oldest item, or :const:`None` when empty."""
while self._heap:
_, item = heappop(self._heap)

View File

@@ -54,6 +54,9 @@ def _boundmethod_safe_weakref(obj):
def _make_lookup_key(receiver, sender, dispatch_uid):
if dispatch_uid:
return (dispatch_uid, _make_id(sender))
# Issue #9119 - retry-wrapped functions use the underlying function for dispatch_uid
elif hasattr(receiver, '_dispatch_uid'):
return (receiver._dispatch_uid, _make_id(sender))
else:
return (_make_id(receiver), _make_id(sender))
@@ -170,6 +173,7 @@ class Signal: # pragma: no cover
# it up later with the original func id
options['dispatch_uid'] = _make_id(fun)
fun = _retry_receiver(fun)
fun._dispatch_uid = options['dispatch_uid']
self._connect_signal(fun, sender, options['weak'],
options['dispatch_uid'])

View File

@@ -51,8 +51,13 @@ def instantiate(name, *args, **kwargs):
@contextmanager
def cwd_in_path():
"""Context adding the current working directory to sys.path."""
cwd = os.getcwd()
if cwd in sys.path:
try:
cwd = os.getcwd()
except FileNotFoundError:
cwd = None
if not cwd:
yield
elif cwd in sys.path:
yield
else:
sys.path.insert(0, cwd)

View File

@@ -50,9 +50,9 @@ TIMEZONE_REGEX = re.compile(
)
def parse_iso8601(datestring):
def parse_iso8601(datestring: str) -> datetime:
"""Parse and convert ISO-8601 string to datetime."""
warn("parse_iso8601", "v5.3", "v6", "datetime.datetime.fromisoformat")
warn("parse_iso8601", "v5.3", "v6", "datetime.datetime.fromisoformat or dateutil.parser.isoparse")
m = ISO8601_REGEX.match(datestring)
if not m:
raise ValueError('unable to parse date string %r' % datestring)

View File

@@ -37,7 +37,7 @@ base_logger = logger = _get_logger('celery')
def set_in_sighandler(value):
"""Set flag signifiying that we're inside a signal handler."""
"""Set flag signifying that we're inside a signal handler."""
global _in_sighandler
_in_sighandler = value

View File

@@ -1,4 +1,6 @@
"""Worker name utilities."""
from __future__ import annotations
import os
import socket
from functools import partial
@@ -22,13 +24,18 @@ NODENAME_DEFAULT = 'celery'
gethostname = memoize(1, Cache=dict)(socket.gethostname)
__all__ = (
'worker_direct', 'gethostname', 'nodename',
'anon_nodename', 'nodesplit', 'default_nodename',
'node_format', 'host_format',
'worker_direct',
'gethostname',
'nodename',
'anon_nodename',
'nodesplit',
'default_nodename',
'node_format',
'host_format',
)
def worker_direct(hostname):
def worker_direct(hostname: str | Queue) -> Queue:
"""Return the :class:`kombu.Queue` being a direct route to a worker.
Arguments:
@@ -46,21 +53,20 @@ def worker_direct(hostname):
)
def nodename(name, hostname):
def nodename(name: str, hostname: str) -> str:
"""Create node name from name/hostname pair."""
return NODENAME_SEP.join((name, hostname))
def anon_nodename(hostname=None, prefix='gen'):
def anon_nodename(hostname: str | None = None, prefix: str = 'gen') -> str:
"""Return the nodename for this process (not a worker).
This is used for e.g. the origin task message field.
"""
return nodename(''.join([prefix, str(os.getpid())]),
hostname or gethostname())
return nodename(''.join([prefix, str(os.getpid())]), hostname or gethostname())
def nodesplit(name):
def nodesplit(name: str) -> tuple[None, str] | list[str]:
"""Split node name into tuple of name/hostname."""
parts = name.split(NODENAME_SEP, 1)
if len(parts) == 1:
@@ -68,21 +74,21 @@ def nodesplit(name):
return parts
def default_nodename(hostname):
def default_nodename(hostname: str) -> str:
"""Return the default nodename for this process."""
name, host = nodesplit(hostname or '')
return nodename(name or NODENAME_DEFAULT, host or gethostname())
def node_format(s, name, **extra):
def node_format(s: str, name: str, **extra: dict) -> str:
"""Format worker node name (name@host.com)."""
shortname, host = nodesplit(name)
return host_format(
s, host, shortname or NODENAME_DEFAULT, p=name, **extra)
return host_format(s, host, shortname or NODENAME_DEFAULT, p=name, **extra)
def _fmt_process_index(prefix='', default='0'):
def _fmt_process_index(prefix: str = '', default: str = '0') -> str:
from .log import current_process_index
index = current_process_index()
return f'{prefix}{index}' if index else default
@@ -90,13 +96,19 @@ def _fmt_process_index(prefix='', default='0'):
_fmt_process_index_with_prefix = partial(_fmt_process_index, '-', '')
def host_format(s, host=None, name=None, **extra):
def host_format(s: str, host: str | None = None, name: str | None = None, **extra: dict) -> str:
"""Format host %x abbreviations."""
host = host or gethostname()
hname, _, domain = host.partition('.')
name = name or hname
keys = dict({
'h': host, 'n': name, 'd': domain,
'i': _fmt_process_index, 'I': _fmt_process_index_with_prefix,
}, **extra)
keys = dict(
{
'h': host,
'n': name,
'd': domain,
'i': _fmt_process_index,
'I': _fmt_process_index_with_prefix,
},
**extra,
)
return simple_format(s, keys)

View File

@@ -0,0 +1,20 @@
from __future__ import annotations
def detect_quorum_queues(app, driver_type: str) -> tuple[bool, str]:
"""Detect if any of the queues are quorum queues.
Returns:
tuple[bool, str]: A tuple containing a boolean indicating if any of the queues are quorum queues
and the name of the first quorum queue found or an empty string if no quorum queues were found.
"""
is_rabbitmq_broker = driver_type == 'amqp'
if is_rabbitmq_broker:
queues = app.amqp.queues
for qname in queues:
qarguments = queues[qname].queue_arguments or {}
if qarguments.get("x-queue-type") == "quorum":
return True, qname
return False, ""

View File

@@ -15,7 +15,7 @@ from decimal import Decimal
from itertools import chain
from numbers import Number
from pprint import _recursion
from typing import Any, AnyStr, Callable, Dict, Iterator, List, Sequence, Set, Tuple # noqa
from typing import Any, AnyStr, Callable, Dict, Iterator, List, Optional, Sequence, Set, Tuple # noqa
from .text import truncate
@@ -41,7 +41,7 @@ _quoted = namedtuple('_quoted', ('value',))
#: Recursion protection.
_dirty = namedtuple('_dirty', ('objid',))
#: Types that are repsented as chars.
#: Types that are represented as chars.
chars_t = (bytes, str)
#: Types that are regarded as safe to call repr on.
@@ -194,9 +194,12 @@ def _reprseq(val, lit_start, lit_end, builtin_type, chainer):
)
def reprstream(stack, seen=None, maxlevels=3, level=0, isinstance=isinstance):
def reprstream(stack: deque,
seen: Optional[Set] = None,
maxlevels: int = 3,
level: int = 0,
isinstance: Callable = isinstance) -> Iterator[Any]:
"""Streaming repr, yielding tokens."""
# type: (deque, Set, int, int, Callable) -> Iterator[Any]
seen = seen or set()
append = stack.append
popleft = stack.popleft

View File

@@ -1,4 +1,6 @@
"""System information utilities."""
from __future__ import annotations
import os
from math import ceil
@@ -9,16 +11,16 @@ __all__ = ('load_average', 'df')
if hasattr(os, 'getloadavg'):
def _load_average():
def _load_average() -> tuple[float, ...]:
return tuple(ceil(l * 1e2) / 1e2 for l in os.getloadavg())
else: # pragma: no cover
# Windows doesn't have getloadavg
def _load_average():
return (0.0, 0.0, 0.0)
def _load_average() -> tuple[float, ...]:
return 0.0, 0.0, 0.0,
def load_average():
def load_average() -> tuple[float, ...]:
"""Return system load average as a triple."""
return _load_average()
@@ -26,23 +28,23 @@ def load_average():
class df:
"""Disk information."""
def __init__(self, path):
def __init__(self, path: str | bytes | os.PathLike) -> None:
self.path = path
@property
def total_blocks(self):
def total_blocks(self) -> float:
return self.stat.f_blocks * self.stat.f_frsize / 1024
@property
def available(self):
def available(self) -> float:
return self.stat.f_bavail * self.stat.f_frsize / 1024
@property
def capacity(self):
def capacity(self) -> int:
avail = self.stat.f_bavail
used = self.stat.f_blocks - self.stat.f_bfree
return int(ceil(used * 100.0 / (used + avail) + 0.5))
@cached_property
def stat(self):
def stat(self) -> os.statvfs_result:
return os.statvfs(os.path.abspath(self.path))

View File

@@ -1,6 +1,7 @@
"""Terminals and colors."""
from __future__ import annotations
import base64
import codecs
import os
import platform
import sys
@@ -8,6 +9,8 @@ from functools import reduce
__all__ = ('colored',)
from typing import Any
BLACK, RED, GREEN, YELLOW, BLUE, MAGENTA, CYAN, WHITE = range(8)
OP_SEQ = '\033[%dm'
RESET_SEQ = '\033[0m'
@@ -26,7 +29,7 @@ _IMG_PRE = '\033Ptmux;\033\033]' if TERM_IS_SCREEN else '\033]'
_IMG_POST = '\a\033\\' if TERM_IS_SCREEN else '\a'
def fg(s):
def fg(s: int) -> str:
return COLOR_SEQ % s
@@ -41,11 +44,11 @@ class colored:
... c.green('dog ')))
"""
def __init__(self, *s, **kwargs):
self.s = s
self.enabled = not IS_WINDOWS and kwargs.get('enabled', True)
self.op = kwargs.get('op', '')
self.names = {
def __init__(self, *s: object, **kwargs: Any) -> None:
self.s: tuple[object, ...] = s
self.enabled: bool = not IS_WINDOWS and kwargs.get('enabled', True)
self.op: str = kwargs.get('op', '')
self.names: dict[str, Any] = {
'black': self.black,
'red': self.red,
'green': self.green,
@@ -56,10 +59,10 @@ class colored:
'white': self.white,
}
def _add(self, a, b):
return str(a) + str(b)
def _add(self, a: object, b: object) -> str:
return f"{a}{b}"
def _fold_no_color(self, a, b):
def _fold_no_color(self, a: Any, b: Any) -> str:
try:
A = a.no_color()
except AttributeError:
@@ -69,109 +72,113 @@ class colored:
except AttributeError:
B = str(b)
return ''.join((str(A), str(B)))
return f"{A}{B}"
def no_color(self):
def no_color(self) -> str:
if self.s:
return str(reduce(self._fold_no_color, self.s))
return ''
def embed(self):
def embed(self) -> str:
prefix = ''
if self.enabled:
prefix = self.op
return ''.join((str(prefix), str(reduce(self._add, self.s))))
return f"{prefix}{reduce(self._add, self.s)}"
def __str__(self):
def __str__(self) -> str:
suffix = ''
if self.enabled:
suffix = RESET_SEQ
return str(''.join((self.embed(), str(suffix))))
return f"{self.embed()}{suffix}"
def node(self, s, op):
def node(self, s: tuple[object, ...], op: str) -> colored:
return self.__class__(enabled=self.enabled, op=op, *s)
def black(self, *s):
def black(self, *s: object) -> colored:
return self.node(s, fg(30 + BLACK))
def red(self, *s):
def red(self, *s: object) -> colored:
return self.node(s, fg(30 + RED))
def green(self, *s):
def green(self, *s: object) -> colored:
return self.node(s, fg(30 + GREEN))
def yellow(self, *s):
def yellow(self, *s: object) -> colored:
return self.node(s, fg(30 + YELLOW))
def blue(self, *s):
def blue(self, *s: object) -> colored:
return self.node(s, fg(30 + BLUE))
def magenta(self, *s):
def magenta(self, *s: object) -> colored:
return self.node(s, fg(30 + MAGENTA))
def cyan(self, *s):
def cyan(self, *s: object) -> colored:
return self.node(s, fg(30 + CYAN))
def white(self, *s):
def white(self, *s: object) -> colored:
return self.node(s, fg(30 + WHITE))
def __repr__(self):
def __repr__(self) -> str:
return repr(self.no_color())
def bold(self, *s):
def bold(self, *s: object) -> colored:
return self.node(s, OP_SEQ % 1)
def underline(self, *s):
def underline(self, *s: object) -> colored:
return self.node(s, OP_SEQ % 4)
def blink(self, *s):
def blink(self, *s: object) -> colored:
return self.node(s, OP_SEQ % 5)
def reverse(self, *s):
def reverse(self, *s: object) -> colored:
return self.node(s, OP_SEQ % 7)
def bright(self, *s):
def bright(self, *s: object) -> colored:
return self.node(s, OP_SEQ % 8)
def ired(self, *s):
def ired(self, *s: object) -> colored:
return self.node(s, fg(40 + RED))
def igreen(self, *s):
def igreen(self, *s: object) -> colored:
return self.node(s, fg(40 + GREEN))
def iyellow(self, *s):
def iyellow(self, *s: object) -> colored:
return self.node(s, fg(40 + YELLOW))
def iblue(self, *s):
def iblue(self, *s: colored) -> colored:
return self.node(s, fg(40 + BLUE))
def imagenta(self, *s):
def imagenta(self, *s: object) -> colored:
return self.node(s, fg(40 + MAGENTA))
def icyan(self, *s):
def icyan(self, *s: object) -> colored:
return self.node(s, fg(40 + CYAN))
def iwhite(self, *s):
def iwhite(self, *s: object) -> colored:
return self.node(s, fg(40 + WHITE))
def reset(self, *s):
return self.node(s or [''], RESET_SEQ)
def reset(self, *s: object) -> colored:
return self.node(s or ('',), RESET_SEQ)
def __add__(self, other):
return str(self) + str(other)
def __add__(self, other: object) -> str:
return f"{self}{other}"
def supports_images():
return sys.stdin.isatty() and ITERM_PROFILE
def supports_images() -> bool:
try:
return sys.stdin.isatty() and bool(os.environ.get('ITERM_PROFILE'))
except AttributeError:
return False
def _read_as_base64(path):
with codecs.open(path, mode='rb') as fh:
def _read_as_base64(path: str) -> str:
with open(path, mode='rb') as fh:
encoded = base64.b64encode(fh.read())
return encoded if isinstance(encoded, str) else encoded.decode('ascii')
return encoded.decode('ascii')
def imgcat(path, inline=1, preserve_aspect_ratio=0, **kwargs):
def imgcat(path: str, inline: int = 1, preserve_aspect_ratio: int = 0, **kwargs: Any) -> str:
return '\n%s1337;File=inline=%d;preserveAspectRatio=%d:%s%s' % (
_IMG_PRE, inline, preserve_aspect_ratio,
_read_as_base64(path), _IMG_POST)

View File

@@ -14,6 +14,7 @@ from types import ModuleType
from typing import Any, Callable
from dateutil import tz as dateutil_tz
from dateutil.parser import isoparse
from kombu.utils.functional import reprcall
from kombu.utils.objects import cached_property
@@ -40,6 +41,9 @@ C_REMDEBUG = os.environ.get('C_REMDEBUG', False)
DAYNAMES = 'sun', 'mon', 'tue', 'wed', 'thu', 'fri', 'sat'
WEEKDAYS = dict(zip(DAYNAMES, range(7)))
MONTHNAMES = 'jan', 'feb', 'mar', 'apr', 'may', 'jun', 'jul', 'aug', 'sep', 'oct', 'nov', 'dec'
YEARMONTHS = dict(zip(MONTHNAMES, range(1, 13)))
RATE_MODIFIER_MAP = {
's': lambda n: n,
'm': lambda n: n / 60.0,
@@ -200,7 +204,7 @@ def delta_resolution(dt: datetime, delta: timedelta) -> datetime:
def remaining(
start: datetime, ends_in: timedelta, now: Callable | None = None,
relative: bool = False) -> timedelta:
"""Calculate the remaining time for a start date and a timedelta.
"""Calculate the real remaining time for a start date and a timedelta.
For example, "how many seconds left for 30 seconds after start?"
@@ -211,24 +215,28 @@ def remaining(
using :func:`delta_resolution` (i.e., rounded to the
resolution of `ends_in`).
now (Callable): Function returning the current time and date.
Defaults to :func:`datetime.utcnow`.
Defaults to :func:`datetime.now(timezone.utc)`.
Returns:
~datetime.timedelta: Remaining time.
"""
now = now or datetime.utcnow()
if str(
start.tzinfo) == str(
now.tzinfo) and now.utcoffset() != start.utcoffset():
# DST started/ended
start = start.replace(tzinfo=now.tzinfo)
now = now or datetime.now(datetime_timezone.utc)
end_date = start + ends_in
if relative:
end_date = delta_resolution(end_date, ends_in).replace(microsecond=0)
ret = end_date - now
# Using UTC to calculate real time difference.
# Python by default uses wall time in arithmetic between datetimes with
# equal non-UTC timezones.
now_utc = now.astimezone(timezone.utc)
end_date_utc = end_date.astimezone(timezone.utc)
ret = end_date_utc - now_utc
if C_REMDEBUG: # pragma: no cover
print('rem: NOW:{!r} START:{!r} ENDS_IN:{!r} END_DATE:{} REM:{}'.format(
now, start, ends_in, end_date, ret))
print(
'rem: NOW:{!r} NOW_UTC:{!r} START:{!r} ENDS_IN:{!r} '
'END_DATE:{} END_DATE_UTC:{!r} REM:{}'.format(
now, now_utc, start, ends_in, end_date, end_date_utc, ret)
)
return ret
@@ -257,6 +265,21 @@ def weekday(name: str) -> int:
raise KeyError(name)
def yearmonth(name: str) -> int:
"""Return the position of a month: 1 - 12, where 1 is January.
Example:
>>> yearmonth('january'), yearmonth('jan'), yearmonth('may')
(1, 1, 5)
"""
abbreviation = name[0:3].lower()
try:
return YEARMONTHS[abbreviation]
except KeyError:
# Show original day name in exception, instead of abbr.
raise KeyError(name)
def humanize_seconds(
secs: int, prefix: str = '', sep: str = '', now: str = 'now',
microseconds: bool = False) -> str:
@@ -288,7 +311,7 @@ def maybe_iso8601(dt: datetime | str | None) -> None | datetime:
return
if isinstance(dt, datetime):
return dt
return datetime.fromisoformat(dt)
return isoparse(dt)
def is_naive(dt: datetime) -> bool:
@@ -302,7 +325,7 @@ def _can_detect_ambiguous(tz: tzinfo) -> bool:
return isinstance(tz, ZoneInfo) or hasattr(tz, "is_ambiguous")
def _is_ambigious(dt: datetime, tz: tzinfo) -> bool:
def _is_ambiguous(dt: datetime, tz: tzinfo) -> bool:
"""Helper function to determine if a timezone is ambiguous using python's dateutil module.
Returns False if the timezone cannot detect ambiguity, or if there is no ambiguity, otherwise True.
@@ -319,7 +342,7 @@ def make_aware(dt: datetime, tz: tzinfo) -> datetime:
"""Set timezone for a :class:`~datetime.datetime` object."""
dt = dt.replace(tzinfo=tz)
if _is_ambigious(dt, tz):
if _is_ambiguous(dt, tz):
dt = min(dt.replace(fold=0), dt.replace(fold=1))
return dt

View File

@@ -10,6 +10,7 @@ import threading
from itertools import count
from threading import TIMEOUT_MAX as THREAD_TIMEOUT_MAX
from time import sleep
from typing import Any, Callable, Iterator, Optional, Tuple
from kombu.asynchronous.timer import Entry
from kombu.asynchronous.timer import Timer as Schedule
@@ -30,20 +31,23 @@ class Timer(threading.Thread):
Entry = Entry
Schedule = Schedule
running = False
on_tick = None
running: bool = False
on_tick: Optional[Callable[[float], None]] = None
_timer_count = count(1)
_timer_count: count = count(1)
if TIMER_DEBUG: # pragma: no cover
def start(self, *args, **kwargs):
def start(self, *args: Any, **kwargs: Any) -> None:
import traceback
print('- Timer starting')
traceback.print_stack()
super().start(*args, **kwargs)
def __init__(self, schedule=None, on_error=None, on_tick=None,
on_start=None, max_interval=None, **kwargs):
def __init__(self, schedule: Optional[Schedule] = None,
on_error: Optional[Callable[[Exception], None]] = None,
on_tick: Optional[Callable[[float], None]] = None,
on_start: Optional[Callable[['Timer'], None]] = None,
max_interval: Optional[float] = None, **kwargs: Any) -> None:
self.schedule = schedule or self.Schedule(on_error=on_error,
max_interval=max_interval)
self.on_start = on_start
@@ -60,8 +64,10 @@ class Timer(threading.Thread):
self.daemon = True
self.name = f'Timer-{next(self._timer_count)}'
def _next_entry(self):
def _next_entry(self) -> Optional[float]:
with self.not_empty:
delay: Optional[float]
entry: Optional[Entry]
delay, entry = next(self.scheduler)
if entry is None:
if delay is None:
@@ -70,10 +76,10 @@ class Timer(threading.Thread):
return self.schedule.apply_entry(entry)
__next__ = next = _next_entry # for 2to3
def run(self):
def run(self) -> None:
try:
self.running = True
self.scheduler = iter(self.schedule)
self.scheduler: Iterator[Tuple[Optional[float], Optional[Entry]]] = iter(self.schedule)
while not self.__is_shutdown.is_set():
delay = self._next_entry()
@@ -94,61 +100,61 @@ class Timer(threading.Thread):
sys.stderr.flush()
os._exit(1)
def stop(self):
def stop(self) -> None:
self.__is_shutdown.set()
if self.running:
self.__is_stopped.wait()
self.join(THREAD_TIMEOUT_MAX)
self.running = False
def ensure_started(self):
def ensure_started(self) -> None:
if not self.running and not self.is_alive():
if self.on_start:
self.on_start(self)
self.start()
def _do_enter(self, meth, *args, **kwargs):
def _do_enter(self, meth: str, *args: Any, **kwargs: Any) -> Entry:
self.ensure_started()
with self.mutex:
entry = getattr(self.schedule, meth)(*args, **kwargs)
self.not_empty.notify()
return entry
def enter(self, entry, eta, priority=None):
def enter(self, entry: Entry, eta: float, priority: Optional[int] = None) -> Entry:
return self._do_enter('enter_at', entry, eta, priority=priority)
def call_at(self, *args, **kwargs):
def call_at(self, *args: Any, **kwargs: Any) -> Entry:
return self._do_enter('call_at', *args, **kwargs)
def enter_after(self, *args, **kwargs):
def enter_after(self, *args: Any, **kwargs: Any) -> Entry:
return self._do_enter('enter_after', *args, **kwargs)
def call_after(self, *args, **kwargs):
def call_after(self, *args: Any, **kwargs: Any) -> Entry:
return self._do_enter('call_after', *args, **kwargs)
def call_repeatedly(self, *args, **kwargs):
def call_repeatedly(self, *args: Any, **kwargs: Any) -> Entry:
return self._do_enter('call_repeatedly', *args, **kwargs)
def exit_after(self, secs, priority=10):
def exit_after(self, secs: float, priority: int = 10) -> None:
self.call_after(secs, sys.exit, priority)
def cancel(self, tref):
def cancel(self, tref: Entry) -> None:
tref.cancel()
def clear(self):
def clear(self) -> None:
self.schedule.clear()
def empty(self):
def empty(self) -> bool:
return not len(self)
def __len__(self):
def __len__(self) -> int:
return len(self.schedule)
def __bool__(self):
def __bool__(self) -> bool:
"""``bool(timer)``."""
return True
__nonzero__ = __bool__
@property
def queue(self):
def queue(self) -> list:
return self.schedule.queue

View File

@@ -169,6 +169,7 @@ class Consumer:
'celery.worker.consumer.heart:Heart',
'celery.worker.consumer.control:Control',
'celery.worker.consumer.tasks:Tasks',
'celery.worker.consumer.delayed_delivery:DelayedDelivery',
'celery.worker.consumer.consumer:Evloop',
'celery.worker.consumer.agent:Agent',
]
@@ -390,20 +391,21 @@ class Consumer:
else:
warnings.warn(CANCEL_TASKS_BY_DEFAULT, CPendingDeprecationWarning)
self.initial_prefetch_count = max(
self.prefetch_multiplier,
self.max_prefetch_count - len(tuple(active_requests)) * self.prefetch_multiplier
)
self._maximum_prefetch_restored = self.initial_prefetch_count == self.max_prefetch_count
if not self._maximum_prefetch_restored:
logger.info(
f"Temporarily reducing the prefetch count to {self.initial_prefetch_count} to avoid over-fetching "
f"since {len(tuple(active_requests))} tasks are currently being processed.\n"
f"The prefetch count will be gradually restored to {self.max_prefetch_count} as the tasks "
"complete processing."
if self.app.conf.worker_enable_prefetch_count_reduction:
self.initial_prefetch_count = max(
self.prefetch_multiplier,
self.max_prefetch_count - len(tuple(active_requests)) * self.prefetch_multiplier
)
self._maximum_prefetch_restored = self.initial_prefetch_count == self.max_prefetch_count
if not self._maximum_prefetch_restored:
logger.info(
f"Temporarily reducing the prefetch count to {self.initial_prefetch_count} to avoid "
f"over-fetching since {len(tuple(active_requests))} tasks are currently being processed.\n"
f"The prefetch count will be gradually restored to {self.max_prefetch_count} as the tasks "
"complete processing."
)
def register_with_event_loop(self, hub):
self.blueprint.send_all(
self, 'register_with_event_loop', args=(hub,),
@@ -411,6 +413,7 @@ class Consumer:
)
def shutdown(self):
self.perform_pending_operations()
self.blueprint.shutdown(self)
def stop(self):
@@ -475,9 +478,9 @@ class Consumer:
return self.ensure_connected(
self.app.connection_for_read(heartbeat=heartbeat))
def connection_for_write(self, heartbeat=None):
def connection_for_write(self, url=None, heartbeat=None):
return self.ensure_connected(
self.app.connection_for_write(heartbeat=heartbeat))
self.app.connection_for_write(url=url, heartbeat=heartbeat))
def ensure_connected(self, conn):
# Callback called for each retry while the connection
@@ -504,13 +507,14 @@ class Consumer:
# to determine whether connection retries are disabled.
retry_disabled = not self.app.conf.broker_connection_retry
warnings.warn(
CPendingDeprecationWarning(
f"The broker_connection_retry configuration setting will no longer determine\n"
f"whether broker connection retries are made during startup in Celery 6.0 and above.\n"
f"If you wish to retain the existing behavior for retrying connections on startup,\n"
f"you should set broker_connection_retry_on_startup to {self.app.conf.broker_connection_retry}.")
)
if retry_disabled:
warnings.warn(
CPendingDeprecationWarning(
"The broker_connection_retry configuration setting will no longer determine\n"
"whether broker connection retries are made during startup in Celery 6.0 and above.\n"
"If you wish to refrain from retrying connections on startup,\n"
"you should set broker_connection_retry_on_startup to False instead.")
)
else:
if self.first_connection_attempt:
retry_disabled = not self.app.conf.broker_connection_retry_on_startup
@@ -696,7 +700,10 @@ class Consumer:
def _restore_prefetch_count_after_connection_restart(self, p, *args):
with self.qos._mutex:
if self._maximum_prefetch_restored:
if any((
not self.app.conf.worker_enable_prefetch_count_reduction,
self._maximum_prefetch_restored,
)):
return
new_prefetch_count = min(self.max_prefetch_count, self._new_prefetch_count)
@@ -726,6 +733,29 @@ class Consumer:
self=self, state=self.blueprint.human_state(),
)
def cancel_all_unacked_requests(self):
"""Cancel all active requests that either do not require late acknowledgments or,
if they do, have not been acknowledged yet.
"""
def should_cancel(request):
if not request.task.acks_late:
# Task does not require late acknowledgment, cancel it.
return True
if not request.acknowledged:
# Task is late acknowledged, but it has not been acknowledged yet, cancel it.
return True
# Task is late acknowledged, but it has already been acknowledged.
return False # Do not cancel and allow it to gracefully finish as it has already been acknowledged.
requests_to_cancel = tuple(filter(should_cancel, active_requests))
if requests_to_cancel:
for request in requests_to_cancel:
request.cancel(self.pool)
class Evloop(bootsteps.StartStopStep):
"""Event loop service.

View File

@@ -0,0 +1,247 @@
"""Native delayed delivery functionality for Celery workers.
This module provides the DelayedDelivery bootstep which handles setup and configuration
of native delayed delivery functionality when using quorum queues.
"""
from typing import Iterator, List, Optional, Set, Union, ValuesView
from kombu import Connection, Queue
from kombu.transport.native_delayed_delivery import (bind_queue_to_native_delayed_delivery_exchange,
declare_native_delayed_delivery_exchanges_and_queues)
from kombu.utils.functional import retry_over_time
from celery import Celery, bootsteps
from celery.utils.log import get_logger
from celery.utils.quorum_queues import detect_quorum_queues
from celery.worker.consumer import Consumer, Tasks
__all__ = ('DelayedDelivery',)
logger = get_logger(__name__)
# Default retry settings
RETRY_INTERVAL = 1.0 # seconds between retries
MAX_RETRIES = 3 # maximum number of retries
# Valid queue types for delayed delivery
VALID_QUEUE_TYPES = {'classic', 'quorum'}
class DelayedDelivery(bootsteps.StartStopStep):
"""Bootstep that sets up native delayed delivery functionality.
This component handles the setup and configuration of native delayed delivery
for Celery workers. It is automatically included when quorum queues are
detected in the application configuration.
Responsibilities:
- Declaring native delayed delivery exchanges and queues
- Binding all application queues to the delayed delivery exchanges
- Handling connection failures gracefully with retries
- Validating configuration settings
"""
requires = (Tasks,)
def include_if(self, c: Consumer) -> bool:
"""Determine if this bootstep should be included.
Args:
c: The Celery consumer instance
Returns:
bool: True if quorum queues are detected, False otherwise
"""
return detect_quorum_queues(c.app, c.app.connection_for_write().transport.driver_type)[0]
def start(self, c: Consumer) -> None:
"""Initialize delayed delivery for all broker URLs.
Attempts to set up delayed delivery for each broker URL in the configuration.
Failures are logged but don't prevent attempting remaining URLs.
Args:
c: The Celery consumer instance
Raises:
ValueError: If configuration validation fails
"""
app: Celery = c.app
try:
self._validate_configuration(app)
except ValueError as e:
logger.critical("Configuration validation failed: %s", str(e))
raise
broker_urls = self._validate_broker_urls(app.conf.broker_url)
setup_errors = []
for broker_url in broker_urls:
try:
retry_over_time(
self._setup_delayed_delivery,
args=(c, broker_url),
catch=(ConnectionRefusedError, OSError),
errback=self._on_retry,
interval_start=RETRY_INTERVAL,
max_retries=MAX_RETRIES,
)
except Exception as e:
logger.warning(
"Failed to setup delayed delivery for %r: %s",
broker_url, str(e)
)
setup_errors.append((broker_url, e))
if len(setup_errors) == len(broker_urls):
logger.critical(
"Failed to setup delayed delivery for all broker URLs. "
"Native delayed delivery will not be available."
)
def _setup_delayed_delivery(self, c: Consumer, broker_url: str) -> None:
"""Set up delayed delivery for a specific broker URL.
Args:
c: The Celery consumer instance
broker_url: The broker URL to configure
Raises:
ConnectionRefusedError: If connection to the broker fails
OSError: If there are network-related issues
Exception: For other unexpected errors during setup
"""
connection: Connection = c.app.connection_for_write(url=broker_url)
queue_type = c.app.conf.broker_native_delayed_delivery_queue_type
logger.debug(
"Setting up delayed delivery for broker %r with queue type %r",
broker_url, queue_type
)
try:
declare_native_delayed_delivery_exchanges_and_queues(
connection,
queue_type
)
except Exception as e:
logger.warning(
"Failed to declare exchanges and queues for %r: %s",
broker_url, str(e)
)
raise
try:
self._bind_queues(c.app, connection)
except Exception as e:
logger.warning(
"Failed to bind queues for %r: %s",
broker_url, str(e)
)
raise
def _bind_queues(self, app: Celery, connection: Connection) -> None:
"""Bind all application queues to delayed delivery exchanges.
Args:
app: The Celery application instance
connection: The broker connection to use
Raises:
Exception: If queue binding fails
"""
queues: ValuesView[Queue] = app.amqp.queues.values()
if not queues:
logger.warning("No queues found to bind for delayed delivery")
return
for queue in queues:
try:
logger.debug("Binding queue %r to delayed delivery exchange", queue.name)
bind_queue_to_native_delayed_delivery_exchange(connection, queue)
except Exception as e:
logger.error(
"Failed to bind queue %r: %s",
queue.name, str(e)
)
raise
def _on_retry(self, exc: Exception, interval_range: Iterator[float], intervals_count: int) -> None:
"""Callback for retry attempts.
Args:
exc: The exception that triggered the retry
interval_range: An iterator which returns the time in seconds to sleep next
intervals_count: Number of retry attempts so far
"""
logger.warning(
"Retrying delayed delivery setup (attempt %d/%d) after error: %s",
intervals_count + 1, MAX_RETRIES, str(exc)
)
def _validate_configuration(self, app: Celery) -> None:
"""Validate all required configuration settings.
Args:
app: The Celery application instance
Raises:
ValueError: If any configuration is invalid
"""
# Validate broker URLs
self._validate_broker_urls(app.conf.broker_url)
# Validate queue type
self._validate_queue_type(app.conf.broker_native_delayed_delivery_queue_type)
def _validate_broker_urls(self, broker_urls: Union[str, List[str]]) -> Set[str]:
"""Validate and split broker URLs.
Args:
broker_urls: Broker URLs, either as a semicolon-separated string
or as a list of strings
Returns:
Set of valid broker URLs
Raises:
ValueError: If no valid broker URLs are found or if invalid URLs are provided
"""
if not broker_urls:
raise ValueError("broker_url configuration is empty")
if isinstance(broker_urls, str):
brokers = broker_urls.split(";")
elif isinstance(broker_urls, list):
if not all(isinstance(url, str) for url in broker_urls):
raise ValueError("All broker URLs must be strings")
brokers = broker_urls
else:
raise ValueError(f"broker_url must be a string or list, got {broker_urls!r}")
valid_urls = {url for url in brokers}
if not valid_urls:
raise ValueError("No valid broker URLs found in configuration")
return valid_urls
def _validate_queue_type(self, queue_type: Optional[str]) -> None:
"""Validate the queue type configuration.
Args:
queue_type: The configured queue type
Raises:
ValueError: If queue type is invalid
"""
if not queue_type:
raise ValueError("broker_native_delayed_delivery_queue_type is not configured")
if queue_type not in VALID_QUEUE_TYPES:
sorted_types = sorted(VALID_QUEUE_TYPES)
raise ValueError(
f"Invalid queue type {queue_type!r}. Must be one of: {', '.join(sorted_types)}"
)

View File

@@ -176,6 +176,7 @@ class Gossip(bootsteps.ConsumerStep):
channel,
queues=[ev.queue],
on_message=partial(self.on_message, ev.event_from_message),
accept=ev.accept,
no_ack=True
)]

View File

@@ -22,7 +22,7 @@ class Mingle(bootsteps.StartStopStep):
label = 'Mingle'
requires = (Events,)
compatible_transports = {'amqp', 'redis'}
compatible_transports = {'amqp', 'redis', 'gcpubsub'}
def __init__(self, c, without_mingle=False, **kwargs):
self.enabled = not without_mingle and self.compatible_transport(c.app)

View File

@@ -1,13 +1,18 @@
"""Worker Task Consumer Bootstep."""
from __future__ import annotations
from kombu.common import QoS, ignore_errors
from celery import bootsteps
from celery.utils.log import get_logger
from celery.utils.quorum_queues import detect_quorum_queues
from .mingle import Mingle
__all__ = ('Tasks',)
logger = get_logger(__name__)
debug = logger.debug
@@ -25,10 +30,7 @@ class Tasks(bootsteps.StartStopStep):
"""Start task consumer."""
c.update_strategies()
# - RabbitMQ 3.3 completely redefines how basic_qos works...
# This will detect if the new qos semantics is in effect,
# and if so make sure the 'apply_global' flag is set on qos updates.
qos_global = not c.connection.qos_semantics_matches_spec
qos_global = self.qos_global(c)
# set initial prefetch count
c.connection.default_channel.basic_qos(
@@ -63,3 +65,24 @@ class Tasks(bootsteps.StartStopStep):
def info(self, c):
"""Return task consumer info."""
return {'prefetch_count': c.qos.value if c.qos else 'N/A'}
def qos_global(self, c) -> bool:
"""Determine if global QoS should be applied.
Additional information:
https://www.rabbitmq.com/docs/consumer-prefetch
https://www.rabbitmq.com/docs/quorum-queues#global-qos
"""
# - RabbitMQ 3.3 completely redefines how basic_qos works...
# This will detect if the new qos semantics is in effect,
# and if so make sure the 'apply_global' flag is set on qos updates.
qos_global = not c.connection.qos_semantics_matches_spec
if c.app.conf.worker_detect_quorum_queues:
using_quorum_queues, qname = detect_quorum_queues(c.app, c.connection.transport.driver_type)
if using_quorum_queues:
qos_global = False
logger.info("Global QoS is disabled. Prefetch count in now static.")
return qos_global

View File

@@ -7,6 +7,7 @@ from billiard.common import TERM_SIGNAME
from kombu.utils.encoding import safe_repr
from celery.exceptions import WorkerShutdown
from celery.platforms import EX_OK
from celery.platforms import signals as _signals
from celery.utils.functional import maybe_list
from celery.utils.log import get_logger
@@ -580,7 +581,7 @@ def autoscale(state, max=None, min=None):
def shutdown(state, msg='Got shutdown from remote', **kwargs):
"""Shutdown worker(s)."""
logger.warning(msg)
raise WorkerShutdown(msg)
raise WorkerShutdown(EX_OK)
# -- Queues

View File

@@ -119,8 +119,10 @@ def synloop(obj, connection, consumer, blueprint, hub, qos,
obj.on_ready()
while blueprint.state == RUN and obj.connection:
state.maybe_shutdown()
def _loop_cycle():
"""
Perform one iteration of the blocking event loop.
"""
if heartbeat_error[0] is not None:
raise heartbeat_error[0]
if qos.prev != qos.value:
@@ -133,3 +135,9 @@ def synloop(obj, connection, consumer, blueprint, hub, qos,
except OSError:
if blueprint.state == RUN:
raise
while blueprint.state == RUN and obj.connection:
try:
state.maybe_shutdown()
finally:
_loop_cycle()

View File

@@ -602,8 +602,8 @@ class Request:
is_worker_lost = isinstance(exc, WorkerLostError)
if self.task.acks_late:
reject = (
self.task.reject_on_worker_lost and
is_worker_lost
(self.task.reject_on_worker_lost and is_worker_lost)
or (isinstance(exc, TimeLimitExceeded) and not self.task.acks_on_failure_or_timeout)
)
ack = self.task.acks_on_failure_or_timeout
if reject:
@@ -777,7 +777,7 @@ def create_request_cls(base, task, pool, hostname, eventer,
if isinstance(exc, (SystemExit, KeyboardInterrupt)):
raise exc
return self.on_failure(retval, return_ok=True)
task_ready(self)
task_ready(self, successful=True)
if acks_late:
self.acknowledge()

View File

@@ -14,7 +14,8 @@ The worker consists of several components, all managed by bootsteps
import os
import sys
from datetime import datetime
from datetime import datetime, timezone
from time import sleep
from billiard import cpu_count
from kombu.utils.compat import detect_environment
@@ -89,7 +90,7 @@ class WorkController:
def __init__(self, app=None, hostname=None, **kwargs):
self.app = app or self.app
self.hostname = default_nodename(hostname)
self.startup_time = datetime.utcnow()
self.startup_time = datetime.now(timezone.utc)
self.app.loader.init_worker()
self.on_before_init(**kwargs)
self.setup_defaults(**kwargs)
@@ -241,7 +242,7 @@ class WorkController:
not self.app.IS_WINDOWS)
def stop(self, in_sighandler=False, exitcode=None):
"""Graceful shutdown of the worker server."""
"""Graceful shutdown of the worker server (Warm shutdown)."""
if exitcode is not None:
self.exitcode = exitcode
if self.blueprint.state == RUN:
@@ -251,7 +252,7 @@ class WorkController:
self._send_worker_shutdown()
def terminate(self, in_sighandler=False):
"""Not so graceful shutdown of the worker server."""
"""Not so graceful shutdown of the worker server (Cold shutdown)."""
if self.blueprint.state != TERMINATE:
self.signal_consumer_close()
if not in_sighandler or self.pool.signal_safe:
@@ -293,7 +294,7 @@ class WorkController:
return reload_from_cwd(sys.modules[module], reloader)
def info(self):
uptime = datetime.utcnow() - self.startup_time
uptime = datetime.now(timezone.utc) - self.startup_time
return {'total': self.state.total_count,
'pid': os.getpid(),
'clock': str(self.app.clock),
@@ -407,3 +408,28 @@ class WorkController:
'worker_disable_rate_limits', disable_rate_limits,
)
self.worker_lost_wait = either('worker_lost_wait', worker_lost_wait)
def wait_for_soft_shutdown(self):
"""Wait :setting:`worker_soft_shutdown_timeout` if soft shutdown is enabled.
To enable soft shutdown, set the :setting:`worker_soft_shutdown_timeout` in the
configuration. Soft shutdown can be used to allow the worker to finish processing
few more tasks before initiating a cold shutdown. This mechanism allows the worker
to finish short tasks that are already in progress and requeue long-running tasks
to be picked up by another worker.
.. warning::
If there are no tasks in the worker, the worker will not wait for the
soft shutdown timeout even if it is set as it makes no sense to wait for
the timeout when there are no tasks to process.
"""
app = self.app
requests = tuple(state.active_requests)
if app.conf.worker_enable_soft_shutdown_on_idle:
requests = True
if app.conf.worker_soft_shutdown_timeout > 0 and requests:
log = f"Initiating Soft Shutdown, terminating in {app.conf.worker_soft_shutdown_timeout} seconds"
logger.warning(log)
sleep(app.conf.worker_soft_shutdown_timeout)