main commit
All checks were successful
continuous-integration/drone/push Build is passing

This commit is contained in:
2025-10-16 16:30:25 +09:00
parent 91c7e04474
commit 537e7b363f
1146 changed files with 45926 additions and 77196 deletions

View File

@@ -1,64 +1,43 @@
from __future__ import annotations
import contextlib
import inspect
import io
import json
import math
import queue
import sys
import typing
import warnings
from collections.abc import Awaitable, Generator, Iterable, Mapping, MutableMapping, Sequence
from concurrent.futures import Future
from contextlib import AbstractContextManager
from types import GeneratorType
from typing import (
Any,
Callable,
Literal,
TypedDict,
Union,
cast,
)
from urllib.parse import unquote, urljoin
import anyio
import anyio.abc
import anyio.from_thread
import httpx
from anyio.streams.stapled import StapledObjectStream
from starlette._utils import is_async_callable
from starlette.types import ASGIApp, Message, Receive, Scope, Send
from starlette.websockets import WebSocketDisconnect
if sys.version_info >= (3, 10): # pragma: no cover
from typing import TypeGuard
if sys.version_info >= (3, 8): # pragma: no cover
from typing import TypedDict
else: # pragma: no cover
from typing_extensions import TypeGuard
from typing_extensions import TypedDict
if sys.version_info >= (3, 11): # pragma: no cover
from typing import Self
else: # pragma: no cover
from typing_extensions import Self
_PortalFactoryType = typing.Callable[
[], typing.ContextManager[anyio.abc.BlockingPortal]
]
try:
import httpx
except ModuleNotFoundError: # pragma: no cover
raise RuntimeError(
"The starlette.testclient module requires the httpx package to be installed.\n"
"You can install this with:\n"
" $ pip install httpx\n"
)
_PortalFactoryType = Callable[[], AbstractContextManager[anyio.abc.BlockingPortal]]
ASGIInstance = Callable[[Receive, Send], Awaitable[None]]
ASGI2App = Callable[[Scope], ASGIInstance]
ASGI3App = Callable[[Scope, Receive, Send], Awaitable[None]]
ASGIInstance = typing.Callable[[Receive, Send], typing.Awaitable[None]]
ASGI2App = typing.Callable[[Scope], ASGIInstance]
ASGI3App = typing.Callable[[Scope, Receive, Send], typing.Awaitable[None]]
_RequestData = Mapping[str, Union[str, Iterable[str], bytes]]
_RequestData = typing.Mapping[str, typing.Union[str, typing.Iterable[str]]]
def _is_asgi3(app: ASGI2App | ASGI3App) -> TypeGuard[ASGI3App]:
def _is_asgi3(app: typing.Union[ASGI2App, ASGI3App]) -> bool:
if inspect.isclass(app):
return hasattr(app, "__await__")
return is_async_callable(app)
@@ -79,24 +58,14 @@ class _WrapASGI2:
class _AsyncBackend(TypedDict):
backend: str
backend_options: dict[str, Any]
backend_options: typing.Dict[str, typing.Any]
class _Upgrade(Exception):
def __init__(self, session: WebSocketTestSession) -> None:
def __init__(self, session: "WebSocketTestSession") -> None:
self.session = session
class WebSocketDenialResponse( # type: ignore[misc]
httpx.Response,
WebSocketDisconnect,
):
"""
A special case of `WebSocketDisconnect`, raised in the `TestClient` if the
`WebSocket` is closed before being accepted with a `send_denial_response()`.
"""
class WebSocketTestSession:
def __init__(
self,
@@ -108,60 +77,65 @@ class WebSocketTestSession:
self.scope = scope
self.accepted_subprotocol = None
self.portal_factory = portal_factory
self._receive_queue: "queue.Queue[typing.Any]" = queue.Queue()
self._send_queue: "queue.Queue[typing.Any]" = queue.Queue()
self.extra_headers = None
def __enter__(self) -> WebSocketTestSession:
with contextlib.ExitStack() as stack:
self.portal = portal = stack.enter_context(self.portal_factory())
fut, cs = portal.start_task(self._run)
stack.callback(fut.result)
stack.callback(portal.call, cs.cancel)
def __enter__(self) -> "WebSocketTestSession":
self.exit_stack = contextlib.ExitStack()
self.portal = self.exit_stack.enter_context(self.portal_factory())
try:
_: "Future[None]" = self.portal.start_task_soon(self._run)
self.send({"type": "websocket.connect"})
message = self.receive()
self._raise_on_close(message)
self.accepted_subprotocol = message.get("subprotocol", None)
self.extra_headers = message.get("headers", None)
stack.callback(self.close, 1000)
self.exit_stack = stack.pop_all()
return self
except Exception:
self.exit_stack.close()
raise
self.accepted_subprotocol = message.get("subprotocol", None)
self.extra_headers = message.get("headers", None)
return self
def __exit__(self, *args: Any) -> bool | None:
return self.exit_stack.__exit__(*args)
def __exit__(self, *args: typing.Any) -> None:
try:
self.close(1000)
finally:
self.exit_stack.close()
while not self._send_queue.empty():
message = self._send_queue.get()
if isinstance(message, BaseException):
raise message
async def _run(self, *, task_status: anyio.abc.TaskStatus[anyio.CancelScope]) -> None:
async def _run(self) -> None:
"""
The sub-thread in which the websocket session runs.
"""
send: anyio.create_memory_object_stream[Message] = anyio.create_memory_object_stream(math.inf)
send_tx, send_rx = send
receive: anyio.create_memory_object_stream[Message] = anyio.create_memory_object_stream(math.inf)
receive_tx, receive_rx = receive
with send_tx, send_rx, receive_tx, receive_rx, anyio.CancelScope() as cs:
self._receive_tx = receive_tx
self._send_rx = send_rx
task_status.started(cs)
await self.app(self.scope, receive_rx.receive, send_tx.send)
scope = self.scope
receive = self._asgi_receive
send = self._asgi_send
try:
await self.app(scope, receive, send)
except BaseException as exc:
self._send_queue.put(exc)
raise
# wait for cs.cancel to be called before closing streams
await anyio.sleep_forever()
async def _asgi_receive(self) -> Message:
while self._receive_queue.empty():
await anyio.sleep(0)
return self._receive_queue.get()
async def _asgi_send(self, message: Message) -> None:
self._send_queue.put(message)
def _raise_on_close(self, message: Message) -> None:
if message["type"] == "websocket.close":
raise WebSocketDisconnect(code=message.get("code", 1000), reason=message.get("reason", ""))
elif message["type"] == "websocket.http.response.start":
status_code: int = message["status"]
headers: list[tuple[bytes, bytes]] = message["headers"]
body: list[bytes] = []
while True:
message = self.receive()
assert message["type"] == "websocket.http.response.body"
body.append(message["body"])
if not message.get("more_body", False):
break
raise WebSocketDenialResponse(status_code=status_code, headers=headers, content=b"".join(body))
raise WebSocketDisconnect(
message.get("code", 1000), message.get("reason", "")
)
def send(self, message: Message) -> None:
self.portal.call(self._receive_tx.send, message)
self._receive_queue.put(message)
def send_text(self, data: str) -> None:
self.send({"type": "websocket.receive", "text": data})
@@ -169,30 +143,35 @@ class WebSocketTestSession:
def send_bytes(self, data: bytes) -> None:
self.send({"type": "websocket.receive", "bytes": data})
def send_json(self, data: Any, mode: Literal["text", "binary"] = "text") -> None:
text = json.dumps(data, separators=(",", ":"), ensure_ascii=False)
def send_json(self, data: typing.Any, mode: str = "text") -> None:
assert mode in ["text", "binary"]
text = json.dumps(data, separators=(",", ":"))
if mode == "text":
self.send({"type": "websocket.receive", "text": text})
else:
self.send({"type": "websocket.receive", "bytes": text.encode("utf-8")})
def close(self, code: int = 1000, reason: str | None = None) -> None:
self.send({"type": "websocket.disconnect", "code": code, "reason": reason})
def close(self, code: int = 1000) -> None:
self.send({"type": "websocket.disconnect", "code": code})
def receive(self) -> Message:
return self.portal.call(self._send_rx.receive)
message = self._send_queue.get()
if isinstance(message, BaseException):
raise message
return message
def receive_text(self) -> str:
message = self.receive()
self._raise_on_close(message)
return cast(str, message["text"])
return message["text"]
def receive_bytes(self) -> bytes:
message = self.receive()
self._raise_on_close(message)
return cast(bytes, message["bytes"])
return message["bytes"]
def receive_json(self, mode: Literal["text", "binary"] = "text") -> Any:
def receive_json(self, mode: str = "text") -> typing.Any:
assert mode in ["text", "binary"]
message = self.receive()
self._raise_on_close(message)
if mode == "text":
@@ -210,15 +189,13 @@ class _TestClientTransport(httpx.BaseTransport):
raise_server_exceptions: bool = True,
root_path: str = "",
*,
client: tuple[str, int],
app_state: dict[str, Any],
app_state: typing.Dict[str, typing.Any],
) -> None:
self.app = app
self.raise_server_exceptions = raise_server_exceptions
self.root_path = root_path
self.portal_factory = portal_factory
self.app_state = app_state
self.client = client
def handle_request(self, request: httpx.Request) -> httpx.Response:
scheme = request.url.scheme
@@ -238,36 +215,38 @@ class _TestClientTransport(httpx.BaseTransport):
# Include the 'host' header.
if "host" in request.headers:
headers: list[tuple[bytes, bytes]] = []
headers: typing.List[typing.Tuple[bytes, bytes]] = []
elif port == default_port: # pragma: no cover
headers = [(b"host", host.encode())]
else: # pragma: no cover
headers = [(b"host", (f"{host}:{port}").encode())]
# Include other request headers.
headers += [(key.lower().encode(), value.encode()) for key, value in request.headers.multi_items()]
headers += [
(key.lower().encode(), value.encode())
for key, value in request.headers.items()
]
scope: dict[str, Any]
scope: typing.Dict[str, typing.Any]
if scheme in {"ws", "wss"}:
subprotocol = request.headers.get("sec-websocket-protocol", None)
if subprotocol is None:
subprotocols: Sequence[str] = []
subprotocols: typing.Sequence[str] = []
else:
subprotocols = [value.strip() for value in subprotocol.split(",")]
scope = {
"type": "websocket",
"path": unquote(path),
"raw_path": raw_path.split(b"?", 1)[0],
"raw_path": raw_path,
"root_path": self.root_path,
"scheme": scheme,
"query_string": query.encode(),
"headers": headers,
"client": self.client,
"client": ["testclient", 50000],
"server": [host, port],
"subprotocols": subprotocols,
"state": self.app_state.copy(),
"extensions": {"websocket.http.response": {}},
}
session = WebSocketTestSession(self.app, scope, self.portal_factory)
raise _Upgrade(session)
@@ -277,12 +256,12 @@ class _TestClientTransport(httpx.BaseTransport):
"http_version": "1.1",
"method": request.method,
"path": unquote(path),
"raw_path": raw_path.split(b"?", 1)[0],
"raw_path": raw_path,
"root_path": self.root_path,
"scheme": scheme,
"query_string": query.encode(),
"headers": headers,
"client": self.client,
"client": ["testclient", 50000],
"server": [host, port],
"extensions": {"http.response.debug": {}},
"state": self.app_state.copy(),
@@ -291,7 +270,7 @@ class _TestClientTransport(httpx.BaseTransport):
request_complete = False
response_started = False
response_complete: anyio.Event
raw_kwargs: dict[str, Any] = {"stream": io.BytesIO()}
raw_kwargs: typing.Dict[str, typing.Any] = {"stream": io.BytesIO()}
template = None
context = None
@@ -327,13 +306,22 @@ class _TestClientTransport(httpx.BaseTransport):
nonlocal raw_kwargs, response_started, template, context
if message["type"] == "http.response.start":
assert not response_started, 'Received multiple "http.response.start" messages.'
assert (
not response_started
), 'Received multiple "http.response.start" messages.'
raw_kwargs["status_code"] = message["status"]
raw_kwargs["headers"] = [(key.decode(), value.decode()) for key, value in message.get("headers", [])]
raw_kwargs["headers"] = [
(key.decode(), value.decode())
for key, value in message.get("headers", [])
]
response_started = True
elif message["type"] == "http.response.body":
assert response_started, 'Received "http.response.body" without "http.response.start".'
assert not response_complete.is_set(), 'Received "http.response.body" after response completed.'
assert (
response_started
), 'Received "http.response.body" without "http.response.start".'
assert (
not response_complete.is_set()
), 'Received "http.response.body" after response completed.'
body = message.get("body", b"")
more_body = message.get("more_body", False)
if request.method != "HEAD":
@@ -373,8 +361,8 @@ class _TestClientTransport(httpx.BaseTransport):
class TestClient(httpx.Client):
__test__ = False
task: Future[None]
portal: anyio.abc.BlockingPortal | None = None
task: "Future[None]"
portal: typing.Optional[anyio.abc.BlockingPortal] = None
def __init__(
self,
@@ -382,84 +370,110 @@ class TestClient(httpx.Client):
base_url: str = "http://testserver",
raise_server_exceptions: bool = True,
root_path: str = "",
backend: Literal["asyncio", "trio"] = "asyncio",
backend_options: dict[str, Any] | None = None,
cookies: httpx._types.CookieTypes | None = None,
headers: dict[str, str] | None = None,
follow_redirects: bool = True,
client: tuple[str, int] = ("testclient", 50000),
backend: str = "asyncio",
backend_options: typing.Optional[typing.Dict[str, typing.Any]] = None,
cookies: httpx._client.CookieTypes = None,
headers: typing.Dict[str, str] = None,
) -> None:
self.async_backend = _AsyncBackend(backend=backend, backend_options=backend_options or {})
self.async_backend = _AsyncBackend(
backend=backend, backend_options=backend_options or {}
)
if _is_asgi3(app):
app = typing.cast(ASGI3App, app)
asgi_app = app
else:
app = cast(ASGI2App, app) # type: ignore[assignment]
app = typing.cast(ASGI2App, app) # type: ignore[assignment]
asgi_app = _WrapASGI2(app) # type: ignore[arg-type]
self.app = asgi_app
self.app_state: dict[str, Any] = {}
self.app_state: typing.Dict[str, typing.Any] = {}
transport = _TestClientTransport(
self.app,
portal_factory=self._portal_factory,
raise_server_exceptions=raise_server_exceptions,
root_path=root_path,
app_state=self.app_state,
client=client,
)
if headers is None:
headers = {}
headers.setdefault("user-agent", "testclient")
super().__init__(
app=self.app,
base_url=base_url,
headers=headers,
transport=transport,
follow_redirects=follow_redirects,
follow_redirects=True,
cookies=cookies,
)
@contextlib.contextmanager
def _portal_factory(self) -> Generator[anyio.abc.BlockingPortal, None, None]:
def _portal_factory(self) -> typing.Generator[anyio.abc.BlockingPortal, None, None]:
if self.portal is not None:
yield self.portal
else:
with anyio.from_thread.start_blocking_portal(**self.async_backend) as portal:
with anyio.from_thread.start_blocking_portal(
**self.async_backend
) as portal:
yield portal
def _choose_redirect_arg(
self,
follow_redirects: typing.Optional[bool],
allow_redirects: typing.Optional[bool],
) -> typing.Union[bool, httpx._client.UseClientDefault]:
redirect: typing.Union[
bool, httpx._client.UseClientDefault
] = httpx._client.USE_CLIENT_DEFAULT
if allow_redirects is not None:
message = (
"The `allow_redirects` argument is deprecated. "
"Use `follow_redirects` instead."
)
warnings.warn(message, DeprecationWarning)
redirect = allow_redirects
if follow_redirects is not None:
redirect = follow_redirects
elif allow_redirects is not None and follow_redirects is not None:
raise RuntimeError( # pragma: no cover
"Cannot use both `allow_redirects` and `follow_redirects`."
)
return redirect
def request( # type: ignore[override]
self,
method: str,
url: httpx._types.URLTypes,
*,
content: httpx._types.RequestContent | None = None,
data: _RequestData | None = None,
files: httpx._types.RequestFiles | None = None,
json: Any = None,
params: httpx._types.QueryParamTypes | None = None,
headers: httpx._types.HeaderTypes | None = None,
cookies: httpx._types.CookieTypes | None = None,
auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
extensions: dict[str, Any] | None = None,
content: typing.Optional[httpx._types.RequestContent] = None,
data: typing.Optional[_RequestData] = None,
files: typing.Optional[httpx._types.RequestFiles] = None,
json: typing.Any = None,
params: typing.Optional[httpx._types.QueryParamTypes] = None,
headers: typing.Optional[httpx._types.HeaderTypes] = None,
cookies: typing.Optional[httpx._types.CookieTypes] = None,
auth: typing.Union[
httpx._types.AuthTypes, httpx._client.UseClientDefault
] = httpx._client.USE_CLIENT_DEFAULT,
follow_redirects: typing.Optional[bool] = None,
allow_redirects: typing.Optional[bool] = None,
timeout: typing.Union[
httpx._client.TimeoutTypes, httpx._client.UseClientDefault
] = httpx._client.USE_CLIENT_DEFAULT,
extensions: typing.Optional[typing.Dict[str, typing.Any]] = None,
) -> httpx.Response:
if timeout is not httpx.USE_CLIENT_DEFAULT:
warnings.warn(
"You should not use the 'timeout' argument with the TestClient. "
"See https://github.com/Kludex/starlette/issues/1108 for more information.",
DeprecationWarning,
)
url = self._merge_url(url)
url = self.base_url.join(url)
redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
return super().request(
method,
url,
content=content,
data=data,
data=data, # type: ignore[arg-type]
files=files,
json=json,
params=params,
headers=headers,
cookies=cookies,
auth=auth,
follow_redirects=follow_redirects,
follow_redirects=redirect,
timeout=timeout,
extensions=extensions,
)
@@ -468,21 +482,27 @@ class TestClient(httpx.Client):
self,
url: httpx._types.URLTypes,
*,
params: httpx._types.QueryParamTypes | None = None,
headers: httpx._types.HeaderTypes | None = None,
cookies: httpx._types.CookieTypes | None = None,
auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
extensions: dict[str, Any] | None = None,
params: typing.Optional[httpx._types.QueryParamTypes] = None,
headers: typing.Optional[httpx._types.HeaderTypes] = None,
cookies: typing.Optional[httpx._types.CookieTypes] = None,
auth: typing.Union[
httpx._types.AuthTypes, httpx._client.UseClientDefault
] = httpx._client.USE_CLIENT_DEFAULT,
follow_redirects: typing.Optional[bool] = None,
allow_redirects: typing.Optional[bool] = None,
timeout: typing.Union[
httpx._client.TimeoutTypes, httpx._client.UseClientDefault
] = httpx._client.USE_CLIENT_DEFAULT,
extensions: typing.Optional[typing.Dict[str, typing.Any]] = None,
) -> httpx.Response:
redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
return super().get(
url,
params=params,
headers=headers,
cookies=cookies,
auth=auth,
follow_redirects=follow_redirects,
follow_redirects=redirect,
timeout=timeout,
extensions=extensions,
)
@@ -491,21 +511,27 @@ class TestClient(httpx.Client):
self,
url: httpx._types.URLTypes,
*,
params: httpx._types.QueryParamTypes | None = None,
headers: httpx._types.HeaderTypes | None = None,
cookies: httpx._types.CookieTypes | None = None,
auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
extensions: dict[str, Any] | None = None,
params: typing.Optional[httpx._types.QueryParamTypes] = None,
headers: typing.Optional[httpx._types.HeaderTypes] = None,
cookies: typing.Optional[httpx._types.CookieTypes] = None,
auth: typing.Union[
httpx._types.AuthTypes, httpx._client.UseClientDefault
] = httpx._client.USE_CLIENT_DEFAULT,
follow_redirects: typing.Optional[bool] = None,
allow_redirects: typing.Optional[bool] = None,
timeout: typing.Union[
httpx._client.TimeoutTypes, httpx._client.UseClientDefault
] = httpx._client.USE_CLIENT_DEFAULT,
extensions: typing.Optional[typing.Dict[str, typing.Any]] = None,
) -> httpx.Response:
redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
return super().options(
url,
params=params,
headers=headers,
cookies=cookies,
auth=auth,
follow_redirects=follow_redirects,
follow_redirects=redirect,
timeout=timeout,
extensions=extensions,
)
@@ -514,21 +540,27 @@ class TestClient(httpx.Client):
self,
url: httpx._types.URLTypes,
*,
params: httpx._types.QueryParamTypes | None = None,
headers: httpx._types.HeaderTypes | None = None,
cookies: httpx._types.CookieTypes | None = None,
auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
extensions: dict[str, Any] | None = None,
params: typing.Optional[httpx._types.QueryParamTypes] = None,
headers: typing.Optional[httpx._types.HeaderTypes] = None,
cookies: typing.Optional[httpx._types.CookieTypes] = None,
auth: typing.Union[
httpx._types.AuthTypes, httpx._client.UseClientDefault
] = httpx._client.USE_CLIENT_DEFAULT,
follow_redirects: typing.Optional[bool] = None,
allow_redirects: typing.Optional[bool] = None,
timeout: typing.Union[
httpx._client.TimeoutTypes, httpx._client.UseClientDefault
] = httpx._client.USE_CLIENT_DEFAULT,
extensions: typing.Optional[typing.Dict[str, typing.Any]] = None,
) -> httpx.Response:
redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
return super().head(
url,
params=params,
headers=headers,
cookies=cookies,
auth=auth,
follow_redirects=follow_redirects,
follow_redirects=redirect,
timeout=timeout,
extensions=extensions,
)
@@ -537,29 +569,35 @@ class TestClient(httpx.Client):
self,
url: httpx._types.URLTypes,
*,
content: httpx._types.RequestContent | None = None,
data: _RequestData | None = None,
files: httpx._types.RequestFiles | None = None,
json: Any = None,
params: httpx._types.QueryParamTypes | None = None,
headers: httpx._types.HeaderTypes | None = None,
cookies: httpx._types.CookieTypes | None = None,
auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
extensions: dict[str, Any] | None = None,
content: typing.Optional[httpx._types.RequestContent] = None,
data: typing.Optional[_RequestData] = None,
files: typing.Optional[httpx._types.RequestFiles] = None,
json: typing.Any = None,
params: typing.Optional[httpx._types.QueryParamTypes] = None,
headers: typing.Optional[httpx._types.HeaderTypes] = None,
cookies: typing.Optional[httpx._types.CookieTypes] = None,
auth: typing.Union[
httpx._types.AuthTypes, httpx._client.UseClientDefault
] = httpx._client.USE_CLIENT_DEFAULT,
follow_redirects: typing.Optional[bool] = None,
allow_redirects: typing.Optional[bool] = None,
timeout: typing.Union[
httpx._client.TimeoutTypes, httpx._client.UseClientDefault
] = httpx._client.USE_CLIENT_DEFAULT,
extensions: typing.Optional[typing.Dict[str, typing.Any]] = None,
) -> httpx.Response:
redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
return super().post(
url,
content=content,
data=data,
data=data, # type: ignore[arg-type]
files=files,
json=json,
params=params,
headers=headers,
cookies=cookies,
auth=auth,
follow_redirects=follow_redirects,
follow_redirects=redirect,
timeout=timeout,
extensions=extensions,
)
@@ -568,29 +606,35 @@ class TestClient(httpx.Client):
self,
url: httpx._types.URLTypes,
*,
content: httpx._types.RequestContent | None = None,
data: _RequestData | None = None,
files: httpx._types.RequestFiles | None = None,
json: Any = None,
params: httpx._types.QueryParamTypes | None = None,
headers: httpx._types.HeaderTypes | None = None,
cookies: httpx._types.CookieTypes | None = None,
auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
extensions: dict[str, Any] | None = None,
content: typing.Optional[httpx._types.RequestContent] = None,
data: typing.Optional[_RequestData] = None,
files: typing.Optional[httpx._types.RequestFiles] = None,
json: typing.Any = None,
params: typing.Optional[httpx._types.QueryParamTypes] = None,
headers: typing.Optional[httpx._types.HeaderTypes] = None,
cookies: typing.Optional[httpx._types.CookieTypes] = None,
auth: typing.Union[
httpx._types.AuthTypes, httpx._client.UseClientDefault
] = httpx._client.USE_CLIENT_DEFAULT,
follow_redirects: typing.Optional[bool] = None,
allow_redirects: typing.Optional[bool] = None,
timeout: typing.Union[
httpx._client.TimeoutTypes, httpx._client.UseClientDefault
] = httpx._client.USE_CLIENT_DEFAULT,
extensions: typing.Optional[typing.Dict[str, typing.Any]] = None,
) -> httpx.Response:
redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
return super().put(
url,
content=content,
data=data,
data=data, # type: ignore[arg-type]
files=files,
json=json,
params=params,
headers=headers,
cookies=cookies,
auth=auth,
follow_redirects=follow_redirects,
follow_redirects=redirect,
timeout=timeout,
extensions=extensions,
)
@@ -599,29 +643,35 @@ class TestClient(httpx.Client):
self,
url: httpx._types.URLTypes,
*,
content: httpx._types.RequestContent | None = None,
data: _RequestData | None = None,
files: httpx._types.RequestFiles | None = None,
json: Any = None,
params: httpx._types.QueryParamTypes | None = None,
headers: httpx._types.HeaderTypes | None = None,
cookies: httpx._types.CookieTypes | None = None,
auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
extensions: dict[str, Any] | None = None,
content: typing.Optional[httpx._types.RequestContent] = None,
data: typing.Optional[_RequestData] = None,
files: typing.Optional[httpx._types.RequestFiles] = None,
json: typing.Any = None,
params: typing.Optional[httpx._types.QueryParamTypes] = None,
headers: typing.Optional[httpx._types.HeaderTypes] = None,
cookies: typing.Optional[httpx._types.CookieTypes] = None,
auth: typing.Union[
httpx._types.AuthTypes, httpx._client.UseClientDefault
] = httpx._client.USE_CLIENT_DEFAULT,
follow_redirects: typing.Optional[bool] = None,
allow_redirects: typing.Optional[bool] = None,
timeout: typing.Union[
httpx._client.TimeoutTypes, httpx._client.UseClientDefault
] = httpx._client.USE_CLIENT_DEFAULT,
extensions: typing.Optional[typing.Dict[str, typing.Any]] = None,
) -> httpx.Response:
redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
return super().patch(
url,
content=content,
data=data,
data=data, # type: ignore[arg-type]
files=files,
json=json,
params=params,
headers=headers,
cookies=cookies,
auth=auth,
follow_redirects=follow_redirects,
follow_redirects=redirect,
timeout=timeout,
extensions=extensions,
)
@@ -630,31 +680,34 @@ class TestClient(httpx.Client):
self,
url: httpx._types.URLTypes,
*,
params: httpx._types.QueryParamTypes | None = None,
headers: httpx._types.HeaderTypes | None = None,
cookies: httpx._types.CookieTypes | None = None,
auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
extensions: dict[str, Any] | None = None,
params: typing.Optional[httpx._types.QueryParamTypes] = None,
headers: typing.Optional[httpx._types.HeaderTypes] = None,
cookies: typing.Optional[httpx._types.CookieTypes] = None,
auth: typing.Union[
httpx._types.AuthTypes, httpx._client.UseClientDefault
] = httpx._client.USE_CLIENT_DEFAULT,
follow_redirects: typing.Optional[bool] = None,
allow_redirects: typing.Optional[bool] = None,
timeout: typing.Union[
httpx._client.TimeoutTypes, httpx._client.UseClientDefault
] = httpx._client.USE_CLIENT_DEFAULT,
extensions: typing.Optional[typing.Dict[str, typing.Any]] = None,
) -> httpx.Response:
redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
return super().delete(
url,
params=params,
headers=headers,
cookies=cookies,
auth=auth,
follow_redirects=follow_redirects,
follow_redirects=redirect,
timeout=timeout,
extensions=extensions,
)
def websocket_connect(
self,
url: str,
subprotocols: Sequence[str] | None = None,
**kwargs: Any,
) -> WebSocketTestSession:
self, url: str, subprotocols: typing.Sequence[str] = None, **kwargs: typing.Any
) -> typing.Any:
url = urljoin("ws://testserver", url)
headers = kwargs.get("headers", {})
headers.setdefault("connection", "upgrade")
@@ -672,24 +725,22 @@ class TestClient(httpx.Client):
return session
def __enter__(self) -> Self:
def __enter__(self) -> "TestClient":
with contextlib.ExitStack() as stack:
self.portal = portal = stack.enter_context(anyio.from_thread.start_blocking_portal(**self.async_backend))
self.portal = portal = stack.enter_context(
anyio.from_thread.start_blocking_portal(**self.async_backend)
)
@stack.callback
def reset_portal() -> None:
self.portal = None
send: anyio.create_memory_object_stream[MutableMapping[str, Any] | None] = (
anyio.create_memory_object_stream(math.inf)
self.stream_send = StapledObjectStream(
*anyio.create_memory_object_stream(math.inf)
)
receive: anyio.create_memory_object_stream[MutableMapping[str, Any]] = anyio.create_memory_object_stream(
math.inf
self.stream_receive = StapledObjectStream(
*anyio.create_memory_object_stream(math.inf)
)
for channel in (*send, *receive):
stack.callback(channel.close)
self.stream_send = StapledObjectStream(*send)
self.stream_receive = StapledObjectStream(*receive)
self.task = portal.start_task_soon(self.lifespan)
portal.call(self.wait_startup)
@@ -701,7 +752,7 @@ class TestClient(httpx.Client):
return self
def __exit__(self, *args: Any) -> None:
def __exit__(self, *args: typing.Any) -> None:
self.exit_stack.close()
async def lifespan(self) -> None:
@@ -714,7 +765,7 @@ class TestClient(httpx.Client):
async def wait_startup(self) -> None:
await self.stream_receive.send({"type": "lifespan.startup"})
async def receive() -> Any:
async def receive() -> typing.Any:
message = await self.stream_send.receive()
if message is None:
self.task.result()
@@ -729,17 +780,18 @@ class TestClient(httpx.Client):
await receive()
async def wait_shutdown(self) -> None:
async def receive() -> Any:
async def receive() -> typing.Any:
message = await self.stream_send.receive()
if message is None:
self.task.result()
return message
await self.stream_receive.send({"type": "lifespan.shutdown"})
message = await receive()
assert message["type"] in (
"lifespan.shutdown.complete",
"lifespan.shutdown.failed",
)
if message["type"] == "lifespan.shutdown.failed":
await receive()
async with self.stream_send:
await self.stream_receive.send({"type": "lifespan.shutdown"})
message = await receive()
assert message["type"] in (
"lifespan.shutdown.complete",
"lifespan.shutdown.failed",
)
if message["type"] == "lifespan.shutdown.failed":
await receive()