Source code for eta_ctrl.agents.mpc_agent

from __future__ import annotations

from logging import getLogger
from pathlib import Path
from typing import TYPE_CHECKING, Self, cast

import pyomo.environ as pyo
from gymnasium import spaces
from gymnasium.vector.utils import create_empty_array, iterate
from pyomo import opt
from pyomo.common.errors import InfeasibleConstraintException
from stable_baselines3.common.base_class import BaseAlgorithm
from stable_baselines3.common.vec_env import VecEnv, VecNormalize

from eta_ctrl.common.sb3_extensions.policies import NoPolicy
from eta_ctrl.simulators import PyomoModel
from eta_ctrl.util.utils import import_class_from_module

if TYPE_CHECKING:
    from collections.abc import Callable
    from typing import Any

    import numpy as np
    from stable_baselines3.common.policies import BasePolicy
    from stable_baselines3.common.type_aliases import MaybeCallback

log = getLogger(__name__)


[docs] class MpcAgent(BaseAlgorithm): """Simple, Pyomo based optimization agent supporting multiple solvers. The MpcAgent requires a PyomoModel which is passed via the `model_import` parameter. It must be defined in the config under the 'agent_specific' section. Common stablebaselines3 parameters are ignored for the MpcAgent as it cannot be used for learning or training. It can only be used to predict actions via predict(); use EtaCtrl.play() to run experiments. :param policy: Agent policy. Parameter is not used in this agent :param env: Environment to be optimized :param sampling_time: Interval for one timestep. Used to calcucate n_prediction_steps :param prediction_horizon: Duration of the prediction in seconds (usually a subsample of the episode duration) :param model_import: Dotted import path to the PyomoModel subclass (e.g. ``"my_package.my_module.MyModel"``) :param verbose: Logging verbosity :param solver_name: Name of the solver (e.g. gurobi, cplex, or glpk). Is passed to ``pyomo.SolverFactory``. :param action_index: Index of the solution value to be used as action (by default the value for the first timestep in the solution will be used) :param model_parameters: Dictionary of parameters to forward to the PyomoModel :param solver_options: Dictionary of solver options (e.g. time limits or tolerances) :param solver_callback: Optional callback function called after each solve step """ def __init__( self, env: VecEnv, sampling_time: float, prediction_horizon: float, model_import: str, verbose: int = 1, *, solver_name: str = "cplex", action_index: int = 0, model_parameters: dict[str, Any] | None = None, solver_options: dict[str, Any] | None = None, solver_callback: Callable[[BaseAlgorithm], None] | None = None, **kwargs: Any, ) -> None: super().__init__( policy=NoPolicy, env=env, learning_rate=0.0, verbose=verbose, use_sde=False, monitor_wrapper=False, supported_action_spaces=(spaces.Box,), ) log.setLevel(int(verbose * 10)) # Set logging verbosity if isinstance(self.get_env(), VecNormalize): msg = "The MPC agent does not allow the use of normalized environments." raise TypeError(msg) #: Specification of the order in which action values should be returned. self.actions_order = self.get_env().get_attr("state_config", 0)[0].actions # Solver parameters #: Name of the solver to be used self.solver_name: str = solver_name #: Index of the solution value to be used as action (if this is 0, the first value in a list #: of solution values will be used). self.action_index = action_index #: Additional callback for predicting self.solver_callback = solver_callback #: Pyomo solver instance self.solver = pyo.SolverFactory(self.solver_name) self.solver.options.update(solver_options or {}) # Adjust solver settings self.policy_class: type[BasePolicy] target_class: type[PyomoModel] = import_class_from_module(model_import, base_class=PyomoModel) #: PyomoModel instance self.model: PyomoModel = target_class( model_parameters=model_parameters, sampling_time=sampling_time, prediction_horizon=prediction_horizon, ) # Shortcut for model access self.concrete_model: pyo.ConcreteModel = self.model.model self._setup_model() def _setup_model(self) -> None: """Required method by the BaseAlgorithm interface.""" if self.policy_class is not None: self.policy: type[BasePolicy] = self.policy_class( # type: ignore[assignment] self.observation_space, self.action_space, )
[docs] def get_env(self) -> VecEnv: """Helper method for type annotation.""" if self.env is None: msg = "Can't access attribute 'self.env', initialize environment first" raise AttributeError(msg) return self.env
[docs] def solve(self) -> pyo.ConcreteModel: """Solve the current pyomo model instance with given parameters. This could also be used separately to solve normal MILP problems. Since the entire problem instance is returned, result handling can be outsourced. :return: Solved pyomo model instance. """ _tee: bool = bool(log.level / 10 <= 1) result = self.solver.solve(self.model.model, symbolic_solver_labels=True, tee=_tee) if _tee: print("\n") # noqa: T201 (print is ok here, because cplex prints directly to console). log.debug( "Problem information:\n%s\n%s\n%s", "\t+----------------------------------+", "\n".join( f"\t {item}: {value.value} " for item, value in result["Problem"][0].items() if not isinstance(value.value, opt.UndefinedData) ), "\t+----------------------------------+", ) # Log status after the optimization log.info( "Solver information:\n%s\n%s\n%s", "\t+----------------------------------+", "\n".join( f"\t {item}: {value.value} " for item, value in result["Solver"][0].items() if item != "Statistics" and not isinstance(value.value, opt.UndefinedData) ), "\t+----------------------------------+", ) # Log status after the optimization if len(result["Solution"]) >= 1: log.debug( "Solution information:\n%s\n%s\n\t%s", "\t+----------------------------------+", "\n".join( f"\t {item}: {value.value} " for item, value in result["Solution"][0].items() if not isinstance(value.value, opt.UndefinedData) ), "\t+----------------------------------+", ) # Check if no optimal solution could be found if not opt.check_optimal_termination(result): self.handle_solve_failed(result=result) return self.model.model
[docs] def predict( self, observation: np.ndarray | dict[str, np.ndarray], state: tuple[np.ndarray, ...] | None = None, episode_start: np.ndarray | None = None, deterministic: bool = False, ) -> tuple[np.ndarray, tuple[np.ndarray, ...] | None]: """ Solve the current pyomo model instance with given parameters and observations and return the optimal actions. :param observation: the input observation (not used here). :param state: The last states (not used here). :param episode_start: The last masks (not used here). :param deterministic: Whether to return deterministic actions. This agent always returns deterministic actions. :return: Tuple of the model's action and the next state (not used here). """ action_array: np.ndarray = create_empty_array(self.action_space, n=self.get_env().num_envs) # type: ignore[assignment] # Return actions for each environment for idx, env_obs in enumerate(iterate(self.observation_space, observation)): env_obs_: dict = env_obs # for typing only, must be of type dictionary # Update model parameters with environment observations self.model.pyo_update_params(env_obs_) # Solve the model for actions self.solve() if self.solver_callback: self.solver_callback(self) # Aggregate the agent actions from pyomo component objects solution = {} for com in self.model.model.component_objects(pyo.Var): com = cast("pyo.Var", com) if isinstance(com, pyo.ScalarVar): continue try: solution[com.name] = pyo.value(com[com.index_set().at(self.action_index + 1)]) # index is 1-based except (ValueError, KeyError) as e: model_name = type(self.model).__name__ msg = f"Couldn't fetch the value for action '{com.name}' in the PyomoModel {model_name}" raise ValueError(msg) from e for i, action in enumerate(self.actions_order): log.debug(f"Action '{action}' value: {solution[action]}") action_array[idx][i] = solution[action] return action_array, state
[docs] def learn( self, total_timesteps: int, callback: MaybeCallback = None, log_interval: int = 100, tb_log_name: str = "run", reset_num_timesteps: bool = True, progress_bar: bool = False, ) -> Self: """The MPC approach cannot learn a new model. Specify the model attribute as a pyomo Concrete model instead, to use the prediction function of this agent. :param total_timesteps: The total number of samples (env steps) to train on :param callback: callback(s) called at every step with state of the algorithm. :param log_interval: The number of timesteps before logging. :param tb_log_name: the name of the run for TensorBoard logging :param reset_num_timesteps: whether or not to reset the current timestep number (used in logging) :param progress_bar: Display a progress bar using tqdm and rich. :return: The trained model. """ return self
[docs] def handle_solve_failed(self, result: Any) -> None: """Called when the solver did not reach an optimal solution. If a feasible (suboptimal) solution exists, logs a warning and returns so the caller can continue with that solution. If no feasible solution exists, logs full diagnostics and raises an :exc:`InfeasibleConstraintException`. :param result: Result object returned by the Pyomo solver. """ if len(result["Solution"]) != 0: gap_info = "" if "Gap" in result["Solution"][0]: gap_value = result["Solution"][0]["Gap"].value if not isinstance(gap_value, opt.UndefinedData): gap_info = f" (achieved MIP gap: {gap_value})" log.warning( "Solver did not reach optimal solution%s. " "Termination condition: %s, Status: %s. " "Continuing with best available solution.", gap_info, result.solver.termination_condition, result.solver.status, ) return log.error( "Solver failed: no feasible solution found. Termination condition: %s, Status: %s. %s", result.solver.termination_condition, result.solver.status, f"Message: {getattr(result.solver, 'message', '')}", ) result_str = str(result) if len(result_str) > 10000: try: log_dir = Path(self.get_env().get_attr("config_run", 0)[0].results_path) result_file = log_dir / f"solver_result_failure_{self.num_timesteps}.txt" result_file.write_text(result_str, encoding="utf-8") log.debug("Full solver result saved to: %s", result_file) except (OSError, AttributeError, IndexError, TypeError) as e: log.warning("Could not save result to disk: %s. Logging truncated version.", e) log.debug("Solver result (truncated): %s", result_str[:5000] + "...") else: log.debug("Full solver result: %s", result) msg = "Solver failed to find feasible solution." raise InfeasibleConstraintException(msg)