This commit is contained in:
@@ -1,17 +1,42 @@
|
||||
import typing
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from collections.abc import Iterator
|
||||
from typing import Any, Protocol
|
||||
|
||||
if sys.version_info >= (3, 10): # pragma: no cover
|
||||
from typing import ParamSpec
|
||||
else: # pragma: no cover
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
from starlette.types import ASGIApp
|
||||
|
||||
P = ParamSpec("P")
|
||||
|
||||
|
||||
class _MiddlewareFactory(Protocol[P]):
|
||||
def __call__(self, app: ASGIApp, /, *args: P.args, **kwargs: P.kwargs) -> ASGIApp: ... # pragma: no cover
|
||||
|
||||
|
||||
class Middleware:
|
||||
def __init__(self, cls: type, **options: typing.Any) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
cls: _MiddlewareFactory[P],
|
||||
*args: P.args,
|
||||
**kwargs: P.kwargs,
|
||||
) -> None:
|
||||
self.cls = cls
|
||||
self.options = options
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
|
||||
def __iter__(self) -> typing.Iterator:
|
||||
as_tuple = (self.cls, self.options)
|
||||
def __iter__(self) -> Iterator[Any]:
|
||||
as_tuple = (self.cls, self.args, self.kwargs)
|
||||
return iter(as_tuple)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
class_name = self.__class__.__name__
|
||||
option_strings = [f"{key}={value!r}" for key, value in self.options.items()]
|
||||
args_repr = ", ".join([self.cls.__name__] + option_strings)
|
||||
args_strings = [f"{value!r}" for value in self.args]
|
||||
option_strings = [f"{key}={value!r}" for key, value in self.kwargs.items()]
|
||||
name = getattr(self.cls, "__name__", "")
|
||||
args_repr = ", ".join([name] + args_strings + option_strings)
|
||||
return f"{class_name}({args_repr})"
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
import typing
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Callable
|
||||
|
||||
from starlette.authentication import (
|
||||
AuthCredentials,
|
||||
@@ -16,15 +18,13 @@ class AuthenticationMiddleware:
|
||||
self,
|
||||
app: ASGIApp,
|
||||
backend: AuthenticationBackend,
|
||||
on_error: typing.Optional[
|
||||
typing.Callable[[HTTPConnection, AuthenticationError], Response]
|
||||
] = None,
|
||||
on_error: Callable[[HTTPConnection, AuthenticationError], Response] | None = None,
|
||||
) -> None:
|
||||
self.app = app
|
||||
self.backend = backend
|
||||
self.on_error: typing.Callable[
|
||||
[HTTPConnection, AuthenticationError], Response
|
||||
] = (on_error if on_error is not None else self.default_on_error)
|
||||
self.on_error: Callable[[HTTPConnection, AuthenticationError], Response] = (
|
||||
on_error if on_error is not None else self.default_on_error
|
||||
)
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
if scope["type"] not in ["http", "websocket"]:
|
||||
|
||||
@@ -1,23 +1,100 @@
|
||||
import typing
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncGenerator, AsyncIterable, Awaitable, Mapping, MutableMapping
|
||||
from typing import Any, Callable, TypeVar, Union
|
||||
|
||||
import anyio
|
||||
|
||||
from starlette.background import BackgroundTask
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import ContentStream, Response, StreamingResponse
|
||||
from starlette._utils import collapse_excgroups
|
||||
from starlette.requests import ClientDisconnect, Request
|
||||
from starlette.responses import Response
|
||||
from starlette.types import ASGIApp, Message, Receive, Scope, Send
|
||||
|
||||
RequestResponseEndpoint = typing.Callable[[Request], typing.Awaitable[Response]]
|
||||
DispatchFunction = typing.Callable[
|
||||
[Request, RequestResponseEndpoint], typing.Awaitable[Response]
|
||||
]
|
||||
T = typing.TypeVar("T")
|
||||
RequestResponseEndpoint = Callable[[Request], Awaitable[Response]]
|
||||
DispatchFunction = Callable[[Request, RequestResponseEndpoint], Awaitable[Response]]
|
||||
BodyStreamGenerator = AsyncGenerator[Union[bytes, MutableMapping[str, Any]], None]
|
||||
AsyncContentStream = AsyncIterable[Union[str, bytes, memoryview, MutableMapping[str, Any]]]
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class _CachedRequest(Request):
|
||||
"""
|
||||
If the user calls Request.body() from their dispatch function
|
||||
we cache the entire request body in memory and pass that to downstream middlewares,
|
||||
but if they call Request.stream() then all we do is send an
|
||||
empty body so that downstream things don't hang forever.
|
||||
"""
|
||||
|
||||
def __init__(self, scope: Scope, receive: Receive):
|
||||
super().__init__(scope, receive)
|
||||
self._wrapped_rcv_disconnected = False
|
||||
self._wrapped_rcv_consumed = False
|
||||
self._wrapped_rc_stream = self.stream()
|
||||
|
||||
async def wrapped_receive(self) -> Message:
|
||||
# wrapped_rcv state 1: disconnected
|
||||
if self._wrapped_rcv_disconnected:
|
||||
# we've already sent a disconnect to the downstream app
|
||||
# we don't need to wait to get another one
|
||||
# (although most ASGI servers will just keep sending it)
|
||||
return {"type": "http.disconnect"}
|
||||
# wrapped_rcv state 1: consumed but not yet disconnected
|
||||
if self._wrapped_rcv_consumed:
|
||||
# since the downstream app has consumed us all that is left
|
||||
# is to send it a disconnect
|
||||
if self._is_disconnected:
|
||||
# the middleware has already seen the disconnect
|
||||
# since we know the client is disconnected no need to wait
|
||||
# for the message
|
||||
self._wrapped_rcv_disconnected = True
|
||||
return {"type": "http.disconnect"}
|
||||
# we don't know yet if the client is disconnected or not
|
||||
# so we'll wait until we get that message
|
||||
msg = await self.receive()
|
||||
if msg["type"] != "http.disconnect": # pragma: no cover
|
||||
# at this point a disconnect is all that we should be receiving
|
||||
# if we get something else, things went wrong somewhere
|
||||
raise RuntimeError(f"Unexpected message received: {msg['type']}")
|
||||
self._wrapped_rcv_disconnected = True
|
||||
return msg
|
||||
|
||||
# wrapped_rcv state 3: not yet consumed
|
||||
if getattr(self, "_body", None) is not None:
|
||||
# body() was called, we return it even if the client disconnected
|
||||
self._wrapped_rcv_consumed = True
|
||||
return {
|
||||
"type": "http.request",
|
||||
"body": self._body,
|
||||
"more_body": False,
|
||||
}
|
||||
elif self._stream_consumed:
|
||||
# stream() was called to completion
|
||||
# return an empty body so that downstream apps don't hang
|
||||
# waiting for a disconnect
|
||||
self._wrapped_rcv_consumed = True
|
||||
return {
|
||||
"type": "http.request",
|
||||
"body": b"",
|
||||
"more_body": False,
|
||||
}
|
||||
else:
|
||||
# body() was never called and stream() wasn't consumed
|
||||
try:
|
||||
stream = self.stream()
|
||||
chunk = await stream.__anext__()
|
||||
self._wrapped_rcv_consumed = self._stream_consumed
|
||||
return {
|
||||
"type": "http.request",
|
||||
"body": chunk,
|
||||
"more_body": not self._stream_consumed,
|
||||
}
|
||||
except ClientDisconnect:
|
||||
self._wrapped_rcv_disconnected = True
|
||||
return {"type": "http.disconnect"}
|
||||
|
||||
|
||||
class BaseHTTPMiddleware:
|
||||
def __init__(
|
||||
self, app: ASGIApp, dispatch: typing.Optional[DispatchFunction] = None
|
||||
) -> None:
|
||||
def __init__(self, app: ASGIApp, dispatch: DispatchFunction | None = None) -> None:
|
||||
self.app = app
|
||||
self.dispatch_func = self.dispatch if dispatch is None else dispatch
|
||||
|
||||
@@ -26,35 +103,32 @@ class BaseHTTPMiddleware:
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
|
||||
request = _CachedRequest(scope, receive)
|
||||
wrapped_receive = request.wrapped_receive
|
||||
response_sent = anyio.Event()
|
||||
app_exc: Exception | None = None
|
||||
exception_already_raised = False
|
||||
|
||||
async def call_next(request: Request) -> Response:
|
||||
app_exc: typing.Optional[Exception] = None
|
||||
send_stream, recv_stream = anyio.create_memory_object_stream()
|
||||
|
||||
async def receive_or_disconnect() -> Message:
|
||||
if response_sent.is_set():
|
||||
return {"type": "http.disconnect"}
|
||||
|
||||
async with anyio.create_task_group() as task_group:
|
||||
|
||||
async def wrap(func: typing.Callable[[], typing.Awaitable[T]]) -> T:
|
||||
async def wrap(func: Callable[[], Awaitable[T]]) -> T:
|
||||
result = await func()
|
||||
task_group.cancel_scope.cancel()
|
||||
return result
|
||||
|
||||
task_group.start_soon(wrap, response_sent.wait)
|
||||
message = await wrap(request.receive)
|
||||
message = await wrap(wrapped_receive)
|
||||
|
||||
if response_sent.is_set():
|
||||
return {"type": "http.disconnect"}
|
||||
|
||||
return message
|
||||
|
||||
async def close_recv_stream_on_response_sent() -> None:
|
||||
await response_sent.wait()
|
||||
recv_stream.close()
|
||||
|
||||
async def send_no_error(message: Message) -> None:
|
||||
try:
|
||||
await send_stream.send(message)
|
||||
@@ -65,13 +139,12 @@ class BaseHTTPMiddleware:
|
||||
async def coro() -> None:
|
||||
nonlocal app_exc
|
||||
|
||||
async with send_stream:
|
||||
with send_stream:
|
||||
try:
|
||||
await self.app(scope, receive_or_disconnect, send_no_error)
|
||||
except Exception as exc:
|
||||
app_exc = exc
|
||||
|
||||
task_group.start_soon(close_recv_stream_on_response_sent)
|
||||
task_group.start_soon(coro)
|
||||
|
||||
try:
|
||||
@@ -81,54 +154,82 @@ class BaseHTTPMiddleware:
|
||||
message = await recv_stream.receive()
|
||||
except anyio.EndOfStream:
|
||||
if app_exc is not None:
|
||||
nonlocal exception_already_raised
|
||||
exception_already_raised = True
|
||||
raise app_exc
|
||||
raise RuntimeError("No response returned.")
|
||||
|
||||
assert message["type"] == "http.response.start"
|
||||
|
||||
async def body_stream() -> typing.AsyncGenerator[bytes, None]:
|
||||
async with recv_stream:
|
||||
async for message in recv_stream:
|
||||
assert message["type"] == "http.response.body"
|
||||
body = message.get("body", b"")
|
||||
if body:
|
||||
yield body
|
||||
async def body_stream() -> BodyStreamGenerator:
|
||||
async for message in recv_stream:
|
||||
if message["type"] == "http.response.pathsend":
|
||||
yield message
|
||||
break
|
||||
assert message["type"] == "http.response.body", f"Unexpected message: {message}"
|
||||
body = message.get("body", b"")
|
||||
if body:
|
||||
yield body
|
||||
if not message.get("more_body", False):
|
||||
break
|
||||
|
||||
if app_exc is not None:
|
||||
raise app_exc
|
||||
|
||||
response = _StreamingResponse(
|
||||
status_code=message["status"], content=body_stream(), info=info
|
||||
)
|
||||
response = _StreamingResponse(status_code=message["status"], content=body_stream(), info=info)
|
||||
response.raw_headers = message["headers"]
|
||||
return response
|
||||
|
||||
async with anyio.create_task_group() as task_group:
|
||||
request = Request(scope, receive=receive)
|
||||
response = await self.dispatch_func(request, call_next)
|
||||
await response(scope, receive, send)
|
||||
response_sent.set()
|
||||
streams: anyio.create_memory_object_stream[Message] = anyio.create_memory_object_stream()
|
||||
send_stream, recv_stream = streams
|
||||
with recv_stream, send_stream, collapse_excgroups():
|
||||
async with anyio.create_task_group() as task_group:
|
||||
response = await self.dispatch_func(request, call_next)
|
||||
await response(scope, wrapped_receive, send)
|
||||
response_sent.set()
|
||||
recv_stream.close()
|
||||
if app_exc is not None and not exception_already_raised:
|
||||
raise app_exc
|
||||
|
||||
async def dispatch(
|
||||
self, request: Request, call_next: RequestResponseEndpoint
|
||||
) -> Response:
|
||||
async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
|
||||
raise NotImplementedError() # pragma: no cover
|
||||
|
||||
|
||||
class _StreamingResponse(StreamingResponse):
|
||||
class _StreamingResponse(Response):
|
||||
def __init__(
|
||||
self,
|
||||
content: ContentStream,
|
||||
content: AsyncContentStream,
|
||||
status_code: int = 200,
|
||||
headers: typing.Optional[typing.Mapping[str, str]] = None,
|
||||
media_type: typing.Optional[str] = None,
|
||||
background: typing.Optional[BackgroundTask] = None,
|
||||
info: typing.Optional[typing.Mapping[str, typing.Any]] = None,
|
||||
headers: Mapping[str, str] | None = None,
|
||||
media_type: str | None = None,
|
||||
info: Mapping[str, Any] | None = None,
|
||||
) -> None:
|
||||
self._info = info
|
||||
super().__init__(content, status_code, headers, media_type, background)
|
||||
self.info = info
|
||||
self.body_iterator = content
|
||||
self.status_code = status_code
|
||||
self.media_type = media_type
|
||||
self.init_headers(headers)
|
||||
self.background = None
|
||||
|
||||
async def stream_response(self, send: Send) -> None:
|
||||
if self._info:
|
||||
await send({"type": "http.response.debug", "info": self._info})
|
||||
return await super().stream_response(send)
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
if self.info is not None:
|
||||
await send({"type": "http.response.debug", "info": self.info})
|
||||
await send(
|
||||
{
|
||||
"type": "http.response.start",
|
||||
"status": self.status_code,
|
||||
"headers": self.raw_headers,
|
||||
}
|
||||
)
|
||||
|
||||
should_close_body = True
|
||||
async for chunk in self.body_iterator:
|
||||
if isinstance(chunk, dict):
|
||||
# We got an ASGI message which is not response body (eg: pathsend)
|
||||
should_close_body = False
|
||||
await send(chunk)
|
||||
continue
|
||||
await send({"type": "http.response.body", "body": chunk, "more_body": True})
|
||||
|
||||
if should_close_body:
|
||||
await send({"type": "http.response.body", "body": b"", "more_body": False})
|
||||
|
||||
if self.background:
|
||||
await self.background()
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import re
|
||||
import typing
|
||||
from collections.abc import Sequence
|
||||
|
||||
from starlette.datastructures import Headers, MutableHeaders
|
||||
from starlette.responses import PlainTextResponse, Response
|
||||
@@ -14,12 +16,12 @@ class CORSMiddleware:
|
||||
def __init__(
|
||||
self,
|
||||
app: ASGIApp,
|
||||
allow_origins: typing.Sequence[str] = (),
|
||||
allow_methods: typing.Sequence[str] = ("GET",),
|
||||
allow_headers: typing.Sequence[str] = (),
|
||||
allow_origins: Sequence[str] = (),
|
||||
allow_methods: Sequence[str] = ("GET",),
|
||||
allow_headers: Sequence[str] = (),
|
||||
allow_credentials: bool = False,
|
||||
allow_origin_regex: typing.Optional[str] = None,
|
||||
expose_headers: typing.Sequence[str] = (),
|
||||
allow_origin_regex: str | None = None,
|
||||
expose_headers: Sequence[str] = (),
|
||||
max_age: int = 600,
|
||||
) -> None:
|
||||
if "*" in allow_methods:
|
||||
@@ -94,9 +96,7 @@ class CORSMiddleware:
|
||||
if self.allow_all_origins:
|
||||
return True
|
||||
|
||||
if self.allow_origin_regex is not None and self.allow_origin_regex.fullmatch(
|
||||
origin
|
||||
):
|
||||
if self.allow_origin_regex is not None and self.allow_origin_regex.fullmatch(origin):
|
||||
return True
|
||||
|
||||
return origin in self.allow_origins
|
||||
@@ -139,15 +139,11 @@ class CORSMiddleware:
|
||||
|
||||
return PlainTextResponse("OK", status_code=200, headers=headers)
|
||||
|
||||
async def simple_response(
|
||||
self, scope: Scope, receive: Receive, send: Send, request_headers: Headers
|
||||
) -> None:
|
||||
async def simple_response(self, scope: Scope, receive: Receive, send: Send, request_headers: Headers) -> None:
|
||||
send = functools.partial(self.send, send=send, request_headers=request_headers)
|
||||
await self.app(scope, receive, send)
|
||||
|
||||
async def send(
|
||||
self, message: Message, send: Send, request_headers: Headers
|
||||
) -> None:
|
||||
async def send(self, message: Message, send: Send, request_headers: Headers) -> None:
|
||||
if message["type"] != "http.response.start":
|
||||
await send(message)
|
||||
return
|
||||
|
||||
@@ -1,13 +1,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import html
|
||||
import inspect
|
||||
import sys
|
||||
import traceback
|
||||
import typing
|
||||
|
||||
from starlette._utils import is_async_callable
|
||||
from starlette.concurrency import run_in_threadpool
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import HTMLResponse, PlainTextResponse, Response
|
||||
from starlette.types import ASGIApp, Message, Receive, Scope, Send
|
||||
from starlette.types import ASGIApp, ExceptionHandler, Message, Receive, Scope, Send
|
||||
|
||||
STYLES = """
|
||||
p {
|
||||
@@ -137,7 +139,7 @@ class ServerErrorMiddleware:
|
||||
def __init__(
|
||||
self,
|
||||
app: ASGIApp,
|
||||
handler: typing.Optional[typing.Callable] = None,
|
||||
handler: ExceptionHandler | None = None,
|
||||
debug: bool = False,
|
||||
) -> None:
|
||||
self.app = app
|
||||
@@ -183,9 +185,7 @@ class ServerErrorMiddleware:
|
||||
# to optionally raise the error within the test case.
|
||||
raise exc
|
||||
|
||||
def format_line(
|
||||
self, index: int, line: str, frame_lineno: int, frame_index: int
|
||||
) -> str:
|
||||
def format_line(self, index: int, line: str, frame_lineno: int, frame_index: int) -> str:
|
||||
values = {
|
||||
# HTML escape - line could contain < or >
|
||||
"line": html.escape(line).replace(" ", " "),
|
||||
@@ -199,7 +199,10 @@ class ServerErrorMiddleware:
|
||||
def generate_frame_html(self, frame: inspect.FrameInfo, is_collapsed: bool) -> str:
|
||||
code_context = "".join(
|
||||
self.format_line(
|
||||
index, line, frame.lineno, frame.index # type: ignore[arg-type]
|
||||
index,
|
||||
line,
|
||||
frame.lineno,
|
||||
frame.index, # type: ignore[arg-type]
|
||||
)
|
||||
for index, line in enumerate(frame.code_context or [])
|
||||
)
|
||||
@@ -219,9 +222,7 @@ class ServerErrorMiddleware:
|
||||
return FRAME_TEMPLATE.format(**values)
|
||||
|
||||
def generate_html(self, exc: Exception, limit: int = 7) -> str:
|
||||
traceback_obj = traceback.TracebackException.from_exception(
|
||||
exc, capture_locals=True
|
||||
)
|
||||
traceback_obj = traceback.TracebackException.from_exception(exc, capture_locals=True)
|
||||
|
||||
exc_html = ""
|
||||
is_collapsed = False
|
||||
@@ -232,11 +233,13 @@ class ServerErrorMiddleware:
|
||||
exc_html += self.generate_frame_html(frame, is_collapsed)
|
||||
is_collapsed = True
|
||||
|
||||
if sys.version_info >= (3, 13): # pragma: no cover
|
||||
exc_type_str = traceback_obj.exc_type_str
|
||||
else: # pragma: no cover
|
||||
exc_type_str = traceback_obj.exc_type.__name__
|
||||
|
||||
# escape error class and text
|
||||
error = (
|
||||
f"{html.escape(traceback_obj.exc_type.__name__)}: "
|
||||
f"{html.escape(str(traceback_obj))}"
|
||||
)
|
||||
error = f"{html.escape(exc_type_str)}: {html.escape(str(traceback_obj))}"
|
||||
|
||||
return TEMPLATE.format(styles=STYLES, js=JS, error=error, exc_html=exc_html)
|
||||
|
||||
|
||||
@@ -1,11 +1,17 @@
|
||||
import typing
|
||||
from __future__ import annotations
|
||||
|
||||
from starlette._utils import is_async_callable
|
||||
from starlette.concurrency import run_in_threadpool
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from starlette._exception_handler import (
|
||||
ExceptionHandlers,
|
||||
StatusHandlers,
|
||||
wrap_app_handling_exceptions,
|
||||
)
|
||||
from starlette.exceptions import HTTPException, WebSocketException
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import PlainTextResponse, Response
|
||||
from starlette.types import ASGIApp, Message, Receive, Scope, Send
|
||||
from starlette.types import ASGIApp, ExceptionHandler, Receive, Scope, Send
|
||||
from starlette.websockets import WebSocket
|
||||
|
||||
|
||||
@@ -13,28 +19,24 @@ class ExceptionMiddleware:
|
||||
def __init__(
|
||||
self,
|
||||
app: ASGIApp,
|
||||
handlers: typing.Optional[
|
||||
typing.Mapping[typing.Any, typing.Callable[[Request, Exception], Response]]
|
||||
] = None,
|
||||
handlers: Mapping[Any, ExceptionHandler] | None = None,
|
||||
debug: bool = False,
|
||||
) -> None:
|
||||
self.app = app
|
||||
self.debug = debug # TODO: We ought to handle 404 cases if debug is set.
|
||||
self._status_handlers: typing.Dict[int, typing.Callable] = {}
|
||||
self._exception_handlers: typing.Dict[
|
||||
typing.Type[Exception], typing.Callable
|
||||
] = {
|
||||
self._status_handlers: StatusHandlers = {}
|
||||
self._exception_handlers: ExceptionHandlers = {
|
||||
HTTPException: self.http_exception,
|
||||
WebSocketException: self.websocket_exception,
|
||||
}
|
||||
if handlers is not None:
|
||||
if handlers is not None: # pragma: no branch
|
||||
for key, value in handlers.items():
|
||||
self.add_exception_handler(key, value)
|
||||
|
||||
def add_exception_handler(
|
||||
self,
|
||||
exc_class_or_status_code: typing.Union[int, typing.Type[Exception]],
|
||||
handler: typing.Callable[[Request, Exception], Response],
|
||||
exc_class_or_status_code: int | type[Exception],
|
||||
handler: ExceptionHandler,
|
||||
) -> None:
|
||||
if isinstance(exc_class_or_status_code, int):
|
||||
self._status_handlers[exc_class_or_status_code] = handler
|
||||
@@ -42,68 +44,30 @@ class ExceptionMiddleware:
|
||||
assert issubclass(exc_class_or_status_code, Exception)
|
||||
self._exception_handlers[exc_class_or_status_code] = handler
|
||||
|
||||
def _lookup_exception_handler(
|
||||
self, exc: Exception
|
||||
) -> typing.Optional[typing.Callable]:
|
||||
for cls in type(exc).__mro__:
|
||||
if cls in self._exception_handlers:
|
||||
return self._exception_handlers[cls]
|
||||
return None
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
if scope["type"] not in ("http", "websocket"):
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
|
||||
response_started = False
|
||||
|
||||
async def sender(message: Message) -> None:
|
||||
nonlocal response_started
|
||||
|
||||
if message["type"] == "http.response.start":
|
||||
response_started = True
|
||||
await send(message)
|
||||
|
||||
try:
|
||||
await self.app(scope, receive, sender)
|
||||
except Exception as exc:
|
||||
handler = None
|
||||
|
||||
if isinstance(exc, HTTPException):
|
||||
handler = self._status_handlers.get(exc.status_code)
|
||||
|
||||
if handler is None:
|
||||
handler = self._lookup_exception_handler(exc)
|
||||
|
||||
if handler is None:
|
||||
raise exc
|
||||
|
||||
if response_started:
|
||||
msg = "Caught handled exception, but response already started."
|
||||
raise RuntimeError(msg) from exc
|
||||
|
||||
if scope["type"] == "http":
|
||||
request = Request(scope, receive=receive)
|
||||
if is_async_callable(handler):
|
||||
response = await handler(request, exc)
|
||||
else:
|
||||
response = await run_in_threadpool(handler, request, exc)
|
||||
await response(scope, receive, sender)
|
||||
elif scope["type"] == "websocket":
|
||||
websocket = WebSocket(scope, receive=receive, send=send)
|
||||
if is_async_callable(handler):
|
||||
await handler(websocket, exc)
|
||||
else:
|
||||
await run_in_threadpool(handler, websocket, exc)
|
||||
|
||||
def http_exception(self, request: Request, exc: HTTPException) -> Response:
|
||||
if exc.status_code in {204, 304}:
|
||||
return Response(status_code=exc.status_code, headers=exc.headers)
|
||||
return PlainTextResponse(
|
||||
exc.detail, status_code=exc.status_code, headers=exc.headers
|
||||
scope["starlette.exception_handlers"] = (
|
||||
self._exception_handlers,
|
||||
self._status_handlers,
|
||||
)
|
||||
|
||||
async def websocket_exception(
|
||||
self, websocket: WebSocket, exc: WebSocketException
|
||||
) -> None:
|
||||
await websocket.close(code=exc.code, reason=exc.reason)
|
||||
conn: Request | WebSocket
|
||||
if scope["type"] == "http":
|
||||
conn = Request(scope, receive, send)
|
||||
else:
|
||||
conn = WebSocket(scope, receive, send)
|
||||
|
||||
await wrap_app_handling_exceptions(self.app, conn)(scope, receive, send)
|
||||
|
||||
async def http_exception(self, request: Request, exc: Exception) -> Response:
|
||||
assert isinstance(exc, HTTPException)
|
||||
if exc.status_code in {204, 304}:
|
||||
return Response(status_code=exc.status_code, headers=exc.headers)
|
||||
return PlainTextResponse(exc.detail, status_code=exc.status_code, headers=exc.headers)
|
||||
|
||||
async def websocket_exception(self, websocket: WebSocket, exc: Exception) -> None:
|
||||
assert isinstance(exc, WebSocketException)
|
||||
await websocket.close(code=exc.code, reason=exc.reason) # pragma: no cover
|
||||
|
||||
@@ -1,49 +1,51 @@
|
||||
import gzip
|
||||
import io
|
||||
import typing
|
||||
from typing import NoReturn
|
||||
|
||||
from starlette.datastructures import Headers, MutableHeaders
|
||||
from starlette.types import ASGIApp, Message, Receive, Scope, Send
|
||||
|
||||
DEFAULT_EXCLUDED_CONTENT_TYPES = ("text/event-stream",)
|
||||
|
||||
|
||||
class GZipMiddleware:
|
||||
def __init__(
|
||||
self, app: ASGIApp, minimum_size: int = 500, compresslevel: int = 9
|
||||
) -> None:
|
||||
def __init__(self, app: ASGIApp, minimum_size: int = 500, compresslevel: int = 9) -> None:
|
||||
self.app = app
|
||||
self.minimum_size = minimum_size
|
||||
self.compresslevel = compresslevel
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
if scope["type"] == "http":
|
||||
headers = Headers(scope=scope)
|
||||
if "gzip" in headers.get("Accept-Encoding", ""):
|
||||
responder = GZipResponder(
|
||||
self.app, self.minimum_size, compresslevel=self.compresslevel
|
||||
)
|
||||
await responder(scope, receive, send)
|
||||
return
|
||||
await self.app(scope, receive, send)
|
||||
if scope["type"] != "http": # pragma: no cover
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
|
||||
headers = Headers(scope=scope)
|
||||
responder: ASGIApp
|
||||
if "gzip" in headers.get("Accept-Encoding", ""):
|
||||
responder = GZipResponder(self.app, self.minimum_size, compresslevel=self.compresslevel)
|
||||
else:
|
||||
responder = IdentityResponder(self.app, self.minimum_size)
|
||||
|
||||
await responder(scope, receive, send)
|
||||
|
||||
|
||||
class GZipResponder:
|
||||
def __init__(self, app: ASGIApp, minimum_size: int, compresslevel: int = 9) -> None:
|
||||
class IdentityResponder:
|
||||
content_encoding: str
|
||||
|
||||
def __init__(self, app: ASGIApp, minimum_size: int) -> None:
|
||||
self.app = app
|
||||
self.minimum_size = minimum_size
|
||||
self.send: Send = unattached_send
|
||||
self.initial_message: Message = {}
|
||||
self.started = False
|
||||
self.content_encoding_set = False
|
||||
self.gzip_buffer = io.BytesIO()
|
||||
self.gzip_file = gzip.GzipFile(
|
||||
mode="wb", fileobj=self.gzip_buffer, compresslevel=compresslevel
|
||||
)
|
||||
self.content_type_is_excluded = False
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
self.send = send
|
||||
await self.app(scope, receive, self.send_with_gzip)
|
||||
await self.app(scope, receive, self.send_with_compression)
|
||||
|
||||
async def send_with_gzip(self, message: Message) -> None:
|
||||
async def send_with_compression(self, message: Message) -> None:
|
||||
message_type = message["type"]
|
||||
if message_type == "http.response.start":
|
||||
# Don't send the initial message until we've determined how to
|
||||
@@ -51,7 +53,8 @@ class GZipResponder:
|
||||
self.initial_message = message
|
||||
headers = Headers(raw=self.initial_message["headers"])
|
||||
self.content_encoding_set = "content-encoding" in headers
|
||||
elif message_type == "http.response.body" and self.content_encoding_set:
|
||||
self.content_type_is_excluded = headers.get("content-type", "").startswith(DEFAULT_EXCLUDED_CONTENT_TYPES)
|
||||
elif message_type == "http.response.body" and (self.content_encoding_set or self.content_type_is_excluded):
|
||||
if not self.started:
|
||||
self.started = True
|
||||
await self.send(self.initial_message)
|
||||
@@ -61,53 +64,82 @@ class GZipResponder:
|
||||
body = message.get("body", b"")
|
||||
more_body = message.get("more_body", False)
|
||||
if len(body) < self.minimum_size and not more_body:
|
||||
# Don't apply GZip to small outgoing responses.
|
||||
# Don't apply compression to small outgoing responses.
|
||||
await self.send(self.initial_message)
|
||||
await self.send(message)
|
||||
elif not more_body:
|
||||
# Standard GZip response.
|
||||
self.gzip_file.write(body)
|
||||
self.gzip_file.close()
|
||||
body = self.gzip_buffer.getvalue()
|
||||
# Standard response.
|
||||
body = self.apply_compression(body, more_body=False)
|
||||
|
||||
headers = MutableHeaders(raw=self.initial_message["headers"])
|
||||
headers["Content-Encoding"] = "gzip"
|
||||
headers["Content-Length"] = str(len(body))
|
||||
headers.add_vary_header("Accept-Encoding")
|
||||
message["body"] = body
|
||||
if body != message["body"]:
|
||||
headers["Content-Encoding"] = self.content_encoding
|
||||
headers["Content-Length"] = str(len(body))
|
||||
message["body"] = body
|
||||
|
||||
await self.send(self.initial_message)
|
||||
await self.send(message)
|
||||
else:
|
||||
# Initial body in streaming GZip response.
|
||||
headers = MutableHeaders(raw=self.initial_message["headers"])
|
||||
headers["Content-Encoding"] = "gzip"
|
||||
headers.add_vary_header("Accept-Encoding")
|
||||
del headers["Content-Length"]
|
||||
# Initial body in streaming response.
|
||||
body = self.apply_compression(body, more_body=True)
|
||||
|
||||
self.gzip_file.write(body)
|
||||
message["body"] = self.gzip_buffer.getvalue()
|
||||
self.gzip_buffer.seek(0)
|
||||
self.gzip_buffer.truncate()
|
||||
headers = MutableHeaders(raw=self.initial_message["headers"])
|
||||
headers.add_vary_header("Accept-Encoding")
|
||||
if body != message["body"]:
|
||||
headers["Content-Encoding"] = self.content_encoding
|
||||
del headers["Content-Length"]
|
||||
message["body"] = body
|
||||
|
||||
await self.send(self.initial_message)
|
||||
await self.send(message)
|
||||
|
||||
elif message_type == "http.response.body":
|
||||
# Remaining body in streaming GZip response.
|
||||
# Remaining body in streaming response.
|
||||
body = message.get("body", b"")
|
||||
more_body = message.get("more_body", False)
|
||||
|
||||
self.gzip_file.write(body)
|
||||
if not more_body:
|
||||
self.gzip_file.close()
|
||||
|
||||
message["body"] = self.gzip_buffer.getvalue()
|
||||
self.gzip_buffer.seek(0)
|
||||
self.gzip_buffer.truncate()
|
||||
message["body"] = self.apply_compression(body, more_body=more_body)
|
||||
|
||||
await self.send(message)
|
||||
elif message_type == "http.response.pathsend": # pragma: no branch
|
||||
# Don't apply GZip to pathsend responses
|
||||
await self.send(self.initial_message)
|
||||
await self.send(message)
|
||||
|
||||
def apply_compression(self, body: bytes, *, more_body: bool) -> bytes:
|
||||
"""Apply compression on the response body.
|
||||
|
||||
If more_body is False, any compression file should be closed. If it
|
||||
isn't, it won't be closed automatically until all background tasks
|
||||
complete.
|
||||
"""
|
||||
return body
|
||||
|
||||
|
||||
async def unattached_send(message: Message) -> typing.NoReturn:
|
||||
class GZipResponder(IdentityResponder):
|
||||
content_encoding = "gzip"
|
||||
|
||||
def __init__(self, app: ASGIApp, minimum_size: int, compresslevel: int = 9) -> None:
|
||||
super().__init__(app, minimum_size)
|
||||
|
||||
self.gzip_buffer = io.BytesIO()
|
||||
self.gzip_file = gzip.GzipFile(mode="wb", fileobj=self.gzip_buffer, compresslevel=compresslevel)
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
with self.gzip_buffer, self.gzip_file:
|
||||
await super().__call__(scope, receive, send)
|
||||
|
||||
def apply_compression(self, body: bytes, *, more_body: bool) -> bytes:
|
||||
self.gzip_file.write(body)
|
||||
if not more_body:
|
||||
self.gzip_file.close()
|
||||
|
||||
body = self.gzip_buffer.getvalue()
|
||||
self.gzip_buffer.seek(0)
|
||||
self.gzip_buffer.truncate()
|
||||
|
||||
return body
|
||||
|
||||
|
||||
async def unattached_send(message: Message) -> NoReturn:
|
||||
raise RuntimeError("send awaitable not set") # pragma: no cover
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import sys
|
||||
import typing
|
||||
from base64 import b64decode, b64encode
|
||||
from typing import Literal
|
||||
|
||||
import itsdangerous
|
||||
from itsdangerous.exc import BadSignature
|
||||
@@ -10,22 +11,18 @@ from starlette.datastructures import MutableHeaders, Secret
|
||||
from starlette.requests import HTTPConnection
|
||||
from starlette.types import ASGIApp, Message, Receive, Scope, Send
|
||||
|
||||
if sys.version_info >= (3, 8): # pragma: no cover
|
||||
from typing import Literal
|
||||
else: # pragma: no cover
|
||||
from typing_extensions import Literal
|
||||
|
||||
|
||||
class SessionMiddleware:
|
||||
def __init__(
|
||||
self,
|
||||
app: ASGIApp,
|
||||
secret_key: typing.Union[str, Secret],
|
||||
secret_key: str | Secret,
|
||||
session_cookie: str = "session",
|
||||
max_age: typing.Optional[int] = 14 * 24 * 60 * 60, # 14 days, in seconds
|
||||
max_age: int | None = 14 * 24 * 60 * 60, # 14 days, in seconds
|
||||
path: str = "/",
|
||||
same_site: Literal["lax", "strict", "none"] = "lax",
|
||||
https_only: bool = False,
|
||||
domain: str | None = None,
|
||||
) -> None:
|
||||
self.app = app
|
||||
self.signer = itsdangerous.TimestampSigner(str(secret_key))
|
||||
@@ -35,6 +32,8 @@ class SessionMiddleware:
|
||||
self.security_flags = "httponly; samesite=" + same_site
|
||||
if https_only: # Secure flag can be used with HTTPS only
|
||||
self.security_flags += "; secure"
|
||||
if domain is not None:
|
||||
self.security_flags += f"; domain={domain}"
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
if scope["type"] not in ("http", "websocket"): # pragma: no cover
|
||||
@@ -62,7 +61,7 @@ class SessionMiddleware:
|
||||
data = b64encode(json.dumps(scope["session"]).encode("utf-8"))
|
||||
data = self.signer.sign(data)
|
||||
headers = MutableHeaders(scope=message)
|
||||
header_value = "{session_cookie}={data}; path={path}; {max_age}{security_flags}".format( # noqa E501
|
||||
header_value = "{session_cookie}={data}; path={path}; {max_age}{security_flags}".format(
|
||||
session_cookie=self.session_cookie,
|
||||
data=data.decode("utf-8"),
|
||||
path=self.path,
|
||||
@@ -73,7 +72,7 @@ class SessionMiddleware:
|
||||
elif not initial_session_was_empty:
|
||||
# The session has been cleared.
|
||||
headers = MutableHeaders(scope=message)
|
||||
header_value = "{session_cookie}={data}; path={path}; {expires}{security_flags}".format( # noqa E501
|
||||
header_value = "{session_cookie}={data}; path={path}; {expires}{security_flags}".format(
|
||||
session_cookie=self.session_cookie,
|
||||
data="null",
|
||||
path=self.path,
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
import typing
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from starlette.datastructures import URL, Headers
|
||||
from starlette.responses import PlainTextResponse, RedirectResponse, Response
|
||||
@@ -11,7 +13,7 @@ class TrustedHostMiddleware:
|
||||
def __init__(
|
||||
self,
|
||||
app: ASGIApp,
|
||||
allowed_hosts: typing.Optional[typing.Sequence[str]] = None,
|
||||
allowed_hosts: Sequence[str] | None = None,
|
||||
www_redirect: bool = True,
|
||||
) -> None:
|
||||
if allowed_hosts is None:
|
||||
@@ -39,9 +41,7 @@ class TrustedHostMiddleware:
|
||||
is_valid_host = False
|
||||
found_www_redirect = False
|
||||
for pattern in self.allowed_hosts:
|
||||
if host == pattern or (
|
||||
pattern.startswith("*") and host.endswith(pattern[1:])
|
||||
):
|
||||
if host == pattern or (pattern.startswith("*") and host.endswith(pattern[1:])):
|
||||
is_valid_host = True
|
||||
break
|
||||
elif "www." + host == pattern:
|
||||
|
||||
@@ -1,10 +1,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import math
|
||||
import sys
|
||||
import typing
|
||||
import warnings
|
||||
from collections.abc import MutableMapping
|
||||
from typing import Any, Callable
|
||||
|
||||
import anyio
|
||||
from anyio.abc import ObjectReceiveStream, ObjectSendStream
|
||||
|
||||
from starlette.types import Receive, Scope, Send
|
||||
|
||||
@@ -15,14 +19,20 @@ warnings.warn(
|
||||
)
|
||||
|
||||
|
||||
def build_environ(scope: Scope, body: bytes) -> dict:
|
||||
def build_environ(scope: Scope, body: bytes) -> dict[str, Any]:
|
||||
"""
|
||||
Builds a scope and request body into a WSGI environ object.
|
||||
"""
|
||||
|
||||
script_name = scope.get("root_path", "").encode("utf8").decode("latin1")
|
||||
path_info = scope["path"].encode("utf8").decode("latin1")
|
||||
if path_info.startswith(script_name):
|
||||
path_info = path_info[len(script_name) :]
|
||||
|
||||
environ = {
|
||||
"REQUEST_METHOD": scope["method"],
|
||||
"SCRIPT_NAME": scope.get("root_path", "").encode("utf8").decode("latin1"),
|
||||
"PATH_INFO": scope["path"].encode("utf8").decode("latin1"),
|
||||
"SCRIPT_NAME": script_name,
|
||||
"PATH_INFO": path_info,
|
||||
"QUERY_STRING": scope["query_string"].decode("ascii"),
|
||||
"SERVER_PROTOCOL": f"HTTP/{scope['http_version']}",
|
||||
"wsgi.version": (1, 0),
|
||||
@@ -62,7 +72,7 @@ def build_environ(scope: Scope, body: bytes) -> dict:
|
||||
|
||||
|
||||
class WSGIMiddleware:
|
||||
def __init__(self, app: typing.Callable) -> None:
|
||||
def __init__(self, app: Callable[..., Any]) -> None:
|
||||
self.app = app
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
@@ -72,16 +82,17 @@ class WSGIMiddleware:
|
||||
|
||||
|
||||
class WSGIResponder:
|
||||
def __init__(self, app: typing.Callable, scope: Scope) -> None:
|
||||
stream_send: ObjectSendStream[MutableMapping[str, Any]]
|
||||
stream_receive: ObjectReceiveStream[MutableMapping[str, Any]]
|
||||
|
||||
def __init__(self, app: Callable[..., Any], scope: Scope) -> None:
|
||||
self.app = app
|
||||
self.scope = scope
|
||||
self.status = None
|
||||
self.response_headers = None
|
||||
self.stream_send, self.stream_receive = anyio.create_memory_object_stream(
|
||||
math.inf
|
||||
)
|
||||
self.stream_send, self.stream_receive = anyio.create_memory_object_stream(math.inf)
|
||||
self.response_started = False
|
||||
self.exc_info: typing.Any = None
|
||||
self.exc_info: Any = None
|
||||
|
||||
async def __call__(self, receive: Receive, send: Send) -> None:
|
||||
body = b""
|
||||
@@ -107,11 +118,11 @@ class WSGIResponder:
|
||||
def start_response(
|
||||
self,
|
||||
status: str,
|
||||
response_headers: typing.List[typing.Tuple[str, str]],
|
||||
exc_info: typing.Any = None,
|
||||
response_headers: list[tuple[str, str]],
|
||||
exc_info: Any = None,
|
||||
) -> None:
|
||||
self.exc_info = exc_info
|
||||
if not self.response_started:
|
||||
if not self.response_started: # pragma: no branch
|
||||
self.response_started = True
|
||||
status_code_string, _ = status.split(" ", 1)
|
||||
status_code = int(status_code_string)
|
||||
@@ -128,13 +139,15 @@ class WSGIResponder:
|
||||
},
|
||||
)
|
||||
|
||||
def wsgi(self, environ: dict, start_response: typing.Callable) -> None:
|
||||
def wsgi(
|
||||
self,
|
||||
environ: dict[str, Any],
|
||||
start_response: Callable[..., Any],
|
||||
) -> None:
|
||||
for chunk in self.app(environ, start_response):
|
||||
anyio.from_thread.run(
|
||||
self.stream_send.send,
|
||||
{"type": "http.response.body", "body": chunk, "more_body": True},
|
||||
)
|
||||
|
||||
anyio.from_thread.run(
|
||||
self.stream_send.send, {"type": "http.response.body", "body": b""}
|
||||
)
|
||||
anyio.from_thread.run(self.stream_send.send, {"type": "http.response.body", "body": b""})
|
||||
|
||||
Reference in New Issue
Block a user