This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
from uvicorn.config import Config
|
||||
from uvicorn.main import Server, main, run
|
||||
|
||||
__version__ = "0.24.0"
|
||||
__version__ = "0.37.0"
|
||||
__all__ = ["main", "run", "Config", "Server"]
|
||||
|
||||
84
venv/lib/python3.12/site-packages/uvicorn/_compat.py
Normal file
84
venv/lib/python3.12/site-packages/uvicorn/_compat.py
Normal file
@@ -0,0 +1,84 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
from collections.abc import Callable, Coroutine
|
||||
from typing import Any, TypeVar
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
asyncio_run = asyncio.run
|
||||
elif sys.version_info >= (3, 11):
|
||||
|
||||
def asyncio_run(
|
||||
main: Coroutine[Any, Any, _T],
|
||||
*,
|
||||
debug: bool = False,
|
||||
loop_factory: Callable[[], asyncio.AbstractEventLoop] | None = None,
|
||||
) -> _T:
|
||||
# asyncio.run from Python 3.12
|
||||
# https://docs.python.org/3/license.html#psf-license
|
||||
with asyncio.Runner(debug=debug, loop_factory=loop_factory) as runner:
|
||||
return runner.run(main)
|
||||
|
||||
else:
|
||||
# modified version of asyncio.run from Python 3.10 to add loop_factory kwarg
|
||||
# https://docs.python.org/3/license.html#psf-license
|
||||
def asyncio_run(
|
||||
main: Coroutine[Any, Any, _T],
|
||||
*,
|
||||
debug: bool = False,
|
||||
loop_factory: Callable[[], asyncio.AbstractEventLoop] | None = None,
|
||||
) -> _T:
|
||||
try:
|
||||
asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
pass
|
||||
else:
|
||||
raise RuntimeError("asyncio.run() cannot be called from a running event loop")
|
||||
|
||||
if not asyncio.iscoroutine(main):
|
||||
raise ValueError(f"a coroutine was expected, got {main!r}")
|
||||
|
||||
if loop_factory is None:
|
||||
loop = asyncio.new_event_loop()
|
||||
else:
|
||||
loop = loop_factory()
|
||||
try:
|
||||
if loop_factory is None:
|
||||
asyncio.set_event_loop(loop)
|
||||
if debug is not None:
|
||||
loop.set_debug(debug)
|
||||
return loop.run_until_complete(main)
|
||||
finally:
|
||||
try:
|
||||
_cancel_all_tasks(loop)
|
||||
loop.run_until_complete(loop.shutdown_asyncgens())
|
||||
loop.run_until_complete(loop.shutdown_default_executor())
|
||||
finally:
|
||||
if loop_factory is None:
|
||||
asyncio.set_event_loop(None)
|
||||
loop.close()
|
||||
|
||||
def _cancel_all_tasks(loop: asyncio.AbstractEventLoop) -> None:
|
||||
to_cancel = asyncio.all_tasks(loop)
|
||||
if not to_cancel:
|
||||
return
|
||||
|
||||
for task in to_cancel:
|
||||
task.cancel()
|
||||
|
||||
loop.run_until_complete(asyncio.gather(*to_cancel, return_exceptions=True))
|
||||
|
||||
for task in to_cancel:
|
||||
if task.cancelled():
|
||||
continue
|
||||
if task.exception() is not None:
|
||||
loop.call_exception_handler(
|
||||
{
|
||||
"message": "unhandled exception during asyncio.run() shutdown",
|
||||
"exception": task.exception(),
|
||||
"task": task,
|
||||
}
|
||||
)
|
||||
@@ -2,12 +2,15 @@
|
||||
Some light wrappers around Python's multiprocessing, to deal with cleanly
|
||||
starting child processes.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import multiprocessing
|
||||
import os
|
||||
import sys
|
||||
from multiprocessing.context import SpawnProcess
|
||||
from socket import socket
|
||||
from typing import Callable, List, Optional
|
||||
from typing import Callable
|
||||
|
||||
from uvicorn.config import Config
|
||||
|
||||
@@ -18,7 +21,7 @@ spawn = multiprocessing.get_context("spawn")
|
||||
def get_subprocess(
|
||||
config: Config,
|
||||
target: Callable[..., None],
|
||||
sockets: List[socket],
|
||||
sockets: list[socket],
|
||||
) -> SpawnProcess:
|
||||
"""
|
||||
Called in the parent process, to instantiate a new child process instance.
|
||||
@@ -32,10 +35,10 @@ def get_subprocess(
|
||||
"""
|
||||
# We pass across the stdin fileno, and reopen it in the child process.
|
||||
# This is required for some debugging environments.
|
||||
stdin_fileno: Optional[int]
|
||||
try:
|
||||
stdin_fileno = sys.stdin.fileno()
|
||||
except OSError:
|
||||
# The `sys.stdin` can be `None`, see https://docs.python.org/3/library/sys.html#sys.__stdin__.
|
||||
except (AttributeError, OSError):
|
||||
stdin_fileno = None
|
||||
|
||||
kwargs = {
|
||||
@@ -51,8 +54,8 @@ def get_subprocess(
|
||||
def subprocess_started(
|
||||
config: Config,
|
||||
target: Callable[..., None],
|
||||
sockets: List[socket],
|
||||
stdin_fileno: Optional[int],
|
||||
sockets: list[socket],
|
||||
stdin_fileno: int | None,
|
||||
) -> None:
|
||||
"""
|
||||
Called when the child process starts.
|
||||
@@ -67,10 +70,15 @@ def subprocess_started(
|
||||
"""
|
||||
# Re-open stdin.
|
||||
if stdin_fileno is not None:
|
||||
sys.stdin = os.fdopen(stdin_fileno)
|
||||
sys.stdin = os.fdopen(stdin_fileno) # pragma: full coverage
|
||||
|
||||
# Logging needs to be setup again for each child.
|
||||
config.configure_logging()
|
||||
|
||||
# Now we can call into `Server.run(sockets=sockets)`
|
||||
target(sockets=sockets)
|
||||
try:
|
||||
# Now we can call into `Server.run(sockets=sockets)`
|
||||
target(sockets=sockets)
|
||||
except KeyboardInterrupt: # pragma: no cover
|
||||
# supress the exception to avoid a traceback from subprocess.Popen
|
||||
# the parent already expects us to end, so no vital information is lost
|
||||
pass
|
||||
|
||||
@@ -28,25 +28,12 @@ ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
import types
|
||||
from typing import (
|
||||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterable,
|
||||
MutableMapping,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
|
||||
if sys.version_info >= (3, 8): # pragma: py-lt-38
|
||||
from typing import Literal, Protocol, TypedDict
|
||||
else: # pragma: py-gte-38
|
||||
from typing_extensions import Literal, Protocol, TypedDict
|
||||
from collections.abc import Awaitable, Iterable, MutableMapping
|
||||
from typing import Any, Callable, Literal, Optional, Protocol, TypedDict, Union
|
||||
|
||||
if sys.version_info >= (3, 11): # pragma: py-lt-311
|
||||
from typing import NotRequired
|
||||
@@ -55,15 +42,15 @@ else: # pragma: py-gte-311
|
||||
|
||||
# WSGI
|
||||
Environ = MutableMapping[str, Any]
|
||||
ExcInfo = Tuple[Type[BaseException], BaseException, Optional[types.TracebackType]]
|
||||
StartResponse = Callable[[str, Iterable[Tuple[str, str]], Optional[ExcInfo]], None]
|
||||
ExcInfo = tuple[type[BaseException], BaseException, Optional[types.TracebackType]]
|
||||
StartResponse = Callable[[str, Iterable[tuple[str, str]], Optional[ExcInfo]], None]
|
||||
WSGIApp = Callable[[Environ, StartResponse], Union[Iterable[bytes], BaseException]]
|
||||
|
||||
|
||||
# ASGI
|
||||
class ASGIVersions(TypedDict):
|
||||
spec_version: str
|
||||
version: Union[Literal["2.0"], Literal["3.0"]]
|
||||
version: Literal["2.0"] | Literal["3.0"]
|
||||
|
||||
|
||||
class HTTPScope(TypedDict):
|
||||
@@ -76,11 +63,11 @@ class HTTPScope(TypedDict):
|
||||
raw_path: bytes
|
||||
query_string: bytes
|
||||
root_path: str
|
||||
headers: Iterable[Tuple[bytes, bytes]]
|
||||
client: Optional[Tuple[str, int]]
|
||||
server: Optional[Tuple[str, Optional[int]]]
|
||||
state: NotRequired[Dict[str, Any]]
|
||||
extensions: NotRequired[Dict[str, Dict[object, object]]]
|
||||
headers: Iterable[tuple[bytes, bytes]]
|
||||
client: tuple[str, int] | None
|
||||
server: tuple[str, int | None] | None
|
||||
state: NotRequired[dict[str, Any]]
|
||||
extensions: NotRequired[dict[str, dict[object, object]]]
|
||||
|
||||
|
||||
class WebSocketScope(TypedDict):
|
||||
@@ -92,18 +79,18 @@ class WebSocketScope(TypedDict):
|
||||
raw_path: bytes
|
||||
query_string: bytes
|
||||
root_path: str
|
||||
headers: Iterable[Tuple[bytes, bytes]]
|
||||
client: Optional[Tuple[str, int]]
|
||||
server: Optional[Tuple[str, Optional[int]]]
|
||||
headers: Iterable[tuple[bytes, bytes]]
|
||||
client: tuple[str, int] | None
|
||||
server: tuple[str, int | None] | None
|
||||
subprotocols: Iterable[str]
|
||||
state: NotRequired[Dict[str, Any]]
|
||||
extensions: NotRequired[Dict[str, Dict[object, object]]]
|
||||
state: NotRequired[dict[str, Any]]
|
||||
extensions: NotRequired[dict[str, dict[object, object]]]
|
||||
|
||||
|
||||
class LifespanScope(TypedDict):
|
||||
type: Literal["lifespan"]
|
||||
asgi: ASGIVersions
|
||||
state: NotRequired[Dict[str, Any]]
|
||||
state: NotRequired[dict[str, Any]]
|
||||
|
||||
|
||||
WWWScope = Union[HTTPScope, WebSocketScope]
|
||||
@@ -118,32 +105,32 @@ class HTTPRequestEvent(TypedDict):
|
||||
|
||||
class HTTPResponseDebugEvent(TypedDict):
|
||||
type: Literal["http.response.debug"]
|
||||
info: Dict[str, object]
|
||||
info: dict[str, object]
|
||||
|
||||
|
||||
class HTTPResponseStartEvent(TypedDict):
|
||||
type: Literal["http.response.start"]
|
||||
status: int
|
||||
headers: Iterable[Tuple[bytes, bytes]]
|
||||
headers: NotRequired[Iterable[tuple[bytes, bytes]]]
|
||||
trailers: NotRequired[bool]
|
||||
|
||||
|
||||
class HTTPResponseBodyEvent(TypedDict):
|
||||
type: Literal["http.response.body"]
|
||||
body: bytes
|
||||
more_body: bool
|
||||
more_body: NotRequired[bool]
|
||||
|
||||
|
||||
class HTTPResponseTrailersEvent(TypedDict):
|
||||
type: Literal["http.response.trailers"]
|
||||
headers: Iterable[Tuple[bytes, bytes]]
|
||||
headers: Iterable[tuple[bytes, bytes]]
|
||||
more_trailers: bool
|
||||
|
||||
|
||||
class HTTPServerPushEvent(TypedDict):
|
||||
type: Literal["http.response.push"]
|
||||
path: str
|
||||
headers: Iterable[Tuple[bytes, bytes]]
|
||||
headers: Iterable[tuple[bytes, bytes]]
|
||||
|
||||
|
||||
class HTTPDisconnectEvent(TypedDict):
|
||||
@@ -156,43 +143,62 @@ class WebSocketConnectEvent(TypedDict):
|
||||
|
||||
class WebSocketAcceptEvent(TypedDict):
|
||||
type: Literal["websocket.accept"]
|
||||
subprotocol: Optional[str]
|
||||
headers: Iterable[Tuple[bytes, bytes]]
|
||||
subprotocol: NotRequired[str | None]
|
||||
headers: NotRequired[Iterable[tuple[bytes, bytes]]]
|
||||
|
||||
|
||||
class WebSocketReceiveEvent(TypedDict):
|
||||
class _WebSocketReceiveEventBytes(TypedDict):
|
||||
type: Literal["websocket.receive"]
|
||||
bytes: Optional[bytes]
|
||||
text: Optional[str]
|
||||
bytes: bytes
|
||||
text: NotRequired[None]
|
||||
|
||||
|
||||
class WebSocketSendEvent(TypedDict):
|
||||
class _WebSocketReceiveEventText(TypedDict):
|
||||
type: Literal["websocket.receive"]
|
||||
bytes: NotRequired[None]
|
||||
text: str
|
||||
|
||||
|
||||
WebSocketReceiveEvent = Union[_WebSocketReceiveEventBytes, _WebSocketReceiveEventText]
|
||||
|
||||
|
||||
class _WebSocketSendEventBytes(TypedDict):
|
||||
type: Literal["websocket.send"]
|
||||
bytes: Optional[bytes]
|
||||
text: Optional[str]
|
||||
bytes: bytes
|
||||
text: NotRequired[None]
|
||||
|
||||
|
||||
class _WebSocketSendEventText(TypedDict):
|
||||
type: Literal["websocket.send"]
|
||||
bytes: NotRequired[None]
|
||||
text: str
|
||||
|
||||
|
||||
WebSocketSendEvent = Union[_WebSocketSendEventBytes, _WebSocketSendEventText]
|
||||
|
||||
|
||||
class WebSocketResponseStartEvent(TypedDict):
|
||||
type: Literal["websocket.http.response.start"]
|
||||
status: int
|
||||
headers: Iterable[Tuple[bytes, bytes]]
|
||||
headers: Iterable[tuple[bytes, bytes]]
|
||||
|
||||
|
||||
class WebSocketResponseBodyEvent(TypedDict):
|
||||
type: Literal["websocket.http.response.body"]
|
||||
body: bytes
|
||||
more_body: bool
|
||||
more_body: NotRequired[bool]
|
||||
|
||||
|
||||
class WebSocketDisconnectEvent(TypedDict):
|
||||
type: Literal["websocket.disconnect"]
|
||||
code: int
|
||||
reason: NotRequired[str | None]
|
||||
|
||||
|
||||
class WebSocketCloseEvent(TypedDict):
|
||||
type: Literal["websocket.close"]
|
||||
code: int
|
||||
reason: Optional[str]
|
||||
code: NotRequired[int]
|
||||
reason: NotRequired[str | None]
|
||||
|
||||
|
||||
class LifespanStartupEvent(TypedDict):
|
||||
@@ -221,9 +227,7 @@ class LifespanShutdownFailedEvent(TypedDict):
|
||||
message: str
|
||||
|
||||
|
||||
WebSocketEvent = Union[
|
||||
WebSocketReceiveEvent, WebSocketDisconnectEvent, WebSocketConnectEvent
|
||||
]
|
||||
WebSocketEvent = Union[WebSocketReceiveEvent, WebSocketDisconnectEvent, WebSocketConnectEvent]
|
||||
|
||||
|
||||
ASGIReceiveEvent = Union[
|
||||
@@ -260,16 +264,12 @@ ASGISendCallable = Callable[[ASGISendEvent], Awaitable[None]]
|
||||
|
||||
|
||||
class ASGI2Protocol(Protocol):
|
||||
def __init__(self, scope: Scope) -> None:
|
||||
... # pragma: no cover
|
||||
def __init__(self, scope: Scope) -> None: ... # pragma: no cover
|
||||
|
||||
async def __call__(
|
||||
self, receive: ASGIReceiveCallable, send: ASGISendCallable
|
||||
) -> None:
|
||||
... # pragma: no cover
|
||||
async def __call__(self, receive: ASGIReceiveCallable, send: ASGISendCallable) -> None: ... # pragma: no cover
|
||||
|
||||
|
||||
ASGI2Application = Type[ASGI2Protocol]
|
||||
ASGI2Application = type[ASGI2Protocol]
|
||||
ASGI3Application = Callable[
|
||||
[
|
||||
Scope,
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import json
|
||||
@@ -7,19 +9,10 @@ import os
|
||||
import socket
|
||||
import ssl
|
||||
import sys
|
||||
from collections.abc import Awaitable
|
||||
from configparser import RawConfigParser
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
from typing import IO, Any, Callable, Literal
|
||||
|
||||
import click
|
||||
|
||||
@@ -32,12 +25,12 @@ from uvicorn.middleware.proxy_headers import ProxyHeadersMiddleware
|
||||
from uvicorn.middleware.wsgi import WSGIMiddleware
|
||||
|
||||
HTTPProtocolType = Literal["auto", "h11", "httptools"]
|
||||
WSProtocolType = Literal["auto", "none", "websockets", "wsproto"]
|
||||
WSProtocolType = Literal["auto", "none", "websockets", "websockets-sansio", "wsproto"]
|
||||
LifespanType = Literal["auto", "on", "off"]
|
||||
LoopSetupType = Literal["none", "auto", "asyncio", "uvloop"]
|
||||
LoopFactoryType = Literal["none", "auto", "asyncio", "uvloop"]
|
||||
InterfaceType = Literal["auto", "asgi3", "asgi2", "wsgi"]
|
||||
|
||||
LOG_LEVELS: Dict[str, int] = {
|
||||
LOG_LEVELS: dict[str, int] = {
|
||||
"critical": logging.CRITICAL,
|
||||
"error": logging.ERROR,
|
||||
"warning": logging.WARNING,
|
||||
@@ -45,33 +38,34 @@ LOG_LEVELS: Dict[str, int] = {
|
||||
"debug": logging.DEBUG,
|
||||
"trace": TRACE_LOG_LEVEL,
|
||||
}
|
||||
HTTP_PROTOCOLS: Dict[HTTPProtocolType, str] = {
|
||||
HTTP_PROTOCOLS: dict[str, str] = {
|
||||
"auto": "uvicorn.protocols.http.auto:AutoHTTPProtocol",
|
||||
"h11": "uvicorn.protocols.http.h11_impl:H11Protocol",
|
||||
"httptools": "uvicorn.protocols.http.httptools_impl:HttpToolsProtocol",
|
||||
}
|
||||
WS_PROTOCOLS: Dict[WSProtocolType, Optional[str]] = {
|
||||
WS_PROTOCOLS: dict[str, str | None] = {
|
||||
"auto": "uvicorn.protocols.websockets.auto:AutoWebSocketsProtocol",
|
||||
"none": None,
|
||||
"websockets": "uvicorn.protocols.websockets.websockets_impl:WebSocketProtocol",
|
||||
"websockets-sansio": "uvicorn.protocols.websockets.websockets_sansio_impl:WebSocketsSansIOProtocol",
|
||||
"wsproto": "uvicorn.protocols.websockets.wsproto_impl:WSProtocol",
|
||||
}
|
||||
LIFESPAN: Dict[LifespanType, str] = {
|
||||
LIFESPAN: dict[str, str] = {
|
||||
"auto": "uvicorn.lifespan.on:LifespanOn",
|
||||
"on": "uvicorn.lifespan.on:LifespanOn",
|
||||
"off": "uvicorn.lifespan.off:LifespanOff",
|
||||
}
|
||||
LOOP_SETUPS: Dict[LoopSetupType, Optional[str]] = {
|
||||
LOOP_FACTORIES: dict[str, str | None] = {
|
||||
"none": None,
|
||||
"auto": "uvicorn.loops.auto:auto_loop_setup",
|
||||
"asyncio": "uvicorn.loops.asyncio:asyncio_setup",
|
||||
"uvloop": "uvicorn.loops.uvloop:uvloop_setup",
|
||||
"auto": "uvicorn.loops.auto:auto_loop_factory",
|
||||
"asyncio": "uvicorn.loops.asyncio:asyncio_loop_factory",
|
||||
"uvloop": "uvicorn.loops.uvloop:uvloop_loop_factory",
|
||||
}
|
||||
INTERFACES: List[InterfaceType] = ["auto", "asgi3", "asgi2", "wsgi"]
|
||||
INTERFACES: list[InterfaceType] = ["auto", "asgi3", "asgi2", "wsgi"]
|
||||
|
||||
SSL_PROTOCOL_VERSION: int = ssl.PROTOCOL_TLS_SERVER
|
||||
|
||||
LOGGING_CONFIG: Dict[str, Any] = {
|
||||
LOGGING_CONFIG: dict[str, Any] = {
|
||||
"version": 1,
|
||||
"disable_existing_loggers": False,
|
||||
"formatters": {
|
||||
@@ -108,13 +102,13 @@ logger = logging.getLogger("uvicorn.error")
|
||||
|
||||
|
||||
def create_ssl_context(
|
||||
certfile: Union[str, os.PathLike],
|
||||
keyfile: Optional[Union[str, os.PathLike]],
|
||||
password: Optional[str],
|
||||
certfile: str | os.PathLike[str],
|
||||
keyfile: str | os.PathLike[str] | None,
|
||||
password: str | None,
|
||||
ssl_version: int,
|
||||
cert_reqs: int,
|
||||
ca_certs: Optional[Union[str, os.PathLike]],
|
||||
ciphers: Optional[str],
|
||||
ca_certs: str | os.PathLike[str] | None,
|
||||
ciphers: str | None,
|
||||
) -> ssl.SSLContext:
|
||||
ctx = ssl.SSLContext(ssl_version)
|
||||
get_password = (lambda: password) if password else None
|
||||
@@ -132,22 +126,20 @@ def is_dir(path: Path) -> bool:
|
||||
if not path.is_absolute():
|
||||
path = path.resolve()
|
||||
return path.is_dir()
|
||||
except OSError:
|
||||
except OSError: # pragma: full coverage
|
||||
return False
|
||||
|
||||
|
||||
def resolve_reload_patterns(
|
||||
patterns_list: List[str], directories_list: List[str]
|
||||
) -> Tuple[List[str], List[Path]]:
|
||||
directories: List[Path] = list(set(map(Path, directories_list.copy())))
|
||||
patterns: List[str] = patterns_list.copy()
|
||||
def resolve_reload_patterns(patterns_list: list[str], directories_list: list[str]) -> tuple[list[str], list[Path]]:
|
||||
directories: list[Path] = list(set(map(Path, directories_list.copy())))
|
||||
patterns: list[str] = patterns_list.copy()
|
||||
|
||||
current_working_directory = Path.cwd()
|
||||
for pattern in patterns_list:
|
||||
# Special case for the .* pattern, otherwise this would only match
|
||||
# hidden directories which is probably undesired
|
||||
if pattern == ".*":
|
||||
continue
|
||||
continue # pragma: py-not-linux
|
||||
patterns.append(pattern)
|
||||
if is_dir(Path(pattern)):
|
||||
directories.append(Path(pattern))
|
||||
@@ -159,15 +151,13 @@ def resolve_reload_patterns(
|
||||
directories = list(set(directories))
|
||||
directories = list(map(Path, directories))
|
||||
directories = list(map(lambda x: x.resolve(), directories))
|
||||
directories = list(
|
||||
{reload_path for reload_path in directories if is_dir(reload_path)}
|
||||
)
|
||||
directories = list({reload_path for reload_path in directories if is_dir(reload_path)})
|
||||
|
||||
children = []
|
||||
for j in range(len(directories)):
|
||||
for k in range(j + 1, len(directories)):
|
||||
for k in range(j + 1, len(directories)): # pragma: full coverage
|
||||
if directories[j] in directories[k].parents:
|
||||
children.append(directories[k]) # pragma: py-darwin
|
||||
children.append(directories[k])
|
||||
elif directories[k] in directories[j].parents:
|
||||
children.append(directories[j])
|
||||
|
||||
@@ -176,7 +166,7 @@ def resolve_reload_patterns(
|
||||
return list(set(patterns)), directories
|
||||
|
||||
|
||||
def _normalize_dirs(dirs: Union[List[str], str, None]) -> List[str]:
|
||||
def _normalize_dirs(dirs: list[str] | str | None) -> list[str]:
|
||||
if dirs is None:
|
||||
return []
|
||||
if isinstance(dirs, str):
|
||||
@@ -187,54 +177,55 @@ def _normalize_dirs(dirs: Union[List[str], str, None]) -> List[str]:
|
||||
class Config:
|
||||
def __init__(
|
||||
self,
|
||||
app: Union["ASGIApplication", Callable, str],
|
||||
app: ASGIApplication | Callable[..., Any] | str,
|
||||
host: str = "127.0.0.1",
|
||||
port: int = 8000,
|
||||
uds: Optional[str] = None,
|
||||
fd: Optional[int] = None,
|
||||
loop: LoopSetupType = "auto",
|
||||
http: Union[Type[asyncio.Protocol], HTTPProtocolType] = "auto",
|
||||
ws: Union[Type[asyncio.Protocol], WSProtocolType] = "auto",
|
||||
uds: str | None = None,
|
||||
fd: int | None = None,
|
||||
loop: LoopFactoryType | str = "auto",
|
||||
http: type[asyncio.Protocol] | HTTPProtocolType | str = "auto",
|
||||
ws: type[asyncio.Protocol] | WSProtocolType | str = "auto",
|
||||
ws_max_size: int = 16 * 1024 * 1024,
|
||||
ws_max_queue: int = 32,
|
||||
ws_ping_interval: Optional[float] = 20.0,
|
||||
ws_ping_timeout: Optional[float] = 20.0,
|
||||
ws_ping_interval: float | None = 20.0,
|
||||
ws_ping_timeout: float | None = 20.0,
|
||||
ws_per_message_deflate: bool = True,
|
||||
lifespan: LifespanType = "auto",
|
||||
env_file: Optional[Union[str, os.PathLike]] = None,
|
||||
log_config: Optional[Union[Dict[str, Any], str]] = LOGGING_CONFIG,
|
||||
log_level: Optional[Union[str, int]] = None,
|
||||
env_file: str | os.PathLike[str] | None = None,
|
||||
log_config: dict[str, Any] | str | RawConfigParser | IO[Any] | None = LOGGING_CONFIG,
|
||||
log_level: str | int | None = None,
|
||||
access_log: bool = True,
|
||||
use_colors: Optional[bool] = None,
|
||||
use_colors: bool | None = None,
|
||||
interface: InterfaceType = "auto",
|
||||
reload: bool = False,
|
||||
reload_dirs: Optional[Union[List[str], str]] = None,
|
||||
reload_dirs: list[str] | str | None = None,
|
||||
reload_delay: float = 0.25,
|
||||
reload_includes: Optional[Union[List[str], str]] = None,
|
||||
reload_excludes: Optional[Union[List[str], str]] = None,
|
||||
workers: Optional[int] = None,
|
||||
reload_includes: list[str] | str | None = None,
|
||||
reload_excludes: list[str] | str | None = None,
|
||||
workers: int | None = None,
|
||||
proxy_headers: bool = True,
|
||||
server_header: bool = True,
|
||||
date_header: bool = True,
|
||||
forwarded_allow_ips: Optional[Union[List[str], str]] = None,
|
||||
forwarded_allow_ips: list[str] | str | None = None,
|
||||
root_path: str = "",
|
||||
limit_concurrency: Optional[int] = None,
|
||||
limit_max_requests: Optional[int] = None,
|
||||
limit_concurrency: int | None = None,
|
||||
limit_max_requests: int | None = None,
|
||||
backlog: int = 2048,
|
||||
timeout_keep_alive: int = 5,
|
||||
timeout_notify: int = 30,
|
||||
timeout_graceful_shutdown: Optional[int] = None,
|
||||
callback_notify: Optional[Callable[..., Awaitable[None]]] = None,
|
||||
ssl_keyfile: Optional[str] = None,
|
||||
ssl_certfile: Optional[Union[str, os.PathLike]] = None,
|
||||
ssl_keyfile_password: Optional[str] = None,
|
||||
timeout_graceful_shutdown: int | None = None,
|
||||
timeout_worker_healthcheck: int = 5,
|
||||
callback_notify: Callable[..., Awaitable[None]] | None = None,
|
||||
ssl_keyfile: str | os.PathLike[str] | None = None,
|
||||
ssl_certfile: str | os.PathLike[str] | None = None,
|
||||
ssl_keyfile_password: str | None = None,
|
||||
ssl_version: int = SSL_PROTOCOL_VERSION,
|
||||
ssl_cert_reqs: int = ssl.CERT_NONE,
|
||||
ssl_ca_certs: Optional[str] = None,
|
||||
ssl_ca_certs: str | os.PathLike[str] | None = None,
|
||||
ssl_ciphers: str = "TLSv1",
|
||||
headers: Optional[List[Tuple[str, str]]] = None,
|
||||
headers: list[tuple[str, str]] | None = None,
|
||||
factory: bool = False,
|
||||
h11_max_incomplete_event_size: Optional[int] = None,
|
||||
h11_max_incomplete_event_size: int | None = None,
|
||||
):
|
||||
self.app = app
|
||||
self.host = host
|
||||
@@ -268,6 +259,7 @@ class Config:
|
||||
self.timeout_keep_alive = timeout_keep_alive
|
||||
self.timeout_notify = timeout_notify
|
||||
self.timeout_graceful_shutdown = timeout_graceful_shutdown
|
||||
self.timeout_worker_healthcheck = timeout_worker_healthcheck
|
||||
self.callback_notify = callback_notify
|
||||
self.ssl_keyfile = ssl_keyfile
|
||||
self.ssl_certfile = ssl_certfile
|
||||
@@ -276,25 +268,22 @@ class Config:
|
||||
self.ssl_cert_reqs = ssl_cert_reqs
|
||||
self.ssl_ca_certs = ssl_ca_certs
|
||||
self.ssl_ciphers = ssl_ciphers
|
||||
self.headers: List[Tuple[str, str]] = headers or []
|
||||
self.encoded_headers: List[Tuple[bytes, bytes]] = []
|
||||
self.headers: list[tuple[str, str]] = headers or []
|
||||
self.encoded_headers: list[tuple[bytes, bytes]] = []
|
||||
self.factory = factory
|
||||
self.h11_max_incomplete_event_size = h11_max_incomplete_event_size
|
||||
|
||||
self.loaded = False
|
||||
self.configure_logging()
|
||||
|
||||
self.reload_dirs: List[Path] = []
|
||||
self.reload_dirs_excludes: List[Path] = []
|
||||
self.reload_includes: List[str] = []
|
||||
self.reload_excludes: List[str] = []
|
||||
self.reload_dirs: list[Path] = []
|
||||
self.reload_dirs_excludes: list[Path] = []
|
||||
self.reload_includes: list[str] = []
|
||||
self.reload_excludes: list[str] = []
|
||||
|
||||
if (
|
||||
reload_dirs or reload_includes or reload_excludes
|
||||
) and not self.should_reload:
|
||||
if (reload_dirs or reload_includes or reload_excludes) and not self.should_reload:
|
||||
logger.warning(
|
||||
"Current configuration will not reload as not all conditions are met, "
|
||||
"please refer to documentation."
|
||||
"Current configuration will not reload as not all conditions are met, please refer to documentation."
|
||||
)
|
||||
|
||||
if self.should_reload:
|
||||
@@ -302,30 +291,23 @@ class Config:
|
||||
reload_includes = _normalize_dirs(reload_includes)
|
||||
reload_excludes = _normalize_dirs(reload_excludes)
|
||||
|
||||
self.reload_includes, self.reload_dirs = resolve_reload_patterns(
|
||||
reload_includes, reload_dirs
|
||||
)
|
||||
self.reload_includes, self.reload_dirs = resolve_reload_patterns(reload_includes, reload_dirs)
|
||||
|
||||
self.reload_excludes, self.reload_dirs_excludes = resolve_reload_patterns(
|
||||
reload_excludes, []
|
||||
)
|
||||
self.reload_excludes, self.reload_dirs_excludes = resolve_reload_patterns(reload_excludes, [])
|
||||
|
||||
reload_dirs_tmp = self.reload_dirs.copy()
|
||||
|
||||
for directory in self.reload_dirs_excludes:
|
||||
for reload_directory in reload_dirs_tmp:
|
||||
if (
|
||||
directory == reload_directory
|
||||
or directory in reload_directory.parents
|
||||
):
|
||||
if directory == reload_directory or directory in reload_directory.parents:
|
||||
try:
|
||||
self.reload_dirs.remove(reload_directory)
|
||||
except ValueError:
|
||||
except ValueError: # pragma: full coverage
|
||||
pass
|
||||
|
||||
for pattern in self.reload_excludes:
|
||||
if pattern in self.reload_includes:
|
||||
self.reload_includes.remove(pattern)
|
||||
self.reload_includes.remove(pattern) # pragma: full coverage
|
||||
|
||||
if not self.reload_dirs:
|
||||
if reload_dirs:
|
||||
@@ -334,7 +316,7 @@ class Config:
|
||||
+ "directories, watching current working directory.",
|
||||
reload_dirs,
|
||||
)
|
||||
self.reload_dirs = [Path(os.getcwd())]
|
||||
self.reload_dirs = [Path.cwd()]
|
||||
|
||||
logger.info(
|
||||
"Will watch for changes in these directories: %s",
|
||||
@@ -350,20 +332,18 @@ class Config:
|
||||
if workers is None and "WEB_CONCURRENCY" in os.environ:
|
||||
self.workers = int(os.environ["WEB_CONCURRENCY"])
|
||||
|
||||
self.forwarded_allow_ips: Union[List[str], str]
|
||||
self.forwarded_allow_ips: list[str] | str
|
||||
if forwarded_allow_ips is None:
|
||||
self.forwarded_allow_ips = os.environ.get(
|
||||
"FORWARDED_ALLOW_IPS", "127.0.0.1"
|
||||
)
|
||||
self.forwarded_allow_ips = os.environ.get("FORWARDED_ALLOW_IPS", "127.0.0.1")
|
||||
else:
|
||||
self.forwarded_allow_ips = forwarded_allow_ips
|
||||
self.forwarded_allow_ips = forwarded_allow_ips # pragma: full coverage
|
||||
|
||||
if self.reload and self.workers > 1:
|
||||
logger.warning('"workers" flag is ignored when reloading is enabled.')
|
||||
|
||||
@property
|
||||
def asgi_version(self) -> Literal["2.0", "3.0"]:
|
||||
mapping: Dict[str, Literal["2.0", "3.0"]] = {
|
||||
mapping: dict[str, Literal["2.0", "3.0"]] = {
|
||||
"asgi2": "2.0",
|
||||
"asgi3": "3.0",
|
||||
"wsgi": "3.0",
|
||||
@@ -384,18 +364,14 @@ class Config:
|
||||
if self.log_config is not None:
|
||||
if isinstance(self.log_config, dict):
|
||||
if self.use_colors in (True, False):
|
||||
self.log_config["formatters"]["default"][
|
||||
"use_colors"
|
||||
] = self.use_colors
|
||||
self.log_config["formatters"]["access"][
|
||||
"use_colors"
|
||||
] = self.use_colors
|
||||
self.log_config["formatters"]["default"]["use_colors"] = self.use_colors
|
||||
self.log_config["formatters"]["access"]["use_colors"] = self.use_colors
|
||||
logging.config.dictConfig(self.log_config)
|
||||
elif self.log_config.endswith(".json"):
|
||||
elif isinstance(self.log_config, str) and self.log_config.endswith(".json"):
|
||||
with open(self.log_config) as file:
|
||||
loaded_config = json.load(file)
|
||||
logging.config.dictConfig(loaded_config)
|
||||
elif self.log_config.endswith((".yaml", ".yml")):
|
||||
elif isinstance(self.log_config, str) and self.log_config.endswith((".yaml", ".yml")):
|
||||
# Install the PyYAML package or the uvicorn[standard] optional
|
||||
# dependencies to enable this functionality.
|
||||
import yaml
|
||||
@@ -406,9 +382,7 @@ class Config:
|
||||
else:
|
||||
# See the note about fileConfig() here:
|
||||
# https://docs.python.org/3/library/logging.config.html#configuration-file-format
|
||||
logging.config.fileConfig(
|
||||
self.log_config, disable_existing_loggers=False
|
||||
)
|
||||
logging.config.fileConfig(self.log_config, disable_existing_loggers=False)
|
||||
|
||||
if self.log_level is not None:
|
||||
if isinstance(self.log_level, str):
|
||||
@@ -427,7 +401,7 @@ class Config:
|
||||
|
||||
if self.is_ssl:
|
||||
assert self.ssl_certfile
|
||||
self.ssl: Optional[ssl.SSLContext] = create_ssl_context(
|
||||
self.ssl: ssl.SSLContext | None = create_ssl_context(
|
||||
keyfile=self.ssl_keyfile,
|
||||
certfile=self.ssl_certfile,
|
||||
password=self.ssl_keyfile_password,
|
||||
@@ -439,10 +413,7 @@ class Config:
|
||||
else:
|
||||
self.ssl = None
|
||||
|
||||
encoded_headers = [
|
||||
(key.lower().encode("latin1"), value.encode("latin1"))
|
||||
for key, value in self.headers
|
||||
]
|
||||
encoded_headers = [(key.lower().encode("latin1"), value.encode("latin1")) for key, value in self.headers]
|
||||
self.encoded_headers = (
|
||||
[(b"server", b"uvicorn")] + encoded_headers
|
||||
if b"server" not in dict(encoded_headers) and self.server_header
|
||||
@@ -450,14 +421,14 @@ class Config:
|
||||
)
|
||||
|
||||
if isinstance(self.http, str):
|
||||
http_protocol_class = import_from_string(HTTP_PROTOCOLS[self.http])
|
||||
self.http_protocol_class: Type[asyncio.Protocol] = http_protocol_class
|
||||
http_protocol_class = import_from_string(HTTP_PROTOCOLS.get(self.http, self.http))
|
||||
self.http_protocol_class: type[asyncio.Protocol] = http_protocol_class
|
||||
else:
|
||||
self.http_protocol_class = self.http
|
||||
|
||||
if isinstance(self.ws, str):
|
||||
ws_protocol_class = import_from_string(WS_PROTOCOLS[self.ws])
|
||||
self.ws_protocol_class: Optional[Type[asyncio.Protocol]] = ws_protocol_class
|
||||
ws_protocol_class = import_from_string(WS_PROTOCOLS.get(self.ws, self.ws))
|
||||
self.ws_protocol_class: type[asyncio.Protocol] | None = ws_protocol_class
|
||||
else:
|
||||
self.ws_protocol_class = self.ws
|
||||
|
||||
@@ -478,18 +449,17 @@ class Config:
|
||||
else:
|
||||
if not self.factory:
|
||||
logger.warning(
|
||||
"ASGI app factory detected. Using it, "
|
||||
"but please consider setting the --factory flag explicitly."
|
||||
"ASGI app factory detected. Using it, but please consider setting the --factory flag explicitly."
|
||||
)
|
||||
|
||||
if self.interface == "auto":
|
||||
if inspect.isclass(self.loaded_app):
|
||||
use_asgi_3 = hasattr(self.loaded_app, "__await__")
|
||||
elif inspect.isfunction(self.loaded_app):
|
||||
use_asgi_3 = asyncio.iscoroutinefunction(self.loaded_app)
|
||||
use_asgi_3 = inspect.iscoroutinefunction(self.loaded_app)
|
||||
else:
|
||||
call = getattr(self.loaded_app, "__call__", None)
|
||||
use_asgi_3 = asyncio.iscoroutinefunction(call)
|
||||
use_asgi_3 = inspect.iscoroutinefunction(call)
|
||||
self.interface = "asgi3" if use_asgi_3 else "asgi2"
|
||||
|
||||
if self.interface == "wsgi":
|
||||
@@ -501,19 +471,32 @@ class Config:
|
||||
if logger.getEffectiveLevel() <= TRACE_LOG_LEVEL:
|
||||
self.loaded_app = MessageLoggerMiddleware(self.loaded_app)
|
||||
if self.proxy_headers:
|
||||
self.loaded_app = ProxyHeadersMiddleware(
|
||||
self.loaded_app, trusted_hosts=self.forwarded_allow_ips
|
||||
)
|
||||
self.loaded_app = ProxyHeadersMiddleware(self.loaded_app, trusted_hosts=self.forwarded_allow_ips)
|
||||
|
||||
self.loaded = True
|
||||
|
||||
def setup_event_loop(self) -> None:
|
||||
loop_setup: Optional[Callable] = import_from_string(LOOP_SETUPS[self.loop])
|
||||
if loop_setup is not None:
|
||||
loop_setup(use_subprocess=self.use_subprocess)
|
||||
raise AttributeError(
|
||||
"The `setup_event_loop` method was replaced by `get_loop_factory` in uvicorn 0.36.0.\n"
|
||||
"None of those methods are supposed to be used directly. If you are doing it, please let me know here: "
|
||||
"https://github.com/Kludex/uvicorn/discussions/2706. Thank you, and sorry for the inconvenience."
|
||||
)
|
||||
|
||||
def get_loop_factory(self) -> Callable[[], asyncio.AbstractEventLoop] | None:
|
||||
if self.loop in LOOP_FACTORIES:
|
||||
loop_factory: Callable[..., Any] | None = import_from_string(LOOP_FACTORIES[self.loop])
|
||||
else:
|
||||
try:
|
||||
return import_from_string(self.loop)
|
||||
except ImportFromStringError as exc:
|
||||
logger.error("Error loading custom loop setup function. %s" % exc)
|
||||
sys.exit(1)
|
||||
if loop_factory is None:
|
||||
return None
|
||||
return loop_factory(use_subprocess=self.use_subprocess)
|
||||
|
||||
def bind_socket(self) -> socket.socket:
|
||||
logger_args: List[Union[str, int]]
|
||||
logger_args: list[str | int]
|
||||
if self.uds: # pragma: py-win32
|
||||
path = self.uds
|
||||
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
|
||||
@@ -521,33 +504,25 @@ class Config:
|
||||
sock.bind(path)
|
||||
uds_perms = 0o666
|
||||
os.chmod(self.uds, uds_perms)
|
||||
except OSError as exc:
|
||||
except OSError as exc: # pragma: full coverage
|
||||
logger.error(exc)
|
||||
sys.exit(1)
|
||||
|
||||
message = "Uvicorn running on unix socket %s (Press CTRL+C to quit)"
|
||||
sock_name_format = "%s"
|
||||
color_message = (
|
||||
"Uvicorn running on "
|
||||
+ click.style(sock_name_format, bold=True)
|
||||
+ " (Press CTRL+C to quit)"
|
||||
)
|
||||
color_message = "Uvicorn running on " + click.style(sock_name_format, bold=True) + " (Press CTRL+C to quit)"
|
||||
logger_args = [self.uds]
|
||||
elif self.fd: # pragma: py-win32
|
||||
sock = socket.fromfd(self.fd, socket.AF_UNIX, socket.SOCK_STREAM)
|
||||
message = "Uvicorn running on socket %s (Press CTRL+C to quit)"
|
||||
fd_name_format = "%s"
|
||||
color_message = (
|
||||
"Uvicorn running on "
|
||||
+ click.style(fd_name_format, bold=True)
|
||||
+ " (Press CTRL+C to quit)"
|
||||
)
|
||||
color_message = "Uvicorn running on " + click.style(fd_name_format, bold=True) + " (Press CTRL+C to quit)"
|
||||
logger_args = [sock.getsockname()]
|
||||
else:
|
||||
family = socket.AF_INET
|
||||
addr_format = "%s://%s:%d"
|
||||
|
||||
if self.host and ":" in self.host: # pragma: py-win32
|
||||
if self.host and ":" in self.host: # pragma: full coverage
|
||||
# It's an IPv6 address.
|
||||
family = socket.AF_INET6
|
||||
addr_format = "%s://[%s]:%d"
|
||||
@@ -556,16 +531,12 @@ class Config:
|
||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
try:
|
||||
sock.bind((self.host, self.port))
|
||||
except OSError as exc:
|
||||
except OSError as exc: # pragma: full coverage
|
||||
logger.error(exc)
|
||||
sys.exit(1)
|
||||
|
||||
message = f"Uvicorn running on {addr_format} (Press CTRL+C to quit)"
|
||||
color_message = (
|
||||
"Uvicorn running on "
|
||||
+ click.style(addr_format, bold=True)
|
||||
+ " (Press CTRL+C to quit)"
|
||||
)
|
||||
color_message = "Uvicorn running on " + click.style(addr_format, bold=True) + " (Press CTRL+C to quit)"
|
||||
protocol_name = "https" if self.is_ssl else "http"
|
||||
logger_args = [protocol_name, self.host, sock.getsockname()[1]]
|
||||
logger.info(message, *logger_args, extra={"color_message": color_message})
|
||||
|
||||
@@ -12,9 +12,7 @@ def import_from_string(import_str: Any) -> Any:
|
||||
|
||||
module_str, _, attrs_str = import_str.partition(":")
|
||||
if not module_str or not attrs_str:
|
||||
message = (
|
||||
'Import string "{import_str}" must be in format "<module>:<attribute>".'
|
||||
)
|
||||
message = 'Import string "{import_str}" must be in format "<module>:<attribute>".'
|
||||
raise ImportFromStringError(message.format(import_str=import_str))
|
||||
|
||||
try:
|
||||
@@ -31,8 +29,6 @@ def import_from_string(import_str: Any) -> Any:
|
||||
instance = getattr(instance, attr_str)
|
||||
except AttributeError:
|
||||
message = 'Attribute "{attrs_str}" not found in module "{module_str}".'
|
||||
raise ImportFromStringError(
|
||||
message.format(attrs_str=attrs_str, module_str=module_str)
|
||||
)
|
||||
raise ImportFromStringError(message.format(attrs_str=attrs_str, module_str=module_str))
|
||||
|
||||
return instance
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
from typing import Any, Dict
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from uvicorn import Config
|
||||
|
||||
@@ -6,7 +8,7 @@ from uvicorn import Config
|
||||
class LifespanOff:
|
||||
def __init__(self, config: Config) -> None:
|
||||
self.should_exit = False
|
||||
self.state: Dict[str, Any] = {}
|
||||
self.state: dict[str, Any] = {}
|
||||
|
||||
async def startup(self) -> None:
|
||||
pass
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from asyncio import Queue
|
||||
from typing import Any, Dict, Union
|
||||
from typing import Any, Union
|
||||
|
||||
from uvicorn import Config
|
||||
from uvicorn._types import (
|
||||
@@ -35,12 +37,12 @@ class LifespanOn:
|
||||
self.logger = logging.getLogger("uvicorn.error")
|
||||
self.startup_event = asyncio.Event()
|
||||
self.shutdown_event = asyncio.Event()
|
||||
self.receive_queue: "Queue[LifespanReceiveMessage]" = asyncio.Queue()
|
||||
self.receive_queue: Queue[LifespanReceiveMessage] = asyncio.Queue()
|
||||
self.error_occured = False
|
||||
self.startup_failed = False
|
||||
self.shutdown_failed = False
|
||||
self.should_exit = False
|
||||
self.state: Dict[str, Any] = {}
|
||||
self.state: dict[str, Any] = {}
|
||||
|
||||
async def startup(self) -> None:
|
||||
self.logger.info("Waiting for application startup.")
|
||||
@@ -48,7 +50,7 @@ class LifespanOn:
|
||||
loop = asyncio.get_event_loop()
|
||||
main_lifespan_task = loop.create_task(self.main()) # noqa: F841
|
||||
# Keep a hard reference to prevent garbage collection
|
||||
# See https://github.com/encode/uvicorn/pull/972
|
||||
# See https://github.com/Kludex/uvicorn/pull/972
|
||||
startup_event: LifespanStartupEvent = {"type": "lifespan.startup"}
|
||||
await self.receive_queue.put(startup_event)
|
||||
await self.startup_event.wait()
|
||||
@@ -67,9 +69,7 @@ class LifespanOn:
|
||||
await self.receive_queue.put(shutdown_event)
|
||||
await self.shutdown_event.wait()
|
||||
|
||||
if self.shutdown_failed or (
|
||||
self.error_occured and self.config.lifespan == "on"
|
||||
):
|
||||
if self.shutdown_failed or (self.error_occured and self.config.lifespan == "on"):
|
||||
self.logger.error("Application shutdown failed. Exiting.")
|
||||
self.should_exit = True
|
||||
else:
|
||||
@@ -99,7 +99,7 @@ class LifespanOn:
|
||||
self.startup_event.set()
|
||||
self.shutdown_event.set()
|
||||
|
||||
async def send(self, message: "LifespanSendMessage") -> None:
|
||||
async def send(self, message: LifespanSendMessage) -> None:
|
||||
assert message["type"] in (
|
||||
"lifespan.startup.complete",
|
||||
"lifespan.startup.failed",
|
||||
@@ -133,5 +133,5 @@ class LifespanOn:
|
||||
if message.get("message"):
|
||||
self.logger.error(message["message"])
|
||||
|
||||
async def receive(self) -> "LifespanReceiveMessage":
|
||||
async def receive(self) -> LifespanReceiveMessage:
|
||||
return await self.receive_queue.get()
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import http
|
||||
import logging
|
||||
import sys
|
||||
from copy import copy
|
||||
from typing import Literal, Optional
|
||||
from typing import Literal
|
||||
|
||||
import click
|
||||
|
||||
@@ -14,7 +16,7 @@ class ColourizedFormatter(logging.Formatter):
|
||||
A custom log formatter class that:
|
||||
|
||||
* Outputs the LOG_LEVEL with an appropriate color.
|
||||
* If a log call includes an `extras={"color_message": ...}` it will be used
|
||||
* If a log call includes an `extra={"color_message": ...}` it will be used
|
||||
for formatting the output, instead of the plain text message.
|
||||
"""
|
||||
|
||||
@@ -24,17 +26,15 @@ class ColourizedFormatter(logging.Formatter):
|
||||
logging.INFO: lambda level_name: click.style(str(level_name), fg="green"),
|
||||
logging.WARNING: lambda level_name: click.style(str(level_name), fg="yellow"),
|
||||
logging.ERROR: lambda level_name: click.style(str(level_name), fg="red"),
|
||||
logging.CRITICAL: lambda level_name: click.style(
|
||||
str(level_name), fg="bright_red"
|
||||
),
|
||||
logging.CRITICAL: lambda level_name: click.style(str(level_name), fg="bright_red"),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fmt: Optional[str] = None,
|
||||
datefmt: Optional[str] = None,
|
||||
fmt: str | None = None,
|
||||
datefmt: str | None = None,
|
||||
style: Literal["%", "{", "$"] = "%",
|
||||
use_colors: Optional[bool] = None,
|
||||
use_colors: bool | None = None,
|
||||
):
|
||||
if use_colors in (True, False):
|
||||
self.use_colors = use_colors
|
||||
@@ -84,7 +84,7 @@ class AccessFormatter(ColourizedFormatter):
|
||||
status_phrase = http.HTTPStatus(status_code).phrase
|
||||
except ValueError:
|
||||
status_phrase = ""
|
||||
status_and_phrase = "%s %s" % (status_code, status_phrase)
|
||||
status_and_phrase = f"{status_code} {status_phrase}"
|
||||
if self.use_colors:
|
||||
|
||||
def default(code: int) -> str:
|
||||
@@ -104,7 +104,7 @@ class AccessFormatter(ColourizedFormatter):
|
||||
status_code,
|
||||
) = recordcopy.args # type: ignore[misc]
|
||||
status_code = self.get_status_code(int(status_code)) # type: ignore[arg-type]
|
||||
request_line = "%s %s HTTP/%s" % (method, full_path, http_version)
|
||||
request_line = f"{method} {full_path} HTTP/{http_version}"
|
||||
if self.use_colors:
|
||||
request_line = click.style(request_line, bold=True)
|
||||
recordcopy.__dict__.update(
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import sys
|
||||
|
||||
logger = logging.getLogger("uvicorn.error")
|
||||
from collections.abc import Callable
|
||||
|
||||
|
||||
def asyncio_setup(use_subprocess: bool = False) -> None:
|
||||
if sys.platform == "win32" and use_subprocess:
|
||||
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
||||
def asyncio_loop_factory(use_subprocess: bool = False) -> Callable[[], asyncio.AbstractEventLoop]:
|
||||
if sys.platform == "win32" and not use_subprocess:
|
||||
return asyncio.ProactorEventLoop
|
||||
return asyncio.SelectorEventLoop
|
||||
|
||||
@@ -1,11 +1,17 @@
|
||||
def auto_loop_setup(use_subprocess: bool = False) -> None:
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Callable
|
||||
|
||||
|
||||
def auto_loop_factory(use_subprocess: bool = False) -> Callable[[], asyncio.AbstractEventLoop]:
|
||||
try:
|
||||
import uvloop # noqa
|
||||
except ImportError: # pragma: no cover
|
||||
from uvicorn.loops.asyncio import asyncio_setup as loop_setup
|
||||
from uvicorn.loops.asyncio import asyncio_loop_factory as loop_factory
|
||||
|
||||
loop_setup(use_subprocess=use_subprocess)
|
||||
return loop_factory(use_subprocess=use_subprocess)
|
||||
else: # pragma: no cover
|
||||
from uvicorn.loops.uvloop import uvloop_setup
|
||||
from uvicorn.loops.uvloop import uvloop_loop_factory
|
||||
|
||||
uvloop_setup(use_subprocess=use_subprocess)
|
||||
return uvloop_loop_factory(use_subprocess=use_subprocess)
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Callable
|
||||
|
||||
import uvloop
|
||||
|
||||
|
||||
def uvloop_setup(use_subprocess: bool = False) -> None:
|
||||
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||
def uvloop_loop_factory(use_subprocess: bool = False) -> Callable[[], asyncio.AbstractEventLoop]:
|
||||
return uvloop.new_event_loop
|
||||
|
||||
@@ -1,41 +1,44 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
import ssl
|
||||
import sys
|
||||
import typing
|
||||
import warnings
|
||||
from configparser import RawConfigParser
|
||||
from typing import IO, Any, Callable, get_args
|
||||
|
||||
import click
|
||||
|
||||
import uvicorn
|
||||
from uvicorn._types import ASGIApplication
|
||||
from uvicorn.config import (
|
||||
HTTP_PROTOCOLS,
|
||||
INTERFACES,
|
||||
LIFESPAN,
|
||||
LOG_LEVELS,
|
||||
LOGGING_CONFIG,
|
||||
LOOP_SETUPS,
|
||||
SSL_PROTOCOL_VERSION,
|
||||
WS_PROTOCOLS,
|
||||
Config,
|
||||
HTTPProtocolType,
|
||||
InterfaceType,
|
||||
LifespanType,
|
||||
LoopSetupType,
|
||||
LoopFactoryType,
|
||||
WSProtocolType,
|
||||
)
|
||||
from uvicorn.server import Server, ServerState # noqa: F401 # Used to be defined here.
|
||||
from uvicorn.server import Server
|
||||
from uvicorn.supervisors import ChangeReload, Multiprocess
|
||||
|
||||
LEVEL_CHOICES = click.Choice(list(LOG_LEVELS.keys()))
|
||||
HTTP_CHOICES = click.Choice(list(HTTP_PROTOCOLS.keys()))
|
||||
WS_CHOICES = click.Choice(list(WS_PROTOCOLS.keys()))
|
||||
LIFESPAN_CHOICES = click.Choice(list(LIFESPAN.keys()))
|
||||
LOOP_CHOICES = click.Choice([key for key in LOOP_SETUPS.keys() if key != "none"])
|
||||
INTERFACE_CHOICES = click.Choice(INTERFACES)
|
||||
|
||||
|
||||
def _metavar_from_type(_type: Any) -> str:
|
||||
return f"[{'|'.join(key for key in get_args(_type) if key != 'none')}]"
|
||||
|
||||
|
||||
STARTUP_FAILURE = 3
|
||||
|
||||
logger = logging.getLogger("uvicorn.error")
|
||||
@@ -45,12 +48,11 @@ def print_version(ctx: click.Context, param: click.Parameter, value: bool) -> No
|
||||
if not value or ctx.resilient_parsing:
|
||||
return
|
||||
click.echo(
|
||||
"Running uvicorn %s with %s %s on %s"
|
||||
% (
|
||||
uvicorn.__version__,
|
||||
platform.python_implementation(),
|
||||
platform.python_version(),
|
||||
platform.system(),
|
||||
"Running uvicorn {version} with {py_implementation} {py_version} on {system}".format( # noqa: UP032
|
||||
version=uvicorn.__version__,
|
||||
py_implementation=platform.python_implementation(),
|
||||
py_version=platform.python_version(),
|
||||
system=platform.system(),
|
||||
)
|
||||
)
|
||||
ctx.exit()
|
||||
@@ -73,16 +75,13 @@ def print_version(ctx: click.Context, param: click.Parameter, value: bool) -> No
|
||||
show_default=True,
|
||||
)
|
||||
@click.option("--uds", type=str, default=None, help="Bind to a UNIX domain socket.")
|
||||
@click.option(
|
||||
"--fd", type=int, default=None, help="Bind to socket from this file descriptor."
|
||||
)
|
||||
@click.option("--fd", type=int, default=None, help="Bind to socket from this file descriptor.")
|
||||
@click.option("--reload", is_flag=True, default=False, help="Enable auto-reload.")
|
||||
@click.option(
|
||||
"--reload-dir",
|
||||
"reload_dirs",
|
||||
multiple=True,
|
||||
help="Set reload directories explicitly, instead of using the current working"
|
||||
" directory.",
|
||||
help="Set reload directories explicitly, instead of using the current working directory.",
|
||||
type=click.Path(exists=True),
|
||||
)
|
||||
@click.option(
|
||||
@@ -107,8 +106,7 @@ def print_version(ctx: click.Context, param: click.Parameter, value: bool) -> No
|
||||
type=float,
|
||||
default=0.25,
|
||||
show_default=True,
|
||||
help="Delay between previous and next check if application needs to be."
|
||||
" Defaults to 0.25s.",
|
||||
help="Delay between previous and next check if application needs to be. Defaults to 0.25s.",
|
||||
)
|
||||
@click.option(
|
||||
"--workers",
|
||||
@@ -119,21 +117,24 @@ def print_version(ctx: click.Context, param: click.Parameter, value: bool) -> No
|
||||
)
|
||||
@click.option(
|
||||
"--loop",
|
||||
type=LOOP_CHOICES,
|
||||
type=str,
|
||||
metavar=_metavar_from_type(LoopFactoryType),
|
||||
default="auto",
|
||||
help="Event loop implementation.",
|
||||
help="Event loop factory implementation.",
|
||||
show_default=True,
|
||||
)
|
||||
@click.option(
|
||||
"--http",
|
||||
type=HTTP_CHOICES,
|
||||
type=str,
|
||||
metavar=_metavar_from_type(HTTPProtocolType),
|
||||
default="auto",
|
||||
help="HTTP protocol implementation.",
|
||||
show_default=True,
|
||||
)
|
||||
@click.option(
|
||||
"--ws",
|
||||
type=WS_CHOICES,
|
||||
type=str,
|
||||
metavar=_metavar_from_type(WSProtocolType),
|
||||
default="auto",
|
||||
help="WebSocket protocol implementation.",
|
||||
show_default=True,
|
||||
@@ -156,14 +157,14 @@ def print_version(ctx: click.Context, param: click.Parameter, value: bool) -> No
|
||||
"--ws-ping-interval",
|
||||
type=float,
|
||||
default=20.0,
|
||||
help="WebSocket ping interval",
|
||||
help="WebSocket ping interval in seconds.",
|
||||
show_default=True,
|
||||
)
|
||||
@click.option(
|
||||
"--ws-ping-timeout",
|
||||
type=float,
|
||||
default=20.0,
|
||||
help="WebSocket ping timeout",
|
||||
help="WebSocket ping timeout in seconds.",
|
||||
show_default=True,
|
||||
)
|
||||
@click.option(
|
||||
@@ -224,8 +225,7 @@ def print_version(ctx: click.Context, param: click.Parameter, value: bool) -> No
|
||||
"--proxy-headers/--no-proxy-headers",
|
||||
is_flag=True,
|
||||
default=True,
|
||||
help="Enable/Disable X-Forwarded-Proto, X-Forwarded-For, X-Forwarded-Port to "
|
||||
"populate remote address info.",
|
||||
help="Enable/Disable X-Forwarded-Proto, X-Forwarded-For to populate url scheme and remote address info.",
|
||||
)
|
||||
@click.option(
|
||||
"--server-header/--no-server-header",
|
||||
@@ -243,8 +243,10 @@ def print_version(ctx: click.Context, param: click.Parameter, value: bool) -> No
|
||||
"--forwarded-allow-ips",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Comma separated list of IPs to trust with proxy headers. Defaults to"
|
||||
" the $FORWARDED_ALLOW_IPS environment variable if available, or '127.0.0.1'.",
|
||||
help="Comma separated list of IP Addresses, IP Networks, or literals "
|
||||
"(e.g. UNIX Socket path) to trust with proxy headers. Defaults to the "
|
||||
"$FORWARDED_ALLOW_IPS environment variable if available, or '127.0.0.1'. "
|
||||
"The literal '*' means trust everything.",
|
||||
)
|
||||
@click.option(
|
||||
"--root-path",
|
||||
@@ -256,8 +258,7 @@ def print_version(ctx: click.Context, param: click.Parameter, value: bool) -> No
|
||||
"--limit-concurrency",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Maximum number of concurrent connections or tasks to allow, before issuing"
|
||||
" HTTP 503 responses.",
|
||||
help="Maximum number of concurrent connections or tasks to allow, before issuing HTTP 503 responses.",
|
||||
)
|
||||
@click.option(
|
||||
"--backlog",
|
||||
@@ -275,7 +276,7 @@ def print_version(ctx: click.Context, param: click.Parameter, value: bool) -> No
|
||||
"--timeout-keep-alive",
|
||||
type=int,
|
||||
default=5,
|
||||
help="Close Keep-Alive connections if no new data is received within this timeout.",
|
||||
help="Close Keep-Alive connections if no new data is received within this timeout (in seconds).",
|
||||
show_default=True,
|
||||
)
|
||||
@click.option(
|
||||
@@ -285,8 +286,13 @@ def print_version(ctx: click.Context, param: click.Parameter, value: bool) -> No
|
||||
help="Maximum number of seconds to wait for graceful shutdown.",
|
||||
)
|
||||
@click.option(
|
||||
"--ssl-keyfile", type=str, default=None, help="SSL key file", show_default=True
|
||||
"--timeout-worker-healthcheck",
|
||||
type=int,
|
||||
default=5,
|
||||
help="Maximum number of seconds to wait for a worker to respond to a healthcheck.",
|
||||
show_default=True,
|
||||
)
|
||||
@click.option("--ssl-keyfile", type=str, default=None, help="SSL key file", show_default=True)
|
||||
@click.option(
|
||||
"--ssl-certfile",
|
||||
type=str,
|
||||
@@ -370,9 +376,9 @@ def main(
|
||||
port: int,
|
||||
uds: str,
|
||||
fd: int,
|
||||
loop: LoopSetupType,
|
||||
http: HTTPProtocolType,
|
||||
ws: WSProtocolType,
|
||||
loop: LoopFactoryType | str,
|
||||
http: HTTPProtocolType | str,
|
||||
ws: WSProtocolType | str,
|
||||
ws_max_size: int,
|
||||
ws_max_queue: int,
|
||||
ws_ping_interval: float,
|
||||
@@ -381,9 +387,9 @@ def main(
|
||||
lifespan: LifespanType,
|
||||
interface: InterfaceType,
|
||||
reload: bool,
|
||||
reload_dirs: typing.List[str],
|
||||
reload_includes: typing.List[str],
|
||||
reload_excludes: typing.List[str],
|
||||
reload_dirs: list[str],
|
||||
reload_includes: list[str],
|
||||
reload_excludes: list[str],
|
||||
reload_delay: float,
|
||||
workers: int,
|
||||
env_file: str,
|
||||
@@ -399,7 +405,8 @@ def main(
|
||||
backlog: int,
|
||||
limit_max_requests: int,
|
||||
timeout_keep_alive: int,
|
||||
timeout_graceful_shutdown: typing.Optional[int],
|
||||
timeout_graceful_shutdown: int | None,
|
||||
timeout_worker_healthcheck: int,
|
||||
ssl_keyfile: str,
|
||||
ssl_certfile: str,
|
||||
ssl_keyfile_password: str,
|
||||
@@ -407,10 +414,10 @@ def main(
|
||||
ssl_cert_reqs: int,
|
||||
ssl_ca_certs: str,
|
||||
ssl_ciphers: str,
|
||||
headers: typing.List[str],
|
||||
headers: list[str],
|
||||
use_colors: bool,
|
||||
app_dir: str,
|
||||
h11_max_incomplete_event_size: typing.Optional[int],
|
||||
h11_max_incomplete_event_size: int | None,
|
||||
factory: bool,
|
||||
) -> None:
|
||||
run(
|
||||
@@ -449,6 +456,7 @@ def main(
|
||||
limit_max_requests=limit_max_requests,
|
||||
timeout_keep_alive=timeout_keep_alive,
|
||||
timeout_graceful_shutdown=timeout_graceful_shutdown,
|
||||
timeout_worker_healthcheck=timeout_worker_healthcheck,
|
||||
ssl_keyfile=ssl_keyfile,
|
||||
ssl_certfile=ssl_certfile,
|
||||
ssl_keyfile_password=ssl_keyfile_password,
|
||||
@@ -465,56 +473,55 @@ def main(
|
||||
|
||||
|
||||
def run(
|
||||
app: typing.Union["ASGIApplication", typing.Callable, str],
|
||||
app: ASGIApplication | Callable[..., Any] | str,
|
||||
*,
|
||||
host: str = "127.0.0.1",
|
||||
port: int = 8000,
|
||||
uds: typing.Optional[str] = None,
|
||||
fd: typing.Optional[int] = None,
|
||||
loop: LoopSetupType = "auto",
|
||||
http: typing.Union[typing.Type[asyncio.Protocol], HTTPProtocolType] = "auto",
|
||||
ws: typing.Union[typing.Type[asyncio.Protocol], WSProtocolType] = "auto",
|
||||
uds: str | None = None,
|
||||
fd: int | None = None,
|
||||
loop: LoopFactoryType | str = "auto",
|
||||
http: type[asyncio.Protocol] | HTTPProtocolType | str = "auto",
|
||||
ws: type[asyncio.Protocol] | WSProtocolType | str = "auto",
|
||||
ws_max_size: int = 16777216,
|
||||
ws_max_queue: int = 32,
|
||||
ws_ping_interval: typing.Optional[float] = 20.0,
|
||||
ws_ping_timeout: typing.Optional[float] = 20.0,
|
||||
ws_ping_interval: float | None = 20.0,
|
||||
ws_ping_timeout: float | None = 20.0,
|
||||
ws_per_message_deflate: bool = True,
|
||||
lifespan: LifespanType = "auto",
|
||||
interface: InterfaceType = "auto",
|
||||
reload: bool = False,
|
||||
reload_dirs: typing.Optional[typing.Union[typing.List[str], str]] = None,
|
||||
reload_includes: typing.Optional[typing.Union[typing.List[str], str]] = None,
|
||||
reload_excludes: typing.Optional[typing.Union[typing.List[str], str]] = None,
|
||||
reload_dirs: list[str] | str | None = None,
|
||||
reload_includes: list[str] | str | None = None,
|
||||
reload_excludes: list[str] | str | None = None,
|
||||
reload_delay: float = 0.25,
|
||||
workers: typing.Optional[int] = None,
|
||||
env_file: typing.Optional[typing.Union[str, os.PathLike]] = None,
|
||||
log_config: typing.Optional[
|
||||
typing.Union[typing.Dict[str, typing.Any], str]
|
||||
] = LOGGING_CONFIG,
|
||||
log_level: typing.Optional[typing.Union[str, int]] = None,
|
||||
workers: int | None = None,
|
||||
env_file: str | os.PathLike[str] | None = None,
|
||||
log_config: dict[str, Any] | str | RawConfigParser | IO[Any] | None = LOGGING_CONFIG,
|
||||
log_level: str | int | None = None,
|
||||
access_log: bool = True,
|
||||
proxy_headers: bool = True,
|
||||
server_header: bool = True,
|
||||
date_header: bool = True,
|
||||
forwarded_allow_ips: typing.Optional[typing.Union[typing.List[str], str]] = None,
|
||||
forwarded_allow_ips: list[str] | str | None = None,
|
||||
root_path: str = "",
|
||||
limit_concurrency: typing.Optional[int] = None,
|
||||
limit_concurrency: int | None = None,
|
||||
backlog: int = 2048,
|
||||
limit_max_requests: typing.Optional[int] = None,
|
||||
limit_max_requests: int | None = None,
|
||||
timeout_keep_alive: int = 5,
|
||||
timeout_graceful_shutdown: typing.Optional[int] = None,
|
||||
ssl_keyfile: typing.Optional[str] = None,
|
||||
ssl_certfile: typing.Optional[typing.Union[str, os.PathLike]] = None,
|
||||
ssl_keyfile_password: typing.Optional[str] = None,
|
||||
timeout_graceful_shutdown: int | None = None,
|
||||
timeout_worker_healthcheck: int = 5,
|
||||
ssl_keyfile: str | os.PathLike[str] | None = None,
|
||||
ssl_certfile: str | os.PathLike[str] | None = None,
|
||||
ssl_keyfile_password: str | None = None,
|
||||
ssl_version: int = SSL_PROTOCOL_VERSION,
|
||||
ssl_cert_reqs: int = ssl.CERT_NONE,
|
||||
ssl_ca_certs: typing.Optional[str] = None,
|
||||
ssl_ca_certs: str | os.PathLike[str] | None = None,
|
||||
ssl_ciphers: str = "TLSv1",
|
||||
headers: typing.Optional[typing.List[typing.Tuple[str, str]]] = None,
|
||||
use_colors: typing.Optional[bool] = None,
|
||||
app_dir: typing.Optional[str] = None,
|
||||
headers: list[tuple[str, str]] | None = None,
|
||||
use_colors: bool | None = None,
|
||||
app_dir: str | None = None,
|
||||
factory: bool = False,
|
||||
h11_max_incomplete_event_size: typing.Optional[int] = None,
|
||||
h11_max_incomplete_event_size: int | None = None,
|
||||
) -> None:
|
||||
if app_dir is not None:
|
||||
sys.path.insert(0, app_dir)
|
||||
@@ -555,6 +562,7 @@ def run(
|
||||
limit_max_requests=limit_max_requests,
|
||||
timeout_keep_alive=timeout_keep_alive,
|
||||
timeout_graceful_shutdown=timeout_graceful_shutdown,
|
||||
timeout_worker_healthcheck=timeout_worker_healthcheck,
|
||||
ssl_keyfile=ssl_keyfile,
|
||||
ssl_certfile=ssl_certfile,
|
||||
ssl_keyfile_password=ssl_keyfile_password,
|
||||
@@ -571,26 +579,39 @@ def run(
|
||||
|
||||
if (config.reload or config.workers > 1) and not isinstance(app, str):
|
||||
logger = logging.getLogger("uvicorn.error")
|
||||
logger.warning(
|
||||
"You must pass the application as an import string to enable 'reload' or "
|
||||
"'workers'."
|
||||
)
|
||||
logger.warning("You must pass the application as an import string to enable 'reload' or 'workers'.")
|
||||
sys.exit(1)
|
||||
|
||||
if config.should_reload:
|
||||
sock = config.bind_socket()
|
||||
ChangeReload(config, target=server.run, sockets=[sock]).run()
|
||||
elif config.workers > 1:
|
||||
sock = config.bind_socket()
|
||||
Multiprocess(config, target=server.run, sockets=[sock]).run()
|
||||
else:
|
||||
server.run()
|
||||
if config.uds and os.path.exists(config.uds):
|
||||
os.remove(config.uds) # pragma: py-win32
|
||||
try:
|
||||
if config.should_reload:
|
||||
sock = config.bind_socket()
|
||||
ChangeReload(config, target=server.run, sockets=[sock]).run()
|
||||
elif config.workers > 1:
|
||||
sock = config.bind_socket()
|
||||
Multiprocess(config, target=server.run, sockets=[sock]).run()
|
||||
else:
|
||||
server.run()
|
||||
except KeyboardInterrupt:
|
||||
pass # pragma: full coverage
|
||||
finally:
|
||||
if config.uds and os.path.exists(config.uds):
|
||||
os.remove(config.uds) # pragma: py-win32
|
||||
|
||||
if not server.started and not config.should_reload and config.workers == 1:
|
||||
sys.exit(STARTUP_FAILURE)
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
if name == "ServerState":
|
||||
warnings.warn(
|
||||
"uvicorn.main.ServerState is deprecated, use uvicorn.server.ServerState instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
from uvicorn.server import ServerState
|
||||
|
||||
return ServerState
|
||||
raise AttributeError(f"module {__name__} has no attribute {name}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main() # pragma: no cover
|
||||
|
||||
@@ -10,8 +10,6 @@ class ASGI2Middleware:
|
||||
def __init__(self, app: "ASGI2Application"):
|
||||
self.app = app
|
||||
|
||||
async def __call__(
|
||||
self, scope: "Scope", receive: "ASGIReceiveCallable", send: "ASGISendCallable"
|
||||
) -> None:
|
||||
async def __call__(self, scope: "Scope", receive: "ASGIReceiveCallable", send: "ASGISendCallable") -> None:
|
||||
instance = self.app(scope)
|
||||
await instance(receive, send)
|
||||
|
||||
@@ -1,84 +1,142 @@
|
||||
"""
|
||||
This middleware can be used when a known proxy is fronting the application,
|
||||
and is trusted to be properly setting the `X-Forwarded-Proto` and
|
||||
`X-Forwarded-For` headers with the connecting client information.
|
||||
from __future__ import annotations
|
||||
|
||||
Modifies the `client` and `scheme` information so that they reference
|
||||
the connecting client, rather that the connecting proxy.
|
||||
import ipaddress
|
||||
|
||||
https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers#Proxies
|
||||
"""
|
||||
from typing import List, Optional, Tuple, Union, cast
|
||||
|
||||
from uvicorn._types import (
|
||||
ASGI3Application,
|
||||
ASGIReceiveCallable,
|
||||
ASGISendCallable,
|
||||
HTTPScope,
|
||||
Scope,
|
||||
WebSocketScope,
|
||||
)
|
||||
from uvicorn._types import ASGI3Application, ASGIReceiveCallable, ASGISendCallable, Scope
|
||||
|
||||
|
||||
class ProxyHeadersMiddleware:
|
||||
def __init__(
|
||||
self,
|
||||
app: "ASGI3Application",
|
||||
trusted_hosts: Union[List[str], str] = "127.0.0.1",
|
||||
) -> None:
|
||||
"""Middleware for handling known proxy headers
|
||||
|
||||
This middleware can be used when a known proxy is fronting the application,
|
||||
and is trusted to be properly setting the `X-Forwarded-Proto` and
|
||||
`X-Forwarded-For` headers with the connecting client information.
|
||||
|
||||
Modifies the `client` and `scheme` information so that they reference
|
||||
the connecting client, rather that the connecting proxy.
|
||||
|
||||
References:
|
||||
- <https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers#Proxies>
|
||||
- <https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/X-Forwarded-For>
|
||||
"""
|
||||
|
||||
def __init__(self, app: ASGI3Application, trusted_hosts: list[str] | str = "127.0.0.1") -> None:
|
||||
self.app = app
|
||||
if isinstance(trusted_hosts, str):
|
||||
self.trusted_hosts = {item.strip() for item in trusted_hosts.split(",")}
|
||||
else:
|
||||
self.trusted_hosts = set(trusted_hosts)
|
||||
self.always_trust = "*" in self.trusted_hosts
|
||||
self.trusted_hosts = _TrustedHosts(trusted_hosts)
|
||||
|
||||
def get_trusted_client_host(
|
||||
self, x_forwarded_for_hosts: List[str]
|
||||
) -> Optional[str]:
|
||||
if self.always_trust:
|
||||
return x_forwarded_for_hosts[0]
|
||||
async def __call__(self, scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable) -> None:
|
||||
if scope["type"] == "lifespan":
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
for host in reversed(x_forwarded_for_hosts):
|
||||
if host not in self.trusted_hosts:
|
||||
return host
|
||||
client_addr = scope.get("client")
|
||||
client_host = client_addr[0] if client_addr else None
|
||||
|
||||
return None
|
||||
if client_host in self.trusted_hosts:
|
||||
headers = dict(scope["headers"])
|
||||
|
||||
async def __call__(
|
||||
self, scope: "Scope", receive: "ASGIReceiveCallable", send: "ASGISendCallable"
|
||||
) -> None:
|
||||
if scope["type"] in ("http", "websocket"):
|
||||
scope = cast(Union["HTTPScope", "WebSocketScope"], scope)
|
||||
client_addr: Optional[Tuple[str, int]] = scope.get("client")
|
||||
client_host = client_addr[0] if client_addr else None
|
||||
if b"x-forwarded-proto" in headers:
|
||||
x_forwarded_proto = headers[b"x-forwarded-proto"].decode("latin1").strip()
|
||||
|
||||
if self.always_trust or client_host in self.trusted_hosts:
|
||||
headers = dict(scope["headers"])
|
||||
|
||||
if b"x-forwarded-proto" in headers:
|
||||
# Determine if the incoming request was http or https based on
|
||||
# the X-Forwarded-Proto header.
|
||||
x_forwarded_proto = (
|
||||
headers[b"x-forwarded-proto"].decode("latin1").strip()
|
||||
)
|
||||
if x_forwarded_proto in {"http", "https", "ws", "wss"}:
|
||||
if scope["type"] == "websocket":
|
||||
scope["scheme"] = (
|
||||
"wss" if x_forwarded_proto == "https" else "ws"
|
||||
)
|
||||
scope["scheme"] = x_forwarded_proto.replace("http", "ws")
|
||||
else:
|
||||
scope["scheme"] = x_forwarded_proto
|
||||
|
||||
if b"x-forwarded-for" in headers:
|
||||
# Determine the client address from the last trusted IP in the
|
||||
# X-Forwarded-For header. We've lost the connecting client's port
|
||||
# information by now, so only include the host.
|
||||
x_forwarded_for = headers[b"x-forwarded-for"].decode("latin1")
|
||||
x_forwarded_for_hosts = [
|
||||
item.strip() for item in x_forwarded_for.split(",")
|
||||
]
|
||||
host = self.get_trusted_client_host(x_forwarded_for_hosts)
|
||||
if b"x-forwarded-for" in headers:
|
||||
x_forwarded_for = headers[b"x-forwarded-for"].decode("latin1")
|
||||
host = self.trusted_hosts.get_trusted_client_host(x_forwarded_for)
|
||||
|
||||
if host:
|
||||
# If the x-forwarded-for header is empty then host is an empty string.
|
||||
# Only set the client if we actually got something usable.
|
||||
# See: https://github.com/Kludex/uvicorn/issues/1068
|
||||
|
||||
# We've lost the connecting client's port information by now,
|
||||
# so only include the host.
|
||||
port = 0
|
||||
scope["client"] = (host, port) # type: ignore[arg-type]
|
||||
scope["client"] = (host, port)
|
||||
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
|
||||
def _parse_raw_hosts(value: str) -> list[str]:
|
||||
return [item.strip() for item in value.split(",")]
|
||||
|
||||
|
||||
class _TrustedHosts:
|
||||
"""Container for trusted hosts and networks"""
|
||||
|
||||
def __init__(self, trusted_hosts: list[str] | str) -> None:
|
||||
self.always_trust: bool = trusted_hosts in ("*", ["*"])
|
||||
|
||||
self.trusted_literals: set[str] = set()
|
||||
self.trusted_hosts: set[ipaddress.IPv4Address | ipaddress.IPv6Address] = set()
|
||||
self.trusted_networks: set[ipaddress.IPv4Network | ipaddress.IPv6Network] = set()
|
||||
|
||||
# Notes:
|
||||
# - We separate hosts from literals as there are many ways to write
|
||||
# an IPv6 Address so we need to compare by object.
|
||||
# - We don't convert IP Address to single host networks (e.g. /32 / 128) as
|
||||
# it more efficient to do an address lookup in a set than check for
|
||||
# membership in each network.
|
||||
# - We still allow literals as it might be possible that we receive a
|
||||
# something that isn't an IP Address e.g. a unix socket.
|
||||
|
||||
if not self.always_trust:
|
||||
if isinstance(trusted_hosts, str):
|
||||
trusted_hosts = _parse_raw_hosts(trusted_hosts)
|
||||
|
||||
for host in trusted_hosts:
|
||||
# Note: because we always convert invalid IP types to literals it
|
||||
# is not possible for the user to know they provided a malformed IP
|
||||
# type - this may lead to unexpected / difficult to debug behaviour.
|
||||
|
||||
if "/" in host:
|
||||
# Looks like a network
|
||||
try:
|
||||
self.trusted_networks.add(ipaddress.ip_network(host))
|
||||
except ValueError:
|
||||
# Was not a valid IP Network
|
||||
self.trusted_literals.add(host)
|
||||
else:
|
||||
try:
|
||||
self.trusted_hosts.add(ipaddress.ip_address(host))
|
||||
except ValueError:
|
||||
# Was not a valid IP Address
|
||||
self.trusted_literals.add(host)
|
||||
|
||||
def __contains__(self, host: str | None) -> bool:
|
||||
if self.always_trust:
|
||||
return True
|
||||
|
||||
if not host:
|
||||
return False
|
||||
|
||||
try:
|
||||
ip = ipaddress.ip_address(host)
|
||||
if ip in self.trusted_hosts:
|
||||
return True
|
||||
return any(ip in net for net in self.trusted_networks)
|
||||
|
||||
except ValueError:
|
||||
return host in self.trusted_literals
|
||||
|
||||
def get_trusted_client_host(self, x_forwarded_for: str) -> str:
|
||||
"""Extract the client host from x_forwarded_for header
|
||||
|
||||
In general this is the first "untrusted" host in the forwarded for list.
|
||||
"""
|
||||
x_forwarded_for_hosts = _parse_raw_hosts(x_forwarded_for)
|
||||
|
||||
if self.always_trust:
|
||||
return x_forwarded_for_hosts[0]
|
||||
|
||||
# Note: each proxy appends to the header list so check it in reverse order
|
||||
for host in reversed(x_forwarded_for_hosts):
|
||||
if host not in self:
|
||||
return host
|
||||
|
||||
# All hosts are trusted meaning that the client was also a trusted proxy
|
||||
# See https://github.com/Kludex/uvicorn/issues/1068#issuecomment-855371576
|
||||
return x_forwarded_for_hosts[0]
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
import io
|
||||
import sys
|
||||
import warnings
|
||||
from collections import deque
|
||||
from typing import Deque, Iterable, Optional, Tuple
|
||||
from collections.abc import Iterable
|
||||
|
||||
from uvicorn._types import (
|
||||
ASGIReceiveCallable,
|
||||
@@ -22,16 +24,18 @@ from uvicorn._types import (
|
||||
)
|
||||
|
||||
|
||||
def build_environ(
|
||||
scope: "HTTPScope", message: "ASGIReceiveEvent", body: io.BytesIO
|
||||
) -> Environ:
|
||||
def build_environ(scope: HTTPScope, message: ASGIReceiveEvent, body: io.BytesIO) -> Environ:
|
||||
"""
|
||||
Builds a scope and request message into a WSGI environ object.
|
||||
"""
|
||||
script_name = scope.get("root_path", "").encode("utf8").decode("latin1")
|
||||
path_info = scope["path"].encode("utf8").decode("latin1")
|
||||
if path_info.startswith(script_name):
|
||||
path_info = path_info[len(script_name) :]
|
||||
environ = {
|
||||
"REQUEST_METHOD": scope["method"],
|
||||
"SCRIPT_NAME": "",
|
||||
"PATH_INFO": scope["path"].encode("utf8").decode("latin1"),
|
||||
"SCRIPT_NAME": script_name,
|
||||
"PATH_INFO": path_info,
|
||||
"QUERY_STRING": scope["query_string"].decode("ascii"),
|
||||
"SERVER_PROTOCOL": "HTTP/%s" % scope["http_version"],
|
||||
"wsgi.version": (1, 0),
|
||||
@@ -78,8 +82,7 @@ def build_environ(
|
||||
class _WSGIMiddleware:
|
||||
def __init__(self, app: WSGIApp, workers: int = 10):
|
||||
warnings.warn(
|
||||
"Uvicorn's native WSGI implementation is deprecated, you "
|
||||
"should switch to a2wsgi (`pip install a2wsgi`).",
|
||||
"Uvicorn's native WSGI implementation is deprecated, you should switch to a2wsgi (`pip install a2wsgi`).",
|
||||
DeprecationWarning,
|
||||
)
|
||||
self.app = app
|
||||
@@ -87,9 +90,9 @@ class _WSGIMiddleware:
|
||||
|
||||
async def __call__(
|
||||
self,
|
||||
scope: "HTTPScope",
|
||||
receive: "ASGIReceiveCallable",
|
||||
send: "ASGISendCallable",
|
||||
scope: HTTPScope,
|
||||
receive: ASGIReceiveCallable,
|
||||
send: ASGISendCallable,
|
||||
) -> None:
|
||||
assert scope["type"] == "http"
|
||||
instance = WSGIResponder(self.app, self.executor, scope)
|
||||
@@ -101,7 +104,7 @@ class WSGIResponder:
|
||||
self,
|
||||
app: WSGIApp,
|
||||
executor: concurrent.futures.ThreadPoolExecutor,
|
||||
scope: "HTTPScope",
|
||||
scope: HTTPScope,
|
||||
):
|
||||
self.app = app
|
||||
self.executor = executor
|
||||
@@ -109,21 +112,19 @@ class WSGIResponder:
|
||||
self.status = None
|
||||
self.response_headers = None
|
||||
self.send_event = asyncio.Event()
|
||||
self.send_queue: Deque[Optional["ASGISendEvent"]] = deque()
|
||||
self.send_queue: deque[ASGISendEvent | None] = deque()
|
||||
self.loop: asyncio.AbstractEventLoop = asyncio.get_event_loop()
|
||||
self.response_started = False
|
||||
self.exc_info: Optional[ExcInfo] = None
|
||||
self.exc_info: ExcInfo | None = None
|
||||
|
||||
async def __call__(
|
||||
self, receive: "ASGIReceiveCallable", send: "ASGISendCallable"
|
||||
) -> None:
|
||||
async def __call__(self, receive: ASGIReceiveCallable, send: ASGISendCallable) -> None:
|
||||
message: HTTPRequestEvent = await receive() # type: ignore[assignment]
|
||||
body = io.BytesIO(message.get("body", b""))
|
||||
more_body = message.get("more_body", False)
|
||||
if more_body:
|
||||
body.seek(0, io.SEEK_END)
|
||||
while more_body:
|
||||
body_message: "HTTPRequestEvent" = (
|
||||
body_message: HTTPRequestEvent = (
|
||||
await receive() # type: ignore[assignment]
|
||||
)
|
||||
body.write(body_message.get("body", b""))
|
||||
@@ -131,9 +132,7 @@ class WSGIResponder:
|
||||
body.seek(0)
|
||||
environ = build_environ(self.scope, message, body)
|
||||
self.loop = asyncio.get_event_loop()
|
||||
wsgi = self.loop.run_in_executor(
|
||||
self.executor, self.wsgi, environ, self.start_response
|
||||
)
|
||||
wsgi = self.loop.run_in_executor(self.executor, self.wsgi, environ, self.start_response)
|
||||
sender = self.loop.create_task(self.sender(send))
|
||||
try:
|
||||
await asyncio.wait_for(wsgi, None)
|
||||
@@ -144,7 +143,7 @@ class WSGIResponder:
|
||||
if self.exc_info is not None:
|
||||
raise self.exc_info[0].with_traceback(self.exc_info[1], self.exc_info[2])
|
||||
|
||||
async def sender(self, send: "ASGISendCallable") -> None:
|
||||
async def sender(self, send: ASGISendCallable) -> None:
|
||||
while True:
|
||||
if self.send_queue:
|
||||
message = self.send_queue.popleft()
|
||||
@@ -158,18 +157,15 @@ class WSGIResponder:
|
||||
def start_response(
|
||||
self,
|
||||
status: str,
|
||||
response_headers: Iterable[Tuple[str, str]],
|
||||
exc_info: Optional[ExcInfo] = None,
|
||||
response_headers: Iterable[tuple[str, str]],
|
||||
exc_info: ExcInfo | None = None,
|
||||
) -> None:
|
||||
self.exc_info = exc_info
|
||||
if not self.response_started:
|
||||
self.response_started = True
|
||||
status_code_str, _ = status.split(" ", 1)
|
||||
status_code = int(status_code_str)
|
||||
headers = [
|
||||
(name.encode("ascii"), value.encode("ascii"))
|
||||
for name, value in response_headers
|
||||
]
|
||||
headers = [(name.encode("ascii"), value.encode("ascii")) for name, value in response_headers]
|
||||
http_response_start_event: HTTPResponseStartEvent = {
|
||||
"type": "http.response.start",
|
||||
"status": status_code,
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import asyncio
|
||||
from typing import Type
|
||||
from __future__ import annotations
|
||||
|
||||
AutoHTTPProtocol: Type[asyncio.Protocol]
|
||||
import asyncio
|
||||
|
||||
AutoHTTPProtocol: type[asyncio.Protocol]
|
||||
try:
|
||||
import httptools # noqa
|
||||
except ImportError: # pragma: no cover
|
||||
|
||||
@@ -1,12 +1,6 @@
|
||||
import asyncio
|
||||
|
||||
from uvicorn._types import (
|
||||
ASGIReceiveCallable,
|
||||
ASGISendCallable,
|
||||
HTTPResponseBodyEvent,
|
||||
HTTPResponseStartEvent,
|
||||
Scope,
|
||||
)
|
||||
from uvicorn._types import ASGIReceiveCallable, ASGISendCallable, Scope
|
||||
|
||||
CLOSE_HEADER = (b"connection", b"close")
|
||||
|
||||
@@ -22,7 +16,7 @@ class FlowControl:
|
||||
self._is_writable_event.set()
|
||||
|
||||
async def drain(self) -> None:
|
||||
await self._is_writable_event.wait()
|
||||
await self._is_writable_event.wait() # pragma: full coverage
|
||||
|
||||
def pause_reading(self) -> None:
|
||||
if not self.read_paused:
|
||||
@@ -35,32 +29,26 @@ class FlowControl:
|
||||
self._transport.resume_reading()
|
||||
|
||||
def pause_writing(self) -> None:
|
||||
if not self.write_paused:
|
||||
if not self.write_paused: # pragma: full coverage
|
||||
self.write_paused = True
|
||||
self._is_writable_event.clear()
|
||||
|
||||
def resume_writing(self) -> None:
|
||||
if self.write_paused:
|
||||
if self.write_paused: # pragma: full coverage
|
||||
self.write_paused = False
|
||||
self._is_writable_event.set()
|
||||
|
||||
|
||||
async def service_unavailable(
|
||||
scope: "Scope", receive: "ASGIReceiveCallable", send: "ASGISendCallable"
|
||||
) -> None:
|
||||
response_start: "HTTPResponseStartEvent" = {
|
||||
"type": "http.response.start",
|
||||
"status": 503,
|
||||
"headers": [
|
||||
(b"content-type", b"text/plain; charset=utf-8"),
|
||||
(b"connection", b"close"),
|
||||
],
|
||||
}
|
||||
await send(response_start)
|
||||
|
||||
response_body: "HTTPResponseBodyEvent" = {
|
||||
"type": "http.response.body",
|
||||
"body": b"Service Unavailable",
|
||||
"more_body": False,
|
||||
}
|
||||
await send(response_body)
|
||||
async def service_unavailable(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable) -> None:
|
||||
await send(
|
||||
{
|
||||
"type": "http.response.start",
|
||||
"status": 503,
|
||||
"headers": [
|
||||
(b"content-type", b"text/plain; charset=utf-8"),
|
||||
(b"content-length", b"19"),
|
||||
(b"connection", b"close"),
|
||||
],
|
||||
}
|
||||
)
|
||||
await send({"type": "http.response.body", "body": b"Service Unavailable", "more_body": False})
|
||||
|
||||
@@ -1,16 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import http
|
||||
import logging
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Tuple,
|
||||
cast,
|
||||
)
|
||||
from typing import Any, Callable, Literal, cast
|
||||
from urllib.parse import unquote
|
||||
|
||||
import h11
|
||||
@@ -27,19 +20,8 @@ from uvicorn._types import (
|
||||
)
|
||||
from uvicorn.config import Config
|
||||
from uvicorn.logging import TRACE_LOG_LEVEL
|
||||
from uvicorn.protocols.http.flow_control import (
|
||||
CLOSE_HEADER,
|
||||
HIGH_WATER_LIMIT,
|
||||
FlowControl,
|
||||
service_unavailable,
|
||||
)
|
||||
from uvicorn.protocols.utils import (
|
||||
get_client_addr,
|
||||
get_local_addr,
|
||||
get_path_with_query_string,
|
||||
get_remote_addr,
|
||||
is_ssl,
|
||||
)
|
||||
from uvicorn.protocols.http.flow_control import CLOSE_HEADER, HIGH_WATER_LIMIT, FlowControl, service_unavailable
|
||||
from uvicorn.protocols.utils import get_client_addr, get_local_addr, get_path_with_query_string, get_remote_addr, is_ssl
|
||||
from uvicorn.server import ServerState
|
||||
|
||||
|
||||
@@ -50,9 +32,7 @@ def _get_status_phrase(status_code: int) -> bytes:
|
||||
return b""
|
||||
|
||||
|
||||
STATUS_PHRASES = {
|
||||
status_code: _get_status_phrase(status_code) for status_code in range(100, 600)
|
||||
}
|
||||
STATUS_PHRASES = {status_code: _get_status_phrase(status_code) for status_code in range(100, 600)}
|
||||
|
||||
|
||||
class H11Protocol(asyncio.Protocol):
|
||||
@@ -60,8 +40,8 @@ class H11Protocol(asyncio.Protocol):
|
||||
self,
|
||||
config: Config,
|
||||
server_state: ServerState,
|
||||
app_state: Dict[str, Any],
|
||||
_loop: Optional[asyncio.AbstractEventLoop] = None,
|
||||
app_state: dict[str, Any],
|
||||
_loop: asyncio.AbstractEventLoop | None = None,
|
||||
) -> None:
|
||||
if not config.loaded:
|
||||
config.load()
|
||||
@@ -84,7 +64,7 @@ class H11Protocol(asyncio.Protocol):
|
||||
self.app_state = app_state
|
||||
|
||||
# Timeouts
|
||||
self.timeout_keep_alive_task: Optional[asyncio.TimerHandle] = None
|
||||
self.timeout_keep_alive_task: asyncio.TimerHandle | None = None
|
||||
self.timeout_keep_alive = config.timeout_keep_alive
|
||||
|
||||
# Shared server state
|
||||
@@ -95,13 +75,13 @@ class H11Protocol(asyncio.Protocol):
|
||||
# Per-connection state
|
||||
self.transport: asyncio.Transport = None # type: ignore[assignment]
|
||||
self.flow: FlowControl = None # type: ignore[assignment]
|
||||
self.server: Optional[Tuple[str, int]] = None
|
||||
self.client: Optional[Tuple[str, int]] = None
|
||||
self.scheme: Optional[Literal["http", "https"]] = None
|
||||
self.server: tuple[str, int] | None = None
|
||||
self.client: tuple[str, int] | None = None
|
||||
self.scheme: Literal["http", "https"] | None = None
|
||||
|
||||
# Per-request state
|
||||
self.scope: HTTPScope = None # type: ignore[assignment]
|
||||
self.headers: List[Tuple[bytes, bytes]] = None # type: ignore[assignment]
|
||||
self.headers: list[tuple[bytes, bytes]] = None # type: ignore[assignment]
|
||||
self.cycle: RequestResponseCycle = None # type: ignore[assignment]
|
||||
|
||||
# Protocol interface
|
||||
@@ -120,7 +100,7 @@ class H11Protocol(asyncio.Protocol):
|
||||
prefix = "%s:%d - " % self.client if self.client else ""
|
||||
self.logger.log(TRACE_LOG_LEVEL, "%sHTTP connection made", prefix)
|
||||
|
||||
def connection_lost(self, exc: Optional[Exception]) -> None:
|
||||
def connection_lost(self, exc: Exception | None) -> None:
|
||||
self.connections.discard(self)
|
||||
|
||||
if self.logger.level <= TRACE_LOG_LEVEL:
|
||||
@@ -153,7 +133,7 @@ class H11Protocol(asyncio.Protocol):
|
||||
self.timeout_keep_alive_task.cancel()
|
||||
self.timeout_keep_alive_task = None
|
||||
|
||||
def _get_upgrade(self) -> Optional[bytes]:
|
||||
def _get_upgrade(self) -> bytes | None:
|
||||
connection = []
|
||||
upgrade = None
|
||||
for name, value in self.headers:
|
||||
@@ -167,14 +147,24 @@ class H11Protocol(asyncio.Protocol):
|
||||
|
||||
def _should_upgrade_to_ws(self) -> bool:
|
||||
if self.ws_protocol_class is None:
|
||||
if self.config.ws == "auto":
|
||||
msg = "Unsupported upgrade request."
|
||||
self.logger.warning(msg)
|
||||
msg = "No supported WebSocket library detected. Please use \"pip install 'uvicorn[standard]'\", or install 'websockets' or 'wsproto' manually." # noqa: E501
|
||||
self.logger.warning(msg)
|
||||
return False
|
||||
return True
|
||||
|
||||
def _unsupported_upgrade_warning(self) -> None:
|
||||
msg = "Unsupported upgrade request."
|
||||
self.logger.warning(msg)
|
||||
if not self._should_upgrade_to_ws():
|
||||
msg = "No supported WebSocket library detected. Please use \"pip install 'uvicorn[standard]'\", or install 'websockets' or 'wsproto' manually." # noqa: E501
|
||||
self.logger.warning(msg)
|
||||
|
||||
def _should_upgrade(self) -> bool:
|
||||
upgrade = self._get_upgrade()
|
||||
if upgrade == b"websocket" and self._should_upgrade_to_ws():
|
||||
return True
|
||||
if upgrade is not None:
|
||||
self._unsupported_upgrade_warning()
|
||||
return False
|
||||
|
||||
def data_received(self, data: bytes) -> None:
|
||||
self._unset_keepalive_if_required()
|
||||
|
||||
@@ -205,34 +195,31 @@ class H11Protocol(asyncio.Protocol):
|
||||
elif isinstance(event, h11.Request):
|
||||
self.headers = [(key.lower(), value) for key, value in event.headers]
|
||||
raw_path, _, query_string = event.target.partition(b"?")
|
||||
path = unquote(raw_path.decode("ascii"))
|
||||
full_path = self.root_path + path
|
||||
full_raw_path = self.root_path.encode("ascii") + raw_path
|
||||
self.scope = {
|
||||
"type": "http",
|
||||
"asgi": {
|
||||
"version": self.config.asgi_version,
|
||||
"spec_version": "2.3",
|
||||
},
|
||||
"asgi": {"version": self.config.asgi_version, "spec_version": "2.3"},
|
||||
"http_version": event.http_version.decode("ascii"),
|
||||
"server": self.server,
|
||||
"client": self.client,
|
||||
"scheme": self.scheme, # type: ignore[typeddict-item]
|
||||
"method": event.method.decode("ascii"),
|
||||
"root_path": self.root_path,
|
||||
"path": unquote(raw_path.decode("ascii")),
|
||||
"raw_path": raw_path,
|
||||
"path": full_path,
|
||||
"raw_path": full_raw_path,
|
||||
"query_string": query_string,
|
||||
"headers": self.headers,
|
||||
"state": self.app_state.copy(),
|
||||
}
|
||||
|
||||
upgrade = self._get_upgrade()
|
||||
if upgrade == b"websocket" and self._should_upgrade_to_ws():
|
||||
if self._should_upgrade():
|
||||
self.handle_websocket_upgrade(event)
|
||||
return
|
||||
|
||||
# Handle 503 responses when 'limit_concurrency' is exceeded.
|
||||
if self.limit_concurrency is not None and (
|
||||
len(self.connections) >= self.limit_concurrency
|
||||
or len(self.tasks) >= self.limit_concurrency
|
||||
len(self.connections) >= self.limit_concurrency or len(self.tasks) >= self.limit_concurrency
|
||||
):
|
||||
app = service_unavailable
|
||||
message = "Exceeded concurrency limit."
|
||||
@@ -240,6 +227,14 @@ class H11Protocol(asyncio.Protocol):
|
||||
else:
|
||||
app = self.app
|
||||
|
||||
# When starting to process a request, disable the keep-alive
|
||||
# timeout. Normally we disable this when receiving data from
|
||||
# client and set back when finishing processing its request.
|
||||
# However, for pipelined requests processing finishes after
|
||||
# already receiving the next request and thus the timer may
|
||||
# be set here, which we don't want.
|
||||
self._unset_keepalive_if_required()
|
||||
|
||||
self.cycle = RequestResponseCycle(
|
||||
scope=self.scope,
|
||||
conn=self.conn,
|
||||
@@ -271,9 +266,11 @@ class H11Protocol(asyncio.Protocol):
|
||||
continue
|
||||
self.cycle.more_body = False
|
||||
self.cycle.message_event.set()
|
||||
if self.conn.their_state == h11.MUST_CLOSE:
|
||||
break
|
||||
|
||||
def handle_websocket_upgrade(self, event: h11.Request) -> None:
|
||||
if self.logger.level <= TRACE_LOG_LEVEL:
|
||||
if self.logger.level <= TRACE_LOG_LEVEL: # pragma: full coverage
|
||||
prefix = "%s:%d - " % self.client if self.client else ""
|
||||
self.logger.log(TRACE_LOG_LEVEL, "%sUpgrading to WebSocket", prefix)
|
||||
|
||||
@@ -293,7 +290,7 @@ class H11Protocol(asyncio.Protocol):
|
||||
|
||||
def send_400_response(self, msg: str) -> None:
|
||||
reason = STATUS_PHRASES[400]
|
||||
headers: List[Tuple[bytes, bytes]] = [
|
||||
headers: list[tuple[bytes, bytes]] = [
|
||||
(b"content-type", b"text/plain; charset=utf-8"),
|
||||
(b"connection", b"close"),
|
||||
]
|
||||
@@ -318,9 +315,7 @@ class H11Protocol(asyncio.Protocol):
|
||||
# Set a short Keep-Alive timeout.
|
||||
self._unset_keepalive_if_required()
|
||||
|
||||
self.timeout_keep_alive_task = self.loop.call_later(
|
||||
self.timeout_keep_alive, self.timeout_keep_alive_handler
|
||||
)
|
||||
self.timeout_keep_alive_task = self.loop.call_later(self.timeout_keep_alive, self.timeout_keep_alive_handler)
|
||||
|
||||
# Unpause data reads if needed.
|
||||
self.flow.resume_reading()
|
||||
@@ -345,13 +340,13 @@ class H11Protocol(asyncio.Protocol):
|
||||
"""
|
||||
Called by the transport when the write buffer exceeds the high water mark.
|
||||
"""
|
||||
self.flow.pause_writing()
|
||||
self.flow.pause_writing() # pragma: full coverage
|
||||
|
||||
def resume_writing(self) -> None:
|
||||
"""
|
||||
Called by the transport when the write buffer drops below the low water mark.
|
||||
"""
|
||||
self.flow.resume_writing()
|
||||
self.flow.resume_writing() # pragma: full coverage
|
||||
|
||||
def timeout_keep_alive_handler(self) -> None:
|
||||
"""
|
||||
@@ -367,14 +362,14 @@ class H11Protocol(asyncio.Protocol):
|
||||
class RequestResponseCycle:
|
||||
def __init__(
|
||||
self,
|
||||
scope: "HTTPScope",
|
||||
scope: HTTPScope,
|
||||
conn: h11.Connection,
|
||||
transport: asyncio.Transport,
|
||||
flow: FlowControl,
|
||||
logger: logging.Logger,
|
||||
access_logger: logging.Logger,
|
||||
access_log: bool,
|
||||
default_headers: List[Tuple[bytes, bytes]],
|
||||
default_headers: list[tuple[bytes, bytes]],
|
||||
message_event: asyncio.Event,
|
||||
on_response: Callable[..., None],
|
||||
) -> None:
|
||||
@@ -403,7 +398,7 @@ class RequestResponseCycle:
|
||||
self.response_complete = False
|
||||
|
||||
# ASGI exception wrapper
|
||||
async def run_asgi(self, app: "ASGI3Application") -> None:
|
||||
async def run_asgi(self, app: ASGI3Application) -> None:
|
||||
try:
|
||||
result = await app( # type: ignore[func-returns-value]
|
||||
self.scope, self.receive, self.send
|
||||
@@ -432,7 +427,7 @@ class RequestResponseCycle:
|
||||
self.on_response = lambda: None
|
||||
|
||||
async def send_500_response(self) -> None:
|
||||
response_start_event: "HTTPResponseStartEvent" = {
|
||||
response_start_event: HTTPResponseStartEvent = {
|
||||
"type": "http.response.start",
|
||||
"status": 500,
|
||||
"headers": [
|
||||
@@ -441,7 +436,7 @@ class RequestResponseCycle:
|
||||
],
|
||||
}
|
||||
await self.send(response_start_event)
|
||||
response_body_event: "HTTPResponseBodyEvent" = {
|
||||
response_body_event: HTTPResponseBodyEvent = {
|
||||
"type": "http.response.body",
|
||||
"body": b"Internal Server Error",
|
||||
"more_body": False,
|
||||
@@ -449,14 +444,14 @@ class RequestResponseCycle:
|
||||
await self.send(response_body_event)
|
||||
|
||||
# ASGI interface
|
||||
async def send(self, message: "ASGISendEvent") -> None:
|
||||
async def send(self, message: ASGISendEvent) -> None:
|
||||
message_type = message["type"]
|
||||
|
||||
if self.flow.write_paused and not self.disconnected:
|
||||
await self.flow.drain()
|
||||
await self.flow.drain() # pragma: full coverage
|
||||
|
||||
if self.disconnected:
|
||||
return
|
||||
return # pragma: full coverage
|
||||
|
||||
if not self.response_started:
|
||||
# Sending response status line and headers
|
||||
@@ -523,12 +518,10 @@ class RequestResponseCycle:
|
||||
self.transport.close()
|
||||
self.on_response()
|
||||
|
||||
async def receive(self) -> "ASGIReceiveEvent":
|
||||
async def receive(self) -> ASGIReceiveEvent:
|
||||
if self.waiting_for_100_continue and not self.transport.is_closing():
|
||||
headers: List[Tuple[str, str]] = []
|
||||
event = h11.InformationalResponse(
|
||||
status_code=100, headers=headers, reason="Continue"
|
||||
)
|
||||
headers: list[tuple[str, str]] = []
|
||||
event = h11.InformationalResponse(status_code=100, headers=headers, reason="Continue")
|
||||
output = self.conn.send(event=event)
|
||||
self.transport.write(output)
|
||||
self.waiting_for_100_continue = False
|
||||
@@ -541,7 +534,7 @@ class RequestResponseCycle:
|
||||
if self.disconnected or self.response_complete:
|
||||
return {"type": "http.disconnect"}
|
||||
|
||||
message: "HTTPRequestEvent" = {
|
||||
message: HTTPRequestEvent = {
|
||||
"type": "http.request",
|
||||
"body": self.body,
|
||||
"more_body": self.more_body,
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import http
|
||||
import logging
|
||||
@@ -5,18 +7,7 @@ import re
|
||||
import urllib
|
||||
from asyncio.events import TimerHandle
|
||||
from collections import deque
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Deque,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Tuple,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
from typing import Any, Callable, Literal, cast
|
||||
|
||||
import httptools
|
||||
|
||||
@@ -24,31 +15,18 @@ from uvicorn._types import (
|
||||
ASGI3Application,
|
||||
ASGIReceiveEvent,
|
||||
ASGISendEvent,
|
||||
HTTPDisconnectEvent,
|
||||
HTTPRequestEvent,
|
||||
HTTPResponseBodyEvent,
|
||||
HTTPResponseStartEvent,
|
||||
HTTPScope,
|
||||
)
|
||||
from uvicorn.config import Config
|
||||
from uvicorn.logging import TRACE_LOG_LEVEL
|
||||
from uvicorn.protocols.http.flow_control import (
|
||||
CLOSE_HEADER,
|
||||
HIGH_WATER_LIMIT,
|
||||
FlowControl,
|
||||
service_unavailable,
|
||||
)
|
||||
from uvicorn.protocols.utils import (
|
||||
get_client_addr,
|
||||
get_local_addr,
|
||||
get_path_with_query_string,
|
||||
get_remote_addr,
|
||||
is_ssl,
|
||||
)
|
||||
from uvicorn.protocols.http.flow_control import CLOSE_HEADER, HIGH_WATER_LIMIT, FlowControl, service_unavailable
|
||||
from uvicorn.protocols.utils import get_client_addr, get_local_addr, get_path_with_query_string, get_remote_addr, is_ssl
|
||||
from uvicorn.server import ServerState
|
||||
|
||||
HEADER_RE = re.compile(b'[\x00-\x1F\x7F()<>@,;:[]={} \t\\"]')
|
||||
HEADER_VALUE_RE = re.compile(b"[\x00-\x1F\x7F]")
|
||||
HEADER_RE = re.compile(b'[\x00-\x1f\x7f()<>@,;:[]={} \t\\"]')
|
||||
HEADER_VALUE_RE = re.compile(b"[\x00-\x08\x0a-\x1f\x7f]")
|
||||
|
||||
|
||||
def _get_status_line(status_code: int) -> bytes:
|
||||
@@ -59,9 +37,7 @@ def _get_status_line(status_code: int) -> bytes:
|
||||
return b"".join([b"HTTP/1.1 ", str(status_code).encode(), b" ", phrase, b"\r\n"])
|
||||
|
||||
|
||||
STATUS_LINE = {
|
||||
status_code: _get_status_line(status_code) for status_code in range(100, 600)
|
||||
}
|
||||
STATUS_LINE = {status_code: _get_status_line(status_code) for status_code in range(100, 600)}
|
||||
|
||||
|
||||
class HttpToolsProtocol(asyncio.Protocol):
|
||||
@@ -69,8 +45,8 @@ class HttpToolsProtocol(asyncio.Protocol):
|
||||
self,
|
||||
config: Config,
|
||||
server_state: ServerState,
|
||||
app_state: Dict[str, Any],
|
||||
_loop: Optional[asyncio.AbstractEventLoop] = None,
|
||||
app_state: dict[str, Any],
|
||||
_loop: asyncio.AbstractEventLoop | None = None,
|
||||
) -> None:
|
||||
if not config.loaded:
|
||||
config.load()
|
||||
@@ -82,13 +58,21 @@ class HttpToolsProtocol(asyncio.Protocol):
|
||||
self.access_logger = logging.getLogger("uvicorn.access")
|
||||
self.access_log = self.access_logger.hasHandlers()
|
||||
self.parser = httptools.HttpRequestParser(self)
|
||||
|
||||
try:
|
||||
# Enable dangerous leniencies to allow server to a response on the first request from a pipelined request.
|
||||
self.parser.set_dangerous_leniencies(lenient_data_after_close=True)
|
||||
except AttributeError: # pragma: no cover
|
||||
# httptools < 0.6.3
|
||||
pass
|
||||
|
||||
self.ws_protocol_class = config.ws_protocol_class
|
||||
self.root_path = config.root_path
|
||||
self.limit_concurrency = config.limit_concurrency
|
||||
self.app_state = app_state
|
||||
|
||||
# Timeouts
|
||||
self.timeout_keep_alive_task: Optional[TimerHandle] = None
|
||||
self.timeout_keep_alive_task: TimerHandle | None = None
|
||||
self.timeout_keep_alive = config.timeout_keep_alive
|
||||
|
||||
# Global state
|
||||
@@ -99,14 +83,14 @@ class HttpToolsProtocol(asyncio.Protocol):
|
||||
# Per-connection state
|
||||
self.transport: asyncio.Transport = None # type: ignore[assignment]
|
||||
self.flow: FlowControl = None # type: ignore[assignment]
|
||||
self.server: Optional[Tuple[str, int]] = None
|
||||
self.client: Optional[Tuple[str, int]] = None
|
||||
self.scheme: Optional[Literal["http", "https"]] = None
|
||||
self.pipeline: Deque[Tuple[RequestResponseCycle, ASGI3Application]] = deque()
|
||||
self.server: tuple[str, int] | None = None
|
||||
self.client: tuple[str, int] | None = None
|
||||
self.scheme: Literal["http", "https"] | None = None
|
||||
self.pipeline: deque[tuple[RequestResponseCycle, ASGI3Application]] = deque()
|
||||
|
||||
# Per-request state
|
||||
self.scope: HTTPScope = None # type: ignore[assignment]
|
||||
self.headers: List[Tuple[bytes, bytes]] = None # type: ignore[assignment]
|
||||
self.headers: list[tuple[bytes, bytes]] = None # type: ignore[assignment]
|
||||
self.expect_100_continue = False
|
||||
self.cycle: RequestResponseCycle = None # type: ignore[assignment]
|
||||
|
||||
@@ -126,7 +110,7 @@ class HttpToolsProtocol(asyncio.Protocol):
|
||||
prefix = "%s:%d - " % self.client if self.client else ""
|
||||
self.logger.log(TRACE_LOG_LEVEL, "%sHTTP connection made", prefix)
|
||||
|
||||
def connection_lost(self, exc: Optional[Exception]) -> None:
|
||||
def connection_lost(self, exc: Exception | None) -> None:
|
||||
self.connections.discard(self)
|
||||
|
||||
if self.logger.level <= TRACE_LOG_LEVEL:
|
||||
@@ -153,7 +137,7 @@ class HttpToolsProtocol(asyncio.Protocol):
|
||||
self.timeout_keep_alive_task.cancel()
|
||||
self.timeout_keep_alive_task = None
|
||||
|
||||
def _get_upgrade(self) -> Optional[bytes]:
|
||||
def _get_upgrade(self) -> bytes | None:
|
||||
connection = []
|
||||
upgrade = None
|
||||
for name, value in self.headers:
|
||||
@@ -163,21 +147,22 @@ class HttpToolsProtocol(asyncio.Protocol):
|
||||
upgrade = value.lower()
|
||||
if b"upgrade" in connection:
|
||||
return upgrade
|
||||
return None
|
||||
return None # pragma: full coverage
|
||||
|
||||
def _should_upgrade_to_ws(self, upgrade: Optional[bytes]) -> bool:
|
||||
if upgrade == b"websocket" and self.ws_protocol_class is not None:
|
||||
return True
|
||||
if self.config.ws == "auto":
|
||||
msg = "Unsupported upgrade request."
|
||||
self.logger.warning(msg)
|
||||
def _should_upgrade_to_ws(self) -> bool:
|
||||
if self.ws_protocol_class is None:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _unsupported_upgrade_warning(self) -> None:
|
||||
self.logger.warning("Unsupported upgrade request.")
|
||||
if not self._should_upgrade_to_ws():
|
||||
msg = "No supported WebSocket library detected. Please use \"pip install 'uvicorn[standard]'\", or install 'websockets' or 'wsproto' manually." # noqa: E501
|
||||
self.logger.warning(msg)
|
||||
return False
|
||||
|
||||
def _should_upgrade(self) -> bool:
|
||||
upgrade = self._get_upgrade()
|
||||
return self._should_upgrade_to_ws(upgrade)
|
||||
return upgrade == b"websocket" and self._should_upgrade_to_ws()
|
||||
|
||||
def data_received(self, data: bytes) -> None:
|
||||
self._unset_keepalive_if_required()
|
||||
@@ -190,9 +175,10 @@ class HttpToolsProtocol(asyncio.Protocol):
|
||||
self.send_400_response(msg)
|
||||
return
|
||||
except httptools.HttpParserUpgrade:
|
||||
upgrade = self._get_upgrade()
|
||||
if self._should_upgrade_to_ws(upgrade):
|
||||
if self._should_upgrade():
|
||||
self.handle_websocket_upgrade()
|
||||
else:
|
||||
self._unsupported_upgrade_warning()
|
||||
|
||||
def handle_websocket_upgrade(self) -> None:
|
||||
if self.logger.level <= TRACE_LOG_LEVEL:
|
||||
@@ -217,7 +203,7 @@ class HttpToolsProtocol(asyncio.Protocol):
|
||||
def send_400_response(self, msg: str) -> None:
|
||||
content = [STATUS_LINE[400]]
|
||||
for name, value in self.server_state.default_headers:
|
||||
content.extend([name, b": ", value, b"\r\n"])
|
||||
content.extend([name, b": ", value, b"\r\n"]) # pragma: full coverage
|
||||
content.extend(
|
||||
[
|
||||
b"content-type: text/plain; charset=utf-8\r\n",
|
||||
@@ -269,14 +255,15 @@ class HttpToolsProtocol(asyncio.Protocol):
|
||||
path = raw_path.decode("ascii")
|
||||
if "%" in path:
|
||||
path = urllib.parse.unquote(path)
|
||||
self.scope["path"] = path
|
||||
self.scope["raw_path"] = raw_path
|
||||
full_path = self.root_path + path
|
||||
full_raw_path = self.root_path.encode("ascii") + raw_path
|
||||
self.scope["path"] = full_path
|
||||
self.scope["raw_path"] = full_raw_path
|
||||
self.scope["query_string"] = parsed_url.query or b""
|
||||
|
||||
# Handle 503 responses when 'limit_concurrency' is exceeded.
|
||||
if self.limit_concurrency is not None and (
|
||||
len(self.connections) >= self.limit_concurrency
|
||||
or len(self.tasks) >= self.limit_concurrency
|
||||
len(self.connections) >= self.limit_concurrency or len(self.tasks) >= self.limit_concurrency
|
||||
):
|
||||
app = service_unavailable
|
||||
message = "Exceeded concurrency limit."
|
||||
@@ -309,9 +296,7 @@ class HttpToolsProtocol(asyncio.Protocol):
|
||||
self.pipeline.appendleft((self.cycle, app))
|
||||
|
||||
def on_body(self, body: bytes) -> None:
|
||||
if (
|
||||
self.parser.should_upgrade() and self._should_upgrade()
|
||||
) or self.cycle.response_complete:
|
||||
if (self.parser.should_upgrade() and self._should_upgrade()) or self.cycle.response_complete:
|
||||
return
|
||||
self.cycle.body += body
|
||||
if len(self.cycle.body) > HIGH_WATER_LIMIT:
|
||||
@@ -319,9 +304,7 @@ class HttpToolsProtocol(asyncio.Protocol):
|
||||
self.cycle.message_event.set()
|
||||
|
||||
def on_message_complete(self) -> None:
|
||||
if (
|
||||
self.parser.should_upgrade() and self._should_upgrade()
|
||||
) or self.cycle.response_complete:
|
||||
if (self.parser.should_upgrade() and self._should_upgrade()) or self.cycle.response_complete:
|
||||
return
|
||||
self.cycle.more_body = False
|
||||
self.cycle.message_event.set()
|
||||
@@ -333,22 +316,22 @@ class HttpToolsProtocol(asyncio.Protocol):
|
||||
if self.transport.is_closing():
|
||||
return
|
||||
|
||||
# Set a short Keep-Alive timeout.
|
||||
self._unset_keepalive_if_required()
|
||||
|
||||
self.timeout_keep_alive_task = self.loop.call_later(
|
||||
self.timeout_keep_alive, self.timeout_keep_alive_handler
|
||||
)
|
||||
|
||||
# Unpause data reads if needed.
|
||||
self.flow.resume_reading()
|
||||
|
||||
# Unblock any pipelined events.
|
||||
# Unblock any pipelined events. If there are none, arm the
|
||||
# Keep-Alive timeout instead.
|
||||
if self.pipeline:
|
||||
cycle, app = self.pipeline.pop()
|
||||
task = self.loop.create_task(cycle.run_asgi(app))
|
||||
task.add_done_callback(self.tasks.discard)
|
||||
self.tasks.add(task)
|
||||
else:
|
||||
self.timeout_keep_alive_task = self.loop.call_later(
|
||||
self.timeout_keep_alive, self.timeout_keep_alive_handler
|
||||
)
|
||||
|
||||
def shutdown(self) -> None:
|
||||
"""
|
||||
@@ -363,13 +346,13 @@ class HttpToolsProtocol(asyncio.Protocol):
|
||||
"""
|
||||
Called by the transport when the write buffer exceeds the high water mark.
|
||||
"""
|
||||
self.flow.pause_writing()
|
||||
self.flow.pause_writing() # pragma: full coverage
|
||||
|
||||
def resume_writing(self) -> None:
|
||||
"""
|
||||
Called by the transport when the write buffer drops below the low water mark.
|
||||
"""
|
||||
self.flow.resume_writing()
|
||||
self.flow.resume_writing() # pragma: full coverage
|
||||
|
||||
def timeout_keep_alive_handler(self) -> None:
|
||||
"""
|
||||
@@ -383,13 +366,13 @@ class HttpToolsProtocol(asyncio.Protocol):
|
||||
class RequestResponseCycle:
|
||||
def __init__(
|
||||
self,
|
||||
scope: "HTTPScope",
|
||||
scope: HTTPScope,
|
||||
transport: asyncio.Transport,
|
||||
flow: FlowControl,
|
||||
logger: logging.Logger,
|
||||
access_logger: logging.Logger,
|
||||
access_log: bool,
|
||||
default_headers: List[Tuple[bytes, bytes]],
|
||||
default_headers: list[tuple[bytes, bytes]],
|
||||
message_event: asyncio.Event,
|
||||
expect_100_continue: bool,
|
||||
keep_alive: bool,
|
||||
@@ -417,11 +400,11 @@ class RequestResponseCycle:
|
||||
# Response state
|
||||
self.response_started = False
|
||||
self.response_complete = False
|
||||
self.chunked_encoding: Optional[bool] = None
|
||||
self.chunked_encoding: bool | None = None
|
||||
self.expected_content_length = 0
|
||||
|
||||
# ASGI exception wrapper
|
||||
async def run_asgi(self, app: "ASGI3Application") -> None:
|
||||
async def run_asgi(self, app: ASGI3Application) -> None:
|
||||
try:
|
||||
result = await app( # type: ignore[func-returns-value]
|
||||
self.scope, self.receive, self.send
|
||||
@@ -450,31 +433,28 @@ class RequestResponseCycle:
|
||||
self.on_response = lambda: None
|
||||
|
||||
async def send_500_response(self) -> None:
|
||||
response_start_event: "HTTPResponseStartEvent" = {
|
||||
"type": "http.response.start",
|
||||
"status": 500,
|
||||
"headers": [
|
||||
(b"content-type", b"text/plain; charset=utf-8"),
|
||||
(b"connection", b"close"),
|
||||
],
|
||||
}
|
||||
await self.send(response_start_event)
|
||||
response_body_event: "HTTPResponseBodyEvent" = {
|
||||
"type": "http.response.body",
|
||||
"body": b"Internal Server Error",
|
||||
"more_body": False,
|
||||
}
|
||||
await self.send(response_body_event)
|
||||
await self.send(
|
||||
{
|
||||
"type": "http.response.start",
|
||||
"status": 500,
|
||||
"headers": [
|
||||
(b"content-type", b"text/plain; charset=utf-8"),
|
||||
(b"content-length", b"21"),
|
||||
(b"connection", b"close"),
|
||||
],
|
||||
}
|
||||
)
|
||||
await self.send({"type": "http.response.body", "body": b"Internal Server Error", "more_body": False})
|
||||
|
||||
# ASGI interface
|
||||
async def send(self, message: "ASGISendEvent") -> None:
|
||||
async def send(self, message: ASGISendEvent) -> None:
|
||||
message_type = message["type"]
|
||||
|
||||
if self.flow.write_paused and not self.disconnected:
|
||||
await self.flow.drain()
|
||||
await self.flow.drain() # pragma: full coverage
|
||||
|
||||
if self.disconnected:
|
||||
return
|
||||
return # pragma: full coverage
|
||||
|
||||
if not self.response_started:
|
||||
# Sending response status line and headers
|
||||
@@ -507,7 +487,7 @@ class RequestResponseCycle:
|
||||
|
||||
for name, value in headers:
|
||||
if HEADER_RE.search(name):
|
||||
raise RuntimeError("Invalid HTTP header name.")
|
||||
raise RuntimeError("Invalid HTTP header name.") # pragma: full coverage
|
||||
if HEADER_VALUE_RE.search(value):
|
||||
raise RuntimeError("Invalid HTTP header value.")
|
||||
|
||||
@@ -522,11 +502,7 @@ class RequestResponseCycle:
|
||||
self.keep_alive = False
|
||||
content.extend([name, b": ", value, b"\r\n"])
|
||||
|
||||
if (
|
||||
self.chunked_encoding is None
|
||||
and self.scope["method"] != "HEAD"
|
||||
and status_code not in (204, 304)
|
||||
):
|
||||
if self.chunked_encoding is None and self.scope["method"] != "HEAD" and status_code not in (204, 304):
|
||||
# Neither content-length nor transfer-encoding specified
|
||||
self.chunked_encoding = True
|
||||
content.append(b"transfer-encoding: chunked\r\n")
|
||||
@@ -577,7 +553,7 @@ class RequestResponseCycle:
|
||||
msg = "Unexpected ASGI message '%s' sent, after response already completed."
|
||||
raise RuntimeError(msg % message_type)
|
||||
|
||||
async def receive(self) -> "ASGIReceiveEvent":
|
||||
async def receive(self) -> ASGIReceiveEvent:
|
||||
if self.waiting_for_100_continue and not self.transport.is_closing():
|
||||
self.transport.write(b"HTTP/1.1 100 Continue\r\n\r\n")
|
||||
self.waiting_for_100_continue = False
|
||||
@@ -587,15 +563,8 @@ class RequestResponseCycle:
|
||||
await self.message_event.wait()
|
||||
self.message_event.clear()
|
||||
|
||||
message: "Union[HTTPDisconnectEvent, HTTPRequestEvent]"
|
||||
if self.disconnected or self.response_complete:
|
||||
message = {"type": "http.disconnect"}
|
||||
else:
|
||||
message = {
|
||||
"type": "http.request",
|
||||
"body": self.body,
|
||||
"more_body": self.more_body,
|
||||
}
|
||||
self.body = b""
|
||||
|
||||
return {"type": "http.disconnect"}
|
||||
message: HTTPRequestEvent = {"type": "http.request", "body": self.body, "more_body": self.more_body}
|
||||
self.body = b""
|
||||
return message
|
||||
|
||||
@@ -1,11 +1,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import urllib.parse
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from uvicorn._types import WWWScope
|
||||
|
||||
|
||||
def get_remote_addr(transport: asyncio.Transport) -> Optional[Tuple[str, int]]:
|
||||
class ClientDisconnected(OSError): ...
|
||||
|
||||
|
||||
def get_remote_addr(transport: asyncio.Transport) -> tuple[str, int] | None:
|
||||
socket_info = transport.get_extra_info("socket")
|
||||
if socket_info is not None:
|
||||
try:
|
||||
@@ -22,7 +26,7 @@ def get_remote_addr(transport: asyncio.Transport) -> Optional[Tuple[str, int]]:
|
||||
return None
|
||||
|
||||
|
||||
def get_local_addr(transport: asyncio.Transport) -> Optional[Tuple[str, int]]:
|
||||
def get_local_addr(transport: asyncio.Transport) -> tuple[str, int] | None:
|
||||
socket_info = transport.get_extra_info("socket")
|
||||
if socket_info is not None:
|
||||
info = socket_info.getsockname()
|
||||
@@ -38,17 +42,15 @@ def is_ssl(transport: asyncio.Transport) -> bool:
|
||||
return bool(transport.get_extra_info("sslcontext"))
|
||||
|
||||
|
||||
def get_client_addr(scope: "WWWScope") -> str:
|
||||
def get_client_addr(scope: WWWScope) -> str:
|
||||
client = scope.get("client")
|
||||
if not client:
|
||||
return ""
|
||||
return "%s:%d" % client
|
||||
|
||||
|
||||
def get_path_with_query_string(scope: "WWWScope") -> str:
|
||||
def get_path_with_query_string(scope: WWWScope) -> str:
|
||||
path_with_query_string = urllib.parse.quote(scope["path"])
|
||||
if scope["query_string"]:
|
||||
path_with_query_string = "{}?{}".format(
|
||||
path_with_query_string, scope["query_string"].decode("ascii")
|
||||
)
|
||||
path_with_query_string = "{}?{}".format(path_with_query_string, scope["query_string"].decode("ascii"))
|
||||
return path_with_query_string
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
import asyncio
|
||||
import typing
|
||||
from __future__ import annotations
|
||||
|
||||
AutoWebSocketsProtocol: typing.Optional[typing.Callable[..., asyncio.Protocol]]
|
||||
import asyncio
|
||||
from typing import Callable
|
||||
|
||||
AutoWebSocketsProtocol: Callable[..., asyncio.Protocol] | None
|
||||
try:
|
||||
import websockets # noqa
|
||||
except ImportError: # pragma: no cover
|
||||
|
||||
@@ -1,40 +1,40 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import http
|
||||
import logging
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Literal, Optional, cast
|
||||
from urllib.parse import unquote
|
||||
|
||||
import websockets
|
||||
import websockets.legacy.handshake
|
||||
from websockets.datastructures import Headers
|
||||
from websockets.exceptions import ConnectionClosed
|
||||
from websockets.extensions.base import ServerExtensionFactory
|
||||
from websockets.extensions.permessage_deflate import ServerPerMessageDeflateFactory
|
||||
from websockets.legacy.server import HTTPResponse
|
||||
from websockets.server import WebSocketServerProtocol
|
||||
from websockets.typing import Subprotocol
|
||||
|
||||
from uvicorn._types import (
|
||||
ASGI3Application,
|
||||
ASGISendEvent,
|
||||
WebSocketAcceptEvent,
|
||||
WebSocketCloseEvent,
|
||||
WebSocketConnectEvent,
|
||||
WebSocketDisconnectEvent,
|
||||
WebSocketReceiveEvent,
|
||||
WebSocketResponseBodyEvent,
|
||||
WebSocketResponseStartEvent,
|
||||
WebSocketScope,
|
||||
WebSocketSendEvent,
|
||||
)
|
||||
from uvicorn.config import Config
|
||||
from uvicorn.logging import TRACE_LOG_LEVEL
|
||||
from uvicorn.protocols.utils import (
|
||||
ClientDisconnected,
|
||||
get_client_addr,
|
||||
get_local_addr,
|
||||
get_path_with_query_string,
|
||||
get_remote_addr,
|
||||
@@ -57,20 +57,21 @@ class Server:
|
||||
|
||||
|
||||
class WebSocketProtocol(WebSocketServerProtocol):
|
||||
extra_headers: List[Tuple[str, str]]
|
||||
extra_headers: list[tuple[str, str]]
|
||||
logger: logging.Logger | logging.LoggerAdapter[Any]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Config,
|
||||
server_state: ServerState,
|
||||
app_state: Dict[str, Any],
|
||||
_loop: Optional[asyncio.AbstractEventLoop] = None,
|
||||
app_state: dict[str, Any],
|
||||
_loop: asyncio.AbstractEventLoop | None = None,
|
||||
):
|
||||
if not config.loaded:
|
||||
config.load()
|
||||
|
||||
self.config = config
|
||||
self.app = config.loaded_app
|
||||
self.app = cast(ASGI3Application, config.loaded_app)
|
||||
self.loop = _loop or asyncio.get_event_loop()
|
||||
self.root_path = config.root_path
|
||||
self.app_state = app_state
|
||||
@@ -81,23 +82,23 @@ class WebSocketProtocol(WebSocketServerProtocol):
|
||||
|
||||
# Connection state
|
||||
self.transport: asyncio.Transport = None # type: ignore[assignment]
|
||||
self.server: Optional[Tuple[str, int]] = None
|
||||
self.client: Optional[Tuple[str, int]] = None
|
||||
self.server: tuple[str, int] | None = None
|
||||
self.client: tuple[str, int] | None = None
|
||||
self.scheme: Literal["wss", "ws"] = None # type: ignore[assignment]
|
||||
|
||||
# Connection events
|
||||
self.scope: WebSocketScope = None # type: ignore[assignment]
|
||||
self.scope: WebSocketScope
|
||||
self.handshake_started_event = asyncio.Event()
|
||||
self.handshake_completed_event = asyncio.Event()
|
||||
self.closed_event = asyncio.Event()
|
||||
self.initial_response: Optional[HTTPResponse] = None
|
||||
self.initial_response: HTTPResponse | None = None
|
||||
self.connect_sent = False
|
||||
self.lost_connection_before_handshake = False
|
||||
self.accepted_subprotocol: Optional[Subprotocol] = None
|
||||
self.accepted_subprotocol: Subprotocol | None = None
|
||||
|
||||
self.ws_server: Server = Server() # type: ignore[assignment]
|
||||
|
||||
extensions = []
|
||||
extensions: list[ServerExtensionFactory] = []
|
||||
if self.config.ws_per_message_deflate:
|
||||
extensions.append(ServerPerMessageDeflateFactory())
|
||||
|
||||
@@ -113,8 +114,7 @@ class WebSocketProtocol(WebSocketServerProtocol):
|
||||
)
|
||||
self.server_header = None
|
||||
self.extra_headers = [
|
||||
(name.decode("latin-1"), value.decode("latin-1"))
|
||||
for name, value in server_state.default_headers
|
||||
(name.decode("latin-1"), value.decode("latin-1")) for name, value in server_state.default_headers
|
||||
]
|
||||
|
||||
def connection_made( # type: ignore[override]
|
||||
@@ -132,16 +132,14 @@ class WebSocketProtocol(WebSocketServerProtocol):
|
||||
|
||||
super().connection_made(transport)
|
||||
|
||||
def connection_lost(self, exc: Optional[Exception]) -> None:
|
||||
def connection_lost(self, exc: Exception | None) -> None:
|
||||
self.connections.remove(self)
|
||||
|
||||
if self.logger.isEnabledFor(TRACE_LOG_LEVEL):
|
||||
prefix = "%s:%d - " % self.client if self.client else ""
|
||||
self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket connection lost", prefix)
|
||||
|
||||
self.lost_connection_before_handshake = (
|
||||
not self.handshake_completed_event.is_set()
|
||||
)
|
||||
self.lost_connection_before_handshake = not self.handshake_completed_event.is_set()
|
||||
self.handshake_completed_event.set()
|
||||
super().connection_lost(exc)
|
||||
if exc is None:
|
||||
@@ -155,12 +153,10 @@ class WebSocketProtocol(WebSocketServerProtocol):
|
||||
self.send_500_response()
|
||||
self.transport.close()
|
||||
|
||||
def on_task_complete(self, task: asyncio.Task) -> None:
|
||||
def on_task_complete(self, task: asyncio.Task[None]) -> None:
|
||||
self.tasks.discard(task)
|
||||
|
||||
async def process_request(
|
||||
self, path: str, headers: Headers
|
||||
) -> Optional[HTTPResponse]:
|
||||
async def process_request(self, path: str, request_headers: Headers) -> HTTPResponse | None:
|
||||
"""
|
||||
This hook is called to determine if the websocket should return
|
||||
an HTTP response and close.
|
||||
@@ -171,31 +167,35 @@ class WebSocketProtocol(WebSocketServerProtocol):
|
||||
"""
|
||||
path_portion, _, query_string = path.partition("?")
|
||||
|
||||
websockets.legacy.handshake.check_request(headers)
|
||||
websockets.legacy.handshake.check_request(request_headers)
|
||||
|
||||
subprotocols = []
|
||||
for header in headers.get_all("Sec-WebSocket-Protocol"):
|
||||
subprotocols: list[str] = []
|
||||
for header in request_headers.get_all("Sec-WebSocket-Protocol"):
|
||||
subprotocols.extend([token.strip() for token in header.split(",")])
|
||||
|
||||
asgi_headers = [
|
||||
(name.encode("ascii"), value.encode("ascii", errors="surrogateescape"))
|
||||
for name, value in headers.raw_items()
|
||||
for name, value in request_headers.raw_items()
|
||||
]
|
||||
path = unquote(path_portion)
|
||||
full_path = self.root_path + path
|
||||
full_raw_path = self.root_path.encode("ascii") + path_portion.encode("ascii")
|
||||
|
||||
self.scope = {
|
||||
"type": "websocket",
|
||||
"asgi": {"version": self.config.asgi_version, "spec_version": "2.3"},
|
||||
"asgi": {"version": self.config.asgi_version, "spec_version": "2.4"},
|
||||
"http_version": "1.1",
|
||||
"scheme": self.scheme,
|
||||
"server": self.server,
|
||||
"client": self.client,
|
||||
"root_path": self.root_path,
|
||||
"path": unquote(path_portion),
|
||||
"raw_path": path_portion.encode("ascii"),
|
||||
"path": full_path,
|
||||
"raw_path": full_raw_path,
|
||||
"query_string": query_string.encode("ascii"),
|
||||
"headers": asgi_headers,
|
||||
"subprotocols": subprotocols,
|
||||
"state": self.app_state.copy(),
|
||||
"extensions": {"websocket.http.response": {}},
|
||||
}
|
||||
task = self.loop.create_task(self.run_asgi())
|
||||
task.add_done_callback(self.on_task_complete)
|
||||
@@ -204,8 +204,8 @@ class WebSocketProtocol(WebSocketServerProtocol):
|
||||
return self.initial_response
|
||||
|
||||
def process_subprotocol(
|
||||
self, headers: Headers, available_subprotocols: Optional[Sequence[Subprotocol]]
|
||||
) -> Optional[Subprotocol]:
|
||||
self, headers: Headers, available_subprotocols: Sequence[Subprotocol] | None
|
||||
) -> Subprotocol | None:
|
||||
"""
|
||||
We override the standard 'process_subprotocol' behavior here so that
|
||||
we return whatever subprotocol is sent in the 'accept' message.
|
||||
@@ -215,8 +215,7 @@ class WebSocketProtocol(WebSocketServerProtocol):
|
||||
def send_500_response(self) -> None:
|
||||
msg = b"Internal Server Error"
|
||||
content = [
|
||||
b"HTTP/1.1 500 Internal Server Error\r\n"
|
||||
b"content-type: text/plain; charset=utf-8\r\n",
|
||||
b"HTTP/1.1 500 Internal Server Error\r\ncontent-type: text/plain; charset=utf-8\r\n",
|
||||
b"content-length: " + str(len(msg)).encode("ascii") + b"\r\n",
|
||||
b"connection: close\r\n",
|
||||
b"\r\n",
|
||||
@@ -224,12 +223,10 @@ class WebSocketProtocol(WebSocketServerProtocol):
|
||||
]
|
||||
self.transport.write(b"".join(content))
|
||||
# Allow handler task to terminate cleanly, as websockets doesn't cancel it by
|
||||
# itself (see https://github.com/encode/uvicorn/issues/920)
|
||||
# itself (see https://github.com/Kludex/uvicorn/issues/920)
|
||||
self.handshake_started_event.set()
|
||||
|
||||
async def ws_handler( # type: ignore[override]
|
||||
self, protocol: WebSocketServerProtocol, path: str
|
||||
) -> Any:
|
||||
async def ws_handler(self, protocol: WebSocketServerProtocol, path: str) -> Any: # type: ignore[override]
|
||||
"""
|
||||
This is the main handler function for the 'websockets' implementation
|
||||
to call into. We just wait for close then return, and instead allow
|
||||
@@ -244,11 +241,13 @@ class WebSocketProtocol(WebSocketServerProtocol):
|
||||
termination states.
|
||||
"""
|
||||
try:
|
||||
result = await self.app(self.scope, self.asgi_receive, self.asgi_send)
|
||||
except BaseException as exc:
|
||||
result = await self.app(self.scope, self.asgi_receive, self.asgi_send) # type: ignore[func-returns-value]
|
||||
except ClientDisconnected: # pragma: full coverage
|
||||
self.closed_event.set()
|
||||
msg = "Exception in ASGI application\n"
|
||||
self.logger.error(msg, exc_info=exc)
|
||||
self.transport.close()
|
||||
except BaseException:
|
||||
self.closed_event.set()
|
||||
self.logger.exception("Exception in ASGI application\n")
|
||||
if not self.handshake_started_event.is_set():
|
||||
self.send_500_response()
|
||||
else:
|
||||
@@ -257,17 +256,15 @@ class WebSocketProtocol(WebSocketServerProtocol):
|
||||
else:
|
||||
self.closed_event.set()
|
||||
if not self.handshake_started_event.is_set():
|
||||
msg = "ASGI callable returned without sending handshake."
|
||||
self.logger.error(msg)
|
||||
self.logger.error("ASGI callable returned without sending handshake.")
|
||||
self.send_500_response()
|
||||
self.transport.close()
|
||||
elif result is not None:
|
||||
msg = "ASGI callable should return None, but returned '%s'."
|
||||
self.logger.error(msg, result)
|
||||
self.logger.error("ASGI callable should return None, but returned '%s'.", result)
|
||||
await self.handshake_completed_event.wait()
|
||||
self.transport.close()
|
||||
|
||||
async def asgi_send(self, message: "ASGISendEvent") -> None:
|
||||
async def asgi_send(self, message: ASGISendEvent) -> None:
|
||||
message_type = message["type"]
|
||||
|
||||
if not self.handshake_started_event.is_set():
|
||||
@@ -275,13 +272,11 @@ class WebSocketProtocol(WebSocketServerProtocol):
|
||||
message = cast("WebSocketAcceptEvent", message)
|
||||
self.logger.info(
|
||||
'%s - "WebSocket %s" [accepted]',
|
||||
self.scope["client"],
|
||||
get_client_addr(self.scope),
|
||||
get_path_with_query_string(self.scope),
|
||||
)
|
||||
self.initial_response = None
|
||||
self.accepted_subprotocol = cast(
|
||||
Optional[Subprotocol], message.get("subprotocol")
|
||||
)
|
||||
self.accepted_subprotocol = cast(Optional[Subprotocol], message.get("subprotocol"))
|
||||
if "headers" in message:
|
||||
self.extra_headers.extend(
|
||||
# ASGI spec requires bytes
|
||||
@@ -295,53 +290,76 @@ class WebSocketProtocol(WebSocketServerProtocol):
|
||||
message = cast("WebSocketCloseEvent", message)
|
||||
self.logger.info(
|
||||
'%s - "WebSocket %s" 403',
|
||||
self.scope["client"],
|
||||
get_client_addr(self.scope),
|
||||
get_path_with_query_string(self.scope),
|
||||
)
|
||||
self.initial_response = (http.HTTPStatus.FORBIDDEN, [], b"")
|
||||
self.handshake_started_event.set()
|
||||
self.closed_event.set()
|
||||
|
||||
elif message_type == "websocket.http.response.start":
|
||||
message = cast("WebSocketResponseStartEvent", message)
|
||||
self.logger.info(
|
||||
'%s - "WebSocket %s" %d',
|
||||
get_client_addr(self.scope),
|
||||
get_path_with_query_string(self.scope),
|
||||
message["status"],
|
||||
)
|
||||
# websockets requires the status to be an enum. look it up.
|
||||
status = http.HTTPStatus(message["status"])
|
||||
headers = [
|
||||
(name.decode("latin-1"), value.decode("latin-1")) for name, value in message.get("headers", [])
|
||||
]
|
||||
self.initial_response = (status, headers, b"")
|
||||
self.handshake_started_event.set()
|
||||
|
||||
else:
|
||||
msg = (
|
||||
"Expected ASGI message 'websocket.accept' or 'websocket.close', "
|
||||
"but got '%s'."
|
||||
"Expected ASGI message 'websocket.accept', 'websocket.close', "
|
||||
"or 'websocket.http.response.start' but got '%s'."
|
||||
)
|
||||
raise RuntimeError(msg % message_type)
|
||||
|
||||
elif not self.closed_event.is_set():
|
||||
elif not self.closed_event.is_set() and self.initial_response is None:
|
||||
await self.handshake_completed_event.wait()
|
||||
|
||||
if message_type == "websocket.send":
|
||||
message = cast("WebSocketSendEvent", message)
|
||||
bytes_data = message.get("bytes")
|
||||
text_data = message.get("text")
|
||||
data = text_data if bytes_data is None else bytes_data
|
||||
await self.send(data) # type: ignore[arg-type]
|
||||
try:
|
||||
if message_type == "websocket.send":
|
||||
message = cast("WebSocketSendEvent", message)
|
||||
bytes_data = message.get("bytes")
|
||||
text_data = message.get("text")
|
||||
data = text_data if bytes_data is None else bytes_data
|
||||
await self.send(data) # type: ignore[arg-type]
|
||||
|
||||
elif message_type == "websocket.close":
|
||||
message = cast("WebSocketCloseEvent", message)
|
||||
code = message.get("code", 1000)
|
||||
reason = message.get("reason", "") or ""
|
||||
await self.close(code, reason)
|
||||
self.closed_event.set()
|
||||
elif message_type == "websocket.close":
|
||||
message = cast("WebSocketCloseEvent", message)
|
||||
code = message.get("code", 1000)
|
||||
reason = message.get("reason", "") or ""
|
||||
await self.close(code, reason)
|
||||
self.closed_event.set()
|
||||
|
||||
else:
|
||||
msg = "Expected ASGI message 'websocket.send' or 'websocket.close', but got '%s'."
|
||||
raise RuntimeError(msg % message_type)
|
||||
except ConnectionClosed as exc:
|
||||
raise ClientDisconnected from exc
|
||||
|
||||
elif self.initial_response is not None:
|
||||
if message_type == "websocket.http.response.body":
|
||||
message = cast("WebSocketResponseBodyEvent", message)
|
||||
body = self.initial_response[2] + message["body"]
|
||||
self.initial_response = self.initial_response[:2] + (body,)
|
||||
if not message.get("more_body", False):
|
||||
self.closed_event.set()
|
||||
else:
|
||||
msg = (
|
||||
"Expected ASGI message 'websocket.send' or 'websocket.close',"
|
||||
" but got '%s'."
|
||||
)
|
||||
msg = "Expected ASGI message 'websocket.http.response.body' but got '%s'."
|
||||
raise RuntimeError(msg % message_type)
|
||||
|
||||
else:
|
||||
msg = "Unexpected ASGI message '%s', after sending 'websocket.close'."
|
||||
msg = "Unexpected ASGI message '%s', after sending 'websocket.close' or response already completed."
|
||||
raise RuntimeError(msg % message_type)
|
||||
|
||||
async def asgi_receive(
|
||||
self,
|
||||
) -> Union[
|
||||
"WebSocketDisconnectEvent", "WebSocketConnectEvent", "WebSocketReceiveEvent"
|
||||
]:
|
||||
async def asgi_receive(self) -> WebSocketDisconnectEvent | WebSocketConnectEvent | WebSocketReceiveEvent:
|
||||
if not self.connect_sent:
|
||||
self.connect_sent = True
|
||||
return {"type": "websocket.connect"}
|
||||
@@ -358,19 +376,12 @@ class WebSocketProtocol(WebSocketServerProtocol):
|
||||
|
||||
try:
|
||||
data = await self.recv()
|
||||
except ConnectionClosed as exc:
|
||||
except ConnectionClosed:
|
||||
self.closed_event.set()
|
||||
if self.ws_server.closing:
|
||||
return {"type": "websocket.disconnect", "code": 1012}
|
||||
return {"type": "websocket.disconnect", "code": exc.code}
|
||||
|
||||
msg: WebSocketReceiveEvent = { # type: ignore[typeddict-item]
|
||||
"type": "websocket.receive"
|
||||
}
|
||||
return {"type": "websocket.disconnect", "code": self.close_code or 1005, "reason": self.close_reason}
|
||||
|
||||
if isinstance(data, str):
|
||||
msg["text"] = data
|
||||
else:
|
||||
msg["bytes"] = data
|
||||
|
||||
return msg
|
||||
return {"type": "websocket.receive", "text": data}
|
||||
return {"type": "websocket.receive", "bytes": data}
|
||||
|
||||
@@ -0,0 +1,417 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import sys
|
||||
from asyncio.transports import BaseTransport, Transport
|
||||
from http import HTTPStatus
|
||||
from typing import Any, Literal, cast
|
||||
from urllib.parse import unquote
|
||||
|
||||
from websockets.exceptions import InvalidState
|
||||
from websockets.extensions.permessage_deflate import ServerPerMessageDeflateFactory
|
||||
from websockets.frames import Frame, Opcode
|
||||
from websockets.http11 import Request
|
||||
from websockets.server import ServerProtocol
|
||||
|
||||
from uvicorn._types import (
|
||||
ASGIReceiveEvent,
|
||||
ASGISendEvent,
|
||||
WebSocketAcceptEvent,
|
||||
WebSocketCloseEvent,
|
||||
WebSocketResponseBodyEvent,
|
||||
WebSocketResponseStartEvent,
|
||||
WebSocketScope,
|
||||
WebSocketSendEvent,
|
||||
)
|
||||
from uvicorn.config import Config
|
||||
from uvicorn.logging import TRACE_LOG_LEVEL
|
||||
from uvicorn.protocols.utils import (
|
||||
ClientDisconnected,
|
||||
get_local_addr,
|
||||
get_path_with_query_string,
|
||||
get_remote_addr,
|
||||
is_ssl,
|
||||
)
|
||||
from uvicorn.server import ServerState
|
||||
|
||||
if sys.version_info >= (3, 11): # pragma: no cover
|
||||
from typing import assert_never
|
||||
else: # pragma: no cover
|
||||
from typing_extensions import assert_never
|
||||
|
||||
|
||||
class WebSocketsSansIOProtocol(asyncio.Protocol):
|
||||
def __init__(
|
||||
self,
|
||||
config: Config,
|
||||
server_state: ServerState,
|
||||
app_state: dict[str, Any],
|
||||
_loop: asyncio.AbstractEventLoop | None = None,
|
||||
) -> None:
|
||||
if not config.loaded:
|
||||
config.load() # pragma: no cover
|
||||
|
||||
self.config = config
|
||||
self.app = config.loaded_app
|
||||
self.loop = _loop or asyncio.get_event_loop()
|
||||
self.logger = logging.getLogger("uvicorn.error")
|
||||
self.root_path = config.root_path
|
||||
self.app_state = app_state
|
||||
|
||||
# Shared server state
|
||||
self.connections = server_state.connections
|
||||
self.tasks = server_state.tasks
|
||||
self.default_headers = server_state.default_headers
|
||||
|
||||
# Connection state
|
||||
self.transport: asyncio.Transport = None # type: ignore[assignment]
|
||||
self.server: tuple[str, int] | None = None
|
||||
self.client: tuple[str, int] | None = None
|
||||
self.scheme: Literal["wss", "ws"] = None # type: ignore[assignment]
|
||||
|
||||
# WebSocket state
|
||||
self.queue: asyncio.Queue[ASGIReceiveEvent] = asyncio.Queue()
|
||||
self.handshake_initiated = False
|
||||
self.handshake_complete = False
|
||||
self.close_sent = False
|
||||
self.initial_response: tuple[int, list[tuple[str, str]], bytes] | None = None
|
||||
|
||||
extensions = []
|
||||
if self.config.ws_per_message_deflate:
|
||||
extensions = [
|
||||
ServerPerMessageDeflateFactory(
|
||||
server_max_window_bits=12,
|
||||
client_max_window_bits=12,
|
||||
compress_settings={"memLevel": 5},
|
||||
)
|
||||
]
|
||||
self.conn = ServerProtocol(
|
||||
extensions=extensions,
|
||||
max_size=self.config.ws_max_size,
|
||||
logger=logging.getLogger("uvicorn.error"),
|
||||
)
|
||||
|
||||
self.read_paused = False
|
||||
self.writable = asyncio.Event()
|
||||
self.writable.set()
|
||||
|
||||
# Buffers
|
||||
self.bytes = b""
|
||||
|
||||
def connection_made(self, transport: BaseTransport) -> None:
|
||||
"""Called when a connection is made."""
|
||||
transport = cast(Transport, transport)
|
||||
self.connections.add(self)
|
||||
self.transport = transport
|
||||
self.server = get_local_addr(transport)
|
||||
self.client = get_remote_addr(transport)
|
||||
self.scheme = "wss" if is_ssl(transport) else "ws"
|
||||
|
||||
if self.logger.level <= TRACE_LOG_LEVEL:
|
||||
prefix = "%s:%d - " % self.client if self.client else ""
|
||||
self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket connection made", prefix)
|
||||
|
||||
def connection_lost(self, exc: Exception | None) -> None:
|
||||
code = 1005 if self.handshake_complete else 1006
|
||||
self.queue.put_nowait({"type": "websocket.disconnect", "code": code})
|
||||
self.connections.remove(self)
|
||||
|
||||
if self.logger.level <= TRACE_LOG_LEVEL:
|
||||
prefix = "%s:%d - " % self.client if self.client else ""
|
||||
self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket connection lost", prefix)
|
||||
|
||||
self.handshake_complete = True
|
||||
if exc is None:
|
||||
self.transport.close()
|
||||
|
||||
def eof_received(self) -> None:
|
||||
pass
|
||||
|
||||
def shutdown(self) -> None:
|
||||
if self.handshake_complete:
|
||||
self.queue.put_nowait({"type": "websocket.disconnect", "code": 1012})
|
||||
self.conn.send_close(1012)
|
||||
output = self.conn.data_to_send()
|
||||
self.transport.write(b"".join(output))
|
||||
else:
|
||||
self.send_500_response()
|
||||
self.transport.close()
|
||||
|
||||
def data_received(self, data: bytes) -> None:
|
||||
self.conn.receive_data(data)
|
||||
if self.conn.parser_exc is not None: # pragma: no cover
|
||||
self.handle_parser_exception()
|
||||
return
|
||||
self.handle_events()
|
||||
|
||||
def handle_events(self) -> None:
|
||||
for event in self.conn.events_received():
|
||||
if isinstance(event, Request):
|
||||
self.handle_connect(event)
|
||||
if isinstance(event, Frame):
|
||||
if event.opcode == Opcode.CONT:
|
||||
self.handle_cont(event) # pragma: no cover
|
||||
elif event.opcode == Opcode.TEXT:
|
||||
self.handle_text(event)
|
||||
elif event.opcode == Opcode.BINARY:
|
||||
self.handle_bytes(event)
|
||||
elif event.opcode == Opcode.PING:
|
||||
self.handle_ping()
|
||||
elif event.opcode == Opcode.PONG:
|
||||
pass # pragma: no cover
|
||||
elif event.opcode == Opcode.CLOSE:
|
||||
self.handle_close(event)
|
||||
else:
|
||||
assert_never(event.opcode) # pragma: no cover
|
||||
|
||||
# Event handlers
|
||||
|
||||
def handle_connect(self, event: Request) -> None:
|
||||
self.request = event
|
||||
self.response = self.conn.accept(event)
|
||||
self.handshake_initiated = True
|
||||
if self.response.status_code != 101:
|
||||
self.handshake_complete = True
|
||||
self.close_sent = True
|
||||
self.conn.send_response(self.response)
|
||||
output = self.conn.data_to_send()
|
||||
self.transport.write(b"".join(output))
|
||||
self.transport.close()
|
||||
return
|
||||
|
||||
headers = [
|
||||
(key.encode("ascii"), value.encode("ascii", errors="surrogateescape"))
|
||||
for key, value in event.headers.raw_items()
|
||||
]
|
||||
raw_path, _, query_string = event.path.partition("?")
|
||||
self.scope: WebSocketScope = {
|
||||
"type": "websocket",
|
||||
"asgi": {"version": self.config.asgi_version, "spec_version": "2.3"},
|
||||
"http_version": "1.1",
|
||||
"scheme": self.scheme,
|
||||
"server": self.server,
|
||||
"client": self.client,
|
||||
"root_path": self.root_path,
|
||||
"path": unquote(raw_path),
|
||||
"raw_path": raw_path.encode("ascii"),
|
||||
"query_string": query_string.encode("ascii"),
|
||||
"headers": headers,
|
||||
"subprotocols": event.headers.get_all("Sec-WebSocket-Protocol"),
|
||||
"state": self.app_state.copy(),
|
||||
"extensions": {"websocket.http.response": {}},
|
||||
}
|
||||
self.queue.put_nowait({"type": "websocket.connect"})
|
||||
task = self.loop.create_task(self.run_asgi())
|
||||
task.add_done_callback(self.on_task_complete)
|
||||
self.tasks.add(task)
|
||||
|
||||
def handle_cont(self, event: Frame) -> None: # pragma: no cover
|
||||
self.bytes += event.data
|
||||
if event.fin:
|
||||
self.send_receive_event_to_app()
|
||||
|
||||
def handle_text(self, event: Frame) -> None:
|
||||
self.bytes = event.data
|
||||
self.curr_msg_data_type: Literal["text", "bytes"] = "text"
|
||||
if event.fin:
|
||||
self.send_receive_event_to_app()
|
||||
|
||||
def handle_bytes(self, event: Frame) -> None:
|
||||
self.bytes = event.data
|
||||
self.curr_msg_data_type = "bytes"
|
||||
if event.fin:
|
||||
self.send_receive_event_to_app()
|
||||
|
||||
def send_receive_event_to_app(self) -> None:
|
||||
if self.curr_msg_data_type == "text":
|
||||
try:
|
||||
self.queue.put_nowait({"type": "websocket.receive", "text": self.bytes.decode()})
|
||||
except UnicodeDecodeError: # pragma: no cover
|
||||
self.logger.exception("Invalid UTF-8 sequence received from client.")
|
||||
self.conn.send_close(1007)
|
||||
self.handle_parser_exception()
|
||||
return
|
||||
else:
|
||||
self.queue.put_nowait({"type": "websocket.receive", "bytes": self.bytes})
|
||||
if not self.read_paused:
|
||||
self.read_paused = True
|
||||
self.transport.pause_reading()
|
||||
|
||||
def handle_ping(self) -> None:
|
||||
output = self.conn.data_to_send()
|
||||
self.transport.write(b"".join(output))
|
||||
|
||||
def handle_close(self, event: Frame) -> None:
|
||||
if not self.close_sent and not self.transport.is_closing():
|
||||
assert self.conn.close_rcvd is not None
|
||||
code = self.conn.close_rcvd.code
|
||||
reason = self.conn.close_rcvd.reason
|
||||
self.queue.put_nowait({"type": "websocket.disconnect", "code": code, "reason": reason})
|
||||
|
||||
output = self.conn.data_to_send()
|
||||
self.transport.write(b"".join(output))
|
||||
self.transport.close()
|
||||
|
||||
def handle_parser_exception(self) -> None: # pragma: no cover
|
||||
assert self.conn.close_sent is not None
|
||||
code = self.conn.close_sent.code
|
||||
reason = self.conn.close_sent.reason
|
||||
self.queue.put_nowait({"type": "websocket.disconnect", "code": code, "reason": reason})
|
||||
|
||||
output = self.conn.data_to_send()
|
||||
self.transport.write(b"".join(output))
|
||||
self.close_sent = True
|
||||
self.transport.close()
|
||||
|
||||
def on_task_complete(self, task: asyncio.Task[None]) -> None:
|
||||
self.tasks.discard(task)
|
||||
|
||||
async def run_asgi(self) -> None:
|
||||
try:
|
||||
result = await self.app(self.scope, self.receive, self.send)
|
||||
except ClientDisconnected:
|
||||
self.transport.close() # pragma: no cover
|
||||
except BaseException:
|
||||
self.logger.exception("Exception in ASGI application\n")
|
||||
self.send_500_response()
|
||||
self.transport.close()
|
||||
else:
|
||||
if not self.handshake_complete:
|
||||
self.logger.error("ASGI callable returned without completing handshake.")
|
||||
self.send_500_response()
|
||||
self.transport.close()
|
||||
elif result is not None:
|
||||
self.logger.error("ASGI callable should return None, but returned '%s'.", result)
|
||||
self.transport.close()
|
||||
|
||||
def send_500_response(self) -> None:
|
||||
if self.initial_response or self.handshake_complete:
|
||||
return
|
||||
response = self.conn.reject(500, "Internal Server Error")
|
||||
self.conn.send_response(response)
|
||||
output = self.conn.data_to_send()
|
||||
self.transport.write(b"".join(output))
|
||||
|
||||
async def send(self, message: ASGISendEvent) -> None:
|
||||
await self.writable.wait()
|
||||
|
||||
message_type = message["type"]
|
||||
|
||||
if not self.handshake_complete and self.initial_response is None:
|
||||
if message_type == "websocket.accept":
|
||||
message = cast(WebSocketAcceptEvent, message)
|
||||
self.logger.info(
|
||||
'%s - "WebSocket %s" [accepted]',
|
||||
self.scope["client"],
|
||||
get_path_with_query_string(self.scope),
|
||||
)
|
||||
headers = [
|
||||
(name.decode("latin-1").lower(), value.decode("latin-1").lower())
|
||||
for name, value in (self.default_headers + list(message.get("headers", [])))
|
||||
]
|
||||
accepted_subprotocol = message.get("subprotocol")
|
||||
if accepted_subprotocol:
|
||||
headers.append(("Sec-WebSocket-Protocol", accepted_subprotocol))
|
||||
self.response.headers.update(headers)
|
||||
|
||||
if not self.transport.is_closing():
|
||||
self.handshake_complete = True
|
||||
self.conn.send_response(self.response)
|
||||
output = self.conn.data_to_send()
|
||||
self.transport.write(b"".join(output))
|
||||
|
||||
elif message_type == "websocket.close":
|
||||
message = cast(WebSocketCloseEvent, message)
|
||||
self.queue.put_nowait({"type": "websocket.disconnect", "code": 1006})
|
||||
self.logger.info(
|
||||
'%s - "WebSocket %s" 403',
|
||||
self.scope["client"],
|
||||
get_path_with_query_string(self.scope),
|
||||
)
|
||||
response = self.conn.reject(HTTPStatus.FORBIDDEN, "")
|
||||
self.conn.send_response(response)
|
||||
output = self.conn.data_to_send()
|
||||
self.close_sent = True
|
||||
self.handshake_complete = True
|
||||
self.transport.write(b"".join(output))
|
||||
self.transport.close()
|
||||
elif message_type == "websocket.http.response.start" and self.initial_response is None:
|
||||
message = cast(WebSocketResponseStartEvent, message)
|
||||
if not (100 <= message["status"] < 600):
|
||||
raise RuntimeError("Invalid HTTP status code '%d' in response." % message["status"])
|
||||
self.logger.info(
|
||||
'%s - "WebSocket %s" %d',
|
||||
self.scope["client"],
|
||||
get_path_with_query_string(self.scope),
|
||||
message["status"],
|
||||
)
|
||||
headers = [
|
||||
(name.decode("latin-1"), value.decode("latin-1"))
|
||||
for name, value in list(message.get("headers", []))
|
||||
]
|
||||
self.initial_response = (message["status"], headers, b"")
|
||||
else:
|
||||
msg = (
|
||||
"Expected ASGI message 'websocket.accept', 'websocket.close' "
|
||||
"or 'websocket.http.response.start' "
|
||||
"but got '%s'."
|
||||
)
|
||||
raise RuntimeError(msg % message_type)
|
||||
|
||||
elif not self.close_sent and self.initial_response is None:
|
||||
try:
|
||||
if message_type == "websocket.send":
|
||||
message = cast(WebSocketSendEvent, message)
|
||||
bytes_data = message.get("bytes")
|
||||
text_data = message.get("text")
|
||||
if text_data:
|
||||
self.conn.send_text(text_data.encode())
|
||||
elif bytes_data:
|
||||
self.conn.send_binary(bytes_data)
|
||||
output = self.conn.data_to_send()
|
||||
self.transport.write(b"".join(output))
|
||||
|
||||
elif message_type == "websocket.close" and not self.transport.is_closing():
|
||||
message = cast(WebSocketCloseEvent, message)
|
||||
code = message.get("code", 1000)
|
||||
reason = message.get("reason", "") or ""
|
||||
self.queue.put_nowait({"type": "websocket.disconnect", "code": code, "reason": reason})
|
||||
self.conn.send_close(code, reason)
|
||||
output = self.conn.data_to_send()
|
||||
self.transport.write(b"".join(output))
|
||||
self.close_sent = True
|
||||
self.transport.close()
|
||||
else:
|
||||
msg = "Expected ASGI message 'websocket.send' or 'websocket.close', but got '%s'."
|
||||
raise RuntimeError(msg % message_type)
|
||||
except InvalidState:
|
||||
raise ClientDisconnected()
|
||||
elif self.initial_response is not None:
|
||||
if message_type == "websocket.http.response.body":
|
||||
message = cast(WebSocketResponseBodyEvent, message)
|
||||
body = self.initial_response[2] + message["body"]
|
||||
self.initial_response = self.initial_response[:2] + (body,)
|
||||
if not message.get("more_body", False):
|
||||
response = self.conn.reject(self.initial_response[0], body.decode())
|
||||
response.headers.update(self.initial_response[1])
|
||||
self.queue.put_nowait({"type": "websocket.disconnect", "code": 1006})
|
||||
self.conn.send_response(response)
|
||||
output = self.conn.data_to_send()
|
||||
self.close_sent = True
|
||||
self.transport.write(b"".join(output))
|
||||
self.transport.close()
|
||||
else: # pragma: no cover
|
||||
msg = "Expected ASGI message 'websocket.http.response.body' but got '%s'."
|
||||
raise RuntimeError(msg % message_type)
|
||||
|
||||
else:
|
||||
msg = "Unexpected ASGI message '%s', after sending 'websocket.close'."
|
||||
raise RuntimeError(msg % message_type)
|
||||
|
||||
async def receive(self) -> ASGIReceiveEvent:
|
||||
message = await self.queue.get()
|
||||
if self.read_paused and self.queue.empty():
|
||||
self.read_paused = False
|
||||
self.transport.resume_reading()
|
||||
return message
|
||||
@@ -1,27 +1,32 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import typing
|
||||
from typing import Literal
|
||||
from typing import Any, Literal, cast
|
||||
from urllib.parse import unquote
|
||||
|
||||
import wsproto
|
||||
from wsproto import ConnectionType, events
|
||||
from wsproto.connection import ConnectionState
|
||||
from wsproto.extensions import Extension, PerMessageDeflate
|
||||
from wsproto.utilities import RemoteProtocolError
|
||||
from wsproto.utilities import LocalProtocolError, RemoteProtocolError
|
||||
|
||||
from uvicorn._types import (
|
||||
ASGI3Application,
|
||||
ASGISendEvent,
|
||||
WebSocketAcceptEvent,
|
||||
WebSocketCloseEvent,
|
||||
WebSocketEvent,
|
||||
WebSocketReceiveEvent,
|
||||
WebSocketResponseBodyEvent,
|
||||
WebSocketResponseStartEvent,
|
||||
WebSocketScope,
|
||||
WebSocketSendEvent,
|
||||
)
|
||||
from uvicorn.config import Config
|
||||
from uvicorn.logging import TRACE_LOG_LEVEL
|
||||
from uvicorn.protocols.utils import (
|
||||
ClientDisconnected,
|
||||
get_client_addr,
|
||||
get_local_addr,
|
||||
get_path_with_query_string,
|
||||
get_remote_addr,
|
||||
@@ -35,14 +40,14 @@ class WSProtocol(asyncio.Protocol):
|
||||
self,
|
||||
config: Config,
|
||||
server_state: ServerState,
|
||||
app_state: typing.Dict[str, typing.Any],
|
||||
_loop: typing.Optional[asyncio.AbstractEventLoop] = None,
|
||||
app_state: dict[str, Any],
|
||||
_loop: asyncio.AbstractEventLoop | None = None,
|
||||
) -> None:
|
||||
if not config.loaded:
|
||||
config.load()
|
||||
config.load() # pragma: full coverage
|
||||
|
||||
self.config = config
|
||||
self.app = config.loaded_app
|
||||
self.app = cast(ASGI3Application, config.loaded_app)
|
||||
self.loop = _loop or asyncio.get_event_loop()
|
||||
self.logger = logging.getLogger("uvicorn.error")
|
||||
self.root_path = config.root_path
|
||||
@@ -55,15 +60,18 @@ class WSProtocol(asyncio.Protocol):
|
||||
|
||||
# Connection state
|
||||
self.transport: asyncio.Transport = None # type: ignore[assignment]
|
||||
self.server: typing.Optional[typing.Tuple[str, int]] = None
|
||||
self.client: typing.Optional[typing.Tuple[str, int]] = None
|
||||
self.server: tuple[str, int] | None = None
|
||||
self.client: tuple[str, int] | None = None
|
||||
self.scheme: Literal["wss", "ws"] = None # type: ignore[assignment]
|
||||
|
||||
# WebSocket state
|
||||
self.queue: asyncio.Queue["WebSocketEvent"] = asyncio.Queue()
|
||||
self.queue: asyncio.Queue[WebSocketEvent] = asyncio.Queue()
|
||||
self.handshake_complete = False
|
||||
self.close_sent = False
|
||||
|
||||
# Rejection state
|
||||
self.response_started = False
|
||||
|
||||
self.conn = wsproto.WSConnection(connection_type=ConnectionType.SERVER)
|
||||
|
||||
self.read_paused = False
|
||||
@@ -89,7 +97,7 @@ class WSProtocol(asyncio.Protocol):
|
||||
prefix = "%s:%d - " % self.client if self.client else ""
|
||||
self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket connection made", prefix)
|
||||
|
||||
def connection_lost(self, exc: typing.Optional[Exception]) -> None:
|
||||
def connection_lost(self, exc: Exception | None) -> None:
|
||||
code = 1005 if self.handshake_complete else 1006
|
||||
self.queue.put_nowait({"type": "websocket.disconnect", "code": code})
|
||||
self.connections.remove(self)
|
||||
@@ -132,13 +140,13 @@ class WSProtocol(asyncio.Protocol):
|
||||
"""
|
||||
Called by the transport when the write buffer exceeds the high water mark.
|
||||
"""
|
||||
self.writable.clear()
|
||||
self.writable.clear() # pragma: full coverage
|
||||
|
||||
def resume_writing(self) -> None:
|
||||
"""
|
||||
Called by the transport when the write buffer drops below the low water mark.
|
||||
"""
|
||||
self.writable.set()
|
||||
self.writable.set() # pragma: full coverage
|
||||
|
||||
def shutdown(self) -> None:
|
||||
if self.handshake_complete:
|
||||
@@ -149,7 +157,7 @@ class WSProtocol(asyncio.Protocol):
|
||||
self.send_500_response()
|
||||
self.transport.close()
|
||||
|
||||
def on_task_complete(self, task: asyncio.Task) -> None:
|
||||
def on_task_complete(self, task: asyncio.Task[None]) -> None:
|
||||
self.tasks.discard(task)
|
||||
|
||||
# Event handlers
|
||||
@@ -158,20 +166,24 @@ class WSProtocol(asyncio.Protocol):
|
||||
headers = [(b"host", event.host.encode())]
|
||||
headers += [(key.lower(), value) for key, value in event.extra_headers]
|
||||
raw_path, _, query_string = event.target.partition("?")
|
||||
self.scope: "WebSocketScope" = {
|
||||
path = unquote(raw_path)
|
||||
full_path = self.root_path + path
|
||||
full_raw_path = self.root_path.encode("ascii") + raw_path.encode("ascii")
|
||||
self.scope: WebSocketScope = {
|
||||
"type": "websocket",
|
||||
"asgi": {"version": self.config.asgi_version, "spec_version": "2.3"},
|
||||
"asgi": {"version": self.config.asgi_version, "spec_version": "2.4"},
|
||||
"http_version": "1.1",
|
||||
"scheme": self.scheme,
|
||||
"server": self.server,
|
||||
"client": self.client,
|
||||
"root_path": self.root_path,
|
||||
"path": unquote(raw_path),
|
||||
"raw_path": raw_path.encode("ascii"),
|
||||
"path": full_path,
|
||||
"raw_path": full_raw_path,
|
||||
"query_string": query_string.encode("ascii"),
|
||||
"headers": headers,
|
||||
"subprotocols": event.subprotocols,
|
||||
"state": self.app_state.copy(),
|
||||
"extensions": {"websocket.http.response": {}},
|
||||
}
|
||||
self.queue.put_nowait({"type": "websocket.connect"})
|
||||
task = self.loop.create_task(self.run_asgi())
|
||||
@@ -181,11 +193,7 @@ class WSProtocol(asyncio.Protocol):
|
||||
def handle_text(self, event: events.TextMessage) -> None:
|
||||
self.text += event.data
|
||||
if event.message_finished:
|
||||
msg: "WebSocketReceiveEvent" = { # type: ignore[typeddict-item]
|
||||
"type": "websocket.receive",
|
||||
"text": self.text,
|
||||
}
|
||||
self.queue.put_nowait(msg)
|
||||
self.queue.put_nowait({"type": "websocket.receive", "text": self.text})
|
||||
self.text = ""
|
||||
if not self.read_paused:
|
||||
self.read_paused = True
|
||||
@@ -195,11 +203,7 @@ class WSProtocol(asyncio.Protocol):
|
||||
self.bytes += event.data
|
||||
# todo: we may want to guard the size of self.bytes and self.text
|
||||
if event.message_finished:
|
||||
msg: "WebSocketReceiveEvent" = { # type: ignore[typeddict-item]
|
||||
"type": "websocket.receive",
|
||||
"bytes": self.bytes,
|
||||
}
|
||||
self.queue.put_nowait(msg)
|
||||
self.queue.put_nowait({"type": "websocket.receive", "bytes": self.bytes})
|
||||
self.bytes = b""
|
||||
if not self.read_paused:
|
||||
self.read_paused = True
|
||||
@@ -208,62 +212,58 @@ class WSProtocol(asyncio.Protocol):
|
||||
def handle_close(self, event: events.CloseConnection) -> None:
|
||||
if self.conn.state == ConnectionState.REMOTE_CLOSING:
|
||||
self.transport.write(self.conn.send(event.response()))
|
||||
self.queue.put_nowait({"type": "websocket.disconnect", "code": event.code})
|
||||
self.queue.put_nowait({"type": "websocket.disconnect", "code": event.code, "reason": event.reason})
|
||||
self.transport.close()
|
||||
|
||||
def handle_ping(self, event: events.Ping) -> None:
|
||||
self.transport.write(self.conn.send(event.response()))
|
||||
|
||||
def send_500_response(self) -> None:
|
||||
headers = [
|
||||
if self.response_started or self.handshake_complete:
|
||||
return # we cannot send responses anymore
|
||||
headers: list[tuple[bytes, bytes]] = [
|
||||
(b"content-type", b"text/plain; charset=utf-8"),
|
||||
(b"connection", b"close"),
|
||||
(b"content-length", b"21"),
|
||||
]
|
||||
output = self.conn.send(
|
||||
wsproto.events.RejectConnection(
|
||||
status_code=500, headers=headers, has_body=True
|
||||
)
|
||||
)
|
||||
output += self.conn.send(
|
||||
wsproto.events.RejectData(data=b"Internal Server Error")
|
||||
)
|
||||
output = self.conn.send(wsproto.events.RejectConnection(status_code=500, headers=headers, has_body=True))
|
||||
output += self.conn.send(wsproto.events.RejectData(data=b"Internal Server Error"))
|
||||
self.transport.write(output)
|
||||
|
||||
async def run_asgi(self) -> None:
|
||||
try:
|
||||
result = await self.app(self.scope, self.receive, self.send)
|
||||
result = await self.app(self.scope, self.receive, self.send) # type: ignore[func-returns-value]
|
||||
except ClientDisconnected:
|
||||
self.transport.close() # pragma: full coverage
|
||||
except BaseException:
|
||||
self.logger.exception("Exception in ASGI application\n")
|
||||
if not self.handshake_complete:
|
||||
self.send_500_response()
|
||||
self.send_500_response()
|
||||
self.transport.close()
|
||||
else:
|
||||
if not self.handshake_complete:
|
||||
msg = "ASGI callable returned without completing handshake."
|
||||
self.logger.error(msg)
|
||||
self.logger.error("ASGI callable returned without completing handshake.")
|
||||
self.send_500_response()
|
||||
self.transport.close()
|
||||
elif result is not None:
|
||||
msg = "ASGI callable should return None, but returned '%s'."
|
||||
self.logger.error(msg, result)
|
||||
self.logger.error("ASGI callable should return None, but returned '%s'.", result)
|
||||
self.transport.close()
|
||||
|
||||
async def send(self, message: "ASGISendEvent") -> None:
|
||||
async def send(self, message: ASGISendEvent) -> None:
|
||||
await self.writable.wait()
|
||||
|
||||
message_type = message["type"]
|
||||
|
||||
if not self.handshake_complete:
|
||||
if message_type == "websocket.accept":
|
||||
message = typing.cast("WebSocketAcceptEvent", message)
|
||||
message = cast(WebSocketAcceptEvent, message)
|
||||
self.logger.info(
|
||||
'%s - "WebSocket %s" [accepted]',
|
||||
self.scope["client"],
|
||||
get_client_addr(self.scope),
|
||||
get_path_with_query_string(self.scope),
|
||||
)
|
||||
subprotocol = message.get("subprotocol")
|
||||
extra_headers = self.default_headers + list(message.get("headers", []))
|
||||
extensions: typing.List[Extension] = []
|
||||
extensions: list[Extension] = []
|
||||
if self.config.ws_per_message_deflate:
|
||||
extensions.append(PerMessageDeflate())
|
||||
if not self.transport.is_closing():
|
||||
@@ -281,7 +281,7 @@ class WSProtocol(asyncio.Protocol):
|
||||
self.queue.put_nowait({"type": "websocket.disconnect", "code": 1006})
|
||||
self.logger.info(
|
||||
'%s - "WebSocket %s" 403',
|
||||
self.scope["client"],
|
||||
get_client_addr(self.scope),
|
||||
get_path_with_query_string(self.scope),
|
||||
)
|
||||
self.handshake_complete = True
|
||||
@@ -291,50 +291,85 @@ class WSProtocol(asyncio.Protocol):
|
||||
self.transport.write(output)
|
||||
self.transport.close()
|
||||
|
||||
elif message_type == "websocket.http.response.start":
|
||||
message = cast(WebSocketResponseStartEvent, message)
|
||||
# ensure status code is in the valid range
|
||||
if not (100 <= message["status"] < 600):
|
||||
msg = "Invalid HTTP status code '%d' in response."
|
||||
raise RuntimeError(msg % message["status"])
|
||||
self.logger.info(
|
||||
'%s - "WebSocket %s" %d',
|
||||
get_client_addr(self.scope),
|
||||
get_path_with_query_string(self.scope),
|
||||
message["status"],
|
||||
)
|
||||
self.handshake_complete = True
|
||||
event = events.RejectConnection(
|
||||
status_code=message["status"],
|
||||
headers=list(message["headers"]),
|
||||
has_body=True,
|
||||
)
|
||||
output = self.conn.send(event)
|
||||
self.transport.write(output)
|
||||
self.response_started = True
|
||||
|
||||
else:
|
||||
msg = (
|
||||
"Expected ASGI message 'websocket.accept' or 'websocket.close', "
|
||||
"Expected ASGI message 'websocket.accept', 'websocket.close' "
|
||||
"or 'websocket.http.response.start' "
|
||||
"but got '%s'."
|
||||
)
|
||||
raise RuntimeError(msg % message_type)
|
||||
|
||||
elif not self.close_sent:
|
||||
if message_type == "websocket.send":
|
||||
message = typing.cast("WebSocketSendEvent", message)
|
||||
bytes_data = message.get("bytes")
|
||||
text_data = message.get("text")
|
||||
data = text_data if bytes_data is None else bytes_data
|
||||
output = self.conn.send(
|
||||
wsproto.events.Message(data=data) # type: ignore[type-var]
|
||||
)
|
||||
if not self.transport.is_closing():
|
||||
self.transport.write(output)
|
||||
elif not self.close_sent and not self.response_started:
|
||||
try:
|
||||
if message_type == "websocket.send":
|
||||
message = cast(WebSocketSendEvent, message)
|
||||
bytes_data = message.get("bytes")
|
||||
text_data = message.get("text")
|
||||
data = text_data if bytes_data is None else bytes_data
|
||||
output = self.conn.send(wsproto.events.Message(data=data)) # type: ignore
|
||||
if not self.transport.is_closing():
|
||||
self.transport.write(output)
|
||||
|
||||
elif message_type == "websocket.close":
|
||||
message = typing.cast("WebSocketCloseEvent", message)
|
||||
self.close_sent = True
|
||||
code = message.get("code", 1000)
|
||||
reason = message.get("reason", "") or ""
|
||||
self.queue.put_nowait({"type": "websocket.disconnect", "code": code})
|
||||
output = self.conn.send(
|
||||
wsproto.events.CloseConnection(code=code, reason=reason)
|
||||
)
|
||||
if not self.transport.is_closing():
|
||||
self.transport.write(output)
|
||||
elif message_type == "websocket.close":
|
||||
message = cast(WebSocketCloseEvent, message)
|
||||
self.close_sent = True
|
||||
code = message.get("code", 1000)
|
||||
reason = message.get("reason", "") or ""
|
||||
self.queue.put_nowait({"type": "websocket.disconnect", "code": code, "reason": reason})
|
||||
output = self.conn.send(wsproto.events.CloseConnection(code=code, reason=reason))
|
||||
if not self.transport.is_closing():
|
||||
self.transport.write(output)
|
||||
self.transport.close()
|
||||
|
||||
else:
|
||||
msg = "Expected ASGI message 'websocket.send' or 'websocket.close', but got '%s'."
|
||||
raise RuntimeError(msg % message_type)
|
||||
except LocalProtocolError as exc:
|
||||
raise ClientDisconnected from exc
|
||||
elif self.response_started:
|
||||
if message_type == "websocket.http.response.body":
|
||||
message = cast("WebSocketResponseBodyEvent", message)
|
||||
body_finished = not message.get("more_body", False)
|
||||
reject_data = events.RejectData(data=message["body"], body_finished=body_finished)
|
||||
output = self.conn.send(reject_data)
|
||||
self.transport.write(output)
|
||||
|
||||
if body_finished:
|
||||
self.queue.put_nowait({"type": "websocket.disconnect", "code": 1006})
|
||||
self.close_sent = True
|
||||
self.transport.close()
|
||||
|
||||
else:
|
||||
msg = (
|
||||
"Expected ASGI message 'websocket.send' or 'websocket.close',"
|
||||
" but got '%s'."
|
||||
)
|
||||
msg = "Expected ASGI message 'websocket.http.response.body' but got '%s'."
|
||||
raise RuntimeError(msg % message_type)
|
||||
|
||||
else:
|
||||
msg = "Unexpected ASGI message '%s', after sending 'websocket.close'."
|
||||
raise RuntimeError(msg % message_type)
|
||||
|
||||
async def receive(self) -> "WebSocketEvent":
|
||||
async def receive(self) -> WebSocketEvent:
|
||||
message = await self.queue.get()
|
||||
if self.read_paused and self.queue.empty():
|
||||
self.read_paused = False
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
@@ -7,22 +10,24 @@ import socket
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Generator, Sequence
|
||||
from email.utils import formatdate
|
||||
from types import FrameType
|
||||
from typing import TYPE_CHECKING, List, Optional, Sequence, Set, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Union
|
||||
|
||||
import click
|
||||
|
||||
from uvicorn._compat import asyncio_run
|
||||
from uvicorn.config import Config
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from uvicorn.protocols.http.h11_impl import H11Protocol
|
||||
from uvicorn.protocols.http.httptools_impl import HttpToolsProtocol
|
||||
from uvicorn.protocols.websockets.websockets_impl import WebSocketProtocol
|
||||
from uvicorn.protocols.websockets.websockets_sansio_impl import WebSocketsSansIOProtocol
|
||||
from uvicorn.protocols.websockets.wsproto_impl import WSProtocol
|
||||
|
||||
Protocols = Union[H11Protocol, HttpToolsProtocol, WSProtocol, WebSocketProtocol]
|
||||
|
||||
Protocols = Union[H11Protocol, HttpToolsProtocol, WSProtocol, WebSocketProtocol, WebSocketsSansIOProtocol]
|
||||
|
||||
HANDLED_SIGNALS = (
|
||||
signal.SIGINT, # Unix signal 2. Sent by Ctrl+C.
|
||||
@@ -41,9 +46,9 @@ class ServerState:
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.total_requests = 0
|
||||
self.connections: Set["Protocols"] = set()
|
||||
self.tasks: Set[asyncio.Task] = set()
|
||||
self.default_headers: List[Tuple[bytes, bytes]] = []
|
||||
self.connections: set[Protocols] = set()
|
||||
self.tasks: set[asyncio.Task[None]] = set()
|
||||
self.default_headers: list[tuple[bytes, bytes]] = []
|
||||
|
||||
|
||||
class Server:
|
||||
@@ -56,11 +61,16 @@ class Server:
|
||||
self.force_exit = False
|
||||
self.last_notified = 0.0
|
||||
|
||||
def run(self, sockets: Optional[List[socket.socket]] = None) -> None:
|
||||
self.config.setup_event_loop()
|
||||
return asyncio.run(self.serve(sockets=sockets))
|
||||
self._captured_signals: list[int] = []
|
||||
|
||||
async def serve(self, sockets: Optional[List[socket.socket]] = None) -> None:
|
||||
def run(self, sockets: list[socket.socket] | None = None) -> None:
|
||||
return asyncio_run(self.serve(sockets=sockets), loop_factory=self.config.get_loop_factory())
|
||||
|
||||
async def serve(self, sockets: list[socket.socket] | None = None) -> None:
|
||||
with self.capture_signals():
|
||||
await self._serve(sockets)
|
||||
|
||||
async def _serve(self, sockets: list[socket.socket] | None = None) -> None:
|
||||
process_id = os.getpid()
|
||||
|
||||
config = self.config
|
||||
@@ -69,8 +79,6 @@ class Server:
|
||||
|
||||
self.lifespan = config.lifespan_class(config)
|
||||
|
||||
self.install_signal_handlers()
|
||||
|
||||
message = "Started server process [%d]"
|
||||
color_message = "Started server process [" + click.style("%d", fg="cyan") + "]"
|
||||
logger.info(message, process_id, extra={"color_message": color_message})
|
||||
@@ -85,7 +93,7 @@ class Server:
|
||||
color_message = "Finished server process [" + click.style("%d", fg="cyan") + "]"
|
||||
logger.info(message, process_id, extra={"color_message": color_message})
|
||||
|
||||
async def startup(self, sockets: Optional[List[socket.socket]] = None) -> None:
|
||||
async def startup(self, sockets: list[socket.socket] | None = None) -> None:
|
||||
await self.lifespan.startup()
|
||||
if self.lifespan.should_exit:
|
||||
self.should_exit = True
|
||||
@@ -94,7 +102,7 @@ class Server:
|
||||
config = self.config
|
||||
|
||||
def create_protocol(
|
||||
_loop: Optional[asyncio.AbstractEventLoop] = None,
|
||||
_loop: asyncio.AbstractEventLoop | None = None,
|
||||
) -> asyncio.Protocol:
|
||||
return config.http_protocol_class( # type: ignore[call-arg]
|
||||
config=config,
|
||||
@@ -106,13 +114,13 @@ class Server:
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
listeners: Sequence[socket.SocketType]
|
||||
if sockets is not None:
|
||||
if sockets is not None: # pragma: full coverage
|
||||
# Explicitly passed a list of open sockets.
|
||||
# We use this when the server is run from a Gunicorn worker.
|
||||
|
||||
def _share_socket(
|
||||
sock: socket.SocketType,
|
||||
) -> socket.SocketType: # pragma py-linux pragma: py-darwin
|
||||
) -> socket.SocketType: # pragma py-not-win32
|
||||
# Windows requires the socket be explicitly shared across
|
||||
# multiple workers (processes).
|
||||
from socket import fromshare # type: ignore[attr-defined]
|
||||
@@ -120,23 +128,19 @@ class Server:
|
||||
sock_data = sock.share(os.getpid()) # type: ignore[attr-defined]
|
||||
return fromshare(sock_data)
|
||||
|
||||
self.servers: List[asyncio.base_events.Server] = []
|
||||
self.servers: list[asyncio.base_events.Server] = []
|
||||
for sock in sockets:
|
||||
is_windows = platform.system() == "Windows"
|
||||
if config.workers > 1 and is_windows: # pragma: py-not-win32
|
||||
sock = _share_socket(sock) # type: ignore[assignment]
|
||||
server = await loop.create_server(
|
||||
create_protocol, sock=sock, ssl=config.ssl, backlog=config.backlog
|
||||
)
|
||||
server = await loop.create_server(create_protocol, sock=sock, ssl=config.ssl, backlog=config.backlog)
|
||||
self.servers.append(server)
|
||||
listeners = sockets
|
||||
|
||||
elif config.fd is not None: # pragma: py-win32
|
||||
# Use an existing socket, from a file descriptor.
|
||||
sock = socket.fromfd(config.fd, socket.AF_UNIX, socket.SOCK_STREAM)
|
||||
server = await loop.create_server(
|
||||
create_protocol, sock=sock, ssl=config.ssl, backlog=config.backlog
|
||||
)
|
||||
server = await loop.create_server(create_protocol, sock=sock, ssl=config.ssl, backlog=config.backlog)
|
||||
assert server.sockets is not None # mypy
|
||||
listeners = server.sockets
|
||||
self.servers = [server]
|
||||
@@ -145,7 +149,7 @@ class Server:
|
||||
# Create a socket using UNIX domain socket.
|
||||
uds_perms = 0o666
|
||||
if os.path.exists(config.uds):
|
||||
uds_perms = os.stat(config.uds).st_mode
|
||||
uds_perms = os.stat(config.uds).st_mode # pragma: full coverage
|
||||
server = await loop.create_unix_server(
|
||||
create_protocol, path=config.uds, ssl=config.ssl, backlog=config.backlog
|
||||
)
|
||||
@@ -178,7 +182,7 @@ class Server:
|
||||
else:
|
||||
# We're most likely running multiple workers, so a message has already been
|
||||
# logged by `config.bind_socket()`.
|
||||
pass
|
||||
pass # pragma: full coverage
|
||||
|
||||
self.started = True
|
||||
|
||||
@@ -193,9 +197,7 @@ class Server:
|
||||
)
|
||||
|
||||
elif config.uds is not None: # pragma: py-win32
|
||||
logger.info(
|
||||
"Uvicorn running on unix socket %s (Press CTRL+C to quit)", config.uds
|
||||
)
|
||||
logger.info("Uvicorn running on unix socket %s (Press CTRL+C to quit)", config.uds)
|
||||
|
||||
else:
|
||||
addr_format = "%s://%s:%d"
|
||||
@@ -210,11 +212,7 @@ class Server:
|
||||
|
||||
protocol_name = "https" if config.ssl else "http"
|
||||
message = f"Uvicorn running on {addr_format} (Press CTRL+C to quit)"
|
||||
color_message = (
|
||||
"Uvicorn running on "
|
||||
+ click.style(addr_format, bold=True)
|
||||
+ " (Press CTRL+C to quit)"
|
||||
)
|
||||
color_message = "Uvicorn running on " + click.style(addr_format, bold=True) + " (Press CTRL+C to quit)"
|
||||
logger.info(
|
||||
message,
|
||||
protocol_name,
|
||||
@@ -243,31 +241,33 @@ class Server:
|
||||
else:
|
||||
date_header = []
|
||||
|
||||
self.server_state.default_headers = (
|
||||
date_header + self.config.encoded_headers
|
||||
)
|
||||
self.server_state.default_headers = date_header + self.config.encoded_headers
|
||||
|
||||
# Callback to `callback_notify` once every `timeout_notify` seconds.
|
||||
if self.config.callback_notify is not None:
|
||||
if current_time - self.last_notified > self.config.timeout_notify:
|
||||
if current_time - self.last_notified > self.config.timeout_notify: # pragma: full coverage
|
||||
self.last_notified = current_time
|
||||
await self.config.callback_notify()
|
||||
|
||||
# Determine if we should exit.
|
||||
if self.should_exit:
|
||||
return True
|
||||
if self.config.limit_max_requests is not None:
|
||||
return self.server_state.total_requests >= self.config.limit_max_requests
|
||||
|
||||
max_requests = self.config.limit_max_requests
|
||||
if max_requests is not None and self.server_state.total_requests >= max_requests:
|
||||
logger.warning(f"Maximum request limit of {max_requests} exceeded. Terminating process.")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def shutdown(self, sockets: Optional[List[socket.socket]] = None) -> None:
|
||||
async def shutdown(self, sockets: list[socket.socket] | None = None) -> None:
|
||||
logger.info("Shutting down")
|
||||
|
||||
# Stop accepting new connections.
|
||||
for server in self.servers:
|
||||
server.close()
|
||||
for sock in sockets or []:
|
||||
sock.close()
|
||||
sock.close() # pragma: full coverage
|
||||
|
||||
# Request shutdown on all existing connections.
|
||||
for connection in list(self.server_state.connections):
|
||||
@@ -286,10 +286,7 @@ class Server:
|
||||
len(self.server_state.tasks),
|
||||
)
|
||||
for t in self.server_state.tasks:
|
||||
if sys.version_info < (3, 9): # pragma: py-gte-39
|
||||
t.cancel()
|
||||
else: # pragma: py-lt-39
|
||||
t.cancel(msg="Task cancelled, timeout graceful shutdown exceeded")
|
||||
t.cancel(msg="Task cancelled, timeout graceful shutdown exceeded")
|
||||
|
||||
# Send the lifespan shutdown event, and wait for application shutdown.
|
||||
if not self.force_exit:
|
||||
@@ -313,23 +310,29 @@ class Server:
|
||||
for server in self.servers:
|
||||
await server.wait_closed()
|
||||
|
||||
def install_signal_handlers(self) -> None:
|
||||
@contextlib.contextmanager
|
||||
def capture_signals(self) -> Generator[None, None, None]:
|
||||
# Signals can only be listened to from the main thread.
|
||||
if threading.current_thread() is not threading.main_thread():
|
||||
# Signals can only be listened to from the main thread.
|
||||
yield
|
||||
return
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
# always use signal.signal, even if loop.add_signal_handler is available
|
||||
# this allows to restore previous signal handlers later on
|
||||
original_handlers = {sig: signal.signal(sig, self.handle_exit) for sig in HANDLED_SIGNALS}
|
||||
try:
|
||||
for sig in HANDLED_SIGNALS:
|
||||
loop.add_signal_handler(sig, self.handle_exit, sig, None)
|
||||
except NotImplementedError: # pragma: no cover
|
||||
# Windows
|
||||
for sig in HANDLED_SIGNALS:
|
||||
signal.signal(sig, self.handle_exit)
|
||||
yield
|
||||
finally:
|
||||
for sig, handler in original_handlers.items():
|
||||
signal.signal(sig, handler)
|
||||
# If we did gracefully shut down due to a signal, try to
|
||||
# trigger the expected behaviour now; multiple signals would be
|
||||
# done LIFO, see https://stackoverflow.com/questions/48434964
|
||||
for captured_signal in reversed(self._captured_signals):
|
||||
signal.raise_signal(captured_signal)
|
||||
|
||||
def handle_exit(self, sig: int, frame: Optional[FrameType]) -> None:
|
||||
def handle_exit(self, sig: int, frame: FrameType | None) -> None:
|
||||
self._captured_signals.append(sig)
|
||||
if self.should_exit and sig == signal.SIGINT:
|
||||
self.force_exit = True
|
||||
self.force_exit = True # pragma: full coverage
|
||||
else:
|
||||
self.should_exit = True
|
||||
|
||||
@@ -1,21 +1,16 @@
|
||||
from typing import TYPE_CHECKING, Type
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from uvicorn.supervisors.basereload import BaseReload
|
||||
from uvicorn.supervisors.multiprocess import Multiprocess
|
||||
|
||||
if TYPE_CHECKING:
|
||||
ChangeReload: Type[BaseReload]
|
||||
ChangeReload: type[BaseReload]
|
||||
else:
|
||||
try:
|
||||
from uvicorn.supervisors.watchfilesreload import (
|
||||
WatchFilesReload as ChangeReload,
|
||||
)
|
||||
from uvicorn.supervisors.watchfilesreload import WatchFilesReload as ChangeReload
|
||||
except ImportError: # pragma: no cover
|
||||
try:
|
||||
from uvicorn.supervisors.watchgodreload import (
|
||||
WatchGodReload as ChangeReload,
|
||||
)
|
||||
except ImportError:
|
||||
from uvicorn.supervisors.statreload import StatReload as ChangeReload
|
||||
from uvicorn.supervisors.statreload import StatReload as ChangeReload
|
||||
|
||||
__all__ = ["Multiprocess", "ChangeReload"]
|
||||
|
||||
@@ -1,12 +1,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
import threading
|
||||
from collections.abc import Iterator
|
||||
from pathlib import Path
|
||||
from socket import socket
|
||||
from types import FrameType
|
||||
from typing import Callable, Iterator, List, Optional
|
||||
from typing import Callable
|
||||
|
||||
import click
|
||||
|
||||
@@ -25,8 +28,8 @@ class BaseReload:
|
||||
def __init__(
|
||||
self,
|
||||
config: Config,
|
||||
target: Callable[[Optional[List[socket]]], None],
|
||||
sockets: List[socket],
|
||||
target: Callable[[list[socket] | None], None],
|
||||
sockets: list[socket],
|
||||
) -> None:
|
||||
self.config = config
|
||||
self.target = target
|
||||
@@ -34,16 +37,16 @@ class BaseReload:
|
||||
self.should_exit = threading.Event()
|
||||
self.pid = os.getpid()
|
||||
self.is_restarting = False
|
||||
self.reloader_name: Optional[str] = None
|
||||
self.reloader_name: str | None = None
|
||||
|
||||
def signal_handler(self, sig: int, frame: Optional[FrameType]) -> None:
|
||||
def signal_handler(self, sig: int, frame: FrameType | None) -> None: # pragma: full coverage
|
||||
"""
|
||||
A signal handler that is registered with the parent process.
|
||||
"""
|
||||
if sys.platform == "win32" and self.is_restarting:
|
||||
self.is_restarting = False # pragma: py-not-win32
|
||||
self.is_restarting = False
|
||||
else:
|
||||
self.should_exit.set() # pragma: py-win32
|
||||
self.should_exit.set()
|
||||
|
||||
def run(self) -> None:
|
||||
self.startup()
|
||||
@@ -62,10 +65,10 @@ class BaseReload:
|
||||
if self.should_exit.wait(self.config.reload_delay):
|
||||
raise StopIteration()
|
||||
|
||||
def __iter__(self) -> Iterator[Optional[List[Path]]]:
|
||||
def __iter__(self) -> Iterator[list[Path] | None]:
|
||||
return self
|
||||
|
||||
def __next__(self) -> Optional[List[Path]]:
|
||||
def __next__(self) -> list[Path] | None:
|
||||
return self.should_restart()
|
||||
|
||||
def startup(self) -> None:
|
||||
@@ -79,9 +82,7 @@ class BaseReload:
|
||||
for sig in HANDLED_SIGNALS:
|
||||
signal.signal(sig, self.signal_handler)
|
||||
|
||||
self.process = get_subprocess(
|
||||
config=self.config, target=self.target, sockets=self.sockets
|
||||
)
|
||||
self.process = get_subprocess(config=self.config, target=self.target, sockets=self.sockets)
|
||||
self.process.start()
|
||||
|
||||
def restart(self) -> None:
|
||||
@@ -89,13 +90,15 @@ class BaseReload:
|
||||
self.is_restarting = True
|
||||
assert self.process.pid is not None
|
||||
os.kill(self.process.pid, signal.CTRL_C_EVENT)
|
||||
|
||||
# This is a workaround to ensure the Ctrl+C event is processed
|
||||
sys.stdout.write(" ") # This has to be a non-empty string
|
||||
sys.stdout.flush()
|
||||
else: # pragma: py-win32
|
||||
self.process.terminate()
|
||||
self.process.join()
|
||||
|
||||
self.process = get_subprocess(
|
||||
config=self.config, target=self.target, sockets=self.sockets
|
||||
)
|
||||
self.process = get_subprocess(config=self.config, target=self.target, sockets=self.sockets)
|
||||
self.process.start()
|
||||
|
||||
def shutdown(self) -> None:
|
||||
@@ -108,13 +111,11 @@ class BaseReload:
|
||||
for sock in self.sockets:
|
||||
sock.close()
|
||||
|
||||
message = "Stopping reloader process [{}]".format(str(self.pid))
|
||||
color_message = "Stopping reloader process [{}]".format(
|
||||
click.style(str(self.pid), fg="cyan", bold=True)
|
||||
)
|
||||
message = f"Stopping reloader process [{str(self.pid)}]"
|
||||
color_message = "Stopping reloader process [{}]".format(click.style(str(self.pid), fg="cyan", bold=True))
|
||||
logger.info(message, extra={"color_message": color_message})
|
||||
|
||||
def should_restart(self) -> Optional[List[Path]]:
|
||||
def should_restart(self) -> list[Path] | None:
|
||||
raise NotImplementedError("Reload strategies should override should_restart()")
|
||||
|
||||
|
||||
|
||||
@@ -1,74 +1,222 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import signal
|
||||
import threading
|
||||
from multiprocessing.context import SpawnProcess
|
||||
from multiprocessing import Pipe
|
||||
from socket import socket
|
||||
from types import FrameType
|
||||
from typing import Callable, List, Optional
|
||||
from typing import Any, Callable
|
||||
|
||||
import click
|
||||
|
||||
from uvicorn._subprocess import get_subprocess
|
||||
from uvicorn.config import Config
|
||||
|
||||
HANDLED_SIGNALS = (
|
||||
signal.SIGINT, # Unix signal 2. Sent by Ctrl+C.
|
||||
signal.SIGTERM, # Unix signal 15. Sent by `kill <pid>`.
|
||||
)
|
||||
SIGNALS = {
|
||||
getattr(signal, f"SIG{x}"): x
|
||||
for x in "INT TERM BREAK HUP QUIT TTIN TTOU USR1 USR2 WINCH".split()
|
||||
if hasattr(signal, f"SIG{x}")
|
||||
}
|
||||
|
||||
logger = logging.getLogger("uvicorn.error")
|
||||
|
||||
|
||||
class Process:
|
||||
def __init__(
|
||||
self,
|
||||
config: Config,
|
||||
target: Callable[[list[socket] | None], None],
|
||||
sockets: list[socket],
|
||||
) -> None:
|
||||
self.real_target = target
|
||||
|
||||
self.parent_conn, self.child_conn = Pipe()
|
||||
self.process = get_subprocess(config, self.target, sockets)
|
||||
|
||||
def ping(self, timeout: float = 5) -> bool:
|
||||
self.parent_conn.send(b"ping")
|
||||
if self.parent_conn.poll(timeout):
|
||||
self.parent_conn.recv()
|
||||
return True
|
||||
return False
|
||||
|
||||
def pong(self) -> None:
|
||||
self.child_conn.recv()
|
||||
self.child_conn.send(b"pong")
|
||||
|
||||
def always_pong(self) -> None:
|
||||
while True:
|
||||
self.pong()
|
||||
|
||||
def target(self, sockets: list[socket] | None = None) -> Any: # pragma: no cover
|
||||
if os.name == "nt": # pragma: py-not-win32
|
||||
# Windows doesn't support SIGTERM, so we use SIGBREAK instead.
|
||||
# And then we raise SIGTERM when SIGBREAK is received.
|
||||
# https://learn.microsoft.com/zh-cn/cpp/c-runtime-library/reference/signal?view=msvc-170
|
||||
signal.signal(
|
||||
signal.SIGBREAK, # type: ignore[attr-defined]
|
||||
lambda sig, frame: signal.raise_signal(signal.SIGTERM),
|
||||
)
|
||||
|
||||
threading.Thread(target=self.always_pong, daemon=True).start()
|
||||
return self.real_target(sockets)
|
||||
|
||||
def is_alive(self, timeout: float = 5) -> bool:
|
||||
if not self.process.is_alive():
|
||||
return False # pragma: full coverage
|
||||
|
||||
return self.ping(timeout)
|
||||
|
||||
def start(self) -> None:
|
||||
self.process.start()
|
||||
|
||||
def terminate(self) -> None:
|
||||
if self.process.exitcode is None: # Process is still running
|
||||
assert self.process.pid is not None
|
||||
if os.name == "nt": # pragma: py-not-win32
|
||||
# Windows doesn't support SIGTERM.
|
||||
# So send SIGBREAK, and then in process raise SIGTERM.
|
||||
os.kill(self.process.pid, signal.CTRL_BREAK_EVENT) # type: ignore[attr-defined]
|
||||
else:
|
||||
os.kill(self.process.pid, signal.SIGTERM)
|
||||
logger.info(f"Terminated child process [{self.process.pid}]")
|
||||
|
||||
self.parent_conn.close()
|
||||
self.child_conn.close()
|
||||
|
||||
def kill(self) -> None:
|
||||
# In Windows, the method will call `TerminateProcess` to kill the process.
|
||||
# In Unix, the method will send SIGKILL to the process.
|
||||
self.process.kill()
|
||||
|
||||
def join(self) -> None:
|
||||
logger.info(f"Waiting for child process [{self.process.pid}]")
|
||||
self.process.join()
|
||||
|
||||
@property
|
||||
def pid(self) -> int | None:
|
||||
return self.process.pid
|
||||
|
||||
|
||||
class Multiprocess:
|
||||
def __init__(
|
||||
self,
|
||||
config: Config,
|
||||
target: Callable[[Optional[List[socket]]], None],
|
||||
sockets: List[socket],
|
||||
target: Callable[[list[socket] | None], None],
|
||||
sockets: list[socket],
|
||||
) -> None:
|
||||
self.config = config
|
||||
self.target = target
|
||||
self.sockets = sockets
|
||||
self.processes: List[SpawnProcess] = []
|
||||
|
||||
self.processes_num = config.workers
|
||||
self.processes: list[Process] = []
|
||||
|
||||
self.should_exit = threading.Event()
|
||||
self.pid = os.getpid()
|
||||
|
||||
def signal_handler(self, sig: int, frame: Optional[FrameType]) -> None:
|
||||
"""
|
||||
A signal handler that is registered with the parent process.
|
||||
"""
|
||||
self.should_exit.set()
|
||||
self.signal_queue: list[int] = []
|
||||
for sig in SIGNALS:
|
||||
signal.signal(sig, lambda sig, frame: self.signal_queue.append(sig))
|
||||
|
||||
def run(self) -> None:
|
||||
self.startup()
|
||||
self.should_exit.wait()
|
||||
self.shutdown()
|
||||
|
||||
def startup(self) -> None:
|
||||
message = "Started parent process [{}]".format(str(self.pid))
|
||||
color_message = "Started parent process [{}]".format(
|
||||
click.style(str(self.pid), fg="cyan", bold=True)
|
||||
)
|
||||
logger.info(message, extra={"color_message": color_message})
|
||||
|
||||
for sig in HANDLED_SIGNALS:
|
||||
signal.signal(sig, self.signal_handler)
|
||||
|
||||
for _idx in range(self.config.workers):
|
||||
process = get_subprocess(
|
||||
config=self.config, target=self.target, sockets=self.sockets
|
||||
)
|
||||
def init_processes(self) -> None:
|
||||
for _ in range(self.processes_num):
|
||||
process = Process(self.config, self.target, self.sockets)
|
||||
process.start()
|
||||
self.processes.append(process)
|
||||
|
||||
def shutdown(self) -> None:
|
||||
def terminate_all(self) -> None:
|
||||
for process in self.processes:
|
||||
process.terminate()
|
||||
|
||||
def join_all(self) -> None:
|
||||
for process in self.processes:
|
||||
process.join()
|
||||
|
||||
message = "Stopping parent process [{}]".format(str(self.pid))
|
||||
color_message = "Stopping parent process [{}]".format(
|
||||
click.style(str(self.pid), fg="cyan", bold=True)
|
||||
)
|
||||
def restart_all(self) -> None:
|
||||
for idx, process in enumerate(self.processes):
|
||||
process.terminate()
|
||||
process.join()
|
||||
new_process = Process(self.config, self.target, self.sockets)
|
||||
new_process.start()
|
||||
self.processes[idx] = new_process
|
||||
|
||||
def run(self) -> None:
|
||||
message = f"Started parent process [{os.getpid()}]"
|
||||
color_message = "Started parent process [{}]".format(click.style(str(os.getpid()), fg="cyan", bold=True))
|
||||
logger.info(message, extra={"color_message": color_message})
|
||||
|
||||
self.init_processes()
|
||||
|
||||
while not self.should_exit.wait(0.5):
|
||||
self.handle_signals()
|
||||
self.keep_subprocess_alive()
|
||||
|
||||
self.terminate_all()
|
||||
self.join_all()
|
||||
|
||||
message = f"Stopping parent process [{os.getpid()}]"
|
||||
color_message = "Stopping parent process [{}]".format(click.style(str(os.getpid()), fg="cyan", bold=True))
|
||||
logger.info(message, extra={"color_message": color_message})
|
||||
|
||||
def keep_subprocess_alive(self) -> None:
|
||||
if self.should_exit.is_set():
|
||||
return # parent process is exiting, no need to keep subprocess alive
|
||||
|
||||
for idx, process in enumerate(self.processes):
|
||||
if process.is_alive(timeout=self.config.timeout_worker_healthcheck):
|
||||
continue
|
||||
|
||||
process.kill() # process is hung, kill it
|
||||
process.join()
|
||||
|
||||
if self.should_exit.is_set():
|
||||
return # pragma: full coverage
|
||||
|
||||
logger.info(f"Child process [{process.pid}] died")
|
||||
process = Process(self.config, self.target, self.sockets)
|
||||
process.start()
|
||||
self.processes[idx] = process
|
||||
|
||||
def handle_signals(self) -> None:
|
||||
for sig in tuple(self.signal_queue):
|
||||
self.signal_queue.remove(sig)
|
||||
sig_name = SIGNALS[sig]
|
||||
sig_handler = getattr(self, f"handle_{sig_name.lower()}", None)
|
||||
if sig_handler is not None:
|
||||
sig_handler()
|
||||
else: # pragma: no cover
|
||||
logger.debug(f"Received signal {sig_name}, but no handler is defined for it.")
|
||||
|
||||
def handle_int(self) -> None:
|
||||
logger.info("Received SIGINT, exiting.")
|
||||
self.should_exit.set()
|
||||
|
||||
def handle_term(self) -> None:
|
||||
logger.info("Received SIGTERM, exiting.")
|
||||
self.should_exit.set()
|
||||
|
||||
def handle_break(self) -> None: # pragma: py-not-win32
|
||||
logger.info("Received SIGBREAK, exiting.")
|
||||
self.should_exit.set()
|
||||
|
||||
def handle_hup(self) -> None: # pragma: py-win32
|
||||
logger.info("Received SIGHUP, restarting processes.")
|
||||
self.restart_all()
|
||||
|
||||
def handle_ttin(self) -> None: # pragma: py-win32
|
||||
logger.info("Received SIGTTIN, increasing the number of processes.")
|
||||
self.processes_num += 1
|
||||
process = Process(self.config, self.target, self.sockets)
|
||||
process.start()
|
||||
self.processes.append(process)
|
||||
|
||||
def handle_ttou(self) -> None: # pragma: py-win32
|
||||
logger.info("Received SIGTTOU, decreasing number of processes.")
|
||||
if self.processes_num <= 1:
|
||||
logger.info("Already reached one process, cannot decrease the number of processes anymore.")
|
||||
return
|
||||
self.processes_num -= 1
|
||||
process = self.processes.pop()
|
||||
process.terminate()
|
||||
process.join()
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Iterator
|
||||
from pathlib import Path
|
||||
from socket import socket
|
||||
from typing import Callable, Dict, Iterator, List, Optional
|
||||
from typing import Callable
|
||||
|
||||
from uvicorn.config import Config
|
||||
from uvicorn.supervisors.basereload import BaseReload
|
||||
@@ -13,20 +16,17 @@ class StatReload(BaseReload):
|
||||
def __init__(
|
||||
self,
|
||||
config: Config,
|
||||
target: Callable[[Optional[List[socket]]], None],
|
||||
sockets: List[socket],
|
||||
target: Callable[[list[socket] | None], None],
|
||||
sockets: list[socket],
|
||||
) -> None:
|
||||
super().__init__(config, target, sockets)
|
||||
self.reloader_name = "StatReload"
|
||||
self.mtimes: Dict[Path, float] = {}
|
||||
self.mtimes: dict[Path, float] = {}
|
||||
|
||||
if config.reload_excludes or config.reload_includes:
|
||||
logger.warning(
|
||||
"--reload-include and --reload-exclude have no effect unless "
|
||||
"watchfiles is installed."
|
||||
)
|
||||
logger.warning("--reload-include and --reload-exclude have no effect unless watchfiles is installed.")
|
||||
|
||||
def should_restart(self) -> Optional[List[Path]]:
|
||||
def should_restart(self) -> list[Path] | None:
|
||||
self.pause()
|
||||
|
||||
for file in self.iter_py_files():
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from socket import socket
|
||||
from typing import Callable, List, Optional
|
||||
from typing import Callable
|
||||
|
||||
from watchfiles import watch
|
||||
|
||||
@@ -11,20 +13,12 @@ from uvicorn.supervisors.basereload import BaseReload
|
||||
class FileFilter:
|
||||
def __init__(self, config: Config):
|
||||
default_includes = ["*.py"]
|
||||
self.includes = [
|
||||
default
|
||||
for default in default_includes
|
||||
if default not in config.reload_excludes
|
||||
]
|
||||
self.includes = [default for default in default_includes if default not in config.reload_excludes]
|
||||
self.includes.extend(config.reload_includes)
|
||||
self.includes = list(set(self.includes))
|
||||
|
||||
default_excludes = [".*", ".py[cod]", ".sw.*", "~*"]
|
||||
self.excludes = [
|
||||
default
|
||||
for default in default_excludes
|
||||
if default not in config.reload_includes
|
||||
]
|
||||
self.excludes = [default for default in default_excludes if default not in config.reload_includes]
|
||||
self.exclude_dirs = []
|
||||
for e in config.reload_excludes:
|
||||
p = Path(e)
|
||||
@@ -37,19 +31,22 @@ class FileFilter:
|
||||
if is_dir:
|
||||
self.exclude_dirs.append(p)
|
||||
else:
|
||||
self.excludes.append(e)
|
||||
self.excludes.append(e) # pragma: full coverage
|
||||
self.excludes = list(set(self.excludes))
|
||||
|
||||
def __call__(self, path: Path) -> bool:
|
||||
for include_pattern in self.includes:
|
||||
if path.match(include_pattern):
|
||||
if str(path).endswith(include_pattern):
|
||||
return True # pragma: full coverage
|
||||
|
||||
for exclude_dir in self.exclude_dirs:
|
||||
if exclude_dir in path.parents:
|
||||
return False
|
||||
|
||||
for exclude_pattern in self.excludes:
|
||||
if path.match(exclude_pattern):
|
||||
return False
|
||||
return False # pragma: full coverage
|
||||
|
||||
return True
|
||||
return False
|
||||
@@ -59,17 +56,14 @@ class WatchFilesReload(BaseReload):
|
||||
def __init__(
|
||||
self,
|
||||
config: Config,
|
||||
target: Callable[[Optional[List[socket]]], None],
|
||||
sockets: List[socket],
|
||||
target: Callable[[list[socket] | None], None],
|
||||
sockets: list[socket],
|
||||
) -> None:
|
||||
super().__init__(config, target, sockets)
|
||||
self.reloader_name = "WatchFiles"
|
||||
self.reload_dirs = []
|
||||
for directory in config.reload_dirs:
|
||||
if Path.cwd() not in directory.parents:
|
||||
self.reload_dirs.append(directory)
|
||||
if Path.cwd() not in self.reload_dirs:
|
||||
self.reload_dirs.append(Path.cwd())
|
||||
self.reload_dirs.append(directory)
|
||||
|
||||
self.watch_filter = FileFilter(config)
|
||||
self.watcher = watch(
|
||||
@@ -81,7 +75,7 @@ class WatchFilesReload(BaseReload):
|
||||
yield_on_timeout=True,
|
||||
)
|
||||
|
||||
def should_restart(self) -> Optional[List[Path]]:
|
||||
def should_restart(self) -> list[Path] | None:
|
||||
self.pause()
|
||||
|
||||
changes = next(self.watcher)
|
||||
|
||||
@@ -1,158 +0,0 @@
|
||||
import logging
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from socket import socket
|
||||
from typing import TYPE_CHECKING, Callable, Dict, List, Optional
|
||||
|
||||
from watchgod import DefaultWatcher
|
||||
|
||||
from uvicorn.config import Config
|
||||
from uvicorn.supervisors.basereload import BaseReload
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import os
|
||||
|
||||
DirEntry = os.DirEntry[str]
|
||||
|
||||
logger = logging.getLogger("uvicorn.error")
|
||||
|
||||
|
||||
class CustomWatcher(DefaultWatcher):
|
||||
def __init__(self, root_path: Path, config: Config):
|
||||
default_includes = ["*.py"]
|
||||
self.includes = [
|
||||
default
|
||||
for default in default_includes
|
||||
if default not in config.reload_excludes
|
||||
]
|
||||
self.includes.extend(config.reload_includes)
|
||||
self.includes = list(set(self.includes))
|
||||
|
||||
default_excludes = [".*", ".py[cod]", ".sw.*", "~*"]
|
||||
self.excludes = [
|
||||
default
|
||||
for default in default_excludes
|
||||
if default not in config.reload_includes
|
||||
]
|
||||
self.excludes.extend(config.reload_excludes)
|
||||
self.excludes = list(set(self.excludes))
|
||||
|
||||
self.watched_dirs: Dict[str, bool] = {}
|
||||
self.watched_files: Dict[str, bool] = {}
|
||||
self.dirs_includes = set(config.reload_dirs)
|
||||
self.dirs_excludes = set(config.reload_dirs_excludes)
|
||||
self.resolved_root = root_path
|
||||
super().__init__(str(root_path))
|
||||
|
||||
def should_watch_file(self, entry: "DirEntry") -> bool:
|
||||
cached_result = self.watched_files.get(entry.path)
|
||||
if cached_result is not None:
|
||||
return cached_result
|
||||
|
||||
entry_path = Path(entry)
|
||||
|
||||
# cwd is not verified through should_watch_dir, so we need to verify here
|
||||
if entry_path.parent == Path.cwd() and Path.cwd() not in self.dirs_includes:
|
||||
self.watched_files[entry.path] = False
|
||||
return False
|
||||
for include_pattern in self.includes:
|
||||
if entry_path.match(include_pattern):
|
||||
for exclude_pattern in self.excludes:
|
||||
if entry_path.match(exclude_pattern):
|
||||
self.watched_files[entry.path] = False
|
||||
return False
|
||||
self.watched_files[entry.path] = True
|
||||
return True
|
||||
self.watched_files[entry.path] = False
|
||||
return False
|
||||
|
||||
def should_watch_dir(self, entry: "DirEntry") -> bool:
|
||||
cached_result = self.watched_dirs.get(entry.path)
|
||||
if cached_result is not None:
|
||||
return cached_result
|
||||
|
||||
entry_path = Path(entry)
|
||||
|
||||
if entry_path in self.dirs_excludes:
|
||||
self.watched_dirs[entry.path] = False
|
||||
return False
|
||||
|
||||
for exclude_pattern in self.excludes:
|
||||
if entry_path.match(exclude_pattern):
|
||||
is_watched = False
|
||||
if entry_path in self.dirs_includes:
|
||||
is_watched = True
|
||||
|
||||
for directory in self.dirs_includes:
|
||||
if directory in entry_path.parents:
|
||||
is_watched = True
|
||||
|
||||
if is_watched:
|
||||
logger.debug(
|
||||
"WatchGodReload detected a new excluded dir '%s' in '%s'; "
|
||||
"Adding to exclude list.",
|
||||
entry_path.relative_to(self.resolved_root),
|
||||
str(self.resolved_root),
|
||||
)
|
||||
self.watched_dirs[entry.path] = False
|
||||
self.dirs_excludes.add(entry_path)
|
||||
return False
|
||||
|
||||
if entry_path in self.dirs_includes:
|
||||
self.watched_dirs[entry.path] = True
|
||||
return True
|
||||
|
||||
for directory in self.dirs_includes:
|
||||
if directory in entry_path.parents:
|
||||
self.watched_dirs[entry.path] = True
|
||||
return True
|
||||
|
||||
for include_pattern in self.includes:
|
||||
if entry_path.match(include_pattern):
|
||||
logger.info(
|
||||
"WatchGodReload detected a new reload dir '%s' in '%s'; "
|
||||
"Adding to watch list.",
|
||||
str(entry_path.relative_to(self.resolved_root)),
|
||||
str(self.resolved_root),
|
||||
)
|
||||
self.dirs_includes.add(entry_path)
|
||||
self.watched_dirs[entry.path] = True
|
||||
return True
|
||||
|
||||
self.watched_dirs[entry.path] = False
|
||||
return False
|
||||
|
||||
|
||||
class WatchGodReload(BaseReload):
|
||||
def __init__(
|
||||
self,
|
||||
config: Config,
|
||||
target: Callable[[Optional[List[socket]]], None],
|
||||
sockets: List[socket],
|
||||
) -> None:
|
||||
warnings.warn(
|
||||
'"watchgod" is deprecated, you should switch '
|
||||
"to watchfiles (`pip install watchfiles`).",
|
||||
DeprecationWarning,
|
||||
)
|
||||
super().__init__(config, target, sockets)
|
||||
self.reloader_name = "WatchGod"
|
||||
self.watchers = []
|
||||
reload_dirs = []
|
||||
for directory in config.reload_dirs:
|
||||
if Path.cwd() not in directory.parents:
|
||||
reload_dirs.append(directory)
|
||||
if Path.cwd() not in reload_dirs:
|
||||
reload_dirs.append(Path.cwd())
|
||||
for w in reload_dirs:
|
||||
self.watchers.append(CustomWatcher(w.resolve(), self.config))
|
||||
|
||||
def should_restart(self) -> Optional[List[Path]]:
|
||||
self.pause()
|
||||
|
||||
for watcher in self.watchers:
|
||||
change = watcher.check()
|
||||
if change != set():
|
||||
return list({Path(c[1]) for c in change})
|
||||
|
||||
return None
|
||||
@@ -1,14 +1,24 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import signal
|
||||
import sys
|
||||
from typing import Any, Dict
|
||||
import warnings
|
||||
from typing import Any
|
||||
|
||||
from gunicorn.arbiter import Arbiter
|
||||
from gunicorn.workers.base import Worker
|
||||
|
||||
from uvicorn._compat import asyncio_run
|
||||
from uvicorn.config import Config
|
||||
from uvicorn.main import Server
|
||||
from uvicorn.server import Server
|
||||
|
||||
warnings.warn(
|
||||
"The `uvicorn.workers` module is deprecated. Please use `uvicorn-worker` package instead.\n"
|
||||
"For more details, see https://github.com/Kludex/uvicorn-worker.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
|
||||
class UvicornWorker(Worker):
|
||||
@@ -17,10 +27,10 @@ class UvicornWorker(Worker):
|
||||
rather than a WSGI callable.
|
||||
"""
|
||||
|
||||
CONFIG_KWARGS: Dict[str, Any] = {"loop": "auto", "http": "auto"}
|
||||
CONFIG_KWARGS: dict[str, Any] = {"loop": "auto", "http": "auto"}
|
||||
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
super(UvicornWorker, self).__init__(*args, **kwargs)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
logger = logging.getLogger("uvicorn.error")
|
||||
logger.handlers = self.log.error_log.handlers
|
||||
@@ -61,14 +71,10 @@ class UvicornWorker(Worker):
|
||||
|
||||
self.config = Config(**config_kwargs)
|
||||
|
||||
def init_process(self) -> None:
|
||||
self.config.setup_event_loop()
|
||||
super(UvicornWorker, self).init_process()
|
||||
|
||||
def init_signals(self) -> None:
|
||||
# Reset signals so Gunicorn doesn't swallow subprocess return codes
|
||||
# other signals are set up by Server.install_signal_handlers()
|
||||
# See: https://github.com/encode/uvicorn/issues/894
|
||||
# See: https://github.com/Kludex/uvicorn/issues/894
|
||||
for s in self.SIGNALS:
|
||||
signal.signal(s, signal.SIG_DFL)
|
||||
|
||||
@@ -79,7 +85,7 @@ class UvicornWorker(Worker):
|
||||
def _install_sigquit_handler(self) -> None:
|
||||
"""Install a SIGQUIT handler on workers.
|
||||
|
||||
- https://github.com/encode/uvicorn/issues/1116
|
||||
- https://github.com/Kludex/uvicorn/issues/1116
|
||||
- https://github.com/benoitc/gunicorn/issues/2604
|
||||
"""
|
||||
|
||||
@@ -95,7 +101,7 @@ class UvicornWorker(Worker):
|
||||
sys.exit(Arbiter.WORKER_BOOT_ERROR)
|
||||
|
||||
def run(self) -> None:
|
||||
return asyncio.run(self._serve())
|
||||
return asyncio_run(self._serve(), loop_factory=self.config.get_loop_factory())
|
||||
|
||||
async def callback_notify(self) -> None:
|
||||
self.notify()
|
||||
|
||||
Reference in New Issue
Block a user