This commit is contained in:
@@ -1,6 +1,9 @@
|
||||
# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
|
||||
# mypy: no-warn-return-any, allow-any-generics
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import namedtuple
|
||||
import logging
|
||||
import re
|
||||
from typing import Any
|
||||
from typing import Callable
|
||||
@@ -8,6 +11,7 @@ from typing import Dict
|
||||
from typing import Iterable
|
||||
from typing import List
|
||||
from typing import Mapping
|
||||
from typing import NamedTuple
|
||||
from typing import Optional
|
||||
from typing import Sequence
|
||||
from typing import Set
|
||||
@@ -17,10 +21,18 @@ from typing import TYPE_CHECKING
|
||||
from typing import Union
|
||||
|
||||
from sqlalchemy import cast
|
||||
from sqlalchemy import Column
|
||||
from sqlalchemy import MetaData
|
||||
from sqlalchemy import PrimaryKeyConstraint
|
||||
from sqlalchemy import schema
|
||||
from sqlalchemy import String
|
||||
from sqlalchemy import Table
|
||||
from sqlalchemy import text
|
||||
|
||||
from . import _autogen
|
||||
from . import base
|
||||
from ._autogen import _constraint_sig as _constraint_sig
|
||||
from ._autogen import ComparisonResult as ComparisonResult
|
||||
from .. import util
|
||||
from ..util import sqla_compat
|
||||
|
||||
@@ -34,13 +46,10 @@ if TYPE_CHECKING:
|
||||
from sqlalchemy.engine.reflection import Inspector
|
||||
from sqlalchemy.sql import ClauseElement
|
||||
from sqlalchemy.sql import Executable
|
||||
from sqlalchemy.sql.elements import ColumnElement
|
||||
from sqlalchemy.sql.elements import quoted_name
|
||||
from sqlalchemy.sql.schema import Column
|
||||
from sqlalchemy.sql.schema import Constraint
|
||||
from sqlalchemy.sql.schema import ForeignKeyConstraint
|
||||
from sqlalchemy.sql.schema import Index
|
||||
from sqlalchemy.sql.schema import Table
|
||||
from sqlalchemy.sql.schema import UniqueConstraint
|
||||
from sqlalchemy.sql.selectable import TableClause
|
||||
from sqlalchemy.sql.type_api import TypeEngine
|
||||
@@ -50,6 +59,8 @@ if TYPE_CHECKING:
|
||||
from ..operations.batch import ApplyBatchImpl
|
||||
from ..operations.batch import BatchOperationsImpl
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ImplMeta(type):
|
||||
def __init__(
|
||||
@@ -66,11 +77,8 @@ class ImplMeta(type):
|
||||
|
||||
_impls: Dict[str, Type[DefaultImpl]] = {}
|
||||
|
||||
Params = namedtuple("Params", ["token0", "tokens", "args", "kwargs"])
|
||||
|
||||
|
||||
class DefaultImpl(metaclass=ImplMeta):
|
||||
|
||||
"""Provide the entrypoint for major migration operations,
|
||||
including database-specific behavioral variances.
|
||||
|
||||
@@ -130,6 +138,40 @@ class DefaultImpl(metaclass=ImplMeta):
|
||||
self.output_buffer.write(text + "\n\n")
|
||||
self.output_buffer.flush()
|
||||
|
||||
def version_table_impl(
|
||||
self,
|
||||
*,
|
||||
version_table: str,
|
||||
version_table_schema: Optional[str],
|
||||
version_table_pk: bool,
|
||||
**kw: Any,
|
||||
) -> Table:
|
||||
"""Generate a :class:`.Table` object which will be used as the
|
||||
structure for the Alembic version table.
|
||||
|
||||
Third party dialects may override this hook to provide an alternate
|
||||
structure for this :class:`.Table`; requirements are only that it
|
||||
be named based on the ``version_table`` parameter and contains
|
||||
at least a single string-holding column named ``version_num``.
|
||||
|
||||
.. versionadded:: 1.14
|
||||
|
||||
"""
|
||||
vt = Table(
|
||||
version_table,
|
||||
MetaData(),
|
||||
Column("version_num", String(32), nullable=False),
|
||||
schema=version_table_schema,
|
||||
)
|
||||
if version_table_pk:
|
||||
vt.append_constraint(
|
||||
PrimaryKeyConstraint(
|
||||
"version_num", name=f"{version_table}_pkc"
|
||||
)
|
||||
)
|
||||
|
||||
return vt
|
||||
|
||||
def requires_recreate_in_batch(
|
||||
self, batch_op: BatchOperationsImpl
|
||||
) -> bool:
|
||||
@@ -161,16 +203,15 @@ class DefaultImpl(metaclass=ImplMeta):
|
||||
def _exec(
|
||||
self,
|
||||
construct: Union[Executable, str],
|
||||
execution_options: Optional[dict[str, Any]] = None,
|
||||
multiparams: Sequence[dict] = (),
|
||||
params: Dict[str, Any] = util.immutabledict(),
|
||||
execution_options: Optional[Mapping[str, Any]] = None,
|
||||
multiparams: Optional[Sequence[Mapping[str, Any]]] = None,
|
||||
params: Mapping[str, Any] = util.immutabledict(),
|
||||
) -> Optional[CursorResult]:
|
||||
if isinstance(construct, str):
|
||||
construct = text(construct)
|
||||
if self.as_sql:
|
||||
if multiparams or params:
|
||||
# TODO: coverage
|
||||
raise Exception("Execution arguments not allowed with as_sql")
|
||||
if multiparams is not None or params:
|
||||
raise TypeError("SQL parameters not allowed with as_sql")
|
||||
|
||||
compile_kw: dict[str, Any]
|
||||
if self.literal_binds and not isinstance(
|
||||
@@ -193,11 +234,16 @@ class DefaultImpl(metaclass=ImplMeta):
|
||||
assert conn is not None
|
||||
if execution_options:
|
||||
conn = conn.execution_options(**execution_options)
|
||||
if params:
|
||||
assert isinstance(multiparams, tuple)
|
||||
multiparams += (params,)
|
||||
|
||||
return conn.execute(construct, multiparams)
|
||||
if params and multiparams is not None:
|
||||
raise TypeError(
|
||||
"Can't send params and multiparams at the same time"
|
||||
)
|
||||
|
||||
if multiparams:
|
||||
return conn.execute(construct, multiparams)
|
||||
else:
|
||||
return conn.execute(construct, params)
|
||||
|
||||
def execute(
|
||||
self,
|
||||
@@ -210,8 +256,11 @@ class DefaultImpl(metaclass=ImplMeta):
|
||||
self,
|
||||
table_name: str,
|
||||
column_name: str,
|
||||
*,
|
||||
nullable: Optional[bool] = None,
|
||||
server_default: Union[_ServerDefault, Literal[False]] = False,
|
||||
server_default: Optional[
|
||||
Union[_ServerDefault, Literal[False]]
|
||||
] = False,
|
||||
name: Optional[str] = None,
|
||||
type_: Optional[TypeEngine] = None,
|
||||
schema: Optional[str] = None,
|
||||
@@ -322,25 +371,40 @@ class DefaultImpl(metaclass=ImplMeta):
|
||||
self,
|
||||
table_name: str,
|
||||
column: Column[Any],
|
||||
*,
|
||||
schema: Optional[Union[str, quoted_name]] = None,
|
||||
if_not_exists: Optional[bool] = None,
|
||||
) -> None:
|
||||
self._exec(base.AddColumn(table_name, column, schema=schema))
|
||||
self._exec(
|
||||
base.AddColumn(
|
||||
table_name,
|
||||
column,
|
||||
schema=schema,
|
||||
if_not_exists=if_not_exists,
|
||||
)
|
||||
)
|
||||
|
||||
def drop_column(
|
||||
self,
|
||||
table_name: str,
|
||||
column: Column[Any],
|
||||
*,
|
||||
schema: Optional[str] = None,
|
||||
if_exists: Optional[bool] = None,
|
||||
**kw,
|
||||
) -> None:
|
||||
self._exec(base.DropColumn(table_name, column, schema=schema))
|
||||
self._exec(
|
||||
base.DropColumn(
|
||||
table_name, column, schema=schema, if_exists=if_exists
|
||||
)
|
||||
)
|
||||
|
||||
def add_constraint(self, const: Any) -> None:
|
||||
if const._create_rule is None or const._create_rule(self):
|
||||
self._exec(schema.AddConstraint(const))
|
||||
|
||||
def drop_constraint(self, const: Constraint) -> None:
|
||||
self._exec(schema.DropConstraint(const))
|
||||
def drop_constraint(self, const: Constraint, **kw: Any) -> None:
|
||||
self._exec(schema.DropConstraint(const, **kw))
|
||||
|
||||
def rename_table(
|
||||
self,
|
||||
@@ -352,11 +416,11 @@ class DefaultImpl(metaclass=ImplMeta):
|
||||
base.RenameTable(old_table_name, new_table_name, schema=schema)
|
||||
)
|
||||
|
||||
def create_table(self, table: Table) -> None:
|
||||
def create_table(self, table: Table, **kw: Any) -> None:
|
||||
table.dispatch.before_create(
|
||||
table, self.connection, checkfirst=False, _ddl_runner=self
|
||||
)
|
||||
self._exec(schema.CreateTable(table))
|
||||
self._exec(schema.CreateTable(table, **kw))
|
||||
table.dispatch.after_create(
|
||||
table, self.connection, checkfirst=False, _ddl_runner=self
|
||||
)
|
||||
@@ -375,11 +439,11 @@ class DefaultImpl(metaclass=ImplMeta):
|
||||
if comment and with_comment:
|
||||
self.create_column_comment(column)
|
||||
|
||||
def drop_table(self, table: Table) -> None:
|
||||
def drop_table(self, table: Table, **kw: Any) -> None:
|
||||
table.dispatch.before_drop(
|
||||
table, self.connection, checkfirst=False, _ddl_runner=self
|
||||
)
|
||||
self._exec(schema.DropTable(table))
|
||||
self._exec(schema.DropTable(table, **kw))
|
||||
table.dispatch.after_drop(
|
||||
table, self.connection, checkfirst=False, _ddl_runner=self
|
||||
)
|
||||
@@ -393,7 +457,7 @@ class DefaultImpl(metaclass=ImplMeta):
|
||||
def drop_table_comment(self, table: Table) -> None:
|
||||
self._exec(schema.DropTableComment(table))
|
||||
|
||||
def create_column_comment(self, column: ColumnElement[Any]) -> None:
|
||||
def create_column_comment(self, column: Column[Any]) -> None:
|
||||
self._exec(schema.SetColumnComment(column))
|
||||
|
||||
def drop_index(self, index: Index, **kw: Any) -> None:
|
||||
@@ -412,15 +476,19 @@ class DefaultImpl(metaclass=ImplMeta):
|
||||
if self.as_sql:
|
||||
for row in rows:
|
||||
self._exec(
|
||||
sqla_compat._insert_inline(table).values(
|
||||
table.insert()
|
||||
.inline()
|
||||
.values(
|
||||
**{
|
||||
k: sqla_compat._literal_bindparam(
|
||||
k, v, type_=table.c[k].type
|
||||
k: (
|
||||
sqla_compat._literal_bindparam(
|
||||
k, v, type_=table.c[k].type
|
||||
)
|
||||
if not isinstance(
|
||||
v, sqla_compat._literal_bindparam
|
||||
)
|
||||
else v
|
||||
)
|
||||
if not isinstance(
|
||||
v, sqla_compat._literal_bindparam
|
||||
)
|
||||
else v
|
||||
for k, v in row.items()
|
||||
}
|
||||
)
|
||||
@@ -428,16 +496,13 @@ class DefaultImpl(metaclass=ImplMeta):
|
||||
else:
|
||||
if rows:
|
||||
if multiinsert:
|
||||
self._exec(
|
||||
sqla_compat._insert_inline(table), multiparams=rows
|
||||
)
|
||||
self._exec(table.insert().inline(), multiparams=rows)
|
||||
else:
|
||||
for row in rows:
|
||||
self._exec(
|
||||
sqla_compat._insert_inline(table).values(**row)
|
||||
)
|
||||
self._exec(table.insert().inline().values(**row))
|
||||
|
||||
def _tokenize_column_type(self, column: Column) -> Params:
|
||||
definition: str
|
||||
definition = self.dialect.type_compiler.process(column.type).lower()
|
||||
|
||||
# tokenize the SQLAlchemy-generated version of a type, so that
|
||||
@@ -452,9 +517,9 @@ class DefaultImpl(metaclass=ImplMeta):
|
||||
# varchar character set utf8
|
||||
#
|
||||
|
||||
tokens = re.findall(r"[\w\-_]+|\(.+?\)", definition)
|
||||
tokens: List[str] = re.findall(r"[\w\-_]+|\(.+?\)", definition)
|
||||
|
||||
term_tokens = []
|
||||
term_tokens: List[str] = []
|
||||
paren_term = None
|
||||
|
||||
for token in tokens:
|
||||
@@ -466,6 +531,7 @@ class DefaultImpl(metaclass=ImplMeta):
|
||||
params = Params(term_tokens[0], term_tokens[1:], [], {})
|
||||
|
||||
if paren_term:
|
||||
term: str
|
||||
for term in re.findall("[^(),]+", paren_term):
|
||||
if "=" in term:
|
||||
key, val = term.split("=")
|
||||
@@ -642,7 +708,7 @@ class DefaultImpl(metaclass=ImplMeta):
|
||||
diff, ignored = _compare_identity_options(
|
||||
metadata_identity,
|
||||
inspector_identity,
|
||||
sqla_compat.Identity(),
|
||||
schema.Identity(),
|
||||
skip={"always"},
|
||||
)
|
||||
|
||||
@@ -664,15 +730,96 @@ class DefaultImpl(metaclass=ImplMeta):
|
||||
bool(diff) or bool(metadata_identity) != bool(inspector_identity),
|
||||
)
|
||||
|
||||
def create_index_sig(self, index: Index) -> Tuple[Any, ...]:
|
||||
# order of col matters in an index
|
||||
return tuple(col.name for col in index.columns)
|
||||
def _compare_index_unique(
|
||||
self, metadata_index: Index, reflected_index: Index
|
||||
) -> Optional[str]:
|
||||
conn_unique = bool(reflected_index.unique)
|
||||
meta_unique = bool(metadata_index.unique)
|
||||
if conn_unique != meta_unique:
|
||||
return f"unique={conn_unique} to unique={meta_unique}"
|
||||
else:
|
||||
return None
|
||||
|
||||
def create_unique_constraint_sig(
|
||||
self, const: UniqueConstraint
|
||||
) -> Tuple[Any, ...]:
|
||||
# order of col does not matters in an unique constraint
|
||||
return tuple(sorted([col.name for col in const.columns]))
|
||||
def _create_metadata_constraint_sig(
|
||||
self, constraint: _autogen._C, **opts: Any
|
||||
) -> _constraint_sig[_autogen._C]:
|
||||
return _constraint_sig.from_constraint(True, self, constraint, **opts)
|
||||
|
||||
def _create_reflected_constraint_sig(
|
||||
self, constraint: _autogen._C, **opts: Any
|
||||
) -> _constraint_sig[_autogen._C]:
|
||||
return _constraint_sig.from_constraint(False, self, constraint, **opts)
|
||||
|
||||
def compare_indexes(
|
||||
self,
|
||||
metadata_index: Index,
|
||||
reflected_index: Index,
|
||||
) -> ComparisonResult:
|
||||
"""Compare two indexes by comparing the signature generated by
|
||||
``create_index_sig``.
|
||||
|
||||
This method returns a ``ComparisonResult``.
|
||||
"""
|
||||
msg: List[str] = []
|
||||
unique_msg = self._compare_index_unique(
|
||||
metadata_index, reflected_index
|
||||
)
|
||||
if unique_msg:
|
||||
msg.append(unique_msg)
|
||||
m_sig = self._create_metadata_constraint_sig(metadata_index)
|
||||
r_sig = self._create_reflected_constraint_sig(reflected_index)
|
||||
|
||||
assert _autogen.is_index_sig(m_sig)
|
||||
assert _autogen.is_index_sig(r_sig)
|
||||
|
||||
# The assumption is that the index have no expression
|
||||
for sig in m_sig, r_sig:
|
||||
if sig.has_expressions:
|
||||
log.warning(
|
||||
"Generating approximate signature for index %s. "
|
||||
"The dialect "
|
||||
"implementation should either skip expression indexes "
|
||||
"or provide a custom implementation.",
|
||||
sig.const,
|
||||
)
|
||||
|
||||
if m_sig.column_names != r_sig.column_names:
|
||||
msg.append(
|
||||
f"expression {r_sig.column_names} to {m_sig.column_names}"
|
||||
)
|
||||
|
||||
if msg:
|
||||
return ComparisonResult.Different(msg)
|
||||
else:
|
||||
return ComparisonResult.Equal()
|
||||
|
||||
def compare_unique_constraint(
|
||||
self,
|
||||
metadata_constraint: UniqueConstraint,
|
||||
reflected_constraint: UniqueConstraint,
|
||||
) -> ComparisonResult:
|
||||
"""Compare two unique constraints by comparing the two signatures.
|
||||
|
||||
The arguments are two tuples that contain the unique constraint and
|
||||
the signatures generated by ``create_unique_constraint_sig``.
|
||||
|
||||
This method returns a ``ComparisonResult``.
|
||||
"""
|
||||
metadata_tup = self._create_metadata_constraint_sig(
|
||||
metadata_constraint
|
||||
)
|
||||
reflected_tup = self._create_reflected_constraint_sig(
|
||||
reflected_constraint
|
||||
)
|
||||
|
||||
meta_sig = metadata_tup.unnamed
|
||||
conn_sig = reflected_tup.unnamed
|
||||
if conn_sig != meta_sig:
|
||||
return ComparisonResult.Different(
|
||||
f"expression {conn_sig} to {meta_sig}"
|
||||
)
|
||||
else:
|
||||
return ComparisonResult.Equal()
|
||||
|
||||
def _skip_functional_indexes(self, metadata_indexes, conn_indexes):
|
||||
conn_indexes_by_name = {c.name: c for c in conn_indexes}
|
||||
@@ -697,6 +844,13 @@ class DefaultImpl(metaclass=ImplMeta):
|
||||
return reflected_object.get("dialect_options", {})
|
||||
|
||||
|
||||
class Params(NamedTuple):
|
||||
token0: str
|
||||
tokens: List[str]
|
||||
args: List[str]
|
||||
kwargs: Dict[str, str]
|
||||
|
||||
|
||||
def _compare_identity_options(
|
||||
metadata_io: Union[schema.Identity, schema.Sequence, None],
|
||||
inspector_io: Union[schema.Identity, schema.Sequence, None],
|
||||
@@ -735,12 +889,13 @@ def _compare_identity_options(
|
||||
set(meta_d).union(insp_d),
|
||||
)
|
||||
if sqla_compat.identity_has_dialect_kwargs:
|
||||
assert hasattr(default_io, "dialect_kwargs")
|
||||
# use only the dialect kwargs in inspector_io since metadata_io
|
||||
# can have options for many backends
|
||||
check_dicts(
|
||||
getattr(metadata_io, "dialect_kwargs", {}),
|
||||
getattr(inspector_io, "dialect_kwargs", {}),
|
||||
default_io.dialect_kwargs, # type: ignore[union-attr]
|
||||
default_io.dialect_kwargs,
|
||||
getattr(inspector_io, "dialect_kwargs", {}),
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user