This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
"""FastAPI framework, high performance, easy to learn, fast to code, ready for production"""
|
||||
|
||||
__version__ = "0.117.1"
|
||||
__version__ = "0.104.1"
|
||||
|
||||
from starlette import status as status
|
||||
|
||||
|
||||
@@ -1,3 +0,0 @@
|
||||
from fastapi.cli import main
|
||||
|
||||
main()
|
||||
@@ -2,7 +2,6 @@ from collections import deque
|
||||
from copy import copy
|
||||
from dataclasses import dataclass, is_dataclass
|
||||
from enum import Enum
|
||||
from functools import lru_cache
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
@@ -16,7 +15,6 @@ from typing import (
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
from fastapi.exceptions import RequestErrorModel
|
||||
@@ -26,8 +24,7 @@ from pydantic.version import VERSION as PYDANTIC_VERSION
|
||||
from starlette.datastructures import UploadFile
|
||||
from typing_extensions import Annotated, Literal, get_args, get_origin
|
||||
|
||||
PYDANTIC_VERSION_MINOR_TUPLE = tuple(int(x) for x in PYDANTIC_VERSION.split(".")[:2])
|
||||
PYDANTIC_V2 = PYDANTIC_VERSION_MINOR_TUPLE[0] == 2
|
||||
PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.")
|
||||
|
||||
|
||||
sequence_annotation_to_type = {
|
||||
@@ -46,8 +43,6 @@ sequence_annotation_to_type = {
|
||||
|
||||
sequence_types = tuple(sequence_annotation_to_type.keys())
|
||||
|
||||
Url: Type[Any]
|
||||
|
||||
if PYDANTIC_V2:
|
||||
from pydantic import PydanticSchemaGenerationError as PydanticSchemaGenerationError
|
||||
from pydantic import TypeAdapter
|
||||
@@ -73,7 +68,7 @@ if PYDANTIC_V2:
|
||||
general_plain_validator_function as with_info_plain_validator_function, # noqa: F401
|
||||
)
|
||||
|
||||
RequiredParam = PydanticUndefined
|
||||
Required = PydanticUndefined
|
||||
Undefined = PydanticUndefined
|
||||
UndefinedType = PydanticUndefinedType
|
||||
evaluate_forwardref = eval_type_lenient
|
||||
@@ -132,7 +127,7 @@ if PYDANTIC_V2:
|
||||
)
|
||||
except ValidationError as exc:
|
||||
return None, _regenerate_error_with_loc(
|
||||
errors=exc.errors(include_url=False), loc_prefix=loc
|
||||
errors=exc.errors(), loc_prefix=loc
|
||||
)
|
||||
|
||||
def serialize(
|
||||
@@ -232,10 +227,6 @@ if PYDANTIC_V2:
|
||||
field_mapping, definitions = schema_generator.generate_definitions(
|
||||
inputs=inputs
|
||||
)
|
||||
for item_def in cast(Dict[str, Dict[str, Any]], definitions).values():
|
||||
if "description" in item_def:
|
||||
item_description = cast(str, item_def["description"]).split("\f")[0]
|
||||
item_def["description"] = item_description
|
||||
return field_mapping, definitions # type: ignore[return-value]
|
||||
|
||||
def is_scalar_field(field: ModelField) -> bool:
|
||||
@@ -258,12 +249,7 @@ if PYDANTIC_V2:
|
||||
return is_bytes_sequence_annotation(field.type_)
|
||||
|
||||
def copy_field_info(*, field_info: FieldInfo, annotation: Any) -> FieldInfo:
|
||||
cls = type(field_info)
|
||||
merged_field_info = cls.from_annotation(annotation)
|
||||
new_field_info = copy(field_info)
|
||||
new_field_info.metadata = merged_field_info.metadata
|
||||
new_field_info.annotation = merged_field_info.annotation
|
||||
return new_field_info
|
||||
return type(field_info).from_annotation(annotation)
|
||||
|
||||
def serialize_sequence_value(*, field: ModelField, value: Any) -> Sequence[Any]:
|
||||
origin_type = (
|
||||
@@ -275,7 +261,7 @@ if PYDANTIC_V2:
|
||||
def get_missing_field_error(loc: Tuple[str, ...]) -> Dict[str, Any]:
|
||||
error = ValidationError.from_exception_data(
|
||||
"Field required", [{"type": "missing", "loc": loc, "input": {}}]
|
||||
).errors(include_url=False)[0]
|
||||
).errors()[0]
|
||||
error["input"] = None
|
||||
return error # type: ignore[return-value]
|
||||
|
||||
@@ -286,12 +272,6 @@ if PYDANTIC_V2:
|
||||
BodyModel: Type[BaseModel] = create_model(model_name, **field_params) # type: ignore[call-overload]
|
||||
return BodyModel
|
||||
|
||||
def get_model_fields(model: Type[BaseModel]) -> List[ModelField]:
|
||||
return [
|
||||
ModelField(field_info=field_info, name=name)
|
||||
for name, field_info in model.model_fields.items()
|
||||
]
|
||||
|
||||
else:
|
||||
from fastapi.openapi.constants import REF_PREFIX as REF_PREFIX
|
||||
from pydantic import AnyUrl as Url # noqa: F401
|
||||
@@ -319,10 +299,9 @@ else:
|
||||
from pydantic.fields import ( # type: ignore[no-redef,attr-defined]
|
||||
ModelField as ModelField, # noqa: F401
|
||||
)
|
||||
|
||||
# Keeping old "Required" functionality from Pydantic V1, without
|
||||
# shadowing typing.Required.
|
||||
RequiredParam: Any = Ellipsis # type: ignore[no-redef]
|
||||
from pydantic.fields import ( # type: ignore[no-redef,attr-defined]
|
||||
Required as Required, # noqa: F401
|
||||
)
|
||||
from pydantic.fields import ( # type: ignore[no-redef,attr-defined]
|
||||
Undefined as Undefined,
|
||||
)
|
||||
@@ -393,10 +372,9 @@ else:
|
||||
)
|
||||
definitions.update(m_definitions)
|
||||
model_name = model_name_map[model]
|
||||
definitions[model_name] = m_schema
|
||||
for m_schema in definitions.values():
|
||||
if "description" in m_schema:
|
||||
m_schema["description"] = m_schema["description"].split("\f")[0]
|
||||
definitions[model_name] = m_schema
|
||||
return definitions
|
||||
|
||||
def is_pv1_scalar_field(field: ModelField) -> bool:
|
||||
@@ -528,9 +506,6 @@ else:
|
||||
BodyModel.__fields__[f.name] = f # type: ignore[index]
|
||||
return BodyModel
|
||||
|
||||
def get_model_fields(model: Type[BaseModel]) -> List[ModelField]:
|
||||
return list(model.__fields__.values()) # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def _regenerate_error_with_loc(
|
||||
*, errors: Sequence[Any], loc_prefix: Tuple[Union[str, int], ...]
|
||||
@@ -550,12 +525,6 @@ def _annotation_is_sequence(annotation: Union[Type[Any], None]) -> bool:
|
||||
|
||||
|
||||
def field_annotation_is_sequence(annotation: Union[Type[Any], None]) -> bool:
|
||||
origin = get_origin(annotation)
|
||||
if origin is Union or origin is UnionType:
|
||||
for arg in get_args(annotation):
|
||||
if field_annotation_is_sequence(arg):
|
||||
return True
|
||||
return False
|
||||
return _annotation_is_sequence(annotation) or _annotation_is_sequence(
|
||||
get_origin(annotation)
|
||||
)
|
||||
@@ -658,8 +627,3 @@ def is_uploadfile_sequence_annotation(annotation: Any) -> bool:
|
||||
is_uploadfile_or_nonable_uploadfile_annotation(sub_annotation)
|
||||
for sub_annotation in get_args(annotation)
|
||||
)
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_cached_model_fields(model: Type[BaseModel]) -> List[ModelField]:
|
||||
return get_model_fields(model)
|
||||
|
||||
@@ -22,6 +22,7 @@ from fastapi.exception_handlers import (
|
||||
)
|
||||
from fastapi.exceptions import RequestValidationError, WebSocketRequestValidationError
|
||||
from fastapi.logger import logger
|
||||
from fastapi.middleware.asyncexitstack import AsyncExitStackMiddleware
|
||||
from fastapi.openapi.docs import (
|
||||
get_redoc_html,
|
||||
get_swagger_ui_html,
|
||||
@@ -36,11 +37,13 @@ from starlette.datastructures import State
|
||||
from starlette.exceptions import HTTPException
|
||||
from starlette.middleware import Middleware
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.middleware.errors import ServerErrorMiddleware
|
||||
from starlette.middleware.exceptions import ExceptionMiddleware
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import HTMLResponse, JSONResponse, Response
|
||||
from starlette.routing import BaseRoute
|
||||
from starlette.types import ASGIApp, Lifespan, Receive, Scope, Send
|
||||
from typing_extensions import Annotated, Doc, deprecated
|
||||
from typing_extensions import Annotated, Doc, deprecated # type: ignore [attr-defined]
|
||||
|
||||
AppType = TypeVar("AppType", bound="FastAPI")
|
||||
|
||||
@@ -297,7 +300,7 @@ class FastAPI(Starlette):
|
||||
browser tabs open). Or if you want to leave fixed the possible URLs.
|
||||
|
||||
If the servers `list` is not provided, or is an empty `list`, the
|
||||
default value would be a `dict` with a `url` value of `/`.
|
||||
default value would be a a `dict` with a `url` value of `/`.
|
||||
|
||||
Each item in the `list` is a `dict` containing:
|
||||
|
||||
@@ -748,7 +751,7 @@ class FastAPI(Starlette):
|
||||
This affects the generated OpenAPI (e.g. visible at `/docs`).
|
||||
|
||||
Read more about it in the
|
||||
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-parameters-from-openapi).
|
||||
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-from-openapi).
|
||||
"""
|
||||
),
|
||||
] = True,
|
||||
@@ -810,32 +813,6 @@ class FastAPI(Starlette):
|
||||
"""
|
||||
),
|
||||
] = True,
|
||||
openapi_external_docs: Annotated[
|
||||
Optional[Dict[str, Any]],
|
||||
Doc(
|
||||
"""
|
||||
This field allows you to provide additional external documentation links.
|
||||
If provided, it must be a dictionary containing:
|
||||
|
||||
* `description`: A brief description of the external documentation.
|
||||
* `url`: The URL pointing to the external documentation. The value **MUST**
|
||||
be a valid URL format.
|
||||
|
||||
**Example**:
|
||||
|
||||
```python
|
||||
from fastapi import FastAPI
|
||||
|
||||
external_docs = {
|
||||
"description": "Detailed API Reference",
|
||||
"url": "https://example.com/api-docs",
|
||||
}
|
||||
|
||||
app = FastAPI(openapi_external_docs=external_docs)
|
||||
```
|
||||
"""
|
||||
),
|
||||
] = None,
|
||||
**extra: Annotated[
|
||||
Any,
|
||||
Doc(
|
||||
@@ -864,7 +841,6 @@ class FastAPI(Starlette):
|
||||
self.swagger_ui_parameters = swagger_ui_parameters
|
||||
self.servers = servers or []
|
||||
self.separate_input_output_schemas = separate_input_output_schemas
|
||||
self.openapi_external_docs = openapi_external_docs
|
||||
self.extra = extra
|
||||
self.openapi_version: Annotated[
|
||||
str,
|
||||
@@ -929,7 +905,7 @@ class FastAPI(Starlette):
|
||||
A state object for the application. This is the same object for the
|
||||
entire application, it doesn't change from request to request.
|
||||
|
||||
You normally wouldn't use this in FastAPI, for most of the cases you
|
||||
You normally woudln't use this in FastAPI, for most of the cases you
|
||||
would instead use FastAPI dependencies.
|
||||
|
||||
This is simply inherited from Starlette.
|
||||
@@ -990,6 +966,55 @@ class FastAPI(Starlette):
|
||||
self.middleware_stack: Union[ASGIApp, None] = None
|
||||
self.setup()
|
||||
|
||||
def build_middleware_stack(self) -> ASGIApp:
|
||||
# Duplicate/override from Starlette to add AsyncExitStackMiddleware
|
||||
# inside of ExceptionMiddleware, inside of custom user middlewares
|
||||
debug = self.debug
|
||||
error_handler = None
|
||||
exception_handlers = {}
|
||||
|
||||
for key, value in self.exception_handlers.items():
|
||||
if key in (500, Exception):
|
||||
error_handler = value
|
||||
else:
|
||||
exception_handlers[key] = value
|
||||
|
||||
middleware = (
|
||||
[Middleware(ServerErrorMiddleware, handler=error_handler, debug=debug)]
|
||||
+ self.user_middleware
|
||||
+ [
|
||||
Middleware(
|
||||
ExceptionMiddleware, handlers=exception_handlers, debug=debug
|
||||
),
|
||||
# Add FastAPI-specific AsyncExitStackMiddleware for dependencies with
|
||||
# contextvars.
|
||||
# This needs to happen after user middlewares because those create a
|
||||
# new contextvars context copy by using a new AnyIO task group.
|
||||
# The initial part of dependencies with 'yield' is executed in the
|
||||
# FastAPI code, inside all the middlewares. However, the teardown part
|
||||
# (after 'yield') is executed in the AsyncExitStack in this middleware.
|
||||
# If the AsyncExitStack lived outside of the custom middlewares and
|
||||
# contextvars were set in a dependency with 'yield' in that internal
|
||||
# contextvars context, the values would not be available in the
|
||||
# outer context of the AsyncExitStack.
|
||||
# By placing the middleware and the AsyncExitStack here, inside all
|
||||
# user middlewares, the code before and after 'yield' in dependencies
|
||||
# with 'yield' is executed in the same contextvars context. Thus, all values
|
||||
# set in contextvars before 'yield' are still available after 'yield,' as
|
||||
# expected.
|
||||
# Additionally, by having this AsyncExitStack here, after the
|
||||
# ExceptionMiddleware, dependencies can now catch handled exceptions,
|
||||
# e.g. HTTPException, to customize the teardown code (e.g. DB session
|
||||
# rollback).
|
||||
Middleware(AsyncExitStackMiddleware),
|
||||
]
|
||||
)
|
||||
|
||||
app = self.router
|
||||
for cls, options in reversed(middleware):
|
||||
app = cls(app=app, **options)
|
||||
return app
|
||||
|
||||
def openapi(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Generate the OpenAPI schema of the application. This is called by FastAPI
|
||||
@@ -1019,7 +1044,6 @@ class FastAPI(Starlette):
|
||||
tags=self.openapi_tags,
|
||||
servers=self.servers,
|
||||
separate_input_output_schemas=self.separate_input_output_schemas,
|
||||
external_docs=self.openapi_external_docs,
|
||||
)
|
||||
return self.openapi_schema
|
||||
|
||||
@@ -1047,7 +1071,7 @@ class FastAPI(Starlette):
|
||||
oauth2_redirect_url = root_path + oauth2_redirect_url
|
||||
return get_swagger_ui_html(
|
||||
openapi_url=openapi_url,
|
||||
title=f"{self.title} - Swagger UI",
|
||||
title=self.title + " - Swagger UI",
|
||||
oauth2_redirect_url=oauth2_redirect_url,
|
||||
init_oauth=self.swagger_ui_init_oauth,
|
||||
swagger_ui_parameters=self.swagger_ui_parameters,
|
||||
@@ -1071,7 +1095,7 @@ class FastAPI(Starlette):
|
||||
root_path = req.scope.get("root_path", "").rstrip("/")
|
||||
openapi_url = root_path + self.openapi_url
|
||||
return get_redoc_html(
|
||||
openapi_url=openapi_url, title=f"{self.title} - ReDoc"
|
||||
openapi_url=openapi_url, title=self.title + " - ReDoc"
|
||||
)
|
||||
|
||||
self.add_route(self.redoc_url, redoc_html, include_in_schema=False)
|
||||
@@ -1084,7 +1108,7 @@ class FastAPI(Starlette):
|
||||
def add_api_route(
|
||||
self,
|
||||
path: str,
|
||||
endpoint: Callable[..., Any],
|
||||
endpoint: Callable[..., Coroutine[Any, Any, Response]],
|
||||
*,
|
||||
response_model: Any = Default(None),
|
||||
status_code: Optional[int] = None,
|
||||
@@ -1748,7 +1772,7 @@ class FastAPI(Starlette):
|
||||
This affects the generated OpenAPI (e.g. visible at `/docs`).
|
||||
|
||||
Read more about it in the
|
||||
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-parameters-from-openapi).
|
||||
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-from-openapi).
|
||||
"""
|
||||
),
|
||||
] = True,
|
||||
@@ -2121,7 +2145,7 @@ class FastAPI(Starlette):
|
||||
This affects the generated OpenAPI (e.g. visible at `/docs`).
|
||||
|
||||
Read more about it in the
|
||||
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-parameters-from-openapi).
|
||||
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-from-openapi).
|
||||
"""
|
||||
),
|
||||
] = True,
|
||||
@@ -2499,7 +2523,7 @@ class FastAPI(Starlette):
|
||||
This affects the generated OpenAPI (e.g. visible at `/docs`).
|
||||
|
||||
Read more about it in the
|
||||
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-parameters-from-openapi).
|
||||
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-from-openapi).
|
||||
"""
|
||||
),
|
||||
] = True,
|
||||
@@ -2877,7 +2901,7 @@ class FastAPI(Starlette):
|
||||
This affects the generated OpenAPI (e.g. visible at `/docs`).
|
||||
|
||||
Read more about it in the
|
||||
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-parameters-from-openapi).
|
||||
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-from-openapi).
|
||||
"""
|
||||
),
|
||||
] = True,
|
||||
@@ -3250,7 +3274,7 @@ class FastAPI(Starlette):
|
||||
This affects the generated OpenAPI (e.g. visible at `/docs`).
|
||||
|
||||
Read more about it in the
|
||||
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-parameters-from-openapi).
|
||||
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-from-openapi).
|
||||
"""
|
||||
),
|
||||
] = True,
|
||||
@@ -3623,7 +3647,7 @@ class FastAPI(Starlette):
|
||||
This affects the generated OpenAPI (e.g. visible at `/docs`).
|
||||
|
||||
Read more about it in the
|
||||
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-parameters-from-openapi).
|
||||
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-from-openapi).
|
||||
"""
|
||||
),
|
||||
] = True,
|
||||
@@ -3996,7 +4020,7 @@ class FastAPI(Starlette):
|
||||
This affects the generated OpenAPI (e.g. visible at `/docs`).
|
||||
|
||||
Read more about it in the
|
||||
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-parameters-from-openapi).
|
||||
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-from-openapi).
|
||||
"""
|
||||
),
|
||||
] = True,
|
||||
@@ -4374,7 +4398,7 @@ class FastAPI(Starlette):
|
||||
This affects the generated OpenAPI (e.g. visible at `/docs`).
|
||||
|
||||
Read more about it in the
|
||||
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-parameters-from-openapi).
|
||||
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-from-openapi).
|
||||
"""
|
||||
),
|
||||
] = True,
|
||||
@@ -4453,7 +4477,7 @@ class FastAPI(Starlette):
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
@app.trace("/items/{item_id}")
|
||||
@app.put("/items/{item_id}")
|
||||
def trace_item(item_id: str):
|
||||
return None
|
||||
```
|
||||
@@ -4543,17 +4567,14 @@ class FastAPI(Starlette):
|
||||
|
||||
```python
|
||||
import time
|
||||
from typing import Awaitable, Callable
|
||||
|
||||
from fastapi import FastAPI, Request, Response
|
||||
from fastapi import FastAPI, Request
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
@app.middleware("http")
|
||||
async def add_process_time_header(
|
||||
request: Request, call_next: Callable[[Request], Awaitable[Response]]
|
||||
) -> Response:
|
||||
async def add_process_time_header(request: Request, call_next):
|
||||
start_time = time.time()
|
||||
response = await call_next(request)
|
||||
process_time = time.time() - start_time
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import Any, Callable
|
||||
|
||||
from starlette.background import BackgroundTasks as StarletteBackgroundTasks
|
||||
from typing_extensions import Annotated, Doc, ParamSpec
|
||||
from typing_extensions import Annotated, Doc, ParamSpec # type: ignore [attr-defined]
|
||||
|
||||
P = ParamSpec("P")
|
||||
|
||||
|
||||
@@ -1,13 +0,0 @@
|
||||
try:
|
||||
from fastapi_cli.cli import main as cli_main
|
||||
|
||||
except ImportError: # pragma: no cover
|
||||
cli_main = None # type: ignore
|
||||
|
||||
|
||||
def main() -> None:
|
||||
if not cli_main: # type: ignore[truthy-function]
|
||||
message = 'To use the fastapi command, please install "fastapi[standard]":\n\n\tpip install "fastapi[standard]"\n'
|
||||
print(message)
|
||||
raise RuntimeError(message) # noqa: B904
|
||||
cli_main()
|
||||
@@ -1,7 +1,8 @@
|
||||
from contextlib import AsyncExitStack as AsyncExitStack # noqa
|
||||
from contextlib import asynccontextmanager as asynccontextmanager
|
||||
from typing import AsyncGenerator, ContextManager, TypeVar
|
||||
|
||||
import anyio.to_thread
|
||||
import anyio
|
||||
from anyio import CapacityLimiter
|
||||
from starlette.concurrency import iterate_in_threadpool as iterate_in_threadpool # noqa
|
||||
from starlette.concurrency import run_in_threadpool as run_in_threadpool # noqa
|
||||
@@ -28,7 +29,7 @@ async def contextmanager_in_threadpool(
|
||||
except Exception as e:
|
||||
ok = bool(
|
||||
await anyio.to_thread.run_sync(
|
||||
cm.__exit__, type(e), e, e.__traceback__, limiter=exit_limiter
|
||||
cm.__exit__, type(e), e, None, limiter=exit_limiter
|
||||
)
|
||||
)
|
||||
if not ok:
|
||||
|
||||
@@ -24,7 +24,7 @@ from starlette.datastructures import Headers as Headers # noqa: F401
|
||||
from starlette.datastructures import QueryParams as QueryParams # noqa: F401
|
||||
from starlette.datastructures import State as State # noqa: F401
|
||||
from starlette.datastructures import UploadFile as StarletteUploadFile
|
||||
from typing_extensions import Annotated, Doc
|
||||
from typing_extensions import Annotated, Doc # type: ignore [attr-defined]
|
||||
|
||||
|
||||
class UploadFile(StarletteUploadFile):
|
||||
|
||||
@@ -1,37 +1,58 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable, List, Optional, Sequence, Tuple
|
||||
from typing import Any, Callable, List, Optional, Sequence
|
||||
|
||||
from fastapi._compat import ModelField
|
||||
from fastapi.security.base import SecurityBase
|
||||
|
||||
|
||||
@dataclass
|
||||
class SecurityRequirement:
|
||||
security_scheme: SecurityBase
|
||||
scopes: Optional[Sequence[str]] = None
|
||||
def __init__(
|
||||
self, security_scheme: SecurityBase, scopes: Optional[Sequence[str]] = None
|
||||
):
|
||||
self.security_scheme = security_scheme
|
||||
self.scopes = scopes
|
||||
|
||||
|
||||
@dataclass
|
||||
class Dependant:
|
||||
path_params: List[ModelField] = field(default_factory=list)
|
||||
query_params: List[ModelField] = field(default_factory=list)
|
||||
header_params: List[ModelField] = field(default_factory=list)
|
||||
cookie_params: List[ModelField] = field(default_factory=list)
|
||||
body_params: List[ModelField] = field(default_factory=list)
|
||||
dependencies: List["Dependant"] = field(default_factory=list)
|
||||
security_requirements: List[SecurityRequirement] = field(default_factory=list)
|
||||
name: Optional[str] = None
|
||||
call: Optional[Callable[..., Any]] = None
|
||||
request_param_name: Optional[str] = None
|
||||
websocket_param_name: Optional[str] = None
|
||||
http_connection_param_name: Optional[str] = None
|
||||
response_param_name: Optional[str] = None
|
||||
background_tasks_param_name: Optional[str] = None
|
||||
security_scopes_param_name: Optional[str] = None
|
||||
security_scopes: Optional[List[str]] = None
|
||||
use_cache: bool = True
|
||||
path: Optional[str] = None
|
||||
cache_key: Tuple[Optional[Callable[..., Any]], Tuple[str, ...]] = field(init=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
path_params: Optional[List[ModelField]] = None,
|
||||
query_params: Optional[List[ModelField]] = None,
|
||||
header_params: Optional[List[ModelField]] = None,
|
||||
cookie_params: Optional[List[ModelField]] = None,
|
||||
body_params: Optional[List[ModelField]] = None,
|
||||
dependencies: Optional[List["Dependant"]] = None,
|
||||
security_schemes: Optional[List[SecurityRequirement]] = None,
|
||||
name: Optional[str] = None,
|
||||
call: Optional[Callable[..., Any]] = None,
|
||||
request_param_name: Optional[str] = None,
|
||||
websocket_param_name: Optional[str] = None,
|
||||
http_connection_param_name: Optional[str] = None,
|
||||
response_param_name: Optional[str] = None,
|
||||
background_tasks_param_name: Optional[str] = None,
|
||||
security_scopes_param_name: Optional[str] = None,
|
||||
security_scopes: Optional[List[str]] = None,
|
||||
use_cache: bool = True,
|
||||
path: Optional[str] = None,
|
||||
) -> None:
|
||||
self.path_params = path_params or []
|
||||
self.query_params = query_params or []
|
||||
self.header_params = header_params or []
|
||||
self.cookie_params = cookie_params or []
|
||||
self.body_params = body_params or []
|
||||
self.dependencies = dependencies or []
|
||||
self.security_requirements = security_schemes or []
|
||||
self.request_param_name = request_param_name
|
||||
self.websocket_param_name = websocket_param_name
|
||||
self.http_connection_param_name = http_connection_param_name
|
||||
self.response_param_name = response_param_name
|
||||
self.background_tasks_param_name = background_tasks_param_name
|
||||
self.security_scopes = security_scopes
|
||||
self.security_scopes_param_name = security_scopes_param_name
|
||||
self.name = name
|
||||
self.call = call
|
||||
self.use_cache = use_cache
|
||||
# Store the path to be able to re-generate a dependable from it in overrides
|
||||
self.path = path
|
||||
# Save the cache key at creation to optimize performance
|
||||
self.cache_key = (self.call, tuple(sorted(set(self.security_scopes or []))))
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
import inspect
|
||||
import sys
|
||||
from contextlib import AsyncExitStack, contextmanager
|
||||
from copy import copy, deepcopy
|
||||
from dataclasses import dataclass
|
||||
from contextlib import contextmanager
|
||||
from copy import deepcopy
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
@@ -25,7 +23,7 @@ from fastapi._compat import (
|
||||
PYDANTIC_V2,
|
||||
ErrorWrapper,
|
||||
ModelField,
|
||||
RequiredParam,
|
||||
Required,
|
||||
Undefined,
|
||||
_regenerate_error_with_loc,
|
||||
copy_field_info,
|
||||
@@ -33,7 +31,6 @@ from fastapi._compat import (
|
||||
evaluate_forwardref,
|
||||
field_annotation_is_scalar,
|
||||
get_annotation_from_field_info,
|
||||
get_cached_model_fields,
|
||||
get_missing_field_error,
|
||||
is_bytes_field,
|
||||
is_bytes_sequence_field,
|
||||
@@ -49,6 +46,7 @@ from fastapi._compat import (
|
||||
)
|
||||
from fastapi.background import BackgroundTasks
|
||||
from fastapi.concurrency import (
|
||||
AsyncExitStack,
|
||||
asynccontextmanager,
|
||||
contextmanager_in_threadpool,
|
||||
)
|
||||
@@ -57,28 +55,16 @@ from fastapi.logger import logger
|
||||
from fastapi.security.base import SecurityBase
|
||||
from fastapi.security.oauth2 import OAuth2, SecurityScopes
|
||||
from fastapi.security.open_id_connect_url import OpenIdConnect
|
||||
from fastapi.utils import create_model_field, get_path_param_names
|
||||
from pydantic import BaseModel
|
||||
from fastapi.utils import create_response_field, get_path_param_names
|
||||
from pydantic.fields import FieldInfo
|
||||
from starlette.background import BackgroundTasks as StarletteBackgroundTasks
|
||||
from starlette.concurrency import run_in_threadpool
|
||||
from starlette.datastructures import (
|
||||
FormData,
|
||||
Headers,
|
||||
ImmutableMultiDict,
|
||||
QueryParams,
|
||||
UploadFile,
|
||||
)
|
||||
from starlette.datastructures import FormData, Headers, QueryParams, UploadFile
|
||||
from starlette.requests import HTTPConnection, Request
|
||||
from starlette.responses import Response
|
||||
from starlette.websockets import WebSocket
|
||||
from typing_extensions import Annotated, get_args, get_origin
|
||||
|
||||
if sys.version_info >= (3, 13): # pragma: no cover
|
||||
from inspect import iscoroutinefunction
|
||||
else: # pragma: no cover
|
||||
from asyncio import iscoroutinefunction
|
||||
|
||||
multipart_not_installed_error = (
|
||||
'Form data requires "python-multipart" to be installed. \n'
|
||||
'You can install "python-multipart" with: \n\n'
|
||||
@@ -94,23 +80,17 @@ multipart_incorrect_install_error = (
|
||||
)
|
||||
|
||||
|
||||
def ensure_multipart_is_installed() -> None:
|
||||
try:
|
||||
from python_multipart import __version__
|
||||
|
||||
# Import an attribute that can be mocked/deleted in testing
|
||||
assert __version__ > "0.0.12"
|
||||
except (ImportError, AssertionError):
|
||||
def check_file_field(field: ModelField) -> None:
|
||||
field_info = field.field_info
|
||||
if isinstance(field_info, params.Form):
|
||||
try:
|
||||
# __version__ is available in both multiparts, and can be mocked
|
||||
from multipart import __version__ # type: ignore[no-redef,import-untyped]
|
||||
from multipart import __version__ # type: ignore
|
||||
|
||||
assert __version__
|
||||
try:
|
||||
# parse_options_header is only available in the right multipart
|
||||
from multipart.multipart import ( # type: ignore[import-untyped]
|
||||
parse_options_header,
|
||||
)
|
||||
from multipart.multipart import parse_options_header # type: ignore
|
||||
|
||||
assert parse_options_header
|
||||
except ImportError:
|
||||
@@ -139,9 +119,9 @@ def get_param_sub_dependant(
|
||||
|
||||
|
||||
def get_parameterless_sub_dependant(*, depends: params.Depends, path: str) -> Dependant:
|
||||
assert callable(depends.dependency), (
|
||||
"A parameter-less dependency must have a callable dependency"
|
||||
)
|
||||
assert callable(
|
||||
depends.dependency
|
||||
), "A parameter-less dependency must have a callable dependency"
|
||||
return get_sub_dependant(depends=depends, dependency=depends.dependency, path=path)
|
||||
|
||||
|
||||
@@ -196,7 +176,7 @@ def get_flat_dependant(
|
||||
header_params=dependant.header_params.copy(),
|
||||
cookie_params=dependant.cookie_params.copy(),
|
||||
body_params=dependant.body_params.copy(),
|
||||
security_requirements=dependant.security_requirements.copy(),
|
||||
security_schemes=dependant.security_requirements.copy(),
|
||||
use_cache=dependant.use_cache,
|
||||
path=dependant.path,
|
||||
)
|
||||
@@ -215,23 +195,14 @@ def get_flat_dependant(
|
||||
return flat_dependant
|
||||
|
||||
|
||||
def _get_flat_fields_from_params(fields: List[ModelField]) -> List[ModelField]:
|
||||
if not fields:
|
||||
return fields
|
||||
first_field = fields[0]
|
||||
if len(fields) == 1 and lenient_issubclass(first_field.type_, BaseModel):
|
||||
fields_to_extract = get_cached_model_fields(first_field.type_)
|
||||
return fields_to_extract
|
||||
return fields
|
||||
|
||||
|
||||
def get_flat_params(dependant: Dependant) -> List[ModelField]:
|
||||
flat_dependant = get_flat_dependant(dependant, skip_repeats=True)
|
||||
path_params = _get_flat_fields_from_params(flat_dependant.path_params)
|
||||
query_params = _get_flat_fields_from_params(flat_dependant.query_params)
|
||||
header_params = _get_flat_fields_from_params(flat_dependant.header_params)
|
||||
cookie_params = _get_flat_fields_from_params(flat_dependant.cookie_params)
|
||||
return path_params + query_params + header_params + cookie_params
|
||||
return (
|
||||
flat_dependant.path_params
|
||||
+ flat_dependant.query_params
|
||||
+ flat_dependant.header_params
|
||||
+ flat_dependant.cookie_params
|
||||
)
|
||||
|
||||
|
||||
def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature:
|
||||
@@ -254,8 +225,6 @@ def get_typed_annotation(annotation: Any, globalns: Dict[str, Any]) -> Any:
|
||||
if isinstance(annotation, str):
|
||||
annotation = ForwardRef(annotation)
|
||||
annotation = evaluate_forwardref(annotation, globalns, globalns)
|
||||
if annotation is type(None):
|
||||
return None
|
||||
return annotation
|
||||
|
||||
|
||||
@@ -290,16 +259,16 @@ def get_dependant(
|
||||
)
|
||||
for param_name, param in signature_params.items():
|
||||
is_path_param = param_name in path_param_names
|
||||
param_details = analyze_param(
|
||||
type_annotation, depends, param_field = analyze_param(
|
||||
param_name=param_name,
|
||||
annotation=param.annotation,
|
||||
value=param.default,
|
||||
is_path_param=is_path_param,
|
||||
)
|
||||
if param_details.depends is not None:
|
||||
if depends is not None:
|
||||
sub_dependant = get_param_sub_dependant(
|
||||
param_name=param_name,
|
||||
depends=param_details.depends,
|
||||
depends=depends,
|
||||
path=path,
|
||||
security_scopes=security_scopes,
|
||||
)
|
||||
@@ -307,18 +276,18 @@ def get_dependant(
|
||||
continue
|
||||
if add_non_field_param_to_dependency(
|
||||
param_name=param_name,
|
||||
type_annotation=param_details.type_annotation,
|
||||
type_annotation=type_annotation,
|
||||
dependant=dependant,
|
||||
):
|
||||
assert param_details.field is None, (
|
||||
f"Cannot specify multiple FastAPI annotations for {param_name!r}"
|
||||
)
|
||||
assert (
|
||||
param_field is None
|
||||
), f"Cannot specify multiple FastAPI annotations for {param_name!r}"
|
||||
continue
|
||||
assert param_details.field is not None
|
||||
if isinstance(param_details.field.field_info, params.Body):
|
||||
dependant.body_params.append(param_details.field)
|
||||
assert param_field is not None
|
||||
if is_body_param(param_field=param_field, is_path_param=is_path_param):
|
||||
dependant.body_params.append(param_field)
|
||||
else:
|
||||
add_param_to_fields(field=param_details.field, dependant=dependant)
|
||||
add_param_to_fields(field=param_field, dependant=dependant)
|
||||
return dependant
|
||||
|
||||
|
||||
@@ -346,29 +315,20 @@ def add_non_field_param_to_dependency(
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParamDetails:
|
||||
type_annotation: Any
|
||||
depends: Optional[params.Depends]
|
||||
field: Optional[ModelField]
|
||||
|
||||
|
||||
def analyze_param(
|
||||
*,
|
||||
param_name: str,
|
||||
annotation: Any,
|
||||
value: Any,
|
||||
is_path_param: bool,
|
||||
) -> ParamDetails:
|
||||
) -> Tuple[Any, Optional[params.Depends], Optional[ModelField]]:
|
||||
field_info = None
|
||||
depends = None
|
||||
type_annotation: Any = Any
|
||||
use_annotation: Any = Any
|
||||
if annotation is not inspect.Signature.empty:
|
||||
use_annotation = annotation
|
||||
type_annotation = annotation
|
||||
# Extract Annotated info
|
||||
if get_origin(use_annotation) is Annotated:
|
||||
if (
|
||||
annotation is not inspect.Signature.empty
|
||||
and get_origin(annotation) is Annotated
|
||||
):
|
||||
annotated_args = get_args(annotation)
|
||||
type_annotation = annotated_args[0]
|
||||
fastapi_annotations = [
|
||||
@@ -376,26 +336,16 @@ def analyze_param(
|
||||
for arg in annotated_args[1:]
|
||||
if isinstance(arg, (FieldInfo, params.Depends))
|
||||
]
|
||||
fastapi_specific_annotations = [
|
||||
arg
|
||||
for arg in fastapi_annotations
|
||||
if isinstance(arg, (params.Param, params.Body, params.Depends))
|
||||
]
|
||||
if fastapi_specific_annotations:
|
||||
fastapi_annotation: Union[FieldInfo, params.Depends, None] = (
|
||||
fastapi_specific_annotations[-1]
|
||||
)
|
||||
else:
|
||||
fastapi_annotation = None
|
||||
# Set default for Annotated FieldInfo
|
||||
assert (
|
||||
len(fastapi_annotations) <= 1
|
||||
), f"Cannot specify multiple `Annotated` FastAPI arguments for {param_name!r}"
|
||||
fastapi_annotation = next(iter(fastapi_annotations), None)
|
||||
if isinstance(fastapi_annotation, FieldInfo):
|
||||
# Copy `field_info` because we mutate `field_info.default` below.
|
||||
field_info = copy_field_info(
|
||||
field_info=fastapi_annotation, annotation=use_annotation
|
||||
field_info=fastapi_annotation, annotation=annotation
|
||||
)
|
||||
assert (
|
||||
field_info.default is Undefined or field_info.default is RequiredParam
|
||||
), (
|
||||
assert field_info.default is Undefined or field_info.default is Required, (
|
||||
f"`{field_info.__class__.__name__}` default value cannot be set in"
|
||||
f" `Annotated` for {param_name!r}. Set the default value with `=` instead."
|
||||
)
|
||||
@@ -403,11 +353,12 @@ def analyze_param(
|
||||
assert not is_path_param, "Path parameters cannot have default values"
|
||||
field_info.default = value
|
||||
else:
|
||||
field_info.default = RequiredParam
|
||||
# Get Annotated Depends
|
||||
field_info.default = Required
|
||||
elif isinstance(fastapi_annotation, params.Depends):
|
||||
depends = fastapi_annotation
|
||||
# Get Depends from default value
|
||||
elif annotation is not inspect.Signature.empty:
|
||||
type_annotation = annotation
|
||||
|
||||
if isinstance(value, params.Depends):
|
||||
assert depends is None, (
|
||||
"Cannot specify `Depends` in `Annotated` and default value"
|
||||
@@ -418,7 +369,6 @@ def analyze_param(
|
||||
f" default value together for {param_name!r}"
|
||||
)
|
||||
depends = value
|
||||
# Get FieldInfo from default value
|
||||
elif isinstance(value, FieldInfo):
|
||||
assert field_info is None, (
|
||||
"Cannot specify FastAPI annotations in `Annotated` and default value"
|
||||
@@ -428,13 +378,9 @@ def analyze_param(
|
||||
if PYDANTIC_V2:
|
||||
field_info.annotation = type_annotation
|
||||
|
||||
# Get Depends from type annotation
|
||||
if depends is not None and depends.dependency is None:
|
||||
# Copy `depends` before mutating it
|
||||
depends = copy(depends)
|
||||
depends.dependency = type_annotation
|
||||
|
||||
# Handle non-param type annotations like Request
|
||||
if lenient_issubclass(
|
||||
type_annotation,
|
||||
(
|
||||
@@ -447,30 +393,27 @@ def analyze_param(
|
||||
),
|
||||
):
|
||||
assert depends is None, f"Cannot specify `Depends` for type {type_annotation!r}"
|
||||
assert field_info is None, (
|
||||
f"Cannot specify FastAPI annotation for type {type_annotation!r}"
|
||||
)
|
||||
# Handle default assignations, neither field_info nor depends was not found in Annotated nor default value
|
||||
assert (
|
||||
field_info is None
|
||||
), f"Cannot specify FastAPI annotation for type {type_annotation!r}"
|
||||
elif field_info is None and depends is None:
|
||||
default_value = value if value is not inspect.Signature.empty else RequiredParam
|
||||
default_value = value if value is not inspect.Signature.empty else Required
|
||||
if is_path_param:
|
||||
# We might check here that `default_value is RequiredParam`, but the fact is that the same
|
||||
# We might check here that `default_value is Required`, but the fact is that the same
|
||||
# parameter might sometimes be a path parameter and sometimes not. See
|
||||
# `tests/test_infer_param_optionality.py` for an example.
|
||||
field_info = params.Path(annotation=use_annotation)
|
||||
field_info = params.Path(annotation=type_annotation)
|
||||
elif is_uploadfile_or_nonable_uploadfile_annotation(
|
||||
type_annotation
|
||||
) or is_uploadfile_sequence_annotation(type_annotation):
|
||||
field_info = params.File(annotation=use_annotation, default=default_value)
|
||||
field_info = params.File(annotation=type_annotation, default=default_value)
|
||||
elif not field_annotation_is_scalar(annotation=type_annotation):
|
||||
field_info = params.Body(annotation=use_annotation, default=default_value)
|
||||
field_info = params.Body(annotation=type_annotation, default=default_value)
|
||||
else:
|
||||
field_info = params.Query(annotation=use_annotation, default=default_value)
|
||||
field_info = params.Query(annotation=type_annotation, default=default_value)
|
||||
|
||||
field = None
|
||||
# It's a field_info, not a dependency
|
||||
if field_info is not None:
|
||||
# Handle field_info.in_
|
||||
if is_path_param:
|
||||
assert isinstance(field_info, params.Path), (
|
||||
f"Cannot use `{field_info.__class__.__name__}` for path param"
|
||||
@@ -481,67 +424,69 @@ def analyze_param(
|
||||
and getattr(field_info, "in_", None) is None
|
||||
):
|
||||
field_info.in_ = params.ParamTypes.query
|
||||
use_annotation_from_field_info = get_annotation_from_field_info(
|
||||
use_annotation,
|
||||
use_annotation = get_annotation_from_field_info(
|
||||
type_annotation,
|
||||
field_info,
|
||||
param_name,
|
||||
)
|
||||
if isinstance(field_info, params.Form):
|
||||
ensure_multipart_is_installed()
|
||||
if not field_info.alias and getattr(field_info, "convert_underscores", None):
|
||||
alias = param_name.replace("_", "-")
|
||||
else:
|
||||
alias = field_info.alias or param_name
|
||||
field_info.alias = alias
|
||||
field = create_model_field(
|
||||
field = create_response_field(
|
||||
name=param_name,
|
||||
type_=use_annotation_from_field_info,
|
||||
type_=use_annotation,
|
||||
default=field_info.default,
|
||||
alias=alias,
|
||||
required=field_info.default in (RequiredParam, Undefined),
|
||||
required=field_info.default in (Required, Undefined),
|
||||
field_info=field_info,
|
||||
)
|
||||
if is_path_param:
|
||||
assert is_scalar_field(field=field), (
|
||||
"Path params must be of one of the supported types"
|
||||
)
|
||||
elif isinstance(field_info, params.Query):
|
||||
assert (
|
||||
is_scalar_field(field)
|
||||
or is_scalar_sequence_field(field)
|
||||
or (
|
||||
lenient_issubclass(field.type_, BaseModel)
|
||||
# For Pydantic v1
|
||||
and getattr(field, "shape", 1) == 1
|
||||
)
|
||||
)
|
||||
|
||||
return ParamDetails(type_annotation=type_annotation, depends=depends, field=field)
|
||||
return type_annotation, depends, field
|
||||
|
||||
|
||||
def is_body_param(*, param_field: ModelField, is_path_param: bool) -> bool:
|
||||
if is_path_param:
|
||||
assert is_scalar_field(
|
||||
field=param_field
|
||||
), "Path params must be of one of the supported types"
|
||||
return False
|
||||
elif is_scalar_field(field=param_field):
|
||||
return False
|
||||
elif isinstance(
|
||||
param_field.field_info, (params.Query, params.Header)
|
||||
) and is_scalar_sequence_field(param_field):
|
||||
return False
|
||||
else:
|
||||
assert isinstance(
|
||||
param_field.field_info, params.Body
|
||||
), f"Param: {param_field.name} can only be a request body, using Body()"
|
||||
return True
|
||||
|
||||
|
||||
def add_param_to_fields(*, field: ModelField, dependant: Dependant) -> None:
|
||||
field_info = field.field_info
|
||||
field_info_in = getattr(field_info, "in_", None)
|
||||
if field_info_in == params.ParamTypes.path:
|
||||
field_info = cast(params.Param, field.field_info)
|
||||
if field_info.in_ == params.ParamTypes.path:
|
||||
dependant.path_params.append(field)
|
||||
elif field_info_in == params.ParamTypes.query:
|
||||
elif field_info.in_ == params.ParamTypes.query:
|
||||
dependant.query_params.append(field)
|
||||
elif field_info_in == params.ParamTypes.header:
|
||||
elif field_info.in_ == params.ParamTypes.header:
|
||||
dependant.header_params.append(field)
|
||||
else:
|
||||
assert field_info_in == params.ParamTypes.cookie, (
|
||||
f"non-body parameters must be in path, query, header or cookie: {field.name}"
|
||||
)
|
||||
assert (
|
||||
field_info.in_ == params.ParamTypes.cookie
|
||||
), f"non-body parameters must be in path, query, header or cookie: {field.name}"
|
||||
dependant.cookie_params.append(field)
|
||||
|
||||
|
||||
def is_coroutine_callable(call: Callable[..., Any]) -> bool:
|
||||
if inspect.isroutine(call):
|
||||
return iscoroutinefunction(call)
|
||||
return inspect.iscoroutinefunction(call)
|
||||
if inspect.isclass(call):
|
||||
return False
|
||||
dunder_call = getattr(call, "__call__", None) # noqa: B004
|
||||
return iscoroutinefunction(dunder_call)
|
||||
return inspect.iscoroutinefunction(dunder_call)
|
||||
|
||||
|
||||
def is_async_gen_callable(call: Callable[..., Any]) -> bool:
|
||||
@@ -568,15 +513,6 @@ async def solve_generator(
|
||||
return await stack.enter_async_context(cm)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SolvedDependency:
|
||||
values: Dict[str, Any]
|
||||
errors: List[Any]
|
||||
background_tasks: Optional[StarletteBackgroundTasks]
|
||||
response: Response
|
||||
dependency_cache: Dict[Tuple[Callable[..., Any], Tuple[str]], Any]
|
||||
|
||||
|
||||
async def solve_dependencies(
|
||||
*,
|
||||
request: Union[Request, WebSocket],
|
||||
@@ -586,17 +522,20 @@ async def solve_dependencies(
|
||||
response: Optional[Response] = None,
|
||||
dependency_overrides_provider: Optional[Any] = None,
|
||||
dependency_cache: Optional[Dict[Tuple[Callable[..., Any], Tuple[str]], Any]] = None,
|
||||
async_exit_stack: AsyncExitStack,
|
||||
embed_body_fields: bool,
|
||||
) -> SolvedDependency:
|
||||
) -> Tuple[
|
||||
Dict[str, Any],
|
||||
List[Any],
|
||||
Optional[StarletteBackgroundTasks],
|
||||
Response,
|
||||
Dict[Tuple[Callable[..., Any], Tuple[str]], Any],
|
||||
]:
|
||||
values: Dict[str, Any] = {}
|
||||
errors: List[Any] = []
|
||||
if response is None:
|
||||
response = Response()
|
||||
del response.headers["content-length"]
|
||||
response.status_code = None # type: ignore
|
||||
if dependency_cache is None:
|
||||
dependency_cache = {}
|
||||
dependency_cache = dependency_cache or {}
|
||||
sub_dependant: Dependant
|
||||
for sub_dependant in dependant.dependencies:
|
||||
sub_dependant.call = cast(Callable[..., Any], sub_dependant.call)
|
||||
@@ -629,23 +568,30 @@ async def solve_dependencies(
|
||||
response=response,
|
||||
dependency_overrides_provider=dependency_overrides_provider,
|
||||
dependency_cache=dependency_cache,
|
||||
async_exit_stack=async_exit_stack,
|
||||
embed_body_fields=embed_body_fields,
|
||||
)
|
||||
background_tasks = solved_result.background_tasks
|
||||
if solved_result.errors:
|
||||
errors.extend(solved_result.errors)
|
||||
(
|
||||
sub_values,
|
||||
sub_errors,
|
||||
background_tasks,
|
||||
_, # the subdependency returns the same response we have
|
||||
sub_dependency_cache,
|
||||
) = solved_result
|
||||
dependency_cache.update(sub_dependency_cache)
|
||||
if sub_errors:
|
||||
errors.extend(sub_errors)
|
||||
continue
|
||||
if sub_dependant.use_cache and sub_dependant.cache_key in dependency_cache:
|
||||
solved = dependency_cache[sub_dependant.cache_key]
|
||||
elif is_gen_callable(call) or is_async_gen_callable(call):
|
||||
stack = request.scope.get("fastapi_astack")
|
||||
assert isinstance(stack, AsyncExitStack)
|
||||
solved = await solve_generator(
|
||||
call=call, stack=async_exit_stack, sub_values=solved_result.values
|
||||
call=call, stack=stack, sub_values=sub_values
|
||||
)
|
||||
elif is_coroutine_callable(call):
|
||||
solved = await call(**solved_result.values)
|
||||
solved = await call(**sub_values)
|
||||
else:
|
||||
solved = await run_in_threadpool(call, **solved_result.values)
|
||||
solved = await run_in_threadpool(call, **sub_values)
|
||||
if sub_dependant.name is not None:
|
||||
values[sub_dependant.name] = solved
|
||||
if sub_dependant.cache_key not in dependency_cache:
|
||||
@@ -672,9 +618,7 @@ async def solve_dependencies(
|
||||
body_values,
|
||||
body_errors,
|
||||
) = await request_body_to_args( # body_params checked above
|
||||
body_fields=dependant.body_params,
|
||||
received_body=body,
|
||||
embed_body_fields=embed_body_fields,
|
||||
required_params=dependant.body_params, received_body=body
|
||||
)
|
||||
values.update(body_values)
|
||||
errors.extend(body_errors)
|
||||
@@ -694,289 +638,142 @@ async def solve_dependencies(
|
||||
values[dependant.security_scopes_param_name] = SecurityScopes(
|
||||
scopes=dependant.security_scopes
|
||||
)
|
||||
return SolvedDependency(
|
||||
values=values,
|
||||
errors=errors,
|
||||
background_tasks=background_tasks,
|
||||
response=response,
|
||||
dependency_cache=dependency_cache,
|
||||
)
|
||||
|
||||
|
||||
def _validate_value_with_model_field(
|
||||
*, field: ModelField, value: Any, values: Dict[str, Any], loc: Tuple[str, ...]
|
||||
) -> Tuple[Any, List[Any]]:
|
||||
if value is None:
|
||||
if field.required:
|
||||
return None, [get_missing_field_error(loc=loc)]
|
||||
else:
|
||||
return deepcopy(field.default), []
|
||||
v_, errors_ = field.validate(value, values, loc=loc)
|
||||
if isinstance(errors_, ErrorWrapper):
|
||||
return None, [errors_]
|
||||
elif isinstance(errors_, list):
|
||||
new_errors = _regenerate_error_with_loc(errors=errors_, loc_prefix=())
|
||||
return None, new_errors
|
||||
else:
|
||||
return v_, []
|
||||
|
||||
|
||||
def _get_multidict_value(
|
||||
field: ModelField, values: Mapping[str, Any], alias: Union[str, None] = None
|
||||
) -> Any:
|
||||
alias = alias or field.alias
|
||||
if is_sequence_field(field) and isinstance(values, (ImmutableMultiDict, Headers)):
|
||||
value = values.getlist(alias)
|
||||
else:
|
||||
value = values.get(alias, None)
|
||||
if (
|
||||
value is None
|
||||
or (
|
||||
isinstance(field.field_info, params.Form)
|
||||
and isinstance(value, str) # For type checks
|
||||
and value == ""
|
||||
)
|
||||
or (is_sequence_field(field) and len(value) == 0)
|
||||
):
|
||||
if field.required:
|
||||
return
|
||||
else:
|
||||
return deepcopy(field.default)
|
||||
return value
|
||||
return values, errors, background_tasks, response, dependency_cache
|
||||
|
||||
|
||||
def request_params_to_args(
|
||||
fields: Sequence[ModelField],
|
||||
required_params: Sequence[ModelField],
|
||||
received_params: Union[Mapping[str, Any], QueryParams, Headers],
|
||||
) -> Tuple[Dict[str, Any], List[Any]]:
|
||||
values: Dict[str, Any] = {}
|
||||
errors: List[Dict[str, Any]] = []
|
||||
|
||||
if not fields:
|
||||
return values, errors
|
||||
|
||||
first_field = fields[0]
|
||||
fields_to_extract = fields
|
||||
single_not_embedded_field = False
|
||||
default_convert_underscores = True
|
||||
if len(fields) == 1 and lenient_issubclass(first_field.type_, BaseModel):
|
||||
fields_to_extract = get_cached_model_fields(first_field.type_)
|
||||
single_not_embedded_field = True
|
||||
# If headers are in a Pydantic model, the way to disable convert_underscores
|
||||
# would be with Header(convert_underscores=False) at the Pydantic model level
|
||||
default_convert_underscores = getattr(
|
||||
first_field.field_info, "convert_underscores", True
|
||||
)
|
||||
|
||||
params_to_process: Dict[str, Any] = {}
|
||||
|
||||
processed_keys = set()
|
||||
|
||||
for field in fields_to_extract:
|
||||
alias = None
|
||||
if isinstance(received_params, Headers):
|
||||
# Handle fields extracted from a Pydantic Model for a header, each field
|
||||
# doesn't have a FieldInfo of type Header with the default convert_underscores=True
|
||||
convert_underscores = getattr(
|
||||
field.field_info, "convert_underscores", default_convert_underscores
|
||||
)
|
||||
if convert_underscores:
|
||||
alias = (
|
||||
field.alias
|
||||
if field.alias != field.name
|
||||
else field.name.replace("_", "-")
|
||||
)
|
||||
value = _get_multidict_value(field, received_params, alias=alias)
|
||||
if value is not None:
|
||||
params_to_process[field.name] = value
|
||||
processed_keys.add(alias or field.alias)
|
||||
processed_keys.add(field.name)
|
||||
|
||||
for key, value in received_params.items():
|
||||
if key not in processed_keys:
|
||||
params_to_process[key] = value
|
||||
|
||||
if single_not_embedded_field:
|
||||
field_info = first_field.field_info
|
||||
assert isinstance(field_info, params.Param), (
|
||||
"Params must be subclasses of Param"
|
||||
)
|
||||
loc: Tuple[str, ...] = (field_info.in_.value,)
|
||||
v_, errors_ = _validate_value_with_model_field(
|
||||
field=first_field, value=params_to_process, values=values, loc=loc
|
||||
)
|
||||
return {first_field.name: v_}, errors_
|
||||
|
||||
for field in fields:
|
||||
value = _get_multidict_value(field, received_params)
|
||||
values = {}
|
||||
errors = []
|
||||
for field in required_params:
|
||||
if is_scalar_sequence_field(field) and isinstance(
|
||||
received_params, (QueryParams, Headers)
|
||||
):
|
||||
value = received_params.getlist(field.alias) or field.default
|
||||
else:
|
||||
value = received_params.get(field.alias)
|
||||
field_info = field.field_info
|
||||
assert isinstance(field_info, params.Param), (
|
||||
"Params must be subclasses of Param"
|
||||
)
|
||||
assert isinstance(
|
||||
field_info, params.Param
|
||||
), "Params must be subclasses of Param"
|
||||
loc = (field_info.in_.value, field.alias)
|
||||
v_, errors_ = _validate_value_with_model_field(
|
||||
field=field, value=value, values=values, loc=loc
|
||||
)
|
||||
if errors_:
|
||||
errors.extend(errors_)
|
||||
if value is None:
|
||||
if field.required:
|
||||
errors.append(get_missing_field_error(loc=loc))
|
||||
else:
|
||||
values[field.name] = deepcopy(field.default)
|
||||
continue
|
||||
v_, errors_ = field.validate(value, values, loc=loc)
|
||||
if isinstance(errors_, ErrorWrapper):
|
||||
errors.append(errors_)
|
||||
elif isinstance(errors_, list):
|
||||
new_errors = _regenerate_error_with_loc(errors=errors_, loc_prefix=())
|
||||
errors.extend(new_errors)
|
||||
else:
|
||||
values[field.name] = v_
|
||||
return values, errors
|
||||
|
||||
|
||||
def is_union_of_base_models(field_type: Any) -> bool:
|
||||
"""Check if field type is a Union where all members are BaseModel subclasses."""
|
||||
from fastapi.types import UnionType
|
||||
|
||||
origin = get_origin(field_type)
|
||||
|
||||
# Check if it's a Union type (covers both typing.Union and types.UnionType in Python 3.10+)
|
||||
if origin is not Union and origin is not UnionType:
|
||||
return False
|
||||
|
||||
union_args = get_args(field_type)
|
||||
|
||||
for arg in union_args:
|
||||
if not lenient_issubclass(arg, BaseModel):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def _should_embed_body_fields(fields: List[ModelField]) -> bool:
|
||||
if not fields:
|
||||
return False
|
||||
# More than one dependency could have the same field, it would show up as multiple
|
||||
# fields but it's the same one, so count them by name
|
||||
body_param_names_set = {field.name for field in fields}
|
||||
# A top level field has to be a single field, not multiple
|
||||
if len(body_param_names_set) > 1:
|
||||
return True
|
||||
first_field = fields[0]
|
||||
# If it explicitly specifies it is embedded, it has to be embedded
|
||||
if getattr(first_field.field_info, "embed", None):
|
||||
return True
|
||||
# If it's a Form (or File) field, it has to be a BaseModel (or a union of BaseModels) to be top level
|
||||
# otherwise it has to be embedded, so that the key value pair can be extracted
|
||||
if (
|
||||
isinstance(first_field.field_info, params.Form)
|
||||
and not lenient_issubclass(first_field.type_, BaseModel)
|
||||
and not is_union_of_base_models(first_field.type_)
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
async def _extract_form_body(
|
||||
body_fields: List[ModelField],
|
||||
received_body: FormData,
|
||||
) -> Dict[str, Any]:
|
||||
values = {}
|
||||
|
||||
for field in body_fields:
|
||||
value = _get_multidict_value(field, received_body)
|
||||
field_info = field.field_info
|
||||
if (
|
||||
isinstance(field_info, params.File)
|
||||
and is_bytes_field(field)
|
||||
and isinstance(value, UploadFile)
|
||||
):
|
||||
value = await value.read()
|
||||
elif (
|
||||
is_bytes_sequence_field(field)
|
||||
and isinstance(field_info, params.File)
|
||||
and value_is_sequence(value)
|
||||
):
|
||||
# For types
|
||||
assert isinstance(value, sequence_types) # type: ignore[arg-type]
|
||||
results: List[Union[bytes, str]] = []
|
||||
|
||||
async def process_fn(
|
||||
fn: Callable[[], Coroutine[Any, Any, Any]],
|
||||
) -> None:
|
||||
result = await fn()
|
||||
results.append(result) # noqa: B023
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
for sub_value in value:
|
||||
tg.start_soon(process_fn, sub_value.read)
|
||||
value = serialize_sequence_value(field=field, value=results)
|
||||
if value is not None:
|
||||
values[field.alias] = value
|
||||
for key, value in received_body.items():
|
||||
if key not in values:
|
||||
values[key] = value
|
||||
return values
|
||||
|
||||
|
||||
async def request_body_to_args(
|
||||
body_fields: List[ModelField],
|
||||
required_params: List[ModelField],
|
||||
received_body: Optional[Union[Dict[str, Any], FormData]],
|
||||
embed_body_fields: bool,
|
||||
) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]:
|
||||
values: Dict[str, Any] = {}
|
||||
values = {}
|
||||
errors: List[Dict[str, Any]] = []
|
||||
assert body_fields, "request_body_to_args() should be called with fields"
|
||||
single_not_embedded_field = len(body_fields) == 1 and not embed_body_fields
|
||||
first_field = body_fields[0]
|
||||
body_to_process = received_body
|
||||
if required_params:
|
||||
field = required_params[0]
|
||||
field_info = field.field_info
|
||||
embed = getattr(field_info, "embed", None)
|
||||
field_alias_omitted = len(required_params) == 1 and not embed
|
||||
if field_alias_omitted:
|
||||
received_body = {field.alias: received_body}
|
||||
|
||||
fields_to_extract: List[ModelField] = body_fields
|
||||
for field in required_params:
|
||||
loc: Tuple[str, ...]
|
||||
if field_alias_omitted:
|
||||
loc = ("body",)
|
||||
else:
|
||||
loc = ("body", field.alias)
|
||||
|
||||
if (
|
||||
single_not_embedded_field
|
||||
and lenient_issubclass(first_field.type_, BaseModel)
|
||||
and isinstance(received_body, FormData)
|
||||
):
|
||||
fields_to_extract = get_cached_model_fields(first_field.type_)
|
||||
|
||||
if isinstance(received_body, FormData):
|
||||
body_to_process = await _extract_form_body(fields_to_extract, received_body)
|
||||
|
||||
if single_not_embedded_field:
|
||||
loc: Tuple[str, ...] = ("body",)
|
||||
v_, errors_ = _validate_value_with_model_field(
|
||||
field=first_field, value=body_to_process, values=values, loc=loc
|
||||
)
|
||||
return {first_field.name: v_}, errors_
|
||||
for field in body_fields:
|
||||
loc = ("body", field.alias)
|
||||
value: Optional[Any] = None
|
||||
if body_to_process is not None:
|
||||
try:
|
||||
value = body_to_process.get(field.alias)
|
||||
# If the received body is a list, not a dict
|
||||
except AttributeError:
|
||||
errors.append(get_missing_field_error(loc))
|
||||
value: Optional[Any] = None
|
||||
if received_body is not None:
|
||||
if (is_sequence_field(field)) and isinstance(received_body, FormData):
|
||||
value = received_body.getlist(field.alias)
|
||||
else:
|
||||
try:
|
||||
value = received_body.get(field.alias)
|
||||
except AttributeError:
|
||||
errors.append(get_missing_field_error(loc))
|
||||
continue
|
||||
if (
|
||||
value is None
|
||||
or (isinstance(field_info, params.Form) and value == "")
|
||||
or (
|
||||
isinstance(field_info, params.Form)
|
||||
and is_sequence_field(field)
|
||||
and len(value) == 0
|
||||
)
|
||||
):
|
||||
if field.required:
|
||||
errors.append(get_missing_field_error(loc))
|
||||
else:
|
||||
values[field.name] = deepcopy(field.default)
|
||||
continue
|
||||
v_, errors_ = _validate_value_with_model_field(
|
||||
field=field, value=value, values=values, loc=loc
|
||||
)
|
||||
if errors_:
|
||||
errors.extend(errors_)
|
||||
else:
|
||||
values[field.name] = v_
|
||||
if (
|
||||
isinstance(field_info, params.File)
|
||||
and is_bytes_field(field)
|
||||
and isinstance(value, UploadFile)
|
||||
):
|
||||
value = await value.read()
|
||||
elif (
|
||||
is_bytes_sequence_field(field)
|
||||
and isinstance(field_info, params.File)
|
||||
and value_is_sequence(value)
|
||||
):
|
||||
# For types
|
||||
assert isinstance(value, sequence_types) # type: ignore[arg-type]
|
||||
results: List[Union[bytes, str]] = []
|
||||
|
||||
async def process_fn(
|
||||
fn: Callable[[], Coroutine[Any, Any, Any]]
|
||||
) -> None:
|
||||
result = await fn()
|
||||
results.append(result) # noqa: B023
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
for sub_value in value:
|
||||
tg.start_soon(process_fn, sub_value.read)
|
||||
value = serialize_sequence_value(field=field, value=results)
|
||||
|
||||
v_, errors_ = field.validate(value, values, loc=loc)
|
||||
|
||||
if isinstance(errors_, list):
|
||||
errors.extend(errors_)
|
||||
elif errors_:
|
||||
errors.append(errors_)
|
||||
else:
|
||||
values[field.name] = v_
|
||||
return values, errors
|
||||
|
||||
|
||||
def get_body_field(
|
||||
*, flat_dependant: Dependant, name: str, embed_body_fields: bool
|
||||
) -> Optional[ModelField]:
|
||||
"""
|
||||
Get a ModelField representing the request body for a path operation, combining
|
||||
all body parameters into a single field if necessary.
|
||||
|
||||
Used to check if it's form data (with `isinstance(body_field, params.Form)`)
|
||||
or JSON and to generate the JSON Schema for a request body.
|
||||
|
||||
This is **not** used to validate/parse the request body, that's done with each
|
||||
individual body parameter.
|
||||
"""
|
||||
def get_body_field(*, dependant: Dependant, name: str) -> Optional[ModelField]:
|
||||
flat_dependant = get_flat_dependant(dependant)
|
||||
if not flat_dependant.body_params:
|
||||
return None
|
||||
first_param = flat_dependant.body_params[0]
|
||||
if not embed_body_fields:
|
||||
field_info = first_param.field_info
|
||||
embed = getattr(field_info, "embed", None)
|
||||
body_param_names_set = {param.name for param in flat_dependant.body_params}
|
||||
if len(body_param_names_set) == 1 and not embed:
|
||||
check_file_field(first_param)
|
||||
return first_param
|
||||
# If one field requires to embed, all have to be embedded
|
||||
# in case a sub-dependency is evaluated with a single unique body field
|
||||
# That is combined (embedded) with other body fields
|
||||
for param in flat_dependant.body_params:
|
||||
setattr(param.field_info, "embed", True) # noqa: B010
|
||||
model_name = "Body_" + name
|
||||
BodyModel = create_body_model(
|
||||
fields=flat_dependant.body_params, model_name=model_name
|
||||
@@ -1002,11 +799,12 @@ def get_body_field(
|
||||
]
|
||||
if len(set(body_param_media_types)) == 1:
|
||||
BodyFieldInfo_kwargs["media_type"] = body_param_media_types[0]
|
||||
final_field = create_model_field(
|
||||
final_field = create_response_field(
|
||||
name="body",
|
||||
type_=BodyModel,
|
||||
required=required,
|
||||
alias="body",
|
||||
field_info=BodyFieldInfo(**BodyFieldInfo_kwargs),
|
||||
)
|
||||
check_file_field(final_field)
|
||||
return final_field
|
||||
|
||||
@@ -22,9 +22,9 @@ from pydantic import BaseModel
|
||||
from pydantic.color import Color
|
||||
from pydantic.networks import AnyUrl, NameEmail
|
||||
from pydantic.types import SecretBytes, SecretStr
|
||||
from typing_extensions import Annotated, Doc
|
||||
from typing_extensions import Annotated, Doc # type: ignore [attr-defined]
|
||||
|
||||
from ._compat import PYDANTIC_V2, UndefinedType, Url, _model_dump
|
||||
from ._compat import PYDANTIC_V2, Url, _model_dump
|
||||
|
||||
|
||||
# Taken from Pydantic v1 as is
|
||||
@@ -86,7 +86,7 @@ ENCODERS_BY_TYPE: Dict[Type[Any], Callable[[Any], Any]] = {
|
||||
|
||||
|
||||
def generate_encoders_by_class_tuples(
|
||||
type_encoder_map: Dict[Any, Callable[[Any], Any]],
|
||||
type_encoder_map: Dict[Any, Callable[[Any], Any]]
|
||||
) -> Dict[Callable[[Any], Any], Tuple[Any, ...]]:
|
||||
encoders_by_class_tuples: Dict[Callable[[Any], Any], Tuple[Any, ...]] = defaultdict(
|
||||
tuple
|
||||
@@ -219,7 +219,7 @@ def jsonable_encoder(
|
||||
if not PYDANTIC_V2:
|
||||
encoders = getattr(obj.__config__, "json_encoders", {}) # type: ignore[attr-defined]
|
||||
if custom_encoder:
|
||||
encoders = {**encoders, **custom_encoder}
|
||||
encoders.update(custom_encoder)
|
||||
obj_dict = _model_dump(
|
||||
obj,
|
||||
mode="json",
|
||||
@@ -241,7 +241,6 @@ def jsonable_encoder(
|
||||
sqlalchemy_safe=sqlalchemy_safe,
|
||||
)
|
||||
if dataclasses.is_dataclass(obj):
|
||||
assert not isinstance(obj, type)
|
||||
obj_dict = dataclasses.asdict(obj)
|
||||
return jsonable_encoder(
|
||||
obj_dict,
|
||||
@@ -260,8 +259,6 @@ def jsonable_encoder(
|
||||
return str(obj)
|
||||
if isinstance(obj, (str, int, float, type(None))):
|
||||
return obj
|
||||
if isinstance(obj, UndefinedType):
|
||||
return None
|
||||
if isinstance(obj, dict):
|
||||
encoded_dict = {}
|
||||
allowed_keys = set(obj.keys())
|
||||
|
||||
@@ -5,7 +5,7 @@ from fastapi.websockets import WebSocket
|
||||
from starlette.exceptions import HTTPException
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse, Response
|
||||
from starlette.status import WS_1008_POLICY_VIOLATION
|
||||
from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY, WS_1008_POLICY_VIOLATION
|
||||
|
||||
|
||||
async def http_exception_handler(request: Request, exc: HTTPException) -> Response:
|
||||
@@ -21,7 +21,7 @@ async def request_validation_exception_handler(
|
||||
request: Request, exc: RequestValidationError
|
||||
) -> JSONResponse:
|
||||
return JSONResponse(
|
||||
status_code=422,
|
||||
status_code=HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
content={"detail": jsonable_encoder(exc.errors())},
|
||||
)
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import Any, Dict, Optional, Sequence, Type, Union
|
||||
from pydantic import BaseModel, create_model
|
||||
from starlette.exceptions import HTTPException as StarletteHTTPException
|
||||
from starlette.exceptions import WebSocketException as StarletteWebSocketException
|
||||
from typing_extensions import Annotated, Doc
|
||||
from typing_extensions import Annotated, Doc # type: ignore [attr-defined]
|
||||
|
||||
|
||||
class HTTPException(StarletteHTTPException):
|
||||
|
||||
@@ -0,0 +1,25 @@
|
||||
from typing import Optional
|
||||
|
||||
from fastapi.concurrency import AsyncExitStack
|
||||
from starlette.types import ASGIApp, Receive, Scope, Send
|
||||
|
||||
|
||||
class AsyncExitStackMiddleware:
|
||||
def __init__(self, app: ASGIApp, context_name: str = "fastapi_astack") -> None:
|
||||
self.app = app
|
||||
self.context_name = context_name
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
dependency_exception: Optional[Exception] = None
|
||||
async with AsyncExitStack() as stack:
|
||||
scope[self.context_name] = stack
|
||||
try:
|
||||
await self.app(scope, receive, send)
|
||||
except Exception as e:
|
||||
dependency_exception = e
|
||||
raise e
|
||||
if dependency_exception:
|
||||
# This exception was possibly handled by the dependency but it should
|
||||
# still bubble up so that the ServerErrorMiddleware can return a 500
|
||||
# or the ExceptionMiddleware can catch and handle any other exceptions
|
||||
raise dependency_exception
|
||||
@@ -3,7 +3,7 @@ from typing import Any, Dict, Optional
|
||||
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from starlette.responses import HTMLResponse
|
||||
from typing_extensions import Annotated, Doc
|
||||
from typing_extensions import Annotated, Doc # type: ignore [attr-defined]
|
||||
|
||||
swagger_ui_default_parameters: Annotated[
|
||||
Dict[str, Any],
|
||||
@@ -53,7 +53,7 @@ def get_swagger_ui_html(
|
||||
It is normally set to a CDN URL.
|
||||
"""
|
||||
),
|
||||
] = "https://cdn.jsdelivr.net/npm/swagger-ui-dist@5/swagger-ui-bundle.js",
|
||||
] = "https://cdn.jsdelivr.net/npm/swagger-ui-dist@5.9.0/swagger-ui-bundle.js",
|
||||
swagger_css_url: Annotated[
|
||||
str,
|
||||
Doc(
|
||||
@@ -63,7 +63,7 @@ def get_swagger_ui_html(
|
||||
It is normally set to a CDN URL.
|
||||
"""
|
||||
),
|
||||
] = "https://cdn.jsdelivr.net/npm/swagger-ui-dist@5/swagger-ui.css",
|
||||
] = "https://cdn.jsdelivr.net/npm/swagger-ui-dist@5.9.0/swagger-ui.css",
|
||||
swagger_favicon_url: Annotated[
|
||||
str,
|
||||
Doc(
|
||||
@@ -188,7 +188,7 @@ def get_redoc_html(
|
||||
It is normally set to a CDN URL.
|
||||
"""
|
||||
),
|
||||
] = "https://cdn.jsdelivr.net/npm/redoc@2/bundles/redoc.standalone.js",
|
||||
] = "https://cdn.jsdelivr.net/npm/redoc@next/bundles/redoc.standalone.js",
|
||||
redoc_favicon_url: Annotated[
|
||||
str,
|
||||
Doc(
|
||||
|
||||
@@ -55,7 +55,11 @@ except ImportError: # pragma: no cover
|
||||
return with_info_plain_validator_function(cls._validate)
|
||||
|
||||
|
||||
class BaseModelWithConfig(BaseModel):
|
||||
class Contact(BaseModel):
|
||||
name: Optional[str] = None
|
||||
url: Optional[AnyUrl] = None
|
||||
email: Optional[EmailStr] = None
|
||||
|
||||
if PYDANTIC_V2:
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
@@ -65,19 +69,21 @@ class BaseModelWithConfig(BaseModel):
|
||||
extra = "allow"
|
||||
|
||||
|
||||
class Contact(BaseModelWithConfig):
|
||||
name: Optional[str] = None
|
||||
url: Optional[AnyUrl] = None
|
||||
email: Optional[EmailStr] = None
|
||||
|
||||
|
||||
class License(BaseModelWithConfig):
|
||||
class License(BaseModel):
|
||||
name: str
|
||||
identifier: Optional[str] = None
|
||||
url: Optional[AnyUrl] = None
|
||||
|
||||
if PYDANTIC_V2:
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
class Info(BaseModelWithConfig):
|
||||
else:
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
|
||||
class Info(BaseModel):
|
||||
title: str
|
||||
summary: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
@@ -86,18 +92,42 @@ class Info(BaseModelWithConfig):
|
||||
license: Optional[License] = None
|
||||
version: str
|
||||
|
||||
if PYDANTIC_V2:
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
class ServerVariable(BaseModelWithConfig):
|
||||
else:
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
|
||||
class ServerVariable(BaseModel):
|
||||
enum: Annotated[Optional[List[str]], Field(min_length=1)] = None
|
||||
default: str
|
||||
description: Optional[str] = None
|
||||
|
||||
if PYDANTIC_V2:
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
class Server(BaseModelWithConfig):
|
||||
else:
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
|
||||
class Server(BaseModel):
|
||||
url: Union[AnyUrl, str]
|
||||
description: Optional[str] = None
|
||||
variables: Optional[Dict[str, ServerVariable]] = None
|
||||
|
||||
if PYDANTIC_V2:
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
else:
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
|
||||
class Reference(BaseModel):
|
||||
ref: str = Field(alias="$ref")
|
||||
@@ -108,26 +138,36 @@ class Discriminator(BaseModel):
|
||||
mapping: Optional[Dict[str, str]] = None
|
||||
|
||||
|
||||
class XML(BaseModelWithConfig):
|
||||
class XML(BaseModel):
|
||||
name: Optional[str] = None
|
||||
namespace: Optional[str] = None
|
||||
prefix: Optional[str] = None
|
||||
attribute: Optional[bool] = None
|
||||
wrapped: Optional[bool] = None
|
||||
|
||||
if PYDANTIC_V2:
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
class ExternalDocumentation(BaseModelWithConfig):
|
||||
else:
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
|
||||
class ExternalDocumentation(BaseModel):
|
||||
description: Optional[str] = None
|
||||
url: AnyUrl
|
||||
|
||||
if PYDANTIC_V2:
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
# Ref JSON Schema 2020-12: https://json-schema.org/draft/2020-12/json-schema-validation#name-type
|
||||
SchemaType = Literal[
|
||||
"array", "boolean", "integer", "null", "number", "object", "string"
|
||||
]
|
||||
else:
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
|
||||
class Schema(BaseModelWithConfig):
|
||||
class Schema(BaseModel):
|
||||
# Ref: JSON Schema 2020-12: https://json-schema.org/draft/2020-12/json-schema-core.html#name-the-json-schema-core-vocabu
|
||||
# Core Vocabulary
|
||||
schema_: Optional[str] = Field(default=None, alias="$schema")
|
||||
@@ -151,7 +191,7 @@ class Schema(BaseModelWithConfig):
|
||||
dependentSchemas: Optional[Dict[str, "SchemaOrBool"]] = None
|
||||
prefixItems: Optional[List["SchemaOrBool"]] = None
|
||||
# TODO: uncomment and remove below when deprecating Pydantic v1
|
||||
# It generates a list of schemas for tuples, before prefixItems was available
|
||||
# It generales a list of schemas for tuples, before prefixItems was available
|
||||
# items: Optional["SchemaOrBool"] = None
|
||||
items: Optional[Union["SchemaOrBool", List["SchemaOrBool"]]] = None
|
||||
contains: Optional["SchemaOrBool"] = None
|
||||
@@ -163,7 +203,7 @@ class Schema(BaseModelWithConfig):
|
||||
unevaluatedProperties: Optional["SchemaOrBool"] = None
|
||||
# Ref: JSON Schema Validation 2020-12: https://json-schema.org/draft/2020-12/json-schema-validation.html#name-a-vocabulary-for-structural
|
||||
# A Vocabulary for Structural Validation
|
||||
type: Optional[Union[SchemaType, List[SchemaType]]] = None
|
||||
type: Optional[str] = None
|
||||
enum: Optional[List[Any]] = None
|
||||
const: Optional[Any] = None
|
||||
multipleOf: Optional[float] = Field(default=None, gt=0)
|
||||
@@ -213,6 +253,14 @@ class Schema(BaseModelWithConfig):
|
||||
),
|
||||
] = None
|
||||
|
||||
if PYDANTIC_V2:
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
else:
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
|
||||
# Ref: https://json-schema.org/draft/2020-12/json-schema-core.html#name-json-schema-documents
|
||||
# A JSON Schema MUST be an object or a boolean.
|
||||
@@ -241,22 +289,38 @@ class ParameterInType(Enum):
|
||||
cookie = "cookie"
|
||||
|
||||
|
||||
class Encoding(BaseModelWithConfig):
|
||||
class Encoding(BaseModel):
|
||||
contentType: Optional[str] = None
|
||||
headers: Optional[Dict[str, Union["Header", Reference]]] = None
|
||||
style: Optional[str] = None
|
||||
explode: Optional[bool] = None
|
||||
allowReserved: Optional[bool] = None
|
||||
|
||||
if PYDANTIC_V2:
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
class MediaType(BaseModelWithConfig):
|
||||
else:
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
|
||||
class MediaType(BaseModel):
|
||||
schema_: Optional[Union[Schema, Reference]] = Field(default=None, alias="schema")
|
||||
example: Optional[Any] = None
|
||||
examples: Optional[Dict[str, Union[Example, Reference]]] = None
|
||||
encoding: Optional[Dict[str, Encoding]] = None
|
||||
|
||||
if PYDANTIC_V2:
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
class ParameterBase(BaseModelWithConfig):
|
||||
else:
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
|
||||
class ParameterBase(BaseModel):
|
||||
description: Optional[str] = None
|
||||
required: Optional[bool] = None
|
||||
deprecated: Optional[bool] = None
|
||||
@@ -270,6 +334,14 @@ class ParameterBase(BaseModelWithConfig):
|
||||
# Serialization rules for more complex scenarios
|
||||
content: Optional[Dict[str, MediaType]] = None
|
||||
|
||||
if PYDANTIC_V2:
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
else:
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
|
||||
class Parameter(ParameterBase):
|
||||
name: str
|
||||
@@ -280,13 +352,21 @@ class Header(ParameterBase):
|
||||
pass
|
||||
|
||||
|
||||
class RequestBody(BaseModelWithConfig):
|
||||
class RequestBody(BaseModel):
|
||||
description: Optional[str] = None
|
||||
content: Dict[str, MediaType]
|
||||
required: Optional[bool] = None
|
||||
|
||||
if PYDANTIC_V2:
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
class Link(BaseModelWithConfig):
|
||||
else:
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
|
||||
class Link(BaseModel):
|
||||
operationRef: Optional[str] = None
|
||||
operationId: Optional[str] = None
|
||||
parameters: Optional[Dict[str, Union[Any, str]]] = None
|
||||
@@ -294,15 +374,31 @@ class Link(BaseModelWithConfig):
|
||||
description: Optional[str] = None
|
||||
server: Optional[Server] = None
|
||||
|
||||
if PYDANTIC_V2:
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
class Response(BaseModelWithConfig):
|
||||
else:
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
|
||||
class Response(BaseModel):
|
||||
description: str
|
||||
headers: Optional[Dict[str, Union[Header, Reference]]] = None
|
||||
content: Optional[Dict[str, MediaType]] = None
|
||||
links: Optional[Dict[str, Union[Link, Reference]]] = None
|
||||
|
||||
if PYDANTIC_V2:
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
class Operation(BaseModelWithConfig):
|
||||
else:
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
|
||||
class Operation(BaseModel):
|
||||
tags: Optional[List[str]] = None
|
||||
summary: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
@@ -317,8 +413,16 @@ class Operation(BaseModelWithConfig):
|
||||
security: Optional[List[Dict[str, List[str]]]] = None
|
||||
servers: Optional[List[Server]] = None
|
||||
|
||||
if PYDANTIC_V2:
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
class PathItem(BaseModelWithConfig):
|
||||
else:
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
|
||||
class PathItem(BaseModel):
|
||||
ref: Optional[str] = Field(default=None, alias="$ref")
|
||||
summary: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
@@ -333,6 +437,14 @@ class PathItem(BaseModelWithConfig):
|
||||
servers: Optional[List[Server]] = None
|
||||
parameters: Optional[List[Union[Parameter, Reference]]] = None
|
||||
|
||||
if PYDANTIC_V2:
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
else:
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
|
||||
class SecuritySchemeType(Enum):
|
||||
apiKey = "apiKey"
|
||||
@@ -341,10 +453,18 @@ class SecuritySchemeType(Enum):
|
||||
openIdConnect = "openIdConnect"
|
||||
|
||||
|
||||
class SecurityBase(BaseModelWithConfig):
|
||||
class SecurityBase(BaseModel):
|
||||
type_: SecuritySchemeType = Field(alias="type")
|
||||
description: Optional[str] = None
|
||||
|
||||
if PYDANTIC_V2:
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
else:
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
|
||||
class APIKeyIn(Enum):
|
||||
query = "query"
|
||||
@@ -368,10 +488,18 @@ class HTTPBearer(HTTPBase):
|
||||
bearerFormat: Optional[str] = None
|
||||
|
||||
|
||||
class OAuthFlow(BaseModelWithConfig):
|
||||
class OAuthFlow(BaseModel):
|
||||
refreshUrl: Optional[str] = None
|
||||
scopes: Dict[str, str] = {}
|
||||
|
||||
if PYDANTIC_V2:
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
else:
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
|
||||
class OAuthFlowImplicit(OAuthFlow):
|
||||
authorizationUrl: str
|
||||
@@ -390,12 +518,20 @@ class OAuthFlowAuthorizationCode(OAuthFlow):
|
||||
tokenUrl: str
|
||||
|
||||
|
||||
class OAuthFlows(BaseModelWithConfig):
|
||||
class OAuthFlows(BaseModel):
|
||||
implicit: Optional[OAuthFlowImplicit] = None
|
||||
password: Optional[OAuthFlowPassword] = None
|
||||
clientCredentials: Optional[OAuthFlowClientCredentials] = None
|
||||
authorizationCode: Optional[OAuthFlowAuthorizationCode] = None
|
||||
|
||||
if PYDANTIC_V2:
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
else:
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
|
||||
class OAuth2(SecurityBase):
|
||||
type_: SecuritySchemeType = Field(default=SecuritySchemeType.oauth2, alias="type")
|
||||
@@ -412,7 +548,7 @@ class OpenIdConnect(SecurityBase):
|
||||
SecurityScheme = Union[APIKey, HTTPBase, OAuth2, OpenIdConnect, HTTPBearer]
|
||||
|
||||
|
||||
class Components(BaseModelWithConfig):
|
||||
class Components(BaseModel):
|
||||
schemas: Optional[Dict[str, Union[Schema, Reference]]] = None
|
||||
responses: Optional[Dict[str, Union[Response, Reference]]] = None
|
||||
parameters: Optional[Dict[str, Union[Parameter, Reference]]] = None
|
||||
@@ -425,14 +561,30 @@ class Components(BaseModelWithConfig):
|
||||
callbacks: Optional[Dict[str, Union[Dict[str, PathItem], Reference, Any]]] = None
|
||||
pathItems: Optional[Dict[str, Union[PathItem, Reference]]] = None
|
||||
|
||||
if PYDANTIC_V2:
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
class Tag(BaseModelWithConfig):
|
||||
else:
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
|
||||
class Tag(BaseModel):
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
externalDocs: Optional[ExternalDocumentation] = None
|
||||
|
||||
if PYDANTIC_V2:
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
class OpenAPI(BaseModelWithConfig):
|
||||
else:
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
|
||||
class OpenAPI(BaseModel):
|
||||
openapi: str
|
||||
info: Info
|
||||
jsonSchemaDialect: Optional[str] = None
|
||||
@@ -445,6 +597,14 @@ class OpenAPI(BaseModelWithConfig):
|
||||
tags: Optional[List[Tag]] = None
|
||||
externalDocs: Optional[ExternalDocumentation] = None
|
||||
|
||||
if PYDANTIC_V2:
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
else:
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
|
||||
_model_rebuild(Schema)
|
||||
_model_rebuild(Operation)
|
||||
|
||||
@@ -16,15 +16,11 @@ from fastapi._compat import (
|
||||
)
|
||||
from fastapi.datastructures import DefaultPlaceholder
|
||||
from fastapi.dependencies.models import Dependant
|
||||
from fastapi.dependencies.utils import (
|
||||
_get_flat_fields_from_params,
|
||||
get_flat_dependant,
|
||||
get_flat_params,
|
||||
)
|
||||
from fastapi.dependencies.utils import get_flat_dependant, get_flat_params
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from fastapi.openapi.constants import METHODS_WITH_BODY, REF_PREFIX, REF_TEMPLATE
|
||||
from fastapi.openapi.models import OpenAPI
|
||||
from fastapi.params import Body, ParamTypes
|
||||
from fastapi.params import Body, Param
|
||||
from fastapi.responses import Response
|
||||
from fastapi.types import ModelNameMap
|
||||
from fastapi.utils import (
|
||||
@@ -32,9 +28,9 @@ from fastapi.utils import (
|
||||
generate_operation_id_for_path,
|
||||
is_body_allowed_for_status_code,
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
from starlette.responses import JSONResponse
|
||||
from starlette.routing import BaseRoute
|
||||
from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY
|
||||
from typing_extensions import Literal
|
||||
|
||||
validation_error_definition = {
|
||||
@@ -91,9 +87,9 @@ def get_openapi_security_definitions(
|
||||
return security_definitions, operation_security
|
||||
|
||||
|
||||
def _get_openapi_operation_parameters(
|
||||
def get_openapi_operation_parameters(
|
||||
*,
|
||||
dependant: Dependant,
|
||||
all_route_params: Sequence[ModelField],
|
||||
schema_generator: GenerateJsonSchema,
|
||||
model_name_map: ModelNameMap,
|
||||
field_mapping: Dict[
|
||||
@@ -102,67 +98,33 @@ def _get_openapi_operation_parameters(
|
||||
separate_input_output_schemas: bool = True,
|
||||
) -> List[Dict[str, Any]]:
|
||||
parameters = []
|
||||
flat_dependant = get_flat_dependant(dependant, skip_repeats=True)
|
||||
path_params = _get_flat_fields_from_params(flat_dependant.path_params)
|
||||
query_params = _get_flat_fields_from_params(flat_dependant.query_params)
|
||||
header_params = _get_flat_fields_from_params(flat_dependant.header_params)
|
||||
cookie_params = _get_flat_fields_from_params(flat_dependant.cookie_params)
|
||||
parameter_groups = [
|
||||
(ParamTypes.path, path_params),
|
||||
(ParamTypes.query, query_params),
|
||||
(ParamTypes.header, header_params),
|
||||
(ParamTypes.cookie, cookie_params),
|
||||
]
|
||||
default_convert_underscores = True
|
||||
if len(flat_dependant.header_params) == 1:
|
||||
first_field = flat_dependant.header_params[0]
|
||||
if lenient_issubclass(first_field.type_, BaseModel):
|
||||
default_convert_underscores = getattr(
|
||||
first_field.field_info, "convert_underscores", True
|
||||
)
|
||||
for param_type, param_group in parameter_groups:
|
||||
for param in param_group:
|
||||
field_info = param.field_info
|
||||
# field_info = cast(Param, field_info)
|
||||
if not getattr(field_info, "include_in_schema", True):
|
||||
continue
|
||||
param_schema = get_schema_from_model_field(
|
||||
field=param,
|
||||
schema_generator=schema_generator,
|
||||
model_name_map=model_name_map,
|
||||
field_mapping=field_mapping,
|
||||
separate_input_output_schemas=separate_input_output_schemas,
|
||||
)
|
||||
name = param.alias
|
||||
convert_underscores = getattr(
|
||||
param.field_info,
|
||||
"convert_underscores",
|
||||
default_convert_underscores,
|
||||
)
|
||||
if (
|
||||
param_type == ParamTypes.header
|
||||
and param.alias == param.name
|
||||
and convert_underscores
|
||||
):
|
||||
name = param.name.replace("_", "-")
|
||||
|
||||
parameter = {
|
||||
"name": name,
|
||||
"in": param_type.value,
|
||||
"required": param.required,
|
||||
"schema": param_schema,
|
||||
}
|
||||
if field_info.description:
|
||||
parameter["description"] = field_info.description
|
||||
openapi_examples = getattr(field_info, "openapi_examples", None)
|
||||
example = getattr(field_info, "example", None)
|
||||
if openapi_examples:
|
||||
parameter["examples"] = jsonable_encoder(openapi_examples)
|
||||
elif example != Undefined:
|
||||
parameter["example"] = jsonable_encoder(example)
|
||||
if getattr(field_info, "deprecated", None):
|
||||
parameter["deprecated"] = True
|
||||
parameters.append(parameter)
|
||||
for param in all_route_params:
|
||||
field_info = param.field_info
|
||||
field_info = cast(Param, field_info)
|
||||
if not field_info.include_in_schema:
|
||||
continue
|
||||
param_schema = get_schema_from_model_field(
|
||||
field=param,
|
||||
schema_generator=schema_generator,
|
||||
model_name_map=model_name_map,
|
||||
field_mapping=field_mapping,
|
||||
separate_input_output_schemas=separate_input_output_schemas,
|
||||
)
|
||||
parameter = {
|
||||
"name": param.alias,
|
||||
"in": field_info.in_.value,
|
||||
"required": param.required,
|
||||
"schema": param_schema,
|
||||
}
|
||||
if field_info.description:
|
||||
parameter["description"] = field_info.description
|
||||
if field_info.openapi_examples:
|
||||
parameter["examples"] = jsonable_encoder(field_info.openapi_examples)
|
||||
elif field_info.example != Undefined:
|
||||
parameter["example"] = jsonable_encoder(field_info.example)
|
||||
if field_info.deprecated:
|
||||
parameter["deprecated"] = field_info.deprecated
|
||||
parameters.append(parameter)
|
||||
return parameters
|
||||
|
||||
|
||||
@@ -285,8 +247,9 @@ def get_openapi_path(
|
||||
operation.setdefault("security", []).extend(operation_security)
|
||||
if security_definitions:
|
||||
security_schemes.update(security_definitions)
|
||||
operation_parameters = _get_openapi_operation_parameters(
|
||||
dependant=route.dependant,
|
||||
all_route_params = get_flat_params(route.dependant)
|
||||
operation_parameters = get_openapi_operation_parameters(
|
||||
all_route_params=all_route_params,
|
||||
schema_generator=schema_generator,
|
||||
model_name_map=model_name_map,
|
||||
field_mapping=field_mapping,
|
||||
@@ -384,9 +347,9 @@ def get_openapi_path(
|
||||
openapi_response = operation_responses.setdefault(
|
||||
status_code_key, {}
|
||||
)
|
||||
assert isinstance(process_response, dict), (
|
||||
"An additional response must be a dict"
|
||||
)
|
||||
assert isinstance(
|
||||
process_response, dict
|
||||
), "An additional response must be a dict"
|
||||
field = route.response_fields.get(additional_status_code)
|
||||
additional_field_schema: Optional[Dict[str, Any]] = None
|
||||
if field:
|
||||
@@ -415,8 +378,7 @@ def get_openapi_path(
|
||||
)
|
||||
deep_dict_update(openapi_response, process_response)
|
||||
openapi_response["description"] = description
|
||||
http422 = "422"
|
||||
all_route_params = get_flat_params(route.dependant)
|
||||
http422 = str(HTTP_422_UNPROCESSABLE_ENTITY)
|
||||
if (all_route_params or route.body_field) and not any(
|
||||
status in operation["responses"]
|
||||
for status in [http422, "4XX", "default"]
|
||||
@@ -454,9 +416,9 @@ def get_fields_from_routes(
|
||||
route, routing.APIRoute
|
||||
):
|
||||
if route.body_field:
|
||||
assert isinstance(route.body_field, ModelField), (
|
||||
"A request body must be a Pydantic Field"
|
||||
)
|
||||
assert isinstance(
|
||||
route.body_field, ModelField
|
||||
), "A request body must be a Pydantic Field"
|
||||
body_fields_from_routes.append(route.body_field)
|
||||
if route.response_field:
|
||||
responses_from_routes.append(route.response_field)
|
||||
@@ -488,7 +450,6 @@ def get_openapi(
|
||||
contact: Optional[Dict[str, Union[str, Any]]] = None,
|
||||
license_info: Optional[Dict[str, Union[str, Any]]] = None,
|
||||
separate_input_output_schemas: bool = True,
|
||||
external_docs: Optional[Dict[str, Any]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
info: Dict[str, Any] = {"title": title, "version": version}
|
||||
if summary:
|
||||
@@ -566,6 +527,4 @@ def get_openapi(
|
||||
output["webhooks"] = webhook_paths
|
||||
if tags:
|
||||
output["tags"] = tags
|
||||
if external_docs:
|
||||
output["externalDocs"] = external_docs
|
||||
return jsonable_encoder(OpenAPI(**output), by_alias=True, exclude_none=True) # type: ignore
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import Any, Callable, Dict, List, Optional, Sequence, Union
|
||||
from fastapi import params
|
||||
from fastapi._compat import Undefined
|
||||
from fastapi.openapi.models import Example
|
||||
from typing_extensions import Annotated, Doc, deprecated
|
||||
from typing_extensions import Annotated, Doc, deprecated # type: ignore [attr-defined]
|
||||
|
||||
_Unset: Any = Undefined
|
||||
|
||||
@@ -240,7 +240,7 @@ def Path( # noqa: N802
|
||||
),
|
||||
] = None,
|
||||
deprecated: Annotated[
|
||||
Union[deprecated, str, bool, None],
|
||||
Optional[bool],
|
||||
Doc(
|
||||
"""
|
||||
Mark this parameter field as deprecated.
|
||||
@@ -565,7 +565,7 @@ def Query( # noqa: N802
|
||||
),
|
||||
] = None,
|
||||
deprecated: Annotated[
|
||||
Union[deprecated, str, bool, None],
|
||||
Optional[bool],
|
||||
Doc(
|
||||
"""
|
||||
Mark this parameter field as deprecated.
|
||||
@@ -880,7 +880,7 @@ def Header( # noqa: N802
|
||||
),
|
||||
] = None,
|
||||
deprecated: Annotated[
|
||||
Union[deprecated, str, bool, None],
|
||||
Optional[bool],
|
||||
Doc(
|
||||
"""
|
||||
Mark this parameter field as deprecated.
|
||||
@@ -1185,7 +1185,7 @@ def Cookie( # noqa: N802
|
||||
),
|
||||
] = None,
|
||||
deprecated: Annotated[
|
||||
Union[deprecated, str, bool, None],
|
||||
Optional[bool],
|
||||
Doc(
|
||||
"""
|
||||
Mark this parameter field as deprecated.
|
||||
@@ -1282,7 +1282,7 @@ def Body( # noqa: N802
|
||||
),
|
||||
] = _Unset,
|
||||
embed: Annotated[
|
||||
Union[bool, None],
|
||||
bool,
|
||||
Doc(
|
||||
"""
|
||||
When `embed` is `True`, the parameter will be expected in a JSON body as a
|
||||
@@ -1294,7 +1294,7 @@ def Body( # noqa: N802
|
||||
[FastAPI docs for Body - Multiple Parameters](https://fastapi.tiangolo.com/tutorial/body-multiple-params/#embed-a-single-body-parameter).
|
||||
"""
|
||||
),
|
||||
] = None,
|
||||
] = False,
|
||||
media_type: Annotated[
|
||||
str,
|
||||
Doc(
|
||||
@@ -1512,7 +1512,7 @@ def Body( # noqa: N802
|
||||
),
|
||||
] = None,
|
||||
deprecated: Annotated[
|
||||
Union[deprecated, str, bool, None],
|
||||
Optional[bool],
|
||||
Doc(
|
||||
"""
|
||||
Mark this parameter field as deprecated.
|
||||
@@ -1827,7 +1827,7 @@ def Form( # noqa: N802
|
||||
),
|
||||
] = None,
|
||||
deprecated: Annotated[
|
||||
Union[deprecated, str, bool, None],
|
||||
Optional[bool],
|
||||
Doc(
|
||||
"""
|
||||
Mark this parameter field as deprecated.
|
||||
@@ -2141,7 +2141,7 @@ def File( # noqa: N802
|
||||
),
|
||||
] = None,
|
||||
deprecated: Annotated[
|
||||
Union[deprecated, str, bool, None],
|
||||
Optional[bool],
|
||||
Doc(
|
||||
"""
|
||||
Mark this parameter field as deprecated.
|
||||
@@ -2298,7 +2298,7 @@ def Security( # noqa: N802
|
||||
dependency.
|
||||
|
||||
The term "scope" comes from the OAuth2 specification, it seems to be
|
||||
intentionally vague and interpretable. It normally refers to permissions,
|
||||
intentionaly vague and interpretable. It normally refers to permissions,
|
||||
in cases to roles.
|
||||
|
||||
These scopes are integrated with OpenAPI (and the API docs at `/docs`).
|
||||
@@ -2343,7 +2343,7 @@ def Security( # noqa: N802
|
||||
```python
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import Security, FastAPI
|
||||
from fastapi import Depends, FastAPI
|
||||
|
||||
from .db import User
|
||||
from .security import get_current_active_user
|
||||
|
||||
@@ -6,11 +6,7 @@ from fastapi.openapi.models import Example
|
||||
from pydantic.fields import FieldInfo
|
||||
from typing_extensions import Annotated, deprecated
|
||||
|
||||
from ._compat import (
|
||||
PYDANTIC_V2,
|
||||
PYDANTIC_VERSION_MINOR_TUPLE,
|
||||
Undefined,
|
||||
)
|
||||
from ._compat import PYDANTIC_V2, Undefined
|
||||
|
||||
_Unset: Any = Undefined
|
||||
|
||||
@@ -67,11 +63,12 @@ class Param(FieldInfo):
|
||||
),
|
||||
] = _Unset,
|
||||
openapi_examples: Optional[Dict[str, Example]] = None,
|
||||
deprecated: Union[deprecated, str, bool, None] = None,
|
||||
deprecated: Optional[bool] = None,
|
||||
include_in_schema: bool = True,
|
||||
json_schema_extra: Union[Dict[str, Any], None] = None,
|
||||
**extra: Any,
|
||||
):
|
||||
self.deprecated = deprecated
|
||||
if example is not _Unset:
|
||||
warnings.warn(
|
||||
"`example` has been deprecated, please use `examples` instead",
|
||||
@@ -95,7 +92,7 @@ class Param(FieldInfo):
|
||||
max_length=max_length,
|
||||
discriminator=discriminator,
|
||||
multiple_of=multiple_of,
|
||||
allow_inf_nan=allow_inf_nan,
|
||||
allow_nan=allow_inf_nan,
|
||||
max_digits=max_digits,
|
||||
decimal_places=decimal_places,
|
||||
**extra,
|
||||
@@ -109,10 +106,6 @@ class Param(FieldInfo):
|
||||
stacklevel=4,
|
||||
)
|
||||
current_json_schema_extra = json_schema_extra or extra
|
||||
if PYDANTIC_VERSION_MINOR_TUPLE < (2, 7):
|
||||
self.deprecated = deprecated
|
||||
else:
|
||||
kwargs["deprecated"] = deprecated
|
||||
if PYDANTIC_V2:
|
||||
kwargs.update(
|
||||
{
|
||||
@@ -181,7 +174,7 @@ class Path(Param):
|
||||
),
|
||||
] = _Unset,
|
||||
openapi_examples: Optional[Dict[str, Example]] = None,
|
||||
deprecated: Union[deprecated, str, bool, None] = None,
|
||||
deprecated: Optional[bool] = None,
|
||||
include_in_schema: bool = True,
|
||||
json_schema_extra: Union[Dict[str, Any], None] = None,
|
||||
**extra: Any,
|
||||
@@ -267,7 +260,7 @@ class Query(Param):
|
||||
),
|
||||
] = _Unset,
|
||||
openapi_examples: Optional[Dict[str, Example]] = None,
|
||||
deprecated: Union[deprecated, str, bool, None] = None,
|
||||
deprecated: Optional[bool] = None,
|
||||
include_in_schema: bool = True,
|
||||
json_schema_extra: Union[Dict[str, Any], None] = None,
|
||||
**extra: Any,
|
||||
@@ -352,7 +345,7 @@ class Header(Param):
|
||||
),
|
||||
] = _Unset,
|
||||
openapi_examples: Optional[Dict[str, Example]] = None,
|
||||
deprecated: Union[deprecated, str, bool, None] = None,
|
||||
deprecated: Optional[bool] = None,
|
||||
include_in_schema: bool = True,
|
||||
json_schema_extra: Union[Dict[str, Any], None] = None,
|
||||
**extra: Any,
|
||||
@@ -437,7 +430,7 @@ class Cookie(Param):
|
||||
),
|
||||
] = _Unset,
|
||||
openapi_examples: Optional[Dict[str, Example]] = None,
|
||||
deprecated: Union[deprecated, str, bool, None] = None,
|
||||
deprecated: Optional[bool] = None,
|
||||
include_in_schema: bool = True,
|
||||
json_schema_extra: Union[Dict[str, Any], None] = None,
|
||||
**extra: Any,
|
||||
@@ -483,7 +476,7 @@ class Body(FieldInfo):
|
||||
*,
|
||||
default_factory: Union[Callable[[], Any], None] = _Unset,
|
||||
annotation: Optional[Any] = None,
|
||||
embed: Union[bool, None] = None,
|
||||
embed: bool = False,
|
||||
media_type: str = "application/json",
|
||||
alias: Optional[str] = None,
|
||||
alias_priority: Union[int, None] = _Unset,
|
||||
@@ -521,13 +514,14 @@ class Body(FieldInfo):
|
||||
),
|
||||
] = _Unset,
|
||||
openapi_examples: Optional[Dict[str, Example]] = None,
|
||||
deprecated: Union[deprecated, str, bool, None] = None,
|
||||
deprecated: Optional[bool] = None,
|
||||
include_in_schema: bool = True,
|
||||
json_schema_extra: Union[Dict[str, Any], None] = None,
|
||||
**extra: Any,
|
||||
):
|
||||
self.embed = embed
|
||||
self.media_type = media_type
|
||||
self.deprecated = deprecated
|
||||
if example is not _Unset:
|
||||
warnings.warn(
|
||||
"`example` has been deprecated, please use `examples` instead",
|
||||
@@ -551,7 +545,7 @@ class Body(FieldInfo):
|
||||
max_length=max_length,
|
||||
discriminator=discriminator,
|
||||
multiple_of=multiple_of,
|
||||
allow_inf_nan=allow_inf_nan,
|
||||
allow_nan=allow_inf_nan,
|
||||
max_digits=max_digits,
|
||||
decimal_places=decimal_places,
|
||||
**extra,
|
||||
@@ -560,15 +554,11 @@ class Body(FieldInfo):
|
||||
kwargs["examples"] = examples
|
||||
if regex is not None:
|
||||
warnings.warn(
|
||||
"`regex` has been deprecated, please use `pattern` instead",
|
||||
"`regex` has been depreacated, please use `pattern` instead",
|
||||
category=DeprecationWarning,
|
||||
stacklevel=4,
|
||||
)
|
||||
current_json_schema_extra = json_schema_extra or extra
|
||||
if PYDANTIC_VERSION_MINOR_TUPLE < (2, 7):
|
||||
self.deprecated = deprecated
|
||||
else:
|
||||
kwargs["deprecated"] = deprecated
|
||||
if PYDANTIC_V2:
|
||||
kwargs.update(
|
||||
{
|
||||
@@ -637,7 +627,7 @@ class Form(Body):
|
||||
),
|
||||
] = _Unset,
|
||||
openapi_examples: Optional[Dict[str, Example]] = None,
|
||||
deprecated: Union[deprecated, str, bool, None] = None,
|
||||
deprecated: Optional[bool] = None,
|
||||
include_in_schema: bool = True,
|
||||
json_schema_extra: Union[Dict[str, Any], None] = None,
|
||||
**extra: Any,
|
||||
@@ -646,6 +636,7 @@ class Form(Body):
|
||||
default=default,
|
||||
default_factory=default_factory,
|
||||
annotation=annotation,
|
||||
embed=True,
|
||||
media_type=media_type,
|
||||
alias=alias,
|
||||
alias_priority=alias_priority,
|
||||
@@ -721,7 +712,7 @@ class File(Form):
|
||||
),
|
||||
] = _Unset,
|
||||
openapi_examples: Optional[Dict[str, Example]] = None,
|
||||
deprecated: Union[deprecated, str, bool, None] = None,
|
||||
deprecated: Optional[bool] = None,
|
||||
include_in_schema: bool = True,
|
||||
json_schema_extra: Union[Dict[str, Any], None] = None,
|
||||
**extra: Any,
|
||||
|
||||
@@ -1,19 +1,16 @@
|
||||
import asyncio
|
||||
import dataclasses
|
||||
import email.message
|
||||
import inspect
|
||||
import json
|
||||
import sys
|
||||
from contextlib import AsyncExitStack, asynccontextmanager
|
||||
from contextlib import AsyncExitStack
|
||||
from enum import Enum, IntEnum
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Collection,
|
||||
Coroutine,
|
||||
Dict,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Set,
|
||||
@@ -34,10 +31,8 @@ from fastapi._compat import (
|
||||
from fastapi.datastructures import Default, DefaultPlaceholder
|
||||
from fastapi.dependencies.models import Dependant
|
||||
from fastapi.dependencies.utils import (
|
||||
_should_embed_body_fields,
|
||||
get_body_field,
|
||||
get_dependant,
|
||||
get_flat_dependant,
|
||||
get_parameterless_sub_dependant,
|
||||
get_typed_return_annotation,
|
||||
solve_dependencies,
|
||||
@@ -52,7 +47,7 @@ from fastapi.exceptions import (
|
||||
from fastapi.types import DecoratedCallable, IncEx
|
||||
from fastapi.utils import (
|
||||
create_cloned_field,
|
||||
create_model_field,
|
||||
create_response_field,
|
||||
generate_unique_id,
|
||||
get_value_or_default,
|
||||
is_body_allowed_for_status_code,
|
||||
@@ -72,14 +67,9 @@ from starlette.routing import (
|
||||
websocket_session,
|
||||
)
|
||||
from starlette.routing import Mount as Mount # noqa
|
||||
from starlette.types import AppType, ASGIApp, Lifespan, Scope
|
||||
from starlette.types import ASGIApp, Lifespan, Scope
|
||||
from starlette.websockets import WebSocket
|
||||
from typing_extensions import Annotated, Doc, deprecated
|
||||
|
||||
if sys.version_info >= (3, 13): # pragma: no cover
|
||||
from inspect import iscoroutinefunction
|
||||
else: # pragma: no cover
|
||||
from asyncio import iscoroutinefunction
|
||||
from typing_extensions import Annotated, Doc, deprecated # type: ignore [attr-defined]
|
||||
|
||||
|
||||
def _prepare_response_content(
|
||||
@@ -125,28 +115,10 @@ def _prepare_response_content(
|
||||
for k, v in res.items()
|
||||
}
|
||||
elif dataclasses.is_dataclass(res):
|
||||
assert not isinstance(res, type)
|
||||
return dataclasses.asdict(res)
|
||||
return res
|
||||
|
||||
|
||||
def _merge_lifespan_context(
|
||||
original_context: Lifespan[Any], nested_context: Lifespan[Any]
|
||||
) -> Lifespan[Any]:
|
||||
@asynccontextmanager
|
||||
async def merged_lifespan(
|
||||
app: AppType,
|
||||
) -> AsyncIterator[Optional[Mapping[str, Any]]]:
|
||||
async with original_context(app) as maybe_original_state:
|
||||
async with nested_context(app) as maybe_nested_state:
|
||||
if maybe_nested_state is None and maybe_original_state is None:
|
||||
yield None # old ASGI compatibility
|
||||
else:
|
||||
yield {**(maybe_nested_state or {}), **(maybe_original_state or {})}
|
||||
|
||||
return merged_lifespan # type: ignore[return-value]
|
||||
|
||||
|
||||
async def serialize_response(
|
||||
*,
|
||||
field: Optional[ModelField] = None,
|
||||
@@ -234,10 +206,9 @@ def get_request_handler(
|
||||
response_model_exclude_defaults: bool = False,
|
||||
response_model_exclude_none: bool = False,
|
||||
dependency_overrides_provider: Optional[Any] = None,
|
||||
embed_body_fields: bool = False,
|
||||
) -> Callable[[Request], Coroutine[Any, Any, Response]]:
|
||||
assert dependant.call is not None, "dependant.call must be a function"
|
||||
is_coroutine = iscoroutinefunction(dependant.call)
|
||||
is_coroutine = asyncio.iscoroutinefunction(dependant.call)
|
||||
is_body_form = body_field and isinstance(body_field.field_info, params.Form)
|
||||
if isinstance(response_class, DefaultPlaceholder):
|
||||
actual_response_class: Type[Response] = response_class.value
|
||||
@@ -245,149 +216,113 @@ def get_request_handler(
|
||||
actual_response_class = response_class
|
||||
|
||||
async def app(request: Request) -> Response:
|
||||
response: Union[Response, None] = None
|
||||
async with AsyncExitStack() as file_stack:
|
||||
try:
|
||||
body: Any = None
|
||||
if body_field:
|
||||
if is_body_form:
|
||||
body = await request.form()
|
||||
file_stack.push_async_callback(body.close)
|
||||
else:
|
||||
body_bytes = await request.body()
|
||||
if body_bytes:
|
||||
json_body: Any = Undefined
|
||||
content_type_value = request.headers.get("content-type")
|
||||
if not content_type_value:
|
||||
json_body = await request.json()
|
||||
else:
|
||||
message = email.message.Message()
|
||||
message["content-type"] = content_type_value
|
||||
if message.get_content_maintype() == "application":
|
||||
subtype = message.get_content_subtype()
|
||||
if subtype == "json" or subtype.endswith("+json"):
|
||||
json_body = await request.json()
|
||||
if json_body != Undefined:
|
||||
body = json_body
|
||||
else:
|
||||
body = body_bytes
|
||||
except json.JSONDecodeError as e:
|
||||
validation_error = RequestValidationError(
|
||||
[
|
||||
{
|
||||
"type": "json_invalid",
|
||||
"loc": ("body", e.pos),
|
||||
"msg": "JSON decode error",
|
||||
"input": {},
|
||||
"ctx": {"error": e.msg},
|
||||
}
|
||||
],
|
||||
body=e.doc,
|
||||
)
|
||||
raise validation_error from e
|
||||
except HTTPException:
|
||||
# If a middleware raises an HTTPException, it should be raised again
|
||||
raise
|
||||
except Exception as e:
|
||||
http_error = HTTPException(
|
||||
status_code=400, detail="There was an error parsing the body"
|
||||
)
|
||||
raise http_error from e
|
||||
errors: List[Any] = []
|
||||
async with AsyncExitStack() as async_exit_stack:
|
||||
solved_result = await solve_dependencies(
|
||||
request=request,
|
||||
dependant=dependant,
|
||||
body=body,
|
||||
dependency_overrides_provider=dependency_overrides_provider,
|
||||
async_exit_stack=async_exit_stack,
|
||||
embed_body_fields=embed_body_fields,
|
||||
)
|
||||
errors = solved_result.errors
|
||||
if not errors:
|
||||
raw_response = await run_endpoint_function(
|
||||
dependant=dependant,
|
||||
values=solved_result.values,
|
||||
is_coroutine=is_coroutine,
|
||||
)
|
||||
if isinstance(raw_response, Response):
|
||||
if raw_response.background is None:
|
||||
raw_response.background = solved_result.background_tasks
|
||||
response = raw_response
|
||||
else:
|
||||
response_args: Dict[str, Any] = {
|
||||
"background": solved_result.background_tasks
|
||||
}
|
||||
# If status_code was set, use it, otherwise use the default from the
|
||||
# response class, in the case of redirect it's 307
|
||||
current_status_code = (
|
||||
status_code
|
||||
if status_code
|
||||
else solved_result.response.status_code
|
||||
)
|
||||
if current_status_code is not None:
|
||||
response_args["status_code"] = current_status_code
|
||||
if solved_result.response.status_code:
|
||||
response_args["status_code"] = (
|
||||
solved_result.response.status_code
|
||||
)
|
||||
content = await serialize_response(
|
||||
field=response_field,
|
||||
response_content=raw_response,
|
||||
include=response_model_include,
|
||||
exclude=response_model_exclude,
|
||||
by_alias=response_model_by_alias,
|
||||
exclude_unset=response_model_exclude_unset,
|
||||
exclude_defaults=response_model_exclude_defaults,
|
||||
exclude_none=response_model_exclude_none,
|
||||
is_coroutine=is_coroutine,
|
||||
)
|
||||
response = actual_response_class(content, **response_args)
|
||||
if not is_body_allowed_for_status_code(response.status_code):
|
||||
response.body = b""
|
||||
response.headers.raw.extend(solved_result.response.headers.raw)
|
||||
if errors:
|
||||
validation_error = RequestValidationError(
|
||||
_normalize_errors(errors), body=body
|
||||
)
|
||||
raise validation_error
|
||||
if response is None:
|
||||
raise FastAPIError(
|
||||
"No response object was returned. There's a high chance that the "
|
||||
"application code is raising an exception and a dependency with yield "
|
||||
"has a block with a bare except, or a block with except Exception, "
|
||||
"and is not raising the exception again. Read more about it in the "
|
||||
"docs: https://fastapi.tiangolo.com/tutorial/dependencies/dependencies-with-yield/#dependencies-with-yield-and-except"
|
||||
try:
|
||||
body: Any = None
|
||||
if body_field:
|
||||
if is_body_form:
|
||||
body = await request.form()
|
||||
stack = request.scope.get("fastapi_astack")
|
||||
assert isinstance(stack, AsyncExitStack)
|
||||
stack.push_async_callback(body.close)
|
||||
else:
|
||||
body_bytes = await request.body()
|
||||
if body_bytes:
|
||||
json_body: Any = Undefined
|
||||
content_type_value = request.headers.get("content-type")
|
||||
if not content_type_value:
|
||||
json_body = await request.json()
|
||||
else:
|
||||
message = email.message.Message()
|
||||
message["content-type"] = content_type_value
|
||||
if message.get_content_maintype() == "application":
|
||||
subtype = message.get_content_subtype()
|
||||
if subtype == "json" or subtype.endswith("+json"):
|
||||
json_body = await request.json()
|
||||
if json_body != Undefined:
|
||||
body = json_body
|
||||
else:
|
||||
body = body_bytes
|
||||
except json.JSONDecodeError as e:
|
||||
raise RequestValidationError(
|
||||
[
|
||||
{
|
||||
"type": "json_invalid",
|
||||
"loc": ("body", e.pos),
|
||||
"msg": "JSON decode error",
|
||||
"input": {},
|
||||
"ctx": {"error": e.msg},
|
||||
}
|
||||
],
|
||||
body=e.doc,
|
||||
) from e
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="There was an error parsing the body"
|
||||
) from e
|
||||
solved_result = await solve_dependencies(
|
||||
request=request,
|
||||
dependant=dependant,
|
||||
body=body,
|
||||
dependency_overrides_provider=dependency_overrides_provider,
|
||||
)
|
||||
values, errors, background_tasks, sub_response, _ = solved_result
|
||||
if errors:
|
||||
raise RequestValidationError(_normalize_errors(errors), body=body)
|
||||
else:
|
||||
raw_response = await run_endpoint_function(
|
||||
dependant=dependant, values=values, is_coroutine=is_coroutine
|
||||
)
|
||||
return response
|
||||
|
||||
if isinstance(raw_response, Response):
|
||||
if raw_response.background is None:
|
||||
raw_response.background = background_tasks
|
||||
return raw_response
|
||||
response_args: Dict[str, Any] = {"background": background_tasks}
|
||||
# If status_code was set, use it, otherwise use the default from the
|
||||
# response class, in the case of redirect it's 307
|
||||
current_status_code = (
|
||||
status_code if status_code else sub_response.status_code
|
||||
)
|
||||
if current_status_code is not None:
|
||||
response_args["status_code"] = current_status_code
|
||||
if sub_response.status_code:
|
||||
response_args["status_code"] = sub_response.status_code
|
||||
content = await serialize_response(
|
||||
field=response_field,
|
||||
response_content=raw_response,
|
||||
include=response_model_include,
|
||||
exclude=response_model_exclude,
|
||||
by_alias=response_model_by_alias,
|
||||
exclude_unset=response_model_exclude_unset,
|
||||
exclude_defaults=response_model_exclude_defaults,
|
||||
exclude_none=response_model_exclude_none,
|
||||
is_coroutine=is_coroutine,
|
||||
)
|
||||
response = actual_response_class(content, **response_args)
|
||||
if not is_body_allowed_for_status_code(response.status_code):
|
||||
response.body = b""
|
||||
response.headers.raw.extend(sub_response.headers.raw)
|
||||
return response
|
||||
|
||||
return app
|
||||
|
||||
|
||||
def get_websocket_app(
|
||||
dependant: Dependant,
|
||||
dependency_overrides_provider: Optional[Any] = None,
|
||||
embed_body_fields: bool = False,
|
||||
dependant: Dependant, dependency_overrides_provider: Optional[Any] = None
|
||||
) -> Callable[[WebSocket], Coroutine[Any, Any, Any]]:
|
||||
async def app(websocket: WebSocket) -> None:
|
||||
async with AsyncExitStack() as async_exit_stack:
|
||||
# TODO: remove this scope later, after a few releases
|
||||
# This scope fastapi_astack is no longer used by FastAPI, kept for
|
||||
# compatibility, just in case
|
||||
websocket.scope["fastapi_astack"] = async_exit_stack
|
||||
solved_result = await solve_dependencies(
|
||||
request=websocket,
|
||||
dependant=dependant,
|
||||
dependency_overrides_provider=dependency_overrides_provider,
|
||||
async_exit_stack=async_exit_stack,
|
||||
embed_body_fields=embed_body_fields,
|
||||
)
|
||||
if solved_result.errors:
|
||||
raise WebSocketRequestValidationError(
|
||||
_normalize_errors(solved_result.errors)
|
||||
)
|
||||
assert dependant.call is not None, "dependant.call must be a function"
|
||||
await dependant.call(**solved_result.values)
|
||||
solved_result = await solve_dependencies(
|
||||
request=websocket,
|
||||
dependant=dependant,
|
||||
dependency_overrides_provider=dependency_overrides_provider,
|
||||
)
|
||||
values, errors, _, _2, _3 = solved_result
|
||||
if errors:
|
||||
raise WebSocketRequestValidationError(_normalize_errors(errors))
|
||||
assert dependant.call is not None, "dependant.call must be a function"
|
||||
await dependant.call(**values)
|
||||
|
||||
return app
|
||||
|
||||
@@ -413,15 +348,11 @@ class APIWebSocketRoute(routing.WebSocketRoute):
|
||||
0,
|
||||
get_parameterless_sub_dependant(depends=depends, path=self.path_format),
|
||||
)
|
||||
self._flat_dependant = get_flat_dependant(self.dependant)
|
||||
self._embed_body_fields = _should_embed_body_fields(
|
||||
self._flat_dependant.body_params
|
||||
)
|
||||
|
||||
self.app = websocket_session(
|
||||
get_websocket_app(
|
||||
dependant=self.dependant,
|
||||
dependency_overrides_provider=dependency_overrides_provider,
|
||||
embed_body_fields=self._embed_body_fields,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -500,9 +431,9 @@ class APIRoute(routing.Route):
|
||||
methods = ["GET"]
|
||||
self.methods: Set[str] = {method.upper() for method in methods}
|
||||
if isinstance(generate_unique_id_function, DefaultPlaceholder):
|
||||
current_generate_unique_id: Callable[[APIRoute], str] = (
|
||||
generate_unique_id_function.value
|
||||
)
|
||||
current_generate_unique_id: Callable[
|
||||
["APIRoute"], str
|
||||
] = generate_unique_id_function.value
|
||||
else:
|
||||
current_generate_unique_id = generate_unique_id_function
|
||||
self.unique_id = self.operation_id or current_generate_unique_id(self)
|
||||
@@ -511,11 +442,11 @@ class APIRoute(routing.Route):
|
||||
status_code = int(status_code)
|
||||
self.status_code = status_code
|
||||
if self.response_model:
|
||||
assert is_body_allowed_for_status_code(status_code), (
|
||||
f"Status code {status_code} must not have a response body"
|
||||
)
|
||||
assert is_body_allowed_for_status_code(
|
||||
status_code
|
||||
), f"Status code {status_code} must not have a response body"
|
||||
response_name = "Response_" + self.unique_id
|
||||
self.response_field = create_model_field(
|
||||
self.response_field = create_response_field(
|
||||
name=response_name,
|
||||
type_=self.response_model,
|
||||
mode="serialization",
|
||||
@@ -528,9 +459,9 @@ class APIRoute(routing.Route):
|
||||
# By being a new field, no inheritance will be passed as is. A new model
|
||||
# will always be created.
|
||||
# TODO: remove when deprecating Pydantic v1
|
||||
self.secure_cloned_response_field: Optional[ModelField] = (
|
||||
create_cloned_field(self.response_field)
|
||||
)
|
||||
self.secure_cloned_response_field: Optional[
|
||||
ModelField
|
||||
] = create_cloned_field(self.response_field)
|
||||
else:
|
||||
self.response_field = None # type: ignore
|
||||
self.secure_cloned_response_field = None
|
||||
@@ -544,13 +475,11 @@ class APIRoute(routing.Route):
|
||||
assert isinstance(response, dict), "An additional response must be a dict"
|
||||
model = response.get("model")
|
||||
if model:
|
||||
assert is_body_allowed_for_status_code(additional_status_code), (
|
||||
f"Status code {additional_status_code} must not have a response body"
|
||||
)
|
||||
assert is_body_allowed_for_status_code(
|
||||
additional_status_code
|
||||
), f"Status code {additional_status_code} must not have a response body"
|
||||
response_name = f"Response_{additional_status_code}_{self.unique_id}"
|
||||
response_field = create_model_field(
|
||||
name=response_name, type_=model, mode="serialization"
|
||||
)
|
||||
response_field = create_response_field(name=response_name, type_=model)
|
||||
response_fields[additional_status_code] = response_field
|
||||
if response_fields:
|
||||
self.response_fields: Dict[Union[int, str], ModelField] = response_fields
|
||||
@@ -564,15 +493,7 @@ class APIRoute(routing.Route):
|
||||
0,
|
||||
get_parameterless_sub_dependant(depends=depends, path=self.path_format),
|
||||
)
|
||||
self._flat_dependant = get_flat_dependant(self.dependant)
|
||||
self._embed_body_fields = _should_embed_body_fields(
|
||||
self._flat_dependant.body_params
|
||||
)
|
||||
self.body_field = get_body_field(
|
||||
flat_dependant=self._flat_dependant,
|
||||
name=self.unique_id,
|
||||
embed_body_fields=self._embed_body_fields,
|
||||
)
|
||||
self.body_field = get_body_field(dependant=self.dependant, name=self.unique_id)
|
||||
self.app = request_response(self.get_route_handler())
|
||||
|
||||
def get_route_handler(self) -> Callable[[Request], Coroutine[Any, Any, Response]]:
|
||||
@@ -589,7 +510,6 @@ class APIRoute(routing.Route):
|
||||
response_model_exclude_defaults=self.response_model_exclude_defaults,
|
||||
response_model_exclude_none=self.response_model_exclude_none,
|
||||
dependency_overrides_provider=self.dependency_overrides_provider,
|
||||
embed_body_fields=self._embed_body_fields,
|
||||
)
|
||||
|
||||
def matches(self, scope: Scope) -> Tuple[Match, Scope]:
|
||||
@@ -821,7 +741,7 @@ class APIRouter(routing.Router):
|
||||
This affects the generated OpenAPI (e.g. visible at `/docs`).
|
||||
|
||||
Read more about it in the
|
||||
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-parameters-from-openapi).
|
||||
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-from-openapi).
|
||||
"""
|
||||
),
|
||||
] = True,
|
||||
@@ -851,9 +771,9 @@ class APIRouter(routing.Router):
|
||||
)
|
||||
if prefix:
|
||||
assert prefix.startswith("/"), "A path prefix must start with '/'"
|
||||
assert not prefix.endswith("/"), (
|
||||
"A path prefix must not end with '/', as the routes will start with '/'"
|
||||
)
|
||||
assert not prefix.endswith(
|
||||
"/"
|
||||
), "A path prefix must not end with '/', as the routes will start with '/'"
|
||||
self.prefix = prefix
|
||||
self.tags: List[Union[str, Enum]] = tags or []
|
||||
self.dependencies = list(dependencies or [])
|
||||
@@ -869,7 +789,7 @@ class APIRouter(routing.Router):
|
||||
def route(
|
||||
self,
|
||||
path: str,
|
||||
methods: Optional[Collection[str]] = None,
|
||||
methods: Optional[List[str]] = None,
|
||||
name: Optional[str] = None,
|
||||
include_in_schema: bool = True,
|
||||
) -> Callable[[DecoratedCallable], DecoratedCallable]:
|
||||
@@ -1263,9 +1183,9 @@ class APIRouter(routing.Router):
|
||||
"""
|
||||
if prefix:
|
||||
assert prefix.startswith("/"), "A path prefix must start with '/'"
|
||||
assert not prefix.endswith("/"), (
|
||||
"A path prefix must not end with '/', as the routes will start with '/'"
|
||||
)
|
||||
assert not prefix.endswith(
|
||||
"/"
|
||||
), "A path prefix must not end with '/', as the routes will start with '/'"
|
||||
else:
|
||||
for r in router.routes:
|
||||
path = getattr(r, "path") # noqa: B009
|
||||
@@ -1365,10 +1285,6 @@ class APIRouter(routing.Router):
|
||||
self.add_event_handler("startup", handler)
|
||||
for handler in router.on_shutdown:
|
||||
self.add_event_handler("shutdown", handler)
|
||||
self.lifespan_context = _merge_lifespan_context(
|
||||
self.lifespan_context,
|
||||
router.lifespan_context,
|
||||
)
|
||||
|
||||
def get(
|
||||
self,
|
||||
@@ -1633,7 +1549,7 @@ class APIRouter(routing.Router):
|
||||
This affects the generated OpenAPI (e.g. visible at `/docs`).
|
||||
|
||||
Read more about it in the
|
||||
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-parameters-from-openapi).
|
||||
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-from-openapi).
|
||||
"""
|
||||
),
|
||||
] = True,
|
||||
@@ -2010,7 +1926,7 @@ class APIRouter(routing.Router):
|
||||
This affects the generated OpenAPI (e.g. visible at `/docs`).
|
||||
|
||||
Read more about it in the
|
||||
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-parameters-from-openapi).
|
||||
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-from-openapi).
|
||||
"""
|
||||
),
|
||||
] = True,
|
||||
@@ -2392,7 +2308,7 @@ class APIRouter(routing.Router):
|
||||
This affects the generated OpenAPI (e.g. visible at `/docs`).
|
||||
|
||||
Read more about it in the
|
||||
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-parameters-from-openapi).
|
||||
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-from-openapi).
|
||||
"""
|
||||
),
|
||||
] = True,
|
||||
@@ -2774,7 +2690,7 @@ class APIRouter(routing.Router):
|
||||
This affects the generated OpenAPI (e.g. visible at `/docs`).
|
||||
|
||||
Read more about it in the
|
||||
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-parameters-from-openapi).
|
||||
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-from-openapi).
|
||||
"""
|
||||
),
|
||||
] = True,
|
||||
@@ -3151,7 +3067,7 @@ class APIRouter(routing.Router):
|
||||
This affects the generated OpenAPI (e.g. visible at `/docs`).
|
||||
|
||||
Read more about it in the
|
||||
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-parameters-from-openapi).
|
||||
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-from-openapi).
|
||||
"""
|
||||
),
|
||||
] = True,
|
||||
@@ -3528,7 +3444,7 @@ class APIRouter(routing.Router):
|
||||
This affects the generated OpenAPI (e.g. visible at `/docs`).
|
||||
|
||||
Read more about it in the
|
||||
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-parameters-from-openapi).
|
||||
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-from-openapi).
|
||||
"""
|
||||
),
|
||||
] = True,
|
||||
@@ -3910,7 +3826,7 @@ class APIRouter(routing.Router):
|
||||
This affects the generated OpenAPI (e.g. visible at `/docs`).
|
||||
|
||||
Read more about it in the
|
||||
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-parameters-from-openapi).
|
||||
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-from-openapi).
|
||||
"""
|
||||
),
|
||||
] = True,
|
||||
@@ -4292,7 +4208,7 @@ class APIRouter(routing.Router):
|
||||
This affects the generated OpenAPI (e.g. visible at `/docs`).
|
||||
|
||||
Read more about it in the
|
||||
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-parameters-from-openapi).
|
||||
[FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-from-openapi).
|
||||
"""
|
||||
),
|
||||
] = True,
|
||||
@@ -4377,7 +4293,7 @@ class APIRouter(routing.Router):
|
||||
app = FastAPI()
|
||||
router = APIRouter()
|
||||
|
||||
@router.trace("/items/{item_id}")
|
||||
@router.put("/items/{item_id}")
|
||||
def trace_item(item_id: str):
|
||||
return None
|
||||
|
||||
|
||||
@@ -5,19 +5,11 @@ from fastapi.security.base import SecurityBase
|
||||
from starlette.exceptions import HTTPException
|
||||
from starlette.requests import Request
|
||||
from starlette.status import HTTP_403_FORBIDDEN
|
||||
from typing_extensions import Annotated, Doc
|
||||
from typing_extensions import Annotated, Doc # type: ignore [attr-defined]
|
||||
|
||||
|
||||
class APIKeyBase(SecurityBase):
|
||||
@staticmethod
|
||||
def check_api_key(api_key: Optional[str], auto_error: bool) -> Optional[str]:
|
||||
if not api_key:
|
||||
if auto_error:
|
||||
raise HTTPException(
|
||||
status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
|
||||
)
|
||||
return None
|
||||
return api_key
|
||||
pass
|
||||
|
||||
|
||||
class APIKeyQuery(APIKeyBase):
|
||||
@@ -84,7 +76,7 @@ class APIKeyQuery(APIKeyBase):
|
||||
Doc(
|
||||
"""
|
||||
By default, if the query parameter is not provided, `APIKeyQuery` will
|
||||
automatically cancel the request and send the client an error.
|
||||
automatically cancel the request and sebd the client an error.
|
||||
|
||||
If `auto_error` is set to `False`, when the query parameter is not
|
||||
available, instead of erroring out, the dependency result will be
|
||||
@@ -100,7 +92,7 @@ class APIKeyQuery(APIKeyBase):
|
||||
] = True,
|
||||
):
|
||||
self.model: APIKey = APIKey(
|
||||
**{"in": APIKeyIn.query},
|
||||
**{"in": APIKeyIn.query}, # type: ignore[arg-type]
|
||||
name=name,
|
||||
description=description,
|
||||
)
|
||||
@@ -109,7 +101,14 @@ class APIKeyQuery(APIKeyBase):
|
||||
|
||||
async def __call__(self, request: Request) -> Optional[str]:
|
||||
api_key = request.query_params.get(self.model.name)
|
||||
return self.check_api_key(api_key, self.auto_error)
|
||||
if not api_key:
|
||||
if self.auto_error:
|
||||
raise HTTPException(
|
||||
status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
|
||||
)
|
||||
else:
|
||||
return None
|
||||
return api_key
|
||||
|
||||
|
||||
class APIKeyHeader(APIKeyBase):
|
||||
@@ -188,7 +187,7 @@ class APIKeyHeader(APIKeyBase):
|
||||
] = True,
|
||||
):
|
||||
self.model: APIKey = APIKey(
|
||||
**{"in": APIKeyIn.header},
|
||||
**{"in": APIKeyIn.header}, # type: ignore[arg-type]
|
||||
name=name,
|
||||
description=description,
|
||||
)
|
||||
@@ -197,7 +196,14 @@ class APIKeyHeader(APIKeyBase):
|
||||
|
||||
async def __call__(self, request: Request) -> Optional[str]:
|
||||
api_key = request.headers.get(self.model.name)
|
||||
return self.check_api_key(api_key, self.auto_error)
|
||||
if not api_key:
|
||||
if self.auto_error:
|
||||
raise HTTPException(
|
||||
status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
|
||||
)
|
||||
else:
|
||||
return None
|
||||
return api_key
|
||||
|
||||
|
||||
class APIKeyCookie(APIKeyBase):
|
||||
@@ -276,7 +282,7 @@ class APIKeyCookie(APIKeyBase):
|
||||
] = True,
|
||||
):
|
||||
self.model: APIKey = APIKey(
|
||||
**{"in": APIKeyIn.cookie},
|
||||
**{"in": APIKeyIn.cookie}, # type: ignore[arg-type]
|
||||
name=name,
|
||||
description=description,
|
||||
)
|
||||
@@ -285,4 +291,11 @@ class APIKeyCookie(APIKeyBase):
|
||||
|
||||
async def __call__(self, request: Request) -> Optional[str]:
|
||||
api_key = request.cookies.get(self.model.name)
|
||||
return self.check_api_key(api_key, self.auto_error)
|
||||
if not api_key:
|
||||
if self.auto_error:
|
||||
raise HTTPException(
|
||||
status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
|
||||
)
|
||||
else:
|
||||
return None
|
||||
return api_key
|
||||
|
||||
@@ -10,12 +10,12 @@ from fastapi.security.utils import get_authorization_scheme_param
|
||||
from pydantic import BaseModel
|
||||
from starlette.requests import Request
|
||||
from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN
|
||||
from typing_extensions import Annotated, Doc
|
||||
from typing_extensions import Annotated, Doc # type: ignore [attr-defined]
|
||||
|
||||
|
||||
class HTTPBasicCredentials(BaseModel):
|
||||
"""
|
||||
The HTTP Basic credentials given as the result of using `HTTPBasic` in a
|
||||
The HTTP Basic credendials given as the result of using `HTTPBasic` in a
|
||||
dependency.
|
||||
|
||||
Read more about it in the
|
||||
@@ -277,7 +277,7 @@ class HTTPBearer(HTTPBase):
|
||||
bool,
|
||||
Doc(
|
||||
"""
|
||||
By default, if the HTTP Bearer token is not provided (in an
|
||||
By default, if the HTTP Bearer token not provided (in an
|
||||
`Authorization` header), `HTTPBearer` will automatically cancel the
|
||||
request and send the client an error.
|
||||
|
||||
@@ -380,7 +380,7 @@ class HTTPDigest(HTTPBase):
|
||||
bool,
|
||||
Doc(
|
||||
"""
|
||||
By default, if the HTTP Digest is not provided, `HTTPDigest` will
|
||||
By default, if the HTTP Digest not provided, `HTTPDigest` will
|
||||
automatically cancel the request and send the client an error.
|
||||
|
||||
If `auto_error` is set to `False`, when the HTTP Digest is not
|
||||
@@ -413,11 +413,8 @@ class HTTPDigest(HTTPBase):
|
||||
else:
|
||||
return None
|
||||
if scheme.lower() != "digest":
|
||||
if self.auto_error:
|
||||
raise HTTPException(
|
||||
status_code=HTTP_403_FORBIDDEN,
|
||||
detail="Invalid authentication credentials",
|
||||
)
|
||||
else:
|
||||
return None
|
||||
raise HTTPException(
|
||||
status_code=HTTP_403_FORBIDDEN,
|
||||
detail="Invalid authentication credentials",
|
||||
)
|
||||
return HTTPAuthorizationCredentials(scheme=scheme, credentials=credentials)
|
||||
|
||||
@@ -10,7 +10,7 @@ from starlette.requests import Request
|
||||
from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN
|
||||
|
||||
# TODO: import from typing when deprecating Python 3.9
|
||||
from typing_extensions import Annotated, Doc
|
||||
from typing_extensions import Annotated, Doc # type: ignore [attr-defined]
|
||||
|
||||
|
||||
class OAuth2PasswordRequestForm:
|
||||
@@ -52,9 +52,9 @@ class OAuth2PasswordRequestForm:
|
||||
```
|
||||
|
||||
Note that for OAuth2 the scope `items:read` is a single scope in an opaque string.
|
||||
You could have custom internal logic to separate it by colon characters (`:`) or
|
||||
You could have custom internal logic to separate it by colon caracters (`:`) or
|
||||
similar, and get the two parts `items` and `read`. Many applications do that to
|
||||
group and organize permissions, you could do it as well in your application, just
|
||||
group and organize permisions, you could do it as well in your application, just
|
||||
know that that it is application specific, it's not part of the specification.
|
||||
"""
|
||||
|
||||
@@ -63,7 +63,7 @@ class OAuth2PasswordRequestForm:
|
||||
*,
|
||||
grant_type: Annotated[
|
||||
Union[str, None],
|
||||
Form(pattern="^password$"),
|
||||
Form(pattern="password"),
|
||||
Doc(
|
||||
"""
|
||||
The OAuth2 spec says it is required and MUST be the fixed string
|
||||
@@ -85,7 +85,7 @@ class OAuth2PasswordRequestForm:
|
||||
],
|
||||
password: Annotated[
|
||||
str,
|
||||
Form(json_schema_extra={"format": "password"}),
|
||||
Form(),
|
||||
Doc(
|
||||
"""
|
||||
`password` string. The OAuth2 spec requires the exact field name
|
||||
@@ -130,7 +130,7 @@ class OAuth2PasswordRequestForm:
|
||||
] = None,
|
||||
client_secret: Annotated[
|
||||
Union[str, None],
|
||||
Form(json_schema_extra={"format": "password"}),
|
||||
Form(),
|
||||
Doc(
|
||||
"""
|
||||
If there's a `client_password` (and a `client_id`), they can be sent
|
||||
@@ -194,9 +194,9 @@ class OAuth2PasswordRequestFormStrict(OAuth2PasswordRequestForm):
|
||||
```
|
||||
|
||||
Note that for OAuth2 the scope `items:read` is a single scope in an opaque string.
|
||||
You could have custom internal logic to separate it by colon characters (`:`) or
|
||||
You could have custom internal logic to separate it by colon caracters (`:`) or
|
||||
similar, and get the two parts `items` and `read`. Many applications do that to
|
||||
group and organize permissions, you could do it as well in your application, just
|
||||
group and organize permisions, you could do it as well in your application, just
|
||||
know that that it is application specific, it's not part of the specification.
|
||||
|
||||
|
||||
@@ -217,7 +217,7 @@ class OAuth2PasswordRequestFormStrict(OAuth2PasswordRequestForm):
|
||||
self,
|
||||
grant_type: Annotated[
|
||||
str,
|
||||
Form(pattern="^password$"),
|
||||
Form(pattern="password"),
|
||||
Doc(
|
||||
"""
|
||||
The OAuth2 spec says it is required and MUST be the fixed string
|
||||
@@ -353,7 +353,7 @@ class OAuth2(SecurityBase):
|
||||
bool,
|
||||
Doc(
|
||||
"""
|
||||
By default, if no HTTP Authorization header is provided, required for
|
||||
By default, if no HTTP Auhtorization header is provided, required for
|
||||
OAuth2 authentication, it will automatically cancel the request and
|
||||
send the client an error.
|
||||
|
||||
@@ -441,7 +441,7 @@ class OAuth2PasswordBearer(OAuth2):
|
||||
bool,
|
||||
Doc(
|
||||
"""
|
||||
By default, if no HTTP Authorization header is provided, required for
|
||||
By default, if no HTTP Auhtorization header is provided, required for
|
||||
OAuth2 authentication, it will automatically cancel the request and
|
||||
send the client an error.
|
||||
|
||||
@@ -457,26 +457,11 @@ class OAuth2PasswordBearer(OAuth2):
|
||||
"""
|
||||
),
|
||||
] = True,
|
||||
refreshUrl: Annotated[
|
||||
Optional[str],
|
||||
Doc(
|
||||
"""
|
||||
The URL to refresh the token and obtain a new one.
|
||||
"""
|
||||
),
|
||||
] = None,
|
||||
):
|
||||
if not scopes:
|
||||
scopes = {}
|
||||
flows = OAuthFlowsModel(
|
||||
password=cast(
|
||||
Any,
|
||||
{
|
||||
"tokenUrl": tokenUrl,
|
||||
"refreshUrl": refreshUrl,
|
||||
"scopes": scopes,
|
||||
},
|
||||
)
|
||||
password=cast(Any, {"tokenUrl": tokenUrl, "scopes": scopes})
|
||||
)
|
||||
super().__init__(
|
||||
flows=flows,
|
||||
@@ -558,7 +543,7 @@ class OAuth2AuthorizationCodeBearer(OAuth2):
|
||||
bool,
|
||||
Doc(
|
||||
"""
|
||||
By default, if no HTTP Authorization header is provided, required for
|
||||
By default, if no HTTP Auhtorization header is provided, required for
|
||||
OAuth2 authentication, it will automatically cancel the request and
|
||||
send the client an error.
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ from fastapi.security.base import SecurityBase
|
||||
from starlette.exceptions import HTTPException
|
||||
from starlette.requests import Request
|
||||
from starlette.status import HTTP_403_FORBIDDEN
|
||||
from typing_extensions import Annotated, Doc
|
||||
from typing_extensions import Annotated, Doc # type: ignore [attr-defined]
|
||||
|
||||
|
||||
class OpenIdConnect(SecurityBase):
|
||||
@@ -49,7 +49,7 @@ class OpenIdConnect(SecurityBase):
|
||||
bool,
|
||||
Doc(
|
||||
"""
|
||||
By default, if no HTTP Authorization header is provided, required for
|
||||
By default, if no HTTP Auhtorization header is provided, required for
|
||||
OpenID Connect authentication, it will automatically cancel the request
|
||||
and send the client an error.
|
||||
|
||||
|
||||
@@ -6,5 +6,6 @@ from pydantic import BaseModel
|
||||
|
||||
DecoratedCallable = TypeVar("DecoratedCallable", bound=Callable[..., Any])
|
||||
UnionType = getattr(types, "UnionType", Union)
|
||||
NoneType = getattr(types, "UnionType", None)
|
||||
ModelNameMap = Dict[Union[Type[BaseModel], Type[Enum]], str]
|
||||
IncEx = Union[Set[int], Set[str], Dict[int, Any], Dict[str, Any]]
|
||||
|
||||
@@ -34,9 +34,9 @@ if TYPE_CHECKING: # pragma: nocover
|
||||
from .routing import APIRoute
|
||||
|
||||
# Cache for `create_cloned_field`
|
||||
_CLONED_TYPES_CACHE: MutableMapping[Type[BaseModel], Type[BaseModel]] = (
|
||||
WeakKeyDictionary()
|
||||
)
|
||||
_CLONED_TYPES_CACHE: MutableMapping[
|
||||
Type[BaseModel], Type[BaseModel]
|
||||
] = WeakKeyDictionary()
|
||||
|
||||
|
||||
def is_body_allowed_for_status_code(status_code: Union[int, str, None]) -> bool:
|
||||
@@ -53,16 +53,16 @@ def is_body_allowed_for_status_code(status_code: Union[int, str, None]) -> bool:
|
||||
}:
|
||||
return True
|
||||
current_status_code = int(status_code)
|
||||
return not (current_status_code < 200 or current_status_code in {204, 205, 304})
|
||||
return not (current_status_code < 200 or current_status_code in {204, 304})
|
||||
|
||||
|
||||
def get_path_param_names(path: str) -> Set[str]:
|
||||
return set(re.findall("{(.*?)}", path))
|
||||
|
||||
|
||||
def create_model_field(
|
||||
def create_response_field(
|
||||
name: str,
|
||||
type_: Any,
|
||||
type_: Type[Any],
|
||||
class_validators: Optional[Dict[str, Validator]] = None,
|
||||
default: Optional[Any] = Undefined,
|
||||
required: Union[bool, UndefinedType] = Undefined,
|
||||
@@ -71,6 +71,9 @@ def create_model_field(
|
||||
alias: Optional[str] = None,
|
||||
mode: Literal["validation", "serialization"] = "validation",
|
||||
) -> ModelField:
|
||||
"""
|
||||
Create a new response field. Raises if type_ is invalid.
|
||||
"""
|
||||
class_validators = class_validators or {}
|
||||
if PYDANTIC_V2:
|
||||
field_info = field_info or FieldInfo(
|
||||
@@ -132,12 +135,11 @@ def create_cloned_field(
|
||||
use_type.__fields__[f.name] = create_cloned_field(
|
||||
f, cloned_types=cloned_types
|
||||
)
|
||||
new_field = create_model_field(name=field.name, type_=use_type)
|
||||
new_field = create_response_field(name=field.name, type_=use_type)
|
||||
new_field.has_alias = field.has_alias # type: ignore[attr-defined]
|
||||
new_field.alias = field.alias # type: ignore[misc]
|
||||
new_field.class_validators = field.class_validators # type: ignore[attr-defined]
|
||||
new_field.default = field.default # type: ignore[misc]
|
||||
new_field.default_factory = field.default_factory # type: ignore[attr-defined]
|
||||
new_field.required = field.required # type: ignore[misc]
|
||||
new_field.model_config = field.model_config # type: ignore[attr-defined]
|
||||
new_field.field_info = field.field_info
|
||||
@@ -171,17 +173,17 @@ def generate_operation_id_for_path(
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
operation_id = f"{name}{path}"
|
||||
operation_id = name + path
|
||||
operation_id = re.sub(r"\W", "_", operation_id)
|
||||
operation_id = f"{operation_id}_{method.lower()}"
|
||||
operation_id = operation_id + "_" + method.lower()
|
||||
return operation_id
|
||||
|
||||
|
||||
def generate_unique_id(route: "APIRoute") -> str:
|
||||
operation_id = f"{route.name}{route.path_format}"
|
||||
operation_id = route.name + route.path_format
|
||||
operation_id = re.sub(r"\W", "_", operation_id)
|
||||
assert route.methods
|
||||
operation_id = f"{operation_id}_{list(route.methods)[0].lower()}"
|
||||
operation_id = operation_id + "_" + list(route.methods)[0].lower()
|
||||
return operation_id
|
||||
|
||||
|
||||
@@ -219,3 +221,9 @@ def get_value_or_default(
|
||||
if not isinstance(item, DefaultPlaceholder):
|
||||
return item
|
||||
return first_item
|
||||
|
||||
|
||||
def match_pydantic_error_url(error_type: str) -> Any:
|
||||
from dirty_equals import IsStr
|
||||
|
||||
return IsStr(regex=rf"^https://errors\.pydantic\.dev/.*/v/{error_type}")
|
||||
|
||||
Reference in New Issue
Block a user