main commit
All checks were successful
continuous-integration/drone/push Build is passing

This commit is contained in:
2025-10-16 16:30:25 +09:00
parent 91c7e04474
commit 537e7b363f
1146 changed files with 45926 additions and 77196 deletions

View File

@@ -1,29 +1,35 @@
from .editor import open_in_editor as open_in_editor
from .exc import AutogenerateDiffsDetected as AutogenerateDiffsDetected
from .exc import CommandError as CommandError
from .langhelpers import _with_legacy_names as _with_legacy_names
from .langhelpers import asbool as asbool
from .langhelpers import dedupe_tuple as dedupe_tuple
from .langhelpers import Dispatcher as Dispatcher
from .langhelpers import EMPTY_DICT as EMPTY_DICT
from .langhelpers import immutabledict as immutabledict
from .langhelpers import memoized_property as memoized_property
from .langhelpers import ModuleClsProxy as ModuleClsProxy
from .langhelpers import not_none as not_none
from .langhelpers import rev_id as rev_id
from .langhelpers import to_list as to_list
from .langhelpers import to_tuple as to_tuple
from .langhelpers import unique_list as unique_list
from .messaging import err as err
from .messaging import format_as_comma as format_as_comma
from .messaging import msg as msg
from .messaging import obfuscate_url_pw as obfuscate_url_pw
from .messaging import status as status
from .messaging import warn as warn
from .messaging import warn_deprecated as warn_deprecated
from .messaging import write_outstream as write_outstream
from .pyfiles import coerce_resource_to_filename as coerce_resource_to_filename
from .pyfiles import load_python_file as load_python_file
from .pyfiles import pyc_file_from_path as pyc_file_from_path
from .pyfiles import template_to_file as template_to_file
from .sqla_compat import sqla_2 as sqla_2
from .editor import open_in_editor
from .exc import AutogenerateDiffsDetected
from .exc import CommandError
from .langhelpers import _with_legacy_names
from .langhelpers import asbool
from .langhelpers import dedupe_tuple
from .langhelpers import Dispatcher
from .langhelpers import EMPTY_DICT
from .langhelpers import immutabledict
from .langhelpers import memoized_property
from .langhelpers import ModuleClsProxy
from .langhelpers import not_none
from .langhelpers import rev_id
from .langhelpers import to_list
from .langhelpers import to_tuple
from .langhelpers import unique_list
from .messaging import err
from .messaging import format_as_comma
from .messaging import msg
from .messaging import obfuscate_url_pw
from .messaging import status
from .messaging import warn
from .messaging import write_outstream
from .pyfiles import coerce_resource_to_filename
from .pyfiles import load_python_file
from .pyfiles import pyc_file_from_path
from .pyfiles import template_to_file
from .sqla_compat import has_computed
from .sqla_compat import sqla_13
from .sqla_compat import sqla_14
from .sqla_compat import sqla_2
if not sqla_13:
raise CommandError("SQLAlchemy 1.3.0 or greater is required.")

View File

@@ -1,37 +1,22 @@
# mypy: no-warn-unused-ignores
from __future__ import annotations
from configparser import ConfigParser
import io
import os
from pathlib import Path
import sys
import typing
from typing import Any
from typing import Iterator
from typing import List
from typing import Optional
from typing import Sequence
from typing import Union
if True:
# zimports hack for too-long names
from sqlalchemy.util import ( # noqa: F401
inspect_getfullargspec as inspect_getfullargspec,
)
from sqlalchemy.util.compat import ( # noqa: F401
inspect_formatargspec as inspect_formatargspec,
)
from sqlalchemy.util import inspect_getfullargspec # noqa
from sqlalchemy.util.compat import inspect_formatargspec # noqa
is_posix = os.name == "posix"
py314 = sys.version_info >= (3, 14)
py313 = sys.version_info >= (3, 13)
py312 = sys.version_info >= (3, 12)
py311 = sys.version_info >= (3, 11)
py310 = sys.version_info >= (3, 10)
py39 = sys.version_info >= (3, 9)
py38 = sys.version_info >= (3, 8)
# produce a wrapper that allows encoded text to stream
@@ -43,82 +28,24 @@ class EncodedIO(io.TextIOWrapper):
if py39:
from importlib import resources as _resources
importlib_resources = _resources
from importlib import metadata as _metadata
importlib_metadata = _metadata
from importlib.metadata import EntryPoint as EntryPoint
from importlib import resources as importlib_resources
from importlib import metadata as importlib_metadata
from importlib.metadata import EntryPoint
else:
import importlib_resources # type:ignore # noqa
import importlib_metadata # type:ignore # noqa
from importlib_metadata import EntryPoint # type:ignore # noqa
if py311:
import tomllib as tomllib
else:
import tomli as tomllib # type: ignore # noqa
if py312:
def path_walk(
path: Path, *, top_down: bool = True
) -> Iterator[tuple[Path, list[str], list[str]]]:
return Path.walk(path)
def path_relative_to(
path: Path, other: Path, *, walk_up: bool = False
) -> Path:
return path.relative_to(other, walk_up=walk_up)
else:
def path_walk(
path: Path, *, top_down: bool = True
) -> Iterator[tuple[Path, list[str], list[str]]]:
for root, dirs, files in os.walk(path, topdown=top_down):
yield Path(root), dirs, files
def path_relative_to(
path: Path, other: Path, *, walk_up: bool = False
) -> Path:
"""
Calculate the relative path of 'path' with respect to 'other',
optionally allowing 'path' to be outside the subtree of 'other'.
OK I used AI for this, sorry
"""
try:
return path.relative_to(other)
except ValueError:
if walk_up:
other_ancestors = list(other.parents) + [other]
for ancestor in other_ancestors:
try:
return path.relative_to(ancestor)
except ValueError:
continue
raise ValueError(
f"{path} is not in the same subtree as {other}"
)
else:
raise
def importlib_metadata_get(group: str) -> Sequence[EntryPoint]:
ep = importlib_metadata.entry_points()
if hasattr(ep, "select"):
return ep.select(group=group)
return ep.select(group=group) # type: ignore
else:
return ep.get(group, ()) # type: ignore
def formatannotation_fwdref(
annotation: Any, base_module: Optional[Any] = None
) -> str:
def formatannotation_fwdref(annotation, base_module=None):
"""vendored from python 3.7"""
# copied over _formatannotation from sqlalchemy 2.0
@@ -139,7 +66,7 @@ def formatannotation_fwdref(
def read_config_parser(
file_config: ConfigParser,
file_argument: Sequence[Union[str, os.PathLike[str]]],
) -> List[str]:
) -> list[str]:
if py310:
return file_config.read(file_argument, encoding="locale")
else:

View File

@@ -1,25 +1,6 @@
from __future__ import annotations
from typing import Any
from typing import List
from typing import Tuple
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from alembic.autogenerate import RevisionContext
class CommandError(Exception):
pass
class AutogenerateDiffsDetected(CommandError):
def __init__(
self,
message: str,
revision_context: RevisionContext,
diffs: List[Tuple[Any, ...]],
) -> None:
super().__init__(message)
self.revision_context = revision_context
self.diffs = diffs
pass

View File

@@ -5,46 +5,33 @@ from collections.abc import Iterable
import textwrap
from typing import Any
from typing import Callable
from typing import cast
from typing import Dict
from typing import List
from typing import Mapping
from typing import MutableMapping
from typing import NoReturn
from typing import Optional
from typing import overload
from typing import Sequence
from typing import Set
from typing import Tuple
from typing import Type
from typing import TYPE_CHECKING
from typing import TypeVar
from typing import Union
import uuid
import warnings
from sqlalchemy.util import asbool as asbool # noqa: F401
from sqlalchemy.util import immutabledict as immutabledict # noqa: F401
from sqlalchemy.util import to_list as to_list # noqa: F401
from sqlalchemy.util import unique_list as unique_list
from sqlalchemy.util import asbool # noqa
from sqlalchemy.util import immutabledict # noqa
from sqlalchemy.util import memoized_property # noqa
from sqlalchemy.util import to_list # noqa
from sqlalchemy.util import unique_list # noqa
from .compat import inspect_getfullargspec
if True:
# zimports workaround :(
from sqlalchemy.util import ( # noqa: F401
memoized_property as memoized_property,
)
EMPTY_DICT: Mapping[Any, Any] = immutabledict()
_T = TypeVar("_T", bound=Any)
_C = TypeVar("_C", bound=Callable[..., Any])
_T = TypeVar("_T")
class _ModuleClsMeta(type):
def __setattr__(cls, key: str, value: Callable[..., Any]) -> None:
def __setattr__(cls, key: str, value: Callable) -> None:
super().__setattr__(key, value)
cls._update_module_proxies(key) # type: ignore
@@ -58,13 +45,9 @@ class ModuleClsProxy(metaclass=_ModuleClsMeta):
"""
_setups: Dict[
Type[Any],
Tuple[
Set[str],
List[Tuple[MutableMapping[str, Any], MutableMapping[str, Any]]],
],
] = collections.defaultdict(lambda: (set(), []))
_setups: Dict[type, Tuple[set, list]] = collections.defaultdict(
lambda: (set(), [])
)
@classmethod
def _update_module_proxies(cls, name: str) -> None:
@@ -87,33 +70,18 @@ class ModuleClsProxy(metaclass=_ModuleClsMeta):
del globals_[attr_name]
@classmethod
def create_module_class_proxy(
cls,
globals_: MutableMapping[str, Any],
locals_: MutableMapping[str, Any],
) -> None:
def create_module_class_proxy(cls, globals_, locals_):
attr_names, modules = cls._setups[cls]
modules.append((globals_, locals_))
cls._setup_proxy(globals_, locals_, attr_names)
@classmethod
def _setup_proxy(
cls,
globals_: MutableMapping[str, Any],
locals_: MutableMapping[str, Any],
attr_names: Set[str],
) -> None:
def _setup_proxy(cls, globals_, locals_, attr_names):
for methname in dir(cls):
cls._add_proxied_attribute(methname, globals_, locals_, attr_names)
@classmethod
def _add_proxied_attribute(
cls,
methname: str,
globals_: MutableMapping[str, Any],
locals_: MutableMapping[str, Any],
attr_names: Set[str],
) -> None:
def _add_proxied_attribute(cls, methname, globals_, locals_, attr_names):
if not methname.startswith("_"):
meth = getattr(cls, methname)
if callable(meth):
@@ -124,15 +92,10 @@ class ModuleClsProxy(metaclass=_ModuleClsMeta):
attr_names.add(methname)
@classmethod
def _create_method_proxy(
cls,
name: str,
globals_: MutableMapping[str, Any],
locals_: MutableMapping[str, Any],
) -> Callable[..., Any]:
def _create_method_proxy(cls, name, globals_, locals_):
fn = getattr(cls, name)
def _name_error(name: str, from_: Exception) -> NoReturn:
def _name_error(name, from_):
raise NameError(
"Can't invoke function '%s', as the proxy object has "
"not yet been "
@@ -156,9 +119,7 @@ class ModuleClsProxy(metaclass=_ModuleClsMeta):
translations,
)
def translate(
fn_name: str, spec: Any, translations: Any, args: Any, kw: Any
) -> Any:
def translate(fn_name, spec, translations, args, kw):
return_kw = {}
return_args = []
@@ -215,15 +176,15 @@ class ModuleClsProxy(metaclass=_ModuleClsMeta):
"doc": fn.__doc__,
}
)
lcl: MutableMapping[str, Any] = {}
lcl = {}
exec(func_text, cast("Dict[str, Any]", globals_), lcl)
return cast("Callable[..., Any]", lcl[name])
exec(func_text, globals_, lcl)
return lcl[name]
def _with_legacy_names(translations: Any) -> Any:
def decorate(fn: _C) -> _C:
fn._legacy_translations = translations # type: ignore[attr-defined]
def _with_legacy_names(translations):
def decorate(fn):
fn._legacy_translations = translations
return fn
return decorate
@@ -234,22 +195,21 @@ def rev_id() -> str:
@overload
def to_tuple(x: Any, default: Tuple[Any, ...]) -> Tuple[Any, ...]: ...
def to_tuple(x: Any, default: tuple) -> tuple:
...
@overload
def to_tuple(x: None, default: Optional[_T] = ...) -> _T: ...
def to_tuple(x: None, default: Optional[_T] = None) -> _T:
...
@overload
def to_tuple(
x: Any, default: Optional[Tuple[Any, ...]] = None
) -> Tuple[Any, ...]: ...
def to_tuple(x: Any, default: Optional[tuple] = None) -> tuple:
...
def to_tuple(
x: Any, default: Optional[Tuple[Any, ...]] = None
) -> Optional[Tuple[Any, ...]]:
def to_tuple(x, default=None):
if x is None:
return default
elif isinstance(x, str):
@@ -266,13 +226,13 @@ def dedupe_tuple(tup: Tuple[str, ...]) -> Tuple[str, ...]:
class Dispatcher:
def __init__(self, uselist: bool = False) -> None:
self._registry: Dict[Tuple[Any, ...], Any] = {}
self._registry: Dict[tuple, Any] = {}
self.uselist = uselist
def dispatch_for(
self, target: Any, qualifier: str = "default"
) -> Callable[[_C], _C]:
def decorate(fn: _C) -> _C:
) -> Callable:
def decorate(fn):
if self.uselist:
self._registry.setdefault((target, qualifier), []).append(fn)
else:
@@ -284,7 +244,7 @@ class Dispatcher:
def dispatch(self, obj: Any, qualifier: str = "default") -> Any:
if isinstance(obj, str):
targets: Sequence[Any] = [obj]
targets: Sequence = [obj]
elif isinstance(obj, type):
targets = obj.__mro__
else:
@@ -299,13 +259,11 @@ class Dispatcher:
raise ValueError("no dispatch function for object: %s" % obj)
def _fn_or_list(
self, fn_or_list: Union[List[Callable[..., Any]], Callable[..., Any]]
) -> Callable[..., Any]:
self, fn_or_list: Union[List[Callable], Callable]
) -> Callable:
if self.uselist:
def go(*arg: Any, **kw: Any) -> None:
if TYPE_CHECKING:
assert isinstance(fn_or_list, Sequence)
def go(*arg, **kw):
for fn in fn_or_list:
fn(*arg, **kw)

View File

@@ -5,7 +5,6 @@ from contextlib import contextmanager
import logging
import sys
import textwrap
from typing import Iterator
from typing import Optional
from typing import TextIO
from typing import Union
@@ -13,6 +12,8 @@ import warnings
from sqlalchemy.engine import url
from . import sqla_compat
log = logging.getLogger(__name__)
# disable "no handler found" errors
@@ -52,9 +53,7 @@ def write_outstream(
@contextmanager
def status(
status_msg: str, newline: bool = False, quiet: bool = False
) -> Iterator[None]:
def status(status_msg: str, newline: bool = False, quiet: bool = False):
msg(status_msg + " ...", newline, flush=True, quiet=quiet)
try:
yield
@@ -67,24 +66,21 @@ def status(
write_outstream(sys.stdout, " done\n")
def err(message: str, quiet: bool = False) -> None:
def err(message: str, quiet: bool = False):
log.error(message)
msg(f"FAILED: {message}", quiet=quiet)
sys.exit(-1)
def obfuscate_url_pw(input_url: str) -> str:
return url.make_url(input_url).render_as_string(hide_password=True)
u = url.make_url(input_url)
return sqla_compat.url_render_as_string(u, hide_password=True)
def warn(msg: str, stacklevel: int = 2) -> None:
warnings.warn(msg, UserWarning, stacklevel=stacklevel)
def warn_deprecated(msg: str, stacklevel: int = 2) -> None:
warnings.warn(msg, DeprecationWarning, stacklevel=stacklevel)
def msg(
msg: str, newline: bool = True, flush: bool = False, quiet: bool = False
) -> None:
@@ -96,17 +92,11 @@ def msg(
write_outstream(sys.stdout, "\n")
else:
# left indent output lines
indent = " "
lines = textwrap.wrap(
msg,
TERMWIDTH,
initial_indent=indent,
subsequent_indent=indent,
)
lines = textwrap.wrap(msg, TERMWIDTH)
if len(lines) > 1:
for line in lines[0:-1]:
write_outstream(sys.stdout, line, "\n")
write_outstream(sys.stdout, lines[-1], ("\n" if newline else ""))
write_outstream(sys.stdout, " ", line, "\n")
write_outstream(sys.stdout, " ", lines[-1], ("\n" if newline else ""))
if flush:
sys.stdout.flush()

View File

@@ -6,13 +6,9 @@ import importlib
import importlib.machinery
import importlib.util
import os
import pathlib
import re
import tempfile
from types import ModuleType
from typing import Any
from typing import Optional
from typing import Union
from mako import exceptions
from mako.template import Template
@@ -22,14 +18,9 @@ from .exc import CommandError
def template_to_file(
template_file: Union[str, os.PathLike[str]],
dest: Union[str, os.PathLike[str]],
output_encoding: str,
*,
append_with_newlines: bool = False,
**kw: Any,
template_file: str, dest: str, output_encoding: str, **kw
) -> None:
template = Template(filename=_preserving_path_as_str(template_file))
template = Template(filename=template_file)
try:
output = template.render_unicode(**kw).encode(output_encoding)
except:
@@ -45,13 +36,11 @@ def template_to_file(
"template-oriented traceback." % fname
)
else:
with open(dest, "ab" if append_with_newlines else "wb") as f:
if append_with_newlines:
f.write("\n\n".encode(output_encoding))
with open(dest, "wb") as f:
f.write(output)
def coerce_resource_to_filename(fname_or_resource: str) -> pathlib.Path:
def coerce_resource_to_filename(fname: str) -> str:
"""Interpret a filename as either a filesystem location or as a package
resource.
@@ -59,9 +48,8 @@ def coerce_resource_to_filename(fname_or_resource: str) -> pathlib.Path:
are interpreted as resources and coerced to a file location.
"""
# TODO: there seem to be zero tests for the package resource codepath
if not os.path.isabs(fname_or_resource) and ":" in fname_or_resource:
tokens = fname_or_resource.split(":")
if not os.path.isabs(fname) and ":" in fname:
tokens = fname.split(":")
# from https://importlib-resources.readthedocs.io/en/latest/migration.html#pkg-resources-resource-filename # noqa E501
@@ -71,48 +59,37 @@ def coerce_resource_to_filename(fname_or_resource: str) -> pathlib.Path:
ref = compat.importlib_resources.files(tokens[0])
for tok in tokens[1:]:
ref = ref / tok
fname_or_resource = file_manager.enter_context( # type: ignore[assignment] # noqa: E501
fname = file_manager.enter_context( # type: ignore[assignment]
compat.importlib_resources.as_file(ref)
)
return pathlib.Path(fname_or_resource)
return fname
def pyc_file_from_path(
path: Union[str, os.PathLike[str]],
) -> Optional[pathlib.Path]:
def pyc_file_from_path(path: str) -> Optional[str]:
"""Given a python source path, locate the .pyc."""
pathpath = pathlib.Path(path)
candidate = pathlib.Path(
importlib.util.cache_from_source(pathpath.as_posix())
)
if candidate.exists():
candidate = importlib.util.cache_from_source(path)
if os.path.exists(candidate):
return candidate
# even for pep3147, fall back to the old way of finding .pyc files,
# to support sourceless operation
ext = pathpath.suffix
filepath, ext = os.path.splitext(path)
for ext in importlib.machinery.BYTECODE_SUFFIXES:
if pathpath.with_suffix(ext).exists():
return pathpath.with_suffix(ext)
if os.path.exists(filepath + ext):
return filepath + ext
else:
return None
def load_python_file(
dir_: Union[str, os.PathLike[str]], filename: Union[str, os.PathLike[str]]
) -> ModuleType:
def load_python_file(dir_: str, filename: str):
"""Load a file from the given path as a Python module."""
dir_ = pathlib.Path(dir_)
filename_as_path = pathlib.Path(filename)
filename = filename_as_path.name
module_id = re.sub(r"\W", "_", filename)
path = dir_ / filename
ext = path.suffix
path = os.path.join(dir_, filename)
_, ext = os.path.splitext(filename)
if ext == ".py":
if path.exists():
if os.path.exists(path):
module = load_module_py(module_id, path)
else:
pyc_path = pyc_file_from_path(path)
@@ -122,32 +99,12 @@ def load_python_file(
module = load_module_py(module_id, pyc_path)
elif ext in (".pyc", ".pyo"):
module = load_module_py(module_id, path)
else:
assert False
return module
def load_module_py(
module_id: str, path: Union[str, os.PathLike[str]]
) -> ModuleType:
def load_module_py(module_id: str, path: str):
spec = importlib.util.spec_from_file_location(module_id, path)
assert spec
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module) # type: ignore
return module
def _preserving_path_as_str(path: Union[str, os.PathLike[str]]) -> str:
"""receive str/pathlike and return a string.
Does not convert an incoming string path to a Path first, to help with
unit tests that are doing string path round trips without OS-specific
processing if not necessary.
"""
if isinstance(path, str):
return path
elif isinstance(path, pathlib.PurePath):
return str(path)
else:
return str(pathlib.Path(path))

View File

@@ -1,27 +1,24 @@
# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
# mypy: no-warn-return-any, allow-any-generics
from __future__ import annotations
import contextlib
import re
from typing import Any
from typing import Callable
from typing import Dict
from typing import Iterable
from typing import Iterator
from typing import Mapping
from typing import Optional
from typing import Protocol
from typing import Set
from typing import Type
from typing import TYPE_CHECKING
from typing import TypeVar
from typing import Union
from sqlalchemy import __version__
from sqlalchemy import inspect
from sqlalchemy import schema
from sqlalchemy import sql
from sqlalchemy import types as sqltypes
from sqlalchemy.engine import url
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.schema import CheckConstraint
from sqlalchemy.schema import Column
from sqlalchemy.schema import ForeignKeyConstraint
@@ -29,33 +26,31 @@ from sqlalchemy.sql import visitors
from sqlalchemy.sql.base import DialectKWArgs
from sqlalchemy.sql.elements import BindParameter
from sqlalchemy.sql.elements import ColumnClause
from sqlalchemy.sql.elements import quoted_name
from sqlalchemy.sql.elements import TextClause
from sqlalchemy.sql.elements import UnaryExpression
from sqlalchemy.sql.naming import _NONE_NAME as _NONE_NAME # type: ignore[attr-defined] # noqa: E501
from sqlalchemy.sql.visitors import traverse
from typing_extensions import TypeGuard
if TYPE_CHECKING:
from sqlalchemy import ClauseElement
from sqlalchemy import Identity
from sqlalchemy import Index
from sqlalchemy import Table
from sqlalchemy.engine import Connection
from sqlalchemy.engine import Dialect
from sqlalchemy.engine import Transaction
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.sql.base import ColumnCollection
from sqlalchemy.sql.compiler import SQLCompiler
from sqlalchemy.sql.dml import Insert
from sqlalchemy.sql.elements import ColumnElement
from sqlalchemy.sql.schema import Constraint
from sqlalchemy.sql.schema import SchemaItem
from sqlalchemy.sql.selectable import Select
from sqlalchemy.sql.selectable import TableClause
_CE = TypeVar("_CE", bound=Union["ColumnElement[Any]", "SchemaItem"])
class _CompilerProtocol(Protocol):
def __call__(self, element: Any, compiler: Any, **kw: Any) -> str: ...
def _safe_int(value: str) -> Union[int, str]:
try:
return int(value)
@@ -66,65 +61,90 @@ def _safe_int(value: str) -> Union[int, str]:
_vers = tuple(
[_safe_int(x) for x in re.findall(r"(\d+|[abc]\d)", __version__)]
)
sqla_13 = _vers >= (1, 3)
sqla_14 = _vers >= (1, 4)
# https://docs.sqlalchemy.org/en/latest/changelog/changelog_14.html#change-0c6e0cc67dfe6fac5164720e57ef307d
sqla_14_18 = _vers >= (1, 4, 18)
sqla_14_26 = _vers >= (1, 4, 26)
sqla_2 = _vers >= (2,)
sqlalchemy_version = __version__
if TYPE_CHECKING:
try:
from sqlalchemy.sql.naming import _NONE_NAME as _NONE_NAME
except ImportError:
from sqlalchemy.sql.elements import _NONE_NAME as _NONE_NAME # type: ignore # noqa: E501
def compiles(
element: Type[ClauseElement], *dialects: str
) -> Callable[[_CompilerProtocol], _CompilerProtocol]: ...
class _Unsupported:
"Placeholder for unsupported SQLAlchemy classes"
try:
from sqlalchemy import Computed
except ImportError:
if not TYPE_CHECKING:
class Computed(_Unsupported):
pass
has_computed = False
has_computed_reflection = False
else:
from sqlalchemy.ext.compiler import compiles # noqa: I100,I202
has_computed = True
has_computed_reflection = _vers >= (1, 3, 16)
try:
from sqlalchemy import Identity
except ImportError:
if not TYPE_CHECKING:
identity_has_dialect_kwargs = issubclass(schema.Identity, DialectKWArgs)
class Identity(_Unsupported):
pass
has_identity = False
else:
identity_has_dialect_kwargs = issubclass(Identity, DialectKWArgs)
def _get_identity_options_dict(
identity: Union[Identity, schema.Sequence, None],
dialect_kwargs: bool = False,
) -> Dict[str, Any]:
if identity is None:
return {}
elif identity_has_dialect_kwargs:
assert hasattr(identity, "_as_dict")
as_dict = identity._as_dict()
if dialect_kwargs:
assert isinstance(identity, DialectKWArgs)
as_dict.update(identity.dialect_kwargs)
else:
as_dict = {}
if isinstance(identity, schema.Identity):
# always=None means something different than always=False
as_dict["always"] = identity.always
if identity.on_null is not None:
as_dict["on_null"] = identity.on_null
# attributes common to Identity and Sequence
attrs = (
"start",
"increment",
"minvalue",
"maxvalue",
"nominvalue",
"nomaxvalue",
"cycle",
"cache",
"order",
)
as_dict.update(
{
key: getattr(identity, key, None)
for key in attrs
if getattr(identity, key, None) is not None
}
)
return as_dict
def _get_identity_options_dict(
identity: Union[Identity, schema.Sequence, None],
dialect_kwargs: bool = False,
) -> Dict[str, Any]:
if identity is None:
return {}
elif identity_has_dialect_kwargs:
as_dict = identity._as_dict() # type: ignore
if dialect_kwargs:
assert isinstance(identity, DialectKWArgs)
as_dict.update(identity.dialect_kwargs)
else:
as_dict = {}
if isinstance(identity, Identity):
# always=None means something different than always=False
as_dict["always"] = identity.always
if identity.on_null is not None:
as_dict["on_null"] = identity.on_null
# attributes common to Identity and Sequence
attrs = (
"start",
"increment",
"minvalue",
"maxvalue",
"nominvalue",
"nomaxvalue",
"cycle",
"cache",
"order",
)
as_dict.update(
{
key: getattr(identity, key, None)
for key in attrs
if getattr(identity, key, None) is not None
}
)
return as_dict
has_identity = True
if sqla_2:
from sqlalchemy.sql.base import _NoneName
@@ -133,6 +153,7 @@ else:
_ConstraintName = Union[None, str, _NoneName]
_ConstraintNameDefined = Union[str, _NoneName]
@@ -142,11 +163,15 @@ def constraint_name_defined(
return name is _NONE_NAME or isinstance(name, (str, _NoneName))
def constraint_name_string(name: _ConstraintName) -> TypeGuard[str]:
def constraint_name_string(
name: _ConstraintName,
) -> TypeGuard[str]:
return isinstance(name, str)
def constraint_name_or_none(name: _ConstraintName) -> Optional[str]:
def constraint_name_or_none(
name: _ConstraintName,
) -> Optional[str]:
return name if constraint_name_string(name) else None
@@ -176,10 +201,17 @@ def _ensure_scope_for_ddl(
yield
def url_render_as_string(url, hide_password=True):
if sqla_14:
return url.render_as_string(hide_password=hide_password)
else:
return url.__to_string__(hide_password=hide_password)
def _safe_begin_connection_transaction(
connection: Connection,
) -> Transaction:
transaction = connection.get_transaction()
transaction = _get_connection_transaction(connection)
if transaction:
return transaction
else:
@@ -189,7 +221,7 @@ def _safe_begin_connection_transaction(
def _safe_commit_connection_transaction(
connection: Connection,
) -> None:
transaction = connection.get_transaction()
transaction = _get_connection_transaction(connection)
if transaction:
transaction.commit()
@@ -197,7 +229,7 @@ def _safe_commit_connection_transaction(
def _safe_rollback_connection_transaction(
connection: Connection,
) -> None:
transaction = connection.get_transaction()
transaction = _get_connection_transaction(connection)
if transaction:
transaction.rollback()
@@ -218,34 +250,70 @@ def _idx_table_bound_expressions(idx: Index) -> Iterable[ColumnElement[Any]]:
def _copy(schema_item: _CE, **kw) -> _CE:
if hasattr(schema_item, "_copy"):
return schema_item._copy(**kw)
return schema_item._copy(**kw) # type: ignore[union-attr]
else:
return schema_item.copy(**kw) # type: ignore[union-attr]
def _get_connection_transaction(
connection: Connection,
) -> Optional[Transaction]:
if sqla_14:
return connection.get_transaction()
else:
r = connection._root # type: ignore[attr-defined]
return r._Connection__transaction
def _create_url(*arg, **kw) -> url.URL:
if hasattr(url.URL, "create"):
return url.URL.create(*arg, **kw)
else:
return url.URL(*arg, **kw)
def _connectable_has_table(
connectable: Connection, tablename: str, schemaname: Union[str, None]
) -> bool:
return connectable.dialect.has_table(connectable, tablename, schemaname)
if sqla_14:
return inspect(connectable).has_table(tablename, schemaname)
else:
return connectable.dialect.has_table(
connectable, tablename, schemaname
)
def _exec_on_inspector(inspector, statement, **params):
with inspector._operation_context() as conn:
return conn.execute(statement, params)
if sqla_14:
with inspector._operation_context() as conn:
return conn.execute(statement, params)
else:
return inspector.bind.execute(statement, params)
def _nullability_might_be_unset(metadata_column):
from sqlalchemy.sql import schema
if not sqla_14:
return metadata_column.nullable
else:
from sqlalchemy.sql import schema
return metadata_column._user_defined_nullable is schema.NULL_UNSPECIFIED
return (
metadata_column._user_defined_nullable is schema.NULL_UNSPECIFIED
)
def _server_default_is_computed(*server_default) -> bool:
return any(isinstance(sd, schema.Computed) for sd in server_default)
if not has_computed:
return False
else:
return any(isinstance(sd, Computed) for sd in server_default)
def _server_default_is_identity(*server_default) -> bool:
return any(isinstance(sd, schema.Identity) for sd in server_default)
if not sqla_14:
return False
else:
return any(isinstance(sd, Identity) for sd in server_default)
def _table_for_constraint(constraint: Constraint) -> Table:
@@ -266,6 +334,15 @@ def _columns_for_constraint(constraint):
return list(constraint.columns)
def _reflect_table(inspector: Inspector, table: Table) -> None:
if sqla_14:
return inspector.reflect_table(table, None)
else:
return inspector.reflecttable( # type: ignore[attr-defined]
table, None
)
def _resolve_for_variant(type_, dialect):
if _type_has_variants(type_):
base_type, mapping = _get_variant_mapping(type_)
@@ -274,7 +351,7 @@ def _resolve_for_variant(type_, dialect):
return type_
if hasattr(sqltypes.TypeEngine, "_variant_mapping"): # 2.0
if hasattr(sqltypes.TypeEngine, "_variant_mapping"):
def _type_has_variants(type_):
return bool(type_._variant_mapping)
@@ -291,12 +368,7 @@ else:
return type_.impl, type_.mapping
def _fk_spec(constraint: ForeignKeyConstraint) -> Any:
if TYPE_CHECKING:
assert constraint.columns is not None
assert constraint.elements is not None
assert isinstance(constraint.parent, Table)
def _fk_spec(constraint):
source_columns = [
constraint.columns[key].name for key in constraint.column_keys
]
@@ -325,7 +397,7 @@ def _fk_spec(constraint: ForeignKeyConstraint) -> Any:
def _fk_is_self_referential(constraint: ForeignKeyConstraint) -> bool:
spec = constraint.elements[0]._get_colspec()
spec = constraint.elements[0]._get_colspec() # type: ignore[attr-defined]
tokens = spec.split(".")
tokens.pop(-1) # colname
tablekey = ".".join(tokens)
@@ -337,13 +409,13 @@ def _is_type_bound(constraint: Constraint) -> bool:
# this deals with SQLAlchemy #3260, don't copy CHECK constraints
# that will be generated by the type.
# new feature added for #3260
return constraint._type_bound
return constraint._type_bound # type: ignore[attr-defined]
def _find_columns(clause):
"""locate Column objects within the given expression."""
cols: Set[ColumnElement[Any]] = set()
cols = set()
traverse(clause, {}, {"column": cols.add})
return cols
@@ -430,7 +502,7 @@ class _textual_index_element(sql.ColumnElement):
self.fake_column = schema.Column(self.text.text, sqltypes.NULLTYPE)
table.append_column(self.fake_column)
def get_children(self, **kw):
def get_children(self):
return [self.fake_column]
@@ -452,44 +524,116 @@ def _render_literal_bindparam(
return compiler.render_literal_bindparam(element, **kw)
def _get_index_expressions(idx):
return list(idx.expressions)
def _get_index_column_names(idx):
return [getattr(exp, "name", None) for exp in _get_index_expressions(idx)]
def _column_kwargs(col: Column) -> Mapping:
if sqla_13:
return col.kwargs
else:
return {}
def _get_constraint_final_name(
constraint: Union[Index, Constraint], dialect: Optional[Dialect]
) -> Optional[str]:
if constraint.name is None:
return None
assert dialect is not None
# for SQLAlchemy 1.4 we would like to have the option to expand
# the use of "deferred" names for constraints as well as to have
# some flexibility with "None" name and similar; make use of new
# SQLAlchemy API to return what would be the final compiled form of
# the name for this dialect.
return dialect.identifier_preparer.format_constraint(
constraint, _alembic_quote=False
)
if sqla_14:
# for SQLAlchemy 1.4 we would like to have the option to expand
# the use of "deferred" names for constraints as well as to have
# some flexibility with "None" name and similar; make use of new
# SQLAlchemy API to return what would be the final compiled form of
# the name for this dialect.
return dialect.identifier_preparer.format_constraint(
constraint, _alembic_quote=False
)
else:
# prior to SQLAlchemy 1.4, work around quoting logic to get at the
# final compiled name without quotes.
if hasattr(constraint.name, "quote"):
# might be quoted_name, might be truncated_name, keep it the
# same
quoted_name_cls: type = type(constraint.name)
else:
quoted_name_cls = quoted_name
new_name = quoted_name_cls(str(constraint.name), quote=False)
constraint = constraint.__class__(name=new_name)
if isinstance(constraint, schema.Index):
# name should not be quoted.
d = dialect.ddl_compiler(dialect, None) # type: ignore[arg-type]
return d._prepared_index_name( # type: ignore[attr-defined]
constraint
)
else:
# name should not be quoted.
return dialect.identifier_preparer.format_constraint(constraint)
def _constraint_is_named(
constraint: Union[Constraint, Index], dialect: Optional[Dialect]
) -> bool:
if constraint.name is None:
return False
assert dialect is not None
name = dialect.identifier_preparer.format_constraint(
constraint, _alembic_quote=False
)
return name is not None
if sqla_14:
if constraint.name is None:
return False
assert dialect is not None
name = dialect.identifier_preparer.format_constraint(
constraint, _alembic_quote=False
)
return name is not None
else:
return constraint.name is not None
def _is_mariadb(mysql_dialect: Dialect) -> bool:
if sqla_14:
return mysql_dialect.is_mariadb # type: ignore[attr-defined]
else:
return bool(
mysql_dialect.server_version_info
and mysql_dialect._is_mariadb # type: ignore[attr-defined]
)
def _mariadb_normalized_version_info(mysql_dialect):
return mysql_dialect._mariadb_normalized_version_info
def _insert_inline(table: Union[TableClause, Table]) -> Insert:
if sqla_14:
return table.insert().inline()
else:
return table.insert(inline=True) # type: ignore[call-arg]
if sqla_14:
from sqlalchemy import create_mock_engine
from sqlalchemy import select as _select
else:
from sqlalchemy import create_engine
def create_mock_engine(url, executor, **kw): # type: ignore[misc]
return create_engine(
"postgresql://", strategy="mock", executor=executor
)
def _select(*columns, **kw) -> Select: # type: ignore[no-redef]
return sql.select(list(columns), **kw) # type: ignore[call-overload]
def is_expression_index(index: Index) -> bool:
expr: Any
for expr in index.expressions:
if is_expression(expr):
while isinstance(expr, UnaryExpression):
expr = expr.element
if not isinstance(expr, ColumnClause) or expr.is_literal:
return True
return False
def is_expression(expr: Any) -> bool:
while isinstance(expr, UnaryExpression):
expr = expr.element
if not isinstance(expr, ColumnClause) or expr.is_literal:
return True
return False