Source code for eta_ctrl.config.config_setup

from __future__ import annotations

import importlib
from logging import getLogger
from typing import TYPE_CHECKING, Any

from pydantic import BaseModel, ConfigDict, Field, GetJsonSchemaHandler, model_validator
from stable_baselines3.common.base_class import BaseAlgorithm
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv, VecEnv

from eta_ctrl.envs import BaseEnv

if TYPE_CHECKING:
    from pydantic.json_schema import JsonSchemaValue


log = getLogger(__name__)


def _import_class(import_path: str, expected_base: type) -> type:
    """Import a class from *import_path* and verify it is a subclass of *expected_base*."""
    module, cls_name = import_path.rsplit(".", 1)
    cls = getattr(importlib.import_module(module), cls_name)
    if not issubclass(cls, expected_base):
        msg = f"'{import_path}' resolved to {cls}, which is not a subclass of {expected_base.__name__}"
        raise TypeError(msg)
    return cls


[docs] class ConfigSetup(BaseModel): """Helper class, which is part of `Config`, for import and setup parameters.""" model_config = ConfigDict(frozen=True, extra="forbid", use_attribute_docstrings=True) agent_import: str """Import description string for the agent class.""" agent_class: type[BaseAlgorithm] = Field(exclude=True) """Agent class (automatically determined from agent_import).""" environment_import: str """Import description string for the environment class.""" environment_class: type[BaseEnv] = Field(exclude=True) """Imported Environment class (automatically determined from environment_import).""" vectorizer_import: str = "stable_baselines3.common.vec_env.dummy_vec_env.DummyVecEnv" """Import description string for the environment vectorizer (default: stable_baselines3.common.vec_env.dummy_vec_env.DummyVecEnv).""" vectorizer_class: type[DummyVecEnv | SubprocVecEnv] = Field(exclude=True) """Environment vectorizer class (automatically determined from vectorizer_import).""" policy_import: str = "eta_ctrl.common.NoPolicy" """Import description string for the policy class (default: eta_ctrl.agents.common.NoPolicy).""" policy_class: type[BasePolicy] = Field(exclude=True) """Policy class (automatically determined from policy_import).""" monitor_wrapper: bool = False """Flag which is true if the environment should be wrapped for monitoring (default: False).""" norm_wrapper_obs: bool = False """Flag which is true if the observations should be normalized (default: False).""" norm_wrapper_reward: bool = False """Flag which is true if the rewards should be normalized (default: False).""" tensorboard_log: bool = False """Flag to enable tensorboard logging (default: False).""" @model_validator(mode="before") @classmethod def _resolve_classes(cls, data: Any) -> Any: if not isinstance(data, dict): return data data = dict(data) data["agent_class"] = _import_class(data["agent_import"], BaseAlgorithm) data["environment_class"] = _import_class(data["environment_import"], BaseEnv) data["vectorizer_class"] = _import_class( data.get("vectorizer_import", "stable_baselines3.common.vec_env.dummy_vec_env.DummyVecEnv"), VecEnv, ) data["policy_class"] = _import_class( data.get("policy_import", "eta_ctrl.common.NoPolicy"), BasePolicy, ) return data @classmethod def __get_pydantic_json_schema__(cls, core_schema: Any, handler: GetJsonSchemaHandler) -> JsonSchemaValue: json_schema = handler(core_schema) # Remove resolved class fields — these are set automatically from the _import fields for field in ("agent_class", "environment_class", "vectorizer_class", "policy_class"): json_schema.get("properties", {}).pop(field, None) if "required" in json_schema: json_schema["required"] = [r for r in json_schema["required"] if r != field] return json_schema def __str__(self) -> str: """Human-readable string representation of ConfigSetup.""" return f"ConfigSetup(env={self.environment_class.__name__}, agent={self.agent_class.__name__})"