Source code for eta_nexus.util.io_utils

from __future__ import annotations

import csv
import io
import json
import pathlib
import re
import sys
from collections.abc import Mapping, Sequence
from logging import getLogger
from typing import TYPE_CHECKING

import pandas as pd
import toml
import yaml
from dotenv import find_dotenv, load_dotenv

if TYPE_CHECKING:
    import types
    from collections.abc import Callable
    from typing import Any

    from eta_nexus.util.type_annotations import Path, Self


log = getLogger(__name__)


[docs] def json_import(path: Path) -> list[Any] | dict[str, Any]: """Extend standard JSON import to allow '//' comments in JSON files. :param path: Path to JSON file. :return: Parsed dictionary. """ path = pathlib.Path(path) if not isinstance(path, pathlib.Path) else path try: # Remove comments from the JSON file (using regular expression), then parse it into a dictionary cleanup = re.compile(r"^((?:(?:[^\/\"])*(?:\"[^\"]*\")*(?:\/[^\/])*)*)", re.MULTILINE) with path.open("r") as f: file = "\n".join(cleanup.findall(f.read())) result = json.loads(file) log.info(f"JSON file {path} loaded successfully.") except OSError as e: log.exception(f"JSON file couldn't be loaded: {e.strerror}. Filename: {e.filename}") raise return result
[docs] def toml_import(path: Path) -> dict[str, Any]: """Import a TOML file and return the parsed dictionary. :param path: Path to TOML file. :return: Parsed dictionary. """ path = pathlib.Path(path) try: with path.open("r") as f: result = toml.load(f) log.info(f"TOML file {path} loaded successfully.") except OSError as e: log.exception(f"TOML file couldn't be loaded: {e.strerror}. Filename: {e.filename}") raise return result
[docs] def yaml_import(path: Path) -> dict[str, Any]: """Import a YAML file and return the parsed dictionary. :param path: Path to YAML file. :return: Parsed dictionary. """ path = pathlib.Path(path) try: with path.open("r") as f: result = yaml.safe_load(f) log.info(f"YAML file {path} loaded successfully.") except OSError as e: log.exception(f"YAML file couldn't be loaded: {e.strerror}. Filename: {e.filename}") raise return result
[docs] def load_config(file: Path) -> dict[str, Any]: """Load configuration from JSON, TOML, or YAML file. The read file is expected to contain a dictionary of configuration options. When no file extension is provided, searches for files in the following priority order: 1. JSON 2. TOML 3. YML 4. YAML :param file: Path to the configuration file. :return: Dictionary of configuration options. """ possible_extensions: dict[str, Callable] = { ".json": json_import, ".toml": toml_import, ".yml": yaml_import, ".yaml": yaml_import, } file_path = pathlib.Path(file) if file_path.suffix != "": # File ending is provided ext = file_path.suffix.lower() import_method = possible_extensions.get(ext) if import_method is None: raise ValueError(f"Unsupported config file extension: {ext}") else: for extension, _import_method in possible_extensions.items(): file_path = file_path.with_suffix(extension) if file_path.exists(): import_method = _import_method break else: raise FileNotFoundError(f"Config file not found: {file}") config = import_method(file_path) if not isinstance(config, dict): raise TypeError(f"Config file {file} must define a dictionary of options.") return config
[docs] def replace_decimal_str(value: str | float, decimal: str = ".") -> str: """Replace the decimal sign in a string. :param value: The value to replace in. :param decimal: New decimal sign. """ return str(value).replace(".", decimal)
[docs] def csv_export( path: Path, data: Mapping[str, Any] | Sequence[Mapping[str, Any] | Any] | pd.DataFrame, names: Sequence[str] | None = None, index: Sequence[int] | pd.DatetimeIndex | None = None, *, sep: str = ";", decimal: str = ".", ) -> None: """Export data to CSV file. :param path: Directory path to export data. :param data: Data to be saved. :param names: Field names used when data is a Matrix without column names. :param index: Optional sequence to set an index :param sep: Separator to use between the fields. :param decimal: Sign to use for decimal points. """ _path = path if isinstance(path, pathlib.Path) else pathlib.Path(path) if _path.suffix != ".csv": _path.with_suffix(".csv") if isinstance(data, Mapping): with _path.open("a") as f: writer = csv.DictWriter(f, fieldnames=data.keys(), delimiter=sep) if not _path.exists(): writer.writeheader() writer.writerow({key: replace_decimal_str(val, decimal) for key, val in data.items()}) elif isinstance(data, pd.DataFrame): if index is not None: data.index = index data.to_csv(path_or_buf=str(_path), sep=sep, decimal=decimal) elif isinstance(data, Sequence): if names is not None: cols = names elif isinstance(data[-1], Mapping): cols = list(data[-1].keys()) else: raise ValueError("Column names for csv export not specified.") _data = pd.DataFrame(data=data, columns=cols) if index is not None: _data.index = index _data.to_csv(path_or_buf=str(_path), sep=sep, decimal=decimal) log.info(f"Exported CSV data to {_path}.")
[docs] def autoload_env() -> None: """Load a .env file from the user's working directory by walking upward from there.""" dotenv_path = find_dotenv(usecwd=True) load_dotenv(dotenv_path, override=False)
[docs] class Suppressor(io.TextIOBase): """Context manager to suppress standard output.""" def __enter__(self) -> Self: self.stderr = sys.stderr sys.stderr = self return self def __exit__( self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: types.TracebackType | None ) -> None: sys.stderr = self.stderr if exc_type is not None: raise exc_type(exc_val).with_traceback(exc_tb)
[docs] def write(self, x: Any) -> int: return 0