from __future__ import annotations
from contextlib import contextmanager
from datetime import datetime
from logging import getLogger
from typing import TYPE_CHECKING
import numpy as np
from eta_ctrl.common import (
CallbackEnvironment,
is_closed,
log_net_arch,
log_run_info,
log_to_file,
merge_callbacks,
)
from eta_ctrl.config import Config, ConfigRun
from .core_utils import (
initialize_model,
load_model,
vectorize_environment,
)
if TYPE_CHECKING:
from collections.abc import Generator, Mapping
from typing import Any
from stable_baselines3.common.base_class import BaseAlgorithm
from stable_baselines3.common.type_aliases import MaybeCallback
from stable_baselines3.common.vec_env import VecEnv, VecNormalize
from stable_baselines3.common.vec_env.base_vec_env import VecEnvObs
from eta_ctrl.util.type_annotations import Path
log = getLogger(__name__)
[docs]
class EtaCtrl:
"""Initialize an optimization model and provide interfaces for optimization, learning and execution (play).
:param config_name: Name of configuration file in configuration directory (should be JSON format).
:param root_path: Root path of the application (the configuration will be interpreted relative to this).
:param config_overwrite: Dictionary to overwrite selected configurations.
:param config_relpath: Path to configuration file, relative to root path.
"""
def __init__(
self,
config_name: str,
root_path: Path | None = None,
config_overwrite: Mapping[str, Any] | None = None,
config_relpath: Path | None = None,
) -> None:
#: Create Config object for the optimization run.
self.config: Config = Config.from_file(
root_path=root_path, config_relpath=config_relpath, config_name=config_name, overwrite=config_overwrite
)
log.setLevel(int(self.config.settings.verbose * 10))
#: Configuration for an optimization run.
self.config_run: ConfigRun | None = None
#: The vectorized environments.
self._environments: VecEnv | VecNormalize | None = None
#: The model or algorithm.
self._model: BaseAlgorithm | None = None
def __str__(self) -> str:
"""Human-readable string representation of EtaCtrl."""
env_class = self.config.setup.environment_class.__name__
agent_class = self.config.setup.agent_class.__name__
return f"EtaCtrl(config='{self.config.config_name}', env={env_class}, agent={agent_class})"
def __repr__(self) -> str:
"""Developer-friendly string representation of EtaCtrl."""
return (
f"EtaCtrl(config_name='{self.config.config_name}', root_path='{self.config.root_path}', "
f"config_run_initialized={self.config_run is not None})"
)
@property
def environments(self) -> VecEnv | VecNormalize:
if self._environments is None:
msg = "Initialized environments could not be found. Call prepare_environments first."
raise TypeError(msg)
return self._environments
@environments.setter
def environments(self, environments: VecEnv | VecNormalize) -> None:
self._environments = environments
@property
def model(self) -> BaseAlgorithm:
if self._model is None:
msg = "Initialized model could not be found. Call prepare_environments first."
raise TypeError(msg)
return self._model
@model.setter
def model(self, model: BaseAlgorithm) -> None:
self._model = model
[docs]
@contextmanager
def prepare_environments_models(
self,
*,
series_name: str | None,
run_name: str | None,
run_description: str = "",
reset: bool = False,
training: bool = False,
) -> Generator:
if is_closed(self._environments) or self._model is None:
_series_name = series_name if series_name is not None else ""
_run_name = run_name if run_name is not None else ""
self.prepare_run(_series_name, _run_name, run_description)
with self.prepare_environments(training=training):
self.prepare_model(reset=reset)
yield
[docs]
def prepare_run(self, series_name: str, run_name: str, run_description: str = "") -> None:
"""Prepare the learn and play methods by reading configuration, creating results folders and the model.
:param series_name: Name for a series of runs.
:param run_name: Name for a specific run.
:param run_description: Description for a specific run.
"""
self.config_run = ConfigRun(
series=series_name,
name=run_name,
description=run_description,
root_path=self.config.root_path,
results_path=self.config.results_path,
scenarios_path=self.config.scenarios_path,
)
self.config_run.create_results_folders()
# Add file handler to parent logger to log the terminal output
log_to_file(config=self.config, config_run=self.config_run)
log.info("Run prepared successfully.")
[docs]
def prepare_model(self, *, reset: bool = False) -> None:
"""Check for existing model and load it or back it up and create a new model.
:param reset: Flag to determine whether an existing model should be reset.
"""
self._prepare_model(reset=reset)
def _prepare_model(self, *, reset: bool = False) -> None:
"""Check for existing model and load it or back it up and create a new model.
:param reset: Flag to determine whether an existing model should be reset.
"""
if self.config_run is None:
msg = "Set the config_run attribute before trying to initialize the model (e.g. by calling prepare_run)."
raise ValueError(msg)
model_path = self.config_run.run_model_path
if model_path.is_file() and reset:
log.info(f"Existing model detected: {model_path}")
bak_name = model_path / f"_{datetime.fromtimestamp(model_path.stat().st_mtime).strftime('%Y%m%d_%H%M')}.bak"
model_path.rename(bak_name)
log.info(f"Reset is active. Existing model will be backed up. Backup file name: {bak_name}")
elif model_path.is_file():
log.info(f"Existing model detected: {model_path}. Loading model.")
self.model = load_model(
self.config.setup.agent_class,
self.environments,
self.config.settings.agent,
self.config_run.run_model_path,
tensorboard_log=self.config.setup.tensorboard_log,
log_path=self.config_run.series_results_path,
)
return
# Initialize the model if it wasn't loaded from a file
self.model = initialize_model(
self.config.setup.agent_class,
self.config.setup.policy_class,
self.environments,
self.config.settings.agent,
self.config.settings.seed,
tensorboard_log=self.config.setup.tensorboard_log,
log_path=self.config_run.series_results_path,
)
[docs]
@contextmanager
def prepare_environments(self, *, training: bool = True) -> Generator:
"""Context manager which prepares the environments and closes them after it exits.
:param training: Should preparation be done for training (alternative: playing)?
"""
# If the agents specifies the population parameter, the number of environments usually has to be
# equal to that value as well. See NSGA-II agent.
if (
"population" in self.config.settings.agent
and self.config.settings.n_environments != self.config.settings.agent["population"]
):
if self.config.settings.n_environments != 1:
log.warning(
f"Agent specifies 'population' parameter but the number of environments "
f"({self.config.settings.n_environments}) is not equal to the population. "
f"Setting 'n_environments' to {self.config.settings.agent['population']}"
)
self.config.settings.n_environments = self.config.settings.agent["population"]
try:
self._prepare_environments(training=training)
yield
finally:
try:
log.debug("Closing environments.")
self.environments.close()
except TypeError:
log.exception("Environment initialization failed.")
def _prepare_environments(self, *, training: bool = True) -> None:
"""Vectorize and prepare the environments.
:param training: Should preparation be done for training (alternative: playing)?
"""
if self.config_run is None:
msg = "Set the config_run attribute before trying to initialize the model (e.g. by calling prepare_run)."
raise ValueError(msg)
env_class = self.config.setup.environment_class
self.config_run.set_env_info(env_class)
callback = CallbackEnvironment(self.config.settings.plot_interval)
# Vectorize the environments
self.environments = vectorize_environment(
env=env_class,
config_run=self.config_run,
env_settings=self.config.settings.environment,
callback=callback,
verbose=self.config.settings.verbose,
vectorizer=self.config.setup.vectorizer_class,
n=self.config.settings.n_environments,
seed=self.config.settings.seed,
training=training,
monitor_wrapper=self.config.setup.monitor_wrapper,
norm_wrapper_obs=self.config.setup.norm_wrapper_obs,
norm_wrapper_reward=self.config.setup.norm_wrapper_reward,
)
[docs]
def learn(
self,
*,
series_name: str | None = None,
run_name: str | None = None,
run_description: str = "",
reset: bool = False,
callbacks: MaybeCallback = None,
) -> None:
"""Start the learning job for an agent with the specified environment.
:param series_name: Name for a series of runs.
:param run_name: Name for a specific run.
:param run_description: Description for a specific run.
:param reset: Indication whether possibly existing models should be reset. Learning will be continued if
model exists and reset is false.
:param callbacks: Provide additional callbacks to send to the model.learn() call.
"""
with self.prepare_environments_models(
series_name=series_name, run_name=run_name, run_description=run_description, reset=reset, training=True
):
if self.config_run is None:
msg = (
"Set the config_run attribute before trying to initialize the model (e.g. by calling prepare_run)."
)
raise ValueError(msg)
# Log some information about the model and configuration
log_net_arch(self.model, self.config_run)
log_run_info(self.config, self.config_run)
# Genetic algorithm has a slightly different concept for saving since it does not stop between time steps
if "n_generations" in self.config.settings.agent:
save_freq = self.config.settings.save_model_every_x_episodes
total_timesteps = self.config.settings.agent["n_generations"]
else:
# Check if all required config values are present
if self.config.settings.n_episodes_learn is None:
msg = "Missing configuration values for learning: 'n_episodes_learn'."
raise ValueError(msg)
# define callback for periodically saving models
save_freq = int(
self.config.settings.episode_duration
/ self.config.settings.sampling_time
* self.config.settings.save_model_every_x_episodes
)
total_timesteps = int(
self.config.settings.n_episodes_learn
* self.config.settings.episode_duration
/ self.config.settings.sampling_time
)
# Set the seed for the environments before starting to learn
self.environments.seed(self.config.settings.seed)
from stable_baselines3.common.callbacks import CheckpointCallback # noqa: PLC0415
callback_learn = merge_callbacks(
CheckpointCallback(
save_freq=save_freq,
save_path=str(self.config_run.series_results_path / "models"),
name_prefix=self.config_run.name,
),
callbacks,
)
# The experiments are reset before the learning phase begins, start learning
log.info("Start learning process of agent in environment.")
try:
self.model.learn(
total_timesteps=total_timesteps,
callback=callback_learn,
tb_log_name=self.config_run.name,
)
except OSError:
filename = str(self.config_run.series_results_path / f"{self.config_run.name}_model_before_error.pkl")
log.info(f"Saving model to file: {filename}.")
self.model.save(filename)
raise
try:
log.debug("Resetting environment one more time to call environment callback one last time.")
self.environments.reset()
except ValueError as e:
msg = "An error occurred when the environment is resetting."
raise ValueError(msg) from e
# Save model
log.debug(f"Saving model to file: {self.config_run.run_model_path}.")
self.model.save(self.config_run.run_model_path)
from stable_baselines3.common.vec_env import VecNormalize # noqa: PLC0415
if isinstance(self.environments, VecNormalize):
log.debug(f"Saving environment normalization data to file: {self.config_run.vec_normalize_path}.")
self.environments.save(str(self.config_run.vec_normalize_path))
log.info(f"Learning finished: {series_name} / {run_name}")
[docs]
def play(self, *, series_name: str | None = None, run_name: str | None = None, run_description: str = "") -> None:
"""Play with previously learned agent model in environment.
:param series_name: Name for a series of runs.
:param run_name: Name for a specific run.
:param run_description: Description for a specific run.
"""
with self.prepare_environments_models(
series_name=series_name, run_name=run_name, run_description=run_description, reset=False, training=False
):
if self.config_run is None:
msg = (
"Set the config_run attribute before trying to initialize the model (e.g. by calling prepare_run)."
)
raise ValueError(msg)
if self.config.settings.n_episodes_play is None:
msg = "Missing configuration value for playing: 'n_episodes_play' in section 'settings'"
raise ValueError(msg)
# Log some information about the model and configuration
log_net_arch(self.model, self.config_run)
log_run_info(self.config, self.config_run)
n_episodes_stop = self.config.settings.n_episodes_play
# Reset the environments before starting to play
try:
log.debug("Resetting environments before starting to play.")
observations = self._reset_envs()
except Exception:
log.exception("Resetting environments failed")
raise
n_episodes = 0
log.debug("Start playing process of agent in environment.")
_round_actions = self.config.settings.round_actions
_scale_actions = self.config.settings.scale_actions if self.config.settings.scale_actions is not None else 1
while n_episodes < n_episodes_stop:
try:
observations, dones = self._play_step(_round_actions, _scale_actions, observations)
except BaseException as e:
log.exception(
"Exception occurred during an environment step. Aborting and trying to reset environments."
)
try:
observations = self._reset_envs()
except BaseException as followup_exception:
raise e from followup_exception
log.debug("Environment reset successful - re-raising exception")
raise
n_episodes += sum(dones)
def _play_step(
self, _round_actions: int | None, _scale_actions: float, observations: VecEnvObs
) -> tuple[VecEnvObs, np.ndarray]:
# set policy prediction to deterministic for playing; type: ignore
# Type ignored because typing in eta_ctrl is bad
action, _ = self.model.predict(observation=observations, deterministic=True) # type: ignore[arg-type]
# Round and scale actions if required
if _round_actions is not None:
action = np.round(action * _scale_actions, _round_actions)
else:
action *= _scale_actions
observations, _rewards, dones, _ = self.environments.step(action)
return observations, dones
def _reset_envs(self) -> VecEnvObs:
"""Reset the environment before and afterwards when the play and learn function is calling.
:return: Observations after reset.
"""
log.debug("Resetting environments.")
self.environments.seed(self.config.settings.seed)
return self.environments.reset()