This commit is contained in:
@@ -9,12 +9,15 @@ from sqlalchemy.testing import uses_deprecated
|
||||
from sqlalchemy.testing.config import combinations
|
||||
from sqlalchemy.testing.config import fixture
|
||||
from sqlalchemy.testing.config import requirements as requires
|
||||
from sqlalchemy.testing.config import Variation
|
||||
from sqlalchemy.testing.config import variation
|
||||
|
||||
from .assertions import assert_raises
|
||||
from .assertions import assert_raises_message
|
||||
from .assertions import emits_python_deprecation_warning
|
||||
from .assertions import eq_
|
||||
from .assertions import eq_ignore_whitespace
|
||||
from .assertions import expect_deprecated
|
||||
from .assertions import expect_raises
|
||||
from .assertions import expect_raises_message
|
||||
from .assertions import expect_sqlalchemy_deprecated
|
||||
|
||||
@@ -8,6 +8,7 @@ from typing import Dict
|
||||
|
||||
from sqlalchemy import exc as sa_exc
|
||||
from sqlalchemy.engine import default
|
||||
from sqlalchemy.engine import URL
|
||||
from sqlalchemy.testing.assertions import _expect_warnings
|
||||
from sqlalchemy.testing.assertions import eq_ # noqa
|
||||
from sqlalchemy.testing.assertions import is_ # noqa
|
||||
@@ -17,8 +18,6 @@ from sqlalchemy.testing.assertions import is_true # noqa
|
||||
from sqlalchemy.testing.assertions import ne_ # noqa
|
||||
from sqlalchemy.util import decorator
|
||||
|
||||
from ..util import sqla_compat
|
||||
|
||||
|
||||
def _assert_proper_exception_context(exception):
|
||||
"""assert that any exception we're catching does not have a __context__
|
||||
@@ -74,7 +73,9 @@ class _ErrorContainer:
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _expect_raises(except_cls, msg=None, check_context=False):
|
||||
def _expect_raises(
|
||||
except_cls, msg=None, check_context=False, text_exact=False
|
||||
):
|
||||
ec = _ErrorContainer()
|
||||
if check_context:
|
||||
are_we_already_in_a_traceback = sys.exc_info()[0]
|
||||
@@ -85,7 +86,10 @@ def _expect_raises(except_cls, msg=None, check_context=False):
|
||||
ec.error = err
|
||||
success = True
|
||||
if msg is not None:
|
||||
assert re.search(msg, str(err), re.UNICODE), f"{msg} !~ {err}"
|
||||
if text_exact:
|
||||
assert str(err) == msg, f"{msg} != {err}"
|
||||
else:
|
||||
assert re.search(msg, str(err), re.UNICODE), f"{msg} !~ {err}"
|
||||
if check_context and not are_we_already_in_a_traceback:
|
||||
_assert_proper_exception_context(err)
|
||||
print(str(err).encode("utf-8"))
|
||||
@@ -98,8 +102,12 @@ def expect_raises(except_cls, check_context=True):
|
||||
return _expect_raises(except_cls, check_context=check_context)
|
||||
|
||||
|
||||
def expect_raises_message(except_cls, msg, check_context=True):
|
||||
return _expect_raises(except_cls, msg=msg, check_context=check_context)
|
||||
def expect_raises_message(
|
||||
except_cls, msg, check_context=True, text_exact=False
|
||||
):
|
||||
return _expect_raises(
|
||||
except_cls, msg=msg, check_context=check_context, text_exact=text_exact
|
||||
)
|
||||
|
||||
|
||||
def eq_ignore_whitespace(a, b, msg=None):
|
||||
@@ -118,7 +126,7 @@ def _get_dialect(name):
|
||||
if name is None or name == "default":
|
||||
return default.DefaultDialect()
|
||||
else:
|
||||
d = sqla_compat._create_url(name).get_dialect()()
|
||||
d = URL.create(name).get_dialect()()
|
||||
|
||||
if name == "postgresql":
|
||||
d.implicit_returning = True
|
||||
@@ -159,6 +167,10 @@ def emits_python_deprecation_warning(*messages):
|
||||
return decorate
|
||||
|
||||
|
||||
def expect_deprecated(*messages, **kw):
|
||||
return _expect_warnings(DeprecationWarning, messages, **kw)
|
||||
|
||||
|
||||
def expect_sqlalchemy_deprecated(*messages, **kw):
|
||||
return _expect_warnings(sa_exc.SADeprecationWarning, messages, **kw)
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import importlib.machinery
|
||||
import os
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
import textwrap
|
||||
|
||||
@@ -16,7 +17,7 @@ from ..script import ScriptDirectory
|
||||
|
||||
def _get_staging_directory():
|
||||
if provision.FOLLOWER_IDENT:
|
||||
return "scratch_%s" % provision.FOLLOWER_IDENT
|
||||
return f"scratch_{provision.FOLLOWER_IDENT}"
|
||||
else:
|
||||
return "scratch"
|
||||
|
||||
@@ -24,7 +25,7 @@ def _get_staging_directory():
|
||||
def staging_env(create=True, template="generic", sourceless=False):
|
||||
cfg = _testing_config()
|
||||
if create:
|
||||
path = os.path.join(_get_staging_directory(), "scripts")
|
||||
path = _join_path(_get_staging_directory(), "scripts")
|
||||
assert not os.path.exists(path), (
|
||||
"staging directory %s already exists; poor cleanup?" % path
|
||||
)
|
||||
@@ -47,7 +48,7 @@ def staging_env(create=True, template="generic", sourceless=False):
|
||||
"pep3147_everything",
|
||||
), sourceless
|
||||
make_sourceless(
|
||||
os.path.join(path, "env.py"),
|
||||
_join_path(path, "env.py"),
|
||||
"pep3147" if "pep3147" in sourceless else "simple",
|
||||
)
|
||||
|
||||
@@ -63,14 +64,14 @@ def clear_staging_env():
|
||||
|
||||
|
||||
def script_file_fixture(txt):
|
||||
dir_ = os.path.join(_get_staging_directory(), "scripts")
|
||||
path = os.path.join(dir_, "script.py.mako")
|
||||
dir_ = _join_path(_get_staging_directory(), "scripts")
|
||||
path = _join_path(dir_, "script.py.mako")
|
||||
with open(path, "w") as f:
|
||||
f.write(txt)
|
||||
|
||||
|
||||
def env_file_fixture(txt):
|
||||
dir_ = os.path.join(_get_staging_directory(), "scripts")
|
||||
dir_ = _join_path(_get_staging_directory(), "scripts")
|
||||
txt = (
|
||||
"""
|
||||
from alembic import context
|
||||
@@ -80,7 +81,7 @@ config = context.config
|
||||
+ txt
|
||||
)
|
||||
|
||||
path = os.path.join(dir_, "env.py")
|
||||
path = _join_path(dir_, "env.py")
|
||||
pyc_path = util.pyc_file_from_path(path)
|
||||
if pyc_path:
|
||||
os.unlink(pyc_path)
|
||||
@@ -90,26 +91,26 @@ config = context.config
|
||||
|
||||
|
||||
def _sqlite_file_db(tempname="foo.db", future=False, scope=None, **options):
|
||||
dir_ = os.path.join(_get_staging_directory(), "scripts")
|
||||
dir_ = _join_path(_get_staging_directory(), "scripts")
|
||||
url = "sqlite:///%s/%s" % (dir_, tempname)
|
||||
if scope and util.sqla_14:
|
||||
if scope:
|
||||
options["scope"] = scope
|
||||
return testing_util.testing_engine(url=url, future=future, options=options)
|
||||
|
||||
|
||||
def _sqlite_testing_config(sourceless=False, future=False):
|
||||
dir_ = os.path.join(_get_staging_directory(), "scripts")
|
||||
url = "sqlite:///%s/foo.db" % dir_
|
||||
dir_ = _join_path(_get_staging_directory(), "scripts")
|
||||
url = f"sqlite:///{dir_}/foo.db"
|
||||
|
||||
sqlalchemy_future = future or ("future" in config.db.__class__.__module__)
|
||||
|
||||
return _write_config_file(
|
||||
"""
|
||||
f"""
|
||||
[alembic]
|
||||
script_location = %s
|
||||
sqlalchemy.url = %s
|
||||
sourceless = %s
|
||||
%s
|
||||
script_location = {dir_}
|
||||
sqlalchemy.url = {url}
|
||||
sourceless = {"true" if sourceless else "false"}
|
||||
{"sqlalchemy.future = true" if sqlalchemy_future else ""}
|
||||
|
||||
[loggers]
|
||||
keys = root,sqlalchemy
|
||||
@@ -118,7 +119,7 @@ keys = root,sqlalchemy
|
||||
keys = console
|
||||
|
||||
[logger_root]
|
||||
level = WARN
|
||||
level = WARNING
|
||||
handlers = console
|
||||
qualname =
|
||||
|
||||
@@ -140,29 +141,25 @@ keys = generic
|
||||
format = %%(levelname)-5.5s [%%(name)s] %%(message)s
|
||||
datefmt = %%H:%%M:%%S
|
||||
"""
|
||||
% (
|
||||
dir_,
|
||||
url,
|
||||
"true" if sourceless else "false",
|
||||
"sqlalchemy.future = true" if sqlalchemy_future else "",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _multi_dir_testing_config(sourceless=False, extra_version_location=""):
|
||||
dir_ = os.path.join(_get_staging_directory(), "scripts")
|
||||
dir_ = _join_path(_get_staging_directory(), "scripts")
|
||||
sqlalchemy_future = "future" in config.db.__class__.__module__
|
||||
|
||||
url = "sqlite:///%s/foo.db" % dir_
|
||||
|
||||
return _write_config_file(
|
||||
"""
|
||||
f"""
|
||||
[alembic]
|
||||
script_location = %s
|
||||
sqlalchemy.url = %s
|
||||
sqlalchemy.future = %s
|
||||
sourceless = %s
|
||||
version_locations = %%(here)s/model1/ %%(here)s/model2/ %%(here)s/model3/ %s
|
||||
script_location = {dir_}
|
||||
sqlalchemy.url = {url}
|
||||
sqlalchemy.future = {"true" if sqlalchemy_future else "false"}
|
||||
sourceless = {"true" if sourceless else "false"}
|
||||
path_separator = space
|
||||
version_locations = %(here)s/model1/ %(here)s/model2/ %(here)s/model3/ \
|
||||
{extra_version_location}
|
||||
|
||||
[loggers]
|
||||
keys = root
|
||||
@@ -171,7 +168,7 @@ keys = root
|
||||
keys = console
|
||||
|
||||
[logger_root]
|
||||
level = WARN
|
||||
level = WARNING
|
||||
handlers = console
|
||||
qualname =
|
||||
|
||||
@@ -188,26 +185,24 @@ keys = generic
|
||||
format = %%(levelname)-5.5s [%%(name)s] %%(message)s
|
||||
datefmt = %%H:%%M:%%S
|
||||
"""
|
||||
% (
|
||||
dir_,
|
||||
url,
|
||||
"true" if sqlalchemy_future else "false",
|
||||
"true" if sourceless else "false",
|
||||
extra_version_location,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _no_sql_testing_config(dialect="postgresql", directives=""):
|
||||
def _no_sql_pyproject_config(dialect="postgresql", directives=""):
|
||||
"""use a postgresql url with no host so that
|
||||
connections guaranteed to fail"""
|
||||
dir_ = os.path.join(_get_staging_directory(), "scripts")
|
||||
return _write_config_file(
|
||||
"""
|
||||
dir_ = _join_path(_get_staging_directory(), "scripts")
|
||||
|
||||
return _write_toml_config(
|
||||
f"""
|
||||
[tool.alembic]
|
||||
script_location ="{dir_}"
|
||||
{textwrap.dedent(directives)}
|
||||
|
||||
""",
|
||||
f"""
|
||||
[alembic]
|
||||
script_location = %s
|
||||
sqlalchemy.url = %s://
|
||||
%s
|
||||
sqlalchemy.url = {dialect}://
|
||||
|
||||
[loggers]
|
||||
keys = root
|
||||
@@ -216,7 +211,46 @@ keys = root
|
||||
keys = console
|
||||
|
||||
[logger_root]
|
||||
level = WARN
|
||||
level = WARNING
|
||||
handlers = console
|
||||
qualname =
|
||||
|
||||
[handler_console]
|
||||
class = StreamHandler
|
||||
args = (sys.stderr,)
|
||||
level = NOTSET
|
||||
formatter = generic
|
||||
|
||||
[formatters]
|
||||
keys = generic
|
||||
|
||||
[formatter_generic]
|
||||
format = %%(levelname)-5.5s [%%(name)s] %%(message)s
|
||||
datefmt = %%H:%%M:%%S
|
||||
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
def _no_sql_testing_config(dialect="postgresql", directives=""):
|
||||
"""use a postgresql url with no host so that
|
||||
connections guaranteed to fail"""
|
||||
dir_ = _join_path(_get_staging_directory(), "scripts")
|
||||
return _write_config_file(
|
||||
f"""
|
||||
[alembic]
|
||||
script_location ={dir_}
|
||||
sqlalchemy.url = {dialect}://
|
||||
{directives}
|
||||
|
||||
[loggers]
|
||||
keys = root
|
||||
|
||||
[handlers]
|
||||
keys = console
|
||||
|
||||
[logger_root]
|
||||
level = WARNING
|
||||
handlers = console
|
||||
qualname =
|
||||
|
||||
@@ -234,10 +268,16 @@ format = %%(levelname)-5.5s [%%(name)s] %%(message)s
|
||||
datefmt = %%H:%%M:%%S
|
||||
|
||||
"""
|
||||
% (dir_, dialect, directives)
|
||||
)
|
||||
|
||||
|
||||
def _write_toml_config(tomltext, initext):
|
||||
cfg = _write_config_file(initext)
|
||||
with open(cfg.toml_file_name, "w") as f:
|
||||
f.write(tomltext)
|
||||
return cfg
|
||||
|
||||
|
||||
def _write_config_file(text):
|
||||
cfg = _testing_config()
|
||||
with open(cfg.config_file_name, "w") as f:
|
||||
@@ -250,7 +290,10 @@ def _testing_config():
|
||||
|
||||
if not os.access(_get_staging_directory(), os.F_OK):
|
||||
os.mkdir(_get_staging_directory())
|
||||
return Config(os.path.join(_get_staging_directory(), "test_alembic.ini"))
|
||||
return Config(
|
||||
_join_path(_get_staging_directory(), "test_alembic.ini"),
|
||||
_join_path(_get_staging_directory(), "pyproject.toml"),
|
||||
)
|
||||
|
||||
|
||||
def write_script(
|
||||
@@ -270,9 +313,7 @@ def write_script(
|
||||
script = Script._from_path(scriptdir, path)
|
||||
old = scriptdir.revision_map.get_revision(script.revision)
|
||||
if old.down_revision != script.down_revision:
|
||||
raise Exception(
|
||||
"Can't change down_revision " "on a refresh operation."
|
||||
)
|
||||
raise Exception("Can't change down_revision on a refresh operation.")
|
||||
scriptdir.revision_map.add_revision(script, _replace=True)
|
||||
|
||||
if sourceless:
|
||||
@@ -312,9 +353,9 @@ def three_rev_fixture(cfg):
|
||||
write_script(
|
||||
script,
|
||||
a,
|
||||
"""\
|
||||
f"""\
|
||||
"Rev A"
|
||||
revision = '%s'
|
||||
revision = '{a}'
|
||||
down_revision = None
|
||||
|
||||
from alembic import op
|
||||
@@ -327,8 +368,7 @@ def upgrade():
|
||||
def downgrade():
|
||||
op.execute("DROP STEP 1")
|
||||
|
||||
"""
|
||||
% a,
|
||||
""",
|
||||
)
|
||||
|
||||
script.generate_revision(b, "revision b", refresh=True, head=a)
|
||||
@@ -358,10 +398,10 @@ def downgrade():
|
||||
write_script(
|
||||
script,
|
||||
c,
|
||||
"""\
|
||||
f"""\
|
||||
"Rev C"
|
||||
revision = '%s'
|
||||
down_revision = '%s'
|
||||
revision = '{c}'
|
||||
down_revision = '{b}'
|
||||
|
||||
from alembic import op
|
||||
|
||||
@@ -373,8 +413,7 @@ def upgrade():
|
||||
def downgrade():
|
||||
op.execute("DROP STEP 3")
|
||||
|
||||
"""
|
||||
% (c, b),
|
||||
""",
|
||||
)
|
||||
return a, b, c
|
||||
|
||||
@@ -396,10 +435,10 @@ def multi_heads_fixture(cfg, a, b, c):
|
||||
write_script(
|
||||
script,
|
||||
d,
|
||||
"""\
|
||||
f"""\
|
||||
"Rev D"
|
||||
revision = '%s'
|
||||
down_revision = '%s'
|
||||
revision = '{d}'
|
||||
down_revision = '{b}'
|
||||
|
||||
from alembic import op
|
||||
|
||||
@@ -411,8 +450,7 @@ def upgrade():
|
||||
def downgrade():
|
||||
op.execute("DROP STEP 4")
|
||||
|
||||
"""
|
||||
% (d, b),
|
||||
""",
|
||||
)
|
||||
|
||||
script.generate_revision(
|
||||
@@ -421,10 +459,10 @@ def downgrade():
|
||||
write_script(
|
||||
script,
|
||||
e,
|
||||
"""\
|
||||
f"""\
|
||||
"Rev E"
|
||||
revision = '%s'
|
||||
down_revision = '%s'
|
||||
revision = '{e}'
|
||||
down_revision = '{d}'
|
||||
|
||||
from alembic import op
|
||||
|
||||
@@ -436,8 +474,7 @@ def upgrade():
|
||||
def downgrade():
|
||||
op.execute("DROP STEP 5")
|
||||
|
||||
"""
|
||||
% (e, d),
|
||||
""",
|
||||
)
|
||||
|
||||
script.generate_revision(
|
||||
@@ -446,10 +483,10 @@ def downgrade():
|
||||
write_script(
|
||||
script,
|
||||
f,
|
||||
"""\
|
||||
f"""\
|
||||
"Rev F"
|
||||
revision = '%s'
|
||||
down_revision = '%s'
|
||||
revision = '{f}'
|
||||
down_revision = '{b}'
|
||||
|
||||
from alembic import op
|
||||
|
||||
@@ -461,8 +498,7 @@ def upgrade():
|
||||
def downgrade():
|
||||
op.execute("DROP STEP 6")
|
||||
|
||||
"""
|
||||
% (f, b),
|
||||
""",
|
||||
)
|
||||
|
||||
return d, e, f
|
||||
@@ -471,25 +507,25 @@ def downgrade():
|
||||
def _multidb_testing_config(engines):
|
||||
"""alembic.ini fixture to work exactly with the 'multidb' template"""
|
||||
|
||||
dir_ = os.path.join(_get_staging_directory(), "scripts")
|
||||
dir_ = _join_path(_get_staging_directory(), "scripts")
|
||||
|
||||
sqlalchemy_future = "future" in config.db.__class__.__module__
|
||||
|
||||
databases = ", ".join(engines.keys())
|
||||
engines = "\n\n".join(
|
||||
"[%s]\n" "sqlalchemy.url = %s" % (key, value.url)
|
||||
f"[{key}]\nsqlalchemy.url = {value.url}"
|
||||
for key, value in engines.items()
|
||||
)
|
||||
|
||||
return _write_config_file(
|
||||
"""
|
||||
f"""
|
||||
[alembic]
|
||||
script_location = %s
|
||||
script_location = {dir_}
|
||||
sourceless = false
|
||||
sqlalchemy.future = %s
|
||||
databases = %s
|
||||
sqlalchemy.future = {"true" if sqlalchemy_future else "false"}
|
||||
databases = {databases}
|
||||
|
||||
%s
|
||||
{engines}
|
||||
[loggers]
|
||||
keys = root
|
||||
|
||||
@@ -497,7 +533,7 @@ keys = root
|
||||
keys = console
|
||||
|
||||
[logger_root]
|
||||
level = WARN
|
||||
level = WARNING
|
||||
handlers = console
|
||||
qualname =
|
||||
|
||||
@@ -514,5 +550,8 @@ keys = generic
|
||||
format = %%(levelname)-5.5s [%%(name)s] %%(message)s
|
||||
datefmt = %%H:%%M:%%S
|
||||
"""
|
||||
% (dir_, "true" if sqlalchemy_future else "false", databases, engines)
|
||||
)
|
||||
|
||||
|
||||
def _join_path(base: str, *more: str):
|
||||
return str(Path(base).joinpath(*more).as_posix())
|
||||
|
||||
@@ -3,11 +3,14 @@ from __future__ import annotations
|
||||
import configparser
|
||||
from contextlib import contextmanager
|
||||
import io
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
|
||||
from sqlalchemy import Column
|
||||
from sqlalchemy import create_mock_engine
|
||||
from sqlalchemy import inspect
|
||||
from sqlalchemy import MetaData
|
||||
from sqlalchemy import String
|
||||
@@ -17,20 +20,19 @@ from sqlalchemy import text
|
||||
from sqlalchemy.testing import config
|
||||
from sqlalchemy.testing import mock
|
||||
from sqlalchemy.testing.assertions import eq_
|
||||
from sqlalchemy.testing.fixtures import FutureEngineMixin
|
||||
from sqlalchemy.testing.fixtures import TablesTest as SQLAlchemyTablesTest
|
||||
from sqlalchemy.testing.fixtures import TestBase as SQLAlchemyTestBase
|
||||
|
||||
import alembic
|
||||
from .assertions import _get_dialect
|
||||
from .env import _get_staging_directory
|
||||
from ..environment import EnvironmentContext
|
||||
from ..migration import MigrationContext
|
||||
from ..operations import Operations
|
||||
from ..util import sqla_compat
|
||||
from ..util.sqla_compat import create_mock_engine
|
||||
from ..util.sqla_compat import sqla_14
|
||||
from ..util.sqla_compat import sqla_2
|
||||
|
||||
|
||||
testing_config = configparser.ConfigParser()
|
||||
testing_config.read(["test.cfg"])
|
||||
|
||||
@@ -38,6 +40,31 @@ testing_config.read(["test.cfg"])
|
||||
class TestBase(SQLAlchemyTestBase):
|
||||
is_sqlalchemy_future = sqla_2
|
||||
|
||||
@testing.fixture()
|
||||
def clear_staging_dir(self):
|
||||
yield
|
||||
location = _get_staging_directory()
|
||||
for filename in os.listdir(location):
|
||||
file_path = os.path.join(location, filename)
|
||||
if os.path.isfile(file_path) or os.path.islink(file_path):
|
||||
os.unlink(file_path)
|
||||
elif os.path.isdir(file_path):
|
||||
shutil.rmtree(file_path)
|
||||
|
||||
@contextmanager
|
||||
def pushd(self, dirname):
|
||||
current_dir = os.getcwd()
|
||||
try:
|
||||
os.chdir(dirname)
|
||||
yield
|
||||
finally:
|
||||
os.chdir(current_dir)
|
||||
|
||||
@testing.fixture()
|
||||
def pop_alembic_config_env(self):
|
||||
yield
|
||||
os.environ.pop("ALEMBIC_CONFIG", None)
|
||||
|
||||
@testing.fixture()
|
||||
def ops_context(self, migration_context):
|
||||
with migration_context.begin_transaction(_per_migration=True):
|
||||
@@ -49,6 +76,12 @@ class TestBase(SQLAlchemyTestBase):
|
||||
connection, opts=dict(transaction_per_migration=True)
|
||||
)
|
||||
|
||||
@testing.fixture
|
||||
def as_sql_migration_context(self, connection):
|
||||
return MigrationContext.configure(
|
||||
connection, opts=dict(transaction_per_migration=True, as_sql=True)
|
||||
)
|
||||
|
||||
@testing.fixture
|
||||
def connection(self):
|
||||
with config.db.connect() as conn:
|
||||
@@ -59,14 +92,6 @@ class TablesTest(TestBase, SQLAlchemyTablesTest):
|
||||
pass
|
||||
|
||||
|
||||
if sqla_14:
|
||||
from sqlalchemy.testing.fixtures import FutureEngineMixin
|
||||
else:
|
||||
|
||||
class FutureEngineMixin: # type:ignore[no-redef]
|
||||
__requires__ = ("sqlalchemy_14",)
|
||||
|
||||
|
||||
FutureEngineMixin.is_sqlalchemy_future = True
|
||||
|
||||
|
||||
@@ -184,12 +209,8 @@ def op_fixture(
|
||||
opts["as_sql"] = as_sql
|
||||
if literal_binds:
|
||||
opts["literal_binds"] = literal_binds
|
||||
if not sqla_14 and dialect == "mariadb":
|
||||
ctx_dialect = _get_dialect("mysql")
|
||||
ctx_dialect.server_version_info = (10, 4, 0, "MariaDB")
|
||||
|
||||
else:
|
||||
ctx_dialect = _get_dialect(dialect)
|
||||
ctx_dialect = _get_dialect(dialect)
|
||||
if native_boolean is not None:
|
||||
ctx_dialect.supports_native_boolean = native_boolean
|
||||
# this is new as of SQLAlchemy 1.2.7 and is used by SQL Server,
|
||||
@@ -268,9 +289,11 @@ class AlterColRoundTripFixture:
|
||||
"x",
|
||||
column.name,
|
||||
existing_type=column.type,
|
||||
existing_server_default=column.server_default
|
||||
if column.server_default is not None
|
||||
else False,
|
||||
existing_server_default=(
|
||||
column.server_default
|
||||
if column.server_default is not None
|
||||
else False
|
||||
),
|
||||
existing_nullable=True if column.nullable else False,
|
||||
# existing_comment=column.comment,
|
||||
nullable=to_.get("nullable", None),
|
||||
@@ -298,9 +321,13 @@ class AlterColRoundTripFixture:
|
||||
new_col["type"],
|
||||
new_col.get("default", None),
|
||||
compare.get("type", old_col["type"]),
|
||||
compare["server_default"].text
|
||||
if "server_default" in compare
|
||||
else column.server_default.arg.text
|
||||
if column.server_default is not None
|
||||
else None,
|
||||
(
|
||||
compare["server_default"].text
|
||||
if "server_default" in compare
|
||||
else (
|
||||
column.server_default.arg.text
|
||||
if column.server_default is not None
|
||||
else None
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from sqlalchemy.testing.requirements import Requirements
|
||||
|
||||
from alembic import util
|
||||
from alembic.util import sqla_compat
|
||||
from ..testing import exclusions
|
||||
|
||||
|
||||
@@ -74,13 +73,6 @@ class SuiteRequirements(Requirements):
|
||||
def reflects_fk_options(self):
|
||||
return exclusions.closed()
|
||||
|
||||
@property
|
||||
def sqlalchemy_14(self):
|
||||
return exclusions.skip_if(
|
||||
lambda config: not util.sqla_14,
|
||||
"SQLAlchemy 1.4 or greater required",
|
||||
)
|
||||
|
||||
@property
|
||||
def sqlalchemy_1x(self):
|
||||
return exclusions.skip_if(
|
||||
@@ -95,6 +87,18 @@ class SuiteRequirements(Requirements):
|
||||
"SQLAlchemy 2.x test",
|
||||
)
|
||||
|
||||
@property
|
||||
def asyncio(self):
|
||||
def go(config):
|
||||
try:
|
||||
import greenlet # noqa: F401
|
||||
except ImportError:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
return exclusions.only_if(go)
|
||||
|
||||
@property
|
||||
def comments(self):
|
||||
return exclusions.only_if(
|
||||
@@ -109,26 +113,6 @@ class SuiteRequirements(Requirements):
|
||||
def computed_columns(self):
|
||||
return exclusions.closed()
|
||||
|
||||
@property
|
||||
def computed_columns_api(self):
|
||||
return exclusions.only_if(
|
||||
exclusions.BooleanPredicate(sqla_compat.has_computed)
|
||||
)
|
||||
|
||||
@property
|
||||
def computed_reflects_normally(self):
|
||||
return exclusions.only_if(
|
||||
exclusions.BooleanPredicate(sqla_compat.has_computed_reflection)
|
||||
)
|
||||
|
||||
@property
|
||||
def computed_reflects_as_server_default(self):
|
||||
return exclusions.closed()
|
||||
|
||||
@property
|
||||
def computed_doesnt_reflect_as_server_default(self):
|
||||
return exclusions.closed()
|
||||
|
||||
@property
|
||||
def autoincrement_on_composite_pk(self):
|
||||
return exclusions.closed()
|
||||
@@ -190,9 +174,3 @@ class SuiteRequirements(Requirements):
|
||||
@property
|
||||
def identity_columns_alter(self):
|
||||
return exclusions.closed()
|
||||
|
||||
@property
|
||||
def identity_columns_api(self):
|
||||
return exclusions.only_if(
|
||||
exclusions.BooleanPredicate(sqla_compat.has_identity)
|
||||
)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from itertools import zip_longest
|
||||
|
||||
from sqlalchemy import schema
|
||||
from sqlalchemy.sql.elements import ClauseList
|
||||
|
||||
|
||||
class CompareTable:
|
||||
@@ -60,6 +61,14 @@ class CompareIndex:
|
||||
def __ne__(self, other):
|
||||
return not self.__eq__(other)
|
||||
|
||||
def __repr__(self):
|
||||
expr = ClauseList(*self.index.expressions)
|
||||
try:
|
||||
expr_str = expr.compile().string
|
||||
except Exception:
|
||||
expr_str = str(expr)
|
||||
return f"<CompareIndex {self.index.name}({expr_str})>"
|
||||
|
||||
|
||||
class CompareCheckConstraint:
|
||||
def __init__(self, constraint):
|
||||
|
||||
@@ -14,6 +14,7 @@ from sqlalchemy import inspect
|
||||
from sqlalchemy import Integer
|
||||
from sqlalchemy import MetaData
|
||||
from sqlalchemy import Numeric
|
||||
from sqlalchemy import PrimaryKeyConstraint
|
||||
from sqlalchemy import String
|
||||
from sqlalchemy import Table
|
||||
from sqlalchemy import Text
|
||||
@@ -149,6 +150,118 @@ class ModelOne:
|
||||
return m
|
||||
|
||||
|
||||
class NamingConvModel:
|
||||
__requires__ = ("unique_constraint_reflection",)
|
||||
configure_opts = {"conv_all_constraint_names": True}
|
||||
naming_convention = {
|
||||
"ix": "ix_%(column_0_label)s",
|
||||
"uq": "uq_%(table_name)s_%(constraint_name)s",
|
||||
"ck": "ck_%(table_name)s_%(constraint_name)s",
|
||||
"fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
|
||||
"pk": "pk_%(table_name)s",
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def _get_db_schema(cls):
|
||||
# database side - assume all constraints have a name that
|
||||
# we would assume here is a "db generated" name. need to make
|
||||
# sure these all render with op.f().
|
||||
m = MetaData()
|
||||
Table(
|
||||
"x1",
|
||||
m,
|
||||
Column("q", Integer),
|
||||
Index("db_x1_index_q", "q"),
|
||||
PrimaryKeyConstraint("q", name="db_x1_primary_q"),
|
||||
)
|
||||
Table(
|
||||
"x2",
|
||||
m,
|
||||
Column("q", Integer),
|
||||
Column("p", ForeignKey("x1.q", name="db_x2_foreign_q")),
|
||||
CheckConstraint("q > 5", name="db_x2_check_q"),
|
||||
)
|
||||
Table(
|
||||
"x3",
|
||||
m,
|
||||
Column("q", Integer),
|
||||
Column("r", Integer),
|
||||
Column("s", Integer),
|
||||
UniqueConstraint("q", name="db_x3_unique_q"),
|
||||
)
|
||||
Table(
|
||||
"x4",
|
||||
m,
|
||||
Column("q", Integer),
|
||||
PrimaryKeyConstraint("q", name="db_x4_primary_q"),
|
||||
)
|
||||
Table(
|
||||
"x5",
|
||||
m,
|
||||
Column("q", Integer),
|
||||
Column("p", ForeignKey("x4.q", name="db_x5_foreign_q")),
|
||||
Column("r", Integer),
|
||||
Column("s", Integer),
|
||||
PrimaryKeyConstraint("q", name="db_x5_primary_q"),
|
||||
UniqueConstraint("r", name="db_x5_unique_r"),
|
||||
CheckConstraint("s > 5", name="db_x5_check_s"),
|
||||
)
|
||||
# SQLite and it's "no names needed" thing. bleh.
|
||||
# we can't have a name for these so you'll see "None" for the name.
|
||||
Table(
|
||||
"unnamed_sqlite",
|
||||
m,
|
||||
Column("q", Integer),
|
||||
Column("r", Integer),
|
||||
PrimaryKeyConstraint("q"),
|
||||
UniqueConstraint("r"),
|
||||
)
|
||||
return m
|
||||
|
||||
@classmethod
|
||||
def _get_model_schema(cls):
|
||||
from sqlalchemy.sql.naming import conv
|
||||
|
||||
m = MetaData(naming_convention=cls.naming_convention)
|
||||
Table(
|
||||
"x1", m, Column("q", Integer, primary_key=True), Index(None, "q")
|
||||
)
|
||||
Table(
|
||||
"x2",
|
||||
m,
|
||||
Column("q", Integer),
|
||||
Column("p", ForeignKey("x1.q")),
|
||||
CheckConstraint("q > 5", name="token_x2check1"),
|
||||
)
|
||||
Table(
|
||||
"x3",
|
||||
m,
|
||||
Column("q", Integer),
|
||||
Column("r", Integer),
|
||||
Column("s", Integer),
|
||||
UniqueConstraint("r", name="token_x3r"),
|
||||
UniqueConstraint("s", name=conv("userdef_x3_unique_s")),
|
||||
)
|
||||
Table(
|
||||
"x4",
|
||||
m,
|
||||
Column("q", Integer, primary_key=True),
|
||||
Index("userdef_x4_idx_q", "q"),
|
||||
)
|
||||
Table(
|
||||
"x6",
|
||||
m,
|
||||
Column("q", Integer, primary_key=True),
|
||||
Column("p", ForeignKey("x4.q")),
|
||||
Column("r", Integer),
|
||||
Column("s", Integer),
|
||||
UniqueConstraint("r", name="token_x6r"),
|
||||
CheckConstraint("s > 5", "token_x6check1"),
|
||||
CheckConstraint("s < 20", conv("userdef_x6_check_s")),
|
||||
)
|
||||
return m
|
||||
|
||||
|
||||
class _ComparesFKs:
|
||||
def _assert_fk_diff(
|
||||
self,
|
||||
|
||||
@@ -6,9 +6,7 @@ from sqlalchemy import Table
|
||||
|
||||
from ._autogen_fixtures import AutogenFixtureTest
|
||||
from ... import testing
|
||||
from ...testing import config
|
||||
from ...testing import eq_
|
||||
from ...testing import exclusions
|
||||
from ...testing import is_
|
||||
from ...testing import is_true
|
||||
from ...testing import mock
|
||||
@@ -63,18 +61,8 @@ class AutogenerateComputedTest(AutogenFixtureTest, TestBase):
|
||||
c = diffs[0][3]
|
||||
eq_(c.name, "foo")
|
||||
|
||||
if config.requirements.computed_reflects_normally.enabled:
|
||||
is_true(isinstance(c.computed, sa.Computed))
|
||||
else:
|
||||
is_(c.computed, None)
|
||||
|
||||
if config.requirements.computed_reflects_as_server_default.enabled:
|
||||
is_true(isinstance(c.server_default, sa.DefaultClause))
|
||||
eq_(str(c.server_default.arg.text), "5")
|
||||
elif config.requirements.computed_reflects_normally.enabled:
|
||||
is_true(isinstance(c.computed, sa.Computed))
|
||||
else:
|
||||
is_(c.computed, None)
|
||||
is_true(isinstance(c.computed, sa.Computed))
|
||||
is_true(isinstance(c.server_default, sa.Computed))
|
||||
|
||||
@testing.combinations(
|
||||
lambda: (None, sa.Computed("bar*5")),
|
||||
@@ -85,7 +73,6 @@ class AutogenerateComputedTest(AutogenFixtureTest, TestBase):
|
||||
),
|
||||
lambda: (sa.Computed("bar*5"), sa.Computed("bar * 42")),
|
||||
)
|
||||
@config.requirements.computed_reflects_normally
|
||||
def test_cant_change_computed_warning(self, test_case):
|
||||
arg_before, arg_after = testing.resolve_lambda(test_case, **locals())
|
||||
m1 = MetaData()
|
||||
@@ -124,10 +111,7 @@ class AutogenerateComputedTest(AutogenFixtureTest, TestBase):
|
||||
lambda: (None, None),
|
||||
lambda: (sa.Computed("5"), sa.Computed("5")),
|
||||
lambda: (sa.Computed("bar*5"), sa.Computed("bar*5")),
|
||||
(
|
||||
lambda: (sa.Computed("bar*5"), None),
|
||||
config.requirements.computed_doesnt_reflect_as_server_default,
|
||||
),
|
||||
lambda: (sa.Computed("bar*5"), sa.Computed("bar * \r\n\t5")),
|
||||
)
|
||||
def test_computed_unchanged(self, test_case):
|
||||
arg_before, arg_after = testing.resolve_lambda(test_case, **locals())
|
||||
@@ -158,46 +142,3 @@ class AutogenerateComputedTest(AutogenFixtureTest, TestBase):
|
||||
eq_(mock_warn.mock_calls, [])
|
||||
|
||||
eq_(list(diffs), [])
|
||||
|
||||
@config.requirements.computed_reflects_as_server_default
|
||||
def test_remove_computed_default_on_computed(self):
|
||||
"""Asserts the current behavior which is that on PG and Oracle,
|
||||
the GENERATED ALWAYS AS is reflected as a server default which we can't
|
||||
tell is actually "computed", so these come out as a modification to
|
||||
the server default.
|
||||
|
||||
"""
|
||||
m1 = MetaData()
|
||||
m2 = MetaData()
|
||||
|
||||
Table(
|
||||
"user",
|
||||
m1,
|
||||
Column("id", Integer, primary_key=True),
|
||||
Column("bar", Integer),
|
||||
Column("foo", Integer, sa.Computed("bar + 42")),
|
||||
)
|
||||
|
||||
Table(
|
||||
"user",
|
||||
m2,
|
||||
Column("id", Integer, primary_key=True),
|
||||
Column("bar", Integer),
|
||||
Column("foo", Integer),
|
||||
)
|
||||
|
||||
diffs = self._fixture(m1, m2)
|
||||
|
||||
eq_(diffs[0][0][0], "modify_default")
|
||||
eq_(diffs[0][0][2], "user")
|
||||
eq_(diffs[0][0][3], "foo")
|
||||
old = diffs[0][0][-2]
|
||||
new = diffs[0][0][-1]
|
||||
|
||||
is_(new, None)
|
||||
is_true(isinstance(old, sa.DefaultClause))
|
||||
|
||||
if exclusions.against(config, "postgresql"):
|
||||
eq_(str(old.arg.text), "(bar + 42)")
|
||||
elif exclusions.against(config, "oracle"):
|
||||
eq_(str(old.arg.text), '"BAR"+42')
|
||||
|
||||
@@ -24,9 +24,9 @@ class MigrationTransactionTest(TestBase):
|
||||
self.context = MigrationContext.configure(
|
||||
dialect=conn.dialect, opts=opts
|
||||
)
|
||||
self.context.output_buffer = (
|
||||
self.context.impl.output_buffer
|
||||
) = io.StringIO()
|
||||
self.context.output_buffer = self.context.impl.output_buffer = (
|
||||
io.StringIO()
|
||||
)
|
||||
else:
|
||||
self.context = MigrationContext.configure(
|
||||
connection=conn, opts=opts
|
||||
|
||||
@@ -10,8 +10,6 @@ import warnings
|
||||
|
||||
from sqlalchemy import exc as sa_exc
|
||||
|
||||
from ..util import sqla_14
|
||||
|
||||
|
||||
def setup_filters():
|
||||
"""Set global warning behavior for the test suite."""
|
||||
@@ -23,13 +21,6 @@ def setup_filters():
|
||||
|
||||
# some selected deprecations...
|
||||
warnings.filterwarnings("error", category=DeprecationWarning)
|
||||
if not sqla_14:
|
||||
# 1.3 uses pkg_resources in PluginLoader
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
"pkg_resources is deprecated as an API",
|
||||
DeprecationWarning,
|
||||
)
|
||||
try:
|
||||
import pytest
|
||||
except ImportError:
|
||||
|
||||
Reference in New Issue
Block a user