API refactor
All checks were successful
continuous-integration/drone/push Build is passing

This commit is contained in:
2025-10-07 16:25:52 +09:00
parent 76d0d86211
commit 91c7e04474
1171 changed files with 81940 additions and 44117 deletions

View File

@@ -1,5 +1,5 @@
from uvicorn.config import Config
from uvicorn.main import Server, main, run
__version__ = "0.24.0"
__version__ = "0.37.0"
__all__ = ["main", "run", "Config", "Server"]

View File

@@ -0,0 +1,84 @@
from __future__ import annotations
import asyncio
import sys
from collections.abc import Callable, Coroutine
from typing import Any, TypeVar
_T = TypeVar("_T")
if sys.version_info >= (3, 12):
asyncio_run = asyncio.run
elif sys.version_info >= (3, 11):
def asyncio_run(
main: Coroutine[Any, Any, _T],
*,
debug: bool = False,
loop_factory: Callable[[], asyncio.AbstractEventLoop] | None = None,
) -> _T:
# asyncio.run from Python 3.12
# https://docs.python.org/3/license.html#psf-license
with asyncio.Runner(debug=debug, loop_factory=loop_factory) as runner:
return runner.run(main)
else:
# modified version of asyncio.run from Python 3.10 to add loop_factory kwarg
# https://docs.python.org/3/license.html#psf-license
def asyncio_run(
main: Coroutine[Any, Any, _T],
*,
debug: bool = False,
loop_factory: Callable[[], asyncio.AbstractEventLoop] | None = None,
) -> _T:
try:
asyncio.get_running_loop()
except RuntimeError:
pass
else:
raise RuntimeError("asyncio.run() cannot be called from a running event loop")
if not asyncio.iscoroutine(main):
raise ValueError(f"a coroutine was expected, got {main!r}")
if loop_factory is None:
loop = asyncio.new_event_loop()
else:
loop = loop_factory()
try:
if loop_factory is None:
asyncio.set_event_loop(loop)
if debug is not None:
loop.set_debug(debug)
return loop.run_until_complete(main)
finally:
try:
_cancel_all_tasks(loop)
loop.run_until_complete(loop.shutdown_asyncgens())
loop.run_until_complete(loop.shutdown_default_executor())
finally:
if loop_factory is None:
asyncio.set_event_loop(None)
loop.close()
def _cancel_all_tasks(loop: asyncio.AbstractEventLoop) -> None:
to_cancel = asyncio.all_tasks(loop)
if not to_cancel:
return
for task in to_cancel:
task.cancel()
loop.run_until_complete(asyncio.gather(*to_cancel, return_exceptions=True))
for task in to_cancel:
if task.cancelled():
continue
if task.exception() is not None:
loop.call_exception_handler(
{
"message": "unhandled exception during asyncio.run() shutdown",
"exception": task.exception(),
"task": task,
}
)

View File

@@ -2,12 +2,15 @@
Some light wrappers around Python's multiprocessing, to deal with cleanly
starting child processes.
"""
from __future__ import annotations
import multiprocessing
import os
import sys
from multiprocessing.context import SpawnProcess
from socket import socket
from typing import Callable, List, Optional
from typing import Callable
from uvicorn.config import Config
@@ -18,7 +21,7 @@ spawn = multiprocessing.get_context("spawn")
def get_subprocess(
config: Config,
target: Callable[..., None],
sockets: List[socket],
sockets: list[socket],
) -> SpawnProcess:
"""
Called in the parent process, to instantiate a new child process instance.
@@ -32,10 +35,10 @@ def get_subprocess(
"""
# We pass across the stdin fileno, and reopen it in the child process.
# This is required for some debugging environments.
stdin_fileno: Optional[int]
try:
stdin_fileno = sys.stdin.fileno()
except OSError:
# The `sys.stdin` can be `None`, see https://docs.python.org/3/library/sys.html#sys.__stdin__.
except (AttributeError, OSError):
stdin_fileno = None
kwargs = {
@@ -51,8 +54,8 @@ def get_subprocess(
def subprocess_started(
config: Config,
target: Callable[..., None],
sockets: List[socket],
stdin_fileno: Optional[int],
sockets: list[socket],
stdin_fileno: int | None,
) -> None:
"""
Called when the child process starts.
@@ -67,10 +70,15 @@ def subprocess_started(
"""
# Re-open stdin.
if stdin_fileno is not None:
sys.stdin = os.fdopen(stdin_fileno)
sys.stdin = os.fdopen(stdin_fileno) # pragma: full coverage
# Logging needs to be setup again for each child.
config.configure_logging()
# Now we can call into `Server.run(sockets=sockets)`
target(sockets=sockets)
try:
# Now we can call into `Server.run(sockets=sockets)`
target(sockets=sockets)
except KeyboardInterrupt: # pragma: no cover
# supress the exception to avoid a traceback from subprocess.Popen
# the parent already expects us to end, so no vital information is lost
pass

View File

@@ -28,25 +28,12 @@ ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""
from __future__ import annotations
import sys
import types
from typing import (
Any,
Awaitable,
Callable,
Dict,
Iterable,
MutableMapping,
Optional,
Tuple,
Type,
Union,
)
if sys.version_info >= (3, 8): # pragma: py-lt-38
from typing import Literal, Protocol, TypedDict
else: # pragma: py-gte-38
from typing_extensions import Literal, Protocol, TypedDict
from collections.abc import Awaitable, Iterable, MutableMapping
from typing import Any, Callable, Literal, Optional, Protocol, TypedDict, Union
if sys.version_info >= (3, 11): # pragma: py-lt-311
from typing import NotRequired
@@ -55,15 +42,15 @@ else: # pragma: py-gte-311
# WSGI
Environ = MutableMapping[str, Any]
ExcInfo = Tuple[Type[BaseException], BaseException, Optional[types.TracebackType]]
StartResponse = Callable[[str, Iterable[Tuple[str, str]], Optional[ExcInfo]], None]
ExcInfo = tuple[type[BaseException], BaseException, Optional[types.TracebackType]]
StartResponse = Callable[[str, Iterable[tuple[str, str]], Optional[ExcInfo]], None]
WSGIApp = Callable[[Environ, StartResponse], Union[Iterable[bytes], BaseException]]
# ASGI
class ASGIVersions(TypedDict):
spec_version: str
version: Union[Literal["2.0"], Literal["3.0"]]
version: Literal["2.0"] | Literal["3.0"]
class HTTPScope(TypedDict):
@@ -76,11 +63,11 @@ class HTTPScope(TypedDict):
raw_path: bytes
query_string: bytes
root_path: str
headers: Iterable[Tuple[bytes, bytes]]
client: Optional[Tuple[str, int]]
server: Optional[Tuple[str, Optional[int]]]
state: NotRequired[Dict[str, Any]]
extensions: NotRequired[Dict[str, Dict[object, object]]]
headers: Iterable[tuple[bytes, bytes]]
client: tuple[str, int] | None
server: tuple[str, int | None] | None
state: NotRequired[dict[str, Any]]
extensions: NotRequired[dict[str, dict[object, object]]]
class WebSocketScope(TypedDict):
@@ -92,18 +79,18 @@ class WebSocketScope(TypedDict):
raw_path: bytes
query_string: bytes
root_path: str
headers: Iterable[Tuple[bytes, bytes]]
client: Optional[Tuple[str, int]]
server: Optional[Tuple[str, Optional[int]]]
headers: Iterable[tuple[bytes, bytes]]
client: tuple[str, int] | None
server: tuple[str, int | None] | None
subprotocols: Iterable[str]
state: NotRequired[Dict[str, Any]]
extensions: NotRequired[Dict[str, Dict[object, object]]]
state: NotRequired[dict[str, Any]]
extensions: NotRequired[dict[str, dict[object, object]]]
class LifespanScope(TypedDict):
type: Literal["lifespan"]
asgi: ASGIVersions
state: NotRequired[Dict[str, Any]]
state: NotRequired[dict[str, Any]]
WWWScope = Union[HTTPScope, WebSocketScope]
@@ -118,32 +105,32 @@ class HTTPRequestEvent(TypedDict):
class HTTPResponseDebugEvent(TypedDict):
type: Literal["http.response.debug"]
info: Dict[str, object]
info: dict[str, object]
class HTTPResponseStartEvent(TypedDict):
type: Literal["http.response.start"]
status: int
headers: Iterable[Tuple[bytes, bytes]]
headers: NotRequired[Iterable[tuple[bytes, bytes]]]
trailers: NotRequired[bool]
class HTTPResponseBodyEvent(TypedDict):
type: Literal["http.response.body"]
body: bytes
more_body: bool
more_body: NotRequired[bool]
class HTTPResponseTrailersEvent(TypedDict):
type: Literal["http.response.trailers"]
headers: Iterable[Tuple[bytes, bytes]]
headers: Iterable[tuple[bytes, bytes]]
more_trailers: bool
class HTTPServerPushEvent(TypedDict):
type: Literal["http.response.push"]
path: str
headers: Iterable[Tuple[bytes, bytes]]
headers: Iterable[tuple[bytes, bytes]]
class HTTPDisconnectEvent(TypedDict):
@@ -156,43 +143,62 @@ class WebSocketConnectEvent(TypedDict):
class WebSocketAcceptEvent(TypedDict):
type: Literal["websocket.accept"]
subprotocol: Optional[str]
headers: Iterable[Tuple[bytes, bytes]]
subprotocol: NotRequired[str | None]
headers: NotRequired[Iterable[tuple[bytes, bytes]]]
class WebSocketReceiveEvent(TypedDict):
class _WebSocketReceiveEventBytes(TypedDict):
type: Literal["websocket.receive"]
bytes: Optional[bytes]
text: Optional[str]
bytes: bytes
text: NotRequired[None]
class WebSocketSendEvent(TypedDict):
class _WebSocketReceiveEventText(TypedDict):
type: Literal["websocket.receive"]
bytes: NotRequired[None]
text: str
WebSocketReceiveEvent = Union[_WebSocketReceiveEventBytes, _WebSocketReceiveEventText]
class _WebSocketSendEventBytes(TypedDict):
type: Literal["websocket.send"]
bytes: Optional[bytes]
text: Optional[str]
bytes: bytes
text: NotRequired[None]
class _WebSocketSendEventText(TypedDict):
type: Literal["websocket.send"]
bytes: NotRequired[None]
text: str
WebSocketSendEvent = Union[_WebSocketSendEventBytes, _WebSocketSendEventText]
class WebSocketResponseStartEvent(TypedDict):
type: Literal["websocket.http.response.start"]
status: int
headers: Iterable[Tuple[bytes, bytes]]
headers: Iterable[tuple[bytes, bytes]]
class WebSocketResponseBodyEvent(TypedDict):
type: Literal["websocket.http.response.body"]
body: bytes
more_body: bool
more_body: NotRequired[bool]
class WebSocketDisconnectEvent(TypedDict):
type: Literal["websocket.disconnect"]
code: int
reason: NotRequired[str | None]
class WebSocketCloseEvent(TypedDict):
type: Literal["websocket.close"]
code: int
reason: Optional[str]
code: NotRequired[int]
reason: NotRequired[str | None]
class LifespanStartupEvent(TypedDict):
@@ -221,9 +227,7 @@ class LifespanShutdownFailedEvent(TypedDict):
message: str
WebSocketEvent = Union[
WebSocketReceiveEvent, WebSocketDisconnectEvent, WebSocketConnectEvent
]
WebSocketEvent = Union[WebSocketReceiveEvent, WebSocketDisconnectEvent, WebSocketConnectEvent]
ASGIReceiveEvent = Union[
@@ -260,16 +264,12 @@ ASGISendCallable = Callable[[ASGISendEvent], Awaitable[None]]
class ASGI2Protocol(Protocol):
def __init__(self, scope: Scope) -> None:
... # pragma: no cover
def __init__(self, scope: Scope) -> None: ... # pragma: no cover
async def __call__(
self, receive: ASGIReceiveCallable, send: ASGISendCallable
) -> None:
... # pragma: no cover
async def __call__(self, receive: ASGIReceiveCallable, send: ASGISendCallable) -> None: ... # pragma: no cover
ASGI2Application = Type[ASGI2Protocol]
ASGI2Application = type[ASGI2Protocol]
ASGI3Application = Callable[
[
Scope,

View File

@@ -1,3 +1,5 @@
from __future__ import annotations
import asyncio
import inspect
import json
@@ -7,19 +9,10 @@ import os
import socket
import ssl
import sys
from collections.abc import Awaitable
from configparser import RawConfigParser
from pathlib import Path
from typing import (
Any,
Awaitable,
Callable,
Dict,
List,
Literal,
Optional,
Tuple,
Type,
Union,
)
from typing import IO, Any, Callable, Literal
import click
@@ -32,12 +25,12 @@ from uvicorn.middleware.proxy_headers import ProxyHeadersMiddleware
from uvicorn.middleware.wsgi import WSGIMiddleware
HTTPProtocolType = Literal["auto", "h11", "httptools"]
WSProtocolType = Literal["auto", "none", "websockets", "wsproto"]
WSProtocolType = Literal["auto", "none", "websockets", "websockets-sansio", "wsproto"]
LifespanType = Literal["auto", "on", "off"]
LoopSetupType = Literal["none", "auto", "asyncio", "uvloop"]
LoopFactoryType = Literal["none", "auto", "asyncio", "uvloop"]
InterfaceType = Literal["auto", "asgi3", "asgi2", "wsgi"]
LOG_LEVELS: Dict[str, int] = {
LOG_LEVELS: dict[str, int] = {
"critical": logging.CRITICAL,
"error": logging.ERROR,
"warning": logging.WARNING,
@@ -45,33 +38,34 @@ LOG_LEVELS: Dict[str, int] = {
"debug": logging.DEBUG,
"trace": TRACE_LOG_LEVEL,
}
HTTP_PROTOCOLS: Dict[HTTPProtocolType, str] = {
HTTP_PROTOCOLS: dict[str, str] = {
"auto": "uvicorn.protocols.http.auto:AutoHTTPProtocol",
"h11": "uvicorn.protocols.http.h11_impl:H11Protocol",
"httptools": "uvicorn.protocols.http.httptools_impl:HttpToolsProtocol",
}
WS_PROTOCOLS: Dict[WSProtocolType, Optional[str]] = {
WS_PROTOCOLS: dict[str, str | None] = {
"auto": "uvicorn.protocols.websockets.auto:AutoWebSocketsProtocol",
"none": None,
"websockets": "uvicorn.protocols.websockets.websockets_impl:WebSocketProtocol",
"websockets-sansio": "uvicorn.protocols.websockets.websockets_sansio_impl:WebSocketsSansIOProtocol",
"wsproto": "uvicorn.protocols.websockets.wsproto_impl:WSProtocol",
}
LIFESPAN: Dict[LifespanType, str] = {
LIFESPAN: dict[str, str] = {
"auto": "uvicorn.lifespan.on:LifespanOn",
"on": "uvicorn.lifespan.on:LifespanOn",
"off": "uvicorn.lifespan.off:LifespanOff",
}
LOOP_SETUPS: Dict[LoopSetupType, Optional[str]] = {
LOOP_FACTORIES: dict[str, str | None] = {
"none": None,
"auto": "uvicorn.loops.auto:auto_loop_setup",
"asyncio": "uvicorn.loops.asyncio:asyncio_setup",
"uvloop": "uvicorn.loops.uvloop:uvloop_setup",
"auto": "uvicorn.loops.auto:auto_loop_factory",
"asyncio": "uvicorn.loops.asyncio:asyncio_loop_factory",
"uvloop": "uvicorn.loops.uvloop:uvloop_loop_factory",
}
INTERFACES: List[InterfaceType] = ["auto", "asgi3", "asgi2", "wsgi"]
INTERFACES: list[InterfaceType] = ["auto", "asgi3", "asgi2", "wsgi"]
SSL_PROTOCOL_VERSION: int = ssl.PROTOCOL_TLS_SERVER
LOGGING_CONFIG: Dict[str, Any] = {
LOGGING_CONFIG: dict[str, Any] = {
"version": 1,
"disable_existing_loggers": False,
"formatters": {
@@ -108,13 +102,13 @@ logger = logging.getLogger("uvicorn.error")
def create_ssl_context(
certfile: Union[str, os.PathLike],
keyfile: Optional[Union[str, os.PathLike]],
password: Optional[str],
certfile: str | os.PathLike[str],
keyfile: str | os.PathLike[str] | None,
password: str | None,
ssl_version: int,
cert_reqs: int,
ca_certs: Optional[Union[str, os.PathLike]],
ciphers: Optional[str],
ca_certs: str | os.PathLike[str] | None,
ciphers: str | None,
) -> ssl.SSLContext:
ctx = ssl.SSLContext(ssl_version)
get_password = (lambda: password) if password else None
@@ -132,22 +126,20 @@ def is_dir(path: Path) -> bool:
if not path.is_absolute():
path = path.resolve()
return path.is_dir()
except OSError:
except OSError: # pragma: full coverage
return False
def resolve_reload_patterns(
patterns_list: List[str], directories_list: List[str]
) -> Tuple[List[str], List[Path]]:
directories: List[Path] = list(set(map(Path, directories_list.copy())))
patterns: List[str] = patterns_list.copy()
def resolve_reload_patterns(patterns_list: list[str], directories_list: list[str]) -> tuple[list[str], list[Path]]:
directories: list[Path] = list(set(map(Path, directories_list.copy())))
patterns: list[str] = patterns_list.copy()
current_working_directory = Path.cwd()
for pattern in patterns_list:
# Special case for the .* pattern, otherwise this would only match
# hidden directories which is probably undesired
if pattern == ".*":
continue
continue # pragma: py-not-linux
patterns.append(pattern)
if is_dir(Path(pattern)):
directories.append(Path(pattern))
@@ -159,15 +151,13 @@ def resolve_reload_patterns(
directories = list(set(directories))
directories = list(map(Path, directories))
directories = list(map(lambda x: x.resolve(), directories))
directories = list(
{reload_path for reload_path in directories if is_dir(reload_path)}
)
directories = list({reload_path for reload_path in directories if is_dir(reload_path)})
children = []
for j in range(len(directories)):
for k in range(j + 1, len(directories)):
for k in range(j + 1, len(directories)): # pragma: full coverage
if directories[j] in directories[k].parents:
children.append(directories[k]) # pragma: py-darwin
children.append(directories[k])
elif directories[k] in directories[j].parents:
children.append(directories[j])
@@ -176,7 +166,7 @@ def resolve_reload_patterns(
return list(set(patterns)), directories
def _normalize_dirs(dirs: Union[List[str], str, None]) -> List[str]:
def _normalize_dirs(dirs: list[str] | str | None) -> list[str]:
if dirs is None:
return []
if isinstance(dirs, str):
@@ -187,54 +177,55 @@ def _normalize_dirs(dirs: Union[List[str], str, None]) -> List[str]:
class Config:
def __init__(
self,
app: Union["ASGIApplication", Callable, str],
app: ASGIApplication | Callable[..., Any] | str,
host: str = "127.0.0.1",
port: int = 8000,
uds: Optional[str] = None,
fd: Optional[int] = None,
loop: LoopSetupType = "auto",
http: Union[Type[asyncio.Protocol], HTTPProtocolType] = "auto",
ws: Union[Type[asyncio.Protocol], WSProtocolType] = "auto",
uds: str | None = None,
fd: int | None = None,
loop: LoopFactoryType | str = "auto",
http: type[asyncio.Protocol] | HTTPProtocolType | str = "auto",
ws: type[asyncio.Protocol] | WSProtocolType | str = "auto",
ws_max_size: int = 16 * 1024 * 1024,
ws_max_queue: int = 32,
ws_ping_interval: Optional[float] = 20.0,
ws_ping_timeout: Optional[float] = 20.0,
ws_ping_interval: float | None = 20.0,
ws_ping_timeout: float | None = 20.0,
ws_per_message_deflate: bool = True,
lifespan: LifespanType = "auto",
env_file: Optional[Union[str, os.PathLike]] = None,
log_config: Optional[Union[Dict[str, Any], str]] = LOGGING_CONFIG,
log_level: Optional[Union[str, int]] = None,
env_file: str | os.PathLike[str] | None = None,
log_config: dict[str, Any] | str | RawConfigParser | IO[Any] | None = LOGGING_CONFIG,
log_level: str | int | None = None,
access_log: bool = True,
use_colors: Optional[bool] = None,
use_colors: bool | None = None,
interface: InterfaceType = "auto",
reload: bool = False,
reload_dirs: Optional[Union[List[str], str]] = None,
reload_dirs: list[str] | str | None = None,
reload_delay: float = 0.25,
reload_includes: Optional[Union[List[str], str]] = None,
reload_excludes: Optional[Union[List[str], str]] = None,
workers: Optional[int] = None,
reload_includes: list[str] | str | None = None,
reload_excludes: list[str] | str | None = None,
workers: int | None = None,
proxy_headers: bool = True,
server_header: bool = True,
date_header: bool = True,
forwarded_allow_ips: Optional[Union[List[str], str]] = None,
forwarded_allow_ips: list[str] | str | None = None,
root_path: str = "",
limit_concurrency: Optional[int] = None,
limit_max_requests: Optional[int] = None,
limit_concurrency: int | None = None,
limit_max_requests: int | None = None,
backlog: int = 2048,
timeout_keep_alive: int = 5,
timeout_notify: int = 30,
timeout_graceful_shutdown: Optional[int] = None,
callback_notify: Optional[Callable[..., Awaitable[None]]] = None,
ssl_keyfile: Optional[str] = None,
ssl_certfile: Optional[Union[str, os.PathLike]] = None,
ssl_keyfile_password: Optional[str] = None,
timeout_graceful_shutdown: int | None = None,
timeout_worker_healthcheck: int = 5,
callback_notify: Callable[..., Awaitable[None]] | None = None,
ssl_keyfile: str | os.PathLike[str] | None = None,
ssl_certfile: str | os.PathLike[str] | None = None,
ssl_keyfile_password: str | None = None,
ssl_version: int = SSL_PROTOCOL_VERSION,
ssl_cert_reqs: int = ssl.CERT_NONE,
ssl_ca_certs: Optional[str] = None,
ssl_ca_certs: str | os.PathLike[str] | None = None,
ssl_ciphers: str = "TLSv1",
headers: Optional[List[Tuple[str, str]]] = None,
headers: list[tuple[str, str]] | None = None,
factory: bool = False,
h11_max_incomplete_event_size: Optional[int] = None,
h11_max_incomplete_event_size: int | None = None,
):
self.app = app
self.host = host
@@ -268,6 +259,7 @@ class Config:
self.timeout_keep_alive = timeout_keep_alive
self.timeout_notify = timeout_notify
self.timeout_graceful_shutdown = timeout_graceful_shutdown
self.timeout_worker_healthcheck = timeout_worker_healthcheck
self.callback_notify = callback_notify
self.ssl_keyfile = ssl_keyfile
self.ssl_certfile = ssl_certfile
@@ -276,25 +268,22 @@ class Config:
self.ssl_cert_reqs = ssl_cert_reqs
self.ssl_ca_certs = ssl_ca_certs
self.ssl_ciphers = ssl_ciphers
self.headers: List[Tuple[str, str]] = headers or []
self.encoded_headers: List[Tuple[bytes, bytes]] = []
self.headers: list[tuple[str, str]] = headers or []
self.encoded_headers: list[tuple[bytes, bytes]] = []
self.factory = factory
self.h11_max_incomplete_event_size = h11_max_incomplete_event_size
self.loaded = False
self.configure_logging()
self.reload_dirs: List[Path] = []
self.reload_dirs_excludes: List[Path] = []
self.reload_includes: List[str] = []
self.reload_excludes: List[str] = []
self.reload_dirs: list[Path] = []
self.reload_dirs_excludes: list[Path] = []
self.reload_includes: list[str] = []
self.reload_excludes: list[str] = []
if (
reload_dirs or reload_includes or reload_excludes
) and not self.should_reload:
if (reload_dirs or reload_includes or reload_excludes) and not self.should_reload:
logger.warning(
"Current configuration will not reload as not all conditions are met, "
"please refer to documentation."
"Current configuration will not reload as not all conditions are met, please refer to documentation."
)
if self.should_reload:
@@ -302,30 +291,23 @@ class Config:
reload_includes = _normalize_dirs(reload_includes)
reload_excludes = _normalize_dirs(reload_excludes)
self.reload_includes, self.reload_dirs = resolve_reload_patterns(
reload_includes, reload_dirs
)
self.reload_includes, self.reload_dirs = resolve_reload_patterns(reload_includes, reload_dirs)
self.reload_excludes, self.reload_dirs_excludes = resolve_reload_patterns(
reload_excludes, []
)
self.reload_excludes, self.reload_dirs_excludes = resolve_reload_patterns(reload_excludes, [])
reload_dirs_tmp = self.reload_dirs.copy()
for directory in self.reload_dirs_excludes:
for reload_directory in reload_dirs_tmp:
if (
directory == reload_directory
or directory in reload_directory.parents
):
if directory == reload_directory or directory in reload_directory.parents:
try:
self.reload_dirs.remove(reload_directory)
except ValueError:
except ValueError: # pragma: full coverage
pass
for pattern in self.reload_excludes:
if pattern in self.reload_includes:
self.reload_includes.remove(pattern)
self.reload_includes.remove(pattern) # pragma: full coverage
if not self.reload_dirs:
if reload_dirs:
@@ -334,7 +316,7 @@ class Config:
+ "directories, watching current working directory.",
reload_dirs,
)
self.reload_dirs = [Path(os.getcwd())]
self.reload_dirs = [Path.cwd()]
logger.info(
"Will watch for changes in these directories: %s",
@@ -350,20 +332,18 @@ class Config:
if workers is None and "WEB_CONCURRENCY" in os.environ:
self.workers = int(os.environ["WEB_CONCURRENCY"])
self.forwarded_allow_ips: Union[List[str], str]
self.forwarded_allow_ips: list[str] | str
if forwarded_allow_ips is None:
self.forwarded_allow_ips = os.environ.get(
"FORWARDED_ALLOW_IPS", "127.0.0.1"
)
self.forwarded_allow_ips = os.environ.get("FORWARDED_ALLOW_IPS", "127.0.0.1")
else:
self.forwarded_allow_ips = forwarded_allow_ips
self.forwarded_allow_ips = forwarded_allow_ips # pragma: full coverage
if self.reload and self.workers > 1:
logger.warning('"workers" flag is ignored when reloading is enabled.')
@property
def asgi_version(self) -> Literal["2.0", "3.0"]:
mapping: Dict[str, Literal["2.0", "3.0"]] = {
mapping: dict[str, Literal["2.0", "3.0"]] = {
"asgi2": "2.0",
"asgi3": "3.0",
"wsgi": "3.0",
@@ -384,18 +364,14 @@ class Config:
if self.log_config is not None:
if isinstance(self.log_config, dict):
if self.use_colors in (True, False):
self.log_config["formatters"]["default"][
"use_colors"
] = self.use_colors
self.log_config["formatters"]["access"][
"use_colors"
] = self.use_colors
self.log_config["formatters"]["default"]["use_colors"] = self.use_colors
self.log_config["formatters"]["access"]["use_colors"] = self.use_colors
logging.config.dictConfig(self.log_config)
elif self.log_config.endswith(".json"):
elif isinstance(self.log_config, str) and self.log_config.endswith(".json"):
with open(self.log_config) as file:
loaded_config = json.load(file)
logging.config.dictConfig(loaded_config)
elif self.log_config.endswith((".yaml", ".yml")):
elif isinstance(self.log_config, str) and self.log_config.endswith((".yaml", ".yml")):
# Install the PyYAML package or the uvicorn[standard] optional
# dependencies to enable this functionality.
import yaml
@@ -406,9 +382,7 @@ class Config:
else:
# See the note about fileConfig() here:
# https://docs.python.org/3/library/logging.config.html#configuration-file-format
logging.config.fileConfig(
self.log_config, disable_existing_loggers=False
)
logging.config.fileConfig(self.log_config, disable_existing_loggers=False)
if self.log_level is not None:
if isinstance(self.log_level, str):
@@ -427,7 +401,7 @@ class Config:
if self.is_ssl:
assert self.ssl_certfile
self.ssl: Optional[ssl.SSLContext] = create_ssl_context(
self.ssl: ssl.SSLContext | None = create_ssl_context(
keyfile=self.ssl_keyfile,
certfile=self.ssl_certfile,
password=self.ssl_keyfile_password,
@@ -439,10 +413,7 @@ class Config:
else:
self.ssl = None
encoded_headers = [
(key.lower().encode("latin1"), value.encode("latin1"))
for key, value in self.headers
]
encoded_headers = [(key.lower().encode("latin1"), value.encode("latin1")) for key, value in self.headers]
self.encoded_headers = (
[(b"server", b"uvicorn")] + encoded_headers
if b"server" not in dict(encoded_headers) and self.server_header
@@ -450,14 +421,14 @@ class Config:
)
if isinstance(self.http, str):
http_protocol_class = import_from_string(HTTP_PROTOCOLS[self.http])
self.http_protocol_class: Type[asyncio.Protocol] = http_protocol_class
http_protocol_class = import_from_string(HTTP_PROTOCOLS.get(self.http, self.http))
self.http_protocol_class: type[asyncio.Protocol] = http_protocol_class
else:
self.http_protocol_class = self.http
if isinstance(self.ws, str):
ws_protocol_class = import_from_string(WS_PROTOCOLS[self.ws])
self.ws_protocol_class: Optional[Type[asyncio.Protocol]] = ws_protocol_class
ws_protocol_class = import_from_string(WS_PROTOCOLS.get(self.ws, self.ws))
self.ws_protocol_class: type[asyncio.Protocol] | None = ws_protocol_class
else:
self.ws_protocol_class = self.ws
@@ -478,18 +449,17 @@ class Config:
else:
if not self.factory:
logger.warning(
"ASGI app factory detected. Using it, "
"but please consider setting the --factory flag explicitly."
"ASGI app factory detected. Using it, but please consider setting the --factory flag explicitly."
)
if self.interface == "auto":
if inspect.isclass(self.loaded_app):
use_asgi_3 = hasattr(self.loaded_app, "__await__")
elif inspect.isfunction(self.loaded_app):
use_asgi_3 = asyncio.iscoroutinefunction(self.loaded_app)
use_asgi_3 = inspect.iscoroutinefunction(self.loaded_app)
else:
call = getattr(self.loaded_app, "__call__", None)
use_asgi_3 = asyncio.iscoroutinefunction(call)
use_asgi_3 = inspect.iscoroutinefunction(call)
self.interface = "asgi3" if use_asgi_3 else "asgi2"
if self.interface == "wsgi":
@@ -501,19 +471,32 @@ class Config:
if logger.getEffectiveLevel() <= TRACE_LOG_LEVEL:
self.loaded_app = MessageLoggerMiddleware(self.loaded_app)
if self.proxy_headers:
self.loaded_app = ProxyHeadersMiddleware(
self.loaded_app, trusted_hosts=self.forwarded_allow_ips
)
self.loaded_app = ProxyHeadersMiddleware(self.loaded_app, trusted_hosts=self.forwarded_allow_ips)
self.loaded = True
def setup_event_loop(self) -> None:
loop_setup: Optional[Callable] = import_from_string(LOOP_SETUPS[self.loop])
if loop_setup is not None:
loop_setup(use_subprocess=self.use_subprocess)
raise AttributeError(
"The `setup_event_loop` method was replaced by `get_loop_factory` in uvicorn 0.36.0.\n"
"None of those methods are supposed to be used directly. If you are doing it, please let me know here: "
"https://github.com/Kludex/uvicorn/discussions/2706. Thank you, and sorry for the inconvenience."
)
def get_loop_factory(self) -> Callable[[], asyncio.AbstractEventLoop] | None:
if self.loop in LOOP_FACTORIES:
loop_factory: Callable[..., Any] | None = import_from_string(LOOP_FACTORIES[self.loop])
else:
try:
return import_from_string(self.loop)
except ImportFromStringError as exc:
logger.error("Error loading custom loop setup function. %s" % exc)
sys.exit(1)
if loop_factory is None:
return None
return loop_factory(use_subprocess=self.use_subprocess)
def bind_socket(self) -> socket.socket:
logger_args: List[Union[str, int]]
logger_args: list[str | int]
if self.uds: # pragma: py-win32
path = self.uds
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
@@ -521,33 +504,25 @@ class Config:
sock.bind(path)
uds_perms = 0o666
os.chmod(self.uds, uds_perms)
except OSError as exc:
except OSError as exc: # pragma: full coverage
logger.error(exc)
sys.exit(1)
message = "Uvicorn running on unix socket %s (Press CTRL+C to quit)"
sock_name_format = "%s"
color_message = (
"Uvicorn running on "
+ click.style(sock_name_format, bold=True)
+ " (Press CTRL+C to quit)"
)
color_message = "Uvicorn running on " + click.style(sock_name_format, bold=True) + " (Press CTRL+C to quit)"
logger_args = [self.uds]
elif self.fd: # pragma: py-win32
sock = socket.fromfd(self.fd, socket.AF_UNIX, socket.SOCK_STREAM)
message = "Uvicorn running on socket %s (Press CTRL+C to quit)"
fd_name_format = "%s"
color_message = (
"Uvicorn running on "
+ click.style(fd_name_format, bold=True)
+ " (Press CTRL+C to quit)"
)
color_message = "Uvicorn running on " + click.style(fd_name_format, bold=True) + " (Press CTRL+C to quit)"
logger_args = [sock.getsockname()]
else:
family = socket.AF_INET
addr_format = "%s://%s:%d"
if self.host and ":" in self.host: # pragma: py-win32
if self.host and ":" in self.host: # pragma: full coverage
# It's an IPv6 address.
family = socket.AF_INET6
addr_format = "%s://[%s]:%d"
@@ -556,16 +531,12 @@ class Config:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
try:
sock.bind((self.host, self.port))
except OSError as exc:
except OSError as exc: # pragma: full coverage
logger.error(exc)
sys.exit(1)
message = f"Uvicorn running on {addr_format} (Press CTRL+C to quit)"
color_message = (
"Uvicorn running on "
+ click.style(addr_format, bold=True)
+ " (Press CTRL+C to quit)"
)
color_message = "Uvicorn running on " + click.style(addr_format, bold=True) + " (Press CTRL+C to quit)"
protocol_name = "https" if self.is_ssl else "http"
logger_args = [protocol_name, self.host, sock.getsockname()[1]]
logger.info(message, *logger_args, extra={"color_message": color_message})

View File

@@ -12,9 +12,7 @@ def import_from_string(import_str: Any) -> Any:
module_str, _, attrs_str = import_str.partition(":")
if not module_str or not attrs_str:
message = (
'Import string "{import_str}" must be in format "<module>:<attribute>".'
)
message = 'Import string "{import_str}" must be in format "<module>:<attribute>".'
raise ImportFromStringError(message.format(import_str=import_str))
try:
@@ -31,8 +29,6 @@ def import_from_string(import_str: Any) -> Any:
instance = getattr(instance, attr_str)
except AttributeError:
message = 'Attribute "{attrs_str}" not found in module "{module_str}".'
raise ImportFromStringError(
message.format(attrs_str=attrs_str, module_str=module_str)
)
raise ImportFromStringError(message.format(attrs_str=attrs_str, module_str=module_str))
return instance

View File

@@ -1,4 +1,6 @@
from typing import Any, Dict
from __future__ import annotations
from typing import Any
from uvicorn import Config
@@ -6,7 +8,7 @@ from uvicorn import Config
class LifespanOff:
def __init__(self, config: Config) -> None:
self.should_exit = False
self.state: Dict[str, Any] = {}
self.state: dict[str, Any] = {}
async def startup(self) -> None:
pass

View File

@@ -1,7 +1,9 @@
from __future__ import annotations
import asyncio
import logging
from asyncio import Queue
from typing import Any, Dict, Union
from typing import Any, Union
from uvicorn import Config
from uvicorn._types import (
@@ -35,12 +37,12 @@ class LifespanOn:
self.logger = logging.getLogger("uvicorn.error")
self.startup_event = asyncio.Event()
self.shutdown_event = asyncio.Event()
self.receive_queue: "Queue[LifespanReceiveMessage]" = asyncio.Queue()
self.receive_queue: Queue[LifespanReceiveMessage] = asyncio.Queue()
self.error_occured = False
self.startup_failed = False
self.shutdown_failed = False
self.should_exit = False
self.state: Dict[str, Any] = {}
self.state: dict[str, Any] = {}
async def startup(self) -> None:
self.logger.info("Waiting for application startup.")
@@ -48,7 +50,7 @@ class LifespanOn:
loop = asyncio.get_event_loop()
main_lifespan_task = loop.create_task(self.main()) # noqa: F841
# Keep a hard reference to prevent garbage collection
# See https://github.com/encode/uvicorn/pull/972
# See https://github.com/Kludex/uvicorn/pull/972
startup_event: LifespanStartupEvent = {"type": "lifespan.startup"}
await self.receive_queue.put(startup_event)
await self.startup_event.wait()
@@ -67,9 +69,7 @@ class LifespanOn:
await self.receive_queue.put(shutdown_event)
await self.shutdown_event.wait()
if self.shutdown_failed or (
self.error_occured and self.config.lifespan == "on"
):
if self.shutdown_failed or (self.error_occured and self.config.lifespan == "on"):
self.logger.error("Application shutdown failed. Exiting.")
self.should_exit = True
else:
@@ -99,7 +99,7 @@ class LifespanOn:
self.startup_event.set()
self.shutdown_event.set()
async def send(self, message: "LifespanSendMessage") -> None:
async def send(self, message: LifespanSendMessage) -> None:
assert message["type"] in (
"lifespan.startup.complete",
"lifespan.startup.failed",
@@ -133,5 +133,5 @@ class LifespanOn:
if message.get("message"):
self.logger.error(message["message"])
async def receive(self) -> "LifespanReceiveMessage":
async def receive(self) -> LifespanReceiveMessage:
return await self.receive_queue.get()

View File

@@ -1,8 +1,10 @@
from __future__ import annotations
import http
import logging
import sys
from copy import copy
from typing import Literal, Optional
from typing import Literal
import click
@@ -14,7 +16,7 @@ class ColourizedFormatter(logging.Formatter):
A custom log formatter class that:
* Outputs the LOG_LEVEL with an appropriate color.
* If a log call includes an `extras={"color_message": ...}` it will be used
* If a log call includes an `extra={"color_message": ...}` it will be used
for formatting the output, instead of the plain text message.
"""
@@ -24,17 +26,15 @@ class ColourizedFormatter(logging.Formatter):
logging.INFO: lambda level_name: click.style(str(level_name), fg="green"),
logging.WARNING: lambda level_name: click.style(str(level_name), fg="yellow"),
logging.ERROR: lambda level_name: click.style(str(level_name), fg="red"),
logging.CRITICAL: lambda level_name: click.style(
str(level_name), fg="bright_red"
),
logging.CRITICAL: lambda level_name: click.style(str(level_name), fg="bright_red"),
}
def __init__(
self,
fmt: Optional[str] = None,
datefmt: Optional[str] = None,
fmt: str | None = None,
datefmt: str | None = None,
style: Literal["%", "{", "$"] = "%",
use_colors: Optional[bool] = None,
use_colors: bool | None = None,
):
if use_colors in (True, False):
self.use_colors = use_colors
@@ -84,7 +84,7 @@ class AccessFormatter(ColourizedFormatter):
status_phrase = http.HTTPStatus(status_code).phrase
except ValueError:
status_phrase = ""
status_and_phrase = "%s %s" % (status_code, status_phrase)
status_and_phrase = f"{status_code} {status_phrase}"
if self.use_colors:
def default(code: int) -> str:
@@ -104,7 +104,7 @@ class AccessFormatter(ColourizedFormatter):
status_code,
) = recordcopy.args # type: ignore[misc]
status_code = self.get_status_code(int(status_code)) # type: ignore[arg-type]
request_line = "%s %s HTTP/%s" % (method, full_path, http_version)
request_line = f"{method} {full_path} HTTP/{http_version}"
if self.use_colors:
request_line = click.style(request_line, bold=True)
recordcopy.__dict__.update(

View File

@@ -1,10 +1,11 @@
from __future__ import annotations
import asyncio
import logging
import sys
logger = logging.getLogger("uvicorn.error")
from collections.abc import Callable
def asyncio_setup(use_subprocess: bool = False) -> None:
if sys.platform == "win32" and use_subprocess:
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
def asyncio_loop_factory(use_subprocess: bool = False) -> Callable[[], asyncio.AbstractEventLoop]:
if sys.platform == "win32" and not use_subprocess:
return asyncio.ProactorEventLoop
return asyncio.SelectorEventLoop

View File

@@ -1,11 +1,17 @@
def auto_loop_setup(use_subprocess: bool = False) -> None:
from __future__ import annotations
import asyncio
from collections.abc import Callable
def auto_loop_factory(use_subprocess: bool = False) -> Callable[[], asyncio.AbstractEventLoop]:
try:
import uvloop # noqa
except ImportError: # pragma: no cover
from uvicorn.loops.asyncio import asyncio_setup as loop_setup
from uvicorn.loops.asyncio import asyncio_loop_factory as loop_factory
loop_setup(use_subprocess=use_subprocess)
return loop_factory(use_subprocess=use_subprocess)
else: # pragma: no cover
from uvicorn.loops.uvloop import uvloop_setup
from uvicorn.loops.uvloop import uvloop_loop_factory
uvloop_setup(use_subprocess=use_subprocess)
return uvloop_loop_factory(use_subprocess=use_subprocess)

View File

@@ -1,7 +1,10 @@
from __future__ import annotations
import asyncio
from collections.abc import Callable
import uvloop
def uvloop_setup(use_subprocess: bool = False) -> None:
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
def uvloop_loop_factory(use_subprocess: bool = False) -> Callable[[], asyncio.AbstractEventLoop]:
return uvloop.new_event_loop

View File

@@ -1,41 +1,44 @@
from __future__ import annotations
import asyncio
import logging
import os
import platform
import ssl
import sys
import typing
import warnings
from configparser import RawConfigParser
from typing import IO, Any, Callable, get_args
import click
import uvicorn
from uvicorn._types import ASGIApplication
from uvicorn.config import (
HTTP_PROTOCOLS,
INTERFACES,
LIFESPAN,
LOG_LEVELS,
LOGGING_CONFIG,
LOOP_SETUPS,
SSL_PROTOCOL_VERSION,
WS_PROTOCOLS,
Config,
HTTPProtocolType,
InterfaceType,
LifespanType,
LoopSetupType,
LoopFactoryType,
WSProtocolType,
)
from uvicorn.server import Server, ServerState # noqa: F401 # Used to be defined here.
from uvicorn.server import Server
from uvicorn.supervisors import ChangeReload, Multiprocess
LEVEL_CHOICES = click.Choice(list(LOG_LEVELS.keys()))
HTTP_CHOICES = click.Choice(list(HTTP_PROTOCOLS.keys()))
WS_CHOICES = click.Choice(list(WS_PROTOCOLS.keys()))
LIFESPAN_CHOICES = click.Choice(list(LIFESPAN.keys()))
LOOP_CHOICES = click.Choice([key for key in LOOP_SETUPS.keys() if key != "none"])
INTERFACE_CHOICES = click.Choice(INTERFACES)
def _metavar_from_type(_type: Any) -> str:
return f"[{'|'.join(key for key in get_args(_type) if key != 'none')}]"
STARTUP_FAILURE = 3
logger = logging.getLogger("uvicorn.error")
@@ -45,12 +48,11 @@ def print_version(ctx: click.Context, param: click.Parameter, value: bool) -> No
if not value or ctx.resilient_parsing:
return
click.echo(
"Running uvicorn %s with %s %s on %s"
% (
uvicorn.__version__,
platform.python_implementation(),
platform.python_version(),
platform.system(),
"Running uvicorn {version} with {py_implementation} {py_version} on {system}".format( # noqa: UP032
version=uvicorn.__version__,
py_implementation=platform.python_implementation(),
py_version=platform.python_version(),
system=platform.system(),
)
)
ctx.exit()
@@ -73,16 +75,13 @@ def print_version(ctx: click.Context, param: click.Parameter, value: bool) -> No
show_default=True,
)
@click.option("--uds", type=str, default=None, help="Bind to a UNIX domain socket.")
@click.option(
"--fd", type=int, default=None, help="Bind to socket from this file descriptor."
)
@click.option("--fd", type=int, default=None, help="Bind to socket from this file descriptor.")
@click.option("--reload", is_flag=True, default=False, help="Enable auto-reload.")
@click.option(
"--reload-dir",
"reload_dirs",
multiple=True,
help="Set reload directories explicitly, instead of using the current working"
" directory.",
help="Set reload directories explicitly, instead of using the current working directory.",
type=click.Path(exists=True),
)
@click.option(
@@ -107,8 +106,7 @@ def print_version(ctx: click.Context, param: click.Parameter, value: bool) -> No
type=float,
default=0.25,
show_default=True,
help="Delay between previous and next check if application needs to be."
" Defaults to 0.25s.",
help="Delay between previous and next check if application needs to be. Defaults to 0.25s.",
)
@click.option(
"--workers",
@@ -119,21 +117,24 @@ def print_version(ctx: click.Context, param: click.Parameter, value: bool) -> No
)
@click.option(
"--loop",
type=LOOP_CHOICES,
type=str,
metavar=_metavar_from_type(LoopFactoryType),
default="auto",
help="Event loop implementation.",
help="Event loop factory implementation.",
show_default=True,
)
@click.option(
"--http",
type=HTTP_CHOICES,
type=str,
metavar=_metavar_from_type(HTTPProtocolType),
default="auto",
help="HTTP protocol implementation.",
show_default=True,
)
@click.option(
"--ws",
type=WS_CHOICES,
type=str,
metavar=_metavar_from_type(WSProtocolType),
default="auto",
help="WebSocket protocol implementation.",
show_default=True,
@@ -156,14 +157,14 @@ def print_version(ctx: click.Context, param: click.Parameter, value: bool) -> No
"--ws-ping-interval",
type=float,
default=20.0,
help="WebSocket ping interval",
help="WebSocket ping interval in seconds.",
show_default=True,
)
@click.option(
"--ws-ping-timeout",
type=float,
default=20.0,
help="WebSocket ping timeout",
help="WebSocket ping timeout in seconds.",
show_default=True,
)
@click.option(
@@ -224,8 +225,7 @@ def print_version(ctx: click.Context, param: click.Parameter, value: bool) -> No
"--proxy-headers/--no-proxy-headers",
is_flag=True,
default=True,
help="Enable/Disable X-Forwarded-Proto, X-Forwarded-For, X-Forwarded-Port to "
"populate remote address info.",
help="Enable/Disable X-Forwarded-Proto, X-Forwarded-For to populate url scheme and remote address info.",
)
@click.option(
"--server-header/--no-server-header",
@@ -243,8 +243,10 @@ def print_version(ctx: click.Context, param: click.Parameter, value: bool) -> No
"--forwarded-allow-ips",
type=str,
default=None,
help="Comma separated list of IPs to trust with proxy headers. Defaults to"
" the $FORWARDED_ALLOW_IPS environment variable if available, or '127.0.0.1'.",
help="Comma separated list of IP Addresses, IP Networks, or literals "
"(e.g. UNIX Socket path) to trust with proxy headers. Defaults to the "
"$FORWARDED_ALLOW_IPS environment variable if available, or '127.0.0.1'. "
"The literal '*' means trust everything.",
)
@click.option(
"--root-path",
@@ -256,8 +258,7 @@ def print_version(ctx: click.Context, param: click.Parameter, value: bool) -> No
"--limit-concurrency",
type=int,
default=None,
help="Maximum number of concurrent connections or tasks to allow, before issuing"
" HTTP 503 responses.",
help="Maximum number of concurrent connections or tasks to allow, before issuing HTTP 503 responses.",
)
@click.option(
"--backlog",
@@ -275,7 +276,7 @@ def print_version(ctx: click.Context, param: click.Parameter, value: bool) -> No
"--timeout-keep-alive",
type=int,
default=5,
help="Close Keep-Alive connections if no new data is received within this timeout.",
help="Close Keep-Alive connections if no new data is received within this timeout (in seconds).",
show_default=True,
)
@click.option(
@@ -285,8 +286,13 @@ def print_version(ctx: click.Context, param: click.Parameter, value: bool) -> No
help="Maximum number of seconds to wait for graceful shutdown.",
)
@click.option(
"--ssl-keyfile", type=str, default=None, help="SSL key file", show_default=True
"--timeout-worker-healthcheck",
type=int,
default=5,
help="Maximum number of seconds to wait for a worker to respond to a healthcheck.",
show_default=True,
)
@click.option("--ssl-keyfile", type=str, default=None, help="SSL key file", show_default=True)
@click.option(
"--ssl-certfile",
type=str,
@@ -370,9 +376,9 @@ def main(
port: int,
uds: str,
fd: int,
loop: LoopSetupType,
http: HTTPProtocolType,
ws: WSProtocolType,
loop: LoopFactoryType | str,
http: HTTPProtocolType | str,
ws: WSProtocolType | str,
ws_max_size: int,
ws_max_queue: int,
ws_ping_interval: float,
@@ -381,9 +387,9 @@ def main(
lifespan: LifespanType,
interface: InterfaceType,
reload: bool,
reload_dirs: typing.List[str],
reload_includes: typing.List[str],
reload_excludes: typing.List[str],
reload_dirs: list[str],
reload_includes: list[str],
reload_excludes: list[str],
reload_delay: float,
workers: int,
env_file: str,
@@ -399,7 +405,8 @@ def main(
backlog: int,
limit_max_requests: int,
timeout_keep_alive: int,
timeout_graceful_shutdown: typing.Optional[int],
timeout_graceful_shutdown: int | None,
timeout_worker_healthcheck: int,
ssl_keyfile: str,
ssl_certfile: str,
ssl_keyfile_password: str,
@@ -407,10 +414,10 @@ def main(
ssl_cert_reqs: int,
ssl_ca_certs: str,
ssl_ciphers: str,
headers: typing.List[str],
headers: list[str],
use_colors: bool,
app_dir: str,
h11_max_incomplete_event_size: typing.Optional[int],
h11_max_incomplete_event_size: int | None,
factory: bool,
) -> None:
run(
@@ -449,6 +456,7 @@ def main(
limit_max_requests=limit_max_requests,
timeout_keep_alive=timeout_keep_alive,
timeout_graceful_shutdown=timeout_graceful_shutdown,
timeout_worker_healthcheck=timeout_worker_healthcheck,
ssl_keyfile=ssl_keyfile,
ssl_certfile=ssl_certfile,
ssl_keyfile_password=ssl_keyfile_password,
@@ -465,56 +473,55 @@ def main(
def run(
app: typing.Union["ASGIApplication", typing.Callable, str],
app: ASGIApplication | Callable[..., Any] | str,
*,
host: str = "127.0.0.1",
port: int = 8000,
uds: typing.Optional[str] = None,
fd: typing.Optional[int] = None,
loop: LoopSetupType = "auto",
http: typing.Union[typing.Type[asyncio.Protocol], HTTPProtocolType] = "auto",
ws: typing.Union[typing.Type[asyncio.Protocol], WSProtocolType] = "auto",
uds: str | None = None,
fd: int | None = None,
loop: LoopFactoryType | str = "auto",
http: type[asyncio.Protocol] | HTTPProtocolType | str = "auto",
ws: type[asyncio.Protocol] | WSProtocolType | str = "auto",
ws_max_size: int = 16777216,
ws_max_queue: int = 32,
ws_ping_interval: typing.Optional[float] = 20.0,
ws_ping_timeout: typing.Optional[float] = 20.0,
ws_ping_interval: float | None = 20.0,
ws_ping_timeout: float | None = 20.0,
ws_per_message_deflate: bool = True,
lifespan: LifespanType = "auto",
interface: InterfaceType = "auto",
reload: bool = False,
reload_dirs: typing.Optional[typing.Union[typing.List[str], str]] = None,
reload_includes: typing.Optional[typing.Union[typing.List[str], str]] = None,
reload_excludes: typing.Optional[typing.Union[typing.List[str], str]] = None,
reload_dirs: list[str] | str | None = None,
reload_includes: list[str] | str | None = None,
reload_excludes: list[str] | str | None = None,
reload_delay: float = 0.25,
workers: typing.Optional[int] = None,
env_file: typing.Optional[typing.Union[str, os.PathLike]] = None,
log_config: typing.Optional[
typing.Union[typing.Dict[str, typing.Any], str]
] = LOGGING_CONFIG,
log_level: typing.Optional[typing.Union[str, int]] = None,
workers: int | None = None,
env_file: str | os.PathLike[str] | None = None,
log_config: dict[str, Any] | str | RawConfigParser | IO[Any] | None = LOGGING_CONFIG,
log_level: str | int | None = None,
access_log: bool = True,
proxy_headers: bool = True,
server_header: bool = True,
date_header: bool = True,
forwarded_allow_ips: typing.Optional[typing.Union[typing.List[str], str]] = None,
forwarded_allow_ips: list[str] | str | None = None,
root_path: str = "",
limit_concurrency: typing.Optional[int] = None,
limit_concurrency: int | None = None,
backlog: int = 2048,
limit_max_requests: typing.Optional[int] = None,
limit_max_requests: int | None = None,
timeout_keep_alive: int = 5,
timeout_graceful_shutdown: typing.Optional[int] = None,
ssl_keyfile: typing.Optional[str] = None,
ssl_certfile: typing.Optional[typing.Union[str, os.PathLike]] = None,
ssl_keyfile_password: typing.Optional[str] = None,
timeout_graceful_shutdown: int | None = None,
timeout_worker_healthcheck: int = 5,
ssl_keyfile: str | os.PathLike[str] | None = None,
ssl_certfile: str | os.PathLike[str] | None = None,
ssl_keyfile_password: str | None = None,
ssl_version: int = SSL_PROTOCOL_VERSION,
ssl_cert_reqs: int = ssl.CERT_NONE,
ssl_ca_certs: typing.Optional[str] = None,
ssl_ca_certs: str | os.PathLike[str] | None = None,
ssl_ciphers: str = "TLSv1",
headers: typing.Optional[typing.List[typing.Tuple[str, str]]] = None,
use_colors: typing.Optional[bool] = None,
app_dir: typing.Optional[str] = None,
headers: list[tuple[str, str]] | None = None,
use_colors: bool | None = None,
app_dir: str | None = None,
factory: bool = False,
h11_max_incomplete_event_size: typing.Optional[int] = None,
h11_max_incomplete_event_size: int | None = None,
) -> None:
if app_dir is not None:
sys.path.insert(0, app_dir)
@@ -555,6 +562,7 @@ def run(
limit_max_requests=limit_max_requests,
timeout_keep_alive=timeout_keep_alive,
timeout_graceful_shutdown=timeout_graceful_shutdown,
timeout_worker_healthcheck=timeout_worker_healthcheck,
ssl_keyfile=ssl_keyfile,
ssl_certfile=ssl_certfile,
ssl_keyfile_password=ssl_keyfile_password,
@@ -571,26 +579,39 @@ def run(
if (config.reload or config.workers > 1) and not isinstance(app, str):
logger = logging.getLogger("uvicorn.error")
logger.warning(
"You must pass the application as an import string to enable 'reload' or "
"'workers'."
)
logger.warning("You must pass the application as an import string to enable 'reload' or 'workers'.")
sys.exit(1)
if config.should_reload:
sock = config.bind_socket()
ChangeReload(config, target=server.run, sockets=[sock]).run()
elif config.workers > 1:
sock = config.bind_socket()
Multiprocess(config, target=server.run, sockets=[sock]).run()
else:
server.run()
if config.uds and os.path.exists(config.uds):
os.remove(config.uds) # pragma: py-win32
try:
if config.should_reload:
sock = config.bind_socket()
ChangeReload(config, target=server.run, sockets=[sock]).run()
elif config.workers > 1:
sock = config.bind_socket()
Multiprocess(config, target=server.run, sockets=[sock]).run()
else:
server.run()
except KeyboardInterrupt:
pass # pragma: full coverage
finally:
if config.uds and os.path.exists(config.uds):
os.remove(config.uds) # pragma: py-win32
if not server.started and not config.should_reload and config.workers == 1:
sys.exit(STARTUP_FAILURE)
def __getattr__(name: str) -> Any:
if name == "ServerState":
warnings.warn(
"uvicorn.main.ServerState is deprecated, use uvicorn.server.ServerState instead.",
DeprecationWarning,
)
from uvicorn.server import ServerState
return ServerState
raise AttributeError(f"module {__name__} has no attribute {name}")
if __name__ == "__main__":
main() # pragma: no cover

View File

@@ -10,8 +10,6 @@ class ASGI2Middleware:
def __init__(self, app: "ASGI2Application"):
self.app = app
async def __call__(
self, scope: "Scope", receive: "ASGIReceiveCallable", send: "ASGISendCallable"
) -> None:
async def __call__(self, scope: "Scope", receive: "ASGIReceiveCallable", send: "ASGISendCallable") -> None:
instance = self.app(scope)
await instance(receive, send)

View File

@@ -1,84 +1,142 @@
"""
This middleware can be used when a known proxy is fronting the application,
and is trusted to be properly setting the `X-Forwarded-Proto` and
`X-Forwarded-For` headers with the connecting client information.
from __future__ import annotations
Modifies the `client` and `scheme` information so that they reference
the connecting client, rather that the connecting proxy.
import ipaddress
https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers#Proxies
"""
from typing import List, Optional, Tuple, Union, cast
from uvicorn._types import (
ASGI3Application,
ASGIReceiveCallable,
ASGISendCallable,
HTTPScope,
Scope,
WebSocketScope,
)
from uvicorn._types import ASGI3Application, ASGIReceiveCallable, ASGISendCallable, Scope
class ProxyHeadersMiddleware:
def __init__(
self,
app: "ASGI3Application",
trusted_hosts: Union[List[str], str] = "127.0.0.1",
) -> None:
"""Middleware for handling known proxy headers
This middleware can be used when a known proxy is fronting the application,
and is trusted to be properly setting the `X-Forwarded-Proto` and
`X-Forwarded-For` headers with the connecting client information.
Modifies the `client` and `scheme` information so that they reference
the connecting client, rather that the connecting proxy.
References:
- <https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers#Proxies>
- <https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/X-Forwarded-For>
"""
def __init__(self, app: ASGI3Application, trusted_hosts: list[str] | str = "127.0.0.1") -> None:
self.app = app
if isinstance(trusted_hosts, str):
self.trusted_hosts = {item.strip() for item in trusted_hosts.split(",")}
else:
self.trusted_hosts = set(trusted_hosts)
self.always_trust = "*" in self.trusted_hosts
self.trusted_hosts = _TrustedHosts(trusted_hosts)
def get_trusted_client_host(
self, x_forwarded_for_hosts: List[str]
) -> Optional[str]:
if self.always_trust:
return x_forwarded_for_hosts[0]
async def __call__(self, scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable) -> None:
if scope["type"] == "lifespan":
return await self.app(scope, receive, send)
for host in reversed(x_forwarded_for_hosts):
if host not in self.trusted_hosts:
return host
client_addr = scope.get("client")
client_host = client_addr[0] if client_addr else None
return None
if client_host in self.trusted_hosts:
headers = dict(scope["headers"])
async def __call__(
self, scope: "Scope", receive: "ASGIReceiveCallable", send: "ASGISendCallable"
) -> None:
if scope["type"] in ("http", "websocket"):
scope = cast(Union["HTTPScope", "WebSocketScope"], scope)
client_addr: Optional[Tuple[str, int]] = scope.get("client")
client_host = client_addr[0] if client_addr else None
if b"x-forwarded-proto" in headers:
x_forwarded_proto = headers[b"x-forwarded-proto"].decode("latin1").strip()
if self.always_trust or client_host in self.trusted_hosts:
headers = dict(scope["headers"])
if b"x-forwarded-proto" in headers:
# Determine if the incoming request was http or https based on
# the X-Forwarded-Proto header.
x_forwarded_proto = (
headers[b"x-forwarded-proto"].decode("latin1").strip()
)
if x_forwarded_proto in {"http", "https", "ws", "wss"}:
if scope["type"] == "websocket":
scope["scheme"] = (
"wss" if x_forwarded_proto == "https" else "ws"
)
scope["scheme"] = x_forwarded_proto.replace("http", "ws")
else:
scope["scheme"] = x_forwarded_proto
if b"x-forwarded-for" in headers:
# Determine the client address from the last trusted IP in the
# X-Forwarded-For header. We've lost the connecting client's port
# information by now, so only include the host.
x_forwarded_for = headers[b"x-forwarded-for"].decode("latin1")
x_forwarded_for_hosts = [
item.strip() for item in x_forwarded_for.split(",")
]
host = self.get_trusted_client_host(x_forwarded_for_hosts)
if b"x-forwarded-for" in headers:
x_forwarded_for = headers[b"x-forwarded-for"].decode("latin1")
host = self.trusted_hosts.get_trusted_client_host(x_forwarded_for)
if host:
# If the x-forwarded-for header is empty then host is an empty string.
# Only set the client if we actually got something usable.
# See: https://github.com/Kludex/uvicorn/issues/1068
# We've lost the connecting client's port information by now,
# so only include the host.
port = 0
scope["client"] = (host, port) # type: ignore[arg-type]
scope["client"] = (host, port)
return await self.app(scope, receive, send)
def _parse_raw_hosts(value: str) -> list[str]:
return [item.strip() for item in value.split(",")]
class _TrustedHosts:
"""Container for trusted hosts and networks"""
def __init__(self, trusted_hosts: list[str] | str) -> None:
self.always_trust: bool = trusted_hosts in ("*", ["*"])
self.trusted_literals: set[str] = set()
self.trusted_hosts: set[ipaddress.IPv4Address | ipaddress.IPv6Address] = set()
self.trusted_networks: set[ipaddress.IPv4Network | ipaddress.IPv6Network] = set()
# Notes:
# - We separate hosts from literals as there are many ways to write
# an IPv6 Address so we need to compare by object.
# - We don't convert IP Address to single host networks (e.g. /32 / 128) as
# it more efficient to do an address lookup in a set than check for
# membership in each network.
# - We still allow literals as it might be possible that we receive a
# something that isn't an IP Address e.g. a unix socket.
if not self.always_trust:
if isinstance(trusted_hosts, str):
trusted_hosts = _parse_raw_hosts(trusted_hosts)
for host in trusted_hosts:
# Note: because we always convert invalid IP types to literals it
# is not possible for the user to know they provided a malformed IP
# type - this may lead to unexpected / difficult to debug behaviour.
if "/" in host:
# Looks like a network
try:
self.trusted_networks.add(ipaddress.ip_network(host))
except ValueError:
# Was not a valid IP Network
self.trusted_literals.add(host)
else:
try:
self.trusted_hosts.add(ipaddress.ip_address(host))
except ValueError:
# Was not a valid IP Address
self.trusted_literals.add(host)
def __contains__(self, host: str | None) -> bool:
if self.always_trust:
return True
if not host:
return False
try:
ip = ipaddress.ip_address(host)
if ip in self.trusted_hosts:
return True
return any(ip in net for net in self.trusted_networks)
except ValueError:
return host in self.trusted_literals
def get_trusted_client_host(self, x_forwarded_for: str) -> str:
"""Extract the client host from x_forwarded_for header
In general this is the first "untrusted" host in the forwarded for list.
"""
x_forwarded_for_hosts = _parse_raw_hosts(x_forwarded_for)
if self.always_trust:
return x_forwarded_for_hosts[0]
# Note: each proxy appends to the header list so check it in reverse order
for host in reversed(x_forwarded_for_hosts):
if host not in self:
return host
# All hosts are trusted meaning that the client was also a trusted proxy
# See https://github.com/Kludex/uvicorn/issues/1068#issuecomment-855371576
return x_forwarded_for_hosts[0]

View File

@@ -1,10 +1,12 @@
from __future__ import annotations
import asyncio
import concurrent.futures
import io
import sys
import warnings
from collections import deque
from typing import Deque, Iterable, Optional, Tuple
from collections.abc import Iterable
from uvicorn._types import (
ASGIReceiveCallable,
@@ -22,16 +24,18 @@ from uvicorn._types import (
)
def build_environ(
scope: "HTTPScope", message: "ASGIReceiveEvent", body: io.BytesIO
) -> Environ:
def build_environ(scope: HTTPScope, message: ASGIReceiveEvent, body: io.BytesIO) -> Environ:
"""
Builds a scope and request message into a WSGI environ object.
"""
script_name = scope.get("root_path", "").encode("utf8").decode("latin1")
path_info = scope["path"].encode("utf8").decode("latin1")
if path_info.startswith(script_name):
path_info = path_info[len(script_name) :]
environ = {
"REQUEST_METHOD": scope["method"],
"SCRIPT_NAME": "",
"PATH_INFO": scope["path"].encode("utf8").decode("latin1"),
"SCRIPT_NAME": script_name,
"PATH_INFO": path_info,
"QUERY_STRING": scope["query_string"].decode("ascii"),
"SERVER_PROTOCOL": "HTTP/%s" % scope["http_version"],
"wsgi.version": (1, 0),
@@ -78,8 +82,7 @@ def build_environ(
class _WSGIMiddleware:
def __init__(self, app: WSGIApp, workers: int = 10):
warnings.warn(
"Uvicorn's native WSGI implementation is deprecated, you "
"should switch to a2wsgi (`pip install a2wsgi`).",
"Uvicorn's native WSGI implementation is deprecated, you should switch to a2wsgi (`pip install a2wsgi`).",
DeprecationWarning,
)
self.app = app
@@ -87,9 +90,9 @@ class _WSGIMiddleware:
async def __call__(
self,
scope: "HTTPScope",
receive: "ASGIReceiveCallable",
send: "ASGISendCallable",
scope: HTTPScope,
receive: ASGIReceiveCallable,
send: ASGISendCallable,
) -> None:
assert scope["type"] == "http"
instance = WSGIResponder(self.app, self.executor, scope)
@@ -101,7 +104,7 @@ class WSGIResponder:
self,
app: WSGIApp,
executor: concurrent.futures.ThreadPoolExecutor,
scope: "HTTPScope",
scope: HTTPScope,
):
self.app = app
self.executor = executor
@@ -109,21 +112,19 @@ class WSGIResponder:
self.status = None
self.response_headers = None
self.send_event = asyncio.Event()
self.send_queue: Deque[Optional["ASGISendEvent"]] = deque()
self.send_queue: deque[ASGISendEvent | None] = deque()
self.loop: asyncio.AbstractEventLoop = asyncio.get_event_loop()
self.response_started = False
self.exc_info: Optional[ExcInfo] = None
self.exc_info: ExcInfo | None = None
async def __call__(
self, receive: "ASGIReceiveCallable", send: "ASGISendCallable"
) -> None:
async def __call__(self, receive: ASGIReceiveCallable, send: ASGISendCallable) -> None:
message: HTTPRequestEvent = await receive() # type: ignore[assignment]
body = io.BytesIO(message.get("body", b""))
more_body = message.get("more_body", False)
if more_body:
body.seek(0, io.SEEK_END)
while more_body:
body_message: "HTTPRequestEvent" = (
body_message: HTTPRequestEvent = (
await receive() # type: ignore[assignment]
)
body.write(body_message.get("body", b""))
@@ -131,9 +132,7 @@ class WSGIResponder:
body.seek(0)
environ = build_environ(self.scope, message, body)
self.loop = asyncio.get_event_loop()
wsgi = self.loop.run_in_executor(
self.executor, self.wsgi, environ, self.start_response
)
wsgi = self.loop.run_in_executor(self.executor, self.wsgi, environ, self.start_response)
sender = self.loop.create_task(self.sender(send))
try:
await asyncio.wait_for(wsgi, None)
@@ -144,7 +143,7 @@ class WSGIResponder:
if self.exc_info is not None:
raise self.exc_info[0].with_traceback(self.exc_info[1], self.exc_info[2])
async def sender(self, send: "ASGISendCallable") -> None:
async def sender(self, send: ASGISendCallable) -> None:
while True:
if self.send_queue:
message = self.send_queue.popleft()
@@ -158,18 +157,15 @@ class WSGIResponder:
def start_response(
self,
status: str,
response_headers: Iterable[Tuple[str, str]],
exc_info: Optional[ExcInfo] = None,
response_headers: Iterable[tuple[str, str]],
exc_info: ExcInfo | None = None,
) -> None:
self.exc_info = exc_info
if not self.response_started:
self.response_started = True
status_code_str, _ = status.split(" ", 1)
status_code = int(status_code_str)
headers = [
(name.encode("ascii"), value.encode("ascii"))
for name, value in response_headers
]
headers = [(name.encode("ascii"), value.encode("ascii")) for name, value in response_headers]
http_response_start_event: HTTPResponseStartEvent = {
"type": "http.response.start",
"status": status_code,

View File

@@ -1,7 +1,8 @@
import asyncio
from typing import Type
from __future__ import annotations
AutoHTTPProtocol: Type[asyncio.Protocol]
import asyncio
AutoHTTPProtocol: type[asyncio.Protocol]
try:
import httptools # noqa
except ImportError: # pragma: no cover

View File

@@ -1,12 +1,6 @@
import asyncio
from uvicorn._types import (
ASGIReceiveCallable,
ASGISendCallable,
HTTPResponseBodyEvent,
HTTPResponseStartEvent,
Scope,
)
from uvicorn._types import ASGIReceiveCallable, ASGISendCallable, Scope
CLOSE_HEADER = (b"connection", b"close")
@@ -22,7 +16,7 @@ class FlowControl:
self._is_writable_event.set()
async def drain(self) -> None:
await self._is_writable_event.wait()
await self._is_writable_event.wait() # pragma: full coverage
def pause_reading(self) -> None:
if not self.read_paused:
@@ -35,32 +29,26 @@ class FlowControl:
self._transport.resume_reading()
def pause_writing(self) -> None:
if not self.write_paused:
if not self.write_paused: # pragma: full coverage
self.write_paused = True
self._is_writable_event.clear()
def resume_writing(self) -> None:
if self.write_paused:
if self.write_paused: # pragma: full coverage
self.write_paused = False
self._is_writable_event.set()
async def service_unavailable(
scope: "Scope", receive: "ASGIReceiveCallable", send: "ASGISendCallable"
) -> None:
response_start: "HTTPResponseStartEvent" = {
"type": "http.response.start",
"status": 503,
"headers": [
(b"content-type", b"text/plain; charset=utf-8"),
(b"connection", b"close"),
],
}
await send(response_start)
response_body: "HTTPResponseBodyEvent" = {
"type": "http.response.body",
"body": b"Service Unavailable",
"more_body": False,
}
await send(response_body)
async def service_unavailable(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable) -> None:
await send(
{
"type": "http.response.start",
"status": 503,
"headers": [
(b"content-type", b"text/plain; charset=utf-8"),
(b"content-length", b"19"),
(b"connection", b"close"),
],
}
)
await send({"type": "http.response.body", "body": b"Service Unavailable", "more_body": False})

View File

@@ -1,16 +1,9 @@
from __future__ import annotations
import asyncio
import http
import logging
from typing import (
Any,
Callable,
Dict,
List,
Literal,
Optional,
Tuple,
cast,
)
from typing import Any, Callable, Literal, cast
from urllib.parse import unquote
import h11
@@ -27,19 +20,8 @@ from uvicorn._types import (
)
from uvicorn.config import Config
from uvicorn.logging import TRACE_LOG_LEVEL
from uvicorn.protocols.http.flow_control import (
CLOSE_HEADER,
HIGH_WATER_LIMIT,
FlowControl,
service_unavailable,
)
from uvicorn.protocols.utils import (
get_client_addr,
get_local_addr,
get_path_with_query_string,
get_remote_addr,
is_ssl,
)
from uvicorn.protocols.http.flow_control import CLOSE_HEADER, HIGH_WATER_LIMIT, FlowControl, service_unavailable
from uvicorn.protocols.utils import get_client_addr, get_local_addr, get_path_with_query_string, get_remote_addr, is_ssl
from uvicorn.server import ServerState
@@ -50,9 +32,7 @@ def _get_status_phrase(status_code: int) -> bytes:
return b""
STATUS_PHRASES = {
status_code: _get_status_phrase(status_code) for status_code in range(100, 600)
}
STATUS_PHRASES = {status_code: _get_status_phrase(status_code) for status_code in range(100, 600)}
class H11Protocol(asyncio.Protocol):
@@ -60,8 +40,8 @@ class H11Protocol(asyncio.Protocol):
self,
config: Config,
server_state: ServerState,
app_state: Dict[str, Any],
_loop: Optional[asyncio.AbstractEventLoop] = None,
app_state: dict[str, Any],
_loop: asyncio.AbstractEventLoop | None = None,
) -> None:
if not config.loaded:
config.load()
@@ -84,7 +64,7 @@ class H11Protocol(asyncio.Protocol):
self.app_state = app_state
# Timeouts
self.timeout_keep_alive_task: Optional[asyncio.TimerHandle] = None
self.timeout_keep_alive_task: asyncio.TimerHandle | None = None
self.timeout_keep_alive = config.timeout_keep_alive
# Shared server state
@@ -95,13 +75,13 @@ class H11Protocol(asyncio.Protocol):
# Per-connection state
self.transport: asyncio.Transport = None # type: ignore[assignment]
self.flow: FlowControl = None # type: ignore[assignment]
self.server: Optional[Tuple[str, int]] = None
self.client: Optional[Tuple[str, int]] = None
self.scheme: Optional[Literal["http", "https"]] = None
self.server: tuple[str, int] | None = None
self.client: tuple[str, int] | None = None
self.scheme: Literal["http", "https"] | None = None
# Per-request state
self.scope: HTTPScope = None # type: ignore[assignment]
self.headers: List[Tuple[bytes, bytes]] = None # type: ignore[assignment]
self.headers: list[tuple[bytes, bytes]] = None # type: ignore[assignment]
self.cycle: RequestResponseCycle = None # type: ignore[assignment]
# Protocol interface
@@ -120,7 +100,7 @@ class H11Protocol(asyncio.Protocol):
prefix = "%s:%d - " % self.client if self.client else ""
self.logger.log(TRACE_LOG_LEVEL, "%sHTTP connection made", prefix)
def connection_lost(self, exc: Optional[Exception]) -> None:
def connection_lost(self, exc: Exception | None) -> None:
self.connections.discard(self)
if self.logger.level <= TRACE_LOG_LEVEL:
@@ -153,7 +133,7 @@ class H11Protocol(asyncio.Protocol):
self.timeout_keep_alive_task.cancel()
self.timeout_keep_alive_task = None
def _get_upgrade(self) -> Optional[bytes]:
def _get_upgrade(self) -> bytes | None:
connection = []
upgrade = None
for name, value in self.headers:
@@ -167,14 +147,24 @@ class H11Protocol(asyncio.Protocol):
def _should_upgrade_to_ws(self) -> bool:
if self.ws_protocol_class is None:
if self.config.ws == "auto":
msg = "Unsupported upgrade request."
self.logger.warning(msg)
msg = "No supported WebSocket library detected. Please use \"pip install 'uvicorn[standard]'\", or install 'websockets' or 'wsproto' manually." # noqa: E501
self.logger.warning(msg)
return False
return True
def _unsupported_upgrade_warning(self) -> None:
msg = "Unsupported upgrade request."
self.logger.warning(msg)
if not self._should_upgrade_to_ws():
msg = "No supported WebSocket library detected. Please use \"pip install 'uvicorn[standard]'\", or install 'websockets' or 'wsproto' manually." # noqa: E501
self.logger.warning(msg)
def _should_upgrade(self) -> bool:
upgrade = self._get_upgrade()
if upgrade == b"websocket" and self._should_upgrade_to_ws():
return True
if upgrade is not None:
self._unsupported_upgrade_warning()
return False
def data_received(self, data: bytes) -> None:
self._unset_keepalive_if_required()
@@ -205,34 +195,31 @@ class H11Protocol(asyncio.Protocol):
elif isinstance(event, h11.Request):
self.headers = [(key.lower(), value) for key, value in event.headers]
raw_path, _, query_string = event.target.partition(b"?")
path = unquote(raw_path.decode("ascii"))
full_path = self.root_path + path
full_raw_path = self.root_path.encode("ascii") + raw_path
self.scope = {
"type": "http",
"asgi": {
"version": self.config.asgi_version,
"spec_version": "2.3",
},
"asgi": {"version": self.config.asgi_version, "spec_version": "2.3"},
"http_version": event.http_version.decode("ascii"),
"server": self.server,
"client": self.client,
"scheme": self.scheme, # type: ignore[typeddict-item]
"method": event.method.decode("ascii"),
"root_path": self.root_path,
"path": unquote(raw_path.decode("ascii")),
"raw_path": raw_path,
"path": full_path,
"raw_path": full_raw_path,
"query_string": query_string,
"headers": self.headers,
"state": self.app_state.copy(),
}
upgrade = self._get_upgrade()
if upgrade == b"websocket" and self._should_upgrade_to_ws():
if self._should_upgrade():
self.handle_websocket_upgrade(event)
return
# Handle 503 responses when 'limit_concurrency' is exceeded.
if self.limit_concurrency is not None and (
len(self.connections) >= self.limit_concurrency
or len(self.tasks) >= self.limit_concurrency
len(self.connections) >= self.limit_concurrency or len(self.tasks) >= self.limit_concurrency
):
app = service_unavailable
message = "Exceeded concurrency limit."
@@ -240,6 +227,14 @@ class H11Protocol(asyncio.Protocol):
else:
app = self.app
# When starting to process a request, disable the keep-alive
# timeout. Normally we disable this when receiving data from
# client and set back when finishing processing its request.
# However, for pipelined requests processing finishes after
# already receiving the next request and thus the timer may
# be set here, which we don't want.
self._unset_keepalive_if_required()
self.cycle = RequestResponseCycle(
scope=self.scope,
conn=self.conn,
@@ -271,9 +266,11 @@ class H11Protocol(asyncio.Protocol):
continue
self.cycle.more_body = False
self.cycle.message_event.set()
if self.conn.their_state == h11.MUST_CLOSE:
break
def handle_websocket_upgrade(self, event: h11.Request) -> None:
if self.logger.level <= TRACE_LOG_LEVEL:
if self.logger.level <= TRACE_LOG_LEVEL: # pragma: full coverage
prefix = "%s:%d - " % self.client if self.client else ""
self.logger.log(TRACE_LOG_LEVEL, "%sUpgrading to WebSocket", prefix)
@@ -293,7 +290,7 @@ class H11Protocol(asyncio.Protocol):
def send_400_response(self, msg: str) -> None:
reason = STATUS_PHRASES[400]
headers: List[Tuple[bytes, bytes]] = [
headers: list[tuple[bytes, bytes]] = [
(b"content-type", b"text/plain; charset=utf-8"),
(b"connection", b"close"),
]
@@ -318,9 +315,7 @@ class H11Protocol(asyncio.Protocol):
# Set a short Keep-Alive timeout.
self._unset_keepalive_if_required()
self.timeout_keep_alive_task = self.loop.call_later(
self.timeout_keep_alive, self.timeout_keep_alive_handler
)
self.timeout_keep_alive_task = self.loop.call_later(self.timeout_keep_alive, self.timeout_keep_alive_handler)
# Unpause data reads if needed.
self.flow.resume_reading()
@@ -345,13 +340,13 @@ class H11Protocol(asyncio.Protocol):
"""
Called by the transport when the write buffer exceeds the high water mark.
"""
self.flow.pause_writing()
self.flow.pause_writing() # pragma: full coverage
def resume_writing(self) -> None:
"""
Called by the transport when the write buffer drops below the low water mark.
"""
self.flow.resume_writing()
self.flow.resume_writing() # pragma: full coverage
def timeout_keep_alive_handler(self) -> None:
"""
@@ -367,14 +362,14 @@ class H11Protocol(asyncio.Protocol):
class RequestResponseCycle:
def __init__(
self,
scope: "HTTPScope",
scope: HTTPScope,
conn: h11.Connection,
transport: asyncio.Transport,
flow: FlowControl,
logger: logging.Logger,
access_logger: logging.Logger,
access_log: bool,
default_headers: List[Tuple[bytes, bytes]],
default_headers: list[tuple[bytes, bytes]],
message_event: asyncio.Event,
on_response: Callable[..., None],
) -> None:
@@ -403,7 +398,7 @@ class RequestResponseCycle:
self.response_complete = False
# ASGI exception wrapper
async def run_asgi(self, app: "ASGI3Application") -> None:
async def run_asgi(self, app: ASGI3Application) -> None:
try:
result = await app( # type: ignore[func-returns-value]
self.scope, self.receive, self.send
@@ -432,7 +427,7 @@ class RequestResponseCycle:
self.on_response = lambda: None
async def send_500_response(self) -> None:
response_start_event: "HTTPResponseStartEvent" = {
response_start_event: HTTPResponseStartEvent = {
"type": "http.response.start",
"status": 500,
"headers": [
@@ -441,7 +436,7 @@ class RequestResponseCycle:
],
}
await self.send(response_start_event)
response_body_event: "HTTPResponseBodyEvent" = {
response_body_event: HTTPResponseBodyEvent = {
"type": "http.response.body",
"body": b"Internal Server Error",
"more_body": False,
@@ -449,14 +444,14 @@ class RequestResponseCycle:
await self.send(response_body_event)
# ASGI interface
async def send(self, message: "ASGISendEvent") -> None:
async def send(self, message: ASGISendEvent) -> None:
message_type = message["type"]
if self.flow.write_paused and not self.disconnected:
await self.flow.drain()
await self.flow.drain() # pragma: full coverage
if self.disconnected:
return
return # pragma: full coverage
if not self.response_started:
# Sending response status line and headers
@@ -523,12 +518,10 @@ class RequestResponseCycle:
self.transport.close()
self.on_response()
async def receive(self) -> "ASGIReceiveEvent":
async def receive(self) -> ASGIReceiveEvent:
if self.waiting_for_100_continue and not self.transport.is_closing():
headers: List[Tuple[str, str]] = []
event = h11.InformationalResponse(
status_code=100, headers=headers, reason="Continue"
)
headers: list[tuple[str, str]] = []
event = h11.InformationalResponse(status_code=100, headers=headers, reason="Continue")
output = self.conn.send(event=event)
self.transport.write(output)
self.waiting_for_100_continue = False
@@ -541,7 +534,7 @@ class RequestResponseCycle:
if self.disconnected or self.response_complete:
return {"type": "http.disconnect"}
message: "HTTPRequestEvent" = {
message: HTTPRequestEvent = {
"type": "http.request",
"body": self.body,
"more_body": self.more_body,

View File

@@ -1,3 +1,5 @@
from __future__ import annotations
import asyncio
import http
import logging
@@ -5,18 +7,7 @@ import re
import urllib
from asyncio.events import TimerHandle
from collections import deque
from typing import (
Any,
Callable,
Deque,
Dict,
List,
Literal,
Optional,
Tuple,
Union,
cast,
)
from typing import Any, Callable, Literal, cast
import httptools
@@ -24,31 +15,18 @@ from uvicorn._types import (
ASGI3Application,
ASGIReceiveEvent,
ASGISendEvent,
HTTPDisconnectEvent,
HTTPRequestEvent,
HTTPResponseBodyEvent,
HTTPResponseStartEvent,
HTTPScope,
)
from uvicorn.config import Config
from uvicorn.logging import TRACE_LOG_LEVEL
from uvicorn.protocols.http.flow_control import (
CLOSE_HEADER,
HIGH_WATER_LIMIT,
FlowControl,
service_unavailable,
)
from uvicorn.protocols.utils import (
get_client_addr,
get_local_addr,
get_path_with_query_string,
get_remote_addr,
is_ssl,
)
from uvicorn.protocols.http.flow_control import CLOSE_HEADER, HIGH_WATER_LIMIT, FlowControl, service_unavailable
from uvicorn.protocols.utils import get_client_addr, get_local_addr, get_path_with_query_string, get_remote_addr, is_ssl
from uvicorn.server import ServerState
HEADER_RE = re.compile(b'[\x00-\x1F\x7F()<>@,;:[]={} \t\\"]')
HEADER_VALUE_RE = re.compile(b"[\x00-\x1F\x7F]")
HEADER_RE = re.compile(b'[\x00-\x1f\x7f()<>@,;:[]={} \t\\"]')
HEADER_VALUE_RE = re.compile(b"[\x00-\x08\x0a-\x1f\x7f]")
def _get_status_line(status_code: int) -> bytes:
@@ -59,9 +37,7 @@ def _get_status_line(status_code: int) -> bytes:
return b"".join([b"HTTP/1.1 ", str(status_code).encode(), b" ", phrase, b"\r\n"])
STATUS_LINE = {
status_code: _get_status_line(status_code) for status_code in range(100, 600)
}
STATUS_LINE = {status_code: _get_status_line(status_code) for status_code in range(100, 600)}
class HttpToolsProtocol(asyncio.Protocol):
@@ -69,8 +45,8 @@ class HttpToolsProtocol(asyncio.Protocol):
self,
config: Config,
server_state: ServerState,
app_state: Dict[str, Any],
_loop: Optional[asyncio.AbstractEventLoop] = None,
app_state: dict[str, Any],
_loop: asyncio.AbstractEventLoop | None = None,
) -> None:
if not config.loaded:
config.load()
@@ -82,13 +58,21 @@ class HttpToolsProtocol(asyncio.Protocol):
self.access_logger = logging.getLogger("uvicorn.access")
self.access_log = self.access_logger.hasHandlers()
self.parser = httptools.HttpRequestParser(self)
try:
# Enable dangerous leniencies to allow server to a response on the first request from a pipelined request.
self.parser.set_dangerous_leniencies(lenient_data_after_close=True)
except AttributeError: # pragma: no cover
# httptools < 0.6.3
pass
self.ws_protocol_class = config.ws_protocol_class
self.root_path = config.root_path
self.limit_concurrency = config.limit_concurrency
self.app_state = app_state
# Timeouts
self.timeout_keep_alive_task: Optional[TimerHandle] = None
self.timeout_keep_alive_task: TimerHandle | None = None
self.timeout_keep_alive = config.timeout_keep_alive
# Global state
@@ -99,14 +83,14 @@ class HttpToolsProtocol(asyncio.Protocol):
# Per-connection state
self.transport: asyncio.Transport = None # type: ignore[assignment]
self.flow: FlowControl = None # type: ignore[assignment]
self.server: Optional[Tuple[str, int]] = None
self.client: Optional[Tuple[str, int]] = None
self.scheme: Optional[Literal["http", "https"]] = None
self.pipeline: Deque[Tuple[RequestResponseCycle, ASGI3Application]] = deque()
self.server: tuple[str, int] | None = None
self.client: tuple[str, int] | None = None
self.scheme: Literal["http", "https"] | None = None
self.pipeline: deque[tuple[RequestResponseCycle, ASGI3Application]] = deque()
# Per-request state
self.scope: HTTPScope = None # type: ignore[assignment]
self.headers: List[Tuple[bytes, bytes]] = None # type: ignore[assignment]
self.headers: list[tuple[bytes, bytes]] = None # type: ignore[assignment]
self.expect_100_continue = False
self.cycle: RequestResponseCycle = None # type: ignore[assignment]
@@ -126,7 +110,7 @@ class HttpToolsProtocol(asyncio.Protocol):
prefix = "%s:%d - " % self.client if self.client else ""
self.logger.log(TRACE_LOG_LEVEL, "%sHTTP connection made", prefix)
def connection_lost(self, exc: Optional[Exception]) -> None:
def connection_lost(self, exc: Exception | None) -> None:
self.connections.discard(self)
if self.logger.level <= TRACE_LOG_LEVEL:
@@ -153,7 +137,7 @@ class HttpToolsProtocol(asyncio.Protocol):
self.timeout_keep_alive_task.cancel()
self.timeout_keep_alive_task = None
def _get_upgrade(self) -> Optional[bytes]:
def _get_upgrade(self) -> bytes | None:
connection = []
upgrade = None
for name, value in self.headers:
@@ -163,21 +147,22 @@ class HttpToolsProtocol(asyncio.Protocol):
upgrade = value.lower()
if b"upgrade" in connection:
return upgrade
return None
return None # pragma: full coverage
def _should_upgrade_to_ws(self, upgrade: Optional[bytes]) -> bool:
if upgrade == b"websocket" and self.ws_protocol_class is not None:
return True
if self.config.ws == "auto":
msg = "Unsupported upgrade request."
self.logger.warning(msg)
def _should_upgrade_to_ws(self) -> bool:
if self.ws_protocol_class is None:
return False
return True
def _unsupported_upgrade_warning(self) -> None:
self.logger.warning("Unsupported upgrade request.")
if not self._should_upgrade_to_ws():
msg = "No supported WebSocket library detected. Please use \"pip install 'uvicorn[standard]'\", or install 'websockets' or 'wsproto' manually." # noqa: E501
self.logger.warning(msg)
return False
def _should_upgrade(self) -> bool:
upgrade = self._get_upgrade()
return self._should_upgrade_to_ws(upgrade)
return upgrade == b"websocket" and self._should_upgrade_to_ws()
def data_received(self, data: bytes) -> None:
self._unset_keepalive_if_required()
@@ -190,9 +175,10 @@ class HttpToolsProtocol(asyncio.Protocol):
self.send_400_response(msg)
return
except httptools.HttpParserUpgrade:
upgrade = self._get_upgrade()
if self._should_upgrade_to_ws(upgrade):
if self._should_upgrade():
self.handle_websocket_upgrade()
else:
self._unsupported_upgrade_warning()
def handle_websocket_upgrade(self) -> None:
if self.logger.level <= TRACE_LOG_LEVEL:
@@ -217,7 +203,7 @@ class HttpToolsProtocol(asyncio.Protocol):
def send_400_response(self, msg: str) -> None:
content = [STATUS_LINE[400]]
for name, value in self.server_state.default_headers:
content.extend([name, b": ", value, b"\r\n"])
content.extend([name, b": ", value, b"\r\n"]) # pragma: full coverage
content.extend(
[
b"content-type: text/plain; charset=utf-8\r\n",
@@ -269,14 +255,15 @@ class HttpToolsProtocol(asyncio.Protocol):
path = raw_path.decode("ascii")
if "%" in path:
path = urllib.parse.unquote(path)
self.scope["path"] = path
self.scope["raw_path"] = raw_path
full_path = self.root_path + path
full_raw_path = self.root_path.encode("ascii") + raw_path
self.scope["path"] = full_path
self.scope["raw_path"] = full_raw_path
self.scope["query_string"] = parsed_url.query or b""
# Handle 503 responses when 'limit_concurrency' is exceeded.
if self.limit_concurrency is not None and (
len(self.connections) >= self.limit_concurrency
or len(self.tasks) >= self.limit_concurrency
len(self.connections) >= self.limit_concurrency or len(self.tasks) >= self.limit_concurrency
):
app = service_unavailable
message = "Exceeded concurrency limit."
@@ -309,9 +296,7 @@ class HttpToolsProtocol(asyncio.Protocol):
self.pipeline.appendleft((self.cycle, app))
def on_body(self, body: bytes) -> None:
if (
self.parser.should_upgrade() and self._should_upgrade()
) or self.cycle.response_complete:
if (self.parser.should_upgrade() and self._should_upgrade()) or self.cycle.response_complete:
return
self.cycle.body += body
if len(self.cycle.body) > HIGH_WATER_LIMIT:
@@ -319,9 +304,7 @@ class HttpToolsProtocol(asyncio.Protocol):
self.cycle.message_event.set()
def on_message_complete(self) -> None:
if (
self.parser.should_upgrade() and self._should_upgrade()
) or self.cycle.response_complete:
if (self.parser.should_upgrade() and self._should_upgrade()) or self.cycle.response_complete:
return
self.cycle.more_body = False
self.cycle.message_event.set()
@@ -333,22 +316,22 @@ class HttpToolsProtocol(asyncio.Protocol):
if self.transport.is_closing():
return
# Set a short Keep-Alive timeout.
self._unset_keepalive_if_required()
self.timeout_keep_alive_task = self.loop.call_later(
self.timeout_keep_alive, self.timeout_keep_alive_handler
)
# Unpause data reads if needed.
self.flow.resume_reading()
# Unblock any pipelined events.
# Unblock any pipelined events. If there are none, arm the
# Keep-Alive timeout instead.
if self.pipeline:
cycle, app = self.pipeline.pop()
task = self.loop.create_task(cycle.run_asgi(app))
task.add_done_callback(self.tasks.discard)
self.tasks.add(task)
else:
self.timeout_keep_alive_task = self.loop.call_later(
self.timeout_keep_alive, self.timeout_keep_alive_handler
)
def shutdown(self) -> None:
"""
@@ -363,13 +346,13 @@ class HttpToolsProtocol(asyncio.Protocol):
"""
Called by the transport when the write buffer exceeds the high water mark.
"""
self.flow.pause_writing()
self.flow.pause_writing() # pragma: full coverage
def resume_writing(self) -> None:
"""
Called by the transport when the write buffer drops below the low water mark.
"""
self.flow.resume_writing()
self.flow.resume_writing() # pragma: full coverage
def timeout_keep_alive_handler(self) -> None:
"""
@@ -383,13 +366,13 @@ class HttpToolsProtocol(asyncio.Protocol):
class RequestResponseCycle:
def __init__(
self,
scope: "HTTPScope",
scope: HTTPScope,
transport: asyncio.Transport,
flow: FlowControl,
logger: logging.Logger,
access_logger: logging.Logger,
access_log: bool,
default_headers: List[Tuple[bytes, bytes]],
default_headers: list[tuple[bytes, bytes]],
message_event: asyncio.Event,
expect_100_continue: bool,
keep_alive: bool,
@@ -417,11 +400,11 @@ class RequestResponseCycle:
# Response state
self.response_started = False
self.response_complete = False
self.chunked_encoding: Optional[bool] = None
self.chunked_encoding: bool | None = None
self.expected_content_length = 0
# ASGI exception wrapper
async def run_asgi(self, app: "ASGI3Application") -> None:
async def run_asgi(self, app: ASGI3Application) -> None:
try:
result = await app( # type: ignore[func-returns-value]
self.scope, self.receive, self.send
@@ -450,31 +433,28 @@ class RequestResponseCycle:
self.on_response = lambda: None
async def send_500_response(self) -> None:
response_start_event: "HTTPResponseStartEvent" = {
"type": "http.response.start",
"status": 500,
"headers": [
(b"content-type", b"text/plain; charset=utf-8"),
(b"connection", b"close"),
],
}
await self.send(response_start_event)
response_body_event: "HTTPResponseBodyEvent" = {
"type": "http.response.body",
"body": b"Internal Server Error",
"more_body": False,
}
await self.send(response_body_event)
await self.send(
{
"type": "http.response.start",
"status": 500,
"headers": [
(b"content-type", b"text/plain; charset=utf-8"),
(b"content-length", b"21"),
(b"connection", b"close"),
],
}
)
await self.send({"type": "http.response.body", "body": b"Internal Server Error", "more_body": False})
# ASGI interface
async def send(self, message: "ASGISendEvent") -> None:
async def send(self, message: ASGISendEvent) -> None:
message_type = message["type"]
if self.flow.write_paused and not self.disconnected:
await self.flow.drain()
await self.flow.drain() # pragma: full coverage
if self.disconnected:
return
return # pragma: full coverage
if not self.response_started:
# Sending response status line and headers
@@ -507,7 +487,7 @@ class RequestResponseCycle:
for name, value in headers:
if HEADER_RE.search(name):
raise RuntimeError("Invalid HTTP header name.")
raise RuntimeError("Invalid HTTP header name.") # pragma: full coverage
if HEADER_VALUE_RE.search(value):
raise RuntimeError("Invalid HTTP header value.")
@@ -522,11 +502,7 @@ class RequestResponseCycle:
self.keep_alive = False
content.extend([name, b": ", value, b"\r\n"])
if (
self.chunked_encoding is None
and self.scope["method"] != "HEAD"
and status_code not in (204, 304)
):
if self.chunked_encoding is None and self.scope["method"] != "HEAD" and status_code not in (204, 304):
# Neither content-length nor transfer-encoding specified
self.chunked_encoding = True
content.append(b"transfer-encoding: chunked\r\n")
@@ -577,7 +553,7 @@ class RequestResponseCycle:
msg = "Unexpected ASGI message '%s' sent, after response already completed."
raise RuntimeError(msg % message_type)
async def receive(self) -> "ASGIReceiveEvent":
async def receive(self) -> ASGIReceiveEvent:
if self.waiting_for_100_continue and not self.transport.is_closing():
self.transport.write(b"HTTP/1.1 100 Continue\r\n\r\n")
self.waiting_for_100_continue = False
@@ -587,15 +563,8 @@ class RequestResponseCycle:
await self.message_event.wait()
self.message_event.clear()
message: "Union[HTTPDisconnectEvent, HTTPRequestEvent]"
if self.disconnected or self.response_complete:
message = {"type": "http.disconnect"}
else:
message = {
"type": "http.request",
"body": self.body,
"more_body": self.more_body,
}
self.body = b""
return {"type": "http.disconnect"}
message: HTTPRequestEvent = {"type": "http.request", "body": self.body, "more_body": self.more_body}
self.body = b""
return message

View File

@@ -1,11 +1,15 @@
from __future__ import annotations
import asyncio
import urllib.parse
from typing import Optional, Tuple
from uvicorn._types import WWWScope
def get_remote_addr(transport: asyncio.Transport) -> Optional[Tuple[str, int]]:
class ClientDisconnected(OSError): ...
def get_remote_addr(transport: asyncio.Transport) -> tuple[str, int] | None:
socket_info = transport.get_extra_info("socket")
if socket_info is not None:
try:
@@ -22,7 +26,7 @@ def get_remote_addr(transport: asyncio.Transport) -> Optional[Tuple[str, int]]:
return None
def get_local_addr(transport: asyncio.Transport) -> Optional[Tuple[str, int]]:
def get_local_addr(transport: asyncio.Transport) -> tuple[str, int] | None:
socket_info = transport.get_extra_info("socket")
if socket_info is not None:
info = socket_info.getsockname()
@@ -38,17 +42,15 @@ def is_ssl(transport: asyncio.Transport) -> bool:
return bool(transport.get_extra_info("sslcontext"))
def get_client_addr(scope: "WWWScope") -> str:
def get_client_addr(scope: WWWScope) -> str:
client = scope.get("client")
if not client:
return ""
return "%s:%d" % client
def get_path_with_query_string(scope: "WWWScope") -> str:
def get_path_with_query_string(scope: WWWScope) -> str:
path_with_query_string = urllib.parse.quote(scope["path"])
if scope["query_string"]:
path_with_query_string = "{}?{}".format(
path_with_query_string, scope["query_string"].decode("ascii")
)
path_with_query_string = "{}?{}".format(path_with_query_string, scope["query_string"].decode("ascii"))
return path_with_query_string

View File

@@ -1,7 +1,9 @@
import asyncio
import typing
from __future__ import annotations
AutoWebSocketsProtocol: typing.Optional[typing.Callable[..., asyncio.Protocol]]
import asyncio
from typing import Callable
AutoWebSocketsProtocol: Callable[..., asyncio.Protocol] | None
try:
import websockets # noqa
except ImportError: # pragma: no cover

View File

@@ -1,40 +1,40 @@
from __future__ import annotations
import asyncio
import http
import logging
from typing import (
Any,
Dict,
List,
Literal,
Optional,
Sequence,
Tuple,
Union,
cast,
)
from collections.abc import Sequence
from typing import Any, Literal, Optional, cast
from urllib.parse import unquote
import websockets
import websockets.legacy.handshake
from websockets.datastructures import Headers
from websockets.exceptions import ConnectionClosed
from websockets.extensions.base import ServerExtensionFactory
from websockets.extensions.permessage_deflate import ServerPerMessageDeflateFactory
from websockets.legacy.server import HTTPResponse
from websockets.server import WebSocketServerProtocol
from websockets.typing import Subprotocol
from uvicorn._types import (
ASGI3Application,
ASGISendEvent,
WebSocketAcceptEvent,
WebSocketCloseEvent,
WebSocketConnectEvent,
WebSocketDisconnectEvent,
WebSocketReceiveEvent,
WebSocketResponseBodyEvent,
WebSocketResponseStartEvent,
WebSocketScope,
WebSocketSendEvent,
)
from uvicorn.config import Config
from uvicorn.logging import TRACE_LOG_LEVEL
from uvicorn.protocols.utils import (
ClientDisconnected,
get_client_addr,
get_local_addr,
get_path_with_query_string,
get_remote_addr,
@@ -57,20 +57,21 @@ class Server:
class WebSocketProtocol(WebSocketServerProtocol):
extra_headers: List[Tuple[str, str]]
extra_headers: list[tuple[str, str]]
logger: logging.Logger | logging.LoggerAdapter[Any]
def __init__(
self,
config: Config,
server_state: ServerState,
app_state: Dict[str, Any],
_loop: Optional[asyncio.AbstractEventLoop] = None,
app_state: dict[str, Any],
_loop: asyncio.AbstractEventLoop | None = None,
):
if not config.loaded:
config.load()
self.config = config
self.app = config.loaded_app
self.app = cast(ASGI3Application, config.loaded_app)
self.loop = _loop or asyncio.get_event_loop()
self.root_path = config.root_path
self.app_state = app_state
@@ -81,23 +82,23 @@ class WebSocketProtocol(WebSocketServerProtocol):
# Connection state
self.transport: asyncio.Transport = None # type: ignore[assignment]
self.server: Optional[Tuple[str, int]] = None
self.client: Optional[Tuple[str, int]] = None
self.server: tuple[str, int] | None = None
self.client: tuple[str, int] | None = None
self.scheme: Literal["wss", "ws"] = None # type: ignore[assignment]
# Connection events
self.scope: WebSocketScope = None # type: ignore[assignment]
self.scope: WebSocketScope
self.handshake_started_event = asyncio.Event()
self.handshake_completed_event = asyncio.Event()
self.closed_event = asyncio.Event()
self.initial_response: Optional[HTTPResponse] = None
self.initial_response: HTTPResponse | None = None
self.connect_sent = False
self.lost_connection_before_handshake = False
self.accepted_subprotocol: Optional[Subprotocol] = None
self.accepted_subprotocol: Subprotocol | None = None
self.ws_server: Server = Server() # type: ignore[assignment]
extensions = []
extensions: list[ServerExtensionFactory] = []
if self.config.ws_per_message_deflate:
extensions.append(ServerPerMessageDeflateFactory())
@@ -113,8 +114,7 @@ class WebSocketProtocol(WebSocketServerProtocol):
)
self.server_header = None
self.extra_headers = [
(name.decode("latin-1"), value.decode("latin-1"))
for name, value in server_state.default_headers
(name.decode("latin-1"), value.decode("latin-1")) for name, value in server_state.default_headers
]
def connection_made( # type: ignore[override]
@@ -132,16 +132,14 @@ class WebSocketProtocol(WebSocketServerProtocol):
super().connection_made(transport)
def connection_lost(self, exc: Optional[Exception]) -> None:
def connection_lost(self, exc: Exception | None) -> None:
self.connections.remove(self)
if self.logger.isEnabledFor(TRACE_LOG_LEVEL):
prefix = "%s:%d - " % self.client if self.client else ""
self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket connection lost", prefix)
self.lost_connection_before_handshake = (
not self.handshake_completed_event.is_set()
)
self.lost_connection_before_handshake = not self.handshake_completed_event.is_set()
self.handshake_completed_event.set()
super().connection_lost(exc)
if exc is None:
@@ -155,12 +153,10 @@ class WebSocketProtocol(WebSocketServerProtocol):
self.send_500_response()
self.transport.close()
def on_task_complete(self, task: asyncio.Task) -> None:
def on_task_complete(self, task: asyncio.Task[None]) -> None:
self.tasks.discard(task)
async def process_request(
self, path: str, headers: Headers
) -> Optional[HTTPResponse]:
async def process_request(self, path: str, request_headers: Headers) -> HTTPResponse | None:
"""
This hook is called to determine if the websocket should return
an HTTP response and close.
@@ -171,31 +167,35 @@ class WebSocketProtocol(WebSocketServerProtocol):
"""
path_portion, _, query_string = path.partition("?")
websockets.legacy.handshake.check_request(headers)
websockets.legacy.handshake.check_request(request_headers)
subprotocols = []
for header in headers.get_all("Sec-WebSocket-Protocol"):
subprotocols: list[str] = []
for header in request_headers.get_all("Sec-WebSocket-Protocol"):
subprotocols.extend([token.strip() for token in header.split(",")])
asgi_headers = [
(name.encode("ascii"), value.encode("ascii", errors="surrogateescape"))
for name, value in headers.raw_items()
for name, value in request_headers.raw_items()
]
path = unquote(path_portion)
full_path = self.root_path + path
full_raw_path = self.root_path.encode("ascii") + path_portion.encode("ascii")
self.scope = {
"type": "websocket",
"asgi": {"version": self.config.asgi_version, "spec_version": "2.3"},
"asgi": {"version": self.config.asgi_version, "spec_version": "2.4"},
"http_version": "1.1",
"scheme": self.scheme,
"server": self.server,
"client": self.client,
"root_path": self.root_path,
"path": unquote(path_portion),
"raw_path": path_portion.encode("ascii"),
"path": full_path,
"raw_path": full_raw_path,
"query_string": query_string.encode("ascii"),
"headers": asgi_headers,
"subprotocols": subprotocols,
"state": self.app_state.copy(),
"extensions": {"websocket.http.response": {}},
}
task = self.loop.create_task(self.run_asgi())
task.add_done_callback(self.on_task_complete)
@@ -204,8 +204,8 @@ class WebSocketProtocol(WebSocketServerProtocol):
return self.initial_response
def process_subprotocol(
self, headers: Headers, available_subprotocols: Optional[Sequence[Subprotocol]]
) -> Optional[Subprotocol]:
self, headers: Headers, available_subprotocols: Sequence[Subprotocol] | None
) -> Subprotocol | None:
"""
We override the standard 'process_subprotocol' behavior here so that
we return whatever subprotocol is sent in the 'accept' message.
@@ -215,8 +215,7 @@ class WebSocketProtocol(WebSocketServerProtocol):
def send_500_response(self) -> None:
msg = b"Internal Server Error"
content = [
b"HTTP/1.1 500 Internal Server Error\r\n"
b"content-type: text/plain; charset=utf-8\r\n",
b"HTTP/1.1 500 Internal Server Error\r\ncontent-type: text/plain; charset=utf-8\r\n",
b"content-length: " + str(len(msg)).encode("ascii") + b"\r\n",
b"connection: close\r\n",
b"\r\n",
@@ -224,12 +223,10 @@ class WebSocketProtocol(WebSocketServerProtocol):
]
self.transport.write(b"".join(content))
# Allow handler task to terminate cleanly, as websockets doesn't cancel it by
# itself (see https://github.com/encode/uvicorn/issues/920)
# itself (see https://github.com/Kludex/uvicorn/issues/920)
self.handshake_started_event.set()
async def ws_handler( # type: ignore[override]
self, protocol: WebSocketServerProtocol, path: str
) -> Any:
async def ws_handler(self, protocol: WebSocketServerProtocol, path: str) -> Any: # type: ignore[override]
"""
This is the main handler function for the 'websockets' implementation
to call into. We just wait for close then return, and instead allow
@@ -244,11 +241,13 @@ class WebSocketProtocol(WebSocketServerProtocol):
termination states.
"""
try:
result = await self.app(self.scope, self.asgi_receive, self.asgi_send)
except BaseException as exc:
result = await self.app(self.scope, self.asgi_receive, self.asgi_send) # type: ignore[func-returns-value]
except ClientDisconnected: # pragma: full coverage
self.closed_event.set()
msg = "Exception in ASGI application\n"
self.logger.error(msg, exc_info=exc)
self.transport.close()
except BaseException:
self.closed_event.set()
self.logger.exception("Exception in ASGI application\n")
if not self.handshake_started_event.is_set():
self.send_500_response()
else:
@@ -257,17 +256,15 @@ class WebSocketProtocol(WebSocketServerProtocol):
else:
self.closed_event.set()
if not self.handshake_started_event.is_set():
msg = "ASGI callable returned without sending handshake."
self.logger.error(msg)
self.logger.error("ASGI callable returned without sending handshake.")
self.send_500_response()
self.transport.close()
elif result is not None:
msg = "ASGI callable should return None, but returned '%s'."
self.logger.error(msg, result)
self.logger.error("ASGI callable should return None, but returned '%s'.", result)
await self.handshake_completed_event.wait()
self.transport.close()
async def asgi_send(self, message: "ASGISendEvent") -> None:
async def asgi_send(self, message: ASGISendEvent) -> None:
message_type = message["type"]
if not self.handshake_started_event.is_set():
@@ -275,13 +272,11 @@ class WebSocketProtocol(WebSocketServerProtocol):
message = cast("WebSocketAcceptEvent", message)
self.logger.info(
'%s - "WebSocket %s" [accepted]',
self.scope["client"],
get_client_addr(self.scope),
get_path_with_query_string(self.scope),
)
self.initial_response = None
self.accepted_subprotocol = cast(
Optional[Subprotocol], message.get("subprotocol")
)
self.accepted_subprotocol = cast(Optional[Subprotocol], message.get("subprotocol"))
if "headers" in message:
self.extra_headers.extend(
# ASGI spec requires bytes
@@ -295,53 +290,76 @@ class WebSocketProtocol(WebSocketServerProtocol):
message = cast("WebSocketCloseEvent", message)
self.logger.info(
'%s - "WebSocket %s" 403',
self.scope["client"],
get_client_addr(self.scope),
get_path_with_query_string(self.scope),
)
self.initial_response = (http.HTTPStatus.FORBIDDEN, [], b"")
self.handshake_started_event.set()
self.closed_event.set()
elif message_type == "websocket.http.response.start":
message = cast("WebSocketResponseStartEvent", message)
self.logger.info(
'%s - "WebSocket %s" %d',
get_client_addr(self.scope),
get_path_with_query_string(self.scope),
message["status"],
)
# websockets requires the status to be an enum. look it up.
status = http.HTTPStatus(message["status"])
headers = [
(name.decode("latin-1"), value.decode("latin-1")) for name, value in message.get("headers", [])
]
self.initial_response = (status, headers, b"")
self.handshake_started_event.set()
else:
msg = (
"Expected ASGI message 'websocket.accept' or 'websocket.close', "
"but got '%s'."
"Expected ASGI message 'websocket.accept', 'websocket.close', "
"or 'websocket.http.response.start' but got '%s'."
)
raise RuntimeError(msg % message_type)
elif not self.closed_event.is_set():
elif not self.closed_event.is_set() and self.initial_response is None:
await self.handshake_completed_event.wait()
if message_type == "websocket.send":
message = cast("WebSocketSendEvent", message)
bytes_data = message.get("bytes")
text_data = message.get("text")
data = text_data if bytes_data is None else bytes_data
await self.send(data) # type: ignore[arg-type]
try:
if message_type == "websocket.send":
message = cast("WebSocketSendEvent", message)
bytes_data = message.get("bytes")
text_data = message.get("text")
data = text_data if bytes_data is None else bytes_data
await self.send(data) # type: ignore[arg-type]
elif message_type == "websocket.close":
message = cast("WebSocketCloseEvent", message)
code = message.get("code", 1000)
reason = message.get("reason", "") or ""
await self.close(code, reason)
self.closed_event.set()
elif message_type == "websocket.close":
message = cast("WebSocketCloseEvent", message)
code = message.get("code", 1000)
reason = message.get("reason", "") or ""
await self.close(code, reason)
self.closed_event.set()
else:
msg = "Expected ASGI message 'websocket.send' or 'websocket.close', but got '%s'."
raise RuntimeError(msg % message_type)
except ConnectionClosed as exc:
raise ClientDisconnected from exc
elif self.initial_response is not None:
if message_type == "websocket.http.response.body":
message = cast("WebSocketResponseBodyEvent", message)
body = self.initial_response[2] + message["body"]
self.initial_response = self.initial_response[:2] + (body,)
if not message.get("more_body", False):
self.closed_event.set()
else:
msg = (
"Expected ASGI message 'websocket.send' or 'websocket.close',"
" but got '%s'."
)
msg = "Expected ASGI message 'websocket.http.response.body' but got '%s'."
raise RuntimeError(msg % message_type)
else:
msg = "Unexpected ASGI message '%s', after sending 'websocket.close'."
msg = "Unexpected ASGI message '%s', after sending 'websocket.close' or response already completed."
raise RuntimeError(msg % message_type)
async def asgi_receive(
self,
) -> Union[
"WebSocketDisconnectEvent", "WebSocketConnectEvent", "WebSocketReceiveEvent"
]:
async def asgi_receive(self) -> WebSocketDisconnectEvent | WebSocketConnectEvent | WebSocketReceiveEvent:
if not self.connect_sent:
self.connect_sent = True
return {"type": "websocket.connect"}
@@ -358,19 +376,12 @@ class WebSocketProtocol(WebSocketServerProtocol):
try:
data = await self.recv()
except ConnectionClosed as exc:
except ConnectionClosed:
self.closed_event.set()
if self.ws_server.closing:
return {"type": "websocket.disconnect", "code": 1012}
return {"type": "websocket.disconnect", "code": exc.code}
msg: WebSocketReceiveEvent = { # type: ignore[typeddict-item]
"type": "websocket.receive"
}
return {"type": "websocket.disconnect", "code": self.close_code or 1005, "reason": self.close_reason}
if isinstance(data, str):
msg["text"] = data
else:
msg["bytes"] = data
return msg
return {"type": "websocket.receive", "text": data}
return {"type": "websocket.receive", "bytes": data}

View File

@@ -0,0 +1,417 @@
from __future__ import annotations
import asyncio
import logging
import sys
from asyncio.transports import BaseTransport, Transport
from http import HTTPStatus
from typing import Any, Literal, cast
from urllib.parse import unquote
from websockets.exceptions import InvalidState
from websockets.extensions.permessage_deflate import ServerPerMessageDeflateFactory
from websockets.frames import Frame, Opcode
from websockets.http11 import Request
from websockets.server import ServerProtocol
from uvicorn._types import (
ASGIReceiveEvent,
ASGISendEvent,
WebSocketAcceptEvent,
WebSocketCloseEvent,
WebSocketResponseBodyEvent,
WebSocketResponseStartEvent,
WebSocketScope,
WebSocketSendEvent,
)
from uvicorn.config import Config
from uvicorn.logging import TRACE_LOG_LEVEL
from uvicorn.protocols.utils import (
ClientDisconnected,
get_local_addr,
get_path_with_query_string,
get_remote_addr,
is_ssl,
)
from uvicorn.server import ServerState
if sys.version_info >= (3, 11): # pragma: no cover
from typing import assert_never
else: # pragma: no cover
from typing_extensions import assert_never
class WebSocketsSansIOProtocol(asyncio.Protocol):
def __init__(
self,
config: Config,
server_state: ServerState,
app_state: dict[str, Any],
_loop: asyncio.AbstractEventLoop | None = None,
) -> None:
if not config.loaded:
config.load() # pragma: no cover
self.config = config
self.app = config.loaded_app
self.loop = _loop or asyncio.get_event_loop()
self.logger = logging.getLogger("uvicorn.error")
self.root_path = config.root_path
self.app_state = app_state
# Shared server state
self.connections = server_state.connections
self.tasks = server_state.tasks
self.default_headers = server_state.default_headers
# Connection state
self.transport: asyncio.Transport = None # type: ignore[assignment]
self.server: tuple[str, int] | None = None
self.client: tuple[str, int] | None = None
self.scheme: Literal["wss", "ws"] = None # type: ignore[assignment]
# WebSocket state
self.queue: asyncio.Queue[ASGIReceiveEvent] = asyncio.Queue()
self.handshake_initiated = False
self.handshake_complete = False
self.close_sent = False
self.initial_response: tuple[int, list[tuple[str, str]], bytes] | None = None
extensions = []
if self.config.ws_per_message_deflate:
extensions = [
ServerPerMessageDeflateFactory(
server_max_window_bits=12,
client_max_window_bits=12,
compress_settings={"memLevel": 5},
)
]
self.conn = ServerProtocol(
extensions=extensions,
max_size=self.config.ws_max_size,
logger=logging.getLogger("uvicorn.error"),
)
self.read_paused = False
self.writable = asyncio.Event()
self.writable.set()
# Buffers
self.bytes = b""
def connection_made(self, transport: BaseTransport) -> None:
"""Called when a connection is made."""
transport = cast(Transport, transport)
self.connections.add(self)
self.transport = transport
self.server = get_local_addr(transport)
self.client = get_remote_addr(transport)
self.scheme = "wss" if is_ssl(transport) else "ws"
if self.logger.level <= TRACE_LOG_LEVEL:
prefix = "%s:%d - " % self.client if self.client else ""
self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket connection made", prefix)
def connection_lost(self, exc: Exception | None) -> None:
code = 1005 if self.handshake_complete else 1006
self.queue.put_nowait({"type": "websocket.disconnect", "code": code})
self.connections.remove(self)
if self.logger.level <= TRACE_LOG_LEVEL:
prefix = "%s:%d - " % self.client if self.client else ""
self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket connection lost", prefix)
self.handshake_complete = True
if exc is None:
self.transport.close()
def eof_received(self) -> None:
pass
def shutdown(self) -> None:
if self.handshake_complete:
self.queue.put_nowait({"type": "websocket.disconnect", "code": 1012})
self.conn.send_close(1012)
output = self.conn.data_to_send()
self.transport.write(b"".join(output))
else:
self.send_500_response()
self.transport.close()
def data_received(self, data: bytes) -> None:
self.conn.receive_data(data)
if self.conn.parser_exc is not None: # pragma: no cover
self.handle_parser_exception()
return
self.handle_events()
def handle_events(self) -> None:
for event in self.conn.events_received():
if isinstance(event, Request):
self.handle_connect(event)
if isinstance(event, Frame):
if event.opcode == Opcode.CONT:
self.handle_cont(event) # pragma: no cover
elif event.opcode == Opcode.TEXT:
self.handle_text(event)
elif event.opcode == Opcode.BINARY:
self.handle_bytes(event)
elif event.opcode == Opcode.PING:
self.handle_ping()
elif event.opcode == Opcode.PONG:
pass # pragma: no cover
elif event.opcode == Opcode.CLOSE:
self.handle_close(event)
else:
assert_never(event.opcode) # pragma: no cover
# Event handlers
def handle_connect(self, event: Request) -> None:
self.request = event
self.response = self.conn.accept(event)
self.handshake_initiated = True
if self.response.status_code != 101:
self.handshake_complete = True
self.close_sent = True
self.conn.send_response(self.response)
output = self.conn.data_to_send()
self.transport.write(b"".join(output))
self.transport.close()
return
headers = [
(key.encode("ascii"), value.encode("ascii", errors="surrogateescape"))
for key, value in event.headers.raw_items()
]
raw_path, _, query_string = event.path.partition("?")
self.scope: WebSocketScope = {
"type": "websocket",
"asgi": {"version": self.config.asgi_version, "spec_version": "2.3"},
"http_version": "1.1",
"scheme": self.scheme,
"server": self.server,
"client": self.client,
"root_path": self.root_path,
"path": unquote(raw_path),
"raw_path": raw_path.encode("ascii"),
"query_string": query_string.encode("ascii"),
"headers": headers,
"subprotocols": event.headers.get_all("Sec-WebSocket-Protocol"),
"state": self.app_state.copy(),
"extensions": {"websocket.http.response": {}},
}
self.queue.put_nowait({"type": "websocket.connect"})
task = self.loop.create_task(self.run_asgi())
task.add_done_callback(self.on_task_complete)
self.tasks.add(task)
def handle_cont(self, event: Frame) -> None: # pragma: no cover
self.bytes += event.data
if event.fin:
self.send_receive_event_to_app()
def handle_text(self, event: Frame) -> None:
self.bytes = event.data
self.curr_msg_data_type: Literal["text", "bytes"] = "text"
if event.fin:
self.send_receive_event_to_app()
def handle_bytes(self, event: Frame) -> None:
self.bytes = event.data
self.curr_msg_data_type = "bytes"
if event.fin:
self.send_receive_event_to_app()
def send_receive_event_to_app(self) -> None:
if self.curr_msg_data_type == "text":
try:
self.queue.put_nowait({"type": "websocket.receive", "text": self.bytes.decode()})
except UnicodeDecodeError: # pragma: no cover
self.logger.exception("Invalid UTF-8 sequence received from client.")
self.conn.send_close(1007)
self.handle_parser_exception()
return
else:
self.queue.put_nowait({"type": "websocket.receive", "bytes": self.bytes})
if not self.read_paused:
self.read_paused = True
self.transport.pause_reading()
def handle_ping(self) -> None:
output = self.conn.data_to_send()
self.transport.write(b"".join(output))
def handle_close(self, event: Frame) -> None:
if not self.close_sent and not self.transport.is_closing():
assert self.conn.close_rcvd is not None
code = self.conn.close_rcvd.code
reason = self.conn.close_rcvd.reason
self.queue.put_nowait({"type": "websocket.disconnect", "code": code, "reason": reason})
output = self.conn.data_to_send()
self.transport.write(b"".join(output))
self.transport.close()
def handle_parser_exception(self) -> None: # pragma: no cover
assert self.conn.close_sent is not None
code = self.conn.close_sent.code
reason = self.conn.close_sent.reason
self.queue.put_nowait({"type": "websocket.disconnect", "code": code, "reason": reason})
output = self.conn.data_to_send()
self.transport.write(b"".join(output))
self.close_sent = True
self.transport.close()
def on_task_complete(self, task: asyncio.Task[None]) -> None:
self.tasks.discard(task)
async def run_asgi(self) -> None:
try:
result = await self.app(self.scope, self.receive, self.send)
except ClientDisconnected:
self.transport.close() # pragma: no cover
except BaseException:
self.logger.exception("Exception in ASGI application\n")
self.send_500_response()
self.transport.close()
else:
if not self.handshake_complete:
self.logger.error("ASGI callable returned without completing handshake.")
self.send_500_response()
self.transport.close()
elif result is not None:
self.logger.error("ASGI callable should return None, but returned '%s'.", result)
self.transport.close()
def send_500_response(self) -> None:
if self.initial_response or self.handshake_complete:
return
response = self.conn.reject(500, "Internal Server Error")
self.conn.send_response(response)
output = self.conn.data_to_send()
self.transport.write(b"".join(output))
async def send(self, message: ASGISendEvent) -> None:
await self.writable.wait()
message_type = message["type"]
if not self.handshake_complete and self.initial_response is None:
if message_type == "websocket.accept":
message = cast(WebSocketAcceptEvent, message)
self.logger.info(
'%s - "WebSocket %s" [accepted]',
self.scope["client"],
get_path_with_query_string(self.scope),
)
headers = [
(name.decode("latin-1").lower(), value.decode("latin-1").lower())
for name, value in (self.default_headers + list(message.get("headers", [])))
]
accepted_subprotocol = message.get("subprotocol")
if accepted_subprotocol:
headers.append(("Sec-WebSocket-Protocol", accepted_subprotocol))
self.response.headers.update(headers)
if not self.transport.is_closing():
self.handshake_complete = True
self.conn.send_response(self.response)
output = self.conn.data_to_send()
self.transport.write(b"".join(output))
elif message_type == "websocket.close":
message = cast(WebSocketCloseEvent, message)
self.queue.put_nowait({"type": "websocket.disconnect", "code": 1006})
self.logger.info(
'%s - "WebSocket %s" 403',
self.scope["client"],
get_path_with_query_string(self.scope),
)
response = self.conn.reject(HTTPStatus.FORBIDDEN, "")
self.conn.send_response(response)
output = self.conn.data_to_send()
self.close_sent = True
self.handshake_complete = True
self.transport.write(b"".join(output))
self.transport.close()
elif message_type == "websocket.http.response.start" and self.initial_response is None:
message = cast(WebSocketResponseStartEvent, message)
if not (100 <= message["status"] < 600):
raise RuntimeError("Invalid HTTP status code '%d' in response." % message["status"])
self.logger.info(
'%s - "WebSocket %s" %d',
self.scope["client"],
get_path_with_query_string(self.scope),
message["status"],
)
headers = [
(name.decode("latin-1"), value.decode("latin-1"))
for name, value in list(message.get("headers", []))
]
self.initial_response = (message["status"], headers, b"")
else:
msg = (
"Expected ASGI message 'websocket.accept', 'websocket.close' "
"or 'websocket.http.response.start' "
"but got '%s'."
)
raise RuntimeError(msg % message_type)
elif not self.close_sent and self.initial_response is None:
try:
if message_type == "websocket.send":
message = cast(WebSocketSendEvent, message)
bytes_data = message.get("bytes")
text_data = message.get("text")
if text_data:
self.conn.send_text(text_data.encode())
elif bytes_data:
self.conn.send_binary(bytes_data)
output = self.conn.data_to_send()
self.transport.write(b"".join(output))
elif message_type == "websocket.close" and not self.transport.is_closing():
message = cast(WebSocketCloseEvent, message)
code = message.get("code", 1000)
reason = message.get("reason", "") or ""
self.queue.put_nowait({"type": "websocket.disconnect", "code": code, "reason": reason})
self.conn.send_close(code, reason)
output = self.conn.data_to_send()
self.transport.write(b"".join(output))
self.close_sent = True
self.transport.close()
else:
msg = "Expected ASGI message 'websocket.send' or 'websocket.close', but got '%s'."
raise RuntimeError(msg % message_type)
except InvalidState:
raise ClientDisconnected()
elif self.initial_response is not None:
if message_type == "websocket.http.response.body":
message = cast(WebSocketResponseBodyEvent, message)
body = self.initial_response[2] + message["body"]
self.initial_response = self.initial_response[:2] + (body,)
if not message.get("more_body", False):
response = self.conn.reject(self.initial_response[0], body.decode())
response.headers.update(self.initial_response[1])
self.queue.put_nowait({"type": "websocket.disconnect", "code": 1006})
self.conn.send_response(response)
output = self.conn.data_to_send()
self.close_sent = True
self.transport.write(b"".join(output))
self.transport.close()
else: # pragma: no cover
msg = "Expected ASGI message 'websocket.http.response.body' but got '%s'."
raise RuntimeError(msg % message_type)
else:
msg = "Unexpected ASGI message '%s', after sending 'websocket.close'."
raise RuntimeError(msg % message_type)
async def receive(self) -> ASGIReceiveEvent:
message = await self.queue.get()
if self.read_paused and self.queue.empty():
self.read_paused = False
self.transport.resume_reading()
return message

View File

@@ -1,27 +1,32 @@
from __future__ import annotations
import asyncio
import logging
import typing
from typing import Literal
from typing import Any, Literal, cast
from urllib.parse import unquote
import wsproto
from wsproto import ConnectionType, events
from wsproto.connection import ConnectionState
from wsproto.extensions import Extension, PerMessageDeflate
from wsproto.utilities import RemoteProtocolError
from wsproto.utilities import LocalProtocolError, RemoteProtocolError
from uvicorn._types import (
ASGI3Application,
ASGISendEvent,
WebSocketAcceptEvent,
WebSocketCloseEvent,
WebSocketEvent,
WebSocketReceiveEvent,
WebSocketResponseBodyEvent,
WebSocketResponseStartEvent,
WebSocketScope,
WebSocketSendEvent,
)
from uvicorn.config import Config
from uvicorn.logging import TRACE_LOG_LEVEL
from uvicorn.protocols.utils import (
ClientDisconnected,
get_client_addr,
get_local_addr,
get_path_with_query_string,
get_remote_addr,
@@ -35,14 +40,14 @@ class WSProtocol(asyncio.Protocol):
self,
config: Config,
server_state: ServerState,
app_state: typing.Dict[str, typing.Any],
_loop: typing.Optional[asyncio.AbstractEventLoop] = None,
app_state: dict[str, Any],
_loop: asyncio.AbstractEventLoop | None = None,
) -> None:
if not config.loaded:
config.load()
config.load() # pragma: full coverage
self.config = config
self.app = config.loaded_app
self.app = cast(ASGI3Application, config.loaded_app)
self.loop = _loop or asyncio.get_event_loop()
self.logger = logging.getLogger("uvicorn.error")
self.root_path = config.root_path
@@ -55,15 +60,18 @@ class WSProtocol(asyncio.Protocol):
# Connection state
self.transport: asyncio.Transport = None # type: ignore[assignment]
self.server: typing.Optional[typing.Tuple[str, int]] = None
self.client: typing.Optional[typing.Tuple[str, int]] = None
self.server: tuple[str, int] | None = None
self.client: tuple[str, int] | None = None
self.scheme: Literal["wss", "ws"] = None # type: ignore[assignment]
# WebSocket state
self.queue: asyncio.Queue["WebSocketEvent"] = asyncio.Queue()
self.queue: asyncio.Queue[WebSocketEvent] = asyncio.Queue()
self.handshake_complete = False
self.close_sent = False
# Rejection state
self.response_started = False
self.conn = wsproto.WSConnection(connection_type=ConnectionType.SERVER)
self.read_paused = False
@@ -89,7 +97,7 @@ class WSProtocol(asyncio.Protocol):
prefix = "%s:%d - " % self.client if self.client else ""
self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket connection made", prefix)
def connection_lost(self, exc: typing.Optional[Exception]) -> None:
def connection_lost(self, exc: Exception | None) -> None:
code = 1005 if self.handshake_complete else 1006
self.queue.put_nowait({"type": "websocket.disconnect", "code": code})
self.connections.remove(self)
@@ -132,13 +140,13 @@ class WSProtocol(asyncio.Protocol):
"""
Called by the transport when the write buffer exceeds the high water mark.
"""
self.writable.clear()
self.writable.clear() # pragma: full coverage
def resume_writing(self) -> None:
"""
Called by the transport when the write buffer drops below the low water mark.
"""
self.writable.set()
self.writable.set() # pragma: full coverage
def shutdown(self) -> None:
if self.handshake_complete:
@@ -149,7 +157,7 @@ class WSProtocol(asyncio.Protocol):
self.send_500_response()
self.transport.close()
def on_task_complete(self, task: asyncio.Task) -> None:
def on_task_complete(self, task: asyncio.Task[None]) -> None:
self.tasks.discard(task)
# Event handlers
@@ -158,20 +166,24 @@ class WSProtocol(asyncio.Protocol):
headers = [(b"host", event.host.encode())]
headers += [(key.lower(), value) for key, value in event.extra_headers]
raw_path, _, query_string = event.target.partition("?")
self.scope: "WebSocketScope" = {
path = unquote(raw_path)
full_path = self.root_path + path
full_raw_path = self.root_path.encode("ascii") + raw_path.encode("ascii")
self.scope: WebSocketScope = {
"type": "websocket",
"asgi": {"version": self.config.asgi_version, "spec_version": "2.3"},
"asgi": {"version": self.config.asgi_version, "spec_version": "2.4"},
"http_version": "1.1",
"scheme": self.scheme,
"server": self.server,
"client": self.client,
"root_path": self.root_path,
"path": unquote(raw_path),
"raw_path": raw_path.encode("ascii"),
"path": full_path,
"raw_path": full_raw_path,
"query_string": query_string.encode("ascii"),
"headers": headers,
"subprotocols": event.subprotocols,
"state": self.app_state.copy(),
"extensions": {"websocket.http.response": {}},
}
self.queue.put_nowait({"type": "websocket.connect"})
task = self.loop.create_task(self.run_asgi())
@@ -181,11 +193,7 @@ class WSProtocol(asyncio.Protocol):
def handle_text(self, event: events.TextMessage) -> None:
self.text += event.data
if event.message_finished:
msg: "WebSocketReceiveEvent" = { # type: ignore[typeddict-item]
"type": "websocket.receive",
"text": self.text,
}
self.queue.put_nowait(msg)
self.queue.put_nowait({"type": "websocket.receive", "text": self.text})
self.text = ""
if not self.read_paused:
self.read_paused = True
@@ -195,11 +203,7 @@ class WSProtocol(asyncio.Protocol):
self.bytes += event.data
# todo: we may want to guard the size of self.bytes and self.text
if event.message_finished:
msg: "WebSocketReceiveEvent" = { # type: ignore[typeddict-item]
"type": "websocket.receive",
"bytes": self.bytes,
}
self.queue.put_nowait(msg)
self.queue.put_nowait({"type": "websocket.receive", "bytes": self.bytes})
self.bytes = b""
if not self.read_paused:
self.read_paused = True
@@ -208,62 +212,58 @@ class WSProtocol(asyncio.Protocol):
def handle_close(self, event: events.CloseConnection) -> None:
if self.conn.state == ConnectionState.REMOTE_CLOSING:
self.transport.write(self.conn.send(event.response()))
self.queue.put_nowait({"type": "websocket.disconnect", "code": event.code})
self.queue.put_nowait({"type": "websocket.disconnect", "code": event.code, "reason": event.reason})
self.transport.close()
def handle_ping(self, event: events.Ping) -> None:
self.transport.write(self.conn.send(event.response()))
def send_500_response(self) -> None:
headers = [
if self.response_started or self.handshake_complete:
return # we cannot send responses anymore
headers: list[tuple[bytes, bytes]] = [
(b"content-type", b"text/plain; charset=utf-8"),
(b"connection", b"close"),
(b"content-length", b"21"),
]
output = self.conn.send(
wsproto.events.RejectConnection(
status_code=500, headers=headers, has_body=True
)
)
output += self.conn.send(
wsproto.events.RejectData(data=b"Internal Server Error")
)
output = self.conn.send(wsproto.events.RejectConnection(status_code=500, headers=headers, has_body=True))
output += self.conn.send(wsproto.events.RejectData(data=b"Internal Server Error"))
self.transport.write(output)
async def run_asgi(self) -> None:
try:
result = await self.app(self.scope, self.receive, self.send)
result = await self.app(self.scope, self.receive, self.send) # type: ignore[func-returns-value]
except ClientDisconnected:
self.transport.close() # pragma: full coverage
except BaseException:
self.logger.exception("Exception in ASGI application\n")
if not self.handshake_complete:
self.send_500_response()
self.send_500_response()
self.transport.close()
else:
if not self.handshake_complete:
msg = "ASGI callable returned without completing handshake."
self.logger.error(msg)
self.logger.error("ASGI callable returned without completing handshake.")
self.send_500_response()
self.transport.close()
elif result is not None:
msg = "ASGI callable should return None, but returned '%s'."
self.logger.error(msg, result)
self.logger.error("ASGI callable should return None, but returned '%s'.", result)
self.transport.close()
async def send(self, message: "ASGISendEvent") -> None:
async def send(self, message: ASGISendEvent) -> None:
await self.writable.wait()
message_type = message["type"]
if not self.handshake_complete:
if message_type == "websocket.accept":
message = typing.cast("WebSocketAcceptEvent", message)
message = cast(WebSocketAcceptEvent, message)
self.logger.info(
'%s - "WebSocket %s" [accepted]',
self.scope["client"],
get_client_addr(self.scope),
get_path_with_query_string(self.scope),
)
subprotocol = message.get("subprotocol")
extra_headers = self.default_headers + list(message.get("headers", []))
extensions: typing.List[Extension] = []
extensions: list[Extension] = []
if self.config.ws_per_message_deflate:
extensions.append(PerMessageDeflate())
if not self.transport.is_closing():
@@ -281,7 +281,7 @@ class WSProtocol(asyncio.Protocol):
self.queue.put_nowait({"type": "websocket.disconnect", "code": 1006})
self.logger.info(
'%s - "WebSocket %s" 403',
self.scope["client"],
get_client_addr(self.scope),
get_path_with_query_string(self.scope),
)
self.handshake_complete = True
@@ -291,50 +291,85 @@ class WSProtocol(asyncio.Protocol):
self.transport.write(output)
self.transport.close()
elif message_type == "websocket.http.response.start":
message = cast(WebSocketResponseStartEvent, message)
# ensure status code is in the valid range
if not (100 <= message["status"] < 600):
msg = "Invalid HTTP status code '%d' in response."
raise RuntimeError(msg % message["status"])
self.logger.info(
'%s - "WebSocket %s" %d',
get_client_addr(self.scope),
get_path_with_query_string(self.scope),
message["status"],
)
self.handshake_complete = True
event = events.RejectConnection(
status_code=message["status"],
headers=list(message["headers"]),
has_body=True,
)
output = self.conn.send(event)
self.transport.write(output)
self.response_started = True
else:
msg = (
"Expected ASGI message 'websocket.accept' or 'websocket.close', "
"Expected ASGI message 'websocket.accept', 'websocket.close' "
"or 'websocket.http.response.start' "
"but got '%s'."
)
raise RuntimeError(msg % message_type)
elif not self.close_sent:
if message_type == "websocket.send":
message = typing.cast("WebSocketSendEvent", message)
bytes_data = message.get("bytes")
text_data = message.get("text")
data = text_data if bytes_data is None else bytes_data
output = self.conn.send(
wsproto.events.Message(data=data) # type: ignore[type-var]
)
if not self.transport.is_closing():
self.transport.write(output)
elif not self.close_sent and not self.response_started:
try:
if message_type == "websocket.send":
message = cast(WebSocketSendEvent, message)
bytes_data = message.get("bytes")
text_data = message.get("text")
data = text_data if bytes_data is None else bytes_data
output = self.conn.send(wsproto.events.Message(data=data)) # type: ignore
if not self.transport.is_closing():
self.transport.write(output)
elif message_type == "websocket.close":
message = typing.cast("WebSocketCloseEvent", message)
self.close_sent = True
code = message.get("code", 1000)
reason = message.get("reason", "") or ""
self.queue.put_nowait({"type": "websocket.disconnect", "code": code})
output = self.conn.send(
wsproto.events.CloseConnection(code=code, reason=reason)
)
if not self.transport.is_closing():
self.transport.write(output)
elif message_type == "websocket.close":
message = cast(WebSocketCloseEvent, message)
self.close_sent = True
code = message.get("code", 1000)
reason = message.get("reason", "") or ""
self.queue.put_nowait({"type": "websocket.disconnect", "code": code, "reason": reason})
output = self.conn.send(wsproto.events.CloseConnection(code=code, reason=reason))
if not self.transport.is_closing():
self.transport.write(output)
self.transport.close()
else:
msg = "Expected ASGI message 'websocket.send' or 'websocket.close', but got '%s'."
raise RuntimeError(msg % message_type)
except LocalProtocolError as exc:
raise ClientDisconnected from exc
elif self.response_started:
if message_type == "websocket.http.response.body":
message = cast("WebSocketResponseBodyEvent", message)
body_finished = not message.get("more_body", False)
reject_data = events.RejectData(data=message["body"], body_finished=body_finished)
output = self.conn.send(reject_data)
self.transport.write(output)
if body_finished:
self.queue.put_nowait({"type": "websocket.disconnect", "code": 1006})
self.close_sent = True
self.transport.close()
else:
msg = (
"Expected ASGI message 'websocket.send' or 'websocket.close',"
" but got '%s'."
)
msg = "Expected ASGI message 'websocket.http.response.body' but got '%s'."
raise RuntimeError(msg % message_type)
else:
msg = "Unexpected ASGI message '%s', after sending 'websocket.close'."
raise RuntimeError(msg % message_type)
async def receive(self) -> "WebSocketEvent":
async def receive(self) -> WebSocketEvent:
message = await self.queue.get()
if self.read_paused and self.queue.empty():
self.read_paused = False

View File

@@ -1,4 +1,7 @@
from __future__ import annotations
import asyncio
import contextlib
import logging
import os
import platform
@@ -7,22 +10,24 @@ import socket
import sys
import threading
import time
from collections.abc import Generator, Sequence
from email.utils import formatdate
from types import FrameType
from typing import TYPE_CHECKING, List, Optional, Sequence, Set, Tuple, Union
from typing import TYPE_CHECKING, Union
import click
from uvicorn._compat import asyncio_run
from uvicorn.config import Config
if TYPE_CHECKING:
from uvicorn.protocols.http.h11_impl import H11Protocol
from uvicorn.protocols.http.httptools_impl import HttpToolsProtocol
from uvicorn.protocols.websockets.websockets_impl import WebSocketProtocol
from uvicorn.protocols.websockets.websockets_sansio_impl import WebSocketsSansIOProtocol
from uvicorn.protocols.websockets.wsproto_impl import WSProtocol
Protocols = Union[H11Protocol, HttpToolsProtocol, WSProtocol, WebSocketProtocol]
Protocols = Union[H11Protocol, HttpToolsProtocol, WSProtocol, WebSocketProtocol, WebSocketsSansIOProtocol]
HANDLED_SIGNALS = (
signal.SIGINT, # Unix signal 2. Sent by Ctrl+C.
@@ -41,9 +46,9 @@ class ServerState:
def __init__(self) -> None:
self.total_requests = 0
self.connections: Set["Protocols"] = set()
self.tasks: Set[asyncio.Task] = set()
self.default_headers: List[Tuple[bytes, bytes]] = []
self.connections: set[Protocols] = set()
self.tasks: set[asyncio.Task[None]] = set()
self.default_headers: list[tuple[bytes, bytes]] = []
class Server:
@@ -56,11 +61,16 @@ class Server:
self.force_exit = False
self.last_notified = 0.0
def run(self, sockets: Optional[List[socket.socket]] = None) -> None:
self.config.setup_event_loop()
return asyncio.run(self.serve(sockets=sockets))
self._captured_signals: list[int] = []
async def serve(self, sockets: Optional[List[socket.socket]] = None) -> None:
def run(self, sockets: list[socket.socket] | None = None) -> None:
return asyncio_run(self.serve(sockets=sockets), loop_factory=self.config.get_loop_factory())
async def serve(self, sockets: list[socket.socket] | None = None) -> None:
with self.capture_signals():
await self._serve(sockets)
async def _serve(self, sockets: list[socket.socket] | None = None) -> None:
process_id = os.getpid()
config = self.config
@@ -69,8 +79,6 @@ class Server:
self.lifespan = config.lifespan_class(config)
self.install_signal_handlers()
message = "Started server process [%d]"
color_message = "Started server process [" + click.style("%d", fg="cyan") + "]"
logger.info(message, process_id, extra={"color_message": color_message})
@@ -85,7 +93,7 @@ class Server:
color_message = "Finished server process [" + click.style("%d", fg="cyan") + "]"
logger.info(message, process_id, extra={"color_message": color_message})
async def startup(self, sockets: Optional[List[socket.socket]] = None) -> None:
async def startup(self, sockets: list[socket.socket] | None = None) -> None:
await self.lifespan.startup()
if self.lifespan.should_exit:
self.should_exit = True
@@ -94,7 +102,7 @@ class Server:
config = self.config
def create_protocol(
_loop: Optional[asyncio.AbstractEventLoop] = None,
_loop: asyncio.AbstractEventLoop | None = None,
) -> asyncio.Protocol:
return config.http_protocol_class( # type: ignore[call-arg]
config=config,
@@ -106,13 +114,13 @@ class Server:
loop = asyncio.get_running_loop()
listeners: Sequence[socket.SocketType]
if sockets is not None:
if sockets is not None: # pragma: full coverage
# Explicitly passed a list of open sockets.
# We use this when the server is run from a Gunicorn worker.
def _share_socket(
sock: socket.SocketType,
) -> socket.SocketType: # pragma py-linux pragma: py-darwin
) -> socket.SocketType: # pragma py-not-win32
# Windows requires the socket be explicitly shared across
# multiple workers (processes).
from socket import fromshare # type: ignore[attr-defined]
@@ -120,23 +128,19 @@ class Server:
sock_data = sock.share(os.getpid()) # type: ignore[attr-defined]
return fromshare(sock_data)
self.servers: List[asyncio.base_events.Server] = []
self.servers: list[asyncio.base_events.Server] = []
for sock in sockets:
is_windows = platform.system() == "Windows"
if config.workers > 1 and is_windows: # pragma: py-not-win32
sock = _share_socket(sock) # type: ignore[assignment]
server = await loop.create_server(
create_protocol, sock=sock, ssl=config.ssl, backlog=config.backlog
)
server = await loop.create_server(create_protocol, sock=sock, ssl=config.ssl, backlog=config.backlog)
self.servers.append(server)
listeners = sockets
elif config.fd is not None: # pragma: py-win32
# Use an existing socket, from a file descriptor.
sock = socket.fromfd(config.fd, socket.AF_UNIX, socket.SOCK_STREAM)
server = await loop.create_server(
create_protocol, sock=sock, ssl=config.ssl, backlog=config.backlog
)
server = await loop.create_server(create_protocol, sock=sock, ssl=config.ssl, backlog=config.backlog)
assert server.sockets is not None # mypy
listeners = server.sockets
self.servers = [server]
@@ -145,7 +149,7 @@ class Server:
# Create a socket using UNIX domain socket.
uds_perms = 0o666
if os.path.exists(config.uds):
uds_perms = os.stat(config.uds).st_mode
uds_perms = os.stat(config.uds).st_mode # pragma: full coverage
server = await loop.create_unix_server(
create_protocol, path=config.uds, ssl=config.ssl, backlog=config.backlog
)
@@ -178,7 +182,7 @@ class Server:
else:
# We're most likely running multiple workers, so a message has already been
# logged by `config.bind_socket()`.
pass
pass # pragma: full coverage
self.started = True
@@ -193,9 +197,7 @@ class Server:
)
elif config.uds is not None: # pragma: py-win32
logger.info(
"Uvicorn running on unix socket %s (Press CTRL+C to quit)", config.uds
)
logger.info("Uvicorn running on unix socket %s (Press CTRL+C to quit)", config.uds)
else:
addr_format = "%s://%s:%d"
@@ -210,11 +212,7 @@ class Server:
protocol_name = "https" if config.ssl else "http"
message = f"Uvicorn running on {addr_format} (Press CTRL+C to quit)"
color_message = (
"Uvicorn running on "
+ click.style(addr_format, bold=True)
+ " (Press CTRL+C to quit)"
)
color_message = "Uvicorn running on " + click.style(addr_format, bold=True) + " (Press CTRL+C to quit)"
logger.info(
message,
protocol_name,
@@ -243,31 +241,33 @@ class Server:
else:
date_header = []
self.server_state.default_headers = (
date_header + self.config.encoded_headers
)
self.server_state.default_headers = date_header + self.config.encoded_headers
# Callback to `callback_notify` once every `timeout_notify` seconds.
if self.config.callback_notify is not None:
if current_time - self.last_notified > self.config.timeout_notify:
if current_time - self.last_notified > self.config.timeout_notify: # pragma: full coverage
self.last_notified = current_time
await self.config.callback_notify()
# Determine if we should exit.
if self.should_exit:
return True
if self.config.limit_max_requests is not None:
return self.server_state.total_requests >= self.config.limit_max_requests
max_requests = self.config.limit_max_requests
if max_requests is not None and self.server_state.total_requests >= max_requests:
logger.warning(f"Maximum request limit of {max_requests} exceeded. Terminating process.")
return True
return False
async def shutdown(self, sockets: Optional[List[socket.socket]] = None) -> None:
async def shutdown(self, sockets: list[socket.socket] | None = None) -> None:
logger.info("Shutting down")
# Stop accepting new connections.
for server in self.servers:
server.close()
for sock in sockets or []:
sock.close()
sock.close() # pragma: full coverage
# Request shutdown on all existing connections.
for connection in list(self.server_state.connections):
@@ -286,10 +286,7 @@ class Server:
len(self.server_state.tasks),
)
for t in self.server_state.tasks:
if sys.version_info < (3, 9): # pragma: py-gte-39
t.cancel()
else: # pragma: py-lt-39
t.cancel(msg="Task cancelled, timeout graceful shutdown exceeded")
t.cancel(msg="Task cancelled, timeout graceful shutdown exceeded")
# Send the lifespan shutdown event, and wait for application shutdown.
if not self.force_exit:
@@ -313,23 +310,29 @@ class Server:
for server in self.servers:
await server.wait_closed()
def install_signal_handlers(self) -> None:
@contextlib.contextmanager
def capture_signals(self) -> Generator[None, None, None]:
# Signals can only be listened to from the main thread.
if threading.current_thread() is not threading.main_thread():
# Signals can only be listened to from the main thread.
yield
return
loop = asyncio.get_event_loop()
# always use signal.signal, even if loop.add_signal_handler is available
# this allows to restore previous signal handlers later on
original_handlers = {sig: signal.signal(sig, self.handle_exit) for sig in HANDLED_SIGNALS}
try:
for sig in HANDLED_SIGNALS:
loop.add_signal_handler(sig, self.handle_exit, sig, None)
except NotImplementedError: # pragma: no cover
# Windows
for sig in HANDLED_SIGNALS:
signal.signal(sig, self.handle_exit)
yield
finally:
for sig, handler in original_handlers.items():
signal.signal(sig, handler)
# If we did gracefully shut down due to a signal, try to
# trigger the expected behaviour now; multiple signals would be
# done LIFO, see https://stackoverflow.com/questions/48434964
for captured_signal in reversed(self._captured_signals):
signal.raise_signal(captured_signal)
def handle_exit(self, sig: int, frame: Optional[FrameType]) -> None:
def handle_exit(self, sig: int, frame: FrameType | None) -> None:
self._captured_signals.append(sig)
if self.should_exit and sig == signal.SIGINT:
self.force_exit = True
self.force_exit = True # pragma: full coverage
else:
self.should_exit = True

View File

@@ -1,21 +1,16 @@
from typing import TYPE_CHECKING, Type
from __future__ import annotations
from typing import TYPE_CHECKING
from uvicorn.supervisors.basereload import BaseReload
from uvicorn.supervisors.multiprocess import Multiprocess
if TYPE_CHECKING:
ChangeReload: Type[BaseReload]
ChangeReload: type[BaseReload]
else:
try:
from uvicorn.supervisors.watchfilesreload import (
WatchFilesReload as ChangeReload,
)
from uvicorn.supervisors.watchfilesreload import WatchFilesReload as ChangeReload
except ImportError: # pragma: no cover
try:
from uvicorn.supervisors.watchgodreload import (
WatchGodReload as ChangeReload,
)
except ImportError:
from uvicorn.supervisors.statreload import StatReload as ChangeReload
from uvicorn.supervisors.statreload import StatReload as ChangeReload
__all__ = ["Multiprocess", "ChangeReload"]

View File

@@ -1,12 +1,15 @@
from __future__ import annotations
import logging
import os
import signal
import sys
import threading
from collections.abc import Iterator
from pathlib import Path
from socket import socket
from types import FrameType
from typing import Callable, Iterator, List, Optional
from typing import Callable
import click
@@ -25,8 +28,8 @@ class BaseReload:
def __init__(
self,
config: Config,
target: Callable[[Optional[List[socket]]], None],
sockets: List[socket],
target: Callable[[list[socket] | None], None],
sockets: list[socket],
) -> None:
self.config = config
self.target = target
@@ -34,16 +37,16 @@ class BaseReload:
self.should_exit = threading.Event()
self.pid = os.getpid()
self.is_restarting = False
self.reloader_name: Optional[str] = None
self.reloader_name: str | None = None
def signal_handler(self, sig: int, frame: Optional[FrameType]) -> None:
def signal_handler(self, sig: int, frame: FrameType | None) -> None: # pragma: full coverage
"""
A signal handler that is registered with the parent process.
"""
if sys.platform == "win32" and self.is_restarting:
self.is_restarting = False # pragma: py-not-win32
self.is_restarting = False
else:
self.should_exit.set() # pragma: py-win32
self.should_exit.set()
def run(self) -> None:
self.startup()
@@ -62,10 +65,10 @@ class BaseReload:
if self.should_exit.wait(self.config.reload_delay):
raise StopIteration()
def __iter__(self) -> Iterator[Optional[List[Path]]]:
def __iter__(self) -> Iterator[list[Path] | None]:
return self
def __next__(self) -> Optional[List[Path]]:
def __next__(self) -> list[Path] | None:
return self.should_restart()
def startup(self) -> None:
@@ -79,9 +82,7 @@ class BaseReload:
for sig in HANDLED_SIGNALS:
signal.signal(sig, self.signal_handler)
self.process = get_subprocess(
config=self.config, target=self.target, sockets=self.sockets
)
self.process = get_subprocess(config=self.config, target=self.target, sockets=self.sockets)
self.process.start()
def restart(self) -> None:
@@ -89,13 +90,15 @@ class BaseReload:
self.is_restarting = True
assert self.process.pid is not None
os.kill(self.process.pid, signal.CTRL_C_EVENT)
# This is a workaround to ensure the Ctrl+C event is processed
sys.stdout.write(" ") # This has to be a non-empty string
sys.stdout.flush()
else: # pragma: py-win32
self.process.terminate()
self.process.join()
self.process = get_subprocess(
config=self.config, target=self.target, sockets=self.sockets
)
self.process = get_subprocess(config=self.config, target=self.target, sockets=self.sockets)
self.process.start()
def shutdown(self) -> None:
@@ -108,13 +111,11 @@ class BaseReload:
for sock in self.sockets:
sock.close()
message = "Stopping reloader process [{}]".format(str(self.pid))
color_message = "Stopping reloader process [{}]".format(
click.style(str(self.pid), fg="cyan", bold=True)
)
message = f"Stopping reloader process [{str(self.pid)}]"
color_message = "Stopping reloader process [{}]".format(click.style(str(self.pid), fg="cyan", bold=True))
logger.info(message, extra={"color_message": color_message})
def should_restart(self) -> Optional[List[Path]]:
def should_restart(self) -> list[Path] | None:
raise NotImplementedError("Reload strategies should override should_restart()")

View File

@@ -1,74 +1,222 @@
from __future__ import annotations
import logging
import os
import signal
import threading
from multiprocessing.context import SpawnProcess
from multiprocessing import Pipe
from socket import socket
from types import FrameType
from typing import Callable, List, Optional
from typing import Any, Callable
import click
from uvicorn._subprocess import get_subprocess
from uvicorn.config import Config
HANDLED_SIGNALS = (
signal.SIGINT, # Unix signal 2. Sent by Ctrl+C.
signal.SIGTERM, # Unix signal 15. Sent by `kill <pid>`.
)
SIGNALS = {
getattr(signal, f"SIG{x}"): x
for x in "INT TERM BREAK HUP QUIT TTIN TTOU USR1 USR2 WINCH".split()
if hasattr(signal, f"SIG{x}")
}
logger = logging.getLogger("uvicorn.error")
class Process:
def __init__(
self,
config: Config,
target: Callable[[list[socket] | None], None],
sockets: list[socket],
) -> None:
self.real_target = target
self.parent_conn, self.child_conn = Pipe()
self.process = get_subprocess(config, self.target, sockets)
def ping(self, timeout: float = 5) -> bool:
self.parent_conn.send(b"ping")
if self.parent_conn.poll(timeout):
self.parent_conn.recv()
return True
return False
def pong(self) -> None:
self.child_conn.recv()
self.child_conn.send(b"pong")
def always_pong(self) -> None:
while True:
self.pong()
def target(self, sockets: list[socket] | None = None) -> Any: # pragma: no cover
if os.name == "nt": # pragma: py-not-win32
# Windows doesn't support SIGTERM, so we use SIGBREAK instead.
# And then we raise SIGTERM when SIGBREAK is received.
# https://learn.microsoft.com/zh-cn/cpp/c-runtime-library/reference/signal?view=msvc-170
signal.signal(
signal.SIGBREAK, # type: ignore[attr-defined]
lambda sig, frame: signal.raise_signal(signal.SIGTERM),
)
threading.Thread(target=self.always_pong, daemon=True).start()
return self.real_target(sockets)
def is_alive(self, timeout: float = 5) -> bool:
if not self.process.is_alive():
return False # pragma: full coverage
return self.ping(timeout)
def start(self) -> None:
self.process.start()
def terminate(self) -> None:
if self.process.exitcode is None: # Process is still running
assert self.process.pid is not None
if os.name == "nt": # pragma: py-not-win32
# Windows doesn't support SIGTERM.
# So send SIGBREAK, and then in process raise SIGTERM.
os.kill(self.process.pid, signal.CTRL_BREAK_EVENT) # type: ignore[attr-defined]
else:
os.kill(self.process.pid, signal.SIGTERM)
logger.info(f"Terminated child process [{self.process.pid}]")
self.parent_conn.close()
self.child_conn.close()
def kill(self) -> None:
# In Windows, the method will call `TerminateProcess` to kill the process.
# In Unix, the method will send SIGKILL to the process.
self.process.kill()
def join(self) -> None:
logger.info(f"Waiting for child process [{self.process.pid}]")
self.process.join()
@property
def pid(self) -> int | None:
return self.process.pid
class Multiprocess:
def __init__(
self,
config: Config,
target: Callable[[Optional[List[socket]]], None],
sockets: List[socket],
target: Callable[[list[socket] | None], None],
sockets: list[socket],
) -> None:
self.config = config
self.target = target
self.sockets = sockets
self.processes: List[SpawnProcess] = []
self.processes_num = config.workers
self.processes: list[Process] = []
self.should_exit = threading.Event()
self.pid = os.getpid()
def signal_handler(self, sig: int, frame: Optional[FrameType]) -> None:
"""
A signal handler that is registered with the parent process.
"""
self.should_exit.set()
self.signal_queue: list[int] = []
for sig in SIGNALS:
signal.signal(sig, lambda sig, frame: self.signal_queue.append(sig))
def run(self) -> None:
self.startup()
self.should_exit.wait()
self.shutdown()
def startup(self) -> None:
message = "Started parent process [{}]".format(str(self.pid))
color_message = "Started parent process [{}]".format(
click.style(str(self.pid), fg="cyan", bold=True)
)
logger.info(message, extra={"color_message": color_message})
for sig in HANDLED_SIGNALS:
signal.signal(sig, self.signal_handler)
for _idx in range(self.config.workers):
process = get_subprocess(
config=self.config, target=self.target, sockets=self.sockets
)
def init_processes(self) -> None:
for _ in range(self.processes_num):
process = Process(self.config, self.target, self.sockets)
process.start()
self.processes.append(process)
def shutdown(self) -> None:
def terminate_all(self) -> None:
for process in self.processes:
process.terminate()
def join_all(self) -> None:
for process in self.processes:
process.join()
message = "Stopping parent process [{}]".format(str(self.pid))
color_message = "Stopping parent process [{}]".format(
click.style(str(self.pid), fg="cyan", bold=True)
)
def restart_all(self) -> None:
for idx, process in enumerate(self.processes):
process.terminate()
process.join()
new_process = Process(self.config, self.target, self.sockets)
new_process.start()
self.processes[idx] = new_process
def run(self) -> None:
message = f"Started parent process [{os.getpid()}]"
color_message = "Started parent process [{}]".format(click.style(str(os.getpid()), fg="cyan", bold=True))
logger.info(message, extra={"color_message": color_message})
self.init_processes()
while not self.should_exit.wait(0.5):
self.handle_signals()
self.keep_subprocess_alive()
self.terminate_all()
self.join_all()
message = f"Stopping parent process [{os.getpid()}]"
color_message = "Stopping parent process [{}]".format(click.style(str(os.getpid()), fg="cyan", bold=True))
logger.info(message, extra={"color_message": color_message})
def keep_subprocess_alive(self) -> None:
if self.should_exit.is_set():
return # parent process is exiting, no need to keep subprocess alive
for idx, process in enumerate(self.processes):
if process.is_alive(timeout=self.config.timeout_worker_healthcheck):
continue
process.kill() # process is hung, kill it
process.join()
if self.should_exit.is_set():
return # pragma: full coverage
logger.info(f"Child process [{process.pid}] died")
process = Process(self.config, self.target, self.sockets)
process.start()
self.processes[idx] = process
def handle_signals(self) -> None:
for sig in tuple(self.signal_queue):
self.signal_queue.remove(sig)
sig_name = SIGNALS[sig]
sig_handler = getattr(self, f"handle_{sig_name.lower()}", None)
if sig_handler is not None:
sig_handler()
else: # pragma: no cover
logger.debug(f"Received signal {sig_name}, but no handler is defined for it.")
def handle_int(self) -> None:
logger.info("Received SIGINT, exiting.")
self.should_exit.set()
def handle_term(self) -> None:
logger.info("Received SIGTERM, exiting.")
self.should_exit.set()
def handle_break(self) -> None: # pragma: py-not-win32
logger.info("Received SIGBREAK, exiting.")
self.should_exit.set()
def handle_hup(self) -> None: # pragma: py-win32
logger.info("Received SIGHUP, restarting processes.")
self.restart_all()
def handle_ttin(self) -> None: # pragma: py-win32
logger.info("Received SIGTTIN, increasing the number of processes.")
self.processes_num += 1
process = Process(self.config, self.target, self.sockets)
process.start()
self.processes.append(process)
def handle_ttou(self) -> None: # pragma: py-win32
logger.info("Received SIGTTOU, decreasing number of processes.")
if self.processes_num <= 1:
logger.info("Already reached one process, cannot decrease the number of processes anymore.")
return
self.processes_num -= 1
process = self.processes.pop()
process.terminate()
process.join()

View File

@@ -1,7 +1,10 @@
from __future__ import annotations
import logging
from collections.abc import Iterator
from pathlib import Path
from socket import socket
from typing import Callable, Dict, Iterator, List, Optional
from typing import Callable
from uvicorn.config import Config
from uvicorn.supervisors.basereload import BaseReload
@@ -13,20 +16,17 @@ class StatReload(BaseReload):
def __init__(
self,
config: Config,
target: Callable[[Optional[List[socket]]], None],
sockets: List[socket],
target: Callable[[list[socket] | None], None],
sockets: list[socket],
) -> None:
super().__init__(config, target, sockets)
self.reloader_name = "StatReload"
self.mtimes: Dict[Path, float] = {}
self.mtimes: dict[Path, float] = {}
if config.reload_excludes or config.reload_includes:
logger.warning(
"--reload-include and --reload-exclude have no effect unless "
"watchfiles is installed."
)
logger.warning("--reload-include and --reload-exclude have no effect unless watchfiles is installed.")
def should_restart(self) -> Optional[List[Path]]:
def should_restart(self) -> list[Path] | None:
self.pause()
for file in self.iter_py_files():

View File

@@ -1,6 +1,8 @@
from __future__ import annotations
from pathlib import Path
from socket import socket
from typing import Callable, List, Optional
from typing import Callable
from watchfiles import watch
@@ -11,20 +13,12 @@ from uvicorn.supervisors.basereload import BaseReload
class FileFilter:
def __init__(self, config: Config):
default_includes = ["*.py"]
self.includes = [
default
for default in default_includes
if default not in config.reload_excludes
]
self.includes = [default for default in default_includes if default not in config.reload_excludes]
self.includes.extend(config.reload_includes)
self.includes = list(set(self.includes))
default_excludes = [".*", ".py[cod]", ".sw.*", "~*"]
self.excludes = [
default
for default in default_excludes
if default not in config.reload_includes
]
self.excludes = [default for default in default_excludes if default not in config.reload_includes]
self.exclude_dirs = []
for e in config.reload_excludes:
p = Path(e)
@@ -37,19 +31,22 @@ class FileFilter:
if is_dir:
self.exclude_dirs.append(p)
else:
self.excludes.append(e)
self.excludes.append(e) # pragma: full coverage
self.excludes = list(set(self.excludes))
def __call__(self, path: Path) -> bool:
for include_pattern in self.includes:
if path.match(include_pattern):
if str(path).endswith(include_pattern):
return True # pragma: full coverage
for exclude_dir in self.exclude_dirs:
if exclude_dir in path.parents:
return False
for exclude_pattern in self.excludes:
if path.match(exclude_pattern):
return False
return False # pragma: full coverage
return True
return False
@@ -59,17 +56,14 @@ class WatchFilesReload(BaseReload):
def __init__(
self,
config: Config,
target: Callable[[Optional[List[socket]]], None],
sockets: List[socket],
target: Callable[[list[socket] | None], None],
sockets: list[socket],
) -> None:
super().__init__(config, target, sockets)
self.reloader_name = "WatchFiles"
self.reload_dirs = []
for directory in config.reload_dirs:
if Path.cwd() not in directory.parents:
self.reload_dirs.append(directory)
if Path.cwd() not in self.reload_dirs:
self.reload_dirs.append(Path.cwd())
self.reload_dirs.append(directory)
self.watch_filter = FileFilter(config)
self.watcher = watch(
@@ -81,7 +75,7 @@ class WatchFilesReload(BaseReload):
yield_on_timeout=True,
)
def should_restart(self) -> Optional[List[Path]]:
def should_restart(self) -> list[Path] | None:
self.pause()
changes = next(self.watcher)

View File

@@ -1,158 +0,0 @@
import logging
import warnings
from pathlib import Path
from socket import socket
from typing import TYPE_CHECKING, Callable, Dict, List, Optional
from watchgod import DefaultWatcher
from uvicorn.config import Config
from uvicorn.supervisors.basereload import BaseReload
if TYPE_CHECKING:
import os
DirEntry = os.DirEntry[str]
logger = logging.getLogger("uvicorn.error")
class CustomWatcher(DefaultWatcher):
def __init__(self, root_path: Path, config: Config):
default_includes = ["*.py"]
self.includes = [
default
for default in default_includes
if default not in config.reload_excludes
]
self.includes.extend(config.reload_includes)
self.includes = list(set(self.includes))
default_excludes = [".*", ".py[cod]", ".sw.*", "~*"]
self.excludes = [
default
for default in default_excludes
if default not in config.reload_includes
]
self.excludes.extend(config.reload_excludes)
self.excludes = list(set(self.excludes))
self.watched_dirs: Dict[str, bool] = {}
self.watched_files: Dict[str, bool] = {}
self.dirs_includes = set(config.reload_dirs)
self.dirs_excludes = set(config.reload_dirs_excludes)
self.resolved_root = root_path
super().__init__(str(root_path))
def should_watch_file(self, entry: "DirEntry") -> bool:
cached_result = self.watched_files.get(entry.path)
if cached_result is not None:
return cached_result
entry_path = Path(entry)
# cwd is not verified through should_watch_dir, so we need to verify here
if entry_path.parent == Path.cwd() and Path.cwd() not in self.dirs_includes:
self.watched_files[entry.path] = False
return False
for include_pattern in self.includes:
if entry_path.match(include_pattern):
for exclude_pattern in self.excludes:
if entry_path.match(exclude_pattern):
self.watched_files[entry.path] = False
return False
self.watched_files[entry.path] = True
return True
self.watched_files[entry.path] = False
return False
def should_watch_dir(self, entry: "DirEntry") -> bool:
cached_result = self.watched_dirs.get(entry.path)
if cached_result is not None:
return cached_result
entry_path = Path(entry)
if entry_path in self.dirs_excludes:
self.watched_dirs[entry.path] = False
return False
for exclude_pattern in self.excludes:
if entry_path.match(exclude_pattern):
is_watched = False
if entry_path in self.dirs_includes:
is_watched = True
for directory in self.dirs_includes:
if directory in entry_path.parents:
is_watched = True
if is_watched:
logger.debug(
"WatchGodReload detected a new excluded dir '%s' in '%s'; "
"Adding to exclude list.",
entry_path.relative_to(self.resolved_root),
str(self.resolved_root),
)
self.watched_dirs[entry.path] = False
self.dirs_excludes.add(entry_path)
return False
if entry_path in self.dirs_includes:
self.watched_dirs[entry.path] = True
return True
for directory in self.dirs_includes:
if directory in entry_path.parents:
self.watched_dirs[entry.path] = True
return True
for include_pattern in self.includes:
if entry_path.match(include_pattern):
logger.info(
"WatchGodReload detected a new reload dir '%s' in '%s'; "
"Adding to watch list.",
str(entry_path.relative_to(self.resolved_root)),
str(self.resolved_root),
)
self.dirs_includes.add(entry_path)
self.watched_dirs[entry.path] = True
return True
self.watched_dirs[entry.path] = False
return False
class WatchGodReload(BaseReload):
def __init__(
self,
config: Config,
target: Callable[[Optional[List[socket]]], None],
sockets: List[socket],
) -> None:
warnings.warn(
'"watchgod" is deprecated, you should switch '
"to watchfiles (`pip install watchfiles`).",
DeprecationWarning,
)
super().__init__(config, target, sockets)
self.reloader_name = "WatchGod"
self.watchers = []
reload_dirs = []
for directory in config.reload_dirs:
if Path.cwd() not in directory.parents:
reload_dirs.append(directory)
if Path.cwd() not in reload_dirs:
reload_dirs.append(Path.cwd())
for w in reload_dirs:
self.watchers.append(CustomWatcher(w.resolve(), self.config))
def should_restart(self) -> Optional[List[Path]]:
self.pause()
for watcher in self.watchers:
change = watcher.check()
if change != set():
return list({Path(c[1]) for c in change})
return None

View File

@@ -1,14 +1,24 @@
from __future__ import annotations
import asyncio
import logging
import signal
import sys
from typing import Any, Dict
import warnings
from typing import Any
from gunicorn.arbiter import Arbiter
from gunicorn.workers.base import Worker
from uvicorn._compat import asyncio_run
from uvicorn.config import Config
from uvicorn.main import Server
from uvicorn.server import Server
warnings.warn(
"The `uvicorn.workers` module is deprecated. Please use `uvicorn-worker` package instead.\n"
"For more details, see https://github.com/Kludex/uvicorn-worker.",
DeprecationWarning,
)
class UvicornWorker(Worker):
@@ -17,10 +27,10 @@ class UvicornWorker(Worker):
rather than a WSGI callable.
"""
CONFIG_KWARGS: Dict[str, Any] = {"loop": "auto", "http": "auto"}
CONFIG_KWARGS: dict[str, Any] = {"loop": "auto", "http": "auto"}
def __init__(self, *args: Any, **kwargs: Any) -> None:
super(UvicornWorker, self).__init__(*args, **kwargs)
super().__init__(*args, **kwargs)
logger = logging.getLogger("uvicorn.error")
logger.handlers = self.log.error_log.handlers
@@ -61,14 +71,10 @@ class UvicornWorker(Worker):
self.config = Config(**config_kwargs)
def init_process(self) -> None:
self.config.setup_event_loop()
super(UvicornWorker, self).init_process()
def init_signals(self) -> None:
# Reset signals so Gunicorn doesn't swallow subprocess return codes
# other signals are set up by Server.install_signal_handlers()
# See: https://github.com/encode/uvicorn/issues/894
# See: https://github.com/Kludex/uvicorn/issues/894
for s in self.SIGNALS:
signal.signal(s, signal.SIG_DFL)
@@ -79,7 +85,7 @@ class UvicornWorker(Worker):
def _install_sigquit_handler(self) -> None:
"""Install a SIGQUIT handler on workers.
- https://github.com/encode/uvicorn/issues/1116
- https://github.com/Kludex/uvicorn/issues/1116
- https://github.com/benoitc/gunicorn/issues/2604
"""
@@ -95,7 +101,7 @@ class UvicornWorker(Worker):
sys.exit(Arbiter.WORKER_BOOT_ERROR)
def run(self) -> None:
return asyncio.run(self._serve())
return asyncio_run(self._serve(), loop_factory=self.config.get_loop_factory())
async def callback_notify(self) -> None:
self.notify()