API refactor
All checks were successful
continuous-integration/drone/push Build is passing

This commit is contained in:
2025-10-07 16:25:52 +09:00
parent 76d0d86211
commit 91c7e04474
1171 changed files with 81940 additions and 44117 deletions

View File

@@ -1,4 +1,5 @@
"""The Azure Storage Block Blob backend for Celery."""
from kombu.transport.azurestoragequeues import Transport as AzureStorageQueuesTransport
from kombu.utils import cached_property
from kombu.utils.encoding import bytes_to_str
@@ -28,6 +29,13 @@ class AzureBlockBlobBackend(KeyValueStoreBackend):
container_name=None,
*args,
**kwargs):
"""
Supported URL formats:
azureblockblob://CONNECTION_STRING
azureblockblob://DefaultAzureCredential@STORAGE_ACCOUNT_URL
azureblockblob://ManagedIdentityCredential@STORAGE_ACCOUNT_URL
"""
super().__init__(*args, **kwargs)
if azurestorage is None or azurestorage.__version__ < '12':
@@ -65,11 +73,26 @@ class AzureBlockBlobBackend(KeyValueStoreBackend):
the container is created if it doesn't yet exist.
"""
client = BlobServiceClient.from_connection_string(
self._connection_string,
connection_timeout=self._connection_timeout,
read_timeout=self._read_timeout
)
if (
"DefaultAzureCredential" in self._connection_string or
"ManagedIdentityCredential" in self._connection_string
):
# Leveraging the work that Kombu already did for us
credential_, url = AzureStorageQueuesTransport.parse_uri(
self._connection_string
)
client = BlobServiceClient(
account_url=url,
credential=credential_,
connection_timeout=self._connection_timeout,
read_timeout=self._read_timeout,
)
else:
client = BlobServiceClient.from_connection_string(
self._connection_string,
connection_timeout=self._connection_timeout,
read_timeout=self._read_timeout,
)
try:
client.create_container(name=self._container_name)

View File

@@ -9,7 +9,7 @@ import sys
import time
import warnings
from collections import namedtuple
from datetime import datetime, timedelta
from datetime import timedelta
from functools import partial
from weakref import WeakValueDictionary
@@ -460,7 +460,7 @@ class Backend:
state, traceback, request, format_date=True,
encode=False):
if state in self.READY_STATES:
date_done = datetime.utcnow()
date_done = self.app.now()
if format_date:
date_done = date_done.isoformat()
else:
@@ -833,9 +833,11 @@ class BaseKeyValueStoreBackend(Backend):
"""
global_keyprefix = self.app.conf.get('result_backend_transport_options', {}).get("global_keyprefix", None)
if global_keyprefix:
self.task_keyprefix = f"{global_keyprefix}_{self.task_keyprefix}"
self.group_keyprefix = f"{global_keyprefix}_{self.group_keyprefix}"
self.chord_keyprefix = f"{global_keyprefix}_{self.chord_keyprefix}"
if global_keyprefix[-1] not in ':_-.':
global_keyprefix += '_'
self.task_keyprefix = f"{global_keyprefix}{self.task_keyprefix}"
self.group_keyprefix = f"{global_keyprefix}{self.group_keyprefix}"
self.chord_keyprefix = f"{global_keyprefix}{self.chord_keyprefix}"
def _encode_prefixes(self):
self.task_keyprefix = self.key_t(self.task_keyprefix)
@@ -1080,7 +1082,7 @@ class BaseKeyValueStoreBackend(Backend):
)
finally:
deps.delete()
self.client.delete(key)
self.delete(key)
else:
self.expire(key, self.expires)

View File

@@ -86,7 +86,7 @@ class CassandraBackend(BaseBackend):
supports_autoexpire = True # autoexpire supported via entry_ttl
def __init__(self, servers=None, keyspace=None, table=None, entry_ttl=None,
port=9042, bundle_path=None, **kwargs):
port=None, bundle_path=None, **kwargs):
super().__init__(**kwargs)
if not cassandra:
@@ -96,7 +96,7 @@ class CassandraBackend(BaseBackend):
self.servers = servers or conf.get('cassandra_servers', None)
self.bundle_path = bundle_path or conf.get(
'cassandra_secure_bundle_path', None)
self.port = port or conf.get('cassandra_port', None)
self.port = port or conf.get('cassandra_port', None) or 9042
self.keyspace = keyspace or conf.get('cassandra_keyspace', None)
self.table = table or conf.get('cassandra_table', None)
self.cassandra_options = conf.get('cassandra_options', {})

View File

@@ -98,11 +98,23 @@ class DatabaseBackend(BaseBackend):
'Missing connection string! Do you have the'
' database_url setting set to a real value?')
self.session_manager = SessionManager()
create_tables_at_setup = conf.database_create_tables_at_setup
if create_tables_at_setup is True:
self._create_tables()
@property
def extended_result(self):
return self.app.conf.find_value_for_key('extended', 'result')
def ResultSession(self, session_manager=SessionManager()):
def _create_tables(self):
"""Create the task and taskset tables."""
self.ResultSession()
def ResultSession(self, session_manager=None):
if session_manager is None:
session_manager = self.session_manager
return session_manager.session_factory(
dburi=self.url,
short_lived_sessions=self.short_lived_sessions,

View File

@@ -1,5 +1,5 @@
"""Database models used by the SQLAlchemy result store backend."""
from datetime import datetime
from datetime import datetime, timezone
import sqlalchemy as sa
from sqlalchemy.types import PickleType
@@ -22,8 +22,8 @@ class Task(ResultModelBase):
task_id = sa.Column(sa.String(155), unique=True)
status = sa.Column(sa.String(50), default=states.PENDING)
result = sa.Column(PickleType, nullable=True)
date_done = sa.Column(sa.DateTime, default=datetime.utcnow,
onupdate=datetime.utcnow, nullable=True)
date_done = sa.Column(sa.DateTime, default=datetime.now(timezone.utc),
onupdate=datetime.now(timezone.utc), nullable=True)
traceback = sa.Column(sa.Text, nullable=True)
def __init__(self, task_id):
@@ -84,7 +84,7 @@ class TaskSet(ResultModelBase):
autoincrement=True, primary_key=True)
taskset_id = sa.Column(sa.String(155), unique=True)
result = sa.Column(PickleType, nullable=True)
date_done = sa.Column(sa.DateTime, default=datetime.utcnow,
date_done = sa.Column(sa.DateTime, default=datetime.now(timezone.utc),
nullable=True)
def __init__(self, taskset_id, result):

View File

@@ -1,6 +1,8 @@
"""AWS DynamoDB result store backend."""
from collections import namedtuple
from ipaddress import ip_address
from time import sleep, time
from typing import Any, Dict
from kombu.utils.url import _parse_url as parse_url
@@ -54,11 +56,15 @@ class DynamoDBBackend(KeyValueStoreBackend):
supports_autoexpire = True
_key_field = DynamoDBAttribute(name='id', data_type='S')
# Each record has either a value field or count field
_value_field = DynamoDBAttribute(name='result', data_type='B')
_count_filed = DynamoDBAttribute(name="chord_count", data_type='N')
_timestamp_field = DynamoDBAttribute(name='timestamp', data_type='N')
_ttl_field = DynamoDBAttribute(name='ttl', data_type='N')
_available_fields = None
implements_incr = True
def __init__(self, url=None, table_name=None, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -91,9 +97,9 @@ class DynamoDBBackend(KeyValueStoreBackend):
aws_credentials_given = access_key_given
if region == 'localhost':
if region == 'localhost' or DynamoDBBackend._is_valid_ip(region):
# We are using the downloadable, local version of DynamoDB
self.endpoint_url = f'http://localhost:{port}'
self.endpoint_url = f'http://{region}:{port}'
self.aws_region = 'us-east-1'
logger.warning(
'Using local-only DynamoDB endpoint URL: {}'.format(
@@ -148,6 +154,14 @@ class DynamoDBBackend(KeyValueStoreBackend):
secret_access_key=aws_secret_access_key
)
@staticmethod
def _is_valid_ip(ip):
try:
ip_address(ip)
return True
except ValueError:
return False
def _get_client(self, access_key_id=None, secret_access_key=None):
"""Get client connection."""
if self._client is None:
@@ -459,6 +473,40 @@ class DynamoDBBackend(KeyValueStoreBackend):
})
return put_request
def _prepare_init_count_request(self, key: str) -> Dict[str, Any]:
"""Construct the counter initialization request parameters"""
timestamp = time()
return {
'TableName': self.table_name,
'Item': {
self._key_field.name: {
self._key_field.data_type: key
},
self._count_filed.name: {
self._count_filed.data_type: "0"
},
self._timestamp_field.name: {
self._timestamp_field.data_type: str(timestamp)
}
}
}
def _prepare_inc_count_request(self, key: str) -> Dict[str, Any]:
"""Construct the counter increment request parameters"""
return {
'TableName': self.table_name,
'Key': {
self._key_field.name: {
self._key_field.data_type: key
}
},
'UpdateExpression': f"set {self._count_filed.name} = {self._count_filed.name} + :num",
"ExpressionAttributeValues": {
":num": {"N": "1"},
},
"ReturnValues": "UPDATED_NEW",
}
def _item_to_dict(self, raw_response):
"""Convert get_item() response to field-value pairs."""
if 'Item' not in raw_response:
@@ -491,3 +539,18 @@ class DynamoDBBackend(KeyValueStoreBackend):
key = str(key)
request_parameters = self._prepare_get_request(key)
self.client.delete_item(**request_parameters)
def incr(self, key: bytes) -> int:
"""Atomically increase the chord_count and return the new count"""
key = str(key)
request_parameters = self._prepare_inc_count_request(key)
item_response = self.client.update_item(**request_parameters)
new_count: str = item_response["Attributes"][self._count_filed.name][self._count_filed.data_type]
return int(new_count)
def _apply_chord_incr(self, header_result_args, body, **kwargs):
chord_key = self.get_key_for_chord(header_result_args[0])
init_count_request = self._prepare_init_count_request(str(chord_key))
self.client.put_item(**init_count_request)
return super()._apply_chord_incr(
header_result_args, body, **kwargs)

View File

@@ -1,5 +1,5 @@
"""Elasticsearch result store backend."""
from datetime import datetime
from datetime import datetime, timezone
from kombu.utils.encoding import bytes_to_str
from kombu.utils.url import _parse_url
@@ -14,6 +14,11 @@ try:
except ImportError:
elasticsearch = None
try:
import elastic_transport
except ImportError:
elastic_transport = None
__all__ = ('ElasticsearchBackend',)
E_LIB_MISSING = """\
@@ -31,7 +36,7 @@ class ElasticsearchBackend(KeyValueStoreBackend):
"""
index = 'celery'
doc_type = 'backend'
doc_type = None
scheme = 'http'
host = 'localhost'
port = 9200
@@ -83,17 +88,17 @@ class ElasticsearchBackend(KeyValueStoreBackend):
self._server = None
def exception_safe_to_retry(self, exc):
if isinstance(exc, (elasticsearch.exceptions.TransportError)):
if isinstance(exc, elasticsearch.exceptions.ApiError):
# 401: Unauthorized
# 409: Conflict
# 429: Too Many Requests
# 500: Internal Server Error
# 502: Bad Gateway
# 503: Service Unavailable
# 504: Gateway Timeout
# N/A: Low level exception (i.e. socket exception)
if exc.status_code in {401, 409, 429, 500, 502, 503, 504, 'N/A'}:
if exc.status_code in {401, 409, 500, 502, 504, 'N/A'}:
return True
if isinstance(exc, elasticsearch.exceptions.TransportError):
return True
return False
def get(self, key):
@@ -108,17 +113,23 @@ class ElasticsearchBackend(KeyValueStoreBackend):
pass
def _get(self, key):
return self.server.get(
index=self.index,
doc_type=self.doc_type,
id=key,
)
if self.doc_type:
return self.server.get(
index=self.index,
id=key,
doc_type=self.doc_type,
)
else:
return self.server.get(
index=self.index,
id=key,
)
def _set_with_state(self, key, value, state):
body = {
'result': value,
'@timestamp': '{}Z'.format(
datetime.utcnow().isoformat()[:-3]
datetime.now(timezone.utc).isoformat()[:-9]
),
}
try:
@@ -135,14 +146,23 @@ class ElasticsearchBackend(KeyValueStoreBackend):
def _index(self, id, body, **kwargs):
body = {bytes_to_str(k): v for k, v in body.items()}
return self.server.index(
id=bytes_to_str(id),
index=self.index,
doc_type=self.doc_type,
body=body,
params={'op_type': 'create'},
**kwargs
)
if self.doc_type:
return self.server.index(
id=bytes_to_str(id),
index=self.index,
doc_type=self.doc_type,
body=body,
params={'op_type': 'create'},
**kwargs
)
else:
return self.server.index(
id=bytes_to_str(id),
index=self.index,
body=body,
params={'op_type': 'create'},
**kwargs
)
def _update(self, id, body, state, **kwargs):
"""Update state in a conflict free manner.
@@ -182,19 +202,32 @@ class ElasticsearchBackend(KeyValueStoreBackend):
prim_term = res_get.get('_primary_term', 1)
# try to update document with current seq_no and primary_term
res = self.server.update(
id=bytes_to_str(id),
index=self.index,
doc_type=self.doc_type,
body={'doc': body},
params={'if_primary_term': prim_term, 'if_seq_no': seq_no},
**kwargs
)
if self.doc_type:
res = self.server.update(
id=bytes_to_str(id),
index=self.index,
doc_type=self.doc_type,
body={'doc': body},
params={'if_primary_term': prim_term, 'if_seq_no': seq_no},
**kwargs
)
else:
res = self.server.update(
id=bytes_to_str(id),
index=self.index,
body={'doc': body},
params={'if_primary_term': prim_term, 'if_seq_no': seq_no},
**kwargs
)
# result is elastic search update query result
# noop = query did not update any document
# updated = at least one document got updated
if res['result'] == 'noop':
raise elasticsearch.exceptions.ConflictError(409, 'conflicting update occurred concurrently', {})
raise elasticsearch.exceptions.ConflictError(
"conflicting update occurred concurrently",
elastic_transport.ApiResponseMeta(409, "HTTP/1.1",
elastic_transport.HttpHeaders(), 0, elastic_transport.NodeConfig(
self.scheme, self.host, self.port)), None)
return res
def encode(self, data):
@@ -225,7 +258,10 @@ class ElasticsearchBackend(KeyValueStoreBackend):
return [self.get(key) for key in keys]
def delete(self, key):
self.server.delete(index=self.index, doc_type=self.doc_type, id=key)
if self.doc_type:
self.server.delete(index=self.index, id=key, doc_type=self.doc_type)
else:
self.server.delete(index=self.index, id=key)
def _get_server(self):
"""Connect to the Elasticsearch server."""
@@ -233,11 +269,10 @@ class ElasticsearchBackend(KeyValueStoreBackend):
if self.username and self.password:
http_auth = (self.username, self.password)
return elasticsearch.Elasticsearch(
f'{self.host}:{self.port}',
f'{self.scheme}://{self.host}:{self.port}',
retry_on_timeout=self.es_retry_on_timeout,
max_retries=self.es_max_retries,
timeout=self.es_timeout,
scheme=self.scheme,
http_auth=http_auth,
)

View File

@@ -50,7 +50,7 @@ class FilesystemBackend(KeyValueStoreBackend):
self.open = open
self.unlink = unlink
# Lets verify that we've everything setup right
# Let's verify that we've everything setup right
self._do_directory_test(b'.fs-backend-' + uuid().encode(encoding))
def __reduce__(self, args=(), kwargs=None):

View File

@@ -0,0 +1,352 @@
"""Google Cloud Storage result store backend for Celery."""
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime, timedelta
from os import getpid
from threading import RLock
from kombu.utils.encoding import bytes_to_str
from kombu.utils.functional import dictfilter
from kombu.utils.url import url_to_parts
from celery.canvas import maybe_signature
from celery.exceptions import ChordError, ImproperlyConfigured
from celery.result import GroupResult, allow_join_result
from celery.utils.log import get_logger
from .base import KeyValueStoreBackend
try:
import requests
from google.api_core import retry
from google.api_core.exceptions import Conflict
from google.api_core.retry import if_exception_type
from google.cloud import storage
from google.cloud.storage import Client
from google.cloud.storage.retry import DEFAULT_RETRY
except ImportError:
storage = None
try:
from google.cloud import firestore, firestore_admin_v1
except ImportError:
firestore = None
firestore_admin_v1 = None
__all__ = ('GCSBackend',)
logger = get_logger(__name__)
class GCSBackendBase(KeyValueStoreBackend):
"""Google Cloud Storage task result backend."""
def __init__(self, **kwargs):
if not storage:
raise ImproperlyConfigured(
'You must install google-cloud-storage to use gcs backend'
)
super().__init__(**kwargs)
self._client_lock = RLock()
self._pid = getpid()
self._retry_policy = DEFAULT_RETRY
self._client = None
conf = self.app.conf
if self.url:
url_params = self._params_from_url()
conf.update(**dictfilter(url_params))
self.bucket_name = conf.get('gcs_bucket')
if not self.bucket_name:
raise ImproperlyConfigured(
'Missing bucket name: specify gcs_bucket to use gcs backend'
)
self.project = conf.get('gcs_project')
if not self.project:
raise ImproperlyConfigured(
'Missing project:specify gcs_project to use gcs backend'
)
self.base_path = conf.get('gcs_base_path', '').strip('/')
self._threadpool_maxsize = int(conf.get('gcs_threadpool_maxsize', 10))
self.ttl = float(conf.get('gcs_ttl') or 0)
if self.ttl < 0:
raise ImproperlyConfigured(
f'Invalid ttl: {self.ttl} must be greater than or equal to 0'
)
elif self.ttl:
if not self._is_bucket_lifecycle_rule_exists():
raise ImproperlyConfigured(
f'Missing lifecycle rule to use gcs backend with ttl on '
f'bucket: {self.bucket_name}'
)
def get(self, key):
key = bytes_to_str(key)
blob = self._get_blob(key)
try:
return blob.download_as_bytes(retry=self._retry_policy)
except storage.blob.NotFound:
return None
def set(self, key, value):
key = bytes_to_str(key)
blob = self._get_blob(key)
if self.ttl:
blob.custom_time = datetime.utcnow() + timedelta(seconds=self.ttl)
blob.upload_from_string(value, retry=self._retry_policy)
def delete(self, key):
key = bytes_to_str(key)
blob = self._get_blob(key)
if blob.exists():
blob.delete(retry=self._retry_policy)
def mget(self, keys):
with ThreadPoolExecutor() as pool:
return list(pool.map(self.get, keys))
@property
def client(self):
"""Returns a storage client."""
# make sure it's thread-safe, as creating a new client is expensive
with self._client_lock:
if self._client and self._pid == getpid():
return self._client
# make sure each process gets its own connection after a fork
self._client = Client(project=self.project)
self._pid = getpid()
# config the number of connections to the server
adapter = requests.adapters.HTTPAdapter(
pool_connections=self._threadpool_maxsize,
pool_maxsize=self._threadpool_maxsize,
max_retries=3,
)
client_http = self._client._http
client_http.mount("https://", adapter)
client_http._auth_request.session.mount("https://", adapter)
return self._client
@property
def bucket(self):
return self.client.bucket(self.bucket_name)
def _get_blob(self, key):
key_bucket_path = f'{self.base_path}/{key}' if self.base_path else key
return self.bucket.blob(key_bucket_path)
def _is_bucket_lifecycle_rule_exists(self):
bucket = self.bucket
bucket.reload()
for rule in bucket.lifecycle_rules:
if rule['action']['type'] == 'Delete':
return True
return False
def _params_from_url(self):
url_parts = url_to_parts(self.url)
return {
'gcs_bucket': url_parts.hostname,
'gcs_base_path': url_parts.path,
**url_parts.query,
}
class GCSBackend(GCSBackendBase):
"""Google Cloud Storage task result backend.
Uses Firestore for chord ref count.
"""
implements_incr = True
supports_native_join = True
# Firestore parameters
_collection_name = 'celery'
_field_count = 'chord_count'
_field_expires = 'expires_at'
def __init__(self, **kwargs):
if not (firestore and firestore_admin_v1):
raise ImproperlyConfigured(
'You must install google-cloud-firestore to use gcs backend'
)
super().__init__(**kwargs)
self._firestore_lock = RLock()
self._firestore_client = None
self.firestore_project = self.app.conf.get(
'firestore_project', self.project
)
if not self._is_firestore_ttl_policy_enabled():
raise ImproperlyConfigured(
f'Missing TTL policy to use gcs backend with ttl on '
f'Firestore collection: {self._collection_name} '
f'project: {self.firestore_project}'
)
@property
def firestore_client(self):
"""Returns a firestore client."""
# make sure it's thread-safe, as creating a new client is expensive
with self._firestore_lock:
if self._firestore_client and self._pid == getpid():
return self._firestore_client
# make sure each process gets its own connection after a fork
self._firestore_client = firestore.Client(
project=self.firestore_project
)
self._pid = getpid()
return self._firestore_client
def _is_firestore_ttl_policy_enabled(self):
client = firestore_admin_v1.FirestoreAdminClient()
name = (
f"projects/{self.firestore_project}"
f"/databases/(default)/collectionGroups/{self._collection_name}"
f"/fields/{self._field_expires}"
)
request = firestore_admin_v1.GetFieldRequest(name=name)
field = client.get_field(request=request)
ttl_config = field.ttl_config
return ttl_config and ttl_config.state in {
firestore_admin_v1.Field.TtlConfig.State.ACTIVE,
firestore_admin_v1.Field.TtlConfig.State.CREATING,
}
def _apply_chord_incr(self, header_result_args, body, **kwargs):
key = self.get_key_for_chord(header_result_args[0]).decode()
self._expire_chord_key(key, 86400)
return super()._apply_chord_incr(header_result_args, body, **kwargs)
def incr(self, key: bytes) -> int:
doc = self._firestore_document(key)
resp = doc.set(
{self._field_count: firestore.Increment(1)},
merge=True,
retry=retry.Retry(
predicate=if_exception_type(Conflict),
initial=1.0,
maximum=180.0,
multiplier=2.0,
timeout=180.0,
),
)
return resp.transform_results[0].integer_value
def on_chord_part_return(self, request, state, result, **kwargs):
"""Chord part return callback.
Called for each task in the chord.
Increments the counter stored in Firestore.
If the counter reaches the number of tasks in the chord, the callback
is called.
If the callback raises an exception, the chord is marked as errored.
If the callback returns a value, the chord is marked as successful.
"""
app = self.app
gid = request.group
if not gid:
return
key = self.get_key_for_chord(gid)
val = self.incr(key)
size = request.chord.get("chord_size")
if size is None:
deps = self._restore_deps(gid, request)
if deps is None:
return
size = len(deps)
if val > size: # pragma: no cover
logger.warning(
'Chord counter incremented too many times for %r', gid
)
elif val == size:
# Read the deps once, to reduce the number of reads from GCS ($$)
deps = self._restore_deps(gid, request)
if deps is None:
return
callback = maybe_signature(request.chord, app=app)
j = deps.join_native
try:
with allow_join_result():
ret = j(
timeout=app.conf.result_chord_join_timeout,
propagate=True,
)
except Exception as exc: # pylint: disable=broad-except
try:
culprit = next(deps._failed_join_report())
reason = 'Dependency {0.id} raised {1!r}'.format(
culprit,
exc,
)
except StopIteration:
reason = repr(exc)
logger.exception('Chord %r raised: %r', gid, reason)
self.chord_error_from_stack(callback, ChordError(reason))
else:
try:
callback.delay(ret)
except Exception as exc: # pylint: disable=broad-except
logger.exception('Chord %r raised: %r', gid, exc)
self.chord_error_from_stack(
callback,
ChordError(f'Callback error: {exc!r}'),
)
finally:
deps.delete()
# Firestore doesn't have an exact ttl policy, so delete the key.
self._delete_chord_key(key)
def _restore_deps(self, gid, request):
app = self.app
try:
deps = GroupResult.restore(gid, backend=self)
except Exception as exc: # pylint: disable=broad-except
callback = maybe_signature(request.chord, app=app)
logger.exception('Chord %r raised: %r', gid, exc)
self.chord_error_from_stack(
callback,
ChordError(f'Cannot restore group: {exc!r}'),
)
return
if deps is None:
try:
raise ValueError(gid)
except ValueError as exc:
callback = maybe_signature(request.chord, app=app)
logger.exception('Chord callback %r raised: %r', gid, exc)
self.chord_error_from_stack(
callback,
ChordError(f'GroupResult {gid} no longer exists'),
)
return deps
def _delete_chord_key(self, key):
doc = self._firestore_document(key)
doc.delete()
def _expire_chord_key(self, key, expires):
"""Set TTL policy for a Firestore document.
Firestore ttl data is typically deleted within 24 hours after its
expiration date.
"""
val_expires = datetime.utcnow() + timedelta(seconds=expires)
doc = self._firestore_document(key)
doc.set({self._field_expires: val_expires}, merge=True)
def _firestore_document(self, key):
return self.firestore_client.collection(
self._collection_name
).document(bytes_to_str(key))

View File

@@ -1,5 +1,5 @@
"""MongoDB result store backend."""
from datetime import datetime, timedelta
from datetime import datetime, timedelta, timezone
from kombu.exceptions import EncodeError
from kombu.utils.objects import cached_property
@@ -228,7 +228,7 @@ class MongoBackend(BaseBackend):
meta = {
'_id': group_id,
'result': self.encode([i.id for i in result]),
'date_done': datetime.utcnow(),
'date_done': datetime.now(timezone.utc),
}
self.group_collection.replace_one({'_id': group_id}, meta, upsert=True)
return result

View File

@@ -359,6 +359,11 @@ class RedisBackend(BaseKeyValueStoreBackend, AsyncBackendMixin):
connparams.update(query)
return connparams
def exception_safe_to_retry(self, exc):
if isinstance(exc, self.connection_errors):
return True
return False
@cached_property
def retry_policy(self):
retry_policy = super().retry_policy

View File

@@ -222,7 +222,7 @@ class RPCBackend(base.Backend, AsyncBackendMixin):
def on_out_of_band_result(self, task_id, message):
# Callback called when a reply for a task is received,
# but we have no idea what do do with it.
# but we have no idea what to do with it.
# Since the result is not pending, we put it in a separate
# buffer: probably it will become pending later.
if self.result_consumer: