Source code for eta_ctrl.core

from __future__ import annotations

from contextlib import contextmanager
from datetime import datetime
from logging import getLogger
from typing import TYPE_CHECKING

import numpy as np

from eta_ctrl.common import (
    CallbackEnvironment,
    is_closed,
    log_net_arch,
    log_run_info,
    log_to_file,
    merge_callbacks,
)
from eta_ctrl.config import Config, ConfigRun

from .core_utils import (
    initialize_model,
    load_model,
    vectorize_environment,
)

if TYPE_CHECKING:
    from collections.abc import Generator, Mapping
    from typing import Any

    from stable_baselines3.common.base_class import BaseAlgorithm
    from stable_baselines3.common.type_aliases import MaybeCallback
    from stable_baselines3.common.vec_env import VecEnv, VecNormalize
    from stable_baselines3.common.vec_env.base_vec_env import VecEnvObs

    from eta_ctrl.util.type_annotations import Path

log = getLogger(__name__)


[docs] class EtaCtrl: """Initialize an optimization model and provide interfaces for optimization, learning and execution (play). :param config_name: Name of configuration file in configuration directory (should be JSON format). :param root_path: Root path of the application (the configuration will be interpreted relative to this). :param config_overwrite: Dictionary to overwrite selected configurations. :param config_relpath: Path to configuration file, relative to root path. """ def __init__( self, config_name: str, root_path: Path | None = None, config_overwrite: Mapping[str, Any] | None = None, config_relpath: Path | None = None, ) -> None: #: Create Config object for the optimization run. self.config: Config = Config.from_file( root_path=root_path, config_relpath=config_relpath, config_name=config_name, overwrite=config_overwrite ) log.setLevel(int(self.config.settings.verbose * 10)) #: Configuration for an optimization run. self.config_run: ConfigRun | None = None #: The vectorized environments. self._environments: VecEnv | VecNormalize | None = None #: The model or algorithm. self._model: BaseAlgorithm | None = None def __str__(self) -> str: """Human-readable string representation of EtaCtrl.""" env_class = self.config.setup.environment_class.__name__ agent_class = self.config.setup.agent_class.__name__ return f"EtaCtrl(config='{self.config.config_name}', env={env_class}, agent={agent_class})" def __repr__(self) -> str: """Developer-friendly string representation of EtaCtrl.""" return ( f"EtaCtrl(config_name='{self.config.config_name}', root_path='{self.config.root_path}', " f"config_run_initialized={self.config_run is not None})" ) @property def environments(self) -> VecEnv | VecNormalize: if self._environments is None: msg = "Initialized environments could not be found. Call prepare_environments first." raise TypeError(msg) return self._environments @environments.setter def environments(self, environments: VecEnv | VecNormalize) -> None: self._environments = environments @property def model(self) -> BaseAlgorithm: if self._model is None: msg = "Initialized model could not be found. Call prepare_environments first." raise TypeError(msg) return self._model @model.setter def model(self, model: BaseAlgorithm) -> None: self._model = model
[docs] @contextmanager def prepare_environments_models( self, *, series_name: str | None, run_name: str | None, run_description: str = "", reset: bool = False, training: bool = False, ) -> Generator: if is_closed(self._environments) or self._model is None: _series_name = series_name if series_name is not None else "" _run_name = run_name if run_name is not None else "" self.prepare_run(_series_name, _run_name, run_description) with self.prepare_environments(training=training): self.prepare_model(reset=reset) yield
[docs] def prepare_run(self, series_name: str, run_name: str, run_description: str = "") -> None: """Prepare the learn and play methods by reading configuration, creating results folders and the model. :param series_name: Name for a series of runs. :param run_name: Name for a specific run. :param run_description: Description for a specific run. """ self.config_run = ConfigRun( series=series_name, name=run_name, description=run_description, root_path=self.config.root_path, results_path=self.config.results_path, scenarios_path=self.config.scenarios_path, ) self.config_run.create_results_folders() # Add file handler to parent logger to log the terminal output log_to_file(config=self.config, config_run=self.config_run) log.info("Run prepared successfully.")
[docs] def prepare_model(self, *, reset: bool = False) -> None: """Check for existing model and load it or back it up and create a new model. :param reset: Flag to determine whether an existing model should be reset. """ self._prepare_model(reset=reset)
def _prepare_model(self, *, reset: bool = False) -> None: """Check for existing model and load it or back it up and create a new model. :param reset: Flag to determine whether an existing model should be reset. """ if self.config_run is None: msg = "Set the config_run attribute before trying to initialize the model (e.g. by calling prepare_run)." raise ValueError(msg) model_path = self.config_run.run_model_path if model_path.is_file() and reset: log.info(f"Existing model detected: {model_path}") bak_name = model_path / f"_{datetime.fromtimestamp(model_path.stat().st_mtime).strftime('%Y%m%d_%H%M')}.bak" model_path.rename(bak_name) log.info(f"Reset is active. Existing model will be backed up. Backup file name: {bak_name}") elif model_path.is_file(): log.info(f"Existing model detected: {model_path}. Loading model.") self.model = load_model( self.config.setup.agent_class, self.environments, self.config.settings.agent, self.config_run.run_model_path, tensorboard_log=self.config.setup.tensorboard_log, log_path=self.config_run.series_results_path, ) return # Initialize the model if it wasn't loaded from a file self.model = initialize_model( self.config.setup.agent_class, self.config.setup.policy_class, self.environments, self.config.settings.agent, self.config.settings.seed, tensorboard_log=self.config.setup.tensorboard_log, log_path=self.config_run.series_results_path, )
[docs] @contextmanager def prepare_environments(self, *, training: bool = True) -> Generator: """Context manager which prepares the environments and closes them after it exits. :param training: Should preparation be done for training (alternative: playing)? """ # If the agents specifies the population parameter, the number of environments usually has to be # equal to that value as well. See NSGA-II agent. if ( "population" in self.config.settings.agent and self.config.settings.n_environments != self.config.settings.agent["population"] ): if self.config.settings.n_environments != 1: log.warning( f"Agent specifies 'population' parameter but the number of environments " f"({self.config.settings.n_environments}) is not equal to the population. " f"Setting 'n_environments' to {self.config.settings.agent['population']}" ) self.config.settings.n_environments = self.config.settings.agent["population"] try: self._prepare_environments(training=training) yield finally: try: log.debug("Closing environments.") self.environments.close() except TypeError: log.exception("Environment initialization failed.")
def _prepare_environments(self, *, training: bool = True) -> None: """Vectorize and prepare the environments. :param training: Should preparation be done for training (alternative: playing)? """ if self.config_run is None: msg = "Set the config_run attribute before trying to initialize the model (e.g. by calling prepare_run)." raise ValueError(msg) env_class = self.config.setup.environment_class self.config_run.set_env_info(env_class) callback = CallbackEnvironment(self.config.settings.plot_interval) # Vectorize the environments self.environments = vectorize_environment( env=env_class, config_run=self.config_run, env_settings=self.config.settings.environment, callback=callback, verbose=self.config.settings.verbose, vectorizer=self.config.setup.vectorizer_class, n=self.config.settings.n_environments, seed=self.config.settings.seed, training=training, monitor_wrapper=self.config.setup.monitor_wrapper, norm_wrapper_obs=self.config.setup.norm_wrapper_obs, norm_wrapper_reward=self.config.setup.norm_wrapper_reward, )
[docs] def learn( self, *, series_name: str | None = None, run_name: str | None = None, run_description: str = "", reset: bool = False, callbacks: MaybeCallback = None, ) -> None: """Start the learning job for an agent with the specified environment. :param series_name: Name for a series of runs. :param run_name: Name for a specific run. :param run_description: Description for a specific run. :param reset: Indication whether possibly existing models should be reset. Learning will be continued if model exists and reset is false. :param callbacks: Provide additional callbacks to send to the model.learn() call. """ with self.prepare_environments_models( series_name=series_name, run_name=run_name, run_description=run_description, reset=reset, training=True ): if self.config_run is None: msg = ( "Set the config_run attribute before trying to initialize the model (e.g. by calling prepare_run)." ) raise ValueError(msg) # Log some information about the model and configuration log_net_arch(self.model, self.config_run) log_run_info(self.config, self.config_run) # Genetic algorithm has a slightly different concept for saving since it does not stop between time steps if "n_generations" in self.config.settings.agent: save_freq = self.config.settings.save_model_every_x_episodes total_timesteps = self.config.settings.agent["n_generations"] else: # Check if all required config values are present if self.config.settings.n_episodes_learn is None: msg = "Missing configuration values for learning: 'n_episodes_learn'." raise ValueError(msg) # define callback for periodically saving models save_freq = int( self.config.settings.episode_duration / self.config.settings.sampling_time * self.config.settings.save_model_every_x_episodes ) total_timesteps = int( self.config.settings.n_episodes_learn * self.config.settings.episode_duration / self.config.settings.sampling_time ) # Set the seed for the environments before starting to learn self.environments.seed(self.config.settings.seed) from stable_baselines3.common.callbacks import CheckpointCallback # noqa: PLC0415 callback_learn = merge_callbacks( CheckpointCallback( save_freq=save_freq, save_path=str(self.config_run.series_results_path / "models"), name_prefix=self.config_run.name, ), callbacks, ) # The experiments are reset before the learning phase begins, start learning log.info("Start learning process of agent in environment.") try: self.model.learn( total_timesteps=total_timesteps, callback=callback_learn, tb_log_name=self.config_run.name, ) except OSError: filename = str(self.config_run.series_results_path / f"{self.config_run.name}_model_before_error.pkl") log.info(f"Saving model to file: {filename}.") self.model.save(filename) raise try: log.debug("Resetting environment one more time to call environment callback one last time.") self.environments.reset() except ValueError as e: msg = "An error occurred when the environment is resetting." raise ValueError(msg) from e # Save model log.debug(f"Saving model to file: {self.config_run.run_model_path}.") self.model.save(self.config_run.run_model_path) from stable_baselines3.common.vec_env import VecNormalize # noqa: PLC0415 if isinstance(self.environments, VecNormalize): log.debug(f"Saving environment normalization data to file: {self.config_run.vec_normalize_path}.") self.environments.save(str(self.config_run.vec_normalize_path)) log.info(f"Learning finished: {series_name} / {run_name}")
[docs] def play(self, *, series_name: str | None = None, run_name: str | None = None, run_description: str = "") -> None: """Play with previously learned agent model in environment. :param series_name: Name for a series of runs. :param run_name: Name for a specific run. :param run_description: Description for a specific run. """ with self.prepare_environments_models( series_name=series_name, run_name=run_name, run_description=run_description, reset=False, training=False ): if self.config_run is None: msg = ( "Set the config_run attribute before trying to initialize the model (e.g. by calling prepare_run)." ) raise ValueError(msg) if self.config.settings.n_episodes_play is None: msg = "Missing configuration value for playing: 'n_episodes_play' in section 'settings'" raise ValueError(msg) # Log some information about the model and configuration log_net_arch(self.model, self.config_run) log_run_info(self.config, self.config_run) n_episodes_stop = self.config.settings.n_episodes_play # Reset the environments before starting to play try: log.debug("Resetting environments before starting to play.") observations = self._reset_envs() except Exception: log.exception("Resetting environments failed") raise n_episodes = 0 log.debug("Start playing process of agent in environment.") _round_actions = self.config.settings.round_actions _scale_actions = self.config.settings.scale_actions if self.config.settings.scale_actions is not None else 1 while n_episodes < n_episodes_stop: try: observations, dones = self._play_step(_round_actions, _scale_actions, observations) except BaseException as e: log.exception( "Exception occurred during an environment step. Aborting and trying to reset environments." ) try: observations = self._reset_envs() except BaseException as followup_exception: raise e from followup_exception log.debug("Environment reset successful - re-raising exception") raise n_episodes += sum(dones)
def _play_step( self, _round_actions: int | None, _scale_actions: float, observations: VecEnvObs ) -> tuple[VecEnvObs, np.ndarray]: # set policy prediction to deterministic for playing; type: ignore # Type ignored because typing in eta_ctrl is bad action, _ = self.model.predict(observation=observations, deterministic=True) # type: ignore[arg-type] # Round and scale actions if required if _round_actions is not None: action = np.round(action * _scale_actions, _round_actions) else: action *= _scale_actions observations, _rewards, dones, _ = self.environments.step(action) return observations, dones def _reset_envs(self) -> VecEnvObs: """Reset the environment before and afterwards when the play and learn function is calling. :return: Observations after reset. """ log.debug("Resetting environments.") self.environments.seed(self.config.settings.seed) return self.environments.reset()