Source code for shillelagh.adapters.file.csvfile

"""
An adapter for CSV files.

This adapter treats a CSV file as a table, allowing rows to be inserted,
deleted, and updated. It's not very practical since it requires the data
to be written with the ``QUOTE_NONNUMERIC`` format option, with strings
explicitly quoted. It's also not very efficient, since it implements the
filtering and sorting in Python, instead of relying on the backend.

Remote files (HTTP/HTTPS) are also supported in read-only mode.
"""

import csv
import logging
import os
import tempfile
import urllib.parse
from datetime import timedelta
from pathlib import Path
from typing import Any, Dict, Iterator, List, Optional, Tuple, cast

import requests

from shillelagh.adapters.base import Adapter
from shillelagh.exceptions import ProgrammingError
from shillelagh.fields import Field
from shillelagh.filters import (
    Equal,
    Filter,
    IsNotNull,
    IsNull,
    NotEqual,
    Operator,
    Range,
)
from shillelagh.lib import RowIDManager, analyze, filter_data, update_order
from shillelagh.typing import Maybe, MaybeType, RequestedOrder, Row

_logger = logging.getLogger(__name__)

INITIAL_COST = 0
FILTERING_COST = 1000
SORTING_COST = 10000

DEFAULT_TIMEOUT = timedelta(minutes=3)

SUPPORTED_PROTOCOLS = {"http", "https"}


[docs] class RowTracker: """An iterator that keeps track of the last yielded row.""" def __init__(self, iterable: Iterator[Row]): self.iterable = iterable self.last_row: Optional[Row] = None def __iter__(self) -> Iterator[Row]: for row in self.iterable: self.last_row = row yield row def __next__(self) -> Row: return self.iterable.__next__()
[docs] class CSVFile(Adapter): r""" An adapter for CSV files. The files must be written with the ``QUOTE_NONNUMERIC`` format option, with strings explicitly quoted:: "index","temperature","site" 10.0,15.2,"Diamond_St" 11.0,13.1,"Blacktail_Loop" 12.0,13.3,"Platinum_St" 13.0,12.1,"Kodiak_Trail" The adapter will first scan the whole file to determine number of rows, as well as the type and order of each column. The adapter has no index. When data is ``SELECT``\ed the adapter will stream over all the rows in the file, filtering them on the fly. If a specific order is requests the resulting rows will be loaded into memory so they can be sorted. Inserted rows are appended to the end of the file. Deleted rows simply have their row ID marked as deleted (-1), and are ignored when the data is scanned for results. When the adapter is closed deleted rows will be garbage collected. Updates are handled with a delete followed by an insert. """ # the adapter is not safe, since it could be used to read files from # the filesystem, or potentially overwrite existing files safe = False supports_limit = True supports_offset = True
[docs] @staticmethod def supports(uri: str, fast: bool = True, **kwargs: Any) -> MaybeType: # local file path = Path(uri) if path.suffix == ".csv" and path.exists(): return True # remote file parsed = urllib.parse.urlparse(uri) if parsed.scheme not in SUPPORTED_PROTOCOLS: return False # URLs ending in ``.csv`` are probably CSV files if parsed.path.endswith(".csv"): return True # do a head request to get mimetype if fast: return Maybe response = requests.head(uri, timeout=DEFAULT_TIMEOUT.total_seconds()) return "text/csv" in response.headers.get("content-type", "")
[docs] @staticmethod def parse_uri(uri: str) -> Tuple[str]: return (uri,)
def __init__(self, path_or_uri: str): super().__init__() path = Path(path_or_uri) if path.suffix == ".csv" and path.exists(): self.local = True else: self.local = False # download CSV file with tempfile.NamedTemporaryFile(delete=False) as output: response = requests.get( path_or_uri, timeout=DEFAULT_TIMEOUT.total_seconds(), ) output.write(response.content) path = Path(output.name) self.path = path self.modified = False _logger.info("Opening file CSV file %s to load metadata", self.path) with open(self.path, encoding="utf-8") as csvfile: reader = csv.reader(csvfile, quoting=csv.QUOTE_NONNUMERIC) try: column_names = next(reader) except StopIteration as ex: raise ProgrammingError("The file has no rows") from ex data = (dict(zip(column_names, row)) for row in reader) # put data in a ``RowTracker``, so we can monitor the last row # and keep track of the column order row_tracker = RowTracker(data) # analyze data to determine number of rows, as well as the order # and type of each column num_rows, order, types = analyze(row_tracker) _logger.debug("Read %d rows", num_rows) self.columns = { column_name: types[column_name]( filters=[Range, Equal, NotEqual, IsNull, IsNotNull], order=order[column_name], exact=True, ) for column_name in column_names } # the row ID manager is used to keep track of insertions and deletions self.row_id_manager = RowIDManager([range(0, num_rows + 1)]) self.last_row = row_tracker.last_row self.num_rows = num_rows
[docs] def get_columns(self) -> Dict[str, Field]: return self.columns
[docs] def get_cost( self, filtered_columns: List[Tuple[str, Operator]], order: List[Tuple[str, RequestedOrder]], ) -> float: cost = INITIAL_COST # filtering the data has linear cost, since ``filter_data`` builds a # single filter function applied as the data is streamed if filtered_columns: cost += FILTERING_COST # sorting, on the other hand, is costy, requiring loading all the data # in memory and calling ``.sort()`` for each column that needs to be # ordered cost += SORTING_COST * len(order) return cost
[docs] def get_data( self, bounds: Dict[str, Filter], order: List[Tuple[str, RequestedOrder]], limit: Optional[int] = None, offset: Optional[int] = None, **kwargs: Any, ) -> Iterator[Row]: _logger.info("Opening file CSV file %s to load data", self.path) with open(self.path, encoding="utf-8") as csvfile: reader = csv.reader(csvfile, quoting=csv.QUOTE_NONNUMERIC) try: header = next(reader) except StopIteration as ex: raise ProgrammingError("The file has no rows") from ex column_names = ["rowid", *header] rows = ([i, *row] for i, row in zip(self.row_id_manager, reader) if i != -1) data = (dict(zip(column_names, row)) for row in rows) # Filter and sort the data. It would probably be more efficient to simply # declare the columns as having no filter and no sort order, and let the # backend handle this; but it's nice to have an example of how to do this. for row in filter_data(data, bounds, order, limit, offset): _logger.debug(row) yield row
[docs] def insert_data(self, row: Row) -> int: if not self.local: raise ProgrammingError("Cannot apply DML to a remote file") row_id: Optional[int] = row.pop("rowid") row_id = cast(int, self.row_id_manager.insert(row_id)) # append row column_names = list(self.get_columns().keys()) _logger.info("Appending row with ID %d to CSV file %s", row_id, self.path) _logger.debug(row) with open(self.path, "a", encoding="utf-8") as csvfile: writer = csv.writer(csvfile, quoting=csv.QUOTE_NONNUMERIC) writer.writerow([row[column_name] for column_name in column_names]) self.num_rows += 1 # update order, in case it has changed for column_name, column_type in self.columns.items(): column_type.order = update_order( current_order=column_type.order, previous=self.last_row[column_name] if self.last_row else None, current=row[column_name], num_rows=self.num_rows, ) self.last_row = row self.modified = True return row_id
[docs] def delete_data(self, row_id: int) -> None: if not self.local: raise ProgrammingError("Cannot apply DML to a remote file") _logger.info("Deleting row with ID %d from CSV file %s", row_id, self.path) # on ``DELETE``\s we simply mark the row as deleted, so that it will be ignored # on ``SELECT``\s self.row_id_manager.delete(row_id) self.num_rows -= 1 self.modified = True
[docs] def close(self) -> None: """ Garbage collect the file. This method will get rid of deleted rows in the files. """ if not self.local: try: self.path.unlink() except FileNotFoundError: pass return if not self.modified: return # garbage collect -- should we sort the data according to the initial sort # order when writing to the new file? with open(self.path, encoding="utf-8") as csvfile: reader = csv.reader(csvfile, quoting=csv.QUOTE_NONNUMERIC) column_names = next(reader) data = (row for i, row in zip(self.row_id_manager, reader) if i != -1) with open(self.path.with_suffix(".csv.bak"), "w", encoding="utf-8") as copy: writer = csv.writer(copy, quoting=csv.QUOTE_NONNUMERIC) writer.writerow(column_names) writer.writerows(data) os.replace(self.path.with_suffix(".csv.bak"), self.path) self.modified = False
[docs] def drop_table(self) -> None: self.path.unlink()