Source code for eta_ctrl.util.utils

from __future__ import annotations

import copy
import importlib
import math
import re
from collections.abc import Mapping
from datetime import timedelta
from logging import getLogger
from typing import TYPE_CHECKING, TypeVar

if TYPE_CHECKING:
    from typing import Any

    from eta_ctrl.util.type_annotations import TimeStep


log = getLogger(__name__)

T = TypeVar("T")


[docs] def import_class_from_module(path: str, base_class: type[T] | None = None) -> type[T]: """Import a class from a given python module path. :param path: Path to the class, e.g. 'my_package.my_module.MyClass' :type path: str :raises ModuleNotFoundError: Passed module not found :raises AttributeError: Passed class not found :return: Imported class :rtype: type """ module_name, cls_name = path.rsplit(".", 1) try: module = importlib.import_module(module_name) except ModuleNotFoundError as e: msg = f"Could not find module '{e.name}'. While importing class '{cls_name}'." raise ModuleNotFoundError(msg) from e try: cls = getattr(module, cls_name) except AttributeError as e: msg = f"Could not find class '{cls_name}' in module '{module.__name__}'. " raise AttributeError(msg) from e if base_class is not None and not issubclass(cls, base_class): msg = f"Loaded class '{cls_name}' from {module_name} is not subclass of {base_class}" raise TypeError(msg) return cls
[docs] def dict_get_any(dikt: dict[str, Any], *names: str, fail: bool = True, default: Any = None) -> Any: """Get any of the specified items from dictionary, if any are available. The function will return the first value it finds, even if there are multiple matches. :param dikt: Dictionary to get values from. :param names: Item names to look for. :param fail: Flag to determine, if the function should fail with a KeyError, if none of the items are found. If this is False, the function will return the value specified by 'default'. :param default: Value to return, if none of the items are found and 'fail' is False. :return: Value from dictionary. :raise: KeyError, if none of the requested items are available and fail is True. """ for name in names: if name in dikt: # Return first value found in dictionary return dikt[name] if fail is True: msg = ( f"Did not find one of the required keys in the configuration: {names}. Possibly Check the correct spelling" ) raise KeyError(msg) return default
[docs] def dict_pop_any(dikt: dict[str, Any], *names: str, fail: bool = True, default: Any = None) -> Any: """Pop any of the specified items from dictionary, if any are available. The function will return the first value it finds, even if there are multiple matches. This function removes the found values from the dictionary! :param dikt: Dictionary to pop values from. :param names: Item names to look for. :param fail: Flag to determine, if the function should fail with a KeyError, if none of the items are found. If this is False, the function will return the value specified by 'default'. :param default: Value to return, if none of the items are found and 'fail' is False. :return: Value from dictionary. :raise: KeyError, if none of the requested items are available and fail is True. """ for name in names: if name in dikt: # Return first value found in dictionary return dikt.pop(name) if fail is True: msg = f"Did not find one of the required keys in the configuration: {names}" raise KeyError(msg) return default
[docs] def deep_mapping_update( source: Any, overrides: Mapping[str, str | Mapping[str, Any]] ) -> dict[str, str | Mapping[str, Any]]: """Perform a deep update of a nested dictionary or similar mapping. :param source: Original mapping to be updated. :param overrides: Mapping with new values to integrate into the new mapping. :return: New Mapping with values from the source and overrides combined. """ output = dict(copy.deepcopy(source)) if isinstance(source, Mapping) else {} for key, value in overrides.items(): if isinstance(value, Mapping): output[key] = deep_mapping_update(dict(source).get(key, {}), value) else: output[key] = value return output
[docs] def camel_to_snake_case(camel_name: str) -> str: """Convert a string from camel to snake case convention""" return "".join("_" + c.lower() if c.isupper() else c for c in camel_name).strip("_")
[docs] def snake_to_camel_case(snake_name: str) -> str: """Convert a string from snake_case to PascalCase convention.""" clean_name = re.sub(r"[^a-zA-Z0-9]", "_", snake_name) parts = [part.capitalize() for part in clean_name.split("_") if part] return "".join(parts)
[docs] def timestep_to_seconds(timestep: TimeStep | str) -> float: """Convert a TimeStep or string representation to seconds as a float value. :param timestep: Original timestamp value :return: Value in seconds """ seconds = timestep.total_seconds() if isinstance(timestep, timedelta) else timestep return float(seconds)
[docs] def timestep_to_timedelta(timestep: TimeStep | str) -> timedelta: """Convert a TimeStep or string representation to a timedelta object. :param timestep: Original timestamp value :return: timedelta object representing the duration """ return timestep if isinstance(timestep, timedelta) else timedelta(seconds=float(timestep))
[docs] def is_divisible(a: float, b: float) -> bool: """Check whether a is divisible by b. Just returning a%b==0 will not work for small divisor values. E.g. 15 % 0.05 will result in 0.04999.. and not 0. :param a: Dividend :type a: float :param b: Divisor :type b: float :return: a % b == 0 :rtype: bool """ remainder = a % b return math.isclose(remainder, 0, abs_tol=1e-9) or math.isclose(remainder, b, abs_tol=1e-9)