API refactor
All checks were successful
continuous-integration/drone/push Build is passing

This commit is contained in:
2025-10-07 16:25:52 +09:00
parent 76d0d86211
commit 91c7e04474
1171 changed files with 81940 additions and 44117 deletions

View File

@@ -3,13 +3,13 @@ from __future__ import annotations
from typing import Any
from typing import Callable
from typing import Collection
from typing import ContextManager
from typing import Dict
from typing import List
from typing import Mapping
from typing import MutableMapping
from typing import Optional
from typing import overload
from typing import Sequence
from typing import TextIO
from typing import Tuple
from typing import TYPE_CHECKING
@@ -17,6 +17,7 @@ from typing import Union
from sqlalchemy.sql.schema import Column
from sqlalchemy.sql.schema import FetchedValue
from typing_extensions import ContextManager
from typing_extensions import Literal
from .migration import _ProxyTransaction
@@ -107,7 +108,6 @@ CompareType = Callable[
class EnvironmentContext(util.ModuleClsProxy):
"""A configurational facade made available in an ``env.py`` script.
The :class:`.EnvironmentContext` acts as a *facade* to the more
@@ -227,9 +227,9 @@ class EnvironmentContext(util.ModuleClsProxy):
has been configured.
"""
return self.context_opts.get("as_sql", False)
return self.context_opts.get("as_sql", False) # type: ignore[no-any-return] # noqa: E501
def is_transactional_ddl(self):
def is_transactional_ddl(self) -> bool:
"""Return True if the context is configured to expect a
transactional DDL capable backend.
@@ -341,18 +341,17 @@ class EnvironmentContext(util.ModuleClsProxy):
return self.context_opts.get("tag", None)
@overload
def get_x_argument(self, as_dictionary: Literal[False]) -> List[str]:
...
def get_x_argument(self, as_dictionary: Literal[False]) -> List[str]: ...
@overload
def get_x_argument(self, as_dictionary: Literal[True]) -> Dict[str, str]:
...
def get_x_argument(
self, as_dictionary: Literal[True]
) -> Dict[str, str]: ...
@overload
def get_x_argument(
self, as_dictionary: bool = ...
) -> Union[List[str], Dict[str, str]]:
...
) -> Union[List[str], Dict[str, str]]: ...
def get_x_argument(
self, as_dictionary: bool = False
@@ -366,7 +365,11 @@ class EnvironmentContext(util.ModuleClsProxy):
The return value is a list, returned directly from the ``argparse``
structure. If ``as_dictionary=True`` is passed, the ``x`` arguments
are parsed using ``key=value`` format into a dictionary that is
then returned.
then returned. If there is no ``=`` in the argument, value is an empty
string.
.. versionchanged:: 1.13.1 Support ``as_dictionary=True`` when
arguments are passed without the ``=`` symbol.
For example, to support passing a database URL on the command line,
the standard ``env.py`` script can be modified like this::
@@ -400,7 +403,12 @@ class EnvironmentContext(util.ModuleClsProxy):
else:
value = []
if as_dictionary:
value = dict(arg.split("=", 1) for arg in value)
dict_value = {}
for arg in value:
x_key, _, x_value = arg.partition("=")
dict_value[x_key] = x_value
value = dict_value
return value
def configure(
@@ -416,7 +424,7 @@ class EnvironmentContext(util.ModuleClsProxy):
tag: Optional[str] = None,
template_args: Optional[Dict[str, Any]] = None,
render_as_batch: bool = False,
target_metadata: Optional[MetaData] = None,
target_metadata: Union[MetaData, Sequence[MetaData], None] = None,
include_name: Optional[IncludeNameFn] = None,
include_object: Optional[IncludeObjectFn] = None,
include_schemas: bool = False,
@@ -940,7 +948,7 @@ class EnvironmentContext(util.ModuleClsProxy):
def execute(
self,
sql: Union[Executable, str],
execution_options: Optional[dict] = None,
execution_options: Optional[Dict[str, Any]] = None,
) -> None:
"""Execute the given SQL using the current change context.
@@ -968,7 +976,7 @@ class EnvironmentContext(util.ModuleClsProxy):
def begin_transaction(
self,
) -> Union[_ProxyTransaction, ContextManager[None]]:
) -> Union[_ProxyTransaction, ContextManager[None, Optional[bool]]]:
"""Return a context manager that will
enclose an operation within a "transaction",
as defined by the environment's offline

View File

@@ -1,3 +1,6 @@
# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
# mypy: no-warn-return-any, allow-any-generics
from __future__ import annotations
from contextlib import contextmanager
@@ -8,7 +11,6 @@ from typing import Any
from typing import Callable
from typing import cast
from typing import Collection
from typing import ContextManager
from typing import Dict
from typing import Iterable
from typing import Iterator
@@ -21,13 +23,11 @@ from typing import Union
from sqlalchemy import Column
from sqlalchemy import literal_column
from sqlalchemy import MetaData
from sqlalchemy import PrimaryKeyConstraint
from sqlalchemy import String
from sqlalchemy import Table
from sqlalchemy import select
from sqlalchemy.engine import Engine
from sqlalchemy.engine import url as sqla_url
from sqlalchemy.engine.strategies import MockEngineStrategy
from typing_extensions import ContextManager
from .. import ddl
from .. import util
@@ -83,7 +83,6 @@ class _ProxyTransaction:
class MigrationContext:
"""Represent the database state made available to a migration
script.
@@ -176,7 +175,11 @@ class MigrationContext:
opts["output_encoding"],
)
else:
self.output_buffer = opts.get("output_buffer", sys.stdout)
self.output_buffer = opts.get(
"output_buffer", sys.stdout
) # type:ignore[assignment] # noqa: E501
self.transactional_ddl = transactional_ddl
self._user_compare_type = opts.get("compare_type", True)
self._user_compare_server_default = opts.get(
@@ -188,18 +191,6 @@ class MigrationContext:
self.version_table_schema = version_table_schema = opts.get(
"version_table_schema", None
)
self._version = Table(
version_table,
MetaData(),
Column("version_num", String(32), nullable=False),
schema=version_table_schema,
)
if opts.get("version_table_pk", True):
self._version.append_constraint(
PrimaryKeyConstraint(
"version_num", name="%s_pkc" % version_table
)
)
self._start_from_rev: Optional[str] = opts.get("starting_rev")
self.impl = ddl.DefaultImpl.get_by_dialect(dialect)(
@@ -210,14 +201,23 @@ class MigrationContext:
self.output_buffer,
opts,
)
self._version = self.impl.version_table_impl(
version_table=version_table,
version_table_schema=version_table_schema,
version_table_pk=opts.get("version_table_pk", True),
)
log.info("Context impl %s.", self.impl.__class__.__name__)
if self.as_sql:
log.info("Generating static SQL")
log.info(
"Will assume %s DDL.",
"transactional"
if self.impl.transactional_ddl
else "non-transactional",
(
"transactional"
if self.impl.transactional_ddl
else "non-transactional"
),
)
@classmethod
@@ -342,9 +342,9 @@ class MigrationContext:
# except that it will not know it's in "autocommit" and will
# emit deprecation warnings when an autocommit action takes
# place.
self.connection = (
self.impl.connection
) = base_connection.execution_options(isolation_level="AUTOCOMMIT")
self.connection = self.impl.connection = (
base_connection.execution_options(isolation_level="AUTOCOMMIT")
)
# sqlalchemy future mode will "autobegin" in any case, so take
# control of that "transaction" here
@@ -372,7 +372,7 @@ class MigrationContext:
def begin_transaction(
self, _per_migration: bool = False
) -> Union[_ProxyTransaction, ContextManager[None]]:
) -> Union[_ProxyTransaction, ContextManager[None, Optional[bool]]]:
"""Begin a logical transaction for migration operations.
This method is used within an ``env.py`` script to demarcate where
@@ -521,7 +521,7 @@ class MigrationContext:
start_from_rev = None
elif start_from_rev is not None and self.script:
start_from_rev = [
cast("Script", self.script.get_revision(sfr)).revision
self.script.get_revision(sfr).revision
for sfr in util.to_list(start_from_rev)
if sfr not in (None, "base")
]
@@ -536,7 +536,10 @@ class MigrationContext:
return ()
assert self.connection is not None
return tuple(
row[0] for row in self.connection.execute(self._version.select())
row[0]
for row in self.connection.execute(
select(self._version.c.version_num)
)
)
def _ensure_version_table(self, purge: bool = False) -> None:
@@ -652,7 +655,7 @@ class MigrationContext:
def execute(
self,
sql: Union[Executable, str],
execution_options: Optional[dict] = None,
execution_options: Optional[Dict[str, Any]] = None,
) -> None:
"""Execute a SQL construct or string statement.
@@ -1000,6 +1003,11 @@ class MigrationStep:
is_upgrade: bool
migration_fn: Any
if TYPE_CHECKING:
@property
def doc(self) -> Optional[str]: ...
@property
def name(self) -> str:
return self.migration_fn.__name__
@@ -1048,13 +1056,9 @@ class RevisionStep(MigrationStep):
self.revision = revision
self.is_upgrade = is_upgrade
if is_upgrade:
self.migration_fn = (
revision.module.upgrade # type:ignore[attr-defined]
)
self.migration_fn = revision.module.upgrade
else:
self.migration_fn = (
revision.module.downgrade # type:ignore[attr-defined]
)
self.migration_fn = revision.module.downgrade
def __repr__(self):
return "RevisionStep(%r, is_upgrade=%r)" % (
@@ -1070,7 +1074,7 @@ class RevisionStep(MigrationStep):
)
@property
def doc(self) -> str:
def doc(self) -> Optional[str]:
return self.revision.doc
@property
@@ -1168,7 +1172,18 @@ class RevisionStep(MigrationStep):
}
return tuple(set(self.to_revisions).difference(ancestors))
else:
return self.to_revisions
# for each revision we plan to return, compute its ancestors
# (excluding self), and remove those from the final output since
# they are already accounted for.
ancestors = {
r.revision
for to_revision in self.to_revisions
for r in self.revision_map._get_ancestor_nodes(
self.revision_map.get_revisions(to_revision), check=False
)
if r.revision != to_revision
}
return tuple(set(self.to_revisions).difference(ancestors))
def unmerge_branch_idents(
self, heads: Set[str]
@@ -1283,7 +1298,7 @@ class StampStep(MigrationStep):
def __eq__(self, other):
return (
isinstance(other, StampStep)
and other.from_revisions == self.revisions
and other.from_revisions == self.from_revisions
and other.to_revisions == self.to_revisions
and other.branch_move == self.branch_move
and self.is_upgrade == other.is_upgrade