main commit
All checks were successful
continuous-integration/drone/push Build is passing

This commit is contained in:
2025-10-16 16:30:25 +09:00
parent 91c7e04474
commit 537e7b363f
1146 changed files with 45926 additions and 77196 deletions

View File

@@ -1,6 +1,3 @@
# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
# mypy: no-warn-return-any, allow-any-generics
from __future__ import annotations
import logging
@@ -16,19 +13,18 @@ from typing import TYPE_CHECKING
from typing import Union
from sqlalchemy import Column
from sqlalchemy import Float
from sqlalchemy import Identity
from sqlalchemy import literal_column
from sqlalchemy import Numeric
from sqlalchemy import select
from sqlalchemy import text
from sqlalchemy import types as sqltypes
from sqlalchemy.dialects.postgresql import BIGINT
from sqlalchemy.dialects.postgresql import ExcludeConstraint
from sqlalchemy.dialects.postgresql import INTEGER
from sqlalchemy.schema import CreateIndex
from sqlalchemy.sql import operators
from sqlalchemy.sql.elements import ColumnClause
from sqlalchemy.sql.elements import TextClause
from sqlalchemy.sql.elements import UnaryExpression
from sqlalchemy.sql.functions import FunctionElement
from sqlalchemy.types import NULLTYPE
@@ -36,12 +32,12 @@ from .base import alter_column
from .base import alter_table
from .base import AlterColumn
from .base import ColumnComment
from .base import compiles
from .base import format_column_name
from .base import format_table_name
from .base import format_type
from .base import IdentityColumnDefault
from .base import RenameTable
from .impl import ComparisonResult
from .impl import DefaultImpl
from .. import util
from ..autogenerate import render
@@ -50,8 +46,6 @@ from ..operations import schemaobj
from ..operations.base import BatchOperations
from ..operations.base import Operations
from ..util import sqla_compat
from ..util.sqla_compat import compiles
if TYPE_CHECKING:
from typing import Literal
@@ -136,28 +130,25 @@ class PostgresqlImpl(DefaultImpl):
metadata_default = metadata_column.server_default.arg
if isinstance(metadata_default, str):
if not isinstance(inspector_column.type, (Numeric, Float)):
if not isinstance(inspector_column.type, Numeric):
metadata_default = re.sub(r"^'|'$", "", metadata_default)
metadata_default = f"'{metadata_default}'"
metadata_default = literal_column(metadata_default)
# run a real compare against the server
conn = self.connection
assert conn is not None
return not conn.scalar(
select(literal_column(conn_col_default) == metadata_default)
return not self.connection.scalar(
sqla_compat._select(
literal_column(conn_col_default) == metadata_default
)
)
def alter_column(
def alter_column( # type:ignore[override]
self,
table_name: str,
column_name: str,
*,
nullable: Optional[bool] = None,
server_default: Optional[
Union[_ServerDefault, Literal[False]]
] = False,
server_default: Union[_ServerDefault, Literal[False]] = False,
name: Optional[str] = None,
type_: Optional[TypeEngine] = None,
schema: Optional[str] = None,
@@ -223,8 +214,7 @@ class PostgresqlImpl(DefaultImpl):
"join pg_class t on t.oid=d.refobjid "
"join pg_attribute a on a.attrelid=t.oid and "
"a.attnum=d.refobjsubid "
"where c.relkind='S' and "
"c.oid=cast(:seqname as regclass)"
"where c.relkind='S' and c.relname=:seqname"
),
seqname=seq_match.group(1),
).first()
@@ -262,60 +252,62 @@ class PostgresqlImpl(DefaultImpl):
if not sqla_compat.sqla_2:
self._skip_functional_indexes(metadata_indexes, conn_indexes)
# pg behavior regarding modifiers
# | # | compiled sql | returned sql | regexp. group is removed |
# | - | ---------------- | -----------------| ------------------------ |
# | 1 | nulls first | nulls first | - |
# | 2 | nulls last | | (?<! desc)( nulls last)$ |
# | 3 | asc | | ( asc)$ |
# | 4 | asc nulls first | nulls first | ( asc) nulls first$ |
# | 5 | asc nulls last | | ( asc nulls last)$ |
# | 6 | desc | desc | - |
# | 7 | desc nulls first | desc | desc( nulls first)$ |
# | 8 | desc nulls last | desc nulls last | - |
_default_modifiers_re = ( # order of case 2 and 5 matters
re.compile("( asc nulls last)$"), # case 5
re.compile("(?<! desc)( nulls last)$"), # case 2
re.compile("( asc)$"), # case 3
re.compile("( asc) nulls first$"), # case 4
re.compile(" desc( nulls first)$"), # case 7
)
def _cleanup_index_expr(self, index: Index, expr: str) -> str:
def _cleanup_index_expr(
self, index: Index, expr: str, remove_suffix: str
) -> str:
# start = expr
expr = expr.lower().replace('"', "").replace("'", "")
if index.table is not None:
# should not be needed, since include_table=False is in compile
expr = expr.replace(f"{index.table.name.lower()}.", "")
while expr and expr[0] == "(" and expr[-1] == ")":
expr = expr[1:-1]
if "::" in expr:
# strip :: cast. types can have spaces in them
expr = re.sub(r"(::[\w ]+\w)", "", expr)
while expr and expr[0] == "(" and expr[-1] == ")":
expr = expr[1:-1]
if remove_suffix and expr.endswith(remove_suffix):
expr = expr[: -len(remove_suffix)]
# NOTE: when parsing the connection expression this cleanup could
# be skipped
for rs in self._default_modifiers_re:
if match := rs.search(expr):
start, end = match.span(1)
expr = expr[:start] + expr[end:]
break
while expr and expr[0] == "(" and expr[-1] == ")":
expr = expr[1:-1]
# strip casts
cast_re = re.compile(r"cast\s*\(")
if cast_re.match(expr):
expr = cast_re.sub("", expr)
# remove the as type
expr = re.sub(r"as\s+[^)]+\)", "", expr)
# remove spaces
expr = expr.replace(" ", "")
# print(f"START: {start} END: {expr}")
return expr
def _dialect_options(
def _default_modifiers(self, exp: ClauseElement) -> str:
to_remove = ""
while isinstance(exp, UnaryExpression):
if exp.modifier is None:
exp = exp.element
else:
op = exp.modifier
if isinstance(exp.element, UnaryExpression):
inner_op = exp.element.modifier
else:
inner_op = None
if inner_op is None:
if op == operators.asc_op:
# default is asc
to_remove = " asc"
elif op == operators.nullslast_op:
# default is nulls last
to_remove = " nulls last"
else:
if (
inner_op == operators.asc_op
and op == operators.nullslast_op
):
# default is asc nulls last
to_remove = " asc nulls last"
elif (
inner_op == operators.desc_op
and op == operators.nullsfirst_op
):
# default for desc is nulls first
to_remove = " nulls first"
break
return to_remove
def _dialect_sig(
self, item: Union[Index, UniqueConstraint]
) -> Tuple[Any, ...]:
# only the positive case is returned by sqlalchemy reflection so
@@ -324,93 +316,25 @@ class PostgresqlImpl(DefaultImpl):
return ("nulls_not_distinct",)
return ()
def compare_indexes(
self,
metadata_index: Index,
reflected_index: Index,
) -> ComparisonResult:
msg = []
unique_msg = self._compare_index_unique(
metadata_index, reflected_index
)
if unique_msg:
msg.append(unique_msg)
m_exprs = metadata_index.expressions
r_exprs = reflected_index.expressions
if len(m_exprs) != len(r_exprs):
msg.append(f"expression number {len(r_exprs)} to {len(m_exprs)}")
if msg:
# no point going further, return early
return ComparisonResult.Different(msg)
skip = []
for pos, (m_e, r_e) in enumerate(zip(m_exprs, r_exprs), 1):
m_compile = self._compile_element(m_e)
m_text = self._cleanup_index_expr(metadata_index, m_compile)
# print(f"META ORIG: {m_compile!r} CLEANUP: {m_text!r}")
r_compile = self._compile_element(r_e)
r_text = self._cleanup_index_expr(metadata_index, r_compile)
# print(f"CONN ORIG: {r_compile!r} CLEANUP: {r_text!r}")
if m_text == r_text:
continue # expressions these are equal
elif m_compile.strip().endswith("_ops") and (
" " in m_compile or ")" in m_compile # is an expression
):
skip.append(
f"expression #{pos} {m_compile!r} detected "
"as including operator clause."
)
util.warn(
f"Expression #{pos} {m_compile!r} in index "
f"{reflected_index.name!r} detected to include "
"an operator clause. Expression compare cannot proceed. "
"Please move the operator clause to the "
"``postgresql_ops`` dict to enable proper compare "
"of the index expressions: "
"https://docs.sqlalchemy.org/en/latest/dialects/postgresql.html#operator-classes", # noqa: E501
)
else:
msg.append(f"expression #{pos} {r_compile!r} to {m_compile!r}")
m_options = self._dialect_options(metadata_index)
r_options = self._dialect_options(reflected_index)
if m_options != r_options:
msg.extend(f"options {r_options} to {m_options}")
if msg:
return ComparisonResult.Different(msg)
elif skip:
# if there are other changes detected don't skip the index
return ComparisonResult.Skip(skip)
else:
return ComparisonResult.Equal()
def compare_unique_constraint(
self,
metadata_constraint: UniqueConstraint,
reflected_constraint: UniqueConstraint,
) -> ComparisonResult:
metadata_tup = self._create_metadata_constraint_sig(
metadata_constraint
)
reflected_tup = self._create_reflected_constraint_sig(
reflected_constraint
)
meta_sig = metadata_tup.unnamed
conn_sig = reflected_tup.unnamed
if conn_sig != meta_sig:
return ComparisonResult.Different(
f"expression {conn_sig} to {meta_sig}"
def create_index_sig(self, index: Index) -> Tuple[Any, ...]:
return tuple(
self._cleanup_index_expr(
index,
*(
(e, "")
if isinstance(e, str)
else (self._compile_element(e), self._default_modifiers(e))
),
)
for e in index.expressions
) + self._dialect_sig(index)
metadata_do = self._dialect_options(metadata_tup.const)
conn_do = self._dialect_options(reflected_tup.const)
if metadata_do != conn_do:
return ComparisonResult.Different(
f"expression {conn_do} to {metadata_do}"
)
return ComparisonResult.Equal()
def create_unique_constraint_sig(
self, const: UniqueConstraint
) -> Tuple[Any, ...]:
return tuple(
sorted([col.name for col in const.columns])
) + self._dialect_sig(const)
def adjust_reflected_dialect_options(
self, reflected_options: Dict[str, Any], kind: str
@@ -421,9 +345,7 @@ class PostgresqlImpl(DefaultImpl):
options.pop("postgresql_include", None)
return options
def _compile_element(self, element: Union[ClauseElement, str]) -> str:
if isinstance(element, str):
return element
def _compile_element(self, element: ClauseElement) -> str:
return element.compile(
dialect=self.dialect,
compile_kwargs={"literal_binds": True, "include_table": False},
@@ -590,7 +512,7 @@ def visit_identity_column(
)
else:
text += "SET %s " % compiler.get_identity_options(
Identity(**{attr: getattr(identity, attr)})
sqla_compat.Identity(**{attr: getattr(identity, attr)})
)
return text
@@ -634,8 +556,9 @@ class CreateExcludeConstraintOp(ops.AddConstraintOp):
return cls(
constraint.name,
constraint_table.name,
[ # type: ignore
(expr, op) for expr, name, op in constraint._render_exprs
[
(expr, op)
for expr, name, op in constraint._render_exprs # type:ignore[attr-defined] # noqa
],
where=cast("ColumnElement[bool] | None", constraint.where),
schema=constraint_table.schema,
@@ -662,7 +585,7 @@ class CreateExcludeConstraintOp(ops.AddConstraintOp):
expr,
name,
oper,
) in excl._render_exprs:
) in excl._render_exprs: # type:ignore[attr-defined]
t.append_column(Column(name, NULLTYPE))
t.append_constraint(excl)
return excl
@@ -720,7 +643,7 @@ class CreateExcludeConstraintOp(ops.AddConstraintOp):
constraint_name: str,
*elements: Any,
**kw: Any,
) -> Optional[Table]:
):
"""Issue a "create exclude constraint" instruction using the
current batch migration context.
@@ -792,13 +715,10 @@ def _exclude_constraint(
args = [
"(%s, %r)"
% (
_render_potential_column(
sqltext, # type:ignore[arg-type]
autogen_context,
),
_render_potential_column(sqltext, autogen_context),
opstring,
)
for sqltext, name, opstring in constraint._render_exprs
for sqltext, name, opstring in constraint._render_exprs # type:ignore[attr-defined] # noqa
]
if constraint.where is not None:
args.append(
@@ -850,5 +770,5 @@ def _render_potential_column(
return render._render_potential_expr(
value,
autogen_context,
wrap_in_element=isinstance(value, (TextClause, FunctionElement)),
wrap_in_text=isinstance(value, (TextClause, FunctionElement)),
)