API refactor
All checks were successful
continuous-integration/drone/push Build is passing

This commit is contained in:
2025-10-07 16:25:52 +09:00
parent 76d0d86211
commit 91c7e04474
1171 changed files with 81940 additions and 44117 deletions

View File

@@ -1,3 +1,6 @@
# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
# mypy: no-warn-return-any, allow-any-generics
from __future__ import annotations
import re
@@ -8,7 +11,9 @@ from typing import Union
from sqlalchemy import schema
from sqlalchemy import types as sqltypes
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.sql import elements
from sqlalchemy.sql import functions
from sqlalchemy.sql import operators
from .base import alter_table
from .base import AlterColumn
@@ -20,16 +25,16 @@ from .base import format_column_name
from .base import format_server_default
from .impl import DefaultImpl
from .. import util
from ..autogenerate import compare
from ..util import sqla_compat
from ..util.sqla_compat import _is_mariadb
from ..util.sqla_compat import _is_type_bound
from ..util.sqla_compat import compiles
if TYPE_CHECKING:
from typing import Literal
from sqlalchemy.dialects.mysql.base import MySQLDDLCompiler
from sqlalchemy.sql.ddl import DropConstraint
from sqlalchemy.sql.elements import ClauseElement
from sqlalchemy.sql.schema import Constraint
from sqlalchemy.sql.type_api import TypeEngine
@@ -46,12 +51,40 @@ class MySQLImpl(DefaultImpl):
)
type_arg_extract = [r"character set ([\w\-_]+)", r"collate ([\w\-_]+)"]
def alter_column( # type:ignore[override]
def render_ddl_sql_expr(
self,
expr: ClauseElement,
is_server_default: bool = False,
is_index: bool = False,
**kw: Any,
) -> str:
# apply Grouping to index expressions;
# see https://github.com/sqlalchemy/sqlalchemy/blob/
# 36da2eaf3e23269f2cf28420ae73674beafd0661/
# lib/sqlalchemy/dialects/mysql/base.py#L2191
if is_index and (
isinstance(expr, elements.BinaryExpression)
or (
isinstance(expr, elements.UnaryExpression)
and expr.modifier not in (operators.desc_op, operators.asc_op)
)
or isinstance(expr, functions.FunctionElement)
):
expr = elements.Grouping(expr)
return super().render_ddl_sql_expr(
expr, is_server_default=is_server_default, is_index=is_index, **kw
)
def alter_column(
self,
table_name: str,
column_name: str,
*,
nullable: Optional[bool] = None,
server_default: Union[_ServerDefault, Literal[False]] = False,
server_default: Optional[
Union[_ServerDefault, Literal[False]]
] = False,
name: Optional[str] = None,
type_: Optional[TypeEngine] = None,
schema: Optional[str] = None,
@@ -92,21 +125,29 @@ class MySQLImpl(DefaultImpl):
column_name,
schema=schema,
newname=name if name is not None else column_name,
nullable=nullable
if nullable is not None
else existing_nullable
if existing_nullable is not None
else True,
nullable=(
nullable
if nullable is not None
else (
existing_nullable
if existing_nullable is not None
else True
)
),
type_=type_ if type_ is not None else existing_type,
default=server_default
if server_default is not False
else existing_server_default,
autoincrement=autoincrement
if autoincrement is not None
else existing_autoincrement,
comment=comment
if comment is not False
else existing_comment,
default=(
server_default
if server_default is not False
else existing_server_default
),
autoincrement=(
autoincrement
if autoincrement is not None
else existing_autoincrement
),
comment=(
comment if comment is not False else existing_comment
),
)
)
elif (
@@ -121,21 +162,29 @@ class MySQLImpl(DefaultImpl):
column_name,
schema=schema,
newname=name if name is not None else column_name,
nullable=nullable
if nullable is not None
else existing_nullable
if existing_nullable is not None
else True,
nullable=(
nullable
if nullable is not None
else (
existing_nullable
if existing_nullable is not None
else True
)
),
type_=type_ if type_ is not None else existing_type,
default=server_default
if server_default is not False
else existing_server_default,
autoincrement=autoincrement
if autoincrement is not None
else existing_autoincrement,
comment=comment
if comment is not False
else existing_comment,
default=(
server_default
if server_default is not False
else existing_server_default
),
autoincrement=(
autoincrement
if autoincrement is not None
else existing_autoincrement
),
comment=(
comment if comment is not False else existing_comment
),
)
)
elif server_default is not False:
@@ -148,6 +197,7 @@ class MySQLImpl(DefaultImpl):
def drop_constraint(
self,
const: Constraint,
**kw: Any,
) -> None:
if isinstance(const, schema.CheckConstraint) and _is_type_bound(const):
return
@@ -157,12 +207,11 @@ class MySQLImpl(DefaultImpl):
def _is_mysql_allowed_functional_default(
self,
type_: Optional[TypeEngine],
server_default: Union[_ServerDefault, Literal[False]],
server_default: Optional[Union[_ServerDefault, Literal[False]]],
) -> bool:
return (
type_ is not None
and type_._type_affinity # type:ignore[attr-defined]
is sqltypes.DateTime
and type_._type_affinity is sqltypes.DateTime
and server_default is not None
)
@@ -272,10 +321,12 @@ class MySQLImpl(DefaultImpl):
def correct_for_autogen_foreignkeys(self, conn_fks, metadata_fks):
conn_fk_by_sig = {
compare._fk_constraint_sig(fk).sig: fk for fk in conn_fks
self._create_reflected_constraint_sig(fk).unnamed_no_options: fk
for fk in conn_fks
}
metadata_fk_by_sig = {
compare._fk_constraint_sig(fk).sig: fk for fk in metadata_fks
self._create_metadata_constraint_sig(fk).unnamed_no_options: fk
for fk in metadata_fks
}
for sig in set(conn_fk_by_sig).intersection(metadata_fk_by_sig):
@@ -307,7 +358,7 @@ class MySQLAlterDefault(AlterColumn):
self,
name: str,
column_name: str,
default: _ServerDefault,
default: Optional[_ServerDefault],
schema: Optional[str] = None,
) -> None:
super(AlterColumn, self).__init__(name, schema=schema)
@@ -365,9 +416,11 @@ def _mysql_alter_default(
return "%s ALTER COLUMN %s %s" % (
alter_table(compiler, element.table_name, element.schema),
format_column_name(compiler, element.column_name),
"SET DEFAULT %s" % format_server_default(compiler, element.default)
if element.default is not None
else "DROP DEFAULT",
(
"SET DEFAULT %s" % format_server_default(compiler, element.default)
if element.default is not None
else "DROP DEFAULT"
),
)
@@ -454,7 +507,7 @@ def _mysql_drop_constraint(
# note that SQLAlchemy as of 1.2 does not yet support
# DROP CONSTRAINT for MySQL/MariaDB, so we implement fully
# here.
if _is_mariadb(compiler.dialect):
if compiler.dialect.is_mariadb:
return "ALTER TABLE %s DROP CONSTRAINT %s" % (
compiler.preparer.format_table(constraint.table),
compiler.preparer.format_constraint(constraint),