Source code for eta_ctrl.core_utils

from __future__ import annotations

import inspect
import pathlib
from functools import partial
from typing import TYPE_CHECKING

if TYPE_CHECKING:
    from collections.abc import Callable

    from gymnasium import Env
    from stable_baselines3.common.base_class import BaseAlgorithm, BasePolicy
    from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv, VecEnv, VecNormalize

    from eta_ctrl.config import ConfigRun
    from eta_ctrl.envs import BaseEnv
    from eta_ctrl.util.type_annotations import AlgoSettings, EnvSettings, Path
from logging import getLogger

log = getLogger(__name__)


[docs] def vectorize_environment( env: type[BaseEnv], config_run: ConfigRun, env_settings: EnvSettings, callback: Callable[[BaseEnv], None], verbose: int = 2, vectorizer: type[DummyVecEnv | SubprocVecEnv] | None = None, n: int = 1, seed: int | None = None, *, training: bool = False, monitor_wrapper: bool = False, norm_wrapper_obs: bool = False, norm_wrapper_reward: bool = False, ) -> VecNormalize | VecEnv: """Vectorize the environment and automatically apply normalization wrappers if configured. :param env: Environment class which will be instantiated and vectorized. :param config_run: Configuration for a specific optimization run. :param env_settings: Configuration settings dictionary for the environment which is being initialized. :param callback: Callback to call with an environment instance. :param verbose: Logging verbosity to use in the environment. :param vectorizer: Vectorizer class to use for vectorizing the environments. :param n: Number of vectorized environments to create. :param training: Flag to identify whether the environment should be initialized for training or playing. If true, it will be initialized for training. :param norm_wrapper_obs: Flag to determine whether observations from the environments should be normalized. :param norm_wrapper_reward: Flag to determine whether rewards from the environments should be normalized. :return: Vectorized environments, possibly also wrapped in a normalizer. """ from stable_baselines3.common.vec_env import DummyVecEnv, VecMonitor, VecNormalize # noqa: PLC0415 if vectorizer is None: vectorizer = DummyVecEnv # Create the vectorized environment log.debug("Trying to vectorize the environment.") # Ensure n is one, if the DummyVecEnv is used (it doesn't support more than one) if vectorizer is DummyVecEnv and n != 1: n = 1 log.warning("Setting number of environments to 1 because DummyVecEnv (default) is used.") verbose = env_settings.pop("verbose", verbose) # Create the vectorized environment def create_env(env_id: int) -> Env: env_id += 1 return env(env_id=env_id, config_run=config_run, verbose=verbose, callback=callback, seed=seed, **env_settings) envs: VecEnv | VecNormalize envs = vectorizer([partial(create_env, i) for i in range(n)]) # The VecMonitor knows the ep_reward and so this can be logged to tensorboard if monitor_wrapper: envs = VecMonitor(envs) # Automatically normalize the input features if norm_wrapper_obs or norm_wrapper_reward: # check if normalization data is available and load it if possible, otherwise # create a new normalization wrapper. if config_run.vec_normalize_path.is_file(): log.info( f"Normalization data detected. Loading running averages into normalization wrapper: \n" f"\t {config_run.vec_normalize_path}" ) envs = VecNormalize.load(str(config_run.vec_normalize_path), envs) envs.training = training envs.norm_obs = norm_wrapper_obs envs.norm_reward = norm_wrapper_reward else: log.info("No Normalization data detected.") envs = VecNormalize(envs, training=training, norm_obs=norm_wrapper_obs, norm_reward=norm_wrapper_reward) return envs
def _check_tensorboard_log(tensorboard_log: bool, log_path: Path | None) -> dict[str, str]: """Create necessary arguments for tensorboard logging if required :param tensorboard_log: Flag to enable logging to tensorboard. :param log_path: Path for tensorboard log. Online required if logging is true :return: Kwargs for the agent. """ if tensorboard_log: if log_path is None: msg = "If tensorboard logging is enabled, a path for results must be specified as well." raise ValueError(msg) _log_path = pathlib.Path(log_path) log.info(f"Tensorboard logging is enabled. Log file: {_log_path}") log.info( f"Please run the following command in the console to start tensorboard: \n" f'tensorboard --logdir "{_log_path}" --port 6006' ) return {"tensorboard_log": str(_log_path)} return {}
[docs] def initialize_model( algo: type[BaseAlgorithm], policy: type[BasePolicy], envs: VecEnv | VecNormalize, algo_settings: AlgoSettings, seed: int | None = None, *, tensorboard_log: bool = False, log_path: Path | None = None, ) -> BaseAlgorithm: """Initialize a new model or algorithm. :param algo: Algorithm to initialize. :param policy: The policy that should be used by the algorithm. :param envs: The environment which the algorithm operates on. :param algo_settings: Additional settings for the algorithm. :param seed: Random seed to be used by the algorithm. :param tensorboard_log: Flag to enable logging to tensorboard. :param log_path: Path for tensorboard log. Online required if logging is true :return: Initialized model. """ log.debug(f"Trying to initialize model: {algo.__name__}") # tensorboard logging algo_kwargs = _check_tensorboard_log(tensorboard_log, log_path) # check if the agent takes all the default parameters. algo_settings.setdefault("seed", seed) algo_params = inspect.signature(algo).parameters if "seed" not in algo_params and inspect.Parameter.VAR_KEYWORD not in {p.kind for p in algo_params.values()}: del algo_settings["seed"] log.warning( f"'seed' is not a valid parameter for agent {algo.__name__}. This default parameter will be ignored." ) # create model instance return algo(policy=policy, env=envs, **algo_settings, **algo_kwargs) # type: ignore[arg-type]
[docs] def load_model( algo: type[BaseAlgorithm], envs: VecEnv | VecNormalize, algo_settings: AlgoSettings, model_path: Path, *, tensorboard_log: bool = False, log_path: Path | None = None, ) -> BaseAlgorithm: """Load an existing model. :param algo: Algorithm type of the model to be loaded. :param envs: The environment which the algorithm operates on. :param algo_settings: Additional settings for the algorithm. :param model_path: Path to load the model from. :param tensorboard_log: Flag to enable logging to tensorboard. :param log_path: Path for tensorboard log. Online required if logging is true :return: Initialized model. """ log.debug(f"Trying to load existing model: {model_path}") _model_path = model_path if isinstance(model_path, pathlib.Path) else pathlib.Path(model_path) if not _model_path.exists(): msg = f"Model couldn't be loaded. Path not found: {_model_path}" raise OSError(msg) # tensorboard logging algo_kwargs = _check_tensorboard_log(tensorboard_log, log_path) try: model = algo.load(_model_path, envs, **algo_settings, **algo_kwargs) # type: ignore[arg-type] log.debug("Model loaded successfully.") except OSError as e: msg = f"Model couldn't be loaded: {e.strerror}. Filename: {e.filename}" raise OSError(msg) from e return model