This commit is contained in:
@@ -1,6 +1,3 @@
|
||||
# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
|
||||
# mypy: no-warn-return-any, allow-any-generics
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import contextmanager
|
||||
@@ -11,6 +8,7 @@ from typing import Any
|
||||
from typing import Callable
|
||||
from typing import cast
|
||||
from typing import Collection
|
||||
from typing import ContextManager
|
||||
from typing import Dict
|
||||
from typing import Iterable
|
||||
from typing import Iterator
|
||||
@@ -23,11 +21,13 @@ from typing import Union
|
||||
|
||||
from sqlalchemy import Column
|
||||
from sqlalchemy import literal_column
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import MetaData
|
||||
from sqlalchemy import PrimaryKeyConstraint
|
||||
from sqlalchemy import String
|
||||
from sqlalchemy import Table
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.engine import url as sqla_url
|
||||
from sqlalchemy.engine.strategies import MockEngineStrategy
|
||||
from typing_extensions import ContextManager
|
||||
|
||||
from .. import ddl
|
||||
from .. import util
|
||||
@@ -83,6 +83,7 @@ class _ProxyTransaction:
|
||||
|
||||
|
||||
class MigrationContext:
|
||||
|
||||
"""Represent the database state made available to a migration
|
||||
script.
|
||||
|
||||
@@ -175,11 +176,7 @@ class MigrationContext:
|
||||
opts["output_encoding"],
|
||||
)
|
||||
else:
|
||||
self.output_buffer = opts.get(
|
||||
"output_buffer", sys.stdout
|
||||
) # type:ignore[assignment] # noqa: E501
|
||||
|
||||
self.transactional_ddl = transactional_ddl
|
||||
self.output_buffer = opts.get("output_buffer", sys.stdout)
|
||||
|
||||
self._user_compare_type = opts.get("compare_type", True)
|
||||
self._user_compare_server_default = opts.get(
|
||||
@@ -191,6 +188,18 @@ class MigrationContext:
|
||||
self.version_table_schema = version_table_schema = opts.get(
|
||||
"version_table_schema", None
|
||||
)
|
||||
self._version = Table(
|
||||
version_table,
|
||||
MetaData(),
|
||||
Column("version_num", String(32), nullable=False),
|
||||
schema=version_table_schema,
|
||||
)
|
||||
if opts.get("version_table_pk", True):
|
||||
self._version.append_constraint(
|
||||
PrimaryKeyConstraint(
|
||||
"version_num", name="%s_pkc" % version_table
|
||||
)
|
||||
)
|
||||
|
||||
self._start_from_rev: Optional[str] = opts.get("starting_rev")
|
||||
self.impl = ddl.DefaultImpl.get_by_dialect(dialect)(
|
||||
@@ -201,23 +210,14 @@ class MigrationContext:
|
||||
self.output_buffer,
|
||||
opts,
|
||||
)
|
||||
|
||||
self._version = self.impl.version_table_impl(
|
||||
version_table=version_table,
|
||||
version_table_schema=version_table_schema,
|
||||
version_table_pk=opts.get("version_table_pk", True),
|
||||
)
|
||||
|
||||
log.info("Context impl %s.", self.impl.__class__.__name__)
|
||||
if self.as_sql:
|
||||
log.info("Generating static SQL")
|
||||
log.info(
|
||||
"Will assume %s DDL.",
|
||||
(
|
||||
"transactional"
|
||||
if self.impl.transactional_ddl
|
||||
else "non-transactional"
|
||||
),
|
||||
"transactional"
|
||||
if self.impl.transactional_ddl
|
||||
else "non-transactional",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -342,9 +342,9 @@ class MigrationContext:
|
||||
# except that it will not know it's in "autocommit" and will
|
||||
# emit deprecation warnings when an autocommit action takes
|
||||
# place.
|
||||
self.connection = self.impl.connection = (
|
||||
base_connection.execution_options(isolation_level="AUTOCOMMIT")
|
||||
)
|
||||
self.connection = (
|
||||
self.impl.connection
|
||||
) = base_connection.execution_options(isolation_level="AUTOCOMMIT")
|
||||
|
||||
# sqlalchemy future mode will "autobegin" in any case, so take
|
||||
# control of that "transaction" here
|
||||
@@ -372,7 +372,7 @@ class MigrationContext:
|
||||
|
||||
def begin_transaction(
|
||||
self, _per_migration: bool = False
|
||||
) -> Union[_ProxyTransaction, ContextManager[None, Optional[bool]]]:
|
||||
) -> Union[_ProxyTransaction, ContextManager[None]]:
|
||||
"""Begin a logical transaction for migration operations.
|
||||
|
||||
This method is used within an ``env.py`` script to demarcate where
|
||||
@@ -521,7 +521,7 @@ class MigrationContext:
|
||||
start_from_rev = None
|
||||
elif start_from_rev is not None and self.script:
|
||||
start_from_rev = [
|
||||
self.script.get_revision(sfr).revision
|
||||
cast("Script", self.script.get_revision(sfr)).revision
|
||||
for sfr in util.to_list(start_from_rev)
|
||||
if sfr not in (None, "base")
|
||||
]
|
||||
@@ -536,10 +536,7 @@ class MigrationContext:
|
||||
return ()
|
||||
assert self.connection is not None
|
||||
return tuple(
|
||||
row[0]
|
||||
for row in self.connection.execute(
|
||||
select(self._version.c.version_num)
|
||||
)
|
||||
row[0] for row in self.connection.execute(self._version.select())
|
||||
)
|
||||
|
||||
def _ensure_version_table(self, purge: bool = False) -> None:
|
||||
@@ -655,7 +652,7 @@ class MigrationContext:
|
||||
def execute(
|
||||
self,
|
||||
sql: Union[Executable, str],
|
||||
execution_options: Optional[Dict[str, Any]] = None,
|
||||
execution_options: Optional[dict] = None,
|
||||
) -> None:
|
||||
"""Execute a SQL construct or string statement.
|
||||
|
||||
@@ -1003,11 +1000,6 @@ class MigrationStep:
|
||||
is_upgrade: bool
|
||||
migration_fn: Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
@property
|
||||
def doc(self) -> Optional[str]: ...
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self.migration_fn.__name__
|
||||
@@ -1056,9 +1048,13 @@ class RevisionStep(MigrationStep):
|
||||
self.revision = revision
|
||||
self.is_upgrade = is_upgrade
|
||||
if is_upgrade:
|
||||
self.migration_fn = revision.module.upgrade
|
||||
self.migration_fn = (
|
||||
revision.module.upgrade # type:ignore[attr-defined]
|
||||
)
|
||||
else:
|
||||
self.migration_fn = revision.module.downgrade
|
||||
self.migration_fn = (
|
||||
revision.module.downgrade # type:ignore[attr-defined]
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return "RevisionStep(%r, is_upgrade=%r)" % (
|
||||
@@ -1074,7 +1070,7 @@ class RevisionStep(MigrationStep):
|
||||
)
|
||||
|
||||
@property
|
||||
def doc(self) -> Optional[str]:
|
||||
def doc(self) -> str:
|
||||
return self.revision.doc
|
||||
|
||||
@property
|
||||
@@ -1172,18 +1168,7 @@ class RevisionStep(MigrationStep):
|
||||
}
|
||||
return tuple(set(self.to_revisions).difference(ancestors))
|
||||
else:
|
||||
# for each revision we plan to return, compute its ancestors
|
||||
# (excluding self), and remove those from the final output since
|
||||
# they are already accounted for.
|
||||
ancestors = {
|
||||
r.revision
|
||||
for to_revision in self.to_revisions
|
||||
for r in self.revision_map._get_ancestor_nodes(
|
||||
self.revision_map.get_revisions(to_revision), check=False
|
||||
)
|
||||
if r.revision != to_revision
|
||||
}
|
||||
return tuple(set(self.to_revisions).difference(ancestors))
|
||||
return self.to_revisions
|
||||
|
||||
def unmerge_branch_idents(
|
||||
self, heads: Set[str]
|
||||
@@ -1298,7 +1283,7 @@ class StampStep(MigrationStep):
|
||||
def __eq__(self, other):
|
||||
return (
|
||||
isinstance(other, StampStep)
|
||||
and other.from_revisions == self.from_revisions
|
||||
and other.from_revisions == self.revisions
|
||||
and other.to_revisions == self.to_revisions
|
||||
and other.branch_move == self.branch_move
|
||||
and self.is_upgrade == other.is_upgrade
|
||||
|
||||
Reference in New Issue
Block a user