main commit
All checks were successful
continuous-integration/drone/push Build is passing

This commit is contained in:
2025-10-16 16:30:25 +09:00
parent 91c7e04474
commit 537e7b363f
1146 changed files with 45926 additions and 77196 deletions

View File

@@ -1,10 +1,10 @@
from .api import _render_migration_diffs as _render_migration_diffs
from .api import compare_metadata as compare_metadata
from .api import produce_migrations as produce_migrations
from .api import render_python_code as render_python_code
from .api import RevisionContext as RevisionContext
from .compare import _produce_net_changes as _produce_net_changes
from .compare import comparators as comparators
from .render import render_op_text as render_op_text
from .render import renderers as renderers
from .rewriter import Rewriter as Rewriter
from .api import _render_migration_diffs
from .api import compare_metadata
from .api import produce_migrations
from .api import render_python_code
from .api import RevisionContext
from .compare import _produce_net_changes
from .compare import comparators
from .render import render_op_text
from .render import renderers
from .rewriter import Rewriter

View File

@@ -17,7 +17,6 @@ from . import compare
from . import render
from .. import util
from ..operations import ops
from ..util import sqla_compat
"""Provide the 'autogenerate' feature which can produce migration operations
automatically."""
@@ -28,7 +27,6 @@ if TYPE_CHECKING:
from sqlalchemy.engine import Inspector
from sqlalchemy.sql.schema import MetaData
from sqlalchemy.sql.schema import SchemaItem
from sqlalchemy.sql.schema import Table
from ..config import Config
from ..operations.ops import DowngradeOps
@@ -166,7 +164,6 @@ def compare_metadata(context: MigrationContext, metadata: MetaData) -> Any:
"""
migration_script = produce_migrations(context, metadata)
assert migration_script.upgrade_ops is not None
return migration_script.upgrade_ops.as_diffs()
@@ -277,7 +274,7 @@ class AutogenContext:
"""Maintains configuration and state that's specific to an
autogenerate operation."""
metadata: Union[MetaData, Sequence[MetaData], None] = None
metadata: Optional[MetaData] = None
"""The :class:`~sqlalchemy.schema.MetaData` object
representing the destination.
@@ -332,8 +329,8 @@ class AutogenContext:
def __init__(
self,
migration_context: MigrationContext,
metadata: Union[MetaData, Sequence[MetaData], None] = None,
opts: Optional[Dict[str, Any]] = None,
metadata: Optional[MetaData] = None,
opts: Optional[dict] = None,
autogenerate: bool = True,
) -> None:
if (
@@ -443,7 +440,7 @@ class AutogenContext:
def run_object_filters(
self,
object_: SchemaItem,
name: sqla_compat._ConstraintName,
name: Optional[str],
type_: NameFilterType,
reflected: bool,
compare_to: Optional[SchemaItem],
@@ -467,7 +464,7 @@ class AutogenContext:
run_filters = run_object_filters
@util.memoized_property
def sorted_tables(self) -> List[Table]:
def sorted_tables(self):
"""Return an aggregate of the :attr:`.MetaData.sorted_tables`
collection(s).
@@ -483,7 +480,7 @@ class AutogenContext:
return result
@util.memoized_property
def table_key_to_table(self) -> Dict[str, Table]:
def table_key_to_table(self):
"""Return an aggregate of the :attr:`.MetaData.tables` dictionaries.
The :attr:`.MetaData.tables` collection is a dictionary of table key
@@ -494,7 +491,7 @@ class AutogenContext:
objects contain the same table key, an exception is raised.
"""
result: Dict[str, Table] = {}
result = {}
for m in util.to_list(self.metadata):
intersect = set(result).intersection(set(m.tables))
if intersect:
@@ -596,9 +593,9 @@ class RevisionContext:
migration_script = self.generated_revisions[-1]
if not getattr(migration_script, "_needs_render", False):
migration_script.upgrade_ops_list[-1].upgrade_token = upgrade_token
migration_script.downgrade_ops_list[-1].downgrade_token = (
downgrade_token
)
migration_script.downgrade_ops_list[
-1
].downgrade_token = downgrade_token
migration_script._needs_render = True
else:
migration_script._upgrade_ops.append(

View File

@@ -1,6 +1,3 @@
# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
# mypy: no-warn-return-any, allow-any-generics
from __future__ import annotations
import contextlib
@@ -10,12 +7,12 @@ from typing import Any
from typing import cast
from typing import Dict
from typing import Iterator
from typing import List
from typing import Mapping
from typing import Optional
from typing import Set
from typing import Tuple
from typing import TYPE_CHECKING
from typing import TypeVar
from typing import Union
from sqlalchemy import event
@@ -24,15 +21,10 @@ from sqlalchemy import schema as sa_schema
from sqlalchemy import text
from sqlalchemy import types as sqltypes
from sqlalchemy.sql import expression
from sqlalchemy.sql.elements import conv
from sqlalchemy.sql.schema import ForeignKeyConstraint
from sqlalchemy.sql.schema import Index
from sqlalchemy.sql.schema import UniqueConstraint
from sqlalchemy.util import OrderedSet
from alembic.ddl.base import _fk_spec
from .. import util
from ..ddl._autogen import is_index_sig
from ..ddl._autogen import is_uq_sig
from ..operations import ops
from ..util import sqla_compat
@@ -43,7 +35,10 @@ if TYPE_CHECKING:
from sqlalchemy.sql.elements import quoted_name
from sqlalchemy.sql.elements import TextClause
from sqlalchemy.sql.schema import Column
from sqlalchemy.sql.schema import ForeignKeyConstraint
from sqlalchemy.sql.schema import Index
from sqlalchemy.sql.schema import Table
from sqlalchemy.sql.schema import UniqueConstraint
from alembic.autogenerate.api import AutogenContext
from alembic.ddl.impl import DefaultImpl
@@ -51,8 +46,6 @@ if TYPE_CHECKING:
from alembic.operations.ops import MigrationScript
from alembic.operations.ops import ModifyTableOps
from alembic.operations.ops import UpgradeOps
from ..ddl._autogen import _constraint_sig
log = logging.getLogger(__name__)
@@ -217,7 +210,7 @@ def _compare_tables(
(inspector),
# fmt: on
)
_InspectorConv(inspector).reflect_table(t, include_columns=None)
sqla_compat._reflect_table(inspector, t)
if autogen_context.run_object_filters(t, tname, "table", True, None):
modify_table_ops = ops.ModifyTableOps(tname, [], schema=s)
@@ -247,8 +240,7 @@ def _compare_tables(
_compat_autogen_column_reflect(inspector),
# fmt: on
)
_InspectorConv(inspector).reflect_table(t, include_columns=None)
sqla_compat._reflect_table(inspector, t)
conn_column_info[(s, tname)] = t
for s, tname in sorted(existing_tables, key=lambda x: (x[0] or "", x[1])):
@@ -437,56 +429,102 @@ def _compare_columns(
log.info("Detected removed column '%s.%s'", name, cname)
_C = TypeVar("_C", bound=Union[UniqueConstraint, ForeignKeyConstraint, Index])
class _constraint_sig:
const: Union[UniqueConstraint, ForeignKeyConstraint, Index]
class _InspectorConv:
__slots__ = ("inspector",)
def __init__(self, inspector):
self.inspector = inspector
def _apply_reflectinfo_conv(self, consts):
if not consts:
return consts
for const in consts:
if const["name"] is not None and not isinstance(
const["name"], conv
):
const["name"] = conv(const["name"])
return consts
def _apply_constraint_conv(self, consts):
if not consts:
return consts
for const in consts:
if const.name is not None and not isinstance(const.name, conv):
const.name = conv(const.name)
return consts
def get_indexes(self, *args, **kw):
return self._apply_reflectinfo_conv(
self.inspector.get_indexes(*args, **kw)
def md_name_to_sql_name(self, context: AutogenContext) -> Optional[str]:
return sqla_compat._get_constraint_final_name(
self.const, context.dialect
)
def get_unique_constraints(self, *args, **kw):
return self._apply_reflectinfo_conv(
self.inspector.get_unique_constraints(*args, **kw)
def __eq__(self, other):
return self.const == other.const
def __ne__(self, other):
return self.const != other.const
def __hash__(self) -> int:
return hash(self.const)
class _uq_constraint_sig(_constraint_sig):
is_index = False
is_unique = True
def __init__(self, const: UniqueConstraint, impl: DefaultImpl) -> None:
self.const = const
self.name = const.name
self.sig = ("UNIQUE_CONSTRAINT",) + impl.create_unique_constraint_sig(
const
)
def get_foreign_keys(self, *args, **kw):
return self._apply_reflectinfo_conv(
self.inspector.get_foreign_keys(*args, **kw)
@property
def column_names(self) -> List[str]:
return [col.name for col in self.const.columns]
class _ix_constraint_sig(_constraint_sig):
is_index = True
def __init__(self, const: Index, impl: DefaultImpl) -> None:
self.const = const
self.name = const.name
self.sig = ("INDEX",) + impl.create_index_sig(const)
self.is_unique = bool(const.unique)
def md_name_to_sql_name(self, context: AutogenContext) -> Optional[str]:
return sqla_compat._get_constraint_final_name(
self.const, context.dialect
)
def reflect_table(self, table, *, include_columns):
self.inspector.reflect_table(table, include_columns=include_columns)
@property
def column_names(self) -> Union[List[quoted_name], List[None]]:
return sqla_compat._get_index_column_names(self.const)
# I had a cool version of this using _ReflectInfo, however that doesn't
# work in 1.4 and it's not public API in 2.x. Then this is just a two
# liner. So there's no competition...
self._apply_constraint_conv(table.constraints)
self._apply_constraint_conv(table.indexes)
class _fk_constraint_sig(_constraint_sig):
def __init__(
self, const: ForeignKeyConstraint, include_options: bool = False
) -> None:
self.const = const
self.name = const.name
(
self.source_schema,
self.source_table,
self.source_columns,
self.target_schema,
self.target_table,
self.target_columns,
onupdate,
ondelete,
deferrable,
initially,
) = _fk_spec(const)
self.sig: Tuple[Any, ...] = (
self.source_schema,
self.source_table,
tuple(self.source_columns),
self.target_schema,
self.target_table,
tuple(self.target_columns),
)
if include_options:
self.sig += (
(None if onupdate.lower() == "no action" else onupdate.lower())
if onupdate
else None,
(None if ondelete.lower() == "no action" else ondelete.lower())
if ondelete
else None,
# convert initially + deferrable into one three-state value
"initially_deferrable"
if initially and initially.lower() == "deferred"
else "deferrable"
if deferrable
else "not deferrable",
)
@comparators.dispatch_for("table")
@@ -523,34 +561,34 @@ def _compare_indexes_and_uniques(
if conn_table is not None:
# 1b. ... and from connection, if the table exists
try:
conn_uniques = _InspectorConv(inspector).get_unique_constraints(
tname, schema=schema
)
supports_unique_constraints = True
except NotImplementedError:
pass
except TypeError:
# number of arguments is off for the base
# method in SQLAlchemy due to the cache decorator
# not being present
pass
else:
conn_uniques = [ # type:ignore[assignment]
uq
for uq in conn_uniques
if autogen_context.run_name_filters(
uq["name"],
"unique_constraint",
{"table_name": tname, "schema_name": schema},
if hasattr(inspector, "get_unique_constraints"):
try:
conn_uniques = inspector.get_unique_constraints( # type:ignore[assignment] # noqa
tname, schema=schema
)
]
for uq in conn_uniques:
if uq.get("duplicates_index"):
unique_constraints_duplicate_unique_indexes = True
supports_unique_constraints = True
except NotImplementedError:
pass
except TypeError:
# number of arguments is off for the base
# method in SQLAlchemy due to the cache decorator
# not being present
pass
else:
conn_uniques = [ # type:ignore[assignment]
uq
for uq in conn_uniques
if autogen_context.run_name_filters(
uq["name"],
"unique_constraint",
{"table_name": tname, "schema_name": schema},
)
]
for uq in conn_uniques:
if uq.get("duplicates_index"):
unique_constraints_duplicate_unique_indexes = True
try:
conn_indexes = _InspectorConv(inspector).get_indexes(
conn_indexes = inspector.get_indexes( # type:ignore[assignment]
tname, schema=schema
)
except NotImplementedError:
@@ -601,7 +639,7 @@ def _compare_indexes_and_uniques(
# 3. give the dialect a chance to omit indexes and constraints that
# we know are either added implicitly by the DB or that the DB
# can't accurately report on
impl.correct_for_autogen_constraints(
autogen_context.migration_context.impl.correct_for_autogen_constraints(
conn_uniques, # type: ignore[arg-type]
conn_indexes, # type: ignore[arg-type]
metadata_unique_constraints,
@@ -613,31 +651,31 @@ def _compare_indexes_and_uniques(
# Index and UniqueConstraint so we can easily work with them
# interchangeably
metadata_unique_constraints_sig = {
impl._create_metadata_constraint_sig(uq)
for uq in metadata_unique_constraints
_uq_constraint_sig(uq, impl) for uq in metadata_unique_constraints
}
metadata_indexes_sig = {
impl._create_metadata_constraint_sig(ix) for ix in metadata_indexes
_ix_constraint_sig(ix, impl) for ix in metadata_indexes
}
conn_unique_constraints = {
impl._create_reflected_constraint_sig(uq) for uq in conn_uniques
_uq_constraint_sig(uq, impl) for uq in conn_uniques
}
conn_indexes_sig = {
impl._create_reflected_constraint_sig(ix) for ix in conn_indexes
}
conn_indexes_sig = {_ix_constraint_sig(ix, impl) for ix in conn_indexes}
# 5. index things by name, for those objects that have names
metadata_names = {
cast(str, c.md_name_to_sql_name(autogen_context)): c
for c in metadata_unique_constraints_sig.union(metadata_indexes_sig)
if c.is_named
for c in metadata_unique_constraints_sig.union(
metadata_indexes_sig # type:ignore[arg-type]
)
if isinstance(c, _ix_constraint_sig)
or sqla_compat._constraint_is_named(c.const, autogen_context.dialect)
}
conn_uniques_by_name: Dict[sqla_compat._ConstraintName, _constraint_sig]
conn_indexes_by_name: Dict[sqla_compat._ConstraintName, _constraint_sig]
conn_uniques_by_name: Dict[sqla_compat._ConstraintName, _uq_constraint_sig]
conn_indexes_by_name: Dict[sqla_compat._ConstraintName, _ix_constraint_sig]
conn_uniques_by_name = {c.name: c for c in conn_unique_constraints}
conn_indexes_by_name = {c.name: c for c in conn_indexes_sig}
@@ -656,12 +694,13 @@ def _compare_indexes_and_uniques(
# 6. index things by "column signature", to help with unnamed unique
# constraints.
conn_uniques_by_sig = {uq.unnamed: uq for uq in conn_unique_constraints}
conn_uniques_by_sig = {uq.sig: uq for uq in conn_unique_constraints}
metadata_uniques_by_sig = {
uq.unnamed: uq for uq in metadata_unique_constraints_sig
uq.sig: uq for uq in metadata_unique_constraints_sig
}
metadata_indexes_by_sig = {ix.sig: ix for ix in metadata_indexes_sig}
unnamed_metadata_uniques = {
uq.unnamed: uq
uq.sig: uq
for uq in metadata_unique_constraints_sig
if not sqla_compat._constraint_is_named(
uq.const, autogen_context.dialect
@@ -676,18 +715,18 @@ def _compare_indexes_and_uniques(
# 4. The backend may double up indexes as unique constraints and
# vice versa (e.g. MySQL, Postgresql)
def obj_added(obj: _constraint_sig):
if is_index_sig(obj):
def obj_added(obj):
if obj.is_index:
if autogen_context.run_object_filters(
obj.const, obj.name, "index", False, None
):
modify_ops.ops.append(ops.CreateIndexOp.from_index(obj.const))
log.info(
"Detected added index %r on '%s'",
"Detected added index '%s' on %s",
obj.name,
obj.column_names,
", ".join(["'%s'" % obj.column_names]),
)
elif is_uq_sig(obj):
else:
if not supports_unique_constraints:
# can't report unique indexes as added if we don't
# detect them
@@ -702,15 +741,13 @@ def _compare_indexes_and_uniques(
ops.AddConstraintOp.from_constraint(obj.const)
)
log.info(
"Detected added unique constraint %r on '%s'",
"Detected added unique constraint '%s' on %s",
obj.name,
obj.column_names,
", ".join(["'%s'" % obj.column_names]),
)
else:
assert False
def obj_removed(obj: _constraint_sig):
if is_index_sig(obj):
def obj_removed(obj):
if obj.is_index:
if obj.is_unique and not supports_unique_constraints:
# many databases double up unique constraints
# as unique indexes. without that list we can't
@@ -721,8 +758,10 @@ def _compare_indexes_and_uniques(
obj.const, obj.name, "index", True, None
):
modify_ops.ops.append(ops.DropIndexOp.from_index(obj.const))
log.info("Detected removed index %r on %r", obj.name, tname)
elif is_uq_sig(obj):
log.info(
"Detected removed index '%s' on '%s'", obj.name, tname
)
else:
if is_create_table or is_drop_table:
# if the whole table is being dropped, we don't need to
# consider unique constraint separately
@@ -734,40 +773,33 @@ def _compare_indexes_and_uniques(
ops.DropConstraintOp.from_constraint(obj.const)
)
log.info(
"Detected removed unique constraint %r on %r",
"Detected removed unique constraint '%s' on '%s'",
obj.name,
tname,
)
else:
assert False
def obj_changed(
old: _constraint_sig,
new: _constraint_sig,
msg: str,
):
if is_index_sig(old):
assert is_index_sig(new)
def obj_changed(old, new, msg):
if old.is_index:
if autogen_context.run_object_filters(
new.const, new.name, "index", False, old.const
):
log.info(
"Detected changed index %r on %r: %s", old.name, tname, msg
"Detected changed index '%s' on '%s':%s",
old.name,
tname,
", ".join(msg),
)
modify_ops.ops.append(ops.DropIndexOp.from_index(old.const))
modify_ops.ops.append(ops.CreateIndexOp.from_index(new.const))
elif is_uq_sig(old):
assert is_uq_sig(new)
else:
if autogen_context.run_object_filters(
new.const, new.name, "unique_constraint", False, old.const
):
log.info(
"Detected changed unique constraint %r on %r: %s",
"Detected changed unique constraint '%s' on '%s':%s",
old.name,
tname,
msg,
", ".join(msg),
)
modify_ops.ops.append(
ops.DropConstraintOp.from_constraint(old.const)
@@ -775,24 +807,18 @@ def _compare_indexes_and_uniques(
modify_ops.ops.append(
ops.AddConstraintOp.from_constraint(new.const)
)
else:
assert False
for removed_name in sorted(set(conn_names).difference(metadata_names)):
conn_obj = conn_names[removed_name]
if (
is_uq_sig(conn_obj)
and conn_obj.unnamed in unnamed_metadata_uniques
):
conn_obj: Union[_ix_constraint_sig, _uq_constraint_sig] = conn_names[
removed_name
]
if not conn_obj.is_index and conn_obj.sig in unnamed_metadata_uniques:
continue
elif removed_name in doubled_constraints:
conn_uq, conn_idx = doubled_constraints[removed_name]
if (
all(
conn_idx.unnamed != meta_idx.unnamed
for meta_idx in metadata_indexes_sig
)
and conn_uq.unnamed not in metadata_uniques_by_sig
conn_idx.sig not in metadata_indexes_by_sig
and conn_uq.sig not in metadata_uniques_by_sig
):
obj_removed(conn_uq)
obj_removed(conn_idx)
@@ -804,36 +830,30 @@ def _compare_indexes_and_uniques(
if existing_name in doubled_constraints:
conn_uq, conn_idx = doubled_constraints[existing_name]
if is_index_sig(metadata_obj):
if metadata_obj.is_index:
conn_obj = conn_idx
else:
conn_obj = conn_uq
else:
conn_obj = conn_names[existing_name]
if type(conn_obj) != type(metadata_obj):
if conn_obj.is_index != metadata_obj.is_index:
obj_removed(conn_obj)
obj_added(metadata_obj)
else:
comparison = metadata_obj.compare_to_reflected(conn_obj)
msg = []
if conn_obj.is_unique != metadata_obj.is_unique:
msg.append(
" unique=%r to unique=%r"
% (conn_obj.is_unique, metadata_obj.is_unique)
)
if conn_obj.sig != metadata_obj.sig:
msg.append(
" expression %r to %r" % (conn_obj.sig, metadata_obj.sig)
)
if comparison.is_different:
# constraint are different
obj_changed(conn_obj, metadata_obj, comparison.message)
elif comparison.is_skip:
# constraint cannot be compared, skip them
thing = (
"index" if is_index_sig(conn_obj) else "unique constraint"
)
log.info(
"Cannot compare %s %r, assuming equal and skipping. %s",
thing,
conn_obj.name,
comparison.message,
)
else:
# constraint are equal
assert comparison.is_equal
if msg:
obj_changed(conn_obj, metadata_obj, msg)
for added_name in sorted(set(metadata_names).difference(conn_names)):
obj = metadata_names[added_name]
@@ -873,7 +893,7 @@ def _correct_for_uq_duplicates_uix(
}
unnamed_metadata_uqs = {
impl._create_metadata_constraint_sig(cons).unnamed
_uq_constraint_sig(cons, impl).sig
for name, cons in metadata_cons_names
if name is None
}
@@ -897,9 +917,7 @@ def _correct_for_uq_duplicates_uix(
for overlap in uqs_dupe_indexes:
if overlap not in metadata_uq_names:
if (
impl._create_reflected_constraint_sig(
uqs_dupe_indexes[overlap]
).unnamed
_uq_constraint_sig(uqs_dupe_indexes[overlap], impl).sig
not in unnamed_metadata_uqs
):
conn_unique_constraints.discard(uqs_dupe_indexes[overlap])
@@ -1035,7 +1053,7 @@ def _normalize_computed_default(sqltext: str) -> str:
"""
return re.sub(r"[ \(\)'\"`\[\]\t\r\n]", "", sqltext).lower()
return re.sub(r"[ \(\)'\"`\[\]]", "", sqltext).lower()
def _compare_computed_default(
@@ -1119,15 +1137,27 @@ def _compare_server_default(
return False
if sqla_compat._server_default_is_computed(metadata_default):
return _compare_computed_default( # type:ignore[func-returns-value]
autogen_context,
alter_column_op,
schema,
tname,
cname,
conn_col,
metadata_col,
)
# return False in case of a computed column as the server
# default. Note that DDL for adding or removing "GENERATED AS" from
# an existing column is not currently known for any backend.
# Once SQLAlchemy can reflect "GENERATED" as the "computed" element,
# we would also want to ignore and/or warn for changes vs. the
# metadata (or support backend specific DDL if applicable).
if not sqla_compat.has_computed_reflection:
return False
else:
return (
_compare_computed_default( # type:ignore[func-returns-value]
autogen_context,
alter_column_op,
schema,
tname,
cname,
conn_col,
metadata_col,
)
)
if sqla_compat._server_default_is_computed(conn_col_default):
_warn_computed_not_supported(tname, cname)
return False
@@ -1213,8 +1243,8 @@ def _compare_foreign_keys(
modify_table_ops: ModifyTableOps,
schema: Optional[str],
tname: Union[quoted_name, str],
conn_table: Table,
metadata_table: Table,
conn_table: Optional[Table],
metadata_table: Optional[Table],
) -> None:
# if we're doing CREATE TABLE, all FKs are created
# inline within the table def
@@ -1230,9 +1260,7 @@ def _compare_foreign_keys(
conn_fks_list = [
fk
for fk in _InspectorConv(inspector).get_foreign_keys(
tname, schema=schema
)
for fk in inspector.get_foreign_keys(tname, schema=schema)
if autogen_context.run_name_filters(
fk["name"],
"foreign_key_constraint",
@@ -1240,11 +1268,14 @@ def _compare_foreign_keys(
)
]
conn_fks = {
_make_foreign_key(const, conn_table) for const in conn_fks_list
}
backend_reflects_fk_options = bool(
conn_fks_list and "options" in conn_fks_list[0]
)
impl = autogen_context.migration_context.impl
conn_fks = {
_make_foreign_key(const, conn_table) # type: ignore[arg-type]
for const in conn_fks_list
}
# give the dialect a chance to correct the FKs to match more
# closely
@@ -1253,24 +1284,17 @@ def _compare_foreign_keys(
)
metadata_fks_sig = {
impl._create_metadata_constraint_sig(fk) for fk in metadata_fks
_fk_constraint_sig(fk, include_options=backend_reflects_fk_options)
for fk in metadata_fks
}
conn_fks_sig = {
impl._create_reflected_constraint_sig(fk) for fk in conn_fks
_fk_constraint_sig(fk, include_options=backend_reflects_fk_options)
for fk in conn_fks
}
# check if reflected FKs include options, indicating the backend
# can reflect FK options
if conn_fks_list and "options" in conn_fks_list[0]:
conn_fks_by_sig = {c.unnamed: c for c in conn_fks_sig}
metadata_fks_by_sig = {c.unnamed: c for c in metadata_fks_sig}
else:
# otherwise compare by sig without options added
conn_fks_by_sig = {c.unnamed_no_options: c for c in conn_fks_sig}
metadata_fks_by_sig = {
c.unnamed_no_options: c for c in metadata_fks_sig
}
conn_fks_by_sig = {c.sig: c for c in conn_fks_sig}
metadata_fks_by_sig = {c.sig: c for c in metadata_fks_sig}
metadata_fks_by_name = {
c.name: c for c in metadata_fks_sig if c.name is not None

View File

@@ -1,6 +1,3 @@
# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
# mypy: no-warn-return-any, allow-any-generics
from __future__ import annotations
from io import StringIO
@@ -18,9 +15,7 @@ from mako.pygen import PythonPrinter
from sqlalchemy import schema as sa_schema
from sqlalchemy import sql
from sqlalchemy import types as sqltypes
from sqlalchemy.sql.base import _DialectArgView
from sqlalchemy.sql.elements import conv
from sqlalchemy.sql.elements import Label
from sqlalchemy.sql.elements import quoted_name
from .. import util
@@ -30,8 +25,7 @@ from ..util import sqla_compat
if TYPE_CHECKING:
from typing import Literal
from sqlalchemy import Computed
from sqlalchemy import Identity
from sqlalchemy.sql.base import DialectKWArgs
from sqlalchemy.sql.elements import ColumnElement
from sqlalchemy.sql.elements import TextClause
from sqlalchemy.sql.schema import CheckConstraint
@@ -51,6 +45,8 @@ if TYPE_CHECKING:
from alembic.config import Config
from alembic.operations.ops import MigrationScript
from alembic.operations.ops import ModifyTableOps
from alembic.util.sqla_compat import Computed
from alembic.util.sqla_compat import Identity
MAX_PYTHON_ARGS = 255
@@ -168,31 +164,21 @@ def _render_modify_table(
def _render_create_table_comment(
autogen_context: AutogenContext, op: ops.CreateTableCommentOp
) -> str:
if autogen_context._has_batch:
templ = (
"{prefix}create_table_comment(\n"
"{indent}{comment},\n"
"{indent}existing_comment={existing}\n"
")"
)
else:
templ = (
"{prefix}create_table_comment(\n"
"{indent}'{tname}',\n"
"{indent}{comment},\n"
"{indent}existing_comment={existing},\n"
"{indent}schema={schema}\n"
")"
)
templ = (
"{prefix}create_table_comment(\n"
"{indent}'{tname}',\n"
"{indent}{comment},\n"
"{indent}existing_comment={existing},\n"
"{indent}schema={schema}\n"
")"
)
return templ.format(
prefix=_alembic_autogenerate_prefix(autogen_context),
tname=op.table_name,
comment="%r" % op.comment if op.comment is not None else None,
existing=(
"%r" % op.existing_comment
if op.existing_comment is not None
else None
),
existing="%r" % op.existing_comment
if op.existing_comment is not None
else None,
schema="'%s'" % op.schema if op.schema is not None else None,
indent=" ",
)
@@ -202,28 +188,19 @@ def _render_create_table_comment(
def _render_drop_table_comment(
autogen_context: AutogenContext, op: ops.DropTableCommentOp
) -> str:
if autogen_context._has_batch:
templ = (
"{prefix}drop_table_comment(\n"
"{indent}existing_comment={existing}\n"
")"
)
else:
templ = (
"{prefix}drop_table_comment(\n"
"{indent}'{tname}',\n"
"{indent}existing_comment={existing},\n"
"{indent}schema={schema}\n"
")"
)
templ = (
"{prefix}drop_table_comment(\n"
"{indent}'{tname}',\n"
"{indent}existing_comment={existing},\n"
"{indent}schema={schema}\n"
")"
)
return templ.format(
prefix=_alembic_autogenerate_prefix(autogen_context),
tname=op.table_name,
existing=(
"%r" % op.existing_comment
if op.existing_comment is not None
else None
),
existing="%r" % op.existing_comment
if op.existing_comment is not None
else None,
schema="'%s'" % op.schema if op.schema is not None else None,
indent=" ",
)
@@ -280,9 +257,6 @@ def _add_table(autogen_context: AutogenContext, op: ops.CreateTableOp) -> str:
prefixes = ", ".join("'%s'" % p for p in table._prefixes)
text += ",\nprefixes=[%s]" % prefixes
if op.if_not_exists is not None:
text += ",\nif_not_exists=%r" % bool(op.if_not_exists)
text += "\n)"
return text
@@ -295,20 +269,16 @@ def _drop_table(autogen_context: AutogenContext, op: ops.DropTableOp) -> str:
}
if op.schema:
text += ", schema=%r" % _ident(op.schema)
if op.if_exists is not None:
text += ", if_exists=%r" % bool(op.if_exists)
text += ")"
return text
def _render_dialect_kwargs_items(
autogen_context: AutogenContext, dialect_kwargs: _DialectArgView
autogen_context: AutogenContext, item: DialectKWArgs
) -> list[str]:
return [
f"{key}={_render_potential_expr(val, autogen_context)}"
for key, val in dialect_kwargs.items()
for key, val in item.dialect_kwargs.items()
]
@@ -331,9 +301,7 @@ def _add_index(autogen_context: AutogenContext, op: ops.CreateIndexOp) -> str:
assert index.table is not None
opts = _render_dialect_kwargs_items(autogen_context, index.dialect_kwargs)
if op.if_not_exists is not None:
opts.append("if_not_exists=%r" % bool(op.if_not_exists))
opts = _render_dialect_kwargs_items(autogen_context, index)
text = tmpl % {
"prefix": _alembic_autogenerate_prefix(autogen_context),
"name": _render_gen_name(autogen_context, index.name),
@@ -342,11 +310,9 @@ def _add_index(autogen_context: AutogenContext, op: ops.CreateIndexOp) -> str:
_get_index_rendered_expressions(index, autogen_context)
),
"unique": index.unique or False,
"schema": (
(", schema=%r" % _ident(index.table.schema))
if index.table.schema
else ""
),
"schema": (", schema=%r" % _ident(index.table.schema))
if index.table.schema
else "",
"kwargs": ", " + ", ".join(opts) if opts else "",
}
return text
@@ -365,9 +331,7 @@ def _drop_index(autogen_context: AutogenContext, op: ops.DropIndexOp) -> str:
"%(prefix)sdrop_index(%(name)r, "
"table_name=%(table_name)r%(schema)s%(kwargs)s)"
)
opts = _render_dialect_kwargs_items(autogen_context, index.dialect_kwargs)
if op.if_exists is not None:
opts.append("if_exists=%r" % bool(op.if_exists))
opts = _render_dialect_kwargs_items(autogen_context, index)
text = tmpl % {
"prefix": _alembic_autogenerate_prefix(autogen_context),
"name": _render_gen_name(autogen_context, op.index_name),
@@ -389,7 +353,6 @@ def _add_unique_constraint(
def _add_fk_constraint(
autogen_context: AutogenContext, op: ops.CreateForeignKeyOp
) -> str:
constraint = op.to_constraint()
args = [repr(_render_gen_name(autogen_context, op.constraint_name))]
if not autogen_context._has_batch:
args.append(repr(_ident(op.source_table)))
@@ -419,16 +382,9 @@ def _add_fk_constraint(
if value is not None:
args.append("%s=%r" % (k, value))
dialect_kwargs = _render_dialect_kwargs_items(
autogen_context, constraint.dialect_kwargs
)
return "%(prefix)screate_foreign_key(%(args)s%(dialect_kwargs)s)" % {
return "%(prefix)screate_foreign_key(%(args)s)" % {
"prefix": _alembic_autogenerate_prefix(autogen_context),
"args": ", ".join(args),
"dialect_kwargs": (
", " + ", ".join(dialect_kwargs) if dialect_kwargs else ""
),
}
@@ -450,7 +406,7 @@ def _drop_constraint(
name = _render_gen_name(autogen_context, op.constraint_name)
schema = _ident(op.schema) if op.schema else None
type_ = _ident(op.constraint_type) if op.constraint_type else None
if_exists = op.if_exists
params_strs = []
params_strs.append(repr(name))
if not autogen_context._has_batch:
@@ -459,47 +415,32 @@ def _drop_constraint(
params_strs.append(f"schema={schema!r}")
if type_ is not None:
params_strs.append(f"type_={type_!r}")
if if_exists is not None:
params_strs.append(f"if_exists={if_exists}")
return f"{prefix}drop_constraint({', '.join(params_strs)})"
@renderers.dispatch_for(ops.AddColumnOp)
def _add_column(autogen_context: AutogenContext, op: ops.AddColumnOp) -> str:
schema, tname, column, if_not_exists = (
op.schema,
op.table_name,
op.column,
op.if_not_exists,
)
schema, tname, column = op.schema, op.table_name, op.column
if autogen_context._has_batch:
template = "%(prefix)sadd_column(%(column)s)"
else:
template = "%(prefix)sadd_column(%(tname)r, %(column)s"
if schema:
template += ", schema=%(schema)r"
if if_not_exists is not None:
template += ", if_not_exists=%(if_not_exists)r"
template += ")"
text = template % {
"prefix": _alembic_autogenerate_prefix(autogen_context),
"tname": tname,
"column": _render_column(column, autogen_context),
"schema": schema,
"if_not_exists": if_not_exists,
}
return text
@renderers.dispatch_for(ops.DropColumnOp)
def _drop_column(autogen_context: AutogenContext, op: ops.DropColumnOp) -> str:
schema, tname, column_name, if_exists = (
op.schema,
op.table_name,
op.column_name,
op.if_exists,
)
schema, tname, column_name = op.schema, op.table_name, op.column_name
if autogen_context._has_batch:
template = "%(prefix)sdrop_column(%(cname)r)"
@@ -507,8 +448,6 @@ def _drop_column(autogen_context: AutogenContext, op: ops.DropColumnOp) -> str:
template = "%(prefix)sdrop_column(%(tname)r, %(cname)r"
if schema:
template += ", schema=%(schema)r"
if if_exists is not None:
template += ", if_exists=%(if_exists)r"
template += ")"
text = template % {
@@ -516,7 +455,6 @@ def _drop_column(autogen_context: AutogenContext, op: ops.DropColumnOp) -> str:
"tname": _ident(tname),
"cname": _ident(column_name),
"schema": _ident(schema),
"if_exists": if_exists,
}
return text
@@ -531,7 +469,6 @@ def _alter_column(
type_ = op.modify_type
nullable = op.modify_nullable
comment = op.modify_comment
newname = op.modify_name
autoincrement = op.kw.get("autoincrement", None)
existing_type = op.existing_type
existing_nullable = op.existing_nullable
@@ -560,8 +497,6 @@ def _alter_column(
rendered = _render_server_default(server_default, autogen_context)
text += ",\n%sserver_default=%s" % (indent, rendered)
if newname is not None:
text += ",\n%snew_column_name=%r" % (indent, newname)
if type_ is not None:
text += ",\n%stype_=%s" % (indent, _repr_type(type_, autogen_context))
if nullable is not None:
@@ -614,28 +549,23 @@ def _render_potential_expr(
value: Any,
autogen_context: AutogenContext,
*,
wrap_in_element: bool = True,
wrap_in_text: bool = True,
is_server_default: bool = False,
is_index: bool = False,
) -> str:
if isinstance(value, sql.ClauseElement):
sql_text = autogen_context.migration_context.impl.render_ddl_sql_expr(
value, is_server_default=is_server_default, is_index=is_index
)
if wrap_in_element:
prefix = _sqlalchemy_autogenerate_prefix(autogen_context)
element = "literal_column" if is_index else "text"
value_str = f"{prefix}{element}({sql_text!r})"
if (
is_index
and isinstance(value, Label)
and type(value.name) is str
):
return value_str + f".label({value.name!r})"
else:
return value_str
if wrap_in_text:
template = "%(prefix)stext(%(sql)r)"
else:
return repr(sql_text)
template = "%(sql)r"
return template % {
"prefix": _sqlalchemy_autogenerate_prefix(autogen_context),
"sql": autogen_context.migration_context.impl.render_ddl_sql_expr(
value, is_server_default=is_server_default, is_index=is_index
),
}
else:
return repr(value)
@@ -644,11 +574,9 @@ def _get_index_rendered_expressions(
idx: Index, autogen_context: AutogenContext
) -> List[str]:
return [
(
repr(_ident(getattr(exp, "name", None)))
if isinstance(exp, sa_schema.Column)
else _render_potential_expr(exp, autogen_context, is_index=True)
)
repr(_ident(getattr(exp, "name", None)))
if isinstance(exp, sa_schema.Column)
else _render_potential_expr(exp, autogen_context, is_index=True)
for exp in idx.expressions
]
@@ -663,18 +591,16 @@ def _uq_constraint(
has_batch = autogen_context._has_batch
if constraint.deferrable:
opts.append(("deferrable", constraint.deferrable))
opts.append(("deferrable", str(constraint.deferrable)))
if constraint.initially:
opts.append(("initially", constraint.initially))
opts.append(("initially", str(constraint.initially)))
if not has_batch and alter and constraint.table.schema:
opts.append(("schema", _ident(constraint.table.schema)))
if not alter and constraint.name:
opts.append(
("name", _render_gen_name(autogen_context, constraint.name))
)
dialect_options = _render_dialect_kwargs_items(
autogen_context, constraint.dialect_kwargs
)
dialect_options = _render_dialect_kwargs_items(autogen_context, constraint)
if alter:
args = [repr(_render_gen_name(autogen_context, constraint.name))]
@@ -778,7 +704,7 @@ def _render_column(
+ [
"%s=%s"
% (key, _render_potential_expr(val, autogen_context))
for key, val in column.kwargs.items()
for key, val in sqla_compat._column_kwargs(column).items()
]
)
),
@@ -813,8 +739,6 @@ def _render_server_default(
return _render_potential_expr(
default.arg, autogen_context, is_server_default=True
)
elif isinstance(default, sa_schema.FetchedValue):
return _render_fetched_value(autogen_context)
if isinstance(default, str) and repr_:
default = repr(re.sub(r"^'|'$", "", default))
@@ -826,7 +750,7 @@ def _render_computed(
computed: Computed, autogen_context: AutogenContext
) -> str:
text = _render_potential_expr(
computed.sqltext, autogen_context, wrap_in_element=False
computed.sqltext, autogen_context, wrap_in_text=False
)
kwargs = {}
@@ -852,12 +776,6 @@ def _render_identity(
}
def _render_fetched_value(autogen_context: AutogenContext) -> str:
return "%(prefix)sFetchedValue()" % {
"prefix": _sqlalchemy_autogenerate_prefix(autogen_context),
}
def _repr_type(
type_: TypeEngine,
autogen_context: AutogenContext,
@@ -876,10 +794,7 @@ def _repr_type(
mod = type(type_).__module__
imports = autogen_context.imports
if not _skip_variants and sqla_compat._type_has_variants(type_):
return _render_Variant_type(type_, autogen_context)
elif mod.startswith("sqlalchemy.dialects"):
if mod.startswith("sqlalchemy.dialects"):
match = re.match(r"sqlalchemy\.dialects\.(\w+)", mod)
assert match is not None
dname = match.group(1)
@@ -891,6 +806,8 @@ def _repr_type(
return "%s.%r" % (dname, type_)
elif impl_rt:
return impl_rt
elif not _skip_variants and sqla_compat._type_has_variants(type_):
return _render_Variant_type(type_, autogen_context)
elif mod.startswith("sqlalchemy."):
if "_render_%s_type" % type_.__visit_name__ in globals():
fn = globals()["_render_%s_type" % type_.__visit_name__]
@@ -917,7 +834,7 @@ def _render_Variant_type(
) -> str:
base_type, variant_mapping = sqla_compat._get_variant_mapping(type_)
base = _repr_type(base_type, autogen_context, _skip_variants=True)
assert base is not None and base is not False # type: ignore[comparison-overlap] # noqa:E501
assert base is not None and base is not False
for dialect in sorted(variant_mapping):
typ = variant_mapping[dialect]
base += ".with_variant(%s, %r)" % (
@@ -1008,13 +925,13 @@ def _render_primary_key(
def _fk_colspec(
fk: ForeignKey,
metadata_schema: Optional[str],
namespace_metadata: Optional[MetaData],
namespace_metadata: MetaData,
) -> str:
"""Implement a 'safe' version of ForeignKey._get_colspec() that
won't fail if the remote table can't be resolved.
"""
colspec = fk._get_colspec()
colspec = fk._get_colspec() # type:ignore[attr-defined]
tokens = colspec.split(".")
tname, colname = tokens[-2:]
@@ -1032,10 +949,7 @@ def _fk_colspec(
# the FK constraint needs to be rendered in terms of the column
# name.
if (
namespace_metadata is not None
and table_fullname in namespace_metadata.tables
):
if table_fullname in namespace_metadata.tables:
col = namespace_metadata.tables[table_fullname].c.get(colname)
if col is not None:
colname = _ident(col.name) # type: ignore[assignment]
@@ -1066,7 +980,7 @@ def _populate_render_fk_opts(
def _render_foreign_key(
constraint: ForeignKeyConstraint,
autogen_context: AutogenContext,
namespace_metadata: Optional[MetaData],
namespace_metadata: MetaData,
) -> Optional[str]:
rendered = _user_defined_render("foreign_key", constraint, autogen_context)
if rendered is not False:
@@ -1080,16 +994,15 @@ def _render_foreign_key(
_populate_render_fk_opts(constraint, opts)
apply_metadata_schema = (
namespace_metadata.schema if namespace_metadata is not None else None
)
apply_metadata_schema = namespace_metadata.schema
return (
"%(prefix)sForeignKeyConstraint([%(cols)s], "
"[%(refcols)s], %(args)s)"
% {
"prefix": _sqlalchemy_autogenerate_prefix(autogen_context),
"cols": ", ".join(
repr(_ident(f.parent.name)) for f in constraint.elements
"%r" % _ident(cast("Column", f.parent).name)
for f in constraint.elements
),
"refcols": ", ".join(
repr(_fk_colspec(f, apply_metadata_schema, namespace_metadata))
@@ -1130,10 +1043,12 @@ def _render_check_constraint(
# ideally SQLAlchemy would give us more of a first class
# way to detect this.
if (
constraint._create_rule
and hasattr(constraint._create_rule, "target")
constraint._create_rule # type:ignore[attr-defined]
and hasattr(
constraint._create_rule, "target" # type:ignore[attr-defined]
)
and isinstance(
constraint._create_rule.target,
constraint._create_rule.target, # type:ignore[attr-defined]
sqltypes.TypeEngine,
)
):
@@ -1145,13 +1060,11 @@ def _render_check_constraint(
)
return "%(prefix)sCheckConstraint(%(sqltext)s%(opts)s)" % {
"prefix": _sqlalchemy_autogenerate_prefix(autogen_context),
"opts": (
", " + (", ".join("%s=%s" % (k, v) for k, v in opts))
if opts
else ""
),
"opts": ", " + (", ".join("%s=%s" % (k, v) for k, v in opts))
if opts
else "",
"sqltext": _render_potential_expr(
constraint.sqltext, autogen_context, wrap_in_element=False
constraint.sqltext, autogen_context, wrap_in_text=False
),
}
@@ -1163,10 +1076,7 @@ def _execute_sql(autogen_context: AutogenContext, op: ops.ExecuteSQLOp) -> str:
"Autogenerate rendering of SQL Expression language constructs "
"not supported here; please use a plain SQL string"
)
return "{prefix}execute({sqltext!r})".format(
prefix=_alembic_autogenerate_prefix(autogen_context),
sqltext=op.sqltext,
)
return "op.execute(%r)" % op.sqltext
renderers = default_renderers.branch()

View File

@@ -4,7 +4,7 @@ from typing import Any
from typing import Callable
from typing import Iterator
from typing import List
from typing import Tuple
from typing import Optional
from typing import Type
from typing import TYPE_CHECKING
from typing import Union
@@ -16,18 +16,12 @@ if TYPE_CHECKING:
from ..operations.ops import AddColumnOp
from ..operations.ops import AlterColumnOp
from ..operations.ops import CreateTableOp
from ..operations.ops import DowngradeOps
from ..operations.ops import MigrateOperation
from ..operations.ops import MigrationScript
from ..operations.ops import ModifyTableOps
from ..operations.ops import OpContainer
from ..operations.ops import UpgradeOps
from ..runtime.environment import _GetRevArg
from ..runtime.migration import MigrationContext
from ..script.revision import _GetRevArg
ProcessRevisionDirectiveFn = Callable[
["MigrationContext", "_GetRevArg", List["MigrationScript"]], None
]
class Rewriter:
@@ -58,21 +52,15 @@ class Rewriter:
_traverse = util.Dispatcher()
_chained: Tuple[Union[ProcessRevisionDirectiveFn, Rewriter], ...] = ()
_chained: Optional[Rewriter] = None
def __init__(self) -> None:
self.dispatch = util.Dispatcher()
def chain(
self,
other: Union[
ProcessRevisionDirectiveFn,
Rewriter,
],
) -> Rewriter:
def chain(self, other: Rewriter) -> Rewriter:
"""Produce a "chain" of this :class:`.Rewriter` to another.
This allows two or more rewriters to operate serially on a stream,
This allows two rewriters to operate serially on a stream,
e.g.::
writer1 = autogenerate.Rewriter()
@@ -101,7 +89,7 @@ class Rewriter:
"""
wr = self.__class__.__new__(self.__class__)
wr.__dict__.update(self.__dict__)
wr._chained += (other,)
wr._chained = other
return wr
def rewrites(
@@ -113,7 +101,7 @@ class Rewriter:
Type[CreateTableOp],
Type[ModifyTableOps],
],
) -> Callable[..., Any]:
) -> Callable:
"""Register a function as rewriter for a given type.
The function should receive three arguments, which are
@@ -158,8 +146,8 @@ class Rewriter:
directives: List[MigrationScript],
) -> None:
self.process_revision_directives(context, revision, directives)
for process_revision_directives in self._chained:
process_revision_directives(context, revision, directives)
if self._chained:
self._chained(context, revision, directives)
@_traverse.dispatch_for(ops.MigrationScript)
def _traverse_script(
@@ -168,7 +156,7 @@ class Rewriter:
revision: _GetRevArg,
directive: MigrationScript,
) -> None:
upgrade_ops_list: List[UpgradeOps] = []
upgrade_ops_list = []
for upgrade_ops in directive.upgrade_ops_list:
ret = self._traverse_for(context, revision, upgrade_ops)
if len(ret) != 1:
@@ -176,10 +164,9 @@ class Rewriter:
"Can only return single object for UpgradeOps traverse"
)
upgrade_ops_list.append(ret[0])
directive.upgrade_ops = upgrade_ops_list
downgrade_ops_list: List[DowngradeOps] = []
downgrade_ops_list = []
for downgrade_ops in directive.downgrade_ops_list:
ret = self._traverse_for(context, revision, downgrade_ops)
if len(ret) != 1: