API refactor
All checks were successful
continuous-integration/drone/push Build is passing

This commit is contained in:
2025-10-07 16:25:52 +09:00
parent 76d0d86211
commit 91c7e04474
1171 changed files with 81940 additions and 44117 deletions

View File

@@ -1,41 +1,42 @@
from __future__ import annotations
import inspect
import re
import typing
from typing import Any, Callable, NamedTuple
from starlette.requests import Request
from starlette.responses import Response
from starlette.routing import BaseRoute, Mount, Route
from starlette.routing import BaseRoute, Host, Mount, Route
try:
import yaml
except ModuleNotFoundError: # pragma: nocover
except ModuleNotFoundError: # pragma: no cover
yaml = None # type: ignore[assignment]
class OpenAPIResponse(Response):
media_type = "application/vnd.oai.openapi"
def render(self, content: typing.Any) -> bytes:
def render(self, content: Any) -> bytes:
assert yaml is not None, "`pyyaml` must be installed to use OpenAPIResponse."
assert isinstance(
content, dict
), "The schema passed to OpenAPIResponse should be a dictionary."
assert isinstance(content, dict), "The schema passed to OpenAPIResponse should be a dictionary."
return yaml.dump(content, default_flow_style=False).encode("utf-8")
class EndpointInfo(typing.NamedTuple):
class EndpointInfo(NamedTuple):
path: str
http_method: str
func: typing.Callable
func: Callable[..., Any]
_remove_converter_pattern = re.compile(r":\w+}")
class BaseSchemaGenerator:
def get_schema(self, routes: typing.List[BaseRoute]) -> dict:
def get_schema(self, routes: list[BaseRoute]) -> dict[str, Any]:
raise NotImplementedError() # pragma: no cover
def get_endpoints(
self, routes: typing.List[BaseRoute]
) -> typing.List[EndpointInfo]:
def get_endpoints(self, routes: list[BaseRoute]) -> list[EndpointInfo]:
"""
Given the routes, yields the following information:
@@ -46,12 +47,15 @@ class BaseSchemaGenerator:
- func
method ready to extract the docstring
"""
endpoints_info: list = []
endpoints_info: list[EndpointInfo] = []
for route in routes:
if isinstance(route, Mount):
path = self._remove_converter(route.path)
if isinstance(route, (Mount, Host)):
routes = route.routes or []
if isinstance(route, Mount):
path = self._remove_converter(route.path)
else:
path = ""
sub_endpoints = [
EndpointInfo(
path="".join((path, sub_endpoint.path)),
@@ -70,9 +74,7 @@ class BaseSchemaGenerator:
for method in route.methods or ["GET"]:
if method == "HEAD":
continue
endpoints_info.append(
EndpointInfo(path, method.lower(), route.endpoint)
)
endpoints_info.append(EndpointInfo(path, method.lower(), route.endpoint))
else:
path = self._remove_converter(route.path)
for method in ["get", "post", "put", "patch", "delete", "options"]:
@@ -90,9 +92,9 @@ class BaseSchemaGenerator:
Route("/users/{id:int}", endpoint=get_user, methods=["GET"])
Should be represented as `/users/{id}` in the OpenAPI schema.
"""
return re.sub(r":\w+}", "}", path)
return _remove_converter_pattern.sub("}", path)
def parse_docstring(self, func_or_method: typing.Callable) -> dict:
def parse_docstring(self, func_or_method: Callable[..., Any]) -> dict[str, Any]:
"""
Given a function, parse the docstring as YAML and return a dictionary of info.
"""
@@ -123,10 +125,10 @@ class BaseSchemaGenerator:
class SchemaGenerator(BaseSchemaGenerator):
def __init__(self, base_schema: dict) -> None:
def __init__(self, base_schema: dict[str, Any]) -> None:
self.base_schema = base_schema
def get_schema(self, routes: typing.List[BaseRoute]) -> dict:
def get_schema(self, routes: list[BaseRoute]) -> dict[str, Any]:
schema = dict(self.base_schema)
schema.setdefault("paths", {})
endpoints_info = self.get_endpoints(routes)