Major fixes and new features
All checks were successful
continuous-integration/drone/push Build is passing

This commit is contained in:
2025-09-25 15:51:48 +09:00
parent dd7349bb4c
commit ddce9f5125
5586 changed files with 1470941 additions and 0 deletions

View File

@@ -0,0 +1,34 @@
from __future__ import absolute_import
__title__ = 'kafka'
from kafka.version import __version__
__author__ = 'Dana Powers'
__license__ = 'Apache License 2.0'
__copyright__ = 'Copyright 2016 Dana Powers, David Arthur, and Contributors'
# Set default logging handler to avoid "No handler found" warnings.
import logging
try: # Python 2.7+
from logging import NullHandler
except ImportError:
class NullHandler(logging.Handler):
def emit(self, record):
pass
logging.getLogger(__name__).addHandler(NullHandler())
from kafka.admin import KafkaAdminClient
from kafka.client_async import KafkaClient
from kafka.consumer import KafkaConsumer
from kafka.consumer.subscription_state import ConsumerRebalanceListener
from kafka.producer import KafkaProducer
from kafka.conn import BrokerConnection
from kafka.serializer import Serializer, Deserializer
from kafka.structs import TopicPartition, OffsetAndMetadata
__all__ = [
'BrokerConnection', 'ConsumerRebalanceListener', 'KafkaAdminClient',
'KafkaClient', 'KafkaConsumer', 'KafkaProducer',
]

View File

@@ -0,0 +1,14 @@
from __future__ import absolute_import
from kafka.admin.config_resource import ConfigResource, ConfigResourceType
from kafka.admin.client import KafkaAdminClient
from kafka.admin.acl_resource import (ACL, ACLFilter, ResourcePattern, ResourcePatternFilter, ACLOperation,
ResourceType, ACLPermissionType, ACLResourcePatternType)
from kafka.admin.new_topic import NewTopic
from kafka.admin.new_partitions import NewPartitions
__all__ = [
'ConfigResource', 'ConfigResourceType', 'KafkaAdminClient', 'NewTopic', 'NewPartitions', 'ACL', 'ACLFilter',
'ResourcePattern', 'ResourcePatternFilter', 'ACLOperation', 'ResourceType', 'ACLPermissionType',
'ACLResourcePatternType'
]

View File

@@ -0,0 +1,244 @@
from __future__ import absolute_import
from kafka.errors import IllegalArgumentError
# enum in stdlib as of py3.4
try:
from enum import IntEnum # pylint: disable=import-error
except ImportError:
# vendored backport module
from kafka.vendor.enum34 import IntEnum
class ResourceType(IntEnum):
"""Type of kafka resource to set ACL for
The ANY value is only valid in a filter context
"""
UNKNOWN = 0,
ANY = 1,
CLUSTER = 4,
DELEGATION_TOKEN = 6,
GROUP = 3,
TOPIC = 2,
TRANSACTIONAL_ID = 5
class ACLOperation(IntEnum):
"""Type of operation
The ANY value is only valid in a filter context
"""
ANY = 1,
ALL = 2,
READ = 3,
WRITE = 4,
CREATE = 5,
DELETE = 6,
ALTER = 7,
DESCRIBE = 8,
CLUSTER_ACTION = 9,
DESCRIBE_CONFIGS = 10,
ALTER_CONFIGS = 11,
IDEMPOTENT_WRITE = 12
class ACLPermissionType(IntEnum):
"""An enumerated type of permissions
The ANY value is only valid in a filter context
"""
ANY = 1,
DENY = 2,
ALLOW = 3
class ACLResourcePatternType(IntEnum):
"""An enumerated type of resource patterns
More details on the pattern types and how they work
can be found in KIP-290 (Support for prefixed ACLs)
https://cwiki.apache.org/confluence/display/KAFKA/KIP-290%3A+Support+for+Prefixed+ACLs
"""
ANY = 1,
MATCH = 2,
LITERAL = 3,
PREFIXED = 4
class ACLFilter(object):
"""Represents a filter to use with describing and deleting ACLs
The difference between this class and the ACL class is mainly that
we allow using ANY with the operation, permission, and resource type objects
to fetch ALCs matching any of the properties.
To make a filter matching any principal, set principal to None
"""
def __init__(
self,
principal,
host,
operation,
permission_type,
resource_pattern
):
self.principal = principal
self.host = host
self.operation = operation
self.permission_type = permission_type
self.resource_pattern = resource_pattern
self.validate()
def validate(self):
if not isinstance(self.operation, ACLOperation):
raise IllegalArgumentError("operation must be an ACLOperation object, and cannot be ANY")
if not isinstance(self.permission_type, ACLPermissionType):
raise IllegalArgumentError("permission_type must be an ACLPermissionType object, and cannot be ANY")
if not isinstance(self.resource_pattern, ResourcePatternFilter):
raise IllegalArgumentError("resource_pattern must be a ResourcePatternFilter object")
def __repr__(self):
return "<ACL principal={principal}, resource={resource}, operation={operation}, type={type}, host={host}>".format(
principal=self.principal,
host=self.host,
operation=self.operation.name,
type=self.permission_type.name,
resource=self.resource_pattern
)
def __eq__(self, other):
return all((
self.principal == other.principal,
self.host == other.host,
self.operation == other.operation,
self.permission_type == other.permission_type,
self.resource_pattern == other.resource_pattern
))
def __hash__(self):
return hash((
self.principal,
self.host,
self.operation,
self.permission_type,
self.resource_pattern,
))
class ACL(ACLFilter):
"""Represents a concrete ACL for a specific ResourcePattern
In kafka an ACL is a 4-tuple of (principal, host, operation, permission_type)
that limits who can do what on a specific resource (or since KIP-290 a resource pattern)
Terminology:
Principal -> This is the identifier for the user. Depending on the authorization method used (SSL, SASL etc)
the principal will look different. See http://kafka.apache.org/documentation/#security_authz for details.
The principal must be on the format "User:<name>" or kafka will treat it as invalid. It's possible to use
other principal types than "User" if using a custom authorizer for the cluster.
Host -> This must currently be an IP address. It cannot be a range, and it cannot be a domain name.
It can be set to "*", which is special cased in kafka to mean "any host"
Operation -> Which client operation this ACL refers to. Has different meaning depending
on the resource type the ACL refers to. See https://docs.confluent.io/current/kafka/authorization.html#acl-format
for a list of which combinations of resource/operation that unlocks which kafka APIs
Permission Type: Whether this ACL is allowing or denying access
Resource Pattern -> This is a representation of the resource or resource pattern that the ACL
refers to. See the ResourcePattern class for details.
"""
def __init__(
self,
principal,
host,
operation,
permission_type,
resource_pattern
):
super(ACL, self).__init__(principal, host, operation, permission_type, resource_pattern)
self.validate()
def validate(self):
if self.operation == ACLOperation.ANY:
raise IllegalArgumentError("operation cannot be ANY")
if self.permission_type == ACLPermissionType.ANY:
raise IllegalArgumentError("permission_type cannot be ANY")
if not isinstance(self.resource_pattern, ResourcePattern):
raise IllegalArgumentError("resource_pattern must be a ResourcePattern object")
class ResourcePatternFilter(object):
def __init__(
self,
resource_type,
resource_name,
pattern_type
):
self.resource_type = resource_type
self.resource_name = resource_name
self.pattern_type = pattern_type
self.validate()
def validate(self):
if not isinstance(self.resource_type, ResourceType):
raise IllegalArgumentError("resource_type must be a ResourceType object")
if not isinstance(self.pattern_type, ACLResourcePatternType):
raise IllegalArgumentError("pattern_type must be an ACLResourcePatternType object")
def __repr__(self):
return "<ResourcePattern type={}, name={}, pattern={}>".format(
self.resource_type.name,
self.resource_name,
self.pattern_type.name
)
def __eq__(self, other):
return all((
self.resource_type == other.resource_type,
self.resource_name == other.resource_name,
self.pattern_type == other.pattern_type,
))
def __hash__(self):
return hash((
self.resource_type,
self.resource_name,
self.pattern_type
))
class ResourcePattern(ResourcePatternFilter):
"""A resource pattern to apply the ACL to
Resource patterns are used to be able to specify which resources an ACL
describes in a more flexible way than just pointing to a literal topic name for example.
Since KIP-290 (kafka 2.0) it's possible to set an ACL for a prefixed resource name, which
can cut down considerably on the number of ACLs needed when the number of topics and
consumer groups start to grow.
The default pattern_type is LITERAL, and it describes a specific resource. This is also how
ACLs worked before the introduction of prefixed ACLs
"""
def __init__(
self,
resource_type,
resource_name,
pattern_type=ACLResourcePatternType.LITERAL
):
super(ResourcePattern, self).__init__(resource_type, resource_name, pattern_type)
self.validate()
def validate(self):
if self.resource_type == ResourceType.ANY:
raise IllegalArgumentError("resource_type cannot be ANY")
if self.pattern_type in [ACLResourcePatternType.ANY, ACLResourcePatternType.MATCH]:
raise IllegalArgumentError(
"pattern_type cannot be {} on a concrete ResourcePattern".format(self.pattern_type.name)
)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,36 @@
from __future__ import absolute_import
# enum in stdlib as of py3.4
try:
from enum import IntEnum # pylint: disable=import-error
except ImportError:
# vendored backport module
from kafka.vendor.enum34 import IntEnum
class ConfigResourceType(IntEnum):
"""An enumerated type of config resources"""
BROKER = 4,
TOPIC = 2
class ConfigResource(object):
"""A class for specifying config resources.
Arguments:
resource_type (ConfigResourceType): the type of kafka resource
name (string): The name of the kafka resource
configs ({key : value}): A maps of config keys to values.
"""
def __init__(
self,
resource_type,
name,
configs=None
):
if not isinstance(resource_type, (ConfigResourceType)):
resource_type = ConfigResourceType[str(resource_type).upper()] # pylint: disable-msg=unsubscriptable-object
self.resource_type = resource_type
self.name = name
self.configs = configs

View File

@@ -0,0 +1,19 @@
from __future__ import absolute_import
class NewPartitions(object):
"""A class for new partition creation on existing topics. Note that the length of new_assignments, if specified,
must be the difference between the new total number of partitions and the existing number of partitions.
Arguments:
total_count (int): the total number of partitions that should exist on the topic
new_assignments ([[int]]): an array of arrays of replica assignments for new partitions.
If not set, broker assigns replicas per an internal algorithm.
"""
def __init__(
self,
total_count,
new_assignments=None
):
self.total_count = total_count
self.new_assignments = new_assignments

View File

@@ -0,0 +1,34 @@
from __future__ import absolute_import
from kafka.errors import IllegalArgumentError
class NewTopic(object):
""" A class for new topic creation
Arguments:
name (string): name of the topic
num_partitions (int): number of partitions
or -1 if replica_assignment has been specified
replication_factor (int): replication factor or -1 if
replica assignment is specified
replica_assignment (dict of int: [int]): A mapping containing
partition id and replicas to assign to it.
topic_configs (dict of str: str): A mapping of config key
and value for the topic.
"""
def __init__(
self,
name,
num_partitions,
replication_factor,
replica_assignments=None,
topic_configs=None,
):
if not (num_partitions == -1 or replication_factor == -1) ^ (replica_assignments is None):
raise IllegalArgumentError('either num_partitions/replication_factor or replica_assignment must be specified')
self.name = name
self.num_partitions = num_partitions
self.replication_factor = replication_factor
self.replica_assignments = replica_assignments or {}
self.topic_configs = topic_configs or {}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,397 @@
from __future__ import absolute_import
import collections
import copy
import logging
import threading
import time
from kafka.vendor import six
from kafka import errors as Errors
from kafka.conn import collect_hosts
from kafka.future import Future
from kafka.structs import BrokerMetadata, PartitionMetadata, TopicPartition
log = logging.getLogger(__name__)
class ClusterMetadata(object):
"""
A class to manage kafka cluster metadata.
This class does not perform any IO. It simply updates internal state
given API responses (MetadataResponse, GroupCoordinatorResponse).
Keyword Arguments:
retry_backoff_ms (int): Milliseconds to backoff when retrying on
errors. Default: 100.
metadata_max_age_ms (int): The period of time in milliseconds after
which we force a refresh of metadata even if we haven't seen any
partition leadership changes to proactively discover any new
brokers or partitions. Default: 300000
bootstrap_servers: 'host[:port]' string (or list of 'host[:port]'
strings) that the client should contact to bootstrap initial
cluster metadata. This does not have to be the full node list.
It just needs to have at least one broker that will respond to a
Metadata API Request. Default port is 9092. If no servers are
specified, will default to localhost:9092.
"""
DEFAULT_CONFIG = {
'retry_backoff_ms': 100,
'metadata_max_age_ms': 300000,
'bootstrap_servers': [],
}
def __init__(self, **configs):
self._brokers = {} # node_id -> BrokerMetadata
self._partitions = {} # topic -> partition -> PartitionMetadata
self._broker_partitions = collections.defaultdict(set) # node_id -> {TopicPartition...}
self._groups = {} # group_name -> node_id
self._last_refresh_ms = 0
self._last_successful_refresh_ms = 0
self._need_update = True
self._future = None
self._listeners = set()
self._lock = threading.Lock()
self.need_all_topic_metadata = False
self.unauthorized_topics = set()
self.internal_topics = set()
self.controller = None
self.config = copy.copy(self.DEFAULT_CONFIG)
for key in self.config:
if key in configs:
self.config[key] = configs[key]
self._bootstrap_brokers = self._generate_bootstrap_brokers()
self._coordinator_brokers = {}
def _generate_bootstrap_brokers(self):
# collect_hosts does not perform DNS, so we should be fine to re-use
bootstrap_hosts = collect_hosts(self.config['bootstrap_servers'])
brokers = {}
for i, (host, port, _) in enumerate(bootstrap_hosts):
node_id = 'bootstrap-%s' % i
brokers[node_id] = BrokerMetadata(node_id, host, port, None)
return brokers
def is_bootstrap(self, node_id):
return node_id in self._bootstrap_brokers
def brokers(self):
"""Get all BrokerMetadata
Returns:
set: {BrokerMetadata, ...}
"""
return set(self._brokers.values()) or set(self._bootstrap_brokers.values())
def broker_metadata(self, broker_id):
"""Get BrokerMetadata
Arguments:
broker_id (int): node_id for a broker to check
Returns:
BrokerMetadata or None if not found
"""
return (
self._brokers.get(broker_id) or
self._bootstrap_brokers.get(broker_id) or
self._coordinator_brokers.get(broker_id)
)
def partitions_for_topic(self, topic):
"""Return set of all partitions for topic (whether available or not)
Arguments:
topic (str): topic to check for partitions
Returns:
set: {partition (int), ...}
"""
if topic not in self._partitions:
return None
return set(self._partitions[topic].keys())
def available_partitions_for_topic(self, topic):
"""Return set of partitions with known leaders
Arguments:
topic (str): topic to check for partitions
Returns:
set: {partition (int), ...}
None if topic not found.
"""
if topic not in self._partitions:
return None
return set([partition for partition, metadata
in six.iteritems(self._partitions[topic])
if metadata.leader != -1])
def leader_for_partition(self, partition):
"""Return node_id of leader, -1 unavailable, None if unknown."""
if partition.topic not in self._partitions:
return None
elif partition.partition not in self._partitions[partition.topic]:
return None
return self._partitions[partition.topic][partition.partition].leader
def partitions_for_broker(self, broker_id):
"""Return TopicPartitions for which the broker is a leader.
Arguments:
broker_id (int): node id for a broker
Returns:
set: {TopicPartition, ...}
None if the broker either has no partitions or does not exist.
"""
return self._broker_partitions.get(broker_id)
def coordinator_for_group(self, group):
"""Return node_id of group coordinator.
Arguments:
group (str): name of consumer group
Returns:
int: node_id for group coordinator
None if the group does not exist.
"""
return self._groups.get(group)
def ttl(self):
"""Milliseconds until metadata should be refreshed"""
now = time.time() * 1000
if self._need_update:
ttl = 0
else:
metadata_age = now - self._last_successful_refresh_ms
ttl = self.config['metadata_max_age_ms'] - metadata_age
retry_age = now - self._last_refresh_ms
next_retry = self.config['retry_backoff_ms'] - retry_age
return max(ttl, next_retry, 0)
def refresh_backoff(self):
"""Return milliseconds to wait before attempting to retry after failure"""
return self.config['retry_backoff_ms']
def request_update(self):
"""Flags metadata for update, return Future()
Actual update must be handled separately. This method will only
change the reported ttl()
Returns:
kafka.future.Future (value will be the cluster object after update)
"""
with self._lock:
self._need_update = True
if not self._future or self._future.is_done:
self._future = Future()
return self._future
def topics(self, exclude_internal_topics=True):
"""Get set of known topics.
Arguments:
exclude_internal_topics (bool): Whether records from internal topics
(such as offsets) should be exposed to the consumer. If set to
True the only way to receive records from an internal topic is
subscribing to it. Default True
Returns:
set: {topic (str), ...}
"""
topics = set(self._partitions.keys())
if exclude_internal_topics:
return topics - self.internal_topics
else:
return topics
def failed_update(self, exception):
"""Update cluster state given a failed MetadataRequest."""
f = None
with self._lock:
if self._future:
f = self._future
self._future = None
if f:
f.failure(exception)
self._last_refresh_ms = time.time() * 1000
def update_metadata(self, metadata):
"""Update cluster state given a MetadataResponse.
Arguments:
metadata (MetadataResponse): broker response to a metadata request
Returns: None
"""
# In the common case where we ask for a single topic and get back an
# error, we should fail the future
if len(metadata.topics) == 1 and metadata.topics[0][0] != 0:
error_code, topic = metadata.topics[0][:2]
error = Errors.for_code(error_code)(topic)
return self.failed_update(error)
if not metadata.brokers:
log.warning("No broker metadata found in MetadataResponse -- ignoring.")
return self.failed_update(Errors.MetadataEmptyBrokerList(metadata))
_new_brokers = {}
for broker in metadata.brokers:
if metadata.API_VERSION == 0:
node_id, host, port = broker
rack = None
else:
node_id, host, port, rack = broker
_new_brokers.update({
node_id: BrokerMetadata(node_id, host, port, rack)
})
if metadata.API_VERSION == 0:
_new_controller = None
else:
_new_controller = _new_brokers.get(metadata.controller_id)
_new_partitions = {}
_new_broker_partitions = collections.defaultdict(set)
_new_unauthorized_topics = set()
_new_internal_topics = set()
for topic_data in metadata.topics:
if metadata.API_VERSION == 0:
error_code, topic, partitions = topic_data
is_internal = False
else:
error_code, topic, is_internal, partitions = topic_data
if is_internal:
_new_internal_topics.add(topic)
error_type = Errors.for_code(error_code)
if error_type is Errors.NoError:
_new_partitions[topic] = {}
for p_error, partition, leader, replicas, isr in partitions:
_new_partitions[topic][partition] = PartitionMetadata(
topic=topic, partition=partition, leader=leader,
replicas=replicas, isr=isr, error=p_error)
if leader != -1:
_new_broker_partitions[leader].add(
TopicPartition(topic, partition))
# Specific topic errors can be ignored if this is a full metadata fetch
elif self.need_all_topic_metadata:
continue
elif error_type is Errors.LeaderNotAvailableError:
log.warning("Topic %s is not available during auto-create"
" initialization", topic)
elif error_type is Errors.UnknownTopicOrPartitionError:
log.error("Topic %s not found in cluster metadata", topic)
elif error_type is Errors.TopicAuthorizationFailedError:
log.error("Topic %s is not authorized for this client", topic)
_new_unauthorized_topics.add(topic)
elif error_type is Errors.InvalidTopicError:
log.error("'%s' is not a valid topic name", topic)
else:
log.error("Error fetching metadata for topic %s: %s",
topic, error_type)
with self._lock:
self._brokers = _new_brokers
self.controller = _new_controller
self._partitions = _new_partitions
self._broker_partitions = _new_broker_partitions
self.unauthorized_topics = _new_unauthorized_topics
self.internal_topics = _new_internal_topics
f = None
if self._future:
f = self._future
self._future = None
self._need_update = False
now = time.time() * 1000
self._last_refresh_ms = now
self._last_successful_refresh_ms = now
if f:
f.success(self)
log.debug("Updated cluster metadata to %s", self)
for listener in self._listeners:
listener(self)
if self.need_all_topic_metadata:
# the listener may change the interested topics,
# which could cause another metadata refresh.
# If we have already fetched all topics, however,
# another fetch should be unnecessary.
self._need_update = False
def add_listener(self, listener):
"""Add a callback function to be called on each metadata update"""
self._listeners.add(listener)
def remove_listener(self, listener):
"""Remove a previously added listener callback"""
self._listeners.remove(listener)
def add_group_coordinator(self, group, response):
"""Update with metadata for a group coordinator
Arguments:
group (str): name of group from GroupCoordinatorRequest
response (GroupCoordinatorResponse): broker response
Returns:
string: coordinator node_id if metadata is updated, None on error
"""
log.debug("Updating coordinator for %s: %s", group, response)
error_type = Errors.for_code(response.error_code)
if error_type is not Errors.NoError:
log.error("GroupCoordinatorResponse error: %s", error_type)
self._groups[group] = -1
return
# Use a coordinator-specific node id so that group requests
# get a dedicated connection
node_id = 'coordinator-{}'.format(response.coordinator_id)
coordinator = BrokerMetadata(
node_id,
response.host,
response.port,
None)
log.info("Group coordinator for %s is %s", group, coordinator)
self._coordinator_brokers[node_id] = coordinator
self._groups[group] = node_id
return node_id
def with_partitions(self, partitions_to_add):
"""Returns a copy of cluster metadata with partitions added"""
new_metadata = ClusterMetadata(**self.config)
new_metadata._brokers = copy.deepcopy(self._brokers)
new_metadata._partitions = copy.deepcopy(self._partitions)
new_metadata._broker_partitions = copy.deepcopy(self._broker_partitions)
new_metadata._groups = copy.deepcopy(self._groups)
new_metadata.internal_topics = copy.deepcopy(self.internal_topics)
new_metadata.unauthorized_topics = copy.deepcopy(self.unauthorized_topics)
for partition in partitions_to_add:
new_metadata._partitions[partition.topic][partition.partition] = partition
if partition.leader is not None and partition.leader != -1:
new_metadata._broker_partitions[partition.leader].add(
TopicPartition(partition.topic, partition.partition))
return new_metadata
def __str__(self):
return 'ClusterMetadata(brokers: %d, topics: %d, groups: %d)' % \
(len(self._brokers), len(self._partitions), len(self._groups))

View File

@@ -0,0 +1,326 @@
from __future__ import absolute_import
import gzip
import io
import platform
import struct
from kafka.vendor import six
from kafka.vendor.six.moves import range
_XERIAL_V1_HEADER = (-126, b'S', b'N', b'A', b'P', b'P', b'Y', 0, 1, 1)
_XERIAL_V1_FORMAT = 'bccccccBii'
ZSTD_MAX_OUTPUT_SIZE = 1024 * 1024
try:
import snappy
except ImportError:
snappy = None
try:
import zstandard as zstd
except ImportError:
zstd = None
try:
import lz4.frame as lz4
def _lz4_compress(payload, **kwargs):
# Kafka does not support LZ4 dependent blocks
try:
# For lz4>=0.12.0
kwargs.pop('block_linked', None)
return lz4.compress(payload, block_linked=False, **kwargs)
except TypeError:
# For earlier versions of lz4
kwargs.pop('block_mode', None)
return lz4.compress(payload, block_mode=1, **kwargs)
except ImportError:
lz4 = None
try:
import lz4f
except ImportError:
lz4f = None
try:
import lz4framed
except ImportError:
lz4framed = None
try:
import xxhash
except ImportError:
xxhash = None
PYPY = bool(platform.python_implementation() == 'PyPy')
def has_gzip():
return True
def has_snappy():
return snappy is not None
def has_zstd():
return zstd is not None
def has_lz4():
if lz4 is not None:
return True
if lz4f is not None:
return True
if lz4framed is not None:
return True
return False
def gzip_encode(payload, compresslevel=None):
if not compresslevel:
compresslevel = 9
buf = io.BytesIO()
# Gzip context manager introduced in python 2.7
# so old-fashioned way until we decide to not support 2.6
gzipper = gzip.GzipFile(fileobj=buf, mode="w", compresslevel=compresslevel)
try:
gzipper.write(payload)
finally:
gzipper.close()
return buf.getvalue()
def gzip_decode(payload):
buf = io.BytesIO(payload)
# Gzip context manager introduced in python 2.7
# so old-fashioned way until we decide to not support 2.6
gzipper = gzip.GzipFile(fileobj=buf, mode='r')
try:
return gzipper.read()
finally:
gzipper.close()
def snappy_encode(payload, xerial_compatible=True, xerial_blocksize=32*1024):
"""Encodes the given data with snappy compression.
If xerial_compatible is set then the stream is encoded in a fashion
compatible with the xerial snappy library.
The block size (xerial_blocksize) controls how frequent the blocking occurs
32k is the default in the xerial library.
The format winds up being:
+-------------+------------+--------------+------------+--------------+
| Header | Block1 len | Block1 data | Blockn len | Blockn data |
+-------------+------------+--------------+------------+--------------+
| 16 bytes | BE int32 | snappy bytes | BE int32 | snappy bytes |
+-------------+------------+--------------+------------+--------------+
It is important to note that the blocksize is the amount of uncompressed
data presented to snappy at each block, whereas the blocklen is the number
of bytes that will be present in the stream; so the length will always be
<= blocksize.
"""
if not has_snappy():
raise NotImplementedError("Snappy codec is not available")
if not xerial_compatible:
return snappy.compress(payload)
out = io.BytesIO()
for fmt, dat in zip(_XERIAL_V1_FORMAT, _XERIAL_V1_HEADER):
out.write(struct.pack('!' + fmt, dat))
# Chunk through buffers to avoid creating intermediate slice copies
if PYPY:
# on pypy, snappy.compress() on a sliced buffer consumes the entire
# buffer... likely a python-snappy bug, so just use a slice copy
chunker = lambda payload, i, size: payload[i:size+i]
elif six.PY2:
# Sliced buffer avoids additional copies
# pylint: disable-msg=undefined-variable
chunker = lambda payload, i, size: buffer(payload, i, size)
else:
# snappy.compress does not like raw memoryviews, so we have to convert
# tobytes, which is a copy... oh well. it's the thought that counts.
# pylint: disable-msg=undefined-variable
chunker = lambda payload, i, size: memoryview(payload)[i:size+i].tobytes()
for chunk in (chunker(payload, i, xerial_blocksize)
for i in range(0, len(payload), xerial_blocksize)):
block = snappy.compress(chunk)
block_size = len(block)
out.write(struct.pack('!i', block_size))
out.write(block)
return out.getvalue()
def _detect_xerial_stream(payload):
"""Detects if the data given might have been encoded with the blocking mode
of the xerial snappy library.
This mode writes a magic header of the format:
+--------+--------------+------------+---------+--------+
| Marker | Magic String | Null / Pad | Version | Compat |
+--------+--------------+------------+---------+--------+
| byte | c-string | byte | int32 | int32 |
+--------+--------------+------------+---------+--------+
| -126 | 'SNAPPY' | \0 | | |
+--------+--------------+------------+---------+--------+
The pad appears to be to ensure that SNAPPY is a valid cstring
The version is the version of this format as written by xerial,
in the wild this is currently 1 as such we only support v1.
Compat is there to claim the miniumum supported version that
can read a xerial block stream, presently in the wild this is
1.
"""
if len(payload) > 16:
header = struct.unpack('!' + _XERIAL_V1_FORMAT, bytes(payload)[:16])
return header == _XERIAL_V1_HEADER
return False
def snappy_decode(payload):
if not has_snappy():
raise NotImplementedError("Snappy codec is not available")
if _detect_xerial_stream(payload):
# TODO ? Should become a fileobj ?
out = io.BytesIO()
byt = payload[16:]
length = len(byt)
cursor = 0
while cursor < length:
block_size = struct.unpack_from('!i', byt[cursor:])[0]
# Skip the block size
cursor += 4
end = cursor + block_size
out.write(snappy.decompress(byt[cursor:end]))
cursor = end
out.seek(0)
return out.read()
else:
return snappy.decompress(payload)
if lz4:
lz4_encode = _lz4_compress # pylint: disable-msg=no-member
elif lz4f:
lz4_encode = lz4f.compressFrame # pylint: disable-msg=no-member
elif lz4framed:
lz4_encode = lz4framed.compress # pylint: disable-msg=no-member
else:
lz4_encode = None
def lz4f_decode(payload):
"""Decode payload using interoperable LZ4 framing. Requires Kafka >= 0.10"""
# pylint: disable-msg=no-member
ctx = lz4f.createDecompContext()
data = lz4f.decompressFrame(payload, ctx)
lz4f.freeDecompContext(ctx)
# lz4f python module does not expose how much of the payload was
# actually read if the decompression was only partial.
if data['next'] != 0:
raise RuntimeError('lz4f unable to decompress full payload')
return data['decomp']
if lz4:
lz4_decode = lz4.decompress # pylint: disable-msg=no-member
elif lz4f:
lz4_decode = lz4f_decode
elif lz4framed:
lz4_decode = lz4framed.decompress # pylint: disable-msg=no-member
else:
lz4_decode = None
def lz4_encode_old_kafka(payload):
"""Encode payload for 0.8/0.9 brokers -- requires an incorrect header checksum."""
assert xxhash is not None
data = lz4_encode(payload)
header_size = 7
flg = data[4]
if not isinstance(flg, int):
flg = ord(flg)
content_size_bit = ((flg >> 3) & 1)
if content_size_bit:
# Old kafka does not accept the content-size field
# so we need to discard it and reset the header flag
flg -= 8
data = bytearray(data)
data[4] = flg
data = bytes(data)
payload = data[header_size+8:]
else:
payload = data[header_size:]
# This is the incorrect hc
hc = xxhash.xxh32(data[0:header_size-1]).digest()[-2:-1] # pylint: disable-msg=no-member
return b''.join([
data[0:header_size-1],
hc,
payload
])
def lz4_decode_old_kafka(payload):
assert xxhash is not None
# Kafka's LZ4 code has a bug in its header checksum implementation
header_size = 7
if isinstance(payload[4], int):
flg = payload[4]
else:
flg = ord(payload[4])
content_size_bit = ((flg >> 3) & 1)
if content_size_bit:
header_size += 8
# This should be the correct hc
hc = xxhash.xxh32(payload[4:header_size-1]).digest()[-2:-1] # pylint: disable-msg=no-member
munged_payload = b''.join([
payload[0:header_size-1],
hc,
payload[header_size:]
])
return lz4_decode(munged_payload)
def zstd_encode(payload):
if not zstd:
raise NotImplementedError("Zstd codec is not available")
return zstd.ZstdCompressor().compress(payload)
def zstd_decode(payload):
if not zstd:
raise NotImplementedError("Zstd codec is not available")
try:
return zstd.ZstdDecompressor().decompress(payload)
except zstd.ZstdError:
return zstd.ZstdDecompressor().decompress(payload, max_output_size=ZSTD_MAX_OUTPUT_SIZE)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,7 @@
from __future__ import absolute_import
from kafka.consumer.group import KafkaConsumer
__all__ = [
'KafkaConsumer'
]

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,501 @@
from __future__ import absolute_import
import abc
import logging
import re
from kafka.vendor import six
from kafka.errors import IllegalStateError
from kafka.protocol.offset import OffsetResetStrategy
from kafka.structs import OffsetAndMetadata
log = logging.getLogger(__name__)
class SubscriptionState(object):
"""
A class for tracking the topics, partitions, and offsets for the consumer.
A partition is "assigned" either directly with assign_from_user() (manual
assignment) or with assign_from_subscribed() (automatic assignment from
subscription).
Once assigned, the partition is not considered "fetchable" until its initial
position has been set with seek(). Fetchable partitions track a fetch
position which is used to set the offset of the next fetch, and a consumed
position which is the last offset that has been returned to the user. You
can suspend fetching from a partition through pause() without affecting the
fetched/consumed offsets. The partition will remain unfetchable until the
resume() is used. You can also query the pause state independently with
is_paused().
Note that pause state as well as fetch/consumed positions are not preserved
when partition assignment is changed whether directly by the user or
through a group rebalance.
This class also maintains a cache of the latest commit position for each of
the assigned partitions. This is updated through committed() and can be used
to set the initial fetch position (e.g. Fetcher._reset_offset() ).
"""
_SUBSCRIPTION_EXCEPTION_MESSAGE = (
"You must choose only one way to configure your consumer:"
" (1) subscribe to specific topics by name,"
" (2) subscribe to topics matching a regex pattern,"
" (3) assign itself specific topic-partitions.")
# Taken from: https://github.com/apache/kafka/blob/39eb31feaeebfb184d98cc5d94da9148c2319d81/clients/src/main/java/org/apache/kafka/common/internals/Topic.java#L29
_MAX_NAME_LENGTH = 249
_TOPIC_LEGAL_CHARS = re.compile('^[a-zA-Z0-9._-]+$')
def __init__(self, offset_reset_strategy='earliest'):
"""Initialize a SubscriptionState instance
Keyword Arguments:
offset_reset_strategy: 'earliest' or 'latest', otherwise
exception will be raised when fetching an offset that is no
longer available. Default: 'earliest'
"""
try:
offset_reset_strategy = getattr(OffsetResetStrategy,
offset_reset_strategy.upper())
except AttributeError:
log.warning('Unrecognized offset_reset_strategy, using NONE')
offset_reset_strategy = OffsetResetStrategy.NONE
self._default_offset_reset_strategy = offset_reset_strategy
self.subscription = None # set() or None
self.subscribed_pattern = None # regex str or None
self._group_subscription = set()
self._user_assignment = set()
self.assignment = dict()
self.listener = None
# initialize to true for the consumers to fetch offset upon starting up
self.needs_fetch_committed_offsets = True
def subscribe(self, topics=(), pattern=None, listener=None):
"""Subscribe to a list of topics, or a topic regex pattern.
Partitions will be dynamically assigned via a group coordinator.
Topic subscriptions are not incremental: this list will replace the
current assignment (if there is one).
This method is incompatible with assign_from_user()
Arguments:
topics (list): List of topics for subscription.
pattern (str): Pattern to match available topics. You must provide
either topics or pattern, but not both.
listener (ConsumerRebalanceListener): Optionally include listener
callback, which will be called before and after each rebalance
operation.
As part of group management, the consumer will keep track of the
list of consumers that belong to a particular group and will
trigger a rebalance operation if one of the following events
trigger:
* Number of partitions change for any of the subscribed topics
* Topic is created or deleted
* An existing member of the consumer group dies
* A new member is added to the consumer group
When any of these events are triggered, the provided listener
will be invoked first to indicate that the consumer's assignment
has been revoked, and then again when the new assignment has
been received. Note that this listener will immediately override
any listener set in a previous call to subscribe. It is
guaranteed, however, that the partitions revoked/assigned
through this interface are from topics subscribed in this call.
"""
if self._user_assignment or (topics and pattern):
raise IllegalStateError(self._SUBSCRIPTION_EXCEPTION_MESSAGE)
assert topics or pattern, 'Must provide topics or pattern'
if pattern:
log.info('Subscribing to pattern: /%s/', pattern)
self.subscription = set()
self.subscribed_pattern = re.compile(pattern)
else:
self.change_subscription(topics)
if listener and not isinstance(listener, ConsumerRebalanceListener):
raise TypeError('listener must be a ConsumerRebalanceListener')
self.listener = listener
def _ensure_valid_topic_name(self, topic):
""" Ensures that the topic name is valid according to the kafka source. """
# See Kafka Source:
# https://github.com/apache/kafka/blob/39eb31feaeebfb184d98cc5d94da9148c2319d81/clients/src/main/java/org/apache/kafka/common/internals/Topic.java
if topic is None:
raise TypeError('All topics must not be None')
if not isinstance(topic, six.string_types):
raise TypeError('All topics must be strings')
if len(topic) == 0:
raise ValueError('All topics must be non-empty strings')
if topic == '.' or topic == '..':
raise ValueError('Topic name cannot be "." or ".."')
if len(topic) > self._MAX_NAME_LENGTH:
raise ValueError('Topic name is illegal, it can\'t be longer than {0} characters, topic: "{1}"'.format(self._MAX_NAME_LENGTH, topic))
if not self._TOPIC_LEGAL_CHARS.match(topic):
raise ValueError('Topic name "{0}" is illegal, it contains a character other than ASCII alphanumerics, ".", "_" and "-"'.format(topic))
def change_subscription(self, topics):
"""Change the topic subscription.
Arguments:
topics (list of str): topics for subscription
Raises:
IllegalStateError: if assign_from_user has been used already
TypeError: if a topic is None or a non-str
ValueError: if a topic is an empty string or
- a topic name is '.' or '..' or
- a topic name does not consist of ASCII-characters/'-'/'_'/'.'
"""
if self._user_assignment:
raise IllegalStateError(self._SUBSCRIPTION_EXCEPTION_MESSAGE)
if isinstance(topics, six.string_types):
topics = [topics]
if self.subscription == set(topics):
log.warning("subscription unchanged by change_subscription(%s)",
topics)
return
for t in topics:
self._ensure_valid_topic_name(t)
log.info('Updating subscribed topics to: %s', topics)
self.subscription = set(topics)
self._group_subscription.update(topics)
# Remove any assigned partitions which are no longer subscribed to
for tp in set(self.assignment.keys()):
if tp.topic not in self.subscription:
del self.assignment[tp]
def group_subscribe(self, topics):
"""Add topics to the current group subscription.
This is used by the group leader to ensure that it receives metadata
updates for all topics that any member of the group is subscribed to.
Arguments:
topics (list of str): topics to add to the group subscription
"""
if self._user_assignment:
raise IllegalStateError(self._SUBSCRIPTION_EXCEPTION_MESSAGE)
self._group_subscription.update(topics)
def reset_group_subscription(self):
"""Reset the group's subscription to only contain topics subscribed by this consumer."""
if self._user_assignment:
raise IllegalStateError(self._SUBSCRIPTION_EXCEPTION_MESSAGE)
assert self.subscription is not None, 'Subscription required'
self._group_subscription.intersection_update(self.subscription)
def assign_from_user(self, partitions):
"""Manually assign a list of TopicPartitions to this consumer.
This interface does not allow for incremental assignment and will
replace the previous assignment (if there was one).
Manual topic assignment through this method does not use the consumer's
group management functionality. As such, there will be no rebalance
operation triggered when group membership or cluster and topic metadata
change. Note that it is not possible to use both manual partition
assignment with assign() and group assignment with subscribe().
Arguments:
partitions (list of TopicPartition): assignment for this instance.
Raises:
IllegalStateError: if consumer has already called subscribe()
"""
if self.subscription is not None:
raise IllegalStateError(self._SUBSCRIPTION_EXCEPTION_MESSAGE)
if self._user_assignment != set(partitions):
self._user_assignment = set(partitions)
for partition in partitions:
if partition not in self.assignment:
self._add_assigned_partition(partition)
for tp in set(self.assignment.keys()) - self._user_assignment:
del self.assignment[tp]
self.needs_fetch_committed_offsets = True
def assign_from_subscribed(self, assignments):
"""Update the assignment to the specified partitions
This method is called by the coordinator to dynamically assign
partitions based on the consumer's topic subscription. This is different
from assign_from_user() which directly sets the assignment from a
user-supplied TopicPartition list.
Arguments:
assignments (list of TopicPartition): partitions to assign to this
consumer instance.
"""
if not self.partitions_auto_assigned():
raise IllegalStateError(self._SUBSCRIPTION_EXCEPTION_MESSAGE)
for tp in assignments:
if tp.topic not in self.subscription:
raise ValueError("Assigned partition %s for non-subscribed topic." % (tp,))
# after rebalancing, we always reinitialize the assignment state
self.assignment.clear()
for tp in assignments:
self._add_assigned_partition(tp)
self.needs_fetch_committed_offsets = True
log.info("Updated partition assignment: %s", assignments)
def unsubscribe(self):
"""Clear all topic subscriptions and partition assignments"""
self.subscription = None
self._user_assignment.clear()
self.assignment.clear()
self.subscribed_pattern = None
def group_subscription(self):
"""Get the topic subscription for the group.
For the leader, this will include the union of all member subscriptions.
For followers, it is the member's subscription only.
This is used when querying topic metadata to detect metadata changes
that would require rebalancing (the leader fetches metadata for all
topics in the group so that it can do partition assignment).
Returns:
set: topics
"""
return self._group_subscription
def seek(self, partition, offset):
"""Manually specify the fetch offset for a TopicPartition.
Overrides the fetch offsets that the consumer will use on the next
poll(). If this API is invoked for the same partition more than once,
the latest offset will be used on the next poll(). Note that you may
lose data if this API is arbitrarily used in the middle of consumption,
to reset the fetch offsets.
Arguments:
partition (TopicPartition): partition for seek operation
offset (int): message offset in partition
"""
self.assignment[partition].seek(offset)
def assigned_partitions(self):
"""Return set of TopicPartitions in current assignment."""
return set(self.assignment.keys())
def paused_partitions(self):
"""Return current set of paused TopicPartitions."""
return set(partition for partition in self.assignment
if self.is_paused(partition))
def fetchable_partitions(self):
"""Return set of TopicPartitions that should be Fetched."""
fetchable = set()
for partition, state in six.iteritems(self.assignment):
if state.is_fetchable():
fetchable.add(partition)
return fetchable
def partitions_auto_assigned(self):
"""Return True unless user supplied partitions manually."""
return self.subscription is not None
def all_consumed_offsets(self):
"""Returns consumed offsets as {TopicPartition: OffsetAndMetadata}"""
all_consumed = {}
for partition, state in six.iteritems(self.assignment):
if state.has_valid_position:
all_consumed[partition] = OffsetAndMetadata(state.position, '')
return all_consumed
def need_offset_reset(self, partition, offset_reset_strategy=None):
"""Mark partition for offset reset using specified or default strategy.
Arguments:
partition (TopicPartition): partition to mark
offset_reset_strategy (OffsetResetStrategy, optional)
"""
if offset_reset_strategy is None:
offset_reset_strategy = self._default_offset_reset_strategy
self.assignment[partition].await_reset(offset_reset_strategy)
def has_default_offset_reset_policy(self):
"""Return True if default offset reset policy is Earliest or Latest"""
return self._default_offset_reset_strategy != OffsetResetStrategy.NONE
def is_offset_reset_needed(self, partition):
return self.assignment[partition].awaiting_reset
def has_all_fetch_positions(self):
for state in self.assignment.values():
if not state.has_valid_position:
return False
return True
def missing_fetch_positions(self):
missing = set()
for partition, state in six.iteritems(self.assignment):
if not state.has_valid_position:
missing.add(partition)
return missing
def is_assigned(self, partition):
return partition in self.assignment
def is_paused(self, partition):
return partition in self.assignment and self.assignment[partition].paused
def is_fetchable(self, partition):
return partition in self.assignment and self.assignment[partition].is_fetchable()
def pause(self, partition):
self.assignment[partition].pause()
def resume(self, partition):
self.assignment[partition].resume()
def _add_assigned_partition(self, partition):
self.assignment[partition] = TopicPartitionState()
class TopicPartitionState(object):
def __init__(self):
self.committed = None # last committed OffsetAndMetadata
self.has_valid_position = False # whether we have valid position
self.paused = False # whether this partition has been paused by the user
self.awaiting_reset = False # whether we are awaiting reset
self.reset_strategy = None # the reset strategy if awaitingReset is set
self._position = None # offset exposed to the user
self.highwater = None
self.drop_pending_message_set = False
# The last message offset hint available from a message batch with
# magic=2 which includes deleted compacted messages
self.last_offset_from_message_batch = None
def _set_position(self, offset):
assert self.has_valid_position, 'Valid position required'
self._position = offset
def _get_position(self):
return self._position
position = property(_get_position, _set_position, None, "last position")
def await_reset(self, strategy):
self.awaiting_reset = True
self.reset_strategy = strategy
self._position = None
self.last_offset_from_message_batch = None
self.has_valid_position = False
def seek(self, offset):
self._position = offset
self.awaiting_reset = False
self.reset_strategy = None
self.has_valid_position = True
self.drop_pending_message_set = True
self.last_offset_from_message_batch = None
def pause(self):
self.paused = True
def resume(self):
self.paused = False
def is_fetchable(self):
return not self.paused and self.has_valid_position
class ConsumerRebalanceListener(object):
"""
A callback interface that the user can implement to trigger custom actions
when the set of partitions assigned to the consumer changes.
This is applicable when the consumer is having Kafka auto-manage group
membership. If the consumer's directly assign partitions, those
partitions will never be reassigned and this callback is not applicable.
When Kafka is managing the group membership, a partition re-assignment will
be triggered any time the members of the group changes or the subscription
of the members changes. This can occur when processes die, new process
instances are added or old instances come back to life after failure.
Rebalances can also be triggered by changes affecting the subscribed
topics (e.g. when then number of partitions is administratively adjusted).
There are many uses for this functionality. One common use is saving offsets
in a custom store. By saving offsets in the on_partitions_revoked(), call we
can ensure that any time partition assignment changes the offset gets saved.
Another use is flushing out any kind of cache of intermediate results the
consumer may be keeping. For example, consider a case where the consumer is
subscribed to a topic containing user page views, and the goal is to count
the number of page views per users for each five minute window. Let's say
the topic is partitioned by the user id so that all events for a particular
user will go to a single consumer instance. The consumer can keep in memory
a running tally of actions per user and only flush these out to a remote
data store when its cache gets too big. However if a partition is reassigned
it may want to automatically trigger a flush of this cache, before the new
owner takes over consumption.
This callback will execute in the user thread as part of the Consumer.poll()
whenever partition assignment changes.
It is guaranteed that all consumer processes will invoke
on_partitions_revoked() prior to any process invoking
on_partitions_assigned(). So if offsets or other state is saved in the
on_partitions_revoked() call, it should be saved by the time the process
taking over that partition has their on_partitions_assigned() callback
called to load the state.
"""
__metaclass__ = abc.ABCMeta
@abc.abstractmethod
def on_partitions_revoked(self, revoked):
"""
A callback method the user can implement to provide handling of offset
commits to a customized store on the start of a rebalance operation.
This method will be called before a rebalance operation starts and
after the consumer stops fetching data. It is recommended that offsets
should be committed in this callback to either Kafka or a custom offset
store to prevent duplicate data.
NOTE: This method is only called before rebalances. It is not called
prior to KafkaConsumer.close()
Arguments:
revoked (list of TopicPartition): the partitions that were assigned
to the consumer on the last rebalance
"""
pass
@abc.abstractmethod
def on_partitions_assigned(self, assigned):
"""
A callback method the user can implement to provide handling of
customized offsets on completion of a successful partition
re-assignment. This method will be called after an offset re-assignment
completes and before the consumer starts fetching data.
It is guaranteed that all the processes in a consumer group will execute
their on_partitions_revoked() callback before any instance executes its
on_partitions_assigned() callback.
Arguments:
assigned (list of TopicPartition): the partitions assigned to the
consumer (may include partitions that were previously assigned)
"""
pass

View File

@@ -0,0 +1,56 @@
from __future__ import absolute_import
import abc
import logging
log = logging.getLogger(__name__)
class AbstractPartitionAssignor(object):
"""
Abstract assignor implementation which does some common grunt work (in particular collecting
partition counts which are always needed in assignors).
"""
@abc.abstractproperty
def name(self):
""".name should be a string identifying the assignor"""
pass
@abc.abstractmethod
def assign(self, cluster, members):
"""Perform group assignment given cluster metadata and member subscriptions
Arguments:
cluster (ClusterMetadata): metadata for use in assignment
members (dict of {member_id: MemberMetadata}): decoded metadata for
each member in the group.
Returns:
dict: {member_id: MemberAssignment}
"""
pass
@abc.abstractmethod
def metadata(self, topics):
"""Generate ProtocolMetadata to be submitted via JoinGroupRequest.
Arguments:
topics (set): a member's subscribed topics
Returns:
MemberMetadata struct
"""
pass
@abc.abstractmethod
def on_assignment(self, assignment):
"""Callback that runs on each assignment.
This method can be used to update internal state, if any, of the
partition assignor.
Arguments:
assignment (MemberAssignment): the member's assignment
"""
pass

View File

@@ -0,0 +1,77 @@
from __future__ import absolute_import
import collections
import logging
from kafka.vendor import six
from kafka.coordinator.assignors.abstract import AbstractPartitionAssignor
from kafka.coordinator.protocol import ConsumerProtocolMemberMetadata, ConsumerProtocolMemberAssignment
log = logging.getLogger(__name__)
class RangePartitionAssignor(AbstractPartitionAssignor):
"""
The range assignor works on a per-topic basis. For each topic, we lay out
the available partitions in numeric order and the consumers in
lexicographic order. We then divide the number of partitions by the total
number of consumers to determine the number of partitions to assign to each
consumer. If it does not evenly divide, then the first few consumers will
have one extra partition.
For example, suppose there are two consumers C0 and C1, two topics t0 and
t1, and each topic has 3 partitions, resulting in partitions t0p0, t0p1,
t0p2, t1p0, t1p1, and t1p2.
The assignment will be:
C0: [t0p0, t0p1, t1p0, t1p1]
C1: [t0p2, t1p2]
"""
name = 'range'
version = 0
@classmethod
def assign(cls, cluster, member_metadata):
consumers_per_topic = collections.defaultdict(list)
for member, metadata in six.iteritems(member_metadata):
for topic in metadata.subscription:
consumers_per_topic[topic].append(member)
# construct {member_id: {topic: [partition, ...]}}
assignment = collections.defaultdict(dict)
for topic, consumers_for_topic in six.iteritems(consumers_per_topic):
partitions = cluster.partitions_for_topic(topic)
if partitions is None:
log.warning('No partition metadata for topic %s', topic)
continue
partitions = sorted(partitions)
consumers_for_topic.sort()
partitions_per_consumer = len(partitions) // len(consumers_for_topic)
consumers_with_extra = len(partitions) % len(consumers_for_topic)
for i, member in enumerate(consumers_for_topic):
start = partitions_per_consumer * i
start += min(i, consumers_with_extra)
length = partitions_per_consumer
if not i + 1 > consumers_with_extra:
length += 1
assignment[member][topic] = partitions[start:start+length]
protocol_assignment = {}
for member_id in member_metadata:
protocol_assignment[member_id] = ConsumerProtocolMemberAssignment(
cls.version,
sorted(assignment[member_id].items()),
b'')
return protocol_assignment
@classmethod
def metadata(cls, topics):
return ConsumerProtocolMemberMetadata(cls.version, list(topics), b'')
@classmethod
def on_assignment(cls, assignment):
pass

View File

@@ -0,0 +1,96 @@
from __future__ import absolute_import
import collections
import itertools
import logging
from kafka.vendor import six
from kafka.coordinator.assignors.abstract import AbstractPartitionAssignor
from kafka.coordinator.protocol import ConsumerProtocolMemberMetadata, ConsumerProtocolMemberAssignment
from kafka.structs import TopicPartition
log = logging.getLogger(__name__)
class RoundRobinPartitionAssignor(AbstractPartitionAssignor):
"""
The roundrobin assignor lays out all the available partitions and all the
available consumers. It then proceeds to do a roundrobin assignment from
partition to consumer. If the subscriptions of all consumer instances are
identical, then the partitions will be uniformly distributed. (i.e., the
partition ownership counts will be within a delta of exactly one across all
consumers.)
For example, suppose there are two consumers C0 and C1, two topics t0 and
t1, and each topic has 3 partitions, resulting in partitions t0p0, t0p1,
t0p2, t1p0, t1p1, and t1p2.
The assignment will be:
C0: [t0p0, t0p2, t1p1]
C1: [t0p1, t1p0, t1p2]
When subscriptions differ across consumer instances, the assignment process
still considers each consumer instance in round robin fashion but skips
over an instance if it is not subscribed to the topic. Unlike the case when
subscriptions are identical, this can result in imbalanced assignments.
For example, suppose we have three consumers C0, C1, C2, and three topics
t0, t1, t2, with unbalanced partitions t0p0, t1p0, t1p1, t2p0, t2p1, t2p2,
where C0 is subscribed to t0; C1 is subscribed to t0, t1; and C2 is
subscribed to t0, t1, t2.
The assignment will be:
C0: [t0p0]
C1: [t1p0]
C2: [t1p1, t2p0, t2p1, t2p2]
"""
name = 'roundrobin'
version = 0
@classmethod
def assign(cls, cluster, member_metadata):
all_topics = set()
for metadata in six.itervalues(member_metadata):
all_topics.update(metadata.subscription)
all_topic_partitions = []
for topic in all_topics:
partitions = cluster.partitions_for_topic(topic)
if partitions is None:
log.warning('No partition metadata for topic %s', topic)
continue
for partition in partitions:
all_topic_partitions.append(TopicPartition(topic, partition))
all_topic_partitions.sort()
# construct {member_id: {topic: [partition, ...]}}
assignment = collections.defaultdict(lambda: collections.defaultdict(list))
member_iter = itertools.cycle(sorted(member_metadata.keys()))
for partition in all_topic_partitions:
member_id = next(member_iter)
# Because we constructed all_topic_partitions from the set of
# member subscribed topics, we should be safe assuming that
# each topic in all_topic_partitions is in at least one member
# subscription; otherwise this could yield an infinite loop
while partition.topic not in member_metadata[member_id].subscription:
member_id = next(member_iter)
assignment[member_id][partition.topic].append(partition.partition)
protocol_assignment = {}
for member_id in member_metadata:
protocol_assignment[member_id] = ConsumerProtocolMemberAssignment(
cls.version,
sorted(assignment[member_id].items()),
b'')
return protocol_assignment
@classmethod
def metadata(cls, topics):
return ConsumerProtocolMemberMetadata(cls.version, list(topics), b'')
@classmethod
def on_assignment(cls, assignment):
pass

View File

@@ -0,0 +1,149 @@
import logging
from collections import defaultdict, namedtuple
from copy import deepcopy
from kafka.vendor import six
log = logging.getLogger(__name__)
ConsumerPair = namedtuple("ConsumerPair", ["src_member_id", "dst_member_id"])
"""
Represents a pair of Kafka consumer ids involved in a partition reassignment.
Each ConsumerPair corresponds to a particular partition or topic, indicates that the particular partition or some
partition of the particular topic was moved from the source consumer to the destination consumer
during the rebalance. This class helps in determining whether a partition reassignment results in cycles among
the generated graph of consumer pairs.
"""
def is_sublist(source, target):
"""Checks if one list is a sublist of another.
Arguments:
source: the list in which to search for the occurrence of target.
target: the list to search for as a sublist of source
Returns:
true if target is in source; false otherwise
"""
for index in (i for i, e in enumerate(source) if e == target[0]):
if tuple(source[index: index + len(target)]) == target:
return True
return False
class PartitionMovements:
"""
This class maintains some data structures to simplify lookup of partition movements among consumers.
At each point of time during a partition rebalance it keeps track of partition movements
corresponding to each topic, and also possible movement (in form a ConsumerPair object) for each partition.
"""
def __init__(self):
self.partition_movements_by_topic = defaultdict(
lambda: defaultdict(set)
)
self.partition_movements = {}
def move_partition(self, partition, old_consumer, new_consumer):
pair = ConsumerPair(src_member_id=old_consumer, dst_member_id=new_consumer)
if partition in self.partition_movements:
# this partition has previously moved
existing_pair = self._remove_movement_record_of_partition(partition)
assert existing_pair.dst_member_id == old_consumer
if existing_pair.src_member_id != new_consumer:
# the partition is not moving back to its previous consumer
self._add_partition_movement_record(
partition, ConsumerPair(src_member_id=existing_pair.src_member_id, dst_member_id=new_consumer)
)
else:
self._add_partition_movement_record(partition, pair)
def get_partition_to_be_moved(self, partition, old_consumer, new_consumer):
if partition.topic not in self.partition_movements_by_topic:
return partition
if partition in self.partition_movements:
# this partition has previously moved
assert old_consumer == self.partition_movements[partition].dst_member_id
old_consumer = self.partition_movements[partition].src_member_id
reverse_pair = ConsumerPair(src_member_id=new_consumer, dst_member_id=old_consumer)
if reverse_pair not in self.partition_movements_by_topic[partition.topic]:
return partition
return next(iter(self.partition_movements_by_topic[partition.topic][reverse_pair]))
def are_sticky(self):
for topic, movements in six.iteritems(self.partition_movements_by_topic):
movement_pairs = set(movements.keys())
if self._has_cycles(movement_pairs):
log.error(
"Stickiness is violated for topic {}\n"
"Partition movements for this topic occurred among the following consumer pairs:\n"
"{}".format(topic, movement_pairs)
)
return False
return True
def _remove_movement_record_of_partition(self, partition):
pair = self.partition_movements[partition]
del self.partition_movements[partition]
self.partition_movements_by_topic[partition.topic][pair].remove(partition)
if not self.partition_movements_by_topic[partition.topic][pair]:
del self.partition_movements_by_topic[partition.topic][pair]
if not self.partition_movements_by_topic[partition.topic]:
del self.partition_movements_by_topic[partition.topic]
return pair
def _add_partition_movement_record(self, partition, pair):
self.partition_movements[partition] = pair
self.partition_movements_by_topic[partition.topic][pair].add(partition)
def _has_cycles(self, consumer_pairs):
cycles = set()
for pair in consumer_pairs:
reduced_pairs = deepcopy(consumer_pairs)
reduced_pairs.remove(pair)
path = [pair.src_member_id]
if self._is_linked(pair.dst_member_id, pair.src_member_id, reduced_pairs, path) and not self._is_subcycle(
path, cycles
):
cycles.add(tuple(path))
log.error("A cycle of length {} was found: {}".format(len(path) - 1, path))
# for now we want to make sure there is no partition movements of the same topic between a pair of consumers.
# the odds of finding a cycle among more than two consumers seem to be very low (according to various randomized
# tests with the given sticky algorithm) that it should not worth the added complexity of handling those cases.
for cycle in cycles:
if len(cycle) == 3: # indicates a cycle of length 2
return True
return False
@staticmethod
def _is_subcycle(cycle, cycles):
super_cycle = deepcopy(cycle)
super_cycle = super_cycle[:-1]
super_cycle.extend(cycle)
for found_cycle in cycles:
if len(found_cycle) == len(cycle) and is_sublist(super_cycle, found_cycle):
return True
return False
def _is_linked(self, src, dst, pairs, current_path):
if src == dst:
return False
if not pairs:
return False
if ConsumerPair(src, dst) in pairs:
current_path.append(src)
current_path.append(dst)
return True
for pair in pairs:
if pair.src_member_id == src:
reduced_set = deepcopy(pairs)
reduced_set.remove(pair)
current_path.append(pair.src_member_id)
return self._is_linked(pair.dst_member_id, dst, reduced_set, current_path)
return False

View File

@@ -0,0 +1,63 @@
class SortedSet:
def __init__(self, iterable=None, key=None):
self._key = key if key is not None else lambda x: x
self._set = set(iterable) if iterable is not None else set()
self._cached_last = None
self._cached_first = None
def first(self):
if self._cached_first is not None:
return self._cached_first
first = None
for element in self._set:
if first is None or self._key(first) > self._key(element):
first = element
self._cached_first = first
return first
def last(self):
if self._cached_last is not None:
return self._cached_last
last = None
for element in self._set:
if last is None or self._key(last) < self._key(element):
last = element
self._cached_last = last
return last
def pop_last(self):
value = self.last()
self._set.remove(value)
self._cached_last = None
return value
def add(self, value):
if self._cached_last is not None and self._key(value) > self._key(self._cached_last):
self._cached_last = value
if self._cached_first is not None and self._key(value) < self._key(self._cached_first):
self._cached_first = value
return self._set.add(value)
def remove(self, value):
if self._cached_last is not None and self._cached_last == value:
self._cached_last = None
if self._cached_first is not None and self._cached_first == value:
self._cached_first = None
return self._set.remove(value)
def __contains__(self, value):
return value in self._set
def __iter__(self):
return iter(sorted(self._set, key=self._key))
def _bool(self):
return len(self._set) != 0
__nonzero__ = _bool
__bool__ = _bool

View File

@@ -0,0 +1,681 @@
import logging
from collections import defaultdict, namedtuple
from copy import deepcopy
from kafka.cluster import ClusterMetadata
from kafka.coordinator.assignors.abstract import AbstractPartitionAssignor
from kafka.coordinator.assignors.sticky.partition_movements import PartitionMovements
from kafka.coordinator.assignors.sticky.sorted_set import SortedSet
from kafka.coordinator.protocol import ConsumerProtocolMemberMetadata, ConsumerProtocolMemberAssignment
from kafka.coordinator.protocol import Schema
from kafka.protocol.struct import Struct
from kafka.protocol.types import String, Array, Int32
from kafka.structs import TopicPartition
from kafka.vendor import six
log = logging.getLogger(__name__)
ConsumerGenerationPair = namedtuple("ConsumerGenerationPair", ["consumer", "generation"])
def has_identical_list_elements(list_):
"""Checks if all lists in the collection have the same members
Arguments:
list_: collection of lists
Returns:
true if all lists in the collection have the same members; false otherwise
"""
if not list_:
return True
for i in range(1, len(list_)):
if list_[i] != list_[i - 1]:
return False
return True
def subscriptions_comparator_key(element):
return len(element[1]), element[0]
def partitions_comparator_key(element):
return len(element[1]), element[0].topic, element[0].partition
def remove_if_present(collection, element):
try:
collection.remove(element)
except (ValueError, KeyError):
pass
StickyAssignorMemberMetadataV1 = namedtuple("StickyAssignorMemberMetadataV1",
["subscription", "partitions", "generation"])
class StickyAssignorUserDataV1(Struct):
"""
Used for preserving consumer's previously assigned partitions
list and sending it as user data to the leader during a rebalance
"""
SCHEMA = Schema(
("previous_assignment", Array(("topic", String("utf-8")), ("partitions", Array(Int32)))), ("generation", Int32)
)
class StickyAssignmentExecutor:
def __init__(self, cluster, members):
self.members = members
# a mapping between consumers and their assigned partitions that is updated during assignment procedure
self.current_assignment = defaultdict(list)
# an assignment from a previous generation
self.previous_assignment = {}
# a mapping between partitions and their assigned consumers
self.current_partition_consumer = {}
# a flag indicating that there were no previous assignments performed ever
self.is_fresh_assignment = False
# a mapping of all topic partitions to all consumers that can be assigned to them
self.partition_to_all_potential_consumers = {}
# a mapping of all consumers to all potential topic partitions that can be assigned to them
self.consumer_to_all_potential_partitions = {}
# an ascending sorted set of consumers based on how many topic partitions are already assigned to them
self.sorted_current_subscriptions = SortedSet()
# an ascending sorted list of topic partitions based on how many consumers can potentially use them
self.sorted_partitions = []
# all partitions that need to be assigned
self.unassigned_partitions = []
# a flag indicating that a certain partition cannot remain assigned to its current consumer because the consumer
# is no longer subscribed to its topic
self.revocation_required = False
self.partition_movements = PartitionMovements()
self._initialize(cluster)
def perform_initial_assignment(self):
self._populate_sorted_partitions()
self._populate_partitions_to_reassign()
def balance(self):
self._initialize_current_subscriptions()
initializing = len(self.current_assignment[self._get_consumer_with_most_subscriptions()]) == 0
# assign all unassigned partitions
for partition in self.unassigned_partitions:
# skip if there is no potential consumer for the partition
if not self.partition_to_all_potential_consumers[partition]:
continue
self._assign_partition(partition)
# narrow down the reassignment scope to only those partitions that can actually be reassigned
fixed_partitions = set()
for partition in six.iterkeys(self.partition_to_all_potential_consumers):
if not self._can_partition_participate_in_reassignment(partition):
fixed_partitions.add(partition)
for fixed_partition in fixed_partitions:
remove_if_present(self.sorted_partitions, fixed_partition)
remove_if_present(self.unassigned_partitions, fixed_partition)
# narrow down the reassignment scope to only those consumers that are subject to reassignment
fixed_assignments = {}
for consumer in six.iterkeys(self.consumer_to_all_potential_partitions):
if not self._can_consumer_participate_in_reassignment(consumer):
self._remove_consumer_from_current_subscriptions_and_maintain_order(consumer)
fixed_assignments[consumer] = self.current_assignment[consumer]
del self.current_assignment[consumer]
# create a deep copy of the current assignment so we can revert to it
# if we do not get a more balanced assignment later
prebalance_assignment = deepcopy(self.current_assignment)
prebalance_partition_consumers = deepcopy(self.current_partition_consumer)
# if we don't already need to revoke something due to subscription changes,
# first try to balance by only moving newly added partitions
if not self.revocation_required:
self._perform_reassignments(self.unassigned_partitions)
reassignment_performed = self._perform_reassignments(self.sorted_partitions)
# if we are not preserving existing assignments and we have made changes to the current assignment
# make sure we are getting a more balanced assignment; otherwise, revert to previous assignment
if (
not initializing
and reassignment_performed
and self._get_balance_score(self.current_assignment) >= self._get_balance_score(prebalance_assignment)
):
self.current_assignment = prebalance_assignment
self.current_partition_consumer.clear()
self.current_partition_consumer.update(prebalance_partition_consumers)
# add the fixed assignments (those that could not change) back
for consumer, partitions in six.iteritems(fixed_assignments):
self.current_assignment[consumer] = partitions
self._add_consumer_to_current_subscriptions_and_maintain_order(consumer)
def get_final_assignment(self, member_id):
assignment = defaultdict(list)
for topic_partition in self.current_assignment[member_id]:
assignment[topic_partition.topic].append(topic_partition.partition)
assignment = {k: sorted(v) for k, v in six.iteritems(assignment)}
return six.viewitems(assignment)
def _initialize(self, cluster):
self._init_current_assignments(self.members)
for topic in cluster.topics():
partitions = cluster.partitions_for_topic(topic)
if partitions is None:
log.warning("No partition metadata for topic %s", topic)
continue
for p in partitions:
partition = TopicPartition(topic=topic, partition=p)
self.partition_to_all_potential_consumers[partition] = []
for consumer_id, member_metadata in six.iteritems(self.members):
self.consumer_to_all_potential_partitions[consumer_id] = []
for topic in member_metadata.subscription:
if cluster.partitions_for_topic(topic) is None:
log.warning("No partition metadata for topic {}".format(topic))
continue
for p in cluster.partitions_for_topic(topic):
partition = TopicPartition(topic=topic, partition=p)
self.consumer_to_all_potential_partitions[consumer_id].append(partition)
self.partition_to_all_potential_consumers[partition].append(consumer_id)
if consumer_id not in self.current_assignment:
self.current_assignment[consumer_id] = []
def _init_current_assignments(self, members):
# we need to process subscriptions' user data with each consumer's reported generation in mind
# higher generations overwrite lower generations in case of a conflict
# note that a conflict could exists only if user data is for different generations
# for each partition we create a map of its consumers by generation
sorted_partition_consumers_by_generation = {}
for consumer, member_metadata in six.iteritems(members):
for partitions in member_metadata.partitions:
if partitions in sorted_partition_consumers_by_generation:
consumers = sorted_partition_consumers_by_generation[partitions]
if member_metadata.generation and member_metadata.generation in consumers:
# same partition is assigned to two consumers during the same rebalance.
# log a warning and skip this record
log.warning(
"Partition {} is assigned to multiple consumers "
"following sticky assignment generation {}.".format(partitions, member_metadata.generation)
)
else:
consumers[member_metadata.generation] = consumer
else:
sorted_consumers = {member_metadata.generation: consumer}
sorted_partition_consumers_by_generation[partitions] = sorted_consumers
# previous_assignment holds the prior ConsumerGenerationPair (before current) of each partition
# current and previous consumers are the last two consumers of each partition in the above sorted map
for partitions, consumers in six.iteritems(sorted_partition_consumers_by_generation):
generations = sorted(consumers.keys(), reverse=True)
self.current_assignment[consumers[generations[0]]].append(partitions)
# now update previous assignment if any
if len(generations) > 1:
self.previous_assignment[partitions] = ConsumerGenerationPair(
consumer=consumers[generations[1]], generation=generations[1]
)
self.is_fresh_assignment = len(self.current_assignment) == 0
for consumer_id, partitions in six.iteritems(self.current_assignment):
for partition in partitions:
self.current_partition_consumer[partition] = consumer_id
def _are_subscriptions_identical(self):
"""
Returns:
true, if both potential consumers of partitions and potential partitions that consumers can
consume are the same
"""
if not has_identical_list_elements(list(six.itervalues(self.partition_to_all_potential_consumers))):
return False
return has_identical_list_elements(list(six.itervalues(self.consumer_to_all_potential_partitions)))
def _populate_sorted_partitions(self):
# set of topic partitions with their respective potential consumers
all_partitions = set((tp, tuple(consumers))
for tp, consumers in six.iteritems(self.partition_to_all_potential_consumers))
partitions_sorted_by_num_of_potential_consumers = sorted(all_partitions, key=partitions_comparator_key)
self.sorted_partitions = []
if not self.is_fresh_assignment and self._are_subscriptions_identical():
# if this is a reassignment and the subscriptions are identical (all consumers can consumer from all topics)
# then we just need to simply list partitions in a round robin fashion (from consumers with
# most assigned partitions to those with least)
assignments = deepcopy(self.current_assignment)
for consumer_id, partitions in six.iteritems(assignments):
to_remove = []
for partition in partitions:
if partition not in self.partition_to_all_potential_consumers:
to_remove.append(partition)
for partition in to_remove:
partitions.remove(partition)
sorted_consumers = SortedSet(
iterable=[(consumer, tuple(partitions)) for consumer, partitions in six.iteritems(assignments)],
key=subscriptions_comparator_key,
)
# at this point, sorted_consumers contains an ascending-sorted list of consumers based on
# how many valid partitions are currently assigned to them
while sorted_consumers:
# take the consumer with the most partitions
consumer, _ = sorted_consumers.pop_last()
# currently assigned partitions to this consumer
remaining_partitions = assignments[consumer]
# from partitions that had a different consumer before,
# keep only those that are assigned to this consumer now
previous_partitions = set(six.iterkeys(self.previous_assignment)).intersection(set(remaining_partitions))
if previous_partitions:
# if there is a partition of this consumer that was assigned to another consumer before
# mark it as good options for reassignment
partition = previous_partitions.pop()
remaining_partitions.remove(partition)
self.sorted_partitions.append(partition)
sorted_consumers.add((consumer, tuple(assignments[consumer])))
elif remaining_partitions:
# otherwise, mark any other one of the current partitions as a reassignment candidate
self.sorted_partitions.append(remaining_partitions.pop())
sorted_consumers.add((consumer, tuple(assignments[consumer])))
while partitions_sorted_by_num_of_potential_consumers:
partition = partitions_sorted_by_num_of_potential_consumers.pop(0)[0]
if partition not in self.sorted_partitions:
self.sorted_partitions.append(partition)
else:
while partitions_sorted_by_num_of_potential_consumers:
self.sorted_partitions.append(partitions_sorted_by_num_of_potential_consumers.pop(0)[0])
def _populate_partitions_to_reassign(self):
self.unassigned_partitions = deepcopy(self.sorted_partitions)
assignments_to_remove = []
for consumer_id, partitions in six.iteritems(self.current_assignment):
if consumer_id not in self.members:
# if a consumer that existed before (and had some partition assignments) is now removed,
# remove it from current_assignment
for partition in partitions:
del self.current_partition_consumer[partition]
assignments_to_remove.append(consumer_id)
else:
# otherwise (the consumer still exists)
partitions_to_remove = []
for partition in partitions:
if partition not in self.partition_to_all_potential_consumers:
# if this topic partition of this consumer no longer exists
# remove it from current_assignment of the consumer
partitions_to_remove.append(partition)
elif partition.topic not in self.members[consumer_id].subscription:
# if this partition cannot remain assigned to its current consumer because the consumer
# is no longer subscribed to its topic remove it from current_assignment of the consumer
partitions_to_remove.append(partition)
self.revocation_required = True
else:
# otherwise, remove the topic partition from those that need to be assigned only if
# its current consumer is still subscribed to its topic (because it is already assigned
# and we would want to preserve that assignment as much as possible)
self.unassigned_partitions.remove(partition)
for partition in partitions_to_remove:
self.current_assignment[consumer_id].remove(partition)
del self.current_partition_consumer[partition]
for consumer_id in assignments_to_remove:
del self.current_assignment[consumer_id]
def _initialize_current_subscriptions(self):
self.sorted_current_subscriptions = SortedSet(
iterable=[(consumer, tuple(partitions)) for consumer, partitions in six.iteritems(self.current_assignment)],
key=subscriptions_comparator_key,
)
def _get_consumer_with_least_subscriptions(self):
return self.sorted_current_subscriptions.first()[0]
def _get_consumer_with_most_subscriptions(self):
return self.sorted_current_subscriptions.last()[0]
def _remove_consumer_from_current_subscriptions_and_maintain_order(self, consumer):
self.sorted_current_subscriptions.remove((consumer, tuple(self.current_assignment[consumer])))
def _add_consumer_to_current_subscriptions_and_maintain_order(self, consumer):
self.sorted_current_subscriptions.add((consumer, tuple(self.current_assignment[consumer])))
def _is_balanced(self):
"""Determines if the current assignment is a balanced one"""
if (
len(self.current_assignment[self._get_consumer_with_least_subscriptions()])
>= len(self.current_assignment[self._get_consumer_with_most_subscriptions()]) - 1
):
# if minimum and maximum numbers of partitions assigned to consumers differ by at most one return true
return True
# create a mapping from partitions to the consumer assigned to them
all_assigned_partitions = {}
for consumer_id, consumer_partitions in six.iteritems(self.current_assignment):
for partition in consumer_partitions:
if partition in all_assigned_partitions:
log.error("{} is assigned to more than one consumer.".format(partition))
all_assigned_partitions[partition] = consumer_id
# for each consumer that does not have all the topic partitions it can get
# make sure none of the topic partitions it could but did not get cannot be moved to it
# (because that would break the balance)
for consumer, _ in self.sorted_current_subscriptions:
consumer_partition_count = len(self.current_assignment[consumer])
# skip if this consumer already has all the topic partitions it can get
if consumer_partition_count == len(self.consumer_to_all_potential_partitions[consumer]):
continue
# otherwise make sure it cannot get any more
for partition in self.consumer_to_all_potential_partitions[consumer]:
if partition not in self.current_assignment[consumer]:
other_consumer = all_assigned_partitions[partition]
other_consumer_partition_count = len(self.current_assignment[other_consumer])
if consumer_partition_count < other_consumer_partition_count:
return False
return True
def _assign_partition(self, partition):
for consumer, _ in self.sorted_current_subscriptions:
if partition in self.consumer_to_all_potential_partitions[consumer]:
self._remove_consumer_from_current_subscriptions_and_maintain_order(consumer)
self.current_assignment[consumer].append(partition)
self.current_partition_consumer[partition] = consumer
self._add_consumer_to_current_subscriptions_and_maintain_order(consumer)
break
def _can_partition_participate_in_reassignment(self, partition):
return len(self.partition_to_all_potential_consumers[partition]) >= 2
def _can_consumer_participate_in_reassignment(self, consumer):
current_partitions = self.current_assignment[consumer]
current_assignment_size = len(current_partitions)
max_assignment_size = len(self.consumer_to_all_potential_partitions[consumer])
if current_assignment_size > max_assignment_size:
log.error("The consumer {} is assigned more partitions than the maximum possible.".format(consumer))
if current_assignment_size < max_assignment_size:
# if a consumer is not assigned all its potential partitions it is subject to reassignment
return True
for partition in current_partitions:
# if any of the partitions assigned to a consumer is subject to reassignment the consumer itself
# is subject to reassignment
if self._can_partition_participate_in_reassignment(partition):
return True
return False
def _perform_reassignments(self, reassignable_partitions):
reassignment_performed = False
# repeat reassignment until no partition can be moved to improve the balance
while True:
modified = False
# reassign all reassignable partitions until the full list is processed or a balance is achieved
# (starting from the partition with least potential consumers and if needed)
for partition in reassignable_partitions:
if self._is_balanced():
break
# the partition must have at least two potential consumers
if len(self.partition_to_all_potential_consumers[partition]) <= 1:
log.error("Expected more than one potential consumer for partition {}".format(partition))
# the partition must have a current consumer
consumer = self.current_partition_consumer.get(partition)
if consumer is None:
log.error("Expected partition {} to be assigned to a consumer".format(partition))
if (
partition in self.previous_assignment
and len(self.current_assignment[consumer])
> len(self.current_assignment[self.previous_assignment[partition].consumer]) + 1
):
self._reassign_partition_to_consumer(
partition, self.previous_assignment[partition].consumer,
)
reassignment_performed = True
modified = True
continue
# check if a better-suited consumer exist for the partition; if so, reassign it
for other_consumer in self.partition_to_all_potential_consumers[partition]:
if len(self.current_assignment[consumer]) > len(self.current_assignment[other_consumer]) + 1:
self._reassign_partition(partition)
reassignment_performed = True
modified = True
break
if not modified:
break
return reassignment_performed
def _reassign_partition(self, partition):
new_consumer = None
for another_consumer, _ in self.sorted_current_subscriptions:
if partition in self.consumer_to_all_potential_partitions[another_consumer]:
new_consumer = another_consumer
break
assert new_consumer is not None
self._reassign_partition_to_consumer(partition, new_consumer)
def _reassign_partition_to_consumer(self, partition, new_consumer):
consumer = self.current_partition_consumer[partition]
# find the correct partition movement considering the stickiness requirement
partition_to_be_moved = self.partition_movements.get_partition_to_be_moved(partition, consumer, new_consumer)
self._move_partition(partition_to_be_moved, new_consumer)
def _move_partition(self, partition, new_consumer):
old_consumer = self.current_partition_consumer[partition]
self._remove_consumer_from_current_subscriptions_and_maintain_order(old_consumer)
self._remove_consumer_from_current_subscriptions_and_maintain_order(new_consumer)
self.partition_movements.move_partition(partition, old_consumer, new_consumer)
self.current_assignment[old_consumer].remove(partition)
self.current_assignment[new_consumer].append(partition)
self.current_partition_consumer[partition] = new_consumer
self._add_consumer_to_current_subscriptions_and_maintain_order(new_consumer)
self._add_consumer_to_current_subscriptions_and_maintain_order(old_consumer)
@staticmethod
def _get_balance_score(assignment):
"""Calculates a balance score of a give assignment
as the sum of assigned partitions size difference of all consumer pairs.
A perfectly balanced assignment (with all consumers getting the same number of partitions)
has a balance score of 0. Lower balance score indicates a more balanced assignment.
Arguments:
assignment (dict): {consumer: list of assigned topic partitions}
Returns:
the balance score of the assignment
"""
score = 0
consumer_to_assignment = {}
for consumer_id, partitions in six.iteritems(assignment):
consumer_to_assignment[consumer_id] = len(partitions)
consumers_to_explore = set(consumer_to_assignment.keys())
for consumer_id in consumer_to_assignment.keys():
if consumer_id in consumers_to_explore:
consumers_to_explore.remove(consumer_id)
for other_consumer_id in consumers_to_explore:
score += abs(consumer_to_assignment[consumer_id] - consumer_to_assignment[other_consumer_id])
return score
class StickyPartitionAssignor(AbstractPartitionAssignor):
"""
https://cwiki.apache.org/confluence/display/KAFKA/KIP-54+-+Sticky+Partition+Assignment+Strategy
The sticky assignor serves two purposes. First, it guarantees an assignment that is as balanced as possible, meaning either:
- the numbers of topic partitions assigned to consumers differ by at most one; or
- each consumer that has 2+ fewer topic partitions than some other consumer cannot get any of those topic partitions transferred to it.
Second, it preserved as many existing assignment as possible when a reassignment occurs.
This helps in saving some of the overhead processing when topic partitions move from one consumer to another.
Starting fresh it would work by distributing the partitions over consumers as evenly as possible.
Even though this may sound similar to how round robin assignor works, the second example below shows that it is not.
During a reassignment it would perform the reassignment in such a way that in the new assignment
- topic partitions are still distributed as evenly as possible, and
- topic partitions stay with their previously assigned consumers as much as possible.
The first goal above takes precedence over the second one.
Example 1.
Suppose there are three consumers C0, C1, C2,
four topics t0, t1, t2, t3, and each topic has 2 partitions,
resulting in partitions t0p0, t0p1, t1p0, t1p1, t2p0, t2p1, t3p0, t3p1.
Each consumer is subscribed to all three topics.
The assignment with both sticky and round robin assignors will be:
- C0: [t0p0, t1p1, t3p0]
- C1: [t0p1, t2p0, t3p1]
- C2: [t1p0, t2p1]
Now, let's assume C1 is removed and a reassignment is about to happen. The round robin assignor would produce:
- C0: [t0p0, t1p0, t2p0, t3p0]
- C2: [t0p1, t1p1, t2p1, t3p1]
while the sticky assignor would result in:
- C0 [t0p0, t1p1, t3p0, t2p0]
- C2 [t1p0, t2p1, t0p1, t3p1]
preserving all the previous assignments (unlike the round robin assignor).
Example 2.
There are three consumers C0, C1, C2,
and three topics t0, t1, t2, with 1, 2, and 3 partitions respectively.
Therefore, the partitions are t0p0, t1p0, t1p1, t2p0, t2p1, t2p2.
C0 is subscribed to t0;
C1 is subscribed to t0, t1;
and C2 is subscribed to t0, t1, t2.
The round robin assignor would come up with the following assignment:
- C0 [t0p0]
- C1 [t1p0]
- C2 [t1p1, t2p0, t2p1, t2p2]
which is not as balanced as the assignment suggested by sticky assignor:
- C0 [t0p0]
- C1 [t1p0, t1p1]
- C2 [t2p0, t2p1, t2p2]
Now, if consumer C0 is removed, these two assignors would produce the following assignments.
Round Robin (preserves 3 partition assignments):
- C1 [t0p0, t1p1]
- C2 [t1p0, t2p0, t2p1, t2p2]
Sticky (preserves 5 partition assignments):
- C1 [t1p0, t1p1, t0p0]
- C2 [t2p0, t2p1, t2p2]
"""
DEFAULT_GENERATION_ID = -1
name = "sticky"
version = 0
member_assignment = None
generation = DEFAULT_GENERATION_ID
_latest_partition_movements = None
@classmethod
def assign(cls, cluster, members):
"""Performs group assignment given cluster metadata and member subscriptions
Arguments:
cluster (ClusterMetadata): cluster metadata
members (dict of {member_id: MemberMetadata}): decoded metadata for each member in the group.
Returns:
dict: {member_id: MemberAssignment}
"""
members_metadata = {}
for consumer, member_metadata in six.iteritems(members):
members_metadata[consumer] = cls.parse_member_metadata(member_metadata)
executor = StickyAssignmentExecutor(cluster, members_metadata)
executor.perform_initial_assignment()
executor.balance()
cls._latest_partition_movements = executor.partition_movements
assignment = {}
for member_id in members:
assignment[member_id] = ConsumerProtocolMemberAssignment(
cls.version, sorted(executor.get_final_assignment(member_id)), b''
)
return assignment
@classmethod
def parse_member_metadata(cls, metadata):
"""
Parses member metadata into a python object.
This implementation only serializes and deserializes the StickyAssignorMemberMetadataV1 user data,
since no StickyAssignor written in Python was deployed ever in the wild with version V0, meaning that
there is no need to support backward compatibility with V0.
Arguments:
metadata (MemberMetadata): decoded metadata for a member of the group.
Returns:
parsed metadata (StickyAssignorMemberMetadataV1)
"""
user_data = metadata.user_data
if not user_data:
return StickyAssignorMemberMetadataV1(
partitions=[], generation=cls.DEFAULT_GENERATION_ID, subscription=metadata.subscription
)
try:
decoded_user_data = StickyAssignorUserDataV1.decode(user_data)
except Exception as e:
# ignore the consumer's previous assignment if it cannot be parsed
log.error("Could not parse member data", e) # pylint: disable=logging-too-many-args
return StickyAssignorMemberMetadataV1(
partitions=[], generation=cls.DEFAULT_GENERATION_ID, subscription=metadata.subscription
)
member_partitions = []
for topic, partitions in decoded_user_data.previous_assignment: # pylint: disable=no-member
member_partitions.extend([TopicPartition(topic, partition) for partition in partitions])
return StickyAssignorMemberMetadataV1(
# pylint: disable=no-member
partitions=member_partitions, generation=decoded_user_data.generation, subscription=metadata.subscription
)
@classmethod
def metadata(cls, topics):
if cls.member_assignment is None:
log.debug("No member assignment available")
user_data = b''
else:
log.debug("Member assignment is available, generating the metadata: generation {}".format(cls.generation))
partitions_by_topic = defaultdict(list)
for topic_partition in cls.member_assignment: # pylint: disable=not-an-iterable
partitions_by_topic[topic_partition.topic].append(topic_partition.partition)
data = StickyAssignorUserDataV1(six.iteritems(partitions_by_topic), cls.generation)
user_data = data.encode()
return ConsumerProtocolMemberMetadata(cls.version, list(topics), user_data)
@classmethod
def on_assignment(cls, assignment):
"""Callback that runs on each assignment. Updates assignor's state.
Arguments:
assignment: MemberAssignment
"""
log.debug("On assignment: assignment={}".format(assignment))
cls.member_assignment = assignment.partitions()
@classmethod
def on_generation_assignment(cls, generation):
"""Callback that runs on each assignment. Updates assignor's generation id.
Arguments:
generation: generation id
"""
log.debug("On generation assignment: generation={}".format(generation))
cls.generation = generation

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,833 @@
from __future__ import absolute_import, division
import collections
import copy
import functools
import logging
import time
from kafka.vendor import six
from kafka.coordinator.base import BaseCoordinator, Generation
from kafka.coordinator.assignors.range import RangePartitionAssignor
from kafka.coordinator.assignors.roundrobin import RoundRobinPartitionAssignor
from kafka.coordinator.assignors.sticky.sticky_assignor import StickyPartitionAssignor
from kafka.coordinator.protocol import ConsumerProtocol
import kafka.errors as Errors
from kafka.future import Future
from kafka.metrics import AnonMeasurable
from kafka.metrics.stats import Avg, Count, Max, Rate
from kafka.protocol.commit import OffsetCommitRequest, OffsetFetchRequest
from kafka.structs import OffsetAndMetadata, TopicPartition
from kafka.util import WeakMethod
log = logging.getLogger(__name__)
class ConsumerCoordinator(BaseCoordinator):
"""This class manages the coordination process with the consumer coordinator."""
DEFAULT_CONFIG = {
'group_id': 'kafka-python-default-group',
'enable_auto_commit': True,
'auto_commit_interval_ms': 5000,
'default_offset_commit_callback': None,
'assignors': (RangePartitionAssignor, RoundRobinPartitionAssignor, StickyPartitionAssignor),
'session_timeout_ms': 10000,
'heartbeat_interval_ms': 3000,
'max_poll_interval_ms': 300000,
'retry_backoff_ms': 100,
'api_version': (0, 10, 1),
'exclude_internal_topics': True,
'metric_group_prefix': 'consumer'
}
def __init__(self, client, subscription, metrics, **configs):
"""Initialize the coordination manager.
Keyword Arguments:
group_id (str): name of the consumer group to join for dynamic
partition assignment (if enabled), and to use for fetching and
committing offsets. Default: 'kafka-python-default-group'
enable_auto_commit (bool): If true the consumer's offset will be
periodically committed in the background. Default: True.
auto_commit_interval_ms (int): milliseconds between automatic
offset commits, if enable_auto_commit is True. Default: 5000.
default_offset_commit_callback (callable): called as
callback(offsets, exception) response will be either an Exception
or None. This callback can be used to trigger custom actions when
a commit request completes.
assignors (list): List of objects to use to distribute partition
ownership amongst consumer instances when group management is
used. Default: [RangePartitionAssignor, RoundRobinPartitionAssignor]
heartbeat_interval_ms (int): The expected time in milliseconds
between heartbeats to the consumer coordinator when using
Kafka's group management feature. Heartbeats are used to ensure
that the consumer's session stays active and to facilitate
rebalancing when new consumers join or leave the group. The
value must be set lower than session_timeout_ms, but typically
should be set no higher than 1/3 of that value. It can be
adjusted even lower to control the expected time for normal
rebalances. Default: 3000
session_timeout_ms (int): The timeout used to detect failures when
using Kafka's group management facilities. Default: 30000
retry_backoff_ms (int): Milliseconds to backoff when retrying on
errors. Default: 100.
exclude_internal_topics (bool): Whether records from internal topics
(such as offsets) should be exposed to the consumer. If set to
True the only way to receive records from an internal topic is
subscribing to it. Requires 0.10+. Default: True
"""
super(ConsumerCoordinator, self).__init__(client, metrics, **configs)
self.config = copy.copy(self.DEFAULT_CONFIG)
for key in self.config:
if key in configs:
self.config[key] = configs[key]
self._subscription = subscription
self._is_leader = False
self._joined_subscription = set()
self._metadata_snapshot = self._build_metadata_snapshot(subscription, client.cluster)
self._assignment_snapshot = None
self._cluster = client.cluster
self.auto_commit_interval = self.config['auto_commit_interval_ms'] / 1000
self.next_auto_commit_deadline = None
self.completed_offset_commits = collections.deque()
if self.config['default_offset_commit_callback'] is None:
self.config['default_offset_commit_callback'] = self._default_offset_commit_callback
if self.config['group_id'] is not None:
if self.config['api_version'] >= (0, 9):
if not self.config['assignors']:
raise Errors.KafkaConfigurationError('Coordinator requires assignors')
if self.config['api_version'] < (0, 10, 1):
if self.config['max_poll_interval_ms'] != self.config['session_timeout_ms']:
raise Errors.KafkaConfigurationError("Broker version %s does not support "
"different values for max_poll_interval_ms "
"and session_timeout_ms")
if self.config['enable_auto_commit']:
if self.config['api_version'] < (0, 8, 1):
log.warning('Broker version (%s) does not support offset'
' commits; disabling auto-commit.',
self.config['api_version'])
self.config['enable_auto_commit'] = False
elif self.config['group_id'] is None:
log.warning('group_id is None: disabling auto-commit.')
self.config['enable_auto_commit'] = False
else:
self.next_auto_commit_deadline = time.time() + self.auto_commit_interval
self.consumer_sensors = ConsumerCoordinatorMetrics(
metrics, self.config['metric_group_prefix'], self._subscription)
self._cluster.request_update()
self._cluster.add_listener(WeakMethod(self._handle_metadata_update))
def __del__(self):
if hasattr(self, '_cluster') and self._cluster:
self._cluster.remove_listener(WeakMethod(self._handle_metadata_update))
super(ConsumerCoordinator, self).__del__()
def protocol_type(self):
return ConsumerProtocol.PROTOCOL_TYPE
def group_protocols(self):
"""Returns list of preferred (protocols, metadata)"""
if self._subscription.subscription is None:
raise Errors.IllegalStateError('Consumer has not subscribed to topics')
# dpkp note: I really dislike this.
# why? because we are using this strange method group_protocols,
# which is seemingly innocuous, to set internal state (_joined_subscription)
# that is later used to check whether metadata has changed since we joined a group
# but there is no guarantee that this method, group_protocols, will get called
# in the correct sequence or that it will only be called when we want it to be.
# So this really should be moved elsewhere, but I don't have the energy to
# work that out right now. If you read this at some later date after the mutable
# state has bitten you... I'm sorry! It mimics the java client, and that's the
# best I've got for now.
self._joined_subscription = set(self._subscription.subscription)
metadata_list = []
for assignor in self.config['assignors']:
metadata = assignor.metadata(self._joined_subscription)
group_protocol = (assignor.name, metadata)
metadata_list.append(group_protocol)
return metadata_list
def _handle_metadata_update(self, cluster):
# if we encounter any unauthorized topics, raise an exception
if cluster.unauthorized_topics:
raise Errors.TopicAuthorizationFailedError(cluster.unauthorized_topics)
if self._subscription.subscribed_pattern:
topics = []
for topic in cluster.topics(self.config['exclude_internal_topics']):
if self._subscription.subscribed_pattern.match(topic):
topics.append(topic)
if set(topics) != self._subscription.subscription:
self._subscription.change_subscription(topics)
self._client.set_topics(self._subscription.group_subscription())
# check if there are any changes to the metadata which should trigger
# a rebalance
if self._subscription.partitions_auto_assigned():
metadata_snapshot = self._build_metadata_snapshot(self._subscription, cluster)
if self._metadata_snapshot != metadata_snapshot:
self._metadata_snapshot = metadata_snapshot
# If we haven't got group coordinator support,
# just assign all partitions locally
if self._auto_assign_all_partitions():
self._subscription.assign_from_subscribed([
TopicPartition(topic, partition)
for topic in self._subscription.subscription
for partition in self._metadata_snapshot[topic]
])
def _auto_assign_all_partitions(self):
# For users that use "subscribe" without group support,
# we will simply assign all partitions to this consumer
if self.config['api_version'] < (0, 9):
return True
elif self.config['group_id'] is None:
return True
else:
return False
def _build_metadata_snapshot(self, subscription, cluster):
metadata_snapshot = {}
for topic in subscription.group_subscription():
partitions = cluster.partitions_for_topic(topic) or []
metadata_snapshot[topic] = set(partitions)
return metadata_snapshot
def _lookup_assignor(self, name):
for assignor in self.config['assignors']:
if assignor.name == name:
return assignor
return None
def _on_join_complete(self, generation, member_id, protocol,
member_assignment_bytes):
# only the leader is responsible for monitoring for metadata changes
# (i.e. partition changes)
if not self._is_leader:
self._assignment_snapshot = None
assignor = self._lookup_assignor(protocol)
assert assignor, 'Coordinator selected invalid assignment protocol: %s' % (protocol,)
assignment = ConsumerProtocol.ASSIGNMENT.decode(member_assignment_bytes)
# set the flag to refresh last committed offsets
self._subscription.needs_fetch_committed_offsets = True
# update partition assignment
try:
self._subscription.assign_from_subscribed(assignment.partitions())
except ValueError as e:
log.warning("%s. Probably due to a deleted topic. Requesting Re-join" % e)
self.request_rejoin()
# give the assignor a chance to update internal state
# based on the received assignment
assignor.on_assignment(assignment)
if assignor.name == 'sticky':
assignor.on_generation_assignment(generation)
# reschedule the auto commit starting from now
self.next_auto_commit_deadline = time.time() + self.auto_commit_interval
assigned = set(self._subscription.assigned_partitions())
log.info("Setting newly assigned partitions %s for group %s",
assigned, self.group_id)
# execute the user's callback after rebalance
if self._subscription.listener:
try:
self._subscription.listener.on_partitions_assigned(assigned)
except Exception:
log.exception("User provided listener %s for group %s"
" failed on partition assignment: %s",
self._subscription.listener, self.group_id,
assigned)
def poll(self):
"""
Poll for coordinator events. Only applicable if group_id is set, and
broker version supports GroupCoordinators. This ensures that the
coordinator is known, and if using automatic partition assignment,
ensures that the consumer has joined the group. This also handles
periodic offset commits if they are enabled.
"""
if self.group_id is None:
return
self._invoke_completed_offset_commit_callbacks()
self.ensure_coordinator_ready()
if self.config['api_version'] >= (0, 9) and self._subscription.partitions_auto_assigned():
if self.need_rejoin():
# due to a race condition between the initial metadata fetch and the
# initial rebalance, we need to ensure that the metadata is fresh
# before joining initially, and then request the metadata update. If
# metadata update arrives while the rebalance is still pending (for
# example, when the join group is still inflight), then we will lose
# track of the fact that we need to rebalance again to reflect the
# change to the topic subscription. Without ensuring that the
# metadata is fresh, any metadata update that changes the topic
# subscriptions and arrives while a rebalance is in progress will
# essentially be ignored. See KAFKA-3949 for the complete
# description of the problem.
if self._subscription.subscribed_pattern:
metadata_update = self._client.cluster.request_update()
self._client.poll(future=metadata_update)
self.ensure_active_group()
self.poll_heartbeat()
self._maybe_auto_commit_offsets_async()
def time_to_next_poll(self):
"""Return seconds (float) remaining until :meth:`.poll` should be called again"""
if not self.config['enable_auto_commit']:
return self.time_to_next_heartbeat()
if time.time() > self.next_auto_commit_deadline:
return 0
return min(self.next_auto_commit_deadline - time.time(),
self.time_to_next_heartbeat())
def _perform_assignment(self, leader_id, assignment_strategy, members):
assignor = self._lookup_assignor(assignment_strategy)
assert assignor, 'Invalid assignment protocol: %s' % (assignment_strategy,)
member_metadata = {}
all_subscribed_topics = set()
for member_id, metadata_bytes in members:
metadata = ConsumerProtocol.METADATA.decode(metadata_bytes)
member_metadata[member_id] = metadata
all_subscribed_topics.update(metadata.subscription) # pylint: disable-msg=no-member
# the leader will begin watching for changes to any of the topics
# the group is interested in, which ensures that all metadata changes
# will eventually be seen
# Because assignment typically happens within response callbacks,
# we cannot block on metadata updates here (no recursion into poll())
self._subscription.group_subscribe(all_subscribed_topics)
self._client.set_topics(self._subscription.group_subscription())
# keep track of the metadata used for assignment so that we can check
# after rebalance completion whether anything has changed
self._cluster.request_update()
self._is_leader = True
self._assignment_snapshot = self._metadata_snapshot
log.debug("Performing assignment for group %s using strategy %s"
" with subscriptions %s", self.group_id, assignor.name,
member_metadata)
assignments = assignor.assign(self._cluster, member_metadata)
log.debug("Finished assignment for group %s: %s", self.group_id, assignments)
group_assignment = {}
for member_id, assignment in six.iteritems(assignments):
group_assignment[member_id] = assignment
return group_assignment
def _on_join_prepare(self, generation, member_id):
# commit offsets prior to rebalance if auto-commit enabled
self._maybe_auto_commit_offsets_sync()
# execute the user's callback before rebalance
log.info("Revoking previously assigned partitions %s for group %s",
self._subscription.assigned_partitions(), self.group_id)
if self._subscription.listener:
try:
revoked = set(self._subscription.assigned_partitions())
self._subscription.listener.on_partitions_revoked(revoked)
except Exception:
log.exception("User provided subscription listener %s"
" for group %s failed on_partitions_revoked",
self._subscription.listener, self.group_id)
self._is_leader = False
self._subscription.reset_group_subscription()
def need_rejoin(self):
"""Check whether the group should be rejoined
Returns:
bool: True if consumer should rejoin group, False otherwise
"""
if not self._subscription.partitions_auto_assigned():
return False
if self._auto_assign_all_partitions():
return False
# we need to rejoin if we performed the assignment and metadata has changed
if (self._assignment_snapshot is not None
and self._assignment_snapshot != self._metadata_snapshot):
return True
# we need to join if our subscription has changed since the last join
if (self._joined_subscription is not None
and self._joined_subscription != self._subscription.subscription):
return True
return super(ConsumerCoordinator, self).need_rejoin()
def refresh_committed_offsets_if_needed(self):
"""Fetch committed offsets for assigned partitions."""
if self._subscription.needs_fetch_committed_offsets:
offsets = self.fetch_committed_offsets(self._subscription.assigned_partitions())
for partition, offset in six.iteritems(offsets):
# verify assignment is still active
if self._subscription.is_assigned(partition):
self._subscription.assignment[partition].committed = offset
self._subscription.needs_fetch_committed_offsets = False
def fetch_committed_offsets(self, partitions):
"""Fetch the current committed offsets for specified partitions
Arguments:
partitions (list of TopicPartition): partitions to fetch
Returns:
dict: {TopicPartition: OffsetAndMetadata}
"""
if not partitions:
return {}
while True:
self.ensure_coordinator_ready()
# contact coordinator to fetch committed offsets
future = self._send_offset_fetch_request(partitions)
self._client.poll(future=future)
if future.succeeded():
return future.value
if not future.retriable():
raise future.exception # pylint: disable-msg=raising-bad-type
time.sleep(self.config['retry_backoff_ms'] / 1000)
def close(self, autocommit=True):
"""Close the coordinator, leave the current group,
and reset local generation / member_id.
Keyword Arguments:
autocommit (bool): If auto-commit is configured for this consumer,
this optional flag causes the consumer to attempt to commit any
pending consumed offsets prior to close. Default: True
"""
try:
if autocommit:
self._maybe_auto_commit_offsets_sync()
finally:
super(ConsumerCoordinator, self).close()
def _invoke_completed_offset_commit_callbacks(self):
while self.completed_offset_commits:
callback, offsets, exception = self.completed_offset_commits.popleft()
callback(offsets, exception)
def commit_offsets_async(self, offsets, callback=None):
"""Commit specific offsets asynchronously.
Arguments:
offsets (dict {TopicPartition: OffsetAndMetadata}): what to commit
callback (callable, optional): called as callback(offsets, response)
response will be either an Exception or a OffsetCommitResponse
struct. This callback can be used to trigger custom actions when
a commit request completes.
Returns:
kafka.future.Future
"""
self._invoke_completed_offset_commit_callbacks()
if not self.coordinator_unknown():
future = self._do_commit_offsets_async(offsets, callback)
else:
# we don't know the current coordinator, so try to find it and then
# send the commit or fail (we don't want recursive retries which can
# cause offset commits to arrive out of order). Note that there may
# be multiple offset commits chained to the same coordinator lookup
# request. This is fine because the listeners will be invoked in the
# same order that they were added. Note also that BaseCoordinator
# prevents multiple concurrent coordinator lookup requests.
future = self.lookup_coordinator()
future.add_callback(lambda r: functools.partial(self._do_commit_offsets_async, offsets, callback)())
if callback:
future.add_errback(lambda e: self.completed_offset_commits.appendleft((callback, offsets, e)))
# ensure the commit has a chance to be transmitted (without blocking on
# its completion). Note that commits are treated as heartbeats by the
# coordinator, so there is no need to explicitly allow heartbeats
# through delayed task execution.
self._client.poll(timeout_ms=0) # no wakeup if we add that feature
return future
def _do_commit_offsets_async(self, offsets, callback=None):
assert self.config['api_version'] >= (0, 8, 1), 'Unsupported Broker API'
assert all(map(lambda k: isinstance(k, TopicPartition), offsets))
assert all(map(lambda v: isinstance(v, OffsetAndMetadata),
offsets.values()))
if callback is None:
callback = self.config['default_offset_commit_callback']
self._subscription.needs_fetch_committed_offsets = True
future = self._send_offset_commit_request(offsets)
future.add_both(lambda res: self.completed_offset_commits.appendleft((callback, offsets, res)))
return future
def commit_offsets_sync(self, offsets):
"""Commit specific offsets synchronously.
This method will retry until the commit completes successfully or an
unrecoverable error is encountered.
Arguments:
offsets (dict {TopicPartition: OffsetAndMetadata}): what to commit
Raises error on failure
"""
assert self.config['api_version'] >= (0, 8, 1), 'Unsupported Broker API'
assert all(map(lambda k: isinstance(k, TopicPartition), offsets))
assert all(map(lambda v: isinstance(v, OffsetAndMetadata),
offsets.values()))
self._invoke_completed_offset_commit_callbacks()
if not offsets:
return
while True:
self.ensure_coordinator_ready()
future = self._send_offset_commit_request(offsets)
self._client.poll(future=future)
if future.succeeded():
return future.value
if not future.retriable():
raise future.exception # pylint: disable-msg=raising-bad-type
time.sleep(self.config['retry_backoff_ms'] / 1000)
def _maybe_auto_commit_offsets_sync(self):
if self.config['enable_auto_commit']:
try:
self.commit_offsets_sync(self._subscription.all_consumed_offsets())
# The three main group membership errors are known and should not
# require a stacktrace -- just a warning
except (Errors.UnknownMemberIdError,
Errors.IllegalGenerationError,
Errors.RebalanceInProgressError):
log.warning("Offset commit failed: group membership out of date"
" This is likely to cause duplicate message"
" delivery.")
except Exception:
log.exception("Offset commit failed: This is likely to cause"
" duplicate message delivery")
def _send_offset_commit_request(self, offsets):
"""Commit offsets for the specified list of topics and partitions.
This is a non-blocking call which returns a request future that can be
polled in the case of a synchronous commit or ignored in the
asynchronous case.
Arguments:
offsets (dict of {TopicPartition: OffsetAndMetadata}): what should
be committed
Returns:
Future: indicating whether the commit was successful or not
"""
assert self.config['api_version'] >= (0, 8, 1), 'Unsupported Broker API'
assert all(map(lambda k: isinstance(k, TopicPartition), offsets))
assert all(map(lambda v: isinstance(v, OffsetAndMetadata),
offsets.values()))
if not offsets:
log.debug('No offsets to commit')
return Future().success(None)
node_id = self.coordinator()
if node_id is None:
return Future().failure(Errors.GroupCoordinatorNotAvailableError)
# create the offset commit request
offset_data = collections.defaultdict(dict)
for tp, offset in six.iteritems(offsets):
offset_data[tp.topic][tp.partition] = offset
if self._subscription.partitions_auto_assigned():
generation = self.generation()
else:
generation = Generation.NO_GENERATION
# if the generation is None, we are not part of an active group
# (and we expect to be). The only thing we can do is fail the commit
# and let the user rejoin the group in poll()
if self.config['api_version'] >= (0, 9) and generation is None:
return Future().failure(Errors.CommitFailedError())
if self.config['api_version'] >= (0, 9):
request = OffsetCommitRequest[2](
self.group_id,
generation.generation_id,
generation.member_id,
OffsetCommitRequest[2].DEFAULT_RETENTION_TIME,
[(
topic, [(
partition,
offset.offset,
offset.metadata
) for partition, offset in six.iteritems(partitions)]
) for topic, partitions in six.iteritems(offset_data)]
)
elif self.config['api_version'] >= (0, 8, 2):
request = OffsetCommitRequest[1](
self.group_id, -1, '',
[(
topic, [(
partition,
offset.offset,
-1,
offset.metadata
) for partition, offset in six.iteritems(partitions)]
) for topic, partitions in six.iteritems(offset_data)]
)
elif self.config['api_version'] >= (0, 8, 1):
request = OffsetCommitRequest[0](
self.group_id,
[(
topic, [(
partition,
offset.offset,
offset.metadata
) for partition, offset in six.iteritems(partitions)]
) for topic, partitions in six.iteritems(offset_data)]
)
log.debug("Sending offset-commit request with %s for group %s to %s",
offsets, self.group_id, node_id)
future = Future()
_f = self._client.send(node_id, request)
_f.add_callback(self._handle_offset_commit_response, offsets, future, time.time())
_f.add_errback(self._failed_request, node_id, request, future)
return future
def _handle_offset_commit_response(self, offsets, future, send_time, response):
# TODO look at adding request_latency_ms to response (like java kafka)
self.consumer_sensors.commit_latency.record((time.time() - send_time) * 1000)
unauthorized_topics = set()
for topic, partitions in response.topics:
for partition, error_code in partitions:
tp = TopicPartition(topic, partition)
offset = offsets[tp]
error_type = Errors.for_code(error_code)
if error_type is Errors.NoError:
log.debug("Group %s committed offset %s for partition %s",
self.group_id, offset, tp)
if self._subscription.is_assigned(tp):
self._subscription.assignment[tp].committed = offset
elif error_type is Errors.GroupAuthorizationFailedError:
log.error("Not authorized to commit offsets for group %s",
self.group_id)
future.failure(error_type(self.group_id))
return
elif error_type is Errors.TopicAuthorizationFailedError:
unauthorized_topics.add(topic)
elif error_type in (Errors.OffsetMetadataTooLargeError,
Errors.InvalidCommitOffsetSizeError):
# raise the error to the user
log.debug("OffsetCommit for group %s failed on partition %s"
" %s", self.group_id, tp, error_type.__name__)
future.failure(error_type())
return
elif error_type is Errors.GroupLoadInProgressError:
# just retry
log.debug("OffsetCommit for group %s failed: %s",
self.group_id, error_type.__name__)
future.failure(error_type(self.group_id))
return
elif error_type in (Errors.GroupCoordinatorNotAvailableError,
Errors.NotCoordinatorForGroupError,
Errors.RequestTimedOutError):
log.debug("OffsetCommit for group %s failed: %s",
self.group_id, error_type.__name__)
self.coordinator_dead(error_type())
future.failure(error_type(self.group_id))
return
elif error_type in (Errors.UnknownMemberIdError,
Errors.IllegalGenerationError,
Errors.RebalanceInProgressError):
# need to re-join group
error = error_type(self.group_id)
log.debug("OffsetCommit for group %s failed: %s",
self.group_id, error)
self.reset_generation()
future.failure(Errors.CommitFailedError())
return
else:
log.error("Group %s failed to commit partition %s at offset"
" %s: %s", self.group_id, tp, offset,
error_type.__name__)
future.failure(error_type())
return
if unauthorized_topics:
log.error("Not authorized to commit to topics %s for group %s",
unauthorized_topics, self.group_id)
future.failure(Errors.TopicAuthorizationFailedError(unauthorized_topics))
else:
future.success(None)
def _send_offset_fetch_request(self, partitions):
"""Fetch the committed offsets for a set of partitions.
This is a non-blocking call. The returned future can be polled to get
the actual offsets returned from the broker.
Arguments:
partitions (list of TopicPartition): the partitions to fetch
Returns:
Future: resolves to dict of offsets: {TopicPartition: OffsetAndMetadata}
"""
assert self.config['api_version'] >= (0, 8, 1), 'Unsupported Broker API'
assert all(map(lambda k: isinstance(k, TopicPartition), partitions))
if not partitions:
return Future().success({})
node_id = self.coordinator()
if node_id is None:
return Future().failure(Errors.GroupCoordinatorNotAvailableError)
# Verify node is ready
if not self._client.ready(node_id):
log.debug("Node %s not ready -- failing offset fetch request",
node_id)
return Future().failure(Errors.NodeNotReadyError)
log.debug("Group %s fetching committed offsets for partitions: %s",
self.group_id, partitions)
# construct the request
topic_partitions = collections.defaultdict(set)
for tp in partitions:
topic_partitions[tp.topic].add(tp.partition)
if self.config['api_version'] >= (0, 8, 2):
request = OffsetFetchRequest[1](
self.group_id,
list(topic_partitions.items())
)
else:
request = OffsetFetchRequest[0](
self.group_id,
list(topic_partitions.items())
)
# send the request with a callback
future = Future()
_f = self._client.send(node_id, request)
_f.add_callback(self._handle_offset_fetch_response, future)
_f.add_errback(self._failed_request, node_id, request, future)
return future
def _handle_offset_fetch_response(self, future, response):
offsets = {}
for topic, partitions in response.topics:
for partition, offset, metadata, error_code in partitions:
tp = TopicPartition(topic, partition)
error_type = Errors.for_code(error_code)
if error_type is not Errors.NoError:
error = error_type()
log.debug("Group %s failed to fetch offset for partition"
" %s: %s", self.group_id, tp, error)
if error_type is Errors.GroupLoadInProgressError:
# just retry
future.failure(error)
elif error_type is Errors.NotCoordinatorForGroupError:
# re-discover the coordinator and retry
self.coordinator_dead(error_type())
future.failure(error)
elif error_type is Errors.UnknownTopicOrPartitionError:
log.warning("OffsetFetchRequest -- unknown topic %s"
" (have you committed any offsets yet?)",
topic)
continue
else:
log.error("Unknown error fetching offsets for %s: %s",
tp, error)
future.failure(error)
return
elif offset >= 0:
# record the position with the offset
# (-1 indicates no committed offset to fetch)
offsets[tp] = OffsetAndMetadata(offset, metadata)
else:
log.debug("Group %s has no committed offset for partition"
" %s", self.group_id, tp)
future.success(offsets)
def _default_offset_commit_callback(self, offsets, exception):
if exception is not None:
log.error("Offset commit failed: %s", exception)
def _commit_offsets_async_on_complete(self, offsets, exception):
if exception is not None:
log.warning("Auto offset commit failed for group %s: %s",
self.group_id, exception)
if getattr(exception, 'retriable', False):
self.next_auto_commit_deadline = min(time.time() + self.config['retry_backoff_ms'] / 1000, self.next_auto_commit_deadline)
else:
log.debug("Completed autocommit of offsets %s for group %s",
offsets, self.group_id)
def _maybe_auto_commit_offsets_async(self):
if self.config['enable_auto_commit']:
if self.coordinator_unknown():
self.next_auto_commit_deadline = time.time() + self.config['retry_backoff_ms'] / 1000
elif time.time() > self.next_auto_commit_deadline:
self.next_auto_commit_deadline = time.time() + self.auto_commit_interval
self.commit_offsets_async(self._subscription.all_consumed_offsets(),
self._commit_offsets_async_on_complete)
class ConsumerCoordinatorMetrics(object):
def __init__(self, metrics, metric_group_prefix, subscription):
self.metrics = metrics
self.metric_group_name = '%s-coordinator-metrics' % (metric_group_prefix,)
self.commit_latency = metrics.sensor('commit-latency')
self.commit_latency.add(metrics.metric_name(
'commit-latency-avg', self.metric_group_name,
'The average time taken for a commit request'), Avg())
self.commit_latency.add(metrics.metric_name(
'commit-latency-max', self.metric_group_name,
'The max time taken for a commit request'), Max())
self.commit_latency.add(metrics.metric_name(
'commit-rate', self.metric_group_name,
'The number of commit calls per second'), Rate(sampled_stat=Count()))
num_parts = AnonMeasurable(lambda config, now:
len(subscription.assigned_partitions()))
metrics.add_metric(metrics.metric_name(
'assigned-partitions', self.metric_group_name,
'The number of partitions currently assigned to this consumer'),
num_parts)

View File

@@ -0,0 +1,68 @@
from __future__ import absolute_import, division
import copy
import time
class Heartbeat(object):
DEFAULT_CONFIG = {
'group_id': None,
'heartbeat_interval_ms': 3000,
'session_timeout_ms': 10000,
'max_poll_interval_ms': 300000,
'retry_backoff_ms': 100,
}
def __init__(self, **configs):
self.config = copy.copy(self.DEFAULT_CONFIG)
for key in self.config:
if key in configs:
self.config[key] = configs[key]
if self.config['group_id'] is not None:
assert (self.config['heartbeat_interval_ms']
<= self.config['session_timeout_ms']), (
'Heartbeat interval must be lower than the session timeout')
self.last_send = -1 * float('inf')
self.last_receive = -1 * float('inf')
self.last_poll = -1 * float('inf')
self.last_reset = time.time()
self.heartbeat_failed = None
def poll(self):
self.last_poll = time.time()
def sent_heartbeat(self):
self.last_send = time.time()
self.heartbeat_failed = False
def fail_heartbeat(self):
self.heartbeat_failed = True
def received_heartbeat(self):
self.last_receive = time.time()
def time_to_next_heartbeat(self):
"""Returns seconds (float) remaining before next heartbeat should be sent"""
time_since_last_heartbeat = time.time() - max(self.last_send, self.last_reset)
if self.heartbeat_failed:
delay_to_next_heartbeat = self.config['retry_backoff_ms'] / 1000
else:
delay_to_next_heartbeat = self.config['heartbeat_interval_ms'] / 1000
return max(0, delay_to_next_heartbeat - time_since_last_heartbeat)
def should_heartbeat(self):
return self.time_to_next_heartbeat() == 0
def session_timeout_expired(self):
last_recv = max(self.last_receive, self.last_reset)
return (time.time() - last_recv) > (self.config['session_timeout_ms'] / 1000)
def reset_timeouts(self):
self.last_reset = time.time()
self.last_poll = time.time()
self.heartbeat_failed = False
def poll_timeout_expired(self):
return (time.time() - self.last_poll) > (self.config['max_poll_interval_ms'] / 1000)

View File

@@ -0,0 +1,33 @@
from __future__ import absolute_import
from kafka.protocol.struct import Struct
from kafka.protocol.types import Array, Bytes, Int16, Int32, Schema, String
from kafka.structs import TopicPartition
class ConsumerProtocolMemberMetadata(Struct):
SCHEMA = Schema(
('version', Int16),
('subscription', Array(String('utf-8'))),
('user_data', Bytes))
class ConsumerProtocolMemberAssignment(Struct):
SCHEMA = Schema(
('version', Int16),
('assignment', Array(
('topic', String('utf-8')),
('partitions', Array(Int32)))),
('user_data', Bytes))
def partitions(self):
return [TopicPartition(topic, partition)
for topic, partitions in self.assignment # pylint: disable-msg=no-member
for partition in partitions]
class ConsumerProtocol(object):
PROTOCOL_TYPE = 'consumer'
ASSIGNMENT_STRATEGIES = ('range', 'roundrobin')
METADATA = ConsumerProtocolMemberMetadata
ASSIGNMENT = ConsumerProtocolMemberAssignment

View File

@@ -0,0 +1,538 @@
from __future__ import absolute_import
import inspect
import sys
class KafkaError(RuntimeError):
retriable = False
# whether metadata should be refreshed on error
invalid_metadata = False
def __str__(self):
if not self.args:
return self.__class__.__name__
return '{0}: {1}'.format(self.__class__.__name__,
super(KafkaError, self).__str__())
class IllegalStateError(KafkaError):
pass
class IllegalArgumentError(KafkaError):
pass
class NoBrokersAvailable(KafkaError):
retriable = True
invalid_metadata = True
class NodeNotReadyError(KafkaError):
retriable = True
class KafkaProtocolError(KafkaError):
retriable = True
class CorrelationIdError(KafkaProtocolError):
retriable = True
class Cancelled(KafkaError):
retriable = True
class TooManyInFlightRequests(KafkaError):
retriable = True
class StaleMetadata(KafkaError):
retriable = True
invalid_metadata = True
class MetadataEmptyBrokerList(KafkaError):
retriable = True
class UnrecognizedBrokerVersion(KafkaError):
pass
class IncompatibleBrokerVersion(KafkaError):
pass
class CommitFailedError(KafkaError):
def __init__(self, *args, **kwargs):
super(CommitFailedError, self).__init__(
"""Commit cannot be completed since the group has already
rebalanced and assigned the partitions to another member.
This means that the time between subsequent calls to poll()
was longer than the configured max_poll_interval_ms, which
typically implies that the poll loop is spending too much
time message processing. You can address this either by
increasing the rebalance timeout with max_poll_interval_ms,
or by reducing the maximum size of batches returned in poll()
with max_poll_records.
""", *args, **kwargs)
class AuthenticationMethodNotSupported(KafkaError):
pass
class AuthenticationFailedError(KafkaError):
retriable = False
class BrokerResponseError(KafkaError):
errno = None
message = None
description = None
def __str__(self):
"""Add errno to standard KafkaError str"""
return '[Error {0}] {1}'.format(
self.errno,
super(BrokerResponseError, self).__str__())
class NoError(BrokerResponseError):
errno = 0
message = 'NO_ERROR'
description = 'No error--it worked!'
class UnknownError(BrokerResponseError):
errno = -1
message = 'UNKNOWN'
description = 'An unexpected server error.'
class OffsetOutOfRangeError(BrokerResponseError):
errno = 1
message = 'OFFSET_OUT_OF_RANGE'
description = ('The requested offset is outside the range of offsets'
' maintained by the server for the given topic/partition.')
class CorruptRecordException(BrokerResponseError):
errno = 2
message = 'CORRUPT_MESSAGE'
description = ('This message has failed its CRC checksum, exceeds the'
' valid size, or is otherwise corrupt.')
# Backward compatibility
InvalidMessageError = CorruptRecordException
class UnknownTopicOrPartitionError(BrokerResponseError):
errno = 3
message = 'UNKNOWN_TOPIC_OR_PARTITION'
description = ('This request is for a topic or partition that does not'
' exist on this broker.')
retriable = True
invalid_metadata = True
class InvalidFetchRequestError(BrokerResponseError):
errno = 4
message = 'INVALID_FETCH_SIZE'
description = 'The message has a negative size.'
class LeaderNotAvailableError(BrokerResponseError):
errno = 5
message = 'LEADER_NOT_AVAILABLE'
description = ('This error is thrown if we are in the middle of a'
' leadership election and there is currently no leader for'
' this partition and hence it is unavailable for writes.')
retriable = True
invalid_metadata = True
class NotLeaderForPartitionError(BrokerResponseError):
errno = 6
message = 'NOT_LEADER_FOR_PARTITION'
description = ('This error is thrown if the client attempts to send'
' messages to a replica that is not the leader for some'
' partition. It indicates that the clients metadata is out'
' of date.')
retriable = True
invalid_metadata = True
class RequestTimedOutError(BrokerResponseError):
errno = 7
message = 'REQUEST_TIMED_OUT'
description = ('This error is thrown if the request exceeds the'
' user-specified time limit in the request.')
retriable = True
class BrokerNotAvailableError(BrokerResponseError):
errno = 8
message = 'BROKER_NOT_AVAILABLE'
description = ('This is not a client facing error and is used mostly by'
' tools when a broker is not alive.')
class ReplicaNotAvailableError(BrokerResponseError):
errno = 9
message = 'REPLICA_NOT_AVAILABLE'
description = ('If replica is expected on a broker, but is not (this can be'
' safely ignored).')
class MessageSizeTooLargeError(BrokerResponseError):
errno = 10
message = 'MESSAGE_SIZE_TOO_LARGE'
description = ('The server has a configurable maximum message size to avoid'
' unbounded memory allocation. This error is thrown if the'
' client attempt to produce a message larger than this'
' maximum.')
class StaleControllerEpochError(BrokerResponseError):
errno = 11
message = 'STALE_CONTROLLER_EPOCH'
description = 'Internal error code for broker-to-broker communication.'
class OffsetMetadataTooLargeError(BrokerResponseError):
errno = 12
message = 'OFFSET_METADATA_TOO_LARGE'
description = ('If you specify a string larger than configured maximum for'
' offset metadata.')
# TODO is this deprecated? https://cwiki.apache.org/confluence/display/KAFKA/A+Guide+To+The+Kafka+Protocol#AGuideToTheKafkaProtocol-ErrorCodes
class StaleLeaderEpochCodeError(BrokerResponseError):
errno = 13
message = 'STALE_LEADER_EPOCH_CODE'
class GroupLoadInProgressError(BrokerResponseError):
errno = 14
message = 'OFFSETS_LOAD_IN_PROGRESS'
description = ('The broker returns this error code for an offset fetch'
' request if it is still loading offsets (after a leader'
' change for that offsets topic partition), or in response'
' to group membership requests (such as heartbeats) when'
' group metadata is being loaded by the coordinator.')
retriable = True
class GroupCoordinatorNotAvailableError(BrokerResponseError):
errno = 15
message = 'CONSUMER_COORDINATOR_NOT_AVAILABLE'
description = ('The broker returns this error code for group coordinator'
' requests, offset commits, and most group management'
' requests if the offsets topic has not yet been created, or'
' if the group coordinator is not active.')
retriable = True
class NotCoordinatorForGroupError(BrokerResponseError):
errno = 16
message = 'NOT_COORDINATOR_FOR_CONSUMER'
description = ('The broker returns this error code if it receives an offset'
' fetch or commit request for a group that it is not a'
' coordinator for.')
retriable = True
class InvalidTopicError(BrokerResponseError):
errno = 17
message = 'INVALID_TOPIC'
description = ('For a request which attempts to access an invalid topic'
' (e.g. one which has an illegal name), or if an attempt'
' is made to write to an internal topic (such as the'
' consumer offsets topic).')
class RecordListTooLargeError(BrokerResponseError):
errno = 18
message = 'RECORD_LIST_TOO_LARGE'
description = ('If a message batch in a produce request exceeds the maximum'
' configured segment size.')
class NotEnoughReplicasError(BrokerResponseError):
errno = 19
message = 'NOT_ENOUGH_REPLICAS'
description = ('Returned from a produce request when the number of in-sync'
' replicas is lower than the configured minimum and'
' requiredAcks is -1.')
retriable = True
class NotEnoughReplicasAfterAppendError(BrokerResponseError):
errno = 20
message = 'NOT_ENOUGH_REPLICAS_AFTER_APPEND'
description = ('Returned from a produce request when the message was'
' written to the log, but with fewer in-sync replicas than'
' required.')
retriable = True
class InvalidRequiredAcksError(BrokerResponseError):
errno = 21
message = 'INVALID_REQUIRED_ACKS'
description = ('Returned from a produce request if the requested'
' requiredAcks is invalid (anything other than -1, 1, or 0).')
class IllegalGenerationError(BrokerResponseError):
errno = 22
message = 'ILLEGAL_GENERATION'
description = ('Returned from group membership requests (such as heartbeats)'
' when the generation id provided in the request is not the'
' current generation.')
class InconsistentGroupProtocolError(BrokerResponseError):
errno = 23
message = 'INCONSISTENT_GROUP_PROTOCOL'
description = ('Returned in join group when the member provides a protocol'
' type or set of protocols which is not compatible with the'
' current group.')
class InvalidGroupIdError(BrokerResponseError):
errno = 24
message = 'INVALID_GROUP_ID'
description = 'Returned in join group when the groupId is empty or null.'
class UnknownMemberIdError(BrokerResponseError):
errno = 25
message = 'UNKNOWN_MEMBER_ID'
description = ('Returned from group requests (offset commits/fetches,'
' heartbeats, etc) when the memberId is not in the current'
' generation.')
class InvalidSessionTimeoutError(BrokerResponseError):
errno = 26
message = 'INVALID_SESSION_TIMEOUT'
description = ('Return in join group when the requested session timeout is'
' outside of the allowed range on the broker')
class RebalanceInProgressError(BrokerResponseError):
errno = 27
message = 'REBALANCE_IN_PROGRESS'
description = ('Returned in heartbeat requests when the coordinator has'
' begun rebalancing the group. This indicates to the client'
' that it should rejoin the group.')
class InvalidCommitOffsetSizeError(BrokerResponseError):
errno = 28
message = 'INVALID_COMMIT_OFFSET_SIZE'
description = ('This error indicates that an offset commit was rejected'
' because of oversize metadata.')
class TopicAuthorizationFailedError(BrokerResponseError):
errno = 29
message = 'TOPIC_AUTHORIZATION_FAILED'
description = ('Returned by the broker when the client is not authorized to'
' access the requested topic.')
class GroupAuthorizationFailedError(BrokerResponseError):
errno = 30
message = 'GROUP_AUTHORIZATION_FAILED'
description = ('Returned by the broker when the client is not authorized to'
' access a particular groupId.')
class ClusterAuthorizationFailedError(BrokerResponseError):
errno = 31
message = 'CLUSTER_AUTHORIZATION_FAILED'
description = ('Returned by the broker when the client is not authorized to'
' use an inter-broker or administrative API.')
class InvalidTimestampError(BrokerResponseError):
errno = 32
message = 'INVALID_TIMESTAMP'
description = 'The timestamp of the message is out of acceptable range.'
class UnsupportedSaslMechanismError(BrokerResponseError):
errno = 33
message = 'UNSUPPORTED_SASL_MECHANISM'
description = 'The broker does not support the requested SASL mechanism.'
class IllegalSaslStateError(BrokerResponseError):
errno = 34
message = 'ILLEGAL_SASL_STATE'
description = 'Request is not valid given the current SASL state.'
class UnsupportedVersionError(BrokerResponseError):
errno = 35
message = 'UNSUPPORTED_VERSION'
description = 'The version of API is not supported.'
class TopicAlreadyExistsError(BrokerResponseError):
errno = 36
message = 'TOPIC_ALREADY_EXISTS'
description = 'Topic with this name already exists.'
class InvalidPartitionsError(BrokerResponseError):
errno = 37
message = 'INVALID_PARTITIONS'
description = 'Number of partitions is invalid.'
class InvalidReplicationFactorError(BrokerResponseError):
errno = 38
message = 'INVALID_REPLICATION_FACTOR'
description = 'Replication-factor is invalid.'
class InvalidReplicationAssignmentError(BrokerResponseError):
errno = 39
message = 'INVALID_REPLICATION_ASSIGNMENT'
description = 'Replication assignment is invalid.'
class InvalidConfigurationError(BrokerResponseError):
errno = 40
message = 'INVALID_CONFIG'
description = 'Configuration is invalid.'
class NotControllerError(BrokerResponseError):
errno = 41
message = 'NOT_CONTROLLER'
description = 'This is not the correct controller for this cluster.'
retriable = True
class InvalidRequestError(BrokerResponseError):
errno = 42
message = 'INVALID_REQUEST'
description = ('This most likely occurs because of a request being'
' malformed by the client library or the message was'
' sent to an incompatible broker. See the broker logs'
' for more details.')
class UnsupportedForMessageFormatError(BrokerResponseError):
errno = 43
message = 'UNSUPPORTED_FOR_MESSAGE_FORMAT'
description = ('The message format version on the broker does not'
' support this request.')
class PolicyViolationError(BrokerResponseError):
errno = 44
message = 'POLICY_VIOLATION'
description = 'Request parameters do not satisfy the configured policy.'
class SecurityDisabledError(BrokerResponseError):
errno = 54
message = 'SECURITY_DISABLED'
description = 'Security features are disabled.'
class NonEmptyGroupError(BrokerResponseError):
errno = 68
message = 'NON_EMPTY_GROUP'
description = 'The group is not empty.'
class GroupIdNotFoundError(BrokerResponseError):
errno = 69
message = 'GROUP_ID_NOT_FOUND'
description = 'The group id does not exist.'
class KafkaUnavailableError(KafkaError):
pass
class KafkaTimeoutError(KafkaError):
pass
class FailedPayloadsError(KafkaError):
def __init__(self, payload, *args):
super(FailedPayloadsError, self).__init__(*args)
self.payload = payload
class KafkaConnectionError(KafkaError):
retriable = True
invalid_metadata = True
class ProtocolError(KafkaError):
pass
class UnsupportedCodecError(KafkaError):
pass
class KafkaConfigurationError(KafkaError):
pass
class QuotaViolationError(KafkaError):
pass
class AsyncProducerQueueFull(KafkaError):
def __init__(self, failed_msgs, *args):
super(AsyncProducerQueueFull, self).__init__(*args)
self.failed_msgs = failed_msgs
def _iter_broker_errors():
for name, obj in inspect.getmembers(sys.modules[__name__]):
if inspect.isclass(obj) and issubclass(obj, BrokerResponseError) and obj != BrokerResponseError:
yield obj
kafka_errors = dict([(x.errno, x) for x in _iter_broker_errors()])
def for_code(error_code):
return kafka_errors.get(error_code, UnknownError)
def check_error(response):
if isinstance(response, Exception):
raise response
if response.error:
error_class = kafka_errors.get(response.error, UnknownError)
raise error_class(response)
RETRY_BACKOFF_ERROR_TYPES = (
KafkaUnavailableError, LeaderNotAvailableError,
KafkaConnectionError, FailedPayloadsError
)
RETRY_REFRESH_ERROR_TYPES = (
NotLeaderForPartitionError, UnknownTopicOrPartitionError,
LeaderNotAvailableError, KafkaConnectionError
)
RETRY_ERROR_TYPES = RETRY_BACKOFF_ERROR_TYPES + RETRY_REFRESH_ERROR_TYPES

View File

@@ -0,0 +1,83 @@
from __future__ import absolute_import
import functools
import logging
log = logging.getLogger(__name__)
class Future(object):
error_on_callbacks = False # and errbacks
def __init__(self):
self.is_done = False
self.value = None
self.exception = None
self._callbacks = []
self._errbacks = []
def succeeded(self):
return self.is_done and not bool(self.exception)
def failed(self):
return self.is_done and bool(self.exception)
def retriable(self):
try:
return self.exception.retriable
except AttributeError:
return False
def success(self, value):
assert not self.is_done, 'Future is already complete'
self.value = value
self.is_done = True
if self._callbacks:
self._call_backs('callback', self._callbacks, self.value)
return self
def failure(self, e):
assert not self.is_done, 'Future is already complete'
self.exception = e if type(e) is not type else e()
assert isinstance(self.exception, BaseException), (
'future failed without an exception')
self.is_done = True
self._call_backs('errback', self._errbacks, self.exception)
return self
def add_callback(self, f, *args, **kwargs):
if args or kwargs:
f = functools.partial(f, *args, **kwargs)
if self.is_done and not self.exception:
self._call_backs('callback', [f], self.value)
else:
self._callbacks.append(f)
return self
def add_errback(self, f, *args, **kwargs):
if args or kwargs:
f = functools.partial(f, *args, **kwargs)
if self.is_done and self.exception:
self._call_backs('errback', [f], self.exception)
else:
self._errbacks.append(f)
return self
def add_both(self, f, *args, **kwargs):
self.add_callback(f, *args, **kwargs)
self.add_errback(f, *args, **kwargs)
return self
def chain(self, future):
self.add_callback(future.success)
self.add_errback(future.failure)
return self
def _call_backs(self, back_type, backs, value):
for f in backs:
try:
f(value)
except Exception as e:
log.exception('Error processing %s', back_type)
if self.error_on_callbacks:
raise e

View File

@@ -0,0 +1,15 @@
from __future__ import absolute_import
from kafka.metrics.compound_stat import NamedMeasurable
from kafka.metrics.dict_reporter import DictReporter
from kafka.metrics.kafka_metric import KafkaMetric
from kafka.metrics.measurable import AnonMeasurable
from kafka.metrics.metric_config import MetricConfig
from kafka.metrics.metric_name import MetricName
from kafka.metrics.metrics import Metrics
from kafka.metrics.quota import Quota
__all__ = [
'AnonMeasurable', 'DictReporter', 'KafkaMetric', 'MetricConfig',
'MetricName', 'Metrics', 'NamedMeasurable', 'Quota'
]

View File

@@ -0,0 +1,34 @@
from __future__ import absolute_import
import abc
from kafka.metrics.stat import AbstractStat
class AbstractCompoundStat(AbstractStat):
"""
A compound stat is a stat where a single measurement and associated
data structure feeds many metrics. This is the example for a
histogram which has many associated percentiles.
"""
__metaclass__ = abc.ABCMeta
def stats(self):
"""
Return list of NamedMeasurable
"""
raise NotImplementedError
class NamedMeasurable(object):
def __init__(self, metric_name, measurable_stat):
self._name = metric_name
self._stat = measurable_stat
@property
def name(self):
return self._name
@property
def stat(self):
return self._stat

View File

@@ -0,0 +1,83 @@
from __future__ import absolute_import
import logging
import threading
from kafka.metrics.metrics_reporter import AbstractMetricsReporter
logger = logging.getLogger(__name__)
class DictReporter(AbstractMetricsReporter):
"""A basic dictionary based metrics reporter.
Store all metrics in a two level dictionary of category > name > metric.
"""
def __init__(self, prefix=''):
self._lock = threading.Lock()
self._prefix = prefix if prefix else '' # never allow None
self._store = {}
def snapshot(self):
"""
Return a nested dictionary snapshot of all metrics and their
values at this time. Example:
{
'category': {
'metric1_name': 42.0,
'metric2_name': 'foo'
}
}
"""
return dict((category, dict((name, metric.value())
for name, metric in list(metrics.items())))
for category, metrics in
list(self._store.items()))
def init(self, metrics):
for metric in metrics:
self.metric_change(metric)
def metric_change(self, metric):
with self._lock:
category = self.get_category(metric)
if category not in self._store:
self._store[category] = {}
self._store[category][metric.metric_name.name] = metric
def metric_removal(self, metric):
with self._lock:
category = self.get_category(metric)
metrics = self._store.get(category, {})
removed = metrics.pop(metric.metric_name.name, None)
if not metrics:
self._store.pop(category, None)
return removed
def get_category(self, metric):
"""
Return a string category for the metric.
The category is made up of this reporter's prefix and the
metric's group and tags.
Examples:
prefix = 'foo', group = 'bar', tags = {'a': 1, 'b': 2}
returns: 'foo.bar.a=1,b=2'
prefix = 'foo', group = 'bar', tags = None
returns: 'foo.bar'
prefix = None, group = 'bar', tags = None
returns: 'bar'
"""
tags = ','.join('%s=%s' % (k, v) for k, v in
sorted(metric.metric_name.tags.items()))
return '.'.join(x for x in
[self._prefix, metric.metric_name.group, tags] if x)
def configure(self, configs):
pass
def close(self):
pass

View File

@@ -0,0 +1,36 @@
from __future__ import absolute_import
import time
class KafkaMetric(object):
# NOTE java constructor takes a lock instance
def __init__(self, metric_name, measurable, config):
if not metric_name:
raise ValueError('metric_name must be non-empty')
if not measurable:
raise ValueError('measurable must be non-empty')
self._metric_name = metric_name
self._measurable = measurable
self._config = config
@property
def metric_name(self):
return self._metric_name
@property
def measurable(self):
return self._measurable
@property
def config(self):
return self._config
@config.setter
def config(self, config):
self._config = config
def value(self, time_ms=None):
if time_ms is None:
time_ms = time.time() * 1000
return self.measurable.measure(self.config, time_ms)

View File

@@ -0,0 +1,29 @@
from __future__ import absolute_import
import abc
class AbstractMeasurable(object):
"""A measurable quantity that can be registered as a metric"""
@abc.abstractmethod
def measure(self, config, now):
"""
Measure this quantity and return the result
Arguments:
config (MetricConfig): The configuration for this metric
now (int): The POSIX time in milliseconds the measurement
is being taken
Returns:
The measured value
"""
raise NotImplementedError
class AnonMeasurable(AbstractMeasurable):
def __init__(self, measure_fn):
self._measure_fn = measure_fn
def measure(self, config, now):
return float(self._measure_fn(config, now))

View File

@@ -0,0 +1,16 @@
from __future__ import absolute_import
import abc
from kafka.metrics.measurable import AbstractMeasurable
from kafka.metrics.stat import AbstractStat
class AbstractMeasurableStat(AbstractStat, AbstractMeasurable):
"""
An AbstractMeasurableStat is an AbstractStat that is also
an AbstractMeasurable (i.e. can produce a single floating point value).
This is the interface used for most of the simple statistics such
as Avg, Max, Count, etc.
"""
__metaclass__ = abc.ABCMeta

View File

@@ -0,0 +1,33 @@
from __future__ import absolute_import
import sys
class MetricConfig(object):
"""Configuration values for metrics"""
def __init__(self, quota=None, samples=2, event_window=sys.maxsize,
time_window_ms=30 * 1000, tags=None):
"""
Arguments:
quota (Quota, optional): Upper or lower bound of a value.
samples (int, optional): Max number of samples kept per metric.
event_window (int, optional): Max number of values per sample.
time_window_ms (int, optional): Max age of an individual sample.
tags (dict of {str: str}, optional): Tags for each metric.
"""
self.quota = quota
self._samples = samples
self.event_window = event_window
self.time_window_ms = time_window_ms
# tags should be OrderedDict (not supported in py26)
self.tags = tags if tags else {}
@property
def samples(self):
return self._samples
@samples.setter
def samples(self, value):
if value < 1:
raise ValueError('The number of samples must be at least 1.')
self._samples = value

View File

@@ -0,0 +1,106 @@
from __future__ import absolute_import
import copy
class MetricName(object):
"""
This class encapsulates a metric's name, logical group and its
related attributes (tags).
group, tags parameters can be used to create unique metric names.
e.g. domainName:type=group,key1=val1,key2=val2
Usage looks something like this:
# set up metrics:
metric_tags = {'client-id': 'producer-1', 'topic': 'topic'}
metric_config = MetricConfig(tags=metric_tags)
# metrics is the global repository of metrics and sensors
metrics = Metrics(metric_config)
sensor = metrics.sensor('message-sizes')
metric_name = metrics.metric_name('message-size-avg',
'producer-metrics',
'average message size')
sensor.add(metric_name, Avg())
metric_name = metrics.metric_name('message-size-max',
sensor.add(metric_name, Max())
tags = {'client-id': 'my-client', 'topic': 'my-topic'}
metric_name = metrics.metric_name('message-size-min',
'producer-metrics',
'message minimum size', tags)
sensor.add(metric_name, Min())
# as messages are sent we record the sizes
sensor.record(message_size)
"""
def __init__(self, name, group, description=None, tags=None):
"""
Arguments:
name (str): The name of the metric.
group (str): The logical group name of the metrics to which this
metric belongs.
description (str, optional): A human-readable description to
include in the metric.
tags (dict, optional): Additional key/val attributes of the metric.
"""
if not (name and group):
raise ValueError('name and group must be non-empty.')
if tags is not None and not isinstance(tags, dict):
raise ValueError('tags must be a dict if present.')
self._name = name
self._group = group
self._description = description
self._tags = copy.copy(tags)
self._hash = 0
@property
def name(self):
return self._name
@property
def group(self):
return self._group
@property
def description(self):
return self._description
@property
def tags(self):
return copy.copy(self._tags)
def __hash__(self):
if self._hash != 0:
return self._hash
prime = 31
result = 1
result = prime * result + hash(self.group)
result = prime * result + hash(self.name)
tags_hash = hash(frozenset(self.tags.items())) if self.tags else 0
result = prime * result + tags_hash
self._hash = result
return result
def __eq__(self, other):
if self is other:
return True
if other is None:
return False
return (type(self) == type(other) and
self.group == other.group and
self.name == other.name and
self.tags == other.tags)
def __ne__(self, other):
return not self.__eq__(other)
def __str__(self):
return 'MetricName(name=%s, group=%s, description=%s, tags=%s)' % (
self.name, self.group, self.description, self.tags)

View File

@@ -0,0 +1,261 @@
from __future__ import absolute_import
import logging
import sys
import time
import threading
from kafka.metrics import AnonMeasurable, KafkaMetric, MetricConfig, MetricName
from kafka.metrics.stats import Sensor
logger = logging.getLogger(__name__)
class Metrics(object):
"""
A registry of sensors and metrics.
A metric is a named, numerical measurement. A sensor is a handle to
record numerical measurements as they occur. Each Sensor has zero or
more associated metrics. For example a Sensor might represent message
sizes and we might associate with this sensor a metric for the average,
maximum, or other statistics computed off the sequence of message sizes
that are recorded by the sensor.
Usage looks something like this:
# set up metrics:
metrics = Metrics() # the global repository of metrics and sensors
sensor = metrics.sensor('message-sizes')
metric_name = MetricName('message-size-avg', 'producer-metrics')
sensor.add(metric_name, Avg())
metric_name = MetricName('message-size-max', 'producer-metrics')
sensor.add(metric_name, Max())
# as messages are sent we record the sizes
sensor.record(message_size);
"""
def __init__(self, default_config=None, reporters=None,
enable_expiration=False):
"""
Create a metrics repository with a default config, given metric
reporters and the ability to expire eligible sensors
Arguments:
default_config (MetricConfig, optional): The default config
reporters (list of AbstractMetricsReporter, optional):
The metrics reporters
enable_expiration (bool, optional): true if the metrics instance
can garbage collect inactive sensors, false otherwise
"""
self._lock = threading.RLock()
self._config = default_config or MetricConfig()
self._sensors = {}
self._metrics = {}
self._children_sensors = {}
self._reporters = reporters or []
for reporter in self._reporters:
reporter.init([])
if enable_expiration:
def expire_loop():
while True:
# delay 30 seconds
time.sleep(30)
self.ExpireSensorTask.run(self)
metrics_scheduler = threading.Thread(target=expire_loop)
# Creating a daemon thread to not block shutdown
metrics_scheduler.daemon = True
metrics_scheduler.start()
self.add_metric(self.metric_name('count', 'kafka-metrics-count',
'total number of registered metrics'),
AnonMeasurable(lambda config, now: len(self._metrics)))
@property
def config(self):
return self._config
@property
def metrics(self):
"""
Get all the metrics currently maintained and indexed by metricName
"""
return self._metrics
def metric_name(self, name, group, description='', tags=None):
"""
Create a MetricName with the given name, group, description and tags,
plus default tags specified in the metric configuration.
Tag in tags takes precedence if the same tag key is specified in
the default metric configuration.
Arguments:
name (str): The name of the metric
group (str): logical group name of the metrics to which this
metric belongs
description (str, optional): A human-readable description to
include in the metric
tags (dict, optionals): additional key/value attributes of
the metric
"""
combined_tags = dict(self.config.tags)
combined_tags.update(tags or {})
return MetricName(name, group, description, combined_tags)
def get_sensor(self, name):
"""
Get the sensor with the given name if it exists
Arguments:
name (str): The name of the sensor
Returns:
Sensor: The sensor or None if no such sensor exists
"""
if not name:
raise ValueError('name must be non-empty')
return self._sensors.get(name, None)
def sensor(self, name, config=None,
inactive_sensor_expiration_time_seconds=sys.maxsize,
parents=None):
"""
Get or create a sensor with the given unique name and zero or
more parent sensors. All parent sensors will receive every value
recorded with this sensor.
Arguments:
name (str): The name of the sensor
config (MetricConfig, optional): A default configuration to use
for this sensor for metrics that don't have their own config
inactive_sensor_expiration_time_seconds (int, optional):
If no value if recorded on the Sensor for this duration of
time, it is eligible for removal
parents (list of Sensor): The parent sensors
Returns:
Sensor: The sensor that is created
"""
sensor = self.get_sensor(name)
if sensor:
return sensor
with self._lock:
sensor = self.get_sensor(name)
if not sensor:
sensor = Sensor(self, name, parents, config or self.config,
inactive_sensor_expiration_time_seconds)
self._sensors[name] = sensor
if parents:
for parent in parents:
children = self._children_sensors.get(parent)
if not children:
children = []
self._children_sensors[parent] = children
children.append(sensor)
logger.debug('Added sensor with name %s', name)
return sensor
def remove_sensor(self, name):
"""
Remove a sensor (if it exists), associated metrics and its children.
Arguments:
name (str): The name of the sensor to be removed
"""
sensor = self._sensors.get(name)
if sensor:
child_sensors = None
with sensor._lock:
with self._lock:
val = self._sensors.pop(name, None)
if val and val == sensor:
for metric in sensor.metrics:
self.remove_metric(metric.metric_name)
logger.debug('Removed sensor with name %s', name)
child_sensors = self._children_sensors.pop(sensor, None)
if child_sensors:
for child_sensor in child_sensors:
self.remove_sensor(child_sensor.name)
def add_metric(self, metric_name, measurable, config=None):
"""
Add a metric to monitor an object that implements measurable.
This metric won't be associated with any sensor.
This is a way to expose existing values as metrics.
Arguments:
metricName (MetricName): The name of the metric
measurable (AbstractMeasurable): The measurable that will be
measured by this metric
config (MetricConfig, optional): The configuration to use when
measuring this measurable
"""
# NOTE there was a lock here, but i don't think it's needed
metric = KafkaMetric(metric_name, measurable, config or self.config)
self.register_metric(metric)
def remove_metric(self, metric_name):
"""
Remove a metric if it exists and return it. Return None otherwise.
If a metric is removed, `metric_removal` will be invoked
for each reporter.
Arguments:
metric_name (MetricName): The name of the metric
Returns:
KafkaMetric: the removed `KafkaMetric` or None if no such
metric exists
"""
with self._lock:
metric = self._metrics.pop(metric_name, None)
if metric:
for reporter in self._reporters:
reporter.metric_removal(metric)
return metric
def add_reporter(self, reporter):
"""Add a MetricReporter"""
with self._lock:
reporter.init(list(self.metrics.values()))
self._reporters.append(reporter)
def register_metric(self, metric):
with self._lock:
if metric.metric_name in self.metrics:
raise ValueError('A metric named "%s" already exists, cannot'
' register another one.' % (metric.metric_name,))
self.metrics[metric.metric_name] = metric
for reporter in self._reporters:
reporter.metric_change(metric)
class ExpireSensorTask(object):
"""
This iterates over every Sensor and triggers a remove_sensor
if it has expired. Package private for testing
"""
@staticmethod
def run(metrics):
items = list(metrics._sensors.items())
for name, sensor in items:
# remove_sensor also locks the sensor object. This is fine
# because synchronized is reentrant. There is however a minor
# race condition here. Assume we have a parent sensor P and
# child sensor C. Calling record on C would cause a record on
# P as well. So expiration time for P == expiration time for C.
# If the record on P happens via C just after P is removed,
# that will cause C to also get removed. Since the expiration
# time is typically high it is not expected to be a significant
# concern and thus not necessary to optimize
with sensor._lock:
if sensor.has_expired():
logger.debug('Removing expired sensor %s', name)
metrics.remove_sensor(name)
def close(self):
"""Close this metrics repository."""
for reporter in self._reporters:
reporter.close()
self._metrics.clear()

View File

@@ -0,0 +1,57 @@
from __future__ import absolute_import
import abc
class AbstractMetricsReporter(object):
"""
An abstract class to allow things to listen as new metrics
are created so they can be reported.
"""
__metaclass__ = abc.ABCMeta
@abc.abstractmethod
def init(self, metrics):
"""
This is called when the reporter is first registered
to initially register all existing metrics
Arguments:
metrics (list of KafkaMetric): All currently existing metrics
"""
raise NotImplementedError
@abc.abstractmethod
def metric_change(self, metric):
"""
This is called whenever a metric is updated or added
Arguments:
metric (KafkaMetric)
"""
raise NotImplementedError
@abc.abstractmethod
def metric_removal(self, metric):
"""
This is called whenever a metric is removed
Arguments:
metric (KafkaMetric)
"""
raise NotImplementedError
@abc.abstractmethod
def configure(self, configs):
"""
Configure this class with the given key-value pairs
Arguments:
configs (dict of {str, ?})
"""
raise NotImplementedError
@abc.abstractmethod
def close(self):
"""Called when the metrics repository is closed."""
raise NotImplementedError

View File

@@ -0,0 +1,42 @@
from __future__ import absolute_import
class Quota(object):
"""An upper or lower bound for metrics"""
def __init__(self, bound, is_upper):
self._bound = bound
self._upper = is_upper
@staticmethod
def upper_bound(upper_bound):
return Quota(upper_bound, True)
@staticmethod
def lower_bound(lower_bound):
return Quota(lower_bound, False)
def is_upper_bound(self):
return self._upper
@property
def bound(self):
return self._bound
def is_acceptable(self, value):
return ((self.is_upper_bound() and value <= self.bound) or
(not self.is_upper_bound() and value >= self.bound))
def __hash__(self):
prime = 31
result = prime + self.bound
return prime * result + self.is_upper_bound()
def __eq__(self, other):
if self is other:
return True
return (type(self) == type(other) and
self.bound == other.bound and
self.is_upper_bound() == other.is_upper_bound())
def __ne__(self, other):
return not self.__eq__(other)

View File

@@ -0,0 +1,23 @@
from __future__ import absolute_import
import abc
class AbstractStat(object):
"""
An AbstractStat is a quantity such as average, max, etc that is computed
off the stream of updates to a sensor
"""
__metaclass__ = abc.ABCMeta
@abc.abstractmethod
def record(self, config, value, time_ms):
"""
Record the given value
Arguments:
config (MetricConfig): The configuration to use for this metric
value (float): The value to record
timeMs (int): The POSIX time in milliseconds this value occurred
"""
raise NotImplementedError

View File

@@ -0,0 +1,17 @@
from __future__ import absolute_import
from kafka.metrics.stats.avg import Avg
from kafka.metrics.stats.count import Count
from kafka.metrics.stats.histogram import Histogram
from kafka.metrics.stats.max_stat import Max
from kafka.metrics.stats.min_stat import Min
from kafka.metrics.stats.percentile import Percentile
from kafka.metrics.stats.percentiles import Percentiles
from kafka.metrics.stats.rate import Rate
from kafka.metrics.stats.sensor import Sensor
from kafka.metrics.stats.total import Total
__all__ = [
'Avg', 'Count', 'Histogram', 'Max', 'Min', 'Percentile', 'Percentiles',
'Rate', 'Sensor', 'Total'
]

View File

@@ -0,0 +1,24 @@
from __future__ import absolute_import
from kafka.metrics.stats.sampled_stat import AbstractSampledStat
class Avg(AbstractSampledStat):
"""
An AbstractSampledStat that maintains a simple average over its samples.
"""
def __init__(self):
super(Avg, self).__init__(0.0)
def update(self, sample, config, value, now):
sample.value += value
def combine(self, samples, config, now):
total_sum = 0
total_count = 0
for sample in samples:
total_sum += sample.value
total_count += sample.event_count
if not total_count:
return 0
return float(total_sum) / total_count

View File

@@ -0,0 +1,17 @@
from __future__ import absolute_import
from kafka.metrics.stats.sampled_stat import AbstractSampledStat
class Count(AbstractSampledStat):
"""
An AbstractSampledStat that maintains a simple count of what it has seen.
"""
def __init__(self):
super(Count, self).__init__(0.0)
def update(self, sample, config, value, now):
sample.value += 1.0
def combine(self, samples, config, now):
return float(sum(sample.value for sample in samples))

View File

@@ -0,0 +1,95 @@
from __future__ import absolute_import
import math
class Histogram(object):
def __init__(self, bin_scheme):
self._hist = [0.0] * bin_scheme.bins
self._count = 0.0
self._bin_scheme = bin_scheme
def record(self, value):
self._hist[self._bin_scheme.to_bin(value)] += 1.0
self._count += 1.0
def value(self, quantile):
if self._count == 0.0:
return float('NaN')
_sum = 0.0
quant = float(quantile)
for i, value in enumerate(self._hist[:-1]):
_sum += value
if _sum / self._count > quant:
return self._bin_scheme.from_bin(i)
return float('inf')
@property
def counts(self):
return self._hist
def clear(self):
for i in range(self._hist):
self._hist[i] = 0.0
self._count = 0
def __str__(self):
values = ['%.10f:%.0f' % (self._bin_scheme.from_bin(i), value) for
i, value in enumerate(self._hist[:-1])]
values.append('%s:%s' % (float('inf'), self._hist[-1]))
return '{%s}' % ','.join(values)
class ConstantBinScheme(object):
def __init__(self, bins, min_val, max_val):
if bins < 2:
raise ValueError('Must have at least 2 bins.')
self._min = float(min_val)
self._max = float(max_val)
self._bins = int(bins)
self._bucket_width = (max_val - min_val) / (bins - 2)
@property
def bins(self):
return self._bins
def from_bin(self, b):
if b == 0:
return float('-inf')
elif b == self._bins - 1:
return float('inf')
else:
return self._min + (b - 1) * self._bucket_width
def to_bin(self, x):
if x < self._min:
return 0
elif x > self._max:
return self._bins - 1
else:
return int(((x - self._min) / self._bucket_width) + 1)
class LinearBinScheme(object):
def __init__(self, num_bins, max_val):
self._bins = num_bins
self._max = max_val
self._scale = max_val / (num_bins * (num_bins - 1) / 2)
@property
def bins(self):
return self._bins
def from_bin(self, b):
if b == self._bins - 1:
return float('inf')
else:
unscaled = (b * (b + 1.0)) / 2.0
return unscaled * self._scale
def to_bin(self, x):
if x < 0.0:
raise ValueError('Values less than 0.0 not accepted.')
elif x > self._max:
return self._bins - 1
else:
scaled = x / self._scale
return int(-0.5 + math.sqrt(2.0 * scaled + 0.25))

View File

@@ -0,0 +1,17 @@
from __future__ import absolute_import
from kafka.metrics.stats.sampled_stat import AbstractSampledStat
class Max(AbstractSampledStat):
"""An AbstractSampledStat that gives the max over its samples."""
def __init__(self):
super(Max, self).__init__(float('-inf'))
def update(self, sample, config, value, now):
sample.value = max(sample.value, value)
def combine(self, samples, config, now):
if not samples:
return float('-inf')
return float(max(sample.value for sample in samples))

View File

@@ -0,0 +1,19 @@
from __future__ import absolute_import
import sys
from kafka.metrics.stats.sampled_stat import AbstractSampledStat
class Min(AbstractSampledStat):
"""An AbstractSampledStat that gives the min over its samples."""
def __init__(self):
super(Min, self).__init__(float(sys.maxsize))
def update(self, sample, config, value, now):
sample.value = min(sample.value, value)
def combine(self, samples, config, now):
if not samples:
return float(sys.maxsize)
return float(min(sample.value for sample in samples))

View File

@@ -0,0 +1,15 @@
from __future__ import absolute_import
class Percentile(object):
def __init__(self, metric_name, percentile):
self._metric_name = metric_name
self._percentile = float(percentile)
@property
def name(self):
return self._metric_name
@property
def percentile(self):
return self._percentile

View File

@@ -0,0 +1,74 @@
from __future__ import absolute_import
from kafka.metrics import AnonMeasurable, NamedMeasurable
from kafka.metrics.compound_stat import AbstractCompoundStat
from kafka.metrics.stats import Histogram
from kafka.metrics.stats.sampled_stat import AbstractSampledStat
class BucketSizing(object):
CONSTANT = 0
LINEAR = 1
class Percentiles(AbstractSampledStat, AbstractCompoundStat):
"""A compound stat that reports one or more percentiles"""
def __init__(self, size_in_bytes, bucketing, max_val, min_val=0.0,
percentiles=None):
super(Percentiles, self).__init__(0.0)
self._percentiles = percentiles or []
self._buckets = int(size_in_bytes / 4)
if bucketing == BucketSizing.CONSTANT:
self._bin_scheme = Histogram.ConstantBinScheme(self._buckets,
min_val, max_val)
elif bucketing == BucketSizing.LINEAR:
if min_val != 0.0:
raise ValueError('Linear bucket sizing requires min_val'
' to be 0.0.')
self.bin_scheme = Histogram.LinearBinScheme(self._buckets, max_val)
else:
ValueError('Unknown bucket type: %s' % (bucketing,))
def stats(self):
measurables = []
def make_measure_fn(pct):
return lambda config, now: self.value(config, now,
pct / 100.0)
for percentile in self._percentiles:
measure_fn = make_measure_fn(percentile.percentile)
stat = NamedMeasurable(percentile.name, AnonMeasurable(measure_fn))
measurables.append(stat)
return measurables
def value(self, config, now, quantile):
self.purge_obsolete_samples(config, now)
count = sum(sample.event_count for sample in self._samples)
if count == 0.0:
return float('NaN')
sum_val = 0.0
quant = float(quantile)
for b in range(self._buckets):
for sample in self._samples:
assert type(sample) is self.HistogramSample
hist = sample.histogram.counts
sum_val += hist[b]
if sum_val / count > quant:
return self._bin_scheme.from_bin(b)
return float('inf')
def combine(self, samples, config, now):
return self.value(config, now, 0.5)
def new_sample(self, time_ms):
return Percentiles.HistogramSample(self._bin_scheme, time_ms)
def update(self, sample, config, value, time_ms):
assert type(sample) is self.HistogramSample
sample.histogram.record(value)
class HistogramSample(AbstractSampledStat.Sample):
def __init__(self, scheme, now):
super(Percentiles.HistogramSample, self).__init__(0.0, now)
self.histogram = Histogram(scheme)

View File

@@ -0,0 +1,117 @@
from __future__ import absolute_import
from kafka.metrics.measurable_stat import AbstractMeasurableStat
from kafka.metrics.stats.sampled_stat import AbstractSampledStat
class TimeUnit(object):
_names = {
'nanosecond': 0,
'microsecond': 1,
'millisecond': 2,
'second': 3,
'minute': 4,
'hour': 5,
'day': 6,
}
NANOSECONDS = _names['nanosecond']
MICROSECONDS = _names['microsecond']
MILLISECONDS = _names['millisecond']
SECONDS = _names['second']
MINUTES = _names['minute']
HOURS = _names['hour']
DAYS = _names['day']
@staticmethod
def get_name(time_unit):
return TimeUnit._names[time_unit]
class Rate(AbstractMeasurableStat):
"""
The rate of the given quantity. By default this is the total observed
over a set of samples from a sampled statistic divided by the elapsed
time over the sample windows. Alternative AbstractSampledStat
implementations can be provided, however, to record the rate of
occurrences (e.g. the count of values measured over the time interval)
or other such values.
"""
def __init__(self, time_unit=TimeUnit.SECONDS, sampled_stat=None):
self._stat = sampled_stat or SampledTotal()
self._unit = time_unit
def unit_name(self):
return TimeUnit.get_name(self._unit)
def record(self, config, value, time_ms):
self._stat.record(config, value, time_ms)
def measure(self, config, now):
value = self._stat.measure(config, now)
return float(value) / self.convert(self.window_size(config, now))
def window_size(self, config, now):
# purge old samples before we compute the window size
self._stat.purge_obsolete_samples(config, now)
"""
Here we check the total amount of time elapsed since the oldest
non-obsolete window. This give the total window_size of the batch
which is the time used for Rate computation. However, there is
an issue if we do not have sufficient data for e.g. if only
1 second has elapsed in a 30 second window, the measured rate
will be very high. Hence we assume that the elapsed time is
always N-1 complete windows plus whatever fraction of the final
window is complete.
Note that we could simply count the amount of time elapsed in
the current window and add n-1 windows to get the total time,
but this approach does not account for sleeps. AbstractSampledStat
only creates samples whenever record is called, if no record is
called for a period of time that time is not accounted for in
window_size and produces incorrect results.
"""
total_elapsed_time_ms = now - self._stat.oldest(now).last_window_ms
# Check how many full windows of data we have currently retained
num_full_windows = int(total_elapsed_time_ms / config.time_window_ms)
min_full_windows = config.samples - 1
# If the available windows are less than the minimum required,
# add the difference to the totalElapsedTime
if num_full_windows < min_full_windows:
total_elapsed_time_ms += ((min_full_windows - num_full_windows) *
config.time_window_ms)
return total_elapsed_time_ms
def convert(self, time_ms):
if self._unit == TimeUnit.NANOSECONDS:
return time_ms * 1000.0 * 1000.0
elif self._unit == TimeUnit.MICROSECONDS:
return time_ms * 1000.0
elif self._unit == TimeUnit.MILLISECONDS:
return time_ms
elif self._unit == TimeUnit.SECONDS:
return time_ms / 1000.0
elif self._unit == TimeUnit.MINUTES:
return time_ms / (60.0 * 1000.0)
elif self._unit == TimeUnit.HOURS:
return time_ms / (60.0 * 60.0 * 1000.0)
elif self._unit == TimeUnit.DAYS:
return time_ms / (24.0 * 60.0 * 60.0 * 1000.0)
else:
raise ValueError('Unknown unit: %s' % (self._unit,))
class SampledTotal(AbstractSampledStat):
def __init__(self, initial_value=None):
if initial_value is not None:
raise ValueError('initial_value cannot be set on SampledTotal')
super(SampledTotal, self).__init__(0.0)
def update(self, sample, config, value, time_ms):
sample.value += value
def combine(self, samples, config, now):
return float(sum(sample.value for sample in samples))

View File

@@ -0,0 +1,101 @@
from __future__ import absolute_import
import abc
from kafka.metrics.measurable_stat import AbstractMeasurableStat
class AbstractSampledStat(AbstractMeasurableStat):
"""
An AbstractSampledStat records a single scalar value measured over
one or more samples. Each sample is recorded over a configurable
window. The window can be defined by number of events or elapsed
time (or both, if both are given the window is complete when
*either* the event count or elapsed time criterion is met).
All the samples are combined to produce the measurement. When a
window is complete the oldest sample is cleared and recycled to
begin recording the next sample.
Subclasses of this class define different statistics measured
using this basic pattern.
"""
__metaclass__ = abc.ABCMeta
def __init__(self, initial_value):
self._initial_value = initial_value
self._samples = []
self._current = 0
@abc.abstractmethod
def update(self, sample, config, value, time_ms):
raise NotImplementedError
@abc.abstractmethod
def combine(self, samples, config, now):
raise NotImplementedError
def record(self, config, value, time_ms):
sample = self.current(time_ms)
if sample.is_complete(time_ms, config):
sample = self._advance(config, time_ms)
self.update(sample, config, float(value), time_ms)
sample.event_count += 1
def new_sample(self, time_ms):
return self.Sample(self._initial_value, time_ms)
def measure(self, config, now):
self.purge_obsolete_samples(config, now)
return float(self.combine(self._samples, config, now))
def current(self, time_ms):
if not self._samples:
self._samples.append(self.new_sample(time_ms))
return self._samples[self._current]
def oldest(self, now):
if not self._samples:
self._samples.append(self.new_sample(now))
oldest = self._samples[0]
for sample in self._samples[1:]:
if sample.last_window_ms < oldest.last_window_ms:
oldest = sample
return oldest
def purge_obsolete_samples(self, config, now):
"""
Timeout any windows that have expired in the absence of any events
"""
expire_age = config.samples * config.time_window_ms
for sample in self._samples:
if now - sample.last_window_ms >= expire_age:
sample.reset(now)
def _advance(self, config, time_ms):
self._current = (self._current + 1) % config.samples
if self._current >= len(self._samples):
sample = self.new_sample(time_ms)
self._samples.append(sample)
return sample
else:
sample = self.current(time_ms)
sample.reset(time_ms)
return sample
class Sample(object):
def __init__(self, initial_value, now):
self.initial_value = initial_value
self.event_count = 0
self.last_window_ms = now
self.value = initial_value
def reset(self, now):
self.event_count = 0
self.last_window_ms = now
self.value = self.initial_value
def is_complete(self, time_ms, config):
return (time_ms - self.last_window_ms >= config.time_window_ms or
self.event_count >= config.event_window)

View File

@@ -0,0 +1,134 @@
from __future__ import absolute_import
import threading
import time
from kafka.errors import QuotaViolationError
from kafka.metrics import KafkaMetric
class Sensor(object):
"""
A sensor applies a continuous sequence of numerical values
to a set of associated metrics. For example a sensor on
message size would record a sequence of message sizes using
the `record(double)` api and would maintain a set
of metrics about request sizes such as the average or max.
"""
def __init__(self, registry, name, parents, config,
inactive_sensor_expiration_time_seconds):
if not name:
raise ValueError('name must be non-empty')
self._lock = threading.RLock()
self._registry = registry
self._name = name
self._parents = parents or []
self._metrics = []
self._stats = []
self._config = config
self._inactive_sensor_expiration_time_ms = (
inactive_sensor_expiration_time_seconds * 1000)
self._last_record_time = time.time() * 1000
self._check_forest(set())
def _check_forest(self, sensors):
"""Validate that this sensor doesn't end up referencing itself."""
if self in sensors:
raise ValueError('Circular dependency in sensors: %s is its own'
'parent.' % (self.name,))
sensors.add(self)
for parent in self._parents:
parent._check_forest(sensors)
@property
def name(self):
"""
The name this sensor is registered with.
This name will be unique among all registered sensors.
"""
return self._name
@property
def metrics(self):
return tuple(self._metrics)
def record(self, value=1.0, time_ms=None):
"""
Record a value at a known time.
Arguments:
value (double): The value we are recording
time_ms (int): A POSIX timestamp in milliseconds.
Default: The time when record() is evaluated (now)
Raises:
QuotaViolationException: if recording this value moves a
metric beyond its configured maximum or minimum bound
"""
if time_ms is None:
time_ms = time.time() * 1000
self._last_record_time = time_ms
with self._lock: # XXX high volume, might be performance issue
# increment all the stats
for stat in self._stats:
stat.record(self._config, value, time_ms)
self._check_quotas(time_ms)
for parent in self._parents:
parent.record(value, time_ms)
def _check_quotas(self, time_ms):
"""
Check if we have violated our quota for any metric that
has a configured quota
"""
for metric in self._metrics:
if metric.config and metric.config.quota:
value = metric.value(time_ms)
if not metric.config.quota.is_acceptable(value):
raise QuotaViolationError("'%s' violated quota. Actual: "
"%d, Threshold: %d" %
(metric.metric_name,
value,
metric.config.quota.bound))
def add_compound(self, compound_stat, config=None):
"""
Register a compound statistic with this sensor which
yields multiple measurable quantities (like a histogram)
Arguments:
stat (AbstractCompoundStat): The stat to register
config (MetricConfig): The configuration for this stat.
If None then the stat will use the default configuration
for this sensor.
"""
if not compound_stat:
raise ValueError('compound stat must be non-empty')
self._stats.append(compound_stat)
for named_measurable in compound_stat.stats():
metric = KafkaMetric(named_measurable.name, named_measurable.stat,
config or self._config)
self._registry.register_metric(metric)
self._metrics.append(metric)
def add(self, metric_name, stat, config=None):
"""
Register a metric with this sensor
Arguments:
metric_name (MetricName): The name of the metric
stat (AbstractMeasurableStat): The statistic to keep
config (MetricConfig): A special configuration for this metric.
If None use the sensor default configuration.
"""
with self._lock:
metric = KafkaMetric(metric_name, stat, config or self._config)
self._registry.register_metric(metric)
self._metrics.append(metric)
self._stats.append(stat)
def has_expired(self):
"""
Return True if the Sensor is eligible for removal due to inactivity.
"""
return ((time.time() * 1000 - self._last_record_time) >
self._inactive_sensor_expiration_time_ms)

View File

@@ -0,0 +1,15 @@
from __future__ import absolute_import
from kafka.metrics.measurable_stat import AbstractMeasurableStat
class Total(AbstractMeasurableStat):
"""An un-windowed cumulative total maintained over all time."""
def __init__(self, value=0.0):
self._total = value
def record(self, config, value, now):
self._total += value
def measure(self, config, now):
return float(self._total)

View File

@@ -0,0 +1,3 @@
from __future__ import absolute_import
from kafka.oauth.abstract import AbstractTokenProvider

View File

@@ -0,0 +1,42 @@
from __future__ import absolute_import
import abc
# This statement is compatible with both Python 2.7 & 3+
ABC = abc.ABCMeta('ABC', (object,), {'__slots__': ()})
class AbstractTokenProvider(ABC):
"""
A Token Provider must be used for the SASL OAuthBearer protocol.
The implementation should ensure token reuse so that multiple
calls at connect time do not create multiple tokens. The implementation
should also periodically refresh the token in order to guarantee
that each call returns an unexpired token. A timeout error should
be returned after a short period of inactivity so that the
broker can log debugging info and retry.
Token Providers MUST implement the token() method
"""
def __init__(self, **config):
pass
@abc.abstractmethod
def token(self):
"""
Returns a (str) ID/Access Token to be sent to the Kafka
client.
"""
pass
def extensions(self):
"""
This is an OPTIONAL method that may be implemented.
Returns a map of key-value pairs that can
be sent with the SASL/OAUTHBEARER initial client request. If
not implemented, the values are ignored. This feature is only available
in Kafka >= 2.1.0.
"""
return {}

View File

@@ -0,0 +1,8 @@
from __future__ import absolute_import
from kafka.partitioner.default import DefaultPartitioner, murmur2
__all__ = [
'DefaultPartitioner', 'murmur2'
]

View File

@@ -0,0 +1,102 @@
from __future__ import absolute_import
import random
from kafka.vendor import six
class DefaultPartitioner(object):
"""Default partitioner.
Hashes key to partition using murmur2 hashing (from java client)
If key is None, selects partition randomly from available,
or from all partitions if none are currently available
"""
@classmethod
def __call__(cls, key, all_partitions, available):
"""
Get the partition corresponding to key
:param key: partitioning key
:param all_partitions: list of all partitions sorted by partition ID
:param available: list of available partitions in no particular order
:return: one of the values from all_partitions or available
"""
if key is None:
if available:
return random.choice(available)
return random.choice(all_partitions)
idx = murmur2(key)
idx &= 0x7fffffff
idx %= len(all_partitions)
return all_partitions[idx]
# https://github.com/apache/kafka/blob/0.8.2/clients/src/main/java/org/apache/kafka/common/utils/Utils.java#L244
def murmur2(data):
"""Pure-python Murmur2 implementation.
Based on java client, see org.apache.kafka.common.utils.Utils.murmur2
Args:
data (bytes): opaque bytes
Returns: MurmurHash2 of data
"""
# Python2 bytes is really a str, causing the bitwise operations below to fail
# so convert to bytearray.
if six.PY2:
data = bytearray(bytes(data))
length = len(data)
seed = 0x9747b28c
# 'm' and 'r' are mixing constants generated offline.
# They're not really 'magic', they just happen to work well.
m = 0x5bd1e995
r = 24
# Initialize the hash to a random value
h = seed ^ length
length4 = length // 4
for i in range(length4):
i4 = i * 4
k = ((data[i4 + 0] & 0xff) +
((data[i4 + 1] & 0xff) << 8) +
((data[i4 + 2] & 0xff) << 16) +
((data[i4 + 3] & 0xff) << 24))
k &= 0xffffffff
k *= m
k &= 0xffffffff
k ^= (k % 0x100000000) >> r # k ^= k >>> r
k &= 0xffffffff
k *= m
k &= 0xffffffff
h *= m
h &= 0xffffffff
h ^= k
h &= 0xffffffff
# Handle the last few bytes of the input array
extra_bytes = length % 4
if extra_bytes >= 3:
h ^= (data[(length & ~3) + 2] & 0xff) << 16
h &= 0xffffffff
if extra_bytes >= 2:
h ^= (data[(length & ~3) + 1] & 0xff) << 8
h &= 0xffffffff
if extra_bytes >= 1:
h ^= (data[length & ~3] & 0xff)
h &= 0xffffffff
h *= m
h &= 0xffffffff
h ^= (h % 0x100000000) >> 13 # h >>> 13;
h &= 0xffffffff
h *= m
h &= 0xffffffff
h ^= (h % 0x100000000) >> 15 # h >>> 15;
h &= 0xffffffff
return h

View File

@@ -0,0 +1,7 @@
from __future__ import absolute_import
from kafka.producer.kafka import KafkaProducer
__all__ = [
'KafkaProducer'
]

View File

@@ -0,0 +1,115 @@
from __future__ import absolute_import, division
import collections
import io
import threading
import time
from kafka.metrics.stats import Rate
import kafka.errors as Errors
class SimpleBufferPool(object):
"""A simple pool of BytesIO objects with a weak memory ceiling."""
def __init__(self, memory, poolable_size, metrics=None, metric_group_prefix='producer-metrics'):
"""Create a new buffer pool.
Arguments:
memory (int): maximum memory that this buffer pool can allocate
poolable_size (int): memory size per buffer to cache in the free
list rather than deallocating
"""
self._poolable_size = poolable_size
self._lock = threading.RLock()
buffers = int(memory / poolable_size) if poolable_size else 0
self._free = collections.deque([io.BytesIO() for _ in range(buffers)])
self._waiters = collections.deque()
self.wait_time = None
if metrics:
self.wait_time = metrics.sensor('bufferpool-wait-time')
self.wait_time.add(metrics.metric_name(
'bufferpool-wait-ratio', metric_group_prefix,
'The fraction of time an appender waits for space allocation.'),
Rate())
def allocate(self, size, max_time_to_block_ms):
"""
Allocate a buffer of the given size. This method blocks if there is not
enough memory and the buffer pool is configured with blocking mode.
Arguments:
size (int): The buffer size to allocate in bytes [ignored]
max_time_to_block_ms (int): The maximum time in milliseconds to
block for buffer memory to be available
Returns:
io.BytesIO
"""
with self._lock:
# check if we have a free buffer of the right size pooled
if self._free:
return self._free.popleft()
elif self._poolable_size == 0:
return io.BytesIO()
else:
# we are out of buffers and will have to block
buf = None
more_memory = threading.Condition(self._lock)
self._waiters.append(more_memory)
# loop over and over until we have a buffer or have reserved
# enough memory to allocate one
while buf is None:
start_wait = time.time()
more_memory.wait(max_time_to_block_ms / 1000.0)
end_wait = time.time()
if self.wait_time:
self.wait_time.record(end_wait - start_wait)
if self._free:
buf = self._free.popleft()
else:
self._waiters.remove(more_memory)
raise Errors.KafkaTimeoutError(
"Failed to allocate memory within the configured"
" max blocking time")
# remove the condition for this thread to let the next thread
# in line start getting memory
removed = self._waiters.popleft()
assert removed is more_memory, 'Wrong condition'
# signal any additional waiters if there is more memory left
# over for them
if self._free and self._waiters:
self._waiters[0].notify()
# unlock and return the buffer
return buf
def deallocate(self, buf):
"""
Return buffers to the pool. If they are of the poolable size add them
to the free list, otherwise just mark the memory as free.
Arguments:
buffer_ (io.BytesIO): The buffer to return
"""
with self._lock:
# BytesIO.truncate here makes the pool somewhat pointless
# but we stick with the BufferPool API until migrating to
# bytesarray / memoryview. The buffer we return must not
# expose any prior data on read().
buf.truncate(0)
self._free.append(buf)
if self._waiters:
self._waiters[0].notify()
def queued(self):
"""The number of threads blocked waiting on memory."""
with self._lock:
return len(self._waiters)

View File

@@ -0,0 +1,71 @@
from __future__ import absolute_import
import collections
import threading
from kafka import errors as Errors
from kafka.future import Future
class FutureProduceResult(Future):
def __init__(self, topic_partition):
super(FutureProduceResult, self).__init__()
self.topic_partition = topic_partition
self._latch = threading.Event()
def success(self, value):
ret = super(FutureProduceResult, self).success(value)
self._latch.set()
return ret
def failure(self, error):
ret = super(FutureProduceResult, self).failure(error)
self._latch.set()
return ret
def wait(self, timeout=None):
# wait() on python2.6 returns None instead of the flag value
return self._latch.wait(timeout) or self._latch.is_set()
class FutureRecordMetadata(Future):
def __init__(self, produce_future, relative_offset, timestamp_ms, checksum, serialized_key_size, serialized_value_size, serialized_header_size):
super(FutureRecordMetadata, self).__init__()
self._produce_future = produce_future
# packing args as a tuple is a minor speed optimization
self.args = (relative_offset, timestamp_ms, checksum, serialized_key_size, serialized_value_size, serialized_header_size)
produce_future.add_callback(self._produce_success)
produce_future.add_errback(self.failure)
def _produce_success(self, offset_and_timestamp):
offset, produce_timestamp_ms, log_start_offset = offset_and_timestamp
# Unpacking from args tuple is minor speed optimization
(relative_offset, timestamp_ms, checksum,
serialized_key_size, serialized_value_size, serialized_header_size) = self.args
# None is when Broker does not support the API (<0.10) and
# -1 is when the broker is configured for CREATE_TIME timestamps
if produce_timestamp_ms is not None and produce_timestamp_ms != -1:
timestamp_ms = produce_timestamp_ms
if offset != -1 and relative_offset is not None:
offset += relative_offset
tp = self._produce_future.topic_partition
metadata = RecordMetadata(tp[0], tp[1], tp, offset, timestamp_ms, log_start_offset,
checksum, serialized_key_size,
serialized_value_size, serialized_header_size)
self.success(metadata)
def get(self, timeout=None):
if not self.is_done and not self._produce_future.wait(timeout):
raise Errors.KafkaTimeoutError(
"Timeout after waiting for %s secs." % (timeout,))
assert self.is_done
if self.failed():
raise self.exception # pylint: disable-msg=raising-bad-type
return self.value
RecordMetadata = collections.namedtuple(
'RecordMetadata', ['topic', 'partition', 'topic_partition', 'offset', 'timestamp', 'log_start_offset',
'checksum', 'serialized_key_size', 'serialized_value_size', 'serialized_header_size'])

View File

@@ -0,0 +1,749 @@
from __future__ import absolute_import
import atexit
import copy
import logging
import socket
import threading
import time
import weakref
from kafka.vendor import six
import kafka.errors as Errors
from kafka.client_async import KafkaClient, selectors
from kafka.codec import has_gzip, has_snappy, has_lz4, has_zstd
from kafka.metrics import MetricConfig, Metrics
from kafka.partitioner.default import DefaultPartitioner
from kafka.producer.future import FutureRecordMetadata, FutureProduceResult
from kafka.producer.record_accumulator import AtomicInteger, RecordAccumulator
from kafka.producer.sender import Sender
from kafka.record.default_records import DefaultRecordBatchBuilder
from kafka.record.legacy_records import LegacyRecordBatchBuilder
from kafka.serializer import Serializer
from kafka.structs import TopicPartition
log = logging.getLogger(__name__)
PRODUCER_CLIENT_ID_SEQUENCE = AtomicInteger()
class KafkaProducer(object):
"""A Kafka client that publishes records to the Kafka cluster.
The producer is thread safe and sharing a single producer instance across
threads will generally be faster than having multiple instances.
The producer consists of a pool of buffer space that holds records that
haven't yet been transmitted to the server as well as a background I/O
thread that is responsible for turning these records into requests and
transmitting them to the cluster.
:meth:`~kafka.KafkaProducer.send` is asynchronous. When called it adds the
record to a buffer of pending record sends and immediately returns. This
allows the producer to batch together individual records for efficiency.
The 'acks' config controls the criteria under which requests are considered
complete. The "all" setting will result in blocking on the full commit of
the record, the slowest but most durable setting.
If the request fails, the producer can automatically retry, unless
'retries' is configured to 0. Enabling retries also opens up the
possibility of duplicates (see the documentation on message
delivery semantics for details:
https://kafka.apache.org/documentation.html#semantics
).
The producer maintains buffers of unsent records for each partition. These
buffers are of a size specified by the 'batch_size' config. Making this
larger can result in more batching, but requires more memory (since we will
generally have one of these buffers for each active partition).
By default a buffer is available to send immediately even if there is
additional unused space in the buffer. However if you want to reduce the
number of requests you can set 'linger_ms' to something greater than 0.
This will instruct the producer to wait up to that number of milliseconds
before sending a request in hope that more records will arrive to fill up
the same batch. This is analogous to Nagle's algorithm in TCP. Note that
records that arrive close together in time will generally batch together
even with linger_ms=0 so under heavy load batching will occur regardless of
the linger configuration; however setting this to something larger than 0
can lead to fewer, more efficient requests when not under maximal load at
the cost of a small amount of latency.
The buffer_memory controls the total amount of memory available to the
producer for buffering. If records are sent faster than they can be
transmitted to the server then this buffer space will be exhausted. When
the buffer space is exhausted additional send calls will block.
The key_serializer and value_serializer instruct how to turn the key and
value objects the user provides into bytes.
Keyword Arguments:
bootstrap_servers: 'host[:port]' string (or list of 'host[:port]'
strings) that the producer should contact to bootstrap initial
cluster metadata. This does not have to be the full node list.
It just needs to have at least one broker that will respond to a
Metadata API Request. Default port is 9092. If no servers are
specified, will default to localhost:9092.
client_id (str): a name for this client. This string is passed in
each request to servers and can be used to identify specific
server-side log entries that correspond to this client.
Default: 'kafka-python-producer-#' (appended with a unique number
per instance)
key_serializer (callable): used to convert user-supplied keys to bytes
If not None, called as f(key), should return bytes. Default: None.
value_serializer (callable): used to convert user-supplied message
values to bytes. If not None, called as f(value), should return
bytes. Default: None.
acks (0, 1, 'all'): The number of acknowledgments the producer requires
the leader to have received before considering a request complete.
This controls the durability of records that are sent. The
following settings are common:
0: Producer will not wait for any acknowledgment from the server.
The message will immediately be added to the socket
buffer and considered sent. No guarantee can be made that the
server has received the record in this case, and the retries
configuration will not take effect (as the client won't
generally know of any failures). The offset given back for each
record will always be set to -1.
1: Wait for leader to write the record to its local log only.
Broker will respond without awaiting full acknowledgement from
all followers. In this case should the leader fail immediately
after acknowledging the record but before the followers have
replicated it then the record will be lost.
all: Wait for the full set of in-sync replicas to write the record.
This guarantees that the record will not be lost as long as at
least one in-sync replica remains alive. This is the strongest
available guarantee.
If unset, defaults to acks=1.
compression_type (str): The compression type for all data generated by
the producer. Valid values are 'gzip', 'snappy', 'lz4', 'zstd' or None.
Compression is of full batches of data, so the efficacy of batching
will also impact the compression ratio (more batching means better
compression). Default: None.
retries (int): Setting a value greater than zero will cause the client
to resend any record whose send fails with a potentially transient
error. Note that this retry is no different than if the client
resent the record upon receiving the error. Allowing retries
without setting max_in_flight_requests_per_connection to 1 will
potentially change the ordering of records because if two batches
are sent to a single partition, and the first fails and is retried
but the second succeeds, then the records in the second batch may
appear first.
Default: 0.
batch_size (int): Requests sent to brokers will contain multiple
batches, one for each partition with data available to be sent.
A small batch size will make batching less common and may reduce
throughput (a batch size of zero will disable batching entirely).
Default: 16384
linger_ms (int): The producer groups together any records that arrive
in between request transmissions into a single batched request.
Normally this occurs only under load when records arrive faster
than they can be sent out. However in some circumstances the client
may want to reduce the number of requests even under moderate load.
This setting accomplishes this by adding a small amount of
artificial delay; that is, rather than immediately sending out a
record the producer will wait for up to the given delay to allow
other records to be sent so that the sends can be batched together.
This can be thought of as analogous to Nagle's algorithm in TCP.
This setting gives the upper bound on the delay for batching: once
we get batch_size worth of records for a partition it will be sent
immediately regardless of this setting, however if we have fewer
than this many bytes accumulated for this partition we will
'linger' for the specified time waiting for more records to show
up. This setting defaults to 0 (i.e. no delay). Setting linger_ms=5
would have the effect of reducing the number of requests sent but
would add up to 5ms of latency to records sent in the absence of
load. Default: 0.
partitioner (callable): Callable used to determine which partition
each message is assigned to. Called (after key serialization):
partitioner(key_bytes, all_partitions, available_partitions).
The default partitioner implementation hashes each non-None key
using the same murmur2 algorithm as the java client so that
messages with the same key are assigned to the same partition.
When a key is None, the message is delivered to a random partition
(filtered to partitions with available leaders only, if possible).
buffer_memory (int): The total bytes of memory the producer should use
to buffer records waiting to be sent to the server. If records are
sent faster than they can be delivered to the server the producer
will block up to max_block_ms, raising an exception on timeout.
In the current implementation, this setting is an approximation.
Default: 33554432 (32MB)
connections_max_idle_ms: Close idle connections after the number of
milliseconds specified by this config. The broker closes idle
connections after connections.max.idle.ms, so this avoids hitting
unexpected socket disconnected errors on the client.
Default: 540000
max_block_ms (int): Number of milliseconds to block during
:meth:`~kafka.KafkaProducer.send` and
:meth:`~kafka.KafkaProducer.partitions_for`. These methods can be
blocked either because the buffer is full or metadata unavailable.
Blocking in the user-supplied serializers or partitioner will not be
counted against this timeout. Default: 60000.
max_request_size (int): The maximum size of a request. This is also
effectively a cap on the maximum record size. Note that the server
has its own cap on record size which may be different from this.
This setting will limit the number of record batches the producer
will send in a single request to avoid sending huge requests.
Default: 1048576.
metadata_max_age_ms (int): The period of time in milliseconds after
which we force a refresh of metadata even if we haven't seen any
partition leadership changes to proactively discover any new
brokers or partitions. Default: 300000
retry_backoff_ms (int): Milliseconds to backoff when retrying on
errors. Default: 100.
request_timeout_ms (int): Client request timeout in milliseconds.
Default: 30000.
receive_buffer_bytes (int): The size of the TCP receive buffer
(SO_RCVBUF) to use when reading data. Default: None (relies on
system defaults). Java client defaults to 32768.
send_buffer_bytes (int): The size of the TCP send buffer
(SO_SNDBUF) to use when sending data. Default: None (relies on
system defaults). Java client defaults to 131072.
socket_options (list): List of tuple-arguments to socket.setsockopt
to apply to broker connection sockets. Default:
[(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)]
reconnect_backoff_ms (int): The amount of time in milliseconds to
wait before attempting to reconnect to a given host.
Default: 50.
reconnect_backoff_max_ms (int): The maximum amount of time in
milliseconds to backoff/wait when reconnecting to a broker that has
repeatedly failed to connect. If provided, the backoff per host
will increase exponentially for each consecutive connection
failure, up to this maximum. Once the maximum is reached,
reconnection attempts will continue periodically with this fixed
rate. To avoid connection storms, a randomization factor of 0.2
will be applied to the backoff resulting in a random range between
20% below and 20% above the computed value. Default: 1000.
max_in_flight_requests_per_connection (int): Requests are pipelined
to kafka brokers up to this number of maximum requests per
broker connection. Note that if this setting is set to be greater
than 1 and there are failed sends, there is a risk of message
re-ordering due to retries (i.e., if retries are enabled).
Default: 5.
security_protocol (str): Protocol used to communicate with brokers.
Valid values are: PLAINTEXT, SSL, SASL_PLAINTEXT, SASL_SSL.
Default: PLAINTEXT.
ssl_context (ssl.SSLContext): pre-configured SSLContext for wrapping
socket connections. If provided, all other ssl_* configurations
will be ignored. Default: None.
ssl_check_hostname (bool): flag to configure whether ssl handshake
should verify that the certificate matches the brokers hostname.
default: true.
ssl_cafile (str): optional filename of ca file to use in certificate
veriication. default: none.
ssl_certfile (str): optional filename of file in pem format containing
the client certificate, as well as any ca certificates needed to
establish the certificate's authenticity. default: none.
ssl_keyfile (str): optional filename containing the client private key.
default: none.
ssl_password (str): optional password to be used when loading the
certificate chain. default: none.
ssl_crlfile (str): optional filename containing the CRL to check for
certificate expiration. By default, no CRL check is done. When
providing a file, only the leaf certificate will be checked against
this CRL. The CRL can only be checked with Python 3.4+ or 2.7.9+.
default: none.
ssl_ciphers (str): optionally set the available ciphers for ssl
connections. It should be a string in the OpenSSL cipher list
format. If no cipher can be selected (because compile-time options
or other configuration forbids use of all the specified ciphers),
an ssl.SSLError will be raised. See ssl.SSLContext.set_ciphers
api_version (tuple): Specify which Kafka API version to use. If set to
None, the client will attempt to infer the broker version by probing
various APIs. Example: (0, 10, 2). Default: None
api_version_auto_timeout_ms (int): number of milliseconds to throw a
timeout exception from the constructor when checking the broker
api version. Only applies if api_version set to None.
metric_reporters (list): A list of classes to use as metrics reporters.
Implementing the AbstractMetricsReporter interface allows plugging
in classes that will be notified of new metric creation. Default: []
metrics_num_samples (int): The number of samples maintained to compute
metrics. Default: 2
metrics_sample_window_ms (int): The maximum age in milliseconds of
samples used to compute metrics. Default: 30000
selector (selectors.BaseSelector): Provide a specific selector
implementation to use for I/O multiplexing.
Default: selectors.DefaultSelector
sasl_mechanism (str): Authentication mechanism when security_protocol
is configured for SASL_PLAINTEXT or SASL_SSL. Valid values are:
PLAIN, GSSAPI, OAUTHBEARER, SCRAM-SHA-256, SCRAM-SHA-512.
sasl_plain_username (str): username for sasl PLAIN and SCRAM authentication.
Required if sasl_mechanism is PLAIN or one of the SCRAM mechanisms.
sasl_plain_password (str): password for sasl PLAIN and SCRAM authentication.
Required if sasl_mechanism is PLAIN or one of the SCRAM mechanisms.
sasl_kerberos_service_name (str): Service name to include in GSSAPI
sasl mechanism handshake. Default: 'kafka'
sasl_kerberos_domain_name (str): kerberos domain name to use in GSSAPI
sasl mechanism handshake. Default: one of bootstrap servers
sasl_oauth_token_provider (AbstractTokenProvider): OAuthBearer token provider
instance. (See kafka.oauth.abstract). Default: None
Note:
Configuration parameters are described in more detail at
https://kafka.apache.org/0100/configuration.html#producerconfigs
"""
DEFAULT_CONFIG = {
'bootstrap_servers': 'localhost',
'client_id': None,
'key_serializer': None,
'value_serializer': None,
'acks': 1,
'bootstrap_topics_filter': set(),
'compression_type': None,
'retries': 0,
'batch_size': 16384,
'linger_ms': 0,
'partitioner': DefaultPartitioner(),
'buffer_memory': 33554432,
'connections_max_idle_ms': 9 * 60 * 1000,
'max_block_ms': 60000,
'max_request_size': 1048576,
'metadata_max_age_ms': 300000,
'retry_backoff_ms': 100,
'request_timeout_ms': 30000,
'receive_buffer_bytes': None,
'send_buffer_bytes': None,
'socket_options': [(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)],
'sock_chunk_bytes': 4096, # undocumented experimental option
'sock_chunk_buffer_count': 1000, # undocumented experimental option
'reconnect_backoff_ms': 50,
'reconnect_backoff_max_ms': 1000,
'max_in_flight_requests_per_connection': 5,
'security_protocol': 'PLAINTEXT',
'ssl_context': None,
'ssl_check_hostname': True,
'ssl_cafile': None,
'ssl_certfile': None,
'ssl_keyfile': None,
'ssl_crlfile': None,
'ssl_password': None,
'ssl_ciphers': None,
'api_version': None,
'api_version_auto_timeout_ms': 2000,
'metric_reporters': [],
'metrics_num_samples': 2,
'metrics_sample_window_ms': 30000,
'selector': selectors.DefaultSelector,
'sasl_mechanism': None,
'sasl_plain_username': None,
'sasl_plain_password': None,
'sasl_kerberos_service_name': 'kafka',
'sasl_kerberos_domain_name': None,
'sasl_oauth_token_provider': None
}
_COMPRESSORS = {
'gzip': (has_gzip, LegacyRecordBatchBuilder.CODEC_GZIP),
'snappy': (has_snappy, LegacyRecordBatchBuilder.CODEC_SNAPPY),
'lz4': (has_lz4, LegacyRecordBatchBuilder.CODEC_LZ4),
'zstd': (has_zstd, DefaultRecordBatchBuilder.CODEC_ZSTD),
None: (lambda: True, LegacyRecordBatchBuilder.CODEC_NONE),
}
def __init__(self, **configs):
log.debug("Starting the Kafka producer") # trace
self.config = copy.copy(self.DEFAULT_CONFIG)
for key in self.config:
if key in configs:
self.config[key] = configs.pop(key)
# Only check for extra config keys in top-level class
assert not configs, 'Unrecognized configs: %s' % (configs,)
if self.config['client_id'] is None:
self.config['client_id'] = 'kafka-python-producer-%s' % \
(PRODUCER_CLIENT_ID_SEQUENCE.increment(),)
if self.config['acks'] == 'all':
self.config['acks'] = -1
# api_version was previously a str. accept old format for now
if isinstance(self.config['api_version'], str):
deprecated = self.config['api_version']
if deprecated == 'auto':
self.config['api_version'] = None
else:
self.config['api_version'] = tuple(map(int, deprecated.split('.')))
log.warning('use api_version=%s [tuple] -- "%s" as str is deprecated',
str(self.config['api_version']), deprecated)
# Configure metrics
metrics_tags = {'client-id': self.config['client_id']}
metric_config = MetricConfig(samples=self.config['metrics_num_samples'],
time_window_ms=self.config['metrics_sample_window_ms'],
tags=metrics_tags)
reporters = [reporter() for reporter in self.config['metric_reporters']]
self._metrics = Metrics(metric_config, reporters)
client = KafkaClient(metrics=self._metrics, metric_group_prefix='producer',
wakeup_timeout_ms=self.config['max_block_ms'],
**self.config)
# Get auto-discovered version from client if necessary
if self.config['api_version'] is None:
self.config['api_version'] = client.config['api_version']
if self.config['compression_type'] == 'lz4':
assert self.config['api_version'] >= (0, 8, 2), 'LZ4 Requires >= Kafka 0.8.2 Brokers'
if self.config['compression_type'] == 'zstd':
assert self.config['api_version'] >= (2, 1, 0), 'Zstd Requires >= Kafka 2.1.0 Brokers'
# Check compression_type for library support
ct = self.config['compression_type']
if ct not in self._COMPRESSORS:
raise ValueError("Not supported codec: {}".format(ct))
else:
checker, compression_attrs = self._COMPRESSORS[ct]
assert checker(), "Libraries for {} compression codec not found".format(ct)
self.config['compression_attrs'] = compression_attrs
message_version = self._max_usable_produce_magic()
self._accumulator = RecordAccumulator(message_version=message_version, metrics=self._metrics, **self.config)
self._metadata = client.cluster
guarantee_message_order = bool(self.config['max_in_flight_requests_per_connection'] == 1)
self._sender = Sender(client, self._metadata,
self._accumulator, self._metrics,
guarantee_message_order=guarantee_message_order,
**self.config)
self._sender.daemon = True
self._sender.start()
self._closed = False
self._cleanup = self._cleanup_factory()
atexit.register(self._cleanup)
log.debug("Kafka producer started")
def bootstrap_connected(self):
"""Return True if the bootstrap is connected."""
return self._sender.bootstrap_connected()
def _cleanup_factory(self):
"""Build a cleanup clojure that doesn't increase our ref count"""
_self = weakref.proxy(self)
def wrapper():
try:
_self.close(timeout=0)
except (ReferenceError, AttributeError):
pass
return wrapper
def _unregister_cleanup(self):
if getattr(self, '_cleanup', None):
if hasattr(atexit, 'unregister'):
atexit.unregister(self._cleanup) # pylint: disable=no-member
# py2 requires removing from private attribute...
else:
# ValueError on list.remove() if the exithandler no longer exists
# but that is fine here
try:
atexit._exithandlers.remove( # pylint: disable=no-member
(self._cleanup, (), {}))
except ValueError:
pass
self._cleanup = None
def __del__(self):
# Disable logger during destruction to avoid touching dangling references
class NullLogger(object):
def __getattr__(self, name):
return lambda *args: None
global log
log = NullLogger()
self.close()
def close(self, timeout=None):
"""Close this producer.
Arguments:
timeout (float, optional): timeout in seconds to wait for completion.
"""
# drop our atexit handler now to avoid leaks
self._unregister_cleanup()
if not hasattr(self, '_closed') or self._closed:
log.info('Kafka producer closed')
return
if timeout is None:
# threading.TIMEOUT_MAX is available in Python3.3+
timeout = getattr(threading, 'TIMEOUT_MAX', float('inf'))
if getattr(threading, 'TIMEOUT_MAX', False):
assert 0 <= timeout <= getattr(threading, 'TIMEOUT_MAX')
else:
assert timeout >= 0
log.info("Closing the Kafka producer with %s secs timeout.", timeout)
invoked_from_callback = bool(threading.current_thread() is self._sender)
if timeout > 0:
if invoked_from_callback:
log.warning("Overriding close timeout %s secs to 0 in order to"
" prevent useless blocking due to self-join. This"
" means you have incorrectly invoked close with a"
" non-zero timeout from the producer call-back.",
timeout)
else:
# Try to close gracefully.
if self._sender is not None:
self._sender.initiate_close()
self._sender.join(timeout)
if self._sender is not None and self._sender.is_alive():
log.info("Proceeding to force close the producer since pending"
" requests could not be completed within timeout %s.",
timeout)
self._sender.force_close()
self._metrics.close()
try:
self.config['key_serializer'].close()
except AttributeError:
pass
try:
self.config['value_serializer'].close()
except AttributeError:
pass
self._closed = True
log.debug("The Kafka producer has closed.")
def partitions_for(self, topic):
"""Returns set of all known partitions for the topic."""
max_wait = self.config['max_block_ms'] / 1000.0
return self._wait_on_metadata(topic, max_wait)
def _max_usable_produce_magic(self):
if self.config['api_version'] >= (0, 11):
return 2
elif self.config['api_version'] >= (0, 10):
return 1
else:
return 0
def _estimate_size_in_bytes(self, key, value, headers=[]):
magic = self._max_usable_produce_magic()
if magic == 2:
return DefaultRecordBatchBuilder.estimate_size_in_bytes(
key, value, headers)
else:
return LegacyRecordBatchBuilder.estimate_size_in_bytes(
magic, self.config['compression_type'], key, value)
def send(self, topic, value=None, key=None, headers=None, partition=None, timestamp_ms=None):
"""Publish a message to a topic.
Arguments:
topic (str): topic where the message will be published
value (optional): message value. Must be type bytes, or be
serializable to bytes via configured value_serializer. If value
is None, key is required and message acts as a 'delete'.
See kafka compaction documentation for more details:
https://kafka.apache.org/documentation.html#compaction
(compaction requires kafka >= 0.8.1)
partition (int, optional): optionally specify a partition. If not
set, the partition will be selected using the configured
'partitioner'.
key (optional): a key to associate with the message. Can be used to
determine which partition to send the message to. If partition
is None (and producer's partitioner config is left as default),
then messages with the same key will be delivered to the same
partition (but if key is None, partition is chosen randomly).
Must be type bytes, or be serializable to bytes via configured
key_serializer.
headers (optional): a list of header key value pairs. List items
are tuples of str key and bytes value.
timestamp_ms (int, optional): epoch milliseconds (from Jan 1 1970 UTC)
to use as the message timestamp. Defaults to current time.
Returns:
FutureRecordMetadata: resolves to RecordMetadata
Raises:
KafkaTimeoutError: if unable to fetch topic metadata, or unable
to obtain memory buffer prior to configured max_block_ms
"""
assert value is not None or self.config['api_version'] >= (0, 8, 1), (
'Null messages require kafka >= 0.8.1')
assert not (value is None and key is None), 'Need at least one: key or value'
key_bytes = value_bytes = None
try:
self._wait_on_metadata(topic, self.config['max_block_ms'] / 1000.0)
key_bytes = self._serialize(
self.config['key_serializer'],
topic, key)
value_bytes = self._serialize(
self.config['value_serializer'],
topic, value)
assert type(key_bytes) in (bytes, bytearray, memoryview, type(None))
assert type(value_bytes) in (bytes, bytearray, memoryview, type(None))
partition = self._partition(topic, partition, key, value,
key_bytes, value_bytes)
if headers is None:
headers = []
assert type(headers) == list
assert all(type(item) == tuple and len(item) == 2 and type(item[0]) == str and type(item[1]) == bytes for item in headers)
message_size = self._estimate_size_in_bytes(key_bytes, value_bytes, headers)
self._ensure_valid_record_size(message_size)
tp = TopicPartition(topic, partition)
log.debug("Sending (key=%r value=%r headers=%r) to %s", key, value, headers, tp)
result = self._accumulator.append(tp, timestamp_ms,
key_bytes, value_bytes, headers,
self.config['max_block_ms'],
estimated_size=message_size)
future, batch_is_full, new_batch_created = result
if batch_is_full or new_batch_created:
log.debug("Waking up the sender since %s is either full or"
" getting a new batch", tp)
self._sender.wakeup()
return future
# handling exceptions and record the errors;
# for API exceptions return them in the future,
# for other exceptions raise directly
except Errors.BrokerResponseError as e:
log.debug("Exception occurred during message send: %s", e)
return FutureRecordMetadata(
FutureProduceResult(TopicPartition(topic, partition)),
-1, None, None,
len(key_bytes) if key_bytes is not None else -1,
len(value_bytes) if value_bytes is not None else -1,
sum(len(h_key.encode("utf-8")) + len(h_value) for h_key, h_value in headers) if headers else -1,
).failure(e)
def flush(self, timeout=None):
"""
Invoking this method makes all buffered records immediately available
to send (even if linger_ms is greater than 0) and blocks on the
completion of the requests associated with these records. The
post-condition of :meth:`~kafka.KafkaProducer.flush` is that any
previously sent record will have completed
(e.g. Future.is_done() == True). A request is considered completed when
either it is successfully acknowledged according to the 'acks'
configuration for the producer, or it results in an error.
Other threads can continue sending messages while one thread is blocked
waiting for a flush call to complete; however, no guarantee is made
about the completion of messages sent after the flush call begins.
Arguments:
timeout (float, optional): timeout in seconds to wait for completion.
Raises:
KafkaTimeoutError: failure to flush buffered records within the
provided timeout
"""
log.debug("Flushing accumulated records in producer.") # trace
self._accumulator.begin_flush()
self._sender.wakeup()
self._accumulator.await_flush_completion(timeout=timeout)
def _ensure_valid_record_size(self, size):
"""Validate that the record size isn't too large."""
if size > self.config['max_request_size']:
raise Errors.MessageSizeTooLargeError(
"The message is %d bytes when serialized which is larger than"
" the maximum request size you have configured with the"
" max_request_size configuration" % (size,))
if size > self.config['buffer_memory']:
raise Errors.MessageSizeTooLargeError(
"The message is %d bytes when serialized which is larger than"
" the total memory buffer you have configured with the"
" buffer_memory configuration." % (size,))
def _wait_on_metadata(self, topic, max_wait):
"""
Wait for cluster metadata including partitions for the given topic to
be available.
Arguments:
topic (str): topic we want metadata for
max_wait (float): maximum time in secs for waiting on the metadata
Returns:
set: partition ids for the topic
Raises:
KafkaTimeoutError: if partitions for topic were not obtained before
specified max_wait timeout
"""
# add topic to metadata topic list if it is not there already.
self._sender.add_topic(topic)
begin = time.time()
elapsed = 0.0
metadata_event = None
while True:
partitions = self._metadata.partitions_for_topic(topic)
if partitions is not None:
return partitions
if not metadata_event:
metadata_event = threading.Event()
log.debug("Requesting metadata update for topic %s", topic)
metadata_event.clear()
future = self._metadata.request_update()
future.add_both(lambda e, *args: e.set(), metadata_event)
self._sender.wakeup()
metadata_event.wait(max_wait - elapsed)
elapsed = time.time() - begin
if not metadata_event.is_set():
raise Errors.KafkaTimeoutError(
"Failed to update metadata after %.1f secs." % (max_wait,))
elif topic in self._metadata.unauthorized_topics:
raise Errors.TopicAuthorizationFailedError(topic)
else:
log.debug("_wait_on_metadata woke after %s secs.", elapsed)
def _serialize(self, f, topic, data):
if not f:
return data
if isinstance(f, Serializer):
return f.serialize(topic, data)
return f(data)
def _partition(self, topic, partition, key, value,
serialized_key, serialized_value):
if partition is not None:
assert partition >= 0
assert partition in self._metadata.partitions_for_topic(topic), 'Unrecognized partition'
return partition
all_partitions = sorted(self._metadata.partitions_for_topic(topic))
available = list(self._metadata.available_partitions_for_topic(topic))
return self.config['partitioner'](serialized_key,
all_partitions,
available)
def metrics(self, raw=False):
"""Get metrics on producer performance.
This is ported from the Java Producer, for details see:
https://kafka.apache.org/documentation/#producer_monitoring
Warning:
This is an unstable interface. It may change in future
releases without warning.
"""
if raw:
return self._metrics.metrics.copy()
metrics = {}
for k, v in six.iteritems(self._metrics.metrics.copy()):
if k.group not in metrics:
metrics[k.group] = {}
if k.name not in metrics[k.group]:
metrics[k.group][k.name] = {}
metrics[k.group][k.name] = v.value()
return metrics

View File

@@ -0,0 +1,590 @@
from __future__ import absolute_import
import collections
import copy
import logging
import threading
import time
import kafka.errors as Errors
from kafka.producer.buffer import SimpleBufferPool
from kafka.producer.future import FutureRecordMetadata, FutureProduceResult
from kafka.record.memory_records import MemoryRecordsBuilder
from kafka.structs import TopicPartition
log = logging.getLogger(__name__)
class AtomicInteger(object):
def __init__(self, val=0):
self._lock = threading.Lock()
self._val = val
def increment(self):
with self._lock:
self._val += 1
return self._val
def decrement(self):
with self._lock:
self._val -= 1
return self._val
def get(self):
return self._val
class ProducerBatch(object):
def __init__(self, tp, records, buffer):
self.max_record_size = 0
now = time.time()
self.created = now
self.drained = None
self.attempts = 0
self.last_attempt = now
self.last_append = now
self.records = records
self.topic_partition = tp
self.produce_future = FutureProduceResult(tp)
self._retry = False
self._buffer = buffer # We only save it, we don't write to it
@property
def record_count(self):
return self.records.next_offset()
def try_append(self, timestamp_ms, key, value, headers):
metadata = self.records.append(timestamp_ms, key, value, headers)
if metadata is None:
return None
self.max_record_size = max(self.max_record_size, metadata.size)
self.last_append = time.time()
future = FutureRecordMetadata(self.produce_future, metadata.offset,
metadata.timestamp, metadata.crc,
len(key) if key is not None else -1,
len(value) if value is not None else -1,
sum(len(h_key.encode("utf-8")) + len(h_val) for h_key, h_val in headers) if headers else -1)
return future
def done(self, base_offset=None, timestamp_ms=None, exception=None, log_start_offset=None, global_error=None):
level = logging.DEBUG if exception is None else logging.WARNING
log.log(level, "Produced messages to topic-partition %s with base offset"
" %s log start offset %s and error %s.", self.topic_partition, base_offset,
log_start_offset, global_error) # trace
if self.produce_future.is_done:
log.warning('Batch is already closed -- ignoring batch.done()')
return
elif exception is None:
self.produce_future.success((base_offset, timestamp_ms, log_start_offset))
else:
self.produce_future.failure(exception)
def maybe_expire(self, request_timeout_ms, retry_backoff_ms, linger_ms, is_full):
"""Expire batches if metadata is not available
A batch whose metadata is not available should be expired if one
of the following is true:
* the batch is not in retry AND request timeout has elapsed after
it is ready (full or linger.ms has reached).
* the batch is in retry AND request timeout has elapsed after the
backoff period ended.
"""
now = time.time()
since_append = now - self.last_append
since_ready = now - (self.created + linger_ms / 1000.0)
since_backoff = now - (self.last_attempt + retry_backoff_ms / 1000.0)
timeout = request_timeout_ms / 1000.0
error = None
if not self.in_retry() and is_full and timeout < since_append:
error = "%d seconds have passed since last append" % (since_append,)
elif not self.in_retry() and timeout < since_ready:
error = "%d seconds have passed since batch creation plus linger time" % (since_ready,)
elif self.in_retry() and timeout < since_backoff:
error = "%d seconds have passed since last attempt plus backoff time" % (since_backoff,)
if error:
self.records.close()
self.done(-1, None, Errors.KafkaTimeoutError(
"Batch for %s containing %s record(s) expired: %s" % (
self.topic_partition, self.records.next_offset(), error)))
return True
return False
def in_retry(self):
return self._retry
def set_retry(self):
self._retry = True
def buffer(self):
return self._buffer
def __str__(self):
return 'ProducerBatch(topic_partition=%s, record_count=%d)' % (
self.topic_partition, self.records.next_offset())
class RecordAccumulator(object):
"""
This class maintains a dequeue per TopicPartition that accumulates messages
into MessageSets to be sent to the server.
The accumulator attempts to bound memory use, and append calls will block
when that memory is exhausted.
Keyword Arguments:
batch_size (int): Requests sent to brokers will contain multiple
batches, one for each partition with data available to be sent.
A small batch size will make batching less common and may reduce
throughput (a batch size of zero will disable batching entirely).
Default: 16384
buffer_memory (int): The total bytes of memory the producer should use
to buffer records waiting to be sent to the server. If records are
sent faster than they can be delivered to the server the producer
will block up to max_block_ms, raising an exception on timeout.
In the current implementation, this setting is an approximation.
Default: 33554432 (32MB)
compression_attrs (int): The compression type for all data generated by
the producer. Valid values are gzip(1), snappy(2), lz4(3), or
none(0).
Compression is of full batches of data, so the efficacy of batching
will also impact the compression ratio (more batching means better
compression). Default: None.
linger_ms (int): An artificial delay time to add before declaring a
messageset (that isn't full) ready for sending. This allows
time for more records to arrive. Setting a non-zero linger_ms
will trade off some latency for potentially better throughput
due to more batching (and hence fewer, larger requests).
Default: 0
retry_backoff_ms (int): An artificial delay time to retry the
produce request upon receiving an error. This avoids exhausting
all retries in a short period of time. Default: 100
"""
DEFAULT_CONFIG = {
'buffer_memory': 33554432,
'batch_size': 16384,
'compression_attrs': 0,
'linger_ms': 0,
'retry_backoff_ms': 100,
'message_version': 0,
'metrics': None,
'metric_group_prefix': 'producer-metrics',
}
def __init__(self, **configs):
self.config = copy.copy(self.DEFAULT_CONFIG)
for key in self.config:
if key in configs:
self.config[key] = configs.pop(key)
self._closed = False
self._flushes_in_progress = AtomicInteger()
self._appends_in_progress = AtomicInteger()
self._batches = collections.defaultdict(collections.deque) # TopicPartition: [ProducerBatch]
self._tp_locks = {None: threading.Lock()} # TopicPartition: Lock, plus a lock to add entries
self._free = SimpleBufferPool(self.config['buffer_memory'],
self.config['batch_size'],
metrics=self.config['metrics'],
metric_group_prefix=self.config['metric_group_prefix'])
self._incomplete = IncompleteProducerBatches()
# The following variables should only be accessed by the sender thread,
# so we don't need to protect them w/ locking.
self.muted = set()
self._drain_index = 0
def append(self, tp, timestamp_ms, key, value, headers, max_time_to_block_ms,
estimated_size=0):
"""Add a record to the accumulator, return the append result.
The append result will contain the future metadata, and flag for
whether the appended batch is full or a new batch is created
Arguments:
tp (TopicPartition): The topic/partition to which this record is
being sent
timestamp_ms (int): The timestamp of the record (epoch ms)
key (bytes): The key for the record
value (bytes): The value for the record
headers (List[Tuple[str, bytes]]): The header fields for the record
max_time_to_block_ms (int): The maximum time in milliseconds to
block for buffer memory to be available
Returns:
tuple: (future, batch_is_full, new_batch_created)
"""
assert isinstance(tp, TopicPartition), 'not TopicPartition'
assert not self._closed, 'RecordAccumulator is closed'
# We keep track of the number of appending thread to make sure we do
# not miss batches in abortIncompleteBatches().
self._appends_in_progress.increment()
try:
if tp not in self._tp_locks:
with self._tp_locks[None]:
if tp not in self._tp_locks:
self._tp_locks[tp] = threading.Lock()
with self._tp_locks[tp]:
# check if we have an in-progress batch
dq = self._batches[tp]
if dq:
last = dq[-1]
future = last.try_append(timestamp_ms, key, value, headers)
if future is not None:
batch_is_full = len(dq) > 1 or last.records.is_full()
return future, batch_is_full, False
size = max(self.config['batch_size'], estimated_size)
log.debug("Allocating a new %d byte message buffer for %s", size, tp) # trace
buf = self._free.allocate(size, max_time_to_block_ms)
with self._tp_locks[tp]:
# Need to check if producer is closed again after grabbing the
# dequeue lock.
assert not self._closed, 'RecordAccumulator is closed'
if dq:
last = dq[-1]
future = last.try_append(timestamp_ms, key, value, headers)
if future is not None:
# Somebody else found us a batch, return the one we
# waited for! Hopefully this doesn't happen often...
self._free.deallocate(buf)
batch_is_full = len(dq) > 1 or last.records.is_full()
return future, batch_is_full, False
records = MemoryRecordsBuilder(
self.config['message_version'],
self.config['compression_attrs'],
self.config['batch_size']
)
batch = ProducerBatch(tp, records, buf)
future = batch.try_append(timestamp_ms, key, value, headers)
if not future:
raise Exception()
dq.append(batch)
self._incomplete.add(batch)
batch_is_full = len(dq) > 1 or batch.records.is_full()
return future, batch_is_full, True
finally:
self._appends_in_progress.decrement()
def abort_expired_batches(self, request_timeout_ms, cluster):
"""Abort the batches that have been sitting in RecordAccumulator for
more than the configured request_timeout due to metadata being
unavailable.
Arguments:
request_timeout_ms (int): milliseconds to timeout
cluster (ClusterMetadata): current metadata for kafka cluster
Returns:
list of ProducerBatch that were expired
"""
expired_batches = []
to_remove = []
count = 0
for tp in list(self._batches.keys()):
assert tp in self._tp_locks, 'TopicPartition not in locks dict'
# We only check if the batch should be expired if the partition
# does not have a batch in flight. This is to avoid the later
# batches get expired when an earlier batch is still in progress.
# This protection only takes effect when user sets
# max.in.flight.request.per.connection=1. Otherwise the expiration
# order is not guranteed.
if tp in self.muted:
continue
with self._tp_locks[tp]:
# iterate over the batches and expire them if they have stayed
# in accumulator for more than request_timeout_ms
dq = self._batches[tp]
for batch in dq:
is_full = bool(bool(batch != dq[-1]) or batch.records.is_full())
# check if the batch is expired
if batch.maybe_expire(request_timeout_ms,
self.config['retry_backoff_ms'],
self.config['linger_ms'],
is_full):
expired_batches.append(batch)
to_remove.append(batch)
count += 1
self.deallocate(batch)
else:
# Stop at the first batch that has not expired.
break
# Python does not allow us to mutate the dq during iteration
# Assuming expired batches are infrequent, this is better than
# creating a new copy of the deque for iteration on every loop
if to_remove:
for batch in to_remove:
dq.remove(batch)
to_remove = []
if expired_batches:
log.warning("Expired %d batches in accumulator", count) # trace
return expired_batches
def reenqueue(self, batch):
"""Re-enqueue the given record batch in the accumulator to retry."""
now = time.time()
batch.attempts += 1
batch.last_attempt = now
batch.last_append = now
batch.set_retry()
assert batch.topic_partition in self._tp_locks, 'TopicPartition not in locks dict'
assert batch.topic_partition in self._batches, 'TopicPartition not in batches'
dq = self._batches[batch.topic_partition]
with self._tp_locks[batch.topic_partition]:
dq.appendleft(batch)
def ready(self, cluster):
"""
Get a list of nodes whose partitions are ready to be sent, and the
earliest time at which any non-sendable partition will be ready;
Also return the flag for whether there are any unknown leaders for the
accumulated partition batches.
A destination node is ready to send if:
* There is at least one partition that is not backing off its send
* and those partitions are not muted (to prevent reordering if
max_in_flight_requests_per_connection is set to 1)
* and any of the following are true:
* The record set is full
* The record set has sat in the accumulator for at least linger_ms
milliseconds
* The accumulator is out of memory and threads are blocking waiting
for data (in this case all partitions are immediately considered
ready).
* The accumulator has been closed
Arguments:
cluster (ClusterMetadata):
Returns:
tuple:
ready_nodes (set): node_ids that have ready batches
next_ready_check (float): secs until next ready after backoff
unknown_leaders_exist (bool): True if metadata refresh needed
"""
ready_nodes = set()
next_ready_check = 9999999.99
unknown_leaders_exist = False
now = time.time()
exhausted = bool(self._free.queued() > 0)
# several threads are accessing self._batches -- to simplify
# concurrent access, we iterate over a snapshot of partitions
# and lock each partition separately as needed
partitions = list(self._batches.keys())
for tp in partitions:
leader = cluster.leader_for_partition(tp)
if leader is None or leader == -1:
unknown_leaders_exist = True
continue
elif leader in ready_nodes:
continue
elif tp in self.muted:
continue
with self._tp_locks[tp]:
dq = self._batches[tp]
if not dq:
continue
batch = dq[0]
retry_backoff = self.config['retry_backoff_ms'] / 1000.0
linger = self.config['linger_ms'] / 1000.0
backing_off = bool(batch.attempts > 0 and
batch.last_attempt + retry_backoff > now)
waited_time = now - batch.last_attempt
time_to_wait = retry_backoff if backing_off else linger
time_left = max(time_to_wait - waited_time, 0)
full = bool(len(dq) > 1 or batch.records.is_full())
expired = bool(waited_time >= time_to_wait)
sendable = (full or expired or exhausted or self._closed or
self._flush_in_progress())
if sendable and not backing_off:
ready_nodes.add(leader)
else:
# Note that this results in a conservative estimate since
# an un-sendable partition may have a leader that will
# later be found to have sendable data. However, this is
# good enough since we'll just wake up and then sleep again
# for the remaining time.
next_ready_check = min(time_left, next_ready_check)
return ready_nodes, next_ready_check, unknown_leaders_exist
def has_unsent(self):
"""Return whether there is any unsent record in the accumulator."""
for tp in list(self._batches.keys()):
with self._tp_locks[tp]:
dq = self._batches[tp]
if len(dq):
return True
return False
def drain(self, cluster, nodes, max_size):
"""
Drain all the data for the given nodes and collate them into a list of
batches that will fit within the specified size on a per-node basis.
This method attempts to avoid choosing the same topic-node repeatedly.
Arguments:
cluster (ClusterMetadata): The current cluster metadata
nodes (list): list of node_ids to drain
max_size (int): maximum number of bytes to drain
Returns:
dict: {node_id: list of ProducerBatch} with total size less than the
requested max_size.
"""
if not nodes:
return {}
now = time.time()
batches = {}
for node_id in nodes:
size = 0
partitions = list(cluster.partitions_for_broker(node_id))
ready = []
# to make starvation less likely this loop doesn't start at 0
self._drain_index %= len(partitions)
start = self._drain_index
while True:
tp = partitions[self._drain_index]
if tp in self._batches and tp not in self.muted:
with self._tp_locks[tp]:
dq = self._batches[tp]
if dq:
first = dq[0]
backoff = (
bool(first.attempts > 0) and
bool(first.last_attempt +
self.config['retry_backoff_ms'] / 1000.0
> now)
)
# Only drain the batch if it is not during backoff
if not backoff:
if (size + first.records.size_in_bytes() > max_size
and len(ready) > 0):
# there is a rare case that a single batch
# size is larger than the request size due
# to compression; in this case we will
# still eventually send this batch in a
# single request
break
else:
batch = dq.popleft()
batch.records.close()
size += batch.records.size_in_bytes()
ready.append(batch)
batch.drained = now
self._drain_index += 1
self._drain_index %= len(partitions)
if start == self._drain_index:
break
batches[node_id] = ready
return batches
def deallocate(self, batch):
"""Deallocate the record batch."""
self._incomplete.remove(batch)
self._free.deallocate(batch.buffer())
def _flush_in_progress(self):
"""Are there any threads currently waiting on a flush?"""
return self._flushes_in_progress.get() > 0
def begin_flush(self):
"""
Initiate the flushing of data from the accumulator...this makes all
requests immediately ready
"""
self._flushes_in_progress.increment()
def await_flush_completion(self, timeout=None):
"""
Mark all partitions as ready to send and block until the send is complete
"""
try:
for batch in self._incomplete.all():
log.debug('Waiting on produce to %s',
batch.produce_future.topic_partition)
if not batch.produce_future.wait(timeout=timeout):
raise Errors.KafkaTimeoutError('Timeout waiting for future')
if not batch.produce_future.is_done:
raise Errors.UnknownError('Future not done')
if batch.produce_future.failed():
log.warning(batch.produce_future.exception)
finally:
self._flushes_in_progress.decrement()
def abort_incomplete_batches(self):
"""
This function is only called when sender is closed forcefully. It will fail all the
incomplete batches and return.
"""
# We need to keep aborting the incomplete batch until no thread is trying to append to
# 1. Avoid losing batches.
# 2. Free up memory in case appending threads are blocked on buffer full.
# This is a tight loop but should be able to get through very quickly.
while True:
self._abort_batches()
if not self._appends_in_progress.get():
break
# After this point, no thread will append any messages because they will see the close
# flag set. We need to do the last abort after no thread was appending in case the there was a new
# batch appended by the last appending thread.
self._abort_batches()
self._batches.clear()
def _abort_batches(self):
"""Go through incomplete batches and abort them."""
error = Errors.IllegalStateError("Producer is closed forcefully.")
for batch in self._incomplete.all():
tp = batch.topic_partition
# Close the batch before aborting
with self._tp_locks[tp]:
batch.records.close()
batch.done(exception=error)
self.deallocate(batch)
def close(self):
"""Close this accumulator and force all the record buffers to be drained."""
self._closed = True
class IncompleteProducerBatches(object):
"""A threadsafe helper class to hold ProducerBatches that haven't been ack'd yet"""
def __init__(self):
self._incomplete = set()
self._lock = threading.Lock()
def add(self, batch):
with self._lock:
return self._incomplete.add(batch)
def remove(self, batch):
with self._lock:
return self._incomplete.remove(batch)
def all(self):
with self._lock:
return list(self._incomplete)

View File

@@ -0,0 +1,517 @@
from __future__ import absolute_import, division
import collections
import copy
import logging
import threading
import time
from kafka.vendor import six
from kafka import errors as Errors
from kafka.metrics.measurable import AnonMeasurable
from kafka.metrics.stats import Avg, Max, Rate
from kafka.protocol.produce import ProduceRequest
from kafka.structs import TopicPartition
from kafka.version import __version__
log = logging.getLogger(__name__)
class Sender(threading.Thread):
"""
The background thread that handles the sending of produce requests to the
Kafka cluster. This thread makes metadata requests to renew its view of the
cluster and then sends produce requests to the appropriate nodes.
"""
DEFAULT_CONFIG = {
'max_request_size': 1048576,
'acks': 1,
'retries': 0,
'request_timeout_ms': 30000,
'guarantee_message_order': False,
'client_id': 'kafka-python-' + __version__,
'api_version': (0, 8, 0),
}
def __init__(self, client, metadata, accumulator, metrics, **configs):
super(Sender, self).__init__()
self.config = copy.copy(self.DEFAULT_CONFIG)
for key in self.config:
if key in configs:
self.config[key] = configs.pop(key)
self.name = self.config['client_id'] + '-network-thread'
self._client = client
self._accumulator = accumulator
self._metadata = client.cluster
self._running = True
self._force_close = False
self._topics_to_add = set()
self._sensors = SenderMetrics(metrics, self._client, self._metadata)
def run(self):
"""The main run loop for the sender thread."""
log.debug("Starting Kafka producer I/O thread.")
# main loop, runs until close is called
while self._running:
try:
self.run_once()
except Exception:
log.exception("Uncaught error in kafka producer I/O thread")
log.debug("Beginning shutdown of Kafka producer I/O thread, sending"
" remaining records.")
# okay we stopped accepting requests but there may still be
# requests in the accumulator or waiting for acknowledgment,
# wait until these are completed.
while (not self._force_close
and (self._accumulator.has_unsent()
or self._client.in_flight_request_count() > 0)):
try:
self.run_once()
except Exception:
log.exception("Uncaught error in kafka producer I/O thread")
if self._force_close:
# We need to fail all the incomplete batches and wake up the
# threads waiting on the futures.
self._accumulator.abort_incomplete_batches()
try:
self._client.close()
except Exception:
log.exception("Failed to close network client")
log.debug("Shutdown of Kafka producer I/O thread has completed.")
def run_once(self):
"""Run a single iteration of sending."""
while self._topics_to_add:
self._client.add_topic(self._topics_to_add.pop())
# get the list of partitions with data ready to send
result = self._accumulator.ready(self._metadata)
ready_nodes, next_ready_check_delay, unknown_leaders_exist = result
# if there are any partitions whose leaders are not known yet, force
# metadata update
if unknown_leaders_exist:
log.debug('Unknown leaders exist, requesting metadata update')
self._metadata.request_update()
# remove any nodes we aren't ready to send to
not_ready_timeout = float('inf')
for node in list(ready_nodes):
if not self._client.is_ready(node):
log.debug('Node %s not ready; delaying produce of accumulated batch', node)
self._client.maybe_connect(node, wakeup=False)
ready_nodes.remove(node)
not_ready_timeout = min(not_ready_timeout,
self._client.connection_delay(node))
# create produce requests
batches_by_node = self._accumulator.drain(
self._metadata, ready_nodes, self.config['max_request_size'])
if self.config['guarantee_message_order']:
# Mute all the partitions drained
for batch_list in six.itervalues(batches_by_node):
for batch in batch_list:
self._accumulator.muted.add(batch.topic_partition)
expired_batches = self._accumulator.abort_expired_batches(
self.config['request_timeout_ms'], self._metadata)
for expired_batch in expired_batches:
self._sensors.record_errors(expired_batch.topic_partition.topic, expired_batch.record_count)
self._sensors.update_produce_request_metrics(batches_by_node)
requests = self._create_produce_requests(batches_by_node)
# If we have any nodes that are ready to send + have sendable data,
# poll with 0 timeout so this can immediately loop and try sending more
# data. Otherwise, the timeout is determined by nodes that have
# partitions with data that isn't yet sendable (e.g. lingering, backing
# off). Note that this specifically does not include nodes with
# sendable data that aren't ready to send since they would cause busy
# looping.
poll_timeout_ms = min(next_ready_check_delay * 1000, not_ready_timeout)
if ready_nodes:
log.debug("Nodes with data ready to send: %s", ready_nodes) # trace
log.debug("Created %d produce requests: %s", len(requests), requests) # trace
poll_timeout_ms = 0
for node_id, request in six.iteritems(requests):
batches = batches_by_node[node_id]
log.debug('Sending Produce Request: %r', request)
(self._client.send(node_id, request, wakeup=False)
.add_callback(
self._handle_produce_response, node_id, time.time(), batches)
.add_errback(
self._failed_produce, batches, node_id))
# if some partitions are already ready to be sent, the select time
# would be 0; otherwise if some partition already has some data
# accumulated but not ready yet, the select time will be the time
# difference between now and its linger expiry time; otherwise the
# select time will be the time difference between now and the
# metadata expiry time
self._client.poll(timeout_ms=poll_timeout_ms)
def initiate_close(self):
"""Start closing the sender (won't complete until all data is sent)."""
self._running = False
self._accumulator.close()
self.wakeup()
def force_close(self):
"""Closes the sender without sending out any pending messages."""
self._force_close = True
self.initiate_close()
def add_topic(self, topic):
# This is generally called from a separate thread
# so this needs to be a thread-safe operation
# we assume that checking set membership across threads
# is ok where self._client._topics should never
# remove topics for a producer instance, only add them.
if topic not in self._client._topics:
self._topics_to_add.add(topic)
self.wakeup()
def _failed_produce(self, batches, node_id, error):
log.debug("Error sending produce request to node %d: %s", node_id, error) # trace
for batch in batches:
self._complete_batch(batch, error, -1, None)
def _handle_produce_response(self, node_id, send_time, batches, response):
"""Handle a produce response."""
# if we have a response, parse it
log.debug('Parsing produce response: %r', response)
if response:
batches_by_partition = dict([(batch.topic_partition, batch)
for batch in batches])
for topic, partitions in response.topics:
for partition_info in partitions:
global_error = None
log_start_offset = None
if response.API_VERSION < 2:
partition, error_code, offset = partition_info
ts = None
elif 2 <= response.API_VERSION <= 4:
partition, error_code, offset, ts = partition_info
elif 5 <= response.API_VERSION <= 7:
partition, error_code, offset, ts, log_start_offset = partition_info
else:
# the ignored parameter is record_error of type list[(batch_index: int, error_message: str)]
partition, error_code, offset, ts, log_start_offset, _, global_error = partition_info
tp = TopicPartition(topic, partition)
error = Errors.for_code(error_code)
batch = batches_by_partition[tp]
self._complete_batch(batch, error, offset, ts, log_start_offset, global_error)
if response.API_VERSION > 0:
self._sensors.record_throttle_time(response.throttle_time_ms, node=node_id)
else:
# this is the acks = 0 case, just complete all requests
for batch in batches:
self._complete_batch(batch, None, -1, None)
def _complete_batch(self, batch, error, base_offset, timestamp_ms=None, log_start_offset=None, global_error=None):
"""Complete or retry the given batch of records.
Arguments:
batch (RecordBatch): The record batch
error (Exception): The error (or None if none)
base_offset (int): The base offset assigned to the records if successful
timestamp_ms (int, optional): The timestamp returned by the broker for this batch
log_start_offset (int): The start offset of the log at the time this produce response was created
global_error (str): The summarising error message
"""
# Standardize no-error to None
if error is Errors.NoError:
error = None
if error is not None and self._can_retry(batch, error):
# retry
log.warning("Got error produce response on topic-partition %s,"
" retrying (%d attempts left). Error: %s",
batch.topic_partition,
self.config['retries'] - batch.attempts - 1,
global_error or error)
self._accumulator.reenqueue(batch)
self._sensors.record_retries(batch.topic_partition.topic, batch.record_count)
else:
if error is Errors.TopicAuthorizationFailedError:
error = error(batch.topic_partition.topic)
# tell the user the result of their request
batch.done(base_offset, timestamp_ms, error, log_start_offset, global_error)
self._accumulator.deallocate(batch)
if error is not None:
self._sensors.record_errors(batch.topic_partition.topic, batch.record_count)
if getattr(error, 'invalid_metadata', False):
self._metadata.request_update()
# Unmute the completed partition.
if self.config['guarantee_message_order']:
self._accumulator.muted.remove(batch.topic_partition)
def _can_retry(self, batch, error):
"""
We can retry a send if the error is transient and the number of
attempts taken is fewer than the maximum allowed
"""
return (batch.attempts < self.config['retries']
and getattr(error, 'retriable', False))
def _create_produce_requests(self, collated):
"""
Transfer the record batches into a list of produce requests on a
per-node basis.
Arguments:
collated: {node_id: [RecordBatch]}
Returns:
dict: {node_id: ProduceRequest} (version depends on api_version)
"""
requests = {}
for node_id, batches in six.iteritems(collated):
requests[node_id] = self._produce_request(
node_id, self.config['acks'],
self.config['request_timeout_ms'], batches)
return requests
def _produce_request(self, node_id, acks, timeout, batches):
"""Create a produce request from the given record batches.
Returns:
ProduceRequest (version depends on api_version)
"""
produce_records_by_partition = collections.defaultdict(dict)
for batch in batches:
topic = batch.topic_partition.topic
partition = batch.topic_partition.partition
buf = batch.records.buffer()
produce_records_by_partition[topic][partition] = buf
kwargs = {}
if self.config['api_version'] >= (2, 1):
version = 7
elif self.config['api_version'] >= (2, 0):
version = 6
elif self.config['api_version'] >= (1, 1):
version = 5
elif self.config['api_version'] >= (1, 0):
version = 4
elif self.config['api_version'] >= (0, 11):
version = 3
kwargs = dict(transactional_id=None)
elif self.config['api_version'] >= (0, 10):
version = 2
elif self.config['api_version'] == (0, 9):
version = 1
else:
version = 0
return ProduceRequest[version](
required_acks=acks,
timeout=timeout,
topics=[(topic, list(partition_info.items()))
for topic, partition_info
in six.iteritems(produce_records_by_partition)],
**kwargs
)
def wakeup(self):
"""Wake up the selector associated with this send thread."""
self._client.wakeup()
def bootstrap_connected(self):
return self._client.bootstrap_connected()
class SenderMetrics(object):
def __init__(self, metrics, client, metadata):
self.metrics = metrics
self._client = client
self._metadata = metadata
sensor_name = 'batch-size'
self.batch_size_sensor = self.metrics.sensor(sensor_name)
self.add_metric('batch-size-avg', Avg(),
sensor_name=sensor_name,
description='The average number of bytes sent per partition per-request.')
self.add_metric('batch-size-max', Max(),
sensor_name=sensor_name,
description='The max number of bytes sent per partition per-request.')
sensor_name = 'compression-rate'
self.compression_rate_sensor = self.metrics.sensor(sensor_name)
self.add_metric('compression-rate-avg', Avg(),
sensor_name=sensor_name,
description='The average compression rate of record batches.')
sensor_name = 'queue-time'
self.queue_time_sensor = self.metrics.sensor(sensor_name)
self.add_metric('record-queue-time-avg', Avg(),
sensor_name=sensor_name,
description='The average time in ms record batches spent in the record accumulator.')
self.add_metric('record-queue-time-max', Max(),
sensor_name=sensor_name,
description='The maximum time in ms record batches spent in the record accumulator.')
sensor_name = 'produce-throttle-time'
self.produce_throttle_time_sensor = self.metrics.sensor(sensor_name)
self.add_metric('produce-throttle-time-avg', Avg(),
sensor_name=sensor_name,
description='The average throttle time in ms')
self.add_metric('produce-throttle-time-max', Max(),
sensor_name=sensor_name,
description='The maximum throttle time in ms')
sensor_name = 'records-per-request'
self.records_per_request_sensor = self.metrics.sensor(sensor_name)
self.add_metric('record-send-rate', Rate(),
sensor_name=sensor_name,
description='The average number of records sent per second.')
self.add_metric('records-per-request-avg', Avg(),
sensor_name=sensor_name,
description='The average number of records per request.')
sensor_name = 'bytes'
self.byte_rate_sensor = self.metrics.sensor(sensor_name)
self.add_metric('byte-rate', Rate(),
sensor_name=sensor_name,
description='The average number of bytes sent per second.')
sensor_name = 'record-retries'
self.retry_sensor = self.metrics.sensor(sensor_name)
self.add_metric('record-retry-rate', Rate(),
sensor_name=sensor_name,
description='The average per-second number of retried record sends')
sensor_name = 'errors'
self.error_sensor = self.metrics.sensor(sensor_name)
self.add_metric('record-error-rate', Rate(),
sensor_name=sensor_name,
description='The average per-second number of record sends that resulted in errors')
sensor_name = 'record-size-max'
self.max_record_size_sensor = self.metrics.sensor(sensor_name)
self.add_metric('record-size-max', Max(),
sensor_name=sensor_name,
description='The maximum record size across all batches')
self.add_metric('record-size-avg', Avg(),
sensor_name=sensor_name,
description='The average maximum record size per batch')
self.add_metric('requests-in-flight',
AnonMeasurable(lambda *_: self._client.in_flight_request_count()),
description='The current number of in-flight requests awaiting a response.')
self.add_metric('metadata-age',
AnonMeasurable(lambda _, now: (now - self._metadata._last_successful_refresh_ms) / 1000),
description='The age in seconds of the current producer metadata being used.')
def add_metric(self, metric_name, measurable, group_name='producer-metrics',
description=None, tags=None,
sensor_name=None):
m = self.metrics
metric = m.metric_name(metric_name, group_name, description, tags)
if sensor_name:
sensor = m.sensor(sensor_name)
sensor.add(metric, measurable)
else:
m.add_metric(metric, measurable)
def maybe_register_topic_metrics(self, topic):
def sensor_name(name):
return 'topic.{0}.{1}'.format(topic, name)
# if one sensor of the metrics has been registered for the topic,
# then all other sensors should have been registered; and vice versa
if not self.metrics.get_sensor(sensor_name('records-per-batch')):
self.add_metric('record-send-rate', Rate(),
sensor_name=sensor_name('records-per-batch'),
group_name='producer-topic-metrics.' + topic,
description= 'Records sent per second for topic ' + topic)
self.add_metric('byte-rate', Rate(),
sensor_name=sensor_name('bytes'),
group_name='producer-topic-metrics.' + topic,
description='Bytes per second for topic ' + topic)
self.add_metric('compression-rate', Avg(),
sensor_name=sensor_name('compression-rate'),
group_name='producer-topic-metrics.' + topic,
description='Average Compression ratio for topic ' + topic)
self.add_metric('record-retry-rate', Rate(),
sensor_name=sensor_name('record-retries'),
group_name='producer-topic-metrics.' + topic,
description='Record retries per second for topic ' + topic)
self.add_metric('record-error-rate', Rate(),
sensor_name=sensor_name('record-errors'),
group_name='producer-topic-metrics.' + topic,
description='Record errors per second for topic ' + topic)
def update_produce_request_metrics(self, batches_map):
for node_batch in batches_map.values():
records = 0
total_bytes = 0
for batch in node_batch:
# register all per-topic metrics at once
topic = batch.topic_partition.topic
self.maybe_register_topic_metrics(topic)
# per-topic record send rate
topic_records_count = self.metrics.get_sensor(
'topic.' + topic + '.records-per-batch')
topic_records_count.record(batch.record_count)
# per-topic bytes send rate
topic_byte_rate = self.metrics.get_sensor(
'topic.' + topic + '.bytes')
topic_byte_rate.record(batch.records.size_in_bytes())
# per-topic compression rate
topic_compression_rate = self.metrics.get_sensor(
'topic.' + topic + '.compression-rate')
topic_compression_rate.record(batch.records.compression_rate())
# global metrics
self.batch_size_sensor.record(batch.records.size_in_bytes())
if batch.drained:
self.queue_time_sensor.record(batch.drained - batch.created)
self.compression_rate_sensor.record(batch.records.compression_rate())
self.max_record_size_sensor.record(batch.max_record_size)
records += batch.record_count
total_bytes += batch.records.size_in_bytes()
self.records_per_request_sensor.record(records)
self.byte_rate_sensor.record(total_bytes)
def record_retries(self, topic, count):
self.retry_sensor.record(count)
sensor = self.metrics.get_sensor('topic.' + topic + '.record-retries')
if sensor:
sensor.record(count)
def record_errors(self, topic, count):
self.error_sensor.record(count)
sensor = self.metrics.get_sensor('topic.' + topic + '.record-errors')
if sensor:
sensor.record(count)
def record_throttle_time(self, throttle_time_ms, node=None):
self.produce_throttle_time_sensor.record(throttle_time_ms)

View File

@@ -0,0 +1,46 @@
from __future__ import absolute_import
API_KEYS = {
0: 'Produce',
1: 'Fetch',
2: 'ListOffsets',
3: 'Metadata',
4: 'LeaderAndIsr',
5: 'StopReplica',
6: 'UpdateMetadata',
7: 'ControlledShutdown',
8: 'OffsetCommit',
9: 'OffsetFetch',
10: 'FindCoordinator',
11: 'JoinGroup',
12: 'Heartbeat',
13: 'LeaveGroup',
14: 'SyncGroup',
15: 'DescribeGroups',
16: 'ListGroups',
17: 'SaslHandshake',
18: 'ApiVersions',
19: 'CreateTopics',
20: 'DeleteTopics',
21: 'DeleteRecords',
22: 'InitProducerId',
23: 'OffsetForLeaderEpoch',
24: 'AddPartitionsToTxn',
25: 'AddOffsetsToTxn',
26: 'EndTxn',
27: 'WriteTxnMarkers',
28: 'TxnOffsetCommit',
29: 'DescribeAcls',
30: 'CreateAcls',
31: 'DeleteAcls',
32: 'DescribeConfigs',
33: 'AlterConfigs',
36: 'SaslAuthenticate',
37: 'CreatePartitions',
38: 'CreateDelegationToken',
39: 'RenewDelegationToken',
40: 'ExpireDelegationToken',
41: 'DescribeDelegationToken',
42: 'DeleteGroups',
}

View File

@@ -0,0 +1,19 @@
from __future__ import absolute_import
import abc
class AbstractType(object):
__metaclass__ = abc.ABCMeta
@abc.abstractmethod
def encode(cls, value): # pylint: disable=no-self-argument
pass
@abc.abstractmethod
def decode(cls, data): # pylint: disable=no-self-argument
pass
@classmethod
def repr(cls, value):
return repr(value)

View File

@@ -0,0 +1,925 @@
from __future__ import absolute_import
from kafka.protocol.api import Request, Response
from kafka.protocol.types import Array, Boolean, Bytes, Int8, Int16, Int32, Int64, Schema, String
class ApiVersionResponse_v0(Response):
API_KEY = 18
API_VERSION = 0
SCHEMA = Schema(
('error_code', Int16),
('api_versions', Array(
('api_key', Int16),
('min_version', Int16),
('max_version', Int16)))
)
class ApiVersionResponse_v1(Response):
API_KEY = 18
API_VERSION = 1
SCHEMA = Schema(
('error_code', Int16),
('api_versions', Array(
('api_key', Int16),
('min_version', Int16),
('max_version', Int16))),
('throttle_time_ms', Int32)
)
class ApiVersionResponse_v2(Response):
API_KEY = 18
API_VERSION = 2
SCHEMA = ApiVersionResponse_v1.SCHEMA
class ApiVersionRequest_v0(Request):
API_KEY = 18
API_VERSION = 0
RESPONSE_TYPE = ApiVersionResponse_v0
SCHEMA = Schema()
class ApiVersionRequest_v1(Request):
API_KEY = 18
API_VERSION = 1
RESPONSE_TYPE = ApiVersionResponse_v1
SCHEMA = ApiVersionRequest_v0.SCHEMA
class ApiVersionRequest_v2(Request):
API_KEY = 18
API_VERSION = 2
RESPONSE_TYPE = ApiVersionResponse_v1
SCHEMA = ApiVersionRequest_v0.SCHEMA
ApiVersionRequest = [
ApiVersionRequest_v0, ApiVersionRequest_v1, ApiVersionRequest_v2,
]
ApiVersionResponse = [
ApiVersionResponse_v0, ApiVersionResponse_v1, ApiVersionResponse_v2,
]
class CreateTopicsResponse_v0(Response):
API_KEY = 19
API_VERSION = 0
SCHEMA = Schema(
('topic_errors', Array(
('topic', String('utf-8')),
('error_code', Int16)))
)
class CreateTopicsResponse_v1(Response):
API_KEY = 19
API_VERSION = 1
SCHEMA = Schema(
('topic_errors', Array(
('topic', String('utf-8')),
('error_code', Int16),
('error_message', String('utf-8'))))
)
class CreateTopicsResponse_v2(Response):
API_KEY = 19
API_VERSION = 2
SCHEMA = Schema(
('throttle_time_ms', Int32),
('topic_errors', Array(
('topic', String('utf-8')),
('error_code', Int16),
('error_message', String('utf-8'))))
)
class CreateTopicsResponse_v3(Response):
API_KEY = 19
API_VERSION = 3
SCHEMA = CreateTopicsResponse_v2.SCHEMA
class CreateTopicsRequest_v0(Request):
API_KEY = 19
API_VERSION = 0
RESPONSE_TYPE = CreateTopicsResponse_v0
SCHEMA = Schema(
('create_topic_requests', Array(
('topic', String('utf-8')),
('num_partitions', Int32),
('replication_factor', Int16),
('replica_assignment', Array(
('partition_id', Int32),
('replicas', Array(Int32)))),
('configs', Array(
('config_key', String('utf-8')),
('config_value', String('utf-8')))))),
('timeout', Int32)
)
class CreateTopicsRequest_v1(Request):
API_KEY = 19
API_VERSION = 1
RESPONSE_TYPE = CreateTopicsResponse_v1
SCHEMA = Schema(
('create_topic_requests', Array(
('topic', String('utf-8')),
('num_partitions', Int32),
('replication_factor', Int16),
('replica_assignment', Array(
('partition_id', Int32),
('replicas', Array(Int32)))),
('configs', Array(
('config_key', String('utf-8')),
('config_value', String('utf-8')))))),
('timeout', Int32),
('validate_only', Boolean)
)
class CreateTopicsRequest_v2(Request):
API_KEY = 19
API_VERSION = 2
RESPONSE_TYPE = CreateTopicsResponse_v2
SCHEMA = CreateTopicsRequest_v1.SCHEMA
class CreateTopicsRequest_v3(Request):
API_KEY = 19
API_VERSION = 3
RESPONSE_TYPE = CreateTopicsResponse_v3
SCHEMA = CreateTopicsRequest_v1.SCHEMA
CreateTopicsRequest = [
CreateTopicsRequest_v0, CreateTopicsRequest_v1,
CreateTopicsRequest_v2, CreateTopicsRequest_v3,
]
CreateTopicsResponse = [
CreateTopicsResponse_v0, CreateTopicsResponse_v1,
CreateTopicsResponse_v2, CreateTopicsResponse_v3,
]
class DeleteTopicsResponse_v0(Response):
API_KEY = 20
API_VERSION = 0
SCHEMA = Schema(
('topic_error_codes', Array(
('topic', String('utf-8')),
('error_code', Int16)))
)
class DeleteTopicsResponse_v1(Response):
API_KEY = 20
API_VERSION = 1
SCHEMA = Schema(
('throttle_time_ms', Int32),
('topic_error_codes', Array(
('topic', String('utf-8')),
('error_code', Int16)))
)
class DeleteTopicsResponse_v2(Response):
API_KEY = 20
API_VERSION = 2
SCHEMA = DeleteTopicsResponse_v1.SCHEMA
class DeleteTopicsResponse_v3(Response):
API_KEY = 20
API_VERSION = 3
SCHEMA = DeleteTopicsResponse_v1.SCHEMA
class DeleteTopicsRequest_v0(Request):
API_KEY = 20
API_VERSION = 0
RESPONSE_TYPE = DeleteTopicsResponse_v0
SCHEMA = Schema(
('topics', Array(String('utf-8'))),
('timeout', Int32)
)
class DeleteTopicsRequest_v1(Request):
API_KEY = 20
API_VERSION = 1
RESPONSE_TYPE = DeleteTopicsResponse_v1
SCHEMA = DeleteTopicsRequest_v0.SCHEMA
class DeleteTopicsRequest_v2(Request):
API_KEY = 20
API_VERSION = 2
RESPONSE_TYPE = DeleteTopicsResponse_v2
SCHEMA = DeleteTopicsRequest_v0.SCHEMA
class DeleteTopicsRequest_v3(Request):
API_KEY = 20
API_VERSION = 3
RESPONSE_TYPE = DeleteTopicsResponse_v3
SCHEMA = DeleteTopicsRequest_v0.SCHEMA
DeleteTopicsRequest = [
DeleteTopicsRequest_v0, DeleteTopicsRequest_v1,
DeleteTopicsRequest_v2, DeleteTopicsRequest_v3,
]
DeleteTopicsResponse = [
DeleteTopicsResponse_v0, DeleteTopicsResponse_v1,
DeleteTopicsResponse_v2, DeleteTopicsResponse_v3,
]
class ListGroupsResponse_v0(Response):
API_KEY = 16
API_VERSION = 0
SCHEMA = Schema(
('error_code', Int16),
('groups', Array(
('group', String('utf-8')),
('protocol_type', String('utf-8'))))
)
class ListGroupsResponse_v1(Response):
API_KEY = 16
API_VERSION = 1
SCHEMA = Schema(
('throttle_time_ms', Int32),
('error_code', Int16),
('groups', Array(
('group', String('utf-8')),
('protocol_type', String('utf-8'))))
)
class ListGroupsResponse_v2(Response):
API_KEY = 16
API_VERSION = 2
SCHEMA = ListGroupsResponse_v1.SCHEMA
class ListGroupsRequest_v0(Request):
API_KEY = 16
API_VERSION = 0
RESPONSE_TYPE = ListGroupsResponse_v0
SCHEMA = Schema()
class ListGroupsRequest_v1(Request):
API_KEY = 16
API_VERSION = 1
RESPONSE_TYPE = ListGroupsResponse_v1
SCHEMA = ListGroupsRequest_v0.SCHEMA
class ListGroupsRequest_v2(Request):
API_KEY = 16
API_VERSION = 1
RESPONSE_TYPE = ListGroupsResponse_v2
SCHEMA = ListGroupsRequest_v0.SCHEMA
ListGroupsRequest = [
ListGroupsRequest_v0, ListGroupsRequest_v1,
ListGroupsRequest_v2,
]
ListGroupsResponse = [
ListGroupsResponse_v0, ListGroupsResponse_v1,
ListGroupsResponse_v2,
]
class DescribeGroupsResponse_v0(Response):
API_KEY = 15
API_VERSION = 0
SCHEMA = Schema(
('groups', Array(
('error_code', Int16),
('group', String('utf-8')),
('state', String('utf-8')),
('protocol_type', String('utf-8')),
('protocol', String('utf-8')),
('members', Array(
('member_id', String('utf-8')),
('client_id', String('utf-8')),
('client_host', String('utf-8')),
('member_metadata', Bytes),
('member_assignment', Bytes)))))
)
class DescribeGroupsResponse_v1(Response):
API_KEY = 15
API_VERSION = 1
SCHEMA = Schema(
('throttle_time_ms', Int32),
('groups', Array(
('error_code', Int16),
('group', String('utf-8')),
('state', String('utf-8')),
('protocol_type', String('utf-8')),
('protocol', String('utf-8')),
('members', Array(
('member_id', String('utf-8')),
('client_id', String('utf-8')),
('client_host', String('utf-8')),
('member_metadata', Bytes),
('member_assignment', Bytes)))))
)
class DescribeGroupsResponse_v2(Response):
API_KEY = 15
API_VERSION = 2
SCHEMA = DescribeGroupsResponse_v1.SCHEMA
class DescribeGroupsResponse_v3(Response):
API_KEY = 15
API_VERSION = 3
SCHEMA = Schema(
('throttle_time_ms', Int32),
('groups', Array(
('error_code', Int16),
('group', String('utf-8')),
('state', String('utf-8')),
('protocol_type', String('utf-8')),
('protocol', String('utf-8')),
('members', Array(
('member_id', String('utf-8')),
('client_id', String('utf-8')),
('client_host', String('utf-8')),
('member_metadata', Bytes),
('member_assignment', Bytes)))),
('authorized_operations', Int32))
)
class DescribeGroupsRequest_v0(Request):
API_KEY = 15
API_VERSION = 0
RESPONSE_TYPE = DescribeGroupsResponse_v0
SCHEMA = Schema(
('groups', Array(String('utf-8')))
)
class DescribeGroupsRequest_v1(Request):
API_KEY = 15
API_VERSION = 1
RESPONSE_TYPE = DescribeGroupsResponse_v1
SCHEMA = DescribeGroupsRequest_v0.SCHEMA
class DescribeGroupsRequest_v2(Request):
API_KEY = 15
API_VERSION = 2
RESPONSE_TYPE = DescribeGroupsResponse_v2
SCHEMA = DescribeGroupsRequest_v0.SCHEMA
class DescribeGroupsRequest_v3(Request):
API_KEY = 15
API_VERSION = 3
RESPONSE_TYPE = DescribeGroupsResponse_v2
SCHEMA = Schema(
('groups', Array(String('utf-8'))),
('include_authorized_operations', Boolean)
)
DescribeGroupsRequest = [
DescribeGroupsRequest_v0, DescribeGroupsRequest_v1,
DescribeGroupsRequest_v2, DescribeGroupsRequest_v3,
]
DescribeGroupsResponse = [
DescribeGroupsResponse_v0, DescribeGroupsResponse_v1,
DescribeGroupsResponse_v2, DescribeGroupsResponse_v3,
]
class SaslHandShakeResponse_v0(Response):
API_KEY = 17
API_VERSION = 0
SCHEMA = Schema(
('error_code', Int16),
('enabled_mechanisms', Array(String('utf-8')))
)
class SaslHandShakeResponse_v1(Response):
API_KEY = 17
API_VERSION = 1
SCHEMA = SaslHandShakeResponse_v0.SCHEMA
class SaslHandShakeRequest_v0(Request):
API_KEY = 17
API_VERSION = 0
RESPONSE_TYPE = SaslHandShakeResponse_v0
SCHEMA = Schema(
('mechanism', String('utf-8'))
)
class SaslHandShakeRequest_v1(Request):
API_KEY = 17
API_VERSION = 1
RESPONSE_TYPE = SaslHandShakeResponse_v1
SCHEMA = SaslHandShakeRequest_v0.SCHEMA
SaslHandShakeRequest = [SaslHandShakeRequest_v0, SaslHandShakeRequest_v1]
SaslHandShakeResponse = [SaslHandShakeResponse_v0, SaslHandShakeResponse_v1]
class DescribeAclsResponse_v0(Response):
API_KEY = 29
API_VERSION = 0
SCHEMA = Schema(
('throttle_time_ms', Int32),
('error_code', Int16),
('error_message', String('utf-8')),
('resources', Array(
('resource_type', Int8),
('resource_name', String('utf-8')),
('acls', Array(
('principal', String('utf-8')),
('host', String('utf-8')),
('operation', Int8),
('permission_type', Int8)))))
)
class DescribeAclsResponse_v1(Response):
API_KEY = 29
API_VERSION = 1
SCHEMA = Schema(
('throttle_time_ms', Int32),
('error_code', Int16),
('error_message', String('utf-8')),
('resources', Array(
('resource_type', Int8),
('resource_name', String('utf-8')),
('resource_pattern_type', Int8),
('acls', Array(
('principal', String('utf-8')),
('host', String('utf-8')),
('operation', Int8),
('permission_type', Int8)))))
)
class DescribeAclsResponse_v2(Response):
API_KEY = 29
API_VERSION = 2
SCHEMA = DescribeAclsResponse_v1.SCHEMA
class DescribeAclsRequest_v0(Request):
API_KEY = 29
API_VERSION = 0
RESPONSE_TYPE = DescribeAclsResponse_v0
SCHEMA = Schema(
('resource_type', Int8),
('resource_name', String('utf-8')),
('principal', String('utf-8')),
('host', String('utf-8')),
('operation', Int8),
('permission_type', Int8)
)
class DescribeAclsRequest_v1(Request):
API_KEY = 29
API_VERSION = 1
RESPONSE_TYPE = DescribeAclsResponse_v1
SCHEMA = Schema(
('resource_type', Int8),
('resource_name', String('utf-8')),
('resource_pattern_type_filter', Int8),
('principal', String('utf-8')),
('host', String('utf-8')),
('operation', Int8),
('permission_type', Int8)
)
class DescribeAclsRequest_v2(Request):
"""
Enable flexible version
"""
API_KEY = 29
API_VERSION = 2
RESPONSE_TYPE = DescribeAclsResponse_v2
SCHEMA = DescribeAclsRequest_v1.SCHEMA
DescribeAclsRequest = [DescribeAclsRequest_v0, DescribeAclsRequest_v1]
DescribeAclsResponse = [DescribeAclsResponse_v0, DescribeAclsResponse_v1]
class CreateAclsResponse_v0(Response):
API_KEY = 30
API_VERSION = 0
SCHEMA = Schema(
('throttle_time_ms', Int32),
('creation_responses', Array(
('error_code', Int16),
('error_message', String('utf-8'))))
)
class CreateAclsResponse_v1(Response):
API_KEY = 30
API_VERSION = 1
SCHEMA = CreateAclsResponse_v0.SCHEMA
class CreateAclsRequest_v0(Request):
API_KEY = 30
API_VERSION = 0
RESPONSE_TYPE = CreateAclsResponse_v0
SCHEMA = Schema(
('creations', Array(
('resource_type', Int8),
('resource_name', String('utf-8')),
('principal', String('utf-8')),
('host', String('utf-8')),
('operation', Int8),
('permission_type', Int8)))
)
class CreateAclsRequest_v1(Request):
API_KEY = 30
API_VERSION = 1
RESPONSE_TYPE = CreateAclsResponse_v1
SCHEMA = Schema(
('creations', Array(
('resource_type', Int8),
('resource_name', String('utf-8')),
('resource_pattern_type', Int8),
('principal', String('utf-8')),
('host', String('utf-8')),
('operation', Int8),
('permission_type', Int8)))
)
CreateAclsRequest = [CreateAclsRequest_v0, CreateAclsRequest_v1]
CreateAclsResponse = [CreateAclsResponse_v0, CreateAclsResponse_v1]
class DeleteAclsResponse_v0(Response):
API_KEY = 31
API_VERSION = 0
SCHEMA = Schema(
('throttle_time_ms', Int32),
('filter_responses', Array(
('error_code', Int16),
('error_message', String('utf-8')),
('matching_acls', Array(
('error_code', Int16),
('error_message', String('utf-8')),
('resource_type', Int8),
('resource_name', String('utf-8')),
('principal', String('utf-8')),
('host', String('utf-8')),
('operation', Int8),
('permission_type', Int8)))))
)
class DeleteAclsResponse_v1(Response):
API_KEY = 31
API_VERSION = 1
SCHEMA = Schema(
('throttle_time_ms', Int32),
('filter_responses', Array(
('error_code', Int16),
('error_message', String('utf-8')),
('matching_acls', Array(
('error_code', Int16),
('error_message', String('utf-8')),
('resource_type', Int8),
('resource_name', String('utf-8')),
('resource_pattern_type', Int8),
('principal', String('utf-8')),
('host', String('utf-8')),
('operation', Int8),
('permission_type', Int8)))))
)
class DeleteAclsRequest_v0(Request):
API_KEY = 31
API_VERSION = 0
RESPONSE_TYPE = DeleteAclsResponse_v0
SCHEMA = Schema(
('filters', Array(
('resource_type', Int8),
('resource_name', String('utf-8')),
('principal', String('utf-8')),
('host', String('utf-8')),
('operation', Int8),
('permission_type', Int8)))
)
class DeleteAclsRequest_v1(Request):
API_KEY = 31
API_VERSION = 1
RESPONSE_TYPE = DeleteAclsResponse_v1
SCHEMA = Schema(
('filters', Array(
('resource_type', Int8),
('resource_name', String('utf-8')),
('resource_pattern_type_filter', Int8),
('principal', String('utf-8')),
('host', String('utf-8')),
('operation', Int8),
('permission_type', Int8)))
)
DeleteAclsRequest = [DeleteAclsRequest_v0, DeleteAclsRequest_v1]
DeleteAclsResponse = [DeleteAclsResponse_v0, DeleteAclsResponse_v1]
class AlterConfigsResponse_v0(Response):
API_KEY = 33
API_VERSION = 0
SCHEMA = Schema(
('throttle_time_ms', Int32),
('resources', Array(
('error_code', Int16),
('error_message', String('utf-8')),
('resource_type', Int8),
('resource_name', String('utf-8'))))
)
class AlterConfigsResponse_v1(Response):
API_KEY = 33
API_VERSION = 1
SCHEMA = AlterConfigsResponse_v0.SCHEMA
class AlterConfigsRequest_v0(Request):
API_KEY = 33
API_VERSION = 0
RESPONSE_TYPE = AlterConfigsResponse_v0
SCHEMA = Schema(
('resources', Array(
('resource_type', Int8),
('resource_name', String('utf-8')),
('config_entries', Array(
('config_name', String('utf-8')),
('config_value', String('utf-8')))))),
('validate_only', Boolean)
)
class AlterConfigsRequest_v1(Request):
API_KEY = 33
API_VERSION = 1
RESPONSE_TYPE = AlterConfigsResponse_v1
SCHEMA = AlterConfigsRequest_v0.SCHEMA
AlterConfigsRequest = [AlterConfigsRequest_v0, AlterConfigsRequest_v1]
AlterConfigsResponse = [AlterConfigsResponse_v0, AlterConfigsRequest_v1]
class DescribeConfigsResponse_v0(Response):
API_KEY = 32
API_VERSION = 0
SCHEMA = Schema(
('throttle_time_ms', Int32),
('resources', Array(
('error_code', Int16),
('error_message', String('utf-8')),
('resource_type', Int8),
('resource_name', String('utf-8')),
('config_entries', Array(
('config_names', String('utf-8')),
('config_value', String('utf-8')),
('read_only', Boolean),
('is_default', Boolean),
('is_sensitive', Boolean)))))
)
class DescribeConfigsResponse_v1(Response):
API_KEY = 32
API_VERSION = 1
SCHEMA = Schema(
('throttle_time_ms', Int32),
('resources', Array(
('error_code', Int16),
('error_message', String('utf-8')),
('resource_type', Int8),
('resource_name', String('utf-8')),
('config_entries', Array(
('config_names', String('utf-8')),
('config_value', String('utf-8')),
('read_only', Boolean),
('is_default', Boolean),
('is_sensitive', Boolean),
('config_synonyms', Array(
('config_name', String('utf-8')),
('config_value', String('utf-8')),
('config_source', Int8)))))))
)
class DescribeConfigsResponse_v2(Response):
API_KEY = 32
API_VERSION = 2
SCHEMA = Schema(
('throttle_time_ms', Int32),
('resources', Array(
('error_code', Int16),
('error_message', String('utf-8')),
('resource_type', Int8),
('resource_name', String('utf-8')),
('config_entries', Array(
('config_names', String('utf-8')),
('config_value', String('utf-8')),
('read_only', Boolean),
('config_source', Int8),
('is_sensitive', Boolean),
('config_synonyms', Array(
('config_name', String('utf-8')),
('config_value', String('utf-8')),
('config_source', Int8)))))))
)
class DescribeConfigsRequest_v0(Request):
API_KEY = 32
API_VERSION = 0
RESPONSE_TYPE = DescribeConfigsResponse_v0
SCHEMA = Schema(
('resources', Array(
('resource_type', Int8),
('resource_name', String('utf-8')),
('config_names', Array(String('utf-8')))))
)
class DescribeConfigsRequest_v1(Request):
API_KEY = 32
API_VERSION = 1
RESPONSE_TYPE = DescribeConfigsResponse_v1
SCHEMA = Schema(
('resources', Array(
('resource_type', Int8),
('resource_name', String('utf-8')),
('config_names', Array(String('utf-8'))))),
('include_synonyms', Boolean)
)
class DescribeConfigsRequest_v2(Request):
API_KEY = 32
API_VERSION = 2
RESPONSE_TYPE = DescribeConfigsResponse_v2
SCHEMA = DescribeConfigsRequest_v1.SCHEMA
DescribeConfigsRequest = [
DescribeConfigsRequest_v0, DescribeConfigsRequest_v1,
DescribeConfigsRequest_v2,
]
DescribeConfigsResponse = [
DescribeConfigsResponse_v0, DescribeConfigsResponse_v1,
DescribeConfigsResponse_v2,
]
class SaslAuthenticateResponse_v0(Response):
API_KEY = 36
API_VERSION = 0
SCHEMA = Schema(
('error_code', Int16),
('error_message', String('utf-8')),
('sasl_auth_bytes', Bytes)
)
class SaslAuthenticateResponse_v1(Response):
API_KEY = 36
API_VERSION = 1
SCHEMA = Schema(
('error_code', Int16),
('error_message', String('utf-8')),
('sasl_auth_bytes', Bytes),
('session_lifetime_ms', Int64)
)
class SaslAuthenticateRequest_v0(Request):
API_KEY = 36
API_VERSION = 0
RESPONSE_TYPE = SaslAuthenticateResponse_v0
SCHEMA = Schema(
('sasl_auth_bytes', Bytes)
)
class SaslAuthenticateRequest_v1(Request):
API_KEY = 36
API_VERSION = 1
RESPONSE_TYPE = SaslAuthenticateResponse_v1
SCHEMA = SaslAuthenticateRequest_v0.SCHEMA
SaslAuthenticateRequest = [
SaslAuthenticateRequest_v0, SaslAuthenticateRequest_v1,
]
SaslAuthenticateResponse = [
SaslAuthenticateResponse_v0, SaslAuthenticateResponse_v1,
]
class CreatePartitionsResponse_v0(Response):
API_KEY = 37
API_VERSION = 0
SCHEMA = Schema(
('throttle_time_ms', Int32),
('topic_errors', Array(
('topic', String('utf-8')),
('error_code', Int16),
('error_message', String('utf-8'))))
)
class CreatePartitionsResponse_v1(Response):
API_KEY = 37
API_VERSION = 1
SCHEMA = CreatePartitionsResponse_v0.SCHEMA
class CreatePartitionsRequest_v0(Request):
API_KEY = 37
API_VERSION = 0
RESPONSE_TYPE = CreatePartitionsResponse_v0
SCHEMA = Schema(
('topic_partitions', Array(
('topic', String('utf-8')),
('new_partitions', Schema(
('count', Int32),
('assignment', Array(Array(Int32))))))),
('timeout', Int32),
('validate_only', Boolean)
)
class CreatePartitionsRequest_v1(Request):
API_KEY = 37
API_VERSION = 1
SCHEMA = CreatePartitionsRequest_v0.SCHEMA
RESPONSE_TYPE = CreatePartitionsResponse_v1
CreatePartitionsRequest = [
CreatePartitionsRequest_v0, CreatePartitionsRequest_v1,
]
CreatePartitionsResponse = [
CreatePartitionsResponse_v0, CreatePartitionsResponse_v1,
]
class DeleteGroupsResponse_v0(Response):
API_KEY = 42
API_VERSION = 0
SCHEMA = Schema(
("throttle_time_ms", Int32),
("results", Array(
("group_id", String("utf-8")),
("error_code", Int16)))
)
class DeleteGroupsResponse_v1(Response):
API_KEY = 42
API_VERSION = 1
SCHEMA = DeleteGroupsResponse_v0.SCHEMA
class DeleteGroupsRequest_v0(Request):
API_KEY = 42
API_VERSION = 0
RESPONSE_TYPE = DeleteGroupsResponse_v0
SCHEMA = Schema(
("groups_names", Array(String("utf-8")))
)
class DeleteGroupsRequest_v1(Request):
API_KEY = 42
API_VERSION = 1
RESPONSE_TYPE = DeleteGroupsResponse_v1
SCHEMA = DeleteGroupsRequest_v0.SCHEMA
DeleteGroupsRequest = [
DeleteGroupsRequest_v0, DeleteGroupsRequest_v1
]
DeleteGroupsResponse = [
DeleteGroupsResponse_v0, DeleteGroupsResponse_v1
]

View File

@@ -0,0 +1,97 @@
from __future__ import absolute_import
import abc
from kafka.protocol.struct import Struct
from kafka.protocol.types import Int16, Int32, String, Schema, Array
class RequestHeader(Struct):
SCHEMA = Schema(
('api_key', Int16),
('api_version', Int16),
('correlation_id', Int32),
('client_id', String('utf-8'))
)
def __init__(self, request, correlation_id=0, client_id='kafka-python'):
super(RequestHeader, self).__init__(
request.API_KEY, request.API_VERSION, correlation_id, client_id
)
class Request(Struct):
__metaclass__ = abc.ABCMeta
@abc.abstractproperty
def API_KEY(self):
"""Integer identifier for api request"""
pass
@abc.abstractproperty
def API_VERSION(self):
"""Integer of api request version"""
pass
@abc.abstractproperty
def SCHEMA(self):
"""An instance of Schema() representing the request structure"""
pass
@abc.abstractproperty
def RESPONSE_TYPE(self):
"""The Response class associated with the api request"""
pass
def expect_response(self):
"""Override this method if an api request does not always generate a response"""
return True
def to_object(self):
return _to_object(self.SCHEMA, self)
class Response(Struct):
__metaclass__ = abc.ABCMeta
@abc.abstractproperty
def API_KEY(self):
"""Integer identifier for api request/response"""
pass
@abc.abstractproperty
def API_VERSION(self):
"""Integer of api request/response version"""
pass
@abc.abstractproperty
def SCHEMA(self):
"""An instance of Schema() representing the response structure"""
pass
def to_object(self):
return _to_object(self.SCHEMA, self)
def _to_object(schema, data):
obj = {}
for idx, (name, _type) in enumerate(zip(schema.names, schema.fields)):
if isinstance(data, Struct):
val = data.get_item(name)
else:
val = data[idx]
if isinstance(_type, Schema):
obj[name] = _to_object(_type, val)
elif isinstance(_type, Array):
if isinstance(_type.array_of, (Array, Schema)):
obj[name] = [
_to_object(_type.array_of, x)
for x in val
]
else:
obj[name] = val
else:
obj[name] = val
return obj

View File

@@ -0,0 +1,255 @@
from __future__ import absolute_import
from kafka.protocol.api import Request, Response
from kafka.protocol.types import Array, Int8, Int16, Int32, Int64, Schema, String
class OffsetCommitResponse_v0(Response):
API_KEY = 8
API_VERSION = 0
SCHEMA = Schema(
('topics', Array(
('topic', String('utf-8')),
('partitions', Array(
('partition', Int32),
('error_code', Int16)))))
)
class OffsetCommitResponse_v1(Response):
API_KEY = 8
API_VERSION = 1
SCHEMA = OffsetCommitResponse_v0.SCHEMA
class OffsetCommitResponse_v2(Response):
API_KEY = 8
API_VERSION = 2
SCHEMA = OffsetCommitResponse_v1.SCHEMA
class OffsetCommitResponse_v3(Response):
API_KEY = 8
API_VERSION = 3
SCHEMA = Schema(
('throttle_time_ms', Int32),
('topics', Array(
('topic', String('utf-8')),
('partitions', Array(
('partition', Int32),
('error_code', Int16)))))
)
class OffsetCommitRequest_v0(Request):
API_KEY = 8
API_VERSION = 0 # Zookeeper-backed storage
RESPONSE_TYPE = OffsetCommitResponse_v0
SCHEMA = Schema(
('consumer_group', String('utf-8')),
('topics', Array(
('topic', String('utf-8')),
('partitions', Array(
('partition', Int32),
('offset', Int64),
('metadata', String('utf-8'))))))
)
class OffsetCommitRequest_v1(Request):
API_KEY = 8
API_VERSION = 1 # Kafka-backed storage
RESPONSE_TYPE = OffsetCommitResponse_v1
SCHEMA = Schema(
('consumer_group', String('utf-8')),
('consumer_group_generation_id', Int32),
('consumer_id', String('utf-8')),
('topics', Array(
('topic', String('utf-8')),
('partitions', Array(
('partition', Int32),
('offset', Int64),
('timestamp', Int64),
('metadata', String('utf-8'))))))
)
class OffsetCommitRequest_v2(Request):
API_KEY = 8
API_VERSION = 2 # added retention_time, dropped timestamp
RESPONSE_TYPE = OffsetCommitResponse_v2
SCHEMA = Schema(
('consumer_group', String('utf-8')),
('consumer_group_generation_id', Int32),
('consumer_id', String('utf-8')),
('retention_time', Int64),
('topics', Array(
('topic', String('utf-8')),
('partitions', Array(
('partition', Int32),
('offset', Int64),
('metadata', String('utf-8'))))))
)
DEFAULT_GENERATION_ID = -1
DEFAULT_RETENTION_TIME = -1
class OffsetCommitRequest_v3(Request):
API_KEY = 8
API_VERSION = 3
RESPONSE_TYPE = OffsetCommitResponse_v3
SCHEMA = OffsetCommitRequest_v2.SCHEMA
OffsetCommitRequest = [
OffsetCommitRequest_v0, OffsetCommitRequest_v1,
OffsetCommitRequest_v2, OffsetCommitRequest_v3
]
OffsetCommitResponse = [
OffsetCommitResponse_v0, OffsetCommitResponse_v1,
OffsetCommitResponse_v2, OffsetCommitResponse_v3
]
class OffsetFetchResponse_v0(Response):
API_KEY = 9
API_VERSION = 0
SCHEMA = Schema(
('topics', Array(
('topic', String('utf-8')),
('partitions', Array(
('partition', Int32),
('offset', Int64),
('metadata', String('utf-8')),
('error_code', Int16)))))
)
class OffsetFetchResponse_v1(Response):
API_KEY = 9
API_VERSION = 1
SCHEMA = OffsetFetchResponse_v0.SCHEMA
class OffsetFetchResponse_v2(Response):
# Added in KIP-88
API_KEY = 9
API_VERSION = 2
SCHEMA = Schema(
('topics', Array(
('topic', String('utf-8')),
('partitions', Array(
('partition', Int32),
('offset', Int64),
('metadata', String('utf-8')),
('error_code', Int16))))),
('error_code', Int16)
)
class OffsetFetchResponse_v3(Response):
API_KEY = 9
API_VERSION = 3
SCHEMA = Schema(
('throttle_time_ms', Int32),
('topics', Array(
('topic', String('utf-8')),
('partitions', Array(
('partition', Int32),
('offset', Int64),
('metadata', String('utf-8')),
('error_code', Int16))))),
('error_code', Int16)
)
class OffsetFetchRequest_v0(Request):
API_KEY = 9
API_VERSION = 0 # zookeeper-backed storage
RESPONSE_TYPE = OffsetFetchResponse_v0
SCHEMA = Schema(
('consumer_group', String('utf-8')),
('topics', Array(
('topic', String('utf-8')),
('partitions', Array(Int32))))
)
class OffsetFetchRequest_v1(Request):
API_KEY = 9
API_VERSION = 1 # kafka-backed storage
RESPONSE_TYPE = OffsetFetchResponse_v1
SCHEMA = OffsetFetchRequest_v0.SCHEMA
class OffsetFetchRequest_v2(Request):
# KIP-88: Allows passing null topics to return offsets for all partitions
# that the consumer group has a stored offset for, even if no consumer in
# the group is currently consuming that partition.
API_KEY = 9
API_VERSION = 2
RESPONSE_TYPE = OffsetFetchResponse_v2
SCHEMA = OffsetFetchRequest_v1.SCHEMA
class OffsetFetchRequest_v3(Request):
API_KEY = 9
API_VERSION = 3
RESPONSE_TYPE = OffsetFetchResponse_v3
SCHEMA = OffsetFetchRequest_v2.SCHEMA
OffsetFetchRequest = [
OffsetFetchRequest_v0, OffsetFetchRequest_v1,
OffsetFetchRequest_v2, OffsetFetchRequest_v3,
]
OffsetFetchResponse = [
OffsetFetchResponse_v0, OffsetFetchResponse_v1,
OffsetFetchResponse_v2, OffsetFetchResponse_v3,
]
class GroupCoordinatorResponse_v0(Response):
API_KEY = 10
API_VERSION = 0
SCHEMA = Schema(
('error_code', Int16),
('coordinator_id', Int32),
('host', String('utf-8')),
('port', Int32)
)
class GroupCoordinatorResponse_v1(Response):
API_KEY = 10
API_VERSION = 1
SCHEMA = Schema(
('error_code', Int16),
('error_message', String('utf-8')),
('coordinator_id', Int32),
('host', String('utf-8')),
('port', Int32)
)
class GroupCoordinatorRequest_v0(Request):
API_KEY = 10
API_VERSION = 0
RESPONSE_TYPE = GroupCoordinatorResponse_v0
SCHEMA = Schema(
('consumer_group', String('utf-8'))
)
class GroupCoordinatorRequest_v1(Request):
API_KEY = 10
API_VERSION = 1
RESPONSE_TYPE = GroupCoordinatorResponse_v1
SCHEMA = Schema(
('coordinator_key', String('utf-8')),
('coordinator_type', Int8)
)
GroupCoordinatorRequest = [GroupCoordinatorRequest_v0, GroupCoordinatorRequest_v1]
GroupCoordinatorResponse = [GroupCoordinatorResponse_v0, GroupCoordinatorResponse_v1]

View File

@@ -0,0 +1,386 @@
from __future__ import absolute_import
from kafka.protocol.api import Request, Response
from kafka.protocol.types import Array, Int8, Int16, Int32, Int64, Schema, String, Bytes
class FetchResponse_v0(Response):
API_KEY = 1
API_VERSION = 0
SCHEMA = Schema(
('topics', Array(
('topics', String('utf-8')),
('partitions', Array(
('partition', Int32),
('error_code', Int16),
('highwater_offset', Int64),
('message_set', Bytes)))))
)
class FetchResponse_v1(Response):
API_KEY = 1
API_VERSION = 1
SCHEMA = Schema(
('throttle_time_ms', Int32),
('topics', Array(
('topics', String('utf-8')),
('partitions', Array(
('partition', Int32),
('error_code', Int16),
('highwater_offset', Int64),
('message_set', Bytes)))))
)
class FetchResponse_v2(Response):
API_KEY = 1
API_VERSION = 2
SCHEMA = FetchResponse_v1.SCHEMA # message format changed internally
class FetchResponse_v3(Response):
API_KEY = 1
API_VERSION = 3
SCHEMA = FetchResponse_v2.SCHEMA
class FetchResponse_v4(Response):
API_KEY = 1
API_VERSION = 4
SCHEMA = Schema(
('throttle_time_ms', Int32),
('topics', Array(
('topics', String('utf-8')),
('partitions', Array(
('partition', Int32),
('error_code', Int16),
('highwater_offset', Int64),
('last_stable_offset', Int64),
('aborted_transactions', Array(
('producer_id', Int64),
('first_offset', Int64))),
('message_set', Bytes)))))
)
class FetchResponse_v5(Response):
API_KEY = 1
API_VERSION = 5
SCHEMA = Schema(
('throttle_time_ms', Int32),
('topics', Array(
('topics', String('utf-8')),
('partitions', Array(
('partition', Int32),
('error_code', Int16),
('highwater_offset', Int64),
('last_stable_offset', Int64),
('log_start_offset', Int64),
('aborted_transactions', Array(
('producer_id', Int64),
('first_offset', Int64))),
('message_set', Bytes)))))
)
class FetchResponse_v6(Response):
"""
Same as FetchResponse_v5. The version number is bumped up to indicate that the client supports KafkaStorageException.
The KafkaStorageException will be translated to NotLeaderForPartitionException in the response if version <= 5
"""
API_KEY = 1
API_VERSION = 6
SCHEMA = FetchResponse_v5.SCHEMA
class FetchResponse_v7(Response):
"""
Add error_code and session_id to response
"""
API_KEY = 1
API_VERSION = 7
SCHEMA = Schema(
('throttle_time_ms', Int32),
('error_code', Int16),
('session_id', Int32),
('topics', Array(
('topics', String('utf-8')),
('partitions', Array(
('partition', Int32),
('error_code', Int16),
('highwater_offset', Int64),
('last_stable_offset', Int64),
('log_start_offset', Int64),
('aborted_transactions', Array(
('producer_id', Int64),
('first_offset', Int64))),
('message_set', Bytes)))))
)
class FetchResponse_v8(Response):
API_KEY = 1
API_VERSION = 8
SCHEMA = FetchResponse_v7.SCHEMA
class FetchResponse_v9(Response):
API_KEY = 1
API_VERSION = 9
SCHEMA = FetchResponse_v7.SCHEMA
class FetchResponse_v10(Response):
API_KEY = 1
API_VERSION = 10
SCHEMA = FetchResponse_v7.SCHEMA
class FetchResponse_v11(Response):
API_KEY = 1
API_VERSION = 11
SCHEMA = Schema(
('throttle_time_ms', Int32),
('error_code', Int16),
('session_id', Int32),
('topics', Array(
('topics', String('utf-8')),
('partitions', Array(
('partition', Int32),
('error_code', Int16),
('highwater_offset', Int64),
('last_stable_offset', Int64),
('log_start_offset', Int64),
('aborted_transactions', Array(
('producer_id', Int64),
('first_offset', Int64))),
('preferred_read_replica', Int32),
('message_set', Bytes)))))
)
class FetchRequest_v0(Request):
API_KEY = 1
API_VERSION = 0
RESPONSE_TYPE = FetchResponse_v0
SCHEMA = Schema(
('replica_id', Int32),
('max_wait_time', Int32),
('min_bytes', Int32),
('topics', Array(
('topic', String('utf-8')),
('partitions', Array(
('partition', Int32),
('offset', Int64),
('max_bytes', Int32)))))
)
class FetchRequest_v1(Request):
API_KEY = 1
API_VERSION = 1
RESPONSE_TYPE = FetchResponse_v1
SCHEMA = FetchRequest_v0.SCHEMA
class FetchRequest_v2(Request):
API_KEY = 1
API_VERSION = 2
RESPONSE_TYPE = FetchResponse_v2
SCHEMA = FetchRequest_v1.SCHEMA
class FetchRequest_v3(Request):
API_KEY = 1
API_VERSION = 3
RESPONSE_TYPE = FetchResponse_v3
SCHEMA = Schema(
('replica_id', Int32),
('max_wait_time', Int32),
('min_bytes', Int32),
('max_bytes', Int32), # This new field is only difference from FR_v2
('topics', Array(
('topic', String('utf-8')),
('partitions', Array(
('partition', Int32),
('offset', Int64),
('max_bytes', Int32)))))
)
class FetchRequest_v4(Request):
# Adds isolation_level field
API_KEY = 1
API_VERSION = 4
RESPONSE_TYPE = FetchResponse_v4
SCHEMA = Schema(
('replica_id', Int32),
('max_wait_time', Int32),
('min_bytes', Int32),
('max_bytes', Int32),
('isolation_level', Int8),
('topics', Array(
('topic', String('utf-8')),
('partitions', Array(
('partition', Int32),
('offset', Int64),
('max_bytes', Int32)))))
)
class FetchRequest_v5(Request):
# This may only be used in broker-broker api calls
API_KEY = 1
API_VERSION = 5
RESPONSE_TYPE = FetchResponse_v5
SCHEMA = Schema(
('replica_id', Int32),
('max_wait_time', Int32),
('min_bytes', Int32),
('max_bytes', Int32),
('isolation_level', Int8),
('topics', Array(
('topic', String('utf-8')),
('partitions', Array(
('partition', Int32),
('fetch_offset', Int64),
('log_start_offset', Int64),
('max_bytes', Int32)))))
)
class FetchRequest_v6(Request):
"""
The body of FETCH_REQUEST_V6 is the same as FETCH_REQUEST_V5.
The version number is bumped up to indicate that the client supports KafkaStorageException.
The KafkaStorageException will be translated to NotLeaderForPartitionException in the response if version <= 5
"""
API_KEY = 1
API_VERSION = 6
RESPONSE_TYPE = FetchResponse_v6
SCHEMA = FetchRequest_v5.SCHEMA
class FetchRequest_v7(Request):
"""
Add incremental fetch requests
"""
API_KEY = 1
API_VERSION = 7
RESPONSE_TYPE = FetchResponse_v7
SCHEMA = Schema(
('replica_id', Int32),
('max_wait_time', Int32),
('min_bytes', Int32),
('max_bytes', Int32),
('isolation_level', Int8),
('session_id', Int32),
('session_epoch', Int32),
('topics', Array(
('topic', String('utf-8')),
('partitions', Array(
('partition', Int32),
('fetch_offset', Int64),
('log_start_offset', Int64),
('max_bytes', Int32))))),
('forgotten_topics_data', Array(
('topic', String),
('partitions', Array(Int32))
)),
)
class FetchRequest_v8(Request):
"""
bump used to indicate that on quota violation brokers send out responses before throttling.
"""
API_KEY = 1
API_VERSION = 8
RESPONSE_TYPE = FetchResponse_v8
SCHEMA = FetchRequest_v7.SCHEMA
class FetchRequest_v9(Request):
"""
adds the current leader epoch (see KIP-320)
"""
API_KEY = 1
API_VERSION = 9
RESPONSE_TYPE = FetchResponse_v9
SCHEMA = Schema(
('replica_id', Int32),
('max_wait_time', Int32),
('min_bytes', Int32),
('max_bytes', Int32),
('isolation_level', Int8),
('session_id', Int32),
('session_epoch', Int32),
('topics', Array(
('topic', String('utf-8')),
('partitions', Array(
('partition', Int32),
('current_leader_epoch', Int32),
('fetch_offset', Int64),
('log_start_offset', Int64),
('max_bytes', Int32))))),
('forgotten_topics_data', Array(
('topic', String),
('partitions', Array(Int32)),
)),
)
class FetchRequest_v10(Request):
"""
bumped up to indicate ZStandard capability. (see KIP-110)
"""
API_KEY = 1
API_VERSION = 10
RESPONSE_TYPE = FetchResponse_v10
SCHEMA = FetchRequest_v9.SCHEMA
class FetchRequest_v11(Request):
"""
added rack ID to support read from followers (KIP-392)
"""
API_KEY = 1
API_VERSION = 11
RESPONSE_TYPE = FetchResponse_v11
SCHEMA = Schema(
('replica_id', Int32),
('max_wait_time', Int32),
('min_bytes', Int32),
('max_bytes', Int32),
('isolation_level', Int8),
('session_id', Int32),
('session_epoch', Int32),
('topics', Array(
('topic', String('utf-8')),
('partitions', Array(
('partition', Int32),
('current_leader_epoch', Int32),
('fetch_offset', Int64),
('log_start_offset', Int64),
('max_bytes', Int32))))),
('forgotten_topics_data', Array(
('topic', String),
('partitions', Array(Int32))
)),
('rack_id', String('utf-8')),
)
FetchRequest = [
FetchRequest_v0, FetchRequest_v1, FetchRequest_v2,
FetchRequest_v3, FetchRequest_v4, FetchRequest_v5,
FetchRequest_v6, FetchRequest_v7, FetchRequest_v8,
FetchRequest_v9, FetchRequest_v10, FetchRequest_v11,
]
FetchResponse = [
FetchResponse_v0, FetchResponse_v1, FetchResponse_v2,
FetchResponse_v3, FetchResponse_v4, FetchResponse_v5,
FetchResponse_v6, FetchResponse_v7, FetchResponse_v8,
FetchResponse_v9, FetchResponse_v10, FetchResponse_v11,
]

View File

@@ -0,0 +1,30 @@
class KafkaBytes(bytearray):
def __init__(self, size):
super(KafkaBytes, self).__init__(size)
self._idx = 0
def read(self, nbytes=None):
if nbytes is None:
nbytes = len(self) - self._idx
start = self._idx
self._idx += nbytes
if self._idx > len(self):
self._idx = len(self)
return bytes(self[start:self._idx])
def write(self, data):
start = self._idx
self._idx += len(data)
self[start:self._idx] = data
def seek(self, idx):
self._idx = idx
def tell(self):
return self._idx
def __str__(self):
return 'KafkaBytes(%d)' % len(self)
def __repr__(self):
return str(self)

View File

@@ -0,0 +1,230 @@
from __future__ import absolute_import
from kafka.protocol.api import Request, Response
from kafka.protocol.struct import Struct
from kafka.protocol.types import Array, Bytes, Int16, Int32, Schema, String
class JoinGroupResponse_v0(Response):
API_KEY = 11
API_VERSION = 0
SCHEMA = Schema(
('error_code', Int16),
('generation_id', Int32),
('group_protocol', String('utf-8')),
('leader_id', String('utf-8')),
('member_id', String('utf-8')),
('members', Array(
('member_id', String('utf-8')),
('member_metadata', Bytes)))
)
class JoinGroupResponse_v1(Response):
API_KEY = 11
API_VERSION = 1
SCHEMA = JoinGroupResponse_v0.SCHEMA
class JoinGroupResponse_v2(Response):
API_KEY = 11
API_VERSION = 2
SCHEMA = Schema(
('throttle_time_ms', Int32),
('error_code', Int16),
('generation_id', Int32),
('group_protocol', String('utf-8')),
('leader_id', String('utf-8')),
('member_id', String('utf-8')),
('members', Array(
('member_id', String('utf-8')),
('member_metadata', Bytes)))
)
class JoinGroupRequest_v0(Request):
API_KEY = 11
API_VERSION = 0
RESPONSE_TYPE = JoinGroupResponse_v0
SCHEMA = Schema(
('group', String('utf-8')),
('session_timeout', Int32),
('member_id', String('utf-8')),
('protocol_type', String('utf-8')),
('group_protocols', Array(
('protocol_name', String('utf-8')),
('protocol_metadata', Bytes)))
)
UNKNOWN_MEMBER_ID = ''
class JoinGroupRequest_v1(Request):
API_KEY = 11
API_VERSION = 1
RESPONSE_TYPE = JoinGroupResponse_v1
SCHEMA = Schema(
('group', String('utf-8')),
('session_timeout', Int32),
('rebalance_timeout', Int32),
('member_id', String('utf-8')),
('protocol_type', String('utf-8')),
('group_protocols', Array(
('protocol_name', String('utf-8')),
('protocol_metadata', Bytes)))
)
UNKNOWN_MEMBER_ID = ''
class JoinGroupRequest_v2(Request):
API_KEY = 11
API_VERSION = 2
RESPONSE_TYPE = JoinGroupResponse_v2
SCHEMA = JoinGroupRequest_v1.SCHEMA
UNKNOWN_MEMBER_ID = ''
JoinGroupRequest = [
JoinGroupRequest_v0, JoinGroupRequest_v1, JoinGroupRequest_v2
]
JoinGroupResponse = [
JoinGroupResponse_v0, JoinGroupResponse_v1, JoinGroupResponse_v2
]
class ProtocolMetadata(Struct):
SCHEMA = Schema(
('version', Int16),
('subscription', Array(String('utf-8'))), # topics list
('user_data', Bytes)
)
class SyncGroupResponse_v0(Response):
API_KEY = 14
API_VERSION = 0
SCHEMA = Schema(
('error_code', Int16),
('member_assignment', Bytes)
)
class SyncGroupResponse_v1(Response):
API_KEY = 14
API_VERSION = 1
SCHEMA = Schema(
('throttle_time_ms', Int32),
('error_code', Int16),
('member_assignment', Bytes)
)
class SyncGroupRequest_v0(Request):
API_KEY = 14
API_VERSION = 0
RESPONSE_TYPE = SyncGroupResponse_v0
SCHEMA = Schema(
('group', String('utf-8')),
('generation_id', Int32),
('member_id', String('utf-8')),
('group_assignment', Array(
('member_id', String('utf-8')),
('member_metadata', Bytes)))
)
class SyncGroupRequest_v1(Request):
API_KEY = 14
API_VERSION = 1
RESPONSE_TYPE = SyncGroupResponse_v1
SCHEMA = SyncGroupRequest_v0.SCHEMA
SyncGroupRequest = [SyncGroupRequest_v0, SyncGroupRequest_v1]
SyncGroupResponse = [SyncGroupResponse_v0, SyncGroupResponse_v1]
class MemberAssignment(Struct):
SCHEMA = Schema(
('version', Int16),
('assignment', Array(
('topic', String('utf-8')),
('partitions', Array(Int32)))),
('user_data', Bytes)
)
class HeartbeatResponse_v0(Response):
API_KEY = 12
API_VERSION = 0
SCHEMA = Schema(
('error_code', Int16)
)
class HeartbeatResponse_v1(Response):
API_KEY = 12
API_VERSION = 1
SCHEMA = Schema(
('throttle_time_ms', Int32),
('error_code', Int16)
)
class HeartbeatRequest_v0(Request):
API_KEY = 12
API_VERSION = 0
RESPONSE_TYPE = HeartbeatResponse_v0
SCHEMA = Schema(
('group', String('utf-8')),
('generation_id', Int32),
('member_id', String('utf-8'))
)
class HeartbeatRequest_v1(Request):
API_KEY = 12
API_VERSION = 1
RESPONSE_TYPE = HeartbeatResponse_v1
SCHEMA = HeartbeatRequest_v0.SCHEMA
HeartbeatRequest = [HeartbeatRequest_v0, HeartbeatRequest_v1]
HeartbeatResponse = [HeartbeatResponse_v0, HeartbeatResponse_v1]
class LeaveGroupResponse_v0(Response):
API_KEY = 13
API_VERSION = 0
SCHEMA = Schema(
('error_code', Int16)
)
class LeaveGroupResponse_v1(Response):
API_KEY = 13
API_VERSION = 1
SCHEMA = Schema(
('throttle_time_ms', Int32),
('error_code', Int16)
)
class LeaveGroupRequest_v0(Request):
API_KEY = 13
API_VERSION = 0
RESPONSE_TYPE = LeaveGroupResponse_v0
SCHEMA = Schema(
('group', String('utf-8')),
('member_id', String('utf-8'))
)
class LeaveGroupRequest_v1(Request):
API_KEY = 13
API_VERSION = 1
RESPONSE_TYPE = LeaveGroupResponse_v1
SCHEMA = LeaveGroupRequest_v0.SCHEMA
LeaveGroupRequest = [LeaveGroupRequest_v0, LeaveGroupRequest_v1]
LeaveGroupResponse = [LeaveGroupResponse_v0, LeaveGroupResponse_v1]

View File

@@ -0,0 +1,216 @@
from __future__ import absolute_import
import io
import time
from kafka.codec import (has_gzip, has_snappy, has_lz4, has_zstd,
gzip_decode, snappy_decode, zstd_decode,
lz4_decode, lz4_decode_old_kafka)
from kafka.protocol.frame import KafkaBytes
from kafka.protocol.struct import Struct
from kafka.protocol.types import (
Int8, Int32, Int64, Bytes, Schema, AbstractType
)
from kafka.util import crc32, WeakMethod
class Message(Struct):
SCHEMAS = [
Schema(
('crc', Int32),
('magic', Int8),
('attributes', Int8),
('key', Bytes),
('value', Bytes)),
Schema(
('crc', Int32),
('magic', Int8),
('attributes', Int8),
('timestamp', Int64),
('key', Bytes),
('value', Bytes)),
]
SCHEMA = SCHEMAS[1]
CODEC_MASK = 0x07
CODEC_GZIP = 0x01
CODEC_SNAPPY = 0x02
CODEC_LZ4 = 0x03
CODEC_ZSTD = 0x04
TIMESTAMP_TYPE_MASK = 0x08
HEADER_SIZE = 22 # crc(4), magic(1), attributes(1), timestamp(8), key+value size(4*2)
def __init__(self, value, key=None, magic=0, attributes=0, crc=0,
timestamp=None):
assert value is None or isinstance(value, bytes), 'value must be bytes'
assert key is None or isinstance(key, bytes), 'key must be bytes'
assert magic > 0 or timestamp is None, 'timestamp not supported in v0'
# Default timestamp to now for v1 messages
if magic > 0 and timestamp is None:
timestamp = int(time.time() * 1000)
self.timestamp = timestamp
self.crc = crc
self._validated_crc = None
self.magic = magic
self.attributes = attributes
self.key = key
self.value = value
self.encode = WeakMethod(self._encode_self)
@property
def timestamp_type(self):
"""0 for CreateTime; 1 for LogAppendTime; None if unsupported.
Value is determined by broker; produced messages should always set to 0
Requires Kafka >= 0.10 / message version >= 1
"""
if self.magic == 0:
return None
elif self.attributes & self.TIMESTAMP_TYPE_MASK:
return 1
else:
return 0
def _encode_self(self, recalc_crc=True):
version = self.magic
if version == 1:
fields = (self.crc, self.magic, self.attributes, self.timestamp, self.key, self.value)
elif version == 0:
fields = (self.crc, self.magic, self.attributes, self.key, self.value)
else:
raise ValueError('Unrecognized message version: %s' % (version,))
message = Message.SCHEMAS[version].encode(fields)
if not recalc_crc:
return message
self.crc = crc32(message[4:])
crc_field = self.SCHEMAS[version].fields[0]
return crc_field.encode(self.crc) + message[4:]
@classmethod
def decode(cls, data):
_validated_crc = None
if isinstance(data, bytes):
_validated_crc = crc32(data[4:])
data = io.BytesIO(data)
# Partial decode required to determine message version
base_fields = cls.SCHEMAS[0].fields[0:3]
crc, magic, attributes = [field.decode(data) for field in base_fields]
remaining = cls.SCHEMAS[magic].fields[3:]
fields = [field.decode(data) for field in remaining]
if magic == 1:
timestamp = fields[0]
else:
timestamp = None
msg = cls(fields[-1], key=fields[-2],
magic=magic, attributes=attributes, crc=crc,
timestamp=timestamp)
msg._validated_crc = _validated_crc
return msg
def validate_crc(self):
if self._validated_crc is None:
raw_msg = self._encode_self(recalc_crc=False)
self._validated_crc = crc32(raw_msg[4:])
if self.crc == self._validated_crc:
return True
return False
def is_compressed(self):
return self.attributes & self.CODEC_MASK != 0
def decompress(self):
codec = self.attributes & self.CODEC_MASK
assert codec in (self.CODEC_GZIP, self.CODEC_SNAPPY, self.CODEC_LZ4, self.CODEC_ZSTD)
if codec == self.CODEC_GZIP:
assert has_gzip(), 'Gzip decompression unsupported'
raw_bytes = gzip_decode(self.value)
elif codec == self.CODEC_SNAPPY:
assert has_snappy(), 'Snappy decompression unsupported'
raw_bytes = snappy_decode(self.value)
elif codec == self.CODEC_LZ4:
assert has_lz4(), 'LZ4 decompression unsupported'
if self.magic == 0:
raw_bytes = lz4_decode_old_kafka(self.value)
else:
raw_bytes = lz4_decode(self.value)
elif codec == self.CODEC_ZSTD:
assert has_zstd(), "ZSTD decompression unsupported"
raw_bytes = zstd_decode(self.value)
else:
raise Exception('This should be impossible')
return MessageSet.decode(raw_bytes, bytes_to_read=len(raw_bytes))
def __hash__(self):
return hash(self._encode_self(recalc_crc=False))
class PartialMessage(bytes):
def __repr__(self):
return 'PartialMessage(%s)' % (self,)
class MessageSet(AbstractType):
ITEM = Schema(
('offset', Int64),
('message', Bytes)
)
HEADER_SIZE = 12 # offset + message_size
@classmethod
def encode(cls, items, prepend_size=True):
# RecordAccumulator encodes messagesets internally
if isinstance(items, (io.BytesIO, KafkaBytes)):
size = Int32.decode(items)
if prepend_size:
# rewind and return all the bytes
items.seek(items.tell() - 4)
size += 4
return items.read(size)
encoded_values = []
for (offset, message) in items:
encoded_values.append(Int64.encode(offset))
encoded_values.append(Bytes.encode(message))
encoded = b''.join(encoded_values)
if prepend_size:
return Bytes.encode(encoded)
else:
return encoded
@classmethod
def decode(cls, data, bytes_to_read=None):
"""Compressed messages should pass in bytes_to_read (via message size)
otherwise, we decode from data as Int32
"""
if isinstance(data, bytes):
data = io.BytesIO(data)
if bytes_to_read is None:
bytes_to_read = Int32.decode(data)
# if FetchRequest max_bytes is smaller than the available message set
# the server returns partial data for the final message
# So create an internal buffer to avoid over-reading
raw = io.BytesIO(data.read(bytes_to_read))
items = []
while bytes_to_read:
try:
offset = Int64.decode(raw)
msg_bytes = Bytes.decode(raw)
bytes_to_read -= 8 + 4 + len(msg_bytes)
items.append((offset, len(msg_bytes), Message.decode(msg_bytes)))
except ValueError:
# PartialMessage to signal that max_bytes may be too small
items.append((None, None, PartialMessage()))
break
return items
@classmethod
def repr(cls, messages):
if isinstance(messages, (KafkaBytes, io.BytesIO)):
offset = messages.tell()
decoded = cls.decode(messages)
messages.seek(offset)
messages = decoded
return str([cls.ITEM.repr(m) for m in messages])

View File

@@ -0,0 +1,200 @@
from __future__ import absolute_import
from kafka.protocol.api import Request, Response
from kafka.protocol.types import Array, Boolean, Int16, Int32, Schema, String
class MetadataResponse_v0(Response):
API_KEY = 3
API_VERSION = 0
SCHEMA = Schema(
('brokers', Array(
('node_id', Int32),
('host', String('utf-8')),
('port', Int32))),
('topics', Array(
('error_code', Int16),
('topic', String('utf-8')),
('partitions', Array(
('error_code', Int16),
('partition', Int32),
('leader', Int32),
('replicas', Array(Int32)),
('isr', Array(Int32))))))
)
class MetadataResponse_v1(Response):
API_KEY = 3
API_VERSION = 1
SCHEMA = Schema(
('brokers', Array(
('node_id', Int32),
('host', String('utf-8')),
('port', Int32),
('rack', String('utf-8')))),
('controller_id', Int32),
('topics', Array(
('error_code', Int16),
('topic', String('utf-8')),
('is_internal', Boolean),
('partitions', Array(
('error_code', Int16),
('partition', Int32),
('leader', Int32),
('replicas', Array(Int32)),
('isr', Array(Int32))))))
)
class MetadataResponse_v2(Response):
API_KEY = 3
API_VERSION = 2
SCHEMA = Schema(
('brokers', Array(
('node_id', Int32),
('host', String('utf-8')),
('port', Int32),
('rack', String('utf-8')))),
('cluster_id', String('utf-8')), # <-- Added cluster_id field in v2
('controller_id', Int32),
('topics', Array(
('error_code', Int16),
('topic', String('utf-8')),
('is_internal', Boolean),
('partitions', Array(
('error_code', Int16),
('partition', Int32),
('leader', Int32),
('replicas', Array(Int32)),
('isr', Array(Int32))))))
)
class MetadataResponse_v3(Response):
API_KEY = 3
API_VERSION = 3
SCHEMA = Schema(
('throttle_time_ms', Int32),
('brokers', Array(
('node_id', Int32),
('host', String('utf-8')),
('port', Int32),
('rack', String('utf-8')))),
('cluster_id', String('utf-8')),
('controller_id', Int32),
('topics', Array(
('error_code', Int16),
('topic', String('utf-8')),
('is_internal', Boolean),
('partitions', Array(
('error_code', Int16),
('partition', Int32),
('leader', Int32),
('replicas', Array(Int32)),
('isr', Array(Int32))))))
)
class MetadataResponse_v4(Response):
API_KEY = 3
API_VERSION = 4
SCHEMA = MetadataResponse_v3.SCHEMA
class MetadataResponse_v5(Response):
API_KEY = 3
API_VERSION = 5
SCHEMA = Schema(
('throttle_time_ms', Int32),
('brokers', Array(
('node_id', Int32),
('host', String('utf-8')),
('port', Int32),
('rack', String('utf-8')))),
('cluster_id', String('utf-8')),
('controller_id', Int32),
('topics', Array(
('error_code', Int16),
('topic', String('utf-8')),
('is_internal', Boolean),
('partitions', Array(
('error_code', Int16),
('partition', Int32),
('leader', Int32),
('replicas', Array(Int32)),
('isr', Array(Int32)),
('offline_replicas', Array(Int32))))))
)
class MetadataRequest_v0(Request):
API_KEY = 3
API_VERSION = 0
RESPONSE_TYPE = MetadataResponse_v0
SCHEMA = Schema(
('topics', Array(String('utf-8')))
)
ALL_TOPICS = None # Empty Array (len 0) for topics returns all topics
class MetadataRequest_v1(Request):
API_KEY = 3
API_VERSION = 1
RESPONSE_TYPE = MetadataResponse_v1
SCHEMA = MetadataRequest_v0.SCHEMA
ALL_TOPICS = -1 # Null Array (len -1) for topics returns all topics
NO_TOPICS = None # Empty array (len 0) for topics returns no topics
class MetadataRequest_v2(Request):
API_KEY = 3
API_VERSION = 2
RESPONSE_TYPE = MetadataResponse_v2
SCHEMA = MetadataRequest_v1.SCHEMA
ALL_TOPICS = -1 # Null Array (len -1) for topics returns all topics
NO_TOPICS = None # Empty array (len 0) for topics returns no topics
class MetadataRequest_v3(Request):
API_KEY = 3
API_VERSION = 3
RESPONSE_TYPE = MetadataResponse_v3
SCHEMA = MetadataRequest_v1.SCHEMA
ALL_TOPICS = -1 # Null Array (len -1) for topics returns all topics
NO_TOPICS = None # Empty array (len 0) for topics returns no topics
class MetadataRequest_v4(Request):
API_KEY = 3
API_VERSION = 4
RESPONSE_TYPE = MetadataResponse_v4
SCHEMA = Schema(
('topics', Array(String('utf-8'))),
('allow_auto_topic_creation', Boolean)
)
ALL_TOPICS = -1 # Null Array (len -1) for topics returns all topics
NO_TOPICS = None # Empty array (len 0) for topics returns no topics
class MetadataRequest_v5(Request):
"""
The v5 metadata request is the same as v4.
An additional field for offline_replicas has been added to the v5 metadata response
"""
API_KEY = 3
API_VERSION = 5
RESPONSE_TYPE = MetadataResponse_v5
SCHEMA = MetadataRequest_v4.SCHEMA
ALL_TOPICS = -1 # Null Array (len -1) for topics returns all topics
NO_TOPICS = None # Empty array (len 0) for topics returns no topics
MetadataRequest = [
MetadataRequest_v0, MetadataRequest_v1, MetadataRequest_v2,
MetadataRequest_v3, MetadataRequest_v4, MetadataRequest_v5
]
MetadataResponse = [
MetadataResponse_v0, MetadataResponse_v1, MetadataResponse_v2,
MetadataResponse_v3, MetadataResponse_v4, MetadataResponse_v5
]

View File

@@ -0,0 +1,194 @@
from __future__ import absolute_import
from kafka.protocol.api import Request, Response
from kafka.protocol.types import Array, Int8, Int16, Int32, Int64, Schema, String
UNKNOWN_OFFSET = -1
class OffsetResetStrategy(object):
LATEST = -1
EARLIEST = -2
NONE = 0
class OffsetResponse_v0(Response):
API_KEY = 2
API_VERSION = 0
SCHEMA = Schema(
('topics', Array(
('topic', String('utf-8')),
('partitions', Array(
('partition', Int32),
('error_code', Int16),
('offsets', Array(Int64))))))
)
class OffsetResponse_v1(Response):
API_KEY = 2
API_VERSION = 1
SCHEMA = Schema(
('topics', Array(
('topic', String('utf-8')),
('partitions', Array(
('partition', Int32),
('error_code', Int16),
('timestamp', Int64),
('offset', Int64)))))
)
class OffsetResponse_v2(Response):
API_KEY = 2
API_VERSION = 2
SCHEMA = Schema(
('throttle_time_ms', Int32),
('topics', Array(
('topic', String('utf-8')),
('partitions', Array(
('partition', Int32),
('error_code', Int16),
('timestamp', Int64),
('offset', Int64)))))
)
class OffsetResponse_v3(Response):
"""
on quota violation, brokers send out responses before throttling
"""
API_KEY = 2
API_VERSION = 3
SCHEMA = OffsetResponse_v2.SCHEMA
class OffsetResponse_v4(Response):
"""
Add leader_epoch to response
"""
API_KEY = 2
API_VERSION = 4
SCHEMA = Schema(
('throttle_time_ms', Int32),
('topics', Array(
('topic', String('utf-8')),
('partitions', Array(
('partition', Int32),
('error_code', Int16),
('timestamp', Int64),
('offset', Int64),
('leader_epoch', Int32)))))
)
class OffsetResponse_v5(Response):
"""
adds a new error code, OFFSET_NOT_AVAILABLE
"""
API_KEY = 2
API_VERSION = 5
SCHEMA = OffsetResponse_v4.SCHEMA
class OffsetRequest_v0(Request):
API_KEY = 2
API_VERSION = 0
RESPONSE_TYPE = OffsetResponse_v0
SCHEMA = Schema(
('replica_id', Int32),
('topics', Array(
('topic', String('utf-8')),
('partitions', Array(
('partition', Int32),
('timestamp', Int64),
('max_offsets', Int32)))))
)
DEFAULTS = {
'replica_id': -1
}
class OffsetRequest_v1(Request):
API_KEY = 2
API_VERSION = 1
RESPONSE_TYPE = OffsetResponse_v1
SCHEMA = Schema(
('replica_id', Int32),
('topics', Array(
('topic', String('utf-8')),
('partitions', Array(
('partition', Int32),
('timestamp', Int64)))))
)
DEFAULTS = {
'replica_id': -1
}
class OffsetRequest_v2(Request):
API_KEY = 2
API_VERSION = 2
RESPONSE_TYPE = OffsetResponse_v2
SCHEMA = Schema(
('replica_id', Int32),
('isolation_level', Int8), # <- added isolation_level
('topics', Array(
('topic', String('utf-8')),
('partitions', Array(
('partition', Int32),
('timestamp', Int64)))))
)
DEFAULTS = {
'replica_id': -1
}
class OffsetRequest_v3(Request):
API_KEY = 2
API_VERSION = 3
RESPONSE_TYPE = OffsetResponse_v3
SCHEMA = OffsetRequest_v2.SCHEMA
DEFAULTS = {
'replica_id': -1
}
class OffsetRequest_v4(Request):
"""
Add current_leader_epoch to request
"""
API_KEY = 2
API_VERSION = 4
RESPONSE_TYPE = OffsetResponse_v4
SCHEMA = Schema(
('replica_id', Int32),
('isolation_level', Int8), # <- added isolation_level
('topics', Array(
('topic', String('utf-8')),
('partitions', Array(
('partition', Int32),
('current_leader_epoch', Int64),
('timestamp', Int64)))))
)
DEFAULTS = {
'replica_id': -1
}
class OffsetRequest_v5(Request):
API_KEY = 2
API_VERSION = 5
RESPONSE_TYPE = OffsetResponse_v5
SCHEMA = OffsetRequest_v4.SCHEMA
DEFAULTS = {
'replica_id': -1
}
OffsetRequest = [
OffsetRequest_v0, OffsetRequest_v1, OffsetRequest_v2,
OffsetRequest_v3, OffsetRequest_v4, OffsetRequest_v5,
]
OffsetResponse = [
OffsetResponse_v0, OffsetResponse_v1, OffsetResponse_v2,
OffsetResponse_v3, OffsetResponse_v4, OffsetResponse_v5,
]

View File

@@ -0,0 +1,183 @@
from __future__ import absolute_import
import collections
import logging
import kafka.errors as Errors
from kafka.protocol.api import RequestHeader
from kafka.protocol.commit import GroupCoordinatorResponse
from kafka.protocol.frame import KafkaBytes
from kafka.protocol.types import Int32
from kafka.version import __version__
log = logging.getLogger(__name__)
class KafkaProtocol(object):
"""Manage the kafka network protocol
Use an instance of KafkaProtocol to manage bytes send/recv'd
from a network socket to a broker.
Arguments:
client_id (str): identifier string to be included in each request
api_version (tuple): Optional tuple to specify api_version to use.
Currently only used to check for 0.8.2 protocol quirks, but
may be used for more in the future.
"""
def __init__(self, client_id=None, api_version=None):
if client_id is None:
client_id = self._gen_client_id()
self._client_id = client_id
self._api_version = api_version
self._correlation_id = 0
self._header = KafkaBytes(4)
self._rbuffer = None
self._receiving = False
self.in_flight_requests = collections.deque()
self.bytes_to_send = []
def _next_correlation_id(self):
self._correlation_id = (self._correlation_id + 1) % 2**31
return self._correlation_id
def _gen_client_id(self):
return 'kafka-python' + __version__
def send_request(self, request, correlation_id=None):
"""Encode and queue a kafka api request for sending.
Arguments:
request (object): An un-encoded kafka request.
correlation_id (int, optional): Optionally specify an ID to
correlate requests with responses. If not provided, an ID will
be generated automatically.
Returns:
correlation_id
"""
log.debug('Sending request %s', request)
if correlation_id is None:
correlation_id = self._next_correlation_id()
header = RequestHeader(request,
correlation_id=correlation_id,
client_id=self._client_id)
message = b''.join([header.encode(), request.encode()])
size = Int32.encode(len(message))
data = size + message
self.bytes_to_send.append(data)
if request.expect_response():
ifr = (correlation_id, request)
self.in_flight_requests.append(ifr)
return correlation_id
def send_bytes(self):
"""Retrieve all pending bytes to send on the network"""
data = b''.join(self.bytes_to_send)
self.bytes_to_send = []
return data
def receive_bytes(self, data):
"""Process bytes received from the network.
Arguments:
data (bytes): any length bytes received from a network connection
to a kafka broker.
Returns:
responses (list of (correlation_id, response)): any/all completed
responses, decoded from bytes to python objects.
Raises:
KafkaProtocolError: if the bytes received could not be decoded.
CorrelationIdError: if the response does not match the request
correlation id.
"""
i = 0
n = len(data)
responses = []
while i < n:
# Not receiving is the state of reading the payload header
if not self._receiving:
bytes_to_read = min(4 - self._header.tell(), n - i)
self._header.write(data[i:i+bytes_to_read])
i += bytes_to_read
if self._header.tell() == 4:
self._header.seek(0)
nbytes = Int32.decode(self._header)
# reset buffer and switch state to receiving payload bytes
self._rbuffer = KafkaBytes(nbytes)
self._receiving = True
elif self._header.tell() > 4:
raise Errors.KafkaError('this should not happen - are you threading?')
if self._receiving:
total_bytes = len(self._rbuffer)
staged_bytes = self._rbuffer.tell()
bytes_to_read = min(total_bytes - staged_bytes, n - i)
self._rbuffer.write(data[i:i+bytes_to_read])
i += bytes_to_read
staged_bytes = self._rbuffer.tell()
if staged_bytes > total_bytes:
raise Errors.KafkaError('Receive buffer has more bytes than expected?')
if staged_bytes != total_bytes:
break
self._receiving = False
self._rbuffer.seek(0)
resp = self._process_response(self._rbuffer)
responses.append(resp)
self._reset_buffer()
return responses
def _process_response(self, read_buffer):
recv_correlation_id = Int32.decode(read_buffer)
log.debug('Received correlation id: %d', recv_correlation_id)
if not self.in_flight_requests:
raise Errors.CorrelationIdError(
'No in-flight-request found for server response'
' with correlation ID %d'
% (recv_correlation_id,))
(correlation_id, request) = self.in_flight_requests.popleft()
# 0.8.2 quirk
if (recv_correlation_id == 0 and
correlation_id != 0 and
request.RESPONSE_TYPE is GroupCoordinatorResponse[0] and
(self._api_version == (0, 8, 2) or self._api_version is None)):
log.warning('Kafka 0.8.2 quirk -- GroupCoordinatorResponse'
' Correlation ID does not match request. This'
' should go away once at least one topic has been'
' initialized on the broker.')
elif correlation_id != recv_correlation_id:
# return or raise?
raise Errors.CorrelationIdError(
'Correlation IDs do not match: sent %d, recv %d'
% (correlation_id, recv_correlation_id))
# decode response
log.debug('Processing response %s', request.RESPONSE_TYPE.__name__)
try:
response = request.RESPONSE_TYPE.decode(read_buffer)
except ValueError:
read_buffer.seek(0)
buf = read_buffer.read()
log.error('Response %d [ResponseType: %s Request: %s]:'
' Unable to decode %d-byte buffer: %r',
correlation_id, request.RESPONSE_TYPE,
request, len(buf), buf)
raise Errors.KafkaProtocolError('Unable to decode response')
return (correlation_id, response)
def _reset_buffer(self):
self._receiving = False
self._header.seek(0)
self._rbuffer = None

View File

@@ -0,0 +1,35 @@
from __future__ import absolute_import
try:
import copyreg # pylint: disable=import-error
except ImportError:
import copy_reg as copyreg # pylint: disable=import-error
import types
def _pickle_method(method):
try:
func_name = method.__func__.__name__
obj = method.__self__
cls = method.__self__.__class__
except AttributeError:
func_name = method.im_func.__name__
obj = method.im_self
cls = method.im_class
return _unpickle_method, (func_name, obj, cls)
def _unpickle_method(func_name, obj, cls):
for cls in cls.mro():
try:
func = cls.__dict__[func_name]
except KeyError:
pass
else:
break
return func.__get__(obj, cls)
# https://bytes.com/topic/python/answers/552476-why-cant-you-pickle-instancemethods
copyreg.pickle(types.MethodType, _pickle_method, _unpickle_method)

View File

@@ -0,0 +1,232 @@
from __future__ import absolute_import
from kafka.protocol.api import Request, Response
from kafka.protocol.types import Int16, Int32, Int64, String, Array, Schema, Bytes
class ProduceResponse_v0(Response):
API_KEY = 0
API_VERSION = 0
SCHEMA = Schema(
('topics', Array(
('topic', String('utf-8')),
('partitions', Array(
('partition', Int32),
('error_code', Int16),
('offset', Int64)))))
)
class ProduceResponse_v1(Response):
API_KEY = 0
API_VERSION = 1
SCHEMA = Schema(
('topics', Array(
('topic', String('utf-8')),
('partitions', Array(
('partition', Int32),
('error_code', Int16),
('offset', Int64))))),
('throttle_time_ms', Int32)
)
class ProduceResponse_v2(Response):
API_KEY = 0
API_VERSION = 2
SCHEMA = Schema(
('topics', Array(
('topic', String('utf-8')),
('partitions', Array(
('partition', Int32),
('error_code', Int16),
('offset', Int64),
('timestamp', Int64))))),
('throttle_time_ms', Int32)
)
class ProduceResponse_v3(Response):
API_KEY = 0
API_VERSION = 3
SCHEMA = ProduceResponse_v2.SCHEMA
class ProduceResponse_v4(Response):
"""
The version number is bumped up to indicate that the client supports KafkaStorageException.
The KafkaStorageException will be translated to NotLeaderForPartitionException in the response if version <= 3
"""
API_KEY = 0
API_VERSION = 4
SCHEMA = ProduceResponse_v3.SCHEMA
class ProduceResponse_v5(Response):
API_KEY = 0
API_VERSION = 5
SCHEMA = Schema(
('topics', Array(
('topic', String('utf-8')),
('partitions', Array(
('partition', Int32),
('error_code', Int16),
('offset', Int64),
('timestamp', Int64),
('log_start_offset', Int64))))),
('throttle_time_ms', Int32)
)
class ProduceResponse_v6(Response):
"""
The version number is bumped to indicate that on quota violation brokers send out responses before throttling.
"""
API_KEY = 0
API_VERSION = 6
SCHEMA = ProduceResponse_v5.SCHEMA
class ProduceResponse_v7(Response):
"""
V7 bumped up to indicate ZStandard capability. (see KIP-110)
"""
API_KEY = 0
API_VERSION = 7
SCHEMA = ProduceResponse_v6.SCHEMA
class ProduceResponse_v8(Response):
"""
V8 bumped up to add two new fields record_errors offset list and error_message
(See KIP-467)
"""
API_KEY = 0
API_VERSION = 8
SCHEMA = Schema(
('topics', Array(
('topic', String('utf-8')),
('partitions', Array(
('partition', Int32),
('error_code', Int16),
('offset', Int64),
('timestamp', Int64),
('log_start_offset', Int64)),
('record_errors', (Array(
('batch_index', Int32),
('batch_index_error_message', String('utf-8'))
))),
('error_message', String('utf-8'))
))),
('throttle_time_ms', Int32)
)
class ProduceRequest(Request):
API_KEY = 0
def expect_response(self):
if self.required_acks == 0: # pylint: disable=no-member
return False
return True
class ProduceRequest_v0(ProduceRequest):
API_VERSION = 0
RESPONSE_TYPE = ProduceResponse_v0
SCHEMA = Schema(
('required_acks', Int16),
('timeout', Int32),
('topics', Array(
('topic', String('utf-8')),
('partitions', Array(
('partition', Int32),
('messages', Bytes)))))
)
class ProduceRequest_v1(ProduceRequest):
API_VERSION = 1
RESPONSE_TYPE = ProduceResponse_v1
SCHEMA = ProduceRequest_v0.SCHEMA
class ProduceRequest_v2(ProduceRequest):
API_VERSION = 2
RESPONSE_TYPE = ProduceResponse_v2
SCHEMA = ProduceRequest_v1.SCHEMA
class ProduceRequest_v3(ProduceRequest):
API_VERSION = 3
RESPONSE_TYPE = ProduceResponse_v3
SCHEMA = Schema(
('transactional_id', String('utf-8')),
('required_acks', Int16),
('timeout', Int32),
('topics', Array(
('topic', String('utf-8')),
('partitions', Array(
('partition', Int32),
('messages', Bytes)))))
)
class ProduceRequest_v4(ProduceRequest):
"""
The version number is bumped up to indicate that the client supports KafkaStorageException.
The KafkaStorageException will be translated to NotLeaderForPartitionException in the response if version <= 3
"""
API_VERSION = 4
RESPONSE_TYPE = ProduceResponse_v4
SCHEMA = ProduceRequest_v3.SCHEMA
class ProduceRequest_v5(ProduceRequest):
"""
Same as v4. The version number is bumped since the v5 response includes an additional
partition level field: the log_start_offset.
"""
API_VERSION = 5
RESPONSE_TYPE = ProduceResponse_v5
SCHEMA = ProduceRequest_v4.SCHEMA
class ProduceRequest_v6(ProduceRequest):
"""
The version number is bumped to indicate that on quota violation brokers send out responses before throttling.
"""
API_VERSION = 6
RESPONSE_TYPE = ProduceResponse_v6
SCHEMA = ProduceRequest_v5.SCHEMA
class ProduceRequest_v7(ProduceRequest):
"""
V7 bumped up to indicate ZStandard capability. (see KIP-110)
"""
API_VERSION = 7
RESPONSE_TYPE = ProduceResponse_v7
SCHEMA = ProduceRequest_v6.SCHEMA
class ProduceRequest_v8(ProduceRequest):
"""
V8 bumped up to add two new fields record_errors offset list and error_message to PartitionResponse
(See KIP-467)
"""
API_VERSION = 8
RESPONSE_TYPE = ProduceResponse_v8
SCHEMA = ProduceRequest_v7.SCHEMA
ProduceRequest = [
ProduceRequest_v0, ProduceRequest_v1, ProduceRequest_v2,
ProduceRequest_v3, ProduceRequest_v4, ProduceRequest_v5,
ProduceRequest_v6, ProduceRequest_v7, ProduceRequest_v8,
]
ProduceResponse = [
ProduceResponse_v0, ProduceResponse_v1, ProduceResponse_v2,
ProduceResponse_v3, ProduceResponse_v4, ProduceResponse_v5,
ProduceResponse_v6, ProduceResponse_v7, ProduceResponse_v8,
]

View File

@@ -0,0 +1,72 @@
from __future__ import absolute_import
from io import BytesIO
from kafka.protocol.abstract import AbstractType
from kafka.protocol.types import Schema
from kafka.util import WeakMethod
class Struct(AbstractType):
SCHEMA = Schema()
def __init__(self, *args, **kwargs):
if len(args) == len(self.SCHEMA.fields):
for i, name in enumerate(self.SCHEMA.names):
self.__dict__[name] = args[i]
elif len(args) > 0:
raise ValueError('Args must be empty or mirror schema')
else:
for name in self.SCHEMA.names:
self.__dict__[name] = kwargs.pop(name, None)
if kwargs:
raise ValueError('Keyword(s) not in schema %s: %s'
% (list(self.SCHEMA.names),
', '.join(kwargs.keys())))
# overloading encode() to support both class and instance
# Without WeakMethod() this creates circular ref, which
# causes instances to "leak" to garbage
self.encode = WeakMethod(self._encode_self)
@classmethod
def encode(cls, item): # pylint: disable=E0202
bits = []
for i, field in enumerate(cls.SCHEMA.fields):
bits.append(field.encode(item[i]))
return b''.join(bits)
def _encode_self(self):
return self.SCHEMA.encode(
[self.__dict__[name] for name in self.SCHEMA.names]
)
@classmethod
def decode(cls, data):
if isinstance(data, bytes):
data = BytesIO(data)
return cls(*[field.decode(data) for field in cls.SCHEMA.fields])
def get_item(self, name):
if name not in self.SCHEMA.names:
raise KeyError("%s is not in the schema" % name)
return self.__dict__[name]
def __repr__(self):
key_vals = []
for name, field in zip(self.SCHEMA.names, self.SCHEMA.fields):
key_vals.append('%s=%s' % (name, field.repr(self.__dict__[name])))
return self.__class__.__name__ + '(' + ', '.join(key_vals) + ')'
def __hash__(self):
return hash(self.encode())
def __eq__(self, other):
if self.SCHEMA != other.SCHEMA:
return False
for attr in self.SCHEMA.names:
if self.__dict__[attr] != other.__dict__[attr]:
return False
return True

View File

@@ -0,0 +1,198 @@
from __future__ import absolute_import
import struct
from struct import error
from kafka.protocol.abstract import AbstractType
def _pack(f, value):
try:
return f(value)
except error as e:
raise ValueError("Error encountered when attempting to convert value: "
"{!r} to struct format: '{}', hit error: {}"
.format(value, f, e))
def _unpack(f, data):
try:
(value,) = f(data)
return value
except error as e:
raise ValueError("Error encountered when attempting to convert value: "
"{!r} to struct format: '{}', hit error: {}"
.format(data, f, e))
class Int8(AbstractType):
_pack = struct.Struct('>b').pack
_unpack = struct.Struct('>b').unpack
@classmethod
def encode(cls, value):
return _pack(cls._pack, value)
@classmethod
def decode(cls, data):
return _unpack(cls._unpack, data.read(1))
class Int16(AbstractType):
_pack = struct.Struct('>h').pack
_unpack = struct.Struct('>h').unpack
@classmethod
def encode(cls, value):
return _pack(cls._pack, value)
@classmethod
def decode(cls, data):
return _unpack(cls._unpack, data.read(2))
class Int32(AbstractType):
_pack = struct.Struct('>i').pack
_unpack = struct.Struct('>i').unpack
@classmethod
def encode(cls, value):
return _pack(cls._pack, value)
@classmethod
def decode(cls, data):
return _unpack(cls._unpack, data.read(4))
class Int64(AbstractType):
_pack = struct.Struct('>q').pack
_unpack = struct.Struct('>q').unpack
@classmethod
def encode(cls, value):
return _pack(cls._pack, value)
@classmethod
def decode(cls, data):
return _unpack(cls._unpack, data.read(8))
class String(AbstractType):
def __init__(self, encoding='utf-8'):
self.encoding = encoding
def encode(self, value):
if value is None:
return Int16.encode(-1)
value = str(value).encode(self.encoding)
return Int16.encode(len(value)) + value
def decode(self, data):
length = Int16.decode(data)
if length < 0:
return None
value = data.read(length)
if len(value) != length:
raise ValueError('Buffer underrun decoding string')
return value.decode(self.encoding)
class Bytes(AbstractType):
@classmethod
def encode(cls, value):
if value is None:
return Int32.encode(-1)
else:
return Int32.encode(len(value)) + value
@classmethod
def decode(cls, data):
length = Int32.decode(data)
if length < 0:
return None
value = data.read(length)
if len(value) != length:
raise ValueError('Buffer underrun decoding Bytes')
return value
@classmethod
def repr(cls, value):
return repr(value[:100] + b'...' if value is not None and len(value) > 100 else value)
class Boolean(AbstractType):
_pack = struct.Struct('>?').pack
_unpack = struct.Struct('>?').unpack
@classmethod
def encode(cls, value):
return _pack(cls._pack, value)
@classmethod
def decode(cls, data):
return _unpack(cls._unpack, data.read(1))
class Schema(AbstractType):
def __init__(self, *fields):
if fields:
self.names, self.fields = zip(*fields)
else:
self.names, self.fields = (), ()
def encode(self, item):
if len(item) != len(self.fields):
raise ValueError('Item field count does not match Schema')
return b''.join([
field.encode(item[i])
for i, field in enumerate(self.fields)
])
def decode(self, data):
return tuple([field.decode(data) for field in self.fields])
def __len__(self):
return len(self.fields)
def repr(self, value):
key_vals = []
try:
for i in range(len(self)):
try:
field_val = getattr(value, self.names[i])
except AttributeError:
field_val = value[i]
key_vals.append('%s=%s' % (self.names[i], self.fields[i].repr(field_val)))
return '(' + ', '.join(key_vals) + ')'
except Exception:
return repr(value)
class Array(AbstractType):
def __init__(self, *array_of):
if len(array_of) > 1:
self.array_of = Schema(*array_of)
elif len(array_of) == 1 and (isinstance(array_of[0], AbstractType) or
issubclass(array_of[0], AbstractType)):
self.array_of = array_of[0]
else:
raise ValueError('Array instantiated with no array_of type')
def encode(self, items):
if items is None:
return Int32.encode(-1)
return b''.join(
[Int32.encode(len(items))] +
[self.array_of.encode(item) for item in items]
)
def decode(self, data):
length = Int32.decode(data)
if length == -1:
return None
return [self.array_of.decode(data) for _ in range(length)]
def repr(self, list_of_items):
if list_of_items is None:
return 'NULL'
return '[' + ', '.join([self.array_of.repr(item) for item in list_of_items]) + ']'

View File

@@ -0,0 +1,3 @@
from kafka.record.memory_records import MemoryRecords, MemoryRecordsBuilder
__all__ = ["MemoryRecords", "MemoryRecordsBuilder"]

View File

@@ -0,0 +1,145 @@
#!/usr/bin/env python
#
# Taken from https://cloud.google.com/appengine/docs/standard/python/refdocs/\
# modules/google/appengine/api/files/crc32c?hl=ru
#
# Copyright 2007 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Implementation of CRC-32C checksumming as in rfc3720 section B.4.
See https://en.wikipedia.org/wiki/Cyclic_redundancy_check for details on CRC-32C
This code is a manual python translation of c code generated by
pycrc 0.7.1 (https://pycrc.org/). Command line used:
'./pycrc.py --model=crc-32c --generate c --algorithm=table-driven'
"""
import array
CRC_TABLE = (
0x00000000, 0xf26b8303, 0xe13b70f7, 0x1350f3f4,
0xc79a971f, 0x35f1141c, 0x26a1e7e8, 0xd4ca64eb,
0x8ad958cf, 0x78b2dbcc, 0x6be22838, 0x9989ab3b,
0x4d43cfd0, 0xbf284cd3, 0xac78bf27, 0x5e133c24,
0x105ec76f, 0xe235446c, 0xf165b798, 0x030e349b,
0xd7c45070, 0x25afd373, 0x36ff2087, 0xc494a384,
0x9a879fa0, 0x68ec1ca3, 0x7bbcef57, 0x89d76c54,
0x5d1d08bf, 0xaf768bbc, 0xbc267848, 0x4e4dfb4b,
0x20bd8ede, 0xd2d60ddd, 0xc186fe29, 0x33ed7d2a,
0xe72719c1, 0x154c9ac2, 0x061c6936, 0xf477ea35,
0xaa64d611, 0x580f5512, 0x4b5fa6e6, 0xb93425e5,
0x6dfe410e, 0x9f95c20d, 0x8cc531f9, 0x7eaeb2fa,
0x30e349b1, 0xc288cab2, 0xd1d83946, 0x23b3ba45,
0xf779deae, 0x05125dad, 0x1642ae59, 0xe4292d5a,
0xba3a117e, 0x4851927d, 0x5b016189, 0xa96ae28a,
0x7da08661, 0x8fcb0562, 0x9c9bf696, 0x6ef07595,
0x417b1dbc, 0xb3109ebf, 0xa0406d4b, 0x522bee48,
0x86e18aa3, 0x748a09a0, 0x67dafa54, 0x95b17957,
0xcba24573, 0x39c9c670, 0x2a993584, 0xd8f2b687,
0x0c38d26c, 0xfe53516f, 0xed03a29b, 0x1f682198,
0x5125dad3, 0xa34e59d0, 0xb01eaa24, 0x42752927,
0x96bf4dcc, 0x64d4cecf, 0x77843d3b, 0x85efbe38,
0xdbfc821c, 0x2997011f, 0x3ac7f2eb, 0xc8ac71e8,
0x1c661503, 0xee0d9600, 0xfd5d65f4, 0x0f36e6f7,
0x61c69362, 0x93ad1061, 0x80fde395, 0x72966096,
0xa65c047d, 0x5437877e, 0x4767748a, 0xb50cf789,
0xeb1fcbad, 0x197448ae, 0x0a24bb5a, 0xf84f3859,
0x2c855cb2, 0xdeeedfb1, 0xcdbe2c45, 0x3fd5af46,
0x7198540d, 0x83f3d70e, 0x90a324fa, 0x62c8a7f9,
0xb602c312, 0x44694011, 0x5739b3e5, 0xa55230e6,
0xfb410cc2, 0x092a8fc1, 0x1a7a7c35, 0xe811ff36,
0x3cdb9bdd, 0xceb018de, 0xdde0eb2a, 0x2f8b6829,
0x82f63b78, 0x709db87b, 0x63cd4b8f, 0x91a6c88c,
0x456cac67, 0xb7072f64, 0xa457dc90, 0x563c5f93,
0x082f63b7, 0xfa44e0b4, 0xe9141340, 0x1b7f9043,
0xcfb5f4a8, 0x3dde77ab, 0x2e8e845f, 0xdce5075c,
0x92a8fc17, 0x60c37f14, 0x73938ce0, 0x81f80fe3,
0x55326b08, 0xa759e80b, 0xb4091bff, 0x466298fc,
0x1871a4d8, 0xea1a27db, 0xf94ad42f, 0x0b21572c,
0xdfeb33c7, 0x2d80b0c4, 0x3ed04330, 0xccbbc033,
0xa24bb5a6, 0x502036a5, 0x4370c551, 0xb11b4652,
0x65d122b9, 0x97baa1ba, 0x84ea524e, 0x7681d14d,
0x2892ed69, 0xdaf96e6a, 0xc9a99d9e, 0x3bc21e9d,
0xef087a76, 0x1d63f975, 0x0e330a81, 0xfc588982,
0xb21572c9, 0x407ef1ca, 0x532e023e, 0xa145813d,
0x758fe5d6, 0x87e466d5, 0x94b49521, 0x66df1622,
0x38cc2a06, 0xcaa7a905, 0xd9f75af1, 0x2b9cd9f2,
0xff56bd19, 0x0d3d3e1a, 0x1e6dcdee, 0xec064eed,
0xc38d26c4, 0x31e6a5c7, 0x22b65633, 0xd0ddd530,
0x0417b1db, 0xf67c32d8, 0xe52cc12c, 0x1747422f,
0x49547e0b, 0xbb3ffd08, 0xa86f0efc, 0x5a048dff,
0x8ecee914, 0x7ca56a17, 0x6ff599e3, 0x9d9e1ae0,
0xd3d3e1ab, 0x21b862a8, 0x32e8915c, 0xc083125f,
0x144976b4, 0xe622f5b7, 0xf5720643, 0x07198540,
0x590ab964, 0xab613a67, 0xb831c993, 0x4a5a4a90,
0x9e902e7b, 0x6cfbad78, 0x7fab5e8c, 0x8dc0dd8f,
0xe330a81a, 0x115b2b19, 0x020bd8ed, 0xf0605bee,
0x24aa3f05, 0xd6c1bc06, 0xc5914ff2, 0x37faccf1,
0x69e9f0d5, 0x9b8273d6, 0x88d28022, 0x7ab90321,
0xae7367ca, 0x5c18e4c9, 0x4f48173d, 0xbd23943e,
0xf36e6f75, 0x0105ec76, 0x12551f82, 0xe03e9c81,
0x34f4f86a, 0xc69f7b69, 0xd5cf889d, 0x27a40b9e,
0x79b737ba, 0x8bdcb4b9, 0x988c474d, 0x6ae7c44e,
0xbe2da0a5, 0x4c4623a6, 0x5f16d052, 0xad7d5351,
)
CRC_INIT = 0
_MASK = 0xFFFFFFFF
def crc_update(crc, data):
"""Update CRC-32C checksum with data.
Args:
crc: 32-bit checksum to update as long.
data: byte array, string or iterable over bytes.
Returns:
32-bit updated CRC-32C as long.
"""
if type(data) != array.array or data.itemsize != 1:
buf = array.array("B", data)
else:
buf = data
crc = crc ^ _MASK
for b in buf:
table_index = (crc ^ b) & 0xff
crc = (CRC_TABLE[table_index] ^ (crc >> 8)) & _MASK
return crc ^ _MASK
def crc_finalize(crc):
"""Finalize CRC-32C checksum.
This function should be called as last step of crc calculation.
Args:
crc: 32-bit checksum as long.
Returns:
finalized 32-bit checksum as long
"""
return crc & _MASK
def crc(data):
"""Compute CRC-32C checksum of the data.
Args:
data: byte array, string or iterable over bytes.
Returns:
32-bit CRC-32C checksum of data as long.
"""
return crc_finalize(crc_update(CRC_INIT, data))
if __name__ == "__main__":
import sys
# TODO remove the pylint disable once pylint fixes
# https://github.com/PyCQA/pylint/issues/2571
data = sys.stdin.read() # pylint: disable=assignment-from-no-return
print(hex(crc(data)))

View File

@@ -0,0 +1,124 @@
from __future__ import absolute_import
import abc
class ABCRecord(object):
__metaclass__ = abc.ABCMeta
__slots__ = ()
@abc.abstractproperty
def offset(self):
""" Absolute offset of record
"""
@abc.abstractproperty
def timestamp(self):
""" Epoch milliseconds
"""
@abc.abstractproperty
def timestamp_type(self):
""" CREATE_TIME(0) or APPEND_TIME(1)
"""
@abc.abstractproperty
def key(self):
""" Bytes key or None
"""
@abc.abstractproperty
def value(self):
""" Bytes value or None
"""
@abc.abstractproperty
def checksum(self):
""" Prior to v2 format CRC was contained in every message. This will
be the checksum for v0 and v1 and None for v2 and above.
"""
@abc.abstractproperty
def headers(self):
""" If supported by version list of key-value tuples, or empty list if
not supported by format.
"""
class ABCRecordBatchBuilder(object):
__metaclass__ = abc.ABCMeta
__slots__ = ()
@abc.abstractmethod
def append(self, offset, timestamp, key, value, headers=None):
""" Writes record to internal buffer.
Arguments:
offset (int): Relative offset of record, starting from 0
timestamp (int or None): Timestamp in milliseconds since beginning
of the epoch (midnight Jan 1, 1970 (UTC)). If omitted, will be
set to current time.
key (bytes or None): Key of the record
value (bytes or None): Value of the record
headers (List[Tuple[str, bytes]]): Headers of the record. Header
keys can not be ``None``.
Returns:
(bytes, int): Checksum of the written record (or None for v2 and
above) and size of the written record.
"""
@abc.abstractmethod
def size_in_bytes(self, offset, timestamp, key, value, headers):
""" Return the expected size change on buffer (uncompressed) if we add
this message. This will account for varint size changes and give a
reliable size.
"""
@abc.abstractmethod
def build(self):
""" Close for append, compress if needed, write size and header and
return a ready to send buffer object.
Return:
bytearray: finished batch, ready to send.
"""
class ABCRecordBatch(object):
""" For v2 incapsulates a RecordBatch, for v0/v1 a single (maybe
compressed) message.
"""
__metaclass__ = abc.ABCMeta
__slots__ = ()
@abc.abstractmethod
def __iter__(self):
""" Return iterator over records (ABCRecord instances). Will decompress
if needed.
"""
class ABCRecords(object):
__metaclass__ = abc.ABCMeta
__slots__ = ()
@abc.abstractmethod
def __init__(self, buffer):
""" Initialize with bytes-like object conforming to the buffer
interface (ie. bytes, bytearray, memoryview etc.).
"""
@abc.abstractmethod
def size_in_bytes(self):
""" Returns the size of inner buffer.
"""
@abc.abstractmethod
def next_batch(self):
""" Return next batch of records (ABCRecordBatch instances).
"""
@abc.abstractmethod
def has_next(self):
""" True if there are more batches to read, False otherwise.
"""

View File

@@ -0,0 +1,630 @@
# See:
# https://github.com/apache/kafka/blob/trunk/clients/src/main/java/org/\
# apache/kafka/common/record/DefaultRecordBatch.java
# https://github.com/apache/kafka/blob/trunk/clients/src/main/java/org/\
# apache/kafka/common/record/DefaultRecord.java
# RecordBatch and Record implementation for magic 2 and above.
# The schema is given below:
# RecordBatch =>
# BaseOffset => Int64
# Length => Int32
# PartitionLeaderEpoch => Int32
# Magic => Int8
# CRC => Uint32
# Attributes => Int16
# LastOffsetDelta => Int32 // also serves as LastSequenceDelta
# FirstTimestamp => Int64
# MaxTimestamp => Int64
# ProducerId => Int64
# ProducerEpoch => Int16
# BaseSequence => Int32
# Records => [Record]
# Record =>
# Length => Varint
# Attributes => Int8
# TimestampDelta => Varlong
# OffsetDelta => Varint
# Key => Bytes
# Value => Bytes
# Headers => [HeaderKey HeaderValue]
# HeaderKey => String
# HeaderValue => Bytes
# Note that when compression is enabled (see attributes below), the compressed
# record data is serialized directly following the count of the number of
# records. (ie Records => [Record], but without length bytes)
# The CRC covers the data from the attributes to the end of the batch (i.e. all
# the bytes that follow the CRC). It is located after the magic byte, which
# means that clients must parse the magic byte before deciding how to interpret
# the bytes between the batch length and the magic byte. The partition leader
# epoch field is not included in the CRC computation to avoid the need to
# recompute the CRC when this field is assigned for every batch that is
# received by the broker. The CRC-32C (Castagnoli) polynomial is used for the
# computation.
# The current RecordBatch attributes are given below:
#
# * Unused (6-15)
# * Control (5)
# * Transactional (4)
# * Timestamp Type (3)
# * Compression Type (0-2)
import struct
import time
from kafka.record.abc import ABCRecord, ABCRecordBatch, ABCRecordBatchBuilder
from kafka.record.util import (
decode_varint, encode_varint, calc_crc32c, size_of_varint
)
from kafka.errors import CorruptRecordException, UnsupportedCodecError
from kafka.codec import (
gzip_encode, snappy_encode, lz4_encode, zstd_encode,
gzip_decode, snappy_decode, lz4_decode, zstd_decode
)
import kafka.codec as codecs
class DefaultRecordBase(object):
__slots__ = ()
HEADER_STRUCT = struct.Struct(
">q" # BaseOffset => Int64
"i" # Length => Int32
"i" # PartitionLeaderEpoch => Int32
"b" # Magic => Int8
"I" # CRC => Uint32
"h" # Attributes => Int16
"i" # LastOffsetDelta => Int32 // also serves as LastSequenceDelta
"q" # FirstTimestamp => Int64
"q" # MaxTimestamp => Int64
"q" # ProducerId => Int64
"h" # ProducerEpoch => Int16
"i" # BaseSequence => Int32
"i" # Records count => Int32
)
# Byte offset in HEADER_STRUCT of attributes field. Used to calculate CRC
ATTRIBUTES_OFFSET = struct.calcsize(">qiibI")
CRC_OFFSET = struct.calcsize(">qiib")
AFTER_LEN_OFFSET = struct.calcsize(">qi")
CODEC_MASK = 0x07
CODEC_NONE = 0x00
CODEC_GZIP = 0x01
CODEC_SNAPPY = 0x02
CODEC_LZ4 = 0x03
CODEC_ZSTD = 0x04
TIMESTAMP_TYPE_MASK = 0x08
TRANSACTIONAL_MASK = 0x10
CONTROL_MASK = 0x20
LOG_APPEND_TIME = 1
CREATE_TIME = 0
def _assert_has_codec(self, compression_type):
if compression_type == self.CODEC_GZIP:
checker, name = codecs.has_gzip, "gzip"
elif compression_type == self.CODEC_SNAPPY:
checker, name = codecs.has_snappy, "snappy"
elif compression_type == self.CODEC_LZ4:
checker, name = codecs.has_lz4, "lz4"
elif compression_type == self.CODEC_ZSTD:
checker, name = codecs.has_zstd, "zstd"
if not checker():
raise UnsupportedCodecError(
"Libraries for {} compression codec not found".format(name))
class DefaultRecordBatch(DefaultRecordBase, ABCRecordBatch):
__slots__ = ("_buffer", "_header_data", "_pos", "_num_records",
"_next_record_index", "_decompressed")
def __init__(self, buffer):
self._buffer = bytearray(buffer)
self._header_data = self.HEADER_STRUCT.unpack_from(self._buffer)
self._pos = self.HEADER_STRUCT.size
self._num_records = self._header_data[12]
self._next_record_index = 0
self._decompressed = False
@property
def base_offset(self):
return self._header_data[0]
@property
def magic(self):
return self._header_data[3]
@property
def crc(self):
return self._header_data[4]
@property
def attributes(self):
return self._header_data[5]
@property
def last_offset_delta(self):
return self._header_data[6]
@property
def compression_type(self):
return self.attributes & self.CODEC_MASK
@property
def timestamp_type(self):
return int(bool(self.attributes & self.TIMESTAMP_TYPE_MASK))
@property
def is_transactional(self):
return bool(self.attributes & self.TRANSACTIONAL_MASK)
@property
def is_control_batch(self):
return bool(self.attributes & self.CONTROL_MASK)
@property
def first_timestamp(self):
return self._header_data[7]
@property
def max_timestamp(self):
return self._header_data[8]
def _maybe_uncompress(self):
if not self._decompressed:
compression_type = self.compression_type
if compression_type != self.CODEC_NONE:
self._assert_has_codec(compression_type)
data = memoryview(self._buffer)[self._pos:]
if compression_type == self.CODEC_GZIP:
uncompressed = gzip_decode(data)
if compression_type == self.CODEC_SNAPPY:
uncompressed = snappy_decode(data.tobytes())
if compression_type == self.CODEC_LZ4:
uncompressed = lz4_decode(data.tobytes())
if compression_type == self.CODEC_ZSTD:
uncompressed = zstd_decode(data.tobytes())
self._buffer = bytearray(uncompressed)
self._pos = 0
self._decompressed = True
def _read_msg(
self,
decode_varint=decode_varint):
# Record =>
# Length => Varint
# Attributes => Int8
# TimestampDelta => Varlong
# OffsetDelta => Varint
# Key => Bytes
# Value => Bytes
# Headers => [HeaderKey HeaderValue]
# HeaderKey => String
# HeaderValue => Bytes
buffer = self._buffer
pos = self._pos
length, pos = decode_varint(buffer, pos)
start_pos = pos
_, pos = decode_varint(buffer, pos) # attrs can be skipped for now
ts_delta, pos = decode_varint(buffer, pos)
if self.timestamp_type == self.LOG_APPEND_TIME:
timestamp = self.max_timestamp
else:
timestamp = self.first_timestamp + ts_delta
offset_delta, pos = decode_varint(buffer, pos)
offset = self.base_offset + offset_delta
key_len, pos = decode_varint(buffer, pos)
if key_len >= 0:
key = bytes(buffer[pos: pos + key_len])
pos += key_len
else:
key = None
value_len, pos = decode_varint(buffer, pos)
if value_len >= 0:
value = bytes(buffer[pos: pos + value_len])
pos += value_len
else:
value = None
header_count, pos = decode_varint(buffer, pos)
if header_count < 0:
raise CorruptRecordException("Found invalid number of record "
"headers {}".format(header_count))
headers = []
while header_count:
# Header key is of type String, that can't be None
h_key_len, pos = decode_varint(buffer, pos)
if h_key_len < 0:
raise CorruptRecordException(
"Invalid negative header key size {}".format(h_key_len))
h_key = buffer[pos: pos + h_key_len].decode("utf-8")
pos += h_key_len
# Value is of type NULLABLE_BYTES, so it can be None
h_value_len, pos = decode_varint(buffer, pos)
if h_value_len >= 0:
h_value = bytes(buffer[pos: pos + h_value_len])
pos += h_value_len
else:
h_value = None
headers.append((h_key, h_value))
header_count -= 1
# validate whether we have read all header bytes in the current record
if pos - start_pos != length:
raise CorruptRecordException(
"Invalid record size: expected to read {} bytes in record "
"payload, but instead read {}".format(length, pos - start_pos))
self._pos = pos
return DefaultRecord(
offset, timestamp, self.timestamp_type, key, value, headers)
def __iter__(self):
self._maybe_uncompress()
return self
def __next__(self):
if self._next_record_index >= self._num_records:
if self._pos != len(self._buffer):
raise CorruptRecordException(
"{} unconsumed bytes after all records consumed".format(
len(self._buffer) - self._pos))
raise StopIteration
try:
msg = self._read_msg()
except (ValueError, IndexError) as err:
raise CorruptRecordException(
"Found invalid record structure: {!r}".format(err))
else:
self._next_record_index += 1
return msg
next = __next__
def validate_crc(self):
assert self._decompressed is False, \
"Validate should be called before iteration"
crc = self.crc
data_view = memoryview(self._buffer)[self.ATTRIBUTES_OFFSET:]
verify_crc = calc_crc32c(data_view.tobytes())
return crc == verify_crc
class DefaultRecord(ABCRecord):
__slots__ = ("_offset", "_timestamp", "_timestamp_type", "_key", "_value",
"_headers")
def __init__(self, offset, timestamp, timestamp_type, key, value, headers):
self._offset = offset
self._timestamp = timestamp
self._timestamp_type = timestamp_type
self._key = key
self._value = value
self._headers = headers
@property
def offset(self):
return self._offset
@property
def timestamp(self):
""" Epoch milliseconds
"""
return self._timestamp
@property
def timestamp_type(self):
""" CREATE_TIME(0) or APPEND_TIME(1)
"""
return self._timestamp_type
@property
def key(self):
""" Bytes key or None
"""
return self._key
@property
def value(self):
""" Bytes value or None
"""
return self._value
@property
def headers(self):
return self._headers
@property
def checksum(self):
return None
def __repr__(self):
return (
"DefaultRecord(offset={!r}, timestamp={!r}, timestamp_type={!r},"
" key={!r}, value={!r}, headers={!r})".format(
self._offset, self._timestamp, self._timestamp_type,
self._key, self._value, self._headers)
)
class DefaultRecordBatchBuilder(DefaultRecordBase, ABCRecordBatchBuilder):
# excluding key, value and headers:
# 5 bytes length + 10 bytes timestamp + 5 bytes offset + 1 byte attributes
MAX_RECORD_OVERHEAD = 21
__slots__ = ("_magic", "_compression_type", "_batch_size", "_is_transactional",
"_producer_id", "_producer_epoch", "_base_sequence",
"_first_timestamp", "_max_timestamp", "_last_offset", "_num_records",
"_buffer")
def __init__(
self, magic, compression_type, is_transactional,
producer_id, producer_epoch, base_sequence, batch_size):
assert magic >= 2
self._magic = magic
self._compression_type = compression_type & self.CODEC_MASK
self._batch_size = batch_size
self._is_transactional = bool(is_transactional)
# KIP-98 fields for EOS
self._producer_id = producer_id
self._producer_epoch = producer_epoch
self._base_sequence = base_sequence
self._first_timestamp = None
self._max_timestamp = None
self._last_offset = 0
self._num_records = 0
self._buffer = bytearray(self.HEADER_STRUCT.size)
def _get_attributes(self, include_compression_type=True):
attrs = 0
if include_compression_type:
attrs |= self._compression_type
# Timestamp Type is set by Broker
if self._is_transactional:
attrs |= self.TRANSACTIONAL_MASK
# Control batches are only created by Broker
return attrs
def append(self, offset, timestamp, key, value, headers,
# Cache for LOAD_FAST opcodes
encode_varint=encode_varint, size_of_varint=size_of_varint,
get_type=type, type_int=int, time_time=time.time,
byte_like=(bytes, bytearray, memoryview),
bytearray_type=bytearray, len_func=len, zero_len_varint=1
):
""" Write message to messageset buffer with MsgVersion 2
"""
# Check types
if get_type(offset) != type_int:
raise TypeError(offset)
if timestamp is None:
timestamp = type_int(time_time() * 1000)
elif get_type(timestamp) != type_int:
raise TypeError(timestamp)
if not (key is None or get_type(key) in byte_like):
raise TypeError(
"Not supported type for key: {}".format(type(key)))
if not (value is None or get_type(value) in byte_like):
raise TypeError(
"Not supported type for value: {}".format(type(value)))
# We will always add the first message, so those will be set
if self._first_timestamp is None:
self._first_timestamp = timestamp
self._max_timestamp = timestamp
timestamp_delta = 0
first_message = 1
else:
timestamp_delta = timestamp - self._first_timestamp
first_message = 0
# We can't write record right away to out buffer, we need to
# precompute the length as first value...
message_buffer = bytearray_type(b"\x00") # Attributes
write_byte = message_buffer.append
write = message_buffer.extend
encode_varint(timestamp_delta, write_byte)
# Base offset is always 0 on Produce
encode_varint(offset, write_byte)
if key is not None:
encode_varint(len_func(key), write_byte)
write(key)
else:
write_byte(zero_len_varint)
if value is not None:
encode_varint(len_func(value), write_byte)
write(value)
else:
write_byte(zero_len_varint)
encode_varint(len_func(headers), write_byte)
for h_key, h_value in headers:
h_key = h_key.encode("utf-8")
encode_varint(len_func(h_key), write_byte)
write(h_key)
if h_value is not None:
encode_varint(len_func(h_value), write_byte)
write(h_value)
else:
write_byte(zero_len_varint)
message_len = len_func(message_buffer)
main_buffer = self._buffer
required_size = message_len + size_of_varint(message_len)
# Check if we can write this message
if (required_size + len_func(main_buffer) > self._batch_size and
not first_message):
return None
# Those should be updated after the length check
if self._max_timestamp < timestamp:
self._max_timestamp = timestamp
self._num_records += 1
self._last_offset = offset
encode_varint(message_len, main_buffer.append)
main_buffer.extend(message_buffer)
return DefaultRecordMetadata(offset, required_size, timestamp)
def write_header(self, use_compression_type=True):
batch_len = len(self._buffer)
self.HEADER_STRUCT.pack_into(
self._buffer, 0,
0, # BaseOffset, set by broker
batch_len - self.AFTER_LEN_OFFSET, # Size from here to end
0, # PartitionLeaderEpoch, set by broker
self._magic,
0, # CRC will be set below, as we need a filled buffer for it
self._get_attributes(use_compression_type),
self._last_offset,
self._first_timestamp,
self._max_timestamp,
self._producer_id,
self._producer_epoch,
self._base_sequence,
self._num_records
)
crc = calc_crc32c(self._buffer[self.ATTRIBUTES_OFFSET:])
struct.pack_into(">I", self._buffer, self.CRC_OFFSET, crc)
def _maybe_compress(self):
if self._compression_type != self.CODEC_NONE:
self._assert_has_codec(self._compression_type)
header_size = self.HEADER_STRUCT.size
data = bytes(self._buffer[header_size:])
if self._compression_type == self.CODEC_GZIP:
compressed = gzip_encode(data)
elif self._compression_type == self.CODEC_SNAPPY:
compressed = snappy_encode(data)
elif self._compression_type == self.CODEC_LZ4:
compressed = lz4_encode(data)
elif self._compression_type == self.CODEC_ZSTD:
compressed = zstd_encode(data)
compressed_size = len(compressed)
if len(data) <= compressed_size:
# We did not get any benefit from compression, lets send
# uncompressed
return False
else:
# Trim bytearray to the required size
needed_size = header_size + compressed_size
del self._buffer[needed_size:]
self._buffer[header_size:needed_size] = compressed
return True
return False
def build(self):
send_compressed = self._maybe_compress()
self.write_header(send_compressed)
return self._buffer
def size(self):
""" Return current size of data written to buffer
"""
return len(self._buffer)
def size_in_bytes(self, offset, timestamp, key, value, headers):
if self._first_timestamp is not None:
timestamp_delta = timestamp - self._first_timestamp
else:
timestamp_delta = 0
size_of_body = (
1 + # Attrs
size_of_varint(offset) +
size_of_varint(timestamp_delta) +
self.size_of(key, value, headers)
)
return size_of_body + size_of_varint(size_of_body)
@classmethod
def size_of(cls, key, value, headers):
size = 0
# Key size
if key is None:
size += 1
else:
key_len = len(key)
size += size_of_varint(key_len) + key_len
# Value size
if value is None:
size += 1
else:
value_len = len(value)
size += size_of_varint(value_len) + value_len
# Header size
size += size_of_varint(len(headers))
for h_key, h_value in headers:
h_key_len = len(h_key.encode("utf-8"))
size += size_of_varint(h_key_len) + h_key_len
if h_value is None:
size += 1
else:
h_value_len = len(h_value)
size += size_of_varint(h_value_len) + h_value_len
return size
@classmethod
def estimate_size_in_bytes(cls, key, value, headers):
""" Get the upper bound estimate on the size of record
"""
return (
cls.HEADER_STRUCT.size + cls.MAX_RECORD_OVERHEAD +
cls.size_of(key, value, headers)
)
class DefaultRecordMetadata(object):
__slots__ = ("_size", "_timestamp", "_offset")
def __init__(self, offset, size, timestamp):
self._offset = offset
self._size = size
self._timestamp = timestamp
@property
def offset(self):
return self._offset
@property
def crc(self):
return None
@property
def size(self):
return self._size
@property
def timestamp(self):
return self._timestamp
def __repr__(self):
return (
"DefaultRecordMetadata(offset={!r}, size={!r}, timestamp={!r})"
.format(self._offset, self._size, self._timestamp)
)

View File

@@ -0,0 +1,548 @@
# See:
# https://github.com/apache/kafka/blob/trunk/clients/src/main/java/org/\
# apache/kafka/common/record/LegacyRecord.java
# Builder and reader implementation for V0 and V1 record versions. As of Kafka
# 0.11.0.0 those were replaced with V2, thus the Legacy naming.
# The schema is given below (see
# https://kafka.apache.org/protocol#protocol_message_sets for more details):
# MessageSet => [Offset MessageSize Message]
# Offset => int64
# MessageSize => int32
# v0
# Message => Crc MagicByte Attributes Key Value
# Crc => int32
# MagicByte => int8
# Attributes => int8
# Key => bytes
# Value => bytes
# v1 (supported since 0.10.0)
# Message => Crc MagicByte Attributes Key Value
# Crc => int32
# MagicByte => int8
# Attributes => int8
# Timestamp => int64
# Key => bytes
# Value => bytes
# The message attribute bits are given below:
# * Unused (4-7)
# * Timestamp Type (3) (added in V1)
# * Compression Type (0-2)
# Note that when compression is enabled (see attributes above), the whole
# array of MessageSet's is compressed and places into a message as the `value`.
# Only the parent message is marked with `compression` bits in attributes.
# The CRC covers the data from the Magic byte to the end of the message.
import struct
import time
from kafka.record.abc import ABCRecord, ABCRecordBatch, ABCRecordBatchBuilder
from kafka.record.util import calc_crc32
from kafka.codec import (
gzip_encode, snappy_encode, lz4_encode, lz4_encode_old_kafka,
gzip_decode, snappy_decode, lz4_decode, lz4_decode_old_kafka,
)
import kafka.codec as codecs
from kafka.errors import CorruptRecordException, UnsupportedCodecError
class LegacyRecordBase(object):
__slots__ = ()
HEADER_STRUCT_V0 = struct.Struct(
">q" # BaseOffset => Int64
"i" # Length => Int32
"I" # CRC => Int32
"b" # Magic => Int8
"b" # Attributes => Int8
)
HEADER_STRUCT_V1 = struct.Struct(
">q" # BaseOffset => Int64
"i" # Length => Int32
"I" # CRC => Int32
"b" # Magic => Int8
"b" # Attributes => Int8
"q" # timestamp => Int64
)
LOG_OVERHEAD = CRC_OFFSET = struct.calcsize(
">q" # Offset
"i" # Size
)
MAGIC_OFFSET = LOG_OVERHEAD + struct.calcsize(
">I" # CRC
)
# Those are used for fast size calculations
RECORD_OVERHEAD_V0 = struct.calcsize(
">I" # CRC
"b" # magic
"b" # attributes
"i" # Key length
"i" # Value length
)
RECORD_OVERHEAD_V1 = struct.calcsize(
">I" # CRC
"b" # magic
"b" # attributes
"q" # timestamp
"i" # Key length
"i" # Value length
)
KEY_OFFSET_V0 = HEADER_STRUCT_V0.size
KEY_OFFSET_V1 = HEADER_STRUCT_V1.size
KEY_LENGTH = VALUE_LENGTH = struct.calcsize(">i") # Bytes length is Int32
CODEC_MASK = 0x07
CODEC_NONE = 0x00
CODEC_GZIP = 0x01
CODEC_SNAPPY = 0x02
CODEC_LZ4 = 0x03
TIMESTAMP_TYPE_MASK = 0x08
LOG_APPEND_TIME = 1
CREATE_TIME = 0
NO_TIMESTAMP = -1
def _assert_has_codec(self, compression_type):
if compression_type == self.CODEC_GZIP:
checker, name = codecs.has_gzip, "gzip"
elif compression_type == self.CODEC_SNAPPY:
checker, name = codecs.has_snappy, "snappy"
elif compression_type == self.CODEC_LZ4:
checker, name = codecs.has_lz4, "lz4"
if not checker():
raise UnsupportedCodecError(
"Libraries for {} compression codec not found".format(name))
class LegacyRecordBatch(ABCRecordBatch, LegacyRecordBase):
__slots__ = ("_buffer", "_magic", "_offset", "_crc", "_timestamp",
"_attributes", "_decompressed")
def __init__(self, buffer, magic):
self._buffer = memoryview(buffer)
self._magic = magic
offset, length, crc, magic_, attrs, timestamp = self._read_header(0)
assert length == len(buffer) - self.LOG_OVERHEAD
assert magic == magic_
self._offset = offset
self._crc = crc
self._timestamp = timestamp
self._attributes = attrs
self._decompressed = False
@property
def timestamp_type(self):
"""0 for CreateTime; 1 for LogAppendTime; None if unsupported.
Value is determined by broker; produced messages should always set to 0
Requires Kafka >= 0.10 / message version >= 1
"""
if self._magic == 0:
return None
elif self._attributes & self.TIMESTAMP_TYPE_MASK:
return 1
else:
return 0
@property
def compression_type(self):
return self._attributes & self.CODEC_MASK
def validate_crc(self):
crc = calc_crc32(self._buffer[self.MAGIC_OFFSET:])
return self._crc == crc
def _decompress(self, key_offset):
# Copy of `_read_key_value`, but uses memoryview
pos = key_offset
key_size = struct.unpack_from(">i", self._buffer, pos)[0]
pos += self.KEY_LENGTH
if key_size != -1:
pos += key_size
value_size = struct.unpack_from(">i", self._buffer, pos)[0]
pos += self.VALUE_LENGTH
if value_size == -1:
raise CorruptRecordException("Value of compressed message is None")
else:
data = self._buffer[pos:pos + value_size]
compression_type = self.compression_type
self._assert_has_codec(compression_type)
if compression_type == self.CODEC_GZIP:
uncompressed = gzip_decode(data)
elif compression_type == self.CODEC_SNAPPY:
uncompressed = snappy_decode(data.tobytes())
elif compression_type == self.CODEC_LZ4:
if self._magic == 0:
uncompressed = lz4_decode_old_kafka(data.tobytes())
else:
uncompressed = lz4_decode(data.tobytes())
return uncompressed
def _read_header(self, pos):
if self._magic == 0:
offset, length, crc, magic_read, attrs = \
self.HEADER_STRUCT_V0.unpack_from(self._buffer, pos)
timestamp = None
else:
offset, length, crc, magic_read, attrs, timestamp = \
self.HEADER_STRUCT_V1.unpack_from(self._buffer, pos)
return offset, length, crc, magic_read, attrs, timestamp
def _read_all_headers(self):
pos = 0
msgs = []
buffer_len = len(self._buffer)
while pos < buffer_len:
header = self._read_header(pos)
msgs.append((header, pos))
pos += self.LOG_OVERHEAD + header[1] # length
return msgs
def _read_key_value(self, pos):
key_size = struct.unpack_from(">i", self._buffer, pos)[0]
pos += self.KEY_LENGTH
if key_size == -1:
key = None
else:
key = self._buffer[pos:pos + key_size].tobytes()
pos += key_size
value_size = struct.unpack_from(">i", self._buffer, pos)[0]
pos += self.VALUE_LENGTH
if value_size == -1:
value = None
else:
value = self._buffer[pos:pos + value_size].tobytes()
return key, value
def __iter__(self):
if self._magic == 1:
key_offset = self.KEY_OFFSET_V1
else:
key_offset = self.KEY_OFFSET_V0
timestamp_type = self.timestamp_type
if self.compression_type:
# In case we will call iter again
if not self._decompressed:
self._buffer = memoryview(self._decompress(key_offset))
self._decompressed = True
# If relative offset is used, we need to decompress the entire
# message first to compute the absolute offset.
headers = self._read_all_headers()
if self._magic > 0:
msg_header, _ = headers[-1]
absolute_base_offset = self._offset - msg_header[0]
else:
absolute_base_offset = -1
for header, msg_pos in headers:
offset, _, crc, _, attrs, timestamp = header
# There should only ever be a single layer of compression
assert not attrs & self.CODEC_MASK, (
'MessageSet at offset %d appears double-compressed. This '
'should not happen -- check your producers!' % (offset,))
# When magic value is greater than 0, the timestamp
# of a compressed message depends on the
# typestamp type of the wrapper message:
if timestamp_type == self.LOG_APPEND_TIME:
timestamp = self._timestamp
if absolute_base_offset >= 0:
offset += absolute_base_offset
key, value = self._read_key_value(msg_pos + key_offset)
yield LegacyRecord(
offset, timestamp, timestamp_type,
key, value, crc)
else:
key, value = self._read_key_value(key_offset)
yield LegacyRecord(
self._offset, self._timestamp, timestamp_type,
key, value, self._crc)
class LegacyRecord(ABCRecord):
__slots__ = ("_offset", "_timestamp", "_timestamp_type", "_key", "_value",
"_crc")
def __init__(self, offset, timestamp, timestamp_type, key, value, crc):
self._offset = offset
self._timestamp = timestamp
self._timestamp_type = timestamp_type
self._key = key
self._value = value
self._crc = crc
@property
def offset(self):
return self._offset
@property
def timestamp(self):
""" Epoch milliseconds
"""
return self._timestamp
@property
def timestamp_type(self):
""" CREATE_TIME(0) or APPEND_TIME(1)
"""
return self._timestamp_type
@property
def key(self):
""" Bytes key or None
"""
return self._key
@property
def value(self):
""" Bytes value or None
"""
return self._value
@property
def headers(self):
return []
@property
def checksum(self):
return self._crc
def __repr__(self):
return (
"LegacyRecord(offset={!r}, timestamp={!r}, timestamp_type={!r},"
" key={!r}, value={!r}, crc={!r})".format(
self._offset, self._timestamp, self._timestamp_type,
self._key, self._value, self._crc)
)
class LegacyRecordBatchBuilder(ABCRecordBatchBuilder, LegacyRecordBase):
__slots__ = ("_magic", "_compression_type", "_batch_size", "_buffer")
def __init__(self, magic, compression_type, batch_size):
self._magic = magic
self._compression_type = compression_type
self._batch_size = batch_size
self._buffer = bytearray()
def append(self, offset, timestamp, key, value, headers=None):
""" Append message to batch.
"""
assert not headers, "Headers not supported in v0/v1"
# Check types
if type(offset) != int:
raise TypeError(offset)
if self._magic == 0:
timestamp = self.NO_TIMESTAMP
elif timestamp is None:
timestamp = int(time.time() * 1000)
elif type(timestamp) != int:
raise TypeError(
"`timestamp` should be int, but {} provided".format(
type(timestamp)))
if not (key is None or
isinstance(key, (bytes, bytearray, memoryview))):
raise TypeError(
"Not supported type for key: {}".format(type(key)))
if not (value is None or
isinstance(value, (bytes, bytearray, memoryview))):
raise TypeError(
"Not supported type for value: {}".format(type(value)))
# Check if we have room for another message
pos = len(self._buffer)
size = self.size_in_bytes(offset, timestamp, key, value)
# We always allow at least one record to be appended
if offset != 0 and pos + size >= self._batch_size:
return None
# Allocate proper buffer length
self._buffer.extend(bytearray(size))
# Encode message
crc = self._encode_msg(pos, offset, timestamp, key, value)
return LegacyRecordMetadata(offset, crc, size, timestamp)
def _encode_msg(self, start_pos, offset, timestamp, key, value,
attributes=0):
""" Encode msg data into the `msg_buffer`, which should be allocated
to at least the size of this message.
"""
magic = self._magic
buf = self._buffer
pos = start_pos
# Write key and value
pos += self.KEY_OFFSET_V0 if magic == 0 else self.KEY_OFFSET_V1
if key is None:
struct.pack_into(">i", buf, pos, -1)
pos += self.KEY_LENGTH
else:
key_size = len(key)
struct.pack_into(">i", buf, pos, key_size)
pos += self.KEY_LENGTH
buf[pos: pos + key_size] = key
pos += key_size
if value is None:
struct.pack_into(">i", buf, pos, -1)
pos += self.VALUE_LENGTH
else:
value_size = len(value)
struct.pack_into(">i", buf, pos, value_size)
pos += self.VALUE_LENGTH
buf[pos: pos + value_size] = value
pos += value_size
length = (pos - start_pos) - self.LOG_OVERHEAD
# Write msg header. Note, that Crc will be updated later
if magic == 0:
self.HEADER_STRUCT_V0.pack_into(
buf, start_pos,
offset, length, 0, magic, attributes)
else:
self.HEADER_STRUCT_V1.pack_into(
buf, start_pos,
offset, length, 0, magic, attributes, timestamp)
# Calculate CRC for msg
crc_data = memoryview(buf)[start_pos + self.MAGIC_OFFSET:]
crc = calc_crc32(crc_data)
struct.pack_into(">I", buf, start_pos + self.CRC_OFFSET, crc)
return crc
def _maybe_compress(self):
if self._compression_type:
self._assert_has_codec(self._compression_type)
data = bytes(self._buffer)
if self._compression_type == self.CODEC_GZIP:
compressed = gzip_encode(data)
elif self._compression_type == self.CODEC_SNAPPY:
compressed = snappy_encode(data)
elif self._compression_type == self.CODEC_LZ4:
if self._magic == 0:
compressed = lz4_encode_old_kafka(data)
else:
compressed = lz4_encode(data)
size = self.size_in_bytes(
0, timestamp=0, key=None, value=compressed)
# We will try to reuse the same buffer if we have enough space
if size > len(self._buffer):
self._buffer = bytearray(size)
else:
del self._buffer[size:]
self._encode_msg(
start_pos=0,
offset=0, timestamp=0, key=None, value=compressed,
attributes=self._compression_type)
return True
return False
def build(self):
"""Compress batch to be ready for send"""
self._maybe_compress()
return self._buffer
def size(self):
""" Return current size of data written to buffer
"""
return len(self._buffer)
# Size calculations. Just copied Java's implementation
def size_in_bytes(self, offset, timestamp, key, value, headers=None):
""" Actual size of message to add
"""
assert not headers, "Headers not supported in v0/v1"
magic = self._magic
return self.LOG_OVERHEAD + self.record_size(magic, key, value)
@classmethod
def record_size(cls, magic, key, value):
message_size = cls.record_overhead(magic)
if key is not None:
message_size += len(key)
if value is not None:
message_size += len(value)
return message_size
@classmethod
def record_overhead(cls, magic):
assert magic in [0, 1], "Not supported magic"
if magic == 0:
return cls.RECORD_OVERHEAD_V0
else:
return cls.RECORD_OVERHEAD_V1
@classmethod
def estimate_size_in_bytes(cls, magic, compression_type, key, value):
""" Upper bound estimate of record size.
"""
assert magic in [0, 1], "Not supported magic"
# In case of compression we may need another overhead for inner msg
if compression_type:
return (
cls.LOG_OVERHEAD + cls.record_overhead(magic) +
cls.record_size(magic, key, value)
)
return cls.LOG_OVERHEAD + cls.record_size(magic, key, value)
class LegacyRecordMetadata(object):
__slots__ = ("_crc", "_size", "_timestamp", "_offset")
def __init__(self, offset, crc, size, timestamp):
self._offset = offset
self._crc = crc
self._size = size
self._timestamp = timestamp
@property
def offset(self):
return self._offset
@property
def crc(self):
return self._crc
@property
def size(self):
return self._size
@property
def timestamp(self):
return self._timestamp
def __repr__(self):
return (
"LegacyRecordMetadata(offset={!r}, crc={!r}, size={!r},"
" timestamp={!r})".format(
self._offset, self._crc, self._size, self._timestamp)
)

View File

@@ -0,0 +1,187 @@
# This class takes advantage of the fact that all formats v0, v1 and v2 of
# messages storage has the same byte offsets for Length and Magic fields.
# Lets look closely at what leading bytes all versions have:
#
# V0 and V1 (Offset is MessageSet part, other bytes are Message ones):
# Offset => Int64
# BytesLength => Int32
# CRC => Int32
# Magic => Int8
# ...
#
# V2:
# BaseOffset => Int64
# Length => Int32
# PartitionLeaderEpoch => Int32
# Magic => Int8
# ...
#
# So we can iterate over batches just by knowing offsets of Length. Magic is
# used to construct the correct class for Batch itself.
from __future__ import division
import struct
from kafka.errors import CorruptRecordException
from kafka.record.abc import ABCRecords
from kafka.record.legacy_records import LegacyRecordBatch, LegacyRecordBatchBuilder
from kafka.record.default_records import DefaultRecordBatch, DefaultRecordBatchBuilder
class MemoryRecords(ABCRecords):
LENGTH_OFFSET = struct.calcsize(">q")
LOG_OVERHEAD = struct.calcsize(">qi")
MAGIC_OFFSET = struct.calcsize(">qii")
# Minimum space requirements for Record V0
MIN_SLICE = LOG_OVERHEAD + LegacyRecordBatch.RECORD_OVERHEAD_V0
__slots__ = ("_buffer", "_pos", "_next_slice", "_remaining_bytes")
def __init__(self, bytes_data):
self._buffer = bytes_data
self._pos = 0
# We keep one slice ahead so `has_next` will return very fast
self._next_slice = None
self._remaining_bytes = None
self._cache_next()
def size_in_bytes(self):
return len(self._buffer)
def valid_bytes(self):
# We need to read the whole buffer to get the valid_bytes.
# NOTE: in Fetcher we do the call after iteration, so should be fast
if self._remaining_bytes is None:
next_slice = self._next_slice
pos = self._pos
while self._remaining_bytes is None:
self._cache_next()
# Reset previous iterator position
self._next_slice = next_slice
self._pos = pos
return len(self._buffer) - self._remaining_bytes
# NOTE: we cache offsets here as kwargs for a bit more speed, as cPython
# will use LOAD_FAST opcode in this case
def _cache_next(self, len_offset=LENGTH_OFFSET, log_overhead=LOG_OVERHEAD):
buffer = self._buffer
buffer_len = len(buffer)
pos = self._pos
remaining = buffer_len - pos
if remaining < log_overhead:
# Will be re-checked in Fetcher for remaining bytes.
self._remaining_bytes = remaining
self._next_slice = None
return
length, = struct.unpack_from(
">i", buffer, pos + len_offset)
slice_end = pos + log_overhead + length
if slice_end > buffer_len:
# Will be re-checked in Fetcher for remaining bytes
self._remaining_bytes = remaining
self._next_slice = None
return
self._next_slice = memoryview(buffer)[pos: slice_end]
self._pos = slice_end
def has_next(self):
return self._next_slice is not None
# NOTE: same cache for LOAD_FAST as above
def next_batch(self, _min_slice=MIN_SLICE,
_magic_offset=MAGIC_OFFSET):
next_slice = self._next_slice
if next_slice is None:
return None
if len(next_slice) < _min_slice:
raise CorruptRecordException(
"Record size is less than the minimum record overhead "
"({})".format(_min_slice - self.LOG_OVERHEAD))
self._cache_next()
magic, = struct.unpack_from(">b", next_slice, _magic_offset)
if magic <= 1:
return LegacyRecordBatch(next_slice, magic)
else:
return DefaultRecordBatch(next_slice)
class MemoryRecordsBuilder(object):
__slots__ = ("_builder", "_batch_size", "_buffer", "_next_offset", "_closed",
"_bytes_written")
def __init__(self, magic, compression_type, batch_size):
assert magic in [0, 1, 2], "Not supported magic"
assert compression_type in [0, 1, 2, 3, 4], "Not valid compression type"
if magic >= 2:
self._builder = DefaultRecordBatchBuilder(
magic=magic, compression_type=compression_type,
is_transactional=False, producer_id=-1, producer_epoch=-1,
base_sequence=-1, batch_size=batch_size)
else:
self._builder = LegacyRecordBatchBuilder(
magic=magic, compression_type=compression_type,
batch_size=batch_size)
self._batch_size = batch_size
self._buffer = None
self._next_offset = 0
self._closed = False
self._bytes_written = 0
def append(self, timestamp, key, value, headers=[]):
""" Append a message to the buffer.
Returns: RecordMetadata or None if unable to append
"""
if self._closed:
return None
offset = self._next_offset
metadata = self._builder.append(offset, timestamp, key, value, headers)
# Return of None means there's no space to add a new message
if metadata is None:
return None
self._next_offset += 1
return metadata
def close(self):
# This method may be called multiple times on the same batch
# i.e., on retries
# we need to make sure we only close it out once
# otherwise compressed messages may be double-compressed
# see Issue 718
if not self._closed:
self._bytes_written = self._builder.size()
self._buffer = bytes(self._builder.build())
self._builder = None
self._closed = True
def size_in_bytes(self):
if not self._closed:
return self._builder.size()
else:
return len(self._buffer)
def compression_rate(self):
assert self._closed
return self.size_in_bytes() / self._bytes_written
def is_full(self):
if self._closed:
return True
else:
return self._builder.size() >= self._batch_size
def next_offset(self):
return self._next_offset
def buffer(self):
assert self._closed
return self._buffer

View File

@@ -0,0 +1,135 @@
import binascii
from kafka.record._crc32c import crc as crc32c_py
try:
from crc32c import crc32c as crc32c_c
except ImportError:
crc32c_c = None
def encode_varint(value, write):
""" Encode an integer to a varint presentation. See
https://developers.google.com/protocol-buffers/docs/encoding?csw=1#varints
on how those can be produced.
Arguments:
value (int): Value to encode
write (function): Called per byte that needs to be writen
Returns:
int: Number of bytes written
"""
value = (value << 1) ^ (value >> 63)
if value <= 0x7f: # 1 byte
write(value)
return 1
if value <= 0x3fff: # 2 bytes
write(0x80 | (value & 0x7f))
write(value >> 7)
return 2
if value <= 0x1fffff: # 3 bytes
write(0x80 | (value & 0x7f))
write(0x80 | ((value >> 7) & 0x7f))
write(value >> 14)
return 3
if value <= 0xfffffff: # 4 bytes
write(0x80 | (value & 0x7f))
write(0x80 | ((value >> 7) & 0x7f))
write(0x80 | ((value >> 14) & 0x7f))
write(value >> 21)
return 4
if value <= 0x7ffffffff: # 5 bytes
write(0x80 | (value & 0x7f))
write(0x80 | ((value >> 7) & 0x7f))
write(0x80 | ((value >> 14) & 0x7f))
write(0x80 | ((value >> 21) & 0x7f))
write(value >> 28)
return 5
else:
# Return to general algorithm
bits = value & 0x7f
value >>= 7
i = 0
while value:
write(0x80 | bits)
bits = value & 0x7f
value >>= 7
i += 1
write(bits)
return i
def size_of_varint(value):
""" Number of bytes needed to encode an integer in variable-length format.
"""
value = (value << 1) ^ (value >> 63)
if value <= 0x7f:
return 1
if value <= 0x3fff:
return 2
if value <= 0x1fffff:
return 3
if value <= 0xfffffff:
return 4
if value <= 0x7ffffffff:
return 5
if value <= 0x3ffffffffff:
return 6
if value <= 0x1ffffffffffff:
return 7
if value <= 0xffffffffffffff:
return 8
if value <= 0x7fffffffffffffff:
return 9
return 10
def decode_varint(buffer, pos=0):
""" Decode an integer from a varint presentation. See
https://developers.google.com/protocol-buffers/docs/encoding?csw=1#varints
on how those can be produced.
Arguments:
buffer (bytearray): buffer to read from.
pos (int): optional position to read from
Returns:
(int, int): Decoded int value and next read position
"""
result = buffer[pos]
if not (result & 0x81):
return (result >> 1), pos + 1
if not (result & 0x80):
return (result >> 1) ^ (~0), pos + 1
result &= 0x7f
pos += 1
shift = 7
while 1:
b = buffer[pos]
result |= ((b & 0x7f) << shift)
pos += 1
if not (b & 0x80):
return ((result >> 1) ^ -(result & 1), pos)
shift += 7
if shift >= 64:
raise ValueError("Out of int64 range")
_crc32c = crc32c_py
if crc32c_c is not None:
_crc32c = crc32c_c
def calc_crc32c(memview, _crc32c=_crc32c):
""" Calculate CRC-32C (Castagnoli) checksum over a memoryview of data
"""
return _crc32c(memview)
def calc_crc32(memview):
""" Calculate simple CRC-32 checksum over a memoryview of data
"""
crc = binascii.crc32(memview) & 0xffffffff
return crc

View File

@@ -0,0 +1,81 @@
from __future__ import absolute_import
import base64
import hashlib
import hmac
import uuid
from kafka.vendor import six
if six.PY2:
def xor_bytes(left, right):
return bytearray(ord(lb) ^ ord(rb) for lb, rb in zip(left, right))
else:
def xor_bytes(left, right):
return bytes(lb ^ rb for lb, rb in zip(left, right))
class ScramClient:
MECHANISMS = {
'SCRAM-SHA-256': hashlib.sha256,
'SCRAM-SHA-512': hashlib.sha512
}
def __init__(self, user, password, mechanism):
self.nonce = str(uuid.uuid4()).replace('-', '')
self.auth_message = ''
self.salted_password = None
self.user = user
self.password = password.encode('utf-8')
self.hashfunc = self.MECHANISMS[mechanism]
self.hashname = ''.join(mechanism.lower().split('-')[1:3])
self.stored_key = None
self.client_key = None
self.client_signature = None
self.client_proof = None
self.server_key = None
self.server_signature = None
def first_message(self):
client_first_bare = 'n={},r={}'.format(self.user, self.nonce)
self.auth_message += client_first_bare
return 'n,,' + client_first_bare
def process_server_first_message(self, server_first_message):
self.auth_message += ',' + server_first_message
params = dict(pair.split('=', 1) for pair in server_first_message.split(','))
server_nonce = params['r']
if not server_nonce.startswith(self.nonce):
raise ValueError("Server nonce, did not start with client nonce!")
self.nonce = server_nonce
self.auth_message += ',c=biws,r=' + self.nonce
salt = base64.b64decode(params['s'].encode('utf-8'))
iterations = int(params['i'])
self.create_salted_password(salt, iterations)
self.client_key = self.hmac(self.salted_password, b'Client Key')
self.stored_key = self.hashfunc(self.client_key).digest()
self.client_signature = self.hmac(self.stored_key, self.auth_message.encode('utf-8'))
self.client_proof = xor_bytes(self.client_key, self.client_signature)
self.server_key = self.hmac(self.salted_password, b'Server Key')
self.server_signature = self.hmac(self.server_key, self.auth_message.encode('utf-8'))
def hmac(self, key, msg):
return hmac.new(key, msg, digestmod=self.hashfunc).digest()
def create_salted_password(self, salt, iterations):
self.salted_password = hashlib.pbkdf2_hmac(
self.hashname, self.password, salt, iterations
)
def final_message(self):
return 'c=biws,r={},p={}'.format(self.nonce, base64.b64encode(self.client_proof).decode('utf-8'))
def process_server_final_message(self, server_final_message):
params = dict(pair.split('=', 1) for pair in server_final_message.split(','))
if self.server_signature != base64.b64decode(params['v'].encode('utf-8')):
raise ValueError("Server sent wrong signature!")

View File

@@ -0,0 +1,3 @@
from __future__ import absolute_import
from kafka.serializer.abstract import Serializer, Deserializer

View File

@@ -0,0 +1,31 @@
from __future__ import absolute_import
import abc
class Serializer(object):
__meta__ = abc.ABCMeta
def __init__(self, **config):
pass
@abc.abstractmethod
def serialize(self, topic, value):
pass
def close(self):
pass
class Deserializer(object):
__meta__ = abc.ABCMeta
def __init__(self, **config):
pass
@abc.abstractmethod
def deserialize(self, topic, bytes_):
pass
def close(self):
pass

View File

@@ -0,0 +1,87 @@
""" Other useful structs """
from __future__ import absolute_import
from collections import namedtuple
"""A topic and partition tuple
Keyword Arguments:
topic (str): A topic name
partition (int): A partition id
"""
TopicPartition = namedtuple("TopicPartition",
["topic", "partition"])
"""A Kafka broker metadata used by admin tools.
Keyword Arguments:
nodeID (int): The Kafka broker id.
host (str): The Kafka broker hostname.
port (int): The Kafka broker port.
rack (str): The rack of the broker, which is used to in rack aware
partition assignment for fault tolerance.
Examples: `RACK1`, `us-east-1d`. Default: None
"""
BrokerMetadata = namedtuple("BrokerMetadata",
["nodeId", "host", "port", "rack"])
"""A topic partition metadata describing the state in the MetadataResponse.
Keyword Arguments:
topic (str): The topic name of the partition this metadata relates to.
partition (int): The id of the partition this metadata relates to.
leader (int): The id of the broker that is the leader for the partition.
replicas (List[int]): The ids of all brokers that contain replicas of the
partition.
isr (List[int]): The ids of all brokers that contain in-sync replicas of
the partition.
error (KafkaError): A KafkaError object associated with the request for
this partition metadata.
"""
PartitionMetadata = namedtuple("PartitionMetadata",
["topic", "partition", "leader", "replicas", "isr", "error"])
"""The Kafka offset commit API
The Kafka offset commit API allows users to provide additional metadata
(in the form of a string) when an offset is committed. This can be useful
(for example) to store information about which node made the commit,
what time the commit was made, etc.
Keyword Arguments:
offset (int): The offset to be committed
metadata (str): Non-null metadata
"""
OffsetAndMetadata = namedtuple("OffsetAndMetadata",
# TODO add leaderEpoch: OffsetAndMetadata(offset, leaderEpoch, metadata)
["offset", "metadata"])
"""An offset and timestamp tuple
Keyword Arguments:
offset (int): An offset
timestamp (int): The timestamp associated to the offset
"""
OffsetAndTimestamp = namedtuple("OffsetAndTimestamp",
["offset", "timestamp"])
MemberInformation = namedtuple("MemberInformation",
["member_id", "client_id", "client_host", "member_metadata", "member_assignment"])
GroupInformation = namedtuple("GroupInformation",
["error_code", "group", "state", "protocol_type", "protocol", "members", "authorized_operations"])
"""Define retry policy for async producer
Keyword Arguments:
Limit (int): Number of retries. limit >= 0, 0 means no retries
backoff_ms (int): Milliseconds to backoff.
retry_on_timeouts:
"""
RetryOptions = namedtuple("RetryOptions",
["limit", "backoff_ms", "retry_on_timeouts"])

View File

@@ -0,0 +1,66 @@
from __future__ import absolute_import
import binascii
import weakref
from kafka.vendor import six
if six.PY3:
MAX_INT = 2 ** 31
TO_SIGNED = 2 ** 32
def crc32(data):
crc = binascii.crc32(data)
# py2 and py3 behave a little differently
# CRC is encoded as a signed int in kafka protocol
# so we'll convert the py3 unsigned result to signed
if crc >= MAX_INT:
crc -= TO_SIGNED
return crc
else:
from binascii import crc32
class WeakMethod(object):
"""
Callable that weakly references a method and the object it is bound to. It
is based on https://stackoverflow.com/a/24287465.
Arguments:
object_dot_method: A bound instance method (i.e. 'object.method').
"""
def __init__(self, object_dot_method):
try:
self.target = weakref.ref(object_dot_method.__self__)
except AttributeError:
self.target = weakref.ref(object_dot_method.im_self)
self._target_id = id(self.target())
try:
self.method = weakref.ref(object_dot_method.__func__)
except AttributeError:
self.method = weakref.ref(object_dot_method.im_func)
self._method_id = id(self.method())
def __call__(self, *args, **kwargs):
"""
Calls the method on target with args and kwargs.
"""
return self.method()(self.target(), *args, **kwargs)
def __hash__(self):
return hash(self.target) ^ hash(self.method)
def __eq__(self, other):
if not isinstance(other, WeakMethod):
return False
return self._target_id == other._target_id and self._method_id == other._method_id
class Dict(dict):
"""Utility class to support passing weakrefs to dicts
See: https://docs.python.org/2/library/weakref.html
"""
pass

View File

@@ -0,0 +1,841 @@
# pylint: skip-file
# vendored from:
# https://bitbucket.org/stoneleaf/enum34/src/58c4cd7174ca35f164304c8a6f0a4d47b779c2a7/enum/__init__.py?at=1.1.6
"""Python Enumerations"""
import sys as _sys
__all__ = ['Enum', 'IntEnum', 'unique']
version = 1, 1, 6
pyver = float('%s.%s' % _sys.version_info[:2])
try:
any
except NameError:
def any(iterable):
for element in iterable:
if element:
return True
return False
try:
from collections import OrderedDict
except ImportError:
OrderedDict = None
try:
basestring
except NameError:
# In Python 2 basestring is the ancestor of both str and unicode
# in Python 3 it's just str, but was missing in 3.1
basestring = str
try:
unicode
except NameError:
# In Python 3 unicode no longer exists (it's just str)
unicode = str
class _RouteClassAttributeToGetattr(object):
"""Route attribute access on a class to __getattr__.
This is a descriptor, used to define attributes that act differently when
accessed through an instance and through a class. Instance access remains
normal, but access to an attribute through a class will be routed to the
class's __getattr__ method; this is done by raising AttributeError.
"""
def __init__(self, fget=None):
self.fget = fget
def __get__(self, instance, ownerclass=None):
if instance is None:
raise AttributeError()
return self.fget(instance)
def __set__(self, instance, value):
raise AttributeError("can't set attribute")
def __delete__(self, instance):
raise AttributeError("can't delete attribute")
def _is_descriptor(obj):
"""Returns True if obj is a descriptor, False otherwise."""
return (
hasattr(obj, '__get__') or
hasattr(obj, '__set__') or
hasattr(obj, '__delete__'))
def _is_dunder(name):
"""Returns True if a __dunder__ name, False otherwise."""
return (name[:2] == name[-2:] == '__' and
name[2:3] != '_' and
name[-3:-2] != '_' and
len(name) > 4)
def _is_sunder(name):
"""Returns True if a _sunder_ name, False otherwise."""
return (name[0] == name[-1] == '_' and
name[1:2] != '_' and
name[-2:-1] != '_' and
len(name) > 2)
def _make_class_unpicklable(cls):
"""Make the given class un-picklable."""
def _break_on_call_reduce(self, protocol=None):
raise TypeError('%r cannot be pickled' % self)
cls.__reduce_ex__ = _break_on_call_reduce
cls.__module__ = '<unknown>'
class _EnumDict(dict):
"""Track enum member order and ensure member names are not reused.
EnumMeta will use the names found in self._member_names as the
enumeration member names.
"""
def __init__(self):
super(_EnumDict, self).__init__()
self._member_names = []
def __setitem__(self, key, value):
"""Changes anything not dundered or not a descriptor.
If a descriptor is added with the same name as an enum member, the name
is removed from _member_names (this may leave a hole in the numerical
sequence of values).
If an enum member name is used twice, an error is raised; duplicate
values are not checked for.
Single underscore (sunder) names are reserved.
Note: in 3.x __order__ is simply discarded as a not necessary piece
leftover from 2.x
"""
if pyver >= 3.0 and key in ('_order_', '__order__'):
return
elif key == '__order__':
key = '_order_'
if _is_sunder(key):
if key != '_order_':
raise ValueError('_names_ are reserved for future Enum use')
elif _is_dunder(key):
pass
elif key in self._member_names:
# descriptor overwriting an enum?
raise TypeError('Attempted to reuse key: %r' % key)
elif not _is_descriptor(value):
if key in self:
# enum overwriting a descriptor?
raise TypeError('Key already defined as: %r' % self[key])
self._member_names.append(key)
super(_EnumDict, self).__setitem__(key, value)
# Dummy value for Enum as EnumMeta explicity checks for it, but of course until
# EnumMeta finishes running the first time the Enum class doesn't exist. This
# is also why there are checks in EnumMeta like `if Enum is not None`
Enum = None
class EnumMeta(type):
"""Metaclass for Enum"""
@classmethod
def __prepare__(metacls, cls, bases):
return _EnumDict()
def __new__(metacls, cls, bases, classdict):
# an Enum class is final once enumeration items have been defined; it
# cannot be mixed with other types (int, float, etc.) if it has an
# inherited __new__ unless a new __new__ is defined (or the resulting
# class will fail).
if type(classdict) is dict:
original_dict = classdict
classdict = _EnumDict()
for k, v in original_dict.items():
classdict[k] = v
member_type, first_enum = metacls._get_mixins_(bases)
__new__, save_new, use_args = metacls._find_new_(classdict, member_type,
first_enum)
# save enum items into separate mapping so they don't get baked into
# the new class
members = dict((k, classdict[k]) for k in classdict._member_names)
for name in classdict._member_names:
del classdict[name]
# py2 support for definition order
_order_ = classdict.get('_order_')
if _order_ is None:
if pyver < 3.0:
try:
_order_ = [name for (name, value) in sorted(members.items(), key=lambda item: item[1])]
except TypeError:
_order_ = [name for name in sorted(members.keys())]
else:
_order_ = classdict._member_names
else:
del classdict['_order_']
if pyver < 3.0:
_order_ = _order_.replace(',', ' ').split()
aliases = [name for name in members if name not in _order_]
_order_ += aliases
# check for illegal enum names (any others?)
invalid_names = set(members) & set(['mro'])
if invalid_names:
raise ValueError('Invalid enum member name(s): %s' % (
', '.join(invalid_names), ))
# save attributes from super classes so we know if we can take
# the shortcut of storing members in the class dict
base_attributes = set([a for b in bases for a in b.__dict__])
# create our new Enum type
enum_class = super(EnumMeta, metacls).__new__(metacls, cls, bases, classdict)
enum_class._member_names_ = [] # names in random order
if OrderedDict is not None:
enum_class._member_map_ = OrderedDict()
else:
enum_class._member_map_ = {} # name->value map
enum_class._member_type_ = member_type
# Reverse value->name map for hashable values.
enum_class._value2member_map_ = {}
# instantiate them, checking for duplicates as we go
# we instantiate first instead of checking for duplicates first in case
# a custom __new__ is doing something funky with the values -- such as
# auto-numbering ;)
if __new__ is None:
__new__ = enum_class.__new__
for member_name in _order_:
value = members[member_name]
if not isinstance(value, tuple):
args = (value, )
else:
args = value
if member_type is tuple: # special case for tuple enums
args = (args, ) # wrap it one more time
if not use_args or not args:
enum_member = __new__(enum_class)
if not hasattr(enum_member, '_value_'):
enum_member._value_ = value
else:
enum_member = __new__(enum_class, *args)
if not hasattr(enum_member, '_value_'):
enum_member._value_ = member_type(*args)
value = enum_member._value_
enum_member._name_ = member_name
enum_member.__objclass__ = enum_class
enum_member.__init__(*args)
# If another member with the same value was already defined, the
# new member becomes an alias to the existing one.
for name, canonical_member in enum_class._member_map_.items():
if canonical_member.value == enum_member._value_:
enum_member = canonical_member
break
else:
# Aliases don't appear in member names (only in __members__).
enum_class._member_names_.append(member_name)
# performance boost for any member that would not shadow
# a DynamicClassAttribute (aka _RouteClassAttributeToGetattr)
if member_name not in base_attributes:
setattr(enum_class, member_name, enum_member)
# now add to _member_map_
enum_class._member_map_[member_name] = enum_member
try:
# This may fail if value is not hashable. We can't add the value
# to the map, and by-value lookups for this value will be
# linear.
enum_class._value2member_map_[value] = enum_member
except TypeError:
pass
# If a custom type is mixed into the Enum, and it does not know how
# to pickle itself, pickle.dumps will succeed but pickle.loads will
# fail. Rather than have the error show up later and possibly far
# from the source, sabotage the pickle protocol for this class so
# that pickle.dumps also fails.
#
# However, if the new class implements its own __reduce_ex__, do not
# sabotage -- it's on them to make sure it works correctly. We use
# __reduce_ex__ instead of any of the others as it is preferred by
# pickle over __reduce__, and it handles all pickle protocols.
unpicklable = False
if '__reduce_ex__' not in classdict:
if member_type is not object:
methods = ('__getnewargs_ex__', '__getnewargs__',
'__reduce_ex__', '__reduce__')
if not any(m in member_type.__dict__ for m in methods):
_make_class_unpicklable(enum_class)
unpicklable = True
# double check that repr and friends are not the mixin's or various
# things break (such as pickle)
for name in ('__repr__', '__str__', '__format__', '__reduce_ex__'):
class_method = getattr(enum_class, name)
obj_method = getattr(member_type, name, None)
enum_method = getattr(first_enum, name, None)
if name not in classdict and class_method is not enum_method:
if name == '__reduce_ex__' and unpicklable:
continue
setattr(enum_class, name, enum_method)
# method resolution and int's are not playing nice
# Python's less than 2.6 use __cmp__
if pyver < 2.6:
if issubclass(enum_class, int):
setattr(enum_class, '__cmp__', getattr(int, '__cmp__'))
elif pyver < 3.0:
if issubclass(enum_class, int):
for method in (
'__le__',
'__lt__',
'__gt__',
'__ge__',
'__eq__',
'__ne__',
'__hash__',
):
setattr(enum_class, method, getattr(int, method))
# replace any other __new__ with our own (as long as Enum is not None,
# anyway) -- again, this is to support pickle
if Enum is not None:
# if the user defined their own __new__, save it before it gets
# clobbered in case they subclass later
if save_new:
setattr(enum_class, '__member_new__', enum_class.__dict__['__new__'])
setattr(enum_class, '__new__', Enum.__dict__['__new__'])
return enum_class
def __bool__(cls):
"""
classes/types should always be True.
"""
return True
def __call__(cls, value, names=None, module=None, type=None, start=1):
"""Either returns an existing member, or creates a new enum class.
This method is used both when an enum class is given a value to match
to an enumeration member (i.e. Color(3)) and for the functional API
(i.e. Color = Enum('Color', names='red green blue')).
When used for the functional API: `module`, if set, will be stored in
the new class' __module__ attribute; `type`, if set, will be mixed in
as the first base class.
Note: if `module` is not set this routine will attempt to discover the
calling module by walking the frame stack; if this is unsuccessful
the resulting class will not be pickleable.
"""
if names is None: # simple value lookup
return cls.__new__(cls, value)
# otherwise, functional API: we're creating a new Enum type
return cls._create_(value, names, module=module, type=type, start=start)
def __contains__(cls, member):
return isinstance(member, cls) and member.name in cls._member_map_
def __delattr__(cls, attr):
# nicer error message when someone tries to delete an attribute
# (see issue19025).
if attr in cls._member_map_:
raise AttributeError(
"%s: cannot delete Enum member." % cls.__name__)
super(EnumMeta, cls).__delattr__(attr)
def __dir__(self):
return (['__class__', '__doc__', '__members__', '__module__'] +
self._member_names_)
@property
def __members__(cls):
"""Returns a mapping of member name->value.
This mapping lists all enum members, including aliases. Note that this
is a copy of the internal mapping.
"""
return cls._member_map_.copy()
def __getattr__(cls, name):
"""Return the enum member matching `name`
We use __getattr__ instead of descriptors or inserting into the enum
class' __dict__ in order to support `name` and `value` being both
properties for enum members (which live in the class' __dict__) and
enum members themselves.
"""
if _is_dunder(name):
raise AttributeError(name)
try:
return cls._member_map_[name]
except KeyError:
raise AttributeError(name)
def __getitem__(cls, name):
return cls._member_map_[name]
def __iter__(cls):
return (cls._member_map_[name] for name in cls._member_names_)
def __reversed__(cls):
return (cls._member_map_[name] for name in reversed(cls._member_names_))
def __len__(cls):
return len(cls._member_names_)
__nonzero__ = __bool__
def __repr__(cls):
return "<enum %r>" % cls.__name__
def __setattr__(cls, name, value):
"""Block attempts to reassign Enum members.
A simple assignment to the class namespace only changes one of the
several possible ways to get an Enum member from the Enum class,
resulting in an inconsistent Enumeration.
"""
member_map = cls.__dict__.get('_member_map_', {})
if name in member_map:
raise AttributeError('Cannot reassign members.')
super(EnumMeta, cls).__setattr__(name, value)
def _create_(cls, class_name, names=None, module=None, type=None, start=1):
"""Convenience method to create a new Enum class.
`names` can be:
* A string containing member names, separated either with spaces or
commas. Values are auto-numbered from 1.
* An iterable of member names. Values are auto-numbered from 1.
* An iterable of (member name, value) pairs.
* A mapping of member name -> value.
"""
if pyver < 3.0:
# if class_name is unicode, attempt a conversion to ASCII
if isinstance(class_name, unicode):
try:
class_name = class_name.encode('ascii')
except UnicodeEncodeError:
raise TypeError('%r is not representable in ASCII' % class_name)
metacls = cls.__class__
if type is None:
bases = (cls, )
else:
bases = (type, cls)
classdict = metacls.__prepare__(class_name, bases)
_order_ = []
# special processing needed for names?
if isinstance(names, basestring):
names = names.replace(',', ' ').split()
if isinstance(names, (tuple, list)) and isinstance(names[0], basestring):
names = [(e, i+start) for (i, e) in enumerate(names)]
# Here, names is either an iterable of (name, value) or a mapping.
item = None # in case names is empty
for item in names:
if isinstance(item, basestring):
member_name, member_value = item, names[item]
else:
member_name, member_value = item
classdict[member_name] = member_value
_order_.append(member_name)
# only set _order_ in classdict if name/value was not from a mapping
if not isinstance(item, basestring):
classdict['_order_'] = ' '.join(_order_)
enum_class = metacls.__new__(metacls, class_name, bases, classdict)
# TODO: replace the frame hack if a blessed way to know the calling
# module is ever developed
if module is None:
try:
module = _sys._getframe(2).f_globals['__name__']
except (AttributeError, ValueError):
pass
if module is None:
_make_class_unpicklable(enum_class)
else:
enum_class.__module__ = module
return enum_class
@staticmethod
def _get_mixins_(bases):
"""Returns the type for creating enum members, and the first inherited
enum class.
bases: the tuple of bases that was given to __new__
"""
if not bases or Enum is None:
return object, Enum
# double check that we are not subclassing a class with existing
# enumeration members; while we're at it, see if any other data
# type has been mixed in so we can use the correct __new__
member_type = first_enum = None
for base in bases:
if (base is not Enum and
issubclass(base, Enum) and
base._member_names_):
raise TypeError("Cannot extend enumerations")
# base is now the last base in bases
if not issubclass(base, Enum):
raise TypeError("new enumerations must be created as "
"`ClassName([mixin_type,] enum_type)`")
# get correct mix-in type (either mix-in type of Enum subclass, or
# first base if last base is Enum)
if not issubclass(bases[0], Enum):
member_type = bases[0] # first data type
first_enum = bases[-1] # enum type
else:
for base in bases[0].__mro__:
# most common: (IntEnum, int, Enum, object)
# possible: (<Enum 'AutoIntEnum'>, <Enum 'IntEnum'>,
# <class 'int'>, <Enum 'Enum'>,
# <class 'object'>)
if issubclass(base, Enum):
if first_enum is None:
first_enum = base
else:
if member_type is None:
member_type = base
return member_type, first_enum
if pyver < 3.0:
@staticmethod
def _find_new_(classdict, member_type, first_enum):
"""Returns the __new__ to be used for creating the enum members.
classdict: the class dictionary given to __new__
member_type: the data type whose __new__ will be used by default
first_enum: enumeration to check for an overriding __new__
"""
# now find the correct __new__, checking to see of one was defined
# by the user; also check earlier enum classes in case a __new__ was
# saved as __member_new__
__new__ = classdict.get('__new__', None)
if __new__:
return None, True, True # __new__, save_new, use_args
N__new__ = getattr(None, '__new__')
O__new__ = getattr(object, '__new__')
if Enum is None:
E__new__ = N__new__
else:
E__new__ = Enum.__dict__['__new__']
# check all possibles for __member_new__ before falling back to
# __new__
for method in ('__member_new__', '__new__'):
for possible in (member_type, first_enum):
try:
target = possible.__dict__[method]
except (AttributeError, KeyError):
target = getattr(possible, method, None)
if target not in [
None,
N__new__,
O__new__,
E__new__,
]:
if method == '__member_new__':
classdict['__new__'] = target
return None, False, True
if isinstance(target, staticmethod):
target = target.__get__(member_type)
__new__ = target
break
if __new__ is not None:
break
else:
__new__ = object.__new__
# if a non-object.__new__ is used then whatever value/tuple was
# assigned to the enum member name will be passed to __new__ and to the
# new enum member's __init__
if __new__ is object.__new__:
use_args = False
else:
use_args = True
return __new__, False, use_args
else:
@staticmethod
def _find_new_(classdict, member_type, first_enum):
"""Returns the __new__ to be used for creating the enum members.
classdict: the class dictionary given to __new__
member_type: the data type whose __new__ will be used by default
first_enum: enumeration to check for an overriding __new__
"""
# now find the correct __new__, checking to see of one was defined
# by the user; also check earlier enum classes in case a __new__ was
# saved as __member_new__
__new__ = classdict.get('__new__', None)
# should __new__ be saved as __member_new__ later?
save_new = __new__ is not None
if __new__ is None:
# check all possibles for __member_new__ before falling back to
# __new__
for method in ('__member_new__', '__new__'):
for possible in (member_type, first_enum):
target = getattr(possible, method, None)
if target not in (
None,
None.__new__,
object.__new__,
Enum.__new__,
):
__new__ = target
break
if __new__ is not None:
break
else:
__new__ = object.__new__
# if a non-object.__new__ is used then whatever value/tuple was
# assigned to the enum member name will be passed to __new__ and to the
# new enum member's __init__
if __new__ is object.__new__:
use_args = False
else:
use_args = True
return __new__, save_new, use_args
########################################################
# In order to support Python 2 and 3 with a single
# codebase we have to create the Enum methods separately
# and then use the `type(name, bases, dict)` method to
# create the class.
########################################################
temp_enum_dict = {}
temp_enum_dict['__doc__'] = "Generic enumeration.\n\n Derive from this class to define new enumerations.\n\n"
def __new__(cls, value):
# all enum instances are actually created during class construction
# without calling this method; this method is called by the metaclass'
# __call__ (i.e. Color(3) ), and by pickle
if type(value) is cls:
# For lookups like Color(Color.red)
value = value.value
#return value
# by-value search for a matching enum member
# see if it's in the reverse mapping (for hashable values)
try:
if value in cls._value2member_map_:
return cls._value2member_map_[value]
except TypeError:
# not there, now do long search -- O(n) behavior
for member in cls._member_map_.values():
if member.value == value:
return member
raise ValueError("%s is not a valid %s" % (value, cls.__name__))
temp_enum_dict['__new__'] = __new__
del __new__
def __repr__(self):
return "<%s.%s: %r>" % (
self.__class__.__name__, self._name_, self._value_)
temp_enum_dict['__repr__'] = __repr__
del __repr__
def __str__(self):
return "%s.%s" % (self.__class__.__name__, self._name_)
temp_enum_dict['__str__'] = __str__
del __str__
if pyver >= 3.0:
def __dir__(self):
added_behavior = [
m
for cls in self.__class__.mro()
for m in cls.__dict__
if m[0] != '_' and m not in self._member_map_
]
return (['__class__', '__doc__', '__module__', ] + added_behavior)
temp_enum_dict['__dir__'] = __dir__
del __dir__
def __format__(self, format_spec):
# mixed-in Enums should use the mixed-in type's __format__, otherwise
# we can get strange results with the Enum name showing up instead of
# the value
# pure Enum branch
if self._member_type_ is object:
cls = str
val = str(self)
# mix-in branch
else:
cls = self._member_type_
val = self.value
return cls.__format__(val, format_spec)
temp_enum_dict['__format__'] = __format__
del __format__
####################################
# Python's less than 2.6 use __cmp__
if pyver < 2.6:
def __cmp__(self, other):
if type(other) is self.__class__:
if self is other:
return 0
return -1
return NotImplemented
raise TypeError("unorderable types: %s() and %s()" % (self.__class__.__name__, other.__class__.__name__))
temp_enum_dict['__cmp__'] = __cmp__
del __cmp__
else:
def __le__(self, other):
raise TypeError("unorderable types: %s() <= %s()" % (self.__class__.__name__, other.__class__.__name__))
temp_enum_dict['__le__'] = __le__
del __le__
def __lt__(self, other):
raise TypeError("unorderable types: %s() < %s()" % (self.__class__.__name__, other.__class__.__name__))
temp_enum_dict['__lt__'] = __lt__
del __lt__
def __ge__(self, other):
raise TypeError("unorderable types: %s() >= %s()" % (self.__class__.__name__, other.__class__.__name__))
temp_enum_dict['__ge__'] = __ge__
del __ge__
def __gt__(self, other):
raise TypeError("unorderable types: %s() > %s()" % (self.__class__.__name__, other.__class__.__name__))
temp_enum_dict['__gt__'] = __gt__
del __gt__
def __eq__(self, other):
if type(other) is self.__class__:
return self is other
return NotImplemented
temp_enum_dict['__eq__'] = __eq__
del __eq__
def __ne__(self, other):
if type(other) is self.__class__:
return self is not other
return NotImplemented
temp_enum_dict['__ne__'] = __ne__
del __ne__
def __hash__(self):
return hash(self._name_)
temp_enum_dict['__hash__'] = __hash__
del __hash__
def __reduce_ex__(self, proto):
return self.__class__, (self._value_, )
temp_enum_dict['__reduce_ex__'] = __reduce_ex__
del __reduce_ex__
# _RouteClassAttributeToGetattr is used to provide access to the `name`
# and `value` properties of enum members while keeping some measure of
# protection from modification, while still allowing for an enumeration
# to have members named `name` and `value`. This works because enumeration
# members are not set directly on the enum class -- __getattr__ is
# used to look them up.
@_RouteClassAttributeToGetattr
def name(self):
return self._name_
temp_enum_dict['name'] = name
del name
@_RouteClassAttributeToGetattr
def value(self):
return self._value_
temp_enum_dict['value'] = value
del value
@classmethod
def _convert(cls, name, module, filter, source=None):
"""
Create a new Enum subclass that replaces a collection of global constants
"""
# convert all constants from source (or module) that pass filter() to
# a new Enum called name, and export the enum and its members back to
# module;
# also, replace the __reduce_ex__ method so unpickling works in
# previous Python versions
module_globals = vars(_sys.modules[module])
if source:
source = vars(source)
else:
source = module_globals
members = dict((name, value) for name, value in source.items() if filter(name))
cls = cls(name, members, module=module)
cls.__reduce_ex__ = _reduce_ex_by_name
module_globals.update(cls.__members__)
module_globals[name] = cls
return cls
temp_enum_dict['_convert'] = _convert
del _convert
Enum = EnumMeta('Enum', (object, ), temp_enum_dict)
del temp_enum_dict
# Enum has now been created
###########################
class IntEnum(int, Enum):
"""Enum where members are also (and must be) ints"""
def _reduce_ex_by_name(self, proto):
return self.name
def unique(enumeration):
"""Class decorator that ensures only unique members exist in an enumeration."""
duplicates = []
for name, member in enumeration.__members__.items():
if name != member.name:
duplicates.append((name, member.name))
if duplicates:
duplicate_names = ', '.join(
["%s -> %s" % (alias, name) for (alias, name) in duplicates]
)
raise ValueError('duplicate names found in %r: %s' %
(enumeration, duplicate_names)
)
return enumeration

View File

@@ -0,0 +1,637 @@
# pylint: skip-file
# vendored from https://github.com/berkerpeksag/selectors34
# at commit ff61b82168d2cc9c4922ae08e2a8bf94aab61ea2 (unreleased, ~1.2)
#
# Original author: Charles-Francois Natali (c.f.natali[at]gmail.com)
# Maintainer: Berker Peksag (berker.peksag[at]gmail.com)
# Also see https://pypi.python.org/pypi/selectors34
"""Selectors module.
This module allows high-level and efficient I/O multiplexing, built upon the
`select` module primitives.
The following code adapted from trollius.selectors.
"""
from __future__ import absolute_import
from abc import ABCMeta, abstractmethod
from collections import namedtuple, Mapping
from errno import EINTR
import math
import select
import sys
from kafka.vendor import six
def _wrap_error(exc, mapping, key):
if key not in mapping:
return
new_err_cls = mapping[key]
new_err = new_err_cls(*exc.args)
# raise a new exception with the original traceback
if hasattr(exc, '__traceback__'):
traceback = exc.__traceback__
else:
traceback = sys.exc_info()[2]
six.reraise(new_err_cls, new_err, traceback)
# generic events, that must be mapped to implementation-specific ones
EVENT_READ = (1 << 0)
EVENT_WRITE = (1 << 1)
def _fileobj_to_fd(fileobj):
"""Return a file descriptor from a file object.
Parameters:
fileobj -- file object or file descriptor
Returns:
corresponding file descriptor
Raises:
ValueError if the object is invalid
"""
if isinstance(fileobj, six.integer_types):
fd = fileobj
else:
try:
fd = int(fileobj.fileno())
except (AttributeError, TypeError, ValueError):
raise ValueError("Invalid file object: "
"{0!r}".format(fileobj))
if fd < 0:
raise ValueError("Invalid file descriptor: {0}".format(fd))
return fd
SelectorKey = namedtuple('SelectorKey', ['fileobj', 'fd', 'events', 'data'])
"""Object used to associate a file object to its backing file descriptor,
selected event mask and attached data."""
class _SelectorMapping(Mapping):
"""Mapping of file objects to selector keys."""
def __init__(self, selector):
self._selector = selector
def __len__(self):
return len(self._selector._fd_to_key)
def __getitem__(self, fileobj):
try:
fd = self._selector._fileobj_lookup(fileobj)
return self._selector._fd_to_key[fd]
except KeyError:
raise KeyError("{0!r} is not registered".format(fileobj))
def __iter__(self):
return iter(self._selector._fd_to_key)
# Using six.add_metaclass() decorator instead of six.with_metaclass() because
# the latter leaks temporary_class to garbage with gc disabled
@six.add_metaclass(ABCMeta)
class BaseSelector(object):
"""Selector abstract base class.
A selector supports registering file objects to be monitored for specific
I/O events.
A file object is a file descriptor or any object with a `fileno()` method.
An arbitrary object can be attached to the file object, which can be used
for example to store context information, a callback, etc.
A selector can use various implementations (select(), poll(), epoll()...)
depending on the platform. The default `Selector` class uses the most
efficient implementation on the current platform.
"""
@abstractmethod
def register(self, fileobj, events, data=None):
"""Register a file object.
Parameters:
fileobj -- file object or file descriptor
events -- events to monitor (bitwise mask of EVENT_READ|EVENT_WRITE)
data -- attached data
Returns:
SelectorKey instance
Raises:
ValueError if events is invalid
KeyError if fileobj is already registered
OSError if fileobj is closed or otherwise is unacceptable to
the underlying system call (if a system call is made)
Note:
OSError may or may not be raised
"""
raise NotImplementedError
@abstractmethod
def unregister(self, fileobj):
"""Unregister a file object.
Parameters:
fileobj -- file object or file descriptor
Returns:
SelectorKey instance
Raises:
KeyError if fileobj is not registered
Note:
If fileobj is registered but has since been closed this does
*not* raise OSError (even if the wrapped syscall does)
"""
raise NotImplementedError
def modify(self, fileobj, events, data=None):
"""Change a registered file object monitored events or attached data.
Parameters:
fileobj -- file object or file descriptor
events -- events to monitor (bitwise mask of EVENT_READ|EVENT_WRITE)
data -- attached data
Returns:
SelectorKey instance
Raises:
Anything that unregister() or register() raises
"""
self.unregister(fileobj)
return self.register(fileobj, events, data)
@abstractmethod
def select(self, timeout=None):
"""Perform the actual selection, until some monitored file objects are
ready or a timeout expires.
Parameters:
timeout -- if timeout > 0, this specifies the maximum wait time, in
seconds
if timeout <= 0, the select() call won't block, and will
report the currently ready file objects
if timeout is None, select() will block until a monitored
file object becomes ready
Returns:
list of (key, events) for ready file objects
`events` is a bitwise mask of EVENT_READ|EVENT_WRITE
"""
raise NotImplementedError
def close(self):
"""Close the selector.
This must be called to make sure that any underlying resource is freed.
"""
pass
def get_key(self, fileobj):
"""Return the key associated to a registered file object.
Returns:
SelectorKey for this file object
"""
mapping = self.get_map()
if mapping is None:
raise RuntimeError('Selector is closed')
try:
return mapping[fileobj]
except KeyError:
raise KeyError("{0!r} is not registered".format(fileobj))
@abstractmethod
def get_map(self):
"""Return a mapping of file objects to selector keys."""
raise NotImplementedError
def __enter__(self):
return self
def __exit__(self, *args):
self.close()
class _BaseSelectorImpl(BaseSelector):
"""Base selector implementation."""
def __init__(self):
# this maps file descriptors to keys
self._fd_to_key = {}
# read-only mapping returned by get_map()
self._map = _SelectorMapping(self)
def _fileobj_lookup(self, fileobj):
"""Return a file descriptor from a file object.
This wraps _fileobj_to_fd() to do an exhaustive search in case
the object is invalid but we still have it in our map. This
is used by unregister() so we can unregister an object that
was previously registered even if it is closed. It is also
used by _SelectorMapping.
"""
try:
return _fileobj_to_fd(fileobj)
except ValueError:
# Do an exhaustive search.
for key in self._fd_to_key.values():
if key.fileobj is fileobj:
return key.fd
# Raise ValueError after all.
raise
def register(self, fileobj, events, data=None):
if (not events) or (events & ~(EVENT_READ | EVENT_WRITE)):
raise ValueError("Invalid events: {0!r}".format(events))
key = SelectorKey(fileobj, self._fileobj_lookup(fileobj), events, data)
if key.fd in self._fd_to_key:
raise KeyError("{0!r} (FD {1}) is already registered"
.format(fileobj, key.fd))
self._fd_to_key[key.fd] = key
return key
def unregister(self, fileobj):
try:
key = self._fd_to_key.pop(self._fileobj_lookup(fileobj))
except KeyError:
raise KeyError("{0!r} is not registered".format(fileobj))
return key
def modify(self, fileobj, events, data=None):
# TODO: Subclasses can probably optimize this even further.
try:
key = self._fd_to_key[self._fileobj_lookup(fileobj)]
except KeyError:
raise KeyError("{0!r} is not registered".format(fileobj))
if events != key.events:
self.unregister(fileobj)
key = self.register(fileobj, events, data)
elif data != key.data:
# Use a shortcut to update the data.
key = key._replace(data=data)
self._fd_to_key[key.fd] = key
return key
def close(self):
self._fd_to_key.clear()
self._map = None
def get_map(self):
return self._map
def _key_from_fd(self, fd):
"""Return the key associated to a given file descriptor.
Parameters:
fd -- file descriptor
Returns:
corresponding key, or None if not found
"""
try:
return self._fd_to_key[fd]
except KeyError:
return None
class SelectSelector(_BaseSelectorImpl):
"""Select-based selector."""
def __init__(self):
super(SelectSelector, self).__init__()
self._readers = set()
self._writers = set()
def register(self, fileobj, events, data=None):
key = super(SelectSelector, self).register(fileobj, events, data)
if events & EVENT_READ:
self._readers.add(key.fd)
if events & EVENT_WRITE:
self._writers.add(key.fd)
return key
def unregister(self, fileobj):
key = super(SelectSelector, self).unregister(fileobj)
self._readers.discard(key.fd)
self._writers.discard(key.fd)
return key
if sys.platform == 'win32':
def _select(self, r, w, _, timeout=None):
r, w, x = select.select(r, w, w, timeout)
return r, w + x, []
else:
_select = staticmethod(select.select)
def select(self, timeout=None):
timeout = None if timeout is None else max(timeout, 0)
ready = []
try:
r, w, _ = self._select(self._readers, self._writers, [], timeout)
except select.error as exc:
if exc.args[0] == EINTR:
return ready
else:
raise
r = set(r)
w = set(w)
for fd in r | w:
events = 0
if fd in r:
events |= EVENT_READ
if fd in w:
events |= EVENT_WRITE
key = self._key_from_fd(fd)
if key:
ready.append((key, events & key.events))
return ready
if hasattr(select, 'poll'):
class PollSelector(_BaseSelectorImpl):
"""Poll-based selector."""
def __init__(self):
super(PollSelector, self).__init__()
self._poll = select.poll()
def register(self, fileobj, events, data=None):
key = super(PollSelector, self).register(fileobj, events, data)
poll_events = 0
if events & EVENT_READ:
poll_events |= select.POLLIN
if events & EVENT_WRITE:
poll_events |= select.POLLOUT
self._poll.register(key.fd, poll_events)
return key
def unregister(self, fileobj):
key = super(PollSelector, self).unregister(fileobj)
self._poll.unregister(key.fd)
return key
def select(self, timeout=None):
if timeout is None:
timeout = None
elif timeout <= 0:
timeout = 0
else:
# poll() has a resolution of 1 millisecond, round away from
# zero to wait *at least* timeout seconds.
timeout = int(math.ceil(timeout * 1e3))
ready = []
try:
fd_event_list = self._poll.poll(timeout)
except select.error as exc:
if exc.args[0] == EINTR:
return ready
else:
raise
for fd, event in fd_event_list:
events = 0
if event & ~select.POLLIN:
events |= EVENT_WRITE
if event & ~select.POLLOUT:
events |= EVENT_READ
key = self._key_from_fd(fd)
if key:
ready.append((key, events & key.events))
return ready
if hasattr(select, 'epoll'):
class EpollSelector(_BaseSelectorImpl):
"""Epoll-based selector."""
def __init__(self):
super(EpollSelector, self).__init__()
self._epoll = select.epoll()
def fileno(self):
return self._epoll.fileno()
def register(self, fileobj, events, data=None):
key = super(EpollSelector, self).register(fileobj, events, data)
epoll_events = 0
if events & EVENT_READ:
epoll_events |= select.EPOLLIN
if events & EVENT_WRITE:
epoll_events |= select.EPOLLOUT
self._epoll.register(key.fd, epoll_events)
return key
def unregister(self, fileobj):
key = super(EpollSelector, self).unregister(fileobj)
try:
self._epoll.unregister(key.fd)
except IOError:
# This can happen if the FD was closed since it
# was registered.
pass
return key
def select(self, timeout=None):
if timeout is None:
timeout = -1
elif timeout <= 0:
timeout = 0
else:
# epoll_wait() has a resolution of 1 millisecond, round away
# from zero to wait *at least* timeout seconds.
timeout = math.ceil(timeout * 1e3) * 1e-3
# epoll_wait() expects `maxevents` to be greater than zero;
# we want to make sure that `select()` can be called when no
# FD is registered.
max_ev = max(len(self._fd_to_key), 1)
ready = []
try:
fd_event_list = self._epoll.poll(timeout, max_ev)
except IOError as exc:
if exc.errno == EINTR:
return ready
else:
raise
for fd, event in fd_event_list:
events = 0
if event & ~select.EPOLLIN:
events |= EVENT_WRITE
if event & ~select.EPOLLOUT:
events |= EVENT_READ
key = self._key_from_fd(fd)
if key:
ready.append((key, events & key.events))
return ready
def close(self):
self._epoll.close()
super(EpollSelector, self).close()
if hasattr(select, 'devpoll'):
class DevpollSelector(_BaseSelectorImpl):
"""Solaris /dev/poll selector."""
def __init__(self):
super(DevpollSelector, self).__init__()
self._devpoll = select.devpoll()
def fileno(self):
return self._devpoll.fileno()
def register(self, fileobj, events, data=None):
key = super(DevpollSelector, self).register(fileobj, events, data)
poll_events = 0
if events & EVENT_READ:
poll_events |= select.POLLIN
if events & EVENT_WRITE:
poll_events |= select.POLLOUT
self._devpoll.register(key.fd, poll_events)
return key
def unregister(self, fileobj):
key = super(DevpollSelector, self).unregister(fileobj)
self._devpoll.unregister(key.fd)
return key
def select(self, timeout=None):
if timeout is None:
timeout = None
elif timeout <= 0:
timeout = 0
else:
# devpoll() has a resolution of 1 millisecond, round away from
# zero to wait *at least* timeout seconds.
timeout = math.ceil(timeout * 1e3)
ready = []
try:
fd_event_list = self._devpoll.poll(timeout)
except OSError as exc:
if exc.errno == EINTR:
return ready
else:
raise
for fd, event in fd_event_list:
events = 0
if event & ~select.POLLIN:
events |= EVENT_WRITE
if event & ~select.POLLOUT:
events |= EVENT_READ
key = self._key_from_fd(fd)
if key:
ready.append((key, events & key.events))
return ready
def close(self):
self._devpoll.close()
super(DevpollSelector, self).close()
if hasattr(select, 'kqueue'):
class KqueueSelector(_BaseSelectorImpl):
"""Kqueue-based selector."""
def __init__(self):
super(KqueueSelector, self).__init__()
self._kqueue = select.kqueue()
def fileno(self):
return self._kqueue.fileno()
def register(self, fileobj, events, data=None):
key = super(KqueueSelector, self).register(fileobj, events, data)
if events & EVENT_READ:
kev = select.kevent(key.fd, select.KQ_FILTER_READ,
select.KQ_EV_ADD)
self._kqueue.control([kev], 0, 0)
if events & EVENT_WRITE:
kev = select.kevent(key.fd, select.KQ_FILTER_WRITE,
select.KQ_EV_ADD)
self._kqueue.control([kev], 0, 0)
return key
def unregister(self, fileobj):
key = super(KqueueSelector, self).unregister(fileobj)
if key.events & EVENT_READ:
kev = select.kevent(key.fd, select.KQ_FILTER_READ,
select.KQ_EV_DELETE)
try:
self._kqueue.control([kev], 0, 0)
except OSError:
# This can happen if the FD was closed since it
# was registered.
pass
if key.events & EVENT_WRITE:
kev = select.kevent(key.fd, select.KQ_FILTER_WRITE,
select.KQ_EV_DELETE)
try:
self._kqueue.control([kev], 0, 0)
except OSError:
# See comment above.
pass
return key
def select(self, timeout=None):
timeout = None if timeout is None else max(timeout, 0)
max_ev = len(self._fd_to_key)
ready = []
try:
kev_list = self._kqueue.control(None, max_ev, timeout)
except OSError as exc:
if exc.errno == EINTR:
return ready
else:
raise
for kev in kev_list:
fd = kev.ident
flag = kev.filter
events = 0
if flag == select.KQ_FILTER_READ:
events |= EVENT_READ
if flag == select.KQ_FILTER_WRITE:
events |= EVENT_WRITE
key = self._key_from_fd(fd)
if key:
ready.append((key, events & key.events))
return ready
def close(self):
self._kqueue.close()
super(KqueueSelector, self).close()
# Choose the best implementation, roughly:
# epoll|kqueue|devpoll > poll > select.
# select() also can't accept a FD > FD_SETSIZE (usually around 1024)
if 'KqueueSelector' in globals():
DefaultSelector = KqueueSelector
elif 'EpollSelector' in globals():
DefaultSelector = EpollSelector
elif 'DevpollSelector' in globals():
DefaultSelector = DevpollSelector
elif 'PollSelector' in globals():
DefaultSelector = PollSelector
else:
DefaultSelector = SelectSelector

View File

@@ -0,0 +1,897 @@
# pylint: skip-file
# Copyright (c) 2010-2017 Benjamin Peterson
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
"""Utilities for writing code that runs on Python 2 and 3"""
from __future__ import absolute_import
import functools
import itertools
import operator
import sys
import types
__author__ = "Benjamin Peterson <benjamin@python.org>"
__version__ = "1.11.0"
# Useful for very coarse version differentiation.
PY2 = sys.version_info[0] == 2
PY3 = sys.version_info[0] == 3
PY34 = sys.version_info[0:2] >= (3, 4)
if PY3:
string_types = str,
integer_types = int,
class_types = type,
text_type = str
binary_type = bytes
MAXSIZE = sys.maxsize
else:
string_types = basestring,
integer_types = (int, long)
class_types = (type, types.ClassType)
text_type = unicode
binary_type = str
if sys.platform.startswith("java"):
# Jython always uses 32 bits.
MAXSIZE = int((1 << 31) - 1)
else:
# It's possible to have sizeof(long) != sizeof(Py_ssize_t).
class X(object):
def __len__(self):
return 1 << 31
try:
len(X())
except OverflowError:
# 32-bit
MAXSIZE = int((1 << 31) - 1)
else:
# 64-bit
MAXSIZE = int((1 << 63) - 1)
# Don't del it here, cause with gc disabled this "leaks" to garbage.
# Note: This is a kafka-python customization, details at:
# https://github.com/dpkp/kafka-python/pull/979#discussion_r100403389
# del X
def _add_doc(func, doc):
"""Add documentation to a function."""
func.__doc__ = doc
def _import_module(name):
"""Import module, returning the module after the last dot."""
__import__(name)
return sys.modules[name]
class _LazyDescr(object):
def __init__(self, name):
self.name = name
def __get__(self, obj, tp):
result = self._resolve()
setattr(obj, self.name, result) # Invokes __set__.
try:
# This is a bit ugly, but it avoids running this again by
# removing this descriptor.
delattr(obj.__class__, self.name)
except AttributeError:
pass
return result
class MovedModule(_LazyDescr):
def __init__(self, name, old, new=None):
super(MovedModule, self).__init__(name)
if PY3:
if new is None:
new = name
self.mod = new
else:
self.mod = old
def _resolve(self):
return _import_module(self.mod)
def __getattr__(self, attr):
_module = self._resolve()
value = getattr(_module, attr)
setattr(self, attr, value)
return value
class _LazyModule(types.ModuleType):
def __init__(self, name):
super(_LazyModule, self).__init__(name)
self.__doc__ = self.__class__.__doc__
def __dir__(self):
attrs = ["__doc__", "__name__"]
attrs += [attr.name for attr in self._moved_attributes]
return attrs
# Subclasses should override this
_moved_attributes = []
class MovedAttribute(_LazyDescr):
def __init__(self, name, old_mod, new_mod, old_attr=None, new_attr=None):
super(MovedAttribute, self).__init__(name)
if PY3:
if new_mod is None:
new_mod = name
self.mod = new_mod
if new_attr is None:
if old_attr is None:
new_attr = name
else:
new_attr = old_attr
self.attr = new_attr
else:
self.mod = old_mod
if old_attr is None:
old_attr = name
self.attr = old_attr
def _resolve(self):
module = _import_module(self.mod)
return getattr(module, self.attr)
class _SixMetaPathImporter(object):
"""
A meta path importer to import six.moves and its submodules.
This class implements a PEP302 finder and loader. It should be compatible
with Python 2.5 and all existing versions of Python3
"""
def __init__(self, six_module_name):
self.name = six_module_name
self.known_modules = {}
def _add_module(self, mod, *fullnames):
for fullname in fullnames:
self.known_modules[self.name + "." + fullname] = mod
def _get_module(self, fullname):
return self.known_modules[self.name + "." + fullname]
def find_module(self, fullname, path=None):
if fullname in self.known_modules:
return self
return None
def __get_module(self, fullname):
try:
return self.known_modules[fullname]
except KeyError:
raise ImportError("This loader does not know module " + fullname)
def load_module(self, fullname):
try:
# in case of a reload
return sys.modules[fullname]
except KeyError:
pass
mod = self.__get_module(fullname)
if isinstance(mod, MovedModule):
mod = mod._resolve()
else:
mod.__loader__ = self
sys.modules[fullname] = mod
return mod
def is_package(self, fullname):
"""
Return true, if the named module is a package.
We need this method to get correct spec objects with
Python 3.4 (see PEP451)
"""
return hasattr(self.__get_module(fullname), "__path__")
def get_code(self, fullname):
"""Return None
Required, if is_package is implemented"""
self.__get_module(fullname) # eventually raises ImportError
return None
get_source = get_code # same as get_code
_importer = _SixMetaPathImporter(__name__)
class _MovedItems(_LazyModule):
"""Lazy loading of moved objects"""
__path__ = [] # mark as package
_moved_attributes = [
MovedAttribute("cStringIO", "cStringIO", "io", "StringIO"),
MovedAttribute("filter", "itertools", "builtins", "ifilter", "filter"),
MovedAttribute("filterfalse", "itertools", "itertools", "ifilterfalse", "filterfalse"),
MovedAttribute("input", "__builtin__", "builtins", "raw_input", "input"),
MovedAttribute("intern", "__builtin__", "sys"),
MovedAttribute("map", "itertools", "builtins", "imap", "map"),
MovedAttribute("getcwd", "os", "os", "getcwdu", "getcwd"),
MovedAttribute("getcwdb", "os", "os", "getcwd", "getcwdb"),
MovedAttribute("getoutput", "commands", "subprocess"),
MovedAttribute("range", "__builtin__", "builtins", "xrange", "range"),
MovedAttribute("reload_module", "__builtin__", "importlib" if PY34 else "imp", "reload"),
MovedAttribute("reduce", "__builtin__", "functools"),
MovedAttribute("shlex_quote", "pipes", "shlex", "quote"),
MovedAttribute("StringIO", "StringIO", "io"),
MovedAttribute("UserDict", "UserDict", "collections"),
MovedAttribute("UserList", "UserList", "collections"),
MovedAttribute("UserString", "UserString", "collections"),
MovedAttribute("xrange", "__builtin__", "builtins", "xrange", "range"),
MovedAttribute("zip", "itertools", "builtins", "izip", "zip"),
MovedAttribute("zip_longest", "itertools", "itertools", "izip_longest", "zip_longest"),
MovedModule("builtins", "__builtin__"),
MovedModule("configparser", "ConfigParser"),
MovedModule("copyreg", "copy_reg"),
MovedModule("dbm_gnu", "gdbm", "dbm.gnu"),
MovedModule("_dummy_thread", "dummy_thread", "_dummy_thread"),
MovedModule("http_cookiejar", "cookielib", "http.cookiejar"),
MovedModule("http_cookies", "Cookie", "http.cookies"),
MovedModule("html_entities", "htmlentitydefs", "html.entities"),
MovedModule("html_parser", "HTMLParser", "html.parser"),
MovedModule("http_client", "httplib", "http.client"),
MovedModule("email_mime_base", "email.MIMEBase", "email.mime.base"),
MovedModule("email_mime_image", "email.MIMEImage", "email.mime.image"),
MovedModule("email_mime_multipart", "email.MIMEMultipart", "email.mime.multipart"),
MovedModule("email_mime_nonmultipart", "email.MIMENonMultipart", "email.mime.nonmultipart"),
MovedModule("email_mime_text", "email.MIMEText", "email.mime.text"),
MovedModule("BaseHTTPServer", "BaseHTTPServer", "http.server"),
MovedModule("CGIHTTPServer", "CGIHTTPServer", "http.server"),
MovedModule("SimpleHTTPServer", "SimpleHTTPServer", "http.server"),
MovedModule("cPickle", "cPickle", "pickle"),
MovedModule("queue", "Queue"),
MovedModule("reprlib", "repr"),
MovedModule("socketserver", "SocketServer"),
MovedModule("_thread", "thread", "_thread"),
MovedModule("tkinter", "Tkinter"),
MovedModule("tkinter_dialog", "Dialog", "tkinter.dialog"),
MovedModule("tkinter_filedialog", "FileDialog", "tkinter.filedialog"),
MovedModule("tkinter_scrolledtext", "ScrolledText", "tkinter.scrolledtext"),
MovedModule("tkinter_simpledialog", "SimpleDialog", "tkinter.simpledialog"),
MovedModule("tkinter_tix", "Tix", "tkinter.tix"),
MovedModule("tkinter_ttk", "ttk", "tkinter.ttk"),
MovedModule("tkinter_constants", "Tkconstants", "tkinter.constants"),
MovedModule("tkinter_dnd", "Tkdnd", "tkinter.dnd"),
MovedModule("tkinter_colorchooser", "tkColorChooser",
"tkinter.colorchooser"),
MovedModule("tkinter_commondialog", "tkCommonDialog",
"tkinter.commondialog"),
MovedModule("tkinter_tkfiledialog", "tkFileDialog", "tkinter.filedialog"),
MovedModule("tkinter_font", "tkFont", "tkinter.font"),
MovedModule("tkinter_messagebox", "tkMessageBox", "tkinter.messagebox"),
MovedModule("tkinter_tksimpledialog", "tkSimpleDialog",
"tkinter.simpledialog"),
MovedModule("urllib_parse", __name__ + ".moves.urllib_parse", "urllib.parse"),
MovedModule("urllib_error", __name__ + ".moves.urllib_error", "urllib.error"),
MovedModule("urllib", __name__ + ".moves.urllib", __name__ + ".moves.urllib"),
MovedModule("urllib_robotparser", "robotparser", "urllib.robotparser"),
MovedModule("xmlrpc_client", "xmlrpclib", "xmlrpc.client"),
MovedModule("xmlrpc_server", "SimpleXMLRPCServer", "xmlrpc.server"),
]
# Add windows specific modules.
if sys.platform == "win32":
_moved_attributes += [
MovedModule("winreg", "_winreg"),
]
for attr in _moved_attributes:
setattr(_MovedItems, attr.name, attr)
if isinstance(attr, MovedModule):
_importer._add_module(attr, "moves." + attr.name)
del attr
_MovedItems._moved_attributes = _moved_attributes
moves = _MovedItems(__name__ + ".moves")
_importer._add_module(moves, "moves")
class Module_six_moves_urllib_parse(_LazyModule):
"""Lazy loading of moved objects in six.moves.urllib_parse"""
_urllib_parse_moved_attributes = [
MovedAttribute("ParseResult", "urlparse", "urllib.parse"),
MovedAttribute("SplitResult", "urlparse", "urllib.parse"),
MovedAttribute("parse_qs", "urlparse", "urllib.parse"),
MovedAttribute("parse_qsl", "urlparse", "urllib.parse"),
MovedAttribute("urldefrag", "urlparse", "urllib.parse"),
MovedAttribute("urljoin", "urlparse", "urllib.parse"),
MovedAttribute("urlparse", "urlparse", "urllib.parse"),
MovedAttribute("urlsplit", "urlparse", "urllib.parse"),
MovedAttribute("urlunparse", "urlparse", "urllib.parse"),
MovedAttribute("urlunsplit", "urlparse", "urllib.parse"),
MovedAttribute("quote", "urllib", "urllib.parse"),
MovedAttribute("quote_plus", "urllib", "urllib.parse"),
MovedAttribute("unquote", "urllib", "urllib.parse"),
MovedAttribute("unquote_plus", "urllib", "urllib.parse"),
MovedAttribute("unquote_to_bytes", "urllib", "urllib.parse", "unquote", "unquote_to_bytes"),
MovedAttribute("urlencode", "urllib", "urllib.parse"),
MovedAttribute("splitquery", "urllib", "urllib.parse"),
MovedAttribute("splittag", "urllib", "urllib.parse"),
MovedAttribute("splituser", "urllib", "urllib.parse"),
MovedAttribute("splitvalue", "urllib", "urllib.parse"),
MovedAttribute("uses_fragment", "urlparse", "urllib.parse"),
MovedAttribute("uses_netloc", "urlparse", "urllib.parse"),
MovedAttribute("uses_params", "urlparse", "urllib.parse"),
MovedAttribute("uses_query", "urlparse", "urllib.parse"),
MovedAttribute("uses_relative", "urlparse", "urllib.parse"),
]
for attr in _urllib_parse_moved_attributes:
setattr(Module_six_moves_urllib_parse, attr.name, attr)
del attr
Module_six_moves_urllib_parse._moved_attributes = _urllib_parse_moved_attributes
_importer._add_module(Module_six_moves_urllib_parse(__name__ + ".moves.urllib_parse"),
"moves.urllib_parse", "moves.urllib.parse")
class Module_six_moves_urllib_error(_LazyModule):
"""Lazy loading of moved objects in six.moves.urllib_error"""
_urllib_error_moved_attributes = [
MovedAttribute("URLError", "urllib2", "urllib.error"),
MovedAttribute("HTTPError", "urllib2", "urllib.error"),
MovedAttribute("ContentTooShortError", "urllib", "urllib.error"),
]
for attr in _urllib_error_moved_attributes:
setattr(Module_six_moves_urllib_error, attr.name, attr)
del attr
Module_six_moves_urllib_error._moved_attributes = _urllib_error_moved_attributes
_importer._add_module(Module_six_moves_urllib_error(__name__ + ".moves.urllib.error"),
"moves.urllib_error", "moves.urllib.error")
class Module_six_moves_urllib_request(_LazyModule):
"""Lazy loading of moved objects in six.moves.urllib_request"""
_urllib_request_moved_attributes = [
MovedAttribute("urlopen", "urllib2", "urllib.request"),
MovedAttribute("install_opener", "urllib2", "urllib.request"),
MovedAttribute("build_opener", "urllib2", "urllib.request"),
MovedAttribute("pathname2url", "urllib", "urllib.request"),
MovedAttribute("url2pathname", "urllib", "urllib.request"),
MovedAttribute("getproxies", "urllib", "urllib.request"),
MovedAttribute("Request", "urllib2", "urllib.request"),
MovedAttribute("OpenerDirector", "urllib2", "urllib.request"),
MovedAttribute("HTTPDefaultErrorHandler", "urllib2", "urllib.request"),
MovedAttribute("HTTPRedirectHandler", "urllib2", "urllib.request"),
MovedAttribute("HTTPCookieProcessor", "urllib2", "urllib.request"),
MovedAttribute("ProxyHandler", "urllib2", "urllib.request"),
MovedAttribute("BaseHandler", "urllib2", "urllib.request"),
MovedAttribute("HTTPPasswordMgr", "urllib2", "urllib.request"),
MovedAttribute("HTTPPasswordMgrWithDefaultRealm", "urllib2", "urllib.request"),
MovedAttribute("AbstractBasicAuthHandler", "urllib2", "urllib.request"),
MovedAttribute("HTTPBasicAuthHandler", "urllib2", "urllib.request"),
MovedAttribute("ProxyBasicAuthHandler", "urllib2", "urllib.request"),
MovedAttribute("AbstractDigestAuthHandler", "urllib2", "urllib.request"),
MovedAttribute("HTTPDigestAuthHandler", "urllib2", "urllib.request"),
MovedAttribute("ProxyDigestAuthHandler", "urllib2", "urllib.request"),
MovedAttribute("HTTPHandler", "urllib2", "urllib.request"),
MovedAttribute("HTTPSHandler", "urllib2", "urllib.request"),
MovedAttribute("FileHandler", "urllib2", "urllib.request"),
MovedAttribute("FTPHandler", "urllib2", "urllib.request"),
MovedAttribute("CacheFTPHandler", "urllib2", "urllib.request"),
MovedAttribute("UnknownHandler", "urllib2", "urllib.request"),
MovedAttribute("HTTPErrorProcessor", "urllib2", "urllib.request"),
MovedAttribute("urlretrieve", "urllib", "urllib.request"),
MovedAttribute("urlcleanup", "urllib", "urllib.request"),
MovedAttribute("URLopener", "urllib", "urllib.request"),
MovedAttribute("FancyURLopener", "urllib", "urllib.request"),
MovedAttribute("proxy_bypass", "urllib", "urllib.request"),
MovedAttribute("parse_http_list", "urllib2", "urllib.request"),
MovedAttribute("parse_keqv_list", "urllib2", "urllib.request"),
]
for attr in _urllib_request_moved_attributes:
setattr(Module_six_moves_urllib_request, attr.name, attr)
del attr
Module_six_moves_urllib_request._moved_attributes = _urllib_request_moved_attributes
_importer._add_module(Module_six_moves_urllib_request(__name__ + ".moves.urllib.request"),
"moves.urllib_request", "moves.urllib.request")
class Module_six_moves_urllib_response(_LazyModule):
"""Lazy loading of moved objects in six.moves.urllib_response"""
_urllib_response_moved_attributes = [
MovedAttribute("addbase", "urllib", "urllib.response"),
MovedAttribute("addclosehook", "urllib", "urllib.response"),
MovedAttribute("addinfo", "urllib", "urllib.response"),
MovedAttribute("addinfourl", "urllib", "urllib.response"),
]
for attr in _urllib_response_moved_attributes:
setattr(Module_six_moves_urllib_response, attr.name, attr)
del attr
Module_six_moves_urllib_response._moved_attributes = _urllib_response_moved_attributes
_importer._add_module(Module_six_moves_urllib_response(__name__ + ".moves.urllib.response"),
"moves.urllib_response", "moves.urllib.response")
class Module_six_moves_urllib_robotparser(_LazyModule):
"""Lazy loading of moved objects in six.moves.urllib_robotparser"""
_urllib_robotparser_moved_attributes = [
MovedAttribute("RobotFileParser", "robotparser", "urllib.robotparser"),
]
for attr in _urllib_robotparser_moved_attributes:
setattr(Module_six_moves_urllib_robotparser, attr.name, attr)
del attr
Module_six_moves_urllib_robotparser._moved_attributes = _urllib_robotparser_moved_attributes
_importer._add_module(Module_six_moves_urllib_robotparser(__name__ + ".moves.urllib.robotparser"),
"moves.urllib_robotparser", "moves.urllib.robotparser")
class Module_six_moves_urllib(types.ModuleType):
"""Create a six.moves.urllib namespace that resembles the Python 3 namespace"""
__path__ = [] # mark as package
parse = _importer._get_module("moves.urllib_parse")
error = _importer._get_module("moves.urllib_error")
request = _importer._get_module("moves.urllib_request")
response = _importer._get_module("moves.urllib_response")
robotparser = _importer._get_module("moves.urllib_robotparser")
def __dir__(self):
return ['parse', 'error', 'request', 'response', 'robotparser']
_importer._add_module(Module_six_moves_urllib(__name__ + ".moves.urllib"),
"moves.urllib")
def add_move(move):
"""Add an item to six.moves."""
setattr(_MovedItems, move.name, move)
def remove_move(name):
"""Remove item from six.moves."""
try:
delattr(_MovedItems, name)
except AttributeError:
try:
del moves.__dict__[name]
except KeyError:
raise AttributeError("no such move, %r" % (name,))
if PY3:
_meth_func = "__func__"
_meth_self = "__self__"
_func_closure = "__closure__"
_func_code = "__code__"
_func_defaults = "__defaults__"
_func_globals = "__globals__"
else:
_meth_func = "im_func"
_meth_self = "im_self"
_func_closure = "func_closure"
_func_code = "func_code"
_func_defaults = "func_defaults"
_func_globals = "func_globals"
try:
advance_iterator = next
except NameError:
def advance_iterator(it):
return it.next()
next = advance_iterator
try:
callable = callable
except NameError:
def callable(obj):
return any("__call__" in klass.__dict__ for klass in type(obj).__mro__)
if PY3:
def get_unbound_function(unbound):
return unbound
create_bound_method = types.MethodType
def create_unbound_method(func, cls):
return func
Iterator = object
else:
def get_unbound_function(unbound):
return unbound.im_func
def create_bound_method(func, obj):
return types.MethodType(func, obj, obj.__class__)
def create_unbound_method(func, cls):
return types.MethodType(func, None, cls)
class Iterator(object):
def next(self):
return type(self).__next__(self)
callable = callable
_add_doc(get_unbound_function,
"""Get the function out of a possibly unbound function""")
get_method_function = operator.attrgetter(_meth_func)
get_method_self = operator.attrgetter(_meth_self)
get_function_closure = operator.attrgetter(_func_closure)
get_function_code = operator.attrgetter(_func_code)
get_function_defaults = operator.attrgetter(_func_defaults)
get_function_globals = operator.attrgetter(_func_globals)
if PY3:
def iterkeys(d, **kw):
return iter(d.keys(**kw))
def itervalues(d, **kw):
return iter(d.values(**kw))
def iteritems(d, **kw):
return iter(d.items(**kw))
def iterlists(d, **kw):
return iter(d.lists(**kw))
viewkeys = operator.methodcaller("keys")
viewvalues = operator.methodcaller("values")
viewitems = operator.methodcaller("items")
else:
def iterkeys(d, **kw):
return d.iterkeys(**kw)
def itervalues(d, **kw):
return d.itervalues(**kw)
def iteritems(d, **kw):
return d.iteritems(**kw)
def iterlists(d, **kw):
return d.iterlists(**kw)
viewkeys = operator.methodcaller("viewkeys")
viewvalues = operator.methodcaller("viewvalues")
viewitems = operator.methodcaller("viewitems")
_add_doc(iterkeys, "Return an iterator over the keys of a dictionary.")
_add_doc(itervalues, "Return an iterator over the values of a dictionary.")
_add_doc(iteritems,
"Return an iterator over the (key, value) pairs of a dictionary.")
_add_doc(iterlists,
"Return an iterator over the (key, [values]) pairs of a dictionary.")
if PY3:
def b(s):
return s.encode("latin-1")
def u(s):
return s
unichr = chr
import struct
int2byte = struct.Struct(">B").pack
del struct
byte2int = operator.itemgetter(0)
indexbytes = operator.getitem
iterbytes = iter
import io
StringIO = io.StringIO
BytesIO = io.BytesIO
_assertCountEqual = "assertCountEqual"
if sys.version_info[1] <= 1:
_assertRaisesRegex = "assertRaisesRegexp"
_assertRegex = "assertRegexpMatches"
else:
_assertRaisesRegex = "assertRaisesRegex"
_assertRegex = "assertRegex"
else:
def b(s):
return s
# Workaround for standalone backslash
def u(s):
return unicode(s.replace(r'\\', r'\\\\'), "unicode_escape")
unichr = unichr
int2byte = chr
def byte2int(bs):
return ord(bs[0])
def indexbytes(buf, i):
return ord(buf[i])
iterbytes = functools.partial(itertools.imap, ord)
import StringIO
StringIO = BytesIO = StringIO.StringIO
_assertCountEqual = "assertItemsEqual"
_assertRaisesRegex = "assertRaisesRegexp"
_assertRegex = "assertRegexpMatches"
_add_doc(b, """Byte literal""")
_add_doc(u, """Text literal""")
def assertCountEqual(self, *args, **kwargs):
return getattr(self, _assertCountEqual)(*args, **kwargs)
def assertRaisesRegex(self, *args, **kwargs):
return getattr(self, _assertRaisesRegex)(*args, **kwargs)
def assertRegex(self, *args, **kwargs):
return getattr(self, _assertRegex)(*args, **kwargs)
if PY3:
exec_ = getattr(moves.builtins, "exec")
def reraise(tp, value, tb=None):
try:
if value is None:
value = tp()
if value.__traceback__ is not tb:
raise value.with_traceback(tb)
raise value
finally:
value = None
tb = None
else:
def exec_(_code_, _globs_=None, _locs_=None):
"""Execute code in a namespace."""
if _globs_ is None:
frame = sys._getframe(1)
_globs_ = frame.f_globals
if _locs_ is None:
_locs_ = frame.f_locals
del frame
elif _locs_ is None:
_locs_ = _globs_
exec("""exec _code_ in _globs_, _locs_""")
exec_("""def reraise(tp, value, tb=None):
try:
raise tp, value, tb
finally:
tb = None
""")
if sys.version_info[:2] == (3, 2):
exec_("""def raise_from(value, from_value):
try:
if from_value is None:
raise value
raise value from from_value
finally:
value = None
""")
elif sys.version_info[:2] > (3, 2):
exec_("""def raise_from(value, from_value):
try:
raise value from from_value
finally:
value = None
""")
else:
def raise_from(value, from_value):
raise value
print_ = getattr(moves.builtins, "print", None)
if print_ is None:
def print_(*args, **kwargs):
"""The new-style print function for Python 2.4 and 2.5."""
fp = kwargs.pop("file", sys.stdout)
if fp is None:
return
def write(data):
if not isinstance(data, basestring):
data = str(data)
# If the file has an encoding, encode unicode with it.
if (isinstance(fp, file) and
isinstance(data, unicode) and
fp.encoding is not None):
errors = getattr(fp, "errors", None)
if errors is None:
errors = "strict"
data = data.encode(fp.encoding, errors)
fp.write(data)
want_unicode = False
sep = kwargs.pop("sep", None)
if sep is not None:
if isinstance(sep, unicode):
want_unicode = True
elif not isinstance(sep, str):
raise TypeError("sep must be None or a string")
end = kwargs.pop("end", None)
if end is not None:
if isinstance(end, unicode):
want_unicode = True
elif not isinstance(end, str):
raise TypeError("end must be None or a string")
if kwargs:
raise TypeError("invalid keyword arguments to print()")
if not want_unicode:
for arg in args:
if isinstance(arg, unicode):
want_unicode = True
break
if want_unicode:
newline = unicode("\n")
space = unicode(" ")
else:
newline = "\n"
space = " "
if sep is None:
sep = space
if end is None:
end = newline
for i, arg in enumerate(args):
if i:
write(sep)
write(arg)
write(end)
if sys.version_info[:2] < (3, 3):
_print = print_
def print_(*args, **kwargs):
fp = kwargs.get("file", sys.stdout)
flush = kwargs.pop("flush", False)
_print(*args, **kwargs)
if flush and fp is not None:
fp.flush()
_add_doc(reraise, """Reraise an exception.""")
if sys.version_info[0:2] < (3, 4):
def wraps(wrapped, assigned=functools.WRAPPER_ASSIGNMENTS,
updated=functools.WRAPPER_UPDATES):
def wrapper(f):
f = functools.wraps(wrapped, assigned, updated)(f)
f.__wrapped__ = wrapped
return f
return wrapper
else:
wraps = functools.wraps
def with_metaclass(meta, *bases):
"""Create a base class with a metaclass."""
# This requires a bit of explanation: the basic idea is to make a dummy
# metaclass for one level of class instantiation that replaces itself with
# the actual metaclass.
class metaclass(type):
def __new__(cls, name, this_bases, d):
return meta(name, bases, d)
@classmethod
def __prepare__(cls, name, this_bases):
return meta.__prepare__(name, bases)
return type.__new__(metaclass, 'temporary_class', (), {})
def add_metaclass(metaclass):
"""Class decorator for creating a class with a metaclass."""
def wrapper(cls):
orig_vars = cls.__dict__.copy()
slots = orig_vars.get('__slots__')
if slots is not None:
if isinstance(slots, str):
slots = [slots]
for slots_var in slots:
orig_vars.pop(slots_var)
orig_vars.pop('__dict__', None)
orig_vars.pop('__weakref__', None)
return metaclass(cls.__name__, cls.__bases__, orig_vars)
return wrapper
def python_2_unicode_compatible(klass):
"""
A decorator that defines __unicode__ and __str__ methods under Python 2.
Under Python 3 it does nothing.
To support Python 2 and 3 with a single code base, define a __str__ method
returning text and apply this decorator to the class.
"""
if PY2:
if '__str__' not in klass.__dict__:
raise ValueError("@python_2_unicode_compatible cannot be applied "
"to %s because it doesn't define __str__()." %
klass.__name__)
klass.__unicode__ = klass.__str__
klass.__str__ = lambda self: self.__unicode__().encode('utf-8')
return klass
# Complete the moves implementation.
# This code is at the end of this module to speed up module loading.
# Turn this module into a package.
__path__ = [] # required for PEP 302 and PEP 451
__package__ = __name__ # see PEP 366 @ReservedAssignment
if globals().get("__spec__") is not None:
__spec__.submodule_search_locations = [] # PEP 451 @UndefinedVariable
# Remove other six meta path importers, since they cause problems. This can
# happen if six is removed from sys.modules and then reloaded. (Setuptools does
# this for some reason.)
if sys.meta_path:
for i, importer in enumerate(sys.meta_path):
# Here's some real nastiness: Another "instance" of the six module might
# be floating around. Therefore, we can't use isinstance() to check for
# the six meta path importer, since the other six instance will have
# inserted an importer with different class.
if (type(importer).__name__ == "_SixMetaPathImporter" and
importer.name == __name__):
del sys.meta_path[i]
break
del i, importer
# Finally, add the importer to the meta path import hook.
sys.meta_path.append(_importer)

View File

@@ -0,0 +1,58 @@
# pylint: skip-file
# vendored from https://github.com/mhils/backports.socketpair
from __future__ import absolute_import
import sys
import socket
import errno
_LOCALHOST = '127.0.0.1'
_LOCALHOST_V6 = '::1'
if not hasattr(socket, "socketpair"):
# Origin: https://gist.github.com/4325783, by Geert Jansen. Public domain.
def socketpair(family=socket.AF_INET, type=socket.SOCK_STREAM, proto=0):
if family == socket.AF_INET:
host = _LOCALHOST
elif family == socket.AF_INET6:
host = _LOCALHOST_V6
else:
raise ValueError("Only AF_INET and AF_INET6 socket address families "
"are supported")
if type != socket.SOCK_STREAM:
raise ValueError("Only SOCK_STREAM socket type is supported")
if proto != 0:
raise ValueError("Only protocol zero is supported")
# We create a connected TCP socket. Note the trick with
# setblocking(False) that prevents us from having to create a thread.
lsock = socket.socket(family, type, proto)
try:
lsock.bind((host, 0))
lsock.listen(min(socket.SOMAXCONN, 128))
# On IPv6, ignore flow_info and scope_id
addr, port = lsock.getsockname()[:2]
csock = socket.socket(family, type, proto)
try:
csock.setblocking(False)
if sys.version_info >= (3, 0):
try:
csock.connect((addr, port))
except (BlockingIOError, InterruptedError):
pass
else:
try:
csock.connect((addr, port))
except socket.error as e:
if e.errno != errno.WSAEWOULDBLOCK:
raise
csock.setblocking(True)
ssock, _ = lsock.accept()
except Exception:
csock.close()
raise
finally:
lsock.close()
return (ssock, csock)
socket.socketpair = socketpair

View File

@@ -0,0 +1 @@
__version__ = '2.0.2'