This commit is contained in:
@@ -3,4 +3,4 @@ from . import mysql
|
||||
from . import oracle
|
||||
from . import postgresql
|
||||
from . import sqlite
|
||||
from .impl import DefaultImpl as DefaultImpl
|
||||
from .impl import DefaultImpl
|
||||
|
||||
@@ -1,329 +0,0 @@
|
||||
# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
|
||||
# mypy: no-warn-return-any, allow-any-generics
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from typing import ClassVar
|
||||
from typing import Dict
|
||||
from typing import Generic
|
||||
from typing import NamedTuple
|
||||
from typing import Optional
|
||||
from typing import Sequence
|
||||
from typing import Tuple
|
||||
from typing import Type
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TypeVar
|
||||
from typing import Union
|
||||
|
||||
from sqlalchemy.sql.schema import Constraint
|
||||
from sqlalchemy.sql.schema import ForeignKeyConstraint
|
||||
from sqlalchemy.sql.schema import Index
|
||||
from sqlalchemy.sql.schema import UniqueConstraint
|
||||
from typing_extensions import TypeGuard
|
||||
|
||||
from .. import util
|
||||
from ..util import sqla_compat
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Literal
|
||||
|
||||
from alembic.autogenerate.api import AutogenContext
|
||||
from alembic.ddl.impl import DefaultImpl
|
||||
|
||||
CompareConstraintType = Union[Constraint, Index]
|
||||
|
||||
_C = TypeVar("_C", bound=CompareConstraintType)
|
||||
|
||||
_clsreg: Dict[str, Type[_constraint_sig]] = {}
|
||||
|
||||
|
||||
class ComparisonResult(NamedTuple):
|
||||
status: Literal["equal", "different", "skip"]
|
||||
message: str
|
||||
|
||||
@property
|
||||
def is_equal(self) -> bool:
|
||||
return self.status == "equal"
|
||||
|
||||
@property
|
||||
def is_different(self) -> bool:
|
||||
return self.status == "different"
|
||||
|
||||
@property
|
||||
def is_skip(self) -> bool:
|
||||
return self.status == "skip"
|
||||
|
||||
@classmethod
|
||||
def Equal(cls) -> ComparisonResult:
|
||||
"""the constraints are equal."""
|
||||
return cls("equal", "The two constraints are equal")
|
||||
|
||||
@classmethod
|
||||
def Different(cls, reason: Union[str, Sequence[str]]) -> ComparisonResult:
|
||||
"""the constraints are different for the provided reason(s)."""
|
||||
return cls("different", ", ".join(util.to_list(reason)))
|
||||
|
||||
@classmethod
|
||||
def Skip(cls, reason: Union[str, Sequence[str]]) -> ComparisonResult:
|
||||
"""the constraint cannot be compared for the provided reason(s).
|
||||
|
||||
The message is logged, but the constraints will be otherwise
|
||||
considered equal, meaning that no migration command will be
|
||||
generated.
|
||||
"""
|
||||
return cls("skip", ", ".join(util.to_list(reason)))
|
||||
|
||||
|
||||
class _constraint_sig(Generic[_C]):
|
||||
const: _C
|
||||
|
||||
_sig: Tuple[Any, ...]
|
||||
name: Optional[sqla_compat._ConstraintNameDefined]
|
||||
|
||||
impl: DefaultImpl
|
||||
|
||||
_is_index: ClassVar[bool] = False
|
||||
_is_fk: ClassVar[bool] = False
|
||||
_is_uq: ClassVar[bool] = False
|
||||
|
||||
_is_metadata: bool
|
||||
|
||||
def __init_subclass__(cls) -> None:
|
||||
cls._register()
|
||||
|
||||
@classmethod
|
||||
def _register(cls):
|
||||
raise NotImplementedError()
|
||||
|
||||
def __init__(
|
||||
self, is_metadata: bool, impl: DefaultImpl, const: _C
|
||||
) -> None:
|
||||
raise NotImplementedError()
|
||||
|
||||
def compare_to_reflected(
|
||||
self, other: _constraint_sig[Any]
|
||||
) -> ComparisonResult:
|
||||
assert self.impl is other.impl
|
||||
assert self._is_metadata
|
||||
assert not other._is_metadata
|
||||
|
||||
return self._compare_to_reflected(other)
|
||||
|
||||
def _compare_to_reflected(
|
||||
self, other: _constraint_sig[_C]
|
||||
) -> ComparisonResult:
|
||||
raise NotImplementedError()
|
||||
|
||||
@classmethod
|
||||
def from_constraint(
|
||||
cls, is_metadata: bool, impl: DefaultImpl, constraint: _C
|
||||
) -> _constraint_sig[_C]:
|
||||
# these could be cached by constraint/impl, however, if the
|
||||
# constraint is modified in place, then the sig is wrong. the mysql
|
||||
# impl currently does this, and if we fixed that we can't be sure
|
||||
# someone else might do it too, so play it safe.
|
||||
sig = _clsreg[constraint.__visit_name__](is_metadata, impl, constraint)
|
||||
return sig
|
||||
|
||||
def md_name_to_sql_name(self, context: AutogenContext) -> Optional[str]:
|
||||
return sqla_compat._get_constraint_final_name(
|
||||
self.const, context.dialect
|
||||
)
|
||||
|
||||
@util.memoized_property
|
||||
def is_named(self):
|
||||
return sqla_compat._constraint_is_named(self.const, self.impl.dialect)
|
||||
|
||||
@util.memoized_property
|
||||
def unnamed(self) -> Tuple[Any, ...]:
|
||||
return self._sig
|
||||
|
||||
@util.memoized_property
|
||||
def unnamed_no_options(self) -> Tuple[Any, ...]:
|
||||
raise NotImplementedError()
|
||||
|
||||
@util.memoized_property
|
||||
def _full_sig(self) -> Tuple[Any, ...]:
|
||||
return (self.name,) + self.unnamed
|
||||
|
||||
def __eq__(self, other) -> bool:
|
||||
return self._full_sig == other._full_sig
|
||||
|
||||
def __ne__(self, other) -> bool:
|
||||
return self._full_sig != other._full_sig
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(self._full_sig)
|
||||
|
||||
|
||||
class _uq_constraint_sig(_constraint_sig[UniqueConstraint]):
|
||||
_is_uq = True
|
||||
|
||||
@classmethod
|
||||
def _register(cls) -> None:
|
||||
_clsreg["unique_constraint"] = cls
|
||||
|
||||
is_unique = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
is_metadata: bool,
|
||||
impl: DefaultImpl,
|
||||
const: UniqueConstraint,
|
||||
) -> None:
|
||||
self.impl = impl
|
||||
self.const = const
|
||||
self.name = sqla_compat.constraint_name_or_none(const.name)
|
||||
self._sig = tuple(sorted([col.name for col in const.columns]))
|
||||
self._is_metadata = is_metadata
|
||||
|
||||
@property
|
||||
def column_names(self) -> Tuple[str, ...]:
|
||||
return tuple([col.name for col in self.const.columns])
|
||||
|
||||
def _compare_to_reflected(
|
||||
self, other: _constraint_sig[_C]
|
||||
) -> ComparisonResult:
|
||||
assert self._is_metadata
|
||||
metadata_obj = self
|
||||
conn_obj = other
|
||||
|
||||
assert is_uq_sig(conn_obj)
|
||||
return self.impl.compare_unique_constraint(
|
||||
metadata_obj.const, conn_obj.const
|
||||
)
|
||||
|
||||
|
||||
class _ix_constraint_sig(_constraint_sig[Index]):
|
||||
_is_index = True
|
||||
|
||||
name: sqla_compat._ConstraintName
|
||||
|
||||
@classmethod
|
||||
def _register(cls) -> None:
|
||||
_clsreg["index"] = cls
|
||||
|
||||
def __init__(
|
||||
self, is_metadata: bool, impl: DefaultImpl, const: Index
|
||||
) -> None:
|
||||
self.impl = impl
|
||||
self.const = const
|
||||
self.name = const.name
|
||||
self.is_unique = bool(const.unique)
|
||||
self._is_metadata = is_metadata
|
||||
|
||||
def _compare_to_reflected(
|
||||
self, other: _constraint_sig[_C]
|
||||
) -> ComparisonResult:
|
||||
assert self._is_metadata
|
||||
metadata_obj = self
|
||||
conn_obj = other
|
||||
|
||||
assert is_index_sig(conn_obj)
|
||||
return self.impl.compare_indexes(metadata_obj.const, conn_obj.const)
|
||||
|
||||
@util.memoized_property
|
||||
def has_expressions(self):
|
||||
return sqla_compat.is_expression_index(self.const)
|
||||
|
||||
@util.memoized_property
|
||||
def column_names(self) -> Tuple[str, ...]:
|
||||
return tuple([col.name for col in self.const.columns])
|
||||
|
||||
@util.memoized_property
|
||||
def column_names_optional(self) -> Tuple[Optional[str], ...]:
|
||||
return tuple(
|
||||
[getattr(col, "name", None) for col in self.const.expressions]
|
||||
)
|
||||
|
||||
@util.memoized_property
|
||||
def is_named(self):
|
||||
return True
|
||||
|
||||
@util.memoized_property
|
||||
def unnamed(self):
|
||||
return (self.is_unique,) + self.column_names_optional
|
||||
|
||||
|
||||
class _fk_constraint_sig(_constraint_sig[ForeignKeyConstraint]):
|
||||
_is_fk = True
|
||||
|
||||
@classmethod
|
||||
def _register(cls) -> None:
|
||||
_clsreg["foreign_key_constraint"] = cls
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
is_metadata: bool,
|
||||
impl: DefaultImpl,
|
||||
const: ForeignKeyConstraint,
|
||||
) -> None:
|
||||
self._is_metadata = is_metadata
|
||||
|
||||
self.impl = impl
|
||||
self.const = const
|
||||
|
||||
self.name = sqla_compat.constraint_name_or_none(const.name)
|
||||
|
||||
(
|
||||
self.source_schema,
|
||||
self.source_table,
|
||||
self.source_columns,
|
||||
self.target_schema,
|
||||
self.target_table,
|
||||
self.target_columns,
|
||||
onupdate,
|
||||
ondelete,
|
||||
deferrable,
|
||||
initially,
|
||||
) = sqla_compat._fk_spec(const)
|
||||
|
||||
self._sig: Tuple[Any, ...] = (
|
||||
self.source_schema,
|
||||
self.source_table,
|
||||
tuple(self.source_columns),
|
||||
self.target_schema,
|
||||
self.target_table,
|
||||
tuple(self.target_columns),
|
||||
) + (
|
||||
(
|
||||
(None if onupdate.lower() == "no action" else onupdate.lower())
|
||||
if onupdate
|
||||
else None
|
||||
),
|
||||
(
|
||||
(None if ondelete.lower() == "no action" else ondelete.lower())
|
||||
if ondelete
|
||||
else None
|
||||
),
|
||||
# convert initially + deferrable into one three-state value
|
||||
(
|
||||
"initially_deferrable"
|
||||
if initially and initially.lower() == "deferred"
|
||||
else "deferrable" if deferrable else "not deferrable"
|
||||
),
|
||||
)
|
||||
|
||||
@util.memoized_property
|
||||
def unnamed_no_options(self):
|
||||
return (
|
||||
self.source_schema,
|
||||
self.source_table,
|
||||
tuple(self.source_columns),
|
||||
self.target_schema,
|
||||
self.target_table,
|
||||
tuple(self.target_columns),
|
||||
)
|
||||
|
||||
|
||||
def is_index_sig(sig: _constraint_sig) -> TypeGuard[_ix_constraint_sig]:
|
||||
return sig._is_index
|
||||
|
||||
|
||||
def is_uq_sig(sig: _constraint_sig) -> TypeGuard[_uq_constraint_sig]:
|
||||
return sig._is_uq
|
||||
|
||||
|
||||
def is_fk_sig(sig: _constraint_sig) -> TypeGuard[_fk_constraint_sig]:
|
||||
return sig._is_fk
|
||||
@@ -1,6 +1,3 @@
|
||||
# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
|
||||
# mypy: no-warn-return-any, allow-any-generics
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
@@ -25,8 +22,6 @@ from ..util.sqla_compat import _table_for_constraint # noqa
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import Computed
|
||||
from sqlalchemy import Identity
|
||||
from sqlalchemy.sql.compiler import Compiled
|
||||
from sqlalchemy.sql.compiler import DDLCompiler
|
||||
from sqlalchemy.sql.elements import TextClause
|
||||
@@ -35,11 +30,14 @@ if TYPE_CHECKING:
|
||||
from sqlalchemy.sql.type_api import TypeEngine
|
||||
|
||||
from .impl import DefaultImpl
|
||||
from ..util.sqla_compat import Computed
|
||||
from ..util.sqla_compat import Identity
|
||||
|
||||
_ServerDefault = Union["TextClause", "FetchedValue", "Function[Any]", str]
|
||||
|
||||
|
||||
class AlterTable(DDLElement):
|
||||
|
||||
"""Represent an ALTER TABLE statement.
|
||||
|
||||
Only the string name and optional schema name of the table
|
||||
@@ -154,24 +152,17 @@ class AddColumn(AlterTable):
|
||||
name: str,
|
||||
column: Column[Any],
|
||||
schema: Optional[Union[quoted_name, str]] = None,
|
||||
if_not_exists: Optional[bool] = None,
|
||||
) -> None:
|
||||
super().__init__(name, schema=schema)
|
||||
self.column = column
|
||||
self.if_not_exists = if_not_exists
|
||||
|
||||
|
||||
class DropColumn(AlterTable):
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
column: Column[Any],
|
||||
schema: Optional[str] = None,
|
||||
if_exists: Optional[bool] = None,
|
||||
self, name: str, column: Column[Any], schema: Optional[str] = None
|
||||
) -> None:
|
||||
super().__init__(name, schema=schema)
|
||||
self.column = column
|
||||
self.if_exists = if_exists
|
||||
|
||||
|
||||
class ColumnComment(AlterColumn):
|
||||
@@ -196,9 +187,7 @@ def visit_rename_table(
|
||||
def visit_add_column(element: AddColumn, compiler: DDLCompiler, **kw) -> str:
|
||||
return "%s %s" % (
|
||||
alter_table(compiler, element.table_name, element.schema),
|
||||
add_column(
|
||||
compiler, element.column, if_not_exists=element.if_not_exists, **kw
|
||||
),
|
||||
add_column(compiler, element.column, **kw),
|
||||
)
|
||||
|
||||
|
||||
@@ -206,9 +195,7 @@ def visit_add_column(element: AddColumn, compiler: DDLCompiler, **kw) -> str:
|
||||
def visit_drop_column(element: DropColumn, compiler: DDLCompiler, **kw) -> str:
|
||||
return "%s %s" % (
|
||||
alter_table(compiler, element.table_name, element.schema),
|
||||
drop_column(
|
||||
compiler, element.column.name, if_exists=element.if_exists, **kw
|
||||
),
|
||||
drop_column(compiler, element.column.name, **kw),
|
||||
)
|
||||
|
||||
|
||||
@@ -248,11 +235,9 @@ def visit_column_default(
|
||||
return "%s %s %s" % (
|
||||
alter_table(compiler, element.table_name, element.schema),
|
||||
alter_column(compiler, element.column_name),
|
||||
(
|
||||
"SET DEFAULT %s" % format_server_default(compiler, element.default)
|
||||
if element.default is not None
|
||||
else "DROP DEFAULT"
|
||||
),
|
||||
"SET DEFAULT %s" % format_server_default(compiler, element.default)
|
||||
if element.default is not None
|
||||
else "DROP DEFAULT",
|
||||
)
|
||||
|
||||
|
||||
@@ -310,13 +295,9 @@ def format_server_default(
|
||||
compiler: DDLCompiler,
|
||||
default: Optional[_ServerDefault],
|
||||
) -> str:
|
||||
# this can be updated to use compiler.render_default_string
|
||||
# for SQLAlchemy 2.0 and above; not in 1.4
|
||||
default_str = compiler.get_column_default_string(
|
||||
return compiler.get_column_default_string(
|
||||
Column("x", Integer, server_default=default)
|
||||
)
|
||||
assert default_str is not None
|
||||
return default_str
|
||||
|
||||
|
||||
def format_type(compiler: DDLCompiler, type_: TypeEngine) -> str:
|
||||
@@ -331,29 +312,16 @@ def alter_table(
|
||||
return "ALTER TABLE %s" % format_table_name(compiler, name, schema)
|
||||
|
||||
|
||||
def drop_column(
|
||||
compiler: DDLCompiler, name: str, if_exists: Optional[bool] = None, **kw
|
||||
) -> str:
|
||||
return "DROP COLUMN %s%s" % (
|
||||
"IF EXISTS " if if_exists else "",
|
||||
format_column_name(compiler, name),
|
||||
)
|
||||
def drop_column(compiler: DDLCompiler, name: str, **kw) -> str:
|
||||
return "DROP COLUMN %s" % format_column_name(compiler, name)
|
||||
|
||||
|
||||
def alter_column(compiler: DDLCompiler, name: str) -> str:
|
||||
return "ALTER COLUMN %s" % format_column_name(compiler, name)
|
||||
|
||||
|
||||
def add_column(
|
||||
compiler: DDLCompiler,
|
||||
column: Column[Any],
|
||||
if_not_exists: Optional[bool] = None,
|
||||
**kw,
|
||||
) -> str:
|
||||
text = "ADD COLUMN %s%s" % (
|
||||
"IF NOT EXISTS " if if_not_exists else "",
|
||||
compiler.get_column_specification(column, **kw),
|
||||
)
|
||||
def add_column(compiler: DDLCompiler, column: Column[Any], **kw) -> str:
|
||||
text = "ADD COLUMN %s" % compiler.get_column_specification(column, **kw)
|
||||
|
||||
const = " ".join(
|
||||
compiler.process(constraint) for constraint in column.constraints
|
||||
|
||||
@@ -1,9 +1,6 @@
|
||||
# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
|
||||
# mypy: no-warn-return-any, allow-any-generics
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections import namedtuple
|
||||
import re
|
||||
from typing import Any
|
||||
from typing import Callable
|
||||
@@ -11,7 +8,6 @@ from typing import Dict
|
||||
from typing import Iterable
|
||||
from typing import List
|
||||
from typing import Mapping
|
||||
from typing import NamedTuple
|
||||
from typing import Optional
|
||||
from typing import Sequence
|
||||
from typing import Set
|
||||
@@ -21,18 +17,10 @@ from typing import TYPE_CHECKING
|
||||
from typing import Union
|
||||
|
||||
from sqlalchemy import cast
|
||||
from sqlalchemy import Column
|
||||
from sqlalchemy import MetaData
|
||||
from sqlalchemy import PrimaryKeyConstraint
|
||||
from sqlalchemy import schema
|
||||
from sqlalchemy import String
|
||||
from sqlalchemy import Table
|
||||
from sqlalchemy import text
|
||||
|
||||
from . import _autogen
|
||||
from . import base
|
||||
from ._autogen import _constraint_sig as _constraint_sig
|
||||
from ._autogen import ComparisonResult as ComparisonResult
|
||||
from .. import util
|
||||
from ..util import sqla_compat
|
||||
|
||||
@@ -46,10 +34,13 @@ if TYPE_CHECKING:
|
||||
from sqlalchemy.engine.reflection import Inspector
|
||||
from sqlalchemy.sql import ClauseElement
|
||||
from sqlalchemy.sql import Executable
|
||||
from sqlalchemy.sql.elements import ColumnElement
|
||||
from sqlalchemy.sql.elements import quoted_name
|
||||
from sqlalchemy.sql.schema import Column
|
||||
from sqlalchemy.sql.schema import Constraint
|
||||
from sqlalchemy.sql.schema import ForeignKeyConstraint
|
||||
from sqlalchemy.sql.schema import Index
|
||||
from sqlalchemy.sql.schema import Table
|
||||
from sqlalchemy.sql.schema import UniqueConstraint
|
||||
from sqlalchemy.sql.selectable import TableClause
|
||||
from sqlalchemy.sql.type_api import TypeEngine
|
||||
@@ -59,8 +50,6 @@ if TYPE_CHECKING:
|
||||
from ..operations.batch import ApplyBatchImpl
|
||||
from ..operations.batch import BatchOperationsImpl
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ImplMeta(type):
|
||||
def __init__(
|
||||
@@ -77,8 +66,11 @@ class ImplMeta(type):
|
||||
|
||||
_impls: Dict[str, Type[DefaultImpl]] = {}
|
||||
|
||||
Params = namedtuple("Params", ["token0", "tokens", "args", "kwargs"])
|
||||
|
||||
|
||||
class DefaultImpl(metaclass=ImplMeta):
|
||||
|
||||
"""Provide the entrypoint for major migration operations,
|
||||
including database-specific behavioral variances.
|
||||
|
||||
@@ -138,40 +130,6 @@ class DefaultImpl(metaclass=ImplMeta):
|
||||
self.output_buffer.write(text + "\n\n")
|
||||
self.output_buffer.flush()
|
||||
|
||||
def version_table_impl(
|
||||
self,
|
||||
*,
|
||||
version_table: str,
|
||||
version_table_schema: Optional[str],
|
||||
version_table_pk: bool,
|
||||
**kw: Any,
|
||||
) -> Table:
|
||||
"""Generate a :class:`.Table` object which will be used as the
|
||||
structure for the Alembic version table.
|
||||
|
||||
Third party dialects may override this hook to provide an alternate
|
||||
structure for this :class:`.Table`; requirements are only that it
|
||||
be named based on the ``version_table`` parameter and contains
|
||||
at least a single string-holding column named ``version_num``.
|
||||
|
||||
.. versionadded:: 1.14
|
||||
|
||||
"""
|
||||
vt = Table(
|
||||
version_table,
|
||||
MetaData(),
|
||||
Column("version_num", String(32), nullable=False),
|
||||
schema=version_table_schema,
|
||||
)
|
||||
if version_table_pk:
|
||||
vt.append_constraint(
|
||||
PrimaryKeyConstraint(
|
||||
"version_num", name=f"{version_table}_pkc"
|
||||
)
|
||||
)
|
||||
|
||||
return vt
|
||||
|
||||
def requires_recreate_in_batch(
|
||||
self, batch_op: BatchOperationsImpl
|
||||
) -> bool:
|
||||
@@ -203,15 +161,16 @@ class DefaultImpl(metaclass=ImplMeta):
|
||||
def _exec(
|
||||
self,
|
||||
construct: Union[Executable, str],
|
||||
execution_options: Optional[Mapping[str, Any]] = None,
|
||||
multiparams: Optional[Sequence[Mapping[str, Any]]] = None,
|
||||
params: Mapping[str, Any] = util.immutabledict(),
|
||||
execution_options: Optional[dict[str, Any]] = None,
|
||||
multiparams: Sequence[dict] = (),
|
||||
params: Dict[str, Any] = util.immutabledict(),
|
||||
) -> Optional[CursorResult]:
|
||||
if isinstance(construct, str):
|
||||
construct = text(construct)
|
||||
if self.as_sql:
|
||||
if multiparams is not None or params:
|
||||
raise TypeError("SQL parameters not allowed with as_sql")
|
||||
if multiparams or params:
|
||||
# TODO: coverage
|
||||
raise Exception("Execution arguments not allowed with as_sql")
|
||||
|
||||
compile_kw: dict[str, Any]
|
||||
if self.literal_binds and not isinstance(
|
||||
@@ -234,16 +193,11 @@ class DefaultImpl(metaclass=ImplMeta):
|
||||
assert conn is not None
|
||||
if execution_options:
|
||||
conn = conn.execution_options(**execution_options)
|
||||
if params:
|
||||
assert isinstance(multiparams, tuple)
|
||||
multiparams += (params,)
|
||||
|
||||
if params and multiparams is not None:
|
||||
raise TypeError(
|
||||
"Can't send params and multiparams at the same time"
|
||||
)
|
||||
|
||||
if multiparams:
|
||||
return conn.execute(construct, multiparams)
|
||||
else:
|
||||
return conn.execute(construct, params)
|
||||
return conn.execute(construct, multiparams)
|
||||
|
||||
def execute(
|
||||
self,
|
||||
@@ -256,11 +210,8 @@ class DefaultImpl(metaclass=ImplMeta):
|
||||
self,
|
||||
table_name: str,
|
||||
column_name: str,
|
||||
*,
|
||||
nullable: Optional[bool] = None,
|
||||
server_default: Optional[
|
||||
Union[_ServerDefault, Literal[False]]
|
||||
] = False,
|
||||
server_default: Union[_ServerDefault, Literal[False]] = False,
|
||||
name: Optional[str] = None,
|
||||
type_: Optional[TypeEngine] = None,
|
||||
schema: Optional[str] = None,
|
||||
@@ -371,40 +322,25 @@ class DefaultImpl(metaclass=ImplMeta):
|
||||
self,
|
||||
table_name: str,
|
||||
column: Column[Any],
|
||||
*,
|
||||
schema: Optional[Union[str, quoted_name]] = None,
|
||||
if_not_exists: Optional[bool] = None,
|
||||
) -> None:
|
||||
self._exec(
|
||||
base.AddColumn(
|
||||
table_name,
|
||||
column,
|
||||
schema=schema,
|
||||
if_not_exists=if_not_exists,
|
||||
)
|
||||
)
|
||||
self._exec(base.AddColumn(table_name, column, schema=schema))
|
||||
|
||||
def drop_column(
|
||||
self,
|
||||
table_name: str,
|
||||
column: Column[Any],
|
||||
*,
|
||||
schema: Optional[str] = None,
|
||||
if_exists: Optional[bool] = None,
|
||||
**kw,
|
||||
) -> None:
|
||||
self._exec(
|
||||
base.DropColumn(
|
||||
table_name, column, schema=schema, if_exists=if_exists
|
||||
)
|
||||
)
|
||||
self._exec(base.DropColumn(table_name, column, schema=schema))
|
||||
|
||||
def add_constraint(self, const: Any) -> None:
|
||||
if const._create_rule is None or const._create_rule(self):
|
||||
self._exec(schema.AddConstraint(const))
|
||||
|
||||
def drop_constraint(self, const: Constraint, **kw: Any) -> None:
|
||||
self._exec(schema.DropConstraint(const, **kw))
|
||||
def drop_constraint(self, const: Constraint) -> None:
|
||||
self._exec(schema.DropConstraint(const))
|
||||
|
||||
def rename_table(
|
||||
self,
|
||||
@@ -416,11 +352,11 @@ class DefaultImpl(metaclass=ImplMeta):
|
||||
base.RenameTable(old_table_name, new_table_name, schema=schema)
|
||||
)
|
||||
|
||||
def create_table(self, table: Table, **kw: Any) -> None:
|
||||
def create_table(self, table: Table) -> None:
|
||||
table.dispatch.before_create(
|
||||
table, self.connection, checkfirst=False, _ddl_runner=self
|
||||
)
|
||||
self._exec(schema.CreateTable(table, **kw))
|
||||
self._exec(schema.CreateTable(table))
|
||||
table.dispatch.after_create(
|
||||
table, self.connection, checkfirst=False, _ddl_runner=self
|
||||
)
|
||||
@@ -439,11 +375,11 @@ class DefaultImpl(metaclass=ImplMeta):
|
||||
if comment and with_comment:
|
||||
self.create_column_comment(column)
|
||||
|
||||
def drop_table(self, table: Table, **kw: Any) -> None:
|
||||
def drop_table(self, table: Table) -> None:
|
||||
table.dispatch.before_drop(
|
||||
table, self.connection, checkfirst=False, _ddl_runner=self
|
||||
)
|
||||
self._exec(schema.DropTable(table, **kw))
|
||||
self._exec(schema.DropTable(table))
|
||||
table.dispatch.after_drop(
|
||||
table, self.connection, checkfirst=False, _ddl_runner=self
|
||||
)
|
||||
@@ -457,7 +393,7 @@ class DefaultImpl(metaclass=ImplMeta):
|
||||
def drop_table_comment(self, table: Table) -> None:
|
||||
self._exec(schema.DropTableComment(table))
|
||||
|
||||
def create_column_comment(self, column: Column[Any]) -> None:
|
||||
def create_column_comment(self, column: ColumnElement[Any]) -> None:
|
||||
self._exec(schema.SetColumnComment(column))
|
||||
|
||||
def drop_index(self, index: Index, **kw: Any) -> None:
|
||||
@@ -476,19 +412,15 @@ class DefaultImpl(metaclass=ImplMeta):
|
||||
if self.as_sql:
|
||||
for row in rows:
|
||||
self._exec(
|
||||
table.insert()
|
||||
.inline()
|
||||
.values(
|
||||
sqla_compat._insert_inline(table).values(
|
||||
**{
|
||||
k: (
|
||||
sqla_compat._literal_bindparam(
|
||||
k, v, type_=table.c[k].type
|
||||
)
|
||||
if not isinstance(
|
||||
v, sqla_compat._literal_bindparam
|
||||
)
|
||||
else v
|
||||
k: sqla_compat._literal_bindparam(
|
||||
k, v, type_=table.c[k].type
|
||||
)
|
||||
if not isinstance(
|
||||
v, sqla_compat._literal_bindparam
|
||||
)
|
||||
else v
|
||||
for k, v in row.items()
|
||||
}
|
||||
)
|
||||
@@ -496,13 +428,16 @@ class DefaultImpl(metaclass=ImplMeta):
|
||||
else:
|
||||
if rows:
|
||||
if multiinsert:
|
||||
self._exec(table.insert().inline(), multiparams=rows)
|
||||
self._exec(
|
||||
sqla_compat._insert_inline(table), multiparams=rows
|
||||
)
|
||||
else:
|
||||
for row in rows:
|
||||
self._exec(table.insert().inline().values(**row))
|
||||
self._exec(
|
||||
sqla_compat._insert_inline(table).values(**row)
|
||||
)
|
||||
|
||||
def _tokenize_column_type(self, column: Column) -> Params:
|
||||
definition: str
|
||||
definition = self.dialect.type_compiler.process(column.type).lower()
|
||||
|
||||
# tokenize the SQLAlchemy-generated version of a type, so that
|
||||
@@ -517,9 +452,9 @@ class DefaultImpl(metaclass=ImplMeta):
|
||||
# varchar character set utf8
|
||||
#
|
||||
|
||||
tokens: List[str] = re.findall(r"[\w\-_]+|\(.+?\)", definition)
|
||||
tokens = re.findall(r"[\w\-_]+|\(.+?\)", definition)
|
||||
|
||||
term_tokens: List[str] = []
|
||||
term_tokens = []
|
||||
paren_term = None
|
||||
|
||||
for token in tokens:
|
||||
@@ -531,7 +466,6 @@ class DefaultImpl(metaclass=ImplMeta):
|
||||
params = Params(term_tokens[0], term_tokens[1:], [], {})
|
||||
|
||||
if paren_term:
|
||||
term: str
|
||||
for term in re.findall("[^(),]+", paren_term):
|
||||
if "=" in term:
|
||||
key, val = term.split("=")
|
||||
@@ -708,7 +642,7 @@ class DefaultImpl(metaclass=ImplMeta):
|
||||
diff, ignored = _compare_identity_options(
|
||||
metadata_identity,
|
||||
inspector_identity,
|
||||
schema.Identity(),
|
||||
sqla_compat.Identity(),
|
||||
skip={"always"},
|
||||
)
|
||||
|
||||
@@ -730,96 +664,15 @@ class DefaultImpl(metaclass=ImplMeta):
|
||||
bool(diff) or bool(metadata_identity) != bool(inspector_identity),
|
||||
)
|
||||
|
||||
def _compare_index_unique(
|
||||
self, metadata_index: Index, reflected_index: Index
|
||||
) -> Optional[str]:
|
||||
conn_unique = bool(reflected_index.unique)
|
||||
meta_unique = bool(metadata_index.unique)
|
||||
if conn_unique != meta_unique:
|
||||
return f"unique={conn_unique} to unique={meta_unique}"
|
||||
else:
|
||||
return None
|
||||
def create_index_sig(self, index: Index) -> Tuple[Any, ...]:
|
||||
# order of col matters in an index
|
||||
return tuple(col.name for col in index.columns)
|
||||
|
||||
def _create_metadata_constraint_sig(
|
||||
self, constraint: _autogen._C, **opts: Any
|
||||
) -> _constraint_sig[_autogen._C]:
|
||||
return _constraint_sig.from_constraint(True, self, constraint, **opts)
|
||||
|
||||
def _create_reflected_constraint_sig(
|
||||
self, constraint: _autogen._C, **opts: Any
|
||||
) -> _constraint_sig[_autogen._C]:
|
||||
return _constraint_sig.from_constraint(False, self, constraint, **opts)
|
||||
|
||||
def compare_indexes(
|
||||
self,
|
||||
metadata_index: Index,
|
||||
reflected_index: Index,
|
||||
) -> ComparisonResult:
|
||||
"""Compare two indexes by comparing the signature generated by
|
||||
``create_index_sig``.
|
||||
|
||||
This method returns a ``ComparisonResult``.
|
||||
"""
|
||||
msg: List[str] = []
|
||||
unique_msg = self._compare_index_unique(
|
||||
metadata_index, reflected_index
|
||||
)
|
||||
if unique_msg:
|
||||
msg.append(unique_msg)
|
||||
m_sig = self._create_metadata_constraint_sig(metadata_index)
|
||||
r_sig = self._create_reflected_constraint_sig(reflected_index)
|
||||
|
||||
assert _autogen.is_index_sig(m_sig)
|
||||
assert _autogen.is_index_sig(r_sig)
|
||||
|
||||
# The assumption is that the index have no expression
|
||||
for sig in m_sig, r_sig:
|
||||
if sig.has_expressions:
|
||||
log.warning(
|
||||
"Generating approximate signature for index %s. "
|
||||
"The dialect "
|
||||
"implementation should either skip expression indexes "
|
||||
"or provide a custom implementation.",
|
||||
sig.const,
|
||||
)
|
||||
|
||||
if m_sig.column_names != r_sig.column_names:
|
||||
msg.append(
|
||||
f"expression {r_sig.column_names} to {m_sig.column_names}"
|
||||
)
|
||||
|
||||
if msg:
|
||||
return ComparisonResult.Different(msg)
|
||||
else:
|
||||
return ComparisonResult.Equal()
|
||||
|
||||
def compare_unique_constraint(
|
||||
self,
|
||||
metadata_constraint: UniqueConstraint,
|
||||
reflected_constraint: UniqueConstraint,
|
||||
) -> ComparisonResult:
|
||||
"""Compare two unique constraints by comparing the two signatures.
|
||||
|
||||
The arguments are two tuples that contain the unique constraint and
|
||||
the signatures generated by ``create_unique_constraint_sig``.
|
||||
|
||||
This method returns a ``ComparisonResult``.
|
||||
"""
|
||||
metadata_tup = self._create_metadata_constraint_sig(
|
||||
metadata_constraint
|
||||
)
|
||||
reflected_tup = self._create_reflected_constraint_sig(
|
||||
reflected_constraint
|
||||
)
|
||||
|
||||
meta_sig = metadata_tup.unnamed
|
||||
conn_sig = reflected_tup.unnamed
|
||||
if conn_sig != meta_sig:
|
||||
return ComparisonResult.Different(
|
||||
f"expression {conn_sig} to {meta_sig}"
|
||||
)
|
||||
else:
|
||||
return ComparisonResult.Equal()
|
||||
def create_unique_constraint_sig(
|
||||
self, const: UniqueConstraint
|
||||
) -> Tuple[Any, ...]:
|
||||
# order of col does not matters in an unique constraint
|
||||
return tuple(sorted([col.name for col in const.columns]))
|
||||
|
||||
def _skip_functional_indexes(self, metadata_indexes, conn_indexes):
|
||||
conn_indexes_by_name = {c.name: c for c in conn_indexes}
|
||||
@@ -844,13 +697,6 @@ class DefaultImpl(metaclass=ImplMeta):
|
||||
return reflected_object.get("dialect_options", {})
|
||||
|
||||
|
||||
class Params(NamedTuple):
|
||||
token0: str
|
||||
tokens: List[str]
|
||||
args: List[str]
|
||||
kwargs: Dict[str, str]
|
||||
|
||||
|
||||
def _compare_identity_options(
|
||||
metadata_io: Union[schema.Identity, schema.Sequence, None],
|
||||
inspector_io: Union[schema.Identity, schema.Sequence, None],
|
||||
@@ -889,13 +735,12 @@ def _compare_identity_options(
|
||||
set(meta_d).union(insp_d),
|
||||
)
|
||||
if sqla_compat.identity_has_dialect_kwargs:
|
||||
assert hasattr(default_io, "dialect_kwargs")
|
||||
# use only the dialect kwargs in inspector_io since metadata_io
|
||||
# can have options for many backends
|
||||
check_dicts(
|
||||
getattr(metadata_io, "dialect_kwargs", {}),
|
||||
getattr(inspector_io, "dialect_kwargs", {}),
|
||||
default_io.dialect_kwargs,
|
||||
default_io.dialect_kwargs, # type: ignore[union-attr]
|
||||
getattr(inspector_io, "dialect_kwargs", {}),
|
||||
)
|
||||
|
||||
|
||||
@@ -1,6 +1,3 @@
|
||||
# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
|
||||
# mypy: no-warn-return-any, allow-any-generics
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
@@ -12,6 +9,7 @@ from typing import TYPE_CHECKING
|
||||
from typing import Union
|
||||
|
||||
from sqlalchemy import types as sqltypes
|
||||
from sqlalchemy.ext.compiler import compiles
|
||||
from sqlalchemy.schema import Column
|
||||
from sqlalchemy.schema import CreateIndex
|
||||
from sqlalchemy.sql.base import Executable
|
||||
@@ -32,7 +30,6 @@ from .base import RenameTable
|
||||
from .impl import DefaultImpl
|
||||
from .. import util
|
||||
from ..util import sqla_compat
|
||||
from ..util.sqla_compat import compiles
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Literal
|
||||
@@ -83,11 +80,10 @@ class MSSQLImpl(DefaultImpl):
|
||||
if self.as_sql and self.batch_separator:
|
||||
self.static_output(self.batch_separator)
|
||||
|
||||
def alter_column(
|
||||
def alter_column( # type:ignore[override]
|
||||
self,
|
||||
table_name: str,
|
||||
column_name: str,
|
||||
*,
|
||||
nullable: Optional[bool] = None,
|
||||
server_default: Optional[
|
||||
Union[_ServerDefault, Literal[False]]
|
||||
@@ -203,7 +199,6 @@ class MSSQLImpl(DefaultImpl):
|
||||
self,
|
||||
table_name: str,
|
||||
column: Column[Any],
|
||||
*,
|
||||
schema: Optional[str] = None,
|
||||
**kw,
|
||||
) -> None:
|
||||
|
||||
@@ -1,6 +1,3 @@
|
||||
# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
|
||||
# mypy: no-warn-return-any, allow-any-generics
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
@@ -11,9 +8,7 @@ from typing import Union
|
||||
|
||||
from sqlalchemy import schema
|
||||
from sqlalchemy import types as sqltypes
|
||||
from sqlalchemy.sql import elements
|
||||
from sqlalchemy.sql import functions
|
||||
from sqlalchemy.sql import operators
|
||||
from sqlalchemy.ext.compiler import compiles
|
||||
|
||||
from .base import alter_table
|
||||
from .base import AlterColumn
|
||||
@@ -25,16 +20,16 @@ from .base import format_column_name
|
||||
from .base import format_server_default
|
||||
from .impl import DefaultImpl
|
||||
from .. import util
|
||||
from ..autogenerate import compare
|
||||
from ..util import sqla_compat
|
||||
from ..util.sqla_compat import _is_mariadb
|
||||
from ..util.sqla_compat import _is_type_bound
|
||||
from ..util.sqla_compat import compiles
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Literal
|
||||
|
||||
from sqlalchemy.dialects.mysql.base import MySQLDDLCompiler
|
||||
from sqlalchemy.sql.ddl import DropConstraint
|
||||
from sqlalchemy.sql.elements import ClauseElement
|
||||
from sqlalchemy.sql.schema import Constraint
|
||||
from sqlalchemy.sql.type_api import TypeEngine
|
||||
|
||||
@@ -51,40 +46,12 @@ class MySQLImpl(DefaultImpl):
|
||||
)
|
||||
type_arg_extract = [r"character set ([\w\-_]+)", r"collate ([\w\-_]+)"]
|
||||
|
||||
def render_ddl_sql_expr(
|
||||
self,
|
||||
expr: ClauseElement,
|
||||
is_server_default: bool = False,
|
||||
is_index: bool = False,
|
||||
**kw: Any,
|
||||
) -> str:
|
||||
# apply Grouping to index expressions;
|
||||
# see https://github.com/sqlalchemy/sqlalchemy/blob/
|
||||
# 36da2eaf3e23269f2cf28420ae73674beafd0661/
|
||||
# lib/sqlalchemy/dialects/mysql/base.py#L2191
|
||||
if is_index and (
|
||||
isinstance(expr, elements.BinaryExpression)
|
||||
or (
|
||||
isinstance(expr, elements.UnaryExpression)
|
||||
and expr.modifier not in (operators.desc_op, operators.asc_op)
|
||||
)
|
||||
or isinstance(expr, functions.FunctionElement)
|
||||
):
|
||||
expr = elements.Grouping(expr)
|
||||
|
||||
return super().render_ddl_sql_expr(
|
||||
expr, is_server_default=is_server_default, is_index=is_index, **kw
|
||||
)
|
||||
|
||||
def alter_column(
|
||||
def alter_column( # type:ignore[override]
|
||||
self,
|
||||
table_name: str,
|
||||
column_name: str,
|
||||
*,
|
||||
nullable: Optional[bool] = None,
|
||||
server_default: Optional[
|
||||
Union[_ServerDefault, Literal[False]]
|
||||
] = False,
|
||||
server_default: Union[_ServerDefault, Literal[False]] = False,
|
||||
name: Optional[str] = None,
|
||||
type_: Optional[TypeEngine] = None,
|
||||
schema: Optional[str] = None,
|
||||
@@ -125,29 +92,21 @@ class MySQLImpl(DefaultImpl):
|
||||
column_name,
|
||||
schema=schema,
|
||||
newname=name if name is not None else column_name,
|
||||
nullable=(
|
||||
nullable
|
||||
if nullable is not None
|
||||
else (
|
||||
existing_nullable
|
||||
if existing_nullable is not None
|
||||
else True
|
||||
)
|
||||
),
|
||||
nullable=nullable
|
||||
if nullable is not None
|
||||
else existing_nullable
|
||||
if existing_nullable is not None
|
||||
else True,
|
||||
type_=type_ if type_ is not None else existing_type,
|
||||
default=(
|
||||
server_default
|
||||
if server_default is not False
|
||||
else existing_server_default
|
||||
),
|
||||
autoincrement=(
|
||||
autoincrement
|
||||
if autoincrement is not None
|
||||
else existing_autoincrement
|
||||
),
|
||||
comment=(
|
||||
comment if comment is not False else existing_comment
|
||||
),
|
||||
default=server_default
|
||||
if server_default is not False
|
||||
else existing_server_default,
|
||||
autoincrement=autoincrement
|
||||
if autoincrement is not None
|
||||
else existing_autoincrement,
|
||||
comment=comment
|
||||
if comment is not False
|
||||
else existing_comment,
|
||||
)
|
||||
)
|
||||
elif (
|
||||
@@ -162,29 +121,21 @@ class MySQLImpl(DefaultImpl):
|
||||
column_name,
|
||||
schema=schema,
|
||||
newname=name if name is not None else column_name,
|
||||
nullable=(
|
||||
nullable
|
||||
if nullable is not None
|
||||
else (
|
||||
existing_nullable
|
||||
if existing_nullable is not None
|
||||
else True
|
||||
)
|
||||
),
|
||||
nullable=nullable
|
||||
if nullable is not None
|
||||
else existing_nullable
|
||||
if existing_nullable is not None
|
||||
else True,
|
||||
type_=type_ if type_ is not None else existing_type,
|
||||
default=(
|
||||
server_default
|
||||
if server_default is not False
|
||||
else existing_server_default
|
||||
),
|
||||
autoincrement=(
|
||||
autoincrement
|
||||
if autoincrement is not None
|
||||
else existing_autoincrement
|
||||
),
|
||||
comment=(
|
||||
comment if comment is not False else existing_comment
|
||||
),
|
||||
default=server_default
|
||||
if server_default is not False
|
||||
else existing_server_default,
|
||||
autoincrement=autoincrement
|
||||
if autoincrement is not None
|
||||
else existing_autoincrement,
|
||||
comment=comment
|
||||
if comment is not False
|
||||
else existing_comment,
|
||||
)
|
||||
)
|
||||
elif server_default is not False:
|
||||
@@ -197,7 +148,6 @@ class MySQLImpl(DefaultImpl):
|
||||
def drop_constraint(
|
||||
self,
|
||||
const: Constraint,
|
||||
**kw: Any,
|
||||
) -> None:
|
||||
if isinstance(const, schema.CheckConstraint) and _is_type_bound(const):
|
||||
return
|
||||
@@ -207,11 +157,12 @@ class MySQLImpl(DefaultImpl):
|
||||
def _is_mysql_allowed_functional_default(
|
||||
self,
|
||||
type_: Optional[TypeEngine],
|
||||
server_default: Optional[Union[_ServerDefault, Literal[False]]],
|
||||
server_default: Union[_ServerDefault, Literal[False]],
|
||||
) -> bool:
|
||||
return (
|
||||
type_ is not None
|
||||
and type_._type_affinity is sqltypes.DateTime
|
||||
and type_._type_affinity # type:ignore[attr-defined]
|
||||
is sqltypes.DateTime
|
||||
and server_default is not None
|
||||
)
|
||||
|
||||
@@ -321,12 +272,10 @@ class MySQLImpl(DefaultImpl):
|
||||
|
||||
def correct_for_autogen_foreignkeys(self, conn_fks, metadata_fks):
|
||||
conn_fk_by_sig = {
|
||||
self._create_reflected_constraint_sig(fk).unnamed_no_options: fk
|
||||
for fk in conn_fks
|
||||
compare._fk_constraint_sig(fk).sig: fk for fk in conn_fks
|
||||
}
|
||||
metadata_fk_by_sig = {
|
||||
self._create_metadata_constraint_sig(fk).unnamed_no_options: fk
|
||||
for fk in metadata_fks
|
||||
compare._fk_constraint_sig(fk).sig: fk for fk in metadata_fks
|
||||
}
|
||||
|
||||
for sig in set(conn_fk_by_sig).intersection(metadata_fk_by_sig):
|
||||
@@ -358,7 +307,7 @@ class MySQLAlterDefault(AlterColumn):
|
||||
self,
|
||||
name: str,
|
||||
column_name: str,
|
||||
default: Optional[_ServerDefault],
|
||||
default: _ServerDefault,
|
||||
schema: Optional[str] = None,
|
||||
) -> None:
|
||||
super(AlterColumn, self).__init__(name, schema=schema)
|
||||
@@ -416,11 +365,9 @@ def _mysql_alter_default(
|
||||
return "%s ALTER COLUMN %s %s" % (
|
||||
alter_table(compiler, element.table_name, element.schema),
|
||||
format_column_name(compiler, element.column_name),
|
||||
(
|
||||
"SET DEFAULT %s" % format_server_default(compiler, element.default)
|
||||
if element.default is not None
|
||||
else "DROP DEFAULT"
|
||||
),
|
||||
"SET DEFAULT %s" % format_server_default(compiler, element.default)
|
||||
if element.default is not None
|
||||
else "DROP DEFAULT",
|
||||
)
|
||||
|
||||
|
||||
@@ -507,7 +454,7 @@ def _mysql_drop_constraint(
|
||||
# note that SQLAlchemy as of 1.2 does not yet support
|
||||
# DROP CONSTRAINT for MySQL/MariaDB, so we implement fully
|
||||
# here.
|
||||
if compiler.dialect.is_mariadb:
|
||||
if _is_mariadb(compiler.dialect):
|
||||
return "ALTER TABLE %s DROP CONSTRAINT %s" % (
|
||||
compiler.preparer.format_table(constraint.table),
|
||||
compiler.preparer.format_constraint(constraint),
|
||||
|
||||
@@ -1,6 +1,3 @@
|
||||
# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
|
||||
# mypy: no-warn-return-any, allow-any-generics
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
@@ -8,6 +5,7 @@ from typing import Any
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlalchemy.ext.compiler import compiles
|
||||
from sqlalchemy.sql import sqltypes
|
||||
|
||||
from .base import AddColumn
|
||||
@@ -24,7 +22,6 @@ from .base import format_type
|
||||
from .base import IdentityColumnDefault
|
||||
from .base import RenameTable
|
||||
from .impl import DefaultImpl
|
||||
from ..util.sqla_compat import compiles
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlalchemy.dialects.oracle.base import OracleDDLCompiler
|
||||
@@ -141,11 +138,9 @@ def visit_column_default(
|
||||
return "%s %s %s" % (
|
||||
alter_table(compiler, element.table_name, element.schema),
|
||||
alter_column(compiler, element.column_name),
|
||||
(
|
||||
"DEFAULT %s" % format_server_default(compiler, element.default)
|
||||
if element.default is not None
|
||||
else "DEFAULT NULL"
|
||||
),
|
||||
"DEFAULT %s" % format_server_default(compiler, element.default)
|
||||
if element.default is not None
|
||||
else "DEFAULT NULL",
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,3 @@
|
||||
# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
|
||||
# mypy: no-warn-return-any, allow-any-generics
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
@@ -16,19 +13,18 @@ from typing import TYPE_CHECKING
|
||||
from typing import Union
|
||||
|
||||
from sqlalchemy import Column
|
||||
from sqlalchemy import Float
|
||||
from sqlalchemy import Identity
|
||||
from sqlalchemy import literal_column
|
||||
from sqlalchemy import Numeric
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy import types as sqltypes
|
||||
from sqlalchemy.dialects.postgresql import BIGINT
|
||||
from sqlalchemy.dialects.postgresql import ExcludeConstraint
|
||||
from sqlalchemy.dialects.postgresql import INTEGER
|
||||
from sqlalchemy.schema import CreateIndex
|
||||
from sqlalchemy.sql import operators
|
||||
from sqlalchemy.sql.elements import ColumnClause
|
||||
from sqlalchemy.sql.elements import TextClause
|
||||
from sqlalchemy.sql.elements import UnaryExpression
|
||||
from sqlalchemy.sql.functions import FunctionElement
|
||||
from sqlalchemy.types import NULLTYPE
|
||||
|
||||
@@ -36,12 +32,12 @@ from .base import alter_column
|
||||
from .base import alter_table
|
||||
from .base import AlterColumn
|
||||
from .base import ColumnComment
|
||||
from .base import compiles
|
||||
from .base import format_column_name
|
||||
from .base import format_table_name
|
||||
from .base import format_type
|
||||
from .base import IdentityColumnDefault
|
||||
from .base import RenameTable
|
||||
from .impl import ComparisonResult
|
||||
from .impl import DefaultImpl
|
||||
from .. import util
|
||||
from ..autogenerate import render
|
||||
@@ -50,8 +46,6 @@ from ..operations import schemaobj
|
||||
from ..operations.base import BatchOperations
|
||||
from ..operations.base import Operations
|
||||
from ..util import sqla_compat
|
||||
from ..util.sqla_compat import compiles
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Literal
|
||||
@@ -136,28 +130,25 @@ class PostgresqlImpl(DefaultImpl):
|
||||
metadata_default = metadata_column.server_default.arg
|
||||
|
||||
if isinstance(metadata_default, str):
|
||||
if not isinstance(inspector_column.type, (Numeric, Float)):
|
||||
if not isinstance(inspector_column.type, Numeric):
|
||||
metadata_default = re.sub(r"^'|'$", "", metadata_default)
|
||||
metadata_default = f"'{metadata_default}'"
|
||||
|
||||
metadata_default = literal_column(metadata_default)
|
||||
|
||||
# run a real compare against the server
|
||||
conn = self.connection
|
||||
assert conn is not None
|
||||
return not conn.scalar(
|
||||
select(literal_column(conn_col_default) == metadata_default)
|
||||
return not self.connection.scalar(
|
||||
sqla_compat._select(
|
||||
literal_column(conn_col_default) == metadata_default
|
||||
)
|
||||
)
|
||||
|
||||
def alter_column(
|
||||
def alter_column( # type:ignore[override]
|
||||
self,
|
||||
table_name: str,
|
||||
column_name: str,
|
||||
*,
|
||||
nullable: Optional[bool] = None,
|
||||
server_default: Optional[
|
||||
Union[_ServerDefault, Literal[False]]
|
||||
] = False,
|
||||
server_default: Union[_ServerDefault, Literal[False]] = False,
|
||||
name: Optional[str] = None,
|
||||
type_: Optional[TypeEngine] = None,
|
||||
schema: Optional[str] = None,
|
||||
@@ -223,8 +214,7 @@ class PostgresqlImpl(DefaultImpl):
|
||||
"join pg_class t on t.oid=d.refobjid "
|
||||
"join pg_attribute a on a.attrelid=t.oid and "
|
||||
"a.attnum=d.refobjsubid "
|
||||
"where c.relkind='S' and "
|
||||
"c.oid=cast(:seqname as regclass)"
|
||||
"where c.relkind='S' and c.relname=:seqname"
|
||||
),
|
||||
seqname=seq_match.group(1),
|
||||
).first()
|
||||
@@ -262,60 +252,62 @@ class PostgresqlImpl(DefaultImpl):
|
||||
if not sqla_compat.sqla_2:
|
||||
self._skip_functional_indexes(metadata_indexes, conn_indexes)
|
||||
|
||||
# pg behavior regarding modifiers
|
||||
# | # | compiled sql | returned sql | regexp. group is removed |
|
||||
# | - | ---------------- | -----------------| ------------------------ |
|
||||
# | 1 | nulls first | nulls first | - |
|
||||
# | 2 | nulls last | | (?<! desc)( nulls last)$ |
|
||||
# | 3 | asc | | ( asc)$ |
|
||||
# | 4 | asc nulls first | nulls first | ( asc) nulls first$ |
|
||||
# | 5 | asc nulls last | | ( asc nulls last)$ |
|
||||
# | 6 | desc | desc | - |
|
||||
# | 7 | desc nulls first | desc | desc( nulls first)$ |
|
||||
# | 8 | desc nulls last | desc nulls last | - |
|
||||
_default_modifiers_re = ( # order of case 2 and 5 matters
|
||||
re.compile("( asc nulls last)$"), # case 5
|
||||
re.compile("(?<! desc)( nulls last)$"), # case 2
|
||||
re.compile("( asc)$"), # case 3
|
||||
re.compile("( asc) nulls first$"), # case 4
|
||||
re.compile(" desc( nulls first)$"), # case 7
|
||||
)
|
||||
|
||||
def _cleanup_index_expr(self, index: Index, expr: str) -> str:
|
||||
def _cleanup_index_expr(
|
||||
self, index: Index, expr: str, remove_suffix: str
|
||||
) -> str:
|
||||
# start = expr
|
||||
expr = expr.lower().replace('"', "").replace("'", "")
|
||||
if index.table is not None:
|
||||
# should not be needed, since include_table=False is in compile
|
||||
expr = expr.replace(f"{index.table.name.lower()}.", "")
|
||||
|
||||
while expr and expr[0] == "(" and expr[-1] == ")":
|
||||
expr = expr[1:-1]
|
||||
if "::" in expr:
|
||||
# strip :: cast. types can have spaces in them
|
||||
expr = re.sub(r"(::[\w ]+\w)", "", expr)
|
||||
|
||||
while expr and expr[0] == "(" and expr[-1] == ")":
|
||||
expr = expr[1:-1]
|
||||
if remove_suffix and expr.endswith(remove_suffix):
|
||||
expr = expr[: -len(remove_suffix)]
|
||||
|
||||
# NOTE: when parsing the connection expression this cleanup could
|
||||
# be skipped
|
||||
for rs in self._default_modifiers_re:
|
||||
if match := rs.search(expr):
|
||||
start, end = match.span(1)
|
||||
expr = expr[:start] + expr[end:]
|
||||
break
|
||||
|
||||
while expr and expr[0] == "(" and expr[-1] == ")":
|
||||
expr = expr[1:-1]
|
||||
|
||||
# strip casts
|
||||
cast_re = re.compile(r"cast\s*\(")
|
||||
if cast_re.match(expr):
|
||||
expr = cast_re.sub("", expr)
|
||||
# remove the as type
|
||||
expr = re.sub(r"as\s+[^)]+\)", "", expr)
|
||||
# remove spaces
|
||||
expr = expr.replace(" ", "")
|
||||
# print(f"START: {start} END: {expr}")
|
||||
return expr
|
||||
|
||||
def _dialect_options(
|
||||
def _default_modifiers(self, exp: ClauseElement) -> str:
|
||||
to_remove = ""
|
||||
while isinstance(exp, UnaryExpression):
|
||||
if exp.modifier is None:
|
||||
exp = exp.element
|
||||
else:
|
||||
op = exp.modifier
|
||||
if isinstance(exp.element, UnaryExpression):
|
||||
inner_op = exp.element.modifier
|
||||
else:
|
||||
inner_op = None
|
||||
if inner_op is None:
|
||||
if op == operators.asc_op:
|
||||
# default is asc
|
||||
to_remove = " asc"
|
||||
elif op == operators.nullslast_op:
|
||||
# default is nulls last
|
||||
to_remove = " nulls last"
|
||||
else:
|
||||
if (
|
||||
inner_op == operators.asc_op
|
||||
and op == operators.nullslast_op
|
||||
):
|
||||
# default is asc nulls last
|
||||
to_remove = " asc nulls last"
|
||||
elif (
|
||||
inner_op == operators.desc_op
|
||||
and op == operators.nullsfirst_op
|
||||
):
|
||||
# default for desc is nulls first
|
||||
to_remove = " nulls first"
|
||||
break
|
||||
return to_remove
|
||||
|
||||
def _dialect_sig(
|
||||
self, item: Union[Index, UniqueConstraint]
|
||||
) -> Tuple[Any, ...]:
|
||||
# only the positive case is returned by sqlalchemy reflection so
|
||||
@@ -324,93 +316,25 @@ class PostgresqlImpl(DefaultImpl):
|
||||
return ("nulls_not_distinct",)
|
||||
return ()
|
||||
|
||||
def compare_indexes(
|
||||
self,
|
||||
metadata_index: Index,
|
||||
reflected_index: Index,
|
||||
) -> ComparisonResult:
|
||||
msg = []
|
||||
unique_msg = self._compare_index_unique(
|
||||
metadata_index, reflected_index
|
||||
)
|
||||
if unique_msg:
|
||||
msg.append(unique_msg)
|
||||
m_exprs = metadata_index.expressions
|
||||
r_exprs = reflected_index.expressions
|
||||
if len(m_exprs) != len(r_exprs):
|
||||
msg.append(f"expression number {len(r_exprs)} to {len(m_exprs)}")
|
||||
if msg:
|
||||
# no point going further, return early
|
||||
return ComparisonResult.Different(msg)
|
||||
skip = []
|
||||
for pos, (m_e, r_e) in enumerate(zip(m_exprs, r_exprs), 1):
|
||||
m_compile = self._compile_element(m_e)
|
||||
m_text = self._cleanup_index_expr(metadata_index, m_compile)
|
||||
# print(f"META ORIG: {m_compile!r} CLEANUP: {m_text!r}")
|
||||
r_compile = self._compile_element(r_e)
|
||||
r_text = self._cleanup_index_expr(metadata_index, r_compile)
|
||||
# print(f"CONN ORIG: {r_compile!r} CLEANUP: {r_text!r}")
|
||||
if m_text == r_text:
|
||||
continue # expressions these are equal
|
||||
elif m_compile.strip().endswith("_ops") and (
|
||||
" " in m_compile or ")" in m_compile # is an expression
|
||||
):
|
||||
skip.append(
|
||||
f"expression #{pos} {m_compile!r} detected "
|
||||
"as including operator clause."
|
||||
)
|
||||
util.warn(
|
||||
f"Expression #{pos} {m_compile!r} in index "
|
||||
f"{reflected_index.name!r} detected to include "
|
||||
"an operator clause. Expression compare cannot proceed. "
|
||||
"Please move the operator clause to the "
|
||||
"``postgresql_ops`` dict to enable proper compare "
|
||||
"of the index expressions: "
|
||||
"https://docs.sqlalchemy.org/en/latest/dialects/postgresql.html#operator-classes", # noqa: E501
|
||||
)
|
||||
else:
|
||||
msg.append(f"expression #{pos} {r_compile!r} to {m_compile!r}")
|
||||
|
||||
m_options = self._dialect_options(metadata_index)
|
||||
r_options = self._dialect_options(reflected_index)
|
||||
if m_options != r_options:
|
||||
msg.extend(f"options {r_options} to {m_options}")
|
||||
|
||||
if msg:
|
||||
return ComparisonResult.Different(msg)
|
||||
elif skip:
|
||||
# if there are other changes detected don't skip the index
|
||||
return ComparisonResult.Skip(skip)
|
||||
else:
|
||||
return ComparisonResult.Equal()
|
||||
|
||||
def compare_unique_constraint(
|
||||
self,
|
||||
metadata_constraint: UniqueConstraint,
|
||||
reflected_constraint: UniqueConstraint,
|
||||
) -> ComparisonResult:
|
||||
metadata_tup = self._create_metadata_constraint_sig(
|
||||
metadata_constraint
|
||||
)
|
||||
reflected_tup = self._create_reflected_constraint_sig(
|
||||
reflected_constraint
|
||||
)
|
||||
|
||||
meta_sig = metadata_tup.unnamed
|
||||
conn_sig = reflected_tup.unnamed
|
||||
if conn_sig != meta_sig:
|
||||
return ComparisonResult.Different(
|
||||
f"expression {conn_sig} to {meta_sig}"
|
||||
def create_index_sig(self, index: Index) -> Tuple[Any, ...]:
|
||||
return tuple(
|
||||
self._cleanup_index_expr(
|
||||
index,
|
||||
*(
|
||||
(e, "")
|
||||
if isinstance(e, str)
|
||||
else (self._compile_element(e), self._default_modifiers(e))
|
||||
),
|
||||
)
|
||||
for e in index.expressions
|
||||
) + self._dialect_sig(index)
|
||||
|
||||
metadata_do = self._dialect_options(metadata_tup.const)
|
||||
conn_do = self._dialect_options(reflected_tup.const)
|
||||
if metadata_do != conn_do:
|
||||
return ComparisonResult.Different(
|
||||
f"expression {conn_do} to {metadata_do}"
|
||||
)
|
||||
|
||||
return ComparisonResult.Equal()
|
||||
def create_unique_constraint_sig(
|
||||
self, const: UniqueConstraint
|
||||
) -> Tuple[Any, ...]:
|
||||
return tuple(
|
||||
sorted([col.name for col in const.columns])
|
||||
) + self._dialect_sig(const)
|
||||
|
||||
def adjust_reflected_dialect_options(
|
||||
self, reflected_options: Dict[str, Any], kind: str
|
||||
@@ -421,9 +345,7 @@ class PostgresqlImpl(DefaultImpl):
|
||||
options.pop("postgresql_include", None)
|
||||
return options
|
||||
|
||||
def _compile_element(self, element: Union[ClauseElement, str]) -> str:
|
||||
if isinstance(element, str):
|
||||
return element
|
||||
def _compile_element(self, element: ClauseElement) -> str:
|
||||
return element.compile(
|
||||
dialect=self.dialect,
|
||||
compile_kwargs={"literal_binds": True, "include_table": False},
|
||||
@@ -590,7 +512,7 @@ def visit_identity_column(
|
||||
)
|
||||
else:
|
||||
text += "SET %s " % compiler.get_identity_options(
|
||||
Identity(**{attr: getattr(identity, attr)})
|
||||
sqla_compat.Identity(**{attr: getattr(identity, attr)})
|
||||
)
|
||||
return text
|
||||
|
||||
@@ -634,8 +556,9 @@ class CreateExcludeConstraintOp(ops.AddConstraintOp):
|
||||
return cls(
|
||||
constraint.name,
|
||||
constraint_table.name,
|
||||
[ # type: ignore
|
||||
(expr, op) for expr, name, op in constraint._render_exprs
|
||||
[
|
||||
(expr, op)
|
||||
for expr, name, op in constraint._render_exprs # type:ignore[attr-defined] # noqa
|
||||
],
|
||||
where=cast("ColumnElement[bool] | None", constraint.where),
|
||||
schema=constraint_table.schema,
|
||||
@@ -662,7 +585,7 @@ class CreateExcludeConstraintOp(ops.AddConstraintOp):
|
||||
expr,
|
||||
name,
|
||||
oper,
|
||||
) in excl._render_exprs:
|
||||
) in excl._render_exprs: # type:ignore[attr-defined]
|
||||
t.append_column(Column(name, NULLTYPE))
|
||||
t.append_constraint(excl)
|
||||
return excl
|
||||
@@ -720,7 +643,7 @@ class CreateExcludeConstraintOp(ops.AddConstraintOp):
|
||||
constraint_name: str,
|
||||
*elements: Any,
|
||||
**kw: Any,
|
||||
) -> Optional[Table]:
|
||||
):
|
||||
"""Issue a "create exclude constraint" instruction using the
|
||||
current batch migration context.
|
||||
|
||||
@@ -792,13 +715,10 @@ def _exclude_constraint(
|
||||
args = [
|
||||
"(%s, %r)"
|
||||
% (
|
||||
_render_potential_column(
|
||||
sqltext, # type:ignore[arg-type]
|
||||
autogen_context,
|
||||
),
|
||||
_render_potential_column(sqltext, autogen_context),
|
||||
opstring,
|
||||
)
|
||||
for sqltext, name, opstring in constraint._render_exprs
|
||||
for sqltext, name, opstring in constraint._render_exprs # type:ignore[attr-defined] # noqa
|
||||
]
|
||||
if constraint.where is not None:
|
||||
args.append(
|
||||
@@ -850,5 +770,5 @@ def _render_potential_column(
|
||||
return render._render_potential_expr(
|
||||
value,
|
||||
autogen_context,
|
||||
wrap_in_element=isinstance(value, (TextClause, FunctionElement)),
|
||||
wrap_in_text=isinstance(value, (TextClause, FunctionElement)),
|
||||
)
|
||||
|
||||
@@ -1,6 +1,3 @@
|
||||
# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
|
||||
# mypy: no-warn-return-any, allow-any-generics
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
@@ -11,19 +8,16 @@ from typing import TYPE_CHECKING
|
||||
from typing import Union
|
||||
|
||||
from sqlalchemy import cast
|
||||
from sqlalchemy import Computed
|
||||
from sqlalchemy import JSON
|
||||
from sqlalchemy import schema
|
||||
from sqlalchemy import sql
|
||||
from sqlalchemy.ext.compiler import compiles
|
||||
|
||||
from .base import alter_table
|
||||
from .base import ColumnName
|
||||
from .base import format_column_name
|
||||
from .base import format_table_name
|
||||
from .base import RenameTable
|
||||
from .impl import DefaultImpl
|
||||
from .. import util
|
||||
from ..util.sqla_compat import compiles
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlalchemy.engine.reflection import Inspector
|
||||
@@ -65,7 +59,7 @@ class SQLiteImpl(DefaultImpl):
|
||||
) and isinstance(col.server_default.arg, sql.ClauseElement):
|
||||
return True
|
||||
elif (
|
||||
isinstance(col.server_default, Computed)
|
||||
isinstance(col.server_default, util.sqla_compat.Computed)
|
||||
and col.server_default.persisted
|
||||
):
|
||||
return True
|
||||
@@ -77,13 +71,13 @@ class SQLiteImpl(DefaultImpl):
|
||||
def add_constraint(self, const: Constraint):
|
||||
# attempt to distinguish between an
|
||||
# auto-gen constraint and an explicit one
|
||||
if const._create_rule is None:
|
||||
if const._create_rule is None: # type:ignore[attr-defined]
|
||||
raise NotImplementedError(
|
||||
"No support for ALTER of constraints in SQLite dialect. "
|
||||
"Please refer to the batch mode feature which allows for "
|
||||
"SQLite migrations using a copy-and-move strategy."
|
||||
)
|
||||
elif const._create_rule(self):
|
||||
elif const._create_rule(self): # type:ignore[attr-defined]
|
||||
util.warn(
|
||||
"Skipping unsupported ALTER for "
|
||||
"creation of implicit constraint. "
|
||||
@@ -91,8 +85,8 @@ class SQLiteImpl(DefaultImpl):
|
||||
"SQLite migrations using a copy-and-move strategy."
|
||||
)
|
||||
|
||||
def drop_constraint(self, const: Constraint, **kw: Any):
|
||||
if const._create_rule is None:
|
||||
def drop_constraint(self, const: Constraint):
|
||||
if const._create_rule is None: # type:ignore[attr-defined]
|
||||
raise NotImplementedError(
|
||||
"No support for ALTER of constraints in SQLite dialect. "
|
||||
"Please refer to the batch mode feature which allows for "
|
||||
@@ -183,7 +177,8 @@ class SQLiteImpl(DefaultImpl):
|
||||
new_type: TypeEngine,
|
||||
) -> None:
|
||||
if (
|
||||
existing.type._type_affinity is not new_type._type_affinity
|
||||
existing.type._type_affinity # type:ignore[attr-defined]
|
||||
is not new_type._type_affinity # type:ignore[attr-defined]
|
||||
and not isinstance(new_type, JSON)
|
||||
):
|
||||
existing_transfer["expr"] = cast(
|
||||
@@ -210,15 +205,6 @@ def visit_rename_table(
|
||||
)
|
||||
|
||||
|
||||
@compiles(ColumnName, "sqlite")
|
||||
def visit_column_name(element: ColumnName, compiler: DDLCompiler, **kw) -> str:
|
||||
return "%s RENAME COLUMN %s TO %s" % (
|
||||
alter_table(compiler, element.table_name, element.schema),
|
||||
format_column_name(compiler, element.column_name),
|
||||
format_column_name(compiler, element.newname),
|
||||
)
|
||||
|
||||
|
||||
# @compiles(AddColumn, 'sqlite')
|
||||
# def visit_add_column(element, compiler, **kw):
|
||||
# return "%s %s" % (
|
||||
|
||||
Reference in New Issue
Block a user