"""
An adapter to Datasette instances.
See https://datasette.io/ for more information.
"""
import logging
import urllib.parse
from datetime import timedelta
from typing import Any, Dict, Iterator, List, Optional, Tuple, Type, cast
import dateutil.parser
from shillelagh.adapters.base import Adapter
from shillelagh.exceptions import ProgrammingError
from shillelagh.fields import Field, Float, Integer, ISODate, ISODateTime, Order, String
from shillelagh.filters import Equal, Filter, IsNotNull, IsNull, Like, NotEqual, Range
from shillelagh.lib import SimpleCostModel, build_sql, get_session
from shillelagh.typing import RequestedOrder, Row
_logger = logging.getLogger(__name__)
KNOWN_DOMAINS = {"datasette.io", "datasettes.com"}
# this is just a wild guess; used to estimate query cost
AVERAGE_NUMBER_OF_ROWS = 1000
# how many rows to get when performing our own pagination
DEFAULT_LIMIT = 1000
CACHE_EXPIRATION = timedelta(minutes=3)
[docs]
def is_known_domain(netloc: str) -> bool:
"""
Identify well known Datasette domains.
"""
for domain in KNOWN_DOMAINS:
if netloc == domain or netloc.endswith("." + domain):
return True
return False
[docs]
def is_datasette(uri: str) -> bool:
"""
Identify Datasette servers via a HEAD request.
"""
parsed = urllib.parse.urlparse(uri)
try:
# pylint: disable=unused-variable
mountpoint, database, table = parsed.path.rsplit("/", 2)
except ValueError:
return False
parsed = parsed._replace(path=f"{mountpoint}/-/versions.json")
uri = urllib.parse.urlunparse(parsed)
session = get_session({}, "datasette_cache", CACHE_EXPIRATION)
response = session.get(uri)
try:
payload = response.json()
except Exception: # pylint: disable=broad-exception-caught
return False
return "datasette" in payload
[docs]
def get_field(value: Any) -> Field:
"""
Return a Shillelagh ``Field`` based on the value type.
"""
class_: Type[Field] = String
filters = [Range, Equal, NotEqual, IsNull, IsNotNull]
if isinstance(value, int):
class_ = Integer
elif isinstance(value, float):
class_ = Float
elif isinstance(value, str):
try:
dateutil.parser.isoparse(value)
except Exception: # pylint: disable=broad-except
# regular string
filters.append(Like)
else:
class_ = ISODate if len(value) == 10 else ISODateTime # type: ignore
return class_(filters=filters, order=Order.ANY, exact=True)
[docs]
class DatasetteAPI(Adapter):
"""
An adapter to Datasette instances (https://datasette.io/).
"""
safe = True
supports_limit = True
supports_offset = True
[docs]
@staticmethod
def supports(uri: str, fast: bool = True, **kwargs: Any) -> Optional[bool]:
parsed = urllib.parse.urlparse(uri)
if parsed.scheme in {"http", "https"} and is_known_domain(parsed.netloc):
return True
if fast:
return None
return is_datasette(uri)
[docs]
@staticmethod
def parse_uri(uri: str) -> Tuple[str, str, str]:
server_url, database, table = uri.rsplit("/", 2)
return server_url, database, table
def __init__(self, server_url: str, database: str, table: str):
super().__init__()
self.server_url = server_url
self.database = database
self.table = table
# use a cache for the API requests
self._session = get_session({}, "datasette_cache", CACHE_EXPIRATION)
self._set_columns()
def _run_query(self, sql: str) -> Dict[str, Any]:
"""
Run a query and return the JSON payload.
"""
url = f"{self.server_url}/{self.database}.json"
_logger.info("GET %s", url)
response = self._session.get(url, params={"sql": sql})
payload = response.json()
return cast(Dict[str, Any], payload)
def _set_columns(self) -> None:
# get column names first
payload = self._run_query(f'SELECT * FROM "{self.table}" LIMIT 0')
columns = payload["columns"]
# now try to get non-null values for all columns
select = ", ".join(f'MAX("{column}")' for column in columns)
payload = self._run_query(f'SELECT {select} FROM "{self.table}" LIMIT 1')
rows = payload["rows"]
if not rows:
raise ProgrammingError(f'Table "{self.table}" has no data')
self.columns = {key: get_field(value) for key, value in zip(columns, rows[0])}
[docs]
def get_columns(self) -> Dict[str, Field]:
return self.columns
get_cost = SimpleCostModel(AVERAGE_NUMBER_OF_ROWS)
[docs]
def get_data(
self,
bounds: Dict[str, Filter],
order: List[Tuple[str, RequestedOrder]],
limit: Optional[int] = None,
offset: Optional[int] = None,
**kwargs: Any,
) -> Iterator[Row]:
offset = offset or 0
while True:
if limit is None:
# request 1 more, so we know if there are more pages to be fetched
end = DEFAULT_LIMIT + 1
else:
end = min(limit, DEFAULT_LIMIT + 1)
sql = build_sql(
self.columns,
bounds,
order,
f'"{self.table}"',
limit=end,
offset=offset,
)
payload = self._run_query(sql)
if payload.get("error"):
raise ProgrammingError(
f'Error ({payload["title"]}): {payload["error"]}',
)
columns = payload["columns"]
rows = payload["rows"]
i = -1
for i, values in enumerate(rows[:DEFAULT_LIMIT]):
row = dict(zip(columns, values))
row["rowid"] = i
_logger.debug(row)
yield row
if not payload["truncated"] and len(rows) <= DEFAULT_LIMIT:
break
offset += i + 1
if limit is not None:
limit -= i + 1