main commit
All checks were successful
continuous-integration/drone/push Build is passing

This commit is contained in:
2025-10-16 16:30:25 +09:00
parent 91c7e04474
commit 537e7b363f
1146 changed files with 45926 additions and 77196 deletions

View File

@@ -3,4 +3,4 @@ from . import mysql
from . import oracle
from . import postgresql
from . import sqlite
from .impl import DefaultImpl as DefaultImpl
from .impl import DefaultImpl

View File

@@ -1,329 +0,0 @@
# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
# mypy: no-warn-return-any, allow-any-generics
from __future__ import annotations
from typing import Any
from typing import ClassVar
from typing import Dict
from typing import Generic
from typing import NamedTuple
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Type
from typing import TYPE_CHECKING
from typing import TypeVar
from typing import Union
from sqlalchemy.sql.schema import Constraint
from sqlalchemy.sql.schema import ForeignKeyConstraint
from sqlalchemy.sql.schema import Index
from sqlalchemy.sql.schema import UniqueConstraint
from typing_extensions import TypeGuard
from .. import util
from ..util import sqla_compat
if TYPE_CHECKING:
from typing import Literal
from alembic.autogenerate.api import AutogenContext
from alembic.ddl.impl import DefaultImpl
CompareConstraintType = Union[Constraint, Index]
_C = TypeVar("_C", bound=CompareConstraintType)
_clsreg: Dict[str, Type[_constraint_sig]] = {}
class ComparisonResult(NamedTuple):
status: Literal["equal", "different", "skip"]
message: str
@property
def is_equal(self) -> bool:
return self.status == "equal"
@property
def is_different(self) -> bool:
return self.status == "different"
@property
def is_skip(self) -> bool:
return self.status == "skip"
@classmethod
def Equal(cls) -> ComparisonResult:
"""the constraints are equal."""
return cls("equal", "The two constraints are equal")
@classmethod
def Different(cls, reason: Union[str, Sequence[str]]) -> ComparisonResult:
"""the constraints are different for the provided reason(s)."""
return cls("different", ", ".join(util.to_list(reason)))
@classmethod
def Skip(cls, reason: Union[str, Sequence[str]]) -> ComparisonResult:
"""the constraint cannot be compared for the provided reason(s).
The message is logged, but the constraints will be otherwise
considered equal, meaning that no migration command will be
generated.
"""
return cls("skip", ", ".join(util.to_list(reason)))
class _constraint_sig(Generic[_C]):
const: _C
_sig: Tuple[Any, ...]
name: Optional[sqla_compat._ConstraintNameDefined]
impl: DefaultImpl
_is_index: ClassVar[bool] = False
_is_fk: ClassVar[bool] = False
_is_uq: ClassVar[bool] = False
_is_metadata: bool
def __init_subclass__(cls) -> None:
cls._register()
@classmethod
def _register(cls):
raise NotImplementedError()
def __init__(
self, is_metadata: bool, impl: DefaultImpl, const: _C
) -> None:
raise NotImplementedError()
def compare_to_reflected(
self, other: _constraint_sig[Any]
) -> ComparisonResult:
assert self.impl is other.impl
assert self._is_metadata
assert not other._is_metadata
return self._compare_to_reflected(other)
def _compare_to_reflected(
self, other: _constraint_sig[_C]
) -> ComparisonResult:
raise NotImplementedError()
@classmethod
def from_constraint(
cls, is_metadata: bool, impl: DefaultImpl, constraint: _C
) -> _constraint_sig[_C]:
# these could be cached by constraint/impl, however, if the
# constraint is modified in place, then the sig is wrong. the mysql
# impl currently does this, and if we fixed that we can't be sure
# someone else might do it too, so play it safe.
sig = _clsreg[constraint.__visit_name__](is_metadata, impl, constraint)
return sig
def md_name_to_sql_name(self, context: AutogenContext) -> Optional[str]:
return sqla_compat._get_constraint_final_name(
self.const, context.dialect
)
@util.memoized_property
def is_named(self):
return sqla_compat._constraint_is_named(self.const, self.impl.dialect)
@util.memoized_property
def unnamed(self) -> Tuple[Any, ...]:
return self._sig
@util.memoized_property
def unnamed_no_options(self) -> Tuple[Any, ...]:
raise NotImplementedError()
@util.memoized_property
def _full_sig(self) -> Tuple[Any, ...]:
return (self.name,) + self.unnamed
def __eq__(self, other) -> bool:
return self._full_sig == other._full_sig
def __ne__(self, other) -> bool:
return self._full_sig != other._full_sig
def __hash__(self) -> int:
return hash(self._full_sig)
class _uq_constraint_sig(_constraint_sig[UniqueConstraint]):
_is_uq = True
@classmethod
def _register(cls) -> None:
_clsreg["unique_constraint"] = cls
is_unique = True
def __init__(
self,
is_metadata: bool,
impl: DefaultImpl,
const: UniqueConstraint,
) -> None:
self.impl = impl
self.const = const
self.name = sqla_compat.constraint_name_or_none(const.name)
self._sig = tuple(sorted([col.name for col in const.columns]))
self._is_metadata = is_metadata
@property
def column_names(self) -> Tuple[str, ...]:
return tuple([col.name for col in self.const.columns])
def _compare_to_reflected(
self, other: _constraint_sig[_C]
) -> ComparisonResult:
assert self._is_metadata
metadata_obj = self
conn_obj = other
assert is_uq_sig(conn_obj)
return self.impl.compare_unique_constraint(
metadata_obj.const, conn_obj.const
)
class _ix_constraint_sig(_constraint_sig[Index]):
_is_index = True
name: sqla_compat._ConstraintName
@classmethod
def _register(cls) -> None:
_clsreg["index"] = cls
def __init__(
self, is_metadata: bool, impl: DefaultImpl, const: Index
) -> None:
self.impl = impl
self.const = const
self.name = const.name
self.is_unique = bool(const.unique)
self._is_metadata = is_metadata
def _compare_to_reflected(
self, other: _constraint_sig[_C]
) -> ComparisonResult:
assert self._is_metadata
metadata_obj = self
conn_obj = other
assert is_index_sig(conn_obj)
return self.impl.compare_indexes(metadata_obj.const, conn_obj.const)
@util.memoized_property
def has_expressions(self):
return sqla_compat.is_expression_index(self.const)
@util.memoized_property
def column_names(self) -> Tuple[str, ...]:
return tuple([col.name for col in self.const.columns])
@util.memoized_property
def column_names_optional(self) -> Tuple[Optional[str], ...]:
return tuple(
[getattr(col, "name", None) for col in self.const.expressions]
)
@util.memoized_property
def is_named(self):
return True
@util.memoized_property
def unnamed(self):
return (self.is_unique,) + self.column_names_optional
class _fk_constraint_sig(_constraint_sig[ForeignKeyConstraint]):
_is_fk = True
@classmethod
def _register(cls) -> None:
_clsreg["foreign_key_constraint"] = cls
def __init__(
self,
is_metadata: bool,
impl: DefaultImpl,
const: ForeignKeyConstraint,
) -> None:
self._is_metadata = is_metadata
self.impl = impl
self.const = const
self.name = sqla_compat.constraint_name_or_none(const.name)
(
self.source_schema,
self.source_table,
self.source_columns,
self.target_schema,
self.target_table,
self.target_columns,
onupdate,
ondelete,
deferrable,
initially,
) = sqla_compat._fk_spec(const)
self._sig: Tuple[Any, ...] = (
self.source_schema,
self.source_table,
tuple(self.source_columns),
self.target_schema,
self.target_table,
tuple(self.target_columns),
) + (
(
(None if onupdate.lower() == "no action" else onupdate.lower())
if onupdate
else None
),
(
(None if ondelete.lower() == "no action" else ondelete.lower())
if ondelete
else None
),
# convert initially + deferrable into one three-state value
(
"initially_deferrable"
if initially and initially.lower() == "deferred"
else "deferrable" if deferrable else "not deferrable"
),
)
@util.memoized_property
def unnamed_no_options(self):
return (
self.source_schema,
self.source_table,
tuple(self.source_columns),
self.target_schema,
self.target_table,
tuple(self.target_columns),
)
def is_index_sig(sig: _constraint_sig) -> TypeGuard[_ix_constraint_sig]:
return sig._is_index
def is_uq_sig(sig: _constraint_sig) -> TypeGuard[_uq_constraint_sig]:
return sig._is_uq
def is_fk_sig(sig: _constraint_sig) -> TypeGuard[_fk_constraint_sig]:
return sig._is_fk

View File

@@ -1,6 +1,3 @@
# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
# mypy: no-warn-return-any, allow-any-generics
from __future__ import annotations
import functools
@@ -25,8 +22,6 @@ from ..util.sqla_compat import _table_for_constraint # noqa
if TYPE_CHECKING:
from typing import Any
from sqlalchemy import Computed
from sqlalchemy import Identity
from sqlalchemy.sql.compiler import Compiled
from sqlalchemy.sql.compiler import DDLCompiler
from sqlalchemy.sql.elements import TextClause
@@ -35,11 +30,14 @@ if TYPE_CHECKING:
from sqlalchemy.sql.type_api import TypeEngine
from .impl import DefaultImpl
from ..util.sqla_compat import Computed
from ..util.sqla_compat import Identity
_ServerDefault = Union["TextClause", "FetchedValue", "Function[Any]", str]
class AlterTable(DDLElement):
"""Represent an ALTER TABLE statement.
Only the string name and optional schema name of the table
@@ -154,24 +152,17 @@ class AddColumn(AlterTable):
name: str,
column: Column[Any],
schema: Optional[Union[quoted_name, str]] = None,
if_not_exists: Optional[bool] = None,
) -> None:
super().__init__(name, schema=schema)
self.column = column
self.if_not_exists = if_not_exists
class DropColumn(AlterTable):
def __init__(
self,
name: str,
column: Column[Any],
schema: Optional[str] = None,
if_exists: Optional[bool] = None,
self, name: str, column: Column[Any], schema: Optional[str] = None
) -> None:
super().__init__(name, schema=schema)
self.column = column
self.if_exists = if_exists
class ColumnComment(AlterColumn):
@@ -196,9 +187,7 @@ def visit_rename_table(
def visit_add_column(element: AddColumn, compiler: DDLCompiler, **kw) -> str:
return "%s %s" % (
alter_table(compiler, element.table_name, element.schema),
add_column(
compiler, element.column, if_not_exists=element.if_not_exists, **kw
),
add_column(compiler, element.column, **kw),
)
@@ -206,9 +195,7 @@ def visit_add_column(element: AddColumn, compiler: DDLCompiler, **kw) -> str:
def visit_drop_column(element: DropColumn, compiler: DDLCompiler, **kw) -> str:
return "%s %s" % (
alter_table(compiler, element.table_name, element.schema),
drop_column(
compiler, element.column.name, if_exists=element.if_exists, **kw
),
drop_column(compiler, element.column.name, **kw),
)
@@ -248,11 +235,9 @@ def visit_column_default(
return "%s %s %s" % (
alter_table(compiler, element.table_name, element.schema),
alter_column(compiler, element.column_name),
(
"SET DEFAULT %s" % format_server_default(compiler, element.default)
if element.default is not None
else "DROP DEFAULT"
),
"SET DEFAULT %s" % format_server_default(compiler, element.default)
if element.default is not None
else "DROP DEFAULT",
)
@@ -310,13 +295,9 @@ def format_server_default(
compiler: DDLCompiler,
default: Optional[_ServerDefault],
) -> str:
# this can be updated to use compiler.render_default_string
# for SQLAlchemy 2.0 and above; not in 1.4
default_str = compiler.get_column_default_string(
return compiler.get_column_default_string(
Column("x", Integer, server_default=default)
)
assert default_str is not None
return default_str
def format_type(compiler: DDLCompiler, type_: TypeEngine) -> str:
@@ -331,29 +312,16 @@ def alter_table(
return "ALTER TABLE %s" % format_table_name(compiler, name, schema)
def drop_column(
compiler: DDLCompiler, name: str, if_exists: Optional[bool] = None, **kw
) -> str:
return "DROP COLUMN %s%s" % (
"IF EXISTS " if if_exists else "",
format_column_name(compiler, name),
)
def drop_column(compiler: DDLCompiler, name: str, **kw) -> str:
return "DROP COLUMN %s" % format_column_name(compiler, name)
def alter_column(compiler: DDLCompiler, name: str) -> str:
return "ALTER COLUMN %s" % format_column_name(compiler, name)
def add_column(
compiler: DDLCompiler,
column: Column[Any],
if_not_exists: Optional[bool] = None,
**kw,
) -> str:
text = "ADD COLUMN %s%s" % (
"IF NOT EXISTS " if if_not_exists else "",
compiler.get_column_specification(column, **kw),
)
def add_column(compiler: DDLCompiler, column: Column[Any], **kw) -> str:
text = "ADD COLUMN %s" % compiler.get_column_specification(column, **kw)
const = " ".join(
compiler.process(constraint) for constraint in column.constraints

View File

@@ -1,9 +1,6 @@
# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
# mypy: no-warn-return-any, allow-any-generics
from __future__ import annotations
import logging
from collections import namedtuple
import re
from typing import Any
from typing import Callable
@@ -11,7 +8,6 @@ from typing import Dict
from typing import Iterable
from typing import List
from typing import Mapping
from typing import NamedTuple
from typing import Optional
from typing import Sequence
from typing import Set
@@ -21,18 +17,10 @@ from typing import TYPE_CHECKING
from typing import Union
from sqlalchemy import cast
from sqlalchemy import Column
from sqlalchemy import MetaData
from sqlalchemy import PrimaryKeyConstraint
from sqlalchemy import schema
from sqlalchemy import String
from sqlalchemy import Table
from sqlalchemy import text
from . import _autogen
from . import base
from ._autogen import _constraint_sig as _constraint_sig
from ._autogen import ComparisonResult as ComparisonResult
from .. import util
from ..util import sqla_compat
@@ -46,10 +34,13 @@ if TYPE_CHECKING:
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.sql import ClauseElement
from sqlalchemy.sql import Executable
from sqlalchemy.sql.elements import ColumnElement
from sqlalchemy.sql.elements import quoted_name
from sqlalchemy.sql.schema import Column
from sqlalchemy.sql.schema import Constraint
from sqlalchemy.sql.schema import ForeignKeyConstraint
from sqlalchemy.sql.schema import Index
from sqlalchemy.sql.schema import Table
from sqlalchemy.sql.schema import UniqueConstraint
from sqlalchemy.sql.selectable import TableClause
from sqlalchemy.sql.type_api import TypeEngine
@@ -59,8 +50,6 @@ if TYPE_CHECKING:
from ..operations.batch import ApplyBatchImpl
from ..operations.batch import BatchOperationsImpl
log = logging.getLogger(__name__)
class ImplMeta(type):
def __init__(
@@ -77,8 +66,11 @@ class ImplMeta(type):
_impls: Dict[str, Type[DefaultImpl]] = {}
Params = namedtuple("Params", ["token0", "tokens", "args", "kwargs"])
class DefaultImpl(metaclass=ImplMeta):
"""Provide the entrypoint for major migration operations,
including database-specific behavioral variances.
@@ -138,40 +130,6 @@ class DefaultImpl(metaclass=ImplMeta):
self.output_buffer.write(text + "\n\n")
self.output_buffer.flush()
def version_table_impl(
self,
*,
version_table: str,
version_table_schema: Optional[str],
version_table_pk: bool,
**kw: Any,
) -> Table:
"""Generate a :class:`.Table` object which will be used as the
structure for the Alembic version table.
Third party dialects may override this hook to provide an alternate
structure for this :class:`.Table`; requirements are only that it
be named based on the ``version_table`` parameter and contains
at least a single string-holding column named ``version_num``.
.. versionadded:: 1.14
"""
vt = Table(
version_table,
MetaData(),
Column("version_num", String(32), nullable=False),
schema=version_table_schema,
)
if version_table_pk:
vt.append_constraint(
PrimaryKeyConstraint(
"version_num", name=f"{version_table}_pkc"
)
)
return vt
def requires_recreate_in_batch(
self, batch_op: BatchOperationsImpl
) -> bool:
@@ -203,15 +161,16 @@ class DefaultImpl(metaclass=ImplMeta):
def _exec(
self,
construct: Union[Executable, str],
execution_options: Optional[Mapping[str, Any]] = None,
multiparams: Optional[Sequence[Mapping[str, Any]]] = None,
params: Mapping[str, Any] = util.immutabledict(),
execution_options: Optional[dict[str, Any]] = None,
multiparams: Sequence[dict] = (),
params: Dict[str, Any] = util.immutabledict(),
) -> Optional[CursorResult]:
if isinstance(construct, str):
construct = text(construct)
if self.as_sql:
if multiparams is not None or params:
raise TypeError("SQL parameters not allowed with as_sql")
if multiparams or params:
# TODO: coverage
raise Exception("Execution arguments not allowed with as_sql")
compile_kw: dict[str, Any]
if self.literal_binds and not isinstance(
@@ -234,16 +193,11 @@ class DefaultImpl(metaclass=ImplMeta):
assert conn is not None
if execution_options:
conn = conn.execution_options(**execution_options)
if params:
assert isinstance(multiparams, tuple)
multiparams += (params,)
if params and multiparams is not None:
raise TypeError(
"Can't send params and multiparams at the same time"
)
if multiparams:
return conn.execute(construct, multiparams)
else:
return conn.execute(construct, params)
return conn.execute(construct, multiparams)
def execute(
self,
@@ -256,11 +210,8 @@ class DefaultImpl(metaclass=ImplMeta):
self,
table_name: str,
column_name: str,
*,
nullable: Optional[bool] = None,
server_default: Optional[
Union[_ServerDefault, Literal[False]]
] = False,
server_default: Union[_ServerDefault, Literal[False]] = False,
name: Optional[str] = None,
type_: Optional[TypeEngine] = None,
schema: Optional[str] = None,
@@ -371,40 +322,25 @@ class DefaultImpl(metaclass=ImplMeta):
self,
table_name: str,
column: Column[Any],
*,
schema: Optional[Union[str, quoted_name]] = None,
if_not_exists: Optional[bool] = None,
) -> None:
self._exec(
base.AddColumn(
table_name,
column,
schema=schema,
if_not_exists=if_not_exists,
)
)
self._exec(base.AddColumn(table_name, column, schema=schema))
def drop_column(
self,
table_name: str,
column: Column[Any],
*,
schema: Optional[str] = None,
if_exists: Optional[bool] = None,
**kw,
) -> None:
self._exec(
base.DropColumn(
table_name, column, schema=schema, if_exists=if_exists
)
)
self._exec(base.DropColumn(table_name, column, schema=schema))
def add_constraint(self, const: Any) -> None:
if const._create_rule is None or const._create_rule(self):
self._exec(schema.AddConstraint(const))
def drop_constraint(self, const: Constraint, **kw: Any) -> None:
self._exec(schema.DropConstraint(const, **kw))
def drop_constraint(self, const: Constraint) -> None:
self._exec(schema.DropConstraint(const))
def rename_table(
self,
@@ -416,11 +352,11 @@ class DefaultImpl(metaclass=ImplMeta):
base.RenameTable(old_table_name, new_table_name, schema=schema)
)
def create_table(self, table: Table, **kw: Any) -> None:
def create_table(self, table: Table) -> None:
table.dispatch.before_create(
table, self.connection, checkfirst=False, _ddl_runner=self
)
self._exec(schema.CreateTable(table, **kw))
self._exec(schema.CreateTable(table))
table.dispatch.after_create(
table, self.connection, checkfirst=False, _ddl_runner=self
)
@@ -439,11 +375,11 @@ class DefaultImpl(metaclass=ImplMeta):
if comment and with_comment:
self.create_column_comment(column)
def drop_table(self, table: Table, **kw: Any) -> None:
def drop_table(self, table: Table) -> None:
table.dispatch.before_drop(
table, self.connection, checkfirst=False, _ddl_runner=self
)
self._exec(schema.DropTable(table, **kw))
self._exec(schema.DropTable(table))
table.dispatch.after_drop(
table, self.connection, checkfirst=False, _ddl_runner=self
)
@@ -457,7 +393,7 @@ class DefaultImpl(metaclass=ImplMeta):
def drop_table_comment(self, table: Table) -> None:
self._exec(schema.DropTableComment(table))
def create_column_comment(self, column: Column[Any]) -> None:
def create_column_comment(self, column: ColumnElement[Any]) -> None:
self._exec(schema.SetColumnComment(column))
def drop_index(self, index: Index, **kw: Any) -> None:
@@ -476,19 +412,15 @@ class DefaultImpl(metaclass=ImplMeta):
if self.as_sql:
for row in rows:
self._exec(
table.insert()
.inline()
.values(
sqla_compat._insert_inline(table).values(
**{
k: (
sqla_compat._literal_bindparam(
k, v, type_=table.c[k].type
)
if not isinstance(
v, sqla_compat._literal_bindparam
)
else v
k: sqla_compat._literal_bindparam(
k, v, type_=table.c[k].type
)
if not isinstance(
v, sqla_compat._literal_bindparam
)
else v
for k, v in row.items()
}
)
@@ -496,13 +428,16 @@ class DefaultImpl(metaclass=ImplMeta):
else:
if rows:
if multiinsert:
self._exec(table.insert().inline(), multiparams=rows)
self._exec(
sqla_compat._insert_inline(table), multiparams=rows
)
else:
for row in rows:
self._exec(table.insert().inline().values(**row))
self._exec(
sqla_compat._insert_inline(table).values(**row)
)
def _tokenize_column_type(self, column: Column) -> Params:
definition: str
definition = self.dialect.type_compiler.process(column.type).lower()
# tokenize the SQLAlchemy-generated version of a type, so that
@@ -517,9 +452,9 @@ class DefaultImpl(metaclass=ImplMeta):
# varchar character set utf8
#
tokens: List[str] = re.findall(r"[\w\-_]+|\(.+?\)", definition)
tokens = re.findall(r"[\w\-_]+|\(.+?\)", definition)
term_tokens: List[str] = []
term_tokens = []
paren_term = None
for token in tokens:
@@ -531,7 +466,6 @@ class DefaultImpl(metaclass=ImplMeta):
params = Params(term_tokens[0], term_tokens[1:], [], {})
if paren_term:
term: str
for term in re.findall("[^(),]+", paren_term):
if "=" in term:
key, val = term.split("=")
@@ -708,7 +642,7 @@ class DefaultImpl(metaclass=ImplMeta):
diff, ignored = _compare_identity_options(
metadata_identity,
inspector_identity,
schema.Identity(),
sqla_compat.Identity(),
skip={"always"},
)
@@ -730,96 +664,15 @@ class DefaultImpl(metaclass=ImplMeta):
bool(diff) or bool(metadata_identity) != bool(inspector_identity),
)
def _compare_index_unique(
self, metadata_index: Index, reflected_index: Index
) -> Optional[str]:
conn_unique = bool(reflected_index.unique)
meta_unique = bool(metadata_index.unique)
if conn_unique != meta_unique:
return f"unique={conn_unique} to unique={meta_unique}"
else:
return None
def create_index_sig(self, index: Index) -> Tuple[Any, ...]:
# order of col matters in an index
return tuple(col.name for col in index.columns)
def _create_metadata_constraint_sig(
self, constraint: _autogen._C, **opts: Any
) -> _constraint_sig[_autogen._C]:
return _constraint_sig.from_constraint(True, self, constraint, **opts)
def _create_reflected_constraint_sig(
self, constraint: _autogen._C, **opts: Any
) -> _constraint_sig[_autogen._C]:
return _constraint_sig.from_constraint(False, self, constraint, **opts)
def compare_indexes(
self,
metadata_index: Index,
reflected_index: Index,
) -> ComparisonResult:
"""Compare two indexes by comparing the signature generated by
``create_index_sig``.
This method returns a ``ComparisonResult``.
"""
msg: List[str] = []
unique_msg = self._compare_index_unique(
metadata_index, reflected_index
)
if unique_msg:
msg.append(unique_msg)
m_sig = self._create_metadata_constraint_sig(metadata_index)
r_sig = self._create_reflected_constraint_sig(reflected_index)
assert _autogen.is_index_sig(m_sig)
assert _autogen.is_index_sig(r_sig)
# The assumption is that the index have no expression
for sig in m_sig, r_sig:
if sig.has_expressions:
log.warning(
"Generating approximate signature for index %s. "
"The dialect "
"implementation should either skip expression indexes "
"or provide a custom implementation.",
sig.const,
)
if m_sig.column_names != r_sig.column_names:
msg.append(
f"expression {r_sig.column_names} to {m_sig.column_names}"
)
if msg:
return ComparisonResult.Different(msg)
else:
return ComparisonResult.Equal()
def compare_unique_constraint(
self,
metadata_constraint: UniqueConstraint,
reflected_constraint: UniqueConstraint,
) -> ComparisonResult:
"""Compare two unique constraints by comparing the two signatures.
The arguments are two tuples that contain the unique constraint and
the signatures generated by ``create_unique_constraint_sig``.
This method returns a ``ComparisonResult``.
"""
metadata_tup = self._create_metadata_constraint_sig(
metadata_constraint
)
reflected_tup = self._create_reflected_constraint_sig(
reflected_constraint
)
meta_sig = metadata_tup.unnamed
conn_sig = reflected_tup.unnamed
if conn_sig != meta_sig:
return ComparisonResult.Different(
f"expression {conn_sig} to {meta_sig}"
)
else:
return ComparisonResult.Equal()
def create_unique_constraint_sig(
self, const: UniqueConstraint
) -> Tuple[Any, ...]:
# order of col does not matters in an unique constraint
return tuple(sorted([col.name for col in const.columns]))
def _skip_functional_indexes(self, metadata_indexes, conn_indexes):
conn_indexes_by_name = {c.name: c for c in conn_indexes}
@@ -844,13 +697,6 @@ class DefaultImpl(metaclass=ImplMeta):
return reflected_object.get("dialect_options", {})
class Params(NamedTuple):
token0: str
tokens: List[str]
args: List[str]
kwargs: Dict[str, str]
def _compare_identity_options(
metadata_io: Union[schema.Identity, schema.Sequence, None],
inspector_io: Union[schema.Identity, schema.Sequence, None],
@@ -889,13 +735,12 @@ def _compare_identity_options(
set(meta_d).union(insp_d),
)
if sqla_compat.identity_has_dialect_kwargs:
assert hasattr(default_io, "dialect_kwargs")
# use only the dialect kwargs in inspector_io since metadata_io
# can have options for many backends
check_dicts(
getattr(metadata_io, "dialect_kwargs", {}),
getattr(inspector_io, "dialect_kwargs", {}),
default_io.dialect_kwargs,
default_io.dialect_kwargs, # type: ignore[union-attr]
getattr(inspector_io, "dialect_kwargs", {}),
)

View File

@@ -1,6 +1,3 @@
# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
# mypy: no-warn-return-any, allow-any-generics
from __future__ import annotations
import re
@@ -12,6 +9,7 @@ from typing import TYPE_CHECKING
from typing import Union
from sqlalchemy import types as sqltypes
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.schema import Column
from sqlalchemy.schema import CreateIndex
from sqlalchemy.sql.base import Executable
@@ -32,7 +30,6 @@ from .base import RenameTable
from .impl import DefaultImpl
from .. import util
from ..util import sqla_compat
from ..util.sqla_compat import compiles
if TYPE_CHECKING:
from typing import Literal
@@ -83,11 +80,10 @@ class MSSQLImpl(DefaultImpl):
if self.as_sql and self.batch_separator:
self.static_output(self.batch_separator)
def alter_column(
def alter_column( # type:ignore[override]
self,
table_name: str,
column_name: str,
*,
nullable: Optional[bool] = None,
server_default: Optional[
Union[_ServerDefault, Literal[False]]
@@ -203,7 +199,6 @@ class MSSQLImpl(DefaultImpl):
self,
table_name: str,
column: Column[Any],
*,
schema: Optional[str] = None,
**kw,
) -> None:

View File

@@ -1,6 +1,3 @@
# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
# mypy: no-warn-return-any, allow-any-generics
from __future__ import annotations
import re
@@ -11,9 +8,7 @@ from typing import Union
from sqlalchemy import schema
from sqlalchemy import types as sqltypes
from sqlalchemy.sql import elements
from sqlalchemy.sql import functions
from sqlalchemy.sql import operators
from sqlalchemy.ext.compiler import compiles
from .base import alter_table
from .base import AlterColumn
@@ -25,16 +20,16 @@ from .base import format_column_name
from .base import format_server_default
from .impl import DefaultImpl
from .. import util
from ..autogenerate import compare
from ..util import sqla_compat
from ..util.sqla_compat import _is_mariadb
from ..util.sqla_compat import _is_type_bound
from ..util.sqla_compat import compiles
if TYPE_CHECKING:
from typing import Literal
from sqlalchemy.dialects.mysql.base import MySQLDDLCompiler
from sqlalchemy.sql.ddl import DropConstraint
from sqlalchemy.sql.elements import ClauseElement
from sqlalchemy.sql.schema import Constraint
from sqlalchemy.sql.type_api import TypeEngine
@@ -51,40 +46,12 @@ class MySQLImpl(DefaultImpl):
)
type_arg_extract = [r"character set ([\w\-_]+)", r"collate ([\w\-_]+)"]
def render_ddl_sql_expr(
self,
expr: ClauseElement,
is_server_default: bool = False,
is_index: bool = False,
**kw: Any,
) -> str:
# apply Grouping to index expressions;
# see https://github.com/sqlalchemy/sqlalchemy/blob/
# 36da2eaf3e23269f2cf28420ae73674beafd0661/
# lib/sqlalchemy/dialects/mysql/base.py#L2191
if is_index and (
isinstance(expr, elements.BinaryExpression)
or (
isinstance(expr, elements.UnaryExpression)
and expr.modifier not in (operators.desc_op, operators.asc_op)
)
or isinstance(expr, functions.FunctionElement)
):
expr = elements.Grouping(expr)
return super().render_ddl_sql_expr(
expr, is_server_default=is_server_default, is_index=is_index, **kw
)
def alter_column(
def alter_column( # type:ignore[override]
self,
table_name: str,
column_name: str,
*,
nullable: Optional[bool] = None,
server_default: Optional[
Union[_ServerDefault, Literal[False]]
] = False,
server_default: Union[_ServerDefault, Literal[False]] = False,
name: Optional[str] = None,
type_: Optional[TypeEngine] = None,
schema: Optional[str] = None,
@@ -125,29 +92,21 @@ class MySQLImpl(DefaultImpl):
column_name,
schema=schema,
newname=name if name is not None else column_name,
nullable=(
nullable
if nullable is not None
else (
existing_nullable
if existing_nullable is not None
else True
)
),
nullable=nullable
if nullable is not None
else existing_nullable
if existing_nullable is not None
else True,
type_=type_ if type_ is not None else existing_type,
default=(
server_default
if server_default is not False
else existing_server_default
),
autoincrement=(
autoincrement
if autoincrement is not None
else existing_autoincrement
),
comment=(
comment if comment is not False else existing_comment
),
default=server_default
if server_default is not False
else existing_server_default,
autoincrement=autoincrement
if autoincrement is not None
else existing_autoincrement,
comment=comment
if comment is not False
else existing_comment,
)
)
elif (
@@ -162,29 +121,21 @@ class MySQLImpl(DefaultImpl):
column_name,
schema=schema,
newname=name if name is not None else column_name,
nullable=(
nullable
if nullable is not None
else (
existing_nullable
if existing_nullable is not None
else True
)
),
nullable=nullable
if nullable is not None
else existing_nullable
if existing_nullable is not None
else True,
type_=type_ if type_ is not None else existing_type,
default=(
server_default
if server_default is not False
else existing_server_default
),
autoincrement=(
autoincrement
if autoincrement is not None
else existing_autoincrement
),
comment=(
comment if comment is not False else existing_comment
),
default=server_default
if server_default is not False
else existing_server_default,
autoincrement=autoincrement
if autoincrement is not None
else existing_autoincrement,
comment=comment
if comment is not False
else existing_comment,
)
)
elif server_default is not False:
@@ -197,7 +148,6 @@ class MySQLImpl(DefaultImpl):
def drop_constraint(
self,
const: Constraint,
**kw: Any,
) -> None:
if isinstance(const, schema.CheckConstraint) and _is_type_bound(const):
return
@@ -207,11 +157,12 @@ class MySQLImpl(DefaultImpl):
def _is_mysql_allowed_functional_default(
self,
type_: Optional[TypeEngine],
server_default: Optional[Union[_ServerDefault, Literal[False]]],
server_default: Union[_ServerDefault, Literal[False]],
) -> bool:
return (
type_ is not None
and type_._type_affinity is sqltypes.DateTime
and type_._type_affinity # type:ignore[attr-defined]
is sqltypes.DateTime
and server_default is not None
)
@@ -321,12 +272,10 @@ class MySQLImpl(DefaultImpl):
def correct_for_autogen_foreignkeys(self, conn_fks, metadata_fks):
conn_fk_by_sig = {
self._create_reflected_constraint_sig(fk).unnamed_no_options: fk
for fk in conn_fks
compare._fk_constraint_sig(fk).sig: fk for fk in conn_fks
}
metadata_fk_by_sig = {
self._create_metadata_constraint_sig(fk).unnamed_no_options: fk
for fk in metadata_fks
compare._fk_constraint_sig(fk).sig: fk for fk in metadata_fks
}
for sig in set(conn_fk_by_sig).intersection(metadata_fk_by_sig):
@@ -358,7 +307,7 @@ class MySQLAlterDefault(AlterColumn):
self,
name: str,
column_name: str,
default: Optional[_ServerDefault],
default: _ServerDefault,
schema: Optional[str] = None,
) -> None:
super(AlterColumn, self).__init__(name, schema=schema)
@@ -416,11 +365,9 @@ def _mysql_alter_default(
return "%s ALTER COLUMN %s %s" % (
alter_table(compiler, element.table_name, element.schema),
format_column_name(compiler, element.column_name),
(
"SET DEFAULT %s" % format_server_default(compiler, element.default)
if element.default is not None
else "DROP DEFAULT"
),
"SET DEFAULT %s" % format_server_default(compiler, element.default)
if element.default is not None
else "DROP DEFAULT",
)
@@ -507,7 +454,7 @@ def _mysql_drop_constraint(
# note that SQLAlchemy as of 1.2 does not yet support
# DROP CONSTRAINT for MySQL/MariaDB, so we implement fully
# here.
if compiler.dialect.is_mariadb:
if _is_mariadb(compiler.dialect):
return "ALTER TABLE %s DROP CONSTRAINT %s" % (
compiler.preparer.format_table(constraint.table),
compiler.preparer.format_constraint(constraint),

View File

@@ -1,6 +1,3 @@
# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
# mypy: no-warn-return-any, allow-any-generics
from __future__ import annotations
import re
@@ -8,6 +5,7 @@ from typing import Any
from typing import Optional
from typing import TYPE_CHECKING
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.sql import sqltypes
from .base import AddColumn
@@ -24,7 +22,6 @@ from .base import format_type
from .base import IdentityColumnDefault
from .base import RenameTable
from .impl import DefaultImpl
from ..util.sqla_compat import compiles
if TYPE_CHECKING:
from sqlalchemy.dialects.oracle.base import OracleDDLCompiler
@@ -141,11 +138,9 @@ def visit_column_default(
return "%s %s %s" % (
alter_table(compiler, element.table_name, element.schema),
alter_column(compiler, element.column_name),
(
"DEFAULT %s" % format_server_default(compiler, element.default)
if element.default is not None
else "DEFAULT NULL"
),
"DEFAULT %s" % format_server_default(compiler, element.default)
if element.default is not None
else "DEFAULT NULL",
)

View File

@@ -1,6 +1,3 @@
# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
# mypy: no-warn-return-any, allow-any-generics
from __future__ import annotations
import logging
@@ -16,19 +13,18 @@ from typing import TYPE_CHECKING
from typing import Union
from sqlalchemy import Column
from sqlalchemy import Float
from sqlalchemy import Identity
from sqlalchemy import literal_column
from sqlalchemy import Numeric
from sqlalchemy import select
from sqlalchemy import text
from sqlalchemy import types as sqltypes
from sqlalchemy.dialects.postgresql import BIGINT
from sqlalchemy.dialects.postgresql import ExcludeConstraint
from sqlalchemy.dialects.postgresql import INTEGER
from sqlalchemy.schema import CreateIndex
from sqlalchemy.sql import operators
from sqlalchemy.sql.elements import ColumnClause
from sqlalchemy.sql.elements import TextClause
from sqlalchemy.sql.elements import UnaryExpression
from sqlalchemy.sql.functions import FunctionElement
from sqlalchemy.types import NULLTYPE
@@ -36,12 +32,12 @@ from .base import alter_column
from .base import alter_table
from .base import AlterColumn
from .base import ColumnComment
from .base import compiles
from .base import format_column_name
from .base import format_table_name
from .base import format_type
from .base import IdentityColumnDefault
from .base import RenameTable
from .impl import ComparisonResult
from .impl import DefaultImpl
from .. import util
from ..autogenerate import render
@@ -50,8 +46,6 @@ from ..operations import schemaobj
from ..operations.base import BatchOperations
from ..operations.base import Operations
from ..util import sqla_compat
from ..util.sqla_compat import compiles
if TYPE_CHECKING:
from typing import Literal
@@ -136,28 +130,25 @@ class PostgresqlImpl(DefaultImpl):
metadata_default = metadata_column.server_default.arg
if isinstance(metadata_default, str):
if not isinstance(inspector_column.type, (Numeric, Float)):
if not isinstance(inspector_column.type, Numeric):
metadata_default = re.sub(r"^'|'$", "", metadata_default)
metadata_default = f"'{metadata_default}'"
metadata_default = literal_column(metadata_default)
# run a real compare against the server
conn = self.connection
assert conn is not None
return not conn.scalar(
select(literal_column(conn_col_default) == metadata_default)
return not self.connection.scalar(
sqla_compat._select(
literal_column(conn_col_default) == metadata_default
)
)
def alter_column(
def alter_column( # type:ignore[override]
self,
table_name: str,
column_name: str,
*,
nullable: Optional[bool] = None,
server_default: Optional[
Union[_ServerDefault, Literal[False]]
] = False,
server_default: Union[_ServerDefault, Literal[False]] = False,
name: Optional[str] = None,
type_: Optional[TypeEngine] = None,
schema: Optional[str] = None,
@@ -223,8 +214,7 @@ class PostgresqlImpl(DefaultImpl):
"join pg_class t on t.oid=d.refobjid "
"join pg_attribute a on a.attrelid=t.oid and "
"a.attnum=d.refobjsubid "
"where c.relkind='S' and "
"c.oid=cast(:seqname as regclass)"
"where c.relkind='S' and c.relname=:seqname"
),
seqname=seq_match.group(1),
).first()
@@ -262,60 +252,62 @@ class PostgresqlImpl(DefaultImpl):
if not sqla_compat.sqla_2:
self._skip_functional_indexes(metadata_indexes, conn_indexes)
# pg behavior regarding modifiers
# | # | compiled sql | returned sql | regexp. group is removed |
# | - | ---------------- | -----------------| ------------------------ |
# | 1 | nulls first | nulls first | - |
# | 2 | nulls last | | (?<! desc)( nulls last)$ |
# | 3 | asc | | ( asc)$ |
# | 4 | asc nulls first | nulls first | ( asc) nulls first$ |
# | 5 | asc nulls last | | ( asc nulls last)$ |
# | 6 | desc | desc | - |
# | 7 | desc nulls first | desc | desc( nulls first)$ |
# | 8 | desc nulls last | desc nulls last | - |
_default_modifiers_re = ( # order of case 2 and 5 matters
re.compile("( asc nulls last)$"), # case 5
re.compile("(?<! desc)( nulls last)$"), # case 2
re.compile("( asc)$"), # case 3
re.compile("( asc) nulls first$"), # case 4
re.compile(" desc( nulls first)$"), # case 7
)
def _cleanup_index_expr(self, index: Index, expr: str) -> str:
def _cleanup_index_expr(
self, index: Index, expr: str, remove_suffix: str
) -> str:
# start = expr
expr = expr.lower().replace('"', "").replace("'", "")
if index.table is not None:
# should not be needed, since include_table=False is in compile
expr = expr.replace(f"{index.table.name.lower()}.", "")
while expr and expr[0] == "(" and expr[-1] == ")":
expr = expr[1:-1]
if "::" in expr:
# strip :: cast. types can have spaces in them
expr = re.sub(r"(::[\w ]+\w)", "", expr)
while expr and expr[0] == "(" and expr[-1] == ")":
expr = expr[1:-1]
if remove_suffix and expr.endswith(remove_suffix):
expr = expr[: -len(remove_suffix)]
# NOTE: when parsing the connection expression this cleanup could
# be skipped
for rs in self._default_modifiers_re:
if match := rs.search(expr):
start, end = match.span(1)
expr = expr[:start] + expr[end:]
break
while expr and expr[0] == "(" and expr[-1] == ")":
expr = expr[1:-1]
# strip casts
cast_re = re.compile(r"cast\s*\(")
if cast_re.match(expr):
expr = cast_re.sub("", expr)
# remove the as type
expr = re.sub(r"as\s+[^)]+\)", "", expr)
# remove spaces
expr = expr.replace(" ", "")
# print(f"START: {start} END: {expr}")
return expr
def _dialect_options(
def _default_modifiers(self, exp: ClauseElement) -> str:
to_remove = ""
while isinstance(exp, UnaryExpression):
if exp.modifier is None:
exp = exp.element
else:
op = exp.modifier
if isinstance(exp.element, UnaryExpression):
inner_op = exp.element.modifier
else:
inner_op = None
if inner_op is None:
if op == operators.asc_op:
# default is asc
to_remove = " asc"
elif op == operators.nullslast_op:
# default is nulls last
to_remove = " nulls last"
else:
if (
inner_op == operators.asc_op
and op == operators.nullslast_op
):
# default is asc nulls last
to_remove = " asc nulls last"
elif (
inner_op == operators.desc_op
and op == operators.nullsfirst_op
):
# default for desc is nulls first
to_remove = " nulls first"
break
return to_remove
def _dialect_sig(
self, item: Union[Index, UniqueConstraint]
) -> Tuple[Any, ...]:
# only the positive case is returned by sqlalchemy reflection so
@@ -324,93 +316,25 @@ class PostgresqlImpl(DefaultImpl):
return ("nulls_not_distinct",)
return ()
def compare_indexes(
self,
metadata_index: Index,
reflected_index: Index,
) -> ComparisonResult:
msg = []
unique_msg = self._compare_index_unique(
metadata_index, reflected_index
)
if unique_msg:
msg.append(unique_msg)
m_exprs = metadata_index.expressions
r_exprs = reflected_index.expressions
if len(m_exprs) != len(r_exprs):
msg.append(f"expression number {len(r_exprs)} to {len(m_exprs)}")
if msg:
# no point going further, return early
return ComparisonResult.Different(msg)
skip = []
for pos, (m_e, r_e) in enumerate(zip(m_exprs, r_exprs), 1):
m_compile = self._compile_element(m_e)
m_text = self._cleanup_index_expr(metadata_index, m_compile)
# print(f"META ORIG: {m_compile!r} CLEANUP: {m_text!r}")
r_compile = self._compile_element(r_e)
r_text = self._cleanup_index_expr(metadata_index, r_compile)
# print(f"CONN ORIG: {r_compile!r} CLEANUP: {r_text!r}")
if m_text == r_text:
continue # expressions these are equal
elif m_compile.strip().endswith("_ops") and (
" " in m_compile or ")" in m_compile # is an expression
):
skip.append(
f"expression #{pos} {m_compile!r} detected "
"as including operator clause."
)
util.warn(
f"Expression #{pos} {m_compile!r} in index "
f"{reflected_index.name!r} detected to include "
"an operator clause. Expression compare cannot proceed. "
"Please move the operator clause to the "
"``postgresql_ops`` dict to enable proper compare "
"of the index expressions: "
"https://docs.sqlalchemy.org/en/latest/dialects/postgresql.html#operator-classes", # noqa: E501
)
else:
msg.append(f"expression #{pos} {r_compile!r} to {m_compile!r}")
m_options = self._dialect_options(metadata_index)
r_options = self._dialect_options(reflected_index)
if m_options != r_options:
msg.extend(f"options {r_options} to {m_options}")
if msg:
return ComparisonResult.Different(msg)
elif skip:
# if there are other changes detected don't skip the index
return ComparisonResult.Skip(skip)
else:
return ComparisonResult.Equal()
def compare_unique_constraint(
self,
metadata_constraint: UniqueConstraint,
reflected_constraint: UniqueConstraint,
) -> ComparisonResult:
metadata_tup = self._create_metadata_constraint_sig(
metadata_constraint
)
reflected_tup = self._create_reflected_constraint_sig(
reflected_constraint
)
meta_sig = metadata_tup.unnamed
conn_sig = reflected_tup.unnamed
if conn_sig != meta_sig:
return ComparisonResult.Different(
f"expression {conn_sig} to {meta_sig}"
def create_index_sig(self, index: Index) -> Tuple[Any, ...]:
return tuple(
self._cleanup_index_expr(
index,
*(
(e, "")
if isinstance(e, str)
else (self._compile_element(e), self._default_modifiers(e))
),
)
for e in index.expressions
) + self._dialect_sig(index)
metadata_do = self._dialect_options(metadata_tup.const)
conn_do = self._dialect_options(reflected_tup.const)
if metadata_do != conn_do:
return ComparisonResult.Different(
f"expression {conn_do} to {metadata_do}"
)
return ComparisonResult.Equal()
def create_unique_constraint_sig(
self, const: UniqueConstraint
) -> Tuple[Any, ...]:
return tuple(
sorted([col.name for col in const.columns])
) + self._dialect_sig(const)
def adjust_reflected_dialect_options(
self, reflected_options: Dict[str, Any], kind: str
@@ -421,9 +345,7 @@ class PostgresqlImpl(DefaultImpl):
options.pop("postgresql_include", None)
return options
def _compile_element(self, element: Union[ClauseElement, str]) -> str:
if isinstance(element, str):
return element
def _compile_element(self, element: ClauseElement) -> str:
return element.compile(
dialect=self.dialect,
compile_kwargs={"literal_binds": True, "include_table": False},
@@ -590,7 +512,7 @@ def visit_identity_column(
)
else:
text += "SET %s " % compiler.get_identity_options(
Identity(**{attr: getattr(identity, attr)})
sqla_compat.Identity(**{attr: getattr(identity, attr)})
)
return text
@@ -634,8 +556,9 @@ class CreateExcludeConstraintOp(ops.AddConstraintOp):
return cls(
constraint.name,
constraint_table.name,
[ # type: ignore
(expr, op) for expr, name, op in constraint._render_exprs
[
(expr, op)
for expr, name, op in constraint._render_exprs # type:ignore[attr-defined] # noqa
],
where=cast("ColumnElement[bool] | None", constraint.where),
schema=constraint_table.schema,
@@ -662,7 +585,7 @@ class CreateExcludeConstraintOp(ops.AddConstraintOp):
expr,
name,
oper,
) in excl._render_exprs:
) in excl._render_exprs: # type:ignore[attr-defined]
t.append_column(Column(name, NULLTYPE))
t.append_constraint(excl)
return excl
@@ -720,7 +643,7 @@ class CreateExcludeConstraintOp(ops.AddConstraintOp):
constraint_name: str,
*elements: Any,
**kw: Any,
) -> Optional[Table]:
):
"""Issue a "create exclude constraint" instruction using the
current batch migration context.
@@ -792,13 +715,10 @@ def _exclude_constraint(
args = [
"(%s, %r)"
% (
_render_potential_column(
sqltext, # type:ignore[arg-type]
autogen_context,
),
_render_potential_column(sqltext, autogen_context),
opstring,
)
for sqltext, name, opstring in constraint._render_exprs
for sqltext, name, opstring in constraint._render_exprs # type:ignore[attr-defined] # noqa
]
if constraint.where is not None:
args.append(
@@ -850,5 +770,5 @@ def _render_potential_column(
return render._render_potential_expr(
value,
autogen_context,
wrap_in_element=isinstance(value, (TextClause, FunctionElement)),
wrap_in_text=isinstance(value, (TextClause, FunctionElement)),
)

View File

@@ -1,6 +1,3 @@
# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
# mypy: no-warn-return-any, allow-any-generics
from __future__ import annotations
import re
@@ -11,19 +8,16 @@ from typing import TYPE_CHECKING
from typing import Union
from sqlalchemy import cast
from sqlalchemy import Computed
from sqlalchemy import JSON
from sqlalchemy import schema
from sqlalchemy import sql
from sqlalchemy.ext.compiler import compiles
from .base import alter_table
from .base import ColumnName
from .base import format_column_name
from .base import format_table_name
from .base import RenameTable
from .impl import DefaultImpl
from .. import util
from ..util.sqla_compat import compiles
if TYPE_CHECKING:
from sqlalchemy.engine.reflection import Inspector
@@ -65,7 +59,7 @@ class SQLiteImpl(DefaultImpl):
) and isinstance(col.server_default.arg, sql.ClauseElement):
return True
elif (
isinstance(col.server_default, Computed)
isinstance(col.server_default, util.sqla_compat.Computed)
and col.server_default.persisted
):
return True
@@ -77,13 +71,13 @@ class SQLiteImpl(DefaultImpl):
def add_constraint(self, const: Constraint):
# attempt to distinguish between an
# auto-gen constraint and an explicit one
if const._create_rule is None:
if const._create_rule is None: # type:ignore[attr-defined]
raise NotImplementedError(
"No support for ALTER of constraints in SQLite dialect. "
"Please refer to the batch mode feature which allows for "
"SQLite migrations using a copy-and-move strategy."
)
elif const._create_rule(self):
elif const._create_rule(self): # type:ignore[attr-defined]
util.warn(
"Skipping unsupported ALTER for "
"creation of implicit constraint. "
@@ -91,8 +85,8 @@ class SQLiteImpl(DefaultImpl):
"SQLite migrations using a copy-and-move strategy."
)
def drop_constraint(self, const: Constraint, **kw: Any):
if const._create_rule is None:
def drop_constraint(self, const: Constraint):
if const._create_rule is None: # type:ignore[attr-defined]
raise NotImplementedError(
"No support for ALTER of constraints in SQLite dialect. "
"Please refer to the batch mode feature which allows for "
@@ -183,7 +177,8 @@ class SQLiteImpl(DefaultImpl):
new_type: TypeEngine,
) -> None:
if (
existing.type._type_affinity is not new_type._type_affinity
existing.type._type_affinity # type:ignore[attr-defined]
is not new_type._type_affinity # type:ignore[attr-defined]
and not isinstance(new_type, JSON)
):
existing_transfer["expr"] = cast(
@@ -210,15 +205,6 @@ def visit_rename_table(
)
@compiles(ColumnName, "sqlite")
def visit_column_name(element: ColumnName, compiler: DDLCompiler, **kw) -> str:
return "%s RENAME COLUMN %s TO %s" % (
alter_table(compiler, element.table_name, element.schema),
format_column_name(compiler, element.column_name),
format_column_name(compiler, element.newname),
)
# @compiles(AddColumn, 'sqlite')
# def visit_add_column(element, compiler, **kw):
# return "%s %s" % (