"""
A SQLALchemy dialect.
"""
# pylint: disable=protected-access, abstract-method
from typing import Any, Optional, cast
import sqlalchemy.types
from sqlalchemy.dialects.sqlite.base import SQLiteDialect
from sqlalchemy.engine.url import URL
from sqlalchemy.pool.base import _ConnectionFairy
from sqlalchemy.sql.type_api import TypeEngine
from typing_extensions import TypedDict
from shillelagh.adapters.base import Adapter
from shillelagh.backends.apsw import db
from shillelagh.backends.apsw.vt import VTTable
from shillelagh.exceptions import ProgrammingError
from shillelagh.lib import find_adapter
[docs]
class ComputedColumn(TypedDict):
"""
A custom type for a computed column in SQLAlchemy.
"""
sqltext: str
persisted: bool
[docs]
class SQLAlchemyColumn(TypedDict, total=False):
"""
A custom type for a SQLAlchemy column.
"""
name: str
type: TypeEngine
nullable: bool
default: Optional[str]
autoincrement: str
primary_key: int
comment: Optional[str]
computed: Optional[ComputedColumn]
[docs]
class APSWDialect(SQLiteDialect):
"""
A SQLAlchemy dialect for Shillelagh.
The dialect is based on the ``SQLiteDialect``, since we're using APSW.
"""
name = "shillelagh"
driver = "apsw"
# This is supported in ``SQLiteDialect``, and equally supported here. See
# https://docs.sqlalchemy.org/en/14/core/connections.html#caching-for-third-party-dialects
# for more context.
supports_statement_cache = True
# ``SQLiteDialect.colspecs`` has custom representations for objects that SQLite stores
# as string (eg, timestamps). Since the Shillelagh DB API driver returns them as
# proper objects the custom representations are not needed.
colspecs: dict[TypeEngine, TypeEngine] = {}
supports_sane_rowcount = False
[docs]
@classmethod
def dbapi(cls): # pylint: disable=method-hidden
"""
Return the DB API module.
"""
return db
[docs]
@classmethod
def import_dbapi(cls): # pylint: disable=method-hidden
"""
Return the DB API module.
"""
return db
def __init__(
self,
adapters: Optional[list[str]] = None,
adapter_kwargs: Optional[dict[str, dict[str, Any]]] = None,
safe: bool = False,
**kwargs: Any,
):
super().__init__(**kwargs)
self._adapters = adapters
self._adapter_kwargs = adapter_kwargs or {}
self._safe = safe
[docs]
def create_connect_args(
self,
url: URL,
) -> tuple[tuple[()], dict[str, Any]]:
path = str(url.database) if url.database else ":memory:"
return (), {
"path": path,
"adapters": self._adapters,
"adapter_kwargs": self._adapter_kwargs,
"safe": self._safe,
"isolation_level": self.isolation_level,
}
[docs]
def do_ping(self, dbapi_connection: _ConnectionFairy) -> bool:
"""
Return true if the database is online.
"""
return True
[docs]
def has_table( # pylint: disable=unused-argument
self,
connection: _ConnectionFairy,
table_name: str,
schema: Optional[str] = None,
**kwargs: Any,
) -> bool:
"""
Return true if a given table exists.
"""
try:
get_adapter_for_table_name(connection, table_name)
except ProgrammingError:
return bool(
super().has_table(
connection,
table_name,
schema,
**kwargs, # pylint: disable=unused-argument
),
)
return True
# needed for SQLAlchemy
def _get_table_sql( # pylint: disable=unused-argument
self,
connection: _ConnectionFairy,
table_name: str,
schema: Optional[str] = None,
**kwargs: Any,
) -> str:
adapter = get_adapter_for_table_name(connection, table_name)
table = VTTable(adapter)
return table.get_create_table(table_name)
[docs]
def get_columns( # pylint: disable=unused-argument
self,
connection: _ConnectionFairy,
table_name: str,
schema: Optional[str] = None,
**kwargs: Any,
) -> list[SQLAlchemyColumn]:
adapter = get_adapter_for_table_name(connection, table_name)
columns = adapter.get_columns()
return [
{
"name": column_name,
"type": getattr(sqlalchemy.types, field.type),
"nullable": True,
"default": None,
"autoincrement": "auto",
"primary_key": 0,
}
for column_name, field in columns.items()
]
[docs]
def get_adapter_for_table_name(
connection: _ConnectionFairy,
table_name: str,
) -> Adapter:
"""
Return an adapter associated with a connection.
This function instantiates the adapter responsible for a given table name,
using the connection to properly pass any adapter kwargs.
"""
raw_connection = cast(db.Connection, connection.engine.raw_connection())
adapter, args, kwargs = find_adapter(
table_name,
raw_connection._adapter_kwargs,
raw_connection._adapters,
)
return adapter(*args, **kwargs) # type: ignore