This commit is contained in:
@@ -3,11 +3,14 @@ from __future__ import annotations
|
||||
import configparser
|
||||
from contextlib import contextmanager
|
||||
import io
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
|
||||
from sqlalchemy import Column
|
||||
from sqlalchemy import create_mock_engine
|
||||
from sqlalchemy import inspect
|
||||
from sqlalchemy import MetaData
|
||||
from sqlalchemy import String
|
||||
@@ -17,20 +20,19 @@ from sqlalchemy import text
|
||||
from sqlalchemy.testing import config
|
||||
from sqlalchemy.testing import mock
|
||||
from sqlalchemy.testing.assertions import eq_
|
||||
from sqlalchemy.testing.fixtures import FutureEngineMixin
|
||||
from sqlalchemy.testing.fixtures import TablesTest as SQLAlchemyTablesTest
|
||||
from sqlalchemy.testing.fixtures import TestBase as SQLAlchemyTestBase
|
||||
|
||||
import alembic
|
||||
from .assertions import _get_dialect
|
||||
from .env import _get_staging_directory
|
||||
from ..environment import EnvironmentContext
|
||||
from ..migration import MigrationContext
|
||||
from ..operations import Operations
|
||||
from ..util import sqla_compat
|
||||
from ..util.sqla_compat import create_mock_engine
|
||||
from ..util.sqla_compat import sqla_14
|
||||
from ..util.sqla_compat import sqla_2
|
||||
|
||||
|
||||
testing_config = configparser.ConfigParser()
|
||||
testing_config.read(["test.cfg"])
|
||||
|
||||
@@ -38,6 +40,31 @@ testing_config.read(["test.cfg"])
|
||||
class TestBase(SQLAlchemyTestBase):
|
||||
is_sqlalchemy_future = sqla_2
|
||||
|
||||
@testing.fixture()
|
||||
def clear_staging_dir(self):
|
||||
yield
|
||||
location = _get_staging_directory()
|
||||
for filename in os.listdir(location):
|
||||
file_path = os.path.join(location, filename)
|
||||
if os.path.isfile(file_path) or os.path.islink(file_path):
|
||||
os.unlink(file_path)
|
||||
elif os.path.isdir(file_path):
|
||||
shutil.rmtree(file_path)
|
||||
|
||||
@contextmanager
|
||||
def pushd(self, dirname):
|
||||
current_dir = os.getcwd()
|
||||
try:
|
||||
os.chdir(dirname)
|
||||
yield
|
||||
finally:
|
||||
os.chdir(current_dir)
|
||||
|
||||
@testing.fixture()
|
||||
def pop_alembic_config_env(self):
|
||||
yield
|
||||
os.environ.pop("ALEMBIC_CONFIG", None)
|
||||
|
||||
@testing.fixture()
|
||||
def ops_context(self, migration_context):
|
||||
with migration_context.begin_transaction(_per_migration=True):
|
||||
@@ -49,6 +76,12 @@ class TestBase(SQLAlchemyTestBase):
|
||||
connection, opts=dict(transaction_per_migration=True)
|
||||
)
|
||||
|
||||
@testing.fixture
|
||||
def as_sql_migration_context(self, connection):
|
||||
return MigrationContext.configure(
|
||||
connection, opts=dict(transaction_per_migration=True, as_sql=True)
|
||||
)
|
||||
|
||||
@testing.fixture
|
||||
def connection(self):
|
||||
with config.db.connect() as conn:
|
||||
@@ -59,14 +92,6 @@ class TablesTest(TestBase, SQLAlchemyTablesTest):
|
||||
pass
|
||||
|
||||
|
||||
if sqla_14:
|
||||
from sqlalchemy.testing.fixtures import FutureEngineMixin
|
||||
else:
|
||||
|
||||
class FutureEngineMixin: # type:ignore[no-redef]
|
||||
__requires__ = ("sqlalchemy_14",)
|
||||
|
||||
|
||||
FutureEngineMixin.is_sqlalchemy_future = True
|
||||
|
||||
|
||||
@@ -184,12 +209,8 @@ def op_fixture(
|
||||
opts["as_sql"] = as_sql
|
||||
if literal_binds:
|
||||
opts["literal_binds"] = literal_binds
|
||||
if not sqla_14 and dialect == "mariadb":
|
||||
ctx_dialect = _get_dialect("mysql")
|
||||
ctx_dialect.server_version_info = (10, 4, 0, "MariaDB")
|
||||
|
||||
else:
|
||||
ctx_dialect = _get_dialect(dialect)
|
||||
ctx_dialect = _get_dialect(dialect)
|
||||
if native_boolean is not None:
|
||||
ctx_dialect.supports_native_boolean = native_boolean
|
||||
# this is new as of SQLAlchemy 1.2.7 and is used by SQL Server,
|
||||
@@ -268,9 +289,11 @@ class AlterColRoundTripFixture:
|
||||
"x",
|
||||
column.name,
|
||||
existing_type=column.type,
|
||||
existing_server_default=column.server_default
|
||||
if column.server_default is not None
|
||||
else False,
|
||||
existing_server_default=(
|
||||
column.server_default
|
||||
if column.server_default is not None
|
||||
else False
|
||||
),
|
||||
existing_nullable=True if column.nullable else False,
|
||||
# existing_comment=column.comment,
|
||||
nullable=to_.get("nullable", None),
|
||||
@@ -298,9 +321,13 @@ class AlterColRoundTripFixture:
|
||||
new_col["type"],
|
||||
new_col.get("default", None),
|
||||
compare.get("type", old_col["type"]),
|
||||
compare["server_default"].text
|
||||
if "server_default" in compare
|
||||
else column.server_default.arg.text
|
||||
if column.server_default is not None
|
||||
else None,
|
||||
(
|
||||
compare["server_default"].text
|
||||
if "server_default" in compare
|
||||
else (
|
||||
column.server_default.arg.text
|
||||
if column.server_default is not None
|
||||
else None
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user