Source code for eta_ctrl.common.common

from __future__ import annotations

import inspect
import pathlib
from logging import getLogger
from typing import TYPE_CHECKING

from eta_ctrl.util import dict_get_any

if TYPE_CHECKING:
    from collections.abc import Mapping, Sequence
    from typing import Any

    import torch as th
    from stable_baselines3.common.vec_env import VecEnv, VecNormalize

    from eta_ctrl.envs import BaseEnv
    from eta_ctrl.util.type_annotations import Path

log = getLogger(__name__)


[docs] def deserialize_net_arch( net_arch: Sequence[Mapping[str, Any]], in_features: int, device: th.device | str = "auto" ) -> th.nn.Sequential: """Deserialize_net_arch can take a list of dictionaries describing a sequential torch network and deserialize it by instantiating the corresponding classes. An example for a possible net_arch would be: .. code-block:: [{"layer": "Linear", "out_features": 60}, {"activation_func": "Tanh"}, {"layer": "Linear", "out_features": 60}, {"activation_func": "Tanh"}] One key of the dictionary should be either 'layer', 'activation_func' or 'process'. If the 'layer' key is present, a layer from the :py:mod:`torch.nn` module is instantiated, if the 'activation_func' key is present, the value will be instantiated as an activation function from :py:mod:`torch.nn`. If the key 'process' is present, the value will be interpreted as a data processor from :py:mod:`eta_ctrl.common.processors`. All other keys of each dictionary will be used as keyword parameters to the instantiation of the layer, activation function or processor. Only the number of input features for the first layer must be specified (using the 'in_features') parameter. The function will then automatically determine the number of input features for all other layers in the sequential network. :param net_arch: List of dictionaries describing the network architecture. :param in_features: Number of input features for the first layer. :param device: Torch device to use for training the network. :return: Sequential torch network. """ import torch as th # noqa: PLC0415 from .sb3_extensions import processors # noqa: PLC0415 network = th.nn.Sequential() _features = in_features for net in net_arch: _net = dict(net) if "process" in net: process = getattr(processors, _net.pop("process")) # The "Split" process must be treated differently, because it needs to be deserialized recursively. if {"net_arch" and "sizes"} < inspect.signature(process).parameters.keys(): sizes = process.get_full_sizes(_features, _net["sizes"]) _net["net_arch"] = [deserialize_net_arch(e, sizes[i], device) for i, e in enumerate(_net["net_arch"])] try: if len({"in_channels", "in_features"} & inspect.signature(process).parameters.keys()) > 0: network.append(process(_features, **_net)) else: network.append(process(**_net)) except TypeError as e: msg = f"Could not instantiate processing module {process.__name__}: {e}" raise TypeError(msg) from e elif "layer" in net: layer = getattr(th.nn, _net.pop("layer")) # Set the number of input features if required by the layer class try: if len({"in_channels", "in_features"} & inspect.signature(layer).parameters.keys()) > 0: network.append(layer(_features, **_net)) else: network.append(layer(**_net)) except TypeError as e: msg = f"Could not instantiate layer module {layer.__name__}: {e}" raise TypeError(msg) from e elif "activation_func" in net: activation_func = _net.pop("activation_func") try: network.append(getattr(th.nn, activation_func)(**_net)) except TypeError as e: msg = f"Could not instantiate activation function module {activation_func}: {e}" raise TypeError(msg) from e else: msg = f"Unknown process or layer type: {net}." raise ValueError(msg) _features = dict_get_any(_net, "out_channels", "out_features", fail=False, default=_features) network.to(device) return network
[docs] def is_vectorized(env: BaseEnv | VecEnv | VecNormalize | None) -> bool: """Check if an environment is vectorized. :param env: The environment to check. """ if env is None: return False return hasattr(env, "num_envs")
[docs] def is_closed(env: BaseEnv | VecEnv | VecNormalize | None) -> bool: """Check whether an environment has been closed. :param env: The environment to check. """ if env is None: return True if hasattr(env, "closed"): return env.closed if hasattr(env, "venv"): return is_closed(env.venv) return False
[docs] def episode_results_path(series_results_path: Path, run_name: str, episode: int, env_id: int = 1) -> pathlib.Path: """Generate a filepath which can be used for storing episode results of a specific environment as a csv file. Name is of the format: ThisRun_001_01.csv (run name _ episode number _ environment id .csv) :param series_results_path: Path for results of the series of optimization runs. :param run_name: Name of the optimization run. :param episode: Number of the episode the environment is working on. :param env_id: Identification of the environment. """ path = series_results_path if isinstance(series_results_path, pathlib.Path) else pathlib.Path(series_results_path) return path / f"{episode_name_string(run_name, episode, env_id)}.csv"
[docs] def episode_name_string(run_name: str, episode: int, env_id: int = 1) -> str: """Generate a name which can be used to pre or postfix files from a specific episode and run of an environment. Name is of the format: ThisRun_001_01 (run name _ episode number _ environment id) :param run_name: Name of the optimization run. :param episode: Number of the episode the environment is working on. :param env_id: Identification of the environment. """ return f"{run_name}_{episode:0>#3}_{env_id:0>#2}"