main commit
All checks were successful
continuous-integration/drone/push Build is passing

This commit is contained in:
2025-10-16 16:30:25 +09:00
parent 91c7e04474
commit 537e7b363f
1146 changed files with 45926 additions and 77196 deletions

View File

@@ -1,5 +1,5 @@
# dialects/mysql/__init__.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# mysql/__init__.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@@ -53,8 +53,7 @@ from .base import YEAR
from .dml import Insert
from .dml import insert
from .expression import match
from .mariadb import INET4
from .mariadb import INET6
from ...util import compat
# default dialect
base.dialect = dialect = mysqldb.dialect
@@ -72,8 +71,6 @@ __all__ = (
"DOUBLE",
"ENUM",
"FLOAT",
"INET4",
"INET6",
"INTEGER",
"INTEGER",
"JSON",

View File

@@ -1,9 +1,10 @@
# dialects/mysql/aiomysql.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors <see AUTHORS
# mysql/aiomysql.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors <see AUTHORS
# file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
r"""
.. dialect:: mysql+aiomysql
@@ -22,108 +23,207 @@ This dialect should normally be used only with the
:func:`_asyncio.create_async_engine` engine creation function::
from sqlalchemy.ext.asyncio import create_async_engine
engine = create_async_engine("mysql+aiomysql://user:pass@hostname/dbname?charset=utf8mb4")
engine = create_async_engine(
"mysql+aiomysql://user:pass@hostname/dbname?charset=utf8mb4"
)
""" # noqa
from __future__ import annotations
from types import ModuleType
from typing import Any
from typing import Dict
from typing import Optional
from typing import Tuple
from typing import TYPE_CHECKING
from typing import Union
from .pymysql import MySQLDialect_pymysql
from ... import pool
from ... import util
from ...connectors.asyncio import AsyncAdapt_dbapi_connection
from ...connectors.asyncio import AsyncAdapt_dbapi_cursor
from ...connectors.asyncio import AsyncAdapt_dbapi_module
from ...connectors.asyncio import AsyncAdapt_dbapi_ss_cursor
from ...engine import AdaptedConnection
from ...util.concurrency import asyncio
from ...util.concurrency import await_fallback
from ...util.concurrency import await_only
if TYPE_CHECKING:
from ...connectors.asyncio import AsyncIODBAPIConnection
from ...connectors.asyncio import AsyncIODBAPICursor
from ...engine.interfaces import ConnectArgsType
from ...engine.interfaces import DBAPIConnection
from ...engine.interfaces import DBAPICursor
from ...engine.interfaces import DBAPIModule
from ...engine.interfaces import PoolProxiedConnection
from ...engine.url import URL
class AsyncAdapt_aiomysql_cursor:
# TODO: base on connectors/asyncio.py
# see #10415
server_side = False
__slots__ = (
"_adapt_connection",
"_connection",
"await_",
"_cursor",
"_rows",
)
def __init__(self, adapt_connection):
self._adapt_connection = adapt_connection
self._connection = adapt_connection._connection
self.await_ = adapt_connection.await_
class AsyncAdapt_aiomysql_cursor(AsyncAdapt_dbapi_cursor):
__slots__ = ()
cursor = self._connection.cursor(adapt_connection.dbapi.Cursor)
def _make_new_cursor(
self, connection: AsyncIODBAPIConnection
) -> AsyncIODBAPICursor:
return connection.cursor(self._adapt_connection.dbapi.Cursor)
# see https://github.com/aio-libs/aiomysql/issues/543
self._cursor = self.await_(cursor.__aenter__())
self._rows = []
@property
def description(self):
return self._cursor.description
class AsyncAdapt_aiomysql_ss_cursor(
AsyncAdapt_dbapi_ss_cursor, AsyncAdapt_aiomysql_cursor
):
__slots__ = ()
@property
def rowcount(self):
return self._cursor.rowcount
def _make_new_cursor(
self, connection: AsyncIODBAPIConnection
) -> AsyncIODBAPICursor:
return connection.cursor(
self._adapt_connection.dbapi.aiomysql.cursors.SSCursor
@property
def arraysize(self):
return self._cursor.arraysize
@arraysize.setter
def arraysize(self, value):
self._cursor.arraysize = value
@property
def lastrowid(self):
return self._cursor.lastrowid
def close(self):
# note we aren't actually closing the cursor here,
# we are just letting GC do it. to allow this to be async
# we would need the Result to change how it does "Safe close cursor".
# MySQL "cursors" don't actually have state to be "closed" besides
# exhausting rows, which we already have done for sync cursor.
# another option would be to emulate aiosqlite dialect and assign
# cursor only if we are doing server side cursor operation.
self._rows[:] = []
def execute(self, operation, parameters=None):
return self.await_(self._execute_async(operation, parameters))
def executemany(self, operation, seq_of_parameters):
return self.await_(
self._executemany_async(operation, seq_of_parameters)
)
async def _execute_async(self, operation, parameters):
async with self._adapt_connection._execute_mutex:
result = await self._cursor.execute(operation, parameters)
class AsyncAdapt_aiomysql_connection(AsyncAdapt_dbapi_connection):
if not self.server_side:
# aiomysql has a "fake" async result, so we have to pull it out
# of that here since our default result is not async.
# we could just as easily grab "_rows" here and be done with it
# but this is safer.
self._rows = list(await self._cursor.fetchall())
return result
async def _executemany_async(self, operation, seq_of_parameters):
async with self._adapt_connection._execute_mutex:
return await self._cursor.executemany(operation, seq_of_parameters)
def setinputsizes(self, *inputsizes):
pass
def __iter__(self):
while self._rows:
yield self._rows.pop(0)
def fetchone(self):
if self._rows:
return self._rows.pop(0)
else:
return None
def fetchmany(self, size=None):
if size is None:
size = self.arraysize
retval = self._rows[0:size]
self._rows[:] = self._rows[size:]
return retval
def fetchall(self):
retval = self._rows[:]
self._rows[:] = []
return retval
class AsyncAdapt_aiomysql_ss_cursor(AsyncAdapt_aiomysql_cursor):
# TODO: base on connectors/asyncio.py
# see #10415
__slots__ = ()
server_side = True
_cursor_cls = AsyncAdapt_aiomysql_cursor
_ss_cursor_cls = AsyncAdapt_aiomysql_ss_cursor
def __init__(self, adapt_connection):
self._adapt_connection = adapt_connection
self._connection = adapt_connection._connection
self.await_ = adapt_connection.await_
def ping(self, reconnect: bool) -> None:
assert not reconnect
self.await_(self._connection.ping(reconnect))
cursor = self._connection.cursor(adapt_connection.dbapi.SSCursor)
def character_set_name(self) -> Optional[str]:
return self._connection.character_set_name() # type: ignore[no-any-return] # noqa: E501
self._cursor = self.await_(cursor.__aenter__())
def autocommit(self, value: Any) -> None:
def close(self):
if self._cursor is not None:
self.await_(self._cursor.close())
self._cursor = None
def fetchone(self):
return self.await_(self._cursor.fetchone())
def fetchmany(self, size=None):
return self.await_(self._cursor.fetchmany(size=size))
def fetchall(self):
return self.await_(self._cursor.fetchall())
class AsyncAdapt_aiomysql_connection(AdaptedConnection):
# TODO: base on connectors/asyncio.py
# see #10415
await_ = staticmethod(await_only)
__slots__ = ("dbapi", "_execute_mutex")
def __init__(self, dbapi, connection):
self.dbapi = dbapi
self._connection = connection
self._execute_mutex = asyncio.Lock()
def ping(self, reconnect):
return self.await_(self._connection.ping(reconnect))
def character_set_name(self):
return self._connection.character_set_name()
def autocommit(self, value):
self.await_(self._connection.autocommit(value))
def get_autocommit(self) -> bool:
return self._connection.get_autocommit() # type: ignore
def cursor(self, server_side=False):
if server_side:
return AsyncAdapt_aiomysql_ss_cursor(self)
else:
return AsyncAdapt_aiomysql_cursor(self)
def terminate(self) -> None:
def rollback(self):
self.await_(self._connection.rollback())
def commit(self):
self.await_(self._connection.commit())
def close(self):
# it's not awaitable.
self._connection.close()
def close(self) -> None:
self.await_(self._connection.ensure_closed())
class AsyncAdaptFallback_aiomysql_connection(AsyncAdapt_aiomysql_connection):
# TODO: base on connectors/asyncio.py
# see #10415
__slots__ = ()
await_ = staticmethod(await_fallback)
class AsyncAdapt_aiomysql_dbapi(AsyncAdapt_dbapi_module):
def __init__(self, aiomysql: ModuleType, pymysql: ModuleType):
class AsyncAdapt_aiomysql_dbapi:
def __init__(self, aiomysql, pymysql):
self.aiomysql = aiomysql
self.pymysql = pymysql
self.paramstyle = "format"
self._init_dbapi_attributes()
self.Cursor, self.SSCursor = self._init_cursors_subclasses()
def _init_dbapi_attributes(self) -> None:
def _init_dbapi_attributes(self):
for name in (
"Warning",
"Error",
@@ -149,7 +249,7 @@ class AsyncAdapt_aiomysql_dbapi(AsyncAdapt_dbapi_module):
):
setattr(self, name, getattr(self.pymysql, name))
def connect(self, *arg: Any, **kw: Any) -> AsyncAdapt_aiomysql_connection:
def connect(self, *arg, **kw):
async_fallback = kw.pop("async_fallback", False)
creator_fn = kw.pop("async_creator_fn", self.aiomysql.connect)
@@ -164,23 +264,17 @@ class AsyncAdapt_aiomysql_dbapi(AsyncAdapt_dbapi_module):
await_only(creator_fn(*arg, **kw)),
)
def _init_cursors_subclasses(
self,
) -> Tuple[AsyncIODBAPICursor, AsyncIODBAPICursor]:
def _init_cursors_subclasses(self):
# suppress unconditional warning emitted by aiomysql
class Cursor(self.aiomysql.Cursor): # type: ignore[misc, name-defined]
async def _show_warnings(
self, conn: AsyncIODBAPIConnection
) -> None:
class Cursor(self.aiomysql.Cursor):
async def _show_warnings(self, conn):
pass
class SSCursor(self.aiomysql.SSCursor): # type: ignore[misc, name-defined] # noqa: E501
async def _show_warnings(
self, conn: AsyncIODBAPIConnection
) -> None:
class SSCursor(self.aiomysql.SSCursor):
async def _show_warnings(self, conn):
pass
return Cursor, SSCursor # type: ignore[return-value]
return Cursor, SSCursor
class MySQLDialect_aiomysql(MySQLDialect_pymysql):
@@ -191,16 +285,15 @@ class MySQLDialect_aiomysql(MySQLDialect_pymysql):
_sscursor = AsyncAdapt_aiomysql_ss_cursor
is_async = True
has_terminate = True
@classmethod
def import_dbapi(cls) -> AsyncAdapt_aiomysql_dbapi:
def import_dbapi(cls):
return AsyncAdapt_aiomysql_dbapi(
__import__("aiomysql"), __import__("pymysql")
)
@classmethod
def get_pool_class(cls, url: URL) -> type:
def get_pool_class(cls, url):
async_fallback = url.query.get("async_fallback", False)
if util.asbool(async_fallback):
@@ -208,37 +301,25 @@ class MySQLDialect_aiomysql(MySQLDialect_pymysql):
else:
return pool.AsyncAdaptedQueuePool
def do_terminate(self, dbapi_connection: DBAPIConnection) -> None:
dbapi_connection.terminate()
def create_connect_args(
self, url: URL, _translate_args: Optional[Dict[str, Any]] = None
) -> ConnectArgsType:
def create_connect_args(self, url):
return super().create_connect_args(
url, _translate_args=dict(username="user", database="db")
)
def is_disconnect(
self,
e: DBAPIModule.Error,
connection: Optional[Union[PoolProxiedConnection, DBAPIConnection]],
cursor: Optional[DBAPICursor],
) -> bool:
def is_disconnect(self, e, connection, cursor):
if super().is_disconnect(e, connection, cursor):
return True
else:
str_e = str(e).lower()
return "not connected" in str_e
def _found_rows_client_flag(self) -> int:
from pymysql.constants import CLIENT # type: ignore
def _found_rows_client_flag(self):
from pymysql.constants import CLIENT
return CLIENT.FOUND_ROWS # type: ignore[no-any-return]
return CLIENT.FOUND_ROWS
def get_driver_connection(
self, connection: DBAPIConnection
) -> AsyncIODBAPIConnection:
return connection._connection # type: ignore[no-any-return]
def get_driver_connection(self, connection):
return connection._connection
dialect = MySQLDialect_aiomysql

View File

@@ -1,9 +1,10 @@
# dialects/mysql/asyncmy.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors <see AUTHORS
# mysql/asyncmy.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors <see AUTHORS
# file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
r"""
.. dialect:: mysql+asyncmy
@@ -20,100 +21,210 @@ This dialect should normally be used only with the
:func:`_asyncio.create_async_engine` engine creation function::
from sqlalchemy.ext.asyncio import create_async_engine
engine = create_async_engine("mysql+asyncmy://user:pass@hostname/dbname?charset=utf8mb4")
engine = create_async_engine(
"mysql+asyncmy://user:pass@hostname/dbname?charset=utf8mb4"
)
""" # noqa
from __future__ import annotations
from types import ModuleType
from typing import Any
from typing import NoReturn
from typing import Optional
from typing import TYPE_CHECKING
from typing import Union
from contextlib import asynccontextmanager
from .pymysql import MySQLDialect_pymysql
from ... import pool
from ... import util
from ...connectors.asyncio import AsyncAdapt_dbapi_connection
from ...connectors.asyncio import AsyncAdapt_dbapi_cursor
from ...connectors.asyncio import AsyncAdapt_dbapi_module
from ...connectors.asyncio import AsyncAdapt_dbapi_ss_cursor
from ...engine import AdaptedConnection
from ...util.concurrency import asyncio
from ...util.concurrency import await_fallback
from ...util.concurrency import await_only
if TYPE_CHECKING:
from ...connectors.asyncio import AsyncIODBAPIConnection
from ...connectors.asyncio import AsyncIODBAPICursor
from ...engine.interfaces import ConnectArgsType
from ...engine.interfaces import DBAPIConnection
from ...engine.interfaces import DBAPICursor
from ...engine.interfaces import DBAPIModule
from ...engine.interfaces import PoolProxiedConnection
from ...engine.url import URL
class AsyncAdapt_asyncmy_cursor:
# TODO: base on connectors/asyncio.py
# see #10415
server_side = False
__slots__ = (
"_adapt_connection",
"_connection",
"await_",
"_cursor",
"_rows",
)
class AsyncAdapt_asyncmy_cursor(AsyncAdapt_dbapi_cursor):
__slots__ = ()
def __init__(self, adapt_connection):
self._adapt_connection = adapt_connection
self._connection = adapt_connection._connection
self.await_ = adapt_connection.await_
cursor = self._connection.cursor()
class AsyncAdapt_asyncmy_ss_cursor(
AsyncAdapt_dbapi_ss_cursor, AsyncAdapt_asyncmy_cursor
):
__slots__ = ()
self._cursor = self.await_(cursor.__aenter__())
self._rows = []
def _make_new_cursor(
self, connection: AsyncIODBAPIConnection
) -> AsyncIODBAPICursor:
return connection.cursor(
self._adapt_connection.dbapi.asyncmy.cursors.SSCursor
@property
def description(self):
return self._cursor.description
@property
def rowcount(self):
return self._cursor.rowcount
@property
def arraysize(self):
return self._cursor.arraysize
@arraysize.setter
def arraysize(self, value):
self._cursor.arraysize = value
@property
def lastrowid(self):
return self._cursor.lastrowid
def close(self):
# note we aren't actually closing the cursor here,
# we are just letting GC do it. to allow this to be async
# we would need the Result to change how it does "Safe close cursor".
# MySQL "cursors" don't actually have state to be "closed" besides
# exhausting rows, which we already have done for sync cursor.
# another option would be to emulate aiosqlite dialect and assign
# cursor only if we are doing server side cursor operation.
self._rows[:] = []
def execute(self, operation, parameters=None):
return self.await_(self._execute_async(operation, parameters))
def executemany(self, operation, seq_of_parameters):
return self.await_(
self._executemany_async(operation, seq_of_parameters)
)
async def _execute_async(self, operation, parameters):
async with self._adapt_connection._mutex_and_adapt_errors():
if parameters is None:
result = await self._cursor.execute(operation)
else:
result = await self._cursor.execute(operation, parameters)
class AsyncAdapt_asyncmy_connection(AsyncAdapt_dbapi_connection):
if not self.server_side:
# asyncmy has a "fake" async result, so we have to pull it out
# of that here since our default result is not async.
# we could just as easily grab "_rows" here and be done with it
# but this is safer.
self._rows = list(await self._cursor.fetchall())
return result
async def _executemany_async(self, operation, seq_of_parameters):
async with self._adapt_connection._mutex_and_adapt_errors():
return await self._cursor.executemany(operation, seq_of_parameters)
def setinputsizes(self, *inputsizes):
pass
def __iter__(self):
while self._rows:
yield self._rows.pop(0)
def fetchone(self):
if self._rows:
return self._rows.pop(0)
else:
return None
def fetchmany(self, size=None):
if size is None:
size = self.arraysize
retval = self._rows[0:size]
self._rows[:] = self._rows[size:]
return retval
def fetchall(self):
retval = self._rows[:]
self._rows[:] = []
return retval
class AsyncAdapt_asyncmy_ss_cursor(AsyncAdapt_asyncmy_cursor):
# TODO: base on connectors/asyncio.py
# see #10415
__slots__ = ()
server_side = True
_cursor_cls = AsyncAdapt_asyncmy_cursor
_ss_cursor_cls = AsyncAdapt_asyncmy_ss_cursor
def __init__(self, adapt_connection):
self._adapt_connection = adapt_connection
self._connection = adapt_connection._connection
self.await_ = adapt_connection.await_
def _handle_exception(self, error: Exception) -> NoReturn:
if isinstance(error, AttributeError):
raise self.dbapi.InternalError(
"network operation failed due to asyncmy attribute error"
)
cursor = self._connection.cursor(
adapt_connection.dbapi.asyncmy.cursors.SSCursor
)
raise error
self._cursor = self.await_(cursor.__aenter__())
def ping(self, reconnect: bool) -> None:
def close(self):
if self._cursor is not None:
self.await_(self._cursor.close())
self._cursor = None
def fetchone(self):
return self.await_(self._cursor.fetchone())
def fetchmany(self, size=None):
return self.await_(self._cursor.fetchmany(size=size))
def fetchall(self):
return self.await_(self._cursor.fetchall())
class AsyncAdapt_asyncmy_connection(AdaptedConnection):
# TODO: base on connectors/asyncio.py
# see #10415
await_ = staticmethod(await_only)
__slots__ = ("dbapi", "_execute_mutex")
def __init__(self, dbapi, connection):
self.dbapi = dbapi
self._connection = connection
self._execute_mutex = asyncio.Lock()
@asynccontextmanager
async def _mutex_and_adapt_errors(self):
async with self._execute_mutex:
try:
yield
except AttributeError:
raise self.dbapi.InternalError(
"network operation failed due to asyncmy attribute error"
)
def ping(self, reconnect):
assert not reconnect
return self.await_(self._do_ping())
async def _do_ping(self) -> None:
try:
async with self._execute_mutex:
await self._connection.ping(False)
except Exception as error:
self._handle_exception(error)
async def _do_ping(self):
async with self._mutex_and_adapt_errors():
return await self._connection.ping(False)
def character_set_name(self) -> Optional[str]:
return self._connection.character_set_name() # type: ignore[no-any-return] # noqa: E501
def character_set_name(self):
return self._connection.character_set_name()
def autocommit(self, value: Any) -> None:
def autocommit(self, value):
self.await_(self._connection.autocommit(value))
def get_autocommit(self) -> bool:
return self._connection.get_autocommit() # type: ignore
def cursor(self, server_side=False):
if server_side:
return AsyncAdapt_asyncmy_ss_cursor(self)
else:
return AsyncAdapt_asyncmy_cursor(self)
def terminate(self) -> None:
def rollback(self):
self.await_(self._connection.rollback())
def commit(self):
self.await_(self._connection.commit())
def close(self):
# it's not awaitable.
self._connection.close()
def close(self) -> None:
self.await_(self._connection.ensure_closed())
class AsyncAdaptFallback_asyncmy_connection(AsyncAdapt_asyncmy_connection):
__slots__ = ()
@@ -121,13 +232,18 @@ class AsyncAdaptFallback_asyncmy_connection(AsyncAdapt_asyncmy_connection):
await_ = staticmethod(await_fallback)
class AsyncAdapt_asyncmy_dbapi(AsyncAdapt_dbapi_module):
def __init__(self, asyncmy: ModuleType):
def _Binary(x):
"""Return x as a binary type."""
return bytes(x)
class AsyncAdapt_asyncmy_dbapi:
def __init__(self, asyncmy):
self.asyncmy = asyncmy
self.paramstyle = "format"
self._init_dbapi_attributes()
def _init_dbapi_attributes(self) -> None:
def _init_dbapi_attributes(self):
for name in (
"Warning",
"Error",
@@ -148,9 +264,9 @@ class AsyncAdapt_asyncmy_dbapi(AsyncAdapt_dbapi_module):
BINARY = util.symbol("BINARY")
DATETIME = util.symbol("DATETIME")
TIMESTAMP = util.symbol("TIMESTAMP")
Binary = staticmethod(bytes)
Binary = staticmethod(_Binary)
def connect(self, *arg: Any, **kw: Any) -> AsyncAdapt_asyncmy_connection:
def connect(self, *arg, **kw):
async_fallback = kw.pop("async_fallback", False)
creator_fn = kw.pop("async_creator_fn", self.asyncmy.connect)
@@ -174,14 +290,13 @@ class MySQLDialect_asyncmy(MySQLDialect_pymysql):
_sscursor = AsyncAdapt_asyncmy_ss_cursor
is_async = True
has_terminate = True
@classmethod
def import_dbapi(cls) -> DBAPIModule:
def import_dbapi(cls):
return AsyncAdapt_asyncmy_dbapi(__import__("asyncmy"))
@classmethod
def get_pool_class(cls, url: URL) -> type:
def get_pool_class(cls, url):
async_fallback = url.query.get("async_fallback", False)
if util.asbool(async_fallback):
@@ -189,20 +304,12 @@ class MySQLDialect_asyncmy(MySQLDialect_pymysql):
else:
return pool.AsyncAdaptedQueuePool
def do_terminate(self, dbapi_connection: DBAPIConnection) -> None:
dbapi_connection.terminate()
def create_connect_args(self, url: URL) -> ConnectArgsType: # type: ignore[override] # noqa: E501
def create_connect_args(self, url):
return super().create_connect_args(
url, _translate_args=dict(username="user", database="db")
)
def is_disconnect(
self,
e: DBAPIModule.Error,
connection: Optional[Union[PoolProxiedConnection, DBAPIConnection]],
cursor: Optional[DBAPICursor],
) -> bool:
def is_disconnect(self, e, connection, cursor):
if super().is_disconnect(e, connection, cursor):
return True
else:
@@ -211,15 +318,13 @@ class MySQLDialect_asyncmy(MySQLDialect_pymysql):
"not connected" in str_e or "network operation failed" in str_e
)
def _found_rows_client_flag(self) -> int:
from asyncmy.constants import CLIENT # type: ignore
def _found_rows_client_flag(self):
from asyncmy.constants import CLIENT
return CLIENT.FOUND_ROWS # type: ignore[no-any-return]
return CLIENT.FOUND_ROWS
def get_driver_connection(
self, connection: DBAPIConnection
) -> AsyncIODBAPIConnection:
return connection._connection # type: ignore[no-any-return]
def get_driver_connection(self, connection):
return connection._connection
dialect = MySQLDialect_asyncmy

View File

@@ -1,9 +1,10 @@
# dialects/mysql/cymysql.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# mysql/cymysql.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
r"""
@@ -20,36 +21,18 @@ r"""
dialects are mysqlclient and PyMySQL.
""" # noqa
from __future__ import annotations
from typing import Any
from typing import Iterable
from typing import Optional
from typing import TYPE_CHECKING
from typing import Union
from .base import BIT
from .base import MySQLDialect
from .mysqldb import MySQLDialect_mysqldb
from .types import BIT
from ... import util
if TYPE_CHECKING:
from ...engine.base import Connection
from ...engine.interfaces import DBAPIConnection
from ...engine.interfaces import DBAPICursor
from ...engine.interfaces import DBAPIModule
from ...engine.interfaces import Dialect
from ...engine.interfaces import PoolProxiedConnection
from ...sql.type_api import _ResultProcessorType
class _cymysqlBIT(BIT):
def result_processor(
self, dialect: Dialect, coltype: object
) -> Optional[_ResultProcessorType[Any]]:
def result_processor(self, dialect, coltype):
"""Convert MySQL's 64 bit, variable length binary string to a long."""
def process(value: Optional[Iterable[int]]) -> Optional[int]:
def process(value):
if value is not None:
v = 0
for i in iter(value):
@@ -72,22 +55,17 @@ class MySQLDialect_cymysql(MySQLDialect_mysqldb):
colspecs = util.update_copy(MySQLDialect.colspecs, {BIT: _cymysqlBIT})
@classmethod
def import_dbapi(cls) -> DBAPIModule:
def import_dbapi(cls):
return __import__("cymysql")
def _detect_charset(self, connection: Connection) -> str:
return connection.connection.charset # type: ignore[no-any-return]
def _detect_charset(self, connection):
return connection.connection.charset
def _extract_error_code(self, exception: DBAPIModule.Error) -> int:
return exception.errno # type: ignore[no-any-return]
def _extract_error_code(self, exception):
return exception.errno
def is_disconnect(
self,
e: DBAPIModule.Error,
connection: Optional[Union[PoolProxiedConnection, DBAPIConnection]],
cursor: Optional[DBAPICursor],
) -> bool:
if isinstance(e, self.loaded_dbapi.OperationalError):
def is_disconnect(self, e, connection, cursor):
if isinstance(e, self.dbapi.OperationalError):
return self._extract_error_code(e) in (
2006,
2013,
@@ -95,7 +73,7 @@ class MySQLDialect_cymysql(MySQLDialect_mysqldb):
2045,
2055,
)
elif isinstance(e, self.loaded_dbapi.InterfaceError):
elif isinstance(e, self.dbapi.InterfaceError):
# if underlying connection is closed,
# this is the error you get
return True

View File

@@ -1,5 +1,5 @@
# dialects/mysql/dml.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# mysql/dml.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@@ -7,7 +7,6 @@
from __future__ import annotations
from typing import Any
from typing import Dict
from typing import List
from typing import Mapping
from typing import Optional
@@ -142,11 +141,7 @@ class Insert(StandardInsert):
in :ref:`tutorial_parameter_ordered_updates`::
insert().on_duplicate_key_update(
[
("name", "some name"),
("value", "some value"),
]
)
[("name", "some name"), ("value", "some value")])
.. versionchanged:: 1.3 parameters can be specified as a dictionary
or list of 2-tuples; the latter form provides for parameter
@@ -186,7 +181,6 @@ class OnDuplicateClause(ClauseElement):
_parameter_ordering: Optional[List[str]] = None
update: Dict[str, Any]
stringify_dialect = "mysql"
def __init__(

View File

@@ -1,51 +1,34 @@
# dialects/mysql/enumerated.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# mysql/enumerated.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
from __future__ import annotations
import enum
import re
from typing import Any
from typing import Dict
from typing import Optional
from typing import Set
from typing import Type
from typing import TYPE_CHECKING
from typing import Union
from .types import _StringType
from ... import exc
from ... import sql
from ... import util
from ...sql import sqltypes
from ...sql import type_api
if TYPE_CHECKING:
from ...engine.interfaces import Dialect
from ...sql.elements import ColumnElement
from ...sql.type_api import _BindProcessorType
from ...sql.type_api import _ResultProcessorType
from ...sql.type_api import TypeEngine
from ...sql.type_api import TypeEngineMixin
class ENUM(type_api.NativeForEmulated, sqltypes.Enum, _StringType):
class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum, _StringType):
"""MySQL ENUM type."""
__visit_name__ = "ENUM"
native_enum = True
def __init__(self, *enums: Union[str, Type[enum.Enum]], **kw: Any) -> None:
def __init__(self, *enums, **kw):
"""Construct an ENUM.
E.g.::
Column("myenum", ENUM("foo", "bar", "baz"))
Column('myenum', ENUM("foo", "bar", "baz"))
:param enums: The range of valid values for this ENUM. Values in
enums are not quoted, they will be escaped and surrounded by single
@@ -79,27 +62,21 @@ class ENUM(type_api.NativeForEmulated, sqltypes.Enum, _StringType):
"""
kw.pop("strict", None)
self._enum_init(enums, kw) # type: ignore[arg-type]
self._enum_init(enums, kw)
_StringType.__init__(self, length=self.length, **kw)
@classmethod
def adapt_emulated_to_native(
cls,
impl: Union[TypeEngine[Any], TypeEngineMixin],
**kw: Any,
) -> ENUM:
def adapt_emulated_to_native(cls, impl, **kw):
"""Produce a MySQL native :class:`.mysql.ENUM` from plain
:class:`.Enum`.
"""
if TYPE_CHECKING:
assert isinstance(impl, ENUM)
kw.setdefault("validate_strings", impl.validate_strings)
kw.setdefault("values_callable", impl.values_callable)
kw.setdefault("omit_aliases", impl._omit_aliases)
return cls(**kw)
def _object_value_for_elem(self, elem: str) -> Union[str, enum.Enum]:
def _object_value_for_elem(self, elem):
# mysql sends back a blank string for any value that
# was persisted that was not in the enums; that is, it does no
# validation on the incoming data, it "truncates" it to be
@@ -109,27 +86,24 @@ class ENUM(type_api.NativeForEmulated, sqltypes.Enum, _StringType):
else:
return super()._object_value_for_elem(elem)
def __repr__(self) -> str:
def __repr__(self):
return util.generic_repr(
self, to_inspect=[ENUM, _StringType, sqltypes.Enum]
)
# TODO: SET is a string as far as configuration but does not act like
# a string at the python level. We either need to make a py-type agnostic
# version of String as a base to be used for this, make this some kind of
# TypeDecorator, or just vendor it out as its own type.
class SET(_StringType):
"""MySQL SET type."""
__visit_name__ = "SET"
def __init__(self, *values: str, **kw: Any):
def __init__(self, *values, **kw):
"""Construct a SET.
E.g.::
Column("myset", SET("foo", "bar", "baz"))
Column('myset', SET("foo", "bar", "baz"))
The list of potential values is required in the case that this
set will be used to generate DDL for a table, or if the
@@ -177,19 +151,17 @@ class SET(_StringType):
"setting retrieve_as_bitwise=True"
)
if self.retrieve_as_bitwise:
self._inversed_bitmap: Dict[str, int] = {
self._bitmap = {
value: 2**idx for idx, value in enumerate(self.values)
}
self._bitmap: Dict[int, str] = {
2**idx: value for idx, value in enumerate(self.values)
}
self._bitmap.update(
(2**idx, value) for idx, value in enumerate(self.values)
)
length = max([len(v) for v in values] + [0])
kw.setdefault("length", length)
super().__init__(**kw)
def column_expression(
self, colexpr: ColumnElement[Any]
) -> ColumnElement[Any]:
def column_expression(self, colexpr):
if self.retrieve_as_bitwise:
return sql.type_coerce(
sql.type_coerce(colexpr, sqltypes.Integer) + 0, self
@@ -197,12 +169,10 @@ class SET(_StringType):
else:
return colexpr
def result_processor(
self, dialect: Dialect, coltype: Any
) -> Optional[_ResultProcessorType[Any]]:
def result_processor(self, dialect, coltype):
if self.retrieve_as_bitwise:
def process(value: Union[str, int, None]) -> Optional[Set[str]]:
def process(value):
if value is not None:
value = int(value)
@@ -213,14 +183,11 @@ class SET(_StringType):
else:
super_convert = super().result_processor(dialect, coltype)
def process(value: Union[str, Set[str], None]) -> Optional[Set[str]]: # type: ignore[misc] # noqa: E501
def process(value):
if isinstance(value, str):
# MySQLdb returns a string, let's parse
if super_convert:
value = super_convert(value)
assert value is not None
if TYPE_CHECKING:
assert isinstance(value, str)
return set(re.findall(r"[^,]+", value))
else:
# mysql-connector-python does a naive
@@ -231,48 +198,43 @@ class SET(_StringType):
return process
def bind_processor(
self, dialect: Dialect
) -> _BindProcessorType[Union[str, int]]:
def bind_processor(self, dialect):
super_convert = super().bind_processor(dialect)
if self.retrieve_as_bitwise:
def process(
value: Union[str, int, set[str], None],
) -> Union[str, int, None]:
def process(value):
if value is None:
return None
elif isinstance(value, (int, str)):
if super_convert:
return super_convert(value) # type: ignore[arg-type, no-any-return] # noqa: E501
return super_convert(value)
else:
return value
else:
int_value = 0
for v in value:
int_value |= self._inversed_bitmap[v]
int_value |= self._bitmap[v]
return int_value
else:
def process(
value: Union[str, int, set[str], None],
) -> Union[str, int, None]:
def process(value):
# accept strings and int (actually bitflag) values directly
if value is not None and not isinstance(value, (int, str)):
value = ",".join(value)
if super_convert:
return super_convert(value) # type: ignore
return super_convert(value)
else:
return value
return process
def adapt(self, cls: type, **kw: Any) -> Any:
def adapt(self, impltype, **kw):
kw["retrieve_as_bitwise"] = self.retrieve_as_bitwise
return util.constructor_copy(self, cls, *self.values, **kw)
return util.constructor_copy(self, impltype, *self.values, **kw)
def __repr__(self) -> str:
def __repr__(self):
return util.generic_repr(
self,
to_inspect=[SET, _StringType],

View File

@@ -1,13 +1,10 @@
# dialects/mysql/expression.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
from __future__ import annotations
from typing import Any
from ... import exc
from ... import util
@@ -20,7 +17,7 @@ from ...sql.base import Generative
from ...util.typing import Self
class match(Generative, elements.BinaryExpression[Any]):
class match(Generative, elements.BinaryExpression):
"""Produce a ``MATCH (X, Y) AGAINST ('TEXT')`` clause.
E.g.::
@@ -40,9 +37,7 @@ class match(Generative, elements.BinaryExpression[Any]):
.order_by(desc(match_expr))
)
Would produce SQL resembling:
.. sourcecode:: sql
Would produce SQL resembling::
SELECT id, firstname, lastname
FROM user
@@ -75,9 +70,8 @@ class match(Generative, elements.BinaryExpression[Any]):
__visit_name__ = "mysql_match"
inherit_cache = True
modifiers: util.immutabledict[str, Any]
def __init__(self, *cols: elements.ColumnElement[Any], **kw: Any):
def __init__(self, *cols, **kw):
if not cols:
raise exc.ArgumentError("columns are required")

View File

@@ -1,21 +1,13 @@
# dialects/mysql/json.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# mysql/json.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
from __future__ import annotations
from typing import Any
from typing import TYPE_CHECKING
# mypy: ignore-errors
from ... import types as sqltypes
if TYPE_CHECKING:
from ...engine.interfaces import Dialect
from ...sql.type_api import _BindProcessorType
from ...sql.type_api import _LiteralProcessorType
class JSON(sqltypes.JSON):
"""MySQL JSON type.
@@ -42,13 +34,13 @@ class JSON(sqltypes.JSON):
class _FormatTypeMixin:
def _format_value(self, value: Any) -> str:
def _format_value(self, value):
raise NotImplementedError()
def bind_processor(self, dialect: Dialect) -> _BindProcessorType[Any]:
super_proc = self.string_bind_processor(dialect) # type: ignore[attr-defined] # noqa: E501
def bind_processor(self, dialect):
super_proc = self.string_bind_processor(dialect)
def process(value: Any) -> Any:
def process(value):
value = self._format_value(value)
if super_proc:
value = super_proc(value)
@@ -56,31 +48,29 @@ class _FormatTypeMixin:
return process
def literal_processor(
self, dialect: Dialect
) -> _LiteralProcessorType[Any]:
super_proc = self.string_literal_processor(dialect) # type: ignore[attr-defined] # noqa: E501
def literal_processor(self, dialect):
super_proc = self.string_literal_processor(dialect)
def process(value: Any) -> str:
def process(value):
value = self._format_value(value)
if super_proc:
value = super_proc(value)
return value # type: ignore[no-any-return]
return value
return process
class JSONIndexType(_FormatTypeMixin, sqltypes.JSON.JSONIndexType):
def _format_value(self, value: Any) -> str:
def _format_value(self, value):
if isinstance(value, int):
formatted_value = "$[%s]" % value
value = "$[%s]" % value
else:
formatted_value = '$."%s"' % value
return formatted_value
value = '$."%s"' % value
return value
class JSONPathType(_FormatTypeMixin, sqltypes.JSON.JSONPathType):
def _format_value(self, value: Any) -> str:
def _format_value(self, value):
return "$%s" % (
"".join(
[

View File

@@ -1,73 +1,32 @@
# dialects/mysql/mariadb.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# mysql/mariadb.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
from __future__ import annotations
from typing import Any
from typing import Callable
# mypy: ignore-errors
from .base import MariaDBIdentifierPreparer
from .base import MySQLDialect
from .base import MySQLIdentifierPreparer
from .base import MySQLTypeCompiler
from ...sql import sqltypes
class INET4(sqltypes.TypeEngine[str]):
"""INET4 column type for MariaDB
.. versionadded:: 2.0.37
"""
__visit_name__ = "INET4"
class INET6(sqltypes.TypeEngine[str]):
"""INET6 column type for MariaDB
.. versionadded:: 2.0.37
"""
__visit_name__ = "INET6"
class MariaDBTypeCompiler(MySQLTypeCompiler):
def visit_INET4(self, type_: INET4, **kwargs: Any) -> str:
return "INET4"
def visit_INET6(self, type_: INET6, **kwargs: Any) -> str:
return "INET6"
class MariaDBDialect(MySQLDialect):
is_mariadb = True
supports_statement_cache = True
name = "mariadb"
preparer: type[MySQLIdentifierPreparer] = MariaDBIdentifierPreparer
type_compiler_cls = MariaDBTypeCompiler
preparer = MariaDBIdentifierPreparer
def loader(driver: str) -> Callable[[], type[MariaDBDialect]]:
dialect_mod = __import__(
def loader(driver):
driver_mod = __import__(
"sqlalchemy.dialects.mysql.%s" % driver
).dialects.mysql
driver_cls = getattr(driver_mod, driver).dialect
driver_mod = getattr(dialect_mod, driver)
if hasattr(driver_mod, "mariadb_dialect"):
driver_cls = driver_mod.mariadb_dialect
return driver_cls # type: ignore[no-any-return]
else:
driver_cls = driver_mod.dialect
return type(
"MariaDBDialect_%s" % driver,
(
MariaDBDialect,
driver_cls,
),
{"supports_statement_cache": True},
)
return type(
"MariaDBDialect_%s" % driver,
(
MariaDBDialect,
driver_cls,
),
{"supports_statement_cache": True},
)

View File

@@ -1,9 +1,11 @@
# dialects/mysql/mariadbconnector.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# mysql/mariadbconnector.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
"""
@@ -27,15 +29,7 @@ be ``mysqldb``. ``mariadb+mariadbconnector://`` is required to use this driver.
.. mariadb: https://github.com/mariadb-corporation/mariadb-connector-python
""" # noqa
from __future__ import annotations
import re
from typing import Any
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import TYPE_CHECKING
from typing import Union
from uuid import UUID as _python_UUID
from .base import MySQLCompiler
@@ -45,19 +39,6 @@ from ... import sql
from ... import util
from ...sql import sqltypes
if TYPE_CHECKING:
from ...engine.base import Connection
from ...engine.interfaces import ConnectArgsType
from ...engine.interfaces import DBAPIConnection
from ...engine.interfaces import DBAPICursor
from ...engine.interfaces import DBAPIModule
from ...engine.interfaces import Dialect
from ...engine.interfaces import IsolationLevel
from ...engine.interfaces import PoolProxiedConnection
from ...engine.url import URL
from ...sql.compiler import SQLCompiler
from ...sql.type_api import _ResultProcessorType
mariadb_cpy_minimum_version = (1, 0, 1)
@@ -66,12 +47,10 @@ class _MariaDBUUID(sqltypes.UUID[sqltypes._UUID_RETURN]):
# work around JIRA issue
# https://jira.mariadb.org/browse/CONPY-270. When that issue is fixed,
# this type can be removed.
def result_processor(
self, dialect: Dialect, coltype: object
) -> Optional[_ResultProcessorType[Any]]:
def result_processor(self, dialect, coltype):
if self.as_uuid:
def process(value: Any) -> Any:
def process(value):
if value is not None:
if hasattr(value, "decode"):
value = value.decode("ascii")
@@ -81,7 +60,7 @@ class _MariaDBUUID(sqltypes.UUID[sqltypes._UUID_RETURN]):
return process
else:
def process(value: Any) -> Any:
def process(value):
if value is not None:
if hasattr(value, "decode"):
value = value.decode("ascii")
@@ -92,27 +71,30 @@ class _MariaDBUUID(sqltypes.UUID[sqltypes._UUID_RETURN]):
class MySQLExecutionContext_mariadbconnector(MySQLExecutionContext):
_lastrowid: Optional[int] = None
_lastrowid = None
def create_server_side_cursor(self) -> DBAPICursor:
def create_server_side_cursor(self):
return self._dbapi_connection.cursor(buffered=False)
def create_default_cursor(self) -> DBAPICursor:
def create_default_cursor(self):
return self._dbapi_connection.cursor(buffered=True)
def post_exec(self) -> None:
def post_exec(self):
super().post_exec()
self._rowcount = self.cursor.rowcount
if TYPE_CHECKING:
assert isinstance(self.compiled, SQLCompiler)
if self.isinsert and self.compiled.postfetch_lastrowid:
self._lastrowid = self.cursor.lastrowid
def get_lastrowid(self) -> int:
if TYPE_CHECKING:
assert self._lastrowid is not None
@property
def rowcount(self):
if self._rowcount is not None:
return self._rowcount
else:
return self.cursor.rowcount
def get_lastrowid(self):
return self._lastrowid
@@ -151,7 +133,7 @@ class MySQLDialect_mariadbconnector(MySQLDialect):
)
@util.memoized_property
def _dbapi_version(self) -> Tuple[int, ...]:
def _dbapi_version(self):
if self.dbapi and hasattr(self.dbapi, "__version__"):
return tuple(
[
@@ -164,7 +146,7 @@ class MySQLDialect_mariadbconnector(MySQLDialect):
else:
return (99, 99, 99)
def __init__(self, **kwargs: Any) -> None:
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.paramstyle = "qmark"
if self.dbapi is not None:
@@ -176,26 +158,20 @@ class MySQLDialect_mariadbconnector(MySQLDialect):
)
@classmethod
def import_dbapi(cls) -> DBAPIModule:
def import_dbapi(cls):
return __import__("mariadb")
def is_disconnect(
self,
e: DBAPIModule.Error,
connection: Optional[Union[PoolProxiedConnection, DBAPIConnection]],
cursor: Optional[DBAPICursor],
) -> bool:
def is_disconnect(self, e, connection, cursor):
if super().is_disconnect(e, connection, cursor):
return True
elif isinstance(e, self.loaded_dbapi.Error):
elif isinstance(e, self.dbapi.Error):
str_e = str(e).lower()
return "not connected" in str_e or "isn't valid" in str_e
else:
return False
def create_connect_args(self, url: URL) -> ConnectArgsType:
def create_connect_args(self, url):
opts = url.translate_connect_args()
opts.update(url.query)
int_params = [
"connect_timeout",
@@ -210,7 +186,6 @@ class MySQLDialect_mariadbconnector(MySQLDialect):
"ssl_verify_cert",
"ssl",
"pool_reset_connection",
"compress",
]
for key in int_params:
@@ -230,21 +205,19 @@ class MySQLDialect_mariadbconnector(MySQLDialect):
except (AttributeError, ImportError):
self.supports_sane_rowcount = False
opts["client_flag"] = client_flag
return [], opts
return [[], opts]
def _extract_error_code(self, exception: DBAPIModule.Error) -> int:
def _extract_error_code(self, exception):
try:
rc: int = exception.errno
rc = exception.errno
except:
rc = -1
return rc
def _detect_charset(self, connection: Connection) -> str:
def _detect_charset(self, connection):
return "utf8mb4"
def get_isolation_level_values(
self, dbapi_conn: DBAPIConnection
) -> Sequence[IsolationLevel]:
def get_isolation_level_values(self, dbapi_connection):
return (
"SERIALIZABLE",
"READ UNCOMMITTED",
@@ -253,26 +226,21 @@ class MySQLDialect_mariadbconnector(MySQLDialect):
"AUTOCOMMIT",
)
def detect_autocommit_setting(self, dbapi_conn: DBAPIConnection) -> bool:
return bool(dbapi_conn.autocommit)
def set_isolation_level(
self, dbapi_connection: DBAPIConnection, level: IsolationLevel
) -> None:
def set_isolation_level(self, connection, level):
if level == "AUTOCOMMIT":
dbapi_connection.autocommit = True
connection.autocommit = True
else:
dbapi_connection.autocommit = False
super().set_isolation_level(dbapi_connection, level)
connection.autocommit = False
super().set_isolation_level(connection, level)
def do_begin_twophase(self, connection: Connection, xid: Any) -> None:
def do_begin_twophase(self, connection, xid):
connection.execute(
sql.text("XA BEGIN :xid").bindparams(
sql.bindparam("xid", xid, literal_execute=True)
)
)
def do_prepare_twophase(self, connection: Connection, xid: Any) -> None:
def do_prepare_twophase(self, connection, xid):
connection.execute(
sql.text("XA END :xid").bindparams(
sql.bindparam("xid", xid, literal_execute=True)
@@ -285,12 +253,8 @@ class MySQLDialect_mariadbconnector(MySQLDialect):
)
def do_rollback_twophase(
self,
connection: Connection,
xid: Any,
is_prepared: bool = True,
recover: bool = False,
) -> None:
self, connection, xid, is_prepared=True, recover=False
):
if not is_prepared:
connection.execute(
sql.text("XA END :xid").bindparams(
@@ -304,12 +268,8 @@ class MySQLDialect_mariadbconnector(MySQLDialect):
)
def do_commit_twophase(
self,
connection: Connection,
xid: Any,
is_prepared: bool = True,
recover: bool = False,
) -> None:
self, connection, xid, is_prepared=True, recover=False
):
if not is_prepared:
self.do_prepare_twophase(connection, xid)
connection.execute(

View File

@@ -1,9 +1,10 @@
# dialects/mysql/mysqlconnector.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# mysql/mysqlconnector.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
r"""
@@ -13,85 +14,26 @@ r"""
:connectstring: mysql+mysqlconnector://<user>:<password>@<host>[:<port>]/<dbname>
:url: https://pypi.org/project/mysql-connector-python/
Driver Status
-------------
MySQL Connector/Python is supported as of SQLAlchemy 2.0.39 to the
degree which the driver is functional. There are still ongoing issues
with features such as server side cursors which remain disabled until
upstream issues are repaired.
.. warning:: The MySQL Connector/Python driver published by Oracle is subject
to frequent, major regressions of essential functionality such as being able
to correctly persist simple binary strings which indicate it is not well
tested. The SQLAlchemy project is not able to maintain this dialect fully as
regressions in the driver prevent it from being included in continuous
integration.
.. versionchanged:: 2.0.39
The MySQL Connector/Python dialect has been updated to support the
latest version of this DBAPI. Previously, MySQL Connector/Python
was not fully supported. However, support remains limited due to ongoing
regressions introduced in this driver.
Connecting to MariaDB with MySQL Connector/Python
--------------------------------------------------
MySQL Connector/Python may attempt to pass an incompatible collation to the
database when connecting to MariaDB. Experimentation has shown that using
``?charset=utf8mb4&collation=utfmb4_general_ci`` or similar MariaDB-compatible
charset/collation will allow connectivity.
.. note::
The MySQL Connector/Python DBAPI has had many issues since its release,
some of which may remain unresolved, and the mysqlconnector dialect is
**not tested as part of SQLAlchemy's continuous integration**.
The recommended MySQL dialects are mysqlclient and PyMySQL.
""" # noqa
from __future__ import annotations
import re
from typing import Any
from typing import cast
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import TYPE_CHECKING
from typing import Union
from .base import MariaDBIdentifierPreparer
from .base import BIT
from .base import MySQLCompiler
from .base import MySQLDialect
from .base import MySQLExecutionContext
from .base import MySQLIdentifierPreparer
from .mariadb import MariaDBDialect
from .types import BIT
from ... import util
if TYPE_CHECKING:
from ...engine.base import Connection
from ...engine.cursor import CursorResult
from ...engine.interfaces import ConnectArgsType
from ...engine.interfaces import DBAPIConnection
from ...engine.interfaces import DBAPICursor
from ...engine.interfaces import DBAPIModule
from ...engine.interfaces import IsolationLevel
from ...engine.interfaces import PoolProxiedConnection
from ...engine.row import Row
from ...engine.url import URL
from ...sql.elements import BinaryExpression
class MySQLExecutionContext_mysqlconnector(MySQLExecutionContext):
def create_server_side_cursor(self) -> DBAPICursor:
return self._dbapi_connection.cursor(buffered=False)
def create_default_cursor(self) -> DBAPICursor:
return self._dbapi_connection.cursor(buffered=True)
class MySQLCompiler_mysqlconnector(MySQLCompiler):
def visit_mod_binary(
self, binary: BinaryExpression[Any], operator: Any, **kw: Any
) -> str:
def visit_mod_binary(self, binary, operator, **kw):
return (
self.process(binary.left, **kw)
+ " % "
@@ -99,37 +41,22 @@ class MySQLCompiler_mysqlconnector(MySQLCompiler):
)
class IdentifierPreparerCommon_mysqlconnector:
class MySQLIdentifierPreparer_mysqlconnector(MySQLIdentifierPreparer):
@property
def _double_percents(self) -> bool:
def _double_percents(self):
return False
@_double_percents.setter
def _double_percents(self, value: Any) -> None:
def _double_percents(self, value):
pass
def _escape_identifier(self, value: str) -> str:
value = value.replace(
self.escape_quote, # type:ignore[attr-defined]
self.escape_to_quote, # type:ignore[attr-defined]
)
def _escape_identifier(self, value):
value = value.replace(self.escape_quote, self.escape_to_quote)
return value
class MySQLIdentifierPreparer_mysqlconnector(
IdentifierPreparerCommon_mysqlconnector, MySQLIdentifierPreparer
):
pass
class MariaDBIdentifierPreparer_mysqlconnector(
IdentifierPreparerCommon_mysqlconnector, MariaDBIdentifierPreparer
):
pass
class _myconnpyBIT(BIT):
def result_processor(self, dialect: Any, coltype: Any) -> None:
def result_processor(self, dialect, coltype):
"""MySQL-connector already converts mysql bits, so."""
return None
@@ -144,31 +71,24 @@ class MySQLDialect_mysqlconnector(MySQLDialect):
supports_native_decimal = True
supports_native_bit = True
# not until https://bugs.mysql.com/bug.php?id=117548
supports_server_side_cursors = False
default_paramstyle = "format"
statement_compiler = MySQLCompiler_mysqlconnector
execution_ctx_cls = MySQLExecutionContext_mysqlconnector
preparer: type[MySQLIdentifierPreparer] = (
MySQLIdentifierPreparer_mysqlconnector
)
preparer = MySQLIdentifierPreparer_mysqlconnector
colspecs = util.update_copy(MySQLDialect.colspecs, {BIT: _myconnpyBIT})
@classmethod
def import_dbapi(cls) -> DBAPIModule:
return cast("DBAPIModule", __import__("mysql.connector").connector)
def import_dbapi(cls):
from mysql import connector
def do_ping(self, dbapi_connection: DBAPIConnection) -> bool:
return connector
def do_ping(self, dbapi_connection):
dbapi_connection.ping(False)
return True
def create_connect_args(self, url: URL) -> ConnectArgsType:
def create_connect_args(self, url):
opts = url.translate_connect_args(username="user")
opts.update(url.query)
@@ -176,7 +96,6 @@ class MySQLDialect_mysqlconnector(MySQLDialect):
util.coerce_kw_type(opts, "allow_local_infile", bool)
util.coerce_kw_type(opts, "autocommit", bool)
util.coerce_kw_type(opts, "buffered", bool)
util.coerce_kw_type(opts, "client_flag", int)
util.coerce_kw_type(opts, "compress", bool)
util.coerce_kw_type(opts, "connection_timeout", int)
util.coerce_kw_type(opts, "connect_timeout", int)
@@ -191,21 +110,15 @@ class MySQLDialect_mysqlconnector(MySQLDialect):
util.coerce_kw_type(opts, "use_pure", bool)
util.coerce_kw_type(opts, "use_unicode", bool)
# note that "buffered" is set to False by default in MySQL/connector
# python. If you set it to True, then there is no way to get a server
# side cursor because the logic is written to disallow that.
# leaving this at True until
# https://bugs.mysql.com/bug.php?id=117548 can be fixed
opts["buffered"] = True
# unfortunately, MySQL/connector python refuses to release a
# cursor without reading fully, so non-buffered isn't an option
opts.setdefault("buffered", True)
# FOUND_ROWS must be set in ClientFlag to enable
# supports_sane_rowcount.
if self.dbapi is not None:
try:
from mysql.connector import constants # type: ignore
ClientFlag = constants.ClientFlag
from mysql.connector.constants import ClientFlag
client_flags = opts.get(
"client_flags", ClientFlag.get_default()
@@ -214,35 +127,24 @@ class MySQLDialect_mysqlconnector(MySQLDialect):
opts["client_flags"] = client_flags
except Exception:
pass
return [], opts
return [[], opts]
@util.memoized_property
def _mysqlconnector_version_info(self) -> Optional[Tuple[int, ...]]:
def _mysqlconnector_version_info(self):
if self.dbapi and hasattr(self.dbapi, "__version__"):
m = re.match(r"(\d+)\.(\d+)(?:\.(\d+))?", self.dbapi.__version__)
if m:
return tuple(int(x) for x in m.group(1, 2, 3) if x is not None)
return None
def _detect_charset(self, connection: Connection) -> str:
return connection.connection.charset # type: ignore
def _detect_charset(self, connection):
return connection.connection.charset
def _extract_error_code(self, exception: BaseException) -> int:
return exception.errno # type: ignore
def _extract_error_code(self, exception):
return exception.errno
def is_disconnect(
self,
e: Exception,
connection: Optional[Union[PoolProxiedConnection, DBAPIConnection]],
cursor: Optional[DBAPICursor],
) -> bool:
def is_disconnect(self, e, connection, cursor):
errnos = (2006, 2013, 2014, 2045, 2055, 2048)
exceptions = (
self.loaded_dbapi.OperationalError, #
self.loaded_dbapi.InterfaceError,
self.loaded_dbapi.ProgrammingError,
)
exceptions = (self.dbapi.OperationalError, self.dbapi.InterfaceError)
if isinstance(e, exceptions):
return (
e.errno in errnos
@@ -252,51 +154,26 @@ class MySQLDialect_mysqlconnector(MySQLDialect):
else:
return False
def _compat_fetchall(
self,
rp: CursorResult[Tuple[Any, ...]],
charset: Optional[str] = None,
) -> Sequence[Row[Tuple[Any, ...]]]:
def _compat_fetchall(self, rp, charset=None):
return rp.fetchall()
def _compat_fetchone(
self,
rp: CursorResult[Tuple[Any, ...]],
charset: Optional[str] = None,
) -> Optional[Row[Tuple[Any, ...]]]:
def _compat_fetchone(self, rp, charset=None):
return rp.fetchone()
def get_isolation_level_values(
self, dbapi_conn: DBAPIConnection
) -> Sequence[IsolationLevel]:
return (
"SERIALIZABLE",
"READ UNCOMMITTED",
"READ COMMITTED",
"REPEATABLE READ",
"AUTOCOMMIT",
)
_isolation_lookup = {
"SERIALIZABLE",
"READ UNCOMMITTED",
"READ COMMITTED",
"REPEATABLE READ",
"AUTOCOMMIT",
}
def detect_autocommit_setting(self, dbapi_conn: DBAPIConnection) -> bool:
return bool(dbapi_conn.autocommit)
def set_isolation_level(
self, dbapi_connection: DBAPIConnection, level: IsolationLevel
) -> None:
def _set_isolation_level(self, connection, level):
if level == "AUTOCOMMIT":
dbapi_connection.autocommit = True
connection.autocommit = True
else:
dbapi_connection.autocommit = False
super().set_isolation_level(dbapi_connection, level)
class MariaDBDialect_mysqlconnector(
MariaDBDialect, MySQLDialect_mysqlconnector
):
supports_statement_cache = True
_allows_uuid_binds = False
preparer = MariaDBIdentifierPreparer_mysqlconnector
connection.autocommit = False
super()._set_isolation_level(connection, level)
dialect = MySQLDialect_mysqlconnector
mariadb_dialect = MariaDBDialect_mysqlconnector

View File

@@ -1,9 +1,11 @@
# dialects/mysql/mysqldb.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# mysql/mysqldb.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
"""
@@ -46,9 +48,9 @@ key "ssl", which may be specified using the
"ssl": {
"ca": "/home/gord/client-ssl/ca.pem",
"cert": "/home/gord/client-ssl/client-cert.pem",
"key": "/home/gord/client-ssl/client-key.pem",
"key": "/home/gord/client-ssl/client-key.pem"
}
},
}
)
For convenience, the following keys may also be specified inline within the URL
@@ -72,9 +74,7 @@ Using MySQLdb with Google Cloud SQL
-----------------------------------
Google Cloud SQL now recommends use of the MySQLdb dialect. Connect
using a URL like the following:
.. sourcecode:: text
using a URL like the following::
mysql+mysqldb://root@/<dbname>?unix_socket=/cloudsql/<projectid>:<instancename>
@@ -84,39 +84,25 @@ Server Side Cursors
The mysqldb dialect supports server-side cursors. See :ref:`mysql_ss_cursors`.
"""
from __future__ import annotations
import re
from typing import Any
from typing import Callable
from typing import cast
from typing import Dict
from typing import Optional
from typing import Tuple
from typing import TYPE_CHECKING
from .base import MySQLCompiler
from .base import MySQLDialect
from .base import MySQLExecutionContext
from .base import MySQLIdentifierPreparer
from .base import TEXT
from ... import sql
from ... import util
from ...util.typing import Literal
if TYPE_CHECKING:
from ...engine.base import Connection
from ...engine.interfaces import _DBAPIMultiExecuteParams
from ...engine.interfaces import ConnectArgsType
from ...engine.interfaces import DBAPIConnection
from ...engine.interfaces import DBAPICursor
from ...engine.interfaces import DBAPIModule
from ...engine.interfaces import ExecutionContext
from ...engine.interfaces import IsolationLevel
from ...engine.url import URL
class MySQLExecutionContext_mysqldb(MySQLExecutionContext):
pass
@property
def rowcount(self):
if hasattr(self, "_rowcount"):
return self._rowcount
else:
return self.cursor.rowcount
class MySQLCompiler_mysqldb(MySQLCompiler):
@@ -136,9 +122,8 @@ class MySQLDialect_mysqldb(MySQLDialect):
execution_ctx_cls = MySQLExecutionContext_mysqldb
statement_compiler = MySQLCompiler_mysqldb
preparer = MySQLIdentifierPreparer
server_version_info: Tuple[int, ...]
def __init__(self, **kwargs: Any):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._mysql_dbapi_version = (
self._parse_dbapi_version(self.dbapi.__version__)
@@ -146,7 +131,7 @@ class MySQLDialect_mysqldb(MySQLDialect):
else (0, 0, 0)
)
def _parse_dbapi_version(self, version: str) -> Tuple[int, ...]:
def _parse_dbapi_version(self, version):
m = re.match(r"(\d+)\.(\d+)(?:\.(\d+))?", version)
if m:
return tuple(int(x) for x in m.group(1, 2, 3) if x is not None)
@@ -154,7 +139,7 @@ class MySQLDialect_mysqldb(MySQLDialect):
return (0, 0, 0)
@util.langhelpers.memoized_property
def supports_server_side_cursors(self) -> bool:
def supports_server_side_cursors(self):
try:
cursors = __import__("MySQLdb.cursors").cursors
self._sscursor = cursors.SSCursor
@@ -163,13 +148,13 @@ class MySQLDialect_mysqldb(MySQLDialect):
return False
@classmethod
def import_dbapi(cls) -> DBAPIModule:
def import_dbapi(cls):
return __import__("MySQLdb")
def on_connect(self) -> Callable[[DBAPIConnection], None]:
def on_connect(self):
super_ = super().on_connect()
def on_connect(conn: DBAPIConnection) -> None:
def on_connect(conn):
if super_ is not None:
super_(conn)
@@ -182,24 +167,43 @@ class MySQLDialect_mysqldb(MySQLDialect):
return on_connect
def do_ping(self, dbapi_connection: DBAPIConnection) -> Literal[True]:
def do_ping(self, dbapi_connection):
dbapi_connection.ping()
return True
def do_executemany(
self,
cursor: DBAPICursor,
statement: str,
parameters: _DBAPIMultiExecuteParams,
context: Optional[ExecutionContext] = None,
) -> None:
def do_executemany(self, cursor, statement, parameters, context=None):
rowcount = cursor.executemany(statement, parameters)
if context is not None:
cast(MySQLExecutionContext, context)._rowcount = rowcount
context._rowcount = rowcount
def create_connect_args(
self, url: URL, _translate_args: Optional[Dict[str, Any]] = None
) -> ConnectArgsType:
def _check_unicode_returns(self, connection):
# work around issue fixed in
# https://github.com/farcepest/MySQLdb1/commit/cd44524fef63bd3fcb71947392326e9742d520e8
# specific issue w/ the utf8mb4_bin collation and unicode returns
collation = connection.exec_driver_sql(
"show collation where %s = 'utf8mb4' and %s = 'utf8mb4_bin'"
% (
self.identifier_preparer.quote("Charset"),
self.identifier_preparer.quote("Collation"),
)
).scalar()
has_utf8mb4_bin = self.server_version_info > (5,) and collation
if has_utf8mb4_bin:
additional_tests = [
sql.collate(
sql.cast(
sql.literal_column("'test collated returns'"),
TEXT(charset="utf8mb4"),
),
"utf8mb4_bin",
)
]
else:
additional_tests = []
return super()._check_unicode_returns(connection, additional_tests)
def create_connect_args(self, url, _translate_args=None):
if _translate_args is None:
_translate_args = dict(
database="db", username="user", password="passwd"
@@ -213,7 +217,7 @@ class MySQLDialect_mysqldb(MySQLDialect):
util.coerce_kw_type(opts, "read_timeout", int)
util.coerce_kw_type(opts, "write_timeout", int)
util.coerce_kw_type(opts, "client_flag", int)
util.coerce_kw_type(opts, "local_infile", bool)
util.coerce_kw_type(opts, "local_infile", int)
# Note: using either of the below will cause all strings to be
# returned as Unicode, both in raw SQL operations and with column
# types like String and MSString.
@@ -248,9 +252,9 @@ class MySQLDialect_mysqldb(MySQLDialect):
if client_flag_found_rows is not None:
client_flag |= client_flag_found_rows
opts["client_flag"] = client_flag
return [], opts
return [[], opts]
def _found_rows_client_flag(self) -> Optional[int]:
def _found_rows_client_flag(self):
if self.dbapi is not None:
try:
CLIENT_FLAGS = __import__(
@@ -259,23 +263,20 @@ class MySQLDialect_mysqldb(MySQLDialect):
except (AttributeError, ImportError):
return None
else:
return CLIENT_FLAGS.FOUND_ROWS # type: ignore
return CLIENT_FLAGS.FOUND_ROWS
else:
return None
def _extract_error_code(self, exception: DBAPIModule.Error) -> int:
return exception.args[0] # type: ignore[no-any-return]
def _extract_error_code(self, exception):
return exception.args[0]
def _detect_charset(self, connection: Connection) -> str:
def _detect_charset(self, connection):
"""Sniff out the character set in use for connection results."""
try:
# note: the SQL here would be
# "SHOW VARIABLES LIKE 'character_set%%'"
cset_name: Callable[[], str] = (
connection.connection.character_set_name
)
cset_name = connection.connection.character_set_name
except AttributeError:
util.warn(
"No 'character_set_name' can be detected with "
@@ -287,9 +288,7 @@ class MySQLDialect_mysqldb(MySQLDialect):
else:
return cset_name()
def get_isolation_level_values(
self, dbapi_conn: DBAPIConnection
) -> Tuple[IsolationLevel, ...]:
def get_isolation_level_values(self, dbapi_connection):
return (
"SERIALIZABLE",
"READ UNCOMMITTED",
@@ -298,12 +297,7 @@ class MySQLDialect_mysqldb(MySQLDialect):
"AUTOCOMMIT",
)
def detect_autocommit_setting(self, dbapi_conn: DBAPIConnection) -> bool:
return dbapi_conn.get_autocommit() # type: ignore[no-any-return]
def set_isolation_level(
self, dbapi_connection: DBAPIConnection, level: IsolationLevel
) -> None:
def set_isolation_level(self, dbapi_connection, level):
if level == "AUTOCOMMIT":
dbapi_connection.autocommit(True)
else:

View File

@@ -1,10 +1,5 @@
# dialects/mysql/provision.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
from ... import exc
from ...testing.provision import configure_follower
from ...testing.provision import create_db
@@ -39,13 +34,6 @@ def generate_driver_url(url, driver, query_str):
drivername="%s+%s" % (backend, driver)
).update_query_string(query_str)
if driver == "mariadbconnector":
new_url = new_url.difference_update_query(["charset"])
elif driver == "mysqlconnector":
new_url = new_url.update_query_pairs(
[("collation", "utf8mb4_general_ci")]
)
try:
new_url.get_dialect()
except exc.NoSuchModuleError:

View File

@@ -1,9 +1,11 @@
# dialects/mysql/pymysql.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# mysql/pymysql.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
r"""
@@ -39,6 +41,7 @@ necessary to indicate ``ssl_check_hostname=false`` in PyMySQL::
"&ssl_check_hostname=false"
)
MySQL-Python Compatibility
--------------------------
@@ -47,26 +50,9 @@ and targets 100% compatibility. Most behavioral notes for MySQL-python apply
to the pymysql driver as well.
""" # noqa
from __future__ import annotations
from typing import Any
from typing import Dict
from typing import Optional
from typing import TYPE_CHECKING
from typing import Union
from .mysqldb import MySQLDialect_mysqldb
from ...util import langhelpers
from ...util.typing import Literal
if TYPE_CHECKING:
from ...engine.interfaces import ConnectArgsType
from ...engine.interfaces import DBAPIConnection
from ...engine.interfaces import DBAPICursor
from ...engine.interfaces import DBAPIModule
from ...engine.interfaces import PoolProxiedConnection
from ...engine.url import URL
class MySQLDialect_pymysql(MySQLDialect_mysqldb):
@@ -76,7 +62,7 @@ class MySQLDialect_pymysql(MySQLDialect_mysqldb):
description_encoding = None
@langhelpers.memoized_property
def supports_server_side_cursors(self) -> bool:
def supports_server_side_cursors(self):
try:
cursors = __import__("pymysql.cursors").cursors
self._sscursor = cursors.SSCursor
@@ -85,11 +71,11 @@ class MySQLDialect_pymysql(MySQLDialect_mysqldb):
return False
@classmethod
def import_dbapi(cls) -> DBAPIModule:
def import_dbapi(cls):
return __import__("pymysql")
@langhelpers.memoized_property
def _send_false_to_ping(self) -> bool:
def _send_false_to_ping(self):
"""determine if pymysql has deprecated, changed the default of,
or removed the 'reconnect' argument of connection.ping().
@@ -100,9 +86,7 @@ class MySQLDialect_pymysql(MySQLDialect_mysqldb):
""" # noqa: E501
try:
Connection = __import__(
"pymysql.connections"
).connections.Connection
Connection = __import__("pymysql.connections").Connection
except (ImportError, AttributeError):
return True
else:
@@ -116,7 +100,7 @@ class MySQLDialect_pymysql(MySQLDialect_mysqldb):
not insp.defaults or insp.defaults[0] is not False
)
def do_ping(self, dbapi_connection: DBAPIConnection) -> Literal[True]:
def do_ping(self, dbapi_connection):
if self._send_false_to_ping:
dbapi_connection.ping(False)
else:
@@ -124,24 +108,17 @@ class MySQLDialect_pymysql(MySQLDialect_mysqldb):
return True
def create_connect_args(
self, url: URL, _translate_args: Optional[Dict[str, Any]] = None
) -> ConnectArgsType:
def create_connect_args(self, url, _translate_args=None):
if _translate_args is None:
_translate_args = dict(username="user")
return super().create_connect_args(
url, _translate_args=_translate_args
)
def is_disconnect(
self,
e: DBAPIModule.Error,
connection: Optional[Union[PoolProxiedConnection, DBAPIConnection]],
cursor: Optional[DBAPICursor],
) -> bool:
def is_disconnect(self, e, connection, cursor):
if super().is_disconnect(e, connection, cursor):
return True
elif isinstance(e, self.loaded_dbapi.Error):
elif isinstance(e, self.dbapi.Error):
str_e = str(e).lower()
return (
"already closed" in str_e or "connection was killed" in str_e
@@ -149,7 +126,7 @@ class MySQLDialect_pymysql(MySQLDialect_mysqldb):
else:
return False
def _extract_error_code(self, exception: BaseException) -> Any:
def _extract_error_code(self, exception):
if isinstance(exception.args[0], Exception):
exception = exception.args[0]
return exception.args[0]

View File

@@ -1,13 +1,15 @@
# dialects/mysql/pyodbc.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# mysql/pyodbc.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
r"""
.. dialect:: mysql+pyodbc
:name: PyODBC
:dbapi: pyodbc
@@ -28,30 +30,21 @@ r"""
Pass through exact pyodbc connection string::
import urllib
connection_string = (
"DRIVER=MySQL ODBC 8.0 ANSI Driver;"
"SERVER=localhost;"
"PORT=3307;"
"DATABASE=mydb;"
"UID=root;"
"PWD=(whatever);"
"charset=utf8mb4;"
'DRIVER=MySQL ODBC 8.0 ANSI Driver;'
'SERVER=localhost;'
'PORT=3307;'
'DATABASE=mydb;'
'UID=root;'
'PWD=(whatever);'
'charset=utf8mb4;'
)
params = urllib.parse.quote_plus(connection_string)
connection_uri = "mysql+pyodbc:///?odbc_connect=%s" % params
""" # noqa
from __future__ import annotations
import datetime
import re
from typing import Any
from typing import Callable
from typing import Optional
from typing import Tuple
from typing import TYPE_CHECKING
from typing import Union
from .base import MySQLDialect
from .base import MySQLExecutionContext
@@ -61,31 +54,23 @@ from ... import util
from ...connectors.pyodbc import PyODBCConnector
from ...sql.sqltypes import Time
if TYPE_CHECKING:
from ...engine import Connection
from ...engine.interfaces import DBAPIConnection
from ...engine.interfaces import Dialect
from ...sql.type_api import _ResultProcessorType
class _pyodbcTIME(TIME):
def result_processor(
self, dialect: Dialect, coltype: object
) -> _ResultProcessorType[datetime.time]:
def process(value: Any) -> Union[datetime.time, None]:
def result_processor(self, dialect, coltype):
def process(value):
# pyodbc returns a datetime.time object; no need to convert
return value # type: ignore[no-any-return]
return value
return process
class MySQLExecutionContext_pyodbc(MySQLExecutionContext):
def get_lastrowid(self) -> int:
def get_lastrowid(self):
cursor = self.create_cursor()
cursor.execute("SELECT LAST_INSERT_ID()")
lastrowid = cursor.fetchone()[0] # type: ignore[index]
lastrowid = cursor.fetchone()[0]
cursor.close()
return lastrowid # type: ignore[no-any-return]
return lastrowid
class MySQLDialect_pyodbc(PyODBCConnector, MySQLDialect):
@@ -96,7 +81,7 @@ class MySQLDialect_pyodbc(PyODBCConnector, MySQLDialect):
pyodbc_driver_name = "MySQL"
def _detect_charset(self, connection: Connection) -> str:
def _detect_charset(self, connection):
"""Sniff out the character set in use for connection results."""
# Prefer 'character_set_results' for the current connection over the
@@ -121,25 +106,21 @@ class MySQLDialect_pyodbc(PyODBCConnector, MySQLDialect):
)
return "latin1"
def _get_server_version_info(
self, connection: Connection
) -> Tuple[int, ...]:
def _get_server_version_info(self, connection):
return MySQLDialect._get_server_version_info(self, connection)
def _extract_error_code(self, exception: BaseException) -> Optional[int]:
def _extract_error_code(self, exception):
m = re.compile(r"\((\d+)\)").search(str(exception.args))
if m is None:
return None
c: Optional[str] = m.group(1)
c = m.group(1)
if c:
return int(c)
else:
return None
def on_connect(self) -> Callable[[DBAPIConnection], None]:
def on_connect(self):
super_ = super().on_connect()
def on_connect(conn: DBAPIConnection) -> None:
def on_connect(conn):
if super_ is not None:
super_(conn)

View File

@@ -1,65 +1,46 @@
# dialects/mysql/reflection.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# mysql/reflection.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
from __future__ import annotations
# mypy: ignore-errors
import re
from typing import Any
from typing import Callable
from typing import Dict
from typing import List
from typing import Optional
from typing import overload
from typing import Sequence
from typing import Tuple
from typing import TYPE_CHECKING
from typing import Union
from .enumerated import ENUM
from .enumerated import SET
from .types import DATETIME
from .types import TIME
from .types import TIMESTAMP
from ... import log
from ... import types as sqltypes
from ... import util
from ...util.typing import Literal
if TYPE_CHECKING:
from .base import MySQLDialect
from .base import MySQLIdentifierPreparer
from ...engine.interfaces import ReflectedColumn
class ReflectedState:
"""Stores raw information about a SHOW CREATE TABLE statement."""
charset: Optional[str]
def __init__(self) -> None:
self.columns: List[ReflectedColumn] = []
self.table_options: Dict[str, str] = {}
self.table_name: Optional[str] = None
self.keys: List[Dict[str, Any]] = []
self.fk_constraints: List[Dict[str, Any]] = []
self.ck_constraints: List[Dict[str, Any]] = []
def __init__(self):
self.columns = []
self.table_options = {}
self.table_name = None
self.keys = []
self.fk_constraints = []
self.ck_constraints = []
@log.class_logger
class MySQLTableDefinitionParser:
"""Parses the results of a SHOW CREATE TABLE statement."""
def __init__(
self, dialect: MySQLDialect, preparer: MySQLIdentifierPreparer
):
def __init__(self, dialect, preparer):
self.dialect = dialect
self.preparer = preparer
self._prep_regexes()
def parse(
self, show_create: str, charset: Optional[str]
) -> ReflectedState:
def parse(self, show_create, charset):
state = ReflectedState()
state.charset = charset
for line in re.split(r"\r?\n", show_create):
@@ -84,11 +65,11 @@ class MySQLTableDefinitionParser:
if type_ is None:
util.warn("Unknown schema content: %r" % line)
elif type_ == "key":
state.keys.append(spec) # type: ignore[arg-type]
state.keys.append(spec)
elif type_ == "fk_constraint":
state.fk_constraints.append(spec) # type: ignore[arg-type]
state.fk_constraints.append(spec)
elif type_ == "ck_constraint":
state.ck_constraints.append(spec) # type: ignore[arg-type]
state.ck_constraints.append(spec)
else:
pass
return state
@@ -96,13 +77,7 @@ class MySQLTableDefinitionParser:
def _check_view(self, sql: str) -> bool:
return bool(self._re_is_view.match(sql))
def _parse_constraints(self, line: str) -> Union[
Tuple[None, str],
Tuple[Literal["partition"], str],
Tuple[
Literal["ck_constraint", "fk_constraint", "key"], Dict[str, str]
],
]:
def _parse_constraints(self, line):
"""Parse a KEY or CONSTRAINT line.
:param line: A line of SHOW CREATE TABLE output
@@ -152,7 +127,7 @@ class MySQLTableDefinitionParser:
# No match.
return (None, line)
def _parse_table_name(self, line: str, state: ReflectedState) -> None:
def _parse_table_name(self, line, state):
"""Extract the table name.
:param line: The first line of SHOW CREATE TABLE
@@ -163,7 +138,7 @@ class MySQLTableDefinitionParser:
if m:
state.table_name = cleanup(m.group("name"))
def _parse_table_options(self, line: str, state: ReflectedState) -> None:
def _parse_table_options(self, line, state):
"""Build a dictionary of all reflected table-level options.
:param line: The final line of SHOW CREATE TABLE output.
@@ -189,9 +164,7 @@ class MySQLTableDefinitionParser:
for opt, val in options.items():
state.table_options["%s_%s" % (self.dialect.name, opt)] = val
def _parse_partition_options(
self, line: str, state: ReflectedState
) -> None:
def _parse_partition_options(self, line, state):
options = {}
new_line = line[:]
@@ -247,7 +220,7 @@ class MySQLTableDefinitionParser:
else:
state.table_options["%s_%s" % (self.dialect.name, opt)] = val
def _parse_column(self, line: str, state: ReflectedState) -> None:
def _parse_column(self, line, state):
"""Extract column details.
Falls back to a 'minimal support' variant if full parse fails.
@@ -310,16 +283,13 @@ class MySQLTableDefinitionParser:
type_instance = col_type(*type_args, **type_kw)
col_kw: Dict[str, Any] = {}
col_kw = {}
# NOT NULL
col_kw["nullable"] = True
# this can be "NULL" in the case of TIMESTAMP
if spec.get("notnull", False) == "NOT NULL":
col_kw["nullable"] = False
# For generated columns, the nullability is marked in a different place
if spec.get("notnull_generated", False) == "NOT NULL":
col_kw["nullable"] = False
# AUTO_INCREMENT
if spec.get("autoincr", False):
@@ -351,13 +321,9 @@ class MySQLTableDefinitionParser:
name=name, type=type_instance, default=default, comment=comment
)
col_d.update(col_kw)
state.columns.append(col_d) # type: ignore[arg-type]
state.columns.append(col_d)
def _describe_to_create(
self,
table_name: str,
columns: Sequence[Tuple[str, str, str, str, str, str]],
) -> str:
def _describe_to_create(self, table_name, columns):
"""Re-format DESCRIBE output as a SHOW CREATE TABLE string.
DESCRIBE is a much simpler reflection and is sufficient for
@@ -410,9 +376,7 @@ class MySQLTableDefinitionParser:
]
)
def _parse_keyexprs(
self, identifiers: str
) -> List[Tuple[str, Optional[int], str]]:
def _parse_keyexprs(self, identifiers):
"""Unpack '"col"(2),"col" ASC'-ish strings into components."""
return [
@@ -422,12 +386,11 @@ class MySQLTableDefinitionParser:
)
]
def _prep_regexes(self) -> None:
def _prep_regexes(self):
"""Pre-compile regular expressions."""
self._pr_options: List[
Tuple[re.Pattern[Any], Optional[Callable[[str], str]]]
] = []
self._re_columns = []
self._pr_options = []
_final = self.preparer.final_quote
@@ -485,13 +448,11 @@ class MySQLTableDefinitionParser:
r"(?: +COLLATE +(?P<collate>[\w_]+))?"
r"(?: +(?P<notnull>(?:NOT )?NULL))?"
r"(?: +DEFAULT +(?P<default>"
r"(?:NULL|'(?:''|[^'])*'|\(.+?\)|[\-\w\.\(\)]+"
r"(?:NULL|'(?:''|[^'])*'|[\-\w\.\(\)]+"
r"(?: +ON UPDATE [\-\w\.\(\)]+)?)"
r"))?"
r"(?: +(?:GENERATED ALWAYS)? ?AS +(?P<generated>\("
r".*\))? ?(?P<persistence>VIRTUAL|STORED)?"
r"(?: +(?P<notnull_generated>(?:NOT )?NULL))?"
r")?"
r".*\))? ?(?P<persistence>VIRTUAL|STORED)?)?"
r"(?: +(?P<autoincr>AUTO_INCREMENT))?"
r"(?: +COMMENT +'(?P<comment>(?:''|[^'])*)')?"
r"(?: +COLUMN_FORMAT +(?P<colfmt>\w+))?"
@@ -539,7 +500,7 @@ class MySQLTableDefinitionParser:
#
# unique constraints come back as KEYs
kw = quotes.copy()
kw["on"] = "RESTRICT|CASCADE|SET NULL|NO ACTION|SET DEFAULT"
kw["on"] = "RESTRICT|CASCADE|SET NULL|NO ACTION"
self._re_fk_constraint = _re_compile(
r" "
r"CONSTRAINT +"
@@ -616,21 +577,21 @@ class MySQLTableDefinitionParser:
_optional_equals = r"(?:\s*(?:=\s*)|\s+)"
def _add_option_string(self, directive: str) -> None:
def _add_option_string(self, directive):
regex = r"(?P<directive>%s)%s" r"'(?P<val>(?:[^']|'')*?)'(?!')" % (
re.escape(directive),
self._optional_equals,
)
self._pr_options.append(_pr_compile(regex, cleanup_text))
def _add_option_word(self, directive: str) -> None:
def _add_option_word(self, directive):
regex = r"(?P<directive>%s)%s" r"(?P<val>\w+)" % (
re.escape(directive),
self._optional_equals,
)
self._pr_options.append(_pr_compile(regex))
def _add_partition_option_word(self, directive: str) -> None:
def _add_partition_option_word(self, directive):
if directive == "PARTITION BY" or directive == "SUBPARTITION BY":
regex = r"(?<!\S)(?P<directive>%s)%s" r"(?P<val>\w+.*)" % (
re.escape(directive),
@@ -645,7 +606,7 @@ class MySQLTableDefinitionParser:
regex = r"(?<!\S)(?P<directive>%s)(?!\S)" % (re.escape(directive),)
self._pr_options.append(_pr_compile(regex))
def _add_option_regex(self, directive: str, regex: str) -> None:
def _add_option_regex(self, directive, regex):
regex = r"(?P<directive>%s)%s" r"(?P<val>%s)" % (
re.escape(directive),
self._optional_equals,
@@ -663,35 +624,21 @@ _options_of_type_string = (
)
@overload
def _pr_compile(
regex: str, cleanup: Callable[[str], str]
) -> Tuple[re.Pattern[Any], Callable[[str], str]]: ...
@overload
def _pr_compile(
regex: str, cleanup: None = None
) -> Tuple[re.Pattern[Any], None]: ...
def _pr_compile(
regex: str, cleanup: Optional[Callable[[str], str]] = None
) -> Tuple[re.Pattern[Any], Optional[Callable[[str], str]]]:
def _pr_compile(regex, cleanup=None):
"""Prepare a 2-tuple of compiled regex and callable."""
return (_re_compile(regex), cleanup)
def _re_compile(regex: str) -> re.Pattern[Any]:
def _re_compile(regex):
"""Compile a string to regex, I and UNICODE."""
return re.compile(regex, re.I | re.UNICODE)
def _strip_values(values: Sequence[str]) -> List[str]:
def _strip_values(values):
"Strip reflected values quotes"
strip_values: List[str] = []
strip_values = []
for a in values:
if a[0:1] == '"' or a[0:1] == "'":
# strip enclosing quotes and unquote interior
@@ -703,9 +650,7 @@ def _strip_values(values: Sequence[str]) -> List[str]:
def cleanup_text(raw_text: str) -> str:
if "\\" in raw_text:
raw_text = re.sub(
_control_char_regexp,
lambda s: _control_char_map[s[0]], # type: ignore[index]
raw_text,
_control_char_regexp, lambda s: _control_char_map[s[0]], raw_text
)
return raw_text.replace("''", "'")

View File

@@ -1,5 +1,5 @@
# dialects/mysql/reserved_words.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# mysql/reserved_words.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@@ -11,6 +11,7 @@
# https://mariadb.com/kb/en/reserved-words/
# includes: Reserved Words, Oracle Mode (separate set unioned)
# excludes: Exceptions, Function Names
# mypy: ignore-errors
RESERVED_WORDS_MARIADB = {
"accessible",
@@ -281,7 +282,6 @@ RESERVED_WORDS_MARIADB = {
}
)
# https://dev.mysql.com/doc/refman/8.3/en/keywords.html
# https://dev.mysql.com/doc/refman/8.0/en/keywords.html
# https://dev.mysql.com/doc/refman/5.7/en/keywords.html
# https://dev.mysql.com/doc/refman/5.6/en/keywords.html
@@ -403,7 +403,6 @@ RESERVED_WORDS_MYSQL = {
"int4",
"int8",
"integer",
"intersect",
"interval",
"into",
"io_after_gtids",
@@ -469,7 +468,6 @@ RESERVED_WORDS_MYSQL = {
"outfile",
"over",
"parse_gcol_expr",
"parallel",
"partition",
"percent_rank",
"persist",
@@ -478,7 +476,6 @@ RESERVED_WORDS_MYSQL = {
"primary",
"procedure",
"purge",
"qualify",
"range",
"rank",
"read",

View File

@@ -1,30 +1,18 @@
# dialects/mysql/types.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# mysql/types.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
from __future__ import annotations
# mypy: ignore-errors
import datetime
import decimal
from typing import Any
from typing import Iterable
from typing import Optional
from typing import TYPE_CHECKING
from typing import Union
from ... import exc
from ... import util
from ...sql import sqltypes
if TYPE_CHECKING:
from .base import MySQLDialect
from ...engine.interfaces import Dialect
from ...sql.type_api import _BindProcessorType
from ...sql.type_api import _ResultProcessorType
from ...sql.type_api import TypeEngine
class _NumericType:
"""Base for MySQL numeric types.
@@ -34,27 +22,19 @@ class _NumericType:
"""
def __init__(
self, unsigned: bool = False, zerofill: bool = False, **kw: Any
):
def __init__(self, unsigned=False, zerofill=False, **kw):
self.unsigned = unsigned
self.zerofill = zerofill
super().__init__(**kw)
def __repr__(self) -> str:
def __repr__(self):
return util.generic_repr(
self, to_inspect=[_NumericType, sqltypes.Numeric]
)
class _FloatType(_NumericType, sqltypes.Float[Union[decimal.Decimal, float]]):
def __init__(
self,
precision: Optional[int] = None,
scale: Optional[int] = None,
asdecimal: bool = True,
**kw: Any,
):
class _FloatType(_NumericType, sqltypes.Float):
def __init__(self, precision=None, scale=None, asdecimal=True, **kw):
if isinstance(self, (REAL, DOUBLE)) and (
(precision is None and scale is not None)
or (precision is not None and scale is None)
@@ -66,18 +46,18 @@ class _FloatType(_NumericType, sqltypes.Float[Union[decimal.Decimal, float]]):
super().__init__(precision=precision, asdecimal=asdecimal, **kw)
self.scale = scale
def __repr__(self) -> str:
def __repr__(self):
return util.generic_repr(
self, to_inspect=[_FloatType, _NumericType, sqltypes.Float]
)
class _IntegerType(_NumericType, sqltypes.Integer):
def __init__(self, display_width: Optional[int] = None, **kw: Any):
def __init__(self, display_width=None, **kw):
self.display_width = display_width
super().__init__(**kw)
def __repr__(self) -> str:
def __repr__(self):
return util.generic_repr(
self, to_inspect=[_IntegerType, _NumericType, sqltypes.Integer]
)
@@ -88,13 +68,13 @@ class _StringType(sqltypes.String):
def __init__(
self,
charset: Optional[str] = None,
collation: Optional[str] = None,
ascii: bool = False, # noqa
binary: bool = False,
unicode: bool = False,
national: bool = False,
**kw: Any,
charset=None,
collation=None,
ascii=False, # noqa
binary=False,
unicode=False,
national=False,
**kw,
):
self.charset = charset
@@ -107,33 +87,25 @@ class _StringType(sqltypes.String):
self.national = national
super().__init__(**kw)
def __repr__(self) -> str:
def __repr__(self):
return util.generic_repr(
self, to_inspect=[_StringType, sqltypes.String]
)
class _MatchType(
sqltypes.Float[Union[decimal.Decimal, float]], sqltypes.MatchType
):
def __init__(self, **kw: Any):
class _MatchType(sqltypes.Float, sqltypes.MatchType):
def __init__(self, **kw):
# TODO: float arguments?
sqltypes.Float.__init__(self) # type: ignore[arg-type]
sqltypes.Float.__init__(self)
sqltypes.MatchType.__init__(self)
class NUMERIC(_NumericType, sqltypes.NUMERIC[Union[decimal.Decimal, float]]):
class NUMERIC(_NumericType, sqltypes.NUMERIC):
"""MySQL NUMERIC type."""
__visit_name__ = "NUMERIC"
def __init__(
self,
precision: Optional[int] = None,
scale: Optional[int] = None,
asdecimal: bool = True,
**kw: Any,
):
def __init__(self, precision=None, scale=None, asdecimal=True, **kw):
"""Construct a NUMERIC.
:param precision: Total digits in this number. If scale and precision
@@ -154,18 +126,12 @@ class NUMERIC(_NumericType, sqltypes.NUMERIC[Union[decimal.Decimal, float]]):
)
class DECIMAL(_NumericType, sqltypes.DECIMAL[Union[decimal.Decimal, float]]):
class DECIMAL(_NumericType, sqltypes.DECIMAL):
"""MySQL DECIMAL type."""
__visit_name__ = "DECIMAL"
def __init__(
self,
precision: Optional[int] = None,
scale: Optional[int] = None,
asdecimal: bool = True,
**kw: Any,
):
def __init__(self, precision=None, scale=None, asdecimal=True, **kw):
"""Construct a DECIMAL.
:param precision: Total digits in this number. If scale and precision
@@ -186,18 +152,12 @@ class DECIMAL(_NumericType, sqltypes.DECIMAL[Union[decimal.Decimal, float]]):
)
class DOUBLE(_FloatType, sqltypes.DOUBLE[Union[decimal.Decimal, float]]):
class DOUBLE(_FloatType, sqltypes.DOUBLE):
"""MySQL DOUBLE type."""
__visit_name__ = "DOUBLE"
def __init__(
self,
precision: Optional[int] = None,
scale: Optional[int] = None,
asdecimal: bool = True,
**kw: Any,
):
def __init__(self, precision=None, scale=None, asdecimal=True, **kw):
"""Construct a DOUBLE.
.. note::
@@ -226,18 +186,12 @@ class DOUBLE(_FloatType, sqltypes.DOUBLE[Union[decimal.Decimal, float]]):
)
class REAL(_FloatType, sqltypes.REAL[Union[decimal.Decimal, float]]):
class REAL(_FloatType, sqltypes.REAL):
"""MySQL REAL type."""
__visit_name__ = "REAL"
def __init__(
self,
precision: Optional[int] = None,
scale: Optional[int] = None,
asdecimal: bool = True,
**kw: Any,
):
def __init__(self, precision=None, scale=None, asdecimal=True, **kw):
"""Construct a REAL.
.. note::
@@ -266,18 +220,12 @@ class REAL(_FloatType, sqltypes.REAL[Union[decimal.Decimal, float]]):
)
class FLOAT(_FloatType, sqltypes.FLOAT[Union[decimal.Decimal, float]]):
class FLOAT(_FloatType, sqltypes.FLOAT):
"""MySQL FLOAT type."""
__visit_name__ = "FLOAT"
def __init__(
self,
precision: Optional[int] = None,
scale: Optional[int] = None,
asdecimal: bool = False,
**kw: Any,
):
def __init__(self, precision=None, scale=None, asdecimal=False, **kw):
"""Construct a FLOAT.
:param precision: Total digits in this number. If scale and precision
@@ -297,9 +245,7 @@ class FLOAT(_FloatType, sqltypes.FLOAT[Union[decimal.Decimal, float]]):
precision=precision, scale=scale, asdecimal=asdecimal, **kw
)
def bind_processor(
self, dialect: Dialect
) -> Optional[_BindProcessorType[Union[decimal.Decimal, float]]]:
def bind_processor(self, dialect):
return None
@@ -308,7 +254,7 @@ class INTEGER(_IntegerType, sqltypes.INTEGER):
__visit_name__ = "INTEGER"
def __init__(self, display_width: Optional[int] = None, **kw: Any):
def __init__(self, display_width=None, **kw):
"""Construct an INTEGER.
:param display_width: Optional, maximum display width for this number.
@@ -329,7 +275,7 @@ class BIGINT(_IntegerType, sqltypes.BIGINT):
__visit_name__ = "BIGINT"
def __init__(self, display_width: Optional[int] = None, **kw: Any):
def __init__(self, display_width=None, **kw):
"""Construct a BIGINTEGER.
:param display_width: Optional, maximum display width for this number.
@@ -350,7 +296,7 @@ class MEDIUMINT(_IntegerType):
__visit_name__ = "MEDIUMINT"
def __init__(self, display_width: Optional[int] = None, **kw: Any):
def __init__(self, display_width=None, **kw):
"""Construct a MEDIUMINTEGER
:param display_width: Optional, maximum display width for this number.
@@ -371,7 +317,7 @@ class TINYINT(_IntegerType):
__visit_name__ = "TINYINT"
def __init__(self, display_width: Optional[int] = None, **kw: Any):
def __init__(self, display_width=None, **kw):
"""Construct a TINYINT.
:param display_width: Optional, maximum display width for this number.
@@ -386,19 +332,13 @@ class TINYINT(_IntegerType):
"""
super().__init__(display_width=display_width, **kw)
def _compare_type_affinity(self, other: TypeEngine[Any]) -> bool:
return (
self._type_affinity is other._type_affinity
or other._type_affinity is sqltypes.Boolean
)
class SMALLINT(_IntegerType, sqltypes.SMALLINT):
"""MySQL SMALLINTEGER type."""
__visit_name__ = "SMALLINT"
def __init__(self, display_width: Optional[int] = None, **kw: Any):
def __init__(self, display_width=None, **kw):
"""Construct a SMALLINTEGER.
:param display_width: Optional, maximum display width for this number.
@@ -414,7 +354,7 @@ class SMALLINT(_IntegerType, sqltypes.SMALLINT):
super().__init__(display_width=display_width, **kw)
class BIT(sqltypes.TypeEngine[Any]):
class BIT(sqltypes.TypeEngine):
"""MySQL BIT type.
This type is for MySQL 5.0.3 or greater for MyISAM, and 5.0.5 or greater
@@ -425,7 +365,7 @@ class BIT(sqltypes.TypeEngine[Any]):
__visit_name__ = "BIT"
def __init__(self, length: Optional[int] = None):
def __init__(self, length=None):
"""Construct a BIT.
:param length: Optional, number of bits.
@@ -433,19 +373,20 @@ class BIT(sqltypes.TypeEngine[Any]):
"""
self.length = length
def result_processor(
self, dialect: MySQLDialect, coltype: object # type: ignore[override]
) -> Optional[_ResultProcessorType[Any]]:
"""Convert a MySQL's 64 bit, variable length binary string to a
long."""
def result_processor(self, dialect, coltype):
"""Convert a MySQL's 64 bit, variable length binary string to a long.
if dialect.supports_native_bit:
return None
TODO: this is MySQL-db, pyodbc specific. OurSQL and mysqlconnector
already do this, so this logic should be moved to those dialects.
def process(value: Optional[Iterable[int]]) -> Optional[int]:
"""
def process(value):
if value is not None:
v = 0
for i in value:
if not isinstance(i, int):
i = ord(i) # convert byte to int on Python 2
v = v << 8 | i
return v
return value
@@ -458,7 +399,7 @@ class TIME(sqltypes.TIME):
__visit_name__ = "TIME"
def __init__(self, timezone: bool = False, fsp: Optional[int] = None):
def __init__(self, timezone=False, fsp=None):
"""Construct a MySQL TIME type.
:param timezone: not used by the MySQL dialect.
@@ -477,12 +418,10 @@ class TIME(sqltypes.TIME):
super().__init__(timezone=timezone)
self.fsp = fsp
def result_processor(
self, dialect: Dialect, coltype: object
) -> _ResultProcessorType[datetime.time]:
def result_processor(self, dialect, coltype):
time = datetime.time
def process(value: Any) -> Optional[datetime.time]:
def process(value):
# convert from a timedelta value
if value is not None:
microseconds = value.microseconds
@@ -505,7 +444,7 @@ class TIMESTAMP(sqltypes.TIMESTAMP):
__visit_name__ = "TIMESTAMP"
def __init__(self, timezone: bool = False, fsp: Optional[int] = None):
def __init__(self, timezone=False, fsp=None):
"""Construct a MySQL TIMESTAMP type.
:param timezone: not used by the MySQL dialect.
@@ -530,7 +469,7 @@ class DATETIME(sqltypes.DATETIME):
__visit_name__ = "DATETIME"
def __init__(self, timezone: bool = False, fsp: Optional[int] = None):
def __init__(self, timezone=False, fsp=None):
"""Construct a MySQL DATETIME type.
:param timezone: not used by the MySQL dialect.
@@ -550,26 +489,26 @@ class DATETIME(sqltypes.DATETIME):
self.fsp = fsp
class YEAR(sqltypes.TypeEngine[Any]):
class YEAR(sqltypes.TypeEngine):
"""MySQL YEAR type, for single byte storage of years 1901-2155."""
__visit_name__ = "YEAR"
def __init__(self, display_width: Optional[int] = None):
def __init__(self, display_width=None):
self.display_width = display_width
class TEXT(_StringType, sqltypes.TEXT):
"""MySQL TEXT type, for character storage encoded up to 2^16 bytes."""
"""MySQL TEXT type, for text up to 2^16 characters."""
__visit_name__ = "TEXT"
def __init__(self, length: Optional[int] = None, **kw: Any):
def __init__(self, length=None, **kw):
"""Construct a TEXT.
:param length: Optional, if provided the server may optimize storage
by substituting the smallest TEXT type sufficient to store
``length`` bytes of characters.
``length`` characters.
:param charset: Optional, a column-level character set for this string
value. Takes precedence to 'ascii' or 'unicode' short-hand.
@@ -596,11 +535,11 @@ class TEXT(_StringType, sqltypes.TEXT):
class TINYTEXT(_StringType):
"""MySQL TINYTEXT type, for character storage encoded up to 2^8 bytes."""
"""MySQL TINYTEXT type, for text up to 2^8 characters."""
__visit_name__ = "TINYTEXT"
def __init__(self, **kwargs: Any):
def __init__(self, **kwargs):
"""Construct a TINYTEXT.
:param charset: Optional, a column-level character set for this string
@@ -628,12 +567,11 @@ class TINYTEXT(_StringType):
class MEDIUMTEXT(_StringType):
"""MySQL MEDIUMTEXT type, for character storage encoded up
to 2^24 bytes."""
"""MySQL MEDIUMTEXT type, for text up to 2^24 characters."""
__visit_name__ = "MEDIUMTEXT"
def __init__(self, **kwargs: Any):
def __init__(self, **kwargs):
"""Construct a MEDIUMTEXT.
:param charset: Optional, a column-level character set for this string
@@ -661,11 +599,11 @@ class MEDIUMTEXT(_StringType):
class LONGTEXT(_StringType):
"""MySQL LONGTEXT type, for character storage encoded up to 2^32 bytes."""
"""MySQL LONGTEXT type, for text up to 2^32 characters."""
__visit_name__ = "LONGTEXT"
def __init__(self, **kwargs: Any):
def __init__(self, **kwargs):
"""Construct a LONGTEXT.
:param charset: Optional, a column-level character set for this string
@@ -697,7 +635,7 @@ class VARCHAR(_StringType, sqltypes.VARCHAR):
__visit_name__ = "VARCHAR"
def __init__(self, length: Optional[int] = None, **kwargs: Any) -> None:
def __init__(self, length=None, **kwargs):
"""Construct a VARCHAR.
:param charset: Optional, a column-level character set for this string
@@ -729,7 +667,7 @@ class CHAR(_StringType, sqltypes.CHAR):
__visit_name__ = "CHAR"
def __init__(self, length: Optional[int] = None, **kwargs: Any):
def __init__(self, length=None, **kwargs):
"""Construct a CHAR.
:param length: Maximum data length, in characters.
@@ -745,7 +683,7 @@ class CHAR(_StringType, sqltypes.CHAR):
super().__init__(length=length, **kwargs)
@classmethod
def _adapt_string_for_cast(cls, type_: sqltypes.String) -> sqltypes.CHAR:
def _adapt_string_for_cast(self, type_):
# copy the given string type into a CHAR
# for the purposes of rendering a CAST expression
type_ = sqltypes.to_instance(type_)
@@ -774,7 +712,7 @@ class NVARCHAR(_StringType, sqltypes.NVARCHAR):
__visit_name__ = "NVARCHAR"
def __init__(self, length: Optional[int] = None, **kwargs: Any):
def __init__(self, length=None, **kwargs):
"""Construct an NVARCHAR.
:param length: Maximum data length, in characters.
@@ -800,7 +738,7 @@ class NCHAR(_StringType, sqltypes.NCHAR):
__visit_name__ = "NCHAR"
def __init__(self, length: Optional[int] = None, **kwargs: Any):
def __init__(self, length=None, **kwargs):
"""Construct an NCHAR.
:param length: Maximum data length, in characters.