API refactor
All checks were successful
continuous-integration/drone/push Build is passing

This commit is contained in:
2025-10-07 16:25:52 +09:00
parent 76d0d86211
commit 91c7e04474
1171 changed files with 81940 additions and 44117 deletions

View File

@@ -3,11 +3,14 @@ from __future__ import annotations
import configparser
from contextlib import contextmanager
import io
import os
import re
import shutil
from typing import Any
from typing import Dict
from sqlalchemy import Column
from sqlalchemy import create_mock_engine
from sqlalchemy import inspect
from sqlalchemy import MetaData
from sqlalchemy import String
@@ -17,20 +20,19 @@ from sqlalchemy import text
from sqlalchemy.testing import config
from sqlalchemy.testing import mock
from sqlalchemy.testing.assertions import eq_
from sqlalchemy.testing.fixtures import FutureEngineMixin
from sqlalchemy.testing.fixtures import TablesTest as SQLAlchemyTablesTest
from sqlalchemy.testing.fixtures import TestBase as SQLAlchemyTestBase
import alembic
from .assertions import _get_dialect
from .env import _get_staging_directory
from ..environment import EnvironmentContext
from ..migration import MigrationContext
from ..operations import Operations
from ..util import sqla_compat
from ..util.sqla_compat import create_mock_engine
from ..util.sqla_compat import sqla_14
from ..util.sqla_compat import sqla_2
testing_config = configparser.ConfigParser()
testing_config.read(["test.cfg"])
@@ -38,6 +40,31 @@ testing_config.read(["test.cfg"])
class TestBase(SQLAlchemyTestBase):
is_sqlalchemy_future = sqla_2
@testing.fixture()
def clear_staging_dir(self):
yield
location = _get_staging_directory()
for filename in os.listdir(location):
file_path = os.path.join(location, filename)
if os.path.isfile(file_path) or os.path.islink(file_path):
os.unlink(file_path)
elif os.path.isdir(file_path):
shutil.rmtree(file_path)
@contextmanager
def pushd(self, dirname):
current_dir = os.getcwd()
try:
os.chdir(dirname)
yield
finally:
os.chdir(current_dir)
@testing.fixture()
def pop_alembic_config_env(self):
yield
os.environ.pop("ALEMBIC_CONFIG", None)
@testing.fixture()
def ops_context(self, migration_context):
with migration_context.begin_transaction(_per_migration=True):
@@ -49,6 +76,12 @@ class TestBase(SQLAlchemyTestBase):
connection, opts=dict(transaction_per_migration=True)
)
@testing.fixture
def as_sql_migration_context(self, connection):
return MigrationContext.configure(
connection, opts=dict(transaction_per_migration=True, as_sql=True)
)
@testing.fixture
def connection(self):
with config.db.connect() as conn:
@@ -59,14 +92,6 @@ class TablesTest(TestBase, SQLAlchemyTablesTest):
pass
if sqla_14:
from sqlalchemy.testing.fixtures import FutureEngineMixin
else:
class FutureEngineMixin: # type:ignore[no-redef]
__requires__ = ("sqlalchemy_14",)
FutureEngineMixin.is_sqlalchemy_future = True
@@ -184,12 +209,8 @@ def op_fixture(
opts["as_sql"] = as_sql
if literal_binds:
opts["literal_binds"] = literal_binds
if not sqla_14 and dialect == "mariadb":
ctx_dialect = _get_dialect("mysql")
ctx_dialect.server_version_info = (10, 4, 0, "MariaDB")
else:
ctx_dialect = _get_dialect(dialect)
ctx_dialect = _get_dialect(dialect)
if native_boolean is not None:
ctx_dialect.supports_native_boolean = native_boolean
# this is new as of SQLAlchemy 1.2.7 and is used by SQL Server,
@@ -268,9 +289,11 @@ class AlterColRoundTripFixture:
"x",
column.name,
existing_type=column.type,
existing_server_default=column.server_default
if column.server_default is not None
else False,
existing_server_default=(
column.server_default
if column.server_default is not None
else False
),
existing_nullable=True if column.nullable else False,
# existing_comment=column.comment,
nullable=to_.get("nullable", None),
@@ -298,9 +321,13 @@ class AlterColRoundTripFixture:
new_col["type"],
new_col.get("default", None),
compare.get("type", old_col["type"]),
compare["server_default"].text
if "server_default" in compare
else column.server_default.arg.text
if column.server_default is not None
else None,
(
compare["server_default"].text
if "server_default" in compare
else (
column.server_default.arg.text
if column.server_default is not None
else None
)
),
)