This commit is contained in:
@@ -3,13 +3,13 @@ from __future__ import annotations
|
||||
from typing import Any
|
||||
from typing import Callable
|
||||
from typing import Collection
|
||||
from typing import ContextManager
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Mapping
|
||||
from typing import MutableMapping
|
||||
from typing import Optional
|
||||
from typing import overload
|
||||
from typing import Sequence
|
||||
from typing import TextIO
|
||||
from typing import Tuple
|
||||
from typing import TYPE_CHECKING
|
||||
@@ -17,7 +17,6 @@ from typing import Union
|
||||
|
||||
from sqlalchemy.sql.schema import Column
|
||||
from sqlalchemy.sql.schema import FetchedValue
|
||||
from typing_extensions import ContextManager
|
||||
from typing_extensions import Literal
|
||||
|
||||
from .migration import _ProxyTransaction
|
||||
@@ -108,6 +107,7 @@ CompareType = Callable[
|
||||
|
||||
|
||||
class EnvironmentContext(util.ModuleClsProxy):
|
||||
|
||||
"""A configurational facade made available in an ``env.py`` script.
|
||||
|
||||
The :class:`.EnvironmentContext` acts as a *facade* to the more
|
||||
@@ -227,9 +227,9 @@ class EnvironmentContext(util.ModuleClsProxy):
|
||||
has been configured.
|
||||
|
||||
"""
|
||||
return self.context_opts.get("as_sql", False) # type: ignore[no-any-return] # noqa: E501
|
||||
return self.context_opts.get("as_sql", False)
|
||||
|
||||
def is_transactional_ddl(self) -> bool:
|
||||
def is_transactional_ddl(self):
|
||||
"""Return True if the context is configured to expect a
|
||||
transactional DDL capable backend.
|
||||
|
||||
@@ -341,17 +341,18 @@ class EnvironmentContext(util.ModuleClsProxy):
|
||||
return self.context_opts.get("tag", None)
|
||||
|
||||
@overload
|
||||
def get_x_argument(self, as_dictionary: Literal[False]) -> List[str]: ...
|
||||
def get_x_argument(self, as_dictionary: Literal[False]) -> List[str]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def get_x_argument(
|
||||
self, as_dictionary: Literal[True]
|
||||
) -> Dict[str, str]: ...
|
||||
def get_x_argument(self, as_dictionary: Literal[True]) -> Dict[str, str]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def get_x_argument(
|
||||
self, as_dictionary: bool = ...
|
||||
) -> Union[List[str], Dict[str, str]]: ...
|
||||
) -> Union[List[str], Dict[str, str]]:
|
||||
...
|
||||
|
||||
def get_x_argument(
|
||||
self, as_dictionary: bool = False
|
||||
@@ -365,11 +366,7 @@ class EnvironmentContext(util.ModuleClsProxy):
|
||||
The return value is a list, returned directly from the ``argparse``
|
||||
structure. If ``as_dictionary=True`` is passed, the ``x`` arguments
|
||||
are parsed using ``key=value`` format into a dictionary that is
|
||||
then returned. If there is no ``=`` in the argument, value is an empty
|
||||
string.
|
||||
|
||||
.. versionchanged:: 1.13.1 Support ``as_dictionary=True`` when
|
||||
arguments are passed without the ``=`` symbol.
|
||||
then returned.
|
||||
|
||||
For example, to support passing a database URL on the command line,
|
||||
the standard ``env.py`` script can be modified like this::
|
||||
@@ -403,12 +400,7 @@ class EnvironmentContext(util.ModuleClsProxy):
|
||||
else:
|
||||
value = []
|
||||
if as_dictionary:
|
||||
dict_value = {}
|
||||
for arg in value:
|
||||
x_key, _, x_value = arg.partition("=")
|
||||
dict_value[x_key] = x_value
|
||||
value = dict_value
|
||||
|
||||
value = dict(arg.split("=", 1) for arg in value)
|
||||
return value
|
||||
|
||||
def configure(
|
||||
@@ -424,7 +416,7 @@ class EnvironmentContext(util.ModuleClsProxy):
|
||||
tag: Optional[str] = None,
|
||||
template_args: Optional[Dict[str, Any]] = None,
|
||||
render_as_batch: bool = False,
|
||||
target_metadata: Union[MetaData, Sequence[MetaData], None] = None,
|
||||
target_metadata: Optional[MetaData] = None,
|
||||
include_name: Optional[IncludeNameFn] = None,
|
||||
include_object: Optional[IncludeObjectFn] = None,
|
||||
include_schemas: bool = False,
|
||||
@@ -948,7 +940,7 @@ class EnvironmentContext(util.ModuleClsProxy):
|
||||
def execute(
|
||||
self,
|
||||
sql: Union[Executable, str],
|
||||
execution_options: Optional[Dict[str, Any]] = None,
|
||||
execution_options: Optional[dict] = None,
|
||||
) -> None:
|
||||
"""Execute the given SQL using the current change context.
|
||||
|
||||
@@ -976,7 +968,7 @@ class EnvironmentContext(util.ModuleClsProxy):
|
||||
|
||||
def begin_transaction(
|
||||
self,
|
||||
) -> Union[_ProxyTransaction, ContextManager[None, Optional[bool]]]:
|
||||
) -> Union[_ProxyTransaction, ContextManager[None]]:
|
||||
"""Return a context manager that will
|
||||
enclose an operation within a "transaction",
|
||||
as defined by the environment's offline
|
||||
|
||||
@@ -1,6 +1,3 @@
|
||||
# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
|
||||
# mypy: no-warn-return-any, allow-any-generics
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import contextmanager
|
||||
@@ -11,6 +8,7 @@ from typing import Any
|
||||
from typing import Callable
|
||||
from typing import cast
|
||||
from typing import Collection
|
||||
from typing import ContextManager
|
||||
from typing import Dict
|
||||
from typing import Iterable
|
||||
from typing import Iterator
|
||||
@@ -23,11 +21,13 @@ from typing import Union
|
||||
|
||||
from sqlalchemy import Column
|
||||
from sqlalchemy import literal_column
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import MetaData
|
||||
from sqlalchemy import PrimaryKeyConstraint
|
||||
from sqlalchemy import String
|
||||
from sqlalchemy import Table
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.engine import url as sqla_url
|
||||
from sqlalchemy.engine.strategies import MockEngineStrategy
|
||||
from typing_extensions import ContextManager
|
||||
|
||||
from .. import ddl
|
||||
from .. import util
|
||||
@@ -83,6 +83,7 @@ class _ProxyTransaction:
|
||||
|
||||
|
||||
class MigrationContext:
|
||||
|
||||
"""Represent the database state made available to a migration
|
||||
script.
|
||||
|
||||
@@ -175,11 +176,7 @@ class MigrationContext:
|
||||
opts["output_encoding"],
|
||||
)
|
||||
else:
|
||||
self.output_buffer = opts.get(
|
||||
"output_buffer", sys.stdout
|
||||
) # type:ignore[assignment] # noqa: E501
|
||||
|
||||
self.transactional_ddl = transactional_ddl
|
||||
self.output_buffer = opts.get("output_buffer", sys.stdout)
|
||||
|
||||
self._user_compare_type = opts.get("compare_type", True)
|
||||
self._user_compare_server_default = opts.get(
|
||||
@@ -191,6 +188,18 @@ class MigrationContext:
|
||||
self.version_table_schema = version_table_schema = opts.get(
|
||||
"version_table_schema", None
|
||||
)
|
||||
self._version = Table(
|
||||
version_table,
|
||||
MetaData(),
|
||||
Column("version_num", String(32), nullable=False),
|
||||
schema=version_table_schema,
|
||||
)
|
||||
if opts.get("version_table_pk", True):
|
||||
self._version.append_constraint(
|
||||
PrimaryKeyConstraint(
|
||||
"version_num", name="%s_pkc" % version_table
|
||||
)
|
||||
)
|
||||
|
||||
self._start_from_rev: Optional[str] = opts.get("starting_rev")
|
||||
self.impl = ddl.DefaultImpl.get_by_dialect(dialect)(
|
||||
@@ -201,23 +210,14 @@ class MigrationContext:
|
||||
self.output_buffer,
|
||||
opts,
|
||||
)
|
||||
|
||||
self._version = self.impl.version_table_impl(
|
||||
version_table=version_table,
|
||||
version_table_schema=version_table_schema,
|
||||
version_table_pk=opts.get("version_table_pk", True),
|
||||
)
|
||||
|
||||
log.info("Context impl %s.", self.impl.__class__.__name__)
|
||||
if self.as_sql:
|
||||
log.info("Generating static SQL")
|
||||
log.info(
|
||||
"Will assume %s DDL.",
|
||||
(
|
||||
"transactional"
|
||||
if self.impl.transactional_ddl
|
||||
else "non-transactional"
|
||||
),
|
||||
"transactional"
|
||||
if self.impl.transactional_ddl
|
||||
else "non-transactional",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -342,9 +342,9 @@ class MigrationContext:
|
||||
# except that it will not know it's in "autocommit" and will
|
||||
# emit deprecation warnings when an autocommit action takes
|
||||
# place.
|
||||
self.connection = self.impl.connection = (
|
||||
base_connection.execution_options(isolation_level="AUTOCOMMIT")
|
||||
)
|
||||
self.connection = (
|
||||
self.impl.connection
|
||||
) = base_connection.execution_options(isolation_level="AUTOCOMMIT")
|
||||
|
||||
# sqlalchemy future mode will "autobegin" in any case, so take
|
||||
# control of that "transaction" here
|
||||
@@ -372,7 +372,7 @@ class MigrationContext:
|
||||
|
||||
def begin_transaction(
|
||||
self, _per_migration: bool = False
|
||||
) -> Union[_ProxyTransaction, ContextManager[None, Optional[bool]]]:
|
||||
) -> Union[_ProxyTransaction, ContextManager[None]]:
|
||||
"""Begin a logical transaction for migration operations.
|
||||
|
||||
This method is used within an ``env.py`` script to demarcate where
|
||||
@@ -521,7 +521,7 @@ class MigrationContext:
|
||||
start_from_rev = None
|
||||
elif start_from_rev is not None and self.script:
|
||||
start_from_rev = [
|
||||
self.script.get_revision(sfr).revision
|
||||
cast("Script", self.script.get_revision(sfr)).revision
|
||||
for sfr in util.to_list(start_from_rev)
|
||||
if sfr not in (None, "base")
|
||||
]
|
||||
@@ -536,10 +536,7 @@ class MigrationContext:
|
||||
return ()
|
||||
assert self.connection is not None
|
||||
return tuple(
|
||||
row[0]
|
||||
for row in self.connection.execute(
|
||||
select(self._version.c.version_num)
|
||||
)
|
||||
row[0] for row in self.connection.execute(self._version.select())
|
||||
)
|
||||
|
||||
def _ensure_version_table(self, purge: bool = False) -> None:
|
||||
@@ -655,7 +652,7 @@ class MigrationContext:
|
||||
def execute(
|
||||
self,
|
||||
sql: Union[Executable, str],
|
||||
execution_options: Optional[Dict[str, Any]] = None,
|
||||
execution_options: Optional[dict] = None,
|
||||
) -> None:
|
||||
"""Execute a SQL construct or string statement.
|
||||
|
||||
@@ -1003,11 +1000,6 @@ class MigrationStep:
|
||||
is_upgrade: bool
|
||||
migration_fn: Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
@property
|
||||
def doc(self) -> Optional[str]: ...
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self.migration_fn.__name__
|
||||
@@ -1056,9 +1048,13 @@ class RevisionStep(MigrationStep):
|
||||
self.revision = revision
|
||||
self.is_upgrade = is_upgrade
|
||||
if is_upgrade:
|
||||
self.migration_fn = revision.module.upgrade
|
||||
self.migration_fn = (
|
||||
revision.module.upgrade # type:ignore[attr-defined]
|
||||
)
|
||||
else:
|
||||
self.migration_fn = revision.module.downgrade
|
||||
self.migration_fn = (
|
||||
revision.module.downgrade # type:ignore[attr-defined]
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return "RevisionStep(%r, is_upgrade=%r)" % (
|
||||
@@ -1074,7 +1070,7 @@ class RevisionStep(MigrationStep):
|
||||
)
|
||||
|
||||
@property
|
||||
def doc(self) -> Optional[str]:
|
||||
def doc(self) -> str:
|
||||
return self.revision.doc
|
||||
|
||||
@property
|
||||
@@ -1172,18 +1168,7 @@ class RevisionStep(MigrationStep):
|
||||
}
|
||||
return tuple(set(self.to_revisions).difference(ancestors))
|
||||
else:
|
||||
# for each revision we plan to return, compute its ancestors
|
||||
# (excluding self), and remove those from the final output since
|
||||
# they are already accounted for.
|
||||
ancestors = {
|
||||
r.revision
|
||||
for to_revision in self.to_revisions
|
||||
for r in self.revision_map._get_ancestor_nodes(
|
||||
self.revision_map.get_revisions(to_revision), check=False
|
||||
)
|
||||
if r.revision != to_revision
|
||||
}
|
||||
return tuple(set(self.to_revisions).difference(ancestors))
|
||||
return self.to_revisions
|
||||
|
||||
def unmerge_branch_idents(
|
||||
self, heads: Set[str]
|
||||
@@ -1298,7 +1283,7 @@ class StampStep(MigrationStep):
|
||||
def __eq__(self, other):
|
||||
return (
|
||||
isinstance(other, StampStep)
|
||||
and other.from_revisions == self.from_revisions
|
||||
and other.from_revisions == self.revisions
|
||||
and other.to_revisions == self.to_revisions
|
||||
and other.branch_move == self.branch_move
|
||||
and self.is_upgrade == other.is_upgrade
|
||||
|
||||
Reference in New Issue
Block a user