This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
"""The Azure Storage Block Blob backend for Celery."""
|
||||
from kombu.transport.azurestoragequeues import Transport as AzureStorageQueuesTransport
|
||||
from kombu.utils import cached_property
|
||||
from kombu.utils.encoding import bytes_to_str
|
||||
|
||||
@@ -28,6 +29,13 @@ class AzureBlockBlobBackend(KeyValueStoreBackend):
|
||||
container_name=None,
|
||||
*args,
|
||||
**kwargs):
|
||||
"""
|
||||
Supported URL formats:
|
||||
|
||||
azureblockblob://CONNECTION_STRING
|
||||
azureblockblob://DefaultAzureCredential@STORAGE_ACCOUNT_URL
|
||||
azureblockblob://ManagedIdentityCredential@STORAGE_ACCOUNT_URL
|
||||
"""
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
if azurestorage is None or azurestorage.__version__ < '12':
|
||||
@@ -65,11 +73,26 @@ class AzureBlockBlobBackend(KeyValueStoreBackend):
|
||||
the container is created if it doesn't yet exist.
|
||||
|
||||
"""
|
||||
client = BlobServiceClient.from_connection_string(
|
||||
self._connection_string,
|
||||
connection_timeout=self._connection_timeout,
|
||||
read_timeout=self._read_timeout
|
||||
)
|
||||
if (
|
||||
"DefaultAzureCredential" in self._connection_string or
|
||||
"ManagedIdentityCredential" in self._connection_string
|
||||
):
|
||||
# Leveraging the work that Kombu already did for us
|
||||
credential_, url = AzureStorageQueuesTransport.parse_uri(
|
||||
self._connection_string
|
||||
)
|
||||
client = BlobServiceClient(
|
||||
account_url=url,
|
||||
credential=credential_,
|
||||
connection_timeout=self._connection_timeout,
|
||||
read_timeout=self._read_timeout,
|
||||
)
|
||||
else:
|
||||
client = BlobServiceClient.from_connection_string(
|
||||
self._connection_string,
|
||||
connection_timeout=self._connection_timeout,
|
||||
read_timeout=self._read_timeout,
|
||||
)
|
||||
|
||||
try:
|
||||
client.create_container(name=self._container_name)
|
||||
|
||||
@@ -9,7 +9,7 @@ import sys
|
||||
import time
|
||||
import warnings
|
||||
from collections import namedtuple
|
||||
from datetime import datetime, timedelta
|
||||
from datetime import timedelta
|
||||
from functools import partial
|
||||
from weakref import WeakValueDictionary
|
||||
|
||||
@@ -460,7 +460,7 @@ class Backend:
|
||||
state, traceback, request, format_date=True,
|
||||
encode=False):
|
||||
if state in self.READY_STATES:
|
||||
date_done = datetime.utcnow()
|
||||
date_done = self.app.now()
|
||||
if format_date:
|
||||
date_done = date_done.isoformat()
|
||||
else:
|
||||
@@ -833,9 +833,11 @@ class BaseKeyValueStoreBackend(Backend):
|
||||
"""
|
||||
global_keyprefix = self.app.conf.get('result_backend_transport_options', {}).get("global_keyprefix", None)
|
||||
if global_keyprefix:
|
||||
self.task_keyprefix = f"{global_keyprefix}_{self.task_keyprefix}"
|
||||
self.group_keyprefix = f"{global_keyprefix}_{self.group_keyprefix}"
|
||||
self.chord_keyprefix = f"{global_keyprefix}_{self.chord_keyprefix}"
|
||||
if global_keyprefix[-1] not in ':_-.':
|
||||
global_keyprefix += '_'
|
||||
self.task_keyprefix = f"{global_keyprefix}{self.task_keyprefix}"
|
||||
self.group_keyprefix = f"{global_keyprefix}{self.group_keyprefix}"
|
||||
self.chord_keyprefix = f"{global_keyprefix}{self.chord_keyprefix}"
|
||||
|
||||
def _encode_prefixes(self):
|
||||
self.task_keyprefix = self.key_t(self.task_keyprefix)
|
||||
@@ -1080,7 +1082,7 @@ class BaseKeyValueStoreBackend(Backend):
|
||||
)
|
||||
finally:
|
||||
deps.delete()
|
||||
self.client.delete(key)
|
||||
self.delete(key)
|
||||
else:
|
||||
self.expire(key, self.expires)
|
||||
|
||||
|
||||
@@ -86,7 +86,7 @@ class CassandraBackend(BaseBackend):
|
||||
supports_autoexpire = True # autoexpire supported via entry_ttl
|
||||
|
||||
def __init__(self, servers=None, keyspace=None, table=None, entry_ttl=None,
|
||||
port=9042, bundle_path=None, **kwargs):
|
||||
port=None, bundle_path=None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
if not cassandra:
|
||||
@@ -96,7 +96,7 @@ class CassandraBackend(BaseBackend):
|
||||
self.servers = servers or conf.get('cassandra_servers', None)
|
||||
self.bundle_path = bundle_path or conf.get(
|
||||
'cassandra_secure_bundle_path', None)
|
||||
self.port = port or conf.get('cassandra_port', None)
|
||||
self.port = port or conf.get('cassandra_port', None) or 9042
|
||||
self.keyspace = keyspace or conf.get('cassandra_keyspace', None)
|
||||
self.table = table or conf.get('cassandra_table', None)
|
||||
self.cassandra_options = conf.get('cassandra_options', {})
|
||||
|
||||
@@ -98,11 +98,23 @@ class DatabaseBackend(BaseBackend):
|
||||
'Missing connection string! Do you have the'
|
||||
' database_url setting set to a real value?')
|
||||
|
||||
self.session_manager = SessionManager()
|
||||
|
||||
create_tables_at_setup = conf.database_create_tables_at_setup
|
||||
if create_tables_at_setup is True:
|
||||
self._create_tables()
|
||||
|
||||
@property
|
||||
def extended_result(self):
|
||||
return self.app.conf.find_value_for_key('extended', 'result')
|
||||
|
||||
def ResultSession(self, session_manager=SessionManager()):
|
||||
def _create_tables(self):
|
||||
"""Create the task and taskset tables."""
|
||||
self.ResultSession()
|
||||
|
||||
def ResultSession(self, session_manager=None):
|
||||
if session_manager is None:
|
||||
session_manager = self.session_manager
|
||||
return session_manager.session_factory(
|
||||
dburi=self.url,
|
||||
short_lived_sessions=self.short_lived_sessions,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
"""Database models used by the SQLAlchemy result store backend."""
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.types import PickleType
|
||||
@@ -22,8 +22,8 @@ class Task(ResultModelBase):
|
||||
task_id = sa.Column(sa.String(155), unique=True)
|
||||
status = sa.Column(sa.String(50), default=states.PENDING)
|
||||
result = sa.Column(PickleType, nullable=True)
|
||||
date_done = sa.Column(sa.DateTime, default=datetime.utcnow,
|
||||
onupdate=datetime.utcnow, nullable=True)
|
||||
date_done = sa.Column(sa.DateTime, default=datetime.now(timezone.utc),
|
||||
onupdate=datetime.now(timezone.utc), nullable=True)
|
||||
traceback = sa.Column(sa.Text, nullable=True)
|
||||
|
||||
def __init__(self, task_id):
|
||||
@@ -84,7 +84,7 @@ class TaskSet(ResultModelBase):
|
||||
autoincrement=True, primary_key=True)
|
||||
taskset_id = sa.Column(sa.String(155), unique=True)
|
||||
result = sa.Column(PickleType, nullable=True)
|
||||
date_done = sa.Column(sa.DateTime, default=datetime.utcnow,
|
||||
date_done = sa.Column(sa.DateTime, default=datetime.now(timezone.utc),
|
||||
nullable=True)
|
||||
|
||||
def __init__(self, taskset_id, result):
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
"""AWS DynamoDB result store backend."""
|
||||
from collections import namedtuple
|
||||
from ipaddress import ip_address
|
||||
from time import sleep, time
|
||||
from typing import Any, Dict
|
||||
|
||||
from kombu.utils.url import _parse_url as parse_url
|
||||
|
||||
@@ -54,11 +56,15 @@ class DynamoDBBackend(KeyValueStoreBackend):
|
||||
supports_autoexpire = True
|
||||
|
||||
_key_field = DynamoDBAttribute(name='id', data_type='S')
|
||||
# Each record has either a value field or count field
|
||||
_value_field = DynamoDBAttribute(name='result', data_type='B')
|
||||
_count_filed = DynamoDBAttribute(name="chord_count", data_type='N')
|
||||
_timestamp_field = DynamoDBAttribute(name='timestamp', data_type='N')
|
||||
_ttl_field = DynamoDBAttribute(name='ttl', data_type='N')
|
||||
_available_fields = None
|
||||
|
||||
implements_incr = True
|
||||
|
||||
def __init__(self, url=None, table_name=None, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
@@ -91,9 +97,9 @@ class DynamoDBBackend(KeyValueStoreBackend):
|
||||
|
||||
aws_credentials_given = access_key_given
|
||||
|
||||
if region == 'localhost':
|
||||
if region == 'localhost' or DynamoDBBackend._is_valid_ip(region):
|
||||
# We are using the downloadable, local version of DynamoDB
|
||||
self.endpoint_url = f'http://localhost:{port}'
|
||||
self.endpoint_url = f'http://{region}:{port}'
|
||||
self.aws_region = 'us-east-1'
|
||||
logger.warning(
|
||||
'Using local-only DynamoDB endpoint URL: {}'.format(
|
||||
@@ -148,6 +154,14 @@ class DynamoDBBackend(KeyValueStoreBackend):
|
||||
secret_access_key=aws_secret_access_key
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _is_valid_ip(ip):
|
||||
try:
|
||||
ip_address(ip)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
def _get_client(self, access_key_id=None, secret_access_key=None):
|
||||
"""Get client connection."""
|
||||
if self._client is None:
|
||||
@@ -459,6 +473,40 @@ class DynamoDBBackend(KeyValueStoreBackend):
|
||||
})
|
||||
return put_request
|
||||
|
||||
def _prepare_init_count_request(self, key: str) -> Dict[str, Any]:
|
||||
"""Construct the counter initialization request parameters"""
|
||||
timestamp = time()
|
||||
return {
|
||||
'TableName': self.table_name,
|
||||
'Item': {
|
||||
self._key_field.name: {
|
||||
self._key_field.data_type: key
|
||||
},
|
||||
self._count_filed.name: {
|
||||
self._count_filed.data_type: "0"
|
||||
},
|
||||
self._timestamp_field.name: {
|
||||
self._timestamp_field.data_type: str(timestamp)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def _prepare_inc_count_request(self, key: str) -> Dict[str, Any]:
|
||||
"""Construct the counter increment request parameters"""
|
||||
return {
|
||||
'TableName': self.table_name,
|
||||
'Key': {
|
||||
self._key_field.name: {
|
||||
self._key_field.data_type: key
|
||||
}
|
||||
},
|
||||
'UpdateExpression': f"set {self._count_filed.name} = {self._count_filed.name} + :num",
|
||||
"ExpressionAttributeValues": {
|
||||
":num": {"N": "1"},
|
||||
},
|
||||
"ReturnValues": "UPDATED_NEW",
|
||||
}
|
||||
|
||||
def _item_to_dict(self, raw_response):
|
||||
"""Convert get_item() response to field-value pairs."""
|
||||
if 'Item' not in raw_response:
|
||||
@@ -491,3 +539,18 @@ class DynamoDBBackend(KeyValueStoreBackend):
|
||||
key = str(key)
|
||||
request_parameters = self._prepare_get_request(key)
|
||||
self.client.delete_item(**request_parameters)
|
||||
|
||||
def incr(self, key: bytes) -> int:
|
||||
"""Atomically increase the chord_count and return the new count"""
|
||||
key = str(key)
|
||||
request_parameters = self._prepare_inc_count_request(key)
|
||||
item_response = self.client.update_item(**request_parameters)
|
||||
new_count: str = item_response["Attributes"][self._count_filed.name][self._count_filed.data_type]
|
||||
return int(new_count)
|
||||
|
||||
def _apply_chord_incr(self, header_result_args, body, **kwargs):
|
||||
chord_key = self.get_key_for_chord(header_result_args[0])
|
||||
init_count_request = self._prepare_init_count_request(str(chord_key))
|
||||
self.client.put_item(**init_count_request)
|
||||
return super()._apply_chord_incr(
|
||||
header_result_args, body, **kwargs)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
"""Elasticsearch result store backend."""
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from kombu.utils.encoding import bytes_to_str
|
||||
from kombu.utils.url import _parse_url
|
||||
@@ -14,6 +14,11 @@ try:
|
||||
except ImportError:
|
||||
elasticsearch = None
|
||||
|
||||
try:
|
||||
import elastic_transport
|
||||
except ImportError:
|
||||
elastic_transport = None
|
||||
|
||||
__all__ = ('ElasticsearchBackend',)
|
||||
|
||||
E_LIB_MISSING = """\
|
||||
@@ -31,7 +36,7 @@ class ElasticsearchBackend(KeyValueStoreBackend):
|
||||
"""
|
||||
|
||||
index = 'celery'
|
||||
doc_type = 'backend'
|
||||
doc_type = None
|
||||
scheme = 'http'
|
||||
host = 'localhost'
|
||||
port = 9200
|
||||
@@ -83,17 +88,17 @@ class ElasticsearchBackend(KeyValueStoreBackend):
|
||||
self._server = None
|
||||
|
||||
def exception_safe_to_retry(self, exc):
|
||||
if isinstance(exc, (elasticsearch.exceptions.TransportError)):
|
||||
if isinstance(exc, elasticsearch.exceptions.ApiError):
|
||||
# 401: Unauthorized
|
||||
# 409: Conflict
|
||||
# 429: Too Many Requests
|
||||
# 500: Internal Server Error
|
||||
# 502: Bad Gateway
|
||||
# 503: Service Unavailable
|
||||
# 504: Gateway Timeout
|
||||
# N/A: Low level exception (i.e. socket exception)
|
||||
if exc.status_code in {401, 409, 429, 500, 502, 503, 504, 'N/A'}:
|
||||
if exc.status_code in {401, 409, 500, 502, 504, 'N/A'}:
|
||||
return True
|
||||
if isinstance(exc, elasticsearch.exceptions.TransportError):
|
||||
return True
|
||||
return False
|
||||
|
||||
def get(self, key):
|
||||
@@ -108,17 +113,23 @@ class ElasticsearchBackend(KeyValueStoreBackend):
|
||||
pass
|
||||
|
||||
def _get(self, key):
|
||||
return self.server.get(
|
||||
index=self.index,
|
||||
doc_type=self.doc_type,
|
||||
id=key,
|
||||
)
|
||||
if self.doc_type:
|
||||
return self.server.get(
|
||||
index=self.index,
|
||||
id=key,
|
||||
doc_type=self.doc_type,
|
||||
)
|
||||
else:
|
||||
return self.server.get(
|
||||
index=self.index,
|
||||
id=key,
|
||||
)
|
||||
|
||||
def _set_with_state(self, key, value, state):
|
||||
body = {
|
||||
'result': value,
|
||||
'@timestamp': '{}Z'.format(
|
||||
datetime.utcnow().isoformat()[:-3]
|
||||
datetime.now(timezone.utc).isoformat()[:-9]
|
||||
),
|
||||
}
|
||||
try:
|
||||
@@ -135,14 +146,23 @@ class ElasticsearchBackend(KeyValueStoreBackend):
|
||||
|
||||
def _index(self, id, body, **kwargs):
|
||||
body = {bytes_to_str(k): v for k, v in body.items()}
|
||||
return self.server.index(
|
||||
id=bytes_to_str(id),
|
||||
index=self.index,
|
||||
doc_type=self.doc_type,
|
||||
body=body,
|
||||
params={'op_type': 'create'},
|
||||
**kwargs
|
||||
)
|
||||
if self.doc_type:
|
||||
return self.server.index(
|
||||
id=bytes_to_str(id),
|
||||
index=self.index,
|
||||
doc_type=self.doc_type,
|
||||
body=body,
|
||||
params={'op_type': 'create'},
|
||||
**kwargs
|
||||
)
|
||||
else:
|
||||
return self.server.index(
|
||||
id=bytes_to_str(id),
|
||||
index=self.index,
|
||||
body=body,
|
||||
params={'op_type': 'create'},
|
||||
**kwargs
|
||||
)
|
||||
|
||||
def _update(self, id, body, state, **kwargs):
|
||||
"""Update state in a conflict free manner.
|
||||
@@ -182,19 +202,32 @@ class ElasticsearchBackend(KeyValueStoreBackend):
|
||||
prim_term = res_get.get('_primary_term', 1)
|
||||
|
||||
# try to update document with current seq_no and primary_term
|
||||
res = self.server.update(
|
||||
id=bytes_to_str(id),
|
||||
index=self.index,
|
||||
doc_type=self.doc_type,
|
||||
body={'doc': body},
|
||||
params={'if_primary_term': prim_term, 'if_seq_no': seq_no},
|
||||
**kwargs
|
||||
)
|
||||
if self.doc_type:
|
||||
res = self.server.update(
|
||||
id=bytes_to_str(id),
|
||||
index=self.index,
|
||||
doc_type=self.doc_type,
|
||||
body={'doc': body},
|
||||
params={'if_primary_term': prim_term, 'if_seq_no': seq_no},
|
||||
**kwargs
|
||||
)
|
||||
else:
|
||||
res = self.server.update(
|
||||
id=bytes_to_str(id),
|
||||
index=self.index,
|
||||
body={'doc': body},
|
||||
params={'if_primary_term': prim_term, 'if_seq_no': seq_no},
|
||||
**kwargs
|
||||
)
|
||||
# result is elastic search update query result
|
||||
# noop = query did not update any document
|
||||
# updated = at least one document got updated
|
||||
if res['result'] == 'noop':
|
||||
raise elasticsearch.exceptions.ConflictError(409, 'conflicting update occurred concurrently', {})
|
||||
raise elasticsearch.exceptions.ConflictError(
|
||||
"conflicting update occurred concurrently",
|
||||
elastic_transport.ApiResponseMeta(409, "HTTP/1.1",
|
||||
elastic_transport.HttpHeaders(), 0, elastic_transport.NodeConfig(
|
||||
self.scheme, self.host, self.port)), None)
|
||||
return res
|
||||
|
||||
def encode(self, data):
|
||||
@@ -225,7 +258,10 @@ class ElasticsearchBackend(KeyValueStoreBackend):
|
||||
return [self.get(key) for key in keys]
|
||||
|
||||
def delete(self, key):
|
||||
self.server.delete(index=self.index, doc_type=self.doc_type, id=key)
|
||||
if self.doc_type:
|
||||
self.server.delete(index=self.index, id=key, doc_type=self.doc_type)
|
||||
else:
|
||||
self.server.delete(index=self.index, id=key)
|
||||
|
||||
def _get_server(self):
|
||||
"""Connect to the Elasticsearch server."""
|
||||
@@ -233,11 +269,10 @@ class ElasticsearchBackend(KeyValueStoreBackend):
|
||||
if self.username and self.password:
|
||||
http_auth = (self.username, self.password)
|
||||
return elasticsearch.Elasticsearch(
|
||||
f'{self.host}:{self.port}',
|
||||
f'{self.scheme}://{self.host}:{self.port}',
|
||||
retry_on_timeout=self.es_retry_on_timeout,
|
||||
max_retries=self.es_max_retries,
|
||||
timeout=self.es_timeout,
|
||||
scheme=self.scheme,
|
||||
http_auth=http_auth,
|
||||
)
|
||||
|
||||
|
||||
@@ -50,7 +50,7 @@ class FilesystemBackend(KeyValueStoreBackend):
|
||||
self.open = open
|
||||
self.unlink = unlink
|
||||
|
||||
# Lets verify that we've everything setup right
|
||||
# Let's verify that we've everything setup right
|
||||
self._do_directory_test(b'.fs-backend-' + uuid().encode(encoding))
|
||||
|
||||
def __reduce__(self, args=(), kwargs=None):
|
||||
|
||||
352
venv/lib/python3.12/site-packages/celery/backends/gcs.py
Normal file
352
venv/lib/python3.12/site-packages/celery/backends/gcs.py
Normal file
@@ -0,0 +1,352 @@
|
||||
"""Google Cloud Storage result store backend for Celery."""
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import datetime, timedelta
|
||||
from os import getpid
|
||||
from threading import RLock
|
||||
|
||||
from kombu.utils.encoding import bytes_to_str
|
||||
from kombu.utils.functional import dictfilter
|
||||
from kombu.utils.url import url_to_parts
|
||||
|
||||
from celery.canvas import maybe_signature
|
||||
from celery.exceptions import ChordError, ImproperlyConfigured
|
||||
from celery.result import GroupResult, allow_join_result
|
||||
from celery.utils.log import get_logger
|
||||
|
||||
from .base import KeyValueStoreBackend
|
||||
|
||||
try:
|
||||
import requests
|
||||
from google.api_core import retry
|
||||
from google.api_core.exceptions import Conflict
|
||||
from google.api_core.retry import if_exception_type
|
||||
from google.cloud import storage
|
||||
from google.cloud.storage import Client
|
||||
from google.cloud.storage.retry import DEFAULT_RETRY
|
||||
except ImportError:
|
||||
storage = None
|
||||
|
||||
try:
|
||||
from google.cloud import firestore, firestore_admin_v1
|
||||
except ImportError:
|
||||
firestore = None
|
||||
firestore_admin_v1 = None
|
||||
|
||||
|
||||
__all__ = ('GCSBackend',)
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class GCSBackendBase(KeyValueStoreBackend):
|
||||
"""Google Cloud Storage task result backend."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
if not storage:
|
||||
raise ImproperlyConfigured(
|
||||
'You must install google-cloud-storage to use gcs backend'
|
||||
)
|
||||
super().__init__(**kwargs)
|
||||
self._client_lock = RLock()
|
||||
self._pid = getpid()
|
||||
self._retry_policy = DEFAULT_RETRY
|
||||
self._client = None
|
||||
|
||||
conf = self.app.conf
|
||||
if self.url:
|
||||
url_params = self._params_from_url()
|
||||
conf.update(**dictfilter(url_params))
|
||||
|
||||
self.bucket_name = conf.get('gcs_bucket')
|
||||
if not self.bucket_name:
|
||||
raise ImproperlyConfigured(
|
||||
'Missing bucket name: specify gcs_bucket to use gcs backend'
|
||||
)
|
||||
self.project = conf.get('gcs_project')
|
||||
if not self.project:
|
||||
raise ImproperlyConfigured(
|
||||
'Missing project:specify gcs_project to use gcs backend'
|
||||
)
|
||||
self.base_path = conf.get('gcs_base_path', '').strip('/')
|
||||
self._threadpool_maxsize = int(conf.get('gcs_threadpool_maxsize', 10))
|
||||
self.ttl = float(conf.get('gcs_ttl') or 0)
|
||||
if self.ttl < 0:
|
||||
raise ImproperlyConfigured(
|
||||
f'Invalid ttl: {self.ttl} must be greater than or equal to 0'
|
||||
)
|
||||
elif self.ttl:
|
||||
if not self._is_bucket_lifecycle_rule_exists():
|
||||
raise ImproperlyConfigured(
|
||||
f'Missing lifecycle rule to use gcs backend with ttl on '
|
||||
f'bucket: {self.bucket_name}'
|
||||
)
|
||||
|
||||
def get(self, key):
|
||||
key = bytes_to_str(key)
|
||||
blob = self._get_blob(key)
|
||||
try:
|
||||
return blob.download_as_bytes(retry=self._retry_policy)
|
||||
except storage.blob.NotFound:
|
||||
return None
|
||||
|
||||
def set(self, key, value):
|
||||
key = bytes_to_str(key)
|
||||
blob = self._get_blob(key)
|
||||
if self.ttl:
|
||||
blob.custom_time = datetime.utcnow() + timedelta(seconds=self.ttl)
|
||||
blob.upload_from_string(value, retry=self._retry_policy)
|
||||
|
||||
def delete(self, key):
|
||||
key = bytes_to_str(key)
|
||||
blob = self._get_blob(key)
|
||||
if blob.exists():
|
||||
blob.delete(retry=self._retry_policy)
|
||||
|
||||
def mget(self, keys):
|
||||
with ThreadPoolExecutor() as pool:
|
||||
return list(pool.map(self.get, keys))
|
||||
|
||||
@property
|
||||
def client(self):
|
||||
"""Returns a storage client."""
|
||||
|
||||
# make sure it's thread-safe, as creating a new client is expensive
|
||||
with self._client_lock:
|
||||
if self._client and self._pid == getpid():
|
||||
return self._client
|
||||
# make sure each process gets its own connection after a fork
|
||||
self._client = Client(project=self.project)
|
||||
self._pid = getpid()
|
||||
|
||||
# config the number of connections to the server
|
||||
adapter = requests.adapters.HTTPAdapter(
|
||||
pool_connections=self._threadpool_maxsize,
|
||||
pool_maxsize=self._threadpool_maxsize,
|
||||
max_retries=3,
|
||||
)
|
||||
client_http = self._client._http
|
||||
client_http.mount("https://", adapter)
|
||||
client_http._auth_request.session.mount("https://", adapter)
|
||||
|
||||
return self._client
|
||||
|
||||
@property
|
||||
def bucket(self):
|
||||
return self.client.bucket(self.bucket_name)
|
||||
|
||||
def _get_blob(self, key):
|
||||
key_bucket_path = f'{self.base_path}/{key}' if self.base_path else key
|
||||
return self.bucket.blob(key_bucket_path)
|
||||
|
||||
def _is_bucket_lifecycle_rule_exists(self):
|
||||
bucket = self.bucket
|
||||
bucket.reload()
|
||||
for rule in bucket.lifecycle_rules:
|
||||
if rule['action']['type'] == 'Delete':
|
||||
return True
|
||||
return False
|
||||
|
||||
def _params_from_url(self):
|
||||
url_parts = url_to_parts(self.url)
|
||||
|
||||
return {
|
||||
'gcs_bucket': url_parts.hostname,
|
||||
'gcs_base_path': url_parts.path,
|
||||
**url_parts.query,
|
||||
}
|
||||
|
||||
|
||||
class GCSBackend(GCSBackendBase):
|
||||
"""Google Cloud Storage task result backend.
|
||||
|
||||
Uses Firestore for chord ref count.
|
||||
"""
|
||||
|
||||
implements_incr = True
|
||||
supports_native_join = True
|
||||
|
||||
# Firestore parameters
|
||||
_collection_name = 'celery'
|
||||
_field_count = 'chord_count'
|
||||
_field_expires = 'expires_at'
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
if not (firestore and firestore_admin_v1):
|
||||
raise ImproperlyConfigured(
|
||||
'You must install google-cloud-firestore to use gcs backend'
|
||||
)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self._firestore_lock = RLock()
|
||||
self._firestore_client = None
|
||||
|
||||
self.firestore_project = self.app.conf.get(
|
||||
'firestore_project', self.project
|
||||
)
|
||||
if not self._is_firestore_ttl_policy_enabled():
|
||||
raise ImproperlyConfigured(
|
||||
f'Missing TTL policy to use gcs backend with ttl on '
|
||||
f'Firestore collection: {self._collection_name} '
|
||||
f'project: {self.firestore_project}'
|
||||
)
|
||||
|
||||
@property
|
||||
def firestore_client(self):
|
||||
"""Returns a firestore client."""
|
||||
|
||||
# make sure it's thread-safe, as creating a new client is expensive
|
||||
with self._firestore_lock:
|
||||
if self._firestore_client and self._pid == getpid():
|
||||
return self._firestore_client
|
||||
# make sure each process gets its own connection after a fork
|
||||
self._firestore_client = firestore.Client(
|
||||
project=self.firestore_project
|
||||
)
|
||||
self._pid = getpid()
|
||||
return self._firestore_client
|
||||
|
||||
def _is_firestore_ttl_policy_enabled(self):
|
||||
client = firestore_admin_v1.FirestoreAdminClient()
|
||||
|
||||
name = (
|
||||
f"projects/{self.firestore_project}"
|
||||
f"/databases/(default)/collectionGroups/{self._collection_name}"
|
||||
f"/fields/{self._field_expires}"
|
||||
)
|
||||
request = firestore_admin_v1.GetFieldRequest(name=name)
|
||||
field = client.get_field(request=request)
|
||||
|
||||
ttl_config = field.ttl_config
|
||||
return ttl_config and ttl_config.state in {
|
||||
firestore_admin_v1.Field.TtlConfig.State.ACTIVE,
|
||||
firestore_admin_v1.Field.TtlConfig.State.CREATING,
|
||||
}
|
||||
|
||||
def _apply_chord_incr(self, header_result_args, body, **kwargs):
|
||||
key = self.get_key_for_chord(header_result_args[0]).decode()
|
||||
self._expire_chord_key(key, 86400)
|
||||
return super()._apply_chord_incr(header_result_args, body, **kwargs)
|
||||
|
||||
def incr(self, key: bytes) -> int:
|
||||
doc = self._firestore_document(key)
|
||||
resp = doc.set(
|
||||
{self._field_count: firestore.Increment(1)},
|
||||
merge=True,
|
||||
retry=retry.Retry(
|
||||
predicate=if_exception_type(Conflict),
|
||||
initial=1.0,
|
||||
maximum=180.0,
|
||||
multiplier=2.0,
|
||||
timeout=180.0,
|
||||
),
|
||||
)
|
||||
return resp.transform_results[0].integer_value
|
||||
|
||||
def on_chord_part_return(self, request, state, result, **kwargs):
|
||||
"""Chord part return callback.
|
||||
|
||||
Called for each task in the chord.
|
||||
Increments the counter stored in Firestore.
|
||||
If the counter reaches the number of tasks in the chord, the callback
|
||||
is called.
|
||||
If the callback raises an exception, the chord is marked as errored.
|
||||
If the callback returns a value, the chord is marked as successful.
|
||||
"""
|
||||
app = self.app
|
||||
gid = request.group
|
||||
if not gid:
|
||||
return
|
||||
key = self.get_key_for_chord(gid)
|
||||
val = self.incr(key)
|
||||
size = request.chord.get("chord_size")
|
||||
if size is None:
|
||||
deps = self._restore_deps(gid, request)
|
||||
if deps is None:
|
||||
return
|
||||
size = len(deps)
|
||||
if val > size: # pragma: no cover
|
||||
logger.warning(
|
||||
'Chord counter incremented too many times for %r', gid
|
||||
)
|
||||
elif val == size:
|
||||
# Read the deps once, to reduce the number of reads from GCS ($$)
|
||||
deps = self._restore_deps(gid, request)
|
||||
if deps is None:
|
||||
return
|
||||
callback = maybe_signature(request.chord, app=app)
|
||||
j = deps.join_native
|
||||
try:
|
||||
with allow_join_result():
|
||||
ret = j(
|
||||
timeout=app.conf.result_chord_join_timeout,
|
||||
propagate=True,
|
||||
)
|
||||
except Exception as exc: # pylint: disable=broad-except
|
||||
try:
|
||||
culprit = next(deps._failed_join_report())
|
||||
reason = 'Dependency {0.id} raised {1!r}'.format(
|
||||
culprit,
|
||||
exc,
|
||||
)
|
||||
except StopIteration:
|
||||
reason = repr(exc)
|
||||
|
||||
logger.exception('Chord %r raised: %r', gid, reason)
|
||||
self.chord_error_from_stack(callback, ChordError(reason))
|
||||
else:
|
||||
try:
|
||||
callback.delay(ret)
|
||||
except Exception as exc: # pylint: disable=broad-except
|
||||
logger.exception('Chord %r raised: %r', gid, exc)
|
||||
self.chord_error_from_stack(
|
||||
callback,
|
||||
ChordError(f'Callback error: {exc!r}'),
|
||||
)
|
||||
finally:
|
||||
deps.delete()
|
||||
# Firestore doesn't have an exact ttl policy, so delete the key.
|
||||
self._delete_chord_key(key)
|
||||
|
||||
def _restore_deps(self, gid, request):
|
||||
app = self.app
|
||||
try:
|
||||
deps = GroupResult.restore(gid, backend=self)
|
||||
except Exception as exc: # pylint: disable=broad-except
|
||||
callback = maybe_signature(request.chord, app=app)
|
||||
logger.exception('Chord %r raised: %r', gid, exc)
|
||||
self.chord_error_from_stack(
|
||||
callback,
|
||||
ChordError(f'Cannot restore group: {exc!r}'),
|
||||
)
|
||||
return
|
||||
if deps is None:
|
||||
try:
|
||||
raise ValueError(gid)
|
||||
except ValueError as exc:
|
||||
callback = maybe_signature(request.chord, app=app)
|
||||
logger.exception('Chord callback %r raised: %r', gid, exc)
|
||||
self.chord_error_from_stack(
|
||||
callback,
|
||||
ChordError(f'GroupResult {gid} no longer exists'),
|
||||
)
|
||||
return deps
|
||||
|
||||
def _delete_chord_key(self, key):
|
||||
doc = self._firestore_document(key)
|
||||
doc.delete()
|
||||
|
||||
def _expire_chord_key(self, key, expires):
|
||||
"""Set TTL policy for a Firestore document.
|
||||
|
||||
Firestore ttl data is typically deleted within 24 hours after its
|
||||
expiration date.
|
||||
"""
|
||||
val_expires = datetime.utcnow() + timedelta(seconds=expires)
|
||||
doc = self._firestore_document(key)
|
||||
doc.set({self._field_expires: val_expires}, merge=True)
|
||||
|
||||
def _firestore_document(self, key):
|
||||
return self.firestore_client.collection(
|
||||
self._collection_name
|
||||
).document(bytes_to_str(key))
|
||||
@@ -1,5 +1,5 @@
|
||||
"""MongoDB result store backend."""
|
||||
from datetime import datetime, timedelta
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from kombu.exceptions import EncodeError
|
||||
from kombu.utils.objects import cached_property
|
||||
@@ -228,7 +228,7 @@ class MongoBackend(BaseBackend):
|
||||
meta = {
|
||||
'_id': group_id,
|
||||
'result': self.encode([i.id for i in result]),
|
||||
'date_done': datetime.utcnow(),
|
||||
'date_done': datetime.now(timezone.utc),
|
||||
}
|
||||
self.group_collection.replace_one({'_id': group_id}, meta, upsert=True)
|
||||
return result
|
||||
|
||||
@@ -359,6 +359,11 @@ class RedisBackend(BaseKeyValueStoreBackend, AsyncBackendMixin):
|
||||
connparams.update(query)
|
||||
return connparams
|
||||
|
||||
def exception_safe_to_retry(self, exc):
|
||||
if isinstance(exc, self.connection_errors):
|
||||
return True
|
||||
return False
|
||||
|
||||
@cached_property
|
||||
def retry_policy(self):
|
||||
retry_policy = super().retry_policy
|
||||
|
||||
@@ -222,7 +222,7 @@ class RPCBackend(base.Backend, AsyncBackendMixin):
|
||||
|
||||
def on_out_of_band_result(self, task_id, message):
|
||||
# Callback called when a reply for a task is received,
|
||||
# but we have no idea what do do with it.
|
||||
# but we have no idea what to do with it.
|
||||
# Since the result is not pending, we put it in a separate
|
||||
# buffer: probably it will become pending later.
|
||||
if self.result_consumer:
|
||||
|
||||
Reference in New Issue
Block a user