This commit is contained in:
@@ -1,4 +1,6 @@
|
||||
import sys
|
||||
|
||||
from . import context
|
||||
from . import op
|
||||
|
||||
__version__ = "1.16.5"
|
||||
__version__ = "1.12.1"
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
from .api import _render_migration_diffs as _render_migration_diffs
|
||||
from .api import compare_metadata as compare_metadata
|
||||
from .api import produce_migrations as produce_migrations
|
||||
from .api import render_python_code as render_python_code
|
||||
from .api import RevisionContext as RevisionContext
|
||||
from .compare import _produce_net_changes as _produce_net_changes
|
||||
from .compare import comparators as comparators
|
||||
from .render import render_op_text as render_op_text
|
||||
from .render import renderers as renderers
|
||||
from .rewriter import Rewriter as Rewriter
|
||||
from .api import _render_migration_diffs
|
||||
from .api import compare_metadata
|
||||
from .api import produce_migrations
|
||||
from .api import render_python_code
|
||||
from .api import RevisionContext
|
||||
from .compare import _produce_net_changes
|
||||
from .compare import comparators
|
||||
from .render import render_op_text
|
||||
from .render import renderers
|
||||
from .rewriter import Rewriter
|
||||
|
||||
@@ -17,7 +17,6 @@ from . import compare
|
||||
from . import render
|
||||
from .. import util
|
||||
from ..operations import ops
|
||||
from ..util import sqla_compat
|
||||
|
||||
"""Provide the 'autogenerate' feature which can produce migration operations
|
||||
automatically."""
|
||||
@@ -28,7 +27,6 @@ if TYPE_CHECKING:
|
||||
from sqlalchemy.engine import Inspector
|
||||
from sqlalchemy.sql.schema import MetaData
|
||||
from sqlalchemy.sql.schema import SchemaItem
|
||||
from sqlalchemy.sql.schema import Table
|
||||
|
||||
from ..config import Config
|
||||
from ..operations.ops import DowngradeOps
|
||||
@@ -166,7 +164,6 @@ def compare_metadata(context: MigrationContext, metadata: MetaData) -> Any:
|
||||
"""
|
||||
|
||||
migration_script = produce_migrations(context, metadata)
|
||||
assert migration_script.upgrade_ops is not None
|
||||
return migration_script.upgrade_ops.as_diffs()
|
||||
|
||||
|
||||
@@ -277,7 +274,7 @@ class AutogenContext:
|
||||
"""Maintains configuration and state that's specific to an
|
||||
autogenerate operation."""
|
||||
|
||||
metadata: Union[MetaData, Sequence[MetaData], None] = None
|
||||
metadata: Optional[MetaData] = None
|
||||
"""The :class:`~sqlalchemy.schema.MetaData` object
|
||||
representing the destination.
|
||||
|
||||
@@ -332,8 +329,8 @@ class AutogenContext:
|
||||
def __init__(
|
||||
self,
|
||||
migration_context: MigrationContext,
|
||||
metadata: Union[MetaData, Sequence[MetaData], None] = None,
|
||||
opts: Optional[Dict[str, Any]] = None,
|
||||
metadata: Optional[MetaData] = None,
|
||||
opts: Optional[dict] = None,
|
||||
autogenerate: bool = True,
|
||||
) -> None:
|
||||
if (
|
||||
@@ -443,7 +440,7 @@ class AutogenContext:
|
||||
def run_object_filters(
|
||||
self,
|
||||
object_: SchemaItem,
|
||||
name: sqla_compat._ConstraintName,
|
||||
name: Optional[str],
|
||||
type_: NameFilterType,
|
||||
reflected: bool,
|
||||
compare_to: Optional[SchemaItem],
|
||||
@@ -467,7 +464,7 @@ class AutogenContext:
|
||||
run_filters = run_object_filters
|
||||
|
||||
@util.memoized_property
|
||||
def sorted_tables(self) -> List[Table]:
|
||||
def sorted_tables(self):
|
||||
"""Return an aggregate of the :attr:`.MetaData.sorted_tables`
|
||||
collection(s).
|
||||
|
||||
@@ -483,7 +480,7 @@ class AutogenContext:
|
||||
return result
|
||||
|
||||
@util.memoized_property
|
||||
def table_key_to_table(self) -> Dict[str, Table]:
|
||||
def table_key_to_table(self):
|
||||
"""Return an aggregate of the :attr:`.MetaData.tables` dictionaries.
|
||||
|
||||
The :attr:`.MetaData.tables` collection is a dictionary of table key
|
||||
@@ -494,7 +491,7 @@ class AutogenContext:
|
||||
objects contain the same table key, an exception is raised.
|
||||
|
||||
"""
|
||||
result: Dict[str, Table] = {}
|
||||
result = {}
|
||||
for m in util.to_list(self.metadata):
|
||||
intersect = set(result).intersection(set(m.tables))
|
||||
if intersect:
|
||||
@@ -596,9 +593,9 @@ class RevisionContext:
|
||||
migration_script = self.generated_revisions[-1]
|
||||
if not getattr(migration_script, "_needs_render", False):
|
||||
migration_script.upgrade_ops_list[-1].upgrade_token = upgrade_token
|
||||
migration_script.downgrade_ops_list[-1].downgrade_token = (
|
||||
downgrade_token
|
||||
)
|
||||
migration_script.downgrade_ops_list[
|
||||
-1
|
||||
].downgrade_token = downgrade_token
|
||||
migration_script._needs_render = True
|
||||
else:
|
||||
migration_script._upgrade_ops.append(
|
||||
|
||||
@@ -1,6 +1,3 @@
|
||||
# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
|
||||
# mypy: no-warn-return-any, allow-any-generics
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
@@ -10,12 +7,12 @@ from typing import Any
|
||||
from typing import cast
|
||||
from typing import Dict
|
||||
from typing import Iterator
|
||||
from typing import List
|
||||
from typing import Mapping
|
||||
from typing import Optional
|
||||
from typing import Set
|
||||
from typing import Tuple
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TypeVar
|
||||
from typing import Union
|
||||
|
||||
from sqlalchemy import event
|
||||
@@ -24,15 +21,10 @@ from sqlalchemy import schema as sa_schema
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy import types as sqltypes
|
||||
from sqlalchemy.sql import expression
|
||||
from sqlalchemy.sql.elements import conv
|
||||
from sqlalchemy.sql.schema import ForeignKeyConstraint
|
||||
from sqlalchemy.sql.schema import Index
|
||||
from sqlalchemy.sql.schema import UniqueConstraint
|
||||
from sqlalchemy.util import OrderedSet
|
||||
|
||||
from alembic.ddl.base import _fk_spec
|
||||
from .. import util
|
||||
from ..ddl._autogen import is_index_sig
|
||||
from ..ddl._autogen import is_uq_sig
|
||||
from ..operations import ops
|
||||
from ..util import sqla_compat
|
||||
|
||||
@@ -43,7 +35,10 @@ if TYPE_CHECKING:
|
||||
from sqlalchemy.sql.elements import quoted_name
|
||||
from sqlalchemy.sql.elements import TextClause
|
||||
from sqlalchemy.sql.schema import Column
|
||||
from sqlalchemy.sql.schema import ForeignKeyConstraint
|
||||
from sqlalchemy.sql.schema import Index
|
||||
from sqlalchemy.sql.schema import Table
|
||||
from sqlalchemy.sql.schema import UniqueConstraint
|
||||
|
||||
from alembic.autogenerate.api import AutogenContext
|
||||
from alembic.ddl.impl import DefaultImpl
|
||||
@@ -51,8 +46,6 @@ if TYPE_CHECKING:
|
||||
from alembic.operations.ops import MigrationScript
|
||||
from alembic.operations.ops import ModifyTableOps
|
||||
from alembic.operations.ops import UpgradeOps
|
||||
from ..ddl._autogen import _constraint_sig
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
@@ -217,7 +210,7 @@ def _compare_tables(
|
||||
(inspector),
|
||||
# fmt: on
|
||||
)
|
||||
_InspectorConv(inspector).reflect_table(t, include_columns=None)
|
||||
sqla_compat._reflect_table(inspector, t)
|
||||
if autogen_context.run_object_filters(t, tname, "table", True, None):
|
||||
modify_table_ops = ops.ModifyTableOps(tname, [], schema=s)
|
||||
|
||||
@@ -247,8 +240,7 @@ def _compare_tables(
|
||||
_compat_autogen_column_reflect(inspector),
|
||||
# fmt: on
|
||||
)
|
||||
_InspectorConv(inspector).reflect_table(t, include_columns=None)
|
||||
|
||||
sqla_compat._reflect_table(inspector, t)
|
||||
conn_column_info[(s, tname)] = t
|
||||
|
||||
for s, tname in sorted(existing_tables, key=lambda x: (x[0] or "", x[1])):
|
||||
@@ -437,56 +429,102 @@ def _compare_columns(
|
||||
log.info("Detected removed column '%s.%s'", name, cname)
|
||||
|
||||
|
||||
_C = TypeVar("_C", bound=Union[UniqueConstraint, ForeignKeyConstraint, Index])
|
||||
class _constraint_sig:
|
||||
const: Union[UniqueConstraint, ForeignKeyConstraint, Index]
|
||||
|
||||
|
||||
class _InspectorConv:
|
||||
__slots__ = ("inspector",)
|
||||
|
||||
def __init__(self, inspector):
|
||||
self.inspector = inspector
|
||||
|
||||
def _apply_reflectinfo_conv(self, consts):
|
||||
if not consts:
|
||||
return consts
|
||||
for const in consts:
|
||||
if const["name"] is not None and not isinstance(
|
||||
const["name"], conv
|
||||
):
|
||||
const["name"] = conv(const["name"])
|
||||
return consts
|
||||
|
||||
def _apply_constraint_conv(self, consts):
|
||||
if not consts:
|
||||
return consts
|
||||
for const in consts:
|
||||
if const.name is not None and not isinstance(const.name, conv):
|
||||
const.name = conv(const.name)
|
||||
return consts
|
||||
|
||||
def get_indexes(self, *args, **kw):
|
||||
return self._apply_reflectinfo_conv(
|
||||
self.inspector.get_indexes(*args, **kw)
|
||||
def md_name_to_sql_name(self, context: AutogenContext) -> Optional[str]:
|
||||
return sqla_compat._get_constraint_final_name(
|
||||
self.const, context.dialect
|
||||
)
|
||||
|
||||
def get_unique_constraints(self, *args, **kw):
|
||||
return self._apply_reflectinfo_conv(
|
||||
self.inspector.get_unique_constraints(*args, **kw)
|
||||
def __eq__(self, other):
|
||||
return self.const == other.const
|
||||
|
||||
def __ne__(self, other):
|
||||
return self.const != other.const
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(self.const)
|
||||
|
||||
|
||||
class _uq_constraint_sig(_constraint_sig):
|
||||
is_index = False
|
||||
is_unique = True
|
||||
|
||||
def __init__(self, const: UniqueConstraint, impl: DefaultImpl) -> None:
|
||||
self.const = const
|
||||
self.name = const.name
|
||||
self.sig = ("UNIQUE_CONSTRAINT",) + impl.create_unique_constraint_sig(
|
||||
const
|
||||
)
|
||||
|
||||
def get_foreign_keys(self, *args, **kw):
|
||||
return self._apply_reflectinfo_conv(
|
||||
self.inspector.get_foreign_keys(*args, **kw)
|
||||
@property
|
||||
def column_names(self) -> List[str]:
|
||||
return [col.name for col in self.const.columns]
|
||||
|
||||
|
||||
class _ix_constraint_sig(_constraint_sig):
|
||||
is_index = True
|
||||
|
||||
def __init__(self, const: Index, impl: DefaultImpl) -> None:
|
||||
self.const = const
|
||||
self.name = const.name
|
||||
self.sig = ("INDEX",) + impl.create_index_sig(const)
|
||||
self.is_unique = bool(const.unique)
|
||||
|
||||
def md_name_to_sql_name(self, context: AutogenContext) -> Optional[str]:
|
||||
return sqla_compat._get_constraint_final_name(
|
||||
self.const, context.dialect
|
||||
)
|
||||
|
||||
def reflect_table(self, table, *, include_columns):
|
||||
self.inspector.reflect_table(table, include_columns=include_columns)
|
||||
@property
|
||||
def column_names(self) -> Union[List[quoted_name], List[None]]:
|
||||
return sqla_compat._get_index_column_names(self.const)
|
||||
|
||||
# I had a cool version of this using _ReflectInfo, however that doesn't
|
||||
# work in 1.4 and it's not public API in 2.x. Then this is just a two
|
||||
# liner. So there's no competition...
|
||||
self._apply_constraint_conv(table.constraints)
|
||||
self._apply_constraint_conv(table.indexes)
|
||||
|
||||
class _fk_constraint_sig(_constraint_sig):
|
||||
def __init__(
|
||||
self, const: ForeignKeyConstraint, include_options: bool = False
|
||||
) -> None:
|
||||
self.const = const
|
||||
self.name = const.name
|
||||
|
||||
(
|
||||
self.source_schema,
|
||||
self.source_table,
|
||||
self.source_columns,
|
||||
self.target_schema,
|
||||
self.target_table,
|
||||
self.target_columns,
|
||||
onupdate,
|
||||
ondelete,
|
||||
deferrable,
|
||||
initially,
|
||||
) = _fk_spec(const)
|
||||
|
||||
self.sig: Tuple[Any, ...] = (
|
||||
self.source_schema,
|
||||
self.source_table,
|
||||
tuple(self.source_columns),
|
||||
self.target_schema,
|
||||
self.target_table,
|
||||
tuple(self.target_columns),
|
||||
)
|
||||
if include_options:
|
||||
self.sig += (
|
||||
(None if onupdate.lower() == "no action" else onupdate.lower())
|
||||
if onupdate
|
||||
else None,
|
||||
(None if ondelete.lower() == "no action" else ondelete.lower())
|
||||
if ondelete
|
||||
else None,
|
||||
# convert initially + deferrable into one three-state value
|
||||
"initially_deferrable"
|
||||
if initially and initially.lower() == "deferred"
|
||||
else "deferrable"
|
||||
if deferrable
|
||||
else "not deferrable",
|
||||
)
|
||||
|
||||
|
||||
@comparators.dispatch_for("table")
|
||||
@@ -523,34 +561,34 @@ def _compare_indexes_and_uniques(
|
||||
|
||||
if conn_table is not None:
|
||||
# 1b. ... and from connection, if the table exists
|
||||
try:
|
||||
conn_uniques = _InspectorConv(inspector).get_unique_constraints(
|
||||
tname, schema=schema
|
||||
)
|
||||
|
||||
supports_unique_constraints = True
|
||||
except NotImplementedError:
|
||||
pass
|
||||
except TypeError:
|
||||
# number of arguments is off for the base
|
||||
# method in SQLAlchemy due to the cache decorator
|
||||
# not being present
|
||||
pass
|
||||
else:
|
||||
conn_uniques = [ # type:ignore[assignment]
|
||||
uq
|
||||
for uq in conn_uniques
|
||||
if autogen_context.run_name_filters(
|
||||
uq["name"],
|
||||
"unique_constraint",
|
||||
{"table_name": tname, "schema_name": schema},
|
||||
if hasattr(inspector, "get_unique_constraints"):
|
||||
try:
|
||||
conn_uniques = inspector.get_unique_constraints( # type:ignore[assignment] # noqa
|
||||
tname, schema=schema
|
||||
)
|
||||
]
|
||||
for uq in conn_uniques:
|
||||
if uq.get("duplicates_index"):
|
||||
unique_constraints_duplicate_unique_indexes = True
|
||||
supports_unique_constraints = True
|
||||
except NotImplementedError:
|
||||
pass
|
||||
except TypeError:
|
||||
# number of arguments is off for the base
|
||||
# method in SQLAlchemy due to the cache decorator
|
||||
# not being present
|
||||
pass
|
||||
else:
|
||||
conn_uniques = [ # type:ignore[assignment]
|
||||
uq
|
||||
for uq in conn_uniques
|
||||
if autogen_context.run_name_filters(
|
||||
uq["name"],
|
||||
"unique_constraint",
|
||||
{"table_name": tname, "schema_name": schema},
|
||||
)
|
||||
]
|
||||
for uq in conn_uniques:
|
||||
if uq.get("duplicates_index"):
|
||||
unique_constraints_duplicate_unique_indexes = True
|
||||
try:
|
||||
conn_indexes = _InspectorConv(inspector).get_indexes(
|
||||
conn_indexes = inspector.get_indexes( # type:ignore[assignment]
|
||||
tname, schema=schema
|
||||
)
|
||||
except NotImplementedError:
|
||||
@@ -601,7 +639,7 @@ def _compare_indexes_and_uniques(
|
||||
# 3. give the dialect a chance to omit indexes and constraints that
|
||||
# we know are either added implicitly by the DB or that the DB
|
||||
# can't accurately report on
|
||||
impl.correct_for_autogen_constraints(
|
||||
autogen_context.migration_context.impl.correct_for_autogen_constraints(
|
||||
conn_uniques, # type: ignore[arg-type]
|
||||
conn_indexes, # type: ignore[arg-type]
|
||||
metadata_unique_constraints,
|
||||
@@ -613,31 +651,31 @@ def _compare_indexes_and_uniques(
|
||||
# Index and UniqueConstraint so we can easily work with them
|
||||
# interchangeably
|
||||
metadata_unique_constraints_sig = {
|
||||
impl._create_metadata_constraint_sig(uq)
|
||||
for uq in metadata_unique_constraints
|
||||
_uq_constraint_sig(uq, impl) for uq in metadata_unique_constraints
|
||||
}
|
||||
|
||||
metadata_indexes_sig = {
|
||||
impl._create_metadata_constraint_sig(ix) for ix in metadata_indexes
|
||||
_ix_constraint_sig(ix, impl) for ix in metadata_indexes
|
||||
}
|
||||
|
||||
conn_unique_constraints = {
|
||||
impl._create_reflected_constraint_sig(uq) for uq in conn_uniques
|
||||
_uq_constraint_sig(uq, impl) for uq in conn_uniques
|
||||
}
|
||||
|
||||
conn_indexes_sig = {
|
||||
impl._create_reflected_constraint_sig(ix) for ix in conn_indexes
|
||||
}
|
||||
conn_indexes_sig = {_ix_constraint_sig(ix, impl) for ix in conn_indexes}
|
||||
|
||||
# 5. index things by name, for those objects that have names
|
||||
metadata_names = {
|
||||
cast(str, c.md_name_to_sql_name(autogen_context)): c
|
||||
for c in metadata_unique_constraints_sig.union(metadata_indexes_sig)
|
||||
if c.is_named
|
||||
for c in metadata_unique_constraints_sig.union(
|
||||
metadata_indexes_sig # type:ignore[arg-type]
|
||||
)
|
||||
if isinstance(c, _ix_constraint_sig)
|
||||
or sqla_compat._constraint_is_named(c.const, autogen_context.dialect)
|
||||
}
|
||||
|
||||
conn_uniques_by_name: Dict[sqla_compat._ConstraintName, _constraint_sig]
|
||||
conn_indexes_by_name: Dict[sqla_compat._ConstraintName, _constraint_sig]
|
||||
conn_uniques_by_name: Dict[sqla_compat._ConstraintName, _uq_constraint_sig]
|
||||
conn_indexes_by_name: Dict[sqla_compat._ConstraintName, _ix_constraint_sig]
|
||||
|
||||
conn_uniques_by_name = {c.name: c for c in conn_unique_constraints}
|
||||
conn_indexes_by_name = {c.name: c for c in conn_indexes_sig}
|
||||
@@ -656,12 +694,13 @@ def _compare_indexes_and_uniques(
|
||||
|
||||
# 6. index things by "column signature", to help with unnamed unique
|
||||
# constraints.
|
||||
conn_uniques_by_sig = {uq.unnamed: uq for uq in conn_unique_constraints}
|
||||
conn_uniques_by_sig = {uq.sig: uq for uq in conn_unique_constraints}
|
||||
metadata_uniques_by_sig = {
|
||||
uq.unnamed: uq for uq in metadata_unique_constraints_sig
|
||||
uq.sig: uq for uq in metadata_unique_constraints_sig
|
||||
}
|
||||
metadata_indexes_by_sig = {ix.sig: ix for ix in metadata_indexes_sig}
|
||||
unnamed_metadata_uniques = {
|
||||
uq.unnamed: uq
|
||||
uq.sig: uq
|
||||
for uq in metadata_unique_constraints_sig
|
||||
if not sqla_compat._constraint_is_named(
|
||||
uq.const, autogen_context.dialect
|
||||
@@ -676,18 +715,18 @@ def _compare_indexes_and_uniques(
|
||||
# 4. The backend may double up indexes as unique constraints and
|
||||
# vice versa (e.g. MySQL, Postgresql)
|
||||
|
||||
def obj_added(obj: _constraint_sig):
|
||||
if is_index_sig(obj):
|
||||
def obj_added(obj):
|
||||
if obj.is_index:
|
||||
if autogen_context.run_object_filters(
|
||||
obj.const, obj.name, "index", False, None
|
||||
):
|
||||
modify_ops.ops.append(ops.CreateIndexOp.from_index(obj.const))
|
||||
log.info(
|
||||
"Detected added index %r on '%s'",
|
||||
"Detected added index '%s' on %s",
|
||||
obj.name,
|
||||
obj.column_names,
|
||||
", ".join(["'%s'" % obj.column_names]),
|
||||
)
|
||||
elif is_uq_sig(obj):
|
||||
else:
|
||||
if not supports_unique_constraints:
|
||||
# can't report unique indexes as added if we don't
|
||||
# detect them
|
||||
@@ -702,15 +741,13 @@ def _compare_indexes_and_uniques(
|
||||
ops.AddConstraintOp.from_constraint(obj.const)
|
||||
)
|
||||
log.info(
|
||||
"Detected added unique constraint %r on '%s'",
|
||||
"Detected added unique constraint '%s' on %s",
|
||||
obj.name,
|
||||
obj.column_names,
|
||||
", ".join(["'%s'" % obj.column_names]),
|
||||
)
|
||||
else:
|
||||
assert False
|
||||
|
||||
def obj_removed(obj: _constraint_sig):
|
||||
if is_index_sig(obj):
|
||||
def obj_removed(obj):
|
||||
if obj.is_index:
|
||||
if obj.is_unique and not supports_unique_constraints:
|
||||
# many databases double up unique constraints
|
||||
# as unique indexes. without that list we can't
|
||||
@@ -721,8 +758,10 @@ def _compare_indexes_and_uniques(
|
||||
obj.const, obj.name, "index", True, None
|
||||
):
|
||||
modify_ops.ops.append(ops.DropIndexOp.from_index(obj.const))
|
||||
log.info("Detected removed index %r on %r", obj.name, tname)
|
||||
elif is_uq_sig(obj):
|
||||
log.info(
|
||||
"Detected removed index '%s' on '%s'", obj.name, tname
|
||||
)
|
||||
else:
|
||||
if is_create_table or is_drop_table:
|
||||
# if the whole table is being dropped, we don't need to
|
||||
# consider unique constraint separately
|
||||
@@ -734,40 +773,33 @@ def _compare_indexes_and_uniques(
|
||||
ops.DropConstraintOp.from_constraint(obj.const)
|
||||
)
|
||||
log.info(
|
||||
"Detected removed unique constraint %r on %r",
|
||||
"Detected removed unique constraint '%s' on '%s'",
|
||||
obj.name,
|
||||
tname,
|
||||
)
|
||||
else:
|
||||
assert False
|
||||
|
||||
def obj_changed(
|
||||
old: _constraint_sig,
|
||||
new: _constraint_sig,
|
||||
msg: str,
|
||||
):
|
||||
if is_index_sig(old):
|
||||
assert is_index_sig(new)
|
||||
|
||||
def obj_changed(old, new, msg):
|
||||
if old.is_index:
|
||||
if autogen_context.run_object_filters(
|
||||
new.const, new.name, "index", False, old.const
|
||||
):
|
||||
log.info(
|
||||
"Detected changed index %r on %r: %s", old.name, tname, msg
|
||||
"Detected changed index '%s' on '%s':%s",
|
||||
old.name,
|
||||
tname,
|
||||
", ".join(msg),
|
||||
)
|
||||
modify_ops.ops.append(ops.DropIndexOp.from_index(old.const))
|
||||
modify_ops.ops.append(ops.CreateIndexOp.from_index(new.const))
|
||||
elif is_uq_sig(old):
|
||||
assert is_uq_sig(new)
|
||||
|
||||
else:
|
||||
if autogen_context.run_object_filters(
|
||||
new.const, new.name, "unique_constraint", False, old.const
|
||||
):
|
||||
log.info(
|
||||
"Detected changed unique constraint %r on %r: %s",
|
||||
"Detected changed unique constraint '%s' on '%s':%s",
|
||||
old.name,
|
||||
tname,
|
||||
msg,
|
||||
", ".join(msg),
|
||||
)
|
||||
modify_ops.ops.append(
|
||||
ops.DropConstraintOp.from_constraint(old.const)
|
||||
@@ -775,24 +807,18 @@ def _compare_indexes_and_uniques(
|
||||
modify_ops.ops.append(
|
||||
ops.AddConstraintOp.from_constraint(new.const)
|
||||
)
|
||||
else:
|
||||
assert False
|
||||
|
||||
for removed_name in sorted(set(conn_names).difference(metadata_names)):
|
||||
conn_obj = conn_names[removed_name]
|
||||
if (
|
||||
is_uq_sig(conn_obj)
|
||||
and conn_obj.unnamed in unnamed_metadata_uniques
|
||||
):
|
||||
conn_obj: Union[_ix_constraint_sig, _uq_constraint_sig] = conn_names[
|
||||
removed_name
|
||||
]
|
||||
if not conn_obj.is_index and conn_obj.sig in unnamed_metadata_uniques:
|
||||
continue
|
||||
elif removed_name in doubled_constraints:
|
||||
conn_uq, conn_idx = doubled_constraints[removed_name]
|
||||
if (
|
||||
all(
|
||||
conn_idx.unnamed != meta_idx.unnamed
|
||||
for meta_idx in metadata_indexes_sig
|
||||
)
|
||||
and conn_uq.unnamed not in metadata_uniques_by_sig
|
||||
conn_idx.sig not in metadata_indexes_by_sig
|
||||
and conn_uq.sig not in metadata_uniques_by_sig
|
||||
):
|
||||
obj_removed(conn_uq)
|
||||
obj_removed(conn_idx)
|
||||
@@ -804,36 +830,30 @@ def _compare_indexes_and_uniques(
|
||||
|
||||
if existing_name in doubled_constraints:
|
||||
conn_uq, conn_idx = doubled_constraints[existing_name]
|
||||
if is_index_sig(metadata_obj):
|
||||
if metadata_obj.is_index:
|
||||
conn_obj = conn_idx
|
||||
else:
|
||||
conn_obj = conn_uq
|
||||
else:
|
||||
conn_obj = conn_names[existing_name]
|
||||
|
||||
if type(conn_obj) != type(metadata_obj):
|
||||
if conn_obj.is_index != metadata_obj.is_index:
|
||||
obj_removed(conn_obj)
|
||||
obj_added(metadata_obj)
|
||||
else:
|
||||
comparison = metadata_obj.compare_to_reflected(conn_obj)
|
||||
msg = []
|
||||
if conn_obj.is_unique != metadata_obj.is_unique:
|
||||
msg.append(
|
||||
" unique=%r to unique=%r"
|
||||
% (conn_obj.is_unique, metadata_obj.is_unique)
|
||||
)
|
||||
if conn_obj.sig != metadata_obj.sig:
|
||||
msg.append(
|
||||
" expression %r to %r" % (conn_obj.sig, metadata_obj.sig)
|
||||
)
|
||||
|
||||
if comparison.is_different:
|
||||
# constraint are different
|
||||
obj_changed(conn_obj, metadata_obj, comparison.message)
|
||||
elif comparison.is_skip:
|
||||
# constraint cannot be compared, skip them
|
||||
thing = (
|
||||
"index" if is_index_sig(conn_obj) else "unique constraint"
|
||||
)
|
||||
log.info(
|
||||
"Cannot compare %s %r, assuming equal and skipping. %s",
|
||||
thing,
|
||||
conn_obj.name,
|
||||
comparison.message,
|
||||
)
|
||||
else:
|
||||
# constraint are equal
|
||||
assert comparison.is_equal
|
||||
if msg:
|
||||
obj_changed(conn_obj, metadata_obj, msg)
|
||||
|
||||
for added_name in sorted(set(metadata_names).difference(conn_names)):
|
||||
obj = metadata_names[added_name]
|
||||
@@ -873,7 +893,7 @@ def _correct_for_uq_duplicates_uix(
|
||||
}
|
||||
|
||||
unnamed_metadata_uqs = {
|
||||
impl._create_metadata_constraint_sig(cons).unnamed
|
||||
_uq_constraint_sig(cons, impl).sig
|
||||
for name, cons in metadata_cons_names
|
||||
if name is None
|
||||
}
|
||||
@@ -897,9 +917,7 @@ def _correct_for_uq_duplicates_uix(
|
||||
for overlap in uqs_dupe_indexes:
|
||||
if overlap not in metadata_uq_names:
|
||||
if (
|
||||
impl._create_reflected_constraint_sig(
|
||||
uqs_dupe_indexes[overlap]
|
||||
).unnamed
|
||||
_uq_constraint_sig(uqs_dupe_indexes[overlap], impl).sig
|
||||
not in unnamed_metadata_uqs
|
||||
):
|
||||
conn_unique_constraints.discard(uqs_dupe_indexes[overlap])
|
||||
@@ -1035,7 +1053,7 @@ def _normalize_computed_default(sqltext: str) -> str:
|
||||
|
||||
"""
|
||||
|
||||
return re.sub(r"[ \(\)'\"`\[\]\t\r\n]", "", sqltext).lower()
|
||||
return re.sub(r"[ \(\)'\"`\[\]]", "", sqltext).lower()
|
||||
|
||||
|
||||
def _compare_computed_default(
|
||||
@@ -1119,15 +1137,27 @@ def _compare_server_default(
|
||||
return False
|
||||
|
||||
if sqla_compat._server_default_is_computed(metadata_default):
|
||||
return _compare_computed_default( # type:ignore[func-returns-value]
|
||||
autogen_context,
|
||||
alter_column_op,
|
||||
schema,
|
||||
tname,
|
||||
cname,
|
||||
conn_col,
|
||||
metadata_col,
|
||||
)
|
||||
# return False in case of a computed column as the server
|
||||
# default. Note that DDL for adding or removing "GENERATED AS" from
|
||||
# an existing column is not currently known for any backend.
|
||||
# Once SQLAlchemy can reflect "GENERATED" as the "computed" element,
|
||||
# we would also want to ignore and/or warn for changes vs. the
|
||||
# metadata (or support backend specific DDL if applicable).
|
||||
if not sqla_compat.has_computed_reflection:
|
||||
return False
|
||||
|
||||
else:
|
||||
return (
|
||||
_compare_computed_default( # type:ignore[func-returns-value]
|
||||
autogen_context,
|
||||
alter_column_op,
|
||||
schema,
|
||||
tname,
|
||||
cname,
|
||||
conn_col,
|
||||
metadata_col,
|
||||
)
|
||||
)
|
||||
if sqla_compat._server_default_is_computed(conn_col_default):
|
||||
_warn_computed_not_supported(tname, cname)
|
||||
return False
|
||||
@@ -1213,8 +1243,8 @@ def _compare_foreign_keys(
|
||||
modify_table_ops: ModifyTableOps,
|
||||
schema: Optional[str],
|
||||
tname: Union[quoted_name, str],
|
||||
conn_table: Table,
|
||||
metadata_table: Table,
|
||||
conn_table: Optional[Table],
|
||||
metadata_table: Optional[Table],
|
||||
) -> None:
|
||||
# if we're doing CREATE TABLE, all FKs are created
|
||||
# inline within the table def
|
||||
@@ -1230,9 +1260,7 @@ def _compare_foreign_keys(
|
||||
|
||||
conn_fks_list = [
|
||||
fk
|
||||
for fk in _InspectorConv(inspector).get_foreign_keys(
|
||||
tname, schema=schema
|
||||
)
|
||||
for fk in inspector.get_foreign_keys(tname, schema=schema)
|
||||
if autogen_context.run_name_filters(
|
||||
fk["name"],
|
||||
"foreign_key_constraint",
|
||||
@@ -1240,11 +1268,14 @@ def _compare_foreign_keys(
|
||||
)
|
||||
]
|
||||
|
||||
conn_fks = {
|
||||
_make_foreign_key(const, conn_table) for const in conn_fks_list
|
||||
}
|
||||
backend_reflects_fk_options = bool(
|
||||
conn_fks_list and "options" in conn_fks_list[0]
|
||||
)
|
||||
|
||||
impl = autogen_context.migration_context.impl
|
||||
conn_fks = {
|
||||
_make_foreign_key(const, conn_table) # type: ignore[arg-type]
|
||||
for const in conn_fks_list
|
||||
}
|
||||
|
||||
# give the dialect a chance to correct the FKs to match more
|
||||
# closely
|
||||
@@ -1253,24 +1284,17 @@ def _compare_foreign_keys(
|
||||
)
|
||||
|
||||
metadata_fks_sig = {
|
||||
impl._create_metadata_constraint_sig(fk) for fk in metadata_fks
|
||||
_fk_constraint_sig(fk, include_options=backend_reflects_fk_options)
|
||||
for fk in metadata_fks
|
||||
}
|
||||
|
||||
conn_fks_sig = {
|
||||
impl._create_reflected_constraint_sig(fk) for fk in conn_fks
|
||||
_fk_constraint_sig(fk, include_options=backend_reflects_fk_options)
|
||||
for fk in conn_fks
|
||||
}
|
||||
|
||||
# check if reflected FKs include options, indicating the backend
|
||||
# can reflect FK options
|
||||
if conn_fks_list and "options" in conn_fks_list[0]:
|
||||
conn_fks_by_sig = {c.unnamed: c for c in conn_fks_sig}
|
||||
metadata_fks_by_sig = {c.unnamed: c for c in metadata_fks_sig}
|
||||
else:
|
||||
# otherwise compare by sig without options added
|
||||
conn_fks_by_sig = {c.unnamed_no_options: c for c in conn_fks_sig}
|
||||
metadata_fks_by_sig = {
|
||||
c.unnamed_no_options: c for c in metadata_fks_sig
|
||||
}
|
||||
conn_fks_by_sig = {c.sig: c for c in conn_fks_sig}
|
||||
metadata_fks_by_sig = {c.sig: c for c in metadata_fks_sig}
|
||||
|
||||
metadata_fks_by_name = {
|
||||
c.name: c for c in metadata_fks_sig if c.name is not None
|
||||
|
||||
@@ -1,6 +1,3 @@
|
||||
# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
|
||||
# mypy: no-warn-return-any, allow-any-generics
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from io import StringIO
|
||||
@@ -18,9 +15,7 @@ from mako.pygen import PythonPrinter
|
||||
from sqlalchemy import schema as sa_schema
|
||||
from sqlalchemy import sql
|
||||
from sqlalchemy import types as sqltypes
|
||||
from sqlalchemy.sql.base import _DialectArgView
|
||||
from sqlalchemy.sql.elements import conv
|
||||
from sqlalchemy.sql.elements import Label
|
||||
from sqlalchemy.sql.elements import quoted_name
|
||||
|
||||
from .. import util
|
||||
@@ -30,8 +25,7 @@ from ..util import sqla_compat
|
||||
if TYPE_CHECKING:
|
||||
from typing import Literal
|
||||
|
||||
from sqlalchemy import Computed
|
||||
from sqlalchemy import Identity
|
||||
from sqlalchemy.sql.base import DialectKWArgs
|
||||
from sqlalchemy.sql.elements import ColumnElement
|
||||
from sqlalchemy.sql.elements import TextClause
|
||||
from sqlalchemy.sql.schema import CheckConstraint
|
||||
@@ -51,6 +45,8 @@ if TYPE_CHECKING:
|
||||
from alembic.config import Config
|
||||
from alembic.operations.ops import MigrationScript
|
||||
from alembic.operations.ops import ModifyTableOps
|
||||
from alembic.util.sqla_compat import Computed
|
||||
from alembic.util.sqla_compat import Identity
|
||||
|
||||
|
||||
MAX_PYTHON_ARGS = 255
|
||||
@@ -168,31 +164,21 @@ def _render_modify_table(
|
||||
def _render_create_table_comment(
|
||||
autogen_context: AutogenContext, op: ops.CreateTableCommentOp
|
||||
) -> str:
|
||||
if autogen_context._has_batch:
|
||||
templ = (
|
||||
"{prefix}create_table_comment(\n"
|
||||
"{indent}{comment},\n"
|
||||
"{indent}existing_comment={existing}\n"
|
||||
")"
|
||||
)
|
||||
else:
|
||||
templ = (
|
||||
"{prefix}create_table_comment(\n"
|
||||
"{indent}'{tname}',\n"
|
||||
"{indent}{comment},\n"
|
||||
"{indent}existing_comment={existing},\n"
|
||||
"{indent}schema={schema}\n"
|
||||
")"
|
||||
)
|
||||
templ = (
|
||||
"{prefix}create_table_comment(\n"
|
||||
"{indent}'{tname}',\n"
|
||||
"{indent}{comment},\n"
|
||||
"{indent}existing_comment={existing},\n"
|
||||
"{indent}schema={schema}\n"
|
||||
")"
|
||||
)
|
||||
return templ.format(
|
||||
prefix=_alembic_autogenerate_prefix(autogen_context),
|
||||
tname=op.table_name,
|
||||
comment="%r" % op.comment if op.comment is not None else None,
|
||||
existing=(
|
||||
"%r" % op.existing_comment
|
||||
if op.existing_comment is not None
|
||||
else None
|
||||
),
|
||||
existing="%r" % op.existing_comment
|
||||
if op.existing_comment is not None
|
||||
else None,
|
||||
schema="'%s'" % op.schema if op.schema is not None else None,
|
||||
indent=" ",
|
||||
)
|
||||
@@ -202,28 +188,19 @@ def _render_create_table_comment(
|
||||
def _render_drop_table_comment(
|
||||
autogen_context: AutogenContext, op: ops.DropTableCommentOp
|
||||
) -> str:
|
||||
if autogen_context._has_batch:
|
||||
templ = (
|
||||
"{prefix}drop_table_comment(\n"
|
||||
"{indent}existing_comment={existing}\n"
|
||||
")"
|
||||
)
|
||||
else:
|
||||
templ = (
|
||||
"{prefix}drop_table_comment(\n"
|
||||
"{indent}'{tname}',\n"
|
||||
"{indent}existing_comment={existing},\n"
|
||||
"{indent}schema={schema}\n"
|
||||
")"
|
||||
)
|
||||
templ = (
|
||||
"{prefix}drop_table_comment(\n"
|
||||
"{indent}'{tname}',\n"
|
||||
"{indent}existing_comment={existing},\n"
|
||||
"{indent}schema={schema}\n"
|
||||
")"
|
||||
)
|
||||
return templ.format(
|
||||
prefix=_alembic_autogenerate_prefix(autogen_context),
|
||||
tname=op.table_name,
|
||||
existing=(
|
||||
"%r" % op.existing_comment
|
||||
if op.existing_comment is not None
|
||||
else None
|
||||
),
|
||||
existing="%r" % op.existing_comment
|
||||
if op.existing_comment is not None
|
||||
else None,
|
||||
schema="'%s'" % op.schema if op.schema is not None else None,
|
||||
indent=" ",
|
||||
)
|
||||
@@ -280,9 +257,6 @@ def _add_table(autogen_context: AutogenContext, op: ops.CreateTableOp) -> str:
|
||||
prefixes = ", ".join("'%s'" % p for p in table._prefixes)
|
||||
text += ",\nprefixes=[%s]" % prefixes
|
||||
|
||||
if op.if_not_exists is not None:
|
||||
text += ",\nif_not_exists=%r" % bool(op.if_not_exists)
|
||||
|
||||
text += "\n)"
|
||||
return text
|
||||
|
||||
@@ -295,20 +269,16 @@ def _drop_table(autogen_context: AutogenContext, op: ops.DropTableOp) -> str:
|
||||
}
|
||||
if op.schema:
|
||||
text += ", schema=%r" % _ident(op.schema)
|
||||
|
||||
if op.if_exists is not None:
|
||||
text += ", if_exists=%r" % bool(op.if_exists)
|
||||
|
||||
text += ")"
|
||||
return text
|
||||
|
||||
|
||||
def _render_dialect_kwargs_items(
|
||||
autogen_context: AutogenContext, dialect_kwargs: _DialectArgView
|
||||
autogen_context: AutogenContext, item: DialectKWArgs
|
||||
) -> list[str]:
|
||||
return [
|
||||
f"{key}={_render_potential_expr(val, autogen_context)}"
|
||||
for key, val in dialect_kwargs.items()
|
||||
for key, val in item.dialect_kwargs.items()
|
||||
]
|
||||
|
||||
|
||||
@@ -331,9 +301,7 @@ def _add_index(autogen_context: AutogenContext, op: ops.CreateIndexOp) -> str:
|
||||
|
||||
assert index.table is not None
|
||||
|
||||
opts = _render_dialect_kwargs_items(autogen_context, index.dialect_kwargs)
|
||||
if op.if_not_exists is not None:
|
||||
opts.append("if_not_exists=%r" % bool(op.if_not_exists))
|
||||
opts = _render_dialect_kwargs_items(autogen_context, index)
|
||||
text = tmpl % {
|
||||
"prefix": _alembic_autogenerate_prefix(autogen_context),
|
||||
"name": _render_gen_name(autogen_context, index.name),
|
||||
@@ -342,11 +310,9 @@ def _add_index(autogen_context: AutogenContext, op: ops.CreateIndexOp) -> str:
|
||||
_get_index_rendered_expressions(index, autogen_context)
|
||||
),
|
||||
"unique": index.unique or False,
|
||||
"schema": (
|
||||
(", schema=%r" % _ident(index.table.schema))
|
||||
if index.table.schema
|
||||
else ""
|
||||
),
|
||||
"schema": (", schema=%r" % _ident(index.table.schema))
|
||||
if index.table.schema
|
||||
else "",
|
||||
"kwargs": ", " + ", ".join(opts) if opts else "",
|
||||
}
|
||||
return text
|
||||
@@ -365,9 +331,7 @@ def _drop_index(autogen_context: AutogenContext, op: ops.DropIndexOp) -> str:
|
||||
"%(prefix)sdrop_index(%(name)r, "
|
||||
"table_name=%(table_name)r%(schema)s%(kwargs)s)"
|
||||
)
|
||||
opts = _render_dialect_kwargs_items(autogen_context, index.dialect_kwargs)
|
||||
if op.if_exists is not None:
|
||||
opts.append("if_exists=%r" % bool(op.if_exists))
|
||||
opts = _render_dialect_kwargs_items(autogen_context, index)
|
||||
text = tmpl % {
|
||||
"prefix": _alembic_autogenerate_prefix(autogen_context),
|
||||
"name": _render_gen_name(autogen_context, op.index_name),
|
||||
@@ -389,7 +353,6 @@ def _add_unique_constraint(
|
||||
def _add_fk_constraint(
|
||||
autogen_context: AutogenContext, op: ops.CreateForeignKeyOp
|
||||
) -> str:
|
||||
constraint = op.to_constraint()
|
||||
args = [repr(_render_gen_name(autogen_context, op.constraint_name))]
|
||||
if not autogen_context._has_batch:
|
||||
args.append(repr(_ident(op.source_table)))
|
||||
@@ -419,16 +382,9 @@ def _add_fk_constraint(
|
||||
if value is not None:
|
||||
args.append("%s=%r" % (k, value))
|
||||
|
||||
dialect_kwargs = _render_dialect_kwargs_items(
|
||||
autogen_context, constraint.dialect_kwargs
|
||||
)
|
||||
|
||||
return "%(prefix)screate_foreign_key(%(args)s%(dialect_kwargs)s)" % {
|
||||
return "%(prefix)screate_foreign_key(%(args)s)" % {
|
||||
"prefix": _alembic_autogenerate_prefix(autogen_context),
|
||||
"args": ", ".join(args),
|
||||
"dialect_kwargs": (
|
||||
", " + ", ".join(dialect_kwargs) if dialect_kwargs else ""
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@@ -450,7 +406,7 @@ def _drop_constraint(
|
||||
name = _render_gen_name(autogen_context, op.constraint_name)
|
||||
schema = _ident(op.schema) if op.schema else None
|
||||
type_ = _ident(op.constraint_type) if op.constraint_type else None
|
||||
if_exists = op.if_exists
|
||||
|
||||
params_strs = []
|
||||
params_strs.append(repr(name))
|
||||
if not autogen_context._has_batch:
|
||||
@@ -459,47 +415,32 @@ def _drop_constraint(
|
||||
params_strs.append(f"schema={schema!r}")
|
||||
if type_ is not None:
|
||||
params_strs.append(f"type_={type_!r}")
|
||||
if if_exists is not None:
|
||||
params_strs.append(f"if_exists={if_exists}")
|
||||
|
||||
return f"{prefix}drop_constraint({', '.join(params_strs)})"
|
||||
|
||||
|
||||
@renderers.dispatch_for(ops.AddColumnOp)
|
||||
def _add_column(autogen_context: AutogenContext, op: ops.AddColumnOp) -> str:
|
||||
schema, tname, column, if_not_exists = (
|
||||
op.schema,
|
||||
op.table_name,
|
||||
op.column,
|
||||
op.if_not_exists,
|
||||
)
|
||||
schema, tname, column = op.schema, op.table_name, op.column
|
||||
if autogen_context._has_batch:
|
||||
template = "%(prefix)sadd_column(%(column)s)"
|
||||
else:
|
||||
template = "%(prefix)sadd_column(%(tname)r, %(column)s"
|
||||
if schema:
|
||||
template += ", schema=%(schema)r"
|
||||
if if_not_exists is not None:
|
||||
template += ", if_not_exists=%(if_not_exists)r"
|
||||
template += ")"
|
||||
text = template % {
|
||||
"prefix": _alembic_autogenerate_prefix(autogen_context),
|
||||
"tname": tname,
|
||||
"column": _render_column(column, autogen_context),
|
||||
"schema": schema,
|
||||
"if_not_exists": if_not_exists,
|
||||
}
|
||||
return text
|
||||
|
||||
|
||||
@renderers.dispatch_for(ops.DropColumnOp)
|
||||
def _drop_column(autogen_context: AutogenContext, op: ops.DropColumnOp) -> str:
|
||||
schema, tname, column_name, if_exists = (
|
||||
op.schema,
|
||||
op.table_name,
|
||||
op.column_name,
|
||||
op.if_exists,
|
||||
)
|
||||
schema, tname, column_name = op.schema, op.table_name, op.column_name
|
||||
|
||||
if autogen_context._has_batch:
|
||||
template = "%(prefix)sdrop_column(%(cname)r)"
|
||||
@@ -507,8 +448,6 @@ def _drop_column(autogen_context: AutogenContext, op: ops.DropColumnOp) -> str:
|
||||
template = "%(prefix)sdrop_column(%(tname)r, %(cname)r"
|
||||
if schema:
|
||||
template += ", schema=%(schema)r"
|
||||
if if_exists is not None:
|
||||
template += ", if_exists=%(if_exists)r"
|
||||
template += ")"
|
||||
|
||||
text = template % {
|
||||
@@ -516,7 +455,6 @@ def _drop_column(autogen_context: AutogenContext, op: ops.DropColumnOp) -> str:
|
||||
"tname": _ident(tname),
|
||||
"cname": _ident(column_name),
|
||||
"schema": _ident(schema),
|
||||
"if_exists": if_exists,
|
||||
}
|
||||
return text
|
||||
|
||||
@@ -531,7 +469,6 @@ def _alter_column(
|
||||
type_ = op.modify_type
|
||||
nullable = op.modify_nullable
|
||||
comment = op.modify_comment
|
||||
newname = op.modify_name
|
||||
autoincrement = op.kw.get("autoincrement", None)
|
||||
existing_type = op.existing_type
|
||||
existing_nullable = op.existing_nullable
|
||||
@@ -560,8 +497,6 @@ def _alter_column(
|
||||
rendered = _render_server_default(server_default, autogen_context)
|
||||
text += ",\n%sserver_default=%s" % (indent, rendered)
|
||||
|
||||
if newname is not None:
|
||||
text += ",\n%snew_column_name=%r" % (indent, newname)
|
||||
if type_ is not None:
|
||||
text += ",\n%stype_=%s" % (indent, _repr_type(type_, autogen_context))
|
||||
if nullable is not None:
|
||||
@@ -614,28 +549,23 @@ def _render_potential_expr(
|
||||
value: Any,
|
||||
autogen_context: AutogenContext,
|
||||
*,
|
||||
wrap_in_element: bool = True,
|
||||
wrap_in_text: bool = True,
|
||||
is_server_default: bool = False,
|
||||
is_index: bool = False,
|
||||
) -> str:
|
||||
if isinstance(value, sql.ClauseElement):
|
||||
sql_text = autogen_context.migration_context.impl.render_ddl_sql_expr(
|
||||
value, is_server_default=is_server_default, is_index=is_index
|
||||
)
|
||||
if wrap_in_element:
|
||||
prefix = _sqlalchemy_autogenerate_prefix(autogen_context)
|
||||
element = "literal_column" if is_index else "text"
|
||||
value_str = f"{prefix}{element}({sql_text!r})"
|
||||
if (
|
||||
is_index
|
||||
and isinstance(value, Label)
|
||||
and type(value.name) is str
|
||||
):
|
||||
return value_str + f".label({value.name!r})"
|
||||
else:
|
||||
return value_str
|
||||
if wrap_in_text:
|
||||
template = "%(prefix)stext(%(sql)r)"
|
||||
else:
|
||||
return repr(sql_text)
|
||||
template = "%(sql)r"
|
||||
|
||||
return template % {
|
||||
"prefix": _sqlalchemy_autogenerate_prefix(autogen_context),
|
||||
"sql": autogen_context.migration_context.impl.render_ddl_sql_expr(
|
||||
value, is_server_default=is_server_default, is_index=is_index
|
||||
),
|
||||
}
|
||||
|
||||
else:
|
||||
return repr(value)
|
||||
|
||||
@@ -644,11 +574,9 @@ def _get_index_rendered_expressions(
|
||||
idx: Index, autogen_context: AutogenContext
|
||||
) -> List[str]:
|
||||
return [
|
||||
(
|
||||
repr(_ident(getattr(exp, "name", None)))
|
||||
if isinstance(exp, sa_schema.Column)
|
||||
else _render_potential_expr(exp, autogen_context, is_index=True)
|
||||
)
|
||||
repr(_ident(getattr(exp, "name", None)))
|
||||
if isinstance(exp, sa_schema.Column)
|
||||
else _render_potential_expr(exp, autogen_context, is_index=True)
|
||||
for exp in idx.expressions
|
||||
]
|
||||
|
||||
@@ -663,18 +591,16 @@ def _uq_constraint(
|
||||
has_batch = autogen_context._has_batch
|
||||
|
||||
if constraint.deferrable:
|
||||
opts.append(("deferrable", constraint.deferrable))
|
||||
opts.append(("deferrable", str(constraint.deferrable)))
|
||||
if constraint.initially:
|
||||
opts.append(("initially", constraint.initially))
|
||||
opts.append(("initially", str(constraint.initially)))
|
||||
if not has_batch and alter and constraint.table.schema:
|
||||
opts.append(("schema", _ident(constraint.table.schema)))
|
||||
if not alter and constraint.name:
|
||||
opts.append(
|
||||
("name", _render_gen_name(autogen_context, constraint.name))
|
||||
)
|
||||
dialect_options = _render_dialect_kwargs_items(
|
||||
autogen_context, constraint.dialect_kwargs
|
||||
)
|
||||
dialect_options = _render_dialect_kwargs_items(autogen_context, constraint)
|
||||
|
||||
if alter:
|
||||
args = [repr(_render_gen_name(autogen_context, constraint.name))]
|
||||
@@ -778,7 +704,7 @@ def _render_column(
|
||||
+ [
|
||||
"%s=%s"
|
||||
% (key, _render_potential_expr(val, autogen_context))
|
||||
for key, val in column.kwargs.items()
|
||||
for key, val in sqla_compat._column_kwargs(column).items()
|
||||
]
|
||||
)
|
||||
),
|
||||
@@ -813,8 +739,6 @@ def _render_server_default(
|
||||
return _render_potential_expr(
|
||||
default.arg, autogen_context, is_server_default=True
|
||||
)
|
||||
elif isinstance(default, sa_schema.FetchedValue):
|
||||
return _render_fetched_value(autogen_context)
|
||||
|
||||
if isinstance(default, str) and repr_:
|
||||
default = repr(re.sub(r"^'|'$", "", default))
|
||||
@@ -826,7 +750,7 @@ def _render_computed(
|
||||
computed: Computed, autogen_context: AutogenContext
|
||||
) -> str:
|
||||
text = _render_potential_expr(
|
||||
computed.sqltext, autogen_context, wrap_in_element=False
|
||||
computed.sqltext, autogen_context, wrap_in_text=False
|
||||
)
|
||||
|
||||
kwargs = {}
|
||||
@@ -852,12 +776,6 @@ def _render_identity(
|
||||
}
|
||||
|
||||
|
||||
def _render_fetched_value(autogen_context: AutogenContext) -> str:
|
||||
return "%(prefix)sFetchedValue()" % {
|
||||
"prefix": _sqlalchemy_autogenerate_prefix(autogen_context),
|
||||
}
|
||||
|
||||
|
||||
def _repr_type(
|
||||
type_: TypeEngine,
|
||||
autogen_context: AutogenContext,
|
||||
@@ -876,10 +794,7 @@ def _repr_type(
|
||||
|
||||
mod = type(type_).__module__
|
||||
imports = autogen_context.imports
|
||||
|
||||
if not _skip_variants and sqla_compat._type_has_variants(type_):
|
||||
return _render_Variant_type(type_, autogen_context)
|
||||
elif mod.startswith("sqlalchemy.dialects"):
|
||||
if mod.startswith("sqlalchemy.dialects"):
|
||||
match = re.match(r"sqlalchemy\.dialects\.(\w+)", mod)
|
||||
assert match is not None
|
||||
dname = match.group(1)
|
||||
@@ -891,6 +806,8 @@ def _repr_type(
|
||||
return "%s.%r" % (dname, type_)
|
||||
elif impl_rt:
|
||||
return impl_rt
|
||||
elif not _skip_variants and sqla_compat._type_has_variants(type_):
|
||||
return _render_Variant_type(type_, autogen_context)
|
||||
elif mod.startswith("sqlalchemy."):
|
||||
if "_render_%s_type" % type_.__visit_name__ in globals():
|
||||
fn = globals()["_render_%s_type" % type_.__visit_name__]
|
||||
@@ -917,7 +834,7 @@ def _render_Variant_type(
|
||||
) -> str:
|
||||
base_type, variant_mapping = sqla_compat._get_variant_mapping(type_)
|
||||
base = _repr_type(base_type, autogen_context, _skip_variants=True)
|
||||
assert base is not None and base is not False # type: ignore[comparison-overlap] # noqa:E501
|
||||
assert base is not None and base is not False
|
||||
for dialect in sorted(variant_mapping):
|
||||
typ = variant_mapping[dialect]
|
||||
base += ".with_variant(%s, %r)" % (
|
||||
@@ -1008,13 +925,13 @@ def _render_primary_key(
|
||||
def _fk_colspec(
|
||||
fk: ForeignKey,
|
||||
metadata_schema: Optional[str],
|
||||
namespace_metadata: Optional[MetaData],
|
||||
namespace_metadata: MetaData,
|
||||
) -> str:
|
||||
"""Implement a 'safe' version of ForeignKey._get_colspec() that
|
||||
won't fail if the remote table can't be resolved.
|
||||
|
||||
"""
|
||||
colspec = fk._get_colspec()
|
||||
colspec = fk._get_colspec() # type:ignore[attr-defined]
|
||||
tokens = colspec.split(".")
|
||||
tname, colname = tokens[-2:]
|
||||
|
||||
@@ -1032,10 +949,7 @@ def _fk_colspec(
|
||||
# the FK constraint needs to be rendered in terms of the column
|
||||
# name.
|
||||
|
||||
if (
|
||||
namespace_metadata is not None
|
||||
and table_fullname in namespace_metadata.tables
|
||||
):
|
||||
if table_fullname in namespace_metadata.tables:
|
||||
col = namespace_metadata.tables[table_fullname].c.get(colname)
|
||||
if col is not None:
|
||||
colname = _ident(col.name) # type: ignore[assignment]
|
||||
@@ -1066,7 +980,7 @@ def _populate_render_fk_opts(
|
||||
def _render_foreign_key(
|
||||
constraint: ForeignKeyConstraint,
|
||||
autogen_context: AutogenContext,
|
||||
namespace_metadata: Optional[MetaData],
|
||||
namespace_metadata: MetaData,
|
||||
) -> Optional[str]:
|
||||
rendered = _user_defined_render("foreign_key", constraint, autogen_context)
|
||||
if rendered is not False:
|
||||
@@ -1080,16 +994,15 @@ def _render_foreign_key(
|
||||
|
||||
_populate_render_fk_opts(constraint, opts)
|
||||
|
||||
apply_metadata_schema = (
|
||||
namespace_metadata.schema if namespace_metadata is not None else None
|
||||
)
|
||||
apply_metadata_schema = namespace_metadata.schema
|
||||
return (
|
||||
"%(prefix)sForeignKeyConstraint([%(cols)s], "
|
||||
"[%(refcols)s], %(args)s)"
|
||||
% {
|
||||
"prefix": _sqlalchemy_autogenerate_prefix(autogen_context),
|
||||
"cols": ", ".join(
|
||||
repr(_ident(f.parent.name)) for f in constraint.elements
|
||||
"%r" % _ident(cast("Column", f.parent).name)
|
||||
for f in constraint.elements
|
||||
),
|
||||
"refcols": ", ".join(
|
||||
repr(_fk_colspec(f, apply_metadata_schema, namespace_metadata))
|
||||
@@ -1130,10 +1043,12 @@ def _render_check_constraint(
|
||||
# ideally SQLAlchemy would give us more of a first class
|
||||
# way to detect this.
|
||||
if (
|
||||
constraint._create_rule
|
||||
and hasattr(constraint._create_rule, "target")
|
||||
constraint._create_rule # type:ignore[attr-defined]
|
||||
and hasattr(
|
||||
constraint._create_rule, "target" # type:ignore[attr-defined]
|
||||
)
|
||||
and isinstance(
|
||||
constraint._create_rule.target,
|
||||
constraint._create_rule.target, # type:ignore[attr-defined]
|
||||
sqltypes.TypeEngine,
|
||||
)
|
||||
):
|
||||
@@ -1145,13 +1060,11 @@ def _render_check_constraint(
|
||||
)
|
||||
return "%(prefix)sCheckConstraint(%(sqltext)s%(opts)s)" % {
|
||||
"prefix": _sqlalchemy_autogenerate_prefix(autogen_context),
|
||||
"opts": (
|
||||
", " + (", ".join("%s=%s" % (k, v) for k, v in opts))
|
||||
if opts
|
||||
else ""
|
||||
),
|
||||
"opts": ", " + (", ".join("%s=%s" % (k, v) for k, v in opts))
|
||||
if opts
|
||||
else "",
|
||||
"sqltext": _render_potential_expr(
|
||||
constraint.sqltext, autogen_context, wrap_in_element=False
|
||||
constraint.sqltext, autogen_context, wrap_in_text=False
|
||||
),
|
||||
}
|
||||
|
||||
@@ -1163,10 +1076,7 @@ def _execute_sql(autogen_context: AutogenContext, op: ops.ExecuteSQLOp) -> str:
|
||||
"Autogenerate rendering of SQL Expression language constructs "
|
||||
"not supported here; please use a plain SQL string"
|
||||
)
|
||||
return "{prefix}execute({sqltext!r})".format(
|
||||
prefix=_alembic_autogenerate_prefix(autogen_context),
|
||||
sqltext=op.sqltext,
|
||||
)
|
||||
return "op.execute(%r)" % op.sqltext
|
||||
|
||||
|
||||
renderers = default_renderers.branch()
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import Any
|
||||
from typing import Callable
|
||||
from typing import Iterator
|
||||
from typing import List
|
||||
from typing import Tuple
|
||||
from typing import Optional
|
||||
from typing import Type
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import Union
|
||||
@@ -16,18 +16,12 @@ if TYPE_CHECKING:
|
||||
from ..operations.ops import AddColumnOp
|
||||
from ..operations.ops import AlterColumnOp
|
||||
from ..operations.ops import CreateTableOp
|
||||
from ..operations.ops import DowngradeOps
|
||||
from ..operations.ops import MigrateOperation
|
||||
from ..operations.ops import MigrationScript
|
||||
from ..operations.ops import ModifyTableOps
|
||||
from ..operations.ops import OpContainer
|
||||
from ..operations.ops import UpgradeOps
|
||||
from ..runtime.environment import _GetRevArg
|
||||
from ..runtime.migration import MigrationContext
|
||||
from ..script.revision import _GetRevArg
|
||||
|
||||
ProcessRevisionDirectiveFn = Callable[
|
||||
["MigrationContext", "_GetRevArg", List["MigrationScript"]], None
|
||||
]
|
||||
|
||||
|
||||
class Rewriter:
|
||||
@@ -58,21 +52,15 @@ class Rewriter:
|
||||
|
||||
_traverse = util.Dispatcher()
|
||||
|
||||
_chained: Tuple[Union[ProcessRevisionDirectiveFn, Rewriter], ...] = ()
|
||||
_chained: Optional[Rewriter] = None
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.dispatch = util.Dispatcher()
|
||||
|
||||
def chain(
|
||||
self,
|
||||
other: Union[
|
||||
ProcessRevisionDirectiveFn,
|
||||
Rewriter,
|
||||
],
|
||||
) -> Rewriter:
|
||||
def chain(self, other: Rewriter) -> Rewriter:
|
||||
"""Produce a "chain" of this :class:`.Rewriter` to another.
|
||||
|
||||
This allows two or more rewriters to operate serially on a stream,
|
||||
This allows two rewriters to operate serially on a stream,
|
||||
e.g.::
|
||||
|
||||
writer1 = autogenerate.Rewriter()
|
||||
@@ -101,7 +89,7 @@ class Rewriter:
|
||||
"""
|
||||
wr = self.__class__.__new__(self.__class__)
|
||||
wr.__dict__.update(self.__dict__)
|
||||
wr._chained += (other,)
|
||||
wr._chained = other
|
||||
return wr
|
||||
|
||||
def rewrites(
|
||||
@@ -113,7 +101,7 @@ class Rewriter:
|
||||
Type[CreateTableOp],
|
||||
Type[ModifyTableOps],
|
||||
],
|
||||
) -> Callable[..., Any]:
|
||||
) -> Callable:
|
||||
"""Register a function as rewriter for a given type.
|
||||
|
||||
The function should receive three arguments, which are
|
||||
@@ -158,8 +146,8 @@ class Rewriter:
|
||||
directives: List[MigrationScript],
|
||||
) -> None:
|
||||
self.process_revision_directives(context, revision, directives)
|
||||
for process_revision_directives in self._chained:
|
||||
process_revision_directives(context, revision, directives)
|
||||
if self._chained:
|
||||
self._chained(context, revision, directives)
|
||||
|
||||
@_traverse.dispatch_for(ops.MigrationScript)
|
||||
def _traverse_script(
|
||||
@@ -168,7 +156,7 @@ class Rewriter:
|
||||
revision: _GetRevArg,
|
||||
directive: MigrationScript,
|
||||
) -> None:
|
||||
upgrade_ops_list: List[UpgradeOps] = []
|
||||
upgrade_ops_list = []
|
||||
for upgrade_ops in directive.upgrade_ops_list:
|
||||
ret = self._traverse_for(context, revision, upgrade_ops)
|
||||
if len(ret) != 1:
|
||||
@@ -176,10 +164,9 @@ class Rewriter:
|
||||
"Can only return single object for UpgradeOps traverse"
|
||||
)
|
||||
upgrade_ops_list.append(ret[0])
|
||||
|
||||
directive.upgrade_ops = upgrade_ops_list
|
||||
|
||||
downgrade_ops_list: List[DowngradeOps] = []
|
||||
downgrade_ops_list = []
|
||||
for downgrade_ops in directive.downgrade_ops_list:
|
||||
ret = self._traverse_for(context, revision, downgrade_ops)
|
||||
if len(ret) != 1:
|
||||
|
||||
@@ -1,9 +1,6 @@
|
||||
# mypy: allow-untyped-defs, allow-untyped-calls
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import pathlib
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING
|
||||
@@ -13,7 +10,6 @@ from . import autogenerate as autogen
|
||||
from . import util
|
||||
from .runtime.environment import EnvironmentContext
|
||||
from .script import ScriptDirectory
|
||||
from .util import compat
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from alembic.config import Config
|
||||
@@ -22,7 +18,7 @@ if TYPE_CHECKING:
|
||||
from .runtime.environment import ProcessRevisionDirectiveFn
|
||||
|
||||
|
||||
def list_templates(config: Config) -> None:
|
||||
def list_templates(config: Config):
|
||||
"""List available templates.
|
||||
|
||||
:param config: a :class:`.Config` object.
|
||||
@@ -30,10 +26,12 @@ def list_templates(config: Config) -> None:
|
||||
"""
|
||||
|
||||
config.print_stdout("Available templates:\n")
|
||||
for tempname in config._get_template_path().iterdir():
|
||||
with (tempname / "README").open() as readme:
|
||||
for tempname in os.listdir(config.get_template_directory()):
|
||||
with open(
|
||||
os.path.join(config.get_template_directory(), tempname, "README")
|
||||
) as readme:
|
||||
synopsis = next(readme).rstrip()
|
||||
config.print_stdout("%s - %s", tempname.name, synopsis)
|
||||
config.print_stdout("%s - %s", tempname, synopsis)
|
||||
|
||||
config.print_stdout("\nTemplates are used via the 'init' command, e.g.:")
|
||||
config.print_stdout("\n alembic init --template generic ./scripts")
|
||||
@@ -49,7 +47,7 @@ def init(
|
||||
|
||||
:param config: a :class:`.Config` object.
|
||||
|
||||
:param directory: string path of the target directory.
|
||||
:param directory: string path of the target directory
|
||||
|
||||
:param template: string name of the migration environment template to
|
||||
use.
|
||||
@@ -59,136 +57,65 @@ def init(
|
||||
|
||||
"""
|
||||
|
||||
directory_path = pathlib.Path(directory)
|
||||
if directory_path.exists() and list(directory_path.iterdir()):
|
||||
if os.access(directory, os.F_OK) and os.listdir(directory):
|
||||
raise util.CommandError(
|
||||
"Directory %s already exists and is not empty" % directory_path
|
||||
"Directory %s already exists and is not empty" % directory
|
||||
)
|
||||
|
||||
template_path = config._get_template_path() / template
|
||||
template_dir = os.path.join(config.get_template_directory(), template)
|
||||
if not os.access(template_dir, os.F_OK):
|
||||
raise util.CommandError("No such template %r" % template)
|
||||
|
||||
if not template_path.exists():
|
||||
raise util.CommandError(f"No such template {template_path}")
|
||||
|
||||
# left as os.access() to suit unit test mocking
|
||||
if not os.access(directory_path, os.F_OK):
|
||||
if not os.access(directory, os.F_OK):
|
||||
with util.status(
|
||||
f"Creating directory {directory_path.absolute()}",
|
||||
f"Creating directory {os.path.abspath(directory)!r}",
|
||||
**config.messaging_opts,
|
||||
):
|
||||
os.makedirs(directory_path)
|
||||
os.makedirs(directory)
|
||||
|
||||
versions = directory_path / "versions"
|
||||
versions = os.path.join(directory, "versions")
|
||||
with util.status(
|
||||
f"Creating directory {versions.absolute()}",
|
||||
f"Creating directory {os.path.abspath(versions)!r}",
|
||||
**config.messaging_opts,
|
||||
):
|
||||
os.makedirs(versions)
|
||||
|
||||
if not directory_path.is_absolute():
|
||||
# for non-absolute path, state config file in .ini / pyproject
|
||||
# as relative to the %(here)s token, which is where the config
|
||||
# file itself would be
|
||||
script = ScriptDirectory(directory)
|
||||
|
||||
if config._config_file_path is not None:
|
||||
rel_dir = compat.path_relative_to(
|
||||
directory_path.absolute(),
|
||||
config._config_file_path.absolute().parent,
|
||||
walk_up=True,
|
||||
)
|
||||
ini_script_location_directory = ("%(here)s" / rel_dir).as_posix()
|
||||
if config._toml_file_path is not None:
|
||||
rel_dir = compat.path_relative_to(
|
||||
directory_path.absolute(),
|
||||
config._toml_file_path.absolute().parent,
|
||||
walk_up=True,
|
||||
)
|
||||
toml_script_location_directory = ("%(here)s" / rel_dir).as_posix()
|
||||
|
||||
else:
|
||||
ini_script_location_directory = directory_path.as_posix()
|
||||
toml_script_location_directory = directory_path.as_posix()
|
||||
|
||||
script = ScriptDirectory(directory_path)
|
||||
|
||||
has_toml = False
|
||||
|
||||
config_file: pathlib.Path | None = None
|
||||
|
||||
for file_path in template_path.iterdir():
|
||||
file_ = file_path.name
|
||||
config_file: str | None = None
|
||||
for file_ in os.listdir(template_dir):
|
||||
file_path = os.path.join(template_dir, file_)
|
||||
if file_ == "alembic.ini.mako":
|
||||
assert config.config_file_name is not None
|
||||
config_file = pathlib.Path(config.config_file_name).absolute()
|
||||
if config_file.exists():
|
||||
config_file = os.path.abspath(config.config_file_name)
|
||||
if os.access(config_file, os.F_OK):
|
||||
util.msg(
|
||||
f"File {config_file} already exists, skipping",
|
||||
f"File {config_file!r} already exists, skipping",
|
||||
**config.messaging_opts,
|
||||
)
|
||||
else:
|
||||
script._generate_template(
|
||||
file_path,
|
||||
config_file,
|
||||
script_location=ini_script_location_directory,
|
||||
file_path, config_file, script_location=directory
|
||||
)
|
||||
elif file_ == "pyproject.toml.mako":
|
||||
has_toml = True
|
||||
assert config._toml_file_path is not None
|
||||
toml_path = config._toml_file_path.absolute()
|
||||
|
||||
if toml_path.exists():
|
||||
# left as open() to suit unit test mocking
|
||||
with open(toml_path, "rb") as f:
|
||||
toml_data = compat.tomllib.load(f)
|
||||
if "tool" in toml_data and "alembic" in toml_data["tool"]:
|
||||
|
||||
util.msg(
|
||||
f"File {toml_path} already exists "
|
||||
"and already has a [tool.alembic] section, "
|
||||
"skipping",
|
||||
)
|
||||
continue
|
||||
script._append_template(
|
||||
file_path,
|
||||
toml_path,
|
||||
script_location=toml_script_location_directory,
|
||||
)
|
||||
else:
|
||||
script._generate_template(
|
||||
file_path,
|
||||
toml_path,
|
||||
script_location=toml_script_location_directory,
|
||||
)
|
||||
|
||||
elif file_path.is_file():
|
||||
output_file = directory_path / file_
|
||||
elif os.path.isfile(file_path):
|
||||
output_file = os.path.join(directory, file_)
|
||||
script._copy_file(file_path, output_file)
|
||||
|
||||
if package:
|
||||
for path in [
|
||||
directory_path.absolute() / "__init__.py",
|
||||
versions.absolute() / "__init__.py",
|
||||
os.path.join(os.path.abspath(directory), "__init__.py"),
|
||||
os.path.join(os.path.abspath(versions), "__init__.py"),
|
||||
]:
|
||||
with util.status(f"Adding {path!s}", **config.messaging_opts):
|
||||
# left as open() to suit unit test mocking
|
||||
with util.status(f"Adding {path!r}", **config.messaging_opts):
|
||||
with open(path, "w"):
|
||||
pass
|
||||
|
||||
assert config_file is not None
|
||||
|
||||
if has_toml:
|
||||
util.msg(
|
||||
f"Please edit configuration settings in {toml_path} and "
|
||||
"configuration/connection/logging "
|
||||
f"settings in {config_file} before proceeding.",
|
||||
**config.messaging_opts,
|
||||
)
|
||||
else:
|
||||
util.msg(
|
||||
"Please edit configuration/connection/logging "
|
||||
f"settings in {config_file} before proceeding.",
|
||||
**config.messaging_opts,
|
||||
)
|
||||
util.msg(
|
||||
"Please edit configuration/connection/logging "
|
||||
f"settings in {config_file!r} before proceeding.",
|
||||
**config.messaging_opts,
|
||||
)
|
||||
|
||||
|
||||
def revision(
|
||||
@@ -199,7 +126,7 @@ def revision(
|
||||
head: str = "head",
|
||||
splice: bool = False,
|
||||
branch_label: Optional[_RevIdType] = None,
|
||||
version_path: Union[str, os.PathLike[str], None] = None,
|
||||
version_path: Optional[str] = None,
|
||||
rev_id: Optional[str] = None,
|
||||
depends_on: Optional[str] = None,
|
||||
process_revision_directives: Optional[ProcessRevisionDirectiveFn] = None,
|
||||
@@ -245,7 +172,7 @@ def revision(
|
||||
will be applied to the structure generated by the revision process
|
||||
where it can be altered programmatically. Note that unlike all
|
||||
the other parameters, this option is only available via programmatic
|
||||
use of :func:`.command.revision`.
|
||||
use of :func:`.command.revision`
|
||||
|
||||
"""
|
||||
|
||||
@@ -269,9 +196,7 @@ def revision(
|
||||
process_revision_directives=process_revision_directives,
|
||||
)
|
||||
|
||||
environment = util.asbool(
|
||||
config.get_alembic_option("revision_environment")
|
||||
)
|
||||
environment = util.asbool(config.get_main_option("revision_environment"))
|
||||
|
||||
if autogenerate:
|
||||
environment = True
|
||||
@@ -365,15 +290,10 @@ def check(config: "Config") -> None:
|
||||
# the revision_context now has MigrationScript structure(s) present.
|
||||
|
||||
migration_script = revision_context.generated_revisions[-1]
|
||||
diffs = []
|
||||
for upgrade_ops in migration_script.upgrade_ops_list:
|
||||
diffs.extend(upgrade_ops.as_diffs())
|
||||
|
||||
diffs = migration_script.upgrade_ops.as_diffs()
|
||||
if diffs:
|
||||
raise util.AutogenerateDiffsDetected(
|
||||
f"New upgrade operations detected: {diffs}",
|
||||
revision_context=revision_context,
|
||||
diffs=diffs,
|
||||
f"New upgrade operations detected: {diffs}"
|
||||
)
|
||||
else:
|
||||
config.print_stdout("No new upgrade operations detected.")
|
||||
@@ -390,11 +310,9 @@ def merge(
|
||||
|
||||
:param config: a :class:`.Config` instance
|
||||
|
||||
:param revisions: The revisions to merge.
|
||||
:param message: string message to apply to the revision
|
||||
|
||||
:param message: string message to apply to the revision.
|
||||
|
||||
:param branch_label: string label name to apply to the new revision.
|
||||
:param branch_label: string label name to apply to the new revision
|
||||
|
||||
:param rev_id: hardcoded revision identifier instead of generating a new
|
||||
one.
|
||||
@@ -411,9 +329,7 @@ def merge(
|
||||
# e.g. multiple databases
|
||||
}
|
||||
|
||||
environment = util.asbool(
|
||||
config.get_alembic_option("revision_environment")
|
||||
)
|
||||
environment = util.asbool(config.get_main_option("revision_environment"))
|
||||
|
||||
if environment:
|
||||
|
||||
@@ -449,10 +365,9 @@ def upgrade(
|
||||
|
||||
:param config: a :class:`.Config` instance.
|
||||
|
||||
:param revision: string revision target or range for --sql mode. May be
|
||||
``"heads"`` to target the most recent revision(s).
|
||||
:param revision: string revision target or range for --sql mode
|
||||
|
||||
:param sql: if True, use ``--sql`` mode.
|
||||
:param sql: if True, use ``--sql`` mode
|
||||
|
||||
:param tag: an arbitrary "tag" that can be intercepted by custom
|
||||
``env.py`` scripts via the :meth:`.EnvironmentContext.get_tag_argument`
|
||||
@@ -493,10 +408,9 @@ def downgrade(
|
||||
|
||||
:param config: a :class:`.Config` instance.
|
||||
|
||||
:param revision: string revision target or range for --sql mode. May
|
||||
be ``"base"`` to target the first revision.
|
||||
:param revision: string revision target or range for --sql mode
|
||||
|
||||
:param sql: if True, use ``--sql`` mode.
|
||||
:param sql: if True, use ``--sql`` mode
|
||||
|
||||
:param tag: an arbitrary "tag" that can be intercepted by custom
|
||||
``env.py`` scripts via the :meth:`.EnvironmentContext.get_tag_argument`
|
||||
@@ -530,13 +444,12 @@ def downgrade(
|
||||
script.run_env()
|
||||
|
||||
|
||||
def show(config: Config, rev: str) -> None:
|
||||
def show(config, rev):
|
||||
"""Show the revision(s) denoted by the given symbol.
|
||||
|
||||
:param config: a :class:`.Config` instance.
|
||||
|
||||
:param rev: string revision target. May be ``"current"`` to show the
|
||||
revision(s) currently applied in the database.
|
||||
:param revision: string revision target
|
||||
|
||||
"""
|
||||
|
||||
@@ -566,7 +479,7 @@ def history(
|
||||
|
||||
:param config: a :class:`.Config` instance.
|
||||
|
||||
:param rev_range: string revision range.
|
||||
:param rev_range: string revision range
|
||||
|
||||
:param verbose: output in verbose mode.
|
||||
|
||||
@@ -586,7 +499,7 @@ def history(
|
||||
base = head = None
|
||||
|
||||
environment = (
|
||||
util.asbool(config.get_alembic_option("revision_environment"))
|
||||
util.asbool(config.get_main_option("revision_environment"))
|
||||
or indicate_current
|
||||
)
|
||||
|
||||
@@ -625,9 +538,7 @@ def history(
|
||||
_display_history(config, script, base, head)
|
||||
|
||||
|
||||
def heads(
|
||||
config: Config, verbose: bool = False, resolve_dependencies: bool = False
|
||||
) -> None:
|
||||
def heads(config, verbose=False, resolve_dependencies=False):
|
||||
"""Show current available heads in the script directory.
|
||||
|
||||
:param config: a :class:`.Config` instance.
|
||||
@@ -652,7 +563,7 @@ def heads(
|
||||
)
|
||||
|
||||
|
||||
def branches(config: Config, verbose: bool = False) -> None:
|
||||
def branches(config, verbose=False):
|
||||
"""Show current branch points.
|
||||
|
||||
:param config: a :class:`.Config` instance.
|
||||
@@ -722,9 +633,7 @@ def stamp(
|
||||
:param config: a :class:`.Config` instance.
|
||||
|
||||
:param revision: target revision or list of revisions. May be a list
|
||||
to indicate stamping of multiple branch heads; may be ``"base"``
|
||||
to remove all revisions from the table or ``"heads"`` to stamp the
|
||||
most recent revision(s).
|
||||
to indicate stamping of multiple branch heads.
|
||||
|
||||
.. note:: this parameter is called "revisions" in the command line
|
||||
interface.
|
||||
@@ -814,7 +723,7 @@ def ensure_version(config: Config, sql: bool = False) -> None:
|
||||
|
||||
:param config: a :class:`.Config` instance.
|
||||
|
||||
:param sql: use ``--sql`` mode.
|
||||
:param sql: use ``--sql`` mode
|
||||
|
||||
.. versionadded:: 1.7.6
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -5,6 +5,7 @@ from __future__ import annotations
|
||||
from typing import Any
|
||||
from typing import Callable
|
||||
from typing import Collection
|
||||
from typing import ContextManager
|
||||
from typing import Dict
|
||||
from typing import Iterable
|
||||
from typing import List
|
||||
@@ -13,14 +14,11 @@ from typing import Mapping
|
||||
from typing import MutableMapping
|
||||
from typing import Optional
|
||||
from typing import overload
|
||||
from typing import Sequence
|
||||
from typing import TextIO
|
||||
from typing import Tuple
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import Union
|
||||
|
||||
from typing_extensions import ContextManager
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlalchemy.engine.base import Connection
|
||||
from sqlalchemy.engine.url import URL
|
||||
@@ -41,9 +39,7 @@ if TYPE_CHECKING:
|
||||
|
||||
### end imports ###
|
||||
|
||||
def begin_transaction() -> (
|
||||
Union[_ProxyTransaction, ContextManager[None, Optional[bool]]]
|
||||
):
|
||||
def begin_transaction() -> Union[_ProxyTransaction, ContextManager[None]]:
|
||||
"""Return a context manager that will
|
||||
enclose an operation within a "transaction",
|
||||
as defined by the environment's offline
|
||||
@@ -101,7 +97,7 @@ def configure(
|
||||
tag: Optional[str] = None,
|
||||
template_args: Optional[Dict[str, Any]] = None,
|
||||
render_as_batch: bool = False,
|
||||
target_metadata: Union[MetaData, Sequence[MetaData], None] = None,
|
||||
target_metadata: Optional[MetaData] = None,
|
||||
include_name: Optional[
|
||||
Callable[
|
||||
[
|
||||
@@ -163,8 +159,8 @@ def configure(
|
||||
MigrationContext,
|
||||
Column[Any],
|
||||
Column[Any],
|
||||
TypeEngine[Any],
|
||||
TypeEngine[Any],
|
||||
TypeEngine,
|
||||
TypeEngine,
|
||||
],
|
||||
Optional[bool],
|
||||
],
|
||||
@@ -639,8 +635,7 @@ def configure(
|
||||
"""
|
||||
|
||||
def execute(
|
||||
sql: Union[Executable, str],
|
||||
execution_options: Optional[Dict[str, Any]] = None,
|
||||
sql: Union[Executable, str], execution_options: Optional[dict] = None
|
||||
) -> None:
|
||||
"""Execute the given SQL using the current change context.
|
||||
|
||||
@@ -763,11 +758,7 @@ def get_x_argument(
|
||||
The return value is a list, returned directly from the ``argparse``
|
||||
structure. If ``as_dictionary=True`` is passed, the ``x`` arguments
|
||||
are parsed using ``key=value`` format into a dictionary that is
|
||||
then returned. If there is no ``=`` in the argument, value is an empty
|
||||
string.
|
||||
|
||||
.. versionchanged:: 1.13.1 Support ``as_dictionary=True`` when
|
||||
arguments are passed without the ``=`` symbol.
|
||||
then returned.
|
||||
|
||||
For example, to support passing a database URL on the command line,
|
||||
the standard ``env.py`` script can be modified like this::
|
||||
@@ -809,7 +800,7 @@ def is_offline_mode() -> bool:
|
||||
|
||||
"""
|
||||
|
||||
def is_transactional_ddl() -> bool:
|
||||
def is_transactional_ddl():
|
||||
"""Return True if the context is configured to expect a
|
||||
transactional DDL capable backend.
|
||||
|
||||
|
||||
@@ -3,4 +3,4 @@ from . import mysql
|
||||
from . import oracle
|
||||
from . import postgresql
|
||||
from . import sqlite
|
||||
from .impl import DefaultImpl as DefaultImpl
|
||||
from .impl import DefaultImpl
|
||||
|
||||
@@ -1,329 +0,0 @@
|
||||
# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
|
||||
# mypy: no-warn-return-any, allow-any-generics
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from typing import ClassVar
|
||||
from typing import Dict
|
||||
from typing import Generic
|
||||
from typing import NamedTuple
|
||||
from typing import Optional
|
||||
from typing import Sequence
|
||||
from typing import Tuple
|
||||
from typing import Type
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TypeVar
|
||||
from typing import Union
|
||||
|
||||
from sqlalchemy.sql.schema import Constraint
|
||||
from sqlalchemy.sql.schema import ForeignKeyConstraint
|
||||
from sqlalchemy.sql.schema import Index
|
||||
from sqlalchemy.sql.schema import UniqueConstraint
|
||||
from typing_extensions import TypeGuard
|
||||
|
||||
from .. import util
|
||||
from ..util import sqla_compat
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Literal
|
||||
|
||||
from alembic.autogenerate.api import AutogenContext
|
||||
from alembic.ddl.impl import DefaultImpl
|
||||
|
||||
CompareConstraintType = Union[Constraint, Index]
|
||||
|
||||
_C = TypeVar("_C", bound=CompareConstraintType)
|
||||
|
||||
_clsreg: Dict[str, Type[_constraint_sig]] = {}
|
||||
|
||||
|
||||
class ComparisonResult(NamedTuple):
|
||||
status: Literal["equal", "different", "skip"]
|
||||
message: str
|
||||
|
||||
@property
|
||||
def is_equal(self) -> bool:
|
||||
return self.status == "equal"
|
||||
|
||||
@property
|
||||
def is_different(self) -> bool:
|
||||
return self.status == "different"
|
||||
|
||||
@property
|
||||
def is_skip(self) -> bool:
|
||||
return self.status == "skip"
|
||||
|
||||
@classmethod
|
||||
def Equal(cls) -> ComparisonResult:
|
||||
"""the constraints are equal."""
|
||||
return cls("equal", "The two constraints are equal")
|
||||
|
||||
@classmethod
|
||||
def Different(cls, reason: Union[str, Sequence[str]]) -> ComparisonResult:
|
||||
"""the constraints are different for the provided reason(s)."""
|
||||
return cls("different", ", ".join(util.to_list(reason)))
|
||||
|
||||
@classmethod
|
||||
def Skip(cls, reason: Union[str, Sequence[str]]) -> ComparisonResult:
|
||||
"""the constraint cannot be compared for the provided reason(s).
|
||||
|
||||
The message is logged, but the constraints will be otherwise
|
||||
considered equal, meaning that no migration command will be
|
||||
generated.
|
||||
"""
|
||||
return cls("skip", ", ".join(util.to_list(reason)))
|
||||
|
||||
|
||||
class _constraint_sig(Generic[_C]):
|
||||
const: _C
|
||||
|
||||
_sig: Tuple[Any, ...]
|
||||
name: Optional[sqla_compat._ConstraintNameDefined]
|
||||
|
||||
impl: DefaultImpl
|
||||
|
||||
_is_index: ClassVar[bool] = False
|
||||
_is_fk: ClassVar[bool] = False
|
||||
_is_uq: ClassVar[bool] = False
|
||||
|
||||
_is_metadata: bool
|
||||
|
||||
def __init_subclass__(cls) -> None:
|
||||
cls._register()
|
||||
|
||||
@classmethod
|
||||
def _register(cls):
|
||||
raise NotImplementedError()
|
||||
|
||||
def __init__(
|
||||
self, is_metadata: bool, impl: DefaultImpl, const: _C
|
||||
) -> None:
|
||||
raise NotImplementedError()
|
||||
|
||||
def compare_to_reflected(
|
||||
self, other: _constraint_sig[Any]
|
||||
) -> ComparisonResult:
|
||||
assert self.impl is other.impl
|
||||
assert self._is_metadata
|
||||
assert not other._is_metadata
|
||||
|
||||
return self._compare_to_reflected(other)
|
||||
|
||||
def _compare_to_reflected(
|
||||
self, other: _constraint_sig[_C]
|
||||
) -> ComparisonResult:
|
||||
raise NotImplementedError()
|
||||
|
||||
@classmethod
|
||||
def from_constraint(
|
||||
cls, is_metadata: bool, impl: DefaultImpl, constraint: _C
|
||||
) -> _constraint_sig[_C]:
|
||||
# these could be cached by constraint/impl, however, if the
|
||||
# constraint is modified in place, then the sig is wrong. the mysql
|
||||
# impl currently does this, and if we fixed that we can't be sure
|
||||
# someone else might do it too, so play it safe.
|
||||
sig = _clsreg[constraint.__visit_name__](is_metadata, impl, constraint)
|
||||
return sig
|
||||
|
||||
def md_name_to_sql_name(self, context: AutogenContext) -> Optional[str]:
|
||||
return sqla_compat._get_constraint_final_name(
|
||||
self.const, context.dialect
|
||||
)
|
||||
|
||||
@util.memoized_property
|
||||
def is_named(self):
|
||||
return sqla_compat._constraint_is_named(self.const, self.impl.dialect)
|
||||
|
||||
@util.memoized_property
|
||||
def unnamed(self) -> Tuple[Any, ...]:
|
||||
return self._sig
|
||||
|
||||
@util.memoized_property
|
||||
def unnamed_no_options(self) -> Tuple[Any, ...]:
|
||||
raise NotImplementedError()
|
||||
|
||||
@util.memoized_property
|
||||
def _full_sig(self) -> Tuple[Any, ...]:
|
||||
return (self.name,) + self.unnamed
|
||||
|
||||
def __eq__(self, other) -> bool:
|
||||
return self._full_sig == other._full_sig
|
||||
|
||||
def __ne__(self, other) -> bool:
|
||||
return self._full_sig != other._full_sig
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(self._full_sig)
|
||||
|
||||
|
||||
class _uq_constraint_sig(_constraint_sig[UniqueConstraint]):
|
||||
_is_uq = True
|
||||
|
||||
@classmethod
|
||||
def _register(cls) -> None:
|
||||
_clsreg["unique_constraint"] = cls
|
||||
|
||||
is_unique = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
is_metadata: bool,
|
||||
impl: DefaultImpl,
|
||||
const: UniqueConstraint,
|
||||
) -> None:
|
||||
self.impl = impl
|
||||
self.const = const
|
||||
self.name = sqla_compat.constraint_name_or_none(const.name)
|
||||
self._sig = tuple(sorted([col.name for col in const.columns]))
|
||||
self._is_metadata = is_metadata
|
||||
|
||||
@property
|
||||
def column_names(self) -> Tuple[str, ...]:
|
||||
return tuple([col.name for col in self.const.columns])
|
||||
|
||||
def _compare_to_reflected(
|
||||
self, other: _constraint_sig[_C]
|
||||
) -> ComparisonResult:
|
||||
assert self._is_metadata
|
||||
metadata_obj = self
|
||||
conn_obj = other
|
||||
|
||||
assert is_uq_sig(conn_obj)
|
||||
return self.impl.compare_unique_constraint(
|
||||
metadata_obj.const, conn_obj.const
|
||||
)
|
||||
|
||||
|
||||
class _ix_constraint_sig(_constraint_sig[Index]):
|
||||
_is_index = True
|
||||
|
||||
name: sqla_compat._ConstraintName
|
||||
|
||||
@classmethod
|
||||
def _register(cls) -> None:
|
||||
_clsreg["index"] = cls
|
||||
|
||||
def __init__(
|
||||
self, is_metadata: bool, impl: DefaultImpl, const: Index
|
||||
) -> None:
|
||||
self.impl = impl
|
||||
self.const = const
|
||||
self.name = const.name
|
||||
self.is_unique = bool(const.unique)
|
||||
self._is_metadata = is_metadata
|
||||
|
||||
def _compare_to_reflected(
|
||||
self, other: _constraint_sig[_C]
|
||||
) -> ComparisonResult:
|
||||
assert self._is_metadata
|
||||
metadata_obj = self
|
||||
conn_obj = other
|
||||
|
||||
assert is_index_sig(conn_obj)
|
||||
return self.impl.compare_indexes(metadata_obj.const, conn_obj.const)
|
||||
|
||||
@util.memoized_property
|
||||
def has_expressions(self):
|
||||
return sqla_compat.is_expression_index(self.const)
|
||||
|
||||
@util.memoized_property
|
||||
def column_names(self) -> Tuple[str, ...]:
|
||||
return tuple([col.name for col in self.const.columns])
|
||||
|
||||
@util.memoized_property
|
||||
def column_names_optional(self) -> Tuple[Optional[str], ...]:
|
||||
return tuple(
|
||||
[getattr(col, "name", None) for col in self.const.expressions]
|
||||
)
|
||||
|
||||
@util.memoized_property
|
||||
def is_named(self):
|
||||
return True
|
||||
|
||||
@util.memoized_property
|
||||
def unnamed(self):
|
||||
return (self.is_unique,) + self.column_names_optional
|
||||
|
||||
|
||||
class _fk_constraint_sig(_constraint_sig[ForeignKeyConstraint]):
|
||||
_is_fk = True
|
||||
|
||||
@classmethod
|
||||
def _register(cls) -> None:
|
||||
_clsreg["foreign_key_constraint"] = cls
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
is_metadata: bool,
|
||||
impl: DefaultImpl,
|
||||
const: ForeignKeyConstraint,
|
||||
) -> None:
|
||||
self._is_metadata = is_metadata
|
||||
|
||||
self.impl = impl
|
||||
self.const = const
|
||||
|
||||
self.name = sqla_compat.constraint_name_or_none(const.name)
|
||||
|
||||
(
|
||||
self.source_schema,
|
||||
self.source_table,
|
||||
self.source_columns,
|
||||
self.target_schema,
|
||||
self.target_table,
|
||||
self.target_columns,
|
||||
onupdate,
|
||||
ondelete,
|
||||
deferrable,
|
||||
initially,
|
||||
) = sqla_compat._fk_spec(const)
|
||||
|
||||
self._sig: Tuple[Any, ...] = (
|
||||
self.source_schema,
|
||||
self.source_table,
|
||||
tuple(self.source_columns),
|
||||
self.target_schema,
|
||||
self.target_table,
|
||||
tuple(self.target_columns),
|
||||
) + (
|
||||
(
|
||||
(None if onupdate.lower() == "no action" else onupdate.lower())
|
||||
if onupdate
|
||||
else None
|
||||
),
|
||||
(
|
||||
(None if ondelete.lower() == "no action" else ondelete.lower())
|
||||
if ondelete
|
||||
else None
|
||||
),
|
||||
# convert initially + deferrable into one three-state value
|
||||
(
|
||||
"initially_deferrable"
|
||||
if initially and initially.lower() == "deferred"
|
||||
else "deferrable" if deferrable else "not deferrable"
|
||||
),
|
||||
)
|
||||
|
||||
@util.memoized_property
|
||||
def unnamed_no_options(self):
|
||||
return (
|
||||
self.source_schema,
|
||||
self.source_table,
|
||||
tuple(self.source_columns),
|
||||
self.target_schema,
|
||||
self.target_table,
|
||||
tuple(self.target_columns),
|
||||
)
|
||||
|
||||
|
||||
def is_index_sig(sig: _constraint_sig) -> TypeGuard[_ix_constraint_sig]:
|
||||
return sig._is_index
|
||||
|
||||
|
||||
def is_uq_sig(sig: _constraint_sig) -> TypeGuard[_uq_constraint_sig]:
|
||||
return sig._is_uq
|
||||
|
||||
|
||||
def is_fk_sig(sig: _constraint_sig) -> TypeGuard[_fk_constraint_sig]:
|
||||
return sig._is_fk
|
||||
@@ -1,6 +1,3 @@
|
||||
# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
|
||||
# mypy: no-warn-return-any, allow-any-generics
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
@@ -25,8 +22,6 @@ from ..util.sqla_compat import _table_for_constraint # noqa
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import Computed
|
||||
from sqlalchemy import Identity
|
||||
from sqlalchemy.sql.compiler import Compiled
|
||||
from sqlalchemy.sql.compiler import DDLCompiler
|
||||
from sqlalchemy.sql.elements import TextClause
|
||||
@@ -35,11 +30,14 @@ if TYPE_CHECKING:
|
||||
from sqlalchemy.sql.type_api import TypeEngine
|
||||
|
||||
from .impl import DefaultImpl
|
||||
from ..util.sqla_compat import Computed
|
||||
from ..util.sqla_compat import Identity
|
||||
|
||||
_ServerDefault = Union["TextClause", "FetchedValue", "Function[Any]", str]
|
||||
|
||||
|
||||
class AlterTable(DDLElement):
|
||||
|
||||
"""Represent an ALTER TABLE statement.
|
||||
|
||||
Only the string name and optional schema name of the table
|
||||
@@ -154,24 +152,17 @@ class AddColumn(AlterTable):
|
||||
name: str,
|
||||
column: Column[Any],
|
||||
schema: Optional[Union[quoted_name, str]] = None,
|
||||
if_not_exists: Optional[bool] = None,
|
||||
) -> None:
|
||||
super().__init__(name, schema=schema)
|
||||
self.column = column
|
||||
self.if_not_exists = if_not_exists
|
||||
|
||||
|
||||
class DropColumn(AlterTable):
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
column: Column[Any],
|
||||
schema: Optional[str] = None,
|
||||
if_exists: Optional[bool] = None,
|
||||
self, name: str, column: Column[Any], schema: Optional[str] = None
|
||||
) -> None:
|
||||
super().__init__(name, schema=schema)
|
||||
self.column = column
|
||||
self.if_exists = if_exists
|
||||
|
||||
|
||||
class ColumnComment(AlterColumn):
|
||||
@@ -196,9 +187,7 @@ def visit_rename_table(
|
||||
def visit_add_column(element: AddColumn, compiler: DDLCompiler, **kw) -> str:
|
||||
return "%s %s" % (
|
||||
alter_table(compiler, element.table_name, element.schema),
|
||||
add_column(
|
||||
compiler, element.column, if_not_exists=element.if_not_exists, **kw
|
||||
),
|
||||
add_column(compiler, element.column, **kw),
|
||||
)
|
||||
|
||||
|
||||
@@ -206,9 +195,7 @@ def visit_add_column(element: AddColumn, compiler: DDLCompiler, **kw) -> str:
|
||||
def visit_drop_column(element: DropColumn, compiler: DDLCompiler, **kw) -> str:
|
||||
return "%s %s" % (
|
||||
alter_table(compiler, element.table_name, element.schema),
|
||||
drop_column(
|
||||
compiler, element.column.name, if_exists=element.if_exists, **kw
|
||||
),
|
||||
drop_column(compiler, element.column.name, **kw),
|
||||
)
|
||||
|
||||
|
||||
@@ -248,11 +235,9 @@ def visit_column_default(
|
||||
return "%s %s %s" % (
|
||||
alter_table(compiler, element.table_name, element.schema),
|
||||
alter_column(compiler, element.column_name),
|
||||
(
|
||||
"SET DEFAULT %s" % format_server_default(compiler, element.default)
|
||||
if element.default is not None
|
||||
else "DROP DEFAULT"
|
||||
),
|
||||
"SET DEFAULT %s" % format_server_default(compiler, element.default)
|
||||
if element.default is not None
|
||||
else "DROP DEFAULT",
|
||||
)
|
||||
|
||||
|
||||
@@ -310,13 +295,9 @@ def format_server_default(
|
||||
compiler: DDLCompiler,
|
||||
default: Optional[_ServerDefault],
|
||||
) -> str:
|
||||
# this can be updated to use compiler.render_default_string
|
||||
# for SQLAlchemy 2.0 and above; not in 1.4
|
||||
default_str = compiler.get_column_default_string(
|
||||
return compiler.get_column_default_string(
|
||||
Column("x", Integer, server_default=default)
|
||||
)
|
||||
assert default_str is not None
|
||||
return default_str
|
||||
|
||||
|
||||
def format_type(compiler: DDLCompiler, type_: TypeEngine) -> str:
|
||||
@@ -331,29 +312,16 @@ def alter_table(
|
||||
return "ALTER TABLE %s" % format_table_name(compiler, name, schema)
|
||||
|
||||
|
||||
def drop_column(
|
||||
compiler: DDLCompiler, name: str, if_exists: Optional[bool] = None, **kw
|
||||
) -> str:
|
||||
return "DROP COLUMN %s%s" % (
|
||||
"IF EXISTS " if if_exists else "",
|
||||
format_column_name(compiler, name),
|
||||
)
|
||||
def drop_column(compiler: DDLCompiler, name: str, **kw) -> str:
|
||||
return "DROP COLUMN %s" % format_column_name(compiler, name)
|
||||
|
||||
|
||||
def alter_column(compiler: DDLCompiler, name: str) -> str:
|
||||
return "ALTER COLUMN %s" % format_column_name(compiler, name)
|
||||
|
||||
|
||||
def add_column(
|
||||
compiler: DDLCompiler,
|
||||
column: Column[Any],
|
||||
if_not_exists: Optional[bool] = None,
|
||||
**kw,
|
||||
) -> str:
|
||||
text = "ADD COLUMN %s%s" % (
|
||||
"IF NOT EXISTS " if if_not_exists else "",
|
||||
compiler.get_column_specification(column, **kw),
|
||||
)
|
||||
def add_column(compiler: DDLCompiler, column: Column[Any], **kw) -> str:
|
||||
text = "ADD COLUMN %s" % compiler.get_column_specification(column, **kw)
|
||||
|
||||
const = " ".join(
|
||||
compiler.process(constraint) for constraint in column.constraints
|
||||
|
||||
@@ -1,9 +1,6 @@
|
||||
# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
|
||||
# mypy: no-warn-return-any, allow-any-generics
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections import namedtuple
|
||||
import re
|
||||
from typing import Any
|
||||
from typing import Callable
|
||||
@@ -11,7 +8,6 @@ from typing import Dict
|
||||
from typing import Iterable
|
||||
from typing import List
|
||||
from typing import Mapping
|
||||
from typing import NamedTuple
|
||||
from typing import Optional
|
||||
from typing import Sequence
|
||||
from typing import Set
|
||||
@@ -21,18 +17,10 @@ from typing import TYPE_CHECKING
|
||||
from typing import Union
|
||||
|
||||
from sqlalchemy import cast
|
||||
from sqlalchemy import Column
|
||||
from sqlalchemy import MetaData
|
||||
from sqlalchemy import PrimaryKeyConstraint
|
||||
from sqlalchemy import schema
|
||||
from sqlalchemy import String
|
||||
from sqlalchemy import Table
|
||||
from sqlalchemy import text
|
||||
|
||||
from . import _autogen
|
||||
from . import base
|
||||
from ._autogen import _constraint_sig as _constraint_sig
|
||||
from ._autogen import ComparisonResult as ComparisonResult
|
||||
from .. import util
|
||||
from ..util import sqla_compat
|
||||
|
||||
@@ -46,10 +34,13 @@ if TYPE_CHECKING:
|
||||
from sqlalchemy.engine.reflection import Inspector
|
||||
from sqlalchemy.sql import ClauseElement
|
||||
from sqlalchemy.sql import Executable
|
||||
from sqlalchemy.sql.elements import ColumnElement
|
||||
from sqlalchemy.sql.elements import quoted_name
|
||||
from sqlalchemy.sql.schema import Column
|
||||
from sqlalchemy.sql.schema import Constraint
|
||||
from sqlalchemy.sql.schema import ForeignKeyConstraint
|
||||
from sqlalchemy.sql.schema import Index
|
||||
from sqlalchemy.sql.schema import Table
|
||||
from sqlalchemy.sql.schema import UniqueConstraint
|
||||
from sqlalchemy.sql.selectable import TableClause
|
||||
from sqlalchemy.sql.type_api import TypeEngine
|
||||
@@ -59,8 +50,6 @@ if TYPE_CHECKING:
|
||||
from ..operations.batch import ApplyBatchImpl
|
||||
from ..operations.batch import BatchOperationsImpl
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ImplMeta(type):
|
||||
def __init__(
|
||||
@@ -77,8 +66,11 @@ class ImplMeta(type):
|
||||
|
||||
_impls: Dict[str, Type[DefaultImpl]] = {}
|
||||
|
||||
Params = namedtuple("Params", ["token0", "tokens", "args", "kwargs"])
|
||||
|
||||
|
||||
class DefaultImpl(metaclass=ImplMeta):
|
||||
|
||||
"""Provide the entrypoint for major migration operations,
|
||||
including database-specific behavioral variances.
|
||||
|
||||
@@ -138,40 +130,6 @@ class DefaultImpl(metaclass=ImplMeta):
|
||||
self.output_buffer.write(text + "\n\n")
|
||||
self.output_buffer.flush()
|
||||
|
||||
def version_table_impl(
|
||||
self,
|
||||
*,
|
||||
version_table: str,
|
||||
version_table_schema: Optional[str],
|
||||
version_table_pk: bool,
|
||||
**kw: Any,
|
||||
) -> Table:
|
||||
"""Generate a :class:`.Table` object which will be used as the
|
||||
structure for the Alembic version table.
|
||||
|
||||
Third party dialects may override this hook to provide an alternate
|
||||
structure for this :class:`.Table`; requirements are only that it
|
||||
be named based on the ``version_table`` parameter and contains
|
||||
at least a single string-holding column named ``version_num``.
|
||||
|
||||
.. versionadded:: 1.14
|
||||
|
||||
"""
|
||||
vt = Table(
|
||||
version_table,
|
||||
MetaData(),
|
||||
Column("version_num", String(32), nullable=False),
|
||||
schema=version_table_schema,
|
||||
)
|
||||
if version_table_pk:
|
||||
vt.append_constraint(
|
||||
PrimaryKeyConstraint(
|
||||
"version_num", name=f"{version_table}_pkc"
|
||||
)
|
||||
)
|
||||
|
||||
return vt
|
||||
|
||||
def requires_recreate_in_batch(
|
||||
self, batch_op: BatchOperationsImpl
|
||||
) -> bool:
|
||||
@@ -203,15 +161,16 @@ class DefaultImpl(metaclass=ImplMeta):
|
||||
def _exec(
|
||||
self,
|
||||
construct: Union[Executable, str],
|
||||
execution_options: Optional[Mapping[str, Any]] = None,
|
||||
multiparams: Optional[Sequence[Mapping[str, Any]]] = None,
|
||||
params: Mapping[str, Any] = util.immutabledict(),
|
||||
execution_options: Optional[dict[str, Any]] = None,
|
||||
multiparams: Sequence[dict] = (),
|
||||
params: Dict[str, Any] = util.immutabledict(),
|
||||
) -> Optional[CursorResult]:
|
||||
if isinstance(construct, str):
|
||||
construct = text(construct)
|
||||
if self.as_sql:
|
||||
if multiparams is not None or params:
|
||||
raise TypeError("SQL parameters not allowed with as_sql")
|
||||
if multiparams or params:
|
||||
# TODO: coverage
|
||||
raise Exception("Execution arguments not allowed with as_sql")
|
||||
|
||||
compile_kw: dict[str, Any]
|
||||
if self.literal_binds and not isinstance(
|
||||
@@ -234,16 +193,11 @@ class DefaultImpl(metaclass=ImplMeta):
|
||||
assert conn is not None
|
||||
if execution_options:
|
||||
conn = conn.execution_options(**execution_options)
|
||||
if params:
|
||||
assert isinstance(multiparams, tuple)
|
||||
multiparams += (params,)
|
||||
|
||||
if params and multiparams is not None:
|
||||
raise TypeError(
|
||||
"Can't send params and multiparams at the same time"
|
||||
)
|
||||
|
||||
if multiparams:
|
||||
return conn.execute(construct, multiparams)
|
||||
else:
|
||||
return conn.execute(construct, params)
|
||||
return conn.execute(construct, multiparams)
|
||||
|
||||
def execute(
|
||||
self,
|
||||
@@ -256,11 +210,8 @@ class DefaultImpl(metaclass=ImplMeta):
|
||||
self,
|
||||
table_name: str,
|
||||
column_name: str,
|
||||
*,
|
||||
nullable: Optional[bool] = None,
|
||||
server_default: Optional[
|
||||
Union[_ServerDefault, Literal[False]]
|
||||
] = False,
|
||||
server_default: Union[_ServerDefault, Literal[False]] = False,
|
||||
name: Optional[str] = None,
|
||||
type_: Optional[TypeEngine] = None,
|
||||
schema: Optional[str] = None,
|
||||
@@ -371,40 +322,25 @@ class DefaultImpl(metaclass=ImplMeta):
|
||||
self,
|
||||
table_name: str,
|
||||
column: Column[Any],
|
||||
*,
|
||||
schema: Optional[Union[str, quoted_name]] = None,
|
||||
if_not_exists: Optional[bool] = None,
|
||||
) -> None:
|
||||
self._exec(
|
||||
base.AddColumn(
|
||||
table_name,
|
||||
column,
|
||||
schema=schema,
|
||||
if_not_exists=if_not_exists,
|
||||
)
|
||||
)
|
||||
self._exec(base.AddColumn(table_name, column, schema=schema))
|
||||
|
||||
def drop_column(
|
||||
self,
|
||||
table_name: str,
|
||||
column: Column[Any],
|
||||
*,
|
||||
schema: Optional[str] = None,
|
||||
if_exists: Optional[bool] = None,
|
||||
**kw,
|
||||
) -> None:
|
||||
self._exec(
|
||||
base.DropColumn(
|
||||
table_name, column, schema=schema, if_exists=if_exists
|
||||
)
|
||||
)
|
||||
self._exec(base.DropColumn(table_name, column, schema=schema))
|
||||
|
||||
def add_constraint(self, const: Any) -> None:
|
||||
if const._create_rule is None or const._create_rule(self):
|
||||
self._exec(schema.AddConstraint(const))
|
||||
|
||||
def drop_constraint(self, const: Constraint, **kw: Any) -> None:
|
||||
self._exec(schema.DropConstraint(const, **kw))
|
||||
def drop_constraint(self, const: Constraint) -> None:
|
||||
self._exec(schema.DropConstraint(const))
|
||||
|
||||
def rename_table(
|
||||
self,
|
||||
@@ -416,11 +352,11 @@ class DefaultImpl(metaclass=ImplMeta):
|
||||
base.RenameTable(old_table_name, new_table_name, schema=schema)
|
||||
)
|
||||
|
||||
def create_table(self, table: Table, **kw: Any) -> None:
|
||||
def create_table(self, table: Table) -> None:
|
||||
table.dispatch.before_create(
|
||||
table, self.connection, checkfirst=False, _ddl_runner=self
|
||||
)
|
||||
self._exec(schema.CreateTable(table, **kw))
|
||||
self._exec(schema.CreateTable(table))
|
||||
table.dispatch.after_create(
|
||||
table, self.connection, checkfirst=False, _ddl_runner=self
|
||||
)
|
||||
@@ -439,11 +375,11 @@ class DefaultImpl(metaclass=ImplMeta):
|
||||
if comment and with_comment:
|
||||
self.create_column_comment(column)
|
||||
|
||||
def drop_table(self, table: Table, **kw: Any) -> None:
|
||||
def drop_table(self, table: Table) -> None:
|
||||
table.dispatch.before_drop(
|
||||
table, self.connection, checkfirst=False, _ddl_runner=self
|
||||
)
|
||||
self._exec(schema.DropTable(table, **kw))
|
||||
self._exec(schema.DropTable(table))
|
||||
table.dispatch.after_drop(
|
||||
table, self.connection, checkfirst=False, _ddl_runner=self
|
||||
)
|
||||
@@ -457,7 +393,7 @@ class DefaultImpl(metaclass=ImplMeta):
|
||||
def drop_table_comment(self, table: Table) -> None:
|
||||
self._exec(schema.DropTableComment(table))
|
||||
|
||||
def create_column_comment(self, column: Column[Any]) -> None:
|
||||
def create_column_comment(self, column: ColumnElement[Any]) -> None:
|
||||
self._exec(schema.SetColumnComment(column))
|
||||
|
||||
def drop_index(self, index: Index, **kw: Any) -> None:
|
||||
@@ -476,19 +412,15 @@ class DefaultImpl(metaclass=ImplMeta):
|
||||
if self.as_sql:
|
||||
for row in rows:
|
||||
self._exec(
|
||||
table.insert()
|
||||
.inline()
|
||||
.values(
|
||||
sqla_compat._insert_inline(table).values(
|
||||
**{
|
||||
k: (
|
||||
sqla_compat._literal_bindparam(
|
||||
k, v, type_=table.c[k].type
|
||||
)
|
||||
if not isinstance(
|
||||
v, sqla_compat._literal_bindparam
|
||||
)
|
||||
else v
|
||||
k: sqla_compat._literal_bindparam(
|
||||
k, v, type_=table.c[k].type
|
||||
)
|
||||
if not isinstance(
|
||||
v, sqla_compat._literal_bindparam
|
||||
)
|
||||
else v
|
||||
for k, v in row.items()
|
||||
}
|
||||
)
|
||||
@@ -496,13 +428,16 @@ class DefaultImpl(metaclass=ImplMeta):
|
||||
else:
|
||||
if rows:
|
||||
if multiinsert:
|
||||
self._exec(table.insert().inline(), multiparams=rows)
|
||||
self._exec(
|
||||
sqla_compat._insert_inline(table), multiparams=rows
|
||||
)
|
||||
else:
|
||||
for row in rows:
|
||||
self._exec(table.insert().inline().values(**row))
|
||||
self._exec(
|
||||
sqla_compat._insert_inline(table).values(**row)
|
||||
)
|
||||
|
||||
def _tokenize_column_type(self, column: Column) -> Params:
|
||||
definition: str
|
||||
definition = self.dialect.type_compiler.process(column.type).lower()
|
||||
|
||||
# tokenize the SQLAlchemy-generated version of a type, so that
|
||||
@@ -517,9 +452,9 @@ class DefaultImpl(metaclass=ImplMeta):
|
||||
# varchar character set utf8
|
||||
#
|
||||
|
||||
tokens: List[str] = re.findall(r"[\w\-_]+|\(.+?\)", definition)
|
||||
tokens = re.findall(r"[\w\-_]+|\(.+?\)", definition)
|
||||
|
||||
term_tokens: List[str] = []
|
||||
term_tokens = []
|
||||
paren_term = None
|
||||
|
||||
for token in tokens:
|
||||
@@ -531,7 +466,6 @@ class DefaultImpl(metaclass=ImplMeta):
|
||||
params = Params(term_tokens[0], term_tokens[1:], [], {})
|
||||
|
||||
if paren_term:
|
||||
term: str
|
||||
for term in re.findall("[^(),]+", paren_term):
|
||||
if "=" in term:
|
||||
key, val = term.split("=")
|
||||
@@ -708,7 +642,7 @@ class DefaultImpl(metaclass=ImplMeta):
|
||||
diff, ignored = _compare_identity_options(
|
||||
metadata_identity,
|
||||
inspector_identity,
|
||||
schema.Identity(),
|
||||
sqla_compat.Identity(),
|
||||
skip={"always"},
|
||||
)
|
||||
|
||||
@@ -730,96 +664,15 @@ class DefaultImpl(metaclass=ImplMeta):
|
||||
bool(diff) or bool(metadata_identity) != bool(inspector_identity),
|
||||
)
|
||||
|
||||
def _compare_index_unique(
|
||||
self, metadata_index: Index, reflected_index: Index
|
||||
) -> Optional[str]:
|
||||
conn_unique = bool(reflected_index.unique)
|
||||
meta_unique = bool(metadata_index.unique)
|
||||
if conn_unique != meta_unique:
|
||||
return f"unique={conn_unique} to unique={meta_unique}"
|
||||
else:
|
||||
return None
|
||||
def create_index_sig(self, index: Index) -> Tuple[Any, ...]:
|
||||
# order of col matters in an index
|
||||
return tuple(col.name for col in index.columns)
|
||||
|
||||
def _create_metadata_constraint_sig(
|
||||
self, constraint: _autogen._C, **opts: Any
|
||||
) -> _constraint_sig[_autogen._C]:
|
||||
return _constraint_sig.from_constraint(True, self, constraint, **opts)
|
||||
|
||||
def _create_reflected_constraint_sig(
|
||||
self, constraint: _autogen._C, **opts: Any
|
||||
) -> _constraint_sig[_autogen._C]:
|
||||
return _constraint_sig.from_constraint(False, self, constraint, **opts)
|
||||
|
||||
def compare_indexes(
|
||||
self,
|
||||
metadata_index: Index,
|
||||
reflected_index: Index,
|
||||
) -> ComparisonResult:
|
||||
"""Compare two indexes by comparing the signature generated by
|
||||
``create_index_sig``.
|
||||
|
||||
This method returns a ``ComparisonResult``.
|
||||
"""
|
||||
msg: List[str] = []
|
||||
unique_msg = self._compare_index_unique(
|
||||
metadata_index, reflected_index
|
||||
)
|
||||
if unique_msg:
|
||||
msg.append(unique_msg)
|
||||
m_sig = self._create_metadata_constraint_sig(metadata_index)
|
||||
r_sig = self._create_reflected_constraint_sig(reflected_index)
|
||||
|
||||
assert _autogen.is_index_sig(m_sig)
|
||||
assert _autogen.is_index_sig(r_sig)
|
||||
|
||||
# The assumption is that the index have no expression
|
||||
for sig in m_sig, r_sig:
|
||||
if sig.has_expressions:
|
||||
log.warning(
|
||||
"Generating approximate signature for index %s. "
|
||||
"The dialect "
|
||||
"implementation should either skip expression indexes "
|
||||
"or provide a custom implementation.",
|
||||
sig.const,
|
||||
)
|
||||
|
||||
if m_sig.column_names != r_sig.column_names:
|
||||
msg.append(
|
||||
f"expression {r_sig.column_names} to {m_sig.column_names}"
|
||||
)
|
||||
|
||||
if msg:
|
||||
return ComparisonResult.Different(msg)
|
||||
else:
|
||||
return ComparisonResult.Equal()
|
||||
|
||||
def compare_unique_constraint(
|
||||
self,
|
||||
metadata_constraint: UniqueConstraint,
|
||||
reflected_constraint: UniqueConstraint,
|
||||
) -> ComparisonResult:
|
||||
"""Compare two unique constraints by comparing the two signatures.
|
||||
|
||||
The arguments are two tuples that contain the unique constraint and
|
||||
the signatures generated by ``create_unique_constraint_sig``.
|
||||
|
||||
This method returns a ``ComparisonResult``.
|
||||
"""
|
||||
metadata_tup = self._create_metadata_constraint_sig(
|
||||
metadata_constraint
|
||||
)
|
||||
reflected_tup = self._create_reflected_constraint_sig(
|
||||
reflected_constraint
|
||||
)
|
||||
|
||||
meta_sig = metadata_tup.unnamed
|
||||
conn_sig = reflected_tup.unnamed
|
||||
if conn_sig != meta_sig:
|
||||
return ComparisonResult.Different(
|
||||
f"expression {conn_sig} to {meta_sig}"
|
||||
)
|
||||
else:
|
||||
return ComparisonResult.Equal()
|
||||
def create_unique_constraint_sig(
|
||||
self, const: UniqueConstraint
|
||||
) -> Tuple[Any, ...]:
|
||||
# order of col does not matters in an unique constraint
|
||||
return tuple(sorted([col.name for col in const.columns]))
|
||||
|
||||
def _skip_functional_indexes(self, metadata_indexes, conn_indexes):
|
||||
conn_indexes_by_name = {c.name: c for c in conn_indexes}
|
||||
@@ -844,13 +697,6 @@ class DefaultImpl(metaclass=ImplMeta):
|
||||
return reflected_object.get("dialect_options", {})
|
||||
|
||||
|
||||
class Params(NamedTuple):
|
||||
token0: str
|
||||
tokens: List[str]
|
||||
args: List[str]
|
||||
kwargs: Dict[str, str]
|
||||
|
||||
|
||||
def _compare_identity_options(
|
||||
metadata_io: Union[schema.Identity, schema.Sequence, None],
|
||||
inspector_io: Union[schema.Identity, schema.Sequence, None],
|
||||
@@ -889,13 +735,12 @@ def _compare_identity_options(
|
||||
set(meta_d).union(insp_d),
|
||||
)
|
||||
if sqla_compat.identity_has_dialect_kwargs:
|
||||
assert hasattr(default_io, "dialect_kwargs")
|
||||
# use only the dialect kwargs in inspector_io since metadata_io
|
||||
# can have options for many backends
|
||||
check_dicts(
|
||||
getattr(metadata_io, "dialect_kwargs", {}),
|
||||
getattr(inspector_io, "dialect_kwargs", {}),
|
||||
default_io.dialect_kwargs,
|
||||
default_io.dialect_kwargs, # type: ignore[union-attr]
|
||||
getattr(inspector_io, "dialect_kwargs", {}),
|
||||
)
|
||||
|
||||
|
||||
@@ -1,6 +1,3 @@
|
||||
# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
|
||||
# mypy: no-warn-return-any, allow-any-generics
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
@@ -12,6 +9,7 @@ from typing import TYPE_CHECKING
|
||||
from typing import Union
|
||||
|
||||
from sqlalchemy import types as sqltypes
|
||||
from sqlalchemy.ext.compiler import compiles
|
||||
from sqlalchemy.schema import Column
|
||||
from sqlalchemy.schema import CreateIndex
|
||||
from sqlalchemy.sql.base import Executable
|
||||
@@ -32,7 +30,6 @@ from .base import RenameTable
|
||||
from .impl import DefaultImpl
|
||||
from .. import util
|
||||
from ..util import sqla_compat
|
||||
from ..util.sqla_compat import compiles
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Literal
|
||||
@@ -83,11 +80,10 @@ class MSSQLImpl(DefaultImpl):
|
||||
if self.as_sql and self.batch_separator:
|
||||
self.static_output(self.batch_separator)
|
||||
|
||||
def alter_column(
|
||||
def alter_column( # type:ignore[override]
|
||||
self,
|
||||
table_name: str,
|
||||
column_name: str,
|
||||
*,
|
||||
nullable: Optional[bool] = None,
|
||||
server_default: Optional[
|
||||
Union[_ServerDefault, Literal[False]]
|
||||
@@ -203,7 +199,6 @@ class MSSQLImpl(DefaultImpl):
|
||||
self,
|
||||
table_name: str,
|
||||
column: Column[Any],
|
||||
*,
|
||||
schema: Optional[str] = None,
|
||||
**kw,
|
||||
) -> None:
|
||||
|
||||
@@ -1,6 +1,3 @@
|
||||
# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
|
||||
# mypy: no-warn-return-any, allow-any-generics
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
@@ -11,9 +8,7 @@ from typing import Union
|
||||
|
||||
from sqlalchemy import schema
|
||||
from sqlalchemy import types as sqltypes
|
||||
from sqlalchemy.sql import elements
|
||||
from sqlalchemy.sql import functions
|
||||
from sqlalchemy.sql import operators
|
||||
from sqlalchemy.ext.compiler import compiles
|
||||
|
||||
from .base import alter_table
|
||||
from .base import AlterColumn
|
||||
@@ -25,16 +20,16 @@ from .base import format_column_name
|
||||
from .base import format_server_default
|
||||
from .impl import DefaultImpl
|
||||
from .. import util
|
||||
from ..autogenerate import compare
|
||||
from ..util import sqla_compat
|
||||
from ..util.sqla_compat import _is_mariadb
|
||||
from ..util.sqla_compat import _is_type_bound
|
||||
from ..util.sqla_compat import compiles
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Literal
|
||||
|
||||
from sqlalchemy.dialects.mysql.base import MySQLDDLCompiler
|
||||
from sqlalchemy.sql.ddl import DropConstraint
|
||||
from sqlalchemy.sql.elements import ClauseElement
|
||||
from sqlalchemy.sql.schema import Constraint
|
||||
from sqlalchemy.sql.type_api import TypeEngine
|
||||
|
||||
@@ -51,40 +46,12 @@ class MySQLImpl(DefaultImpl):
|
||||
)
|
||||
type_arg_extract = [r"character set ([\w\-_]+)", r"collate ([\w\-_]+)"]
|
||||
|
||||
def render_ddl_sql_expr(
|
||||
self,
|
||||
expr: ClauseElement,
|
||||
is_server_default: bool = False,
|
||||
is_index: bool = False,
|
||||
**kw: Any,
|
||||
) -> str:
|
||||
# apply Grouping to index expressions;
|
||||
# see https://github.com/sqlalchemy/sqlalchemy/blob/
|
||||
# 36da2eaf3e23269f2cf28420ae73674beafd0661/
|
||||
# lib/sqlalchemy/dialects/mysql/base.py#L2191
|
||||
if is_index and (
|
||||
isinstance(expr, elements.BinaryExpression)
|
||||
or (
|
||||
isinstance(expr, elements.UnaryExpression)
|
||||
and expr.modifier not in (operators.desc_op, operators.asc_op)
|
||||
)
|
||||
or isinstance(expr, functions.FunctionElement)
|
||||
):
|
||||
expr = elements.Grouping(expr)
|
||||
|
||||
return super().render_ddl_sql_expr(
|
||||
expr, is_server_default=is_server_default, is_index=is_index, **kw
|
||||
)
|
||||
|
||||
def alter_column(
|
||||
def alter_column( # type:ignore[override]
|
||||
self,
|
||||
table_name: str,
|
||||
column_name: str,
|
||||
*,
|
||||
nullable: Optional[bool] = None,
|
||||
server_default: Optional[
|
||||
Union[_ServerDefault, Literal[False]]
|
||||
] = False,
|
||||
server_default: Union[_ServerDefault, Literal[False]] = False,
|
||||
name: Optional[str] = None,
|
||||
type_: Optional[TypeEngine] = None,
|
||||
schema: Optional[str] = None,
|
||||
@@ -125,29 +92,21 @@ class MySQLImpl(DefaultImpl):
|
||||
column_name,
|
||||
schema=schema,
|
||||
newname=name if name is not None else column_name,
|
||||
nullable=(
|
||||
nullable
|
||||
if nullable is not None
|
||||
else (
|
||||
existing_nullable
|
||||
if existing_nullable is not None
|
||||
else True
|
||||
)
|
||||
),
|
||||
nullable=nullable
|
||||
if nullable is not None
|
||||
else existing_nullable
|
||||
if existing_nullable is not None
|
||||
else True,
|
||||
type_=type_ if type_ is not None else existing_type,
|
||||
default=(
|
||||
server_default
|
||||
if server_default is not False
|
||||
else existing_server_default
|
||||
),
|
||||
autoincrement=(
|
||||
autoincrement
|
||||
if autoincrement is not None
|
||||
else existing_autoincrement
|
||||
),
|
||||
comment=(
|
||||
comment if comment is not False else existing_comment
|
||||
),
|
||||
default=server_default
|
||||
if server_default is not False
|
||||
else existing_server_default,
|
||||
autoincrement=autoincrement
|
||||
if autoincrement is not None
|
||||
else existing_autoincrement,
|
||||
comment=comment
|
||||
if comment is not False
|
||||
else existing_comment,
|
||||
)
|
||||
)
|
||||
elif (
|
||||
@@ -162,29 +121,21 @@ class MySQLImpl(DefaultImpl):
|
||||
column_name,
|
||||
schema=schema,
|
||||
newname=name if name is not None else column_name,
|
||||
nullable=(
|
||||
nullable
|
||||
if nullable is not None
|
||||
else (
|
||||
existing_nullable
|
||||
if existing_nullable is not None
|
||||
else True
|
||||
)
|
||||
),
|
||||
nullable=nullable
|
||||
if nullable is not None
|
||||
else existing_nullable
|
||||
if existing_nullable is not None
|
||||
else True,
|
||||
type_=type_ if type_ is not None else existing_type,
|
||||
default=(
|
||||
server_default
|
||||
if server_default is not False
|
||||
else existing_server_default
|
||||
),
|
||||
autoincrement=(
|
||||
autoincrement
|
||||
if autoincrement is not None
|
||||
else existing_autoincrement
|
||||
),
|
||||
comment=(
|
||||
comment if comment is not False else existing_comment
|
||||
),
|
||||
default=server_default
|
||||
if server_default is not False
|
||||
else existing_server_default,
|
||||
autoincrement=autoincrement
|
||||
if autoincrement is not None
|
||||
else existing_autoincrement,
|
||||
comment=comment
|
||||
if comment is not False
|
||||
else existing_comment,
|
||||
)
|
||||
)
|
||||
elif server_default is not False:
|
||||
@@ -197,7 +148,6 @@ class MySQLImpl(DefaultImpl):
|
||||
def drop_constraint(
|
||||
self,
|
||||
const: Constraint,
|
||||
**kw: Any,
|
||||
) -> None:
|
||||
if isinstance(const, schema.CheckConstraint) and _is_type_bound(const):
|
||||
return
|
||||
@@ -207,11 +157,12 @@ class MySQLImpl(DefaultImpl):
|
||||
def _is_mysql_allowed_functional_default(
|
||||
self,
|
||||
type_: Optional[TypeEngine],
|
||||
server_default: Optional[Union[_ServerDefault, Literal[False]]],
|
||||
server_default: Union[_ServerDefault, Literal[False]],
|
||||
) -> bool:
|
||||
return (
|
||||
type_ is not None
|
||||
and type_._type_affinity is sqltypes.DateTime
|
||||
and type_._type_affinity # type:ignore[attr-defined]
|
||||
is sqltypes.DateTime
|
||||
and server_default is not None
|
||||
)
|
||||
|
||||
@@ -321,12 +272,10 @@ class MySQLImpl(DefaultImpl):
|
||||
|
||||
def correct_for_autogen_foreignkeys(self, conn_fks, metadata_fks):
|
||||
conn_fk_by_sig = {
|
||||
self._create_reflected_constraint_sig(fk).unnamed_no_options: fk
|
||||
for fk in conn_fks
|
||||
compare._fk_constraint_sig(fk).sig: fk for fk in conn_fks
|
||||
}
|
||||
metadata_fk_by_sig = {
|
||||
self._create_metadata_constraint_sig(fk).unnamed_no_options: fk
|
||||
for fk in metadata_fks
|
||||
compare._fk_constraint_sig(fk).sig: fk for fk in metadata_fks
|
||||
}
|
||||
|
||||
for sig in set(conn_fk_by_sig).intersection(metadata_fk_by_sig):
|
||||
@@ -358,7 +307,7 @@ class MySQLAlterDefault(AlterColumn):
|
||||
self,
|
||||
name: str,
|
||||
column_name: str,
|
||||
default: Optional[_ServerDefault],
|
||||
default: _ServerDefault,
|
||||
schema: Optional[str] = None,
|
||||
) -> None:
|
||||
super(AlterColumn, self).__init__(name, schema=schema)
|
||||
@@ -416,11 +365,9 @@ def _mysql_alter_default(
|
||||
return "%s ALTER COLUMN %s %s" % (
|
||||
alter_table(compiler, element.table_name, element.schema),
|
||||
format_column_name(compiler, element.column_name),
|
||||
(
|
||||
"SET DEFAULT %s" % format_server_default(compiler, element.default)
|
||||
if element.default is not None
|
||||
else "DROP DEFAULT"
|
||||
),
|
||||
"SET DEFAULT %s" % format_server_default(compiler, element.default)
|
||||
if element.default is not None
|
||||
else "DROP DEFAULT",
|
||||
)
|
||||
|
||||
|
||||
@@ -507,7 +454,7 @@ def _mysql_drop_constraint(
|
||||
# note that SQLAlchemy as of 1.2 does not yet support
|
||||
# DROP CONSTRAINT for MySQL/MariaDB, so we implement fully
|
||||
# here.
|
||||
if compiler.dialect.is_mariadb:
|
||||
if _is_mariadb(compiler.dialect):
|
||||
return "ALTER TABLE %s DROP CONSTRAINT %s" % (
|
||||
compiler.preparer.format_table(constraint.table),
|
||||
compiler.preparer.format_constraint(constraint),
|
||||
|
||||
@@ -1,6 +1,3 @@
|
||||
# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
|
||||
# mypy: no-warn-return-any, allow-any-generics
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
@@ -8,6 +5,7 @@ from typing import Any
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlalchemy.ext.compiler import compiles
|
||||
from sqlalchemy.sql import sqltypes
|
||||
|
||||
from .base import AddColumn
|
||||
@@ -24,7 +22,6 @@ from .base import format_type
|
||||
from .base import IdentityColumnDefault
|
||||
from .base import RenameTable
|
||||
from .impl import DefaultImpl
|
||||
from ..util.sqla_compat import compiles
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlalchemy.dialects.oracle.base import OracleDDLCompiler
|
||||
@@ -141,11 +138,9 @@ def visit_column_default(
|
||||
return "%s %s %s" % (
|
||||
alter_table(compiler, element.table_name, element.schema),
|
||||
alter_column(compiler, element.column_name),
|
||||
(
|
||||
"DEFAULT %s" % format_server_default(compiler, element.default)
|
||||
if element.default is not None
|
||||
else "DEFAULT NULL"
|
||||
),
|
||||
"DEFAULT %s" % format_server_default(compiler, element.default)
|
||||
if element.default is not None
|
||||
else "DEFAULT NULL",
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,3 @@
|
||||
# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
|
||||
# mypy: no-warn-return-any, allow-any-generics
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
@@ -16,19 +13,18 @@ from typing import TYPE_CHECKING
|
||||
from typing import Union
|
||||
|
||||
from sqlalchemy import Column
|
||||
from sqlalchemy import Float
|
||||
from sqlalchemy import Identity
|
||||
from sqlalchemy import literal_column
|
||||
from sqlalchemy import Numeric
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy import types as sqltypes
|
||||
from sqlalchemy.dialects.postgresql import BIGINT
|
||||
from sqlalchemy.dialects.postgresql import ExcludeConstraint
|
||||
from sqlalchemy.dialects.postgresql import INTEGER
|
||||
from sqlalchemy.schema import CreateIndex
|
||||
from sqlalchemy.sql import operators
|
||||
from sqlalchemy.sql.elements import ColumnClause
|
||||
from sqlalchemy.sql.elements import TextClause
|
||||
from sqlalchemy.sql.elements import UnaryExpression
|
||||
from sqlalchemy.sql.functions import FunctionElement
|
||||
from sqlalchemy.types import NULLTYPE
|
||||
|
||||
@@ -36,12 +32,12 @@ from .base import alter_column
|
||||
from .base import alter_table
|
||||
from .base import AlterColumn
|
||||
from .base import ColumnComment
|
||||
from .base import compiles
|
||||
from .base import format_column_name
|
||||
from .base import format_table_name
|
||||
from .base import format_type
|
||||
from .base import IdentityColumnDefault
|
||||
from .base import RenameTable
|
||||
from .impl import ComparisonResult
|
||||
from .impl import DefaultImpl
|
||||
from .. import util
|
||||
from ..autogenerate import render
|
||||
@@ -50,8 +46,6 @@ from ..operations import schemaobj
|
||||
from ..operations.base import BatchOperations
|
||||
from ..operations.base import Operations
|
||||
from ..util import sqla_compat
|
||||
from ..util.sqla_compat import compiles
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Literal
|
||||
@@ -136,28 +130,25 @@ class PostgresqlImpl(DefaultImpl):
|
||||
metadata_default = metadata_column.server_default.arg
|
||||
|
||||
if isinstance(metadata_default, str):
|
||||
if not isinstance(inspector_column.type, (Numeric, Float)):
|
||||
if not isinstance(inspector_column.type, Numeric):
|
||||
metadata_default = re.sub(r"^'|'$", "", metadata_default)
|
||||
metadata_default = f"'{metadata_default}'"
|
||||
|
||||
metadata_default = literal_column(metadata_default)
|
||||
|
||||
# run a real compare against the server
|
||||
conn = self.connection
|
||||
assert conn is not None
|
||||
return not conn.scalar(
|
||||
select(literal_column(conn_col_default) == metadata_default)
|
||||
return not self.connection.scalar(
|
||||
sqla_compat._select(
|
||||
literal_column(conn_col_default) == metadata_default
|
||||
)
|
||||
)
|
||||
|
||||
def alter_column(
|
||||
def alter_column( # type:ignore[override]
|
||||
self,
|
||||
table_name: str,
|
||||
column_name: str,
|
||||
*,
|
||||
nullable: Optional[bool] = None,
|
||||
server_default: Optional[
|
||||
Union[_ServerDefault, Literal[False]]
|
||||
] = False,
|
||||
server_default: Union[_ServerDefault, Literal[False]] = False,
|
||||
name: Optional[str] = None,
|
||||
type_: Optional[TypeEngine] = None,
|
||||
schema: Optional[str] = None,
|
||||
@@ -223,8 +214,7 @@ class PostgresqlImpl(DefaultImpl):
|
||||
"join pg_class t on t.oid=d.refobjid "
|
||||
"join pg_attribute a on a.attrelid=t.oid and "
|
||||
"a.attnum=d.refobjsubid "
|
||||
"where c.relkind='S' and "
|
||||
"c.oid=cast(:seqname as regclass)"
|
||||
"where c.relkind='S' and c.relname=:seqname"
|
||||
),
|
||||
seqname=seq_match.group(1),
|
||||
).first()
|
||||
@@ -262,60 +252,62 @@ class PostgresqlImpl(DefaultImpl):
|
||||
if not sqla_compat.sqla_2:
|
||||
self._skip_functional_indexes(metadata_indexes, conn_indexes)
|
||||
|
||||
# pg behavior regarding modifiers
|
||||
# | # | compiled sql | returned sql | regexp. group is removed |
|
||||
# | - | ---------------- | -----------------| ------------------------ |
|
||||
# | 1 | nulls first | nulls first | - |
|
||||
# | 2 | nulls last | | (?<! desc)( nulls last)$ |
|
||||
# | 3 | asc | | ( asc)$ |
|
||||
# | 4 | asc nulls first | nulls first | ( asc) nulls first$ |
|
||||
# | 5 | asc nulls last | | ( asc nulls last)$ |
|
||||
# | 6 | desc | desc | - |
|
||||
# | 7 | desc nulls first | desc | desc( nulls first)$ |
|
||||
# | 8 | desc nulls last | desc nulls last | - |
|
||||
_default_modifiers_re = ( # order of case 2 and 5 matters
|
||||
re.compile("( asc nulls last)$"), # case 5
|
||||
re.compile("(?<! desc)( nulls last)$"), # case 2
|
||||
re.compile("( asc)$"), # case 3
|
||||
re.compile("( asc) nulls first$"), # case 4
|
||||
re.compile(" desc( nulls first)$"), # case 7
|
||||
)
|
||||
|
||||
def _cleanup_index_expr(self, index: Index, expr: str) -> str:
|
||||
def _cleanup_index_expr(
|
||||
self, index: Index, expr: str, remove_suffix: str
|
||||
) -> str:
|
||||
# start = expr
|
||||
expr = expr.lower().replace('"', "").replace("'", "")
|
||||
if index.table is not None:
|
||||
# should not be needed, since include_table=False is in compile
|
||||
expr = expr.replace(f"{index.table.name.lower()}.", "")
|
||||
|
||||
while expr and expr[0] == "(" and expr[-1] == ")":
|
||||
expr = expr[1:-1]
|
||||
if "::" in expr:
|
||||
# strip :: cast. types can have spaces in them
|
||||
expr = re.sub(r"(::[\w ]+\w)", "", expr)
|
||||
|
||||
while expr and expr[0] == "(" and expr[-1] == ")":
|
||||
expr = expr[1:-1]
|
||||
if remove_suffix and expr.endswith(remove_suffix):
|
||||
expr = expr[: -len(remove_suffix)]
|
||||
|
||||
# NOTE: when parsing the connection expression this cleanup could
|
||||
# be skipped
|
||||
for rs in self._default_modifiers_re:
|
||||
if match := rs.search(expr):
|
||||
start, end = match.span(1)
|
||||
expr = expr[:start] + expr[end:]
|
||||
break
|
||||
|
||||
while expr and expr[0] == "(" and expr[-1] == ")":
|
||||
expr = expr[1:-1]
|
||||
|
||||
# strip casts
|
||||
cast_re = re.compile(r"cast\s*\(")
|
||||
if cast_re.match(expr):
|
||||
expr = cast_re.sub("", expr)
|
||||
# remove the as type
|
||||
expr = re.sub(r"as\s+[^)]+\)", "", expr)
|
||||
# remove spaces
|
||||
expr = expr.replace(" ", "")
|
||||
# print(f"START: {start} END: {expr}")
|
||||
return expr
|
||||
|
||||
def _dialect_options(
|
||||
def _default_modifiers(self, exp: ClauseElement) -> str:
|
||||
to_remove = ""
|
||||
while isinstance(exp, UnaryExpression):
|
||||
if exp.modifier is None:
|
||||
exp = exp.element
|
||||
else:
|
||||
op = exp.modifier
|
||||
if isinstance(exp.element, UnaryExpression):
|
||||
inner_op = exp.element.modifier
|
||||
else:
|
||||
inner_op = None
|
||||
if inner_op is None:
|
||||
if op == operators.asc_op:
|
||||
# default is asc
|
||||
to_remove = " asc"
|
||||
elif op == operators.nullslast_op:
|
||||
# default is nulls last
|
||||
to_remove = " nulls last"
|
||||
else:
|
||||
if (
|
||||
inner_op == operators.asc_op
|
||||
and op == operators.nullslast_op
|
||||
):
|
||||
# default is asc nulls last
|
||||
to_remove = " asc nulls last"
|
||||
elif (
|
||||
inner_op == operators.desc_op
|
||||
and op == operators.nullsfirst_op
|
||||
):
|
||||
# default for desc is nulls first
|
||||
to_remove = " nulls first"
|
||||
break
|
||||
return to_remove
|
||||
|
||||
def _dialect_sig(
|
||||
self, item: Union[Index, UniqueConstraint]
|
||||
) -> Tuple[Any, ...]:
|
||||
# only the positive case is returned by sqlalchemy reflection so
|
||||
@@ -324,93 +316,25 @@ class PostgresqlImpl(DefaultImpl):
|
||||
return ("nulls_not_distinct",)
|
||||
return ()
|
||||
|
||||
def compare_indexes(
|
||||
self,
|
||||
metadata_index: Index,
|
||||
reflected_index: Index,
|
||||
) -> ComparisonResult:
|
||||
msg = []
|
||||
unique_msg = self._compare_index_unique(
|
||||
metadata_index, reflected_index
|
||||
)
|
||||
if unique_msg:
|
||||
msg.append(unique_msg)
|
||||
m_exprs = metadata_index.expressions
|
||||
r_exprs = reflected_index.expressions
|
||||
if len(m_exprs) != len(r_exprs):
|
||||
msg.append(f"expression number {len(r_exprs)} to {len(m_exprs)}")
|
||||
if msg:
|
||||
# no point going further, return early
|
||||
return ComparisonResult.Different(msg)
|
||||
skip = []
|
||||
for pos, (m_e, r_e) in enumerate(zip(m_exprs, r_exprs), 1):
|
||||
m_compile = self._compile_element(m_e)
|
||||
m_text = self._cleanup_index_expr(metadata_index, m_compile)
|
||||
# print(f"META ORIG: {m_compile!r} CLEANUP: {m_text!r}")
|
||||
r_compile = self._compile_element(r_e)
|
||||
r_text = self._cleanup_index_expr(metadata_index, r_compile)
|
||||
# print(f"CONN ORIG: {r_compile!r} CLEANUP: {r_text!r}")
|
||||
if m_text == r_text:
|
||||
continue # expressions these are equal
|
||||
elif m_compile.strip().endswith("_ops") and (
|
||||
" " in m_compile or ")" in m_compile # is an expression
|
||||
):
|
||||
skip.append(
|
||||
f"expression #{pos} {m_compile!r} detected "
|
||||
"as including operator clause."
|
||||
)
|
||||
util.warn(
|
||||
f"Expression #{pos} {m_compile!r} in index "
|
||||
f"{reflected_index.name!r} detected to include "
|
||||
"an operator clause. Expression compare cannot proceed. "
|
||||
"Please move the operator clause to the "
|
||||
"``postgresql_ops`` dict to enable proper compare "
|
||||
"of the index expressions: "
|
||||
"https://docs.sqlalchemy.org/en/latest/dialects/postgresql.html#operator-classes", # noqa: E501
|
||||
)
|
||||
else:
|
||||
msg.append(f"expression #{pos} {r_compile!r} to {m_compile!r}")
|
||||
|
||||
m_options = self._dialect_options(metadata_index)
|
||||
r_options = self._dialect_options(reflected_index)
|
||||
if m_options != r_options:
|
||||
msg.extend(f"options {r_options} to {m_options}")
|
||||
|
||||
if msg:
|
||||
return ComparisonResult.Different(msg)
|
||||
elif skip:
|
||||
# if there are other changes detected don't skip the index
|
||||
return ComparisonResult.Skip(skip)
|
||||
else:
|
||||
return ComparisonResult.Equal()
|
||||
|
||||
def compare_unique_constraint(
|
||||
self,
|
||||
metadata_constraint: UniqueConstraint,
|
||||
reflected_constraint: UniqueConstraint,
|
||||
) -> ComparisonResult:
|
||||
metadata_tup = self._create_metadata_constraint_sig(
|
||||
metadata_constraint
|
||||
)
|
||||
reflected_tup = self._create_reflected_constraint_sig(
|
||||
reflected_constraint
|
||||
)
|
||||
|
||||
meta_sig = metadata_tup.unnamed
|
||||
conn_sig = reflected_tup.unnamed
|
||||
if conn_sig != meta_sig:
|
||||
return ComparisonResult.Different(
|
||||
f"expression {conn_sig} to {meta_sig}"
|
||||
def create_index_sig(self, index: Index) -> Tuple[Any, ...]:
|
||||
return tuple(
|
||||
self._cleanup_index_expr(
|
||||
index,
|
||||
*(
|
||||
(e, "")
|
||||
if isinstance(e, str)
|
||||
else (self._compile_element(e), self._default_modifiers(e))
|
||||
),
|
||||
)
|
||||
for e in index.expressions
|
||||
) + self._dialect_sig(index)
|
||||
|
||||
metadata_do = self._dialect_options(metadata_tup.const)
|
||||
conn_do = self._dialect_options(reflected_tup.const)
|
||||
if metadata_do != conn_do:
|
||||
return ComparisonResult.Different(
|
||||
f"expression {conn_do} to {metadata_do}"
|
||||
)
|
||||
|
||||
return ComparisonResult.Equal()
|
||||
def create_unique_constraint_sig(
|
||||
self, const: UniqueConstraint
|
||||
) -> Tuple[Any, ...]:
|
||||
return tuple(
|
||||
sorted([col.name for col in const.columns])
|
||||
) + self._dialect_sig(const)
|
||||
|
||||
def adjust_reflected_dialect_options(
|
||||
self, reflected_options: Dict[str, Any], kind: str
|
||||
@@ -421,9 +345,7 @@ class PostgresqlImpl(DefaultImpl):
|
||||
options.pop("postgresql_include", None)
|
||||
return options
|
||||
|
||||
def _compile_element(self, element: Union[ClauseElement, str]) -> str:
|
||||
if isinstance(element, str):
|
||||
return element
|
||||
def _compile_element(self, element: ClauseElement) -> str:
|
||||
return element.compile(
|
||||
dialect=self.dialect,
|
||||
compile_kwargs={"literal_binds": True, "include_table": False},
|
||||
@@ -590,7 +512,7 @@ def visit_identity_column(
|
||||
)
|
||||
else:
|
||||
text += "SET %s " % compiler.get_identity_options(
|
||||
Identity(**{attr: getattr(identity, attr)})
|
||||
sqla_compat.Identity(**{attr: getattr(identity, attr)})
|
||||
)
|
||||
return text
|
||||
|
||||
@@ -634,8 +556,9 @@ class CreateExcludeConstraintOp(ops.AddConstraintOp):
|
||||
return cls(
|
||||
constraint.name,
|
||||
constraint_table.name,
|
||||
[ # type: ignore
|
||||
(expr, op) for expr, name, op in constraint._render_exprs
|
||||
[
|
||||
(expr, op)
|
||||
for expr, name, op in constraint._render_exprs # type:ignore[attr-defined] # noqa
|
||||
],
|
||||
where=cast("ColumnElement[bool] | None", constraint.where),
|
||||
schema=constraint_table.schema,
|
||||
@@ -662,7 +585,7 @@ class CreateExcludeConstraintOp(ops.AddConstraintOp):
|
||||
expr,
|
||||
name,
|
||||
oper,
|
||||
) in excl._render_exprs:
|
||||
) in excl._render_exprs: # type:ignore[attr-defined]
|
||||
t.append_column(Column(name, NULLTYPE))
|
||||
t.append_constraint(excl)
|
||||
return excl
|
||||
@@ -720,7 +643,7 @@ class CreateExcludeConstraintOp(ops.AddConstraintOp):
|
||||
constraint_name: str,
|
||||
*elements: Any,
|
||||
**kw: Any,
|
||||
) -> Optional[Table]:
|
||||
):
|
||||
"""Issue a "create exclude constraint" instruction using the
|
||||
current batch migration context.
|
||||
|
||||
@@ -792,13 +715,10 @@ def _exclude_constraint(
|
||||
args = [
|
||||
"(%s, %r)"
|
||||
% (
|
||||
_render_potential_column(
|
||||
sqltext, # type:ignore[arg-type]
|
||||
autogen_context,
|
||||
),
|
||||
_render_potential_column(sqltext, autogen_context),
|
||||
opstring,
|
||||
)
|
||||
for sqltext, name, opstring in constraint._render_exprs
|
||||
for sqltext, name, opstring in constraint._render_exprs # type:ignore[attr-defined] # noqa
|
||||
]
|
||||
if constraint.where is not None:
|
||||
args.append(
|
||||
@@ -850,5 +770,5 @@ def _render_potential_column(
|
||||
return render._render_potential_expr(
|
||||
value,
|
||||
autogen_context,
|
||||
wrap_in_element=isinstance(value, (TextClause, FunctionElement)),
|
||||
wrap_in_text=isinstance(value, (TextClause, FunctionElement)),
|
||||
)
|
||||
|
||||
@@ -1,6 +1,3 @@
|
||||
# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
|
||||
# mypy: no-warn-return-any, allow-any-generics
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
@@ -11,19 +8,16 @@ from typing import TYPE_CHECKING
|
||||
from typing import Union
|
||||
|
||||
from sqlalchemy import cast
|
||||
from sqlalchemy import Computed
|
||||
from sqlalchemy import JSON
|
||||
from sqlalchemy import schema
|
||||
from sqlalchemy import sql
|
||||
from sqlalchemy.ext.compiler import compiles
|
||||
|
||||
from .base import alter_table
|
||||
from .base import ColumnName
|
||||
from .base import format_column_name
|
||||
from .base import format_table_name
|
||||
from .base import RenameTable
|
||||
from .impl import DefaultImpl
|
||||
from .. import util
|
||||
from ..util.sqla_compat import compiles
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlalchemy.engine.reflection import Inspector
|
||||
@@ -65,7 +59,7 @@ class SQLiteImpl(DefaultImpl):
|
||||
) and isinstance(col.server_default.arg, sql.ClauseElement):
|
||||
return True
|
||||
elif (
|
||||
isinstance(col.server_default, Computed)
|
||||
isinstance(col.server_default, util.sqla_compat.Computed)
|
||||
and col.server_default.persisted
|
||||
):
|
||||
return True
|
||||
@@ -77,13 +71,13 @@ class SQLiteImpl(DefaultImpl):
|
||||
def add_constraint(self, const: Constraint):
|
||||
# attempt to distinguish between an
|
||||
# auto-gen constraint and an explicit one
|
||||
if const._create_rule is None:
|
||||
if const._create_rule is None: # type:ignore[attr-defined]
|
||||
raise NotImplementedError(
|
||||
"No support for ALTER of constraints in SQLite dialect. "
|
||||
"Please refer to the batch mode feature which allows for "
|
||||
"SQLite migrations using a copy-and-move strategy."
|
||||
)
|
||||
elif const._create_rule(self):
|
||||
elif const._create_rule(self): # type:ignore[attr-defined]
|
||||
util.warn(
|
||||
"Skipping unsupported ALTER for "
|
||||
"creation of implicit constraint. "
|
||||
@@ -91,8 +85,8 @@ class SQLiteImpl(DefaultImpl):
|
||||
"SQLite migrations using a copy-and-move strategy."
|
||||
)
|
||||
|
||||
def drop_constraint(self, const: Constraint, **kw: Any):
|
||||
if const._create_rule is None:
|
||||
def drop_constraint(self, const: Constraint):
|
||||
if const._create_rule is None: # type:ignore[attr-defined]
|
||||
raise NotImplementedError(
|
||||
"No support for ALTER of constraints in SQLite dialect. "
|
||||
"Please refer to the batch mode feature which allows for "
|
||||
@@ -183,7 +177,8 @@ class SQLiteImpl(DefaultImpl):
|
||||
new_type: TypeEngine,
|
||||
) -> None:
|
||||
if (
|
||||
existing.type._type_affinity is not new_type._type_affinity
|
||||
existing.type._type_affinity # type:ignore[attr-defined]
|
||||
is not new_type._type_affinity # type:ignore[attr-defined]
|
||||
and not isinstance(new_type, JSON)
|
||||
):
|
||||
existing_transfer["expr"] = cast(
|
||||
@@ -210,15 +205,6 @@ def visit_rename_table(
|
||||
)
|
||||
|
||||
|
||||
@compiles(ColumnName, "sqlite")
|
||||
def visit_column_name(element: ColumnName, compiler: DDLCompiler, **kw) -> str:
|
||||
return "%s RENAME COLUMN %s TO %s" % (
|
||||
alter_table(compiler, element.table_name, element.schema),
|
||||
format_column_name(compiler, element.column_name),
|
||||
format_column_name(compiler, element.newname),
|
||||
)
|
||||
|
||||
|
||||
# @compiles(AddColumn, 'sqlite')
|
||||
# def visit_add_column(element, compiler, **kw):
|
||||
# return "%s %s" % (
|
||||
|
||||
@@ -12,7 +12,6 @@ from typing import List
|
||||
from typing import Literal
|
||||
from typing import Mapping
|
||||
from typing import Optional
|
||||
from typing import overload
|
||||
from typing import Sequence
|
||||
from typing import Tuple
|
||||
from typing import Type
|
||||
@@ -27,6 +26,7 @@ if TYPE_CHECKING:
|
||||
from sqlalchemy.sql.elements import conv
|
||||
from sqlalchemy.sql.elements import TextClause
|
||||
from sqlalchemy.sql.expression import TableClause
|
||||
from sqlalchemy.sql.functions import Function
|
||||
from sqlalchemy.sql.schema import Column
|
||||
from sqlalchemy.sql.schema import Computed
|
||||
from sqlalchemy.sql.schema import Identity
|
||||
@@ -35,36 +35,16 @@ if TYPE_CHECKING:
|
||||
from sqlalchemy.sql.type_api import TypeEngine
|
||||
from sqlalchemy.util import immutabledict
|
||||
|
||||
from .operations.base import BatchOperations
|
||||
from .operations.ops import AddColumnOp
|
||||
from .operations.ops import AddConstraintOp
|
||||
from .operations.ops import AlterColumnOp
|
||||
from .operations.ops import AlterTableOp
|
||||
from .operations.ops import BulkInsertOp
|
||||
from .operations.ops import CreateIndexOp
|
||||
from .operations.ops import CreateTableCommentOp
|
||||
from .operations.ops import CreateTableOp
|
||||
from .operations.ops import DropColumnOp
|
||||
from .operations.ops import DropConstraintOp
|
||||
from .operations.ops import DropIndexOp
|
||||
from .operations.ops import DropTableCommentOp
|
||||
from .operations.ops import DropTableOp
|
||||
from .operations.ops import ExecuteSQLOp
|
||||
from .operations.ops import BatchOperations
|
||||
from .operations.ops import MigrateOperation
|
||||
from .runtime.migration import MigrationContext
|
||||
from .util.sqla_compat import _literal_bindparam
|
||||
|
||||
_T = TypeVar("_T")
|
||||
_C = TypeVar("_C", bound=Callable[..., Any])
|
||||
|
||||
### end imports ###
|
||||
|
||||
def add_column(
|
||||
table_name: str,
|
||||
column: Column[Any],
|
||||
*,
|
||||
schema: Optional[str] = None,
|
||||
if_not_exists: Optional[bool] = None,
|
||||
table_name: str, column: Column[Any], *, schema: Optional[str] = None
|
||||
) -> None:
|
||||
"""Issue an "add column" instruction using the current
|
||||
migration context.
|
||||
@@ -141,10 +121,6 @@ def add_column(
|
||||
quoting of the schema outside of the default behavior, use
|
||||
the SQLAlchemy construct
|
||||
:class:`~sqlalchemy.sql.elements.quoted_name`.
|
||||
:param if_not_exists: If True, adds IF NOT EXISTS operator
|
||||
when creating the new column for compatible dialects
|
||||
|
||||
.. versionadded:: 1.16.0
|
||||
|
||||
"""
|
||||
|
||||
@@ -154,14 +130,12 @@ def alter_column(
|
||||
*,
|
||||
nullable: Optional[bool] = None,
|
||||
comment: Union[str, Literal[False], None] = False,
|
||||
server_default: Union[
|
||||
str, bool, Identity, Computed, TextClause, None
|
||||
] = False,
|
||||
server_default: Any = False,
|
||||
new_column_name: Optional[str] = None,
|
||||
type_: Union[TypeEngine[Any], Type[TypeEngine[Any]], None] = None,
|
||||
existing_type: Union[TypeEngine[Any], Type[TypeEngine[Any]], None] = None,
|
||||
type_: Union[TypeEngine, Type[TypeEngine], None] = None,
|
||||
existing_type: Union[TypeEngine, Type[TypeEngine], None] = None,
|
||||
existing_server_default: Union[
|
||||
str, bool, Identity, Computed, TextClause, None
|
||||
str, bool, Identity, Computed, None
|
||||
] = False,
|
||||
existing_nullable: Optional[bool] = None,
|
||||
existing_comment: Optional[str] = None,
|
||||
@@ -256,7 +230,7 @@ def batch_alter_table(
|
||||
table_name: str,
|
||||
schema: Optional[str] = None,
|
||||
recreate: Literal["auto", "always", "never"] = "auto",
|
||||
partial_reordering: Optional[Tuple[Any, ...]] = None,
|
||||
partial_reordering: Optional[tuple] = None,
|
||||
copy_from: Optional[Table] = None,
|
||||
table_args: Tuple[Any, ...] = (),
|
||||
table_kwargs: Mapping[str, Any] = immutabledict({}),
|
||||
@@ -403,7 +377,7 @@ def batch_alter_table(
|
||||
|
||||
def bulk_insert(
|
||||
table: Union[Table, TableClause],
|
||||
rows: List[Dict[str, Any]],
|
||||
rows: List[dict],
|
||||
*,
|
||||
multiinsert: bool = True,
|
||||
) -> None:
|
||||
@@ -659,7 +633,7 @@ def create_foreign_key(
|
||||
def create_index(
|
||||
index_name: Optional[str],
|
||||
table_name: str,
|
||||
columns: Sequence[Union[str, TextClause, ColumnElement[Any]]],
|
||||
columns: Sequence[Union[str, TextClause, Function[Any]]],
|
||||
*,
|
||||
schema: Optional[str] = None,
|
||||
unique: bool = False,
|
||||
@@ -756,12 +730,7 @@ def create_primary_key(
|
||||
|
||||
"""
|
||||
|
||||
def create_table(
|
||||
table_name: str,
|
||||
*columns: SchemaItem,
|
||||
if_not_exists: Optional[bool] = None,
|
||||
**kw: Any,
|
||||
) -> Table:
|
||||
def create_table(table_name: str, *columns: SchemaItem, **kw: Any) -> Table:
|
||||
r"""Issue a "create table" instruction using the current migration
|
||||
context.
|
||||
|
||||
@@ -832,10 +801,6 @@ def create_table(
|
||||
quoting of the schema outside of the default behavior, use
|
||||
the SQLAlchemy construct
|
||||
:class:`~sqlalchemy.sql.elements.quoted_name`.
|
||||
:param if_not_exists: If True, adds IF NOT EXISTS operator when
|
||||
creating the new table.
|
||||
|
||||
.. versionadded:: 1.13.3
|
||||
:param \**kw: Other keyword arguments are passed to the underlying
|
||||
:class:`sqlalchemy.schema.Table` object created for the command.
|
||||
|
||||
@@ -935,11 +900,6 @@ def drop_column(
|
||||
quoting of the schema outside of the default behavior, use
|
||||
the SQLAlchemy construct
|
||||
:class:`~sqlalchemy.sql.elements.quoted_name`.
|
||||
:param if_exists: If True, adds IF EXISTS operator when
|
||||
dropping the new column for compatible dialects
|
||||
|
||||
.. versionadded:: 1.16.0
|
||||
|
||||
:param mssql_drop_check: Optional boolean. When ``True``, on
|
||||
Microsoft SQL Server only, first
|
||||
drop the CHECK constraint on the column using a
|
||||
@@ -961,6 +921,7 @@ def drop_column(
|
||||
then exec's a separate DROP CONSTRAINT for that default. Only
|
||||
works if the column has exactly one FK constraint which refers to
|
||||
it, at the moment.
|
||||
|
||||
"""
|
||||
|
||||
def drop_constraint(
|
||||
@@ -969,7 +930,6 @@ def drop_constraint(
|
||||
type_: Optional[str] = None,
|
||||
*,
|
||||
schema: Optional[str] = None,
|
||||
if_exists: Optional[bool] = None,
|
||||
) -> None:
|
||||
r"""Drop a constraint of the given name, typically via DROP CONSTRAINT.
|
||||
|
||||
@@ -981,10 +941,6 @@ def drop_constraint(
|
||||
quoting of the schema outside of the default behavior, use
|
||||
the SQLAlchemy construct
|
||||
:class:`~sqlalchemy.sql.elements.quoted_name`.
|
||||
:param if_exists: If True, adds IF EXISTS operator when
|
||||
dropping the constraint
|
||||
|
||||
.. versionadded:: 1.16.0
|
||||
|
||||
"""
|
||||
|
||||
@@ -1025,11 +981,7 @@ def drop_index(
|
||||
"""
|
||||
|
||||
def drop_table(
|
||||
table_name: str,
|
||||
*,
|
||||
schema: Optional[str] = None,
|
||||
if_exists: Optional[bool] = None,
|
||||
**kw: Any,
|
||||
table_name: str, *, schema: Optional[str] = None, **kw: Any
|
||||
) -> None:
|
||||
r"""Issue a "drop table" instruction using the current
|
||||
migration context.
|
||||
@@ -1044,10 +996,6 @@ def drop_table(
|
||||
quoting of the schema outside of the default behavior, use
|
||||
the SQLAlchemy construct
|
||||
:class:`~sqlalchemy.sql.elements.quoted_name`.
|
||||
:param if_exists: If True, adds IF EXISTS operator when
|
||||
dropping the table.
|
||||
|
||||
.. versionadded:: 1.13.3
|
||||
:param \**kw: Other keyword arguments are passed to the underlying
|
||||
:class:`sqlalchemy.schema.Table` object created for the command.
|
||||
|
||||
@@ -1184,7 +1132,7 @@ def f(name: str) -> conv:
|
||||
names will be converted along conventions. If the ``target_metadata``
|
||||
contains the naming convention
|
||||
``{"ck": "ck_bool_%(table_name)s_%(constraint_name)s"}``, then the
|
||||
output of the following::
|
||||
output of the following:
|
||||
|
||||
op.add_column("t", "x", Boolean(name="x"))
|
||||
|
||||
@@ -1214,7 +1162,7 @@ def get_context() -> MigrationContext:
|
||||
|
||||
"""
|
||||
|
||||
def implementation_for(op_cls: Any) -> Callable[[_C], _C]:
|
||||
def implementation_for(op_cls: Any) -> Callable[..., Any]:
|
||||
"""Register an implementation for a given :class:`.MigrateOperation`.
|
||||
|
||||
This is part of the operation extensibility API.
|
||||
@@ -1226,7 +1174,7 @@ def implementation_for(op_cls: Any) -> Callable[[_C], _C]:
|
||||
"""
|
||||
|
||||
def inline_literal(
|
||||
value: Union[str, int], type_: Optional[TypeEngine[Any]] = None
|
||||
value: Union[str, int], type_: Optional[TypeEngine] = None
|
||||
) -> _literal_bindparam:
|
||||
r"""Produce an 'inline literal' expression, suitable for
|
||||
using in an INSERT, UPDATE, or DELETE statement.
|
||||
@@ -1270,27 +1218,6 @@ def inline_literal(
|
||||
|
||||
"""
|
||||
|
||||
@overload
|
||||
def invoke(operation: CreateTableOp) -> Table: ...
|
||||
@overload
|
||||
def invoke(
|
||||
operation: Union[
|
||||
AddConstraintOp,
|
||||
DropConstraintOp,
|
||||
CreateIndexOp,
|
||||
DropIndexOp,
|
||||
AddColumnOp,
|
||||
AlterColumnOp,
|
||||
AlterTableOp,
|
||||
CreateTableCommentOp,
|
||||
DropTableCommentOp,
|
||||
DropColumnOp,
|
||||
BulkInsertOp,
|
||||
DropTableOp,
|
||||
ExecuteSQLOp,
|
||||
],
|
||||
) -> None: ...
|
||||
@overload
|
||||
def invoke(operation: MigrateOperation) -> Any:
|
||||
"""Given a :class:`.MigrateOperation`, invoke it in terms of
|
||||
this :class:`.Operations` instance.
|
||||
@@ -1299,7 +1226,7 @@ def invoke(operation: MigrateOperation) -> Any:
|
||||
|
||||
def register_operation(
|
||||
name: str, sourcename: Optional[str] = None
|
||||
) -> Callable[[Type[_T]], Type[_T]]:
|
||||
) -> Callable[[_T], _T]:
|
||||
"""Register a new operation for this class.
|
||||
|
||||
This method is normally used to add new operations
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
# mypy: allow-untyped-calls
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import contextmanager
|
||||
@@ -12,9 +10,7 @@ from typing import Dict
|
||||
from typing import Iterator
|
||||
from typing import List # noqa
|
||||
from typing import Mapping
|
||||
from typing import NoReturn
|
||||
from typing import Optional
|
||||
from typing import overload
|
||||
from typing import Sequence # noqa
|
||||
from typing import Tuple
|
||||
from typing import Type # noqa
|
||||
@@ -43,6 +39,7 @@ if TYPE_CHECKING:
|
||||
from sqlalchemy.sql.expression import ColumnElement
|
||||
from sqlalchemy.sql.expression import TableClause
|
||||
from sqlalchemy.sql.expression import TextClause
|
||||
from sqlalchemy.sql.functions import Function
|
||||
from sqlalchemy.sql.schema import Column
|
||||
from sqlalchemy.sql.schema import Computed
|
||||
from sqlalchemy.sql.schema import Identity
|
||||
@@ -50,28 +47,12 @@ if TYPE_CHECKING:
|
||||
from sqlalchemy.types import TypeEngine
|
||||
|
||||
from .batch import BatchOperationsImpl
|
||||
from .ops import AddColumnOp
|
||||
from .ops import AddConstraintOp
|
||||
from .ops import AlterColumnOp
|
||||
from .ops import AlterTableOp
|
||||
from .ops import BulkInsertOp
|
||||
from .ops import CreateIndexOp
|
||||
from .ops import CreateTableCommentOp
|
||||
from .ops import CreateTableOp
|
||||
from .ops import DropColumnOp
|
||||
from .ops import DropConstraintOp
|
||||
from .ops import DropIndexOp
|
||||
from .ops import DropTableCommentOp
|
||||
from .ops import DropTableOp
|
||||
from .ops import ExecuteSQLOp
|
||||
from .ops import MigrateOperation
|
||||
from ..ddl import DefaultImpl
|
||||
from ..runtime.migration import MigrationContext
|
||||
__all__ = ("Operations", "BatchOperations")
|
||||
_T = TypeVar("_T")
|
||||
|
||||
_C = TypeVar("_C", bound=Callable[..., Any])
|
||||
|
||||
|
||||
class AbstractOperations(util.ModuleClsProxy):
|
||||
"""Base class for Operations and BatchOperations.
|
||||
@@ -105,7 +86,7 @@ class AbstractOperations(util.ModuleClsProxy):
|
||||
@classmethod
|
||||
def register_operation(
|
||||
cls, name: str, sourcename: Optional[str] = None
|
||||
) -> Callable[[Type[_T]], Type[_T]]:
|
||||
) -> Callable[[_T], _T]:
|
||||
"""Register a new operation for this class.
|
||||
|
||||
This method is normally used to add new operations
|
||||
@@ -122,7 +103,7 @@ class AbstractOperations(util.ModuleClsProxy):
|
||||
|
||||
"""
|
||||
|
||||
def register(op_cls: Type[_T]) -> Type[_T]:
|
||||
def register(op_cls):
|
||||
if sourcename is None:
|
||||
fn = getattr(op_cls, name)
|
||||
source_name = fn.__name__
|
||||
@@ -141,11 +122,8 @@ class AbstractOperations(util.ModuleClsProxy):
|
||||
*spec, formatannotation=formatannotation_fwdref
|
||||
)
|
||||
num_defaults = len(spec[3]) if spec[3] else 0
|
||||
|
||||
defaulted_vals: Tuple[Any, ...]
|
||||
|
||||
if num_defaults:
|
||||
defaulted_vals = tuple(name_args[0 - num_defaults :])
|
||||
defaulted_vals = name_args[0 - num_defaults :]
|
||||
else:
|
||||
defaulted_vals = ()
|
||||
|
||||
@@ -186,7 +164,7 @@ class AbstractOperations(util.ModuleClsProxy):
|
||||
|
||||
globals_ = dict(globals())
|
||||
globals_.update({"op_cls": op_cls})
|
||||
lcl: Dict[str, Any] = {}
|
||||
lcl = {}
|
||||
|
||||
exec(func_text, globals_, lcl)
|
||||
setattr(cls, name, lcl[name])
|
||||
@@ -202,7 +180,7 @@ class AbstractOperations(util.ModuleClsProxy):
|
||||
return register
|
||||
|
||||
@classmethod
|
||||
def implementation_for(cls, op_cls: Any) -> Callable[[_C], _C]:
|
||||
def implementation_for(cls, op_cls: Any) -> Callable[..., Any]:
|
||||
"""Register an implementation for a given :class:`.MigrateOperation`.
|
||||
|
||||
This is part of the operation extensibility API.
|
||||
@@ -213,7 +191,7 @@ class AbstractOperations(util.ModuleClsProxy):
|
||||
|
||||
"""
|
||||
|
||||
def decorate(fn: _C) -> _C:
|
||||
def decorate(fn):
|
||||
cls._to_impl.dispatch_for(op_cls)(fn)
|
||||
return fn
|
||||
|
||||
@@ -235,7 +213,7 @@ class AbstractOperations(util.ModuleClsProxy):
|
||||
table_name: str,
|
||||
schema: Optional[str] = None,
|
||||
recreate: Literal["auto", "always", "never"] = "auto",
|
||||
partial_reordering: Optional[Tuple[Any, ...]] = None,
|
||||
partial_reordering: Optional[tuple] = None,
|
||||
copy_from: Optional[Table] = None,
|
||||
table_args: Tuple[Any, ...] = (),
|
||||
table_kwargs: Mapping[str, Any] = util.immutabledict(),
|
||||
@@ -404,32 +382,6 @@ class AbstractOperations(util.ModuleClsProxy):
|
||||
|
||||
return self.migration_context
|
||||
|
||||
@overload
|
||||
def invoke(self, operation: CreateTableOp) -> Table: ...
|
||||
|
||||
@overload
|
||||
def invoke(
|
||||
self,
|
||||
operation: Union[
|
||||
AddConstraintOp,
|
||||
DropConstraintOp,
|
||||
CreateIndexOp,
|
||||
DropIndexOp,
|
||||
AddColumnOp,
|
||||
AlterColumnOp,
|
||||
AlterTableOp,
|
||||
CreateTableCommentOp,
|
||||
DropTableCommentOp,
|
||||
DropColumnOp,
|
||||
BulkInsertOp,
|
||||
DropTableOp,
|
||||
ExecuteSQLOp,
|
||||
],
|
||||
) -> None: ...
|
||||
|
||||
@overload
|
||||
def invoke(self, operation: MigrateOperation) -> Any: ...
|
||||
|
||||
def invoke(self, operation: MigrateOperation) -> Any:
|
||||
"""Given a :class:`.MigrateOperation`, invoke it in terms of
|
||||
this :class:`.Operations` instance.
|
||||
@@ -464,7 +416,7 @@ class AbstractOperations(util.ModuleClsProxy):
|
||||
names will be converted along conventions. If the ``target_metadata``
|
||||
contains the naming convention
|
||||
``{"ck": "ck_bool_%(table_name)s_%(constraint_name)s"}``, then the
|
||||
output of the following::
|
||||
output of the following:
|
||||
|
||||
op.add_column("t", "x", Boolean(name="x"))
|
||||
|
||||
@@ -618,7 +570,6 @@ class Operations(AbstractOperations):
|
||||
column: Column[Any],
|
||||
*,
|
||||
schema: Optional[str] = None,
|
||||
if_not_exists: Optional[bool] = None,
|
||||
) -> None:
|
||||
"""Issue an "add column" instruction using the current
|
||||
migration context.
|
||||
@@ -695,10 +646,6 @@ class Operations(AbstractOperations):
|
||||
quoting of the schema outside of the default behavior, use
|
||||
the SQLAlchemy construct
|
||||
:class:`~sqlalchemy.sql.elements.quoted_name`.
|
||||
:param if_not_exists: If True, adds IF NOT EXISTS operator
|
||||
when creating the new column for compatible dialects
|
||||
|
||||
.. versionadded:: 1.16.0
|
||||
|
||||
""" # noqa: E501
|
||||
...
|
||||
@@ -710,16 +657,12 @@ class Operations(AbstractOperations):
|
||||
*,
|
||||
nullable: Optional[bool] = None,
|
||||
comment: Union[str, Literal[False], None] = False,
|
||||
server_default: Union[
|
||||
str, bool, Identity, Computed, TextClause, None
|
||||
] = False,
|
||||
server_default: Any = False,
|
||||
new_column_name: Optional[str] = None,
|
||||
type_: Union[TypeEngine[Any], Type[TypeEngine[Any]], None] = None,
|
||||
existing_type: Union[
|
||||
TypeEngine[Any], Type[TypeEngine[Any]], None
|
||||
] = None,
|
||||
type_: Union[TypeEngine, Type[TypeEngine], None] = None,
|
||||
existing_type: Union[TypeEngine, Type[TypeEngine], None] = None,
|
||||
existing_server_default: Union[
|
||||
str, bool, Identity, Computed, TextClause, None
|
||||
str, bool, Identity, Computed, None
|
||||
] = False,
|
||||
existing_nullable: Optional[bool] = None,
|
||||
existing_comment: Optional[str] = None,
|
||||
@@ -813,7 +756,7 @@ class Operations(AbstractOperations):
|
||||
def bulk_insert(
|
||||
self,
|
||||
table: Union[Table, TableClause],
|
||||
rows: List[Dict[str, Any]],
|
||||
rows: List[dict],
|
||||
*,
|
||||
multiinsert: bool = True,
|
||||
) -> None:
|
||||
@@ -1080,7 +1023,7 @@ class Operations(AbstractOperations):
|
||||
self,
|
||||
index_name: Optional[str],
|
||||
table_name: str,
|
||||
columns: Sequence[Union[str, TextClause, ColumnElement[Any]]],
|
||||
columns: Sequence[Union[str, TextClause, Function[Any]]],
|
||||
*,
|
||||
schema: Optional[str] = None,
|
||||
unique: bool = False,
|
||||
@@ -1181,11 +1124,7 @@ class Operations(AbstractOperations):
|
||||
...
|
||||
|
||||
def create_table(
|
||||
self,
|
||||
table_name: str,
|
||||
*columns: SchemaItem,
|
||||
if_not_exists: Optional[bool] = None,
|
||||
**kw: Any,
|
||||
self, table_name: str, *columns: SchemaItem, **kw: Any
|
||||
) -> Table:
|
||||
r"""Issue a "create table" instruction using the current migration
|
||||
context.
|
||||
@@ -1257,10 +1196,6 @@ class Operations(AbstractOperations):
|
||||
quoting of the schema outside of the default behavior, use
|
||||
the SQLAlchemy construct
|
||||
:class:`~sqlalchemy.sql.elements.quoted_name`.
|
||||
:param if_not_exists: If True, adds IF NOT EXISTS operator when
|
||||
creating the new table.
|
||||
|
||||
.. versionadded:: 1.13.3
|
||||
:param \**kw: Other keyword arguments are passed to the underlying
|
||||
:class:`sqlalchemy.schema.Table` object created for the command.
|
||||
|
||||
@@ -1366,11 +1301,6 @@ class Operations(AbstractOperations):
|
||||
quoting of the schema outside of the default behavior, use
|
||||
the SQLAlchemy construct
|
||||
:class:`~sqlalchemy.sql.elements.quoted_name`.
|
||||
:param if_exists: If True, adds IF EXISTS operator when
|
||||
dropping the new column for compatible dialects
|
||||
|
||||
.. versionadded:: 1.16.0
|
||||
|
||||
:param mssql_drop_check: Optional boolean. When ``True``, on
|
||||
Microsoft SQL Server only, first
|
||||
drop the CHECK constraint on the column using a
|
||||
@@ -1392,6 +1322,7 @@ class Operations(AbstractOperations):
|
||||
then exec's a separate DROP CONSTRAINT for that default. Only
|
||||
works if the column has exactly one FK constraint which refers to
|
||||
it, at the moment.
|
||||
|
||||
""" # noqa: E501
|
||||
...
|
||||
|
||||
@@ -1402,7 +1333,6 @@ class Operations(AbstractOperations):
|
||||
type_: Optional[str] = None,
|
||||
*,
|
||||
schema: Optional[str] = None,
|
||||
if_exists: Optional[bool] = None,
|
||||
) -> None:
|
||||
r"""Drop a constraint of the given name, typically via DROP CONSTRAINT.
|
||||
|
||||
@@ -1414,10 +1344,6 @@ class Operations(AbstractOperations):
|
||||
quoting of the schema outside of the default behavior, use
|
||||
the SQLAlchemy construct
|
||||
:class:`~sqlalchemy.sql.elements.quoted_name`.
|
||||
:param if_exists: If True, adds IF EXISTS operator when
|
||||
dropping the constraint
|
||||
|
||||
.. versionadded:: 1.16.0
|
||||
|
||||
""" # noqa: E501
|
||||
...
|
||||
@@ -1461,12 +1387,7 @@ class Operations(AbstractOperations):
|
||||
...
|
||||
|
||||
def drop_table(
|
||||
self,
|
||||
table_name: str,
|
||||
*,
|
||||
schema: Optional[str] = None,
|
||||
if_exists: Optional[bool] = None,
|
||||
**kw: Any,
|
||||
self, table_name: str, *, schema: Optional[str] = None, **kw: Any
|
||||
) -> None:
|
||||
r"""Issue a "drop table" instruction using the current
|
||||
migration context.
|
||||
@@ -1481,10 +1402,6 @@ class Operations(AbstractOperations):
|
||||
quoting of the schema outside of the default behavior, use
|
||||
the SQLAlchemy construct
|
||||
:class:`~sqlalchemy.sql.elements.quoted_name`.
|
||||
:param if_exists: If True, adds IF EXISTS operator when
|
||||
dropping the table.
|
||||
|
||||
.. versionadded:: 1.13.3
|
||||
:param \**kw: Other keyword arguments are passed to the underlying
|
||||
:class:`sqlalchemy.schema.Table` object created for the command.
|
||||
|
||||
@@ -1643,7 +1560,7 @@ class BatchOperations(AbstractOperations):
|
||||
|
||||
impl: BatchOperationsImpl
|
||||
|
||||
def _noop(self, operation: Any) -> NoReturn:
|
||||
def _noop(self, operation):
|
||||
raise NotImplementedError(
|
||||
"The %s method does not apply to a batch table alter operation."
|
||||
% operation
|
||||
@@ -1660,7 +1577,6 @@ class BatchOperations(AbstractOperations):
|
||||
*,
|
||||
insert_before: Optional[str] = None,
|
||||
insert_after: Optional[str] = None,
|
||||
if_not_exists: Optional[bool] = None,
|
||||
) -> None:
|
||||
"""Issue an "add column" instruction using the current
|
||||
batch migration context.
|
||||
@@ -1680,10 +1596,8 @@ class BatchOperations(AbstractOperations):
|
||||
comment: Union[str, Literal[False], None] = False,
|
||||
server_default: Any = False,
|
||||
new_column_name: Optional[str] = None,
|
||||
type_: Union[TypeEngine[Any], Type[TypeEngine[Any]], None] = None,
|
||||
existing_type: Union[
|
||||
TypeEngine[Any], Type[TypeEngine[Any]], None
|
||||
] = None,
|
||||
type_: Union[TypeEngine, Type[TypeEngine], None] = None,
|
||||
existing_type: Union[TypeEngine, Type[TypeEngine], None] = None,
|
||||
existing_server_default: Union[
|
||||
str, bool, Identity, Computed, None
|
||||
] = False,
|
||||
@@ -1738,7 +1652,7 @@ class BatchOperations(AbstractOperations):
|
||||
|
||||
def create_exclude_constraint(
|
||||
self, constraint_name: str, *elements: Any, **kw: Any
|
||||
) -> Optional[Table]:
|
||||
):
|
||||
"""Issue a "create exclude constraint" instruction using the
|
||||
current batch migration context.
|
||||
|
||||
@@ -1754,7 +1668,7 @@ class BatchOperations(AbstractOperations):
|
||||
|
||||
def create_foreign_key(
|
||||
self,
|
||||
constraint_name: Optional[str],
|
||||
constraint_name: str,
|
||||
referent_table: str,
|
||||
local_cols: List[str],
|
||||
remote_cols: List[str],
|
||||
@@ -1804,7 +1718,7 @@ class BatchOperations(AbstractOperations):
|
||||
...
|
||||
|
||||
def create_primary_key(
|
||||
self, constraint_name: Optional[str], columns: List[str]
|
||||
self, constraint_name: str, columns: List[str]
|
||||
) -> None:
|
||||
"""Issue a "create primary key" instruction using the
|
||||
current batch migration context.
|
||||
|
||||
@@ -1,6 +1,3 @@
|
||||
# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
|
||||
# mypy: no-warn-return-any, allow-any-generics
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
@@ -18,10 +15,9 @@ from sqlalchemy import Index
|
||||
from sqlalchemy import MetaData
|
||||
from sqlalchemy import PrimaryKeyConstraint
|
||||
from sqlalchemy import schema as sql_schema
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import Table
|
||||
from sqlalchemy import types as sqltypes
|
||||
from sqlalchemy.sql.schema import SchemaEventTarget
|
||||
from sqlalchemy.events import SchemaEventTarget
|
||||
from sqlalchemy.util import OrderedDict
|
||||
from sqlalchemy.util import topological
|
||||
|
||||
@@ -32,9 +28,11 @@ from ..util.sqla_compat import _copy_expression
|
||||
from ..util.sqla_compat import _ensure_scope_for_ddl
|
||||
from ..util.sqla_compat import _fk_is_self_referential
|
||||
from ..util.sqla_compat import _idx_table_bound_expressions
|
||||
from ..util.sqla_compat import _insert_inline
|
||||
from ..util.sqla_compat import _is_type_bound
|
||||
from ..util.sqla_compat import _remove_column_from_collection
|
||||
from ..util.sqla_compat import _resolve_for_variant
|
||||
from ..util.sqla_compat import _select
|
||||
from ..util.sqla_compat import constraint_name_defined
|
||||
from ..util.sqla_compat import constraint_name_string
|
||||
|
||||
@@ -376,7 +374,7 @@ class ApplyBatchImpl:
|
||||
for idx_existing in self.indexes.values():
|
||||
# this is a lift-and-move from Table.to_metadata
|
||||
|
||||
if idx_existing._column_flag:
|
||||
if idx_existing._column_flag: # type: ignore
|
||||
continue
|
||||
|
||||
idx_copy = Index(
|
||||
@@ -405,7 +403,9 @@ class ApplyBatchImpl:
|
||||
def _setup_referent(
|
||||
self, metadata: MetaData, constraint: ForeignKeyConstraint
|
||||
) -> None:
|
||||
spec = constraint.elements[0]._get_colspec()
|
||||
spec = constraint.elements[
|
||||
0
|
||||
]._get_colspec() # type:ignore[attr-defined]
|
||||
parts = spec.split(".")
|
||||
tname = parts[-2]
|
||||
if len(parts) == 3:
|
||||
@@ -448,15 +448,13 @@ class ApplyBatchImpl:
|
||||
|
||||
try:
|
||||
op_impl._exec(
|
||||
self.new_table.insert()
|
||||
.inline()
|
||||
.from_select(
|
||||
_insert_inline(self.new_table).from_select(
|
||||
list(
|
||||
k
|
||||
for k, transfer in self.column_transfers.items()
|
||||
if "expr" in transfer
|
||||
),
|
||||
select(
|
||||
_select(
|
||||
*[
|
||||
transfer["expr"]
|
||||
for transfer in self.column_transfers.values()
|
||||
@@ -548,7 +546,9 @@ class ApplyBatchImpl:
|
||||
else:
|
||||
sql_schema.DefaultClause(
|
||||
server_default # type: ignore[arg-type]
|
||||
)._set_parent(existing)
|
||||
)._set_parent( # type:ignore[attr-defined]
|
||||
existing
|
||||
)
|
||||
if autoincrement is not None:
|
||||
existing.autoincrement = bool(autoincrement)
|
||||
|
||||
|
||||
@@ -1,13 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import abstractmethod
|
||||
import os
|
||||
import pathlib
|
||||
import re
|
||||
from typing import Any
|
||||
from typing import Callable
|
||||
from typing import cast
|
||||
from typing import Dict
|
||||
from typing import FrozenSet
|
||||
from typing import Iterator
|
||||
from typing import List
|
||||
@@ -18,7 +15,6 @@ from typing import Set
|
||||
from typing import Tuple
|
||||
from typing import Type
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TypeVar
|
||||
from typing import Union
|
||||
|
||||
from sqlalchemy.types import NULLTYPE
|
||||
@@ -37,6 +33,7 @@ if TYPE_CHECKING:
|
||||
from sqlalchemy.sql.elements import conv
|
||||
from sqlalchemy.sql.elements import quoted_name
|
||||
from sqlalchemy.sql.elements import TextClause
|
||||
from sqlalchemy.sql.functions import Function
|
||||
from sqlalchemy.sql.schema import CheckConstraint
|
||||
from sqlalchemy.sql.schema import Column
|
||||
from sqlalchemy.sql.schema import Computed
|
||||
@@ -56,9 +53,6 @@ if TYPE_CHECKING:
|
||||
from ..runtime.migration import MigrationContext
|
||||
from ..script.revision import _RevIdType
|
||||
|
||||
_T = TypeVar("_T", bound=Any)
|
||||
_AC = TypeVar("_AC", bound="AddConstraintOp")
|
||||
|
||||
|
||||
class MigrateOperation:
|
||||
"""base class for migration command and organization objects.
|
||||
@@ -76,7 +70,7 @@ class MigrateOperation:
|
||||
"""
|
||||
|
||||
@util.memoized_property
|
||||
def info(self) -> Dict[Any, Any]:
|
||||
def info(self):
|
||||
"""A dictionary that may be used to store arbitrary information
|
||||
along with this :class:`.MigrateOperation` object.
|
||||
|
||||
@@ -98,14 +92,12 @@ class AddConstraintOp(MigrateOperation):
|
||||
add_constraint_ops = util.Dispatcher()
|
||||
|
||||
@property
|
||||
def constraint_type(self) -> str:
|
||||
def constraint_type(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
@classmethod
|
||||
def register_add_constraint(
|
||||
cls, type_: str
|
||||
) -> Callable[[Type[_AC]], Type[_AC]]:
|
||||
def go(klass: Type[_AC]) -> Type[_AC]:
|
||||
def register_add_constraint(cls, type_: str) -> Callable:
|
||||
def go(klass):
|
||||
cls.add_constraint_ops.dispatch_for(type_)(klass.from_constraint)
|
||||
return klass
|
||||
|
||||
@@ -113,7 +105,7 @@ class AddConstraintOp(MigrateOperation):
|
||||
|
||||
@classmethod
|
||||
def from_constraint(cls, constraint: Constraint) -> AddConstraintOp:
|
||||
return cls.add_constraint_ops.dispatch(constraint.__visit_name__)( # type: ignore[no-any-return] # noqa: E501
|
||||
return cls.add_constraint_ops.dispatch(constraint.__visit_name__)(
|
||||
constraint
|
||||
)
|
||||
|
||||
@@ -142,14 +134,12 @@ class DropConstraintOp(MigrateOperation):
|
||||
type_: Optional[str] = None,
|
||||
*,
|
||||
schema: Optional[str] = None,
|
||||
if_exists: Optional[bool] = None,
|
||||
_reverse: Optional[AddConstraintOp] = None,
|
||||
) -> None:
|
||||
self.constraint_name = constraint_name
|
||||
self.table_name = table_name
|
||||
self.constraint_type = type_
|
||||
self.schema = schema
|
||||
self.if_exists = if_exists
|
||||
self._reverse = _reverse
|
||||
|
||||
def reverse(self) -> AddConstraintOp:
|
||||
@@ -207,7 +197,6 @@ class DropConstraintOp(MigrateOperation):
|
||||
type_: Optional[str] = None,
|
||||
*,
|
||||
schema: Optional[str] = None,
|
||||
if_exists: Optional[bool] = None,
|
||||
) -> None:
|
||||
r"""Drop a constraint of the given name, typically via DROP CONSTRAINT.
|
||||
|
||||
@@ -219,20 +208,10 @@ class DropConstraintOp(MigrateOperation):
|
||||
quoting of the schema outside of the default behavior, use
|
||||
the SQLAlchemy construct
|
||||
:class:`~sqlalchemy.sql.elements.quoted_name`.
|
||||
:param if_exists: If True, adds IF EXISTS operator when
|
||||
dropping the constraint
|
||||
|
||||
.. versionadded:: 1.16.0
|
||||
|
||||
"""
|
||||
|
||||
op = cls(
|
||||
constraint_name,
|
||||
table_name,
|
||||
type_=type_,
|
||||
schema=schema,
|
||||
if_exists=if_exists,
|
||||
)
|
||||
op = cls(constraint_name, table_name, type_=type_, schema=schema)
|
||||
return operations.invoke(op)
|
||||
|
||||
@classmethod
|
||||
@@ -363,7 +342,7 @@ class CreatePrimaryKeyOp(AddConstraintOp):
|
||||
def batch_create_primary_key(
|
||||
cls,
|
||||
operations: BatchOperations,
|
||||
constraint_name: Optional[str],
|
||||
constraint_name: str,
|
||||
columns: List[str],
|
||||
) -> None:
|
||||
"""Issue a "create primary key" instruction using the
|
||||
@@ -419,7 +398,7 @@ class CreateUniqueConstraintOp(AddConstraintOp):
|
||||
|
||||
uq_constraint = cast("UniqueConstraint", constraint)
|
||||
|
||||
kw: Dict[str, Any] = {}
|
||||
kw: dict = {}
|
||||
if uq_constraint.deferrable:
|
||||
kw["deferrable"] = uq_constraint.deferrable
|
||||
if uq_constraint.initially:
|
||||
@@ -553,7 +532,7 @@ class CreateForeignKeyOp(AddConstraintOp):
|
||||
@classmethod
|
||||
def from_constraint(cls, constraint: Constraint) -> CreateForeignKeyOp:
|
||||
fk_constraint = cast("ForeignKeyConstraint", constraint)
|
||||
kw: Dict[str, Any] = {}
|
||||
kw: dict = {}
|
||||
if fk_constraint.onupdate:
|
||||
kw["onupdate"] = fk_constraint.onupdate
|
||||
if fk_constraint.ondelete:
|
||||
@@ -695,7 +674,7 @@ class CreateForeignKeyOp(AddConstraintOp):
|
||||
def batch_create_foreign_key(
|
||||
cls,
|
||||
operations: BatchOperations,
|
||||
constraint_name: Optional[str],
|
||||
constraint_name: str,
|
||||
referent_table: str,
|
||||
local_cols: List[str],
|
||||
remote_cols: List[str],
|
||||
@@ -918,9 +897,9 @@ class CreateIndexOp(MigrateOperation):
|
||||
def from_index(cls, index: Index) -> CreateIndexOp:
|
||||
assert index.table is not None
|
||||
return cls(
|
||||
index.name,
|
||||
index.name, # type: ignore[arg-type]
|
||||
index.table.name,
|
||||
index.expressions,
|
||||
sqla_compat._get_index_expressions(index),
|
||||
schema=index.table.schema,
|
||||
unique=index.unique,
|
||||
**index.kwargs,
|
||||
@@ -947,7 +926,7 @@ class CreateIndexOp(MigrateOperation):
|
||||
operations: Operations,
|
||||
index_name: Optional[str],
|
||||
table_name: str,
|
||||
columns: Sequence[Union[str, TextClause, ColumnElement[Any]]],
|
||||
columns: Sequence[Union[str, TextClause, Function[Any]]],
|
||||
*,
|
||||
schema: Optional[str] = None,
|
||||
unique: bool = False,
|
||||
@@ -1075,7 +1054,6 @@ class DropIndexOp(MigrateOperation):
|
||||
table_name=index.table.name,
|
||||
schema=index.table.schema,
|
||||
_reverse=CreateIndexOp.from_index(index),
|
||||
unique=index.unique,
|
||||
**index.kwargs,
|
||||
)
|
||||
|
||||
@@ -1173,7 +1151,6 @@ class CreateTableOp(MigrateOperation):
|
||||
columns: Sequence[SchemaItem],
|
||||
*,
|
||||
schema: Optional[str] = None,
|
||||
if_not_exists: Optional[bool] = None,
|
||||
_namespace_metadata: Optional[MetaData] = None,
|
||||
_constraints_included: bool = False,
|
||||
**kw: Any,
|
||||
@@ -1181,7 +1158,6 @@ class CreateTableOp(MigrateOperation):
|
||||
self.table_name = table_name
|
||||
self.columns = columns
|
||||
self.schema = schema
|
||||
self.if_not_exists = if_not_exists
|
||||
self.info = kw.pop("info", {})
|
||||
self.comment = kw.pop("comment", None)
|
||||
self.prefixes = kw.pop("prefixes", None)
|
||||
@@ -1206,7 +1182,7 @@ class CreateTableOp(MigrateOperation):
|
||||
|
||||
return cls(
|
||||
table.name,
|
||||
list(table.c) + list(table.constraints),
|
||||
list(table.c) + list(table.constraints), # type:ignore[arg-type]
|
||||
schema=table.schema,
|
||||
_namespace_metadata=_namespace_metadata,
|
||||
# given a Table() object, this Table will contain full Index()
|
||||
@@ -1244,7 +1220,6 @@ class CreateTableOp(MigrateOperation):
|
||||
operations: Operations,
|
||||
table_name: str,
|
||||
*columns: SchemaItem,
|
||||
if_not_exists: Optional[bool] = None,
|
||||
**kw: Any,
|
||||
) -> Table:
|
||||
r"""Issue a "create table" instruction using the current migration
|
||||
@@ -1317,10 +1292,6 @@ class CreateTableOp(MigrateOperation):
|
||||
quoting of the schema outside of the default behavior, use
|
||||
the SQLAlchemy construct
|
||||
:class:`~sqlalchemy.sql.elements.quoted_name`.
|
||||
:param if_not_exists: If True, adds IF NOT EXISTS operator when
|
||||
creating the new table.
|
||||
|
||||
.. versionadded:: 1.13.3
|
||||
:param \**kw: Other keyword arguments are passed to the underlying
|
||||
:class:`sqlalchemy.schema.Table` object created for the command.
|
||||
|
||||
@@ -1328,7 +1299,7 @@ class CreateTableOp(MigrateOperation):
|
||||
to the parameters given.
|
||||
|
||||
"""
|
||||
op = cls(table_name, columns, if_not_exists=if_not_exists, **kw)
|
||||
op = cls(table_name, columns, **kw)
|
||||
return operations.invoke(op)
|
||||
|
||||
|
||||
@@ -1341,13 +1312,11 @@ class DropTableOp(MigrateOperation):
|
||||
table_name: str,
|
||||
*,
|
||||
schema: Optional[str] = None,
|
||||
if_exists: Optional[bool] = None,
|
||||
table_kw: Optional[MutableMapping[Any, Any]] = None,
|
||||
_reverse: Optional[CreateTableOp] = None,
|
||||
) -> None:
|
||||
self.table_name = table_name
|
||||
self.schema = schema
|
||||
self.if_exists = if_exists
|
||||
self.table_kw = table_kw or {}
|
||||
self.comment = self.table_kw.pop("comment", None)
|
||||
self.info = self.table_kw.pop("info", None)
|
||||
@@ -1394,9 +1363,9 @@ class DropTableOp(MigrateOperation):
|
||||
info=self.info.copy() if self.info else {},
|
||||
prefixes=list(self.prefixes) if self.prefixes else [],
|
||||
schema=self.schema,
|
||||
_constraints_included=(
|
||||
self._reverse._constraints_included if self._reverse else False
|
||||
),
|
||||
_constraints_included=self._reverse._constraints_included
|
||||
if self._reverse
|
||||
else False,
|
||||
**self.table_kw,
|
||||
)
|
||||
return t
|
||||
@@ -1408,7 +1377,6 @@ class DropTableOp(MigrateOperation):
|
||||
table_name: str,
|
||||
*,
|
||||
schema: Optional[str] = None,
|
||||
if_exists: Optional[bool] = None,
|
||||
**kw: Any,
|
||||
) -> None:
|
||||
r"""Issue a "drop table" instruction using the current
|
||||
@@ -1424,15 +1392,11 @@ class DropTableOp(MigrateOperation):
|
||||
quoting of the schema outside of the default behavior, use
|
||||
the SQLAlchemy construct
|
||||
:class:`~sqlalchemy.sql.elements.quoted_name`.
|
||||
:param if_exists: If True, adds IF EXISTS operator when
|
||||
dropping the table.
|
||||
|
||||
.. versionadded:: 1.13.3
|
||||
:param \**kw: Other keyword arguments are passed to the underlying
|
||||
:class:`sqlalchemy.schema.Table` object created for the command.
|
||||
|
||||
"""
|
||||
op = cls(table_name, schema=schema, if_exists=if_exists, table_kw=kw)
|
||||
op = cls(table_name, schema=schema, table_kw=kw)
|
||||
operations.invoke(op)
|
||||
|
||||
|
||||
@@ -1570,7 +1534,7 @@ class CreateTableCommentOp(AlterTableOp):
|
||||
)
|
||||
return operations.invoke(op)
|
||||
|
||||
def reverse(self) -> Union[CreateTableCommentOp, DropTableCommentOp]:
|
||||
def reverse(self):
|
||||
"""Reverses the COMMENT ON operation against a table."""
|
||||
if self.existing_comment is None:
|
||||
return DropTableCommentOp(
|
||||
@@ -1586,16 +1550,14 @@ class CreateTableCommentOp(AlterTableOp):
|
||||
schema=self.schema,
|
||||
)
|
||||
|
||||
def to_table(
|
||||
self, migration_context: Optional[MigrationContext] = None
|
||||
) -> Table:
|
||||
def to_table(self, migration_context=None):
|
||||
schema_obj = schemaobj.SchemaObjects(migration_context)
|
||||
|
||||
return schema_obj.table(
|
||||
self.table_name, schema=self.schema, comment=self.comment
|
||||
)
|
||||
|
||||
def to_diff_tuple(self) -> Tuple[Any, ...]:
|
||||
def to_diff_tuple(self):
|
||||
return ("add_table_comment", self.to_table(), self.existing_comment)
|
||||
|
||||
|
||||
@@ -1667,20 +1629,18 @@ class DropTableCommentOp(AlterTableOp):
|
||||
)
|
||||
return operations.invoke(op)
|
||||
|
||||
def reverse(self) -> CreateTableCommentOp:
|
||||
def reverse(self):
|
||||
"""Reverses the COMMENT ON operation against a table."""
|
||||
return CreateTableCommentOp(
|
||||
self.table_name, self.existing_comment, schema=self.schema
|
||||
)
|
||||
|
||||
def to_table(
|
||||
self, migration_context: Optional[MigrationContext] = None
|
||||
) -> Table:
|
||||
def to_table(self, migration_context=None):
|
||||
schema_obj = schemaobj.SchemaObjects(migration_context)
|
||||
|
||||
return schema_obj.table(self.table_name, schema=self.schema)
|
||||
|
||||
def to_diff_tuple(self) -> Tuple[Any, ...]:
|
||||
def to_diff_tuple(self):
|
||||
return ("remove_table_comment", self.to_table())
|
||||
|
||||
|
||||
@@ -1855,16 +1815,12 @@ class AlterColumnOp(AlterTableOp):
|
||||
*,
|
||||
nullable: Optional[bool] = None,
|
||||
comment: Optional[Union[str, Literal[False]]] = False,
|
||||
server_default: Union[
|
||||
str, bool, Identity, Computed, TextClause, None
|
||||
] = False,
|
||||
server_default: Any = False,
|
||||
new_column_name: Optional[str] = None,
|
||||
type_: Optional[Union[TypeEngine[Any], Type[TypeEngine[Any]]]] = None,
|
||||
existing_type: Optional[
|
||||
Union[TypeEngine[Any], Type[TypeEngine[Any]]]
|
||||
] = None,
|
||||
existing_server_default: Union[
|
||||
str, bool, Identity, Computed, TextClause, None
|
||||
type_: Optional[Union[TypeEngine, Type[TypeEngine]]] = None,
|
||||
existing_type: Optional[Union[TypeEngine, Type[TypeEngine]]] = None,
|
||||
existing_server_default: Optional[
|
||||
Union[str, bool, Identity, Computed]
|
||||
] = False,
|
||||
existing_nullable: Optional[bool] = None,
|
||||
existing_comment: Optional[str] = None,
|
||||
@@ -1982,10 +1938,8 @@ class AlterColumnOp(AlterTableOp):
|
||||
comment: Optional[Union[str, Literal[False]]] = False,
|
||||
server_default: Any = False,
|
||||
new_column_name: Optional[str] = None,
|
||||
type_: Optional[Union[TypeEngine[Any], Type[TypeEngine[Any]]]] = None,
|
||||
existing_type: Optional[
|
||||
Union[TypeEngine[Any], Type[TypeEngine[Any]]]
|
||||
] = None,
|
||||
type_: Optional[Union[TypeEngine, Type[TypeEngine]]] = None,
|
||||
existing_type: Optional[Union[TypeEngine, Type[TypeEngine]]] = None,
|
||||
existing_server_default: Optional[
|
||||
Union[str, bool, Identity, Computed]
|
||||
] = False,
|
||||
@@ -2049,31 +2003,27 @@ class AddColumnOp(AlterTableOp):
|
||||
column: Column[Any],
|
||||
*,
|
||||
schema: Optional[str] = None,
|
||||
if_not_exists: Optional[bool] = None,
|
||||
**kw: Any,
|
||||
) -> None:
|
||||
super().__init__(table_name, schema=schema)
|
||||
self.column = column
|
||||
self.if_not_exists = if_not_exists
|
||||
self.kw = kw
|
||||
|
||||
def reverse(self) -> DropColumnOp:
|
||||
op = DropColumnOp.from_column_and_tablename(
|
||||
return DropColumnOp.from_column_and_tablename(
|
||||
self.schema, self.table_name, self.column
|
||||
)
|
||||
op.if_exists = self.if_not_exists
|
||||
return op
|
||||
|
||||
def to_diff_tuple(
|
||||
self,
|
||||
) -> Tuple[str, Optional[str], str, Column[Any]]:
|
||||
return ("add_column", self.schema, self.table_name, self.column)
|
||||
|
||||
def to_column(self) -> Column[Any]:
|
||||
def to_column(self) -> Column:
|
||||
return self.column
|
||||
|
||||
@classmethod
|
||||
def from_column(cls, col: Column[Any]) -> AddColumnOp:
|
||||
def from_column(cls, col: Column) -> AddColumnOp:
|
||||
return cls(col.table.name, col, schema=col.table.schema)
|
||||
|
||||
@classmethod
|
||||
@@ -2093,7 +2043,6 @@ class AddColumnOp(AlterTableOp):
|
||||
column: Column[Any],
|
||||
*,
|
||||
schema: Optional[str] = None,
|
||||
if_not_exists: Optional[bool] = None,
|
||||
) -> None:
|
||||
"""Issue an "add column" instruction using the current
|
||||
migration context.
|
||||
@@ -2170,19 +2119,10 @@ class AddColumnOp(AlterTableOp):
|
||||
quoting of the schema outside of the default behavior, use
|
||||
the SQLAlchemy construct
|
||||
:class:`~sqlalchemy.sql.elements.quoted_name`.
|
||||
:param if_not_exists: If True, adds IF NOT EXISTS operator
|
||||
when creating the new column for compatible dialects
|
||||
|
||||
.. versionadded:: 1.16.0
|
||||
|
||||
"""
|
||||
|
||||
op = cls(
|
||||
table_name,
|
||||
column,
|
||||
schema=schema,
|
||||
if_not_exists=if_not_exists,
|
||||
)
|
||||
op = cls(table_name, column, schema=schema)
|
||||
return operations.invoke(op)
|
||||
|
||||
@classmethod
|
||||
@@ -2193,7 +2133,6 @@ class AddColumnOp(AlterTableOp):
|
||||
*,
|
||||
insert_before: Optional[str] = None,
|
||||
insert_after: Optional[str] = None,
|
||||
if_not_exists: Optional[bool] = None,
|
||||
) -> None:
|
||||
"""Issue an "add column" instruction using the current
|
||||
batch migration context.
|
||||
@@ -2214,7 +2153,6 @@ class AddColumnOp(AlterTableOp):
|
||||
operations.impl.table_name,
|
||||
column,
|
||||
schema=operations.impl.schema,
|
||||
if_not_exists=if_not_exists,
|
||||
**kw,
|
||||
)
|
||||
return operations.invoke(op)
|
||||
@@ -2231,14 +2169,12 @@ class DropColumnOp(AlterTableOp):
|
||||
column_name: str,
|
||||
*,
|
||||
schema: Optional[str] = None,
|
||||
if_exists: Optional[bool] = None,
|
||||
_reverse: Optional[AddColumnOp] = None,
|
||||
**kw: Any,
|
||||
) -> None:
|
||||
super().__init__(table_name, schema=schema)
|
||||
self.column_name = column_name
|
||||
self.kw = kw
|
||||
self.if_exists = if_exists
|
||||
self._reverse = _reverse
|
||||
|
||||
def to_diff_tuple(
|
||||
@@ -2258,11 +2194,9 @@ class DropColumnOp(AlterTableOp):
|
||||
"original column is not present"
|
||||
)
|
||||
|
||||
op = AddColumnOp.from_column_and_tablename(
|
||||
return AddColumnOp.from_column_and_tablename(
|
||||
self.schema, self.table_name, self._reverse.column
|
||||
)
|
||||
op.if_not_exists = self.if_exists
|
||||
return op
|
||||
|
||||
@classmethod
|
||||
def from_column_and_tablename(
|
||||
@@ -2280,7 +2214,7 @@ class DropColumnOp(AlterTableOp):
|
||||
|
||||
def to_column(
|
||||
self, migration_context: Optional[MigrationContext] = None
|
||||
) -> Column[Any]:
|
||||
) -> Column:
|
||||
if self._reverse is not None:
|
||||
return self._reverse.column
|
||||
schema_obj = schemaobj.SchemaObjects(migration_context)
|
||||
@@ -2309,11 +2243,6 @@ class DropColumnOp(AlterTableOp):
|
||||
quoting of the schema outside of the default behavior, use
|
||||
the SQLAlchemy construct
|
||||
:class:`~sqlalchemy.sql.elements.quoted_name`.
|
||||
:param if_exists: If True, adds IF EXISTS operator when
|
||||
dropping the new column for compatible dialects
|
||||
|
||||
.. versionadded:: 1.16.0
|
||||
|
||||
:param mssql_drop_check: Optional boolean. When ``True``, on
|
||||
Microsoft SQL Server only, first
|
||||
drop the CHECK constraint on the column using a
|
||||
@@ -2335,6 +2264,7 @@ class DropColumnOp(AlterTableOp):
|
||||
then exec's a separate DROP CONSTRAINT for that default. Only
|
||||
works if the column has exactly one FK constraint which refers to
|
||||
it, at the moment.
|
||||
|
||||
"""
|
||||
|
||||
op = cls(table_name, column_name, schema=schema, **kw)
|
||||
@@ -2368,7 +2298,7 @@ class BulkInsertOp(MigrateOperation):
|
||||
def __init__(
|
||||
self,
|
||||
table: Union[Table, TableClause],
|
||||
rows: List[Dict[str, Any]],
|
||||
rows: List[dict],
|
||||
*,
|
||||
multiinsert: bool = True,
|
||||
) -> None:
|
||||
@@ -2381,7 +2311,7 @@ class BulkInsertOp(MigrateOperation):
|
||||
cls,
|
||||
operations: Operations,
|
||||
table: Union[Table, TableClause],
|
||||
rows: List[Dict[str, Any]],
|
||||
rows: List[dict],
|
||||
*,
|
||||
multiinsert: bool = True,
|
||||
) -> None:
|
||||
@@ -2677,7 +2607,7 @@ class UpgradeOps(OpContainer):
|
||||
self.upgrade_token = upgrade_token
|
||||
|
||||
def reverse_into(self, downgrade_ops: DowngradeOps) -> DowngradeOps:
|
||||
downgrade_ops.ops[:] = list(
|
||||
downgrade_ops.ops[:] = list( # type:ignore[index]
|
||||
reversed([op.reverse() for op in self.ops])
|
||||
)
|
||||
return downgrade_ops
|
||||
@@ -2704,7 +2634,7 @@ class DowngradeOps(OpContainer):
|
||||
super().__init__(ops=ops)
|
||||
self.downgrade_token = downgrade_token
|
||||
|
||||
def reverse(self) -> UpgradeOps:
|
||||
def reverse(self):
|
||||
return UpgradeOps(
|
||||
ops=list(reversed([op.reverse() for op in self.ops]))
|
||||
)
|
||||
@@ -2735,8 +2665,6 @@ class MigrationScript(MigrateOperation):
|
||||
"""
|
||||
|
||||
_needs_render: Optional[bool]
|
||||
_upgrade_ops: List[UpgradeOps]
|
||||
_downgrade_ops: List[DowngradeOps]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -2749,7 +2677,7 @@ class MigrationScript(MigrateOperation):
|
||||
head: Optional[str] = None,
|
||||
splice: Optional[bool] = None,
|
||||
branch_label: Optional[_RevIdType] = None,
|
||||
version_path: Union[str, os.PathLike[str], None] = None,
|
||||
version_path: Optional[str] = None,
|
||||
depends_on: Optional[_RevIdType] = None,
|
||||
) -> None:
|
||||
self.rev_id = rev_id
|
||||
@@ -2758,15 +2686,13 @@ class MigrationScript(MigrateOperation):
|
||||
self.head = head
|
||||
self.splice = splice
|
||||
self.branch_label = branch_label
|
||||
self.version_path = (
|
||||
pathlib.Path(version_path).as_posix() if version_path else None
|
||||
)
|
||||
self.version_path = version_path
|
||||
self.depends_on = depends_on
|
||||
self.upgrade_ops = upgrade_ops
|
||||
self.downgrade_ops = downgrade_ops
|
||||
|
||||
@property
|
||||
def upgrade_ops(self) -> Optional[UpgradeOps]:
|
||||
def upgrade_ops(self):
|
||||
"""An instance of :class:`.UpgradeOps`.
|
||||
|
||||
.. seealso::
|
||||
@@ -2785,15 +2711,13 @@ class MigrationScript(MigrateOperation):
|
||||
return self._upgrade_ops[0]
|
||||
|
||||
@upgrade_ops.setter
|
||||
def upgrade_ops(
|
||||
self, upgrade_ops: Union[UpgradeOps, List[UpgradeOps]]
|
||||
) -> None:
|
||||
def upgrade_ops(self, upgrade_ops):
|
||||
self._upgrade_ops = util.to_list(upgrade_ops)
|
||||
for elem in self._upgrade_ops:
|
||||
assert isinstance(elem, UpgradeOps)
|
||||
|
||||
@property
|
||||
def downgrade_ops(self) -> Optional[DowngradeOps]:
|
||||
def downgrade_ops(self):
|
||||
"""An instance of :class:`.DowngradeOps`.
|
||||
|
||||
.. seealso::
|
||||
@@ -2812,9 +2736,7 @@ class MigrationScript(MigrateOperation):
|
||||
return self._downgrade_ops[0]
|
||||
|
||||
@downgrade_ops.setter
|
||||
def downgrade_ops(
|
||||
self, downgrade_ops: Union[DowngradeOps, List[DowngradeOps]]
|
||||
) -> None:
|
||||
def downgrade_ops(self, downgrade_ops):
|
||||
self._downgrade_ops = util.to_list(downgrade_ops)
|
||||
for elem in self._downgrade_ops:
|
||||
assert isinstance(elem, DowngradeOps)
|
||||
|
||||
@@ -1,6 +1,3 @@
|
||||
# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
|
||||
# mypy: no-warn-return-any, allow-any-generics
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
@@ -223,12 +220,10 @@ class SchemaObjects:
|
||||
t = sa_schema.Table(name, m, *cols, **kw)
|
||||
|
||||
constraints = [
|
||||
(
|
||||
sqla_compat._copy(elem, target_table=t)
|
||||
if getattr(elem, "parent", None) is not t
|
||||
and getattr(elem, "parent", None) is not None
|
||||
else elem
|
||||
)
|
||||
sqla_compat._copy(elem, target_table=t)
|
||||
if getattr(elem, "parent", None) is not t
|
||||
and getattr(elem, "parent", None) is not None
|
||||
else elem
|
||||
for elem in columns
|
||||
if isinstance(elem, (Constraint, Index))
|
||||
]
|
||||
@@ -279,8 +274,10 @@ class SchemaObjects:
|
||||
ForeignKey.
|
||||
|
||||
"""
|
||||
if isinstance(fk._colspec, str):
|
||||
table_key, cname = fk._colspec.rsplit(".", 1)
|
||||
if isinstance(fk._colspec, str): # type:ignore[attr-defined]
|
||||
table_key, cname = fk._colspec.rsplit( # type:ignore[attr-defined]
|
||||
".", 1
|
||||
)
|
||||
sname, tname = self._parse_table_key(table_key)
|
||||
if table_key not in metadata.tables:
|
||||
rel_t = sa_schema.Table(tname, metadata, schema=sname)
|
||||
|
||||
@@ -1,6 +1,3 @@
|
||||
# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
|
||||
# mypy: no-warn-return-any, allow-any-generics
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import schema as sa_schema
|
||||
@@ -79,11 +76,8 @@ def alter_column(
|
||||
|
||||
@Operations.implementation_for(ops.DropTableOp)
|
||||
def drop_table(operations: "Operations", operation: "ops.DropTableOp") -> None:
|
||||
kw = {}
|
||||
if operation.if_exists is not None:
|
||||
kw["if_exists"] = operation.if_exists
|
||||
operations.impl.drop_table(
|
||||
operation.to_table(operations.migration_context), **kw
|
||||
operation.to_table(operations.migration_context)
|
||||
)
|
||||
|
||||
|
||||
@@ -93,11 +87,7 @@ def drop_column(
|
||||
) -> None:
|
||||
column = operation.to_column(operations.migration_context)
|
||||
operations.impl.drop_column(
|
||||
operation.table_name,
|
||||
column,
|
||||
schema=operation.schema,
|
||||
if_exists=operation.if_exists,
|
||||
**operation.kw,
|
||||
operation.table_name, column, schema=operation.schema, **operation.kw
|
||||
)
|
||||
|
||||
|
||||
@@ -108,6 +98,9 @@ def create_index(
|
||||
idx = operation.to_index(operations.migration_context)
|
||||
kw = {}
|
||||
if operation.if_not_exists is not None:
|
||||
if not sqla_2:
|
||||
raise NotImplementedError("SQLAlchemy 2.0+ required")
|
||||
|
||||
kw["if_not_exists"] = operation.if_not_exists
|
||||
operations.impl.create_index(idx, **kw)
|
||||
|
||||
@@ -116,6 +109,9 @@ def create_index(
|
||||
def drop_index(operations: "Operations", operation: "ops.DropIndexOp") -> None:
|
||||
kw = {}
|
||||
if operation.if_exists is not None:
|
||||
if not sqla_2:
|
||||
raise NotImplementedError("SQLAlchemy 2.0+ required")
|
||||
|
||||
kw["if_exists"] = operation.if_exists
|
||||
|
||||
operations.impl.drop_index(
|
||||
@@ -128,11 +124,8 @@ def drop_index(operations: "Operations", operation: "ops.DropIndexOp") -> None:
|
||||
def create_table(
|
||||
operations: "Operations", operation: "ops.CreateTableOp"
|
||||
) -> "Table":
|
||||
kw = {}
|
||||
if operation.if_not_exists is not None:
|
||||
kw["if_not_exists"] = operation.if_not_exists
|
||||
table = operation.to_table(operations.migration_context)
|
||||
operations.impl.create_table(table, **kw)
|
||||
operations.impl.create_table(table)
|
||||
return table
|
||||
|
||||
|
||||
@@ -172,13 +165,7 @@ def add_column(operations: "Operations", operation: "ops.AddColumnOp") -> None:
|
||||
column = _copy(column)
|
||||
|
||||
t = operations.schema_obj.table(table_name, column, schema=schema)
|
||||
operations.impl.add_column(
|
||||
table_name,
|
||||
column,
|
||||
schema=schema,
|
||||
if_not_exists=operation.if_not_exists,
|
||||
**kw,
|
||||
)
|
||||
operations.impl.add_column(table_name, column, schema=schema, **kw)
|
||||
|
||||
for constraint in t.constraints:
|
||||
if not isinstance(constraint, sa_schema.PrimaryKeyConstraint):
|
||||
@@ -208,19 +195,13 @@ def create_constraint(
|
||||
def drop_constraint(
|
||||
operations: "Operations", operation: "ops.DropConstraintOp"
|
||||
) -> None:
|
||||
kw = {}
|
||||
if operation.if_exists is not None:
|
||||
if not sqla_2:
|
||||
raise NotImplementedError("SQLAlchemy 2.0 required")
|
||||
kw["if_exists"] = operation.if_exists
|
||||
operations.impl.drop_constraint(
|
||||
operations.schema_obj.generic_constraint(
|
||||
operation.constraint_name,
|
||||
operation.table_name,
|
||||
operation.constraint_type,
|
||||
schema=operation.schema,
|
||||
),
|
||||
**kw,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -3,13 +3,13 @@ from __future__ import annotations
|
||||
from typing import Any
|
||||
from typing import Callable
|
||||
from typing import Collection
|
||||
from typing import ContextManager
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Mapping
|
||||
from typing import MutableMapping
|
||||
from typing import Optional
|
||||
from typing import overload
|
||||
from typing import Sequence
|
||||
from typing import TextIO
|
||||
from typing import Tuple
|
||||
from typing import TYPE_CHECKING
|
||||
@@ -17,7 +17,6 @@ from typing import Union
|
||||
|
||||
from sqlalchemy.sql.schema import Column
|
||||
from sqlalchemy.sql.schema import FetchedValue
|
||||
from typing_extensions import ContextManager
|
||||
from typing_extensions import Literal
|
||||
|
||||
from .migration import _ProxyTransaction
|
||||
@@ -108,6 +107,7 @@ CompareType = Callable[
|
||||
|
||||
|
||||
class EnvironmentContext(util.ModuleClsProxy):
|
||||
|
||||
"""A configurational facade made available in an ``env.py`` script.
|
||||
|
||||
The :class:`.EnvironmentContext` acts as a *facade* to the more
|
||||
@@ -227,9 +227,9 @@ class EnvironmentContext(util.ModuleClsProxy):
|
||||
has been configured.
|
||||
|
||||
"""
|
||||
return self.context_opts.get("as_sql", False) # type: ignore[no-any-return] # noqa: E501
|
||||
return self.context_opts.get("as_sql", False)
|
||||
|
||||
def is_transactional_ddl(self) -> bool:
|
||||
def is_transactional_ddl(self):
|
||||
"""Return True if the context is configured to expect a
|
||||
transactional DDL capable backend.
|
||||
|
||||
@@ -341,17 +341,18 @@ class EnvironmentContext(util.ModuleClsProxy):
|
||||
return self.context_opts.get("tag", None)
|
||||
|
||||
@overload
|
||||
def get_x_argument(self, as_dictionary: Literal[False]) -> List[str]: ...
|
||||
def get_x_argument(self, as_dictionary: Literal[False]) -> List[str]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def get_x_argument(
|
||||
self, as_dictionary: Literal[True]
|
||||
) -> Dict[str, str]: ...
|
||||
def get_x_argument(self, as_dictionary: Literal[True]) -> Dict[str, str]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def get_x_argument(
|
||||
self, as_dictionary: bool = ...
|
||||
) -> Union[List[str], Dict[str, str]]: ...
|
||||
) -> Union[List[str], Dict[str, str]]:
|
||||
...
|
||||
|
||||
def get_x_argument(
|
||||
self, as_dictionary: bool = False
|
||||
@@ -365,11 +366,7 @@ class EnvironmentContext(util.ModuleClsProxy):
|
||||
The return value is a list, returned directly from the ``argparse``
|
||||
structure. If ``as_dictionary=True`` is passed, the ``x`` arguments
|
||||
are parsed using ``key=value`` format into a dictionary that is
|
||||
then returned. If there is no ``=`` in the argument, value is an empty
|
||||
string.
|
||||
|
||||
.. versionchanged:: 1.13.1 Support ``as_dictionary=True`` when
|
||||
arguments are passed without the ``=`` symbol.
|
||||
then returned.
|
||||
|
||||
For example, to support passing a database URL on the command line,
|
||||
the standard ``env.py`` script can be modified like this::
|
||||
@@ -403,12 +400,7 @@ class EnvironmentContext(util.ModuleClsProxy):
|
||||
else:
|
||||
value = []
|
||||
if as_dictionary:
|
||||
dict_value = {}
|
||||
for arg in value:
|
||||
x_key, _, x_value = arg.partition("=")
|
||||
dict_value[x_key] = x_value
|
||||
value = dict_value
|
||||
|
||||
value = dict(arg.split("=", 1) for arg in value)
|
||||
return value
|
||||
|
||||
def configure(
|
||||
@@ -424,7 +416,7 @@ class EnvironmentContext(util.ModuleClsProxy):
|
||||
tag: Optional[str] = None,
|
||||
template_args: Optional[Dict[str, Any]] = None,
|
||||
render_as_batch: bool = False,
|
||||
target_metadata: Union[MetaData, Sequence[MetaData], None] = None,
|
||||
target_metadata: Optional[MetaData] = None,
|
||||
include_name: Optional[IncludeNameFn] = None,
|
||||
include_object: Optional[IncludeObjectFn] = None,
|
||||
include_schemas: bool = False,
|
||||
@@ -948,7 +940,7 @@ class EnvironmentContext(util.ModuleClsProxy):
|
||||
def execute(
|
||||
self,
|
||||
sql: Union[Executable, str],
|
||||
execution_options: Optional[Dict[str, Any]] = None,
|
||||
execution_options: Optional[dict] = None,
|
||||
) -> None:
|
||||
"""Execute the given SQL using the current change context.
|
||||
|
||||
@@ -976,7 +968,7 @@ class EnvironmentContext(util.ModuleClsProxy):
|
||||
|
||||
def begin_transaction(
|
||||
self,
|
||||
) -> Union[_ProxyTransaction, ContextManager[None, Optional[bool]]]:
|
||||
) -> Union[_ProxyTransaction, ContextManager[None]]:
|
||||
"""Return a context manager that will
|
||||
enclose an operation within a "transaction",
|
||||
as defined by the environment's offline
|
||||
|
||||
@@ -1,6 +1,3 @@
|
||||
# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
|
||||
# mypy: no-warn-return-any, allow-any-generics
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import contextmanager
|
||||
@@ -11,6 +8,7 @@ from typing import Any
|
||||
from typing import Callable
|
||||
from typing import cast
|
||||
from typing import Collection
|
||||
from typing import ContextManager
|
||||
from typing import Dict
|
||||
from typing import Iterable
|
||||
from typing import Iterator
|
||||
@@ -23,11 +21,13 @@ from typing import Union
|
||||
|
||||
from sqlalchemy import Column
|
||||
from sqlalchemy import literal_column
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import MetaData
|
||||
from sqlalchemy import PrimaryKeyConstraint
|
||||
from sqlalchemy import String
|
||||
from sqlalchemy import Table
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.engine import url as sqla_url
|
||||
from sqlalchemy.engine.strategies import MockEngineStrategy
|
||||
from typing_extensions import ContextManager
|
||||
|
||||
from .. import ddl
|
||||
from .. import util
|
||||
@@ -83,6 +83,7 @@ class _ProxyTransaction:
|
||||
|
||||
|
||||
class MigrationContext:
|
||||
|
||||
"""Represent the database state made available to a migration
|
||||
script.
|
||||
|
||||
@@ -175,11 +176,7 @@ class MigrationContext:
|
||||
opts["output_encoding"],
|
||||
)
|
||||
else:
|
||||
self.output_buffer = opts.get(
|
||||
"output_buffer", sys.stdout
|
||||
) # type:ignore[assignment] # noqa: E501
|
||||
|
||||
self.transactional_ddl = transactional_ddl
|
||||
self.output_buffer = opts.get("output_buffer", sys.stdout)
|
||||
|
||||
self._user_compare_type = opts.get("compare_type", True)
|
||||
self._user_compare_server_default = opts.get(
|
||||
@@ -191,6 +188,18 @@ class MigrationContext:
|
||||
self.version_table_schema = version_table_schema = opts.get(
|
||||
"version_table_schema", None
|
||||
)
|
||||
self._version = Table(
|
||||
version_table,
|
||||
MetaData(),
|
||||
Column("version_num", String(32), nullable=False),
|
||||
schema=version_table_schema,
|
||||
)
|
||||
if opts.get("version_table_pk", True):
|
||||
self._version.append_constraint(
|
||||
PrimaryKeyConstraint(
|
||||
"version_num", name="%s_pkc" % version_table
|
||||
)
|
||||
)
|
||||
|
||||
self._start_from_rev: Optional[str] = opts.get("starting_rev")
|
||||
self.impl = ddl.DefaultImpl.get_by_dialect(dialect)(
|
||||
@@ -201,23 +210,14 @@ class MigrationContext:
|
||||
self.output_buffer,
|
||||
opts,
|
||||
)
|
||||
|
||||
self._version = self.impl.version_table_impl(
|
||||
version_table=version_table,
|
||||
version_table_schema=version_table_schema,
|
||||
version_table_pk=opts.get("version_table_pk", True),
|
||||
)
|
||||
|
||||
log.info("Context impl %s.", self.impl.__class__.__name__)
|
||||
if self.as_sql:
|
||||
log.info("Generating static SQL")
|
||||
log.info(
|
||||
"Will assume %s DDL.",
|
||||
(
|
||||
"transactional"
|
||||
if self.impl.transactional_ddl
|
||||
else "non-transactional"
|
||||
),
|
||||
"transactional"
|
||||
if self.impl.transactional_ddl
|
||||
else "non-transactional",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -342,9 +342,9 @@ class MigrationContext:
|
||||
# except that it will not know it's in "autocommit" and will
|
||||
# emit deprecation warnings when an autocommit action takes
|
||||
# place.
|
||||
self.connection = self.impl.connection = (
|
||||
base_connection.execution_options(isolation_level="AUTOCOMMIT")
|
||||
)
|
||||
self.connection = (
|
||||
self.impl.connection
|
||||
) = base_connection.execution_options(isolation_level="AUTOCOMMIT")
|
||||
|
||||
# sqlalchemy future mode will "autobegin" in any case, so take
|
||||
# control of that "transaction" here
|
||||
@@ -372,7 +372,7 @@ class MigrationContext:
|
||||
|
||||
def begin_transaction(
|
||||
self, _per_migration: bool = False
|
||||
) -> Union[_ProxyTransaction, ContextManager[None, Optional[bool]]]:
|
||||
) -> Union[_ProxyTransaction, ContextManager[None]]:
|
||||
"""Begin a logical transaction for migration operations.
|
||||
|
||||
This method is used within an ``env.py`` script to demarcate where
|
||||
@@ -521,7 +521,7 @@ class MigrationContext:
|
||||
start_from_rev = None
|
||||
elif start_from_rev is not None and self.script:
|
||||
start_from_rev = [
|
||||
self.script.get_revision(sfr).revision
|
||||
cast("Script", self.script.get_revision(sfr)).revision
|
||||
for sfr in util.to_list(start_from_rev)
|
||||
if sfr not in (None, "base")
|
||||
]
|
||||
@@ -536,10 +536,7 @@ class MigrationContext:
|
||||
return ()
|
||||
assert self.connection is not None
|
||||
return tuple(
|
||||
row[0]
|
||||
for row in self.connection.execute(
|
||||
select(self._version.c.version_num)
|
||||
)
|
||||
row[0] for row in self.connection.execute(self._version.select())
|
||||
)
|
||||
|
||||
def _ensure_version_table(self, purge: bool = False) -> None:
|
||||
@@ -655,7 +652,7 @@ class MigrationContext:
|
||||
def execute(
|
||||
self,
|
||||
sql: Union[Executable, str],
|
||||
execution_options: Optional[Dict[str, Any]] = None,
|
||||
execution_options: Optional[dict] = None,
|
||||
) -> None:
|
||||
"""Execute a SQL construct or string statement.
|
||||
|
||||
@@ -1003,11 +1000,6 @@ class MigrationStep:
|
||||
is_upgrade: bool
|
||||
migration_fn: Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
@property
|
||||
def doc(self) -> Optional[str]: ...
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self.migration_fn.__name__
|
||||
@@ -1056,9 +1048,13 @@ class RevisionStep(MigrationStep):
|
||||
self.revision = revision
|
||||
self.is_upgrade = is_upgrade
|
||||
if is_upgrade:
|
||||
self.migration_fn = revision.module.upgrade
|
||||
self.migration_fn = (
|
||||
revision.module.upgrade # type:ignore[attr-defined]
|
||||
)
|
||||
else:
|
||||
self.migration_fn = revision.module.downgrade
|
||||
self.migration_fn = (
|
||||
revision.module.downgrade # type:ignore[attr-defined]
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return "RevisionStep(%r, is_upgrade=%r)" % (
|
||||
@@ -1074,7 +1070,7 @@ class RevisionStep(MigrationStep):
|
||||
)
|
||||
|
||||
@property
|
||||
def doc(self) -> Optional[str]:
|
||||
def doc(self) -> str:
|
||||
return self.revision.doc
|
||||
|
||||
@property
|
||||
@@ -1172,18 +1168,7 @@ class RevisionStep(MigrationStep):
|
||||
}
|
||||
return tuple(set(self.to_revisions).difference(ancestors))
|
||||
else:
|
||||
# for each revision we plan to return, compute its ancestors
|
||||
# (excluding self), and remove those from the final output since
|
||||
# they are already accounted for.
|
||||
ancestors = {
|
||||
r.revision
|
||||
for to_revision in self.to_revisions
|
||||
for r in self.revision_map._get_ancestor_nodes(
|
||||
self.revision_map.get_revisions(to_revision), check=False
|
||||
)
|
||||
if r.revision != to_revision
|
||||
}
|
||||
return tuple(set(self.to_revisions).difference(ancestors))
|
||||
return self.to_revisions
|
||||
|
||||
def unmerge_branch_idents(
|
||||
self, heads: Set[str]
|
||||
@@ -1298,7 +1283,7 @@ class StampStep(MigrationStep):
|
||||
def __eq__(self, other):
|
||||
return (
|
||||
isinstance(other, StampStep)
|
||||
and other.from_revisions == self.from_revisions
|
||||
and other.from_revisions == self.revisions
|
||||
and other.to_revisions == self.to_revisions
|
||||
and other.branch_move == self.branch_move
|
||||
and self.is_upgrade == other.is_upgrade
|
||||
|
||||
@@ -3,7 +3,6 @@ from __future__ import annotations
|
||||
from contextlib import contextmanager
|
||||
import datetime
|
||||
import os
|
||||
from pathlib import Path
|
||||
import re
|
||||
import shutil
|
||||
import sys
|
||||
@@ -12,6 +11,7 @@ from typing import Any
|
||||
from typing import cast
|
||||
from typing import Iterator
|
||||
from typing import List
|
||||
from typing import Mapping
|
||||
from typing import Optional
|
||||
from typing import Sequence
|
||||
from typing import Set
|
||||
@@ -23,9 +23,7 @@ from . import revision
|
||||
from . import write_hooks
|
||||
from .. import util
|
||||
from ..runtime import migration
|
||||
from ..util import compat
|
||||
from ..util import not_none
|
||||
from ..util.pyfiles import _preserving_path_as_str
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .revision import _GetRevArg
|
||||
@@ -33,28 +31,26 @@ if TYPE_CHECKING:
|
||||
from .revision import Revision
|
||||
from ..config import Config
|
||||
from ..config import MessagingOptions
|
||||
from ..config import PostWriteHookConfig
|
||||
from ..runtime.migration import RevisionStep
|
||||
from ..runtime.migration import StampStep
|
||||
|
||||
try:
|
||||
if compat.py39:
|
||||
from zoneinfo import ZoneInfo
|
||||
from zoneinfo import ZoneInfoNotFoundError
|
||||
else:
|
||||
from backports.zoneinfo import ZoneInfo # type: ignore[import-not-found,no-redef] # noqa: E501
|
||||
from backports.zoneinfo import ZoneInfoNotFoundError # type: ignore[no-redef] # noqa: E501
|
||||
from dateutil import tz
|
||||
except ImportError:
|
||||
ZoneInfo = None # type: ignore[assignment, misc]
|
||||
tz = None # type: ignore[assignment]
|
||||
|
||||
_sourceless_rev_file = re.compile(r"(?!\.\#|__init__)(.*\.py)(c|o)?$")
|
||||
_only_source_rev_file = re.compile(r"(?!\.\#|__init__)(.*\.py)$")
|
||||
_legacy_rev = re.compile(r"([a-f0-9]+)\.py$")
|
||||
_slug_re = re.compile(r"\w+")
|
||||
_default_file_template = "%(rev)s_%(slug)s"
|
||||
_split_on_space_comma = re.compile(r", *|(?: +)")
|
||||
|
||||
_split_on_space_comma_colon = re.compile(r", *|(?: +)|\:")
|
||||
|
||||
|
||||
class ScriptDirectory:
|
||||
|
||||
"""Provides operations upon an Alembic script directory.
|
||||
|
||||
This object is useful to get information as to current revisions,
|
||||
@@ -76,55 +72,40 @@ class ScriptDirectory:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dir: Union[str, os.PathLike[str]], # noqa: A002
|
||||
dir: str, # noqa
|
||||
file_template: str = _default_file_template,
|
||||
truncate_slug_length: Optional[int] = 40,
|
||||
version_locations: Optional[
|
||||
Sequence[Union[str, os.PathLike[str]]]
|
||||
] = None,
|
||||
version_locations: Optional[List[str]] = None,
|
||||
sourceless: bool = False,
|
||||
output_encoding: str = "utf-8",
|
||||
timezone: Optional[str] = None,
|
||||
hooks: list[PostWriteHookConfig] = [],
|
||||
hook_config: Optional[Mapping[str, str]] = None,
|
||||
recursive_version_locations: bool = False,
|
||||
messaging_opts: MessagingOptions = cast(
|
||||
"MessagingOptions", util.EMPTY_DICT
|
||||
),
|
||||
) -> None:
|
||||
self.dir = _preserving_path_as_str(dir)
|
||||
self.version_locations = [
|
||||
_preserving_path_as_str(p) for p in version_locations or ()
|
||||
]
|
||||
self.dir = dir
|
||||
self.file_template = file_template
|
||||
self.version_locations = version_locations
|
||||
self.truncate_slug_length = truncate_slug_length or 40
|
||||
self.sourceless = sourceless
|
||||
self.output_encoding = output_encoding
|
||||
self.revision_map = revision.RevisionMap(self._load_revisions)
|
||||
self.timezone = timezone
|
||||
self.hooks = hooks
|
||||
self.hook_config = hook_config
|
||||
self.recursive_version_locations = recursive_version_locations
|
||||
self.messaging_opts = messaging_opts
|
||||
|
||||
if not os.access(dir, os.F_OK):
|
||||
raise util.CommandError(
|
||||
f"Path doesn't exist: {dir}. Please use "
|
||||
"Path doesn't exist: %r. Please use "
|
||||
"the 'init' command to create a new "
|
||||
"scripts folder."
|
||||
"scripts folder." % os.path.abspath(dir)
|
||||
)
|
||||
|
||||
@property
|
||||
def versions(self) -> str:
|
||||
"""return a single version location based on the sole path passed
|
||||
within version_locations.
|
||||
|
||||
If multiple version locations are configured, an error is raised.
|
||||
|
||||
|
||||
"""
|
||||
return str(self._singular_version_location)
|
||||
|
||||
@util.memoized_property
|
||||
def _singular_version_location(self) -> Path:
|
||||
loc = self._version_locations
|
||||
if len(loc) > 1:
|
||||
raise util.CommandError("Multiple version_locations present")
|
||||
@@ -132,31 +113,40 @@ class ScriptDirectory:
|
||||
return loc[0]
|
||||
|
||||
@util.memoized_property
|
||||
def _version_locations(self) -> Sequence[Path]:
|
||||
def _version_locations(self):
|
||||
if self.version_locations:
|
||||
return [
|
||||
util.coerce_resource_to_filename(location).absolute()
|
||||
os.path.abspath(util.coerce_resource_to_filename(location))
|
||||
for location in self.version_locations
|
||||
]
|
||||
else:
|
||||
return [Path(self.dir, "versions").absolute()]
|
||||
return (os.path.abspath(os.path.join(self.dir, "versions")),)
|
||||
|
||||
def _load_revisions(self) -> Iterator[Script]:
|
||||
paths = [vers for vers in self._version_locations if vers.exists()]
|
||||
if self.version_locations:
|
||||
paths = [
|
||||
vers
|
||||
for vers in self._version_locations
|
||||
if os.path.exists(vers)
|
||||
]
|
||||
else:
|
||||
paths = [self.versions]
|
||||
|
||||
dupes = set()
|
||||
for vers in paths:
|
||||
for file_path in Script._list_py_dir(self, vers):
|
||||
real_path = file_path.resolve()
|
||||
real_path = os.path.realpath(file_path)
|
||||
if real_path in dupes:
|
||||
util.warn(
|
||||
f"File {real_path} loaded twice! ignoring. "
|
||||
"Please ensure version_locations is unique."
|
||||
"File %s loaded twice! ignoring. Please ensure "
|
||||
"version_locations is unique." % real_path
|
||||
)
|
||||
continue
|
||||
dupes.add(real_path)
|
||||
|
||||
script = Script._from_path(self, real_path)
|
||||
filename = os.path.basename(real_path)
|
||||
dir_name = os.path.dirname(real_path)
|
||||
script = Script._from_filename(self, dir_name, filename)
|
||||
if script is None:
|
||||
continue
|
||||
yield script
|
||||
@@ -170,36 +160,74 @@ class ScriptDirectory:
|
||||
present.
|
||||
|
||||
"""
|
||||
script_location = config.get_alembic_option("script_location")
|
||||
script_location = config.get_main_option("script_location")
|
||||
if script_location is None:
|
||||
raise util.CommandError(
|
||||
"No 'script_location' key found in configuration."
|
||||
"No 'script_location' key " "found in configuration."
|
||||
)
|
||||
truncate_slug_length: Optional[int]
|
||||
tsl = config.get_alembic_option("truncate_slug_length")
|
||||
tsl = config.get_main_option("truncate_slug_length")
|
||||
if tsl is not None:
|
||||
truncate_slug_length = int(tsl)
|
||||
else:
|
||||
truncate_slug_length = None
|
||||
|
||||
prepend_sys_path = config.get_prepend_sys_paths_list()
|
||||
if prepend_sys_path:
|
||||
sys.path[:0] = prepend_sys_path
|
||||
version_locations_str = config.get_main_option("version_locations")
|
||||
version_locations: Optional[List[str]]
|
||||
if version_locations_str:
|
||||
version_path_separator = config.get_main_option(
|
||||
"version_path_separator"
|
||||
)
|
||||
|
||||
rvl = config.get_alembic_boolean_option("recursive_version_locations")
|
||||
split_on_path = {
|
||||
None: None,
|
||||
"space": " ",
|
||||
"os": os.pathsep,
|
||||
":": ":",
|
||||
";": ";",
|
||||
}
|
||||
|
||||
try:
|
||||
split_char: Optional[str] = split_on_path[
|
||||
version_path_separator
|
||||
]
|
||||
except KeyError as ke:
|
||||
raise ValueError(
|
||||
"'%s' is not a valid value for "
|
||||
"version_path_separator; "
|
||||
"expected 'space', 'os', ':', ';'" % version_path_separator
|
||||
) from ke
|
||||
else:
|
||||
if split_char is None:
|
||||
# legacy behaviour for backwards compatibility
|
||||
version_locations = _split_on_space_comma.split(
|
||||
version_locations_str
|
||||
)
|
||||
else:
|
||||
version_locations = [
|
||||
x for x in version_locations_str.split(split_char) if x
|
||||
]
|
||||
else:
|
||||
version_locations = None
|
||||
|
||||
prepend_sys_path = config.get_main_option("prepend_sys_path")
|
||||
if prepend_sys_path:
|
||||
sys.path[:0] = list(
|
||||
_split_on_space_comma_colon.split(prepend_sys_path)
|
||||
)
|
||||
|
||||
rvl = config.get_main_option("recursive_version_locations") == "true"
|
||||
return ScriptDirectory(
|
||||
util.coerce_resource_to_filename(script_location),
|
||||
file_template=config.get_alembic_option(
|
||||
file_template=config.get_main_option(
|
||||
"file_template", _default_file_template
|
||||
),
|
||||
truncate_slug_length=truncate_slug_length,
|
||||
sourceless=config.get_alembic_boolean_option("sourceless"),
|
||||
output_encoding=config.get_alembic_option(
|
||||
"output_encoding", "utf-8"
|
||||
),
|
||||
version_locations=config.get_version_locations_list(),
|
||||
timezone=config.get_alembic_option("timezone"),
|
||||
hooks=config.get_hooks_list(),
|
||||
sourceless=config.get_main_option("sourceless") == "true",
|
||||
output_encoding=config.get_main_option("output_encoding", "utf-8"),
|
||||
version_locations=version_locations,
|
||||
timezone=config.get_main_option("timezone"),
|
||||
hook_config=config.get_section("post_write_hooks", {}),
|
||||
recursive_version_locations=rvl,
|
||||
messaging_opts=config.messaging_opts,
|
||||
)
|
||||
@@ -269,22 +297,24 @@ class ScriptDirectory:
|
||||
):
|
||||
yield cast(Script, rev)
|
||||
|
||||
def get_revisions(self, id_: _GetRevArg) -> Tuple[Script, ...]:
|
||||
def get_revisions(self, id_: _GetRevArg) -> Tuple[Optional[Script], ...]:
|
||||
"""Return the :class:`.Script` instance with the given rev identifier,
|
||||
symbolic name, or sequence of identifiers.
|
||||
|
||||
"""
|
||||
with self._catch_revision_errors():
|
||||
return cast(
|
||||
Tuple[Script, ...],
|
||||
Tuple[Optional[Script], ...],
|
||||
self.revision_map.get_revisions(id_),
|
||||
)
|
||||
|
||||
def get_all_current(self, id_: Tuple[str, ...]) -> Set[Script]:
|
||||
def get_all_current(self, id_: Tuple[str, ...]) -> Set[Optional[Script]]:
|
||||
with self._catch_revision_errors():
|
||||
return cast(Set[Script], self.revision_map._get_all_current(id_))
|
||||
return cast(
|
||||
Set[Optional[Script]], self.revision_map._get_all_current(id_)
|
||||
)
|
||||
|
||||
def get_revision(self, id_: str) -> Script:
|
||||
def get_revision(self, id_: str) -> Optional[Script]:
|
||||
"""Return the :class:`.Script` instance with the given rev id.
|
||||
|
||||
.. seealso::
|
||||
@@ -294,7 +324,7 @@ class ScriptDirectory:
|
||||
"""
|
||||
|
||||
with self._catch_revision_errors():
|
||||
return cast(Script, self.revision_map.get_revision(id_))
|
||||
return cast(Optional[Script], self.revision_map.get_revision(id_))
|
||||
|
||||
def as_revision_number(
|
||||
self, id_: Optional[str]
|
||||
@@ -549,37 +579,24 @@ class ScriptDirectory:
|
||||
util.load_python_file(self.dir, "env.py")
|
||||
|
||||
@property
|
||||
def env_py_location(self) -> str:
|
||||
return str(Path(self.dir, "env.py"))
|
||||
def env_py_location(self):
|
||||
return os.path.abspath(os.path.join(self.dir, "env.py"))
|
||||
|
||||
def _append_template(self, src: Path, dest: Path, **kw: Any) -> None:
|
||||
def _generate_template(self, src: str, dest: str, **kw: Any) -> None:
|
||||
with util.status(
|
||||
f"Appending to existing {dest.absolute()}",
|
||||
**self.messaging_opts,
|
||||
):
|
||||
util.template_to_file(
|
||||
src,
|
||||
dest,
|
||||
self.output_encoding,
|
||||
append_with_newlines=True,
|
||||
**kw,
|
||||
)
|
||||
|
||||
def _generate_template(self, src: Path, dest: Path, **kw: Any) -> None:
|
||||
with util.status(
|
||||
f"Generating {dest.absolute()}", **self.messaging_opts
|
||||
f"Generating {os.path.abspath(dest)}", **self.messaging_opts
|
||||
):
|
||||
util.template_to_file(src, dest, self.output_encoding, **kw)
|
||||
|
||||
def _copy_file(self, src: Path, dest: Path) -> None:
|
||||
def _copy_file(self, src: str, dest: str) -> None:
|
||||
with util.status(
|
||||
f"Generating {dest.absolute()}", **self.messaging_opts
|
||||
f"Generating {os.path.abspath(dest)}", **self.messaging_opts
|
||||
):
|
||||
shutil.copy(src, dest)
|
||||
|
||||
def _ensure_directory(self, path: Path) -> None:
|
||||
path = path.absolute()
|
||||
if not path.exists():
|
||||
def _ensure_directory(self, path: str) -> None:
|
||||
path = os.path.abspath(path)
|
||||
if not os.path.exists(path):
|
||||
with util.status(
|
||||
f"Creating directory {path}", **self.messaging_opts
|
||||
):
|
||||
@@ -587,27 +604,25 @@ class ScriptDirectory:
|
||||
|
||||
def _generate_create_date(self) -> datetime.datetime:
|
||||
if self.timezone is not None:
|
||||
if ZoneInfo is None:
|
||||
if tz is None:
|
||||
raise util.CommandError(
|
||||
"Python >= 3.9 is required for timezone support or "
|
||||
"the 'backports.zoneinfo' package must be installed."
|
||||
"The library 'python-dateutil' is required "
|
||||
"for timezone support"
|
||||
)
|
||||
# First, assume correct capitalization
|
||||
try:
|
||||
tzinfo = ZoneInfo(self.timezone)
|
||||
except ZoneInfoNotFoundError:
|
||||
tzinfo = None
|
||||
tzinfo = tz.gettz(self.timezone)
|
||||
if tzinfo is None:
|
||||
try:
|
||||
tzinfo = ZoneInfo(self.timezone.upper())
|
||||
except ZoneInfoNotFoundError:
|
||||
raise util.CommandError(
|
||||
"Can't locate timezone: %s" % self.timezone
|
||||
) from None
|
||||
|
||||
create_date = datetime.datetime.now(
|
||||
tz=datetime.timezone.utc
|
||||
).astimezone(tzinfo)
|
||||
# Fall back to uppercase
|
||||
tzinfo = tz.gettz(self.timezone.upper())
|
||||
if tzinfo is None:
|
||||
raise util.CommandError(
|
||||
"Can't locate timezone: %s" % self.timezone
|
||||
)
|
||||
create_date = (
|
||||
datetime.datetime.utcnow()
|
||||
.replace(tzinfo=tz.tzutc())
|
||||
.astimezone(tzinfo)
|
||||
)
|
||||
else:
|
||||
create_date = datetime.datetime.now()
|
||||
return create_date
|
||||
@@ -619,8 +634,7 @@ class ScriptDirectory:
|
||||
head: Optional[_RevIdType] = None,
|
||||
splice: Optional[bool] = False,
|
||||
branch_labels: Optional[_RevIdType] = None,
|
||||
version_path: Union[str, os.PathLike[str], None] = None,
|
||||
file_template: Optional[str] = None,
|
||||
version_path: Optional[str] = None,
|
||||
depends_on: Optional[_RevIdType] = None,
|
||||
**kw: Any,
|
||||
) -> Optional[Script]:
|
||||
@@ -661,7 +675,7 @@ class ScriptDirectory:
|
||||
self.revision_map.get_revisions(head),
|
||||
)
|
||||
for h in heads:
|
||||
assert h != "base" # type: ignore[comparison-overlap]
|
||||
assert h != "base"
|
||||
|
||||
if len(set(heads)) != len(heads):
|
||||
raise util.CommandError("Duplicate head revisions specified")
|
||||
@@ -673,7 +687,7 @@ class ScriptDirectory:
|
||||
for head_ in heads:
|
||||
if head_ is not None:
|
||||
assert isinstance(head_, Script)
|
||||
version_path = head_._script_path.parent
|
||||
version_path = os.path.dirname(head_.path)
|
||||
break
|
||||
else:
|
||||
raise util.CommandError(
|
||||
@@ -681,19 +695,16 @@ class ScriptDirectory:
|
||||
"please specify --version-path"
|
||||
)
|
||||
else:
|
||||
version_path = self._singular_version_location
|
||||
else:
|
||||
version_path = Path(version_path)
|
||||
version_path = self.versions
|
||||
|
||||
assert isinstance(version_path, Path)
|
||||
norm_path = version_path.absolute()
|
||||
norm_path = os.path.normpath(os.path.abspath(version_path))
|
||||
for vers_path in self._version_locations:
|
||||
if vers_path.absolute() == norm_path:
|
||||
if os.path.normpath(vers_path) == norm_path:
|
||||
break
|
||||
else:
|
||||
raise util.CommandError(
|
||||
f"Path {version_path} is not represented in current "
|
||||
"version locations"
|
||||
"Path %s is not represented in current "
|
||||
"version locations" % version_path
|
||||
)
|
||||
|
||||
if self.version_locations:
|
||||
@@ -714,11 +725,9 @@ class ScriptDirectory:
|
||||
if depends_on:
|
||||
with self._catch_revision_errors():
|
||||
resolved_depends_on = [
|
||||
(
|
||||
dep
|
||||
if dep in rev.branch_labels # maintain branch labels
|
||||
else rev.revision
|
||||
) # resolve partial revision identifiers
|
||||
dep
|
||||
if dep in rev.branch_labels # maintain branch labels
|
||||
else rev.revision # resolve partial revision identifiers
|
||||
for rev, dep in [
|
||||
(not_none(self.revision_map.get_revision(dep)), dep)
|
||||
for dep in util.to_list(depends_on)
|
||||
@@ -728,7 +737,7 @@ class ScriptDirectory:
|
||||
resolved_depends_on = None
|
||||
|
||||
self._generate_template(
|
||||
Path(self.dir, "script.py.mako"),
|
||||
os.path.join(self.dir, "script.py.mako"),
|
||||
path,
|
||||
up_revision=str(revid),
|
||||
down_revision=revision.tuple_rev_as_scalar(
|
||||
@@ -742,7 +751,7 @@ class ScriptDirectory:
|
||||
**kw,
|
||||
)
|
||||
|
||||
post_write_hooks = self.hooks
|
||||
post_write_hooks = self.hook_config
|
||||
if post_write_hooks:
|
||||
write_hooks._run_hooks(path, post_write_hooks)
|
||||
|
||||
@@ -765,11 +774,11 @@ class ScriptDirectory:
|
||||
|
||||
def _rev_path(
|
||||
self,
|
||||
path: Union[str, os.PathLike[str]],
|
||||
path: str,
|
||||
rev_id: str,
|
||||
message: Optional[str],
|
||||
create_date: datetime.datetime,
|
||||
) -> Path:
|
||||
) -> str:
|
||||
epoch = int(create_date.timestamp())
|
||||
slug = "_".join(_slug_re.findall(message or "")).lower()
|
||||
if len(slug) > self.truncate_slug_length:
|
||||
@@ -788,10 +797,11 @@ class ScriptDirectory:
|
||||
"second": create_date.second,
|
||||
}
|
||||
)
|
||||
return Path(path) / filename
|
||||
return os.path.join(path, filename)
|
||||
|
||||
|
||||
class Script(revision.Revision):
|
||||
|
||||
"""Represent a single revision file in a ``versions/`` directory.
|
||||
|
||||
The :class:`.Script` instance is returned by methods
|
||||
@@ -799,17 +809,12 @@ class Script(revision.Revision):
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
module: ModuleType,
|
||||
rev_id: str,
|
||||
path: Union[str, os.PathLike[str]],
|
||||
):
|
||||
def __init__(self, module: ModuleType, rev_id: str, path: str):
|
||||
self.module = module
|
||||
self.path = _preserving_path_as_str(path)
|
||||
self.path = path
|
||||
super().__init__(
|
||||
rev_id,
|
||||
module.down_revision,
|
||||
module.down_revision, # type: ignore[attr-defined]
|
||||
branch_labels=util.to_tuple(
|
||||
getattr(module, "branch_labels", None), default=()
|
||||
),
|
||||
@@ -824,10 +829,6 @@ class Script(revision.Revision):
|
||||
path: str
|
||||
"""Filesystem path of the script."""
|
||||
|
||||
@property
|
||||
def _script_path(self) -> Path:
|
||||
return Path(self.path)
|
||||
|
||||
_db_current_indicator: Optional[bool] = None
|
||||
"""Utility variable which when set will cause string output to indicate
|
||||
this is a "current" version in some database"""
|
||||
@@ -846,9 +847,9 @@ class Script(revision.Revision):
|
||||
if doc:
|
||||
if hasattr(self.module, "_alembic_source_encoding"):
|
||||
doc = doc.decode( # type: ignore[attr-defined]
|
||||
self.module._alembic_source_encoding
|
||||
self.module._alembic_source_encoding # type: ignore[attr-defined] # noqa
|
||||
)
|
||||
return doc.strip()
|
||||
return doc.strip() # type: ignore[union-attr]
|
||||
else:
|
||||
return ""
|
||||
|
||||
@@ -888,7 +889,7 @@ class Script(revision.Revision):
|
||||
)
|
||||
return entry
|
||||
|
||||
def __str__(self) -> str:
|
||||
def __str__(self):
|
||||
return "%s -> %s%s%s%s, %s" % (
|
||||
self._format_down_revision(),
|
||||
self.revision,
|
||||
@@ -922,11 +923,9 @@ class Script(revision.Revision):
|
||||
if head_indicators or tree_indicators:
|
||||
text += "%s%s%s" % (
|
||||
" (head)" if self._is_real_head else "",
|
||||
(
|
||||
" (effective head)"
|
||||
if self.is_head and not self._is_real_head
|
||||
else ""
|
||||
),
|
||||
" (effective head)"
|
||||
if self.is_head and not self._is_real_head
|
||||
else "",
|
||||
" (current)" if self._db_current_indicator else "",
|
||||
)
|
||||
if tree_indicators:
|
||||
@@ -960,33 +959,36 @@ class Script(revision.Revision):
|
||||
return util.format_as_comma(self._versioned_down_revisions)
|
||||
|
||||
@classmethod
|
||||
def _list_py_dir(
|
||||
cls, scriptdir: ScriptDirectory, path: Path
|
||||
) -> List[Path]:
|
||||
def _from_path(
|
||||
cls, scriptdir: ScriptDirectory, path: str
|
||||
) -> Optional[Script]:
|
||||
dir_, filename = os.path.split(path)
|
||||
return cls._from_filename(scriptdir, dir_, filename)
|
||||
|
||||
@classmethod
|
||||
def _list_py_dir(cls, scriptdir: ScriptDirectory, path: str) -> List[str]:
|
||||
paths = []
|
||||
for root, dirs, files in compat.path_walk(path, top_down=True):
|
||||
if root.name.endswith("__pycache__"):
|
||||
for root, dirs, files in os.walk(path, topdown=True):
|
||||
if root.endswith("__pycache__"):
|
||||
# a special case - we may include these files
|
||||
# if a `sourceless` option is specified
|
||||
continue
|
||||
|
||||
for filename in sorted(files):
|
||||
paths.append(root / filename)
|
||||
paths.append(os.path.join(root, filename))
|
||||
|
||||
if scriptdir.sourceless:
|
||||
# look for __pycache__
|
||||
py_cache_path = root / "__pycache__"
|
||||
if py_cache_path.exists():
|
||||
py_cache_path = os.path.join(root, "__pycache__")
|
||||
if os.path.exists(py_cache_path):
|
||||
# add all files from __pycache__ whose filename is not
|
||||
# already in the names we got from the version directory.
|
||||
# add as relative paths including __pycache__ token
|
||||
names = {
|
||||
Path(filename).name.split(".")[0] for filename in files
|
||||
}
|
||||
names = {filename.split(".")[0] for filename in files}
|
||||
paths.extend(
|
||||
py_cache_path / pyc
|
||||
for pyc in py_cache_path.iterdir()
|
||||
if pyc.name.split(".")[0] not in names
|
||||
os.path.join(py_cache_path, pyc)
|
||||
for pyc in os.listdir(py_cache_path)
|
||||
if pyc.split(".")[0] not in names
|
||||
)
|
||||
|
||||
if not scriptdir.recursive_version_locations:
|
||||
@@ -1001,13 +1003,9 @@ class Script(revision.Revision):
|
||||
return paths
|
||||
|
||||
@classmethod
|
||||
def _from_path(
|
||||
cls, scriptdir: ScriptDirectory, path: Union[str, os.PathLike[str]]
|
||||
def _from_filename(
|
||||
cls, scriptdir: ScriptDirectory, dir_: str, filename: str
|
||||
) -> Optional[Script]:
|
||||
|
||||
path = Path(path)
|
||||
dir_, filename = path.parent, path.name
|
||||
|
||||
if scriptdir.sourceless:
|
||||
py_match = _sourceless_rev_file.match(filename)
|
||||
else:
|
||||
@@ -1025,8 +1023,8 @@ class Script(revision.Revision):
|
||||
is_c = is_o = False
|
||||
|
||||
if is_o or is_c:
|
||||
py_exists = (dir_ / py_filename).exists()
|
||||
pyc_exists = (dir_ / (py_filename + "c")).exists()
|
||||
py_exists = os.path.exists(os.path.join(dir_, py_filename))
|
||||
pyc_exists = os.path.exists(os.path.join(dir_, py_filename + "c"))
|
||||
|
||||
# prefer .py over .pyc because we'd like to get the
|
||||
# source encoding; prefer .pyc over .pyo because we'd like to
|
||||
@@ -1042,14 +1040,14 @@ class Script(revision.Revision):
|
||||
m = _legacy_rev.match(filename)
|
||||
if not m:
|
||||
raise util.CommandError(
|
||||
"Could not determine revision id from "
|
||||
f"filename {filename}. "
|
||||
"Could not determine revision id from filename %s. "
|
||||
"Be sure the 'revision' variable is "
|
||||
"declared inside the script (please see 'Upgrading "
|
||||
"from Alembic 0.1 to 0.2' in the documentation)."
|
||||
% filename
|
||||
)
|
||||
else:
|
||||
revision = m.group(1)
|
||||
else:
|
||||
revision = module.revision
|
||||
return Script(module, revision, dir_ / filename)
|
||||
return Script(module, revision, os.path.join(dir_, filename))
|
||||
|
||||
@@ -14,7 +14,6 @@ from typing import Iterator
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import overload
|
||||
from typing import Protocol
|
||||
from typing import Sequence
|
||||
from typing import Set
|
||||
from typing import Tuple
|
||||
@@ -48,17 +47,6 @@ _relative_destination = re.compile(r"(?:(.+?)@)?(\w+)?((?:\+|-)\d+)")
|
||||
_revision_illegal_chars = ["@", "-", "+"]
|
||||
|
||||
|
||||
class _CollectRevisionsProtocol(Protocol):
|
||||
def __call__(
|
||||
self,
|
||||
upper: _RevisionIdentifierType,
|
||||
lower: _RevisionIdentifierType,
|
||||
inclusive: bool,
|
||||
implicit_base: bool,
|
||||
assert_relative_length: bool,
|
||||
) -> Tuple[Set[Revision], Tuple[Optional[_RevisionOrBase], ...]]: ...
|
||||
|
||||
|
||||
class RevisionError(Exception):
|
||||
pass
|
||||
|
||||
@@ -408,7 +396,7 @@ class RevisionMap:
|
||||
for rev in self._get_ancestor_nodes(
|
||||
[revision],
|
||||
include_dependencies=False,
|
||||
map_=map_,
|
||||
map_=cast(_RevisionMapType, map_),
|
||||
):
|
||||
if rev is revision:
|
||||
continue
|
||||
@@ -719,11 +707,9 @@ class RevisionMap:
|
||||
resolved_target = target
|
||||
|
||||
resolved_test_against_revs = [
|
||||
(
|
||||
self._revision_for_ident(test_against_rev)
|
||||
if not isinstance(test_against_rev, Revision)
|
||||
else test_against_rev
|
||||
)
|
||||
self._revision_for_ident(test_against_rev)
|
||||
if not isinstance(test_against_rev, Revision)
|
||||
else test_against_rev
|
||||
for test_against_rev in util.to_tuple(
|
||||
test_against_revs, default=()
|
||||
)
|
||||
@@ -805,7 +791,7 @@ class RevisionMap:
|
||||
The iterator yields :class:`.Revision` objects.
|
||||
|
||||
"""
|
||||
fn: _CollectRevisionsProtocol
|
||||
fn: Callable
|
||||
if select_for_downgrade:
|
||||
fn = self._collect_downgrade_revisions
|
||||
else:
|
||||
@@ -832,7 +818,7 @@ class RevisionMap:
|
||||
) -> Iterator[Any]:
|
||||
if omit_immediate_dependencies:
|
||||
|
||||
def fn(rev: Revision) -> Iterable[str]:
|
||||
def fn(rev):
|
||||
if rev not in targets:
|
||||
return rev._all_nextrev
|
||||
else:
|
||||
@@ -840,12 +826,12 @@ class RevisionMap:
|
||||
|
||||
elif include_dependencies:
|
||||
|
||||
def fn(rev: Revision) -> Iterable[str]:
|
||||
def fn(rev):
|
||||
return rev._all_nextrev
|
||||
|
||||
else:
|
||||
|
||||
def fn(rev: Revision) -> Iterable[str]:
|
||||
def fn(rev):
|
||||
return rev.nextrev
|
||||
|
||||
return self._iterate_related_revisions(
|
||||
@@ -861,12 +847,12 @@ class RevisionMap:
|
||||
) -> Iterator[Revision]:
|
||||
if include_dependencies:
|
||||
|
||||
def fn(rev: Revision) -> Iterable[str]:
|
||||
def fn(rev):
|
||||
return rev._normalized_down_revisions
|
||||
|
||||
else:
|
||||
|
||||
def fn(rev: Revision) -> Iterable[str]:
|
||||
def fn(rev):
|
||||
return rev._versioned_down_revisions
|
||||
|
||||
return self._iterate_related_revisions(
|
||||
@@ -875,7 +861,7 @@ class RevisionMap:
|
||||
|
||||
def _iterate_related_revisions(
|
||||
self,
|
||||
fn: Callable[[Revision], Iterable[str]],
|
||||
fn: Callable,
|
||||
targets: Collection[Optional[_RevisionOrBase]],
|
||||
map_: Optional[_RevisionMapType],
|
||||
check: bool = False,
|
||||
@@ -937,7 +923,7 @@ class RevisionMap:
|
||||
|
||||
id_to_rev = self._revision_map
|
||||
|
||||
def get_ancestors(rev_id: str) -> Set[str]:
|
||||
def get_ancestors(rev_id):
|
||||
return {
|
||||
r.revision
|
||||
for r in self._get_ancestor_nodes([id_to_rev[rev_id]])
|
||||
@@ -1017,9 +1003,9 @@ class RevisionMap:
|
||||
# each time but it was getting complicated
|
||||
current_heads[current_candidate_idx] = heads_to_add[0]
|
||||
current_heads.extend(heads_to_add[1:])
|
||||
ancestors_by_idx[current_candidate_idx] = (
|
||||
get_ancestors(heads_to_add[0])
|
||||
)
|
||||
ancestors_by_idx[
|
||||
current_candidate_idx
|
||||
] = get_ancestors(heads_to_add[0])
|
||||
ancestors_by_idx.extend(
|
||||
get_ancestors(head) for head in heads_to_add[1:]
|
||||
)
|
||||
@@ -1055,7 +1041,7 @@ class RevisionMap:
|
||||
children: Sequence[Optional[_RevisionOrBase]]
|
||||
for _ in range(abs(steps)):
|
||||
if steps > 0:
|
||||
assert initial != "base" # type: ignore[comparison-overlap]
|
||||
assert initial != "base"
|
||||
# Walk up
|
||||
walk_up = [
|
||||
is_revision(rev)
|
||||
@@ -1069,7 +1055,7 @@ class RevisionMap:
|
||||
children = walk_up
|
||||
else:
|
||||
# Walk down
|
||||
if initial == "base": # type: ignore[comparison-overlap]
|
||||
if initial == "base":
|
||||
children = ()
|
||||
else:
|
||||
children = self.get_revisions(
|
||||
@@ -1184,13 +1170,9 @@ class RevisionMap:
|
||||
branch_label = symbol
|
||||
# Walk down the tree to find downgrade target.
|
||||
rev = self._walk(
|
||||
start=(
|
||||
self.get_revision(symbol)
|
||||
if branch_label is None
|
||||
else self.get_revision(
|
||||
"%s@%s" % (branch_label, symbol)
|
||||
)
|
||||
),
|
||||
start=self.get_revision(symbol)
|
||||
if branch_label is None
|
||||
else self.get_revision("%s@%s" % (branch_label, symbol)),
|
||||
steps=rel_int,
|
||||
no_overwalk=assert_relative_length,
|
||||
)
|
||||
@@ -1207,7 +1189,7 @@ class RevisionMap:
|
||||
# No relative destination given, revision specified is absolute.
|
||||
branch_label, _, symbol = target.rpartition("@")
|
||||
if not branch_label:
|
||||
branch_label = None
|
||||
branch_label = None # type:ignore[assignment]
|
||||
return branch_label, self.get_revision(symbol)
|
||||
|
||||
def _parse_upgrade_target(
|
||||
@@ -1308,13 +1290,9 @@ class RevisionMap:
|
||||
)
|
||||
return (
|
||||
self._walk(
|
||||
start=(
|
||||
self.get_revision(symbol)
|
||||
if branch_label is None
|
||||
else self.get_revision(
|
||||
"%s@%s" % (branch_label, symbol)
|
||||
)
|
||||
),
|
||||
start=self.get_revision(symbol)
|
||||
if branch_label is None
|
||||
else self.get_revision("%s@%s" % (branch_label, symbol)),
|
||||
steps=relative,
|
||||
no_overwalk=assert_relative_length,
|
||||
),
|
||||
@@ -1323,11 +1301,11 @@ class RevisionMap:
|
||||
def _collect_downgrade_revisions(
|
||||
self,
|
||||
upper: _RevisionIdentifierType,
|
||||
lower: _RevisionIdentifierType,
|
||||
target: _RevisionIdentifierType,
|
||||
inclusive: bool,
|
||||
implicit_base: bool,
|
||||
assert_relative_length: bool,
|
||||
) -> Tuple[Set[Revision], Tuple[Optional[_RevisionOrBase], ...]]:
|
||||
) -> Any:
|
||||
"""
|
||||
Compute the set of current revisions specified by :upper, and the
|
||||
downgrade target specified by :target. Return all dependents of target
|
||||
@@ -1338,7 +1316,7 @@ class RevisionMap:
|
||||
|
||||
branch_label, target_revision = self._parse_downgrade_target(
|
||||
current_revisions=upper,
|
||||
target=lower,
|
||||
target=target,
|
||||
assert_relative_length=assert_relative_length,
|
||||
)
|
||||
if target_revision == "base":
|
||||
@@ -1430,7 +1408,7 @@ class RevisionMap:
|
||||
inclusive: bool,
|
||||
implicit_base: bool,
|
||||
assert_relative_length: bool,
|
||||
) -> Tuple[Set[Revision], Tuple[Revision, ...]]:
|
||||
) -> Tuple[Set[Revision], Tuple[Optional[_RevisionOrBase]]]:
|
||||
"""
|
||||
Compute the set of required revisions specified by :upper, and the
|
||||
current set of active revisions specified by :lower. Find the
|
||||
@@ -1522,7 +1500,7 @@ class RevisionMap:
|
||||
)
|
||||
needs.intersection_update(lower_descendents)
|
||||
|
||||
return needs, tuple(targets)
|
||||
return needs, tuple(targets) # type:ignore[return-value]
|
||||
|
||||
def _get_all_current(
|
||||
self, id_: Tuple[str, ...]
|
||||
@@ -1703,13 +1681,15 @@ class Revision:
|
||||
|
||||
|
||||
@overload
|
||||
def tuple_rev_as_scalar(rev: None) -> None: ...
|
||||
def tuple_rev_as_scalar(rev: None) -> None:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def tuple_rev_as_scalar(
|
||||
rev: Union[Tuple[_T, ...], List[_T]],
|
||||
) -> Union[_T, Tuple[_T, ...], List[_T]]: ...
|
||||
rev: Union[Tuple[_T, ...], List[_T]]
|
||||
) -> Union[_T, Tuple[_T, ...], List[_T]]:
|
||||
...
|
||||
|
||||
|
||||
def tuple_rev_as_scalar(
|
||||
|
||||
@@ -1,10 +1,5 @@
|
||||
# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
|
||||
# mypy: no-warn-return-any, allow-any-generics
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib.util
|
||||
import os
|
||||
import shlex
|
||||
import subprocess
|
||||
import sys
|
||||
@@ -12,16 +7,13 @@ from typing import Any
|
||||
from typing import Callable
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Mapping
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import Union
|
||||
|
||||
from .. import util
|
||||
from ..util import compat
|
||||
from ..util.pyfiles import _preserving_path_as_str
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..config import PostWriteHookConfig
|
||||
|
||||
REVISION_SCRIPT_TOKEN = "REVISION_SCRIPT_FILENAME"
|
||||
|
||||
@@ -48,19 +40,16 @@ def register(name: str) -> Callable:
|
||||
|
||||
|
||||
def _invoke(
|
||||
name: str,
|
||||
revision_path: Union[str, os.PathLike[str]],
|
||||
options: PostWriteHookConfig,
|
||||
name: str, revision: str, options: Mapping[str, Union[str, int]]
|
||||
) -> Any:
|
||||
"""Invokes the formatter registered for the given name.
|
||||
|
||||
:param name: The name of a formatter in the registry
|
||||
:param revision: string path to the revision file
|
||||
:param revision: A :class:`.MigrationRevision` instance
|
||||
:param options: A dict containing kwargs passed to the
|
||||
specified formatter.
|
||||
:raises: :class:`alembic.util.CommandError`
|
||||
"""
|
||||
revision_path = _preserving_path_as_str(revision_path)
|
||||
try:
|
||||
hook = _registry[name]
|
||||
except KeyError as ke:
|
||||
@@ -68,28 +57,36 @@ def _invoke(
|
||||
f"No formatter with name '{name}' registered"
|
||||
) from ke
|
||||
else:
|
||||
return hook(revision_path, options)
|
||||
return hook(revision, options)
|
||||
|
||||
|
||||
def _run_hooks(
|
||||
path: Union[str, os.PathLike[str]], hooks: list[PostWriteHookConfig]
|
||||
) -> None:
|
||||
def _run_hooks(path: str, hook_config: Mapping[str, str]) -> None:
|
||||
"""Invoke hooks for a generated revision."""
|
||||
|
||||
for hook in hooks:
|
||||
name = hook["_hook_name"]
|
||||
from .base import _split_on_space_comma
|
||||
|
||||
names = _split_on_space_comma.split(hook_config.get("hooks", ""))
|
||||
|
||||
for name in names:
|
||||
if not name:
|
||||
continue
|
||||
opts = {
|
||||
key[len(name) + 1 :]: hook_config[key]
|
||||
for key in hook_config
|
||||
if key.startswith(name + ".")
|
||||
}
|
||||
opts["_hook_name"] = name
|
||||
try:
|
||||
type_ = hook["type"]
|
||||
type_ = opts["type"]
|
||||
except KeyError as ke:
|
||||
raise util.CommandError(
|
||||
f"Key '{name}.type' (or 'type' in toml) is required "
|
||||
f"for post write hook {name!r}"
|
||||
f"Key {name}.type is required for post write hook {name!r}"
|
||||
) from ke
|
||||
else:
|
||||
with util.status(
|
||||
f"Running post write hook {name!r}", newline=True
|
||||
):
|
||||
_invoke(type_, path, hook)
|
||||
_invoke(type_, path, opts)
|
||||
|
||||
|
||||
def _parse_cmdline_options(cmdline_options_str: str, path: str) -> List[str]:
|
||||
@@ -113,35 +110,17 @@ def _parse_cmdline_options(cmdline_options_str: str, path: str) -> List[str]:
|
||||
return cmdline_options_list
|
||||
|
||||
|
||||
def _get_required_option(options: dict, name: str) -> str:
|
||||
try:
|
||||
return options[name]
|
||||
except KeyError as ke:
|
||||
raise util.CommandError(
|
||||
f"Key {options['_hook_name']}.{name} is required for post "
|
||||
f"write hook {options['_hook_name']!r}"
|
||||
) from ke
|
||||
|
||||
|
||||
def _run_hook(
|
||||
path: str, options: dict, ignore_output: bool, command: List[str]
|
||||
) -> None:
|
||||
cwd: Optional[str] = options.get("cwd", None)
|
||||
cmdline_options_str = options.get("options", "")
|
||||
cmdline_options_list = _parse_cmdline_options(cmdline_options_str, path)
|
||||
|
||||
kw: Dict[str, Any] = {}
|
||||
if ignore_output:
|
||||
kw["stdout"] = kw["stderr"] = subprocess.DEVNULL
|
||||
|
||||
subprocess.run([*command, *cmdline_options_list], cwd=cwd, **kw)
|
||||
|
||||
|
||||
@register("console_scripts")
|
||||
def console_scripts(
|
||||
path: str, options: dict, ignore_output: bool = False
|
||||
) -> None:
|
||||
entrypoint_name = _get_required_option(options, "entrypoint")
|
||||
try:
|
||||
entrypoint_name = options["entrypoint"]
|
||||
except KeyError as ke:
|
||||
raise util.CommandError(
|
||||
f"Key {options['_hook_name']}.entrypoint is required for post "
|
||||
f"write hook {options['_hook_name']!r}"
|
||||
) from ke
|
||||
for entry in compat.importlib_metadata_get("console_scripts"):
|
||||
if entry.name == entrypoint_name:
|
||||
impl: Any = entry
|
||||
@@ -150,27 +129,48 @@ def console_scripts(
|
||||
raise util.CommandError(
|
||||
f"Could not find entrypoint console_scripts.{entrypoint_name}"
|
||||
)
|
||||
cwd: Optional[str] = options.get("cwd", None)
|
||||
cmdline_options_str = options.get("options", "")
|
||||
cmdline_options_list = _parse_cmdline_options(cmdline_options_str, path)
|
||||
|
||||
command = [
|
||||
sys.executable,
|
||||
"-c",
|
||||
f"import {impl.module}; {impl.module}.{impl.attr}()",
|
||||
]
|
||||
_run_hook(path, options, ignore_output, command)
|
||||
kw: Dict[str, Any] = {}
|
||||
if ignore_output:
|
||||
kw["stdout"] = kw["stderr"] = subprocess.DEVNULL
|
||||
|
||||
subprocess.run(
|
||||
[
|
||||
sys.executable,
|
||||
"-c",
|
||||
f"import {impl.module}; {impl.module}.{impl.attr}()",
|
||||
]
|
||||
+ cmdline_options_list,
|
||||
cwd=cwd,
|
||||
**kw,
|
||||
)
|
||||
|
||||
|
||||
@register("exec")
|
||||
def exec_(path: str, options: dict, ignore_output: bool = False) -> None:
|
||||
executable = _get_required_option(options, "executable")
|
||||
_run_hook(path, options, ignore_output, command=[executable])
|
||||
try:
|
||||
executable = options["executable"]
|
||||
except KeyError as ke:
|
||||
raise util.CommandError(
|
||||
f"Key {options['_hook_name']}.executable is required for post "
|
||||
f"write hook {options['_hook_name']!r}"
|
||||
) from ke
|
||||
cwd: Optional[str] = options.get("cwd", None)
|
||||
cmdline_options_str = options.get("options", "")
|
||||
cmdline_options_list = _parse_cmdline_options(cmdline_options_str, path)
|
||||
|
||||
kw: Dict[str, Any] = {}
|
||||
if ignore_output:
|
||||
kw["stdout"] = kw["stderr"] = subprocess.DEVNULL
|
||||
|
||||
@register("module")
|
||||
def module(path: str, options: dict, ignore_output: bool = False) -> None:
|
||||
module_name = _get_required_option(options, "module")
|
||||
|
||||
if importlib.util.find_spec(module_name) is None:
|
||||
raise util.CommandError(f"Could not find module {module_name}")
|
||||
|
||||
command = [sys.executable, "-m", module_name]
|
||||
_run_hook(path, options, ignore_output, command)
|
||||
subprocess.run(
|
||||
[
|
||||
executable,
|
||||
*cmdline_options_list,
|
||||
],
|
||||
cwd=cwd,
|
||||
**kw,
|
||||
)
|
||||
|
||||
@@ -1,32 +1,27 @@
|
||||
# A generic, single database configuration.
|
||||
|
||||
[alembic]
|
||||
# path to migration scripts.
|
||||
# this is typically a path given in POSIX (e.g. forward slashes)
|
||||
# format, relative to the token %(here)s which refers to the location of this
|
||||
# ini file
|
||||
# path to migration scripts
|
||||
script_location = ${script_location}
|
||||
|
||||
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
|
||||
# Uncomment the line below if you want the files to be prepended with date and time
|
||||
# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file
|
||||
# for all available tokens
|
||||
# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s
|
||||
|
||||
# sys.path path, will be prepended to sys.path if present.
|
||||
# defaults to the current working directory. for multiple paths, the path separator
|
||||
# is defined by "path_separator" below.
|
||||
# defaults to the current working directory.
|
||||
prepend_sys_path = .
|
||||
|
||||
# timezone to use when rendering the date within the migration file
|
||||
# as well as the filename.
|
||||
# If specified, requires the python>=3.9 or backports.zoneinfo library and tzdata library.
|
||||
# Any required deps can installed by adding `alembic[tz]` to the pip requirements
|
||||
# string value is passed to ZoneInfo()
|
||||
# If specified, requires the python-dateutil library that can be
|
||||
# installed by adding `alembic[tz]` to the pip requirements
|
||||
# string value is passed to dateutil.tz.gettz()
|
||||
# leave blank for localtime
|
||||
# timezone =
|
||||
|
||||
# max length of characters to apply to the "slug" field
|
||||
# max length of characters to apply to the
|
||||
# "slug" field
|
||||
# truncate_slug_length = 40
|
||||
|
||||
# set to 'true' to run the environment during
|
||||
@@ -39,38 +34,20 @@ prepend_sys_path = .
|
||||
# sourceless = false
|
||||
|
||||
# version location specification; This defaults
|
||||
# to <script_location>/versions. When using multiple version
|
||||
# to ${script_location}/versions. When using multiple version
|
||||
# directories, initial revisions must be specified with --version-path.
|
||||
# The path separator used here should be the separator specified by "path_separator"
|
||||
# below.
|
||||
# version_locations = %(here)s/bar:%(here)s/bat:%(here)s/alembic/versions
|
||||
# The path separator used here should be the separator specified by "version_path_separator" below.
|
||||
# version_locations = %(here)s/bar:%(here)s/bat:${script_location}/versions
|
||||
|
||||
# path_separator; This indicates what character is used to split lists of file
|
||||
# paths, including version_locations and prepend_sys_path within configparser
|
||||
# files such as alembic.ini.
|
||||
# The default rendered in new alembic.ini files is "os", which uses os.pathsep
|
||||
# to provide os-dependent path splitting.
|
||||
# version path separator; As mentioned above, this is the character used to split
|
||||
# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep.
|
||||
# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas.
|
||||
# Valid values for version_path_separator are:
|
||||
#
|
||||
# Note that in order to support legacy alembic.ini files, this default does NOT
|
||||
# take place if path_separator is not present in alembic.ini. If this
|
||||
# option is omitted entirely, fallback logic is as follows:
|
||||
#
|
||||
# 1. Parsing of the version_locations option falls back to using the legacy
|
||||
# "version_path_separator" key, which if absent then falls back to the legacy
|
||||
# behavior of splitting on spaces and/or commas.
|
||||
# 2. Parsing of the prepend_sys_path option falls back to the legacy
|
||||
# behavior of splitting on spaces, commas, or colons.
|
||||
#
|
||||
# Valid values for path_separator are:
|
||||
#
|
||||
# path_separator = :
|
||||
# path_separator = ;
|
||||
# path_separator = space
|
||||
# path_separator = newline
|
||||
#
|
||||
# Use os.pathsep. Default configuration used for new projects.
|
||||
path_separator = os
|
||||
|
||||
# version_path_separator = :
|
||||
# version_path_separator = ;
|
||||
# version_path_separator = space
|
||||
version_path_separator = os # Use os.pathsep. Default configuration used for new projects.
|
||||
|
||||
# set to 'true' to search source files recursively
|
||||
# in each "version_locations" directory
|
||||
@@ -81,9 +58,6 @@ path_separator = os
|
||||
# are written from script.py.mako
|
||||
# output_encoding = utf-8
|
||||
|
||||
# database URL. This is consumed by the user-maintained env.py script only.
|
||||
# other means of configuring database URLs may be customized within the env.py
|
||||
# file.
|
||||
sqlalchemy.url = driver://user:pass@localhost/dbname
|
||||
|
||||
|
||||
@@ -98,20 +72,13 @@ sqlalchemy.url = driver://user:pass@localhost/dbname
|
||||
# black.entrypoint = black
|
||||
# black.options = -l 79 REVISION_SCRIPT_FILENAME
|
||||
|
||||
# lint with attempts to fix using "ruff" - use the module runner, against the "ruff" module
|
||||
# hooks = ruff
|
||||
# ruff.type = module
|
||||
# ruff.module = ruff
|
||||
# ruff.options = check --fix REVISION_SCRIPT_FILENAME
|
||||
|
||||
# Alternatively, use the exec runner to execute a binary found on your PATH
|
||||
# lint with attempts to fix using "ruff" - use the exec runner, execute a binary
|
||||
# hooks = ruff
|
||||
# ruff.type = exec
|
||||
# ruff.executable = ruff
|
||||
# ruff.options = check --fix REVISION_SCRIPT_FILENAME
|
||||
# ruff.executable = %(here)s/.venv/bin/ruff
|
||||
# ruff.options = --fix REVISION_SCRIPT_FILENAME
|
||||
|
||||
# Logging configuration. This is also consumed by the user-maintained
|
||||
# env.py script only.
|
||||
# Logging configuration
|
||||
[loggers]
|
||||
keys = root,sqlalchemy,alembic
|
||||
|
||||
@@ -122,12 +89,12 @@ keys = console
|
||||
keys = generic
|
||||
|
||||
[logger_root]
|
||||
level = WARNING
|
||||
level = WARN
|
||||
handlers = console
|
||||
qualname =
|
||||
|
||||
[logger_sqlalchemy]
|
||||
level = WARNING
|
||||
level = WARN
|
||||
handlers =
|
||||
qualname = sqlalchemy.engine
|
||||
|
||||
|
||||
@@ -13,16 +13,14 @@ ${imports if imports else ""}
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = ${repr(up_revision)}
|
||||
down_revision: Union[str, Sequence[str], None] = ${repr(down_revision)}
|
||||
down_revision: Union[str, None] = ${repr(down_revision)}
|
||||
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
|
||||
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
${upgrades if upgrades else "pass"}
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
${downgrades if downgrades else "pass"}
|
||||
|
||||
@@ -1,10 +1,7 @@
|
||||
# A generic, single database configuration.
|
||||
|
||||
[alembic]
|
||||
# path to migration scripts.
|
||||
# this is typically a path given in POSIX (e.g. forward slashes)
|
||||
# format, relative to the token %(here)s which refers to the location of this
|
||||
# ini file
|
||||
# path to migration scripts
|
||||
script_location = ${script_location}
|
||||
|
||||
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
|
||||
@@ -14,20 +11,19 @@ script_location = ${script_location}
|
||||
# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s
|
||||
|
||||
# sys.path path, will be prepended to sys.path if present.
|
||||
# defaults to the current working directory. for multiple paths, the path separator
|
||||
# is defined by "path_separator" below.
|
||||
# defaults to the current working directory.
|
||||
prepend_sys_path = .
|
||||
|
||||
|
||||
# timezone to use when rendering the date within the migration file
|
||||
# as well as the filename.
|
||||
# If specified, requires the python>=3.9 or backports.zoneinfo library and tzdata library.
|
||||
# Any required deps can installed by adding `alembic[tz]` to the pip requirements
|
||||
# string value is passed to ZoneInfo()
|
||||
# If specified, requires the python-dateutil library that can be
|
||||
# installed by adding `alembic[tz]` to the pip requirements
|
||||
# string value is passed to dateutil.tz.gettz()
|
||||
# leave blank for localtime
|
||||
# timezone =
|
||||
|
||||
# max length of characters to apply to the "slug" field
|
||||
# max length of characters to apply to the
|
||||
# "slug" field
|
||||
# truncate_slug_length = 40
|
||||
|
||||
# set to 'true' to run the environment during
|
||||
@@ -40,37 +36,20 @@ prepend_sys_path = .
|
||||
# sourceless = false
|
||||
|
||||
# version location specification; This defaults
|
||||
# to <script_location>/versions. When using multiple version
|
||||
# to ${script_location}/versions. When using multiple version
|
||||
# directories, initial revisions must be specified with --version-path.
|
||||
# The path separator used here should be the separator specified by "path_separator"
|
||||
# below.
|
||||
# version_locations = %(here)s/bar:%(here)s/bat:%(here)s/alembic/versions
|
||||
# The path separator used here should be the separator specified by "version_path_separator" below.
|
||||
# version_locations = %(here)s/bar:%(here)s/bat:${script_location}/versions
|
||||
|
||||
# path_separator; This indicates what character is used to split lists of file
|
||||
# paths, including version_locations and prepend_sys_path within configparser
|
||||
# files such as alembic.ini.
|
||||
# The default rendered in new alembic.ini files is "os", which uses os.pathsep
|
||||
# to provide os-dependent path splitting.
|
||||
# version path separator; As mentioned above, this is the character used to split
|
||||
# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep.
|
||||
# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas.
|
||||
# Valid values for version_path_separator are:
|
||||
#
|
||||
# Note that in order to support legacy alembic.ini files, this default does NOT
|
||||
# take place if path_separator is not present in alembic.ini. If this
|
||||
# option is omitted entirely, fallback logic is as follows:
|
||||
#
|
||||
# 1. Parsing of the version_locations option falls back to using the legacy
|
||||
# "version_path_separator" key, which if absent then falls back to the legacy
|
||||
# behavior of splitting on spaces and/or commas.
|
||||
# 2. Parsing of the prepend_sys_path option falls back to the legacy
|
||||
# behavior of splitting on spaces, commas, or colons.
|
||||
#
|
||||
# Valid values for path_separator are:
|
||||
#
|
||||
# path_separator = :
|
||||
# path_separator = ;
|
||||
# path_separator = space
|
||||
# path_separator = newline
|
||||
#
|
||||
# Use os.pathsep. Default configuration used for new projects.
|
||||
path_separator = os
|
||||
# version_path_separator = :
|
||||
# version_path_separator = ;
|
||||
# version_path_separator = space
|
||||
version_path_separator = os # Use os.pathsep. Default configuration used for new projects.
|
||||
|
||||
# set to 'true' to search source files recursively
|
||||
# in each "version_locations" directory
|
||||
@@ -81,9 +60,6 @@ path_separator = os
|
||||
# are written from script.py.mako
|
||||
# output_encoding = utf-8
|
||||
|
||||
# database URL. This is consumed by the user-maintained env.py script only.
|
||||
# other means of configuring database URLs may be customized within the env.py
|
||||
# file.
|
||||
sqlalchemy.url = driver://user:pass@localhost/dbname
|
||||
|
||||
|
||||
@@ -98,20 +74,13 @@ sqlalchemy.url = driver://user:pass@localhost/dbname
|
||||
# black.entrypoint = black
|
||||
# black.options = -l 79 REVISION_SCRIPT_FILENAME
|
||||
|
||||
# lint with attempts to fix using "ruff" - use the module runner, against the "ruff" module
|
||||
# hooks = ruff
|
||||
# ruff.type = module
|
||||
# ruff.module = ruff
|
||||
# ruff.options = check --fix REVISION_SCRIPT_FILENAME
|
||||
|
||||
# Alternatively, use the exec runner to execute a binary found on your PATH
|
||||
# lint with attempts to fix using "ruff" - use the exec runner, execute a binary
|
||||
# hooks = ruff
|
||||
# ruff.type = exec
|
||||
# ruff.executable = ruff
|
||||
# ruff.options = check --fix REVISION_SCRIPT_FILENAME
|
||||
# ruff.executable = %(here)s/.venv/bin/ruff
|
||||
# ruff.options = --fix REVISION_SCRIPT_FILENAME
|
||||
|
||||
# Logging configuration. This is also consumed by the user-maintained
|
||||
# env.py script only.
|
||||
# Logging configuration
|
||||
[loggers]
|
||||
keys = root,sqlalchemy,alembic
|
||||
|
||||
@@ -122,12 +91,12 @@ keys = console
|
||||
keys = generic
|
||||
|
||||
[logger_root]
|
||||
level = WARNING
|
||||
level = WARN
|
||||
handlers = console
|
||||
qualname =
|
||||
|
||||
[logger_sqlalchemy]
|
||||
level = WARNING
|
||||
level = WARN
|
||||
handlers =
|
||||
qualname = sqlalchemy.engine
|
||||
|
||||
|
||||
@@ -13,16 +13,14 @@ ${imports if imports else ""}
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = ${repr(up_revision)}
|
||||
down_revision: Union[str, Sequence[str], None] = ${repr(down_revision)}
|
||||
down_revision: Union[str, None] = ${repr(down_revision)}
|
||||
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
|
||||
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
${upgrades if upgrades else "pass"}
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
${downgrades if downgrades else "pass"}
|
||||
|
||||
@@ -1,10 +1,7 @@
|
||||
# a multi-database configuration.
|
||||
|
||||
[alembic]
|
||||
# path to migration scripts.
|
||||
# this is typically a path given in POSIX (e.g. forward slashes)
|
||||
# format, relative to the token %(here)s which refers to the location of this
|
||||
# ini file
|
||||
# path to migration scripts
|
||||
script_location = ${script_location}
|
||||
|
||||
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
|
||||
@@ -14,19 +11,19 @@ script_location = ${script_location}
|
||||
# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s
|
||||
|
||||
# sys.path path, will be prepended to sys.path if present.
|
||||
# defaults to the current working directory. for multiple paths, the path separator
|
||||
# is defined by "path_separator" below.
|
||||
# defaults to the current working directory.
|
||||
prepend_sys_path = .
|
||||
|
||||
# timezone to use when rendering the date within the migration file
|
||||
# as well as the filename.
|
||||
# If specified, requires the python>=3.9 or backports.zoneinfo library and tzdata library.
|
||||
# Any required deps can installed by adding `alembic[tz]` to the pip requirements
|
||||
# string value is passed to ZoneInfo()
|
||||
# If specified, requires the python-dateutil library that can be
|
||||
# installed by adding `alembic[tz]` to the pip requirements
|
||||
# string value is passed to dateutil.tz.gettz()
|
||||
# leave blank for localtime
|
||||
# timezone =
|
||||
|
||||
# max length of characters to apply to the "slug" field
|
||||
# max length of characters to apply to the
|
||||
# "slug" field
|
||||
# truncate_slug_length = 40
|
||||
|
||||
# set to 'true' to run the environment during
|
||||
@@ -39,37 +36,20 @@ prepend_sys_path = .
|
||||
# sourceless = false
|
||||
|
||||
# version location specification; This defaults
|
||||
# to <script_location>/versions. When using multiple version
|
||||
# to ${script_location}/versions. When using multiple version
|
||||
# directories, initial revisions must be specified with --version-path.
|
||||
# The path separator used here should be the separator specified by "path_separator"
|
||||
# below.
|
||||
# version_locations = %(here)s/bar:%(here)s/bat:%(here)s/alembic/versions
|
||||
# The path separator used here should be the separator specified by "version_path_separator" below.
|
||||
# version_locations = %(here)s/bar:%(here)s/bat:${script_location}/versions
|
||||
|
||||
# path_separator; This indicates what character is used to split lists of file
|
||||
# paths, including version_locations and prepend_sys_path within configparser
|
||||
# files such as alembic.ini.
|
||||
# The default rendered in new alembic.ini files is "os", which uses os.pathsep
|
||||
# to provide os-dependent path splitting.
|
||||
# version path separator; As mentioned above, this is the character used to split
|
||||
# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep.
|
||||
# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas.
|
||||
# Valid values for version_path_separator are:
|
||||
#
|
||||
# Note that in order to support legacy alembic.ini files, this default does NOT
|
||||
# take place if path_separator is not present in alembic.ini. If this
|
||||
# option is omitted entirely, fallback logic is as follows:
|
||||
#
|
||||
# 1. Parsing of the version_locations option falls back to using the legacy
|
||||
# "version_path_separator" key, which if absent then falls back to the legacy
|
||||
# behavior of splitting on spaces and/or commas.
|
||||
# 2. Parsing of the prepend_sys_path option falls back to the legacy
|
||||
# behavior of splitting on spaces, commas, or colons.
|
||||
#
|
||||
# Valid values for path_separator are:
|
||||
#
|
||||
# path_separator = :
|
||||
# path_separator = ;
|
||||
# path_separator = space
|
||||
# path_separator = newline
|
||||
#
|
||||
# Use os.pathsep. Default configuration used for new projects.
|
||||
path_separator = os
|
||||
# version_path_separator = :
|
||||
# version_path_separator = ;
|
||||
# version_path_separator = space
|
||||
version_path_separator = os # Use os.pathsep. Default configuration used for new projects.
|
||||
|
||||
# set to 'true' to search source files recursively
|
||||
# in each "version_locations" directory
|
||||
@@ -80,13 +60,6 @@ path_separator = os
|
||||
# are written from script.py.mako
|
||||
# output_encoding = utf-8
|
||||
|
||||
# for multiple database configuration, new named sections are added
|
||||
# which each include a distinct ``sqlalchemy.url`` entry. A custom value
|
||||
# ``databases`` is added which indicates a listing of the per-database sections.
|
||||
# The ``databases`` entry as well as the URLs present in the ``[engine1]``
|
||||
# and ``[engine2]`` sections continue to be consumed by the user-maintained env.py
|
||||
# script only.
|
||||
|
||||
databases = engine1, engine2
|
||||
|
||||
[engine1]
|
||||
@@ -106,20 +79,13 @@ sqlalchemy.url = driver://user:pass@localhost/dbname2
|
||||
# black.entrypoint = black
|
||||
# black.options = -l 79 REVISION_SCRIPT_FILENAME
|
||||
|
||||
# lint with attempts to fix using "ruff" - use the module runner, against the "ruff" module
|
||||
# hooks = ruff
|
||||
# ruff.type = module
|
||||
# ruff.module = ruff
|
||||
# ruff.options = check --fix REVISION_SCRIPT_FILENAME
|
||||
|
||||
# Alternatively, use the exec runner to execute a binary found on your PATH
|
||||
# lint with attempts to fix using "ruff" - use the exec runner, execute a binary
|
||||
# hooks = ruff
|
||||
# ruff.type = exec
|
||||
# ruff.executable = ruff
|
||||
# ruff.options = check --fix REVISION_SCRIPT_FILENAME
|
||||
# ruff.executable = %(here)s/.venv/bin/ruff
|
||||
# ruff.options = --fix REVISION_SCRIPT_FILENAME
|
||||
|
||||
# Logging configuration. This is also consumed by the user-maintained
|
||||
# env.py script only.
|
||||
# Logging configuration
|
||||
[loggers]
|
||||
keys = root,sqlalchemy,alembic
|
||||
|
||||
@@ -130,12 +96,12 @@ keys = console
|
||||
keys = generic
|
||||
|
||||
[logger_root]
|
||||
level = WARNING
|
||||
level = WARN
|
||||
handlers = console
|
||||
qualname =
|
||||
|
||||
[logger_sqlalchemy]
|
||||
level = WARNING
|
||||
level = WARN
|
||||
handlers =
|
||||
qualname = sqlalchemy.engine
|
||||
|
||||
|
||||
@@ -16,18 +16,16 @@ ${imports if imports else ""}
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = ${repr(up_revision)}
|
||||
down_revision: Union[str, Sequence[str], None] = ${repr(down_revision)}
|
||||
down_revision: Union[str, None] = ${repr(down_revision)}
|
||||
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
|
||||
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
|
||||
|
||||
|
||||
def upgrade(engine_name: str) -> None:
|
||||
"""Upgrade schema."""
|
||||
globals()["upgrade_%s" % engine_name]()
|
||||
|
||||
|
||||
def downgrade(engine_name: str) -> None:
|
||||
"""Downgrade schema."""
|
||||
globals()["downgrade_%s" % engine_name]()
|
||||
|
||||
<%
|
||||
@@ -40,12 +38,10 @@ def downgrade(engine_name: str) -> None:
|
||||
% for db_name in re.split(r',\s*', db_names):
|
||||
|
||||
def upgrade_${db_name}() -> None:
|
||||
"""Upgrade ${db_name} schema."""
|
||||
${context.get("%s_upgrades" % db_name, "pass")}
|
||||
|
||||
|
||||
def downgrade_${db_name}() -> None:
|
||||
"""Downgrade ${db_name} schema."""
|
||||
${context.get("%s_downgrades" % db_name, "pass")}
|
||||
|
||||
% endfor
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
pyproject configuration, based on the generic configuration.
|
||||
@@ -1,44 +0,0 @@
|
||||
# A generic, single database configuration.
|
||||
|
||||
[alembic]
|
||||
|
||||
# database URL. This is consumed by the user-maintained env.py script only.
|
||||
# other means of configuring database URLs may be customized within the env.py
|
||||
# file.
|
||||
sqlalchemy.url = driver://user:pass@localhost/dbname
|
||||
|
||||
|
||||
# Logging configuration
|
||||
[loggers]
|
||||
keys = root,sqlalchemy,alembic
|
||||
|
||||
[handlers]
|
||||
keys = console
|
||||
|
||||
[formatters]
|
||||
keys = generic
|
||||
|
||||
[logger_root]
|
||||
level = WARNING
|
||||
handlers = console
|
||||
qualname =
|
||||
|
||||
[logger_sqlalchemy]
|
||||
level = WARNING
|
||||
handlers =
|
||||
qualname = sqlalchemy.engine
|
||||
|
||||
[logger_alembic]
|
||||
level = INFO
|
||||
handlers =
|
||||
qualname = alembic
|
||||
|
||||
[handler_console]
|
||||
class = StreamHandler
|
||||
args = (sys.stderr,)
|
||||
level = NOTSET
|
||||
formatter = generic
|
||||
|
||||
[formatter_generic]
|
||||
format = %(levelname)-5.5s [%(name)s] %(message)s
|
||||
datefmt = %H:%M:%S
|
||||
@@ -1,78 +0,0 @@
|
||||
from logging.config import fileConfig
|
||||
|
||||
from sqlalchemy import engine_from_config
|
||||
from sqlalchemy import pool
|
||||
|
||||
from alembic import context
|
||||
|
||||
# this is the Alembic Config object, which provides
|
||||
# access to the values within the .ini file in use.
|
||||
config = context.config
|
||||
|
||||
# Interpret the config file for Python logging.
|
||||
# This line sets up loggers basically.
|
||||
if config.config_file_name is not None:
|
||||
fileConfig(config.config_file_name)
|
||||
|
||||
# add your model's MetaData object here
|
||||
# for 'autogenerate' support
|
||||
# from myapp import mymodel
|
||||
# target_metadata = mymodel.Base.metadata
|
||||
target_metadata = None
|
||||
|
||||
# other values from the config, defined by the needs of env.py,
|
||||
# can be acquired:
|
||||
# my_important_option = config.get_main_option("my_important_option")
|
||||
# ... etc.
|
||||
|
||||
|
||||
def run_migrations_offline() -> None:
|
||||
"""Run migrations in 'offline' mode.
|
||||
|
||||
This configures the context with just a URL
|
||||
and not an Engine, though an Engine is acceptable
|
||||
here as well. By skipping the Engine creation
|
||||
we don't even need a DBAPI to be available.
|
||||
|
||||
Calls to context.execute() here emit the given string to the
|
||||
script output.
|
||||
|
||||
"""
|
||||
url = config.get_main_option("sqlalchemy.url")
|
||||
context.configure(
|
||||
url=url,
|
||||
target_metadata=target_metadata,
|
||||
literal_binds=True,
|
||||
dialect_opts={"paramstyle": "named"},
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
def run_migrations_online() -> None:
|
||||
"""Run migrations in 'online' mode.
|
||||
|
||||
In this scenario we need to create an Engine
|
||||
and associate a connection with the context.
|
||||
|
||||
"""
|
||||
connectable = engine_from_config(
|
||||
config.get_section(config.config_ini_section, {}),
|
||||
prefix="sqlalchemy.",
|
||||
poolclass=pool.NullPool,
|
||||
)
|
||||
|
||||
with connectable.connect() as connection:
|
||||
context.configure(
|
||||
connection=connection, target_metadata=target_metadata
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
if context.is_offline_mode():
|
||||
run_migrations_offline()
|
||||
else:
|
||||
run_migrations_online()
|
||||
@@ -1,82 +0,0 @@
|
||||
[tool.alembic]
|
||||
|
||||
# path to migration scripts.
|
||||
# this is typically a path given in POSIX (e.g. forward slashes)
|
||||
# format, relative to the token %(here)s which refers to the location of this
|
||||
# ini file
|
||||
script_location = "${script_location}"
|
||||
|
||||
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
|
||||
# Uncomment the line below if you want the files to be prepended with date and time
|
||||
# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file
|
||||
# for all available tokens
|
||||
# file_template = "%%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s"
|
||||
|
||||
# additional paths to be prepended to sys.path. defaults to the current working directory.
|
||||
prepend_sys_path = [
|
||||
"."
|
||||
]
|
||||
|
||||
# timezone to use when rendering the date within the migration file
|
||||
# as well as the filename.
|
||||
# If specified, requires the python>=3.9 or backports.zoneinfo library and tzdata library.
|
||||
# Any required deps can installed by adding `alembic[tz]` to the pip requirements
|
||||
# string value is passed to ZoneInfo()
|
||||
# leave blank for localtime
|
||||
# timezone =
|
||||
|
||||
# max length of characters to apply to the "slug" field
|
||||
# truncate_slug_length = 40
|
||||
|
||||
# set to 'true' to run the environment during
|
||||
# the 'revision' command, regardless of autogenerate
|
||||
# revision_environment = false
|
||||
|
||||
# set to 'true' to allow .pyc and .pyo files without
|
||||
# a source .py file to be detected as revisions in the
|
||||
# versions/ directory
|
||||
# sourceless = false
|
||||
|
||||
# version location specification; This defaults
|
||||
# to <script_location>/versions. When using multiple version
|
||||
# directories, initial revisions must be specified with --version-path.
|
||||
# version_locations = [
|
||||
# "%(here)s/alembic/versions",
|
||||
# "%(here)s/foo/bar"
|
||||
# ]
|
||||
|
||||
|
||||
# set to 'true' to search source files recursively
|
||||
# in each "version_locations" directory
|
||||
# new in Alembic version 1.10
|
||||
# recursive_version_locations = false
|
||||
|
||||
# the output encoding used when revision files
|
||||
# are written from script.py.mako
|
||||
# output_encoding = "utf-8"
|
||||
|
||||
# This section defines scripts or Python functions that are run
|
||||
# on newly generated revision scripts. See the documentation for further
|
||||
# detail and examples
|
||||
# [[tool.alembic.post_write_hooks]]
|
||||
# format using "black" - use the console_scripts runner,
|
||||
# against the "black" entrypoint
|
||||
# name = "black"
|
||||
# type = "console_scripts"
|
||||
# entrypoint = "black"
|
||||
# options = "-l 79 REVISION_SCRIPT_FILENAME"
|
||||
#
|
||||
# [[tool.alembic.post_write_hooks]]
|
||||
# lint with attempts to fix using "ruff" - use the module runner, against the "ruff" module
|
||||
# name = "ruff"
|
||||
# type = "module"
|
||||
# module = "ruff"
|
||||
# options = "check --fix REVISION_SCRIPT_FILENAME"
|
||||
#
|
||||
# [[tool.alembic.post_write_hooks]]
|
||||
# Alternatively, use the exec runner to execute a binary found on your PATH
|
||||
# name = "ruff"
|
||||
# type = "exec"
|
||||
# executable = "ruff"
|
||||
# options = "check --fix REVISION_SCRIPT_FILENAME"
|
||||
|
||||
@@ -1,28 +0,0 @@
|
||||
"""${message}
|
||||
|
||||
Revision ID: ${up_revision}
|
||||
Revises: ${down_revision | comma,n}
|
||||
Create Date: ${create_date}
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
${imports if imports else ""}
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = ${repr(up_revision)}
|
||||
down_revision: Union[str, Sequence[str], None] = ${repr(down_revision)}
|
||||
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
|
||||
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
${upgrades if upgrades else "pass"}
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
${downgrades if downgrades else "pass"}
|
||||
@@ -1 +0,0 @@
|
||||
pyproject configuration, with an async dbapi.
|
||||
@@ -1,44 +0,0 @@
|
||||
# A generic, single database configuration.
|
||||
|
||||
[alembic]
|
||||
|
||||
# database URL. This is consumed by the user-maintained env.py script only.
|
||||
# other means of configuring database URLs may be customized within the env.py
|
||||
# file.
|
||||
sqlalchemy.url = driver://user:pass@localhost/dbname
|
||||
|
||||
|
||||
# Logging configuration
|
||||
[loggers]
|
||||
keys = root,sqlalchemy,alembic
|
||||
|
||||
[handlers]
|
||||
keys = console
|
||||
|
||||
[formatters]
|
||||
keys = generic
|
||||
|
||||
[logger_root]
|
||||
level = WARNING
|
||||
handlers = console
|
||||
qualname =
|
||||
|
||||
[logger_sqlalchemy]
|
||||
level = WARNING
|
||||
handlers =
|
||||
qualname = sqlalchemy.engine
|
||||
|
||||
[logger_alembic]
|
||||
level = INFO
|
||||
handlers =
|
||||
qualname = alembic
|
||||
|
||||
[handler_console]
|
||||
class = StreamHandler
|
||||
args = (sys.stderr,)
|
||||
level = NOTSET
|
||||
formatter = generic
|
||||
|
||||
[formatter_generic]
|
||||
format = %(levelname)-5.5s [%(name)s] %(message)s
|
||||
datefmt = %H:%M:%S
|
||||
@@ -1,89 +0,0 @@
|
||||
import asyncio
|
||||
from logging.config import fileConfig
|
||||
|
||||
from sqlalchemy import pool
|
||||
from sqlalchemy.engine import Connection
|
||||
from sqlalchemy.ext.asyncio import async_engine_from_config
|
||||
|
||||
from alembic import context
|
||||
|
||||
# this is the Alembic Config object, which provides
|
||||
# access to the values within the .ini file in use.
|
||||
config = context.config
|
||||
|
||||
# Interpret the config file for Python logging.
|
||||
# This line sets up loggers basically.
|
||||
if config.config_file_name is not None:
|
||||
fileConfig(config.config_file_name)
|
||||
|
||||
# add your model's MetaData object here
|
||||
# for 'autogenerate' support
|
||||
# from myapp import mymodel
|
||||
# target_metadata = mymodel.Base.metadata
|
||||
target_metadata = None
|
||||
|
||||
# other values from the config, defined by the needs of env.py,
|
||||
# can be acquired:
|
||||
# my_important_option = config.get_main_option("my_important_option")
|
||||
# ... etc.
|
||||
|
||||
|
||||
def run_migrations_offline() -> None:
|
||||
"""Run migrations in 'offline' mode.
|
||||
|
||||
This configures the context with just a URL
|
||||
and not an Engine, though an Engine is acceptable
|
||||
here as well. By skipping the Engine creation
|
||||
we don't even need a DBAPI to be available.
|
||||
|
||||
Calls to context.execute() here emit the given string to the
|
||||
script output.
|
||||
|
||||
"""
|
||||
url = config.get_main_option("sqlalchemy.url")
|
||||
context.configure(
|
||||
url=url,
|
||||
target_metadata=target_metadata,
|
||||
literal_binds=True,
|
||||
dialect_opts={"paramstyle": "named"},
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
def do_run_migrations(connection: Connection) -> None:
|
||||
context.configure(connection=connection, target_metadata=target_metadata)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
async def run_async_migrations() -> None:
|
||||
"""In this scenario we need to create an Engine
|
||||
and associate a connection with the context.
|
||||
|
||||
"""
|
||||
|
||||
connectable = async_engine_from_config(
|
||||
config.get_section(config.config_ini_section, {}),
|
||||
prefix="sqlalchemy.",
|
||||
poolclass=pool.NullPool,
|
||||
)
|
||||
|
||||
async with connectable.connect() as connection:
|
||||
await connection.run_sync(do_run_migrations)
|
||||
|
||||
await connectable.dispose()
|
||||
|
||||
|
||||
def run_migrations_online() -> None:
|
||||
"""Run migrations in 'online' mode."""
|
||||
|
||||
asyncio.run(run_async_migrations())
|
||||
|
||||
|
||||
if context.is_offline_mode():
|
||||
run_migrations_offline()
|
||||
else:
|
||||
run_migrations_online()
|
||||
@@ -1,82 +0,0 @@
|
||||
[tool.alembic]
|
||||
|
||||
# path to migration scripts.
|
||||
# this is typically a path given in POSIX (e.g. forward slashes)
|
||||
# format, relative to the token %(here)s which refers to the location of this
|
||||
# ini file
|
||||
script_location = "${script_location}"
|
||||
|
||||
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
|
||||
# Uncomment the line below if you want the files to be prepended with date and time
|
||||
# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file
|
||||
# for all available tokens
|
||||
# file_template = "%%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s"
|
||||
|
||||
# additional paths to be prepended to sys.path. defaults to the current working directory.
|
||||
prepend_sys_path = [
|
||||
"."
|
||||
]
|
||||
|
||||
# timezone to use when rendering the date within the migration file
|
||||
# as well as the filename.
|
||||
# If specified, requires the python>=3.9 or backports.zoneinfo library and tzdata library.
|
||||
# Any required deps can installed by adding `alembic[tz]` to the pip requirements
|
||||
# string value is passed to ZoneInfo()
|
||||
# leave blank for localtime
|
||||
# timezone =
|
||||
|
||||
# max length of characters to apply to the "slug" field
|
||||
# truncate_slug_length = 40
|
||||
|
||||
# set to 'true' to run the environment during
|
||||
# the 'revision' command, regardless of autogenerate
|
||||
# revision_environment = false
|
||||
|
||||
# set to 'true' to allow .pyc and .pyo files without
|
||||
# a source .py file to be detected as revisions in the
|
||||
# versions/ directory
|
||||
# sourceless = false
|
||||
|
||||
# version location specification; This defaults
|
||||
# to <script_location>/versions. When using multiple version
|
||||
# directories, initial revisions must be specified with --version-path.
|
||||
# version_locations = [
|
||||
# "%(here)s/alembic/versions",
|
||||
# "%(here)s/foo/bar"
|
||||
# ]
|
||||
|
||||
|
||||
# set to 'true' to search source files recursively
|
||||
# in each "version_locations" directory
|
||||
# new in Alembic version 1.10
|
||||
# recursive_version_locations = false
|
||||
|
||||
# the output encoding used when revision files
|
||||
# are written from script.py.mako
|
||||
# output_encoding = "utf-8"
|
||||
|
||||
# This section defines scripts or Python functions that are run
|
||||
# on newly generated revision scripts. See the documentation for further
|
||||
# detail and examples
|
||||
# [[tool.alembic.post_write_hooks]]
|
||||
# format using "black" - use the console_scripts runner,
|
||||
# against the "black" entrypoint
|
||||
# name = "black"
|
||||
# type = "console_scripts"
|
||||
# entrypoint = "black"
|
||||
# options = "-l 79 REVISION_SCRIPT_FILENAME"
|
||||
#
|
||||
# [[tool.alembic.post_write_hooks]]
|
||||
# lint with attempts to fix using "ruff" - use the module runner, against the "ruff" module
|
||||
# name = "ruff"
|
||||
# type = "module"
|
||||
# module = "ruff"
|
||||
# options = "check --fix REVISION_SCRIPT_FILENAME"
|
||||
#
|
||||
# [[tool.alembic.post_write_hooks]]
|
||||
# Alternatively, use the exec runner to execute a binary found on your PATH
|
||||
# name = "ruff"
|
||||
# type = "exec"
|
||||
# executable = "ruff"
|
||||
# options = "check --fix REVISION_SCRIPT_FILENAME"
|
||||
|
||||
@@ -1,28 +0,0 @@
|
||||
"""${message}
|
||||
|
||||
Revision ID: ${up_revision}
|
||||
Revises: ${down_revision | comma,n}
|
||||
Create Date: ${create_date}
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
${imports if imports else ""}
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = ${repr(up_revision)}
|
||||
down_revision: Union[str, Sequence[str], None] = ${repr(down_revision)}
|
||||
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
|
||||
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
${upgrades if upgrades else "pass"}
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
${downgrades if downgrades else "pass"}
|
||||
@@ -9,15 +9,12 @@ from sqlalchemy.testing import uses_deprecated
|
||||
from sqlalchemy.testing.config import combinations
|
||||
from sqlalchemy.testing.config import fixture
|
||||
from sqlalchemy.testing.config import requirements as requires
|
||||
from sqlalchemy.testing.config import Variation
|
||||
from sqlalchemy.testing.config import variation
|
||||
|
||||
from .assertions import assert_raises
|
||||
from .assertions import assert_raises_message
|
||||
from .assertions import emits_python_deprecation_warning
|
||||
from .assertions import eq_
|
||||
from .assertions import eq_ignore_whitespace
|
||||
from .assertions import expect_deprecated
|
||||
from .assertions import expect_raises
|
||||
from .assertions import expect_raises_message
|
||||
from .assertions import expect_sqlalchemy_deprecated
|
||||
|
||||
@@ -8,7 +8,6 @@ from typing import Dict
|
||||
|
||||
from sqlalchemy import exc as sa_exc
|
||||
from sqlalchemy.engine import default
|
||||
from sqlalchemy.engine import URL
|
||||
from sqlalchemy.testing.assertions import _expect_warnings
|
||||
from sqlalchemy.testing.assertions import eq_ # noqa
|
||||
from sqlalchemy.testing.assertions import is_ # noqa
|
||||
@@ -18,6 +17,8 @@ from sqlalchemy.testing.assertions import is_true # noqa
|
||||
from sqlalchemy.testing.assertions import ne_ # noqa
|
||||
from sqlalchemy.util import decorator
|
||||
|
||||
from ..util import sqla_compat
|
||||
|
||||
|
||||
def _assert_proper_exception_context(exception):
|
||||
"""assert that any exception we're catching does not have a __context__
|
||||
@@ -73,9 +74,7 @@ class _ErrorContainer:
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _expect_raises(
|
||||
except_cls, msg=None, check_context=False, text_exact=False
|
||||
):
|
||||
def _expect_raises(except_cls, msg=None, check_context=False):
|
||||
ec = _ErrorContainer()
|
||||
if check_context:
|
||||
are_we_already_in_a_traceback = sys.exc_info()[0]
|
||||
@@ -86,10 +85,7 @@ def _expect_raises(
|
||||
ec.error = err
|
||||
success = True
|
||||
if msg is not None:
|
||||
if text_exact:
|
||||
assert str(err) == msg, f"{msg} != {err}"
|
||||
else:
|
||||
assert re.search(msg, str(err), re.UNICODE), f"{msg} !~ {err}"
|
||||
assert re.search(msg, str(err), re.UNICODE), f"{msg} !~ {err}"
|
||||
if check_context and not are_we_already_in_a_traceback:
|
||||
_assert_proper_exception_context(err)
|
||||
print(str(err).encode("utf-8"))
|
||||
@@ -102,12 +98,8 @@ def expect_raises(except_cls, check_context=True):
|
||||
return _expect_raises(except_cls, check_context=check_context)
|
||||
|
||||
|
||||
def expect_raises_message(
|
||||
except_cls, msg, check_context=True, text_exact=False
|
||||
):
|
||||
return _expect_raises(
|
||||
except_cls, msg=msg, check_context=check_context, text_exact=text_exact
|
||||
)
|
||||
def expect_raises_message(except_cls, msg, check_context=True):
|
||||
return _expect_raises(except_cls, msg=msg, check_context=check_context)
|
||||
|
||||
|
||||
def eq_ignore_whitespace(a, b, msg=None):
|
||||
@@ -126,7 +118,7 @@ def _get_dialect(name):
|
||||
if name is None or name == "default":
|
||||
return default.DefaultDialect()
|
||||
else:
|
||||
d = URL.create(name).get_dialect()()
|
||||
d = sqla_compat._create_url(name).get_dialect()()
|
||||
|
||||
if name == "postgresql":
|
||||
d.implicit_returning = True
|
||||
@@ -167,10 +159,6 @@ def emits_python_deprecation_warning(*messages):
|
||||
return decorate
|
||||
|
||||
|
||||
def expect_deprecated(*messages, **kw):
|
||||
return _expect_warnings(DeprecationWarning, messages, **kw)
|
||||
|
||||
|
||||
def expect_sqlalchemy_deprecated(*messages, **kw):
|
||||
return _expect_warnings(sa_exc.SADeprecationWarning, messages, **kw)
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import importlib.machinery
|
||||
import os
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
import textwrap
|
||||
|
||||
@@ -17,7 +16,7 @@ from ..script import ScriptDirectory
|
||||
|
||||
def _get_staging_directory():
|
||||
if provision.FOLLOWER_IDENT:
|
||||
return f"scratch_{provision.FOLLOWER_IDENT}"
|
||||
return "scratch_%s" % provision.FOLLOWER_IDENT
|
||||
else:
|
||||
return "scratch"
|
||||
|
||||
@@ -25,7 +24,7 @@ def _get_staging_directory():
|
||||
def staging_env(create=True, template="generic", sourceless=False):
|
||||
cfg = _testing_config()
|
||||
if create:
|
||||
path = _join_path(_get_staging_directory(), "scripts")
|
||||
path = os.path.join(_get_staging_directory(), "scripts")
|
||||
assert not os.path.exists(path), (
|
||||
"staging directory %s already exists; poor cleanup?" % path
|
||||
)
|
||||
@@ -48,7 +47,7 @@ def staging_env(create=True, template="generic", sourceless=False):
|
||||
"pep3147_everything",
|
||||
), sourceless
|
||||
make_sourceless(
|
||||
_join_path(path, "env.py"),
|
||||
os.path.join(path, "env.py"),
|
||||
"pep3147" if "pep3147" in sourceless else "simple",
|
||||
)
|
||||
|
||||
@@ -64,14 +63,14 @@ def clear_staging_env():
|
||||
|
||||
|
||||
def script_file_fixture(txt):
|
||||
dir_ = _join_path(_get_staging_directory(), "scripts")
|
||||
path = _join_path(dir_, "script.py.mako")
|
||||
dir_ = os.path.join(_get_staging_directory(), "scripts")
|
||||
path = os.path.join(dir_, "script.py.mako")
|
||||
with open(path, "w") as f:
|
||||
f.write(txt)
|
||||
|
||||
|
||||
def env_file_fixture(txt):
|
||||
dir_ = _join_path(_get_staging_directory(), "scripts")
|
||||
dir_ = os.path.join(_get_staging_directory(), "scripts")
|
||||
txt = (
|
||||
"""
|
||||
from alembic import context
|
||||
@@ -81,7 +80,7 @@ config = context.config
|
||||
+ txt
|
||||
)
|
||||
|
||||
path = _join_path(dir_, "env.py")
|
||||
path = os.path.join(dir_, "env.py")
|
||||
pyc_path = util.pyc_file_from_path(path)
|
||||
if pyc_path:
|
||||
os.unlink(pyc_path)
|
||||
@@ -91,26 +90,26 @@ config = context.config
|
||||
|
||||
|
||||
def _sqlite_file_db(tempname="foo.db", future=False, scope=None, **options):
|
||||
dir_ = _join_path(_get_staging_directory(), "scripts")
|
||||
dir_ = os.path.join(_get_staging_directory(), "scripts")
|
||||
url = "sqlite:///%s/%s" % (dir_, tempname)
|
||||
if scope:
|
||||
if scope and util.sqla_14:
|
||||
options["scope"] = scope
|
||||
return testing_util.testing_engine(url=url, future=future, options=options)
|
||||
|
||||
|
||||
def _sqlite_testing_config(sourceless=False, future=False):
|
||||
dir_ = _join_path(_get_staging_directory(), "scripts")
|
||||
url = f"sqlite:///{dir_}/foo.db"
|
||||
dir_ = os.path.join(_get_staging_directory(), "scripts")
|
||||
url = "sqlite:///%s/foo.db" % dir_
|
||||
|
||||
sqlalchemy_future = future or ("future" in config.db.__class__.__module__)
|
||||
|
||||
return _write_config_file(
|
||||
f"""
|
||||
"""
|
||||
[alembic]
|
||||
script_location = {dir_}
|
||||
sqlalchemy.url = {url}
|
||||
sourceless = {"true" if sourceless else "false"}
|
||||
{"sqlalchemy.future = true" if sqlalchemy_future else ""}
|
||||
script_location = %s
|
||||
sqlalchemy.url = %s
|
||||
sourceless = %s
|
||||
%s
|
||||
|
||||
[loggers]
|
||||
keys = root,sqlalchemy
|
||||
@@ -119,7 +118,7 @@ keys = root,sqlalchemy
|
||||
keys = console
|
||||
|
||||
[logger_root]
|
||||
level = WARNING
|
||||
level = WARN
|
||||
handlers = console
|
||||
qualname =
|
||||
|
||||
@@ -141,25 +140,29 @@ keys = generic
|
||||
format = %%(levelname)-5.5s [%%(name)s] %%(message)s
|
||||
datefmt = %%H:%%M:%%S
|
||||
"""
|
||||
% (
|
||||
dir_,
|
||||
url,
|
||||
"true" if sourceless else "false",
|
||||
"sqlalchemy.future = true" if sqlalchemy_future else "",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _multi_dir_testing_config(sourceless=False, extra_version_location=""):
|
||||
dir_ = _join_path(_get_staging_directory(), "scripts")
|
||||
dir_ = os.path.join(_get_staging_directory(), "scripts")
|
||||
sqlalchemy_future = "future" in config.db.__class__.__module__
|
||||
|
||||
url = "sqlite:///%s/foo.db" % dir_
|
||||
|
||||
return _write_config_file(
|
||||
f"""
|
||||
"""
|
||||
[alembic]
|
||||
script_location = {dir_}
|
||||
sqlalchemy.url = {url}
|
||||
sqlalchemy.future = {"true" if sqlalchemy_future else "false"}
|
||||
sourceless = {"true" if sourceless else "false"}
|
||||
path_separator = space
|
||||
version_locations = %(here)s/model1/ %(here)s/model2/ %(here)s/model3/ \
|
||||
{extra_version_location}
|
||||
script_location = %s
|
||||
sqlalchemy.url = %s
|
||||
sqlalchemy.future = %s
|
||||
sourceless = %s
|
||||
version_locations = %%(here)s/model1/ %%(here)s/model2/ %%(here)s/model3/ %s
|
||||
|
||||
[loggers]
|
||||
keys = root
|
||||
@@ -168,7 +171,7 @@ keys = root
|
||||
keys = console
|
||||
|
||||
[logger_root]
|
||||
level = WARNING
|
||||
level = WARN
|
||||
handlers = console
|
||||
qualname =
|
||||
|
||||
@@ -185,63 +188,26 @@ keys = generic
|
||||
format = %%(levelname)-5.5s [%%(name)s] %%(message)s
|
||||
datefmt = %%H:%%M:%%S
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def _no_sql_pyproject_config(dialect="postgresql", directives=""):
|
||||
"""use a postgresql url with no host so that
|
||||
connections guaranteed to fail"""
|
||||
dir_ = _join_path(_get_staging_directory(), "scripts")
|
||||
|
||||
return _write_toml_config(
|
||||
f"""
|
||||
[tool.alembic]
|
||||
script_location ="{dir_}"
|
||||
{textwrap.dedent(directives)}
|
||||
|
||||
""",
|
||||
f"""
|
||||
[alembic]
|
||||
sqlalchemy.url = {dialect}://
|
||||
|
||||
[loggers]
|
||||
keys = root
|
||||
|
||||
[handlers]
|
||||
keys = console
|
||||
|
||||
[logger_root]
|
||||
level = WARNING
|
||||
handlers = console
|
||||
qualname =
|
||||
|
||||
[handler_console]
|
||||
class = StreamHandler
|
||||
args = (sys.stderr,)
|
||||
level = NOTSET
|
||||
formatter = generic
|
||||
|
||||
[formatters]
|
||||
keys = generic
|
||||
|
||||
[formatter_generic]
|
||||
format = %%(levelname)-5.5s [%%(name)s] %%(message)s
|
||||
datefmt = %%H:%%M:%%S
|
||||
|
||||
""",
|
||||
% (
|
||||
dir_,
|
||||
url,
|
||||
"true" if sqlalchemy_future else "false",
|
||||
"true" if sourceless else "false",
|
||||
extra_version_location,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _no_sql_testing_config(dialect="postgresql", directives=""):
|
||||
"""use a postgresql url with no host so that
|
||||
connections guaranteed to fail"""
|
||||
dir_ = _join_path(_get_staging_directory(), "scripts")
|
||||
dir_ = os.path.join(_get_staging_directory(), "scripts")
|
||||
return _write_config_file(
|
||||
f"""
|
||||
"""
|
||||
[alembic]
|
||||
script_location ={dir_}
|
||||
sqlalchemy.url = {dialect}://
|
||||
{directives}
|
||||
script_location = %s
|
||||
sqlalchemy.url = %s://
|
||||
%s
|
||||
|
||||
[loggers]
|
||||
keys = root
|
||||
@@ -250,7 +216,7 @@ keys = root
|
||||
keys = console
|
||||
|
||||
[logger_root]
|
||||
level = WARNING
|
||||
level = WARN
|
||||
handlers = console
|
||||
qualname =
|
||||
|
||||
@@ -268,16 +234,10 @@ format = %%(levelname)-5.5s [%%(name)s] %%(message)s
|
||||
datefmt = %%H:%%M:%%S
|
||||
|
||||
"""
|
||||
% (dir_, dialect, directives)
|
||||
)
|
||||
|
||||
|
||||
def _write_toml_config(tomltext, initext):
|
||||
cfg = _write_config_file(initext)
|
||||
with open(cfg.toml_file_name, "w") as f:
|
||||
f.write(tomltext)
|
||||
return cfg
|
||||
|
||||
|
||||
def _write_config_file(text):
|
||||
cfg = _testing_config()
|
||||
with open(cfg.config_file_name, "w") as f:
|
||||
@@ -290,10 +250,7 @@ def _testing_config():
|
||||
|
||||
if not os.access(_get_staging_directory(), os.F_OK):
|
||||
os.mkdir(_get_staging_directory())
|
||||
return Config(
|
||||
_join_path(_get_staging_directory(), "test_alembic.ini"),
|
||||
_join_path(_get_staging_directory(), "pyproject.toml"),
|
||||
)
|
||||
return Config(os.path.join(_get_staging_directory(), "test_alembic.ini"))
|
||||
|
||||
|
||||
def write_script(
|
||||
@@ -313,7 +270,9 @@ def write_script(
|
||||
script = Script._from_path(scriptdir, path)
|
||||
old = scriptdir.revision_map.get_revision(script.revision)
|
||||
if old.down_revision != script.down_revision:
|
||||
raise Exception("Can't change down_revision on a refresh operation.")
|
||||
raise Exception(
|
||||
"Can't change down_revision " "on a refresh operation."
|
||||
)
|
||||
scriptdir.revision_map.add_revision(script, _replace=True)
|
||||
|
||||
if sourceless:
|
||||
@@ -353,9 +312,9 @@ def three_rev_fixture(cfg):
|
||||
write_script(
|
||||
script,
|
||||
a,
|
||||
f"""\
|
||||
"""\
|
||||
"Rev A"
|
||||
revision = '{a}'
|
||||
revision = '%s'
|
||||
down_revision = None
|
||||
|
||||
from alembic import op
|
||||
@@ -368,7 +327,8 @@ def upgrade():
|
||||
def downgrade():
|
||||
op.execute("DROP STEP 1")
|
||||
|
||||
""",
|
||||
"""
|
||||
% a,
|
||||
)
|
||||
|
||||
script.generate_revision(b, "revision b", refresh=True, head=a)
|
||||
@@ -398,10 +358,10 @@ def downgrade():
|
||||
write_script(
|
||||
script,
|
||||
c,
|
||||
f"""\
|
||||
"""\
|
||||
"Rev C"
|
||||
revision = '{c}'
|
||||
down_revision = '{b}'
|
||||
revision = '%s'
|
||||
down_revision = '%s'
|
||||
|
||||
from alembic import op
|
||||
|
||||
@@ -413,7 +373,8 @@ def upgrade():
|
||||
def downgrade():
|
||||
op.execute("DROP STEP 3")
|
||||
|
||||
""",
|
||||
"""
|
||||
% (c, b),
|
||||
)
|
||||
return a, b, c
|
||||
|
||||
@@ -435,10 +396,10 @@ def multi_heads_fixture(cfg, a, b, c):
|
||||
write_script(
|
||||
script,
|
||||
d,
|
||||
f"""\
|
||||
"""\
|
||||
"Rev D"
|
||||
revision = '{d}'
|
||||
down_revision = '{b}'
|
||||
revision = '%s'
|
||||
down_revision = '%s'
|
||||
|
||||
from alembic import op
|
||||
|
||||
@@ -450,7 +411,8 @@ def upgrade():
|
||||
def downgrade():
|
||||
op.execute("DROP STEP 4")
|
||||
|
||||
""",
|
||||
"""
|
||||
% (d, b),
|
||||
)
|
||||
|
||||
script.generate_revision(
|
||||
@@ -459,10 +421,10 @@ def downgrade():
|
||||
write_script(
|
||||
script,
|
||||
e,
|
||||
f"""\
|
||||
"""\
|
||||
"Rev E"
|
||||
revision = '{e}'
|
||||
down_revision = '{d}'
|
||||
revision = '%s'
|
||||
down_revision = '%s'
|
||||
|
||||
from alembic import op
|
||||
|
||||
@@ -474,7 +436,8 @@ def upgrade():
|
||||
def downgrade():
|
||||
op.execute("DROP STEP 5")
|
||||
|
||||
""",
|
||||
"""
|
||||
% (e, d),
|
||||
)
|
||||
|
||||
script.generate_revision(
|
||||
@@ -483,10 +446,10 @@ def downgrade():
|
||||
write_script(
|
||||
script,
|
||||
f,
|
||||
f"""\
|
||||
"""\
|
||||
"Rev F"
|
||||
revision = '{f}'
|
||||
down_revision = '{b}'
|
||||
revision = '%s'
|
||||
down_revision = '%s'
|
||||
|
||||
from alembic import op
|
||||
|
||||
@@ -498,7 +461,8 @@ def upgrade():
|
||||
def downgrade():
|
||||
op.execute("DROP STEP 6")
|
||||
|
||||
""",
|
||||
"""
|
||||
% (f, b),
|
||||
)
|
||||
|
||||
return d, e, f
|
||||
@@ -507,25 +471,25 @@ def downgrade():
|
||||
def _multidb_testing_config(engines):
|
||||
"""alembic.ini fixture to work exactly with the 'multidb' template"""
|
||||
|
||||
dir_ = _join_path(_get_staging_directory(), "scripts")
|
||||
dir_ = os.path.join(_get_staging_directory(), "scripts")
|
||||
|
||||
sqlalchemy_future = "future" in config.db.__class__.__module__
|
||||
|
||||
databases = ", ".join(engines.keys())
|
||||
engines = "\n\n".join(
|
||||
f"[{key}]\nsqlalchemy.url = {value.url}"
|
||||
"[%s]\n" "sqlalchemy.url = %s" % (key, value.url)
|
||||
for key, value in engines.items()
|
||||
)
|
||||
|
||||
return _write_config_file(
|
||||
f"""
|
||||
"""
|
||||
[alembic]
|
||||
script_location = {dir_}
|
||||
script_location = %s
|
||||
sourceless = false
|
||||
sqlalchemy.future = {"true" if sqlalchemy_future else "false"}
|
||||
databases = {databases}
|
||||
sqlalchemy.future = %s
|
||||
databases = %s
|
||||
|
||||
{engines}
|
||||
%s
|
||||
[loggers]
|
||||
keys = root
|
||||
|
||||
@@ -533,7 +497,7 @@ keys = root
|
||||
keys = console
|
||||
|
||||
[logger_root]
|
||||
level = WARNING
|
||||
level = WARN
|
||||
handlers = console
|
||||
qualname =
|
||||
|
||||
@@ -550,8 +514,5 @@ keys = generic
|
||||
format = %%(levelname)-5.5s [%%(name)s] %%(message)s
|
||||
datefmt = %%H:%%M:%%S
|
||||
"""
|
||||
% (dir_, "true" if sqlalchemy_future else "false", databases, engines)
|
||||
)
|
||||
|
||||
|
||||
def _join_path(base: str, *more: str):
|
||||
return str(Path(base).joinpath(*more).as_posix())
|
||||
|
||||
@@ -3,14 +3,11 @@ from __future__ import annotations
|
||||
import configparser
|
||||
from contextlib import contextmanager
|
||||
import io
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
|
||||
from sqlalchemy import Column
|
||||
from sqlalchemy import create_mock_engine
|
||||
from sqlalchemy import inspect
|
||||
from sqlalchemy import MetaData
|
||||
from sqlalchemy import String
|
||||
@@ -20,19 +17,20 @@ from sqlalchemy import text
|
||||
from sqlalchemy.testing import config
|
||||
from sqlalchemy.testing import mock
|
||||
from sqlalchemy.testing.assertions import eq_
|
||||
from sqlalchemy.testing.fixtures import FutureEngineMixin
|
||||
from sqlalchemy.testing.fixtures import TablesTest as SQLAlchemyTablesTest
|
||||
from sqlalchemy.testing.fixtures import TestBase as SQLAlchemyTestBase
|
||||
|
||||
import alembic
|
||||
from .assertions import _get_dialect
|
||||
from .env import _get_staging_directory
|
||||
from ..environment import EnvironmentContext
|
||||
from ..migration import MigrationContext
|
||||
from ..operations import Operations
|
||||
from ..util import sqla_compat
|
||||
from ..util.sqla_compat import create_mock_engine
|
||||
from ..util.sqla_compat import sqla_14
|
||||
from ..util.sqla_compat import sqla_2
|
||||
|
||||
|
||||
testing_config = configparser.ConfigParser()
|
||||
testing_config.read(["test.cfg"])
|
||||
|
||||
@@ -40,31 +38,6 @@ testing_config.read(["test.cfg"])
|
||||
class TestBase(SQLAlchemyTestBase):
|
||||
is_sqlalchemy_future = sqla_2
|
||||
|
||||
@testing.fixture()
|
||||
def clear_staging_dir(self):
|
||||
yield
|
||||
location = _get_staging_directory()
|
||||
for filename in os.listdir(location):
|
||||
file_path = os.path.join(location, filename)
|
||||
if os.path.isfile(file_path) or os.path.islink(file_path):
|
||||
os.unlink(file_path)
|
||||
elif os.path.isdir(file_path):
|
||||
shutil.rmtree(file_path)
|
||||
|
||||
@contextmanager
|
||||
def pushd(self, dirname):
|
||||
current_dir = os.getcwd()
|
||||
try:
|
||||
os.chdir(dirname)
|
||||
yield
|
||||
finally:
|
||||
os.chdir(current_dir)
|
||||
|
||||
@testing.fixture()
|
||||
def pop_alembic_config_env(self):
|
||||
yield
|
||||
os.environ.pop("ALEMBIC_CONFIG", None)
|
||||
|
||||
@testing.fixture()
|
||||
def ops_context(self, migration_context):
|
||||
with migration_context.begin_transaction(_per_migration=True):
|
||||
@@ -76,12 +49,6 @@ class TestBase(SQLAlchemyTestBase):
|
||||
connection, opts=dict(transaction_per_migration=True)
|
||||
)
|
||||
|
||||
@testing.fixture
|
||||
def as_sql_migration_context(self, connection):
|
||||
return MigrationContext.configure(
|
||||
connection, opts=dict(transaction_per_migration=True, as_sql=True)
|
||||
)
|
||||
|
||||
@testing.fixture
|
||||
def connection(self):
|
||||
with config.db.connect() as conn:
|
||||
@@ -92,6 +59,14 @@ class TablesTest(TestBase, SQLAlchemyTablesTest):
|
||||
pass
|
||||
|
||||
|
||||
if sqla_14:
|
||||
from sqlalchemy.testing.fixtures import FutureEngineMixin
|
||||
else:
|
||||
|
||||
class FutureEngineMixin: # type:ignore[no-redef]
|
||||
__requires__ = ("sqlalchemy_14",)
|
||||
|
||||
|
||||
FutureEngineMixin.is_sqlalchemy_future = True
|
||||
|
||||
|
||||
@@ -209,8 +184,12 @@ def op_fixture(
|
||||
opts["as_sql"] = as_sql
|
||||
if literal_binds:
|
||||
opts["literal_binds"] = literal_binds
|
||||
if not sqla_14 and dialect == "mariadb":
|
||||
ctx_dialect = _get_dialect("mysql")
|
||||
ctx_dialect.server_version_info = (10, 4, 0, "MariaDB")
|
||||
|
||||
ctx_dialect = _get_dialect(dialect)
|
||||
else:
|
||||
ctx_dialect = _get_dialect(dialect)
|
||||
if native_boolean is not None:
|
||||
ctx_dialect.supports_native_boolean = native_boolean
|
||||
# this is new as of SQLAlchemy 1.2.7 and is used by SQL Server,
|
||||
@@ -289,11 +268,9 @@ class AlterColRoundTripFixture:
|
||||
"x",
|
||||
column.name,
|
||||
existing_type=column.type,
|
||||
existing_server_default=(
|
||||
column.server_default
|
||||
if column.server_default is not None
|
||||
else False
|
||||
),
|
||||
existing_server_default=column.server_default
|
||||
if column.server_default is not None
|
||||
else False,
|
||||
existing_nullable=True if column.nullable else False,
|
||||
# existing_comment=column.comment,
|
||||
nullable=to_.get("nullable", None),
|
||||
@@ -321,13 +298,9 @@ class AlterColRoundTripFixture:
|
||||
new_col["type"],
|
||||
new_col.get("default", None),
|
||||
compare.get("type", old_col["type"]),
|
||||
(
|
||||
compare["server_default"].text
|
||||
if "server_default" in compare
|
||||
else (
|
||||
column.server_default.arg.text
|
||||
if column.server_default is not None
|
||||
else None
|
||||
)
|
||||
),
|
||||
compare["server_default"].text
|
||||
if "server_default" in compare
|
||||
else column.server_default.arg.text
|
||||
if column.server_default is not None
|
||||
else None,
|
||||
)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from sqlalchemy.testing.requirements import Requirements
|
||||
|
||||
from alembic import util
|
||||
from alembic.util import sqla_compat
|
||||
from ..testing import exclusions
|
||||
|
||||
|
||||
@@ -73,6 +74,13 @@ class SuiteRequirements(Requirements):
|
||||
def reflects_fk_options(self):
|
||||
return exclusions.closed()
|
||||
|
||||
@property
|
||||
def sqlalchemy_14(self):
|
||||
return exclusions.skip_if(
|
||||
lambda config: not util.sqla_14,
|
||||
"SQLAlchemy 1.4 or greater required",
|
||||
)
|
||||
|
||||
@property
|
||||
def sqlalchemy_1x(self):
|
||||
return exclusions.skip_if(
|
||||
@@ -87,18 +95,6 @@ class SuiteRequirements(Requirements):
|
||||
"SQLAlchemy 2.x test",
|
||||
)
|
||||
|
||||
@property
|
||||
def asyncio(self):
|
||||
def go(config):
|
||||
try:
|
||||
import greenlet # noqa: F401
|
||||
except ImportError:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
return exclusions.only_if(go)
|
||||
|
||||
@property
|
||||
def comments(self):
|
||||
return exclusions.only_if(
|
||||
@@ -113,6 +109,26 @@ class SuiteRequirements(Requirements):
|
||||
def computed_columns(self):
|
||||
return exclusions.closed()
|
||||
|
||||
@property
|
||||
def computed_columns_api(self):
|
||||
return exclusions.only_if(
|
||||
exclusions.BooleanPredicate(sqla_compat.has_computed)
|
||||
)
|
||||
|
||||
@property
|
||||
def computed_reflects_normally(self):
|
||||
return exclusions.only_if(
|
||||
exclusions.BooleanPredicate(sqla_compat.has_computed_reflection)
|
||||
)
|
||||
|
||||
@property
|
||||
def computed_reflects_as_server_default(self):
|
||||
return exclusions.closed()
|
||||
|
||||
@property
|
||||
def computed_doesnt_reflect_as_server_default(self):
|
||||
return exclusions.closed()
|
||||
|
||||
@property
|
||||
def autoincrement_on_composite_pk(self):
|
||||
return exclusions.closed()
|
||||
@@ -174,3 +190,9 @@ class SuiteRequirements(Requirements):
|
||||
@property
|
||||
def identity_columns_alter(self):
|
||||
return exclusions.closed()
|
||||
|
||||
@property
|
||||
def identity_columns_api(self):
|
||||
return exclusions.only_if(
|
||||
exclusions.BooleanPredicate(sqla_compat.has_identity)
|
||||
)
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from itertools import zip_longest
|
||||
|
||||
from sqlalchemy import schema
|
||||
from sqlalchemy.sql.elements import ClauseList
|
||||
|
||||
|
||||
class CompareTable:
|
||||
@@ -61,14 +60,6 @@ class CompareIndex:
|
||||
def __ne__(self, other):
|
||||
return not self.__eq__(other)
|
||||
|
||||
def __repr__(self):
|
||||
expr = ClauseList(*self.index.expressions)
|
||||
try:
|
||||
expr_str = expr.compile().string
|
||||
except Exception:
|
||||
expr_str = str(expr)
|
||||
return f"<CompareIndex {self.index.name}({expr_str})>"
|
||||
|
||||
|
||||
class CompareCheckConstraint:
|
||||
def __init__(self, constraint):
|
||||
|
||||
@@ -14,7 +14,6 @@ from sqlalchemy import inspect
|
||||
from sqlalchemy import Integer
|
||||
from sqlalchemy import MetaData
|
||||
from sqlalchemy import Numeric
|
||||
from sqlalchemy import PrimaryKeyConstraint
|
||||
from sqlalchemy import String
|
||||
from sqlalchemy import Table
|
||||
from sqlalchemy import Text
|
||||
@@ -150,118 +149,6 @@ class ModelOne:
|
||||
return m
|
||||
|
||||
|
||||
class NamingConvModel:
|
||||
__requires__ = ("unique_constraint_reflection",)
|
||||
configure_opts = {"conv_all_constraint_names": True}
|
||||
naming_convention = {
|
||||
"ix": "ix_%(column_0_label)s",
|
||||
"uq": "uq_%(table_name)s_%(constraint_name)s",
|
||||
"ck": "ck_%(table_name)s_%(constraint_name)s",
|
||||
"fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
|
||||
"pk": "pk_%(table_name)s",
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def _get_db_schema(cls):
|
||||
# database side - assume all constraints have a name that
|
||||
# we would assume here is a "db generated" name. need to make
|
||||
# sure these all render with op.f().
|
||||
m = MetaData()
|
||||
Table(
|
||||
"x1",
|
||||
m,
|
||||
Column("q", Integer),
|
||||
Index("db_x1_index_q", "q"),
|
||||
PrimaryKeyConstraint("q", name="db_x1_primary_q"),
|
||||
)
|
||||
Table(
|
||||
"x2",
|
||||
m,
|
||||
Column("q", Integer),
|
||||
Column("p", ForeignKey("x1.q", name="db_x2_foreign_q")),
|
||||
CheckConstraint("q > 5", name="db_x2_check_q"),
|
||||
)
|
||||
Table(
|
||||
"x3",
|
||||
m,
|
||||
Column("q", Integer),
|
||||
Column("r", Integer),
|
||||
Column("s", Integer),
|
||||
UniqueConstraint("q", name="db_x3_unique_q"),
|
||||
)
|
||||
Table(
|
||||
"x4",
|
||||
m,
|
||||
Column("q", Integer),
|
||||
PrimaryKeyConstraint("q", name="db_x4_primary_q"),
|
||||
)
|
||||
Table(
|
||||
"x5",
|
||||
m,
|
||||
Column("q", Integer),
|
||||
Column("p", ForeignKey("x4.q", name="db_x5_foreign_q")),
|
||||
Column("r", Integer),
|
||||
Column("s", Integer),
|
||||
PrimaryKeyConstraint("q", name="db_x5_primary_q"),
|
||||
UniqueConstraint("r", name="db_x5_unique_r"),
|
||||
CheckConstraint("s > 5", name="db_x5_check_s"),
|
||||
)
|
||||
# SQLite and it's "no names needed" thing. bleh.
|
||||
# we can't have a name for these so you'll see "None" for the name.
|
||||
Table(
|
||||
"unnamed_sqlite",
|
||||
m,
|
||||
Column("q", Integer),
|
||||
Column("r", Integer),
|
||||
PrimaryKeyConstraint("q"),
|
||||
UniqueConstraint("r"),
|
||||
)
|
||||
return m
|
||||
|
||||
@classmethod
|
||||
def _get_model_schema(cls):
|
||||
from sqlalchemy.sql.naming import conv
|
||||
|
||||
m = MetaData(naming_convention=cls.naming_convention)
|
||||
Table(
|
||||
"x1", m, Column("q", Integer, primary_key=True), Index(None, "q")
|
||||
)
|
||||
Table(
|
||||
"x2",
|
||||
m,
|
||||
Column("q", Integer),
|
||||
Column("p", ForeignKey("x1.q")),
|
||||
CheckConstraint("q > 5", name="token_x2check1"),
|
||||
)
|
||||
Table(
|
||||
"x3",
|
||||
m,
|
||||
Column("q", Integer),
|
||||
Column("r", Integer),
|
||||
Column("s", Integer),
|
||||
UniqueConstraint("r", name="token_x3r"),
|
||||
UniqueConstraint("s", name=conv("userdef_x3_unique_s")),
|
||||
)
|
||||
Table(
|
||||
"x4",
|
||||
m,
|
||||
Column("q", Integer, primary_key=True),
|
||||
Index("userdef_x4_idx_q", "q"),
|
||||
)
|
||||
Table(
|
||||
"x6",
|
||||
m,
|
||||
Column("q", Integer, primary_key=True),
|
||||
Column("p", ForeignKey("x4.q")),
|
||||
Column("r", Integer),
|
||||
Column("s", Integer),
|
||||
UniqueConstraint("r", name="token_x6r"),
|
||||
CheckConstraint("s > 5", "token_x6check1"),
|
||||
CheckConstraint("s < 20", conv("userdef_x6_check_s")),
|
||||
)
|
||||
return m
|
||||
|
||||
|
||||
class _ComparesFKs:
|
||||
def _assert_fk_diff(
|
||||
self,
|
||||
|
||||
@@ -6,7 +6,9 @@ from sqlalchemy import Table
|
||||
|
||||
from ._autogen_fixtures import AutogenFixtureTest
|
||||
from ... import testing
|
||||
from ...testing import config
|
||||
from ...testing import eq_
|
||||
from ...testing import exclusions
|
||||
from ...testing import is_
|
||||
from ...testing import is_true
|
||||
from ...testing import mock
|
||||
@@ -61,8 +63,18 @@ class AutogenerateComputedTest(AutogenFixtureTest, TestBase):
|
||||
c = diffs[0][3]
|
||||
eq_(c.name, "foo")
|
||||
|
||||
is_true(isinstance(c.computed, sa.Computed))
|
||||
is_true(isinstance(c.server_default, sa.Computed))
|
||||
if config.requirements.computed_reflects_normally.enabled:
|
||||
is_true(isinstance(c.computed, sa.Computed))
|
||||
else:
|
||||
is_(c.computed, None)
|
||||
|
||||
if config.requirements.computed_reflects_as_server_default.enabled:
|
||||
is_true(isinstance(c.server_default, sa.DefaultClause))
|
||||
eq_(str(c.server_default.arg.text), "5")
|
||||
elif config.requirements.computed_reflects_normally.enabled:
|
||||
is_true(isinstance(c.computed, sa.Computed))
|
||||
else:
|
||||
is_(c.computed, None)
|
||||
|
||||
@testing.combinations(
|
||||
lambda: (None, sa.Computed("bar*5")),
|
||||
@@ -73,6 +85,7 @@ class AutogenerateComputedTest(AutogenFixtureTest, TestBase):
|
||||
),
|
||||
lambda: (sa.Computed("bar*5"), sa.Computed("bar * 42")),
|
||||
)
|
||||
@config.requirements.computed_reflects_normally
|
||||
def test_cant_change_computed_warning(self, test_case):
|
||||
arg_before, arg_after = testing.resolve_lambda(test_case, **locals())
|
||||
m1 = MetaData()
|
||||
@@ -111,7 +124,10 @@ class AutogenerateComputedTest(AutogenFixtureTest, TestBase):
|
||||
lambda: (None, None),
|
||||
lambda: (sa.Computed("5"), sa.Computed("5")),
|
||||
lambda: (sa.Computed("bar*5"), sa.Computed("bar*5")),
|
||||
lambda: (sa.Computed("bar*5"), sa.Computed("bar * \r\n\t5")),
|
||||
(
|
||||
lambda: (sa.Computed("bar*5"), None),
|
||||
config.requirements.computed_doesnt_reflect_as_server_default,
|
||||
),
|
||||
)
|
||||
def test_computed_unchanged(self, test_case):
|
||||
arg_before, arg_after = testing.resolve_lambda(test_case, **locals())
|
||||
@@ -142,3 +158,46 @@ class AutogenerateComputedTest(AutogenFixtureTest, TestBase):
|
||||
eq_(mock_warn.mock_calls, [])
|
||||
|
||||
eq_(list(diffs), [])
|
||||
|
||||
@config.requirements.computed_reflects_as_server_default
|
||||
def test_remove_computed_default_on_computed(self):
|
||||
"""Asserts the current behavior which is that on PG and Oracle,
|
||||
the GENERATED ALWAYS AS is reflected as a server default which we can't
|
||||
tell is actually "computed", so these come out as a modification to
|
||||
the server default.
|
||||
|
||||
"""
|
||||
m1 = MetaData()
|
||||
m2 = MetaData()
|
||||
|
||||
Table(
|
||||
"user",
|
||||
m1,
|
||||
Column("id", Integer, primary_key=True),
|
||||
Column("bar", Integer),
|
||||
Column("foo", Integer, sa.Computed("bar + 42")),
|
||||
)
|
||||
|
||||
Table(
|
||||
"user",
|
||||
m2,
|
||||
Column("id", Integer, primary_key=True),
|
||||
Column("bar", Integer),
|
||||
Column("foo", Integer),
|
||||
)
|
||||
|
||||
diffs = self._fixture(m1, m2)
|
||||
|
||||
eq_(diffs[0][0][0], "modify_default")
|
||||
eq_(diffs[0][0][2], "user")
|
||||
eq_(diffs[0][0][3], "foo")
|
||||
old = diffs[0][0][-2]
|
||||
new = diffs[0][0][-1]
|
||||
|
||||
is_(new, None)
|
||||
is_true(isinstance(old, sa.DefaultClause))
|
||||
|
||||
if exclusions.against(config, "postgresql"):
|
||||
eq_(str(old.arg.text), "(bar + 42)")
|
||||
elif exclusions.against(config, "oracle"):
|
||||
eq_(str(old.arg.text), '"BAR"+42')
|
||||
|
||||
@@ -24,9 +24,9 @@ class MigrationTransactionTest(TestBase):
|
||||
self.context = MigrationContext.configure(
|
||||
dialect=conn.dialect, opts=opts
|
||||
)
|
||||
self.context.output_buffer = self.context.impl.output_buffer = (
|
||||
io.StringIO()
|
||||
)
|
||||
self.context.output_buffer = (
|
||||
self.context.impl.output_buffer
|
||||
) = io.StringIO()
|
||||
else:
|
||||
self.context = MigrationContext.configure(
|
||||
connection=conn, opts=opts
|
||||
|
||||
@@ -10,6 +10,8 @@ import warnings
|
||||
|
||||
from sqlalchemy import exc as sa_exc
|
||||
|
||||
from ..util import sqla_14
|
||||
|
||||
|
||||
def setup_filters():
|
||||
"""Set global warning behavior for the test suite."""
|
||||
@@ -21,6 +23,13 @@ def setup_filters():
|
||||
|
||||
# some selected deprecations...
|
||||
warnings.filterwarnings("error", category=DeprecationWarning)
|
||||
if not sqla_14:
|
||||
# 1.3 uses pkg_resources in PluginLoader
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
"pkg_resources is deprecated as an API",
|
||||
DeprecationWarning,
|
||||
)
|
||||
try:
|
||||
import pytest
|
||||
except ImportError:
|
||||
|
||||
@@ -1,29 +1,35 @@
|
||||
from .editor import open_in_editor as open_in_editor
|
||||
from .exc import AutogenerateDiffsDetected as AutogenerateDiffsDetected
|
||||
from .exc import CommandError as CommandError
|
||||
from .langhelpers import _with_legacy_names as _with_legacy_names
|
||||
from .langhelpers import asbool as asbool
|
||||
from .langhelpers import dedupe_tuple as dedupe_tuple
|
||||
from .langhelpers import Dispatcher as Dispatcher
|
||||
from .langhelpers import EMPTY_DICT as EMPTY_DICT
|
||||
from .langhelpers import immutabledict as immutabledict
|
||||
from .langhelpers import memoized_property as memoized_property
|
||||
from .langhelpers import ModuleClsProxy as ModuleClsProxy
|
||||
from .langhelpers import not_none as not_none
|
||||
from .langhelpers import rev_id as rev_id
|
||||
from .langhelpers import to_list as to_list
|
||||
from .langhelpers import to_tuple as to_tuple
|
||||
from .langhelpers import unique_list as unique_list
|
||||
from .messaging import err as err
|
||||
from .messaging import format_as_comma as format_as_comma
|
||||
from .messaging import msg as msg
|
||||
from .messaging import obfuscate_url_pw as obfuscate_url_pw
|
||||
from .messaging import status as status
|
||||
from .messaging import warn as warn
|
||||
from .messaging import warn_deprecated as warn_deprecated
|
||||
from .messaging import write_outstream as write_outstream
|
||||
from .pyfiles import coerce_resource_to_filename as coerce_resource_to_filename
|
||||
from .pyfiles import load_python_file as load_python_file
|
||||
from .pyfiles import pyc_file_from_path as pyc_file_from_path
|
||||
from .pyfiles import template_to_file as template_to_file
|
||||
from .sqla_compat import sqla_2 as sqla_2
|
||||
from .editor import open_in_editor
|
||||
from .exc import AutogenerateDiffsDetected
|
||||
from .exc import CommandError
|
||||
from .langhelpers import _with_legacy_names
|
||||
from .langhelpers import asbool
|
||||
from .langhelpers import dedupe_tuple
|
||||
from .langhelpers import Dispatcher
|
||||
from .langhelpers import EMPTY_DICT
|
||||
from .langhelpers import immutabledict
|
||||
from .langhelpers import memoized_property
|
||||
from .langhelpers import ModuleClsProxy
|
||||
from .langhelpers import not_none
|
||||
from .langhelpers import rev_id
|
||||
from .langhelpers import to_list
|
||||
from .langhelpers import to_tuple
|
||||
from .langhelpers import unique_list
|
||||
from .messaging import err
|
||||
from .messaging import format_as_comma
|
||||
from .messaging import msg
|
||||
from .messaging import obfuscate_url_pw
|
||||
from .messaging import status
|
||||
from .messaging import warn
|
||||
from .messaging import write_outstream
|
||||
from .pyfiles import coerce_resource_to_filename
|
||||
from .pyfiles import load_python_file
|
||||
from .pyfiles import pyc_file_from_path
|
||||
from .pyfiles import template_to_file
|
||||
from .sqla_compat import has_computed
|
||||
from .sqla_compat import sqla_13
|
||||
from .sqla_compat import sqla_14
|
||||
from .sqla_compat import sqla_2
|
||||
|
||||
|
||||
if not sqla_13:
|
||||
raise CommandError("SQLAlchemy 1.3.0 or greater is required.")
|
||||
|
||||
@@ -1,37 +1,22 @@
|
||||
# mypy: no-warn-unused-ignores
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from configparser import ConfigParser
|
||||
import io
|
||||
import os
|
||||
from pathlib import Path
|
||||
import sys
|
||||
import typing
|
||||
from typing import Any
|
||||
from typing import Iterator
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Sequence
|
||||
from typing import Union
|
||||
|
||||
if True:
|
||||
# zimports hack for too-long names
|
||||
from sqlalchemy.util import ( # noqa: F401
|
||||
inspect_getfullargspec as inspect_getfullargspec,
|
||||
)
|
||||
from sqlalchemy.util.compat import ( # noqa: F401
|
||||
inspect_formatargspec as inspect_formatargspec,
|
||||
)
|
||||
from sqlalchemy.util import inspect_getfullargspec # noqa
|
||||
from sqlalchemy.util.compat import inspect_formatargspec # noqa
|
||||
|
||||
is_posix = os.name == "posix"
|
||||
|
||||
py314 = sys.version_info >= (3, 14)
|
||||
py313 = sys.version_info >= (3, 13)
|
||||
py312 = sys.version_info >= (3, 12)
|
||||
py311 = sys.version_info >= (3, 11)
|
||||
py310 = sys.version_info >= (3, 10)
|
||||
py39 = sys.version_info >= (3, 9)
|
||||
py38 = sys.version_info >= (3, 8)
|
||||
|
||||
|
||||
# produce a wrapper that allows encoded text to stream
|
||||
@@ -43,82 +28,24 @@ class EncodedIO(io.TextIOWrapper):
|
||||
|
||||
|
||||
if py39:
|
||||
from importlib import resources as _resources
|
||||
|
||||
importlib_resources = _resources
|
||||
from importlib import metadata as _metadata
|
||||
|
||||
importlib_metadata = _metadata
|
||||
from importlib.metadata import EntryPoint as EntryPoint
|
||||
from importlib import resources as importlib_resources
|
||||
from importlib import metadata as importlib_metadata
|
||||
from importlib.metadata import EntryPoint
|
||||
else:
|
||||
import importlib_resources # type:ignore # noqa
|
||||
import importlib_metadata # type:ignore # noqa
|
||||
from importlib_metadata import EntryPoint # type:ignore # noqa
|
||||
|
||||
if py311:
|
||||
import tomllib as tomllib
|
||||
else:
|
||||
import tomli as tomllib # type: ignore # noqa
|
||||
|
||||
|
||||
if py312:
|
||||
|
||||
def path_walk(
|
||||
path: Path, *, top_down: bool = True
|
||||
) -> Iterator[tuple[Path, list[str], list[str]]]:
|
||||
return Path.walk(path)
|
||||
|
||||
def path_relative_to(
|
||||
path: Path, other: Path, *, walk_up: bool = False
|
||||
) -> Path:
|
||||
return path.relative_to(other, walk_up=walk_up)
|
||||
|
||||
else:
|
||||
|
||||
def path_walk(
|
||||
path: Path, *, top_down: bool = True
|
||||
) -> Iterator[tuple[Path, list[str], list[str]]]:
|
||||
for root, dirs, files in os.walk(path, topdown=top_down):
|
||||
yield Path(root), dirs, files
|
||||
|
||||
def path_relative_to(
|
||||
path: Path, other: Path, *, walk_up: bool = False
|
||||
) -> Path:
|
||||
"""
|
||||
Calculate the relative path of 'path' with respect to 'other',
|
||||
optionally allowing 'path' to be outside the subtree of 'other'.
|
||||
|
||||
OK I used AI for this, sorry
|
||||
|
||||
"""
|
||||
try:
|
||||
return path.relative_to(other)
|
||||
except ValueError:
|
||||
if walk_up:
|
||||
other_ancestors = list(other.parents) + [other]
|
||||
for ancestor in other_ancestors:
|
||||
try:
|
||||
return path.relative_to(ancestor)
|
||||
except ValueError:
|
||||
continue
|
||||
raise ValueError(
|
||||
f"{path} is not in the same subtree as {other}"
|
||||
)
|
||||
else:
|
||||
raise
|
||||
|
||||
|
||||
def importlib_metadata_get(group: str) -> Sequence[EntryPoint]:
|
||||
ep = importlib_metadata.entry_points()
|
||||
if hasattr(ep, "select"):
|
||||
return ep.select(group=group)
|
||||
return ep.select(group=group) # type: ignore
|
||||
else:
|
||||
return ep.get(group, ()) # type: ignore
|
||||
|
||||
|
||||
def formatannotation_fwdref(
|
||||
annotation: Any, base_module: Optional[Any] = None
|
||||
) -> str:
|
||||
def formatannotation_fwdref(annotation, base_module=None):
|
||||
"""vendored from python 3.7"""
|
||||
# copied over _formatannotation from sqlalchemy 2.0
|
||||
|
||||
@@ -139,7 +66,7 @@ def formatannotation_fwdref(
|
||||
def read_config_parser(
|
||||
file_config: ConfigParser,
|
||||
file_argument: Sequence[Union[str, os.PathLike[str]]],
|
||||
) -> List[str]:
|
||||
) -> list[str]:
|
||||
if py310:
|
||||
return file_config.read(file_argument, encoding="locale")
|
||||
else:
|
||||
|
||||
@@ -1,25 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from typing import List
|
||||
from typing import Tuple
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from alembic.autogenerate import RevisionContext
|
||||
|
||||
|
||||
class CommandError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class AutogenerateDiffsDetected(CommandError):
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
revision_context: RevisionContext,
|
||||
diffs: List[Tuple[Any, ...]],
|
||||
) -> None:
|
||||
super().__init__(message)
|
||||
self.revision_context = revision_context
|
||||
self.diffs = diffs
|
||||
pass
|
||||
|
||||
@@ -5,46 +5,33 @@ from collections.abc import Iterable
|
||||
import textwrap
|
||||
from typing import Any
|
||||
from typing import Callable
|
||||
from typing import cast
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Mapping
|
||||
from typing import MutableMapping
|
||||
from typing import NoReturn
|
||||
from typing import Optional
|
||||
from typing import overload
|
||||
from typing import Sequence
|
||||
from typing import Set
|
||||
from typing import Tuple
|
||||
from typing import Type
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TypeVar
|
||||
from typing import Union
|
||||
import uuid
|
||||
import warnings
|
||||
|
||||
from sqlalchemy.util import asbool as asbool # noqa: F401
|
||||
from sqlalchemy.util import immutabledict as immutabledict # noqa: F401
|
||||
from sqlalchemy.util import to_list as to_list # noqa: F401
|
||||
from sqlalchemy.util import unique_list as unique_list
|
||||
from sqlalchemy.util import asbool # noqa
|
||||
from sqlalchemy.util import immutabledict # noqa
|
||||
from sqlalchemy.util import memoized_property # noqa
|
||||
from sqlalchemy.util import to_list # noqa
|
||||
from sqlalchemy.util import unique_list # noqa
|
||||
|
||||
from .compat import inspect_getfullargspec
|
||||
|
||||
if True:
|
||||
# zimports workaround :(
|
||||
from sqlalchemy.util import ( # noqa: F401
|
||||
memoized_property as memoized_property,
|
||||
)
|
||||
|
||||
|
||||
EMPTY_DICT: Mapping[Any, Any] = immutabledict()
|
||||
_T = TypeVar("_T", bound=Any)
|
||||
|
||||
_C = TypeVar("_C", bound=Callable[..., Any])
|
||||
_T = TypeVar("_T")
|
||||
|
||||
|
||||
class _ModuleClsMeta(type):
|
||||
def __setattr__(cls, key: str, value: Callable[..., Any]) -> None:
|
||||
def __setattr__(cls, key: str, value: Callable) -> None:
|
||||
super().__setattr__(key, value)
|
||||
cls._update_module_proxies(key) # type: ignore
|
||||
|
||||
@@ -58,13 +45,9 @@ class ModuleClsProxy(metaclass=_ModuleClsMeta):
|
||||
|
||||
"""
|
||||
|
||||
_setups: Dict[
|
||||
Type[Any],
|
||||
Tuple[
|
||||
Set[str],
|
||||
List[Tuple[MutableMapping[str, Any], MutableMapping[str, Any]]],
|
||||
],
|
||||
] = collections.defaultdict(lambda: (set(), []))
|
||||
_setups: Dict[type, Tuple[set, list]] = collections.defaultdict(
|
||||
lambda: (set(), [])
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _update_module_proxies(cls, name: str) -> None:
|
||||
@@ -87,33 +70,18 @@ class ModuleClsProxy(metaclass=_ModuleClsMeta):
|
||||
del globals_[attr_name]
|
||||
|
||||
@classmethod
|
||||
def create_module_class_proxy(
|
||||
cls,
|
||||
globals_: MutableMapping[str, Any],
|
||||
locals_: MutableMapping[str, Any],
|
||||
) -> None:
|
||||
def create_module_class_proxy(cls, globals_, locals_):
|
||||
attr_names, modules = cls._setups[cls]
|
||||
modules.append((globals_, locals_))
|
||||
cls._setup_proxy(globals_, locals_, attr_names)
|
||||
|
||||
@classmethod
|
||||
def _setup_proxy(
|
||||
cls,
|
||||
globals_: MutableMapping[str, Any],
|
||||
locals_: MutableMapping[str, Any],
|
||||
attr_names: Set[str],
|
||||
) -> None:
|
||||
def _setup_proxy(cls, globals_, locals_, attr_names):
|
||||
for methname in dir(cls):
|
||||
cls._add_proxied_attribute(methname, globals_, locals_, attr_names)
|
||||
|
||||
@classmethod
|
||||
def _add_proxied_attribute(
|
||||
cls,
|
||||
methname: str,
|
||||
globals_: MutableMapping[str, Any],
|
||||
locals_: MutableMapping[str, Any],
|
||||
attr_names: Set[str],
|
||||
) -> None:
|
||||
def _add_proxied_attribute(cls, methname, globals_, locals_, attr_names):
|
||||
if not methname.startswith("_"):
|
||||
meth = getattr(cls, methname)
|
||||
if callable(meth):
|
||||
@@ -124,15 +92,10 @@ class ModuleClsProxy(metaclass=_ModuleClsMeta):
|
||||
attr_names.add(methname)
|
||||
|
||||
@classmethod
|
||||
def _create_method_proxy(
|
||||
cls,
|
||||
name: str,
|
||||
globals_: MutableMapping[str, Any],
|
||||
locals_: MutableMapping[str, Any],
|
||||
) -> Callable[..., Any]:
|
||||
def _create_method_proxy(cls, name, globals_, locals_):
|
||||
fn = getattr(cls, name)
|
||||
|
||||
def _name_error(name: str, from_: Exception) -> NoReturn:
|
||||
def _name_error(name, from_):
|
||||
raise NameError(
|
||||
"Can't invoke function '%s', as the proxy object has "
|
||||
"not yet been "
|
||||
@@ -156,9 +119,7 @@ class ModuleClsProxy(metaclass=_ModuleClsMeta):
|
||||
translations,
|
||||
)
|
||||
|
||||
def translate(
|
||||
fn_name: str, spec: Any, translations: Any, args: Any, kw: Any
|
||||
) -> Any:
|
||||
def translate(fn_name, spec, translations, args, kw):
|
||||
return_kw = {}
|
||||
return_args = []
|
||||
|
||||
@@ -215,15 +176,15 @@ class ModuleClsProxy(metaclass=_ModuleClsMeta):
|
||||
"doc": fn.__doc__,
|
||||
}
|
||||
)
|
||||
lcl: MutableMapping[str, Any] = {}
|
||||
lcl = {}
|
||||
|
||||
exec(func_text, cast("Dict[str, Any]", globals_), lcl)
|
||||
return cast("Callable[..., Any]", lcl[name])
|
||||
exec(func_text, globals_, lcl)
|
||||
return lcl[name]
|
||||
|
||||
|
||||
def _with_legacy_names(translations: Any) -> Any:
|
||||
def decorate(fn: _C) -> _C:
|
||||
fn._legacy_translations = translations # type: ignore[attr-defined]
|
||||
def _with_legacy_names(translations):
|
||||
def decorate(fn):
|
||||
fn._legacy_translations = translations
|
||||
return fn
|
||||
|
||||
return decorate
|
||||
@@ -234,22 +195,21 @@ def rev_id() -> str:
|
||||
|
||||
|
||||
@overload
|
||||
def to_tuple(x: Any, default: Tuple[Any, ...]) -> Tuple[Any, ...]: ...
|
||||
def to_tuple(x: Any, default: tuple) -> tuple:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def to_tuple(x: None, default: Optional[_T] = ...) -> _T: ...
|
||||
def to_tuple(x: None, default: Optional[_T] = None) -> _T:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def to_tuple(
|
||||
x: Any, default: Optional[Tuple[Any, ...]] = None
|
||||
) -> Tuple[Any, ...]: ...
|
||||
def to_tuple(x: Any, default: Optional[tuple] = None) -> tuple:
|
||||
...
|
||||
|
||||
|
||||
def to_tuple(
|
||||
x: Any, default: Optional[Tuple[Any, ...]] = None
|
||||
) -> Optional[Tuple[Any, ...]]:
|
||||
def to_tuple(x, default=None):
|
||||
if x is None:
|
||||
return default
|
||||
elif isinstance(x, str):
|
||||
@@ -266,13 +226,13 @@ def dedupe_tuple(tup: Tuple[str, ...]) -> Tuple[str, ...]:
|
||||
|
||||
class Dispatcher:
|
||||
def __init__(self, uselist: bool = False) -> None:
|
||||
self._registry: Dict[Tuple[Any, ...], Any] = {}
|
||||
self._registry: Dict[tuple, Any] = {}
|
||||
self.uselist = uselist
|
||||
|
||||
def dispatch_for(
|
||||
self, target: Any, qualifier: str = "default"
|
||||
) -> Callable[[_C], _C]:
|
||||
def decorate(fn: _C) -> _C:
|
||||
) -> Callable:
|
||||
def decorate(fn):
|
||||
if self.uselist:
|
||||
self._registry.setdefault((target, qualifier), []).append(fn)
|
||||
else:
|
||||
@@ -284,7 +244,7 @@ class Dispatcher:
|
||||
|
||||
def dispatch(self, obj: Any, qualifier: str = "default") -> Any:
|
||||
if isinstance(obj, str):
|
||||
targets: Sequence[Any] = [obj]
|
||||
targets: Sequence = [obj]
|
||||
elif isinstance(obj, type):
|
||||
targets = obj.__mro__
|
||||
else:
|
||||
@@ -299,13 +259,11 @@ class Dispatcher:
|
||||
raise ValueError("no dispatch function for object: %s" % obj)
|
||||
|
||||
def _fn_or_list(
|
||||
self, fn_or_list: Union[List[Callable[..., Any]], Callable[..., Any]]
|
||||
) -> Callable[..., Any]:
|
||||
self, fn_or_list: Union[List[Callable], Callable]
|
||||
) -> Callable:
|
||||
if self.uselist:
|
||||
|
||||
def go(*arg: Any, **kw: Any) -> None:
|
||||
if TYPE_CHECKING:
|
||||
assert isinstance(fn_or_list, Sequence)
|
||||
def go(*arg, **kw):
|
||||
for fn in fn_or_list:
|
||||
fn(*arg, **kw)
|
||||
|
||||
|
||||
@@ -5,7 +5,6 @@ from contextlib import contextmanager
|
||||
import logging
|
||||
import sys
|
||||
import textwrap
|
||||
from typing import Iterator
|
||||
from typing import Optional
|
||||
from typing import TextIO
|
||||
from typing import Union
|
||||
@@ -13,6 +12,8 @@ import warnings
|
||||
|
||||
from sqlalchemy.engine import url
|
||||
|
||||
from . import sqla_compat
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
# disable "no handler found" errors
|
||||
@@ -52,9 +53,7 @@ def write_outstream(
|
||||
|
||||
|
||||
@contextmanager
|
||||
def status(
|
||||
status_msg: str, newline: bool = False, quiet: bool = False
|
||||
) -> Iterator[None]:
|
||||
def status(status_msg: str, newline: bool = False, quiet: bool = False):
|
||||
msg(status_msg + " ...", newline, flush=True, quiet=quiet)
|
||||
try:
|
||||
yield
|
||||
@@ -67,24 +66,21 @@ def status(
|
||||
write_outstream(sys.stdout, " done\n")
|
||||
|
||||
|
||||
def err(message: str, quiet: bool = False) -> None:
|
||||
def err(message: str, quiet: bool = False):
|
||||
log.error(message)
|
||||
msg(f"FAILED: {message}", quiet=quiet)
|
||||
sys.exit(-1)
|
||||
|
||||
|
||||
def obfuscate_url_pw(input_url: str) -> str:
|
||||
return url.make_url(input_url).render_as_string(hide_password=True)
|
||||
u = url.make_url(input_url)
|
||||
return sqla_compat.url_render_as_string(u, hide_password=True)
|
||||
|
||||
|
||||
def warn(msg: str, stacklevel: int = 2) -> None:
|
||||
warnings.warn(msg, UserWarning, stacklevel=stacklevel)
|
||||
|
||||
|
||||
def warn_deprecated(msg: str, stacklevel: int = 2) -> None:
|
||||
warnings.warn(msg, DeprecationWarning, stacklevel=stacklevel)
|
||||
|
||||
|
||||
def msg(
|
||||
msg: str, newline: bool = True, flush: bool = False, quiet: bool = False
|
||||
) -> None:
|
||||
@@ -96,17 +92,11 @@ def msg(
|
||||
write_outstream(sys.stdout, "\n")
|
||||
else:
|
||||
# left indent output lines
|
||||
indent = " "
|
||||
lines = textwrap.wrap(
|
||||
msg,
|
||||
TERMWIDTH,
|
||||
initial_indent=indent,
|
||||
subsequent_indent=indent,
|
||||
)
|
||||
lines = textwrap.wrap(msg, TERMWIDTH)
|
||||
if len(lines) > 1:
|
||||
for line in lines[0:-1]:
|
||||
write_outstream(sys.stdout, line, "\n")
|
||||
write_outstream(sys.stdout, lines[-1], ("\n" if newline else ""))
|
||||
write_outstream(sys.stdout, " ", line, "\n")
|
||||
write_outstream(sys.stdout, " ", lines[-1], ("\n" if newline else ""))
|
||||
if flush:
|
||||
sys.stdout.flush()
|
||||
|
||||
|
||||
@@ -6,13 +6,9 @@ import importlib
|
||||
import importlib.machinery
|
||||
import importlib.util
|
||||
import os
|
||||
import pathlib
|
||||
import re
|
||||
import tempfile
|
||||
from types import ModuleType
|
||||
from typing import Any
|
||||
from typing import Optional
|
||||
from typing import Union
|
||||
|
||||
from mako import exceptions
|
||||
from mako.template import Template
|
||||
@@ -22,14 +18,9 @@ from .exc import CommandError
|
||||
|
||||
|
||||
def template_to_file(
|
||||
template_file: Union[str, os.PathLike[str]],
|
||||
dest: Union[str, os.PathLike[str]],
|
||||
output_encoding: str,
|
||||
*,
|
||||
append_with_newlines: bool = False,
|
||||
**kw: Any,
|
||||
template_file: str, dest: str, output_encoding: str, **kw
|
||||
) -> None:
|
||||
template = Template(filename=_preserving_path_as_str(template_file))
|
||||
template = Template(filename=template_file)
|
||||
try:
|
||||
output = template.render_unicode(**kw).encode(output_encoding)
|
||||
except:
|
||||
@@ -45,13 +36,11 @@ def template_to_file(
|
||||
"template-oriented traceback." % fname
|
||||
)
|
||||
else:
|
||||
with open(dest, "ab" if append_with_newlines else "wb") as f:
|
||||
if append_with_newlines:
|
||||
f.write("\n\n".encode(output_encoding))
|
||||
with open(dest, "wb") as f:
|
||||
f.write(output)
|
||||
|
||||
|
||||
def coerce_resource_to_filename(fname_or_resource: str) -> pathlib.Path:
|
||||
def coerce_resource_to_filename(fname: str) -> str:
|
||||
"""Interpret a filename as either a filesystem location or as a package
|
||||
resource.
|
||||
|
||||
@@ -59,9 +48,8 @@ def coerce_resource_to_filename(fname_or_resource: str) -> pathlib.Path:
|
||||
are interpreted as resources and coerced to a file location.
|
||||
|
||||
"""
|
||||
# TODO: there seem to be zero tests for the package resource codepath
|
||||
if not os.path.isabs(fname_or_resource) and ":" in fname_or_resource:
|
||||
tokens = fname_or_resource.split(":")
|
||||
if not os.path.isabs(fname) and ":" in fname:
|
||||
tokens = fname.split(":")
|
||||
|
||||
# from https://importlib-resources.readthedocs.io/en/latest/migration.html#pkg-resources-resource-filename # noqa E501
|
||||
|
||||
@@ -71,48 +59,37 @@ def coerce_resource_to_filename(fname_or_resource: str) -> pathlib.Path:
|
||||
ref = compat.importlib_resources.files(tokens[0])
|
||||
for tok in tokens[1:]:
|
||||
ref = ref / tok
|
||||
fname_or_resource = file_manager.enter_context( # type: ignore[assignment] # noqa: E501
|
||||
fname = file_manager.enter_context( # type: ignore[assignment]
|
||||
compat.importlib_resources.as_file(ref)
|
||||
)
|
||||
return pathlib.Path(fname_or_resource)
|
||||
return fname
|
||||
|
||||
|
||||
def pyc_file_from_path(
|
||||
path: Union[str, os.PathLike[str]],
|
||||
) -> Optional[pathlib.Path]:
|
||||
def pyc_file_from_path(path: str) -> Optional[str]:
|
||||
"""Given a python source path, locate the .pyc."""
|
||||
|
||||
pathpath = pathlib.Path(path)
|
||||
candidate = pathlib.Path(
|
||||
importlib.util.cache_from_source(pathpath.as_posix())
|
||||
)
|
||||
if candidate.exists():
|
||||
candidate = importlib.util.cache_from_source(path)
|
||||
if os.path.exists(candidate):
|
||||
return candidate
|
||||
|
||||
# even for pep3147, fall back to the old way of finding .pyc files,
|
||||
# to support sourceless operation
|
||||
ext = pathpath.suffix
|
||||
filepath, ext = os.path.splitext(path)
|
||||
for ext in importlib.machinery.BYTECODE_SUFFIXES:
|
||||
if pathpath.with_suffix(ext).exists():
|
||||
return pathpath.with_suffix(ext)
|
||||
if os.path.exists(filepath + ext):
|
||||
return filepath + ext
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def load_python_file(
|
||||
dir_: Union[str, os.PathLike[str]], filename: Union[str, os.PathLike[str]]
|
||||
) -> ModuleType:
|
||||
def load_python_file(dir_: str, filename: str):
|
||||
"""Load a file from the given path as a Python module."""
|
||||
|
||||
dir_ = pathlib.Path(dir_)
|
||||
filename_as_path = pathlib.Path(filename)
|
||||
filename = filename_as_path.name
|
||||
|
||||
module_id = re.sub(r"\W", "_", filename)
|
||||
path = dir_ / filename
|
||||
ext = path.suffix
|
||||
path = os.path.join(dir_, filename)
|
||||
_, ext = os.path.splitext(filename)
|
||||
if ext == ".py":
|
||||
if path.exists():
|
||||
if os.path.exists(path):
|
||||
module = load_module_py(module_id, path)
|
||||
else:
|
||||
pyc_path = pyc_file_from_path(path)
|
||||
@@ -122,32 +99,12 @@ def load_python_file(
|
||||
module = load_module_py(module_id, pyc_path)
|
||||
elif ext in (".pyc", ".pyo"):
|
||||
module = load_module_py(module_id, path)
|
||||
else:
|
||||
assert False
|
||||
return module
|
||||
|
||||
|
||||
def load_module_py(
|
||||
module_id: str, path: Union[str, os.PathLike[str]]
|
||||
) -> ModuleType:
|
||||
def load_module_py(module_id: str, path: str):
|
||||
spec = importlib.util.spec_from_file_location(module_id, path)
|
||||
assert spec
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module) # type: ignore
|
||||
return module
|
||||
|
||||
|
||||
def _preserving_path_as_str(path: Union[str, os.PathLike[str]]) -> str:
|
||||
"""receive str/pathlike and return a string.
|
||||
|
||||
Does not convert an incoming string path to a Path first, to help with
|
||||
unit tests that are doing string path round trips without OS-specific
|
||||
processing if not necessary.
|
||||
|
||||
"""
|
||||
if isinstance(path, str):
|
||||
return path
|
||||
elif isinstance(path, pathlib.PurePath):
|
||||
return str(path)
|
||||
else:
|
||||
return str(pathlib.Path(path))
|
||||
|
||||
@@ -1,27 +1,24 @@
|
||||
# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
|
||||
# mypy: no-warn-return-any, allow-any-generics
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import re
|
||||
from typing import Any
|
||||
from typing import Callable
|
||||
from typing import Dict
|
||||
from typing import Iterable
|
||||
from typing import Iterator
|
||||
from typing import Mapping
|
||||
from typing import Optional
|
||||
from typing import Protocol
|
||||
from typing import Set
|
||||
from typing import Type
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TypeVar
|
||||
from typing import Union
|
||||
|
||||
from sqlalchemy import __version__
|
||||
from sqlalchemy import inspect
|
||||
from sqlalchemy import schema
|
||||
from sqlalchemy import sql
|
||||
from sqlalchemy import types as sqltypes
|
||||
from sqlalchemy.engine import url
|
||||
from sqlalchemy.ext.compiler import compiles
|
||||
from sqlalchemy.schema import CheckConstraint
|
||||
from sqlalchemy.schema import Column
|
||||
from sqlalchemy.schema import ForeignKeyConstraint
|
||||
@@ -29,33 +26,31 @@ from sqlalchemy.sql import visitors
|
||||
from sqlalchemy.sql.base import DialectKWArgs
|
||||
from sqlalchemy.sql.elements import BindParameter
|
||||
from sqlalchemy.sql.elements import ColumnClause
|
||||
from sqlalchemy.sql.elements import quoted_name
|
||||
from sqlalchemy.sql.elements import TextClause
|
||||
from sqlalchemy.sql.elements import UnaryExpression
|
||||
from sqlalchemy.sql.naming import _NONE_NAME as _NONE_NAME # type: ignore[attr-defined] # noqa: E501
|
||||
from sqlalchemy.sql.visitors import traverse
|
||||
from typing_extensions import TypeGuard
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlalchemy import ClauseElement
|
||||
from sqlalchemy import Identity
|
||||
from sqlalchemy import Index
|
||||
from sqlalchemy import Table
|
||||
from sqlalchemy.engine import Connection
|
||||
from sqlalchemy.engine import Dialect
|
||||
from sqlalchemy.engine import Transaction
|
||||
from sqlalchemy.engine.reflection import Inspector
|
||||
from sqlalchemy.sql.base import ColumnCollection
|
||||
from sqlalchemy.sql.compiler import SQLCompiler
|
||||
from sqlalchemy.sql.dml import Insert
|
||||
from sqlalchemy.sql.elements import ColumnElement
|
||||
from sqlalchemy.sql.schema import Constraint
|
||||
from sqlalchemy.sql.schema import SchemaItem
|
||||
from sqlalchemy.sql.selectable import Select
|
||||
from sqlalchemy.sql.selectable import TableClause
|
||||
|
||||
_CE = TypeVar("_CE", bound=Union["ColumnElement[Any]", "SchemaItem"])
|
||||
|
||||
|
||||
class _CompilerProtocol(Protocol):
|
||||
def __call__(self, element: Any, compiler: Any, **kw: Any) -> str: ...
|
||||
|
||||
|
||||
def _safe_int(value: str) -> Union[int, str]:
|
||||
try:
|
||||
return int(value)
|
||||
@@ -66,65 +61,90 @@ def _safe_int(value: str) -> Union[int, str]:
|
||||
_vers = tuple(
|
||||
[_safe_int(x) for x in re.findall(r"(\d+|[abc]\d)", __version__)]
|
||||
)
|
||||
sqla_13 = _vers >= (1, 3)
|
||||
sqla_14 = _vers >= (1, 4)
|
||||
# https://docs.sqlalchemy.org/en/latest/changelog/changelog_14.html#change-0c6e0cc67dfe6fac5164720e57ef307d
|
||||
sqla_14_18 = _vers >= (1, 4, 18)
|
||||
sqla_14_26 = _vers >= (1, 4, 26)
|
||||
sqla_2 = _vers >= (2,)
|
||||
sqlalchemy_version = __version__
|
||||
|
||||
if TYPE_CHECKING:
|
||||
try:
|
||||
from sqlalchemy.sql.naming import _NONE_NAME as _NONE_NAME
|
||||
except ImportError:
|
||||
from sqlalchemy.sql.elements import _NONE_NAME as _NONE_NAME # type: ignore # noqa: E501
|
||||
|
||||
def compiles(
|
||||
element: Type[ClauseElement], *dialects: str
|
||||
) -> Callable[[_CompilerProtocol], _CompilerProtocol]: ...
|
||||
|
||||
class _Unsupported:
|
||||
"Placeholder for unsupported SQLAlchemy classes"
|
||||
|
||||
|
||||
try:
|
||||
from sqlalchemy import Computed
|
||||
except ImportError:
|
||||
if not TYPE_CHECKING:
|
||||
|
||||
class Computed(_Unsupported):
|
||||
pass
|
||||
|
||||
has_computed = False
|
||||
has_computed_reflection = False
|
||||
else:
|
||||
from sqlalchemy.ext.compiler import compiles # noqa: I100,I202
|
||||
has_computed = True
|
||||
has_computed_reflection = _vers >= (1, 3, 16)
|
||||
|
||||
try:
|
||||
from sqlalchemy import Identity
|
||||
except ImportError:
|
||||
if not TYPE_CHECKING:
|
||||
|
||||
identity_has_dialect_kwargs = issubclass(schema.Identity, DialectKWArgs)
|
||||
class Identity(_Unsupported):
|
||||
pass
|
||||
|
||||
has_identity = False
|
||||
else:
|
||||
identity_has_dialect_kwargs = issubclass(Identity, DialectKWArgs)
|
||||
|
||||
def _get_identity_options_dict(
|
||||
identity: Union[Identity, schema.Sequence, None],
|
||||
dialect_kwargs: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
if identity is None:
|
||||
return {}
|
||||
elif identity_has_dialect_kwargs:
|
||||
assert hasattr(identity, "_as_dict")
|
||||
as_dict = identity._as_dict()
|
||||
if dialect_kwargs:
|
||||
assert isinstance(identity, DialectKWArgs)
|
||||
as_dict.update(identity.dialect_kwargs)
|
||||
else:
|
||||
as_dict = {}
|
||||
if isinstance(identity, schema.Identity):
|
||||
# always=None means something different than always=False
|
||||
as_dict["always"] = identity.always
|
||||
if identity.on_null is not None:
|
||||
as_dict["on_null"] = identity.on_null
|
||||
# attributes common to Identity and Sequence
|
||||
attrs = (
|
||||
"start",
|
||||
"increment",
|
||||
"minvalue",
|
||||
"maxvalue",
|
||||
"nominvalue",
|
||||
"nomaxvalue",
|
||||
"cycle",
|
||||
"cache",
|
||||
"order",
|
||||
)
|
||||
as_dict.update(
|
||||
{
|
||||
key: getattr(identity, key, None)
|
||||
for key in attrs
|
||||
if getattr(identity, key, None) is not None
|
||||
}
|
||||
)
|
||||
return as_dict
|
||||
def _get_identity_options_dict(
|
||||
identity: Union[Identity, schema.Sequence, None],
|
||||
dialect_kwargs: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
if identity is None:
|
||||
return {}
|
||||
elif identity_has_dialect_kwargs:
|
||||
as_dict = identity._as_dict() # type: ignore
|
||||
if dialect_kwargs:
|
||||
assert isinstance(identity, DialectKWArgs)
|
||||
as_dict.update(identity.dialect_kwargs)
|
||||
else:
|
||||
as_dict = {}
|
||||
if isinstance(identity, Identity):
|
||||
# always=None means something different than always=False
|
||||
as_dict["always"] = identity.always
|
||||
if identity.on_null is not None:
|
||||
as_dict["on_null"] = identity.on_null
|
||||
# attributes common to Identity and Sequence
|
||||
attrs = (
|
||||
"start",
|
||||
"increment",
|
||||
"minvalue",
|
||||
"maxvalue",
|
||||
"nominvalue",
|
||||
"nomaxvalue",
|
||||
"cycle",
|
||||
"cache",
|
||||
"order",
|
||||
)
|
||||
as_dict.update(
|
||||
{
|
||||
key: getattr(identity, key, None)
|
||||
for key in attrs
|
||||
if getattr(identity, key, None) is not None
|
||||
}
|
||||
)
|
||||
return as_dict
|
||||
|
||||
has_identity = True
|
||||
|
||||
if sqla_2:
|
||||
from sqlalchemy.sql.base import _NoneName
|
||||
@@ -133,6 +153,7 @@ else:
|
||||
|
||||
|
||||
_ConstraintName = Union[None, str, _NoneName]
|
||||
|
||||
_ConstraintNameDefined = Union[str, _NoneName]
|
||||
|
||||
|
||||
@@ -142,11 +163,15 @@ def constraint_name_defined(
|
||||
return name is _NONE_NAME or isinstance(name, (str, _NoneName))
|
||||
|
||||
|
||||
def constraint_name_string(name: _ConstraintName) -> TypeGuard[str]:
|
||||
def constraint_name_string(
|
||||
name: _ConstraintName,
|
||||
) -> TypeGuard[str]:
|
||||
return isinstance(name, str)
|
||||
|
||||
|
||||
def constraint_name_or_none(name: _ConstraintName) -> Optional[str]:
|
||||
def constraint_name_or_none(
|
||||
name: _ConstraintName,
|
||||
) -> Optional[str]:
|
||||
return name if constraint_name_string(name) else None
|
||||
|
||||
|
||||
@@ -176,10 +201,17 @@ def _ensure_scope_for_ddl(
|
||||
yield
|
||||
|
||||
|
||||
def url_render_as_string(url, hide_password=True):
|
||||
if sqla_14:
|
||||
return url.render_as_string(hide_password=hide_password)
|
||||
else:
|
||||
return url.__to_string__(hide_password=hide_password)
|
||||
|
||||
|
||||
def _safe_begin_connection_transaction(
|
||||
connection: Connection,
|
||||
) -> Transaction:
|
||||
transaction = connection.get_transaction()
|
||||
transaction = _get_connection_transaction(connection)
|
||||
if transaction:
|
||||
return transaction
|
||||
else:
|
||||
@@ -189,7 +221,7 @@ def _safe_begin_connection_transaction(
|
||||
def _safe_commit_connection_transaction(
|
||||
connection: Connection,
|
||||
) -> None:
|
||||
transaction = connection.get_transaction()
|
||||
transaction = _get_connection_transaction(connection)
|
||||
if transaction:
|
||||
transaction.commit()
|
||||
|
||||
@@ -197,7 +229,7 @@ def _safe_commit_connection_transaction(
|
||||
def _safe_rollback_connection_transaction(
|
||||
connection: Connection,
|
||||
) -> None:
|
||||
transaction = connection.get_transaction()
|
||||
transaction = _get_connection_transaction(connection)
|
||||
if transaction:
|
||||
transaction.rollback()
|
||||
|
||||
@@ -218,34 +250,70 @@ def _idx_table_bound_expressions(idx: Index) -> Iterable[ColumnElement[Any]]:
|
||||
|
||||
def _copy(schema_item: _CE, **kw) -> _CE:
|
||||
if hasattr(schema_item, "_copy"):
|
||||
return schema_item._copy(**kw)
|
||||
return schema_item._copy(**kw) # type: ignore[union-attr]
|
||||
else:
|
||||
return schema_item.copy(**kw) # type: ignore[union-attr]
|
||||
|
||||
|
||||
def _get_connection_transaction(
|
||||
connection: Connection,
|
||||
) -> Optional[Transaction]:
|
||||
if sqla_14:
|
||||
return connection.get_transaction()
|
||||
else:
|
||||
r = connection._root # type: ignore[attr-defined]
|
||||
return r._Connection__transaction
|
||||
|
||||
|
||||
def _create_url(*arg, **kw) -> url.URL:
|
||||
if hasattr(url.URL, "create"):
|
||||
return url.URL.create(*arg, **kw)
|
||||
else:
|
||||
return url.URL(*arg, **kw)
|
||||
|
||||
|
||||
def _connectable_has_table(
|
||||
connectable: Connection, tablename: str, schemaname: Union[str, None]
|
||||
) -> bool:
|
||||
return connectable.dialect.has_table(connectable, tablename, schemaname)
|
||||
if sqla_14:
|
||||
return inspect(connectable).has_table(tablename, schemaname)
|
||||
else:
|
||||
return connectable.dialect.has_table(
|
||||
connectable, tablename, schemaname
|
||||
)
|
||||
|
||||
|
||||
def _exec_on_inspector(inspector, statement, **params):
|
||||
with inspector._operation_context() as conn:
|
||||
return conn.execute(statement, params)
|
||||
if sqla_14:
|
||||
with inspector._operation_context() as conn:
|
||||
return conn.execute(statement, params)
|
||||
else:
|
||||
return inspector.bind.execute(statement, params)
|
||||
|
||||
|
||||
def _nullability_might_be_unset(metadata_column):
|
||||
from sqlalchemy.sql import schema
|
||||
if not sqla_14:
|
||||
return metadata_column.nullable
|
||||
else:
|
||||
from sqlalchemy.sql import schema
|
||||
|
||||
return metadata_column._user_defined_nullable is schema.NULL_UNSPECIFIED
|
||||
return (
|
||||
metadata_column._user_defined_nullable is schema.NULL_UNSPECIFIED
|
||||
)
|
||||
|
||||
|
||||
def _server_default_is_computed(*server_default) -> bool:
|
||||
return any(isinstance(sd, schema.Computed) for sd in server_default)
|
||||
if not has_computed:
|
||||
return False
|
||||
else:
|
||||
return any(isinstance(sd, Computed) for sd in server_default)
|
||||
|
||||
|
||||
def _server_default_is_identity(*server_default) -> bool:
|
||||
return any(isinstance(sd, schema.Identity) for sd in server_default)
|
||||
if not sqla_14:
|
||||
return False
|
||||
else:
|
||||
return any(isinstance(sd, Identity) for sd in server_default)
|
||||
|
||||
|
||||
def _table_for_constraint(constraint: Constraint) -> Table:
|
||||
@@ -266,6 +334,15 @@ def _columns_for_constraint(constraint):
|
||||
return list(constraint.columns)
|
||||
|
||||
|
||||
def _reflect_table(inspector: Inspector, table: Table) -> None:
|
||||
if sqla_14:
|
||||
return inspector.reflect_table(table, None)
|
||||
else:
|
||||
return inspector.reflecttable( # type: ignore[attr-defined]
|
||||
table, None
|
||||
)
|
||||
|
||||
|
||||
def _resolve_for_variant(type_, dialect):
|
||||
if _type_has_variants(type_):
|
||||
base_type, mapping = _get_variant_mapping(type_)
|
||||
@@ -274,7 +351,7 @@ def _resolve_for_variant(type_, dialect):
|
||||
return type_
|
||||
|
||||
|
||||
if hasattr(sqltypes.TypeEngine, "_variant_mapping"): # 2.0
|
||||
if hasattr(sqltypes.TypeEngine, "_variant_mapping"):
|
||||
|
||||
def _type_has_variants(type_):
|
||||
return bool(type_._variant_mapping)
|
||||
@@ -291,12 +368,7 @@ else:
|
||||
return type_.impl, type_.mapping
|
||||
|
||||
|
||||
def _fk_spec(constraint: ForeignKeyConstraint) -> Any:
|
||||
if TYPE_CHECKING:
|
||||
assert constraint.columns is not None
|
||||
assert constraint.elements is not None
|
||||
assert isinstance(constraint.parent, Table)
|
||||
|
||||
def _fk_spec(constraint):
|
||||
source_columns = [
|
||||
constraint.columns[key].name for key in constraint.column_keys
|
||||
]
|
||||
@@ -325,7 +397,7 @@ def _fk_spec(constraint: ForeignKeyConstraint) -> Any:
|
||||
|
||||
|
||||
def _fk_is_self_referential(constraint: ForeignKeyConstraint) -> bool:
|
||||
spec = constraint.elements[0]._get_colspec()
|
||||
spec = constraint.elements[0]._get_colspec() # type: ignore[attr-defined]
|
||||
tokens = spec.split(".")
|
||||
tokens.pop(-1) # colname
|
||||
tablekey = ".".join(tokens)
|
||||
@@ -337,13 +409,13 @@ def _is_type_bound(constraint: Constraint) -> bool:
|
||||
# this deals with SQLAlchemy #3260, don't copy CHECK constraints
|
||||
# that will be generated by the type.
|
||||
# new feature added for #3260
|
||||
return constraint._type_bound
|
||||
return constraint._type_bound # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def _find_columns(clause):
|
||||
"""locate Column objects within the given expression."""
|
||||
|
||||
cols: Set[ColumnElement[Any]] = set()
|
||||
cols = set()
|
||||
traverse(clause, {}, {"column": cols.add})
|
||||
return cols
|
||||
|
||||
@@ -430,7 +502,7 @@ class _textual_index_element(sql.ColumnElement):
|
||||
self.fake_column = schema.Column(self.text.text, sqltypes.NULLTYPE)
|
||||
table.append_column(self.fake_column)
|
||||
|
||||
def get_children(self, **kw):
|
||||
def get_children(self):
|
||||
return [self.fake_column]
|
||||
|
||||
|
||||
@@ -452,44 +524,116 @@ def _render_literal_bindparam(
|
||||
return compiler.render_literal_bindparam(element, **kw)
|
||||
|
||||
|
||||
def _get_index_expressions(idx):
|
||||
return list(idx.expressions)
|
||||
|
||||
|
||||
def _get_index_column_names(idx):
|
||||
return [getattr(exp, "name", None) for exp in _get_index_expressions(idx)]
|
||||
|
||||
|
||||
def _column_kwargs(col: Column) -> Mapping:
|
||||
if sqla_13:
|
||||
return col.kwargs
|
||||
else:
|
||||
return {}
|
||||
|
||||
|
||||
def _get_constraint_final_name(
|
||||
constraint: Union[Index, Constraint], dialect: Optional[Dialect]
|
||||
) -> Optional[str]:
|
||||
if constraint.name is None:
|
||||
return None
|
||||
assert dialect is not None
|
||||
# for SQLAlchemy 1.4 we would like to have the option to expand
|
||||
# the use of "deferred" names for constraints as well as to have
|
||||
# some flexibility with "None" name and similar; make use of new
|
||||
# SQLAlchemy API to return what would be the final compiled form of
|
||||
# the name for this dialect.
|
||||
return dialect.identifier_preparer.format_constraint(
|
||||
constraint, _alembic_quote=False
|
||||
)
|
||||
if sqla_14:
|
||||
# for SQLAlchemy 1.4 we would like to have the option to expand
|
||||
# the use of "deferred" names for constraints as well as to have
|
||||
# some flexibility with "None" name and similar; make use of new
|
||||
# SQLAlchemy API to return what would be the final compiled form of
|
||||
# the name for this dialect.
|
||||
return dialect.identifier_preparer.format_constraint(
|
||||
constraint, _alembic_quote=False
|
||||
)
|
||||
else:
|
||||
# prior to SQLAlchemy 1.4, work around quoting logic to get at the
|
||||
# final compiled name without quotes.
|
||||
if hasattr(constraint.name, "quote"):
|
||||
# might be quoted_name, might be truncated_name, keep it the
|
||||
# same
|
||||
quoted_name_cls: type = type(constraint.name)
|
||||
else:
|
||||
quoted_name_cls = quoted_name
|
||||
|
||||
new_name = quoted_name_cls(str(constraint.name), quote=False)
|
||||
constraint = constraint.__class__(name=new_name)
|
||||
|
||||
if isinstance(constraint, schema.Index):
|
||||
# name should not be quoted.
|
||||
d = dialect.ddl_compiler(dialect, None) # type: ignore[arg-type]
|
||||
return d._prepared_index_name( # type: ignore[attr-defined]
|
||||
constraint
|
||||
)
|
||||
else:
|
||||
# name should not be quoted.
|
||||
return dialect.identifier_preparer.format_constraint(constraint)
|
||||
|
||||
|
||||
def _constraint_is_named(
|
||||
constraint: Union[Constraint, Index], dialect: Optional[Dialect]
|
||||
) -> bool:
|
||||
if constraint.name is None:
|
||||
return False
|
||||
assert dialect is not None
|
||||
name = dialect.identifier_preparer.format_constraint(
|
||||
constraint, _alembic_quote=False
|
||||
)
|
||||
return name is not None
|
||||
if sqla_14:
|
||||
if constraint.name is None:
|
||||
return False
|
||||
assert dialect is not None
|
||||
name = dialect.identifier_preparer.format_constraint(
|
||||
constraint, _alembic_quote=False
|
||||
)
|
||||
return name is not None
|
||||
else:
|
||||
return constraint.name is not None
|
||||
|
||||
|
||||
def _is_mariadb(mysql_dialect: Dialect) -> bool:
|
||||
if sqla_14:
|
||||
return mysql_dialect.is_mariadb # type: ignore[attr-defined]
|
||||
else:
|
||||
return bool(
|
||||
mysql_dialect.server_version_info
|
||||
and mysql_dialect._is_mariadb # type: ignore[attr-defined]
|
||||
)
|
||||
|
||||
|
||||
def _mariadb_normalized_version_info(mysql_dialect):
|
||||
return mysql_dialect._mariadb_normalized_version_info
|
||||
|
||||
|
||||
def _insert_inline(table: Union[TableClause, Table]) -> Insert:
|
||||
if sqla_14:
|
||||
return table.insert().inline()
|
||||
else:
|
||||
return table.insert(inline=True) # type: ignore[call-arg]
|
||||
|
||||
|
||||
if sqla_14:
|
||||
from sqlalchemy import create_mock_engine
|
||||
from sqlalchemy import select as _select
|
||||
else:
|
||||
from sqlalchemy import create_engine
|
||||
|
||||
def create_mock_engine(url, executor, **kw): # type: ignore[misc]
|
||||
return create_engine(
|
||||
"postgresql://", strategy="mock", executor=executor
|
||||
)
|
||||
|
||||
def _select(*columns, **kw) -> Select: # type: ignore[no-redef]
|
||||
return sql.select(list(columns), **kw) # type: ignore[call-overload]
|
||||
|
||||
|
||||
def is_expression_index(index: Index) -> bool:
|
||||
expr: Any
|
||||
for expr in index.expressions:
|
||||
if is_expression(expr):
|
||||
while isinstance(expr, UnaryExpression):
|
||||
expr = expr.element
|
||||
if not isinstance(expr, ColumnClause) or expr.is_literal:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def is_expression(expr: Any) -> bool:
|
||||
while isinstance(expr, UnaryExpression):
|
||||
expr = expr.element
|
||||
if not isinstance(expr, ColumnClause) or expr.is_literal:
|
||||
return True
|
||||
return False
|
||||
|
||||
Reference in New Issue
Block a user