This commit is contained in:
@@ -1,3 +1,6 @@
|
||||
# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
|
||||
# mypy: no-warn-return-any, allow-any-generics
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
@@ -13,18 +16,19 @@ from typing import TYPE_CHECKING
|
||||
from typing import Union
|
||||
|
||||
from sqlalchemy import Column
|
||||
from sqlalchemy import Float
|
||||
from sqlalchemy import Identity
|
||||
from sqlalchemy import literal_column
|
||||
from sqlalchemy import Numeric
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy import types as sqltypes
|
||||
from sqlalchemy.dialects.postgresql import BIGINT
|
||||
from sqlalchemy.dialects.postgresql import ExcludeConstraint
|
||||
from sqlalchemy.dialects.postgresql import INTEGER
|
||||
from sqlalchemy.schema import CreateIndex
|
||||
from sqlalchemy.sql import operators
|
||||
from sqlalchemy.sql.elements import ColumnClause
|
||||
from sqlalchemy.sql.elements import TextClause
|
||||
from sqlalchemy.sql.elements import UnaryExpression
|
||||
from sqlalchemy.sql.functions import FunctionElement
|
||||
from sqlalchemy.types import NULLTYPE
|
||||
|
||||
@@ -32,12 +36,12 @@ from .base import alter_column
|
||||
from .base import alter_table
|
||||
from .base import AlterColumn
|
||||
from .base import ColumnComment
|
||||
from .base import compiles
|
||||
from .base import format_column_name
|
||||
from .base import format_table_name
|
||||
from .base import format_type
|
||||
from .base import IdentityColumnDefault
|
||||
from .base import RenameTable
|
||||
from .impl import ComparisonResult
|
||||
from .impl import DefaultImpl
|
||||
from .. import util
|
||||
from ..autogenerate import render
|
||||
@@ -46,6 +50,8 @@ from ..operations import schemaobj
|
||||
from ..operations.base import BatchOperations
|
||||
from ..operations.base import Operations
|
||||
from ..util import sqla_compat
|
||||
from ..util.sqla_compat import compiles
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Literal
|
||||
@@ -130,25 +136,28 @@ class PostgresqlImpl(DefaultImpl):
|
||||
metadata_default = metadata_column.server_default.arg
|
||||
|
||||
if isinstance(metadata_default, str):
|
||||
if not isinstance(inspector_column.type, Numeric):
|
||||
if not isinstance(inspector_column.type, (Numeric, Float)):
|
||||
metadata_default = re.sub(r"^'|'$", "", metadata_default)
|
||||
metadata_default = f"'{metadata_default}'"
|
||||
|
||||
metadata_default = literal_column(metadata_default)
|
||||
|
||||
# run a real compare against the server
|
||||
return not self.connection.scalar(
|
||||
sqla_compat._select(
|
||||
literal_column(conn_col_default) == metadata_default
|
||||
)
|
||||
conn = self.connection
|
||||
assert conn is not None
|
||||
return not conn.scalar(
|
||||
select(literal_column(conn_col_default) == metadata_default)
|
||||
)
|
||||
|
||||
def alter_column( # type:ignore[override]
|
||||
def alter_column(
|
||||
self,
|
||||
table_name: str,
|
||||
column_name: str,
|
||||
*,
|
||||
nullable: Optional[bool] = None,
|
||||
server_default: Union[_ServerDefault, Literal[False]] = False,
|
||||
server_default: Optional[
|
||||
Union[_ServerDefault, Literal[False]]
|
||||
] = False,
|
||||
name: Optional[str] = None,
|
||||
type_: Optional[TypeEngine] = None,
|
||||
schema: Optional[str] = None,
|
||||
@@ -214,7 +223,8 @@ class PostgresqlImpl(DefaultImpl):
|
||||
"join pg_class t on t.oid=d.refobjid "
|
||||
"join pg_attribute a on a.attrelid=t.oid and "
|
||||
"a.attnum=d.refobjsubid "
|
||||
"where c.relkind='S' and c.relname=:seqname"
|
||||
"where c.relkind='S' and "
|
||||
"c.oid=cast(:seqname as regclass)"
|
||||
),
|
||||
seqname=seq_match.group(1),
|
||||
).first()
|
||||
@@ -252,62 +262,60 @@ class PostgresqlImpl(DefaultImpl):
|
||||
if not sqla_compat.sqla_2:
|
||||
self._skip_functional_indexes(metadata_indexes, conn_indexes)
|
||||
|
||||
def _cleanup_index_expr(
|
||||
self, index: Index, expr: str, remove_suffix: str
|
||||
) -> str:
|
||||
# start = expr
|
||||
# pg behavior regarding modifiers
|
||||
# | # | compiled sql | returned sql | regexp. group is removed |
|
||||
# | - | ---------------- | -----------------| ------------------------ |
|
||||
# | 1 | nulls first | nulls first | - |
|
||||
# | 2 | nulls last | | (?<! desc)( nulls last)$ |
|
||||
# | 3 | asc | | ( asc)$ |
|
||||
# | 4 | asc nulls first | nulls first | ( asc) nulls first$ |
|
||||
# | 5 | asc nulls last | | ( asc nulls last)$ |
|
||||
# | 6 | desc | desc | - |
|
||||
# | 7 | desc nulls first | desc | desc( nulls first)$ |
|
||||
# | 8 | desc nulls last | desc nulls last | - |
|
||||
_default_modifiers_re = ( # order of case 2 and 5 matters
|
||||
re.compile("( asc nulls last)$"), # case 5
|
||||
re.compile("(?<! desc)( nulls last)$"), # case 2
|
||||
re.compile("( asc)$"), # case 3
|
||||
re.compile("( asc) nulls first$"), # case 4
|
||||
re.compile(" desc( nulls first)$"), # case 7
|
||||
)
|
||||
|
||||
def _cleanup_index_expr(self, index: Index, expr: str) -> str:
|
||||
expr = expr.lower().replace('"', "").replace("'", "")
|
||||
if index.table is not None:
|
||||
# should not be needed, since include_table=False is in compile
|
||||
expr = expr.replace(f"{index.table.name.lower()}.", "")
|
||||
|
||||
while expr and expr[0] == "(" and expr[-1] == ")":
|
||||
expr = expr[1:-1]
|
||||
if "::" in expr:
|
||||
# strip :: cast. types can have spaces in them
|
||||
expr = re.sub(r"(::[\w ]+\w)", "", expr)
|
||||
|
||||
if remove_suffix and expr.endswith(remove_suffix):
|
||||
expr = expr[: -len(remove_suffix)]
|
||||
while expr and expr[0] == "(" and expr[-1] == ")":
|
||||
expr = expr[1:-1]
|
||||
|
||||
# print(f"START: {start} END: {expr}")
|
||||
# NOTE: when parsing the connection expression this cleanup could
|
||||
# be skipped
|
||||
for rs in self._default_modifiers_re:
|
||||
if match := rs.search(expr):
|
||||
start, end = match.span(1)
|
||||
expr = expr[:start] + expr[end:]
|
||||
break
|
||||
|
||||
while expr and expr[0] == "(" and expr[-1] == ")":
|
||||
expr = expr[1:-1]
|
||||
|
||||
# strip casts
|
||||
cast_re = re.compile(r"cast\s*\(")
|
||||
if cast_re.match(expr):
|
||||
expr = cast_re.sub("", expr)
|
||||
# remove the as type
|
||||
expr = re.sub(r"as\s+[^)]+\)", "", expr)
|
||||
# remove spaces
|
||||
expr = expr.replace(" ", "")
|
||||
return expr
|
||||
|
||||
def _default_modifiers(self, exp: ClauseElement) -> str:
|
||||
to_remove = ""
|
||||
while isinstance(exp, UnaryExpression):
|
||||
if exp.modifier is None:
|
||||
exp = exp.element
|
||||
else:
|
||||
op = exp.modifier
|
||||
if isinstance(exp.element, UnaryExpression):
|
||||
inner_op = exp.element.modifier
|
||||
else:
|
||||
inner_op = None
|
||||
if inner_op is None:
|
||||
if op == operators.asc_op:
|
||||
# default is asc
|
||||
to_remove = " asc"
|
||||
elif op == operators.nullslast_op:
|
||||
# default is nulls last
|
||||
to_remove = " nulls last"
|
||||
else:
|
||||
if (
|
||||
inner_op == operators.asc_op
|
||||
and op == operators.nullslast_op
|
||||
):
|
||||
# default is asc nulls last
|
||||
to_remove = " asc nulls last"
|
||||
elif (
|
||||
inner_op == operators.desc_op
|
||||
and op == operators.nullsfirst_op
|
||||
):
|
||||
# default for desc is nulls first
|
||||
to_remove = " nulls first"
|
||||
break
|
||||
return to_remove
|
||||
|
||||
def _dialect_sig(
|
||||
def _dialect_options(
|
||||
self, item: Union[Index, UniqueConstraint]
|
||||
) -> Tuple[Any, ...]:
|
||||
# only the positive case is returned by sqlalchemy reflection so
|
||||
@@ -316,25 +324,93 @@ class PostgresqlImpl(DefaultImpl):
|
||||
return ("nulls_not_distinct",)
|
||||
return ()
|
||||
|
||||
def create_index_sig(self, index: Index) -> Tuple[Any, ...]:
|
||||
return tuple(
|
||||
self._cleanup_index_expr(
|
||||
index,
|
||||
*(
|
||||
(e, "")
|
||||
if isinstance(e, str)
|
||||
else (self._compile_element(e), self._default_modifiers(e))
|
||||
),
|
||||
)
|
||||
for e in index.expressions
|
||||
) + self._dialect_sig(index)
|
||||
def compare_indexes(
|
||||
self,
|
||||
metadata_index: Index,
|
||||
reflected_index: Index,
|
||||
) -> ComparisonResult:
|
||||
msg = []
|
||||
unique_msg = self._compare_index_unique(
|
||||
metadata_index, reflected_index
|
||||
)
|
||||
if unique_msg:
|
||||
msg.append(unique_msg)
|
||||
m_exprs = metadata_index.expressions
|
||||
r_exprs = reflected_index.expressions
|
||||
if len(m_exprs) != len(r_exprs):
|
||||
msg.append(f"expression number {len(r_exprs)} to {len(m_exprs)}")
|
||||
if msg:
|
||||
# no point going further, return early
|
||||
return ComparisonResult.Different(msg)
|
||||
skip = []
|
||||
for pos, (m_e, r_e) in enumerate(zip(m_exprs, r_exprs), 1):
|
||||
m_compile = self._compile_element(m_e)
|
||||
m_text = self._cleanup_index_expr(metadata_index, m_compile)
|
||||
# print(f"META ORIG: {m_compile!r} CLEANUP: {m_text!r}")
|
||||
r_compile = self._compile_element(r_e)
|
||||
r_text = self._cleanup_index_expr(metadata_index, r_compile)
|
||||
# print(f"CONN ORIG: {r_compile!r} CLEANUP: {r_text!r}")
|
||||
if m_text == r_text:
|
||||
continue # expressions these are equal
|
||||
elif m_compile.strip().endswith("_ops") and (
|
||||
" " in m_compile or ")" in m_compile # is an expression
|
||||
):
|
||||
skip.append(
|
||||
f"expression #{pos} {m_compile!r} detected "
|
||||
"as including operator clause."
|
||||
)
|
||||
util.warn(
|
||||
f"Expression #{pos} {m_compile!r} in index "
|
||||
f"{reflected_index.name!r} detected to include "
|
||||
"an operator clause. Expression compare cannot proceed. "
|
||||
"Please move the operator clause to the "
|
||||
"``postgresql_ops`` dict to enable proper compare "
|
||||
"of the index expressions: "
|
||||
"https://docs.sqlalchemy.org/en/latest/dialects/postgresql.html#operator-classes", # noqa: E501
|
||||
)
|
||||
else:
|
||||
msg.append(f"expression #{pos} {r_compile!r} to {m_compile!r}")
|
||||
|
||||
def create_unique_constraint_sig(
|
||||
self, const: UniqueConstraint
|
||||
) -> Tuple[Any, ...]:
|
||||
return tuple(
|
||||
sorted([col.name for col in const.columns])
|
||||
) + self._dialect_sig(const)
|
||||
m_options = self._dialect_options(metadata_index)
|
||||
r_options = self._dialect_options(reflected_index)
|
||||
if m_options != r_options:
|
||||
msg.extend(f"options {r_options} to {m_options}")
|
||||
|
||||
if msg:
|
||||
return ComparisonResult.Different(msg)
|
||||
elif skip:
|
||||
# if there are other changes detected don't skip the index
|
||||
return ComparisonResult.Skip(skip)
|
||||
else:
|
||||
return ComparisonResult.Equal()
|
||||
|
||||
def compare_unique_constraint(
|
||||
self,
|
||||
metadata_constraint: UniqueConstraint,
|
||||
reflected_constraint: UniqueConstraint,
|
||||
) -> ComparisonResult:
|
||||
metadata_tup = self._create_metadata_constraint_sig(
|
||||
metadata_constraint
|
||||
)
|
||||
reflected_tup = self._create_reflected_constraint_sig(
|
||||
reflected_constraint
|
||||
)
|
||||
|
||||
meta_sig = metadata_tup.unnamed
|
||||
conn_sig = reflected_tup.unnamed
|
||||
if conn_sig != meta_sig:
|
||||
return ComparisonResult.Different(
|
||||
f"expression {conn_sig} to {meta_sig}"
|
||||
)
|
||||
|
||||
metadata_do = self._dialect_options(metadata_tup.const)
|
||||
conn_do = self._dialect_options(reflected_tup.const)
|
||||
if metadata_do != conn_do:
|
||||
return ComparisonResult.Different(
|
||||
f"expression {conn_do} to {metadata_do}"
|
||||
)
|
||||
|
||||
return ComparisonResult.Equal()
|
||||
|
||||
def adjust_reflected_dialect_options(
|
||||
self, reflected_options: Dict[str, Any], kind: str
|
||||
@@ -345,7 +421,9 @@ class PostgresqlImpl(DefaultImpl):
|
||||
options.pop("postgresql_include", None)
|
||||
return options
|
||||
|
||||
def _compile_element(self, element: ClauseElement) -> str:
|
||||
def _compile_element(self, element: Union[ClauseElement, str]) -> str:
|
||||
if isinstance(element, str):
|
||||
return element
|
||||
return element.compile(
|
||||
dialect=self.dialect,
|
||||
compile_kwargs={"literal_binds": True, "include_table": False},
|
||||
@@ -512,7 +590,7 @@ def visit_identity_column(
|
||||
)
|
||||
else:
|
||||
text += "SET %s " % compiler.get_identity_options(
|
||||
sqla_compat.Identity(**{attr: getattr(identity, attr)})
|
||||
Identity(**{attr: getattr(identity, attr)})
|
||||
)
|
||||
return text
|
||||
|
||||
@@ -556,9 +634,8 @@ class CreateExcludeConstraintOp(ops.AddConstraintOp):
|
||||
return cls(
|
||||
constraint.name,
|
||||
constraint_table.name,
|
||||
[
|
||||
(expr, op)
|
||||
for expr, name, op in constraint._render_exprs # type:ignore[attr-defined] # noqa
|
||||
[ # type: ignore
|
||||
(expr, op) for expr, name, op in constraint._render_exprs
|
||||
],
|
||||
where=cast("ColumnElement[bool] | None", constraint.where),
|
||||
schema=constraint_table.schema,
|
||||
@@ -585,7 +662,7 @@ class CreateExcludeConstraintOp(ops.AddConstraintOp):
|
||||
expr,
|
||||
name,
|
||||
oper,
|
||||
) in excl._render_exprs: # type:ignore[attr-defined]
|
||||
) in excl._render_exprs:
|
||||
t.append_column(Column(name, NULLTYPE))
|
||||
t.append_constraint(excl)
|
||||
return excl
|
||||
@@ -643,7 +720,7 @@ class CreateExcludeConstraintOp(ops.AddConstraintOp):
|
||||
constraint_name: str,
|
||||
*elements: Any,
|
||||
**kw: Any,
|
||||
):
|
||||
) -> Optional[Table]:
|
||||
"""Issue a "create exclude constraint" instruction using the
|
||||
current batch migration context.
|
||||
|
||||
@@ -715,10 +792,13 @@ def _exclude_constraint(
|
||||
args = [
|
||||
"(%s, %r)"
|
||||
% (
|
||||
_render_potential_column(sqltext, autogen_context),
|
||||
_render_potential_column(
|
||||
sqltext, # type:ignore[arg-type]
|
||||
autogen_context,
|
||||
),
|
||||
opstring,
|
||||
)
|
||||
for sqltext, name, opstring in constraint._render_exprs # type:ignore[attr-defined] # noqa
|
||||
for sqltext, name, opstring in constraint._render_exprs
|
||||
]
|
||||
if constraint.where is not None:
|
||||
args.append(
|
||||
@@ -770,5 +850,5 @@ def _render_potential_column(
|
||||
return render._render_potential_expr(
|
||||
value,
|
||||
autogen_context,
|
||||
wrap_in_text=isinstance(value, (TextClause, FunctionElement)),
|
||||
wrap_in_element=isinstance(value, (TextClause, FunctionElement)),
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user