This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
# dialects/mysql/__init__.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# mysql/__init__.py
|
||||
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
@@ -53,8 +53,7 @@ from .base import YEAR
|
||||
from .dml import Insert
|
||||
from .dml import insert
|
||||
from .expression import match
|
||||
from .mariadb import INET4
|
||||
from .mariadb import INET6
|
||||
from ...util import compat
|
||||
|
||||
# default dialect
|
||||
base.dialect = dialect = mysqldb.dialect
|
||||
@@ -72,8 +71,6 @@ __all__ = (
|
||||
"DOUBLE",
|
||||
"ENUM",
|
||||
"FLOAT",
|
||||
"INET4",
|
||||
"INET6",
|
||||
"INTEGER",
|
||||
"INTEGER",
|
||||
"JSON",
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
# dialects/mysql/aiomysql.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors <see AUTHORS
|
||||
# mysql/aiomysql.py
|
||||
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors <see AUTHORS
|
||||
# file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
r"""
|
||||
.. dialect:: mysql+aiomysql
|
||||
@@ -22,108 +23,207 @@ This dialect should normally be used only with the
|
||||
:func:`_asyncio.create_async_engine` engine creation function::
|
||||
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
engine = create_async_engine("mysql+aiomysql://user:pass@hostname/dbname?charset=utf8mb4")
|
||||
|
||||
engine = create_async_engine(
|
||||
"mysql+aiomysql://user:pass@hostname/dbname?charset=utf8mb4"
|
||||
)
|
||||
|
||||
""" # noqa
|
||||
from __future__ import annotations
|
||||
|
||||
from types import ModuleType
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import Union
|
||||
|
||||
from .pymysql import MySQLDialect_pymysql
|
||||
from ... import pool
|
||||
from ... import util
|
||||
from ...connectors.asyncio import AsyncAdapt_dbapi_connection
|
||||
from ...connectors.asyncio import AsyncAdapt_dbapi_cursor
|
||||
from ...connectors.asyncio import AsyncAdapt_dbapi_module
|
||||
from ...connectors.asyncio import AsyncAdapt_dbapi_ss_cursor
|
||||
from ...engine import AdaptedConnection
|
||||
from ...util.concurrency import asyncio
|
||||
from ...util.concurrency import await_fallback
|
||||
from ...util.concurrency import await_only
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
from ...connectors.asyncio import AsyncIODBAPIConnection
|
||||
from ...connectors.asyncio import AsyncIODBAPICursor
|
||||
from ...engine.interfaces import ConnectArgsType
|
||||
from ...engine.interfaces import DBAPIConnection
|
||||
from ...engine.interfaces import DBAPICursor
|
||||
from ...engine.interfaces import DBAPIModule
|
||||
from ...engine.interfaces import PoolProxiedConnection
|
||||
from ...engine.url import URL
|
||||
class AsyncAdapt_aiomysql_cursor:
|
||||
# TODO: base on connectors/asyncio.py
|
||||
# see #10415
|
||||
server_side = False
|
||||
__slots__ = (
|
||||
"_adapt_connection",
|
||||
"_connection",
|
||||
"await_",
|
||||
"_cursor",
|
||||
"_rows",
|
||||
)
|
||||
|
||||
def __init__(self, adapt_connection):
|
||||
self._adapt_connection = adapt_connection
|
||||
self._connection = adapt_connection._connection
|
||||
self.await_ = adapt_connection.await_
|
||||
|
||||
class AsyncAdapt_aiomysql_cursor(AsyncAdapt_dbapi_cursor):
|
||||
__slots__ = ()
|
||||
cursor = self._connection.cursor(adapt_connection.dbapi.Cursor)
|
||||
|
||||
def _make_new_cursor(
|
||||
self, connection: AsyncIODBAPIConnection
|
||||
) -> AsyncIODBAPICursor:
|
||||
return connection.cursor(self._adapt_connection.dbapi.Cursor)
|
||||
# see https://github.com/aio-libs/aiomysql/issues/543
|
||||
self._cursor = self.await_(cursor.__aenter__())
|
||||
self._rows = []
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return self._cursor.description
|
||||
|
||||
class AsyncAdapt_aiomysql_ss_cursor(
|
||||
AsyncAdapt_dbapi_ss_cursor, AsyncAdapt_aiomysql_cursor
|
||||
):
|
||||
__slots__ = ()
|
||||
@property
|
||||
def rowcount(self):
|
||||
return self._cursor.rowcount
|
||||
|
||||
def _make_new_cursor(
|
||||
self, connection: AsyncIODBAPIConnection
|
||||
) -> AsyncIODBAPICursor:
|
||||
return connection.cursor(
|
||||
self._adapt_connection.dbapi.aiomysql.cursors.SSCursor
|
||||
@property
|
||||
def arraysize(self):
|
||||
return self._cursor.arraysize
|
||||
|
||||
@arraysize.setter
|
||||
def arraysize(self, value):
|
||||
self._cursor.arraysize = value
|
||||
|
||||
@property
|
||||
def lastrowid(self):
|
||||
return self._cursor.lastrowid
|
||||
|
||||
def close(self):
|
||||
# note we aren't actually closing the cursor here,
|
||||
# we are just letting GC do it. to allow this to be async
|
||||
# we would need the Result to change how it does "Safe close cursor".
|
||||
# MySQL "cursors" don't actually have state to be "closed" besides
|
||||
# exhausting rows, which we already have done for sync cursor.
|
||||
# another option would be to emulate aiosqlite dialect and assign
|
||||
# cursor only if we are doing server side cursor operation.
|
||||
self._rows[:] = []
|
||||
|
||||
def execute(self, operation, parameters=None):
|
||||
return self.await_(self._execute_async(operation, parameters))
|
||||
|
||||
def executemany(self, operation, seq_of_parameters):
|
||||
return self.await_(
|
||||
self._executemany_async(operation, seq_of_parameters)
|
||||
)
|
||||
|
||||
async def _execute_async(self, operation, parameters):
|
||||
async with self._adapt_connection._execute_mutex:
|
||||
result = await self._cursor.execute(operation, parameters)
|
||||
|
||||
class AsyncAdapt_aiomysql_connection(AsyncAdapt_dbapi_connection):
|
||||
if not self.server_side:
|
||||
# aiomysql has a "fake" async result, so we have to pull it out
|
||||
# of that here since our default result is not async.
|
||||
# we could just as easily grab "_rows" here and be done with it
|
||||
# but this is safer.
|
||||
self._rows = list(await self._cursor.fetchall())
|
||||
return result
|
||||
|
||||
async def _executemany_async(self, operation, seq_of_parameters):
|
||||
async with self._adapt_connection._execute_mutex:
|
||||
return await self._cursor.executemany(operation, seq_of_parameters)
|
||||
|
||||
def setinputsizes(self, *inputsizes):
|
||||
pass
|
||||
|
||||
def __iter__(self):
|
||||
while self._rows:
|
||||
yield self._rows.pop(0)
|
||||
|
||||
def fetchone(self):
|
||||
if self._rows:
|
||||
return self._rows.pop(0)
|
||||
else:
|
||||
return None
|
||||
|
||||
def fetchmany(self, size=None):
|
||||
if size is None:
|
||||
size = self.arraysize
|
||||
|
||||
retval = self._rows[0:size]
|
||||
self._rows[:] = self._rows[size:]
|
||||
return retval
|
||||
|
||||
def fetchall(self):
|
||||
retval = self._rows[:]
|
||||
self._rows[:] = []
|
||||
return retval
|
||||
|
||||
|
||||
class AsyncAdapt_aiomysql_ss_cursor(AsyncAdapt_aiomysql_cursor):
|
||||
# TODO: base on connectors/asyncio.py
|
||||
# see #10415
|
||||
__slots__ = ()
|
||||
server_side = True
|
||||
|
||||
_cursor_cls = AsyncAdapt_aiomysql_cursor
|
||||
_ss_cursor_cls = AsyncAdapt_aiomysql_ss_cursor
|
||||
def __init__(self, adapt_connection):
|
||||
self._adapt_connection = adapt_connection
|
||||
self._connection = adapt_connection._connection
|
||||
self.await_ = adapt_connection.await_
|
||||
|
||||
def ping(self, reconnect: bool) -> None:
|
||||
assert not reconnect
|
||||
self.await_(self._connection.ping(reconnect))
|
||||
cursor = self._connection.cursor(adapt_connection.dbapi.SSCursor)
|
||||
|
||||
def character_set_name(self) -> Optional[str]:
|
||||
return self._connection.character_set_name() # type: ignore[no-any-return] # noqa: E501
|
||||
self._cursor = self.await_(cursor.__aenter__())
|
||||
|
||||
def autocommit(self, value: Any) -> None:
|
||||
def close(self):
|
||||
if self._cursor is not None:
|
||||
self.await_(self._cursor.close())
|
||||
self._cursor = None
|
||||
|
||||
def fetchone(self):
|
||||
return self.await_(self._cursor.fetchone())
|
||||
|
||||
def fetchmany(self, size=None):
|
||||
return self.await_(self._cursor.fetchmany(size=size))
|
||||
|
||||
def fetchall(self):
|
||||
return self.await_(self._cursor.fetchall())
|
||||
|
||||
|
||||
class AsyncAdapt_aiomysql_connection(AdaptedConnection):
|
||||
# TODO: base on connectors/asyncio.py
|
||||
# see #10415
|
||||
await_ = staticmethod(await_only)
|
||||
__slots__ = ("dbapi", "_execute_mutex")
|
||||
|
||||
def __init__(self, dbapi, connection):
|
||||
self.dbapi = dbapi
|
||||
self._connection = connection
|
||||
self._execute_mutex = asyncio.Lock()
|
||||
|
||||
def ping(self, reconnect):
|
||||
return self.await_(self._connection.ping(reconnect))
|
||||
|
||||
def character_set_name(self):
|
||||
return self._connection.character_set_name()
|
||||
|
||||
def autocommit(self, value):
|
||||
self.await_(self._connection.autocommit(value))
|
||||
|
||||
def get_autocommit(self) -> bool:
|
||||
return self._connection.get_autocommit() # type: ignore
|
||||
def cursor(self, server_side=False):
|
||||
if server_side:
|
||||
return AsyncAdapt_aiomysql_ss_cursor(self)
|
||||
else:
|
||||
return AsyncAdapt_aiomysql_cursor(self)
|
||||
|
||||
def terminate(self) -> None:
|
||||
def rollback(self):
|
||||
self.await_(self._connection.rollback())
|
||||
|
||||
def commit(self):
|
||||
self.await_(self._connection.commit())
|
||||
|
||||
def close(self):
|
||||
# it's not awaitable.
|
||||
self._connection.close()
|
||||
|
||||
def close(self) -> None:
|
||||
self.await_(self._connection.ensure_closed())
|
||||
|
||||
|
||||
class AsyncAdaptFallback_aiomysql_connection(AsyncAdapt_aiomysql_connection):
|
||||
# TODO: base on connectors/asyncio.py
|
||||
# see #10415
|
||||
__slots__ = ()
|
||||
|
||||
await_ = staticmethod(await_fallback)
|
||||
|
||||
|
||||
class AsyncAdapt_aiomysql_dbapi(AsyncAdapt_dbapi_module):
|
||||
def __init__(self, aiomysql: ModuleType, pymysql: ModuleType):
|
||||
class AsyncAdapt_aiomysql_dbapi:
|
||||
def __init__(self, aiomysql, pymysql):
|
||||
self.aiomysql = aiomysql
|
||||
self.pymysql = pymysql
|
||||
self.paramstyle = "format"
|
||||
self._init_dbapi_attributes()
|
||||
self.Cursor, self.SSCursor = self._init_cursors_subclasses()
|
||||
|
||||
def _init_dbapi_attributes(self) -> None:
|
||||
def _init_dbapi_attributes(self):
|
||||
for name in (
|
||||
"Warning",
|
||||
"Error",
|
||||
@@ -149,7 +249,7 @@ class AsyncAdapt_aiomysql_dbapi(AsyncAdapt_dbapi_module):
|
||||
):
|
||||
setattr(self, name, getattr(self.pymysql, name))
|
||||
|
||||
def connect(self, *arg: Any, **kw: Any) -> AsyncAdapt_aiomysql_connection:
|
||||
def connect(self, *arg, **kw):
|
||||
async_fallback = kw.pop("async_fallback", False)
|
||||
creator_fn = kw.pop("async_creator_fn", self.aiomysql.connect)
|
||||
|
||||
@@ -164,23 +264,17 @@ class AsyncAdapt_aiomysql_dbapi(AsyncAdapt_dbapi_module):
|
||||
await_only(creator_fn(*arg, **kw)),
|
||||
)
|
||||
|
||||
def _init_cursors_subclasses(
|
||||
self,
|
||||
) -> Tuple[AsyncIODBAPICursor, AsyncIODBAPICursor]:
|
||||
def _init_cursors_subclasses(self):
|
||||
# suppress unconditional warning emitted by aiomysql
|
||||
class Cursor(self.aiomysql.Cursor): # type: ignore[misc, name-defined]
|
||||
async def _show_warnings(
|
||||
self, conn: AsyncIODBAPIConnection
|
||||
) -> None:
|
||||
class Cursor(self.aiomysql.Cursor):
|
||||
async def _show_warnings(self, conn):
|
||||
pass
|
||||
|
||||
class SSCursor(self.aiomysql.SSCursor): # type: ignore[misc, name-defined] # noqa: E501
|
||||
async def _show_warnings(
|
||||
self, conn: AsyncIODBAPIConnection
|
||||
) -> None:
|
||||
class SSCursor(self.aiomysql.SSCursor):
|
||||
async def _show_warnings(self, conn):
|
||||
pass
|
||||
|
||||
return Cursor, SSCursor # type: ignore[return-value]
|
||||
return Cursor, SSCursor
|
||||
|
||||
|
||||
class MySQLDialect_aiomysql(MySQLDialect_pymysql):
|
||||
@@ -191,16 +285,15 @@ class MySQLDialect_aiomysql(MySQLDialect_pymysql):
|
||||
_sscursor = AsyncAdapt_aiomysql_ss_cursor
|
||||
|
||||
is_async = True
|
||||
has_terminate = True
|
||||
|
||||
@classmethod
|
||||
def import_dbapi(cls) -> AsyncAdapt_aiomysql_dbapi:
|
||||
def import_dbapi(cls):
|
||||
return AsyncAdapt_aiomysql_dbapi(
|
||||
__import__("aiomysql"), __import__("pymysql")
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_pool_class(cls, url: URL) -> type:
|
||||
def get_pool_class(cls, url):
|
||||
async_fallback = url.query.get("async_fallback", False)
|
||||
|
||||
if util.asbool(async_fallback):
|
||||
@@ -208,37 +301,25 @@ class MySQLDialect_aiomysql(MySQLDialect_pymysql):
|
||||
else:
|
||||
return pool.AsyncAdaptedQueuePool
|
||||
|
||||
def do_terminate(self, dbapi_connection: DBAPIConnection) -> None:
|
||||
dbapi_connection.terminate()
|
||||
|
||||
def create_connect_args(
|
||||
self, url: URL, _translate_args: Optional[Dict[str, Any]] = None
|
||||
) -> ConnectArgsType:
|
||||
def create_connect_args(self, url):
|
||||
return super().create_connect_args(
|
||||
url, _translate_args=dict(username="user", database="db")
|
||||
)
|
||||
|
||||
def is_disconnect(
|
||||
self,
|
||||
e: DBAPIModule.Error,
|
||||
connection: Optional[Union[PoolProxiedConnection, DBAPIConnection]],
|
||||
cursor: Optional[DBAPICursor],
|
||||
) -> bool:
|
||||
def is_disconnect(self, e, connection, cursor):
|
||||
if super().is_disconnect(e, connection, cursor):
|
||||
return True
|
||||
else:
|
||||
str_e = str(e).lower()
|
||||
return "not connected" in str_e
|
||||
|
||||
def _found_rows_client_flag(self) -> int:
|
||||
from pymysql.constants import CLIENT # type: ignore
|
||||
def _found_rows_client_flag(self):
|
||||
from pymysql.constants import CLIENT
|
||||
|
||||
return CLIENT.FOUND_ROWS # type: ignore[no-any-return]
|
||||
return CLIENT.FOUND_ROWS
|
||||
|
||||
def get_driver_connection(
|
||||
self, connection: DBAPIConnection
|
||||
) -> AsyncIODBAPIConnection:
|
||||
return connection._connection # type: ignore[no-any-return]
|
||||
def get_driver_connection(self, connection):
|
||||
return connection._connection
|
||||
|
||||
|
||||
dialect = MySQLDialect_aiomysql
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
# dialects/mysql/asyncmy.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors <see AUTHORS
|
||||
# mysql/asyncmy.py
|
||||
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors <see AUTHORS
|
||||
# file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
r"""
|
||||
.. dialect:: mysql+asyncmy
|
||||
@@ -20,100 +21,210 @@ This dialect should normally be used only with the
|
||||
:func:`_asyncio.create_async_engine` engine creation function::
|
||||
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
engine = create_async_engine("mysql+asyncmy://user:pass@hostname/dbname?charset=utf8mb4")
|
||||
|
||||
engine = create_async_engine(
|
||||
"mysql+asyncmy://user:pass@hostname/dbname?charset=utf8mb4"
|
||||
)
|
||||
|
||||
""" # noqa
|
||||
from __future__ import annotations
|
||||
|
||||
from types import ModuleType
|
||||
from typing import Any
|
||||
from typing import NoReturn
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import Union
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from .pymysql import MySQLDialect_pymysql
|
||||
from ... import pool
|
||||
from ... import util
|
||||
from ...connectors.asyncio import AsyncAdapt_dbapi_connection
|
||||
from ...connectors.asyncio import AsyncAdapt_dbapi_cursor
|
||||
from ...connectors.asyncio import AsyncAdapt_dbapi_module
|
||||
from ...connectors.asyncio import AsyncAdapt_dbapi_ss_cursor
|
||||
from ...engine import AdaptedConnection
|
||||
from ...util.concurrency import asyncio
|
||||
from ...util.concurrency import await_fallback
|
||||
from ...util.concurrency import await_only
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...connectors.asyncio import AsyncIODBAPIConnection
|
||||
from ...connectors.asyncio import AsyncIODBAPICursor
|
||||
from ...engine.interfaces import ConnectArgsType
|
||||
from ...engine.interfaces import DBAPIConnection
|
||||
from ...engine.interfaces import DBAPICursor
|
||||
from ...engine.interfaces import DBAPIModule
|
||||
from ...engine.interfaces import PoolProxiedConnection
|
||||
from ...engine.url import URL
|
||||
|
||||
class AsyncAdapt_asyncmy_cursor:
|
||||
# TODO: base on connectors/asyncio.py
|
||||
# see #10415
|
||||
server_side = False
|
||||
__slots__ = (
|
||||
"_adapt_connection",
|
||||
"_connection",
|
||||
"await_",
|
||||
"_cursor",
|
||||
"_rows",
|
||||
)
|
||||
|
||||
class AsyncAdapt_asyncmy_cursor(AsyncAdapt_dbapi_cursor):
|
||||
__slots__ = ()
|
||||
def __init__(self, adapt_connection):
|
||||
self._adapt_connection = adapt_connection
|
||||
self._connection = adapt_connection._connection
|
||||
self.await_ = adapt_connection.await_
|
||||
|
||||
cursor = self._connection.cursor()
|
||||
|
||||
class AsyncAdapt_asyncmy_ss_cursor(
|
||||
AsyncAdapt_dbapi_ss_cursor, AsyncAdapt_asyncmy_cursor
|
||||
):
|
||||
__slots__ = ()
|
||||
self._cursor = self.await_(cursor.__aenter__())
|
||||
self._rows = []
|
||||
|
||||
def _make_new_cursor(
|
||||
self, connection: AsyncIODBAPIConnection
|
||||
) -> AsyncIODBAPICursor:
|
||||
return connection.cursor(
|
||||
self._adapt_connection.dbapi.asyncmy.cursors.SSCursor
|
||||
@property
|
||||
def description(self):
|
||||
return self._cursor.description
|
||||
|
||||
@property
|
||||
def rowcount(self):
|
||||
return self._cursor.rowcount
|
||||
|
||||
@property
|
||||
def arraysize(self):
|
||||
return self._cursor.arraysize
|
||||
|
||||
@arraysize.setter
|
||||
def arraysize(self, value):
|
||||
self._cursor.arraysize = value
|
||||
|
||||
@property
|
||||
def lastrowid(self):
|
||||
return self._cursor.lastrowid
|
||||
|
||||
def close(self):
|
||||
# note we aren't actually closing the cursor here,
|
||||
# we are just letting GC do it. to allow this to be async
|
||||
# we would need the Result to change how it does "Safe close cursor".
|
||||
# MySQL "cursors" don't actually have state to be "closed" besides
|
||||
# exhausting rows, which we already have done for sync cursor.
|
||||
# another option would be to emulate aiosqlite dialect and assign
|
||||
# cursor only if we are doing server side cursor operation.
|
||||
self._rows[:] = []
|
||||
|
||||
def execute(self, operation, parameters=None):
|
||||
return self.await_(self._execute_async(operation, parameters))
|
||||
|
||||
def executemany(self, operation, seq_of_parameters):
|
||||
return self.await_(
|
||||
self._executemany_async(operation, seq_of_parameters)
|
||||
)
|
||||
|
||||
async def _execute_async(self, operation, parameters):
|
||||
async with self._adapt_connection._mutex_and_adapt_errors():
|
||||
if parameters is None:
|
||||
result = await self._cursor.execute(operation)
|
||||
else:
|
||||
result = await self._cursor.execute(operation, parameters)
|
||||
|
||||
class AsyncAdapt_asyncmy_connection(AsyncAdapt_dbapi_connection):
|
||||
if not self.server_side:
|
||||
# asyncmy has a "fake" async result, so we have to pull it out
|
||||
# of that here since our default result is not async.
|
||||
# we could just as easily grab "_rows" here and be done with it
|
||||
# but this is safer.
|
||||
self._rows = list(await self._cursor.fetchall())
|
||||
return result
|
||||
|
||||
async def _executemany_async(self, operation, seq_of_parameters):
|
||||
async with self._adapt_connection._mutex_and_adapt_errors():
|
||||
return await self._cursor.executemany(operation, seq_of_parameters)
|
||||
|
||||
def setinputsizes(self, *inputsizes):
|
||||
pass
|
||||
|
||||
def __iter__(self):
|
||||
while self._rows:
|
||||
yield self._rows.pop(0)
|
||||
|
||||
def fetchone(self):
|
||||
if self._rows:
|
||||
return self._rows.pop(0)
|
||||
else:
|
||||
return None
|
||||
|
||||
def fetchmany(self, size=None):
|
||||
if size is None:
|
||||
size = self.arraysize
|
||||
|
||||
retval = self._rows[0:size]
|
||||
self._rows[:] = self._rows[size:]
|
||||
return retval
|
||||
|
||||
def fetchall(self):
|
||||
retval = self._rows[:]
|
||||
self._rows[:] = []
|
||||
return retval
|
||||
|
||||
|
||||
class AsyncAdapt_asyncmy_ss_cursor(AsyncAdapt_asyncmy_cursor):
|
||||
# TODO: base on connectors/asyncio.py
|
||||
# see #10415
|
||||
__slots__ = ()
|
||||
server_side = True
|
||||
|
||||
_cursor_cls = AsyncAdapt_asyncmy_cursor
|
||||
_ss_cursor_cls = AsyncAdapt_asyncmy_ss_cursor
|
||||
def __init__(self, adapt_connection):
|
||||
self._adapt_connection = adapt_connection
|
||||
self._connection = adapt_connection._connection
|
||||
self.await_ = adapt_connection.await_
|
||||
|
||||
def _handle_exception(self, error: Exception) -> NoReturn:
|
||||
if isinstance(error, AttributeError):
|
||||
raise self.dbapi.InternalError(
|
||||
"network operation failed due to asyncmy attribute error"
|
||||
)
|
||||
cursor = self._connection.cursor(
|
||||
adapt_connection.dbapi.asyncmy.cursors.SSCursor
|
||||
)
|
||||
|
||||
raise error
|
||||
self._cursor = self.await_(cursor.__aenter__())
|
||||
|
||||
def ping(self, reconnect: bool) -> None:
|
||||
def close(self):
|
||||
if self._cursor is not None:
|
||||
self.await_(self._cursor.close())
|
||||
self._cursor = None
|
||||
|
||||
def fetchone(self):
|
||||
return self.await_(self._cursor.fetchone())
|
||||
|
||||
def fetchmany(self, size=None):
|
||||
return self.await_(self._cursor.fetchmany(size=size))
|
||||
|
||||
def fetchall(self):
|
||||
return self.await_(self._cursor.fetchall())
|
||||
|
||||
|
||||
class AsyncAdapt_asyncmy_connection(AdaptedConnection):
|
||||
# TODO: base on connectors/asyncio.py
|
||||
# see #10415
|
||||
await_ = staticmethod(await_only)
|
||||
__slots__ = ("dbapi", "_execute_mutex")
|
||||
|
||||
def __init__(self, dbapi, connection):
|
||||
self.dbapi = dbapi
|
||||
self._connection = connection
|
||||
self._execute_mutex = asyncio.Lock()
|
||||
|
||||
@asynccontextmanager
|
||||
async def _mutex_and_adapt_errors(self):
|
||||
async with self._execute_mutex:
|
||||
try:
|
||||
yield
|
||||
except AttributeError:
|
||||
raise self.dbapi.InternalError(
|
||||
"network operation failed due to asyncmy attribute error"
|
||||
)
|
||||
|
||||
def ping(self, reconnect):
|
||||
assert not reconnect
|
||||
return self.await_(self._do_ping())
|
||||
|
||||
async def _do_ping(self) -> None:
|
||||
try:
|
||||
async with self._execute_mutex:
|
||||
await self._connection.ping(False)
|
||||
except Exception as error:
|
||||
self._handle_exception(error)
|
||||
async def _do_ping(self):
|
||||
async with self._mutex_and_adapt_errors():
|
||||
return await self._connection.ping(False)
|
||||
|
||||
def character_set_name(self) -> Optional[str]:
|
||||
return self._connection.character_set_name() # type: ignore[no-any-return] # noqa: E501
|
||||
def character_set_name(self):
|
||||
return self._connection.character_set_name()
|
||||
|
||||
def autocommit(self, value: Any) -> None:
|
||||
def autocommit(self, value):
|
||||
self.await_(self._connection.autocommit(value))
|
||||
|
||||
def get_autocommit(self) -> bool:
|
||||
return self._connection.get_autocommit() # type: ignore
|
||||
def cursor(self, server_side=False):
|
||||
if server_side:
|
||||
return AsyncAdapt_asyncmy_ss_cursor(self)
|
||||
else:
|
||||
return AsyncAdapt_asyncmy_cursor(self)
|
||||
|
||||
def terminate(self) -> None:
|
||||
def rollback(self):
|
||||
self.await_(self._connection.rollback())
|
||||
|
||||
def commit(self):
|
||||
self.await_(self._connection.commit())
|
||||
|
||||
def close(self):
|
||||
# it's not awaitable.
|
||||
self._connection.close()
|
||||
|
||||
def close(self) -> None:
|
||||
self.await_(self._connection.ensure_closed())
|
||||
|
||||
|
||||
class AsyncAdaptFallback_asyncmy_connection(AsyncAdapt_asyncmy_connection):
|
||||
__slots__ = ()
|
||||
@@ -121,13 +232,18 @@ class AsyncAdaptFallback_asyncmy_connection(AsyncAdapt_asyncmy_connection):
|
||||
await_ = staticmethod(await_fallback)
|
||||
|
||||
|
||||
class AsyncAdapt_asyncmy_dbapi(AsyncAdapt_dbapi_module):
|
||||
def __init__(self, asyncmy: ModuleType):
|
||||
def _Binary(x):
|
||||
"""Return x as a binary type."""
|
||||
return bytes(x)
|
||||
|
||||
|
||||
class AsyncAdapt_asyncmy_dbapi:
|
||||
def __init__(self, asyncmy):
|
||||
self.asyncmy = asyncmy
|
||||
self.paramstyle = "format"
|
||||
self._init_dbapi_attributes()
|
||||
|
||||
def _init_dbapi_attributes(self) -> None:
|
||||
def _init_dbapi_attributes(self):
|
||||
for name in (
|
||||
"Warning",
|
||||
"Error",
|
||||
@@ -148,9 +264,9 @@ class AsyncAdapt_asyncmy_dbapi(AsyncAdapt_dbapi_module):
|
||||
BINARY = util.symbol("BINARY")
|
||||
DATETIME = util.symbol("DATETIME")
|
||||
TIMESTAMP = util.symbol("TIMESTAMP")
|
||||
Binary = staticmethod(bytes)
|
||||
Binary = staticmethod(_Binary)
|
||||
|
||||
def connect(self, *arg: Any, **kw: Any) -> AsyncAdapt_asyncmy_connection:
|
||||
def connect(self, *arg, **kw):
|
||||
async_fallback = kw.pop("async_fallback", False)
|
||||
creator_fn = kw.pop("async_creator_fn", self.asyncmy.connect)
|
||||
|
||||
@@ -174,14 +290,13 @@ class MySQLDialect_asyncmy(MySQLDialect_pymysql):
|
||||
_sscursor = AsyncAdapt_asyncmy_ss_cursor
|
||||
|
||||
is_async = True
|
||||
has_terminate = True
|
||||
|
||||
@classmethod
|
||||
def import_dbapi(cls) -> DBAPIModule:
|
||||
def import_dbapi(cls):
|
||||
return AsyncAdapt_asyncmy_dbapi(__import__("asyncmy"))
|
||||
|
||||
@classmethod
|
||||
def get_pool_class(cls, url: URL) -> type:
|
||||
def get_pool_class(cls, url):
|
||||
async_fallback = url.query.get("async_fallback", False)
|
||||
|
||||
if util.asbool(async_fallback):
|
||||
@@ -189,20 +304,12 @@ class MySQLDialect_asyncmy(MySQLDialect_pymysql):
|
||||
else:
|
||||
return pool.AsyncAdaptedQueuePool
|
||||
|
||||
def do_terminate(self, dbapi_connection: DBAPIConnection) -> None:
|
||||
dbapi_connection.terminate()
|
||||
|
||||
def create_connect_args(self, url: URL) -> ConnectArgsType: # type: ignore[override] # noqa: E501
|
||||
def create_connect_args(self, url):
|
||||
return super().create_connect_args(
|
||||
url, _translate_args=dict(username="user", database="db")
|
||||
)
|
||||
|
||||
def is_disconnect(
|
||||
self,
|
||||
e: DBAPIModule.Error,
|
||||
connection: Optional[Union[PoolProxiedConnection, DBAPIConnection]],
|
||||
cursor: Optional[DBAPICursor],
|
||||
) -> bool:
|
||||
def is_disconnect(self, e, connection, cursor):
|
||||
if super().is_disconnect(e, connection, cursor):
|
||||
return True
|
||||
else:
|
||||
@@ -211,15 +318,13 @@ class MySQLDialect_asyncmy(MySQLDialect_pymysql):
|
||||
"not connected" in str_e or "network operation failed" in str_e
|
||||
)
|
||||
|
||||
def _found_rows_client_flag(self) -> int:
|
||||
from asyncmy.constants import CLIENT # type: ignore
|
||||
def _found_rows_client_flag(self):
|
||||
from asyncmy.constants import CLIENT
|
||||
|
||||
return CLIENT.FOUND_ROWS # type: ignore[no-any-return]
|
||||
return CLIENT.FOUND_ROWS
|
||||
|
||||
def get_driver_connection(
|
||||
self, connection: DBAPIConnection
|
||||
) -> AsyncIODBAPIConnection:
|
||||
return connection._connection # type: ignore[no-any-return]
|
||||
def get_driver_connection(self, connection):
|
||||
return connection._connection
|
||||
|
||||
|
||||
dialect = MySQLDialect_asyncmy
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,9 +1,10 @@
|
||||
# dialects/mysql/cymysql.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# mysql/cymysql.py
|
||||
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
r"""
|
||||
|
||||
@@ -20,36 +21,18 @@ r"""
|
||||
dialects are mysqlclient and PyMySQL.
|
||||
|
||||
""" # noqa
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from typing import Iterable
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import Union
|
||||
|
||||
from .base import BIT
|
||||
from .base import MySQLDialect
|
||||
from .mysqldb import MySQLDialect_mysqldb
|
||||
from .types import BIT
|
||||
from ... import util
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...engine.base import Connection
|
||||
from ...engine.interfaces import DBAPIConnection
|
||||
from ...engine.interfaces import DBAPICursor
|
||||
from ...engine.interfaces import DBAPIModule
|
||||
from ...engine.interfaces import Dialect
|
||||
from ...engine.interfaces import PoolProxiedConnection
|
||||
from ...sql.type_api import _ResultProcessorType
|
||||
|
||||
|
||||
class _cymysqlBIT(BIT):
|
||||
def result_processor(
|
||||
self, dialect: Dialect, coltype: object
|
||||
) -> Optional[_ResultProcessorType[Any]]:
|
||||
def result_processor(self, dialect, coltype):
|
||||
"""Convert MySQL's 64 bit, variable length binary string to a long."""
|
||||
|
||||
def process(value: Optional[Iterable[int]]) -> Optional[int]:
|
||||
def process(value):
|
||||
if value is not None:
|
||||
v = 0
|
||||
for i in iter(value):
|
||||
@@ -72,22 +55,17 @@ class MySQLDialect_cymysql(MySQLDialect_mysqldb):
|
||||
colspecs = util.update_copy(MySQLDialect.colspecs, {BIT: _cymysqlBIT})
|
||||
|
||||
@classmethod
|
||||
def import_dbapi(cls) -> DBAPIModule:
|
||||
def import_dbapi(cls):
|
||||
return __import__("cymysql")
|
||||
|
||||
def _detect_charset(self, connection: Connection) -> str:
|
||||
return connection.connection.charset # type: ignore[no-any-return]
|
||||
def _detect_charset(self, connection):
|
||||
return connection.connection.charset
|
||||
|
||||
def _extract_error_code(self, exception: DBAPIModule.Error) -> int:
|
||||
return exception.errno # type: ignore[no-any-return]
|
||||
def _extract_error_code(self, exception):
|
||||
return exception.errno
|
||||
|
||||
def is_disconnect(
|
||||
self,
|
||||
e: DBAPIModule.Error,
|
||||
connection: Optional[Union[PoolProxiedConnection, DBAPIConnection]],
|
||||
cursor: Optional[DBAPICursor],
|
||||
) -> bool:
|
||||
if isinstance(e, self.loaded_dbapi.OperationalError):
|
||||
def is_disconnect(self, e, connection, cursor):
|
||||
if isinstance(e, self.dbapi.OperationalError):
|
||||
return self._extract_error_code(e) in (
|
||||
2006,
|
||||
2013,
|
||||
@@ -95,7 +73,7 @@ class MySQLDialect_cymysql(MySQLDialect_mysqldb):
|
||||
2045,
|
||||
2055,
|
||||
)
|
||||
elif isinstance(e, self.loaded_dbapi.InterfaceError):
|
||||
elif isinstance(e, self.dbapi.InterfaceError):
|
||||
# if underlying connection is closed,
|
||||
# this is the error you get
|
||||
return True
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# dialects/mysql/dml.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# mysql/dml.py
|
||||
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
@@ -7,7 +7,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Mapping
|
||||
from typing import Optional
|
||||
@@ -142,11 +141,7 @@ class Insert(StandardInsert):
|
||||
in :ref:`tutorial_parameter_ordered_updates`::
|
||||
|
||||
insert().on_duplicate_key_update(
|
||||
[
|
||||
("name", "some name"),
|
||||
("value", "some value"),
|
||||
]
|
||||
)
|
||||
[("name", "some name"), ("value", "some value")])
|
||||
|
||||
.. versionchanged:: 1.3 parameters can be specified as a dictionary
|
||||
or list of 2-tuples; the latter form provides for parameter
|
||||
@@ -186,7 +181,6 @@ class OnDuplicateClause(ClauseElement):
|
||||
|
||||
_parameter_ordering: Optional[List[str]] = None
|
||||
|
||||
update: Dict[str, Any]
|
||||
stringify_dialect = "mysql"
|
||||
|
||||
def __init__(
|
||||
|
||||
@@ -1,51 +1,34 @@
|
||||
# dialects/mysql/enumerated.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# mysql/enumerated.py
|
||||
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
import re
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import Optional
|
||||
from typing import Set
|
||||
from typing import Type
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import Union
|
||||
|
||||
from .types import _StringType
|
||||
from ... import exc
|
||||
from ... import sql
|
||||
from ... import util
|
||||
from ...sql import sqltypes
|
||||
from ...sql import type_api
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...engine.interfaces import Dialect
|
||||
from ...sql.elements import ColumnElement
|
||||
from ...sql.type_api import _BindProcessorType
|
||||
from ...sql.type_api import _ResultProcessorType
|
||||
from ...sql.type_api import TypeEngine
|
||||
from ...sql.type_api import TypeEngineMixin
|
||||
|
||||
|
||||
class ENUM(type_api.NativeForEmulated, sqltypes.Enum, _StringType):
|
||||
class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum, _StringType):
|
||||
"""MySQL ENUM type."""
|
||||
|
||||
__visit_name__ = "ENUM"
|
||||
|
||||
native_enum = True
|
||||
|
||||
def __init__(self, *enums: Union[str, Type[enum.Enum]], **kw: Any) -> None:
|
||||
def __init__(self, *enums, **kw):
|
||||
"""Construct an ENUM.
|
||||
|
||||
E.g.::
|
||||
|
||||
Column("myenum", ENUM("foo", "bar", "baz"))
|
||||
Column('myenum', ENUM("foo", "bar", "baz"))
|
||||
|
||||
:param enums: The range of valid values for this ENUM. Values in
|
||||
enums are not quoted, they will be escaped and surrounded by single
|
||||
@@ -79,27 +62,21 @@ class ENUM(type_api.NativeForEmulated, sqltypes.Enum, _StringType):
|
||||
|
||||
"""
|
||||
kw.pop("strict", None)
|
||||
self._enum_init(enums, kw) # type: ignore[arg-type]
|
||||
self._enum_init(enums, kw)
|
||||
_StringType.__init__(self, length=self.length, **kw)
|
||||
|
||||
@classmethod
|
||||
def adapt_emulated_to_native(
|
||||
cls,
|
||||
impl: Union[TypeEngine[Any], TypeEngineMixin],
|
||||
**kw: Any,
|
||||
) -> ENUM:
|
||||
def adapt_emulated_to_native(cls, impl, **kw):
|
||||
"""Produce a MySQL native :class:`.mysql.ENUM` from plain
|
||||
:class:`.Enum`.
|
||||
|
||||
"""
|
||||
if TYPE_CHECKING:
|
||||
assert isinstance(impl, ENUM)
|
||||
kw.setdefault("validate_strings", impl.validate_strings)
|
||||
kw.setdefault("values_callable", impl.values_callable)
|
||||
kw.setdefault("omit_aliases", impl._omit_aliases)
|
||||
return cls(**kw)
|
||||
|
||||
def _object_value_for_elem(self, elem: str) -> Union[str, enum.Enum]:
|
||||
def _object_value_for_elem(self, elem):
|
||||
# mysql sends back a blank string for any value that
|
||||
# was persisted that was not in the enums; that is, it does no
|
||||
# validation on the incoming data, it "truncates" it to be
|
||||
@@ -109,27 +86,24 @@ class ENUM(type_api.NativeForEmulated, sqltypes.Enum, _StringType):
|
||||
else:
|
||||
return super()._object_value_for_elem(elem)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
def __repr__(self):
|
||||
return util.generic_repr(
|
||||
self, to_inspect=[ENUM, _StringType, sqltypes.Enum]
|
||||
)
|
||||
|
||||
|
||||
# TODO: SET is a string as far as configuration but does not act like
|
||||
# a string at the python level. We either need to make a py-type agnostic
|
||||
# version of String as a base to be used for this, make this some kind of
|
||||
# TypeDecorator, or just vendor it out as its own type.
|
||||
class SET(_StringType):
|
||||
"""MySQL SET type."""
|
||||
|
||||
__visit_name__ = "SET"
|
||||
|
||||
def __init__(self, *values: str, **kw: Any):
|
||||
def __init__(self, *values, **kw):
|
||||
"""Construct a SET.
|
||||
|
||||
E.g.::
|
||||
|
||||
Column("myset", SET("foo", "bar", "baz"))
|
||||
Column('myset', SET("foo", "bar", "baz"))
|
||||
|
||||
|
||||
The list of potential values is required in the case that this
|
||||
set will be used to generate DDL for a table, or if the
|
||||
@@ -177,19 +151,17 @@ class SET(_StringType):
|
||||
"setting retrieve_as_bitwise=True"
|
||||
)
|
||||
if self.retrieve_as_bitwise:
|
||||
self._inversed_bitmap: Dict[str, int] = {
|
||||
self._bitmap = {
|
||||
value: 2**idx for idx, value in enumerate(self.values)
|
||||
}
|
||||
self._bitmap: Dict[int, str] = {
|
||||
2**idx: value for idx, value in enumerate(self.values)
|
||||
}
|
||||
self._bitmap.update(
|
||||
(2**idx, value) for idx, value in enumerate(self.values)
|
||||
)
|
||||
length = max([len(v) for v in values] + [0])
|
||||
kw.setdefault("length", length)
|
||||
super().__init__(**kw)
|
||||
|
||||
def column_expression(
|
||||
self, colexpr: ColumnElement[Any]
|
||||
) -> ColumnElement[Any]:
|
||||
def column_expression(self, colexpr):
|
||||
if self.retrieve_as_bitwise:
|
||||
return sql.type_coerce(
|
||||
sql.type_coerce(colexpr, sqltypes.Integer) + 0, self
|
||||
@@ -197,12 +169,10 @@ class SET(_StringType):
|
||||
else:
|
||||
return colexpr
|
||||
|
||||
def result_processor(
|
||||
self, dialect: Dialect, coltype: Any
|
||||
) -> Optional[_ResultProcessorType[Any]]:
|
||||
def result_processor(self, dialect, coltype):
|
||||
if self.retrieve_as_bitwise:
|
||||
|
||||
def process(value: Union[str, int, None]) -> Optional[Set[str]]:
|
||||
def process(value):
|
||||
if value is not None:
|
||||
value = int(value)
|
||||
|
||||
@@ -213,14 +183,11 @@ class SET(_StringType):
|
||||
else:
|
||||
super_convert = super().result_processor(dialect, coltype)
|
||||
|
||||
def process(value: Union[str, Set[str], None]) -> Optional[Set[str]]: # type: ignore[misc] # noqa: E501
|
||||
def process(value):
|
||||
if isinstance(value, str):
|
||||
# MySQLdb returns a string, let's parse
|
||||
if super_convert:
|
||||
value = super_convert(value)
|
||||
assert value is not None
|
||||
if TYPE_CHECKING:
|
||||
assert isinstance(value, str)
|
||||
return set(re.findall(r"[^,]+", value))
|
||||
else:
|
||||
# mysql-connector-python does a naive
|
||||
@@ -231,48 +198,43 @@ class SET(_StringType):
|
||||
|
||||
return process
|
||||
|
||||
def bind_processor(
|
||||
self, dialect: Dialect
|
||||
) -> _BindProcessorType[Union[str, int]]:
|
||||
def bind_processor(self, dialect):
|
||||
super_convert = super().bind_processor(dialect)
|
||||
if self.retrieve_as_bitwise:
|
||||
|
||||
def process(
|
||||
value: Union[str, int, set[str], None],
|
||||
) -> Union[str, int, None]:
|
||||
def process(value):
|
||||
if value is None:
|
||||
return None
|
||||
elif isinstance(value, (int, str)):
|
||||
if super_convert:
|
||||
return super_convert(value) # type: ignore[arg-type, no-any-return] # noqa: E501
|
||||
return super_convert(value)
|
||||
else:
|
||||
return value
|
||||
else:
|
||||
int_value = 0
|
||||
for v in value:
|
||||
int_value |= self._inversed_bitmap[v]
|
||||
int_value |= self._bitmap[v]
|
||||
return int_value
|
||||
|
||||
else:
|
||||
|
||||
def process(
|
||||
value: Union[str, int, set[str], None],
|
||||
) -> Union[str, int, None]:
|
||||
def process(value):
|
||||
# accept strings and int (actually bitflag) values directly
|
||||
if value is not None and not isinstance(value, (int, str)):
|
||||
value = ",".join(value)
|
||||
|
||||
if super_convert:
|
||||
return super_convert(value) # type: ignore
|
||||
return super_convert(value)
|
||||
else:
|
||||
return value
|
||||
|
||||
return process
|
||||
|
||||
def adapt(self, cls: type, **kw: Any) -> Any:
|
||||
def adapt(self, impltype, **kw):
|
||||
kw["retrieve_as_bitwise"] = self.retrieve_as_bitwise
|
||||
return util.constructor_copy(self, cls, *self.values, **kw)
|
||||
return util.constructor_copy(self, impltype, *self.values, **kw)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
def __repr__(self):
|
||||
return util.generic_repr(
|
||||
self,
|
||||
to_inspect=[SET, _StringType],
|
||||
|
||||
@@ -1,13 +1,10 @@
|
||||
# dialects/mysql/expression.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from ... import exc
|
||||
from ... import util
|
||||
@@ -20,7 +17,7 @@ from ...sql.base import Generative
|
||||
from ...util.typing import Self
|
||||
|
||||
|
||||
class match(Generative, elements.BinaryExpression[Any]):
|
||||
class match(Generative, elements.BinaryExpression):
|
||||
"""Produce a ``MATCH (X, Y) AGAINST ('TEXT')`` clause.
|
||||
|
||||
E.g.::
|
||||
@@ -40,9 +37,7 @@ class match(Generative, elements.BinaryExpression[Any]):
|
||||
.order_by(desc(match_expr))
|
||||
)
|
||||
|
||||
Would produce SQL resembling:
|
||||
|
||||
.. sourcecode:: sql
|
||||
Would produce SQL resembling::
|
||||
|
||||
SELECT id, firstname, lastname
|
||||
FROM user
|
||||
@@ -75,9 +70,8 @@ class match(Generative, elements.BinaryExpression[Any]):
|
||||
__visit_name__ = "mysql_match"
|
||||
|
||||
inherit_cache = True
|
||||
modifiers: util.immutabledict[str, Any]
|
||||
|
||||
def __init__(self, *cols: elements.ColumnElement[Any], **kw: Any):
|
||||
def __init__(self, *cols, **kw):
|
||||
if not cols:
|
||||
raise exc.ArgumentError("columns are required")
|
||||
|
||||
|
||||
@@ -1,21 +1,13 @@
|
||||
# dialects/mysql/json.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# mysql/json.py
|
||||
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING
|
||||
# mypy: ignore-errors
|
||||
|
||||
from ... import types as sqltypes
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...engine.interfaces import Dialect
|
||||
from ...sql.type_api import _BindProcessorType
|
||||
from ...sql.type_api import _LiteralProcessorType
|
||||
|
||||
|
||||
class JSON(sqltypes.JSON):
|
||||
"""MySQL JSON type.
|
||||
@@ -42,13 +34,13 @@ class JSON(sqltypes.JSON):
|
||||
|
||||
|
||||
class _FormatTypeMixin:
|
||||
def _format_value(self, value: Any) -> str:
|
||||
def _format_value(self, value):
|
||||
raise NotImplementedError()
|
||||
|
||||
def bind_processor(self, dialect: Dialect) -> _BindProcessorType[Any]:
|
||||
super_proc = self.string_bind_processor(dialect) # type: ignore[attr-defined] # noqa: E501
|
||||
def bind_processor(self, dialect):
|
||||
super_proc = self.string_bind_processor(dialect)
|
||||
|
||||
def process(value: Any) -> Any:
|
||||
def process(value):
|
||||
value = self._format_value(value)
|
||||
if super_proc:
|
||||
value = super_proc(value)
|
||||
@@ -56,31 +48,29 @@ class _FormatTypeMixin:
|
||||
|
||||
return process
|
||||
|
||||
def literal_processor(
|
||||
self, dialect: Dialect
|
||||
) -> _LiteralProcessorType[Any]:
|
||||
super_proc = self.string_literal_processor(dialect) # type: ignore[attr-defined] # noqa: E501
|
||||
def literal_processor(self, dialect):
|
||||
super_proc = self.string_literal_processor(dialect)
|
||||
|
||||
def process(value: Any) -> str:
|
||||
def process(value):
|
||||
value = self._format_value(value)
|
||||
if super_proc:
|
||||
value = super_proc(value)
|
||||
return value # type: ignore[no-any-return]
|
||||
return value
|
||||
|
||||
return process
|
||||
|
||||
|
||||
class JSONIndexType(_FormatTypeMixin, sqltypes.JSON.JSONIndexType):
|
||||
def _format_value(self, value: Any) -> str:
|
||||
def _format_value(self, value):
|
||||
if isinstance(value, int):
|
||||
formatted_value = "$[%s]" % value
|
||||
value = "$[%s]" % value
|
||||
else:
|
||||
formatted_value = '$."%s"' % value
|
||||
return formatted_value
|
||||
value = '$."%s"' % value
|
||||
return value
|
||||
|
||||
|
||||
class JSONPathType(_FormatTypeMixin, sqltypes.JSON.JSONPathType):
|
||||
def _format_value(self, value: Any) -> str:
|
||||
def _format_value(self, value):
|
||||
return "$%s" % (
|
||||
"".join(
|
||||
[
|
||||
|
||||
@@ -1,73 +1,32 @@
|
||||
# dialects/mysql/mariadb.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# mysql/mariadb.py
|
||||
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from typing import Callable
|
||||
|
||||
# mypy: ignore-errors
|
||||
from .base import MariaDBIdentifierPreparer
|
||||
from .base import MySQLDialect
|
||||
from .base import MySQLIdentifierPreparer
|
||||
from .base import MySQLTypeCompiler
|
||||
from ...sql import sqltypes
|
||||
|
||||
|
||||
class INET4(sqltypes.TypeEngine[str]):
|
||||
"""INET4 column type for MariaDB
|
||||
|
||||
.. versionadded:: 2.0.37
|
||||
"""
|
||||
|
||||
__visit_name__ = "INET4"
|
||||
|
||||
|
||||
class INET6(sqltypes.TypeEngine[str]):
|
||||
"""INET6 column type for MariaDB
|
||||
|
||||
.. versionadded:: 2.0.37
|
||||
"""
|
||||
|
||||
__visit_name__ = "INET6"
|
||||
|
||||
|
||||
class MariaDBTypeCompiler(MySQLTypeCompiler):
|
||||
def visit_INET4(self, type_: INET4, **kwargs: Any) -> str:
|
||||
return "INET4"
|
||||
|
||||
def visit_INET6(self, type_: INET6, **kwargs: Any) -> str:
|
||||
return "INET6"
|
||||
|
||||
|
||||
class MariaDBDialect(MySQLDialect):
|
||||
is_mariadb = True
|
||||
supports_statement_cache = True
|
||||
name = "mariadb"
|
||||
preparer: type[MySQLIdentifierPreparer] = MariaDBIdentifierPreparer
|
||||
type_compiler_cls = MariaDBTypeCompiler
|
||||
preparer = MariaDBIdentifierPreparer
|
||||
|
||||
|
||||
def loader(driver: str) -> Callable[[], type[MariaDBDialect]]:
|
||||
dialect_mod = __import__(
|
||||
def loader(driver):
|
||||
driver_mod = __import__(
|
||||
"sqlalchemy.dialects.mysql.%s" % driver
|
||||
).dialects.mysql
|
||||
driver_cls = getattr(driver_mod, driver).dialect
|
||||
|
||||
driver_mod = getattr(dialect_mod, driver)
|
||||
if hasattr(driver_mod, "mariadb_dialect"):
|
||||
driver_cls = driver_mod.mariadb_dialect
|
||||
return driver_cls # type: ignore[no-any-return]
|
||||
else:
|
||||
driver_cls = driver_mod.dialect
|
||||
|
||||
return type(
|
||||
"MariaDBDialect_%s" % driver,
|
||||
(
|
||||
MariaDBDialect,
|
||||
driver_cls,
|
||||
),
|
||||
{"supports_statement_cache": True},
|
||||
)
|
||||
return type(
|
||||
"MariaDBDialect_%s" % driver,
|
||||
(
|
||||
MariaDBDialect,
|
||||
driver_cls,
|
||||
),
|
||||
{"supports_statement_cache": True},
|
||||
)
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
# dialects/mysql/mariadbconnector.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# mysql/mariadbconnector.py
|
||||
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
|
||||
"""
|
||||
|
||||
@@ -27,15 +29,7 @@ be ``mysqldb``. ``mariadb+mariadbconnector://`` is required to use this driver.
|
||||
.. mariadb: https://github.com/mariadb-corporation/mariadb-connector-python
|
||||
|
||||
""" # noqa
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import Any
|
||||
from typing import Optional
|
||||
from typing import Sequence
|
||||
from typing import Tuple
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import Union
|
||||
from uuid import UUID as _python_UUID
|
||||
|
||||
from .base import MySQLCompiler
|
||||
@@ -45,19 +39,6 @@ from ... import sql
|
||||
from ... import util
|
||||
from ...sql import sqltypes
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...engine.base import Connection
|
||||
from ...engine.interfaces import ConnectArgsType
|
||||
from ...engine.interfaces import DBAPIConnection
|
||||
from ...engine.interfaces import DBAPICursor
|
||||
from ...engine.interfaces import DBAPIModule
|
||||
from ...engine.interfaces import Dialect
|
||||
from ...engine.interfaces import IsolationLevel
|
||||
from ...engine.interfaces import PoolProxiedConnection
|
||||
from ...engine.url import URL
|
||||
from ...sql.compiler import SQLCompiler
|
||||
from ...sql.type_api import _ResultProcessorType
|
||||
|
||||
|
||||
mariadb_cpy_minimum_version = (1, 0, 1)
|
||||
|
||||
@@ -66,12 +47,10 @@ class _MariaDBUUID(sqltypes.UUID[sqltypes._UUID_RETURN]):
|
||||
# work around JIRA issue
|
||||
# https://jira.mariadb.org/browse/CONPY-270. When that issue is fixed,
|
||||
# this type can be removed.
|
||||
def result_processor(
|
||||
self, dialect: Dialect, coltype: object
|
||||
) -> Optional[_ResultProcessorType[Any]]:
|
||||
def result_processor(self, dialect, coltype):
|
||||
if self.as_uuid:
|
||||
|
||||
def process(value: Any) -> Any:
|
||||
def process(value):
|
||||
if value is not None:
|
||||
if hasattr(value, "decode"):
|
||||
value = value.decode("ascii")
|
||||
@@ -81,7 +60,7 @@ class _MariaDBUUID(sqltypes.UUID[sqltypes._UUID_RETURN]):
|
||||
return process
|
||||
else:
|
||||
|
||||
def process(value: Any) -> Any:
|
||||
def process(value):
|
||||
if value is not None:
|
||||
if hasattr(value, "decode"):
|
||||
value = value.decode("ascii")
|
||||
@@ -92,27 +71,30 @@ class _MariaDBUUID(sqltypes.UUID[sqltypes._UUID_RETURN]):
|
||||
|
||||
|
||||
class MySQLExecutionContext_mariadbconnector(MySQLExecutionContext):
|
||||
_lastrowid: Optional[int] = None
|
||||
_lastrowid = None
|
||||
|
||||
def create_server_side_cursor(self) -> DBAPICursor:
|
||||
def create_server_side_cursor(self):
|
||||
return self._dbapi_connection.cursor(buffered=False)
|
||||
|
||||
def create_default_cursor(self) -> DBAPICursor:
|
||||
def create_default_cursor(self):
|
||||
return self._dbapi_connection.cursor(buffered=True)
|
||||
|
||||
def post_exec(self) -> None:
|
||||
def post_exec(self):
|
||||
super().post_exec()
|
||||
|
||||
self._rowcount = self.cursor.rowcount
|
||||
|
||||
if TYPE_CHECKING:
|
||||
assert isinstance(self.compiled, SQLCompiler)
|
||||
if self.isinsert and self.compiled.postfetch_lastrowid:
|
||||
self._lastrowid = self.cursor.lastrowid
|
||||
|
||||
def get_lastrowid(self) -> int:
|
||||
if TYPE_CHECKING:
|
||||
assert self._lastrowid is not None
|
||||
@property
|
||||
def rowcount(self):
|
||||
if self._rowcount is not None:
|
||||
return self._rowcount
|
||||
else:
|
||||
return self.cursor.rowcount
|
||||
|
||||
def get_lastrowid(self):
|
||||
return self._lastrowid
|
||||
|
||||
|
||||
@@ -151,7 +133,7 @@ class MySQLDialect_mariadbconnector(MySQLDialect):
|
||||
)
|
||||
|
||||
@util.memoized_property
|
||||
def _dbapi_version(self) -> Tuple[int, ...]:
|
||||
def _dbapi_version(self):
|
||||
if self.dbapi and hasattr(self.dbapi, "__version__"):
|
||||
return tuple(
|
||||
[
|
||||
@@ -164,7 +146,7 @@ class MySQLDialect_mariadbconnector(MySQLDialect):
|
||||
else:
|
||||
return (99, 99, 99)
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.paramstyle = "qmark"
|
||||
if self.dbapi is not None:
|
||||
@@ -176,26 +158,20 @@ class MySQLDialect_mariadbconnector(MySQLDialect):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def import_dbapi(cls) -> DBAPIModule:
|
||||
def import_dbapi(cls):
|
||||
return __import__("mariadb")
|
||||
|
||||
def is_disconnect(
|
||||
self,
|
||||
e: DBAPIModule.Error,
|
||||
connection: Optional[Union[PoolProxiedConnection, DBAPIConnection]],
|
||||
cursor: Optional[DBAPICursor],
|
||||
) -> bool:
|
||||
def is_disconnect(self, e, connection, cursor):
|
||||
if super().is_disconnect(e, connection, cursor):
|
||||
return True
|
||||
elif isinstance(e, self.loaded_dbapi.Error):
|
||||
elif isinstance(e, self.dbapi.Error):
|
||||
str_e = str(e).lower()
|
||||
return "not connected" in str_e or "isn't valid" in str_e
|
||||
else:
|
||||
return False
|
||||
|
||||
def create_connect_args(self, url: URL) -> ConnectArgsType:
|
||||
def create_connect_args(self, url):
|
||||
opts = url.translate_connect_args()
|
||||
opts.update(url.query)
|
||||
|
||||
int_params = [
|
||||
"connect_timeout",
|
||||
@@ -210,7 +186,6 @@ class MySQLDialect_mariadbconnector(MySQLDialect):
|
||||
"ssl_verify_cert",
|
||||
"ssl",
|
||||
"pool_reset_connection",
|
||||
"compress",
|
||||
]
|
||||
|
||||
for key in int_params:
|
||||
@@ -230,21 +205,19 @@ class MySQLDialect_mariadbconnector(MySQLDialect):
|
||||
except (AttributeError, ImportError):
|
||||
self.supports_sane_rowcount = False
|
||||
opts["client_flag"] = client_flag
|
||||
return [], opts
|
||||
return [[], opts]
|
||||
|
||||
def _extract_error_code(self, exception: DBAPIModule.Error) -> int:
|
||||
def _extract_error_code(self, exception):
|
||||
try:
|
||||
rc: int = exception.errno
|
||||
rc = exception.errno
|
||||
except:
|
||||
rc = -1
|
||||
return rc
|
||||
|
||||
def _detect_charset(self, connection: Connection) -> str:
|
||||
def _detect_charset(self, connection):
|
||||
return "utf8mb4"
|
||||
|
||||
def get_isolation_level_values(
|
||||
self, dbapi_conn: DBAPIConnection
|
||||
) -> Sequence[IsolationLevel]:
|
||||
def get_isolation_level_values(self, dbapi_connection):
|
||||
return (
|
||||
"SERIALIZABLE",
|
||||
"READ UNCOMMITTED",
|
||||
@@ -253,26 +226,21 @@ class MySQLDialect_mariadbconnector(MySQLDialect):
|
||||
"AUTOCOMMIT",
|
||||
)
|
||||
|
||||
def detect_autocommit_setting(self, dbapi_conn: DBAPIConnection) -> bool:
|
||||
return bool(dbapi_conn.autocommit)
|
||||
|
||||
def set_isolation_level(
|
||||
self, dbapi_connection: DBAPIConnection, level: IsolationLevel
|
||||
) -> None:
|
||||
def set_isolation_level(self, connection, level):
|
||||
if level == "AUTOCOMMIT":
|
||||
dbapi_connection.autocommit = True
|
||||
connection.autocommit = True
|
||||
else:
|
||||
dbapi_connection.autocommit = False
|
||||
super().set_isolation_level(dbapi_connection, level)
|
||||
connection.autocommit = False
|
||||
super().set_isolation_level(connection, level)
|
||||
|
||||
def do_begin_twophase(self, connection: Connection, xid: Any) -> None:
|
||||
def do_begin_twophase(self, connection, xid):
|
||||
connection.execute(
|
||||
sql.text("XA BEGIN :xid").bindparams(
|
||||
sql.bindparam("xid", xid, literal_execute=True)
|
||||
)
|
||||
)
|
||||
|
||||
def do_prepare_twophase(self, connection: Connection, xid: Any) -> None:
|
||||
def do_prepare_twophase(self, connection, xid):
|
||||
connection.execute(
|
||||
sql.text("XA END :xid").bindparams(
|
||||
sql.bindparam("xid", xid, literal_execute=True)
|
||||
@@ -285,12 +253,8 @@ class MySQLDialect_mariadbconnector(MySQLDialect):
|
||||
)
|
||||
|
||||
def do_rollback_twophase(
|
||||
self,
|
||||
connection: Connection,
|
||||
xid: Any,
|
||||
is_prepared: bool = True,
|
||||
recover: bool = False,
|
||||
) -> None:
|
||||
self, connection, xid, is_prepared=True, recover=False
|
||||
):
|
||||
if not is_prepared:
|
||||
connection.execute(
|
||||
sql.text("XA END :xid").bindparams(
|
||||
@@ -304,12 +268,8 @@ class MySQLDialect_mariadbconnector(MySQLDialect):
|
||||
)
|
||||
|
||||
def do_commit_twophase(
|
||||
self,
|
||||
connection: Connection,
|
||||
xid: Any,
|
||||
is_prepared: bool = True,
|
||||
recover: bool = False,
|
||||
) -> None:
|
||||
self, connection, xid, is_prepared=True, recover=False
|
||||
):
|
||||
if not is_prepared:
|
||||
self.do_prepare_twophase(connection, xid)
|
||||
connection.execute(
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
# dialects/mysql/mysqlconnector.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# mysql/mysqlconnector.py
|
||||
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
|
||||
r"""
|
||||
@@ -13,85 +14,26 @@ r"""
|
||||
:connectstring: mysql+mysqlconnector://<user>:<password>@<host>[:<port>]/<dbname>
|
||||
:url: https://pypi.org/project/mysql-connector-python/
|
||||
|
||||
Driver Status
|
||||
-------------
|
||||
|
||||
MySQL Connector/Python is supported as of SQLAlchemy 2.0.39 to the
|
||||
degree which the driver is functional. There are still ongoing issues
|
||||
with features such as server side cursors which remain disabled until
|
||||
upstream issues are repaired.
|
||||
|
||||
.. warning:: The MySQL Connector/Python driver published by Oracle is subject
|
||||
to frequent, major regressions of essential functionality such as being able
|
||||
to correctly persist simple binary strings which indicate it is not well
|
||||
tested. The SQLAlchemy project is not able to maintain this dialect fully as
|
||||
regressions in the driver prevent it from being included in continuous
|
||||
integration.
|
||||
|
||||
.. versionchanged:: 2.0.39
|
||||
|
||||
The MySQL Connector/Python dialect has been updated to support the
|
||||
latest version of this DBAPI. Previously, MySQL Connector/Python
|
||||
was not fully supported. However, support remains limited due to ongoing
|
||||
regressions introduced in this driver.
|
||||
|
||||
Connecting to MariaDB with MySQL Connector/Python
|
||||
--------------------------------------------------
|
||||
|
||||
MySQL Connector/Python may attempt to pass an incompatible collation to the
|
||||
database when connecting to MariaDB. Experimentation has shown that using
|
||||
``?charset=utf8mb4&collation=utfmb4_general_ci`` or similar MariaDB-compatible
|
||||
charset/collation will allow connectivity.
|
||||
.. note::
|
||||
|
||||
The MySQL Connector/Python DBAPI has had many issues since its release,
|
||||
some of which may remain unresolved, and the mysqlconnector dialect is
|
||||
**not tested as part of SQLAlchemy's continuous integration**.
|
||||
The recommended MySQL dialects are mysqlclient and PyMySQL.
|
||||
|
||||
""" # noqa
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import Optional
|
||||
from typing import Sequence
|
||||
from typing import Tuple
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import Union
|
||||
|
||||
from .base import MariaDBIdentifierPreparer
|
||||
from .base import BIT
|
||||
from .base import MySQLCompiler
|
||||
from .base import MySQLDialect
|
||||
from .base import MySQLExecutionContext
|
||||
from .base import MySQLIdentifierPreparer
|
||||
from .mariadb import MariaDBDialect
|
||||
from .types import BIT
|
||||
from ... import util
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
from ...engine.base import Connection
|
||||
from ...engine.cursor import CursorResult
|
||||
from ...engine.interfaces import ConnectArgsType
|
||||
from ...engine.interfaces import DBAPIConnection
|
||||
from ...engine.interfaces import DBAPICursor
|
||||
from ...engine.interfaces import DBAPIModule
|
||||
from ...engine.interfaces import IsolationLevel
|
||||
from ...engine.interfaces import PoolProxiedConnection
|
||||
from ...engine.row import Row
|
||||
from ...engine.url import URL
|
||||
from ...sql.elements import BinaryExpression
|
||||
|
||||
|
||||
class MySQLExecutionContext_mysqlconnector(MySQLExecutionContext):
|
||||
def create_server_side_cursor(self) -> DBAPICursor:
|
||||
return self._dbapi_connection.cursor(buffered=False)
|
||||
|
||||
def create_default_cursor(self) -> DBAPICursor:
|
||||
return self._dbapi_connection.cursor(buffered=True)
|
||||
|
||||
|
||||
class MySQLCompiler_mysqlconnector(MySQLCompiler):
|
||||
def visit_mod_binary(
|
||||
self, binary: BinaryExpression[Any], operator: Any, **kw: Any
|
||||
) -> str:
|
||||
def visit_mod_binary(self, binary, operator, **kw):
|
||||
return (
|
||||
self.process(binary.left, **kw)
|
||||
+ " % "
|
||||
@@ -99,37 +41,22 @@ class MySQLCompiler_mysqlconnector(MySQLCompiler):
|
||||
)
|
||||
|
||||
|
||||
class IdentifierPreparerCommon_mysqlconnector:
|
||||
class MySQLIdentifierPreparer_mysqlconnector(MySQLIdentifierPreparer):
|
||||
@property
|
||||
def _double_percents(self) -> bool:
|
||||
def _double_percents(self):
|
||||
return False
|
||||
|
||||
@_double_percents.setter
|
||||
def _double_percents(self, value: Any) -> None:
|
||||
def _double_percents(self, value):
|
||||
pass
|
||||
|
||||
def _escape_identifier(self, value: str) -> str:
|
||||
value = value.replace(
|
||||
self.escape_quote, # type:ignore[attr-defined]
|
||||
self.escape_to_quote, # type:ignore[attr-defined]
|
||||
)
|
||||
def _escape_identifier(self, value):
|
||||
value = value.replace(self.escape_quote, self.escape_to_quote)
|
||||
return value
|
||||
|
||||
|
||||
class MySQLIdentifierPreparer_mysqlconnector(
|
||||
IdentifierPreparerCommon_mysqlconnector, MySQLIdentifierPreparer
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
class MariaDBIdentifierPreparer_mysqlconnector(
|
||||
IdentifierPreparerCommon_mysqlconnector, MariaDBIdentifierPreparer
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
class _myconnpyBIT(BIT):
|
||||
def result_processor(self, dialect: Any, coltype: Any) -> None:
|
||||
def result_processor(self, dialect, coltype):
|
||||
"""MySQL-connector already converts mysql bits, so."""
|
||||
|
||||
return None
|
||||
@@ -144,31 +71,24 @@ class MySQLDialect_mysqlconnector(MySQLDialect):
|
||||
|
||||
supports_native_decimal = True
|
||||
|
||||
supports_native_bit = True
|
||||
|
||||
# not until https://bugs.mysql.com/bug.php?id=117548
|
||||
supports_server_side_cursors = False
|
||||
|
||||
default_paramstyle = "format"
|
||||
statement_compiler = MySQLCompiler_mysqlconnector
|
||||
|
||||
execution_ctx_cls = MySQLExecutionContext_mysqlconnector
|
||||
|
||||
preparer: type[MySQLIdentifierPreparer] = (
|
||||
MySQLIdentifierPreparer_mysqlconnector
|
||||
)
|
||||
preparer = MySQLIdentifierPreparer_mysqlconnector
|
||||
|
||||
colspecs = util.update_copy(MySQLDialect.colspecs, {BIT: _myconnpyBIT})
|
||||
|
||||
@classmethod
|
||||
def import_dbapi(cls) -> DBAPIModule:
|
||||
return cast("DBAPIModule", __import__("mysql.connector").connector)
|
||||
def import_dbapi(cls):
|
||||
from mysql import connector
|
||||
|
||||
def do_ping(self, dbapi_connection: DBAPIConnection) -> bool:
|
||||
return connector
|
||||
|
||||
def do_ping(self, dbapi_connection):
|
||||
dbapi_connection.ping(False)
|
||||
return True
|
||||
|
||||
def create_connect_args(self, url: URL) -> ConnectArgsType:
|
||||
def create_connect_args(self, url):
|
||||
opts = url.translate_connect_args(username="user")
|
||||
|
||||
opts.update(url.query)
|
||||
@@ -176,7 +96,6 @@ class MySQLDialect_mysqlconnector(MySQLDialect):
|
||||
util.coerce_kw_type(opts, "allow_local_infile", bool)
|
||||
util.coerce_kw_type(opts, "autocommit", bool)
|
||||
util.coerce_kw_type(opts, "buffered", bool)
|
||||
util.coerce_kw_type(opts, "client_flag", int)
|
||||
util.coerce_kw_type(opts, "compress", bool)
|
||||
util.coerce_kw_type(opts, "connection_timeout", int)
|
||||
util.coerce_kw_type(opts, "connect_timeout", int)
|
||||
@@ -191,21 +110,15 @@ class MySQLDialect_mysqlconnector(MySQLDialect):
|
||||
util.coerce_kw_type(opts, "use_pure", bool)
|
||||
util.coerce_kw_type(opts, "use_unicode", bool)
|
||||
|
||||
# note that "buffered" is set to False by default in MySQL/connector
|
||||
# python. If you set it to True, then there is no way to get a server
|
||||
# side cursor because the logic is written to disallow that.
|
||||
|
||||
# leaving this at True until
|
||||
# https://bugs.mysql.com/bug.php?id=117548 can be fixed
|
||||
opts["buffered"] = True
|
||||
# unfortunately, MySQL/connector python refuses to release a
|
||||
# cursor without reading fully, so non-buffered isn't an option
|
||||
opts.setdefault("buffered", True)
|
||||
|
||||
# FOUND_ROWS must be set in ClientFlag to enable
|
||||
# supports_sane_rowcount.
|
||||
if self.dbapi is not None:
|
||||
try:
|
||||
from mysql.connector import constants # type: ignore
|
||||
|
||||
ClientFlag = constants.ClientFlag
|
||||
from mysql.connector.constants import ClientFlag
|
||||
|
||||
client_flags = opts.get(
|
||||
"client_flags", ClientFlag.get_default()
|
||||
@@ -214,35 +127,24 @@ class MySQLDialect_mysqlconnector(MySQLDialect):
|
||||
opts["client_flags"] = client_flags
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return [], opts
|
||||
return [[], opts]
|
||||
|
||||
@util.memoized_property
|
||||
def _mysqlconnector_version_info(self) -> Optional[Tuple[int, ...]]:
|
||||
def _mysqlconnector_version_info(self):
|
||||
if self.dbapi and hasattr(self.dbapi, "__version__"):
|
||||
m = re.match(r"(\d+)\.(\d+)(?:\.(\d+))?", self.dbapi.__version__)
|
||||
if m:
|
||||
return tuple(int(x) for x in m.group(1, 2, 3) if x is not None)
|
||||
return None
|
||||
|
||||
def _detect_charset(self, connection: Connection) -> str:
|
||||
return connection.connection.charset # type: ignore
|
||||
def _detect_charset(self, connection):
|
||||
return connection.connection.charset
|
||||
|
||||
def _extract_error_code(self, exception: BaseException) -> int:
|
||||
return exception.errno # type: ignore
|
||||
def _extract_error_code(self, exception):
|
||||
return exception.errno
|
||||
|
||||
def is_disconnect(
|
||||
self,
|
||||
e: Exception,
|
||||
connection: Optional[Union[PoolProxiedConnection, DBAPIConnection]],
|
||||
cursor: Optional[DBAPICursor],
|
||||
) -> bool:
|
||||
def is_disconnect(self, e, connection, cursor):
|
||||
errnos = (2006, 2013, 2014, 2045, 2055, 2048)
|
||||
exceptions = (
|
||||
self.loaded_dbapi.OperationalError, #
|
||||
self.loaded_dbapi.InterfaceError,
|
||||
self.loaded_dbapi.ProgrammingError,
|
||||
)
|
||||
exceptions = (self.dbapi.OperationalError, self.dbapi.InterfaceError)
|
||||
if isinstance(e, exceptions):
|
||||
return (
|
||||
e.errno in errnos
|
||||
@@ -252,51 +154,26 @@ class MySQLDialect_mysqlconnector(MySQLDialect):
|
||||
else:
|
||||
return False
|
||||
|
||||
def _compat_fetchall(
|
||||
self,
|
||||
rp: CursorResult[Tuple[Any, ...]],
|
||||
charset: Optional[str] = None,
|
||||
) -> Sequence[Row[Tuple[Any, ...]]]:
|
||||
def _compat_fetchall(self, rp, charset=None):
|
||||
return rp.fetchall()
|
||||
|
||||
def _compat_fetchone(
|
||||
self,
|
||||
rp: CursorResult[Tuple[Any, ...]],
|
||||
charset: Optional[str] = None,
|
||||
) -> Optional[Row[Tuple[Any, ...]]]:
|
||||
def _compat_fetchone(self, rp, charset=None):
|
||||
return rp.fetchone()
|
||||
|
||||
def get_isolation_level_values(
|
||||
self, dbapi_conn: DBAPIConnection
|
||||
) -> Sequence[IsolationLevel]:
|
||||
return (
|
||||
"SERIALIZABLE",
|
||||
"READ UNCOMMITTED",
|
||||
"READ COMMITTED",
|
||||
"REPEATABLE READ",
|
||||
"AUTOCOMMIT",
|
||||
)
|
||||
_isolation_lookup = {
|
||||
"SERIALIZABLE",
|
||||
"READ UNCOMMITTED",
|
||||
"READ COMMITTED",
|
||||
"REPEATABLE READ",
|
||||
"AUTOCOMMIT",
|
||||
}
|
||||
|
||||
def detect_autocommit_setting(self, dbapi_conn: DBAPIConnection) -> bool:
|
||||
return bool(dbapi_conn.autocommit)
|
||||
|
||||
def set_isolation_level(
|
||||
self, dbapi_connection: DBAPIConnection, level: IsolationLevel
|
||||
) -> None:
|
||||
def _set_isolation_level(self, connection, level):
|
||||
if level == "AUTOCOMMIT":
|
||||
dbapi_connection.autocommit = True
|
||||
connection.autocommit = True
|
||||
else:
|
||||
dbapi_connection.autocommit = False
|
||||
super().set_isolation_level(dbapi_connection, level)
|
||||
|
||||
|
||||
class MariaDBDialect_mysqlconnector(
|
||||
MariaDBDialect, MySQLDialect_mysqlconnector
|
||||
):
|
||||
supports_statement_cache = True
|
||||
_allows_uuid_binds = False
|
||||
preparer = MariaDBIdentifierPreparer_mysqlconnector
|
||||
connection.autocommit = False
|
||||
super()._set_isolation_level(connection, level)
|
||||
|
||||
|
||||
dialect = MySQLDialect_mysqlconnector
|
||||
mariadb_dialect = MariaDBDialect_mysqlconnector
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
# dialects/mysql/mysqldb.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# mysql/mysqldb.py
|
||||
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
|
||||
"""
|
||||
|
||||
@@ -46,9 +48,9 @@ key "ssl", which may be specified using the
|
||||
"ssl": {
|
||||
"ca": "/home/gord/client-ssl/ca.pem",
|
||||
"cert": "/home/gord/client-ssl/client-cert.pem",
|
||||
"key": "/home/gord/client-ssl/client-key.pem",
|
||||
"key": "/home/gord/client-ssl/client-key.pem"
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
For convenience, the following keys may also be specified inline within the URL
|
||||
@@ -72,9 +74,7 @@ Using MySQLdb with Google Cloud SQL
|
||||
-----------------------------------
|
||||
|
||||
Google Cloud SQL now recommends use of the MySQLdb dialect. Connect
|
||||
using a URL like the following:
|
||||
|
||||
.. sourcecode:: text
|
||||
using a URL like the following::
|
||||
|
||||
mysql+mysqldb://root@/<dbname>?unix_socket=/cloudsql/<projectid>:<instancename>
|
||||
|
||||
@@ -84,39 +84,25 @@ Server Side Cursors
|
||||
The mysqldb dialect supports server-side cursors. See :ref:`mysql_ss_cursors`.
|
||||
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import Any
|
||||
from typing import Callable
|
||||
from typing import cast
|
||||
from typing import Dict
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from .base import MySQLCompiler
|
||||
from .base import MySQLDialect
|
||||
from .base import MySQLExecutionContext
|
||||
from .base import MySQLIdentifierPreparer
|
||||
from .base import TEXT
|
||||
from ... import sql
|
||||
from ... import util
|
||||
from ...util.typing import Literal
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
from ...engine.base import Connection
|
||||
from ...engine.interfaces import _DBAPIMultiExecuteParams
|
||||
from ...engine.interfaces import ConnectArgsType
|
||||
from ...engine.interfaces import DBAPIConnection
|
||||
from ...engine.interfaces import DBAPICursor
|
||||
from ...engine.interfaces import DBAPIModule
|
||||
from ...engine.interfaces import ExecutionContext
|
||||
from ...engine.interfaces import IsolationLevel
|
||||
from ...engine.url import URL
|
||||
|
||||
|
||||
class MySQLExecutionContext_mysqldb(MySQLExecutionContext):
|
||||
pass
|
||||
@property
|
||||
def rowcount(self):
|
||||
if hasattr(self, "_rowcount"):
|
||||
return self._rowcount
|
||||
else:
|
||||
return self.cursor.rowcount
|
||||
|
||||
|
||||
class MySQLCompiler_mysqldb(MySQLCompiler):
|
||||
@@ -136,9 +122,8 @@ class MySQLDialect_mysqldb(MySQLDialect):
|
||||
execution_ctx_cls = MySQLExecutionContext_mysqldb
|
||||
statement_compiler = MySQLCompiler_mysqldb
|
||||
preparer = MySQLIdentifierPreparer
|
||||
server_version_info: Tuple[int, ...]
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._mysql_dbapi_version = (
|
||||
self._parse_dbapi_version(self.dbapi.__version__)
|
||||
@@ -146,7 +131,7 @@ class MySQLDialect_mysqldb(MySQLDialect):
|
||||
else (0, 0, 0)
|
||||
)
|
||||
|
||||
def _parse_dbapi_version(self, version: str) -> Tuple[int, ...]:
|
||||
def _parse_dbapi_version(self, version):
|
||||
m = re.match(r"(\d+)\.(\d+)(?:\.(\d+))?", version)
|
||||
if m:
|
||||
return tuple(int(x) for x in m.group(1, 2, 3) if x is not None)
|
||||
@@ -154,7 +139,7 @@ class MySQLDialect_mysqldb(MySQLDialect):
|
||||
return (0, 0, 0)
|
||||
|
||||
@util.langhelpers.memoized_property
|
||||
def supports_server_side_cursors(self) -> bool:
|
||||
def supports_server_side_cursors(self):
|
||||
try:
|
||||
cursors = __import__("MySQLdb.cursors").cursors
|
||||
self._sscursor = cursors.SSCursor
|
||||
@@ -163,13 +148,13 @@ class MySQLDialect_mysqldb(MySQLDialect):
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def import_dbapi(cls) -> DBAPIModule:
|
||||
def import_dbapi(cls):
|
||||
return __import__("MySQLdb")
|
||||
|
||||
def on_connect(self) -> Callable[[DBAPIConnection], None]:
|
||||
def on_connect(self):
|
||||
super_ = super().on_connect()
|
||||
|
||||
def on_connect(conn: DBAPIConnection) -> None:
|
||||
def on_connect(conn):
|
||||
if super_ is not None:
|
||||
super_(conn)
|
||||
|
||||
@@ -182,24 +167,43 @@ class MySQLDialect_mysqldb(MySQLDialect):
|
||||
|
||||
return on_connect
|
||||
|
||||
def do_ping(self, dbapi_connection: DBAPIConnection) -> Literal[True]:
|
||||
def do_ping(self, dbapi_connection):
|
||||
dbapi_connection.ping()
|
||||
return True
|
||||
|
||||
def do_executemany(
|
||||
self,
|
||||
cursor: DBAPICursor,
|
||||
statement: str,
|
||||
parameters: _DBAPIMultiExecuteParams,
|
||||
context: Optional[ExecutionContext] = None,
|
||||
) -> None:
|
||||
def do_executemany(self, cursor, statement, parameters, context=None):
|
||||
rowcount = cursor.executemany(statement, parameters)
|
||||
if context is not None:
|
||||
cast(MySQLExecutionContext, context)._rowcount = rowcount
|
||||
context._rowcount = rowcount
|
||||
|
||||
def create_connect_args(
|
||||
self, url: URL, _translate_args: Optional[Dict[str, Any]] = None
|
||||
) -> ConnectArgsType:
|
||||
def _check_unicode_returns(self, connection):
|
||||
# work around issue fixed in
|
||||
# https://github.com/farcepest/MySQLdb1/commit/cd44524fef63bd3fcb71947392326e9742d520e8
|
||||
# specific issue w/ the utf8mb4_bin collation and unicode returns
|
||||
|
||||
collation = connection.exec_driver_sql(
|
||||
"show collation where %s = 'utf8mb4' and %s = 'utf8mb4_bin'"
|
||||
% (
|
||||
self.identifier_preparer.quote("Charset"),
|
||||
self.identifier_preparer.quote("Collation"),
|
||||
)
|
||||
).scalar()
|
||||
has_utf8mb4_bin = self.server_version_info > (5,) and collation
|
||||
if has_utf8mb4_bin:
|
||||
additional_tests = [
|
||||
sql.collate(
|
||||
sql.cast(
|
||||
sql.literal_column("'test collated returns'"),
|
||||
TEXT(charset="utf8mb4"),
|
||||
),
|
||||
"utf8mb4_bin",
|
||||
)
|
||||
]
|
||||
else:
|
||||
additional_tests = []
|
||||
return super()._check_unicode_returns(connection, additional_tests)
|
||||
|
||||
def create_connect_args(self, url, _translate_args=None):
|
||||
if _translate_args is None:
|
||||
_translate_args = dict(
|
||||
database="db", username="user", password="passwd"
|
||||
@@ -213,7 +217,7 @@ class MySQLDialect_mysqldb(MySQLDialect):
|
||||
util.coerce_kw_type(opts, "read_timeout", int)
|
||||
util.coerce_kw_type(opts, "write_timeout", int)
|
||||
util.coerce_kw_type(opts, "client_flag", int)
|
||||
util.coerce_kw_type(opts, "local_infile", bool)
|
||||
util.coerce_kw_type(opts, "local_infile", int)
|
||||
# Note: using either of the below will cause all strings to be
|
||||
# returned as Unicode, both in raw SQL operations and with column
|
||||
# types like String and MSString.
|
||||
@@ -248,9 +252,9 @@ class MySQLDialect_mysqldb(MySQLDialect):
|
||||
if client_flag_found_rows is not None:
|
||||
client_flag |= client_flag_found_rows
|
||||
opts["client_flag"] = client_flag
|
||||
return [], opts
|
||||
return [[], opts]
|
||||
|
||||
def _found_rows_client_flag(self) -> Optional[int]:
|
||||
def _found_rows_client_flag(self):
|
||||
if self.dbapi is not None:
|
||||
try:
|
||||
CLIENT_FLAGS = __import__(
|
||||
@@ -259,23 +263,20 @@ class MySQLDialect_mysqldb(MySQLDialect):
|
||||
except (AttributeError, ImportError):
|
||||
return None
|
||||
else:
|
||||
return CLIENT_FLAGS.FOUND_ROWS # type: ignore
|
||||
return CLIENT_FLAGS.FOUND_ROWS
|
||||
else:
|
||||
return None
|
||||
|
||||
def _extract_error_code(self, exception: DBAPIModule.Error) -> int:
|
||||
return exception.args[0] # type: ignore[no-any-return]
|
||||
def _extract_error_code(self, exception):
|
||||
return exception.args[0]
|
||||
|
||||
def _detect_charset(self, connection: Connection) -> str:
|
||||
def _detect_charset(self, connection):
|
||||
"""Sniff out the character set in use for connection results."""
|
||||
|
||||
try:
|
||||
# note: the SQL here would be
|
||||
# "SHOW VARIABLES LIKE 'character_set%%'"
|
||||
|
||||
cset_name: Callable[[], str] = (
|
||||
connection.connection.character_set_name
|
||||
)
|
||||
cset_name = connection.connection.character_set_name
|
||||
except AttributeError:
|
||||
util.warn(
|
||||
"No 'character_set_name' can be detected with "
|
||||
@@ -287,9 +288,7 @@ class MySQLDialect_mysqldb(MySQLDialect):
|
||||
else:
|
||||
return cset_name()
|
||||
|
||||
def get_isolation_level_values(
|
||||
self, dbapi_conn: DBAPIConnection
|
||||
) -> Tuple[IsolationLevel, ...]:
|
||||
def get_isolation_level_values(self, dbapi_connection):
|
||||
return (
|
||||
"SERIALIZABLE",
|
||||
"READ UNCOMMITTED",
|
||||
@@ -298,12 +297,7 @@ class MySQLDialect_mysqldb(MySQLDialect):
|
||||
"AUTOCOMMIT",
|
||||
)
|
||||
|
||||
def detect_autocommit_setting(self, dbapi_conn: DBAPIConnection) -> bool:
|
||||
return dbapi_conn.get_autocommit() # type: ignore[no-any-return]
|
||||
|
||||
def set_isolation_level(
|
||||
self, dbapi_connection: DBAPIConnection, level: IsolationLevel
|
||||
) -> None:
|
||||
def set_isolation_level(self, dbapi_connection, level):
|
||||
if level == "AUTOCOMMIT":
|
||||
dbapi_connection.autocommit(True)
|
||||
else:
|
||||
|
||||
@@ -1,10 +1,5 @@
|
||||
# dialects/mysql/provision.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
from ... import exc
|
||||
from ...testing.provision import configure_follower
|
||||
from ...testing.provision import create_db
|
||||
@@ -39,13 +34,6 @@ def generate_driver_url(url, driver, query_str):
|
||||
drivername="%s+%s" % (backend, driver)
|
||||
).update_query_string(query_str)
|
||||
|
||||
if driver == "mariadbconnector":
|
||||
new_url = new_url.difference_update_query(["charset"])
|
||||
elif driver == "mysqlconnector":
|
||||
new_url = new_url.update_query_pairs(
|
||||
[("collation", "utf8mb4_general_ci")]
|
||||
)
|
||||
|
||||
try:
|
||||
new_url.get_dialect()
|
||||
except exc.NoSuchModuleError:
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
# dialects/mysql/pymysql.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# mysql/pymysql.py
|
||||
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
|
||||
r"""
|
||||
|
||||
@@ -39,6 +41,7 @@ necessary to indicate ``ssl_check_hostname=false`` in PyMySQL::
|
||||
"&ssl_check_hostname=false"
|
||||
)
|
||||
|
||||
|
||||
MySQL-Python Compatibility
|
||||
--------------------------
|
||||
|
||||
@@ -47,26 +50,9 @@ and targets 100% compatibility. Most behavioral notes for MySQL-python apply
|
||||
to the pymysql driver as well.
|
||||
|
||||
""" # noqa
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import Union
|
||||
|
||||
from .mysqldb import MySQLDialect_mysqldb
|
||||
from ...util import langhelpers
|
||||
from ...util.typing import Literal
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
from ...engine.interfaces import ConnectArgsType
|
||||
from ...engine.interfaces import DBAPIConnection
|
||||
from ...engine.interfaces import DBAPICursor
|
||||
from ...engine.interfaces import DBAPIModule
|
||||
from ...engine.interfaces import PoolProxiedConnection
|
||||
from ...engine.url import URL
|
||||
|
||||
|
||||
class MySQLDialect_pymysql(MySQLDialect_mysqldb):
|
||||
@@ -76,7 +62,7 @@ class MySQLDialect_pymysql(MySQLDialect_mysqldb):
|
||||
description_encoding = None
|
||||
|
||||
@langhelpers.memoized_property
|
||||
def supports_server_side_cursors(self) -> bool:
|
||||
def supports_server_side_cursors(self):
|
||||
try:
|
||||
cursors = __import__("pymysql.cursors").cursors
|
||||
self._sscursor = cursors.SSCursor
|
||||
@@ -85,11 +71,11 @@ class MySQLDialect_pymysql(MySQLDialect_mysqldb):
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def import_dbapi(cls) -> DBAPIModule:
|
||||
def import_dbapi(cls):
|
||||
return __import__("pymysql")
|
||||
|
||||
@langhelpers.memoized_property
|
||||
def _send_false_to_ping(self) -> bool:
|
||||
def _send_false_to_ping(self):
|
||||
"""determine if pymysql has deprecated, changed the default of,
|
||||
or removed the 'reconnect' argument of connection.ping().
|
||||
|
||||
@@ -100,9 +86,7 @@ class MySQLDialect_pymysql(MySQLDialect_mysqldb):
|
||||
""" # noqa: E501
|
||||
|
||||
try:
|
||||
Connection = __import__(
|
||||
"pymysql.connections"
|
||||
).connections.Connection
|
||||
Connection = __import__("pymysql.connections").Connection
|
||||
except (ImportError, AttributeError):
|
||||
return True
|
||||
else:
|
||||
@@ -116,7 +100,7 @@ class MySQLDialect_pymysql(MySQLDialect_mysqldb):
|
||||
not insp.defaults or insp.defaults[0] is not False
|
||||
)
|
||||
|
||||
def do_ping(self, dbapi_connection: DBAPIConnection) -> Literal[True]:
|
||||
def do_ping(self, dbapi_connection):
|
||||
if self._send_false_to_ping:
|
||||
dbapi_connection.ping(False)
|
||||
else:
|
||||
@@ -124,24 +108,17 @@ class MySQLDialect_pymysql(MySQLDialect_mysqldb):
|
||||
|
||||
return True
|
||||
|
||||
def create_connect_args(
|
||||
self, url: URL, _translate_args: Optional[Dict[str, Any]] = None
|
||||
) -> ConnectArgsType:
|
||||
def create_connect_args(self, url, _translate_args=None):
|
||||
if _translate_args is None:
|
||||
_translate_args = dict(username="user")
|
||||
return super().create_connect_args(
|
||||
url, _translate_args=_translate_args
|
||||
)
|
||||
|
||||
def is_disconnect(
|
||||
self,
|
||||
e: DBAPIModule.Error,
|
||||
connection: Optional[Union[PoolProxiedConnection, DBAPIConnection]],
|
||||
cursor: Optional[DBAPICursor],
|
||||
) -> bool:
|
||||
def is_disconnect(self, e, connection, cursor):
|
||||
if super().is_disconnect(e, connection, cursor):
|
||||
return True
|
||||
elif isinstance(e, self.loaded_dbapi.Error):
|
||||
elif isinstance(e, self.dbapi.Error):
|
||||
str_e = str(e).lower()
|
||||
return (
|
||||
"already closed" in str_e or "connection was killed" in str_e
|
||||
@@ -149,7 +126,7 @@ class MySQLDialect_pymysql(MySQLDialect_mysqldb):
|
||||
else:
|
||||
return False
|
||||
|
||||
def _extract_error_code(self, exception: BaseException) -> Any:
|
||||
def _extract_error_code(self, exception):
|
||||
if isinstance(exception.args[0], Exception):
|
||||
exception = exception.args[0]
|
||||
return exception.args[0]
|
||||
|
||||
@@ -1,13 +1,15 @@
|
||||
# dialects/mysql/pyodbc.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# mysql/pyodbc.py
|
||||
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
|
||||
r"""
|
||||
|
||||
|
||||
.. dialect:: mysql+pyodbc
|
||||
:name: PyODBC
|
||||
:dbapi: pyodbc
|
||||
@@ -28,30 +30,21 @@ r"""
|
||||
Pass through exact pyodbc connection string::
|
||||
|
||||
import urllib
|
||||
|
||||
connection_string = (
|
||||
"DRIVER=MySQL ODBC 8.0 ANSI Driver;"
|
||||
"SERVER=localhost;"
|
||||
"PORT=3307;"
|
||||
"DATABASE=mydb;"
|
||||
"UID=root;"
|
||||
"PWD=(whatever);"
|
||||
"charset=utf8mb4;"
|
||||
'DRIVER=MySQL ODBC 8.0 ANSI Driver;'
|
||||
'SERVER=localhost;'
|
||||
'PORT=3307;'
|
||||
'DATABASE=mydb;'
|
||||
'UID=root;'
|
||||
'PWD=(whatever);'
|
||||
'charset=utf8mb4;'
|
||||
)
|
||||
params = urllib.parse.quote_plus(connection_string)
|
||||
connection_uri = "mysql+pyodbc:///?odbc_connect=%s" % params
|
||||
|
||||
""" # noqa
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
import re
|
||||
from typing import Any
|
||||
from typing import Callable
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import Union
|
||||
|
||||
from .base import MySQLDialect
|
||||
from .base import MySQLExecutionContext
|
||||
@@ -61,31 +54,23 @@ from ... import util
|
||||
from ...connectors.pyodbc import PyODBCConnector
|
||||
from ...sql.sqltypes import Time
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...engine import Connection
|
||||
from ...engine.interfaces import DBAPIConnection
|
||||
from ...engine.interfaces import Dialect
|
||||
from ...sql.type_api import _ResultProcessorType
|
||||
|
||||
|
||||
class _pyodbcTIME(TIME):
|
||||
def result_processor(
|
||||
self, dialect: Dialect, coltype: object
|
||||
) -> _ResultProcessorType[datetime.time]:
|
||||
def process(value: Any) -> Union[datetime.time, None]:
|
||||
def result_processor(self, dialect, coltype):
|
||||
def process(value):
|
||||
# pyodbc returns a datetime.time object; no need to convert
|
||||
return value # type: ignore[no-any-return]
|
||||
return value
|
||||
|
||||
return process
|
||||
|
||||
|
||||
class MySQLExecutionContext_pyodbc(MySQLExecutionContext):
|
||||
def get_lastrowid(self) -> int:
|
||||
def get_lastrowid(self):
|
||||
cursor = self.create_cursor()
|
||||
cursor.execute("SELECT LAST_INSERT_ID()")
|
||||
lastrowid = cursor.fetchone()[0] # type: ignore[index]
|
||||
lastrowid = cursor.fetchone()[0]
|
||||
cursor.close()
|
||||
return lastrowid # type: ignore[no-any-return]
|
||||
return lastrowid
|
||||
|
||||
|
||||
class MySQLDialect_pyodbc(PyODBCConnector, MySQLDialect):
|
||||
@@ -96,7 +81,7 @@ class MySQLDialect_pyodbc(PyODBCConnector, MySQLDialect):
|
||||
|
||||
pyodbc_driver_name = "MySQL"
|
||||
|
||||
def _detect_charset(self, connection: Connection) -> str:
|
||||
def _detect_charset(self, connection):
|
||||
"""Sniff out the character set in use for connection results."""
|
||||
|
||||
# Prefer 'character_set_results' for the current connection over the
|
||||
@@ -121,25 +106,21 @@ class MySQLDialect_pyodbc(PyODBCConnector, MySQLDialect):
|
||||
)
|
||||
return "latin1"
|
||||
|
||||
def _get_server_version_info(
|
||||
self, connection: Connection
|
||||
) -> Tuple[int, ...]:
|
||||
def _get_server_version_info(self, connection):
|
||||
return MySQLDialect._get_server_version_info(self, connection)
|
||||
|
||||
def _extract_error_code(self, exception: BaseException) -> Optional[int]:
|
||||
def _extract_error_code(self, exception):
|
||||
m = re.compile(r"\((\d+)\)").search(str(exception.args))
|
||||
if m is None:
|
||||
return None
|
||||
c: Optional[str] = m.group(1)
|
||||
c = m.group(1)
|
||||
if c:
|
||||
return int(c)
|
||||
else:
|
||||
return None
|
||||
|
||||
def on_connect(self) -> Callable[[DBAPIConnection], None]:
|
||||
def on_connect(self):
|
||||
super_ = super().on_connect()
|
||||
|
||||
def on_connect(conn: DBAPIConnection) -> None:
|
||||
def on_connect(conn):
|
||||
if super_ is not None:
|
||||
super_(conn)
|
||||
|
||||
|
||||
@@ -1,65 +1,46 @@
|
||||
# dialects/mysql/reflection.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# mysql/reflection.py
|
||||
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
from __future__ import annotations
|
||||
# mypy: ignore-errors
|
||||
|
||||
|
||||
import re
|
||||
from typing import Any
|
||||
from typing import Callable
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import overload
|
||||
from typing import Sequence
|
||||
from typing import Tuple
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import Union
|
||||
|
||||
from .enumerated import ENUM
|
||||
from .enumerated import SET
|
||||
from .types import DATETIME
|
||||
from .types import TIME
|
||||
from .types import TIMESTAMP
|
||||
from ... import log
|
||||
from ... import types as sqltypes
|
||||
from ... import util
|
||||
from ...util.typing import Literal
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .base import MySQLDialect
|
||||
from .base import MySQLIdentifierPreparer
|
||||
from ...engine.interfaces import ReflectedColumn
|
||||
|
||||
|
||||
class ReflectedState:
|
||||
"""Stores raw information about a SHOW CREATE TABLE statement."""
|
||||
|
||||
charset: Optional[str]
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.columns: List[ReflectedColumn] = []
|
||||
self.table_options: Dict[str, str] = {}
|
||||
self.table_name: Optional[str] = None
|
||||
self.keys: List[Dict[str, Any]] = []
|
||||
self.fk_constraints: List[Dict[str, Any]] = []
|
||||
self.ck_constraints: List[Dict[str, Any]] = []
|
||||
def __init__(self):
|
||||
self.columns = []
|
||||
self.table_options = {}
|
||||
self.table_name = None
|
||||
self.keys = []
|
||||
self.fk_constraints = []
|
||||
self.ck_constraints = []
|
||||
|
||||
|
||||
@log.class_logger
|
||||
class MySQLTableDefinitionParser:
|
||||
"""Parses the results of a SHOW CREATE TABLE statement."""
|
||||
|
||||
def __init__(
|
||||
self, dialect: MySQLDialect, preparer: MySQLIdentifierPreparer
|
||||
):
|
||||
def __init__(self, dialect, preparer):
|
||||
self.dialect = dialect
|
||||
self.preparer = preparer
|
||||
self._prep_regexes()
|
||||
|
||||
def parse(
|
||||
self, show_create: str, charset: Optional[str]
|
||||
) -> ReflectedState:
|
||||
def parse(self, show_create, charset):
|
||||
state = ReflectedState()
|
||||
state.charset = charset
|
||||
for line in re.split(r"\r?\n", show_create):
|
||||
@@ -84,11 +65,11 @@ class MySQLTableDefinitionParser:
|
||||
if type_ is None:
|
||||
util.warn("Unknown schema content: %r" % line)
|
||||
elif type_ == "key":
|
||||
state.keys.append(spec) # type: ignore[arg-type]
|
||||
state.keys.append(spec)
|
||||
elif type_ == "fk_constraint":
|
||||
state.fk_constraints.append(spec) # type: ignore[arg-type]
|
||||
state.fk_constraints.append(spec)
|
||||
elif type_ == "ck_constraint":
|
||||
state.ck_constraints.append(spec) # type: ignore[arg-type]
|
||||
state.ck_constraints.append(spec)
|
||||
else:
|
||||
pass
|
||||
return state
|
||||
@@ -96,13 +77,7 @@ class MySQLTableDefinitionParser:
|
||||
def _check_view(self, sql: str) -> bool:
|
||||
return bool(self._re_is_view.match(sql))
|
||||
|
||||
def _parse_constraints(self, line: str) -> Union[
|
||||
Tuple[None, str],
|
||||
Tuple[Literal["partition"], str],
|
||||
Tuple[
|
||||
Literal["ck_constraint", "fk_constraint", "key"], Dict[str, str]
|
||||
],
|
||||
]:
|
||||
def _parse_constraints(self, line):
|
||||
"""Parse a KEY or CONSTRAINT line.
|
||||
|
||||
:param line: A line of SHOW CREATE TABLE output
|
||||
@@ -152,7 +127,7 @@ class MySQLTableDefinitionParser:
|
||||
# No match.
|
||||
return (None, line)
|
||||
|
||||
def _parse_table_name(self, line: str, state: ReflectedState) -> None:
|
||||
def _parse_table_name(self, line, state):
|
||||
"""Extract the table name.
|
||||
|
||||
:param line: The first line of SHOW CREATE TABLE
|
||||
@@ -163,7 +138,7 @@ class MySQLTableDefinitionParser:
|
||||
if m:
|
||||
state.table_name = cleanup(m.group("name"))
|
||||
|
||||
def _parse_table_options(self, line: str, state: ReflectedState) -> None:
|
||||
def _parse_table_options(self, line, state):
|
||||
"""Build a dictionary of all reflected table-level options.
|
||||
|
||||
:param line: The final line of SHOW CREATE TABLE output.
|
||||
@@ -189,9 +164,7 @@ class MySQLTableDefinitionParser:
|
||||
for opt, val in options.items():
|
||||
state.table_options["%s_%s" % (self.dialect.name, opt)] = val
|
||||
|
||||
def _parse_partition_options(
|
||||
self, line: str, state: ReflectedState
|
||||
) -> None:
|
||||
def _parse_partition_options(self, line, state):
|
||||
options = {}
|
||||
new_line = line[:]
|
||||
|
||||
@@ -247,7 +220,7 @@ class MySQLTableDefinitionParser:
|
||||
else:
|
||||
state.table_options["%s_%s" % (self.dialect.name, opt)] = val
|
||||
|
||||
def _parse_column(self, line: str, state: ReflectedState) -> None:
|
||||
def _parse_column(self, line, state):
|
||||
"""Extract column details.
|
||||
|
||||
Falls back to a 'minimal support' variant if full parse fails.
|
||||
@@ -310,16 +283,13 @@ class MySQLTableDefinitionParser:
|
||||
|
||||
type_instance = col_type(*type_args, **type_kw)
|
||||
|
||||
col_kw: Dict[str, Any] = {}
|
||||
col_kw = {}
|
||||
|
||||
# NOT NULL
|
||||
col_kw["nullable"] = True
|
||||
# this can be "NULL" in the case of TIMESTAMP
|
||||
if spec.get("notnull", False) == "NOT NULL":
|
||||
col_kw["nullable"] = False
|
||||
# For generated columns, the nullability is marked in a different place
|
||||
if spec.get("notnull_generated", False) == "NOT NULL":
|
||||
col_kw["nullable"] = False
|
||||
|
||||
# AUTO_INCREMENT
|
||||
if spec.get("autoincr", False):
|
||||
@@ -351,13 +321,9 @@ class MySQLTableDefinitionParser:
|
||||
name=name, type=type_instance, default=default, comment=comment
|
||||
)
|
||||
col_d.update(col_kw)
|
||||
state.columns.append(col_d) # type: ignore[arg-type]
|
||||
state.columns.append(col_d)
|
||||
|
||||
def _describe_to_create(
|
||||
self,
|
||||
table_name: str,
|
||||
columns: Sequence[Tuple[str, str, str, str, str, str]],
|
||||
) -> str:
|
||||
def _describe_to_create(self, table_name, columns):
|
||||
"""Re-format DESCRIBE output as a SHOW CREATE TABLE string.
|
||||
|
||||
DESCRIBE is a much simpler reflection and is sufficient for
|
||||
@@ -410,9 +376,7 @@ class MySQLTableDefinitionParser:
|
||||
]
|
||||
)
|
||||
|
||||
def _parse_keyexprs(
|
||||
self, identifiers: str
|
||||
) -> List[Tuple[str, Optional[int], str]]:
|
||||
def _parse_keyexprs(self, identifiers):
|
||||
"""Unpack '"col"(2),"col" ASC'-ish strings into components."""
|
||||
|
||||
return [
|
||||
@@ -422,12 +386,11 @@ class MySQLTableDefinitionParser:
|
||||
)
|
||||
]
|
||||
|
||||
def _prep_regexes(self) -> None:
|
||||
def _prep_regexes(self):
|
||||
"""Pre-compile regular expressions."""
|
||||
|
||||
self._pr_options: List[
|
||||
Tuple[re.Pattern[Any], Optional[Callable[[str], str]]]
|
||||
] = []
|
||||
self._re_columns = []
|
||||
self._pr_options = []
|
||||
|
||||
_final = self.preparer.final_quote
|
||||
|
||||
@@ -485,13 +448,11 @@ class MySQLTableDefinitionParser:
|
||||
r"(?: +COLLATE +(?P<collate>[\w_]+))?"
|
||||
r"(?: +(?P<notnull>(?:NOT )?NULL))?"
|
||||
r"(?: +DEFAULT +(?P<default>"
|
||||
r"(?:NULL|'(?:''|[^'])*'|\(.+?\)|[\-\w\.\(\)]+"
|
||||
r"(?:NULL|'(?:''|[^'])*'|[\-\w\.\(\)]+"
|
||||
r"(?: +ON UPDATE [\-\w\.\(\)]+)?)"
|
||||
r"))?"
|
||||
r"(?: +(?:GENERATED ALWAYS)? ?AS +(?P<generated>\("
|
||||
r".*\))? ?(?P<persistence>VIRTUAL|STORED)?"
|
||||
r"(?: +(?P<notnull_generated>(?:NOT )?NULL))?"
|
||||
r")?"
|
||||
r".*\))? ?(?P<persistence>VIRTUAL|STORED)?)?"
|
||||
r"(?: +(?P<autoincr>AUTO_INCREMENT))?"
|
||||
r"(?: +COMMENT +'(?P<comment>(?:''|[^'])*)')?"
|
||||
r"(?: +COLUMN_FORMAT +(?P<colfmt>\w+))?"
|
||||
@@ -539,7 +500,7 @@ class MySQLTableDefinitionParser:
|
||||
#
|
||||
# unique constraints come back as KEYs
|
||||
kw = quotes.copy()
|
||||
kw["on"] = "RESTRICT|CASCADE|SET NULL|NO ACTION|SET DEFAULT"
|
||||
kw["on"] = "RESTRICT|CASCADE|SET NULL|NO ACTION"
|
||||
self._re_fk_constraint = _re_compile(
|
||||
r" "
|
||||
r"CONSTRAINT +"
|
||||
@@ -616,21 +577,21 @@ class MySQLTableDefinitionParser:
|
||||
|
||||
_optional_equals = r"(?:\s*(?:=\s*)|\s+)"
|
||||
|
||||
def _add_option_string(self, directive: str) -> None:
|
||||
def _add_option_string(self, directive):
|
||||
regex = r"(?P<directive>%s)%s" r"'(?P<val>(?:[^']|'')*?)'(?!')" % (
|
||||
re.escape(directive),
|
||||
self._optional_equals,
|
||||
)
|
||||
self._pr_options.append(_pr_compile(regex, cleanup_text))
|
||||
|
||||
def _add_option_word(self, directive: str) -> None:
|
||||
def _add_option_word(self, directive):
|
||||
regex = r"(?P<directive>%s)%s" r"(?P<val>\w+)" % (
|
||||
re.escape(directive),
|
||||
self._optional_equals,
|
||||
)
|
||||
self._pr_options.append(_pr_compile(regex))
|
||||
|
||||
def _add_partition_option_word(self, directive: str) -> None:
|
||||
def _add_partition_option_word(self, directive):
|
||||
if directive == "PARTITION BY" or directive == "SUBPARTITION BY":
|
||||
regex = r"(?<!\S)(?P<directive>%s)%s" r"(?P<val>\w+.*)" % (
|
||||
re.escape(directive),
|
||||
@@ -645,7 +606,7 @@ class MySQLTableDefinitionParser:
|
||||
regex = r"(?<!\S)(?P<directive>%s)(?!\S)" % (re.escape(directive),)
|
||||
self._pr_options.append(_pr_compile(regex))
|
||||
|
||||
def _add_option_regex(self, directive: str, regex: str) -> None:
|
||||
def _add_option_regex(self, directive, regex):
|
||||
regex = r"(?P<directive>%s)%s" r"(?P<val>%s)" % (
|
||||
re.escape(directive),
|
||||
self._optional_equals,
|
||||
@@ -663,35 +624,21 @@ _options_of_type_string = (
|
||||
)
|
||||
|
||||
|
||||
@overload
|
||||
def _pr_compile(
|
||||
regex: str, cleanup: Callable[[str], str]
|
||||
) -> Tuple[re.Pattern[Any], Callable[[str], str]]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def _pr_compile(
|
||||
regex: str, cleanup: None = None
|
||||
) -> Tuple[re.Pattern[Any], None]: ...
|
||||
|
||||
|
||||
def _pr_compile(
|
||||
regex: str, cleanup: Optional[Callable[[str], str]] = None
|
||||
) -> Tuple[re.Pattern[Any], Optional[Callable[[str], str]]]:
|
||||
def _pr_compile(regex, cleanup=None):
|
||||
"""Prepare a 2-tuple of compiled regex and callable."""
|
||||
|
||||
return (_re_compile(regex), cleanup)
|
||||
|
||||
|
||||
def _re_compile(regex: str) -> re.Pattern[Any]:
|
||||
def _re_compile(regex):
|
||||
"""Compile a string to regex, I and UNICODE."""
|
||||
|
||||
return re.compile(regex, re.I | re.UNICODE)
|
||||
|
||||
|
||||
def _strip_values(values: Sequence[str]) -> List[str]:
|
||||
def _strip_values(values):
|
||||
"Strip reflected values quotes"
|
||||
strip_values: List[str] = []
|
||||
strip_values = []
|
||||
for a in values:
|
||||
if a[0:1] == '"' or a[0:1] == "'":
|
||||
# strip enclosing quotes and unquote interior
|
||||
@@ -703,9 +650,7 @@ def _strip_values(values: Sequence[str]) -> List[str]:
|
||||
def cleanup_text(raw_text: str) -> str:
|
||||
if "\\" in raw_text:
|
||||
raw_text = re.sub(
|
||||
_control_char_regexp,
|
||||
lambda s: _control_char_map[s[0]], # type: ignore[index]
|
||||
raw_text,
|
||||
_control_char_regexp, lambda s: _control_char_map[s[0]], raw_text
|
||||
)
|
||||
return raw_text.replace("''", "'")
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# dialects/mysql/reserved_words.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# mysql/reserved_words.py
|
||||
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
@@ -11,6 +11,7 @@
|
||||
# https://mariadb.com/kb/en/reserved-words/
|
||||
# includes: Reserved Words, Oracle Mode (separate set unioned)
|
||||
# excludes: Exceptions, Function Names
|
||||
# mypy: ignore-errors
|
||||
|
||||
RESERVED_WORDS_MARIADB = {
|
||||
"accessible",
|
||||
@@ -281,7 +282,6 @@ RESERVED_WORDS_MARIADB = {
|
||||
}
|
||||
)
|
||||
|
||||
# https://dev.mysql.com/doc/refman/8.3/en/keywords.html
|
||||
# https://dev.mysql.com/doc/refman/8.0/en/keywords.html
|
||||
# https://dev.mysql.com/doc/refman/5.7/en/keywords.html
|
||||
# https://dev.mysql.com/doc/refman/5.6/en/keywords.html
|
||||
@@ -403,7 +403,6 @@ RESERVED_WORDS_MYSQL = {
|
||||
"int4",
|
||||
"int8",
|
||||
"integer",
|
||||
"intersect",
|
||||
"interval",
|
||||
"into",
|
||||
"io_after_gtids",
|
||||
@@ -469,7 +468,6 @@ RESERVED_WORDS_MYSQL = {
|
||||
"outfile",
|
||||
"over",
|
||||
"parse_gcol_expr",
|
||||
"parallel",
|
||||
"partition",
|
||||
"percent_rank",
|
||||
"persist",
|
||||
@@ -478,7 +476,6 @@ RESERVED_WORDS_MYSQL = {
|
||||
"primary",
|
||||
"procedure",
|
||||
"purge",
|
||||
"qualify",
|
||||
"range",
|
||||
"rank",
|
||||
"read",
|
||||
|
||||
@@ -1,30 +1,18 @@
|
||||
# dialects/mysql/types.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# mysql/types.py
|
||||
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
from __future__ import annotations
|
||||
# mypy: ignore-errors
|
||||
|
||||
|
||||
import datetime
|
||||
import decimal
|
||||
from typing import Any
|
||||
from typing import Iterable
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import Union
|
||||
|
||||
from ... import exc
|
||||
from ... import util
|
||||
from ...sql import sqltypes
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .base import MySQLDialect
|
||||
from ...engine.interfaces import Dialect
|
||||
from ...sql.type_api import _BindProcessorType
|
||||
from ...sql.type_api import _ResultProcessorType
|
||||
from ...sql.type_api import TypeEngine
|
||||
|
||||
|
||||
class _NumericType:
|
||||
"""Base for MySQL numeric types.
|
||||
@@ -34,27 +22,19 @@ class _NumericType:
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, unsigned: bool = False, zerofill: bool = False, **kw: Any
|
||||
):
|
||||
def __init__(self, unsigned=False, zerofill=False, **kw):
|
||||
self.unsigned = unsigned
|
||||
self.zerofill = zerofill
|
||||
super().__init__(**kw)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
def __repr__(self):
|
||||
return util.generic_repr(
|
||||
self, to_inspect=[_NumericType, sqltypes.Numeric]
|
||||
)
|
||||
|
||||
|
||||
class _FloatType(_NumericType, sqltypes.Float[Union[decimal.Decimal, float]]):
|
||||
def __init__(
|
||||
self,
|
||||
precision: Optional[int] = None,
|
||||
scale: Optional[int] = None,
|
||||
asdecimal: bool = True,
|
||||
**kw: Any,
|
||||
):
|
||||
class _FloatType(_NumericType, sqltypes.Float):
|
||||
def __init__(self, precision=None, scale=None, asdecimal=True, **kw):
|
||||
if isinstance(self, (REAL, DOUBLE)) and (
|
||||
(precision is None and scale is not None)
|
||||
or (precision is not None and scale is None)
|
||||
@@ -66,18 +46,18 @@ class _FloatType(_NumericType, sqltypes.Float[Union[decimal.Decimal, float]]):
|
||||
super().__init__(precision=precision, asdecimal=asdecimal, **kw)
|
||||
self.scale = scale
|
||||
|
||||
def __repr__(self) -> str:
|
||||
def __repr__(self):
|
||||
return util.generic_repr(
|
||||
self, to_inspect=[_FloatType, _NumericType, sqltypes.Float]
|
||||
)
|
||||
|
||||
|
||||
class _IntegerType(_NumericType, sqltypes.Integer):
|
||||
def __init__(self, display_width: Optional[int] = None, **kw: Any):
|
||||
def __init__(self, display_width=None, **kw):
|
||||
self.display_width = display_width
|
||||
super().__init__(**kw)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
def __repr__(self):
|
||||
return util.generic_repr(
|
||||
self, to_inspect=[_IntegerType, _NumericType, sqltypes.Integer]
|
||||
)
|
||||
@@ -88,13 +68,13 @@ class _StringType(sqltypes.String):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
charset: Optional[str] = None,
|
||||
collation: Optional[str] = None,
|
||||
ascii: bool = False, # noqa
|
||||
binary: bool = False,
|
||||
unicode: bool = False,
|
||||
national: bool = False,
|
||||
**kw: Any,
|
||||
charset=None,
|
||||
collation=None,
|
||||
ascii=False, # noqa
|
||||
binary=False,
|
||||
unicode=False,
|
||||
national=False,
|
||||
**kw,
|
||||
):
|
||||
self.charset = charset
|
||||
|
||||
@@ -107,33 +87,25 @@ class _StringType(sqltypes.String):
|
||||
self.national = national
|
||||
super().__init__(**kw)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
def __repr__(self):
|
||||
return util.generic_repr(
|
||||
self, to_inspect=[_StringType, sqltypes.String]
|
||||
)
|
||||
|
||||
|
||||
class _MatchType(
|
||||
sqltypes.Float[Union[decimal.Decimal, float]], sqltypes.MatchType
|
||||
):
|
||||
def __init__(self, **kw: Any):
|
||||
class _MatchType(sqltypes.Float, sqltypes.MatchType):
|
||||
def __init__(self, **kw):
|
||||
# TODO: float arguments?
|
||||
sqltypes.Float.__init__(self) # type: ignore[arg-type]
|
||||
sqltypes.Float.__init__(self)
|
||||
sqltypes.MatchType.__init__(self)
|
||||
|
||||
|
||||
class NUMERIC(_NumericType, sqltypes.NUMERIC[Union[decimal.Decimal, float]]):
|
||||
class NUMERIC(_NumericType, sqltypes.NUMERIC):
|
||||
"""MySQL NUMERIC type."""
|
||||
|
||||
__visit_name__ = "NUMERIC"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
precision: Optional[int] = None,
|
||||
scale: Optional[int] = None,
|
||||
asdecimal: bool = True,
|
||||
**kw: Any,
|
||||
):
|
||||
def __init__(self, precision=None, scale=None, asdecimal=True, **kw):
|
||||
"""Construct a NUMERIC.
|
||||
|
||||
:param precision: Total digits in this number. If scale and precision
|
||||
@@ -154,18 +126,12 @@ class NUMERIC(_NumericType, sqltypes.NUMERIC[Union[decimal.Decimal, float]]):
|
||||
)
|
||||
|
||||
|
||||
class DECIMAL(_NumericType, sqltypes.DECIMAL[Union[decimal.Decimal, float]]):
|
||||
class DECIMAL(_NumericType, sqltypes.DECIMAL):
|
||||
"""MySQL DECIMAL type."""
|
||||
|
||||
__visit_name__ = "DECIMAL"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
precision: Optional[int] = None,
|
||||
scale: Optional[int] = None,
|
||||
asdecimal: bool = True,
|
||||
**kw: Any,
|
||||
):
|
||||
def __init__(self, precision=None, scale=None, asdecimal=True, **kw):
|
||||
"""Construct a DECIMAL.
|
||||
|
||||
:param precision: Total digits in this number. If scale and precision
|
||||
@@ -186,18 +152,12 @@ class DECIMAL(_NumericType, sqltypes.DECIMAL[Union[decimal.Decimal, float]]):
|
||||
)
|
||||
|
||||
|
||||
class DOUBLE(_FloatType, sqltypes.DOUBLE[Union[decimal.Decimal, float]]):
|
||||
class DOUBLE(_FloatType, sqltypes.DOUBLE):
|
||||
"""MySQL DOUBLE type."""
|
||||
|
||||
__visit_name__ = "DOUBLE"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
precision: Optional[int] = None,
|
||||
scale: Optional[int] = None,
|
||||
asdecimal: bool = True,
|
||||
**kw: Any,
|
||||
):
|
||||
def __init__(self, precision=None, scale=None, asdecimal=True, **kw):
|
||||
"""Construct a DOUBLE.
|
||||
|
||||
.. note::
|
||||
@@ -226,18 +186,12 @@ class DOUBLE(_FloatType, sqltypes.DOUBLE[Union[decimal.Decimal, float]]):
|
||||
)
|
||||
|
||||
|
||||
class REAL(_FloatType, sqltypes.REAL[Union[decimal.Decimal, float]]):
|
||||
class REAL(_FloatType, sqltypes.REAL):
|
||||
"""MySQL REAL type."""
|
||||
|
||||
__visit_name__ = "REAL"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
precision: Optional[int] = None,
|
||||
scale: Optional[int] = None,
|
||||
asdecimal: bool = True,
|
||||
**kw: Any,
|
||||
):
|
||||
def __init__(self, precision=None, scale=None, asdecimal=True, **kw):
|
||||
"""Construct a REAL.
|
||||
|
||||
.. note::
|
||||
@@ -266,18 +220,12 @@ class REAL(_FloatType, sqltypes.REAL[Union[decimal.Decimal, float]]):
|
||||
)
|
||||
|
||||
|
||||
class FLOAT(_FloatType, sqltypes.FLOAT[Union[decimal.Decimal, float]]):
|
||||
class FLOAT(_FloatType, sqltypes.FLOAT):
|
||||
"""MySQL FLOAT type."""
|
||||
|
||||
__visit_name__ = "FLOAT"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
precision: Optional[int] = None,
|
||||
scale: Optional[int] = None,
|
||||
asdecimal: bool = False,
|
||||
**kw: Any,
|
||||
):
|
||||
def __init__(self, precision=None, scale=None, asdecimal=False, **kw):
|
||||
"""Construct a FLOAT.
|
||||
|
||||
:param precision: Total digits in this number. If scale and precision
|
||||
@@ -297,9 +245,7 @@ class FLOAT(_FloatType, sqltypes.FLOAT[Union[decimal.Decimal, float]]):
|
||||
precision=precision, scale=scale, asdecimal=asdecimal, **kw
|
||||
)
|
||||
|
||||
def bind_processor(
|
||||
self, dialect: Dialect
|
||||
) -> Optional[_BindProcessorType[Union[decimal.Decimal, float]]]:
|
||||
def bind_processor(self, dialect):
|
||||
return None
|
||||
|
||||
|
||||
@@ -308,7 +254,7 @@ class INTEGER(_IntegerType, sqltypes.INTEGER):
|
||||
|
||||
__visit_name__ = "INTEGER"
|
||||
|
||||
def __init__(self, display_width: Optional[int] = None, **kw: Any):
|
||||
def __init__(self, display_width=None, **kw):
|
||||
"""Construct an INTEGER.
|
||||
|
||||
:param display_width: Optional, maximum display width for this number.
|
||||
@@ -329,7 +275,7 @@ class BIGINT(_IntegerType, sqltypes.BIGINT):
|
||||
|
||||
__visit_name__ = "BIGINT"
|
||||
|
||||
def __init__(self, display_width: Optional[int] = None, **kw: Any):
|
||||
def __init__(self, display_width=None, **kw):
|
||||
"""Construct a BIGINTEGER.
|
||||
|
||||
:param display_width: Optional, maximum display width for this number.
|
||||
@@ -350,7 +296,7 @@ class MEDIUMINT(_IntegerType):
|
||||
|
||||
__visit_name__ = "MEDIUMINT"
|
||||
|
||||
def __init__(self, display_width: Optional[int] = None, **kw: Any):
|
||||
def __init__(self, display_width=None, **kw):
|
||||
"""Construct a MEDIUMINTEGER
|
||||
|
||||
:param display_width: Optional, maximum display width for this number.
|
||||
@@ -371,7 +317,7 @@ class TINYINT(_IntegerType):
|
||||
|
||||
__visit_name__ = "TINYINT"
|
||||
|
||||
def __init__(self, display_width: Optional[int] = None, **kw: Any):
|
||||
def __init__(self, display_width=None, **kw):
|
||||
"""Construct a TINYINT.
|
||||
|
||||
:param display_width: Optional, maximum display width for this number.
|
||||
@@ -386,19 +332,13 @@ class TINYINT(_IntegerType):
|
||||
"""
|
||||
super().__init__(display_width=display_width, **kw)
|
||||
|
||||
def _compare_type_affinity(self, other: TypeEngine[Any]) -> bool:
|
||||
return (
|
||||
self._type_affinity is other._type_affinity
|
||||
or other._type_affinity is sqltypes.Boolean
|
||||
)
|
||||
|
||||
|
||||
class SMALLINT(_IntegerType, sqltypes.SMALLINT):
|
||||
"""MySQL SMALLINTEGER type."""
|
||||
|
||||
__visit_name__ = "SMALLINT"
|
||||
|
||||
def __init__(self, display_width: Optional[int] = None, **kw: Any):
|
||||
def __init__(self, display_width=None, **kw):
|
||||
"""Construct a SMALLINTEGER.
|
||||
|
||||
:param display_width: Optional, maximum display width for this number.
|
||||
@@ -414,7 +354,7 @@ class SMALLINT(_IntegerType, sqltypes.SMALLINT):
|
||||
super().__init__(display_width=display_width, **kw)
|
||||
|
||||
|
||||
class BIT(sqltypes.TypeEngine[Any]):
|
||||
class BIT(sqltypes.TypeEngine):
|
||||
"""MySQL BIT type.
|
||||
|
||||
This type is for MySQL 5.0.3 or greater for MyISAM, and 5.0.5 or greater
|
||||
@@ -425,7 +365,7 @@ class BIT(sqltypes.TypeEngine[Any]):
|
||||
|
||||
__visit_name__ = "BIT"
|
||||
|
||||
def __init__(self, length: Optional[int] = None):
|
||||
def __init__(self, length=None):
|
||||
"""Construct a BIT.
|
||||
|
||||
:param length: Optional, number of bits.
|
||||
@@ -433,19 +373,20 @@ class BIT(sqltypes.TypeEngine[Any]):
|
||||
"""
|
||||
self.length = length
|
||||
|
||||
def result_processor(
|
||||
self, dialect: MySQLDialect, coltype: object # type: ignore[override]
|
||||
) -> Optional[_ResultProcessorType[Any]]:
|
||||
"""Convert a MySQL's 64 bit, variable length binary string to a
|
||||
long."""
|
||||
def result_processor(self, dialect, coltype):
|
||||
"""Convert a MySQL's 64 bit, variable length binary string to a long.
|
||||
|
||||
if dialect.supports_native_bit:
|
||||
return None
|
||||
TODO: this is MySQL-db, pyodbc specific. OurSQL and mysqlconnector
|
||||
already do this, so this logic should be moved to those dialects.
|
||||
|
||||
def process(value: Optional[Iterable[int]]) -> Optional[int]:
|
||||
"""
|
||||
|
||||
def process(value):
|
||||
if value is not None:
|
||||
v = 0
|
||||
for i in value:
|
||||
if not isinstance(i, int):
|
||||
i = ord(i) # convert byte to int on Python 2
|
||||
v = v << 8 | i
|
||||
return v
|
||||
return value
|
||||
@@ -458,7 +399,7 @@ class TIME(sqltypes.TIME):
|
||||
|
||||
__visit_name__ = "TIME"
|
||||
|
||||
def __init__(self, timezone: bool = False, fsp: Optional[int] = None):
|
||||
def __init__(self, timezone=False, fsp=None):
|
||||
"""Construct a MySQL TIME type.
|
||||
|
||||
:param timezone: not used by the MySQL dialect.
|
||||
@@ -477,12 +418,10 @@ class TIME(sqltypes.TIME):
|
||||
super().__init__(timezone=timezone)
|
||||
self.fsp = fsp
|
||||
|
||||
def result_processor(
|
||||
self, dialect: Dialect, coltype: object
|
||||
) -> _ResultProcessorType[datetime.time]:
|
||||
def result_processor(self, dialect, coltype):
|
||||
time = datetime.time
|
||||
|
||||
def process(value: Any) -> Optional[datetime.time]:
|
||||
def process(value):
|
||||
# convert from a timedelta value
|
||||
if value is not None:
|
||||
microseconds = value.microseconds
|
||||
@@ -505,7 +444,7 @@ class TIMESTAMP(sqltypes.TIMESTAMP):
|
||||
|
||||
__visit_name__ = "TIMESTAMP"
|
||||
|
||||
def __init__(self, timezone: bool = False, fsp: Optional[int] = None):
|
||||
def __init__(self, timezone=False, fsp=None):
|
||||
"""Construct a MySQL TIMESTAMP type.
|
||||
|
||||
:param timezone: not used by the MySQL dialect.
|
||||
@@ -530,7 +469,7 @@ class DATETIME(sqltypes.DATETIME):
|
||||
|
||||
__visit_name__ = "DATETIME"
|
||||
|
||||
def __init__(self, timezone: bool = False, fsp: Optional[int] = None):
|
||||
def __init__(self, timezone=False, fsp=None):
|
||||
"""Construct a MySQL DATETIME type.
|
||||
|
||||
:param timezone: not used by the MySQL dialect.
|
||||
@@ -550,26 +489,26 @@ class DATETIME(sqltypes.DATETIME):
|
||||
self.fsp = fsp
|
||||
|
||||
|
||||
class YEAR(sqltypes.TypeEngine[Any]):
|
||||
class YEAR(sqltypes.TypeEngine):
|
||||
"""MySQL YEAR type, for single byte storage of years 1901-2155."""
|
||||
|
||||
__visit_name__ = "YEAR"
|
||||
|
||||
def __init__(self, display_width: Optional[int] = None):
|
||||
def __init__(self, display_width=None):
|
||||
self.display_width = display_width
|
||||
|
||||
|
||||
class TEXT(_StringType, sqltypes.TEXT):
|
||||
"""MySQL TEXT type, for character storage encoded up to 2^16 bytes."""
|
||||
"""MySQL TEXT type, for text up to 2^16 characters."""
|
||||
|
||||
__visit_name__ = "TEXT"
|
||||
|
||||
def __init__(self, length: Optional[int] = None, **kw: Any):
|
||||
def __init__(self, length=None, **kw):
|
||||
"""Construct a TEXT.
|
||||
|
||||
:param length: Optional, if provided the server may optimize storage
|
||||
by substituting the smallest TEXT type sufficient to store
|
||||
``length`` bytes of characters.
|
||||
``length`` characters.
|
||||
|
||||
:param charset: Optional, a column-level character set for this string
|
||||
value. Takes precedence to 'ascii' or 'unicode' short-hand.
|
||||
@@ -596,11 +535,11 @@ class TEXT(_StringType, sqltypes.TEXT):
|
||||
|
||||
|
||||
class TINYTEXT(_StringType):
|
||||
"""MySQL TINYTEXT type, for character storage encoded up to 2^8 bytes."""
|
||||
"""MySQL TINYTEXT type, for text up to 2^8 characters."""
|
||||
|
||||
__visit_name__ = "TINYTEXT"
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
def __init__(self, **kwargs):
|
||||
"""Construct a TINYTEXT.
|
||||
|
||||
:param charset: Optional, a column-level character set for this string
|
||||
@@ -628,12 +567,11 @@ class TINYTEXT(_StringType):
|
||||
|
||||
|
||||
class MEDIUMTEXT(_StringType):
|
||||
"""MySQL MEDIUMTEXT type, for character storage encoded up
|
||||
to 2^24 bytes."""
|
||||
"""MySQL MEDIUMTEXT type, for text up to 2^24 characters."""
|
||||
|
||||
__visit_name__ = "MEDIUMTEXT"
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
def __init__(self, **kwargs):
|
||||
"""Construct a MEDIUMTEXT.
|
||||
|
||||
:param charset: Optional, a column-level character set for this string
|
||||
@@ -661,11 +599,11 @@ class MEDIUMTEXT(_StringType):
|
||||
|
||||
|
||||
class LONGTEXT(_StringType):
|
||||
"""MySQL LONGTEXT type, for character storage encoded up to 2^32 bytes."""
|
||||
"""MySQL LONGTEXT type, for text up to 2^32 characters."""
|
||||
|
||||
__visit_name__ = "LONGTEXT"
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
def __init__(self, **kwargs):
|
||||
"""Construct a LONGTEXT.
|
||||
|
||||
:param charset: Optional, a column-level character set for this string
|
||||
@@ -697,7 +635,7 @@ class VARCHAR(_StringType, sqltypes.VARCHAR):
|
||||
|
||||
__visit_name__ = "VARCHAR"
|
||||
|
||||
def __init__(self, length: Optional[int] = None, **kwargs: Any) -> None:
|
||||
def __init__(self, length=None, **kwargs):
|
||||
"""Construct a VARCHAR.
|
||||
|
||||
:param charset: Optional, a column-level character set for this string
|
||||
@@ -729,7 +667,7 @@ class CHAR(_StringType, sqltypes.CHAR):
|
||||
|
||||
__visit_name__ = "CHAR"
|
||||
|
||||
def __init__(self, length: Optional[int] = None, **kwargs: Any):
|
||||
def __init__(self, length=None, **kwargs):
|
||||
"""Construct a CHAR.
|
||||
|
||||
:param length: Maximum data length, in characters.
|
||||
@@ -745,7 +683,7 @@ class CHAR(_StringType, sqltypes.CHAR):
|
||||
super().__init__(length=length, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def _adapt_string_for_cast(cls, type_: sqltypes.String) -> sqltypes.CHAR:
|
||||
def _adapt_string_for_cast(self, type_):
|
||||
# copy the given string type into a CHAR
|
||||
# for the purposes of rendering a CAST expression
|
||||
type_ = sqltypes.to_instance(type_)
|
||||
@@ -774,7 +712,7 @@ class NVARCHAR(_StringType, sqltypes.NVARCHAR):
|
||||
|
||||
__visit_name__ = "NVARCHAR"
|
||||
|
||||
def __init__(self, length: Optional[int] = None, **kwargs: Any):
|
||||
def __init__(self, length=None, **kwargs):
|
||||
"""Construct an NVARCHAR.
|
||||
|
||||
:param length: Maximum data length, in characters.
|
||||
@@ -800,7 +738,7 @@ class NCHAR(_StringType, sqltypes.NCHAR):
|
||||
|
||||
__visit_name__ = "NCHAR"
|
||||
|
||||
def __init__(self, length: Optional[int] = None, **kwargs: Any):
|
||||
def __init__(self, length=None, **kwargs):
|
||||
"""Construct an NCHAR.
|
||||
|
||||
:param length: Maximum data length, in characters.
|
||||
|
||||
Reference in New Issue
Block a user