Source code for eta_ctrl.config.config_setup

from __future__ import annotations

import importlib
from logging import getLogger
from typing import TYPE_CHECKING

from attrs import converters, define, field, fields

if TYPE_CHECKING:
    from typing import Any

    from attrs import Attribute
    from stable_baselines3.common.base_class import BaseAlgorithm, BasePolicy
    from stable_baselines3.common.vec_env import DummyVecEnv

    from eta_ctrl.envs import BaseEnv


log = getLogger(__name__)


def _get_class(instance: ConfigSetup, attrib: Attribute, new_value: str | None) -> str | None:
    """Find module and class name and import the specified class."""
    if new_value is not None:
        module, cls_name = new_value.rsplit(".", 1)
        try:
            cls = getattr(importlib.import_module(module), cls_name)
        except ModuleNotFoundError as e:
            msg = f"Could not find module '{e.name}'. While importing class '{cls_name}' from '{attrib.name}' value."
            raise ModuleNotFoundError(msg) from e
        except AttributeError as e:
            msg = (
                f"Could not find class '{cls_name}' in module '{module}'. "
                f"While importing class '{cls_name}' from '{attrib.name}' value."
            )
            raise AttributeError(msg) from e

        cls_attr_name = f"{attrib.name.rsplit('_', 1)[0]}_class"
        setattr(instance, cls_attr_name, cls)

    return new_value


[docs] @define(frozen=False, kw_only=True) class ConfigSetup: """Configuration options as specified in the "setup" section of the configuration file.""" #: Import description string for the agent class. agent_import: str = field(on_setattr=_get_class) #: Agent class (automatically determined from agent_import). agent_class: type[BaseAlgorithm] = field(init=False) #: Import description string for the environment class. environment_import: str = field(on_setattr=_get_class) #: Imported Environment class (automatically determined from environment_import). environment_class: type[BaseEnv] = field(init=False) #: Import description string for the interaction environment (default: None). interaction_env_import: str | None = field(default=None, on_setattr=_get_class) #: Interaction environment class (default: None) (automatically determined from interaction_env_import). interaction_env_class: type[BaseEnv] | None = field(init=False, default=None) #: Import description string for the environment vectorizer #: (default: stable_baselines3.common.vec_env.dummy_vec_env.DummyVecEnv). vectorizer_import: str = field( default="stable_baselines3.common.vec_env.dummy_vec_env.DummyVecEnv", on_setattr=_get_class, converter=converters.default_if_none( # type: ignore[misc] "stable_baselines3.common.vec_env.dummy_vec_env.DummyVecEnv" ), ) # mypy currently does not recognize converters.default_if_none #: Environment vectorizer class (automatically determined from vectorizer_import). vectorizer_class: type[DummyVecEnv] = field(init=False) #: Import description string for the policy class (default: eta_ctrl.agents.common.NoPolicy). policy_import: str = field( default="eta_ctrl.common.NoPolicy", on_setattr=_get_class, converter=converters.default_if_none("eta_ctrl.common.NoPolicy"), # type: ignore[misc] ) # mypy currently does not recognize converters.default_if_none #: Policy class (automatically determined from policy_import). policy_class: type[BasePolicy] = field(init=False) #: Flag which is true if the environment should be wrapped for monitoring (default: False). monitor_wrapper: bool = field(default=False, converter=bool) #: Flag which is true if the observations should be normalized (default: False). norm_wrapper_obs: bool = field(default=False, converter=bool) #: Flag which is true if the rewards should be normalized (default: False). norm_wrapper_reward: bool = field(default=False, converter=bool) #: Flag to enable tensorboard logging (default: False). tensorboard_log: bool = field(default=False, converter=bool) def __attrs_post_init__(self) -> None: _fields = fields(ConfigSetup) _get_class(self, _fields.agent_import, self.agent_import) _get_class(self, _fields.environment_import, self.environment_import) _get_class(self, _fields.interaction_env_import, self.interaction_env_import) _get_class(self, _fields.vectorizer_import, self.vectorizer_import) _get_class(self, _fields.policy_import, self.policy_import)
[docs] @classmethod def from_dict(cls, dikt: dict[str, Any]) -> ConfigSetup: errors = [] def get_import(name: str, required: bool = False) -> str | Any: """Get import string or combination of package and class name from dictionary. :param name: Name of the configuration value. :param required: Flag to determine if the value is required. """ nonlocal errors, dikt import_value = dikt.pop(f"{name}_import", None) package_value = dikt.pop(f"{name}_package", None) class_value = dikt.pop(f"{name}_class", None) # Check import if import_value is not None: return import_value # Check package and class if package_value is not None and class_value is not None: return f"{package_value}.{class_value}" # If only one of package and class is specified, raise error if (package_value is None) ^ (class_value is None): msg = f"Only one of '{name}_package' and '{name}_class' is specified." log.info(msg) # Raise error if required value is missing if required: msg = f"'{name}_import' or both of '{name}_package' and '{name}_class' parameters must be specified." log.error(msg) errors.append(name) return None agent_import = get_import("agent", required=True) environment_import = get_import("environment", required=True) interaction_env_import = get_import("interaction_env") vectorizer_import = get_import("vectorizer") policy_import = get_import("policy") monitor_wrapper = dikt.pop("monitor_wrapper", None) norm_wrapper_obs = dikt.pop("norm_wrapper_obs", None) norm_wrapper_reward = dikt.pop("norm_wrapper_reward", None) tensorboard_log = dikt.pop("tensorboard_log", None) # Log configuration values which were not recognized. if dikt: msg = "Following values were not recognized in the config setup section and are ignored: " msg += ", ".join(dikt.keys()) log.warning(msg) if errors: msg = "Not all required values were found in setup section (see log). Could not load config file. " msg += f"Missing values: {', '.join(errors)}" raise ValueError(msg) return ConfigSetup( agent_import=agent_import, environment_import=environment_import, interaction_env_import=interaction_env_import, vectorizer_import=vectorizer_import, policy_import=policy_import, monitor_wrapper=monitor_wrapper, norm_wrapper_obs=norm_wrapper_obs, norm_wrapper_reward=norm_wrapper_reward, tensorboard_log=tensorboard_log, )
def __getitem__(self, name: str) -> Any: return getattr(self, name) def __setitem__(self, name: str, value: Any) -> None: if not hasattr(self, name): msg = f"The key {name} does not exist - it cannot be set." raise KeyError(msg) setattr(self, name, value)