"""Helper functions for Shillelagh."""
import base64
import inspect
import itertools
import json
import marshal
import math
import operator
from datetime import timedelta
from typing import (
Any,
Callable,
Dict,
Iterator,
List,
Optional,
Set,
Tuple,
Type,
TypeVar,
)
import apsw
import requests_cache
from packaging.version import Version
from shillelagh.adapters.base import Adapter
from shillelagh.exceptions import ImpossibleFilterError, ProgrammingError
from shillelagh.fields import Boolean, Field, Float, Integer, Order, String
from shillelagh.filters import (
Equal,
Filter,
Impossible,
IsNotNull,
IsNull,
Like,
NotEqual,
Operator,
Range,
)
from shillelagh.typing import RequestedOrder, Row
DELETED = range(-1, 0)
CACHE_EXPIRATION = timedelta(minutes=3)
[docs]
class RowIDManager:
"""
A row ID manager that tracks insert and deletes.
The ``RowIDManager`` should be used with an append-only table structure.
It assigns a row ID to each row. When a new row is appended it will
automatically receive a new ID. And when rows are deleted their ID
gets changed to -1 to indicate the deletion.
An example:
>>> data = ["zero", "one", "two"]
>>> manager = RowIDManager([range(len(data))])
To insert data:
>>> data.append("three")
>>> manager.insert()
3
>>> data.append("four")
>>> manager.insert(10) # you can specify a row ID
10
>>> for row_id, value in zip(manager, data):
... if row_id != -1:
... print(row_id, value)
0 zero
1 one
2 two
3 three
10 four
To delete data:
>>> manager.delete(data.index("two"))
>>> print(data)
['zero', 'one', 'two', 'three', 'four']
>>> for row_id, value in zip(manager, data):
... if row_id != -1:
... print(row_id, value)
0 zero
1 one
3 three
10 four
"""
def __init__(self, ranges: List[range]):
if not ranges:
# pylint: disable=broad-exception-raised
raise Exception("Argument ``ranges`` cannot be empty")
self.ranges = ranges
def __iter__(self):
yield from itertools.chain(*self.ranges)
[docs]
def get_max_row_id(self) -> int:
"""
Find the maximum row ID.
"""
return max((r.stop - 1) for r in self.ranges)
[docs]
def check_row_id(self, row_id: int) -> None:
"""
Check if a provided row ID is not being used.
"""
for range_ in self.ranges:
if range_.start <= row_id < range_.stop:
# pylint: disable=broad-exception-raised
raise Exception(f"Row ID {row_id} already present")
[docs]
def insert(self, row_id: Optional[int] = None) -> int:
"""
Insert a new row ID.
"""
if row_id is None:
max_row_id = self.get_max_row_id()
row_id = max_row_id + 1
else:
self.check_row_id(row_id)
last = self.ranges[-1]
if last.stop == row_id:
self.ranges[-1] = range(last.start, row_id + 1)
else:
self.ranges.append(range(row_id, row_id + 1))
return row_id
[docs]
def delete(self, row_id: int) -> None:
"""Mark a given row ID as deleted."""
for i, range_ in enumerate(self.ranges):
if range_.start <= row_id < range_.stop:
if range_.start == range_.stop - 1:
self.ranges[i] = DELETED
elif row_id == range_.start:
self.ranges[i] = range(range_.start + 1, range_.stop)
self.ranges.insert(i, DELETED)
elif row_id == range_.stop - 1:
self.ranges[i] = range(range_.start, range_.stop - 1)
self.ranges.insert(i + 1, DELETED)
else:
self.ranges[i] = range(range_.start, row_id)
self.ranges.insert(i + 1, range(row_id + 1, range_.stop))
self.ranges.insert(i + 1, DELETED)
return
# pylint: disable=broad-exception-raised
raise Exception(f"Row ID {row_id} not found")
[docs]
def analyze( # pylint: disable=too-many-branches
data: Iterator[Row],
) -> Tuple[int, Dict[str, Order], Dict[str, Type[Field]]]:
"""
Compute number of rows, order, and types from a stream of rows.
"""
order: Dict[str, Order] = {}
types: Dict[str, Type[Field]] = {}
previous_row: Row = {}
row: Row = {}
i = -1
for i, row in enumerate(data):
for column_name, value in row.items():
# determine order
if i > 0:
previous = previous_row.get(column_name)
order[column_name] = update_order(
current_order=order.get(column_name, Order.NONE),
previous=previous,
current=value,
num_rows=i + 1,
)
# determine types
if types.get(column_name) == String:
continue
if isinstance(value, (str, list, dict)):
types[column_name] = String
elif types.get(column_name) == Float:
continue
elif isinstance(value, float):
types[column_name] = Float
elif types.get(column_name) == Integer:
continue
# ``isintance(True, int) == True`` :(
elif isinstance(value, int) and not isinstance(value, bool):
types[column_name] = Integer
elif types.get(column_name) == Boolean:
continue
elif isinstance(value, bool):
types[column_name] = Boolean
else:
# something weird, use string
types[column_name] = String
previous_row = row
if row and not order:
order = {column_name: Order.NONE for column_name in row.keys()}
num_rows = i + 1
return num_rows, order, types
[docs]
def update_order(
current_order: Order,
previous: Any,
current: Any,
num_rows: int,
) -> Order:
"""
Update the stored order of a given column.
This is used to analyze the order of columns, by traversing the
results and checking if their are sorted in any way.
"""
if num_rows < 2 or previous is None:
return Order.NONE
try:
if num_rows == 2:
return Order.ASCENDING if current >= previous else Order.DESCENDING
if (
current_order == Order.NONE
or (current_order == Order.ASCENDING and current < previous)
or (current_order == Order.DESCENDING and current > previous)
):
return Order.NONE
except TypeError:
return Order.NONE
return current_order
[docs]
def escape_string(value: str) -> str:
"""Escape single quotes."""
return value.replace("'", "''")
[docs]
def unescape_string(value: str) -> str:
"""Unescape single quotes."""
return value.replace("''", "'")
[docs]
def escape_identifier(value: str) -> str:
"""Escape double quotes."""
return value.replace('"', '""')
[docs]
def unescape_identifier(value: str) -> str:
"""Unescape double quotes."""
return value.replace('""', '"')
[docs]
def serialize(value: Any) -> str:
"""
Serialize adapter arguments.
This function is used with the SQLite backend, in order to serialize
the arguments needed to instantiate an adapter via a virtual table.
"""
try:
serialized = marshal.dumps(value)
except ValueError as ex:
raise ProgrammingError(
f"The argument {value} is not serializable because it has type "
f"{type(value)}. Make sure only basic types (list, dicts, strings, "
"numbers) are passed as arguments to adapters.",
) from ex
return escape_string(base64.b64encode(serialized).decode())
[docs]
def deserialize(value: str) -> Any:
"""
Deserialize adapter arguments.
This function is used by the SQLite backend, in order to deserialize
the virtual table definition and instantiate an adapter.
"""
return marshal.loads(base64.b64decode(unescape_string(value).encode()))
[docs]
def build_sql( # pylint: disable=too-many-locals, too-many-arguments, too-many-branches
columns: Dict[str, Field],
bounds: Dict[str, Filter],
order: List[Tuple[str, RequestedOrder]],
table: Optional[str] = None,
column_map: Optional[Dict[str, str]] = None,
limit: Optional[int] = None,
offset: Optional[int] = None,
alias: Optional[str] = None,
) -> str:
"""
Build a SQL query.
This is used by adapters which use a simplified SQL dialect to fetch data. For
GSheets a column map is required, since the SQL references columns by label
("A", "B", etc.) instead of name.
"""
sql = "SELECT *"
if table:
sql = f"{sql} FROM {table}"
if alias:
sql = f"{sql} AS {alias}"
conditions = []
for column_name, filter_ in bounds.items():
if (
isinstance(filter_, Range) # pylint: disable=too-many-boolean-expressions
and filter_.start is not None
and filter_.end is not None
and filter_.start == filter_.end
and filter_.include_start
and filter_.include_end
):
filter_ = Equal(filter_.start)
field = columns[column_name]
id_ = column_map[column_name] if column_map else column_name
if alias:
id_ = f"{alias}.{id_}"
conditions.extend(get_conditions(id_, field, filter_))
if conditions:
sql = f"{sql} WHERE {' AND '.join(conditions)}"
column_order: List[str] = []
for column_name, requested_order in order:
id_ = column_map[column_name] if column_map else column_name
if alias:
id_ = f"{alias}.{id_}"
desc = " DESC" if requested_order == Order.DESCENDING else ""
column_order.append(f"{id_}{desc}")
if column_order:
sql = f"{sql} ORDER BY {', '.join(column_order)}"
if limit is not None:
sql = f"{sql} LIMIT {limit}"
if offset is not None:
sql = f"{sql} OFFSET {offset}"
return sql
[docs]
def get_conditions(id_: str, field: Field, filter_: Filter) -> List[str]:
"""
Build a SQL condition from a column ID and a filter.
"""
if isinstance(filter_, Impossible):
raise ImpossibleFilterError()
if isinstance(filter_, Equal):
return [f"{id_} = {field.quote(filter_.value)}"]
if isinstance(filter_, NotEqual):
return [f"{id_} != {field.quote(filter_.value)}"]
if isinstance(filter_, Range):
conditions = []
if filter_.start is not None:
operator_ = ">=" if filter_.include_start else ">"
conditions.append(f"{id_} {operator_} {field.quote(filter_.start)}")
if filter_.end is not None:
operator_ = "<=" if filter_.include_end else "<"
conditions.append(f"{id_} {operator_} {field.quote(filter_.end)}")
return conditions
if isinstance(filter_, Like):
return [f"{id_} LIKE {field.quote(filter_.value)}"]
if isinstance(filter_, IsNull):
return [f"{id_} IS NULL"]
if isinstance(filter_, IsNotNull):
return [f"{id_} IS NOT NULL"]
raise ProgrammingError(f"Invalid filter: {filter_}")
[docs]
def combine_args_kwargs(
func: Callable[..., Any], *args: Any, **kwargs: Any
) -> Tuple[Any, ...]:
"""
Combine args and kwargs into args.
This is needed because we allow users to pass custom kwargs to adapters,
but when creating the virtual table we serialize only args.
"""
signature = inspect.signature(func)
bound_args = signature.bind(*args, **kwargs)
bound_args.apply_defaults()
return bound_args.args
[docs]
def is_null(column: Any, _: Any) -> bool:
"""
Operator for ``IS NULL``.
"""
return column is None
[docs]
def is_not_null(column: Any, _: Any) -> bool:
"""
Operator for ``IS NOT NULL``.
"""
return column is not None
[docs]
def filter_data( # pylint: disable=too-many-arguments
data: Iterator[Row],
bounds: Dict[str, Filter],
order: List[Tuple[str, RequestedOrder]],
limit: Optional[int] = None,
offset: Optional[int] = None,
requested_columns: Optional[Set[str]] = None,
) -> Iterator[Row]:
"""
Apply filtering and sorting to a stream of rows.
This is used mostly as an exercise. It's probably much more efficient to
simply declare fields without any filtering/sorting and let the backend
(SQLite, eg) handle it.
"""
data = (
{
k: v
for k, v in row.items()
if requested_columns is None or k in requested_columns
}
for row in data
)
for column_name, filter_ in bounds.items():
def apply_filter(
data: Iterator[Row],
operator_: Callable[[Any, Any], bool],
column_name: str,
value: Any,
) -> Iterator[Row]:
"""
Apply a given filter to an iterator of rows.
This method is needed because Python uses lazy bindings in
generator expressions, so we need to create a new scope in
order to apply several filters that reuse the same variable
names.
"""
return (row for row in data if operator_(row[column_name], value))
if isinstance(filter_, Impossible):
return
if isinstance(filter_, Equal):
data = apply_filter(data, operator.eq, column_name, filter_.value)
elif isinstance(filter_, NotEqual):
data = apply_filter(data, operator.ne, column_name, filter_.value)
elif isinstance(filter_, Range):
if filter_.start is not None:
operator_ = operator.ge if filter_.include_start else operator.gt
data = apply_filter(data, operator_, column_name, filter_.start)
if filter_.end is not None:
operator_ = operator.le if filter_.include_end else operator.lt
data = apply_filter(data, operator_, column_name, filter_.end)
elif isinstance(filter_, IsNull):
data = apply_filter(data, is_null, column_name, None)
elif isinstance(filter_, IsNotNull):
data = apply_filter(data, is_not_null, column_name, None)
else:
raise ProgrammingError(f"Invalid filter: {filter_}")
if order:
# in order to sort we need to consume the iterator and load it into
# memory :(
rows = list(data)
for column_name, requested_order in order:
reverse = requested_order == Order.DESCENDING
rows.sort(key=operator.itemgetter(column_name), reverse=reverse)
data = iter(rows)
data = apply_limit_and_offset(data, limit, offset)
yield from data
T = TypeVar("T")
[docs]
def apply_limit_and_offset(
rows: Iterator[T],
limit: Optional[int] = None,
offset: Optional[int] = None,
) -> Iterator[T]:
"""
Apply limit/offset to a stream of rows.
"""
if limit is not None or offset is not None:
start = offset or 0
end = None if limit is None else start + limit
rows = itertools.islice(rows, start, end)
return rows
[docs]
def SimpleCostModel(rows: int, fixed_cost: int = 0): # pylint: disable=invalid-name
"""
A simple model for estimating query costs.
The model assumes that each filtering operation is O(n), and each sorting
operation is O(n log n), in addition to a fixed cost.
"""
def method(
obj: Any, # pylint: disable=unused-argument
filtered_columns: List[Tuple[str, Operator]],
order: List[Tuple[str, RequestedOrder]],
) -> int:
return int(
fixed_cost
+ rows * len(filtered_columns)
+ rows
* math.log2(rows) # pylint: disable=c-extension-no-member
* len(order),
)
return method
[docs]
def NetworkAPICostModel(
download_cost: int,
fixed_cost: int = 0,
): # pylint: disable=invalid-name
"""
A cost model for adapters with network API calls.
In this case, transferring less data and doing less connections is more efficient.
"""
def method(
obj: Any, # pylint: disable=unused-argument
filtered_columns: List[Tuple[str, Operator]],
order: List[Tuple[str, RequestedOrder]], # pylint: disable=unused-argument
) -> int:
return fixed_cost + int(download_cost / (len(filtered_columns) + 1))
return method
[docs]
def find_adapter(
uri: str,
adapter_kwargs: Dict[str, Any],
adapters: List[Type[Adapter]],
) -> Tuple[Type[Adapter], Tuple[Any, ...], Dict[str, Any]]:
"""
Find an adapter that handles a given URI.
This is done in 2 passes: first the ``supports`` method is called with ``fast=True``.
If no adapter returns ``True`` we do a second pass on the plugins that returned
``None``, passing ``fast=False`` so they can do network requests to better inspect
the URI.
"""
candidates = set()
for adapter in adapters:
key = adapter.__name__.lower()
kwargs = adapter_kwargs.get(key, {})
supported: Optional[bool] = adapter.supports(uri, fast=True, **kwargs)
if supported:
args = adapter.parse_uri(uri)
return adapter, args, kwargs
if supported is None:
candidates.add(adapter)
for adapter in candidates:
key = adapter.__name__.lower()
kwargs = adapter_kwargs.get(key, {})
if adapter.supports(uri, fast=False, **kwargs):
args = adapter.parse_uri(uri)
return adapter, args, kwargs
raise ProgrammingError(f"Unsupported table: {uri}")
[docs]
def flatten(row: Row) -> Row:
"""
Function that converts JSON to strings, to flatten rows.
"""
return {
k: json.dumps(v) if isinstance(v, (list, dict)) else v for k, v in row.items()
}
[docs]
def best_index_object_available() -> bool:
"""
Check if support for best index object is available.
"""
return bool(Version(apsw.apswversion()) >= Version("3.41.0.0"))
[docs]
def get_session(
request_headers: Dict[str, str],
cache_name: str,
expire_after: timedelta = CACHE_EXPIRATION,
) -> requests_cache.CachedSession: # E: line too long (81 > 79 characters)
"""
Return a cached session.
"""
session = requests_cache.CachedSession(
cache_name=cache_name,
backend="sqlite",
expire_after=requests_cache.DO_NOT_CACHE
if expire_after == timedelta(seconds=-1)
else expire_after.total_seconds(),
)
session.headers.update(request_headers)
return session