main commit
All checks were successful
continuous-integration/drone/push Build is passing

This commit is contained in:
2025-10-16 16:30:25 +09:00
parent 91c7e04474
commit 537e7b363f
1146 changed files with 45926 additions and 77196 deletions

View File

@@ -1,19 +1,16 @@
import asyncio
import dataclasses
import email.message
import inspect
import json
import sys
from contextlib import AsyncExitStack, asynccontextmanager
from contextlib import AsyncExitStack
from enum import Enum, IntEnum
from typing import (
Any,
AsyncIterator,
Callable,
Collection,
Coroutine,
Dict,
List,
Mapping,
Optional,
Sequence,
Set,
@@ -34,10 +31,8 @@ from fastapi._compat import (
from fastapi.datastructures import Default, DefaultPlaceholder
from fastapi.dependencies.models import Dependant
from fastapi.dependencies.utils import (
_should_embed_body_fields,
get_body_field,
get_dependant,
get_flat_dependant,
get_parameterless_sub_dependant,
get_typed_return_annotation,
solve_dependencies,
@@ -52,7 +47,7 @@ from fastapi.exceptions import (
from fastapi.types import DecoratedCallable, IncEx
from fastapi.utils import (
create_cloned_field,
create_model_field,
create_response_field,
generate_unique_id,
get_value_or_default,
is_body_allowed_for_status_code,
@@ -72,14 +67,9 @@ from starlette.routing import (
websocket_session,
)
from starlette.routing import Mount as Mount # noqa
from starlette.types import AppType, ASGIApp, Lifespan, Scope
from starlette.types import ASGIApp, Lifespan, Scope
from starlette.websockets import WebSocket
from typing_extensions import Annotated, Doc, deprecated
if sys.version_info >= (3, 13): # pragma: no cover
from inspect import iscoroutinefunction
else: # pragma: no cover
from asyncio import iscoroutinefunction
from typing_extensions import Annotated, Doc, deprecated # type: ignore [attr-defined]
def _prepare_response_content(
@@ -125,28 +115,10 @@ def _prepare_response_content(
for k, v in res.items()
}
elif dataclasses.is_dataclass(res):
assert not isinstance(res, type)
return dataclasses.asdict(res)
return res
def _merge_lifespan_context(
original_context: Lifespan[Any], nested_context: Lifespan[Any]
) -> Lifespan[Any]:
@asynccontextmanager
async def merged_lifespan(
app: AppType,
) -> AsyncIterator[Optional[Mapping[str, Any]]]:
async with original_context(app) as maybe_original_state:
async with nested_context(app) as maybe_nested_state:
if maybe_nested_state is None and maybe_original_state is None:
yield None # old ASGI compatibility
else:
yield {**(maybe_nested_state or {}), **(maybe_original_state or {})}
return merged_lifespan # type: ignore[return-value]
async def serialize_response(
*,
field: Optional[ModelField] = None,
@@ -234,10 +206,9 @@ def get_request_handler(
response_model_exclude_defaults: bool = False,
response_model_exclude_none: bool = False,
dependency_overrides_provider: Optional[Any] = None,
embed_body_fields: bool = False,
) -> Callable[[Request], Coroutine[Any, Any, Response]]:
assert dependant.call is not None, "dependant.call must be a function"
is_coroutine = iscoroutinefunction(dependant.call)
is_coroutine = asyncio.iscoroutinefunction(dependant.call)
is_body_form = body_field and isinstance(body_field.field_info, params.Form)
if isinstance(response_class, DefaultPlaceholder):
actual_response_class: Type[Response] = response_class.value
@@ -245,149 +216,113 @@ def get_request_handler(
actual_response_class = response_class
async def app(request: Request) -> Response:
response: Union[Response, None] = None
async with AsyncExitStack() as file_stack:
try:
body: Any = None
if body_field:
if is_body_form:
body = await request.form()
file_stack.push_async_callback(body.close)
else:
body_bytes = await request.body()
if body_bytes:
json_body: Any = Undefined
content_type_value = request.headers.get("content-type")
if not content_type_value:
json_body = await request.json()
else:
message = email.message.Message()
message["content-type"] = content_type_value
if message.get_content_maintype() == "application":
subtype = message.get_content_subtype()
if subtype == "json" or subtype.endswith("+json"):
json_body = await request.json()
if json_body != Undefined:
body = json_body
else:
body = body_bytes
except json.JSONDecodeError as e:
validation_error = RequestValidationError(
[
{
"type": "json_invalid",
"loc": ("body", e.pos),
"msg": "JSON decode error",
"input": {},
"ctx": {"error": e.msg},
}
],
body=e.doc,
)
raise validation_error from e
except HTTPException:
# If a middleware raises an HTTPException, it should be raised again
raise
except Exception as e:
http_error = HTTPException(
status_code=400, detail="There was an error parsing the body"
)
raise http_error from e
errors: List[Any] = []
async with AsyncExitStack() as async_exit_stack:
solved_result = await solve_dependencies(
request=request,
dependant=dependant,
body=body,
dependency_overrides_provider=dependency_overrides_provider,
async_exit_stack=async_exit_stack,
embed_body_fields=embed_body_fields,
)
errors = solved_result.errors
if not errors:
raw_response = await run_endpoint_function(
dependant=dependant,
values=solved_result.values,
is_coroutine=is_coroutine,
)
if isinstance(raw_response, Response):
if raw_response.background is None:
raw_response.background = solved_result.background_tasks
response = raw_response
else:
response_args: Dict[str, Any] = {
"background": solved_result.background_tasks
}
# If status_code was set, use it, otherwise use the default from the
# response class, in the case of redirect it's 307
current_status_code = (
status_code
if status_code
else solved_result.response.status_code
)
if current_status_code is not None:
response_args["status_code"] = current_status_code
if solved_result.response.status_code:
response_args["status_code"] = (
solved_result.response.status_code
)
content = await serialize_response(
field=response_field,
response_content=raw_response,
include=response_model_include,
exclude=response_model_exclude,
by_alias=response_model_by_alias,
exclude_unset=response_model_exclude_unset,
exclude_defaults=response_model_exclude_defaults,
exclude_none=response_model_exclude_none,
is_coroutine=is_coroutine,
)
response = actual_response_class(content, **response_args)
if not is_body_allowed_for_status_code(response.status_code):
response.body = b""
response.headers.raw.extend(solved_result.response.headers.raw)
if errors:
validation_error = RequestValidationError(
_normalize_errors(errors), body=body
)
raise validation_error
if response is None:
raise FastAPIError(
"No response object was returned. There's a high chance that the "
"application code is raising an exception and a dependency with yield "
"has a block with a bare except, or a block with except Exception, "
"and is not raising the exception again. Read more about it in the "
"docs: https://fastapi.tiangolo.com/tutorial/dependencies/dependencies-with-yield/#dependencies-with-yield-and-except"
try:
body: Any = None
if body_field:
if is_body_form:
body = await request.form()
stack = request.scope.get("fastapi_astack")
assert isinstance(stack, AsyncExitStack)
stack.push_async_callback(body.close)
else:
body_bytes = await request.body()
if body_bytes:
json_body: Any = Undefined
content_type_value = request.headers.get("content-type")
if not content_type_value:
json_body = await request.json()
else:
message = email.message.Message()
message["content-type"] = content_type_value
if message.get_content_maintype() == "application":
subtype = message.get_content_subtype()
if subtype == "json" or subtype.endswith("+json"):
json_body = await request.json()
if json_body != Undefined:
body = json_body
else:
body = body_bytes
except json.JSONDecodeError as e:
raise RequestValidationError(
[
{
"type": "json_invalid",
"loc": ("body", e.pos),
"msg": "JSON decode error",
"input": {},
"ctx": {"error": e.msg},
}
],
body=e.doc,
) from e
except HTTPException:
raise
except Exception as e:
raise HTTPException(
status_code=400, detail="There was an error parsing the body"
) from e
solved_result = await solve_dependencies(
request=request,
dependant=dependant,
body=body,
dependency_overrides_provider=dependency_overrides_provider,
)
values, errors, background_tasks, sub_response, _ = solved_result
if errors:
raise RequestValidationError(_normalize_errors(errors), body=body)
else:
raw_response = await run_endpoint_function(
dependant=dependant, values=values, is_coroutine=is_coroutine
)
return response
if isinstance(raw_response, Response):
if raw_response.background is None:
raw_response.background = background_tasks
return raw_response
response_args: Dict[str, Any] = {"background": background_tasks}
# If status_code was set, use it, otherwise use the default from the
# response class, in the case of redirect it's 307
current_status_code = (
status_code if status_code else sub_response.status_code
)
if current_status_code is not None:
response_args["status_code"] = current_status_code
if sub_response.status_code:
response_args["status_code"] = sub_response.status_code
content = await serialize_response(
field=response_field,
response_content=raw_response,
include=response_model_include,
exclude=response_model_exclude,
by_alias=response_model_by_alias,
exclude_unset=response_model_exclude_unset,
exclude_defaults=response_model_exclude_defaults,
exclude_none=response_model_exclude_none,
is_coroutine=is_coroutine,
)
response = actual_response_class(content, **response_args)
if not is_body_allowed_for_status_code(response.status_code):
response.body = b""
response.headers.raw.extend(sub_response.headers.raw)
return response
return app
def get_websocket_app(
dependant: Dependant,
dependency_overrides_provider: Optional[Any] = None,
embed_body_fields: bool = False,
dependant: Dependant, dependency_overrides_provider: Optional[Any] = None
) -> Callable[[WebSocket], Coroutine[Any, Any, Any]]:
async def app(websocket: WebSocket) -> None:
async with AsyncExitStack() as async_exit_stack:
# TODO: remove this scope later, after a few releases
# This scope fastapi_astack is no longer used by FastAPI, kept for
# compatibility, just in case
websocket.scope["fastapi_astack"] = async_exit_stack
solved_result = await solve_dependencies(
request=websocket,
dependant=dependant,
dependency_overrides_provider=dependency_overrides_provider,
async_exit_stack=async_exit_stack,
embed_body_fields=embed_body_fields,
)
if solved_result.errors:
raise WebSocketRequestValidationError(
_normalize_errors(solved_result.errors)
)
assert dependant.call is not None, "dependant.call must be a function"
await dependant.call(**solved_result.values)
solved_result = await solve_dependencies(
request=websocket,
dependant=dependant,
dependency_overrides_provider=dependency_overrides_provider,
)
values, errors, _, _2, _3 = solved_result
if errors:
raise WebSocketRequestValidationError(_normalize_errors(errors))
assert dependant.call is not None, "dependant.call must be a function"
await dependant.call(**values)
return app
@@ -413,15 +348,11 @@ class APIWebSocketRoute(routing.WebSocketRoute):
0,
get_parameterless_sub_dependant(depends=depends, path=self.path_format),
)
self._flat_dependant = get_flat_dependant(self.dependant)
self._embed_body_fields = _should_embed_body_fields(
self._flat_dependant.body_params
)
self.app = websocket_session(
get_websocket_app(
dependant=self.dependant,
dependency_overrides_provider=dependency_overrides_provider,
embed_body_fields=self._embed_body_fields,
)
)
@@ -500,9 +431,9 @@ class APIRoute(routing.Route):
methods = ["GET"]
self.methods: Set[str] = {method.upper() for method in methods}
if isinstance(generate_unique_id_function, DefaultPlaceholder):
current_generate_unique_id: Callable[[APIRoute], str] = (
generate_unique_id_function.value
)
current_generate_unique_id: Callable[
["APIRoute"], str
] = generate_unique_id_function.value
else:
current_generate_unique_id = generate_unique_id_function
self.unique_id = self.operation_id or current_generate_unique_id(self)
@@ -511,11 +442,11 @@ class APIRoute(routing.Route):
status_code = int(status_code)
self.status_code = status_code
if self.response_model:
assert is_body_allowed_for_status_code(status_code), (
f"Status code {status_code} must not have a response body"
)
assert is_body_allowed_for_status_code(
status_code
), f"Status code {status_code} must not have a response body"
response_name = "Response_" + self.unique_id
self.response_field = create_model_field(
self.response_field = create_response_field(
name=response_name,
type_=self.response_model,
mode="serialization",
@@ -528,9 +459,9 @@ class APIRoute(routing.Route):
# By being a new field, no inheritance will be passed as is. A new model
# will always be created.
# TODO: remove when deprecating Pydantic v1
self.secure_cloned_response_field: Optional[ModelField] = (
create_cloned_field(self.response_field)
)
self.secure_cloned_response_field: Optional[
ModelField
] = create_cloned_field(self.response_field)
else:
self.response_field = None # type: ignore
self.secure_cloned_response_field = None
@@ -544,13 +475,11 @@ class APIRoute(routing.Route):
assert isinstance(response, dict), "An additional response must be a dict"
model = response.get("model")
if model:
assert is_body_allowed_for_status_code(additional_status_code), (
f"Status code {additional_status_code} must not have a response body"
)
assert is_body_allowed_for_status_code(
additional_status_code
), f"Status code {additional_status_code} must not have a response body"
response_name = f"Response_{additional_status_code}_{self.unique_id}"
response_field = create_model_field(
name=response_name, type_=model, mode="serialization"
)
response_field = create_response_field(name=response_name, type_=model)
response_fields[additional_status_code] = response_field
if response_fields:
self.response_fields: Dict[Union[int, str], ModelField] = response_fields
@@ -564,15 +493,7 @@ class APIRoute(routing.Route):
0,
get_parameterless_sub_dependant(depends=depends, path=self.path_format),
)
self._flat_dependant = get_flat_dependant(self.dependant)
self._embed_body_fields = _should_embed_body_fields(
self._flat_dependant.body_params
)
self.body_field = get_body_field(
flat_dependant=self._flat_dependant,
name=self.unique_id,
embed_body_fields=self._embed_body_fields,
)
self.body_field = get_body_field(dependant=self.dependant, name=self.unique_id)
self.app = request_response(self.get_route_handler())
def get_route_handler(self) -> Callable[[Request], Coroutine[Any, Any, Response]]:
@@ -589,7 +510,6 @@ class APIRoute(routing.Route):
response_model_exclude_defaults=self.response_model_exclude_defaults,
response_model_exclude_none=self.response_model_exclude_none,
dependency_overrides_provider=self.dependency_overrides_provider,
embed_body_fields=self._embed_body_fields,
)
def matches(self, scope: Scope) -> Tuple[Match, Scope]:
@@ -821,7 +741,7 @@ class APIRouter(routing.Router):
This affects the generated OpenAPI (e.g. visible at `/docs`).
Read more about it in the
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-parameters-from-openapi).
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-from-openapi).
"""
),
] = True,
@@ -851,9 +771,9 @@ class APIRouter(routing.Router):
)
if prefix:
assert prefix.startswith("/"), "A path prefix must start with '/'"
assert not prefix.endswith("/"), (
"A path prefix must not end with '/', as the routes will start with '/'"
)
assert not prefix.endswith(
"/"
), "A path prefix must not end with '/', as the routes will start with '/'"
self.prefix = prefix
self.tags: List[Union[str, Enum]] = tags or []
self.dependencies = list(dependencies or [])
@@ -869,7 +789,7 @@ class APIRouter(routing.Router):
def route(
self,
path: str,
methods: Optional[Collection[str]] = None,
methods: Optional[List[str]] = None,
name: Optional[str] = None,
include_in_schema: bool = True,
) -> Callable[[DecoratedCallable], DecoratedCallable]:
@@ -1263,9 +1183,9 @@ class APIRouter(routing.Router):
"""
if prefix:
assert prefix.startswith("/"), "A path prefix must start with '/'"
assert not prefix.endswith("/"), (
"A path prefix must not end with '/', as the routes will start with '/'"
)
assert not prefix.endswith(
"/"
), "A path prefix must not end with '/', as the routes will start with '/'"
else:
for r in router.routes:
path = getattr(r, "path") # noqa: B009
@@ -1365,10 +1285,6 @@ class APIRouter(routing.Router):
self.add_event_handler("startup", handler)
for handler in router.on_shutdown:
self.add_event_handler("shutdown", handler)
self.lifespan_context = _merge_lifespan_context(
self.lifespan_context,
router.lifespan_context,
)
def get(
self,
@@ -1633,7 +1549,7 @@ class APIRouter(routing.Router):
This affects the generated OpenAPI (e.g. visible at `/docs`).
Read more about it in the
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-parameters-from-openapi).
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-from-openapi).
"""
),
] = True,
@@ -2010,7 +1926,7 @@ class APIRouter(routing.Router):
This affects the generated OpenAPI (e.g. visible at `/docs`).
Read more about it in the
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-parameters-from-openapi).
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-from-openapi).
"""
),
] = True,
@@ -2392,7 +2308,7 @@ class APIRouter(routing.Router):
This affects the generated OpenAPI (e.g. visible at `/docs`).
Read more about it in the
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-parameters-from-openapi).
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-from-openapi).
"""
),
] = True,
@@ -2774,7 +2690,7 @@ class APIRouter(routing.Router):
This affects the generated OpenAPI (e.g. visible at `/docs`).
Read more about it in the
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-parameters-from-openapi).
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-from-openapi).
"""
),
] = True,
@@ -3151,7 +3067,7 @@ class APIRouter(routing.Router):
This affects the generated OpenAPI (e.g. visible at `/docs`).
Read more about it in the
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-parameters-from-openapi).
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-from-openapi).
"""
),
] = True,
@@ -3528,7 +3444,7 @@ class APIRouter(routing.Router):
This affects the generated OpenAPI (e.g. visible at `/docs`).
Read more about it in the
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-parameters-from-openapi).
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-from-openapi).
"""
),
] = True,
@@ -3910,7 +3826,7 @@ class APIRouter(routing.Router):
This affects the generated OpenAPI (e.g. visible at `/docs`).
Read more about it in the
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-parameters-from-openapi).
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-from-openapi).
"""
),
] = True,
@@ -4292,7 +4208,7 @@ class APIRouter(routing.Router):
This affects the generated OpenAPI (e.g. visible at `/docs`).
Read more about it in the
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-parameters-from-openapi).
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-from-openapi).
"""
),
] = True,
@@ -4377,7 +4293,7 @@ class APIRouter(routing.Router):
app = FastAPI()
router = APIRouter()
@router.trace("/items/{item_id}")
@router.put("/items/{item_id}")
def trace_item(item_id: str):
return None