Major fixes and new features
All checks were successful
continuous-integration/drone/push Build is passing

This commit is contained in:
2025-09-25 15:51:48 +09:00
parent dd7349bb4c
commit ddce9f5125
5586 changed files with 1470941 additions and 0 deletions

View File

@@ -0,0 +1,172 @@
"""Distributed Task Queue."""
# :copyright: (c) 2017-2026 Asif Saif Uddin, celery core and individual
# contributors, All rights reserved.
# :copyright: (c) 2015-2016 Ask Solem. All rights reserved.
# :copyright: (c) 2012-2014 GoPivotal, Inc., All rights reserved.
# :copyright: (c) 2009 - 2012 Ask Solem and individual contributors,
# All rights reserved.
# :license: BSD (3 Clause), see LICENSE for more details.
import os
import re
import sys
from collections import namedtuple
# Lazy loading
from . import local
SERIES = 'emerald-rush'
__version__ = '5.3.4'
__author__ = 'Ask Solem'
__contact__ = 'auvipy@gmail.com'
__homepage__ = 'https://docs.celeryq.dev/'
__docformat__ = 'restructuredtext'
__keywords__ = 'task job queue distributed messaging actor'
# -eof meta-
__all__ = (
'Celery', 'bugreport', 'shared_task', 'Task',
'current_app', 'current_task', 'maybe_signature',
'chain', 'chord', 'chunks', 'group', 'signature',
'xmap', 'xstarmap', 'uuid',
)
VERSION_BANNER = f'{__version__} ({SERIES})'
version_info_t = namedtuple('version_info_t', (
'major', 'minor', 'micro', 'releaselevel', 'serial',
))
# bumpversion can only search for {current_version}
# so we have to parse the version here.
_temp = re.match(
r'(\d+)\.(\d+).(\d+)(.+)?', __version__).groups()
VERSION = version_info = version_info_t(
int(_temp[0]), int(_temp[1]), int(_temp[2]), _temp[3] or '', '')
del _temp
del re
if os.environ.get('C_IMPDEBUG'): # pragma: no cover
import builtins
def debug_import(name, locals=None, globals=None,
fromlist=None, level=-1, real_import=builtins.__import__):
glob = globals or getattr(sys, 'emarfteg_'[::-1])(1).f_globals
importer_name = glob and glob.get('__name__') or 'unknown'
print(f'-- {importer_name} imports {name}')
return real_import(name, locals, globals, fromlist, level)
builtins.__import__ = debug_import
# This is never executed, but tricks static analyzers (PyDev, PyCharm,
# pylint, etc.) into knowing the types of these symbols, and what
# they contain.
STATICA_HACK = True
globals()['kcah_acitats'[::-1].upper()] = False
if STATICA_HACK: # pragma: no cover
from celery._state import current_app, current_task
from celery.app import shared_task
from celery.app.base import Celery
from celery.app.task import Task
from celery.app.utils import bugreport
from celery.canvas import (chain, chord, chunks, group, maybe_signature, signature, subtask, xmap, # noqa
xstarmap)
from celery.utils import uuid
# Eventlet/gevent patching must happen before importing
# anything else, so these tools must be at top-level.
def _find_option_with_arg(argv, short_opts=None, long_opts=None):
"""Search argv for options specifying short and longopt alternatives.
Returns:
str: value for option found
Raises:
KeyError: if option not found.
"""
for i, arg in enumerate(argv):
if arg.startswith('-'):
if long_opts and arg.startswith('--'):
name, sep, val = arg.partition('=')
if name in long_opts:
return val if sep else argv[i + 1]
if short_opts and arg in short_opts:
return argv[i + 1]
raise KeyError('|'.join(short_opts or [] + long_opts or []))
def _patch_eventlet():
import eventlet.debug
eventlet.monkey_patch()
blockdetect = float(os.environ.get('EVENTLET_NOBLOCK', 0))
if blockdetect:
eventlet.debug.hub_blocking_detection(blockdetect, blockdetect)
def _patch_gevent():
import gevent.monkey
import gevent.signal
gevent.monkey.patch_all()
def maybe_patch_concurrency(argv=None, short_opts=None,
long_opts=None, patches=None):
"""Apply eventlet/gevent monkeypatches.
With short and long opt alternatives that specify the command line
option to set the pool, this makes sure that anything that needs
to be patched is completed as early as possible.
(e.g., eventlet/gevent monkey patches).
"""
argv = argv if argv else sys.argv
short_opts = short_opts if short_opts else ['-P']
long_opts = long_opts if long_opts else ['--pool']
patches = patches if patches else {'eventlet': _patch_eventlet,
'gevent': _patch_gevent}
try:
pool = _find_option_with_arg(argv, short_opts, long_opts)
except KeyError:
pass
else:
try:
patcher = patches[pool]
except KeyError:
pass
else:
patcher()
# set up eventlet/gevent environments ASAP
from celery import concurrency
if pool in concurrency.get_available_pool_names():
concurrency.get_implementation(pool)
# this just creates a new module, that imports stuff on first attribute
# access. This makes the library faster to use.
old_module, new_module = local.recreate_module( # pragma: no cover
__name__,
by_module={
'celery.app': ['Celery', 'bugreport', 'shared_task'],
'celery.app.task': ['Task'],
'celery._state': ['current_app', 'current_task'],
'celery.canvas': [
'Signature', 'chain', 'chord', 'chunks', 'group',
'signature', 'maybe_signature', 'subtask',
'xmap', 'xstarmap',
],
'celery.utils': ['uuid'],
},
__package__='celery', __file__=__file__,
__path__=__path__, __doc__=__doc__, __version__=__version__,
__author__=__author__, __contact__=__contact__,
__homepage__=__homepage__, __docformat__=__docformat__, local=local,
VERSION=VERSION, SERIES=SERIES, VERSION_BANNER=VERSION_BANNER,
version_info_t=version_info_t,
version_info=version_info,
maybe_patch_concurrency=maybe_patch_concurrency,
_find_option_with_arg=_find_option_with_arg,
)

View File

@@ -0,0 +1,19 @@
"""Entry-point for the :program:`celery` umbrella command."""
import sys
from . import maybe_patch_concurrency
__all__ = ('main',)
def main() -> None:
"""Entrypoint to the ``celery`` umbrella command."""
if 'multi' not in sys.argv:
maybe_patch_concurrency()
from celery.bin.celery import main as _main
sys.exit(_main())
if __name__ == '__main__': # pragma: no cover
main()

View File

@@ -0,0 +1,197 @@
"""Internal state.
This is an internal module containing thread state
like the ``current_app``, and ``current_task``.
This module shouldn't be used directly.
"""
import os
import sys
import threading
import weakref
from celery.local import Proxy
from celery.utils.threads import LocalStack
__all__ = (
'set_default_app', 'get_current_app', 'get_current_task',
'get_current_worker_task', 'current_app', 'current_task',
'connect_on_app_finalize',
)
#: Global default app used when no current app.
default_app = None
#: Function returning the app provided or the default app if none.
#:
#: The environment variable :envvar:`CELERY_TRACE_APP` is used to
#: trace app leaks. When enabled an exception is raised if there
#: is no active app.
app_or_default = None
#: List of all app instances (weakrefs), mustn't be used directly.
_apps = weakref.WeakSet()
#: Global set of functions to call whenever a new app is finalized.
#: Shared tasks, and built-in tasks are created by adding callbacks here.
_on_app_finalizers = set()
_task_join_will_block = False
def connect_on_app_finalize(callback):
"""Connect callback to be called when any app is finalized."""
_on_app_finalizers.add(callback)
return callback
def _announce_app_finalized(app):
callbacks = set(_on_app_finalizers)
for callback in callbacks:
callback(app)
def _set_task_join_will_block(blocks):
global _task_join_will_block
_task_join_will_block = blocks
def task_join_will_block():
return _task_join_will_block
class _TLS(threading.local):
#: Apps with the :attr:`~celery.app.base.BaseApp.set_as_current` attribute
#: sets this, so it will always contain the last instantiated app,
#: and is the default app returned by :func:`app_or_default`.
current_app = None
_tls = _TLS()
_task_stack = LocalStack()
#: Function used to push a task to the thread local stack
#: keeping track of the currently executing task.
#: You must remember to pop the task after.
push_current_task = _task_stack.push
#: Function used to pop a task from the thread local stack
#: keeping track of the currently executing task.
pop_current_task = _task_stack.pop
def set_default_app(app):
"""Set default app."""
global default_app
default_app = app
def _get_current_app():
if default_app is None:
#: creates the global fallback app instance.
from celery.app.base import Celery
set_default_app(Celery(
'default', fixups=[], set_as_current=False,
loader=os.environ.get('CELERY_LOADER') or 'default',
))
return _tls.current_app or default_app
def _set_current_app(app):
_tls.current_app = app
if os.environ.get('C_STRICT_APP'): # pragma: no cover
def get_current_app():
"""Return the current app."""
raise RuntimeError('USES CURRENT APP')
elif os.environ.get('C_WARN_APP'): # pragma: no cover
def get_current_app():
import traceback
print('-- USES CURRENT_APP', file=sys.stderr) # +
traceback.print_stack(file=sys.stderr)
return _get_current_app()
else:
get_current_app = _get_current_app
def get_current_task():
"""Currently executing task."""
return _task_stack.top
def get_current_worker_task():
"""Currently executing task, that was applied by the worker.
This is used to differentiate between the actual task
executed by the worker and any task that was called within
a task (using ``task.__call__`` or ``task.apply``)
"""
for task in reversed(_task_stack.stack):
if not task.request.called_directly:
return task
#: Proxy to current app.
current_app = Proxy(get_current_app)
#: Proxy to current task.
current_task = Proxy(get_current_task)
def _register_app(app):
_apps.add(app)
def _deregister_app(app):
_apps.discard(app)
def _get_active_apps():
return _apps
def _app_or_default(app=None):
if app is None:
return get_current_app()
return app
def _app_or_default_trace(app=None): # pragma: no cover
from traceback import print_stack
try:
from billiard.process import current_process
except ImportError:
current_process = None
if app is None:
if getattr(_tls, 'current_app', None):
print('-- RETURNING TO CURRENT APP --') # +
print_stack()
return _tls.current_app
if not current_process or current_process()._name == 'MainProcess':
raise Exception('DEFAULT APP')
print('-- RETURNING TO DEFAULT APP --') # +
print_stack()
return default_app
return app
def enable_trace():
"""Enable tracing of app instances."""
global app_or_default
app_or_default = _app_or_default_trace
def disable_trace():
"""Disable tracing of app instances."""
global app_or_default
app_or_default = _app_or_default
if os.environ.get('CELERY_TRACE_APP'): # pragma: no cover
enable_trace()
else:
disable_trace()

View File

@@ -0,0 +1,76 @@
"""Celery Application."""
from celery import _state
from celery._state import app_or_default, disable_trace, enable_trace, pop_current_task, push_current_task
from celery.local import Proxy
from .base import Celery
from .utils import AppPickler
__all__ = (
'Celery', 'AppPickler', 'app_or_default', 'default_app',
'bugreport', 'enable_trace', 'disable_trace', 'shared_task',
'push_current_task', 'pop_current_task',
)
#: Proxy always returning the app set as default.
default_app = Proxy(lambda: _state.default_app)
def bugreport(app=None):
"""Return information useful in bug reports."""
return (app or _state.get_current_app()).bugreport()
def shared_task(*args, **kwargs):
"""Create shared task (decorator).
This can be used by library authors to create tasks that'll work
for any app environment.
Returns:
~celery.local.Proxy: A proxy that always takes the task from the
current apps task registry.
Example:
>>> from celery import Celery, shared_task
>>> @shared_task
... def add(x, y):
... return x + y
...
>>> app1 = Celery(broker='amqp://')
>>> add.app is app1
True
>>> app2 = Celery(broker='redis://')
>>> add.app is app2
True
"""
def create_shared_task(**options):
def __inner(fun):
name = options.get('name')
# Set as shared task so that unfinalized apps,
# and future apps will register a copy of this task.
_state.connect_on_app_finalize(
lambda app: app._task_from_fun(fun, **options)
)
# Force all finalized apps to take this task as well.
for app in _state._get_active_apps():
if app.finalized:
with app._finalize_mutex:
app._task_from_fun(fun, **options)
# Return a proxy that always gets the task from the current
# apps task registry.
def task_by_cons():
app = _state.get_current_app()
return app.tasks[
name or app.gen_task_name(fun.__name__, fun.__module__)
]
return Proxy(task_by_cons)
return __inner
if len(args) == 1 and callable(args[0]):
return create_shared_task(**kwargs)(args[0])
return create_shared_task(*args, **kwargs)

View File

@@ -0,0 +1,614 @@
"""Sending/Receiving Messages (Kombu integration)."""
import numbers
from collections import namedtuple
from collections.abc import Mapping
from datetime import timedelta
from weakref import WeakValueDictionary
from kombu import Connection, Consumer, Exchange, Producer, Queue, pools
from kombu.common import Broadcast
from kombu.utils.functional import maybe_list
from kombu.utils.objects import cached_property
from celery import signals
from celery.utils.nodenames import anon_nodename
from celery.utils.saferepr import saferepr
from celery.utils.text import indent as textindent
from celery.utils.time import maybe_make_aware
from . import routes as _routes
__all__ = ('AMQP', 'Queues', 'task_message')
#: earliest date supported by time.mktime.
INT_MIN = -2147483648
#: Human readable queue declaration.
QUEUE_FORMAT = """
.> {0.name:<16} exchange={0.exchange.name}({0.exchange.type}) \
key={0.routing_key}
"""
task_message = namedtuple('task_message',
('headers', 'properties', 'body', 'sent_event'))
def utf8dict(d, encoding='utf-8'):
return {k.decode(encoding) if isinstance(k, bytes) else k: v
for k, v in d.items()}
class Queues(dict):
"""Queue name⇒ declaration mapping.
Arguments:
queues (Iterable): Initial list/tuple or dict of queues.
create_missing (bool): By default any unknown queues will be
added automatically, but if this flag is disabled the occurrence
of unknown queues in `wanted` will raise :exc:`KeyError`.
max_priority (int): Default x-max-priority for queues with none set.
"""
#: If set, this is a subset of queues to consume from.
#: The rest of the queues are then used for routing only.
_consume_from = None
def __init__(self, queues=None, default_exchange=None,
create_missing=True, autoexchange=None,
max_priority=None, default_routing_key=None):
super().__init__()
self.aliases = WeakValueDictionary()
self.default_exchange = default_exchange
self.default_routing_key = default_routing_key
self.create_missing = create_missing
self.autoexchange = Exchange if autoexchange is None else autoexchange
self.max_priority = max_priority
if queues is not None and not isinstance(queues, Mapping):
queues = {q.name: q for q in queues}
queues = queues or {}
for name, q in queues.items():
self.add(q) if isinstance(q, Queue) else self.add_compat(name, **q)
def __getitem__(self, name):
try:
return self.aliases[name]
except KeyError:
return super().__getitem__(name)
def __setitem__(self, name, queue):
if self.default_exchange and not queue.exchange:
queue.exchange = self.default_exchange
super().__setitem__(name, queue)
if queue.alias:
self.aliases[queue.alias] = queue
def __missing__(self, name):
if self.create_missing:
return self.add(self.new_missing(name))
raise KeyError(name)
def add(self, queue, **kwargs):
"""Add new queue.
The first argument can either be a :class:`kombu.Queue` instance,
or the name of a queue. If the former the rest of the keyword
arguments are ignored, and options are simply taken from the queue
instance.
Arguments:
queue (kombu.Queue, str): Queue to add.
exchange (kombu.Exchange, str):
if queue is str, specifies exchange name.
routing_key (str): if queue is str, specifies binding key.
exchange_type (str): if queue is str, specifies type of exchange.
**options (Any): Additional declaration options used when
queue is a str.
"""
if not isinstance(queue, Queue):
return self.add_compat(queue, **kwargs)
return self._add(queue)
def add_compat(self, name, **options):
# docs used to use binding_key as routing key
options.setdefault('routing_key', options.get('binding_key'))
if options['routing_key'] is None:
options['routing_key'] = name
return self._add(Queue.from_dict(name, **options))
def _add(self, queue):
if queue.exchange is None or queue.exchange.name == '':
queue.exchange = self.default_exchange
if not queue.routing_key:
queue.routing_key = self.default_routing_key
if self.max_priority is not None:
if queue.queue_arguments is None:
queue.queue_arguments = {}
self._set_max_priority(queue.queue_arguments)
self[queue.name] = queue
return queue
def _set_max_priority(self, args):
if 'x-max-priority' not in args and self.max_priority is not None:
return args.update({'x-max-priority': self.max_priority})
def format(self, indent=0, indent_first=True):
"""Format routing table into string for log dumps."""
active = self.consume_from
if not active:
return ''
info = [QUEUE_FORMAT.strip().format(q)
for _, q in sorted(active.items())]
if indent_first:
return textindent('\n'.join(info), indent)
return info[0] + '\n' + textindent('\n'.join(info[1:]), indent)
def select_add(self, queue, **kwargs):
"""Add new task queue that'll be consumed from.
The queue will be active even when a subset has been selected
using the :option:`celery worker -Q` option.
"""
q = self.add(queue, **kwargs)
if self._consume_from is not None:
self._consume_from[q.name] = q
return q
def select(self, include):
"""Select a subset of currently defined queues to consume from.
Arguments:
include (Sequence[str], str): Names of queues to consume from.
"""
if include:
self._consume_from = {
name: self[name] for name in maybe_list(include)
}
def deselect(self, exclude):
"""Deselect queues so that they won't be consumed from.
Arguments:
exclude (Sequence[str], str): Names of queues to avoid
consuming from.
"""
if exclude:
exclude = maybe_list(exclude)
if self._consume_from is None:
# using all queues
return self.select(k for k in self if k not in exclude)
# using selection
for queue in exclude:
self._consume_from.pop(queue, None)
def new_missing(self, name):
return Queue(name, self.autoexchange(name), name)
@property
def consume_from(self):
if self._consume_from is not None:
return self._consume_from
return self
class AMQP:
"""App AMQP API: app.amqp."""
Connection = Connection
Consumer = Consumer
Producer = Producer
#: compat alias to Connection
BrokerConnection = Connection
queues_cls = Queues
#: Cached and prepared routing table.
_rtable = None
#: Underlying producer pool instance automatically
#: set by the :attr:`producer_pool`.
_producer_pool = None
# Exchange class/function used when defining automatic queues.
# For example, you can use ``autoexchange = lambda n: None`` to use the
# AMQP default exchange: a shortcut to bypass routing
# and instead send directly to the queue named in the routing key.
autoexchange = None
#: Max size of positional argument representation used for
#: logging purposes.
argsrepr_maxsize = 1024
#: Max size of keyword argument representation used for logging purposes.
kwargsrepr_maxsize = 1024
def __init__(self, app):
self.app = app
self.task_protocols = {
1: self.as_task_v1,
2: self.as_task_v2,
}
self.app._conf.bind_to(self._handle_conf_update)
@cached_property
def create_task_message(self):
return self.task_protocols[self.app.conf.task_protocol]
@cached_property
def send_task_message(self):
return self._create_task_sender()
def Queues(self, queues, create_missing=None,
autoexchange=None, max_priority=None):
# Create new :class:`Queues` instance, using queue defaults
# from the current configuration.
conf = self.app.conf
default_routing_key = conf.task_default_routing_key
if create_missing is None:
create_missing = conf.task_create_missing_queues
if max_priority is None:
max_priority = conf.task_queue_max_priority
if not queues and conf.task_default_queue:
queues = (Queue(conf.task_default_queue,
exchange=self.default_exchange,
routing_key=default_routing_key),)
autoexchange = (self.autoexchange if autoexchange is None
else autoexchange)
return self.queues_cls(
queues, self.default_exchange, create_missing,
autoexchange, max_priority, default_routing_key,
)
def Router(self, queues=None, create_missing=None):
"""Return the current task router."""
return _routes.Router(self.routes, queues or self.queues,
self.app.either('task_create_missing_queues',
create_missing), app=self.app)
def flush_routes(self):
self._rtable = _routes.prepare(self.app.conf.task_routes)
def TaskConsumer(self, channel, queues=None, accept=None, **kw):
if accept is None:
accept = self.app.conf.accept_content
return self.Consumer(
channel, accept=accept,
queues=queues or list(self.queues.consume_from.values()),
**kw
)
def as_task_v2(self, task_id, name, args=None, kwargs=None,
countdown=None, eta=None, group_id=None, group_index=None,
expires=None, retries=0, chord=None,
callbacks=None, errbacks=None, reply_to=None,
time_limit=None, soft_time_limit=None,
create_sent_event=False, root_id=None, parent_id=None,
shadow=None, chain=None, now=None, timezone=None,
origin=None, ignore_result=False, argsrepr=None, kwargsrepr=None, stamped_headers=None,
**options):
args = args or ()
kwargs = kwargs or {}
if not isinstance(args, (list, tuple)):
raise TypeError('task args must be a list or tuple')
if not isinstance(kwargs, Mapping):
raise TypeError('task keyword arguments must be a mapping')
if countdown: # convert countdown to ETA
self._verify_seconds(countdown, 'countdown')
now = now or self.app.now()
timezone = timezone or self.app.timezone
eta = maybe_make_aware(
now + timedelta(seconds=countdown), tz=timezone,
)
if isinstance(expires, numbers.Real):
self._verify_seconds(expires, 'expires')
now = now or self.app.now()
timezone = timezone or self.app.timezone
expires = maybe_make_aware(
now + timedelta(seconds=expires), tz=timezone,
)
if not isinstance(eta, str):
eta = eta and eta.isoformat()
# If we retry a task `expires` will already be ISO8601-formatted.
if not isinstance(expires, str):
expires = expires and expires.isoformat()
if argsrepr is None:
argsrepr = saferepr(args, self.argsrepr_maxsize)
if kwargsrepr is None:
kwargsrepr = saferepr(kwargs, self.kwargsrepr_maxsize)
if not root_id: # empty root_id defaults to task_id
root_id = task_id
stamps = {header: options[header] for header in stamped_headers or []}
headers = {
'lang': 'py',
'task': name,
'id': task_id,
'shadow': shadow,
'eta': eta,
'expires': expires,
'group': group_id,
'group_index': group_index,
'retries': retries,
'timelimit': [time_limit, soft_time_limit],
'root_id': root_id,
'parent_id': parent_id,
'argsrepr': argsrepr,
'kwargsrepr': kwargsrepr,
'origin': origin or anon_nodename(),
'ignore_result': ignore_result,
'stamped_headers': stamped_headers,
'stamps': stamps,
}
return task_message(
headers=headers,
properties={
'correlation_id': task_id,
'reply_to': reply_to or '',
},
body=(
args, kwargs, {
'callbacks': callbacks,
'errbacks': errbacks,
'chain': chain,
'chord': chord,
},
),
sent_event={
'uuid': task_id,
'root_id': root_id,
'parent_id': parent_id,
'name': name,
'args': argsrepr,
'kwargs': kwargsrepr,
'retries': retries,
'eta': eta,
'expires': expires,
} if create_sent_event else None,
)
def as_task_v1(self, task_id, name, args=None, kwargs=None,
countdown=None, eta=None, group_id=None, group_index=None,
expires=None, retries=0,
chord=None, callbacks=None, errbacks=None, reply_to=None,
time_limit=None, soft_time_limit=None,
create_sent_event=False, root_id=None, parent_id=None,
shadow=None, now=None, timezone=None,
**compat_kwargs):
args = args or ()
kwargs = kwargs or {}
utc = self.utc
if not isinstance(args, (list, tuple)):
raise TypeError('task args must be a list or tuple')
if not isinstance(kwargs, Mapping):
raise TypeError('task keyword arguments must be a mapping')
if countdown: # convert countdown to ETA
self._verify_seconds(countdown, 'countdown')
now = now or self.app.now()
eta = now + timedelta(seconds=countdown)
if isinstance(expires, numbers.Real):
self._verify_seconds(expires, 'expires')
now = now or self.app.now()
expires = now + timedelta(seconds=expires)
eta = eta and eta.isoformat()
expires = expires and expires.isoformat()
return task_message(
headers={},
properties={
'correlation_id': task_id,
'reply_to': reply_to or '',
},
body={
'task': name,
'id': task_id,
'args': args,
'kwargs': kwargs,
'group': group_id,
'group_index': group_index,
'retries': retries,
'eta': eta,
'expires': expires,
'utc': utc,
'callbacks': callbacks,
'errbacks': errbacks,
'timelimit': (time_limit, soft_time_limit),
'taskset': group_id,
'chord': chord,
},
sent_event={
'uuid': task_id,
'name': name,
'args': saferepr(args),
'kwargs': saferepr(kwargs),
'retries': retries,
'eta': eta,
'expires': expires,
} if create_sent_event else None,
)
def _verify_seconds(self, s, what):
if s < INT_MIN:
raise ValueError(f'{what} is out of range: {s!r}')
return s
def _create_task_sender(self):
default_retry = self.app.conf.task_publish_retry
default_policy = self.app.conf.task_publish_retry_policy
default_delivery_mode = self.app.conf.task_default_delivery_mode
default_queue = self.default_queue
queues = self.queues
send_before_publish = signals.before_task_publish.send
before_receivers = signals.before_task_publish.receivers
send_after_publish = signals.after_task_publish.send
after_receivers = signals.after_task_publish.receivers
send_task_sent = signals.task_sent.send # XXX compat
sent_receivers = signals.task_sent.receivers
default_evd = self._event_dispatcher
default_exchange = self.default_exchange
default_rkey = self.app.conf.task_default_routing_key
default_serializer = self.app.conf.task_serializer
default_compressor = self.app.conf.task_compression
def send_task_message(producer, name, message,
exchange=None, routing_key=None, queue=None,
event_dispatcher=None,
retry=None, retry_policy=None,
serializer=None, delivery_mode=None,
compression=None, declare=None,
headers=None, exchange_type=None, **kwargs):
retry = default_retry if retry is None else retry
headers2, properties, body, sent_event = message
if headers:
headers2.update(headers)
if kwargs:
properties.update(kwargs)
qname = queue
if queue is None and exchange is None:
queue = default_queue
if queue is not None:
if isinstance(queue, str):
qname, queue = queue, queues[queue]
else:
qname = queue.name
if delivery_mode is None:
try:
delivery_mode = queue.exchange.delivery_mode
except AttributeError:
pass
delivery_mode = delivery_mode or default_delivery_mode
if exchange_type is None:
try:
exchange_type = queue.exchange.type
except AttributeError:
exchange_type = 'direct'
# convert to anon-exchange, when exchange not set and direct ex.
if (not exchange or not routing_key) and exchange_type == 'direct':
exchange, routing_key = '', qname
elif exchange is None:
# not topic exchange, and exchange not undefined
exchange = queue.exchange.name or default_exchange
routing_key = routing_key or queue.routing_key or default_rkey
if declare is None and queue and not isinstance(queue, Broadcast):
declare = [queue]
# merge default and custom policy
retry = default_retry if retry is None else retry
_rp = (dict(default_policy, **retry_policy) if retry_policy
else default_policy)
if before_receivers:
send_before_publish(
sender=name, body=body,
exchange=exchange, routing_key=routing_key,
declare=declare, headers=headers2,
properties=properties, retry_policy=retry_policy,
)
ret = producer.publish(
body,
exchange=exchange,
routing_key=routing_key,
serializer=serializer or default_serializer,
compression=compression or default_compressor,
retry=retry, retry_policy=_rp,
delivery_mode=delivery_mode, declare=declare,
headers=headers2,
**properties
)
if after_receivers:
send_after_publish(sender=name, body=body, headers=headers2,
exchange=exchange, routing_key=routing_key)
if sent_receivers: # XXX deprecated
if isinstance(body, tuple): # protocol version 2
send_task_sent(
sender=name, task_id=headers2['id'], task=name,
args=body[0], kwargs=body[1],
eta=headers2['eta'], taskset=headers2['group'],
)
else: # protocol version 1
send_task_sent(
sender=name, task_id=body['id'], task=name,
args=body['args'], kwargs=body['kwargs'],
eta=body['eta'], taskset=body['taskset'],
)
if sent_event:
evd = event_dispatcher or default_evd
exname = exchange
if isinstance(exname, Exchange):
exname = exname.name
sent_event.update({
'queue': qname,
'exchange': exname,
'routing_key': routing_key,
})
evd.publish('task-sent', sent_event,
producer, retry=retry, retry_policy=retry_policy)
return ret
return send_task_message
@cached_property
def default_queue(self):
return self.queues[self.app.conf.task_default_queue]
@cached_property
def queues(self):
"""Queue name⇒ declaration mapping."""
return self.Queues(self.app.conf.task_queues)
@queues.setter
def queues(self, queues):
return self.Queues(queues)
@property
def routes(self):
if self._rtable is None:
self.flush_routes()
return self._rtable
@cached_property
def router(self):
return self.Router()
@router.setter
def router(self, value):
return value
@property
def producer_pool(self):
if self._producer_pool is None:
self._producer_pool = pools.producers[
self.app.connection_for_write()]
self._producer_pool.limit = self.app.pool.limit
return self._producer_pool
publisher_pool = producer_pool # compat alias
@cached_property
def default_exchange(self):
return Exchange(self.app.conf.task_default_exchange,
self.app.conf.task_default_exchange_type)
@cached_property
def utc(self):
return self.app.conf.enable_utc
@cached_property
def _event_dispatcher(self):
# We call Dispatcher.publish with a custom producer
# so don't need the dispatcher to be enabled.
return self.app.events.Dispatcher(enabled=False)
def _handle_conf_update(self, *args, **kwargs):
if ('task_routes' in kwargs or 'task_routes' in args):
self.flush_routes()
self.router = self.Router()
return

View File

@@ -0,0 +1,52 @@
"""Task Annotations.
Annotations is a nice term for monkey-patching task classes
in the configuration.
This prepares and performs the annotations in the
:setting:`task_annotations` setting.
"""
from celery.utils.functional import firstmethod, mlazy
from celery.utils.imports import instantiate
_first_match = firstmethod('annotate')
_first_match_any = firstmethod('annotate_any')
__all__ = ('MapAnnotation', 'prepare', 'resolve_all')
class MapAnnotation(dict):
"""Annotation map: task_name => attributes."""
def annotate_any(self):
try:
return dict(self['*'])
except KeyError:
pass
def annotate(self, task):
try:
return dict(self[task.name])
except KeyError:
pass
def prepare(annotations):
"""Expand the :setting:`task_annotations` setting."""
def expand_annotation(annotation):
if isinstance(annotation, dict):
return MapAnnotation(annotation)
elif isinstance(annotation, str):
return mlazy(instantiate, annotation)
return annotation
if annotations is None:
return ()
elif not isinstance(annotations, (list, tuple)):
annotations = (annotations,)
return [expand_annotation(anno) for anno in annotations]
def resolve_all(anno, task):
"""Resolve all pending annotations."""
return (x for x in (_first_match(anno, task), _first_match_any(anno)) if x)

View File

@@ -0,0 +1,66 @@
"""Tasks auto-retry functionality."""
from vine.utils import wraps
from celery.exceptions import Ignore, Retry
from celery.utils.time import get_exponential_backoff_interval
def add_autoretry_behaviour(task, **options):
"""Wrap task's `run` method with auto-retry functionality."""
autoretry_for = tuple(
options.get('autoretry_for',
getattr(task, 'autoretry_for', ()))
)
dont_autoretry_for = tuple(
options.get('dont_autoretry_for',
getattr(task, 'dont_autoretry_for', ()))
)
retry_kwargs = options.get(
'retry_kwargs', getattr(task, 'retry_kwargs', {})
)
retry_backoff = float(
options.get('retry_backoff',
getattr(task, 'retry_backoff', False))
)
retry_backoff_max = int(
options.get('retry_backoff_max',
getattr(task, 'retry_backoff_max', 600))
)
retry_jitter = options.get(
'retry_jitter', getattr(task, 'retry_jitter', True)
)
if autoretry_for and not hasattr(task, '_orig_run'):
@wraps(task.run)
def run(*args, **kwargs):
try:
return task._orig_run(*args, **kwargs)
except Ignore:
# If Ignore signal occurs task shouldn't be retried,
# even if it suits autoretry_for list
raise
except Retry:
raise
except dont_autoretry_for:
raise
except autoretry_for as exc:
if retry_backoff:
retry_kwargs['countdown'] = \
get_exponential_backoff_interval(
factor=int(max(1.0, retry_backoff)),
retries=task.request.retries,
maximum=retry_backoff_max,
full_jitter=retry_jitter)
# Override max_retries
if hasattr(task, 'override_max_retries'):
retry_kwargs['max_retries'] = getattr(task,
'override_max_retries',
task.max_retries)
ret = task.retry(exc=exc, **retry_kwargs)
# Stop propagation
if hasattr(task, 'override_max_retries'):
delattr(task, 'override_max_retries')
raise ret
task._orig_run, task.run = task.run, run

View File

@@ -0,0 +1,68 @@
"""Backend selection."""
import sys
import types
from celery._state import current_app
from celery.exceptions import ImproperlyConfigured, reraise
from celery.utils.imports import load_extension_class_names, symbol_by_name
__all__ = ('by_name', 'by_url')
UNKNOWN_BACKEND = """
Unknown result backend: {0!r}. Did you spell that correctly? ({1!r})
"""
BACKEND_ALIASES = {
'rpc': 'celery.backends.rpc.RPCBackend',
'cache': 'celery.backends.cache:CacheBackend',
'redis': 'celery.backends.redis:RedisBackend',
'rediss': 'celery.backends.redis:RedisBackend',
'sentinel': 'celery.backends.redis:SentinelBackend',
'mongodb': 'celery.backends.mongodb:MongoBackend',
'db': 'celery.backends.database:DatabaseBackend',
'database': 'celery.backends.database:DatabaseBackend',
'elasticsearch': 'celery.backends.elasticsearch:ElasticsearchBackend',
'cassandra': 'celery.backends.cassandra:CassandraBackend',
'couchbase': 'celery.backends.couchbase:CouchbaseBackend',
'couchdb': 'celery.backends.couchdb:CouchBackend',
'cosmosdbsql': 'celery.backends.cosmosdbsql:CosmosDBSQLBackend',
'riak': 'celery.backends.riak:RiakBackend',
'file': 'celery.backends.filesystem:FilesystemBackend',
'disabled': 'celery.backends.base:DisabledBackend',
'consul': 'celery.backends.consul:ConsulBackend',
'dynamodb': 'celery.backends.dynamodb:DynamoDBBackend',
'azureblockblob': 'celery.backends.azureblockblob:AzureBlockBlobBackend',
'arangodb': 'celery.backends.arangodb:ArangoDbBackend',
's3': 'celery.backends.s3:S3Backend',
}
def by_name(backend=None, loader=None,
extension_namespace='celery.result_backends'):
"""Get backend class by name/alias."""
backend = backend or 'disabled'
loader = loader or current_app.loader
aliases = dict(BACKEND_ALIASES, **loader.override_backends)
aliases.update(load_extension_class_names(extension_namespace))
try:
cls = symbol_by_name(backend, aliases)
except ValueError as exc:
reraise(ImproperlyConfigured, ImproperlyConfigured(
UNKNOWN_BACKEND.strip().format(backend, exc)), sys.exc_info()[2])
if isinstance(cls, types.ModuleType):
raise ImproperlyConfigured(UNKNOWN_BACKEND.strip().format(
backend, 'is a Python module, not a backend class.'))
return cls
def by_url(backend=None, loader=None):
"""Get backend class by URL."""
url = None
if backend and '://' in backend:
url = backend
scheme, _, _ = url.partition('://')
if '+' in scheme:
backend, url = url.split('+', 1)
else:
backend = scheme
return by_name(backend, loader), url

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,187 @@
"""Built-in Tasks.
The built-in tasks are always available in all app instances.
"""
from celery._state import connect_on_app_finalize
from celery.utils.log import get_logger
__all__ = ()
logger = get_logger(__name__)
@connect_on_app_finalize
def add_backend_cleanup_task(app):
"""Task used to clean up expired results.
If the configured backend requires periodic cleanup this task is also
automatically configured to run every day at 4am (requires
:program:`celery beat` to be running).
"""
@app.task(name='celery.backend_cleanup', shared=False, lazy=False)
def backend_cleanup():
app.backend.cleanup()
return backend_cleanup
@connect_on_app_finalize
def add_accumulate_task(app):
"""Task used by Task.replace when replacing task with group."""
@app.task(bind=True, name='celery.accumulate', shared=False, lazy=False)
def accumulate(self, *args, **kwargs):
index = kwargs.get('index')
return args[index] if index is not None else args
return accumulate
@connect_on_app_finalize
def add_unlock_chord_task(app):
"""Task used by result backends without native chord support.
Will joins chord by creating a task chain polling the header
for completion.
"""
from celery.canvas import maybe_signature
from celery.exceptions import ChordError
from celery.result import allow_join_result, result_from_tuple
@app.task(name='celery.chord_unlock', max_retries=None, shared=False,
default_retry_delay=app.conf.result_chord_retry_interval, ignore_result=True, lazy=False, bind=True)
def unlock_chord(self, group_id, callback, interval=None,
max_retries=None, result=None,
Result=app.AsyncResult, GroupResult=app.GroupResult,
result_from_tuple=result_from_tuple, **kwargs):
if interval is None:
interval = self.default_retry_delay
# check if the task group is ready, and if so apply the callback.
callback = maybe_signature(callback, app)
deps = GroupResult(
group_id,
[result_from_tuple(r, app=app) for r in result],
app=app,
)
j = deps.join_native if deps.supports_native_join else deps.join
try:
ready = deps.ready()
except Exception as exc:
raise self.retry(
exc=exc, countdown=interval, max_retries=max_retries,
)
else:
if not ready:
raise self.retry(countdown=interval, max_retries=max_retries)
callback = maybe_signature(callback, app=app)
try:
with allow_join_result():
ret = j(
timeout=app.conf.result_chord_join_timeout,
propagate=True,
)
except Exception as exc: # pylint: disable=broad-except
try:
culprit = next(deps._failed_join_report())
reason = f'Dependency {culprit.id} raised {exc!r}'
except StopIteration:
reason = repr(exc)
logger.exception('Chord %r raised: %r', group_id, exc)
app.backend.chord_error_from_stack(callback, ChordError(reason))
else:
try:
callback.delay(ret)
except Exception as exc: # pylint: disable=broad-except
logger.exception('Chord %r raised: %r', group_id, exc)
app.backend.chord_error_from_stack(
callback,
exc=ChordError(f'Callback error: {exc!r}'),
)
return unlock_chord
@connect_on_app_finalize
def add_map_task(app):
from celery.canvas import signature
@app.task(name='celery.map', shared=False, lazy=False)
def xmap(task, it):
task = signature(task, app=app).type
return [task(item) for item in it]
return xmap
@connect_on_app_finalize
def add_starmap_task(app):
from celery.canvas import signature
@app.task(name='celery.starmap', shared=False, lazy=False)
def xstarmap(task, it):
task = signature(task, app=app).type
return [task(*item) for item in it]
return xstarmap
@connect_on_app_finalize
def add_chunk_task(app):
from celery.canvas import chunks as _chunks
@app.task(name='celery.chunks', shared=False, lazy=False)
def chunks(task, it, n):
return _chunks.apply_chunks(task, it, n)
return chunks
@connect_on_app_finalize
def add_group_task(app):
"""No longer used, but here for backwards compatibility."""
from celery.canvas import maybe_signature
from celery.result import result_from_tuple
@app.task(name='celery.group', bind=True, shared=False, lazy=False)
def group(self, tasks, result, group_id, partial_args, add_to_parent=True):
app = self.app
result = result_from_tuple(result, app)
# any partial args are added to all tasks in the group
taskit = (maybe_signature(task, app=app).clone(partial_args)
for i, task in enumerate(tasks))
with app.producer_or_acquire() as producer:
[stask.apply_async(group_id=group_id, producer=producer,
add_to_parent=False) for stask in taskit]
parent = app.current_worker_task
if add_to_parent and parent:
parent.add_trail(result)
return result
return group
@connect_on_app_finalize
def add_chain_task(app):
"""No longer used, but here for backwards compatibility."""
@app.task(name='celery.chain', shared=False, lazy=False)
def chain(*args, **kwargs):
raise NotImplementedError('chain is not a real task')
return chain
@connect_on_app_finalize
def add_chord_task(app):
"""No longer used, but here for backwards compatibility."""
from celery import chord as _chord
from celery import group
from celery.canvas import maybe_signature
@app.task(name='celery.chord', bind=True, ignore_result=False,
shared=False, lazy=False)
def chord(self, header, body, partial_args=(), interval=None,
countdown=1, max_retries=None, eager=False, **kwargs):
app = self.app
# - convert back to group if serialized
tasks = header.tasks if isinstance(header, group) else header
header = group([
maybe_signature(s, app=app) for s in tasks
], app=self.app)
body = maybe_signature(body, app=app)
ch = _chord(header, body)
return ch.run(header, body, partial_args, app, interval,
countdown, max_retries, **kwargs)
return chord

View File

@@ -0,0 +1,779 @@
"""Worker Remote Control Client.
Client for worker remote control commands.
Server implementation is in :mod:`celery.worker.control`.
There are two types of remote control commands:
* Inspect commands: Does not have side effects, will usually just return some value
found in the worker, like the list of currently registered tasks, the list of active tasks, etc.
Commands are accessible via :class:`Inspect` class.
* Control commands: Performs side effects, like adding a new queue to consume from.
Commands are accessible via :class:`Control` class.
"""
import warnings
from billiard.common import TERM_SIGNAME
from kombu.matcher import match
from kombu.pidbox import Mailbox
from kombu.utils.compat import register_after_fork
from kombu.utils.functional import lazy
from kombu.utils.objects import cached_property
from celery.exceptions import DuplicateNodenameWarning
from celery.utils.log import get_logger
from celery.utils.text import pluralize
__all__ = ('Inspect', 'Control', 'flatten_reply')
logger = get_logger(__name__)
W_DUPNODE = """\
Received multiple replies from node {0}: {1}.
Please make sure you give each node a unique nodename using
the celery worker `-n` option.\
"""
def flatten_reply(reply):
"""Flatten node replies.
Convert from a list of replies in this format::
[{'a@example.com': reply},
{'b@example.com': reply}]
into this format::
{'a@example.com': reply,
'b@example.com': reply}
"""
nodes, dupes = {}, set()
for item in reply:
[dupes.add(name) for name in item if name in nodes]
nodes.update(item)
if dupes:
warnings.warn(DuplicateNodenameWarning(
W_DUPNODE.format(
pluralize(len(dupes), 'name'), ', '.join(sorted(dupes)),
),
))
return nodes
def _after_fork_cleanup_control(control):
try:
control._after_fork()
except Exception as exc: # pylint: disable=broad-except
logger.info('after fork raised exception: %r', exc, exc_info=1)
class Inspect:
"""API for inspecting workers.
This class provides proxy for accessing Inspect API of workers. The API is
defined in :py:mod:`celery.worker.control`
"""
app = None
def __init__(self, destination=None, timeout=1.0, callback=None,
connection=None, app=None, limit=None, pattern=None,
matcher=None):
self.app = app or self.app
self.destination = destination
self.timeout = timeout
self.callback = callback
self.connection = connection
self.limit = limit
self.pattern = pattern
self.matcher = matcher
def _prepare(self, reply):
if reply:
by_node = flatten_reply(reply)
if (self.destination and
not isinstance(self.destination, (list, tuple))):
return by_node.get(self.destination)
if self.pattern:
pattern = self.pattern
matcher = self.matcher
return {node: reply for node, reply in by_node.items()
if match(node, pattern, matcher)}
return by_node
def _request(self, command, **kwargs):
return self._prepare(self.app.control.broadcast(
command,
arguments=kwargs,
destination=self.destination,
callback=self.callback,
connection=self.connection,
limit=self.limit,
timeout=self.timeout, reply=True,
pattern=self.pattern, matcher=self.matcher,
))
def report(self):
"""Return human readable report for each worker.
Returns:
Dict: Dictionary ``{HOSTNAME: {'ok': REPORT_STRING}}``.
"""
return self._request('report')
def clock(self):
"""Get the Clock value on workers.
>>> app.control.inspect().clock()
{'celery@node1': {'clock': 12}}
Returns:
Dict: Dictionary ``{HOSTNAME: CLOCK_VALUE}``.
"""
return self._request('clock')
def active(self, safe=None):
"""Return list of tasks currently executed by workers.
Arguments:
safe (Boolean): Set to True to disable deserialization.
Returns:
Dict: Dictionary ``{HOSTNAME: [TASK_INFO,...]}``.
See Also:
For ``TASK_INFO`` details see :func:`query_task` return value.
"""
return self._request('active', safe=safe)
def scheduled(self, safe=None):
"""Return list of scheduled tasks with details.
Returns:
Dict: Dictionary ``{HOSTNAME: [TASK_SCHEDULED_INFO,...]}``.
Here is the list of ``TASK_SCHEDULED_INFO`` fields:
* ``eta`` - scheduled time for task execution as string in ISO 8601 format
* ``priority`` - priority of the task
* ``request`` - field containing ``TASK_INFO`` value.
See Also:
For more details about ``TASK_INFO`` see :func:`query_task` return value.
"""
return self._request('scheduled')
def reserved(self, safe=None):
"""Return list of currently reserved tasks, not including scheduled/active.
Returns:
Dict: Dictionary ``{HOSTNAME: [TASK_INFO,...]}``.
See Also:
For ``TASK_INFO`` details see :func:`query_task` return value.
"""
return self._request('reserved')
def stats(self):
"""Return statistics of worker.
Returns:
Dict: Dictionary ``{HOSTNAME: STAT_INFO}``.
Here is the list of ``STAT_INFO`` fields:
* ``broker`` - Section for broker information.
* ``connect_timeout`` - Timeout in seconds (int/float) for establishing a new connection.
* ``heartbeat`` - Current heartbeat value (set by client).
* ``hostname`` - Node name of the remote broker.
* ``insist`` - No longer used.
* ``login_method`` - Login method used to connect to the broker.
* ``port`` - Port of the remote broker.
* ``ssl`` - SSL enabled/disabled.
* ``transport`` - Name of transport used (e.g., amqp or redis)
* ``transport_options`` - Options passed to transport.
* ``uri_prefix`` - Some transports expects the host name to be a URL.
E.g. ``redis+socket:///tmp/redis.sock``.
In this example the URI-prefix will be redis.
* ``userid`` - User id used to connect to the broker with.
* ``virtual_host`` - Virtual host used.
* ``clock`` - Value of the workers logical clock. This is a positive integer
and should be increasing every time you receive statistics.
* ``uptime`` - Numbers of seconds since the worker controller was started
* ``pid`` - Process id of the worker instance (Main process).
* ``pool`` - Pool-specific section.
* ``max-concurrency`` - Max number of processes/threads/green threads.
* ``max-tasks-per-child`` - Max number of tasks a thread may execute before being recycled.
* ``processes`` - List of PIDs (or thread-ids).
* ``put-guarded-by-semaphore`` - Internal
* ``timeouts`` - Default values for time limits.
* ``writes`` - Specific to the prefork pool, this shows the distribution
of writes to each process in the pool when using async I/O.
* ``prefetch_count`` - Current prefetch count value for the task consumer.
* ``rusage`` - System usage statistics. The fields available may be different on your platform.
From :manpage:`getrusage(2)`:
* ``stime`` - Time spent in operating system code on behalf of this process.
* ``utime`` - Time spent executing user instructions.
* ``maxrss`` - The maximum resident size used by this process (in kilobytes).
* ``idrss`` - Amount of non-shared memory used for data (in kilobytes times
ticks of execution)
* ``isrss`` - Amount of non-shared memory used for stack space
(in kilobytes times ticks of execution)
* ``ixrss`` - Amount of memory shared with other processes
(in kilobytes times ticks of execution).
* ``inblock`` - Number of times the file system had to read from the disk
on behalf of this process.
* ``oublock`` - Number of times the file system has to write to disk
on behalf of this process.
* ``majflt`` - Number of page faults that were serviced by doing I/O.
* ``minflt`` - Number of page faults that were serviced without doing I/O.
* ``msgrcv`` - Number of IPC messages received.
* ``msgsnd`` - Number of IPC messages sent.
* ``nvcsw`` - Number of times this process voluntarily invoked a context switch.
* ``nivcsw`` - Number of times an involuntary context switch took place.
* ``nsignals`` - Number of signals received.
* ``nswap`` - The number of times this process was swapped entirely
out of memory.
* ``total`` - Map of task names and the total number of tasks with that type
the worker has accepted since start-up.
"""
return self._request('stats')
def revoked(self):
"""Return list of revoked tasks.
>>> app.control.inspect().revoked()
{'celery@node1': ['16f527de-1c72-47a6-b477-c472b92fef7a']}
Returns:
Dict: Dictionary ``{HOSTNAME: [TASK_ID, ...]}``.
"""
return self._request('revoked')
def registered(self, *taskinfoitems):
"""Return all registered tasks per worker.
>>> app.control.inspect().registered()
{'celery@node1': ['task1', 'task1']}
>>> app.control.inspect().registered('serializer', 'max_retries')
{'celery@node1': ['task_foo [serializer=json max_retries=3]', 'tasb_bar [serializer=json max_retries=3]']}
Arguments:
taskinfoitems (Sequence[str]): List of :class:`~celery.app.task.Task`
attributes to include.
Returns:
Dict: Dictionary ``{HOSTNAME: [TASK1_INFO, ...]}``.
"""
return self._request('registered', taskinfoitems=taskinfoitems)
registered_tasks = registered
def ping(self, destination=None):
"""Ping all (or specific) workers.
>>> app.control.inspect().ping()
{'celery@node1': {'ok': 'pong'}, 'celery@node2': {'ok': 'pong'}}
>>> app.control.inspect().ping(destination=['celery@node1'])
{'celery@node1': {'ok': 'pong'}}
Arguments:
destination (List): If set, a list of the hosts to send the
command to, when empty broadcast to all workers.
Returns:
Dict: Dictionary ``{HOSTNAME: {'ok': 'pong'}}``.
See Also:
:meth:`broadcast` for supported keyword arguments.
"""
if destination:
self.destination = destination
return self._request('ping')
def active_queues(self):
"""Return information about queues from which worker consumes tasks.
Returns:
Dict: Dictionary ``{HOSTNAME: [QUEUE_INFO, QUEUE_INFO,...]}``.
Here is the list of ``QUEUE_INFO`` fields:
* ``name``
* ``exchange``
* ``name``
* ``type``
* ``arguments``
* ``durable``
* ``passive``
* ``auto_delete``
* ``delivery_mode``
* ``no_declare``
* ``routing_key``
* ``queue_arguments``
* ``binding_arguments``
* ``consumer_arguments``
* ``durable``
* ``exclusive``
* ``auto_delete``
* ``no_ack``
* ``alias``
* ``bindings``
* ``no_declare``
* ``expires``
* ``message_ttl``
* ``max_length``
* ``max_length_bytes``
* ``max_priority``
See Also:
See the RabbitMQ/AMQP documentation for more details about
``queue_info`` fields.
Note:
The ``queue_info`` fields are RabbitMQ/AMQP oriented.
Not all fields applies for other transports.
"""
return self._request('active_queues')
def query_task(self, *ids):
"""Return detail of tasks currently executed by workers.
Arguments:
*ids (str): IDs of tasks to be queried.
Returns:
Dict: Dictionary ``{HOSTNAME: {TASK_ID: [STATE, TASK_INFO]}}``.
Here is the list of ``TASK_INFO`` fields:
* ``id`` - ID of the task
* ``name`` - Name of the task
* ``args`` - Positinal arguments passed to the task
* ``kwargs`` - Keyword arguments passed to the task
* ``type`` - Type of the task
* ``hostname`` - Hostname of the worker processing the task
* ``time_start`` - Time of processing start
* ``acknowledged`` - True when task was acknowledged to broker
* ``delivery_info`` - Dictionary containing delivery information
* ``exchange`` - Name of exchange where task was published
* ``routing_key`` - Routing key used when task was published
* ``priority`` - Priority used when task was published
* ``redelivered`` - True if the task was redelivered
* ``worker_pid`` - PID of worker processin the task
"""
# signature used be unary: query_task(ids=[id1, id2])
# we need this to preserve backward compatibility.
if len(ids) == 1 and isinstance(ids[0], (list, tuple)):
ids = ids[0]
return self._request('query_task', ids=ids)
def conf(self, with_defaults=False):
"""Return configuration of each worker.
Arguments:
with_defaults (bool): if set to True, method returns also
configuration options with default values.
Returns:
Dict: Dictionary ``{HOSTNAME: WORKER_CONFIGURATION}``.
See Also:
``WORKER_CONFIGURATION`` is a dictionary containing current configuration options.
See :ref:`configuration` for possible values.
"""
return self._request('conf', with_defaults=with_defaults)
def hello(self, from_node, revoked=None):
return self._request('hello', from_node=from_node, revoked=revoked)
def memsample(self):
"""Return sample current RSS memory usage.
Note:
Requires the psutils library.
"""
return self._request('memsample')
def memdump(self, samples=10):
"""Dump statistics of previous memsample requests.
Note:
Requires the psutils library.
"""
return self._request('memdump', samples=samples)
def objgraph(self, type='Request', n=200, max_depth=10):
"""Create graph of uncollected objects (memory-leak debugging).
Arguments:
n (int): Max number of objects to graph.
max_depth (int): Traverse at most n levels deep.
type (str): Name of object to graph. Default is ``"Request"``.
Returns:
Dict: Dictionary ``{'filename': FILENAME}``
Note:
Requires the objgraph library.
"""
return self._request('objgraph', num=n, max_depth=max_depth, type=type)
class Control:
"""Worker remote control client."""
Mailbox = Mailbox
def __init__(self, app=None):
self.app = app
self.mailbox = self.Mailbox(
app.conf.control_exchange,
type='fanout',
accept=app.conf.accept_content,
serializer=app.conf.task_serializer,
producer_pool=lazy(lambda: self.app.amqp.producer_pool),
queue_ttl=app.conf.control_queue_ttl,
reply_queue_ttl=app.conf.control_queue_ttl,
queue_expires=app.conf.control_queue_expires,
reply_queue_expires=app.conf.control_queue_expires,
)
register_after_fork(self, _after_fork_cleanup_control)
def _after_fork(self):
del self.mailbox.producer_pool
@cached_property
def inspect(self):
"""Create new :class:`Inspect` instance."""
return self.app.subclass_with_self(Inspect, reverse='control.inspect')
def purge(self, connection=None):
"""Discard all waiting tasks.
This will ignore all tasks waiting for execution, and they will
be deleted from the messaging server.
Arguments:
connection (kombu.Connection): Optional specific connection
instance to use. If not provided a connection will
be acquired from the connection pool.
Returns:
int: the number of tasks discarded.
"""
with self.app.connection_or_acquire(connection) as conn:
return self.app.amqp.TaskConsumer(conn).purge()
discard_all = purge
def election(self, id, topic, action=None, connection=None):
self.broadcast(
'election', connection=connection, destination=None,
arguments={
'id': id, 'topic': topic, 'action': action,
},
)
def revoke(self, task_id, destination=None, terminate=False,
signal=TERM_SIGNAME, **kwargs):
"""Tell all (or specific) workers to revoke a task by id (or list of ids).
If a task is revoked, the workers will ignore the task and
not execute it after all.
Arguments:
task_id (Union(str, list)): Id of the task to revoke
(or list of ids).
terminate (bool): Also terminate the process currently working
on the task (if any).
signal (str): Name of signal to send to process if terminate.
Default is TERM.
See Also:
:meth:`broadcast` for supported keyword arguments.
"""
return self.broadcast('revoke', destination=destination, arguments={
'task_id': task_id,
'terminate': terminate,
'signal': signal,
}, **kwargs)
def revoke_by_stamped_headers(self, headers, destination=None, terminate=False,
signal=TERM_SIGNAME, **kwargs):
"""
Tell all (or specific) workers to revoke a task by headers.
If a task is revoked, the workers will ignore the task and
not execute it after all.
Arguments:
headers (dict[str, Union(str, list)]): Headers to match when revoking tasks.
terminate (bool): Also terminate the process currently working
on the task (if any).
signal (str): Name of signal to send to process if terminate.
Default is TERM.
See Also:
:meth:`broadcast` for supported keyword arguments.
"""
result = self.broadcast('revoke_by_stamped_headers', destination=destination, arguments={
'headers': headers,
'terminate': terminate,
'signal': signal,
}, **kwargs)
task_ids = set()
if result:
for host in result:
for response in host.values():
task_ids.update(response['ok'])
if task_ids:
return self.revoke(list(task_ids), destination=destination, terminate=terminate, signal=signal, **kwargs)
else:
return result
def terminate(self, task_id,
destination=None, signal=TERM_SIGNAME, **kwargs):
"""Tell all (or specific) workers to terminate a task by id (or list of ids).
See Also:
This is just a shortcut to :meth:`revoke` with the terminate
argument enabled.
"""
return self.revoke(
task_id,
destination=destination, terminate=True, signal=signal, **kwargs)
def ping(self, destination=None, timeout=1.0, **kwargs):
"""Ping all (or specific) workers.
>>> app.control.ping()
[{'celery@node1': {'ok': 'pong'}}, {'celery@node2': {'ok': 'pong'}}]
>>> app.control.ping(destination=['celery@node2'])
[{'celery@node2': {'ok': 'pong'}}]
Returns:
List[Dict]: List of ``{HOSTNAME: {'ok': 'pong'}}`` dictionaries.
See Also:
:meth:`broadcast` for supported keyword arguments.
"""
return self.broadcast(
'ping', reply=True, arguments={}, destination=destination,
timeout=timeout, **kwargs)
def rate_limit(self, task_name, rate_limit, destination=None, **kwargs):
"""Tell workers to set a new rate limit for task by type.
Arguments:
task_name (str): Name of task to change rate limit for.
rate_limit (int, str): The rate limit as tasks per second,
or a rate limit string (`'100/m'`, etc.
see :attr:`celery.app.task.Task.rate_limit` for
more information).
See Also:
:meth:`broadcast` for supported keyword arguments.
"""
return self.broadcast(
'rate_limit',
destination=destination,
arguments={
'task_name': task_name,
'rate_limit': rate_limit,
},
**kwargs)
def add_consumer(self, queue,
exchange=None, exchange_type='direct', routing_key=None,
options=None, destination=None, **kwargs):
"""Tell all (or specific) workers to start consuming from a new queue.
Only the queue name is required as if only the queue is specified
then the exchange/routing key will be set to the same name (
like automatic queues do).
Note:
This command does not respect the default queue/exchange
options in the configuration.
Arguments:
queue (str): Name of queue to start consuming from.
exchange (str): Optional name of exchange.
exchange_type (str): Type of exchange (defaults to 'direct')
command to, when empty broadcast to all workers.
routing_key (str): Optional routing key.
options (Dict): Additional options as supported
by :meth:`kombu.entity.Queue.from_dict`.
See Also:
:meth:`broadcast` for supported keyword arguments.
"""
return self.broadcast(
'add_consumer',
destination=destination,
arguments=dict({
'queue': queue,
'exchange': exchange,
'exchange_type': exchange_type,
'routing_key': routing_key,
}, **options or {}),
**kwargs
)
def cancel_consumer(self, queue, destination=None, **kwargs):
"""Tell all (or specific) workers to stop consuming from ``queue``.
See Also:
Supports the same arguments as :meth:`broadcast`.
"""
return self.broadcast(
'cancel_consumer', destination=destination,
arguments={'queue': queue}, **kwargs)
def time_limit(self, task_name, soft=None, hard=None,
destination=None, **kwargs):
"""Tell workers to set time limits for a task by type.
Arguments:
task_name (str): Name of task to change time limits for.
soft (float): New soft time limit (in seconds).
hard (float): New hard time limit (in seconds).
**kwargs (Any): arguments passed on to :meth:`broadcast`.
"""
return self.broadcast(
'time_limit',
arguments={
'task_name': task_name,
'hard': hard,
'soft': soft,
},
destination=destination,
**kwargs)
def enable_events(self, destination=None, **kwargs):
"""Tell all (or specific) workers to enable events.
See Also:
Supports the same arguments as :meth:`broadcast`.
"""
return self.broadcast(
'enable_events', arguments={}, destination=destination, **kwargs)
def disable_events(self, destination=None, **kwargs):
"""Tell all (or specific) workers to disable events.
See Also:
Supports the same arguments as :meth:`broadcast`.
"""
return self.broadcast(
'disable_events', arguments={}, destination=destination, **kwargs)
def pool_grow(self, n=1, destination=None, **kwargs):
"""Tell all (or specific) workers to grow the pool by ``n``.
See Also:
Supports the same arguments as :meth:`broadcast`.
"""
return self.broadcast(
'pool_grow', arguments={'n': n}, destination=destination, **kwargs)
def pool_shrink(self, n=1, destination=None, **kwargs):
"""Tell all (or specific) workers to shrink the pool by ``n``.
See Also:
Supports the same arguments as :meth:`broadcast`.
"""
return self.broadcast(
'pool_shrink', arguments={'n': n},
destination=destination, **kwargs)
def autoscale(self, max, min, destination=None, **kwargs):
"""Change worker(s) autoscale setting.
See Also:
Supports the same arguments as :meth:`broadcast`.
"""
return self.broadcast(
'autoscale', arguments={'max': max, 'min': min},
destination=destination, **kwargs)
def shutdown(self, destination=None, **kwargs):
"""Shutdown worker(s).
See Also:
Supports the same arguments as :meth:`broadcast`
"""
return self.broadcast(
'shutdown', arguments={}, destination=destination, **kwargs)
def pool_restart(self, modules=None, reload=False, reloader=None,
destination=None, **kwargs):
"""Restart the execution pools of all or specific workers.
Keyword Arguments:
modules (Sequence[str]): List of modules to reload.
reload (bool): Flag to enable module reloading. Default is False.
reloader (Any): Function to reload a module.
destination (Sequence[str]): List of worker names to send this
command to.
See Also:
Supports the same arguments as :meth:`broadcast`
"""
return self.broadcast(
'pool_restart',
arguments={
'modules': modules,
'reload': reload,
'reloader': reloader,
},
destination=destination, **kwargs)
def heartbeat(self, destination=None, **kwargs):
"""Tell worker(s) to send a heartbeat immediately.
See Also:
Supports the same arguments as :meth:`broadcast`
"""
return self.broadcast(
'heartbeat', arguments={}, destination=destination, **kwargs)
def broadcast(self, command, arguments=None, destination=None,
connection=None, reply=False, timeout=1.0, limit=None,
callback=None, channel=None, pattern=None, matcher=None,
**extra_kwargs):
"""Broadcast a control command to the celery workers.
Arguments:
command (str): Name of command to send.
arguments (Dict): Keyword arguments for the command.
destination (List): If set, a list of the hosts to send the
command to, when empty broadcast to all workers.
connection (kombu.Connection): Custom broker connection to use,
if not set, a connection will be acquired from the pool.
reply (bool): Wait for and return the reply.
timeout (float): Timeout in seconds to wait for the reply.
limit (int): Limit number of replies.
callback (Callable): Callback called immediately for
each reply received.
pattern (str): Custom pattern string to match
matcher (Callable): Custom matcher to run the pattern to match
"""
with self.app.connection_or_acquire(connection) as conn:
arguments = dict(arguments or {}, **extra_kwargs)
if pattern and matcher:
# tests pass easier without requiring pattern/matcher to
# always be sent in
return self.mailbox(conn)._broadcast(
command, arguments, destination, reply, timeout,
limit, callback, channel=channel,
pattern=pattern, matcher=matcher,
)
else:
return self.mailbox(conn)._broadcast(
command, arguments, destination, reply, timeout,
limit, callback, channel=channel,
)

View File

@@ -0,0 +1,414 @@
"""Configuration introspection and defaults."""
from collections import deque, namedtuple
from datetime import timedelta
from celery.utils.functional import memoize
from celery.utils.serialization import strtobool
__all__ = ('Option', 'NAMESPACES', 'flatten', 'find')
DEFAULT_POOL = 'prefork'
DEFAULT_ACCEPT_CONTENT = ('json',)
DEFAULT_PROCESS_LOG_FMT = """
[%(asctime)s: %(levelname)s/%(processName)s] %(message)s
""".strip()
DEFAULT_TASK_LOG_FMT = """[%(asctime)s: %(levelname)s/%(processName)s] \
%(task_name)s[%(task_id)s]: %(message)s"""
DEFAULT_SECURITY_DIGEST = 'sha256'
OLD_NS = {'celery_{0}'}
OLD_NS_BEAT = {'celerybeat_{0}'}
OLD_NS_WORKER = {'celeryd_{0}'}
searchresult = namedtuple('searchresult', ('namespace', 'key', 'type'))
def Namespace(__old__=None, **options):
if __old__ is not None:
for key, opt in options.items():
if not opt.old:
opt.old = {o.format(key) for o in __old__}
return options
def old_ns(ns):
return {f'{ns}_{{0}}'}
class Option:
"""Describes a Celery configuration option."""
alt = None
deprecate_by = None
remove_by = None
old = set()
typemap = {'string': str, 'int': int, 'float': float, 'any': lambda v: v,
'bool': strtobool, 'dict': dict, 'tuple': tuple}
def __init__(self, default=None, *args, **kwargs):
self.default = default
self.type = kwargs.get('type') or 'string'
for attr, value in kwargs.items():
setattr(self, attr, value)
def to_python(self, value):
return self.typemap[self.type](value)
def __repr__(self):
return '<Option: type->{} default->{!r}>'.format(self.type,
self.default)
NAMESPACES = Namespace(
accept_content=Option(DEFAULT_ACCEPT_CONTENT, type='list', old=OLD_NS),
result_accept_content=Option(None, type='list'),
enable_utc=Option(True, type='bool'),
imports=Option((), type='tuple', old=OLD_NS),
include=Option((), type='tuple', old=OLD_NS),
timezone=Option(type='string', old=OLD_NS),
beat=Namespace(
__old__=OLD_NS_BEAT,
max_loop_interval=Option(0, type='float'),
schedule=Option({}, type='dict'),
scheduler=Option('celery.beat:PersistentScheduler'),
schedule_filename=Option('celerybeat-schedule'),
sync_every=Option(0, type='int'),
cron_starting_deadline=Option(None, type=int)
),
broker=Namespace(
url=Option(None, type='string'),
read_url=Option(None, type='string'),
write_url=Option(None, type='string'),
transport=Option(type='string'),
transport_options=Option({}, type='dict'),
connection_timeout=Option(4, type='float'),
connection_retry=Option(True, type='bool'),
connection_retry_on_startup=Option(None, type='bool'),
connection_max_retries=Option(100, type='int'),
channel_error_retry=Option(False, type='bool'),
failover_strategy=Option(None, type='string'),
heartbeat=Option(120, type='int'),
heartbeat_checkrate=Option(3.0, type='int'),
login_method=Option(None, type='string'),
pool_limit=Option(10, type='int'),
use_ssl=Option(False, type='bool'),
host=Option(type='string'),
port=Option(type='int'),
user=Option(type='string'),
password=Option(type='string'),
vhost=Option(type='string'),
),
cache=Namespace(
__old__=old_ns('celery_cache'),
backend=Option(),
backend_options=Option({}, type='dict'),
),
cassandra=Namespace(
entry_ttl=Option(type='float'),
keyspace=Option(type='string'),
port=Option(type='string'),
read_consistency=Option(type='string'),
servers=Option(type='list'),
bundle_path=Option(type='string'),
table=Option(type='string'),
write_consistency=Option(type='string'),
auth_provider=Option(type='string'),
auth_kwargs=Option(type='string'),
options=Option({}, type='dict'),
),
s3=Namespace(
access_key_id=Option(type='string'),
secret_access_key=Option(type='string'),
bucket=Option(type='string'),
base_path=Option(type='string'),
endpoint_url=Option(type='string'),
region=Option(type='string'),
),
azureblockblob=Namespace(
container_name=Option('celery', type='string'),
retry_initial_backoff_sec=Option(2, type='int'),
retry_increment_base=Option(2, type='int'),
retry_max_attempts=Option(3, type='int'),
base_path=Option('', type='string'),
connection_timeout=Option(20, type='int'),
read_timeout=Option(120, type='int'),
),
control=Namespace(
queue_ttl=Option(300.0, type='float'),
queue_expires=Option(10.0, type='float'),
exchange=Option('celery', type='string'),
),
couchbase=Namespace(
__old__=old_ns('celery_couchbase'),
backend_settings=Option(None, type='dict'),
),
arangodb=Namespace(
__old__=old_ns('celery_arangodb'),
backend_settings=Option(None, type='dict')
),
mongodb=Namespace(
__old__=old_ns('celery_mongodb'),
backend_settings=Option(type='dict'),
),
cosmosdbsql=Namespace(
database_name=Option('celerydb', type='string'),
collection_name=Option('celerycol', type='string'),
consistency_level=Option('Session', type='string'),
max_retry_attempts=Option(9, type='int'),
max_retry_wait_time=Option(30, type='int'),
),
event=Namespace(
__old__=old_ns('celery_event'),
queue_expires=Option(60.0, type='float'),
queue_ttl=Option(5.0, type='float'),
queue_prefix=Option('celeryev'),
serializer=Option('json'),
exchange=Option('celeryev', type='string'),
),
redis=Namespace(
__old__=old_ns('celery_redis'),
backend_use_ssl=Option(type='dict'),
db=Option(type='int'),
host=Option(type='string'),
max_connections=Option(type='int'),
username=Option(type='string'),
password=Option(type='string'),
port=Option(type='int'),
socket_timeout=Option(120.0, type='float'),
socket_connect_timeout=Option(None, type='float'),
retry_on_timeout=Option(False, type='bool'),
socket_keepalive=Option(False, type='bool'),
),
result=Namespace(
__old__=old_ns('celery_result'),
backend=Option(type='string'),
cache_max=Option(
-1,
type='int', old={'celery_max_cached_results'},
),
compression=Option(type='str'),
exchange=Option('celeryresults'),
exchange_type=Option('direct'),
expires=Option(
timedelta(days=1),
type='float', old={'celery_task_result_expires'},
),
persistent=Option(None, type='bool'),
extended=Option(False, type='bool'),
serializer=Option('json'),
backend_transport_options=Option({}, type='dict'),
chord_retry_interval=Option(1.0, type='float'),
chord_join_timeout=Option(3.0, type='float'),
backend_max_sleep_between_retries_ms=Option(10000, type='int'),
backend_max_retries=Option(float("inf"), type='float'),
backend_base_sleep_between_retries_ms=Option(10, type='int'),
backend_always_retry=Option(False, type='bool'),
),
elasticsearch=Namespace(
__old__=old_ns('celery_elasticsearch'),
retry_on_timeout=Option(type='bool'),
max_retries=Option(type='int'),
timeout=Option(type='float'),
save_meta_as_text=Option(True, type='bool'),
),
security=Namespace(
__old__=old_ns('celery_security'),
certificate=Option(type='string'),
cert_store=Option(type='string'),
key=Option(type='string'),
key_password=Option(type='bytes'),
digest=Option(DEFAULT_SECURITY_DIGEST, type='string'),
),
database=Namespace(
url=Option(old={'celery_result_dburi'}),
engine_options=Option(
type='dict', old={'celery_result_engine_options'},
),
short_lived_sessions=Option(
False, type='bool', old={'celery_result_db_short_lived_sessions'},
),
table_schemas=Option(type='dict'),
table_names=Option(type='dict', old={'celery_result_db_tablenames'}),
),
task=Namespace(
__old__=OLD_NS,
acks_late=Option(False, type='bool'),
acks_on_failure_or_timeout=Option(True, type='bool'),
always_eager=Option(False, type='bool'),
annotations=Option(type='any'),
compression=Option(type='string', old={'celery_message_compression'}),
create_missing_queues=Option(True, type='bool'),
inherit_parent_priority=Option(False, type='bool'),
default_delivery_mode=Option(2, type='string'),
default_queue=Option('celery'),
default_exchange=Option(None, type='string'), # taken from queue
default_exchange_type=Option('direct'),
default_routing_key=Option(None, type='string'), # taken from queue
default_rate_limit=Option(type='string'),
default_priority=Option(None, type='string'),
eager_propagates=Option(
False, type='bool', old={'celery_eager_propagates_exceptions'},
),
ignore_result=Option(False, type='bool'),
store_eager_result=Option(False, type='bool'),
protocol=Option(2, type='int', old={'celery_task_protocol'}),
publish_retry=Option(
True, type='bool', old={'celery_task_publish_retry'},
),
publish_retry_policy=Option(
{'max_retries': 3,
'interval_start': 0,
'interval_max': 1,
'interval_step': 0.2},
type='dict', old={'celery_task_publish_retry_policy'},
),
queues=Option(type='dict'),
queue_max_priority=Option(None, type='int'),
reject_on_worker_lost=Option(type='bool'),
remote_tracebacks=Option(False, type='bool'),
routes=Option(type='any'),
send_sent_event=Option(
False, type='bool', old={'celery_send_task_sent_event'},
),
serializer=Option('json', old={'celery_task_serializer'}),
soft_time_limit=Option(
type='float', old={'celeryd_task_soft_time_limit'},
),
time_limit=Option(
type='float', old={'celeryd_task_time_limit'},
),
store_errors_even_if_ignored=Option(False, type='bool'),
track_started=Option(False, type='bool'),
allow_error_cb_on_chord_header=Option(False, type='bool'),
),
worker=Namespace(
__old__=OLD_NS_WORKER,
agent=Option(None, type='string'),
autoscaler=Option('celery.worker.autoscale:Autoscaler'),
cancel_long_running_tasks_on_connection_loss=Option(
False, type='bool'
),
concurrency=Option(None, type='int'),
consumer=Option('celery.worker.consumer:Consumer', type='string'),
direct=Option(False, type='bool', old={'celery_worker_direct'}),
disable_rate_limits=Option(
False, type='bool', old={'celery_disable_rate_limits'},
),
deduplicate_successful_tasks=Option(
False, type='bool'
),
enable_remote_control=Option(
True, type='bool', old={'celery_enable_remote_control'},
),
hijack_root_logger=Option(True, type='bool'),
log_color=Option(type='bool'),
log_format=Option(DEFAULT_PROCESS_LOG_FMT),
lost_wait=Option(10.0, type='float', old={'celeryd_worker_lost_wait'}),
max_memory_per_child=Option(type='int'),
max_tasks_per_child=Option(type='int'),
pool=Option(DEFAULT_POOL),
pool_putlocks=Option(True, type='bool'),
pool_restarts=Option(False, type='bool'),
proc_alive_timeout=Option(4.0, type='float'),
prefetch_multiplier=Option(4, type='int'),
redirect_stdouts=Option(
True, type='bool', old={'celery_redirect_stdouts'},
),
redirect_stdouts_level=Option(
'WARNING', old={'celery_redirect_stdouts_level'},
),
send_task_events=Option(
False, type='bool', old={'celery_send_events'},
),
state_db=Option(),
task_log_format=Option(DEFAULT_TASK_LOG_FMT),
timer=Option(type='string'),
timer_precision=Option(1.0, type='float'),
),
)
def _flatten_keys(ns, key, opt):
return [(ns + key, opt)]
def _to_compat(ns, key, opt):
if opt.old:
return [
(oldkey.format(key).upper(), ns + key, opt)
for oldkey in opt.old
]
return [((ns + key).upper(), ns + key, opt)]
def flatten(d, root='', keyfilter=_flatten_keys):
"""Flatten settings."""
stack = deque([(root, d)])
while stack:
ns, options = stack.popleft()
for key, opt in options.items():
if isinstance(opt, dict):
stack.append((ns + key + '_', opt))
else:
yield from keyfilter(ns, key, opt)
DEFAULTS = {
key: opt.default for key, opt in flatten(NAMESPACES)
}
__compat = list(flatten(NAMESPACES, keyfilter=_to_compat))
_OLD_DEFAULTS = {old_key: opt.default for old_key, _, opt in __compat}
_TO_OLD_KEY = {new_key: old_key for old_key, new_key, _ in __compat}
_TO_NEW_KEY = {old_key: new_key for old_key, new_key, _ in __compat}
__compat = None
SETTING_KEYS = set(DEFAULTS.keys())
_OLD_SETTING_KEYS = set(_TO_NEW_KEY.keys())
def find_deprecated_settings(source): # pragma: no cover
from celery.utils import deprecated
for name, opt in flatten(NAMESPACES):
if (opt.deprecate_by or opt.remove_by) and getattr(source, name, None):
deprecated.warn(description=f'The {name!r} setting',
deprecation=opt.deprecate_by,
removal=opt.remove_by,
alternative=f'Use the {opt.alt} instead')
return source
@memoize(maxsize=None)
def find(name, namespace='celery'):
"""Find setting by name."""
# - Try specified name-space first.
namespace = namespace.lower()
try:
return searchresult(
namespace, name.lower(), NAMESPACES[namespace][name.lower()],
)
except KeyError:
# - Try all the other namespaces.
for ns, opts in NAMESPACES.items():
if ns.lower() == name.lower():
return searchresult(None, ns, opts)
elif isinstance(opts, dict):
try:
return searchresult(ns, name.lower(), opts[name.lower()])
except KeyError:
pass
# - See if name is a qualname last.
return searchresult(None, name.lower(), DEFAULTS[name.lower()])

View File

@@ -0,0 +1,40 @@
"""Implementation for the app.events shortcuts."""
from contextlib import contextmanager
from kombu.utils.objects import cached_property
class Events:
"""Implements app.events."""
receiver_cls = 'celery.events.receiver:EventReceiver'
dispatcher_cls = 'celery.events.dispatcher:EventDispatcher'
state_cls = 'celery.events.state:State'
def __init__(self, app=None):
self.app = app
@cached_property
def Receiver(self):
return self.app.subclass_with_self(
self.receiver_cls, reverse='events.Receiver')
@cached_property
def Dispatcher(self):
return self.app.subclass_with_self(
self.dispatcher_cls, reverse='events.Dispatcher')
@cached_property
def State(self):
return self.app.subclass_with_self(
self.state_cls, reverse='events.State')
@contextmanager
def default_dispatcher(self, hostname=None, enabled=True,
buffer_while_offline=False):
with self.app.amqp.producer_pool.acquire(block=True) as prod:
# pylint: disable=too-many-function-args
# This is a property pylint...
with self.Dispatcher(prod.connection, hostname, enabled,
prod.channel, buffer_while_offline) as d:
yield d

View File

@@ -0,0 +1,247 @@
"""Logging configuration.
The Celery instances logging section: ``Celery.log``.
Sets up logging for the worker and other programs,
redirects standard outs, colors log output, patches logging
related compatibility fixes, and so on.
"""
import logging
import os
import sys
import warnings
from logging.handlers import WatchedFileHandler
from kombu.utils.encoding import set_default_encoding_file
from celery import signals
from celery._state import get_current_task
from celery.exceptions import CDeprecationWarning, CPendingDeprecationWarning
from celery.local import class_property
from celery.utils.log import (ColorFormatter, LoggingProxy, get_logger, get_multiprocessing_logger, mlevel,
reset_multiprocessing_logger)
from celery.utils.nodenames import node_format
from celery.utils.term import colored
__all__ = ('TaskFormatter', 'Logging')
MP_LOG = os.environ.get('MP_LOG', False)
class TaskFormatter(ColorFormatter):
"""Formatter for tasks, adding the task name and id."""
def format(self, record):
task = get_current_task()
if task and task.request:
record.__dict__.update(task_id=task.request.id,
task_name=task.name)
else:
record.__dict__.setdefault('task_name', '???')
record.__dict__.setdefault('task_id', '???')
return super().format(record)
class Logging:
"""Application logging setup (app.log)."""
#: The logging subsystem is only configured once per process.
#: setup_logging_subsystem sets this flag, and subsequent calls
#: will do nothing.
_setup = False
def __init__(self, app):
self.app = app
self.loglevel = mlevel(logging.WARN)
self.format = self.app.conf.worker_log_format
self.task_format = self.app.conf.worker_task_log_format
self.colorize = self.app.conf.worker_log_color
def setup(self, loglevel=None, logfile=None, redirect_stdouts=False,
redirect_level='WARNING', colorize=None, hostname=None):
loglevel = mlevel(loglevel)
handled = self.setup_logging_subsystem(
loglevel, logfile, colorize=colorize, hostname=hostname,
)
if not handled and redirect_stdouts:
self.redirect_stdouts(redirect_level)
os.environ.update(
CELERY_LOG_LEVEL=str(loglevel) if loglevel else '',
CELERY_LOG_FILE=str(logfile) if logfile else '',
)
warnings.filterwarnings('always', category=CDeprecationWarning)
warnings.filterwarnings('always', category=CPendingDeprecationWarning)
logging.captureWarnings(True)
return handled
def redirect_stdouts(self, loglevel=None, name='celery.redirected'):
self.redirect_stdouts_to_logger(
get_logger(name), loglevel=loglevel
)
os.environ.update(
CELERY_LOG_REDIRECT='1',
CELERY_LOG_REDIRECT_LEVEL=str(loglevel or ''),
)
def setup_logging_subsystem(self, loglevel=None, logfile=None, format=None,
colorize=None, hostname=None, **kwargs):
if self.already_setup:
return
if logfile and hostname:
logfile = node_format(logfile, hostname)
Logging._setup = True
loglevel = mlevel(loglevel or self.loglevel)
format = format or self.format
colorize = self.supports_color(colorize, logfile)
reset_multiprocessing_logger()
receivers = signals.setup_logging.send(
sender=None, loglevel=loglevel, logfile=logfile,
format=format, colorize=colorize,
)
if not receivers:
root = logging.getLogger()
if self.app.conf.worker_hijack_root_logger:
root.handlers = []
get_logger('celery').handlers = []
get_logger('celery.task').handlers = []
get_logger('celery.redirected').handlers = []
# Configure root logger
self._configure_logger(
root, logfile, loglevel, format, colorize, **kwargs
)
# Configure the multiprocessing logger
self._configure_logger(
get_multiprocessing_logger(),
logfile, loglevel if MP_LOG else logging.ERROR,
format, colorize, **kwargs
)
signals.after_setup_logger.send(
sender=None, logger=root,
loglevel=loglevel, logfile=logfile,
format=format, colorize=colorize,
)
# then setup the root task logger.
self.setup_task_loggers(loglevel, logfile, colorize=colorize)
try:
stream = logging.getLogger().handlers[0].stream
except (AttributeError, IndexError):
pass
else:
set_default_encoding_file(stream)
# This is a hack for multiprocessing's fork+exec, so that
# logging before Process.run works.
logfile_name = logfile if isinstance(logfile, str) else ''
os.environ.update(_MP_FORK_LOGLEVEL_=str(loglevel),
_MP_FORK_LOGFILE_=logfile_name,
_MP_FORK_LOGFORMAT_=format)
return receivers
def _configure_logger(self, logger, logfile, loglevel,
format, colorize, **kwargs):
if logger is not None:
self.setup_handlers(logger, logfile, format,
colorize, **kwargs)
if loglevel:
logger.setLevel(loglevel)
def setup_task_loggers(self, loglevel=None, logfile=None, format=None,
colorize=None, propagate=False, **kwargs):
"""Setup the task logger.
If `logfile` is not specified, then `sys.stderr` is used.
Will return the base task logger object.
"""
loglevel = mlevel(loglevel or self.loglevel)
format = format or self.task_format
colorize = self.supports_color(colorize, logfile)
logger = self.setup_handlers(
get_logger('celery.task'),
logfile, format, colorize,
formatter=TaskFormatter, **kwargs
)
logger.setLevel(loglevel)
# this is an int for some reason, better to not question why.
logger.propagate = int(propagate)
signals.after_setup_task_logger.send(
sender=None, logger=logger,
loglevel=loglevel, logfile=logfile,
format=format, colorize=colorize,
)
return logger
def redirect_stdouts_to_logger(self, logger, loglevel=None,
stdout=True, stderr=True):
"""Redirect :class:`sys.stdout` and :class:`sys.stderr` to logger.
Arguments:
logger (logging.Logger): Logger instance to redirect to.
loglevel (int, str): The loglevel redirected message
will be logged as.
"""
proxy = LoggingProxy(logger, loglevel)
if stdout:
sys.stdout = proxy
if stderr:
sys.stderr = proxy
return proxy
def supports_color(self, colorize=None, logfile=None):
colorize = self.colorize if colorize is None else colorize
if self.app.IS_WINDOWS:
# Windows does not support ANSI color codes.
return False
if colorize or colorize is None:
# Only use color if there's no active log file
# and stderr is an actual terminal.
return logfile is None and sys.stderr.isatty()
return colorize
def colored(self, logfile=None, enabled=None):
return colored(enabled=self.supports_color(enabled, logfile))
def setup_handlers(self, logger, logfile, format, colorize,
formatter=ColorFormatter, **kwargs):
if self._is_configured(logger):
return logger
handler = self._detect_handler(logfile)
handler.setFormatter(formatter(format, use_color=colorize))
logger.addHandler(handler)
return logger
def _detect_handler(self, logfile=None):
"""Create handler from filename, an open stream or `None` (stderr)."""
logfile = sys.__stderr__ if logfile is None else logfile
if hasattr(logfile, 'write'):
return logging.StreamHandler(logfile)
return WatchedFileHandler(logfile, encoding='utf-8')
def _has_handler(self, logger):
return any(
not isinstance(h, logging.NullHandler)
for h in logger.handlers or []
)
def _is_configured(self, logger):
return self._has_handler(logger) and not getattr(
logger, '_rudimentary_setup', False)
def get_default_logger(self, name='celery', **kwargs):
return get_logger(name)
@class_property
def already_setup(self):
return self._setup
@already_setup.setter
def already_setup(self, was_setup):
self._setup = was_setup

View File

@@ -0,0 +1,68 @@
"""Registry of available tasks."""
import inspect
from importlib import import_module
from celery._state import get_current_app
from celery.app.autoretry import add_autoretry_behaviour
from celery.exceptions import InvalidTaskError, NotRegistered
__all__ = ('TaskRegistry',)
class TaskRegistry(dict):
"""Map of registered tasks."""
NotRegistered = NotRegistered
def __missing__(self, key):
raise self.NotRegistered(key)
def register(self, task):
"""Register a task in the task registry.
The task will be automatically instantiated if not already an
instance. Name must be configured prior to registration.
"""
if task.name is None:
raise InvalidTaskError(
'Task class {!r} must specify .name attribute'.format(
type(task).__name__))
task = inspect.isclass(task) and task() or task
add_autoretry_behaviour(task)
self[task.name] = task
def unregister(self, name):
"""Unregister task by name.
Arguments:
name (str): name of the task to unregister, or a
:class:`celery.app.task.Task` with a valid `name` attribute.
Raises:
celery.exceptions.NotRegistered: if the task is not registered.
"""
try:
self.pop(getattr(name, 'name', name))
except KeyError:
raise self.NotRegistered(name)
# -- these methods are irrelevant now and will be removed in 4.0
def regular(self):
return self.filter_types('regular')
def periodic(self):
return self.filter_types('periodic')
def filter_types(self, type):
return {name: task for name, task in self.items()
if getattr(task, 'type', 'regular') == type}
def _unpickle_task(name):
return get_current_app().tasks[name]
def _unpickle_task_v2(name, module=None):
if module:
import_module(module)
return get_current_app().tasks[name]

View File

@@ -0,0 +1,136 @@
"""Task Routing.
Contains utilities for working with task routers, (:setting:`task_routes`).
"""
import fnmatch
import re
from collections import OrderedDict
from collections.abc import Mapping
from kombu import Queue
from celery.exceptions import QueueNotFound
from celery.utils.collections import lpmerge
from celery.utils.functional import maybe_evaluate, mlazy
from celery.utils.imports import symbol_by_name
try:
Pattern = re._pattern_type
except AttributeError: # pragma: no cover
# for support Python 3.7
Pattern = re.Pattern
__all__ = ('MapRoute', 'Router', 'prepare')
class MapRoute:
"""Creates a router out of a :class:`dict`."""
def __init__(self, map):
map = map.items() if isinstance(map, Mapping) else map
self.map = {}
self.patterns = OrderedDict()
for k, v in map:
if isinstance(k, Pattern):
self.patterns[k] = v
elif '*' in k:
self.patterns[re.compile(fnmatch.translate(k))] = v
else:
self.map[k] = v
def __call__(self, name, *args, **kwargs):
try:
return dict(self.map[name])
except KeyError:
pass
except ValueError:
return {'queue': self.map[name]}
for regex, route in self.patterns.items():
if regex.match(name):
try:
return dict(route)
except ValueError:
return {'queue': route}
class Router:
"""Route tasks based on the :setting:`task_routes` setting."""
def __init__(self, routes=None, queues=None,
create_missing=False, app=None):
self.app = app
self.queues = {} if queues is None else queues
self.routes = [] if routes is None else routes
self.create_missing = create_missing
def route(self, options, name, args=(), kwargs=None, task_type=None):
kwargs = {} if not kwargs else kwargs
options = self.expand_destination(options) # expands 'queue'
if self.routes:
route = self.lookup_route(name, args, kwargs, options, task_type)
if route: # expands 'queue' in route.
return lpmerge(self.expand_destination(route), options)
if 'queue' not in options:
options = lpmerge(self.expand_destination(
self.app.conf.task_default_queue), options)
return options
def expand_destination(self, route):
# Route can be a queue name: convenient for direct exchanges.
if isinstance(route, str):
queue, route = route, {}
else:
# can use defaults from configured queue, but override specific
# things (like the routing_key): great for topic exchanges.
queue = route.pop('queue', None)
if queue:
if isinstance(queue, Queue):
route['queue'] = queue
else:
try:
route['queue'] = self.queues[queue]
except KeyError:
raise QueueNotFound(
f'Queue {queue!r} missing from task_queues')
return route
def lookup_route(self, name,
args=None, kwargs=None, options=None, task_type=None):
query = self.query_router
for router in self.routes:
route = query(router, name, args, kwargs, options, task_type)
if route is not None:
return route
def query_router(self, router, task, args, kwargs, options, task_type):
router = maybe_evaluate(router)
if hasattr(router, 'route_for_task'):
# pre 4.0 router class
return router.route_for_task(task, args, kwargs)
return router(task, args, kwargs, options, task=task_type)
def expand_router_string(router):
router = symbol_by_name(router)
if hasattr(router, 'route_for_task'):
# need to instantiate pre 4.0 router classes
router = router()
return router
def prepare(routes):
"""Expand the :setting:`task_routes` setting."""
def expand_route(route):
if isinstance(route, (Mapping, list, tuple)):
return MapRoute(route)
if isinstance(route, str):
return mlazy(expand_router_string, route)
return route
if routes is None:
return ()
if not isinstance(routes, (list, tuple)):
routes = (routes,)
return [expand_route(route) for route in routes]

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,763 @@
"""Trace task execution.
This module defines how the task execution is traced:
errors are recorded, handlers are applied and so on.
"""
import logging
import os
import sys
import time
from collections import namedtuple
from typing import Any, Callable, Dict, FrozenSet, Optional, Sequence, Tuple, Type, Union
from warnings import warn
from billiard.einfo import ExceptionInfo, ExceptionWithTraceback
from kombu.exceptions import EncodeError
from kombu.serialization import loads as loads_message
from kombu.serialization import prepare_accept_content
from kombu.utils.encoding import safe_repr, safe_str
import celery
import celery.loaders.app
from celery import current_app, group, signals, states
from celery._state import _task_stack
from celery.app.task import Context
from celery.app.task import Task as BaseTask
from celery.exceptions import BackendGetMetaError, Ignore, InvalidTaskError, Reject, Retry
from celery.result import AsyncResult
from celery.utils.log import get_logger
from celery.utils.nodenames import gethostname
from celery.utils.objects import mro_lookup
from celery.utils.saferepr import saferepr
from celery.utils.serialization import get_pickleable_etype, get_pickleable_exception, get_pickled_exception
# ## ---
# This is the heart of the worker, the inner loop so to speak.
# It used to be split up into nice little classes and methods,
# but in the end it only resulted in bad performance and horrible tracebacks,
# so instead we now use one closure per task class.
# pylint: disable=redefined-outer-name
# We cache globals and attribute lookups, so disable this warning.
# pylint: disable=broad-except
# We know what we're doing...
__all__ = (
'TraceInfo', 'build_tracer', 'trace_task',
'setup_worker_optimizations', 'reset_worker_optimizations',
)
from celery.worker.state import successful_requests
logger = get_logger(__name__)
#: Format string used to log task receipt.
LOG_RECEIVED = """\
Task %(name)s[%(id)s] received\
"""
#: Format string used to log task success.
LOG_SUCCESS = """\
Task %(name)s[%(id)s] succeeded in %(runtime)ss: %(return_value)s\
"""
#: Format string used to log task failure.
LOG_FAILURE = """\
Task %(name)s[%(id)s] %(description)s: %(exc)s\
"""
#: Format string used to log task internal error.
LOG_INTERNAL_ERROR = """\
Task %(name)s[%(id)s] %(description)s: %(exc)s\
"""
#: Format string used to log task ignored.
LOG_IGNORED = """\
Task %(name)s[%(id)s] %(description)s\
"""
#: Format string used to log task rejected.
LOG_REJECTED = """\
Task %(name)s[%(id)s] %(exc)s\
"""
#: Format string used to log task retry.
LOG_RETRY = """\
Task %(name)s[%(id)s] retry: %(exc)s\
"""
log_policy_t = namedtuple(
'log_policy_t',
('format', 'description', 'severity', 'traceback', 'mail'),
)
log_policy_reject = log_policy_t(LOG_REJECTED, 'rejected', logging.WARN, 1, 1)
log_policy_ignore = log_policy_t(LOG_IGNORED, 'ignored', logging.INFO, 0, 0)
log_policy_internal = log_policy_t(
LOG_INTERNAL_ERROR, 'INTERNAL ERROR', logging.CRITICAL, 1, 1,
)
log_policy_expected = log_policy_t(
LOG_FAILURE, 'raised expected', logging.INFO, 0, 0,
)
log_policy_unexpected = log_policy_t(
LOG_FAILURE, 'raised unexpected', logging.ERROR, 1, 1,
)
send_prerun = signals.task_prerun.send
send_postrun = signals.task_postrun.send
send_success = signals.task_success.send
STARTED = states.STARTED
SUCCESS = states.SUCCESS
IGNORED = states.IGNORED
REJECTED = states.REJECTED
RETRY = states.RETRY
FAILURE = states.FAILURE
EXCEPTION_STATES = states.EXCEPTION_STATES
IGNORE_STATES = frozenset({IGNORED, RETRY, REJECTED})
#: set by :func:`setup_worker_optimizations`
_localized = []
_patched = {}
trace_ok_t = namedtuple('trace_ok_t', ('retval', 'info', 'runtime', 'retstr'))
def info(fmt, context):
"""Log 'fmt % context' with severity 'INFO'.
'context' is also passed in extra with key 'data' for custom handlers.
"""
logger.info(fmt, context, extra={'data': context})
def task_has_custom(task, attr):
"""Return true if the task overrides ``attr``."""
return mro_lookup(task.__class__, attr, stop={BaseTask, object},
monkey_patched=['celery.app.task'])
def get_log_policy(task, einfo, exc):
if isinstance(exc, Reject):
return log_policy_reject
elif isinstance(exc, Ignore):
return log_policy_ignore
elif einfo.internal:
return log_policy_internal
else:
if task.throws and isinstance(exc, task.throws):
return log_policy_expected
return log_policy_unexpected
def get_task_name(request, default):
"""Use 'shadow' in request for the task name if applicable."""
# request.shadow could be None or an empty string.
# If so, we should use default.
return getattr(request, 'shadow', None) or default
class TraceInfo:
"""Information about task execution."""
__slots__ = ('state', 'retval')
def __init__(self, state, retval=None):
self.state = state
self.retval = retval
def handle_error_state(self, task, req,
eager=False, call_errbacks=True):
if task.ignore_result:
store_errors = task.store_errors_even_if_ignored
elif eager and task.store_eager_result:
store_errors = True
else:
store_errors = not eager
return {
RETRY: self.handle_retry,
FAILURE: self.handle_failure,
}[self.state](task, req,
store_errors=store_errors,
call_errbacks=call_errbacks)
def handle_reject(self, task, req, **kwargs):
self._log_error(task, req, ExceptionInfo())
def handle_ignore(self, task, req, **kwargs):
self._log_error(task, req, ExceptionInfo())
def handle_retry(self, task, req, store_errors=True, **kwargs):
"""Handle retry exception."""
# the exception raised is the Retry semi-predicate,
# and it's exc' attribute is the original exception raised (if any).
type_, _, tb = sys.exc_info()
try:
reason = self.retval
einfo = ExceptionInfo((type_, reason, tb))
if store_errors:
task.backend.mark_as_retry(
req.id, reason.exc, einfo.traceback, request=req,
)
task.on_retry(reason.exc, req.id, req.args, req.kwargs, einfo)
signals.task_retry.send(sender=task, request=req,
reason=reason, einfo=einfo)
info(LOG_RETRY, {
'id': req.id,
'name': get_task_name(req, task.name),
'exc': str(reason),
})
return einfo
finally:
del tb
def handle_failure(self, task, req, store_errors=True, call_errbacks=True):
"""Handle exception."""
orig_exc = self.retval
exc = get_pickleable_exception(orig_exc)
if exc.__traceback__ is None:
# `get_pickleable_exception` may have created a new exception without
# a traceback.
_, _, exc.__traceback__ = sys.exc_info()
exc_type = get_pickleable_etype(type(orig_exc))
# make sure we only send pickleable exceptions back to parent.
einfo = ExceptionInfo(exc_info=(exc_type, exc, exc.__traceback__))
task.backend.mark_as_failure(
req.id, exc, einfo.traceback,
request=req, store_result=store_errors,
call_errbacks=call_errbacks,
)
task.on_failure(exc, req.id, req.args, req.kwargs, einfo)
signals.task_failure.send(sender=task, task_id=req.id,
exception=exc, args=req.args,
kwargs=req.kwargs,
traceback=exc.__traceback__,
einfo=einfo)
self._log_error(task, req, einfo)
return einfo
def _log_error(self, task, req, einfo):
eobj = einfo.exception = get_pickled_exception(einfo.exception)
if isinstance(eobj, ExceptionWithTraceback):
eobj = einfo.exception = eobj.exc
exception, traceback, exc_info, sargs, skwargs = (
safe_repr(eobj),
safe_str(einfo.traceback),
einfo.exc_info,
req.get('argsrepr') or safe_repr(req.args),
req.get('kwargsrepr') or safe_repr(req.kwargs),
)
policy = get_log_policy(task, einfo, eobj)
context = {
'hostname': req.hostname,
'id': req.id,
'name': get_task_name(req, task.name),
'exc': exception,
'traceback': traceback,
'args': sargs,
'kwargs': skwargs,
'description': policy.description,
'internal': einfo.internal,
}
logger.log(policy.severity, policy.format.strip(), context,
exc_info=exc_info if policy.traceback else None,
extra={'data': context})
def traceback_clear(exc=None):
# Cleared Tb, but einfo still has a reference to Traceback.
# exc cleans up the Traceback at the last moment that can be revealed.
tb = None
if exc is not None:
if hasattr(exc, '__traceback__'):
tb = exc.__traceback__
else:
_, _, tb = sys.exc_info()
else:
_, _, tb = sys.exc_info()
while tb is not None:
try:
tb.tb_frame.clear()
tb.tb_frame.f_locals
except RuntimeError:
# Ignore the exception raised if the frame is still executing.
pass
tb = tb.tb_next
def build_tracer(
name: str,
task: Union[celery.Task, celery.local.PromiseProxy],
loader: Optional[celery.loaders.app.AppLoader] = None,
hostname: Optional[str] = None,
store_errors: bool = True,
Info: Type[TraceInfo] = TraceInfo,
eager: bool = False,
propagate: bool = False,
app: Optional[celery.Celery] = None,
monotonic: Callable[[], int] = time.monotonic,
trace_ok_t: Type[trace_ok_t] = trace_ok_t,
IGNORE_STATES: FrozenSet[str] = IGNORE_STATES) -> \
Callable[[str, Tuple[Any, ...], Dict[str, Any], Any], trace_ok_t]:
"""Return a function that traces task execution.
Catches all exceptions and updates result backend with the
state and result.
If the call was successful, it saves the result to the task result
backend, and sets the task status to `"SUCCESS"`.
If the call raises :exc:`~@Retry`, it extracts
the original exception, uses that as the result and sets the task state
to `"RETRY"`.
If the call results in an exception, it saves the exception as the task
result, and sets the task state to `"FAILURE"`.
Return a function that takes the following arguments:
:param uuid: The id of the task.
:param args: List of positional args to pass on to the function.
:param kwargs: Keyword arguments mapping to pass on to the function.
:keyword request: Request dict.
"""
# pylint: disable=too-many-statements
# If the task doesn't define a custom __call__ method
# we optimize it away by simply calling the run method directly,
# saving the extra method call and a line less in the stack trace.
fun = task if task_has_custom(task, '__call__') else task.run
loader = loader or app.loader
ignore_result = task.ignore_result
track_started = task.track_started
track_started = not eager and (task.track_started and not ignore_result)
# #6476
if eager and not ignore_result and task.store_eager_result:
publish_result = True
else:
publish_result = not eager and not ignore_result
deduplicate_successful_tasks = ((app.conf.task_acks_late or task.acks_late)
and app.conf.worker_deduplicate_successful_tasks
and app.backend.persistent)
hostname = hostname or gethostname()
inherit_parent_priority = app.conf.task_inherit_parent_priority
loader_task_init = loader.on_task_init
loader_cleanup = loader.on_process_cleanup
task_before_start = None
task_on_success = None
task_after_return = None
if task_has_custom(task, 'before_start'):
task_before_start = task.before_start
if task_has_custom(task, 'on_success'):
task_on_success = task.on_success
if task_has_custom(task, 'after_return'):
task_after_return = task.after_return
pid = os.getpid()
request_stack = task.request_stack
push_request = request_stack.push
pop_request = request_stack.pop
push_task = _task_stack.push
pop_task = _task_stack.pop
_does_info = logger.isEnabledFor(logging.INFO)
resultrepr_maxsize = task.resultrepr_maxsize
prerun_receivers = signals.task_prerun.receivers
postrun_receivers = signals.task_postrun.receivers
success_receivers = signals.task_success.receivers
from celery import canvas
signature = canvas.maybe_signature # maybe_ does not clone if already
def on_error(
request: celery.app.task.Context,
exc: Union[Exception, Type[Exception]],
state: str = FAILURE,
call_errbacks: bool = True) -> Tuple[Info, Any, Any, Any]:
"""Handle any errors raised by a `Task`'s execution."""
if propagate:
raise
I = Info(state, exc)
R = I.handle_error_state(
task, request, eager=eager, call_errbacks=call_errbacks,
)
return I, R, I.state, I.retval
def trace_task(
uuid: str,
args: Sequence[Any],
kwargs: Dict[str, Any],
request: Optional[Dict[str, Any]] = None) -> trace_ok_t:
"""Execute and trace a `Task`."""
# R - is the possibly prepared return value.
# I - is the Info object.
# T - runtime
# Rstr - textual representation of return value
# retval - is the always unmodified return value.
# state - is the resulting task state.
# This function is very long because we've unrolled all the calls
# for performance reasons, and because the function is so long
# we want the main variables (I, and R) to stand out visually from the
# the rest of the variables, so breaking PEP8 is worth it ;)
R = I = T = Rstr = retval = state = None
task_request = None
time_start = monotonic()
try:
try:
kwargs.items
except AttributeError:
raise InvalidTaskError(
'Task keyword arguments is not a mapping')
task_request = Context(request or {}, args=args,
called_directly=False, kwargs=kwargs)
redelivered = (task_request.delivery_info
and task_request.delivery_info.get('redelivered', False))
if deduplicate_successful_tasks and redelivered:
if task_request.id in successful_requests:
return trace_ok_t(R, I, T, Rstr)
r = AsyncResult(task_request.id, app=app)
try:
state = r.state
except BackendGetMetaError:
pass
else:
if state == SUCCESS:
info(LOG_IGNORED, {
'id': task_request.id,
'name': get_task_name(task_request, name),
'description': 'Task already completed successfully.'
})
return trace_ok_t(R, I, T, Rstr)
push_task(task)
root_id = task_request.root_id or uuid
task_priority = task_request.delivery_info.get('priority') if \
inherit_parent_priority else None
push_request(task_request)
try:
# -*- PRE -*-
if prerun_receivers:
send_prerun(sender=task, task_id=uuid, task=task,
args=args, kwargs=kwargs)
loader_task_init(uuid, task)
if track_started:
task.backend.store_result(
uuid, {'pid': pid, 'hostname': hostname}, STARTED,
request=task_request,
)
# -*- TRACE -*-
try:
if task_before_start:
task_before_start(uuid, args, kwargs)
R = retval = fun(*args, **kwargs)
state = SUCCESS
except Reject as exc:
I, R = Info(REJECTED, exc), ExceptionInfo(internal=True)
state, retval = I.state, I.retval
I.handle_reject(task, task_request)
traceback_clear(exc)
except Ignore as exc:
I, R = Info(IGNORED, exc), ExceptionInfo(internal=True)
state, retval = I.state, I.retval
I.handle_ignore(task, task_request)
traceback_clear(exc)
except Retry as exc:
I, R, state, retval = on_error(
task_request, exc, RETRY, call_errbacks=False)
traceback_clear(exc)
except Exception as exc:
I, R, state, retval = on_error(task_request, exc)
traceback_clear(exc)
except BaseException:
raise
else:
try:
# callback tasks must be applied before the result is
# stored, so that result.children is populated.
# groups are called inline and will store trail
# separately, so need to call them separately
# so that the trail's not added multiple times :(
# (Issue #1936)
callbacks = task.request.callbacks
if callbacks:
if len(task.request.callbacks) > 1:
sigs, groups = [], []
for sig in callbacks:
sig = signature(sig, app=app)
if isinstance(sig, group):
groups.append(sig)
else:
sigs.append(sig)
for group_ in groups:
group_.apply_async(
(retval,),
parent_id=uuid, root_id=root_id,
priority=task_priority
)
if sigs:
group(sigs, app=app).apply_async(
(retval,),
parent_id=uuid, root_id=root_id,
priority=task_priority
)
else:
signature(callbacks[0], app=app).apply_async(
(retval,), parent_id=uuid, root_id=root_id,
priority=task_priority
)
# execute first task in chain
chain = task_request.chain
if chain:
_chsig = signature(chain.pop(), app=app)
_chsig.apply_async(
(retval,), chain=chain,
parent_id=uuid, root_id=root_id,
priority=task_priority
)
task.backend.mark_as_done(
uuid, retval, task_request, publish_result,
)
except EncodeError as exc:
I, R, state, retval = on_error(task_request, exc)
else:
Rstr = saferepr(R, resultrepr_maxsize)
T = monotonic() - time_start
if task_on_success:
task_on_success(retval, uuid, args, kwargs)
if success_receivers:
send_success(sender=task, result=retval)
if _does_info:
info(LOG_SUCCESS, {
'id': uuid,
'name': get_task_name(task_request, name),
'return_value': Rstr,
'runtime': T,
'args': task_request.get('argsrepr') or safe_repr(args),
'kwargs': task_request.get('kwargsrepr') or safe_repr(kwargs),
})
# -* POST *-
if state not in IGNORE_STATES:
if task_after_return:
task_after_return(
state, retval, uuid, args, kwargs, None,
)
finally:
try:
if postrun_receivers:
send_postrun(sender=task, task_id=uuid, task=task,
args=args, kwargs=kwargs,
retval=retval, state=state)
finally:
pop_task()
pop_request()
if not eager:
try:
task.backend.process_cleanup()
loader_cleanup()
except (KeyboardInterrupt, SystemExit, MemoryError):
raise
except Exception as exc:
logger.error('Process cleanup failed: %r', exc,
exc_info=True)
except MemoryError:
raise
except Exception as exc:
_signal_internal_error(task, uuid, args, kwargs, request, exc)
if eager:
raise
R = report_internal_error(task, exc)
if task_request is not None:
I, _, _, _ = on_error(task_request, exc)
return trace_ok_t(R, I, T, Rstr)
return trace_task
def trace_task(task, uuid, args, kwargs, request=None, **opts):
"""Trace task execution."""
request = {} if not request else request
try:
if task.__trace__ is None:
task.__trace__ = build_tracer(task.name, task, **opts)
return task.__trace__(uuid, args, kwargs, request)
except Exception as exc:
_signal_internal_error(task, uuid, args, kwargs, request, exc)
return trace_ok_t(report_internal_error(task, exc), TraceInfo(FAILURE, exc), 0.0, None)
def _signal_internal_error(task, uuid, args, kwargs, request, exc):
"""Send a special `internal_error` signal to the app for outside body errors."""
try:
_, _, tb = sys.exc_info()
einfo = ExceptionInfo()
einfo.exception = get_pickleable_exception(einfo.exception)
einfo.type = get_pickleable_etype(einfo.type)
signals.task_internal_error.send(
sender=task,
task_id=uuid,
args=args,
kwargs=kwargs,
request=request,
exception=exc,
traceback=tb,
einfo=einfo,
)
finally:
del tb
def trace_task_ret(name, uuid, request, body, content_type,
content_encoding, loads=loads_message, app=None,
**extra_request):
app = app or current_app._get_current_object()
embed = None
if content_type:
accept = prepare_accept_content(app.conf.accept_content)
args, kwargs, embed = loads(
body, content_type, content_encoding, accept=accept,
)
else:
args, kwargs, embed = body
hostname = gethostname()
request.update({
'args': args, 'kwargs': kwargs,
'hostname': hostname, 'is_eager': False,
}, **embed or {})
R, I, T, Rstr = trace_task(app.tasks[name],
uuid, args, kwargs, request, app=app)
return (1, R, T) if I else (0, Rstr, T)
def fast_trace_task(task, uuid, request, body, content_type,
content_encoding, loads=loads_message, _loc=None,
hostname=None, **_):
_loc = _localized if not _loc else _loc
embed = None
tasks, accept, hostname = _loc
if content_type:
args, kwargs, embed = loads(
body, content_type, content_encoding, accept=accept,
)
else:
args, kwargs, embed = body
request.update({
'args': args, 'kwargs': kwargs,
'hostname': hostname, 'is_eager': False,
}, **embed or {})
R, I, T, Rstr = tasks[task].__trace__(
uuid, args, kwargs, request,
)
return (1, R, T) if I else (0, Rstr, T)
def report_internal_error(task, exc):
_type, _value, _tb = sys.exc_info()
try:
_value = task.backend.prepare_exception(exc, 'pickle')
exc_info = ExceptionInfo((_type, _value, _tb), internal=True)
warn(RuntimeWarning(
'Exception raised outside body: {!r}:\n{}'.format(
exc, exc_info.traceback)))
return exc_info
finally:
del _tb
def setup_worker_optimizations(app, hostname=None):
"""Setup worker related optimizations."""
hostname = hostname or gethostname()
# make sure custom Task.__call__ methods that calls super
# won't mess up the request/task stack.
_install_stack_protection()
# all new threads start without a current app, so if an app is not
# passed on to the thread it will fall back to the "default app",
# which then could be the wrong app. So for the worker
# we set this to always return our app. This is a hack,
# and means that only a single app can be used for workers
# running in the same process.
app.set_current()
app.set_default()
# evaluate all task classes by finalizing the app.
app.finalize()
# set fast shortcut to task registry
_localized[:] = [
app._tasks,
prepare_accept_content(app.conf.accept_content),
hostname,
]
app.use_fast_trace_task = True
def reset_worker_optimizations(app=current_app):
"""Reset previously configured optimizations."""
try:
delattr(BaseTask, '_stackprotected')
except AttributeError:
pass
try:
BaseTask.__call__ = _patched.pop('BaseTask.__call__')
except KeyError:
pass
app.use_fast_trace_task = False
def _install_stack_protection():
# Patches BaseTask.__call__ in the worker to handle the edge case
# where people override it and also call super.
#
# - The worker optimizes away BaseTask.__call__ and instead
# calls task.run directly.
# - so with the addition of current_task and the request stack
# BaseTask.__call__ now pushes to those stacks so that
# they work when tasks are called directly.
#
# The worker only optimizes away __call__ in the case
# where it hasn't been overridden, so the request/task stack
# will blow if a custom task class defines __call__ and also
# calls super().
if not getattr(BaseTask, '_stackprotected', False):
_patched['BaseTask.__call__'] = orig = BaseTask.__call__
def __protected_call__(self, *args, **kwargs):
stack = self.request_stack
req = stack.top
if req and not req._protected and \
len(stack) == 1 and not req.called_directly:
req._protected = 1
return self.run(*args, **kwargs)
return orig(self, *args, **kwargs)
BaseTask.__call__ = __protected_call__
BaseTask._stackprotected = True

View File

@@ -0,0 +1,415 @@
"""App utilities: Compat settings, bug-report tool, pickling apps."""
import os
import platform as _platform
import re
from collections import namedtuple
from collections.abc import Mapping
from copy import deepcopy
from types import ModuleType
from kombu.utils.url import maybe_sanitize_url
from celery.exceptions import ImproperlyConfigured
from celery.platforms import pyimplementation
from celery.utils.collections import ConfigurationView
from celery.utils.imports import import_from_cwd, qualname, symbol_by_name
from celery.utils.text import pretty
from .defaults import _OLD_DEFAULTS, _OLD_SETTING_KEYS, _TO_NEW_KEY, _TO_OLD_KEY, DEFAULTS, SETTING_KEYS, find
__all__ = (
'Settings', 'appstr', 'bugreport',
'filter_hidden_settings', 'find_app',
)
#: Format used to generate bug-report information.
BUGREPORT_INFO = """
software -> celery:{celery_v} kombu:{kombu_v} py:{py_v}
billiard:{billiard_v} {driver_v}
platform -> system:{system} arch:{arch}
kernel version:{kernel_version} imp:{py_i}
loader -> {loader}
settings -> transport:{transport} results:{results}
{human_settings}
"""
HIDDEN_SETTINGS = re.compile(
'API|TOKEN|KEY|SECRET|PASS|PROFANITIES_LIST|SIGNATURE|DATABASE',
re.IGNORECASE,
)
E_MIX_OLD_INTO_NEW = """
Cannot mix new and old setting keys, please rename the
following settings to the new format:
{renames}
"""
E_MIX_NEW_INTO_OLD = """
Cannot mix new setting names with old setting names, please
rename the following settings to use the old format:
{renames}
Or change all of the settings to use the new format :)
"""
FMT_REPLACE_SETTING = '{replace:<36} -> {with_}'
def appstr(app):
"""String used in __repr__ etc, to id app instances."""
return f'{app.main or "__main__"} at {id(app):#x}'
class Settings(ConfigurationView):
"""Celery settings object.
.. seealso:
:ref:`configuration` for a full list of configuration keys.
"""
def __init__(self, *args, deprecated_settings=None, **kwargs):
super().__init__(*args, **kwargs)
self.deprecated_settings = deprecated_settings
@property
def broker_read_url(self):
return (
os.environ.get('CELERY_BROKER_READ_URL') or
self.get('broker_read_url') or
self.broker_url
)
@property
def broker_write_url(self):
return (
os.environ.get('CELERY_BROKER_WRITE_URL') or
self.get('broker_write_url') or
self.broker_url
)
@property
def broker_url(self):
return (
os.environ.get('CELERY_BROKER_URL') or
self.first('broker_url', 'broker_host')
)
@property
def result_backend(self):
return (
os.environ.get('CELERY_RESULT_BACKEND') or
self.first('result_backend', 'CELERY_RESULT_BACKEND')
)
@property
def task_default_exchange(self):
return self.first(
'task_default_exchange',
'task_default_queue',
)
@property
def task_default_routing_key(self):
return self.first(
'task_default_routing_key',
'task_default_queue',
)
@property
def timezone(self):
# this way we also support django's time zone.
return self.first('timezone', 'TIME_ZONE')
def without_defaults(self):
"""Return the current configuration, but without defaults."""
# the last stash is the default settings, so just skip that
return Settings({}, self.maps[:-1])
def value_set_for(self, key):
return key in self.without_defaults()
def find_option(self, name, namespace=''):
"""Search for option by name.
Example:
>>> from proj.celery import app
>>> app.conf.find_option('disable_rate_limits')
('worker', 'prefetch_multiplier',
<Option: type->bool default->False>))
Arguments:
name (str): Name of option, cannot be partial.
namespace (str): Preferred name-space (``None`` by default).
Returns:
Tuple: of ``(namespace, key, type)``.
"""
return find(name, namespace)
def find_value_for_key(self, name, namespace='celery'):
"""Shortcut to ``get_by_parts(*find_option(name)[:-1])``."""
return self.get_by_parts(*self.find_option(name, namespace)[:-1])
def get_by_parts(self, *parts):
"""Return the current value for setting specified as a path.
Example:
>>> from proj.celery import app
>>> app.conf.get_by_parts('worker', 'disable_rate_limits')
False
"""
return self['_'.join(part for part in parts if part)]
def finalize(self):
# See PendingConfiguration in celery/app/base.py
# first access will read actual configuration.
try:
self['__bogus__']
except KeyError:
pass
return self
def table(self, with_defaults=False, censored=True):
filt = filter_hidden_settings if censored else lambda v: v
dict_members = dir(dict)
self.finalize()
settings = self if with_defaults else self.without_defaults()
return filt({
k: v for k, v in settings.items()
if not k.startswith('_') and k not in dict_members
})
def humanize(self, with_defaults=False, censored=True):
"""Return a human readable text showing configuration changes."""
return '\n'.join(
f'{key}: {pretty(value, width=50)}'
for key, value in self.table(with_defaults, censored).items())
def maybe_warn_deprecated_settings(self):
# TODO: Remove this method in Celery 6.0
if self.deprecated_settings:
from celery.app.defaults import _TO_NEW_KEY
from celery.utils import deprecated
for setting in self.deprecated_settings:
deprecated.warn(description=f'The {setting!r} setting',
removal='6.0.0',
alternative=f'Use the {_TO_NEW_KEY[setting]} instead')
return True
return False
def _new_key_to_old(key, convert=_TO_OLD_KEY.get):
return convert(key, key)
def _old_key_to_new(key, convert=_TO_NEW_KEY.get):
return convert(key, key)
_settings_info_t = namedtuple('settings_info_t', (
'defaults', 'convert', 'key_t', 'mix_error',
))
_settings_info = _settings_info_t(
DEFAULTS, _TO_NEW_KEY, _old_key_to_new, E_MIX_OLD_INTO_NEW,
)
_old_settings_info = _settings_info_t(
_OLD_DEFAULTS, _TO_OLD_KEY, _new_key_to_old, E_MIX_NEW_INTO_OLD,
)
def detect_settings(conf, preconf=None, ignore_keys=None, prefix=None,
all_keys=None, old_keys=None):
preconf = {} if not preconf else preconf
ignore_keys = set() if not ignore_keys else ignore_keys
all_keys = SETTING_KEYS if not all_keys else all_keys
old_keys = _OLD_SETTING_KEYS if not old_keys else old_keys
source = conf
if conf is None:
source, conf = preconf, {}
have = set(source.keys()) - ignore_keys
is_in_new = have.intersection(all_keys)
is_in_old = have.intersection(old_keys)
info = None
if is_in_new:
# have new setting names
info, left = _settings_info, is_in_old
if is_in_old and len(is_in_old) > len(is_in_new):
# Majority of the settings are old.
info, left = _old_settings_info, is_in_new
if is_in_old:
# have old setting names, or a majority of the names are old.
if not info:
info, left = _old_settings_info, is_in_new
if is_in_new and len(is_in_new) > len(is_in_old):
# Majority of the settings are new
info, left = _settings_info, is_in_old
else:
# no settings, just use new format.
info, left = _settings_info, is_in_old
if prefix:
# always use new format if prefix is used.
info, left = _settings_info, set()
# only raise error for keys that the user didn't provide two keys
# for (e.g., both ``result_expires`` and ``CELERY_TASK_RESULT_EXPIRES``).
really_left = {key for key in left if info.convert[key] not in have}
if really_left:
# user is mixing old/new, or new/old settings, give renaming
# suggestions.
raise ImproperlyConfigured(info.mix_error.format(renames='\n'.join(
FMT_REPLACE_SETTING.format(replace=key, with_=info.convert[key])
for key in sorted(really_left)
)))
preconf = {info.convert.get(k, k): v for k, v in preconf.items()}
defaults = dict(deepcopy(info.defaults), **preconf)
return Settings(
preconf, [conf, defaults],
(_old_key_to_new, _new_key_to_old),
deprecated_settings=is_in_old,
prefix=prefix,
)
class AppPickler:
"""Old application pickler/unpickler (< 3.1)."""
def __call__(self, cls, *args):
kwargs = self.build_kwargs(*args)
app = self.construct(cls, **kwargs)
self.prepare(app, **kwargs)
return app
def prepare(self, app, **kwargs):
app.conf.update(kwargs['changes'])
def build_kwargs(self, *args):
return self.build_standard_kwargs(*args)
def build_standard_kwargs(self, main, changes, loader, backend, amqp,
events, log, control, accept_magic_kwargs,
config_source=None):
return {'main': main, 'loader': loader, 'backend': backend,
'amqp': amqp, 'changes': changes, 'events': events,
'log': log, 'control': control, 'set_as_current': False,
'config_source': config_source}
def construct(self, cls, **kwargs):
return cls(**kwargs)
def _unpickle_app(cls, pickler, *args):
"""Rebuild app for versions 2.5+."""
return pickler()(cls, *args)
def _unpickle_app_v2(cls, kwargs):
"""Rebuild app for versions 3.1+."""
kwargs['set_as_current'] = False
return cls(**kwargs)
def filter_hidden_settings(conf):
"""Filter sensitive settings."""
def maybe_censor(key, value, mask='*' * 8):
if isinstance(value, Mapping):
return filter_hidden_settings(value)
if isinstance(key, str):
if HIDDEN_SETTINGS.search(key):
return mask
elif 'broker_url' in key.lower():
from kombu import Connection
return Connection(value).as_uri(mask=mask)
elif 'backend' in key.lower():
return maybe_sanitize_url(value, mask=mask)
return value
return {k: maybe_censor(k, v) for k, v in conf.items()}
def bugreport(app):
"""Return a string containing information useful in bug-reports."""
import billiard
import kombu
import celery
try:
conn = app.connection()
driver_v = '{}:{}'.format(conn.transport.driver_name,
conn.transport.driver_version())
transport = conn.transport_cls
except Exception: # pylint: disable=broad-except
transport = driver_v = ''
return BUGREPORT_INFO.format(
system=_platform.system(),
arch=', '.join(x for x in _platform.architecture() if x),
kernel_version=_platform.release(),
py_i=pyimplementation(),
celery_v=celery.VERSION_BANNER,
kombu_v=kombu.__version__,
billiard_v=billiard.__version__,
py_v=_platform.python_version(),
driver_v=driver_v,
transport=transport,
results=maybe_sanitize_url(app.conf.result_backend or 'disabled'),
human_settings=app.conf.humanize(),
loader=qualname(app.loader.__class__),
)
def find_app(app, symbol_by_name=symbol_by_name, imp=import_from_cwd):
"""Find app by name."""
from .base import Celery
try:
sym = symbol_by_name(app, imp=imp)
except AttributeError:
# last part was not an attribute, but a module
sym = imp(app)
if isinstance(sym, ModuleType) and ':' not in app:
try:
found = sym.app
if isinstance(found, ModuleType):
raise AttributeError()
except AttributeError:
try:
found = sym.celery
if isinstance(found, ModuleType):
raise AttributeError(
"attribute 'celery' is the celery module not the instance of celery")
except AttributeError:
if getattr(sym, '__path__', None):
try:
return find_app(
f'{app}.celery',
symbol_by_name=symbol_by_name, imp=imp,
)
except ImportError:
pass
for suspect in vars(sym).values():
if isinstance(suspect, Celery):
return suspect
raise
else:
return found
else:
return found
return sym

View File

@@ -0,0 +1,160 @@
"""Beat command-line program.
This module is the 'program-version' of :mod:`celery.beat`.
It does everything necessary to run that module
as an actual application, like installing signal handlers
and so on.
"""
from __future__ import annotations
import numbers
import socket
import sys
from datetime import datetime
from signal import Signals
from types import FrameType
from typing import Any
from celery import VERSION_BANNER, Celery, beat, platforms
from celery.utils.imports import qualname
from celery.utils.log import LOG_LEVELS, get_logger
from celery.utils.time import humanize_seconds
__all__ = ('Beat',)
STARTUP_INFO_FMT = """
LocalTime -> {timestamp}
Configuration ->
. broker -> {conninfo}
. loader -> {loader}
. scheduler -> {scheduler}
{scheduler_info}
. logfile -> {logfile}@%{loglevel}
. maxinterval -> {hmax_interval} ({max_interval}s)
""".strip()
logger = get_logger('celery.beat')
class Beat:
"""Beat as a service."""
Service = beat.Service
app: Celery = None
def __init__(self, max_interval: int | None = None, app: Celery | None = None,
socket_timeout: int = 30, pidfile: str | None = None, no_color: bool | None = None,
loglevel: str = 'WARN', logfile: str | None = None, schedule: str | None = None,
scheduler: str | None = None,
scheduler_cls: str | None = None, # XXX use scheduler
redirect_stdouts: bool | None = None,
redirect_stdouts_level: str | None = None,
quiet: bool = False, **kwargs: Any) -> None:
self.app = app = app or self.app
either = self.app.either
self.loglevel = loglevel
self.logfile = logfile
self.schedule = either('beat_schedule_filename', schedule)
self.scheduler_cls = either(
'beat_scheduler', scheduler, scheduler_cls)
self.redirect_stdouts = either(
'worker_redirect_stdouts', redirect_stdouts)
self.redirect_stdouts_level = either(
'worker_redirect_stdouts_level', redirect_stdouts_level)
self.quiet = quiet
self.max_interval = max_interval
self.socket_timeout = socket_timeout
self.no_color = no_color
self.colored = app.log.colored(
self.logfile,
enabled=not no_color if no_color is not None else no_color,
)
self.pidfile = pidfile
if not isinstance(self.loglevel, numbers.Integral):
self.loglevel = LOG_LEVELS[self.loglevel.upper()]
def run(self) -> None:
if not self.quiet:
print(str(self.colored.cyan(
f'celery beat v{VERSION_BANNER} is starting.')))
self.init_loader()
self.set_process_title()
self.start_scheduler()
def setup_logging(self, colorize: bool | None = None) -> None:
if colorize is None and self.no_color is not None:
colorize = not self.no_color
self.app.log.setup(self.loglevel, self.logfile,
self.redirect_stdouts, self.redirect_stdouts_level,
colorize=colorize)
def start_scheduler(self) -> None:
if self.pidfile:
platforms.create_pidlock(self.pidfile)
service = self.Service(
app=self.app,
max_interval=self.max_interval,
scheduler_cls=self.scheduler_cls,
schedule_filename=self.schedule,
)
if not self.quiet:
print(self.banner(service))
self.setup_logging()
if self.socket_timeout:
logger.debug('Setting default socket timeout to %r',
self.socket_timeout)
socket.setdefaulttimeout(self.socket_timeout)
try:
self.install_sync_handler(service)
service.start()
except Exception as exc:
logger.critical('beat raised exception %s: %r',
exc.__class__, exc,
exc_info=True)
raise
def banner(self, service: beat.Service) -> str:
c = self.colored
return str(
c.blue('__ ', c.magenta('-'),
c.blue(' ... __ '), c.magenta('-'),
c.blue(' _\n'),
c.reset(self.startup_info(service))),
)
def init_loader(self) -> None:
# Run the worker init handler.
# (Usually imports task modules and such.)
self.app.loader.init_worker()
self.app.finalize()
def startup_info(self, service: beat.Service) -> str:
scheduler = service.get_scheduler(lazy=True)
return STARTUP_INFO_FMT.format(
conninfo=self.app.connection().as_uri(),
timestamp=datetime.now().replace(microsecond=0),
logfile=self.logfile or '[stderr]',
loglevel=LOG_LEVELS[self.loglevel],
loader=qualname(self.app.loader),
scheduler=qualname(scheduler),
scheduler_info=scheduler.info,
hmax_interval=humanize_seconds(scheduler.max_interval),
max_interval=scheduler.max_interval,
)
def set_process_title(self) -> None:
arg_start = 'manage' in sys.argv[0] and 2 or 1
platforms.set_process_title(
'celery beat', info=' '.join(sys.argv[arg_start:]),
)
def install_sync_handler(self, service: beat.Service) -> None:
"""Install a `SIGTERM` + `SIGINT` handler saving the schedule."""
def _sync(signum: Signals, frame: FrameType) -> None:
service.sync()
raise SystemExit()
platforms.signals.update(SIGTERM=_sync, SIGINT=_sync)

View File

@@ -0,0 +1,506 @@
"""Start/stop/manage workers."""
import errno
import os
import shlex
import signal
import sys
from collections import OrderedDict, UserList, defaultdict
from functools import partial
from subprocess import Popen
from time import sleep
from kombu.utils.encoding import from_utf8
from kombu.utils.objects import cached_property
from celery.platforms import IS_WINDOWS, Pidfile, signal_name
from celery.utils.nodenames import gethostname, host_format, node_format, nodesplit
from celery.utils.saferepr import saferepr
__all__ = ('Cluster', 'Node')
CELERY_EXE = 'celery'
def celery_exe(*args):
return ' '.join((CELERY_EXE,) + args)
def build_nodename(name, prefix, suffix):
hostname = suffix
if '@' in name:
nodename = host_format(name)
shortname, hostname = nodesplit(nodename)
name = shortname
else:
shortname = f'{prefix}{name}'
nodename = host_format(
f'{shortname}@{hostname}',
)
return name, nodename, hostname
def build_expander(nodename, shortname, hostname):
return partial(
node_format,
name=nodename,
N=shortname,
d=hostname,
h=nodename,
i='%i',
I='%I',
)
def format_opt(opt, value):
if not value:
return opt
if opt.startswith('--'):
return f'{opt}={value}'
return f'{opt} {value}'
def _kwargs_to_command_line(kwargs):
return {
('--{}'.format(k.replace('_', '-'))
if len(k) > 1 else f'-{k}'): f'{v}'
for k, v in kwargs.items()
}
class NamespacedOptionParser:
def __init__(self, args):
self.args = args
self.options = OrderedDict()
self.values = []
self.passthrough = ''
self.namespaces = defaultdict(lambda: OrderedDict())
def parse(self):
rargs = [arg for arg in self.args if arg]
pos = 0
while pos < len(rargs):
arg = rargs[pos]
if arg == '--':
self.passthrough = ' '.join(rargs[pos:])
break
elif arg[0] == '-':
if arg[1] == '-':
self.process_long_opt(arg[2:])
else:
value = None
if len(rargs) > pos + 1 and rargs[pos + 1][0] != '-':
value = rargs[pos + 1]
pos += 1
self.process_short_opt(arg[1:], value)
else:
self.values.append(arg)
pos += 1
def process_long_opt(self, arg, value=None):
if '=' in arg:
arg, value = arg.split('=', 1)
self.add_option(arg, value, short=False)
def process_short_opt(self, arg, value=None):
self.add_option(arg, value, short=True)
def optmerge(self, ns, defaults=None):
if defaults is None:
defaults = self.options
return OrderedDict(defaults, **self.namespaces[ns])
def add_option(self, name, value, short=False, ns=None):
prefix = short and '-' or '--'
dest = self.options
if ':' in name:
name, ns = name.split(':')
dest = self.namespaces[ns]
dest[prefix + name] = value
class Node:
"""Represents a node in a cluster."""
def __init__(self, name,
cmd=None, append=None, options=None, extra_args=None):
self.name = name
self.cmd = cmd or f"-m {celery_exe('worker', '--detach')}"
self.append = append
self.extra_args = extra_args or ''
self.options = self._annotate_with_default_opts(
options or OrderedDict())
self.expander = self._prepare_expander()
self.argv = self._prepare_argv()
self._pid = None
def _annotate_with_default_opts(self, options):
options['-n'] = self.name
self._setdefaultopt(options, ['--pidfile', '-p'], '/var/run/celery/%n.pid')
self._setdefaultopt(options, ['--logfile', '-f'], '/var/log/celery/%n%I.log')
self._setdefaultopt(options, ['--executable'], sys.executable)
return options
def _setdefaultopt(self, d, alt, value):
for opt in alt[1:]:
try:
return d[opt]
except KeyError:
pass
value = d.setdefault(alt[0], os.path.normpath(value))
dir_path = os.path.dirname(value)
if dir_path and not os.path.exists(dir_path):
os.makedirs(dir_path)
return value
def _prepare_expander(self):
shortname, hostname = self.name.split('@', 1)
return build_expander(
self.name, shortname, hostname)
def _prepare_argv(self):
cmd = self.expander(self.cmd).split(' ')
i = cmd.index('celery') + 1
options = self.options.copy()
for opt, value in self.options.items():
if opt in (
'-A', '--app',
'-b', '--broker',
'--result-backend',
'--loader',
'--config',
'--workdir',
'-C', '--no-color',
'-q', '--quiet',
):
cmd.insert(i, format_opt(opt, self.expander(value)))
options.pop(opt)
cmd = [' '.join(cmd)]
argv = tuple(
cmd +
[format_opt(opt, self.expander(value))
for opt, value in options.items()] +
[self.extra_args]
)
if self.append:
argv += (self.expander(self.append),)
return argv
def alive(self):
return self.send(0)
def send(self, sig, on_error=None):
pid = self.pid
if pid:
try:
os.kill(pid, sig)
except OSError as exc:
if exc.errno != errno.ESRCH:
raise
maybe_call(on_error, self)
return False
return True
maybe_call(on_error, self)
def start(self, env=None, **kwargs):
return self._waitexec(
self.argv, path=self.executable, env=env, **kwargs)
def _waitexec(self, argv, path=sys.executable, env=None,
on_spawn=None, on_signalled=None, on_failure=None):
argstr = self.prepare_argv(argv, path)
maybe_call(on_spawn, self, argstr=' '.join(argstr), env=env)
pipe = Popen(argstr, env=env)
return self.handle_process_exit(
pipe.wait(),
on_signalled=on_signalled,
on_failure=on_failure,
)
def handle_process_exit(self, retcode, on_signalled=None, on_failure=None):
if retcode < 0:
maybe_call(on_signalled, self, -retcode)
return -retcode
elif retcode > 0:
maybe_call(on_failure, self, retcode)
return retcode
def prepare_argv(self, argv, path):
args = ' '.join([path] + list(argv))
return shlex.split(from_utf8(args), posix=not IS_WINDOWS)
def getopt(self, *alt):
for opt in alt:
try:
return self.options[opt]
except KeyError:
pass
raise KeyError(alt[0])
def __repr__(self):
return f'<{type(self).__name__}: {self.name}>'
@cached_property
def pidfile(self):
return self.expander(self.getopt('--pidfile', '-p'))
@cached_property
def logfile(self):
return self.expander(self.getopt('--logfile', '-f'))
@property
def pid(self):
if self._pid is not None:
return self._pid
try:
return Pidfile(self.pidfile).read_pid()
except ValueError:
pass
@pid.setter
def pid(self, value):
self._pid = value
@cached_property
def executable(self):
return self.options['--executable']
@cached_property
def argv_with_executable(self):
return (self.executable,) + self.argv
@classmethod
def from_kwargs(cls, name, **kwargs):
return cls(name, options=_kwargs_to_command_line(kwargs))
def maybe_call(fun, *args, **kwargs):
if fun is not None:
fun(*args, **kwargs)
class MultiParser:
Node = Node
def __init__(self, cmd='celery worker',
append='', prefix='', suffix='',
range_prefix='celery'):
self.cmd = cmd
self.append = append
self.prefix = prefix
self.suffix = suffix
self.range_prefix = range_prefix
def parse(self, p):
names = p.values
options = dict(p.options)
ranges = len(names) == 1
prefix = self.prefix
cmd = options.pop('--cmd', self.cmd)
append = options.pop('--append', self.append)
hostname = options.pop('--hostname', options.pop('-n', gethostname()))
prefix = options.pop('--prefix', prefix) or ''
suffix = options.pop('--suffix', self.suffix) or hostname
suffix = '' if suffix in ('""', "''") else suffix
range_prefix = options.pop('--range-prefix', '') or self.range_prefix
if ranges:
try:
names, prefix = self._get_ranges(names), range_prefix
except ValueError:
pass
self._update_ns_opts(p, names)
self._update_ns_ranges(p, ranges)
return (
self._node_from_options(
p, name, prefix, suffix, cmd, append, options)
for name in names
)
def _node_from_options(self, p, name, prefix,
suffix, cmd, append, options):
namespace, nodename, _ = build_nodename(name, prefix, suffix)
namespace = nodename if nodename in p.namespaces else namespace
return Node(nodename, cmd, append,
p.optmerge(namespace, options), p.passthrough)
def _get_ranges(self, names):
noderange = int(names[0])
return [str(n) for n in range(1, noderange + 1)]
def _update_ns_opts(self, p, names):
# Numbers in args always refers to the index in the list of names.
# (e.g., `start foo bar baz -c:1` where 1 is foo, 2 is bar, and so on).
for ns_name, ns_opts in list(p.namespaces.items()):
if ns_name.isdigit():
ns_index = int(ns_name) - 1
if ns_index < 0:
raise KeyError(f'Indexes start at 1 got: {ns_name!r}')
try:
p.namespaces[names[ns_index]].update(ns_opts)
except IndexError:
raise KeyError(f'No node at index {ns_name!r}')
def _update_ns_ranges(self, p, ranges):
for ns_name, ns_opts in list(p.namespaces.items()):
if ',' in ns_name or (ranges and '-' in ns_name):
for subns in self._parse_ns_range(ns_name, ranges):
p.namespaces[subns].update(ns_opts)
p.namespaces.pop(ns_name)
def _parse_ns_range(self, ns, ranges=False):
ret = []
for space in ',' in ns and ns.split(',') or [ns]:
if ranges and '-' in space:
start, stop = space.split('-')
ret.extend(
str(n) for n in range(int(start), int(stop) + 1)
)
else:
ret.append(space)
return ret
class Cluster(UserList):
"""Represent a cluster of workers."""
def __init__(self, nodes, cmd=None, env=None,
on_stopping_preamble=None,
on_send_signal=None,
on_still_waiting_for=None,
on_still_waiting_progress=None,
on_still_waiting_end=None,
on_node_start=None,
on_node_restart=None,
on_node_shutdown_ok=None,
on_node_status=None,
on_node_signal=None,
on_node_signal_dead=None,
on_node_down=None,
on_child_spawn=None,
on_child_signalled=None,
on_child_failure=None):
self.nodes = nodes
self.cmd = cmd or celery_exe('worker')
self.env = env
self.on_stopping_preamble = on_stopping_preamble
self.on_send_signal = on_send_signal
self.on_still_waiting_for = on_still_waiting_for
self.on_still_waiting_progress = on_still_waiting_progress
self.on_still_waiting_end = on_still_waiting_end
self.on_node_start = on_node_start
self.on_node_restart = on_node_restart
self.on_node_shutdown_ok = on_node_shutdown_ok
self.on_node_status = on_node_status
self.on_node_signal = on_node_signal
self.on_node_signal_dead = on_node_signal_dead
self.on_node_down = on_node_down
self.on_child_spawn = on_child_spawn
self.on_child_signalled = on_child_signalled
self.on_child_failure = on_child_failure
def start(self):
return [self.start_node(node) for node in self]
def start_node(self, node):
maybe_call(self.on_node_start, node)
retcode = self._start_node(node)
maybe_call(self.on_node_status, node, retcode)
return retcode
def _start_node(self, node):
return node.start(
self.env,
on_spawn=self.on_child_spawn,
on_signalled=self.on_child_signalled,
on_failure=self.on_child_failure,
)
def send_all(self, sig):
for node in self.getpids(on_down=self.on_node_down):
maybe_call(self.on_node_signal, node, signal_name(sig))
node.send(sig, self.on_node_signal_dead)
def kill(self):
return self.send_all(signal.SIGKILL)
def restart(self, sig=signal.SIGTERM):
retvals = []
def restart_on_down(node):
maybe_call(self.on_node_restart, node)
retval = self._start_node(node)
maybe_call(self.on_node_status, node, retval)
retvals.append(retval)
self._stop_nodes(retry=2, on_down=restart_on_down, sig=sig)
return retvals
def stop(self, retry=None, callback=None, sig=signal.SIGTERM):
return self._stop_nodes(retry=retry, on_down=callback, sig=sig)
def stopwait(self, retry=2, callback=None, sig=signal.SIGTERM):
return self._stop_nodes(retry=retry, on_down=callback, sig=sig)
def _stop_nodes(self, retry=None, on_down=None, sig=signal.SIGTERM):
on_down = on_down if on_down is not None else self.on_node_down
nodes = list(self.getpids(on_down=on_down))
if nodes:
for node in self.shutdown_nodes(nodes, sig=sig, retry=retry):
maybe_call(on_down, node)
def shutdown_nodes(self, nodes, sig=signal.SIGTERM, retry=None):
P = set(nodes)
maybe_call(self.on_stopping_preamble, nodes)
to_remove = set()
for node in P:
maybe_call(self.on_send_signal, node, signal_name(sig))
if not node.send(sig, self.on_node_signal_dead):
to_remove.add(node)
yield node
P -= to_remove
if retry:
maybe_call(self.on_still_waiting_for, P)
its = 0
while P:
to_remove = set()
for node in P:
its += 1
maybe_call(self.on_still_waiting_progress, P)
if not node.alive():
maybe_call(self.on_node_shutdown_ok, node)
to_remove.add(node)
yield node
maybe_call(self.on_still_waiting_for, P)
break
P -= to_remove
if P and not its % len(P):
sleep(float(retry))
maybe_call(self.on_still_waiting_end)
def find(self, name):
for node in self:
if node.name == name:
return node
raise KeyError(name)
def getpids(self, on_down=None):
for node in self:
if node.pid:
yield node
else:
maybe_call(on_down, node)
def __repr__(self):
return '<{name}({0}): {1}>'.format(
len(self), saferepr([n.name for n in self]),
name=type(self).__name__,
)
@property
def data(self):
return self.nodes

View File

@@ -0,0 +1,387 @@
"""Worker command-line program.
This module is the 'program-version' of :mod:`celery.worker`.
It does everything necessary to run that module
as an actual application, like installing signal handlers,
platform tweaks, and so on.
"""
import logging
import os
import platform as _platform
import sys
from datetime import datetime
from functools import partial
from billiard.common import REMAP_SIGTERM
from billiard.process import current_process
from kombu.utils.encoding import safe_str
from celery import VERSION_BANNER, platforms, signals
from celery.app import trace
from celery.loaders.app import AppLoader
from celery.platforms import EX_FAILURE, EX_OK, check_privileges
from celery.utils import static, term
from celery.utils.debug import cry
from celery.utils.imports import qualname
from celery.utils.log import get_logger, in_sighandler, set_in_sighandler
from celery.utils.text import pluralize
from celery.worker import WorkController
__all__ = ('Worker',)
logger = get_logger(__name__)
is_jython = sys.platform.startswith('java')
is_pypy = hasattr(sys, 'pypy_version_info')
ARTLINES = [
' --------------',
'--- ***** -----',
'-- ******* ----',
'- *** --- * ---',
'- ** ----------',
'- ** ----------',
'- ** ----------',
'- ** ----------',
'- *** --- * ---',
'-- ******* ----',
'--- ***** -----',
' --------------',
]
BANNER = """\
{hostname} v{version}
{platform} {timestamp}
[config]
.> app: {app}
.> transport: {conninfo}
.> results: {results}
.> concurrency: {concurrency}
.> task events: {events}
[queues]
{queues}
"""
EXTRA_INFO_FMT = """
[tasks]
{tasks}
"""
def active_thread_count():
from threading import enumerate
return sum(1 for t in enumerate()
if not t.name.startswith('Dummy-'))
def safe_say(msg):
print(f'\n{msg}', file=sys.__stderr__, flush=True)
class Worker(WorkController):
"""Worker as a program."""
def on_before_init(self, quiet=False, **kwargs):
self.quiet = quiet
trace.setup_worker_optimizations(self.app, self.hostname)
# this signal can be used to set up configuration for
# workers by name.
signals.celeryd_init.send(
sender=self.hostname, instance=self,
conf=self.app.conf, options=kwargs,
)
check_privileges(self.app.conf.accept_content)
def on_after_init(self, purge=False, no_color=None,
redirect_stdouts=None, redirect_stdouts_level=None,
**kwargs):
self.redirect_stdouts = self.app.either(
'worker_redirect_stdouts', redirect_stdouts)
self.redirect_stdouts_level = self.app.either(
'worker_redirect_stdouts_level', redirect_stdouts_level)
super().setup_defaults(**kwargs)
self.purge = purge
self.no_color = no_color
self._isatty = sys.stdout.isatty()
self.colored = self.app.log.colored(
self.logfile,
enabled=not no_color if no_color is not None else no_color
)
def on_init_blueprint(self):
self._custom_logging = self.setup_logging()
# apply task execution optimizations
# -- This will finalize the app!
trace.setup_worker_optimizations(self.app, self.hostname)
def on_start(self):
app = self.app
super().on_start()
# this signal can be used to, for example, change queues after
# the -Q option has been applied.
signals.celeryd_after_setup.send(
sender=self.hostname, instance=self, conf=app.conf,
)
if self.purge:
self.purge_messages()
if not self.quiet:
self.emit_banner()
self.set_process_status('-active-')
self.install_platform_tweaks(self)
if not self._custom_logging and self.redirect_stdouts:
app.log.redirect_stdouts(self.redirect_stdouts_level)
# TODO: Remove the following code in Celery 6.0
# This qualifies as a hack for issue #6366.
warn_deprecated = True
config_source = app._config_source
if isinstance(config_source, str):
# Don't raise the warning when the settings originate from
# django.conf:settings
warn_deprecated = config_source.lower() not in [
'django.conf:settings',
]
if warn_deprecated:
if app.conf.maybe_warn_deprecated_settings():
logger.warning(
"Please run `celery upgrade settings path/to/settings.py` "
"to avoid these warnings and to allow a smoother upgrade "
"to Celery 6.0."
)
def emit_banner(self):
# Dump configuration to screen so we have some basic information
# for when users sends bug reports.
use_image = term.supports_images()
if use_image:
print(term.imgcat(static.logo()))
print(safe_str(''.join([
str(self.colored.cyan(
' \n', self.startup_info(artlines=not use_image))),
str(self.colored.reset(self.extra_info() or '')),
])), file=sys.__stdout__, flush=True)
def on_consumer_ready(self, consumer):
signals.worker_ready.send(sender=consumer)
logger.info('%s ready.', safe_str(self.hostname))
def setup_logging(self, colorize=None):
if colorize is None and self.no_color is not None:
colorize = not self.no_color
return self.app.log.setup(
self.loglevel, self.logfile,
redirect_stdouts=False, colorize=colorize, hostname=self.hostname,
)
def purge_messages(self):
with self.app.connection_for_write() as connection:
count = self.app.control.purge(connection=connection)
if count: # pragma: no cover
print(f"purge: Erased {count} {pluralize(count, 'message')} from the queue.\n", flush=True)
def tasklist(self, include_builtins=True, sep='\n', int_='celery.'):
return sep.join(
f' . {task}' for task in sorted(self.app.tasks)
if (not task.startswith(int_) if not include_builtins else task)
)
def extra_info(self):
if self.loglevel is None:
return
if self.loglevel <= logging.INFO:
include_builtins = self.loglevel <= logging.DEBUG
tasklist = self.tasklist(include_builtins=include_builtins)
return EXTRA_INFO_FMT.format(tasks=tasklist)
def startup_info(self, artlines=True):
app = self.app
concurrency = str(self.concurrency)
appr = '{}:{:#x}'.format(app.main or '__main__', id(app))
if not isinstance(app.loader, AppLoader):
loader = qualname(app.loader)
if loader.startswith('celery.loaders'): # pragma: no cover
loader = loader[14:]
appr += f' ({loader})'
if self.autoscale:
max, min = self.autoscale
concurrency = f'{{min={min}, max={max}}}'
pool = self.pool_cls
if not isinstance(pool, str):
pool = pool.__module__
concurrency += f" ({pool.split('.')[-1]})"
events = 'ON'
if not self.task_events:
events = 'OFF (enable -E to monitor tasks in this worker)'
banner = BANNER.format(
app=appr,
hostname=safe_str(self.hostname),
timestamp=datetime.now().replace(microsecond=0),
version=VERSION_BANNER,
conninfo=self.app.connection().as_uri(),
results=self.app.backend.as_uri(),
concurrency=concurrency,
platform=safe_str(_platform.platform()),
events=events,
queues=app.amqp.queues.format(indent=0, indent_first=False),
).splitlines()
# integrate the ASCII art.
if artlines:
for i, _ in enumerate(banner):
try:
banner[i] = ' '.join([ARTLINES[i], banner[i]])
except IndexError:
banner[i] = ' ' * 16 + banner[i]
return '\n'.join(banner) + '\n'
def install_platform_tweaks(self, worker):
"""Install platform specific tweaks and workarounds."""
if self.app.IS_macOS:
self.macOS_proxy_detection_workaround()
# Install signal handler so SIGHUP restarts the worker.
if not self._isatty:
# only install HUP handler if detached from terminal,
# so closing the terminal window doesn't restart the worker
# into the background.
if self.app.IS_macOS:
# macOS can't exec from a process using threads.
# See https://github.com/celery/celery/issues#issue/152
install_HUP_not_supported_handler(worker)
else:
install_worker_restart_handler(worker)
install_worker_term_handler(worker)
install_worker_term_hard_handler(worker)
install_worker_int_handler(worker)
install_cry_handler()
install_rdb_handler()
def macOS_proxy_detection_workaround(self):
"""See https://github.com/celery/celery/issues#issue/161."""
os.environ.setdefault('celery_dummy_proxy', 'set_by_celeryd')
def set_process_status(self, info):
return platforms.set_mp_process_title(
'celeryd',
info=f'{info} ({platforms.strargv(sys.argv)})',
hostname=self.hostname,
)
def _shutdown_handler(worker, sig='TERM', how='Warm',
callback=None, exitcode=EX_OK):
def _handle_request(*args):
with in_sighandler():
from celery.worker import state
if current_process()._name == 'MainProcess':
if callback:
callback(worker)
safe_say(f'worker: {how} shutdown (MainProcess)')
signals.worker_shutting_down.send(
sender=worker.hostname, sig=sig, how=how,
exitcode=exitcode,
)
setattr(state, {'Warm': 'should_stop',
'Cold': 'should_terminate'}[how], exitcode)
_handle_request.__name__ = str(f'worker_{how}')
platforms.signals[sig] = _handle_request
if REMAP_SIGTERM == "SIGQUIT":
install_worker_term_handler = partial(
_shutdown_handler, sig='SIGTERM', how='Cold', exitcode=EX_FAILURE,
)
else:
install_worker_term_handler = partial(
_shutdown_handler, sig='SIGTERM', how='Warm',
)
if not is_jython: # pragma: no cover
install_worker_term_hard_handler = partial(
_shutdown_handler, sig='SIGQUIT', how='Cold',
exitcode=EX_FAILURE,
)
else: # pragma: no cover
install_worker_term_handler = \
install_worker_term_hard_handler = lambda *a, **kw: None
def on_SIGINT(worker):
safe_say('worker: Hitting Ctrl+C again will terminate all running tasks!')
install_worker_term_hard_handler(worker, sig='SIGINT')
if not is_jython: # pragma: no cover
install_worker_int_handler = partial(
_shutdown_handler, sig='SIGINT', callback=on_SIGINT,
exitcode=EX_FAILURE,
)
else: # pragma: no cover
def install_worker_int_handler(*args, **kwargs):
pass
def _reload_current_worker():
platforms.close_open_fds([
sys.__stdin__, sys.__stdout__, sys.__stderr__,
])
os.execv(sys.executable, [sys.executable] + sys.argv)
def install_worker_restart_handler(worker, sig='SIGHUP'):
def restart_worker_sig_handler(*args):
"""Signal handler restarting the current python program."""
set_in_sighandler(True)
safe_say(f"Restarting celery worker ({' '.join(sys.argv)})")
import atexit
atexit.register(_reload_current_worker)
from celery.worker import state
state.should_stop = EX_OK
platforms.signals[sig] = restart_worker_sig_handler
def install_cry_handler(sig='SIGUSR1'):
# PyPy does not have sys._current_frames
if is_pypy: # pragma: no cover
return
def cry_handler(*args):
"""Signal handler logging the stack-trace of all active threads."""
with in_sighandler():
safe_say(cry())
platforms.signals[sig] = cry_handler
def install_rdb_handler(envvar='CELERY_RDBSIG',
sig='SIGUSR2'): # pragma: no cover
def rdb_handler(*args):
"""Signal handler setting a rdb breakpoint at the current frame."""
with in_sighandler():
from celery.contrib.rdb import _frame, set_trace
# gevent does not pass standard signal handler args
frame = args[1] if args else _frame().f_back
set_trace(frame)
if os.environ.get(envvar):
platforms.signals[sig] = rdb_handler
def install_HUP_not_supported_handler(worker, sig='SIGHUP'):
def warn_on_HUP_handler(signum, frame):
with in_sighandler():
safe_say('{sig} not supported: Restarting with {sig} is '
'unstable on this platform!'.format(sig=sig))
platforms.signals[sig] = warn_on_HUP_handler

View File

@@ -0,0 +1 @@
"""Result Backends."""

View File

@@ -0,0 +1,190 @@
"""ArangoDb result store backend."""
# pylint: disable=W1202,W0703
from datetime import timedelta
from kombu.utils.objects import cached_property
from kombu.utils.url import _parse_url
from celery.exceptions import ImproperlyConfigured
from .base import KeyValueStoreBackend
try:
from pyArango import connection as py_arango_connection
from pyArango.theExceptions import AQLQueryError
except ImportError:
py_arango_connection = AQLQueryError = None
__all__ = ('ArangoDbBackend',)
class ArangoDbBackend(KeyValueStoreBackend):
"""ArangoDb backend.
Sample url
"arangodb://username:password@host:port/database/collection"
*arangodb_backend_settings* is where the settings are present
(in the app.conf)
Settings should contain the host, port, username, password, database name,
collection name else the default will be chosen.
Default database name and collection name is celery.
Raises
------
celery.exceptions.ImproperlyConfigured:
if module :pypi:`pyArango` is not available.
"""
host = '127.0.0.1'
port = '8529'
database = 'celery'
collection = 'celery'
username = None
password = None
# protocol is not supported in backend url (http is taken as default)
http_protocol = 'http'
verify = False
# Use str as arangodb key not bytes
key_t = str
def __init__(self, url=None, *args, **kwargs):
"""Parse the url or load the settings from settings object."""
super().__init__(*args, **kwargs)
if py_arango_connection is None:
raise ImproperlyConfigured(
'You need to install the pyArango library to use the '
'ArangoDb backend.',
)
self.url = url
if url is None:
host = port = database = collection = username = password = None
else:
(
_schema, host, port, username, password,
database_collection, _query
) = _parse_url(url)
if database_collection is None:
database = collection = None
else:
database, collection = database_collection.split('/')
config = self.app.conf.get('arangodb_backend_settings', None)
if config is not None:
if not isinstance(config, dict):
raise ImproperlyConfigured(
'ArangoDb backend settings should be grouped in a dict',
)
else:
config = {}
self.host = host or config.get('host', self.host)
self.port = int(port or config.get('port', self.port))
self.http_protocol = config.get('http_protocol', self.http_protocol)
self.verify = config.get('verify', self.verify)
self.database = database or config.get('database', self.database)
self.collection = \
collection or config.get('collection', self.collection)
self.username = username or config.get('username', self.username)
self.password = password or config.get('password', self.password)
self.arangodb_url = "{http_protocol}://{host}:{port}".format(
http_protocol=self.http_protocol, host=self.host, port=self.port
)
self._connection = None
@property
def connection(self):
"""Connect to the arangodb server."""
if self._connection is None:
self._connection = py_arango_connection.Connection(
arangoURL=self.arangodb_url, username=self.username,
password=self.password, verify=self.verify
)
return self._connection
@property
def db(self):
"""Database Object to the given database."""
return self.connection[self.database]
@cached_property
def expires_delta(self):
return timedelta(seconds=0 if self.expires is None else self.expires)
def get(self, key):
if key is None:
return None
query = self.db.AQLQuery(
"RETURN DOCUMENT(@@collection, @key).task",
rawResults=True,
bindVars={
"@collection": self.collection,
"key": key,
},
)
return next(query) if len(query) > 0 else None
def set(self, key, value):
self.db.AQLQuery(
"""
UPSERT {_key: @key}
INSERT {_key: @key, task: @value}
UPDATE {task: @value} IN @@collection
""",
bindVars={
"@collection": self.collection,
"key": key,
"value": value,
},
)
def mget(self, keys):
if keys is None:
return
query = self.db.AQLQuery(
"FOR k IN @keys RETURN DOCUMENT(@@collection, k).task",
rawResults=True,
bindVars={
"@collection": self.collection,
"keys": keys if isinstance(keys, list) else list(keys),
},
)
while True:
yield from query
try:
query.nextBatch()
except StopIteration:
break
def delete(self, key):
if key is None:
return
self.db.AQLQuery(
"REMOVE {_key: @key} IN @@collection",
bindVars={
"@collection": self.collection,
"key": key,
},
)
def cleanup(self):
if not self.expires:
return
checkpoint = (self.app.now() - self.expires_delta).isoformat()
self.db.AQLQuery(
"""
FOR record IN @@collection
FILTER record.task.date_done < @checkpoint
REMOVE record IN @@collection
""",
bindVars={
"@collection": self.collection,
"checkpoint": checkpoint,
},
)

View File

@@ -0,0 +1,333 @@
"""Async I/O backend support utilities."""
import socket
import threading
import time
from collections import deque
from queue import Empty
from time import sleep
from weakref import WeakKeyDictionary
from kombu.utils.compat import detect_environment
from celery import states
from celery.exceptions import TimeoutError
from celery.utils.threads import THREAD_TIMEOUT_MAX
__all__ = (
'AsyncBackendMixin', 'BaseResultConsumer', 'Drainer',
'register_drainer',
)
drainers = {}
def register_drainer(name):
"""Decorator used to register a new result drainer type."""
def _inner(cls):
drainers[name] = cls
return cls
return _inner
@register_drainer('default')
class Drainer:
"""Result draining service."""
def __init__(self, result_consumer):
self.result_consumer = result_consumer
def start(self):
pass
def stop(self):
pass
def drain_events_until(self, p, timeout=None, interval=1, on_interval=None, wait=None):
wait = wait or self.result_consumer.drain_events
time_start = time.monotonic()
while 1:
# Total time spent may exceed a single call to wait()
if timeout and time.monotonic() - time_start >= timeout:
raise socket.timeout()
try:
yield self.wait_for(p, wait, timeout=interval)
except socket.timeout:
pass
if on_interval:
on_interval()
if p.ready: # got event on the wanted channel.
break
def wait_for(self, p, wait, timeout=None):
wait(timeout=timeout)
class greenletDrainer(Drainer):
spawn = None
_g = None
_drain_complete_event = None # event, sended (and recreated) after every drain_events iteration
def _create_drain_complete_event(self):
"""create new self._drain_complete_event object"""
pass
def _send_drain_complete_event(self):
"""raise self._drain_complete_event for wakeup .wait_for"""
pass
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._started = threading.Event()
self._stopped = threading.Event()
self._shutdown = threading.Event()
self._create_drain_complete_event()
def run(self):
self._started.set()
while not self._stopped.is_set():
try:
self.result_consumer.drain_events(timeout=1)
self._send_drain_complete_event()
self._create_drain_complete_event()
except socket.timeout:
pass
self._shutdown.set()
def start(self):
if not self._started.is_set():
self._g = self.spawn(self.run)
self._started.wait()
def stop(self):
self._stopped.set()
self._send_drain_complete_event()
self._shutdown.wait(THREAD_TIMEOUT_MAX)
def wait_for(self, p, wait, timeout=None):
self.start()
if not p.ready:
self._drain_complete_event.wait(timeout=timeout)
@register_drainer('eventlet')
class eventletDrainer(greenletDrainer):
def spawn(self, func):
from eventlet import sleep, spawn
g = spawn(func)
sleep(0)
return g
def _create_drain_complete_event(self):
from eventlet.event import Event
self._drain_complete_event = Event()
def _send_drain_complete_event(self):
self._drain_complete_event.send()
@register_drainer('gevent')
class geventDrainer(greenletDrainer):
def spawn(self, func):
import gevent
g = gevent.spawn(func)
gevent.sleep(0)
return g
def _create_drain_complete_event(self):
from gevent.event import Event
self._drain_complete_event = Event()
def _send_drain_complete_event(self):
self._drain_complete_event.set()
self._create_drain_complete_event()
class AsyncBackendMixin:
"""Mixin for backends that enables the async API."""
def _collect_into(self, result, bucket):
self.result_consumer.buckets[result] = bucket
def iter_native(self, result, no_ack=True, **kwargs):
self._ensure_not_eager()
results = result.results
if not results:
raise StopIteration()
# we tell the result consumer to put consumed results
# into these buckets.
bucket = deque()
for node in results:
if not hasattr(node, '_cache'):
bucket.append(node)
elif node._cache:
bucket.append(node)
else:
self._collect_into(node, bucket)
for _ in self._wait_for_pending(result, no_ack=no_ack, **kwargs):
while bucket:
node = bucket.popleft()
if not hasattr(node, '_cache'):
yield node.id, node.children
else:
yield node.id, node._cache
while bucket:
node = bucket.popleft()
yield node.id, node._cache
def add_pending_result(self, result, weak=False, start_drainer=True):
if start_drainer:
self.result_consumer.drainer.start()
try:
self._maybe_resolve_from_buffer(result)
except Empty:
self._add_pending_result(result.id, result, weak=weak)
return result
def _maybe_resolve_from_buffer(self, result):
result._maybe_set_cache(self._pending_messages.take(result.id))
def _add_pending_result(self, task_id, result, weak=False):
concrete, weak_ = self._pending_results
if task_id not in weak_ and result.id not in concrete:
(weak_ if weak else concrete)[task_id] = result
self.result_consumer.consume_from(task_id)
def add_pending_results(self, results, weak=False):
self.result_consumer.drainer.start()
return [self.add_pending_result(result, weak=weak, start_drainer=False)
for result in results]
def remove_pending_result(self, result):
self._remove_pending_result(result.id)
self.on_result_fulfilled(result)
return result
def _remove_pending_result(self, task_id):
for mapping in self._pending_results:
mapping.pop(task_id, None)
def on_result_fulfilled(self, result):
self.result_consumer.cancel_for(result.id)
def wait_for_pending(self, result,
callback=None, propagate=True, **kwargs):
self._ensure_not_eager()
for _ in self._wait_for_pending(result, **kwargs):
pass
return result.maybe_throw(callback=callback, propagate=propagate)
def _wait_for_pending(self, result,
timeout=None, on_interval=None, on_message=None,
**kwargs):
return self.result_consumer._wait_for_pending(
result, timeout=timeout,
on_interval=on_interval, on_message=on_message,
**kwargs
)
@property
def is_async(self):
return True
class BaseResultConsumer:
"""Manager responsible for consuming result messages."""
def __init__(self, backend, app, accept,
pending_results, pending_messages):
self.backend = backend
self.app = app
self.accept = accept
self._pending_results = pending_results
self._pending_messages = pending_messages
self.on_message = None
self.buckets = WeakKeyDictionary()
self.drainer = drainers[detect_environment()](self)
def start(self, initial_task_id, **kwargs):
raise NotImplementedError()
def stop(self):
pass
def drain_events(self, timeout=None):
raise NotImplementedError()
def consume_from(self, task_id):
raise NotImplementedError()
def cancel_for(self, task_id):
raise NotImplementedError()
def _after_fork(self):
self.buckets.clear()
self.buckets = WeakKeyDictionary()
self.on_message = None
self.on_after_fork()
def on_after_fork(self):
pass
def drain_events_until(self, p, timeout=None, on_interval=None):
return self.drainer.drain_events_until(
p, timeout=timeout, on_interval=on_interval)
def _wait_for_pending(self, result,
timeout=None, on_interval=None, on_message=None,
**kwargs):
self.on_wait_for_pending(result, timeout=timeout, **kwargs)
prev_on_m, self.on_message = self.on_message, on_message
try:
for _ in self.drain_events_until(
result.on_ready, timeout=timeout,
on_interval=on_interval):
yield
sleep(0)
except socket.timeout:
raise TimeoutError('The operation timed out.')
finally:
self.on_message = prev_on_m
def on_wait_for_pending(self, result, timeout=None, **kwargs):
pass
def on_out_of_band_result(self, message):
self.on_state_change(message.payload, message)
def _get_pending_result(self, task_id):
for mapping in self._pending_results:
try:
return mapping[task_id]
except KeyError:
pass
raise KeyError(task_id)
def on_state_change(self, meta, message):
if self.on_message:
self.on_message(meta)
if meta['status'] in states.READY_STATES:
task_id = meta['task_id']
try:
result = self._get_pending_result(task_id)
except KeyError:
# send to buffer in case we received this result
# before it was added to _pending_results.
self._pending_messages.put(task_id, meta)
else:
result._maybe_set_cache(meta)
buckets = self.buckets
try:
# remove bucket for this result, since it's fulfilled
bucket = buckets.pop(result)
except KeyError:
pass
else:
# send to waiter via bucket
bucket.append(result)
sleep(0)

View File

@@ -0,0 +1,165 @@
"""The Azure Storage Block Blob backend for Celery."""
from kombu.utils import cached_property
from kombu.utils.encoding import bytes_to_str
from celery.exceptions import ImproperlyConfigured
from celery.utils.log import get_logger
from .base import KeyValueStoreBackend
try:
import azure.storage.blob as azurestorage
from azure.core.exceptions import ResourceExistsError, ResourceNotFoundError
from azure.storage.blob import BlobServiceClient
except ImportError:
azurestorage = None
__all__ = ("AzureBlockBlobBackend",)
LOGGER = get_logger(__name__)
AZURE_BLOCK_BLOB_CONNECTION_PREFIX = 'azureblockblob://'
class AzureBlockBlobBackend(KeyValueStoreBackend):
"""Azure Storage Block Blob backend for Celery."""
def __init__(self,
url=None,
container_name=None,
*args,
**kwargs):
super().__init__(*args, **kwargs)
if azurestorage is None or azurestorage.__version__ < '12':
raise ImproperlyConfigured(
"You need to install the azure-storage-blob v12 library to"
"use the AzureBlockBlob backend")
conf = self.app.conf
self._connection_string = self._parse_url(url)
self._container_name = (
container_name or
conf["azureblockblob_container_name"])
self.base_path = conf.get('azureblockblob_base_path', '')
self._connection_timeout = conf.get(
'azureblockblob_connection_timeout', 20
)
self._read_timeout = conf.get('azureblockblob_read_timeout', 120)
@classmethod
def _parse_url(cls, url, prefix=AZURE_BLOCK_BLOB_CONNECTION_PREFIX):
connection_string = url[len(prefix):]
if not connection_string:
raise ImproperlyConfigured("Invalid URL")
return connection_string
@cached_property
def _blob_service_client(self):
"""Return the Azure Storage Blob service client.
If this is the first call to the property, the client is created and
the container is created if it doesn't yet exist.
"""
client = BlobServiceClient.from_connection_string(
self._connection_string,
connection_timeout=self._connection_timeout,
read_timeout=self._read_timeout
)
try:
client.create_container(name=self._container_name)
msg = f"Container created with name {self._container_name}."
except ResourceExistsError:
msg = f"Container with name {self._container_name} already." \
"exists. This will not be created."
LOGGER.info(msg)
return client
def get(self, key):
"""Read the value stored at the given key.
Args:
key: The key for which to read the value.
"""
key = bytes_to_str(key)
LOGGER.debug("Getting Azure Block Blob %s/%s", self._container_name, key)
blob_client = self._blob_service_client.get_blob_client(
container=self._container_name,
blob=f'{self.base_path}{key}',
)
try:
return blob_client.download_blob().readall().decode()
except ResourceNotFoundError:
return None
def set(self, key, value):
"""Store a value for a given key.
Args:
key: The key at which to store the value.
value: The value to store.
"""
key = bytes_to_str(key)
LOGGER.debug(f"Creating azure blob at {self._container_name}/{key}")
blob_client = self._blob_service_client.get_blob_client(
container=self._container_name,
blob=f'{self.base_path}{key}',
)
blob_client.upload_blob(value, overwrite=True)
def mget(self, keys):
"""Read all the values for the provided keys.
Args:
keys: The list of keys to read.
"""
return [self.get(key) for key in keys]
def delete(self, key):
"""Delete the value at a given key.
Args:
key: The key of the value to delete.
"""
key = bytes_to_str(key)
LOGGER.debug(f"Deleting azure blob at {self._container_name}/{key}")
blob_client = self._blob_service_client.get_blob_client(
container=self._container_name,
blob=f'{self.base_path}{key}',
)
blob_client.delete_blob()
def as_uri(self, include_password=False):
if include_password:
return (
f'{AZURE_BLOCK_BLOB_CONNECTION_PREFIX}'
f'{self._connection_string}'
)
connection_string_parts = self._connection_string.split(';')
account_key_prefix = 'AccountKey='
redacted_connection_string_parts = [
f'{account_key_prefix}**' if part.startswith(account_key_prefix)
else part
for part in connection_string_parts
]
return (
f'{AZURE_BLOCK_BLOB_CONNECTION_PREFIX}'
f'{";".join(redacted_connection_string_parts)}'
)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,163 @@
"""Memcached and in-memory cache result backend."""
from kombu.utils.encoding import bytes_to_str, ensure_bytes
from kombu.utils.objects import cached_property
from celery.exceptions import ImproperlyConfigured
from celery.utils.functional import LRUCache
from .base import KeyValueStoreBackend
__all__ = ('CacheBackend',)
_imp = [None]
REQUIRES_BACKEND = """\
The Memcached backend requires either pylibmc or python-memcached.\
"""
UNKNOWN_BACKEND = """\
The cache backend {0!r} is unknown,
Please use one of the following backends instead: {1}\
"""
# Global shared in-memory cache for in-memory cache client
# This is to share cache between threads
_DUMMY_CLIENT_CACHE = LRUCache(limit=5000)
def import_best_memcache():
if _imp[0] is None:
is_pylibmc, memcache_key_t = False, bytes_to_str
try:
import pylibmc as memcache
is_pylibmc = True
except ImportError:
try:
import memcache
except ImportError:
raise ImproperlyConfigured(REQUIRES_BACKEND)
_imp[0] = (is_pylibmc, memcache, memcache_key_t)
return _imp[0]
def get_best_memcache(*args, **kwargs):
# pylint: disable=unpacking-non-sequence
# This is most definitely a sequence, but pylint thinks it's not.
is_pylibmc, memcache, key_t = import_best_memcache()
Client = _Client = memcache.Client
if not is_pylibmc:
def Client(*args, **kwargs): # noqa: F811
kwargs.pop('behaviors', None)
return _Client(*args, **kwargs)
return Client, key_t
class DummyClient:
def __init__(self, *args, **kwargs):
self.cache = _DUMMY_CLIENT_CACHE
def get(self, key, *args, **kwargs):
return self.cache.get(key)
def get_multi(self, keys):
cache = self.cache
return {k: cache[k] for k in keys if k in cache}
def set(self, key, value, *args, **kwargs):
self.cache[key] = value
def delete(self, key, *args, **kwargs):
self.cache.pop(key, None)
def incr(self, key, delta=1):
return self.cache.incr(key, delta)
def touch(self, key, expire):
pass
backends = {
'memcache': get_best_memcache,
'memcached': get_best_memcache,
'pylibmc': get_best_memcache,
'memory': lambda: (DummyClient, ensure_bytes),
}
class CacheBackend(KeyValueStoreBackend):
"""Cache result backend."""
servers = None
supports_autoexpire = True
supports_native_join = True
implements_incr = True
def __init__(self, app, expires=None, backend=None,
options=None, url=None, **kwargs):
options = {} if not options else options
super().__init__(app, **kwargs)
self.url = url
self.options = dict(self.app.conf.cache_backend_options,
**options)
self.backend = url or backend or self.app.conf.cache_backend
if self.backend:
self.backend, _, servers = self.backend.partition('://')
self.servers = servers.rstrip('/').split(';')
self.expires = self.prepare_expires(expires, type=int)
try:
self.Client, self.key_t = backends[self.backend]()
except KeyError:
raise ImproperlyConfigured(UNKNOWN_BACKEND.format(
self.backend, ', '.join(backends)))
self._encode_prefixes() # rencode the keyprefixes
def get(self, key):
return self.client.get(key)
def mget(self, keys):
return self.client.get_multi(keys)
def set(self, key, value):
return self.client.set(key, value, self.expires)
def delete(self, key):
return self.client.delete(key)
def _apply_chord_incr(self, header_result_args, body, **kwargs):
chord_key = self.get_key_for_chord(header_result_args[0])
self.client.set(chord_key, 0, time=self.expires)
return super()._apply_chord_incr(
header_result_args, body, **kwargs)
def incr(self, key):
return self.client.incr(key)
def expire(self, key, value):
return self.client.touch(key, value)
@cached_property
def client(self):
return self.Client(self.servers, **self.options)
def __reduce__(self, args=(), kwargs=None):
kwargs = {} if not kwargs else kwargs
servers = ';'.join(self.servers)
backend = f'{self.backend}://{servers}/'
kwargs.update(
{'backend': backend,
'expires': self.expires,
'options': self.options})
return super().__reduce__(args, kwargs)
def as_uri(self, *args, **kwargs):
"""Return the backend as an URI.
This properly handles the case of multiple servers.
"""
servers = ';'.join(self.servers)
return f'{self.backend}://{servers}/'

View File

@@ -0,0 +1,256 @@
"""Apache Cassandra result store backend using the DataStax driver."""
import threading
from celery import states
from celery.exceptions import ImproperlyConfigured
from celery.utils.log import get_logger
from .base import BaseBackend
try: # pragma: no cover
import cassandra
import cassandra.auth
import cassandra.cluster
import cassandra.query
except ImportError:
cassandra = None
__all__ = ('CassandraBackend',)
logger = get_logger(__name__)
E_NO_CASSANDRA = """
You need to install the cassandra-driver library to
use the Cassandra backend. See https://github.com/datastax/python-driver
"""
E_NO_SUCH_CASSANDRA_AUTH_PROVIDER = """
CASSANDRA_AUTH_PROVIDER you provided is not a valid auth_provider class.
See https://datastax.github.io/python-driver/api/cassandra/auth.html.
"""
E_CASSANDRA_MISCONFIGURED = 'Cassandra backend improperly configured.'
E_CASSANDRA_NOT_CONFIGURED = 'Cassandra backend not configured.'
Q_INSERT_RESULT = """
INSERT INTO {table} (
task_id, status, result, date_done, traceback, children) VALUES (
%s, %s, %s, %s, %s, %s) {expires};
"""
Q_SELECT_RESULT = """
SELECT status, result, date_done, traceback, children
FROM {table}
WHERE task_id=%s
LIMIT 1
"""
Q_CREATE_RESULT_TABLE = """
CREATE TABLE {table} (
task_id text,
status text,
result blob,
date_done timestamp,
traceback blob,
children blob,
PRIMARY KEY ((task_id), date_done)
) WITH CLUSTERING ORDER BY (date_done DESC);
"""
Q_EXPIRES = """
USING TTL {0}
"""
def buf_t(x):
return bytes(x, 'utf8')
class CassandraBackend(BaseBackend):
"""Cassandra/AstraDB backend utilizing DataStax driver.
Raises:
celery.exceptions.ImproperlyConfigured:
if module :pypi:`cassandra-driver` is not available,
or not-exactly-one of the :setting:`cassandra_servers` and
the :setting:`cassandra_secure_bundle_path` settings is set.
"""
#: List of Cassandra servers with format: ``hostname``.
servers = None
#: Location of the secure connect bundle zipfile (absolute path).
bundle_path = None
supports_autoexpire = True # autoexpire supported via entry_ttl
def __init__(self, servers=None, keyspace=None, table=None, entry_ttl=None,
port=9042, bundle_path=None, **kwargs):
super().__init__(**kwargs)
if not cassandra:
raise ImproperlyConfigured(E_NO_CASSANDRA)
conf = self.app.conf
self.servers = servers or conf.get('cassandra_servers', None)
self.bundle_path = bundle_path or conf.get(
'cassandra_secure_bundle_path', None)
self.port = port or conf.get('cassandra_port', None)
self.keyspace = keyspace or conf.get('cassandra_keyspace', None)
self.table = table or conf.get('cassandra_table', None)
self.cassandra_options = conf.get('cassandra_options', {})
# either servers or bundle path must be provided...
db_directions = self.servers or self.bundle_path
if not db_directions or not self.keyspace or not self.table:
raise ImproperlyConfigured(E_CASSANDRA_NOT_CONFIGURED)
# ...but not both:
if self.servers and self.bundle_path:
raise ImproperlyConfigured(E_CASSANDRA_MISCONFIGURED)
expires = entry_ttl or conf.get('cassandra_entry_ttl', None)
self.cqlexpires = (
Q_EXPIRES.format(expires) if expires is not None else '')
read_cons = conf.get('cassandra_read_consistency') or 'LOCAL_QUORUM'
write_cons = conf.get('cassandra_write_consistency') or 'LOCAL_QUORUM'
self.read_consistency = getattr(
cassandra.ConsistencyLevel, read_cons,
cassandra.ConsistencyLevel.LOCAL_QUORUM)
self.write_consistency = getattr(
cassandra.ConsistencyLevel, write_cons,
cassandra.ConsistencyLevel.LOCAL_QUORUM)
self.auth_provider = None
auth_provider = conf.get('cassandra_auth_provider', None)
auth_kwargs = conf.get('cassandra_auth_kwargs', None)
if auth_provider and auth_kwargs:
auth_provider_class = getattr(cassandra.auth, auth_provider, None)
if not auth_provider_class:
raise ImproperlyConfigured(E_NO_SUCH_CASSANDRA_AUTH_PROVIDER)
self.auth_provider = auth_provider_class(**auth_kwargs)
self._cluster = None
self._session = None
self._write_stmt = None
self._read_stmt = None
self._lock = threading.RLock()
def _get_connection(self, write=False):
"""Prepare the connection for action.
Arguments:
write (bool): are we a writer?
"""
if self._session is not None:
return
self._lock.acquire()
try:
if self._session is not None:
return
# using either 'servers' or 'bundle_path' here:
if self.servers:
self._cluster = cassandra.cluster.Cluster(
self.servers, port=self.port,
auth_provider=self.auth_provider,
**self.cassandra_options)
else:
# 'bundle_path' is guaranteed to be set
self._cluster = cassandra.cluster.Cluster(
cloud={
'secure_connect_bundle': self.bundle_path,
},
auth_provider=self.auth_provider,
**self.cassandra_options)
self._session = self._cluster.connect(self.keyspace)
# We're forced to do concatenation below, as formatting would
# blow up on superficial %s that'll be processed by Cassandra
self._write_stmt = cassandra.query.SimpleStatement(
Q_INSERT_RESULT.format(
table=self.table, expires=self.cqlexpires),
)
self._write_stmt.consistency_level = self.write_consistency
self._read_stmt = cassandra.query.SimpleStatement(
Q_SELECT_RESULT.format(table=self.table),
)
self._read_stmt.consistency_level = self.read_consistency
if write:
# Only possible writers "workers" are allowed to issue
# CREATE TABLE. This is to prevent conflicting situations
# where both task-creator and task-executor would issue it
# at the same time.
# Anyway; if you're doing anything critical, you should
# have created this table in advance, in which case
# this query will be a no-op (AlreadyExists)
make_stmt = cassandra.query.SimpleStatement(
Q_CREATE_RESULT_TABLE.format(table=self.table),
)
make_stmt.consistency_level = self.write_consistency
try:
self._session.execute(make_stmt)
except cassandra.AlreadyExists:
pass
except cassandra.OperationTimedOut:
# a heavily loaded or gone Cassandra cluster failed to respond.
# leave this class in a consistent state
if self._cluster is not None:
self._cluster.shutdown() # also shuts down _session
self._cluster = None
self._session = None
raise # we did fail after all - reraise
finally:
self._lock.release()
def _store_result(self, task_id, result, state,
traceback=None, request=None, **kwargs):
"""Store return value and state of an executed task."""
self._get_connection(write=True)
self._session.execute(self._write_stmt, (
task_id,
state,
buf_t(self.encode(result)),
self.app.now(),
buf_t(self.encode(traceback)),
buf_t(self.encode(self.current_task_children(request)))
))
def as_uri(self, include_password=True):
return 'cassandra://'
def _get_task_meta_for(self, task_id):
"""Get task meta-data for a task by id."""
self._get_connection()
res = self._session.execute(self._read_stmt, (task_id, )).one()
if not res:
return {'status': states.PENDING, 'result': None}
status, result, date_done, traceback, children = res
return self.meta_from_decoded({
'task_id': task_id,
'status': status,
'result': self.decode(result),
'date_done': date_done,
'traceback': self.decode(traceback),
'children': self.decode(children),
})
def __reduce__(self, args=(), kwargs=None):
kwargs = {} if not kwargs else kwargs
kwargs.update(
{'servers': self.servers,
'keyspace': self.keyspace,
'table': self.table})
return super().__reduce__(args, kwargs)

View File

@@ -0,0 +1,116 @@
"""Consul result store backend.
- :class:`ConsulBackend` implements KeyValueStoreBackend to store results
in the key-value store of Consul.
"""
from kombu.utils.encoding import bytes_to_str
from kombu.utils.url import parse_url
from celery.backends.base import KeyValueStoreBackend
from celery.exceptions import ImproperlyConfigured
from celery.utils.log import get_logger
try:
import consul
except ImportError:
consul = None
logger = get_logger(__name__)
__all__ = ('ConsulBackend',)
CONSUL_MISSING = """\
You need to install the python-consul library in order to use \
the Consul result store backend."""
class ConsulBackend(KeyValueStoreBackend):
"""Consul.io K/V store backend for Celery."""
consul = consul
supports_autoexpire = True
consistency = 'consistent'
path = None
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if self.consul is None:
raise ImproperlyConfigured(CONSUL_MISSING)
#
# By default, for correctness, we use a client connection per
# operation. If set, self.one_client will be used for all operations.
# This provides for the original behaviour to be selected, and is
# also convenient for mocking in the unit tests.
#
self.one_client = None
self._init_from_params(**parse_url(self.url))
def _init_from_params(self, hostname, port, virtual_host, **params):
logger.debug('Setting on Consul client to connect to %s:%d',
hostname, port)
self.path = virtual_host
self.hostname = hostname
self.port = port
#
# Optionally, allow a single client connection to be used to reduce
# the connection load on Consul by adding a "one_client=1" parameter
# to the URL.
#
if params.get('one_client', None):
self.one_client = self.client()
def client(self):
return self.one_client or consul.Consul(host=self.hostname,
port=self.port,
consistency=self.consistency)
def _key_to_consul_key(self, key):
key = bytes_to_str(key)
return key if self.path is None else f'{self.path}/{key}'
def get(self, key):
key = self._key_to_consul_key(key)
logger.debug('Trying to fetch key %s from Consul', key)
try:
_, data = self.client().kv.get(key)
return data['Value']
except TypeError:
pass
def mget(self, keys):
for key in keys:
yield self.get(key)
def set(self, key, value):
"""Set a key in Consul.
Before creating the key it will create a session inside Consul
where it creates a session with a TTL
The key created afterwards will reference to the session's ID.
If the session expires it will remove the key so that results
can auto expire from the K/V store
"""
session_name = bytes_to_str(key)
key = self._key_to_consul_key(key)
logger.debug('Trying to create Consul session %s with TTL %d',
session_name, self.expires)
client = self.client()
session_id = client.session.create(name=session_name,
behavior='delete',
ttl=self.expires)
logger.debug('Created Consul session %s', session_id)
logger.debug('Writing key %s to Consul', key)
return client.kv.put(key=key, value=value, acquire=session_id)
def delete(self, key):
key = self._key_to_consul_key(key)
logger.debug('Removing key %s from Consul', key)
return self.client().kv.delete(key)

View File

@@ -0,0 +1,218 @@
"""The CosmosDB/SQL backend for Celery (experimental)."""
from kombu.utils import cached_property
from kombu.utils.encoding import bytes_to_str
from kombu.utils.url import _parse_url
from celery.exceptions import ImproperlyConfigured
from celery.utils.log import get_logger
from .base import KeyValueStoreBackend
try:
import pydocumentdb
from pydocumentdb.document_client import DocumentClient
from pydocumentdb.documents import ConnectionPolicy, ConsistencyLevel, PartitionKind
from pydocumentdb.errors import HTTPFailure
from pydocumentdb.retry_options import RetryOptions
except ImportError:
pydocumentdb = DocumentClient = ConsistencyLevel = PartitionKind = \
HTTPFailure = ConnectionPolicy = RetryOptions = None
__all__ = ("CosmosDBSQLBackend",)
ERROR_NOT_FOUND = 404
ERROR_EXISTS = 409
LOGGER = get_logger(__name__)
class CosmosDBSQLBackend(KeyValueStoreBackend):
"""CosmosDB/SQL backend for Celery."""
def __init__(self,
url=None,
database_name=None,
collection_name=None,
consistency_level=None,
max_retry_attempts=None,
max_retry_wait_time=None,
*args,
**kwargs):
super().__init__(*args, **kwargs)
if pydocumentdb is None:
raise ImproperlyConfigured(
"You need to install the pydocumentdb library to use the "
"CosmosDB backend.")
conf = self.app.conf
self._endpoint, self._key = self._parse_url(url)
self._database_name = (
database_name or
conf["cosmosdbsql_database_name"])
self._collection_name = (
collection_name or
conf["cosmosdbsql_collection_name"])
try:
self._consistency_level = getattr(
ConsistencyLevel,
consistency_level or
conf["cosmosdbsql_consistency_level"])
except AttributeError:
raise ImproperlyConfigured("Unknown CosmosDB consistency level")
self._max_retry_attempts = (
max_retry_attempts or
conf["cosmosdbsql_max_retry_attempts"])
self._max_retry_wait_time = (
max_retry_wait_time or
conf["cosmosdbsql_max_retry_wait_time"])
@classmethod
def _parse_url(cls, url):
_, host, port, _, password, _, _ = _parse_url(url)
if not host or not password:
raise ImproperlyConfigured("Invalid URL")
if not port:
port = 443
scheme = "https" if port == 443 else "http"
endpoint = f"{scheme}://{host}:{port}"
return endpoint, password
@cached_property
def _client(self):
"""Return the CosmosDB/SQL client.
If this is the first call to the property, the client is created and
the database and collection are initialized if they don't yet exist.
"""
connection_policy = ConnectionPolicy()
connection_policy.RetryOptions = RetryOptions(
max_retry_attempt_count=self._max_retry_attempts,
max_wait_time_in_seconds=self._max_retry_wait_time)
client = DocumentClient(
self._endpoint,
{"masterKey": self._key},
connection_policy=connection_policy,
consistency_level=self._consistency_level)
self._create_database_if_not_exists(client)
self._create_collection_if_not_exists(client)
return client
def _create_database_if_not_exists(self, client):
try:
client.CreateDatabase({"id": self._database_name})
except HTTPFailure as ex:
if ex.status_code != ERROR_EXISTS:
raise
else:
LOGGER.info("Created CosmosDB database %s",
self._database_name)
def _create_collection_if_not_exists(self, client):
try:
client.CreateCollection(
self._database_link,
{"id": self._collection_name,
"partitionKey": {"paths": ["/id"],
"kind": PartitionKind.Hash}})
except HTTPFailure as ex:
if ex.status_code != ERROR_EXISTS:
raise
else:
LOGGER.info("Created CosmosDB collection %s/%s",
self._database_name, self._collection_name)
@cached_property
def _database_link(self):
return "dbs/" + self._database_name
@cached_property
def _collection_link(self):
return self._database_link + "/colls/" + self._collection_name
def _get_document_link(self, key):
return self._collection_link + "/docs/" + key
@classmethod
def _get_partition_key(cls, key):
if not key or key.isspace():
raise ValueError("Key cannot be none, empty or whitespace.")
return {"partitionKey": key}
def get(self, key):
"""Read the value stored at the given key.
Args:
key: The key for which to read the value.
"""
key = bytes_to_str(key)
LOGGER.debug("Getting CosmosDB document %s/%s/%s",
self._database_name, self._collection_name, key)
try:
document = self._client.ReadDocument(
self._get_document_link(key),
self._get_partition_key(key))
except HTTPFailure as ex:
if ex.status_code != ERROR_NOT_FOUND:
raise
return None
else:
return document.get("value")
def set(self, key, value):
"""Store a value for a given key.
Args:
key: The key at which to store the value.
value: The value to store.
"""
key = bytes_to_str(key)
LOGGER.debug("Creating CosmosDB document %s/%s/%s",
self._database_name, self._collection_name, key)
self._client.CreateDocument(
self._collection_link,
{"id": key, "value": value},
self._get_partition_key(key))
def mget(self, keys):
"""Read all the values for the provided keys.
Args:
keys: The list of keys to read.
"""
return [self.get(key) for key in keys]
def delete(self, key):
"""Delete the value at a given key.
Args:
key: The key of the value to delete.
"""
key = bytes_to_str(key)
LOGGER.debug("Deleting CosmosDB document %s/%s/%s",
self._database_name, self._collection_name, key)
self._client.DeleteDocument(
self._get_document_link(key),
self._get_partition_key(key))

View File

@@ -0,0 +1,114 @@
"""Couchbase result store backend."""
from kombu.utils.url import _parse_url
from celery.exceptions import ImproperlyConfigured
from .base import KeyValueStoreBackend
try:
from couchbase.auth import PasswordAuthenticator
from couchbase.cluster import Cluster
except ImportError:
Cluster = PasswordAuthenticator = None
try:
from couchbase_core._libcouchbase import FMT_AUTO
except ImportError:
FMT_AUTO = None
__all__ = ('CouchbaseBackend',)
class CouchbaseBackend(KeyValueStoreBackend):
"""Couchbase backend.
Raises:
celery.exceptions.ImproperlyConfigured:
if module :pypi:`couchbase` is not available.
"""
bucket = 'default'
host = 'localhost'
port = 8091
username = None
password = None
quiet = False
supports_autoexpire = True
timeout = 2.5
# Use str as couchbase key not bytes
key_t = str
def __init__(self, url=None, *args, **kwargs):
kwargs.setdefault('expires_type', int)
super().__init__(*args, **kwargs)
self.url = url
if Cluster is None:
raise ImproperlyConfigured(
'You need to install the couchbase library to use the '
'Couchbase backend.',
)
uhost = uport = uname = upass = ubucket = None
if url:
_, uhost, uport, uname, upass, ubucket, _ = _parse_url(url)
ubucket = ubucket.strip('/') if ubucket else None
config = self.app.conf.get('couchbase_backend_settings', None)
if config is not None:
if not isinstance(config, dict):
raise ImproperlyConfigured(
'Couchbase backend settings should be grouped in a dict',
)
else:
config = {}
self.host = uhost or config.get('host', self.host)
self.port = int(uport or config.get('port', self.port))
self.bucket = ubucket or config.get('bucket', self.bucket)
self.username = uname or config.get('username', self.username)
self.password = upass or config.get('password', self.password)
self._connection = None
def _get_connection(self):
"""Connect to the Couchbase server."""
if self._connection is None:
if self.host and self.port:
uri = f"couchbase://{self.host}:{self.port}"
else:
uri = f"couchbase://{self.host}"
if self.username and self.password:
opt = PasswordAuthenticator(self.username, self.password)
else:
opt = None
cluster = Cluster(uri, opt)
bucket = cluster.bucket(self.bucket)
self._connection = bucket.default_collection()
return self._connection
@property
def connection(self):
return self._get_connection()
def get(self, key):
return self.connection.get(key).content
def set(self, key, value):
# Since 4.0.0 value is JSONType in couchbase lib, so parameter format isn't needed
if FMT_AUTO is not None:
self.connection.upsert(key, value, ttl=self.expires, format=FMT_AUTO)
else:
self.connection.upsert(key, value, ttl=self.expires)
def mget(self, keys):
return self.connection.get_multi(keys)
def delete(self, key):
self.connection.remove(key)

View File

@@ -0,0 +1,99 @@
"""CouchDB result store backend."""
from kombu.utils.encoding import bytes_to_str
from kombu.utils.url import _parse_url
from celery.exceptions import ImproperlyConfigured
from .base import KeyValueStoreBackend
try:
import pycouchdb
except ImportError:
pycouchdb = None
__all__ = ('CouchBackend',)
ERR_LIB_MISSING = """\
You need to install the pycouchdb library to use the CouchDB result backend\
"""
class CouchBackend(KeyValueStoreBackend):
"""CouchDB backend.
Raises:
celery.exceptions.ImproperlyConfigured:
if module :pypi:`pycouchdb` is not available.
"""
container = 'default'
scheme = 'http'
host = 'localhost'
port = 5984
username = None
password = None
def __init__(self, url=None, *args, **kwargs):
super().__init__(*args, **kwargs)
self.url = url
if pycouchdb is None:
raise ImproperlyConfigured(ERR_LIB_MISSING)
uscheme = uhost = uport = uname = upass = ucontainer = None
if url:
_, uhost, uport, uname, upass, ucontainer, _ = _parse_url(url)
ucontainer = ucontainer.strip('/') if ucontainer else None
self.scheme = uscheme or self.scheme
self.host = uhost or self.host
self.port = int(uport or self.port)
self.container = ucontainer or self.container
self.username = uname or self.username
self.password = upass or self.password
self._connection = None
def _get_connection(self):
"""Connect to the CouchDB server."""
if self.username and self.password:
conn_string = f'{self.scheme}://{self.username}:{self.password}@{self.host}:{self.port}'
server = pycouchdb.Server(conn_string, authmethod='basic')
else:
conn_string = f'{self.scheme}://{self.host}:{self.port}'
server = pycouchdb.Server(conn_string)
try:
return server.database(self.container)
except pycouchdb.exceptions.NotFound:
return server.create(self.container)
@property
def connection(self):
if self._connection is None:
self._connection = self._get_connection()
return self._connection
def get(self, key):
key = bytes_to_str(key)
try:
return self.connection.get(key)['value']
except pycouchdb.exceptions.NotFound:
return None
def set(self, key, value):
key = bytes_to_str(key)
data = {'_id': key, 'value': value}
try:
self.connection.save(data)
except pycouchdb.exceptions.Conflict:
# document already exists, update it
data = self.connection.get(key)
data['value'] = value
self.connection.save(data)
def mget(self, keys):
return [self.get(key) for key in keys]
def delete(self, key):
self.connection.delete(key)

View File

@@ -0,0 +1,222 @@
"""SQLAlchemy result store backend."""
import logging
from contextlib import contextmanager
from vine.utils import wraps
from celery import states
from celery.backends.base import BaseBackend
from celery.exceptions import ImproperlyConfigured
from celery.utils.time import maybe_timedelta
from .models import Task, TaskExtended, TaskSet
from .session import SessionManager
try:
from sqlalchemy.exc import DatabaseError, InvalidRequestError
from sqlalchemy.orm.exc import StaleDataError
except ImportError:
raise ImproperlyConfigured(
'The database result backend requires SQLAlchemy to be installed.'
'See https://pypi.org/project/SQLAlchemy/')
logger = logging.getLogger(__name__)
__all__ = ('DatabaseBackend',)
@contextmanager
def session_cleanup(session):
try:
yield
except Exception:
session.rollback()
raise
finally:
session.close()
def retry(fun):
@wraps(fun)
def _inner(*args, **kwargs):
max_retries = kwargs.pop('max_retries', 3)
for retries in range(max_retries):
try:
return fun(*args, **kwargs)
except (DatabaseError, InvalidRequestError, StaleDataError):
logger.warning(
'Failed operation %s. Retrying %s more times.',
fun.__name__, max_retries - retries - 1,
exc_info=True)
if retries + 1 >= max_retries:
raise
return _inner
class DatabaseBackend(BaseBackend):
"""The database result backend."""
# ResultSet.iterate should sleep this much between each pool,
# to not bombard the database with queries.
subpolling_interval = 0.5
task_cls = Task
taskset_cls = TaskSet
def __init__(self, dburi=None, engine_options=None, url=None, **kwargs):
# The `url` argument was added later and is used by
# the app to set backend by url (celery.app.backends.by_url)
super().__init__(expires_type=maybe_timedelta,
url=url, **kwargs)
conf = self.app.conf
if self.extended_result:
self.task_cls = TaskExtended
self.url = url or dburi or conf.database_url
self.engine_options = dict(
engine_options or {},
**conf.database_engine_options or {})
self.short_lived_sessions = kwargs.get(
'short_lived_sessions',
conf.database_short_lived_sessions)
schemas = conf.database_table_schemas or {}
tablenames = conf.database_table_names or {}
self.task_cls.configure(
schema=schemas.get('task'),
name=tablenames.get('task'))
self.taskset_cls.configure(
schema=schemas.get('group'),
name=tablenames.get('group'))
if not self.url:
raise ImproperlyConfigured(
'Missing connection string! Do you have the'
' database_url setting set to a real value?')
@property
def extended_result(self):
return self.app.conf.find_value_for_key('extended', 'result')
def ResultSession(self, session_manager=SessionManager()):
return session_manager.session_factory(
dburi=self.url,
short_lived_sessions=self.short_lived_sessions,
**self.engine_options)
@retry
def _store_result(self, task_id, result, state, traceback=None,
request=None, **kwargs):
"""Store return value and state of an executed task."""
session = self.ResultSession()
with session_cleanup(session):
task = list(session.query(self.task_cls).filter(self.task_cls.task_id == task_id))
task = task and task[0]
if not task:
task = self.task_cls(task_id)
task.task_id = task_id
session.add(task)
session.flush()
self._update_result(task, result, state, traceback=traceback, request=request)
session.commit()
def _update_result(self, task, result, state, traceback=None,
request=None):
meta = self._get_result_meta(result=result, state=state,
traceback=traceback, request=request,
format_date=False, encode=True)
# Exclude the primary key id and task_id columns
# as we should not set it None
columns = [column.name for column in self.task_cls.__table__.columns
if column.name not in {'id', 'task_id'}]
# Iterate through the columns name of the table
# to set the value from meta.
# If the value is not present in meta, set None
for column in columns:
value = meta.get(column)
setattr(task, column, value)
@retry
def _get_task_meta_for(self, task_id):
"""Get task meta-data for a task by id."""
session = self.ResultSession()
with session_cleanup(session):
task = list(session.query(self.task_cls).filter(self.task_cls.task_id == task_id))
task = task and task[0]
if not task:
task = self.task_cls(task_id)
task.status = states.PENDING
task.result = None
data = task.to_dict()
if data.get('args', None) is not None:
data['args'] = self.decode(data['args'])
if data.get('kwargs', None) is not None:
data['kwargs'] = self.decode(data['kwargs'])
return self.meta_from_decoded(data)
@retry
def _save_group(self, group_id, result):
"""Store the result of an executed group."""
session = self.ResultSession()
with session_cleanup(session):
group = self.taskset_cls(group_id, result)
session.add(group)
session.flush()
session.commit()
return result
@retry
def _restore_group(self, group_id):
"""Get meta-data for group by id."""
session = self.ResultSession()
with session_cleanup(session):
group = session.query(self.taskset_cls).filter(
self.taskset_cls.taskset_id == group_id).first()
if group:
return group.to_dict()
@retry
def _delete_group(self, group_id):
"""Delete meta-data for group by id."""
session = self.ResultSession()
with session_cleanup(session):
session.query(self.taskset_cls).filter(
self.taskset_cls.taskset_id == group_id).delete()
session.flush()
session.commit()
@retry
def _forget(self, task_id):
"""Forget about result."""
session = self.ResultSession()
with session_cleanup(session):
session.query(self.task_cls).filter(self.task_cls.task_id == task_id).delete()
session.commit()
def cleanup(self):
"""Delete expired meta-data."""
session = self.ResultSession()
expires = self.expires
now = self.app.now()
with session_cleanup(session):
session.query(self.task_cls).filter(
self.task_cls.date_done < (now - expires)).delete()
session.query(self.taskset_cls).filter(
self.taskset_cls.date_done < (now - expires)).delete()
session.commit()
def __reduce__(self, args=(), kwargs=None):
kwargs = {} if not kwargs else kwargs
kwargs.update(
{'dburi': self.url,
'expires': self.expires,
'engine_options': self.engine_options})
return super().__reduce__(args, kwargs)

View File

@@ -0,0 +1,108 @@
"""Database models used by the SQLAlchemy result store backend."""
from datetime import datetime
import sqlalchemy as sa
from sqlalchemy.types import PickleType
from celery import states
from .session import ResultModelBase
__all__ = ('Task', 'TaskExtended', 'TaskSet')
class Task(ResultModelBase):
"""Task result/status."""
__tablename__ = 'celery_taskmeta'
__table_args__ = {'sqlite_autoincrement': True}
id = sa.Column(sa.Integer, sa.Sequence('task_id_sequence'),
primary_key=True, autoincrement=True)
task_id = sa.Column(sa.String(155), unique=True)
status = sa.Column(sa.String(50), default=states.PENDING)
result = sa.Column(PickleType, nullable=True)
date_done = sa.Column(sa.DateTime, default=datetime.utcnow,
onupdate=datetime.utcnow, nullable=True)
traceback = sa.Column(sa.Text, nullable=True)
def __init__(self, task_id):
self.task_id = task_id
def to_dict(self):
return {
'task_id': self.task_id,
'status': self.status,
'result': self.result,
'traceback': self.traceback,
'date_done': self.date_done,
}
def __repr__(self):
return '<Task {0.task_id} state: {0.status}>'.format(self)
@classmethod
def configure(cls, schema=None, name=None):
cls.__table__.schema = schema
cls.id.default.schema = schema
cls.__table__.name = name or cls.__tablename__
class TaskExtended(Task):
"""For the extend result."""
__tablename__ = 'celery_taskmeta'
__table_args__ = {'sqlite_autoincrement': True, 'extend_existing': True}
name = sa.Column(sa.String(155), nullable=True)
args = sa.Column(sa.LargeBinary, nullable=True)
kwargs = sa.Column(sa.LargeBinary, nullable=True)
worker = sa.Column(sa.String(155), nullable=True)
retries = sa.Column(sa.Integer, nullable=True)
queue = sa.Column(sa.String(155), nullable=True)
def to_dict(self):
task_dict = super().to_dict()
task_dict.update({
'name': self.name,
'args': self.args,
'kwargs': self.kwargs,
'worker': self.worker,
'retries': self.retries,
'queue': self.queue,
})
return task_dict
class TaskSet(ResultModelBase):
"""TaskSet result."""
__tablename__ = 'celery_tasksetmeta'
__table_args__ = {'sqlite_autoincrement': True}
id = sa.Column(sa.Integer, sa.Sequence('taskset_id_sequence'),
autoincrement=True, primary_key=True)
taskset_id = sa.Column(sa.String(155), unique=True)
result = sa.Column(PickleType, nullable=True)
date_done = sa.Column(sa.DateTime, default=datetime.utcnow,
nullable=True)
def __init__(self, taskset_id, result):
self.taskset_id = taskset_id
self.result = result
def to_dict(self):
return {
'taskset_id': self.taskset_id,
'result': self.result,
'date_done': self.date_done,
}
def __repr__(self):
return f'<TaskSet: {self.taskset_id}>'
@classmethod
def configure(cls, schema=None, name=None):
cls.__table__.schema = schema
cls.id.default.schema = schema
cls.__table__.name = name or cls.__tablename__

View File

@@ -0,0 +1,89 @@
"""SQLAlchemy session."""
import time
from kombu.utils.compat import register_after_fork
from sqlalchemy import create_engine
from sqlalchemy.exc import DatabaseError
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import NullPool
from celery.utils.time import get_exponential_backoff_interval
try:
from sqlalchemy.orm import declarative_base
except ImportError:
# TODO: Remove this once we drop support for SQLAlchemy < 1.4.
from sqlalchemy.ext.declarative import declarative_base
ResultModelBase = declarative_base()
__all__ = ('SessionManager',)
PREPARE_MODELS_MAX_RETRIES = 10
def _after_fork_cleanup_session(session):
session._after_fork()
class SessionManager:
"""Manage SQLAlchemy sessions."""
def __init__(self):
self._engines = {}
self._sessions = {}
self.forked = False
self.prepared = False
if register_after_fork is not None:
register_after_fork(self, _after_fork_cleanup_session)
def _after_fork(self):
self.forked = True
def get_engine(self, dburi, **kwargs):
if self.forked:
try:
return self._engines[dburi]
except KeyError:
engine = self._engines[dburi] = create_engine(dburi, **kwargs)
return engine
else:
kwargs = {k: v for k, v in kwargs.items() if
not k.startswith('pool')}
return create_engine(dburi, poolclass=NullPool, **kwargs)
def create_session(self, dburi, short_lived_sessions=False, **kwargs):
engine = self.get_engine(dburi, **kwargs)
if self.forked:
if short_lived_sessions or dburi not in self._sessions:
self._sessions[dburi] = sessionmaker(bind=engine)
return engine, self._sessions[dburi]
return engine, sessionmaker(bind=engine)
def prepare_models(self, engine):
if not self.prepared:
# SQLAlchemy will check if the items exist before trying to
# create them, which is a race condition. If it raises an error
# in one iteration, the next may pass all the existence checks
# and the call will succeed.
retries = 0
while True:
try:
ResultModelBase.metadata.create_all(engine)
except DatabaseError:
if retries < PREPARE_MODELS_MAX_RETRIES:
sleep_amount_ms = get_exponential_backoff_interval(
10, retries, 1000, True
)
time.sleep(sleep_amount_ms / 1000)
retries += 1
else:
raise
else:
break
self.prepared = True
def session_factory(self, dburi, **kwargs):
engine, session = self.create_session(dburi, **kwargs)
self.prepare_models(engine)
return session()

View File

@@ -0,0 +1,493 @@
"""AWS DynamoDB result store backend."""
from collections import namedtuple
from time import sleep, time
from kombu.utils.url import _parse_url as parse_url
from celery.exceptions import ImproperlyConfigured
from celery.utils.log import get_logger
from .base import KeyValueStoreBackend
try:
import boto3
from botocore.exceptions import ClientError
except ImportError:
boto3 = ClientError = None
__all__ = ('DynamoDBBackend',)
# Helper class that describes a DynamoDB attribute
DynamoDBAttribute = namedtuple('DynamoDBAttribute', ('name', 'data_type'))
logger = get_logger(__name__)
class DynamoDBBackend(KeyValueStoreBackend):
"""AWS DynamoDB result backend.
Raises:
celery.exceptions.ImproperlyConfigured:
if module :pypi:`boto3` is not available.
"""
#: default DynamoDB table name (`default`)
table_name = 'celery'
#: Read Provisioned Throughput (`default`)
read_capacity_units = 1
#: Write Provisioned Throughput (`default`)
write_capacity_units = 1
#: AWS region (`default`)
aws_region = None
#: The endpoint URL that is passed to boto3 (local DynamoDB) (`default`)
endpoint_url = None
#: Item time-to-live in seconds (`default`)
time_to_live_seconds = None
# DynamoDB supports Time to Live as an auto-expiry mechanism.
supports_autoexpire = True
_key_field = DynamoDBAttribute(name='id', data_type='S')
_value_field = DynamoDBAttribute(name='result', data_type='B')
_timestamp_field = DynamoDBAttribute(name='timestamp', data_type='N')
_ttl_field = DynamoDBAttribute(name='ttl', data_type='N')
_available_fields = None
def __init__(self, url=None, table_name=None, *args, **kwargs):
super().__init__(*args, **kwargs)
self.url = url
self.table_name = table_name or self.table_name
if not boto3:
raise ImproperlyConfigured(
'You need to install the boto3 library to use the '
'DynamoDB backend.')
aws_credentials_given = False
aws_access_key_id = None
aws_secret_access_key = None
if url is not None:
scheme, region, port, username, password, table, query = \
parse_url(url)
aws_access_key_id = username
aws_secret_access_key = password
access_key_given = aws_access_key_id is not None
secret_key_given = aws_secret_access_key is not None
if access_key_given != secret_key_given:
raise ImproperlyConfigured(
'You need to specify both the Access Key ID '
'and Secret.')
aws_credentials_given = access_key_given
if region == 'localhost':
# We are using the downloadable, local version of DynamoDB
self.endpoint_url = f'http://localhost:{port}'
self.aws_region = 'us-east-1'
logger.warning(
'Using local-only DynamoDB endpoint URL: {}'.format(
self.endpoint_url
)
)
else:
self.aws_region = region
# If endpoint_url is explicitly set use it instead
_get = self.app.conf.get
config_endpoint_url = _get('dynamodb_endpoint_url')
if config_endpoint_url:
self.endpoint_url = config_endpoint_url
self.read_capacity_units = int(
query.get(
'read',
self.read_capacity_units
)
)
self.write_capacity_units = int(
query.get(
'write',
self.write_capacity_units
)
)
ttl = query.get('ttl_seconds', self.time_to_live_seconds)
if ttl:
try:
self.time_to_live_seconds = int(ttl)
except ValueError as e:
logger.error(
f'TTL must be a number; got "{ttl}"',
exc_info=e
)
raise e
self.table_name = table or self.table_name
self._available_fields = (
self._key_field,
self._value_field,
self._timestamp_field
)
self._client = None
if aws_credentials_given:
self._get_client(
access_key_id=aws_access_key_id,
secret_access_key=aws_secret_access_key
)
def _get_client(self, access_key_id=None, secret_access_key=None):
"""Get client connection."""
if self._client is None:
client_parameters = {
'region_name': self.aws_region
}
if access_key_id is not None:
client_parameters.update({
'aws_access_key_id': access_key_id,
'aws_secret_access_key': secret_access_key
})
if self.endpoint_url is not None:
client_parameters['endpoint_url'] = self.endpoint_url
self._client = boto3.client(
'dynamodb',
**client_parameters
)
self._get_or_create_table()
if self._has_ttl() is not None:
self._validate_ttl_methods()
self._set_table_ttl()
return self._client
def _get_table_schema(self):
"""Get the boto3 structure describing the DynamoDB table schema."""
return {
'AttributeDefinitions': [
{
'AttributeName': self._key_field.name,
'AttributeType': self._key_field.data_type
}
],
'TableName': self.table_name,
'KeySchema': [
{
'AttributeName': self._key_field.name,
'KeyType': 'HASH'
}
],
'ProvisionedThroughput': {
'ReadCapacityUnits': self.read_capacity_units,
'WriteCapacityUnits': self.write_capacity_units
}
}
def _get_or_create_table(self):
"""Create table if not exists, otherwise return the description."""
table_schema = self._get_table_schema()
try:
return self._client.describe_table(TableName=self.table_name)
except ClientError as e:
error_code = e.response['Error'].get('Code', 'Unknown')
if error_code == 'ResourceNotFoundException':
table_description = self._client.create_table(**table_schema)
logger.info(
'DynamoDB Table {} did not exist, creating.'.format(
self.table_name
)
)
# In case we created the table, wait until it becomes available.
self._wait_for_table_status('ACTIVE')
logger.info(
'DynamoDB Table {} is now available.'.format(
self.table_name
)
)
return table_description
else:
raise e
def _has_ttl(self):
"""Return the desired Time to Live config.
- True: Enable TTL on the table; use expiry.
- False: Disable TTL on the table; don't use expiry.
- None: Ignore TTL on the table; don't use expiry.
"""
return None if self.time_to_live_seconds is None \
else self.time_to_live_seconds >= 0
def _validate_ttl_methods(self):
"""Verify boto support for the DynamoDB Time to Live methods."""
# Required TTL methods.
required_methods = (
'update_time_to_live',
'describe_time_to_live',
)
# Find missing methods.
missing_methods = []
for method in list(required_methods):
if not hasattr(self._client, method):
missing_methods.append(method)
if missing_methods:
logger.error(
(
'boto3 method(s) {methods} not found; ensure that '
'boto3>=1.9.178 and botocore>=1.12.178 are installed'
).format(
methods=','.join(missing_methods)
)
)
raise AttributeError(
'boto3 method(s) {methods} not found'.format(
methods=','.join(missing_methods)
)
)
def _get_ttl_specification(self, ttl_attr_name):
"""Get the boto3 structure describing the DynamoDB TTL specification."""
return {
'TableName': self.table_name,
'TimeToLiveSpecification': {
'Enabled': self._has_ttl(),
'AttributeName': ttl_attr_name
}
}
def _get_table_ttl_description(self):
# Get the current TTL description.
try:
description = self._client.describe_time_to_live(
TableName=self.table_name
)
except ClientError as e:
error_code = e.response['Error'].get('Code', 'Unknown')
error_message = e.response['Error'].get('Message', 'Unknown')
logger.error((
'Error describing Time to Live on DynamoDB table {table}: '
'{code}: {message}'
).format(
table=self.table_name,
code=error_code,
message=error_message,
))
raise e
return description
def _set_table_ttl(self):
"""Enable or disable Time to Live on the table."""
# Get the table TTL description, and return early when possible.
description = self._get_table_ttl_description()
status = description['TimeToLiveDescription']['TimeToLiveStatus']
if status in ('ENABLED', 'ENABLING'):
cur_attr_name = \
description['TimeToLiveDescription']['AttributeName']
if self._has_ttl():
if cur_attr_name == self._ttl_field.name:
# We want TTL enabled, and it is currently enabled or being
# enabled, and on the correct attribute.
logger.debug((
'DynamoDB Time to Live is {situation} '
'on table {table}'
).format(
situation='already enabled'
if status == 'ENABLED'
else 'currently being enabled',
table=self.table_name
))
return description
elif status in ('DISABLED', 'DISABLING'):
if not self._has_ttl():
# We want TTL disabled, and it is currently disabled or being
# disabled.
logger.debug((
'DynamoDB Time to Live is {situation} '
'on table {table}'
).format(
situation='already disabled'
if status == 'DISABLED'
else 'currently being disabled',
table=self.table_name
))
return description
# The state shouldn't ever have any value beyond the four handled
# above, but to ease troubleshooting of potential future changes, emit
# a log showing the unknown state.
else: # pragma: no cover
logger.warning((
'Unknown DynamoDB Time to Live status {status} '
'on table {table}. Attempting to continue.'
).format(
status=status,
table=self.table_name
))
# At this point, we have one of the following situations:
#
# We want TTL enabled,
#
# - and it's currently disabled: Try to enable.
#
# - and it's being disabled: Try to enable, but this is almost sure to
# raise ValidationException with message:
#
# Time to live has been modified multiple times within a fixed
# interval
#
# - and it's currently enabling or being enabled, but on the wrong
# attribute: Try to enable, but this will raise ValidationException
# with message:
#
# TimeToLive is active on a different AttributeName: current
# AttributeName is ttlx
#
# We want TTL disabled,
#
# - and it's currently enabled: Try to disable.
#
# - and it's being enabled: Try to disable, but this is almost sure to
# raise ValidationException with message:
#
# Time to live has been modified multiple times within a fixed
# interval
#
attr_name = \
cur_attr_name if status == 'ENABLED' else self._ttl_field.name
try:
specification = self._client.update_time_to_live(
**self._get_ttl_specification(
ttl_attr_name=attr_name
)
)
logger.info(
(
'DynamoDB table Time to Live updated: '
'table={table} enabled={enabled} attribute={attr}'
).format(
table=self.table_name,
enabled=self._has_ttl(),
attr=self._ttl_field.name
)
)
return specification
except ClientError as e:
error_code = e.response['Error'].get('Code', 'Unknown')
error_message = e.response['Error'].get('Message', 'Unknown')
logger.error((
'Error {action} Time to Live on DynamoDB table {table}: '
'{code}: {message}'
).format(
action='enabling' if self._has_ttl() else 'disabling',
table=self.table_name,
code=error_code,
message=error_message,
))
raise e
def _wait_for_table_status(self, expected='ACTIVE'):
"""Poll for the expected table status."""
achieved_state = False
while not achieved_state:
table_description = self.client.describe_table(
TableName=self.table_name
)
logger.debug(
'Waiting for DynamoDB table {} to become {}.'.format(
self.table_name,
expected
)
)
current_status = table_description['Table']['TableStatus']
achieved_state = current_status == expected
sleep(1)
def _prepare_get_request(self, key):
"""Construct the item retrieval request parameters."""
return {
'TableName': self.table_name,
'Key': {
self._key_field.name: {
self._key_field.data_type: key
}
}
}
def _prepare_put_request(self, key, value):
"""Construct the item creation request parameters."""
timestamp = time()
put_request = {
'TableName': self.table_name,
'Item': {
self._key_field.name: {
self._key_field.data_type: key
},
self._value_field.name: {
self._value_field.data_type: value
},
self._timestamp_field.name: {
self._timestamp_field.data_type: str(timestamp)
}
}
}
if self._has_ttl():
put_request['Item'].update({
self._ttl_field.name: {
self._ttl_field.data_type:
str(int(timestamp + self.time_to_live_seconds))
}
})
return put_request
def _item_to_dict(self, raw_response):
"""Convert get_item() response to field-value pairs."""
if 'Item' not in raw_response:
return {}
return {
field.name: raw_response['Item'][field.name][field.data_type]
for field in self._available_fields
}
@property
def client(self):
return self._get_client()
def get(self, key):
key = str(key)
request_parameters = self._prepare_get_request(key)
item_response = self.client.get_item(**request_parameters)
item = self._item_to_dict(item_response)
return item.get(self._value_field.name)
def set(self, key, value):
key = str(key)
request_parameters = self._prepare_put_request(key, value)
self.client.put_item(**request_parameters)
def mget(self, keys):
return [self.get(key) for key in keys]
def delete(self, key):
key = str(key)
request_parameters = self._prepare_get_request(key)
self.client.delete_item(**request_parameters)

View File

@@ -0,0 +1,248 @@
"""Elasticsearch result store backend."""
from datetime import datetime
from kombu.utils.encoding import bytes_to_str
from kombu.utils.url import _parse_url
from celery import states
from celery.exceptions import ImproperlyConfigured
from .base import KeyValueStoreBackend
try:
import elasticsearch
except ImportError:
elasticsearch = None
__all__ = ('ElasticsearchBackend',)
E_LIB_MISSING = """\
You need to install the elasticsearch library to use the Elasticsearch \
result backend.\
"""
class ElasticsearchBackend(KeyValueStoreBackend):
"""Elasticsearch Backend.
Raises:
celery.exceptions.ImproperlyConfigured:
if module :pypi:`elasticsearch` is not available.
"""
index = 'celery'
doc_type = 'backend'
scheme = 'http'
host = 'localhost'
port = 9200
username = None
password = None
es_retry_on_timeout = False
es_timeout = 10
es_max_retries = 3
def __init__(self, url=None, *args, **kwargs):
super().__init__(*args, **kwargs)
self.url = url
_get = self.app.conf.get
if elasticsearch is None:
raise ImproperlyConfigured(E_LIB_MISSING)
index = doc_type = scheme = host = port = username = password = None
if url:
scheme, host, port, username, password, path, _ = _parse_url(url)
if scheme == 'elasticsearch':
scheme = None
if path:
path = path.strip('/')
index, _, doc_type = path.partition('/')
self.index = index or self.index
self.doc_type = doc_type or self.doc_type
self.scheme = scheme or self.scheme
self.host = host or self.host
self.port = port or self.port
self.username = username or self.username
self.password = password or self.password
self.es_retry_on_timeout = (
_get('elasticsearch_retry_on_timeout') or self.es_retry_on_timeout
)
es_timeout = _get('elasticsearch_timeout')
if es_timeout is not None:
self.es_timeout = es_timeout
es_max_retries = _get('elasticsearch_max_retries')
if es_max_retries is not None:
self.es_max_retries = es_max_retries
self.es_save_meta_as_text = _get('elasticsearch_save_meta_as_text', True)
self._server = None
def exception_safe_to_retry(self, exc):
if isinstance(exc, (elasticsearch.exceptions.TransportError)):
# 401: Unauthorized
# 409: Conflict
# 429: Too Many Requests
# 500: Internal Server Error
# 502: Bad Gateway
# 503: Service Unavailable
# 504: Gateway Timeout
# N/A: Low level exception (i.e. socket exception)
if exc.status_code in {401, 409, 429, 500, 502, 503, 504, 'N/A'}:
return True
return False
def get(self, key):
try:
res = self._get(key)
try:
if res['found']:
return res['_source']['result']
except (TypeError, KeyError):
pass
except elasticsearch.exceptions.NotFoundError:
pass
def _get(self, key):
return self.server.get(
index=self.index,
doc_type=self.doc_type,
id=key,
)
def _set_with_state(self, key, value, state):
body = {
'result': value,
'@timestamp': '{}Z'.format(
datetime.utcnow().isoformat()[:-3]
),
}
try:
self._index(
id=key,
body=body,
)
except elasticsearch.exceptions.ConflictError:
# document already exists, update it
self._update(key, body, state)
def set(self, key, value):
return self._set_with_state(key, value, None)
def _index(self, id, body, **kwargs):
body = {bytes_to_str(k): v for k, v in body.items()}
return self.server.index(
id=bytes_to_str(id),
index=self.index,
doc_type=self.doc_type,
body=body,
params={'op_type': 'create'},
**kwargs
)
def _update(self, id, body, state, **kwargs):
"""Update state in a conflict free manner.
If state is defined (not None), this will not update ES server if either:
* existing state is success
* existing state is a ready state and current state in not a ready state
This way, a Retry state cannot override a Success or Failure, and chord_unlock
will not retry indefinitely.
"""
body = {bytes_to_str(k): v for k, v in body.items()}
try:
res_get = self._get(key=id)
if not res_get.get('found'):
return self._index(id, body, **kwargs)
# document disappeared between index and get calls.
except elasticsearch.exceptions.NotFoundError:
return self._index(id, body, **kwargs)
try:
meta_present_on_backend = self.decode_result(res_get['_source']['result'])
except (TypeError, KeyError):
pass
else:
if meta_present_on_backend['status'] == states.SUCCESS:
# if stored state is already in success, do nothing
return {'result': 'noop'}
elif meta_present_on_backend['status'] in states.READY_STATES and state in states.UNREADY_STATES:
# if stored state is in ready state and current not, do nothing
return {'result': 'noop'}
# get current sequence number and primary term
# https://www.elastic.co/guide/en/elasticsearch/reference/current/optimistic-concurrency-control.html
seq_no = res_get.get('_seq_no', 1)
prim_term = res_get.get('_primary_term', 1)
# try to update document with current seq_no and primary_term
res = self.server.update(
id=bytes_to_str(id),
index=self.index,
doc_type=self.doc_type,
body={'doc': body},
params={'if_primary_term': prim_term, 'if_seq_no': seq_no},
**kwargs
)
# result is elastic search update query result
# noop = query did not update any document
# updated = at least one document got updated
if res['result'] == 'noop':
raise elasticsearch.exceptions.ConflictError(409, 'conflicting update occurred concurrently', {})
return res
def encode(self, data):
if self.es_save_meta_as_text:
return super().encode(data)
else:
if not isinstance(data, dict):
return super().encode(data)
if data.get("result"):
data["result"] = self._encode(data["result"])[2]
if data.get("traceback"):
data["traceback"] = self._encode(data["traceback"])[2]
return data
def decode(self, payload):
if self.es_save_meta_as_text:
return super().decode(payload)
else:
if not isinstance(payload, dict):
return super().decode(payload)
if payload.get("result"):
payload["result"] = super().decode(payload["result"])
if payload.get("traceback"):
payload["traceback"] = super().decode(payload["traceback"])
return payload
def mget(self, keys):
return [self.get(key) for key in keys]
def delete(self, key):
self.server.delete(index=self.index, doc_type=self.doc_type, id=key)
def _get_server(self):
"""Connect to the Elasticsearch server."""
http_auth = None
if self.username and self.password:
http_auth = (self.username, self.password)
return elasticsearch.Elasticsearch(
f'{self.host}:{self.port}',
retry_on_timeout=self.es_retry_on_timeout,
max_retries=self.es_max_retries,
timeout=self.es_timeout,
scheme=self.scheme,
http_auth=http_auth,
)
@property
def server(self):
if self._server is None:
self._server = self._get_server()
return self._server

View File

@@ -0,0 +1,112 @@
"""File-system result store backend."""
import locale
import os
from datetime import datetime
from kombu.utils.encoding import ensure_bytes
from celery import uuid
from celery.backends.base import KeyValueStoreBackend
from celery.exceptions import ImproperlyConfigured
default_encoding = locale.getpreferredencoding(False)
E_NO_PATH_SET = 'You need to configure a path for the file-system backend'
E_PATH_NON_CONFORMING_SCHEME = (
'A path for the file-system backend should conform to the file URI scheme'
)
E_PATH_INVALID = """\
The configured path for the file-system backend does not
work correctly, please make sure that it exists and has
the correct permissions.\
"""
class FilesystemBackend(KeyValueStoreBackend):
"""File-system result backend.
Arguments:
url (str): URL to the directory we should use
open (Callable): open function to use when opening files
unlink (Callable): unlink function to use when deleting files
sep (str): directory separator (to join the directory with the key)
encoding (str): encoding used on the file-system
"""
def __init__(self, url=None, open=open, unlink=os.unlink, sep=os.sep,
encoding=default_encoding, *args, **kwargs):
super().__init__(*args, **kwargs)
self.url = url
path = self._find_path(url)
# Remove forwarding "/" for Windows os
if os.name == "nt" and path.startswith("/"):
path = path[1:]
# We need the path and separator as bytes objects
self.path = path.encode(encoding)
self.sep = sep.encode(encoding)
self.open = open
self.unlink = unlink
# Lets verify that we've everything setup right
self._do_directory_test(b'.fs-backend-' + uuid().encode(encoding))
def __reduce__(self, args=(), kwargs=None):
kwargs = {} if not kwargs else kwargs
return super().__reduce__(args, {**kwargs, 'url': self.url})
def _find_path(self, url):
if not url:
raise ImproperlyConfigured(E_NO_PATH_SET)
if url.startswith('file://localhost/'):
return url[16:]
if url.startswith('file://'):
return url[7:]
raise ImproperlyConfigured(E_PATH_NON_CONFORMING_SCHEME)
def _do_directory_test(self, key):
try:
self.set(key, b'test value')
assert self.get(key) == b'test value'
self.delete(key)
except OSError:
raise ImproperlyConfigured(E_PATH_INVALID)
def _filename(self, key):
return self.sep.join((self.path, key))
def get(self, key):
try:
with self.open(self._filename(key), 'rb') as infile:
return infile.read()
except FileNotFoundError:
pass
def set(self, key, value):
with self.open(self._filename(key), 'wb') as outfile:
outfile.write(ensure_bytes(value))
def mget(self, keys):
for key in keys:
yield self.get(key)
def delete(self, key):
self.unlink(self._filename(key))
def cleanup(self):
"""Delete expired meta-data."""
if not self.expires:
return
epoch = datetime(1970, 1, 1, tzinfo=self.app.timezone)
now_ts = (self.app.now() - epoch).total_seconds()
cutoff_ts = now_ts - self.expires
for filename in os.listdir(self.path):
for prefix in (self.task_keyprefix, self.group_keyprefix,
self.chord_keyprefix):
if filename.startswith(prefix):
path = os.path.join(self.path, filename)
if os.stat(path).st_mtime < cutoff_ts:
self.unlink(path)
break

View File

@@ -0,0 +1,333 @@
"""MongoDB result store backend."""
from datetime import datetime, timedelta
from kombu.exceptions import EncodeError
from kombu.utils.objects import cached_property
from kombu.utils.url import maybe_sanitize_url, urlparse
from celery import states
from celery.exceptions import ImproperlyConfigured
from .base import BaseBackend
try:
import pymongo
except ImportError:
pymongo = None
if pymongo:
try:
from bson.binary import Binary
except ImportError:
from pymongo.binary import Binary
from pymongo.errors import InvalidDocument
else: # pragma: no cover
Binary = None
class InvalidDocument(Exception):
pass
__all__ = ('MongoBackend',)
BINARY_CODECS = frozenset(['pickle', 'msgpack'])
class MongoBackend(BaseBackend):
"""MongoDB result backend.
Raises:
celery.exceptions.ImproperlyConfigured:
if module :pypi:`pymongo` is not available.
"""
mongo_host = None
host = 'localhost'
port = 27017
user = None
password = None
database_name = 'celery'
taskmeta_collection = 'celery_taskmeta'
groupmeta_collection = 'celery_groupmeta'
max_pool_size = 10
options = None
supports_autoexpire = False
_connection = None
def __init__(self, app=None, **kwargs):
self.options = {}
super().__init__(app, **kwargs)
if not pymongo:
raise ImproperlyConfigured(
'You need to install the pymongo library to use the '
'MongoDB backend.')
# Set option defaults
for key, value in self._prepare_client_options().items():
self.options.setdefault(key, value)
# update conf with mongo uri data, only if uri was given
if self.url:
self.url = self._ensure_mongodb_uri_compliance(self.url)
uri_data = pymongo.uri_parser.parse_uri(self.url)
# build the hosts list to create a mongo connection
hostslist = [
f'{x[0]}:{x[1]}' for x in uri_data['nodelist']
]
self.user = uri_data['username']
self.password = uri_data['password']
self.mongo_host = hostslist
if uri_data['database']:
# if no database is provided in the uri, use default
self.database_name = uri_data['database']
self.options.update(uri_data['options'])
# update conf with specific settings
config = self.app.conf.get('mongodb_backend_settings')
if config is not None:
if not isinstance(config, dict):
raise ImproperlyConfigured(
'MongoDB backend settings should be grouped in a dict')
config = dict(config) # don't modify original
if 'host' in config or 'port' in config:
# these should take over uri conf
self.mongo_host = None
self.host = config.pop('host', self.host)
self.port = config.pop('port', self.port)
self.mongo_host = config.pop('mongo_host', self.mongo_host)
self.user = config.pop('user', self.user)
self.password = config.pop('password', self.password)
self.database_name = config.pop('database', self.database_name)
self.taskmeta_collection = config.pop(
'taskmeta_collection', self.taskmeta_collection,
)
self.groupmeta_collection = config.pop(
'groupmeta_collection', self.groupmeta_collection,
)
self.options.update(config.pop('options', {}))
self.options.update(config)
@staticmethod
def _ensure_mongodb_uri_compliance(url):
parsed_url = urlparse(url)
if not parsed_url.scheme.startswith('mongodb'):
url = f'mongodb+{url}'
if url == 'mongodb://':
url += 'localhost'
return url
def _prepare_client_options(self):
if pymongo.version_tuple >= (3,):
return {'maxPoolSize': self.max_pool_size}
else: # pragma: no cover
return {'max_pool_size': self.max_pool_size,
'auto_start_request': False}
def _get_connection(self):
"""Connect to the MongoDB server."""
if self._connection is None:
from pymongo import MongoClient
host = self.mongo_host
if not host:
# The first pymongo.Connection() argument (host) can be
# a list of ['host:port'] elements or a mongodb connection
# URI. If this is the case, don't use self.port
# but let pymongo get the port(s) from the URI instead.
# This enables the use of replica sets and sharding.
# See pymongo.Connection() for more info.
host = self.host
if isinstance(host, str) \
and not host.startswith('mongodb://'):
host = f'mongodb://{host}:{self.port}'
# don't change self.options
conf = dict(self.options)
conf['host'] = host
if self.user:
conf['username'] = self.user
if self.password:
conf['password'] = self.password
self._connection = MongoClient(**conf)
return self._connection
def encode(self, data):
if self.serializer == 'bson':
# mongodb handles serialization
return data
payload = super().encode(data)
# serializer which are in a unsupported format (pickle/binary)
if self.serializer in BINARY_CODECS:
payload = Binary(payload)
return payload
def decode(self, data):
if self.serializer == 'bson':
return data
return super().decode(data)
def _store_result(self, task_id, result, state,
traceback=None, request=None, **kwargs):
"""Store return value and state of an executed task."""
meta = self._get_result_meta(result=self.encode(result), state=state,
traceback=traceback, request=request,
format_date=False)
# Add the _id for mongodb
meta['_id'] = task_id
try:
self.collection.replace_one({'_id': task_id}, meta, upsert=True)
except InvalidDocument as exc:
raise EncodeError(exc)
return result
def _get_task_meta_for(self, task_id):
"""Get task meta-data for a task by id."""
obj = self.collection.find_one({'_id': task_id})
if obj:
if self.app.conf.find_value_for_key('extended', 'result'):
return self.meta_from_decoded({
'name': obj['name'],
'args': obj['args'],
'task_id': obj['_id'],
'queue': obj['queue'],
'kwargs': obj['kwargs'],
'status': obj['status'],
'worker': obj['worker'],
'retries': obj['retries'],
'children': obj['children'],
'date_done': obj['date_done'],
'traceback': obj['traceback'],
'result': self.decode(obj['result']),
})
return self.meta_from_decoded({
'task_id': obj['_id'],
'status': obj['status'],
'result': self.decode(obj['result']),
'date_done': obj['date_done'],
'traceback': obj['traceback'],
'children': obj['children'],
})
return {'status': states.PENDING, 'result': None}
def _save_group(self, group_id, result):
"""Save the group result."""
meta = {
'_id': group_id,
'result': self.encode([i.id for i in result]),
'date_done': datetime.utcnow(),
}
self.group_collection.replace_one({'_id': group_id}, meta, upsert=True)
return result
def _restore_group(self, group_id):
"""Get the result for a group by id."""
obj = self.group_collection.find_one({'_id': group_id})
if obj:
return {
'task_id': obj['_id'],
'date_done': obj['date_done'],
'result': [
self.app.AsyncResult(task)
for task in self.decode(obj['result'])
],
}
def _delete_group(self, group_id):
"""Delete a group by id."""
self.group_collection.delete_one({'_id': group_id})
def _forget(self, task_id):
"""Remove result from MongoDB.
Raises:
pymongo.exceptions.OperationsError:
if the task_id could not be removed.
"""
# By using safe=True, this will wait until it receives a response from
# the server. Likewise, it will raise an OperationsError if the
# response was unable to be completed.
self.collection.delete_one({'_id': task_id})
def cleanup(self):
"""Delete expired meta-data."""
if not self.expires:
return
self.collection.delete_many(
{'date_done': {'$lt': self.app.now() - self.expires_delta}},
)
self.group_collection.delete_many(
{'date_done': {'$lt': self.app.now() - self.expires_delta}},
)
def __reduce__(self, args=(), kwargs=None):
kwargs = {} if not kwargs else kwargs
return super().__reduce__(
args, dict(kwargs, expires=self.expires, url=self.url))
def _get_database(self):
conn = self._get_connection()
return conn[self.database_name]
@cached_property
def database(self):
"""Get database from MongoDB connection.
performs authentication if necessary.
"""
return self._get_database()
@cached_property
def collection(self):
"""Get the meta-data task collection."""
collection = self.database[self.taskmeta_collection]
# Ensure an index on date_done is there, if not process the index
# in the background. Once completed cleanup will be much faster
collection.create_index('date_done', background=True)
return collection
@cached_property
def group_collection(self):
"""Get the meta-data task collection."""
collection = self.database[self.groupmeta_collection]
# Ensure an index on date_done is there, if not process the index
# in the background. Once completed cleanup will be much faster
collection.create_index('date_done', background=True)
return collection
@cached_property
def expires_delta(self):
return timedelta(seconds=self.expires)
def as_uri(self, include_password=False):
"""Return the backend as an URI.
Arguments:
include_password (bool): Password censored if disabled.
"""
if not self.url:
return 'mongodb://'
if include_password:
return self.url
if ',' not in self.url:
return maybe_sanitize_url(self.url)
uri1, remainder = self.url.split(',', 1)
return ','.join([maybe_sanitize_url(uri1), remainder])

View File

@@ -0,0 +1,668 @@
"""Redis result store backend."""
import time
from contextlib import contextmanager
from functools import partial
from ssl import CERT_NONE, CERT_OPTIONAL, CERT_REQUIRED
from urllib.parse import unquote
from kombu.utils.functional import retry_over_time
from kombu.utils.objects import cached_property
from kombu.utils.url import _parse_url, maybe_sanitize_url
from celery import states
from celery._state import task_join_will_block
from celery.canvas import maybe_signature
from celery.exceptions import BackendStoreError, ChordError, ImproperlyConfigured
from celery.result import GroupResult, allow_join_result
from celery.utils.functional import _regen, dictfilter
from celery.utils.log import get_logger
from celery.utils.time import humanize_seconds
from .asynchronous import AsyncBackendMixin, BaseResultConsumer
from .base import BaseKeyValueStoreBackend
try:
import redis.connection
from kombu.transport.redis import get_redis_error_classes
except ImportError:
redis = None
get_redis_error_classes = None
try:
import redis.sentinel
except ImportError:
pass
__all__ = ('RedisBackend', 'SentinelBackend')
E_REDIS_MISSING = """
You need to install the redis library in order to use \
the Redis result store backend.
"""
E_REDIS_SENTINEL_MISSING = """
You need to install the redis library with support of \
sentinel in order to use the Redis result store backend.
"""
W_REDIS_SSL_CERT_OPTIONAL = """
Setting ssl_cert_reqs=CERT_OPTIONAL when connecting to redis means that \
celery might not validate the identity of the redis broker when connecting. \
This leaves you vulnerable to man in the middle attacks.
"""
W_REDIS_SSL_CERT_NONE = """
Setting ssl_cert_reqs=CERT_NONE when connecting to redis means that celery \
will not validate the identity of the redis broker when connecting. This \
leaves you vulnerable to man in the middle attacks.
"""
E_REDIS_SSL_PARAMS_AND_SCHEME_MISMATCH = """
SSL connection parameters have been provided but the specified URL scheme \
is redis://. A Redis SSL connection URL should use the scheme rediss://.
"""
E_REDIS_SSL_CERT_REQS_MISSING_INVALID = """
A rediss:// URL must have parameter ssl_cert_reqs and this must be set to \
CERT_REQUIRED, CERT_OPTIONAL, or CERT_NONE
"""
E_LOST = 'Connection to Redis lost: Retry (%s/%s) %s.'
E_RETRY_LIMIT_EXCEEDED = """
Retry limit exceeded while trying to reconnect to the Celery redis result \
store backend. The Celery application must be restarted.
"""
logger = get_logger(__name__)
class ResultConsumer(BaseResultConsumer):
_pubsub = None
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._get_key_for_task = self.backend.get_key_for_task
self._decode_result = self.backend.decode_result
self._ensure = self.backend.ensure
self._connection_errors = self.backend.connection_errors
self.subscribed_to = set()
def on_after_fork(self):
try:
self.backend.client.connection_pool.reset()
if self._pubsub is not None:
self._pubsub.close()
except KeyError as e:
logger.warning(str(e))
super().on_after_fork()
def _reconnect_pubsub(self):
self._pubsub = None
self.backend.client.connection_pool.reset()
# task state might have changed when the connection was down so we
# retrieve meta for all subscribed tasks before going into pubsub mode
if self.subscribed_to:
metas = self.backend.client.mget(self.subscribed_to)
metas = [meta for meta in metas if meta]
for meta in metas:
self.on_state_change(self._decode_result(meta), None)
self._pubsub = self.backend.client.pubsub(
ignore_subscribe_messages=True,
)
# subscribed_to maybe empty after on_state_change
if self.subscribed_to:
self._pubsub.subscribe(*self.subscribed_to)
else:
self._pubsub.connection = self._pubsub.connection_pool.get_connection(
'pubsub', self._pubsub.shard_hint
)
# even if there is nothing to subscribe, we should not lose the callback after connecting.
# The on_connect callback will re-subscribe to any channels we previously subscribed to.
self._pubsub.connection.register_connect_callback(self._pubsub.on_connect)
@contextmanager
def reconnect_on_error(self):
try:
yield
except self._connection_errors:
try:
self._ensure(self._reconnect_pubsub, ())
except self._connection_errors:
logger.critical(E_RETRY_LIMIT_EXCEEDED)
raise
def _maybe_cancel_ready_task(self, meta):
if meta['status'] in states.READY_STATES:
self.cancel_for(meta['task_id'])
def on_state_change(self, meta, message):
super().on_state_change(meta, message)
self._maybe_cancel_ready_task(meta)
def start(self, initial_task_id, **kwargs):
self._pubsub = self.backend.client.pubsub(
ignore_subscribe_messages=True,
)
self._consume_from(initial_task_id)
def on_wait_for_pending(self, result, **kwargs):
for meta in result._iter_meta(**kwargs):
if meta is not None:
self.on_state_change(meta, None)
def stop(self):
if self._pubsub is not None:
self._pubsub.close()
def drain_events(self, timeout=None):
if self._pubsub:
with self.reconnect_on_error():
message = self._pubsub.get_message(timeout=timeout)
if message and message['type'] == 'message':
self.on_state_change(self._decode_result(message['data']), message)
elif timeout:
time.sleep(timeout)
def consume_from(self, task_id):
if self._pubsub is None:
return self.start(task_id)
self._consume_from(task_id)
def _consume_from(self, task_id):
key = self._get_key_for_task(task_id)
if key not in self.subscribed_to:
self.subscribed_to.add(key)
with self.reconnect_on_error():
self._pubsub.subscribe(key)
def cancel_for(self, task_id):
key = self._get_key_for_task(task_id)
self.subscribed_to.discard(key)
if self._pubsub:
with self.reconnect_on_error():
self._pubsub.unsubscribe(key)
class RedisBackend(BaseKeyValueStoreBackend, AsyncBackendMixin):
"""Redis task result store.
It makes use of the following commands:
GET, MGET, DEL, INCRBY, EXPIRE, SET, SETEX
"""
ResultConsumer = ResultConsumer
#: :pypi:`redis` client module.
redis = redis
connection_class_ssl = redis.SSLConnection if redis else None
#: Maximum number of connections in the pool.
max_connections = None
supports_autoexpire = True
supports_native_join = True
#: Maximal length of string value in Redis.
#: 512 MB - https://redis.io/topics/data-types
_MAX_STR_VALUE_SIZE = 536870912
def __init__(self, host=None, port=None, db=None, password=None,
max_connections=None, url=None,
connection_pool=None, **kwargs):
super().__init__(expires_type=int, **kwargs)
_get = self.app.conf.get
if self.redis is None:
raise ImproperlyConfigured(E_REDIS_MISSING.strip())
if host and '://' in host:
url, host = host, None
self.max_connections = (
max_connections or
_get('redis_max_connections') or
self.max_connections)
self._ConnectionPool = connection_pool
socket_timeout = _get('redis_socket_timeout')
socket_connect_timeout = _get('redis_socket_connect_timeout')
retry_on_timeout = _get('redis_retry_on_timeout')
socket_keepalive = _get('redis_socket_keepalive')
health_check_interval = _get('redis_backend_health_check_interval')
self.connparams = {
'host': _get('redis_host') or 'localhost',
'port': _get('redis_port') or 6379,
'db': _get('redis_db') or 0,
'password': _get('redis_password'),
'max_connections': self.max_connections,
'socket_timeout': socket_timeout and float(socket_timeout),
'retry_on_timeout': retry_on_timeout or False,
'socket_connect_timeout':
socket_connect_timeout and float(socket_connect_timeout),
}
username = _get('redis_username')
if username:
# We're extra careful to avoid including this configuration value
# if it wasn't specified since older versions of py-redis
# don't support specifying a username.
# Only Redis>6.0 supports username/password authentication.
# TODO: Include this in connparams' definition once we drop
# support for py-redis<3.4.0.
self.connparams['username'] = username
if health_check_interval:
self.connparams["health_check_interval"] = health_check_interval
# absent in redis.connection.UnixDomainSocketConnection
if socket_keepalive:
self.connparams['socket_keepalive'] = socket_keepalive
# "redis_backend_use_ssl" must be a dict with the keys:
# 'ssl_cert_reqs', 'ssl_ca_certs', 'ssl_certfile', 'ssl_keyfile'
# (the same as "broker_use_ssl")
ssl = _get('redis_backend_use_ssl')
if ssl:
self.connparams.update(ssl)
self.connparams['connection_class'] = self.connection_class_ssl
if url:
self.connparams = self._params_from_url(url, self.connparams)
# If we've received SSL parameters via query string or the
# redis_backend_use_ssl dict, check ssl_cert_reqs is valid. If set
# via query string ssl_cert_reqs will be a string so convert it here
if ('connection_class' in self.connparams and
issubclass(self.connparams['connection_class'], redis.SSLConnection)):
ssl_cert_reqs_missing = 'MISSING'
ssl_string_to_constant = {'CERT_REQUIRED': CERT_REQUIRED,
'CERT_OPTIONAL': CERT_OPTIONAL,
'CERT_NONE': CERT_NONE,
'required': CERT_REQUIRED,
'optional': CERT_OPTIONAL,
'none': CERT_NONE}
ssl_cert_reqs = self.connparams.get('ssl_cert_reqs', ssl_cert_reqs_missing)
ssl_cert_reqs = ssl_string_to_constant.get(ssl_cert_reqs, ssl_cert_reqs)
if ssl_cert_reqs not in ssl_string_to_constant.values():
raise ValueError(E_REDIS_SSL_CERT_REQS_MISSING_INVALID)
if ssl_cert_reqs == CERT_OPTIONAL:
logger.warning(W_REDIS_SSL_CERT_OPTIONAL)
elif ssl_cert_reqs == CERT_NONE:
logger.warning(W_REDIS_SSL_CERT_NONE)
self.connparams['ssl_cert_reqs'] = ssl_cert_reqs
self.url = url
self.connection_errors, self.channel_errors = (
get_redis_error_classes() if get_redis_error_classes
else ((), ()))
self.result_consumer = self.ResultConsumer(
self, self.app, self.accept,
self._pending_results, self._pending_messages,
)
def _params_from_url(self, url, defaults):
scheme, host, port, username, password, path, query = _parse_url(url)
connparams = dict(
defaults, **dictfilter({
'host': host, 'port': port, 'username': username,
'password': password, 'db': query.pop('virtual_host', None)})
)
if scheme == 'socket':
# use 'path' as path to the socket… in this case
# the database number should be given in 'query'
connparams.update({
'connection_class': self.redis.UnixDomainSocketConnection,
'path': '/' + path,
})
# host+port are invalid options when using this connection type.
connparams.pop('host', None)
connparams.pop('port', None)
connparams.pop('socket_connect_timeout')
else:
connparams['db'] = path
ssl_param_keys = ['ssl_ca_certs', 'ssl_certfile', 'ssl_keyfile',
'ssl_cert_reqs']
if scheme == 'redis':
# If connparams or query string contain ssl params, raise error
if (any(key in connparams for key in ssl_param_keys) or
any(key in query for key in ssl_param_keys)):
raise ValueError(E_REDIS_SSL_PARAMS_AND_SCHEME_MISMATCH)
if scheme == 'rediss':
connparams['connection_class'] = redis.SSLConnection
# The following parameters, if present in the URL, are encoded. We
# must add the decoded values to connparams.
for ssl_setting in ssl_param_keys:
ssl_val = query.pop(ssl_setting, None)
if ssl_val:
connparams[ssl_setting] = unquote(ssl_val)
# db may be string and start with / like in kombu.
db = connparams.get('db') or 0
db = db.strip('/') if isinstance(db, str) else db
connparams['db'] = int(db)
for key, value in query.items():
if key in redis.connection.URL_QUERY_ARGUMENT_PARSERS:
query[key] = redis.connection.URL_QUERY_ARGUMENT_PARSERS[key](
value
)
# Query parameters override other parameters
connparams.update(query)
return connparams
@cached_property
def retry_policy(self):
retry_policy = super().retry_policy
if "retry_policy" in self._transport_options:
retry_policy = retry_policy.copy()
retry_policy.update(self._transport_options['retry_policy'])
return retry_policy
def on_task_call(self, producer, task_id):
if not task_join_will_block():
self.result_consumer.consume_from(task_id)
def get(self, key):
return self.client.get(key)
def mget(self, keys):
return self.client.mget(keys)
def ensure(self, fun, args, **policy):
retry_policy = dict(self.retry_policy, **policy)
max_retries = retry_policy.get('max_retries')
return retry_over_time(
fun, self.connection_errors, args, {},
partial(self.on_connection_error, max_retries),
**retry_policy)
def on_connection_error(self, max_retries, exc, intervals, retries):
tts = next(intervals)
logger.error(
E_LOST.strip(),
retries, max_retries or 'Inf', humanize_seconds(tts, 'in '))
return tts
def set(self, key, value, **retry_policy):
if isinstance(value, str) and len(value) > self._MAX_STR_VALUE_SIZE:
raise BackendStoreError('value too large for Redis backend')
return self.ensure(self._set, (key, value), **retry_policy)
def _set(self, key, value):
with self.client.pipeline() as pipe:
if self.expires:
pipe.setex(key, self.expires, value)
else:
pipe.set(key, value)
pipe.publish(key, value)
pipe.execute()
def forget(self, task_id):
super().forget(task_id)
self.result_consumer.cancel_for(task_id)
def delete(self, key):
self.client.delete(key)
def incr(self, key):
return self.client.incr(key)
def expire(self, key, value):
return self.client.expire(key, value)
def add_to_chord(self, group_id, result):
self.client.incr(self.get_key_for_group(group_id, '.t'), 1)
def _unpack_chord_result(self, tup, decode,
EXCEPTION_STATES=states.EXCEPTION_STATES,
PROPAGATE_STATES=states.PROPAGATE_STATES):
_, tid, state, retval = decode(tup)
if state in EXCEPTION_STATES:
retval = self.exception_to_python(retval)
if state in PROPAGATE_STATES:
raise ChordError(f'Dependency {tid} raised {retval!r}')
return retval
def set_chord_size(self, group_id, chord_size):
self.set(self.get_key_for_group(group_id, '.s'), chord_size)
def apply_chord(self, header_result_args, body, **kwargs):
# If any of the child results of this chord are complex (ie. group
# results themselves), we need to save `header_result` to ensure that
# the expected structure is retained when we finish the chord and pass
# the results onward to the body in `on_chord_part_return()`. We don't
# do this is all cases to retain an optimisation in the common case
# where a chord header is comprised of simple result objects.
if not isinstance(header_result_args[1], _regen):
header_result = self.app.GroupResult(*header_result_args)
if any(isinstance(nr, GroupResult) for nr in header_result.results):
header_result.save(backend=self)
@cached_property
def _chord_zset(self):
return self._transport_options.get('result_chord_ordered', True)
@cached_property
def _transport_options(self):
return self.app.conf.get('result_backend_transport_options', {})
def on_chord_part_return(self, request, state, result,
propagate=None, **kwargs):
app = self.app
tid, gid, group_index = request.id, request.group, request.group_index
if not gid or not tid:
return
if group_index is None:
group_index = '+inf'
client = self.client
jkey = self.get_key_for_group(gid, '.j')
tkey = self.get_key_for_group(gid, '.t')
skey = self.get_key_for_group(gid, '.s')
result = self.encode_result(result, state)
encoded = self.encode([1, tid, state, result])
with client.pipeline() as pipe:
pipeline = (
pipe.zadd(jkey, {encoded: group_index}).zcount(jkey, "-inf", "+inf")
if self._chord_zset
else pipe.rpush(jkey, encoded).llen(jkey)
).get(tkey).get(skey)
if self.expires:
pipeline = pipeline \
.expire(jkey, self.expires) \
.expire(tkey, self.expires) \
.expire(skey, self.expires)
_, readycount, totaldiff, chord_size_bytes = pipeline.execute()[:4]
totaldiff = int(totaldiff or 0)
if chord_size_bytes:
try:
callback = maybe_signature(request.chord, app=app)
total = int(chord_size_bytes) + totaldiff
if readycount == total:
header_result = GroupResult.restore(gid)
if header_result is not None:
# If we manage to restore a `GroupResult`, then it must
# have been complex and saved by `apply_chord()` earlier.
#
# Before we can join the `GroupResult`, it needs to be
# manually marked as ready to avoid blocking
header_result.on_ready()
# We'll `join()` it to get the results and ensure they are
# structured as intended rather than the flattened version
# we'd construct without any other information.
join_func = (
header_result.join_native
if header_result.supports_native_join
else header_result.join
)
with allow_join_result():
resl = join_func(
timeout=app.conf.result_chord_join_timeout,
propagate=True
)
else:
# Otherwise simply extract and decode the results we
# stashed along the way, which should be faster for large
# numbers of simple results in the chord header.
decode, unpack = self.decode, self._unpack_chord_result
with client.pipeline() as pipe:
if self._chord_zset:
pipeline = pipe.zrange(jkey, 0, -1)
else:
pipeline = pipe.lrange(jkey, 0, total)
resl, = pipeline.execute()
resl = [unpack(tup, decode) for tup in resl]
try:
callback.delay(resl)
except Exception as exc: # pylint: disable=broad-except
logger.exception(
'Chord callback for %r raised: %r', request.group, exc)
return self.chord_error_from_stack(
callback,
ChordError(f'Callback error: {exc!r}'),
)
finally:
with client.pipeline() as pipe:
pipe \
.delete(jkey) \
.delete(tkey) \
.delete(skey) \
.execute()
except ChordError as exc:
logger.exception('Chord %r raised: %r', request.group, exc)
return self.chord_error_from_stack(callback, exc)
except Exception as exc: # pylint: disable=broad-except
logger.exception('Chord %r raised: %r', request.group, exc)
return self.chord_error_from_stack(
callback,
ChordError(f'Join error: {exc!r}'),
)
def _create_client(self, **params):
return self._get_client()(
connection_pool=self._get_pool(**params),
)
def _get_client(self):
return self.redis.StrictRedis
def _get_pool(self, **params):
return self.ConnectionPool(**params)
@property
def ConnectionPool(self):
if self._ConnectionPool is None:
self._ConnectionPool = self.redis.ConnectionPool
return self._ConnectionPool
@cached_property
def client(self):
return self._create_client(**self.connparams)
def __reduce__(self, args=(), kwargs=None):
kwargs = {} if not kwargs else kwargs
return super().__reduce__(
args, dict(kwargs, expires=self.expires, url=self.url))
if getattr(redis, "sentinel", None):
class SentinelManagedSSLConnection(
redis.sentinel.SentinelManagedConnection,
redis.SSLConnection):
"""Connect to a Redis server using Sentinel + TLS.
Use Sentinel to identify which Redis server is the current master
to connect to and when connecting to the Master server, use an
SSL Connection.
"""
class SentinelBackend(RedisBackend):
"""Redis sentinel task result store."""
# URL looks like `sentinel://0.0.0.0:26347/3;sentinel://0.0.0.0:26348/3`
_SERVER_URI_SEPARATOR = ";"
sentinel = getattr(redis, "sentinel", None)
connection_class_ssl = SentinelManagedSSLConnection if sentinel else None
def __init__(self, *args, **kwargs):
if self.sentinel is None:
raise ImproperlyConfigured(E_REDIS_SENTINEL_MISSING.strip())
super().__init__(*args, **kwargs)
def as_uri(self, include_password=False):
"""Return the server addresses as URIs, sanitizing the password or not."""
# Allow superclass to do work if we don't need to force sanitization
if include_password:
return super().as_uri(
include_password=include_password,
)
# Otherwise we need to ensure that all components get sanitized rather
# by passing them one by one to the `kombu` helper
uri_chunks = (
maybe_sanitize_url(chunk)
for chunk in (self.url or "").split(self._SERVER_URI_SEPARATOR)
)
# Similar to the superclass, strip the trailing slash from URIs with
# all components empty other than the scheme
return self._SERVER_URI_SEPARATOR.join(
uri[:-1] if uri.endswith(":///") else uri
for uri in uri_chunks
)
def _params_from_url(self, url, defaults):
chunks = url.split(self._SERVER_URI_SEPARATOR)
connparams = dict(defaults, hosts=[])
for chunk in chunks:
data = super()._params_from_url(
url=chunk, defaults=defaults)
connparams['hosts'].append(data)
for param in ("host", "port", "db", "password"):
connparams.pop(param)
# Adding db/password in connparams to connect to the correct instance
for param in ("db", "password"):
if connparams['hosts'] and param in connparams['hosts'][0]:
connparams[param] = connparams['hosts'][0].get(param)
return connparams
def _get_sentinel_instance(self, **params):
connparams = params.copy()
hosts = connparams.pop("hosts")
min_other_sentinels = self._transport_options.get("min_other_sentinels", 0)
sentinel_kwargs = self._transport_options.get("sentinel_kwargs", {})
sentinel_instance = self.sentinel.Sentinel(
[(cp['host'], cp['port']) for cp in hosts],
min_other_sentinels=min_other_sentinels,
sentinel_kwargs=sentinel_kwargs,
**connparams)
return sentinel_instance
def _get_pool(self, **params):
sentinel_instance = self._get_sentinel_instance(**params)
master_name = self._transport_options.get("master_name", None)
return sentinel_instance.master_for(
service_name=master_name,
redis_class=self._get_client(),
).connection_pool

View File

@@ -0,0 +1,342 @@
"""The ``RPC`` result backend for AMQP brokers.
RPC-style result backend, using reply-to and one queue per client.
"""
import time
import kombu
from kombu.common import maybe_declare
from kombu.utils.compat import register_after_fork
from kombu.utils.objects import cached_property
from celery import states
from celery._state import current_task, task_join_will_block
from . import base
from .asynchronous import AsyncBackendMixin, BaseResultConsumer
__all__ = ('BacklogLimitExceeded', 'RPCBackend')
E_NO_CHORD_SUPPORT = """
The "rpc" result backend does not support chords!
Note that a group chained with a task is also upgraded to be a chord,
as this pattern requires synchronization.
Result backends that supports chords: Redis, Database, Memcached, and more.
"""
class BacklogLimitExceeded(Exception):
"""Too much state history to fast-forward."""
def _on_after_fork_cleanup_backend(backend):
backend._after_fork()
class ResultConsumer(BaseResultConsumer):
Consumer = kombu.Consumer
_connection = None
_consumer = None
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._create_binding = self.backend._create_binding
def start(self, initial_task_id, no_ack=True, **kwargs):
self._connection = self.app.connection()
initial_queue = self._create_binding(initial_task_id)
self._consumer = self.Consumer(
self._connection.default_channel, [initial_queue],
callbacks=[self.on_state_change], no_ack=no_ack,
accept=self.accept)
self._consumer.consume()
def drain_events(self, timeout=None):
if self._connection:
return self._connection.drain_events(timeout=timeout)
elif timeout:
time.sleep(timeout)
def stop(self):
try:
self._consumer.cancel()
finally:
self._connection.close()
def on_after_fork(self):
self._consumer = None
if self._connection is not None:
self._connection.collect()
self._connection = None
def consume_from(self, task_id):
if self._consumer is None:
return self.start(task_id)
queue = self._create_binding(task_id)
if not self._consumer.consuming_from(queue):
self._consumer.add_queue(queue)
self._consumer.consume()
def cancel_for(self, task_id):
if self._consumer:
self._consumer.cancel_by_queue(self._create_binding(task_id).name)
class RPCBackend(base.Backend, AsyncBackendMixin):
"""Base class for the RPC result backend."""
Exchange = kombu.Exchange
Producer = kombu.Producer
ResultConsumer = ResultConsumer
#: Exception raised when there are too many messages for a task id.
BacklogLimitExceeded = BacklogLimitExceeded
persistent = False
supports_autoexpire = True
supports_native_join = True
retry_policy = {
'max_retries': 20,
'interval_start': 0,
'interval_step': 1,
'interval_max': 1,
}
class Consumer(kombu.Consumer):
"""Consumer that requires manual declaration of queues."""
auto_declare = False
class Queue(kombu.Queue):
"""Queue that never caches declaration."""
can_cache_declaration = False
def __init__(self, app, connection=None, exchange=None, exchange_type=None,
persistent=None, serializer=None, auto_delete=True, **kwargs):
super().__init__(app, **kwargs)
conf = self.app.conf
self._connection = connection
self._out_of_band = {}
self.persistent = self.prepare_persistent(persistent)
self.delivery_mode = 2 if self.persistent else 1
exchange = exchange or conf.result_exchange
exchange_type = exchange_type or conf.result_exchange_type
self.exchange = self._create_exchange(
exchange, exchange_type, self.delivery_mode,
)
self.serializer = serializer or conf.result_serializer
self.auto_delete = auto_delete
self.result_consumer = self.ResultConsumer(
self, self.app, self.accept,
self._pending_results, self._pending_messages,
)
if register_after_fork is not None:
register_after_fork(self, _on_after_fork_cleanup_backend)
def _after_fork(self):
# clear state for child processes.
self._pending_results.clear()
self.result_consumer._after_fork()
def _create_exchange(self, name, type='direct', delivery_mode=2):
# uses direct to queue routing (anon exchange).
return self.Exchange(None)
def _create_binding(self, task_id):
"""Create new binding for task with id."""
# RPC backend caches the binding, as one queue is used for all tasks.
return self.binding
def ensure_chords_allowed(self):
raise NotImplementedError(E_NO_CHORD_SUPPORT.strip())
def on_task_call(self, producer, task_id):
# Called every time a task is sent when using this backend.
# We declare the queue we receive replies on in advance of sending
# the message, but we skip this if running in the prefork pool
# (task_join_will_block), as we know the queue is already declared.
if not task_join_will_block():
maybe_declare(self.binding(producer.channel), retry=True)
def destination_for(self, task_id, request):
"""Get the destination for result by task id.
Returns:
Tuple[str, str]: tuple of ``(reply_to, correlation_id)``.
"""
# Backends didn't always receive the `request`, so we must still
# support old code that relies on current_task.
try:
request = request or current_task.request
except AttributeError:
raise RuntimeError(
f'RPC backend missing task request for {task_id!r}')
return request.reply_to, request.correlation_id or task_id
def on_reply_declare(self, task_id):
# Return value here is used as the `declare=` argument
# for Producer.publish.
# By default we don't have to declare anything when sending a result.
pass
def on_result_fulfilled(self, result):
# This usually cancels the queue after the result is received,
# but we don't have to cancel since we have one queue per process.
pass
def as_uri(self, include_password=True):
return 'rpc://'
def store_result(self, task_id, result, state,
traceback=None, request=None, **kwargs):
"""Send task return value and state."""
routing_key, correlation_id = self.destination_for(task_id, request)
if not routing_key:
return
with self.app.amqp.producer_pool.acquire(block=True) as producer:
producer.publish(
self._to_result(task_id, state, result, traceback, request),
exchange=self.exchange,
routing_key=routing_key,
correlation_id=correlation_id,
serializer=self.serializer,
retry=True, retry_policy=self.retry_policy,
declare=self.on_reply_declare(task_id),
delivery_mode=self.delivery_mode,
)
return result
def _to_result(self, task_id, state, result, traceback, request):
return {
'task_id': task_id,
'status': state,
'result': self.encode_result(result, state),
'traceback': traceback,
'children': self.current_task_children(request),
}
def on_out_of_band_result(self, task_id, message):
# Callback called when a reply for a task is received,
# but we have no idea what do do with it.
# Since the result is not pending, we put it in a separate
# buffer: probably it will become pending later.
if self.result_consumer:
self.result_consumer.on_out_of_band_result(message)
self._out_of_band[task_id] = message
def get_task_meta(self, task_id, backlog_limit=1000):
buffered = self._out_of_band.pop(task_id, None)
if buffered:
return self._set_cache_by_message(task_id, buffered)
# Polling and using basic_get
latest_by_id = {}
prev = None
for acc in self._slurp_from_queue(task_id, self.accept, backlog_limit):
tid = self._get_message_task_id(acc)
prev, latest_by_id[tid] = latest_by_id.get(tid), acc
if prev:
# backends aren't expected to keep history,
# so we delete everything except the most recent state.
prev.ack()
prev = None
latest = latest_by_id.pop(task_id, None)
for tid, msg in latest_by_id.items():
self.on_out_of_band_result(tid, msg)
if latest:
latest.requeue()
return self._set_cache_by_message(task_id, latest)
else:
# no new state, use previous
try:
return self._cache[task_id]
except KeyError:
# result probably pending.
return {'status': states.PENDING, 'result': None}
poll = get_task_meta # XXX compat
def _set_cache_by_message(self, task_id, message):
payload = self._cache[task_id] = self.meta_from_decoded(
message.payload)
return payload
def _slurp_from_queue(self, task_id, accept,
limit=1000, no_ack=False):
with self.app.pool.acquire_channel(block=True) as (_, channel):
binding = self._create_binding(task_id)(channel)
binding.declare()
for _ in range(limit):
msg = binding.get(accept=accept, no_ack=no_ack)
if not msg:
break
yield msg
else:
raise self.BacklogLimitExceeded(task_id)
def _get_message_task_id(self, message):
try:
# try property first so we don't have to deserialize
# the payload.
return message.properties['correlation_id']
except (AttributeError, KeyError):
# message sent by old Celery version, need to deserialize.
return message.payload['task_id']
def revive(self, channel):
pass
def reload_task_result(self, task_id):
raise NotImplementedError(
'reload_task_result is not supported by this backend.')
def reload_group_result(self, task_id):
"""Reload group result, even if it has been previously fetched."""
raise NotImplementedError(
'reload_group_result is not supported by this backend.')
def save_group(self, group_id, result):
raise NotImplementedError(
'save_group is not supported by this backend.')
def restore_group(self, group_id, cache=True):
raise NotImplementedError(
'restore_group is not supported by this backend.')
def delete_group(self, group_id):
raise NotImplementedError(
'delete_group is not supported by this backend.')
def __reduce__(self, args=(), kwargs=None):
kwargs = {} if not kwargs else kwargs
return super().__reduce__(args, dict(
kwargs,
connection=self._connection,
exchange=self.exchange.name,
exchange_type=self.exchange.type,
persistent=self.persistent,
serializer=self.serializer,
auto_delete=self.auto_delete,
expires=self.expires,
))
@property
def binding(self):
return self.Queue(
self.oid, self.exchange, self.oid,
durable=False,
auto_delete=True,
expires=self.expires,
)
@cached_property
def oid(self):
# cached here is the app thread OID: name of queue we receive results on.
return self.app.thread_oid

View File

@@ -0,0 +1,87 @@
"""s3 result store backend."""
from kombu.utils.encoding import bytes_to_str
from celery.exceptions import ImproperlyConfigured
from .base import KeyValueStoreBackend
try:
import boto3
import botocore
except ImportError:
boto3 = None
botocore = None
__all__ = ('S3Backend',)
class S3Backend(KeyValueStoreBackend):
"""An S3 task result store.
Raises:
celery.exceptions.ImproperlyConfigured:
if module :pypi:`boto3` is not available,
if the :setting:`aws_access_key_id` or
setting:`aws_secret_access_key` are not set,
or it the :setting:`bucket` is not set.
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
if not boto3 or not botocore:
raise ImproperlyConfigured('You must install boto3'
'to use s3 backend')
conf = self.app.conf
self.endpoint_url = conf.get('s3_endpoint_url', None)
self.aws_region = conf.get('s3_region', None)
self.aws_access_key_id = conf.get('s3_access_key_id', None)
self.aws_secret_access_key = conf.get('s3_secret_access_key', None)
self.bucket_name = conf.get('s3_bucket', None)
if not self.bucket_name:
raise ImproperlyConfigured('Missing bucket name')
self.base_path = conf.get('s3_base_path', None)
self._s3_resource = self._connect_to_s3()
def _get_s3_object(self, key):
key_bucket_path = self.base_path + key if self.base_path else key
return self._s3_resource.Object(self.bucket_name, key_bucket_path)
def get(self, key):
key = bytes_to_str(key)
s3_object = self._get_s3_object(key)
try:
s3_object.load()
data = s3_object.get()['Body'].read()
return data if self.content_encoding == 'binary' else data.decode('utf-8')
except botocore.exceptions.ClientError as error:
if error.response['Error']['Code'] == "404":
return None
raise error
def set(self, key, value):
key = bytes_to_str(key)
s3_object = self._get_s3_object(key)
s3_object.put(Body=value)
def delete(self, key):
key = bytes_to_str(key)
s3_object = self._get_s3_object(key)
s3_object.delete()
def _connect_to_s3(self):
session = boto3.Session(
aws_access_key_id=self.aws_access_key_id,
aws_secret_access_key=self.aws_secret_access_key,
region_name=self.aws_region
)
if session.get_credentials() is None:
raise ImproperlyConfigured('Missing aws s3 creds')
return session.resource('s3', endpoint_url=self.endpoint_url)

View File

@@ -0,0 +1,736 @@
"""The periodic task scheduler."""
import copy
import errno
import heapq
import os
import shelve
import sys
import time
import traceback
from calendar import timegm
from collections import namedtuple
from functools import total_ordering
from threading import Event, Thread
from billiard import ensure_multiprocessing
from billiard.common import reset_signals
from billiard.context import Process
from kombu.utils.functional import maybe_evaluate, reprcall
from kombu.utils.objects import cached_property
from . import __version__, platforms, signals
from .exceptions import reraise
from .schedules import crontab, maybe_schedule
from .utils.functional import is_numeric_value
from .utils.imports import load_extension_class_names, symbol_by_name
from .utils.log import get_logger, iter_open_logger_fds
from .utils.time import humanize_seconds, maybe_make_aware
__all__ = (
'SchedulingError', 'ScheduleEntry', 'Scheduler',
'PersistentScheduler', 'Service', 'EmbeddedService',
)
event_t = namedtuple('event_t', ('time', 'priority', 'entry'))
logger = get_logger(__name__)
debug, info, error, warning = (logger.debug, logger.info,
logger.error, logger.warning)
DEFAULT_MAX_INTERVAL = 300 # 5 minutes
class SchedulingError(Exception):
"""An error occurred while scheduling a task."""
class BeatLazyFunc:
"""A lazy function declared in 'beat_schedule' and called before sending to worker.
Example:
beat_schedule = {
'test-every-5-minutes': {
'task': 'test',
'schedule': 300,
'kwargs': {
"current": BeatCallBack(datetime.datetime.now)
}
}
}
"""
def __init__(self, func, *args, **kwargs):
self._func = func
self._func_params = {
"args": args,
"kwargs": kwargs
}
def __call__(self):
return self.delay()
def delay(self):
return self._func(*self._func_params["args"], **self._func_params["kwargs"])
@total_ordering
class ScheduleEntry:
"""An entry in the scheduler.
Arguments:
name (str): see :attr:`name`.
schedule (~celery.schedules.schedule): see :attr:`schedule`.
args (Tuple): see :attr:`args`.
kwargs (Dict): see :attr:`kwargs`.
options (Dict): see :attr:`options`.
last_run_at (~datetime.datetime): see :attr:`last_run_at`.
total_run_count (int): see :attr:`total_run_count`.
relative (bool): Is the time relative to when the server starts?
"""
#: The task name
name = None
#: The schedule (:class:`~celery.schedules.schedule`)
schedule = None
#: Positional arguments to apply.
args = None
#: Keyword arguments to apply.
kwargs = None
#: Task execution options.
options = None
#: The time and date of when this task was last scheduled.
last_run_at = None
#: Total number of times this task has been scheduled.
total_run_count = 0
def __init__(self, name=None, task=None, last_run_at=None,
total_run_count=None, schedule=None, args=(), kwargs=None,
options=None, relative=False, app=None):
self.app = app
self.name = name
self.task = task
self.args = args
self.kwargs = kwargs if kwargs else {}
self.options = options if options else {}
self.schedule = maybe_schedule(schedule, relative, app=self.app)
self.last_run_at = last_run_at or self.default_now()
self.total_run_count = total_run_count or 0
def default_now(self):
return self.schedule.now() if self.schedule else self.app.now()
_default_now = default_now # compat
def _next_instance(self, last_run_at=None):
"""Return new instance, with date and count fields updated."""
return self.__class__(**dict(
self,
last_run_at=last_run_at or self.default_now(),
total_run_count=self.total_run_count + 1,
))
__next__ = next = _next_instance # for 2to3
def __reduce__(self):
return self.__class__, (
self.name, self.task, self.last_run_at, self.total_run_count,
self.schedule, self.args, self.kwargs, self.options,
)
def update(self, other):
"""Update values from another entry.
Will only update "editable" fields:
``task``, ``schedule``, ``args``, ``kwargs``, ``options``.
"""
self.__dict__.update({
'task': other.task, 'schedule': other.schedule,
'args': other.args, 'kwargs': other.kwargs,
'options': other.options,
})
def is_due(self):
"""See :meth:`~celery.schedules.schedule.is_due`."""
return self.schedule.is_due(self.last_run_at)
def __iter__(self):
return iter(vars(self).items())
def __repr__(self):
return '<{name}: {0.name} {call} {0.schedule}'.format(
self,
call=reprcall(self.task, self.args or (), self.kwargs or {}),
name=type(self).__name__,
)
def __lt__(self, other):
if isinstance(other, ScheduleEntry):
# How the object is ordered doesn't really matter, as
# in the scheduler heap, the order is decided by the
# preceding members of the tuple ``(time, priority, entry)``.
#
# If all that's left to order on is the entry then it can
# just as well be random.
return id(self) < id(other)
return NotImplemented
def editable_fields_equal(self, other):
for attr in ('task', 'args', 'kwargs', 'options', 'schedule'):
if getattr(self, attr) != getattr(other, attr):
return False
return True
def __eq__(self, other):
"""Test schedule entries equality.
Will only compare "editable" fields:
``task``, ``schedule``, ``args``, ``kwargs``, ``options``.
"""
return self.editable_fields_equal(other)
def _evaluate_entry_args(entry_args):
if not entry_args:
return []
return [
v() if isinstance(v, BeatLazyFunc) else v
for v in entry_args
]
def _evaluate_entry_kwargs(entry_kwargs):
if not entry_kwargs:
return {}
return {
k: v() if isinstance(v, BeatLazyFunc) else v
for k, v in entry_kwargs.items()
}
class Scheduler:
"""Scheduler for periodic tasks.
The :program:`celery beat` program may instantiate this class
multiple times for introspection purposes, but then with the
``lazy`` argument set. It's important for subclasses to
be idempotent when this argument is set.
Arguments:
schedule (~celery.schedules.schedule): see :attr:`schedule`.
max_interval (int): see :attr:`max_interval`.
lazy (bool): Don't set up the schedule.
"""
Entry = ScheduleEntry
#: The schedule dict/shelve.
schedule = None
#: Maximum time to sleep between re-checking the schedule.
max_interval = DEFAULT_MAX_INTERVAL
#: How often to sync the schedule (3 minutes by default)
sync_every = 3 * 60
#: How many tasks can be called before a sync is forced.
sync_every_tasks = None
_last_sync = None
_tasks_since_sync = 0
logger = logger # compat
def __init__(self, app, schedule=None, max_interval=None,
Producer=None, lazy=False, sync_every_tasks=None, **kwargs):
self.app = app
self.data = maybe_evaluate({} if schedule is None else schedule)
self.max_interval = (max_interval or
app.conf.beat_max_loop_interval or
self.max_interval)
self.Producer = Producer or app.amqp.Producer
self._heap = None
self.old_schedulers = None
self.sync_every_tasks = (
app.conf.beat_sync_every if sync_every_tasks is None
else sync_every_tasks)
if not lazy:
self.setup_schedule()
def install_default_entries(self, data):
entries = {}
if self.app.conf.result_expires and \
not self.app.backend.supports_autoexpire:
if 'celery.backend_cleanup' not in data:
entries['celery.backend_cleanup'] = {
'task': 'celery.backend_cleanup',
'schedule': crontab('0', '4', '*'),
'options': {'expires': 12 * 3600}}
self.update_from_dict(entries)
def apply_entry(self, entry, producer=None):
info('Scheduler: Sending due task %s (%s)', entry.name, entry.task)
try:
result = self.apply_async(entry, producer=producer, advance=False)
except Exception as exc: # pylint: disable=broad-except
error('Message Error: %s\n%s',
exc, traceback.format_stack(), exc_info=True)
else:
if result and hasattr(result, 'id'):
debug('%s sent. id->%s', entry.task, result.id)
else:
debug('%s sent.', entry.task)
def adjust(self, n, drift=-0.010):
if n and n > 0:
return n + drift
return n
def is_due(self, entry):
return entry.is_due()
def _when(self, entry, next_time_to_run, mktime=timegm):
"""Return a utc timestamp, make sure heapq in correct order."""
adjust = self.adjust
as_now = maybe_make_aware(entry.default_now())
return (mktime(as_now.utctimetuple()) +
as_now.microsecond / 1e6 +
(adjust(next_time_to_run) or 0))
def populate_heap(self, event_t=event_t, heapify=heapq.heapify):
"""Populate the heap with the data contained in the schedule."""
priority = 5
self._heap = []
for entry in self.schedule.values():
is_due, next_call_delay = entry.is_due()
self._heap.append(event_t(
self._when(
entry,
0 if is_due else next_call_delay
) or 0,
priority, entry
))
heapify(self._heap)
# pylint disable=redefined-outer-name
def tick(self, event_t=event_t, min=min, heappop=heapq.heappop,
heappush=heapq.heappush):
"""Run a tick - one iteration of the scheduler.
Executes one due task per call.
Returns:
float: preferred delay in seconds for next call.
"""
adjust = self.adjust
max_interval = self.max_interval
if (self._heap is None or
not self.schedules_equal(self.old_schedulers, self.schedule)):
self.old_schedulers = copy.copy(self.schedule)
self.populate_heap()
H = self._heap
if not H:
return max_interval
event = H[0]
entry = event[2]
is_due, next_time_to_run = self.is_due(entry)
if is_due:
verify = heappop(H)
if verify is event:
next_entry = self.reserve(entry)
self.apply_entry(entry, producer=self.producer)
heappush(H, event_t(self._when(next_entry, next_time_to_run),
event[1], next_entry))
return 0
else:
heappush(H, verify)
return min(verify[0], max_interval)
adjusted_next_time_to_run = adjust(next_time_to_run)
return min(adjusted_next_time_to_run if is_numeric_value(adjusted_next_time_to_run) else max_interval,
max_interval)
def schedules_equal(self, old_schedules, new_schedules):
if old_schedules is new_schedules is None:
return True
if old_schedules is None or new_schedules is None:
return False
if set(old_schedules.keys()) != set(new_schedules.keys()):
return False
for name, old_entry in old_schedules.items():
new_entry = new_schedules.get(name)
if not new_entry:
return False
if new_entry != old_entry:
return False
return True
def should_sync(self):
return (
(not self._last_sync or
(time.monotonic() - self._last_sync) > self.sync_every) or
(self.sync_every_tasks and
self._tasks_since_sync >= self.sync_every_tasks)
)
def reserve(self, entry):
new_entry = self.schedule[entry.name] = next(entry)
return new_entry
def apply_async(self, entry, producer=None, advance=True, **kwargs):
# Update time-stamps and run counts before we actually execute,
# so we have that done if an exception is raised (doesn't schedule
# forever.)
entry = self.reserve(entry) if advance else entry
task = self.app.tasks.get(entry.task)
try:
entry_args = _evaluate_entry_args(entry.args)
entry_kwargs = _evaluate_entry_kwargs(entry.kwargs)
if task:
return task.apply_async(entry_args, entry_kwargs,
producer=producer,
**entry.options)
else:
return self.send_task(entry.task, entry_args, entry_kwargs,
producer=producer,
**entry.options)
except Exception as exc: # pylint: disable=broad-except
reraise(SchedulingError, SchedulingError(
"Couldn't apply scheduled task {0.name}: {exc}".format(
entry, exc=exc)), sys.exc_info()[2])
finally:
self._tasks_since_sync += 1
if self.should_sync():
self._do_sync()
def send_task(self, *args, **kwargs):
return self.app.send_task(*args, **kwargs)
def setup_schedule(self):
self.install_default_entries(self.data)
self.merge_inplace(self.app.conf.beat_schedule)
def _do_sync(self):
try:
debug('beat: Synchronizing schedule...')
self.sync()
finally:
self._last_sync = time.monotonic()
self._tasks_since_sync = 0
def sync(self):
pass
def close(self):
self.sync()
def add(self, **kwargs):
entry = self.Entry(app=self.app, **kwargs)
self.schedule[entry.name] = entry
return entry
def _maybe_entry(self, name, entry):
if isinstance(entry, self.Entry):
entry.app = self.app
return entry
return self.Entry(**dict(entry, name=name, app=self.app))
def update_from_dict(self, dict_):
self.schedule.update({
name: self._maybe_entry(name, entry)
for name, entry in dict_.items()
})
def merge_inplace(self, b):
schedule = self.schedule
A, B = set(schedule), set(b)
# Remove items from disk not in the schedule anymore.
for key in A ^ B:
schedule.pop(key, None)
# Update and add new items in the schedule
for key in B:
entry = self.Entry(**dict(b[key], name=key, app=self.app))
if schedule.get(key):
schedule[key].update(entry)
else:
schedule[key] = entry
def _ensure_connected(self):
# callback called for each retry while the connection
# can't be established.
def _error_handler(exc, interval):
error('beat: Connection error: %s. '
'Trying again in %s seconds...', exc, interval)
return self.connection.ensure_connection(
_error_handler, self.app.conf.broker_connection_max_retries
)
def get_schedule(self):
return self.data
def set_schedule(self, schedule):
self.data = schedule
schedule = property(get_schedule, set_schedule)
@cached_property
def connection(self):
return self.app.connection_for_write()
@cached_property
def producer(self):
return self.Producer(self._ensure_connected(), auto_declare=False)
@property
def info(self):
return ''
class PersistentScheduler(Scheduler):
"""Scheduler backed by :mod:`shelve` database."""
persistence = shelve
known_suffixes = ('', '.db', '.dat', '.bak', '.dir')
_store = None
def __init__(self, *args, **kwargs):
self.schedule_filename = kwargs.get('schedule_filename')
super().__init__(*args, **kwargs)
def _remove_db(self):
for suffix in self.known_suffixes:
with platforms.ignore_errno(errno.ENOENT):
os.remove(self.schedule_filename + suffix)
def _open_schedule(self):
return self.persistence.open(self.schedule_filename, writeback=True)
def _destroy_open_corrupted_schedule(self, exc):
error('Removing corrupted schedule file %r: %r',
self.schedule_filename, exc, exc_info=True)
self._remove_db()
return self._open_schedule()
def setup_schedule(self):
try:
self._store = self._open_schedule()
# In some cases there may be different errors from a storage
# backend for corrupted files. Example - DBPageNotFoundError
# exception from bsddb. In such case the file will be
# successfully opened but the error will be raised on first key
# retrieving.
self._store.keys()
except Exception as exc: # pylint: disable=broad-except
self._store = self._destroy_open_corrupted_schedule(exc)
self._create_schedule()
tz = self.app.conf.timezone
stored_tz = self._store.get('tz')
if stored_tz is not None and stored_tz != tz:
warning('Reset: Timezone changed from %r to %r', stored_tz, tz)
self._store.clear() # Timezone changed, reset db!
utc = self.app.conf.enable_utc
stored_utc = self._store.get('utc_enabled')
if stored_utc is not None and stored_utc != utc:
choices = {True: 'enabled', False: 'disabled'}
warning('Reset: UTC changed from %s to %s',
choices[stored_utc], choices[utc])
self._store.clear() # UTC setting changed, reset db!
entries = self._store.setdefault('entries', {})
self.merge_inplace(self.app.conf.beat_schedule)
self.install_default_entries(self.schedule)
self._store.update({
'__version__': __version__,
'tz': tz,
'utc_enabled': utc,
})
self.sync()
debug('Current schedule:\n' + '\n'.join(
repr(entry) for entry in entries.values()))
def _create_schedule(self):
for _ in (1, 2):
try:
self._store['entries']
except KeyError:
# new schedule db
try:
self._store['entries'] = {}
except KeyError as exc:
self._store = self._destroy_open_corrupted_schedule(exc)
continue
else:
if '__version__' not in self._store:
warning('DB Reset: Account for new __version__ field')
self._store.clear() # remove schedule at 2.2.2 upgrade.
elif 'tz' not in self._store:
warning('DB Reset: Account for new tz field')
self._store.clear() # remove schedule at 3.0.8 upgrade
elif 'utc_enabled' not in self._store:
warning('DB Reset: Account for new utc_enabled field')
self._store.clear() # remove schedule at 3.0.9 upgrade
break
def get_schedule(self):
return self._store['entries']
def set_schedule(self, schedule):
self._store['entries'] = schedule
schedule = property(get_schedule, set_schedule)
def sync(self):
if self._store is not None:
self._store.sync()
def close(self):
self.sync()
self._store.close()
@property
def info(self):
return f' . db -> {self.schedule_filename}'
class Service:
"""Celery periodic task service."""
scheduler_cls = PersistentScheduler
def __init__(self, app, max_interval=None, schedule_filename=None,
scheduler_cls=None):
self.app = app
self.max_interval = (max_interval or
app.conf.beat_max_loop_interval)
self.scheduler_cls = scheduler_cls or self.scheduler_cls
self.schedule_filename = (
schedule_filename or app.conf.beat_schedule_filename)
self._is_shutdown = Event()
self._is_stopped = Event()
def __reduce__(self):
return self.__class__, (self.max_interval, self.schedule_filename,
self.scheduler_cls, self.app)
def start(self, embedded_process=False):
info('beat: Starting...')
debug('beat: Ticking with max interval->%s',
humanize_seconds(self.scheduler.max_interval))
signals.beat_init.send(sender=self)
if embedded_process:
signals.beat_embedded_init.send(sender=self)
platforms.set_process_title('celery beat')
try:
while not self._is_shutdown.is_set():
interval = self.scheduler.tick()
if interval and interval > 0.0:
debug('beat: Waking up %s.',
humanize_seconds(interval, prefix='in '))
time.sleep(interval)
if self.scheduler.should_sync():
self.scheduler._do_sync()
except (KeyboardInterrupt, SystemExit):
self._is_shutdown.set()
finally:
self.sync()
def sync(self):
self.scheduler.close()
self._is_stopped.set()
def stop(self, wait=False):
info('beat: Shutting down...')
self._is_shutdown.set()
wait and self._is_stopped.wait() # block until shutdown done.
def get_scheduler(self, lazy=False,
extension_namespace='celery.beat_schedulers'):
filename = self.schedule_filename
aliases = dict(load_extension_class_names(extension_namespace))
return symbol_by_name(self.scheduler_cls, aliases=aliases)(
app=self.app,
schedule_filename=filename,
max_interval=self.max_interval,
lazy=lazy,
)
@cached_property
def scheduler(self):
return self.get_scheduler()
class _Threaded(Thread):
"""Embedded task scheduler using threading."""
def __init__(self, app, **kwargs):
super().__init__()
self.app = app
self.service = Service(app, **kwargs)
self.daemon = True
self.name = 'Beat'
def run(self):
self.app.set_current()
self.service.start()
def stop(self):
self.service.stop(wait=True)
try:
ensure_multiprocessing()
except NotImplementedError: # pragma: no cover
_Process = None
else:
class _Process(Process):
def __init__(self, app, **kwargs):
super().__init__()
self.app = app
self.service = Service(app, **kwargs)
self.name = 'Beat'
def run(self):
reset_signals(full=False)
platforms.close_open_fds([
sys.__stdin__, sys.__stdout__, sys.__stderr__,
] + list(iter_open_logger_fds()))
self.app.set_default()
self.app.set_current()
self.service.start(embedded_process=True)
def stop(self):
self.service.stop()
self.terminate()
def EmbeddedService(app, max_interval=None, **kwargs):
"""Return embedded clock service.
Arguments:
thread (bool): Run threaded instead of as a separate process.
Uses :mod:`multiprocessing` by default, if available.
"""
if kwargs.pop('thread', False) or _Process is None:
# Need short max interval to be able to stop thread
# in reasonable time.
return _Threaded(app, max_interval=1, **kwargs)
return _Process(app, max_interval=max_interval, **kwargs)

View File

@@ -0,0 +1,312 @@
"""AMQP 0.9.1 REPL."""
import pprint
import click
from amqp import Connection, Message
from click_repl import register_repl
__all__ = ('amqp',)
from celery.bin.base import handle_preload_options
def dump_message(message):
if message is None:
return 'No messages in queue. basic.publish something.'
return {'body': message.body,
'properties': message.properties,
'delivery_info': message.delivery_info}
class AMQPContext:
def __init__(self, cli_context):
self.cli_context = cli_context
self.connection = self.cli_context.app.connection()
self.channel = None
self.reconnect()
@property
def app(self):
return self.cli_context.app
def respond(self, retval):
if isinstance(retval, str):
self.cli_context.echo(retval)
else:
self.cli_context.echo(pprint.pformat(retval))
def echo_error(self, exception):
self.cli_context.error(f'{self.cli_context.ERROR}: {exception}')
def echo_ok(self):
self.cli_context.echo(self.cli_context.OK)
def reconnect(self):
if self.connection:
self.connection.close()
else:
self.connection = self.cli_context.app.connection()
self.cli_context.echo(f'-> connecting to {self.connection.as_uri()}.')
try:
self.connection.connect()
except (ConnectionRefusedError, ConnectionResetError) as e:
self.echo_error(e)
else:
self.cli_context.secho('-> connected.', fg='green', bold=True)
self.channel = self.connection.default_channel
@click.group(invoke_without_command=True)
@click.pass_context
@handle_preload_options
def amqp(ctx):
"""AMQP Administration Shell.
Also works for non-AMQP transports (but not ones that
store declarations in memory).
"""
if not isinstance(ctx.obj, AMQPContext):
ctx.obj = AMQPContext(ctx.obj)
@amqp.command(name='exchange.declare')
@click.argument('exchange',
type=str)
@click.argument('type',
type=str)
@click.argument('passive',
type=bool,
default=False)
@click.argument('durable',
type=bool,
default=False)
@click.argument('auto_delete',
type=bool,
default=False)
@click.pass_obj
def exchange_declare(amqp_context, exchange, type, passive, durable,
auto_delete):
if amqp_context.channel is None:
amqp_context.echo_error('Not connected to broker. Please retry...')
amqp_context.reconnect()
else:
try:
amqp_context.channel.exchange_declare(exchange=exchange,
type=type,
passive=passive,
durable=durable,
auto_delete=auto_delete)
except Exception as e:
amqp_context.echo_error(e)
amqp_context.reconnect()
else:
amqp_context.echo_ok()
@amqp.command(name='exchange.delete')
@click.argument('exchange',
type=str)
@click.argument('if_unused',
type=bool)
@click.pass_obj
def exchange_delete(amqp_context, exchange, if_unused):
if amqp_context.channel is None:
amqp_context.echo_error('Not connected to broker. Please retry...')
amqp_context.reconnect()
else:
try:
amqp_context.channel.exchange_delete(exchange=exchange,
if_unused=if_unused)
except Exception as e:
amqp_context.echo_error(e)
amqp_context.reconnect()
else:
amqp_context.echo_ok()
@amqp.command(name='queue.bind')
@click.argument('queue',
type=str)
@click.argument('exchange',
type=str)
@click.argument('routing_key',
type=str)
@click.pass_obj
def queue_bind(amqp_context, queue, exchange, routing_key):
if amqp_context.channel is None:
amqp_context.echo_error('Not connected to broker. Please retry...')
amqp_context.reconnect()
else:
try:
amqp_context.channel.queue_bind(queue=queue,
exchange=exchange,
routing_key=routing_key)
except Exception as e:
amqp_context.echo_error(e)
amqp_context.reconnect()
else:
amqp_context.echo_ok()
@amqp.command(name='queue.declare')
@click.argument('queue',
type=str)
@click.argument('passive',
type=bool,
default=False)
@click.argument('durable',
type=bool,
default=False)
@click.argument('auto_delete',
type=bool,
default=False)
@click.pass_obj
def queue_declare(amqp_context, queue, passive, durable, auto_delete):
if amqp_context.channel is None:
amqp_context.echo_error('Not connected to broker. Please retry...')
amqp_context.reconnect()
else:
try:
retval = amqp_context.channel.queue_declare(queue=queue,
passive=passive,
durable=durable,
auto_delete=auto_delete)
except Exception as e:
amqp_context.echo_error(e)
amqp_context.reconnect()
else:
amqp_context.cli_context.secho(
'queue:{} messages:{} consumers:{}'.format(*retval),
fg='cyan', bold=True)
amqp_context.echo_ok()
@amqp.command(name='queue.delete')
@click.argument('queue',
type=str)
@click.argument('if_unused',
type=bool,
default=False)
@click.argument('if_empty',
type=bool,
default=False)
@click.pass_obj
def queue_delete(amqp_context, queue, if_unused, if_empty):
if amqp_context.channel is None:
amqp_context.echo_error('Not connected to broker. Please retry...')
amqp_context.reconnect()
else:
try:
retval = amqp_context.channel.queue_delete(queue=queue,
if_unused=if_unused,
if_empty=if_empty)
except Exception as e:
amqp_context.echo_error(e)
amqp_context.reconnect()
else:
amqp_context.cli_context.secho(
f'{retval} messages deleted.',
fg='cyan', bold=True)
amqp_context.echo_ok()
@amqp.command(name='queue.purge')
@click.argument('queue',
type=str)
@click.pass_obj
def queue_purge(amqp_context, queue):
if amqp_context.channel is None:
amqp_context.echo_error('Not connected to broker. Please retry...')
amqp_context.reconnect()
else:
try:
retval = amqp_context.channel.queue_purge(queue=queue)
except Exception as e:
amqp_context.echo_error(e)
amqp_context.reconnect()
else:
amqp_context.cli_context.secho(
f'{retval} messages deleted.',
fg='cyan', bold=True)
amqp_context.echo_ok()
@amqp.command(name='basic.get')
@click.argument('queue',
type=str)
@click.argument('no_ack',
type=bool,
default=False)
@click.pass_obj
def basic_get(amqp_context, queue, no_ack):
if amqp_context.channel is None:
amqp_context.echo_error('Not connected to broker. Please retry...')
amqp_context.reconnect()
else:
try:
message = amqp_context.channel.basic_get(queue, no_ack=no_ack)
except Exception as e:
amqp_context.echo_error(e)
amqp_context.reconnect()
else:
amqp_context.respond(dump_message(message))
amqp_context.echo_ok()
@amqp.command(name='basic.publish')
@click.argument('msg',
type=str)
@click.argument('exchange',
type=str)
@click.argument('routing_key',
type=str)
@click.argument('mandatory',
type=bool,
default=False)
@click.argument('immediate',
type=bool,
default=False)
@click.pass_obj
def basic_publish(amqp_context, msg, exchange, routing_key, mandatory,
immediate):
if amqp_context.channel is None:
amqp_context.echo_error('Not connected to broker. Please retry...')
amqp_context.reconnect()
else:
# XXX Hack to fix Issue #2013
if isinstance(amqp_context.connection.connection, Connection):
msg = Message(msg)
try:
amqp_context.channel.basic_publish(msg,
exchange=exchange,
routing_key=routing_key,
mandatory=mandatory,
immediate=immediate)
except Exception as e:
amqp_context.echo_error(e)
amqp_context.reconnect()
else:
amqp_context.echo_ok()
@amqp.command(name='basic.ack')
@click.argument('delivery_tag',
type=int)
@click.pass_obj
def basic_ack(amqp_context, delivery_tag):
if amqp_context.channel is None:
amqp_context.echo_error('Not connected to broker. Please retry...')
amqp_context.reconnect()
else:
try:
amqp_context.channel.basic_ack(delivery_tag)
except Exception as e:
amqp_context.echo_error(e)
amqp_context.reconnect()
else:
amqp_context.echo_ok()
register_repl(amqp)

View File

@@ -0,0 +1,287 @@
"""Click customizations for Celery."""
import json
import numbers
from collections import OrderedDict
from functools import update_wrapper
from pprint import pformat
import click
from click import ParamType
from kombu.utils.objects import cached_property
from celery._state import get_current_app
from celery.signals import user_preload_options
from celery.utils import text
from celery.utils.log import mlevel
from celery.utils.time import maybe_iso8601
try:
from pygments import highlight
from pygments.formatters import Terminal256Formatter
from pygments.lexers import PythonLexer
except ImportError:
def highlight(s, *args, **kwargs):
"""Place holder function in case pygments is missing."""
return s
LEXER = None
FORMATTER = None
else:
LEXER = PythonLexer()
FORMATTER = Terminal256Formatter()
class CLIContext:
"""Context Object for the CLI."""
def __init__(self, app, no_color, workdir, quiet=False):
"""Initialize the CLI context."""
self.app = app or get_current_app()
self.no_color = no_color
self.quiet = quiet
self.workdir = workdir
@cached_property
def OK(self):
return self.style("OK", fg="green", bold=True)
@cached_property
def ERROR(self):
return self.style("ERROR", fg="red", bold=True)
def style(self, message=None, **kwargs):
if self.no_color:
return message
else:
return click.style(message, **kwargs)
def secho(self, message=None, **kwargs):
if self.no_color:
kwargs['color'] = False
click.echo(message, **kwargs)
else:
click.secho(message, **kwargs)
def echo(self, message=None, **kwargs):
if self.no_color:
kwargs['color'] = False
click.echo(message, **kwargs)
else:
click.echo(message, **kwargs)
def error(self, message=None, **kwargs):
kwargs['err'] = True
if self.no_color:
kwargs['color'] = False
click.echo(message, **kwargs)
else:
click.secho(message, **kwargs)
def pretty(self, n):
if isinstance(n, list):
return self.OK, self.pretty_list(n)
if isinstance(n, dict):
if 'ok' in n or 'error' in n:
return self.pretty_dict_ok_error(n)
else:
s = json.dumps(n, sort_keys=True, indent=4)
if not self.no_color:
s = highlight(s, LEXER, FORMATTER)
return self.OK, s
if isinstance(n, str):
return self.OK, n
return self.OK, pformat(n)
def pretty_list(self, n):
if not n:
return '- empty -'
return '\n'.join(
f'{self.style("*", fg="white")} {item}' for item in n
)
def pretty_dict_ok_error(self, n):
try:
return (self.OK,
text.indent(self.pretty(n['ok'])[1], 4))
except KeyError:
pass
return (self.ERROR,
text.indent(self.pretty(n['error'])[1], 4))
def say_chat(self, direction, title, body='', show_body=False):
if direction == '<-' and self.quiet:
return
dirstr = not self.quiet and f'{self.style(direction, fg="white", bold=True)} ' or ''
self.echo(f'{dirstr} {title}')
if body and show_body:
self.echo(body)
def handle_preload_options(f):
"""Extract preload options and return a wrapped callable."""
def caller(ctx, *args, **kwargs):
app = ctx.obj.app
preload_options = [o.name for o in app.user_options.get('preload', [])]
if preload_options:
user_options = {
preload_option: kwargs[preload_option]
for preload_option in preload_options
}
user_preload_options.send(sender=f, app=app, options=user_options)
return f(ctx, *args, **kwargs)
return update_wrapper(caller, f)
class CeleryOption(click.Option):
"""Customized option for Celery."""
def get_default(self, ctx, *args, **kwargs):
if self.default_value_from_context:
self.default = ctx.obj[self.default_value_from_context]
return super().get_default(ctx, *args, **kwargs)
def __init__(self, *args, **kwargs):
"""Initialize a Celery option."""
self.help_group = kwargs.pop('help_group', None)
self.default_value_from_context = kwargs.pop('default_value_from_context', None)
super().__init__(*args, **kwargs)
class CeleryCommand(click.Command):
"""Customized command for Celery."""
def format_options(self, ctx, formatter):
"""Write all the options into the formatter if they exist."""
opts = OrderedDict()
for param in self.get_params(ctx):
rv = param.get_help_record(ctx)
if rv is not None:
if hasattr(param, 'help_group') and param.help_group:
opts.setdefault(str(param.help_group), []).append(rv)
else:
opts.setdefault('Options', []).append(rv)
for name, opts_group in opts.items():
with formatter.section(name):
formatter.write_dl(opts_group)
class CeleryDaemonCommand(CeleryCommand):
"""Daemon commands."""
def __init__(self, *args, **kwargs):
"""Initialize a Celery command with common daemon options."""
super().__init__(*args, **kwargs)
self.params.append(CeleryOption(('-f', '--logfile'), help_group="Daemonization Options",
help="Log destination; defaults to stderr"))
self.params.append(CeleryOption(('--pidfile',), help_group="Daemonization Options"))
self.params.append(CeleryOption(('--uid',), help_group="Daemonization Options"))
self.params.append(CeleryOption(('--gid',), help_group="Daemonization Options"))
self.params.append(CeleryOption(('--umask',), help_group="Daemonization Options"))
self.params.append(CeleryOption(('--executable',), help_group="Daemonization Options"))
class CommaSeparatedList(ParamType):
"""Comma separated list argument."""
name = "comma separated list"
def convert(self, value, param, ctx):
return text.str_to_list(value)
class JsonArray(ParamType):
"""JSON formatted array argument."""
name = "json array"
def convert(self, value, param, ctx):
if isinstance(value, list):
return value
try:
v = json.loads(value)
except ValueError as e:
self.fail(str(e))
if not isinstance(v, list):
self.fail(f"{value} was not an array")
return v
class JsonObject(ParamType):
"""JSON formatted object argument."""
name = "json object"
def convert(self, value, param, ctx):
if isinstance(value, dict):
return value
try:
v = json.loads(value)
except ValueError as e:
self.fail(str(e))
if not isinstance(v, dict):
self.fail(f"{value} was not an object")
return v
class ISO8601DateTime(ParamType):
"""ISO 8601 Date Time argument."""
name = "iso-86091"
def convert(self, value, param, ctx):
try:
return maybe_iso8601(value)
except (TypeError, ValueError) as e:
self.fail(e)
class ISO8601DateTimeOrFloat(ParamType):
"""ISO 8601 Date Time or float argument."""
name = "iso-86091 or float"
def convert(self, value, param, ctx):
try:
return float(value)
except (TypeError, ValueError):
pass
try:
return maybe_iso8601(value)
except (TypeError, ValueError) as e:
self.fail(e)
class LogLevel(click.Choice):
"""Log level option."""
def __init__(self):
"""Initialize the log level option with the relevant choices."""
super().__init__(('DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL', 'FATAL'))
def convert(self, value, param, ctx):
if isinstance(value, numbers.Integral):
return value
value = value.upper()
value = super().convert(value, param, ctx)
return mlevel(value)
JSON_ARRAY = JsonArray()
JSON_OBJECT = JsonObject()
ISO8601 = ISO8601DateTime()
ISO8601_OR_FLOAT = ISO8601DateTimeOrFloat()
LOG_LEVEL = LogLevel()
COMMA_SEPARATED_LIST = CommaSeparatedList()

View File

@@ -0,0 +1,72 @@
"""The :program:`celery beat` command."""
from functools import partial
import click
from celery.bin.base import LOG_LEVEL, CeleryDaemonCommand, CeleryOption, handle_preload_options
from celery.platforms import detached, maybe_drop_privileges
@click.command(cls=CeleryDaemonCommand, context_settings={
'allow_extra_args': True
})
@click.option('--detach',
cls=CeleryOption,
is_flag=True,
default=False,
help_group="Beat Options",
help="Detach and run in the background as a daemon.")
@click.option('-s',
'--schedule',
cls=CeleryOption,
callback=lambda ctx, _, value: value or ctx.obj.app.conf.beat_schedule_filename,
help_group="Beat Options",
help="Path to the schedule database."
" Defaults to `celerybeat-schedule`."
"The extension '.db' may be appended to the filename.")
@click.option('-S',
'--scheduler',
cls=CeleryOption,
callback=lambda ctx, _, value: value or ctx.obj.app.conf.beat_scheduler,
help_group="Beat Options",
help="Scheduler class to use.")
@click.option('--max-interval',
cls=CeleryOption,
type=int,
help_group="Beat Options",
help="Max seconds to sleep between schedule iterations.")
@click.option('-l',
'--loglevel',
default='WARNING',
cls=CeleryOption,
type=LOG_LEVEL,
help_group="Beat Options",
help="Logging level.")
@click.pass_context
@handle_preload_options
def beat(ctx, detach=False, logfile=None, pidfile=None, uid=None,
gid=None, umask=None, workdir=None, **kwargs):
"""Start the beat periodic task scheduler."""
app = ctx.obj.app
if ctx.args:
try:
app.config_from_cmdline(ctx.args)
except (KeyError, ValueError) as e:
# TODO: Improve the error messages
raise click.UsageError("Unable to parse extra configuration"
" from command line.\n"
f"Reason: {e}", ctx=ctx)
if not detach:
maybe_drop_privileges(uid=uid, gid=gid)
beat = partial(app.Beat,
logfile=logfile, pidfile=pidfile,
quiet=ctx.obj.quiet, **kwargs)
if detach:
with detached(logfile, pidfile, uid, gid, umask, workdir):
return beat().run()
else:
return beat().run()

View File

@@ -0,0 +1,71 @@
"""The ``celery call`` program used to send tasks from the command-line."""
import click
from celery.bin.base import (ISO8601, ISO8601_OR_FLOAT, JSON_ARRAY, JSON_OBJECT, CeleryCommand, CeleryOption,
handle_preload_options)
@click.command(cls=CeleryCommand)
@click.argument('name')
@click.option('-a',
'--args',
cls=CeleryOption,
type=JSON_ARRAY,
default='[]',
help_group="Calling Options",
help="Positional arguments.")
@click.option('-k',
'--kwargs',
cls=CeleryOption,
type=JSON_OBJECT,
default='{}',
help_group="Calling Options",
help="Keyword arguments.")
@click.option('--eta',
cls=CeleryOption,
type=ISO8601,
help_group="Calling Options",
help="scheduled time.")
@click.option('--countdown',
cls=CeleryOption,
type=float,
help_group="Calling Options",
help="eta in seconds from now.")
@click.option('--expires',
cls=CeleryOption,
type=ISO8601_OR_FLOAT,
help_group="Calling Options",
help="expiry time.")
@click.option('--serializer',
cls=CeleryOption,
default='json',
help_group="Calling Options",
help="task serializer.")
@click.option('--queue',
cls=CeleryOption,
help_group="Routing Options",
help="custom queue name.")
@click.option('--exchange',
cls=CeleryOption,
help_group="Routing Options",
help="custom exchange name.")
@click.option('--routing-key',
cls=CeleryOption,
help_group="Routing Options",
help="custom routing key.")
@click.pass_context
@handle_preload_options
def call(ctx, name, args, kwargs, eta, countdown, expires, serializer, queue, exchange, routing_key):
"""Call a task by name."""
task_id = ctx.obj.app.send_task(
name,
args=args, kwargs=kwargs,
countdown=countdown,
serializer=serializer,
queue=queue,
exchange=exchange,
routing_key=routing_key,
eta=eta,
expires=expires
).id
ctx.obj.echo(task_id)

View File

@@ -0,0 +1,236 @@
"""Celery Command Line Interface."""
import os
import pathlib
import sys
import traceback
try:
from importlib.metadata import entry_points
except ImportError:
from importlib_metadata import entry_points
import click
import click.exceptions
from click.types import ParamType
from click_didyoumean import DYMGroup
from click_plugins import with_plugins
from celery import VERSION_BANNER
from celery.app.utils import find_app
from celery.bin.amqp import amqp
from celery.bin.base import CeleryCommand, CeleryOption, CLIContext
from celery.bin.beat import beat
from celery.bin.call import call
from celery.bin.control import control, inspect, status
from celery.bin.events import events
from celery.bin.graph import graph
from celery.bin.list import list_
from celery.bin.logtool import logtool
from celery.bin.migrate import migrate
from celery.bin.multi import multi
from celery.bin.purge import purge
from celery.bin.result import result
from celery.bin.shell import shell
from celery.bin.upgrade import upgrade
from celery.bin.worker import worker
UNABLE_TO_LOAD_APP_MODULE_NOT_FOUND = click.style("""
Unable to load celery application.
The module {0} was not found.""", fg='red')
UNABLE_TO_LOAD_APP_ERROR_OCCURRED = click.style("""
Unable to load celery application.
While trying to load the module {0} the following error occurred:
{1}""", fg='red')
UNABLE_TO_LOAD_APP_APP_MISSING = click.style("""
Unable to load celery application.
{0}""")
class App(ParamType):
"""Application option."""
name = "application"
def convert(self, value, param, ctx):
try:
return find_app(value)
except ModuleNotFoundError as e:
if e.name != value:
exc = traceback.format_exc()
self.fail(
UNABLE_TO_LOAD_APP_ERROR_OCCURRED.format(value, exc)
)
self.fail(UNABLE_TO_LOAD_APP_MODULE_NOT_FOUND.format(e.name))
except AttributeError as e:
attribute_name = e.args[0].capitalize()
self.fail(UNABLE_TO_LOAD_APP_APP_MISSING.format(attribute_name))
except Exception:
exc = traceback.format_exc()
self.fail(
UNABLE_TO_LOAD_APP_ERROR_OCCURRED.format(value, exc)
)
APP = App()
if sys.version_info >= (3, 10):
_PLUGINS = entry_points(group='celery.commands')
else:
try:
_PLUGINS = entry_points().get('celery.commands', [])
except AttributeError:
_PLUGINS = entry_points().select(group='celery.commands')
@with_plugins(_PLUGINS)
@click.group(cls=DYMGroup, invoke_without_command=True)
@click.option('-A',
'--app',
envvar='APP',
cls=CeleryOption,
type=APP,
help_group="Global Options")
@click.option('-b',
'--broker',
envvar='BROKER_URL',
cls=CeleryOption,
help_group="Global Options")
@click.option('--result-backend',
envvar='RESULT_BACKEND',
cls=CeleryOption,
help_group="Global Options")
@click.option('--loader',
envvar='LOADER',
cls=CeleryOption,
help_group="Global Options")
@click.option('--config',
envvar='CONFIG_MODULE',
cls=CeleryOption,
help_group="Global Options")
@click.option('--workdir',
cls=CeleryOption,
type=pathlib.Path,
callback=lambda _, __, wd: os.chdir(wd) if wd else None,
is_eager=True,
help_group="Global Options")
@click.option('-C',
'--no-color',
envvar='NO_COLOR',
is_flag=True,
cls=CeleryOption,
help_group="Global Options")
@click.option('-q',
'--quiet',
is_flag=True,
cls=CeleryOption,
help_group="Global Options")
@click.option('--version',
cls=CeleryOption,
is_flag=True,
help_group="Global Options")
@click.option('--skip-checks',
envvar='SKIP_CHECKS',
cls=CeleryOption,
is_flag=True,
help_group="Global Options",
help="Skip Django core checks on startup. Setting the SKIP_CHECKS environment "
"variable to any non-empty string will have the same effect.")
@click.pass_context
def celery(ctx, app, broker, result_backend, loader, config, workdir,
no_color, quiet, version, skip_checks):
"""Celery command entrypoint."""
if version:
click.echo(VERSION_BANNER)
ctx.exit()
elif ctx.invoked_subcommand is None:
click.echo(ctx.get_help())
ctx.exit()
if loader:
# Default app takes loader from this env (Issue #1066).
os.environ['CELERY_LOADER'] = loader
if broker:
os.environ['CELERY_BROKER_URL'] = broker
if result_backend:
os.environ['CELERY_RESULT_BACKEND'] = result_backend
if config:
os.environ['CELERY_CONFIG_MODULE'] = config
if skip_checks:
os.environ['CELERY_SKIP_CHECKS'] = 'true'
ctx.obj = CLIContext(app=app, no_color=no_color, workdir=workdir,
quiet=quiet)
# User options
worker.params.extend(ctx.obj.app.user_options.get('worker', []))
beat.params.extend(ctx.obj.app.user_options.get('beat', []))
events.params.extend(ctx.obj.app.user_options.get('events', []))
for command in celery.commands.values():
command.params.extend(ctx.obj.app.user_options.get('preload', []))
@celery.command(cls=CeleryCommand)
@click.pass_context
def report(ctx, **kwargs):
"""Shows information useful to include in bug-reports."""
app = ctx.obj.app
app.loader.import_default_modules()
ctx.obj.echo(app.bugreport())
celery.add_command(purge)
celery.add_command(call)
celery.add_command(beat)
celery.add_command(list_)
celery.add_command(result)
celery.add_command(migrate)
celery.add_command(status)
celery.add_command(worker)
celery.add_command(events)
celery.add_command(inspect)
celery.add_command(control)
celery.add_command(graph)
celery.add_command(upgrade)
celery.add_command(logtool)
celery.add_command(amqp)
celery.add_command(shell)
celery.add_command(multi)
# Monkey-patch click to display a custom error
# when -A or --app are used as sub-command options instead of as options
# of the global command.
previous_show_implementation = click.exceptions.NoSuchOption.show
WRONG_APP_OPTION_USAGE_MESSAGE = """You are using `{option_name}` as an option of the {info_name} sub-command:
celery {info_name} {option_name} celeryapp <...>
The support for this usage was removed in Celery 5.0. Instead you should use `{option_name}` as a global option:
celery {option_name} celeryapp {info_name} <...>"""
def _show(self, file=None):
if self.option_name in ('-A', '--app'):
self.ctx.obj.error(
WRONG_APP_OPTION_USAGE_MESSAGE.format(
option_name=self.option_name,
info_name=self.ctx.info_name),
fg='red'
)
previous_show_implementation(self, file=file)
click.exceptions.NoSuchOption.show = _show
def main() -> int:
"""Start celery umbrella command.
This function is the main entrypoint for the CLI.
:return: The exit code of the CLI.
"""
return celery(auto_envvar_prefix="CELERY")

View File

@@ -0,0 +1,203 @@
"""The ``celery control``, ``. inspect`` and ``. status`` programs."""
from functools import partial
import click
from kombu.utils.json import dumps
from celery.bin.base import COMMA_SEPARATED_LIST, CeleryCommand, CeleryOption, handle_preload_options
from celery.exceptions import CeleryCommandException
from celery.platforms import EX_UNAVAILABLE
from celery.utils import text
from celery.worker.control import Panel
def _say_remote_command_reply(ctx, replies, show_reply=False):
node = next(iter(replies)) # <-- take first.
reply = replies[node]
node = ctx.obj.style(f'{node}: ', fg='cyan', bold=True)
status, preply = ctx.obj.pretty(reply)
ctx.obj.say_chat('->', f'{node}{status}',
text.indent(preply, 4) if show_reply else '',
show_body=show_reply)
def _consume_arguments(meta, method, args):
i = 0
try:
for i, arg in enumerate(args):
try:
name, typ = meta.args[i]
except IndexError:
if meta.variadic:
break
raise click.UsageError(
'Command {!r} takes arguments: {}'.format(
method, meta.signature))
else:
yield name, typ(arg) if typ is not None else arg
finally:
args[:] = args[i:]
def _compile_arguments(action, args):
meta = Panel.meta[action]
arguments = {}
if meta.args:
arguments.update({
k: v for k, v in _consume_arguments(meta, action, args)
})
if meta.variadic:
arguments.update({meta.variadic: args})
return arguments
@click.command(cls=CeleryCommand)
@click.option('-t',
'--timeout',
cls=CeleryOption,
type=float,
default=1.0,
help_group='Remote Control Options',
help='Timeout in seconds waiting for reply.')
@click.option('-d',
'--destination',
cls=CeleryOption,
type=COMMA_SEPARATED_LIST,
help_group='Remote Control Options',
help='Comma separated list of destination node names.')
@click.option('-j',
'--json',
cls=CeleryOption,
is_flag=True,
help_group='Remote Control Options',
help='Use json as output format.')
@click.pass_context
@handle_preload_options
def status(ctx, timeout, destination, json, **kwargs):
"""Show list of workers that are online."""
callback = None if json else partial(_say_remote_command_reply, ctx)
replies = ctx.obj.app.control.inspect(timeout=timeout,
destination=destination,
callback=callback).ping()
if not replies:
raise CeleryCommandException(
message='No nodes replied within time constraint',
exit_code=EX_UNAVAILABLE
)
if json:
ctx.obj.echo(dumps(replies))
nodecount = len(replies)
if not kwargs.get('quiet', False):
ctx.obj.echo('\n{} {} online.'.format(
nodecount, text.pluralize(nodecount, 'node')))
@click.command(cls=CeleryCommand,
context_settings={'allow_extra_args': True})
@click.argument("action", type=click.Choice([
name for name, info in Panel.meta.items()
if info.type == 'inspect' and info.visible
]))
@click.option('-t',
'--timeout',
cls=CeleryOption,
type=float,
default=1.0,
help_group='Remote Control Options',
help='Timeout in seconds waiting for reply.')
@click.option('-d',
'--destination',
cls=CeleryOption,
type=COMMA_SEPARATED_LIST,
help_group='Remote Control Options',
help='Comma separated list of destination node names.')
@click.option('-j',
'--json',
cls=CeleryOption,
is_flag=True,
help_group='Remote Control Options',
help='Use json as output format.')
@click.pass_context
@handle_preload_options
def inspect(ctx, action, timeout, destination, json, **kwargs):
"""Inspect the worker at runtime.
Availability: RabbitMQ (AMQP) and Redis transports.
"""
callback = None if json else partial(_say_remote_command_reply, ctx,
show_reply=True)
arguments = _compile_arguments(action, ctx.args)
inspect = ctx.obj.app.control.inspect(timeout=timeout,
destination=destination,
callback=callback)
replies = inspect._request(action,
**arguments)
if not replies:
raise CeleryCommandException(
message='No nodes replied within time constraint',
exit_code=EX_UNAVAILABLE
)
if json:
ctx.obj.echo(dumps(replies))
return
nodecount = len(replies)
if not ctx.obj.quiet:
ctx.obj.echo('\n{} {} online.'.format(
nodecount, text.pluralize(nodecount, 'node')))
@click.command(cls=CeleryCommand,
context_settings={'allow_extra_args': True})
@click.argument("action", type=click.Choice([
name for name, info in Panel.meta.items()
if info.type == 'control' and info.visible
]))
@click.option('-t',
'--timeout',
cls=CeleryOption,
type=float,
default=1.0,
help_group='Remote Control Options',
help='Timeout in seconds waiting for reply.')
@click.option('-d',
'--destination',
cls=CeleryOption,
type=COMMA_SEPARATED_LIST,
help_group='Remote Control Options',
help='Comma separated list of destination node names.')
@click.option('-j',
'--json',
cls=CeleryOption,
is_flag=True,
help_group='Remote Control Options',
help='Use json as output format.')
@click.pass_context
@handle_preload_options
def control(ctx, action, timeout, destination, json):
"""Workers remote control.
Availability: RabbitMQ (AMQP), Redis, and MongoDB transports.
"""
callback = None if json else partial(_say_remote_command_reply, ctx,
show_reply=True)
args = ctx.args
arguments = _compile_arguments(action, args)
replies = ctx.obj.app.control.broadcast(action, timeout=timeout,
destination=destination,
callback=callback,
reply=True,
arguments=arguments)
if not replies:
raise CeleryCommandException(
message='No nodes replied within time constraint',
exit_code=EX_UNAVAILABLE
)
if json:
ctx.obj.echo(dumps(replies))

View File

@@ -0,0 +1,94 @@
"""The ``celery events`` program."""
import sys
from functools import partial
import click
from celery.bin.base import LOG_LEVEL, CeleryDaemonCommand, CeleryOption, handle_preload_options
from celery.platforms import detached, set_process_title, strargv
def _set_process_status(prog, info=''):
prog = '{}:{}'.format('celery events', prog)
info = f'{info} {strargv(sys.argv)}'
return set_process_title(prog, info=info)
def _run_evdump(app):
from celery.events.dumper import evdump
_set_process_status('dump')
return evdump(app=app)
def _run_evcam(camera, app, logfile=None, pidfile=None, uid=None,
gid=None, umask=None, workdir=None,
detach=False, **kwargs):
from celery.events.snapshot import evcam
_set_process_status('cam')
kwargs['app'] = app
cam = partial(evcam, camera,
logfile=logfile, pidfile=pidfile, **kwargs)
if detach:
with detached(logfile, pidfile, uid, gid, umask, workdir):
return cam()
else:
return cam()
def _run_evtop(app):
try:
from celery.events.cursesmon import evtop
_set_process_status('top')
return evtop(app=app)
except ModuleNotFoundError as e:
if e.name == '_curses':
# TODO: Improve this error message
raise click.UsageError("The curses module is required for this command.")
@click.command(cls=CeleryDaemonCommand)
@click.option('-d',
'--dump',
cls=CeleryOption,
is_flag=True,
help_group='Dumper')
@click.option('-c',
'--camera',
cls=CeleryOption,
help_group='Snapshot')
@click.option('-d',
'--detach',
cls=CeleryOption,
is_flag=True,
help_group='Snapshot')
@click.option('-F', '--frequency', '--freq',
type=float,
default=1.0,
cls=CeleryOption,
help_group='Snapshot')
@click.option('-r', '--maxrate',
cls=CeleryOption,
help_group='Snapshot')
@click.option('-l',
'--loglevel',
default='WARNING',
cls=CeleryOption,
type=LOG_LEVEL,
help_group="Snapshot",
help="Logging level.")
@click.pass_context
@handle_preload_options
def events(ctx, dump, camera, detach, frequency, maxrate, loglevel, **kwargs):
"""Event-stream utilities."""
app = ctx.obj.app
if dump:
return _run_evdump(app)
if camera:
return _run_evcam(camera, app=app, freq=frequency, maxrate=maxrate,
loglevel=loglevel,
detach=detach,
**kwargs)
return _run_evtop(app)

View File

@@ -0,0 +1,197 @@
"""The ``celery graph`` command."""
import sys
from operator import itemgetter
import click
from celery.bin.base import CeleryCommand, handle_preload_options
from celery.utils.graph import DependencyGraph, GraphFormatter
@click.group()
@click.pass_context
@handle_preload_options
def graph(ctx):
"""The ``celery graph`` command."""
@graph.command(cls=CeleryCommand, context_settings={'allow_extra_args': True})
@click.pass_context
def bootsteps(ctx):
"""Display bootsteps graph."""
worker = ctx.obj.app.WorkController()
include = {arg.lower() for arg in ctx.args or ['worker', 'consumer']}
if 'worker' in include:
worker_graph = worker.blueprint.graph
if 'consumer' in include:
worker.blueprint.connect_with(worker.consumer.blueprint)
else:
worker_graph = worker.consumer.blueprint.graph
worker_graph.to_dot(sys.stdout)
@graph.command(cls=CeleryCommand, context_settings={'allow_extra_args': True})
@click.pass_context
def workers(ctx):
"""Display workers graph."""
def simplearg(arg):
return maybe_list(itemgetter(0, 2)(arg.partition(':')))
def maybe_list(l, sep=','):
return l[0], l[1].split(sep) if sep in l[1] else l[1]
args = dict(simplearg(arg) for arg in ctx.args)
generic = 'generic' in args
def generic_label(node):
return '{} ({}://)'.format(type(node).__name__,
node._label.split('://')[0])
class Node:
force_label = None
scheme = {}
def __init__(self, label, pos=None):
self._label = label
self.pos = pos
def label(self):
return self._label
def __str__(self):
return self.label()
class Thread(Node):
scheme = {
'fillcolor': 'lightcyan4',
'fontcolor': 'yellow',
'shape': 'oval',
'fontsize': 10,
'width': 0.3,
'color': 'black',
}
def __init__(self, label, **kwargs):
self.real_label = label
super().__init__(
label=f'thr-{next(tids)}',
pos=0,
)
class Formatter(GraphFormatter):
def label(self, obj):
return obj and obj.label()
def node(self, obj):
scheme = dict(obj.scheme) if obj.pos else obj.scheme
if isinstance(obj, Thread):
scheme['label'] = obj.real_label
return self.draw_node(
obj, dict(self.node_scheme, **scheme),
)
def terminal_node(self, obj):
return self.draw_node(
obj, dict(self.term_scheme, **obj.scheme),
)
def edge(self, a, b, **attrs):
if isinstance(a, Thread):
attrs.update(arrowhead='none', arrowtail='tee')
return self.draw_edge(a, b, self.edge_scheme, attrs)
def subscript(n):
S = {'0': '', '1': '', '2': '', '3': '', '4': '',
'5': '', '6': '', '7': '', '8': '', '9': ''}
return ''.join([S[i] for i in str(n)])
class Worker(Node):
pass
class Backend(Node):
scheme = {
'shape': 'folder',
'width': 2,
'height': 1,
'color': 'black',
'fillcolor': 'peachpuff3',
}
def label(self):
return generic_label(self) if generic else self._label
class Broker(Node):
scheme = {
'shape': 'circle',
'fillcolor': 'cadetblue3',
'color': 'cadetblue4',
'height': 1,
}
def label(self):
return generic_label(self) if generic else self._label
from itertools import count
tids = count(1)
Wmax = int(args.get('wmax', 4) or 0)
Tmax = int(args.get('tmax', 3) or 0)
def maybe_abbr(l, name, max=Wmax):
size = len(l)
abbr = max and size > max
if 'enumerate' in args:
l = [f'{name}{subscript(i + 1)}'
for i, obj in enumerate(l)]
if abbr:
l = l[0:max - 1] + [l[size - 1]]
l[max - 2] = '{}⎨…{}'.format(
name[0], subscript(size - (max - 1)))
return l
app = ctx.obj.app
try:
workers = args['nodes']
threads = args.get('threads') or []
except KeyError:
replies = app.control.inspect().stats() or {}
workers, threads = [], []
for worker, reply in replies.items():
workers.append(worker)
threads.append(reply['pool']['max-concurrency'])
wlen = len(workers)
backend = args.get('backend', app.conf.result_backend)
threads_for = {}
workers = maybe_abbr(workers, 'Worker')
if Wmax and wlen > Wmax:
threads = threads[0:3] + [threads[-1]]
for i, threads in enumerate(threads):
threads_for[workers[i]] = maybe_abbr(
list(range(int(threads))), 'P', Tmax,
)
broker = Broker(args.get(
'broker', app.connection_for_read().as_uri()))
backend = Backend(backend) if backend else None
deps = DependencyGraph(formatter=Formatter())
deps.add_arc(broker)
if backend:
deps.add_arc(backend)
curworker = [0]
for i, worker in enumerate(workers):
worker = Worker(worker, pos=i)
deps.add_arc(worker)
deps.add_edge(worker, broker)
if backend:
deps.add_edge(worker, backend)
threads = threads_for.get(worker._label)
if threads:
for thread in threads:
thread = Thread(thread)
deps.add_arc(thread)
deps.add_edge(thread, worker)
curworker[0] += 1
deps.to_dot(sys.stdout)

View File

@@ -0,0 +1,38 @@
"""The ``celery list bindings`` command, used to inspect queue bindings."""
import click
from celery.bin.base import CeleryCommand, handle_preload_options
@click.group(name="list")
@click.pass_context
@handle_preload_options
def list_(ctx):
"""Get info from broker.
Note:
For RabbitMQ the management plugin is required.
"""
@list_.command(cls=CeleryCommand)
@click.pass_context
def bindings(ctx):
"""Inspect queue bindings."""
# TODO: Consider using a table formatter for this command.
app = ctx.obj.app
with app.connection() as conn:
app.amqp.TaskConsumer(conn).declare()
try:
bindings = conn.manager.get_bindings()
except NotImplementedError:
raise click.UsageError('Your transport cannot list bindings.')
def fmt(q, e, r):
ctx.obj.echo(f'{q:<28} {e:<28} {r}')
fmt('Queue', 'Exchange', 'Routing Key')
fmt('-' * 16, '-' * 16, '-' * 16)
for b in bindings:
fmt(b['destination'], b['source'], b['routing_key'])

View File

@@ -0,0 +1,157 @@
"""The ``celery logtool`` command."""
import re
from collections import Counter
from fileinput import FileInput
import click
from celery.bin.base import CeleryCommand, handle_preload_options
__all__ = ('logtool',)
RE_LOG_START = re.compile(r'^\[\d\d\d\d\-\d\d-\d\d ')
RE_TASK_RECEIVED = re.compile(r'.+?\] Received')
RE_TASK_READY = re.compile(r'.+?\] Task')
RE_TASK_INFO = re.compile(r'.+?([\w\.]+)\[(.+?)\].+')
RE_TASK_RESULT = re.compile(r'.+?[\w\.]+\[.+?\] (.+)')
REPORT_FORMAT = """
Report
======
Task total: {task[total]}
Task errors: {task[errors]}
Task success: {task[succeeded]}
Task completed: {task[completed]}
Tasks
=====
{task[types].format}
"""
class _task_counts(list):
@property
def format(self):
return '\n'.join('{}: {}'.format(*i) for i in self)
def task_info(line):
m = RE_TASK_INFO.match(line)
return m.groups()
class Audit:
def __init__(self, on_task_error=None, on_trace=None, on_debug=None):
self.ids = set()
self.names = {}
self.results = {}
self.ready = set()
self.task_types = Counter()
self.task_errors = 0
self.on_task_error = on_task_error
self.on_trace = on_trace
self.on_debug = on_debug
self.prev_line = None
def run(self, files):
for line in FileInput(files):
self.feed(line)
return self
def task_received(self, line, task_name, task_id):
self.names[task_id] = task_name
self.ids.add(task_id)
self.task_types[task_name] += 1
def task_ready(self, line, task_name, task_id, result):
self.ready.add(task_id)
self.results[task_id] = result
if 'succeeded' not in result:
self.task_error(line, task_name, task_id, result)
def task_error(self, line, task_name, task_id, result):
self.task_errors += 1
if self.on_task_error:
self.on_task_error(line, task_name, task_id, result)
def feed(self, line):
if RE_LOG_START.match(line):
if RE_TASK_RECEIVED.match(line):
task_name, task_id = task_info(line)
self.task_received(line, task_name, task_id)
elif RE_TASK_READY.match(line):
task_name, task_id = task_info(line)
result = RE_TASK_RESULT.match(line)
if result:
result, = result.groups()
self.task_ready(line, task_name, task_id, result)
else:
if self.on_debug:
self.on_debug(line)
self.prev_line = line
else:
if self.on_trace:
self.on_trace('\n'.join(filter(None, [self.prev_line, line])))
self.prev_line = None
def incomplete_tasks(self):
return self.ids ^ self.ready
def report(self):
return {
'task': {
'types': _task_counts(self.task_types.most_common()),
'total': len(self.ids),
'errors': self.task_errors,
'completed': len(self.ready),
'succeeded': len(self.ready) - self.task_errors,
}
}
@click.group()
@click.pass_context
@handle_preload_options
def logtool(ctx):
"""The ``celery logtool`` command."""
@logtool.command(cls=CeleryCommand)
@click.argument('files', nargs=-1)
@click.pass_context
def stats(ctx, files):
ctx.obj.echo(REPORT_FORMAT.format(
**Audit().run(files).report()
))
@logtool.command(cls=CeleryCommand)
@click.argument('files', nargs=-1)
@click.pass_context
def traces(ctx, files):
Audit(on_trace=ctx.obj.echo).run(files)
@logtool.command(cls=CeleryCommand)
@click.argument('files', nargs=-1)
@click.pass_context
def errors(ctx, files):
Audit(on_task_error=lambda line, *_: ctx.obj.echo(line)).run(files)
@logtool.command(cls=CeleryCommand)
@click.argument('files', nargs=-1)
@click.pass_context
def incomplete(ctx, files):
audit = Audit()
audit.run(files)
for task_id in audit.incomplete_tasks():
ctx.obj.echo(f'Did not complete: {task_id}')
@logtool.command(cls=CeleryCommand)
@click.argument('files', nargs=-1)
@click.pass_context
def debug(ctx, files):
Audit(on_debug=ctx.obj.echo).run(files)

View File

@@ -0,0 +1,63 @@
"""The ``celery migrate`` command, used to filter and move messages."""
import click
from kombu import Connection
from celery.bin.base import CeleryCommand, CeleryOption, handle_preload_options
from celery.contrib.migrate import migrate_tasks
@click.command(cls=CeleryCommand)
@click.argument('source')
@click.argument('destination')
@click.option('-n',
'--limit',
cls=CeleryOption,
type=int,
help_group='Migration Options',
help='Number of tasks to consume.')
@click.option('-t',
'--timeout',
cls=CeleryOption,
type=float,
help_group='Migration Options',
help='Timeout in seconds waiting for tasks.')
@click.option('-a',
'--ack-messages',
cls=CeleryOption,
is_flag=True,
help_group='Migration Options',
help='Ack messages from source broker.')
@click.option('-T',
'--tasks',
cls=CeleryOption,
help_group='Migration Options',
help='List of task names to filter on.')
@click.option('-Q',
'--queues',
cls=CeleryOption,
help_group='Migration Options',
help='List of queues to migrate.')
@click.option('-F',
'--forever',
cls=CeleryOption,
is_flag=True,
help_group='Migration Options',
help='Continually migrate tasks until killed.')
@click.pass_context
@handle_preload_options
def migrate(ctx, source, destination, **kwargs):
"""Migrate tasks from one broker to another.
Warning:
This command is experimental, make sure you have a backup of
the tasks before you continue.
"""
# TODO: Use a progress bar
def on_migrate_task(state, body, message):
ctx.obj.echo(f"Migrating task {state.count}/{state.strtotal}: {body}")
migrate_tasks(Connection(source),
Connection(destination),
callback=on_migrate_task,
**kwargs)

View File

@@ -0,0 +1,480 @@
"""Start multiple worker instances from the command-line.
.. program:: celery multi
Examples
========
.. code-block:: console
$ # Single worker with explicit name and events enabled.
$ celery multi start Leslie -E
$ # Pidfiles and logfiles are stored in the current directory
$ # by default. Use --pidfile and --logfile argument to change
$ # this. The abbreviation %n will be expanded to the current
$ # node name.
$ celery multi start Leslie -E --pidfile=/var/run/celery/%n.pid
--logfile=/var/log/celery/%n%I.log
$ # You need to add the same arguments when you restart,
$ # as these aren't persisted anywhere.
$ celery multi restart Leslie -E --pidfile=/var/run/celery/%n.pid
--logfile=/var/log/celery/%n%I.log
$ # To stop the node, you need to specify the same pidfile.
$ celery multi stop Leslie --pidfile=/var/run/celery/%n.pid
$ # 3 workers, with 3 processes each
$ celery multi start 3 -c 3
celery worker -n celery1@myhost -c 3
celery worker -n celery2@myhost -c 3
celery worker -n celery3@myhost -c 3
$ # override name prefix when using range
$ celery multi start 3 --range-prefix=worker -c 3
celery worker -n worker1@myhost -c 3
celery worker -n worker2@myhost -c 3
celery worker -n worker3@myhost -c 3
$ # start 3 named workers
$ celery multi start image video data -c 3
celery worker -n image@myhost -c 3
celery worker -n video@myhost -c 3
celery worker -n data@myhost -c 3
$ # specify custom hostname
$ celery multi start 2 --hostname=worker.example.com -c 3
celery worker -n celery1@worker.example.com -c 3
celery worker -n celery2@worker.example.com -c 3
$ # specify fully qualified nodenames
$ celery multi start foo@worker.example.com bar@worker.example.com -c 3
$ # fully qualified nodenames but using the current hostname
$ celery multi start foo@%h bar@%h
$ # Advanced example starting 10 workers in the background:
$ # * Three of the workers processes the images and video queue
$ # * Two of the workers processes the data queue with loglevel DEBUG
$ # * the rest processes the default' queue.
$ celery multi start 10 -l INFO -Q:1-3 images,video -Q:4,5 data
-Q default -L:4,5 DEBUG
$ # You can show the commands necessary to start the workers with
$ # the 'show' command:
$ celery multi show 10 -l INFO -Q:1-3 images,video -Q:4,5 data
-Q default -L:4,5 DEBUG
$ # Additional options are added to each celery worker's command,
$ # but you can also modify the options for ranges of, or specific workers
$ # 3 workers: Two with 3 processes, and one with 10 processes.
$ celery multi start 3 -c 3 -c:1 10
celery worker -n celery1@myhost -c 10
celery worker -n celery2@myhost -c 3
celery worker -n celery3@myhost -c 3
$ # can also specify options for named workers
$ celery multi start image video data -c 3 -c:image 10
celery worker -n image@myhost -c 10
celery worker -n video@myhost -c 3
celery worker -n data@myhost -c 3
$ # ranges and lists of workers in options is also allowed:
$ # (-c:1-3 can also be written as -c:1,2,3)
$ celery multi start 5 -c 3 -c:1-3 10
celery worker -n celery1@myhost -c 10
celery worker -n celery2@myhost -c 10
celery worker -n celery3@myhost -c 10
celery worker -n celery4@myhost -c 3
celery worker -n celery5@myhost -c 3
$ # lists also works with named workers
$ celery multi start foo bar baz xuzzy -c 3 -c:foo,bar,baz 10
celery worker -n foo@myhost -c 10
celery worker -n bar@myhost -c 10
celery worker -n baz@myhost -c 10
celery worker -n xuzzy@myhost -c 3
"""
import os
import signal
import sys
from functools import wraps
import click
from kombu.utils.objects import cached_property
from celery import VERSION_BANNER
from celery.apps.multi import Cluster, MultiParser, NamespacedOptionParser
from celery.bin.base import CeleryCommand, handle_preload_options
from celery.platforms import EX_FAILURE, EX_OK, signals
from celery.utils import term
from celery.utils.text import pluralize
__all__ = ('MultiTool',)
USAGE = """\
usage: {prog_name} start <node1 node2 nodeN|range> [worker options]
{prog_name} stop <n1 n2 nN|range> [-SIG (default: -TERM)]
{prog_name} restart <n1 n2 nN|range> [-SIG] [worker options]
{prog_name} kill <n1 n2 nN|range>
{prog_name} show <n1 n2 nN|range> [worker options]
{prog_name} get hostname <n1 n2 nN|range> [-qv] [worker options]
{prog_name} names <n1 n2 nN|range>
{prog_name} expand template <n1 n2 nN|range>
{prog_name} help
additional options (must appear after command name):
* --nosplash: Don't display program info.
* --quiet: Don't show as much output.
* --verbose: Show more output.
* --no-color: Don't display colors.
"""
def main():
sys.exit(MultiTool().execute_from_commandline(sys.argv))
def splash(fun):
@wraps(fun)
def _inner(self, *args, **kwargs):
self.splash()
return fun(self, *args, **kwargs)
return _inner
def using_cluster(fun):
@wraps(fun)
def _inner(self, *argv, **kwargs):
return fun(self, self.cluster_from_argv(argv), **kwargs)
return _inner
def using_cluster_and_sig(fun):
@wraps(fun)
def _inner(self, *argv, **kwargs):
p, cluster = self._cluster_from_argv(argv)
sig = self._find_sig_argument(p)
return fun(self, cluster, sig, **kwargs)
return _inner
class TermLogger:
splash_text = 'celery multi v{version}'
splash_context = {'version': VERSION_BANNER}
#: Final exit code.
retcode = 0
def setup_terminal(self, stdout, stderr,
nosplash=False, quiet=False, verbose=False,
no_color=False, **kwargs):
self.stdout = stdout or sys.stdout
self.stderr = stderr or sys.stderr
self.nosplash = nosplash
self.quiet = quiet
self.verbose = verbose
self.no_color = no_color
def ok(self, m, newline=True, file=None):
self.say(m, newline=newline, file=file)
return EX_OK
def say(self, m, newline=True, file=None):
print(m, file=file or self.stdout, end='\n' if newline else '')
def carp(self, m, newline=True, file=None):
return self.say(m, newline, file or self.stderr)
def error(self, msg=None):
if msg:
self.carp(msg)
self.usage()
return EX_FAILURE
def info(self, msg, newline=True):
if self.verbose:
self.note(msg, newline=newline)
def note(self, msg, newline=True):
if not self.quiet:
self.say(str(msg), newline=newline)
@splash
def usage(self):
self.say(USAGE.format(prog_name=self.prog_name))
def splash(self):
if not self.nosplash:
self.note(self.colored.cyan(
self.splash_text.format(**self.splash_context)))
@cached_property
def colored(self):
return term.colored(enabled=not self.no_color)
class MultiTool(TermLogger):
"""The ``celery multi`` program."""
MultiParser = MultiParser
OptionParser = NamespacedOptionParser
reserved_options = [
('--nosplash', 'nosplash'),
('--quiet', 'quiet'),
('-q', 'quiet'),
('--verbose', 'verbose'),
('--no-color', 'no_color'),
]
def __init__(self, env=None, cmd=None,
fh=None, stdout=None, stderr=None, **kwargs):
# fh is an old alias to stdout.
self.env = env
self.cmd = cmd
self.setup_terminal(stdout or fh, stderr, **kwargs)
self.fh = self.stdout
self.prog_name = 'celery multi'
self.commands = {
'start': self.start,
'show': self.show,
'stop': self.stop,
'stopwait': self.stopwait,
'stop_verify': self.stopwait, # compat alias
'restart': self.restart,
'kill': self.kill,
'names': self.names,
'expand': self.expand,
'get': self.get,
'help': self.help,
}
def execute_from_commandline(self, argv, cmd=None):
# Reserve the --nosplash|--quiet|-q/--verbose options.
argv = self._handle_reserved_options(argv)
self.cmd = cmd if cmd is not None else self.cmd
self.prog_name = os.path.basename(argv.pop(0))
if not self.validate_arguments(argv):
return self.error()
return self.call_command(argv[0], argv[1:])
def validate_arguments(self, argv):
return argv and argv[0][0] != '-'
def call_command(self, command, argv):
try:
return self.commands[command](*argv) or EX_OK
except KeyError:
return self.error(f'Invalid command: {command}')
def _handle_reserved_options(self, argv):
argv = list(argv) # don't modify callers argv.
for arg, attr in self.reserved_options:
if arg in argv:
setattr(self, attr, bool(argv.pop(argv.index(arg))))
return argv
@splash
@using_cluster
def start(self, cluster):
self.note('> Starting nodes...')
return int(any(cluster.start()))
@splash
@using_cluster_and_sig
def stop(self, cluster, sig, **kwargs):
return cluster.stop(sig=sig, **kwargs)
@splash
@using_cluster_and_sig
def stopwait(self, cluster, sig, **kwargs):
return cluster.stopwait(sig=sig, **kwargs)
stop_verify = stopwait # compat
@splash
@using_cluster_and_sig
def restart(self, cluster, sig, **kwargs):
return int(any(cluster.restart(sig=sig, **kwargs)))
@using_cluster
def names(self, cluster):
self.say('\n'.join(n.name for n in cluster))
def get(self, wanted, *argv):
try:
node = self.cluster_from_argv(argv).find(wanted)
except KeyError:
return EX_FAILURE
else:
return self.ok(' '.join(node.argv))
@using_cluster
def show(self, cluster):
return self.ok('\n'.join(
' '.join(node.argv_with_executable)
for node in cluster
))
@splash
@using_cluster
def kill(self, cluster):
return cluster.kill()
def expand(self, template, *argv):
return self.ok('\n'.join(
node.expander(template)
for node in self.cluster_from_argv(argv)
))
def help(self, *argv):
self.say(__doc__)
def _find_sig_argument(self, p, default=signal.SIGTERM):
args = p.args[len(p.values):]
for arg in reversed(args):
if len(arg) == 2 and arg[0] == '-':
try:
return int(arg[1])
except ValueError:
pass
if arg[0] == '-':
try:
return signals.signum(arg[1:])
except (AttributeError, TypeError):
pass
return default
def _nodes_from_argv(self, argv, cmd=None):
cmd = cmd if cmd is not None else self.cmd
p = self.OptionParser(argv)
p.parse()
return p, self.MultiParser(cmd=cmd).parse(p)
def cluster_from_argv(self, argv, cmd=None):
_, cluster = self._cluster_from_argv(argv, cmd=cmd)
return cluster
def _cluster_from_argv(self, argv, cmd=None):
p, nodes = self._nodes_from_argv(argv, cmd=cmd)
return p, self.Cluster(list(nodes), cmd=cmd)
def Cluster(self, nodes, cmd=None):
return Cluster(
nodes,
cmd=cmd,
env=self.env,
on_stopping_preamble=self.on_stopping_preamble,
on_send_signal=self.on_send_signal,
on_still_waiting_for=self.on_still_waiting_for,
on_still_waiting_progress=self.on_still_waiting_progress,
on_still_waiting_end=self.on_still_waiting_end,
on_node_start=self.on_node_start,
on_node_restart=self.on_node_restart,
on_node_shutdown_ok=self.on_node_shutdown_ok,
on_node_status=self.on_node_status,
on_node_signal_dead=self.on_node_signal_dead,
on_node_signal=self.on_node_signal,
on_node_down=self.on_node_down,
on_child_spawn=self.on_child_spawn,
on_child_signalled=self.on_child_signalled,
on_child_failure=self.on_child_failure,
)
def on_stopping_preamble(self, nodes):
self.note(self.colored.blue('> Stopping nodes...'))
def on_send_signal(self, node, sig):
self.note('\t> {0.name}: {1} -> {0.pid}'.format(node, sig))
def on_still_waiting_for(self, nodes):
num_left = len(nodes)
if num_left:
self.note(self.colored.blue(
'> Waiting for {} {} -> {}...'.format(
num_left, pluralize(num_left, 'node'),
', '.join(str(node.pid) for node in nodes)),
), newline=False)
def on_still_waiting_progress(self, nodes):
self.note('.', newline=False)
def on_still_waiting_end(self):
self.note('')
def on_node_signal_dead(self, node):
self.note(
'Could not signal {0.name} ({0.pid}): No such process'.format(
node))
def on_node_start(self, node):
self.note(f'\t> {node.name}: ', newline=False)
def on_node_restart(self, node):
self.note(self.colored.blue(
f'> Restarting node {node.name}: '), newline=False)
def on_node_down(self, node):
self.note(f'> {node.name}: {self.DOWN}')
def on_node_shutdown_ok(self, node):
self.note(f'\n\t> {node.name}: {self.OK}')
def on_node_status(self, node, retval):
self.note(retval and self.FAILED or self.OK)
def on_node_signal(self, node, sig):
self.note('Sending {sig} to node {0.name} ({0.pid})'.format(
node, sig=sig))
def on_child_spawn(self, node, argstr, env):
self.info(f' {argstr}')
def on_child_signalled(self, node, signum):
self.note(f'* Child was terminated by signal {signum}')
def on_child_failure(self, node, retcode):
self.note(f'* Child terminated with exit code {retcode}')
@cached_property
def OK(self):
return str(self.colored.green('OK'))
@cached_property
def FAILED(self):
return str(self.colored.red('FAILED'))
@cached_property
def DOWN(self):
return str(self.colored.magenta('DOWN'))
@click.command(
cls=CeleryCommand,
context_settings={
'allow_extra_args': True,
'ignore_unknown_options': True
}
)
@click.pass_context
@handle_preload_options
def multi(ctx, **kwargs):
"""Start multiple worker instances."""
cmd = MultiTool(quiet=ctx.obj.quiet, no_color=ctx.obj.no_color)
# In 4.x, celery multi ignores the global --app option.
# Since in 5.0 the --app option is global only we
# rearrange the arguments so that the MultiTool will parse them correctly.
args = sys.argv[1:]
args = args[args.index('multi'):] + args[:args.index('multi')]
return cmd.execute_from_commandline(args)

View File

@@ -0,0 +1,70 @@
"""The ``celery purge`` program, used to delete messages from queues."""
import click
from celery.bin.base import COMMA_SEPARATED_LIST, CeleryCommand, CeleryOption, handle_preload_options
from celery.utils import text
@click.command(cls=CeleryCommand, context_settings={
'allow_extra_args': True
})
@click.option('-f',
'--force',
cls=CeleryOption,
is_flag=True,
help_group='Purging Options',
help="Don't prompt for verification.")
@click.option('-Q',
'--queues',
cls=CeleryOption,
type=COMMA_SEPARATED_LIST,
help_group='Purging Options',
help="Comma separated list of queue names to purge.")
@click.option('-X',
'--exclude-queues',
cls=CeleryOption,
type=COMMA_SEPARATED_LIST,
help_group='Purging Options',
help="Comma separated list of queues names not to purge.")
@click.pass_context
@handle_preload_options
def purge(ctx, force, queues, exclude_queues, **kwargs):
"""Erase all messages from all known task queues.
Warning:
There's no undo operation for this command.
"""
app = ctx.obj.app
queues = set(queues or app.amqp.queues.keys())
exclude_queues = set(exclude_queues or [])
names = queues - exclude_queues
qnum = len(names)
if names:
queues_headline = text.pluralize(qnum, 'queue')
if not force:
queue_names = ', '.join(sorted(names))
click.confirm(f"{ctx.obj.style('WARNING', fg='red')}:"
"This will remove all tasks from "
f"{queues_headline}: {queue_names}.\n"
" There is no undo for this operation!\n\n"
"(to skip this prompt use the -f option)\n"
"Are you sure you want to delete all tasks?",
abort=True)
def _purge(conn, queue):
try:
return conn.default_channel.queue_purge(queue) or 0
except conn.channel_errors:
return 0
with app.connection_for_write() as conn:
messages = sum(_purge(conn, queue) for queue in names)
if messages:
messages_headline = text.pluralize(messages, 'message')
ctx.obj.echo(f"Purged {messages} {messages_headline} from "
f"{qnum} known task {queues_headline}.")
else:
ctx.obj.echo(f"No messages purged from {qnum} {queues_headline}.")

View File

@@ -0,0 +1,30 @@
"""The ``celery result`` program, used to inspect task results."""
import click
from celery.bin.base import CeleryCommand, CeleryOption, handle_preload_options
@click.command(cls=CeleryCommand)
@click.argument('task_id')
@click.option('-t',
'--task',
cls=CeleryOption,
help_group='Result Options',
help="Name of task (if custom backend).")
@click.option('--traceback',
cls=CeleryOption,
is_flag=True,
help_group='Result Options',
help="Show traceback instead.")
@click.pass_context
@handle_preload_options
def result(ctx, task_id, task, traceback):
"""Print the return value for a given task id."""
app = ctx.obj.app
result_cls = app.tasks[task].AsyncResult if task else app.AsyncResult
task_result = result_cls(task_id)
value = task_result.traceback if traceback else task_result.get()
# TODO: Prettify result
ctx.obj.echo(value)

View File

@@ -0,0 +1,173 @@
"""The ``celery shell`` program, used to start a REPL."""
import os
import sys
from importlib import import_module
import click
from celery.bin.base import CeleryCommand, CeleryOption, handle_preload_options
def _invoke_fallback_shell(locals):
import code
try:
import readline
except ImportError:
pass
else:
import rlcompleter
readline.set_completer(
rlcompleter.Completer(locals).complete)
readline.parse_and_bind('tab:complete')
code.interact(local=locals)
def _invoke_bpython_shell(locals):
import bpython
bpython.embed(locals)
def _invoke_ipython_shell(locals):
for ip in (_ipython, _ipython_pre_10,
_ipython_terminal, _ipython_010,
_no_ipython):
try:
return ip(locals)
except ImportError:
pass
def _ipython(locals):
from IPython import start_ipython
start_ipython(argv=[], user_ns=locals)
def _ipython_pre_10(locals): # pragma: no cover
from IPython.frontend.terminal.ipapp import TerminalIPythonApp
app = TerminalIPythonApp.instance()
app.initialize(argv=[])
app.shell.user_ns.update(locals)
app.start()
def _ipython_terminal(locals): # pragma: no cover
from IPython.terminal import embed
embed.TerminalInteractiveShell(user_ns=locals).mainloop()
def _ipython_010(locals): # pragma: no cover
from IPython.Shell import IPShell
IPShell(argv=[], user_ns=locals).mainloop()
def _no_ipython(self): # pragma: no cover
raise ImportError('no suitable ipython found')
def _invoke_default_shell(locals):
try:
import IPython # noqa
except ImportError:
try:
import bpython # noqa
except ImportError:
_invoke_fallback_shell(locals)
else:
_invoke_bpython_shell(locals)
else:
_invoke_ipython_shell(locals)
@click.command(cls=CeleryCommand, context_settings={
'allow_extra_args': True
})
@click.option('-I',
'--ipython',
is_flag=True,
cls=CeleryOption,
help_group="Shell Options",
help="Force IPython.")
@click.option('-B',
'--bpython',
is_flag=True,
cls=CeleryOption,
help_group="Shell Options",
help="Force bpython.")
@click.option('--python',
is_flag=True,
cls=CeleryOption,
help_group="Shell Options",
help="Force default Python shell.")
@click.option('-T',
'--without-tasks',
is_flag=True,
cls=CeleryOption,
help_group="Shell Options",
help="Don't add tasks to locals.")
@click.option('--eventlet',
is_flag=True,
cls=CeleryOption,
help_group="Shell Options",
help="Use eventlet.")
@click.option('--gevent',
is_flag=True,
cls=CeleryOption,
help_group="Shell Options",
help="Use gevent.")
@click.pass_context
@handle_preload_options
def shell(ctx, ipython=False, bpython=False,
python=False, without_tasks=False, eventlet=False,
gevent=False, **kwargs):
"""Start shell session with convenient access to celery symbols.
The following symbols will be added to the main globals:
- ``celery``: the current application.
- ``chord``, ``group``, ``chain``, ``chunks``,
``xmap``, ``xstarmap`` ``subtask``, ``Task``
- all registered tasks.
"""
sys.path.insert(0, os.getcwd())
if eventlet:
import_module('celery.concurrency.eventlet')
if gevent:
import_module('celery.concurrency.gevent')
import celery
app = ctx.obj.app
app.loader.import_default_modules()
# pylint: disable=attribute-defined-outside-init
locals = {
'app': app,
'celery': app,
'Task': celery.Task,
'chord': celery.chord,
'group': celery.group,
'chain': celery.chain,
'chunks': celery.chunks,
'xmap': celery.xmap,
'xstarmap': celery.xstarmap,
'subtask': celery.subtask,
'signature': celery.signature,
}
if not without_tasks:
locals.update({
task.__name__: task for task in app.tasks.values()
if not task.name.startswith('celery.')
})
if python:
_invoke_fallback_shell(locals)
elif bpython:
try:
_invoke_bpython_shell(locals)
except ImportError:
ctx.obj.echo(f'{ctx.obj.ERROR}: bpython is not installed')
elif ipython:
try:
_invoke_ipython_shell(locals)
except ImportError as e:
ctx.obj.echo(f'{ctx.obj.ERROR}: {e}')
_invoke_default_shell(locals)

View File

@@ -0,0 +1,91 @@
"""The ``celery upgrade`` command, used to upgrade from previous versions."""
import codecs
import sys
import click
from celery.app import defaults
from celery.bin.base import CeleryCommand, CeleryOption, handle_preload_options
from celery.utils.functional import pass1
@click.group()
@click.pass_context
@handle_preload_options
def upgrade(ctx):
"""Perform upgrade between versions."""
def _slurp(filename):
# TODO: Handle case when file does not exist
with codecs.open(filename, 'r', 'utf-8') as read_fh:
return [line for line in read_fh]
def _compat_key(key, namespace='CELERY'):
key = key.upper()
if not key.startswith(namespace):
key = '_'.join([namespace, key])
return key
def _backup(filename, suffix='.orig'):
lines = []
backup_filename = ''.join([filename, suffix])
print(f'writing backup to {backup_filename}...',
file=sys.stderr)
with codecs.open(filename, 'r', 'utf-8') as read_fh:
with codecs.open(backup_filename, 'w', 'utf-8') as backup_fh:
for line in read_fh:
backup_fh.write(line)
lines.append(line)
return lines
def _to_new_key(line, keyfilter=pass1, source=defaults._TO_NEW_KEY):
# sort by length to avoid, for example, broker_transport overriding
# broker_transport_options.
for old_key in reversed(sorted(source, key=lambda x: len(x))):
new_line = line.replace(old_key, keyfilter(source[old_key]))
if line != new_line and 'CELERY_CELERY' not in new_line:
return 1, new_line # only one match per line.
return 0, line
@upgrade.command(cls=CeleryCommand)
@click.argument('filename')
@click.option('--django',
cls=CeleryOption,
is_flag=True,
help_group='Upgrading Options',
help='Upgrade Django project.')
@click.option('--compat',
cls=CeleryOption,
is_flag=True,
help_group='Upgrading Options',
help='Maintain backwards compatibility.')
@click.option('--no-backup',
cls=CeleryOption,
is_flag=True,
help_group='Upgrading Options',
help="Don't backup original files.")
def settings(filename, django, compat, no_backup):
"""Migrate settings from Celery 3.x to Celery 4.x."""
lines = _slurp(filename)
keyfilter = _compat_key if django or compat else pass1
print(f'processing {filename}...', file=sys.stderr)
# gives list of tuples: ``(did_change, line_contents)``
new_lines = [
_to_new_key(line, keyfilter) for line in lines
]
if any(n[0] for n in new_lines): # did have changes
if not no_backup:
_backup(filename)
with codecs.open(filename, 'w', 'utf-8') as write_fh:
for _, line in new_lines:
write_fh.write(line)
print('Changes to your setting have been made!',
file=sys.stdout)
else:
print('Does not seem to require any changes :-)',
file=sys.stdout)

View File

@@ -0,0 +1,360 @@
"""Program used to start a Celery worker instance."""
import os
import sys
import click
from click import ParamType
from click.types import StringParamType
from celery import concurrency
from celery.bin.base import (COMMA_SEPARATED_LIST, LOG_LEVEL, CeleryDaemonCommand, CeleryOption,
handle_preload_options)
from celery.concurrency.base import BasePool
from celery.exceptions import SecurityError
from celery.platforms import EX_FAILURE, EX_OK, detached, maybe_drop_privileges
from celery.utils.log import get_logger
from celery.utils.nodenames import default_nodename, host_format, node_format
logger = get_logger(__name__)
class CeleryBeat(ParamType):
"""Celery Beat flag."""
name = "beat"
def convert(self, value, param, ctx):
if ctx.obj.app.IS_WINDOWS and value:
self.fail('-B option does not work on Windows. '
'Please run celery beat as a separate service.')
return value
class WorkersPool(click.Choice):
"""Workers pool option."""
name = "pool"
def __init__(self):
"""Initialize the workers pool option with the relevant choices."""
super().__init__(concurrency.get_available_pool_names())
def convert(self, value, param, ctx):
# Pools like eventlet/gevent needs to patch libs as early
# as possible.
if isinstance(value, type) and issubclass(value, BasePool):
return value
value = super().convert(value, param, ctx)
worker_pool = ctx.obj.app.conf.worker_pool
if value == 'prefork' and worker_pool:
# If we got the default pool through the CLI
# we need to check if the worker pool was configured.
# If the worker pool was configured, we shouldn't use the default.
value = concurrency.get_implementation(worker_pool)
else:
value = concurrency.get_implementation(value)
if not value:
value = concurrency.get_implementation(worker_pool)
return value
class Hostname(StringParamType):
"""Hostname option."""
name = "hostname"
def convert(self, value, param, ctx):
return host_format(default_nodename(value))
class Autoscale(ParamType):
"""Autoscaling parameter."""
name = "<min workers>, <max workers>"
def convert(self, value, param, ctx):
value = value.split(',')
if len(value) > 2:
self.fail("Expected two comma separated integers or one integer."
f"Got {len(value)} instead.")
if len(value) == 1:
try:
value = (int(value[0]), 0)
except ValueError:
self.fail(f"Expected an integer. Got {value} instead.")
try:
return tuple(reversed(sorted(map(int, value))))
except ValueError:
self.fail("Expected two comma separated integers."
f"Got {value.join(',')} instead.")
CELERY_BEAT = CeleryBeat()
WORKERS_POOL = WorkersPool()
HOSTNAME = Hostname()
AUTOSCALE = Autoscale()
C_FAKEFORK = os.environ.get('C_FAKEFORK')
def detach(path, argv, logfile=None, pidfile=None, uid=None,
gid=None, umask=None, workdir=None, fake=False, app=None,
executable=None, hostname=None):
"""Detach program by argv."""
fake = 1 if C_FAKEFORK else fake
# `detached()` will attempt to touch the logfile to confirm that error
# messages won't be lost after detaching stdout/err, but this means we need
# to pre-format it rather than relying on `setup_logging_subsystem()` like
# we can elsewhere.
logfile = node_format(logfile, hostname)
with detached(logfile, pidfile, uid, gid, umask, workdir, fake,
after_forkers=False):
try:
if executable is not None:
path = executable
os.execv(path, [path] + argv)
return EX_OK
except Exception: # pylint: disable=broad-except
if app is None:
from celery import current_app
app = current_app
app.log.setup_logging_subsystem(
'ERROR', logfile, hostname=hostname)
logger.critical("Can't exec %r", ' '.join([path] + argv),
exc_info=True)
return EX_FAILURE
@click.command(cls=CeleryDaemonCommand,
context_settings={'allow_extra_args': True})
@click.option('-n',
'--hostname',
default=host_format(default_nodename(None)),
cls=CeleryOption,
type=HOSTNAME,
help_group="Worker Options",
help="Set custom hostname (e.g., 'w1@%%h'). "
"Expands: %%h (hostname), %%n (name) and %%d, (domain).")
@click.option('-D',
'--detach',
cls=CeleryOption,
is_flag=True,
default=False,
help_group="Worker Options",
help="Start worker as a background process.")
@click.option('-S',
'--statedb',
cls=CeleryOption,
type=click.Path(),
callback=lambda ctx, _,
value: value or ctx.obj.app.conf.worker_state_db,
help_group="Worker Options",
help="Path to the state database. The extension '.db' may be "
"appended to the filename.")
@click.option('-l',
'--loglevel',
default='WARNING',
cls=CeleryOption,
type=LOG_LEVEL,
help_group="Worker Options",
help="Logging level.")
@click.option('-O',
'--optimization',
default='default',
cls=CeleryOption,
type=click.Choice(('default', 'fair')),
help_group="Worker Options",
help="Apply optimization profile.")
@click.option('--prefetch-multiplier',
type=int,
metavar="<prefetch multiplier>",
callback=lambda ctx, _,
value: value or ctx.obj.app.conf.worker_prefetch_multiplier,
cls=CeleryOption,
help_group="Worker Options",
help="Set custom prefetch multiplier value "
"for this worker instance.")
@click.option('-c',
'--concurrency',
type=int,
metavar="<concurrency>",
callback=lambda ctx, _,
value: value or ctx.obj.app.conf.worker_concurrency,
cls=CeleryOption,
help_group="Pool Options",
help="Number of child processes processing the queue. "
"The default is the number of CPUs available"
" on your system.")
@click.option('-P',
'--pool',
default='prefork',
type=WORKERS_POOL,
cls=CeleryOption,
help_group="Pool Options",
help="Pool implementation.")
@click.option('-E',
'--task-events',
'--events',
is_flag=True,
default=None,
cls=CeleryOption,
help_group="Pool Options",
help="Send task-related events that can be captured by monitors"
" like celery events, celerymon, and others.")
@click.option('--time-limit',
type=float,
cls=CeleryOption,
help_group="Pool Options",
help="Enables a hard time limit "
"(in seconds int/float) for tasks.")
@click.option('--soft-time-limit',
type=float,
cls=CeleryOption,
help_group="Pool Options",
help="Enables a soft time limit "
"(in seconds int/float) for tasks.")
@click.option('--max-tasks-per-child',
type=int,
cls=CeleryOption,
help_group="Pool Options",
help="Maximum number of tasks a pool worker can execute before "
"it's terminated and replaced by a new worker.")
@click.option('--max-memory-per-child',
type=int,
cls=CeleryOption,
help_group="Pool Options",
help="Maximum amount of resident memory, in KiB, that may be "
"consumed by a child process before it will be replaced "
"by a new one. If a single task causes a child process "
"to exceed this limit, the task will be completed and "
"the child process will be replaced afterwards.\n"
"Default: no limit.")
@click.option('--purge',
'--discard',
is_flag=True,
cls=CeleryOption,
help_group="Queue Options")
@click.option('--queues',
'-Q',
type=COMMA_SEPARATED_LIST,
cls=CeleryOption,
help_group="Queue Options")
@click.option('--exclude-queues',
'-X',
type=COMMA_SEPARATED_LIST,
cls=CeleryOption,
help_group="Queue Options")
@click.option('--include',
'-I',
type=COMMA_SEPARATED_LIST,
cls=CeleryOption,
help_group="Queue Options")
@click.option('--without-gossip',
is_flag=True,
cls=CeleryOption,
help_group="Features")
@click.option('--without-mingle',
is_flag=True,
cls=CeleryOption,
help_group="Features")
@click.option('--without-heartbeat',
is_flag=True,
cls=CeleryOption,
help_group="Features", )
@click.option('--heartbeat-interval',
type=int,
cls=CeleryOption,
help_group="Features", )
@click.option('--autoscale',
type=AUTOSCALE,
cls=CeleryOption,
help_group="Features", )
@click.option('-B',
'--beat',
type=CELERY_BEAT,
cls=CeleryOption,
is_flag=True,
help_group="Embedded Beat Options")
@click.option('-s',
'--schedule-filename',
'--schedule',
callback=lambda ctx, _,
value: value or ctx.obj.app.conf.beat_schedule_filename,
cls=CeleryOption,
help_group="Embedded Beat Options")
@click.option('--scheduler',
cls=CeleryOption,
help_group="Embedded Beat Options")
@click.pass_context
@handle_preload_options
def worker(ctx, hostname=None, pool_cls=None, app=None, uid=None, gid=None,
loglevel=None, logfile=None, pidfile=None, statedb=None,
**kwargs):
"""Start worker instance.
\b
Examples
--------
\b
$ celery --app=proj worker -l INFO
$ celery -A proj worker -l INFO -Q hipri,lopri
$ celery -A proj worker --concurrency=4
$ celery -A proj worker --concurrency=1000 -P eventlet
$ celery worker --autoscale=10,0
"""
try:
app = ctx.obj.app
if ctx.args:
try:
app.config_from_cmdline(ctx.args, namespace='worker')
except (KeyError, ValueError) as e:
# TODO: Improve the error messages
raise click.UsageError(
"Unable to parse extra configuration from command line.\n"
f"Reason: {e}", ctx=ctx)
if kwargs.get('detach', False):
argv = ['-m', 'celery'] + sys.argv[1:]
if '--detach' in argv:
argv.remove('--detach')
if '-D' in argv:
argv.remove('-D')
if "--uid" in argv:
argv.remove('--uid')
if "--gid" in argv:
argv.remove('--gid')
return detach(sys.executable,
argv,
logfile=logfile,
pidfile=pidfile,
uid=uid, gid=gid,
umask=kwargs.get('umask', None),
workdir=kwargs.get('workdir', None),
app=app,
executable=kwargs.get('executable', None),
hostname=hostname)
maybe_drop_privileges(uid=uid, gid=gid)
worker = app.Worker(
hostname=hostname, pool_cls=pool_cls, loglevel=loglevel,
logfile=logfile, # node format handled by celery.app.log.setup
pidfile=node_format(pidfile, hostname),
statedb=node_format(statedb, hostname),
no_color=ctx.obj.no_color,
quiet=ctx.obj.quiet,
**kwargs)
worker.start()
ctx.exit(worker.exitcode)
except SecurityError as e:
ctx.obj.error(e.args[0])
ctx.exit(1)

View File

@@ -0,0 +1,415 @@
"""A directed acyclic graph of reusable components."""
from collections import deque
from threading import Event
from kombu.common import ignore_errors
from kombu.utils.encoding import bytes_to_str
from kombu.utils.imports import symbol_by_name
from .utils.graph import DependencyGraph, GraphFormatter
from .utils.imports import instantiate, qualname
from .utils.log import get_logger
try:
from greenlet import GreenletExit
except ImportError:
IGNORE_ERRORS = ()
else:
IGNORE_ERRORS = (GreenletExit,)
__all__ = ('Blueprint', 'Step', 'StartStopStep', 'ConsumerStep')
#: States
RUN = 0x1
CLOSE = 0x2
TERMINATE = 0x3
logger = get_logger(__name__)
def _pre(ns, fmt):
return f'| {ns.alias}: {fmt}'
def _label(s):
return s.name.rsplit('.', 1)[-1]
class StepFormatter(GraphFormatter):
"""Graph formatter for :class:`Blueprint`."""
blueprint_prefix = ''
conditional_prefix = ''
blueprint_scheme = {
'shape': 'parallelogram',
'color': 'slategray4',
'fillcolor': 'slategray3',
}
def label(self, step):
return step and '{}{}'.format(
self._get_prefix(step),
bytes_to_str(
(step.label or _label(step)).encode('utf-8', 'ignore')),
)
def _get_prefix(self, step):
if step.last:
return self.blueprint_prefix
if step.conditional:
return self.conditional_prefix
return ''
def node(self, obj, **attrs):
scheme = self.blueprint_scheme if obj.last else self.node_scheme
return self.draw_node(obj, scheme, attrs)
def edge(self, a, b, **attrs):
if a.last:
attrs.update(arrowhead='none', color='darkseagreen3')
return self.draw_edge(a, b, self.edge_scheme, attrs)
class Blueprint:
"""Blueprint containing bootsteps that can be applied to objects.
Arguments:
steps Sequence[Union[str, Step]]: List of steps.
name (str): Set explicit name for this blueprint.
on_start (Callable): Optional callback applied after blueprint start.
on_close (Callable): Optional callback applied before blueprint close.
on_stopped (Callable): Optional callback applied after
blueprint stopped.
"""
GraphFormatter = StepFormatter
name = None
state = None
started = 0
default_steps = set()
state_to_name = {
0: 'initializing',
RUN: 'running',
CLOSE: 'closing',
TERMINATE: 'terminating',
}
def __init__(self, steps=None, name=None,
on_start=None, on_close=None, on_stopped=None):
self.name = name or self.name or qualname(type(self))
self.types = set(steps or []) | set(self.default_steps)
self.on_start = on_start
self.on_close = on_close
self.on_stopped = on_stopped
self.shutdown_complete = Event()
self.steps = {}
def start(self, parent):
self.state = RUN
if self.on_start:
self.on_start()
for i, step in enumerate(s for s in parent.steps if s is not None):
self._debug('Starting %s', step.alias)
self.started = i + 1
step.start(parent)
logger.debug('^-- substep ok')
def human_state(self):
return self.state_to_name[self.state or 0]
def info(self, parent):
info = {}
for step in parent.steps:
info.update(step.info(parent) or {})
return info
def close(self, parent):
if self.on_close:
self.on_close()
self.send_all(parent, 'close', 'closing', reverse=False)
def restart(self, parent, method='stop',
description='restarting', propagate=False):
self.send_all(parent, method, description, propagate=propagate)
def send_all(self, parent, method,
description=None, reverse=True, propagate=True, args=()):
description = description or method.replace('_', ' ')
steps = reversed(parent.steps) if reverse else parent.steps
for step in steps:
if step:
fun = getattr(step, method, None)
if fun is not None:
self._debug('%s %s...',
description.capitalize(), step.alias)
try:
fun(parent, *args)
except Exception as exc: # pylint: disable=broad-except
if propagate:
raise
logger.exception(
'Error on %s %s: %r', description, step.alias, exc)
def stop(self, parent, close=True, terminate=False):
what = 'terminating' if terminate else 'stopping'
if self.state in (CLOSE, TERMINATE):
return
if self.state != RUN or self.started != len(parent.steps):
# Not fully started, can safely exit.
self.state = TERMINATE
self.shutdown_complete.set()
return
self.close(parent)
self.state = CLOSE
self.restart(
parent, 'terminate' if terminate else 'stop',
description=what, propagate=False,
)
if self.on_stopped:
self.on_stopped()
self.state = TERMINATE
self.shutdown_complete.set()
def join(self, timeout=None):
try:
# Will only get here if running green,
# makes sure all greenthreads have exited.
self.shutdown_complete.wait(timeout=timeout)
except IGNORE_ERRORS:
pass
def apply(self, parent, **kwargs):
"""Apply the steps in this blueprint to an object.
This will apply the ``__init__`` and ``include`` methods
of each step, with the object as argument::
step = Step(obj)
...
step.include(obj)
For :class:`StartStopStep` the services created
will also be added to the objects ``steps`` attribute.
"""
self._debug('Preparing bootsteps.')
order = self.order = []
steps = self.steps = self.claim_steps()
self._debug('Building graph...')
for S in self._finalize_steps(steps):
step = S(parent, **kwargs)
steps[step.name] = step
order.append(step)
self._debug('New boot order: {%s}',
', '.join(s.alias for s in self.order))
for step in order:
step.include(parent)
return self
def connect_with(self, other):
self.graph.adjacent.update(other.graph.adjacent)
self.graph.add_edge(type(other.order[0]), type(self.order[-1]))
def __getitem__(self, name):
return self.steps[name]
def _find_last(self):
return next((C for C in self.steps.values() if C.last), None)
def _firstpass(self, steps):
for step in steps.values():
step.requires = [symbol_by_name(dep) for dep in step.requires]
stream = deque(step.requires for step in steps.values())
while stream:
for node in stream.popleft():
node = symbol_by_name(node)
if node.name not in self.steps:
steps[node.name] = node
stream.append(node.requires)
def _finalize_steps(self, steps):
last = self._find_last()
self._firstpass(steps)
it = ((C, C.requires) for C in steps.values())
G = self.graph = DependencyGraph(
it, formatter=self.GraphFormatter(root=last),
)
if last:
for obj in G:
if obj != last:
G.add_edge(last, obj)
try:
return G.topsort()
except KeyError as exc:
raise KeyError('unknown bootstep: %s' % exc)
def claim_steps(self):
return dict(self.load_step(step) for step in self.types)
def load_step(self, step):
step = symbol_by_name(step)
return step.name, step
def _debug(self, msg, *args):
return logger.debug(_pre(self, msg), *args)
@property
def alias(self):
return _label(self)
class StepType(type):
"""Meta-class for steps."""
name = None
requires = None
def __new__(cls, name, bases, attrs):
module = attrs.get('__module__')
qname = f'{module}.{name}' if module else name
attrs.update(
__qualname__=qname,
name=attrs.get('name') or qname,
)
return super().__new__(cls, name, bases, attrs)
def __str__(cls):
return cls.name
def __repr__(cls):
return 'step:{0.name}{{{0.requires!r}}}'.format(cls)
class Step(metaclass=StepType):
"""A Bootstep.
The :meth:`__init__` method is called when the step
is bound to a parent object, and can as such be used
to initialize attributes in the parent object at
parent instantiation-time.
"""
#: Optional step name, will use ``qualname`` if not specified.
name = None
#: Optional short name used for graph outputs and in logs.
label = None
#: Set this to true if the step is enabled based on some condition.
conditional = False
#: List of other steps that that must be started before this step.
#: Note that all dependencies must be in the same blueprint.
requires = ()
#: This flag is reserved for the workers Consumer,
#: since it is required to always be started last.
#: There can only be one object marked last
#: in every blueprint.
last = False
#: This provides the default for :meth:`include_if`.
enabled = True
def __init__(self, parent, **kwargs):
pass
def include_if(self, parent):
"""Return true if bootstep should be included.
You can define this as an optional predicate that decides whether
this step should be created.
"""
return self.enabled
def instantiate(self, name, *args, **kwargs):
return instantiate(name, *args, **kwargs)
def _should_include(self, parent):
if self.include_if(parent):
return True, self.create(parent)
return False, None
def include(self, parent):
return self._should_include(parent)[0]
def create(self, parent):
"""Create the step."""
def __repr__(self):
return f'<step: {self.alias}>'
@property
def alias(self):
return self.label or _label(self)
def info(self, obj):
pass
class StartStopStep(Step):
"""Bootstep that must be started and stopped in order."""
#: Optional obj created by the :meth:`create` method.
#: This is used by :class:`StartStopStep` to keep the
#: original service object.
obj = None
def start(self, parent):
if self.obj:
return self.obj.start()
def stop(self, parent):
if self.obj:
return self.obj.stop()
def close(self, parent):
pass
def terminate(self, parent):
if self.obj:
return getattr(self.obj, 'terminate', self.obj.stop)()
def include(self, parent):
inc, ret = self._should_include(parent)
if inc:
self.obj = ret
parent.steps.append(self)
return inc
class ConsumerStep(StartStopStep):
"""Bootstep that starts a message consumer."""
requires = ('celery.worker.consumer:Connection',)
consumers = None
def get_consumers(self, channel):
raise NotImplementedError('missing get_consumers')
def start(self, c):
channel = c.connection.channel()
self.consumers = self.get_consumers(channel)
for consumer in self.consumers or []:
consumer.consume()
def stop(self, c):
self._close(c, True)
def shutdown(self, c):
self._close(c, False)
def _close(self, c, cancel_consumers=True):
channels = set()
for consumer in self.consumers or []:
if cancel_consumers:
ignore_errors(c.connection, consumer.cancel)
if consumer.channel:
channels.add(consumer.channel)
for channel in channels:
ignore_errors(c.connection, channel.close)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,48 @@
"""Pool implementation abstract factory, and alias definitions."""
import os
# Import from kombu directly as it's used
# early in the import stage, where celery.utils loads
# too much (e.g., for eventlet patching)
from kombu.utils.imports import symbol_by_name
__all__ = ('get_implementation', 'get_available_pool_names',)
ALIASES = {
'prefork': 'celery.concurrency.prefork:TaskPool',
'eventlet': 'celery.concurrency.eventlet:TaskPool',
'gevent': 'celery.concurrency.gevent:TaskPool',
'solo': 'celery.concurrency.solo:TaskPool',
'processes': 'celery.concurrency.prefork:TaskPool', # XXX compat alias
}
try:
import concurrent.futures # noqa
except ImportError:
pass
else:
ALIASES['threads'] = 'celery.concurrency.thread:TaskPool'
#
# Allow for an out-of-tree worker pool implementation. This is used as follows:
#
# - Set the environment variable CELERY_CUSTOM_WORKER_POOL to the name of
# an implementation of :class:`celery.concurrency.base.BasePool` in the
# standard Celery format of "package:class".
# - Select this pool using '--pool custom'.
#
try:
custom = os.environ.get('CELERY_CUSTOM_WORKER_POOL')
except KeyError:
pass
else:
ALIASES['custom'] = custom
def get_implementation(cls):
"""Return pool implementation by name."""
return symbol_by_name(cls, ALIASES)
def get_available_pool_names():
"""Return all available pool type names."""
return tuple(ALIASES.keys())

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,180 @@
"""Base Execution Pool."""
import logging
import os
import sys
import time
from typing import Any, Dict
from billiard.einfo import ExceptionInfo
from billiard.exceptions import WorkerLostError
from kombu.utils.encoding import safe_repr
from celery.exceptions import WorkerShutdown, WorkerTerminate, reraise
from celery.utils import timer2
from celery.utils.log import get_logger
from celery.utils.text import truncate
__all__ = ('BasePool', 'apply_target')
logger = get_logger('celery.pool')
def apply_target(target, args=(), kwargs=None, callback=None,
accept_callback=None, pid=None, getpid=os.getpid,
propagate=(), monotonic=time.monotonic, **_):
"""Apply function within pool context."""
kwargs = {} if not kwargs else kwargs
if accept_callback:
accept_callback(pid or getpid(), monotonic())
try:
ret = target(*args, **kwargs)
except propagate:
raise
except Exception:
raise
except (WorkerShutdown, WorkerTerminate):
raise
except BaseException as exc:
try:
reraise(WorkerLostError, WorkerLostError(repr(exc)),
sys.exc_info()[2])
except WorkerLostError:
callback(ExceptionInfo())
else:
callback(ret)
class BasePool:
"""Task pool."""
RUN = 0x1
CLOSE = 0x2
TERMINATE = 0x3
Timer = timer2.Timer
#: set to true if the pool can be shutdown from within
#: a signal handler.
signal_safe = True
#: set to true if pool uses greenlets.
is_green = False
_state = None
_pool = None
_does_debug = True
#: only used by multiprocessing pool
uses_semaphore = False
task_join_will_block = True
body_can_be_buffer = False
def __init__(self, limit=None, putlocks=True, forking_enable=True,
callbacks_propagate=(), app=None, **options):
self.limit = limit
self.putlocks = putlocks
self.options = options
self.forking_enable = forking_enable
self.callbacks_propagate = callbacks_propagate
self.app = app
def on_start(self):
pass
def did_start_ok(self):
return True
def flush(self):
pass
def on_stop(self):
pass
def register_with_event_loop(self, loop):
pass
def on_apply(self, *args, **kwargs):
pass
def on_terminate(self):
pass
def on_soft_timeout(self, job):
pass
def on_hard_timeout(self, job):
pass
def maintain_pool(self, *args, **kwargs):
pass
def terminate_job(self, pid, signal=None):
raise NotImplementedError(
f'{type(self)} does not implement kill_job')
def restart(self):
raise NotImplementedError(
f'{type(self)} does not implement restart')
def stop(self):
self.on_stop()
self._state = self.TERMINATE
def terminate(self):
self._state = self.TERMINATE
self.on_terminate()
def start(self):
self._does_debug = logger.isEnabledFor(logging.DEBUG)
self.on_start()
self._state = self.RUN
def close(self):
self._state = self.CLOSE
self.on_close()
def on_close(self):
pass
def apply_async(self, target, args=None, kwargs=None, **options):
"""Equivalent of the :func:`apply` built-in function.
Callbacks should optimally return as soon as possible since
otherwise the thread which handles the result will get blocked.
"""
kwargs = {} if not kwargs else kwargs
args = [] if not args else args
if self._does_debug:
logger.debug('TaskPool: Apply %s (args:%s kwargs:%s)',
target, truncate(safe_repr(args), 1024),
truncate(safe_repr(kwargs), 1024))
return self.on_apply(target, args, kwargs,
waitforslot=self.putlocks,
callbacks_propagate=self.callbacks_propagate,
**options)
def _get_info(self) -> Dict[str, Any]:
"""
Return configuration and statistics information. Subclasses should
augment the data as required.
:return: The returned value must be JSON-friendly.
"""
return {
'implementation': self.__class__.__module__ + ':' + self.__class__.__name__,
'max-concurrency': self.limit,
}
@property
def info(self):
return self._get_info()
@property
def active(self):
return self._state == self.RUN
@property
def num_processes(self):
return self.limit

View File

@@ -0,0 +1,181 @@
"""Eventlet execution pool."""
import sys
from time import monotonic
from greenlet import GreenletExit
from kombu.asynchronous import timer as _timer
from celery import signals
from . import base
__all__ = ('TaskPool',)
W_RACE = """\
Celery module with %s imported before eventlet patched\
"""
RACE_MODS = ('billiard.', 'celery.', 'kombu.')
#: Warn if we couldn't patch early enough,
#: and thread/socket depending celery modules have already been loaded.
for mod in (mod for mod in sys.modules if mod.startswith(RACE_MODS)):
for side in ('thread', 'threading', 'socket'): # pragma: no cover
if getattr(mod, side, None):
import warnings
warnings.warn(RuntimeWarning(W_RACE % side))
def apply_target(target, args=(), kwargs=None, callback=None,
accept_callback=None, getpid=None):
kwargs = {} if not kwargs else kwargs
return base.apply_target(target, args, kwargs, callback, accept_callback,
pid=getpid())
class Timer(_timer.Timer):
"""Eventlet Timer."""
def __init__(self, *args, **kwargs):
from eventlet.greenthread import spawn_after
from greenlet import GreenletExit
super().__init__(*args, **kwargs)
self.GreenletExit = GreenletExit
self._spawn_after = spawn_after
self._queue = set()
def _enter(self, eta, priority, entry, **kwargs):
secs = max(eta - monotonic(), 0)
g = self._spawn_after(secs, entry)
self._queue.add(g)
g.link(self._entry_exit, entry)
g.entry = entry
g.eta = eta
g.priority = priority
g.canceled = False
return g
def _entry_exit(self, g, entry):
try:
try:
g.wait()
except self.GreenletExit:
entry.cancel()
g.canceled = True
finally:
self._queue.discard(g)
def clear(self):
queue = self._queue
while queue:
try:
queue.pop().cancel()
except (KeyError, self.GreenletExit):
pass
def cancel(self, tref):
try:
tref.cancel()
except self.GreenletExit:
pass
@property
def queue(self):
return self._queue
class TaskPool(base.BasePool):
"""Eventlet Task Pool."""
Timer = Timer
signal_safe = False
is_green = True
task_join_will_block = False
_pool = None
_pool_map = None
_quick_put = None
def __init__(self, *args, **kwargs):
from eventlet import greenthread
from eventlet.greenpool import GreenPool
self.Pool = GreenPool
self.getcurrent = greenthread.getcurrent
self.getpid = lambda: id(greenthread.getcurrent())
self.spawn_n = greenthread.spawn_n
super().__init__(*args, **kwargs)
def on_start(self):
self._pool = self.Pool(self.limit)
self._pool_map = {}
signals.eventlet_pool_started.send(sender=self)
self._quick_put = self._pool.spawn
self._quick_apply_sig = signals.eventlet_pool_apply.send
def on_stop(self):
signals.eventlet_pool_preshutdown.send(sender=self)
if self._pool is not None:
self._pool.waitall()
signals.eventlet_pool_postshutdown.send(sender=self)
def on_apply(self, target, args=None, kwargs=None, callback=None,
accept_callback=None, **_):
target = TaskPool._make_killable_target(target)
self._quick_apply_sig(sender=self, target=target, args=args, kwargs=kwargs,)
greenlet = self._quick_put(
apply_target,
target, args,
kwargs,
callback,
accept_callback,
self.getpid
)
self._add_to_pool_map(id(greenlet), greenlet)
def grow(self, n=1):
limit = self.limit + n
self._pool.resize(limit)
self.limit = limit
def shrink(self, n=1):
limit = self.limit - n
self._pool.resize(limit)
self.limit = limit
def terminate_job(self, pid, signal=None):
if pid in self._pool_map.keys():
greenlet = self._pool_map[pid]
greenlet.kill()
greenlet.wait()
def _get_info(self):
info = super()._get_info()
info.update({
'max-concurrency': self.limit,
'free-threads': self._pool.free(),
'running-threads': self._pool.running(),
})
return info
@staticmethod
def _make_killable_target(target):
def killable_target(*args, **kwargs):
try:
return target(*args, **kwargs)
except GreenletExit:
return (False, None, None)
return killable_target
def _add_to_pool_map(self, pid, greenlet):
self._pool_map[pid] = greenlet
greenlet.link(
TaskPool._cleanup_after_job_finish,
self._pool_map,
pid
)
@staticmethod
def _cleanup_after_job_finish(greenlet, pool_map, pid):
del pool_map[pid]

View File

@@ -0,0 +1,122 @@
"""Gevent execution pool."""
from time import monotonic
from kombu.asynchronous import timer as _timer
from . import base
try:
from gevent import Timeout
except ImportError:
Timeout = None
__all__ = ('TaskPool',)
# pylint: disable=redefined-outer-name
# We cache globals and attribute lookups, so disable this warning.
def apply_timeout(target, args=(), kwargs=None, callback=None,
accept_callback=None, pid=None, timeout=None,
timeout_callback=None, Timeout=Timeout,
apply_target=base.apply_target, **rest):
kwargs = {} if not kwargs else kwargs
try:
with Timeout(timeout):
return apply_target(target, args, kwargs, callback,
accept_callback, pid,
propagate=(Timeout,), **rest)
except Timeout:
return timeout_callback(False, timeout)
class Timer(_timer.Timer):
def __init__(self, *args, **kwargs):
from gevent import Greenlet, GreenletExit
class _Greenlet(Greenlet):
cancel = Greenlet.kill
self._Greenlet = _Greenlet
self._GreenletExit = GreenletExit
super().__init__(*args, **kwargs)
self._queue = set()
def _enter(self, eta, priority, entry, **kwargs):
secs = max(eta - monotonic(), 0)
g = self._Greenlet.spawn_later(secs, entry)
self._queue.add(g)
g.link(self._entry_exit)
g.entry = entry
g.eta = eta
g.priority = priority
g.canceled = False
return g
def _entry_exit(self, g):
try:
g.kill()
finally:
self._queue.discard(g)
def clear(self):
queue = self._queue
while queue:
try:
queue.pop().kill()
except KeyError:
pass
@property
def queue(self):
return self._queue
class TaskPool(base.BasePool):
"""GEvent Pool."""
Timer = Timer
signal_safe = False
is_green = True
task_join_will_block = False
_pool = None
_quick_put = None
def __init__(self, *args, **kwargs):
from gevent import spawn_raw
from gevent.pool import Pool
self.Pool = Pool
self.spawn_n = spawn_raw
self.timeout = kwargs.get('timeout')
super().__init__(*args, **kwargs)
def on_start(self):
self._pool = self.Pool(self.limit)
self._quick_put = self._pool.spawn
def on_stop(self):
if self._pool is not None:
self._pool.join()
def on_apply(self, target, args=None, kwargs=None, callback=None,
accept_callback=None, timeout=None,
timeout_callback=None, apply_target=base.apply_target, **_):
timeout = self.timeout if timeout is None else timeout
return self._quick_put(apply_timeout if timeout else apply_target,
target, args, kwargs, callback, accept_callback,
timeout=timeout,
timeout_callback=timeout_callback)
def grow(self, n=1):
self._pool._semaphore.counter += n
self._pool.size += n
def shrink(self, n=1):
self._pool._semaphore.counter -= n
self._pool.size -= n
@property
def num_processes(self):
return len(self._pool)

View File

@@ -0,0 +1,172 @@
"""Prefork execution pool.
Pool implementation using :mod:`multiprocessing`.
"""
import os
from billiard import forking_enable
from billiard.common import REMAP_SIGTERM, TERM_SIGNAME
from billiard.pool import CLOSE, RUN
from billiard.pool import Pool as BlockingPool
from celery import platforms, signals
from celery._state import _set_task_join_will_block, set_default_app
from celery.app import trace
from celery.concurrency.base import BasePool
from celery.utils.functional import noop
from celery.utils.log import get_logger
from .asynpool import AsynPool
__all__ = ('TaskPool', 'process_initializer', 'process_destructor')
#: List of signals to reset when a child process starts.
WORKER_SIGRESET = {
'SIGTERM', 'SIGHUP', 'SIGTTIN', 'SIGTTOU', 'SIGUSR1',
}
#: List of signals to ignore when a child process starts.
if REMAP_SIGTERM:
WORKER_SIGIGNORE = {'SIGINT', TERM_SIGNAME}
else:
WORKER_SIGIGNORE = {'SIGINT'}
logger = get_logger(__name__)
warning, debug = logger.warning, logger.debug
def process_initializer(app, hostname):
"""Pool child process initializer.
Initialize the child pool process to ensure the correct
app instance is used and things like logging works.
"""
# Each running worker gets SIGKILL by OS when main process exits.
platforms.set_pdeathsig('SIGKILL')
_set_task_join_will_block(True)
platforms.signals.reset(*WORKER_SIGRESET)
platforms.signals.ignore(*WORKER_SIGIGNORE)
platforms.set_mp_process_title('celeryd', hostname=hostname)
# This is for Windows and other platforms not supporting
# fork(). Note that init_worker makes sure it's only
# run once per process.
app.loader.init_worker()
app.loader.init_worker_process()
logfile = os.environ.get('CELERY_LOG_FILE') or None
if logfile and '%i' in logfile.lower():
# logfile path will differ so need to set up logging again.
app.log.already_setup = False
app.log.setup(int(os.environ.get('CELERY_LOG_LEVEL', 0) or 0),
logfile,
bool(os.environ.get('CELERY_LOG_REDIRECT', False)),
str(os.environ.get('CELERY_LOG_REDIRECT_LEVEL')),
hostname=hostname)
if os.environ.get('FORKED_BY_MULTIPROCESSING'):
# pool did execv after fork
trace.setup_worker_optimizations(app, hostname)
else:
app.set_current()
set_default_app(app)
app.finalize()
trace._tasks = app._tasks # enables fast_trace_task optimization.
# rebuild execution handler for all tasks.
from celery.app.trace import build_tracer
for name, task in app.tasks.items():
task.__trace__ = build_tracer(name, task, app.loader, hostname,
app=app)
from celery.worker import state as worker_state
worker_state.reset_state()
signals.worker_process_init.send(sender=None)
def process_destructor(pid, exitcode):
"""Pool child process destructor.
Dispatch the :signal:`worker_process_shutdown` signal.
"""
signals.worker_process_shutdown.send(
sender=None, pid=pid, exitcode=exitcode,
)
class TaskPool(BasePool):
"""Multiprocessing Pool implementation."""
Pool = AsynPool
BlockingPool = BlockingPool
uses_semaphore = True
write_stats = None
def on_start(self):
forking_enable(self.forking_enable)
Pool = (self.BlockingPool if self.options.get('threads', True)
else self.Pool)
proc_alive_timeout = (
self.app.conf.worker_proc_alive_timeout if self.app
else None
)
P = self._pool = Pool(processes=self.limit,
initializer=process_initializer,
on_process_exit=process_destructor,
enable_timeouts=True,
synack=False,
proc_alive_timeout=proc_alive_timeout,
**self.options)
# Create proxy methods
self.on_apply = P.apply_async
self.maintain_pool = P.maintain_pool
self.terminate_job = P.terminate_job
self.grow = P.grow
self.shrink = P.shrink
self.flush = getattr(P, 'flush', None) # FIXME add to billiard
def restart(self):
self._pool.restart()
self._pool.apply_async(noop)
def did_start_ok(self):
return self._pool.did_start_ok()
def register_with_event_loop(self, loop):
try:
reg = self._pool.register_with_event_loop
except AttributeError:
return
return reg(loop)
def on_stop(self):
"""Gracefully stop the pool."""
if self._pool is not None and self._pool._state in (RUN, CLOSE):
self._pool.close()
self._pool.join()
self._pool = None
def on_terminate(self):
"""Force terminate the pool."""
if self._pool is not None:
self._pool.terminate()
self._pool = None
def on_close(self):
if self._pool is not None and self._pool._state == RUN:
self._pool.close()
def _get_info(self):
write_stats = getattr(self._pool, 'human_write_stats', None)
info = super()._get_info()
info.update({
'max-concurrency': self.limit,
'processes': [p.pid for p in self._pool._pool],
'max-tasks-per-child': self._pool._maxtasksperchild or 'N/A',
'put-guarded-by-semaphore': self.putlocks,
'timeouts': (self._pool.soft_timeout or 0,
self._pool.timeout or 0),
'writes': write_stats() if write_stats is not None else 'N/A',
})
return info
@property
def num_processes(self):
return self._pool._processes

View File

@@ -0,0 +1,31 @@
"""Single-threaded execution pool."""
import os
from celery import signals
from .base import BasePool, apply_target
__all__ = ('TaskPool',)
class TaskPool(BasePool):
"""Solo task pool (blocking, inline, fast)."""
body_can_be_buffer = True
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.on_apply = apply_target
self.limit = 1
signals.worker_process_init.send(sender=None)
def _get_info(self):
info = super()._get_info()
info.update({
'max-concurrency': 1,
'processes': [os.getpid()],
'max-tasks-per-child': None,
'put-guarded-by-semaphore': True,
'timeouts': (),
})
return info

View File

@@ -0,0 +1,64 @@
"""Thread execution pool."""
from __future__ import annotations
from concurrent.futures import Future, ThreadPoolExecutor, wait
from typing import TYPE_CHECKING, Any, Callable
from .base import BasePool, apply_target
__all__ = ('TaskPool',)
if TYPE_CHECKING:
from typing import TypedDict
PoolInfo = TypedDict('PoolInfo', {'max-concurrency': int, 'threads': int})
# `TargetFunction` should be a Protocol that represents fast_trace_task and
# trace_task_ret.
TargetFunction = Callable[..., Any]
class ApplyResult:
def __init__(self, future: Future) -> None:
self.f = future
self.get = self.f.result
def wait(self, timeout: float | None = None) -> None:
wait([self.f], timeout)
class TaskPool(BasePool):
"""Thread Task Pool."""
limit: int
body_can_be_buffer = True
signal_safe = False
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.executor = ThreadPoolExecutor(max_workers=self.limit)
def on_stop(self) -> None:
self.executor.shutdown()
super().on_stop()
def on_apply(
self,
target: TargetFunction,
args: tuple[Any, ...] | None = None,
kwargs: dict[str, Any] | None = None,
callback: Callable[..., Any] | None = None,
accept_callback: Callable[..., Any] | None = None,
**_: Any
) -> ApplyResult:
f = self.executor.submit(apply_target, target, args, kwargs,
callback, accept_callback)
return ApplyResult(f)
def _get_info(self) -> PoolInfo:
info = super()._get_info()
info.update({
'max-concurrency': self.limit,
'threads': len(self.executor._threads)
})
return info

View File

@@ -0,0 +1,165 @@
"""Abortable Tasks.
Abortable tasks overview
=========================
For long-running :class:`Task`'s, it can be desirable to support
aborting during execution. Of course, these tasks should be built to
support abortion specifically.
The :class:`AbortableTask` serves as a base class for all :class:`Task`
objects that should support abortion by producers.
* Producers may invoke the :meth:`abort` method on
:class:`AbortableAsyncResult` instances, to request abortion.
* Consumers (workers) should periodically check (and honor!) the
:meth:`is_aborted` method at controlled points in their task's
:meth:`run` method. The more often, the better.
The necessary intermediate communication is dealt with by the
:class:`AbortableTask` implementation.
Usage example
-------------
In the consumer:
.. code-block:: python
from celery.contrib.abortable import AbortableTask
from celery.utils.log import get_task_logger
from proj.celery import app
logger = get_logger(__name__)
@app.task(bind=True, base=AbortableTask)
def long_running_task(self):
results = []
for i in range(100):
# check after every 5 iterations...
# (or alternatively, check when some timer is due)
if not i % 5:
if self.is_aborted():
# respect aborted state, and terminate gracefully.
logger.warning('Task aborted')
return
value = do_something_expensive(i)
results.append(y)
logger.info('Task complete')
return results
In the producer:
.. code-block:: python
import time
from proj.tasks import MyLongRunningTask
def myview(request):
# result is of type AbortableAsyncResult
result = long_running_task.delay()
# abort the task after 10 seconds
time.sleep(10)
result.abort()
After the `result.abort()` call, the task execution isn't
aborted immediately. In fact, it's not guaranteed to abort at all.
Keep checking `result.state` status, or call `result.get(timeout=)` to
have it block until the task is finished.
.. note::
In order to abort tasks, there needs to be communication between the
producer and the consumer. This is currently implemented through the
database backend. Therefore, this class will only work with the
database backends.
"""
from celery import Task
from celery.result import AsyncResult
__all__ = ('AbortableAsyncResult', 'AbortableTask')
"""
Task States
-----------
.. state:: ABORTED
ABORTED
~~~~~~~
Task is aborted (typically by the producer) and should be
aborted as soon as possible.
"""
ABORTED = 'ABORTED'
class AbortableAsyncResult(AsyncResult):
"""Represents an abortable result.
Specifically, this gives the `AsyncResult` a :meth:`abort()` method,
that sets the state of the underlying Task to `'ABORTED'`.
"""
def is_aborted(self):
"""Return :const:`True` if the task is (being) aborted."""
return self.state == ABORTED
def abort(self):
"""Set the state of the task to :const:`ABORTED`.
Abortable tasks monitor their state at regular intervals and
terminate execution if so.
Warning:
Be aware that invoking this method does not guarantee when the
task will be aborted (or even if the task will be aborted at all).
"""
# TODO: store_result requires all four arguments to be set,
# but only state should be updated here
return self.backend.store_result(self.id, result=None,
state=ABORTED, traceback=None)
class AbortableTask(Task):
"""Task that can be aborted.
This serves as a base class for all :class:`Task`'s
that support aborting during execution.
All subclasses of :class:`AbortableTask` must call the
:meth:`is_aborted` method periodically and act accordingly when
the call evaluates to :const:`True`.
"""
abstract = True
def AsyncResult(self, task_id):
"""Return the accompanying AbortableAsyncResult instance."""
return AbortableAsyncResult(task_id, backend=self.backend)
def is_aborted(self, **kwargs):
"""Return true if task is aborted.
Checks against the backend whether this
:class:`AbortableAsyncResult` is :const:`ABORTED`.
Always return :const:`False` in case the `task_id` parameter
refers to a regular (non-abortable) :class:`Task`.
Be aware that invoking this method will cause a hit in the
backend (for example a database query), so find a good balance
between calling it regularly (for responsiveness), but not too
often (for performance).
"""
task_id = kwargs.get('task_id', self.request.id)
result = self.AsyncResult(task_id)
if not isinstance(result, AbortableAsyncResult):
return False
return result.is_aborted()

View File

@@ -0,0 +1,416 @@
"""Message migration tools (Broker <-> Broker)."""
import socket
from functools import partial
from itertools import cycle, islice
from kombu import Queue, eventloop
from kombu.common import maybe_declare
from kombu.utils.encoding import ensure_bytes
from celery.app import app_or_default
from celery.utils.nodenames import worker_direct
from celery.utils.text import str_to_list
__all__ = (
'StopFiltering', 'State', 'republish', 'migrate_task',
'migrate_tasks', 'move', 'task_id_eq', 'task_id_in',
'start_filter', 'move_task_by_id', 'move_by_idmap',
'move_by_taskmap', 'move_direct', 'move_direct_by_id',
)
MOVING_PROGRESS_FMT = """\
Moving task {state.filtered}/{state.strtotal}: \
{body[task]}[{body[id]}]\
"""
class StopFiltering(Exception):
"""Semi-predicate used to signal filter stop."""
class State:
"""Migration progress state."""
count = 0
filtered = 0
total_apx = 0
@property
def strtotal(self):
if not self.total_apx:
return '?'
return str(self.total_apx)
def __repr__(self):
if self.filtered:
return f'^{self.filtered}'
return f'{self.count}/{self.strtotal}'
def republish(producer, message, exchange=None, routing_key=None,
remove_props=None):
"""Republish message."""
if not remove_props:
remove_props = ['application_headers', 'content_type',
'content_encoding', 'headers']
body = ensure_bytes(message.body) # use raw message body.
info, headers, props = (message.delivery_info,
message.headers, message.properties)
exchange = info['exchange'] if exchange is None else exchange
routing_key = info['routing_key'] if routing_key is None else routing_key
ctype, enc = message.content_type, message.content_encoding
# remove compression header, as this will be inserted again
# when the message is recompressed.
compression = headers.pop('compression', None)
expiration = props.pop('expiration', None)
# ensure expiration is a float
expiration = float(expiration) if expiration is not None else None
for key in remove_props:
props.pop(key, None)
producer.publish(ensure_bytes(body), exchange=exchange,
routing_key=routing_key, compression=compression,
headers=headers, content_type=ctype,
content_encoding=enc, expiration=expiration,
**props)
def migrate_task(producer, body_, message, queues=None):
"""Migrate single task message."""
info = message.delivery_info
queues = {} if queues is None else queues
republish(producer, message,
exchange=queues.get(info['exchange']),
routing_key=queues.get(info['routing_key']))
def filter_callback(callback, tasks):
def filtered(body, message):
if tasks and body['task'] not in tasks:
return
return callback(body, message)
return filtered
def migrate_tasks(source, dest, migrate=migrate_task, app=None,
queues=None, **kwargs):
"""Migrate tasks from one broker to another."""
app = app_or_default(app)
queues = prepare_queues(queues)
producer = app.amqp.Producer(dest, auto_declare=False)
migrate = partial(migrate, producer, queues=queues)
def on_declare_queue(queue):
new_queue = queue(producer.channel)
new_queue.name = queues.get(queue.name, queue.name)
if new_queue.routing_key == queue.name:
new_queue.routing_key = queues.get(queue.name,
new_queue.routing_key)
if new_queue.exchange.name == queue.name:
new_queue.exchange.name = queues.get(queue.name, queue.name)
new_queue.declare()
return start_filter(app, source, migrate, queues=queues,
on_declare_queue=on_declare_queue, **kwargs)
def _maybe_queue(app, q):
if isinstance(q, str):
return app.amqp.queues[q]
return q
def move(predicate, connection=None, exchange=None, routing_key=None,
source=None, app=None, callback=None, limit=None, transform=None,
**kwargs):
"""Find tasks by filtering them and move the tasks to a new queue.
Arguments:
predicate (Callable): Filter function used to decide the messages
to move. Must accept the standard signature of ``(body, message)``
used by Kombu consumer callbacks. If the predicate wants the
message to be moved it must return either:
1) a tuple of ``(exchange, routing_key)``, or
2) a :class:`~kombu.entity.Queue` instance, or
3) any other true value means the specified
``exchange`` and ``routing_key`` arguments will be used.
connection (kombu.Connection): Custom connection to use.
source: List[Union[str, kombu.Queue]]: Optional list of source
queues to use instead of the default (queues
in :setting:`task_queues`). This list can also contain
:class:`~kombu.entity.Queue` instances.
exchange (str, kombu.Exchange): Default destination exchange.
routing_key (str): Default destination routing key.
limit (int): Limit number of messages to filter.
callback (Callable): Callback called after message moved,
with signature ``(state, body, message)``.
transform (Callable): Optional function to transform the return
value (destination) of the filter function.
Also supports the same keyword arguments as :func:`start_filter`.
To demonstrate, the :func:`move_task_by_id` operation can be implemented
like this:
.. code-block:: python
def is_wanted_task(body, message):
if body['id'] == wanted_id:
return Queue('foo', exchange=Exchange('foo'),
routing_key='foo')
move(is_wanted_task)
or with a transform:
.. code-block:: python
def transform(value):
if isinstance(value, str):
return Queue(value, Exchange(value), value)
return value
move(is_wanted_task, transform=transform)
Note:
The predicate may also return a tuple of ``(exchange, routing_key)``
to specify the destination to where the task should be moved,
or a :class:`~kombu.entity.Queue` instance.
Any other true value means that the task will be moved to the
default exchange/routing_key.
"""
app = app_or_default(app)
queues = [_maybe_queue(app, queue) for queue in source or []] or None
with app.connection_or_acquire(connection, pool=False) as conn:
producer = app.amqp.Producer(conn)
state = State()
def on_task(body, message):
ret = predicate(body, message)
if ret:
if transform:
ret = transform(ret)
if isinstance(ret, Queue):
maybe_declare(ret, conn.default_channel)
ex, rk = ret.exchange.name, ret.routing_key
else:
ex, rk = expand_dest(ret, exchange, routing_key)
republish(producer, message,
exchange=ex, routing_key=rk)
message.ack()
state.filtered += 1
if callback:
callback(state, body, message)
if limit and state.filtered >= limit:
raise StopFiltering()
return start_filter(app, conn, on_task, consume_from=queues, **kwargs)
def expand_dest(ret, exchange, routing_key):
try:
ex, rk = ret
except (TypeError, ValueError):
ex, rk = exchange, routing_key
return ex, rk
def task_id_eq(task_id, body, message):
"""Return true if task id equals task_id'."""
return body['id'] == task_id
def task_id_in(ids, body, message):
"""Return true if task id is member of set ids'."""
return body['id'] in ids
def prepare_queues(queues):
if isinstance(queues, str):
queues = queues.split(',')
if isinstance(queues, list):
queues = dict(tuple(islice(cycle(q.split(':')), None, 2))
for q in queues)
if queues is None:
queues = {}
return queues
class Filterer:
def __init__(self, app, conn, filter,
limit=None, timeout=1.0,
ack_messages=False, tasks=None, queues=None,
callback=None, forever=False, on_declare_queue=None,
consume_from=None, state=None, accept=None, **kwargs):
self.app = app
self.conn = conn
self.filter = filter
self.limit = limit
self.timeout = timeout
self.ack_messages = ack_messages
self.tasks = set(str_to_list(tasks) or [])
self.queues = prepare_queues(queues)
self.callback = callback
self.forever = forever
self.on_declare_queue = on_declare_queue
self.consume_from = [
_maybe_queue(self.app, q)
for q in consume_from or list(self.queues)
]
self.state = state or State()
self.accept = accept
def start(self):
# start migrating messages.
with self.prepare_consumer(self.create_consumer()):
try:
for _ in eventloop(self.conn, # pragma: no cover
timeout=self.timeout,
ignore_timeouts=self.forever):
pass
except socket.timeout:
pass
except StopFiltering:
pass
return self.state
def update_state(self, body, message):
self.state.count += 1
if self.limit and self.state.count >= self.limit:
raise StopFiltering()
def ack_message(self, body, message):
message.ack()
def create_consumer(self):
return self.app.amqp.TaskConsumer(
self.conn,
queues=self.consume_from,
accept=self.accept,
)
def prepare_consumer(self, consumer):
filter = self.filter
update_state = self.update_state
ack_message = self.ack_message
if self.tasks:
filter = filter_callback(filter, self.tasks)
update_state = filter_callback(update_state, self.tasks)
ack_message = filter_callback(ack_message, self.tasks)
consumer.register_callback(filter)
consumer.register_callback(update_state)
if self.ack_messages:
consumer.register_callback(self.ack_message)
if self.callback is not None:
callback = partial(self.callback, self.state)
if self.tasks:
callback = filter_callback(callback, self.tasks)
consumer.register_callback(callback)
self.declare_queues(consumer)
return consumer
def declare_queues(self, consumer):
# declare all queues on the new broker.
for queue in consumer.queues:
if self.queues and queue.name not in self.queues:
continue
if self.on_declare_queue is not None:
self.on_declare_queue(queue)
try:
_, mcount, _ = queue(
consumer.channel).queue_declare(passive=True)
if mcount:
self.state.total_apx += mcount
except self.conn.channel_errors:
pass
def start_filter(app, conn, filter, limit=None, timeout=1.0,
ack_messages=False, tasks=None, queues=None,
callback=None, forever=False, on_declare_queue=None,
consume_from=None, state=None, accept=None, **kwargs):
"""Filter tasks."""
return Filterer(
app, conn, filter,
limit=limit,
timeout=timeout,
ack_messages=ack_messages,
tasks=tasks,
queues=queues,
callback=callback,
forever=forever,
on_declare_queue=on_declare_queue,
consume_from=consume_from,
state=state,
accept=accept,
**kwargs).start()
def move_task_by_id(task_id, dest, **kwargs):
"""Find a task by id and move it to another queue.
Arguments:
task_id (str): Id of task to find and move.
dest: (str, kombu.Queue): Destination queue.
transform (Callable): Optional function to transform the return
value (destination) of the filter function.
**kwargs (Any): Also supports the same keyword
arguments as :func:`move`.
"""
return move_by_idmap({task_id: dest}, **kwargs)
def move_by_idmap(map, **kwargs):
"""Move tasks by matching from a ``task_id: queue`` mapping.
Where ``queue`` is a queue to move the task to.
Example:
>>> move_by_idmap({
... '5bee6e82-f4ac-468e-bd3d-13e8600250bc': Queue('name'),
... 'ada8652d-aef3-466b-abd2-becdaf1b82b3': Queue('name'),
... '3a2b140d-7db1-41ba-ac90-c36a0ef4ab1f': Queue('name')},
... queues=['hipri'])
"""
def task_id_in_map(body, message):
return map.get(message.properties['correlation_id'])
# adding the limit means that we don't have to consume any more
# when we've found everything.
return move(task_id_in_map, limit=len(map), **kwargs)
def move_by_taskmap(map, **kwargs):
"""Move tasks by matching from a ``task_name: queue`` mapping.
``queue`` is the queue to move the task to.
Example:
>>> move_by_taskmap({
... 'tasks.add': Queue('name'),
... 'tasks.mul': Queue('name'),
... })
"""
def task_name_in_map(body, message):
return map.get(body['task']) # <- name of task
return move(task_name_in_map, **kwargs)
def filter_status(state, body, message, **kwargs):
print(MOVING_PROGRESS_FMT.format(state=state, body=body, **kwargs))
move_direct = partial(move, transform=worker_direct)
move_direct_by_id = partial(move_task_by_id, transform=worker_direct)
move_direct_by_idmap = partial(move_by_idmap, transform=worker_direct)
move_direct_by_taskmap = partial(move_by_taskmap, transform=worker_direct)

View File

@@ -0,0 +1,216 @@
"""Fixtures and testing utilities for :pypi:`pytest <pytest>`."""
import os
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Mapping, Sequence, Union # noqa
import pytest
if TYPE_CHECKING:
from celery import Celery
from ..worker import WorkController
else:
Celery = WorkController = object
NO_WORKER = os.environ.get('NO_WORKER')
# pylint: disable=redefined-outer-name
# Well, they're called fixtures....
def pytest_configure(config):
"""Register additional pytest configuration."""
# add the pytest.mark.celery() marker registration to the pytest.ini [markers] section
# this prevents pytest 4.5 and newer from issuing a warning about an unknown marker
# and shows helpful marker documentation when running pytest --markers.
config.addinivalue_line(
"markers", "celery(**overrides): override celery configuration for a test case"
)
@contextmanager
def _create_app(enable_logging=False,
use_trap=False,
parameters=None,
**config):
# type: (Any, Any, Any, **Any) -> Celery
"""Utility context used to setup Celery app for pytest fixtures."""
from .testing.app import TestApp, setup_default_app
parameters = {} if not parameters else parameters
test_app = TestApp(
set_as_current=False,
enable_logging=enable_logging,
config=config,
**parameters
)
with setup_default_app(test_app, use_trap=use_trap):
yield test_app
@pytest.fixture(scope='session')
def use_celery_app_trap():
# type: () -> bool
"""You can override this fixture to enable the app trap.
The app trap raises an exception whenever something attempts
to use the current or default apps.
"""
return False
@pytest.fixture(scope='session')
def celery_session_app(request,
celery_config,
celery_parameters,
celery_enable_logging,
use_celery_app_trap):
# type: (Any, Any, Any, Any, Any) -> Celery
"""Session Fixture: Return app for session fixtures."""
mark = request.node.get_closest_marker('celery')
config = dict(celery_config, **mark.kwargs if mark else {})
with _create_app(enable_logging=celery_enable_logging,
use_trap=use_celery_app_trap,
parameters=celery_parameters,
**config) as app:
if not use_celery_app_trap:
app.set_default()
app.set_current()
yield app
@pytest.fixture(scope='session')
def celery_session_worker(
request, # type: Any
celery_session_app, # type: Celery
celery_includes, # type: Sequence[str]
celery_class_tasks, # type: str
celery_worker_pool, # type: Any
celery_worker_parameters, # type: Mapping[str, Any]
):
# type: (...) -> WorkController
"""Session Fixture: Start worker that lives throughout test suite."""
from .testing import worker
if not NO_WORKER:
for module in celery_includes:
celery_session_app.loader.import_task_module(module)
for class_task in celery_class_tasks:
celery_session_app.register_task(class_task)
with worker.start_worker(celery_session_app,
pool=celery_worker_pool,
**celery_worker_parameters) as w:
yield w
@pytest.fixture(scope='session')
def celery_enable_logging():
# type: () -> bool
"""You can override this fixture to enable logging."""
return False
@pytest.fixture(scope='session')
def celery_includes():
# type: () -> Sequence[str]
"""You can override this include modules when a worker start.
You can have this return a list of module names to import,
these can be task modules, modules registering signals, and so on.
"""
return ()
@pytest.fixture(scope='session')
def celery_worker_pool():
# type: () -> Union[str, Any]
"""You can override this fixture to set the worker pool.
The "solo" pool is used by default, but you can set this to
return e.g. "prefork".
"""
return 'solo'
@pytest.fixture(scope='session')
def celery_config():
# type: () -> Mapping[str, Any]
"""Redefine this fixture to configure the test Celery app.
The config returned by your fixture will then be used
to configure the :func:`celery_app` fixture.
"""
return {}
@pytest.fixture(scope='session')
def celery_parameters():
# type: () -> Mapping[str, Any]
"""Redefine this fixture to change the init parameters of test Celery app.
The dict returned by your fixture will then be used
as parameters when instantiating :class:`~celery.Celery`.
"""
return {}
@pytest.fixture(scope='session')
def celery_worker_parameters():
# type: () -> Mapping[str, Any]
"""Redefine this fixture to change the init parameters of Celery workers.
This can be used e. g. to define queues the worker will consume tasks from.
The dict returned by your fixture will then be used
as parameters when instantiating :class:`~celery.worker.WorkController`.
"""
return {}
@pytest.fixture()
def celery_app(request,
celery_config,
celery_parameters,
celery_enable_logging,
use_celery_app_trap):
"""Fixture creating a Celery application instance."""
mark = request.node.get_closest_marker('celery')
config = dict(celery_config, **mark.kwargs if mark else {})
with _create_app(enable_logging=celery_enable_logging,
use_trap=use_celery_app_trap,
parameters=celery_parameters,
**config) as app:
yield app
@pytest.fixture(scope='session')
def celery_class_tasks():
"""Redefine this fixture to register tasks with the test Celery app."""
return []
@pytest.fixture()
def celery_worker(request,
celery_app,
celery_includes,
celery_worker_pool,
celery_worker_parameters):
# type: (Any, Celery, Sequence[str], str, Any) -> WorkController
"""Fixture: Start worker in a thread, stop it when the test returns."""
from .testing import worker
if not NO_WORKER:
for module in celery_includes:
celery_app.loader.import_task_module(module)
with worker.start_worker(celery_app,
pool=celery_worker_pool,
**celery_worker_parameters) as w:
yield w
@pytest.fixture()
def depends_on_current_app(celery_app):
"""Fixture that sets app as current."""
celery_app.set_current()

View File

@@ -0,0 +1,187 @@
"""Remote Debugger.
Introduction
============
This is a remote debugger for Celery tasks running in multiprocessing
pool workers. Inspired by a lost post on dzone.com.
Usage
-----
.. code-block:: python
from celery.contrib import rdb
from celery import task
@task()
def add(x, y):
result = x + y
rdb.set_trace()
return result
Environment Variables
=====================
.. envvar:: CELERY_RDB_HOST
``CELERY_RDB_HOST``
-------------------
Hostname to bind to. Default is '127.0.0.1' (only accessible from
localhost).
.. envvar:: CELERY_RDB_PORT
``CELERY_RDB_PORT``
-------------------
Base port to bind to. Default is 6899.
The debugger will try to find an available port starting from the
base port. The selected port will be logged by the worker.
"""
import errno
import os
import socket
import sys
from pdb import Pdb
from billiard.process import current_process
__all__ = (
'CELERY_RDB_HOST', 'CELERY_RDB_PORT', 'DEFAULT_PORT',
'Rdb', 'debugger', 'set_trace',
)
DEFAULT_PORT = 6899
CELERY_RDB_HOST = os.environ.get('CELERY_RDB_HOST') or '127.0.0.1'
CELERY_RDB_PORT = int(os.environ.get('CELERY_RDB_PORT') or DEFAULT_PORT)
#: Holds the currently active debugger.
_current = [None]
_frame = getattr(sys, '_getframe')
NO_AVAILABLE_PORT = """\
{self.ident}: Couldn't find an available port.
Please specify one using the CELERY_RDB_PORT environment variable.
"""
BANNER = """\
{self.ident}: Ready to connect: telnet {self.host} {self.port}
Type `exit` in session to continue.
{self.ident}: Waiting for client...
"""
SESSION_STARTED = '{self.ident}: Now in session with {self.remote_addr}.'
SESSION_ENDED = '{self.ident}: Session with {self.remote_addr} ended.'
class Rdb(Pdb):
"""Remote debugger."""
me = 'Remote Debugger'
_prev_outs = None
_sock = None
def __init__(self, host=CELERY_RDB_HOST, port=CELERY_RDB_PORT,
port_search_limit=100, port_skew=+0, out=sys.stdout):
self.active = True
self.out = out
self._prev_handles = sys.stdin, sys.stdout
self._sock, this_port = self.get_avail_port(
host, port, port_search_limit, port_skew,
)
self._sock.setblocking(1)
self._sock.listen(1)
self.ident = f'{self.me}:{this_port}'
self.host = host
self.port = this_port
self.say(BANNER.format(self=self))
self._client, address = self._sock.accept()
self._client.setblocking(1)
self.remote_addr = ':'.join(str(v) for v in address)
self.say(SESSION_STARTED.format(self=self))
self._handle = sys.stdin = sys.stdout = self._client.makefile('rw')
super().__init__(completekey='tab',
stdin=self._handle, stdout=self._handle)
def get_avail_port(self, host, port, search_limit=100, skew=+0):
try:
_, skew = current_process().name.split('-')
skew = int(skew)
except ValueError:
pass
this_port = None
for i in range(search_limit):
_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
_sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
this_port = port + skew + i
try:
_sock.bind((host, this_port))
except OSError as exc:
if exc.errno in [errno.EADDRINUSE, errno.EINVAL]:
continue
raise
else:
return _sock, this_port
raise Exception(NO_AVAILABLE_PORT.format(self=self))
def say(self, m):
print(m, file=self.out)
def __enter__(self):
return self
def __exit__(self, *exc_info):
self._close_session()
def _close_session(self):
self.stdin, self.stdout = sys.stdin, sys.stdout = self._prev_handles
if self.active:
if self._handle is not None:
self._handle.close()
if self._client is not None:
self._client.close()
if self._sock is not None:
self._sock.close()
self.active = False
self.say(SESSION_ENDED.format(self=self))
def do_continue(self, arg):
self._close_session()
self.set_continue()
return 1
do_c = do_cont = do_continue
def do_quit(self, arg):
self._close_session()
self.set_quit()
return 1
do_q = do_exit = do_quit
def set_quit(self):
# this raises a BdbQuit exception that we're unable to catch.
sys.settrace(None)
def debugger():
"""Return the current debugger instance, or create if none."""
rdb = _current[0]
if rdb is None or not rdb.active:
rdb = _current[0] = Rdb()
return rdb
def set_trace(frame=None):
"""Set break-point at current location, or a specified frame."""
if frame is None:
frame = _frame().f_back
return debugger().set_trace(frame)

View File

@@ -0,0 +1,105 @@
"""Sphinx documentation plugin used to document tasks.
Introduction
============
Usage
-----
The Celery extension for Sphinx requires Sphinx 2.0 or later.
Add the extension to your :file:`docs/conf.py` configuration module:
.. code-block:: python
extensions = (...,
'celery.contrib.sphinx')
If you'd like to change the prefix for tasks in reference documentation
then you can change the ``celery_task_prefix`` configuration value:
.. code-block:: python
celery_task_prefix = '(task)' # < default
With the extension installed `autodoc` will automatically find
task decorated objects (e.g. when using the automodule directive)
and generate the correct (as well as add a ``(task)`` prefix),
and you can also refer to the tasks using `:task:proj.tasks.add`
syntax.
Use ``.. autotask::`` to alternatively manually document a task.
"""
from inspect import signature
from docutils import nodes
from sphinx.domains.python import PyFunction
from sphinx.ext.autodoc import FunctionDocumenter
from celery.app.task import BaseTask
class TaskDocumenter(FunctionDocumenter):
"""Document task definitions."""
objtype = 'task'
member_order = 11
@classmethod
def can_document_member(cls, member, membername, isattr, parent):
return isinstance(member, BaseTask) and getattr(member, '__wrapped__')
def format_args(self):
wrapped = getattr(self.object, '__wrapped__', None)
if wrapped is not None:
sig = signature(wrapped)
if "self" in sig.parameters or "cls" in sig.parameters:
sig = sig.replace(parameters=list(sig.parameters.values())[1:])
return str(sig)
return ''
def document_members(self, all_members=False):
pass
def check_module(self):
# Normally checks if *self.object* is really defined in the module
# given by *self.modname*. But since functions decorated with the @task
# decorator are instances living in the celery.local, we have to check
# the wrapped function instead.
wrapped = getattr(self.object, '__wrapped__', None)
if wrapped and getattr(wrapped, '__module__') == self.modname:
return True
return super().check_module()
class TaskDirective(PyFunction):
"""Sphinx task directive."""
def get_signature_prefix(self, sig):
return [nodes.Text(self.env.config.celery_task_prefix)]
def autodoc_skip_member_handler(app, what, name, obj, skip, options):
"""Handler for autodoc-skip-member event."""
# Celery tasks created with the @task decorator have the property
# that *obj.__doc__* and *obj.__class__.__doc__* are equal, which
# trips up the logic in sphinx.ext.autodoc that is supposed to
# suppress repetition of class documentation in an instance of the
# class. This overrides that behavior.
if isinstance(obj, BaseTask) and getattr(obj, '__wrapped__'):
if skip:
return False
return None
def setup(app):
"""Setup Sphinx extension."""
app.setup_extension('sphinx.ext.autodoc')
app.add_autodocumenter(TaskDocumenter)
app.add_directive_to_domain('py', 'task', TaskDirective)
app.add_config_value('celery_task_prefix', '(task)', True)
app.connect('autodoc-skip-member', autodoc_skip_member_handler)
return {
'parallel_read_safe': True
}

View File

@@ -0,0 +1,112 @@
"""Create Celery app instances used for testing."""
import weakref
from contextlib import contextmanager
from copy import deepcopy
from kombu.utils.imports import symbol_by_name
from celery import Celery, _state
#: Contains the default configuration values for the test app.
DEFAULT_TEST_CONFIG = {
'worker_hijack_root_logger': False,
'worker_log_color': False,
'accept_content': {'json'},
'enable_utc': True,
'timezone': 'UTC',
'broker_url': 'memory://',
'result_backend': 'cache+memory://',
'broker_heartbeat': 0,
}
class Trap:
"""Trap that pretends to be an app but raises an exception instead.
This to protect from code that does not properly pass app instances,
then falls back to the current_app.
"""
def __getattr__(self, name):
# Workaround to allow unittest.mock to patch this object
# in Python 3.8 and above.
if name == '_is_coroutine' or name == '__func__':
return None
print(name)
raise RuntimeError('Test depends on current_app')
class UnitLogging(symbol_by_name(Celery.log_cls)):
"""Sets up logging for the test application."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.already_setup = True
def TestApp(name=None, config=None, enable_logging=False, set_as_current=False,
log=UnitLogging, backend=None, broker=None, **kwargs):
"""App used for testing."""
from . import tasks # noqa
config = dict(deepcopy(DEFAULT_TEST_CONFIG), **config or {})
if broker is not None:
config.pop('broker_url', None)
if backend is not None:
config.pop('result_backend', None)
log = None if enable_logging else log
test_app = Celery(
name or 'celery.tests',
set_as_current=set_as_current,
log=log,
broker=broker,
backend=backend,
**kwargs)
test_app.add_defaults(config)
return test_app
@contextmanager
def set_trap(app):
"""Contextmanager that installs the trap app.
The trap means that anything trying to use the current or default app
will raise an exception.
"""
trap = Trap()
prev_tls = _state._tls
_state.set_default_app(trap)
class NonTLS:
current_app = trap
_state._tls = NonTLS()
try:
yield
finally:
_state._tls = prev_tls
@contextmanager
def setup_default_app(app, use_trap=False):
"""Setup default app for testing.
Ensures state is clean after the test returns.
"""
prev_current_app = _state.get_current_app()
prev_default_app = _state.default_app
prev_finalizers = set(_state._on_app_finalizers)
prev_apps = weakref.WeakSet(_state._apps)
try:
if use_trap:
with set_trap(app):
yield
else:
yield
finally:
_state.set_default_app(prev_default_app)
_state._tls.current_app = prev_current_app
if app is not prev_current_app:
app.close()
_state._on_app_finalizers = prev_finalizers
_state._apps = prev_apps

View File

@@ -0,0 +1,239 @@
"""Integration testing utilities."""
import socket
import sys
from collections import defaultdict
from functools import partial
from itertools import count
from typing import Any, Callable, Dict, Sequence, TextIO, Tuple # noqa
from kombu.exceptions import ContentDisallowed
from kombu.utils.functional import retry_over_time
from celery import states
from celery.exceptions import TimeoutError
from celery.result import AsyncResult, ResultSet # noqa
from celery.utils.text import truncate
from celery.utils.time import humanize_seconds as _humanize_seconds
E_STILL_WAITING = 'Still waiting for {0}. Trying again {when}: {exc!r}'
humanize_seconds = partial(_humanize_seconds, microseconds=True)
class Sentinel(Exception):
"""Signifies the end of something."""
class ManagerMixin:
"""Mixin that adds :class:`Manager` capabilities."""
def _init_manager(self,
block_timeout=30 * 60.0, no_join=False,
stdout=None, stderr=None):
# type: (float, bool, TextIO, TextIO) -> None
self.stdout = sys.stdout if stdout is None else stdout
self.stderr = sys.stderr if stderr is None else stderr
self.connerrors = self.app.connection().recoverable_connection_errors
self.block_timeout = block_timeout
self.no_join = no_join
def remark(self, s, sep='-'):
# type: (str, str) -> None
print(f'{sep}{s}', file=self.stdout)
def missing_results(self, r):
# type: (Sequence[AsyncResult]) -> Sequence[str]
return [res.id for res in r if res.id not in res.backend._cache]
def wait_for(
self,
fun, # type: Callable
catch, # type: Sequence[Any]
desc="thing", # type: str
args=(), # type: Tuple
kwargs=None, # type: Dict
errback=None, # type: Callable
max_retries=10, # type: int
interval_start=0.1, # type: float
interval_step=0.5, # type: float
interval_max=5.0, # type: float
emit_warning=False, # type: bool
**options # type: Any
):
# type: (...) -> Any
"""Wait for event to happen.
The `catch` argument specifies the exception that means the event
has not happened yet.
"""
kwargs = {} if not kwargs else kwargs
def on_error(exc, intervals, retries):
interval = next(intervals)
if emit_warning:
self.warn(E_STILL_WAITING.format(
desc, when=humanize_seconds(interval, 'in', ' '), exc=exc,
))
if errback:
errback(exc, interval, retries)
return interval
return self.retry_over_time(
fun, catch,
args=args, kwargs=kwargs,
errback=on_error, max_retries=max_retries,
interval_start=interval_start, interval_step=interval_step,
**options
)
def ensure_not_for_a_while(self, fun, catch,
desc='thing', max_retries=20,
interval_start=0.1, interval_step=0.02,
interval_max=1.0, emit_warning=False,
**options):
"""Make sure something does not happen (at least for a while)."""
try:
return self.wait_for(
fun, catch, desc=desc, max_retries=max_retries,
interval_start=interval_start, interval_step=interval_step,
interval_max=interval_max, emit_warning=emit_warning,
)
except catch:
pass
else:
raise AssertionError(f'Should not have happened: {desc}')
def retry_over_time(self, *args, **kwargs):
return retry_over_time(*args, **kwargs)
def join(self, r, propagate=False, max_retries=10, **kwargs):
if self.no_join:
return
if not isinstance(r, ResultSet):
r = self.app.ResultSet([r])
received = []
def on_result(task_id, value):
received.append(task_id)
for i in range(max_retries) if max_retries else count(0):
received[:] = []
try:
return r.get(callback=on_result, propagate=propagate, **kwargs)
except (socket.timeout, TimeoutError) as exc:
waiting_for = self.missing_results(r)
self.remark(
'Still waiting for {}/{}: [{}]: {!r}'.format(
len(r) - len(received), len(r),
truncate(', '.join(waiting_for)), exc), '!',
)
except self.connerrors as exc:
self.remark(f'join: connection lost: {exc!r}', '!')
raise AssertionError('Test failed: Missing task results')
def inspect(self, timeout=3.0):
return self.app.control.inspect(timeout=timeout)
def query_tasks(self, ids, timeout=0.5):
tasks = self.inspect(timeout).query_task(*ids) or {}
yield from tasks.items()
def query_task_states(self, ids, timeout=0.5):
states = defaultdict(set)
for hostname, reply in self.query_tasks(ids, timeout=timeout):
for task_id, (state, _) in reply.items():
states[state].add(task_id)
return states
def assert_accepted(self, ids, interval=0.5,
desc='waiting for tasks to be accepted', **policy):
return self.assert_task_worker_state(
self.is_accepted, ids, interval=interval, desc=desc, **policy
)
def assert_received(self, ids, interval=0.5,
desc='waiting for tasks to be received', **policy):
return self.assert_task_worker_state(
self.is_received, ids, interval=interval, desc=desc, **policy
)
def assert_result_tasks_in_progress_or_completed(
self,
async_results,
interval=0.5,
desc='waiting for tasks to be started or completed',
**policy
):
return self.assert_task_state_from_result(
self.is_result_task_in_progress,
async_results,
interval=interval, desc=desc, **policy
)
def assert_task_state_from_result(self, fun, results,
interval=0.5, **policy):
return self.wait_for(
partial(self.true_or_raise, fun, results, timeout=interval),
(Sentinel,), **policy
)
@staticmethod
def is_result_task_in_progress(results, **kwargs):
possible_states = (states.STARTED, states.SUCCESS)
return all(result.state in possible_states for result in results)
def assert_task_worker_state(self, fun, ids, interval=0.5, **policy):
return self.wait_for(
partial(self.true_or_raise, fun, ids, timeout=interval),
(Sentinel,), **policy
)
def is_received(self, ids, **kwargs):
return self._ids_matches_state(
['reserved', 'active', 'ready'], ids, **kwargs)
def is_accepted(self, ids, **kwargs):
return self._ids_matches_state(['active', 'ready'], ids, **kwargs)
def _ids_matches_state(self, expected_states, ids, timeout=0.5):
states = self.query_task_states(ids, timeout=timeout)
return all(
any(t in s for s in [states[k] for k in expected_states])
for t in ids
)
def true_or_raise(self, fun, *args, **kwargs):
res = fun(*args, **kwargs)
if not res:
raise Sentinel()
return res
def wait_until_idle(self):
control = self.app.control
with self.app.connection() as connection:
# Try to purge the queue before we start
# to attempt to avoid interference from other tests
while True:
count = control.purge(connection=connection)
if count == 0:
break
# Wait until worker is idle
inspect = control.inspect()
inspect.connection = connection
while True:
try:
count = sum(len(t) for t in inspect.active().values())
except ContentDisallowed:
# test_security_task_done may trigger this exception
break
if count == 0:
break
class Manager(ManagerMixin):
"""Test helpers for task integration tests."""
def __init__(self, app, **kwargs):
self.app = app
self._init_manager(**kwargs)

View File

@@ -0,0 +1,137 @@
"""Useful mocks for unit testing."""
import numbers
from datetime import datetime, timedelta
from typing import Any, Mapping, Sequence # noqa
from unittest.mock import Mock
from celery import Celery # noqa
from celery.canvas import Signature # noqa
def TaskMessage(
name, # type: str
id=None, # type: str
args=(), # type: Sequence
kwargs=None, # type: Mapping
callbacks=None, # type: Sequence[Signature]
errbacks=None, # type: Sequence[Signature]
chain=None, # type: Sequence[Signature]
shadow=None, # type: str
utc=None, # type: bool
**options # type: Any
):
# type: (...) -> Any
"""Create task message in protocol 2 format."""
kwargs = {} if not kwargs else kwargs
from kombu.serialization import dumps
from celery import uuid
id = id or uuid()
message = Mock(name=f'TaskMessage-{id}')
message.headers = {
'id': id,
'task': name,
'shadow': shadow,
}
embed = {'callbacks': callbacks, 'errbacks': errbacks, 'chain': chain}
message.headers.update(options)
message.content_type, message.content_encoding, message.body = dumps(
(args, kwargs, embed), serializer='json',
)
message.payload = (args, kwargs, embed)
return message
def TaskMessage1(
name, # type: str
id=None, # type: str
args=(), # type: Sequence
kwargs=None, # type: Mapping
callbacks=None, # type: Sequence[Signature]
errbacks=None, # type: Sequence[Signature]
chain=None, # type: Sequence[Signature]
**options # type: Any
):
# type: (...) -> Any
"""Create task message in protocol 1 format."""
kwargs = {} if not kwargs else kwargs
from kombu.serialization import dumps
from celery import uuid
id = id or uuid()
message = Mock(name=f'TaskMessage-{id}')
message.headers = {}
message.payload = {
'task': name,
'id': id,
'args': args,
'kwargs': kwargs,
'callbacks': callbacks,
'errbacks': errbacks,
}
message.payload.update(options)
message.content_type, message.content_encoding, message.body = dumps(
message.payload,
)
return message
def task_message_from_sig(app, sig, utc=True, TaskMessage=TaskMessage):
# type: (Celery, Signature, bool, Any) -> Any
"""Create task message from :class:`celery.Signature`.
Example:
>>> m = task_message_from_sig(app, add.s(2, 2))
>>> amqp_client.basic_publish(m, exchange='ex', routing_key='rkey')
"""
sig.freeze()
callbacks = sig.options.pop('link', None)
errbacks = sig.options.pop('link_error', None)
countdown = sig.options.pop('countdown', None)
if countdown:
eta = app.now() + timedelta(seconds=countdown)
else:
eta = sig.options.pop('eta', None)
if eta and isinstance(eta, datetime):
eta = eta.isoformat()
expires = sig.options.pop('expires', None)
if expires and isinstance(expires, numbers.Real):
expires = app.now() + timedelta(seconds=expires)
if expires and isinstance(expires, datetime):
expires = expires.isoformat()
return TaskMessage(
sig.task, id=sig.id, args=sig.args,
kwargs=sig.kwargs,
callbacks=[dict(s) for s in callbacks] if callbacks else None,
errbacks=[dict(s) for s in errbacks] if errbacks else None,
eta=eta,
expires=expires,
utc=utc,
**sig.options
)
class _ContextMock(Mock):
"""Dummy class implementing __enter__ and __exit__.
The :keyword:`with` statement requires these to be implemented
in the class, not just the instance.
"""
def __enter__(self):
return self
def __exit__(self, *exc_info):
pass
def ContextMock(*args, **kwargs):
"""Mock that mocks :keyword:`with` statement contexts."""
obj = _ContextMock(*args, **kwargs)
obj.attach_mock(_ContextMock(), '__enter__')
obj.attach_mock(_ContextMock(), '__exit__')
obj.__enter__.return_value = obj
# if __exit__ return a value the exception is ignored,
# so it must return None here.
obj.__exit__.return_value = None
return obj

View File

@@ -0,0 +1,9 @@
"""Helper tasks for integration tests."""
from celery import shared_task
@shared_task(name='celery.ping')
def ping():
# type: () -> str
"""Simple task that just returns 'pong'."""
return 'pong'

View File

@@ -0,0 +1,221 @@
"""Embedded workers for integration tests."""
import logging
import os
import threading
from contextlib import contextmanager
from typing import Any, Iterable, Union # noqa
import celery.worker.consumer # noqa
from celery import Celery, worker # noqa
from celery.result import _set_task_join_will_block, allow_join_result
from celery.utils.dispatch import Signal
from celery.utils.nodenames import anon_nodename
WORKER_LOGLEVEL = os.environ.get('WORKER_LOGLEVEL', 'error')
test_worker_starting = Signal(
name='test_worker_starting',
providing_args={},
)
test_worker_started = Signal(
name='test_worker_started',
providing_args={'worker', 'consumer'},
)
test_worker_stopped = Signal(
name='test_worker_stopped',
providing_args={'worker'},
)
class TestWorkController(worker.WorkController):
"""Worker that can synchronize on being fully started."""
logger_queue = None
def __init__(self, *args, **kwargs):
# type: (*Any, **Any) -> None
self._on_started = threading.Event()
super().__init__(*args, **kwargs)
if self.pool_cls.__module__.split('.')[-1] == 'prefork':
from billiard import Queue
self.logger_queue = Queue()
self.pid = os.getpid()
try:
from tblib import pickling_support
pickling_support.install()
except ImportError:
pass
# collect logs from forked process.
# XXX: those logs will appear twice in the live log
self.queue_listener = logging.handlers.QueueListener(self.logger_queue, logging.getLogger())
self.queue_listener.start()
class QueueHandler(logging.handlers.QueueHandler):
def prepare(self, record):
record.from_queue = True
# Keep origin record.
return record
def handleError(self, record):
if logging.raiseExceptions:
raise
def start(self):
if self.logger_queue:
handler = self.QueueHandler(self.logger_queue)
handler.addFilter(lambda r: r.process != self.pid and not getattr(r, 'from_queue', False))
logger = logging.getLogger()
logger.addHandler(handler)
return super().start()
def on_consumer_ready(self, consumer):
# type: (celery.worker.consumer.Consumer) -> None
"""Callback called when the Consumer blueprint is fully started."""
self._on_started.set()
test_worker_started.send(
sender=self.app, worker=self, consumer=consumer)
def ensure_started(self):
# type: () -> None
"""Wait for worker to be fully up and running.
Warning:
Worker must be started within a thread for this to work,
or it will block forever.
"""
self._on_started.wait()
@contextmanager
def start_worker(
app, # type: Celery
concurrency=1, # type: int
pool='solo', # type: str
loglevel=WORKER_LOGLEVEL, # type: Union[str, int]
logfile=None, # type: str
perform_ping_check=True, # type: bool
ping_task_timeout=10.0, # type: float
shutdown_timeout=10.0, # type: float
**kwargs # type: Any
):
# type: (...) -> Iterable
"""Start embedded worker.
Yields:
celery.app.worker.Worker: worker instance.
"""
test_worker_starting.send(sender=app)
worker = None
try:
with _start_worker_thread(app,
concurrency=concurrency,
pool=pool,
loglevel=loglevel,
logfile=logfile,
perform_ping_check=perform_ping_check,
shutdown_timeout=shutdown_timeout,
**kwargs) as worker:
if perform_ping_check:
from .tasks import ping
with allow_join_result():
assert ping.delay().get(timeout=ping_task_timeout) == 'pong'
yield worker
finally:
test_worker_stopped.send(sender=app, worker=worker)
@contextmanager
def _start_worker_thread(app,
concurrency=1,
pool='solo',
loglevel=WORKER_LOGLEVEL,
logfile=None,
WorkController=TestWorkController,
perform_ping_check=True,
shutdown_timeout=10.0,
**kwargs):
# type: (Celery, int, str, Union[str, int], str, Any, **Any) -> Iterable
"""Start Celery worker in a thread.
Yields:
celery.worker.Worker: worker instance.
"""
setup_app_for_worker(app, loglevel, logfile)
if perform_ping_check:
assert 'celery.ping' in app.tasks
# Make sure we can connect to the broker
with app.connection(hostname=os.environ.get('TEST_BROKER')) as conn:
conn.default_channel.queue_declare
worker = WorkController(
app=app,
concurrency=concurrency,
hostname=anon_nodename(),
pool=pool,
loglevel=loglevel,
logfile=logfile,
# not allowed to override TestWorkController.on_consumer_ready
ready_callback=None,
without_heartbeat=kwargs.pop("without_heartbeat", True),
without_mingle=True,
without_gossip=True,
**kwargs)
t = threading.Thread(target=worker.start, daemon=True)
t.start()
worker.ensure_started()
_set_task_join_will_block(False)
try:
yield worker
finally:
from celery.worker import state
state.should_terminate = 0
t.join(shutdown_timeout)
if t.is_alive():
raise RuntimeError(
"Worker thread failed to exit within the allocated timeout. "
"Consider raising `shutdown_timeout` if your tasks take longer "
"to execute."
)
state.should_terminate = None
@contextmanager
def _start_worker_process(app,
concurrency=1,
pool='solo',
loglevel=WORKER_LOGLEVEL,
logfile=None,
**kwargs):
# type (Celery, int, str, Union[int, str], str, **Any) -> Iterable
"""Start worker in separate process.
Yields:
celery.app.worker.Worker: worker instance.
"""
from celery.apps.multi import Cluster, Node
app.set_current()
cluster = Cluster([Node('testworker1@%h')])
cluster.start()
try:
yield
finally:
cluster.stopwait()
def setup_app_for_worker(app, loglevel, logfile) -> None:
# type: (Celery, Union[str, int], str) -> None
"""Setup the app to be used for starting an embedded worker."""
app.finalize()
app.set_current()
app.set_default()
type(app.log)._setup = False
app.log.setup(loglevel=loglevel, logfile=logfile)

View File

@@ -0,0 +1,15 @@
"""Monitoring Event Receiver+Dispatcher.
Events is a stream of messages sent for certain actions occurring
in the worker (and clients if :setting:`task_send_sent_event`
is enabled), used for monitoring purposes.
"""
from .dispatcher import EventDispatcher
from .event import Event, event_exchange, get_exchange, group_from
from .receiver import EventReceiver
__all__ = (
'Event', 'EventDispatcher', 'EventReceiver',
'event_exchange', 'get_exchange', 'group_from',
)

View File

@@ -0,0 +1,534 @@
"""Graphical monitor of Celery events using curses."""
import curses
import sys
import threading
from datetime import datetime
from itertools import count
from math import ceil
from textwrap import wrap
from time import time
from celery import VERSION_BANNER, states
from celery.app import app_or_default
from celery.utils.text import abbr, abbrtask
__all__ = ('CursesMonitor', 'evtop')
BORDER_SPACING = 4
LEFT_BORDER_OFFSET = 3
UUID_WIDTH = 36
STATE_WIDTH = 8
TIMESTAMP_WIDTH = 8
MIN_WORKER_WIDTH = 15
MIN_TASK_WIDTH = 16
# this module is considered experimental
# we don't care about coverage.
STATUS_SCREEN = """\
events: {s.event_count} tasks:{s.task_count} workers:{w_alive}/{w_all}
"""
class CursesMonitor: # pragma: no cover
"""A curses based Celery task monitor."""
keymap = {}
win = None
screen_delay = 10
selected_task = None
selected_position = 0
selected_str = 'Selected: '
foreground = curses.COLOR_BLACK
background = curses.COLOR_WHITE
online_str = 'Workers online: '
help_title = 'Keys: '
help = ('j:down k:up i:info t:traceback r:result c:revoke ^c: quit')
greet = f'celery events {VERSION_BANNER}'
info_str = 'Info: '
def __init__(self, state, app, keymap=None):
self.app = app
self.keymap = keymap or self.keymap
self.state = state
default_keymap = {
'J': self.move_selection_down,
'K': self.move_selection_up,
'C': self.revoke_selection,
'T': self.selection_traceback,
'R': self.selection_result,
'I': self.selection_info,
'L': self.selection_rate_limit,
}
self.keymap = dict(default_keymap, **self.keymap)
self.lock = threading.RLock()
def format_row(self, uuid, task, worker, timestamp, state):
mx = self.display_width
# include spacing
detail_width = mx - 1 - STATE_WIDTH - 1 - TIMESTAMP_WIDTH
uuid_space = detail_width - 1 - MIN_TASK_WIDTH - 1 - MIN_WORKER_WIDTH
if uuid_space < UUID_WIDTH:
uuid_width = uuid_space
else:
uuid_width = UUID_WIDTH
detail_width = detail_width - uuid_width - 1
task_width = int(ceil(detail_width / 2.0))
worker_width = detail_width - task_width - 1
uuid = abbr(uuid, uuid_width).ljust(uuid_width)
worker = abbr(worker, worker_width).ljust(worker_width)
task = abbrtask(task, task_width).ljust(task_width)
state = abbr(state, STATE_WIDTH).ljust(STATE_WIDTH)
timestamp = timestamp.ljust(TIMESTAMP_WIDTH)
row = f'{uuid} {worker} {task} {timestamp} {state} '
if self.screen_width is None:
self.screen_width = len(row[:mx])
return row[:mx]
@property
def screen_width(self):
_, mx = self.win.getmaxyx()
return mx
@property
def screen_height(self):
my, _ = self.win.getmaxyx()
return my
@property
def display_width(self):
_, mx = self.win.getmaxyx()
return mx - BORDER_SPACING
@property
def display_height(self):
my, _ = self.win.getmaxyx()
return my - 10
@property
def limit(self):
return self.display_height
def find_position(self):
if not self.tasks:
return 0
for i, e in enumerate(self.tasks):
if self.selected_task == e[0]:
return i
return 0
def move_selection_up(self):
self.move_selection(-1)
def move_selection_down(self):
self.move_selection(1)
def move_selection(self, direction=1):
if not self.tasks:
return
pos = self.find_position()
try:
self.selected_task = self.tasks[pos + direction][0]
except IndexError:
self.selected_task = self.tasks[0][0]
keyalias = {curses.KEY_DOWN: 'J',
curses.KEY_UP: 'K',
curses.KEY_ENTER: 'I'}
def handle_keypress(self):
try:
key = self.win.getkey().upper()
except Exception: # pylint: disable=broad-except
return
key = self.keyalias.get(key) or key
handler = self.keymap.get(key)
if handler is not None:
handler()
def alert(self, callback, title=None):
self.win.erase()
my, mx = self.win.getmaxyx()
y = blank_line = count(2)
if title:
self.win.addstr(next(y), 3, title,
curses.A_BOLD | curses.A_UNDERLINE)
next(blank_line)
callback(my, mx, next(y))
self.win.addstr(my - 1, 0, 'Press any key to continue...',
curses.A_BOLD)
self.win.refresh()
while 1:
try:
return self.win.getkey().upper()
except Exception: # pylint: disable=broad-except
pass
def selection_rate_limit(self):
if not self.selected_task:
return curses.beep()
task = self.state.tasks[self.selected_task]
if not task.name:
return curses.beep()
my, mx = self.win.getmaxyx()
r = 'New rate limit: '
self.win.addstr(my - 2, 3, r, curses.A_BOLD | curses.A_UNDERLINE)
self.win.addstr(my - 2, len(r) + 3, ' ' * (mx - len(r)))
rlimit = self.readline(my - 2, 3 + len(r))
if rlimit:
reply = self.app.control.rate_limit(task.name,
rlimit.strip(), reply=True)
self.alert_remote_control_reply(reply)
def alert_remote_control_reply(self, reply):
def callback(my, mx, xs):
y = count(xs)
if not reply:
self.win.addstr(
next(y), 3, 'No replies received in 1s deadline.',
curses.A_BOLD + curses.color_pair(2),
)
return
for subreply in reply:
curline = next(y)
host, response = next(subreply.items())
host = f'{host}: '
self.win.addstr(curline, 3, host, curses.A_BOLD)
attr = curses.A_NORMAL
text = ''
if 'error' in response:
text = response['error']
attr |= curses.color_pair(2)
elif 'ok' in response:
text = response['ok']
attr |= curses.color_pair(3)
self.win.addstr(curline, 3 + len(host), text, attr)
return self.alert(callback, 'Remote Control Command Replies')
def readline(self, x, y):
buffer = ''
curses.echo()
try:
i = 0
while 1:
ch = self.win.getch(x, y + i)
if ch != -1:
if ch in (10, curses.KEY_ENTER): # enter
break
if ch in (27,):
buffer = ''
break
buffer += chr(ch)
i += 1
finally:
curses.noecho()
return buffer
def revoke_selection(self):
if not self.selected_task:
return curses.beep()
reply = self.app.control.revoke(self.selected_task, reply=True)
self.alert_remote_control_reply(reply)
def selection_info(self):
if not self.selected_task:
return
def alert_callback(mx, my, xs):
my, mx = self.win.getmaxyx()
y = count(xs)
task = self.state.tasks[self.selected_task]
info = task.info(extra=['state'])
infoitems = [
('args', info.pop('args', None)),
('kwargs', info.pop('kwargs', None))
] + list(info.items())
for key, value in infoitems:
if key is None:
continue
value = str(value)
curline = next(y)
keys = key + ': '
self.win.addstr(curline, 3, keys, curses.A_BOLD)
wrapped = wrap(value, mx - 2)
if len(wrapped) == 1:
self.win.addstr(
curline, len(keys) + 3,
abbr(wrapped[0],
self.screen_width - (len(keys) + 3)))
else:
for subline in wrapped:
nexty = next(y)
if nexty >= my - 1:
subline = ' ' * 4 + '[...]'
self.win.addstr(
nexty, 3,
abbr(' ' * 4 + subline, self.screen_width - 4),
curses.A_NORMAL,
)
return self.alert(
alert_callback, f'Task details for {self.selected_task}',
)
def selection_traceback(self):
if not self.selected_task:
return curses.beep()
task = self.state.tasks[self.selected_task]
if task.state not in states.EXCEPTION_STATES:
return curses.beep()
def alert_callback(my, mx, xs):
y = count(xs)
for line in task.traceback.split('\n'):
self.win.addstr(next(y), 3, line)
return self.alert(
alert_callback,
f'Task Exception Traceback for {self.selected_task}',
)
def selection_result(self):
if not self.selected_task:
return
def alert_callback(my, mx, xs):
y = count(xs)
task = self.state.tasks[self.selected_task]
result = (getattr(task, 'result', None) or
getattr(task, 'exception', None))
for line in wrap(result or '', mx - 2):
self.win.addstr(next(y), 3, line)
return self.alert(
alert_callback,
f'Task Result for {self.selected_task}',
)
def display_task_row(self, lineno, task):
state_color = self.state_colors.get(task.state)
attr = curses.A_NORMAL
if task.uuid == self.selected_task:
attr = curses.A_STANDOUT
timestamp = datetime.utcfromtimestamp(
task.timestamp or time(),
)
timef = timestamp.strftime('%H:%M:%S')
hostname = task.worker.hostname if task.worker else '*NONE*'
line = self.format_row(task.uuid, task.name,
hostname,
timef, task.state)
self.win.addstr(lineno, LEFT_BORDER_OFFSET, line, attr)
if state_color:
self.win.addstr(lineno,
len(line) - STATE_WIDTH + BORDER_SPACING - 1,
task.state, state_color | attr)
def draw(self):
with self.lock:
win = self.win
self.handle_keypress()
x = LEFT_BORDER_OFFSET
y = blank_line = count(2)
my, _ = win.getmaxyx()
win.erase()
win.bkgd(' ', curses.color_pair(1))
win.border()
win.addstr(1, x, self.greet, curses.A_DIM | curses.color_pair(5))
next(blank_line)
win.addstr(next(y), x, self.format_row('UUID', 'TASK',
'WORKER', 'TIME', 'STATE'),
curses.A_BOLD | curses.A_UNDERLINE)
tasks = self.tasks
if tasks:
for row, (_, task) in enumerate(tasks):
if row > self.display_height:
break
if task.uuid:
lineno = next(y)
self.display_task_row(lineno, task)
# -- Footer
next(blank_line)
win.hline(my - 6, x, curses.ACS_HLINE, self.screen_width - 4)
# Selected Task Info
if self.selected_task:
win.addstr(my - 5, x, self.selected_str, curses.A_BOLD)
info = 'Missing extended info'
detail = ''
try:
selection = self.state.tasks[self.selected_task]
except KeyError:
pass
else:
info = selection.info()
if 'runtime' in info:
info['runtime'] = '{:.2f}'.format(info['runtime'])
if 'result' in info:
info['result'] = abbr(info['result'], 16)
info = ' '.join(
f'{key}={value}'
for key, value in info.items()
)
detail = '... -> key i'
infowin = abbr(info,
self.screen_width - len(self.selected_str) - 2,
detail)
win.addstr(my - 5, x + len(self.selected_str), infowin)
# Make ellipsis bold
if detail in infowin:
detailpos = len(infowin) - len(detail)
win.addstr(my - 5, x + len(self.selected_str) + detailpos,
detail, curses.A_BOLD)
else:
win.addstr(my - 5, x, 'No task selected', curses.A_NORMAL)
# Workers
if self.workers:
win.addstr(my - 4, x, self.online_str, curses.A_BOLD)
win.addstr(my - 4, x + len(self.online_str),
', '.join(sorted(self.workers)), curses.A_NORMAL)
else:
win.addstr(my - 4, x, 'No workers discovered.')
# Info
win.addstr(my - 3, x, self.info_str, curses.A_BOLD)
win.addstr(
my - 3, x + len(self.info_str),
STATUS_SCREEN.format(
s=self.state,
w_alive=len([w for w in self.state.workers.values()
if w.alive]),
w_all=len(self.state.workers),
),
curses.A_DIM,
)
# Help
self.safe_add_str(my - 2, x, self.help_title, curses.A_BOLD)
self.safe_add_str(my - 2, x + len(self.help_title), self.help,
curses.A_DIM)
win.refresh()
def safe_add_str(self, y, x, string, *args, **kwargs):
if x + len(string) > self.screen_width:
string = string[:self.screen_width - x]
self.win.addstr(y, x, string, *args, **kwargs)
def init_screen(self):
with self.lock:
self.win = curses.initscr()
self.win.nodelay(True)
self.win.keypad(True)
curses.start_color()
curses.init_pair(1, self.foreground, self.background)
# exception states
curses.init_pair(2, curses.COLOR_RED, self.background)
# successful state
curses.init_pair(3, curses.COLOR_GREEN, self.background)
# revoked state
curses.init_pair(4, curses.COLOR_MAGENTA, self.background)
# greeting
curses.init_pair(5, curses.COLOR_BLUE, self.background)
# started state
curses.init_pair(6, curses.COLOR_YELLOW, self.foreground)
self.state_colors = {states.SUCCESS: curses.color_pair(3),
states.REVOKED: curses.color_pair(4),
states.STARTED: curses.color_pair(6)}
for state in states.EXCEPTION_STATES:
self.state_colors[state] = curses.color_pair(2)
curses.cbreak()
def resetscreen(self):
with self.lock:
curses.nocbreak()
self.win.keypad(False)
curses.echo()
curses.endwin()
def nap(self):
curses.napms(self.screen_delay)
@property
def tasks(self):
return list(self.state.tasks_by_time(limit=self.limit))
@property
def workers(self):
return [hostname for hostname, w in self.state.workers.items()
if w.alive]
class DisplayThread(threading.Thread): # pragma: no cover
def __init__(self, display):
self.display = display
self.shutdown = False
super().__init__()
def run(self):
while not self.shutdown:
self.display.draw()
self.display.nap()
def capture_events(app, state, display): # pragma: no cover
def on_connection_error(exc, interval):
print('Connection Error: {!r}. Retry in {}s.'.format(
exc, interval), file=sys.stderr)
while 1:
print('-> evtop: starting capture...', file=sys.stderr)
with app.connection_for_read() as conn:
try:
conn.ensure_connection(on_connection_error,
app.conf.broker_connection_max_retries)
recv = app.events.Receiver(conn, handlers={'*': state.event})
display.resetscreen()
display.init_screen()
recv.capture()
except conn.connection_errors + conn.channel_errors as exc:
print(f'Connection lost: {exc!r}', file=sys.stderr)
def evtop(app=None): # pragma: no cover
"""Start curses monitor."""
app = app_or_default(app)
state = app.events.State()
display = CursesMonitor(state, app)
display.init_screen()
refresher = DisplayThread(display)
refresher.start()
try:
capture_events(app, state, display)
except Exception:
refresher.shutdown = True
refresher.join()
display.resetscreen()
raise
except (KeyboardInterrupt, SystemExit):
refresher.shutdown = True
refresher.join()
display.resetscreen()
if __name__ == '__main__': # pragma: no cover
evtop()

View File

@@ -0,0 +1,229 @@
"""Event dispatcher sends events."""
import os
import threading
import time
from collections import defaultdict, deque
from kombu import Producer
from celery.app import app_or_default
from celery.utils.nodenames import anon_nodename
from celery.utils.time import utcoffset
from .event import Event, get_exchange, group_from
__all__ = ('EventDispatcher',)
class EventDispatcher:
"""Dispatches event messages.
Arguments:
connection (kombu.Connection): Connection to the broker.
hostname (str): Hostname to identify ourselves as,
by default uses the hostname returned by
:func:`~celery.utils.anon_nodename`.
groups (Sequence[str]): List of groups to send events for.
:meth:`send` will ignore send requests to groups not in this list.
If this is :const:`None`, all events will be sent.
Example groups include ``"task"`` and ``"worker"``.
enabled (bool): Set to :const:`False` to not actually publish any
events, making :meth:`send` a no-op.
channel (kombu.Channel): Can be used instead of `connection` to specify
an exact channel to use when sending events.
buffer_while_offline (bool): If enabled events will be buffered
while the connection is down. :meth:`flush` must be called
as soon as the connection is re-established.
Note:
You need to :meth:`close` this after use.
"""
DISABLED_TRANSPORTS = {'sql'}
app = None
# set of callbacks to be called when :meth:`enabled`.
on_enabled = None
# set of callbacks to be called when :meth:`disabled`.
on_disabled = None
def __init__(self, connection=None, hostname=None, enabled=True,
channel=None, buffer_while_offline=True, app=None,
serializer=None, groups=None, delivery_mode=1,
buffer_group=None, buffer_limit=24, on_send_buffered=None):
self.app = app_or_default(app or self.app)
self.connection = connection
self.channel = channel
self.hostname = hostname or anon_nodename()
self.buffer_while_offline = buffer_while_offline
self.buffer_group = buffer_group or frozenset()
self.buffer_limit = buffer_limit
self.on_send_buffered = on_send_buffered
self._group_buffer = defaultdict(list)
self.mutex = threading.Lock()
self.producer = None
self._outbound_buffer = deque()
self.serializer = serializer or self.app.conf.event_serializer
self.on_enabled = set()
self.on_disabled = set()
self.groups = set(groups or [])
self.tzoffset = [-time.timezone, -time.altzone]
self.clock = self.app.clock
self.delivery_mode = delivery_mode
if not connection and channel:
self.connection = channel.connection.client
self.enabled = enabled
conninfo = self.connection or self.app.connection_for_write()
self.exchange = get_exchange(conninfo,
name=self.app.conf.event_exchange)
if conninfo.transport.driver_type in self.DISABLED_TRANSPORTS:
self.enabled = False
if self.enabled:
self.enable()
self.headers = {'hostname': self.hostname}
self.pid = os.getpid()
def __enter__(self):
return self
def __exit__(self, *exc_info):
self.close()
def enable(self):
self.producer = Producer(self.channel or self.connection,
exchange=self.exchange,
serializer=self.serializer,
auto_declare=False)
self.enabled = True
for callback in self.on_enabled:
callback()
def disable(self):
if self.enabled:
self.enabled = False
self.close()
for callback in self.on_disabled:
callback()
def publish(self, type, fields, producer,
blind=False, Event=Event, **kwargs):
"""Publish event using custom :class:`~kombu.Producer`.
Arguments:
type (str): Event type name, with group separated by dash (`-`).
fields: Dictionary of event fields, must be json serializable.
producer (kombu.Producer): Producer instance to use:
only the ``publish`` method will be called.
retry (bool): Retry in the event of connection failure.
retry_policy (Mapping): Map of custom retry policy options.
See :meth:`~kombu.Connection.ensure`.
blind (bool): Don't set logical clock value (also don't forward
the internal logical clock).
Event (Callable): Event type used to create event.
Defaults to :func:`Event`.
utcoffset (Callable): Function returning the current
utc offset in hours.
"""
clock = None if blind else self.clock.forward()
event = Event(type, hostname=self.hostname, utcoffset=utcoffset(),
pid=self.pid, clock=clock, **fields)
with self.mutex:
return self._publish(event, producer,
routing_key=type.replace('-', '.'), **kwargs)
def _publish(self, event, producer, routing_key, retry=False,
retry_policy=None, utcoffset=utcoffset):
exchange = self.exchange
try:
producer.publish(
event,
routing_key=routing_key,
exchange=exchange.name,
retry=retry,
retry_policy=retry_policy,
declare=[exchange],
serializer=self.serializer,
headers=self.headers,
delivery_mode=self.delivery_mode,
)
except Exception as exc: # pylint: disable=broad-except
if not self.buffer_while_offline:
raise
self._outbound_buffer.append((event, routing_key, exc))
def send(self, type, blind=False, utcoffset=utcoffset, retry=False,
retry_policy=None, Event=Event, **fields):
"""Send event.
Arguments:
type (str): Event type name, with group separated by dash (`-`).
retry (bool): Retry in the event of connection failure.
retry_policy (Mapping): Map of custom retry policy options.
See :meth:`~kombu.Connection.ensure`.
blind (bool): Don't set logical clock value (also don't forward
the internal logical clock).
Event (Callable): Event type used to create event,
defaults to :func:`Event`.
utcoffset (Callable): unction returning the current utc offset
in hours.
**fields (Any): Event fields -- must be json serializable.
"""
if self.enabled:
groups, group = self.groups, group_from(type)
if groups and group not in groups:
return
if group in self.buffer_group:
clock = self.clock.forward()
event = Event(type, hostname=self.hostname,
utcoffset=utcoffset(),
pid=self.pid, clock=clock, **fields)
buf = self._group_buffer[group]
buf.append(event)
if len(buf) >= self.buffer_limit:
self.flush()
elif self.on_send_buffered:
self.on_send_buffered()
else:
return self.publish(type, fields, self.producer, blind=blind,
Event=Event, retry=retry,
retry_policy=retry_policy)
def flush(self, errors=True, groups=True):
"""Flush the outbound buffer."""
if errors:
buf = list(self._outbound_buffer)
try:
with self.mutex:
for event, routing_key, _ in buf:
self._publish(event, self.producer, routing_key)
finally:
self._outbound_buffer.clear()
if groups:
with self.mutex:
for group, events in self._group_buffer.items():
self._publish(events, self.producer, '%s.multi' % group)
events[:] = [] # list.clear
def extend_buffer(self, other):
"""Copy the outbound buffer of another instance."""
self._outbound_buffer.extend(other._outbound_buffer)
def close(self):
"""Close the event dispatcher."""
self.mutex.locked() and self.mutex.release()
self.producer = None
def _get_publisher(self):
return self.producer
def _set_publisher(self, producer):
self.producer = producer
publisher = property(_get_publisher, _set_publisher) # XXX compat

View File

@@ -0,0 +1,103 @@
"""Utility to dump events to screen.
This is a simple program that dumps events to the console
as they happen. Think of it like a `tcpdump` for Celery events.
"""
import sys
from datetime import datetime
from celery.app import app_or_default
from celery.utils.functional import LRUCache
from celery.utils.time import humanize_seconds
__all__ = ('Dumper', 'evdump')
TASK_NAMES = LRUCache(limit=0xFFF)
HUMAN_TYPES = {
'worker-offline': 'shutdown',
'worker-online': 'started',
'worker-heartbeat': 'heartbeat',
}
CONNECTION_ERROR = """\
-> Cannot connect to %s: %s.
Trying again %s
"""
def humanize_type(type):
try:
return HUMAN_TYPES[type.lower()]
except KeyError:
return type.lower().replace('-', ' ')
class Dumper:
"""Monitor events."""
def __init__(self, out=sys.stdout):
self.out = out
def say(self, msg):
print(msg, file=self.out)
# need to flush so that output can be piped.
try:
self.out.flush()
except AttributeError: # pragma: no cover
pass
def on_event(self, ev):
timestamp = datetime.utcfromtimestamp(ev.pop('timestamp'))
type = ev.pop('type').lower()
hostname = ev.pop('hostname')
if type.startswith('task-'):
uuid = ev.pop('uuid')
if type in ('task-received', 'task-sent'):
task = TASK_NAMES[uuid] = '{}({}) args={} kwargs={}' \
.format(ev.pop('name'), uuid,
ev.pop('args'),
ev.pop('kwargs'))
else:
task = TASK_NAMES.get(uuid, '')
return self.format_task_event(hostname, timestamp,
type, task, ev)
fields = ', '.join(
f'{key}={ev[key]}' for key in sorted(ev)
)
sep = fields and ':' or ''
self.say(f'{hostname} [{timestamp}] {humanize_type(type)}{sep} {fields}')
def format_task_event(self, hostname, timestamp, type, task, event):
fields = ', '.join(
f'{key}={event[key]}' for key in sorted(event)
)
sep = fields and ':' or ''
self.say(f'{hostname} [{timestamp}] {humanize_type(type)}{sep} {task} {fields}')
def evdump(app=None, out=sys.stdout):
"""Start event dump."""
app = app_or_default(app)
dumper = Dumper(out=out)
dumper.say('-> evdump: starting capture...')
conn = app.connection_for_read().clone()
def _error_handler(exc, interval):
dumper.say(CONNECTION_ERROR % (
conn.as_uri(), exc, humanize_seconds(interval, 'in', ' ')
))
while 1:
try:
conn.ensure_connection(_error_handler)
recv = app.events.Receiver(conn, handlers={'*': dumper.on_event})
recv.capture()
except (KeyboardInterrupt, SystemExit):
return conn and conn.close()
except conn.connection_errors + conn.channel_errors:
dumper.say('-> Connection lost, attempting reconnect')
if __name__ == '__main__': # pragma: no cover
evdump()

View File

@@ -0,0 +1,63 @@
"""Creating events, and event exchange definition."""
import time
from copy import copy
from kombu import Exchange
__all__ = (
'Event', 'event_exchange', 'get_exchange', 'group_from',
)
EVENT_EXCHANGE_NAME = 'celeryev'
#: Exchange used to send events on.
#: Note: Use :func:`get_exchange` instead, as the type of
#: exchange will vary depending on the broker connection.
event_exchange = Exchange(EVENT_EXCHANGE_NAME, type='topic')
def Event(type, _fields=None, __dict__=dict, __now__=time.time, **fields):
"""Create an event.
Notes:
An event is simply a dictionary: the only required field is ``type``.
A ``timestamp`` field will be set to the current time if not provided.
"""
event = __dict__(_fields, **fields) if _fields else fields
if 'timestamp' not in event:
event.update(timestamp=__now__(), type=type)
else:
event['type'] = type
return event
def group_from(type):
"""Get the group part of an event type name.
Example:
>>> group_from('task-sent')
'task'
>>> group_from('custom-my-event')
'custom'
"""
return type.split('-', 1)[0]
def get_exchange(conn, name=EVENT_EXCHANGE_NAME):
"""Get exchange used for sending events.
Arguments:
conn (kombu.Connection): Connection used for sending/receiving events.
name (str): Name of the exchange. Default is ``celeryev``.
Note:
The event type changes if Redis is used as the transport
(from topic -> fanout).
"""
ex = copy(event_exchange)
if conn.transport.driver_type == 'redis':
# quick hack for Issue #436
ex.type = 'fanout'
if name != ex.name:
ex.name = name
return ex

View File

@@ -0,0 +1,135 @@
"""Event receiver implementation."""
import time
from operator import itemgetter
from kombu import Queue
from kombu.connection import maybe_channel
from kombu.mixins import ConsumerMixin
from celery import uuid
from celery.app import app_or_default
from celery.utils.time import adjust_timestamp
from .event import get_exchange
__all__ = ('EventReceiver',)
CLIENT_CLOCK_SKEW = -1
_TZGETTER = itemgetter('utcoffset', 'timestamp')
class EventReceiver(ConsumerMixin):
"""Capture events.
Arguments:
connection (kombu.Connection): Connection to the broker.
handlers (Mapping[Callable]): Event handlers.
This is a map of event type names and their handlers.
The special handler `"*"` captures all events that don't have a
handler.
"""
app = None
def __init__(self, channel, handlers=None, routing_key='#',
node_id=None, app=None, queue_prefix=None,
accept=None, queue_ttl=None, queue_expires=None):
self.app = app_or_default(app or self.app)
self.channel = maybe_channel(channel)
self.handlers = {} if handlers is None else handlers
self.routing_key = routing_key
self.node_id = node_id or uuid()
self.queue_prefix = queue_prefix or self.app.conf.event_queue_prefix
self.exchange = get_exchange(
self.connection or self.app.connection_for_write(),
name=self.app.conf.event_exchange)
if queue_ttl is None:
queue_ttl = self.app.conf.event_queue_ttl
if queue_expires is None:
queue_expires = self.app.conf.event_queue_expires
self.queue = Queue(
'.'.join([self.queue_prefix, self.node_id]),
exchange=self.exchange,
routing_key=self.routing_key,
auto_delete=True, durable=False,
message_ttl=queue_ttl,
expires=queue_expires,
)
self.clock = self.app.clock
self.adjust_clock = self.clock.adjust
self.forward_clock = self.clock.forward
if accept is None:
accept = {self.app.conf.event_serializer, 'json'}
self.accept = accept
def process(self, type, event):
"""Process event by dispatching to configured handler."""
handler = self.handlers.get(type) or self.handlers.get('*')
handler and handler(event)
def get_consumers(self, Consumer, channel):
return [Consumer(queues=[self.queue],
callbacks=[self._receive], no_ack=True,
accept=self.accept)]
def on_consume_ready(self, connection, channel, consumers,
wakeup=True, **kwargs):
if wakeup:
self.wakeup_workers(channel=channel)
def itercapture(self, limit=None, timeout=None, wakeup=True):
return self.consume(limit=limit, timeout=timeout, wakeup=wakeup)
def capture(self, limit=None, timeout=None, wakeup=True):
"""Open up a consumer capturing events.
This has to run in the main process, and it will never stop
unless :attr:`EventDispatcher.should_stop` is set to True, or
forced via :exc:`KeyboardInterrupt` or :exc:`SystemExit`.
"""
for _ in self.consume(limit=limit, timeout=timeout, wakeup=wakeup):
pass
def wakeup_workers(self, channel=None):
self.app.control.broadcast('heartbeat',
connection=self.connection,
channel=channel)
def event_from_message(self, body, localize=True,
now=time.time, tzfields=_TZGETTER,
adjust_timestamp=adjust_timestamp,
CLIENT_CLOCK_SKEW=CLIENT_CLOCK_SKEW):
type = body['type']
if type == 'task-sent':
# clients never sync so cannot use their clock value
_c = body['clock'] = (self.clock.value or 1) + CLIENT_CLOCK_SKEW
self.adjust_clock(_c)
else:
try:
clock = body['clock']
except KeyError:
body['clock'] = self.forward_clock()
else:
self.adjust_clock(clock)
if localize:
try:
offset, timestamp = tzfields(body)
except KeyError:
pass
else:
body['timestamp'] = adjust_timestamp(timestamp, offset)
body['local_received'] = now()
return type, body
def _receive(self, body, message, list=list, isinstance=isinstance):
if isinstance(body, list): # celery 4.0+: List of events
process, from_message = self.process, self.event_from_message
[process(*from_message(event)) for event in body]
else:
self.process(*self.event_from_message(body))
@property
def connection(self):
return self.channel.connection.client if self.channel else None

View File

@@ -0,0 +1,111 @@
"""Periodically store events in a database.
Consuming the events as a stream isn't always suitable
so this module implements a system to take snapshots of the
state of a cluster at regular intervals. There's a full
implementation of this writing the snapshots to a database
in :mod:`djcelery.snapshots` in the `django-celery` distribution.
"""
from kombu.utils.limits import TokenBucket
from celery import platforms
from celery.app import app_or_default
from celery.utils.dispatch import Signal
from celery.utils.imports import instantiate
from celery.utils.log import get_logger
from celery.utils.time import rate
from celery.utils.timer2 import Timer
__all__ = ('Polaroid', 'evcam')
logger = get_logger('celery.evcam')
class Polaroid:
"""Record event snapshots."""
timer = None
shutter_signal = Signal(name='shutter_signal', providing_args={'state'})
cleanup_signal = Signal(name='cleanup_signal')
clear_after = False
_tref = None
_ctref = None
def __init__(self, state, freq=1.0, maxrate=None,
cleanup_freq=3600.0, timer=None, app=None):
self.app = app_or_default(app)
self.state = state
self.freq = freq
self.cleanup_freq = cleanup_freq
self.timer = timer or self.timer or Timer()
self.logger = logger
self.maxrate = maxrate and TokenBucket(rate(maxrate))
def install(self):
self._tref = self.timer.call_repeatedly(self.freq, self.capture)
self._ctref = self.timer.call_repeatedly(
self.cleanup_freq, self.cleanup,
)
def on_shutter(self, state):
pass
def on_cleanup(self):
pass
def cleanup(self):
logger.debug('Cleanup: Running...')
self.cleanup_signal.send(sender=self.state)
self.on_cleanup()
def shutter(self):
if self.maxrate is None or self.maxrate.can_consume():
logger.debug('Shutter: %s', self.state)
self.shutter_signal.send(sender=self.state)
self.on_shutter(self.state)
def capture(self):
self.state.freeze_while(self.shutter, clear_after=self.clear_after)
def cancel(self):
if self._tref:
self._tref() # flush all received events.
self._tref.cancel()
if self._ctref:
self._ctref.cancel()
def __enter__(self):
self.install()
return self
def __exit__(self, *exc_info):
self.cancel()
def evcam(camera, freq=1.0, maxrate=None, loglevel=0,
logfile=None, pidfile=None, timer=None, app=None,
**kwargs):
"""Start snapshot recorder."""
app = app_or_default(app)
if pidfile:
platforms.create_pidlock(pidfile)
app.log.setup_logging_subsystem(loglevel, logfile)
print(f'-> evcam: Taking snapshots with {camera} (every {freq} secs.)')
state = app.events.State()
cam = instantiate(camera, state, app=app, freq=freq,
maxrate=maxrate, timer=timer)
cam.install()
conn = app.connection_for_read()
recv = app.events.Receiver(conn, handlers={'*': state.event})
try:
try:
recv.capture(limit=None)
except KeyboardInterrupt:
raise SystemExit
finally:
cam.cancel()
conn.close()

View File

@@ -0,0 +1,730 @@
"""In-memory representation of cluster state.
This module implements a data-structure used to keep
track of the state of a cluster of workers and the tasks
it is working on (by consuming events).
For every event consumed the state is updated,
so the state represents the state of the cluster
at the time of the last event.
Snapshots (:mod:`celery.events.snapshot`) can be used to
take "pictures" of this state at regular intervals
to for example, store that in a database.
"""
import bisect
import sys
import threading
from collections import defaultdict
from collections.abc import Callable
from datetime import datetime
from decimal import Decimal
from itertools import islice
from operator import itemgetter
from time import time
from typing import Mapping, Optional # noqa
from weakref import WeakSet, ref
from kombu.clocks import timetuple
from kombu.utils.objects import cached_property
from celery import states
from celery.utils.functional import LRUCache, memoize, pass1
from celery.utils.log import get_logger
__all__ = ('Worker', 'Task', 'State', 'heartbeat_expires')
# pylint: disable=redefined-outer-name
# We cache globals and attribute lookups, so disable this warning.
# pylint: disable=too-many-function-args
# For some reason pylint thinks ._event is a method, when it's a property.
#: Set if running PyPy
PYPY = hasattr(sys, 'pypy_version_info')
#: The window (in percentage) is added to the workers heartbeat
#: frequency. If the time between updates exceeds this window,
#: then the worker is considered to be offline.
HEARTBEAT_EXPIRE_WINDOW = 200
#: Max drift between event timestamp and time of event received
#: before we alert that clocks may be unsynchronized.
HEARTBEAT_DRIFT_MAX = 16
DRIFT_WARNING = (
"Substantial drift from %s may mean clocks are out of sync. Current drift is "
"%s seconds. [orig: %s recv: %s]"
)
logger = get_logger(__name__)
warn = logger.warning
R_STATE = '<State: events={0.event_count} tasks={0.task_count}>'
R_WORKER = '<Worker: {0.hostname} ({0.status_string} clock:{0.clock})'
R_TASK = '<Task: {0.name}({0.uuid}) {0.state} clock:{0.clock}>'
#: Mapping of task event names to task state.
TASK_EVENT_TO_STATE = {
'sent': states.PENDING,
'received': states.RECEIVED,
'started': states.STARTED,
'failed': states.FAILURE,
'retried': states.RETRY,
'succeeded': states.SUCCESS,
'revoked': states.REVOKED,
'rejected': states.REJECTED,
}
class CallableDefaultdict(defaultdict):
""":class:`~collections.defaultdict` with configurable __call__.
We use this for backwards compatibility in State.tasks_by_type
etc, which used to be a method but is now an index instead.
So you can do::
>>> add_tasks = state.tasks_by_type['proj.tasks.add']
while still supporting the method call::
>>> add_tasks = list(state.tasks_by_type(
... 'proj.tasks.add', reverse=True))
"""
def __init__(self, fun, *args, **kwargs):
self.fun = fun
super().__init__(*args, **kwargs)
def __call__(self, *args, **kwargs):
return self.fun(*args, **kwargs)
Callable.register(CallableDefaultdict)
@memoize(maxsize=1000, keyfun=lambda a, _: a[0])
def _warn_drift(hostname, drift, local_received, timestamp):
# we use memoize here so the warning is only logged once per hostname
warn(DRIFT_WARNING, hostname, drift,
datetime.fromtimestamp(local_received),
datetime.fromtimestamp(timestamp))
def heartbeat_expires(timestamp, freq=60,
expire_window=HEARTBEAT_EXPIRE_WINDOW,
Decimal=Decimal, float=float, isinstance=isinstance):
"""Return time when heartbeat expires."""
# some json implementations returns decimal.Decimal objects,
# which aren't compatible with float.
freq = float(freq) if isinstance(freq, Decimal) else freq
if isinstance(timestamp, Decimal):
timestamp = float(timestamp)
return timestamp + (freq * (expire_window / 1e2))
def _depickle_task(cls, fields):
return cls(**fields)
def with_unique_field(attr):
def _decorate_cls(cls):
def __eq__(this, other):
if isinstance(other, this.__class__):
return getattr(this, attr) == getattr(other, attr)
return NotImplemented
cls.__eq__ = __eq__
def __hash__(this):
return hash(getattr(this, attr))
cls.__hash__ = __hash__
return cls
return _decorate_cls
@with_unique_field('hostname')
class Worker:
"""Worker State."""
heartbeat_max = 4
expire_window = HEARTBEAT_EXPIRE_WINDOW
_fields = ('hostname', 'pid', 'freq', 'heartbeats', 'clock',
'active', 'processed', 'loadavg', 'sw_ident',
'sw_ver', 'sw_sys')
if not PYPY: # pragma: no cover
__slots__ = _fields + ('event', '__dict__', '__weakref__')
def __init__(self, hostname=None, pid=None, freq=60,
heartbeats=None, clock=0, active=None, processed=None,
loadavg=None, sw_ident=None, sw_ver=None, sw_sys=None):
self.hostname = hostname
self.pid = pid
self.freq = freq
self.heartbeats = [] if heartbeats is None else heartbeats
self.clock = clock or 0
self.active = active
self.processed = processed
self.loadavg = loadavg
self.sw_ident = sw_ident
self.sw_ver = sw_ver
self.sw_sys = sw_sys
self.event = self._create_event_handler()
def __reduce__(self):
return self.__class__, (self.hostname, self.pid, self.freq,
self.heartbeats, self.clock, self.active,
self.processed, self.loadavg, self.sw_ident,
self.sw_ver, self.sw_sys)
def _create_event_handler(self):
_set = object.__setattr__
hbmax = self.heartbeat_max
heartbeats = self.heartbeats
hb_pop = self.heartbeats.pop
hb_append = self.heartbeats.append
def event(type_, timestamp=None,
local_received=None, fields=None,
max_drift=HEARTBEAT_DRIFT_MAX, abs=abs, int=int,
insort=bisect.insort, len=len):
fields = fields or {}
for k, v in fields.items():
_set(self, k, v)
if type_ == 'offline':
heartbeats[:] = []
else:
if not local_received or not timestamp:
return
drift = abs(int(local_received) - int(timestamp))
if drift > max_drift:
_warn_drift(self.hostname, drift,
local_received, timestamp)
if local_received: # pragma: no cover
hearts = len(heartbeats)
if hearts > hbmax - 1:
hb_pop(0)
if hearts and local_received > heartbeats[-1]:
hb_append(local_received)
else:
insort(heartbeats, local_received)
return event
def update(self, f, **kw):
d = dict(f, **kw) if kw else f
for k, v in d.items():
setattr(self, k, v)
def __repr__(self):
return R_WORKER.format(self)
@property
def status_string(self):
return 'ONLINE' if self.alive else 'OFFLINE'
@property
def heartbeat_expires(self):
return heartbeat_expires(self.heartbeats[-1],
self.freq, self.expire_window)
@property
def alive(self, nowfun=time):
return bool(self.heartbeats and nowfun() < self.heartbeat_expires)
@property
def id(self):
return '{0.hostname}.{0.pid}'.format(self)
@with_unique_field('uuid')
class Task:
"""Task State."""
name = received = sent = started = succeeded = failed = retried = \
revoked = rejected = args = kwargs = eta = expires = retries = \
worker = result = exception = timestamp = runtime = traceback = \
exchange = routing_key = root_id = parent_id = client = None
state = states.PENDING
clock = 0
_fields = (
'uuid', 'name', 'state', 'received', 'sent', 'started', 'rejected',
'succeeded', 'failed', 'retried', 'revoked', 'args', 'kwargs',
'eta', 'expires', 'retries', 'worker', 'result', 'exception',
'timestamp', 'runtime', 'traceback', 'exchange', 'routing_key',
'clock', 'client', 'root', 'root_id', 'parent', 'parent_id',
'children',
)
if not PYPY: # pragma: no cover
__slots__ = ('__dict__', '__weakref__')
#: How to merge out of order events.
#: Disorder is detected by logical ordering (e.g., :event:`task-received`
#: must've happened before a :event:`task-failed` event).
#:
#: A merge rule consists of a state and a list of fields to keep from
#: that state. ``(RECEIVED, ('name', 'args')``, means the name and args
#: fields are always taken from the RECEIVED state, and any values for
#: these fields received before or after is simply ignored.
merge_rules = {
states.RECEIVED: (
'name', 'args', 'kwargs', 'parent_id',
'root_id', 'retries', 'eta', 'expires',
),
}
#: meth:`info` displays these fields by default.
_info_fields = (
'args', 'kwargs', 'retries', 'result', 'eta', 'runtime',
'expires', 'exception', 'exchange', 'routing_key',
'root_id', 'parent_id',
)
def __init__(self, uuid=None, cluster_state=None, children=None, **kwargs):
self.uuid = uuid
self.cluster_state = cluster_state
if self.cluster_state is not None:
self.children = WeakSet(
self.cluster_state.tasks.get(task_id)
for task_id in children or ()
if task_id in self.cluster_state.tasks
)
else:
self.children = WeakSet()
self._serializer_handlers = {
'children': self._serializable_children,
'root': self._serializable_root,
'parent': self._serializable_parent,
}
if kwargs:
self.__dict__.update(kwargs)
def event(self, type_, timestamp=None, local_received=None, fields=None,
precedence=states.precedence, setattr=setattr,
task_event_to_state=TASK_EVENT_TO_STATE.get, RETRY=states.RETRY):
fields = fields or {}
# using .get is faster than catching KeyError in this case.
state = task_event_to_state(type_)
if state is not None:
# sets, for example, self.succeeded to the timestamp.
setattr(self, type_, timestamp)
else:
state = type_.upper() # custom state
# note that precedence here is reversed
# see implementation in celery.states.state.__lt__
if state != RETRY and self.state != RETRY and \
precedence(state) > precedence(self.state):
# this state logically happens-before the current state, so merge.
keep = self.merge_rules.get(state)
if keep is not None:
fields = {
k: v for k, v in fields.items() if k in keep
}
else:
fields.update(state=state, timestamp=timestamp)
# update current state with info from this event.
self.__dict__.update(fields)
def info(self, fields=None, extra=None):
"""Information about this task suitable for on-screen display."""
extra = [] if not extra else extra
fields = self._info_fields if fields is None else fields
def _keys():
for key in list(fields) + list(extra):
value = getattr(self, key, None)
if value is not None:
yield key, value
return dict(_keys())
def __repr__(self):
return R_TASK.format(self)
def as_dict(self):
get = object.__getattribute__
handler = self._serializer_handlers.get
return {
k: handler(k, pass1)(get(self, k)) for k in self._fields
}
def _serializable_children(self, value):
return [task.id for task in self.children]
def _serializable_root(self, value):
return self.root_id
def _serializable_parent(self, value):
return self.parent_id
def __reduce__(self):
return _depickle_task, (self.__class__, self.as_dict())
@property
def id(self):
return self.uuid
@property
def origin(self):
return self.client if self.worker is None else self.worker.id
@property
def ready(self):
return self.state in states.READY_STATES
@cached_property
def parent(self):
# issue github.com/mher/flower/issues/648
try:
return self.parent_id and self.cluster_state.tasks.data[self.parent_id]
except KeyError:
return None
@cached_property
def root(self):
# issue github.com/mher/flower/issues/648
try:
return self.root_id and self.cluster_state.tasks.data[self.root_id]
except KeyError:
return None
class State:
"""Records clusters state."""
Worker = Worker
Task = Task
event_count = 0
task_count = 0
heap_multiplier = 4
def __init__(self, callback=None,
workers=None, tasks=None, taskheap=None,
max_workers_in_memory=5000, max_tasks_in_memory=10000,
on_node_join=None, on_node_leave=None,
tasks_by_type=None, tasks_by_worker=None):
self.event_callback = callback
self.workers = (LRUCache(max_workers_in_memory)
if workers is None else workers)
self.tasks = (LRUCache(max_tasks_in_memory)
if tasks is None else tasks)
self._taskheap = [] if taskheap is None else taskheap
self.max_workers_in_memory = max_workers_in_memory
self.max_tasks_in_memory = max_tasks_in_memory
self.on_node_join = on_node_join
self.on_node_leave = on_node_leave
self._mutex = threading.Lock()
self.handlers = {}
self._seen_types = set()
self._tasks_to_resolve = {}
self.rebuild_taskheap()
self.tasks_by_type = CallableDefaultdict(
self._tasks_by_type, WeakSet) # type: Mapping[str, WeakSet[Task]]
self.tasks_by_type.update(
_deserialize_Task_WeakSet_Mapping(tasks_by_type, self.tasks))
self.tasks_by_worker = CallableDefaultdict(
self._tasks_by_worker, WeakSet) # type: Mapping[str, WeakSet[Task]]
self.tasks_by_worker.update(
_deserialize_Task_WeakSet_Mapping(tasks_by_worker, self.tasks))
@cached_property
def _event(self):
return self._create_dispatcher()
def freeze_while(self, fun, *args, **kwargs):
clear_after = kwargs.pop('clear_after', False)
with self._mutex:
try:
return fun(*args, **kwargs)
finally:
if clear_after:
self._clear()
def clear_tasks(self, ready=True):
with self._mutex:
return self._clear_tasks(ready)
def _clear_tasks(self, ready: bool = True):
if ready:
in_progress = {
uuid: task for uuid, task in self.itertasks()
if task.state not in states.READY_STATES
}
self.tasks.clear()
self.tasks.update(in_progress)
else:
self.tasks.clear()
self._taskheap[:] = []
def _clear(self, ready=True):
self.workers.clear()
self._clear_tasks(ready)
self.event_count = 0
self.task_count = 0
def clear(self, ready: bool = True):
with self._mutex:
return self._clear(ready)
def get_or_create_worker(self, hostname, **kwargs):
"""Get or create worker by hostname.
Returns:
Tuple: of ``(worker, was_created)`` pairs.
"""
try:
worker = self.workers[hostname]
if kwargs:
worker.update(kwargs)
return worker, False
except KeyError:
worker = self.workers[hostname] = self.Worker(
hostname, **kwargs)
return worker, True
def get_or_create_task(self, uuid):
"""Get or create task by uuid."""
try:
return self.tasks[uuid], False
except KeyError:
task = self.tasks[uuid] = self.Task(uuid, cluster_state=self)
return task, True
def event(self, event):
with self._mutex:
return self._event(event)
def task_event(self, type_, fields):
"""Deprecated, use :meth:`event`."""
return self._event(dict(fields, type='-'.join(['task', type_])))[0]
def worker_event(self, type_, fields):
"""Deprecated, use :meth:`event`."""
return self._event(dict(fields, type='-'.join(['worker', type_])))[0]
def _create_dispatcher(self):
# pylint: disable=too-many-statements
# This code is highly optimized, but not for reusability.
get_handler = self.handlers.__getitem__
event_callback = self.event_callback
wfields = itemgetter('hostname', 'timestamp', 'local_received')
tfields = itemgetter('uuid', 'hostname', 'timestamp',
'local_received', 'clock')
taskheap = self._taskheap
th_append = taskheap.append
th_pop = taskheap.pop
# Removing events from task heap is an O(n) operation,
# so easier to just account for the common number of events
# for each task (PENDING->RECEIVED->STARTED->final)
#: an O(n) operation
max_events_in_heap = self.max_tasks_in_memory * self.heap_multiplier
add_type = self._seen_types.add
on_node_join, on_node_leave = self.on_node_join, self.on_node_leave
tasks, Task = self.tasks, self.Task
workers, Worker = self.workers, self.Worker
# avoid updating LRU entry at getitem
get_worker, get_task = workers.data.__getitem__, tasks.data.__getitem__
get_task_by_type_set = self.tasks_by_type.__getitem__
get_task_by_worker_set = self.tasks_by_worker.__getitem__
def _event(event,
timetuple=timetuple, KeyError=KeyError,
insort=bisect.insort, created=True):
self.event_count += 1
if event_callback:
event_callback(self, event)
group, _, subject = event['type'].partition('-')
try:
handler = get_handler(group)
except KeyError:
pass
else:
return handler(subject, event), subject
if group == 'worker':
try:
hostname, timestamp, local_received = wfields(event)
except KeyError:
pass
else:
is_offline = subject == 'offline'
try:
worker, created = get_worker(hostname), False
except KeyError:
if is_offline:
worker, created = Worker(hostname), False
else:
worker = workers[hostname] = Worker(hostname)
worker.event(subject, timestamp, local_received, event)
if on_node_join and (created or subject == 'online'):
on_node_join(worker)
if on_node_leave and is_offline:
on_node_leave(worker)
workers.pop(hostname, None)
return (worker, created), subject
elif group == 'task':
(uuid, hostname, timestamp,
local_received, clock) = tfields(event)
# task-sent event is sent by client, not worker
is_client_event = subject == 'sent'
try:
task, task_created = get_task(uuid), False
except KeyError:
task = tasks[uuid] = Task(uuid, cluster_state=self)
task_created = True
if is_client_event:
task.client = hostname
else:
try:
worker = get_worker(hostname)
except KeyError:
worker = workers[hostname] = Worker(hostname)
task.worker = worker
if worker is not None and local_received:
worker.event(None, local_received, timestamp)
origin = hostname if is_client_event else worker.id
# remove oldest event if exceeding the limit.
heaps = len(taskheap)
if heaps + 1 > max_events_in_heap:
th_pop(0)
# most events will be dated later than the previous.
timetup = timetuple(clock, timestamp, origin, ref(task))
if heaps and timetup > taskheap[-1]:
th_append(timetup)
else:
insort(taskheap, timetup)
if subject == 'received':
self.task_count += 1
task.event(subject, timestamp, local_received, event)
task_name = task.name
if task_name is not None:
add_type(task_name)
if task_created: # add to tasks_by_type index
get_task_by_type_set(task_name).add(task)
get_task_by_worker_set(hostname).add(task)
if task.parent_id:
try:
parent_task = self.tasks[task.parent_id]
except KeyError:
self._add_pending_task_child(task)
else:
parent_task.children.add(task)
try:
_children = self._tasks_to_resolve.pop(uuid)
except KeyError:
pass
else:
task.children.update(_children)
return (task, task_created), subject
return _event
def _add_pending_task_child(self, task):
try:
ch = self._tasks_to_resolve[task.parent_id]
except KeyError:
ch = self._tasks_to_resolve[task.parent_id] = WeakSet()
ch.add(task)
def rebuild_taskheap(self, timetuple=timetuple):
heap = self._taskheap[:] = [
timetuple(t.clock, t.timestamp, t.origin, ref(t))
for t in self.tasks.values()
]
heap.sort()
def itertasks(self, limit: Optional[int] = None):
for index, row in enumerate(self.tasks.items()):
yield row
if limit and index + 1 >= limit:
break
def tasks_by_time(self, limit=None, reverse: bool = True):
"""Generator yielding tasks ordered by time.
Yields:
Tuples of ``(uuid, Task)``.
"""
_heap = self._taskheap
if reverse:
_heap = reversed(_heap)
seen = set()
for evtup in islice(_heap, 0, limit):
task = evtup[3]()
if task is not None:
uuid = task.uuid
if uuid not in seen:
yield uuid, task
seen.add(uuid)
tasks_by_timestamp = tasks_by_time
def _tasks_by_type(self, name, limit=None, reverse=True):
"""Get all tasks by type.
This is slower than accessing :attr:`tasks_by_type`,
but will be ordered by time.
Returns:
Generator: giving ``(uuid, Task)`` pairs.
"""
return islice(
((uuid, task) for uuid, task in self.tasks_by_time(reverse=reverse)
if task.name == name),
0, limit,
)
def _tasks_by_worker(self, hostname, limit=None, reverse=True):
"""Get all tasks by worker.
Slower than accessing :attr:`tasks_by_worker`, but ordered by time.
"""
return islice(
((uuid, task) for uuid, task in self.tasks_by_time(reverse=reverse)
if task.worker.hostname == hostname),
0, limit,
)
def task_types(self):
"""Return a list of all seen task types."""
return sorted(self._seen_types)
def alive_workers(self):
"""Return a list of (seemingly) alive workers."""
return (w for w in self.workers.values() if w.alive)
def __repr__(self):
return R_STATE.format(self)
def __reduce__(self):
return self.__class__, (
self.event_callback, self.workers, self.tasks, None,
self.max_workers_in_memory, self.max_tasks_in_memory,
self.on_node_join, self.on_node_leave,
_serialize_Task_WeakSet_Mapping(self.tasks_by_type),
_serialize_Task_WeakSet_Mapping(self.tasks_by_worker),
)
def _serialize_Task_WeakSet_Mapping(mapping):
return {name: [t.id for t in tasks] for name, tasks in mapping.items()}
def _deserialize_Task_WeakSet_Mapping(mapping, tasks):
mapping = mapping or {}
return {name: WeakSet(tasks[i] for i in ids if i in tasks)
for name, ids in mapping.items()}

View File

@@ -0,0 +1,312 @@
"""Celery error types.
Error Hierarchy
===============
- :exc:`Exception`
- :exc:`celery.exceptions.CeleryError`
- :exc:`~celery.exceptions.ImproperlyConfigured`
- :exc:`~celery.exceptions.SecurityError`
- :exc:`~celery.exceptions.TaskPredicate`
- :exc:`~celery.exceptions.Ignore`
- :exc:`~celery.exceptions.Reject`
- :exc:`~celery.exceptions.Retry`
- :exc:`~celery.exceptions.TaskError`
- :exc:`~celery.exceptions.QueueNotFound`
- :exc:`~celery.exceptions.IncompleteStream`
- :exc:`~celery.exceptions.NotRegistered`
- :exc:`~celery.exceptions.AlreadyRegistered`
- :exc:`~celery.exceptions.TimeoutError`
- :exc:`~celery.exceptions.MaxRetriesExceededError`
- :exc:`~celery.exceptions.TaskRevokedError`
- :exc:`~celery.exceptions.InvalidTaskError`
- :exc:`~celery.exceptions.ChordError`
- :exc:`~celery.exceptions.BackendError`
- :exc:`~celery.exceptions.BackendGetMetaError`
- :exc:`~celery.exceptions.BackendStoreError`
- :class:`kombu.exceptions.KombuError`
- :exc:`~celery.exceptions.OperationalError`
Raised when a transport connection error occurs while
sending a message (be it a task, remote control command error).
.. note::
This exception does not inherit from
:exc:`~celery.exceptions.CeleryError`.
- **billiard errors** (prefork pool)
- :exc:`~celery.exceptions.SoftTimeLimitExceeded`
- :exc:`~celery.exceptions.TimeLimitExceeded`
- :exc:`~celery.exceptions.WorkerLostError`
- :exc:`~celery.exceptions.Terminated`
- :class:`UserWarning`
- :class:`~celery.exceptions.CeleryWarning`
- :class:`~celery.exceptions.AlwaysEagerIgnored`
- :class:`~celery.exceptions.DuplicateNodenameWarning`
- :class:`~celery.exceptions.FixupWarning`
- :class:`~celery.exceptions.NotConfigured`
- :class:`~celery.exceptions.SecurityWarning`
- :exc:`BaseException`
- :exc:`SystemExit`
- :exc:`~celery.exceptions.WorkerTerminate`
- :exc:`~celery.exceptions.WorkerShutdown`
"""
import numbers
from billiard.exceptions import SoftTimeLimitExceeded, Terminated, TimeLimitExceeded, WorkerLostError
from click import ClickException
from kombu.exceptions import OperationalError
__all__ = (
'reraise',
# Warnings
'CeleryWarning',
'AlwaysEagerIgnored', 'DuplicateNodenameWarning',
'FixupWarning', 'NotConfigured', 'SecurityWarning',
# Core errors
'CeleryError',
'ImproperlyConfigured', 'SecurityError',
# Kombu (messaging) errors.
'OperationalError',
# Task semi-predicates
'TaskPredicate', 'Ignore', 'Reject', 'Retry',
# Task related errors.
'TaskError', 'QueueNotFound', 'IncompleteStream',
'NotRegistered', 'AlreadyRegistered', 'TimeoutError',
'MaxRetriesExceededError', 'TaskRevokedError',
'InvalidTaskError', 'ChordError',
# Backend related errors.
'BackendError', 'BackendGetMetaError', 'BackendStoreError',
# Billiard task errors.
'SoftTimeLimitExceeded', 'TimeLimitExceeded',
'WorkerLostError', 'Terminated',
# Deprecation warnings (forcing Python to emit them).
'CPendingDeprecationWarning', 'CDeprecationWarning',
# Worker shutdown semi-predicates (inherits from SystemExit).
'WorkerShutdown', 'WorkerTerminate',
'CeleryCommandException',
)
from celery.utils.serialization import get_pickleable_exception
UNREGISTERED_FMT = """\
Task of kind {0} never registered, please make sure it's imported.\
"""
def reraise(tp, value, tb=None):
"""Reraise exception."""
if value.__traceback__ is not tb:
raise value.with_traceback(tb)
raise value
class CeleryWarning(UserWarning):
"""Base class for all Celery warnings."""
class AlwaysEagerIgnored(CeleryWarning):
"""send_task ignores :setting:`task_always_eager` option."""
class DuplicateNodenameWarning(CeleryWarning):
"""Multiple workers are using the same nodename."""
class FixupWarning(CeleryWarning):
"""Fixup related warning."""
class NotConfigured(CeleryWarning):
"""Celery hasn't been configured, as no config module has been found."""
class SecurityWarning(CeleryWarning):
"""Potential security issue found."""
class CeleryError(Exception):
"""Base class for all Celery errors."""
class TaskPredicate(CeleryError):
"""Base class for task-related semi-predicates."""
class Retry(TaskPredicate):
"""The task is to be retried later."""
#: Optional message describing context of retry.
message = None
#: Exception (if any) that caused the retry to happen.
exc = None
#: Time of retry (ETA), either :class:`numbers.Real` or
#: :class:`~datetime.datetime`.
when = None
def __init__(self, message=None, exc=None, when=None, is_eager=False,
sig=None, **kwargs):
from kombu.utils.encoding import safe_repr
self.message = message
if isinstance(exc, str):
self.exc, self.excs = None, exc
else:
self.exc, self.excs = get_pickleable_exception(exc), safe_repr(exc) if exc else None
self.when = when
self.is_eager = is_eager
self.sig = sig
super().__init__(self, exc, when, **kwargs)
def humanize(self):
if isinstance(self.when, numbers.Number):
return f'in {self.when}s'
return f'at {self.when}'
def __str__(self):
if self.message:
return self.message
if self.excs:
return f'Retry {self.humanize()}: {self.excs}'
return f'Retry {self.humanize()}'
def __reduce__(self):
return self.__class__, (self.message, self.exc, self.when)
RetryTaskError = Retry # XXX compat
class Ignore(TaskPredicate):
"""A task can raise this to ignore doing state updates."""
class Reject(TaskPredicate):
"""A task can raise this if it wants to reject/re-queue the message."""
def __init__(self, reason=None, requeue=False):
self.reason = reason
self.requeue = requeue
super().__init__(reason, requeue)
def __repr__(self):
return f'reject requeue={self.requeue}: {self.reason}'
class ImproperlyConfigured(CeleryError):
"""Celery is somehow improperly configured."""
class SecurityError(CeleryError):
"""Security related exception."""
class TaskError(CeleryError):
"""Task related errors."""
class QueueNotFound(KeyError, TaskError):
"""Task routed to a queue not in ``conf.queues``."""
class IncompleteStream(TaskError):
"""Found the end of a stream of data, but the data isn't complete."""
class NotRegistered(KeyError, TaskError):
"""The task is not registered."""
def __repr__(self):
return UNREGISTERED_FMT.format(self)
class AlreadyRegistered(TaskError):
"""The task is already registered."""
# XXX Unused
class TimeoutError(TaskError):
"""The operation timed out."""
class MaxRetriesExceededError(TaskError):
"""The tasks max restart limit has been exceeded."""
def __init__(self, *args, **kwargs):
self.task_args = kwargs.pop("task_args", [])
self.task_kwargs = kwargs.pop("task_kwargs", dict())
super().__init__(*args, **kwargs)
class TaskRevokedError(TaskError):
"""The task has been revoked, so no result available."""
class InvalidTaskError(TaskError):
"""The task has invalid data or ain't properly constructed."""
class ChordError(TaskError):
"""A task part of the chord raised an exception."""
class CPendingDeprecationWarning(PendingDeprecationWarning):
"""Warning of pending deprecation."""
class CDeprecationWarning(DeprecationWarning):
"""Warning of deprecation."""
class WorkerTerminate(SystemExit):
"""Signals that the worker should terminate immediately."""
SystemTerminate = WorkerTerminate # XXX compat
class WorkerShutdown(SystemExit):
"""Signals that the worker should perform a warm shutdown."""
class BackendError(Exception):
"""An issue writing or reading to/from the backend."""
class BackendGetMetaError(BackendError):
"""An issue reading from the backend."""
def __init__(self, *args, **kwargs):
self.task_id = kwargs.get('task_id', "")
def __repr__(self):
return super().__repr__() + " task_id:" + self.task_id
class BackendStoreError(BackendError):
"""An issue writing to the backend."""
def __init__(self, *args, **kwargs):
self.state = kwargs.get('state', "")
self.task_id = kwargs.get('task_id', "")
def __repr__(self):
return super().__repr__() + " state:" + self.state + " task_id:" + self.task_id
class CeleryCommandException(ClickException):
"""A general command exception which stores an exit code."""
def __init__(self, message, exit_code):
super().__init__(message=message)
self.exit_code = exit_code

View File

@@ -0,0 +1 @@
"""Fixups."""

View File

@@ -0,0 +1,213 @@
"""Django-specific customization."""
import os
import sys
import warnings
from datetime import datetime
from importlib import import_module
from typing import IO, TYPE_CHECKING, Any, List, Optional, cast
from kombu.utils.imports import symbol_by_name
from kombu.utils.objects import cached_property
from celery import _state, signals
from celery.exceptions import FixupWarning, ImproperlyConfigured
if TYPE_CHECKING:
from types import ModuleType
from typing import Protocol
from django.db.utils import ConnectionHandler
from celery.app.base import Celery
from celery.app.task import Task
class DjangoDBModule(Protocol):
connections: ConnectionHandler
__all__ = ('DjangoFixup', 'fixup')
ERR_NOT_INSTALLED = """\
Environment variable DJANGO_SETTINGS_MODULE is defined
but Django isn't installed. Won't apply Django fix-ups!
"""
def _maybe_close_fd(fh: IO) -> None:
try:
os.close(fh.fileno())
except (AttributeError, OSError, TypeError):
# TypeError added for celery#962
pass
def _verify_django_version(django: "ModuleType") -> None:
if django.VERSION < (1, 11):
raise ImproperlyConfigured('Celery 5.x requires Django 1.11 or later.')
def fixup(app: "Celery", env: str = 'DJANGO_SETTINGS_MODULE') -> Optional["DjangoFixup"]:
"""Install Django fixup if settings module environment is set."""
SETTINGS_MODULE = os.environ.get(env)
if SETTINGS_MODULE and 'django' not in app.loader_cls.lower():
try:
import django
except ImportError:
warnings.warn(FixupWarning(ERR_NOT_INSTALLED))
else:
_verify_django_version(django)
return DjangoFixup(app).install()
return None
class DjangoFixup:
"""Fixup installed when using Django."""
def __init__(self, app: "Celery"):
self.app = app
if _state.default_app is None:
self.app.set_default()
self._worker_fixup: Optional["DjangoWorkerFixup"] = None
def install(self) -> "DjangoFixup":
# Need to add project directory to path.
# The project directory has precedence over system modules,
# so we prepend it to the path.
sys.path.insert(0, os.getcwd())
self._settings = symbol_by_name('django.conf:settings')
self.app.loader.now = self.now
signals.import_modules.connect(self.on_import_modules)
signals.worker_init.connect(self.on_worker_init)
return self
@property
def worker_fixup(self) -> "DjangoWorkerFixup":
if self._worker_fixup is None:
self._worker_fixup = DjangoWorkerFixup(self.app)
return self._worker_fixup
@worker_fixup.setter
def worker_fixup(self, value: "DjangoWorkerFixup") -> None:
self._worker_fixup = value
def on_import_modules(self, **kwargs: Any) -> None:
# call django.setup() before task modules are imported
self.worker_fixup.validate_models()
def on_worker_init(self, **kwargs: Any) -> None:
self.worker_fixup.install()
def now(self, utc: bool = False) -> datetime:
return datetime.utcnow() if utc else self._now()
def autodiscover_tasks(self) -> List[str]:
from django.apps import apps
return [config.name for config in apps.get_app_configs()]
@cached_property
def _now(self) -> datetime:
return symbol_by_name('django.utils.timezone:now')
class DjangoWorkerFixup:
_db_recycles = 0
def __init__(self, app: "Celery") -> None:
self.app = app
self.db_reuse_max = self.app.conf.get('CELERY_DB_REUSE_MAX', None)
self._db = cast("DjangoDBModule", import_module('django.db'))
self._cache = import_module('django.core.cache')
self._settings = symbol_by_name('django.conf:settings')
self.interface_errors = (
symbol_by_name('django.db.utils.InterfaceError'),
)
self.DatabaseError = symbol_by_name('django.db:DatabaseError')
def django_setup(self) -> None:
import django
django.setup()
def validate_models(self) -> None:
from django.core.checks import run_checks
self.django_setup()
if not os.environ.get('CELERY_SKIP_CHECKS'):
run_checks()
def install(self) -> "DjangoWorkerFixup":
signals.beat_embedded_init.connect(self.close_database)
signals.task_prerun.connect(self.on_task_prerun)
signals.task_postrun.connect(self.on_task_postrun)
signals.worker_process_init.connect(self.on_worker_process_init)
self.close_database()
self.close_cache()
return self
def on_worker_process_init(self, **kwargs: Any) -> None:
# Child process must validate models again if on Windows,
# or if they were started using execv.
if os.environ.get('FORKED_BY_MULTIPROCESSING'):
self.validate_models()
# close connections:
# the parent process may have established these,
# so need to close them.
# calling db.close() on some DB connections will cause
# the inherited DB conn to also get broken in the parent
# process so we need to remove it without triggering any
# network IO that close() might cause.
for c in self._db.connections.all():
if c and c.connection:
self._maybe_close_db_fd(c.connection)
# use the _ version to avoid DB_REUSE preventing the conn.close() call
self._close_database(force=True)
self.close_cache()
def _maybe_close_db_fd(self, fd: IO) -> None:
try:
_maybe_close_fd(fd)
except self.interface_errors:
pass
def on_task_prerun(self, sender: "Task", **kwargs: Any) -> None:
"""Called before every task."""
if not getattr(sender.request, 'is_eager', False):
self.close_database()
def on_task_postrun(self, sender: "Task", **kwargs: Any) -> None:
# See https://groups.google.com/group/django-users/browse_thread/thread/78200863d0c07c6d/
if not getattr(sender.request, 'is_eager', False):
self.close_database()
self.close_cache()
def close_database(self, **kwargs: Any) -> None:
if not self.db_reuse_max:
return self._close_database()
if self._db_recycles >= self.db_reuse_max * 2:
self._db_recycles = 0
self._close_database()
self._db_recycles += 1
def _close_database(self, force: bool = False) -> None:
for conn in self._db.connections.all():
try:
if force:
conn.close()
else:
conn.close_if_unusable_or_obsolete()
except self.interface_errors:
pass
except self.DatabaseError as exc:
str_exc = str(exc)
if 'closed' not in str_exc and 'not connected' not in str_exc:
raise
def close_cache(self) -> None:
try:
self._cache.close_caches()
except (TypeError, AttributeError):
pass

View File

@@ -0,0 +1,18 @@
"""Get loader by name.
Loaders define how configuration is read, what happens
when workers start, when tasks are executed and so on.
"""
from celery.utils.imports import import_from_cwd, symbol_by_name
__all__ = ('get_loader_cls',)
LOADER_ALIASES = {
'app': 'celery.loaders.app:AppLoader',
'default': 'celery.loaders.default:Loader',
}
def get_loader_cls(loader):
"""Get loader class by name/alias."""
return symbol_by_name(loader, LOADER_ALIASES, imp=import_from_cwd)

View File

@@ -0,0 +1,8 @@
"""The default loader used with custom app instances."""
from .base import BaseLoader
__all__ = ('AppLoader',)
class AppLoader(BaseLoader):
"""Default loader used when an app is specified."""

View File

@@ -0,0 +1,272 @@
"""Loader base class."""
import importlib
import os
import re
import sys
from datetime import datetime
from kombu.utils import json
from kombu.utils.objects import cached_property
from celery import signals
from celery.exceptions import reraise
from celery.utils.collections import DictAttribute, force_mapping
from celery.utils.functional import maybe_list
from celery.utils.imports import NotAPackage, find_module, import_from_cwd, symbol_by_name
__all__ = ('BaseLoader',)
_RACE_PROTECTION = False
CONFIG_INVALID_NAME = """\
Error: Module '{module}' doesn't exist, or it's not a valid \
Python module name.
"""
CONFIG_WITH_SUFFIX = CONFIG_INVALID_NAME + """\
Did you mean '{suggest}'?
"""
unconfigured = object()
class BaseLoader:
"""Base class for loaders.
Loaders handles,
* Reading celery client/worker configurations.
* What happens when a task starts?
See :meth:`on_task_init`.
* What happens when the worker starts?
See :meth:`on_worker_init`.
* What happens when the worker shuts down?
See :meth:`on_worker_shutdown`.
* What modules are imported to find tasks?
"""
builtin_modules = frozenset()
configured = False
override_backends = {}
worker_initialized = False
_conf = unconfigured
def __init__(self, app, **kwargs):
self.app = app
self.task_modules = set()
def now(self, utc=True):
if utc:
return datetime.utcnow()
return datetime.now()
def on_task_init(self, task_id, task):
"""Called before a task is executed."""
def on_process_cleanup(self):
"""Called after a task is executed."""
def on_worker_init(self):
"""Called when the worker (:program:`celery worker`) starts."""
def on_worker_shutdown(self):
"""Called when the worker (:program:`celery worker`) shuts down."""
def on_worker_process_init(self):
"""Called when a child process starts."""
def import_task_module(self, module):
self.task_modules.add(module)
return self.import_from_cwd(module)
def import_module(self, module, package=None):
return importlib.import_module(module, package=package)
def import_from_cwd(self, module, imp=None, package=None):
return import_from_cwd(
module,
self.import_module if imp is None else imp,
package=package,
)
def import_default_modules(self):
responses = signals.import_modules.send(sender=self.app)
# Prior to this point loggers are not yet set up properly, need to
# check responses manually and reraised exceptions if any, otherwise
# they'll be silenced, making it incredibly difficult to debug.
for _, response in responses:
if isinstance(response, Exception):
raise response
return [self.import_task_module(m) for m in self.default_modules]
def init_worker(self):
if not self.worker_initialized:
self.worker_initialized = True
self.import_default_modules()
self.on_worker_init()
def shutdown_worker(self):
self.on_worker_shutdown()
def init_worker_process(self):
self.on_worker_process_init()
def config_from_object(self, obj, silent=False):
if isinstance(obj, str):
try:
obj = self._smart_import(obj, imp=self.import_from_cwd)
except (ImportError, AttributeError):
if silent:
return False
raise
self._conf = force_mapping(obj)
if self._conf.get('override_backends') is not None:
self.override_backends = self._conf['override_backends']
return True
def _smart_import(self, path, imp=None):
imp = self.import_module if imp is None else imp
if ':' in path:
# Path includes attribute so can just jump
# here (e.g., ``os.path:abspath``).
return symbol_by_name(path, imp=imp)
# Not sure if path is just a module name or if it includes an
# attribute name (e.g., ``os.path``, vs, ``os.path.abspath``).
try:
return imp(path)
except ImportError:
# Not a module name, so try module + attribute.
return symbol_by_name(path, imp=imp)
def _import_config_module(self, name):
try:
self.find_module(name)
except NotAPackage as exc:
if name.endswith('.py'):
reraise(NotAPackage, NotAPackage(CONFIG_WITH_SUFFIX.format(
module=name, suggest=name[:-3])), sys.exc_info()[2])
raise NotAPackage(CONFIG_INVALID_NAME.format(module=name)) from exc
else:
return self.import_from_cwd(name)
def find_module(self, module):
return find_module(module)
def cmdline_config_parser(self, args, namespace='celery',
re_type=re.compile(r'\((\w+)\)'),
extra_types=None,
override_types=None):
extra_types = extra_types if extra_types else {'json': json.loads}
override_types = override_types if override_types else {
'tuple': 'json',
'list': 'json',
'dict': 'json'
}
from celery.app.defaults import NAMESPACES, Option
namespace = namespace and namespace.lower()
typemap = dict(Option.typemap, **extra_types)
def getarg(arg):
"""Parse single configuration from command-line."""
# ## find key/value
# ns.key=value|ns_key=value (case insensitive)
key, value = arg.split('=', 1)
key = key.lower().replace('.', '_')
# ## find name-space.
# .key=value|_key=value expands to default name-space.
if key[0] == '_':
ns, key = namespace, key[1:]
else:
# find name-space part of key
ns, key = key.split('_', 1)
ns_key = (ns and ns + '_' or '') + key
# (type)value makes cast to custom type.
cast = re_type.match(value)
if cast:
type_ = cast.groups()[0]
type_ = override_types.get(type_, type_)
value = value[len(cast.group()):]
value = typemap[type_](value)
else:
try:
value = NAMESPACES[ns.lower()][key].to_python(value)
except ValueError as exc:
# display key name in error message.
raise ValueError(f'{ns_key!r}: {exc}')
return ns_key, value
return dict(getarg(arg) for arg in args)
def read_configuration(self, env='CELERY_CONFIG_MODULE'):
try:
custom_config = os.environ[env]
except KeyError:
pass
else:
if custom_config:
usercfg = self._import_config_module(custom_config)
return DictAttribute(usercfg)
def autodiscover_tasks(self, packages, related_name='tasks'):
self.task_modules.update(
mod.__name__ for mod in autodiscover_tasks(packages or (),
related_name) if mod)
@cached_property
def default_modules(self):
return (
tuple(self.builtin_modules) +
tuple(maybe_list(self.app.conf.imports)) +
tuple(maybe_list(self.app.conf.include))
)
@property
def conf(self):
"""Loader configuration."""
if self._conf is unconfigured:
self._conf = self.read_configuration()
return self._conf
def autodiscover_tasks(packages, related_name='tasks'):
global _RACE_PROTECTION
if _RACE_PROTECTION:
return ()
_RACE_PROTECTION = True
try:
return [find_related_module(pkg, related_name) for pkg in packages]
finally:
_RACE_PROTECTION = False
def find_related_module(package, related_name):
"""Find module in package."""
# Django 1.7 allows for specifying a class name in INSTALLED_APPS.
# (Issue #2248).
try:
module = importlib.import_module(package)
if not related_name and module:
return module
except ImportError:
package, _, _ = package.rpartition('.')
if not package:
raise
module_name = f'{package}.{related_name}'
try:
return importlib.import_module(module_name)
except ImportError as e:
import_exc_name = getattr(e, 'name', module_name)
if import_exc_name is not None and import_exc_name != module_name:
raise e
return

View File

@@ -0,0 +1,42 @@
"""The default loader used when no custom app has been initialized."""
import os
import warnings
from celery.exceptions import NotConfigured
from celery.utils.collections import DictAttribute
from celery.utils.serialization import strtobool
from .base import BaseLoader
__all__ = ('Loader', 'DEFAULT_CONFIG_MODULE')
DEFAULT_CONFIG_MODULE = 'celeryconfig'
#: Warns if configuration file is missing if :envvar:`C_WNOCONF` is set.
C_WNOCONF = strtobool(os.environ.get('C_WNOCONF', False))
class Loader(BaseLoader):
"""The loader used by the default app."""
def setup_settings(self, settingsdict):
return DictAttribute(settingsdict)
def read_configuration(self, fail_silently=True):
"""Read configuration from :file:`celeryconfig.py`."""
configname = os.environ.get('CELERY_CONFIG_MODULE',
DEFAULT_CONFIG_MODULE)
try:
usercfg = self._import_config_module(configname)
except ImportError:
if not fail_silently:
raise
# billiard sets this if forked using execv
if C_WNOCONF and not os.environ.get('FORKED_BY_MULTIPROCESSING'):
warnings.warn(NotConfigured(
'No {module} module found! Please make sure it exists and '
'is available to Python.'.format(module=configname)))
return self.setup_settings({})
else:
self.configured = True
return self.setup_settings(usercfg)

Some files were not shown because too many files have changed in this diff Show More