Source code for eta_nexus.connections.connection

"""Connection base class and protocols for the ETA Nexus framework."""

from __future__ import annotations

import concurrent.futures
import os
from abc import ABC, abstractmethod
from collections.abc import Iterable, Mapping
from datetime import datetime, timedelta
from logging import getLogger
from typing import TYPE_CHECKING, Any, Generic, Protocol, cast, overload, runtime_checkable

import pandas as pd
from attr import field
from dateutil import tz
from requests.adapters import HTTPAdapter
from requests.exceptions import HTTPError, RequestException, Timeout
from typing_extensions import deprecated
from urllib3.util.retry import Retry

from eta_nexus.nodes.node import Node
from eta_nexus.subscription_handlers.subscription_handler import SubscriptionHandler
from eta_nexus.util import ensure_timezone, round_timestamp, url_parse
from eta_nexus.util.type_annotations import N, N_contra

if TYPE_CHECKING:
    from collections.abc import Mapping
    from datetime import datetime
    from logging import Logger
    from typing import Any, ClassVar
    from urllib.parse import ParseResult

    from pandas._typing import ArrayLike
    from requests.auth import AuthBase
    from requests_cache import AnyResponse, CachedSession

    from eta_nexus.subscription_handlers import SubscriptionHandler
    from eta_nexus.util.type_annotations import Nodes, Self, TimeStep


[docs] @runtime_checkable class StatusReadable(Protocol, Generic[N_contra]): """Non-data Protocol for Connections with the ability to read data."""
[docs] @abstractmethod def read(self, nodes: N_contra | Nodes[N_contra] | None = None) -> pd.DataFrame: """Reads current value from each Node in nodes. Uses selected_nodes if no nodes are passed. :param nodes: Single Node or Sequence/Set of Nodes to read from. :return: Pandas DataFrame with read values. """
[docs] @runtime_checkable class StatusWritable(Protocol, Generic[N]): """Non-data Protocol for Connections with the ability to write data."""
[docs] @abstractmethod def write(self, values: Mapping[N, Any]) -> None: """Writes given values to nodes :param values: Mapping(e.g. Dict) of Nodes and respective values to write {node: value}. """
[docs] @runtime_checkable class StatusSubscribable(Protocol, Generic[N_contra]): """Non-data Protocol for Connections with the ability to subscribe to data."""
[docs] @abstractmethod def subscribe( self, handler: SubscriptionHandler, nodes: N_contra | Nodes[N_contra] | None = None, request_frequency: TimeStep = 1, ) -> None: """Subscribes to nodes and calls handler when new data is available. If the connection protocol doesn't implement subscriptions natively, this method polls the nodes with the given frequency. Uses subscription_nodes if no nodes are passed. :param nodes: Single Node or Sequence/Set of nodes to subscribe to. :param handler: A SubscriptionHandler object :param request_frequency: Time period between two requests. Interpreted as seconds if Numeric is given. Technically no frequency! """
[docs] @abstractmethod def close_sub(self) -> None: """Closes an open subscription. This should gracefully handle non-existent subscriptions."""
[docs] @runtime_checkable class SeriesReadable(Protocol, Generic[N_contra]): """Non-data Protocol for Connections with the ability to read historic data."""
[docs] @abstractmethod def read_series( self, from_time: datetime, to_time: datetime, nodes: N_contra | Nodes[N_contra] | None = None, interval: TimeStep = 1, **kwargs: Any, ) -> pd.DataFrame: """Reads time series data for each Node in nodes. Retrieves values for the partly open time interval [from_time, to_time), adhering to the specified value-to-value distance given as resolution. Uses selected_nodes if no nodes are passed. Will apply the same resolution to all nodes. :param interval: Start and end of timeseries, treated as partly open interval[from_time, to_time). :param nodes: Single Node or Sequence/Set of nodes to read values from. :param resolution: Time between timeseries' values. Interpreted as seconds if Numeric is given. :param kwargs: Additional Subclass arguments. :return: pandas.DataFrame containing the timeseries read from the connection. """
[docs] @runtime_checkable class SeriesWritable(Protocol, Generic[N]): """Non-data Protocol for Connections with the ability to write historic (time series) data.""" @overload def write_series( self, values: Mapping[N, pd.Series], *, allow_overwrite: bool = True, **kwargs: Any, ) -> None: ... @overload def write_series( self, values: pd.DataFrame, *, allow_overwrite: bool = True, **kwargs: Any, ) -> None: ...
[docs] def write_series( self, values: Mapping[N, pd.Series] | pd.DataFrame, *, allow_overwrite: bool = True, **kwargs: Any, ) -> None: """Writes time series data for the given nodes. Accepts either - a mapping from Node -> pandas.Series (index must be datetime-like; series values are samples), or - a pandas.DataFrame with a datetime-like index and one column per Node (column names must match node.name). Implementations may round/align timestamps to node-specific intervals and should ensure timezone awareness consistent with the Connection utilities. :param values: Mapping of nodes to Series, or a DataFrame with datetime-like index. :param allow_overwrite: If True, upsert points at identical timestamps; if False, avoid overwriting. :param kwargs: Additional subclass arguments. """
[docs] @runtime_checkable class SeriesSubscribable(Protocol, Generic[N_contra]): """Non-data Protocol for Connections with the ability to subscribe to historic data."""
[docs] @abstractmethod def subscribe_series( self, handler: SubscriptionHandler, req_interval: TimeStep, offset: TimeStep | None = None, nodes: N_contra | Nodes[N_contra] | None = None, interval: TimeStep = 1, data_interval: TimeStep = 1, **kwargs: Any, ) -> None: """Subscribes to nodes and calls handler when new data is available. Retrieves values for the partly open time interval [now + offset, now + offset + data_duration), adhering to the specified value-to-value distance given as resolution. If the connection protocol doesn't implement subscriptions natively, this method polls the nodes with the given requesty_frequency. Uses subscription_nodes if no nodes are passed. Will apply the same resolution to all nodes. :param handler: A SubscriptionHandler object :param data_duration: Duration of returned timeseries interval. :param offset: Offset between time of request and start of returned timeseries. Can be negative. :param nodes: Single Node or Sequence/Set of nodes to subscribe to. :param request_frequency: Time period between two requests. Interpreted as seconds if Numeric is given. Technically no frequency! :param resolution: Time between timeseries' values. Interpreted as seconds if Numeric is given. :param **kwargs: Subclass arguments """
[docs] @abstractmethod def close_sub(self) -> None: """Closes an open subscription. This should gracefully handle non-existent subscriptions."""
[docs] class Connection(Generic[N], ABC): """Common connection interface class. The URL (netloc) may contain the username and password. (schema://username:password@hostname:port/path) In this case, the parameters usr and pwd are not required. BUT the keyword parameters of the function will take precedence over username and password configured in the url. :param url: Netloc of the server to connect to. :param usr: Username for login to server. :param pwd: Password for login to server. :param nodes: List of nodes to select as a standard case. """ logger: Logger _registry: ClassVar[dict[str, type[Connection]]] = {} _PROTOCOL: ClassVar[str] = field(repr=False, eq=False, order=False) def __init_subclass__(cls, **kwargs: Any) -> None: """Store subclass definitions to instantiate based on protocol.""" protocol = kwargs.pop("protocol", None) if protocol: cls._PROTOCOL = protocol cls._registry[protocol] = cls return super().__init_subclass__(**kwargs) def __init__( self, url: str, usr: str | None = None, pwd: str | None = None, *, nodes: Nodes[N] | None = None ) -> None: #: URL of the server to connect to self.url_parsed: ParseResult #: Username for login to server self.usr: str | None #: Password for login to server self.pwd: str | None self.url_parsed, self.usr, self.pwd = url_parse(url) if nodes is not None: #: Preselected nodes which will be used for reading and writing, if no other nodes are specified self.selected_nodes = self._validate_nodes(nodes) else: self.selected_nodes = set() # Get username and password either from the arguments, from the parsed URL string or from a Node object node = next(iter(self.selected_nodes)) if len(self.selected_nodes) > 0 else None def validate_and_set(attribute: str, value: str | Any, node_value: str | None) -> None: """If attribute is not already set, set it to value or node_value if value is None.""" if value is not None: if not isinstance(value, str): raise TypeError(f"{attribute.capitalize()} should be a string value.") setattr(self, attribute, value) elif getattr(self, attribute) is None and node_value is not None: setattr(self, attribute, node_value) validate_and_set("usr", usr, node.usr if node else None) validate_and_set("pwd", pwd, node.pwd if node else None) #: Store local time zone self._local_tz = tz.tzlocal() #: :py:func:`eta_nexus.util.round_timestamp` self._round_timestamp = round_timestamp #: :py:func:`eta_nexus.util.ensure_timezone` self._assert_tz_awareness = ensure_timezone self.exc: BaseException | None = None def __eq__(self, other: object) -> bool: if not isinstance(other, Connection): return False return (self.url_parsed.netloc, self._extra_equality_key()) == ( other.url_parsed.netloc, other._extra_equality_key(), ) def __hash__(self) -> int: return hash((self.url_parsed.netloc, self._extra_equality_key())) def _extra_equality_key(self) -> Any | None: """Additional attributes that are relevant for deciding if nodes belong to a connection. Override this in case extra keys are necessary, don't forget to also set this in the node class. Enforce presence of attributes used in this method! """ return None
[docs] @classmethod def from_node(cls, node: Nodes[N] | N, usr: str | None = None, pwd: str | None = None, **kwargs: Any) -> Self: """Will return a single connection for an enumerable of nodes with the same url netloc. Initialize the connection object from a node object. When a list of Node objects is provided, from_node checks if all nodes match the same connection; it throws an error if they don't. A node matches a connection if it has the same url netloc. :param node: Node to initialize from. :param kwargs: Other arguments are ignored. :raises: ValueError: if not all nodes match the same connection. :return: Connection object """ nodes = {node} if not isinstance(node, Iterable) else set(node) # Check if all nodes belong to the same connection if len({_node.connection_identifier() for _node in nodes}) != 1: raise ValueError("Nodes must all have the same netloc to be used with the same connection.") for index, _node in enumerate(nodes): # Instantiate connection from the first node if index == 0: # set the username and password usr = _node.usr or usr pwd = _node.pwd or pwd connection_cls = cls._registry[_node.protocol] connection = cast("Self", connection_cls._from_node(_node, usr=usr, pwd=pwd, **kwargs)) # Add node to existing connection else: connection.selected_nodes.add(_node) return connection
[docs] @classmethod def from_nodes(cls, nodes: Nodes[N], **kwargs: Any) -> dict[str, Connection[N]]: """Returns a dictionary of connections for nodes with the same url netloc. This method handles different Connections, unlike from_node(). The keys of the dictionary are the netlocs of the nodes and each connection contains the nodes with the same netloc. (Uses from_node to initialize connections from nodes.). :param nodes: List of nodes to initialize from. :param kwargs: Other arguments are ignored. :return: Dictionary of Connection objects with the netloc as key. """ connections: dict[str, Connection[N]] = {} for node in nodes: connection_id = str(node.connection_identifier()) # If we already have a connection for this URL, add the node to connection if connection_id in connections: connections[connection_id].selected_nodes.add(node) continue # Skip creating a new connection connections[connection_id] = cls.from_node(node, **kwargs) return connections
@classmethod @abstractmethod def _from_node(cls, node: N, **kwargs: Any) -> Self: """Initialize the object from a node with corresponding protocol. :return: Initialized connection object. """ if not isinstance(node, Node): raise TypeError("Node must be a Node object.") if node.protocol != cls._PROTOCOL: raise ValueError( f"Tried to initialize {cls.__name__} from a node " f"that does not specify {cls._PROTOCOL} as its protocol: {node.name}." ) return cls(url=node.url, nodes=[node], **kwargs) @property def url(self) -> str: return self.url_parsed.geturl() def _validate_nodes(self, nodes: N | Nodes[N] | None) -> set[N]: """Make sure that nodes are a Set of nodes and that all nodes correspond to the connection. :param nodes: Single node or list/set of nodes to validate. :return: Set of valid node objects for this connection. """ if nodes is None: _nodes = self.selected_nodes else: nodes = {nodes} if not isinstance(nodes, Iterable) else nodes # If not using preselected nodes from self.selected_nodes, check if nodes correspond to the connection _nodes = { node for node in nodes if ( node.protocol == self._PROTOCOL and node.url_parsed.netloc == self.url_parsed.netloc and node._extra_equality_key() == self._extra_equality_key() ) } # Make sure that some nodes remain after the checks and raise an error if there are none. if len(_nodes) == 0: raise ValueError( f"Some nodes to read from/write to must be specified. If nodes were specified, they do not " f"match the connection {self.url}" ) return _nodes def _preprocess_series_context( self, from_time: datetime, to_time: datetime, nodes: N | Nodes[N] | None = None, interval: TimeStep = 1, **kwargs: Any, ) -> tuple[datetime, datetime, set[N], timedelta]: """Preprocesses the series context to ensure it is ready for reading. This method validates the nodes, ensures the time interval is a timedelta, and rounds the timestamps to the nearest interval. It also checks that the timezones of from_time and to_time are the same. :param from_time: The start time of the series. :param to_time: The end time of the series. :param nodes: The nodes to read from. :param interval: The time interval for the series. :return: A tuple containing the processed from_time, to_time, nodes, and interval. """ nodes = self._validate_nodes(nodes) interval = interval if isinstance(interval, timedelta) else timedelta(seconds=interval) from_time = round_timestamp(from_time, interval.total_seconds(), method="floor", ensure_tz=True) to_time = round_timestamp(to_time, interval.total_seconds(), method="ceil", ensure_tz=True) if from_time.tzinfo != to_time.tzinfo: log = getLogger(self.__module__) log.warning( f"Timezone of from_time and to_time are different. Using from_time timezone: {from_time.tzinfo}" ) return (from_time, to_time, nodes, interval)
[docs] @deprecated("Use `Connection` and the appropriate protocols instead.") class SeriesConnection( StatusReadable[N], StatusWritable[N], StatusSubscribable[N], SeriesReadable[N], SeriesWritable[N], SeriesSubscribable[N], Connection[N], ABC, ): """Connection object for protocols with the ability to provide access to timeseries data. :param url: URL of the server to connect to. """
[docs] class RESTConnection(Connection[N], ABC): """ RESTConnection is an abstract base class for managing RESTful API connections in the ETA Nexus framework. It extends the `Connection` class and provides standardized functionality for handling HTTP requests, managing API tokens, and session management. This class is designed to reduce boilerplate code and streamline the integration of new REST-based connections. Key Features: - Centralized HTTP request handling with consistent error management and logging. - Lazy-loaded session management using a cached session. - API token retrieval from environment variables based on the connection protocol name. - Abstract methods for session initialization and node-specific data reading. - Authentication abstraction for subclasses to define custom authentication mechanisms. Subclasses should implement the `_initialize_session` method to define session initialization logic and the `_read_node` method to handle node-specific data reading. :param url: URL of the REST API endpoint. :param usr: Username for authentication (optional). :param pwd: Password for authentication (optional). :param nodes: List of nodes to connect to (optional). :param retry_total: Total number of retries for failed HTTP requests (default: 3). :param retry_backoff_factor: Backoff factor for retries (default: 1s-> e.g. 1s, 2s, 4s for 3 retries). """ def __init__( self, url: str, usr: str | None = None, pwd: str | None = None, *, nodes: Nodes[N] | None = None, retry_total: int = 3, retry_backoff_factor: float = 1.0, ) -> None: super().__init__(url, usr, pwd, nodes=nodes) self._retry_total = retry_total self._retry_backoff_factor = retry_backoff_factor @property def _api_token(self) -> str | None: """Return the API token from the environment variable if set.""" token = os.getenv(self._PROTOCOL.upper() + "_API_TOKEN") if token is None: self.logger.warning( f"[{self._PROTOCOL.capitalize()}] {self._PROTOCOL.upper()}_API_TOKEN not found in environment." ) return token @property def session(self) -> CachedSession: "Return the cached session." if not hasattr(self, "_cached_session"): session = self._initialize_session() retry_strategy = Retry( total=self._retry_total, # Number of total retries connect=self._retry_total, # Retry on connect timeouts read=self._retry_total, # Retry on read timeouts status_forcelist=[429, 500, 502, 503, 504], # HTTP status codes to retry on allowed_methods=["HEAD", "GET", "OPTIONS", "POST"], # Methods to retry backoff_factor=self._retry_backoff_factor, raise_on_status=False, ) adapter = HTTPAdapter(max_retries=retry_strategy) session.mount("https://", adapter) session.mount("http://", adapter) self._cached_session = session return self._cached_session @property def authentication(self) -> None | AuthBase: """Return the authentication method for the API.""" return None @abstractmethod def _initialize_session(self) -> CachedSession: "Initialize the cached session and return it." def _raw_request( self, method: str, url: str, params: dict[str, Any] | None = None, **kwargs: Any ) -> AnyResponse | None: """Send a raw HTTP request to the REST API. :param method: HTTP method to use (e.g., GET, POST). :param url: URL of the API endpoint. :return: Response object from the requests library. """ kwargs.setdefault("timeout", 10) try: response = self.session.request(method, url, params=params, auth=self.authentication, **kwargs) response.raise_for_status() except HTTPError: # Bad Response (4xx, 5xx after retries exhausted) self.logger.exception(f"[{self._PROTOCOL.capitalize()}] Request failed:") return None except Timeout: # Timeout errors (after retries exhausted) self.logger.exception( f"[{self._PROTOCOL.capitalize()}] Request timed out after {kwargs.get('timeout', 10)}s" ) return None except RequestException: # Other errors (ConnectionError, SSLError, etc.) self.logger.exception(f"[{self._PROTOCOL.capitalize()}] Request error occurred") return None else: return response @abstractmethod def _parse_response(self, json_data: dict[Any, Any]) -> tuple[pd.DatetimeIndex, ArrayLike]: """Parse the JSON data from the REST API into a DataFrame. :param json_data: JSON data from the API response. :return: Tuple of (timestamps, values) where: - timestamps: DatetimeIndex to use as the DataFrame index. Must be timezone-aware. - values: Array-like data for the DataFrame. If a pd.Series is returned, its index MUST match the timestamps to avoid NaN values from misalignment. For other array-like types (list, tuple, ndarray), the values will be automatically aligned with timestamps. """
[docs] @abstractmethod def read_node( self, node: N, from_time: datetime, to_time: datetime, interval: timedelta, **kwargs: Any, ) -> pd.DataFrame: """Read data from a REST API endpoint. :param node: Node to read data from. :param from_time: Start of the time series (timezone-aware). :param to_time: End of the time series (timezone-aware). :param interval: Time interval between data points. :param kwargs: Additional subclass-specific arguments. :return: DataFrame containing the data read from the API. """ raise NotImplementedError( "Subclasses must implement read_node to define how data is read from a REST API endpoint for a given node." )
def _read_node(self, node: N, url: str, params: dict[str, Any] | None = None) -> pd.DataFrame: """Read data from a REST API endpoint. :param node: Node to read data from. :return: DataFrame containing the data read from the API. """ empty_df = pd.DataFrame(columns=[node.name], index=pd.DatetimeIndex([], name="Time (with timezone)")) response = self._raw_request("GET", url, params=params) if response is None: self.logger.warning(f"[{self._PROTOCOL}] No response from {url} for node {node.name}") return empty_df # Process the data into a DataFrame try: json_data = response.json() timestamps, node_values = self._parse_response(json_data) node_data_frame = pd.DataFrame( data=node_values, index=timestamps.tz_convert(self._local_tz), columns=[node.name], dtype="float64", ) node_data_frame.index.name = "Time (with timezone)" except (KeyError, ValueError, AttributeError, TypeError): self.logger.exception(f"[{self._PROTOCOL}] Failed to process data for node {node.name}") return empty_df else: return node_data_frame def _get_data( self, from_time: datetime, to_time: datetime, nodes: N | Nodes[N] | None = None, interval: TimeStep = 60, **kwargs: Any, ) -> pd.DataFrame: """Get data from the REST API for the specified nodes and time interval. :param from_time: Start of the time series, treated as partly open interval [from_time, to_time). :param to_time: End of the time series, treated as partly open interval [from_time, to_time). :param nodes: Single node or list/set of nodes to read values from. :param interval: Time between time series' values. Interpreted as seconds if Numeric is given. :param kwargs: Additional subclass arguments. """ from_time, to_time, nodes, interval = super()._preprocess_series_context( from_time, to_time, nodes, interval, **kwargs ) with concurrent.futures.ThreadPoolExecutor() as executor: results = executor.map(lambda node: self.read_node(node, from_time, to_time, interval, **kwargs), nodes) # Filter out empty or all-NA DataFrames filtered_results = [df for df in results if not df.empty and not df.isna().all().all()] if not filtered_results: self.logger.warning(f"[{self._PROTOCOL.capitalize()}] No valid data retrieved from any node.") col_names = [node.name for node in nodes] if not col_names: col_names = ["__placeholder__"] return pd.DataFrame(columns=col_names) return pd.concat(filtered_results, axis=1, sort=False)