Major fixes and new features
All checks were successful
continuous-integration/drone/push Build is passing
All checks were successful
continuous-integration/drone/push Build is passing
This commit is contained in:
202
venv/lib/python3.12/site-packages/kombu/transport/SLMQ.py
Normal file
202
venv/lib/python3.12/site-packages/kombu/transport/SLMQ.py
Normal file
@@ -0,0 +1,202 @@
|
||||
"""SoftLayer Message Queue transport module for kombu.
|
||||
|
||||
Features
|
||||
========
|
||||
* Type: Virtual
|
||||
* Supports Direct: Yes
|
||||
* Supports Topic: Yes
|
||||
* Supports Fanout: No
|
||||
* Supports Priority: No
|
||||
* Supports TTL: No
|
||||
|
||||
Connection String
|
||||
=================
|
||||
*Unreviewed*
|
||||
|
||||
Transport Options
|
||||
=================
|
||||
*Unreviewed*
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import socket
|
||||
import string
|
||||
from queue import Empty
|
||||
|
||||
from kombu.utils.encoding import bytes_to_str, safe_str
|
||||
from kombu.utils.json import dumps, loads
|
||||
from kombu.utils.objects import cached_property
|
||||
|
||||
from . import virtual
|
||||
|
||||
try:
|
||||
from softlayer_messaging import get_client
|
||||
from softlayer_messaging.errors import ResponseError
|
||||
except ImportError: # pragma: no cover
|
||||
get_client = ResponseError = None
|
||||
|
||||
# dots are replaced by dash, all other punctuation replaced by underscore.
|
||||
CHARS_REPLACE_TABLE = {
|
||||
ord(c): 0x5f for c in string.punctuation if c not in '_'
|
||||
}
|
||||
|
||||
|
||||
class Channel(virtual.Channel):
|
||||
"""SLMQ Channel."""
|
||||
|
||||
default_visibility_timeout = 1800 # 30 minutes.
|
||||
domain_format = 'kombu%(vhost)s'
|
||||
_slmq = None
|
||||
_queue_cache = {}
|
||||
_noack_queues = set()
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
if get_client is None:
|
||||
raise ImportError(
|
||||
'SLMQ transport requires the softlayer_messaging library',
|
||||
)
|
||||
super().__init__(*args, **kwargs)
|
||||
queues = self.slmq.queues()
|
||||
for queue in queues:
|
||||
self._queue_cache[queue] = queue
|
||||
|
||||
def basic_consume(self, queue, no_ack, *args, **kwargs):
|
||||
if no_ack:
|
||||
self._noack_queues.add(queue)
|
||||
return super().basic_consume(queue, no_ack,
|
||||
*args, **kwargs)
|
||||
|
||||
def basic_cancel(self, consumer_tag):
|
||||
if consumer_tag in self._consumers:
|
||||
queue = self._tag_to_queue[consumer_tag]
|
||||
self._noack_queues.discard(queue)
|
||||
return super().basic_cancel(consumer_tag)
|
||||
|
||||
def entity_name(self, name, table=CHARS_REPLACE_TABLE):
|
||||
"""Format AMQP queue name into a valid SLQS queue name."""
|
||||
return str(safe_str(name)).translate(table)
|
||||
|
||||
def _new_queue(self, queue, **kwargs):
|
||||
"""Ensure a queue exists in SLQS."""
|
||||
queue = self.entity_name(self.queue_name_prefix + queue)
|
||||
try:
|
||||
return self._queue_cache[queue]
|
||||
except KeyError:
|
||||
try:
|
||||
self.slmq.create_queue(
|
||||
queue, visibility_timeout=self.visibility_timeout)
|
||||
except ResponseError:
|
||||
pass
|
||||
q = self._queue_cache[queue] = self.slmq.queue(queue)
|
||||
return q
|
||||
|
||||
def _delete(self, queue, *args, **kwargs):
|
||||
"""Delete queue by name."""
|
||||
queue_name = self.entity_name(queue)
|
||||
self._queue_cache.pop(queue_name, None)
|
||||
self.slmq.queue(queue_name).delete(force=True)
|
||||
super()._delete(queue_name)
|
||||
|
||||
def _put(self, queue, message, **kwargs):
|
||||
"""Put message onto queue."""
|
||||
q = self._new_queue(queue)
|
||||
q.push(dumps(message))
|
||||
|
||||
def _get(self, queue):
|
||||
"""Try to retrieve a single message off ``queue``."""
|
||||
q = self._new_queue(queue)
|
||||
rs = q.pop(1)
|
||||
if rs['items']:
|
||||
m = rs['items'][0]
|
||||
payload = loads(bytes_to_str(m['body']))
|
||||
if queue in self._noack_queues:
|
||||
q.message(m['id']).delete()
|
||||
else:
|
||||
payload['properties']['delivery_info'].update({
|
||||
'slmq_message_id': m['id'], 'slmq_queue_name': q.name})
|
||||
return payload
|
||||
raise Empty()
|
||||
|
||||
def basic_ack(self, delivery_tag):
|
||||
delivery_info = self.qos.get(delivery_tag).delivery_info
|
||||
try:
|
||||
queue = delivery_info['slmq_queue_name']
|
||||
except KeyError:
|
||||
pass
|
||||
else:
|
||||
self.delete_message(queue, delivery_info['slmq_message_id'])
|
||||
super().basic_ack(delivery_tag)
|
||||
|
||||
def _size(self, queue):
|
||||
"""Return the number of messages in a queue."""
|
||||
return self._new_queue(queue).detail()['message_count']
|
||||
|
||||
def _purge(self, queue):
|
||||
"""Delete all current messages in a queue."""
|
||||
q = self._new_queue(queue)
|
||||
n = 0
|
||||
results = q.pop(10)
|
||||
while results['items']:
|
||||
for m in results['items']:
|
||||
self.delete_message(queue, m['id'])
|
||||
n += 1
|
||||
results = q.pop(10)
|
||||
return n
|
||||
|
||||
def delete_message(self, queue, message_id):
|
||||
q = self.slmq.queue(self.entity_name(queue))
|
||||
return q.message(message_id).delete()
|
||||
|
||||
@property
|
||||
def slmq(self):
|
||||
if self._slmq is None:
|
||||
conninfo = self.conninfo
|
||||
account = os.environ.get('SLMQ_ACCOUNT', conninfo.virtual_host)
|
||||
user = os.environ.get('SL_USERNAME', conninfo.userid)
|
||||
api_key = os.environ.get('SL_API_KEY', conninfo.password)
|
||||
host = os.environ.get('SLMQ_HOST', conninfo.hostname)
|
||||
port = os.environ.get('SLMQ_PORT', conninfo.port)
|
||||
secure = bool(os.environ.get(
|
||||
'SLMQ_SECURE', self.transport_options.get('secure')) or True,
|
||||
)
|
||||
endpoint = '{}://{}{}'.format(
|
||||
'https' if secure else 'http', host,
|
||||
f':{port}' if port else '',
|
||||
)
|
||||
|
||||
self._slmq = get_client(account, endpoint=endpoint)
|
||||
self._slmq.authenticate(user, api_key)
|
||||
return self._slmq
|
||||
|
||||
@property
|
||||
def conninfo(self):
|
||||
return self.connection.client
|
||||
|
||||
@property
|
||||
def transport_options(self):
|
||||
return self.connection.client.transport_options
|
||||
|
||||
@cached_property
|
||||
def visibility_timeout(self):
|
||||
return (self.transport_options.get('visibility_timeout') or
|
||||
self.default_visibility_timeout)
|
||||
|
||||
@cached_property
|
||||
def queue_name_prefix(self):
|
||||
return self.transport_options.get('queue_name_prefix', '')
|
||||
|
||||
|
||||
class Transport(virtual.Transport):
|
||||
"""SLMQ Transport."""
|
||||
|
||||
Channel = Channel
|
||||
|
||||
polling_interval = 1
|
||||
default_port = None
|
||||
connection_errors = (
|
||||
virtual.Transport.connection_errors + (
|
||||
ResponseError, socket.error
|
||||
)
|
||||
)
|
||||
973
venv/lib/python3.12/site-packages/kombu/transport/SQS.py
Normal file
973
venv/lib/python3.12/site-packages/kombu/transport/SQS.py
Normal file
@@ -0,0 +1,973 @@
|
||||
"""Amazon SQS transport module for Kombu.
|
||||
|
||||
This package implements an AMQP-like interface on top of Amazons SQS service,
|
||||
with the goal of being optimized for high performance and reliability.
|
||||
|
||||
The default settings for this module are focused now on high performance in
|
||||
task queue situations where tasks are small, idempotent and run very fast.
|
||||
|
||||
SQS Features supported by this transport
|
||||
========================================
|
||||
Long Polling
|
||||
------------
|
||||
https://docs.aws.amazon.com/AWSSimpleQueueService/latest/SQSDeveloperGuide/sqs-long-polling.html
|
||||
|
||||
Long polling is enabled by setting the `wait_time_seconds` transport
|
||||
option to a number > 1. Amazon supports up to 20 seconds. This is
|
||||
enabled with 10 seconds by default.
|
||||
|
||||
Batch API Actions
|
||||
-----------------
|
||||
https://docs.aws.amazon.com/AWSSimpleQueueService/latest/SQSDeveloperGuide/sqs-batch-api.html
|
||||
|
||||
The default behavior of the SQS Channel.drain_events() method is to
|
||||
request up to the 'prefetch_count' messages on every request to SQS.
|
||||
These messages are stored locally in a deque object and passed back
|
||||
to the Transport until the deque is empty, before triggering a new
|
||||
API call to Amazon.
|
||||
|
||||
This behavior dramatically speeds up the rate that you can pull tasks
|
||||
from SQS when you have short-running tasks (or a large number of workers).
|
||||
|
||||
When a Celery worker has multiple queues to monitor, it will pull down
|
||||
up to 'prefetch_count' messages from queueA and work on them all before
|
||||
moving on to queueB. If queueB is empty, it will wait up until
|
||||
'polling_interval' expires before moving back and checking on queueA.
|
||||
|
||||
Message Attributes
|
||||
-----------------
|
||||
https://docs.aws.amazon.com/AWSSimpleQueueService/latest/SQSDeveloperGuide/sqs-message-metadata.html
|
||||
|
||||
SQS supports sending message attributes along with the message body.
|
||||
To use this feature, you can pass a 'message_attributes' as keyword argument
|
||||
to `basic_publish` method.
|
||||
|
||||
Other Features supported by this transport
|
||||
==========================================
|
||||
Predefined Queues
|
||||
-----------------
|
||||
The default behavior of this transport is to use a single AWS credential
|
||||
pair in order to manage all SQS queues (e.g. listing queues, creating
|
||||
queues, polling queues, deleting messages).
|
||||
|
||||
If it is preferable for your environment to use multiple AWS credentials, you
|
||||
can use the 'predefined_queues' setting inside the 'transport_options' map.
|
||||
This setting allows you to specify the SQS queue URL and AWS credentials for
|
||||
each of your queues. For example, if you have two queues which both already
|
||||
exist in AWS) you can tell this transport about them as follows:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
transport_options = {
|
||||
'predefined_queues': {
|
||||
'queue-1': {
|
||||
'url': 'https://sqs.us-east-1.amazonaws.com/xxx/aaa',
|
||||
'access_key_id': 'a',
|
||||
'secret_access_key': 'b',
|
||||
'backoff_policy': {1: 10, 2: 20, 3: 40, 4: 80, 5: 320, 6: 640}, # optional
|
||||
'backoff_tasks': ['svc.tasks.tasks.task1'] # optional
|
||||
},
|
||||
'queue-2.fifo': {
|
||||
'url': 'https://sqs.us-east-1.amazonaws.com/xxx/bbb.fifo',
|
||||
'access_key_id': 'c',
|
||||
'secret_access_key': 'd',
|
||||
'backoff_policy': {1: 10, 2: 20, 3: 40, 4: 80, 5: 320, 6: 640}, # optional
|
||||
'backoff_tasks': ['svc.tasks.tasks.task2'] # optional
|
||||
},
|
||||
}
|
||||
'sts_role_arn': 'arn:aws:iam::<xxx>:role/STSTest', # optional
|
||||
'sts_token_timeout': 900 # optional
|
||||
}
|
||||
|
||||
Note that FIFO and standard queues must be named accordingly (the name of
|
||||
a FIFO queue must end with the .fifo suffix).
|
||||
|
||||
backoff_policy & backoff_tasks are optional arguments. These arguments
|
||||
automatically change the message visibility timeout, in order to have
|
||||
different times between specific task retries. This would apply after
|
||||
task failure.
|
||||
|
||||
AWS STS authentication is supported, by using sts_role_arn, and
|
||||
sts_token_timeout. sts_role_arn is the assumed IAM role ARN we are trying
|
||||
to access with. sts_token_timeout is the token timeout, defaults (and minimum)
|
||||
to 900 seconds. After the mentioned period, a new token will be created.
|
||||
|
||||
|
||||
|
||||
If you authenticate using Okta_ (e.g. calling |gac|_), you can also specify
|
||||
a 'session_token' to connect to a queue. Note that those tokens have a
|
||||
limited lifetime and are therefore only suited for short-lived tests.
|
||||
|
||||
.. _Okta: https://www.okta.com/
|
||||
.. _gac: https://github.com/Nike-Inc/gimme-aws-creds#readme
|
||||
.. |gac| replace:: ``gimme-aws-creds``
|
||||
|
||||
|
||||
Client config
|
||||
-------------
|
||||
In some cases you may need to override the botocore config. You can do it
|
||||
as follows:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
transport_option = {
|
||||
'client-config': {
|
||||
'connect_timeout': 5,
|
||||
},
|
||||
}
|
||||
|
||||
For a complete list of settings you can adjust using this option see
|
||||
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
|
||||
|
||||
Features
|
||||
========
|
||||
* Type: Virtual
|
||||
* Supports Direct: Yes
|
||||
* Supports Topic: Yes
|
||||
* Supports Fanout: Yes
|
||||
* Supports Priority: No
|
||||
* Supports TTL: No
|
||||
"""
|
||||
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import socket
|
||||
import string
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from queue import Empty
|
||||
|
||||
from botocore.client import Config
|
||||
from botocore.exceptions import ClientError
|
||||
from vine import ensure_promise, promise, transform
|
||||
|
||||
from kombu.asynchronous import get_event_loop
|
||||
from kombu.asynchronous.aws.ext import boto3, exceptions
|
||||
from kombu.asynchronous.aws.sqs.connection import AsyncSQSConnection
|
||||
from kombu.asynchronous.aws.sqs.message import AsyncMessage
|
||||
from kombu.log import get_logger
|
||||
from kombu.utils import scheduling
|
||||
from kombu.utils.encoding import bytes_to_str, safe_str
|
||||
from kombu.utils.json import dumps, loads
|
||||
from kombu.utils.objects import cached_property
|
||||
|
||||
from . import virtual
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# dots are replaced by dash, dash remains dash, all other punctuation
|
||||
# replaced by underscore.
|
||||
CHARS_REPLACE_TABLE = {
|
||||
ord(c): 0x5f for c in string.punctuation if c not in '-_.'
|
||||
}
|
||||
CHARS_REPLACE_TABLE[0x2e] = 0x2d # '.' -> '-'
|
||||
|
||||
#: SQS bulk get supports a maximum of 10 messages at a time.
|
||||
SQS_MAX_MESSAGES = 10
|
||||
|
||||
|
||||
def maybe_int(x):
|
||||
"""Try to convert x' to int, or return x' if that fails."""
|
||||
try:
|
||||
return int(x)
|
||||
except ValueError:
|
||||
return x
|
||||
|
||||
|
||||
class UndefinedQueueException(Exception):
|
||||
"""Predefined queues are being used and an undefined queue was used."""
|
||||
|
||||
|
||||
class InvalidQueueException(Exception):
|
||||
"""Predefined queues are being used and configuration is not valid."""
|
||||
|
||||
|
||||
class AccessDeniedQueueException(Exception):
|
||||
"""Raised when access to the AWS queue is denied.
|
||||
|
||||
This may occur if the permissions are not correctly set or the
|
||||
credentials are invalid.
|
||||
"""
|
||||
|
||||
|
||||
class DoesNotExistQueueException(Exception):
|
||||
"""The specified queue doesn't exist."""
|
||||
|
||||
|
||||
class QoS(virtual.QoS):
|
||||
"""Quality of Service guarantees implementation for SQS."""
|
||||
|
||||
def reject(self, delivery_tag, requeue=False):
|
||||
super().reject(delivery_tag, requeue=requeue)
|
||||
routing_key, message, backoff_tasks, backoff_policy = \
|
||||
self._extract_backoff_policy_configuration_and_message(
|
||||
delivery_tag)
|
||||
if routing_key and message and backoff_tasks and backoff_policy:
|
||||
self.apply_backoff_policy(
|
||||
routing_key, delivery_tag, backoff_policy, backoff_tasks)
|
||||
|
||||
def _extract_backoff_policy_configuration_and_message(self, delivery_tag):
|
||||
try:
|
||||
message = self._delivered[delivery_tag]
|
||||
routing_key = message.delivery_info['routing_key']
|
||||
except KeyError:
|
||||
return None, None, None, None
|
||||
if not routing_key or not message:
|
||||
return None, None, None, None
|
||||
queue_config = self.channel.predefined_queues.get(routing_key, {})
|
||||
backoff_tasks = queue_config.get('backoff_tasks')
|
||||
backoff_policy = queue_config.get('backoff_policy')
|
||||
return routing_key, message, backoff_tasks, backoff_policy
|
||||
|
||||
def apply_backoff_policy(self, routing_key, delivery_tag,
|
||||
backoff_policy, backoff_tasks):
|
||||
queue_url = self.channel._queue_cache[routing_key]
|
||||
task_name, number_of_retries = \
|
||||
self.extract_task_name_and_number_of_retries(delivery_tag)
|
||||
if not task_name or not number_of_retries:
|
||||
return None
|
||||
policy_value = backoff_policy.get(number_of_retries)
|
||||
if task_name in backoff_tasks and policy_value is not None:
|
||||
c = self.channel.sqs(routing_key)
|
||||
c.change_message_visibility(
|
||||
QueueUrl=queue_url,
|
||||
ReceiptHandle=delivery_tag,
|
||||
VisibilityTimeout=policy_value
|
||||
)
|
||||
|
||||
def extract_task_name_and_number_of_retries(self, delivery_tag):
|
||||
message = self._delivered[delivery_tag]
|
||||
message_headers = message.headers
|
||||
task_name = message_headers['task']
|
||||
number_of_retries = int(
|
||||
message.properties['delivery_info']['sqs_message']
|
||||
['Attributes']['ApproximateReceiveCount'])
|
||||
return task_name, number_of_retries
|
||||
|
||||
|
||||
class Channel(virtual.Channel):
|
||||
"""SQS Channel."""
|
||||
|
||||
default_region = 'us-east-1'
|
||||
default_visibility_timeout = 1800 # 30 minutes.
|
||||
default_wait_time_seconds = 10 # up to 20 seconds max
|
||||
domain_format = 'kombu%(vhost)s'
|
||||
_asynsqs = None
|
||||
_predefined_queue_async_clients = {} # A client for each predefined queue
|
||||
_sqs = None
|
||||
_predefined_queue_clients = {} # A client for each predefined queue
|
||||
_queue_cache = {} # SQS queue name => SQS queue URL
|
||||
_noack_queues = set()
|
||||
QoS = QoS
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
if boto3 is None:
|
||||
raise ImportError('boto3 is not installed')
|
||||
super().__init__(*args, **kwargs)
|
||||
self._validate_predifined_queues()
|
||||
|
||||
# SQS blows up if you try to create a new queue when one already
|
||||
# exists but with a different visibility_timeout. This prepopulates
|
||||
# the queue_cache to protect us from recreating
|
||||
# queues that are known to already exist.
|
||||
self._update_queue_cache(self.queue_name_prefix)
|
||||
|
||||
self.hub = kwargs.get('hub') or get_event_loop()
|
||||
|
||||
def _validate_predifined_queues(self):
|
||||
"""Check that standard and FIFO queues are named properly.
|
||||
|
||||
AWS requires FIFO queues to have a name
|
||||
that ends with the .fifo suffix.
|
||||
"""
|
||||
for queue_name, q in self.predefined_queues.items():
|
||||
fifo_url = q['url'].endswith('.fifo')
|
||||
fifo_name = queue_name.endswith('.fifo')
|
||||
if fifo_url and not fifo_name:
|
||||
raise InvalidQueueException(
|
||||
"Queue with url '{}' must have a name "
|
||||
"ending with .fifo".format(q['url'])
|
||||
)
|
||||
elif not fifo_url and fifo_name:
|
||||
raise InvalidQueueException(
|
||||
"Queue with name '{}' is not a FIFO queue: "
|
||||
"'{}'".format(queue_name, q['url'])
|
||||
)
|
||||
|
||||
def _update_queue_cache(self, queue_name_prefix):
|
||||
if self.predefined_queues:
|
||||
for queue_name, q in self.predefined_queues.items():
|
||||
self._queue_cache[queue_name] = q['url']
|
||||
return
|
||||
|
||||
resp = self.sqs().list_queues(QueueNamePrefix=queue_name_prefix)
|
||||
for url in resp.get('QueueUrls', []):
|
||||
queue_name = url.split('/')[-1]
|
||||
self._queue_cache[queue_name] = url
|
||||
|
||||
def basic_consume(self, queue, no_ack, *args, **kwargs):
|
||||
if no_ack:
|
||||
self._noack_queues.add(queue)
|
||||
if self.hub:
|
||||
self._loop1(queue)
|
||||
return super().basic_consume(
|
||||
queue, no_ack, *args, **kwargs
|
||||
)
|
||||
|
||||
def basic_cancel(self, consumer_tag):
|
||||
if consumer_tag in self._consumers:
|
||||
queue = self._tag_to_queue[consumer_tag]
|
||||
self._noack_queues.discard(queue)
|
||||
return super().basic_cancel(consumer_tag)
|
||||
|
||||
def drain_events(self, timeout=None, callback=None, **kwargs):
|
||||
"""Return a single payload message from one of our queues.
|
||||
|
||||
Raises
|
||||
------
|
||||
Queue.Empty: if no messages available.
|
||||
"""
|
||||
# If we're not allowed to consume or have no consumers, raise Empty
|
||||
if not self._consumers or not self.qos.can_consume():
|
||||
raise Empty()
|
||||
|
||||
# At this point, go and get more messages from SQS
|
||||
self._poll(self.cycle, callback, timeout=timeout)
|
||||
|
||||
def _reset_cycle(self):
|
||||
"""Reset the consume cycle.
|
||||
|
||||
Returns
|
||||
-------
|
||||
FairCycle: object that points to our _get_bulk() method
|
||||
rather than the standard _get() method. This allows for
|
||||
multiple messages to be returned at once from SQS (
|
||||
based on the prefetch limit).
|
||||
"""
|
||||
self._cycle = scheduling.FairCycle(
|
||||
self._get_bulk, self._active_queues, Empty,
|
||||
)
|
||||
|
||||
def entity_name(self, name, table=CHARS_REPLACE_TABLE):
|
||||
"""Format AMQP queue name into a legal SQS queue name."""
|
||||
if name.endswith('.fifo'):
|
||||
partial = name[:-len('.fifo')]
|
||||
partial = str(safe_str(partial)).translate(table)
|
||||
return partial + '.fifo'
|
||||
else:
|
||||
return str(safe_str(name)).translate(table)
|
||||
|
||||
def canonical_queue_name(self, queue_name):
|
||||
return self.entity_name(self.queue_name_prefix + queue_name)
|
||||
|
||||
def _resolve_queue_url(self, queue):
|
||||
"""Try to retrieve the SQS queue URL for a given queue name."""
|
||||
# Translate to SQS name for consistency with initial
|
||||
# _queue_cache population.
|
||||
sqs_qname = self.canonical_queue_name(queue)
|
||||
|
||||
# The SQS ListQueues method only returns 1000 queues. When you have
|
||||
# so many queues, it's possible that the queue you are looking for is
|
||||
# not cached. In this case, we could update the cache with the exact
|
||||
# queue name first.
|
||||
if sqs_qname not in self._queue_cache:
|
||||
self._update_queue_cache(sqs_qname)
|
||||
try:
|
||||
return self._queue_cache[sqs_qname]
|
||||
except KeyError:
|
||||
if self.predefined_queues:
|
||||
raise UndefinedQueueException((
|
||||
"Queue with name '{}' must be "
|
||||
"defined in 'predefined_queues'."
|
||||
).format(sqs_qname))
|
||||
|
||||
raise DoesNotExistQueueException(
|
||||
f"Queue with name '{sqs_qname}' doesn't exist in SQS"
|
||||
)
|
||||
|
||||
def _new_queue(self, queue, **kwargs):
|
||||
"""Ensure a queue with given name exists in SQS.
|
||||
|
||||
Arguments:
|
||||
---------
|
||||
queue (str): the AMQP queue name
|
||||
Returns
|
||||
str: the SQS queue URL
|
||||
"""
|
||||
try:
|
||||
return self._resolve_queue_url(queue)
|
||||
except DoesNotExistQueueException:
|
||||
sqs_qname = self.canonical_queue_name(queue)
|
||||
attributes = {'VisibilityTimeout': str(self.visibility_timeout)}
|
||||
if sqs_qname.endswith('.fifo'):
|
||||
attributes['FifoQueue'] = 'true'
|
||||
|
||||
resp = self._create_queue(sqs_qname, attributes)
|
||||
self._queue_cache[sqs_qname] = resp['QueueUrl']
|
||||
return resp['QueueUrl']
|
||||
|
||||
def _create_queue(self, queue_name, attributes):
|
||||
"""Create an SQS queue with a given name and nominal attributes."""
|
||||
# Allow specifying additional boto create_queue Attributes
|
||||
# via transport options
|
||||
if self.predefined_queues:
|
||||
return None
|
||||
|
||||
attributes.update(
|
||||
self.transport_options.get('sqs-creation-attributes') or {},
|
||||
)
|
||||
|
||||
return self.sqs(queue=queue_name).create_queue(
|
||||
QueueName=queue_name,
|
||||
Attributes=attributes,
|
||||
)
|
||||
|
||||
def _delete(self, queue, *args, **kwargs):
|
||||
"""Delete queue by name."""
|
||||
if self.predefined_queues:
|
||||
return
|
||||
|
||||
q_url = self._resolve_queue_url(queue)
|
||||
self.sqs().delete_queue(
|
||||
QueueUrl=q_url,
|
||||
)
|
||||
self._queue_cache.pop(queue, None)
|
||||
|
||||
def _put(self, queue, message, **kwargs):
|
||||
"""Put message onto queue."""
|
||||
q_url = self._new_queue(queue)
|
||||
kwargs = {'QueueUrl': q_url}
|
||||
if 'properties' in message:
|
||||
if 'message_attributes' in message['properties']:
|
||||
# we don't want to want to have the attribute in the body
|
||||
kwargs['MessageAttributes'] = \
|
||||
message['properties'].pop('message_attributes')
|
||||
if queue.endswith('.fifo'):
|
||||
if 'MessageGroupId' in message['properties']:
|
||||
kwargs['MessageGroupId'] = \
|
||||
message['properties']['MessageGroupId']
|
||||
else:
|
||||
kwargs['MessageGroupId'] = 'default'
|
||||
if 'MessageDeduplicationId' in message['properties']:
|
||||
kwargs['MessageDeduplicationId'] = \
|
||||
message['properties']['MessageDeduplicationId']
|
||||
else:
|
||||
kwargs['MessageDeduplicationId'] = str(uuid.uuid4())
|
||||
else:
|
||||
if "DelaySeconds" in message['properties']:
|
||||
kwargs['DelaySeconds'] = \
|
||||
message['properties']['DelaySeconds']
|
||||
|
||||
if self.sqs_base64_encoding:
|
||||
body = AsyncMessage().encode(dumps(message))
|
||||
else:
|
||||
body = dumps(message)
|
||||
kwargs['MessageBody'] = body
|
||||
|
||||
c = self.sqs(queue=self.canonical_queue_name(queue))
|
||||
if message.get('redelivered'):
|
||||
c.change_message_visibility(
|
||||
QueueUrl=q_url,
|
||||
ReceiptHandle=message['properties']['delivery_tag'],
|
||||
VisibilityTimeout=0
|
||||
)
|
||||
else:
|
||||
c.send_message(**kwargs)
|
||||
|
||||
@staticmethod
|
||||
def _optional_b64_decode(byte_string):
|
||||
try:
|
||||
data = base64.b64decode(byte_string)
|
||||
if base64.b64encode(data) == byte_string:
|
||||
return data
|
||||
# else the base64 module found some embedded base64 content
|
||||
# that should be ignored.
|
||||
except Exception: # pylint: disable=broad-except
|
||||
pass
|
||||
return byte_string
|
||||
|
||||
def _message_to_python(self, message, queue_name, q_url):
|
||||
body = self._optional_b64_decode(message['Body'].encode())
|
||||
payload = loads(bytes_to_str(body))
|
||||
if queue_name in self._noack_queues:
|
||||
q_url = self._new_queue(queue_name)
|
||||
self.asynsqs(queue=queue_name).delete_message(
|
||||
q_url,
|
||||
message['ReceiptHandle'],
|
||||
)
|
||||
else:
|
||||
try:
|
||||
properties = payload['properties']
|
||||
delivery_info = payload['properties']['delivery_info']
|
||||
except KeyError:
|
||||
# json message not sent by kombu?
|
||||
delivery_info = {}
|
||||
properties = {'delivery_info': delivery_info}
|
||||
payload.update({
|
||||
'body': bytes_to_str(body),
|
||||
'properties': properties,
|
||||
})
|
||||
# set delivery tag to SQS receipt handle
|
||||
delivery_info.update({
|
||||
'sqs_message': message, 'sqs_queue': q_url,
|
||||
})
|
||||
properties['delivery_tag'] = message['ReceiptHandle']
|
||||
return payload
|
||||
|
||||
def _messages_to_python(self, messages, queue):
|
||||
"""Convert a list of SQS Message objects into Payloads.
|
||||
|
||||
This method handles converting SQS Message objects into
|
||||
Payloads, and appropriately updating the queue depending on
|
||||
the 'ack' settings for that queue.
|
||||
|
||||
Arguments:
|
||||
---------
|
||||
messages (SQSMessage): A list of SQS Message objects.
|
||||
queue (str): Name representing the queue they came from.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List: A list of Payload objects
|
||||
"""
|
||||
q_url = self._new_queue(queue)
|
||||
return [self._message_to_python(m, queue, q_url) for m in messages]
|
||||
|
||||
def _get_bulk(self, queue,
|
||||
max_if_unlimited=SQS_MAX_MESSAGES, callback=None):
|
||||
"""Try to retrieve multiple messages off ``queue``.
|
||||
|
||||
Where :meth:`_get` returns a single Payload object, this method
|
||||
returns a list of Payload objects. The number of objects returned
|
||||
is determined by the total number of messages available in the queue
|
||||
and the number of messages the QoS object allows (based on the
|
||||
prefetch_count).
|
||||
|
||||
Note:
|
||||
----
|
||||
Ignores QoS limits so caller is responsible for checking
|
||||
that we are allowed to consume at least one message from the
|
||||
queue. get_bulk will then ask QoS for an estimate of
|
||||
the number of extra messages that we can consume.
|
||||
|
||||
Arguments:
|
||||
---------
|
||||
queue (str): The queue name to pull from.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[Message]
|
||||
"""
|
||||
# drain_events calls `can_consume` first, consuming
|
||||
# a token, so we know that we are allowed to consume at least
|
||||
# one message.
|
||||
|
||||
# Note: ignoring max_messages for SQS with boto3
|
||||
max_count = self._get_message_estimate()
|
||||
if max_count:
|
||||
q_url = self._new_queue(queue)
|
||||
resp = self.sqs(queue=queue).receive_message(
|
||||
QueueUrl=q_url, MaxNumberOfMessages=max_count,
|
||||
WaitTimeSeconds=self.wait_time_seconds)
|
||||
if resp.get('Messages'):
|
||||
for m in resp['Messages']:
|
||||
m['Body'] = AsyncMessage(body=m['Body']).decode()
|
||||
for msg in self._messages_to_python(resp['Messages'], queue):
|
||||
self.connection._deliver(msg, queue)
|
||||
return
|
||||
raise Empty()
|
||||
|
||||
def _get(self, queue):
|
||||
"""Try to retrieve a single message off ``queue``."""
|
||||
q_url = self._new_queue(queue)
|
||||
resp = self.sqs(queue=queue).receive_message(
|
||||
QueueUrl=q_url, MaxNumberOfMessages=1,
|
||||
WaitTimeSeconds=self.wait_time_seconds)
|
||||
if resp.get('Messages'):
|
||||
body = AsyncMessage(body=resp['Messages'][0]['Body']).decode()
|
||||
resp['Messages'][0]['Body'] = body
|
||||
return self._messages_to_python(resp['Messages'], queue)[0]
|
||||
raise Empty()
|
||||
|
||||
def _loop1(self, queue, _=None):
|
||||
self.hub.call_soon(self._schedule_queue, queue)
|
||||
|
||||
def _schedule_queue(self, queue):
|
||||
if queue in self._active_queues:
|
||||
if self.qos.can_consume():
|
||||
self._get_bulk_async(
|
||||
queue, callback=promise(self._loop1, (queue,)),
|
||||
)
|
||||
else:
|
||||
self._loop1(queue)
|
||||
|
||||
def _get_message_estimate(self, max_if_unlimited=SQS_MAX_MESSAGES):
|
||||
maxcount = self.qos.can_consume_max_estimate()
|
||||
return min(
|
||||
max_if_unlimited if maxcount is None else max(maxcount, 1),
|
||||
max_if_unlimited,
|
||||
)
|
||||
|
||||
def _get_bulk_async(self, queue,
|
||||
max_if_unlimited=SQS_MAX_MESSAGES, callback=None):
|
||||
maxcount = self._get_message_estimate()
|
||||
if maxcount:
|
||||
return self._get_async(queue, maxcount, callback=callback)
|
||||
# Not allowed to consume, make sure to notify callback..
|
||||
callback = ensure_promise(callback)
|
||||
callback([])
|
||||
return callback
|
||||
|
||||
def _get_async(self, queue, count=1, callback=None):
|
||||
q_url = self._new_queue(queue)
|
||||
qname = self.canonical_queue_name(queue)
|
||||
return self._get_from_sqs(
|
||||
queue_name=qname, queue_url=q_url, count=count,
|
||||
connection=self.asynsqs(queue=qname),
|
||||
callback=transform(
|
||||
self._on_messages_ready, callback, q_url, queue
|
||||
),
|
||||
)
|
||||
|
||||
def _on_messages_ready(self, queue, qname, messages):
|
||||
if 'Messages' in messages and messages['Messages']:
|
||||
callbacks = self.connection._callbacks
|
||||
for msg in messages['Messages']:
|
||||
msg_parsed = self._message_to_python(msg, qname, queue)
|
||||
callbacks[qname](msg_parsed)
|
||||
|
||||
def _get_from_sqs(self, queue_name, queue_url,
|
||||
connection, count=1, callback=None):
|
||||
"""Retrieve and handle messages from SQS.
|
||||
|
||||
Uses long polling and returns :class:`~vine.promises.promise`.
|
||||
"""
|
||||
return connection.receive_message(
|
||||
queue_name, queue_url, number_messages=count,
|
||||
wait_time_seconds=self.wait_time_seconds,
|
||||
callback=callback,
|
||||
)
|
||||
|
||||
def _restore(self, message,
|
||||
unwanted_delivery_info=('sqs_message', 'sqs_queue')):
|
||||
for unwanted_key in unwanted_delivery_info:
|
||||
# Remove objects that aren't JSON serializable (Issue #1108).
|
||||
message.delivery_info.pop(unwanted_key, None)
|
||||
return super()._restore(message)
|
||||
|
||||
def basic_ack(self, delivery_tag, multiple=False):
|
||||
try:
|
||||
message = self.qos.get(delivery_tag).delivery_info
|
||||
sqs_message = message['sqs_message']
|
||||
except KeyError:
|
||||
super().basic_ack(delivery_tag)
|
||||
else:
|
||||
queue = None
|
||||
if 'routing_key' in message:
|
||||
queue = self.canonical_queue_name(message['routing_key'])
|
||||
|
||||
try:
|
||||
self.sqs(queue=queue).delete_message(
|
||||
QueueUrl=message['sqs_queue'],
|
||||
ReceiptHandle=sqs_message['ReceiptHandle']
|
||||
)
|
||||
except ClientError as exception:
|
||||
if exception.response['Error']['Code'] == 'AccessDenied':
|
||||
raise AccessDeniedQueueException(
|
||||
exception.response["Error"]["Message"]
|
||||
)
|
||||
super().basic_reject(delivery_tag)
|
||||
else:
|
||||
super().basic_ack(delivery_tag)
|
||||
|
||||
def _size(self, queue):
|
||||
"""Return the number of messages in a queue."""
|
||||
q_url = self._new_queue(queue)
|
||||
c = self.sqs(queue=self.canonical_queue_name(queue))
|
||||
resp = c.get_queue_attributes(
|
||||
QueueUrl=q_url,
|
||||
AttributeNames=['ApproximateNumberOfMessages'])
|
||||
return int(resp['Attributes']['ApproximateNumberOfMessages'])
|
||||
|
||||
def _purge(self, queue):
|
||||
"""Delete all current messages in a queue."""
|
||||
q_url = self._new_queue(queue)
|
||||
# SQS is slow at registering messages, so run for a few
|
||||
# iterations to ensure messages are detected and deleted.
|
||||
size = 0
|
||||
for i in range(10):
|
||||
size += int(self._size(queue))
|
||||
if not size:
|
||||
break
|
||||
self.sqs(queue=queue).purge_queue(QueueUrl=q_url)
|
||||
return size
|
||||
|
||||
def close(self):
|
||||
super().close()
|
||||
# if self._asynsqs:
|
||||
# try:
|
||||
# self.asynsqs().close()
|
||||
# except AttributeError as exc: # FIXME ???
|
||||
# if "can't set attribute" not in str(exc):
|
||||
# raise
|
||||
|
||||
def new_sqs_client(self, region, access_key_id,
|
||||
secret_access_key, session_token=None):
|
||||
session = boto3.session.Session(
|
||||
region_name=region,
|
||||
aws_access_key_id=access_key_id,
|
||||
aws_secret_access_key=secret_access_key,
|
||||
aws_session_token=session_token,
|
||||
)
|
||||
is_secure = self.is_secure if self.is_secure is not None else True
|
||||
client_kwargs = {
|
||||
'use_ssl': is_secure
|
||||
}
|
||||
if self.endpoint_url is not None:
|
||||
client_kwargs['endpoint_url'] = self.endpoint_url
|
||||
client_config = self.transport_options.get('client-config') or {}
|
||||
config = Config(**client_config)
|
||||
return session.client('sqs', config=config, **client_kwargs)
|
||||
|
||||
def sqs(self, queue=None):
|
||||
if queue is not None and self.predefined_queues:
|
||||
|
||||
if queue not in self.predefined_queues:
|
||||
raise UndefinedQueueException(
|
||||
f"Queue with name '{queue}' must be defined"
|
||||
" in 'predefined_queues'.")
|
||||
q = self.predefined_queues[queue]
|
||||
if self.transport_options.get('sts_role_arn'):
|
||||
return self._handle_sts_session(queue, q)
|
||||
if not self.transport_options.get('sts_role_arn'):
|
||||
if queue in self._predefined_queue_clients:
|
||||
return self._predefined_queue_clients[queue]
|
||||
else:
|
||||
c = self._predefined_queue_clients[queue] = \
|
||||
self.new_sqs_client(
|
||||
region=q.get('region', self.region),
|
||||
access_key_id=q.get(
|
||||
'access_key_id', self.conninfo.userid),
|
||||
secret_access_key=q.get(
|
||||
'secret_access_key', self.conninfo.password)
|
||||
)
|
||||
return c
|
||||
|
||||
if self._sqs is not None:
|
||||
return self._sqs
|
||||
|
||||
c = self._sqs = self.new_sqs_client(
|
||||
region=self.region,
|
||||
access_key_id=self.conninfo.userid,
|
||||
secret_access_key=self.conninfo.password,
|
||||
)
|
||||
return c
|
||||
|
||||
def _handle_sts_session(self, queue, q):
|
||||
region = q.get('region', self.region)
|
||||
if not hasattr(self, 'sts_expiration'): # STS token - token init
|
||||
return self._new_predefined_queue_client_with_sts_session(queue, region)
|
||||
# STS token - refresh if expired
|
||||
elif self.sts_expiration.replace(tzinfo=None) < datetime.utcnow():
|
||||
return self._new_predefined_queue_client_with_sts_session(queue, region)
|
||||
else: # STS token - ruse existing
|
||||
if queue not in self._predefined_queue_clients:
|
||||
return self._new_predefined_queue_client_with_sts_session(queue, region)
|
||||
return self._predefined_queue_clients[queue]
|
||||
|
||||
def _new_predefined_queue_client_with_sts_session(self, queue, region):
|
||||
sts_creds = self.generate_sts_session_token(
|
||||
self.transport_options.get('sts_role_arn'),
|
||||
self.transport_options.get('sts_token_timeout', 900))
|
||||
self.sts_expiration = sts_creds['Expiration']
|
||||
c = self._predefined_queue_clients[queue] = self.new_sqs_client(
|
||||
region=region,
|
||||
access_key_id=sts_creds['AccessKeyId'],
|
||||
secret_access_key=sts_creds['SecretAccessKey'],
|
||||
session_token=sts_creds['SessionToken'],
|
||||
)
|
||||
return c
|
||||
|
||||
def generate_sts_session_token(self, role_arn, token_expiry_seconds):
|
||||
sts_client = boto3.client('sts')
|
||||
sts_policy = sts_client.assume_role(
|
||||
RoleArn=role_arn,
|
||||
RoleSessionName='Celery',
|
||||
DurationSeconds=token_expiry_seconds
|
||||
)
|
||||
return sts_policy['Credentials']
|
||||
|
||||
def asynsqs(self, queue=None):
|
||||
if queue is not None and self.predefined_queues:
|
||||
if queue in self._predefined_queue_async_clients and \
|
||||
not hasattr(self, 'sts_expiration'):
|
||||
return self._predefined_queue_async_clients[queue]
|
||||
if queue not in self.predefined_queues:
|
||||
raise UndefinedQueueException((
|
||||
"Queue with name '{}' must be defined in "
|
||||
"'predefined_queues'."
|
||||
).format(queue))
|
||||
q = self.predefined_queues[queue]
|
||||
c = self._predefined_queue_async_clients[queue] = \
|
||||
AsyncSQSConnection(
|
||||
sqs_connection=self.sqs(queue=queue),
|
||||
region=q.get('region', self.region),
|
||||
fetch_message_attributes=self.fetch_message_attributes,
|
||||
)
|
||||
return c
|
||||
|
||||
if self._asynsqs is not None:
|
||||
return self._asynsqs
|
||||
|
||||
c = self._asynsqs = AsyncSQSConnection(
|
||||
sqs_connection=self.sqs(queue=queue),
|
||||
region=self.region,
|
||||
fetch_message_attributes=self.fetch_message_attributes,
|
||||
)
|
||||
return c
|
||||
|
||||
@property
|
||||
def conninfo(self):
|
||||
return self.connection.client
|
||||
|
||||
@property
|
||||
def transport_options(self):
|
||||
return self.connection.client.transport_options
|
||||
|
||||
@cached_property
|
||||
def visibility_timeout(self):
|
||||
return (self.transport_options.get('visibility_timeout') or
|
||||
self.default_visibility_timeout)
|
||||
|
||||
@cached_property
|
||||
def predefined_queues(self):
|
||||
"""Map of queue_name to predefined queue settings."""
|
||||
return self.transport_options.get('predefined_queues', {})
|
||||
|
||||
@cached_property
|
||||
def queue_name_prefix(self):
|
||||
return self.transport_options.get('queue_name_prefix', '')
|
||||
|
||||
@cached_property
|
||||
def supports_fanout(self):
|
||||
return False
|
||||
|
||||
@cached_property
|
||||
def region(self):
|
||||
return (self.transport_options.get('region') or
|
||||
boto3.Session().region_name or
|
||||
self.default_region)
|
||||
|
||||
@cached_property
|
||||
def regioninfo(self):
|
||||
return self.transport_options.get('regioninfo')
|
||||
|
||||
@cached_property
|
||||
def is_secure(self):
|
||||
return self.transport_options.get('is_secure')
|
||||
|
||||
@cached_property
|
||||
def port(self):
|
||||
return self.transport_options.get('port')
|
||||
|
||||
@cached_property
|
||||
def endpoint_url(self):
|
||||
if self.conninfo.hostname is not None:
|
||||
scheme = 'https' if self.is_secure else 'http'
|
||||
if self.conninfo.port is not None:
|
||||
port = f':{self.conninfo.port}'
|
||||
else:
|
||||
port = ''
|
||||
return '{}://{}{}'.format(
|
||||
scheme,
|
||||
self.conninfo.hostname,
|
||||
port
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def wait_time_seconds(self):
|
||||
return self.transport_options.get('wait_time_seconds',
|
||||
self.default_wait_time_seconds)
|
||||
|
||||
@cached_property
|
||||
def sqs_base64_encoding(self):
|
||||
return self.transport_options.get('sqs_base64_encoding', True)
|
||||
|
||||
@cached_property
|
||||
def fetch_message_attributes(self):
|
||||
return self.transport_options.get('fetch_message_attributes')
|
||||
|
||||
|
||||
class Transport(virtual.Transport):
|
||||
"""SQS Transport.
|
||||
|
||||
Additional queue attributes can be supplied to SQS during queue
|
||||
creation by passing an ``sqs-creation-attributes`` key in
|
||||
transport_options. ``sqs-creation-attributes`` must be a dict whose
|
||||
key-value pairs correspond with Attributes in the
|
||||
`CreateQueue SQS API`_.
|
||||
|
||||
For example, to have SQS queues created with server-side encryption
|
||||
enabled using the default Amazon Managed Customer Master Key, you
|
||||
can set ``KmsMasterKeyId`` Attribute. When the queue is initially
|
||||
created by Kombu, encryption will be enabled.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from kombu.transport.SQS import Transport
|
||||
|
||||
transport = Transport(
|
||||
...,
|
||||
transport_options={
|
||||
'sqs-creation-attributes': {
|
||||
'KmsMasterKeyId': 'alias/aws/sqs',
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
.. _CreateQueue SQS API: https://docs.aws.amazon.com/AWSSimpleQueueService/latest/APIReference/API_CreateQueue.html#API_CreateQueue_RequestParameters
|
||||
|
||||
The ``ApproximateReceiveCount`` message attribute is fetched by this
|
||||
transport by default. Requested message attributes can be changed by
|
||||
setting ``fetch_message_attributes`` in the transport options.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from kombu.transport.SQS import Transport
|
||||
|
||||
transport = Transport(
|
||||
...,
|
||||
transport_options={
|
||||
'fetch_message_attributes': ["All"],
|
||||
}
|
||||
)
|
||||
|
||||
.. _Message Attributes: https://docs.aws.amazon.com/AWSSimpleQueueService/latest/APIReference/API_ReceiveMessage.html#SQS-ReceiveMessage-request-AttributeNames
|
||||
|
||||
""" # noqa: E501
|
||||
|
||||
Channel = Channel
|
||||
|
||||
polling_interval = 1
|
||||
wait_time_seconds = 0
|
||||
default_port = None
|
||||
connection_errors = (
|
||||
virtual.Transport.connection_errors +
|
||||
(exceptions.BotoCoreError, socket.error)
|
||||
)
|
||||
channel_errors = (
|
||||
virtual.Transport.channel_errors + (exceptions.BotoCoreError,)
|
||||
)
|
||||
driver_type = 'sqs'
|
||||
driver_name = 'sqs'
|
||||
|
||||
implements = virtual.Transport.implements.extend(
|
||||
asynchronous=True,
|
||||
exchange_type=frozenset(['direct']),
|
||||
)
|
||||
|
||||
@property
|
||||
def default_connection_params(self):
|
||||
return {'port': self.default_port}
|
||||
@@ -0,0 +1,93 @@
|
||||
"""Built-in transports."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from kombu.utils.compat import _detect_environment
|
||||
from kombu.utils.imports import symbol_by_name
|
||||
|
||||
|
||||
def supports_librabbitmq() -> bool | None:
|
||||
"""Return true if :pypi:`librabbitmq` can be used."""
|
||||
if _detect_environment() == 'default':
|
||||
try:
|
||||
import librabbitmq # noqa
|
||||
except ImportError: # pragma: no cover
|
||||
pass
|
||||
else: # pragma: no cover
|
||||
return True
|
||||
return None
|
||||
|
||||
|
||||
TRANSPORT_ALIASES = {
|
||||
'amqp': 'kombu.transport.pyamqp:Transport',
|
||||
'amqps': 'kombu.transport.pyamqp:SSLTransport',
|
||||
'pyamqp': 'kombu.transport.pyamqp:Transport',
|
||||
'librabbitmq': 'kombu.transport.librabbitmq:Transport',
|
||||
'confluentkafka': 'kombu.transport.confluentkafka:Transport',
|
||||
'kafka': 'kombu.transport.confluentkafka:Transport',
|
||||
'memory': 'kombu.transport.memory:Transport',
|
||||
'redis': 'kombu.transport.redis:Transport',
|
||||
'rediss': 'kombu.transport.redis:Transport',
|
||||
'SQS': 'kombu.transport.SQS:Transport',
|
||||
'sqs': 'kombu.transport.SQS:Transport',
|
||||
'mongodb': 'kombu.transport.mongodb:Transport',
|
||||
'zookeeper': 'kombu.transport.zookeeper:Transport',
|
||||
'sqlalchemy': 'kombu.transport.sqlalchemy:Transport',
|
||||
'sqla': 'kombu.transport.sqlalchemy:Transport',
|
||||
'SLMQ': 'kombu.transport.SLMQ.Transport',
|
||||
'slmq': 'kombu.transport.SLMQ.Transport',
|
||||
'filesystem': 'kombu.transport.filesystem:Transport',
|
||||
'qpid': 'kombu.transport.qpid:Transport',
|
||||
'sentinel': 'kombu.transport.redis:SentinelTransport',
|
||||
'consul': 'kombu.transport.consul:Transport',
|
||||
'etcd': 'kombu.transport.etcd:Transport',
|
||||
'azurestoragequeues': 'kombu.transport.azurestoragequeues:Transport',
|
||||
'azureservicebus': 'kombu.transport.azureservicebus:Transport',
|
||||
'pyro': 'kombu.transport.pyro:Transport',
|
||||
'gcpubsub': 'kombu.transport.gcpubsub:Transport',
|
||||
}
|
||||
|
||||
_transport_cache = {}
|
||||
|
||||
|
||||
def resolve_transport(transport: str | None = None) -> str | None:
|
||||
"""Get transport by name.
|
||||
|
||||
Arguments:
|
||||
---------
|
||||
transport (Union[str, type]): This can be either
|
||||
an actual transport class, or the fully qualified
|
||||
path to a transport class, or the alias of a transport.
|
||||
"""
|
||||
if isinstance(transport, str):
|
||||
try:
|
||||
transport = TRANSPORT_ALIASES[transport]
|
||||
except KeyError:
|
||||
if '.' not in transport and ':' not in transport:
|
||||
from kombu.utils.text import fmatch_best
|
||||
alt = fmatch_best(transport, TRANSPORT_ALIASES)
|
||||
if alt:
|
||||
raise KeyError(
|
||||
'No such transport: {}. Did you mean {}?'.format(
|
||||
transport, alt))
|
||||
raise KeyError(f'No such transport: {transport}')
|
||||
else:
|
||||
if callable(transport):
|
||||
transport = transport()
|
||||
return symbol_by_name(transport)
|
||||
return transport
|
||||
|
||||
|
||||
def get_transport_cls(transport: str | None = None) -> str | None:
|
||||
"""Get transport class by name.
|
||||
|
||||
The transport string is the full path to a transport class, e.g.::
|
||||
|
||||
"kombu.transport.pyamqp:Transport"
|
||||
|
||||
If the name does not include `"."` (is not fully qualified),
|
||||
the alias table will be consulted.
|
||||
"""
|
||||
if transport not in _transport_cache:
|
||||
_transport_cache[transport] = resolve_transport(transport)
|
||||
return _transport_cache[transport]
|
||||
@@ -0,0 +1,498 @@
|
||||
"""Azure Service Bus Message Queue transport module for kombu.
|
||||
|
||||
Note that the Shared Access Policy used to connect to Azure Service Bus
|
||||
requires Manage, Send and Listen claims since the broker will create new
|
||||
queues and delete old queues as required.
|
||||
|
||||
|
||||
Notes when using with Celery if you are experiencing issues with programs not
|
||||
terminating properly. The Azure Service Bus SDK uses the Azure uAMQP library
|
||||
which in turn creates some threads. If the AzureServiceBus Channel is closed,
|
||||
said threads will be closed properly, but it seems there are times when Celery
|
||||
does not do this so these threads will be left running. As the uAMQP threads
|
||||
are not marked as Daemon threads, they will not be killed when the main thread
|
||||
exits. Setting the ``uamqp_keep_alive_interval`` transport option to 0 will
|
||||
prevent the keep_alive thread from starting
|
||||
|
||||
|
||||
More information about Azure Service Bus:
|
||||
https://azure.microsoft.com/en-us/services/service-bus/
|
||||
|
||||
Features
|
||||
========
|
||||
* Type: Virtual
|
||||
* Supports Direct: *Unreviewed*
|
||||
* Supports Topic: *Unreviewed*
|
||||
* Supports Fanout: *Unreviewed*
|
||||
* Supports Priority: *Unreviewed*
|
||||
* Supports TTL: *Unreviewed*
|
||||
|
||||
Connection String
|
||||
=================
|
||||
|
||||
Connection string has the following formats:
|
||||
|
||||
.. code-block::
|
||||
|
||||
azureservicebus://SAS_POLICY_NAME:SAS_KEY@SERVICE_BUSNAMESPACE
|
||||
azureservicebus://DefaultAzureCredential@SERVICE_BUSNAMESPACE
|
||||
azureservicebus://ManagedIdentityCredential@SERVICE_BUSNAMESPACE
|
||||
|
||||
Transport Options
|
||||
=================
|
||||
|
||||
* ``queue_name_prefix`` - String prefix to prepend to queue names in a
|
||||
service bus namespace.
|
||||
* ``wait_time_seconds`` - Number of seconds to wait to receive messages.
|
||||
Default ``5``
|
||||
* ``peek_lock_seconds`` - Number of seconds the message is visible for before
|
||||
it is requeued and sent to another consumer. Default ``60``
|
||||
* ``uamqp_keep_alive_interval`` - Interval in seconds the Azure uAMQP library
|
||||
should send keepalive messages. Default ``30``
|
||||
* ``retry_total`` - Azure SDK retry total. Default ``3``
|
||||
* ``retry_backoff_factor`` - Azure SDK exponential backoff factor.
|
||||
Default ``0.8``
|
||||
* ``retry_backoff_max`` - Azure SDK retry total time. Default ``120``
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import string
|
||||
from queue import Empty
|
||||
from typing import Any
|
||||
|
||||
import azure.core.exceptions
|
||||
import azure.servicebus.exceptions
|
||||
import isodate
|
||||
from azure.servicebus import (ServiceBusClient, ServiceBusMessage,
|
||||
ServiceBusReceiveMode, ServiceBusReceiver,
|
||||
ServiceBusSender)
|
||||
from azure.servicebus.management import ServiceBusAdministrationClient
|
||||
|
||||
try:
|
||||
from azure.identity import (DefaultAzureCredential,
|
||||
ManagedIdentityCredential)
|
||||
except ImportError:
|
||||
DefaultAzureCredential = None
|
||||
ManagedIdentityCredential = None
|
||||
|
||||
from kombu.utils.encoding import bytes_to_str, safe_str
|
||||
from kombu.utils.json import dumps, loads
|
||||
from kombu.utils.objects import cached_property
|
||||
|
||||
from . import virtual
|
||||
|
||||
# dots are replaced by dash, all other punctuation replaced by underscore.
|
||||
PUNCTUATIONS_TO_REPLACE = set(string.punctuation) - {'_', '.', '-'}
|
||||
CHARS_REPLACE_TABLE = {
|
||||
ord('.'): ord('-'),
|
||||
**{ord(c): ord('_') for c in PUNCTUATIONS_TO_REPLACE}
|
||||
}
|
||||
|
||||
|
||||
class SendReceive:
|
||||
"""Container for Sender and Receiver."""
|
||||
|
||||
def __init__(self,
|
||||
receiver: ServiceBusReceiver | None = None,
|
||||
sender: ServiceBusSender | None = None):
|
||||
self.receiver: ServiceBusReceiver = receiver
|
||||
self.sender: ServiceBusSender = sender
|
||||
|
||||
def close(self) -> None:
|
||||
if self.receiver:
|
||||
self.receiver.close()
|
||||
self.receiver = None
|
||||
if self.sender:
|
||||
self.sender.close()
|
||||
self.sender = None
|
||||
|
||||
|
||||
class Channel(virtual.Channel):
|
||||
"""Azure Service Bus channel."""
|
||||
|
||||
default_wait_time_seconds: int = 5 # in seconds
|
||||
default_peek_lock_seconds: int = 60 # in seconds (default 60, max 300)
|
||||
# in seconds (is the default from service bus repo)
|
||||
default_uamqp_keep_alive_interval: int = 30
|
||||
# number of retries (is the default from service bus repo)
|
||||
default_retry_total: int = 3
|
||||
# exponential backoff factor (is the default from service bus repo)
|
||||
default_retry_backoff_factor: float = 0.8
|
||||
# Max time to backoff (is the default from service bus repo)
|
||||
default_retry_backoff_max: int = 120
|
||||
domain_format: str = 'kombu%(vhost)s'
|
||||
_queue_cache: dict[str, SendReceive] = {}
|
||||
_noack_queues: set[str] = set()
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self._namespace = None
|
||||
self._policy = None
|
||||
self._sas_key = None
|
||||
self._connection_string = None
|
||||
|
||||
self._try_parse_connection_string()
|
||||
|
||||
self.qos.restore_at_shutdown = False
|
||||
|
||||
def _try_parse_connection_string(self) -> None:
|
||||
self._namespace, self._credential = Transport.parse_uri(
|
||||
self.conninfo.hostname)
|
||||
|
||||
if (
|
||||
DefaultAzureCredential is not None
|
||||
and isinstance(self._credential, DefaultAzureCredential)
|
||||
) or (
|
||||
ManagedIdentityCredential is not None
|
||||
and isinstance(self._credential, ManagedIdentityCredential)
|
||||
):
|
||||
return None
|
||||
|
||||
if ":" in self._credential:
|
||||
self._policy, self._sas_key = self._credential.split(':', 1)
|
||||
|
||||
conn_dict = {
|
||||
'Endpoint': 'sb://' + self._namespace,
|
||||
'SharedAccessKeyName': self._policy,
|
||||
'SharedAccessKey': self._sas_key,
|
||||
}
|
||||
self._connection_string = ';'.join(
|
||||
[key + '=' + value for key, value in conn_dict.items()])
|
||||
|
||||
def basic_consume(self, queue, no_ack, *args, **kwargs):
|
||||
if no_ack:
|
||||
self._noack_queues.add(queue)
|
||||
return super().basic_consume(
|
||||
queue, no_ack, *args, **kwargs
|
||||
)
|
||||
|
||||
def basic_cancel(self, consumer_tag):
|
||||
if consumer_tag in self._consumers:
|
||||
queue = self._tag_to_queue[consumer_tag]
|
||||
self._noack_queues.discard(queue)
|
||||
return super().basic_cancel(consumer_tag)
|
||||
|
||||
def _add_queue_to_cache(
|
||||
self, name: str,
|
||||
receiver: ServiceBusReceiver | None = None,
|
||||
sender: ServiceBusSender | None = None
|
||||
) -> SendReceive:
|
||||
if name in self._queue_cache:
|
||||
obj = self._queue_cache[name]
|
||||
obj.sender = obj.sender or sender
|
||||
obj.receiver = obj.receiver or receiver
|
||||
else:
|
||||
obj = SendReceive(receiver, sender)
|
||||
self._queue_cache[name] = obj
|
||||
return obj
|
||||
|
||||
def _get_asb_sender(self, queue: str) -> SendReceive:
|
||||
queue_obj = self._queue_cache.get(queue, None)
|
||||
if queue_obj is None or queue_obj.sender is None:
|
||||
sender = self.queue_service.get_queue_sender(
|
||||
queue, keep_alive=self.uamqp_keep_alive_interval)
|
||||
queue_obj = self._add_queue_to_cache(queue, sender=sender)
|
||||
return queue_obj
|
||||
|
||||
def _get_asb_receiver(
|
||||
self, queue: str,
|
||||
recv_mode: ServiceBusReceiveMode = ServiceBusReceiveMode.PEEK_LOCK,
|
||||
queue_cache_key: str | None = None) -> SendReceive:
|
||||
cache_key = queue_cache_key or queue
|
||||
queue_obj = self._queue_cache.get(cache_key, None)
|
||||
if queue_obj is None or queue_obj.receiver is None:
|
||||
receiver = self.queue_service.get_queue_receiver(
|
||||
queue_name=queue, receive_mode=recv_mode,
|
||||
keep_alive=self.uamqp_keep_alive_interval)
|
||||
queue_obj = self._add_queue_to_cache(cache_key, receiver=receiver)
|
||||
return queue_obj
|
||||
|
||||
def entity_name(
|
||||
self, name: str, table: dict[int, int] | None = None) -> str:
|
||||
"""Format AMQP queue name into a valid ServiceBus queue name."""
|
||||
return str(safe_str(name)).translate(table or CHARS_REPLACE_TABLE)
|
||||
|
||||
def _restore(self, message: virtual.base.Message) -> None:
|
||||
# Not be needed as ASB handles unacked messages
|
||||
# Remove 'azure_message' as its not JSON serializable
|
||||
# message.delivery_info.pop('azure_message', None)
|
||||
# super()._restore(message)
|
||||
pass
|
||||
|
||||
def _new_queue(self, queue: str, **kwargs) -> SendReceive:
|
||||
"""Ensure a queue exists in ServiceBus."""
|
||||
queue = self.entity_name(self.queue_name_prefix + queue)
|
||||
|
||||
try:
|
||||
return self._queue_cache[queue]
|
||||
except KeyError:
|
||||
# Converts seconds into ISO8601 duration format
|
||||
# ie 66seconds = P1M6S
|
||||
lock_duration = isodate.duration_isoformat(
|
||||
isodate.Duration(seconds=self.peek_lock_seconds))
|
||||
try:
|
||||
self.queue_mgmt_service.create_queue(
|
||||
queue_name=queue, lock_duration=lock_duration)
|
||||
except azure.core.exceptions.ResourceExistsError:
|
||||
pass
|
||||
return self._add_queue_to_cache(queue)
|
||||
|
||||
def _delete(self, queue: str, *args, **kwargs) -> None:
|
||||
"""Delete queue by name."""
|
||||
queue = self.entity_name(self.queue_name_prefix + queue)
|
||||
|
||||
self.queue_mgmt_service.delete_queue(queue)
|
||||
send_receive_obj = self._queue_cache.pop(queue, None)
|
||||
if send_receive_obj:
|
||||
send_receive_obj.close()
|
||||
|
||||
def _put(self, queue: str, message, **kwargs) -> None:
|
||||
"""Put message onto queue."""
|
||||
queue = self.entity_name(self.queue_name_prefix + queue)
|
||||
msg = ServiceBusMessage(dumps(message))
|
||||
|
||||
queue_obj = self._get_asb_sender(queue)
|
||||
queue_obj.sender.send_messages(msg)
|
||||
|
||||
def _get(
|
||||
self, queue: str,
|
||||
timeout: float | int | None = None
|
||||
) -> dict[str, Any]:
|
||||
"""Try to retrieve a single message off ``queue``."""
|
||||
# If we're not ack'ing for this queue, just change receive_mode
|
||||
recv_mode = ServiceBusReceiveMode.RECEIVE_AND_DELETE \
|
||||
if queue in self._noack_queues else ServiceBusReceiveMode.PEEK_LOCK
|
||||
|
||||
queue = self.entity_name(self.queue_name_prefix + queue)
|
||||
|
||||
queue_obj = self._get_asb_receiver(queue, recv_mode)
|
||||
messages = queue_obj.receiver.receive_messages(
|
||||
max_message_count=1,
|
||||
max_wait_time=timeout or self.wait_time_seconds)
|
||||
|
||||
if not messages:
|
||||
raise Empty()
|
||||
|
||||
# message.body is either byte or generator[bytes]
|
||||
message = messages[0]
|
||||
if not isinstance(message.body, bytes):
|
||||
body = b''.join(message.body)
|
||||
else:
|
||||
body = message.body
|
||||
|
||||
msg = loads(bytes_to_str(body))
|
||||
msg['properties']['delivery_info']['azure_message'] = message
|
||||
msg['properties']['delivery_info']['azure_queue_name'] = queue
|
||||
|
||||
return msg
|
||||
|
||||
def basic_ack(self, delivery_tag: str, multiple: bool = False) -> None:
|
||||
try:
|
||||
delivery_info = self.qos.get(delivery_tag).delivery_info
|
||||
except KeyError:
|
||||
super().basic_ack(delivery_tag)
|
||||
else:
|
||||
queue = delivery_info['azure_queue_name']
|
||||
# recv_mode is PEEK_LOCK when ack'ing messages
|
||||
queue_obj = self._get_asb_receiver(queue)
|
||||
|
||||
try:
|
||||
queue_obj.receiver.complete_message(
|
||||
delivery_info['azure_message'])
|
||||
except azure.servicebus.exceptions.MessageAlreadySettled:
|
||||
super().basic_ack(delivery_tag)
|
||||
except Exception:
|
||||
super().basic_reject(delivery_tag)
|
||||
else:
|
||||
super().basic_ack(delivery_tag)
|
||||
|
||||
def _size(self, queue: str) -> int:
|
||||
"""Return the number of messages in a queue."""
|
||||
queue = self.entity_name(self.queue_name_prefix + queue)
|
||||
props = self.queue_mgmt_service.get_queue_runtime_properties(queue)
|
||||
|
||||
return props.total_message_count
|
||||
|
||||
def _purge(self, queue) -> int:
|
||||
"""Delete all current messages in a queue."""
|
||||
# Azure doesn't provide a purge api yet
|
||||
n = 0
|
||||
max_purge_count = 10
|
||||
queue = self.entity_name(self.queue_name_prefix + queue)
|
||||
|
||||
# By default all the receivers will be in PEEK_LOCK receive mode
|
||||
queue_obj = self._queue_cache.get(queue, None)
|
||||
if queue not in self._noack_queues or \
|
||||
queue_obj is None or queue_obj.receiver is None:
|
||||
queue_obj = self._get_asb_receiver(
|
||||
queue,
|
||||
ServiceBusReceiveMode.RECEIVE_AND_DELETE, 'purge_' + queue
|
||||
)
|
||||
|
||||
while True:
|
||||
messages = queue_obj.receiver.receive_messages(
|
||||
max_message_count=max_purge_count,
|
||||
max_wait_time=0.2
|
||||
)
|
||||
n += len(messages)
|
||||
|
||||
if len(messages) < max_purge_count:
|
||||
break
|
||||
|
||||
return n
|
||||
|
||||
def close(self) -> None:
|
||||
# receivers and senders spawn threads so clean them up
|
||||
if not self.closed:
|
||||
self.closed = True
|
||||
for queue_obj in self._queue_cache.values():
|
||||
queue_obj.close()
|
||||
self._queue_cache.clear()
|
||||
|
||||
if self.connection is not None:
|
||||
self.connection.close_channel(self)
|
||||
|
||||
@cached_property
|
||||
def queue_service(self) -> ServiceBusClient:
|
||||
if self._connection_string:
|
||||
return ServiceBusClient.from_connection_string(
|
||||
self._connection_string,
|
||||
retry_total=self.retry_total,
|
||||
retry_backoff_factor=self.retry_backoff_factor,
|
||||
retry_backoff_max=self.retry_backoff_max
|
||||
)
|
||||
|
||||
return ServiceBusClient(
|
||||
self._namespace,
|
||||
self._credential,
|
||||
retry_total=self.retry_total,
|
||||
retry_backoff_factor=self.retry_backoff_factor,
|
||||
retry_backoff_max=self.retry_backoff_max
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def queue_mgmt_service(self) -> ServiceBusAdministrationClient:
|
||||
if self._connection_string:
|
||||
return ServiceBusAdministrationClient.from_connection_string(
|
||||
self._connection_string
|
||||
)
|
||||
|
||||
return ServiceBusAdministrationClient(
|
||||
self._namespace, self._credential
|
||||
)
|
||||
|
||||
@property
|
||||
def conninfo(self):
|
||||
return self.connection.client
|
||||
|
||||
@property
|
||||
def transport_options(self):
|
||||
return self.connection.client.transport_options
|
||||
|
||||
@cached_property
|
||||
def queue_name_prefix(self) -> str:
|
||||
return self.transport_options.get('queue_name_prefix', '')
|
||||
|
||||
@cached_property
|
||||
def wait_time_seconds(self) -> int:
|
||||
return self.transport_options.get('wait_time_seconds',
|
||||
self.default_wait_time_seconds)
|
||||
|
||||
@cached_property
|
||||
def peek_lock_seconds(self) -> int:
|
||||
return min(self.transport_options.get('peek_lock_seconds',
|
||||
self.default_peek_lock_seconds),
|
||||
300) # Limit upper bounds to 300
|
||||
|
||||
@cached_property
|
||||
def uamqp_keep_alive_interval(self) -> int:
|
||||
return self.transport_options.get(
|
||||
'uamqp_keep_alive_interval',
|
||||
self.default_uamqp_keep_alive_interval
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def retry_total(self) -> int:
|
||||
return self.transport_options.get(
|
||||
'retry_total', self.default_retry_total)
|
||||
|
||||
@cached_property
|
||||
def retry_backoff_factor(self) -> float:
|
||||
return self.transport_options.get(
|
||||
'retry_backoff_factor', self.default_retry_backoff_factor)
|
||||
|
||||
@cached_property
|
||||
def retry_backoff_max(self) -> int:
|
||||
return self.transport_options.get(
|
||||
'retry_backoff_max', self.default_retry_backoff_max)
|
||||
|
||||
|
||||
class Transport(virtual.Transport):
|
||||
"""Azure Service Bus transport."""
|
||||
|
||||
Channel = Channel
|
||||
|
||||
polling_interval = 1
|
||||
default_port = None
|
||||
can_parse_url = True
|
||||
|
||||
@staticmethod
|
||||
def parse_uri(uri: str) -> tuple[str, str | DefaultAzureCredential |
|
||||
ManagedIdentityCredential]:
|
||||
# URL like:
|
||||
# azureservicebus://{SAS policy name}:{SAS key}@{ServiceBus Namespace}
|
||||
# urllib parse does not work as the sas key could contain a slash
|
||||
# e.g.: azureservicebus://rootpolicy:some/key@somenamespace
|
||||
|
||||
# > 'rootpolicy:some/key@somenamespace'
|
||||
uri = uri.replace('azureservicebus://', '')
|
||||
# > 'rootpolicy:some/key', 'somenamespace'
|
||||
credential, namespace = uri.rsplit('@', 1)
|
||||
|
||||
if not namespace.endswith('.net'):
|
||||
namespace += '.servicebus.windows.net'
|
||||
|
||||
if "DefaultAzureCredential".lower() == credential.lower():
|
||||
if DefaultAzureCredential is None:
|
||||
raise ImportError('Azure Service Bus transport with a '
|
||||
'DefaultAzureCredential requires the '
|
||||
'azure-identity library')
|
||||
credential = DefaultAzureCredential()
|
||||
elif "ManagedIdentityCredential".lower() == credential.lower():
|
||||
if ManagedIdentityCredential is None:
|
||||
raise ImportError('Azure Service Bus transport with a '
|
||||
'ManagedIdentityCredential requires the '
|
||||
'azure-identity library')
|
||||
credential = ManagedIdentityCredential()
|
||||
else:
|
||||
# > 'rootpolicy', 'some/key'
|
||||
policy, sas_key = credential.split(':', 1)
|
||||
credential = f"{policy}:{sas_key}"
|
||||
|
||||
# Validate ASB connection string
|
||||
if not all([namespace, credential]):
|
||||
raise ValueError(
|
||||
'Need a URI like '
|
||||
'azureservicebus://{SAS policy name}:{SAS key}@{ServiceBus Namespace} ' # noqa
|
||||
'or the azure Endpoint connection string'
|
||||
)
|
||||
|
||||
return namespace, credential
|
||||
|
||||
@classmethod
|
||||
def as_uri(cls, uri: str, include_password=False, mask='**') -> str:
|
||||
namespace, credential = cls.parse_uri(uri)
|
||||
if isinstance(credential, str) and ":" in credential:
|
||||
policy, sas_key = credential.split(':', 1)
|
||||
return 'azureservicebus://{}:{}@{}'.format(
|
||||
policy,
|
||||
sas_key if include_password else mask,
|
||||
namespace
|
||||
)
|
||||
|
||||
return 'azureservicebus://{}@{}'.format(
|
||||
credential.__class__.__name__,
|
||||
namespace
|
||||
)
|
||||
@@ -0,0 +1,263 @@
|
||||
"""Azure Storage Queues transport module for kombu.
|
||||
|
||||
More information about Azure Storage Queues:
|
||||
https://azure.microsoft.com/en-us/services/storage/queues/
|
||||
|
||||
Features
|
||||
========
|
||||
* Type: Virtual
|
||||
* Supports Direct: *Unreviewed*
|
||||
* Supports Topic: *Unreviewed*
|
||||
* Supports Fanout: *Unreviewed*
|
||||
* Supports Priority: *Unreviewed*
|
||||
* Supports TTL: *Unreviewed*
|
||||
|
||||
Connection String
|
||||
=================
|
||||
|
||||
Connection string has the following formats:
|
||||
|
||||
.. code-block::
|
||||
|
||||
azurestoragequeues://<STORAGE_ACCOUNT_ACCESS_KEY>@<STORAGE_ACCOUNT_URL>
|
||||
azurestoragequeues://<SAS_TOKEN>@<STORAGE_ACCOUNT_URL>
|
||||
azurestoragequeues://DefaultAzureCredential@<STORAGE_ACCOUNT_URL>
|
||||
azurestoragequeues://ManagedIdentityCredential@<STORAGE_ACCOUNT_URL>
|
||||
|
||||
Note that if the access key for the storage account contains a forward slash
|
||||
(``/``), it will have to be regenerated before it can be used in the connection
|
||||
URL.
|
||||
|
||||
.. code-block::
|
||||
|
||||
azurestoragequeues://DefaultAzureCredential@<STORAGE_ACCOUNT_URL>
|
||||
azurestoragequeues://ManagedIdentityCredential@<STORAGE_ACCOUNT_URL>
|
||||
|
||||
If you wish to use an `Azure Managed Identity` you may use the
|
||||
``DefaultAzureCredential`` format of the connection string which will use
|
||||
``DefaultAzureCredential`` class in the azure-identity package. You may want to
|
||||
read the `azure-identity documentation` for more information on how the
|
||||
``DefaultAzureCredential`` works.
|
||||
|
||||
.. _azure-identity documentation:
|
||||
https://learn.microsoft.com/en-us/python/api/overview/azure/identity-readme?view=azure-python
|
||||
.. _Azure Managed Identity:
|
||||
https://learn.microsoft.com/en-us/azure/active-directory/managed-identities-azure-resources/overview
|
||||
|
||||
Transport Options
|
||||
=================
|
||||
|
||||
* ``queue_name_prefix``
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import string
|
||||
from queue import Empty
|
||||
from typing import Any
|
||||
|
||||
from azure.core.exceptions import ResourceExistsError
|
||||
|
||||
from kombu.utils.encoding import safe_str
|
||||
from kombu.utils.json import dumps, loads
|
||||
from kombu.utils.objects import cached_property
|
||||
|
||||
from . import virtual
|
||||
|
||||
try:
|
||||
from azure.storage.queue import QueueServiceClient
|
||||
except ImportError: # pragma: no cover
|
||||
QueueServiceClient = None
|
||||
|
||||
try:
|
||||
from azure.identity import (DefaultAzureCredential,
|
||||
ManagedIdentityCredential)
|
||||
except ImportError:
|
||||
DefaultAzureCredential = None
|
||||
ManagedIdentityCredential = None
|
||||
|
||||
# Azure storage queues allow only alphanumeric and dashes
|
||||
# so, replace everything with a dash
|
||||
CHARS_REPLACE_TABLE = {
|
||||
ord(c): 0x2d for c in string.punctuation
|
||||
}
|
||||
|
||||
|
||||
class Channel(virtual.Channel):
|
||||
"""Azure Storage Queues channel."""
|
||||
|
||||
domain_format: str = 'kombu%(vhost)s'
|
||||
_queue_service: QueueServiceClient | None = None
|
||||
_queue_name_cache: dict[Any, Any] = {}
|
||||
no_ack: bool = True
|
||||
_noack_queues: set[Any] = set()
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
if QueueServiceClient is None:
|
||||
raise ImportError('Azure Storage Queues transport requires the '
|
||||
'azure-storage-queue library')
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self._credential, self._url = Transport.parse_uri(
|
||||
self.conninfo.hostname
|
||||
)
|
||||
|
||||
for queue in self.queue_service.list_queues():
|
||||
self._queue_name_cache[queue['name']] = queue
|
||||
|
||||
def basic_consume(self, queue, no_ack, *args, **kwargs):
|
||||
if no_ack:
|
||||
self._noack_queues.add(queue)
|
||||
|
||||
return super().basic_consume(queue, no_ack,
|
||||
*args, **kwargs)
|
||||
|
||||
def entity_name(self, name, table=CHARS_REPLACE_TABLE) -> str:
|
||||
"""Format AMQP queue name into a valid Azure Storage Queue name."""
|
||||
return str(safe_str(name)).translate(table)
|
||||
|
||||
def _ensure_queue(self, queue):
|
||||
"""Ensure a queue exists."""
|
||||
queue = self.entity_name(self.queue_name_prefix + queue)
|
||||
try:
|
||||
q = self._queue_service.get_queue_client(
|
||||
queue=self._queue_name_cache[queue]
|
||||
)
|
||||
except KeyError:
|
||||
try:
|
||||
q = self.queue_service.create_queue(queue)
|
||||
except ResourceExistsError:
|
||||
q = self._queue_service.get_queue_client(queue=queue)
|
||||
|
||||
self._queue_name_cache[queue] = q.get_queue_properties()
|
||||
return q
|
||||
|
||||
def _delete(self, queue, *args, **kwargs):
|
||||
"""Delete queue by name."""
|
||||
queue_name = self.entity_name(queue)
|
||||
self._queue_name_cache.pop(queue_name, None)
|
||||
self.queue_service.delete_queue(queue_name)
|
||||
|
||||
def _put(self, queue, message, **kwargs):
|
||||
"""Put message onto queue."""
|
||||
q = self._ensure_queue(queue)
|
||||
encoded_message = dumps(message)
|
||||
q.send_message(encoded_message)
|
||||
|
||||
def _get(self, queue, timeout=None):
|
||||
"""Try to retrieve a single message off ``queue``."""
|
||||
q = self._ensure_queue(queue)
|
||||
|
||||
messages = q.receive_messages(messages_per_page=1, timeout=timeout)
|
||||
try:
|
||||
message = next(messages)
|
||||
except StopIteration:
|
||||
raise Empty()
|
||||
|
||||
content = loads(message.content)
|
||||
|
||||
q.delete_message(message=message)
|
||||
|
||||
return content
|
||||
|
||||
def _size(self, queue):
|
||||
"""Return the number of messages in a queue."""
|
||||
q = self._ensure_queue(queue)
|
||||
return q.get_queue_properties().approximate_message_count
|
||||
|
||||
def _purge(self, queue):
|
||||
"""Delete all current messages in a queue."""
|
||||
q = self._ensure_queue(queue)
|
||||
n = self._size(q.queue_name)
|
||||
q.clear_messages()
|
||||
return n
|
||||
|
||||
@property
|
||||
def queue_service(self) -> QueueServiceClient:
|
||||
if self._queue_service is None:
|
||||
self._queue_service = QueueServiceClient(
|
||||
account_url=self._url, credential=self._credential
|
||||
)
|
||||
|
||||
return self._queue_service
|
||||
|
||||
@property
|
||||
def conninfo(self):
|
||||
return self.connection.client
|
||||
|
||||
@property
|
||||
def transport_options(self):
|
||||
return self.connection.client.transport_options
|
||||
|
||||
@cached_property
|
||||
def queue_name_prefix(self) -> str:
|
||||
return self.transport_options.get('queue_name_prefix', '')
|
||||
|
||||
|
||||
class Transport(virtual.Transport):
|
||||
"""Azure Storage Queues transport."""
|
||||
|
||||
Channel = Channel
|
||||
|
||||
polling_interval: int = 1
|
||||
default_port: int | None = None
|
||||
can_parse_url: bool = True
|
||||
|
||||
@staticmethod
|
||||
def parse_uri(uri: str) -> tuple[str | dict, str]:
|
||||
# URL like:
|
||||
# azurestoragequeues://<STORAGE_ACCOUNT_ACCESS_KEY>@<STORAGE_ACCOUNT_URL>
|
||||
# azurestoragequeues://<SAS_TOKEN>@<STORAGE_ACCOUNT_URL>
|
||||
# azurestoragequeues://DefaultAzureCredential@<STORAGE_ACCOUNT_URL>
|
||||
# azurestoragequeues://ManagedIdentityCredential@<STORAGE_ACCOUNT_URL>
|
||||
|
||||
# urllib parse does not work as the sas key could contain a slash
|
||||
# e.g.: azurestoragequeues://some/key@someurl
|
||||
|
||||
try:
|
||||
# > 'some/key@url'
|
||||
uri = uri.replace('azurestoragequeues://', '')
|
||||
# > 'some/key', 'url'
|
||||
credential, url = uri.rsplit('@', 1)
|
||||
|
||||
if "DefaultAzureCredential".lower() == credential.lower():
|
||||
if DefaultAzureCredential is None:
|
||||
raise ImportError('Azure Storage Queues transport with a '
|
||||
'DefaultAzureCredential requires the '
|
||||
'azure-identity library')
|
||||
credential = DefaultAzureCredential()
|
||||
elif "ManagedIdentityCredential".lower() == credential.lower():
|
||||
if ManagedIdentityCredential is None:
|
||||
raise ImportError('Azure Storage Queues transport with a '
|
||||
'ManagedIdentityCredential requires the '
|
||||
'azure-identity library')
|
||||
credential = ManagedIdentityCredential()
|
||||
elif "devstoreaccount1" in url and ".core.windows.net" not in url:
|
||||
# parse credential as a dict if Azurite is being used
|
||||
credential = {
|
||||
"account_name": "devstoreaccount1",
|
||||
"account_key": credential,
|
||||
}
|
||||
|
||||
# Validate parameters
|
||||
assert all([credential, url])
|
||||
except Exception:
|
||||
raise ValueError(
|
||||
'Need a URI like '
|
||||
'azurestoragequeues://{SAS or access key}@{URL}, '
|
||||
'azurestoragequeues://DefaultAzureCredential@{URL}, '
|
||||
', or '
|
||||
'azurestoragequeues://ManagedIdentityCredential@{URL}'
|
||||
)
|
||||
|
||||
return credential, url
|
||||
|
||||
@classmethod
|
||||
def as_uri(
|
||||
cls, uri: str, include_password: bool = False, mask: str = "**"
|
||||
) -> str:
|
||||
credential, url = cls.parse_uri(uri)
|
||||
return "azurestoragequeues://{}@{}".format(
|
||||
credential if include_password else mask, url
|
||||
)
|
||||
271
venv/lib/python3.12/site-packages/kombu/transport/base.py
Normal file
271
venv/lib/python3.12/site-packages/kombu/transport/base.py
Normal file
@@ -0,0 +1,271 @@
|
||||
"""Base transport interface."""
|
||||
# flake8: noqa
|
||||
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import errno
|
||||
import socket
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from amqp.exceptions import RecoverableConnectionError
|
||||
|
||||
from kombu.exceptions import ChannelError, ConnectionError
|
||||
from kombu.message import Message
|
||||
from kombu.utils.functional import dictfilter
|
||||
from kombu.utils.objects import cached_property
|
||||
from kombu.utils.time import maybe_s_to_ms
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from types import TracebackType
|
||||
|
||||
__all__ = ('Message', 'StdChannel', 'Management', 'Transport')
|
||||
|
||||
RABBITMQ_QUEUE_ARGUMENTS = {
|
||||
'expires': ('x-expires', maybe_s_to_ms),
|
||||
'message_ttl': ('x-message-ttl', maybe_s_to_ms),
|
||||
'max_length': ('x-max-length', int),
|
||||
'max_length_bytes': ('x-max-length-bytes', int),
|
||||
'max_priority': ('x-max-priority', int),
|
||||
} # type: Mapping[str, Tuple[str, Callable]]
|
||||
|
||||
|
||||
def to_rabbitmq_queue_arguments(arguments, **options):
|
||||
# type: (Mapping, **Any) -> Dict
|
||||
"""Convert queue arguments to RabbitMQ queue arguments.
|
||||
|
||||
This is the implementation for Channel.prepare_queue_arguments
|
||||
for AMQP-based transports. It's used by both the pyamqp and librabbitmq
|
||||
transports.
|
||||
|
||||
Arguments:
|
||||
arguments (Mapping):
|
||||
User-supplied arguments (``Queue.queue_arguments``).
|
||||
|
||||
Keyword Arguments:
|
||||
expires (float): Queue expiry time in seconds.
|
||||
This will be converted to ``x-expires`` in int milliseconds.
|
||||
message_ttl (float): Message TTL in seconds.
|
||||
This will be converted to ``x-message-ttl`` in int milliseconds.
|
||||
max_length (int): Max queue length (in number of messages).
|
||||
This will be converted to ``x-max-length`` int.
|
||||
max_length_bytes (int): Max queue size in bytes.
|
||||
This will be converted to ``x-max-length-bytes`` int.
|
||||
max_priority (int): Max priority steps for queue.
|
||||
This will be converted to ``x-max-priority`` int.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Dict: RabbitMQ compatible queue arguments.
|
||||
"""
|
||||
prepared = dictfilter(dict(
|
||||
_to_rabbitmq_queue_argument(key, value)
|
||||
for key, value in options.items()
|
||||
))
|
||||
return dict(arguments, **prepared) if prepared else arguments
|
||||
|
||||
|
||||
def _to_rabbitmq_queue_argument(key, value):
|
||||
# type: (str, Any) -> Tuple[str, Any]
|
||||
opt, typ = RABBITMQ_QUEUE_ARGUMENTS[key]
|
||||
return opt, typ(value) if value is not None else value
|
||||
|
||||
|
||||
def _LeftBlank(obj, method):
|
||||
return NotImplementedError(
|
||||
'Transport {0.__module__}.{0.__name__} does not implement {1}'.format(
|
||||
obj.__class__, method))
|
||||
|
||||
|
||||
class StdChannel:
|
||||
"""Standard channel base class."""
|
||||
|
||||
no_ack_consumers = None
|
||||
|
||||
def Consumer(self, *args, **kwargs):
|
||||
from kombu.messaging import Consumer
|
||||
return Consumer(self, *args, **kwargs)
|
||||
|
||||
def Producer(self, *args, **kwargs):
|
||||
from kombu.messaging import Producer
|
||||
return Producer(self, *args, **kwargs)
|
||||
|
||||
def get_bindings(self):
|
||||
raise _LeftBlank(self, 'get_bindings')
|
||||
|
||||
def after_reply_message_received(self, queue):
|
||||
"""Callback called after RPC reply received.
|
||||
|
||||
Notes
|
||||
-----
|
||||
Reply queue semantics: can be used to delete the queue
|
||||
after transient reply message received.
|
||||
"""
|
||||
|
||||
def prepare_queue_arguments(self, arguments, **kwargs):
|
||||
return arguments
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_val: BaseException | None,
|
||||
exc_tb: TracebackType | None
|
||||
) -> None:
|
||||
self.close()
|
||||
|
||||
|
||||
class Management:
|
||||
"""AMQP Management API (incomplete)."""
|
||||
|
||||
def __init__(self, transport):
|
||||
self.transport = transport
|
||||
|
||||
def get_bindings(self):
|
||||
raise _LeftBlank(self, 'get_bindings')
|
||||
|
||||
|
||||
class Implements(dict):
|
||||
"""Helper class used to define transport features."""
|
||||
|
||||
def __getattr__(self, key):
|
||||
try:
|
||||
return self[key]
|
||||
except KeyError:
|
||||
raise AttributeError(key)
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
self[key] = value
|
||||
|
||||
def extend(self, **kwargs):
|
||||
return self.__class__(self, **kwargs)
|
||||
|
||||
|
||||
default_transport_capabilities = Implements(
|
||||
asynchronous=False,
|
||||
exchange_type=frozenset(['direct', 'topic', 'fanout', 'headers']),
|
||||
heartbeats=False,
|
||||
)
|
||||
|
||||
|
||||
class Transport:
|
||||
"""Base class for transports."""
|
||||
|
||||
Management = Management
|
||||
|
||||
#: The :class:`~kombu.Connection` owning this instance.
|
||||
client = None
|
||||
|
||||
#: Set to True if :class:`~kombu.Connection` should pass the URL
|
||||
#: unmodified.
|
||||
can_parse_url = False
|
||||
|
||||
#: Default port used when no port has been specified.
|
||||
default_port = None
|
||||
|
||||
#: Tuple of errors that can happen due to connection failure.
|
||||
connection_errors = (ConnectionError,)
|
||||
|
||||
#: Tuple of errors that can happen due to channel/method failure.
|
||||
channel_errors = (ChannelError,)
|
||||
|
||||
#: Type of driver, can be used to separate transports
|
||||
#: using the AMQP protocol (driver_type: 'amqp'),
|
||||
#: Redis (driver_type: 'redis'), etc...
|
||||
driver_type = 'N/A'
|
||||
|
||||
#: Name of driver library (e.g. 'py-amqp', 'redis').
|
||||
driver_name = 'N/A'
|
||||
|
||||
__reader = None
|
||||
|
||||
implements = default_transport_capabilities.extend()
|
||||
|
||||
def __init__(self, client, **kwargs):
|
||||
self.client = client
|
||||
|
||||
def establish_connection(self):
|
||||
raise _LeftBlank(self, 'establish_connection')
|
||||
|
||||
def close_connection(self, connection):
|
||||
raise _LeftBlank(self, 'close_connection')
|
||||
|
||||
def create_channel(self, connection):
|
||||
raise _LeftBlank(self, 'create_channel')
|
||||
|
||||
def close_channel(self, connection):
|
||||
raise _LeftBlank(self, 'close_channel')
|
||||
|
||||
def drain_events(self, connection, **kwargs):
|
||||
raise _LeftBlank(self, 'drain_events')
|
||||
|
||||
def heartbeat_check(self, connection, rate=2):
|
||||
pass
|
||||
|
||||
def driver_version(self):
|
||||
return 'N/A'
|
||||
|
||||
def get_heartbeat_interval(self, connection):
|
||||
return 0
|
||||
|
||||
def register_with_event_loop(self, connection, loop):
|
||||
pass
|
||||
|
||||
def unregister_from_event_loop(self, connection, loop):
|
||||
pass
|
||||
|
||||
def verify_connection(self, connection):
|
||||
return True
|
||||
|
||||
def _make_reader(self, connection, timeout=socket.timeout,
|
||||
error=socket.error, _unavail=(errno.EAGAIN, errno.EINTR)):
|
||||
drain_events = connection.drain_events
|
||||
|
||||
def _read(loop):
|
||||
if not connection.connected:
|
||||
raise RecoverableConnectionError('Socket was disconnected')
|
||||
try:
|
||||
drain_events(timeout=0)
|
||||
except timeout:
|
||||
return
|
||||
except error as exc:
|
||||
if exc.errno in _unavail:
|
||||
return
|
||||
raise
|
||||
loop.call_soon(_read, loop)
|
||||
|
||||
return _read
|
||||
|
||||
def qos_semantics_matches_spec(self, connection):
|
||||
return True
|
||||
|
||||
def on_readable(self, connection, loop):
|
||||
reader = self.__reader
|
||||
if reader is None:
|
||||
reader = self.__reader = self._make_reader(connection)
|
||||
reader(loop)
|
||||
|
||||
def as_uri(self, uri: str, include_password=False, mask='**') -> str:
|
||||
"""Customise the display format of the URI."""
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
def default_connection_params(self):
|
||||
return {}
|
||||
|
||||
def get_manager(self, *args, **kwargs):
|
||||
return self.Management(self)
|
||||
|
||||
@cached_property
|
||||
def manager(self):
|
||||
return self.get_manager()
|
||||
|
||||
@property
|
||||
def supports_heartbeats(self):
|
||||
return self.implements.heartbeats
|
||||
|
||||
@property
|
||||
def supports_ev(self):
|
||||
return self.implements.asynchronous
|
||||
@@ -0,0 +1,380 @@
|
||||
"""confluent-kafka transport module for Kombu.
|
||||
|
||||
Kafka transport using confluent-kafka library.
|
||||
|
||||
**References**
|
||||
|
||||
- http://docs.confluent.io/current/clients/confluent-kafka-python
|
||||
|
||||
**Limitations**
|
||||
|
||||
The confluent-kafka transport does not support PyPy environment.
|
||||
|
||||
Features
|
||||
========
|
||||
* Type: Virtual
|
||||
* Supports Direct: Yes
|
||||
* Supports Topic: Yes
|
||||
* Supports Fanout: No
|
||||
* Supports Priority: No
|
||||
* Supports TTL: No
|
||||
|
||||
Connection String
|
||||
=================
|
||||
Connection string has the following format:
|
||||
|
||||
.. code-block::
|
||||
|
||||
confluentkafka://[USER:PASSWORD@]KAFKA_ADDRESS[:PORT]
|
||||
|
||||
Transport Options
|
||||
=================
|
||||
* ``connection_wait_time_seconds`` - Time in seconds to wait for connection
|
||||
to succeed. Default ``5``
|
||||
* ``wait_time_seconds`` - Time in seconds to wait to receive messages.
|
||||
Default ``5``
|
||||
* ``security_protocol`` - Protocol used to communicate with broker.
|
||||
Visit https://github.com/edenhill/librdkafka/blob/master/CONFIGURATION.md for
|
||||
an explanation of valid values. Default ``plaintext``
|
||||
* ``sasl_mechanism`` - SASL mechanism to use for authentication.
|
||||
Visit https://github.com/edenhill/librdkafka/blob/master/CONFIGURATION.md for
|
||||
an explanation of valid values.
|
||||
* ``num_partitions`` - Number of partitions to create. Default ``1``
|
||||
* ``replication_factor`` - Replication factor of partitions. Default ``1``
|
||||
* ``topic_config`` - Topic configuration. Must be a dict whose key-value pairs
|
||||
correspond with attributes in the
|
||||
http://kafka.apache.org/documentation.html#topicconfigs.
|
||||
* ``kafka_common_config`` - Configuration applied to producer, consumer and
|
||||
admin client. Must be a dict whose key-value pairs correspond with attributes
|
||||
in the https://github.com/edenhill/librdkafka/blob/master/CONFIGURATION.md.
|
||||
* ``kafka_producer_config`` - Producer configuration. Must be a dict whose
|
||||
key-value pairs correspond with attributes in the
|
||||
https://github.com/edenhill/librdkafka/blob/master/CONFIGURATION.md.
|
||||
* ``kafka_consumer_config`` - Consumer configuration. Must be a dict whose
|
||||
key-value pairs correspond with attributes in the
|
||||
https://github.com/edenhill/librdkafka/blob/master/CONFIGURATION.md.
|
||||
* ``kafka_admin_config`` - Admin client configuration. Must be a dict whose
|
||||
key-value pairs correspond with attributes in the
|
||||
https://github.com/edenhill/librdkafka/blob/master/CONFIGURATION.md.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from queue import Empty
|
||||
|
||||
from kombu.transport import virtual
|
||||
from kombu.utils import cached_property
|
||||
from kombu.utils.encoding import str_to_bytes
|
||||
from kombu.utils.json import dumps, loads
|
||||
|
||||
try:
|
||||
import confluent_kafka
|
||||
from confluent_kafka import (Consumer, KafkaException, Producer,
|
||||
TopicPartition)
|
||||
from confluent_kafka.admin import AdminClient, NewTopic
|
||||
|
||||
KAFKA_CONNECTION_ERRORS = ()
|
||||
KAFKA_CHANNEL_ERRORS = ()
|
||||
|
||||
except ImportError:
|
||||
confluent_kafka = None
|
||||
KAFKA_CONNECTION_ERRORS = KAFKA_CHANNEL_ERRORS = ()
|
||||
|
||||
from kombu.log import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
DEFAULT_PORT = 9092
|
||||
|
||||
|
||||
class NoBrokersAvailable(KafkaException):
|
||||
"""Kafka broker is not available exception."""
|
||||
|
||||
retriable = True
|
||||
|
||||
|
||||
class Message(virtual.Message):
|
||||
"""Message object."""
|
||||
|
||||
def __init__(self, payload, channel=None, **kwargs):
|
||||
self.topic = payload.get('topic')
|
||||
super().__init__(payload, channel=channel, **kwargs)
|
||||
|
||||
|
||||
class QoS(virtual.QoS):
|
||||
"""Quality of Service guarantees."""
|
||||
|
||||
_not_yet_acked = {}
|
||||
|
||||
def can_consume(self):
|
||||
"""Return true if the channel can be consumed from.
|
||||
|
||||
:returns: True, if this QoS object can accept a message.
|
||||
:rtype: bool
|
||||
"""
|
||||
return not self.prefetch_count or len(self._not_yet_acked) < self \
|
||||
.prefetch_count
|
||||
|
||||
def can_consume_max_estimate(self):
|
||||
if self.prefetch_count:
|
||||
return self.prefetch_count - len(self._not_yet_acked)
|
||||
else:
|
||||
return 1
|
||||
|
||||
def append(self, message, delivery_tag):
|
||||
self._not_yet_acked[delivery_tag] = message
|
||||
|
||||
def get(self, delivery_tag):
|
||||
return self._not_yet_acked[delivery_tag]
|
||||
|
||||
def ack(self, delivery_tag):
|
||||
if delivery_tag not in self._not_yet_acked:
|
||||
return
|
||||
message = self._not_yet_acked.pop(delivery_tag)
|
||||
consumer = self.channel._get_consumer(message.topic)
|
||||
consumer.commit()
|
||||
|
||||
def reject(self, delivery_tag, requeue=False):
|
||||
"""Reject a message by delivery tag.
|
||||
|
||||
If requeue is True, then the last consumed message is reverted so
|
||||
it'll be refetched on the next attempt.
|
||||
If False, that message is consumed and ignored.
|
||||
"""
|
||||
if requeue:
|
||||
message = self._not_yet_acked.pop(delivery_tag)
|
||||
consumer = self.channel._get_consumer(message.topic)
|
||||
for assignment in consumer.assignment():
|
||||
topic_partition = TopicPartition(message.topic,
|
||||
assignment.partition)
|
||||
[committed_offset] = consumer.committed([topic_partition])
|
||||
consumer.seek(committed_offset)
|
||||
else:
|
||||
self.ack(delivery_tag)
|
||||
|
||||
def restore_unacked_once(self, stderr=None):
|
||||
pass
|
||||
|
||||
|
||||
class Channel(virtual.Channel):
|
||||
"""Kafka Channel."""
|
||||
|
||||
QoS = QoS
|
||||
Message = Message
|
||||
|
||||
default_wait_time_seconds = 5
|
||||
default_connection_wait_time_seconds = 5
|
||||
_client = None
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self._kafka_consumers = {}
|
||||
self._kafka_producers = {}
|
||||
|
||||
self._client = self._open()
|
||||
|
||||
def sanitize_queue_name(self, queue):
|
||||
"""Need to sanitize the name, celery sometimes pushes in @ signs."""
|
||||
return str(queue).replace('@', '')
|
||||
|
||||
def _get_producer(self, queue):
|
||||
"""Create/get a producer instance for the given topic/queue."""
|
||||
queue = self.sanitize_queue_name(queue)
|
||||
producer = self._kafka_producers.get(queue, None)
|
||||
if producer is None:
|
||||
producer = Producer({
|
||||
**self.common_config,
|
||||
**(self.options.get('kafka_producer_config') or {}),
|
||||
})
|
||||
self._kafka_producers[queue] = producer
|
||||
|
||||
return producer
|
||||
|
||||
def _get_consumer(self, queue):
|
||||
"""Create/get a consumer instance for the given topic/queue."""
|
||||
queue = self.sanitize_queue_name(queue)
|
||||
consumer = self._kafka_consumers.get(queue, None)
|
||||
if consumer is None:
|
||||
consumer = Consumer({
|
||||
'group.id': f'{queue}-consumer-group',
|
||||
'auto.offset.reset': 'earliest',
|
||||
'enable.auto.commit': False,
|
||||
**self.common_config,
|
||||
**(self.options.get('kafka_consumer_config') or {}),
|
||||
})
|
||||
consumer.subscribe([queue])
|
||||
self._kafka_consumers[queue] = consumer
|
||||
|
||||
return consumer
|
||||
|
||||
def _put(self, queue, message, **kwargs):
|
||||
"""Put a message on the topic/queue."""
|
||||
queue = self.sanitize_queue_name(queue)
|
||||
producer = self._get_producer(queue)
|
||||
producer.produce(queue, str_to_bytes(dumps(message)))
|
||||
producer.flush()
|
||||
|
||||
def _get(self, queue, **kwargs):
|
||||
"""Get a message from the topic/queue."""
|
||||
queue = self.sanitize_queue_name(queue)
|
||||
consumer = self._get_consumer(queue)
|
||||
message = None
|
||||
|
||||
try:
|
||||
message = consumer.poll(self.wait_time_seconds)
|
||||
except StopIteration:
|
||||
pass
|
||||
|
||||
if not message:
|
||||
raise Empty()
|
||||
|
||||
error = message.error()
|
||||
if error:
|
||||
logger.error(error)
|
||||
raise Empty()
|
||||
|
||||
return {**loads(message.value()), 'topic': message.topic()}
|
||||
|
||||
def _delete(self, queue, *args, **kwargs):
|
||||
"""Delete a queue/topic."""
|
||||
queue = self.sanitize_queue_name(queue)
|
||||
self._kafka_consumers[queue].close()
|
||||
self._kafka_consumers.pop(queue)
|
||||
self.client.delete_topics([queue])
|
||||
|
||||
def _size(self, queue):
|
||||
"""Get the number of pending messages in the topic/queue."""
|
||||
queue = self.sanitize_queue_name(queue)
|
||||
|
||||
consumer = self._kafka_consumers.get(queue, None)
|
||||
if consumer is None:
|
||||
return 0
|
||||
|
||||
size = 0
|
||||
for assignment in consumer.assignment():
|
||||
topic_partition = TopicPartition(queue, assignment.partition)
|
||||
(_, end_offset) = consumer.get_watermark_offsets(topic_partition)
|
||||
[committed_offset] = consumer.committed([topic_partition])
|
||||
size += end_offset - committed_offset.offset
|
||||
return size
|
||||
|
||||
def _new_queue(self, queue, **kwargs):
|
||||
"""Create a new topic if it does not exist."""
|
||||
queue = self.sanitize_queue_name(queue)
|
||||
if queue in self.client.list_topics().topics:
|
||||
return
|
||||
|
||||
topic = NewTopic(
|
||||
queue,
|
||||
num_partitions=self.options.get('num_partitions', 1),
|
||||
replication_factor=self.options.get('replication_factor', 1),
|
||||
config=self.options.get('topic_config', {})
|
||||
)
|
||||
self.client.create_topics(new_topics=[topic])
|
||||
|
||||
def _has_queue(self, queue, **kwargs):
|
||||
"""Check if a topic already exists."""
|
||||
queue = self.sanitize_queue_name(queue)
|
||||
return queue in self.client.list_topics().topics
|
||||
|
||||
def _open(self):
|
||||
client = AdminClient({
|
||||
**self.common_config,
|
||||
**(self.options.get('kafka_admin_config') or {}),
|
||||
})
|
||||
|
||||
try:
|
||||
# seems to be the only way to check connection
|
||||
client.list_topics(timeout=self.wait_time_seconds)
|
||||
except confluent_kafka.KafkaException as e:
|
||||
raise NoBrokersAvailable(e)
|
||||
|
||||
return client
|
||||
|
||||
@property
|
||||
def client(self):
|
||||
if self._client is None:
|
||||
self._client = self._open()
|
||||
return self._client
|
||||
|
||||
@property
|
||||
def options(self):
|
||||
return self.connection.client.transport_options
|
||||
|
||||
@property
|
||||
def conninfo(self):
|
||||
return self.connection.client
|
||||
|
||||
@cached_property
|
||||
def wait_time_seconds(self):
|
||||
return self.options.get(
|
||||
'wait_time_seconds', self.default_wait_time_seconds
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def connection_wait_time_seconds(self):
|
||||
return self.options.get(
|
||||
'connection_wait_time_seconds',
|
||||
self.default_connection_wait_time_seconds,
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def common_config(self):
|
||||
conninfo = self.connection.client
|
||||
config = {
|
||||
'bootstrap.servers':
|
||||
f'{conninfo.hostname}:{int(conninfo.port) or DEFAULT_PORT}',
|
||||
}
|
||||
security_protocol = self.options.get('security_protocol', 'plaintext')
|
||||
if security_protocol.lower() != 'plaintext':
|
||||
config.update({
|
||||
'security.protocol': security_protocol,
|
||||
'sasl.username': conninfo.userid,
|
||||
'sasl.password': conninfo.password,
|
||||
'sasl.mechanism': self.options.get('sasl_mechanism'),
|
||||
})
|
||||
|
||||
config.update(self.options.get('kafka_common_config') or {})
|
||||
return config
|
||||
|
||||
def close(self):
|
||||
super().close()
|
||||
self._kafka_producers = {}
|
||||
|
||||
for consumer in self._kafka_consumers.values():
|
||||
consumer.close()
|
||||
|
||||
self._kafka_consumers = {}
|
||||
|
||||
|
||||
class Transport(virtual.Transport):
|
||||
"""Kafka Transport."""
|
||||
|
||||
def as_uri(self, uri: str, include_password=False, mask='**') -> str:
|
||||
pass
|
||||
|
||||
Channel = Channel
|
||||
|
||||
default_port = DEFAULT_PORT
|
||||
|
||||
driver_type = 'kafka'
|
||||
driver_name = 'confluentkafka'
|
||||
|
||||
recoverable_connection_errors = (
|
||||
NoBrokersAvailable,
|
||||
)
|
||||
|
||||
def __init__(self, client, **kwargs):
|
||||
if confluent_kafka is None:
|
||||
raise ImportError('The confluent-kafka library is not installed')
|
||||
super().__init__(client, **kwargs)
|
||||
|
||||
def driver_version(self):
|
||||
return confluent_kafka.__version__
|
||||
|
||||
def establish_connection(self):
|
||||
return super().establish_connection()
|
||||
|
||||
def close_connection(self, connection):
|
||||
return super().close_connection(connection)
|
||||
323
venv/lib/python3.12/site-packages/kombu/transport/consul.py
Normal file
323
venv/lib/python3.12/site-packages/kombu/transport/consul.py
Normal file
@@ -0,0 +1,323 @@
|
||||
"""Consul Transport module for Kombu.
|
||||
|
||||
Features
|
||||
========
|
||||
|
||||
It uses Consul.io's Key/Value store to transport messages in Queues
|
||||
|
||||
It uses python-consul for talking to Consul's HTTP API
|
||||
|
||||
Features
|
||||
========
|
||||
* Type: Native
|
||||
* Supports Direct: Yes
|
||||
* Supports Topic: *Unreviewed*
|
||||
* Supports Fanout: *Unreviewed*
|
||||
* Supports Priority: *Unreviewed*
|
||||
* Supports TTL: *Unreviewed*
|
||||
|
||||
Connection String
|
||||
=================
|
||||
|
||||
Connection string has the following format:
|
||||
|
||||
.. code-block::
|
||||
|
||||
consul://CONSUL_ADDRESS[:PORT]
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import socket
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager
|
||||
from queue import Empty
|
||||
from time import monotonic
|
||||
|
||||
from kombu.exceptions import ChannelError
|
||||
from kombu.log import get_logger
|
||||
from kombu.utils.json import dumps, loads
|
||||
from kombu.utils.objects import cached_property
|
||||
|
||||
from . import virtual
|
||||
|
||||
try:
|
||||
import consul
|
||||
except ImportError:
|
||||
consul = None
|
||||
|
||||
logger = get_logger('kombu.transport.consul')
|
||||
|
||||
DEFAULT_PORT = 8500
|
||||
DEFAULT_HOST = 'localhost'
|
||||
|
||||
|
||||
class LockError(Exception):
|
||||
"""An error occurred while trying to acquire the lock."""
|
||||
|
||||
|
||||
class Channel(virtual.Channel):
|
||||
"""Consul Channel class which talks to the Consul Key/Value store."""
|
||||
|
||||
prefix = 'kombu'
|
||||
index = None
|
||||
timeout = '10s'
|
||||
session_ttl = 30
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
if consul is None:
|
||||
raise ImportError('Missing python-consul library')
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
port = self.connection.client.port or self.connection.default_port
|
||||
host = self.connection.client.hostname or DEFAULT_HOST
|
||||
|
||||
logger.debug('Host: %s Port: %s Timeout: %s', host, port, self.timeout)
|
||||
|
||||
self.queues = defaultdict(dict)
|
||||
|
||||
self.client = consul.Consul(host=host, port=int(port))
|
||||
|
||||
def _lock_key(self, queue):
|
||||
return f'{self.prefix}/{queue}.lock'
|
||||
|
||||
def _key_prefix(self, queue):
|
||||
return f'{self.prefix}/{queue}'
|
||||
|
||||
def _get_or_create_session(self, queue):
|
||||
"""Get or create consul session.
|
||||
|
||||
Try to renew the session if it exists, otherwise create a new
|
||||
session in Consul.
|
||||
|
||||
This session is used to acquire a lock inside Consul so that we achieve
|
||||
read-consistency between the nodes.
|
||||
|
||||
Arguments:
|
||||
---------
|
||||
queue (str): The name of the Queue.
|
||||
|
||||
Returns
|
||||
-------
|
||||
str: The ID of the session.
|
||||
"""
|
||||
try:
|
||||
session_id = self.queues[queue]['session_id']
|
||||
except KeyError:
|
||||
session_id = None
|
||||
return (self._renew_existing_session(session_id)
|
||||
if session_id is not None else self._create_new_session())
|
||||
|
||||
def _renew_existing_session(self, session_id):
|
||||
logger.debug('Trying to renew existing session %s', session_id)
|
||||
session = self.client.session.renew(session_id=session_id)
|
||||
return session.get('ID')
|
||||
|
||||
def _create_new_session(self):
|
||||
logger.debug('Creating session %s with TTL %s',
|
||||
self.lock_name, self.session_ttl)
|
||||
session_id = self.client.session.create(
|
||||
name=self.lock_name, ttl=self.session_ttl)
|
||||
logger.debug('Created session %s with id %s',
|
||||
self.lock_name, session_id)
|
||||
return session_id
|
||||
|
||||
@contextmanager
|
||||
def _queue_lock(self, queue, raising=LockError):
|
||||
"""Try to acquire a lock on the Queue.
|
||||
|
||||
It does so by creating a object called 'lock' which is locked by the
|
||||
current session..
|
||||
|
||||
This way other nodes are not able to write to the lock object which
|
||||
means that they have to wait before the lock is released.
|
||||
|
||||
Arguments:
|
||||
---------
|
||||
queue (str): The name of the Queue.
|
||||
raising (Exception): Set custom lock error class.
|
||||
|
||||
Raises
|
||||
------
|
||||
LockError: if the lock cannot be acquired.
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool: success?
|
||||
"""
|
||||
self._acquire_lock(queue, raising=raising)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self._release_lock(queue)
|
||||
|
||||
def _acquire_lock(self, queue, raising=LockError):
|
||||
session_id = self._get_or_create_session(queue)
|
||||
lock_key = self._lock_key(queue)
|
||||
|
||||
logger.debug('Trying to create lock object %s with session %s',
|
||||
lock_key, session_id)
|
||||
|
||||
if self.client.kv.put(key=lock_key,
|
||||
acquire=session_id,
|
||||
value=self.lock_name):
|
||||
self.queues[queue]['session_id'] = session_id
|
||||
return
|
||||
logger.info('Could not acquire lock on key %s', lock_key)
|
||||
raise raising()
|
||||
|
||||
def _release_lock(self, queue):
|
||||
"""Try to release a lock.
|
||||
|
||||
It does so by simply removing the lock key in Consul.
|
||||
|
||||
Arguments:
|
||||
---------
|
||||
queue (str): The name of the queue we want to release
|
||||
the lock from.
|
||||
"""
|
||||
logger.debug('Removing lock key %s', self._lock_key(queue))
|
||||
self.client.kv.delete(key=self._lock_key(queue))
|
||||
|
||||
def _destroy_session(self, queue):
|
||||
"""Destroy a previously created Consul session.
|
||||
|
||||
Will release all locks it still might hold.
|
||||
|
||||
Arguments:
|
||||
---------
|
||||
queue (str): The name of the Queue.
|
||||
"""
|
||||
logger.debug('Destroying session %s', self.queues[queue]['session_id'])
|
||||
self.client.session.destroy(self.queues[queue]['session_id'])
|
||||
|
||||
def _new_queue(self, queue, **_):
|
||||
self.queues[queue] = {'session_id': None}
|
||||
return self.client.kv.put(key=self._key_prefix(queue), value=None)
|
||||
|
||||
def _delete(self, queue, *args, **_):
|
||||
self._destroy_session(queue)
|
||||
self.queues.pop(queue, None)
|
||||
self._purge(queue)
|
||||
|
||||
def _put(self, queue, payload, **_):
|
||||
"""Put `message` onto `queue`.
|
||||
|
||||
This simply writes a key to the K/V store of Consul
|
||||
"""
|
||||
key = '{}/msg/{}_{}'.format(
|
||||
self._key_prefix(queue),
|
||||
int(round(monotonic() * 1000)),
|
||||
uuid.uuid4(),
|
||||
)
|
||||
if not self.client.kv.put(key=key, value=dumps(payload), cas=0):
|
||||
raise ChannelError(f'Cannot add key {key!r} to consul')
|
||||
|
||||
def _get(self, queue, timeout=None):
|
||||
"""Get the first available message from the queue.
|
||||
|
||||
Before it does so it acquires a lock on the Key/Value store so
|
||||
only one node reads at the same time. This is for read consistency
|
||||
"""
|
||||
with self._queue_lock(queue, raising=Empty):
|
||||
key = f'{self._key_prefix(queue)}/msg/'
|
||||
logger.debug('Fetching key %s with index %s', key, self.index)
|
||||
self.index, data = self.client.kv.get(
|
||||
key=key, recurse=True,
|
||||
index=self.index, wait=self.timeout,
|
||||
)
|
||||
|
||||
try:
|
||||
if data is None:
|
||||
raise Empty()
|
||||
|
||||
logger.debug('Removing key %s with modifyindex %s',
|
||||
data[0]['Key'], data[0]['ModifyIndex'])
|
||||
|
||||
self.client.kv.delete(key=data[0]['Key'],
|
||||
cas=data[0]['ModifyIndex'])
|
||||
|
||||
return loads(data[0]['Value'])
|
||||
except TypeError:
|
||||
pass
|
||||
|
||||
raise Empty()
|
||||
|
||||
def _purge(self, queue):
|
||||
self._destroy_session(queue)
|
||||
return self.client.kv.delete(
|
||||
key=f'{self._key_prefix(queue)}/msg/',
|
||||
recurse=True,
|
||||
)
|
||||
|
||||
def _size(self, queue):
|
||||
size = 0
|
||||
try:
|
||||
key = f'{self._key_prefix(queue)}/msg/'
|
||||
logger.debug('Fetching key recursively %s with index %s',
|
||||
key, self.index)
|
||||
self.index, data = self.client.kv.get(
|
||||
key=key, recurse=True,
|
||||
index=self.index, wait=self.timeout,
|
||||
)
|
||||
size = len(data)
|
||||
except TypeError:
|
||||
pass
|
||||
|
||||
logger.debug('Found %s keys under %s with index %s',
|
||||
size, key, self.index)
|
||||
return size
|
||||
|
||||
@cached_property
|
||||
def lock_name(self):
|
||||
return f'{socket.gethostname()}'
|
||||
|
||||
|
||||
class Transport(virtual.Transport):
|
||||
"""Consul K/V storage Transport for Kombu."""
|
||||
|
||||
Channel = Channel
|
||||
|
||||
default_port = DEFAULT_PORT
|
||||
driver_type = 'consul'
|
||||
driver_name = 'consul'
|
||||
|
||||
if consul:
|
||||
connection_errors = (
|
||||
virtual.Transport.connection_errors + (
|
||||
consul.ConsulException, consul.base.ConsulException
|
||||
)
|
||||
)
|
||||
|
||||
channel_errors = (
|
||||
virtual.Transport.channel_errors + (
|
||||
consul.ConsulException, consul.base.ConsulException
|
||||
)
|
||||
)
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
if consul is None:
|
||||
raise ImportError('Missing python-consul library')
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def verify_connection(self, connection):
|
||||
port = connection.client.port or self.default_port
|
||||
host = connection.client.hostname or DEFAULT_HOST
|
||||
|
||||
logger.debug('Verify Consul connection to %s:%s', host, port)
|
||||
|
||||
try:
|
||||
client = consul.Consul(host=host, port=int(port))
|
||||
client.agent.self()
|
||||
return True
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
return False
|
||||
|
||||
def driver_version(self):
|
||||
return consul.__version__
|
||||
300
venv/lib/python3.12/site-packages/kombu/transport/etcd.py
Normal file
300
venv/lib/python3.12/site-packages/kombu/transport/etcd.py
Normal file
@@ -0,0 +1,300 @@
|
||||
"""Etcd Transport module for Kombu.
|
||||
|
||||
It uses Etcd as a store to transport messages in Queues
|
||||
|
||||
It uses python-etcd for talking to Etcd's HTTP API
|
||||
|
||||
Features
|
||||
========
|
||||
* Type: Virtual
|
||||
* Supports Direct: *Unreviewed*
|
||||
* Supports Topic: *Unreviewed*
|
||||
* Supports Fanout: *Unreviewed*
|
||||
* Supports Priority: *Unreviewed*
|
||||
* Supports TTL: *Unreviewed*
|
||||
|
||||
Connection String
|
||||
=================
|
||||
|
||||
Connection string has the following format:
|
||||
|
||||
.. code-block::
|
||||
|
||||
'etcd'://SERVER:PORT
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import socket
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager
|
||||
from queue import Empty
|
||||
|
||||
from kombu.exceptions import ChannelError
|
||||
from kombu.log import get_logger
|
||||
from kombu.utils.json import dumps, loads
|
||||
from kombu.utils.objects import cached_property
|
||||
|
||||
from . import virtual
|
||||
|
||||
try:
|
||||
import etcd
|
||||
except ImportError:
|
||||
etcd = None
|
||||
|
||||
logger = get_logger('kombu.transport.etcd')
|
||||
|
||||
DEFAULT_PORT = 2379
|
||||
DEFAULT_HOST = 'localhost'
|
||||
|
||||
|
||||
class Channel(virtual.Channel):
|
||||
"""Etcd Channel class which talks to the Etcd."""
|
||||
|
||||
prefix = 'kombu'
|
||||
index = None
|
||||
timeout = 10
|
||||
session_ttl = 30
|
||||
lock_ttl = 10
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
if etcd is None:
|
||||
raise ImportError('Missing python-etcd library')
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
port = self.connection.client.port or self.connection.default_port
|
||||
host = self.connection.client.hostname or DEFAULT_HOST
|
||||
|
||||
logger.debug('Host: %s Port: %s Timeout: %s', host, port, self.timeout)
|
||||
|
||||
self.queues = defaultdict(dict)
|
||||
|
||||
self.client = etcd.Client(host=host, port=int(port))
|
||||
|
||||
def _key_prefix(self, queue):
|
||||
"""Create and return the `queue` with the proper prefix.
|
||||
|
||||
Arguments:
|
||||
---------
|
||||
queue (str): The name of the queue.
|
||||
"""
|
||||
return f'{self.prefix}/{queue}'
|
||||
|
||||
@contextmanager
|
||||
def _queue_lock(self, queue):
|
||||
"""Try to acquire a lock on the Queue.
|
||||
|
||||
It does so by creating a object called 'lock' which is locked by the
|
||||
current session..
|
||||
|
||||
This way other nodes are not able to write to the lock object which
|
||||
means that they have to wait before the lock is released.
|
||||
|
||||
Arguments:
|
||||
---------
|
||||
queue (str): The name of the queue.
|
||||
"""
|
||||
lock = etcd.Lock(self.client, queue)
|
||||
lock._uuid = self.lock_value
|
||||
logger.debug(f'Acquiring lock {lock.name}')
|
||||
lock.acquire(blocking=True, lock_ttl=self.lock_ttl)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
logger.debug(f'Releasing lock {lock.name}')
|
||||
lock.release()
|
||||
|
||||
def _new_queue(self, queue, **_):
|
||||
"""Create a new `queue` if the `queue` doesn't already exist.
|
||||
|
||||
Arguments:
|
||||
---------
|
||||
queue (str): The name of the queue.
|
||||
"""
|
||||
self.queues[queue] = queue
|
||||
with self._queue_lock(queue):
|
||||
try:
|
||||
return self.client.write(
|
||||
key=self._key_prefix(queue), dir=True, value=None)
|
||||
except etcd.EtcdNotFile:
|
||||
logger.debug(f'Queue "{queue}" already exists')
|
||||
return self.client.read(key=self._key_prefix(queue))
|
||||
|
||||
def _has_queue(self, queue, **kwargs):
|
||||
"""Verify that queue exists.
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool: Should return :const:`True` if the queue exists
|
||||
or :const:`False` otherwise.
|
||||
"""
|
||||
try:
|
||||
self.client.read(self._key_prefix(queue))
|
||||
return True
|
||||
except etcd.EtcdKeyNotFound:
|
||||
return False
|
||||
|
||||
def _delete(self, queue, *args, **_):
|
||||
"""Delete a `queue`.
|
||||
|
||||
Arguments:
|
||||
---------
|
||||
queue (str): The name of the queue.
|
||||
"""
|
||||
self.queues.pop(queue, None)
|
||||
self._purge(queue)
|
||||
|
||||
def _put(self, queue, payload, **_):
|
||||
"""Put `message` onto `queue`.
|
||||
|
||||
This simply writes a key to the Etcd store
|
||||
|
||||
Arguments:
|
||||
---------
|
||||
queue (str): The name of the queue.
|
||||
payload (dict): Message data which will be dumped to etcd.
|
||||
"""
|
||||
with self._queue_lock(queue):
|
||||
key = self._key_prefix(queue)
|
||||
if not self.client.write(
|
||||
key=key,
|
||||
value=dumps(payload),
|
||||
append=True):
|
||||
raise ChannelError(f'Cannot add key {key!r} to etcd')
|
||||
|
||||
def _get(self, queue, timeout=None):
|
||||
"""Get the first available message from the queue.
|
||||
|
||||
Before it does so it acquires a lock on the store so
|
||||
only one node reads at the same time. This is for read consistency
|
||||
|
||||
Arguments:
|
||||
---------
|
||||
queue (str): The name of the queue.
|
||||
timeout (int): Optional seconds to wait for a response.
|
||||
"""
|
||||
with self._queue_lock(queue):
|
||||
key = self._key_prefix(queue)
|
||||
logger.debug('Fetching key %s with index %s', key, self.index)
|
||||
|
||||
try:
|
||||
result = self.client.read(
|
||||
key=key, recursive=True,
|
||||
index=self.index, timeout=self.timeout)
|
||||
|
||||
if result is None:
|
||||
raise Empty()
|
||||
|
||||
item = result._children[-1]
|
||||
logger.debug('Removing key {}'.format(item['key']))
|
||||
|
||||
msg_content = loads(item['value'])
|
||||
self.client.delete(key=item['key'])
|
||||
return msg_content
|
||||
except (TypeError, IndexError, etcd.EtcdException) as error:
|
||||
logger.debug(f'_get failed: {type(error)}:{error}')
|
||||
|
||||
raise Empty()
|
||||
|
||||
def _purge(self, queue):
|
||||
"""Remove all `message`s from a `queue`.
|
||||
|
||||
Arguments:
|
||||
---------
|
||||
queue (str): The name of the queue.
|
||||
"""
|
||||
with self._queue_lock(queue):
|
||||
key = self._key_prefix(queue)
|
||||
logger.debug(f'Purging queue at key {key}')
|
||||
return self.client.delete(key=key, recursive=True)
|
||||
|
||||
def _size(self, queue):
|
||||
"""Return the size of the `queue`.
|
||||
|
||||
Arguments:
|
||||
---------
|
||||
queue (str): The name of the queue.
|
||||
"""
|
||||
with self._queue_lock(queue):
|
||||
size = 0
|
||||
try:
|
||||
key = self._key_prefix(queue)
|
||||
logger.debug('Fetching key recursively %s with index %s',
|
||||
key, self.index)
|
||||
result = self.client.read(
|
||||
key=key, recursive=True,
|
||||
index=self.index)
|
||||
size = len(result._children)
|
||||
except TypeError:
|
||||
pass
|
||||
|
||||
logger.debug('Found %s keys under %s with index %s',
|
||||
size, key, self.index)
|
||||
return size
|
||||
|
||||
@cached_property
|
||||
def lock_value(self):
|
||||
return f'{socket.gethostname()}.{os.getpid()}'
|
||||
|
||||
|
||||
class Transport(virtual.Transport):
|
||||
"""Etcd storage Transport for Kombu."""
|
||||
|
||||
Channel = Channel
|
||||
|
||||
default_port = DEFAULT_PORT
|
||||
driver_type = 'etcd'
|
||||
driver_name = 'python-etcd'
|
||||
polling_interval = 3
|
||||
|
||||
implements = virtual.Transport.implements.extend(
|
||||
exchange_type=frozenset(['direct']))
|
||||
|
||||
if etcd:
|
||||
connection_errors = (
|
||||
virtual.Transport.connection_errors + (etcd.EtcdException, )
|
||||
)
|
||||
|
||||
channel_errors = (
|
||||
virtual.Transport.channel_errors + (etcd.EtcdException, )
|
||||
)
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""Create a new instance of etcd.Transport."""
|
||||
if etcd is None:
|
||||
raise ImportError('Missing python-etcd library')
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def verify_connection(self, connection):
|
||||
"""Verify the connection works."""
|
||||
port = connection.client.port or self.default_port
|
||||
host = connection.client.hostname or DEFAULT_HOST
|
||||
|
||||
logger.debug('Verify Etcd connection to %s:%s', host, port)
|
||||
|
||||
try:
|
||||
etcd.Client(host=host, port=int(port))
|
||||
return True
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
return False
|
||||
|
||||
def driver_version(self):
|
||||
"""Return the version of the etcd library.
|
||||
|
||||
.. note::
|
||||
python-etcd has no __version__. This is a workaround.
|
||||
"""
|
||||
try:
|
||||
import pip.commands.freeze
|
||||
for x in pip.commands.freeze.freeze():
|
||||
if x.startswith('python-etcd'):
|
||||
return x.split('==')[1]
|
||||
except (ImportError, IndexError):
|
||||
logger.warning('Unable to find the python-etcd version.')
|
||||
return 'Unknown'
|
||||
352
venv/lib/python3.12/site-packages/kombu/transport/filesystem.py
Normal file
352
venv/lib/python3.12/site-packages/kombu/transport/filesystem.py
Normal file
@@ -0,0 +1,352 @@
|
||||
"""File-system Transport module for kombu.
|
||||
|
||||
Transport using the file-system as the message store. Messages written to the
|
||||
queue are stored in `data_folder_in` directory and
|
||||
messages read from the queue are read from `data_folder_out` directory. Both
|
||||
directories must be created manually. Simple example:
|
||||
|
||||
* Producer:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import kombu
|
||||
|
||||
conn = kombu.Connection(
|
||||
'filesystem://', transport_options={
|
||||
'data_folder_in': 'data_in', 'data_folder_out': 'data_out'
|
||||
}
|
||||
)
|
||||
conn.connect()
|
||||
|
||||
test_queue = kombu.Queue('test', routing_key='test')
|
||||
|
||||
with conn as conn:
|
||||
with conn.default_channel as channel:
|
||||
producer = kombu.Producer(channel)
|
||||
producer.publish(
|
||||
{'hello': 'world'},
|
||||
retry=True,
|
||||
exchange=test_queue.exchange,
|
||||
routing_key=test_queue.routing_key,
|
||||
declare=[test_queue],
|
||||
serializer='pickle'
|
||||
)
|
||||
|
||||
* Consumer:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import kombu
|
||||
|
||||
conn = kombu.Connection(
|
||||
'filesystem://', transport_options={
|
||||
'data_folder_in': 'data_out', 'data_folder_out': 'data_in'
|
||||
}
|
||||
)
|
||||
conn.connect()
|
||||
|
||||
def callback(body, message):
|
||||
print(body, message)
|
||||
message.ack()
|
||||
|
||||
test_queue = kombu.Queue('test', routing_key='test')
|
||||
|
||||
with conn as conn:
|
||||
with conn.default_channel as channel:
|
||||
consumer = kombu.Consumer(
|
||||
conn, [test_queue], accept=['pickle']
|
||||
)
|
||||
consumer.register_callback(callback)
|
||||
with consumer:
|
||||
conn.drain_events(timeout=1)
|
||||
|
||||
Features
|
||||
========
|
||||
* Type: Virtual
|
||||
* Supports Direct: Yes
|
||||
* Supports Topic: Yes
|
||||
* Supports Fanout: Yes
|
||||
* Supports Priority: No
|
||||
* Supports TTL: No
|
||||
|
||||
Connection String
|
||||
=================
|
||||
Connection string is in the following format:
|
||||
|
||||
.. code-block::
|
||||
|
||||
filesystem://
|
||||
|
||||
Transport Options
|
||||
=================
|
||||
* ``data_folder_in`` - directory where are messages stored when written
|
||||
to queue.
|
||||
* ``data_folder_out`` - directory from which are messages read when read from
|
||||
queue.
|
||||
* ``store_processed`` - if set to True, all processed messages are backed up to
|
||||
``processed_folder``.
|
||||
* ``processed_folder`` - directory where are backed up processed files.
|
||||
* ``control_folder`` - directory where are exchange-queue table stored.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import uuid
|
||||
from collections import namedtuple
|
||||
from pathlib import Path
|
||||
from queue import Empty
|
||||
from time import monotonic
|
||||
|
||||
from kombu.exceptions import ChannelError
|
||||
from kombu.transport import virtual
|
||||
from kombu.utils.encoding import bytes_to_str, str_to_bytes
|
||||
from kombu.utils.json import dumps, loads
|
||||
from kombu.utils.objects import cached_property
|
||||
|
||||
VERSION = (1, 0, 0)
|
||||
__version__ = '.'.join(map(str, VERSION))
|
||||
|
||||
# needs win32all to work on Windows
|
||||
if os.name == 'nt':
|
||||
|
||||
import pywintypes
|
||||
import win32con
|
||||
import win32file
|
||||
|
||||
LOCK_EX = win32con.LOCKFILE_EXCLUSIVE_LOCK
|
||||
# 0 is the default
|
||||
LOCK_SH = 0
|
||||
LOCK_NB = win32con.LOCKFILE_FAIL_IMMEDIATELY
|
||||
__overlapped = pywintypes.OVERLAPPED()
|
||||
|
||||
def lock(file, flags):
|
||||
"""Create file lock."""
|
||||
hfile = win32file._get_osfhandle(file.fileno())
|
||||
win32file.LockFileEx(hfile, flags, 0, 0xffff0000, __overlapped)
|
||||
|
||||
def unlock(file):
|
||||
"""Remove file lock."""
|
||||
hfile = win32file._get_osfhandle(file.fileno())
|
||||
win32file.UnlockFileEx(hfile, 0, 0xffff0000, __overlapped)
|
||||
|
||||
|
||||
elif os.name == 'posix':
|
||||
|
||||
import fcntl
|
||||
from fcntl import LOCK_EX, LOCK_SH
|
||||
|
||||
def lock(file, flags):
|
||||
"""Create file lock."""
|
||||
fcntl.flock(file.fileno(), flags)
|
||||
|
||||
def unlock(file):
|
||||
"""Remove file lock."""
|
||||
fcntl.flock(file.fileno(), fcntl.LOCK_UN)
|
||||
|
||||
|
||||
else:
|
||||
raise RuntimeError(
|
||||
'Filesystem plugin only defined for NT and POSIX platforms')
|
||||
|
||||
|
||||
exchange_queue_t = namedtuple("exchange_queue_t",
|
||||
["routing_key", "pattern", "queue"])
|
||||
|
||||
|
||||
class Channel(virtual.Channel):
|
||||
"""Filesystem Channel."""
|
||||
|
||||
supports_fanout = True
|
||||
|
||||
def get_table(self, exchange):
|
||||
file = self.control_folder / f"{exchange}.exchange"
|
||||
try:
|
||||
f_obj = file.open("r")
|
||||
try:
|
||||
lock(f_obj, LOCK_SH)
|
||||
exchange_table = loads(bytes_to_str(f_obj.read()))
|
||||
return [exchange_queue_t(*q) for q in exchange_table]
|
||||
finally:
|
||||
unlock(f_obj)
|
||||
f_obj.close()
|
||||
except FileNotFoundError:
|
||||
return []
|
||||
except OSError:
|
||||
raise ChannelError(f"Cannot open {file}")
|
||||
|
||||
def _queue_bind(self, exchange, routing_key, pattern, queue):
|
||||
file = self.control_folder / f"{exchange}.exchange"
|
||||
self.control_folder.mkdir(exist_ok=True)
|
||||
queue_val = exchange_queue_t(routing_key or "", pattern or "",
|
||||
queue or "")
|
||||
try:
|
||||
if file.exists():
|
||||
f_obj = file.open("rb+", buffering=0)
|
||||
lock(f_obj, LOCK_EX)
|
||||
exchange_table = loads(bytes_to_str(f_obj.read()))
|
||||
queues = [exchange_queue_t(*q) for q in exchange_table]
|
||||
if queue_val not in queues:
|
||||
queues.insert(0, queue_val)
|
||||
f_obj.seek(0)
|
||||
f_obj.write(str_to_bytes(dumps(queues)))
|
||||
else:
|
||||
f_obj = file.open("wb", buffering=0)
|
||||
lock(f_obj, LOCK_EX)
|
||||
queues = [queue_val]
|
||||
f_obj.write(str_to_bytes(dumps(queues)))
|
||||
finally:
|
||||
unlock(f_obj)
|
||||
f_obj.close()
|
||||
|
||||
def _put_fanout(self, exchange, payload, routing_key, **kwargs):
|
||||
for q in self.get_table(exchange):
|
||||
self._put(q.queue, payload, **kwargs)
|
||||
|
||||
def _put(self, queue, payload, **kwargs):
|
||||
"""Put `message` onto `queue`."""
|
||||
filename = '{}_{}.{}.msg'.format(int(round(monotonic() * 1000)),
|
||||
uuid.uuid4(), queue)
|
||||
filename = os.path.join(self.data_folder_out, filename)
|
||||
|
||||
try:
|
||||
f = open(filename, 'wb', buffering=0)
|
||||
lock(f, LOCK_EX)
|
||||
f.write(str_to_bytes(dumps(payload)))
|
||||
except OSError:
|
||||
raise ChannelError(
|
||||
f'Cannot add file {filename!r} to directory')
|
||||
finally:
|
||||
unlock(f)
|
||||
f.close()
|
||||
|
||||
def _get(self, queue):
|
||||
"""Get next message from `queue`."""
|
||||
queue_find = '.' + queue + '.msg'
|
||||
folder = os.listdir(self.data_folder_in)
|
||||
folder = sorted(folder)
|
||||
while len(folder) > 0:
|
||||
filename = folder.pop(0)
|
||||
|
||||
# only handle message for the requested queue
|
||||
if filename.find(queue_find) < 0:
|
||||
continue
|
||||
|
||||
if self.store_processed:
|
||||
processed_folder = self.processed_folder
|
||||
else:
|
||||
processed_folder = tempfile.gettempdir()
|
||||
|
||||
try:
|
||||
# move the file to the tmp/processed folder
|
||||
shutil.move(os.path.join(self.data_folder_in, filename),
|
||||
processed_folder)
|
||||
except OSError:
|
||||
# file could be locked, or removed in meantime so ignore
|
||||
continue
|
||||
|
||||
filename = os.path.join(processed_folder, filename)
|
||||
try:
|
||||
f = open(filename, 'rb')
|
||||
payload = f.read()
|
||||
f.close()
|
||||
if not self.store_processed:
|
||||
os.remove(filename)
|
||||
except OSError:
|
||||
raise ChannelError(
|
||||
f'Cannot read file {filename!r} from queue.')
|
||||
|
||||
return loads(bytes_to_str(payload))
|
||||
|
||||
raise Empty()
|
||||
|
||||
def _purge(self, queue):
|
||||
"""Remove all messages from `queue`."""
|
||||
count = 0
|
||||
queue_find = '.' + queue + '.msg'
|
||||
|
||||
folder = os.listdir(self.data_folder_in)
|
||||
while len(folder) > 0:
|
||||
filename = folder.pop()
|
||||
try:
|
||||
# only purge messages for the requested queue
|
||||
if filename.find(queue_find) < 0:
|
||||
continue
|
||||
|
||||
filename = os.path.join(self.data_folder_in, filename)
|
||||
os.remove(filename)
|
||||
|
||||
count += 1
|
||||
|
||||
except OSError:
|
||||
# we simply ignore its existence, as it was probably
|
||||
# processed by another worker
|
||||
pass
|
||||
|
||||
return count
|
||||
|
||||
def _size(self, queue):
|
||||
"""Return the number of messages in `queue` as an :class:`int`."""
|
||||
count = 0
|
||||
|
||||
queue_find = f'.{queue}.msg'
|
||||
folder = os.listdir(self.data_folder_in)
|
||||
while len(folder) > 0:
|
||||
filename = folder.pop()
|
||||
|
||||
# only handle message for the requested queue
|
||||
if filename.find(queue_find) < 0:
|
||||
continue
|
||||
|
||||
count += 1
|
||||
|
||||
return count
|
||||
|
||||
@property
|
||||
def transport_options(self):
|
||||
return self.connection.client.transport_options
|
||||
|
||||
@cached_property
|
||||
def data_folder_in(self):
|
||||
return self.transport_options.get('data_folder_in', 'data_in')
|
||||
|
||||
@cached_property
|
||||
def data_folder_out(self):
|
||||
return self.transport_options.get('data_folder_out', 'data_out')
|
||||
|
||||
@cached_property
|
||||
def store_processed(self):
|
||||
return self.transport_options.get('store_processed', False)
|
||||
|
||||
@cached_property
|
||||
def processed_folder(self):
|
||||
return self.transport_options.get('processed_folder', 'processed')
|
||||
|
||||
@property
|
||||
def control_folder(self):
|
||||
return Path(self.transport_options.get('control_folder', 'control'))
|
||||
|
||||
|
||||
class Transport(virtual.Transport):
|
||||
"""Filesystem Transport."""
|
||||
|
||||
implements = virtual.Transport.implements.extend(
|
||||
asynchronous=False,
|
||||
exchange_type=frozenset(['direct', 'topic', 'fanout'])
|
||||
)
|
||||
|
||||
Channel = Channel
|
||||
# filesystem backend state is global.
|
||||
global_state = virtual.BrokerState()
|
||||
default_port = 0
|
||||
driver_type = 'filesystem'
|
||||
driver_name = 'filesystem'
|
||||
|
||||
def __init__(self, client, **kwargs):
|
||||
super().__init__(client, **kwargs)
|
||||
self.state = self.global_state
|
||||
|
||||
def driver_version(self):
|
||||
return 'N/A'
|
||||
810
venv/lib/python3.12/site-packages/kombu/transport/gcpubsub.py
Normal file
810
venv/lib/python3.12/site-packages/kombu/transport/gcpubsub.py
Normal file
@@ -0,0 +1,810 @@
|
||||
"""GCP Pub/Sub transport module for kombu.
|
||||
|
||||
More information about GCP Pub/Sub:
|
||||
https://cloud.google.com/pubsub
|
||||
|
||||
Features
|
||||
========
|
||||
* Type: Virtual
|
||||
* Supports Direct: Yes
|
||||
* Supports Topic: No
|
||||
* Supports Fanout: Yes
|
||||
* Supports Priority: No
|
||||
* Supports TTL: No
|
||||
|
||||
Connection String
|
||||
=================
|
||||
|
||||
Connection string has the following formats:
|
||||
|
||||
.. code-block::
|
||||
|
||||
gcpubsub://projects/project-name
|
||||
|
||||
Transport Options
|
||||
=================
|
||||
* ``queue_name_prefix``: (str) Prefix for queue names.
|
||||
* ``ack_deadline_seconds``: (int) The maximum time after receiving a message
|
||||
and acknowledging it before pub/sub redelivers the message.
|
||||
* ``expiration_seconds``: (int) Subscriptions without any subscriber
|
||||
activity or changes made to their properties are removed after this period.
|
||||
Examples of subscriber activities include open connections,
|
||||
active pulls, or successful pushes.
|
||||
* ``wait_time_seconds``: (int) The maximum time to wait for new messages.
|
||||
Defaults to 10.
|
||||
* ``retry_timeout_seconds``: (int) The maximum time to wait before retrying.
|
||||
* ``bulk_max_messages``: (int) The maximum number of messages to pull in bulk.
|
||||
Defaults to 32.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import datetime
|
||||
import string
|
||||
import threading
|
||||
from concurrent.futures import (FIRST_COMPLETED, Future, ThreadPoolExecutor,
|
||||
wait)
|
||||
from contextlib import suppress
|
||||
from os import getpid
|
||||
from queue import Empty
|
||||
from threading import Lock
|
||||
from time import monotonic, sleep
|
||||
from uuid import NAMESPACE_OID, uuid3
|
||||
|
||||
from _socket import gethostname
|
||||
from _socket import timeout as socket_timeout
|
||||
from google.api_core.exceptions import (AlreadyExists, DeadlineExceeded,
|
||||
PermissionDenied)
|
||||
from google.api_core.retry import Retry
|
||||
from google.cloud import monitoring_v3
|
||||
from google.cloud.monitoring_v3 import query
|
||||
from google.cloud.pubsub_v1 import PublisherClient, SubscriberClient
|
||||
from google.cloud.pubsub_v1 import exceptions as pubsub_exceptions
|
||||
from google.cloud.pubsub_v1.publisher import exceptions as publisher_exceptions
|
||||
from google.cloud.pubsub_v1.subscriber import \
|
||||
exceptions as subscriber_exceptions
|
||||
from google.pubsub_v1 import gapic_version as package_version
|
||||
|
||||
from kombu.entity import TRANSIENT_DELIVERY_MODE
|
||||
from kombu.log import get_logger
|
||||
from kombu.utils.encoding import bytes_to_str, safe_str
|
||||
from kombu.utils.json import dumps, loads
|
||||
from kombu.utils.objects import cached_property
|
||||
|
||||
from . import virtual
|
||||
|
||||
logger = get_logger('kombu.transport.gcpubsub')
|
||||
|
||||
# dots are replaced by dash, all other punctuation replaced by underscore.
|
||||
PUNCTUATIONS_TO_REPLACE = set(string.punctuation) - {'_', '.', '-'}
|
||||
CHARS_REPLACE_TABLE = {
|
||||
ord('.'): ord('-'),
|
||||
**{ord(c): ord('_') for c in PUNCTUATIONS_TO_REPLACE},
|
||||
}
|
||||
|
||||
|
||||
class UnackedIds:
|
||||
"""Threadsafe list of ack_ids."""
|
||||
|
||||
def __init__(self):
|
||||
self._list = []
|
||||
self._lock = Lock()
|
||||
|
||||
def append(self, val):
|
||||
# append is atomic
|
||||
self._list.append(val)
|
||||
|
||||
def extend(self, vals: list):
|
||||
# extend is atomic
|
||||
self._list.extend(vals)
|
||||
|
||||
def pop(self, index=-1):
|
||||
with self._lock:
|
||||
return self._list.pop(index)
|
||||
|
||||
def remove(self, val):
|
||||
with self._lock, suppress(ValueError):
|
||||
self._list.remove(val)
|
||||
|
||||
def __len__(self):
|
||||
with self._lock:
|
||||
return len(self._list)
|
||||
|
||||
def __getitem__(self, item):
|
||||
# getitem is atomic
|
||||
return self._list[item]
|
||||
|
||||
|
||||
class AtomicCounter:
|
||||
"""Threadsafe counter.
|
||||
|
||||
Returns the value after inc/dec operations.
|
||||
"""
|
||||
|
||||
def __init__(self, initial=0):
|
||||
self._value = initial
|
||||
self._lock = Lock()
|
||||
|
||||
def inc(self, n=1):
|
||||
with self._lock:
|
||||
self._value += n
|
||||
return self._value
|
||||
|
||||
def dec(self, n=1):
|
||||
with self._lock:
|
||||
self._value -= n
|
||||
return self._value
|
||||
|
||||
def get(self):
|
||||
with self._lock:
|
||||
return self._value
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class QueueDescriptor:
|
||||
"""Pub/Sub queue descriptor."""
|
||||
|
||||
name: str
|
||||
topic_path: str # projects/{project_id}/topics/{topic_id}
|
||||
subscription_id: str
|
||||
subscription_path: str # projects/{project_id}/subscriptions/{subscription_id}
|
||||
unacked_ids: UnackedIds = dataclasses.field(default_factory=UnackedIds)
|
||||
|
||||
|
||||
class Channel(virtual.Channel):
|
||||
"""GCP Pub/Sub channel."""
|
||||
|
||||
supports_fanout = True
|
||||
do_restore = False # pub/sub does that for us
|
||||
default_wait_time_seconds = 10
|
||||
default_ack_deadline_seconds = 240
|
||||
default_expiration_seconds = 86400
|
||||
default_retry_timeout_seconds = 300
|
||||
default_bulk_max_messages = 32
|
||||
|
||||
_min_ack_deadline = 10
|
||||
_fanout_exchanges = set()
|
||||
_unacked_extender: threading.Thread = None
|
||||
_stop_extender = threading.Event()
|
||||
_n_channels = AtomicCounter()
|
||||
_queue_cache: dict[str, QueueDescriptor] = {}
|
||||
_tmp_subscriptions: set[str] = set()
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.pool = ThreadPoolExecutor()
|
||||
logger.info('new GCP pub/sub channel: %s', self.conninfo.hostname)
|
||||
|
||||
self.project_id = Transport.parse_uri(self.conninfo.hostname)
|
||||
if self._n_channels.inc() == 1:
|
||||
Channel._unacked_extender = threading.Thread(
|
||||
target=self._extend_unacked_deadline,
|
||||
daemon=True,
|
||||
)
|
||||
self._stop_extender.clear()
|
||||
Channel._unacked_extender.start()
|
||||
|
||||
def entity_name(self, name: str, table=CHARS_REPLACE_TABLE) -> str:
|
||||
"""Format AMQP queue name into a valid Pub/Sub queue name."""
|
||||
if not name.startswith(self.queue_name_prefix):
|
||||
name = self.queue_name_prefix + name
|
||||
|
||||
return str(safe_str(name)).translate(table)
|
||||
|
||||
def _queue_bind(self, exchange, routing_key, pattern, queue):
|
||||
exchange_type = self.typeof(exchange).type
|
||||
queue = self.entity_name(queue)
|
||||
logger.debug(
|
||||
'binding queue: %s to %s exchange: %s with routing_key: %s',
|
||||
queue,
|
||||
exchange_type,
|
||||
exchange,
|
||||
routing_key,
|
||||
)
|
||||
|
||||
filter_args = {}
|
||||
if exchange_type == 'direct':
|
||||
# Direct exchange is implemented as a single subscription
|
||||
# E.g. for exchange 'test_direct':
|
||||
# -topic:'test_direct'
|
||||
# -bound queue:'direct1':
|
||||
# -subscription: direct1' on topic 'test_direct'
|
||||
# -filter:routing_key'
|
||||
filter_args = {
|
||||
'filter': f'attributes.routing_key="{routing_key}"'
|
||||
}
|
||||
subscription_path = self.subscriber.subscription_path(
|
||||
self.project_id, queue
|
||||
)
|
||||
message_retention_duration = self.expiration_seconds
|
||||
elif exchange_type == 'fanout':
|
||||
# Fanout exchange is implemented as a separate subscription.
|
||||
# E.g. for exchange 'test_fanout':
|
||||
# -topic:'test_fanout'
|
||||
# -bound queue 'fanout1':
|
||||
# -subscription:'fanout1-uuid' on topic 'test_fanout'
|
||||
# -bound queue 'fanout2':
|
||||
# -subscription:'fanout2-uuid' on topic 'test_fanout'
|
||||
uid = f'{uuid3(NAMESPACE_OID, f"{gethostname()}.{getpid()}")}'
|
||||
uniq_sub_name = f'{queue}-{uid}'
|
||||
subscription_path = self.subscriber.subscription_path(
|
||||
self.project_id, uniq_sub_name
|
||||
)
|
||||
self._tmp_subscriptions.add(subscription_path)
|
||||
self._fanout_exchanges.add(exchange)
|
||||
message_retention_duration = 600
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f'exchange type {exchange_type} not implemented'
|
||||
)
|
||||
exchange_topic = self._create_topic(
|
||||
self.project_id, exchange, message_retention_duration
|
||||
)
|
||||
self._create_subscription(
|
||||
topic_path=exchange_topic,
|
||||
subscription_path=subscription_path,
|
||||
filter_args=filter_args,
|
||||
msg_retention=message_retention_duration,
|
||||
)
|
||||
qdesc = QueueDescriptor(
|
||||
name=queue,
|
||||
topic_path=exchange_topic,
|
||||
subscription_id=queue,
|
||||
subscription_path=subscription_path,
|
||||
)
|
||||
self._queue_cache[queue] = qdesc
|
||||
|
||||
def _create_topic(
|
||||
self,
|
||||
project_id: str,
|
||||
topic_id: str,
|
||||
message_retention_duration: int = None,
|
||||
) -> str:
|
||||
topic_path = self.publisher.topic_path(project_id, topic_id)
|
||||
if self._is_topic_exists(topic_path):
|
||||
# topic creation takes a while, so skip if possible
|
||||
logger.debug('topic: %s exists', topic_path)
|
||||
return topic_path
|
||||
try:
|
||||
logger.debug('creating topic: %s', topic_path)
|
||||
request = {'name': topic_path}
|
||||
if message_retention_duration:
|
||||
request[
|
||||
'message_retention_duration'
|
||||
] = f'{message_retention_duration}s'
|
||||
self.publisher.create_topic(request=request)
|
||||
except AlreadyExists:
|
||||
pass
|
||||
|
||||
return topic_path
|
||||
|
||||
def _is_topic_exists(self, topic_path: str) -> bool:
|
||||
topics = self.publisher.list_topics(
|
||||
request={"project": f'projects/{self.project_id}'}
|
||||
)
|
||||
for t in topics:
|
||||
if t.name == topic_path:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _create_subscription(
|
||||
self,
|
||||
project_id: str = None,
|
||||
topic_id: str = None,
|
||||
topic_path: str = None,
|
||||
subscription_path: str = None,
|
||||
filter_args=None,
|
||||
msg_retention: int = None,
|
||||
) -> str:
|
||||
subscription_path = (
|
||||
subscription_path
|
||||
or self.subscriber.subscription_path(self.project_id, topic_id)
|
||||
)
|
||||
topic_path = topic_path or self.publisher.topic_path(
|
||||
project_id, topic_id
|
||||
)
|
||||
try:
|
||||
logger.debug(
|
||||
'creating subscription: %s, topic: %s, filter: %s',
|
||||
subscription_path,
|
||||
topic_path,
|
||||
filter_args,
|
||||
)
|
||||
msg_retention = msg_retention or self.expiration_seconds
|
||||
self.subscriber.create_subscription(
|
||||
request={
|
||||
"name": subscription_path,
|
||||
"topic": topic_path,
|
||||
'ack_deadline_seconds': self.ack_deadline_seconds,
|
||||
'expiration_policy': {
|
||||
'ttl': f'{self.expiration_seconds}s'
|
||||
},
|
||||
'message_retention_duration': f'{msg_retention}s',
|
||||
**(filter_args or {}),
|
||||
}
|
||||
)
|
||||
except AlreadyExists:
|
||||
pass
|
||||
return subscription_path
|
||||
|
||||
def _delete(self, queue, *args, **kwargs):
|
||||
"""Delete a queue by name."""
|
||||
queue = self.entity_name(queue)
|
||||
logger.info('deleting queue: %s', queue)
|
||||
qdesc = self._queue_cache.get(queue)
|
||||
if not qdesc:
|
||||
return
|
||||
self.subscriber.delete_subscription(
|
||||
request={"subscription": qdesc.subscription_path}
|
||||
)
|
||||
self._queue_cache.pop(queue, None)
|
||||
|
||||
def _put(self, queue, message, **kwargs):
|
||||
"""Put a message onto the queue."""
|
||||
queue = self.entity_name(queue)
|
||||
qdesc = self._queue_cache[queue]
|
||||
routing_key = self._get_routing_key(message)
|
||||
logger.debug(
|
||||
'putting message to queue: %s, topic: %s, routing_key: %s',
|
||||
queue,
|
||||
qdesc.topic_path,
|
||||
routing_key,
|
||||
)
|
||||
encoded_message = dumps(message)
|
||||
self.publisher.publish(
|
||||
qdesc.topic_path,
|
||||
encoded_message.encode("utf-8"),
|
||||
routing_key=routing_key,
|
||||
)
|
||||
|
||||
def _put_fanout(self, exchange, message, routing_key, **kwargs):
|
||||
"""Put a message onto fanout exchange."""
|
||||
self._lookup(exchange, routing_key)
|
||||
topic_path = self.publisher.topic_path(self.project_id, exchange)
|
||||
logger.debug(
|
||||
'putting msg to fanout exchange: %s, topic: %s',
|
||||
exchange,
|
||||
topic_path,
|
||||
)
|
||||
encoded_message = dumps(message)
|
||||
self.publisher.publish(
|
||||
topic_path,
|
||||
encoded_message.encode("utf-8"),
|
||||
retry=Retry(deadline=self.retry_timeout_seconds),
|
||||
)
|
||||
|
||||
def _get(self, queue: str, timeout: float = None):
|
||||
"""Retrieves a single message from a queue."""
|
||||
queue = self.entity_name(queue)
|
||||
qdesc = self._queue_cache[queue]
|
||||
try:
|
||||
response = self.subscriber.pull(
|
||||
request={
|
||||
'subscription': qdesc.subscription_path,
|
||||
'max_messages': 1,
|
||||
},
|
||||
retry=Retry(deadline=self.retry_timeout_seconds),
|
||||
timeout=timeout or self.wait_time_seconds,
|
||||
)
|
||||
except DeadlineExceeded:
|
||||
raise Empty()
|
||||
|
||||
if len(response.received_messages) == 0:
|
||||
raise Empty()
|
||||
|
||||
message = response.received_messages[0]
|
||||
ack_id = message.ack_id
|
||||
payload = loads(message.message.data)
|
||||
delivery_info = payload['properties']['delivery_info']
|
||||
logger.debug(
|
||||
'queue:%s got message, ack_id: %s, payload: %s',
|
||||
queue,
|
||||
ack_id,
|
||||
payload['properties'],
|
||||
)
|
||||
if self._is_auto_ack(payload['properties']):
|
||||
logger.debug('auto acking message ack_id: %s', ack_id)
|
||||
self._do_ack([ack_id], qdesc.subscription_path)
|
||||
else:
|
||||
delivery_info['gcpubsub_message'] = {
|
||||
'queue': queue,
|
||||
'ack_id': ack_id,
|
||||
'message_id': message.message.message_id,
|
||||
'subscription_path': qdesc.subscription_path,
|
||||
}
|
||||
qdesc.unacked_ids.append(ack_id)
|
||||
|
||||
return payload
|
||||
|
||||
def _is_auto_ack(self, payload_properties: dict):
|
||||
exchange = payload_properties['delivery_info']['exchange']
|
||||
delivery_mode = payload_properties['delivery_mode']
|
||||
return (
|
||||
delivery_mode == TRANSIENT_DELIVERY_MODE
|
||||
or exchange in self._fanout_exchanges
|
||||
)
|
||||
|
||||
def _get_bulk(self, queue: str, timeout: float):
|
||||
"""Retrieves bulk of messages from a queue."""
|
||||
prefixed_queue = self.entity_name(queue)
|
||||
qdesc = self._queue_cache[prefixed_queue]
|
||||
max_messages = self._get_max_messages_estimate()
|
||||
if not max_messages:
|
||||
raise Empty()
|
||||
try:
|
||||
response = self.subscriber.pull(
|
||||
request={
|
||||
'subscription': qdesc.subscription_path,
|
||||
'max_messages': max_messages,
|
||||
},
|
||||
retry=Retry(deadline=self.retry_timeout_seconds),
|
||||
timeout=timeout or self.wait_time_seconds,
|
||||
)
|
||||
except DeadlineExceeded:
|
||||
raise Empty()
|
||||
|
||||
received_messages = response.received_messages
|
||||
if len(received_messages) == 0:
|
||||
raise Empty()
|
||||
|
||||
auto_ack_ids = []
|
||||
ret_payloads = []
|
||||
logger.debug(
|
||||
'batching %d messages from queue: %s',
|
||||
len(received_messages),
|
||||
prefixed_queue,
|
||||
)
|
||||
for message in received_messages:
|
||||
ack_id = message.ack_id
|
||||
payload = loads(bytes_to_str(message.message.data))
|
||||
delivery_info = payload['properties']['delivery_info']
|
||||
delivery_info['gcpubsub_message'] = {
|
||||
'queue': prefixed_queue,
|
||||
'ack_id': ack_id,
|
||||
'message_id': message.message.message_id,
|
||||
'subscription_path': qdesc.subscription_path,
|
||||
}
|
||||
if self._is_auto_ack(payload['properties']):
|
||||
auto_ack_ids.append(ack_id)
|
||||
else:
|
||||
qdesc.unacked_ids.append(ack_id)
|
||||
ret_payloads.append(payload)
|
||||
if auto_ack_ids:
|
||||
logger.debug('auto acking ack_ids: %s', auto_ack_ids)
|
||||
self._do_ack(auto_ack_ids, qdesc.subscription_path)
|
||||
|
||||
return queue, ret_payloads
|
||||
|
||||
def _get_max_messages_estimate(self) -> int:
|
||||
max_allowed = self.qos.can_consume_max_estimate()
|
||||
max_if_unlimited = self.bulk_max_messages
|
||||
return max_if_unlimited if max_allowed is None else max_allowed
|
||||
|
||||
def _lookup(self, exchange, routing_key, default=None):
|
||||
exchange_info = self.state.exchanges.get(exchange, {})
|
||||
if not exchange_info:
|
||||
return super()._lookup(exchange, routing_key, default)
|
||||
ret = self.typeof(exchange).lookup(
|
||||
self.get_table(exchange),
|
||||
exchange,
|
||||
routing_key,
|
||||
default,
|
||||
)
|
||||
if ret:
|
||||
return ret
|
||||
logger.debug(
|
||||
'no queues bound to exchange: %s, binding on the fly',
|
||||
exchange,
|
||||
)
|
||||
self.queue_bind(exchange, exchange, routing_key)
|
||||
return [exchange]
|
||||
|
||||
def _size(self, queue: str) -> int:
|
||||
"""Return the number of messages in a queue.
|
||||
|
||||
This is a *rough* estimation, as Pub/Sub doesn't provide
|
||||
an exact API.
|
||||
"""
|
||||
queue = self.entity_name(queue)
|
||||
if queue not in self._queue_cache:
|
||||
return 0
|
||||
qdesc = self._queue_cache[queue]
|
||||
result = query.Query(
|
||||
self.monitor,
|
||||
self.project_id,
|
||||
'pubsub.googleapis.com/subscription/num_undelivered_messages',
|
||||
end_time=datetime.datetime.now(),
|
||||
minutes=1,
|
||||
).select_resources(subscription_id=qdesc.subscription_id)
|
||||
|
||||
# monitoring API requires the caller to have the monitoring.viewer
|
||||
# role. Since we can live without the exact number of messages
|
||||
# in the queue, we can ignore the exception and allow users to
|
||||
# use the transport without this role.
|
||||
with suppress(PermissionDenied):
|
||||
return sum(
|
||||
content.points[0].value.int64_value for content in result
|
||||
)
|
||||
return -1
|
||||
|
||||
def basic_ack(self, delivery_tag, multiple=False):
|
||||
"""Acknowledge one message."""
|
||||
if multiple:
|
||||
raise NotImplementedError('multiple acks not implemented')
|
||||
|
||||
delivery_info = self.qos.get(delivery_tag).delivery_info
|
||||
pubsub_message = delivery_info['gcpubsub_message']
|
||||
ack_id = pubsub_message['ack_id']
|
||||
queue = pubsub_message['queue']
|
||||
logger.debug('ack message. queue: %s ack_id: %s', queue, ack_id)
|
||||
subscription_path = pubsub_message['subscription_path']
|
||||
self._do_ack([ack_id], subscription_path)
|
||||
qdesc = self._queue_cache[queue]
|
||||
qdesc.unacked_ids.remove(ack_id)
|
||||
super().basic_ack(delivery_tag)
|
||||
|
||||
def _do_ack(self, ack_ids: list[str], subscription_path: str):
|
||||
self.subscriber.acknowledge(
|
||||
request={"subscription": subscription_path, "ack_ids": ack_ids},
|
||||
retry=Retry(deadline=self.retry_timeout_seconds),
|
||||
)
|
||||
|
||||
def _purge(self, queue: str):
|
||||
"""Delete all current messages in a queue."""
|
||||
queue = self.entity_name(queue)
|
||||
qdesc = self._queue_cache.get(queue)
|
||||
if not qdesc:
|
||||
return
|
||||
|
||||
n = self._size(queue)
|
||||
self.subscriber.seek(
|
||||
request={
|
||||
"subscription": qdesc.subscription_path,
|
||||
"time": datetime.datetime.now(),
|
||||
}
|
||||
)
|
||||
return n
|
||||
|
||||
def _extend_unacked_deadline(self):
|
||||
thread_id = threading.get_native_id()
|
||||
logger.info(
|
||||
'unacked deadline extension thread: [%s] started',
|
||||
thread_id,
|
||||
)
|
||||
min_deadline_sleep = self._min_ack_deadline / 2
|
||||
sleep_time = max(min_deadline_sleep, self.ack_deadline_seconds / 4)
|
||||
while not self._stop_extender.wait(sleep_time):
|
||||
for qdesc in self._queue_cache.values():
|
||||
if len(qdesc.unacked_ids) == 0:
|
||||
logger.debug(
|
||||
'thread [%s]: no unacked messages for %s',
|
||||
thread_id,
|
||||
qdesc.subscription_path,
|
||||
)
|
||||
continue
|
||||
logger.debug(
|
||||
'thread [%s]: extend ack deadline for %s: %d msgs [%s]',
|
||||
thread_id,
|
||||
qdesc.subscription_path,
|
||||
len(qdesc.unacked_ids),
|
||||
list(qdesc.unacked_ids),
|
||||
)
|
||||
self.subscriber.modify_ack_deadline(
|
||||
request={
|
||||
"subscription": qdesc.subscription_path,
|
||||
"ack_ids": list(qdesc.unacked_ids),
|
||||
"ack_deadline_seconds": self.ack_deadline_seconds,
|
||||
}
|
||||
)
|
||||
logger.info(
|
||||
'unacked deadline extension thread [%s] stopped', thread_id
|
||||
)
|
||||
|
||||
def after_reply_message_received(self, queue: str):
|
||||
queue = self.entity_name(queue)
|
||||
sub = self.subscriber.subscription_path(self.project_id, queue)
|
||||
logger.debug(
|
||||
'after_reply_message_received: queue: %s, sub: %s', queue, sub
|
||||
)
|
||||
self._tmp_subscriptions.add(sub)
|
||||
|
||||
@cached_property
|
||||
def subscriber(self):
|
||||
return SubscriberClient()
|
||||
|
||||
@cached_property
|
||||
def publisher(self):
|
||||
return PublisherClient()
|
||||
|
||||
@cached_property
|
||||
def monitor(self):
|
||||
return monitoring_v3.MetricServiceClient()
|
||||
|
||||
@property
|
||||
def conninfo(self):
|
||||
return self.connection.client
|
||||
|
||||
@property
|
||||
def transport_options(self):
|
||||
return self.connection.client.transport_options
|
||||
|
||||
@cached_property
|
||||
def wait_time_seconds(self):
|
||||
return self.transport_options.get(
|
||||
'wait_time_seconds', self.default_wait_time_seconds
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def retry_timeout_seconds(self):
|
||||
return self.transport_options.get(
|
||||
'retry_timeout_seconds', self.default_retry_timeout_seconds
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def ack_deadline_seconds(self):
|
||||
return self.transport_options.get(
|
||||
'ack_deadline_seconds', self.default_ack_deadline_seconds
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def queue_name_prefix(self):
|
||||
return self.transport_options.get('queue_name_prefix', 'kombu-')
|
||||
|
||||
@cached_property
|
||||
def expiration_seconds(self):
|
||||
return self.transport_options.get(
|
||||
'expiration_seconds', self.default_expiration_seconds
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def bulk_max_messages(self):
|
||||
return self.transport_options.get(
|
||||
'bulk_max_messages', self.default_bulk_max_messages
|
||||
)
|
||||
|
||||
def close(self):
|
||||
"""Close the channel."""
|
||||
logger.debug('closing channel')
|
||||
while self._tmp_subscriptions:
|
||||
sub = self._tmp_subscriptions.pop()
|
||||
with suppress(Exception):
|
||||
logger.debug('deleting subscription: %s', sub)
|
||||
self.subscriber.delete_subscription(
|
||||
request={"subscription": sub}
|
||||
)
|
||||
if not self._n_channels.dec():
|
||||
self._stop_extender.set()
|
||||
Channel._unacked_extender.join()
|
||||
super().close()
|
||||
|
||||
@staticmethod
|
||||
def _get_routing_key(message):
|
||||
routing_key = (
|
||||
message['properties']
|
||||
.get('delivery_info', {})
|
||||
.get('routing_key', '')
|
||||
)
|
||||
return routing_key
|
||||
|
||||
|
||||
class Transport(virtual.Transport):
|
||||
"""GCP Pub/Sub transport."""
|
||||
|
||||
Channel = Channel
|
||||
|
||||
can_parse_url = True
|
||||
polling_interval = 0.1
|
||||
connection_errors = virtual.Transport.connection_errors + (
|
||||
pubsub_exceptions.TimeoutError,
|
||||
)
|
||||
channel_errors = (
|
||||
virtual.Transport.channel_errors
|
||||
+ (
|
||||
publisher_exceptions.FlowControlLimitError,
|
||||
publisher_exceptions.MessageTooLargeError,
|
||||
publisher_exceptions.PublishError,
|
||||
publisher_exceptions.TimeoutError,
|
||||
publisher_exceptions.PublishToPausedOrderingKeyException,
|
||||
)
|
||||
+ (subscriber_exceptions.AcknowledgeError,)
|
||||
)
|
||||
|
||||
driver_type = 'gcpubsub'
|
||||
driver_name = 'pubsub_v1'
|
||||
|
||||
implements = virtual.Transport.implements.extend(
|
||||
exchange_type=frozenset(['direct', 'fanout']),
|
||||
)
|
||||
|
||||
def __init__(self, client, **kwargs):
|
||||
super().__init__(client, **kwargs)
|
||||
self._pool = ThreadPoolExecutor()
|
||||
self._get_bulk_future_to_queue: dict[Future, str] = dict()
|
||||
|
||||
def driver_version(self):
|
||||
return package_version.__version__
|
||||
|
||||
@staticmethod
|
||||
def parse_uri(uri: str) -> str:
|
||||
# URL like:
|
||||
# gcpubsub://projects/project-name
|
||||
|
||||
project = uri.split('gcpubsub://projects/')[1]
|
||||
return project.strip('/')
|
||||
|
||||
@classmethod
|
||||
def as_uri(self, uri: str, include_password=False, mask='**') -> str:
|
||||
return uri or 'gcpubsub://'
|
||||
|
||||
def drain_events(self, connection, timeout=None):
|
||||
time_start = monotonic()
|
||||
polling_interval = self.polling_interval
|
||||
if timeout and polling_interval and polling_interval > timeout:
|
||||
polling_interval = timeout
|
||||
while 1:
|
||||
try:
|
||||
self._drain_from_active_queues(timeout=timeout)
|
||||
except Empty:
|
||||
if timeout and monotonic() - time_start >= timeout:
|
||||
raise socket_timeout()
|
||||
if polling_interval:
|
||||
sleep(polling_interval)
|
||||
else:
|
||||
break
|
||||
|
||||
def _drain_from_active_queues(self, timeout):
|
||||
# cleanup empty requests from prev run
|
||||
self._rm_empty_bulk_requests()
|
||||
|
||||
# submit new requests for all active queues
|
||||
# longer timeout means less frequent polling
|
||||
# and more messages in a single bulk
|
||||
self._submit_get_bulk_requests(timeout=10)
|
||||
|
||||
done, _ = wait(
|
||||
self._get_bulk_future_to_queue,
|
||||
timeout=timeout,
|
||||
return_when=FIRST_COMPLETED,
|
||||
)
|
||||
empty = {f for f in done if f.exception()}
|
||||
done -= empty
|
||||
for f in empty:
|
||||
self._get_bulk_future_to_queue.pop(f, None)
|
||||
|
||||
if not done:
|
||||
raise Empty()
|
||||
|
||||
logger.debug('got %d done get_bulk tasks', len(done))
|
||||
for f in done:
|
||||
queue, payloads = f.result()
|
||||
for payload in payloads:
|
||||
logger.debug('consuming message from queue: %s', queue)
|
||||
if queue not in self._callbacks:
|
||||
logger.warning(
|
||||
'Message for queue %s without consumers', queue
|
||||
)
|
||||
continue
|
||||
self._deliver(payload, queue)
|
||||
self._get_bulk_future_to_queue.pop(f, None)
|
||||
|
||||
def _rm_empty_bulk_requests(self):
|
||||
empty = {
|
||||
f
|
||||
for f in self._get_bulk_future_to_queue
|
||||
if f.done() and f.exception()
|
||||
}
|
||||
for f in empty:
|
||||
self._get_bulk_future_to_queue.pop(f, None)
|
||||
|
||||
def _submit_get_bulk_requests(self, timeout):
|
||||
queues_with_submitted_get_bulk = set(
|
||||
self._get_bulk_future_to_queue.values()
|
||||
)
|
||||
|
||||
for channel in self.channels:
|
||||
for queue in channel._active_queues:
|
||||
if queue in queues_with_submitted_get_bulk:
|
||||
continue
|
||||
future = self._pool.submit(channel._get_bulk, queue, timeout)
|
||||
self._get_bulk_future_to_queue[future] = queue
|
||||
190
venv/lib/python3.12/site-packages/kombu/transport/librabbitmq.py
Normal file
190
venv/lib/python3.12/site-packages/kombu/transport/librabbitmq.py
Normal file
@@ -0,0 +1,190 @@
|
||||
"""`librabbitmq`_ transport.
|
||||
|
||||
.. _`librabbitmq`: https://pypi.org/project/librabbitmq/
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import socket
|
||||
import warnings
|
||||
|
||||
import librabbitmq as amqp
|
||||
from librabbitmq import ChannelError, ConnectionError
|
||||
|
||||
from kombu.utils.amq_manager import get_manager
|
||||
from kombu.utils.text import version_string_as_tuple
|
||||
|
||||
from . import base
|
||||
from .base import to_rabbitmq_queue_arguments
|
||||
|
||||
W_VERSION = """
|
||||
librabbitmq version too old to detect RabbitMQ version information
|
||||
so make sure you are using librabbitmq 1.5 when using rabbitmq > 3.3
|
||||
"""
|
||||
DEFAULT_PORT = 5672
|
||||
DEFAULT_SSL_PORT = 5671
|
||||
|
||||
NO_SSL_ERROR = """\
|
||||
ssl not supported by librabbitmq, please use pyamqp:// or stunnel\
|
||||
"""
|
||||
|
||||
|
||||
class Message(base.Message):
|
||||
"""AMQP Message (librabbitmq)."""
|
||||
|
||||
def __init__(self, channel, props, info, body):
|
||||
super().__init__(
|
||||
channel=channel,
|
||||
body=body,
|
||||
delivery_info=info,
|
||||
properties=props,
|
||||
delivery_tag=info.get('delivery_tag'),
|
||||
content_type=props.get('content_type'),
|
||||
content_encoding=props.get('content_encoding'),
|
||||
headers=props.get('headers'))
|
||||
|
||||
|
||||
class Channel(amqp.Channel, base.StdChannel):
|
||||
"""AMQP Channel (librabbitmq)."""
|
||||
|
||||
Message = Message
|
||||
|
||||
def prepare_message(self, body, priority=None,
|
||||
content_type=None, content_encoding=None,
|
||||
headers=None, properties=None):
|
||||
"""Encapsulate data into a AMQP message."""
|
||||
properties = properties if properties is not None else {}
|
||||
properties.update({'content_type': content_type,
|
||||
'content_encoding': content_encoding,
|
||||
'headers': headers})
|
||||
# Don't include priority if it's not an integer.
|
||||
# If that's the case librabbitmq will fail
|
||||
# and raise an exception.
|
||||
if priority is not None:
|
||||
properties['priority'] = priority
|
||||
return body, properties
|
||||
|
||||
def prepare_queue_arguments(self, arguments, **kwargs):
|
||||
arguments = to_rabbitmq_queue_arguments(arguments, **kwargs)
|
||||
return {k.encode('utf8'): v for k, v in arguments.items()}
|
||||
|
||||
|
||||
class Connection(amqp.Connection):
|
||||
"""AMQP Connection (librabbitmq)."""
|
||||
|
||||
Channel = Channel
|
||||
Message = Message
|
||||
|
||||
|
||||
class Transport(base.Transport):
|
||||
"""AMQP Transport (librabbitmq)."""
|
||||
|
||||
Connection = Connection
|
||||
|
||||
default_port = DEFAULT_PORT
|
||||
default_ssl_port = DEFAULT_SSL_PORT
|
||||
|
||||
connection_errors = (
|
||||
base.Transport.connection_errors + (
|
||||
ConnectionError, socket.error, IOError, OSError)
|
||||
)
|
||||
channel_errors = (
|
||||
base.Transport.channel_errors + (ChannelError,)
|
||||
)
|
||||
driver_type = 'amqp'
|
||||
driver_name = 'librabbitmq'
|
||||
|
||||
implements = base.Transport.implements.extend(
|
||||
asynchronous=True,
|
||||
heartbeats=False,
|
||||
)
|
||||
|
||||
def __init__(self, client, **kwargs):
|
||||
self.client = client
|
||||
self.default_port = kwargs.get('default_port') or self.default_port
|
||||
self.default_ssl_port = (kwargs.get('default_ssl_port') or
|
||||
self.default_ssl_port)
|
||||
self.__reader = None
|
||||
|
||||
def driver_version(self):
|
||||
return amqp.__version__
|
||||
|
||||
def create_channel(self, connection):
|
||||
return connection.channel()
|
||||
|
||||
def drain_events(self, connection, **kwargs):
|
||||
return connection.drain_events(**kwargs)
|
||||
|
||||
def establish_connection(self):
|
||||
"""Establish connection to the AMQP broker."""
|
||||
conninfo = self.client
|
||||
for name, default_value in self.default_connection_params.items():
|
||||
if not getattr(conninfo, name, None):
|
||||
setattr(conninfo, name, default_value)
|
||||
if conninfo.ssl:
|
||||
raise NotImplementedError(NO_SSL_ERROR)
|
||||
opts = dict({
|
||||
'host': conninfo.host,
|
||||
'userid': conninfo.userid,
|
||||
'password': conninfo.password,
|
||||
'virtual_host': conninfo.virtual_host,
|
||||
'login_method': conninfo.login_method,
|
||||
'insist': conninfo.insist,
|
||||
'ssl': conninfo.ssl,
|
||||
'connect_timeout': conninfo.connect_timeout,
|
||||
}, **conninfo.transport_options or {})
|
||||
conn = self.Connection(**opts)
|
||||
conn.client = self.client
|
||||
self.client.drain_events = conn.drain_events
|
||||
return conn
|
||||
|
||||
def close_connection(self, connection):
|
||||
"""Close the AMQP broker connection."""
|
||||
self.client.drain_events = None
|
||||
connection.close()
|
||||
|
||||
def _collect(self, connection):
|
||||
if connection is not None:
|
||||
for channel in connection.channels.values():
|
||||
channel.connection = None
|
||||
try:
|
||||
os.close(connection.fileno())
|
||||
except (OSError, ValueError):
|
||||
pass
|
||||
connection.channels.clear()
|
||||
connection.callbacks.clear()
|
||||
self.client.drain_events = None
|
||||
self.client = None
|
||||
|
||||
def verify_connection(self, connection):
|
||||
return connection.connected
|
||||
|
||||
def register_with_event_loop(self, connection, loop):
|
||||
loop.add_reader(
|
||||
connection.fileno(), self.on_readable, connection, loop,
|
||||
)
|
||||
|
||||
def get_manager(self, *args, **kwargs):
|
||||
return get_manager(self.client, *args, **kwargs)
|
||||
|
||||
def qos_semantics_matches_spec(self, connection):
|
||||
try:
|
||||
props = connection.server_properties
|
||||
except AttributeError:
|
||||
warnings.warn(UserWarning(W_VERSION))
|
||||
else:
|
||||
if props.get('product') == 'RabbitMQ':
|
||||
return version_string_as_tuple(props['version']) < (3, 3)
|
||||
return True
|
||||
|
||||
@property
|
||||
def default_connection_params(self):
|
||||
return {
|
||||
'userid': 'guest',
|
||||
'password': 'guest',
|
||||
'port': (self.default_ssl_port if self.client.ssl
|
||||
else self.default_port),
|
||||
'hostname': 'localhost',
|
||||
'login_method': 'PLAIN',
|
||||
}
|
||||
106
venv/lib/python3.12/site-packages/kombu/transport/memory.py
Normal file
106
venv/lib/python3.12/site-packages/kombu/transport/memory.py
Normal file
@@ -0,0 +1,106 @@
|
||||
"""In-memory transport module for Kombu.
|
||||
|
||||
Simple transport using memory for storing messages.
|
||||
Messages can be passed only between threads.
|
||||
|
||||
Features
|
||||
========
|
||||
* Type: Virtual
|
||||
* Supports Direct: Yes
|
||||
* Supports Topic: Yes
|
||||
* Supports Fanout: No
|
||||
* Supports Priority: No
|
||||
* Supports TTL: Yes
|
||||
|
||||
Connection String
|
||||
=================
|
||||
Connection string is in the following format:
|
||||
|
||||
.. code-block::
|
||||
|
||||
memory://
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import defaultdict
|
||||
from queue import Queue
|
||||
|
||||
from . import base, virtual
|
||||
|
||||
|
||||
class Channel(virtual.Channel):
|
||||
"""In-memory Channel."""
|
||||
|
||||
events = defaultdict(set)
|
||||
queues = {}
|
||||
do_restore = False
|
||||
supports_fanout = True
|
||||
|
||||
def _has_queue(self, queue, **kwargs):
|
||||
return queue in self.queues
|
||||
|
||||
def _new_queue(self, queue, **kwargs):
|
||||
if queue not in self.queues:
|
||||
self.queues[queue] = Queue()
|
||||
|
||||
def _get(self, queue, timeout=None):
|
||||
return self._queue_for(queue).get(block=False)
|
||||
|
||||
def _queue_for(self, queue):
|
||||
if queue not in self.queues:
|
||||
self.queues[queue] = Queue()
|
||||
return self.queues[queue]
|
||||
|
||||
def _queue_bind(self, *args):
|
||||
pass
|
||||
|
||||
def _put_fanout(self, exchange, message, routing_key=None, **kwargs):
|
||||
for queue in self._lookup(exchange, routing_key):
|
||||
self._queue_for(queue).put(message)
|
||||
|
||||
def _put(self, queue, message, **kwargs):
|
||||
self._queue_for(queue).put(message)
|
||||
|
||||
def _size(self, queue):
|
||||
return self._queue_for(queue).qsize()
|
||||
|
||||
def _delete(self, queue, *args, **kwargs):
|
||||
self.queues.pop(queue, None)
|
||||
|
||||
def _purge(self, queue):
|
||||
q = self._queue_for(queue)
|
||||
size = q.qsize()
|
||||
q.queue.clear()
|
||||
return size
|
||||
|
||||
def close(self):
|
||||
super().close()
|
||||
for queue in self.queues.values():
|
||||
queue.empty()
|
||||
self.queues = {}
|
||||
|
||||
def after_reply_message_received(self, queue):
|
||||
pass
|
||||
|
||||
|
||||
class Transport(virtual.Transport):
|
||||
"""In-memory Transport."""
|
||||
|
||||
Channel = Channel
|
||||
|
||||
#: memory backend state is global.
|
||||
global_state = virtual.BrokerState()
|
||||
|
||||
implements = base.Transport.implements
|
||||
|
||||
driver_type = 'memory'
|
||||
driver_name = 'memory'
|
||||
|
||||
def __init__(self, client, **kwargs):
|
||||
super().__init__(client, **kwargs)
|
||||
self.state = self.global_state
|
||||
|
||||
def driver_version(self):
|
||||
return 'N/A'
|
||||
534
venv/lib/python3.12/site-packages/kombu/transport/mongodb.py
Normal file
534
venv/lib/python3.12/site-packages/kombu/transport/mongodb.py
Normal file
@@ -0,0 +1,534 @@
|
||||
# copyright: (c) 2010 - 2013 by Flavio Percoco Premoli.
|
||||
# license: BSD, see LICENSE for more details.
|
||||
|
||||
"""MongoDB transport module for kombu.
|
||||
|
||||
Features
|
||||
========
|
||||
* Type: Virtual
|
||||
* Supports Direct: Yes
|
||||
* Supports Topic: Yes
|
||||
* Supports Fanout: Yes
|
||||
* Supports Priority: Yes
|
||||
* Supports TTL: Yes
|
||||
|
||||
Connection String
|
||||
=================
|
||||
*Unreviewed*
|
||||
|
||||
Transport Options
|
||||
=================
|
||||
|
||||
* ``connect_timeout``,
|
||||
* ``ssl``,
|
||||
* ``ttl``,
|
||||
* ``capped_queue_size``,
|
||||
* ``default_hostname``,
|
||||
* ``default_port``,
|
||||
* ``default_database``,
|
||||
* ``messages_collection``,
|
||||
* ``routing_collection``,
|
||||
* ``broadcast_collection``,
|
||||
* ``queues_collection``,
|
||||
* ``calc_queue_size``,
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
from queue import Empty
|
||||
|
||||
import pymongo
|
||||
from pymongo import MongoClient, errors, uri_parser
|
||||
from pymongo.cursor import CursorType
|
||||
|
||||
from kombu.exceptions import VersionMismatch
|
||||
from kombu.utils.compat import _detect_environment
|
||||
from kombu.utils.encoding import bytes_to_str
|
||||
from kombu.utils.json import dumps, loads
|
||||
from kombu.utils.objects import cached_property
|
||||
from kombu.utils.url import maybe_sanitize_url
|
||||
|
||||
from . import virtual
|
||||
from .base import to_rabbitmq_queue_arguments
|
||||
|
||||
E_SERVER_VERSION = """\
|
||||
Kombu requires MongoDB version 1.3+ (server is {0})\
|
||||
"""
|
||||
|
||||
E_NO_TTL_INDEXES = """\
|
||||
Kombu requires MongoDB version 2.2+ (server is {0}) for TTL indexes support\
|
||||
"""
|
||||
|
||||
|
||||
class BroadcastCursor:
|
||||
"""Cursor for broadcast queues."""
|
||||
|
||||
def __init__(self, cursor):
|
||||
self._cursor = cursor
|
||||
self._offset = 0
|
||||
self.purge(rewind=False)
|
||||
|
||||
def get_size(self):
|
||||
return self._cursor.collection.count_documents({}) - self._offset
|
||||
|
||||
def close(self):
|
||||
self._cursor.close()
|
||||
|
||||
def purge(self, rewind=True):
|
||||
if rewind:
|
||||
self._cursor.rewind()
|
||||
|
||||
# Fast-forward the cursor past old events
|
||||
self._offset = self._cursor.collection.count_documents({})
|
||||
self._cursor = self._cursor.skip(self._offset)
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
while True:
|
||||
try:
|
||||
msg = next(self._cursor)
|
||||
except pymongo.errors.OperationFailure as exc:
|
||||
# In some cases tailed cursor can become invalid
|
||||
# and have to be reinitalized
|
||||
if 'not valid at server' in str(exc):
|
||||
self.purge()
|
||||
|
||||
continue
|
||||
|
||||
raise
|
||||
else:
|
||||
break
|
||||
|
||||
self._offset += 1
|
||||
|
||||
return msg
|
||||
next = __next__
|
||||
|
||||
|
||||
class Channel(virtual.Channel):
|
||||
"""MongoDB Channel."""
|
||||
|
||||
supports_fanout = True
|
||||
|
||||
# Mutable container. Shared by all class instances
|
||||
_fanout_queues = {}
|
||||
|
||||
# Options
|
||||
ssl = False
|
||||
ttl = False
|
||||
connect_timeout = None
|
||||
capped_queue_size = 100000
|
||||
calc_queue_size = True
|
||||
|
||||
default_hostname = '127.0.0.1'
|
||||
default_port = 27017
|
||||
default_database = 'kombu_default'
|
||||
|
||||
messages_collection = 'messages'
|
||||
routing_collection = 'messages.routing'
|
||||
broadcast_collection = 'messages.broadcast'
|
||||
queues_collection = 'messages.queues'
|
||||
|
||||
from_transport_options = (virtual.Channel.from_transport_options + (
|
||||
'connect_timeout', 'ssl', 'ttl', 'capped_queue_size',
|
||||
'default_hostname', 'default_port', 'default_database',
|
||||
'messages_collection', 'routing_collection',
|
||||
'broadcast_collection', 'queues_collection',
|
||||
'calc_queue_size',
|
||||
))
|
||||
|
||||
def __init__(self, *vargs, **kwargs):
|
||||
super().__init__(*vargs, **kwargs)
|
||||
|
||||
self._broadcast_cursors = {}
|
||||
|
||||
# Evaluate connection
|
||||
self.client
|
||||
|
||||
# AbstractChannel/Channel interface implementation
|
||||
|
||||
def _new_queue(self, queue, **kwargs):
|
||||
if self.ttl:
|
||||
self.queues.update_one(
|
||||
{'_id': queue},
|
||||
{
|
||||
'$set': {
|
||||
'_id': queue,
|
||||
'options': kwargs,
|
||||
'expire_at': self._get_queue_expire(
|
||||
kwargs, 'x-expires'
|
||||
),
|
||||
},
|
||||
},
|
||||
upsert=True)
|
||||
|
||||
def _get(self, queue):
|
||||
if queue in self._fanout_queues:
|
||||
try:
|
||||
msg = next(self._get_broadcast_cursor(queue))
|
||||
except StopIteration:
|
||||
msg = None
|
||||
else:
|
||||
msg = self.messages.find_one_and_delete(
|
||||
{'queue': queue},
|
||||
sort=[('priority', pymongo.ASCENDING)],
|
||||
)
|
||||
|
||||
if self.ttl:
|
||||
self._update_queues_expire(queue)
|
||||
|
||||
if msg is None:
|
||||
raise Empty()
|
||||
|
||||
return loads(bytes_to_str(msg['payload']))
|
||||
|
||||
def _size(self, queue):
|
||||
# Do not calculate actual queue size if requested
|
||||
# for performance considerations
|
||||
if not self.calc_queue_size:
|
||||
return super()._size(queue)
|
||||
|
||||
if queue in self._fanout_queues:
|
||||
return self._get_broadcast_cursor(queue).get_size()
|
||||
|
||||
return self.messages.count_documents({'queue': queue})
|
||||
|
||||
def _put(self, queue, message, **kwargs):
|
||||
data = {
|
||||
'payload': dumps(message),
|
||||
'queue': queue,
|
||||
'priority': self._get_message_priority(message, reverse=True)
|
||||
}
|
||||
|
||||
if self.ttl:
|
||||
data['expire_at'] = self._get_queue_expire(queue, 'x-message-ttl')
|
||||
msg_expire = self._get_message_expire(message)
|
||||
if msg_expire is not None and (
|
||||
data['expire_at'] is None or msg_expire < data['expire_at']
|
||||
):
|
||||
data['expire_at'] = msg_expire
|
||||
|
||||
self.messages.insert_one(data)
|
||||
|
||||
def _put_fanout(self, exchange, message, routing_key, **kwargs):
|
||||
self.broadcast.insert_one({'payload': dumps(message),
|
||||
'queue': exchange})
|
||||
|
||||
def _purge(self, queue):
|
||||
size = self._size(queue)
|
||||
|
||||
if queue in self._fanout_queues:
|
||||
self._get_broadcast_cursor(queue).purge()
|
||||
else:
|
||||
self.messages.delete_many({'queue': queue})
|
||||
|
||||
return size
|
||||
|
||||
def get_table(self, exchange):
|
||||
localRoutes = frozenset(self.state.exchanges[exchange]['table'])
|
||||
brokerRoutes = self.routing.find(
|
||||
{'exchange': exchange}
|
||||
)
|
||||
|
||||
return localRoutes | frozenset(
|
||||
(r['routing_key'], r['pattern'], r['queue'])
|
||||
for r in brokerRoutes
|
||||
)
|
||||
|
||||
def _queue_bind(self, exchange, routing_key, pattern, queue):
|
||||
if self.typeof(exchange).type == 'fanout':
|
||||
self._create_broadcast_cursor(
|
||||
exchange, routing_key, pattern, queue)
|
||||
self._fanout_queues[queue] = exchange
|
||||
|
||||
lookup = {
|
||||
'exchange': exchange,
|
||||
'queue': queue,
|
||||
'routing_key': routing_key,
|
||||
'pattern': pattern,
|
||||
}
|
||||
|
||||
data = lookup.copy()
|
||||
|
||||
if self.ttl:
|
||||
data['expire_at'] = self._get_queue_expire(queue, 'x-expires')
|
||||
|
||||
self.routing.update_one(lookup, {'$set': data}, upsert=True)
|
||||
|
||||
def queue_delete(self, queue, **kwargs):
|
||||
self.routing.delete_many({'queue': queue})
|
||||
|
||||
if self.ttl:
|
||||
self.queues.delete_one({'_id': queue})
|
||||
|
||||
super().queue_delete(queue, **kwargs)
|
||||
|
||||
if queue in self._fanout_queues:
|
||||
try:
|
||||
cursor = self._broadcast_cursors.pop(queue)
|
||||
except KeyError:
|
||||
pass
|
||||
else:
|
||||
cursor.close()
|
||||
|
||||
self._fanout_queues.pop(queue)
|
||||
|
||||
# Implementation details
|
||||
|
||||
def _parse_uri(self, scheme='mongodb://'):
|
||||
# See mongodb uri documentation:
|
||||
# https://docs.mongodb.org/manual/reference/connection-string/
|
||||
client = self.connection.client
|
||||
hostname = client.hostname
|
||||
|
||||
if hostname.startswith('srv://'):
|
||||
scheme = 'mongodb+srv://'
|
||||
hostname = 'mongodb+' + hostname
|
||||
|
||||
if not hostname.startswith(scheme):
|
||||
hostname = scheme + hostname
|
||||
|
||||
if not hostname[len(scheme):]:
|
||||
hostname += self.default_hostname
|
||||
|
||||
if client.userid and '@' not in hostname:
|
||||
head, tail = hostname.split('://')
|
||||
|
||||
credentials = client.userid
|
||||
if client.password:
|
||||
credentials += ':' + client.password
|
||||
|
||||
hostname = head + '://' + credentials + '@' + tail
|
||||
|
||||
port = client.port if client.port else self.default_port
|
||||
|
||||
# We disable validating and normalization parameters here,
|
||||
# because pymongo will validate and normalize parameters later in __init__ of MongoClient
|
||||
parsed = uri_parser.parse_uri(hostname, port, validate=False)
|
||||
|
||||
dbname = parsed['database'] or client.virtual_host
|
||||
|
||||
if dbname in ('/', None):
|
||||
dbname = self.default_database
|
||||
|
||||
options = {
|
||||
'auto_start_request': True,
|
||||
'ssl': self.ssl,
|
||||
'connectTimeoutMS': (int(self.connect_timeout * 1000)
|
||||
if self.connect_timeout else None),
|
||||
}
|
||||
options.update(parsed['options'])
|
||||
options = self._prepare_client_options(options)
|
||||
|
||||
if 'tls' in options:
|
||||
options.pop('ssl')
|
||||
|
||||
return hostname, dbname, options
|
||||
|
||||
def _prepare_client_options(self, options):
|
||||
if pymongo.version_tuple >= (3,):
|
||||
options.pop('auto_start_request', None)
|
||||
if isinstance(options.get('readpreference'), int):
|
||||
modes = pymongo.read_preferences._MONGOS_MODES
|
||||
options['readpreference'] = modes[options['readpreference']]
|
||||
return options
|
||||
|
||||
def prepare_queue_arguments(self, arguments, **kwargs):
|
||||
return to_rabbitmq_queue_arguments(arguments, **kwargs)
|
||||
|
||||
def _open(self, scheme='mongodb://'):
|
||||
hostname, dbname, conf = self._parse_uri(scheme=scheme)
|
||||
|
||||
conf['host'] = hostname
|
||||
|
||||
env = _detect_environment()
|
||||
if env == 'gevent':
|
||||
from gevent import monkey
|
||||
monkey.patch_all()
|
||||
elif env == 'eventlet':
|
||||
from eventlet import monkey_patch
|
||||
monkey_patch()
|
||||
|
||||
mongoconn = MongoClient(**conf)
|
||||
database = mongoconn[dbname]
|
||||
|
||||
version_str = mongoconn.server_info()['version']
|
||||
version_str = version_str.split('-')[0]
|
||||
version = tuple(map(int, version_str.split('.')))
|
||||
|
||||
if version < (1, 3):
|
||||
raise VersionMismatch(E_SERVER_VERSION.format(version_str))
|
||||
elif self.ttl and version < (2, 2):
|
||||
raise VersionMismatch(E_NO_TTL_INDEXES.format(version_str))
|
||||
|
||||
return database
|
||||
|
||||
def _create_broadcast(self, database):
|
||||
"""Create capped collection for broadcast messages."""
|
||||
if self.broadcast_collection in database.list_collection_names():
|
||||
return
|
||||
|
||||
database.create_collection(self.broadcast_collection,
|
||||
size=self.capped_queue_size,
|
||||
capped=True)
|
||||
|
||||
def _ensure_indexes(self, database):
|
||||
"""Ensure indexes on collections."""
|
||||
messages = database[self.messages_collection]
|
||||
messages.create_index(
|
||||
[('queue', 1), ('priority', 1), ('_id', 1)], background=True,
|
||||
)
|
||||
|
||||
database[self.broadcast_collection].create_index([('queue', 1)])
|
||||
|
||||
routing = database[self.routing_collection]
|
||||
routing.create_index([('queue', 1), ('exchange', 1)])
|
||||
|
||||
if self.ttl:
|
||||
messages.create_index([('expire_at', 1)], expireAfterSeconds=0)
|
||||
routing.create_index([('expire_at', 1)], expireAfterSeconds=0)
|
||||
|
||||
database[self.queues_collection].create_index(
|
||||
[('expire_at', 1)], expireAfterSeconds=0)
|
||||
|
||||
def _create_client(self):
|
||||
"""Actually creates connection."""
|
||||
database = self._open()
|
||||
self._create_broadcast(database)
|
||||
self._ensure_indexes(database)
|
||||
|
||||
return database
|
||||
|
||||
@cached_property
|
||||
def client(self):
|
||||
return self._create_client()
|
||||
|
||||
@cached_property
|
||||
def messages(self):
|
||||
return self.client[self.messages_collection]
|
||||
|
||||
@cached_property
|
||||
def routing(self):
|
||||
return self.client[self.routing_collection]
|
||||
|
||||
@cached_property
|
||||
def broadcast(self):
|
||||
return self.client[self.broadcast_collection]
|
||||
|
||||
@cached_property
|
||||
def queues(self):
|
||||
return self.client[self.queues_collection]
|
||||
|
||||
def _get_broadcast_cursor(self, queue):
|
||||
try:
|
||||
return self._broadcast_cursors[queue]
|
||||
except KeyError:
|
||||
# Cursor may be absent when Channel created more than once.
|
||||
# _fanout_queues is a class-level mutable attribute so it's
|
||||
# shared over all Channel instances.
|
||||
return self._create_broadcast_cursor(
|
||||
self._fanout_queues[queue], None, None, queue,
|
||||
)
|
||||
|
||||
def _create_broadcast_cursor(self, exchange, routing_key, pattern, queue):
|
||||
if pymongo.version_tuple >= (3, ):
|
||||
query = {
|
||||
'filter': {'queue': exchange},
|
||||
'cursor_type': CursorType.TAILABLE,
|
||||
}
|
||||
else:
|
||||
query = {
|
||||
'query': {'queue': exchange},
|
||||
'tailable': True,
|
||||
}
|
||||
|
||||
cursor = self.broadcast.find(**query)
|
||||
ret = self._broadcast_cursors[queue] = BroadcastCursor(cursor)
|
||||
return ret
|
||||
|
||||
def _get_message_expire(self, message):
|
||||
value = message.get('properties', {}).get('expiration')
|
||||
if value is not None:
|
||||
return self.get_now() + datetime.timedelta(milliseconds=int(value))
|
||||
|
||||
def _get_queue_expire(self, queue, argument):
|
||||
"""Get expiration header named `argument` of queue definition.
|
||||
|
||||
Note:
|
||||
----
|
||||
`queue` must be either queue name or options itself.
|
||||
"""
|
||||
if isinstance(queue, str):
|
||||
doc = self.queues.find_one({'_id': queue})
|
||||
|
||||
if not doc:
|
||||
return
|
||||
|
||||
data = doc['options']
|
||||
else:
|
||||
data = queue
|
||||
|
||||
try:
|
||||
value = data['arguments'][argument]
|
||||
except (KeyError, TypeError):
|
||||
return
|
||||
|
||||
return self.get_now() + datetime.timedelta(milliseconds=value)
|
||||
|
||||
def _update_queues_expire(self, queue):
|
||||
"""Update expiration field on queues documents."""
|
||||
expire_at = self._get_queue_expire(queue, 'x-expires')
|
||||
|
||||
if not expire_at:
|
||||
return
|
||||
|
||||
self.routing.update_many(
|
||||
{'queue': queue}, {'$set': {'expire_at': expire_at}})
|
||||
self.queues.update_many(
|
||||
{'_id': queue}, {'$set': {'expire_at': expire_at}})
|
||||
|
||||
def get_now(self):
|
||||
"""Return current time in UTC."""
|
||||
return datetime.datetime.utcnow()
|
||||
|
||||
|
||||
class Transport(virtual.Transport):
|
||||
"""MongoDB Transport."""
|
||||
|
||||
Channel = Channel
|
||||
|
||||
can_parse_url = True
|
||||
polling_interval = 1
|
||||
default_port = Channel.default_port
|
||||
connection_errors = (
|
||||
virtual.Transport.connection_errors + (errors.ConnectionFailure,)
|
||||
)
|
||||
channel_errors = (
|
||||
virtual.Transport.channel_errors + (
|
||||
errors.ConnectionFailure,
|
||||
errors.OperationFailure)
|
||||
)
|
||||
driver_type = 'mongodb'
|
||||
driver_name = 'pymongo'
|
||||
|
||||
implements = virtual.Transport.implements.extend(
|
||||
exchange_type=frozenset(['direct', 'topic', 'fanout']),
|
||||
)
|
||||
|
||||
def driver_version(self):
|
||||
return pymongo.version
|
||||
|
||||
def as_uri(self, uri: str, include_password=False, mask='**') -> str:
|
||||
if not uri:
|
||||
return 'mongodb://'
|
||||
if include_password:
|
||||
return uri
|
||||
|
||||
if ',' not in uri:
|
||||
return maybe_sanitize_url(uri)
|
||||
|
||||
uri1, remainder = uri.split(',', 1)
|
||||
return ','.join([maybe_sanitize_url(uri1), remainder])
|
||||
@@ -0,0 +1,134 @@
|
||||
"""Native Delayed Delivery API.
|
||||
|
||||
Only relevant for RabbitMQ.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from kombu import Connection, Exchange, Queue, binding
|
||||
from kombu.log import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
MAX_NUMBER_OF_BITS_TO_USE = 28
|
||||
MAX_LEVEL = MAX_NUMBER_OF_BITS_TO_USE - 1
|
||||
CELERY_DELAYED_DELIVERY_EXCHANGE = "celery_delayed_delivery"
|
||||
|
||||
|
||||
def level_name(level: int) -> str:
|
||||
"""Generates the delayed queue/exchange name based on the level."""
|
||||
if level < 0:
|
||||
raise ValueError("level must be a non-negative number")
|
||||
|
||||
return f"celery_delayed_{level}"
|
||||
|
||||
|
||||
def declare_native_delayed_delivery_exchanges_and_queues(connection: Connection, queue_type: str) -> None:
|
||||
"""Declares all native delayed delivery exchanges and queues."""
|
||||
if queue_type != "classic" and queue_type != "quorum":
|
||||
raise ValueError("queue_type must be either classic or quorum")
|
||||
|
||||
channel = connection.channel()
|
||||
|
||||
routing_key: str = "1.#"
|
||||
|
||||
for level in range(27, -1, - 1):
|
||||
current_level = level_name(level)
|
||||
next_level = level_name(level - 1) if level > 0 else None
|
||||
|
||||
delayed_exchange: Exchange = Exchange(
|
||||
current_level, type="topic").bind(channel)
|
||||
delayed_exchange.declare()
|
||||
|
||||
queue_arguments = {
|
||||
"x-queue-type": queue_type,
|
||||
"x-overflow": "reject-publish",
|
||||
"x-message-ttl": pow(2, level) * 1000,
|
||||
"x-dead-letter-exchange": next_level if level > 0 else CELERY_DELAYED_DELIVERY_EXCHANGE,
|
||||
}
|
||||
|
||||
if queue_type == 'quorum':
|
||||
queue_arguments["x-dead-letter-strategy"] = "at-least-once"
|
||||
|
||||
delayed_queue: Queue = Queue(
|
||||
current_level,
|
||||
queue_arguments=queue_arguments
|
||||
).bind(channel)
|
||||
delayed_queue.declare()
|
||||
delayed_queue.bind_to(current_level, routing_key)
|
||||
|
||||
routing_key = "*." + routing_key
|
||||
|
||||
routing_key = "0.#"
|
||||
for level in range(27, 0, - 1):
|
||||
current_level = level_name(level)
|
||||
next_level = level_name(level - 1) if level > 0 else None
|
||||
|
||||
next_level_exchange: Exchange = Exchange(
|
||||
next_level, type="topic").bind(channel)
|
||||
|
||||
next_level_exchange.bind_to(current_level, routing_key)
|
||||
|
||||
routing_key = "*." + routing_key
|
||||
|
||||
delivery_exchange: Exchange = Exchange(
|
||||
CELERY_DELAYED_DELIVERY_EXCHANGE, type="topic").bind(channel)
|
||||
delivery_exchange.declare()
|
||||
delivery_exchange.bind_to(level_name(0), routing_key)
|
||||
|
||||
|
||||
def bind_queue_to_native_delayed_delivery_exchange(connection: Connection, queue: Queue) -> None:
|
||||
"""Bind a queue to the native delayed delivery exchange.
|
||||
|
||||
When a message arrives at the delivery exchange, it must be forwarded to
|
||||
the original exchange and queue. To accomplish this, the function retrieves
|
||||
the exchange or binding objects associated with the queue and binds them to
|
||||
the delivery exchange.
|
||||
|
||||
|
||||
:param connection: The connection object used to create and manage the channel.
|
||||
:type connection: Connection
|
||||
:param queue: The queue to be bound to the native delayed delivery exchange.
|
||||
:type queue: Queue
|
||||
|
||||
Warning:
|
||||
-------
|
||||
If a direct exchange is detected, a warning will be logged because
|
||||
native delayed delivery does not support direct exchanges.
|
||||
"""
|
||||
channel = connection.channel()
|
||||
queue = queue.bind(channel)
|
||||
|
||||
bindings: set[binding] = set()
|
||||
|
||||
if queue.exchange:
|
||||
bindings.add(binding(
|
||||
queue.exchange,
|
||||
routing_key=queue.routing_key,
|
||||
arguments=queue.binding_arguments
|
||||
))
|
||||
elif queue.bindings:
|
||||
bindings = queue.bindings
|
||||
|
||||
for binding_entry in bindings:
|
||||
exchange: Exchange = binding_entry.exchange.bind(channel)
|
||||
if exchange.type == 'direct':
|
||||
logger.warning(f"Exchange {exchange.name} is a direct exchange "
|
||||
f"and native delayed delivery do not support direct exchanges.\n"
|
||||
f"ETA tasks published to this exchange will block the worker until the ETA arrives.")
|
||||
continue
|
||||
|
||||
routing_key = binding_entry.routing_key if binding_entry.routing_key.startswith(
|
||||
'#') else f"#.{binding_entry.routing_key}"
|
||||
exchange.bind_to(CELERY_DELAYED_DELIVERY_EXCHANGE, routing_key=routing_key)
|
||||
queue.bind_to(exchange.name, routing_key=routing_key)
|
||||
|
||||
|
||||
def calculate_routing_key(countdown: int, routing_key: str) -> str:
|
||||
"""Calculate the routing key for publishing a delayed message based on the countdown."""
|
||||
if countdown < 1:
|
||||
raise ValueError("countdown must be a positive number")
|
||||
|
||||
if not routing_key:
|
||||
raise ValueError("routing_key must be non-empty")
|
||||
|
||||
return '.'.join(list(f'{countdown:028b}')) + f'.{routing_key}'
|
||||
253
venv/lib/python3.12/site-packages/kombu/transport/pyamqp.py
Normal file
253
venv/lib/python3.12/site-packages/kombu/transport/pyamqp.py
Normal file
@@ -0,0 +1,253 @@
|
||||
"""pyamqp transport module for Kombu.
|
||||
|
||||
Pure-Python amqp transport using py-amqp library.
|
||||
|
||||
Features
|
||||
========
|
||||
* Type: Native
|
||||
* Supports Direct: Yes
|
||||
* Supports Topic: Yes
|
||||
* Supports Fanout: Yes
|
||||
* Supports Priority: Yes
|
||||
* Supports TTL: Yes
|
||||
|
||||
Connection String
|
||||
=================
|
||||
Connection string can have the following formats:
|
||||
|
||||
.. code-block::
|
||||
|
||||
amqp://[USER:PASSWORD@]BROKER_ADDRESS[:PORT][/VIRTUALHOST]
|
||||
[USER:PASSWORD@]BROKER_ADDRESS[:PORT][/VIRTUALHOST]
|
||||
amqp://
|
||||
|
||||
For TLS encryption use:
|
||||
|
||||
.. code-block::
|
||||
|
||||
amqps://[USER:PASSWORD@]BROKER_ADDRESS[:PORT][/VIRTUALHOST]
|
||||
|
||||
Transport Options
|
||||
=================
|
||||
Transport Options are passed to constructor of underlying py-amqp
|
||||
:class:`~kombu.connection.Connection` class.
|
||||
|
||||
Using TLS
|
||||
=========
|
||||
Transport over TLS can be enabled by ``ssl`` parameter of
|
||||
:class:`~kombu.Connection` class. By setting ``ssl=True``, TLS transport is
|
||||
used::
|
||||
|
||||
conn = Connect('amqp://', ssl=True)
|
||||
|
||||
This is equivalent to ``amqps://`` transport URI::
|
||||
|
||||
conn = Connect('amqps://')
|
||||
|
||||
For adding additional parameters to underlying TLS, ``ssl`` parameter should
|
||||
be set with dict instead of True::
|
||||
|
||||
conn = Connect('amqp://broker.example.com', ssl={
|
||||
'keyfile': '/path/to/keyfile'
|
||||
'certfile': '/path/to/certfile',
|
||||
'ca_certs': '/path/to/ca_certfile'
|
||||
}
|
||||
)
|
||||
|
||||
All parameters are passed to ``ssl`` parameter of
|
||||
:class:`amqp.connection.Connection` class.
|
||||
|
||||
SSL option ``server_hostname`` can be set to ``None`` which is causing using
|
||||
hostname from broker URL. This is useful when failover is used to fill
|
||||
``server_hostname`` with currently used broker::
|
||||
|
||||
conn = Connect('amqp://broker1.example.com;broker2.example.com', ssl={
|
||||
'server_hostname': None
|
||||
}
|
||||
)
|
||||
"""
|
||||
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import amqp
|
||||
|
||||
from kombu.utils.amq_manager import get_manager
|
||||
from kombu.utils.text import version_string_as_tuple
|
||||
|
||||
from . import base
|
||||
from .base import to_rabbitmq_queue_arguments
|
||||
|
||||
DEFAULT_PORT = 5672
|
||||
DEFAULT_SSL_PORT = 5671
|
||||
|
||||
|
||||
class Message(base.Message):
|
||||
"""AMQP Message."""
|
||||
|
||||
def __init__(self, msg, channel=None, **kwargs):
|
||||
props = msg.properties
|
||||
super().__init__(
|
||||
body=msg.body,
|
||||
channel=channel,
|
||||
delivery_tag=msg.delivery_tag,
|
||||
content_type=props.get('content_type'),
|
||||
content_encoding=props.get('content_encoding'),
|
||||
delivery_info=msg.delivery_info,
|
||||
properties=msg.properties,
|
||||
headers=props.get('application_headers') or {},
|
||||
**kwargs)
|
||||
|
||||
|
||||
class Channel(amqp.Channel, base.StdChannel):
|
||||
"""AMQP Channel."""
|
||||
|
||||
Message = Message
|
||||
|
||||
def prepare_message(self, body, priority=None,
|
||||
content_type=None, content_encoding=None,
|
||||
headers=None, properties=None, _Message=amqp.Message):
|
||||
"""Prepare message so that it can be sent using this transport."""
|
||||
return _Message(
|
||||
body,
|
||||
priority=priority,
|
||||
content_type=content_type,
|
||||
content_encoding=content_encoding,
|
||||
application_headers=headers,
|
||||
**properties or {}
|
||||
)
|
||||
|
||||
def prepare_queue_arguments(self, arguments, **kwargs):
|
||||
return to_rabbitmq_queue_arguments(arguments, **kwargs)
|
||||
|
||||
def message_to_python(self, raw_message):
|
||||
"""Convert encoded message body back to a Python value."""
|
||||
return self.Message(raw_message, channel=self)
|
||||
|
||||
|
||||
class Connection(amqp.Connection):
|
||||
"""AMQP Connection."""
|
||||
|
||||
Channel = Channel
|
||||
|
||||
|
||||
class Transport(base.Transport):
|
||||
"""AMQP Transport."""
|
||||
|
||||
Connection = Connection
|
||||
|
||||
default_port = DEFAULT_PORT
|
||||
default_ssl_port = DEFAULT_SSL_PORT
|
||||
|
||||
# it's very annoying that pyamqp sometimes raises AttributeError
|
||||
# if the connection is lost, but nothing we can do about that here.
|
||||
connection_errors = amqp.Connection.connection_errors
|
||||
channel_errors = amqp.Connection.channel_errors
|
||||
recoverable_connection_errors = \
|
||||
amqp.Connection.recoverable_connection_errors
|
||||
recoverable_channel_errors = amqp.Connection.recoverable_channel_errors
|
||||
|
||||
driver_name = 'py-amqp'
|
||||
driver_type = 'amqp'
|
||||
|
||||
implements = base.Transport.implements.extend(
|
||||
asynchronous=True,
|
||||
heartbeats=True,
|
||||
)
|
||||
|
||||
def __init__(self, client,
|
||||
default_port=None, default_ssl_port=None, **kwargs):
|
||||
self.client = client
|
||||
self.default_port = default_port or self.default_port
|
||||
self.default_ssl_port = default_ssl_port or self.default_ssl_port
|
||||
|
||||
def driver_version(self):
|
||||
return amqp.__version__
|
||||
|
||||
def create_channel(self, connection):
|
||||
return connection.channel()
|
||||
|
||||
def drain_events(self, connection, **kwargs):
|
||||
return connection.drain_events(**kwargs)
|
||||
|
||||
def _collect(self, connection):
|
||||
if connection is not None:
|
||||
connection.collect()
|
||||
|
||||
def establish_connection(self):
|
||||
"""Establish connection to the AMQP broker."""
|
||||
conninfo = self.client
|
||||
for name, default_value in self.default_connection_params.items():
|
||||
if not getattr(conninfo, name, None):
|
||||
setattr(conninfo, name, default_value)
|
||||
if conninfo.hostname == 'localhost':
|
||||
conninfo.hostname = '127.0.0.1'
|
||||
# when server_hostname is None, use hostname from URI.
|
||||
if isinstance(conninfo.ssl, dict) and \
|
||||
'server_hostname' in conninfo.ssl and \
|
||||
conninfo.ssl['server_hostname'] is None:
|
||||
conninfo.ssl['server_hostname'] = conninfo.hostname
|
||||
opts = dict({
|
||||
'host': conninfo.host,
|
||||
'userid': conninfo.userid,
|
||||
'password': conninfo.password,
|
||||
'login_method': conninfo.login_method,
|
||||
'virtual_host': conninfo.virtual_host,
|
||||
'insist': conninfo.insist,
|
||||
'ssl': conninfo.ssl,
|
||||
'connect_timeout': conninfo.connect_timeout,
|
||||
'heartbeat': conninfo.heartbeat,
|
||||
}, **conninfo.transport_options or {})
|
||||
conn = self.Connection(**opts)
|
||||
conn.client = self.client
|
||||
conn.connect()
|
||||
return conn
|
||||
|
||||
def verify_connection(self, connection):
|
||||
return connection.connected
|
||||
|
||||
def close_connection(self, connection):
|
||||
"""Close the AMQP broker connection."""
|
||||
connection.client = None
|
||||
connection.close()
|
||||
|
||||
def get_heartbeat_interval(self, connection):
|
||||
return connection.heartbeat
|
||||
|
||||
def register_with_event_loop(self, connection, loop):
|
||||
connection.transport.raise_on_initial_eintr = True
|
||||
loop.add_reader(connection.sock, self.on_readable, connection, loop)
|
||||
|
||||
def heartbeat_check(self, connection, rate=2):
|
||||
return connection.heartbeat_tick(rate=rate)
|
||||
|
||||
def qos_semantics_matches_spec(self, connection):
|
||||
props = connection.server_properties
|
||||
if props.get('product') == 'RabbitMQ':
|
||||
return version_string_as_tuple(props['version']) < (3, 3)
|
||||
return True
|
||||
|
||||
@property
|
||||
def default_connection_params(self):
|
||||
return {
|
||||
'userid': 'guest',
|
||||
'password': 'guest',
|
||||
'port': (self.default_ssl_port if self.client.ssl
|
||||
else self.default_port),
|
||||
'hostname': 'localhost',
|
||||
'login_method': 'PLAIN',
|
||||
}
|
||||
|
||||
def get_manager(self, *args, **kwargs):
|
||||
return get_manager(self.client, *args, **kwargs)
|
||||
|
||||
|
||||
class SSLTransport(Transport):
|
||||
"""AMQP SSL Transport."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# ugh, not exactly pure, but hey, it's python.
|
||||
if not self.client.ssl: # not dict or False
|
||||
self.client.ssl = True
|
||||
212
venv/lib/python3.12/site-packages/kombu/transport/pyro.py
Normal file
212
venv/lib/python3.12/site-packages/kombu/transport/pyro.py
Normal file
@@ -0,0 +1,212 @@
|
||||
"""Pyro transport module for kombu.
|
||||
|
||||
Pyro transport, and Kombu Broker daemon.
|
||||
|
||||
Requires the :mod:`Pyro4` library to be installed.
|
||||
|
||||
Features
|
||||
========
|
||||
* Type: Virtual
|
||||
* Supports Direct: Yes
|
||||
* Supports Topic: Yes
|
||||
* Supports Fanout: No
|
||||
* Supports Priority: No
|
||||
* Supports TTL: No
|
||||
|
||||
Connection String
|
||||
=================
|
||||
|
||||
To use the Pyro transport with Kombu, use an url of the form:
|
||||
|
||||
.. code-block::
|
||||
|
||||
pyro://localhost/kombu.broker
|
||||
|
||||
The hostname is where the transport will be looking for a Pyro name server,
|
||||
which is used in turn to locate the kombu.broker Pyro service.
|
||||
This broker can be launched by simply executing this transport module directly,
|
||||
with the command: ``python -m kombu.transport.pyro``
|
||||
|
||||
Transport Options
|
||||
=================
|
||||
"""
|
||||
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from queue import Empty, Queue
|
||||
|
||||
from kombu.exceptions import reraise
|
||||
from kombu.log import get_logger
|
||||
from kombu.utils.objects import cached_property
|
||||
|
||||
from . import virtual
|
||||
|
||||
try:
|
||||
import Pyro4 as pyro
|
||||
from Pyro4.errors import NamingError
|
||||
from Pyro4.util import SerializerBase
|
||||
except ImportError: # pragma: no cover
|
||||
pyro = NamingError = SerializerBase = None
|
||||
|
||||
DEFAULT_PORT = 9090
|
||||
E_NAMESERVER = """\
|
||||
Unable to locate pyro nameserver on host {0.hostname}\
|
||||
"""
|
||||
E_LOOKUP = """\
|
||||
Unable to lookup '{0.virtual_host}' in pyro nameserver on host {0.hostname}\
|
||||
"""
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class Channel(virtual.Channel):
|
||||
"""Pyro Channel."""
|
||||
|
||||
def close(self):
|
||||
super().close()
|
||||
if self.shared_queues:
|
||||
self.shared_queues._pyroRelease()
|
||||
|
||||
def queues(self):
|
||||
return self.shared_queues.get_queue_names()
|
||||
|
||||
def _new_queue(self, queue, **kwargs):
|
||||
if queue not in self.queues():
|
||||
self.shared_queues.new_queue(queue)
|
||||
|
||||
def _has_queue(self, queue, **kwargs):
|
||||
return self.shared_queues.has_queue(queue)
|
||||
|
||||
def _get(self, queue, timeout=None):
|
||||
queue = self._queue_for(queue)
|
||||
return self.shared_queues.get(queue)
|
||||
|
||||
def _queue_for(self, queue):
|
||||
if queue not in self.queues():
|
||||
self.shared_queues.new_queue(queue)
|
||||
return queue
|
||||
|
||||
def _put(self, queue, message, **kwargs):
|
||||
queue = self._queue_for(queue)
|
||||
self.shared_queues.put(queue, message)
|
||||
|
||||
def _size(self, queue):
|
||||
return self.shared_queues.size(queue)
|
||||
|
||||
def _delete(self, queue, *args, **kwargs):
|
||||
self.shared_queues.delete(queue)
|
||||
|
||||
def _purge(self, queue):
|
||||
return self.shared_queues.purge(queue)
|
||||
|
||||
def after_reply_message_received(self, queue):
|
||||
pass
|
||||
|
||||
@cached_property
|
||||
def shared_queues(self):
|
||||
return self.connection.shared_queues
|
||||
|
||||
|
||||
class Transport(virtual.Transport):
|
||||
"""Pyro Transport."""
|
||||
|
||||
Channel = Channel
|
||||
|
||||
#: memory backend state is global.
|
||||
# TODO: To be checked whether state can be per-Transport
|
||||
global_state = virtual.BrokerState()
|
||||
|
||||
default_port = DEFAULT_PORT
|
||||
|
||||
driver_type = driver_name = 'pyro'
|
||||
|
||||
def __init__(self, client, **kwargs):
|
||||
super().__init__(client, **kwargs)
|
||||
self.state = self.global_state
|
||||
|
||||
def _open(self):
|
||||
logger.debug("trying Pyro nameserver to find the broker daemon")
|
||||
conninfo = self.client
|
||||
try:
|
||||
nameserver = pyro.locateNS(host=conninfo.hostname,
|
||||
port=self.default_port)
|
||||
except NamingError:
|
||||
reraise(NamingError, NamingError(E_NAMESERVER.format(conninfo)),
|
||||
sys.exc_info()[2])
|
||||
try:
|
||||
# name of registered pyro object
|
||||
uri = nameserver.lookup(conninfo.virtual_host)
|
||||
return pyro.Proxy(uri)
|
||||
except NamingError:
|
||||
reraise(NamingError, NamingError(E_LOOKUP.format(conninfo)),
|
||||
sys.exc_info()[2])
|
||||
|
||||
def driver_version(self):
|
||||
return pyro.__version__
|
||||
|
||||
@cached_property
|
||||
def shared_queues(self):
|
||||
return self._open()
|
||||
|
||||
|
||||
if pyro is not None:
|
||||
SerializerBase.register_dict_to_class("queue.Empty",
|
||||
lambda cls, data: Empty())
|
||||
|
||||
@pyro.expose
|
||||
@pyro.behavior(instance_mode="single")
|
||||
class KombuBroker:
|
||||
"""Kombu Broker used by the Pyro transport.
|
||||
|
||||
You have to run this as a separate (Pyro) service.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.queues = {}
|
||||
|
||||
def get_queue_names(self):
|
||||
return list(self.queues)
|
||||
|
||||
def new_queue(self, queue):
|
||||
if queue in self.queues:
|
||||
return # silently ignore the fact that queue already exists
|
||||
self.queues[queue] = Queue()
|
||||
|
||||
def has_queue(self, queue):
|
||||
return queue in self.queues
|
||||
|
||||
def get(self, queue):
|
||||
return self.queues[queue].get(block=False)
|
||||
|
||||
def put(self, queue, message):
|
||||
self.queues[queue].put(message)
|
||||
|
||||
def size(self, queue):
|
||||
return self.queues[queue].qsize()
|
||||
|
||||
def delete(self, queue):
|
||||
del self.queues[queue]
|
||||
|
||||
def purge(self, queue):
|
||||
while True:
|
||||
try:
|
||||
self.queues[queue].get(blocking=False)
|
||||
except Empty:
|
||||
break
|
||||
|
||||
|
||||
# launch a Kombu Broker daemon with the command:
|
||||
# ``python -m kombu.transport.pyro``
|
||||
if __name__ == "__main__":
|
||||
print("Launching Broker for Kombu's Pyro transport.")
|
||||
with pyro.Daemon() as daemon:
|
||||
print("(Expecting a Pyro name server at {}:{})"
|
||||
.format(pyro.config.NS_HOST, pyro.config.NS_PORT))
|
||||
with pyro.locateNS() as ns:
|
||||
print("You can connect with Kombu using the url "
|
||||
"'pyro://{}/kombu.broker'".format(pyro.config.NS_HOST))
|
||||
uri = daemon.register(KombuBroker)
|
||||
ns.register("kombu.broker", uri)
|
||||
daemon.requestLoop()
|
||||
1748
venv/lib/python3.12/site-packages/kombu/transport/qpid.py
Normal file
1748
venv/lib/python3.12/site-packages/kombu/transport/qpid.py
Normal file
File diff suppressed because it is too large
Load Diff
1460
venv/lib/python3.12/site-packages/kombu/transport/redis.py
Normal file
1460
venv/lib/python3.12/site-packages/kombu/transport/redis.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,256 @@
|
||||
"""SQLAlchemy Transport module for kombu.
|
||||
|
||||
Kombu transport using SQL Database as the message store.
|
||||
|
||||
Features
|
||||
========
|
||||
* Type: Virtual
|
||||
* Supports Direct: yes
|
||||
* Supports Topic: yes
|
||||
* Supports Fanout: no
|
||||
* Supports Priority: no
|
||||
* Supports TTL: no
|
||||
|
||||
Connection String
|
||||
=================
|
||||
|
||||
.. code-block::
|
||||
|
||||
sqla+SQL_ALCHEMY_CONNECTION_STRING
|
||||
sqlalchemy+SQL_ALCHEMY_CONNECTION_STRING
|
||||
|
||||
For details about ``SQL_ALCHEMY_CONNECTION_STRING`` see SQLAlchemy Engine Configuration documentation.
|
||||
|
||||
Examples
|
||||
--------
|
||||
.. code-block::
|
||||
|
||||
# PostgreSQL with default driver
|
||||
sqla+postgresql://scott:tiger@localhost/mydatabase
|
||||
|
||||
# PostgreSQL with psycopg2 driver
|
||||
sqla+postgresql+psycopg2://scott:tiger@localhost/mydatabase
|
||||
|
||||
# PostgreSQL with pg8000 driver
|
||||
sqla+postgresql+pg8000://scott:tiger@localhost/mydatabase
|
||||
|
||||
# MySQL with default driver
|
||||
sqla+mysql://scott:tiger@localhost/foo
|
||||
|
||||
# MySQL with mysqlclient driver (a maintained fork of MySQL-Python)
|
||||
sqla+mysql+mysqldb://scott:tiger@localhost/foo
|
||||
|
||||
# MySQL with PyMySQL driver
|
||||
sqla+mysql+pymysql://scott:tiger@localhost/foo
|
||||
|
||||
Transport Options
|
||||
=================
|
||||
|
||||
* ``queue_tablename``: Name of table storing queues.
|
||||
* ``message_tablename``: Name of table storing messages.
|
||||
|
||||
Moreover parameters of :func:`sqlalchemy.create_engine()` function can be passed as transport options.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import threading
|
||||
from json import dumps, loads
|
||||
from queue import Empty
|
||||
|
||||
from sqlalchemy import create_engine, text
|
||||
from sqlalchemy.exc import OperationalError
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from kombu.transport import virtual
|
||||
from kombu.utils import cached_property
|
||||
from kombu.utils.encoding import bytes_to_str
|
||||
|
||||
from .models import Message as MessageBase
|
||||
from .models import ModelBase
|
||||
from .models import Queue as QueueBase
|
||||
from .models import class_registry, metadata
|
||||
|
||||
# SQLAlchemy overrides != False to have special meaning and pep8 complains
|
||||
# flake8: noqa
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
VERSION = (1, 4, 1)
|
||||
__version__ = '.'.join(map(str, VERSION))
|
||||
|
||||
_MUTEX = threading.RLock()
|
||||
|
||||
|
||||
class Channel(virtual.Channel):
|
||||
"""The channel class."""
|
||||
|
||||
_session = None
|
||||
_engines = {} # engine cache
|
||||
|
||||
def __init__(self, connection, **kwargs):
|
||||
self._configure_entity_tablenames(connection.client.transport_options)
|
||||
super().__init__(connection, **kwargs)
|
||||
|
||||
def _configure_entity_tablenames(self, opts):
|
||||
self.queue_tablename = opts.get('queue_tablename', 'kombu_queue')
|
||||
self.message_tablename = opts.get('message_tablename', 'kombu_message')
|
||||
|
||||
#
|
||||
# Define the model definitions. This registers the declarative
|
||||
# classes with the active SQLAlchemy metadata object. This *must* be
|
||||
# done prior to the ``create_engine`` call.
|
||||
#
|
||||
self.queue_cls and self.message_cls
|
||||
|
||||
def _engine_from_config(self):
|
||||
conninfo = self.connection.client
|
||||
transport_options = conninfo.transport_options.copy()
|
||||
transport_options.pop('queue_tablename', None)
|
||||
transport_options.pop('message_tablename', None)
|
||||
transport_options.pop('callback', None)
|
||||
transport_options.pop('errback', None)
|
||||
transport_options.pop('max_retries', None)
|
||||
transport_options.pop('interval_start', None)
|
||||
transport_options.pop('interval_step', None)
|
||||
transport_options.pop('interval_max', None)
|
||||
transport_options.pop('retry_errors', None)
|
||||
|
||||
return create_engine(conninfo.hostname, **transport_options)
|
||||
|
||||
def _open(self):
|
||||
conninfo = self.connection.client
|
||||
if conninfo.hostname not in self._engines:
|
||||
with _MUTEX:
|
||||
if conninfo.hostname in self._engines:
|
||||
# Engine was created while we were waiting to
|
||||
# acquire the lock.
|
||||
return self._engines[conninfo.hostname]
|
||||
|
||||
engine = self._engine_from_config()
|
||||
Session = sessionmaker(bind=engine)
|
||||
metadata.create_all(engine)
|
||||
self._engines[conninfo.hostname] = engine, Session
|
||||
|
||||
return self._engines[conninfo.hostname]
|
||||
|
||||
@property
|
||||
def session(self):
|
||||
if self._session is None:
|
||||
_, Session = self._open()
|
||||
self._session = Session()
|
||||
return self._session
|
||||
|
||||
def _get_or_create(self, queue):
|
||||
obj = self.session.query(self.queue_cls) \
|
||||
.filter(self.queue_cls.name == queue).first()
|
||||
if not obj:
|
||||
with _MUTEX:
|
||||
obj = self.session.query(self.queue_cls) \
|
||||
.filter(self.queue_cls.name == queue).first()
|
||||
if obj:
|
||||
# Queue was created while we were waiting to
|
||||
# acquire the lock.
|
||||
return obj
|
||||
|
||||
obj = self.queue_cls(queue)
|
||||
self.session.add(obj)
|
||||
try:
|
||||
self.session.commit()
|
||||
except OperationalError:
|
||||
self.session.rollback()
|
||||
|
||||
return obj
|
||||
|
||||
def _new_queue(self, queue, **kwargs):
|
||||
self._get_or_create(queue)
|
||||
|
||||
def _put(self, queue, payload, **kwargs):
|
||||
obj = self._get_or_create(queue)
|
||||
message = self.message_cls(dumps(payload), obj)
|
||||
self.session.add(message)
|
||||
try:
|
||||
self.session.commit()
|
||||
except OperationalError:
|
||||
self.session.rollback()
|
||||
|
||||
def _get(self, queue):
|
||||
obj = self._get_or_create(queue)
|
||||
if self.session.bind.name == 'sqlite':
|
||||
self.session.execute(text('BEGIN IMMEDIATE TRANSACTION'))
|
||||
try:
|
||||
msg = self.session.query(self.message_cls) \
|
||||
.with_for_update() \
|
||||
.filter(self.message_cls.queue_id == obj.id) \
|
||||
.filter(self.message_cls.visible != False) \
|
||||
.order_by(self.message_cls.sent_at) \
|
||||
.order_by(self.message_cls.id) \
|
||||
.limit(1) \
|
||||
.first()
|
||||
if msg:
|
||||
msg.visible = False
|
||||
return loads(bytes_to_str(msg.payload))
|
||||
raise Empty()
|
||||
finally:
|
||||
self.session.commit()
|
||||
|
||||
def _query_all(self, queue):
|
||||
obj = self._get_or_create(queue)
|
||||
return self.session.query(self.message_cls) \
|
||||
.filter(self.message_cls.queue_id == obj.id)
|
||||
|
||||
def _purge(self, queue):
|
||||
count = self._query_all(queue).delete(synchronize_session=False)
|
||||
try:
|
||||
self.session.commit()
|
||||
except OperationalError:
|
||||
self.session.rollback()
|
||||
return count
|
||||
|
||||
def _size(self, queue):
|
||||
return self._query_all(queue).count()
|
||||
|
||||
def _declarative_cls(self, name, base, ns):
|
||||
if name not in class_registry:
|
||||
with _MUTEX:
|
||||
if name in class_registry:
|
||||
# Class was registered while we were waiting to
|
||||
# acquire the lock.
|
||||
return class_registry[name]
|
||||
|
||||
return type(str(name), (base, ModelBase), ns)
|
||||
|
||||
return class_registry[name]
|
||||
|
||||
@cached_property
|
||||
def queue_cls(self):
|
||||
return self._declarative_cls(
|
||||
'Queue',
|
||||
QueueBase,
|
||||
{'__tablename__': self.queue_tablename}
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def message_cls(self):
|
||||
return self._declarative_cls(
|
||||
'Message',
|
||||
MessageBase,
|
||||
{'__tablename__': self.message_tablename}
|
||||
)
|
||||
|
||||
|
||||
class Transport(virtual.Transport):
|
||||
"""The transport class."""
|
||||
|
||||
Channel = Channel
|
||||
|
||||
can_parse_url = True
|
||||
default_port = 0
|
||||
driver_type = 'sql'
|
||||
driver_name = 'sqlalchemy'
|
||||
connection_errors = (OperationalError, )
|
||||
|
||||
def driver_version(self):
|
||||
import sqlalchemy
|
||||
return sqlalchemy.__version__
|
||||
@@ -0,0 +1,76 @@
|
||||
"""Kombu transport using SQLAlchemy as the message store."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
|
||||
from sqlalchemy import (Boolean, Column, DateTime, ForeignKey, Index, Integer,
|
||||
Sequence, SmallInteger, String, Text)
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy.schema import MetaData
|
||||
|
||||
try:
|
||||
from sqlalchemy.orm import declarative_base, declared_attr
|
||||
except ImportError:
|
||||
# TODO: Remove this once we drop support for SQLAlchemy < 1.4.
|
||||
from sqlalchemy.ext.declarative import declarative_base, declared_attr
|
||||
|
||||
class_registry = {}
|
||||
metadata = MetaData()
|
||||
ModelBase = declarative_base(metadata=metadata, class_registry=class_registry)
|
||||
|
||||
|
||||
class Queue:
|
||||
"""The queue class."""
|
||||
|
||||
__table_args__ = {'sqlite_autoincrement': True, 'mysql_engine': 'InnoDB'}
|
||||
|
||||
id = Column(Integer, Sequence('queue_id_sequence'), primary_key=True,
|
||||
autoincrement=True)
|
||||
name = Column(String(200), unique=True)
|
||||
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
|
||||
def __str__(self):
|
||||
return f'<Queue({self.name})>'
|
||||
|
||||
@declared_attr
|
||||
def messages(cls):
|
||||
return relationship('Message', backref='queue', lazy='noload')
|
||||
|
||||
|
||||
class Message:
|
||||
"""The message class."""
|
||||
|
||||
__table_args__ = (
|
||||
Index('ix_kombu_message_timestamp_id', 'timestamp', 'id'),
|
||||
{'sqlite_autoincrement': True, 'mysql_engine': 'InnoDB'}
|
||||
)
|
||||
|
||||
id = Column(Integer, Sequence('message_id_sequence'),
|
||||
primary_key=True, autoincrement=True)
|
||||
visible = Column(Boolean, default=True, index=True)
|
||||
sent_at = Column('timestamp', DateTime, nullable=True, index=True,
|
||||
onupdate=datetime.datetime.now)
|
||||
payload = Column(Text, nullable=False)
|
||||
version = Column(SmallInteger, nullable=False, default=1)
|
||||
|
||||
__mapper_args__ = {'version_id_col': version}
|
||||
|
||||
def __init__(self, payload, queue):
|
||||
self.payload = payload
|
||||
self.queue = queue
|
||||
|
||||
def __str__(self):
|
||||
return '<Message: {0.sent_at} {0.payload} {0.queue_id}>'.format(self)
|
||||
|
||||
@declared_attr
|
||||
def queue_id(self):
|
||||
return Column(
|
||||
Integer,
|
||||
ForeignKey(
|
||||
'%s.id' % class_registry['Queue'].__tablename__,
|
||||
name='FK_kombu_message_queue'
|
||||
)
|
||||
)
|
||||
@@ -0,0 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from .base import (AbstractChannel, Base64, BrokerState, Channel, Empty,
|
||||
Management, Message, NotEquivalentError, QoS, Transport,
|
||||
UndeliverableWarning, binding_key_t, queue_binding_t)
|
||||
|
||||
__all__ = (
|
||||
'Base64', 'NotEquivalentError', 'UndeliverableWarning', 'BrokerState',
|
||||
'QoS', 'Message', 'AbstractChannel', 'Channel', 'Management', 'Transport',
|
||||
'Empty', 'binding_key_t', 'queue_binding_t',
|
||||
)
|
||||
1039
venv/lib/python3.12/site-packages/kombu/transport/virtual/base.py
Normal file
1039
venv/lib/python3.12/site-packages/kombu/transport/virtual/base.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,164 @@
|
||||
"""Virtual AMQ Exchange.
|
||||
|
||||
Implementations of the standard exchanges defined
|
||||
by the AMQ protocol (excluding the `headers` exchange).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
|
||||
from kombu.utils.text import escape_regex
|
||||
|
||||
|
||||
class ExchangeType:
|
||||
"""Base class for exchanges.
|
||||
|
||||
Implements the specifics for an exchange type.
|
||||
|
||||
Arguments:
|
||||
---------
|
||||
channel (ChannelT): AMQ Channel.
|
||||
"""
|
||||
|
||||
type = None
|
||||
|
||||
def __init__(self, channel):
|
||||
self.channel = channel
|
||||
|
||||
def lookup(self, table, exchange, routing_key, default):
|
||||
"""Lookup all queues matching `routing_key` in `exchange`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
str: queue name, or 'default' if no queues matched.
|
||||
"""
|
||||
raise NotImplementedError('subclass responsibility')
|
||||
|
||||
def prepare_bind(self, queue, exchange, routing_key, arguments):
|
||||
"""Prepare queue-binding.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Tuple[str, Pattern, str]: of `(routing_key, regex, queue)`
|
||||
to be stored for bindings to this exchange.
|
||||
"""
|
||||
return routing_key, None, queue
|
||||
|
||||
def equivalent(self, prev, exchange, type,
|
||||
durable, auto_delete, arguments):
|
||||
"""Return true if `prev` and `exchange` is equivalent."""
|
||||
return (type == prev['type'] and
|
||||
durable == prev['durable'] and
|
||||
auto_delete == prev['auto_delete'] and
|
||||
(arguments or {}) == (prev['arguments'] or {}))
|
||||
|
||||
|
||||
class DirectExchange(ExchangeType):
|
||||
"""Direct exchange.
|
||||
|
||||
The `direct` exchange routes based on exact routing keys.
|
||||
"""
|
||||
|
||||
type = 'direct'
|
||||
|
||||
def lookup(self, table, exchange, routing_key, default):
|
||||
return {
|
||||
queue for rkey, _, queue in table
|
||||
if rkey == routing_key
|
||||
}
|
||||
|
||||
def deliver(self, message, exchange, routing_key, **kwargs):
|
||||
_lookup = self.channel._lookup
|
||||
_put = self.channel._put
|
||||
for queue in _lookup(exchange, routing_key):
|
||||
_put(queue, message, **kwargs)
|
||||
|
||||
|
||||
class TopicExchange(ExchangeType):
|
||||
"""Topic exchange.
|
||||
|
||||
The `topic` exchange routes messages based on words separated by
|
||||
dots, using wildcard characters ``*`` (any single word), and ``#``
|
||||
(one or more words).
|
||||
"""
|
||||
|
||||
type = 'topic'
|
||||
|
||||
#: map of wildcard to regex conversions
|
||||
wildcards = {'*': r'.*?[^\.]',
|
||||
'#': r'.*?'}
|
||||
|
||||
#: compiled regex cache
|
||||
_compiled = {}
|
||||
|
||||
def lookup(self, table, exchange, routing_key, default):
|
||||
return {
|
||||
queue for rkey, pattern, queue in table
|
||||
if self._match(pattern, routing_key)
|
||||
}
|
||||
|
||||
def deliver(self, message, exchange, routing_key, **kwargs):
|
||||
_lookup = self.channel._lookup
|
||||
_put = self.channel._put
|
||||
deadletter = self.channel.deadletter_queue
|
||||
for queue in [q for q in _lookup(exchange, routing_key)
|
||||
if q and q != deadletter]:
|
||||
_put(queue, message, **kwargs)
|
||||
|
||||
def prepare_bind(self, queue, exchange, routing_key, arguments):
|
||||
return routing_key, self.key_to_pattern(routing_key), queue
|
||||
|
||||
def key_to_pattern(self, rkey):
|
||||
"""Get the corresponding regex for any routing key."""
|
||||
return '^%s$' % (r'\.'.join(
|
||||
self.wildcards.get(word, word)
|
||||
for word in escape_regex(rkey, '.#*').split('.')
|
||||
))
|
||||
|
||||
def _match(self, pattern, string):
|
||||
"""Match regular expression (cached).
|
||||
|
||||
Same as :func:`re.match`, except the regex is compiled and cached,
|
||||
then reused on subsequent matches with the same pattern.
|
||||
"""
|
||||
try:
|
||||
compiled = self._compiled[pattern]
|
||||
except KeyError:
|
||||
compiled = self._compiled[pattern] = re.compile(pattern, re.U)
|
||||
return compiled.match(string)
|
||||
|
||||
|
||||
class FanoutExchange(ExchangeType):
|
||||
"""Fanout exchange.
|
||||
|
||||
The `fanout` exchange implements broadcast messaging by delivering
|
||||
copies of all messages to all queues bound to the exchange.
|
||||
|
||||
To support fanout the virtual channel needs to store the table
|
||||
as shared state. This requires that the `Channel.supports_fanout`
|
||||
attribute is set to true, and the `Channel._queue_bind` and
|
||||
`Channel.get_table` methods are implemented.
|
||||
|
||||
See Also
|
||||
--------
|
||||
the redis backend for an example implementation of these methods.
|
||||
"""
|
||||
|
||||
type = 'fanout'
|
||||
|
||||
def lookup(self, table, exchange, routing_key, default):
|
||||
return {queue for _, _, queue in table}
|
||||
|
||||
def deliver(self, message, exchange, routing_key, **kwargs):
|
||||
if self.channel.supports_fanout:
|
||||
self.channel._put_fanout(
|
||||
exchange, message, routing_key, **kwargs)
|
||||
|
||||
|
||||
#: Map of standard exchange types and corresponding classes.
|
||||
STANDARD_EXCHANGE_TYPES = {
|
||||
'direct': DirectExchange,
|
||||
'topic': TopicExchange,
|
||||
'fanout': FanoutExchange,
|
||||
}
|
||||
223
venv/lib/python3.12/site-packages/kombu/transport/zookeeper.py
Normal file
223
venv/lib/python3.12/site-packages/kombu/transport/zookeeper.py
Normal file
@@ -0,0 +1,223 @@
|
||||
# copyright: (c) 2010 - 2013 by Mahendra M.
|
||||
# license: BSD, see LICENSE for more details.
|
||||
|
||||
"""Zookeeper transport module for kombu.
|
||||
|
||||
Zookeeper based transport. This transport uses the built-in kazoo Zookeeper
|
||||
based queue implementation.
|
||||
|
||||
**References**
|
||||
|
||||
- https://zookeeper.apache.org/doc/current/recipes.html#sc_recipes_Queues
|
||||
- https://kazoo.readthedocs.io/en/latest/api/recipe/queue.html
|
||||
|
||||
**Limitations**
|
||||
This queue does not offer reliable consumption. An entry is removed from
|
||||
the queue prior to being processed. So if an error occurs, the consumer
|
||||
has to re-queue the item or it will be lost.
|
||||
|
||||
Features
|
||||
========
|
||||
* Type: Virtual
|
||||
* Supports Direct: Yes
|
||||
* Supports Topic: Yes
|
||||
* Supports Fanout: No
|
||||
* Supports Priority: Yes
|
||||
* Supports TTL: No
|
||||
|
||||
Connection String
|
||||
=================
|
||||
Connects to a zookeeper node as:
|
||||
|
||||
.. code-block::
|
||||
|
||||
zookeeper://SERVER:PORT/VHOST
|
||||
|
||||
The <vhost> becomes the base for all the other znodes. So we can use
|
||||
it like a vhost.
|
||||
|
||||
|
||||
Transport Options
|
||||
=================
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import socket
|
||||
from queue import Empty
|
||||
|
||||
from kombu.utils.encoding import bytes_to_str, ensure_bytes
|
||||
from kombu.utils.json import dumps, loads
|
||||
|
||||
from . import virtual
|
||||
|
||||
try:
|
||||
import kazoo
|
||||
from kazoo.client import KazooClient
|
||||
from kazoo.recipe.queue import Queue
|
||||
|
||||
KZ_CONNECTION_ERRORS = (
|
||||
kazoo.exceptions.SystemErrorException,
|
||||
kazoo.exceptions.ConnectionLossException,
|
||||
kazoo.exceptions.MarshallingErrorException,
|
||||
kazoo.exceptions.UnimplementedException,
|
||||
kazoo.exceptions.OperationTimeoutException,
|
||||
kazoo.exceptions.NoAuthException,
|
||||
kazoo.exceptions.InvalidACLException,
|
||||
kazoo.exceptions.AuthFailedException,
|
||||
kazoo.exceptions.SessionExpiredException,
|
||||
)
|
||||
|
||||
KZ_CHANNEL_ERRORS = (
|
||||
kazoo.exceptions.RuntimeInconsistencyException,
|
||||
kazoo.exceptions.DataInconsistencyException,
|
||||
kazoo.exceptions.BadArgumentsException,
|
||||
kazoo.exceptions.MarshallingErrorException,
|
||||
kazoo.exceptions.UnimplementedException,
|
||||
kazoo.exceptions.OperationTimeoutException,
|
||||
kazoo.exceptions.ApiErrorException,
|
||||
kazoo.exceptions.NoNodeException,
|
||||
kazoo.exceptions.NoAuthException,
|
||||
kazoo.exceptions.NodeExistsException,
|
||||
kazoo.exceptions.NoChildrenForEphemeralsException,
|
||||
kazoo.exceptions.NotEmptyException,
|
||||
kazoo.exceptions.SessionExpiredException,
|
||||
kazoo.exceptions.InvalidCallbackException,
|
||||
socket.error,
|
||||
)
|
||||
except ImportError:
|
||||
kazoo = None
|
||||
KZ_CONNECTION_ERRORS = KZ_CHANNEL_ERRORS = ()
|
||||
|
||||
DEFAULT_PORT = 2181
|
||||
|
||||
__author__ = 'Mahendra M <mahendra.m@gmail.com>'
|
||||
|
||||
|
||||
class Channel(virtual.Channel):
|
||||
"""Zookeeper Channel."""
|
||||
|
||||
_client = None
|
||||
_queues = {}
|
||||
|
||||
def __init__(self, connection, **kwargs):
|
||||
super().__init__(connection, **kwargs)
|
||||
vhost = self.connection.client.virtual_host
|
||||
self._vhost = '/{}'.format(vhost.strip('/'))
|
||||
|
||||
def _get_path(self, queue_name):
|
||||
return os.path.join(self._vhost, queue_name)
|
||||
|
||||
def _get_queue(self, queue_name):
|
||||
queue = self._queues.get(queue_name, None)
|
||||
|
||||
if queue is None:
|
||||
queue = Queue(self.client, self._get_path(queue_name))
|
||||
self._queues[queue_name] = queue
|
||||
|
||||
# Ensure that the queue is created
|
||||
len(queue)
|
||||
|
||||
return queue
|
||||
|
||||
def _put(self, queue, message, **kwargs):
|
||||
return self._get_queue(queue).put(
|
||||
ensure_bytes(dumps(message)),
|
||||
priority=self._get_message_priority(message, reverse=True),
|
||||
)
|
||||
|
||||
def _get(self, queue):
|
||||
queue = self._get_queue(queue)
|
||||
msg = queue.get()
|
||||
|
||||
if msg is None:
|
||||
raise Empty()
|
||||
|
||||
return loads(bytes_to_str(msg))
|
||||
|
||||
def _purge(self, queue):
|
||||
count = 0
|
||||
queue = self._get_queue(queue)
|
||||
|
||||
while True:
|
||||
msg = queue.get()
|
||||
if msg is None:
|
||||
break
|
||||
count += 1
|
||||
|
||||
return count
|
||||
|
||||
def _delete(self, queue, *args, **kwargs):
|
||||
if self._has_queue(queue):
|
||||
self._purge(queue)
|
||||
self.client.delete(self._get_path(queue))
|
||||
|
||||
def _size(self, queue):
|
||||
queue = self._get_queue(queue)
|
||||
return len(queue)
|
||||
|
||||
def _new_queue(self, queue, **kwargs):
|
||||
if not self._has_queue(queue):
|
||||
queue = self._get_queue(queue)
|
||||
|
||||
def _has_queue(self, queue):
|
||||
return self.client.exists(self._get_path(queue)) is not None
|
||||
|
||||
def _open(self):
|
||||
conninfo = self.connection.client
|
||||
hosts = []
|
||||
if conninfo.alt:
|
||||
for host_port in conninfo.alt:
|
||||
if host_port.startswith('zookeeper://'):
|
||||
host_port = host_port[len('zookeeper://'):]
|
||||
if not host_port:
|
||||
continue
|
||||
try:
|
||||
host, port = host_port.split(':', 1)
|
||||
host_port = (host, int(port))
|
||||
except ValueError:
|
||||
if host_port == conninfo.hostname:
|
||||
host_port = (host_port, conninfo.port or DEFAULT_PORT)
|
||||
else:
|
||||
host_port = (host_port, DEFAULT_PORT)
|
||||
hosts.append(host_port)
|
||||
host_port = (conninfo.hostname, conninfo.port or DEFAULT_PORT)
|
||||
if host_port not in hosts:
|
||||
hosts.insert(0, host_port)
|
||||
conn_str = ','.join([f'{h}:{p}' for h, p in hosts])
|
||||
conn = KazooClient(conn_str)
|
||||
conn.start()
|
||||
return conn
|
||||
|
||||
@property
|
||||
def client(self):
|
||||
if self._client is None:
|
||||
self._client = self._open()
|
||||
return self._client
|
||||
|
||||
|
||||
class Transport(virtual.Transport):
|
||||
"""Zookeeper Transport."""
|
||||
|
||||
Channel = Channel
|
||||
polling_interval = 1
|
||||
default_port = DEFAULT_PORT
|
||||
connection_errors = (
|
||||
virtual.Transport.connection_errors + KZ_CONNECTION_ERRORS
|
||||
)
|
||||
channel_errors = (
|
||||
virtual.Transport.channel_errors + KZ_CHANNEL_ERRORS
|
||||
)
|
||||
driver_type = 'zookeeper'
|
||||
driver_name = 'kazoo'
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
if kazoo is None:
|
||||
raise ImportError('The kazoo library is not installed')
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def driver_version(self):
|
||||
return kazoo.__version__
|
||||
Reference in New Issue
Block a user