"""Connection base class and protocols for the ETA Nexus framework."""
from __future__ import annotations
from abc import ABC, abstractmethod
from collections.abc import Iterable, Mapping
from datetime import datetime, timedelta
from logging import getLogger
from typing import TYPE_CHECKING, Any, Generic, Protocol, cast, runtime_checkable
from attr import field
from dateutil import tz
from typing_extensions import deprecated
from eta_nexus.nodes.node import Node
from eta_nexus.subhandlers.subhandler import SubscriptionHandler
from eta_nexus.util import ensure_timezone, round_timestamp, url_parse
from eta_nexus.util.type_annotations import N, N_contra
if TYPE_CHECKING:
from collections.abc import Mapping
from datetime import datetime
from typing import Any, ClassVar
from urllib.parse import ParseResult
import pandas as pd
from eta_nexus.subhandlers import SubscriptionHandler
from eta_nexus.util.type_annotations import Nodes, Self, TimeStep
[docs]
@runtime_checkable
class Readable(Protocol, Generic[N_contra]):
"""Non-data Protocol for Connections with the ability to read data."""
[docs]
@abstractmethod
def read(self, nodes: N_contra | Nodes[N_contra] | None = None) -> pd.DataFrame:
"""Reads current value from each Node in nodes. Uses selected_nodes if no nodes are passed.
:param nodes: Single Node or Sequence/Set of Nodes to read from.
:return: Pandas DataFrame with read values.
"""
[docs]
@runtime_checkable
class Writable(Protocol, Generic[N]):
"""Non-data Protocol for Connections with the ability to write data."""
[docs]
@abstractmethod
def write(self, values: Mapping[N, Any]) -> None:
"""Writes given values to nodes
:param values: Mapping(e.g. Dict) of Nodes and respective values to write {node: value}.
"""
[docs]
@runtime_checkable
class Subscribable(Protocol, Generic[N_contra]):
"""Non-data Protocol for Connections with the ability to subscribe to data."""
[docs]
@abstractmethod
def subscribe(
self,
handler: SubscriptionHandler,
nodes: N_contra | Nodes[N_contra] | None = None,
request_frequency: TimeStep = 1,
) -> None:
"""Subscribes to nodes and calls handler when new data is available. If the connection protocol doesn't
implement subscriptions natively, this method polls the nodes with the given frequency. Uses
subscription_nodes if no nodes are passed.
:param nodes: Single Node or Sequence/Set of nodes to subscribe to.
:param handler: A SubscriptionHandler object
:param request_frequency: Time period between two requests. Interpreted as seconds if Numeric is given.
Technically no frequency!
"""
[docs]
@abstractmethod
def close_sub(self) -> None:
"""Closes an open subscription. This should gracefully handle non-existent subscriptions."""
[docs]
@runtime_checkable
class SeriesReadable(Protocol, Generic[N_contra]):
"""Non-data Protocol for Connections with the ability to read historic data."""
[docs]
@abstractmethod
def read_series(
self,
from_time: datetime,
to_time: datetime,
nodes: N_contra | Nodes[N_contra] | None = None,
interval: TimeStep = 1,
**kwargs: Any,
) -> pd.DataFrame:
"""Reads time series data for each Node in nodes. Retrieves values for the partly open time interval
[from_time, to_time), adhering to the specified value-to-value distance given as resolution.
Uses selected_nodes if no nodes are passed. Will apply the same resolution to all nodes.
:param interval: Start and end of timeseries, treated as partly open interval[from_time, to_time).
:param nodes: Single Node or Sequence/Set of nodes to read values from.
:param resolution: Time between timeseries' values. Interpreted as seconds if Numeric is given.
:param kwargs: Additional Subclass arguments.
:return: pandas.DataFrame containing the timeseries read from the connection.
"""
[docs]
@runtime_checkable
class SeriesSubscribable(Protocol, Generic[N_contra]):
"""Non-data Protocol for Connections with the ability to subscribe to historic data."""
[docs]
@abstractmethod
def subscribe_series(
self,
handler: SubscriptionHandler,
req_interval: TimeStep,
offset: TimeStep | None = None,
nodes: N_contra | Nodes[N_contra] | None = None,
interval: TimeStep = 1,
data_interval: TimeStep = 1,
**kwargs: Any,
) -> None:
"""Subscribes to nodes and calls handler when new data is available. Retrieves values for the partly open time
interval [now + offset, now + offset + data_duration), adhering to the specified value-to-value distance
given as resolution. If the connection protocol doesn't implement subscriptions natively, this method polls
the nodes with the given requesty_frequency. Uses subscription_nodes if no nodes are passed. Will apply the
same resolution to all nodes.
:param handler: A SubscriptionHandler object
:param data_duration: Duration of returned timeseries interval.
:param offset: Offset between time of request and start of returned timeseries. Can be negative.
:param nodes: Single Node or Sequence/Set of nodes to subscribe to.
:param request_frequency: Time period between two requests. Interpreted as seconds if Numeric is given.
Technically no frequency!
:param resolution: Time between timeseries' values. Interpreted as seconds if Numeric is given.
:param **kwargs: Subclass arguments
"""
[docs]
@abstractmethod
def close_sub(self) -> None:
"""Closes an open subscription. This should gracefully handle non-existent subscriptions."""
[docs]
class Connection(Generic[N], ABC):
"""Common connection interface class.
The URL (netloc) may contain the username and password. (schema://username:password@hostname:port/path)
In this case, the parameters usr and pwd are not required. BUT the keyword parameters of the function will
take precedence over username and password configured in the url.
:param url: Netloc of the server to connect to.
:param usr: Username for login to server.
:param pwd: Password for login to server.
:param nodes: List of nodes to select as a standard case.
"""
_registry: ClassVar[dict[str, type[Connection]]] = {}
_PROTOCOL: ClassVar[str] = field(repr=False, eq=False, order=False)
def __init_subclass__(cls, **kwargs: Any) -> None:
"""Store subclass definitions to instantiate based on protocol."""
protocol = kwargs.pop("protocol", None)
if protocol:
cls._PROTOCOL = protocol
cls._registry[protocol] = cls
return super().__init_subclass__(**kwargs)
def __init__(
self, url: str, usr: str | None = None, pwd: str | None = None, *, nodes: Nodes[N] | None = None
) -> None:
#: URL of the server to connect to
self._url: ParseResult
#: Username for login to server
self.usr: str | None
#: Password for login to server
self.pwd: str | None
self._url, self.usr, self.pwd = url_parse(url)
if nodes is not None:
#: Preselected nodes which will be used for reading and writing, if no other nodes are specified
self.selected_nodes = self._validate_nodes(nodes)
else:
self.selected_nodes = set()
# Get username and password either from the arguments, from the parsed URL string or from a Node object
node = next(iter(self.selected_nodes)) if len(self.selected_nodes) > 0 else None
def validate_and_set(attribute: str, value: str | Any, node_value: str | None) -> None:
"""If attribute is not already set, set it to value or node_value if value is None."""
if value is not None:
if not isinstance(value, str):
raise TypeError(f"{attribute.capitalize()} should be a string value.")
setattr(self, attribute, value)
elif getattr(self, attribute) is None and node_value is not None:
setattr(self, attribute, node_value)
validate_and_set("usr", usr, node.usr if node else None)
validate_and_set("pwd", pwd, node.pwd if node else None)
#: Store local time zone
self._local_tz = tz.tzlocal()
#: :py:func:`eta_nexus.util.round_timestamp`
self._round_timestamp = round_timestamp
#: :py:func:`eta_nexus.util.ensure_timezone`
self._assert_tz_awareness = ensure_timezone
self.exc: BaseException | None = None
[docs]
@classmethod
def from_node(cls, node: Nodes[N] | N, usr: str | None = None, pwd: str | None = None, **kwargs: Any) -> Self:
"""Will return a single connection for an enumerable of nodes with the same url netloc.
Initialize the connection object from a node object. When a list of Node objects is provided,
from_node checks if all nodes match the same connection; it throws an error if they don't.
A node matches a connection if it has the same url netloc.
:param node: Node to initialize from.
:param kwargs: Other arguments are ignored.
:raises: ValueError: if not all nodes match the same connection.
:return: Connection object
"""
nodes = {node} if not isinstance(node, Iterable) else set(node)
# Check if all nodes have the same netloc
if len({f"{_node.url_parsed.netloc}" for _node in nodes}) != 1:
raise ValueError("Nodes must all have the same netloc to be used with the same connection.")
for index, _node in enumerate(nodes):
# Instantiate connection from the first node
if index == 0:
# set the username and password
usr = _node.usr or usr
pwd = _node.pwd or pwd
connection_cls = cls._registry[_node.protocol]
connection = cast("Self", connection_cls._from_node(_node, usr=usr, pwd=pwd, **kwargs))
# Add node to existing connection
else:
connection.selected_nodes.add(_node)
return connection
[docs]
@classmethod
def from_nodes(cls, nodes: Nodes[N], **kwargs: Any) -> dict[str, Connection[N]]:
"""Returns a dictionary of connections for nodes with the same url netloc.
This method handles different Connections, unlike from_node().
The keys of the dictionary are the netlocs of the nodes and
each connection contains the nodes with the same netloc.
(Uses from_node to initialize connections from nodes.).
:param nodes: List of nodes to initialize from.
:param kwargs: Other arguments are ignored.
:return: Dictionary of Connection objects with the netloc as key.
"""
connections: dict[str, Connection[N]] = {}
for node in nodes:
node_id = f"{node.url_parsed.netloc}"
# If we already have a connection for this URL, add the node to connection
if node_id in connections:
connections[node_id].selected_nodes.add(node)
continue # Skip creating a new connection
connections[node_id] = cls.from_node(node, **kwargs)
return connections
@classmethod
@abstractmethod
def _from_node(cls, node: N, **kwargs: Any) -> Self:
"""Initialize the object from a node with corresponding protocol.
:return: Initialized connection object.
"""
if not isinstance(node, Node):
raise TypeError("Node must be a Node object.")
if node.protocol != cls._PROTOCOL:
raise ValueError(
f"Tried to initialize {cls.__name__} from a node "
f"that does not specify {cls._PROTOCOL} as its protocol: {node.name}."
)
return cls(url=node.url, nodes=[node], **kwargs)
@property
def url(self) -> str:
return self._url.geturl()
def _validate_nodes(self, nodes: N | Nodes[N] | None) -> set[N]:
"""Make sure that nodes are a Set of nodes and that all nodes correspond to the protocol and url
of the connection.
:param nodes: Single node or list/set of nodes to validate.
:return: Set of valid node objects for this connection.
"""
if nodes is None:
_nodes = self.selected_nodes
else:
nodes = {nodes} if not isinstance(nodes, Iterable) else nodes
# If not using preselected nodes from self.selected_nodes, check if nodes correspond to the connection
_nodes = {
node for node in nodes if node.protocol == self._PROTOCOL and node.url_parsed.netloc == self._url.netloc
}
# Make sure that some nodes remain after the checks and raise an error if there are none.
if len(_nodes) == 0:
raise ValueError(
f"Some nodes to read from/write to must be specified. If nodes were specified, they do not "
f"match the connection {self.url}"
)
return _nodes
def _preprocess_series_context(
self,
from_time: datetime,
to_time: datetime,
nodes: N | Nodes[N] | None = None,
interval: TimeStep = 1,
**kwargs: Any,
) -> tuple[datetime, datetime, set[N], timedelta]:
"""Preprocesses the series context to ensure it is ready for reading.
This method validates the nodes, ensures the time interval is a timedelta, and rounds the timestamps
to the nearest interval. It also checks that the timezones of from_time and to_time are the same.
:param from_time: The start time of the series.
:param to_time: The end time of the series.
:param nodes: The nodes to read from.
:param interval: The time interval for the series.
:return: A tuple containing the processed from_time, to_time, nodes, and interval.
"""
nodes = self._validate_nodes(nodes)
interval = interval if isinstance(interval, timedelta) else timedelta(seconds=interval)
from_time = round_timestamp(from_time, interval.total_seconds(), method="floor", ensure_tz=True)
to_time = round_timestamp(to_time, interval.total_seconds(), method="ceil", ensure_tz=True)
if from_time.tzinfo != to_time.tzinfo:
log = getLogger(self.__module__)
log.warning(
f"Timezone of from_time and to_time are different. Using from_time timezone: {from_time.tzinfo}"
)
return (from_time, to_time, nodes, interval)
[docs]
@deprecated("Use `Connection` and the appropriate protocols instead.")
class SeriesConnection(
Readable[N], Writable[N], Subscribable[N], SeriesReadable[N], SeriesSubscribable[N], Connection[N], ABC
):
"""Connection object for protocols with the ability to provide access to timeseries data.
:param url: URL of the server to connect to.
"""