main commit
All checks were successful
continuous-integration/drone/push Build is passing

This commit is contained in:
2025-10-16 16:30:25 +09:00
parent 91c7e04474
commit 537e7b363f
1146 changed files with 45926 additions and 77196 deletions

View File

@@ -9,15 +9,12 @@ from sqlalchemy.testing import uses_deprecated
from sqlalchemy.testing.config import combinations
from sqlalchemy.testing.config import fixture
from sqlalchemy.testing.config import requirements as requires
from sqlalchemy.testing.config import Variation
from sqlalchemy.testing.config import variation
from .assertions import assert_raises
from .assertions import assert_raises_message
from .assertions import emits_python_deprecation_warning
from .assertions import eq_
from .assertions import eq_ignore_whitespace
from .assertions import expect_deprecated
from .assertions import expect_raises
from .assertions import expect_raises_message
from .assertions import expect_sqlalchemy_deprecated

View File

@@ -8,7 +8,6 @@ from typing import Dict
from sqlalchemy import exc as sa_exc
from sqlalchemy.engine import default
from sqlalchemy.engine import URL
from sqlalchemy.testing.assertions import _expect_warnings
from sqlalchemy.testing.assertions import eq_ # noqa
from sqlalchemy.testing.assertions import is_ # noqa
@@ -18,6 +17,8 @@ from sqlalchemy.testing.assertions import is_true # noqa
from sqlalchemy.testing.assertions import ne_ # noqa
from sqlalchemy.util import decorator
from ..util import sqla_compat
def _assert_proper_exception_context(exception):
"""assert that any exception we're catching does not have a __context__
@@ -73,9 +74,7 @@ class _ErrorContainer:
@contextlib.contextmanager
def _expect_raises(
except_cls, msg=None, check_context=False, text_exact=False
):
def _expect_raises(except_cls, msg=None, check_context=False):
ec = _ErrorContainer()
if check_context:
are_we_already_in_a_traceback = sys.exc_info()[0]
@@ -86,10 +85,7 @@ def _expect_raises(
ec.error = err
success = True
if msg is not None:
if text_exact:
assert str(err) == msg, f"{msg} != {err}"
else:
assert re.search(msg, str(err), re.UNICODE), f"{msg} !~ {err}"
assert re.search(msg, str(err), re.UNICODE), f"{msg} !~ {err}"
if check_context and not are_we_already_in_a_traceback:
_assert_proper_exception_context(err)
print(str(err).encode("utf-8"))
@@ -102,12 +98,8 @@ def expect_raises(except_cls, check_context=True):
return _expect_raises(except_cls, check_context=check_context)
def expect_raises_message(
except_cls, msg, check_context=True, text_exact=False
):
return _expect_raises(
except_cls, msg=msg, check_context=check_context, text_exact=text_exact
)
def expect_raises_message(except_cls, msg, check_context=True):
return _expect_raises(except_cls, msg=msg, check_context=check_context)
def eq_ignore_whitespace(a, b, msg=None):
@@ -126,7 +118,7 @@ def _get_dialect(name):
if name is None or name == "default":
return default.DefaultDialect()
else:
d = URL.create(name).get_dialect()()
d = sqla_compat._create_url(name).get_dialect()()
if name == "postgresql":
d.implicit_returning = True
@@ -167,10 +159,6 @@ def emits_python_deprecation_warning(*messages):
return decorate
def expect_deprecated(*messages, **kw):
return _expect_warnings(DeprecationWarning, messages, **kw)
def expect_sqlalchemy_deprecated(*messages, **kw):
return _expect_warnings(sa_exc.SADeprecationWarning, messages, **kw)

View File

@@ -1,6 +1,5 @@
import importlib.machinery
import os
from pathlib import Path
import shutil
import textwrap
@@ -17,7 +16,7 @@ from ..script import ScriptDirectory
def _get_staging_directory():
if provision.FOLLOWER_IDENT:
return f"scratch_{provision.FOLLOWER_IDENT}"
return "scratch_%s" % provision.FOLLOWER_IDENT
else:
return "scratch"
@@ -25,7 +24,7 @@ def _get_staging_directory():
def staging_env(create=True, template="generic", sourceless=False):
cfg = _testing_config()
if create:
path = _join_path(_get_staging_directory(), "scripts")
path = os.path.join(_get_staging_directory(), "scripts")
assert not os.path.exists(path), (
"staging directory %s already exists; poor cleanup?" % path
)
@@ -48,7 +47,7 @@ def staging_env(create=True, template="generic", sourceless=False):
"pep3147_everything",
), sourceless
make_sourceless(
_join_path(path, "env.py"),
os.path.join(path, "env.py"),
"pep3147" if "pep3147" in sourceless else "simple",
)
@@ -64,14 +63,14 @@ def clear_staging_env():
def script_file_fixture(txt):
dir_ = _join_path(_get_staging_directory(), "scripts")
path = _join_path(dir_, "script.py.mako")
dir_ = os.path.join(_get_staging_directory(), "scripts")
path = os.path.join(dir_, "script.py.mako")
with open(path, "w") as f:
f.write(txt)
def env_file_fixture(txt):
dir_ = _join_path(_get_staging_directory(), "scripts")
dir_ = os.path.join(_get_staging_directory(), "scripts")
txt = (
"""
from alembic import context
@@ -81,7 +80,7 @@ config = context.config
+ txt
)
path = _join_path(dir_, "env.py")
path = os.path.join(dir_, "env.py")
pyc_path = util.pyc_file_from_path(path)
if pyc_path:
os.unlink(pyc_path)
@@ -91,26 +90,26 @@ config = context.config
def _sqlite_file_db(tempname="foo.db", future=False, scope=None, **options):
dir_ = _join_path(_get_staging_directory(), "scripts")
dir_ = os.path.join(_get_staging_directory(), "scripts")
url = "sqlite:///%s/%s" % (dir_, tempname)
if scope:
if scope and util.sqla_14:
options["scope"] = scope
return testing_util.testing_engine(url=url, future=future, options=options)
def _sqlite_testing_config(sourceless=False, future=False):
dir_ = _join_path(_get_staging_directory(), "scripts")
url = f"sqlite:///{dir_}/foo.db"
dir_ = os.path.join(_get_staging_directory(), "scripts")
url = "sqlite:///%s/foo.db" % dir_
sqlalchemy_future = future or ("future" in config.db.__class__.__module__)
return _write_config_file(
f"""
"""
[alembic]
script_location = {dir_}
sqlalchemy.url = {url}
sourceless = {"true" if sourceless else "false"}
{"sqlalchemy.future = true" if sqlalchemy_future else ""}
script_location = %s
sqlalchemy.url = %s
sourceless = %s
%s
[loggers]
keys = root,sqlalchemy
@@ -119,7 +118,7 @@ keys = root,sqlalchemy
keys = console
[logger_root]
level = WARNING
level = WARN
handlers = console
qualname =
@@ -141,25 +140,29 @@ keys = generic
format = %%(levelname)-5.5s [%%(name)s] %%(message)s
datefmt = %%H:%%M:%%S
"""
% (
dir_,
url,
"true" if sourceless else "false",
"sqlalchemy.future = true" if sqlalchemy_future else "",
)
)
def _multi_dir_testing_config(sourceless=False, extra_version_location=""):
dir_ = _join_path(_get_staging_directory(), "scripts")
dir_ = os.path.join(_get_staging_directory(), "scripts")
sqlalchemy_future = "future" in config.db.__class__.__module__
url = "sqlite:///%s/foo.db" % dir_
return _write_config_file(
f"""
"""
[alembic]
script_location = {dir_}
sqlalchemy.url = {url}
sqlalchemy.future = {"true" if sqlalchemy_future else "false"}
sourceless = {"true" if sourceless else "false"}
path_separator = space
version_locations = %(here)s/model1/ %(here)s/model2/ %(here)s/model3/ \
{extra_version_location}
script_location = %s
sqlalchemy.url = %s
sqlalchemy.future = %s
sourceless = %s
version_locations = %%(here)s/model1/ %%(here)s/model2/ %%(here)s/model3/ %s
[loggers]
keys = root
@@ -168,7 +171,7 @@ keys = root
keys = console
[logger_root]
level = WARNING
level = WARN
handlers = console
qualname =
@@ -185,63 +188,26 @@ keys = generic
format = %%(levelname)-5.5s [%%(name)s] %%(message)s
datefmt = %%H:%%M:%%S
"""
)
def _no_sql_pyproject_config(dialect="postgresql", directives=""):
"""use a postgresql url with no host so that
connections guaranteed to fail"""
dir_ = _join_path(_get_staging_directory(), "scripts")
return _write_toml_config(
f"""
[tool.alembic]
script_location ="{dir_}"
{textwrap.dedent(directives)}
""",
f"""
[alembic]
sqlalchemy.url = {dialect}://
[loggers]
keys = root
[handlers]
keys = console
[logger_root]
level = WARNING
handlers = console
qualname =
[handler_console]
class = StreamHandler
args = (sys.stderr,)
level = NOTSET
formatter = generic
[formatters]
keys = generic
[formatter_generic]
format = %%(levelname)-5.5s [%%(name)s] %%(message)s
datefmt = %%H:%%M:%%S
""",
% (
dir_,
url,
"true" if sqlalchemy_future else "false",
"true" if sourceless else "false",
extra_version_location,
)
)
def _no_sql_testing_config(dialect="postgresql", directives=""):
"""use a postgresql url with no host so that
connections guaranteed to fail"""
dir_ = _join_path(_get_staging_directory(), "scripts")
dir_ = os.path.join(_get_staging_directory(), "scripts")
return _write_config_file(
f"""
"""
[alembic]
script_location ={dir_}
sqlalchemy.url = {dialect}://
{directives}
script_location = %s
sqlalchemy.url = %s://
%s
[loggers]
keys = root
@@ -250,7 +216,7 @@ keys = root
keys = console
[logger_root]
level = WARNING
level = WARN
handlers = console
qualname =
@@ -268,16 +234,10 @@ format = %%(levelname)-5.5s [%%(name)s] %%(message)s
datefmt = %%H:%%M:%%S
"""
% (dir_, dialect, directives)
)
def _write_toml_config(tomltext, initext):
cfg = _write_config_file(initext)
with open(cfg.toml_file_name, "w") as f:
f.write(tomltext)
return cfg
def _write_config_file(text):
cfg = _testing_config()
with open(cfg.config_file_name, "w") as f:
@@ -290,10 +250,7 @@ def _testing_config():
if not os.access(_get_staging_directory(), os.F_OK):
os.mkdir(_get_staging_directory())
return Config(
_join_path(_get_staging_directory(), "test_alembic.ini"),
_join_path(_get_staging_directory(), "pyproject.toml"),
)
return Config(os.path.join(_get_staging_directory(), "test_alembic.ini"))
def write_script(
@@ -313,7 +270,9 @@ def write_script(
script = Script._from_path(scriptdir, path)
old = scriptdir.revision_map.get_revision(script.revision)
if old.down_revision != script.down_revision:
raise Exception("Can't change down_revision on a refresh operation.")
raise Exception(
"Can't change down_revision " "on a refresh operation."
)
scriptdir.revision_map.add_revision(script, _replace=True)
if sourceless:
@@ -353,9 +312,9 @@ def three_rev_fixture(cfg):
write_script(
script,
a,
f"""\
"""\
"Rev A"
revision = '{a}'
revision = '%s'
down_revision = None
from alembic import op
@@ -368,7 +327,8 @@ def upgrade():
def downgrade():
op.execute("DROP STEP 1")
""",
"""
% a,
)
script.generate_revision(b, "revision b", refresh=True, head=a)
@@ -398,10 +358,10 @@ def downgrade():
write_script(
script,
c,
f"""\
"""\
"Rev C"
revision = '{c}'
down_revision = '{b}'
revision = '%s'
down_revision = '%s'
from alembic import op
@@ -413,7 +373,8 @@ def upgrade():
def downgrade():
op.execute("DROP STEP 3")
""",
"""
% (c, b),
)
return a, b, c
@@ -435,10 +396,10 @@ def multi_heads_fixture(cfg, a, b, c):
write_script(
script,
d,
f"""\
"""\
"Rev D"
revision = '{d}'
down_revision = '{b}'
revision = '%s'
down_revision = '%s'
from alembic import op
@@ -450,7 +411,8 @@ def upgrade():
def downgrade():
op.execute("DROP STEP 4")
""",
"""
% (d, b),
)
script.generate_revision(
@@ -459,10 +421,10 @@ def downgrade():
write_script(
script,
e,
f"""\
"""\
"Rev E"
revision = '{e}'
down_revision = '{d}'
revision = '%s'
down_revision = '%s'
from alembic import op
@@ -474,7 +436,8 @@ def upgrade():
def downgrade():
op.execute("DROP STEP 5")
""",
"""
% (e, d),
)
script.generate_revision(
@@ -483,10 +446,10 @@ def downgrade():
write_script(
script,
f,
f"""\
"""\
"Rev F"
revision = '{f}'
down_revision = '{b}'
revision = '%s'
down_revision = '%s'
from alembic import op
@@ -498,7 +461,8 @@ def upgrade():
def downgrade():
op.execute("DROP STEP 6")
""",
"""
% (f, b),
)
return d, e, f
@@ -507,25 +471,25 @@ def downgrade():
def _multidb_testing_config(engines):
"""alembic.ini fixture to work exactly with the 'multidb' template"""
dir_ = _join_path(_get_staging_directory(), "scripts")
dir_ = os.path.join(_get_staging_directory(), "scripts")
sqlalchemy_future = "future" in config.db.__class__.__module__
databases = ", ".join(engines.keys())
engines = "\n\n".join(
f"[{key}]\nsqlalchemy.url = {value.url}"
"[%s]\n" "sqlalchemy.url = %s" % (key, value.url)
for key, value in engines.items()
)
return _write_config_file(
f"""
"""
[alembic]
script_location = {dir_}
script_location = %s
sourceless = false
sqlalchemy.future = {"true" if sqlalchemy_future else "false"}
databases = {databases}
sqlalchemy.future = %s
databases = %s
{engines}
%s
[loggers]
keys = root
@@ -533,7 +497,7 @@ keys = root
keys = console
[logger_root]
level = WARNING
level = WARN
handlers = console
qualname =
@@ -550,8 +514,5 @@ keys = generic
format = %%(levelname)-5.5s [%%(name)s] %%(message)s
datefmt = %%H:%%M:%%S
"""
% (dir_, "true" if sqlalchemy_future else "false", databases, engines)
)
def _join_path(base: str, *more: str):
return str(Path(base).joinpath(*more).as_posix())

View File

@@ -3,14 +3,11 @@ from __future__ import annotations
import configparser
from contextlib import contextmanager
import io
import os
import re
import shutil
from typing import Any
from typing import Dict
from sqlalchemy import Column
from sqlalchemy import create_mock_engine
from sqlalchemy import inspect
from sqlalchemy import MetaData
from sqlalchemy import String
@@ -20,19 +17,20 @@ from sqlalchemy import text
from sqlalchemy.testing import config
from sqlalchemy.testing import mock
from sqlalchemy.testing.assertions import eq_
from sqlalchemy.testing.fixtures import FutureEngineMixin
from sqlalchemy.testing.fixtures import TablesTest as SQLAlchemyTablesTest
from sqlalchemy.testing.fixtures import TestBase as SQLAlchemyTestBase
import alembic
from .assertions import _get_dialect
from .env import _get_staging_directory
from ..environment import EnvironmentContext
from ..migration import MigrationContext
from ..operations import Operations
from ..util import sqla_compat
from ..util.sqla_compat import create_mock_engine
from ..util.sqla_compat import sqla_14
from ..util.sqla_compat import sqla_2
testing_config = configparser.ConfigParser()
testing_config.read(["test.cfg"])
@@ -40,31 +38,6 @@ testing_config.read(["test.cfg"])
class TestBase(SQLAlchemyTestBase):
is_sqlalchemy_future = sqla_2
@testing.fixture()
def clear_staging_dir(self):
yield
location = _get_staging_directory()
for filename in os.listdir(location):
file_path = os.path.join(location, filename)
if os.path.isfile(file_path) or os.path.islink(file_path):
os.unlink(file_path)
elif os.path.isdir(file_path):
shutil.rmtree(file_path)
@contextmanager
def pushd(self, dirname):
current_dir = os.getcwd()
try:
os.chdir(dirname)
yield
finally:
os.chdir(current_dir)
@testing.fixture()
def pop_alembic_config_env(self):
yield
os.environ.pop("ALEMBIC_CONFIG", None)
@testing.fixture()
def ops_context(self, migration_context):
with migration_context.begin_transaction(_per_migration=True):
@@ -76,12 +49,6 @@ class TestBase(SQLAlchemyTestBase):
connection, opts=dict(transaction_per_migration=True)
)
@testing.fixture
def as_sql_migration_context(self, connection):
return MigrationContext.configure(
connection, opts=dict(transaction_per_migration=True, as_sql=True)
)
@testing.fixture
def connection(self):
with config.db.connect() as conn:
@@ -92,6 +59,14 @@ class TablesTest(TestBase, SQLAlchemyTablesTest):
pass
if sqla_14:
from sqlalchemy.testing.fixtures import FutureEngineMixin
else:
class FutureEngineMixin: # type:ignore[no-redef]
__requires__ = ("sqlalchemy_14",)
FutureEngineMixin.is_sqlalchemy_future = True
@@ -209,8 +184,12 @@ def op_fixture(
opts["as_sql"] = as_sql
if literal_binds:
opts["literal_binds"] = literal_binds
if not sqla_14 and dialect == "mariadb":
ctx_dialect = _get_dialect("mysql")
ctx_dialect.server_version_info = (10, 4, 0, "MariaDB")
ctx_dialect = _get_dialect(dialect)
else:
ctx_dialect = _get_dialect(dialect)
if native_boolean is not None:
ctx_dialect.supports_native_boolean = native_boolean
# this is new as of SQLAlchemy 1.2.7 and is used by SQL Server,
@@ -289,11 +268,9 @@ class AlterColRoundTripFixture:
"x",
column.name,
existing_type=column.type,
existing_server_default=(
column.server_default
if column.server_default is not None
else False
),
existing_server_default=column.server_default
if column.server_default is not None
else False,
existing_nullable=True if column.nullable else False,
# existing_comment=column.comment,
nullable=to_.get("nullable", None),
@@ -321,13 +298,9 @@ class AlterColRoundTripFixture:
new_col["type"],
new_col.get("default", None),
compare.get("type", old_col["type"]),
(
compare["server_default"].text
if "server_default" in compare
else (
column.server_default.arg.text
if column.server_default is not None
else None
)
),
compare["server_default"].text
if "server_default" in compare
else column.server_default.arg.text
if column.server_default is not None
else None,
)

View File

@@ -1,6 +1,7 @@
from sqlalchemy.testing.requirements import Requirements
from alembic import util
from alembic.util import sqla_compat
from ..testing import exclusions
@@ -73,6 +74,13 @@ class SuiteRequirements(Requirements):
def reflects_fk_options(self):
return exclusions.closed()
@property
def sqlalchemy_14(self):
return exclusions.skip_if(
lambda config: not util.sqla_14,
"SQLAlchemy 1.4 or greater required",
)
@property
def sqlalchemy_1x(self):
return exclusions.skip_if(
@@ -87,18 +95,6 @@ class SuiteRequirements(Requirements):
"SQLAlchemy 2.x test",
)
@property
def asyncio(self):
def go(config):
try:
import greenlet # noqa: F401
except ImportError:
return False
else:
return True
return exclusions.only_if(go)
@property
def comments(self):
return exclusions.only_if(
@@ -113,6 +109,26 @@ class SuiteRequirements(Requirements):
def computed_columns(self):
return exclusions.closed()
@property
def computed_columns_api(self):
return exclusions.only_if(
exclusions.BooleanPredicate(sqla_compat.has_computed)
)
@property
def computed_reflects_normally(self):
return exclusions.only_if(
exclusions.BooleanPredicate(sqla_compat.has_computed_reflection)
)
@property
def computed_reflects_as_server_default(self):
return exclusions.closed()
@property
def computed_doesnt_reflect_as_server_default(self):
return exclusions.closed()
@property
def autoincrement_on_composite_pk(self):
return exclusions.closed()
@@ -174,3 +190,9 @@ class SuiteRequirements(Requirements):
@property
def identity_columns_alter(self):
return exclusions.closed()
@property
def identity_columns_api(self):
return exclusions.only_if(
exclusions.BooleanPredicate(sqla_compat.has_identity)
)

View File

@@ -1,7 +1,6 @@
from itertools import zip_longest
from sqlalchemy import schema
from sqlalchemy.sql.elements import ClauseList
class CompareTable:
@@ -61,14 +60,6 @@ class CompareIndex:
def __ne__(self, other):
return not self.__eq__(other)
def __repr__(self):
expr = ClauseList(*self.index.expressions)
try:
expr_str = expr.compile().string
except Exception:
expr_str = str(expr)
return f"<CompareIndex {self.index.name}({expr_str})>"
class CompareCheckConstraint:
def __init__(self, constraint):

View File

@@ -14,7 +14,6 @@ from sqlalchemy import inspect
from sqlalchemy import Integer
from sqlalchemy import MetaData
from sqlalchemy import Numeric
from sqlalchemy import PrimaryKeyConstraint
from sqlalchemy import String
from sqlalchemy import Table
from sqlalchemy import Text
@@ -150,118 +149,6 @@ class ModelOne:
return m
class NamingConvModel:
__requires__ = ("unique_constraint_reflection",)
configure_opts = {"conv_all_constraint_names": True}
naming_convention = {
"ix": "ix_%(column_0_label)s",
"uq": "uq_%(table_name)s_%(constraint_name)s",
"ck": "ck_%(table_name)s_%(constraint_name)s",
"fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
"pk": "pk_%(table_name)s",
}
@classmethod
def _get_db_schema(cls):
# database side - assume all constraints have a name that
# we would assume here is a "db generated" name. need to make
# sure these all render with op.f().
m = MetaData()
Table(
"x1",
m,
Column("q", Integer),
Index("db_x1_index_q", "q"),
PrimaryKeyConstraint("q", name="db_x1_primary_q"),
)
Table(
"x2",
m,
Column("q", Integer),
Column("p", ForeignKey("x1.q", name="db_x2_foreign_q")),
CheckConstraint("q > 5", name="db_x2_check_q"),
)
Table(
"x3",
m,
Column("q", Integer),
Column("r", Integer),
Column("s", Integer),
UniqueConstraint("q", name="db_x3_unique_q"),
)
Table(
"x4",
m,
Column("q", Integer),
PrimaryKeyConstraint("q", name="db_x4_primary_q"),
)
Table(
"x5",
m,
Column("q", Integer),
Column("p", ForeignKey("x4.q", name="db_x5_foreign_q")),
Column("r", Integer),
Column("s", Integer),
PrimaryKeyConstraint("q", name="db_x5_primary_q"),
UniqueConstraint("r", name="db_x5_unique_r"),
CheckConstraint("s > 5", name="db_x5_check_s"),
)
# SQLite and it's "no names needed" thing. bleh.
# we can't have a name for these so you'll see "None" for the name.
Table(
"unnamed_sqlite",
m,
Column("q", Integer),
Column("r", Integer),
PrimaryKeyConstraint("q"),
UniqueConstraint("r"),
)
return m
@classmethod
def _get_model_schema(cls):
from sqlalchemy.sql.naming import conv
m = MetaData(naming_convention=cls.naming_convention)
Table(
"x1", m, Column("q", Integer, primary_key=True), Index(None, "q")
)
Table(
"x2",
m,
Column("q", Integer),
Column("p", ForeignKey("x1.q")),
CheckConstraint("q > 5", name="token_x2check1"),
)
Table(
"x3",
m,
Column("q", Integer),
Column("r", Integer),
Column("s", Integer),
UniqueConstraint("r", name="token_x3r"),
UniqueConstraint("s", name=conv("userdef_x3_unique_s")),
)
Table(
"x4",
m,
Column("q", Integer, primary_key=True),
Index("userdef_x4_idx_q", "q"),
)
Table(
"x6",
m,
Column("q", Integer, primary_key=True),
Column("p", ForeignKey("x4.q")),
Column("r", Integer),
Column("s", Integer),
UniqueConstraint("r", name="token_x6r"),
CheckConstraint("s > 5", "token_x6check1"),
CheckConstraint("s < 20", conv("userdef_x6_check_s")),
)
return m
class _ComparesFKs:
def _assert_fk_diff(
self,

View File

@@ -6,7 +6,9 @@ from sqlalchemy import Table
from ._autogen_fixtures import AutogenFixtureTest
from ... import testing
from ...testing import config
from ...testing import eq_
from ...testing import exclusions
from ...testing import is_
from ...testing import is_true
from ...testing import mock
@@ -61,8 +63,18 @@ class AutogenerateComputedTest(AutogenFixtureTest, TestBase):
c = diffs[0][3]
eq_(c.name, "foo")
is_true(isinstance(c.computed, sa.Computed))
is_true(isinstance(c.server_default, sa.Computed))
if config.requirements.computed_reflects_normally.enabled:
is_true(isinstance(c.computed, sa.Computed))
else:
is_(c.computed, None)
if config.requirements.computed_reflects_as_server_default.enabled:
is_true(isinstance(c.server_default, sa.DefaultClause))
eq_(str(c.server_default.arg.text), "5")
elif config.requirements.computed_reflects_normally.enabled:
is_true(isinstance(c.computed, sa.Computed))
else:
is_(c.computed, None)
@testing.combinations(
lambda: (None, sa.Computed("bar*5")),
@@ -73,6 +85,7 @@ class AutogenerateComputedTest(AutogenFixtureTest, TestBase):
),
lambda: (sa.Computed("bar*5"), sa.Computed("bar * 42")),
)
@config.requirements.computed_reflects_normally
def test_cant_change_computed_warning(self, test_case):
arg_before, arg_after = testing.resolve_lambda(test_case, **locals())
m1 = MetaData()
@@ -111,7 +124,10 @@ class AutogenerateComputedTest(AutogenFixtureTest, TestBase):
lambda: (None, None),
lambda: (sa.Computed("5"), sa.Computed("5")),
lambda: (sa.Computed("bar*5"), sa.Computed("bar*5")),
lambda: (sa.Computed("bar*5"), sa.Computed("bar * \r\n\t5")),
(
lambda: (sa.Computed("bar*5"), None),
config.requirements.computed_doesnt_reflect_as_server_default,
),
)
def test_computed_unchanged(self, test_case):
arg_before, arg_after = testing.resolve_lambda(test_case, **locals())
@@ -142,3 +158,46 @@ class AutogenerateComputedTest(AutogenFixtureTest, TestBase):
eq_(mock_warn.mock_calls, [])
eq_(list(diffs), [])
@config.requirements.computed_reflects_as_server_default
def test_remove_computed_default_on_computed(self):
"""Asserts the current behavior which is that on PG and Oracle,
the GENERATED ALWAYS AS is reflected as a server default which we can't
tell is actually "computed", so these come out as a modification to
the server default.
"""
m1 = MetaData()
m2 = MetaData()
Table(
"user",
m1,
Column("id", Integer, primary_key=True),
Column("bar", Integer),
Column("foo", Integer, sa.Computed("bar + 42")),
)
Table(
"user",
m2,
Column("id", Integer, primary_key=True),
Column("bar", Integer),
Column("foo", Integer),
)
diffs = self._fixture(m1, m2)
eq_(diffs[0][0][0], "modify_default")
eq_(diffs[0][0][2], "user")
eq_(diffs[0][0][3], "foo")
old = diffs[0][0][-2]
new = diffs[0][0][-1]
is_(new, None)
is_true(isinstance(old, sa.DefaultClause))
if exclusions.against(config, "postgresql"):
eq_(str(old.arg.text), "(bar + 42)")
elif exclusions.against(config, "oracle"):
eq_(str(old.arg.text), '"BAR"+42')

View File

@@ -24,9 +24,9 @@ class MigrationTransactionTest(TestBase):
self.context = MigrationContext.configure(
dialect=conn.dialect, opts=opts
)
self.context.output_buffer = self.context.impl.output_buffer = (
io.StringIO()
)
self.context.output_buffer = (
self.context.impl.output_buffer
) = io.StringIO()
else:
self.context = MigrationContext.configure(
connection=conn, opts=opts

View File

@@ -10,6 +10,8 @@ import warnings
from sqlalchemy import exc as sa_exc
from ..util import sqla_14
def setup_filters():
"""Set global warning behavior for the test suite."""
@@ -21,6 +23,13 @@ def setup_filters():
# some selected deprecations...
warnings.filterwarnings("error", category=DeprecationWarning)
if not sqla_14:
# 1.3 uses pkg_resources in PluginLoader
warnings.filterwarnings(
"ignore",
"pkg_resources is deprecated as an API",
DeprecationWarning,
)
try:
import pytest
except ImportError: