Source code for shillelagh.filters

"""
Filters for representing SQL predicates.
"""

import re
from enum import Enum
from typing import Any, Optional, Set, Tuple


[docs] class Operator(Enum): """ Enum representing support comparisons. """ EQ = "==" NE = "!=" GE = ">=" GT = ">" LE = "<=" LT = "<" IS_NULL = "IS NULL" IS_NOT_NULL = "IS NOT NULL" LIKE = "LIKE" LIMIT = "LIMIT" OFFSET = "OFFSET"
[docs] class Side(Enum): """Define the side of an interval endpoint.""" LEFT = "LEFT" RIGHT = "RIGHT"
[docs] class Endpoint: """ One of the two endpoints of a ``Range``. Used to compare ranges. Eg, the range ``>10`` can be represented by: >>> start = Endpoint(10, False, Side.LEFT) >>> end = Endpoint(None, True, Side.RIGHT) >>> print(f'{start},{end}') (10,∞] The first endpoint represents the value 10 at the left side, in an open interval. The second endpoint represents infinity in this case. """ def __init__(self, value: Any, include: bool, side: Side): self.value = value self.include = include self.side = side def __eq__(self, other: Any) -> bool: if not isinstance(other, Endpoint): return NotImplemented return self.value == other.value and self.include == other.include def __gt__(self, other: Any) -> bool: # pylint: disable=too-many-return-statements if not isinstance(other, Endpoint): return NotImplemented if self.value is None: return self.side == Side.RIGHT if other.value is None: return other.side == Side.LEFT if self.value == other.value: if self.side == Side.LEFT: if other.side == Side.LEFT: return not self.include and other.include return not self.include # self.side = Side.RIGHT if other.side == Side.RIGHT: return not other.include and self.include return False return bool(self.value > other.value) # needed for ``max()`` def __lt__(self, other: Any) -> bool: return not self > other def __repr__(self) -> str: """ Representation of an endpoint. >>> print(Endpoint(10, False, Side.LEFT)) (10 """ if self.side == Side.LEFT: symbol = "[" if self.include else "(" value = "-∞" if self.value is None else self.value return f"{symbol}{value}" symbol = "]" if self.include else ")" value = "∞" if self.value is None else self.value return f"{value}{symbol}"
[docs] def get_endpoints_from_operation( operator: Operator, value: Any, ) -> Tuple[Endpoint, Endpoint]: """ Returns endpoints from an operation. """ if operator == Operator.EQ: return Endpoint(value, True, Side.LEFT), Endpoint(value, True, Side.RIGHT) if operator == Operator.GE: return Endpoint(value, True, Side.LEFT), Endpoint(None, True, Side.RIGHT) if operator == Operator.GT: return Endpoint(value, False, Side.LEFT), Endpoint(None, True, Side.RIGHT) if operator == Operator.LE: return Endpoint(None, True, Side.LEFT), Endpoint(value, True, Side.RIGHT) if operator == Operator.LT: return Endpoint(None, True, Side.LEFT), Endpoint(value, False, Side.RIGHT) # pylint: disable=broad-exception-raised raise Exception(f"Invalid operator: {operator}")
[docs] class Filter: """ A filter representing a SQL predicate. """ operators: Set[Operator] = set()
[docs] @classmethod def build(cls, operations: Set[Tuple[Operator, Any]]) -> "Filter": """ Given a set of operations, build a filter: >>> operations = [(Operator.GT, 10), (Operator.GT, 20)] >>> print(Range.build(operations)) >20 """ raise NotImplementedError("Subclass must implement ``build``")
[docs] def check(self, value: Any) -> bool: """ Test if a given filter matches a value: >>> operations = [(Operator.GT, 10), (Operator.GT, 20)] >>> filter_ = Range.build(operations) >>> filter_.check(10) False >>> filter_.check(30) True """ raise NotImplementedError("Subclass must implement ``check``")
[docs] class Impossible(Filter): """ Custom Filter returned when impossible conditions are passed. """
[docs] @classmethod def build(cls, operations: Set[Tuple[Operator, Any]]) -> Filter: return Impossible()
[docs] def check(self, value: Any) -> bool: return False
def __eq__(self, other: Any) -> bool: if not isinstance(other, Impossible): return NotImplemented return True def __repr__(self) -> str: return "1 = 0"
[docs] class IsNull(Filter): """ Filter for ``IS NULL``. """ operators: Set[Operator] = {Operator.IS_NULL}
[docs] @classmethod def build(cls, operations: Set[Tuple[Operator, Any]]) -> Filter: return IsNull()
[docs] def check(self, value: Any) -> bool: return value is None
def __eq__(self, other: Any) -> bool: if not isinstance(other, IsNull): return NotImplemented return True def __repr__(self) -> str: return "IS NULL"
[docs] class IsNotNull(Filter): """ Filter for ``IS NOT NULL``. """ operators: Set[Operator] = {Operator.IS_NOT_NULL}
[docs] @classmethod def build(cls, operations: Set[Tuple[Operator, Any]]) -> Filter: return IsNotNull()
[docs] def check(self, value: Any) -> bool: return value is not None
def __eq__(self, other: Any) -> bool: if not isinstance(other, IsNotNull): return NotImplemented return True def __repr__(self) -> str: return "IS NOT NULL"
[docs] class Equal(Filter): """ Equality comparison. """ operators: Set[Operator] = { Operator.EQ, } def __init__(self, value: Any): self.value = value
[docs] @classmethod def build(cls, operations: Set[Tuple[Operator, Any]]) -> Filter: values = {value for operator, value in operations} if len(values) != 1: return Impossible() return cls(values.pop())
[docs] def check(self, value: Any) -> bool: return bool(value == self.value)
def __repr__(self) -> str: return f"=={self.value}"
[docs] class NotEqual(Filter): """ Inequality comparison. """ operators: Set[Operator] = { Operator.NE, } def __init__(self, value: Any): self.value = value
[docs] @classmethod def build(cls, operations: Set[Tuple[Operator, Any]]) -> Filter: values = {value for operator, value in operations} if len(values) != 1: return Impossible() return cls(values.pop())
[docs] def check(self, value: Any) -> bool: return bool(value != self.value)
def __repr__(self) -> str: return f"!={self.value}"
[docs] class Like(Filter): """ Substring searches. """ operators: Set[Operator] = { Operator.LIKE, } def __init__(self, value: Any): self.value = value self.regex = re.compile( self.value.replace("_", ".").replace("%", ".*"), re.IGNORECASE, )
[docs] @classmethod def build(cls, operations: Set[Tuple[Operator, Any]]) -> Filter: # we only accept a single value values = {value for operator, value in operations} if len(values) != 1: return Impossible() return cls(values.pop())
[docs] def check(self, value: Any) -> bool: return bool(self.regex.match(value))
def __repr__(self) -> str: return f"LIKE {self.value}"
[docs] class Range(Filter): """ A range comparison. This filter represents a range, with an optional start and an optional end. Start and end can be inclusive or exclusive. Ranges can be combined by adding them: >>> range1 = Range(start=10) >>> range2 = Range(start=20) >>> print(range1 + range2) >20 >>> range3 = Range(end=40) >>> print(range2 + range3) >20,<40 """ def __init__( self, start: Optional[Any] = None, end: Optional[Any] = None, include_start: bool = False, include_end: bool = False, ): self.start = start self.end = end self.include_start = include_start self.include_end = include_end operators: Set[Operator] = { Operator.EQ, Operator.GE, Operator.GT, Operator.LE, Operator.LT, } def __eq__(self, other: Any): if not isinstance(other, Range): return NotImplemented return ( self.start == other.start and self.end == other.end and self.include_start == other.include_start and self.include_end == other.include_end ) def __add__(self, other: Any) -> Filter: if not isinstance(other, Range): return NotImplemented start = Endpoint(self.start, self.include_start, Side.LEFT) end = Endpoint(self.end, self.include_end, Side.RIGHT) new_start = Endpoint(other.start, other.include_start, Side.LEFT) new_end = Endpoint(other.end, other.include_end, Side.RIGHT) start = max(start, new_start) end = min(end, new_end) if start > end: return Impossible() return Range(start.value, end.value, start.include, end.include)
[docs] @classmethod def build(cls, operations: Set[Tuple[Operator, Any]]) -> Filter: start = Endpoint(None, True, Side.LEFT) end = Endpoint(None, True, Side.RIGHT) for operator, value in operations: new_start, new_end = get_endpoints_from_operation(operator, value) start = max(start, new_start) end = min(end, new_end) if start > end: return Impossible() return cls(start.value, end.value, start.include, end.include)
[docs] def check(self, value: Any) -> bool: if self.start is not None: if self.include_start and value < self.start: return False if not self.include_start and value <= self.start: return False if self.end is not None: if self.include_end and value > self.end: return False if not self.include_end and value >= self.end: return False return True
def __repr__(self) -> str: if self.start == self.end and self.include_start and self.include_end: return f"=={self.start}" comparisons = [] if self.start is not None: operator = ">=" if self.include_start else ">" comparisons.append(f"{operator}{self.start}") if self.end is not None: operator = "<=" if self.include_end else "<" comparisons.append(f"{operator}{self.end}") return ",".join(comparisons)