API refactor
All checks were successful
continuous-integration/drone/push Build is passing

This commit is contained in:
2025-10-07 16:25:52 +09:00
parent 76d0d86211
commit 91c7e04474
1171 changed files with 81940 additions and 44117 deletions

View File

@@ -1,6 +1,9 @@
# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
# mypy: no-warn-return-any, allow-any-generics
from __future__ import annotations
from collections import namedtuple
import logging
import re
from typing import Any
from typing import Callable
@@ -8,6 +11,7 @@ from typing import Dict
from typing import Iterable
from typing import List
from typing import Mapping
from typing import NamedTuple
from typing import Optional
from typing import Sequence
from typing import Set
@@ -17,10 +21,18 @@ from typing import TYPE_CHECKING
from typing import Union
from sqlalchemy import cast
from sqlalchemy import Column
from sqlalchemy import MetaData
from sqlalchemy import PrimaryKeyConstraint
from sqlalchemy import schema
from sqlalchemy import String
from sqlalchemy import Table
from sqlalchemy import text
from . import _autogen
from . import base
from ._autogen import _constraint_sig as _constraint_sig
from ._autogen import ComparisonResult as ComparisonResult
from .. import util
from ..util import sqla_compat
@@ -34,13 +46,10 @@ if TYPE_CHECKING:
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.sql import ClauseElement
from sqlalchemy.sql import Executable
from sqlalchemy.sql.elements import ColumnElement
from sqlalchemy.sql.elements import quoted_name
from sqlalchemy.sql.schema import Column
from sqlalchemy.sql.schema import Constraint
from sqlalchemy.sql.schema import ForeignKeyConstraint
from sqlalchemy.sql.schema import Index
from sqlalchemy.sql.schema import Table
from sqlalchemy.sql.schema import UniqueConstraint
from sqlalchemy.sql.selectable import TableClause
from sqlalchemy.sql.type_api import TypeEngine
@@ -50,6 +59,8 @@ if TYPE_CHECKING:
from ..operations.batch import ApplyBatchImpl
from ..operations.batch import BatchOperationsImpl
log = logging.getLogger(__name__)
class ImplMeta(type):
def __init__(
@@ -66,11 +77,8 @@ class ImplMeta(type):
_impls: Dict[str, Type[DefaultImpl]] = {}
Params = namedtuple("Params", ["token0", "tokens", "args", "kwargs"])
class DefaultImpl(metaclass=ImplMeta):
"""Provide the entrypoint for major migration operations,
including database-specific behavioral variances.
@@ -130,6 +138,40 @@ class DefaultImpl(metaclass=ImplMeta):
self.output_buffer.write(text + "\n\n")
self.output_buffer.flush()
def version_table_impl(
self,
*,
version_table: str,
version_table_schema: Optional[str],
version_table_pk: bool,
**kw: Any,
) -> Table:
"""Generate a :class:`.Table` object which will be used as the
structure for the Alembic version table.
Third party dialects may override this hook to provide an alternate
structure for this :class:`.Table`; requirements are only that it
be named based on the ``version_table`` parameter and contains
at least a single string-holding column named ``version_num``.
.. versionadded:: 1.14
"""
vt = Table(
version_table,
MetaData(),
Column("version_num", String(32), nullable=False),
schema=version_table_schema,
)
if version_table_pk:
vt.append_constraint(
PrimaryKeyConstraint(
"version_num", name=f"{version_table}_pkc"
)
)
return vt
def requires_recreate_in_batch(
self, batch_op: BatchOperationsImpl
) -> bool:
@@ -161,16 +203,15 @@ class DefaultImpl(metaclass=ImplMeta):
def _exec(
self,
construct: Union[Executable, str],
execution_options: Optional[dict[str, Any]] = None,
multiparams: Sequence[dict] = (),
params: Dict[str, Any] = util.immutabledict(),
execution_options: Optional[Mapping[str, Any]] = None,
multiparams: Optional[Sequence[Mapping[str, Any]]] = None,
params: Mapping[str, Any] = util.immutabledict(),
) -> Optional[CursorResult]:
if isinstance(construct, str):
construct = text(construct)
if self.as_sql:
if multiparams or params:
# TODO: coverage
raise Exception("Execution arguments not allowed with as_sql")
if multiparams is not None or params:
raise TypeError("SQL parameters not allowed with as_sql")
compile_kw: dict[str, Any]
if self.literal_binds and not isinstance(
@@ -193,11 +234,16 @@ class DefaultImpl(metaclass=ImplMeta):
assert conn is not None
if execution_options:
conn = conn.execution_options(**execution_options)
if params:
assert isinstance(multiparams, tuple)
multiparams += (params,)
return conn.execute(construct, multiparams)
if params and multiparams is not None:
raise TypeError(
"Can't send params and multiparams at the same time"
)
if multiparams:
return conn.execute(construct, multiparams)
else:
return conn.execute(construct, params)
def execute(
self,
@@ -210,8 +256,11 @@ class DefaultImpl(metaclass=ImplMeta):
self,
table_name: str,
column_name: str,
*,
nullable: Optional[bool] = None,
server_default: Union[_ServerDefault, Literal[False]] = False,
server_default: Optional[
Union[_ServerDefault, Literal[False]]
] = False,
name: Optional[str] = None,
type_: Optional[TypeEngine] = None,
schema: Optional[str] = None,
@@ -322,25 +371,40 @@ class DefaultImpl(metaclass=ImplMeta):
self,
table_name: str,
column: Column[Any],
*,
schema: Optional[Union[str, quoted_name]] = None,
if_not_exists: Optional[bool] = None,
) -> None:
self._exec(base.AddColumn(table_name, column, schema=schema))
self._exec(
base.AddColumn(
table_name,
column,
schema=schema,
if_not_exists=if_not_exists,
)
)
def drop_column(
self,
table_name: str,
column: Column[Any],
*,
schema: Optional[str] = None,
if_exists: Optional[bool] = None,
**kw,
) -> None:
self._exec(base.DropColumn(table_name, column, schema=schema))
self._exec(
base.DropColumn(
table_name, column, schema=schema, if_exists=if_exists
)
)
def add_constraint(self, const: Any) -> None:
if const._create_rule is None or const._create_rule(self):
self._exec(schema.AddConstraint(const))
def drop_constraint(self, const: Constraint) -> None:
self._exec(schema.DropConstraint(const))
def drop_constraint(self, const: Constraint, **kw: Any) -> None:
self._exec(schema.DropConstraint(const, **kw))
def rename_table(
self,
@@ -352,11 +416,11 @@ class DefaultImpl(metaclass=ImplMeta):
base.RenameTable(old_table_name, new_table_name, schema=schema)
)
def create_table(self, table: Table) -> None:
def create_table(self, table: Table, **kw: Any) -> None:
table.dispatch.before_create(
table, self.connection, checkfirst=False, _ddl_runner=self
)
self._exec(schema.CreateTable(table))
self._exec(schema.CreateTable(table, **kw))
table.dispatch.after_create(
table, self.connection, checkfirst=False, _ddl_runner=self
)
@@ -375,11 +439,11 @@ class DefaultImpl(metaclass=ImplMeta):
if comment and with_comment:
self.create_column_comment(column)
def drop_table(self, table: Table) -> None:
def drop_table(self, table: Table, **kw: Any) -> None:
table.dispatch.before_drop(
table, self.connection, checkfirst=False, _ddl_runner=self
)
self._exec(schema.DropTable(table))
self._exec(schema.DropTable(table, **kw))
table.dispatch.after_drop(
table, self.connection, checkfirst=False, _ddl_runner=self
)
@@ -393,7 +457,7 @@ class DefaultImpl(metaclass=ImplMeta):
def drop_table_comment(self, table: Table) -> None:
self._exec(schema.DropTableComment(table))
def create_column_comment(self, column: ColumnElement[Any]) -> None:
def create_column_comment(self, column: Column[Any]) -> None:
self._exec(schema.SetColumnComment(column))
def drop_index(self, index: Index, **kw: Any) -> None:
@@ -412,15 +476,19 @@ class DefaultImpl(metaclass=ImplMeta):
if self.as_sql:
for row in rows:
self._exec(
sqla_compat._insert_inline(table).values(
table.insert()
.inline()
.values(
**{
k: sqla_compat._literal_bindparam(
k, v, type_=table.c[k].type
k: (
sqla_compat._literal_bindparam(
k, v, type_=table.c[k].type
)
if not isinstance(
v, sqla_compat._literal_bindparam
)
else v
)
if not isinstance(
v, sqla_compat._literal_bindparam
)
else v
for k, v in row.items()
}
)
@@ -428,16 +496,13 @@ class DefaultImpl(metaclass=ImplMeta):
else:
if rows:
if multiinsert:
self._exec(
sqla_compat._insert_inline(table), multiparams=rows
)
self._exec(table.insert().inline(), multiparams=rows)
else:
for row in rows:
self._exec(
sqla_compat._insert_inline(table).values(**row)
)
self._exec(table.insert().inline().values(**row))
def _tokenize_column_type(self, column: Column) -> Params:
definition: str
definition = self.dialect.type_compiler.process(column.type).lower()
# tokenize the SQLAlchemy-generated version of a type, so that
@@ -452,9 +517,9 @@ class DefaultImpl(metaclass=ImplMeta):
# varchar character set utf8
#
tokens = re.findall(r"[\w\-_]+|\(.+?\)", definition)
tokens: List[str] = re.findall(r"[\w\-_]+|\(.+?\)", definition)
term_tokens = []
term_tokens: List[str] = []
paren_term = None
for token in tokens:
@@ -466,6 +531,7 @@ class DefaultImpl(metaclass=ImplMeta):
params = Params(term_tokens[0], term_tokens[1:], [], {})
if paren_term:
term: str
for term in re.findall("[^(),]+", paren_term):
if "=" in term:
key, val = term.split("=")
@@ -642,7 +708,7 @@ class DefaultImpl(metaclass=ImplMeta):
diff, ignored = _compare_identity_options(
metadata_identity,
inspector_identity,
sqla_compat.Identity(),
schema.Identity(),
skip={"always"},
)
@@ -664,15 +730,96 @@ class DefaultImpl(metaclass=ImplMeta):
bool(diff) or bool(metadata_identity) != bool(inspector_identity),
)
def create_index_sig(self, index: Index) -> Tuple[Any, ...]:
# order of col matters in an index
return tuple(col.name for col in index.columns)
def _compare_index_unique(
self, metadata_index: Index, reflected_index: Index
) -> Optional[str]:
conn_unique = bool(reflected_index.unique)
meta_unique = bool(metadata_index.unique)
if conn_unique != meta_unique:
return f"unique={conn_unique} to unique={meta_unique}"
else:
return None
def create_unique_constraint_sig(
self, const: UniqueConstraint
) -> Tuple[Any, ...]:
# order of col does not matters in an unique constraint
return tuple(sorted([col.name for col in const.columns]))
def _create_metadata_constraint_sig(
self, constraint: _autogen._C, **opts: Any
) -> _constraint_sig[_autogen._C]:
return _constraint_sig.from_constraint(True, self, constraint, **opts)
def _create_reflected_constraint_sig(
self, constraint: _autogen._C, **opts: Any
) -> _constraint_sig[_autogen._C]:
return _constraint_sig.from_constraint(False, self, constraint, **opts)
def compare_indexes(
self,
metadata_index: Index,
reflected_index: Index,
) -> ComparisonResult:
"""Compare two indexes by comparing the signature generated by
``create_index_sig``.
This method returns a ``ComparisonResult``.
"""
msg: List[str] = []
unique_msg = self._compare_index_unique(
metadata_index, reflected_index
)
if unique_msg:
msg.append(unique_msg)
m_sig = self._create_metadata_constraint_sig(metadata_index)
r_sig = self._create_reflected_constraint_sig(reflected_index)
assert _autogen.is_index_sig(m_sig)
assert _autogen.is_index_sig(r_sig)
# The assumption is that the index have no expression
for sig in m_sig, r_sig:
if sig.has_expressions:
log.warning(
"Generating approximate signature for index %s. "
"The dialect "
"implementation should either skip expression indexes "
"or provide a custom implementation.",
sig.const,
)
if m_sig.column_names != r_sig.column_names:
msg.append(
f"expression {r_sig.column_names} to {m_sig.column_names}"
)
if msg:
return ComparisonResult.Different(msg)
else:
return ComparisonResult.Equal()
def compare_unique_constraint(
self,
metadata_constraint: UniqueConstraint,
reflected_constraint: UniqueConstraint,
) -> ComparisonResult:
"""Compare two unique constraints by comparing the two signatures.
The arguments are two tuples that contain the unique constraint and
the signatures generated by ``create_unique_constraint_sig``.
This method returns a ``ComparisonResult``.
"""
metadata_tup = self._create_metadata_constraint_sig(
metadata_constraint
)
reflected_tup = self._create_reflected_constraint_sig(
reflected_constraint
)
meta_sig = metadata_tup.unnamed
conn_sig = reflected_tup.unnamed
if conn_sig != meta_sig:
return ComparisonResult.Different(
f"expression {conn_sig} to {meta_sig}"
)
else:
return ComparisonResult.Equal()
def _skip_functional_indexes(self, metadata_indexes, conn_indexes):
conn_indexes_by_name = {c.name: c for c in conn_indexes}
@@ -697,6 +844,13 @@ class DefaultImpl(metaclass=ImplMeta):
return reflected_object.get("dialect_options", {})
class Params(NamedTuple):
token0: str
tokens: List[str]
args: List[str]
kwargs: Dict[str, str]
def _compare_identity_options(
metadata_io: Union[schema.Identity, schema.Sequence, None],
inspector_io: Union[schema.Identity, schema.Sequence, None],
@@ -735,12 +889,13 @@ def _compare_identity_options(
set(meta_d).union(insp_d),
)
if sqla_compat.identity_has_dialect_kwargs:
assert hasattr(default_io, "dialect_kwargs")
# use only the dialect kwargs in inspector_io since metadata_io
# can have options for many backends
check_dicts(
getattr(metadata_io, "dialect_kwargs", {}),
getattr(inspector_io, "dialect_kwargs", {}),
default_io.dialect_kwargs, # type: ignore[union-attr]
default_io.dialect_kwargs,
getattr(inspector_io, "dialect_kwargs", {}),
)