from __future__ import annotations
import copy
import importlib
import math
import re
from collections.abc import Mapping
from datetime import timedelta
from logging import getLogger
from typing import TYPE_CHECKING, TypeVar
if TYPE_CHECKING:
from typing import Any
from eta_ctrl.util.type_annotations import TimeStep
log = getLogger(__name__)
T = TypeVar("T")
[docs]
def import_class_from_module(path: str, base_class: type[T] | None = None) -> type[T]:
"""Import a class from a given python module path.
:param path: Path to the class, e.g. 'my_package.my_module.MyClass'
:type path: str
:raises ModuleNotFoundError: Passed module not found
:raises AttributeError: Passed class not found
:return: Imported class
:rtype: type
"""
module_name, cls_name = path.rsplit(".", 1)
try:
module = importlib.import_module(module_name)
except ModuleNotFoundError as e:
msg = f"Could not find module '{e.name}'. While importing class '{cls_name}'."
raise ModuleNotFoundError(msg) from e
try:
cls = getattr(module, cls_name)
except AttributeError as e:
msg = f"Could not find class '{cls_name}' in module '{module.__name__}'. "
raise AttributeError(msg) from e
if base_class is not None and not issubclass(cls, base_class):
msg = f"Loaded class '{cls_name}' from {module_name} is not subclass of {base_class}"
raise TypeError(msg)
return cls
[docs]
def dict_get_any(dikt: dict[str, Any], *names: str, fail: bool = True, default: Any = None) -> Any:
"""Get any of the specified items from dictionary, if any are available.
The function will return the first value it finds, even if there are multiple matches.
:param dikt: Dictionary to get values from.
:param names: Item names to look for.
:param fail: Flag to determine, if the function should fail with a KeyError, if none of the items are found.
If this is False, the function will return the value specified by 'default'.
:param default: Value to return, if none of the items are found and 'fail' is False.
:return: Value from dictionary.
:raise: KeyError, if none of the requested items are available and fail is True.
"""
for name in names:
if name in dikt:
# Return first value found in dictionary
return dikt[name]
if fail is True:
msg = (
f"Did not find one of the required keys in the configuration: {names}. Possibly Check the correct spelling"
)
raise KeyError(msg)
return default
[docs]
def dict_pop_any(dikt: dict[str, Any], *names: str, fail: bool = True, default: Any = None) -> Any:
"""Pop any of the specified items from dictionary, if any are available.
The function will return the first value it finds, even if there are multiple matches.
This function removes the found values from the dictionary!
:param dikt: Dictionary to pop values from.
:param names: Item names to look for.
:param fail: Flag to determine, if the function should fail with a KeyError, if none of the items are found.
If this is False, the function will return the value specified by 'default'.
:param default: Value to return, if none of the items are found and 'fail' is False.
:return: Value from dictionary.
:raise: KeyError, if none of the requested items are available and fail is True.
"""
for name in names:
if name in dikt:
# Return first value found in dictionary
return dikt.pop(name)
if fail is True:
msg = f"Did not find one of the required keys in the configuration: {names}"
raise KeyError(msg)
return default
[docs]
def dict_search(dikt: dict[str, str], val: str) -> str:
"""Get key of _psr_types dictionary, given value.
Raise ValueError in case of value not specified in data.
:param val: value to search
:param data: dictionary to search for value
:return: key of the dictionary
"""
for key, value in dikt.items():
if val == value:
return key
msg = f"Value: {val} not specified in specified dictionary"
raise ValueError(msg)
[docs]
def deep_mapping_update(
source: Any, overrides: Mapping[str, str | Mapping[str, Any]]
) -> dict[str, str | Mapping[str, Any]]:
"""Perform a deep update of a nested dictionary or similar mapping.
:param source: Original mapping to be updated.
:param overrides: Mapping with new values to integrate into the new mapping.
:return: New Mapping with values from the source and overrides combined.
"""
output = dict(copy.deepcopy(source)) if isinstance(source, Mapping) else {}
for key, value in overrides.items():
if isinstance(value, Mapping):
output[key] = deep_mapping_update(dict(source).get(key, {}), value)
else:
output[key] = value
return output
[docs]
def camel_to_snake_case(camel_name: str) -> str:
"""Convert a string from camel to snake case convention"""
return "".join("_" + c.lower() if c.isupper() else c for c in camel_name).strip("_")
[docs]
def snake_to_camel_case(snake_name: str) -> str:
"""Convert a string from snake_case to PascalCase convention."""
clean_name = re.sub(r"[^a-zA-Z0-9]", "_", snake_name)
parts = [part.capitalize() for part in clean_name.split("_") if part]
return "".join(parts)
[docs]
def timestep_to_seconds(timestep: TimeStep | str) -> float:
"""Convert a TimeStep or string representation to seconds as a float value.
:param timestep: Original timestamp value
:return: Value in seconds
"""
seconds = timestep.total_seconds() if isinstance(timestep, timedelta) else timestep
return float(seconds)
[docs]
def timestep_to_timedelta(timestep: TimeStep | str) -> timedelta:
"""Convert a TimeStep or string representation to a timedelta object.
:param timestep: Original timestamp value
:return: timedelta object representing the duration
"""
return timestep if isinstance(timestep, timedelta) else timedelta(seconds=float(timestep))
[docs]
def is_divisible(a: float, b: float) -> bool:
"""Check whether a is divisible by b.
Just returning a%b==0 will not work for small divisor values.
E.g. 15 % 0.05 will result in 0.04999.. and not 0.
:param a: Dividend
:type a: float
:param b: Divisor
:type b: float
:return: a % b == 0
:rtype: bool
"""
remainder = a % b
return math.isclose(remainder, 0, abs_tol=1e-9) or math.isclose(remainder, b, abs_tol=1e-9)