Source code for shillelagh.adapters.registry

"""
Registry for adapters.

Inspired by SQLAlchemy's ``PluginLoader``.
"""

import logging
import sys
from collections import defaultdict
from typing import Optional, cast

from shillelagh.adapters.base import Adapter
from shillelagh.exceptions import InterfaceError

if sys.version_info < (3, 10):
    from importlib_metadata import entry_points
else:
    from importlib.metadata import entry_points

_logger = logging.getLogger(__name__)


[docs] class UnsafeAdaptersError(InterfaceError): """ Raised when multiple adapters have the same name. """
[docs] class AdapterLoader: """ Adapter registry, allowing new adapters to be registered. """ def __init__(self): self.loaders = defaultdict(list) for entry_point in entry_points(group="shillelagh.adapter"): self.loaders[entry_point.name].append(entry_point.load)
[docs] def load(self, name: str, safe: bool = False, warn: bool = False) -> type[Adapter]: """ Load a given entry point by its name. """ if safe and len(self.loaders[name]) > 1: raise UnsafeAdaptersError(f"Multiple adapters found with name {name}") for load in self.loaders[name]: try: return cast(type[Adapter], load()) except (ImportError, ModuleNotFoundError) as ex: if warn: _logger.warning("Couldn't load adapter %s", name) _logger.debug(ex) continue raise InterfaceError(f"Unable to load adapter {name}")
[docs] def load_all( self, adapters: Optional[list[str]] = None, safe: bool = False, ) -> dict[str, type[Adapter]]: """ Load all the adapters given a list of names. If ``safe`` is True all adapters must be safe and present in the list of names. Otherwise adapters can be unsafe, and if the list is ``None`` everything is returned. """ return self._load_all_safe(adapters) if safe else self._load_all(adapters)
def _load_all_safe( self, adapters: Optional[list[str]] = None, ) -> dict[str, type[Adapter]]: """ Load all safe adapters. If no adapters are specified, return none. """ if not adapters: return {} loaded_adapters = { name: self.load(name, safe=True, warn=True) for name in self.loaders if name in adapters } return { name: adapter for name, adapter in loaded_adapters.items() if adapter.safe } def _load_all( self, adapters: Optional[list[str]] = None, ) -> dict[str, type[Adapter]]: """ Load all adapters, safe and unsafe. If no adapters are specified, return all. """ loaded_adapters = {} for name in self.loaders: if adapters is None: try: loaded_adapters[name] = self.load(name, safe=False) except InterfaceError: pass elif name in adapters: loaded_adapters[name] = self.load(name, safe=False) return loaded_adapters
[docs] def register(self, name: str, modulepath: str, classname: str) -> None: """ Register a new adapter. """ def load() -> type[Adapter]: module = __import__(modulepath) try: for token in modulepath.split(".")[1:]: module = getattr(module, token) return cast(type[Adapter], getattr(module, classname)) except AttributeError as ex: raise ModuleNotFoundError( f"Unable to load {classname} from {modulepath}", ) from ex self.loaders[name].append(load)
[docs] def add(self, name: str, adapter: type[Adapter]) -> None: """ Add an adapter class directly. """ self.loaders[name].append(lambda: adapter)
[docs] def clear(self) -> None: """ Remove all registered adapters. """ self.loaders = defaultdict(list)
registry = AdapterLoader()