Major fixes and new features
All checks were successful
continuous-integration/drone/push Build is passing
All checks were successful
continuous-integration/drone/push Build is passing
This commit is contained in:
@@ -0,0 +1 @@
|
||||
"""Result Backends."""
|
||||
190
venv/lib/python3.12/site-packages/celery/backends/arangodb.py
Normal file
190
venv/lib/python3.12/site-packages/celery/backends/arangodb.py
Normal file
@@ -0,0 +1,190 @@
|
||||
"""ArangoDb result store backend."""
|
||||
|
||||
# pylint: disable=W1202,W0703
|
||||
|
||||
from datetime import timedelta
|
||||
|
||||
from kombu.utils.objects import cached_property
|
||||
from kombu.utils.url import _parse_url
|
||||
|
||||
from celery.exceptions import ImproperlyConfigured
|
||||
|
||||
from .base import KeyValueStoreBackend
|
||||
|
||||
try:
|
||||
from pyArango import connection as py_arango_connection
|
||||
from pyArango.theExceptions import AQLQueryError
|
||||
except ImportError:
|
||||
py_arango_connection = AQLQueryError = None
|
||||
|
||||
__all__ = ('ArangoDbBackend',)
|
||||
|
||||
|
||||
class ArangoDbBackend(KeyValueStoreBackend):
|
||||
"""ArangoDb backend.
|
||||
|
||||
Sample url
|
||||
"arangodb://username:password@host:port/database/collection"
|
||||
*arangodb_backend_settings* is where the settings are present
|
||||
(in the app.conf)
|
||||
Settings should contain the host, port, username, password, database name,
|
||||
collection name else the default will be chosen.
|
||||
Default database name and collection name is celery.
|
||||
|
||||
Raises
|
||||
------
|
||||
celery.exceptions.ImproperlyConfigured:
|
||||
if module :pypi:`pyArango` is not available.
|
||||
|
||||
"""
|
||||
|
||||
host = '127.0.0.1'
|
||||
port = '8529'
|
||||
database = 'celery'
|
||||
collection = 'celery'
|
||||
username = None
|
||||
password = None
|
||||
# protocol is not supported in backend url (http is taken as default)
|
||||
http_protocol = 'http'
|
||||
verify = False
|
||||
|
||||
# Use str as arangodb key not bytes
|
||||
key_t = str
|
||||
|
||||
def __init__(self, url=None, *args, **kwargs):
|
||||
"""Parse the url or load the settings from settings object."""
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
if py_arango_connection is None:
|
||||
raise ImproperlyConfigured(
|
||||
'You need to install the pyArango library to use the '
|
||||
'ArangoDb backend.',
|
||||
)
|
||||
|
||||
self.url = url
|
||||
|
||||
if url is None:
|
||||
host = port = database = collection = username = password = None
|
||||
else:
|
||||
(
|
||||
_schema, host, port, username, password,
|
||||
database_collection, _query
|
||||
) = _parse_url(url)
|
||||
if database_collection is None:
|
||||
database = collection = None
|
||||
else:
|
||||
database, collection = database_collection.split('/')
|
||||
|
||||
config = self.app.conf.get('arangodb_backend_settings', None)
|
||||
if config is not None:
|
||||
if not isinstance(config, dict):
|
||||
raise ImproperlyConfigured(
|
||||
'ArangoDb backend settings should be grouped in a dict',
|
||||
)
|
||||
else:
|
||||
config = {}
|
||||
|
||||
self.host = host or config.get('host', self.host)
|
||||
self.port = int(port or config.get('port', self.port))
|
||||
self.http_protocol = config.get('http_protocol', self.http_protocol)
|
||||
self.verify = config.get('verify', self.verify)
|
||||
self.database = database or config.get('database', self.database)
|
||||
self.collection = \
|
||||
collection or config.get('collection', self.collection)
|
||||
self.username = username or config.get('username', self.username)
|
||||
self.password = password or config.get('password', self.password)
|
||||
self.arangodb_url = "{http_protocol}://{host}:{port}".format(
|
||||
http_protocol=self.http_protocol, host=self.host, port=self.port
|
||||
)
|
||||
self._connection = None
|
||||
|
||||
@property
|
||||
def connection(self):
|
||||
"""Connect to the arangodb server."""
|
||||
if self._connection is None:
|
||||
self._connection = py_arango_connection.Connection(
|
||||
arangoURL=self.arangodb_url, username=self.username,
|
||||
password=self.password, verify=self.verify
|
||||
)
|
||||
return self._connection
|
||||
|
||||
@property
|
||||
def db(self):
|
||||
"""Database Object to the given database."""
|
||||
return self.connection[self.database]
|
||||
|
||||
@cached_property
|
||||
def expires_delta(self):
|
||||
return timedelta(seconds=0 if self.expires is None else self.expires)
|
||||
|
||||
def get(self, key):
|
||||
if key is None:
|
||||
return None
|
||||
query = self.db.AQLQuery(
|
||||
"RETURN DOCUMENT(@@collection, @key).task",
|
||||
rawResults=True,
|
||||
bindVars={
|
||||
"@collection": self.collection,
|
||||
"key": key,
|
||||
},
|
||||
)
|
||||
return next(query) if len(query) > 0 else None
|
||||
|
||||
def set(self, key, value):
|
||||
self.db.AQLQuery(
|
||||
"""
|
||||
UPSERT {_key: @key}
|
||||
INSERT {_key: @key, task: @value}
|
||||
UPDATE {task: @value} IN @@collection
|
||||
""",
|
||||
bindVars={
|
||||
"@collection": self.collection,
|
||||
"key": key,
|
||||
"value": value,
|
||||
},
|
||||
)
|
||||
|
||||
def mget(self, keys):
|
||||
if keys is None:
|
||||
return
|
||||
query = self.db.AQLQuery(
|
||||
"FOR k IN @keys RETURN DOCUMENT(@@collection, k).task",
|
||||
rawResults=True,
|
||||
bindVars={
|
||||
"@collection": self.collection,
|
||||
"keys": keys if isinstance(keys, list) else list(keys),
|
||||
},
|
||||
)
|
||||
while True:
|
||||
yield from query
|
||||
try:
|
||||
query.nextBatch()
|
||||
except StopIteration:
|
||||
break
|
||||
|
||||
def delete(self, key):
|
||||
if key is None:
|
||||
return
|
||||
self.db.AQLQuery(
|
||||
"REMOVE {_key: @key} IN @@collection",
|
||||
bindVars={
|
||||
"@collection": self.collection,
|
||||
"key": key,
|
||||
},
|
||||
)
|
||||
|
||||
def cleanup(self):
|
||||
if not self.expires:
|
||||
return
|
||||
checkpoint = (self.app.now() - self.expires_delta).isoformat()
|
||||
self.db.AQLQuery(
|
||||
"""
|
||||
FOR record IN @@collection
|
||||
FILTER record.task.date_done < @checkpoint
|
||||
REMOVE record IN @@collection
|
||||
""",
|
||||
bindVars={
|
||||
"@collection": self.collection,
|
||||
"checkpoint": checkpoint,
|
||||
},
|
||||
)
|
||||
@@ -0,0 +1,333 @@
|
||||
"""Async I/O backend support utilities."""
|
||||
import socket
|
||||
import threading
|
||||
import time
|
||||
from collections import deque
|
||||
from queue import Empty
|
||||
from time import sleep
|
||||
from weakref import WeakKeyDictionary
|
||||
|
||||
from kombu.utils.compat import detect_environment
|
||||
|
||||
from celery import states
|
||||
from celery.exceptions import TimeoutError
|
||||
from celery.utils.threads import THREAD_TIMEOUT_MAX
|
||||
|
||||
__all__ = (
|
||||
'AsyncBackendMixin', 'BaseResultConsumer', 'Drainer',
|
||||
'register_drainer',
|
||||
)
|
||||
|
||||
drainers = {}
|
||||
|
||||
|
||||
def register_drainer(name):
|
||||
"""Decorator used to register a new result drainer type."""
|
||||
def _inner(cls):
|
||||
drainers[name] = cls
|
||||
return cls
|
||||
return _inner
|
||||
|
||||
|
||||
@register_drainer('default')
|
||||
class Drainer:
|
||||
"""Result draining service."""
|
||||
|
||||
def __init__(self, result_consumer):
|
||||
self.result_consumer = result_consumer
|
||||
|
||||
def start(self):
|
||||
pass
|
||||
|
||||
def stop(self):
|
||||
pass
|
||||
|
||||
def drain_events_until(self, p, timeout=None, interval=1, on_interval=None, wait=None):
|
||||
wait = wait or self.result_consumer.drain_events
|
||||
time_start = time.monotonic()
|
||||
|
||||
while 1:
|
||||
# Total time spent may exceed a single call to wait()
|
||||
if timeout and time.monotonic() - time_start >= timeout:
|
||||
raise socket.timeout()
|
||||
try:
|
||||
yield self.wait_for(p, wait, timeout=interval)
|
||||
except socket.timeout:
|
||||
pass
|
||||
if on_interval:
|
||||
on_interval()
|
||||
if p.ready: # got event on the wanted channel.
|
||||
break
|
||||
|
||||
def wait_for(self, p, wait, timeout=None):
|
||||
wait(timeout=timeout)
|
||||
|
||||
|
||||
class greenletDrainer(Drainer):
|
||||
spawn = None
|
||||
_g = None
|
||||
_drain_complete_event = None # event, sended (and recreated) after every drain_events iteration
|
||||
|
||||
def _create_drain_complete_event(self):
|
||||
"""create new self._drain_complete_event object"""
|
||||
pass
|
||||
|
||||
def _send_drain_complete_event(self):
|
||||
"""raise self._drain_complete_event for wakeup .wait_for"""
|
||||
pass
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._started = threading.Event()
|
||||
self._stopped = threading.Event()
|
||||
self._shutdown = threading.Event()
|
||||
self._create_drain_complete_event()
|
||||
|
||||
def run(self):
|
||||
self._started.set()
|
||||
while not self._stopped.is_set():
|
||||
try:
|
||||
self.result_consumer.drain_events(timeout=1)
|
||||
self._send_drain_complete_event()
|
||||
self._create_drain_complete_event()
|
||||
except socket.timeout:
|
||||
pass
|
||||
self._shutdown.set()
|
||||
|
||||
def start(self):
|
||||
if not self._started.is_set():
|
||||
self._g = self.spawn(self.run)
|
||||
self._started.wait()
|
||||
|
||||
def stop(self):
|
||||
self._stopped.set()
|
||||
self._send_drain_complete_event()
|
||||
self._shutdown.wait(THREAD_TIMEOUT_MAX)
|
||||
|
||||
def wait_for(self, p, wait, timeout=None):
|
||||
self.start()
|
||||
if not p.ready:
|
||||
self._drain_complete_event.wait(timeout=timeout)
|
||||
|
||||
|
||||
@register_drainer('eventlet')
|
||||
class eventletDrainer(greenletDrainer):
|
||||
|
||||
def spawn(self, func):
|
||||
from eventlet import sleep, spawn
|
||||
g = spawn(func)
|
||||
sleep(0)
|
||||
return g
|
||||
|
||||
def _create_drain_complete_event(self):
|
||||
from eventlet.event import Event
|
||||
self._drain_complete_event = Event()
|
||||
|
||||
def _send_drain_complete_event(self):
|
||||
self._drain_complete_event.send()
|
||||
|
||||
|
||||
@register_drainer('gevent')
|
||||
class geventDrainer(greenletDrainer):
|
||||
|
||||
def spawn(self, func):
|
||||
import gevent
|
||||
g = gevent.spawn(func)
|
||||
gevent.sleep(0)
|
||||
return g
|
||||
|
||||
def _create_drain_complete_event(self):
|
||||
from gevent.event import Event
|
||||
self._drain_complete_event = Event()
|
||||
|
||||
def _send_drain_complete_event(self):
|
||||
self._drain_complete_event.set()
|
||||
self._create_drain_complete_event()
|
||||
|
||||
|
||||
class AsyncBackendMixin:
|
||||
"""Mixin for backends that enables the async API."""
|
||||
|
||||
def _collect_into(self, result, bucket):
|
||||
self.result_consumer.buckets[result] = bucket
|
||||
|
||||
def iter_native(self, result, no_ack=True, **kwargs):
|
||||
self._ensure_not_eager()
|
||||
|
||||
results = result.results
|
||||
if not results:
|
||||
raise StopIteration()
|
||||
|
||||
# we tell the result consumer to put consumed results
|
||||
# into these buckets.
|
||||
bucket = deque()
|
||||
for node in results:
|
||||
if not hasattr(node, '_cache'):
|
||||
bucket.append(node)
|
||||
elif node._cache:
|
||||
bucket.append(node)
|
||||
else:
|
||||
self._collect_into(node, bucket)
|
||||
|
||||
for _ in self._wait_for_pending(result, no_ack=no_ack, **kwargs):
|
||||
while bucket:
|
||||
node = bucket.popleft()
|
||||
if not hasattr(node, '_cache'):
|
||||
yield node.id, node.children
|
||||
else:
|
||||
yield node.id, node._cache
|
||||
while bucket:
|
||||
node = bucket.popleft()
|
||||
yield node.id, node._cache
|
||||
|
||||
def add_pending_result(self, result, weak=False, start_drainer=True):
|
||||
if start_drainer:
|
||||
self.result_consumer.drainer.start()
|
||||
try:
|
||||
self._maybe_resolve_from_buffer(result)
|
||||
except Empty:
|
||||
self._add_pending_result(result.id, result, weak=weak)
|
||||
return result
|
||||
|
||||
def _maybe_resolve_from_buffer(self, result):
|
||||
result._maybe_set_cache(self._pending_messages.take(result.id))
|
||||
|
||||
def _add_pending_result(self, task_id, result, weak=False):
|
||||
concrete, weak_ = self._pending_results
|
||||
if task_id not in weak_ and result.id not in concrete:
|
||||
(weak_ if weak else concrete)[task_id] = result
|
||||
self.result_consumer.consume_from(task_id)
|
||||
|
||||
def add_pending_results(self, results, weak=False):
|
||||
self.result_consumer.drainer.start()
|
||||
return [self.add_pending_result(result, weak=weak, start_drainer=False)
|
||||
for result in results]
|
||||
|
||||
def remove_pending_result(self, result):
|
||||
self._remove_pending_result(result.id)
|
||||
self.on_result_fulfilled(result)
|
||||
return result
|
||||
|
||||
def _remove_pending_result(self, task_id):
|
||||
for mapping in self._pending_results:
|
||||
mapping.pop(task_id, None)
|
||||
|
||||
def on_result_fulfilled(self, result):
|
||||
self.result_consumer.cancel_for(result.id)
|
||||
|
||||
def wait_for_pending(self, result,
|
||||
callback=None, propagate=True, **kwargs):
|
||||
self._ensure_not_eager()
|
||||
for _ in self._wait_for_pending(result, **kwargs):
|
||||
pass
|
||||
return result.maybe_throw(callback=callback, propagate=propagate)
|
||||
|
||||
def _wait_for_pending(self, result,
|
||||
timeout=None, on_interval=None, on_message=None,
|
||||
**kwargs):
|
||||
return self.result_consumer._wait_for_pending(
|
||||
result, timeout=timeout,
|
||||
on_interval=on_interval, on_message=on_message,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
@property
|
||||
def is_async(self):
|
||||
return True
|
||||
|
||||
|
||||
class BaseResultConsumer:
|
||||
"""Manager responsible for consuming result messages."""
|
||||
|
||||
def __init__(self, backend, app, accept,
|
||||
pending_results, pending_messages):
|
||||
self.backend = backend
|
||||
self.app = app
|
||||
self.accept = accept
|
||||
self._pending_results = pending_results
|
||||
self._pending_messages = pending_messages
|
||||
self.on_message = None
|
||||
self.buckets = WeakKeyDictionary()
|
||||
self.drainer = drainers[detect_environment()](self)
|
||||
|
||||
def start(self, initial_task_id, **kwargs):
|
||||
raise NotImplementedError()
|
||||
|
||||
def stop(self):
|
||||
pass
|
||||
|
||||
def drain_events(self, timeout=None):
|
||||
raise NotImplementedError()
|
||||
|
||||
def consume_from(self, task_id):
|
||||
raise NotImplementedError()
|
||||
|
||||
def cancel_for(self, task_id):
|
||||
raise NotImplementedError()
|
||||
|
||||
def _after_fork(self):
|
||||
self.buckets.clear()
|
||||
self.buckets = WeakKeyDictionary()
|
||||
self.on_message = None
|
||||
self.on_after_fork()
|
||||
|
||||
def on_after_fork(self):
|
||||
pass
|
||||
|
||||
def drain_events_until(self, p, timeout=None, on_interval=None):
|
||||
return self.drainer.drain_events_until(
|
||||
p, timeout=timeout, on_interval=on_interval)
|
||||
|
||||
def _wait_for_pending(self, result,
|
||||
timeout=None, on_interval=None, on_message=None,
|
||||
**kwargs):
|
||||
self.on_wait_for_pending(result, timeout=timeout, **kwargs)
|
||||
prev_on_m, self.on_message = self.on_message, on_message
|
||||
try:
|
||||
for _ in self.drain_events_until(
|
||||
result.on_ready, timeout=timeout,
|
||||
on_interval=on_interval):
|
||||
yield
|
||||
sleep(0)
|
||||
except socket.timeout:
|
||||
raise TimeoutError('The operation timed out.')
|
||||
finally:
|
||||
self.on_message = prev_on_m
|
||||
|
||||
def on_wait_for_pending(self, result, timeout=None, **kwargs):
|
||||
pass
|
||||
|
||||
def on_out_of_band_result(self, message):
|
||||
self.on_state_change(message.payload, message)
|
||||
|
||||
def _get_pending_result(self, task_id):
|
||||
for mapping in self._pending_results:
|
||||
try:
|
||||
return mapping[task_id]
|
||||
except KeyError:
|
||||
pass
|
||||
raise KeyError(task_id)
|
||||
|
||||
def on_state_change(self, meta, message):
|
||||
if self.on_message:
|
||||
self.on_message(meta)
|
||||
if meta['status'] in states.READY_STATES:
|
||||
task_id = meta['task_id']
|
||||
try:
|
||||
result = self._get_pending_result(task_id)
|
||||
except KeyError:
|
||||
# send to buffer in case we received this result
|
||||
# before it was added to _pending_results.
|
||||
self._pending_messages.put(task_id, meta)
|
||||
else:
|
||||
result._maybe_set_cache(meta)
|
||||
buckets = self.buckets
|
||||
try:
|
||||
# remove bucket for this result, since it's fulfilled
|
||||
bucket = buckets.pop(result)
|
||||
except KeyError:
|
||||
pass
|
||||
else:
|
||||
# send to waiter via bucket
|
||||
bucket.append(result)
|
||||
sleep(0)
|
||||
@@ -0,0 +1,165 @@
|
||||
"""The Azure Storage Block Blob backend for Celery."""
|
||||
from kombu.utils import cached_property
|
||||
from kombu.utils.encoding import bytes_to_str
|
||||
|
||||
from celery.exceptions import ImproperlyConfigured
|
||||
from celery.utils.log import get_logger
|
||||
|
||||
from .base import KeyValueStoreBackend
|
||||
|
||||
try:
|
||||
import azure.storage.blob as azurestorage
|
||||
from azure.core.exceptions import ResourceExistsError, ResourceNotFoundError
|
||||
from azure.storage.blob import BlobServiceClient
|
||||
except ImportError:
|
||||
azurestorage = None
|
||||
|
||||
__all__ = ("AzureBlockBlobBackend",)
|
||||
|
||||
LOGGER = get_logger(__name__)
|
||||
AZURE_BLOCK_BLOB_CONNECTION_PREFIX = 'azureblockblob://'
|
||||
|
||||
|
||||
class AzureBlockBlobBackend(KeyValueStoreBackend):
|
||||
"""Azure Storage Block Blob backend for Celery."""
|
||||
|
||||
def __init__(self,
|
||||
url=None,
|
||||
container_name=None,
|
||||
*args,
|
||||
**kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
if azurestorage is None or azurestorage.__version__ < '12':
|
||||
raise ImproperlyConfigured(
|
||||
"You need to install the azure-storage-blob v12 library to"
|
||||
"use the AzureBlockBlob backend")
|
||||
|
||||
conf = self.app.conf
|
||||
|
||||
self._connection_string = self._parse_url(url)
|
||||
|
||||
self._container_name = (
|
||||
container_name or
|
||||
conf["azureblockblob_container_name"])
|
||||
|
||||
self.base_path = conf.get('azureblockblob_base_path', '')
|
||||
self._connection_timeout = conf.get(
|
||||
'azureblockblob_connection_timeout', 20
|
||||
)
|
||||
self._read_timeout = conf.get('azureblockblob_read_timeout', 120)
|
||||
|
||||
@classmethod
|
||||
def _parse_url(cls, url, prefix=AZURE_BLOCK_BLOB_CONNECTION_PREFIX):
|
||||
connection_string = url[len(prefix):]
|
||||
if not connection_string:
|
||||
raise ImproperlyConfigured("Invalid URL")
|
||||
|
||||
return connection_string
|
||||
|
||||
@cached_property
|
||||
def _blob_service_client(self):
|
||||
"""Return the Azure Storage Blob service client.
|
||||
|
||||
If this is the first call to the property, the client is created and
|
||||
the container is created if it doesn't yet exist.
|
||||
|
||||
"""
|
||||
client = BlobServiceClient.from_connection_string(
|
||||
self._connection_string,
|
||||
connection_timeout=self._connection_timeout,
|
||||
read_timeout=self._read_timeout
|
||||
)
|
||||
|
||||
try:
|
||||
client.create_container(name=self._container_name)
|
||||
msg = f"Container created with name {self._container_name}."
|
||||
except ResourceExistsError:
|
||||
msg = f"Container with name {self._container_name} already." \
|
||||
"exists. This will not be created."
|
||||
LOGGER.info(msg)
|
||||
|
||||
return client
|
||||
|
||||
def get(self, key):
|
||||
"""Read the value stored at the given key.
|
||||
|
||||
Args:
|
||||
key: The key for which to read the value.
|
||||
"""
|
||||
key = bytes_to_str(key)
|
||||
LOGGER.debug("Getting Azure Block Blob %s/%s", self._container_name, key)
|
||||
|
||||
blob_client = self._blob_service_client.get_blob_client(
|
||||
container=self._container_name,
|
||||
blob=f'{self.base_path}{key}',
|
||||
)
|
||||
|
||||
try:
|
||||
return blob_client.download_blob().readall().decode()
|
||||
except ResourceNotFoundError:
|
||||
return None
|
||||
|
||||
def set(self, key, value):
|
||||
"""Store a value for a given key.
|
||||
|
||||
Args:
|
||||
key: The key at which to store the value.
|
||||
value: The value to store.
|
||||
|
||||
"""
|
||||
key = bytes_to_str(key)
|
||||
LOGGER.debug(f"Creating azure blob at {self._container_name}/{key}")
|
||||
|
||||
blob_client = self._blob_service_client.get_blob_client(
|
||||
container=self._container_name,
|
||||
blob=f'{self.base_path}{key}',
|
||||
)
|
||||
|
||||
blob_client.upload_blob(value, overwrite=True)
|
||||
|
||||
def mget(self, keys):
|
||||
"""Read all the values for the provided keys.
|
||||
|
||||
Args:
|
||||
keys: The list of keys to read.
|
||||
|
||||
"""
|
||||
return [self.get(key) for key in keys]
|
||||
|
||||
def delete(self, key):
|
||||
"""Delete the value at a given key.
|
||||
|
||||
Args:
|
||||
key: The key of the value to delete.
|
||||
|
||||
"""
|
||||
key = bytes_to_str(key)
|
||||
LOGGER.debug(f"Deleting azure blob at {self._container_name}/{key}")
|
||||
|
||||
blob_client = self._blob_service_client.get_blob_client(
|
||||
container=self._container_name,
|
||||
blob=f'{self.base_path}{key}',
|
||||
)
|
||||
|
||||
blob_client.delete_blob()
|
||||
|
||||
def as_uri(self, include_password=False):
|
||||
if include_password:
|
||||
return (
|
||||
f'{AZURE_BLOCK_BLOB_CONNECTION_PREFIX}'
|
||||
f'{self._connection_string}'
|
||||
)
|
||||
|
||||
connection_string_parts = self._connection_string.split(';')
|
||||
account_key_prefix = 'AccountKey='
|
||||
redacted_connection_string_parts = [
|
||||
f'{account_key_prefix}**' if part.startswith(account_key_prefix)
|
||||
else part
|
||||
for part in connection_string_parts
|
||||
]
|
||||
|
||||
return (
|
||||
f'{AZURE_BLOCK_BLOB_CONNECTION_PREFIX}'
|
||||
f'{";".join(redacted_connection_string_parts)}'
|
||||
)
|
||||
1110
venv/lib/python3.12/site-packages/celery/backends/base.py
Normal file
1110
venv/lib/python3.12/site-packages/celery/backends/base.py
Normal file
File diff suppressed because it is too large
Load Diff
163
venv/lib/python3.12/site-packages/celery/backends/cache.py
Normal file
163
venv/lib/python3.12/site-packages/celery/backends/cache.py
Normal file
@@ -0,0 +1,163 @@
|
||||
"""Memcached and in-memory cache result backend."""
|
||||
from kombu.utils.encoding import bytes_to_str, ensure_bytes
|
||||
from kombu.utils.objects import cached_property
|
||||
|
||||
from celery.exceptions import ImproperlyConfigured
|
||||
from celery.utils.functional import LRUCache
|
||||
|
||||
from .base import KeyValueStoreBackend
|
||||
|
||||
__all__ = ('CacheBackend',)
|
||||
|
||||
_imp = [None]
|
||||
|
||||
REQUIRES_BACKEND = """\
|
||||
The Memcached backend requires either pylibmc or python-memcached.\
|
||||
"""
|
||||
|
||||
UNKNOWN_BACKEND = """\
|
||||
The cache backend {0!r} is unknown,
|
||||
Please use one of the following backends instead: {1}\
|
||||
"""
|
||||
|
||||
# Global shared in-memory cache for in-memory cache client
|
||||
# This is to share cache between threads
|
||||
_DUMMY_CLIENT_CACHE = LRUCache(limit=5000)
|
||||
|
||||
|
||||
def import_best_memcache():
|
||||
if _imp[0] is None:
|
||||
is_pylibmc, memcache_key_t = False, bytes_to_str
|
||||
try:
|
||||
import pylibmc as memcache
|
||||
is_pylibmc = True
|
||||
except ImportError:
|
||||
try:
|
||||
import memcache
|
||||
except ImportError:
|
||||
raise ImproperlyConfigured(REQUIRES_BACKEND)
|
||||
_imp[0] = (is_pylibmc, memcache, memcache_key_t)
|
||||
return _imp[0]
|
||||
|
||||
|
||||
def get_best_memcache(*args, **kwargs):
|
||||
# pylint: disable=unpacking-non-sequence
|
||||
# This is most definitely a sequence, but pylint thinks it's not.
|
||||
is_pylibmc, memcache, key_t = import_best_memcache()
|
||||
Client = _Client = memcache.Client
|
||||
|
||||
if not is_pylibmc:
|
||||
def Client(*args, **kwargs): # noqa: F811
|
||||
kwargs.pop('behaviors', None)
|
||||
return _Client(*args, **kwargs)
|
||||
|
||||
return Client, key_t
|
||||
|
||||
|
||||
class DummyClient:
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.cache = _DUMMY_CLIENT_CACHE
|
||||
|
||||
def get(self, key, *args, **kwargs):
|
||||
return self.cache.get(key)
|
||||
|
||||
def get_multi(self, keys):
|
||||
cache = self.cache
|
||||
return {k: cache[k] for k in keys if k in cache}
|
||||
|
||||
def set(self, key, value, *args, **kwargs):
|
||||
self.cache[key] = value
|
||||
|
||||
def delete(self, key, *args, **kwargs):
|
||||
self.cache.pop(key, None)
|
||||
|
||||
def incr(self, key, delta=1):
|
||||
return self.cache.incr(key, delta)
|
||||
|
||||
def touch(self, key, expire):
|
||||
pass
|
||||
|
||||
|
||||
backends = {
|
||||
'memcache': get_best_memcache,
|
||||
'memcached': get_best_memcache,
|
||||
'pylibmc': get_best_memcache,
|
||||
'memory': lambda: (DummyClient, ensure_bytes),
|
||||
}
|
||||
|
||||
|
||||
class CacheBackend(KeyValueStoreBackend):
|
||||
"""Cache result backend."""
|
||||
|
||||
servers = None
|
||||
supports_autoexpire = True
|
||||
supports_native_join = True
|
||||
implements_incr = True
|
||||
|
||||
def __init__(self, app, expires=None, backend=None,
|
||||
options=None, url=None, **kwargs):
|
||||
options = {} if not options else options
|
||||
super().__init__(app, **kwargs)
|
||||
self.url = url
|
||||
|
||||
self.options = dict(self.app.conf.cache_backend_options,
|
||||
**options)
|
||||
|
||||
self.backend = url or backend or self.app.conf.cache_backend
|
||||
if self.backend:
|
||||
self.backend, _, servers = self.backend.partition('://')
|
||||
self.servers = servers.rstrip('/').split(';')
|
||||
self.expires = self.prepare_expires(expires, type=int)
|
||||
try:
|
||||
self.Client, self.key_t = backends[self.backend]()
|
||||
except KeyError:
|
||||
raise ImproperlyConfigured(UNKNOWN_BACKEND.format(
|
||||
self.backend, ', '.join(backends)))
|
||||
self._encode_prefixes() # rencode the keyprefixes
|
||||
|
||||
def get(self, key):
|
||||
return self.client.get(key)
|
||||
|
||||
def mget(self, keys):
|
||||
return self.client.get_multi(keys)
|
||||
|
||||
def set(self, key, value):
|
||||
return self.client.set(key, value, self.expires)
|
||||
|
||||
def delete(self, key):
|
||||
return self.client.delete(key)
|
||||
|
||||
def _apply_chord_incr(self, header_result_args, body, **kwargs):
|
||||
chord_key = self.get_key_for_chord(header_result_args[0])
|
||||
self.client.set(chord_key, 0, time=self.expires)
|
||||
return super()._apply_chord_incr(
|
||||
header_result_args, body, **kwargs)
|
||||
|
||||
def incr(self, key):
|
||||
return self.client.incr(key)
|
||||
|
||||
def expire(self, key, value):
|
||||
return self.client.touch(key, value)
|
||||
|
||||
@cached_property
|
||||
def client(self):
|
||||
return self.Client(self.servers, **self.options)
|
||||
|
||||
def __reduce__(self, args=(), kwargs=None):
|
||||
kwargs = {} if not kwargs else kwargs
|
||||
servers = ';'.join(self.servers)
|
||||
backend = f'{self.backend}://{servers}/'
|
||||
kwargs.update(
|
||||
{'backend': backend,
|
||||
'expires': self.expires,
|
||||
'options': self.options})
|
||||
return super().__reduce__(args, kwargs)
|
||||
|
||||
def as_uri(self, *args, **kwargs):
|
||||
"""Return the backend as an URI.
|
||||
|
||||
This properly handles the case of multiple servers.
|
||||
"""
|
||||
servers = ';'.join(self.servers)
|
||||
return f'{self.backend}://{servers}/'
|
||||
256
venv/lib/python3.12/site-packages/celery/backends/cassandra.py
Normal file
256
venv/lib/python3.12/site-packages/celery/backends/cassandra.py
Normal file
@@ -0,0 +1,256 @@
|
||||
"""Apache Cassandra result store backend using the DataStax driver."""
|
||||
import threading
|
||||
|
||||
from celery import states
|
||||
from celery.exceptions import ImproperlyConfigured
|
||||
from celery.utils.log import get_logger
|
||||
|
||||
from .base import BaseBackend
|
||||
|
||||
try: # pragma: no cover
|
||||
import cassandra
|
||||
import cassandra.auth
|
||||
import cassandra.cluster
|
||||
import cassandra.query
|
||||
except ImportError:
|
||||
cassandra = None
|
||||
|
||||
|
||||
__all__ = ('CassandraBackend',)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
E_NO_CASSANDRA = """
|
||||
You need to install the cassandra-driver library to
|
||||
use the Cassandra backend. See https://github.com/datastax/python-driver
|
||||
"""
|
||||
|
||||
E_NO_SUCH_CASSANDRA_AUTH_PROVIDER = """
|
||||
CASSANDRA_AUTH_PROVIDER you provided is not a valid auth_provider class.
|
||||
See https://datastax.github.io/python-driver/api/cassandra/auth.html.
|
||||
"""
|
||||
|
||||
E_CASSANDRA_MISCONFIGURED = 'Cassandra backend improperly configured.'
|
||||
|
||||
E_CASSANDRA_NOT_CONFIGURED = 'Cassandra backend not configured.'
|
||||
|
||||
Q_INSERT_RESULT = """
|
||||
INSERT INTO {table} (
|
||||
task_id, status, result, date_done, traceback, children) VALUES (
|
||||
%s, %s, %s, %s, %s, %s) {expires};
|
||||
"""
|
||||
|
||||
Q_SELECT_RESULT = """
|
||||
SELECT status, result, date_done, traceback, children
|
||||
FROM {table}
|
||||
WHERE task_id=%s
|
||||
LIMIT 1
|
||||
"""
|
||||
|
||||
Q_CREATE_RESULT_TABLE = """
|
||||
CREATE TABLE {table} (
|
||||
task_id text,
|
||||
status text,
|
||||
result blob,
|
||||
date_done timestamp,
|
||||
traceback blob,
|
||||
children blob,
|
||||
PRIMARY KEY ((task_id), date_done)
|
||||
) WITH CLUSTERING ORDER BY (date_done DESC);
|
||||
"""
|
||||
|
||||
Q_EXPIRES = """
|
||||
USING TTL {0}
|
||||
"""
|
||||
|
||||
|
||||
def buf_t(x):
|
||||
return bytes(x, 'utf8')
|
||||
|
||||
|
||||
class CassandraBackend(BaseBackend):
|
||||
"""Cassandra/AstraDB backend utilizing DataStax driver.
|
||||
|
||||
Raises:
|
||||
celery.exceptions.ImproperlyConfigured:
|
||||
if module :pypi:`cassandra-driver` is not available,
|
||||
or not-exactly-one of the :setting:`cassandra_servers` and
|
||||
the :setting:`cassandra_secure_bundle_path` settings is set.
|
||||
"""
|
||||
|
||||
#: List of Cassandra servers with format: ``hostname``.
|
||||
servers = None
|
||||
#: Location of the secure connect bundle zipfile (absolute path).
|
||||
bundle_path = None
|
||||
|
||||
supports_autoexpire = True # autoexpire supported via entry_ttl
|
||||
|
||||
def __init__(self, servers=None, keyspace=None, table=None, entry_ttl=None,
|
||||
port=9042, bundle_path=None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
if not cassandra:
|
||||
raise ImproperlyConfigured(E_NO_CASSANDRA)
|
||||
|
||||
conf = self.app.conf
|
||||
self.servers = servers or conf.get('cassandra_servers', None)
|
||||
self.bundle_path = bundle_path or conf.get(
|
||||
'cassandra_secure_bundle_path', None)
|
||||
self.port = port or conf.get('cassandra_port', None)
|
||||
self.keyspace = keyspace or conf.get('cassandra_keyspace', None)
|
||||
self.table = table or conf.get('cassandra_table', None)
|
||||
self.cassandra_options = conf.get('cassandra_options', {})
|
||||
|
||||
# either servers or bundle path must be provided...
|
||||
db_directions = self.servers or self.bundle_path
|
||||
if not db_directions or not self.keyspace or not self.table:
|
||||
raise ImproperlyConfigured(E_CASSANDRA_NOT_CONFIGURED)
|
||||
# ...but not both:
|
||||
if self.servers and self.bundle_path:
|
||||
raise ImproperlyConfigured(E_CASSANDRA_MISCONFIGURED)
|
||||
|
||||
expires = entry_ttl or conf.get('cassandra_entry_ttl', None)
|
||||
|
||||
self.cqlexpires = (
|
||||
Q_EXPIRES.format(expires) if expires is not None else '')
|
||||
|
||||
read_cons = conf.get('cassandra_read_consistency') or 'LOCAL_QUORUM'
|
||||
write_cons = conf.get('cassandra_write_consistency') or 'LOCAL_QUORUM'
|
||||
|
||||
self.read_consistency = getattr(
|
||||
cassandra.ConsistencyLevel, read_cons,
|
||||
cassandra.ConsistencyLevel.LOCAL_QUORUM)
|
||||
self.write_consistency = getattr(
|
||||
cassandra.ConsistencyLevel, write_cons,
|
||||
cassandra.ConsistencyLevel.LOCAL_QUORUM)
|
||||
|
||||
self.auth_provider = None
|
||||
auth_provider = conf.get('cassandra_auth_provider', None)
|
||||
auth_kwargs = conf.get('cassandra_auth_kwargs', None)
|
||||
if auth_provider and auth_kwargs:
|
||||
auth_provider_class = getattr(cassandra.auth, auth_provider, None)
|
||||
if not auth_provider_class:
|
||||
raise ImproperlyConfigured(E_NO_SUCH_CASSANDRA_AUTH_PROVIDER)
|
||||
self.auth_provider = auth_provider_class(**auth_kwargs)
|
||||
|
||||
self._cluster = None
|
||||
self._session = None
|
||||
self._write_stmt = None
|
||||
self._read_stmt = None
|
||||
self._lock = threading.RLock()
|
||||
|
||||
def _get_connection(self, write=False):
|
||||
"""Prepare the connection for action.
|
||||
|
||||
Arguments:
|
||||
write (bool): are we a writer?
|
||||
"""
|
||||
if self._session is not None:
|
||||
return
|
||||
self._lock.acquire()
|
||||
try:
|
||||
if self._session is not None:
|
||||
return
|
||||
# using either 'servers' or 'bundle_path' here:
|
||||
if self.servers:
|
||||
self._cluster = cassandra.cluster.Cluster(
|
||||
self.servers, port=self.port,
|
||||
auth_provider=self.auth_provider,
|
||||
**self.cassandra_options)
|
||||
else:
|
||||
# 'bundle_path' is guaranteed to be set
|
||||
self._cluster = cassandra.cluster.Cluster(
|
||||
cloud={
|
||||
'secure_connect_bundle': self.bundle_path,
|
||||
},
|
||||
auth_provider=self.auth_provider,
|
||||
**self.cassandra_options)
|
||||
self._session = self._cluster.connect(self.keyspace)
|
||||
|
||||
# We're forced to do concatenation below, as formatting would
|
||||
# blow up on superficial %s that'll be processed by Cassandra
|
||||
self._write_stmt = cassandra.query.SimpleStatement(
|
||||
Q_INSERT_RESULT.format(
|
||||
table=self.table, expires=self.cqlexpires),
|
||||
)
|
||||
self._write_stmt.consistency_level = self.write_consistency
|
||||
|
||||
self._read_stmt = cassandra.query.SimpleStatement(
|
||||
Q_SELECT_RESULT.format(table=self.table),
|
||||
)
|
||||
self._read_stmt.consistency_level = self.read_consistency
|
||||
|
||||
if write:
|
||||
# Only possible writers "workers" are allowed to issue
|
||||
# CREATE TABLE. This is to prevent conflicting situations
|
||||
# where both task-creator and task-executor would issue it
|
||||
# at the same time.
|
||||
|
||||
# Anyway; if you're doing anything critical, you should
|
||||
# have created this table in advance, in which case
|
||||
# this query will be a no-op (AlreadyExists)
|
||||
make_stmt = cassandra.query.SimpleStatement(
|
||||
Q_CREATE_RESULT_TABLE.format(table=self.table),
|
||||
)
|
||||
make_stmt.consistency_level = self.write_consistency
|
||||
|
||||
try:
|
||||
self._session.execute(make_stmt)
|
||||
except cassandra.AlreadyExists:
|
||||
pass
|
||||
|
||||
except cassandra.OperationTimedOut:
|
||||
# a heavily loaded or gone Cassandra cluster failed to respond.
|
||||
# leave this class in a consistent state
|
||||
if self._cluster is not None:
|
||||
self._cluster.shutdown() # also shuts down _session
|
||||
|
||||
self._cluster = None
|
||||
self._session = None
|
||||
raise # we did fail after all - reraise
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def _store_result(self, task_id, result, state,
|
||||
traceback=None, request=None, **kwargs):
|
||||
"""Store return value and state of an executed task."""
|
||||
self._get_connection(write=True)
|
||||
|
||||
self._session.execute(self._write_stmt, (
|
||||
task_id,
|
||||
state,
|
||||
buf_t(self.encode(result)),
|
||||
self.app.now(),
|
||||
buf_t(self.encode(traceback)),
|
||||
buf_t(self.encode(self.current_task_children(request)))
|
||||
))
|
||||
|
||||
def as_uri(self, include_password=True):
|
||||
return 'cassandra://'
|
||||
|
||||
def _get_task_meta_for(self, task_id):
|
||||
"""Get task meta-data for a task by id."""
|
||||
self._get_connection()
|
||||
|
||||
res = self._session.execute(self._read_stmt, (task_id, )).one()
|
||||
if not res:
|
||||
return {'status': states.PENDING, 'result': None}
|
||||
|
||||
status, result, date_done, traceback, children = res
|
||||
|
||||
return self.meta_from_decoded({
|
||||
'task_id': task_id,
|
||||
'status': status,
|
||||
'result': self.decode(result),
|
||||
'date_done': date_done,
|
||||
'traceback': self.decode(traceback),
|
||||
'children': self.decode(children),
|
||||
})
|
||||
|
||||
def __reduce__(self, args=(), kwargs=None):
|
||||
kwargs = {} if not kwargs else kwargs
|
||||
kwargs.update(
|
||||
{'servers': self.servers,
|
||||
'keyspace': self.keyspace,
|
||||
'table': self.table})
|
||||
return super().__reduce__(args, kwargs)
|
||||
116
venv/lib/python3.12/site-packages/celery/backends/consul.py
Normal file
116
venv/lib/python3.12/site-packages/celery/backends/consul.py
Normal file
@@ -0,0 +1,116 @@
|
||||
"""Consul result store backend.
|
||||
|
||||
- :class:`ConsulBackend` implements KeyValueStoreBackend to store results
|
||||
in the key-value store of Consul.
|
||||
"""
|
||||
from kombu.utils.encoding import bytes_to_str
|
||||
from kombu.utils.url import parse_url
|
||||
|
||||
from celery.backends.base import KeyValueStoreBackend
|
||||
from celery.exceptions import ImproperlyConfigured
|
||||
from celery.utils.log import get_logger
|
||||
|
||||
try:
|
||||
import consul
|
||||
except ImportError:
|
||||
consul = None
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
__all__ = ('ConsulBackend',)
|
||||
|
||||
CONSUL_MISSING = """\
|
||||
You need to install the python-consul library in order to use \
|
||||
the Consul result store backend."""
|
||||
|
||||
|
||||
class ConsulBackend(KeyValueStoreBackend):
|
||||
"""Consul.io K/V store backend for Celery."""
|
||||
|
||||
consul = consul
|
||||
|
||||
supports_autoexpire = True
|
||||
|
||||
consistency = 'consistent'
|
||||
path = None
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
if self.consul is None:
|
||||
raise ImproperlyConfigured(CONSUL_MISSING)
|
||||
#
|
||||
# By default, for correctness, we use a client connection per
|
||||
# operation. If set, self.one_client will be used for all operations.
|
||||
# This provides for the original behaviour to be selected, and is
|
||||
# also convenient for mocking in the unit tests.
|
||||
#
|
||||
self.one_client = None
|
||||
self._init_from_params(**parse_url(self.url))
|
||||
|
||||
def _init_from_params(self, hostname, port, virtual_host, **params):
|
||||
logger.debug('Setting on Consul client to connect to %s:%d',
|
||||
hostname, port)
|
||||
self.path = virtual_host
|
||||
self.hostname = hostname
|
||||
self.port = port
|
||||
#
|
||||
# Optionally, allow a single client connection to be used to reduce
|
||||
# the connection load on Consul by adding a "one_client=1" parameter
|
||||
# to the URL.
|
||||
#
|
||||
if params.get('one_client', None):
|
||||
self.one_client = self.client()
|
||||
|
||||
def client(self):
|
||||
return self.one_client or consul.Consul(host=self.hostname,
|
||||
port=self.port,
|
||||
consistency=self.consistency)
|
||||
|
||||
def _key_to_consul_key(self, key):
|
||||
key = bytes_to_str(key)
|
||||
return key if self.path is None else f'{self.path}/{key}'
|
||||
|
||||
def get(self, key):
|
||||
key = self._key_to_consul_key(key)
|
||||
logger.debug('Trying to fetch key %s from Consul', key)
|
||||
try:
|
||||
_, data = self.client().kv.get(key)
|
||||
return data['Value']
|
||||
except TypeError:
|
||||
pass
|
||||
|
||||
def mget(self, keys):
|
||||
for key in keys:
|
||||
yield self.get(key)
|
||||
|
||||
def set(self, key, value):
|
||||
"""Set a key in Consul.
|
||||
|
||||
Before creating the key it will create a session inside Consul
|
||||
where it creates a session with a TTL
|
||||
|
||||
The key created afterwards will reference to the session's ID.
|
||||
|
||||
If the session expires it will remove the key so that results
|
||||
can auto expire from the K/V store
|
||||
"""
|
||||
session_name = bytes_to_str(key)
|
||||
|
||||
key = self._key_to_consul_key(key)
|
||||
|
||||
logger.debug('Trying to create Consul session %s with TTL %d',
|
||||
session_name, self.expires)
|
||||
client = self.client()
|
||||
session_id = client.session.create(name=session_name,
|
||||
behavior='delete',
|
||||
ttl=self.expires)
|
||||
logger.debug('Created Consul session %s', session_id)
|
||||
|
||||
logger.debug('Writing key %s to Consul', key)
|
||||
return client.kv.put(key=key, value=value, acquire=session_id)
|
||||
|
||||
def delete(self, key):
|
||||
key = self._key_to_consul_key(key)
|
||||
logger.debug('Removing key %s from Consul', key)
|
||||
return self.client().kv.delete(key)
|
||||
218
venv/lib/python3.12/site-packages/celery/backends/cosmosdbsql.py
Normal file
218
venv/lib/python3.12/site-packages/celery/backends/cosmosdbsql.py
Normal file
@@ -0,0 +1,218 @@
|
||||
"""The CosmosDB/SQL backend for Celery (experimental)."""
|
||||
from kombu.utils import cached_property
|
||||
from kombu.utils.encoding import bytes_to_str
|
||||
from kombu.utils.url import _parse_url
|
||||
|
||||
from celery.exceptions import ImproperlyConfigured
|
||||
from celery.utils.log import get_logger
|
||||
|
||||
from .base import KeyValueStoreBackend
|
||||
|
||||
try:
|
||||
import pydocumentdb
|
||||
from pydocumentdb.document_client import DocumentClient
|
||||
from pydocumentdb.documents import ConnectionPolicy, ConsistencyLevel, PartitionKind
|
||||
from pydocumentdb.errors import HTTPFailure
|
||||
from pydocumentdb.retry_options import RetryOptions
|
||||
except ImportError:
|
||||
pydocumentdb = DocumentClient = ConsistencyLevel = PartitionKind = \
|
||||
HTTPFailure = ConnectionPolicy = RetryOptions = None
|
||||
|
||||
__all__ = ("CosmosDBSQLBackend",)
|
||||
|
||||
|
||||
ERROR_NOT_FOUND = 404
|
||||
ERROR_EXISTS = 409
|
||||
|
||||
LOGGER = get_logger(__name__)
|
||||
|
||||
|
||||
class CosmosDBSQLBackend(KeyValueStoreBackend):
|
||||
"""CosmosDB/SQL backend for Celery."""
|
||||
|
||||
def __init__(self,
|
||||
url=None,
|
||||
database_name=None,
|
||||
collection_name=None,
|
||||
consistency_level=None,
|
||||
max_retry_attempts=None,
|
||||
max_retry_wait_time=None,
|
||||
*args,
|
||||
**kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
if pydocumentdb is None:
|
||||
raise ImproperlyConfigured(
|
||||
"You need to install the pydocumentdb library to use the "
|
||||
"CosmosDB backend.")
|
||||
|
||||
conf = self.app.conf
|
||||
|
||||
self._endpoint, self._key = self._parse_url(url)
|
||||
|
||||
self._database_name = (
|
||||
database_name or
|
||||
conf["cosmosdbsql_database_name"])
|
||||
|
||||
self._collection_name = (
|
||||
collection_name or
|
||||
conf["cosmosdbsql_collection_name"])
|
||||
|
||||
try:
|
||||
self._consistency_level = getattr(
|
||||
ConsistencyLevel,
|
||||
consistency_level or
|
||||
conf["cosmosdbsql_consistency_level"])
|
||||
except AttributeError:
|
||||
raise ImproperlyConfigured("Unknown CosmosDB consistency level")
|
||||
|
||||
self._max_retry_attempts = (
|
||||
max_retry_attempts or
|
||||
conf["cosmosdbsql_max_retry_attempts"])
|
||||
|
||||
self._max_retry_wait_time = (
|
||||
max_retry_wait_time or
|
||||
conf["cosmosdbsql_max_retry_wait_time"])
|
||||
|
||||
@classmethod
|
||||
def _parse_url(cls, url):
|
||||
_, host, port, _, password, _, _ = _parse_url(url)
|
||||
|
||||
if not host or not password:
|
||||
raise ImproperlyConfigured("Invalid URL")
|
||||
|
||||
if not port:
|
||||
port = 443
|
||||
|
||||
scheme = "https" if port == 443 else "http"
|
||||
endpoint = f"{scheme}://{host}:{port}"
|
||||
return endpoint, password
|
||||
|
||||
@cached_property
|
||||
def _client(self):
|
||||
"""Return the CosmosDB/SQL client.
|
||||
|
||||
If this is the first call to the property, the client is created and
|
||||
the database and collection are initialized if they don't yet exist.
|
||||
|
||||
"""
|
||||
connection_policy = ConnectionPolicy()
|
||||
connection_policy.RetryOptions = RetryOptions(
|
||||
max_retry_attempt_count=self._max_retry_attempts,
|
||||
max_wait_time_in_seconds=self._max_retry_wait_time)
|
||||
|
||||
client = DocumentClient(
|
||||
self._endpoint,
|
||||
{"masterKey": self._key},
|
||||
connection_policy=connection_policy,
|
||||
consistency_level=self._consistency_level)
|
||||
|
||||
self._create_database_if_not_exists(client)
|
||||
self._create_collection_if_not_exists(client)
|
||||
|
||||
return client
|
||||
|
||||
def _create_database_if_not_exists(self, client):
|
||||
try:
|
||||
client.CreateDatabase({"id": self._database_name})
|
||||
except HTTPFailure as ex:
|
||||
if ex.status_code != ERROR_EXISTS:
|
||||
raise
|
||||
else:
|
||||
LOGGER.info("Created CosmosDB database %s",
|
||||
self._database_name)
|
||||
|
||||
def _create_collection_if_not_exists(self, client):
|
||||
try:
|
||||
client.CreateCollection(
|
||||
self._database_link,
|
||||
{"id": self._collection_name,
|
||||
"partitionKey": {"paths": ["/id"],
|
||||
"kind": PartitionKind.Hash}})
|
||||
except HTTPFailure as ex:
|
||||
if ex.status_code != ERROR_EXISTS:
|
||||
raise
|
||||
else:
|
||||
LOGGER.info("Created CosmosDB collection %s/%s",
|
||||
self._database_name, self._collection_name)
|
||||
|
||||
@cached_property
|
||||
def _database_link(self):
|
||||
return "dbs/" + self._database_name
|
||||
|
||||
@cached_property
|
||||
def _collection_link(self):
|
||||
return self._database_link + "/colls/" + self._collection_name
|
||||
|
||||
def _get_document_link(self, key):
|
||||
return self._collection_link + "/docs/" + key
|
||||
|
||||
@classmethod
|
||||
def _get_partition_key(cls, key):
|
||||
if not key or key.isspace():
|
||||
raise ValueError("Key cannot be none, empty or whitespace.")
|
||||
|
||||
return {"partitionKey": key}
|
||||
|
||||
def get(self, key):
|
||||
"""Read the value stored at the given key.
|
||||
|
||||
Args:
|
||||
key: The key for which to read the value.
|
||||
|
||||
"""
|
||||
key = bytes_to_str(key)
|
||||
LOGGER.debug("Getting CosmosDB document %s/%s/%s",
|
||||
self._database_name, self._collection_name, key)
|
||||
|
||||
try:
|
||||
document = self._client.ReadDocument(
|
||||
self._get_document_link(key),
|
||||
self._get_partition_key(key))
|
||||
except HTTPFailure as ex:
|
||||
if ex.status_code != ERROR_NOT_FOUND:
|
||||
raise
|
||||
return None
|
||||
else:
|
||||
return document.get("value")
|
||||
|
||||
def set(self, key, value):
|
||||
"""Store a value for a given key.
|
||||
|
||||
Args:
|
||||
key: The key at which to store the value.
|
||||
value: The value to store.
|
||||
|
||||
"""
|
||||
key = bytes_to_str(key)
|
||||
LOGGER.debug("Creating CosmosDB document %s/%s/%s",
|
||||
self._database_name, self._collection_name, key)
|
||||
|
||||
self._client.CreateDocument(
|
||||
self._collection_link,
|
||||
{"id": key, "value": value},
|
||||
self._get_partition_key(key))
|
||||
|
||||
def mget(self, keys):
|
||||
"""Read all the values for the provided keys.
|
||||
|
||||
Args:
|
||||
keys: The list of keys to read.
|
||||
|
||||
"""
|
||||
return [self.get(key) for key in keys]
|
||||
|
||||
def delete(self, key):
|
||||
"""Delete the value at a given key.
|
||||
|
||||
Args:
|
||||
key: The key of the value to delete.
|
||||
|
||||
"""
|
||||
key = bytes_to_str(key)
|
||||
LOGGER.debug("Deleting CosmosDB document %s/%s/%s",
|
||||
self._database_name, self._collection_name, key)
|
||||
|
||||
self._client.DeleteDocument(
|
||||
self._get_document_link(key),
|
||||
self._get_partition_key(key))
|
||||
114
venv/lib/python3.12/site-packages/celery/backends/couchbase.py
Normal file
114
venv/lib/python3.12/site-packages/celery/backends/couchbase.py
Normal file
@@ -0,0 +1,114 @@
|
||||
"""Couchbase result store backend."""
|
||||
|
||||
from kombu.utils.url import _parse_url
|
||||
|
||||
from celery.exceptions import ImproperlyConfigured
|
||||
|
||||
from .base import KeyValueStoreBackend
|
||||
|
||||
try:
|
||||
from couchbase.auth import PasswordAuthenticator
|
||||
from couchbase.cluster import Cluster
|
||||
except ImportError:
|
||||
Cluster = PasswordAuthenticator = None
|
||||
|
||||
try:
|
||||
from couchbase_core._libcouchbase import FMT_AUTO
|
||||
except ImportError:
|
||||
FMT_AUTO = None
|
||||
|
||||
__all__ = ('CouchbaseBackend',)
|
||||
|
||||
|
||||
class CouchbaseBackend(KeyValueStoreBackend):
|
||||
"""Couchbase backend.
|
||||
|
||||
Raises:
|
||||
celery.exceptions.ImproperlyConfigured:
|
||||
if module :pypi:`couchbase` is not available.
|
||||
"""
|
||||
|
||||
bucket = 'default'
|
||||
host = 'localhost'
|
||||
port = 8091
|
||||
username = None
|
||||
password = None
|
||||
quiet = False
|
||||
supports_autoexpire = True
|
||||
|
||||
timeout = 2.5
|
||||
|
||||
# Use str as couchbase key not bytes
|
||||
key_t = str
|
||||
|
||||
def __init__(self, url=None, *args, **kwargs):
|
||||
kwargs.setdefault('expires_type', int)
|
||||
super().__init__(*args, **kwargs)
|
||||
self.url = url
|
||||
|
||||
if Cluster is None:
|
||||
raise ImproperlyConfigured(
|
||||
'You need to install the couchbase library to use the '
|
||||
'Couchbase backend.',
|
||||
)
|
||||
|
||||
uhost = uport = uname = upass = ubucket = None
|
||||
if url:
|
||||
_, uhost, uport, uname, upass, ubucket, _ = _parse_url(url)
|
||||
ubucket = ubucket.strip('/') if ubucket else None
|
||||
|
||||
config = self.app.conf.get('couchbase_backend_settings', None)
|
||||
if config is not None:
|
||||
if not isinstance(config, dict):
|
||||
raise ImproperlyConfigured(
|
||||
'Couchbase backend settings should be grouped in a dict',
|
||||
)
|
||||
else:
|
||||
config = {}
|
||||
|
||||
self.host = uhost or config.get('host', self.host)
|
||||
self.port = int(uport or config.get('port', self.port))
|
||||
self.bucket = ubucket or config.get('bucket', self.bucket)
|
||||
self.username = uname or config.get('username', self.username)
|
||||
self.password = upass or config.get('password', self.password)
|
||||
|
||||
self._connection = None
|
||||
|
||||
def _get_connection(self):
|
||||
"""Connect to the Couchbase server."""
|
||||
if self._connection is None:
|
||||
if self.host and self.port:
|
||||
uri = f"couchbase://{self.host}:{self.port}"
|
||||
else:
|
||||
uri = f"couchbase://{self.host}"
|
||||
if self.username and self.password:
|
||||
opt = PasswordAuthenticator(self.username, self.password)
|
||||
else:
|
||||
opt = None
|
||||
|
||||
cluster = Cluster(uri, opt)
|
||||
|
||||
bucket = cluster.bucket(self.bucket)
|
||||
|
||||
self._connection = bucket.default_collection()
|
||||
return self._connection
|
||||
|
||||
@property
|
||||
def connection(self):
|
||||
return self._get_connection()
|
||||
|
||||
def get(self, key):
|
||||
return self.connection.get(key).content
|
||||
|
||||
def set(self, key, value):
|
||||
# Since 4.0.0 value is JSONType in couchbase lib, so parameter format isn't needed
|
||||
if FMT_AUTO is not None:
|
||||
self.connection.upsert(key, value, ttl=self.expires, format=FMT_AUTO)
|
||||
else:
|
||||
self.connection.upsert(key, value, ttl=self.expires)
|
||||
|
||||
def mget(self, keys):
|
||||
return self.connection.get_multi(keys)
|
||||
|
||||
def delete(self, key):
|
||||
self.connection.remove(key)
|
||||
99
venv/lib/python3.12/site-packages/celery/backends/couchdb.py
Normal file
99
venv/lib/python3.12/site-packages/celery/backends/couchdb.py
Normal file
@@ -0,0 +1,99 @@
|
||||
"""CouchDB result store backend."""
|
||||
from kombu.utils.encoding import bytes_to_str
|
||||
from kombu.utils.url import _parse_url
|
||||
|
||||
from celery.exceptions import ImproperlyConfigured
|
||||
|
||||
from .base import KeyValueStoreBackend
|
||||
|
||||
try:
|
||||
import pycouchdb
|
||||
except ImportError:
|
||||
pycouchdb = None
|
||||
|
||||
__all__ = ('CouchBackend',)
|
||||
|
||||
ERR_LIB_MISSING = """\
|
||||
You need to install the pycouchdb library to use the CouchDB result backend\
|
||||
"""
|
||||
|
||||
|
||||
class CouchBackend(KeyValueStoreBackend):
|
||||
"""CouchDB backend.
|
||||
|
||||
Raises:
|
||||
celery.exceptions.ImproperlyConfigured:
|
||||
if module :pypi:`pycouchdb` is not available.
|
||||
"""
|
||||
|
||||
container = 'default'
|
||||
scheme = 'http'
|
||||
host = 'localhost'
|
||||
port = 5984
|
||||
username = None
|
||||
password = None
|
||||
|
||||
def __init__(self, url=None, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.url = url
|
||||
|
||||
if pycouchdb is None:
|
||||
raise ImproperlyConfigured(ERR_LIB_MISSING)
|
||||
|
||||
uscheme = uhost = uport = uname = upass = ucontainer = None
|
||||
if url:
|
||||
_, uhost, uport, uname, upass, ucontainer, _ = _parse_url(url)
|
||||
ucontainer = ucontainer.strip('/') if ucontainer else None
|
||||
|
||||
self.scheme = uscheme or self.scheme
|
||||
self.host = uhost or self.host
|
||||
self.port = int(uport or self.port)
|
||||
self.container = ucontainer or self.container
|
||||
self.username = uname or self.username
|
||||
self.password = upass or self.password
|
||||
|
||||
self._connection = None
|
||||
|
||||
def _get_connection(self):
|
||||
"""Connect to the CouchDB server."""
|
||||
if self.username and self.password:
|
||||
conn_string = f'{self.scheme}://{self.username}:{self.password}@{self.host}:{self.port}'
|
||||
server = pycouchdb.Server(conn_string, authmethod='basic')
|
||||
else:
|
||||
conn_string = f'{self.scheme}://{self.host}:{self.port}'
|
||||
server = pycouchdb.Server(conn_string)
|
||||
|
||||
try:
|
||||
return server.database(self.container)
|
||||
except pycouchdb.exceptions.NotFound:
|
||||
return server.create(self.container)
|
||||
|
||||
@property
|
||||
def connection(self):
|
||||
if self._connection is None:
|
||||
self._connection = self._get_connection()
|
||||
return self._connection
|
||||
|
||||
def get(self, key):
|
||||
key = bytes_to_str(key)
|
||||
try:
|
||||
return self.connection.get(key)['value']
|
||||
except pycouchdb.exceptions.NotFound:
|
||||
return None
|
||||
|
||||
def set(self, key, value):
|
||||
key = bytes_to_str(key)
|
||||
data = {'_id': key, 'value': value}
|
||||
try:
|
||||
self.connection.save(data)
|
||||
except pycouchdb.exceptions.Conflict:
|
||||
# document already exists, update it
|
||||
data = self.connection.get(key)
|
||||
data['value'] = value
|
||||
self.connection.save(data)
|
||||
|
||||
def mget(self, keys):
|
||||
return [self.get(key) for key in keys]
|
||||
|
||||
def delete(self, key):
|
||||
self.connection.delete(key)
|
||||
@@ -0,0 +1,222 @@
|
||||
"""SQLAlchemy result store backend."""
|
||||
import logging
|
||||
from contextlib import contextmanager
|
||||
|
||||
from vine.utils import wraps
|
||||
|
||||
from celery import states
|
||||
from celery.backends.base import BaseBackend
|
||||
from celery.exceptions import ImproperlyConfigured
|
||||
from celery.utils.time import maybe_timedelta
|
||||
|
||||
from .models import Task, TaskExtended, TaskSet
|
||||
from .session import SessionManager
|
||||
|
||||
try:
|
||||
from sqlalchemy.exc import DatabaseError, InvalidRequestError
|
||||
from sqlalchemy.orm.exc import StaleDataError
|
||||
except ImportError:
|
||||
raise ImproperlyConfigured(
|
||||
'The database result backend requires SQLAlchemy to be installed.'
|
||||
'See https://pypi.org/project/SQLAlchemy/')
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
__all__ = ('DatabaseBackend',)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def session_cleanup(session):
|
||||
try:
|
||||
yield
|
||||
except Exception:
|
||||
session.rollback()
|
||||
raise
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
|
||||
def retry(fun):
|
||||
|
||||
@wraps(fun)
|
||||
def _inner(*args, **kwargs):
|
||||
max_retries = kwargs.pop('max_retries', 3)
|
||||
|
||||
for retries in range(max_retries):
|
||||
try:
|
||||
return fun(*args, **kwargs)
|
||||
except (DatabaseError, InvalidRequestError, StaleDataError):
|
||||
logger.warning(
|
||||
'Failed operation %s. Retrying %s more times.',
|
||||
fun.__name__, max_retries - retries - 1,
|
||||
exc_info=True)
|
||||
if retries + 1 >= max_retries:
|
||||
raise
|
||||
|
||||
return _inner
|
||||
|
||||
|
||||
class DatabaseBackend(BaseBackend):
|
||||
"""The database result backend."""
|
||||
|
||||
# ResultSet.iterate should sleep this much between each pool,
|
||||
# to not bombard the database with queries.
|
||||
subpolling_interval = 0.5
|
||||
|
||||
task_cls = Task
|
||||
taskset_cls = TaskSet
|
||||
|
||||
def __init__(self, dburi=None, engine_options=None, url=None, **kwargs):
|
||||
# The `url` argument was added later and is used by
|
||||
# the app to set backend by url (celery.app.backends.by_url)
|
||||
super().__init__(expires_type=maybe_timedelta,
|
||||
url=url, **kwargs)
|
||||
conf = self.app.conf
|
||||
|
||||
if self.extended_result:
|
||||
self.task_cls = TaskExtended
|
||||
|
||||
self.url = url or dburi or conf.database_url
|
||||
self.engine_options = dict(
|
||||
engine_options or {},
|
||||
**conf.database_engine_options or {})
|
||||
self.short_lived_sessions = kwargs.get(
|
||||
'short_lived_sessions',
|
||||
conf.database_short_lived_sessions)
|
||||
|
||||
schemas = conf.database_table_schemas or {}
|
||||
tablenames = conf.database_table_names or {}
|
||||
self.task_cls.configure(
|
||||
schema=schemas.get('task'),
|
||||
name=tablenames.get('task'))
|
||||
self.taskset_cls.configure(
|
||||
schema=schemas.get('group'),
|
||||
name=tablenames.get('group'))
|
||||
|
||||
if not self.url:
|
||||
raise ImproperlyConfigured(
|
||||
'Missing connection string! Do you have the'
|
||||
' database_url setting set to a real value?')
|
||||
|
||||
@property
|
||||
def extended_result(self):
|
||||
return self.app.conf.find_value_for_key('extended', 'result')
|
||||
|
||||
def ResultSession(self, session_manager=SessionManager()):
|
||||
return session_manager.session_factory(
|
||||
dburi=self.url,
|
||||
short_lived_sessions=self.short_lived_sessions,
|
||||
**self.engine_options)
|
||||
|
||||
@retry
|
||||
def _store_result(self, task_id, result, state, traceback=None,
|
||||
request=None, **kwargs):
|
||||
"""Store return value and state of an executed task."""
|
||||
session = self.ResultSession()
|
||||
with session_cleanup(session):
|
||||
task = list(session.query(self.task_cls).filter(self.task_cls.task_id == task_id))
|
||||
task = task and task[0]
|
||||
if not task:
|
||||
task = self.task_cls(task_id)
|
||||
task.task_id = task_id
|
||||
session.add(task)
|
||||
session.flush()
|
||||
|
||||
self._update_result(task, result, state, traceback=traceback, request=request)
|
||||
session.commit()
|
||||
|
||||
def _update_result(self, task, result, state, traceback=None,
|
||||
request=None):
|
||||
|
||||
meta = self._get_result_meta(result=result, state=state,
|
||||
traceback=traceback, request=request,
|
||||
format_date=False, encode=True)
|
||||
|
||||
# Exclude the primary key id and task_id columns
|
||||
# as we should not set it None
|
||||
columns = [column.name for column in self.task_cls.__table__.columns
|
||||
if column.name not in {'id', 'task_id'}]
|
||||
|
||||
# Iterate through the columns name of the table
|
||||
# to set the value from meta.
|
||||
# If the value is not present in meta, set None
|
||||
for column in columns:
|
||||
value = meta.get(column)
|
||||
setattr(task, column, value)
|
||||
|
||||
@retry
|
||||
def _get_task_meta_for(self, task_id):
|
||||
"""Get task meta-data for a task by id."""
|
||||
session = self.ResultSession()
|
||||
with session_cleanup(session):
|
||||
task = list(session.query(self.task_cls).filter(self.task_cls.task_id == task_id))
|
||||
task = task and task[0]
|
||||
if not task:
|
||||
task = self.task_cls(task_id)
|
||||
task.status = states.PENDING
|
||||
task.result = None
|
||||
data = task.to_dict()
|
||||
if data.get('args', None) is not None:
|
||||
data['args'] = self.decode(data['args'])
|
||||
if data.get('kwargs', None) is not None:
|
||||
data['kwargs'] = self.decode(data['kwargs'])
|
||||
return self.meta_from_decoded(data)
|
||||
|
||||
@retry
|
||||
def _save_group(self, group_id, result):
|
||||
"""Store the result of an executed group."""
|
||||
session = self.ResultSession()
|
||||
with session_cleanup(session):
|
||||
group = self.taskset_cls(group_id, result)
|
||||
session.add(group)
|
||||
session.flush()
|
||||
session.commit()
|
||||
return result
|
||||
|
||||
@retry
|
||||
def _restore_group(self, group_id):
|
||||
"""Get meta-data for group by id."""
|
||||
session = self.ResultSession()
|
||||
with session_cleanup(session):
|
||||
group = session.query(self.taskset_cls).filter(
|
||||
self.taskset_cls.taskset_id == group_id).first()
|
||||
if group:
|
||||
return group.to_dict()
|
||||
|
||||
@retry
|
||||
def _delete_group(self, group_id):
|
||||
"""Delete meta-data for group by id."""
|
||||
session = self.ResultSession()
|
||||
with session_cleanup(session):
|
||||
session.query(self.taskset_cls).filter(
|
||||
self.taskset_cls.taskset_id == group_id).delete()
|
||||
session.flush()
|
||||
session.commit()
|
||||
|
||||
@retry
|
||||
def _forget(self, task_id):
|
||||
"""Forget about result."""
|
||||
session = self.ResultSession()
|
||||
with session_cleanup(session):
|
||||
session.query(self.task_cls).filter(self.task_cls.task_id == task_id).delete()
|
||||
session.commit()
|
||||
|
||||
def cleanup(self):
|
||||
"""Delete expired meta-data."""
|
||||
session = self.ResultSession()
|
||||
expires = self.expires
|
||||
now = self.app.now()
|
||||
with session_cleanup(session):
|
||||
session.query(self.task_cls).filter(
|
||||
self.task_cls.date_done < (now - expires)).delete()
|
||||
session.query(self.taskset_cls).filter(
|
||||
self.taskset_cls.date_done < (now - expires)).delete()
|
||||
session.commit()
|
||||
|
||||
def __reduce__(self, args=(), kwargs=None):
|
||||
kwargs = {} if not kwargs else kwargs
|
||||
kwargs.update(
|
||||
{'dburi': self.url,
|
||||
'expires': self.expires,
|
||||
'engine_options': self.engine_options})
|
||||
return super().__reduce__(args, kwargs)
|
||||
@@ -0,0 +1,108 @@
|
||||
"""Database models used by the SQLAlchemy result store backend."""
|
||||
from datetime import datetime
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.types import PickleType
|
||||
|
||||
from celery import states
|
||||
|
||||
from .session import ResultModelBase
|
||||
|
||||
__all__ = ('Task', 'TaskExtended', 'TaskSet')
|
||||
|
||||
|
||||
class Task(ResultModelBase):
|
||||
"""Task result/status."""
|
||||
|
||||
__tablename__ = 'celery_taskmeta'
|
||||
__table_args__ = {'sqlite_autoincrement': True}
|
||||
|
||||
id = sa.Column(sa.Integer, sa.Sequence('task_id_sequence'),
|
||||
primary_key=True, autoincrement=True)
|
||||
task_id = sa.Column(sa.String(155), unique=True)
|
||||
status = sa.Column(sa.String(50), default=states.PENDING)
|
||||
result = sa.Column(PickleType, nullable=True)
|
||||
date_done = sa.Column(sa.DateTime, default=datetime.utcnow,
|
||||
onupdate=datetime.utcnow, nullable=True)
|
||||
traceback = sa.Column(sa.Text, nullable=True)
|
||||
|
||||
def __init__(self, task_id):
|
||||
self.task_id = task_id
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
'task_id': self.task_id,
|
||||
'status': self.status,
|
||||
'result': self.result,
|
||||
'traceback': self.traceback,
|
||||
'date_done': self.date_done,
|
||||
}
|
||||
|
||||
def __repr__(self):
|
||||
return '<Task {0.task_id} state: {0.status}>'.format(self)
|
||||
|
||||
@classmethod
|
||||
def configure(cls, schema=None, name=None):
|
||||
cls.__table__.schema = schema
|
||||
cls.id.default.schema = schema
|
||||
cls.__table__.name = name or cls.__tablename__
|
||||
|
||||
|
||||
class TaskExtended(Task):
|
||||
"""For the extend result."""
|
||||
|
||||
__tablename__ = 'celery_taskmeta'
|
||||
__table_args__ = {'sqlite_autoincrement': True, 'extend_existing': True}
|
||||
|
||||
name = sa.Column(sa.String(155), nullable=True)
|
||||
args = sa.Column(sa.LargeBinary, nullable=True)
|
||||
kwargs = sa.Column(sa.LargeBinary, nullable=True)
|
||||
worker = sa.Column(sa.String(155), nullable=True)
|
||||
retries = sa.Column(sa.Integer, nullable=True)
|
||||
queue = sa.Column(sa.String(155), nullable=True)
|
||||
|
||||
def to_dict(self):
|
||||
task_dict = super().to_dict()
|
||||
task_dict.update({
|
||||
'name': self.name,
|
||||
'args': self.args,
|
||||
'kwargs': self.kwargs,
|
||||
'worker': self.worker,
|
||||
'retries': self.retries,
|
||||
'queue': self.queue,
|
||||
})
|
||||
return task_dict
|
||||
|
||||
|
||||
class TaskSet(ResultModelBase):
|
||||
"""TaskSet result."""
|
||||
|
||||
__tablename__ = 'celery_tasksetmeta'
|
||||
__table_args__ = {'sqlite_autoincrement': True}
|
||||
|
||||
id = sa.Column(sa.Integer, sa.Sequence('taskset_id_sequence'),
|
||||
autoincrement=True, primary_key=True)
|
||||
taskset_id = sa.Column(sa.String(155), unique=True)
|
||||
result = sa.Column(PickleType, nullable=True)
|
||||
date_done = sa.Column(sa.DateTime, default=datetime.utcnow,
|
||||
nullable=True)
|
||||
|
||||
def __init__(self, taskset_id, result):
|
||||
self.taskset_id = taskset_id
|
||||
self.result = result
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
'taskset_id': self.taskset_id,
|
||||
'result': self.result,
|
||||
'date_done': self.date_done,
|
||||
}
|
||||
|
||||
def __repr__(self):
|
||||
return f'<TaskSet: {self.taskset_id}>'
|
||||
|
||||
@classmethod
|
||||
def configure(cls, schema=None, name=None):
|
||||
cls.__table__.schema = schema
|
||||
cls.id.default.schema = schema
|
||||
cls.__table__.name = name or cls.__tablename__
|
||||
@@ -0,0 +1,89 @@
|
||||
"""SQLAlchemy session."""
|
||||
import time
|
||||
|
||||
from kombu.utils.compat import register_after_fork
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.exc import DatabaseError
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.pool import NullPool
|
||||
|
||||
from celery.utils.time import get_exponential_backoff_interval
|
||||
|
||||
try:
|
||||
from sqlalchemy.orm import declarative_base
|
||||
except ImportError:
|
||||
# TODO: Remove this once we drop support for SQLAlchemy < 1.4.
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
|
||||
ResultModelBase = declarative_base()
|
||||
|
||||
__all__ = ('SessionManager',)
|
||||
|
||||
PREPARE_MODELS_MAX_RETRIES = 10
|
||||
|
||||
|
||||
def _after_fork_cleanup_session(session):
|
||||
session._after_fork()
|
||||
|
||||
|
||||
class SessionManager:
|
||||
"""Manage SQLAlchemy sessions."""
|
||||
|
||||
def __init__(self):
|
||||
self._engines = {}
|
||||
self._sessions = {}
|
||||
self.forked = False
|
||||
self.prepared = False
|
||||
if register_after_fork is not None:
|
||||
register_after_fork(self, _after_fork_cleanup_session)
|
||||
|
||||
def _after_fork(self):
|
||||
self.forked = True
|
||||
|
||||
def get_engine(self, dburi, **kwargs):
|
||||
if self.forked:
|
||||
try:
|
||||
return self._engines[dburi]
|
||||
except KeyError:
|
||||
engine = self._engines[dburi] = create_engine(dburi, **kwargs)
|
||||
return engine
|
||||
else:
|
||||
kwargs = {k: v for k, v in kwargs.items() if
|
||||
not k.startswith('pool')}
|
||||
return create_engine(dburi, poolclass=NullPool, **kwargs)
|
||||
|
||||
def create_session(self, dburi, short_lived_sessions=False, **kwargs):
|
||||
engine = self.get_engine(dburi, **kwargs)
|
||||
if self.forked:
|
||||
if short_lived_sessions or dburi not in self._sessions:
|
||||
self._sessions[dburi] = sessionmaker(bind=engine)
|
||||
return engine, self._sessions[dburi]
|
||||
return engine, sessionmaker(bind=engine)
|
||||
|
||||
def prepare_models(self, engine):
|
||||
if not self.prepared:
|
||||
# SQLAlchemy will check if the items exist before trying to
|
||||
# create them, which is a race condition. If it raises an error
|
||||
# in one iteration, the next may pass all the existence checks
|
||||
# and the call will succeed.
|
||||
retries = 0
|
||||
while True:
|
||||
try:
|
||||
ResultModelBase.metadata.create_all(engine)
|
||||
except DatabaseError:
|
||||
if retries < PREPARE_MODELS_MAX_RETRIES:
|
||||
sleep_amount_ms = get_exponential_backoff_interval(
|
||||
10, retries, 1000, True
|
||||
)
|
||||
time.sleep(sleep_amount_ms / 1000)
|
||||
retries += 1
|
||||
else:
|
||||
raise
|
||||
else:
|
||||
break
|
||||
self.prepared = True
|
||||
|
||||
def session_factory(self, dburi, **kwargs):
|
||||
engine, session = self.create_session(dburi, **kwargs)
|
||||
self.prepare_models(engine)
|
||||
return session()
|
||||
493
venv/lib/python3.12/site-packages/celery/backends/dynamodb.py
Normal file
493
venv/lib/python3.12/site-packages/celery/backends/dynamodb.py
Normal file
@@ -0,0 +1,493 @@
|
||||
"""AWS DynamoDB result store backend."""
|
||||
from collections import namedtuple
|
||||
from time import sleep, time
|
||||
|
||||
from kombu.utils.url import _parse_url as parse_url
|
||||
|
||||
from celery.exceptions import ImproperlyConfigured
|
||||
from celery.utils.log import get_logger
|
||||
|
||||
from .base import KeyValueStoreBackend
|
||||
|
||||
try:
|
||||
import boto3
|
||||
from botocore.exceptions import ClientError
|
||||
except ImportError:
|
||||
boto3 = ClientError = None
|
||||
|
||||
__all__ = ('DynamoDBBackend',)
|
||||
|
||||
|
||||
# Helper class that describes a DynamoDB attribute
|
||||
DynamoDBAttribute = namedtuple('DynamoDBAttribute', ('name', 'data_type'))
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class DynamoDBBackend(KeyValueStoreBackend):
|
||||
"""AWS DynamoDB result backend.
|
||||
|
||||
Raises:
|
||||
celery.exceptions.ImproperlyConfigured:
|
||||
if module :pypi:`boto3` is not available.
|
||||
"""
|
||||
|
||||
#: default DynamoDB table name (`default`)
|
||||
table_name = 'celery'
|
||||
|
||||
#: Read Provisioned Throughput (`default`)
|
||||
read_capacity_units = 1
|
||||
|
||||
#: Write Provisioned Throughput (`default`)
|
||||
write_capacity_units = 1
|
||||
|
||||
#: AWS region (`default`)
|
||||
aws_region = None
|
||||
|
||||
#: The endpoint URL that is passed to boto3 (local DynamoDB) (`default`)
|
||||
endpoint_url = None
|
||||
|
||||
#: Item time-to-live in seconds (`default`)
|
||||
time_to_live_seconds = None
|
||||
|
||||
# DynamoDB supports Time to Live as an auto-expiry mechanism.
|
||||
supports_autoexpire = True
|
||||
|
||||
_key_field = DynamoDBAttribute(name='id', data_type='S')
|
||||
_value_field = DynamoDBAttribute(name='result', data_type='B')
|
||||
_timestamp_field = DynamoDBAttribute(name='timestamp', data_type='N')
|
||||
_ttl_field = DynamoDBAttribute(name='ttl', data_type='N')
|
||||
_available_fields = None
|
||||
|
||||
def __init__(self, url=None, table_name=None, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.url = url
|
||||
self.table_name = table_name or self.table_name
|
||||
|
||||
if not boto3:
|
||||
raise ImproperlyConfigured(
|
||||
'You need to install the boto3 library to use the '
|
||||
'DynamoDB backend.')
|
||||
|
||||
aws_credentials_given = False
|
||||
aws_access_key_id = None
|
||||
aws_secret_access_key = None
|
||||
|
||||
if url is not None:
|
||||
scheme, region, port, username, password, table, query = \
|
||||
parse_url(url)
|
||||
|
||||
aws_access_key_id = username
|
||||
aws_secret_access_key = password
|
||||
|
||||
access_key_given = aws_access_key_id is not None
|
||||
secret_key_given = aws_secret_access_key is not None
|
||||
|
||||
if access_key_given != secret_key_given:
|
||||
raise ImproperlyConfigured(
|
||||
'You need to specify both the Access Key ID '
|
||||
'and Secret.')
|
||||
|
||||
aws_credentials_given = access_key_given
|
||||
|
||||
if region == 'localhost':
|
||||
# We are using the downloadable, local version of DynamoDB
|
||||
self.endpoint_url = f'http://localhost:{port}'
|
||||
self.aws_region = 'us-east-1'
|
||||
logger.warning(
|
||||
'Using local-only DynamoDB endpoint URL: {}'.format(
|
||||
self.endpoint_url
|
||||
)
|
||||
)
|
||||
else:
|
||||
self.aws_region = region
|
||||
|
||||
# If endpoint_url is explicitly set use it instead
|
||||
_get = self.app.conf.get
|
||||
config_endpoint_url = _get('dynamodb_endpoint_url')
|
||||
if config_endpoint_url:
|
||||
self.endpoint_url = config_endpoint_url
|
||||
|
||||
self.read_capacity_units = int(
|
||||
query.get(
|
||||
'read',
|
||||
self.read_capacity_units
|
||||
)
|
||||
)
|
||||
self.write_capacity_units = int(
|
||||
query.get(
|
||||
'write',
|
||||
self.write_capacity_units
|
||||
)
|
||||
)
|
||||
|
||||
ttl = query.get('ttl_seconds', self.time_to_live_seconds)
|
||||
if ttl:
|
||||
try:
|
||||
self.time_to_live_seconds = int(ttl)
|
||||
except ValueError as e:
|
||||
logger.error(
|
||||
f'TTL must be a number; got "{ttl}"',
|
||||
exc_info=e
|
||||
)
|
||||
raise e
|
||||
|
||||
self.table_name = table or self.table_name
|
||||
|
||||
self._available_fields = (
|
||||
self._key_field,
|
||||
self._value_field,
|
||||
self._timestamp_field
|
||||
)
|
||||
|
||||
self._client = None
|
||||
if aws_credentials_given:
|
||||
self._get_client(
|
||||
access_key_id=aws_access_key_id,
|
||||
secret_access_key=aws_secret_access_key
|
||||
)
|
||||
|
||||
def _get_client(self, access_key_id=None, secret_access_key=None):
|
||||
"""Get client connection."""
|
||||
if self._client is None:
|
||||
client_parameters = {
|
||||
'region_name': self.aws_region
|
||||
}
|
||||
if access_key_id is not None:
|
||||
client_parameters.update({
|
||||
'aws_access_key_id': access_key_id,
|
||||
'aws_secret_access_key': secret_access_key
|
||||
})
|
||||
|
||||
if self.endpoint_url is not None:
|
||||
client_parameters['endpoint_url'] = self.endpoint_url
|
||||
|
||||
self._client = boto3.client(
|
||||
'dynamodb',
|
||||
**client_parameters
|
||||
)
|
||||
self._get_or_create_table()
|
||||
|
||||
if self._has_ttl() is not None:
|
||||
self._validate_ttl_methods()
|
||||
self._set_table_ttl()
|
||||
|
||||
return self._client
|
||||
|
||||
def _get_table_schema(self):
|
||||
"""Get the boto3 structure describing the DynamoDB table schema."""
|
||||
return {
|
||||
'AttributeDefinitions': [
|
||||
{
|
||||
'AttributeName': self._key_field.name,
|
||||
'AttributeType': self._key_field.data_type
|
||||
}
|
||||
],
|
||||
'TableName': self.table_name,
|
||||
'KeySchema': [
|
||||
{
|
||||
'AttributeName': self._key_field.name,
|
||||
'KeyType': 'HASH'
|
||||
}
|
||||
],
|
||||
'ProvisionedThroughput': {
|
||||
'ReadCapacityUnits': self.read_capacity_units,
|
||||
'WriteCapacityUnits': self.write_capacity_units
|
||||
}
|
||||
}
|
||||
|
||||
def _get_or_create_table(self):
|
||||
"""Create table if not exists, otherwise return the description."""
|
||||
table_schema = self._get_table_schema()
|
||||
try:
|
||||
return self._client.describe_table(TableName=self.table_name)
|
||||
except ClientError as e:
|
||||
error_code = e.response['Error'].get('Code', 'Unknown')
|
||||
|
||||
if error_code == 'ResourceNotFoundException':
|
||||
table_description = self._client.create_table(**table_schema)
|
||||
logger.info(
|
||||
'DynamoDB Table {} did not exist, creating.'.format(
|
||||
self.table_name
|
||||
)
|
||||
)
|
||||
# In case we created the table, wait until it becomes available.
|
||||
self._wait_for_table_status('ACTIVE')
|
||||
logger.info(
|
||||
'DynamoDB Table {} is now available.'.format(
|
||||
self.table_name
|
||||
)
|
||||
)
|
||||
return table_description
|
||||
else:
|
||||
raise e
|
||||
|
||||
def _has_ttl(self):
|
||||
"""Return the desired Time to Live config.
|
||||
|
||||
- True: Enable TTL on the table; use expiry.
|
||||
- False: Disable TTL on the table; don't use expiry.
|
||||
- None: Ignore TTL on the table; don't use expiry.
|
||||
"""
|
||||
return None if self.time_to_live_seconds is None \
|
||||
else self.time_to_live_seconds >= 0
|
||||
|
||||
def _validate_ttl_methods(self):
|
||||
"""Verify boto support for the DynamoDB Time to Live methods."""
|
||||
# Required TTL methods.
|
||||
required_methods = (
|
||||
'update_time_to_live',
|
||||
'describe_time_to_live',
|
||||
)
|
||||
|
||||
# Find missing methods.
|
||||
missing_methods = []
|
||||
for method in list(required_methods):
|
||||
if not hasattr(self._client, method):
|
||||
missing_methods.append(method)
|
||||
|
||||
if missing_methods:
|
||||
logger.error(
|
||||
(
|
||||
'boto3 method(s) {methods} not found; ensure that '
|
||||
'boto3>=1.9.178 and botocore>=1.12.178 are installed'
|
||||
).format(
|
||||
methods=','.join(missing_methods)
|
||||
)
|
||||
)
|
||||
raise AttributeError(
|
||||
'boto3 method(s) {methods} not found'.format(
|
||||
methods=','.join(missing_methods)
|
||||
)
|
||||
)
|
||||
|
||||
def _get_ttl_specification(self, ttl_attr_name):
|
||||
"""Get the boto3 structure describing the DynamoDB TTL specification."""
|
||||
return {
|
||||
'TableName': self.table_name,
|
||||
'TimeToLiveSpecification': {
|
||||
'Enabled': self._has_ttl(),
|
||||
'AttributeName': ttl_attr_name
|
||||
}
|
||||
}
|
||||
|
||||
def _get_table_ttl_description(self):
|
||||
# Get the current TTL description.
|
||||
try:
|
||||
description = self._client.describe_time_to_live(
|
||||
TableName=self.table_name
|
||||
)
|
||||
except ClientError as e:
|
||||
error_code = e.response['Error'].get('Code', 'Unknown')
|
||||
error_message = e.response['Error'].get('Message', 'Unknown')
|
||||
logger.error((
|
||||
'Error describing Time to Live on DynamoDB table {table}: '
|
||||
'{code}: {message}'
|
||||
).format(
|
||||
table=self.table_name,
|
||||
code=error_code,
|
||||
message=error_message,
|
||||
))
|
||||
raise e
|
||||
|
||||
return description
|
||||
|
||||
def _set_table_ttl(self):
|
||||
"""Enable or disable Time to Live on the table."""
|
||||
# Get the table TTL description, and return early when possible.
|
||||
description = self._get_table_ttl_description()
|
||||
status = description['TimeToLiveDescription']['TimeToLiveStatus']
|
||||
if status in ('ENABLED', 'ENABLING'):
|
||||
cur_attr_name = \
|
||||
description['TimeToLiveDescription']['AttributeName']
|
||||
if self._has_ttl():
|
||||
if cur_attr_name == self._ttl_field.name:
|
||||
# We want TTL enabled, and it is currently enabled or being
|
||||
# enabled, and on the correct attribute.
|
||||
logger.debug((
|
||||
'DynamoDB Time to Live is {situation} '
|
||||
'on table {table}'
|
||||
).format(
|
||||
situation='already enabled'
|
||||
if status == 'ENABLED'
|
||||
else 'currently being enabled',
|
||||
table=self.table_name
|
||||
))
|
||||
return description
|
||||
|
||||
elif status in ('DISABLED', 'DISABLING'):
|
||||
if not self._has_ttl():
|
||||
# We want TTL disabled, and it is currently disabled or being
|
||||
# disabled.
|
||||
logger.debug((
|
||||
'DynamoDB Time to Live is {situation} '
|
||||
'on table {table}'
|
||||
).format(
|
||||
situation='already disabled'
|
||||
if status == 'DISABLED'
|
||||
else 'currently being disabled',
|
||||
table=self.table_name
|
||||
))
|
||||
return description
|
||||
|
||||
# The state shouldn't ever have any value beyond the four handled
|
||||
# above, but to ease troubleshooting of potential future changes, emit
|
||||
# a log showing the unknown state.
|
||||
else: # pragma: no cover
|
||||
logger.warning((
|
||||
'Unknown DynamoDB Time to Live status {status} '
|
||||
'on table {table}. Attempting to continue.'
|
||||
).format(
|
||||
status=status,
|
||||
table=self.table_name
|
||||
))
|
||||
|
||||
# At this point, we have one of the following situations:
|
||||
#
|
||||
# We want TTL enabled,
|
||||
#
|
||||
# - and it's currently disabled: Try to enable.
|
||||
#
|
||||
# - and it's being disabled: Try to enable, but this is almost sure to
|
||||
# raise ValidationException with message:
|
||||
#
|
||||
# Time to live has been modified multiple times within a fixed
|
||||
# interval
|
||||
#
|
||||
# - and it's currently enabling or being enabled, but on the wrong
|
||||
# attribute: Try to enable, but this will raise ValidationException
|
||||
# with message:
|
||||
#
|
||||
# TimeToLive is active on a different AttributeName: current
|
||||
# AttributeName is ttlx
|
||||
#
|
||||
# We want TTL disabled,
|
||||
#
|
||||
# - and it's currently enabled: Try to disable.
|
||||
#
|
||||
# - and it's being enabled: Try to disable, but this is almost sure to
|
||||
# raise ValidationException with message:
|
||||
#
|
||||
# Time to live has been modified multiple times within a fixed
|
||||
# interval
|
||||
#
|
||||
attr_name = \
|
||||
cur_attr_name if status == 'ENABLED' else self._ttl_field.name
|
||||
try:
|
||||
specification = self._client.update_time_to_live(
|
||||
**self._get_ttl_specification(
|
||||
ttl_attr_name=attr_name
|
||||
)
|
||||
)
|
||||
logger.info(
|
||||
(
|
||||
'DynamoDB table Time to Live updated: '
|
||||
'table={table} enabled={enabled} attribute={attr}'
|
||||
).format(
|
||||
table=self.table_name,
|
||||
enabled=self._has_ttl(),
|
||||
attr=self._ttl_field.name
|
||||
)
|
||||
)
|
||||
return specification
|
||||
except ClientError as e:
|
||||
error_code = e.response['Error'].get('Code', 'Unknown')
|
||||
error_message = e.response['Error'].get('Message', 'Unknown')
|
||||
logger.error((
|
||||
'Error {action} Time to Live on DynamoDB table {table}: '
|
||||
'{code}: {message}'
|
||||
).format(
|
||||
action='enabling' if self._has_ttl() else 'disabling',
|
||||
table=self.table_name,
|
||||
code=error_code,
|
||||
message=error_message,
|
||||
))
|
||||
raise e
|
||||
|
||||
def _wait_for_table_status(self, expected='ACTIVE'):
|
||||
"""Poll for the expected table status."""
|
||||
achieved_state = False
|
||||
while not achieved_state:
|
||||
table_description = self.client.describe_table(
|
||||
TableName=self.table_name
|
||||
)
|
||||
logger.debug(
|
||||
'Waiting for DynamoDB table {} to become {}.'.format(
|
||||
self.table_name,
|
||||
expected
|
||||
)
|
||||
)
|
||||
current_status = table_description['Table']['TableStatus']
|
||||
achieved_state = current_status == expected
|
||||
sleep(1)
|
||||
|
||||
def _prepare_get_request(self, key):
|
||||
"""Construct the item retrieval request parameters."""
|
||||
return {
|
||||
'TableName': self.table_name,
|
||||
'Key': {
|
||||
self._key_field.name: {
|
||||
self._key_field.data_type: key
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def _prepare_put_request(self, key, value):
|
||||
"""Construct the item creation request parameters."""
|
||||
timestamp = time()
|
||||
put_request = {
|
||||
'TableName': self.table_name,
|
||||
'Item': {
|
||||
self._key_field.name: {
|
||||
self._key_field.data_type: key
|
||||
},
|
||||
self._value_field.name: {
|
||||
self._value_field.data_type: value
|
||||
},
|
||||
self._timestamp_field.name: {
|
||||
self._timestamp_field.data_type: str(timestamp)
|
||||
}
|
||||
}
|
||||
}
|
||||
if self._has_ttl():
|
||||
put_request['Item'].update({
|
||||
self._ttl_field.name: {
|
||||
self._ttl_field.data_type:
|
||||
str(int(timestamp + self.time_to_live_seconds))
|
||||
}
|
||||
})
|
||||
return put_request
|
||||
|
||||
def _item_to_dict(self, raw_response):
|
||||
"""Convert get_item() response to field-value pairs."""
|
||||
if 'Item' not in raw_response:
|
||||
return {}
|
||||
return {
|
||||
field.name: raw_response['Item'][field.name][field.data_type]
|
||||
for field in self._available_fields
|
||||
}
|
||||
|
||||
@property
|
||||
def client(self):
|
||||
return self._get_client()
|
||||
|
||||
def get(self, key):
|
||||
key = str(key)
|
||||
request_parameters = self._prepare_get_request(key)
|
||||
item_response = self.client.get_item(**request_parameters)
|
||||
item = self._item_to_dict(item_response)
|
||||
return item.get(self._value_field.name)
|
||||
|
||||
def set(self, key, value):
|
||||
key = str(key)
|
||||
request_parameters = self._prepare_put_request(key, value)
|
||||
self.client.put_item(**request_parameters)
|
||||
|
||||
def mget(self, keys):
|
||||
return [self.get(key) for key in keys]
|
||||
|
||||
def delete(self, key):
|
||||
key = str(key)
|
||||
request_parameters = self._prepare_get_request(key)
|
||||
self.client.delete_item(**request_parameters)
|
||||
@@ -0,0 +1,248 @@
|
||||
"""Elasticsearch result store backend."""
|
||||
from datetime import datetime
|
||||
|
||||
from kombu.utils.encoding import bytes_to_str
|
||||
from kombu.utils.url import _parse_url
|
||||
|
||||
from celery import states
|
||||
from celery.exceptions import ImproperlyConfigured
|
||||
|
||||
from .base import KeyValueStoreBackend
|
||||
|
||||
try:
|
||||
import elasticsearch
|
||||
except ImportError:
|
||||
elasticsearch = None
|
||||
|
||||
__all__ = ('ElasticsearchBackend',)
|
||||
|
||||
E_LIB_MISSING = """\
|
||||
You need to install the elasticsearch library to use the Elasticsearch \
|
||||
result backend.\
|
||||
"""
|
||||
|
||||
|
||||
class ElasticsearchBackend(KeyValueStoreBackend):
|
||||
"""Elasticsearch Backend.
|
||||
|
||||
Raises:
|
||||
celery.exceptions.ImproperlyConfigured:
|
||||
if module :pypi:`elasticsearch` is not available.
|
||||
"""
|
||||
|
||||
index = 'celery'
|
||||
doc_type = 'backend'
|
||||
scheme = 'http'
|
||||
host = 'localhost'
|
||||
port = 9200
|
||||
username = None
|
||||
password = None
|
||||
es_retry_on_timeout = False
|
||||
es_timeout = 10
|
||||
es_max_retries = 3
|
||||
|
||||
def __init__(self, url=None, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.url = url
|
||||
_get = self.app.conf.get
|
||||
|
||||
if elasticsearch is None:
|
||||
raise ImproperlyConfigured(E_LIB_MISSING)
|
||||
|
||||
index = doc_type = scheme = host = port = username = password = None
|
||||
|
||||
if url:
|
||||
scheme, host, port, username, password, path, _ = _parse_url(url)
|
||||
if scheme == 'elasticsearch':
|
||||
scheme = None
|
||||
if path:
|
||||
path = path.strip('/')
|
||||
index, _, doc_type = path.partition('/')
|
||||
|
||||
self.index = index or self.index
|
||||
self.doc_type = doc_type or self.doc_type
|
||||
self.scheme = scheme or self.scheme
|
||||
self.host = host or self.host
|
||||
self.port = port or self.port
|
||||
self.username = username or self.username
|
||||
self.password = password or self.password
|
||||
|
||||
self.es_retry_on_timeout = (
|
||||
_get('elasticsearch_retry_on_timeout') or self.es_retry_on_timeout
|
||||
)
|
||||
|
||||
es_timeout = _get('elasticsearch_timeout')
|
||||
if es_timeout is not None:
|
||||
self.es_timeout = es_timeout
|
||||
|
||||
es_max_retries = _get('elasticsearch_max_retries')
|
||||
if es_max_retries is not None:
|
||||
self.es_max_retries = es_max_retries
|
||||
|
||||
self.es_save_meta_as_text = _get('elasticsearch_save_meta_as_text', True)
|
||||
self._server = None
|
||||
|
||||
def exception_safe_to_retry(self, exc):
|
||||
if isinstance(exc, (elasticsearch.exceptions.TransportError)):
|
||||
# 401: Unauthorized
|
||||
# 409: Conflict
|
||||
# 429: Too Many Requests
|
||||
# 500: Internal Server Error
|
||||
# 502: Bad Gateway
|
||||
# 503: Service Unavailable
|
||||
# 504: Gateway Timeout
|
||||
# N/A: Low level exception (i.e. socket exception)
|
||||
if exc.status_code in {401, 409, 429, 500, 502, 503, 504, 'N/A'}:
|
||||
return True
|
||||
return False
|
||||
|
||||
def get(self, key):
|
||||
try:
|
||||
res = self._get(key)
|
||||
try:
|
||||
if res['found']:
|
||||
return res['_source']['result']
|
||||
except (TypeError, KeyError):
|
||||
pass
|
||||
except elasticsearch.exceptions.NotFoundError:
|
||||
pass
|
||||
|
||||
def _get(self, key):
|
||||
return self.server.get(
|
||||
index=self.index,
|
||||
doc_type=self.doc_type,
|
||||
id=key,
|
||||
)
|
||||
|
||||
def _set_with_state(self, key, value, state):
|
||||
body = {
|
||||
'result': value,
|
||||
'@timestamp': '{}Z'.format(
|
||||
datetime.utcnow().isoformat()[:-3]
|
||||
),
|
||||
}
|
||||
try:
|
||||
self._index(
|
||||
id=key,
|
||||
body=body,
|
||||
)
|
||||
except elasticsearch.exceptions.ConflictError:
|
||||
# document already exists, update it
|
||||
self._update(key, body, state)
|
||||
|
||||
def set(self, key, value):
|
||||
return self._set_with_state(key, value, None)
|
||||
|
||||
def _index(self, id, body, **kwargs):
|
||||
body = {bytes_to_str(k): v for k, v in body.items()}
|
||||
return self.server.index(
|
||||
id=bytes_to_str(id),
|
||||
index=self.index,
|
||||
doc_type=self.doc_type,
|
||||
body=body,
|
||||
params={'op_type': 'create'},
|
||||
**kwargs
|
||||
)
|
||||
|
||||
def _update(self, id, body, state, **kwargs):
|
||||
"""Update state in a conflict free manner.
|
||||
|
||||
If state is defined (not None), this will not update ES server if either:
|
||||
* existing state is success
|
||||
* existing state is a ready state and current state in not a ready state
|
||||
|
||||
This way, a Retry state cannot override a Success or Failure, and chord_unlock
|
||||
will not retry indefinitely.
|
||||
"""
|
||||
body = {bytes_to_str(k): v for k, v in body.items()}
|
||||
|
||||
try:
|
||||
res_get = self._get(key=id)
|
||||
if not res_get.get('found'):
|
||||
return self._index(id, body, **kwargs)
|
||||
# document disappeared between index and get calls.
|
||||
except elasticsearch.exceptions.NotFoundError:
|
||||
return self._index(id, body, **kwargs)
|
||||
|
||||
try:
|
||||
meta_present_on_backend = self.decode_result(res_get['_source']['result'])
|
||||
except (TypeError, KeyError):
|
||||
pass
|
||||
else:
|
||||
if meta_present_on_backend['status'] == states.SUCCESS:
|
||||
# if stored state is already in success, do nothing
|
||||
return {'result': 'noop'}
|
||||
elif meta_present_on_backend['status'] in states.READY_STATES and state in states.UNREADY_STATES:
|
||||
# if stored state is in ready state and current not, do nothing
|
||||
return {'result': 'noop'}
|
||||
|
||||
# get current sequence number and primary term
|
||||
# https://www.elastic.co/guide/en/elasticsearch/reference/current/optimistic-concurrency-control.html
|
||||
seq_no = res_get.get('_seq_no', 1)
|
||||
prim_term = res_get.get('_primary_term', 1)
|
||||
|
||||
# try to update document with current seq_no and primary_term
|
||||
res = self.server.update(
|
||||
id=bytes_to_str(id),
|
||||
index=self.index,
|
||||
doc_type=self.doc_type,
|
||||
body={'doc': body},
|
||||
params={'if_primary_term': prim_term, 'if_seq_no': seq_no},
|
||||
**kwargs
|
||||
)
|
||||
# result is elastic search update query result
|
||||
# noop = query did not update any document
|
||||
# updated = at least one document got updated
|
||||
if res['result'] == 'noop':
|
||||
raise elasticsearch.exceptions.ConflictError(409, 'conflicting update occurred concurrently', {})
|
||||
return res
|
||||
|
||||
def encode(self, data):
|
||||
if self.es_save_meta_as_text:
|
||||
return super().encode(data)
|
||||
else:
|
||||
if not isinstance(data, dict):
|
||||
return super().encode(data)
|
||||
if data.get("result"):
|
||||
data["result"] = self._encode(data["result"])[2]
|
||||
if data.get("traceback"):
|
||||
data["traceback"] = self._encode(data["traceback"])[2]
|
||||
return data
|
||||
|
||||
def decode(self, payload):
|
||||
if self.es_save_meta_as_text:
|
||||
return super().decode(payload)
|
||||
else:
|
||||
if not isinstance(payload, dict):
|
||||
return super().decode(payload)
|
||||
if payload.get("result"):
|
||||
payload["result"] = super().decode(payload["result"])
|
||||
if payload.get("traceback"):
|
||||
payload["traceback"] = super().decode(payload["traceback"])
|
||||
return payload
|
||||
|
||||
def mget(self, keys):
|
||||
return [self.get(key) for key in keys]
|
||||
|
||||
def delete(self, key):
|
||||
self.server.delete(index=self.index, doc_type=self.doc_type, id=key)
|
||||
|
||||
def _get_server(self):
|
||||
"""Connect to the Elasticsearch server."""
|
||||
http_auth = None
|
||||
if self.username and self.password:
|
||||
http_auth = (self.username, self.password)
|
||||
return elasticsearch.Elasticsearch(
|
||||
f'{self.host}:{self.port}',
|
||||
retry_on_timeout=self.es_retry_on_timeout,
|
||||
max_retries=self.es_max_retries,
|
||||
timeout=self.es_timeout,
|
||||
scheme=self.scheme,
|
||||
http_auth=http_auth,
|
||||
)
|
||||
|
||||
@property
|
||||
def server(self):
|
||||
if self._server is None:
|
||||
self._server = self._get_server()
|
||||
return self._server
|
||||
112
venv/lib/python3.12/site-packages/celery/backends/filesystem.py
Normal file
112
venv/lib/python3.12/site-packages/celery/backends/filesystem.py
Normal file
@@ -0,0 +1,112 @@
|
||||
"""File-system result store backend."""
|
||||
import locale
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
from kombu.utils.encoding import ensure_bytes
|
||||
|
||||
from celery import uuid
|
||||
from celery.backends.base import KeyValueStoreBackend
|
||||
from celery.exceptions import ImproperlyConfigured
|
||||
|
||||
default_encoding = locale.getpreferredencoding(False)
|
||||
|
||||
E_NO_PATH_SET = 'You need to configure a path for the file-system backend'
|
||||
E_PATH_NON_CONFORMING_SCHEME = (
|
||||
'A path for the file-system backend should conform to the file URI scheme'
|
||||
)
|
||||
E_PATH_INVALID = """\
|
||||
The configured path for the file-system backend does not
|
||||
work correctly, please make sure that it exists and has
|
||||
the correct permissions.\
|
||||
"""
|
||||
|
||||
|
||||
class FilesystemBackend(KeyValueStoreBackend):
|
||||
"""File-system result backend.
|
||||
|
||||
Arguments:
|
||||
url (str): URL to the directory we should use
|
||||
open (Callable): open function to use when opening files
|
||||
unlink (Callable): unlink function to use when deleting files
|
||||
sep (str): directory separator (to join the directory with the key)
|
||||
encoding (str): encoding used on the file-system
|
||||
"""
|
||||
|
||||
def __init__(self, url=None, open=open, unlink=os.unlink, sep=os.sep,
|
||||
encoding=default_encoding, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.url = url
|
||||
path = self._find_path(url)
|
||||
|
||||
# Remove forwarding "/" for Windows os
|
||||
if os.name == "nt" and path.startswith("/"):
|
||||
path = path[1:]
|
||||
|
||||
# We need the path and separator as bytes objects
|
||||
self.path = path.encode(encoding)
|
||||
self.sep = sep.encode(encoding)
|
||||
|
||||
self.open = open
|
||||
self.unlink = unlink
|
||||
|
||||
# Lets verify that we've everything setup right
|
||||
self._do_directory_test(b'.fs-backend-' + uuid().encode(encoding))
|
||||
|
||||
def __reduce__(self, args=(), kwargs=None):
|
||||
kwargs = {} if not kwargs else kwargs
|
||||
return super().__reduce__(args, {**kwargs, 'url': self.url})
|
||||
|
||||
def _find_path(self, url):
|
||||
if not url:
|
||||
raise ImproperlyConfigured(E_NO_PATH_SET)
|
||||
if url.startswith('file://localhost/'):
|
||||
return url[16:]
|
||||
if url.startswith('file://'):
|
||||
return url[7:]
|
||||
raise ImproperlyConfigured(E_PATH_NON_CONFORMING_SCHEME)
|
||||
|
||||
def _do_directory_test(self, key):
|
||||
try:
|
||||
self.set(key, b'test value')
|
||||
assert self.get(key) == b'test value'
|
||||
self.delete(key)
|
||||
except OSError:
|
||||
raise ImproperlyConfigured(E_PATH_INVALID)
|
||||
|
||||
def _filename(self, key):
|
||||
return self.sep.join((self.path, key))
|
||||
|
||||
def get(self, key):
|
||||
try:
|
||||
with self.open(self._filename(key), 'rb') as infile:
|
||||
return infile.read()
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
def set(self, key, value):
|
||||
with self.open(self._filename(key), 'wb') as outfile:
|
||||
outfile.write(ensure_bytes(value))
|
||||
|
||||
def mget(self, keys):
|
||||
for key in keys:
|
||||
yield self.get(key)
|
||||
|
||||
def delete(self, key):
|
||||
self.unlink(self._filename(key))
|
||||
|
||||
def cleanup(self):
|
||||
"""Delete expired meta-data."""
|
||||
if not self.expires:
|
||||
return
|
||||
epoch = datetime(1970, 1, 1, tzinfo=self.app.timezone)
|
||||
now_ts = (self.app.now() - epoch).total_seconds()
|
||||
cutoff_ts = now_ts - self.expires
|
||||
for filename in os.listdir(self.path):
|
||||
for prefix in (self.task_keyprefix, self.group_keyprefix,
|
||||
self.chord_keyprefix):
|
||||
if filename.startswith(prefix):
|
||||
path = os.path.join(self.path, filename)
|
||||
if os.stat(path).st_mtime < cutoff_ts:
|
||||
self.unlink(path)
|
||||
break
|
||||
333
venv/lib/python3.12/site-packages/celery/backends/mongodb.py
Normal file
333
venv/lib/python3.12/site-packages/celery/backends/mongodb.py
Normal file
@@ -0,0 +1,333 @@
|
||||
"""MongoDB result store backend."""
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from kombu.exceptions import EncodeError
|
||||
from kombu.utils.objects import cached_property
|
||||
from kombu.utils.url import maybe_sanitize_url, urlparse
|
||||
|
||||
from celery import states
|
||||
from celery.exceptions import ImproperlyConfigured
|
||||
|
||||
from .base import BaseBackend
|
||||
|
||||
try:
|
||||
import pymongo
|
||||
except ImportError:
|
||||
pymongo = None
|
||||
|
||||
if pymongo:
|
||||
try:
|
||||
from bson.binary import Binary
|
||||
except ImportError:
|
||||
from pymongo.binary import Binary
|
||||
from pymongo.errors import InvalidDocument
|
||||
else: # pragma: no cover
|
||||
Binary = None
|
||||
|
||||
class InvalidDocument(Exception):
|
||||
pass
|
||||
|
||||
__all__ = ('MongoBackend',)
|
||||
|
||||
BINARY_CODECS = frozenset(['pickle', 'msgpack'])
|
||||
|
||||
|
||||
class MongoBackend(BaseBackend):
|
||||
"""MongoDB result backend.
|
||||
|
||||
Raises:
|
||||
celery.exceptions.ImproperlyConfigured:
|
||||
if module :pypi:`pymongo` is not available.
|
||||
"""
|
||||
|
||||
mongo_host = None
|
||||
host = 'localhost'
|
||||
port = 27017
|
||||
user = None
|
||||
password = None
|
||||
database_name = 'celery'
|
||||
taskmeta_collection = 'celery_taskmeta'
|
||||
groupmeta_collection = 'celery_groupmeta'
|
||||
max_pool_size = 10
|
||||
options = None
|
||||
|
||||
supports_autoexpire = False
|
||||
|
||||
_connection = None
|
||||
|
||||
def __init__(self, app=None, **kwargs):
|
||||
self.options = {}
|
||||
|
||||
super().__init__(app, **kwargs)
|
||||
|
||||
if not pymongo:
|
||||
raise ImproperlyConfigured(
|
||||
'You need to install the pymongo library to use the '
|
||||
'MongoDB backend.')
|
||||
|
||||
# Set option defaults
|
||||
for key, value in self._prepare_client_options().items():
|
||||
self.options.setdefault(key, value)
|
||||
|
||||
# update conf with mongo uri data, only if uri was given
|
||||
if self.url:
|
||||
self.url = self._ensure_mongodb_uri_compliance(self.url)
|
||||
|
||||
uri_data = pymongo.uri_parser.parse_uri(self.url)
|
||||
# build the hosts list to create a mongo connection
|
||||
hostslist = [
|
||||
f'{x[0]}:{x[1]}' for x in uri_data['nodelist']
|
||||
]
|
||||
self.user = uri_data['username']
|
||||
self.password = uri_data['password']
|
||||
self.mongo_host = hostslist
|
||||
if uri_data['database']:
|
||||
# if no database is provided in the uri, use default
|
||||
self.database_name = uri_data['database']
|
||||
|
||||
self.options.update(uri_data['options'])
|
||||
|
||||
# update conf with specific settings
|
||||
config = self.app.conf.get('mongodb_backend_settings')
|
||||
if config is not None:
|
||||
if not isinstance(config, dict):
|
||||
raise ImproperlyConfigured(
|
||||
'MongoDB backend settings should be grouped in a dict')
|
||||
config = dict(config) # don't modify original
|
||||
|
||||
if 'host' in config or 'port' in config:
|
||||
# these should take over uri conf
|
||||
self.mongo_host = None
|
||||
|
||||
self.host = config.pop('host', self.host)
|
||||
self.port = config.pop('port', self.port)
|
||||
self.mongo_host = config.pop('mongo_host', self.mongo_host)
|
||||
self.user = config.pop('user', self.user)
|
||||
self.password = config.pop('password', self.password)
|
||||
self.database_name = config.pop('database', self.database_name)
|
||||
self.taskmeta_collection = config.pop(
|
||||
'taskmeta_collection', self.taskmeta_collection,
|
||||
)
|
||||
self.groupmeta_collection = config.pop(
|
||||
'groupmeta_collection', self.groupmeta_collection,
|
||||
)
|
||||
|
||||
self.options.update(config.pop('options', {}))
|
||||
self.options.update(config)
|
||||
|
||||
@staticmethod
|
||||
def _ensure_mongodb_uri_compliance(url):
|
||||
parsed_url = urlparse(url)
|
||||
if not parsed_url.scheme.startswith('mongodb'):
|
||||
url = f'mongodb+{url}'
|
||||
|
||||
if url == 'mongodb://':
|
||||
url += 'localhost'
|
||||
|
||||
return url
|
||||
|
||||
def _prepare_client_options(self):
|
||||
if pymongo.version_tuple >= (3,):
|
||||
return {'maxPoolSize': self.max_pool_size}
|
||||
else: # pragma: no cover
|
||||
return {'max_pool_size': self.max_pool_size,
|
||||
'auto_start_request': False}
|
||||
|
||||
def _get_connection(self):
|
||||
"""Connect to the MongoDB server."""
|
||||
if self._connection is None:
|
||||
from pymongo import MongoClient
|
||||
|
||||
host = self.mongo_host
|
||||
if not host:
|
||||
# The first pymongo.Connection() argument (host) can be
|
||||
# a list of ['host:port'] elements or a mongodb connection
|
||||
# URI. If this is the case, don't use self.port
|
||||
# but let pymongo get the port(s) from the URI instead.
|
||||
# This enables the use of replica sets and sharding.
|
||||
# See pymongo.Connection() for more info.
|
||||
host = self.host
|
||||
if isinstance(host, str) \
|
||||
and not host.startswith('mongodb://'):
|
||||
host = f'mongodb://{host}:{self.port}'
|
||||
# don't change self.options
|
||||
conf = dict(self.options)
|
||||
conf['host'] = host
|
||||
if self.user:
|
||||
conf['username'] = self.user
|
||||
if self.password:
|
||||
conf['password'] = self.password
|
||||
|
||||
self._connection = MongoClient(**conf)
|
||||
|
||||
return self._connection
|
||||
|
||||
def encode(self, data):
|
||||
if self.serializer == 'bson':
|
||||
# mongodb handles serialization
|
||||
return data
|
||||
payload = super().encode(data)
|
||||
|
||||
# serializer which are in a unsupported format (pickle/binary)
|
||||
if self.serializer in BINARY_CODECS:
|
||||
payload = Binary(payload)
|
||||
return payload
|
||||
|
||||
def decode(self, data):
|
||||
if self.serializer == 'bson':
|
||||
return data
|
||||
return super().decode(data)
|
||||
|
||||
def _store_result(self, task_id, result, state,
|
||||
traceback=None, request=None, **kwargs):
|
||||
"""Store return value and state of an executed task."""
|
||||
meta = self._get_result_meta(result=self.encode(result), state=state,
|
||||
traceback=traceback, request=request,
|
||||
format_date=False)
|
||||
# Add the _id for mongodb
|
||||
meta['_id'] = task_id
|
||||
|
||||
try:
|
||||
self.collection.replace_one({'_id': task_id}, meta, upsert=True)
|
||||
except InvalidDocument as exc:
|
||||
raise EncodeError(exc)
|
||||
|
||||
return result
|
||||
|
||||
def _get_task_meta_for(self, task_id):
|
||||
"""Get task meta-data for a task by id."""
|
||||
obj = self.collection.find_one({'_id': task_id})
|
||||
if obj:
|
||||
if self.app.conf.find_value_for_key('extended', 'result'):
|
||||
return self.meta_from_decoded({
|
||||
'name': obj['name'],
|
||||
'args': obj['args'],
|
||||
'task_id': obj['_id'],
|
||||
'queue': obj['queue'],
|
||||
'kwargs': obj['kwargs'],
|
||||
'status': obj['status'],
|
||||
'worker': obj['worker'],
|
||||
'retries': obj['retries'],
|
||||
'children': obj['children'],
|
||||
'date_done': obj['date_done'],
|
||||
'traceback': obj['traceback'],
|
||||
'result': self.decode(obj['result']),
|
||||
})
|
||||
return self.meta_from_decoded({
|
||||
'task_id': obj['_id'],
|
||||
'status': obj['status'],
|
||||
'result': self.decode(obj['result']),
|
||||
'date_done': obj['date_done'],
|
||||
'traceback': obj['traceback'],
|
||||
'children': obj['children'],
|
||||
})
|
||||
return {'status': states.PENDING, 'result': None}
|
||||
|
||||
def _save_group(self, group_id, result):
|
||||
"""Save the group result."""
|
||||
meta = {
|
||||
'_id': group_id,
|
||||
'result': self.encode([i.id for i in result]),
|
||||
'date_done': datetime.utcnow(),
|
||||
}
|
||||
self.group_collection.replace_one({'_id': group_id}, meta, upsert=True)
|
||||
return result
|
||||
|
||||
def _restore_group(self, group_id):
|
||||
"""Get the result for a group by id."""
|
||||
obj = self.group_collection.find_one({'_id': group_id})
|
||||
if obj:
|
||||
return {
|
||||
'task_id': obj['_id'],
|
||||
'date_done': obj['date_done'],
|
||||
'result': [
|
||||
self.app.AsyncResult(task)
|
||||
for task in self.decode(obj['result'])
|
||||
],
|
||||
}
|
||||
|
||||
def _delete_group(self, group_id):
|
||||
"""Delete a group by id."""
|
||||
self.group_collection.delete_one({'_id': group_id})
|
||||
|
||||
def _forget(self, task_id):
|
||||
"""Remove result from MongoDB.
|
||||
|
||||
Raises:
|
||||
pymongo.exceptions.OperationsError:
|
||||
if the task_id could not be removed.
|
||||
"""
|
||||
# By using safe=True, this will wait until it receives a response from
|
||||
# the server. Likewise, it will raise an OperationsError if the
|
||||
# response was unable to be completed.
|
||||
self.collection.delete_one({'_id': task_id})
|
||||
|
||||
def cleanup(self):
|
||||
"""Delete expired meta-data."""
|
||||
if not self.expires:
|
||||
return
|
||||
|
||||
self.collection.delete_many(
|
||||
{'date_done': {'$lt': self.app.now() - self.expires_delta}},
|
||||
)
|
||||
self.group_collection.delete_many(
|
||||
{'date_done': {'$lt': self.app.now() - self.expires_delta}},
|
||||
)
|
||||
|
||||
def __reduce__(self, args=(), kwargs=None):
|
||||
kwargs = {} if not kwargs else kwargs
|
||||
return super().__reduce__(
|
||||
args, dict(kwargs, expires=self.expires, url=self.url))
|
||||
|
||||
def _get_database(self):
|
||||
conn = self._get_connection()
|
||||
return conn[self.database_name]
|
||||
|
||||
@cached_property
|
||||
def database(self):
|
||||
"""Get database from MongoDB connection.
|
||||
|
||||
performs authentication if necessary.
|
||||
"""
|
||||
return self._get_database()
|
||||
|
||||
@cached_property
|
||||
def collection(self):
|
||||
"""Get the meta-data task collection."""
|
||||
collection = self.database[self.taskmeta_collection]
|
||||
|
||||
# Ensure an index on date_done is there, if not process the index
|
||||
# in the background. Once completed cleanup will be much faster
|
||||
collection.create_index('date_done', background=True)
|
||||
return collection
|
||||
|
||||
@cached_property
|
||||
def group_collection(self):
|
||||
"""Get the meta-data task collection."""
|
||||
collection = self.database[self.groupmeta_collection]
|
||||
|
||||
# Ensure an index on date_done is there, if not process the index
|
||||
# in the background. Once completed cleanup will be much faster
|
||||
collection.create_index('date_done', background=True)
|
||||
return collection
|
||||
|
||||
@cached_property
|
||||
def expires_delta(self):
|
||||
return timedelta(seconds=self.expires)
|
||||
|
||||
def as_uri(self, include_password=False):
|
||||
"""Return the backend as an URI.
|
||||
|
||||
Arguments:
|
||||
include_password (bool): Password censored if disabled.
|
||||
"""
|
||||
if not self.url:
|
||||
return 'mongodb://'
|
||||
if include_password:
|
||||
return self.url
|
||||
|
||||
if ',' not in self.url:
|
||||
return maybe_sanitize_url(self.url)
|
||||
|
||||
uri1, remainder = self.url.split(',', 1)
|
||||
return ','.join([maybe_sanitize_url(uri1), remainder])
|
||||
668
venv/lib/python3.12/site-packages/celery/backends/redis.py
Normal file
668
venv/lib/python3.12/site-packages/celery/backends/redis.py
Normal file
@@ -0,0 +1,668 @@
|
||||
"""Redis result store backend."""
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from functools import partial
|
||||
from ssl import CERT_NONE, CERT_OPTIONAL, CERT_REQUIRED
|
||||
from urllib.parse import unquote
|
||||
|
||||
from kombu.utils.functional import retry_over_time
|
||||
from kombu.utils.objects import cached_property
|
||||
from kombu.utils.url import _parse_url, maybe_sanitize_url
|
||||
|
||||
from celery import states
|
||||
from celery._state import task_join_will_block
|
||||
from celery.canvas import maybe_signature
|
||||
from celery.exceptions import BackendStoreError, ChordError, ImproperlyConfigured
|
||||
from celery.result import GroupResult, allow_join_result
|
||||
from celery.utils.functional import _regen, dictfilter
|
||||
from celery.utils.log import get_logger
|
||||
from celery.utils.time import humanize_seconds
|
||||
|
||||
from .asynchronous import AsyncBackendMixin, BaseResultConsumer
|
||||
from .base import BaseKeyValueStoreBackend
|
||||
|
||||
try:
|
||||
import redis.connection
|
||||
from kombu.transport.redis import get_redis_error_classes
|
||||
except ImportError:
|
||||
redis = None
|
||||
get_redis_error_classes = None
|
||||
|
||||
try:
|
||||
import redis.sentinel
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
__all__ = ('RedisBackend', 'SentinelBackend')
|
||||
|
||||
E_REDIS_MISSING = """
|
||||
You need to install the redis library in order to use \
|
||||
the Redis result store backend.
|
||||
"""
|
||||
|
||||
E_REDIS_SENTINEL_MISSING = """
|
||||
You need to install the redis library with support of \
|
||||
sentinel in order to use the Redis result store backend.
|
||||
"""
|
||||
|
||||
W_REDIS_SSL_CERT_OPTIONAL = """
|
||||
Setting ssl_cert_reqs=CERT_OPTIONAL when connecting to redis means that \
|
||||
celery might not validate the identity of the redis broker when connecting. \
|
||||
This leaves you vulnerable to man in the middle attacks.
|
||||
"""
|
||||
|
||||
W_REDIS_SSL_CERT_NONE = """
|
||||
Setting ssl_cert_reqs=CERT_NONE when connecting to redis means that celery \
|
||||
will not validate the identity of the redis broker when connecting. This \
|
||||
leaves you vulnerable to man in the middle attacks.
|
||||
"""
|
||||
|
||||
E_REDIS_SSL_PARAMS_AND_SCHEME_MISMATCH = """
|
||||
SSL connection parameters have been provided but the specified URL scheme \
|
||||
is redis://. A Redis SSL connection URL should use the scheme rediss://.
|
||||
"""
|
||||
|
||||
E_REDIS_SSL_CERT_REQS_MISSING_INVALID = """
|
||||
A rediss:// URL must have parameter ssl_cert_reqs and this must be set to \
|
||||
CERT_REQUIRED, CERT_OPTIONAL, or CERT_NONE
|
||||
"""
|
||||
|
||||
E_LOST = 'Connection to Redis lost: Retry (%s/%s) %s.'
|
||||
|
||||
E_RETRY_LIMIT_EXCEEDED = """
|
||||
Retry limit exceeded while trying to reconnect to the Celery redis result \
|
||||
store backend. The Celery application must be restarted.
|
||||
"""
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class ResultConsumer(BaseResultConsumer):
|
||||
_pubsub = None
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._get_key_for_task = self.backend.get_key_for_task
|
||||
self._decode_result = self.backend.decode_result
|
||||
self._ensure = self.backend.ensure
|
||||
self._connection_errors = self.backend.connection_errors
|
||||
self.subscribed_to = set()
|
||||
|
||||
def on_after_fork(self):
|
||||
try:
|
||||
self.backend.client.connection_pool.reset()
|
||||
if self._pubsub is not None:
|
||||
self._pubsub.close()
|
||||
except KeyError as e:
|
||||
logger.warning(str(e))
|
||||
super().on_after_fork()
|
||||
|
||||
def _reconnect_pubsub(self):
|
||||
self._pubsub = None
|
||||
self.backend.client.connection_pool.reset()
|
||||
# task state might have changed when the connection was down so we
|
||||
# retrieve meta for all subscribed tasks before going into pubsub mode
|
||||
if self.subscribed_to:
|
||||
metas = self.backend.client.mget(self.subscribed_to)
|
||||
metas = [meta for meta in metas if meta]
|
||||
for meta in metas:
|
||||
self.on_state_change(self._decode_result(meta), None)
|
||||
self._pubsub = self.backend.client.pubsub(
|
||||
ignore_subscribe_messages=True,
|
||||
)
|
||||
# subscribed_to maybe empty after on_state_change
|
||||
if self.subscribed_to:
|
||||
self._pubsub.subscribe(*self.subscribed_to)
|
||||
else:
|
||||
self._pubsub.connection = self._pubsub.connection_pool.get_connection(
|
||||
'pubsub', self._pubsub.shard_hint
|
||||
)
|
||||
# even if there is nothing to subscribe, we should not lose the callback after connecting.
|
||||
# The on_connect callback will re-subscribe to any channels we previously subscribed to.
|
||||
self._pubsub.connection.register_connect_callback(self._pubsub.on_connect)
|
||||
|
||||
@contextmanager
|
||||
def reconnect_on_error(self):
|
||||
try:
|
||||
yield
|
||||
except self._connection_errors:
|
||||
try:
|
||||
self._ensure(self._reconnect_pubsub, ())
|
||||
except self._connection_errors:
|
||||
logger.critical(E_RETRY_LIMIT_EXCEEDED)
|
||||
raise
|
||||
|
||||
def _maybe_cancel_ready_task(self, meta):
|
||||
if meta['status'] in states.READY_STATES:
|
||||
self.cancel_for(meta['task_id'])
|
||||
|
||||
def on_state_change(self, meta, message):
|
||||
super().on_state_change(meta, message)
|
||||
self._maybe_cancel_ready_task(meta)
|
||||
|
||||
def start(self, initial_task_id, **kwargs):
|
||||
self._pubsub = self.backend.client.pubsub(
|
||||
ignore_subscribe_messages=True,
|
||||
)
|
||||
self._consume_from(initial_task_id)
|
||||
|
||||
def on_wait_for_pending(self, result, **kwargs):
|
||||
for meta in result._iter_meta(**kwargs):
|
||||
if meta is not None:
|
||||
self.on_state_change(meta, None)
|
||||
|
||||
def stop(self):
|
||||
if self._pubsub is not None:
|
||||
self._pubsub.close()
|
||||
|
||||
def drain_events(self, timeout=None):
|
||||
if self._pubsub:
|
||||
with self.reconnect_on_error():
|
||||
message = self._pubsub.get_message(timeout=timeout)
|
||||
if message and message['type'] == 'message':
|
||||
self.on_state_change(self._decode_result(message['data']), message)
|
||||
elif timeout:
|
||||
time.sleep(timeout)
|
||||
|
||||
def consume_from(self, task_id):
|
||||
if self._pubsub is None:
|
||||
return self.start(task_id)
|
||||
self._consume_from(task_id)
|
||||
|
||||
def _consume_from(self, task_id):
|
||||
key = self._get_key_for_task(task_id)
|
||||
if key not in self.subscribed_to:
|
||||
self.subscribed_to.add(key)
|
||||
with self.reconnect_on_error():
|
||||
self._pubsub.subscribe(key)
|
||||
|
||||
def cancel_for(self, task_id):
|
||||
key = self._get_key_for_task(task_id)
|
||||
self.subscribed_to.discard(key)
|
||||
if self._pubsub:
|
||||
with self.reconnect_on_error():
|
||||
self._pubsub.unsubscribe(key)
|
||||
|
||||
|
||||
class RedisBackend(BaseKeyValueStoreBackend, AsyncBackendMixin):
|
||||
"""Redis task result store.
|
||||
|
||||
It makes use of the following commands:
|
||||
GET, MGET, DEL, INCRBY, EXPIRE, SET, SETEX
|
||||
"""
|
||||
|
||||
ResultConsumer = ResultConsumer
|
||||
|
||||
#: :pypi:`redis` client module.
|
||||
redis = redis
|
||||
connection_class_ssl = redis.SSLConnection if redis else None
|
||||
|
||||
#: Maximum number of connections in the pool.
|
||||
max_connections = None
|
||||
|
||||
supports_autoexpire = True
|
||||
supports_native_join = True
|
||||
|
||||
#: Maximal length of string value in Redis.
|
||||
#: 512 MB - https://redis.io/topics/data-types
|
||||
_MAX_STR_VALUE_SIZE = 536870912
|
||||
|
||||
def __init__(self, host=None, port=None, db=None, password=None,
|
||||
max_connections=None, url=None,
|
||||
connection_pool=None, **kwargs):
|
||||
super().__init__(expires_type=int, **kwargs)
|
||||
_get = self.app.conf.get
|
||||
if self.redis is None:
|
||||
raise ImproperlyConfigured(E_REDIS_MISSING.strip())
|
||||
|
||||
if host and '://' in host:
|
||||
url, host = host, None
|
||||
|
||||
self.max_connections = (
|
||||
max_connections or
|
||||
_get('redis_max_connections') or
|
||||
self.max_connections)
|
||||
self._ConnectionPool = connection_pool
|
||||
|
||||
socket_timeout = _get('redis_socket_timeout')
|
||||
socket_connect_timeout = _get('redis_socket_connect_timeout')
|
||||
retry_on_timeout = _get('redis_retry_on_timeout')
|
||||
socket_keepalive = _get('redis_socket_keepalive')
|
||||
health_check_interval = _get('redis_backend_health_check_interval')
|
||||
|
||||
self.connparams = {
|
||||
'host': _get('redis_host') or 'localhost',
|
||||
'port': _get('redis_port') or 6379,
|
||||
'db': _get('redis_db') or 0,
|
||||
'password': _get('redis_password'),
|
||||
'max_connections': self.max_connections,
|
||||
'socket_timeout': socket_timeout and float(socket_timeout),
|
||||
'retry_on_timeout': retry_on_timeout or False,
|
||||
'socket_connect_timeout':
|
||||
socket_connect_timeout and float(socket_connect_timeout),
|
||||
}
|
||||
|
||||
username = _get('redis_username')
|
||||
if username:
|
||||
# We're extra careful to avoid including this configuration value
|
||||
# if it wasn't specified since older versions of py-redis
|
||||
# don't support specifying a username.
|
||||
# Only Redis>6.0 supports username/password authentication.
|
||||
|
||||
# TODO: Include this in connparams' definition once we drop
|
||||
# support for py-redis<3.4.0.
|
||||
self.connparams['username'] = username
|
||||
|
||||
if health_check_interval:
|
||||
self.connparams["health_check_interval"] = health_check_interval
|
||||
|
||||
# absent in redis.connection.UnixDomainSocketConnection
|
||||
if socket_keepalive:
|
||||
self.connparams['socket_keepalive'] = socket_keepalive
|
||||
|
||||
# "redis_backend_use_ssl" must be a dict with the keys:
|
||||
# 'ssl_cert_reqs', 'ssl_ca_certs', 'ssl_certfile', 'ssl_keyfile'
|
||||
# (the same as "broker_use_ssl")
|
||||
ssl = _get('redis_backend_use_ssl')
|
||||
if ssl:
|
||||
self.connparams.update(ssl)
|
||||
self.connparams['connection_class'] = self.connection_class_ssl
|
||||
|
||||
if url:
|
||||
self.connparams = self._params_from_url(url, self.connparams)
|
||||
|
||||
# If we've received SSL parameters via query string or the
|
||||
# redis_backend_use_ssl dict, check ssl_cert_reqs is valid. If set
|
||||
# via query string ssl_cert_reqs will be a string so convert it here
|
||||
if ('connection_class' in self.connparams and
|
||||
issubclass(self.connparams['connection_class'], redis.SSLConnection)):
|
||||
ssl_cert_reqs_missing = 'MISSING'
|
||||
ssl_string_to_constant = {'CERT_REQUIRED': CERT_REQUIRED,
|
||||
'CERT_OPTIONAL': CERT_OPTIONAL,
|
||||
'CERT_NONE': CERT_NONE,
|
||||
'required': CERT_REQUIRED,
|
||||
'optional': CERT_OPTIONAL,
|
||||
'none': CERT_NONE}
|
||||
ssl_cert_reqs = self.connparams.get('ssl_cert_reqs', ssl_cert_reqs_missing)
|
||||
ssl_cert_reqs = ssl_string_to_constant.get(ssl_cert_reqs, ssl_cert_reqs)
|
||||
if ssl_cert_reqs not in ssl_string_to_constant.values():
|
||||
raise ValueError(E_REDIS_SSL_CERT_REQS_MISSING_INVALID)
|
||||
|
||||
if ssl_cert_reqs == CERT_OPTIONAL:
|
||||
logger.warning(W_REDIS_SSL_CERT_OPTIONAL)
|
||||
elif ssl_cert_reqs == CERT_NONE:
|
||||
logger.warning(W_REDIS_SSL_CERT_NONE)
|
||||
self.connparams['ssl_cert_reqs'] = ssl_cert_reqs
|
||||
|
||||
self.url = url
|
||||
|
||||
self.connection_errors, self.channel_errors = (
|
||||
get_redis_error_classes() if get_redis_error_classes
|
||||
else ((), ()))
|
||||
self.result_consumer = self.ResultConsumer(
|
||||
self, self.app, self.accept,
|
||||
self._pending_results, self._pending_messages,
|
||||
)
|
||||
|
||||
def _params_from_url(self, url, defaults):
|
||||
scheme, host, port, username, password, path, query = _parse_url(url)
|
||||
connparams = dict(
|
||||
defaults, **dictfilter({
|
||||
'host': host, 'port': port, 'username': username,
|
||||
'password': password, 'db': query.pop('virtual_host', None)})
|
||||
)
|
||||
|
||||
if scheme == 'socket':
|
||||
# use 'path' as path to the socket… in this case
|
||||
# the database number should be given in 'query'
|
||||
connparams.update({
|
||||
'connection_class': self.redis.UnixDomainSocketConnection,
|
||||
'path': '/' + path,
|
||||
})
|
||||
# host+port are invalid options when using this connection type.
|
||||
connparams.pop('host', None)
|
||||
connparams.pop('port', None)
|
||||
connparams.pop('socket_connect_timeout')
|
||||
else:
|
||||
connparams['db'] = path
|
||||
|
||||
ssl_param_keys = ['ssl_ca_certs', 'ssl_certfile', 'ssl_keyfile',
|
||||
'ssl_cert_reqs']
|
||||
|
||||
if scheme == 'redis':
|
||||
# If connparams or query string contain ssl params, raise error
|
||||
if (any(key in connparams for key in ssl_param_keys) or
|
||||
any(key in query for key in ssl_param_keys)):
|
||||
raise ValueError(E_REDIS_SSL_PARAMS_AND_SCHEME_MISMATCH)
|
||||
|
||||
if scheme == 'rediss':
|
||||
connparams['connection_class'] = redis.SSLConnection
|
||||
# The following parameters, if present in the URL, are encoded. We
|
||||
# must add the decoded values to connparams.
|
||||
for ssl_setting in ssl_param_keys:
|
||||
ssl_val = query.pop(ssl_setting, None)
|
||||
if ssl_val:
|
||||
connparams[ssl_setting] = unquote(ssl_val)
|
||||
|
||||
# db may be string and start with / like in kombu.
|
||||
db = connparams.get('db') or 0
|
||||
db = db.strip('/') if isinstance(db, str) else db
|
||||
connparams['db'] = int(db)
|
||||
|
||||
for key, value in query.items():
|
||||
if key in redis.connection.URL_QUERY_ARGUMENT_PARSERS:
|
||||
query[key] = redis.connection.URL_QUERY_ARGUMENT_PARSERS[key](
|
||||
value
|
||||
)
|
||||
|
||||
# Query parameters override other parameters
|
||||
connparams.update(query)
|
||||
return connparams
|
||||
|
||||
@cached_property
|
||||
def retry_policy(self):
|
||||
retry_policy = super().retry_policy
|
||||
if "retry_policy" in self._transport_options:
|
||||
retry_policy = retry_policy.copy()
|
||||
retry_policy.update(self._transport_options['retry_policy'])
|
||||
|
||||
return retry_policy
|
||||
|
||||
def on_task_call(self, producer, task_id):
|
||||
if not task_join_will_block():
|
||||
self.result_consumer.consume_from(task_id)
|
||||
|
||||
def get(self, key):
|
||||
return self.client.get(key)
|
||||
|
||||
def mget(self, keys):
|
||||
return self.client.mget(keys)
|
||||
|
||||
def ensure(self, fun, args, **policy):
|
||||
retry_policy = dict(self.retry_policy, **policy)
|
||||
max_retries = retry_policy.get('max_retries')
|
||||
return retry_over_time(
|
||||
fun, self.connection_errors, args, {},
|
||||
partial(self.on_connection_error, max_retries),
|
||||
**retry_policy)
|
||||
|
||||
def on_connection_error(self, max_retries, exc, intervals, retries):
|
||||
tts = next(intervals)
|
||||
logger.error(
|
||||
E_LOST.strip(),
|
||||
retries, max_retries or 'Inf', humanize_seconds(tts, 'in '))
|
||||
return tts
|
||||
|
||||
def set(self, key, value, **retry_policy):
|
||||
if isinstance(value, str) and len(value) > self._MAX_STR_VALUE_SIZE:
|
||||
raise BackendStoreError('value too large for Redis backend')
|
||||
|
||||
return self.ensure(self._set, (key, value), **retry_policy)
|
||||
|
||||
def _set(self, key, value):
|
||||
with self.client.pipeline() as pipe:
|
||||
if self.expires:
|
||||
pipe.setex(key, self.expires, value)
|
||||
else:
|
||||
pipe.set(key, value)
|
||||
pipe.publish(key, value)
|
||||
pipe.execute()
|
||||
|
||||
def forget(self, task_id):
|
||||
super().forget(task_id)
|
||||
self.result_consumer.cancel_for(task_id)
|
||||
|
||||
def delete(self, key):
|
||||
self.client.delete(key)
|
||||
|
||||
def incr(self, key):
|
||||
return self.client.incr(key)
|
||||
|
||||
def expire(self, key, value):
|
||||
return self.client.expire(key, value)
|
||||
|
||||
def add_to_chord(self, group_id, result):
|
||||
self.client.incr(self.get_key_for_group(group_id, '.t'), 1)
|
||||
|
||||
def _unpack_chord_result(self, tup, decode,
|
||||
EXCEPTION_STATES=states.EXCEPTION_STATES,
|
||||
PROPAGATE_STATES=states.PROPAGATE_STATES):
|
||||
_, tid, state, retval = decode(tup)
|
||||
if state in EXCEPTION_STATES:
|
||||
retval = self.exception_to_python(retval)
|
||||
if state in PROPAGATE_STATES:
|
||||
raise ChordError(f'Dependency {tid} raised {retval!r}')
|
||||
return retval
|
||||
|
||||
def set_chord_size(self, group_id, chord_size):
|
||||
self.set(self.get_key_for_group(group_id, '.s'), chord_size)
|
||||
|
||||
def apply_chord(self, header_result_args, body, **kwargs):
|
||||
# If any of the child results of this chord are complex (ie. group
|
||||
# results themselves), we need to save `header_result` to ensure that
|
||||
# the expected structure is retained when we finish the chord and pass
|
||||
# the results onward to the body in `on_chord_part_return()`. We don't
|
||||
# do this is all cases to retain an optimisation in the common case
|
||||
# where a chord header is comprised of simple result objects.
|
||||
if not isinstance(header_result_args[1], _regen):
|
||||
header_result = self.app.GroupResult(*header_result_args)
|
||||
if any(isinstance(nr, GroupResult) for nr in header_result.results):
|
||||
header_result.save(backend=self)
|
||||
|
||||
@cached_property
|
||||
def _chord_zset(self):
|
||||
return self._transport_options.get('result_chord_ordered', True)
|
||||
|
||||
@cached_property
|
||||
def _transport_options(self):
|
||||
return self.app.conf.get('result_backend_transport_options', {})
|
||||
|
||||
def on_chord_part_return(self, request, state, result,
|
||||
propagate=None, **kwargs):
|
||||
app = self.app
|
||||
tid, gid, group_index = request.id, request.group, request.group_index
|
||||
if not gid or not tid:
|
||||
return
|
||||
if group_index is None:
|
||||
group_index = '+inf'
|
||||
|
||||
client = self.client
|
||||
jkey = self.get_key_for_group(gid, '.j')
|
||||
tkey = self.get_key_for_group(gid, '.t')
|
||||
skey = self.get_key_for_group(gid, '.s')
|
||||
result = self.encode_result(result, state)
|
||||
encoded = self.encode([1, tid, state, result])
|
||||
with client.pipeline() as pipe:
|
||||
pipeline = (
|
||||
pipe.zadd(jkey, {encoded: group_index}).zcount(jkey, "-inf", "+inf")
|
||||
if self._chord_zset
|
||||
else pipe.rpush(jkey, encoded).llen(jkey)
|
||||
).get(tkey).get(skey)
|
||||
if self.expires:
|
||||
pipeline = pipeline \
|
||||
.expire(jkey, self.expires) \
|
||||
.expire(tkey, self.expires) \
|
||||
.expire(skey, self.expires)
|
||||
|
||||
_, readycount, totaldiff, chord_size_bytes = pipeline.execute()[:4]
|
||||
|
||||
totaldiff = int(totaldiff or 0)
|
||||
|
||||
if chord_size_bytes:
|
||||
try:
|
||||
callback = maybe_signature(request.chord, app=app)
|
||||
total = int(chord_size_bytes) + totaldiff
|
||||
if readycount == total:
|
||||
header_result = GroupResult.restore(gid)
|
||||
if header_result is not None:
|
||||
# If we manage to restore a `GroupResult`, then it must
|
||||
# have been complex and saved by `apply_chord()` earlier.
|
||||
#
|
||||
# Before we can join the `GroupResult`, it needs to be
|
||||
# manually marked as ready to avoid blocking
|
||||
header_result.on_ready()
|
||||
# We'll `join()` it to get the results and ensure they are
|
||||
# structured as intended rather than the flattened version
|
||||
# we'd construct without any other information.
|
||||
join_func = (
|
||||
header_result.join_native
|
||||
if header_result.supports_native_join
|
||||
else header_result.join
|
||||
)
|
||||
with allow_join_result():
|
||||
resl = join_func(
|
||||
timeout=app.conf.result_chord_join_timeout,
|
||||
propagate=True
|
||||
)
|
||||
else:
|
||||
# Otherwise simply extract and decode the results we
|
||||
# stashed along the way, which should be faster for large
|
||||
# numbers of simple results in the chord header.
|
||||
decode, unpack = self.decode, self._unpack_chord_result
|
||||
with client.pipeline() as pipe:
|
||||
if self._chord_zset:
|
||||
pipeline = pipe.zrange(jkey, 0, -1)
|
||||
else:
|
||||
pipeline = pipe.lrange(jkey, 0, total)
|
||||
resl, = pipeline.execute()
|
||||
resl = [unpack(tup, decode) for tup in resl]
|
||||
try:
|
||||
callback.delay(resl)
|
||||
except Exception as exc: # pylint: disable=broad-except
|
||||
logger.exception(
|
||||
'Chord callback for %r raised: %r', request.group, exc)
|
||||
return self.chord_error_from_stack(
|
||||
callback,
|
||||
ChordError(f'Callback error: {exc!r}'),
|
||||
)
|
||||
finally:
|
||||
with client.pipeline() as pipe:
|
||||
pipe \
|
||||
.delete(jkey) \
|
||||
.delete(tkey) \
|
||||
.delete(skey) \
|
||||
.execute()
|
||||
except ChordError as exc:
|
||||
logger.exception('Chord %r raised: %r', request.group, exc)
|
||||
return self.chord_error_from_stack(callback, exc)
|
||||
except Exception as exc: # pylint: disable=broad-except
|
||||
logger.exception('Chord %r raised: %r', request.group, exc)
|
||||
return self.chord_error_from_stack(
|
||||
callback,
|
||||
ChordError(f'Join error: {exc!r}'),
|
||||
)
|
||||
|
||||
def _create_client(self, **params):
|
||||
return self._get_client()(
|
||||
connection_pool=self._get_pool(**params),
|
||||
)
|
||||
|
||||
def _get_client(self):
|
||||
return self.redis.StrictRedis
|
||||
|
||||
def _get_pool(self, **params):
|
||||
return self.ConnectionPool(**params)
|
||||
|
||||
@property
|
||||
def ConnectionPool(self):
|
||||
if self._ConnectionPool is None:
|
||||
self._ConnectionPool = self.redis.ConnectionPool
|
||||
return self._ConnectionPool
|
||||
|
||||
@cached_property
|
||||
def client(self):
|
||||
return self._create_client(**self.connparams)
|
||||
|
||||
def __reduce__(self, args=(), kwargs=None):
|
||||
kwargs = {} if not kwargs else kwargs
|
||||
return super().__reduce__(
|
||||
args, dict(kwargs, expires=self.expires, url=self.url))
|
||||
|
||||
|
||||
if getattr(redis, "sentinel", None):
|
||||
class SentinelManagedSSLConnection(
|
||||
redis.sentinel.SentinelManagedConnection,
|
||||
redis.SSLConnection):
|
||||
"""Connect to a Redis server using Sentinel + TLS.
|
||||
|
||||
Use Sentinel to identify which Redis server is the current master
|
||||
to connect to and when connecting to the Master server, use an
|
||||
SSL Connection.
|
||||
"""
|
||||
|
||||
|
||||
class SentinelBackend(RedisBackend):
|
||||
"""Redis sentinel task result store."""
|
||||
|
||||
# URL looks like `sentinel://0.0.0.0:26347/3;sentinel://0.0.0.0:26348/3`
|
||||
_SERVER_URI_SEPARATOR = ";"
|
||||
|
||||
sentinel = getattr(redis, "sentinel", None)
|
||||
connection_class_ssl = SentinelManagedSSLConnection if sentinel else None
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
if self.sentinel is None:
|
||||
raise ImproperlyConfigured(E_REDIS_SENTINEL_MISSING.strip())
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def as_uri(self, include_password=False):
|
||||
"""Return the server addresses as URIs, sanitizing the password or not."""
|
||||
# Allow superclass to do work if we don't need to force sanitization
|
||||
if include_password:
|
||||
return super().as_uri(
|
||||
include_password=include_password,
|
||||
)
|
||||
# Otherwise we need to ensure that all components get sanitized rather
|
||||
# by passing them one by one to the `kombu` helper
|
||||
uri_chunks = (
|
||||
maybe_sanitize_url(chunk)
|
||||
for chunk in (self.url or "").split(self._SERVER_URI_SEPARATOR)
|
||||
)
|
||||
# Similar to the superclass, strip the trailing slash from URIs with
|
||||
# all components empty other than the scheme
|
||||
return self._SERVER_URI_SEPARATOR.join(
|
||||
uri[:-1] if uri.endswith(":///") else uri
|
||||
for uri in uri_chunks
|
||||
)
|
||||
|
||||
def _params_from_url(self, url, defaults):
|
||||
chunks = url.split(self._SERVER_URI_SEPARATOR)
|
||||
connparams = dict(defaults, hosts=[])
|
||||
for chunk in chunks:
|
||||
data = super()._params_from_url(
|
||||
url=chunk, defaults=defaults)
|
||||
connparams['hosts'].append(data)
|
||||
for param in ("host", "port", "db", "password"):
|
||||
connparams.pop(param)
|
||||
|
||||
# Adding db/password in connparams to connect to the correct instance
|
||||
for param in ("db", "password"):
|
||||
if connparams['hosts'] and param in connparams['hosts'][0]:
|
||||
connparams[param] = connparams['hosts'][0].get(param)
|
||||
return connparams
|
||||
|
||||
def _get_sentinel_instance(self, **params):
|
||||
connparams = params.copy()
|
||||
|
||||
hosts = connparams.pop("hosts")
|
||||
min_other_sentinels = self._transport_options.get("min_other_sentinels", 0)
|
||||
sentinel_kwargs = self._transport_options.get("sentinel_kwargs", {})
|
||||
|
||||
sentinel_instance = self.sentinel.Sentinel(
|
||||
[(cp['host'], cp['port']) for cp in hosts],
|
||||
min_other_sentinels=min_other_sentinels,
|
||||
sentinel_kwargs=sentinel_kwargs,
|
||||
**connparams)
|
||||
|
||||
return sentinel_instance
|
||||
|
||||
def _get_pool(self, **params):
|
||||
sentinel_instance = self._get_sentinel_instance(**params)
|
||||
|
||||
master_name = self._transport_options.get("master_name", None)
|
||||
|
||||
return sentinel_instance.master_for(
|
||||
service_name=master_name,
|
||||
redis_class=self._get_client(),
|
||||
).connection_pool
|
||||
342
venv/lib/python3.12/site-packages/celery/backends/rpc.py
Normal file
342
venv/lib/python3.12/site-packages/celery/backends/rpc.py
Normal file
@@ -0,0 +1,342 @@
|
||||
"""The ``RPC`` result backend for AMQP brokers.
|
||||
|
||||
RPC-style result backend, using reply-to and one queue per client.
|
||||
"""
|
||||
import time
|
||||
|
||||
import kombu
|
||||
from kombu.common import maybe_declare
|
||||
from kombu.utils.compat import register_after_fork
|
||||
from kombu.utils.objects import cached_property
|
||||
|
||||
from celery import states
|
||||
from celery._state import current_task, task_join_will_block
|
||||
|
||||
from . import base
|
||||
from .asynchronous import AsyncBackendMixin, BaseResultConsumer
|
||||
|
||||
__all__ = ('BacklogLimitExceeded', 'RPCBackend')
|
||||
|
||||
E_NO_CHORD_SUPPORT = """
|
||||
The "rpc" result backend does not support chords!
|
||||
|
||||
Note that a group chained with a task is also upgraded to be a chord,
|
||||
as this pattern requires synchronization.
|
||||
|
||||
Result backends that supports chords: Redis, Database, Memcached, and more.
|
||||
"""
|
||||
|
||||
|
||||
class BacklogLimitExceeded(Exception):
|
||||
"""Too much state history to fast-forward."""
|
||||
|
||||
|
||||
def _on_after_fork_cleanup_backend(backend):
|
||||
backend._after_fork()
|
||||
|
||||
|
||||
class ResultConsumer(BaseResultConsumer):
|
||||
Consumer = kombu.Consumer
|
||||
|
||||
_connection = None
|
||||
_consumer = None
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._create_binding = self.backend._create_binding
|
||||
|
||||
def start(self, initial_task_id, no_ack=True, **kwargs):
|
||||
self._connection = self.app.connection()
|
||||
initial_queue = self._create_binding(initial_task_id)
|
||||
self._consumer = self.Consumer(
|
||||
self._connection.default_channel, [initial_queue],
|
||||
callbacks=[self.on_state_change], no_ack=no_ack,
|
||||
accept=self.accept)
|
||||
self._consumer.consume()
|
||||
|
||||
def drain_events(self, timeout=None):
|
||||
if self._connection:
|
||||
return self._connection.drain_events(timeout=timeout)
|
||||
elif timeout:
|
||||
time.sleep(timeout)
|
||||
|
||||
def stop(self):
|
||||
try:
|
||||
self._consumer.cancel()
|
||||
finally:
|
||||
self._connection.close()
|
||||
|
||||
def on_after_fork(self):
|
||||
self._consumer = None
|
||||
if self._connection is not None:
|
||||
self._connection.collect()
|
||||
self._connection = None
|
||||
|
||||
def consume_from(self, task_id):
|
||||
if self._consumer is None:
|
||||
return self.start(task_id)
|
||||
queue = self._create_binding(task_id)
|
||||
if not self._consumer.consuming_from(queue):
|
||||
self._consumer.add_queue(queue)
|
||||
self._consumer.consume()
|
||||
|
||||
def cancel_for(self, task_id):
|
||||
if self._consumer:
|
||||
self._consumer.cancel_by_queue(self._create_binding(task_id).name)
|
||||
|
||||
|
||||
class RPCBackend(base.Backend, AsyncBackendMixin):
|
||||
"""Base class for the RPC result backend."""
|
||||
|
||||
Exchange = kombu.Exchange
|
||||
Producer = kombu.Producer
|
||||
ResultConsumer = ResultConsumer
|
||||
|
||||
#: Exception raised when there are too many messages for a task id.
|
||||
BacklogLimitExceeded = BacklogLimitExceeded
|
||||
|
||||
persistent = False
|
||||
supports_autoexpire = True
|
||||
supports_native_join = True
|
||||
|
||||
retry_policy = {
|
||||
'max_retries': 20,
|
||||
'interval_start': 0,
|
||||
'interval_step': 1,
|
||||
'interval_max': 1,
|
||||
}
|
||||
|
||||
class Consumer(kombu.Consumer):
|
||||
"""Consumer that requires manual declaration of queues."""
|
||||
|
||||
auto_declare = False
|
||||
|
||||
class Queue(kombu.Queue):
|
||||
"""Queue that never caches declaration."""
|
||||
|
||||
can_cache_declaration = False
|
||||
|
||||
def __init__(self, app, connection=None, exchange=None, exchange_type=None,
|
||||
persistent=None, serializer=None, auto_delete=True, **kwargs):
|
||||
super().__init__(app, **kwargs)
|
||||
conf = self.app.conf
|
||||
self._connection = connection
|
||||
self._out_of_band = {}
|
||||
self.persistent = self.prepare_persistent(persistent)
|
||||
self.delivery_mode = 2 if self.persistent else 1
|
||||
exchange = exchange or conf.result_exchange
|
||||
exchange_type = exchange_type or conf.result_exchange_type
|
||||
self.exchange = self._create_exchange(
|
||||
exchange, exchange_type, self.delivery_mode,
|
||||
)
|
||||
self.serializer = serializer or conf.result_serializer
|
||||
self.auto_delete = auto_delete
|
||||
self.result_consumer = self.ResultConsumer(
|
||||
self, self.app, self.accept,
|
||||
self._pending_results, self._pending_messages,
|
||||
)
|
||||
if register_after_fork is not None:
|
||||
register_after_fork(self, _on_after_fork_cleanup_backend)
|
||||
|
||||
def _after_fork(self):
|
||||
# clear state for child processes.
|
||||
self._pending_results.clear()
|
||||
self.result_consumer._after_fork()
|
||||
|
||||
def _create_exchange(self, name, type='direct', delivery_mode=2):
|
||||
# uses direct to queue routing (anon exchange).
|
||||
return self.Exchange(None)
|
||||
|
||||
def _create_binding(self, task_id):
|
||||
"""Create new binding for task with id."""
|
||||
# RPC backend caches the binding, as one queue is used for all tasks.
|
||||
return self.binding
|
||||
|
||||
def ensure_chords_allowed(self):
|
||||
raise NotImplementedError(E_NO_CHORD_SUPPORT.strip())
|
||||
|
||||
def on_task_call(self, producer, task_id):
|
||||
# Called every time a task is sent when using this backend.
|
||||
# We declare the queue we receive replies on in advance of sending
|
||||
# the message, but we skip this if running in the prefork pool
|
||||
# (task_join_will_block), as we know the queue is already declared.
|
||||
if not task_join_will_block():
|
||||
maybe_declare(self.binding(producer.channel), retry=True)
|
||||
|
||||
def destination_for(self, task_id, request):
|
||||
"""Get the destination for result by task id.
|
||||
|
||||
Returns:
|
||||
Tuple[str, str]: tuple of ``(reply_to, correlation_id)``.
|
||||
"""
|
||||
# Backends didn't always receive the `request`, so we must still
|
||||
# support old code that relies on current_task.
|
||||
try:
|
||||
request = request or current_task.request
|
||||
except AttributeError:
|
||||
raise RuntimeError(
|
||||
f'RPC backend missing task request for {task_id!r}')
|
||||
return request.reply_to, request.correlation_id or task_id
|
||||
|
||||
def on_reply_declare(self, task_id):
|
||||
# Return value here is used as the `declare=` argument
|
||||
# for Producer.publish.
|
||||
# By default we don't have to declare anything when sending a result.
|
||||
pass
|
||||
|
||||
def on_result_fulfilled(self, result):
|
||||
# This usually cancels the queue after the result is received,
|
||||
# but we don't have to cancel since we have one queue per process.
|
||||
pass
|
||||
|
||||
def as_uri(self, include_password=True):
|
||||
return 'rpc://'
|
||||
|
||||
def store_result(self, task_id, result, state,
|
||||
traceback=None, request=None, **kwargs):
|
||||
"""Send task return value and state."""
|
||||
routing_key, correlation_id = self.destination_for(task_id, request)
|
||||
if not routing_key:
|
||||
return
|
||||
with self.app.amqp.producer_pool.acquire(block=True) as producer:
|
||||
producer.publish(
|
||||
self._to_result(task_id, state, result, traceback, request),
|
||||
exchange=self.exchange,
|
||||
routing_key=routing_key,
|
||||
correlation_id=correlation_id,
|
||||
serializer=self.serializer,
|
||||
retry=True, retry_policy=self.retry_policy,
|
||||
declare=self.on_reply_declare(task_id),
|
||||
delivery_mode=self.delivery_mode,
|
||||
)
|
||||
return result
|
||||
|
||||
def _to_result(self, task_id, state, result, traceback, request):
|
||||
return {
|
||||
'task_id': task_id,
|
||||
'status': state,
|
||||
'result': self.encode_result(result, state),
|
||||
'traceback': traceback,
|
||||
'children': self.current_task_children(request),
|
||||
}
|
||||
|
||||
def on_out_of_band_result(self, task_id, message):
|
||||
# Callback called when a reply for a task is received,
|
||||
# but we have no idea what do do with it.
|
||||
# Since the result is not pending, we put it in a separate
|
||||
# buffer: probably it will become pending later.
|
||||
if self.result_consumer:
|
||||
self.result_consumer.on_out_of_band_result(message)
|
||||
self._out_of_band[task_id] = message
|
||||
|
||||
def get_task_meta(self, task_id, backlog_limit=1000):
|
||||
buffered = self._out_of_band.pop(task_id, None)
|
||||
if buffered:
|
||||
return self._set_cache_by_message(task_id, buffered)
|
||||
|
||||
# Polling and using basic_get
|
||||
latest_by_id = {}
|
||||
prev = None
|
||||
for acc in self._slurp_from_queue(task_id, self.accept, backlog_limit):
|
||||
tid = self._get_message_task_id(acc)
|
||||
prev, latest_by_id[tid] = latest_by_id.get(tid), acc
|
||||
if prev:
|
||||
# backends aren't expected to keep history,
|
||||
# so we delete everything except the most recent state.
|
||||
prev.ack()
|
||||
prev = None
|
||||
|
||||
latest = latest_by_id.pop(task_id, None)
|
||||
for tid, msg in latest_by_id.items():
|
||||
self.on_out_of_band_result(tid, msg)
|
||||
|
||||
if latest:
|
||||
latest.requeue()
|
||||
return self._set_cache_by_message(task_id, latest)
|
||||
else:
|
||||
# no new state, use previous
|
||||
try:
|
||||
return self._cache[task_id]
|
||||
except KeyError:
|
||||
# result probably pending.
|
||||
return {'status': states.PENDING, 'result': None}
|
||||
poll = get_task_meta # XXX compat
|
||||
|
||||
def _set_cache_by_message(self, task_id, message):
|
||||
payload = self._cache[task_id] = self.meta_from_decoded(
|
||||
message.payload)
|
||||
return payload
|
||||
|
||||
def _slurp_from_queue(self, task_id, accept,
|
||||
limit=1000, no_ack=False):
|
||||
with self.app.pool.acquire_channel(block=True) as (_, channel):
|
||||
binding = self._create_binding(task_id)(channel)
|
||||
binding.declare()
|
||||
|
||||
for _ in range(limit):
|
||||
msg = binding.get(accept=accept, no_ack=no_ack)
|
||||
if not msg:
|
||||
break
|
||||
yield msg
|
||||
else:
|
||||
raise self.BacklogLimitExceeded(task_id)
|
||||
|
||||
def _get_message_task_id(self, message):
|
||||
try:
|
||||
# try property first so we don't have to deserialize
|
||||
# the payload.
|
||||
return message.properties['correlation_id']
|
||||
except (AttributeError, KeyError):
|
||||
# message sent by old Celery version, need to deserialize.
|
||||
return message.payload['task_id']
|
||||
|
||||
def revive(self, channel):
|
||||
pass
|
||||
|
||||
def reload_task_result(self, task_id):
|
||||
raise NotImplementedError(
|
||||
'reload_task_result is not supported by this backend.')
|
||||
|
||||
def reload_group_result(self, task_id):
|
||||
"""Reload group result, even if it has been previously fetched."""
|
||||
raise NotImplementedError(
|
||||
'reload_group_result is not supported by this backend.')
|
||||
|
||||
def save_group(self, group_id, result):
|
||||
raise NotImplementedError(
|
||||
'save_group is not supported by this backend.')
|
||||
|
||||
def restore_group(self, group_id, cache=True):
|
||||
raise NotImplementedError(
|
||||
'restore_group is not supported by this backend.')
|
||||
|
||||
def delete_group(self, group_id):
|
||||
raise NotImplementedError(
|
||||
'delete_group is not supported by this backend.')
|
||||
|
||||
def __reduce__(self, args=(), kwargs=None):
|
||||
kwargs = {} if not kwargs else kwargs
|
||||
return super().__reduce__(args, dict(
|
||||
kwargs,
|
||||
connection=self._connection,
|
||||
exchange=self.exchange.name,
|
||||
exchange_type=self.exchange.type,
|
||||
persistent=self.persistent,
|
||||
serializer=self.serializer,
|
||||
auto_delete=self.auto_delete,
|
||||
expires=self.expires,
|
||||
))
|
||||
|
||||
@property
|
||||
def binding(self):
|
||||
return self.Queue(
|
||||
self.oid, self.exchange, self.oid,
|
||||
durable=False,
|
||||
auto_delete=True,
|
||||
expires=self.expires,
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def oid(self):
|
||||
# cached here is the app thread OID: name of queue we receive results on.
|
||||
return self.app.thread_oid
|
||||
87
venv/lib/python3.12/site-packages/celery/backends/s3.py
Normal file
87
venv/lib/python3.12/site-packages/celery/backends/s3.py
Normal file
@@ -0,0 +1,87 @@
|
||||
"""s3 result store backend."""
|
||||
|
||||
from kombu.utils.encoding import bytes_to_str
|
||||
|
||||
from celery.exceptions import ImproperlyConfigured
|
||||
|
||||
from .base import KeyValueStoreBackend
|
||||
|
||||
try:
|
||||
import boto3
|
||||
import botocore
|
||||
except ImportError:
|
||||
boto3 = None
|
||||
botocore = None
|
||||
|
||||
|
||||
__all__ = ('S3Backend',)
|
||||
|
||||
|
||||
class S3Backend(KeyValueStoreBackend):
|
||||
"""An S3 task result store.
|
||||
|
||||
Raises:
|
||||
celery.exceptions.ImproperlyConfigured:
|
||||
if module :pypi:`boto3` is not available,
|
||||
if the :setting:`aws_access_key_id` or
|
||||
setting:`aws_secret_access_key` are not set,
|
||||
or it the :setting:`bucket` is not set.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
if not boto3 or not botocore:
|
||||
raise ImproperlyConfigured('You must install boto3'
|
||||
'to use s3 backend')
|
||||
conf = self.app.conf
|
||||
|
||||
self.endpoint_url = conf.get('s3_endpoint_url', None)
|
||||
self.aws_region = conf.get('s3_region', None)
|
||||
|
||||
self.aws_access_key_id = conf.get('s3_access_key_id', None)
|
||||
self.aws_secret_access_key = conf.get('s3_secret_access_key', None)
|
||||
|
||||
self.bucket_name = conf.get('s3_bucket', None)
|
||||
if not self.bucket_name:
|
||||
raise ImproperlyConfigured('Missing bucket name')
|
||||
|
||||
self.base_path = conf.get('s3_base_path', None)
|
||||
|
||||
self._s3_resource = self._connect_to_s3()
|
||||
|
||||
def _get_s3_object(self, key):
|
||||
key_bucket_path = self.base_path + key if self.base_path else key
|
||||
return self._s3_resource.Object(self.bucket_name, key_bucket_path)
|
||||
|
||||
def get(self, key):
|
||||
key = bytes_to_str(key)
|
||||
s3_object = self._get_s3_object(key)
|
||||
try:
|
||||
s3_object.load()
|
||||
data = s3_object.get()['Body'].read()
|
||||
return data if self.content_encoding == 'binary' else data.decode('utf-8')
|
||||
except botocore.exceptions.ClientError as error:
|
||||
if error.response['Error']['Code'] == "404":
|
||||
return None
|
||||
raise error
|
||||
|
||||
def set(self, key, value):
|
||||
key = bytes_to_str(key)
|
||||
s3_object = self._get_s3_object(key)
|
||||
s3_object.put(Body=value)
|
||||
|
||||
def delete(self, key):
|
||||
key = bytes_to_str(key)
|
||||
s3_object = self._get_s3_object(key)
|
||||
s3_object.delete()
|
||||
|
||||
def _connect_to_s3(self):
|
||||
session = boto3.Session(
|
||||
aws_access_key_id=self.aws_access_key_id,
|
||||
aws_secret_access_key=self.aws_secret_access_key,
|
||||
region_name=self.aws_region
|
||||
)
|
||||
if session.get_credentials() is None:
|
||||
raise ImproperlyConfigured('Missing aws s3 creds')
|
||||
return session.resource('s3', endpoint_url=self.endpoint_url)
|
||||
Reference in New Issue
Block a user