Major fixes and new features
All checks were successful
continuous-integration/drone/push Build is passing

This commit is contained in:
2025-09-25 15:51:48 +09:00
parent dd7349bb4c
commit ddce9f5125
5586 changed files with 1470941 additions and 0 deletions

View File

@@ -0,0 +1,72 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
# Copyright (C) 2003-2007, 2009, 2011 Nominum, Inc.
#
# Permission to use, copy, modify, and distribute this software and its
# documentation for any purpose with or without fee is hereby granted,
# provided that the above copyright notice and this permission notice
# appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
"""dnspython DNS toolkit"""
__all__ = [
"asyncbackend",
"asyncquery",
"asyncresolver",
"btree",
"btreezone",
"dnssec",
"dnssecalgs",
"dnssectypes",
"e164",
"edns",
"entropy",
"exception",
"flags",
"immutable",
"inet",
"ipv4",
"ipv6",
"message",
"name",
"namedict",
"node",
"opcode",
"query",
"quic",
"rcode",
"rdata",
"rdataclass",
"rdataset",
"rdatatype",
"renderer",
"resolver",
"reversename",
"rrset",
"serial",
"set",
"tokenizer",
"transaction",
"tsig",
"tsigkeyring",
"ttl",
"rdtypes",
"update",
"version",
"versioned",
"wire",
"xfr",
"zone",
"zonetypes",
"zonefile",
]
from dns.version import version as __version__ # noqa

View File

@@ -0,0 +1,100 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
# This is a nullcontext for both sync and async. 3.7 has a nullcontext,
# but it is only for sync use.
class NullContext:
def __init__(self, enter_result=None):
self.enter_result = enter_result
def __enter__(self):
return self.enter_result
def __exit__(self, exc_type, exc_value, traceback):
pass
async def __aenter__(self):
return self.enter_result
async def __aexit__(self, exc_type, exc_value, traceback):
pass
# These are declared here so backends can import them without creating
# circular dependencies with dns.asyncbackend.
class Socket: # pragma: no cover
def __init__(self, family: int, type: int):
self.family = family
self.type = type
async def close(self):
pass
async def getpeername(self):
raise NotImplementedError
async def getsockname(self):
raise NotImplementedError
async def getpeercert(self, timeout):
raise NotImplementedError
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_value, traceback):
await self.close()
class DatagramSocket(Socket): # pragma: no cover
async def sendto(self, what, destination, timeout):
raise NotImplementedError
async def recvfrom(self, size, timeout):
raise NotImplementedError
class StreamSocket(Socket): # pragma: no cover
async def sendall(self, what, timeout):
raise NotImplementedError
async def recv(self, size, timeout):
raise NotImplementedError
class NullTransport:
async def connect_tcp(self, host, port, timeout, local_address):
raise NotImplementedError
class Backend: # pragma: no cover
def name(self) -> str:
return "unknown"
async def make_socket(
self,
af,
socktype,
proto=0,
source=None,
destination=None,
timeout=None,
ssl_context=None,
server_hostname=None,
):
raise NotImplementedError
def datagram_connection_required(self):
return False
async def sleep(self, interval):
raise NotImplementedError
def get_transport_class(self):
raise NotImplementedError
async def wait_for(self, awaitable, timeout):
raise NotImplementedError

View File

@@ -0,0 +1,276 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
"""asyncio library query support"""
import asyncio
import socket
import sys
import dns._asyncbackend
import dns._features
import dns.exception
import dns.inet
_is_win32 = sys.platform == "win32"
def _get_running_loop():
try:
return asyncio.get_running_loop()
except AttributeError: # pragma: no cover
return asyncio.get_event_loop()
class _DatagramProtocol:
def __init__(self):
self.transport = None
self.recvfrom = None
def connection_made(self, transport):
self.transport = transport
def datagram_received(self, data, addr):
if self.recvfrom and not self.recvfrom.done():
self.recvfrom.set_result((data, addr))
def error_received(self, exc): # pragma: no cover
if self.recvfrom and not self.recvfrom.done():
self.recvfrom.set_exception(exc)
def connection_lost(self, exc):
if self.recvfrom and not self.recvfrom.done():
if exc is None:
# EOF we triggered. Is there a better way to do this?
try:
raise EOFError("EOF")
except EOFError as e:
self.recvfrom.set_exception(e)
else:
self.recvfrom.set_exception(exc)
def close(self):
if self.transport is not None:
self.transport.close()
async def _maybe_wait_for(awaitable, timeout):
if timeout is not None:
try:
return await asyncio.wait_for(awaitable, timeout)
except asyncio.TimeoutError:
raise dns.exception.Timeout(timeout=timeout)
else:
return await awaitable
class DatagramSocket(dns._asyncbackend.DatagramSocket):
def __init__(self, family, transport, protocol):
super().__init__(family, socket.SOCK_DGRAM)
self.transport = transport
self.protocol = protocol
async def sendto(self, what, destination, timeout): # pragma: no cover
# no timeout for asyncio sendto
self.transport.sendto(what, destination)
return len(what)
async def recvfrom(self, size, timeout):
# ignore size as there's no way I know to tell protocol about it
done = _get_running_loop().create_future()
try:
assert self.protocol.recvfrom is None
self.protocol.recvfrom = done
await _maybe_wait_for(done, timeout)
return done.result()
finally:
self.protocol.recvfrom = None
async def close(self):
self.protocol.close()
async def getpeername(self):
return self.transport.get_extra_info("peername")
async def getsockname(self):
return self.transport.get_extra_info("sockname")
async def getpeercert(self, timeout):
raise NotImplementedError
class StreamSocket(dns._asyncbackend.StreamSocket):
def __init__(self, af, reader, writer):
super().__init__(af, socket.SOCK_STREAM)
self.reader = reader
self.writer = writer
async def sendall(self, what, timeout):
self.writer.write(what)
return await _maybe_wait_for(self.writer.drain(), timeout)
async def recv(self, size, timeout):
return await _maybe_wait_for(self.reader.read(size), timeout)
async def close(self):
self.writer.close()
async def getpeername(self):
return self.writer.get_extra_info("peername")
async def getsockname(self):
return self.writer.get_extra_info("sockname")
async def getpeercert(self, timeout):
return self.writer.get_extra_info("peercert")
if dns._features.have("doh"):
import anyio
import httpcore
import httpcore._backends.anyio
import httpx
_CoreAsyncNetworkBackend = httpcore.AsyncNetworkBackend
_CoreAnyIOStream = httpcore._backends.anyio.AnyIOStream # pyright: ignore
from dns.query import _compute_times, _expiration_for_this_attempt, _remaining
class _NetworkBackend(_CoreAsyncNetworkBackend):
def __init__(self, resolver, local_port, bootstrap_address, family):
super().__init__()
self._local_port = local_port
self._resolver = resolver
self._bootstrap_address = bootstrap_address
self._family = family
if local_port != 0:
raise NotImplementedError(
"the asyncio transport for HTTPX cannot set the local port"
)
async def connect_tcp(
self, host, port, timeout=None, local_address=None, socket_options=None
): # pylint: disable=signature-differs
addresses = []
_, expiration = _compute_times(timeout)
if dns.inet.is_address(host):
addresses.append(host)
elif self._bootstrap_address is not None:
addresses.append(self._bootstrap_address)
else:
timeout = _remaining(expiration)
family = self._family
if local_address:
family = dns.inet.af_for_address(local_address)
answers = await self._resolver.resolve_name(
host, family=family, lifetime=timeout
)
addresses = answers.addresses()
for address in addresses:
try:
attempt_expiration = _expiration_for_this_attempt(2.0, expiration)
timeout = _remaining(attempt_expiration)
with anyio.fail_after(timeout):
stream = await anyio.connect_tcp(
remote_host=address,
remote_port=port,
local_host=local_address,
)
return _CoreAnyIOStream(stream)
except Exception:
pass
raise httpcore.ConnectError
async def connect_unix_socket(
self, path, timeout=None, socket_options=None
): # pylint: disable=signature-differs
raise NotImplementedError
async def sleep(self, seconds): # pylint: disable=signature-differs
await anyio.sleep(seconds)
class _HTTPTransport(httpx.AsyncHTTPTransport):
def __init__(
self,
*args,
local_port=0,
bootstrap_address=None,
resolver=None,
family=socket.AF_UNSPEC,
**kwargs,
):
if resolver is None and bootstrap_address is None:
# pylint: disable=import-outside-toplevel,redefined-outer-name
import dns.asyncresolver
resolver = dns.asyncresolver.Resolver()
super().__init__(*args, **kwargs)
self._pool._network_backend = _NetworkBackend(
resolver, local_port, bootstrap_address, family
)
else:
_HTTPTransport = dns._asyncbackend.NullTransport # type: ignore
class Backend(dns._asyncbackend.Backend):
def name(self):
return "asyncio"
async def make_socket(
self,
af,
socktype,
proto=0,
source=None,
destination=None,
timeout=None,
ssl_context=None,
server_hostname=None,
):
loop = _get_running_loop()
if socktype == socket.SOCK_DGRAM:
if _is_win32 and source is None:
# Win32 wants explicit binding before recvfrom(). This is the
# proper fix for [#637].
source = (dns.inet.any_for_af(af), 0)
transport, protocol = await loop.create_datagram_endpoint(
_DatagramProtocol, # pyright: ignore
source,
family=af,
proto=proto,
remote_addr=destination,
)
return DatagramSocket(af, transport, protocol)
elif socktype == socket.SOCK_STREAM:
if destination is None:
# This shouldn't happen, but we check to make code analysis software
# happier.
raise ValueError("destination required for stream sockets")
(r, w) = await _maybe_wait_for(
asyncio.open_connection(
destination[0],
destination[1],
ssl=ssl_context,
family=af,
proto=proto,
local_addr=source,
server_hostname=server_hostname,
),
timeout,
)
return StreamSocket(af, r, w)
raise NotImplementedError(
"unsupported socket " + f"type {socktype}"
) # pragma: no cover
async def sleep(self, interval):
await asyncio.sleep(interval)
def datagram_connection_required(self):
return False
def get_transport_class(self):
return _HTTPTransport
async def wait_for(self, awaitable, timeout):
return await _maybe_wait_for(awaitable, timeout)

View File

@@ -0,0 +1,154 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
#
# Support for Discovery of Designated Resolvers
import socket
import time
from urllib.parse import urlparse
import dns.asyncbackend
import dns.inet
import dns.name
import dns.nameserver
import dns.query
import dns.rdtypes.svcbbase
# The special name of the local resolver when using DDR
_local_resolver_name = dns.name.from_text("_dns.resolver.arpa")
#
# Processing is split up into I/O independent and I/O dependent parts to
# make supporting sync and async versions easy.
#
class _SVCBInfo:
def __init__(self, bootstrap_address, port, hostname, nameservers):
self.bootstrap_address = bootstrap_address
self.port = port
self.hostname = hostname
self.nameservers = nameservers
def ddr_check_certificate(self, cert):
"""Verify that the _SVCBInfo's address is in the cert's subjectAltName (SAN)"""
for name, value in cert["subjectAltName"]:
if name == "IP Address" and value == self.bootstrap_address:
return True
return False
def make_tls_context(self):
ssl = dns.query.ssl
ctx = ssl.create_default_context()
ctx.minimum_version = ssl.TLSVersion.TLSv1_2
return ctx
def ddr_tls_check_sync(self, lifetime):
ctx = self.make_tls_context()
expiration = time.time() + lifetime
with socket.create_connection(
(self.bootstrap_address, self.port), lifetime
) as s:
with ctx.wrap_socket(s, server_hostname=self.hostname) as ts:
ts.settimeout(dns.query._remaining(expiration))
ts.do_handshake()
cert = ts.getpeercert()
return self.ddr_check_certificate(cert)
async def ddr_tls_check_async(self, lifetime, backend=None):
if backend is None:
backend = dns.asyncbackend.get_default_backend()
ctx = self.make_tls_context()
expiration = time.time() + lifetime
async with await backend.make_socket(
dns.inet.af_for_address(self.bootstrap_address),
socket.SOCK_STREAM,
0,
None,
(self.bootstrap_address, self.port),
lifetime,
ctx,
self.hostname,
) as ts:
cert = await ts.getpeercert(dns.query._remaining(expiration))
return self.ddr_check_certificate(cert)
def _extract_nameservers_from_svcb(answer):
bootstrap_address = answer.nameserver
if not dns.inet.is_address(bootstrap_address):
return []
infos = []
for rr in answer.rrset.processing_order():
nameservers = []
param = rr.params.get(dns.rdtypes.svcbbase.ParamKey.ALPN)
if param is None:
continue
alpns = set(param.ids)
host = rr.target.to_text(omit_final_dot=True)
port = None
param = rr.params.get(dns.rdtypes.svcbbase.ParamKey.PORT)
if param is not None:
port = param.port
# For now we ignore address hints and address resolution and always use the
# bootstrap address
if b"h2" in alpns:
param = rr.params.get(dns.rdtypes.svcbbase.ParamKey.DOHPATH)
if param is None or not param.value.endswith(b"{?dns}"):
continue
path = param.value[:-6].decode()
if not path.startswith("/"):
path = "/" + path
if port is None:
port = 443
url = f"https://{host}:{port}{path}"
# check the URL
try:
urlparse(url)
nameservers.append(dns.nameserver.DoHNameserver(url, bootstrap_address))
except Exception:
# continue processing other ALPN types
pass
if b"dot" in alpns:
if port is None:
port = 853
nameservers.append(
dns.nameserver.DoTNameserver(bootstrap_address, port, host)
)
if b"doq" in alpns:
if port is None:
port = 853
nameservers.append(
dns.nameserver.DoQNameserver(bootstrap_address, port, True, host)
)
if len(nameservers) > 0:
infos.append(_SVCBInfo(bootstrap_address, port, host, nameservers))
return infos
def _get_nameservers_sync(answer, lifetime):
"""Return a list of TLS-validated resolver nameservers extracted from an SVCB
answer."""
nameservers = []
infos = _extract_nameservers_from_svcb(answer)
for info in infos:
try:
if info.ddr_tls_check_sync(lifetime):
nameservers.extend(info.nameservers)
except Exception:
pass
return nameservers
async def _get_nameservers_async(answer, lifetime):
"""Return a list of TLS-validated resolver nameservers extracted from an SVCB
answer."""
nameservers = []
infos = _extract_nameservers_from_svcb(answer)
for info in infos:
try:
if await info.ddr_tls_check_async(lifetime):
nameservers.extend(info.nameservers)
except Exception:
pass
return nameservers

View File

@@ -0,0 +1,95 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
import importlib.metadata
import itertools
import string
from typing import Dict, List, Tuple
def _tuple_from_text(version: str) -> Tuple:
text_parts = version.split(".")
int_parts = []
for text_part in text_parts:
digit_prefix = "".join(
itertools.takewhile(lambda x: x in string.digits, text_part)
)
try:
int_parts.append(int(digit_prefix))
except Exception:
break
return tuple(int_parts)
def _version_check(
requirement: str,
) -> bool:
"""Is the requirement fulfilled?
The requirement must be of the form
package>=version
"""
package, minimum = requirement.split(">=")
try:
version = importlib.metadata.version(package)
# This shouldn't happen, but it apparently can.
if version is None:
return False
except Exception:
return False
t_version = _tuple_from_text(version)
t_minimum = _tuple_from_text(minimum)
if t_version < t_minimum:
return False
return True
_cache: Dict[str, bool] = {}
def have(feature: str) -> bool:
"""Is *feature* available?
This tests if all optional packages needed for the
feature are available and recent enough.
Returns ``True`` if the feature is available,
and ``False`` if it is not or if metadata is
missing.
"""
value = _cache.get(feature)
if value is not None:
return value
requirements = _requirements.get(feature)
if requirements is None:
# we make a cache entry here for consistency not performance
_cache[feature] = False
return False
ok = True
for requirement in requirements:
if not _version_check(requirement):
ok = False
break
_cache[feature] = ok
return ok
def force(feature: str, enabled: bool) -> None:
"""Force the status of *feature* to be *enabled*.
This method is provided as a workaround for any cases
where importlib.metadata is ineffective, or for testing.
"""
_cache[feature] = enabled
_requirements: Dict[str, List[str]] = {
### BEGIN generated requirements
"dnssec": ["cryptography>=45"],
"doh": ["httpcore>=1.0.0", "httpx>=0.28.0", "h2>=4.2.0"],
"doq": ["aioquic>=1.2.0"],
"idna": ["idna>=3.10"],
"trio": ["trio>=0.30"],
"wmi": ["wmi>=1.5.1"],
### END generated requirements
}

View File

@@ -0,0 +1,76 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
# This implementation of the immutable decorator requires python >=
# 3.7, and is significantly more storage efficient when making classes
# with slots immutable. It's also faster.
import contextvars
import inspect
_in__init__ = contextvars.ContextVar("_immutable_in__init__", default=False)
class _Immutable:
"""Immutable mixin class"""
# We set slots to the empty list to say "we don't have any attributes".
# We do this so that if we're mixed in with a class with __slots__, we
# don't cause a __dict__ to be added which would waste space.
__slots__ = ()
def __setattr__(self, name, value):
if _in__init__.get() is not self:
raise TypeError("object doesn't support attribute assignment")
else:
super().__setattr__(name, value)
def __delattr__(self, name):
if _in__init__.get() is not self:
raise TypeError("object doesn't support attribute assignment")
else:
super().__delattr__(name)
def _immutable_init(f):
def nf(*args, **kwargs):
previous = _in__init__.set(args[0])
try:
# call the actual __init__
f(*args, **kwargs)
finally:
_in__init__.reset(previous)
nf.__signature__ = inspect.signature(f) # pyright: ignore
return nf
def immutable(cls):
if _Immutable in cls.__mro__:
# Some ancestor already has the mixin, so just make sure we keep
# following the __init__ protocol.
cls.__init__ = _immutable_init(cls.__init__)
if hasattr(cls, "__setstate__"):
cls.__setstate__ = _immutable_init(cls.__setstate__)
ncls = cls
else:
# Mixin the Immutable class and follow the __init__ protocol.
class ncls(_Immutable, cls):
# We have to do the __slots__ declaration here too!
__slots__ = ()
@_immutable_init
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if hasattr(cls, "__setstate__"):
@_immutable_init
def __setstate__(self, *args, **kwargs):
super().__setstate__(*args, **kwargs)
# make ncls have the same name and module as cls
ncls.__name__ = cls.__name__
ncls.__qualname__ = cls.__qualname__
ncls.__module__ = cls.__module__
return ncls

View File

@@ -0,0 +1,61 @@
import enum
from typing import Any
CERT_NONE = 0
class TLSVersion(enum.IntEnum):
TLSv1_2 = 12
class WantReadException(Exception):
pass
class WantWriteException(Exception):
pass
class SSLWantReadError(Exception):
pass
class SSLWantWriteError(Exception):
pass
class SSLContext:
def __init__(self) -> None:
self.minimum_version: Any = TLSVersion.TLSv1_2
self.check_hostname: bool = False
self.verify_mode: int = CERT_NONE
def wrap_socket(self, *args, **kwargs) -> "SSLSocket": # type: ignore
raise Exception("no ssl support") # pylint: disable=broad-exception-raised
def set_alpn_protocols(self, *args, **kwargs): # type: ignore
raise Exception("no ssl support") # pylint: disable=broad-exception-raised
class SSLSocket:
def pending(self) -> bool:
raise Exception("no ssl support") # pylint: disable=broad-exception-raised
def do_handshake(self) -> None:
raise Exception("no ssl support") # pylint: disable=broad-exception-raised
def settimeout(self, value: Any) -> None:
pass
def getpeercert(self) -> Any:
raise Exception("no ssl support") # pylint: disable=broad-exception-raised
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
return False
def create_default_context(*args, **kwargs) -> SSLContext: # type: ignore
raise Exception("no ssl support") # pylint: disable=broad-exception-raised

View File

@@ -0,0 +1,19 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
import os
from typing import Tuple
def convert_verify_to_cafile_and_capath(
verify: bool | str,
) -> Tuple[str | None, str | None]:
cafile: str | None = None
capath: str | None = None
if isinstance(verify, str):
if os.path.isfile(verify):
cafile = verify
elif os.path.isdir(verify):
capath = verify
else:
raise ValueError("invalid verify string")
return cafile, capath

View File

@@ -0,0 +1,255 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
"""trio async I/O library query support"""
import socket
import trio
import trio.socket # type: ignore
import dns._asyncbackend
import dns._features
import dns.exception
import dns.inet
if not dns._features.have("trio"):
raise ImportError("trio not found or too old")
def _maybe_timeout(timeout):
if timeout is not None:
return trio.move_on_after(timeout)
else:
return dns._asyncbackend.NullContext()
# for brevity
_lltuple = dns.inet.low_level_address_tuple
# pylint: disable=redefined-outer-name
class DatagramSocket(dns._asyncbackend.DatagramSocket):
def __init__(self, sock):
super().__init__(sock.family, socket.SOCK_DGRAM)
self.socket = sock
async def sendto(self, what, destination, timeout):
with _maybe_timeout(timeout):
if destination is None:
return await self.socket.send(what)
else:
return await self.socket.sendto(what, destination)
raise dns.exception.Timeout(
timeout=timeout
) # pragma: no cover lgtm[py/unreachable-statement]
async def recvfrom(self, size, timeout):
with _maybe_timeout(timeout):
return await self.socket.recvfrom(size)
raise dns.exception.Timeout(timeout=timeout) # lgtm[py/unreachable-statement]
async def close(self):
self.socket.close()
async def getpeername(self):
return self.socket.getpeername()
async def getsockname(self):
return self.socket.getsockname()
async def getpeercert(self, timeout):
raise NotImplementedError
class StreamSocket(dns._asyncbackend.StreamSocket):
def __init__(self, family, stream, tls=False):
super().__init__(family, socket.SOCK_STREAM)
self.stream = stream
self.tls = tls
async def sendall(self, what, timeout):
with _maybe_timeout(timeout):
return await self.stream.send_all(what)
raise dns.exception.Timeout(timeout=timeout) # lgtm[py/unreachable-statement]
async def recv(self, size, timeout):
with _maybe_timeout(timeout):
return await self.stream.receive_some(size)
raise dns.exception.Timeout(timeout=timeout) # lgtm[py/unreachable-statement]
async def close(self):
await self.stream.aclose()
async def getpeername(self):
if self.tls:
return self.stream.transport_stream.socket.getpeername()
else:
return self.stream.socket.getpeername()
async def getsockname(self):
if self.tls:
return self.stream.transport_stream.socket.getsockname()
else:
return self.stream.socket.getsockname()
async def getpeercert(self, timeout):
if self.tls:
with _maybe_timeout(timeout):
await self.stream.do_handshake()
return self.stream.getpeercert()
else:
raise NotImplementedError
if dns._features.have("doh"):
import httpcore
import httpcore._backends.trio
import httpx
_CoreAsyncNetworkBackend = httpcore.AsyncNetworkBackend
_CoreTrioStream = httpcore._backends.trio.TrioStream
from dns.query import _compute_times, _expiration_for_this_attempt, _remaining
class _NetworkBackend(_CoreAsyncNetworkBackend):
def __init__(self, resolver, local_port, bootstrap_address, family):
super().__init__()
self._local_port = local_port
self._resolver = resolver
self._bootstrap_address = bootstrap_address
self._family = family
async def connect_tcp(
self, host, port, timeout=None, local_address=None, socket_options=None
): # pylint: disable=signature-differs
addresses = []
_, expiration = _compute_times(timeout)
if dns.inet.is_address(host):
addresses.append(host)
elif self._bootstrap_address is not None:
addresses.append(self._bootstrap_address)
else:
timeout = _remaining(expiration)
family = self._family
if local_address:
family = dns.inet.af_for_address(local_address)
answers = await self._resolver.resolve_name(
host, family=family, lifetime=timeout
)
addresses = answers.addresses()
for address in addresses:
try:
af = dns.inet.af_for_address(address)
if local_address is not None or self._local_port != 0:
source = (local_address, self._local_port)
else:
source = None
destination = (address, port)
attempt_expiration = _expiration_for_this_attempt(2.0, expiration)
timeout = _remaining(attempt_expiration)
sock = await Backend().make_socket(
af, socket.SOCK_STREAM, 0, source, destination, timeout
)
assert isinstance(sock, StreamSocket)
return _CoreTrioStream(sock.stream)
except Exception:
continue
raise httpcore.ConnectError
async def connect_unix_socket(
self, path, timeout=None, socket_options=None
): # pylint: disable=signature-differs
raise NotImplementedError
async def sleep(self, seconds): # pylint: disable=signature-differs
await trio.sleep(seconds)
class _HTTPTransport(httpx.AsyncHTTPTransport):
def __init__(
self,
*args,
local_port=0,
bootstrap_address=None,
resolver=None,
family=socket.AF_UNSPEC,
**kwargs,
):
if resolver is None and bootstrap_address is None:
# pylint: disable=import-outside-toplevel,redefined-outer-name
import dns.asyncresolver
resolver = dns.asyncresolver.Resolver()
super().__init__(*args, **kwargs)
self._pool._network_backend = _NetworkBackend(
resolver, local_port, bootstrap_address, family
)
else:
_HTTPTransport = dns._asyncbackend.NullTransport # type: ignore
class Backend(dns._asyncbackend.Backend):
def name(self):
return "trio"
async def make_socket(
self,
af,
socktype,
proto=0,
source=None,
destination=None,
timeout=None,
ssl_context=None,
server_hostname=None,
):
s = trio.socket.socket(af, socktype, proto)
stream = None
try:
if source:
await s.bind(_lltuple(source, af))
if socktype == socket.SOCK_STREAM or destination is not None:
connected = False
with _maybe_timeout(timeout):
assert destination is not None
await s.connect(_lltuple(destination, af))
connected = True
if not connected:
raise dns.exception.Timeout(
timeout=timeout
) # lgtm[py/unreachable-statement]
except Exception: # pragma: no cover
s.close()
raise
if socktype == socket.SOCK_DGRAM:
return DatagramSocket(s)
elif socktype == socket.SOCK_STREAM:
stream = trio.SocketStream(s)
tls = False
if ssl_context:
tls = True
try:
stream = trio.SSLStream(
stream, ssl_context, server_hostname=server_hostname
)
except Exception: # pragma: no cover
await stream.aclose()
raise
return StreamSocket(af, stream, tls)
raise NotImplementedError(
"unsupported socket " + f"type {socktype}"
) # pragma: no cover
async def sleep(self, interval):
await trio.sleep(interval)
def get_transport_class(self):
return _HTTPTransport
async def wait_for(self, awaitable, timeout):
with _maybe_timeout(timeout):
return await awaitable
raise dns.exception.Timeout(
timeout=timeout
) # pragma: no cover lgtm[py/unreachable-statement]

View File

@@ -0,0 +1,101 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
from typing import Dict
import dns.exception
# pylint: disable=unused-import
from dns._asyncbackend import ( # noqa: F401 lgtm[py/unused-import]
Backend,
DatagramSocket,
Socket,
StreamSocket,
)
# pylint: enable=unused-import
_default_backend = None
_backends: Dict[str, Backend] = {}
# Allow sniffio import to be disabled for testing purposes
_no_sniffio = False
class AsyncLibraryNotFoundError(dns.exception.DNSException):
pass
def get_backend(name: str) -> Backend:
"""Get the specified asynchronous backend.
*name*, a ``str``, the name of the backend. Currently the "trio"
and "asyncio" backends are available.
Raises NotImplementedError if an unknown backend name is specified.
"""
# pylint: disable=import-outside-toplevel,redefined-outer-name
backend = _backends.get(name)
if backend:
return backend
if name == "trio":
import dns._trio_backend
backend = dns._trio_backend.Backend()
elif name == "asyncio":
import dns._asyncio_backend
backend = dns._asyncio_backend.Backend()
else:
raise NotImplementedError(f"unimplemented async backend {name}")
_backends[name] = backend
return backend
def sniff() -> str:
"""Attempt to determine the in-use asynchronous I/O library by using
the ``sniffio`` module if it is available.
Returns the name of the library, or raises AsyncLibraryNotFoundError
if the library cannot be determined.
"""
# pylint: disable=import-outside-toplevel
try:
if _no_sniffio:
raise ImportError
import sniffio
try:
return sniffio.current_async_library()
except sniffio.AsyncLibraryNotFoundError:
raise AsyncLibraryNotFoundError("sniffio cannot determine async library")
except ImportError:
import asyncio
try:
asyncio.get_running_loop()
return "asyncio"
except RuntimeError:
raise AsyncLibraryNotFoundError("no async library detected")
def get_default_backend() -> Backend:
"""Get the default backend, initializing it if necessary."""
if _default_backend:
return _default_backend
return set_default_backend(sniff())
def set_default_backend(name: str) -> Backend:
"""Set the default backend.
It's not normally necessary to call this method, as
``get_default_backend()`` will initialize the backend
appropriately in many cases. If ``sniffio`` is not installed, or
in testing situations, this function allows the backend to be set
explicitly.
"""
global _default_backend
_default_backend = get_backend(name)
return _default_backend

View File

@@ -0,0 +1,953 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
# Copyright (C) 2003-2017 Nominum, Inc.
#
# Permission to use, copy, modify, and distribute this software and its
# documentation for any purpose with or without fee is hereby granted,
# provided that the above copyright notice and this permission notice
# appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
"""Talk to a DNS server."""
import base64
import contextlib
import random
import socket
import struct
import time
import urllib.parse
from typing import Any, Dict, Optional, Tuple, cast
import dns.asyncbackend
import dns.exception
import dns.inet
import dns.message
import dns.name
import dns.quic
import dns.rdatatype
import dns.transaction
import dns.tsig
import dns.xfr
from dns._asyncbackend import NullContext
from dns.query import (
BadResponse,
HTTPVersion,
NoDOH,
NoDOQ,
UDPMode,
_check_status,
_compute_times,
_matches_destination,
_remaining,
have_doh,
make_ssl_context,
)
try:
import ssl
except ImportError:
import dns._no_ssl as ssl # type: ignore
if have_doh:
import httpx
# for brevity
_lltuple = dns.inet.low_level_address_tuple
def _source_tuple(af, address, port):
# Make a high level source tuple, or return None if address and port
# are both None
if address or port:
if address is None:
if af == socket.AF_INET:
address = "0.0.0.0"
elif af == socket.AF_INET6:
address = "::"
else:
raise NotImplementedError(f"unknown address family {af}")
return (address, port)
else:
return None
def _timeout(expiration, now=None):
if expiration is not None:
if not now:
now = time.time()
return max(expiration - now, 0)
else:
return None
async def send_udp(
sock: dns.asyncbackend.DatagramSocket,
what: dns.message.Message | bytes,
destination: Any,
expiration: float | None = None,
) -> Tuple[int, float]:
"""Send a DNS message to the specified UDP socket.
*sock*, a ``dns.asyncbackend.DatagramSocket``.
*what*, a ``bytes`` or ``dns.message.Message``, the message to send.
*destination*, a destination tuple appropriate for the address family
of the socket, specifying where to send the query.
*expiration*, a ``float`` or ``None``, the absolute time at which
a timeout exception should be raised. If ``None``, no timeout will
occur. The expiration value is meaningless for the asyncio backend, as
asyncio's transport sendto() never blocks.
Returns an ``(int, float)`` tuple of bytes sent and the sent time.
"""
if isinstance(what, dns.message.Message):
what = what.to_wire()
sent_time = time.time()
n = await sock.sendto(what, destination, _timeout(expiration, sent_time))
return (n, sent_time)
async def receive_udp(
sock: dns.asyncbackend.DatagramSocket,
destination: Any | None = None,
expiration: float | None = None,
ignore_unexpected: bool = False,
one_rr_per_rrset: bool = False,
keyring: Dict[dns.name.Name, dns.tsig.Key] | None = None,
request_mac: bytes | None = b"",
ignore_trailing: bool = False,
raise_on_truncation: bool = False,
ignore_errors: bool = False,
query: dns.message.Message | None = None,
) -> Any:
"""Read a DNS message from a UDP socket.
*sock*, a ``dns.asyncbackend.DatagramSocket``.
See :py:func:`dns.query.receive_udp()` for the documentation of the other
parameters, and exceptions.
Returns a ``(dns.message.Message, float, tuple)`` tuple of the received message, the
received time, and the address where the message arrived from.
"""
wire = b""
while True:
(wire, from_address) = await sock.recvfrom(65535, _timeout(expiration))
if not _matches_destination(
sock.family, from_address, destination, ignore_unexpected
):
continue
received_time = time.time()
try:
r = dns.message.from_wire(
wire,
keyring=keyring,
request_mac=request_mac,
one_rr_per_rrset=one_rr_per_rrset,
ignore_trailing=ignore_trailing,
raise_on_truncation=raise_on_truncation,
)
except dns.message.Truncated as e:
# See the comment in query.py for details.
if (
ignore_errors
and query is not None
and not query.is_response(e.message())
):
continue
else:
raise
except Exception:
if ignore_errors:
continue
else:
raise
if ignore_errors and query is not None and not query.is_response(r):
continue
return (r, received_time, from_address)
async def udp(
q: dns.message.Message,
where: str,
timeout: float | None = None,
port: int = 53,
source: str | None = None,
source_port: int = 0,
ignore_unexpected: bool = False,
one_rr_per_rrset: bool = False,
ignore_trailing: bool = False,
raise_on_truncation: bool = False,
sock: dns.asyncbackend.DatagramSocket | None = None,
backend: dns.asyncbackend.Backend | None = None,
ignore_errors: bool = False,
) -> dns.message.Message:
"""Return the response obtained after sending a query via UDP.
*sock*, a ``dns.asyncbackend.DatagramSocket``, or ``None``,
the socket to use for the query. If ``None``, the default, a
socket is created. Note that if a socket is provided, the
*source*, *source_port*, and *backend* are ignored.
*backend*, a ``dns.asyncbackend.Backend``, or ``None``. If ``None``,
the default, then dnspython will use the default backend.
See :py:func:`dns.query.udp()` for the documentation of the other
parameters, exceptions, and return type of this method.
"""
wire = q.to_wire()
(begin_time, expiration) = _compute_times(timeout)
af = dns.inet.af_for_address(where)
destination = _lltuple((where, port), af)
if sock:
cm: contextlib.AbstractAsyncContextManager = NullContext(sock)
else:
if not backend:
backend = dns.asyncbackend.get_default_backend()
stuple = _source_tuple(af, source, source_port)
if backend.datagram_connection_required():
dtuple = (where, port)
else:
dtuple = None
cm = await backend.make_socket(af, socket.SOCK_DGRAM, 0, stuple, dtuple)
async with cm as s:
await send_udp(s, wire, destination, expiration) # pyright: ignore
(r, received_time, _) = await receive_udp(
s, # pyright: ignore
destination,
expiration,
ignore_unexpected,
one_rr_per_rrset,
q.keyring,
q.mac,
ignore_trailing,
raise_on_truncation,
ignore_errors,
q,
)
r.time = received_time - begin_time
# We don't need to check q.is_response() if we are in ignore_errors mode
# as receive_udp() will have checked it.
if not (ignore_errors or q.is_response(r)):
raise BadResponse
return r
async def udp_with_fallback(
q: dns.message.Message,
where: str,
timeout: float | None = None,
port: int = 53,
source: str | None = None,
source_port: int = 0,
ignore_unexpected: bool = False,
one_rr_per_rrset: bool = False,
ignore_trailing: bool = False,
udp_sock: dns.asyncbackend.DatagramSocket | None = None,
tcp_sock: dns.asyncbackend.StreamSocket | None = None,
backend: dns.asyncbackend.Backend | None = None,
ignore_errors: bool = False,
) -> Tuple[dns.message.Message, bool]:
"""Return the response to the query, trying UDP first and falling back
to TCP if UDP results in a truncated response.
*udp_sock*, a ``dns.asyncbackend.DatagramSocket``, or ``None``,
the socket to use for the UDP query. If ``None``, the default, a
socket is created. Note that if a socket is provided the *source*,
*source_port*, and *backend* are ignored for the UDP query.
*tcp_sock*, a ``dns.asyncbackend.StreamSocket``, or ``None``, the
socket to use for the TCP query. If ``None``, the default, a
socket is created. Note that if a socket is provided *where*,
*source*, *source_port*, and *backend* are ignored for the TCP query.
*backend*, a ``dns.asyncbackend.Backend``, or ``None``. If ``None``,
the default, then dnspython will use the default backend.
See :py:func:`dns.query.udp_with_fallback()` for the documentation
of the other parameters, exceptions, and return type of this
method.
"""
try:
response = await udp(
q,
where,
timeout,
port,
source,
source_port,
ignore_unexpected,
one_rr_per_rrset,
ignore_trailing,
True,
udp_sock,
backend,
ignore_errors,
)
return (response, False)
except dns.message.Truncated:
response = await tcp(
q,
where,
timeout,
port,
source,
source_port,
one_rr_per_rrset,
ignore_trailing,
tcp_sock,
backend,
)
return (response, True)
async def send_tcp(
sock: dns.asyncbackend.StreamSocket,
what: dns.message.Message | bytes,
expiration: float | None = None,
) -> Tuple[int, float]:
"""Send a DNS message to the specified TCP socket.
*sock*, a ``dns.asyncbackend.StreamSocket``.
See :py:func:`dns.query.send_tcp()` for the documentation of the other
parameters, exceptions, and return type of this method.
"""
if isinstance(what, dns.message.Message):
tcpmsg = what.to_wire(prepend_length=True)
else:
# copying the wire into tcpmsg is inefficient, but lets us
# avoid writev() or doing a short write that would get pushed
# onto the net
tcpmsg = len(what).to_bytes(2, "big") + what
sent_time = time.time()
await sock.sendall(tcpmsg, _timeout(expiration, sent_time))
return (len(tcpmsg), sent_time)
async def _read_exactly(sock, count, expiration):
"""Read the specified number of bytes from stream. Keep trying until we
either get the desired amount, or we hit EOF.
"""
s = b""
while count > 0:
n = await sock.recv(count, _timeout(expiration))
if n == b"":
raise EOFError("EOF")
count = count - len(n)
s = s + n
return s
async def receive_tcp(
sock: dns.asyncbackend.StreamSocket,
expiration: float | None = None,
one_rr_per_rrset: bool = False,
keyring: Dict[dns.name.Name, dns.tsig.Key] | None = None,
request_mac: bytes | None = b"",
ignore_trailing: bool = False,
) -> Tuple[dns.message.Message, float]:
"""Read a DNS message from a TCP socket.
*sock*, a ``dns.asyncbackend.StreamSocket``.
See :py:func:`dns.query.receive_tcp()` for the documentation of the other
parameters, exceptions, and return type of this method.
"""
ldata = await _read_exactly(sock, 2, expiration)
(l,) = struct.unpack("!H", ldata)
wire = await _read_exactly(sock, l, expiration)
received_time = time.time()
r = dns.message.from_wire(
wire,
keyring=keyring,
request_mac=request_mac,
one_rr_per_rrset=one_rr_per_rrset,
ignore_trailing=ignore_trailing,
)
return (r, received_time)
async def tcp(
q: dns.message.Message,
where: str,
timeout: float | None = None,
port: int = 53,
source: str | None = None,
source_port: int = 0,
one_rr_per_rrset: bool = False,
ignore_trailing: bool = False,
sock: dns.asyncbackend.StreamSocket | None = None,
backend: dns.asyncbackend.Backend | None = None,
) -> dns.message.Message:
"""Return the response obtained after sending a query via TCP.
*sock*, a ``dns.asyncbacket.StreamSocket``, or ``None``, the
socket to use for the query. If ``None``, the default, a socket
is created. Note that if a socket is provided
*where*, *port*, *source*, *source_port*, and *backend* are ignored.
*backend*, a ``dns.asyncbackend.Backend``, or ``None``. If ``None``,
the default, then dnspython will use the default backend.
See :py:func:`dns.query.tcp()` for the documentation of the other
parameters, exceptions, and return type of this method.
"""
wire = q.to_wire()
(begin_time, expiration) = _compute_times(timeout)
if sock:
# Verify that the socket is connected, as if it's not connected,
# it's not writable, and the polling in send_tcp() will time out or
# hang forever.
await sock.getpeername()
cm: contextlib.AbstractAsyncContextManager = NullContext(sock)
else:
# These are simple (address, port) pairs, not family-dependent tuples
# you pass to low-level socket code.
af = dns.inet.af_for_address(where)
stuple = _source_tuple(af, source, source_port)
dtuple = (where, port)
if not backend:
backend = dns.asyncbackend.get_default_backend()
cm = await backend.make_socket(
af, socket.SOCK_STREAM, 0, stuple, dtuple, timeout
)
async with cm as s:
await send_tcp(s, wire, expiration) # pyright: ignore
(r, received_time) = await receive_tcp(
s, # pyright: ignore
expiration,
one_rr_per_rrset,
q.keyring,
q.mac,
ignore_trailing,
)
r.time = received_time - begin_time
if not q.is_response(r):
raise BadResponse
return r
async def tls(
q: dns.message.Message,
where: str,
timeout: float | None = None,
port: int = 853,
source: str | None = None,
source_port: int = 0,
one_rr_per_rrset: bool = False,
ignore_trailing: bool = False,
sock: dns.asyncbackend.StreamSocket | None = None,
backend: dns.asyncbackend.Backend | None = None,
ssl_context: ssl.SSLContext | None = None,
server_hostname: str | None = None,
verify: bool | str = True,
) -> dns.message.Message:
"""Return the response obtained after sending a query via TLS.
*sock*, an ``asyncbackend.StreamSocket``, or ``None``, the socket
to use for the query. If ``None``, the default, a socket is
created. Note that if a socket is provided, it must be a
connected SSL stream socket, and *where*, *port*,
*source*, *source_port*, *backend*, *ssl_context*, and *server_hostname*
are ignored.
*backend*, a ``dns.asyncbackend.Backend``, or ``None``. If ``None``,
the default, then dnspython will use the default backend.
See :py:func:`dns.query.tls()` for the documentation of the other
parameters, exceptions, and return type of this method.
"""
(begin_time, expiration) = _compute_times(timeout)
if sock:
cm: contextlib.AbstractAsyncContextManager = NullContext(sock)
else:
if ssl_context is None:
ssl_context = make_ssl_context(verify, server_hostname is not None, ["dot"])
af = dns.inet.af_for_address(where)
stuple = _source_tuple(af, source, source_port)
dtuple = (where, port)
if not backend:
backend = dns.asyncbackend.get_default_backend()
cm = await backend.make_socket(
af,
socket.SOCK_STREAM,
0,
stuple,
dtuple,
timeout,
ssl_context,
server_hostname,
)
async with cm as s:
timeout = _timeout(expiration)
response = await tcp(
q,
where,
timeout,
port,
source,
source_port,
one_rr_per_rrset,
ignore_trailing,
s,
backend,
)
end_time = time.time()
response.time = end_time - begin_time
return response
def _maybe_get_resolver(
resolver: Optional["dns.asyncresolver.Resolver"], # pyright: ignore
) -> "dns.asyncresolver.Resolver": # pyright: ignore
# We need a separate method for this to avoid overriding the global
# variable "dns" with the as-yet undefined local variable "dns"
# in https().
if resolver is None:
# pylint: disable=import-outside-toplevel,redefined-outer-name
import dns.asyncresolver
resolver = dns.asyncresolver.Resolver()
return resolver
async def https(
q: dns.message.Message,
where: str,
timeout: float | None = None,
port: int = 443,
source: str | None = None,
source_port: int = 0, # pylint: disable=W0613
one_rr_per_rrset: bool = False,
ignore_trailing: bool = False,
client: Optional["httpx.AsyncClient|dns.quic.AsyncQuicConnection"] = None,
path: str = "/dns-query",
post: bool = True,
verify: bool | str | ssl.SSLContext = True,
bootstrap_address: str | None = None,
resolver: Optional["dns.asyncresolver.Resolver"] = None, # pyright: ignore
family: int = socket.AF_UNSPEC,
http_version: HTTPVersion = HTTPVersion.DEFAULT,
) -> dns.message.Message:
"""Return the response obtained after sending a query via DNS-over-HTTPS.
*client*, a ``httpx.AsyncClient``. If provided, the client to use for
the query.
Unlike the other dnspython async functions, a backend cannot be provided
in this function because httpx always auto-detects the async backend.
See :py:func:`dns.query.https()` for the documentation of the other
parameters, exceptions, and return type of this method.
"""
try:
af = dns.inet.af_for_address(where)
except ValueError:
af = None
# we bind url and then override as pyright can't figure out all paths bind.
url = where
if af is not None and dns.inet.is_address(where):
if af == socket.AF_INET:
url = f"https://{where}:{port}{path}"
elif af == socket.AF_INET6:
url = f"https://[{where}]:{port}{path}"
extensions = {}
if bootstrap_address is None:
# pylint: disable=possibly-used-before-assignment
parsed = urllib.parse.urlparse(url)
if parsed.hostname is None:
raise ValueError("no hostname in URL")
if dns.inet.is_address(parsed.hostname):
bootstrap_address = parsed.hostname
extensions["sni_hostname"] = parsed.hostname
if parsed.port is not None:
port = parsed.port
if http_version == HTTPVersion.H3 or (
http_version == HTTPVersion.DEFAULT and not have_doh
):
if bootstrap_address is None:
resolver = _maybe_get_resolver(resolver)
assert parsed.hostname is not None # pyright: ignore
answers = await resolver.resolve_name( # pyright: ignore
parsed.hostname, family # pyright: ignore
)
bootstrap_address = random.choice(list(answers.addresses()))
if client and not isinstance(
client, dns.quic.AsyncQuicConnection
): # pyright: ignore
raise ValueError("client parameter must be a dns.quic.AsyncQuicConnection.")
assert client is None or isinstance(client, dns.quic.AsyncQuicConnection)
return await _http3(
q,
bootstrap_address,
url,
timeout,
port,
source,
source_port,
one_rr_per_rrset,
ignore_trailing,
verify=verify,
post=post,
connection=client,
)
if not have_doh:
raise NoDOH # pragma: no cover
# pylint: disable=possibly-used-before-assignment
if client and not isinstance(client, httpx.AsyncClient): # pyright: ignore
raise ValueError("client parameter must be an httpx.AsyncClient")
# pylint: enable=possibly-used-before-assignment
wire = q.to_wire()
headers = {"accept": "application/dns-message"}
h1 = http_version in (HTTPVersion.H1, HTTPVersion.DEFAULT)
h2 = http_version in (HTTPVersion.H2, HTTPVersion.DEFAULT)
backend = dns.asyncbackend.get_default_backend()
if source is None:
local_address = None
local_port = 0
else:
local_address = source
local_port = source_port
if client:
cm: contextlib.AbstractAsyncContextManager = NullContext(client)
else:
transport = backend.get_transport_class()(
local_address=local_address,
http1=h1,
http2=h2,
verify=verify,
local_port=local_port,
bootstrap_address=bootstrap_address,
resolver=resolver,
family=family,
)
cm = httpx.AsyncClient( # pyright: ignore
http1=h1, http2=h2, verify=verify, transport=transport # type: ignore
)
async with cm as the_client:
# see https://tools.ietf.org/html/rfc8484#section-4.1.1 for DoH
# GET and POST examples
if post:
headers.update(
{
"content-type": "application/dns-message",
"content-length": str(len(wire)),
}
)
response = await backend.wait_for(
the_client.post( # pyright: ignore
url,
headers=headers,
content=wire,
extensions=extensions,
),
timeout,
)
else:
wire = base64.urlsafe_b64encode(wire).rstrip(b"=")
twire = wire.decode() # httpx does a repr() if we give it bytes
response = await backend.wait_for(
the_client.get( # pyright: ignore
url,
headers=headers,
params={"dns": twire},
extensions=extensions,
),
timeout,
)
# see https://tools.ietf.org/html/rfc8484#section-4.2.1 for info about DoH
# status codes
if response.status_code < 200 or response.status_code > 299:
raise ValueError(
f"{where} responded with status code {response.status_code}"
f"\nResponse body: {response.content!r}"
)
r = dns.message.from_wire(
response.content,
keyring=q.keyring,
request_mac=q.request_mac,
one_rr_per_rrset=one_rr_per_rrset,
ignore_trailing=ignore_trailing,
)
r.time = response.elapsed.total_seconds()
if not q.is_response(r):
raise BadResponse
return r
async def _http3(
q: dns.message.Message,
where: str,
url: str,
timeout: float | None = None,
port: int = 443,
source: str | None = None,
source_port: int = 0,
one_rr_per_rrset: bool = False,
ignore_trailing: bool = False,
verify: bool | str | ssl.SSLContext = True,
backend: dns.asyncbackend.Backend | None = None,
post: bool = True,
connection: dns.quic.AsyncQuicConnection | None = None,
) -> dns.message.Message:
if not dns.quic.have_quic:
raise NoDOH("DNS-over-HTTP3 is not available.") # pragma: no cover
url_parts = urllib.parse.urlparse(url)
hostname = url_parts.hostname
assert hostname is not None
if url_parts.port is not None:
port = url_parts.port
q.id = 0
wire = q.to_wire()
the_connection: dns.quic.AsyncQuicConnection
if connection:
cfactory = dns.quic.null_factory
mfactory = dns.quic.null_factory
else:
(cfactory, mfactory) = dns.quic.factories_for_backend(backend)
async with cfactory() as context:
async with mfactory(
context, verify_mode=verify, server_name=hostname, h3=True
) as the_manager:
if connection:
the_connection = connection
else:
the_connection = the_manager.connect( # pyright: ignore
where, port, source, source_port
)
(start, expiration) = _compute_times(timeout)
stream = await the_connection.make_stream(timeout) # pyright: ignore
async with stream:
# note that send_h3() does not need await
stream.send_h3(url, wire, post)
wire = await stream.receive(_remaining(expiration))
_check_status(stream.headers(), where, wire)
finish = time.time()
r = dns.message.from_wire(
wire,
keyring=q.keyring,
request_mac=q.request_mac,
one_rr_per_rrset=one_rr_per_rrset,
ignore_trailing=ignore_trailing,
)
r.time = max(finish - start, 0.0)
if not q.is_response(r):
raise BadResponse
return r
async def quic(
q: dns.message.Message,
where: str,
timeout: float | None = None,
port: int = 853,
source: str | None = None,
source_port: int = 0,
one_rr_per_rrset: bool = False,
ignore_trailing: bool = False,
connection: dns.quic.AsyncQuicConnection | None = None,
verify: bool | str = True,
backend: dns.asyncbackend.Backend | None = None,
hostname: str | None = None,
server_hostname: str | None = None,
) -> dns.message.Message:
"""Return the response obtained after sending an asynchronous query via
DNS-over-QUIC.
*backend*, a ``dns.asyncbackend.Backend``, or ``None``. If ``None``,
the default, then dnspython will use the default backend.
See :py:func:`dns.query.quic()` for the documentation of the other
parameters, exceptions, and return type of this method.
"""
if not dns.quic.have_quic:
raise NoDOQ("DNS-over-QUIC is not available.") # pragma: no cover
if server_hostname is not None and hostname is None:
hostname = server_hostname
q.id = 0
wire = q.to_wire()
the_connection: dns.quic.AsyncQuicConnection
if connection:
cfactory = dns.quic.null_factory
mfactory = dns.quic.null_factory
the_connection = connection
else:
(cfactory, mfactory) = dns.quic.factories_for_backend(backend)
async with cfactory() as context:
async with mfactory(
context,
verify_mode=verify,
server_name=server_hostname,
) as the_manager:
if not connection:
the_connection = the_manager.connect( # pyright: ignore
where, port, source, source_port
)
(start, expiration) = _compute_times(timeout)
stream = await the_connection.make_stream(timeout) # pyright: ignore
async with stream:
await stream.send(wire, True)
wire = await stream.receive(_remaining(expiration))
finish = time.time()
r = dns.message.from_wire(
wire,
keyring=q.keyring,
request_mac=q.request_mac,
one_rr_per_rrset=one_rr_per_rrset,
ignore_trailing=ignore_trailing,
)
r.time = max(finish - start, 0.0)
if not q.is_response(r):
raise BadResponse
return r
async def _inbound_xfr(
txn_manager: dns.transaction.TransactionManager,
s: dns.asyncbackend.Socket,
query: dns.message.Message,
serial: int | None,
timeout: float | None,
expiration: float,
) -> Any:
"""Given a socket, does the zone transfer."""
rdtype = query.question[0].rdtype
is_ixfr = rdtype == dns.rdatatype.IXFR
origin = txn_manager.from_wire_origin()
wire = query.to_wire()
is_udp = s.type == socket.SOCK_DGRAM
if is_udp:
udp_sock = cast(dns.asyncbackend.DatagramSocket, s)
await udp_sock.sendto(wire, None, _timeout(expiration))
else:
tcp_sock = cast(dns.asyncbackend.StreamSocket, s)
tcpmsg = struct.pack("!H", len(wire)) + wire
await tcp_sock.sendall(tcpmsg, expiration)
with dns.xfr.Inbound(txn_manager, rdtype, serial, is_udp) as inbound:
done = False
tsig_ctx = None
r: dns.message.Message | None = None
while not done:
(_, mexpiration) = _compute_times(timeout)
if mexpiration is None or (
expiration is not None and mexpiration > expiration
):
mexpiration = expiration
if is_udp:
timeout = _timeout(mexpiration)
(rwire, _) = await udp_sock.recvfrom(65535, timeout) # pyright: ignore
else:
ldata = await _read_exactly(tcp_sock, 2, mexpiration) # pyright: ignore
(l,) = struct.unpack("!H", ldata)
rwire = await _read_exactly(tcp_sock, l, mexpiration) # pyright: ignore
r = dns.message.from_wire(
rwire,
keyring=query.keyring,
request_mac=query.mac,
xfr=True,
origin=origin,
tsig_ctx=tsig_ctx,
multi=(not is_udp),
one_rr_per_rrset=is_ixfr,
)
done = inbound.process_message(r)
yield r
tsig_ctx = r.tsig_ctx
if query.keyring and r is not None and not r.had_tsig:
raise dns.exception.FormError("missing TSIG")
async def inbound_xfr(
where: str,
txn_manager: dns.transaction.TransactionManager,
query: dns.message.Message | None = None,
port: int = 53,
timeout: float | None = None,
lifetime: float | None = None,
source: str | None = None,
source_port: int = 0,
udp_mode: UDPMode = UDPMode.NEVER,
backend: dns.asyncbackend.Backend | None = None,
) -> None:
"""Conduct an inbound transfer and apply it via a transaction from the
txn_manager.
*backend*, a ``dns.asyncbackend.Backend``, or ``None``. If ``None``,
the default, then dnspython will use the default backend.
See :py:func:`dns.query.inbound_xfr()` for the documentation of
the other parameters, exceptions, and return type of this method.
"""
if query is None:
(query, serial) = dns.xfr.make_query(txn_manager)
else:
serial = dns.xfr.extract_serial_from_query(query)
af = dns.inet.af_for_address(where)
stuple = _source_tuple(af, source, source_port)
dtuple = (where, port)
if not backend:
backend = dns.asyncbackend.get_default_backend()
(_, expiration) = _compute_times(lifetime)
if query.question[0].rdtype == dns.rdatatype.IXFR and udp_mode != UDPMode.NEVER:
s = await backend.make_socket(
af, socket.SOCK_DGRAM, 0, stuple, dtuple, _timeout(expiration)
)
async with s:
try:
async for _ in _inbound_xfr( # pyright: ignore
txn_manager,
s,
query,
serial,
timeout,
expiration, # pyright: ignore
):
pass
return
except dns.xfr.UseTCP:
if udp_mode == UDPMode.ONLY:
raise
s = await backend.make_socket(
af, socket.SOCK_STREAM, 0, stuple, dtuple, _timeout(expiration)
)
async with s:
async for _ in _inbound_xfr( # pyright: ignore
txn_manager, s, query, serial, timeout, expiration # pyright: ignore
):
pass

View File

@@ -0,0 +1,478 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
# Copyright (C) 2003-2017 Nominum, Inc.
#
# Permission to use, copy, modify, and distribute this software and its
# documentation for any purpose with or without fee is hereby granted,
# provided that the above copyright notice and this permission notice
# appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
"""Asynchronous DNS stub resolver."""
import socket
import time
from typing import Any, Dict, List
import dns._ddr
import dns.asyncbackend
import dns.asyncquery
import dns.exception
import dns.inet
import dns.name
import dns.nameserver
import dns.query
import dns.rdataclass
import dns.rdatatype
import dns.resolver # lgtm[py/import-and-import-from]
import dns.reversename
# import some resolver symbols for brevity
from dns.resolver import NXDOMAIN, NoAnswer, NoRootSOA, NotAbsolute
# for indentation purposes below
_udp = dns.asyncquery.udp
_tcp = dns.asyncquery.tcp
class Resolver(dns.resolver.BaseResolver):
"""Asynchronous DNS stub resolver."""
async def resolve(
self,
qname: dns.name.Name | str,
rdtype: dns.rdatatype.RdataType | str = dns.rdatatype.A,
rdclass: dns.rdataclass.RdataClass | str = dns.rdataclass.IN,
tcp: bool = False,
source: str | None = None,
raise_on_no_answer: bool = True,
source_port: int = 0,
lifetime: float | None = None,
search: bool | None = None,
backend: dns.asyncbackend.Backend | None = None,
) -> dns.resolver.Answer:
"""Query nameservers asynchronously to find the answer to the question.
*backend*, a ``dns.asyncbackend.Backend``, or ``None``. If ``None``,
the default, then dnspython will use the default backend.
See :py:func:`dns.resolver.Resolver.resolve()` for the
documentation of the other parameters, exceptions, and return
type of this method.
"""
resolution = dns.resolver._Resolution(
self, qname, rdtype, rdclass, tcp, raise_on_no_answer, search
)
if not backend:
backend = dns.asyncbackend.get_default_backend()
start = time.time()
while True:
(request, answer) = resolution.next_request()
# Note we need to say "if answer is not None" and not just
# "if answer" because answer implements __len__, and python
# will call that. We want to return if we have an answer
# object, including in cases where its length is 0.
if answer is not None:
# cache hit!
return answer
assert request is not None # needed for type checking
done = False
while not done:
(nameserver, tcp, backoff) = resolution.next_nameserver()
if backoff:
await backend.sleep(backoff)
timeout = self._compute_timeout(start, lifetime, resolution.errors)
try:
response = await nameserver.async_query(
request,
timeout=timeout,
source=source,
source_port=source_port,
max_size=tcp,
backend=backend,
)
except Exception as ex:
(_, done) = resolution.query_result(None, ex)
continue
(answer, done) = resolution.query_result(response, None)
# Note we need to say "if answer is not None" and not just
# "if answer" because answer implements __len__, and python
# will call that. We want to return if we have an answer
# object, including in cases where its length is 0.
if answer is not None:
return answer
async def resolve_address(
self, ipaddr: str, *args: Any, **kwargs: Any
) -> dns.resolver.Answer:
"""Use an asynchronous resolver to run a reverse query for PTR
records.
This utilizes the resolve() method to perform a PTR lookup on the
specified IP address.
*ipaddr*, a ``str``, the IPv4 or IPv6 address you want to get
the PTR record for.
All other arguments that can be passed to the resolve() function
except for rdtype and rdclass are also supported by this
function.
"""
# We make a modified kwargs for type checking happiness, as otherwise
# we get a legit warning about possibly having rdtype and rdclass
# in the kwargs more than once.
modified_kwargs: Dict[str, Any] = {}
modified_kwargs.update(kwargs)
modified_kwargs["rdtype"] = dns.rdatatype.PTR
modified_kwargs["rdclass"] = dns.rdataclass.IN
return await self.resolve(
dns.reversename.from_address(ipaddr), *args, **modified_kwargs
)
async def resolve_name(
self,
name: dns.name.Name | str,
family: int = socket.AF_UNSPEC,
**kwargs: Any,
) -> dns.resolver.HostAnswers:
"""Use an asynchronous resolver to query for address records.
This utilizes the resolve() method to perform A and/or AAAA lookups on
the specified name.
*qname*, a ``dns.name.Name`` or ``str``, the name to resolve.
*family*, an ``int``, the address family. If socket.AF_UNSPEC
(the default), both A and AAAA records will be retrieved.
All other arguments that can be passed to the resolve() function
except for rdtype and rdclass are also supported by this
function.
"""
# We make a modified kwargs for type checking happiness, as otherwise
# we get a legit warning about possibly having rdtype and rdclass
# in the kwargs more than once.
modified_kwargs: Dict[str, Any] = {}
modified_kwargs.update(kwargs)
modified_kwargs.pop("rdtype", None)
modified_kwargs["rdclass"] = dns.rdataclass.IN
if family == socket.AF_INET:
v4 = await self.resolve(name, dns.rdatatype.A, **modified_kwargs)
return dns.resolver.HostAnswers.make(v4=v4)
elif family == socket.AF_INET6:
v6 = await self.resolve(name, dns.rdatatype.AAAA, **modified_kwargs)
return dns.resolver.HostAnswers.make(v6=v6)
elif family != socket.AF_UNSPEC:
raise NotImplementedError(f"unknown address family {family}")
raise_on_no_answer = modified_kwargs.pop("raise_on_no_answer", True)
lifetime = modified_kwargs.pop("lifetime", None)
start = time.time()
v6 = await self.resolve(
name,
dns.rdatatype.AAAA,
raise_on_no_answer=False,
lifetime=self._compute_timeout(start, lifetime),
**modified_kwargs,
)
# Note that setting name ensures we query the same name
# for A as we did for AAAA. (This is just in case search lists
# are active by default in the resolver configuration and
# we might be talking to a server that says NXDOMAIN when it
# wants to say NOERROR no data.
name = v6.qname
v4 = await self.resolve(
name,
dns.rdatatype.A,
raise_on_no_answer=False,
lifetime=self._compute_timeout(start, lifetime),
**modified_kwargs,
)
answers = dns.resolver.HostAnswers.make(
v6=v6, v4=v4, add_empty=not raise_on_no_answer
)
if not answers:
raise NoAnswer(response=v6.response)
return answers
# pylint: disable=redefined-outer-name
async def canonical_name(self, name: dns.name.Name | str) -> dns.name.Name:
"""Determine the canonical name of *name*.
The canonical name is the name the resolver uses for queries
after all CNAME and DNAME renamings have been applied.
*name*, a ``dns.name.Name`` or ``str``, the query name.
This method can raise any exception that ``resolve()`` can
raise, other than ``dns.resolver.NoAnswer`` and
``dns.resolver.NXDOMAIN``.
Returns a ``dns.name.Name``.
"""
try:
answer = await self.resolve(name, raise_on_no_answer=False)
canonical_name = answer.canonical_name
except dns.resolver.NXDOMAIN as e:
canonical_name = e.canonical_name
return canonical_name
async def try_ddr(self, lifetime: float = 5.0) -> None:
"""Try to update the resolver's nameservers using Discovery of Designated
Resolvers (DDR). If successful, the resolver will subsequently use
DNS-over-HTTPS or DNS-over-TLS for future queries.
*lifetime*, a float, is the maximum time to spend attempting DDR. The default
is 5 seconds.
If the SVCB query is successful and results in a non-empty list of nameservers,
then the resolver's nameservers are set to the returned servers in priority
order.
The current implementation does not use any address hints from the SVCB record,
nor does it resolve addresses for the SCVB target name, rather it assumes that
the bootstrap nameserver will always be one of the addresses and uses it.
A future revision to the code may offer fuller support. The code verifies that
the bootstrap nameserver is in the Subject Alternative Name field of the
TLS certficate.
"""
try:
expiration = time.time() + lifetime
answer = await self.resolve(
dns._ddr._local_resolver_name, "svcb", lifetime=lifetime
)
timeout = dns.query._remaining(expiration)
nameservers = await dns._ddr._get_nameservers_async(answer, timeout)
if len(nameservers) > 0:
self.nameservers = nameservers
except Exception:
pass
default_resolver = None
def get_default_resolver() -> Resolver:
"""Get the default asynchronous resolver, initializing it if necessary."""
if default_resolver is None:
reset_default_resolver()
assert default_resolver is not None
return default_resolver
def reset_default_resolver() -> None:
"""Re-initialize default asynchronous resolver.
Note that the resolver configuration (i.e. /etc/resolv.conf on UNIX
systems) will be re-read immediately.
"""
global default_resolver
default_resolver = Resolver()
async def resolve(
qname: dns.name.Name | str,
rdtype: dns.rdatatype.RdataType | str = dns.rdatatype.A,
rdclass: dns.rdataclass.RdataClass | str = dns.rdataclass.IN,
tcp: bool = False,
source: str | None = None,
raise_on_no_answer: bool = True,
source_port: int = 0,
lifetime: float | None = None,
search: bool | None = None,
backend: dns.asyncbackend.Backend | None = None,
) -> dns.resolver.Answer:
"""Query nameservers asynchronously to find the answer to the question.
This is a convenience function that uses the default resolver
object to make the query.
See :py:func:`dns.asyncresolver.Resolver.resolve` for more
information on the parameters.
"""
return await get_default_resolver().resolve(
qname,
rdtype,
rdclass,
tcp,
source,
raise_on_no_answer,
source_port,
lifetime,
search,
backend,
)
async def resolve_address(
ipaddr: str, *args: Any, **kwargs: Any
) -> dns.resolver.Answer:
"""Use a resolver to run a reverse query for PTR records.
See :py:func:`dns.asyncresolver.Resolver.resolve_address` for more
information on the parameters.
"""
return await get_default_resolver().resolve_address(ipaddr, *args, **kwargs)
async def resolve_name(
name: dns.name.Name | str, family: int = socket.AF_UNSPEC, **kwargs: Any
) -> dns.resolver.HostAnswers:
"""Use a resolver to asynchronously query for address records.
See :py:func:`dns.asyncresolver.Resolver.resolve_name` for more
information on the parameters.
"""
return await get_default_resolver().resolve_name(name, family, **kwargs)
async def canonical_name(name: dns.name.Name | str) -> dns.name.Name:
"""Determine the canonical name of *name*.
See :py:func:`dns.resolver.Resolver.canonical_name` for more
information on the parameters and possible exceptions.
"""
return await get_default_resolver().canonical_name(name)
async def try_ddr(timeout: float = 5.0) -> None:
"""Try to update the default resolver's nameservers using Discovery of Designated
Resolvers (DDR). If successful, the resolver will subsequently use
DNS-over-HTTPS or DNS-over-TLS for future queries.
See :py:func:`dns.resolver.Resolver.try_ddr` for more information.
"""
return await get_default_resolver().try_ddr(timeout)
async def zone_for_name(
name: dns.name.Name | str,
rdclass: dns.rdataclass.RdataClass = dns.rdataclass.IN,
tcp: bool = False,
resolver: Resolver | None = None,
backend: dns.asyncbackend.Backend | None = None,
) -> dns.name.Name:
"""Find the name of the zone which contains the specified name.
See :py:func:`dns.resolver.Resolver.zone_for_name` for more
information on the parameters and possible exceptions.
"""
if isinstance(name, str):
name = dns.name.from_text(name, dns.name.root)
if resolver is None:
resolver = get_default_resolver()
if not name.is_absolute():
raise NotAbsolute(name)
while True:
try:
answer = await resolver.resolve(
name, dns.rdatatype.SOA, rdclass, tcp, backend=backend
)
assert answer.rrset is not None
if answer.rrset.name == name:
return name
# otherwise we were CNAMEd or DNAMEd and need to look higher
except (NXDOMAIN, NoAnswer):
pass
try:
name = name.parent()
except dns.name.NoParent: # pragma: no cover
raise NoRootSOA
async def make_resolver_at(
where: dns.name.Name | str,
port: int = 53,
family: int = socket.AF_UNSPEC,
resolver: Resolver | None = None,
) -> Resolver:
"""Make a stub resolver using the specified destination as the full resolver.
*where*, a ``dns.name.Name`` or ``str`` the domain name or IP address of the
full resolver.
*port*, an ``int``, the port to use. If not specified, the default is 53.
*family*, an ``int``, the address family to use. This parameter is used if
*where* is not an address. The default is ``socket.AF_UNSPEC`` in which case
the first address returned by ``resolve_name()`` will be used, otherwise the
first address of the specified family will be used.
*resolver*, a ``dns.asyncresolver.Resolver`` or ``None``, the resolver to use for
resolution of hostnames. If not specified, the default resolver will be used.
Returns a ``dns.resolver.Resolver`` or raises an exception.
"""
if resolver is None:
resolver = get_default_resolver()
nameservers: List[str | dns.nameserver.Nameserver] = []
if isinstance(where, str) and dns.inet.is_address(where):
nameservers.append(dns.nameserver.Do53Nameserver(where, port))
else:
answers = await resolver.resolve_name(where, family)
for address in answers.addresses():
nameservers.append(dns.nameserver.Do53Nameserver(address, port))
res = Resolver(configure=False)
res.nameservers = nameservers
return res
async def resolve_at(
where: dns.name.Name | str,
qname: dns.name.Name | str,
rdtype: dns.rdatatype.RdataType | str = dns.rdatatype.A,
rdclass: dns.rdataclass.RdataClass | str = dns.rdataclass.IN,
tcp: bool = False,
source: str | None = None,
raise_on_no_answer: bool = True,
source_port: int = 0,
lifetime: float | None = None,
search: bool | None = None,
backend: dns.asyncbackend.Backend | None = None,
port: int = 53,
family: int = socket.AF_UNSPEC,
resolver: Resolver | None = None,
) -> dns.resolver.Answer:
"""Query nameservers to find the answer to the question.
This is a convenience function that calls ``dns.asyncresolver.make_resolver_at()``
to make a resolver, and then uses it to resolve the query.
See ``dns.asyncresolver.Resolver.resolve`` for more information on the resolution
parameters, and ``dns.asyncresolver.make_resolver_at`` for information about the
resolver parameters *where*, *port*, *family*, and *resolver*.
If making more than one query, it is more efficient to call
``dns.asyncresolver.make_resolver_at()`` and then use that resolver for the queries
instead of calling ``resolve_at()`` multiple times.
"""
res = await make_resolver_at(where, port, family, resolver)
return await res.resolve(
qname,
rdtype,
rdclass,
tcp,
source,
raise_on_no_answer,
source_port,
lifetime,
search,
backend,
)

View File

@@ -0,0 +1,850 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
"""
A BTree in the style of Cormen, Leiserson, and Rivest's "Algorithms" book, with
copy-on-write node updates, cursors, and optional space optimization for mostly-in-order
insertion.
"""
from collections.abc import MutableMapping, MutableSet
from typing import Any, Callable, Generic, Optional, Tuple, TypeVar, cast
DEFAULT_T = 127
KT = TypeVar("KT") # the type of a key in Element
class Element(Generic[KT]):
"""All items stored in the BTree are Elements."""
def key(self) -> KT:
"""The key for this element; the returned type must implement comparison."""
raise NotImplementedError # pragma: no cover
ET = TypeVar("ET", bound=Element) # the type of a value in a _KV
def _MIN(t: int) -> int:
"""The minimum number of keys in a non-root node for a BTree with the specified
``t``
"""
return t - 1
def _MAX(t: int) -> int:
"""The maximum number of keys in node for a BTree with the specified ``t``"""
return 2 * t - 1
class _Creator:
"""A _Creator class instance is used as a unique id for the BTree which created
a node.
We use a dedicated creator rather than just a BTree reference to avoid circularity
that would complicate GC.
"""
def __str__(self): # pragma: no cover
return f"{id(self):x}"
class _Node(Generic[KT, ET]):
"""A Node in the BTree.
A Node (leaf or internal) of the BTree.
"""
__slots__ = ["t", "creator", "is_leaf", "elts", "children"]
def __init__(self, t: int, creator: _Creator, is_leaf: bool):
assert t >= 3
self.t = t
self.creator = creator
self.is_leaf = is_leaf
self.elts: list[ET] = []
self.children: list[_Node[KT, ET]] = []
def is_maximal(self) -> bool:
"""Does this node have the maximal number of keys?"""
assert len(self.elts) <= _MAX(self.t)
return len(self.elts) == _MAX(self.t)
def is_minimal(self) -> bool:
"""Does this node have the minimal number of keys?"""
assert len(self.elts) >= _MIN(self.t)
return len(self.elts) == _MIN(self.t)
def search_in_node(self, key: KT) -> tuple[int, bool]:
"""Get the index of the ``Element`` matching ``key`` or the index of its
least successor.
Returns a tuple of the index and an ``equal`` boolean that is ``True`` iff.
the key was found.
"""
l = len(self.elts)
if l > 0 and key > self.elts[l - 1].key():
# This is optimizing near in-order insertion.
return l, False
l = 0
i = len(self.elts)
r = i - 1
equal = False
while l <= r:
m = (l + r) // 2
k = self.elts[m].key()
if key == k:
i = m
equal = True
break
elif key < k:
i = m
r = m - 1
else:
l = m + 1
return i, equal
def maybe_cow_child(self, index: int) -> "_Node[KT, ET]":
assert not self.is_leaf
child = self.children[index]
cloned = child.maybe_cow(self.creator)
if cloned:
self.children[index] = cloned
return cloned
else:
return child
def _get_node(self, key: KT) -> Tuple[Optional["_Node[KT, ET]"], int]:
"""Get the node associated with key and its index, doing
copy-on-write if we have to descend.
Returns a tuple of the node and the index, or the tuple ``(None, 0)``
if the key was not found.
"""
i, equal = self.search_in_node(key)
if equal:
return (self, i)
elif self.is_leaf:
return (None, 0)
else:
child = self.maybe_cow_child(i)
return child._get_node(key)
def get(self, key: KT) -> ET | None:
"""Get the element associated with *key* or return ``None``"""
i, equal = self.search_in_node(key)
if equal:
return self.elts[i]
elif self.is_leaf:
return None
else:
return self.children[i].get(key)
def optimize_in_order_insertion(self, index: int) -> None:
"""Try to minimize the number of Nodes in a BTree where the insertion
is done in-order or close to it, by stealing as much as we can from our
right sibling.
If we don't do this, then an in-order insertion will produce a BTree
where most of the nodes are minimal.
"""
if index == 0:
return
left = self.children[index - 1]
if len(left.elts) == _MAX(self.t):
return
left = self.maybe_cow_child(index - 1)
while len(left.elts) < _MAX(self.t):
if not left.try_right_steal(self, index - 1):
break
def insert_nonfull(self, element: ET, in_order: bool) -> ET | None:
assert not self.is_maximal()
while True:
key = element.key()
i, equal = self.search_in_node(key)
if equal:
# replace
old = self.elts[i]
self.elts[i] = element
return old
elif self.is_leaf:
self.elts.insert(i, element)
return None
else:
child = self.maybe_cow_child(i)
if child.is_maximal():
self.adopt(*child.split())
# Splitting might result in our target moving to us, so
# search again.
continue
oelt = child.insert_nonfull(element, in_order)
if in_order:
self.optimize_in_order_insertion(i)
return oelt
def split(self) -> tuple["_Node[KT, ET]", ET, "_Node[KT, ET]"]:
"""Split a maximal node into two minimal ones and a central element."""
assert self.is_maximal()
right = self.__class__(self.t, self.creator, self.is_leaf)
right.elts = list(self.elts[_MIN(self.t) + 1 :])
middle = self.elts[_MIN(self.t)]
self.elts = list(self.elts[: _MIN(self.t)])
if not self.is_leaf:
right.children = list(self.children[_MIN(self.t) + 1 :])
self.children = list(self.children[: _MIN(self.t) + 1])
return self, middle, right
def try_left_steal(self, parent: "_Node[KT, ET]", index: int) -> bool:
"""Try to steal from this Node's left sibling for balancing purposes.
Returns ``True`` if the theft was successful, or ``False`` if not.
"""
if index != 0:
left = parent.children[index - 1]
if not left.is_minimal():
left = parent.maybe_cow_child(index - 1)
elt = parent.elts[index - 1]
parent.elts[index - 1] = left.elts.pop()
self.elts.insert(0, elt)
if not left.is_leaf:
assert not self.is_leaf
child = left.children.pop()
self.children.insert(0, child)
return True
return False
def try_right_steal(self, parent: "_Node[KT, ET]", index: int) -> bool:
"""Try to steal from this Node's right sibling for balancing purposes.
Returns ``True`` if the theft was successful, or ``False`` if not.
"""
if index + 1 < len(parent.children):
right = parent.children[index + 1]
if not right.is_minimal():
right = parent.maybe_cow_child(index + 1)
elt = parent.elts[index]
parent.elts[index] = right.elts.pop(0)
self.elts.append(elt)
if not right.is_leaf:
assert not self.is_leaf
child = right.children.pop(0)
self.children.append(child)
return True
return False
def adopt(self, left: "_Node[KT, ET]", middle: ET, right: "_Node[KT, ET]") -> None:
"""Adopt left, middle, and right into our Node (which must not be maximal,
and which must not be a leaf). In the case were we are not the new root,
then the left child must already be in the Node."""
assert not self.is_maximal()
assert not self.is_leaf
key = middle.key()
i, equal = self.search_in_node(key)
assert not equal
self.elts.insert(i, middle)
if len(self.children) == 0:
# We are the new root
self.children = [left, right]
else:
assert self.children[i] == left
self.children.insert(i + 1, right)
def merge(self, parent: "_Node[KT, ET]", index: int) -> None:
"""Merge this node's parent and its right sibling into this node."""
right = parent.children.pop(index + 1)
self.elts.append(parent.elts.pop(index))
self.elts.extend(right.elts)
if not self.is_leaf:
self.children.extend(right.children)
def minimum(self) -> ET:
"""The least element in this subtree."""
if self.is_leaf:
return self.elts[0]
else:
return self.children[0].minimum()
def maximum(self) -> ET:
"""The greatest element in this subtree."""
if self.is_leaf:
return self.elts[-1]
else:
return self.children[-1].maximum()
def balance(self, parent: "_Node[KT, ET]", index: int) -> None:
"""This Node is minimal, and we want to make it non-minimal so we can delete.
We try to steal from our siblings, and if that doesn't work we will merge
with one of them."""
assert not parent.is_leaf
if self.try_left_steal(parent, index):
return
if self.try_right_steal(parent, index):
return
# Stealing didn't work, so both siblings must be minimal.
if index == 0:
# We are the left-most node so merge with our right sibling.
self.merge(parent, index)
else:
# Have our left sibling merge with us. This lets us only have "merge right"
# code.
left = parent.maybe_cow_child(index - 1)
left.merge(parent, index - 1)
def delete(
self, key: KT, parent: Optional["_Node[KT, ET]"], exact: ET | None
) -> ET | None:
"""Delete an element matching *key* if it exists. If *exact* is not ``None``
then it must be an exact match with that element. The Node must not be
minimal unless it is the root."""
assert parent is None or not self.is_minimal()
i, equal = self.search_in_node(key)
original_key = None
if equal:
# Note we use "is" here as we meant "exactly this object".
if exact is not None and self.elts[i] is not exact:
raise ValueError("exact delete did not match existing elt")
if self.is_leaf:
return self.elts.pop(i)
# Note we need to ensure exact is None going forward as we've
# already checked exactness and are about to change our target key
# to the least successor.
exact = None
original_key = key
least_successor = self.children[i + 1].minimum()
key = least_successor.key()
i = i + 1
if self.is_leaf:
# No match
if exact is not None:
raise ValueError("exact delete had no match")
return None
# recursively delete in the appropriate child
child = self.maybe_cow_child(i)
if child.is_minimal():
child.balance(self, i)
# Things may have moved.
i, equal = self.search_in_node(key)
assert not equal
child = self.children[i]
assert not child.is_minimal()
elt = child.delete(key, self, exact)
if original_key is not None:
node, i = self._get_node(original_key)
assert node is not None
assert elt is not None
oelt = node.elts[i]
node.elts[i] = elt
elt = oelt
return elt
def visit_in_order(self, visit: Callable[[ET], None]) -> None:
"""Call *visit* on all of the elements in order."""
for i, elt in enumerate(self.elts):
if not self.is_leaf:
self.children[i].visit_in_order(visit)
visit(elt)
if not self.is_leaf:
self.children[-1].visit_in_order(visit)
def _visit_preorder_by_node(self, visit: Callable[["_Node[KT, ET]"], None]) -> None:
"""Visit nodes in preorder. This method is only used for testing."""
visit(self)
if not self.is_leaf:
for child in self.children:
child._visit_preorder_by_node(visit)
def maybe_cow(self, creator: _Creator) -> Optional["_Node[KT, ET]"]:
"""Return a clone of this Node if it was not created by *creator*, or ``None``
otherwise (i.e. copy for copy-on-write if we haven't already copied it)."""
if self.creator is not creator:
return self.clone(creator)
else:
return None
def clone(self, creator: _Creator) -> "_Node[KT, ET]":
"""Make a shallow-copy duplicate of this node."""
cloned = self.__class__(self.t, creator, self.is_leaf)
cloned.elts.extend(self.elts)
if not self.is_leaf:
cloned.children.extend(self.children)
return cloned
def __str__(self): # pragma: no cover
if not self.is_leaf:
children = " " + " ".join([f"{id(c):x}" for c in self.children])
else:
children = ""
return f"{id(self):x} {self.creator} {self.elts}{children}"
class Cursor(Generic[KT, ET]):
"""A seekable cursor for a BTree.
If you are going to use a cursor on a mutable BTree, you should use it
in a ``with`` block so that any mutations of the BTree automatically park
the cursor.
"""
def __init__(self, btree: "BTree[KT, ET]"):
self.btree = btree
self.current_node: _Node | None = None
# The current index is the element index within the current node, or
# if there is no current node then it is 0 on the left boundary and 1
# on the right boundary.
self.current_index: int = 0
self.recurse = False
self.increasing = True
self.parents: list[tuple[_Node, int]] = []
self.parked = False
self.parking_key: KT | None = None
self.parking_key_read = False
def _seek_least(self) -> None:
# seek to the least value in the subtree beneath the current index of the
# current node
assert self.current_node is not None
while not self.current_node.is_leaf:
self.parents.append((self.current_node, self.current_index))
self.current_node = self.current_node.children[self.current_index]
assert self.current_node is not None
self.current_index = 0
def _seek_greatest(self) -> None:
# seek to the greatest value in the subtree beneath the current index of the
# current node
assert self.current_node is not None
while not self.current_node.is_leaf:
self.parents.append((self.current_node, self.current_index))
self.current_node = self.current_node.children[self.current_index]
assert self.current_node is not None
self.current_index = len(self.current_node.elts)
def park(self):
"""Park the cursor.
A cursor must be "parked" before mutating the BTree to avoid undefined behavior.
Cursors created in a ``with`` block register with their BTree and will park
automatically. Note that a parked cursor may not observe some changes made when
it is parked; for example a cursor being iterated with next() will not see items
inserted before its current position.
"""
if not self.parked:
self.parked = True
def _maybe_unpark(self):
if self.parked:
if self.parking_key is not None:
# remember our increasing hint, as seeking might change it
increasing = self.increasing
if self.parking_key_read:
# We've already returned the parking key, so we want to be before it
# if decreasing and after it if increasing.
before = not self.increasing
else:
# We haven't returned the parking key, so we've parked right
# after seeking or are on a boundary. Either way, the before
# hint we want is the value of self.increasing.
before = self.increasing
self.seek(self.parking_key, before)
self.increasing = increasing # might have been altered by seek()
self.parked = False
self.parking_key = None
def prev(self) -> ET | None:
"""Get the previous element, or return None if on the left boundary."""
self._maybe_unpark()
self.parking_key = None
if self.current_node is None:
# on a boundary
if self.current_index == 0:
# left boundary, there is no prev
return None
else:
assert self.current_index == 1
# right boundary; seek to the actual boundary
# so we can do a prev()
self.current_node = self.btree.root
self.current_index = len(self.btree.root.elts)
self._seek_greatest()
while True:
if self.recurse:
if not self.increasing:
# We only want to recurse if we are continuing in the decreasing
# direction.
self._seek_greatest()
self.recurse = False
self.increasing = False
self.current_index -= 1
if self.current_index >= 0:
elt = self.current_node.elts[self.current_index]
if not self.current_node.is_leaf:
self.recurse = True
self.parking_key = elt.key()
self.parking_key_read = True
return elt
else:
if len(self.parents) > 0:
self.current_node, self.current_index = self.parents.pop()
else:
self.current_node = None
self.current_index = 0
return None
def next(self) -> ET | None:
"""Get the next element, or return None if on the right boundary."""
self._maybe_unpark()
self.parking_key = None
if self.current_node is None:
# on a boundary
if self.current_index == 1:
# right boundary, there is no next
return None
else:
assert self.current_index == 0
# left boundary; seek to the actual boundary
# so we can do a next()
self.current_node = self.btree.root
self.current_index = 0
self._seek_least()
while True:
if self.recurse:
if self.increasing:
# We only want to recurse if we are continuing in the increasing
# direction.
self._seek_least()
self.recurse = False
self.increasing = True
if self.current_index < len(self.current_node.elts):
elt = self.current_node.elts[self.current_index]
self.current_index += 1
if not self.current_node.is_leaf:
self.recurse = True
self.parking_key = elt.key()
self.parking_key_read = True
return elt
else:
if len(self.parents) > 0:
self.current_node, self.current_index = self.parents.pop()
else:
self.current_node = None
self.current_index = 1
return None
def _adjust_for_before(self, before: bool, i: int) -> None:
if before:
self.current_index = i
else:
self.current_index = i + 1
def seek(self, key: KT, before: bool = True) -> None:
"""Seek to the specified key.
If *before* is ``True`` (the default) then the cursor is positioned just
before *key* if it exists, or before its least successor if it doesn't. A
subsequent next() will retrieve this value. If *before* is ``False``, then
the cursor is positioned just after *key* if it exists, or its greatest
precessessor if it doesn't. A subsequent prev() will return this value.
"""
self.current_node = self.btree.root
assert self.current_node is not None
self.recurse = False
self.parents = []
self.increasing = before
self.parked = False
self.parking_key = key
self.parking_key_read = False
while not self.current_node.is_leaf:
i, equal = self.current_node.search_in_node(key)
if equal:
self._adjust_for_before(before, i)
if before:
self._seek_greatest()
else:
self._seek_least()
return
self.parents.append((self.current_node, i))
self.current_node = self.current_node.children[i]
assert self.current_node is not None
i, equal = self.current_node.search_in_node(key)
if equal:
self._adjust_for_before(before, i)
else:
self.current_index = i
def seek_first(self) -> None:
"""Seek to the left boundary (i.e. just before the least element).
A subsequent next() will return the least element if the BTree isn't empty."""
self.current_node = None
self.current_index = 0
self.recurse = False
self.increasing = True
self.parents = []
self.parked = False
self.parking_key = None
def seek_last(self) -> None:
"""Seek to the right boundary (i.e. just after the greatest element).
A subsequent prev() will return the greatest element if the BTree isn't empty.
"""
self.current_node = None
self.current_index = 1
self.recurse = False
self.increasing = False
self.parents = []
self.parked = False
self.parking_key = None
def __enter__(self):
self.btree.register_cursor(self)
return self
def __exit__(self, exc_type, exc_value, traceback):
self.btree.deregister_cursor(self)
return False
class Immutable(Exception):
"""The BTree is immutable."""
class BTree(Generic[KT, ET]):
"""An in-memory BTree with copy-on-write and cursors."""
def __init__(self, *, t: int = DEFAULT_T, original: Optional["BTree"] = None):
"""Create a BTree.
If *original* is not ``None``, then the BTree is shallow-cloned from
*original* using copy-on-write. Otherwise a new BTree with the specified
*t* value is created.
The BTree is not thread-safe.
"""
# We don't use a reference to ourselves as a creator as we don't want
# to prevent GC of old btrees.
self.creator = _Creator()
self._immutable = False
self.t: int
self.root: _Node
self.size: int
self.cursors: set[Cursor] = set()
if original is not None:
if not original._immutable:
raise ValueError("original BTree is not immutable")
self.t = original.t
self.root = original.root
self.size = original.size
else:
if t < 3:
raise ValueError("t must be >= 3")
self.t = t
self.root = _Node(self.t, self.creator, True)
self.size = 0
def make_immutable(self):
"""Make the BTree immutable.
Attempts to alter the BTree after making it immutable will raise an
Immutable exception. This operation cannot be undone.
"""
if not self._immutable:
self._immutable = True
def _check_mutable_and_park(self) -> None:
if self._immutable:
raise Immutable
for cursor in self.cursors:
cursor.park()
# Note that we don't use insert() and delete() but rather insert_element() and
# delete_key() so that BTreeDict can be a proper MutableMapping and supply the
# rest of the standard mapping API.
def insert_element(self, elt: ET, in_order: bool = False) -> ET | None:
"""Insert the element into the BTree.
If *in_order* is ``True``, then extra work will be done to make left siblings
full, which optimizes storage space when the the elements are inserted in-order
or close to it.
Returns the previously existing element at the element's key or ``None``.
"""
self._check_mutable_and_park()
cloned = self.root.maybe_cow(self.creator)
if cloned:
self.root = cloned
if self.root.is_maximal():
old_root = self.root
self.root = _Node(self.t, self.creator, False)
self.root.adopt(*old_root.split())
oelt = self.root.insert_nonfull(elt, in_order)
if oelt is None:
# We did not replace, so something was added.
self.size += 1
return oelt
def get_element(self, key: KT) -> ET | None:
"""Get the element matching *key* from the BTree, or return ``None`` if it
does not exist.
"""
return self.root.get(key)
def _delete(self, key: KT, exact: ET | None) -> ET | None:
self._check_mutable_and_park()
cloned = self.root.maybe_cow(self.creator)
if cloned:
self.root = cloned
elt = self.root.delete(key, None, exact)
if elt is not None:
# We deleted something
self.size -= 1
if len(self.root.elts) == 0:
# The root is now empty. If there is a child, then collapse this root
# level and make the child the new root.
if not self.root.is_leaf:
assert len(self.root.children) == 1
self.root = self.root.children[0]
return elt
def delete_key(self, key: KT) -> ET | None:
"""Delete the element matching *key* from the BTree.
Returns the matching element or ``None`` if it does not exist.
"""
return self._delete(key, None)
def delete_exact(self, element: ET) -> ET | None:
"""Delete *element* from the BTree.
Returns the matching element or ``None`` if it was not in the BTree.
"""
delt = self._delete(element.key(), element)
assert delt is element
return delt
def __len__(self):
return self.size
def visit_in_order(self, visit: Callable[[ET], None]) -> None:
"""Call *visit*(element) on all elements in the tree in sorted order."""
self.root.visit_in_order(visit)
def _visit_preorder_by_node(self, visit: Callable[[_Node], None]) -> None:
self.root._visit_preorder_by_node(visit)
def cursor(self) -> Cursor[KT, ET]:
"""Create a cursor."""
return Cursor(self)
def register_cursor(self, cursor: Cursor) -> None:
"""Register a cursor for the automatic parking service."""
self.cursors.add(cursor)
def deregister_cursor(self, cursor: Cursor) -> None:
"""Deregister a cursor from the automatic parking service."""
self.cursors.discard(cursor)
def __copy__(self):
return self.__class__(original=self)
def __iter__(self):
with self.cursor() as cursor:
while True:
elt = cursor.next()
if elt is None:
break
yield elt.key()
VT = TypeVar("VT") # the type of a value in a BTreeDict
class KV(Element, Generic[KT, VT]):
"""The BTree element type used in a ``BTreeDict``."""
def __init__(self, key: KT, value: VT):
self._key = key
self._value = value
def key(self) -> KT:
return self._key
def value(self) -> VT:
return self._value
def __str__(self): # pragma: no cover
return f"KV({self._key}, {self._value})"
def __repr__(self): # pragma: no cover
return f"KV({self._key}, {self._value})"
class BTreeDict(Generic[KT, VT], BTree[KT, KV[KT, VT]], MutableMapping[KT, VT]):
"""A MutableMapping implemented with a BTree.
Unlike a normal Python dict, the BTreeDict may be mutated while iterating.
"""
def __init__(
self,
*,
t: int = DEFAULT_T,
original: BTree | None = None,
in_order: bool = False,
):
super().__init__(t=t, original=original)
self.in_order = in_order
def __getitem__(self, key: KT) -> VT:
elt = self.get_element(key)
if elt is None:
raise KeyError
else:
return cast(KV, elt).value()
def __setitem__(self, key: KT, value: VT) -> None:
elt = KV(key, value)
self.insert_element(elt, self.in_order)
def __delitem__(self, key: KT) -> None:
if self.delete_key(key) is None:
raise KeyError
class Member(Element, Generic[KT]):
"""The BTree element type used in a ``BTreeSet``."""
def __init__(self, key: KT):
self._key = key
def key(self) -> KT:
return self._key
class BTreeSet(BTree, Generic[KT], MutableSet[KT]):
"""A MutableSet implemented with a BTree.
Unlike a normal Python set, the BTreeSet may be mutated while iterating.
"""
def __init__(
self,
*,
t: int = DEFAULT_T,
original: BTree | None = None,
in_order: bool = False,
):
super().__init__(t=t, original=original)
self.in_order = in_order
def __contains__(self, key: Any) -> bool:
return self.get_element(key) is not None
def add(self, value: KT) -> None:
elt = Member(value)
self.insert_element(elt, self.in_order)
def discard(self, value: KT) -> None:
self.delete_key(value)

View File

@@ -0,0 +1,367 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
# A derivative of a dnspython VersionedZone and related classes, using a BTreeDict and
# a separate per-version delegation index. These additions let us
#
# 1) Do efficient CoW versioning (useful for future online updates).
# 2) Maintain sort order
# 3) Allow delegations to be found easily
# 4) Handle glue
# 5) Add Node flags ORIGIN, DELEGATION, and GLUE whenever relevant. The ORIGIN
# flag is set at the origin node, the DELEGATION FLAG is set at delegation
# points, and the GLUE flag is set on nodes beneath delegation points.
import enum
from dataclasses import dataclass
from typing import Callable, MutableMapping, Tuple, cast
import dns.btree
import dns.immutable
import dns.name
import dns.node
import dns.rdataclass
import dns.rdataset
import dns.rdatatype
import dns.versioned
import dns.zone
class NodeFlags(enum.IntFlag):
ORIGIN = 0x01
DELEGATION = 0x02
GLUE = 0x04
class Node(dns.node.Node):
__slots__ = ["flags", "id"]
def __init__(self, flags: NodeFlags | None = None):
super().__init__()
if flags is None:
# We allow optional flags rather than a default
# as pyright doesn't like assigning a literal 0
# to flags.
flags = NodeFlags(0)
self.flags = flags
self.id = 0
def is_delegation(self):
return (self.flags & NodeFlags.DELEGATION) != 0
def is_glue(self):
return (self.flags & NodeFlags.GLUE) != 0
def is_origin(self):
return (self.flags & NodeFlags.ORIGIN) != 0
def is_origin_or_glue(self):
return (self.flags & (NodeFlags.ORIGIN | NodeFlags.GLUE)) != 0
@dns.immutable.immutable
class ImmutableNode(Node):
def __init__(self, node: Node):
super().__init__()
self.id = node.id
self.rdatasets = tuple( # type: ignore
[dns.rdataset.ImmutableRdataset(rds) for rds in node.rdatasets]
)
self.flags = node.flags
def find_rdataset(
self,
rdclass: dns.rdataclass.RdataClass,
rdtype: dns.rdatatype.RdataType,
covers: dns.rdatatype.RdataType = dns.rdatatype.NONE,
create: bool = False,
) -> dns.rdataset.Rdataset:
if create:
raise TypeError("immutable")
return super().find_rdataset(rdclass, rdtype, covers, False)
def get_rdataset(
self,
rdclass: dns.rdataclass.RdataClass,
rdtype: dns.rdatatype.RdataType,
covers: dns.rdatatype.RdataType = dns.rdatatype.NONE,
create: bool = False,
) -> dns.rdataset.Rdataset | None:
if create:
raise TypeError("immutable")
return super().get_rdataset(rdclass, rdtype, covers, False)
def delete_rdataset(
self,
rdclass: dns.rdataclass.RdataClass,
rdtype: dns.rdatatype.RdataType,
covers: dns.rdatatype.RdataType = dns.rdatatype.NONE,
) -> None:
raise TypeError("immutable")
def replace_rdataset(self, replacement: dns.rdataset.Rdataset) -> None:
raise TypeError("immutable")
def is_immutable(self) -> bool:
return True
class Delegations(dns.btree.BTreeSet[dns.name.Name]):
def get_delegation(self, name: dns.name.Name) -> Tuple[dns.name.Name | None, bool]:
"""Get the delegation applicable to *name*, if it exists.
If there delegation, then return a tuple consisting of the name of
the delegation point, and a boolean which is `True` if the name is a proper
subdomain of the delegation point, and `False` if it is equal to the delegation
point.
"""
cursor = self.cursor()
cursor.seek(name, before=False)
prev = cursor.prev()
if prev is None:
return None, False
cut = prev.key()
reln, _, _ = name.fullcompare(cut)
is_subdomain = reln == dns.name.NameRelation.SUBDOMAIN
if is_subdomain or reln == dns.name.NameRelation.EQUAL:
return cut, is_subdomain
else:
return None, False
def is_glue(self, name: dns.name.Name) -> bool:
"""Is *name* glue, i.e. is it beneath a delegation?"""
cursor = self.cursor()
cursor.seek(name, before=False)
cut, is_subdomain = self.get_delegation(name)
if cut is None:
return False
return is_subdomain
class WritableVersion(dns.zone.WritableVersion):
def __init__(self, zone: dns.zone.Zone, replacement: bool = False):
super().__init__(zone, True)
if not replacement:
assert isinstance(zone, dns.versioned.Zone)
version = zone._versions[-1]
self.nodes: dns.btree.BTreeDict[dns.name.Name, Node] = dns.btree.BTreeDict(
original=version.nodes # type: ignore
)
self.delegations = Delegations(original=version.delegations) # type: ignore
else:
self.delegations = Delegations()
def _is_origin(self, name: dns.name.Name) -> bool:
# Assumes name has already been validated (and thus adjusted to the right
# relativity too)
if self.zone.relativize:
return name == dns.name.empty
else:
return name == self.zone.origin
def _maybe_cow_with_name(
self, name: dns.name.Name
) -> Tuple[dns.node.Node, dns.name.Name]:
(node, name) = super()._maybe_cow_with_name(name)
node = cast(Node, node)
if self._is_origin(name):
node.flags |= NodeFlags.ORIGIN
elif self.delegations.is_glue(name):
node.flags |= NodeFlags.GLUE
return (node, name)
def update_glue_flag(self, name: dns.name.Name, is_glue: bool) -> None:
cursor = self.nodes.cursor() # type: ignore
cursor.seek(name, False)
updates = []
while True:
elt = cursor.next()
if elt is None:
break
ename = elt.key()
if not ename.is_subdomain(name):
break
node = cast(dns.node.Node, elt.value())
if ename not in self.changed:
new_node = self.zone.node_factory()
new_node.id = self.id # type: ignore
new_node.rdatasets.extend(node.rdatasets)
self.changed.add(ename)
node = new_node
assert isinstance(node, Node)
if is_glue:
node.flags |= NodeFlags.GLUE
else:
node.flags &= ~NodeFlags.GLUE
# We don't update node here as any insertion could disturb the
# btree and invalidate our cursor. We could use the cursor in a
# with block and avoid this, but it would do a lot of parking and
# unparking so the deferred update mode may still be better.
updates.append((ename, node))
for ename, node in updates:
self.nodes[ename] = node
def delete_node(self, name: dns.name.Name) -> None:
name = self._validate_name(name)
node = self.nodes.get(name)
if node is not None:
if node.is_delegation(): # type: ignore
self.delegations.discard(name)
self.update_glue_flag(name, False)
del self.nodes[name]
self.changed.add(name)
def put_rdataset(
self, name: dns.name.Name, rdataset: dns.rdataset.Rdataset
) -> None:
(node, name) = self._maybe_cow_with_name(name)
if (
rdataset.rdtype == dns.rdatatype.NS and not node.is_origin_or_glue() # type: ignore
):
node.flags |= NodeFlags.DELEGATION # type: ignore
if name not in self.delegations:
self.delegations.add(name)
self.update_glue_flag(name, True)
node.replace_rdataset(rdataset)
def delete_rdataset(
self,
name: dns.name.Name,
rdtype: dns.rdatatype.RdataType,
covers: dns.rdatatype.RdataType,
) -> None:
(node, name) = self._maybe_cow_with_name(name)
if rdtype == dns.rdatatype.NS and name in self.delegations: # type: ignore
node.flags &= ~NodeFlags.DELEGATION # type: ignore
self.delegations.discard(name) # type: ignore
self.update_glue_flag(name, False)
node.delete_rdataset(self.zone.rdclass, rdtype, covers)
if len(node) == 0:
del self.nodes[name]
@dataclass(frozen=True)
class Bounds:
name: dns.name.Name
left: dns.name.Name
right: dns.name.Name | None
closest_encloser: dns.name.Name
is_equal: bool
is_delegation: bool
def __str__(self):
if self.is_equal:
op = "="
else:
op = "<"
if self.is_delegation:
zonecut = " zonecut"
else:
zonecut = ""
return (
f"{self.left} {op} {self.name} < {self.right}{zonecut}; "
f"{self.closest_encloser}"
)
@dns.immutable.immutable
class ImmutableVersion(dns.zone.Version):
def __init__(self, version: dns.zone.Version):
if not isinstance(version, WritableVersion):
raise ValueError(
"a dns.btreezone.ImmutableVersion requires a "
"dns.btreezone.WritableVersion"
)
super().__init__(version.zone, True)
self.id = version.id
self.origin = version.origin
for name in version.changed:
node = version.nodes.get(name)
if node:
version.nodes[name] = ImmutableNode(node)
# the cast below is for mypy
self.nodes = cast(MutableMapping[dns.name.Name, dns.node.Node], version.nodes)
self.nodes.make_immutable() # type: ignore
self.delegations = version.delegations
self.delegations.make_immutable()
def bounds(self, name: dns.name.Name | str) -> Bounds:
"""Return the 'bounds' of *name* in its zone.
The bounds information is useful when making an authoritative response, as
it can be used to determine whether the query name is at or beneath a delegation
point. The other data in the ``Bounds`` object is useful for making on-the-fly
DNSSEC signatures.
The left bound of *name* is *name* itself if it is in the zone, or the greatest
predecessor which is in the zone.
The right bound of *name* is the least successor of *name*, or ``None`` if
no name in the zone is greater than *name*.
The closest encloser of *name* is *name* itself, if *name* is in the zone;
otherwise it is the name with the largest number of labels in common with
*name* that is in the zone, either explicitly or by the implied existence
of empty non-terminals.
The bounds *is_equal* field is ``True`` if and only if *name* is equal to
its left bound.
The bounds *is_delegation* field is ``True`` if and only if the left bound is a
delegation point.
"""
assert self.origin is not None
# validate the origin because we may need to relativize
origin = self.zone._validate_name(self.origin)
name = self.zone._validate_name(name)
cut, _ = self.delegations.get_delegation(name)
if cut is not None:
target = cut
is_delegation = True
else:
target = name
is_delegation = False
c = cast(dns.btree.BTreeDict, self.nodes).cursor()
c.seek(target, False)
left = c.prev()
assert left is not None
c.next() # skip over left
while True:
right = c.next()
if right is None or not right.value().is_glue():
break
left_comparison = left.key().fullcompare(name)
if right is not None:
right_key = right.key()
right_comparison = right_key.fullcompare(name)
else:
right_comparison = (
dns.name.NAMERELN_COMMONANCESTOR,
-1,
len(origin),
)
right_key = None
closest_encloser = dns.name.Name(
name[-max(left_comparison[2], right_comparison[2]) :]
)
return Bounds(
name,
left.key(),
right_key,
closest_encloser,
left_comparison[0] == dns.name.NameRelation.EQUAL,
is_delegation,
)
class Zone(dns.versioned.Zone):
node_factory: Callable[[], dns.node.Node] = Node
map_factory: Callable[[], MutableMapping[dns.name.Name, dns.node.Node]] = cast(
Callable[[], MutableMapping[dns.name.Name, dns.node.Node]],
dns.btree.BTreeDict[dns.name.Name, Node],
)
writable_version_factory: (
Callable[[dns.zone.Zone, bool], dns.zone.Version] | None
) = WritableVersion
immutable_version_factory: Callable[[dns.zone.Version], dns.zone.Version] | None = (
ImmutableVersion
)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,124 @@
from typing import Dict, Tuple, Type
import dns._features
import dns.name
from dns.dnssecalgs.base import GenericPrivateKey
from dns.dnssectypes import Algorithm
from dns.exception import UnsupportedAlgorithm
from dns.rdtypes.ANY.DNSKEY import DNSKEY
# pyright: reportPossiblyUnboundVariable=false
if dns._features.have("dnssec"):
from dns.dnssecalgs.dsa import PrivateDSA, PrivateDSANSEC3SHA1
from dns.dnssecalgs.ecdsa import PrivateECDSAP256SHA256, PrivateECDSAP384SHA384
from dns.dnssecalgs.eddsa import PrivateED448, PrivateED25519
from dns.dnssecalgs.rsa import (
PrivateRSAMD5,
PrivateRSASHA1,
PrivateRSASHA1NSEC3SHA1,
PrivateRSASHA256,
PrivateRSASHA512,
)
_have_cryptography = True
else:
_have_cryptography = False
AlgorithmPrefix = bytes | dns.name.Name | None
algorithms: Dict[Tuple[Algorithm, AlgorithmPrefix], Type[GenericPrivateKey]] = {}
if _have_cryptography:
# pylint: disable=possibly-used-before-assignment
algorithms.update(
{
(Algorithm.RSAMD5, None): PrivateRSAMD5,
(Algorithm.DSA, None): PrivateDSA,
(Algorithm.RSASHA1, None): PrivateRSASHA1,
(Algorithm.DSANSEC3SHA1, None): PrivateDSANSEC3SHA1,
(Algorithm.RSASHA1NSEC3SHA1, None): PrivateRSASHA1NSEC3SHA1,
(Algorithm.RSASHA256, None): PrivateRSASHA256,
(Algorithm.RSASHA512, None): PrivateRSASHA512,
(Algorithm.ECDSAP256SHA256, None): PrivateECDSAP256SHA256,
(Algorithm.ECDSAP384SHA384, None): PrivateECDSAP384SHA384,
(Algorithm.ED25519, None): PrivateED25519,
(Algorithm.ED448, None): PrivateED448,
}
)
def get_algorithm_cls(
algorithm: int | str, prefix: AlgorithmPrefix = None
) -> Type[GenericPrivateKey]:
"""Get Private Key class from Algorithm.
*algorithm*, a ``str`` or ``int`` specifying the DNSKEY algorithm.
Raises ``UnsupportedAlgorithm`` if the algorithm is unknown.
Returns a ``dns.dnssecalgs.GenericPrivateKey``
"""
algorithm = Algorithm.make(algorithm)
cls = algorithms.get((algorithm, prefix))
if cls:
return cls
raise UnsupportedAlgorithm(
f'algorithm "{Algorithm.to_text(algorithm)}" not supported by dnspython'
)
def get_algorithm_cls_from_dnskey(dnskey: DNSKEY) -> Type[GenericPrivateKey]:
"""Get Private Key class from DNSKEY.
*dnskey*, a ``DNSKEY`` to get Algorithm class for.
Raises ``UnsupportedAlgorithm`` if the algorithm is unknown.
Returns a ``dns.dnssecalgs.GenericPrivateKey``
"""
prefix: AlgorithmPrefix = None
if dnskey.algorithm == Algorithm.PRIVATEDNS:
prefix, _ = dns.name.from_wire(dnskey.key, 0)
elif dnskey.algorithm == Algorithm.PRIVATEOID:
length = int(dnskey.key[0])
prefix = dnskey.key[0 : length + 1]
return get_algorithm_cls(dnskey.algorithm, prefix)
def register_algorithm_cls(
algorithm: int | str,
algorithm_cls: Type[GenericPrivateKey],
name: dns.name.Name | str | None = None,
oid: bytes | None = None,
) -> None:
"""Register Algorithm Private Key class.
*algorithm*, a ``str`` or ``int`` specifying the DNSKEY algorithm.
*algorithm_cls*: A `GenericPrivateKey` class.
*name*, an optional ``dns.name.Name`` or ``str``, for for PRIVATEDNS algorithms.
*oid*: an optional BER-encoded `bytes` for PRIVATEOID algorithms.
Raises ``ValueError`` if a name or oid is specified incorrectly.
"""
if not issubclass(algorithm_cls, GenericPrivateKey):
raise TypeError("Invalid algorithm class")
algorithm = Algorithm.make(algorithm)
prefix: AlgorithmPrefix = None
if algorithm == Algorithm.PRIVATEDNS:
if name is None:
raise ValueError("Name required for PRIVATEDNS algorithms")
if isinstance(name, str):
name = dns.name.from_text(name)
prefix = name
elif algorithm == Algorithm.PRIVATEOID:
if oid is None:
raise ValueError("OID required for PRIVATEOID algorithms")
prefix = bytes([len(oid)]) + oid
elif name:
raise ValueError("Name only supported for PRIVATEDNS algorithm")
elif oid:
raise ValueError("OID only supported for PRIVATEOID algorithm")
algorithms[(algorithm, prefix)] = algorithm_cls

View File

@@ -0,0 +1,89 @@
from abc import ABC, abstractmethod # pylint: disable=no-name-in-module
from typing import Any, Type
import dns.rdataclass
import dns.rdatatype
from dns.dnssectypes import Algorithm
from dns.exception import AlgorithmKeyMismatch
from dns.rdtypes.ANY.DNSKEY import DNSKEY
from dns.rdtypes.dnskeybase import Flag
class GenericPublicKey(ABC):
algorithm: Algorithm
@abstractmethod
def __init__(self, key: Any) -> None:
pass
@abstractmethod
def verify(self, signature: bytes, data: bytes) -> None:
"""Verify signed DNSSEC data"""
@abstractmethod
def encode_key_bytes(self) -> bytes:
"""Encode key as bytes for DNSKEY"""
@classmethod
def _ensure_algorithm_key_combination(cls, key: DNSKEY) -> None:
if key.algorithm != cls.algorithm:
raise AlgorithmKeyMismatch
def to_dnskey(self, flags: int = Flag.ZONE, protocol: int = 3) -> DNSKEY:
"""Return public key as DNSKEY"""
return DNSKEY(
rdclass=dns.rdataclass.IN,
rdtype=dns.rdatatype.DNSKEY,
flags=flags,
protocol=protocol,
algorithm=self.algorithm,
key=self.encode_key_bytes(),
)
@classmethod
@abstractmethod
def from_dnskey(cls, key: DNSKEY) -> "GenericPublicKey":
"""Create public key from DNSKEY"""
@classmethod
@abstractmethod
def from_pem(cls, public_pem: bytes) -> "GenericPublicKey":
"""Create public key from PEM-encoded SubjectPublicKeyInfo as specified
in RFC 5280"""
@abstractmethod
def to_pem(self) -> bytes:
"""Return public-key as PEM-encoded SubjectPublicKeyInfo as specified
in RFC 5280"""
class GenericPrivateKey(ABC):
public_cls: Type[GenericPublicKey]
@abstractmethod
def __init__(self, key: Any) -> None:
pass
@abstractmethod
def sign(
self,
data: bytes,
verify: bool = False,
deterministic: bool = True,
) -> bytes:
"""Sign DNSSEC data"""
@abstractmethod
def public_key(self) -> "GenericPublicKey":
"""Return public key instance"""
@classmethod
@abstractmethod
def from_pem(
cls, private_pem: bytes, password: bytes | None = None
) -> "GenericPrivateKey":
"""Create private key from PEM-encoded PKCS#8"""
@abstractmethod
def to_pem(self, password: bytes | None = None) -> bytes:
"""Return private key as PEM-encoded PKCS#8"""

View File

@@ -0,0 +1,68 @@
from typing import Any, Type
from cryptography.hazmat.primitives import serialization
from dns.dnssecalgs.base import GenericPrivateKey, GenericPublicKey
from dns.exception import AlgorithmKeyMismatch
class CryptographyPublicKey(GenericPublicKey):
key: Any = None
key_cls: Any = None
def __init__(self, key: Any) -> None: # pylint: disable=super-init-not-called
if self.key_cls is None:
raise TypeError("Undefined private key class")
if not isinstance( # pylint: disable=isinstance-second-argument-not-valid-type
key, self.key_cls
):
raise AlgorithmKeyMismatch
self.key = key
@classmethod
def from_pem(cls, public_pem: bytes) -> "GenericPublicKey":
key = serialization.load_pem_public_key(public_pem)
return cls(key=key)
def to_pem(self) -> bytes:
return self.key.public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
)
class CryptographyPrivateKey(GenericPrivateKey):
key: Any = None
key_cls: Any = None
public_cls: Type[CryptographyPublicKey] # pyright: ignore
def __init__(self, key: Any) -> None: # pylint: disable=super-init-not-called
if self.key_cls is None:
raise TypeError("Undefined private key class")
if not isinstance( # pylint: disable=isinstance-second-argument-not-valid-type
key, self.key_cls
):
raise AlgorithmKeyMismatch
self.key = key
def public_key(self) -> "CryptographyPublicKey":
return self.public_cls(key=self.key.public_key())
@classmethod
def from_pem(
cls, private_pem: bytes, password: bytes | None = None
) -> "GenericPrivateKey":
key = serialization.load_pem_private_key(private_pem, password=password)
return cls(key=key)
def to_pem(self, password: bytes | None = None) -> bytes:
encryption_algorithm: serialization.KeySerializationEncryption
if password:
encryption_algorithm = serialization.BestAvailableEncryption(password)
else:
encryption_algorithm = serialization.NoEncryption()
return self.key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=encryption_algorithm,
)

View File

@@ -0,0 +1,108 @@
import struct
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric import dsa, utils
from dns.dnssecalgs.cryptography import CryptographyPrivateKey, CryptographyPublicKey
from dns.dnssectypes import Algorithm
from dns.rdtypes.ANY.DNSKEY import DNSKEY
class PublicDSA(CryptographyPublicKey):
key: dsa.DSAPublicKey
key_cls = dsa.DSAPublicKey
algorithm = Algorithm.DSA
chosen_hash = hashes.SHA1()
def verify(self, signature: bytes, data: bytes) -> None:
sig_r = signature[1:21]
sig_s = signature[21:]
sig = utils.encode_dss_signature(
int.from_bytes(sig_r, "big"), int.from_bytes(sig_s, "big")
)
self.key.verify(sig, data, self.chosen_hash)
def encode_key_bytes(self) -> bytes:
"""Encode a public key per RFC 2536, section 2."""
pn = self.key.public_numbers()
dsa_t = (self.key.key_size // 8 - 64) // 8
if dsa_t > 8:
raise ValueError("unsupported DSA key size")
octets = 64 + dsa_t * 8
res = struct.pack("!B", dsa_t)
res += pn.parameter_numbers.q.to_bytes(20, "big")
res += pn.parameter_numbers.p.to_bytes(octets, "big")
res += pn.parameter_numbers.g.to_bytes(octets, "big")
res += pn.y.to_bytes(octets, "big")
return res
@classmethod
def from_dnskey(cls, key: DNSKEY) -> "PublicDSA":
cls._ensure_algorithm_key_combination(key)
keyptr = key.key
(t,) = struct.unpack("!B", keyptr[0:1])
keyptr = keyptr[1:]
octets = 64 + t * 8
dsa_q = keyptr[0:20]
keyptr = keyptr[20:]
dsa_p = keyptr[0:octets]
keyptr = keyptr[octets:]
dsa_g = keyptr[0:octets]
keyptr = keyptr[octets:]
dsa_y = keyptr[0:octets]
return cls(
key=dsa.DSAPublicNumbers( # type: ignore
int.from_bytes(dsa_y, "big"),
dsa.DSAParameterNumbers(
int.from_bytes(dsa_p, "big"),
int.from_bytes(dsa_q, "big"),
int.from_bytes(dsa_g, "big"),
),
).public_key(default_backend()),
)
class PrivateDSA(CryptographyPrivateKey):
key: dsa.DSAPrivateKey
key_cls = dsa.DSAPrivateKey
public_cls = PublicDSA
def sign(
self,
data: bytes,
verify: bool = False,
deterministic: bool = True,
) -> bytes:
"""Sign using a private key per RFC 2536, section 3."""
public_dsa_key = self.key.public_key()
if public_dsa_key.key_size > 1024:
raise ValueError("DSA key size overflow")
der_signature = self.key.sign(
data, self.public_cls.chosen_hash # pyright: ignore
)
dsa_r, dsa_s = utils.decode_dss_signature(der_signature)
dsa_t = (public_dsa_key.key_size // 8 - 64) // 8
octets = 20
signature = (
struct.pack("!B", dsa_t)
+ int.to_bytes(dsa_r, length=octets, byteorder="big")
+ int.to_bytes(dsa_s, length=octets, byteorder="big")
)
if verify:
self.public_key().verify(signature, data)
return signature
@classmethod
def generate(cls, key_size: int) -> "PrivateDSA":
return cls(
key=dsa.generate_private_key(key_size=key_size),
)
class PublicDSANSEC3SHA1(PublicDSA):
algorithm = Algorithm.DSANSEC3SHA1
class PrivateDSANSEC3SHA1(PrivateDSA):
public_cls = PublicDSANSEC3SHA1

View File

@@ -0,0 +1,100 @@
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric import ec, utils
from dns.dnssecalgs.cryptography import CryptographyPrivateKey, CryptographyPublicKey
from dns.dnssectypes import Algorithm
from dns.rdtypes.ANY.DNSKEY import DNSKEY
class PublicECDSA(CryptographyPublicKey):
key: ec.EllipticCurvePublicKey
key_cls = ec.EllipticCurvePublicKey
algorithm: Algorithm
chosen_hash: hashes.HashAlgorithm
curve: ec.EllipticCurve
octets: int
def verify(self, signature: bytes, data: bytes) -> None:
sig_r = signature[0 : self.octets]
sig_s = signature[self.octets :]
sig = utils.encode_dss_signature(
int.from_bytes(sig_r, "big"), int.from_bytes(sig_s, "big")
)
self.key.verify(sig, data, ec.ECDSA(self.chosen_hash))
def encode_key_bytes(self) -> bytes:
"""Encode a public key per RFC 6605, section 4."""
pn = self.key.public_numbers()
return pn.x.to_bytes(self.octets, "big") + pn.y.to_bytes(self.octets, "big")
@classmethod
def from_dnskey(cls, key: DNSKEY) -> "PublicECDSA":
cls._ensure_algorithm_key_combination(key)
ecdsa_x = key.key[0 : cls.octets]
ecdsa_y = key.key[cls.octets : cls.octets * 2]
return cls(
key=ec.EllipticCurvePublicNumbers(
curve=cls.curve,
x=int.from_bytes(ecdsa_x, "big"),
y=int.from_bytes(ecdsa_y, "big"),
).public_key(default_backend()),
)
class PrivateECDSA(CryptographyPrivateKey):
key: ec.EllipticCurvePrivateKey
key_cls = ec.EllipticCurvePrivateKey
public_cls = PublicECDSA
def sign(
self,
data: bytes,
verify: bool = False,
deterministic: bool = True,
) -> bytes:
"""Sign using a private key per RFC 6605, section 4."""
algorithm = ec.ECDSA(
self.public_cls.chosen_hash, # pyright: ignore
deterministic_signing=deterministic,
)
der_signature = self.key.sign(data, algorithm)
dsa_r, dsa_s = utils.decode_dss_signature(der_signature)
signature = int.to_bytes(
dsa_r, length=self.public_cls.octets, byteorder="big" # pyright: ignore
) + int.to_bytes(
dsa_s, length=self.public_cls.octets, byteorder="big" # pyright: ignore
)
if verify:
self.public_key().verify(signature, data)
return signature
@classmethod
def generate(cls) -> "PrivateECDSA":
return cls(
key=ec.generate_private_key(
curve=cls.public_cls.curve, backend=default_backend() # pyright: ignore
),
)
class PublicECDSAP256SHA256(PublicECDSA):
algorithm = Algorithm.ECDSAP256SHA256
chosen_hash = hashes.SHA256()
curve = ec.SECP256R1()
octets = 32
class PrivateECDSAP256SHA256(PrivateECDSA):
public_cls = PublicECDSAP256SHA256
class PublicECDSAP384SHA384(PublicECDSA):
algorithm = Algorithm.ECDSAP384SHA384
chosen_hash = hashes.SHA384()
curve = ec.SECP384R1()
octets = 48
class PrivateECDSAP384SHA384(PrivateECDSA):
public_cls = PublicECDSAP384SHA384

View File

@@ -0,0 +1,70 @@
from typing import Type
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import ed448, ed25519
from dns.dnssecalgs.cryptography import CryptographyPrivateKey, CryptographyPublicKey
from dns.dnssectypes import Algorithm
from dns.rdtypes.ANY.DNSKEY import DNSKEY
class PublicEDDSA(CryptographyPublicKey):
def verify(self, signature: bytes, data: bytes) -> None:
self.key.verify(signature, data)
def encode_key_bytes(self) -> bytes:
"""Encode a public key per RFC 8080, section 3."""
return self.key.public_bytes(
encoding=serialization.Encoding.Raw, format=serialization.PublicFormat.Raw
)
@classmethod
def from_dnskey(cls, key: DNSKEY) -> "PublicEDDSA":
cls._ensure_algorithm_key_combination(key)
return cls(
key=cls.key_cls.from_public_bytes(key.key),
)
class PrivateEDDSA(CryptographyPrivateKey):
public_cls: Type[PublicEDDSA] # pyright: ignore
def sign(
self,
data: bytes,
verify: bool = False,
deterministic: bool = True,
) -> bytes:
"""Sign using a private key per RFC 8080, section 4."""
signature = self.key.sign(data)
if verify:
self.public_key().verify(signature, data)
return signature
@classmethod
def generate(cls) -> "PrivateEDDSA":
return cls(key=cls.key_cls.generate())
class PublicED25519(PublicEDDSA):
key: ed25519.Ed25519PublicKey
key_cls = ed25519.Ed25519PublicKey
algorithm = Algorithm.ED25519
class PrivateED25519(PrivateEDDSA):
key: ed25519.Ed25519PrivateKey
key_cls = ed25519.Ed25519PrivateKey
public_cls = PublicED25519
class PublicED448(PublicEDDSA):
key: ed448.Ed448PublicKey
key_cls = ed448.Ed448PublicKey
algorithm = Algorithm.ED448
class PrivateED448(PrivateEDDSA):
key: ed448.Ed448PrivateKey
key_cls = ed448.Ed448PrivateKey
public_cls = PublicED448

View File

@@ -0,0 +1,126 @@
import math
import struct
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric import padding, rsa
from dns.dnssecalgs.cryptography import CryptographyPrivateKey, CryptographyPublicKey
from dns.dnssectypes import Algorithm
from dns.rdtypes.ANY.DNSKEY import DNSKEY
class PublicRSA(CryptographyPublicKey):
key: rsa.RSAPublicKey
key_cls = rsa.RSAPublicKey
algorithm: Algorithm
chosen_hash: hashes.HashAlgorithm
def verify(self, signature: bytes, data: bytes) -> None:
self.key.verify(signature, data, padding.PKCS1v15(), self.chosen_hash)
def encode_key_bytes(self) -> bytes:
"""Encode a public key per RFC 3110, section 2."""
pn = self.key.public_numbers()
_exp_len = math.ceil(int.bit_length(pn.e) / 8)
exp = int.to_bytes(pn.e, length=_exp_len, byteorder="big")
if _exp_len > 255:
exp_header = b"\0" + struct.pack("!H", _exp_len)
else:
exp_header = struct.pack("!B", _exp_len)
if pn.n.bit_length() < 512 or pn.n.bit_length() > 4096:
raise ValueError("unsupported RSA key length")
return exp_header + exp + pn.n.to_bytes((pn.n.bit_length() + 7) // 8, "big")
@classmethod
def from_dnskey(cls, key: DNSKEY) -> "PublicRSA":
cls._ensure_algorithm_key_combination(key)
keyptr = key.key
(bytes_,) = struct.unpack("!B", keyptr[0:1])
keyptr = keyptr[1:]
if bytes_ == 0:
(bytes_,) = struct.unpack("!H", keyptr[0:2])
keyptr = keyptr[2:]
rsa_e = keyptr[0:bytes_]
rsa_n = keyptr[bytes_:]
return cls(
key=rsa.RSAPublicNumbers(
int.from_bytes(rsa_e, "big"), int.from_bytes(rsa_n, "big")
).public_key(default_backend())
)
class PrivateRSA(CryptographyPrivateKey):
key: rsa.RSAPrivateKey
key_cls = rsa.RSAPrivateKey
public_cls = PublicRSA
default_public_exponent = 65537
def sign(
self,
data: bytes,
verify: bool = False,
deterministic: bool = True,
) -> bytes:
"""Sign using a private key per RFC 3110, section 3."""
signature = self.key.sign(
data, padding.PKCS1v15(), self.public_cls.chosen_hash # pyright: ignore
)
if verify:
self.public_key().verify(signature, data)
return signature
@classmethod
def generate(cls, key_size: int) -> "PrivateRSA":
return cls(
key=rsa.generate_private_key(
public_exponent=cls.default_public_exponent,
key_size=key_size,
backend=default_backend(),
)
)
class PublicRSAMD5(PublicRSA):
algorithm = Algorithm.RSAMD5
chosen_hash = hashes.MD5()
class PrivateRSAMD5(PrivateRSA):
public_cls = PublicRSAMD5
class PublicRSASHA1(PublicRSA):
algorithm = Algorithm.RSASHA1
chosen_hash = hashes.SHA1()
class PrivateRSASHA1(PrivateRSA):
public_cls = PublicRSASHA1
class PublicRSASHA1NSEC3SHA1(PublicRSA):
algorithm = Algorithm.RSASHA1NSEC3SHA1
chosen_hash = hashes.SHA1()
class PrivateRSASHA1NSEC3SHA1(PrivateRSA):
public_cls = PublicRSASHA1NSEC3SHA1
class PublicRSASHA256(PublicRSA):
algorithm = Algorithm.RSASHA256
chosen_hash = hashes.SHA256()
class PrivateRSASHA256(PrivateRSA):
public_cls = PublicRSASHA256
class PublicRSASHA512(PublicRSA):
algorithm = Algorithm.RSASHA512
chosen_hash = hashes.SHA512()
class PrivateRSASHA512(PrivateRSA):
public_cls = PublicRSASHA512

View File

@@ -0,0 +1,71 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
# Copyright (C) 2003-2017 Nominum, Inc.
#
# Permission to use, copy, modify, and distribute this software and its
# documentation for any purpose with or without fee is hereby granted,
# provided that the above copyright notice and this permission notice
# appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
"""Common DNSSEC-related types."""
# This is a separate file to avoid import circularity between dns.dnssec and
# the implementations of the DS and DNSKEY types.
import dns.enum
class Algorithm(dns.enum.IntEnum):
RSAMD5 = 1
DH = 2
DSA = 3
ECC = 4
RSASHA1 = 5
DSANSEC3SHA1 = 6
RSASHA1NSEC3SHA1 = 7
RSASHA256 = 8
RSASHA512 = 10
ECCGOST = 12
ECDSAP256SHA256 = 13
ECDSAP384SHA384 = 14
ED25519 = 15
ED448 = 16
INDIRECT = 252
PRIVATEDNS = 253
PRIVATEOID = 254
@classmethod
def _maximum(cls):
return 255
class DSDigest(dns.enum.IntEnum):
"""DNSSEC Delegation Signer Digest Algorithm"""
NULL = 0
SHA1 = 1
SHA256 = 2
GOST = 3
SHA384 = 4
@classmethod
def _maximum(cls):
return 255
class NSEC3Hash(dns.enum.IntEnum):
"""NSEC3 hash algorithm"""
SHA1 = 1
@classmethod
def _maximum(cls):
return 255

View File

@@ -0,0 +1,116 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
# Copyright (C) 2006-2017 Nominum, Inc.
#
# Permission to use, copy, modify, and distribute this software and its
# documentation for any purpose with or without fee is hereby granted,
# provided that the above copyright notice and this permission notice
# appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
"""DNS E.164 helpers."""
from typing import Iterable
import dns.exception
import dns.name
import dns.resolver
#: The public E.164 domain.
public_enum_domain = dns.name.from_text("e164.arpa.")
def from_e164(
text: str, origin: dns.name.Name | None = public_enum_domain
) -> dns.name.Name:
"""Convert an E.164 number in textual form into a Name object whose
value is the ENUM domain name for that number.
Non-digits in the text are ignored, i.e. "16505551212",
"+1.650.555.1212" and "1 (650) 555-1212" are all the same.
*text*, a ``str``, is an E.164 number in textual form.
*origin*, a ``dns.name.Name``, the domain in which the number
should be constructed. The default is ``e164.arpa.``.
Returns a ``dns.name.Name``.
"""
parts = [d for d in text if d.isdigit()]
parts.reverse()
return dns.name.from_text(".".join(parts), origin=origin)
def to_e164(
name: dns.name.Name,
origin: dns.name.Name | None = public_enum_domain,
want_plus_prefix: bool = True,
) -> str:
"""Convert an ENUM domain name into an E.164 number.
Note that dnspython does not have any information about preferred
number formats within national numbering plans, so all numbers are
emitted as a simple string of digits, prefixed by a '+' (unless
*want_plus_prefix* is ``False``).
*name* is a ``dns.name.Name``, the ENUM domain name.
*origin* is a ``dns.name.Name``, a domain containing the ENUM
domain name. The name is relativized to this domain before being
converted to text. If ``None``, no relativization is done.
*want_plus_prefix* is a ``bool``. If True, add a '+' to the beginning of
the returned number.
Returns a ``str``.
"""
if origin is not None:
name = name.relativize(origin)
dlabels = [d for d in name.labels if d.isdigit() and len(d) == 1]
if len(dlabels) != len(name.labels):
raise dns.exception.SyntaxError("non-digit labels in ENUM domain name")
dlabels.reverse()
text = b"".join(dlabels)
if want_plus_prefix:
text = b"+" + text
return text.decode()
def query(
number: str,
domains: Iterable[dns.name.Name | str],
resolver: dns.resolver.Resolver | None = None,
) -> dns.resolver.Answer:
"""Look for NAPTR RRs for the specified number in the specified domains.
e.g. lookup('16505551212', ['e164.dnspython.org.', 'e164.arpa.'])
*number*, a ``str`` is the number to look for.
*domains* is an iterable containing ``dns.name.Name`` values.
*resolver*, a ``dns.resolver.Resolver``, is the resolver to use. If
``None``, the default resolver is used.
"""
if resolver is None:
resolver = dns.resolver.get_default_resolver()
e_nx = dns.resolver.NXDOMAIN()
for domain in domains:
if isinstance(domain, str):
domain = dns.name.from_text(domain)
qname = from_e164(number, domain)
try:
return resolver.resolve(qname, "NAPTR")
except dns.resolver.NXDOMAIN as e:
e_nx += e
raise e_nx

View File

@@ -0,0 +1,591 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
# Copyright (C) 2009-2017 Nominum, Inc.
#
# Permission to use, copy, modify, and distribute this software and its
# documentation for any purpose with or without fee is hereby granted,
# provided that the above copyright notice and this permission notice
# appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
"""EDNS Options"""
import binascii
import math
import socket
import struct
from typing import Any, Dict
import dns.enum
import dns.inet
import dns.ipv4
import dns.ipv6
import dns.name
import dns.rdata
import dns.wire
class OptionType(dns.enum.IntEnum):
"""EDNS option type codes"""
#: NSID
NSID = 3
#: DAU
DAU = 5
#: DHU
DHU = 6
#: N3U
N3U = 7
#: ECS (client-subnet)
ECS = 8
#: EXPIRE
EXPIRE = 9
#: COOKIE
COOKIE = 10
#: KEEPALIVE
KEEPALIVE = 11
#: PADDING
PADDING = 12
#: CHAIN
CHAIN = 13
#: EDE (extended-dns-error)
EDE = 15
#: REPORTCHANNEL
REPORTCHANNEL = 18
@classmethod
def _maximum(cls):
return 65535
class Option:
"""Base class for all EDNS option types."""
def __init__(self, otype: OptionType | str):
"""Initialize an option.
*otype*, a ``dns.edns.OptionType``, is the option type.
"""
self.otype = OptionType.make(otype)
def to_wire(self, file: Any | None = None) -> bytes | None:
"""Convert an option to wire format.
Returns a ``bytes`` or ``None``.
"""
raise NotImplementedError # pragma: no cover
def to_text(self) -> str:
raise NotImplementedError # pragma: no cover
def to_generic(self) -> "GenericOption":
"""Creates a dns.edns.GenericOption equivalent of this rdata.
Returns a ``dns.edns.GenericOption``.
"""
wire = self.to_wire()
assert wire is not None # for mypy
return GenericOption(self.otype, wire)
@classmethod
def from_wire_parser(cls, otype: OptionType, parser: "dns.wire.Parser") -> "Option":
"""Build an EDNS option object from wire format.
*otype*, a ``dns.edns.OptionType``, is the option type.
*parser*, a ``dns.wire.Parser``, the parser, which should be
restructed to the option length.
Returns a ``dns.edns.Option``.
"""
raise NotImplementedError # pragma: no cover
def _cmp(self, other):
"""Compare an EDNS option with another option of the same type.
Returns < 0 if < *other*, 0 if == *other*, and > 0 if > *other*.
"""
wire = self.to_wire()
owire = other.to_wire()
if wire == owire:
return 0
if wire > owire:
return 1
return -1
def __eq__(self, other):
if not isinstance(other, Option):
return False
if self.otype != other.otype:
return False
return self._cmp(other) == 0
def __ne__(self, other):
if not isinstance(other, Option):
return True
if self.otype != other.otype:
return True
return self._cmp(other) != 0
def __lt__(self, other):
if not isinstance(other, Option) or self.otype != other.otype:
return NotImplemented
return self._cmp(other) < 0
def __le__(self, other):
if not isinstance(other, Option) or self.otype != other.otype:
return NotImplemented
return self._cmp(other) <= 0
def __ge__(self, other):
if not isinstance(other, Option) or self.otype != other.otype:
return NotImplemented
return self._cmp(other) >= 0
def __gt__(self, other):
if not isinstance(other, Option) or self.otype != other.otype:
return NotImplemented
return self._cmp(other) > 0
def __str__(self):
return self.to_text()
class GenericOption(Option): # lgtm[py/missing-equals]
"""Generic Option Class
This class is used for EDNS option types for which we have no better
implementation.
"""
def __init__(self, otype: OptionType | str, data: bytes | str):
super().__init__(otype)
self.data = dns.rdata.Rdata._as_bytes(data, True)
def to_wire(self, file: Any | None = None) -> bytes | None:
if file:
file.write(self.data)
return None
else:
return self.data
def to_text(self) -> str:
return f"Generic {self.otype}"
def to_generic(self) -> "GenericOption":
return self
@classmethod
def from_wire_parser(
cls, otype: OptionType | str, parser: "dns.wire.Parser"
) -> Option:
return cls(otype, parser.get_remaining())
class ECSOption(Option): # lgtm[py/missing-equals]
"""EDNS Client Subnet (ECS, RFC7871)"""
def __init__(self, address: str, srclen: int | None = None, scopelen: int = 0):
"""*address*, a ``str``, is the client address information.
*srclen*, an ``int``, the source prefix length, which is the
leftmost number of bits of the address to be used for the
lookup. The default is 24 for IPv4 and 56 for IPv6.
*scopelen*, an ``int``, the scope prefix length. This value
must be 0 in queries, and should be set in responses.
"""
super().__init__(OptionType.ECS)
af = dns.inet.af_for_address(address)
if af == socket.AF_INET6:
self.family = 2
if srclen is None:
srclen = 56
address = dns.rdata.Rdata._as_ipv6_address(address)
srclen = dns.rdata.Rdata._as_int(srclen, 0, 128)
scopelen = dns.rdata.Rdata._as_int(scopelen, 0, 128)
elif af == socket.AF_INET:
self.family = 1
if srclen is None:
srclen = 24
address = dns.rdata.Rdata._as_ipv4_address(address)
srclen = dns.rdata.Rdata._as_int(srclen, 0, 32)
scopelen = dns.rdata.Rdata._as_int(scopelen, 0, 32)
else: # pragma: no cover (this will never happen)
raise ValueError("Bad address family")
assert srclen is not None
self.address = address
self.srclen = srclen
self.scopelen = scopelen
addrdata = dns.inet.inet_pton(af, address)
nbytes = int(math.ceil(srclen / 8.0))
# Truncate to srclen and pad to the end of the last octet needed
# See RFC section 6
self.addrdata = addrdata[:nbytes]
nbits = srclen % 8
if nbits != 0:
last = struct.pack("B", ord(self.addrdata[-1:]) & (0xFF << (8 - nbits)))
self.addrdata = self.addrdata[:-1] + last
def to_text(self) -> str:
return f"ECS {self.address}/{self.srclen} scope/{self.scopelen}"
@staticmethod
def from_text(text: str) -> Option:
"""Convert a string into a `dns.edns.ECSOption`
*text*, a `str`, the text form of the option.
Returns a `dns.edns.ECSOption`.
Examples:
>>> import dns.edns
>>>
>>> # basic example
>>> dns.edns.ECSOption.from_text('1.2.3.4/24')
>>>
>>> # also understands scope
>>> dns.edns.ECSOption.from_text('1.2.3.4/24/32')
>>>
>>> # IPv6
>>> dns.edns.ECSOption.from_text('2001:4b98::1/64/64')
>>>
>>> # it understands results from `dns.edns.ECSOption.to_text()`
>>> dns.edns.ECSOption.from_text('ECS 1.2.3.4/24/32')
"""
optional_prefix = "ECS"
tokens = text.split()
ecs_text = None
if len(tokens) == 1:
ecs_text = tokens[0]
elif len(tokens) == 2:
if tokens[0] != optional_prefix:
raise ValueError(f'could not parse ECS from "{text}"')
ecs_text = tokens[1]
else:
raise ValueError(f'could not parse ECS from "{text}"')
n_slashes = ecs_text.count("/")
if n_slashes == 1:
address, tsrclen = ecs_text.split("/")
tscope = "0"
elif n_slashes == 2:
address, tsrclen, tscope = ecs_text.split("/")
else:
raise ValueError(f'could not parse ECS from "{text}"')
try:
scope = int(tscope)
except ValueError:
raise ValueError("invalid scope " + f'"{tscope}": scope must be an integer')
try:
srclen = int(tsrclen)
except ValueError:
raise ValueError(
"invalid srclen " + f'"{tsrclen}": srclen must be an integer'
)
return ECSOption(address, srclen, scope)
def to_wire(self, file: Any | None = None) -> bytes | None:
value = (
struct.pack("!HBB", self.family, self.srclen, self.scopelen) + self.addrdata
)
if file:
file.write(value)
return None
else:
return value
@classmethod
def from_wire_parser(
cls, otype: OptionType | str, parser: "dns.wire.Parser"
) -> Option:
family, src, scope = parser.get_struct("!HBB")
addrlen = int(math.ceil(src / 8.0))
prefix = parser.get_bytes(addrlen)
if family == 1:
pad = 4 - addrlen
addr = dns.ipv4.inet_ntoa(prefix + b"\x00" * pad)
elif family == 2:
pad = 16 - addrlen
addr = dns.ipv6.inet_ntoa(prefix + b"\x00" * pad)
else:
raise ValueError("unsupported family")
return cls(addr, src, scope)
class EDECode(dns.enum.IntEnum):
"""Extended DNS Error (EDE) codes"""
OTHER = 0
UNSUPPORTED_DNSKEY_ALGORITHM = 1
UNSUPPORTED_DS_DIGEST_TYPE = 2
STALE_ANSWER = 3
FORGED_ANSWER = 4
DNSSEC_INDETERMINATE = 5
DNSSEC_BOGUS = 6
SIGNATURE_EXPIRED = 7
SIGNATURE_NOT_YET_VALID = 8
DNSKEY_MISSING = 9
RRSIGS_MISSING = 10
NO_ZONE_KEY_BIT_SET = 11
NSEC_MISSING = 12
CACHED_ERROR = 13
NOT_READY = 14
BLOCKED = 15
CENSORED = 16
FILTERED = 17
PROHIBITED = 18
STALE_NXDOMAIN_ANSWER = 19
NOT_AUTHORITATIVE = 20
NOT_SUPPORTED = 21
NO_REACHABLE_AUTHORITY = 22
NETWORK_ERROR = 23
INVALID_DATA = 24
@classmethod
def _maximum(cls):
return 65535
class EDEOption(Option): # lgtm[py/missing-equals]
"""Extended DNS Error (EDE, RFC8914)"""
_preserve_case = {"DNSKEY", "DS", "DNSSEC", "RRSIGs", "NSEC", "NXDOMAIN"}
def __init__(self, code: EDECode | str, text: str | None = None):
"""*code*, a ``dns.edns.EDECode`` or ``str``, the info code of the
extended error.
*text*, a ``str`` or ``None``, specifying additional information about
the error.
"""
super().__init__(OptionType.EDE)
self.code = EDECode.make(code)
if text is not None and not isinstance(text, str):
raise ValueError("text must be string or None")
self.text = text
def to_text(self) -> str:
output = f"EDE {self.code}"
if self.code in EDECode:
desc = EDECode.to_text(self.code)
desc = " ".join(
word if word in self._preserve_case else word.title()
for word in desc.split("_")
)
output += f" ({desc})"
if self.text is not None:
output += f": {self.text}"
return output
def to_wire(self, file: Any | None = None) -> bytes | None:
value = struct.pack("!H", self.code)
if self.text is not None:
value += self.text.encode("utf8")
if file:
file.write(value)
return None
else:
return value
@classmethod
def from_wire_parser(
cls, otype: OptionType | str, parser: "dns.wire.Parser"
) -> Option:
code = EDECode.make(parser.get_uint16())
text = parser.get_remaining()
if text:
if text[-1] == 0: # text MAY be null-terminated
text = text[:-1]
btext = text.decode("utf8")
else:
btext = None
return cls(code, btext)
class NSIDOption(Option):
def __init__(self, nsid: bytes):
super().__init__(OptionType.NSID)
self.nsid = nsid
def to_wire(self, file: Any = None) -> bytes | None:
if file:
file.write(self.nsid)
return None
else:
return self.nsid
def to_text(self) -> str:
if all(c >= 0x20 and c <= 0x7E for c in self.nsid):
# All ASCII printable, so it's probably a string.
value = self.nsid.decode()
else:
value = binascii.hexlify(self.nsid).decode()
return f"NSID {value}"
@classmethod
def from_wire_parser(
cls, otype: OptionType | str, parser: dns.wire.Parser
) -> Option:
return cls(parser.get_remaining())
class CookieOption(Option):
def __init__(self, client: bytes, server: bytes):
super().__init__(OptionType.COOKIE)
self.client = client
self.server = server
if len(client) != 8:
raise ValueError("client cookie must be 8 bytes")
if len(server) != 0 and (len(server) < 8 or len(server) > 32):
raise ValueError("server cookie must be empty or between 8 and 32 bytes")
def to_wire(self, file: Any = None) -> bytes | None:
if file:
file.write(self.client)
if len(self.server) > 0:
file.write(self.server)
return None
else:
return self.client + self.server
def to_text(self) -> str:
client = binascii.hexlify(self.client).decode()
if len(self.server) > 0:
server = binascii.hexlify(self.server).decode()
else:
server = ""
return f"COOKIE {client}{server}"
@classmethod
def from_wire_parser(
cls, otype: OptionType | str, parser: dns.wire.Parser
) -> Option:
return cls(parser.get_bytes(8), parser.get_remaining())
class ReportChannelOption(Option):
# RFC 9567
def __init__(self, agent_domain: dns.name.Name):
super().__init__(OptionType.REPORTCHANNEL)
self.agent_domain = agent_domain
def to_wire(self, file: Any = None) -> bytes | None:
return self.agent_domain.to_wire(file)
def to_text(self) -> str:
return "REPORTCHANNEL " + self.agent_domain.to_text()
@classmethod
def from_wire_parser(
cls, otype: OptionType | str, parser: dns.wire.Parser
) -> Option:
return cls(parser.get_name())
_type_to_class: Dict[OptionType, Any] = {
OptionType.ECS: ECSOption,
OptionType.EDE: EDEOption,
OptionType.NSID: NSIDOption,
OptionType.COOKIE: CookieOption,
OptionType.REPORTCHANNEL: ReportChannelOption,
}
def get_option_class(otype: OptionType) -> Any:
"""Return the class for the specified option type.
The GenericOption class is used if a more specific class is not
known.
"""
cls = _type_to_class.get(otype)
if cls is None:
cls = GenericOption
return cls
def option_from_wire_parser(
otype: OptionType | str, parser: "dns.wire.Parser"
) -> Option:
"""Build an EDNS option object from wire format.
*otype*, an ``int``, is the option type.
*parser*, a ``dns.wire.Parser``, the parser, which should be
restricted to the option length.
Returns an instance of a subclass of ``dns.edns.Option``.
"""
otype = OptionType.make(otype)
cls = get_option_class(otype)
return cls.from_wire_parser(otype, parser)
def option_from_wire(
otype: OptionType | str, wire: bytes, current: int, olen: int
) -> Option:
"""Build an EDNS option object from wire format.
*otype*, an ``int``, is the option type.
*wire*, a ``bytes``, is the wire-format message.
*current*, an ``int``, is the offset in *wire* of the beginning
of the rdata.
*olen*, an ``int``, is the length of the wire-format option data
Returns an instance of a subclass of ``dns.edns.Option``.
"""
parser = dns.wire.Parser(wire, current)
with parser.restrict_to(olen):
return option_from_wire_parser(otype, parser)
def register_type(implementation: Any, otype: OptionType) -> None:
"""Register the implementation of an option type.
*implementation*, a ``class``, is a subclass of ``dns.edns.Option``.
*otype*, an ``int``, is the option type.
"""
_type_to_class[otype] = implementation
### BEGIN generated OptionType constants
NSID = OptionType.NSID
DAU = OptionType.DAU
DHU = OptionType.DHU
N3U = OptionType.N3U
ECS = OptionType.ECS
EXPIRE = OptionType.EXPIRE
COOKIE = OptionType.COOKIE
KEEPALIVE = OptionType.KEEPALIVE
PADDING = OptionType.PADDING
CHAIN = OptionType.CHAIN
EDE = OptionType.EDE
REPORTCHANNEL = OptionType.REPORTCHANNEL
### END generated OptionType constants

View File

@@ -0,0 +1,130 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
# Copyright (C) 2009-2017 Nominum, Inc.
#
# Permission to use, copy, modify, and distribute this software and its
# documentation for any purpose with or without fee is hereby granted,
# provided that the above copyright notice and this permission notice
# appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import hashlib
import os
import random
import threading
import time
from typing import Any
class EntropyPool:
# This is an entropy pool for Python implementations that do not
# have a working SystemRandom. I'm not sure there are any, but
# leaving this code doesn't hurt anything as the library code
# is used if present.
def __init__(self, seed: bytes | None = None):
self.pool_index = 0
self.digest: bytearray | None = None
self.next_byte = 0
self.lock = threading.Lock()
self.hash = hashlib.sha1()
self.hash_len = 20
self.pool = bytearray(b"\0" * self.hash_len)
if seed is not None:
self._stir(seed)
self.seeded = True
self.seed_pid = os.getpid()
else:
self.seeded = False
self.seed_pid = 0
def _stir(self, entropy: bytes | bytearray) -> None:
for c in entropy:
if self.pool_index == self.hash_len:
self.pool_index = 0
b = c & 0xFF
self.pool[self.pool_index] ^= b
self.pool_index += 1
def stir(self, entropy: bytes | bytearray) -> None:
with self.lock:
self._stir(entropy)
def _maybe_seed(self) -> None:
if not self.seeded or self.seed_pid != os.getpid():
try:
seed = os.urandom(16)
except Exception: # pragma: no cover
try:
with open("/dev/urandom", "rb", 0) as r:
seed = r.read(16)
except Exception:
seed = str(time.time()).encode()
self.seeded = True
self.seed_pid = os.getpid()
self.digest = None
seed = bytearray(seed)
self._stir(seed)
def random_8(self) -> int:
with self.lock:
self._maybe_seed()
if self.digest is None or self.next_byte == self.hash_len:
self.hash.update(bytes(self.pool))
self.digest = bytearray(self.hash.digest())
self._stir(self.digest)
self.next_byte = 0
value = self.digest[self.next_byte]
self.next_byte += 1
return value
def random_16(self) -> int:
return self.random_8() * 256 + self.random_8()
def random_32(self) -> int:
return self.random_16() * 65536 + self.random_16()
def random_between(self, first: int, last: int) -> int:
size = last - first + 1
if size > 4294967296:
raise ValueError("too big")
if size > 65536:
rand = self.random_32
max = 4294967295
elif size > 256:
rand = self.random_16
max = 65535
else:
rand = self.random_8
max = 255
return first + size * rand() // (max + 1)
pool = EntropyPool()
system_random: Any | None
try:
system_random = random.SystemRandom()
except Exception: # pragma: no cover
system_random = None
def random_16() -> int:
if system_random is not None:
return system_random.randrange(0, 65536)
else:
return pool.random_16()
def between(first: int, last: int) -> int:
if system_random is not None:
return system_random.randrange(first, last + 1)
else:
return pool.random_between(first, last)

View File

@@ -0,0 +1,113 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
# Copyright (C) 2003-2017 Nominum, Inc.
#
# Permission to use, copy, modify, and distribute this software and its
# documentation for any purpose with or without fee is hereby granted,
# provided that the above copyright notice and this permission notice
# appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import enum
from typing import Any, Type, TypeVar
TIntEnum = TypeVar("TIntEnum", bound="IntEnum")
class IntEnum(enum.IntEnum):
@classmethod
def _missing_(cls, value):
cls._check_value(value)
val = int.__new__(cls, value) # pyright: ignore
val._name_ = cls._extra_to_text(value, None) or f"{cls._prefix()}{value}"
val._value_ = value # pyright: ignore
return val
@classmethod
def _check_value(cls, value):
max = cls._maximum()
if not isinstance(value, int):
raise TypeError
if value < 0 or value > max:
name = cls._short_name()
raise ValueError(f"{name} must be an int between >= 0 and <= {max}")
@classmethod
def from_text(cls: Type[TIntEnum], text: str) -> TIntEnum:
text = text.upper()
try:
return cls[text]
except KeyError:
pass
value = cls._extra_from_text(text)
if value:
return value
prefix = cls._prefix()
if text.startswith(prefix) and text[len(prefix) :].isdigit():
value = int(text[len(prefix) :])
cls._check_value(value)
return cls(value)
raise cls._unknown_exception_class()
@classmethod
def to_text(cls: Type[TIntEnum], value: int) -> str:
cls._check_value(value)
try:
text = cls(value).name
except ValueError:
text = None
text = cls._extra_to_text(value, text)
if text is None:
text = f"{cls._prefix()}{value}"
return text
@classmethod
def make(cls: Type[TIntEnum], value: int | str) -> TIntEnum:
"""Convert text or a value into an enumerated type, if possible.
*value*, the ``int`` or ``str`` to convert.
Raises a class-specific exception if a ``str`` is provided that
cannot be converted.
Raises ``ValueError`` if the value is out of range.
Returns an enumeration from the calling class corresponding to the
value, if one is defined, or an ``int`` otherwise.
"""
if isinstance(value, str):
return cls.from_text(value)
cls._check_value(value)
return cls(value)
@classmethod
def _maximum(cls):
raise NotImplementedError # pragma: no cover
@classmethod
def _short_name(cls):
return cls.__name__.lower()
@classmethod
def _prefix(cls) -> str:
return ""
@classmethod
def _extra_from_text(cls, text: str) -> Any | None: # pylint: disable=W0613
return None
@classmethod
def _extra_to_text(cls, value, current_text): # pylint: disable=W0613
return current_text
@classmethod
def _unknown_exception_class(cls) -> Type[Exception]:
return ValueError

View File

@@ -0,0 +1,169 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
# Copyright (C) 2003-2017 Nominum, Inc.
#
# Permission to use, copy, modify, and distribute this software and its
# documentation for any purpose with or without fee is hereby granted,
# provided that the above copyright notice and this permission notice
# appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
"""Common DNS Exceptions.
Dnspython modules may also define their own exceptions, which will
always be subclasses of ``DNSException``.
"""
from typing import Set
class DNSException(Exception):
"""Abstract base class shared by all dnspython exceptions.
It supports two basic modes of operation:
a) Old/compatible mode is used if ``__init__`` was called with
empty *kwargs*. In compatible mode all *args* are passed
to the standard Python Exception class as before and all *args* are
printed by the standard ``__str__`` implementation. Class variable
``msg`` (or doc string if ``msg`` is ``None``) is returned from ``str()``
if *args* is empty.
b) New/parametrized mode is used if ``__init__`` was called with
non-empty *kwargs*.
In the new mode *args* must be empty and all kwargs must match
those set in class variable ``supp_kwargs``. All kwargs are stored inside
``self.kwargs`` and used in a new ``__str__`` implementation to construct
a formatted message based on the ``fmt`` class variable, a ``string``.
In the simplest case it is enough to override the ``supp_kwargs``
and ``fmt`` class variables to get nice parametrized messages.
"""
msg: str | None = None # non-parametrized message
supp_kwargs: Set[str] = set() # accepted parameters for _fmt_kwargs (sanity check)
fmt: str | None = None # message parametrized with results from _fmt_kwargs
def __init__(self, *args, **kwargs):
self._check_params(*args, **kwargs)
if kwargs:
# This call to a virtual method from __init__ is ok in our usage
self.kwargs = self._check_kwargs(**kwargs) # lgtm[py/init-calls-subclass]
self.msg = str(self)
else:
self.kwargs = dict() # defined but empty for old mode exceptions
if self.msg is None:
# doc string is better implicit message than empty string
self.msg = self.__doc__
if args:
super().__init__(*args)
else:
super().__init__(self.msg)
def _check_params(self, *args, **kwargs):
"""Old exceptions supported only args and not kwargs.
For sanity we do not allow to mix old and new behavior."""
if args or kwargs:
assert bool(args) != bool(
kwargs
), "keyword arguments are mutually exclusive with positional args"
def _check_kwargs(self, **kwargs):
if kwargs:
assert (
set(kwargs.keys()) == self.supp_kwargs
), f"following set of keyword args is required: {self.supp_kwargs}"
return kwargs
def _fmt_kwargs(self, **kwargs):
"""Format kwargs before printing them.
Resulting dictionary has to have keys necessary for str.format call
on fmt class variable.
"""
fmtargs = {}
for kw, data in kwargs.items():
if isinstance(data, list | set):
# convert list of <someobj> to list of str(<someobj>)
fmtargs[kw] = list(map(str, data))
if len(fmtargs[kw]) == 1:
# remove list brackets [] from single-item lists
fmtargs[kw] = fmtargs[kw].pop()
else:
fmtargs[kw] = data
return fmtargs
def __str__(self):
if self.kwargs and self.fmt:
# provide custom message constructed from keyword arguments
fmtargs = self._fmt_kwargs(**self.kwargs)
return self.fmt.format(**fmtargs)
else:
# print *args directly in the same way as old DNSException
return super().__str__()
class FormError(DNSException):
"""DNS message is malformed."""
class SyntaxError(DNSException):
"""Text input is malformed."""
class UnexpectedEnd(SyntaxError):
"""Text input ended unexpectedly."""
class TooBig(DNSException):
"""The DNS message is too big."""
class Timeout(DNSException):
"""The DNS operation timed out."""
supp_kwargs = {"timeout"}
fmt = "The DNS operation timed out after {timeout:.3f} seconds"
# We do this as otherwise mypy complains about unexpected keyword argument
# idna_exception
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
class UnsupportedAlgorithm(DNSException):
"""The DNSSEC algorithm is not supported."""
class AlgorithmKeyMismatch(UnsupportedAlgorithm):
"""The DNSSEC algorithm is not supported for the given key type."""
class ValidationFailure(DNSException):
"""The DNSSEC signature is invalid."""
class DeniedByPolicy(DNSException):
"""Denied by DNSSEC policy."""
class ExceptionWrapper:
def __init__(self, exception_class):
self.exception_class = exception_class
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if exc_type is not None and not isinstance(exc_val, self.exception_class):
raise self.exception_class(str(exc_val)) from exc_val
return False

View File

@@ -0,0 +1,123 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
# Copyright (C) 2001-2017 Nominum, Inc.
#
# Permission to use, copy, modify, and distribute this software and its
# documentation for any purpose with or without fee is hereby granted,
# provided that the above copyright notice and this permission notice
# appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
"""DNS Message Flags."""
import enum
from typing import Any
# Standard DNS flags
class Flag(enum.IntFlag):
#: Query Response
QR = 0x8000
#: Authoritative Answer
AA = 0x0400
#: Truncated Response
TC = 0x0200
#: Recursion Desired
RD = 0x0100
#: Recursion Available
RA = 0x0080
#: Authentic Data
AD = 0x0020
#: Checking Disabled
CD = 0x0010
# EDNS flags
class EDNSFlag(enum.IntFlag):
#: DNSSEC answer OK
DO = 0x8000
def _from_text(text: str, enum_class: Any) -> int:
flags = 0
tokens = text.split()
for t in tokens:
flags |= enum_class[t.upper()]
return flags
def _to_text(flags: int, enum_class: Any) -> str:
text_flags = []
for k, v in enum_class.__members__.items():
if flags & v != 0:
text_flags.append(k)
return " ".join(text_flags)
def from_text(text: str) -> int:
"""Convert a space-separated list of flag text values into a flags
value.
Returns an ``int``
"""
return _from_text(text, Flag)
def to_text(flags: int) -> str:
"""Convert a flags value into a space-separated list of flag text
values.
Returns a ``str``.
"""
return _to_text(flags, Flag)
def edns_from_text(text: str) -> int:
"""Convert a space-separated list of EDNS flag text values into a EDNS
flags value.
Returns an ``int``
"""
return _from_text(text, EDNSFlag)
def edns_to_text(flags: int) -> str:
"""Convert an EDNS flags value into a space-separated list of EDNS flag
text values.
Returns a ``str``.
"""
return _to_text(flags, EDNSFlag)
### BEGIN generated Flag constants
QR = Flag.QR
AA = Flag.AA
TC = Flag.TC
RD = Flag.RD
RA = Flag.RA
AD = Flag.AD
CD = Flag.CD
### END generated Flag constants
### BEGIN generated EDNSFlag constants
DO = EDNSFlag.DO
### END generated EDNSFlag constants

View File

@@ -0,0 +1,72 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
# Copyright (C) 2012-2017 Nominum, Inc.
#
# Permission to use, copy, modify, and distribute this software and its
# documentation for any purpose with or without fee is hereby granted,
# provided that the above copyright notice and this permission notice
# appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
"""DNS GENERATE range conversion."""
from typing import Tuple
import dns.exception
def from_text(text: str) -> Tuple[int, int, int]:
"""Convert the text form of a range in a ``$GENERATE`` statement to an
integer.
*text*, a ``str``, the textual range in ``$GENERATE`` form.
Returns a tuple of three ``int`` values ``(start, stop, step)``.
"""
start = -1
stop = -1
step = 1
cur = ""
state = 0
# state 0 1 2
# x - y / z
if text and text[0] == "-":
raise dns.exception.SyntaxError("Start cannot be a negative number")
for c in text:
if c == "-" and state == 0:
start = int(cur)
cur = ""
state = 1
elif c == "/":
stop = int(cur)
cur = ""
state = 2
elif c.isdigit():
cur += c
else:
raise dns.exception.SyntaxError(f"Could not parse {c}")
if state == 0:
raise dns.exception.SyntaxError("no stop value specified")
elif state == 1:
stop = int(cur)
else:
assert state == 2
step = int(cur)
assert step >= 1
assert start >= 0
if start > stop:
raise dns.exception.SyntaxError("start must be <= stop")
return (start, stop, step)

View File

@@ -0,0 +1,68 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
import collections.abc
from typing import Any, Callable
from dns._immutable_ctx import immutable
@immutable
class Dict(collections.abc.Mapping): # lgtm[py/missing-equals]
def __init__(
self,
dictionary: Any,
no_copy: bool = False,
map_factory: Callable[[], collections.abc.MutableMapping] = dict,
):
"""Make an immutable dictionary from the specified dictionary.
If *no_copy* is `True`, then *dictionary* will be wrapped instead
of copied. Only set this if you are sure there will be no external
references to the dictionary.
"""
if no_copy and isinstance(dictionary, collections.abc.MutableMapping):
self._odict = dictionary
else:
self._odict = map_factory()
self._odict.update(dictionary)
self._hash = None
def __getitem__(self, key):
return self._odict.__getitem__(key)
def __hash__(self): # pylint: disable=invalid-hash-returned
if self._hash is None:
h = 0
for key in sorted(self._odict.keys()):
h ^= hash(key)
object.__setattr__(self, "_hash", h)
# this does return an int, but pylint doesn't figure that out
return self._hash
def __len__(self):
return len(self._odict)
def __iter__(self):
return iter(self._odict)
def constify(o: Any) -> Any:
"""
Convert mutable types to immutable types.
"""
if isinstance(o, bytearray):
return bytes(o)
if isinstance(o, tuple):
try:
hash(o)
return o
except Exception:
return tuple(constify(elt) for elt in o)
if isinstance(o, list):
return tuple(constify(elt) for elt in o)
if isinstance(o, dict):
cdict = dict()
for k, v in o.items():
cdict[k] = constify(v)
return Dict(cdict, True)
return o

View File

@@ -0,0 +1,195 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
# Copyright (C) 2003-2017 Nominum, Inc.
#
# Permission to use, copy, modify, and distribute this software and its
# documentation for any purpose with or without fee is hereby granted,
# provided that the above copyright notice and this permission notice
# appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
"""Generic Internet address helper functions."""
import socket
from typing import Any, Tuple
import dns.ipv4
import dns.ipv6
# We assume that AF_INET and AF_INET6 are always defined. We keep
# these here for the benefit of any old code (unlikely though that
# is!).
AF_INET = socket.AF_INET
AF_INET6 = socket.AF_INET6
def inet_pton(family: int, text: str) -> bytes:
"""Convert the textual form of a network address into its binary form.
*family* is an ``int``, the address family.
*text* is a ``str``, the textual address.
Raises ``NotImplementedError`` if the address family specified is not
implemented.
Returns a ``bytes``.
"""
if family == AF_INET:
return dns.ipv4.inet_aton(text)
elif family == AF_INET6:
return dns.ipv6.inet_aton(text, True)
else:
raise NotImplementedError
def inet_ntop(family: int, address: bytes) -> str:
"""Convert the binary form of a network address into its textual form.
*family* is an ``int``, the address family.
*address* is a ``bytes``, the network address in binary form.
Raises ``NotImplementedError`` if the address family specified is not
implemented.
Returns a ``str``.
"""
if family == AF_INET:
return dns.ipv4.inet_ntoa(address)
elif family == AF_INET6:
return dns.ipv6.inet_ntoa(address)
else:
raise NotImplementedError
def af_for_address(text: str) -> int:
"""Determine the address family of a textual-form network address.
*text*, a ``str``, the textual address.
Raises ``ValueError`` if the address family cannot be determined
from the input.
Returns an ``int``.
"""
try:
dns.ipv4.inet_aton(text)
return AF_INET
except Exception:
try:
dns.ipv6.inet_aton(text, True)
return AF_INET6
except Exception:
raise ValueError
def is_multicast(text: str) -> bool:
"""Is the textual-form network address a multicast address?
*text*, a ``str``, the textual address.
Raises ``ValueError`` if the address family cannot be determined
from the input.
Returns a ``bool``.
"""
try:
first = dns.ipv4.inet_aton(text)[0]
return first >= 224 and first <= 239
except Exception:
try:
first = dns.ipv6.inet_aton(text, True)[0]
return first == 255
except Exception:
raise ValueError
def is_address(text: str) -> bool:
"""Is the specified string an IPv4 or IPv6 address?
*text*, a ``str``, the textual address.
Returns a ``bool``.
"""
try:
dns.ipv4.inet_aton(text)
return True
except Exception:
try:
dns.ipv6.inet_aton(text, True)
return True
except Exception:
return False
def low_level_address_tuple(high_tuple: Tuple[str, int], af: int | None = None) -> Any:
"""Given a "high-level" address tuple, i.e.
an (address, port) return the appropriate "low-level" address tuple
suitable for use in socket calls.
If an *af* other than ``None`` is provided, it is assumed the
address in the high-level tuple is valid and has that af. If af
is ``None``, then af_for_address will be called.
"""
address, port = high_tuple
if af is None:
af = af_for_address(address)
if af == AF_INET:
return (address, port)
elif af == AF_INET6:
i = address.find("%")
if i < 0:
# no scope, shortcut!
return (address, port, 0, 0)
# try to avoid getaddrinfo()
addrpart = address[:i]
scope = address[i + 1 :]
if scope.isdigit():
return (addrpart, port, 0, int(scope))
try:
return (addrpart, port, 0, socket.if_nametoindex(scope))
except AttributeError: # pragma: no cover (we can't really test this)
ai_flags = socket.AI_NUMERICHOST
((*_, tup), *_) = socket.getaddrinfo(address, port, flags=ai_flags)
return tup
else:
raise NotImplementedError(f"unknown address family {af}")
def any_for_af(af):
"""Return the 'any' address for the specified address family."""
if af == socket.AF_INET:
return "0.0.0.0"
elif af == socket.AF_INET6:
return "::"
raise NotImplementedError(f"unknown address family {af}")
def canonicalize(text: str) -> str:
"""Verify that *address* is a valid text form IPv4 or IPv6 address and return its
canonical text form. IPv6 addresses with scopes are rejected.
*text*, a ``str``, the address in textual form.
Raises ``ValueError`` if the text is not valid.
"""
try:
return dns.ipv6.canonicalize(text)
except Exception:
try:
return dns.ipv4.canonicalize(text)
except Exception:
raise ValueError

View File

@@ -0,0 +1,76 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
# Copyright (C) 2003-2017 Nominum, Inc.
#
# Permission to use, copy, modify, and distribute this software and its
# documentation for any purpose with or without fee is hereby granted,
# provided that the above copyright notice and this permission notice
# appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
"""IPv4 helper functions."""
import struct
import dns.exception
def inet_ntoa(address: bytes) -> str:
"""Convert an IPv4 address in binary form to text form.
*address*, a ``bytes``, the IPv4 address in binary form.
Returns a ``str``.
"""
if len(address) != 4:
raise dns.exception.SyntaxError
return f"{address[0]}.{address[1]}.{address[2]}.{address[3]}"
def inet_aton(text: str | bytes) -> bytes:
"""Convert an IPv4 address in text form to binary form.
*text*, a ``str`` or ``bytes``, the IPv4 address in textual form.
Returns a ``bytes``.
"""
if not isinstance(text, bytes):
btext = text.encode()
else:
btext = text
parts = btext.split(b".")
if len(parts) != 4:
raise dns.exception.SyntaxError
for part in parts:
if not part.isdigit():
raise dns.exception.SyntaxError
if len(part) > 1 and part[0] == ord("0"):
# No leading zeros
raise dns.exception.SyntaxError
try:
b = [int(part) for part in parts]
return struct.pack("BBBB", *b)
except Exception:
raise dns.exception.SyntaxError
def canonicalize(text: str | bytes) -> str:
"""Verify that *address* is a valid text form IPv4 address and return its
canonical text form.
*text*, a ``str`` or ``bytes``, the IPv4 address in textual form.
Raises ``dns.exception.SyntaxError`` if the text is not valid.
"""
# Note that inet_aton() only accepts canonial form, but we still run through
# inet_ntoa() to ensure the output is a str.
return inet_ntoa(inet_aton(text))

View File

@@ -0,0 +1,217 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
# Copyright (C) 2003-2017 Nominum, Inc.
#
# Permission to use, copy, modify, and distribute this software and its
# documentation for any purpose with or without fee is hereby granted,
# provided that the above copyright notice and this permission notice
# appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
"""IPv6 helper functions."""
import binascii
import re
from typing import List
import dns.exception
import dns.ipv4
_leading_zero = re.compile(r"0+([0-9a-f]+)")
def inet_ntoa(address: bytes) -> str:
"""Convert an IPv6 address in binary form to text form.
*address*, a ``bytes``, the IPv6 address in binary form.
Raises ``ValueError`` if the address isn't 16 bytes long.
Returns a ``str``.
"""
if len(address) != 16:
raise ValueError("IPv6 addresses are 16 bytes long")
hex = binascii.hexlify(address)
chunks = []
i = 0
l = len(hex)
while i < l:
chunk = hex[i : i + 4].decode()
# strip leading zeros. we do this with an re instead of
# with lstrip() because lstrip() didn't support chars until
# python 2.2.2
m = _leading_zero.match(chunk)
if m is not None:
chunk = m.group(1)
chunks.append(chunk)
i += 4
#
# Compress the longest subsequence of 0-value chunks to ::
#
best_start = 0
best_len = 0
start = -1
last_was_zero = False
for i in range(8):
if chunks[i] != "0":
if last_was_zero:
end = i
current_len = end - start
if current_len > best_len:
best_start = start
best_len = current_len
last_was_zero = False
elif not last_was_zero:
start = i
last_was_zero = True
if last_was_zero:
end = 8
current_len = end - start
if current_len > best_len:
best_start = start
best_len = current_len
if best_len > 1:
if best_start == 0 and (best_len == 6 or best_len == 5 and chunks[5] == "ffff"):
# We have an embedded IPv4 address
if best_len == 6:
prefix = "::"
else:
prefix = "::ffff:"
thex = prefix + dns.ipv4.inet_ntoa(address[12:])
else:
thex = (
":".join(chunks[:best_start])
+ "::"
+ ":".join(chunks[best_start + best_len :])
)
else:
thex = ":".join(chunks)
return thex
_v4_ending = re.compile(rb"(.*):(\d+\.\d+\.\d+\.\d+)$")
_colon_colon_start = re.compile(rb"::.*")
_colon_colon_end = re.compile(rb".*::$")
def inet_aton(text: str | bytes, ignore_scope: bool = False) -> bytes:
"""Convert an IPv6 address in text form to binary form.
*text*, a ``str`` or ``bytes``, the IPv6 address in textual form.
*ignore_scope*, a ``bool``. If ``True``, a scope will be ignored.
If ``False``, the default, it is an error for a scope to be present.
Returns a ``bytes``.
"""
#
# Our aim here is not something fast; we just want something that works.
#
if not isinstance(text, bytes):
btext = text.encode()
else:
btext = text
if ignore_scope:
parts = btext.split(b"%")
l = len(parts)
if l == 2:
btext = parts[0]
elif l > 2:
raise dns.exception.SyntaxError
if btext == b"":
raise dns.exception.SyntaxError
elif btext.endswith(b":") and not btext.endswith(b"::"):
raise dns.exception.SyntaxError
elif btext.startswith(b":") and not btext.startswith(b"::"):
raise dns.exception.SyntaxError
elif btext == b"::":
btext = b"0::"
#
# Get rid of the icky dot-quad syntax if we have it.
#
m = _v4_ending.match(btext)
if m is not None:
b = dns.ipv4.inet_aton(m.group(2))
btext = (
f"{m.group(1).decode()}:{b[0]:02x}{b[1]:02x}:{b[2]:02x}{b[3]:02x}"
).encode()
#
# Try to turn '::<whatever>' into ':<whatever>'; if no match try to
# turn '<whatever>::' into '<whatever>:'
#
m = _colon_colon_start.match(btext)
if m is not None:
btext = btext[1:]
else:
m = _colon_colon_end.match(btext)
if m is not None:
btext = btext[:-1]
#
# Now canonicalize into 8 chunks of 4 hex digits each
#
chunks = btext.split(b":")
l = len(chunks)
if l > 8:
raise dns.exception.SyntaxError
seen_empty = False
canonical: List[bytes] = []
for c in chunks:
if c == b"":
if seen_empty:
raise dns.exception.SyntaxError
seen_empty = True
for _ in range(0, 8 - l + 1):
canonical.append(b"0000")
else:
lc = len(c)
if lc > 4:
raise dns.exception.SyntaxError
if lc != 4:
c = (b"0" * (4 - lc)) + c
canonical.append(c)
if l < 8 and not seen_empty:
raise dns.exception.SyntaxError
btext = b"".join(canonical)
#
# Finally we can go to binary.
#
try:
return binascii.unhexlify(btext)
except (binascii.Error, TypeError):
raise dns.exception.SyntaxError
_mapped_prefix = b"\x00" * 10 + b"\xff\xff"
def is_mapped(address: bytes) -> bool:
"""Is the specified address a mapped IPv4 address?
*address*, a ``bytes`` is an IPv6 address in binary form.
Returns a ``bool``.
"""
return address.startswith(_mapped_prefix)
def canonicalize(text: str | bytes) -> str:
"""Verify that *address* is a valid text form IPv6 address and return its
canonical text form. Addresses with scopes are rejected.
*text*, a ``str`` or ``bytes``, the IPv6 address in textual form.
Raises ``dns.exception.SyntaxError`` if the text is not valid.
"""
return inet_ntoa(inet_aton(text))

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,109 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
# Copyright (C) 2003-2017 Nominum, Inc.
# Copyright (C) 2016 Coresec Systems AB
#
# Permission to use, copy, modify, and distribute this software and its
# documentation for any purpose with or without fee is hereby granted,
# provided that the above copyright notice and this permission notice
# appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND CORESEC SYSTEMS AB DISCLAIMS ALL
# WARRANTIES WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED
# WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL CORESEC
# SYSTEMS AB BE LIABLE FOR ANY SPECIAL, DIRECT, INDIRECT, OR
# CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS
# OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT,
# NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION
# WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
"""DNS name dictionary"""
# pylint seems to be confused about this one!
from collections.abc import MutableMapping # pylint: disable=no-name-in-module
import dns.name
class NameDict(MutableMapping):
"""A dictionary whose keys are dns.name.Name objects.
In addition to being like a regular Python dictionary, this
dictionary can also get the deepest match for a given key.
"""
__slots__ = ["max_depth", "max_depth_items", "__store"]
def __init__(self, *args, **kwargs):
super().__init__()
self.__store = dict()
#: the maximum depth of the keys that have ever been added
self.max_depth = 0
#: the number of items of maximum depth
self.max_depth_items = 0
self.update(dict(*args, **kwargs))
def __update_max_depth(self, key):
if len(key) == self.max_depth:
self.max_depth_items = self.max_depth_items + 1
elif len(key) > self.max_depth:
self.max_depth = len(key)
self.max_depth_items = 1
def __getitem__(self, key):
return self.__store[key]
def __setitem__(self, key, value):
if not isinstance(key, dns.name.Name):
raise ValueError("NameDict key must be a name")
self.__store[key] = value
self.__update_max_depth(key)
def __delitem__(self, key):
self.__store.pop(key)
if len(key) == self.max_depth:
self.max_depth_items = self.max_depth_items - 1
if self.max_depth_items == 0:
self.max_depth = 0
for k in self.__store:
self.__update_max_depth(k)
def __iter__(self):
return iter(self.__store)
def __len__(self):
return len(self.__store)
def has_key(self, key):
return key in self.__store
def get_deepest_match(self, name):
"""Find the deepest match to *name* in the dictionary.
The deepest match is the longest name in the dictionary which is
a superdomain of *name*. Note that *superdomain* includes matching
*name* itself.
*name*, a ``dns.name.Name``, the name to find.
Returns a ``(key, value)`` where *key* is the deepest
``dns.name.Name``, and *value* is the value associated with *key*.
"""
depth = len(name)
if depth > self.max_depth:
depth = self.max_depth
for i in range(-depth, 0):
n = dns.name.Name(name[i:])
if n in self:
return (n, self[n])
v = self[dns.name.empty]
return (dns.name.empty, v)

View File

@@ -0,0 +1,361 @@
from urllib.parse import urlparse
import dns.asyncbackend
import dns.asyncquery
import dns.message
import dns.query
class Nameserver:
def __init__(self):
pass
def __str__(self):
raise NotImplementedError
def kind(self) -> str:
raise NotImplementedError
def is_always_max_size(self) -> bool:
raise NotImplementedError
def answer_nameserver(self) -> str:
raise NotImplementedError
def answer_port(self) -> int:
raise NotImplementedError
def query(
self,
request: dns.message.QueryMessage,
timeout: float,
source: str | None,
source_port: int,
max_size: bool,
one_rr_per_rrset: bool = False,
ignore_trailing: bool = False,
) -> dns.message.Message:
raise NotImplementedError
async def async_query(
self,
request: dns.message.QueryMessage,
timeout: float,
source: str | None,
source_port: int,
max_size: bool,
backend: dns.asyncbackend.Backend,
one_rr_per_rrset: bool = False,
ignore_trailing: bool = False,
) -> dns.message.Message:
raise NotImplementedError
class AddressAndPortNameserver(Nameserver):
def __init__(self, address: str, port: int):
super().__init__()
self.address = address
self.port = port
def kind(self) -> str:
raise NotImplementedError
def is_always_max_size(self) -> bool:
return False
def __str__(self):
ns_kind = self.kind()
return f"{ns_kind}:{self.address}@{self.port}"
def answer_nameserver(self) -> str:
return self.address
def answer_port(self) -> int:
return self.port
class Do53Nameserver(AddressAndPortNameserver):
def __init__(self, address: str, port: int = 53):
super().__init__(address, port)
def kind(self):
return "Do53"
def query(
self,
request: dns.message.QueryMessage,
timeout: float,
source: str | None,
source_port: int,
max_size: bool,
one_rr_per_rrset: bool = False,
ignore_trailing: bool = False,
) -> dns.message.Message:
if max_size:
response = dns.query.tcp(
request,
self.address,
timeout=timeout,
port=self.port,
source=source,
source_port=source_port,
one_rr_per_rrset=one_rr_per_rrset,
ignore_trailing=ignore_trailing,
)
else:
response = dns.query.udp(
request,
self.address,
timeout=timeout,
port=self.port,
source=source,
source_port=source_port,
raise_on_truncation=True,
one_rr_per_rrset=one_rr_per_rrset,
ignore_trailing=ignore_trailing,
ignore_errors=True,
ignore_unexpected=True,
)
return response
async def async_query(
self,
request: dns.message.QueryMessage,
timeout: float,
source: str | None,
source_port: int,
max_size: bool,
backend: dns.asyncbackend.Backend,
one_rr_per_rrset: bool = False,
ignore_trailing: bool = False,
) -> dns.message.Message:
if max_size:
response = await dns.asyncquery.tcp(
request,
self.address,
timeout=timeout,
port=self.port,
source=source,
source_port=source_port,
backend=backend,
one_rr_per_rrset=one_rr_per_rrset,
ignore_trailing=ignore_trailing,
)
else:
response = await dns.asyncquery.udp(
request,
self.address,
timeout=timeout,
port=self.port,
source=source,
source_port=source_port,
raise_on_truncation=True,
backend=backend,
one_rr_per_rrset=one_rr_per_rrset,
ignore_trailing=ignore_trailing,
ignore_errors=True,
ignore_unexpected=True,
)
return response
class DoHNameserver(Nameserver):
def __init__(
self,
url: str,
bootstrap_address: str | None = None,
verify: bool | str = True,
want_get: bool = False,
http_version: dns.query.HTTPVersion = dns.query.HTTPVersion.DEFAULT,
):
super().__init__()
self.url = url
self.bootstrap_address = bootstrap_address
self.verify = verify
self.want_get = want_get
self.http_version = http_version
def kind(self):
return "DoH"
def is_always_max_size(self) -> bool:
return True
def __str__(self):
return self.url
def answer_nameserver(self) -> str:
return self.url
def answer_port(self) -> int:
port = urlparse(self.url).port
if port is None:
port = 443
return port
def query(
self,
request: dns.message.QueryMessage,
timeout: float,
source: str | None,
source_port: int,
max_size: bool = False,
one_rr_per_rrset: bool = False,
ignore_trailing: bool = False,
) -> dns.message.Message:
return dns.query.https(
request,
self.url,
timeout=timeout,
source=source,
source_port=source_port,
bootstrap_address=self.bootstrap_address,
one_rr_per_rrset=one_rr_per_rrset,
ignore_trailing=ignore_trailing,
verify=self.verify,
post=(not self.want_get),
http_version=self.http_version,
)
async def async_query(
self,
request: dns.message.QueryMessage,
timeout: float,
source: str | None,
source_port: int,
max_size: bool,
backend: dns.asyncbackend.Backend,
one_rr_per_rrset: bool = False,
ignore_trailing: bool = False,
) -> dns.message.Message:
return await dns.asyncquery.https(
request,
self.url,
timeout=timeout,
source=source,
source_port=source_port,
bootstrap_address=self.bootstrap_address,
one_rr_per_rrset=one_rr_per_rrset,
ignore_trailing=ignore_trailing,
verify=self.verify,
post=(not self.want_get),
http_version=self.http_version,
)
class DoTNameserver(AddressAndPortNameserver):
def __init__(
self,
address: str,
port: int = 853,
hostname: str | None = None,
verify: bool | str = True,
):
super().__init__(address, port)
self.hostname = hostname
self.verify = verify
def kind(self):
return "DoT"
def query(
self,
request: dns.message.QueryMessage,
timeout: float,
source: str | None,
source_port: int,
max_size: bool = False,
one_rr_per_rrset: bool = False,
ignore_trailing: bool = False,
) -> dns.message.Message:
return dns.query.tls(
request,
self.address,
port=self.port,
timeout=timeout,
one_rr_per_rrset=one_rr_per_rrset,
ignore_trailing=ignore_trailing,
server_hostname=self.hostname,
verify=self.verify,
)
async def async_query(
self,
request: dns.message.QueryMessage,
timeout: float,
source: str | None,
source_port: int,
max_size: bool,
backend: dns.asyncbackend.Backend,
one_rr_per_rrset: bool = False,
ignore_trailing: bool = False,
) -> dns.message.Message:
return await dns.asyncquery.tls(
request,
self.address,
port=self.port,
timeout=timeout,
one_rr_per_rrset=one_rr_per_rrset,
ignore_trailing=ignore_trailing,
server_hostname=self.hostname,
verify=self.verify,
)
class DoQNameserver(AddressAndPortNameserver):
def __init__(
self,
address: str,
port: int = 853,
verify: bool | str = True,
server_hostname: str | None = None,
):
super().__init__(address, port)
self.verify = verify
self.server_hostname = server_hostname
def kind(self):
return "DoQ"
def query(
self,
request: dns.message.QueryMessage,
timeout: float,
source: str | None,
source_port: int,
max_size: bool = False,
one_rr_per_rrset: bool = False,
ignore_trailing: bool = False,
) -> dns.message.Message:
return dns.query.quic(
request,
self.address,
port=self.port,
timeout=timeout,
one_rr_per_rrset=one_rr_per_rrset,
ignore_trailing=ignore_trailing,
verify=self.verify,
server_hostname=self.server_hostname,
)
async def async_query(
self,
request: dns.message.QueryMessage,
timeout: float,
source: str | None,
source_port: int,
max_size: bool,
backend: dns.asyncbackend.Backend,
one_rr_per_rrset: bool = False,
ignore_trailing: bool = False,
) -> dns.message.Message:
return await dns.asyncquery.quic(
request,
self.address,
port=self.port,
timeout=timeout,
one_rr_per_rrset=one_rr_per_rrset,
ignore_trailing=ignore_trailing,
verify=self.verify,
server_hostname=self.server_hostname,
)

View File

@@ -0,0 +1,358 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
# Copyright (C) 2001-2017 Nominum, Inc.
#
# Permission to use, copy, modify, and distribute this software and its
# documentation for any purpose with or without fee is hereby granted,
# provided that the above copyright notice and this permission notice
# appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
"""DNS nodes. A node is a set of rdatasets."""
import enum
import io
from typing import Any, Dict
import dns.immutable
import dns.name
import dns.rdataclass
import dns.rdataset
import dns.rdatatype
import dns.rrset
_cname_types = {
dns.rdatatype.CNAME,
}
# "neutral" types can coexist with a CNAME and thus are not "other data"
_neutral_types = {
dns.rdatatype.NSEC, # RFC 4035 section 2.5
dns.rdatatype.NSEC3, # This is not likely to happen, but not impossible!
dns.rdatatype.KEY, # RFC 4035 section 2.5, RFC 3007
}
def _matches_type_or_its_signature(rdtypes, rdtype, covers):
return rdtype in rdtypes or (rdtype == dns.rdatatype.RRSIG and covers in rdtypes)
@enum.unique
class NodeKind(enum.Enum):
"""Rdatasets in nodes"""
REGULAR = 0 # a.k.a "other data"
NEUTRAL = 1
CNAME = 2
@classmethod
def classify(
cls, rdtype: dns.rdatatype.RdataType, covers: dns.rdatatype.RdataType
) -> "NodeKind":
if _matches_type_or_its_signature(_cname_types, rdtype, covers):
return NodeKind.CNAME
elif _matches_type_or_its_signature(_neutral_types, rdtype, covers):
return NodeKind.NEUTRAL
else:
return NodeKind.REGULAR
@classmethod
def classify_rdataset(cls, rdataset: dns.rdataset.Rdataset) -> "NodeKind":
return cls.classify(rdataset.rdtype, rdataset.covers)
class Node:
"""A Node is a set of rdatasets.
A node is either a CNAME node or an "other data" node. A CNAME
node contains only CNAME, KEY, NSEC, and NSEC3 rdatasets along with their
covering RRSIG rdatasets. An "other data" node contains any
rdataset other than a CNAME or RRSIG(CNAME) rdataset. When
changes are made to a node, the CNAME or "other data" state is
always consistent with the update, i.e. the most recent change
wins. For example, if you have a node which contains a CNAME
rdataset, and then add an MX rdataset to it, then the CNAME
rdataset will be deleted. Likewise if you have a node containing
an MX rdataset and add a CNAME rdataset, the MX rdataset will be
deleted.
"""
__slots__ = ["rdatasets"]
def __init__(self):
# the set of rdatasets, represented as a list.
self.rdatasets = []
def to_text(self, name: dns.name.Name, **kw: Dict[str, Any]) -> str:
"""Convert a node to text format.
Each rdataset at the node is printed. Any keyword arguments
to this method are passed on to the rdataset's to_text() method.
*name*, a ``dns.name.Name``, the owner name of the
rdatasets.
Returns a ``str``.
"""
s = io.StringIO()
for rds in self.rdatasets:
if len(rds) > 0:
s.write(rds.to_text(name, **kw)) # type: ignore[arg-type]
s.write("\n")
return s.getvalue()[:-1]
def __repr__(self):
return "<DNS node " + str(id(self)) + ">"
def __eq__(self, other):
#
# This is inefficient. Good thing we don't need to do it much.
#
for rd in self.rdatasets:
if rd not in other.rdatasets:
return False
for rd in other.rdatasets:
if rd not in self.rdatasets:
return False
return True
def __ne__(self, other):
return not self.__eq__(other)
def __len__(self):
return len(self.rdatasets)
def __iter__(self):
return iter(self.rdatasets)
def _append_rdataset(self, rdataset):
"""Append rdataset to the node with special handling for CNAME and
other data conditions.
Specifically, if the rdataset being appended has ``NodeKind.CNAME``,
then all rdatasets other than KEY, NSEC, NSEC3, and their covering
RRSIGs are deleted. If the rdataset being appended has
``NodeKind.REGULAR`` then CNAME and RRSIG(CNAME) are deleted.
"""
# Make having just one rdataset at the node fast.
if len(self.rdatasets) > 0:
kind = NodeKind.classify_rdataset(rdataset)
if kind == NodeKind.CNAME:
self.rdatasets = [
rds
for rds in self.rdatasets
if NodeKind.classify_rdataset(rds) != NodeKind.REGULAR
]
elif kind == NodeKind.REGULAR:
self.rdatasets = [
rds
for rds in self.rdatasets
if NodeKind.classify_rdataset(rds) != NodeKind.CNAME
]
# Otherwise the rdataset is NodeKind.NEUTRAL and we do not need to
# edit self.rdatasets.
self.rdatasets.append(rdataset)
def find_rdataset(
self,
rdclass: dns.rdataclass.RdataClass,
rdtype: dns.rdatatype.RdataType,
covers: dns.rdatatype.RdataType = dns.rdatatype.NONE,
create: bool = False,
) -> dns.rdataset.Rdataset:
"""Find an rdataset matching the specified properties in the
current node.
*rdclass*, a ``dns.rdataclass.RdataClass``, the class of the rdataset.
*rdtype*, a ``dns.rdatatype.RdataType``, the type of the rdataset.
*covers*, a ``dns.rdatatype.RdataType``, the covered type.
Usually this value is ``dns.rdatatype.NONE``, but if the
rdtype is ``dns.rdatatype.SIG`` or ``dns.rdatatype.RRSIG``,
then the covers value will be the rdata type the SIG/RRSIG
covers. The library treats the SIG and RRSIG types as if they
were a family of types, e.g. RRSIG(A), RRSIG(NS), RRSIG(SOA).
This makes RRSIGs much easier to work with than if RRSIGs
covering different rdata types were aggregated into a single
RRSIG rdataset.
*create*, a ``bool``. If True, create the rdataset if it is not found.
Raises ``KeyError`` if an rdataset of the desired type and class does
not exist and *create* is not ``True``.
Returns a ``dns.rdataset.Rdataset``.
"""
for rds in self.rdatasets:
if rds.match(rdclass, rdtype, covers):
return rds
if not create:
raise KeyError
rds = dns.rdataset.Rdataset(rdclass, rdtype, covers)
self._append_rdataset(rds)
return rds
def get_rdataset(
self,
rdclass: dns.rdataclass.RdataClass,
rdtype: dns.rdatatype.RdataType,
covers: dns.rdatatype.RdataType = dns.rdatatype.NONE,
create: bool = False,
) -> dns.rdataset.Rdataset | None:
"""Get an rdataset matching the specified properties in the
current node.
None is returned if an rdataset of the specified type and
class does not exist and *create* is not ``True``.
*rdclass*, an ``int``, the class of the rdataset.
*rdtype*, an ``int``, the type of the rdataset.
*covers*, an ``int``, the covered type. Usually this value is
dns.rdatatype.NONE, but if the rdtype is dns.rdatatype.SIG or
dns.rdatatype.RRSIG, then the covers value will be the rdata
type the SIG/RRSIG covers. The library treats the SIG and RRSIG
types as if they were a family of
types, e.g. RRSIG(A), RRSIG(NS), RRSIG(SOA). This makes RRSIGs much
easier to work with than if RRSIGs covering different rdata
types were aggregated into a single RRSIG rdataset.
*create*, a ``bool``. If True, create the rdataset if it is not found.
Returns a ``dns.rdataset.Rdataset`` or ``None``.
"""
try:
rds = self.find_rdataset(rdclass, rdtype, covers, create)
except KeyError:
rds = None
return rds
def delete_rdataset(
self,
rdclass: dns.rdataclass.RdataClass,
rdtype: dns.rdatatype.RdataType,
covers: dns.rdatatype.RdataType = dns.rdatatype.NONE,
) -> None:
"""Delete the rdataset matching the specified properties in the
current node.
If a matching rdataset does not exist, it is not an error.
*rdclass*, an ``int``, the class of the rdataset.
*rdtype*, an ``int``, the type of the rdataset.
*covers*, an ``int``, the covered type.
"""
rds = self.get_rdataset(rdclass, rdtype, covers)
if rds is not None:
self.rdatasets.remove(rds)
def replace_rdataset(self, replacement: dns.rdataset.Rdataset) -> None:
"""Replace an rdataset.
It is not an error if there is no rdataset matching *replacement*.
Ownership of the *replacement* object is transferred to the node;
in other words, this method does not store a copy of *replacement*
at the node, it stores *replacement* itself.
*replacement*, a ``dns.rdataset.Rdataset``.
Raises ``ValueError`` if *replacement* is not a
``dns.rdataset.Rdataset``.
"""
if not isinstance(replacement, dns.rdataset.Rdataset):
raise ValueError("replacement is not an rdataset")
if isinstance(replacement, dns.rrset.RRset):
# RRsets are not good replacements as the match() method
# is not compatible.
replacement = replacement.to_rdataset()
self.delete_rdataset(
replacement.rdclass, replacement.rdtype, replacement.covers
)
self._append_rdataset(replacement)
def classify(self) -> NodeKind:
"""Classify a node.
A node which contains a CNAME or RRSIG(CNAME) is a
``NodeKind.CNAME`` node.
A node which contains only "neutral" types, i.e. types allowed to
co-exist with a CNAME, is a ``NodeKind.NEUTRAL`` node. The neutral
types are NSEC, NSEC3, KEY, and their associated RRSIGS. An empty node
is also considered neutral.
A node which contains some rdataset which is not a CNAME, RRSIG(CNAME),
or a neutral type is a a ``NodeKind.REGULAR`` node. Regular nodes are
also commonly referred to as "other data".
"""
for rdataset in self.rdatasets:
kind = NodeKind.classify(rdataset.rdtype, rdataset.covers)
if kind != NodeKind.NEUTRAL:
return kind
return NodeKind.NEUTRAL
def is_immutable(self) -> bool:
return False
@dns.immutable.immutable
class ImmutableNode(Node):
def __init__(self, node):
super().__init__()
self.rdatasets = tuple(
[dns.rdataset.ImmutableRdataset(rds) for rds in node.rdatasets]
)
def find_rdataset(
self,
rdclass: dns.rdataclass.RdataClass,
rdtype: dns.rdatatype.RdataType,
covers: dns.rdatatype.RdataType = dns.rdatatype.NONE,
create: bool = False,
) -> dns.rdataset.Rdataset:
if create:
raise TypeError("immutable")
return super().find_rdataset(rdclass, rdtype, covers, False)
def get_rdataset(
self,
rdclass: dns.rdataclass.RdataClass,
rdtype: dns.rdatatype.RdataType,
covers: dns.rdatatype.RdataType = dns.rdatatype.NONE,
create: bool = False,
) -> dns.rdataset.Rdataset | None:
if create:
raise TypeError("immutable")
return super().get_rdataset(rdclass, rdtype, covers, False)
def delete_rdataset(
self,
rdclass: dns.rdataclass.RdataClass,
rdtype: dns.rdatatype.RdataType,
covers: dns.rdatatype.RdataType = dns.rdatatype.NONE,
) -> None:
raise TypeError("immutable")
def replace_rdataset(self, replacement: dns.rdataset.Rdataset) -> None:
raise TypeError("immutable")
def is_immutable(self) -> bool:
return True

View File

@@ -0,0 +1,119 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
# Copyright (C) 2001-2017 Nominum, Inc.
#
# Permission to use, copy, modify, and distribute this software and its
# documentation for any purpose with or without fee is hereby granted,
# provided that the above copyright notice and this permission notice
# appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
"""DNS Opcodes."""
from typing import Type
import dns.enum
import dns.exception
class Opcode(dns.enum.IntEnum):
#: Query
QUERY = 0
#: Inverse Query (historical)
IQUERY = 1
#: Server Status (unspecified and unimplemented anywhere)
STATUS = 2
#: Notify
NOTIFY = 4
#: Dynamic Update
UPDATE = 5
@classmethod
def _maximum(cls):
return 15
@classmethod
def _unknown_exception_class(cls) -> Type[Exception]:
return UnknownOpcode
class UnknownOpcode(dns.exception.DNSException):
"""An DNS opcode is unknown."""
def from_text(text: str) -> Opcode:
"""Convert text into an opcode.
*text*, a ``str``, the textual opcode
Raises ``dns.opcode.UnknownOpcode`` if the opcode is unknown.
Returns an ``int``.
"""
return Opcode.from_text(text)
def from_flags(flags: int) -> Opcode:
"""Extract an opcode from DNS message flags.
*flags*, an ``int``, the DNS flags.
Returns an ``int``.
"""
return Opcode((flags & 0x7800) >> 11)
def to_flags(value: Opcode) -> int:
"""Convert an opcode to a value suitable for ORing into DNS message
flags.
*value*, an ``int``, the DNS opcode value.
Returns an ``int``.
"""
return (value << 11) & 0x7800
def to_text(value: Opcode) -> str:
"""Convert an opcode to text.
*value*, an ``int`` the opcode value,
Raises ``dns.opcode.UnknownOpcode`` if the opcode is unknown.
Returns a ``str``.
"""
return Opcode.to_text(value)
def is_update(flags: int) -> bool:
"""Is the opcode in flags UPDATE?
*flags*, an ``int``, the DNS message flags.
Returns a ``bool``.
"""
return from_flags(flags) == Opcode.UPDATE
### BEGIN generated Opcode constants
QUERY = Opcode.QUERY
IQUERY = Opcode.IQUERY
STATUS = Opcode.STATUS
NOTIFY = Opcode.NOTIFY
UPDATE = Opcode.UPDATE
### END generated Opcode constants

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,78 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
from typing import Any, Dict, List, Tuple
import dns._features
import dns.asyncbackend
if dns._features.have("doq"):
from dns._asyncbackend import NullContext
from dns.quic._asyncio import AsyncioQuicConnection as AsyncioQuicConnection
from dns.quic._asyncio import AsyncioQuicManager
from dns.quic._asyncio import AsyncioQuicStream as AsyncioQuicStream
from dns.quic._common import AsyncQuicConnection # pyright: ignore
from dns.quic._common import AsyncQuicManager as AsyncQuicManager
from dns.quic._sync import SyncQuicConnection # pyright: ignore
from dns.quic._sync import SyncQuicStream # pyright: ignore
from dns.quic._sync import SyncQuicManager as SyncQuicManager
have_quic = True
def null_factory(
*args, # pylint: disable=unused-argument
**kwargs, # pylint: disable=unused-argument
):
return NullContext(None)
def _asyncio_manager_factory(
context, *args, **kwargs # pylint: disable=unused-argument
):
return AsyncioQuicManager(*args, **kwargs)
# We have a context factory and a manager factory as for trio we need to have
# a nursery.
_async_factories: Dict[str, Tuple[Any, Any]] = {
"asyncio": (null_factory, _asyncio_manager_factory)
}
if dns._features.have("trio"):
import trio
# pylint: disable=ungrouped-imports
from dns.quic._trio import TrioQuicConnection as TrioQuicConnection
from dns.quic._trio import TrioQuicManager
from dns.quic._trio import TrioQuicStream as TrioQuicStream
def _trio_context_factory():
return trio.open_nursery()
def _trio_manager_factory(context, *args, **kwargs):
return TrioQuicManager(context, *args, **kwargs)
_async_factories["trio"] = (_trio_context_factory, _trio_manager_factory)
def factories_for_backend(backend=None):
if backend is None:
backend = dns.asyncbackend.get_default_backend()
return _async_factories[backend.name()]
else: # pragma: no cover
have_quic = False
class AsyncQuicStream: # type: ignore
pass
class AsyncQuicConnection: # type: ignore
async def make_stream(self) -> Any:
raise NotImplementedError
class SyncQuicStream: # type: ignore
pass
class SyncQuicConnection: # type: ignore
def make_stream(self) -> Any:
raise NotImplementedError
Headers = List[Tuple[bytes, bytes]]

View File

@@ -0,0 +1,276 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
import asyncio
import socket
import ssl
import struct
import time
import aioquic.h3.connection # type: ignore
import aioquic.h3.events # type: ignore
import aioquic.quic.configuration # type: ignore
import aioquic.quic.connection # type: ignore
import aioquic.quic.events # type: ignore
import dns.asyncbackend
import dns.exception
import dns.inet
from dns.quic._common import (
QUIC_MAX_DATAGRAM,
AsyncQuicConnection,
AsyncQuicManager,
BaseQuicStream,
UnexpectedEOF,
)
class AsyncioQuicStream(BaseQuicStream):
def __init__(self, connection, stream_id):
super().__init__(connection, stream_id)
self._wake_up = asyncio.Condition()
async def _wait_for_wake_up(self):
async with self._wake_up:
await self._wake_up.wait()
async def wait_for(self, amount, expiration):
while True:
timeout = self._timeout_from_expiration(expiration)
if self._buffer.have(amount):
return
self._expecting = amount
try:
await asyncio.wait_for(self._wait_for_wake_up(), timeout)
except TimeoutError:
raise dns.exception.Timeout
self._expecting = 0
async def wait_for_end(self, expiration):
while True:
timeout = self._timeout_from_expiration(expiration)
if self._buffer.seen_end():
return
try:
await asyncio.wait_for(self._wait_for_wake_up(), timeout)
except TimeoutError:
raise dns.exception.Timeout
async def receive(self, timeout=None):
expiration = self._expiration_from_timeout(timeout)
if self._connection.is_h3():
await self.wait_for_end(expiration)
return self._buffer.get_all()
else:
await self.wait_for(2, expiration)
(size,) = struct.unpack("!H", self._buffer.get(2))
await self.wait_for(size, expiration)
return self._buffer.get(size)
async def send(self, datagram, is_end=False):
data = self._encapsulate(datagram)
await self._connection.write(self._stream_id, data, is_end)
async def _add_input(self, data, is_end):
if self._common_add_input(data, is_end):
async with self._wake_up:
self._wake_up.notify()
async def close(self):
self._close()
# Streams are async context managers
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.close()
async with self._wake_up:
self._wake_up.notify()
return False
class AsyncioQuicConnection(AsyncQuicConnection):
def __init__(self, connection, address, port, source, source_port, manager=None):
super().__init__(connection, address, port, source, source_port, manager)
self._socket = None
self._handshake_complete = asyncio.Event()
self._socket_created = asyncio.Event()
self._wake_timer = asyncio.Condition()
self._receiver_task = None
self._sender_task = None
self._wake_pending = False
async def _receiver(self):
try:
af = dns.inet.af_for_address(self._address)
backend = dns.asyncbackend.get_backend("asyncio")
# Note that peer is a low-level address tuple, but make_socket() wants
# a high-level address tuple, so we convert.
self._socket = await backend.make_socket(
af, socket.SOCK_DGRAM, 0, self._source, (self._peer[0], self._peer[1])
)
self._socket_created.set()
async with self._socket:
while not self._done:
(datagram, address) = await self._socket.recvfrom(
QUIC_MAX_DATAGRAM, None
)
if address[0] != self._peer[0] or address[1] != self._peer[1]:
continue
self._connection.receive_datagram(datagram, address, time.time())
# Wake up the timer in case the sender is sleeping, as there may be
# stuff to send now.
await self._wakeup()
except Exception:
pass
finally:
self._done = True
await self._wakeup()
self._handshake_complete.set()
async def _wakeup(self):
self._wake_pending = True
async with self._wake_timer:
self._wake_timer.notify_all()
async def _wait_for_wake_timer(self):
async with self._wake_timer:
if not self._wake_pending:
await self._wake_timer.wait()
self._wake_pending = False
async def _sender(self):
await self._socket_created.wait()
while not self._done:
datagrams = self._connection.datagrams_to_send(time.time())
for datagram, address in datagrams:
assert address == self._peer
assert self._socket is not None
await self._socket.sendto(datagram, self._peer, None)
(expiration, interval) = self._get_timer_values()
try:
await asyncio.wait_for(self._wait_for_wake_timer(), interval)
except Exception:
pass
self._handle_timer(expiration)
await self._handle_events()
async def _handle_events(self):
count = 0
while True:
event = self._connection.next_event()
if event is None:
return
if isinstance(event, aioquic.quic.events.StreamDataReceived):
if self.is_h3():
assert self._h3_conn is not None
h3_events = self._h3_conn.handle_event(event)
for h3_event in h3_events:
if isinstance(h3_event, aioquic.h3.events.HeadersReceived):
stream = self._streams.get(event.stream_id)
if stream:
if stream._headers is None:
stream._headers = h3_event.headers
elif stream._trailers is None:
stream._trailers = h3_event.headers
if h3_event.stream_ended:
await stream._add_input(b"", True)
elif isinstance(h3_event, aioquic.h3.events.DataReceived):
stream = self._streams.get(event.stream_id)
if stream:
await stream._add_input(
h3_event.data, h3_event.stream_ended
)
else:
stream = self._streams.get(event.stream_id)
if stream:
await stream._add_input(event.data, event.end_stream)
elif isinstance(event, aioquic.quic.events.HandshakeCompleted):
self._handshake_complete.set()
elif isinstance(event, aioquic.quic.events.ConnectionTerminated):
self._done = True
if self._receiver_task is not None:
self._receiver_task.cancel()
elif isinstance(event, aioquic.quic.events.StreamReset):
stream = self._streams.get(event.stream_id)
if stream:
await stream._add_input(b"", True)
count += 1
if count > 10:
# yield
count = 0
await asyncio.sleep(0)
async def write(self, stream, data, is_end=False):
self._connection.send_stream_data(stream, data, is_end)
await self._wakeup()
def run(self):
if self._closed:
return
self._receiver_task = asyncio.Task(self._receiver())
self._sender_task = asyncio.Task(self._sender())
async def make_stream(self, timeout=None):
try:
await asyncio.wait_for(self._handshake_complete.wait(), timeout)
except TimeoutError:
raise dns.exception.Timeout
if self._done:
raise UnexpectedEOF
stream_id = self._connection.get_next_available_stream_id(False)
stream = AsyncioQuicStream(self, stream_id)
self._streams[stream_id] = stream
return stream
async def close(self):
if not self._closed:
if self._manager is not None:
self._manager.closed(self._peer[0], self._peer[1])
self._closed = True
self._connection.close()
# sender might be blocked on this, so set it
self._socket_created.set()
await self._wakeup()
try:
if self._receiver_task is not None:
await self._receiver_task
except asyncio.CancelledError:
pass
try:
if self._sender_task is not None:
await self._sender_task
except asyncio.CancelledError:
pass
if self._socket is not None:
await self._socket.close()
class AsyncioQuicManager(AsyncQuicManager):
def __init__(
self, conf=None, verify_mode=ssl.CERT_REQUIRED, server_name=None, h3=False
):
super().__init__(conf, verify_mode, AsyncioQuicConnection, server_name, h3)
def connect(
self, address, port=853, source=None, source_port=0, want_session_ticket=True
):
(connection, start) = self._connect(
address, port, source, source_port, want_session_ticket
)
if start:
connection.run()
return connection
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
# Copy the iterator into a list as exiting things will mutate the connections
# table.
connections = list(self._connections.values())
for connection in connections:
await connection.close()
return False

View File

@@ -0,0 +1,344 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
import base64
import copy
import functools
import socket
import struct
import time
import urllib.parse
from typing import Any
import aioquic.h3.connection # type: ignore
import aioquic.quic.configuration # type: ignore
import aioquic.quic.connection # type: ignore
import dns._tls_util
import dns.inet
QUIC_MAX_DATAGRAM = 2048
MAX_SESSION_TICKETS = 8
# If we hit the max sessions limit we will delete this many of the oldest connections.
# The value must be a integer > 0 and <= MAX_SESSION_TICKETS.
SESSIONS_TO_DELETE = MAX_SESSION_TICKETS // 4
class UnexpectedEOF(Exception):
pass
class Buffer:
def __init__(self):
self._buffer = b""
self._seen_end = False
def put(self, data, is_end):
if self._seen_end:
return
self._buffer += data
if is_end:
self._seen_end = True
def have(self, amount):
if len(self._buffer) >= amount:
return True
if self._seen_end:
raise UnexpectedEOF
return False
def seen_end(self):
return self._seen_end
def get(self, amount):
assert self.have(amount)
data = self._buffer[:amount]
self._buffer = self._buffer[amount:]
return data
def get_all(self):
assert self.seen_end()
data = self._buffer
self._buffer = b""
return data
class BaseQuicStream:
def __init__(self, connection, stream_id):
self._connection = connection
self._stream_id = stream_id
self._buffer = Buffer()
self._expecting = 0
self._headers = None
self._trailers = None
def id(self):
return self._stream_id
def headers(self):
return self._headers
def trailers(self):
return self._trailers
def _expiration_from_timeout(self, timeout):
if timeout is not None:
expiration = time.time() + timeout
else:
expiration = None
return expiration
def _timeout_from_expiration(self, expiration):
if expiration is not None:
timeout = max(expiration - time.time(), 0.0)
else:
timeout = None
return timeout
# Subclass must implement receive() as sync / async and which returns a message
# or raises.
# Subclass must implement send() as sync / async and which takes a message and
# an EOF indicator.
def send_h3(self, url, datagram, post=True):
if not self._connection.is_h3():
raise SyntaxError("cannot send H3 to a non-H3 connection")
url_parts = urllib.parse.urlparse(url)
path = url_parts.path.encode()
if post:
method = b"POST"
else:
method = b"GET"
path += b"?dns=" + base64.urlsafe_b64encode(datagram).rstrip(b"=")
headers = [
(b":method", method),
(b":scheme", url_parts.scheme.encode()),
(b":authority", url_parts.netloc.encode()),
(b":path", path),
(b"accept", b"application/dns-message"),
]
if post:
headers.extend(
[
(b"content-type", b"application/dns-message"),
(b"content-length", str(len(datagram)).encode()),
]
)
self._connection.send_headers(self._stream_id, headers, not post)
if post:
self._connection.send_data(self._stream_id, datagram, True)
def _encapsulate(self, datagram):
if self._connection.is_h3():
return datagram
l = len(datagram)
return struct.pack("!H", l) + datagram
def _common_add_input(self, data, is_end):
self._buffer.put(data, is_end)
try:
return (
self._expecting > 0 and self._buffer.have(self._expecting)
) or self._buffer.seen_end
except UnexpectedEOF:
return True
def _close(self):
self._connection.close_stream(self._stream_id)
self._buffer.put(b"", True) # send EOF in case we haven't seen it.
class BaseQuicConnection:
def __init__(
self,
connection,
address,
port,
source=None,
source_port=0,
manager=None,
):
self._done = False
self._connection = connection
self._address = address
self._port = port
self._closed = False
self._manager = manager
self._streams = {}
if manager is not None and manager.is_h3():
self._h3_conn = aioquic.h3.connection.H3Connection(connection, False)
else:
self._h3_conn = None
self._af = dns.inet.af_for_address(address)
self._peer = dns.inet.low_level_address_tuple((address, port))
if source is None and source_port != 0:
if self._af == socket.AF_INET:
source = "0.0.0.0"
elif self._af == socket.AF_INET6:
source = "::"
else:
raise NotImplementedError
if source:
self._source = (source, source_port)
else:
self._source = None
def is_h3(self):
return self._h3_conn is not None
def close_stream(self, stream_id):
del self._streams[stream_id]
def send_headers(self, stream_id, headers, is_end=False):
assert self._h3_conn is not None
self._h3_conn.send_headers(stream_id, headers, is_end)
def send_data(self, stream_id, data, is_end=False):
assert self._h3_conn is not None
self._h3_conn.send_data(stream_id, data, is_end)
def _get_timer_values(self, closed_is_special=True):
now = time.time()
expiration = self._connection.get_timer()
if expiration is None:
expiration = now + 3600 # arbitrary "big" value
interval = max(expiration - now, 0)
if self._closed and closed_is_special:
# lower sleep interval to avoid a race in the closing process
# which can lead to higher latency closing due to sleeping when
# we have events.
interval = min(interval, 0.05)
return (expiration, interval)
def _handle_timer(self, expiration):
now = time.time()
if expiration <= now:
self._connection.handle_timer(now)
class AsyncQuicConnection(BaseQuicConnection):
async def make_stream(self, timeout: float | None = None) -> Any:
pass
class BaseQuicManager:
def __init__(
self, conf, verify_mode, connection_factory, server_name=None, h3=False
):
self._connections = {}
self._connection_factory = connection_factory
self._session_tickets = {}
self._tokens = {}
self._h3 = h3
if conf is None:
verify_path = None
if isinstance(verify_mode, str):
verify_path = verify_mode
verify_mode = True
if h3:
alpn_protocols = ["h3"]
else:
alpn_protocols = ["doq", "doq-i03"]
conf = aioquic.quic.configuration.QuicConfiguration(
alpn_protocols=alpn_protocols,
verify_mode=verify_mode,
server_name=server_name,
)
if verify_path is not None:
cafile, capath = dns._tls_util.convert_verify_to_cafile_and_capath(
verify_path
)
conf.load_verify_locations(cafile=cafile, capath=capath)
self._conf = conf
def _connect(
self,
address,
port=853,
source=None,
source_port=0,
want_session_ticket=True,
want_token=True,
):
connection = self._connections.get((address, port))
if connection is not None:
return (connection, False)
conf = self._conf
if want_session_ticket:
try:
session_ticket = self._session_tickets.pop((address, port))
# We found a session ticket, so make a configuration that uses it.
conf = copy.copy(conf)
conf.session_ticket = session_ticket
except KeyError:
# No session ticket.
pass
# Whether or not we found a session ticket, we want a handler to save
# one.
session_ticket_handler = functools.partial(
self.save_session_ticket, address, port
)
else:
session_ticket_handler = None
if want_token:
try:
token = self._tokens.pop((address, port))
# We found a token, so make a configuration that uses it.
conf = copy.copy(conf)
conf.token = token
except KeyError:
# No token
pass
# Whether or not we found a token, we want a handler to save # one.
token_handler = functools.partial(self.save_token, address, port)
else:
token_handler = None
qconn = aioquic.quic.connection.QuicConnection(
configuration=conf,
session_ticket_handler=session_ticket_handler,
token_handler=token_handler,
)
lladdress = dns.inet.low_level_address_tuple((address, port))
qconn.connect(lladdress, time.time())
connection = self._connection_factory(
qconn, address, port, source, source_port, self
)
self._connections[(address, port)] = connection
return (connection, True)
def closed(self, address, port):
try:
del self._connections[(address, port)]
except KeyError:
pass
def is_h3(self):
return self._h3
def save_session_ticket(self, address, port, ticket):
# We rely on dictionaries keys() being in insertion order here. We
# can't just popitem() as that would be LIFO which is the opposite of
# what we want.
l = len(self._session_tickets)
if l >= MAX_SESSION_TICKETS:
keys_to_delete = list(self._session_tickets.keys())[0:SESSIONS_TO_DELETE]
for key in keys_to_delete:
del self._session_tickets[key]
self._session_tickets[(address, port)] = ticket
def save_token(self, address, port, token):
# We rely on dictionaries keys() being in insertion order here. We
# can't just popitem() as that would be LIFO which is the opposite of
# what we want.
l = len(self._tokens)
if l >= MAX_SESSION_TICKETS:
keys_to_delete = list(self._tokens.keys())[0:SESSIONS_TO_DELETE]
for key in keys_to_delete:
del self._tokens[key]
self._tokens[(address, port)] = token
class AsyncQuicManager(BaseQuicManager):
def connect(self, address, port=853, source=None, source_port=0):
raise NotImplementedError

View File

@@ -0,0 +1,306 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
import selectors
import socket
import ssl
import struct
import threading
import time
import aioquic.h3.connection # type: ignore
import aioquic.h3.events # type: ignore
import aioquic.quic.configuration # type: ignore
import aioquic.quic.connection # type: ignore
import aioquic.quic.events # type: ignore
import dns.exception
import dns.inet
from dns.quic._common import (
QUIC_MAX_DATAGRAM,
BaseQuicConnection,
BaseQuicManager,
BaseQuicStream,
UnexpectedEOF,
)
# Function used to create a socket. Can be overridden if needed in special
# situations.
socket_factory = socket.socket
class SyncQuicStream(BaseQuicStream):
def __init__(self, connection, stream_id):
super().__init__(connection, stream_id)
self._wake_up = threading.Condition()
self._lock = threading.Lock()
def wait_for(self, amount, expiration):
while True:
timeout = self._timeout_from_expiration(expiration)
with self._lock:
if self._buffer.have(amount):
return
self._expecting = amount
with self._wake_up:
if not self._wake_up.wait(timeout):
raise dns.exception.Timeout
self._expecting = 0
def wait_for_end(self, expiration):
while True:
timeout = self._timeout_from_expiration(expiration)
with self._lock:
if self._buffer.seen_end():
return
with self._wake_up:
if not self._wake_up.wait(timeout):
raise dns.exception.Timeout
def receive(self, timeout=None):
expiration = self._expiration_from_timeout(timeout)
if self._connection.is_h3():
self.wait_for_end(expiration)
with self._lock:
return self._buffer.get_all()
else:
self.wait_for(2, expiration)
with self._lock:
(size,) = struct.unpack("!H", self._buffer.get(2))
self.wait_for(size, expiration)
with self._lock:
return self._buffer.get(size)
def send(self, datagram, is_end=False):
data = self._encapsulate(datagram)
self._connection.write(self._stream_id, data, is_end)
def _add_input(self, data, is_end):
if self._common_add_input(data, is_end):
with self._wake_up:
self._wake_up.notify()
def close(self):
with self._lock:
self._close()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
with self._wake_up:
self._wake_up.notify()
return False
class SyncQuicConnection(BaseQuicConnection):
def __init__(self, connection, address, port, source, source_port, manager):
super().__init__(connection, address, port, source, source_port, manager)
self._socket = socket_factory(self._af, socket.SOCK_DGRAM, 0)
if self._source is not None:
try:
self._socket.bind(
dns.inet.low_level_address_tuple(self._source, self._af)
)
except Exception:
self._socket.close()
raise
self._socket.connect(self._peer)
(self._send_wakeup, self._receive_wakeup) = socket.socketpair()
self._receive_wakeup.setblocking(False)
self._socket.setblocking(False)
self._handshake_complete = threading.Event()
self._worker_thread = None
self._lock = threading.Lock()
def _read(self):
count = 0
while count < 10:
count += 1
try:
datagram = self._socket.recv(QUIC_MAX_DATAGRAM)
except BlockingIOError:
return
with self._lock:
self._connection.receive_datagram(datagram, self._peer, time.time())
def _drain_wakeup(self):
while True:
try:
self._receive_wakeup.recv(32)
except BlockingIOError:
return
def _worker(self):
try:
with selectors.DefaultSelector() as sel:
sel.register(self._socket, selectors.EVENT_READ, self._read)
sel.register(
self._receive_wakeup, selectors.EVENT_READ, self._drain_wakeup
)
while not self._done:
(expiration, interval) = self._get_timer_values(False)
items = sel.select(interval)
for key, _ in items:
key.data()
with self._lock:
self._handle_timer(expiration)
self._handle_events()
with self._lock:
datagrams = self._connection.datagrams_to_send(time.time())
for datagram, _ in datagrams:
try:
self._socket.send(datagram)
except BlockingIOError:
# we let QUIC handle any lossage
pass
except Exception:
# Eat all exceptions as we have no way to pass them back to the
# caller currently. It might be nice to fix this in the future.
pass
finally:
with self._lock:
self._done = True
self._socket.close()
# Ensure anyone waiting for this gets woken up.
self._handshake_complete.set()
def _handle_events(self):
while True:
with self._lock:
event = self._connection.next_event()
if event is None:
return
if isinstance(event, aioquic.quic.events.StreamDataReceived):
if self.is_h3():
assert self._h3_conn is not None
h3_events = self._h3_conn.handle_event(event)
for h3_event in h3_events:
if isinstance(h3_event, aioquic.h3.events.HeadersReceived):
with self._lock:
stream = self._streams.get(event.stream_id)
if stream:
if stream._headers is None:
stream._headers = h3_event.headers
elif stream._trailers is None:
stream._trailers = h3_event.headers
if h3_event.stream_ended:
stream._add_input(b"", True)
elif isinstance(h3_event, aioquic.h3.events.DataReceived):
with self._lock:
stream = self._streams.get(event.stream_id)
if stream:
stream._add_input(h3_event.data, h3_event.stream_ended)
else:
with self._lock:
stream = self._streams.get(event.stream_id)
if stream:
stream._add_input(event.data, event.end_stream)
elif isinstance(event, aioquic.quic.events.HandshakeCompleted):
self._handshake_complete.set()
elif isinstance(event, aioquic.quic.events.ConnectionTerminated):
with self._lock:
self._done = True
elif isinstance(event, aioquic.quic.events.StreamReset):
with self._lock:
stream = self._streams.get(event.stream_id)
if stream:
stream._add_input(b"", True)
def write(self, stream, data, is_end=False):
with self._lock:
self._connection.send_stream_data(stream, data, is_end)
self._send_wakeup.send(b"\x01")
def send_headers(self, stream_id, headers, is_end=False):
with self._lock:
super().send_headers(stream_id, headers, is_end)
if is_end:
self._send_wakeup.send(b"\x01")
def send_data(self, stream_id, data, is_end=False):
with self._lock:
super().send_data(stream_id, data, is_end)
if is_end:
self._send_wakeup.send(b"\x01")
def run(self):
if self._closed:
return
self._worker_thread = threading.Thread(target=self._worker)
self._worker_thread.start()
def make_stream(self, timeout=None):
if not self._handshake_complete.wait(timeout):
raise dns.exception.Timeout
with self._lock:
if self._done:
raise UnexpectedEOF
stream_id = self._connection.get_next_available_stream_id(False)
stream = SyncQuicStream(self, stream_id)
self._streams[stream_id] = stream
return stream
def close_stream(self, stream_id):
with self._lock:
super().close_stream(stream_id)
def close(self):
with self._lock:
if self._closed:
return
if self._manager is not None:
self._manager.closed(self._peer[0], self._peer[1])
self._closed = True
self._connection.close()
self._send_wakeup.send(b"\x01")
if self._worker_thread is not None:
self._worker_thread.join()
class SyncQuicManager(BaseQuicManager):
def __init__(
self, conf=None, verify_mode=ssl.CERT_REQUIRED, server_name=None, h3=False
):
super().__init__(conf, verify_mode, SyncQuicConnection, server_name, h3)
self._lock = threading.Lock()
def connect(
self,
address,
port=853,
source=None,
source_port=0,
want_session_ticket=True,
want_token=True,
):
with self._lock:
(connection, start) = self._connect(
address, port, source, source_port, want_session_ticket, want_token
)
if start:
connection.run()
return connection
def closed(self, address, port):
with self._lock:
super().closed(address, port)
def save_session_ticket(self, address, port, ticket):
with self._lock:
super().save_session_ticket(address, port, ticket)
def save_token(self, address, port, token):
with self._lock:
super().save_token(address, port, token)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
# Copy the iterator into a list as exiting things will mutate the connections
# table.
connections = list(self._connections.values())
for connection in connections:
connection.close()
return False

View File

@@ -0,0 +1,250 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
import socket
import ssl
import struct
import time
import aioquic.h3.connection # type: ignore
import aioquic.h3.events # type: ignore
import aioquic.quic.configuration # type: ignore
import aioquic.quic.connection # type: ignore
import aioquic.quic.events # type: ignore
import trio
import dns.exception
import dns.inet
from dns._asyncbackend import NullContext
from dns.quic._common import (
QUIC_MAX_DATAGRAM,
AsyncQuicConnection,
AsyncQuicManager,
BaseQuicStream,
UnexpectedEOF,
)
class TrioQuicStream(BaseQuicStream):
def __init__(self, connection, stream_id):
super().__init__(connection, stream_id)
self._wake_up = trio.Condition()
async def wait_for(self, amount):
while True:
if self._buffer.have(amount):
return
self._expecting = amount
async with self._wake_up:
await self._wake_up.wait()
self._expecting = 0
async def wait_for_end(self):
while True:
if self._buffer.seen_end():
return
async with self._wake_up:
await self._wake_up.wait()
async def receive(self, timeout=None):
if timeout is None:
context = NullContext(None)
else:
context = trio.move_on_after(timeout)
with context:
if self._connection.is_h3():
await self.wait_for_end()
return self._buffer.get_all()
else:
await self.wait_for(2)
(size,) = struct.unpack("!H", self._buffer.get(2))
await self.wait_for(size)
return self._buffer.get(size)
raise dns.exception.Timeout
async def send(self, datagram, is_end=False):
data = self._encapsulate(datagram)
await self._connection.write(self._stream_id, data, is_end)
async def _add_input(self, data, is_end):
if self._common_add_input(data, is_end):
async with self._wake_up:
self._wake_up.notify()
async def close(self):
self._close()
# Streams are async context managers
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.close()
async with self._wake_up:
self._wake_up.notify()
return False
class TrioQuicConnection(AsyncQuicConnection):
def __init__(self, connection, address, port, source, source_port, manager=None):
super().__init__(connection, address, port, source, source_port, manager)
self._socket = trio.socket.socket(self._af, socket.SOCK_DGRAM, 0)
self._handshake_complete = trio.Event()
self._run_done = trio.Event()
self._worker_scope = None
self._send_pending = False
async def _worker(self):
try:
if self._source:
await self._socket.bind(
dns.inet.low_level_address_tuple(self._source, self._af)
)
await self._socket.connect(self._peer)
while not self._done:
(expiration, interval) = self._get_timer_values(False)
if self._send_pending:
# Do not block forever if sends are pending. Even though we
# have a wake-up mechanism if we've already started the blocking
# read, the possibility of context switching in send means that
# more writes can happen while we have no wake up context, so
# we need self._send_pending to avoid (effectively) a "lost wakeup"
# race.
interval = 0.0
with trio.CancelScope(
deadline=trio.current_time() + interval # pyright: ignore
) as self._worker_scope:
datagram = await self._socket.recv(QUIC_MAX_DATAGRAM)
self._connection.receive_datagram(datagram, self._peer, time.time())
self._worker_scope = None
self._handle_timer(expiration)
await self._handle_events()
# We clear this now, before sending anything, as sending can cause
# context switches that do more sends. We want to know if that
# happens so we don't block a long time on the recv() above.
self._send_pending = False
datagrams = self._connection.datagrams_to_send(time.time())
for datagram, _ in datagrams:
await self._socket.send(datagram)
finally:
self._done = True
self._socket.close()
self._handshake_complete.set()
async def _handle_events(self):
count = 0
while True:
event = self._connection.next_event()
if event is None:
return
if isinstance(event, aioquic.quic.events.StreamDataReceived):
if self.is_h3():
assert self._h3_conn is not None
h3_events = self._h3_conn.handle_event(event)
for h3_event in h3_events:
if isinstance(h3_event, aioquic.h3.events.HeadersReceived):
stream = self._streams.get(event.stream_id)
if stream:
if stream._headers is None:
stream._headers = h3_event.headers
elif stream._trailers is None:
stream._trailers = h3_event.headers
if h3_event.stream_ended:
await stream._add_input(b"", True)
elif isinstance(h3_event, aioquic.h3.events.DataReceived):
stream = self._streams.get(event.stream_id)
if stream:
await stream._add_input(
h3_event.data, h3_event.stream_ended
)
else:
stream = self._streams.get(event.stream_id)
if stream:
await stream._add_input(event.data, event.end_stream)
elif isinstance(event, aioquic.quic.events.HandshakeCompleted):
self._handshake_complete.set()
elif isinstance(event, aioquic.quic.events.ConnectionTerminated):
self._done = True
self._socket.close()
elif isinstance(event, aioquic.quic.events.StreamReset):
stream = self._streams.get(event.stream_id)
if stream:
await stream._add_input(b"", True)
count += 1
if count > 10:
# yield
count = 0
await trio.sleep(0)
async def write(self, stream, data, is_end=False):
self._connection.send_stream_data(stream, data, is_end)
self._send_pending = True
if self._worker_scope is not None:
self._worker_scope.cancel()
async def run(self):
if self._closed:
return
async with trio.open_nursery() as nursery:
nursery.start_soon(self._worker)
self._run_done.set()
async def make_stream(self, timeout=None):
if timeout is None:
context = NullContext(None)
else:
context = trio.move_on_after(timeout)
with context:
await self._handshake_complete.wait()
if self._done:
raise UnexpectedEOF
stream_id = self._connection.get_next_available_stream_id(False)
stream = TrioQuicStream(self, stream_id)
self._streams[stream_id] = stream
return stream
raise dns.exception.Timeout
async def close(self):
if not self._closed:
if self._manager is not None:
self._manager.closed(self._peer[0], self._peer[1])
self._closed = True
self._connection.close()
self._send_pending = True
if self._worker_scope is not None:
self._worker_scope.cancel()
await self._run_done.wait()
class TrioQuicManager(AsyncQuicManager):
def __init__(
self,
nursery,
conf=None,
verify_mode=ssl.CERT_REQUIRED,
server_name=None,
h3=False,
):
super().__init__(conf, verify_mode, TrioQuicConnection, server_name, h3)
self._nursery = nursery
def connect(
self, address, port=853, source=None, source_port=0, want_session_ticket=True
):
(connection, start) = self._connect(
address, port, source, source_port, want_session_ticket
)
if start:
self._nursery.start_soon(connection.run)
return connection
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
# Copy the iterator into a list as exiting things will mutate the connections
# table.
connections = list(self._connections.values())
for connection in connections:
await connection.close()
return False

View File

@@ -0,0 +1,168 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
# Copyright (C) 2001-2017 Nominum, Inc.
#
# Permission to use, copy, modify, and distribute this software and its
# documentation for any purpose with or without fee is hereby granted,
# provided that the above copyright notice and this permission notice
# appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
"""DNS Result Codes."""
from typing import Tuple, Type
import dns.enum
import dns.exception
class Rcode(dns.enum.IntEnum):
#: No error
NOERROR = 0
#: Format error
FORMERR = 1
#: Server failure
SERVFAIL = 2
#: Name does not exist ("Name Error" in RFC 1025 terminology).
NXDOMAIN = 3
#: Not implemented
NOTIMP = 4
#: Refused
REFUSED = 5
#: Name exists.
YXDOMAIN = 6
#: RRset exists.
YXRRSET = 7
#: RRset does not exist.
NXRRSET = 8
#: Not authoritative.
NOTAUTH = 9
#: Name not in zone.
NOTZONE = 10
#: DSO-TYPE Not Implemented
DSOTYPENI = 11
#: Bad EDNS version.
BADVERS = 16
#: TSIG Signature Failure
BADSIG = 16
#: Key not recognized.
BADKEY = 17
#: Signature out of time window.
BADTIME = 18
#: Bad TKEY Mode.
BADMODE = 19
#: Duplicate key name.
BADNAME = 20
#: Algorithm not supported.
BADALG = 21
#: Bad Truncation
BADTRUNC = 22
#: Bad/missing Server Cookie
BADCOOKIE = 23
@classmethod
def _maximum(cls):
return 4095
@classmethod
def _unknown_exception_class(cls) -> Type[Exception]:
return UnknownRcode
class UnknownRcode(dns.exception.DNSException):
"""A DNS rcode is unknown."""
def from_text(text: str) -> Rcode:
"""Convert text into an rcode.
*text*, a ``str``, the textual rcode or an integer in textual form.
Raises ``dns.rcode.UnknownRcode`` if the rcode mnemonic is unknown.
Returns a ``dns.rcode.Rcode``.
"""
return Rcode.from_text(text)
def from_flags(flags: int, ednsflags: int) -> Rcode:
"""Return the rcode value encoded by flags and ednsflags.
*flags*, an ``int``, the DNS flags field.
*ednsflags*, an ``int``, the EDNS flags field.
Raises ``ValueError`` if rcode is < 0 or > 4095
Returns a ``dns.rcode.Rcode``.
"""
value = (flags & 0x000F) | ((ednsflags >> 20) & 0xFF0)
return Rcode.make(value)
def to_flags(value: Rcode) -> Tuple[int, int]:
"""Return a (flags, ednsflags) tuple which encodes the rcode.
*value*, a ``dns.rcode.Rcode``, the rcode.
Raises ``ValueError`` if rcode is < 0 or > 4095.
Returns an ``(int, int)`` tuple.
"""
if value < 0 or value > 4095:
raise ValueError("rcode must be >= 0 and <= 4095")
v = value & 0xF
ev = (value & 0xFF0) << 20
return (v, ev)
def to_text(value: Rcode, tsig: bool = False) -> str:
"""Convert rcode into text.
*value*, a ``dns.rcode.Rcode``, the rcode.
Raises ``ValueError`` if rcode is < 0 or > 4095.
Returns a ``str``.
"""
if tsig and value == Rcode.BADVERS:
return "BADSIG"
return Rcode.to_text(value)
### BEGIN generated Rcode constants
NOERROR = Rcode.NOERROR
FORMERR = Rcode.FORMERR
SERVFAIL = Rcode.SERVFAIL
NXDOMAIN = Rcode.NXDOMAIN
NOTIMP = Rcode.NOTIMP
REFUSED = Rcode.REFUSED
YXDOMAIN = Rcode.YXDOMAIN
YXRRSET = Rcode.YXRRSET
NXRRSET = Rcode.NXRRSET
NOTAUTH = Rcode.NOTAUTH
NOTZONE = Rcode.NOTZONE
DSOTYPENI = Rcode.DSOTYPENI
BADVERS = Rcode.BADVERS
BADSIG = Rcode.BADSIG
BADKEY = Rcode.BADKEY
BADTIME = Rcode.BADTIME
BADMODE = Rcode.BADMODE
BADNAME = Rcode.BADNAME
BADALG = Rcode.BADALG
BADTRUNC = Rcode.BADTRUNC
BADCOOKIE = Rcode.BADCOOKIE
### END generated Rcode constants

View File

@@ -0,0 +1,935 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
# Copyright (C) 2001-2017 Nominum, Inc.
#
# Permission to use, copy, modify, and distribute this software and its
# documentation for any purpose with or without fee is hereby granted,
# provided that the above copyright notice and this permission notice
# appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
"""DNS rdata."""
import base64
import binascii
import inspect
import io
import ipaddress
import itertools
import random
from importlib import import_module
from typing import Any, Dict, Tuple
import dns.exception
import dns.immutable
import dns.ipv4
import dns.ipv6
import dns.name
import dns.rdataclass
import dns.rdatatype
import dns.tokenizer
import dns.ttl
import dns.wire
_chunksize = 32
# We currently allow comparisons for rdata with relative names for backwards
# compatibility, but in the future we will not, as these kinds of comparisons
# can lead to subtle bugs if code is not carefully written.
#
# This switch allows the future behavior to be turned on so code can be
# tested with it.
_allow_relative_comparisons = True
class NoRelativeRdataOrdering(dns.exception.DNSException):
"""An attempt was made to do an ordered comparison of one or more
rdata with relative names. The only reliable way of sorting rdata
is to use non-relativized rdata.
"""
def _wordbreak(data, chunksize=_chunksize, separator=b" "):
"""Break a binary string into chunks of chunksize characters separated by
a space.
"""
if not chunksize:
return data.decode()
return separator.join(
[data[i : i + chunksize] for i in range(0, len(data), chunksize)]
).decode()
# pylint: disable=unused-argument
def _hexify(data, chunksize=_chunksize, separator=b" ", **kw):
"""Convert a binary string into its hex encoding, broken up into chunks
of chunksize characters separated by a separator.
"""
return _wordbreak(binascii.hexlify(data), chunksize, separator)
def _base64ify(data, chunksize=_chunksize, separator=b" ", **kw):
"""Convert a binary string into its base64 encoding, broken up into chunks
of chunksize characters separated by a separator.
"""
return _wordbreak(base64.b64encode(data), chunksize, separator)
# pylint: enable=unused-argument
__escaped = b'"\\'
def _escapify(qstring):
"""Escape the characters in a quoted string which need it."""
if isinstance(qstring, str):
qstring = qstring.encode()
if not isinstance(qstring, bytearray):
qstring = bytearray(qstring)
text = ""
for c in qstring:
if c in __escaped:
text += "\\" + chr(c)
elif c >= 0x20 and c < 0x7F:
text += chr(c)
else:
text += f"\\{c:03d}"
return text
def _truncate_bitmap(what):
"""Determine the index of greatest byte that isn't all zeros, and
return the bitmap that contains all the bytes less than that index.
"""
for i in range(len(what) - 1, -1, -1):
if what[i] != 0:
return what[0 : i + 1]
return what[0:1]
# So we don't have to edit all the rdata classes...
_constify = dns.immutable.constify
@dns.immutable.immutable
class Rdata:
"""Base class for all DNS rdata types."""
__slots__ = ["rdclass", "rdtype", "rdcomment"]
def __init__(
self,
rdclass: dns.rdataclass.RdataClass,
rdtype: dns.rdatatype.RdataType,
) -> None:
"""Initialize an rdata.
*rdclass*, an ``int`` is the rdataclass of the Rdata.
*rdtype*, an ``int`` is the rdatatype of the Rdata.
"""
self.rdclass = self._as_rdataclass(rdclass)
self.rdtype = self._as_rdatatype(rdtype)
self.rdcomment = None
def _get_all_slots(self):
return itertools.chain.from_iterable(
getattr(cls, "__slots__", []) for cls in self.__class__.__mro__
)
def __getstate__(self):
# We used to try to do a tuple of all slots here, but it
# doesn't work as self._all_slots isn't available at
# __setstate__() time. Before that we tried to store a tuple
# of __slots__, but that didn't work as it didn't store the
# slots defined by ancestors. This older way didn't fail
# outright, but ended up with partially broken objects, e.g.
# if you unpickled an A RR it wouldn't have rdclass and rdtype
# attributes, and would compare badly.
state = {}
for slot in self._get_all_slots():
state[slot] = getattr(self, slot)
return state
def __setstate__(self, state):
for slot, val in state.items():
object.__setattr__(self, slot, val)
if not hasattr(self, "rdcomment"):
# Pickled rdata from 2.0.x might not have a rdcomment, so add
# it if needed.
object.__setattr__(self, "rdcomment", None)
def covers(self) -> dns.rdatatype.RdataType:
"""Return the type a Rdata covers.
DNS SIG/RRSIG rdatas apply to a specific type; this type is
returned by the covers() function. If the rdata type is not
SIG or RRSIG, dns.rdatatype.NONE is returned. This is useful when
creating rdatasets, allowing the rdataset to contain only RRSIGs
of a particular type, e.g. RRSIG(NS).
Returns a ``dns.rdatatype.RdataType``.
"""
return dns.rdatatype.NONE
def extended_rdatatype(self) -> int:
"""Return a 32-bit type value, the least significant 16 bits of
which are the ordinary DNS type, and the upper 16 bits of which are
the "covered" type, if any.
Returns an ``int``.
"""
return self.covers() << 16 | self.rdtype
def to_text(
self,
origin: dns.name.Name | None = None,
relativize: bool = True,
**kw: Dict[str, Any],
) -> str:
"""Convert an rdata to text format.
Returns a ``str``.
"""
raise NotImplementedError # pragma: no cover
def _to_wire(
self,
file: Any,
compress: dns.name.CompressType | None = None,
origin: dns.name.Name | None = None,
canonicalize: bool = False,
) -> None:
raise NotImplementedError # pragma: no cover
def to_wire(
self,
file: Any | None = None,
compress: dns.name.CompressType | None = None,
origin: dns.name.Name | None = None,
canonicalize: bool = False,
) -> bytes | None:
"""Convert an rdata to wire format.
Returns a ``bytes`` if no output file was specified, or ``None`` otherwise.
"""
if file:
# We call _to_wire() and then return None explicitly instead of
# of just returning the None from _to_wire() as mypy's func-returns-value
# unhelpfully errors out with "error: "_to_wire" of "Rdata" does not return
# a value (it only ever returns None)"
self._to_wire(file, compress, origin, canonicalize)
return None
else:
f = io.BytesIO()
self._to_wire(f, compress, origin, canonicalize)
return f.getvalue()
def to_generic(self, origin: dns.name.Name | None = None) -> "GenericRdata":
"""Creates a dns.rdata.GenericRdata equivalent of this rdata.
Returns a ``dns.rdata.GenericRdata``.
"""
wire = self.to_wire(origin=origin)
assert wire is not None # for type checkers
return GenericRdata(self.rdclass, self.rdtype, wire)
def to_digestable(self, origin: dns.name.Name | None = None) -> bytes:
"""Convert rdata to a format suitable for digesting in hashes. This
is also the DNSSEC canonical form.
Returns a ``bytes``.
"""
wire = self.to_wire(origin=origin, canonicalize=True)
assert wire is not None # for mypy
return wire
def __repr__(self):
covers = self.covers()
if covers == dns.rdatatype.NONE:
ctext = ""
else:
ctext = "(" + dns.rdatatype.to_text(covers) + ")"
return (
"<DNS "
+ dns.rdataclass.to_text(self.rdclass)
+ " "
+ dns.rdatatype.to_text(self.rdtype)
+ ctext
+ " rdata: "
+ str(self)
+ ">"
)
def __str__(self):
return self.to_text()
def _cmp(self, other):
"""Compare an rdata with another rdata of the same rdtype and
rdclass.
For rdata with only absolute names:
Return < 0 if self < other in the DNSSEC ordering, 0 if self
== other, and > 0 if self > other.
For rdata with at least one relative names:
The rdata sorts before any rdata with only absolute names.
When compared with another relative rdata, all names are
made absolute as if they were relative to the root, as the
proper origin is not available. While this creates a stable
ordering, it is NOT guaranteed to be the DNSSEC ordering.
In the future, all ordering comparisons for rdata with
relative names will be disallowed.
"""
# the next two lines are for type checkers, so they are bound
our = b""
their = b""
try:
our = self.to_digestable()
our_relative = False
except dns.name.NeedAbsoluteNameOrOrigin:
if _allow_relative_comparisons:
our = self.to_digestable(dns.name.root)
our_relative = True
try:
their = other.to_digestable()
their_relative = False
except dns.name.NeedAbsoluteNameOrOrigin:
if _allow_relative_comparisons:
their = other.to_digestable(dns.name.root)
their_relative = True
if _allow_relative_comparisons:
if our_relative != their_relative:
# For the purpose of comparison, all rdata with at least one
# relative name is less than an rdata with only absolute names.
if our_relative:
return -1
else:
return 1
elif our_relative or their_relative:
raise NoRelativeRdataOrdering
if our == their:
return 0
elif our > their:
return 1
else:
return -1
def __eq__(self, other):
if not isinstance(other, Rdata):
return False
if self.rdclass != other.rdclass or self.rdtype != other.rdtype:
return False
our_relative = False
their_relative = False
try:
our = self.to_digestable()
except dns.name.NeedAbsoluteNameOrOrigin:
our = self.to_digestable(dns.name.root)
our_relative = True
try:
their = other.to_digestable()
except dns.name.NeedAbsoluteNameOrOrigin:
their = other.to_digestable(dns.name.root)
their_relative = True
if our_relative != their_relative:
return False
return our == their
def __ne__(self, other):
if not isinstance(other, Rdata):
return True
if self.rdclass != other.rdclass or self.rdtype != other.rdtype:
return True
return not self.__eq__(other)
def __lt__(self, other):
if (
not isinstance(other, Rdata)
or self.rdclass != other.rdclass
or self.rdtype != other.rdtype
):
return NotImplemented
return self._cmp(other) < 0
def __le__(self, other):
if (
not isinstance(other, Rdata)
or self.rdclass != other.rdclass
or self.rdtype != other.rdtype
):
return NotImplemented
return self._cmp(other) <= 0
def __ge__(self, other):
if (
not isinstance(other, Rdata)
or self.rdclass != other.rdclass
or self.rdtype != other.rdtype
):
return NotImplemented
return self._cmp(other) >= 0
def __gt__(self, other):
if (
not isinstance(other, Rdata)
or self.rdclass != other.rdclass
or self.rdtype != other.rdtype
):
return NotImplemented
return self._cmp(other) > 0
def __hash__(self):
return hash(self.to_digestable(dns.name.root))
@classmethod
def from_text(
cls,
rdclass: dns.rdataclass.RdataClass,
rdtype: dns.rdatatype.RdataType,
tok: dns.tokenizer.Tokenizer,
origin: dns.name.Name | None = None,
relativize: bool = True,
relativize_to: dns.name.Name | None = None,
) -> "Rdata":
raise NotImplementedError # pragma: no cover
@classmethod
def from_wire_parser(
cls,
rdclass: dns.rdataclass.RdataClass,
rdtype: dns.rdatatype.RdataType,
parser: dns.wire.Parser,
origin: dns.name.Name | None = None,
) -> "Rdata":
raise NotImplementedError # pragma: no cover
def replace(self, **kwargs: Any) -> "Rdata":
"""
Create a new Rdata instance based on the instance replace was
invoked on. It is possible to pass different parameters to
override the corresponding properties of the base Rdata.
Any field specific to the Rdata type can be replaced, but the
*rdtype* and *rdclass* fields cannot.
Returns an instance of the same Rdata subclass as *self*.
"""
# Get the constructor parameters.
parameters = inspect.signature(self.__init__).parameters # type: ignore
# Ensure that all of the arguments correspond to valid fields.
# Don't allow rdclass or rdtype to be changed, though.
for key in kwargs:
if key == "rdcomment":
continue
if key not in parameters:
raise AttributeError(
f"'{self.__class__.__name__}' object has no attribute '{key}'"
)
if key in ("rdclass", "rdtype"):
raise AttributeError(
f"Cannot overwrite '{self.__class__.__name__}' attribute '{key}'"
)
# Construct the parameter list. For each field, use the value in
# kwargs if present, and the current value otherwise.
args = (kwargs.get(key, getattr(self, key)) for key in parameters)
# Create, validate, and return the new object.
rd = self.__class__(*args)
# The comment is not set in the constructor, so give it special
# handling.
rdcomment = kwargs.get("rdcomment", self.rdcomment)
if rdcomment is not None:
object.__setattr__(rd, "rdcomment", rdcomment)
return rd
# Type checking and conversion helpers. These are class methods as
# they don't touch object state and may be useful to others.
@classmethod
def _as_rdataclass(cls, value):
return dns.rdataclass.RdataClass.make(value)
@classmethod
def _as_rdatatype(cls, value):
return dns.rdatatype.RdataType.make(value)
@classmethod
def _as_bytes(
cls,
value: Any,
encode: bool = False,
max_length: int | None = None,
empty_ok: bool = True,
) -> bytes:
if encode and isinstance(value, str):
bvalue = value.encode()
elif isinstance(value, bytearray):
bvalue = bytes(value)
elif isinstance(value, bytes):
bvalue = value
else:
raise ValueError("not bytes")
if max_length is not None and len(bvalue) > max_length:
raise ValueError("too long")
if not empty_ok and len(bvalue) == 0:
raise ValueError("empty bytes not allowed")
return bvalue
@classmethod
def _as_name(cls, value):
# Note that proper name conversion (e.g. with origin and IDNA
# awareness) is expected to be done via from_text. This is just
# a simple thing for people invoking the constructor directly.
if isinstance(value, str):
return dns.name.from_text(value)
elif not isinstance(value, dns.name.Name):
raise ValueError("not a name")
return value
@classmethod
def _as_uint8(cls, value):
if not isinstance(value, int):
raise ValueError("not an integer")
if value < 0 or value > 255:
raise ValueError("not a uint8")
return value
@classmethod
def _as_uint16(cls, value):
if not isinstance(value, int):
raise ValueError("not an integer")
if value < 0 or value > 65535:
raise ValueError("not a uint16")
return value
@classmethod
def _as_uint32(cls, value):
if not isinstance(value, int):
raise ValueError("not an integer")
if value < 0 or value > 4294967295:
raise ValueError("not a uint32")
return value
@classmethod
def _as_uint48(cls, value):
if not isinstance(value, int):
raise ValueError("not an integer")
if value < 0 or value > 281474976710655:
raise ValueError("not a uint48")
return value
@classmethod
def _as_int(cls, value, low=None, high=None):
if not isinstance(value, int):
raise ValueError("not an integer")
if low is not None and value < low:
raise ValueError("value too small")
if high is not None and value > high:
raise ValueError("value too large")
return value
@classmethod
def _as_ipv4_address(cls, value):
if isinstance(value, str):
return dns.ipv4.canonicalize(value)
elif isinstance(value, bytes):
return dns.ipv4.inet_ntoa(value)
elif isinstance(value, ipaddress.IPv4Address):
return dns.ipv4.inet_ntoa(value.packed)
else:
raise ValueError("not an IPv4 address")
@classmethod
def _as_ipv6_address(cls, value):
if isinstance(value, str):
return dns.ipv6.canonicalize(value)
elif isinstance(value, bytes):
return dns.ipv6.inet_ntoa(value)
elif isinstance(value, ipaddress.IPv6Address):
return dns.ipv6.inet_ntoa(value.packed)
else:
raise ValueError("not an IPv6 address")
@classmethod
def _as_bool(cls, value):
if isinstance(value, bool):
return value
else:
raise ValueError("not a boolean")
@classmethod
def _as_ttl(cls, value):
if isinstance(value, int):
return cls._as_int(value, 0, dns.ttl.MAX_TTL)
elif isinstance(value, str):
return dns.ttl.from_text(value)
else:
raise ValueError("not a TTL")
@classmethod
def _as_tuple(cls, value, as_value):
try:
# For user convenience, if value is a singleton of the list
# element type, wrap it in a tuple.
return (as_value(value),)
except Exception:
# Otherwise, check each element of the iterable *value*
# against *as_value*.
return tuple(as_value(v) for v in value)
# Processing order
@classmethod
def _processing_order(cls, iterable):
items = list(iterable)
random.shuffle(items)
return items
@dns.immutable.immutable
class GenericRdata(Rdata):
"""Generic Rdata Class
This class is used for rdata types for which we have no better
implementation. It implements the DNS "unknown RRs" scheme.
"""
__slots__ = ["data"]
def __init__(
self,
rdclass: dns.rdataclass.RdataClass,
rdtype: dns.rdatatype.RdataType,
data: bytes,
) -> None:
super().__init__(rdclass, rdtype)
self.data = data
def to_text(
self,
origin: dns.name.Name | None = None,
relativize: bool = True,
**kw: Dict[str, Any],
) -> str:
return rf"\# {len(self.data)} " + _hexify(self.data, **kw) # pyright: ignore
@classmethod
def from_text(
cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
):
token = tok.get()
if not token.is_identifier() or token.value != r"\#":
raise dns.exception.SyntaxError(r"generic rdata does not start with \#")
length = tok.get_int()
hex = tok.concatenate_remaining_identifiers(True).encode()
data = binascii.unhexlify(hex)
if len(data) != length:
raise dns.exception.SyntaxError("generic rdata hex data has wrong length")
return cls(rdclass, rdtype, data)
def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
file.write(self.data)
def to_generic(self, origin: dns.name.Name | None = None) -> "GenericRdata":
return self
@classmethod
def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
return cls(rdclass, rdtype, parser.get_remaining())
_rdata_classes: Dict[Tuple[dns.rdataclass.RdataClass, dns.rdatatype.RdataType], Any] = (
{}
)
_module_prefix = "dns.rdtypes"
_dynamic_load_allowed = True
def get_rdata_class(rdclass, rdtype, use_generic=True):
cls = _rdata_classes.get((rdclass, rdtype))
if not cls:
cls = _rdata_classes.get((dns.rdataclass.ANY, rdtype))
if not cls and _dynamic_load_allowed:
rdclass_text = dns.rdataclass.to_text(rdclass)
rdtype_text = dns.rdatatype.to_text(rdtype)
rdtype_text = rdtype_text.replace("-", "_")
try:
mod = import_module(
".".join([_module_prefix, rdclass_text, rdtype_text])
)
cls = getattr(mod, rdtype_text)
_rdata_classes[(rdclass, rdtype)] = cls
except ImportError:
try:
mod = import_module(".".join([_module_prefix, "ANY", rdtype_text]))
cls = getattr(mod, rdtype_text)
_rdata_classes[(dns.rdataclass.ANY, rdtype)] = cls
_rdata_classes[(rdclass, rdtype)] = cls
except ImportError:
pass
if not cls and use_generic:
cls = GenericRdata
_rdata_classes[(rdclass, rdtype)] = cls
return cls
def load_all_types(disable_dynamic_load=True):
"""Load all rdata types for which dnspython has a non-generic implementation.
Normally dnspython loads DNS rdatatype implementations on demand, but in some
specialized cases loading all types at an application-controlled time is preferred.
If *disable_dynamic_load*, a ``bool``, is ``True`` then dnspython will not attempt
to use its dynamic loading mechanism if an unknown type is subsequently encountered,
and will simply use the ``GenericRdata`` class.
"""
# Load class IN and ANY types.
for rdtype in dns.rdatatype.RdataType:
get_rdata_class(dns.rdataclass.IN, rdtype, False)
# Load the one non-ANY implementation we have in CH. Everything
# else in CH is an ANY type, and we'll discover those on demand but won't
# have to import anything.
get_rdata_class(dns.rdataclass.CH, dns.rdatatype.A, False)
if disable_dynamic_load:
# Now disable dynamic loading so any subsequent unknown type immediately becomes
# GenericRdata without a load attempt.
global _dynamic_load_allowed
_dynamic_load_allowed = False
def from_text(
rdclass: dns.rdataclass.RdataClass | str,
rdtype: dns.rdatatype.RdataType | str,
tok: dns.tokenizer.Tokenizer | str,
origin: dns.name.Name | None = None,
relativize: bool = True,
relativize_to: dns.name.Name | None = None,
idna_codec: dns.name.IDNACodec | None = None,
) -> Rdata:
"""Build an rdata object from text format.
This function attempts to dynamically load a class which
implements the specified rdata class and type. If there is no
class-and-type-specific implementation, the GenericRdata class
is used.
Once a class is chosen, its from_text() class method is called
with the parameters to this function.
If *tok* is a ``str``, then a tokenizer is created and the string
is used as its input.
*rdclass*, a ``dns.rdataclass.RdataClass`` or ``str``, the rdataclass.
*rdtype*, a ``dns.rdatatype.RdataType`` or ``str``, the rdatatype.
*tok*, a ``dns.tokenizer.Tokenizer`` or a ``str``.
*origin*, a ``dns.name.Name`` (or ``None``), the
origin to use for relative names.
*relativize*, a ``bool``. If true, name will be relativized.
*relativize_to*, a ``dns.name.Name`` (or ``None``), the origin to use
when relativizing names. If not set, the *origin* value will be used.
*idna_codec*, a ``dns.name.IDNACodec``, specifies the IDNA
encoder/decoder to use if a tokenizer needs to be created. If
``None``, the default IDNA 2003 encoder/decoder is used. If a
tokenizer is not created, then the codec associated with the tokenizer
is the one that is used.
Returns an instance of the chosen Rdata subclass.
"""
if isinstance(tok, str):
tok = dns.tokenizer.Tokenizer(tok, idna_codec=idna_codec)
if not isinstance(tok, dns.tokenizer.Tokenizer):
raise ValueError("tok must be a string or a Tokenizer")
rdclass = dns.rdataclass.RdataClass.make(rdclass)
rdtype = dns.rdatatype.RdataType.make(rdtype)
cls = get_rdata_class(rdclass, rdtype)
assert cls is not None # for type checkers
with dns.exception.ExceptionWrapper(dns.exception.SyntaxError):
rdata = None
if cls != GenericRdata:
# peek at first token
token = tok.get()
tok.unget(token)
if token.is_identifier() and token.value == r"\#":
#
# Known type using the generic syntax. Extract the
# wire form from the generic syntax, and then run
# from_wire on it.
#
grdata = GenericRdata.from_text(
rdclass, rdtype, tok, origin, relativize, relativize_to
)
rdata = from_wire(
rdclass, rdtype, grdata.data, 0, len(grdata.data), origin
)
#
# If this comparison isn't equal, then there must have been
# compressed names in the wire format, which is an error,
# there being no reasonable context to decompress with.
#
rwire = rdata.to_wire()
if rwire != grdata.data:
raise dns.exception.SyntaxError(
"compressed data in "
"generic syntax form "
"of known rdatatype"
)
if rdata is None:
rdata = cls.from_text(
rdclass, rdtype, tok, origin, relativize, relativize_to
)
token = tok.get_eol_as_token()
if token.comment is not None:
object.__setattr__(rdata, "rdcomment", token.comment)
return rdata
def from_wire_parser(
rdclass: dns.rdataclass.RdataClass | str,
rdtype: dns.rdatatype.RdataType | str,
parser: dns.wire.Parser,
origin: dns.name.Name | None = None,
) -> Rdata:
"""Build an rdata object from wire format
This function attempts to dynamically load a class which
implements the specified rdata class and type. If there is no
class-and-type-specific implementation, the GenericRdata class
is used.
Once a class is chosen, its from_wire() class method is called
with the parameters to this function.
*rdclass*, a ``dns.rdataclass.RdataClass`` or ``str``, the rdataclass.
*rdtype*, a ``dns.rdatatype.RdataType`` or ``str``, the rdatatype.
*parser*, a ``dns.wire.Parser``, the parser, which should be
restricted to the rdata length.
*origin*, a ``dns.name.Name`` (or ``None``). If not ``None``,
then names will be relativized to this origin.
Returns an instance of the chosen Rdata subclass.
"""
rdclass = dns.rdataclass.RdataClass.make(rdclass)
rdtype = dns.rdatatype.RdataType.make(rdtype)
cls = get_rdata_class(rdclass, rdtype)
assert cls is not None # for type checkers
with dns.exception.ExceptionWrapper(dns.exception.FormError):
return cls.from_wire_parser(rdclass, rdtype, parser, origin)
def from_wire(
rdclass: dns.rdataclass.RdataClass | str,
rdtype: dns.rdatatype.RdataType | str,
wire: bytes,
current: int,
rdlen: int,
origin: dns.name.Name | None = None,
) -> Rdata:
"""Build an rdata object from wire format
This function attempts to dynamically load a class which
implements the specified rdata class and type. If there is no
class-and-type-specific implementation, the GenericRdata class
is used.
Once a class is chosen, its from_wire() class method is called
with the parameters to this function.
*rdclass*, an ``int``, the rdataclass.
*rdtype*, an ``int``, the rdatatype.
*wire*, a ``bytes``, the wire-format message.
*current*, an ``int``, the offset in wire of the beginning of
the rdata.
*rdlen*, an ``int``, the length of the wire-format rdata
*origin*, a ``dns.name.Name`` (or ``None``). If not ``None``,
then names will be relativized to this origin.
Returns an instance of the chosen Rdata subclass.
"""
parser = dns.wire.Parser(wire, current)
with parser.restrict_to(rdlen):
return from_wire_parser(rdclass, rdtype, parser, origin)
class RdatatypeExists(dns.exception.DNSException):
"""DNS rdatatype already exists."""
supp_kwargs = {"rdclass", "rdtype"}
fmt = (
"The rdata type with class {rdclass:d} and rdtype {rdtype:d} "
+ "already exists."
)
def register_type(
implementation: Any,
rdtype: int,
rdtype_text: str,
is_singleton: bool = False,
rdclass: dns.rdataclass.RdataClass = dns.rdataclass.IN,
) -> None:
"""Dynamically register a module to handle an rdatatype.
*implementation*, a subclass of ``dns.rdata.Rdata`` implementing the type,
or a module containing such a class named by its text form.
*rdtype*, an ``int``, the rdatatype to register.
*rdtype_text*, a ``str``, the textual form of the rdatatype.
*is_singleton*, a ``bool``, indicating if the type is a singleton (i.e.
RRsets of the type can have only one member.)
*rdclass*, the rdataclass of the type, or ``dns.rdataclass.ANY`` if
it applies to all classes.
"""
rdtype = dns.rdatatype.RdataType.make(rdtype)
existing_cls = get_rdata_class(rdclass, rdtype)
if existing_cls != GenericRdata or dns.rdatatype.is_metatype(rdtype):
raise RdatatypeExists(rdclass=rdclass, rdtype=rdtype)
if isinstance(implementation, type) and issubclass(implementation, Rdata):
impclass = implementation
else:
impclass = getattr(implementation, rdtype_text.replace("-", "_"))
_rdata_classes[(rdclass, rdtype)] = impclass
dns.rdatatype.register_type(rdtype, rdtype_text, is_singleton)

View File

@@ -0,0 +1,118 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
# Copyright (C) 2001-2017 Nominum, Inc.
#
# Permission to use, copy, modify, and distribute this software and its
# documentation for any purpose with or without fee is hereby granted,
# provided that the above copyright notice and this permission notice
# appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
"""DNS Rdata Classes."""
import dns.enum
import dns.exception
class RdataClass(dns.enum.IntEnum):
"""DNS Rdata Class"""
RESERVED0 = 0
IN = 1
INTERNET = IN
CH = 3
CHAOS = CH
HS = 4
HESIOD = HS
NONE = 254
ANY = 255
@classmethod
def _maximum(cls):
return 65535
@classmethod
def _short_name(cls):
return "class"
@classmethod
def _prefix(cls):
return "CLASS"
@classmethod
def _unknown_exception_class(cls):
return UnknownRdataclass
_metaclasses = {RdataClass.NONE, RdataClass.ANY}
class UnknownRdataclass(dns.exception.DNSException):
"""A DNS class is unknown."""
def from_text(text: str) -> RdataClass:
"""Convert text into a DNS rdata class value.
The input text can be a defined DNS RR class mnemonic or
instance of the DNS generic class syntax.
For example, "IN" and "CLASS1" will both result in a value of 1.
Raises ``dns.rdatatype.UnknownRdataclass`` if the class is unknown.
Raises ``ValueError`` if the rdata class value is not >= 0 and <= 65535.
Returns a ``dns.rdataclass.RdataClass``.
"""
return RdataClass.from_text(text)
def to_text(value: RdataClass) -> str:
"""Convert a DNS rdata class value to text.
If the value has a known mnemonic, it will be used, otherwise the
DNS generic class syntax will be used.
Raises ``ValueError`` if the rdata class value is not >= 0 and <= 65535.
Returns a ``str``.
"""
return RdataClass.to_text(value)
def is_metaclass(rdclass: RdataClass) -> bool:
"""True if the specified class is a metaclass.
The currently defined metaclasses are ANY and NONE.
*rdclass* is a ``dns.rdataclass.RdataClass``.
"""
if rdclass in _metaclasses:
return True
return False
### BEGIN generated RdataClass constants
RESERVED0 = RdataClass.RESERVED0
IN = RdataClass.IN
INTERNET = RdataClass.INTERNET
CH = RdataClass.CH
CHAOS = RdataClass.CHAOS
HS = RdataClass.HS
HESIOD = RdataClass.HESIOD
NONE = RdataClass.NONE
ANY = RdataClass.ANY
### END generated RdataClass constants

View File

@@ -0,0 +1,508 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
# Copyright (C) 2001-2017 Nominum, Inc.
#
# Permission to use, copy, modify, and distribute this software and its
# documentation for any purpose with or without fee is hereby granted,
# provided that the above copyright notice and this permission notice
# appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
"""DNS rdatasets (an rdataset is a set of rdatas of a given type and class)"""
import io
import random
import struct
from typing import Any, Collection, Dict, List, cast
import dns.exception
import dns.immutable
import dns.name
import dns.rdata
import dns.rdataclass
import dns.rdatatype
import dns.renderer
import dns.set
import dns.ttl
# define SimpleSet here for backwards compatibility
SimpleSet = dns.set.Set
class DifferingCovers(dns.exception.DNSException):
"""An attempt was made to add a DNS SIG/RRSIG whose covered type
is not the same as that of the other rdatas in the rdataset."""
class IncompatibleTypes(dns.exception.DNSException):
"""An attempt was made to add DNS RR data of an incompatible type."""
class Rdataset(dns.set.Set):
"""A DNS rdataset."""
__slots__ = ["rdclass", "rdtype", "covers", "ttl"]
def __init__(
self,
rdclass: dns.rdataclass.RdataClass,
rdtype: dns.rdatatype.RdataType,
covers: dns.rdatatype.RdataType = dns.rdatatype.NONE,
ttl: int = 0,
):
"""Create a new rdataset of the specified class and type.
*rdclass*, a ``dns.rdataclass.RdataClass``, the rdataclass.
*rdtype*, an ``dns.rdatatype.RdataType``, the rdatatype.
*covers*, an ``dns.rdatatype.RdataType``, the covered rdatatype.
*ttl*, an ``int``, the TTL.
"""
super().__init__()
self.rdclass = rdclass
self.rdtype: dns.rdatatype.RdataType = rdtype
self.covers: dns.rdatatype.RdataType = covers
self.ttl = ttl
def _clone(self):
obj = cast(Rdataset, super()._clone())
obj.rdclass = self.rdclass
obj.rdtype = self.rdtype
obj.covers = self.covers
obj.ttl = self.ttl
return obj
def update_ttl(self, ttl: int) -> None:
"""Perform TTL minimization.
Set the TTL of the rdataset to be the lesser of the set's current
TTL or the specified TTL. If the set contains no rdatas, set the TTL
to the specified TTL.
*ttl*, an ``int`` or ``str``.
"""
ttl = dns.ttl.make(ttl)
if len(self) == 0:
self.ttl = ttl
elif ttl < self.ttl:
self.ttl = ttl
# pylint: disable=arguments-differ,arguments-renamed
def add( # pyright: ignore
self, rd: dns.rdata.Rdata, ttl: int | None = None
) -> None:
"""Add the specified rdata to the rdataset.
If the optional *ttl* parameter is supplied, then
``self.update_ttl(ttl)`` will be called prior to adding the rdata.
*rd*, a ``dns.rdata.Rdata``, the rdata
*ttl*, an ``int``, the TTL.
Raises ``dns.rdataset.IncompatibleTypes`` if the type and class
do not match the type and class of the rdataset.
Raises ``dns.rdataset.DifferingCovers`` if the type is a signature
type and the covered type does not match that of the rdataset.
"""
#
# If we're adding a signature, do some special handling to
# check that the signature covers the same type as the
# other rdatas in this rdataset. If this is the first rdata
# in the set, initialize the covers field.
#
if self.rdclass != rd.rdclass or self.rdtype != rd.rdtype:
raise IncompatibleTypes
if ttl is not None:
self.update_ttl(ttl)
if self.rdtype == dns.rdatatype.RRSIG or self.rdtype == dns.rdatatype.SIG:
covers = rd.covers()
if len(self) == 0 and self.covers == dns.rdatatype.NONE:
self.covers = covers
elif self.covers != covers:
raise DifferingCovers
if dns.rdatatype.is_singleton(rd.rdtype) and len(self) > 0:
self.clear()
super().add(rd)
def union_update(self, other):
self.update_ttl(other.ttl)
super().union_update(other)
def intersection_update(self, other):
self.update_ttl(other.ttl)
super().intersection_update(other)
def update(self, other):
"""Add all rdatas in other to self.
*other*, a ``dns.rdataset.Rdataset``, the rdataset from which
to update.
"""
self.update_ttl(other.ttl)
super().update(other)
def _rdata_repr(self):
def maybe_truncate(s):
if len(s) > 100:
return s[:100] + "..."
return s
return "[" + ", ".join(f"<{maybe_truncate(str(rr))}>" for rr in self) + "]"
def __repr__(self):
if self.covers == 0:
ctext = ""
else:
ctext = "(" + dns.rdatatype.to_text(self.covers) + ")"
return (
"<DNS "
+ dns.rdataclass.to_text(self.rdclass)
+ " "
+ dns.rdatatype.to_text(self.rdtype)
+ ctext
+ " rdataset: "
+ self._rdata_repr()
+ ">"
)
def __str__(self):
return self.to_text()
def __eq__(self, other):
if not isinstance(other, Rdataset):
return False
if (
self.rdclass != other.rdclass
or self.rdtype != other.rdtype
or self.covers != other.covers
):
return False
return super().__eq__(other)
def __ne__(self, other):
return not self.__eq__(other)
def to_text(
self,
name: dns.name.Name | None = None,
origin: dns.name.Name | None = None,
relativize: bool = True,
override_rdclass: dns.rdataclass.RdataClass | None = None,
want_comments: bool = False,
**kw: Dict[str, Any],
) -> str:
"""Convert the rdataset into DNS zone file format.
See ``dns.name.Name.choose_relativity`` for more information
on how *origin* and *relativize* determine the way names
are emitted.
Any additional keyword arguments are passed on to the rdata
``to_text()`` method.
*name*, a ``dns.name.Name``. If name is not ``None``, emit RRs with
*name* as the owner name.
*origin*, a ``dns.name.Name`` or ``None``, the origin for relative
names.
*relativize*, a ``bool``. If ``True``, names will be relativized
to *origin*.
*override_rdclass*, a ``dns.rdataclass.RdataClass`` or ``None``.
If not ``None``, use this class instead of the Rdataset's class.
*want_comments*, a ``bool``. If ``True``, emit comments for rdata
which have them. The default is ``False``.
"""
if name is not None:
name = name.choose_relativity(origin, relativize)
ntext = str(name)
pad = " "
else:
ntext = ""
pad = ""
s = io.StringIO()
if override_rdclass is not None:
rdclass = override_rdclass
else:
rdclass = self.rdclass
if len(self) == 0:
#
# Empty rdatasets are used for the question section, and in
# some dynamic updates, so we don't need to print out the TTL
# (which is meaningless anyway).
#
s.write(
f"{ntext}{pad}{dns.rdataclass.to_text(rdclass)} "
f"{dns.rdatatype.to_text(self.rdtype)}\n"
)
else:
for rd in self:
extra = ""
if want_comments:
if rd.rdcomment:
extra = f" ;{rd.rdcomment}"
s.write(
f"{ntext}{pad}{self.ttl} "
f"{dns.rdataclass.to_text(rdclass)} "
f"{dns.rdatatype.to_text(self.rdtype)} "
f"{rd.to_text(origin=origin, relativize=relativize, **kw)}"
f"{extra}\n"
)
#
# We strip off the final \n for the caller's convenience in printing
#
return s.getvalue()[:-1]
def to_wire(
self,
name: dns.name.Name,
file: Any,
compress: dns.name.CompressType | None = None,
origin: dns.name.Name | None = None,
override_rdclass: dns.rdataclass.RdataClass | None = None,
want_shuffle: bool = True,
) -> int:
"""Convert the rdataset to wire format.
*name*, a ``dns.name.Name`` is the owner name to use.
*file* is the file where the name is emitted (typically a
BytesIO file).
*compress*, a ``dict``, is the compression table to use. If
``None`` (the default), names will not be compressed.
*origin* is a ``dns.name.Name`` or ``None``. If the name is
relative and origin is not ``None``, then *origin* will be appended
to it.
*override_rdclass*, an ``int``, is used as the class instead of the
class of the rdataset. This is useful when rendering rdatasets
associated with dynamic updates.
*want_shuffle*, a ``bool``. If ``True``, then the order of the
Rdatas within the Rdataset will be shuffled before rendering.
Returns an ``int``, the number of records emitted.
"""
if override_rdclass is not None:
rdclass = override_rdclass
want_shuffle = False
else:
rdclass = self.rdclass
if len(self) == 0:
name.to_wire(file, compress, origin)
file.write(struct.pack("!HHIH", self.rdtype, rdclass, 0, 0))
return 1
else:
l: Rdataset | List[dns.rdata.Rdata]
if want_shuffle:
l = list(self)
random.shuffle(l)
else:
l = self
for rd in l:
name.to_wire(file, compress, origin)
file.write(struct.pack("!HHI", self.rdtype, rdclass, self.ttl))
with dns.renderer.prefixed_length(file, 2):
rd.to_wire(file, compress, origin)
return len(self)
def match(
self,
rdclass: dns.rdataclass.RdataClass,
rdtype: dns.rdatatype.RdataType,
covers: dns.rdatatype.RdataType,
) -> bool:
"""Returns ``True`` if this rdataset matches the specified class,
type, and covers.
"""
if self.rdclass == rdclass and self.rdtype == rdtype and self.covers == covers:
return True
return False
def processing_order(self) -> List[dns.rdata.Rdata]:
"""Return rdatas in a valid processing order according to the type's
specification. For example, MX records are in preference order from
lowest to highest preferences, with items of the same preference
shuffled.
For types that do not define a processing order, the rdatas are
simply shuffled.
"""
if len(self) == 0:
return []
else:
return self[0]._processing_order(iter(self)) # pyright: ignore
@dns.immutable.immutable
class ImmutableRdataset(Rdataset): # lgtm[py/missing-equals]
"""An immutable DNS rdataset."""
_clone_class = Rdataset
def __init__(self, rdataset: Rdataset):
"""Create an immutable rdataset from the specified rdataset."""
super().__init__(
rdataset.rdclass, rdataset.rdtype, rdataset.covers, rdataset.ttl
)
self.items = dns.immutable.Dict(rdataset.items)
def update_ttl(self, ttl):
raise TypeError("immutable")
def add(self, rd, ttl=None):
raise TypeError("immutable")
def union_update(self, other):
raise TypeError("immutable")
def intersection_update(self, other):
raise TypeError("immutable")
def update(self, other):
raise TypeError("immutable")
def __delitem__(self, i):
raise TypeError("immutable")
# lgtm complains about these not raising ArithmeticError, but there is
# precedent for overrides of these methods in other classes to raise
# TypeError, and it seems like the better exception.
def __ior__(self, other): # lgtm[py/unexpected-raise-in-special-method]
raise TypeError("immutable")
def __iand__(self, other): # lgtm[py/unexpected-raise-in-special-method]
raise TypeError("immutable")
def __iadd__(self, other): # lgtm[py/unexpected-raise-in-special-method]
raise TypeError("immutable")
def __isub__(self, other): # lgtm[py/unexpected-raise-in-special-method]
raise TypeError("immutable")
def clear(self):
raise TypeError("immutable")
def __copy__(self):
return ImmutableRdataset(super().copy()) # pyright: ignore
def copy(self):
return ImmutableRdataset(super().copy()) # pyright: ignore
def union(self, other):
return ImmutableRdataset(super().union(other)) # pyright: ignore
def intersection(self, other):
return ImmutableRdataset(super().intersection(other)) # pyright: ignore
def difference(self, other):
return ImmutableRdataset(super().difference(other)) # pyright: ignore
def symmetric_difference(self, other):
return ImmutableRdataset(super().symmetric_difference(other)) # pyright: ignore
def from_text_list(
rdclass: dns.rdataclass.RdataClass | str,
rdtype: dns.rdatatype.RdataType | str,
ttl: int,
text_rdatas: Collection[str],
idna_codec: dns.name.IDNACodec | None = None,
origin: dns.name.Name | None = None,
relativize: bool = True,
relativize_to: dns.name.Name | None = None,
) -> Rdataset:
"""Create an rdataset with the specified class, type, and TTL, and with
the specified list of rdatas in text format.
*idna_codec*, a ``dns.name.IDNACodec``, specifies the IDNA
encoder/decoder to use; if ``None``, the default IDNA 2003
encoder/decoder is used.
*origin*, a ``dns.name.Name`` (or ``None``), the
origin to use for relative names.
*relativize*, a ``bool``. If true, name will be relativized.
*relativize_to*, a ``dns.name.Name`` (or ``None``), the origin to use
when relativizing names. If not set, the *origin* value will be used.
Returns a ``dns.rdataset.Rdataset`` object.
"""
rdclass = dns.rdataclass.RdataClass.make(rdclass)
rdtype = dns.rdatatype.RdataType.make(rdtype)
r = Rdataset(rdclass, rdtype)
r.update_ttl(ttl)
for t in text_rdatas:
rd = dns.rdata.from_text(
r.rdclass, r.rdtype, t, origin, relativize, relativize_to, idna_codec
)
r.add(rd)
return r
def from_text(
rdclass: dns.rdataclass.RdataClass | str,
rdtype: dns.rdatatype.RdataType | str,
ttl: int,
*text_rdatas: Any,
) -> Rdataset:
"""Create an rdataset with the specified class, type, and TTL, and with
the specified rdatas in text format.
Returns a ``dns.rdataset.Rdataset`` object.
"""
return from_text_list(rdclass, rdtype, ttl, cast(Collection[str], text_rdatas))
def from_rdata_list(ttl: int, rdatas: Collection[dns.rdata.Rdata]) -> Rdataset:
"""Create an rdataset with the specified TTL, and with
the specified list of rdata objects.
Returns a ``dns.rdataset.Rdataset`` object.
"""
if len(rdatas) == 0:
raise ValueError("rdata list must not be empty")
r = None
for rd in rdatas:
if r is None:
r = Rdataset(rd.rdclass, rd.rdtype)
r.update_ttl(ttl)
r.add(rd)
assert r is not None
return r
def from_rdata(ttl: int, *rdatas: Any) -> Rdataset:
"""Create an rdataset with the specified TTL, and with
the specified rdata objects.
Returns a ``dns.rdataset.Rdataset`` object.
"""
return from_rdata_list(ttl, cast(Collection[dns.rdata.Rdata], rdatas))

View File

@@ -0,0 +1,338 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
# Copyright (C) 2001-2017 Nominum, Inc.
#
# Permission to use, copy, modify, and distribute this software and its
# documentation for any purpose with or without fee is hereby granted,
# provided that the above copyright notice and this permission notice
# appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
"""DNS Rdata Types."""
from typing import Dict
import dns.enum
import dns.exception
class RdataType(dns.enum.IntEnum):
"""DNS Rdata Type"""
TYPE0 = 0
NONE = 0
A = 1
NS = 2
MD = 3
MF = 4
CNAME = 5
SOA = 6
MB = 7
MG = 8
MR = 9
NULL = 10
WKS = 11
PTR = 12
HINFO = 13
MINFO = 14
MX = 15
TXT = 16
RP = 17
AFSDB = 18
X25 = 19
ISDN = 20
RT = 21
NSAP = 22
NSAP_PTR = 23
SIG = 24
KEY = 25
PX = 26
GPOS = 27
AAAA = 28
LOC = 29
NXT = 30
SRV = 33
NAPTR = 35
KX = 36
CERT = 37
A6 = 38
DNAME = 39
OPT = 41
APL = 42
DS = 43
SSHFP = 44
IPSECKEY = 45
RRSIG = 46
NSEC = 47
DNSKEY = 48
DHCID = 49
NSEC3 = 50
NSEC3PARAM = 51
TLSA = 52
SMIMEA = 53
HIP = 55
NINFO = 56
CDS = 59
CDNSKEY = 60
OPENPGPKEY = 61
CSYNC = 62
ZONEMD = 63
SVCB = 64
HTTPS = 65
DSYNC = 66
SPF = 99
UNSPEC = 103
NID = 104
L32 = 105
L64 = 106
LP = 107
EUI48 = 108
EUI64 = 109
TKEY = 249
TSIG = 250
IXFR = 251
AXFR = 252
MAILB = 253
MAILA = 254
ANY = 255
URI = 256
CAA = 257
AVC = 258
AMTRELAY = 260
RESINFO = 261
WALLET = 262
TA = 32768
DLV = 32769
@classmethod
def _maximum(cls):
return 65535
@classmethod
def _short_name(cls):
return "type"
@classmethod
def _prefix(cls):
return "TYPE"
@classmethod
def _extra_from_text(cls, text):
if text.find("-") >= 0:
try:
return cls[text.replace("-", "_")]
except KeyError: # pragma: no cover
pass
return _registered_by_text.get(text)
@classmethod
def _extra_to_text(cls, value, current_text):
if current_text is None:
return _registered_by_value.get(value)
if current_text.find("_") >= 0:
return current_text.replace("_", "-")
return current_text
@classmethod
def _unknown_exception_class(cls):
return UnknownRdatatype
_registered_by_text: Dict[str, RdataType] = {}
_registered_by_value: Dict[RdataType, str] = {}
_metatypes = {RdataType.OPT}
_singletons = {
RdataType.SOA,
RdataType.NXT,
RdataType.DNAME,
RdataType.NSEC,
RdataType.CNAME,
}
class UnknownRdatatype(dns.exception.DNSException):
"""DNS resource record type is unknown."""
def from_text(text: str) -> RdataType:
"""Convert text into a DNS rdata type value.
The input text can be a defined DNS RR type mnemonic or
instance of the DNS generic type syntax.
For example, "NS" and "TYPE2" will both result in a value of 2.
Raises ``dns.rdatatype.UnknownRdatatype`` if the type is unknown.
Raises ``ValueError`` if the rdata type value is not >= 0 and <= 65535.
Returns a ``dns.rdatatype.RdataType``.
"""
return RdataType.from_text(text)
def to_text(value: RdataType) -> str:
"""Convert a DNS rdata type value to text.
If the value has a known mnemonic, it will be used, otherwise the
DNS generic type syntax will be used.
Raises ``ValueError`` if the rdata type value is not >= 0 and <= 65535.
Returns a ``str``.
"""
return RdataType.to_text(value)
def is_metatype(rdtype: RdataType) -> bool:
"""True if the specified type is a metatype.
*rdtype* is a ``dns.rdatatype.RdataType``.
The currently defined metatypes are TKEY, TSIG, IXFR, AXFR, MAILA,
MAILB, ANY, and OPT.
Returns a ``bool``.
"""
return (256 > rdtype >= 128) or rdtype in _metatypes
def is_singleton(rdtype: RdataType) -> bool:
"""Is the specified type a singleton type?
Singleton types can only have a single rdata in an rdataset, or a single
RR in an RRset.
The currently defined singleton types are CNAME, DNAME, NSEC, NXT, and
SOA.
*rdtype* is an ``int``.
Returns a ``bool``.
"""
if rdtype in _singletons:
return True
return False
# pylint: disable=redefined-outer-name
def register_type(
rdtype: RdataType, rdtype_text: str, is_singleton: bool = False
) -> None:
"""Dynamically register an rdatatype.
*rdtype*, a ``dns.rdatatype.RdataType``, the rdatatype to register.
*rdtype_text*, a ``str``, the textual form of the rdatatype.
*is_singleton*, a ``bool``, indicating if the type is a singleton (i.e.
RRsets of the type can have only one member.)
"""
_registered_by_text[rdtype_text] = rdtype
_registered_by_value[rdtype] = rdtype_text
if is_singleton:
_singletons.add(rdtype)
### BEGIN generated RdataType constants
TYPE0 = RdataType.TYPE0
NONE = RdataType.NONE
A = RdataType.A
NS = RdataType.NS
MD = RdataType.MD
MF = RdataType.MF
CNAME = RdataType.CNAME
SOA = RdataType.SOA
MB = RdataType.MB
MG = RdataType.MG
MR = RdataType.MR
NULL = RdataType.NULL
WKS = RdataType.WKS
PTR = RdataType.PTR
HINFO = RdataType.HINFO
MINFO = RdataType.MINFO
MX = RdataType.MX
TXT = RdataType.TXT
RP = RdataType.RP
AFSDB = RdataType.AFSDB
X25 = RdataType.X25
ISDN = RdataType.ISDN
RT = RdataType.RT
NSAP = RdataType.NSAP
NSAP_PTR = RdataType.NSAP_PTR
SIG = RdataType.SIG
KEY = RdataType.KEY
PX = RdataType.PX
GPOS = RdataType.GPOS
AAAA = RdataType.AAAA
LOC = RdataType.LOC
NXT = RdataType.NXT
SRV = RdataType.SRV
NAPTR = RdataType.NAPTR
KX = RdataType.KX
CERT = RdataType.CERT
A6 = RdataType.A6
DNAME = RdataType.DNAME
OPT = RdataType.OPT
APL = RdataType.APL
DS = RdataType.DS
SSHFP = RdataType.SSHFP
IPSECKEY = RdataType.IPSECKEY
RRSIG = RdataType.RRSIG
NSEC = RdataType.NSEC
DNSKEY = RdataType.DNSKEY
DHCID = RdataType.DHCID
NSEC3 = RdataType.NSEC3
NSEC3PARAM = RdataType.NSEC3PARAM
TLSA = RdataType.TLSA
SMIMEA = RdataType.SMIMEA
HIP = RdataType.HIP
NINFO = RdataType.NINFO
CDS = RdataType.CDS
CDNSKEY = RdataType.CDNSKEY
OPENPGPKEY = RdataType.OPENPGPKEY
CSYNC = RdataType.CSYNC
ZONEMD = RdataType.ZONEMD
SVCB = RdataType.SVCB
HTTPS = RdataType.HTTPS
DSYNC = RdataType.DSYNC
SPF = RdataType.SPF
UNSPEC = RdataType.UNSPEC
NID = RdataType.NID
L32 = RdataType.L32
L64 = RdataType.L64
LP = RdataType.LP
EUI48 = RdataType.EUI48
EUI64 = RdataType.EUI64
TKEY = RdataType.TKEY
TSIG = RdataType.TSIG
IXFR = RdataType.IXFR
AXFR = RdataType.AXFR
MAILB = RdataType.MAILB
MAILA = RdataType.MAILA
ANY = RdataType.ANY
URI = RdataType.URI
CAA = RdataType.CAA
AVC = RdataType.AVC
AMTRELAY = RdataType.AMTRELAY
RESINFO = RdataType.RESINFO
WALLET = RdataType.WALLET
TA = RdataType.TA
DLV = RdataType.DLV
### END generated RdataType constants

View File

@@ -0,0 +1,45 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc.
#
# Permission to use, copy, modify, and distribute this software and its
# documentation for any purpose with or without fee is hereby granted,
# provided that the above copyright notice and this permission notice
# appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import dns.immutable
import dns.rdtypes.mxbase
@dns.immutable.immutable
class AFSDB(dns.rdtypes.mxbase.UncompressedDowncasingMX):
"""AFSDB record"""
# Use the property mechanism to make "subtype" an alias for the
# "preference" attribute, and "hostname" an alias for the "exchange"
# attribute.
#
# This lets us inherit the UncompressedMX implementation but lets
# the caller use appropriate attribute names for the rdata type.
#
# We probably lose some performance vs. a cut-and-paste
# implementation, but this way we don't copy code, and that's
# good.
@property
def subtype(self):
"the AFSDB subtype"
return self.preference
@property
def hostname(self):
"the AFSDB hostname"
return self.exchange

View File

@@ -0,0 +1,89 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
# Copyright (C) 2006, 2007, 2009-2011 Nominum, Inc.
#
# Permission to use, copy, modify, and distribute this software and its
# documentation for any purpose with or without fee is hereby granted,
# provided that the above copyright notice and this permission notice
# appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import struct
import dns.exception
import dns.immutable
import dns.rdata
import dns.rdtypes.util
class Relay(dns.rdtypes.util.Gateway):
name = "AMTRELAY relay"
@property
def relay(self):
return self.gateway
@dns.immutable.immutable
class AMTRELAY(dns.rdata.Rdata):
"""AMTRELAY record"""
# see: RFC 8777
__slots__ = ["precedence", "discovery_optional", "relay_type", "relay"]
def __init__(
self, rdclass, rdtype, precedence, discovery_optional, relay_type, relay
):
super().__init__(rdclass, rdtype)
relay = Relay(relay_type, relay)
self.precedence = self._as_uint8(precedence)
self.discovery_optional = self._as_bool(discovery_optional)
self.relay_type = relay.type
self.relay = relay.relay
def to_text(self, origin=None, relativize=True, **kw):
relay = Relay(self.relay_type, self.relay).to_text(origin, relativize)
return (
f"{self.precedence} {self.discovery_optional:d} {self.relay_type} {relay}"
)
@classmethod
def from_text(
cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
):
precedence = tok.get_uint8()
discovery_optional = tok.get_uint8()
if discovery_optional > 1:
raise dns.exception.SyntaxError("expecting 0 or 1")
discovery_optional = bool(discovery_optional)
relay_type = tok.get_uint8()
if relay_type > 0x7F:
raise dns.exception.SyntaxError("expecting an integer <= 127")
relay = Relay.from_text(relay_type, tok, origin, relativize, relativize_to)
return cls(
rdclass, rdtype, precedence, discovery_optional, relay_type, relay.relay
)
def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
relay_type = self.relay_type | (self.discovery_optional << 7)
header = struct.pack("!BB", self.precedence, relay_type)
file.write(header)
Relay(self.relay_type, self.relay).to_wire(file, compress, origin, canonicalize)
@classmethod
def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
(precedence, relay_type) = parser.get_struct("!BB")
discovery_optional = bool(relay_type >> 7)
relay_type &= 0x7F
relay = Relay.from_wire_parser(relay_type, parser, origin)
return cls(
rdclass, rdtype, precedence, discovery_optional, relay_type, relay.relay
)

View File

@@ -0,0 +1,26 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
# Copyright (C) 2016 Nominum, Inc.
#
# Permission to use, copy, modify, and distribute this software and its
# documentation for any purpose with or without fee is hereby granted,
# provided that the above copyright notice and this permission notice
# appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import dns.immutable
import dns.rdtypes.txtbase
@dns.immutable.immutable
class AVC(dns.rdtypes.txtbase.TXTBase):
"""AVC record"""
# See: IANA dns parameters for AVC

View File

@@ -0,0 +1,67 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc.
#
# Permission to use, copy, modify, and distribute this software and its
# documentation for any purpose with or without fee is hereby granted,
# provided that the above copyright notice and this permission notice
# appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import struct
import dns.exception
import dns.immutable
import dns.rdata
import dns.tokenizer
@dns.immutable.immutable
class CAA(dns.rdata.Rdata):
"""CAA (Certification Authority Authorization) record"""
# see: RFC 6844
__slots__ = ["flags", "tag", "value"]
def __init__(self, rdclass, rdtype, flags, tag, value):
super().__init__(rdclass, rdtype)
self.flags = self._as_uint8(flags)
self.tag = self._as_bytes(tag, True, 255)
if not tag.isalnum():
raise ValueError("tag is not alphanumeric")
self.value = self._as_bytes(value)
def to_text(self, origin=None, relativize=True, **kw):
return f'{self.flags} {dns.rdata._escapify(self.tag)} "{dns.rdata._escapify(self.value)}"'
@classmethod
def from_text(
cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
):
flags = tok.get_uint8()
tag = tok.get_string().encode()
value = tok.get_string().encode()
return cls(rdclass, rdtype, flags, tag, value)
def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
file.write(struct.pack("!B", self.flags))
l = len(self.tag)
assert l < 256
file.write(struct.pack("!B", l))
file.write(self.tag)
file.write(self.value)
@classmethod
def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
flags = parser.get_uint8()
tag = parser.get_counted_bytes()
value = parser.get_remaining()
return cls(rdclass, rdtype, flags, tag, value)

View File

@@ -0,0 +1,33 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
# Copyright (C) 2004-2007, 2009-2011 Nominum, Inc.
#
# Permission to use, copy, modify, and distribute this software and its
# documentation for any purpose with or without fee is hereby granted,
# provided that the above copyright notice and this permission notice
# appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import dns.immutable
import dns.rdtypes.dnskeybase # lgtm[py/import-and-import-from]
# pylint: disable=unused-import
from dns.rdtypes.dnskeybase import ( # noqa: F401 lgtm[py/unused-import]
REVOKE,
SEP,
ZONE,
)
# pylint: enable=unused-import
@dns.immutable.immutable
class CDNSKEY(dns.rdtypes.dnskeybase.DNSKEYBase):
"""CDNSKEY record"""

View File

@@ -0,0 +1,29 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc.
#
# Permission to use, copy, modify, and distribute this software and its
# documentation for any purpose with or without fee is hereby granted,
# provided that the above copyright notice and this permission notice
# appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import dns.immutable
import dns.rdtypes.dsbase
@dns.immutable.immutable
class CDS(dns.rdtypes.dsbase.DSBase):
"""CDS record"""
_digest_length_by_type = {
**dns.rdtypes.dsbase.DSBase._digest_length_by_type,
0: 1, # delete, RFC 8078 Sec. 4 (including Errata ID 5049)
}

View File

@@ -0,0 +1,113 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc.
#
# Permission to use, copy, modify, and distribute this software and its
# documentation for any purpose with or without fee is hereby granted,
# provided that the above copyright notice and this permission notice
# appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import base64
import struct
import dns.dnssectypes
import dns.exception
import dns.immutable
import dns.rdata
import dns.tokenizer
_ctype_by_value = {
1: "PKIX",
2: "SPKI",
3: "PGP",
4: "IPKIX",
5: "ISPKI",
6: "IPGP",
7: "ACPKIX",
8: "IACPKIX",
253: "URI",
254: "OID",
}
_ctype_by_name = {
"PKIX": 1,
"SPKI": 2,
"PGP": 3,
"IPKIX": 4,
"ISPKI": 5,
"IPGP": 6,
"ACPKIX": 7,
"IACPKIX": 8,
"URI": 253,
"OID": 254,
}
def _ctype_from_text(what):
v = _ctype_by_name.get(what)
if v is not None:
return v
return int(what)
def _ctype_to_text(what):
v = _ctype_by_value.get(what)
if v is not None:
return v
return str(what)
@dns.immutable.immutable
class CERT(dns.rdata.Rdata):
"""CERT record"""
# see RFC 4398
__slots__ = ["certificate_type", "key_tag", "algorithm", "certificate"]
def __init__(
self, rdclass, rdtype, certificate_type, key_tag, algorithm, certificate
):
super().__init__(rdclass, rdtype)
self.certificate_type = self._as_uint16(certificate_type)
self.key_tag = self._as_uint16(key_tag)
self.algorithm = self._as_uint8(algorithm)
self.certificate = self._as_bytes(certificate)
def to_text(self, origin=None, relativize=True, **kw):
certificate_type = _ctype_to_text(self.certificate_type)
algorithm = dns.dnssectypes.Algorithm.to_text(self.algorithm)
certificate = dns.rdata._base64ify(self.certificate, **kw) # pyright: ignore
return f"{certificate_type} {self.key_tag} {algorithm} {certificate}"
@classmethod
def from_text(
cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
):
certificate_type = _ctype_from_text(tok.get_string())
key_tag = tok.get_uint16()
algorithm = dns.dnssectypes.Algorithm.from_text(tok.get_string())
b64 = tok.concatenate_remaining_identifiers().encode()
certificate = base64.b64decode(b64)
return cls(rdclass, rdtype, certificate_type, key_tag, algorithm, certificate)
def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
prefix = struct.pack(
"!HHB", self.certificate_type, self.key_tag, self.algorithm
)
file.write(prefix)
file.write(self.certificate)
@classmethod
def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
(certificate_type, key_tag, algorithm) = parser.get_struct("!HHB")
certificate = parser.get_remaining()
return cls(rdclass, rdtype, certificate_type, key_tag, algorithm, certificate)

View File

@@ -0,0 +1,28 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc.
#
# Permission to use, copy, modify, and distribute this software and its
# documentation for any purpose with or without fee is hereby granted,
# provided that the above copyright notice and this permission notice
# appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import dns.immutable
import dns.rdtypes.nsbase
@dns.immutable.immutable
class CNAME(dns.rdtypes.nsbase.NSBase):
"""CNAME record
Note: although CNAME is officially a singleton type, dnspython allows
non-singleton CNAME rdatasets because such sets have been commonly
used by BIND and other nameservers for load balancing."""

View File

@@ -0,0 +1,68 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
# Copyright (C) 2004-2007, 2009-2011, 2016 Nominum, Inc.
#
# Permission to use, copy, modify, and distribute this software and its
# documentation for any purpose with or without fee is hereby granted,
# provided that the above copyright notice and this permission notice
# appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import struct
import dns.exception
import dns.immutable
import dns.name
import dns.rdata
import dns.rdatatype
import dns.rdtypes.util
@dns.immutable.immutable
class Bitmap(dns.rdtypes.util.Bitmap):
type_name = "CSYNC"
@dns.immutable.immutable
class CSYNC(dns.rdata.Rdata):
"""CSYNC record"""
__slots__ = ["serial", "flags", "windows"]
def __init__(self, rdclass, rdtype, serial, flags, windows):
super().__init__(rdclass, rdtype)
self.serial = self._as_uint32(serial)
self.flags = self._as_uint16(flags)
if not isinstance(windows, Bitmap):
windows = Bitmap(windows)
self.windows = tuple(windows.windows)
def to_text(self, origin=None, relativize=True, **kw):
text = Bitmap(self.windows).to_text()
return f"{self.serial} {self.flags}{text}"
@classmethod
def from_text(
cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
):
serial = tok.get_uint32()
flags = tok.get_uint16()
bitmap = Bitmap.from_text(tok)
return cls(rdclass, rdtype, serial, flags, bitmap)
def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
file.write(struct.pack("!IH", self.serial, self.flags))
Bitmap(self.windows).to_wire(file)
@classmethod
def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
(serial, flags) = parser.get_struct("!IH")
bitmap = Bitmap.from_wire_parser(parser)
return cls(rdclass, rdtype, serial, flags, bitmap)

View File

@@ -0,0 +1,24 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
# Copyright (C) 2009-2011 Nominum, Inc.
#
# Permission to use, copy, modify, and distribute this software and its
# documentation for any purpose with or without fee is hereby granted,
# provided that the above copyright notice and this permission notice
# appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import dns.immutable
import dns.rdtypes.dsbase
@dns.immutable.immutable
class DLV(dns.rdtypes.dsbase.DSBase):
"""DLV record"""

View File

@@ -0,0 +1,27 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc.
#
# Permission to use, copy, modify, and distribute this software and its
# documentation for any purpose with or without fee is hereby granted,
# provided that the above copyright notice and this permission notice
# appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import dns.immutable
import dns.rdtypes.nsbase
@dns.immutable.immutable
class DNAME(dns.rdtypes.nsbase.UncompressedNS):
"""DNAME record"""
def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
self.target.to_wire(file, None, origin, canonicalize)

View File

@@ -0,0 +1,33 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
# Copyright (C) 2004-2007, 2009-2011 Nominum, Inc.
#
# Permission to use, copy, modify, and distribute this software and its
# documentation for any purpose with or without fee is hereby granted,
# provided that the above copyright notice and this permission notice
# appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import dns.immutable
import dns.rdtypes.dnskeybase # lgtm[py/import-and-import-from]
# pylint: disable=unused-import
from dns.rdtypes.dnskeybase import ( # noqa: F401 lgtm[py/unused-import]
REVOKE,
SEP,
ZONE,
)
# pylint: enable=unused-import
@dns.immutable.immutable
class DNSKEY(dns.rdtypes.dnskeybase.DNSKEYBase):
"""DNSKEY record"""

View File

@@ -0,0 +1,24 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc.
#
# Permission to use, copy, modify, and distribute this software and its
# documentation for any purpose with or without fee is hereby granted,
# provided that the above copyright notice and this permission notice
# appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import dns.immutable
import dns.rdtypes.dsbase
@dns.immutable.immutable
class DS(dns.rdtypes.dsbase.DSBase):
"""DS record"""

View File

@@ -0,0 +1,72 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
import struct
import dns.enum
import dns.exception
import dns.immutable
import dns.rdata
import dns.rdatatype
import dns.rdtypes.util
class UnknownScheme(dns.exception.DNSException):
"""Unknown DSYNC scheme"""
class Scheme(dns.enum.IntEnum):
"""DSYNC SCHEME"""
NOTIFY = 1
@classmethod
def _maximum(cls):
return 255
@classmethod
def _unknown_exception_class(cls):
return UnknownScheme
@dns.immutable.immutable
class DSYNC(dns.rdata.Rdata):
"""DSYNC record"""
# see: draft-ietf-dnsop-generalized-notify
__slots__ = ["rrtype", "scheme", "port", "target"]
def __init__(self, rdclass, rdtype, rrtype, scheme, port, target):
super().__init__(rdclass, rdtype)
self.rrtype = self._as_rdatatype(rrtype)
self.scheme = Scheme.make(scheme)
self.port = self._as_uint16(port)
self.target = self._as_name(target)
def to_text(self, origin=None, relativize=True, **kw):
target = self.target.choose_relativity(origin, relativize)
return (
f"{dns.rdatatype.to_text(self.rrtype)} {Scheme.to_text(self.scheme)} "
f"{self.port} {target}"
)
@classmethod
def from_text(
cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
):
rrtype = dns.rdatatype.from_text(tok.get_string())
scheme = Scheme.make(tok.get_string())
port = tok.get_uint16()
target = tok.get_name(origin, relativize, relativize_to)
return cls(rdclass, rdtype, rrtype, scheme, port, target)
def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
three_ints = struct.pack("!HBH", self.rrtype, self.scheme, self.port)
file.write(three_ints)
self.target.to_wire(file, None, origin, False)
@classmethod
def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
(rrtype, scheme, port) = parser.get_struct("!HBH")
target = parser.get_name(origin)
return cls(rdclass, rdtype, rrtype, scheme, port, target)

View File

@@ -0,0 +1,30 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
# Copyright (C) 2015 Red Hat, Inc.
# Author: Petr Spacek <pspacek@redhat.com>
#
# Permission to use, copy, modify, and distribute this software and its
# documentation for any purpose with or without fee is hereby granted,
# provided that the above copyright notice and this permission notice
# appear in all copies.
#
# THE SOFTWARE IS PROVIDED 'AS IS' AND RED HAT DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import dns.immutable
import dns.rdtypes.euibase
@dns.immutable.immutable
class EUI48(dns.rdtypes.euibase.EUIBase):
"""EUI48 record"""
# see: rfc7043.txt
byte_len = 6 # 0123456789ab (in hex)
text_len = byte_len * 3 - 1 # 01-23-45-67-89-ab

View File

@@ -0,0 +1,30 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
# Copyright (C) 2015 Red Hat, Inc.
# Author: Petr Spacek <pspacek@redhat.com>
#
# Permission to use, copy, modify, and distribute this software and its
# documentation for any purpose with or without fee is hereby granted,
# provided that the above copyright notice and this permission notice
# appear in all copies.
#
# THE SOFTWARE IS PROVIDED 'AS IS' AND RED HAT DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import dns.immutable
import dns.rdtypes.euibase
@dns.immutable.immutable
class EUI64(dns.rdtypes.euibase.EUIBase):
"""EUI64 record"""
# see: rfc7043.txt
byte_len = 8 # 0123456789abcdef (in hex)
text_len = byte_len * 3 - 1 # 01-23-45-67-89-ab-cd-ef

View File

@@ -0,0 +1,126 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc.
#
# Permission to use, copy, modify, and distribute this software and its
# documentation for any purpose with or without fee is hereby granted,
# provided that the above copyright notice and this permission notice
# appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import struct
import dns.exception
import dns.immutable
import dns.rdata
import dns.tokenizer
def _validate_float_string(what):
if len(what) == 0:
raise dns.exception.FormError
if what[0] == b"-"[0] or what[0] == b"+"[0]:
what = what[1:]
if what.isdigit():
return
try:
(left, right) = what.split(b".")
except ValueError:
raise dns.exception.FormError
if left == b"" and right == b"":
raise dns.exception.FormError
if not left == b"" and not left.decode().isdigit():
raise dns.exception.FormError
if not right == b"" and not right.decode().isdigit():
raise dns.exception.FormError
@dns.immutable.immutable
class GPOS(dns.rdata.Rdata):
"""GPOS record"""
# see: RFC 1712
__slots__ = ["latitude", "longitude", "altitude"]
def __init__(self, rdclass, rdtype, latitude, longitude, altitude):
super().__init__(rdclass, rdtype)
if isinstance(latitude, float) or isinstance(latitude, int):
latitude = str(latitude)
if isinstance(longitude, float) or isinstance(longitude, int):
longitude = str(longitude)
if isinstance(altitude, float) or isinstance(altitude, int):
altitude = str(altitude)
latitude = self._as_bytes(latitude, True, 255)
longitude = self._as_bytes(longitude, True, 255)
altitude = self._as_bytes(altitude, True, 255)
_validate_float_string(latitude)
_validate_float_string(longitude)
_validate_float_string(altitude)
self.latitude = latitude
self.longitude = longitude
self.altitude = altitude
flat = self.float_latitude
if flat < -90.0 or flat > 90.0:
raise dns.exception.FormError("bad latitude")
flong = self.float_longitude
if flong < -180.0 or flong > 180.0:
raise dns.exception.FormError("bad longitude")
def to_text(self, origin=None, relativize=True, **kw):
return (
f"{self.latitude.decode()} {self.longitude.decode()} "
f"{self.altitude.decode()}"
)
@classmethod
def from_text(
cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
):
latitude = tok.get_string()
longitude = tok.get_string()
altitude = tok.get_string()
return cls(rdclass, rdtype, latitude, longitude, altitude)
def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
l = len(self.latitude)
assert l < 256
file.write(struct.pack("!B", l))
file.write(self.latitude)
l = len(self.longitude)
assert l < 256
file.write(struct.pack("!B", l))
file.write(self.longitude)
l = len(self.altitude)
assert l < 256
file.write(struct.pack("!B", l))
file.write(self.altitude)
@classmethod
def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
latitude = parser.get_counted_bytes()
longitude = parser.get_counted_bytes()
altitude = parser.get_counted_bytes()
return cls(rdclass, rdtype, latitude, longitude, altitude)
@property
def float_latitude(self):
"latitude as a floating point value"
return float(self.latitude)
@property
def float_longitude(self):
"longitude as a floating point value"
return float(self.longitude)
@property
def float_altitude(self):
"altitude as a floating point value"
return float(self.altitude)

View File

@@ -0,0 +1,64 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc.
#
# Permission to use, copy, modify, and distribute this software and its
# documentation for any purpose with or without fee is hereby granted,
# provided that the above copyright notice and this permission notice
# appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import struct
import dns.exception
import dns.immutable
import dns.rdata
import dns.tokenizer
@dns.immutable.immutable
class HINFO(dns.rdata.Rdata):
"""HINFO record"""
# see: RFC 1035
__slots__ = ["cpu", "os"]
def __init__(self, rdclass, rdtype, cpu, os):
super().__init__(rdclass, rdtype)
self.cpu = self._as_bytes(cpu, True, 255)
self.os = self._as_bytes(os, True, 255)
def to_text(self, origin=None, relativize=True, **kw):
return f'"{dns.rdata._escapify(self.cpu)}" "{dns.rdata._escapify(self.os)}"'
@classmethod
def from_text(
cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
):
cpu = tok.get_string(max_length=255)
os = tok.get_string(max_length=255)
return cls(rdclass, rdtype, cpu, os)
def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
l = len(self.cpu)
assert l < 256
file.write(struct.pack("!B", l))
file.write(self.cpu)
l = len(self.os)
assert l < 256
file.write(struct.pack("!B", l))
file.write(self.os)
@classmethod
def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
cpu = parser.get_counted_bytes()
os = parser.get_counted_bytes()
return cls(rdclass, rdtype, cpu, os)

View File

@@ -0,0 +1,85 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
# Copyright (C) 2010, 2011 Nominum, Inc.
#
# Permission to use, copy, modify, and distribute this software and its
# documentation for any purpose with or without fee is hereby granted,
# provided that the above copyright notice and this permission notice
# appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import base64
import binascii
import struct
import dns.exception
import dns.immutable
import dns.rdata
import dns.rdatatype
@dns.immutable.immutable
class HIP(dns.rdata.Rdata):
"""HIP record"""
# see: RFC 5205
__slots__ = ["hit", "algorithm", "key", "servers"]
def __init__(self, rdclass, rdtype, hit, algorithm, key, servers):
super().__init__(rdclass, rdtype)
self.hit = self._as_bytes(hit, True, 255)
self.algorithm = self._as_uint8(algorithm)
self.key = self._as_bytes(key, True)
self.servers = self._as_tuple(servers, self._as_name)
def to_text(self, origin=None, relativize=True, **kw):
hit = binascii.hexlify(self.hit).decode()
key = base64.b64encode(self.key).replace(b"\n", b"").decode()
text = ""
servers = []
for server in self.servers:
servers.append(server.choose_relativity(origin, relativize))
if len(servers) > 0:
text += " " + " ".join(x.to_unicode() for x in servers)
return f"{self.algorithm} {hit} {key}{text}"
@classmethod
def from_text(
cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
):
algorithm = tok.get_uint8()
hit = binascii.unhexlify(tok.get_string().encode())
key = base64.b64decode(tok.get_string().encode())
servers = []
for token in tok.get_remaining():
server = tok.as_name(token, origin, relativize, relativize_to)
servers.append(server)
return cls(rdclass, rdtype, hit, algorithm, key, servers)
def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
lh = len(self.hit)
lk = len(self.key)
file.write(struct.pack("!BBH", lh, self.algorithm, lk))
file.write(self.hit)
file.write(self.key)
for server in self.servers:
server.to_wire(file, None, origin, False)
@classmethod
def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
(lh, algorithm, lk) = parser.get_struct("!BBH")
hit = parser.get_bytes(lh)
key = parser.get_bytes(lk)
servers = []
while parser.remaining() > 0:
server = parser.get_name(origin)
servers.append(server)
return cls(rdclass, rdtype, hit, algorithm, key, servers)

View File

@@ -0,0 +1,78 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc.
#
# Permission to use, copy, modify, and distribute this software and its
# documentation for any purpose with or without fee is hereby granted,
# provided that the above copyright notice and this permission notice
# appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import struct
import dns.exception
import dns.immutable
import dns.rdata
import dns.tokenizer
@dns.immutable.immutable
class ISDN(dns.rdata.Rdata):
"""ISDN record"""
# see: RFC 1183
__slots__ = ["address", "subaddress"]
def __init__(self, rdclass, rdtype, address, subaddress):
super().__init__(rdclass, rdtype)
self.address = self._as_bytes(address, True, 255)
self.subaddress = self._as_bytes(subaddress, True, 255)
def to_text(self, origin=None, relativize=True, **kw):
if self.subaddress:
return (
f'"{dns.rdata._escapify(self.address)}" '
f'"{dns.rdata._escapify(self.subaddress)}"'
)
else:
return f'"{dns.rdata._escapify(self.address)}"'
@classmethod
def from_text(
cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
):
address = tok.get_string()
tokens = tok.get_remaining(max_tokens=1)
if len(tokens) >= 1:
subaddress = tokens[0].unescape().value
else:
subaddress = ""
return cls(rdclass, rdtype, address, subaddress)
def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
l = len(self.address)
assert l < 256
file.write(struct.pack("!B", l))
file.write(self.address)
l = len(self.subaddress)
if l > 0:
assert l < 256
file.write(struct.pack("!B", l))
file.write(self.subaddress)
@classmethod
def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
address = parser.get_counted_bytes()
if parser.remaining() > 0:
subaddress = parser.get_counted_bytes()
else:
subaddress = b""
return cls(rdclass, rdtype, address, subaddress)

View File

@@ -0,0 +1,42 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
import struct
import dns.immutable
import dns.ipv4
import dns.rdata
@dns.immutable.immutable
class L32(dns.rdata.Rdata):
"""L32 record"""
# see: rfc6742.txt
__slots__ = ["preference", "locator32"]
def __init__(self, rdclass, rdtype, preference, locator32):
super().__init__(rdclass, rdtype)
self.preference = self._as_uint16(preference)
self.locator32 = self._as_ipv4_address(locator32)
def to_text(self, origin=None, relativize=True, **kw):
return f"{self.preference} {self.locator32}"
@classmethod
def from_text(
cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
):
preference = tok.get_uint16()
nodeid = tok.get_identifier()
return cls(rdclass, rdtype, preference, nodeid)
def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
file.write(struct.pack("!H", self.preference))
file.write(dns.ipv4.inet_aton(self.locator32))
@classmethod
def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
preference = parser.get_uint16()
locator32 = parser.get_remaining()
return cls(rdclass, rdtype, preference, locator32)

View File

@@ -0,0 +1,48 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
import struct
import dns.immutable
import dns.rdata
import dns.rdtypes.util
@dns.immutable.immutable
class L64(dns.rdata.Rdata):
"""L64 record"""
# see: rfc6742.txt
__slots__ = ["preference", "locator64"]
def __init__(self, rdclass, rdtype, preference, locator64):
super().__init__(rdclass, rdtype)
self.preference = self._as_uint16(preference)
if isinstance(locator64, bytes):
if len(locator64) != 8:
raise ValueError("invalid locator64")
self.locator64 = dns.rdata._hexify(locator64, 4, b":")
else:
dns.rdtypes.util.parse_formatted_hex(locator64, 4, 4, ":")
self.locator64 = locator64
def to_text(self, origin=None, relativize=True, **kw):
return f"{self.preference} {self.locator64}"
@classmethod
def from_text(
cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
):
preference = tok.get_uint16()
locator64 = tok.get_identifier()
return cls(rdclass, rdtype, preference, locator64)
def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
file.write(struct.pack("!H", self.preference))
file.write(dns.rdtypes.util.parse_formatted_hex(self.locator64, 4, 4, ":"))
@classmethod
def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
preference = parser.get_uint16()
locator64 = parser.get_remaining()
return cls(rdclass, rdtype, preference, locator64)

View File

@@ -0,0 +1,347 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc.
#
# Permission to use, copy, modify, and distribute this software and its
# documentation for any purpose with or without fee is hereby granted,
# provided that the above copyright notice and this permission notice
# appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import struct
import dns.exception
import dns.immutable
import dns.rdata
_pows = tuple(10**i for i in range(0, 11))
# default values are in centimeters
_default_size = 100.0
_default_hprec = 1000000.0
_default_vprec = 1000.0
# for use by from_wire()
_MAX_LATITUDE = 0x80000000 + 90 * 3600000
_MIN_LATITUDE = 0x80000000 - 90 * 3600000
_MAX_LONGITUDE = 0x80000000 + 180 * 3600000
_MIN_LONGITUDE = 0x80000000 - 180 * 3600000
def _exponent_of(what, desc):
if what == 0:
return 0
exp = None
for i, pow in enumerate(_pows):
if what < pow:
exp = i - 1
break
if exp is None or exp < 0:
raise dns.exception.SyntaxError(f"{desc} value out of bounds")
return exp
def _float_to_tuple(what):
if what < 0:
sign = -1
what *= -1
else:
sign = 1
what = round(what * 3600000)
degrees = int(what // 3600000)
what -= degrees * 3600000
minutes = int(what // 60000)
what -= minutes * 60000
seconds = int(what // 1000)
what -= int(seconds * 1000)
what = int(what)
return (degrees, minutes, seconds, what, sign)
def _tuple_to_float(what):
value = float(what[0])
value += float(what[1]) / 60.0
value += float(what[2]) / 3600.0
value += float(what[3]) / 3600000.0
return float(what[4]) * value
def _encode_size(what, desc):
what = int(what)
exponent = _exponent_of(what, desc) & 0xF
base = what // pow(10, exponent) & 0xF
return base * 16 + exponent
def _decode_size(what, desc):
exponent = what & 0x0F
if exponent > 9:
raise dns.exception.FormError(f"bad {desc} exponent")
base = (what & 0xF0) >> 4
if base > 9:
raise dns.exception.FormError(f"bad {desc} base")
return base * pow(10, exponent)
def _check_coordinate_list(value, low, high):
if value[0] < low or value[0] > high:
raise ValueError(f"not in range [{low}, {high}]")
if value[1] < 0 or value[1] > 59:
raise ValueError("bad minutes value")
if value[2] < 0 or value[2] > 59:
raise ValueError("bad seconds value")
if value[3] < 0 or value[3] > 999:
raise ValueError("bad milliseconds value")
if value[4] != 1 and value[4] != -1:
raise ValueError("bad hemisphere value")
@dns.immutable.immutable
class LOC(dns.rdata.Rdata):
"""LOC record"""
# see: RFC 1876
__slots__ = [
"latitude",
"longitude",
"altitude",
"size",
"horizontal_precision",
"vertical_precision",
]
def __init__(
self,
rdclass,
rdtype,
latitude,
longitude,
altitude,
size=_default_size,
hprec=_default_hprec,
vprec=_default_vprec,
):
"""Initialize a LOC record instance.
The parameters I{latitude} and I{longitude} may be either a 4-tuple
of integers specifying (degrees, minutes, seconds, milliseconds),
or they may be floating point values specifying the number of
degrees. The other parameters are floats. Size, horizontal precision,
and vertical precision are specified in centimeters."""
super().__init__(rdclass, rdtype)
if isinstance(latitude, int):
latitude = float(latitude)
if isinstance(latitude, float):
latitude = _float_to_tuple(latitude)
_check_coordinate_list(latitude, -90, 90)
self.latitude = tuple(latitude) # pyright: ignore
if isinstance(longitude, int):
longitude = float(longitude)
if isinstance(longitude, float):
longitude = _float_to_tuple(longitude)
_check_coordinate_list(longitude, -180, 180)
self.longitude = tuple(longitude) # pyright: ignore
self.altitude = float(altitude)
self.size = float(size)
self.horizontal_precision = float(hprec)
self.vertical_precision = float(vprec)
def to_text(self, origin=None, relativize=True, **kw):
if self.latitude[4] > 0:
lat_hemisphere = "N"
else:
lat_hemisphere = "S"
if self.longitude[4] > 0:
long_hemisphere = "E"
else:
long_hemisphere = "W"
text = (
f"{self.latitude[0]} {self.latitude[1]} "
f"{self.latitude[2]}.{self.latitude[3]:03d} {lat_hemisphere} "
f"{self.longitude[0]} {self.longitude[1]} "
f"{self.longitude[2]}.{self.longitude[3]:03d} {long_hemisphere} "
f"{(self.altitude / 100.0):0.2f}m"
)
# do not print default values
if (
self.size != _default_size
or self.horizontal_precision != _default_hprec
or self.vertical_precision != _default_vprec
):
text += (
f" {self.size / 100.0:0.2f}m {self.horizontal_precision / 100.0:0.2f}m"
f" {self.vertical_precision / 100.0:0.2f}m"
)
return text
@classmethod
def from_text(
cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
):
latitude = [0, 0, 0, 0, 1]
longitude = [0, 0, 0, 0, 1]
size = _default_size
hprec = _default_hprec
vprec = _default_vprec
latitude[0] = tok.get_int()
t = tok.get_string()
if t.isdigit():
latitude[1] = int(t)
t = tok.get_string()
if "." in t:
(seconds, milliseconds) = t.split(".")
if not seconds.isdigit():
raise dns.exception.SyntaxError("bad latitude seconds value")
latitude[2] = int(seconds)
l = len(milliseconds)
if l == 0 or l > 3 or not milliseconds.isdigit():
raise dns.exception.SyntaxError("bad latitude milliseconds value")
if l == 1:
m = 100
elif l == 2:
m = 10
else:
m = 1
latitude[3] = m * int(milliseconds)
t = tok.get_string()
elif t.isdigit():
latitude[2] = int(t)
t = tok.get_string()
if t == "S":
latitude[4] = -1
elif t != "N":
raise dns.exception.SyntaxError("bad latitude hemisphere value")
longitude[0] = tok.get_int()
t = tok.get_string()
if t.isdigit():
longitude[1] = int(t)
t = tok.get_string()
if "." in t:
(seconds, milliseconds) = t.split(".")
if not seconds.isdigit():
raise dns.exception.SyntaxError("bad longitude seconds value")
longitude[2] = int(seconds)
l = len(milliseconds)
if l == 0 or l > 3 or not milliseconds.isdigit():
raise dns.exception.SyntaxError("bad longitude milliseconds value")
if l == 1:
m = 100
elif l == 2:
m = 10
else:
m = 1
longitude[3] = m * int(milliseconds)
t = tok.get_string()
elif t.isdigit():
longitude[2] = int(t)
t = tok.get_string()
if t == "W":
longitude[4] = -1
elif t != "E":
raise dns.exception.SyntaxError("bad longitude hemisphere value")
t = tok.get_string()
if t[-1] == "m":
t = t[0:-1]
altitude = float(t) * 100.0 # m -> cm
tokens = tok.get_remaining(max_tokens=3)
if len(tokens) >= 1:
value = tokens[0].unescape().value
if value[-1] == "m":
value = value[0:-1]
size = float(value) * 100.0 # m -> cm
if len(tokens) >= 2:
value = tokens[1].unescape().value
if value[-1] == "m":
value = value[0:-1]
hprec = float(value) * 100.0 # m -> cm
if len(tokens) >= 3:
value = tokens[2].unescape().value
if value[-1] == "m":
value = value[0:-1]
vprec = float(value) * 100.0 # m -> cm
# Try encoding these now so we raise if they are bad
_encode_size(size, "size")
_encode_size(hprec, "horizontal precision")
_encode_size(vprec, "vertical precision")
return cls(rdclass, rdtype, latitude, longitude, altitude, size, hprec, vprec)
def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
milliseconds = (
self.latitude[0] * 3600000
+ self.latitude[1] * 60000
+ self.latitude[2] * 1000
+ self.latitude[3]
) * self.latitude[4]
latitude = 0x80000000 + milliseconds
milliseconds = (
self.longitude[0] * 3600000
+ self.longitude[1] * 60000
+ self.longitude[2] * 1000
+ self.longitude[3]
) * self.longitude[4]
longitude = 0x80000000 + milliseconds
altitude = int(self.altitude) + 10000000
size = _encode_size(self.size, "size")
hprec = _encode_size(self.horizontal_precision, "horizontal precision")
vprec = _encode_size(self.vertical_precision, "vertical precision")
wire = struct.pack(
"!BBBBIII", 0, size, hprec, vprec, latitude, longitude, altitude
)
file.write(wire)
@classmethod
def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
(
version,
size,
hprec,
vprec,
latitude,
longitude,
altitude,
) = parser.get_struct("!BBBBIII")
if version != 0:
raise dns.exception.FormError("LOC version not zero")
if latitude < _MIN_LATITUDE or latitude > _MAX_LATITUDE:
raise dns.exception.FormError("bad latitude")
if latitude > 0x80000000:
latitude = (latitude - 0x80000000) / 3600000
else:
latitude = -1 * (0x80000000 - latitude) / 3600000
if longitude < _MIN_LONGITUDE or longitude > _MAX_LONGITUDE:
raise dns.exception.FormError("bad longitude")
if longitude > 0x80000000:
longitude = (longitude - 0x80000000) / 3600000
else:
longitude = -1 * (0x80000000 - longitude) / 3600000
altitude = float(altitude) - 10000000.0
size = _decode_size(size, "size")
hprec = _decode_size(hprec, "horizontal precision")
vprec = _decode_size(vprec, "vertical precision")
return cls(rdclass, rdtype, latitude, longitude, altitude, size, hprec, vprec)
@property
def float_latitude(self):
"latitude as a floating point value"
return _tuple_to_float(self.latitude)
@property
def float_longitude(self):
"longitude as a floating point value"
return _tuple_to_float(self.longitude)

View File

@@ -0,0 +1,42 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
import struct
import dns.immutable
import dns.rdata
@dns.immutable.immutable
class LP(dns.rdata.Rdata):
"""LP record"""
# see: rfc6742.txt
__slots__ = ["preference", "fqdn"]
def __init__(self, rdclass, rdtype, preference, fqdn):
super().__init__(rdclass, rdtype)
self.preference = self._as_uint16(preference)
self.fqdn = self._as_name(fqdn)
def to_text(self, origin=None, relativize=True, **kw):
fqdn = self.fqdn.choose_relativity(origin, relativize)
return f"{self.preference} {fqdn}"
@classmethod
def from_text(
cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
):
preference = tok.get_uint16()
fqdn = tok.get_name(origin, relativize, relativize_to)
return cls(rdclass, rdtype, preference, fqdn)
def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
file.write(struct.pack("!H", self.preference))
self.fqdn.to_wire(file, compress, origin, canonicalize)
@classmethod
def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
preference = parser.get_uint16()
fqdn = parser.get_name(origin)
return cls(rdclass, rdtype, preference, fqdn)

View File

@@ -0,0 +1,24 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc.
#
# Permission to use, copy, modify, and distribute this software and its
# documentation for any purpose with or without fee is hereby granted,
# provided that the above copyright notice and this permission notice
# appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import dns.immutable
import dns.rdtypes.mxbase
@dns.immutable.immutable
class MX(dns.rdtypes.mxbase.MXBase):
"""MX record"""

View File

@@ -0,0 +1,48 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
import struct
import dns.immutable
import dns.rdata
import dns.rdtypes.util
@dns.immutable.immutable
class NID(dns.rdata.Rdata):
"""NID record"""
# see: rfc6742.txt
__slots__ = ["preference", "nodeid"]
def __init__(self, rdclass, rdtype, preference, nodeid):
super().__init__(rdclass, rdtype)
self.preference = self._as_uint16(preference)
if isinstance(nodeid, bytes):
if len(nodeid) != 8:
raise ValueError("invalid nodeid")
self.nodeid = dns.rdata._hexify(nodeid, 4, b":")
else:
dns.rdtypes.util.parse_formatted_hex(nodeid, 4, 4, ":")
self.nodeid = nodeid
def to_text(self, origin=None, relativize=True, **kw):
return f"{self.preference} {self.nodeid}"
@classmethod
def from_text(
cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
):
preference = tok.get_uint16()
nodeid = tok.get_identifier()
return cls(rdclass, rdtype, preference, nodeid)
def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
file.write(struct.pack("!H", self.preference))
file.write(dns.rdtypes.util.parse_formatted_hex(self.nodeid, 4, 4, ":"))
@classmethod
def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
preference = parser.get_uint16()
nodeid = parser.get_remaining()
return cls(rdclass, rdtype, preference, nodeid)

View File

@@ -0,0 +1,26 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
# Copyright (C) 2006, 2007, 2009-2011 Nominum, Inc.
#
# Permission to use, copy, modify, and distribute this software and its
# documentation for any purpose with or without fee is hereby granted,
# provided that the above copyright notice and this permission notice
# appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import dns.immutable
import dns.rdtypes.txtbase
@dns.immutable.immutable
class NINFO(dns.rdtypes.txtbase.TXTBase):
"""NINFO record"""
# see: draft-reid-dnsext-zs-01

View File

@@ -0,0 +1,24 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc.
#
# Permission to use, copy, modify, and distribute this software and its
# documentation for any purpose with or without fee is hereby granted,
# provided that the above copyright notice and this permission notice
# appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import dns.immutable
import dns.rdtypes.nsbase
@dns.immutable.immutable
class NS(dns.rdtypes.nsbase.NSBase):
"""NS record"""

View File

@@ -0,0 +1,67 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
# Copyright (C) 2004-2007, 2009-2011 Nominum, Inc.
#
# Permission to use, copy, modify, and distribute this software and its
# documentation for any purpose with or without fee is hereby granted,
# provided that the above copyright notice and this permission notice
# appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import dns.exception
import dns.immutable
import dns.name
import dns.rdata
import dns.rdatatype
import dns.rdtypes.util
@dns.immutable.immutable
class Bitmap(dns.rdtypes.util.Bitmap):
type_name = "NSEC"
@dns.immutable.immutable
class NSEC(dns.rdata.Rdata):
"""NSEC record"""
__slots__ = ["next", "windows"]
def __init__(self, rdclass, rdtype, next, windows):
super().__init__(rdclass, rdtype)
self.next = self._as_name(next)
if not isinstance(windows, Bitmap):
windows = Bitmap(windows)
self.windows = tuple(windows.windows)
def to_text(self, origin=None, relativize=True, **kw):
next = self.next.choose_relativity(origin, relativize)
text = Bitmap(self.windows).to_text()
return f"{next}{text}"
@classmethod
def from_text(
cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
):
next = tok.get_name(origin, relativize, relativize_to)
windows = Bitmap.from_text(tok)
return cls(rdclass, rdtype, next, windows)
def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
# Note that NSEC downcasing, originally mandated by RFC 4034
# section 6.2 was removed by RFC 6840 section 5.1.
self.next.to_wire(file, None, origin, False)
Bitmap(self.windows).to_wire(file)
@classmethod
def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
next = parser.get_name(origin)
bitmap = Bitmap.from_wire_parser(parser)
return cls(rdclass, rdtype, next, bitmap)

View File

@@ -0,0 +1,120 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
# Copyright (C) 2004-2017 Nominum, Inc.
#
# Permission to use, copy, modify, and distribute this software and its
# documentation for any purpose with or without fee is hereby granted,
# provided that the above copyright notice and this permission notice
# appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import base64
import binascii
import struct
import dns.exception
import dns.immutable
import dns.name
import dns.rdata
import dns.rdatatype
import dns.rdtypes.util
b32_hex_to_normal = bytes.maketrans(
b"0123456789ABCDEFGHIJKLMNOPQRSTUV", b"ABCDEFGHIJKLMNOPQRSTUVWXYZ234567"
)
b32_normal_to_hex = bytes.maketrans(
b"ABCDEFGHIJKLMNOPQRSTUVWXYZ234567", b"0123456789ABCDEFGHIJKLMNOPQRSTUV"
)
# hash algorithm constants
SHA1 = 1
# flag constants
OPTOUT = 1
@dns.immutable.immutable
class Bitmap(dns.rdtypes.util.Bitmap):
type_name = "NSEC3"
@dns.immutable.immutable
class NSEC3(dns.rdata.Rdata):
"""NSEC3 record"""
__slots__ = ["algorithm", "flags", "iterations", "salt", "next", "windows"]
def __init__(
self, rdclass, rdtype, algorithm, flags, iterations, salt, next, windows
):
super().__init__(rdclass, rdtype)
self.algorithm = self._as_uint8(algorithm)
self.flags = self._as_uint8(flags)
self.iterations = self._as_uint16(iterations)
self.salt = self._as_bytes(salt, True, 255)
self.next = self._as_bytes(next, True, 255)
if not isinstance(windows, Bitmap):
windows = Bitmap(windows)
self.windows = tuple(windows.windows)
def _next_text(self):
next = base64.b32encode(self.next).translate(b32_normal_to_hex).lower().decode()
next = next.rstrip("=")
return next
def to_text(self, origin=None, relativize=True, **kw):
next = self._next_text()
if self.salt == b"":
salt = "-"
else:
salt = binascii.hexlify(self.salt).decode()
text = Bitmap(self.windows).to_text()
return f"{self.algorithm} {self.flags} {self.iterations} {salt} {next}{text}"
@classmethod
def from_text(
cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
):
algorithm = tok.get_uint8()
flags = tok.get_uint8()
iterations = tok.get_uint16()
salt = tok.get_string()
if salt == "-":
salt = b""
else:
salt = binascii.unhexlify(salt.encode("ascii"))
next = tok.get_string().encode("ascii").upper().translate(b32_hex_to_normal)
if next.endswith(b"="):
raise binascii.Error("Incorrect padding")
if len(next) % 8 != 0:
next += b"=" * (8 - len(next) % 8)
next = base64.b32decode(next)
bitmap = Bitmap.from_text(tok)
return cls(rdclass, rdtype, algorithm, flags, iterations, salt, next, bitmap)
def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
l = len(self.salt)
file.write(struct.pack("!BBHB", self.algorithm, self.flags, self.iterations, l))
file.write(self.salt)
l = len(self.next)
file.write(struct.pack("!B", l))
file.write(self.next)
Bitmap(self.windows).to_wire(file)
@classmethod
def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
(algorithm, flags, iterations) = parser.get_struct("!BBH")
salt = parser.get_counted_bytes()
next = parser.get_counted_bytes()
bitmap = Bitmap.from_wire_parser(parser)
return cls(rdclass, rdtype, algorithm, flags, iterations, salt, next, bitmap)
def next_name(self, origin=None):
return dns.name.from_text(self._next_text(), origin)

View File

@@ -0,0 +1,69 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
# Copyright (C) 2004-2007, 2009-2011 Nominum, Inc.
#
# Permission to use, copy, modify, and distribute this software and its
# documentation for any purpose with or without fee is hereby granted,
# provided that the above copyright notice and this permission notice
# appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import binascii
import struct
import dns.exception
import dns.immutable
import dns.rdata
@dns.immutable.immutable
class NSEC3PARAM(dns.rdata.Rdata):
"""NSEC3PARAM record"""
__slots__ = ["algorithm", "flags", "iterations", "salt"]
def __init__(self, rdclass, rdtype, algorithm, flags, iterations, salt):
super().__init__(rdclass, rdtype)
self.algorithm = self._as_uint8(algorithm)
self.flags = self._as_uint8(flags)
self.iterations = self._as_uint16(iterations)
self.salt = self._as_bytes(salt, True, 255)
def to_text(self, origin=None, relativize=True, **kw):
if self.salt == b"":
salt = "-"
else:
salt = binascii.hexlify(self.salt).decode()
return f"{self.algorithm} {self.flags} {self.iterations} {salt}"
@classmethod
def from_text(
cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
):
algorithm = tok.get_uint8()
flags = tok.get_uint8()
iterations = tok.get_uint16()
salt = tok.get_string()
if salt == "-":
salt = ""
else:
salt = binascii.unhexlify(salt.encode())
return cls(rdclass, rdtype, algorithm, flags, iterations, salt)
def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
l = len(self.salt)
file.write(struct.pack("!BBHB", self.algorithm, self.flags, self.iterations, l))
file.write(self.salt)
@classmethod
def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
(algorithm, flags, iterations) = parser.get_struct("!BBH")
salt = parser.get_counted_bytes()
return cls(rdclass, rdtype, algorithm, flags, iterations, salt)

View File

@@ -0,0 +1,53 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
# Copyright (C) 2016 Nominum, Inc.
#
# Permission to use, copy, modify, and distribute this software and its
# documentation for any purpose with or without fee is hereby granted,
# provided that the above copyright notice and this permission notice
# appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import base64
import dns.exception
import dns.immutable
import dns.rdata
import dns.tokenizer
@dns.immutable.immutable
class OPENPGPKEY(dns.rdata.Rdata):
"""OPENPGPKEY record"""
# see: RFC 7929
def __init__(self, rdclass, rdtype, key):
super().__init__(rdclass, rdtype)
self.key = self._as_bytes(key)
def to_text(self, origin=None, relativize=True, **kw):
return dns.rdata._base64ify(self.key, chunksize=None, **kw) # pyright: ignore
@classmethod
def from_text(
cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
):
b64 = tok.concatenate_remaining_identifiers().encode()
key = base64.b64decode(b64)
return cls(rdclass, rdtype, key)
def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
file.write(self.key)
@classmethod
def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
key = parser.get_remaining()
return cls(rdclass, rdtype, key)

View File

@@ -0,0 +1,77 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
# Copyright (C) 2001-2017 Nominum, Inc.
#
# Permission to use, copy, modify, and distribute this software and its
# documentation for any purpose with or without fee is hereby granted,
# provided that the above copyright notice and this permission notice
# appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import struct
import dns.edns
import dns.exception
import dns.immutable
import dns.rdata
# We don't implement from_text, and that's ok.
# pylint: disable=abstract-method
@dns.immutable.immutable
class OPT(dns.rdata.Rdata):
"""OPT record"""
__slots__ = ["options"]
def __init__(self, rdclass, rdtype, options):
"""Initialize an OPT rdata.
*rdclass*, an ``int`` is the rdataclass of the Rdata,
which is also the payload size.
*rdtype*, an ``int`` is the rdatatype of the Rdata.
*options*, a tuple of ``bytes``
"""
super().__init__(rdclass, rdtype)
def as_option(option):
if not isinstance(option, dns.edns.Option):
raise ValueError("option is not a dns.edns.option")
return option
self.options = self._as_tuple(options, as_option)
def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
for opt in self.options:
owire = opt.to_wire()
file.write(struct.pack("!HH", opt.otype, len(owire)))
file.write(owire)
def to_text(self, origin=None, relativize=True, **kw):
return " ".join(opt.to_text() for opt in self.options)
@classmethod
def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
options = []
while parser.remaining() > 0:
(otype, olen) = parser.get_struct("!HH")
with parser.restrict_to(olen):
opt = dns.edns.option_from_wire_parser(otype, parser)
options.append(opt)
return cls(rdclass, rdtype, options)
@property
def payload(self):
"payload size"
return self.rdclass

View File

@@ -0,0 +1,24 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc.
#
# Permission to use, copy, modify, and distribute this software and its
# documentation for any purpose with or without fee is hereby granted,
# provided that the above copyright notice and this permission notice
# appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import dns.immutable
import dns.rdtypes.nsbase
@dns.immutable.immutable
class PTR(dns.rdtypes.nsbase.NSBase):
"""PTR record"""

View File

@@ -0,0 +1,24 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc.
#
# Permission to use, copy, modify, and distribute this software and its
# documentation for any purpose with or without fee is hereby granted,
# provided that the above copyright notice and this permission notice
# appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import dns.immutable
import dns.rdtypes.txtbase
@dns.immutable.immutable
class RESINFO(dns.rdtypes.txtbase.TXTBase):
"""RESINFO record"""

View File

@@ -0,0 +1,58 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc.
#
# Permission to use, copy, modify, and distribute this software and its
# documentation for any purpose with or without fee is hereby granted,
# provided that the above copyright notice and this permission notice
# appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import dns.exception
import dns.immutable
import dns.name
import dns.rdata
@dns.immutable.immutable
class RP(dns.rdata.Rdata):
"""RP record"""
# see: RFC 1183
__slots__ = ["mbox", "txt"]
def __init__(self, rdclass, rdtype, mbox, txt):
super().__init__(rdclass, rdtype)
self.mbox = self._as_name(mbox)
self.txt = self._as_name(txt)
def to_text(self, origin=None, relativize=True, **kw):
mbox = self.mbox.choose_relativity(origin, relativize)
txt = self.txt.choose_relativity(origin, relativize)
return f"{str(mbox)} {str(txt)}"
@classmethod
def from_text(
cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
):
mbox = tok.get_name(origin, relativize, relativize_to)
txt = tok.get_name(origin, relativize, relativize_to)
return cls(rdclass, rdtype, mbox, txt)
def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
self.mbox.to_wire(file, None, origin, canonicalize)
self.txt.to_wire(file, None, origin, canonicalize)
@classmethod
def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
mbox = parser.get_name(origin)
txt = parser.get_name(origin)
return cls(rdclass, rdtype, mbox, txt)

View File

@@ -0,0 +1,155 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
# Copyright (C) 2004-2007, 2009-2011 Nominum, Inc.
#
# Permission to use, copy, modify, and distribute this software and its
# documentation for any purpose with or without fee is hereby granted,
# provided that the above copyright notice and this permission notice
# appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import base64
import calendar
import struct
import time
import dns.dnssectypes
import dns.exception
import dns.immutable
import dns.rdata
import dns.rdatatype
class BadSigTime(dns.exception.DNSException):
"""Time in DNS SIG or RRSIG resource record cannot be parsed."""
def sigtime_to_posixtime(what):
if len(what) <= 10 and what.isdigit():
return int(what)
if len(what) != 14:
raise BadSigTime
year = int(what[0:4])
month = int(what[4:6])
day = int(what[6:8])
hour = int(what[8:10])
minute = int(what[10:12])
second = int(what[12:14])
return calendar.timegm((year, month, day, hour, minute, second, 0, 0, 0))
def posixtime_to_sigtime(what):
return time.strftime("%Y%m%d%H%M%S", time.gmtime(what))
@dns.immutable.immutable
class RRSIG(dns.rdata.Rdata):
"""RRSIG record"""
__slots__ = [
"type_covered",
"algorithm",
"labels",
"original_ttl",
"expiration",
"inception",
"key_tag",
"signer",
"signature",
]
def __init__(
self,
rdclass,
rdtype,
type_covered,
algorithm,
labels,
original_ttl,
expiration,
inception,
key_tag,
signer,
signature,
):
super().__init__(rdclass, rdtype)
self.type_covered = self._as_rdatatype(type_covered)
self.algorithm = dns.dnssectypes.Algorithm.make(algorithm)
self.labels = self._as_uint8(labels)
self.original_ttl = self._as_ttl(original_ttl)
self.expiration = self._as_uint32(expiration)
self.inception = self._as_uint32(inception)
self.key_tag = self._as_uint16(key_tag)
self.signer = self._as_name(signer)
self.signature = self._as_bytes(signature)
def covers(self):
return self.type_covered
def to_text(self, origin=None, relativize=True, **kw):
return (
f"{dns.rdatatype.to_text(self.type_covered)} "
f"{self.algorithm} {self.labels} {self.original_ttl} "
f"{posixtime_to_sigtime(self.expiration)} "
f"{posixtime_to_sigtime(self.inception)} "
f"{self.key_tag} "
f"{self.signer.choose_relativity(origin, relativize)} "
f"{dns.rdata._base64ify(self.signature, **kw)}" # pyright: ignore
)
@classmethod
def from_text(
cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
):
type_covered = dns.rdatatype.from_text(tok.get_string())
algorithm = dns.dnssectypes.Algorithm.from_text(tok.get_string())
labels = tok.get_int()
original_ttl = tok.get_ttl()
expiration = sigtime_to_posixtime(tok.get_string())
inception = sigtime_to_posixtime(tok.get_string())
key_tag = tok.get_int()
signer = tok.get_name(origin, relativize, relativize_to)
b64 = tok.concatenate_remaining_identifiers().encode()
signature = base64.b64decode(b64)
return cls(
rdclass,
rdtype,
type_covered,
algorithm,
labels,
original_ttl,
expiration,
inception,
key_tag,
signer,
signature,
)
def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
header = struct.pack(
"!HBBIIIH",
self.type_covered,
self.algorithm,
self.labels,
self.original_ttl,
self.expiration,
self.inception,
self.key_tag,
)
file.write(header)
self.signer.to_wire(file, None, origin, canonicalize)
file.write(self.signature)
@classmethod
def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
header = parser.get_struct("!HBBIIIH")
signer = parser.get_name(origin)
signature = parser.get_remaining()
return cls(rdclass, rdtype, *header, signer, signature) # pyright: ignore

View File

@@ -0,0 +1,24 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc.
#
# Permission to use, copy, modify, and distribute this software and its
# documentation for any purpose with or without fee is hereby granted,
# provided that the above copyright notice and this permission notice
# appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import dns.immutable
import dns.rdtypes.mxbase
@dns.immutable.immutable
class RT(dns.rdtypes.mxbase.UncompressedDowncasingMX):
"""RT record"""

View File

@@ -0,0 +1,9 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
import dns.immutable
import dns.rdtypes.tlsabase
@dns.immutable.immutable
class SMIMEA(dns.rdtypes.tlsabase.TLSABase):
"""SMIMEA record"""

View File

@@ -0,0 +1,78 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc.
#
# Permission to use, copy, modify, and distribute this software and its
# documentation for any purpose with or without fee is hereby granted,
# provided that the above copyright notice and this permission notice
# appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import struct
import dns.exception
import dns.immutable
import dns.name
import dns.rdata
@dns.immutable.immutable
class SOA(dns.rdata.Rdata):
"""SOA record"""
# see: RFC 1035
__slots__ = ["mname", "rname", "serial", "refresh", "retry", "expire", "minimum"]
def __init__(
self, rdclass, rdtype, mname, rname, serial, refresh, retry, expire, minimum
):
super().__init__(rdclass, rdtype)
self.mname = self._as_name(mname)
self.rname = self._as_name(rname)
self.serial = self._as_uint32(serial)
self.refresh = self._as_ttl(refresh)
self.retry = self._as_ttl(retry)
self.expire = self._as_ttl(expire)
self.minimum = self._as_ttl(minimum)
def to_text(self, origin=None, relativize=True, **kw):
mname = self.mname.choose_relativity(origin, relativize)
rname = self.rname.choose_relativity(origin, relativize)
return f"{mname} {rname} {self.serial} {self.refresh} {self.retry} {self.expire} {self.minimum}"
@classmethod
def from_text(
cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
):
mname = tok.get_name(origin, relativize, relativize_to)
rname = tok.get_name(origin, relativize, relativize_to)
serial = tok.get_uint32()
refresh = tok.get_ttl()
retry = tok.get_ttl()
expire = tok.get_ttl()
minimum = tok.get_ttl()
return cls(
rdclass, rdtype, mname, rname, serial, refresh, retry, expire, minimum
)
def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
self.mname.to_wire(file, compress, origin, canonicalize)
self.rname.to_wire(file, compress, origin, canonicalize)
five_ints = struct.pack(
"!IIIII", self.serial, self.refresh, self.retry, self.expire, self.minimum
)
file.write(five_ints)
@classmethod
def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
mname = parser.get_name(origin)
rname = parser.get_name(origin)
return cls(rdclass, rdtype, mname, rname, *parser.get_struct("!IIIII"))

View File

@@ -0,0 +1,26 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
# Copyright (C) 2006, 2007, 2009-2011 Nominum, Inc.
#
# Permission to use, copy, modify, and distribute this software and its
# documentation for any purpose with or without fee is hereby granted,
# provided that the above copyright notice and this permission notice
# appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import dns.immutable
import dns.rdtypes.txtbase
@dns.immutable.immutable
class SPF(dns.rdtypes.txtbase.TXTBase):
"""SPF record"""
# see: RFC 4408

View File

@@ -0,0 +1,67 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
# Copyright (C) 2005-2007, 2009-2011 Nominum, Inc.
#
# Permission to use, copy, modify, and distribute this software and its
# documentation for any purpose with or without fee is hereby granted,
# provided that the above copyright notice and this permission notice
# appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import binascii
import struct
import dns.immutable
import dns.rdata
import dns.rdatatype
@dns.immutable.immutable
class SSHFP(dns.rdata.Rdata):
"""SSHFP record"""
# See RFC 4255
__slots__ = ["algorithm", "fp_type", "fingerprint"]
def __init__(self, rdclass, rdtype, algorithm, fp_type, fingerprint):
super().__init__(rdclass, rdtype)
self.algorithm = self._as_uint8(algorithm)
self.fp_type = self._as_uint8(fp_type)
self.fingerprint = self._as_bytes(fingerprint, True)
def to_text(self, origin=None, relativize=True, **kw):
kw = kw.copy()
chunksize = kw.pop("chunksize", 128)
fingerprint = dns.rdata._hexify(
self.fingerprint, chunksize=chunksize, **kw # pyright: ignore
)
return f"{self.algorithm} {self.fp_type} {fingerprint}"
@classmethod
def from_text(
cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
):
algorithm = tok.get_uint8()
fp_type = tok.get_uint8()
fingerprint = tok.concatenate_remaining_identifiers().encode()
fingerprint = binascii.unhexlify(fingerprint)
return cls(rdclass, rdtype, algorithm, fp_type, fingerprint)
def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
header = struct.pack("!BB", self.algorithm, self.fp_type)
file.write(header)
file.write(self.fingerprint)
@classmethod
def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
header = parser.get_struct("BB")
fingerprint = parser.get_remaining()
return cls(rdclass, rdtype, header[0], header[1], fingerprint)

View File

@@ -0,0 +1,135 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
# Copyright (C) 2004-2007, 2009-2011 Nominum, Inc.
#
# Permission to use, copy, modify, and distribute this software and its
# documentation for any purpose with or without fee is hereby granted,
# provided that the above copyright notice and this permission notice
# appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import base64
import struct
import dns.exception
import dns.immutable
import dns.rdata
@dns.immutable.immutable
class TKEY(dns.rdata.Rdata):
"""TKEY Record"""
__slots__ = [
"algorithm",
"inception",
"expiration",
"mode",
"error",
"key",
"other",
]
def __init__(
self,
rdclass,
rdtype,
algorithm,
inception,
expiration,
mode,
error,
key,
other=b"",
):
super().__init__(rdclass, rdtype)
self.algorithm = self._as_name(algorithm)
self.inception = self._as_uint32(inception)
self.expiration = self._as_uint32(expiration)
self.mode = self._as_uint16(mode)
self.error = self._as_uint16(error)
self.key = self._as_bytes(key)
self.other = self._as_bytes(other)
def to_text(self, origin=None, relativize=True, **kw):
_algorithm = self.algorithm.choose_relativity(origin, relativize)
key = dns.rdata._base64ify(self.key, 0)
other = ""
if len(self.other) > 0:
other = " " + dns.rdata._base64ify(self.other, 0)
return f"{_algorithm} {self.inception} {self.expiration} {self.mode} {self.error} {key}{other}"
@classmethod
def from_text(
cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
):
algorithm = tok.get_name(relativize=False)
inception = tok.get_uint32()
expiration = tok.get_uint32()
mode = tok.get_uint16()
error = tok.get_uint16()
key_b64 = tok.get_string().encode()
key = base64.b64decode(key_b64)
other_b64 = tok.concatenate_remaining_identifiers(True).encode()
other = base64.b64decode(other_b64)
return cls(
rdclass, rdtype, algorithm, inception, expiration, mode, error, key, other
)
def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
self.algorithm.to_wire(file, compress, origin)
file.write(
struct.pack("!IIHH", self.inception, self.expiration, self.mode, self.error)
)
file.write(struct.pack("!H", len(self.key)))
file.write(self.key)
file.write(struct.pack("!H", len(self.other)))
if len(self.other) > 0:
file.write(self.other)
@classmethod
def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
algorithm = parser.get_name(origin)
inception, expiration, mode, error = parser.get_struct("!IIHH")
key = parser.get_counted_bytes(2)
other = parser.get_counted_bytes(2)
return cls(
rdclass, rdtype, algorithm, inception, expiration, mode, error, key, other
)
# Constants for the mode field - from RFC 2930:
# 2.5 The Mode Field
#
# The mode field specifies the general scheme for key agreement or
# the purpose of the TKEY DNS message. Servers and resolvers
# supporting this specification MUST implement the Diffie-Hellman key
# agreement mode and the key deletion mode for queries. All other
# modes are OPTIONAL. A server supporting TKEY that receives a TKEY
# request with a mode it does not support returns the BADMODE error.
# The following values of the Mode octet are defined, available, or
# reserved:
#
# Value Description
# ----- -----------
# 0 - reserved, see section 7
# 1 server assignment
# 2 Diffie-Hellman exchange
# 3 GSS-API negotiation
# 4 resolver assignment
# 5 key deletion
# 6-65534 - available, see section 7
# 65535 - reserved, see section 7
SERVER_ASSIGNMENT = 1
DIFFIE_HELLMAN_EXCHANGE = 2
GSSAPI_NEGOTIATION = 3
RESOLVER_ASSIGNMENT = 4
KEY_DELETION = 5

View File

@@ -0,0 +1,9 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
import dns.immutable
import dns.rdtypes.tlsabase
@dns.immutable.immutable
class TLSA(dns.rdtypes.tlsabase.TLSABase):
"""TLSA record"""

View File

@@ -0,0 +1,160 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
# Copyright (C) 2001-2017 Nominum, Inc.
#
# Permission to use, copy, modify, and distribute this software and its
# documentation for any purpose with or without fee is hereby granted,
# provided that the above copyright notice and this permission notice
# appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import base64
import struct
import dns.exception
import dns.immutable
import dns.rcode
import dns.rdata
@dns.immutable.immutable
class TSIG(dns.rdata.Rdata):
"""TSIG record"""
__slots__ = [
"algorithm",
"time_signed",
"fudge",
"mac",
"original_id",
"error",
"other",
]
def __init__(
self,
rdclass,
rdtype,
algorithm,
time_signed,
fudge,
mac,
original_id,
error,
other,
):
"""Initialize a TSIG rdata.
*rdclass*, an ``int`` is the rdataclass of the Rdata.
*rdtype*, an ``int`` is the rdatatype of the Rdata.
*algorithm*, a ``dns.name.Name``.
*time_signed*, an ``int``.
*fudge*, an ``int`.
*mac*, a ``bytes``
*original_id*, an ``int``
*error*, an ``int``
*other*, a ``bytes``
"""
super().__init__(rdclass, rdtype)
self.algorithm = self._as_name(algorithm)
self.time_signed = self._as_uint48(time_signed)
self.fudge = self._as_uint16(fudge)
self.mac = self._as_bytes(mac)
self.original_id = self._as_uint16(original_id)
self.error = dns.rcode.Rcode.make(error)
self.other = self._as_bytes(other)
def to_text(self, origin=None, relativize=True, **kw):
algorithm = self.algorithm.choose_relativity(origin, relativize)
error = dns.rcode.to_text(self.error, True)
text = (
f"{algorithm} {self.time_signed} {self.fudge} "
+ f"{len(self.mac)} {dns.rdata._base64ify(self.mac, 0)} "
+ f"{self.original_id} {error} {len(self.other)}"
)
if self.other:
text += f" {dns.rdata._base64ify(self.other, 0)}"
return text
@classmethod
def from_text(
cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
):
algorithm = tok.get_name(relativize=False)
time_signed = tok.get_uint48()
fudge = tok.get_uint16()
mac_len = tok.get_uint16()
mac = base64.b64decode(tok.get_string())
if len(mac) != mac_len:
raise SyntaxError("invalid MAC")
original_id = tok.get_uint16()
error = dns.rcode.from_text(tok.get_string())
other_len = tok.get_uint16()
if other_len > 0:
other = base64.b64decode(tok.get_string())
if len(other) != other_len:
raise SyntaxError("invalid other data")
else:
other = b""
return cls(
rdclass,
rdtype,
algorithm,
time_signed,
fudge,
mac,
original_id,
error,
other,
)
def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
self.algorithm.to_wire(file, None, origin, False)
file.write(
struct.pack(
"!HIHH",
(self.time_signed >> 32) & 0xFFFF,
self.time_signed & 0xFFFFFFFF,
self.fudge,
len(self.mac),
)
)
file.write(self.mac)
file.write(struct.pack("!HHH", self.original_id, self.error, len(self.other)))
file.write(self.other)
@classmethod
def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
algorithm = parser.get_name()
time_signed = parser.get_uint48()
fudge = parser.get_uint16()
mac = parser.get_counted_bytes(2)
(original_id, error) = parser.get_struct("!HH")
other = parser.get_counted_bytes(2)
return cls(
rdclass,
rdtype,
algorithm,
time_signed,
fudge,
mac,
original_id,
error,
other,
)

View File

@@ -0,0 +1,24 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc.
#
# Permission to use, copy, modify, and distribute this software and its
# documentation for any purpose with or without fee is hereby granted,
# provided that the above copyright notice and this permission notice
# appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import dns.immutable
import dns.rdtypes.txtbase
@dns.immutable.immutable
class TXT(dns.rdtypes.txtbase.TXTBase):
"""TXT record"""

View File

@@ -0,0 +1,79 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc.
# Copyright (C) 2015 Red Hat, Inc.
#
# Permission to use, copy, modify, and distribute this software and its
# documentation for any purpose with or without fee is hereby granted,
# provided that the above copyright notice and this permission notice
# appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import struct
import dns.exception
import dns.immutable
import dns.name
import dns.rdata
import dns.rdtypes.util
@dns.immutable.immutable
class URI(dns.rdata.Rdata):
"""URI record"""
# see RFC 7553
__slots__ = ["priority", "weight", "target"]
def __init__(self, rdclass, rdtype, priority, weight, target):
super().__init__(rdclass, rdtype)
self.priority = self._as_uint16(priority)
self.weight = self._as_uint16(weight)
self.target = self._as_bytes(target, True)
if len(self.target) == 0:
raise dns.exception.SyntaxError("URI target cannot be empty")
def to_text(self, origin=None, relativize=True, **kw):
return f'{self.priority} {self.weight} "{self.target.decode()}"'
@classmethod
def from_text(
cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
):
priority = tok.get_uint16()
weight = tok.get_uint16()
target = tok.get().unescape()
if not (target.is_quoted_string() or target.is_identifier()):
raise dns.exception.SyntaxError("URI target must be a string")
return cls(rdclass, rdtype, priority, weight, target.value)
def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
two_ints = struct.pack("!HH", self.priority, self.weight)
file.write(two_ints)
file.write(self.target)
@classmethod
def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
(priority, weight) = parser.get_struct("!HH")
target = parser.get_remaining()
if len(target) == 0:
raise dns.exception.FormError("URI target may not be empty")
return cls(rdclass, rdtype, priority, weight, target)
def _processing_priority(self):
return self.priority
def _processing_weight(self):
return self.weight
@classmethod
def _processing_order(cls, iterable):
return dns.rdtypes.util.weighted_processing_order(iterable)

View File

@@ -0,0 +1,9 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
import dns.immutable
import dns.rdtypes.txtbase
@dns.immutable.immutable
class WALLET(dns.rdtypes.txtbase.TXTBase):
"""WALLET record"""

Some files were not shown because too many files have changed in this diff Show More