Major fixes and new features
All checks were successful
continuous-integration/drone/push Build is passing
All checks were successful
continuous-integration/drone/push Build is passing
This commit is contained in:
@@ -0,0 +1,15 @@
|
||||
"""Worker consumer."""
|
||||
from .agent import Agent
|
||||
from .connection import Connection
|
||||
from .consumer import Consumer
|
||||
from .control import Control
|
||||
from .events import Events
|
||||
from .gossip import Gossip
|
||||
from .heart import Heart
|
||||
from .mingle import Mingle
|
||||
from .tasks import Tasks
|
||||
|
||||
__all__ = (
|
||||
'Consumer', 'Agent', 'Connection', 'Control',
|
||||
'Events', 'Gossip', 'Heart', 'Mingle', 'Tasks',
|
||||
)
|
||||
@@ -0,0 +1,21 @@
|
||||
"""Celery + :pypi:`cell` integration."""
|
||||
from celery import bootsteps
|
||||
|
||||
from .connection import Connection
|
||||
|
||||
__all__ = ('Agent',)
|
||||
|
||||
|
||||
class Agent(bootsteps.StartStopStep):
|
||||
"""Agent starts :pypi:`cell` actors."""
|
||||
|
||||
conditional = True
|
||||
requires = (Connection,)
|
||||
|
||||
def __init__(self, c, **kwargs):
|
||||
self.agent_cls = self.enabled = c.app.conf.worker_agent
|
||||
super().__init__(c, **kwargs)
|
||||
|
||||
def create(self, c):
|
||||
agent = c.agent = self.instantiate(self.agent_cls, c.connection)
|
||||
return agent
|
||||
@@ -0,0 +1,36 @@
|
||||
"""Consumer Broker Connection Bootstep."""
|
||||
from kombu.common import ignore_errors
|
||||
|
||||
from celery import bootsteps
|
||||
from celery.utils.log import get_logger
|
||||
|
||||
__all__ = ('Connection',)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
info = logger.info
|
||||
|
||||
|
||||
class Connection(bootsteps.StartStopStep):
|
||||
"""Service managing the consumer broker connection."""
|
||||
|
||||
def __init__(self, c, **kwargs):
|
||||
c.connection = None
|
||||
super().__init__(c, **kwargs)
|
||||
|
||||
def start(self, c):
|
||||
c.connection = c.connect()
|
||||
info('Connected to %s', c.connection.as_uri())
|
||||
|
||||
def shutdown(self, c):
|
||||
# We must set self.connection to None here, so
|
||||
# that the green pidbox thread exits.
|
||||
connection, c.connection = c.connection, None
|
||||
if connection:
|
||||
ignore_errors(connection, connection.close)
|
||||
|
||||
def info(self, c):
|
||||
params = 'N/A'
|
||||
if c.connection:
|
||||
params = c.connection.info()
|
||||
params.pop('password', None) # don't send password.
|
||||
return {'broker': params}
|
||||
@@ -0,0 +1,745 @@
|
||||
"""Worker Consumer Blueprint.
|
||||
|
||||
This module contains the components responsible for consuming messages
|
||||
from the broker, processing the messages and keeping the broker connections
|
||||
up and running.
|
||||
"""
|
||||
import errno
|
||||
import logging
|
||||
import os
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from time import sleep
|
||||
|
||||
from billiard.common import restart_state
|
||||
from billiard.exceptions import RestartFreqExceeded
|
||||
from kombu.asynchronous.semaphore import DummyLock
|
||||
from kombu.exceptions import ContentDisallowed, DecodeError
|
||||
from kombu.utils.compat import _detect_environment
|
||||
from kombu.utils.encoding import safe_repr
|
||||
from kombu.utils.limits import TokenBucket
|
||||
from vine import ppartial, promise
|
||||
|
||||
from celery import bootsteps, signals
|
||||
from celery.app.trace import build_tracer
|
||||
from celery.exceptions import (CPendingDeprecationWarning, InvalidTaskError, NotRegistered, WorkerShutdown,
|
||||
WorkerTerminate)
|
||||
from celery.utils.functional import noop
|
||||
from celery.utils.log import get_logger
|
||||
from celery.utils.nodenames import gethostname
|
||||
from celery.utils.objects import Bunch
|
||||
from celery.utils.text import truncate
|
||||
from celery.utils.time import humanize_seconds, rate
|
||||
from celery.worker import loops
|
||||
from celery.worker.state import active_requests, maybe_shutdown, requests, reserved_requests, task_reserved
|
||||
|
||||
__all__ = ('Consumer', 'Evloop', 'dump_body')
|
||||
|
||||
CLOSE = bootsteps.CLOSE
|
||||
TERMINATE = bootsteps.TERMINATE
|
||||
STOP_CONDITIONS = {CLOSE, TERMINATE}
|
||||
logger = get_logger(__name__)
|
||||
debug, info, warn, error, crit = (logger.debug, logger.info, logger.warning,
|
||||
logger.error, logger.critical)
|
||||
|
||||
CONNECTION_RETRY = """\
|
||||
consumer: Connection to broker lost. \
|
||||
Trying to re-establish the connection...\
|
||||
"""
|
||||
|
||||
CONNECTION_RETRY_STEP = """\
|
||||
Trying again {when}... ({retries}/{max_retries})\
|
||||
"""
|
||||
|
||||
CONNECTION_ERROR = """\
|
||||
consumer: Cannot connect to %s: %s.
|
||||
%s
|
||||
"""
|
||||
|
||||
CONNECTION_FAILOVER = """\
|
||||
Will retry using next failover.\
|
||||
"""
|
||||
|
||||
UNKNOWN_FORMAT = """\
|
||||
Received and deleted unknown message. Wrong destination?!?
|
||||
|
||||
The full contents of the message body was: %s
|
||||
"""
|
||||
|
||||
#: Error message for when an unregistered task is received.
|
||||
UNKNOWN_TASK_ERROR = """\
|
||||
Received unregistered task of type %s.
|
||||
The message has been ignored and discarded.
|
||||
|
||||
Did you remember to import the module containing this task?
|
||||
Or maybe you're using relative imports?
|
||||
|
||||
Please see
|
||||
https://docs.celeryq.dev/en/latest/internals/protocol.html
|
||||
for more information.
|
||||
|
||||
The full contents of the message body was:
|
||||
%s
|
||||
|
||||
The full contents of the message headers:
|
||||
%s
|
||||
|
||||
The delivery info for this task is:
|
||||
%s
|
||||
"""
|
||||
|
||||
#: Error message for when an invalid task message is received.
|
||||
INVALID_TASK_ERROR = """\
|
||||
Received invalid task message: %s
|
||||
The message has been ignored and discarded.
|
||||
|
||||
Please ensure your message conforms to the task
|
||||
message protocol as described here:
|
||||
https://docs.celeryq.dev/en/latest/internals/protocol.html
|
||||
|
||||
The full contents of the message body was:
|
||||
%s
|
||||
"""
|
||||
|
||||
MESSAGE_DECODE_ERROR = """\
|
||||
Can't decode message body: %r [type:%r encoding:%r headers:%s]
|
||||
|
||||
body: %s
|
||||
"""
|
||||
|
||||
MESSAGE_REPORT = """\
|
||||
body: {0}
|
||||
{{content_type:{1} content_encoding:{2}
|
||||
delivery_info:{3} headers={4}}}
|
||||
"""
|
||||
|
||||
TERMINATING_TASK_ON_RESTART_AFTER_A_CONNECTION_LOSS = """\
|
||||
Task %s cannot be acknowledged after a connection loss since late acknowledgement is enabled for it.
|
||||
Terminating it instead.
|
||||
"""
|
||||
|
||||
CANCEL_TASKS_BY_DEFAULT = """
|
||||
In Celery 5.1 we introduced an optional breaking change which
|
||||
on connection loss cancels all currently executed tasks with late acknowledgement enabled.
|
||||
These tasks cannot be acknowledged as the connection is gone, and the tasks are automatically redelivered
|
||||
back to the queue. You can enable this behavior using the worker_cancel_long_running_tasks_on_connection_loss
|
||||
setting. In Celery 5.1 it is set to False by default. The setting will be set to True by default in Celery 6.0.
|
||||
"""
|
||||
|
||||
|
||||
def dump_body(m, body):
|
||||
"""Format message body for debugging purposes."""
|
||||
# v2 protocol does not deserialize body
|
||||
body = m.body if body is None else body
|
||||
return '{} ({}b)'.format(truncate(safe_repr(body), 1024),
|
||||
len(m.body))
|
||||
|
||||
|
||||
class Consumer:
|
||||
"""Consumer blueprint."""
|
||||
|
||||
Strategies = dict
|
||||
|
||||
#: Optional callback called the first time the worker
|
||||
#: is ready to receive tasks.
|
||||
init_callback = None
|
||||
|
||||
#: The current worker pool instance.
|
||||
pool = None
|
||||
|
||||
#: A timer used for high-priority internal tasks, such
|
||||
#: as sending heartbeats.
|
||||
timer = None
|
||||
|
||||
restart_count = -1 # first start is the same as a restart
|
||||
|
||||
#: This flag will be turned off after the first failed
|
||||
#: connection attempt.
|
||||
first_connection_attempt = True
|
||||
|
||||
class Blueprint(bootsteps.Blueprint):
|
||||
"""Consumer blueprint."""
|
||||
|
||||
name = 'Consumer'
|
||||
default_steps = [
|
||||
'celery.worker.consumer.connection:Connection',
|
||||
'celery.worker.consumer.mingle:Mingle',
|
||||
'celery.worker.consumer.events:Events',
|
||||
'celery.worker.consumer.gossip:Gossip',
|
||||
'celery.worker.consumer.heart:Heart',
|
||||
'celery.worker.consumer.control:Control',
|
||||
'celery.worker.consumer.tasks:Tasks',
|
||||
'celery.worker.consumer.consumer:Evloop',
|
||||
'celery.worker.consumer.agent:Agent',
|
||||
]
|
||||
|
||||
def shutdown(self, parent):
|
||||
self.send_all(parent, 'shutdown')
|
||||
|
||||
def __init__(self, on_task_request,
|
||||
init_callback=noop, hostname=None,
|
||||
pool=None, app=None,
|
||||
timer=None, controller=None, hub=None, amqheartbeat=None,
|
||||
worker_options=None, disable_rate_limits=False,
|
||||
initial_prefetch_count=2, prefetch_multiplier=1, **kwargs):
|
||||
self.app = app
|
||||
self.controller = controller
|
||||
self.init_callback = init_callback
|
||||
self.hostname = hostname or gethostname()
|
||||
self.pid = os.getpid()
|
||||
self.pool = pool
|
||||
self.timer = timer
|
||||
self.strategies = self.Strategies()
|
||||
self.conninfo = self.app.connection_for_read()
|
||||
self.connection_errors = self.conninfo.connection_errors
|
||||
self.channel_errors = self.conninfo.channel_errors
|
||||
self._restart_state = restart_state(maxR=5, maxT=1)
|
||||
|
||||
self._does_info = logger.isEnabledFor(logging.INFO)
|
||||
self._limit_order = 0
|
||||
self.on_task_request = on_task_request
|
||||
self.on_task_message = set()
|
||||
self.amqheartbeat_rate = self.app.conf.broker_heartbeat_checkrate
|
||||
self.disable_rate_limits = disable_rate_limits
|
||||
self.initial_prefetch_count = initial_prefetch_count
|
||||
self.prefetch_multiplier = prefetch_multiplier
|
||||
self._maximum_prefetch_restored = True
|
||||
|
||||
# this contains a tokenbucket for each task type by name, used for
|
||||
# rate limits, or None if rate limits are disabled for that task.
|
||||
self.task_buckets = defaultdict(lambda: None)
|
||||
self.reset_rate_limits()
|
||||
|
||||
self.hub = hub
|
||||
if self.hub or getattr(self.pool, 'is_green', False):
|
||||
self.amqheartbeat = amqheartbeat
|
||||
if self.amqheartbeat is None:
|
||||
self.amqheartbeat = self.app.conf.broker_heartbeat
|
||||
else:
|
||||
self.amqheartbeat = 0
|
||||
|
||||
if not hasattr(self, 'loop'):
|
||||
self.loop = loops.asynloop if hub else loops.synloop
|
||||
|
||||
if _detect_environment() == 'gevent':
|
||||
# there's a gevent bug that causes timeouts to not be reset,
|
||||
# so if the connection timeout is exceeded once, it can NEVER
|
||||
# connect again.
|
||||
self.app.conf.broker_connection_timeout = None
|
||||
|
||||
self._pending_operations = []
|
||||
|
||||
self.steps = []
|
||||
self.blueprint = self.Blueprint(
|
||||
steps=self.app.steps['consumer'],
|
||||
on_close=self.on_close,
|
||||
)
|
||||
self.blueprint.apply(self, **dict(worker_options or {}, **kwargs))
|
||||
|
||||
def call_soon(self, p, *args, **kwargs):
|
||||
p = ppartial(p, *args, **kwargs)
|
||||
if self.hub:
|
||||
return self.hub.call_soon(p)
|
||||
self._pending_operations.append(p)
|
||||
return p
|
||||
|
||||
def perform_pending_operations(self):
|
||||
if not self.hub:
|
||||
while self._pending_operations:
|
||||
try:
|
||||
self._pending_operations.pop()()
|
||||
except Exception as exc: # pylint: disable=broad-except
|
||||
logger.exception('Pending callback raised: %r', exc)
|
||||
|
||||
def bucket_for_task(self, type):
|
||||
limit = rate(getattr(type, 'rate_limit', None))
|
||||
return TokenBucket(limit, capacity=1) if limit else None
|
||||
|
||||
def reset_rate_limits(self):
|
||||
self.task_buckets.update(
|
||||
(n, self.bucket_for_task(t)) for n, t in self.app.tasks.items()
|
||||
)
|
||||
|
||||
def _update_prefetch_count(self, index=0):
|
||||
"""Update prefetch count after pool/shrink grow operations.
|
||||
|
||||
Index must be the change in number of processes as a positive
|
||||
(increasing) or negative (decreasing) number.
|
||||
|
||||
Note:
|
||||
Currently pool grow operations will end up with an offset
|
||||
of +1 if the initial size of the pool was 0 (e.g.
|
||||
:option:`--autoscale=1,0 <celery worker --autoscale>`).
|
||||
"""
|
||||
num_processes = self.pool.num_processes
|
||||
if not self.initial_prefetch_count or not num_processes:
|
||||
return # prefetch disabled
|
||||
self.initial_prefetch_count = (
|
||||
self.pool.num_processes * self.prefetch_multiplier
|
||||
)
|
||||
return self._update_qos_eventually(index)
|
||||
|
||||
def _update_qos_eventually(self, index):
|
||||
return (self.qos.decrement_eventually if index < 0
|
||||
else self.qos.increment_eventually)(
|
||||
abs(index) * self.prefetch_multiplier)
|
||||
|
||||
def _limit_move_to_pool(self, request):
|
||||
task_reserved(request)
|
||||
self.on_task_request(request)
|
||||
|
||||
def _schedule_bucket_request(self, bucket):
|
||||
while True:
|
||||
try:
|
||||
request, tokens = bucket.pop()
|
||||
except IndexError:
|
||||
# no request, break
|
||||
break
|
||||
|
||||
if bucket.can_consume(tokens):
|
||||
self._limit_move_to_pool(request)
|
||||
continue
|
||||
else:
|
||||
# requeue to head, keep the order.
|
||||
bucket.contents.appendleft((request, tokens))
|
||||
|
||||
pri = self._limit_order = (self._limit_order + 1) % 10
|
||||
hold = bucket.expected_time(tokens)
|
||||
self.timer.call_after(
|
||||
hold, self._schedule_bucket_request, (bucket,),
|
||||
priority=pri,
|
||||
)
|
||||
# no tokens, break
|
||||
break
|
||||
|
||||
def _limit_task(self, request, bucket, tokens):
|
||||
bucket.add((request, tokens))
|
||||
return self._schedule_bucket_request(bucket)
|
||||
|
||||
def _limit_post_eta(self, request, bucket, tokens):
|
||||
self.qos.decrement_eventually()
|
||||
bucket.add((request, tokens))
|
||||
return self._schedule_bucket_request(bucket)
|
||||
|
||||
def start(self):
|
||||
blueprint = self.blueprint
|
||||
while blueprint.state not in STOP_CONDITIONS:
|
||||
maybe_shutdown()
|
||||
if self.restart_count:
|
||||
try:
|
||||
self._restart_state.step()
|
||||
except RestartFreqExceeded as exc:
|
||||
crit('Frequent restarts detected: %r', exc, exc_info=1)
|
||||
sleep(1)
|
||||
self.restart_count += 1
|
||||
if self.app.conf.broker_channel_error_retry:
|
||||
recoverable_errors = (self.connection_errors + self.channel_errors)
|
||||
else:
|
||||
recoverable_errors = self.connection_errors
|
||||
try:
|
||||
blueprint.start(self)
|
||||
except recoverable_errors as exc:
|
||||
# If we're not retrying connections, we need to properly shutdown or terminate
|
||||
# the Celery main process instead of abruptly aborting the process without any cleanup.
|
||||
is_connection_loss_on_startup = self.first_connection_attempt
|
||||
self.first_connection_attempt = False
|
||||
connection_retry_type = self._get_connection_retry_type(is_connection_loss_on_startup)
|
||||
connection_retry = self.app.conf[connection_retry_type]
|
||||
if not connection_retry:
|
||||
crit(
|
||||
f"Retrying to {'establish' if is_connection_loss_on_startup else 're-establish'} "
|
||||
f"a connection to the message broker after a connection loss has "
|
||||
f"been disabled (app.conf.{connection_retry_type}=False). Shutting down..."
|
||||
)
|
||||
raise WorkerShutdown(1) from exc
|
||||
if isinstance(exc, OSError) and exc.errno == errno.EMFILE:
|
||||
crit("Too many open files. Aborting...")
|
||||
raise WorkerTerminate(1) from exc
|
||||
maybe_shutdown()
|
||||
if blueprint.state not in STOP_CONDITIONS:
|
||||
if self.connection:
|
||||
self.on_connection_error_after_connected(exc)
|
||||
else:
|
||||
self.on_connection_error_before_connected(exc)
|
||||
self.on_close()
|
||||
blueprint.restart(self)
|
||||
|
||||
def _get_connection_retry_type(self, is_connection_loss_on_startup):
|
||||
return ('broker_connection_retry_on_startup'
|
||||
if (is_connection_loss_on_startup
|
||||
and self.app.conf.broker_connection_retry_on_startup is not None)
|
||||
else 'broker_connection_retry')
|
||||
|
||||
def on_connection_error_before_connected(self, exc):
|
||||
error(CONNECTION_ERROR, self.conninfo.as_uri(), exc,
|
||||
'Trying to reconnect...')
|
||||
|
||||
def on_connection_error_after_connected(self, exc):
|
||||
warn(CONNECTION_RETRY, exc_info=True)
|
||||
try:
|
||||
self.connection.collect()
|
||||
except Exception: # pylint: disable=broad-except
|
||||
pass
|
||||
|
||||
if self.app.conf.worker_cancel_long_running_tasks_on_connection_loss:
|
||||
for request in tuple(active_requests):
|
||||
if request.task.acks_late and not request.acknowledged:
|
||||
warn(TERMINATING_TASK_ON_RESTART_AFTER_A_CONNECTION_LOSS,
|
||||
request)
|
||||
request.cancel(self.pool)
|
||||
else:
|
||||
warnings.warn(CANCEL_TASKS_BY_DEFAULT, CPendingDeprecationWarning)
|
||||
|
||||
self.initial_prefetch_count = max(
|
||||
self.prefetch_multiplier,
|
||||
self.max_prefetch_count - len(tuple(active_requests)) * self.prefetch_multiplier
|
||||
)
|
||||
|
||||
self._maximum_prefetch_restored = self.initial_prefetch_count == self.max_prefetch_count
|
||||
if not self._maximum_prefetch_restored:
|
||||
logger.info(
|
||||
f"Temporarily reducing the prefetch count to {self.initial_prefetch_count} to avoid over-fetching "
|
||||
f"since {len(tuple(active_requests))} tasks are currently being processed.\n"
|
||||
f"The prefetch count will be gradually restored to {self.max_prefetch_count} as the tasks "
|
||||
"complete processing."
|
||||
)
|
||||
|
||||
def register_with_event_loop(self, hub):
|
||||
self.blueprint.send_all(
|
||||
self, 'register_with_event_loop', args=(hub,),
|
||||
description='Hub.register',
|
||||
)
|
||||
|
||||
def shutdown(self):
|
||||
self.blueprint.shutdown(self)
|
||||
|
||||
def stop(self):
|
||||
self.blueprint.stop(self)
|
||||
|
||||
def on_ready(self):
|
||||
callback, self.init_callback = self.init_callback, None
|
||||
if callback:
|
||||
callback(self)
|
||||
|
||||
def loop_args(self):
|
||||
return (self, self.connection, self.task_consumer,
|
||||
self.blueprint, self.hub, self.qos, self.amqheartbeat,
|
||||
self.app.clock, self.amqheartbeat_rate)
|
||||
|
||||
def on_decode_error(self, message, exc):
|
||||
"""Callback called if an error occurs while decoding a message.
|
||||
|
||||
Simply logs the error and acknowledges the message so it
|
||||
doesn't enter a loop.
|
||||
|
||||
Arguments:
|
||||
message (kombu.Message): The message received.
|
||||
exc (Exception): The exception being handled.
|
||||
"""
|
||||
crit(MESSAGE_DECODE_ERROR,
|
||||
exc, message.content_type, message.content_encoding,
|
||||
safe_repr(message.headers), dump_body(message, message.body),
|
||||
exc_info=1)
|
||||
message.ack()
|
||||
|
||||
def on_close(self):
|
||||
# Clear internal queues to get rid of old messages.
|
||||
# They can't be acked anyway, as a delivery tag is specific
|
||||
# to the current channel.
|
||||
if self.controller and self.controller.semaphore:
|
||||
self.controller.semaphore.clear()
|
||||
if self.timer:
|
||||
self.timer.clear()
|
||||
for bucket in self.task_buckets.values():
|
||||
if bucket:
|
||||
bucket.clear_pending()
|
||||
for request_id in reserved_requests:
|
||||
if request_id in requests:
|
||||
del requests[request_id]
|
||||
reserved_requests.clear()
|
||||
if self.pool and self.pool.flush:
|
||||
self.pool.flush()
|
||||
|
||||
def connect(self):
|
||||
"""Establish the broker connection used for consuming tasks.
|
||||
|
||||
Retries establishing the connection if the
|
||||
:setting:`broker_connection_retry` setting is enabled
|
||||
"""
|
||||
conn = self.connection_for_read(heartbeat=self.amqheartbeat)
|
||||
if self.hub:
|
||||
conn.transport.register_with_event_loop(conn.connection, self.hub)
|
||||
return conn
|
||||
|
||||
def connection_for_read(self, heartbeat=None):
|
||||
return self.ensure_connected(
|
||||
self.app.connection_for_read(heartbeat=heartbeat))
|
||||
|
||||
def connection_for_write(self, heartbeat=None):
|
||||
return self.ensure_connected(
|
||||
self.app.connection_for_write(heartbeat=heartbeat))
|
||||
|
||||
def ensure_connected(self, conn):
|
||||
# Callback called for each retry while the connection
|
||||
# can't be established.
|
||||
def _error_handler(exc, interval, next_step=CONNECTION_RETRY_STEP):
|
||||
if getattr(conn, 'alt', None) and interval == 0:
|
||||
next_step = CONNECTION_FAILOVER
|
||||
next_step = next_step.format(
|
||||
when=humanize_seconds(interval, 'in', ' '),
|
||||
retries=int(interval / 2),
|
||||
max_retries=self.app.conf.broker_connection_max_retries)
|
||||
error(CONNECTION_ERROR, conn.as_uri(), exc, next_step)
|
||||
|
||||
# Remember that the connection is lazy, it won't establish
|
||||
# until needed.
|
||||
|
||||
# TODO: Rely only on broker_connection_retry_on_startup to determine whether connection retries are disabled.
|
||||
# We will make the switch in Celery 6.0.
|
||||
|
||||
retry_disabled = False
|
||||
|
||||
if self.app.conf.broker_connection_retry_on_startup is None:
|
||||
# If broker_connection_retry_on_startup is not set, revert to broker_connection_retry
|
||||
# to determine whether connection retries are disabled.
|
||||
retry_disabled = not self.app.conf.broker_connection_retry
|
||||
|
||||
warnings.warn(
|
||||
CPendingDeprecationWarning(
|
||||
f"The broker_connection_retry configuration setting will no longer determine\n"
|
||||
f"whether broker connection retries are made during startup in Celery 6.0 and above.\n"
|
||||
f"If you wish to retain the existing behavior for retrying connections on startup,\n"
|
||||
f"you should set broker_connection_retry_on_startup to {self.app.conf.broker_connection_retry}.")
|
||||
)
|
||||
else:
|
||||
if self.first_connection_attempt:
|
||||
retry_disabled = not self.app.conf.broker_connection_retry_on_startup
|
||||
else:
|
||||
retry_disabled = not self.app.conf.broker_connection_retry
|
||||
|
||||
if retry_disabled:
|
||||
# Retry disabled, just call connect directly.
|
||||
conn.connect()
|
||||
self.first_connection_attempt = False
|
||||
return conn
|
||||
|
||||
conn = conn.ensure_connection(
|
||||
_error_handler, self.app.conf.broker_connection_max_retries,
|
||||
callback=maybe_shutdown,
|
||||
)
|
||||
self.first_connection_attempt = False
|
||||
return conn
|
||||
|
||||
def _flush_events(self):
|
||||
if self.event_dispatcher:
|
||||
self.event_dispatcher.flush()
|
||||
|
||||
def on_send_event_buffered(self):
|
||||
if self.hub:
|
||||
self.hub._ready.add(self._flush_events)
|
||||
|
||||
def add_task_queue(self, queue, exchange=None, exchange_type=None,
|
||||
routing_key=None, **options):
|
||||
cset = self.task_consumer
|
||||
queues = self.app.amqp.queues
|
||||
# Must use in' here, as __missing__ will automatically
|
||||
# create queues when :setting:`task_create_missing_queues` is enabled.
|
||||
# (Issue #1079)
|
||||
if queue in queues:
|
||||
q = queues[queue]
|
||||
else:
|
||||
exchange = queue if exchange is None else exchange
|
||||
exchange_type = ('direct' if exchange_type is None
|
||||
else exchange_type)
|
||||
q = queues.select_add(queue,
|
||||
exchange=exchange,
|
||||
exchange_type=exchange_type,
|
||||
routing_key=routing_key, **options)
|
||||
if not cset.consuming_from(queue):
|
||||
cset.add_queue(q)
|
||||
cset.consume()
|
||||
info('Started consuming from %s', queue)
|
||||
|
||||
def cancel_task_queue(self, queue):
|
||||
info('Canceling queue %s', queue)
|
||||
self.app.amqp.queues.deselect(queue)
|
||||
self.task_consumer.cancel_by_queue(queue)
|
||||
|
||||
def apply_eta_task(self, task):
|
||||
"""Method called by the timer to apply a task with an ETA/countdown."""
|
||||
task_reserved(task)
|
||||
self.on_task_request(task)
|
||||
self.qos.decrement_eventually()
|
||||
|
||||
def _message_report(self, body, message):
|
||||
return MESSAGE_REPORT.format(dump_body(message, body),
|
||||
safe_repr(message.content_type),
|
||||
safe_repr(message.content_encoding),
|
||||
safe_repr(message.delivery_info),
|
||||
safe_repr(message.headers))
|
||||
|
||||
def on_unknown_message(self, body, message):
|
||||
warn(UNKNOWN_FORMAT, self._message_report(body, message))
|
||||
message.reject_log_error(logger, self.connection_errors)
|
||||
signals.task_rejected.send(sender=self, message=message, exc=None)
|
||||
|
||||
def on_unknown_task(self, body, message, exc):
|
||||
error(UNKNOWN_TASK_ERROR,
|
||||
exc,
|
||||
dump_body(message, body),
|
||||
message.headers,
|
||||
message.delivery_info,
|
||||
exc_info=True)
|
||||
try:
|
||||
id_, name = message.headers['id'], message.headers['task']
|
||||
root_id = message.headers.get('root_id')
|
||||
except KeyError: # proto1
|
||||
payload = message.payload
|
||||
id_, name = payload['id'], payload['task']
|
||||
root_id = None
|
||||
request = Bunch(
|
||||
name=name, chord=None, root_id=root_id,
|
||||
correlation_id=message.properties.get('correlation_id'),
|
||||
reply_to=message.properties.get('reply_to'),
|
||||
errbacks=None,
|
||||
)
|
||||
message.reject_log_error(logger, self.connection_errors)
|
||||
self.app.backend.mark_as_failure(
|
||||
id_, NotRegistered(name), request=request,
|
||||
)
|
||||
if self.event_dispatcher:
|
||||
self.event_dispatcher.send(
|
||||
'task-failed', uuid=id_,
|
||||
exception=f'NotRegistered({name!r})',
|
||||
)
|
||||
signals.task_unknown.send(
|
||||
sender=self, message=message, exc=exc, name=name, id=id_,
|
||||
)
|
||||
|
||||
def on_invalid_task(self, body, message, exc):
|
||||
error(INVALID_TASK_ERROR, exc, dump_body(message, body),
|
||||
exc_info=True)
|
||||
message.reject_log_error(logger, self.connection_errors)
|
||||
signals.task_rejected.send(sender=self, message=message, exc=exc)
|
||||
|
||||
def update_strategies(self):
|
||||
loader = self.app.loader
|
||||
for name, task in self.app.tasks.items():
|
||||
self.strategies[name] = task.start_strategy(self.app, self)
|
||||
task.__trace__ = build_tracer(name, task, loader, self.hostname,
|
||||
app=self.app)
|
||||
|
||||
def create_task_handler(self, promise=promise):
|
||||
strategies = self.strategies
|
||||
on_unknown_message = self.on_unknown_message
|
||||
on_unknown_task = self.on_unknown_task
|
||||
on_invalid_task = self.on_invalid_task
|
||||
callbacks = self.on_task_message
|
||||
call_soon = self.call_soon
|
||||
|
||||
def on_task_received(message):
|
||||
# payload will only be set for v1 protocol, since v2
|
||||
# will defer deserializing the message body to the pool.
|
||||
payload = None
|
||||
try:
|
||||
type_ = message.headers['task'] # protocol v2
|
||||
except TypeError:
|
||||
return on_unknown_message(None, message)
|
||||
except KeyError:
|
||||
try:
|
||||
payload = message.decode()
|
||||
except Exception as exc: # pylint: disable=broad-except
|
||||
return self.on_decode_error(message, exc)
|
||||
try:
|
||||
type_, payload = payload['task'], payload # protocol v1
|
||||
except (TypeError, KeyError):
|
||||
return on_unknown_message(payload, message)
|
||||
try:
|
||||
strategy = strategies[type_]
|
||||
except KeyError as exc:
|
||||
return on_unknown_task(None, message, exc)
|
||||
else:
|
||||
try:
|
||||
ack_log_error_promise = promise(
|
||||
call_soon,
|
||||
(message.ack_log_error,),
|
||||
on_error=self._restore_prefetch_count_after_connection_restart,
|
||||
)
|
||||
reject_log_error_promise = promise(
|
||||
call_soon,
|
||||
(message.reject_log_error,),
|
||||
on_error=self._restore_prefetch_count_after_connection_restart,
|
||||
)
|
||||
|
||||
if (
|
||||
not self._maximum_prefetch_restored
|
||||
and self.restart_count > 0
|
||||
and self._new_prefetch_count <= self.max_prefetch_count
|
||||
):
|
||||
ack_log_error_promise.then(self._restore_prefetch_count_after_connection_restart,
|
||||
on_error=self._restore_prefetch_count_after_connection_restart)
|
||||
reject_log_error_promise.then(self._restore_prefetch_count_after_connection_restart,
|
||||
on_error=self._restore_prefetch_count_after_connection_restart)
|
||||
|
||||
strategy(
|
||||
message, payload,
|
||||
ack_log_error_promise,
|
||||
reject_log_error_promise,
|
||||
callbacks,
|
||||
)
|
||||
except (InvalidTaskError, ContentDisallowed) as exc:
|
||||
return on_invalid_task(payload, message, exc)
|
||||
except DecodeError as exc:
|
||||
return self.on_decode_error(message, exc)
|
||||
|
||||
return on_task_received
|
||||
|
||||
def _restore_prefetch_count_after_connection_restart(self, p, *args):
|
||||
with self.qos._mutex:
|
||||
if self._maximum_prefetch_restored:
|
||||
return
|
||||
|
||||
new_prefetch_count = min(self.max_prefetch_count, self._new_prefetch_count)
|
||||
self.qos.value = self.initial_prefetch_count = new_prefetch_count
|
||||
self.qos.set(self.qos.value)
|
||||
|
||||
already_restored = self._maximum_prefetch_restored
|
||||
self._maximum_prefetch_restored = new_prefetch_count == self.max_prefetch_count
|
||||
|
||||
if already_restored is False and self._maximum_prefetch_restored is True:
|
||||
logger.info(
|
||||
"Resuming normal operations following a restart.\n"
|
||||
f"Prefetch count has been restored to the maximum of {self.max_prefetch_count}"
|
||||
)
|
||||
|
||||
@property
|
||||
def max_prefetch_count(self):
|
||||
return self.pool.num_processes * self.prefetch_multiplier
|
||||
|
||||
@property
|
||||
def _new_prefetch_count(self):
|
||||
return self.qos.value + self.prefetch_multiplier
|
||||
|
||||
def __repr__(self):
|
||||
"""``repr(self)``."""
|
||||
return '<Consumer: {self.hostname} ({state})>'.format(
|
||||
self=self, state=self.blueprint.human_state(),
|
||||
)
|
||||
|
||||
|
||||
class Evloop(bootsteps.StartStopStep):
|
||||
"""Event loop service.
|
||||
|
||||
Note:
|
||||
This is always started last.
|
||||
"""
|
||||
|
||||
label = 'event loop'
|
||||
last = True
|
||||
|
||||
def start(self, c):
|
||||
self.patch_all(c)
|
||||
c.loop(*c.loop_args())
|
||||
|
||||
def patch_all(self, c):
|
||||
c.qos._mutex = DummyLock()
|
||||
@@ -0,0 +1,33 @@
|
||||
"""Worker Remote Control Bootstep.
|
||||
|
||||
``Control`` -> :mod:`celery.worker.pidbox` -> :mod:`kombu.pidbox`.
|
||||
|
||||
The actual commands are implemented in :mod:`celery.worker.control`.
|
||||
"""
|
||||
from celery import bootsteps
|
||||
from celery.utils.log import get_logger
|
||||
from celery.worker import pidbox
|
||||
|
||||
from .tasks import Tasks
|
||||
|
||||
__all__ = ('Control',)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class Control(bootsteps.StartStopStep):
|
||||
"""Remote control command service."""
|
||||
|
||||
requires = (Tasks,)
|
||||
|
||||
def __init__(self, c, **kwargs):
|
||||
self.is_green = c.pool is not None and c.pool.is_green
|
||||
self.box = (pidbox.gPidbox if self.is_green else pidbox.Pidbox)(c)
|
||||
self.start = self.box.start
|
||||
self.stop = self.box.stop
|
||||
self.shutdown = self.box.shutdown
|
||||
super().__init__(c, **kwargs)
|
||||
|
||||
def include_if(self, c):
|
||||
return (c.app.conf.worker_enable_remote_control and
|
||||
c.conninfo.supports_exchange_type('fanout'))
|
||||
@@ -0,0 +1,68 @@
|
||||
"""Worker Event Dispatcher Bootstep.
|
||||
|
||||
``Events`` -> :class:`celery.events.EventDispatcher`.
|
||||
"""
|
||||
from kombu.common import ignore_errors
|
||||
|
||||
from celery import bootsteps
|
||||
|
||||
from .connection import Connection
|
||||
|
||||
__all__ = ('Events',)
|
||||
|
||||
|
||||
class Events(bootsteps.StartStopStep):
|
||||
"""Service used for sending monitoring events."""
|
||||
|
||||
requires = (Connection,)
|
||||
|
||||
def __init__(self, c,
|
||||
task_events=True,
|
||||
without_heartbeat=False,
|
||||
without_gossip=False,
|
||||
**kwargs):
|
||||
self.groups = None if task_events else ['worker']
|
||||
self.send_events = (
|
||||
task_events or
|
||||
not without_gossip or
|
||||
not without_heartbeat
|
||||
)
|
||||
self.enabled = self.send_events
|
||||
c.event_dispatcher = None
|
||||
super().__init__(c, **kwargs)
|
||||
|
||||
def start(self, c):
|
||||
# flush events sent while connection was down.
|
||||
prev = self._close(c)
|
||||
dis = c.event_dispatcher = c.app.events.Dispatcher(
|
||||
c.connection_for_write(),
|
||||
hostname=c.hostname,
|
||||
enabled=self.send_events,
|
||||
groups=self.groups,
|
||||
# we currently only buffer events when the event loop is enabled
|
||||
# XXX This excludes eventlet/gevent, which should actually buffer.
|
||||
buffer_group=['task'] if c.hub else None,
|
||||
on_send_buffered=c.on_send_event_buffered if c.hub else None,
|
||||
)
|
||||
if prev:
|
||||
dis.extend_buffer(prev)
|
||||
dis.flush()
|
||||
|
||||
def stop(self, c):
|
||||
pass
|
||||
|
||||
def _close(self, c):
|
||||
if c.event_dispatcher:
|
||||
dispatcher = c.event_dispatcher
|
||||
# remember changes from remote control commands:
|
||||
self.groups = dispatcher.groups
|
||||
|
||||
# close custom connection
|
||||
if dispatcher.connection:
|
||||
ignore_errors(c, dispatcher.connection.close)
|
||||
ignore_errors(c, dispatcher.close)
|
||||
c.event_dispatcher = None
|
||||
return dispatcher
|
||||
|
||||
def shutdown(self, c):
|
||||
self._close(c)
|
||||
@@ -0,0 +1,205 @@
|
||||
"""Worker <-> Worker communication Bootstep."""
|
||||
from collections import defaultdict
|
||||
from functools import partial
|
||||
from heapq import heappush
|
||||
from operator import itemgetter
|
||||
|
||||
from kombu import Consumer
|
||||
from kombu.asynchronous.semaphore import DummyLock
|
||||
from kombu.exceptions import ContentDisallowed, DecodeError
|
||||
|
||||
from celery import bootsteps
|
||||
from celery.utils.log import get_logger
|
||||
from celery.utils.objects import Bunch
|
||||
|
||||
from .mingle import Mingle
|
||||
|
||||
__all__ = ('Gossip',)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
debug, info = logger.debug, logger.info
|
||||
|
||||
|
||||
class Gossip(bootsteps.ConsumerStep):
|
||||
"""Bootstep consuming events from other workers.
|
||||
|
||||
This keeps the logical clock value up to date.
|
||||
"""
|
||||
|
||||
label = 'Gossip'
|
||||
requires = (Mingle,)
|
||||
_cons_stamp_fields = itemgetter(
|
||||
'id', 'clock', 'hostname', 'pid', 'topic', 'action', 'cver',
|
||||
)
|
||||
compatible_transports = {'amqp', 'redis'}
|
||||
|
||||
def __init__(self, c, without_gossip=False,
|
||||
interval=5.0, heartbeat_interval=2.0, **kwargs):
|
||||
self.enabled = not without_gossip and self.compatible_transport(c.app)
|
||||
self.app = c.app
|
||||
c.gossip = self
|
||||
self.Receiver = c.app.events.Receiver
|
||||
self.hostname = c.hostname
|
||||
self.full_hostname = '.'.join([self.hostname, str(c.pid)])
|
||||
self.on = Bunch(
|
||||
node_join=set(),
|
||||
node_leave=set(),
|
||||
node_lost=set(),
|
||||
)
|
||||
|
||||
self.timer = c.timer
|
||||
if self.enabled:
|
||||
self.state = c.app.events.State(
|
||||
on_node_join=self.on_node_join,
|
||||
on_node_leave=self.on_node_leave,
|
||||
max_tasks_in_memory=1,
|
||||
)
|
||||
if c.hub:
|
||||
c._mutex = DummyLock()
|
||||
self.update_state = self.state.event
|
||||
self.interval = interval
|
||||
self.heartbeat_interval = heartbeat_interval
|
||||
self._tref = None
|
||||
self.consensus_requests = defaultdict(list)
|
||||
self.consensus_replies = {}
|
||||
self.event_handlers = {
|
||||
'worker.elect': self.on_elect,
|
||||
'worker.elect.ack': self.on_elect_ack,
|
||||
}
|
||||
self.clock = c.app.clock
|
||||
|
||||
self.election_handlers = {
|
||||
'task': self.call_task
|
||||
}
|
||||
|
||||
super().__init__(c, **kwargs)
|
||||
|
||||
def compatible_transport(self, app):
|
||||
with app.connection_for_read() as conn:
|
||||
return conn.transport.driver_type in self.compatible_transports
|
||||
|
||||
def election(self, id, topic, action=None):
|
||||
self.consensus_replies[id] = []
|
||||
self.dispatcher.send(
|
||||
'worker-elect',
|
||||
id=id, topic=topic, action=action, cver=1,
|
||||
)
|
||||
|
||||
def call_task(self, task):
|
||||
try:
|
||||
self.app.signature(task).apply_async()
|
||||
except Exception as exc: # pylint: disable=broad-except
|
||||
logger.exception('Could not call task: %r', exc)
|
||||
|
||||
def on_elect(self, event):
|
||||
try:
|
||||
(id_, clock, hostname, pid,
|
||||
topic, action, _) = self._cons_stamp_fields(event)
|
||||
except KeyError as exc:
|
||||
return logger.exception('election request missing field %s', exc)
|
||||
heappush(
|
||||
self.consensus_requests[id_],
|
||||
(clock, f'{hostname}.{pid}', topic, action),
|
||||
)
|
||||
self.dispatcher.send('worker-elect-ack', id=id_)
|
||||
|
||||
def start(self, c):
|
||||
super().start(c)
|
||||
self.dispatcher = c.event_dispatcher
|
||||
|
||||
def on_elect_ack(self, event):
|
||||
id = event['id']
|
||||
try:
|
||||
replies = self.consensus_replies[id]
|
||||
except KeyError:
|
||||
return # not for us
|
||||
alive_workers = set(self.state.alive_workers())
|
||||
replies.append(event['hostname'])
|
||||
|
||||
if len(replies) >= len(alive_workers):
|
||||
_, leader, topic, action = self.clock.sort_heap(
|
||||
self.consensus_requests[id],
|
||||
)
|
||||
if leader == self.full_hostname:
|
||||
info('I won the election %r', id)
|
||||
try:
|
||||
handler = self.election_handlers[topic]
|
||||
except KeyError:
|
||||
logger.exception('Unknown election topic %r', topic)
|
||||
else:
|
||||
handler(action)
|
||||
else:
|
||||
info('node %s elected for %r', leader, id)
|
||||
self.consensus_requests.pop(id, None)
|
||||
self.consensus_replies.pop(id, None)
|
||||
|
||||
def on_node_join(self, worker):
|
||||
debug('%s joined the party', worker.hostname)
|
||||
self._call_handlers(self.on.node_join, worker)
|
||||
|
||||
def on_node_leave(self, worker):
|
||||
debug('%s left', worker.hostname)
|
||||
self._call_handlers(self.on.node_leave, worker)
|
||||
|
||||
def on_node_lost(self, worker):
|
||||
info('missed heartbeat from %s', worker.hostname)
|
||||
self._call_handlers(self.on.node_lost, worker)
|
||||
|
||||
def _call_handlers(self, handlers, *args, **kwargs):
|
||||
for handler in handlers:
|
||||
try:
|
||||
handler(*args, **kwargs)
|
||||
except Exception as exc: # pylint: disable=broad-except
|
||||
logger.exception(
|
||||
'Ignored error from handler %r: %r', handler, exc)
|
||||
|
||||
def register_timer(self):
|
||||
if self._tref is not None:
|
||||
self._tref.cancel()
|
||||
self._tref = self.timer.call_repeatedly(self.interval, self.periodic)
|
||||
|
||||
def periodic(self):
|
||||
workers = self.state.workers
|
||||
dirty = set()
|
||||
for worker in workers.values():
|
||||
if not worker.alive:
|
||||
dirty.add(worker)
|
||||
self.on_node_lost(worker)
|
||||
for worker in dirty:
|
||||
workers.pop(worker.hostname, None)
|
||||
|
||||
def get_consumers(self, channel):
|
||||
self.register_timer()
|
||||
ev = self.Receiver(channel, routing_key='worker.#',
|
||||
queue_ttl=self.heartbeat_interval)
|
||||
return [Consumer(
|
||||
channel,
|
||||
queues=[ev.queue],
|
||||
on_message=partial(self.on_message, ev.event_from_message),
|
||||
no_ack=True
|
||||
)]
|
||||
|
||||
def on_message(self, prepare, message):
|
||||
_type = message.delivery_info['routing_key']
|
||||
|
||||
# For redis when `fanout_patterns=False` (See Issue #1882)
|
||||
if _type.split('.', 1)[0] == 'task':
|
||||
return
|
||||
try:
|
||||
handler = self.event_handlers[_type]
|
||||
except KeyError:
|
||||
pass
|
||||
else:
|
||||
return handler(message.payload)
|
||||
|
||||
# proto2: hostname in header; proto1: in body
|
||||
hostname = (message.headers.get('hostname') or
|
||||
message.payload['hostname'])
|
||||
if hostname != self.hostname:
|
||||
try:
|
||||
_, event = prepare(message.payload)
|
||||
self.update_state(event)
|
||||
except (DecodeError, ContentDisallowed, TypeError) as exc:
|
||||
logger.error(exc)
|
||||
else:
|
||||
self.clock.forward()
|
||||
@@ -0,0 +1,36 @@
|
||||
"""Worker Event Heartbeat Bootstep."""
|
||||
from celery import bootsteps
|
||||
from celery.worker import heartbeat
|
||||
|
||||
from .events import Events
|
||||
|
||||
__all__ = ('Heart',)
|
||||
|
||||
|
||||
class Heart(bootsteps.StartStopStep):
|
||||
"""Bootstep sending event heartbeats.
|
||||
|
||||
This service sends a ``worker-heartbeat`` message every n seconds.
|
||||
|
||||
Note:
|
||||
Not to be confused with AMQP protocol level heartbeats.
|
||||
"""
|
||||
|
||||
requires = (Events,)
|
||||
|
||||
def __init__(self, c,
|
||||
without_heartbeat=False, heartbeat_interval=None, **kwargs):
|
||||
self.enabled = not without_heartbeat
|
||||
self.heartbeat_interval = heartbeat_interval
|
||||
c.heart = None
|
||||
super().__init__(c, **kwargs)
|
||||
|
||||
def start(self, c):
|
||||
c.heart = heartbeat.Heart(
|
||||
c.timer, c.event_dispatcher, self.heartbeat_interval,
|
||||
)
|
||||
c.heart.start()
|
||||
|
||||
def stop(self, c):
|
||||
c.heart = c.heart and c.heart.stop()
|
||||
shutdown = stop
|
||||
@@ -0,0 +1,76 @@
|
||||
"""Worker <-> Worker Sync at startup (Bootstep)."""
|
||||
from celery import bootsteps
|
||||
from celery.utils.log import get_logger
|
||||
|
||||
from .events import Events
|
||||
|
||||
__all__ = ('Mingle',)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
debug, info, exception = logger.debug, logger.info, logger.exception
|
||||
|
||||
|
||||
class Mingle(bootsteps.StartStopStep):
|
||||
"""Bootstep syncing state with neighbor workers.
|
||||
|
||||
At startup, or upon consumer restart, this will:
|
||||
|
||||
- Sync logical clocks.
|
||||
- Sync revoked tasks.
|
||||
|
||||
"""
|
||||
|
||||
label = 'Mingle'
|
||||
requires = (Events,)
|
||||
compatible_transports = {'amqp', 'redis'}
|
||||
|
||||
def __init__(self, c, without_mingle=False, **kwargs):
|
||||
self.enabled = not without_mingle and self.compatible_transport(c.app)
|
||||
super().__init__(
|
||||
c, without_mingle=without_mingle, **kwargs)
|
||||
|
||||
def compatible_transport(self, app):
|
||||
with app.connection_for_read() as conn:
|
||||
return conn.transport.driver_type in self.compatible_transports
|
||||
|
||||
def start(self, c):
|
||||
self.sync(c)
|
||||
|
||||
def sync(self, c):
|
||||
info('mingle: searching for neighbors')
|
||||
replies = self.send_hello(c)
|
||||
if replies:
|
||||
info('mingle: sync with %s nodes',
|
||||
len([reply for reply, value in replies.items() if value]))
|
||||
[self.on_node_reply(c, nodename, reply)
|
||||
for nodename, reply in replies.items() if reply]
|
||||
info('mingle: sync complete')
|
||||
else:
|
||||
info('mingle: all alone')
|
||||
|
||||
def send_hello(self, c):
|
||||
inspect = c.app.control.inspect(timeout=1.0, connection=c.connection)
|
||||
our_revoked = c.controller.state.revoked
|
||||
replies = inspect.hello(c.hostname, our_revoked._data) or {}
|
||||
replies.pop(c.hostname, None) # delete my own response
|
||||
return replies
|
||||
|
||||
def on_node_reply(self, c, nodename, reply):
|
||||
debug('mingle: processing reply from %s', nodename)
|
||||
try:
|
||||
self.sync_with_node(c, **reply)
|
||||
except MemoryError:
|
||||
raise
|
||||
except Exception as exc: # pylint: disable=broad-except
|
||||
exception('mingle: sync with %s failed: %r', nodename, exc)
|
||||
|
||||
def sync_with_node(self, c, clock=None, revoked=None, **kwargs):
|
||||
self.on_clock_event(c, clock)
|
||||
self.on_revoked_received(c, revoked)
|
||||
|
||||
def on_clock_event(self, c, clock):
|
||||
c.app.clock.adjust(clock) if clock else c.app.clock.forward()
|
||||
|
||||
def on_revoked_received(self, c, revoked):
|
||||
if revoked:
|
||||
c.controller.state.revoked.update(revoked)
|
||||
@@ -0,0 +1,65 @@
|
||||
"""Worker Task Consumer Bootstep."""
|
||||
from kombu.common import QoS, ignore_errors
|
||||
|
||||
from celery import bootsteps
|
||||
from celery.utils.log import get_logger
|
||||
|
||||
from .mingle import Mingle
|
||||
|
||||
__all__ = ('Tasks',)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
debug = logger.debug
|
||||
|
||||
|
||||
class Tasks(bootsteps.StartStopStep):
|
||||
"""Bootstep starting the task message consumer."""
|
||||
|
||||
requires = (Mingle,)
|
||||
|
||||
def __init__(self, c, **kwargs):
|
||||
c.task_consumer = c.qos = None
|
||||
super().__init__(c, **kwargs)
|
||||
|
||||
def start(self, c):
|
||||
"""Start task consumer."""
|
||||
c.update_strategies()
|
||||
|
||||
# - RabbitMQ 3.3 completely redefines how basic_qos works...
|
||||
# This will detect if the new qos semantics is in effect,
|
||||
# and if so make sure the 'apply_global' flag is set on qos updates.
|
||||
qos_global = not c.connection.qos_semantics_matches_spec
|
||||
|
||||
# set initial prefetch count
|
||||
c.connection.default_channel.basic_qos(
|
||||
0, c.initial_prefetch_count, qos_global,
|
||||
)
|
||||
|
||||
c.task_consumer = c.app.amqp.TaskConsumer(
|
||||
c.connection, on_decode_error=c.on_decode_error,
|
||||
)
|
||||
|
||||
def set_prefetch_count(prefetch_count):
|
||||
return c.task_consumer.qos(
|
||||
prefetch_count=prefetch_count,
|
||||
apply_global=qos_global,
|
||||
)
|
||||
c.qos = QoS(set_prefetch_count, c.initial_prefetch_count)
|
||||
|
||||
def stop(self, c):
|
||||
"""Stop task consumer."""
|
||||
if c.task_consumer:
|
||||
debug('Canceling task consumer...')
|
||||
ignore_errors(c, c.task_consumer.cancel)
|
||||
|
||||
def shutdown(self, c):
|
||||
"""Shutdown task consumer."""
|
||||
if c.task_consumer:
|
||||
self.stop(c)
|
||||
debug('Closing consumer channel...')
|
||||
ignore_errors(c, c.task_consumer.close)
|
||||
c.task_consumer = None
|
||||
|
||||
def info(self, c):
|
||||
"""Return task consumer info."""
|
||||
return {'prefetch_count': c.qos.value if c.qos else 'N/A'}
|
||||
Reference in New Issue
Block a user