Source code for eta_ctrl.envs.base_env

from __future__ import annotations

import abc
import inspect
import pathlib
import time
from datetime import datetime, timedelta
from logging import getLogger
from typing import TYPE_CHECKING, cast

import numpy as np
from gymnasium import Env, spaces

from eta_ctrl.util import csv_export
from eta_ctrl.util.utils import timestep_to_seconds

if TYPE_CHECKING:
    from collections.abc import Callable, Iterator, Mapping, Sequence
    from typing import Any

    from eta_ctrl.config import ConfigRun
    from eta_ctrl.envs.state import StateConfig
    from eta_ctrl.timeseries.scenario_manager import ScenarioManager
    from eta_ctrl.util.type_annotations import ObservationType, Path, StepResult, TimeStep


log = getLogger(__name__)


[docs] class BaseEnv(Env, abc.ABC): """Abstract environment definition, providing some basic functionality for concrete environments to use. The class implements and adapts functions from gymnasium.Env. It provides additional functionality as required by the ETA Ctrl framework and should be used as the starting point for new environments. The initialization of this superclass performs many of the necessary tasks, required to specify a concrete environment. Read the documentation carefully to understand, how new environments can be developed, building on this starting point. There are some class attributes that must be set and some methods that must be implemented to satisfy the interface. This is required to create concrete environments. The required class attributes are: - **version**: Version number of the environment. - **description**: Short description string of the environment. The gymnasium interface requires the following methods for the environment to work correctly within the framework. Consult the documentation of each method for more detail. - **step()** - **reset()** - **close()** - **render()** .. note:: Subclasses should implement the private _step and _reset methods rather than overriding the public step and reset methods. The public methods handle the Gymnasium interface and state management automatically. :param env_id: Identification for the environment, useful when creating multiple environments. :param config_run: Configuration of the optimization run. :param verbose: Verbosity to use for logging. :param callback: callback that should be called after each episode. :param state_modification_callback: callback that should be called after state setup, before logging the state. :param episode_duration: Duration of the episode in seconds. :param sampling_time: Duration of a single time sample / time step in seconds. :param render_mode: Renders the environments to help visualise what the agent see, examples modes are "human", "rgb_array", "ansi" for text. :param path_env: Explicit path to the environment directory. If not provided, the path will be automatically detected from the call stack. If detection fails, falls back to current working directory. :param kwargs: Other keyword arguments (for subclasses). """ @property @abc.abstractmethod def version(self) -> str: """Version of the environment. Needs to be implemented for each subclass as a class attribute. """ raise NotImplementedError @property @abc.abstractmethod def description(self) -> str: """Long description of the environment. Needs to be implemented for each subclass as a class attribute. """ raise NotImplementedError def __init__( self, env_id: int, config_run: ConfigRun, state_config: StateConfig, verbose: int = 2, callback: Callable | None = None, state_modification_callback: Callable | None = None, seed: int | None = None, *, episode_duration: TimeStep | str, sampling_time: TimeStep | str, sim_steps_per_sample: int | str = 1, scenario_manager: ScenarioManager | None = None, render_mode: str | None = None, path_env: Path | None = None, **kwargs: Any, ) -> None: super().__init__() #: Verbosity level used for logging. self.verbose: int = verbose log.setLevel(int(verbose * 10)) # Set some standard path settings #: Information about the optimization run and information about the paths. #: For example, it defines results_path and scenarios_path. self.config_run: ConfigRun = config_run #: Callback can be used for logging and plotting. self.callback: Callable | None = callback #: Callback can be used for modifying the state at each time step. self.state_modification_callback: Callable | None = state_modification_callback #: ID of the environment (useful for vectorized environments). self.env_id: int = int(env_id) #: Render mode for rendering the environment self.render_mode: str | None = render_mode #: Duration of one episode in seconds. self.episode_duration: float = timestep_to_seconds(episode_duration) #: Sampling time (interval between optimization time steps) in seconds. self.sampling_time: float = timestep_to_seconds(sampling_time) #: Number of time steps (of width sampling_time) in each episode. self.n_episode_steps: int = int(self.episode_duration // self.sampling_time) #: Number of simulation steps to be taken for each sample. This must be a divisor of 'sampling_time'. self.sim_steps_per_sample: int = int(sim_steps_per_sample) #: State Configuration for defining State Variables. self.state_config: StateConfig = state_config self.action_space, self.observation_space = self.state_config.continuous_spaces() #: Manager to load scenario data into the state self.scenario_manager = scenario_manager #: Explicit path override for environment directory (used if auto-detection fails) self._path_env_override: pathlib.Path | None = pathlib.Path(path_env) if path_env is not None else None if seed is not None: # Initialize random generator self.action_space.seed(seed=seed) self.observation_space.seed(seed=seed) self._init_attributes() def _init_attributes(self) -> None: """Initialize environment attributes that don't depend on constructor arguments.""" detected_path: pathlib.Path | None = None if self._path_env_override is not None: detected_path = self._path_env_override else: try: detected_path = pathlib.Path(inspect.getfile(type(self))).resolve().parent except (TypeError, OSError): detected_path = None if detected_path is None: detected_path = pathlib.Path.cwd() log.warning( f"Could not automatically detect environment path for {type(self).__name__}. " f"Falling back to current working directory: {detected_path}. " "Consider passing 'path_env' explicitly to the constructor if this is incorrect." ) #: Path of the environment file. self.path_env: pathlib.Path = detected_path # Store data logs and log other information #: Episode timer (stores the start time of the episode). self.episode_timer: float = time.time() #: Current state of the environment. self.state: dict[str, np.ndarray] #: Additional state information to append to the state during stepping and reset self.additional_state: dict[str, float] | None = None #: Log of the environment state. self.state_log: list[dict[str, np.ndarray]] = [] #: Log of the environment state over multiple episodes. self.state_log_longtime: list[list[dict[str, np.ndarray]]] = [] #: Number of completed episodes. self.n_episodes: int = 0 #: Current step of the model (number of completed steps) in the current episode. self.n_steps: int = 0 #: Current step of the model (total over all episodes). self.n_steps_longtime: int = 0 @property def run_name(self) -> str: #: Name of the current optimization run. return self.config_run.name @property def results_path(self) -> pathlib.Path: #: Path for storing results. return self.config_run.results_path @property def scenarios_path(self) -> pathlib.Path | None: #: Path for the scenario data. return self.config_run.scenarios_path @property def series_results_path(self) -> pathlib.Path: #: Path for storing results of series of runs. return self.config_run.series_results_path
[docs] @abc.abstractmethod def _step(self) -> tuple[float, bool, bool, dict]: """Abstract method to perform one internal time step. This private method must be implemented by subclasses to update the internal state dictionary and return step results. It should work with the internal state rather than returning observations directly. :return: Tuple of (reward, terminated, truncated, info) :meta public: """
[docs] def step(self, action: np.ndarray) -> StepResult: """Proceed one time step and return the reward for the action provided as well as the new observation. This method handles the public interface for the step operation. It validates actions, executes actions by calling the private _step method implemented by subclasses, increments n_steps, manages state updates, and returns the formatted results (reward of the previous action taken, new environment state). It also updates the state log and calls the state modification callback. :param action: Actions taken by the agent. :return: The return value represents the state of the environment after the step was performed: * **observations**: A dictionary with new observation values as defined by the observation space, automatically extracted from the internal state. * **reward**: The value of the reward function. This is just one floating point value. * **terminated (bool)**: Whether the agent reaches the terminal state (as defined under the MDP of the task) which can be positive or negative. An example is reaching the goal state or moving into the lava from the Sutton and Barto Gridworld. If true, the Vectorizer will call :meth:`reset`. * **truncated (bool)**: Whether the truncation condition outside the scope of the MDP is satisfied (i.e. the episode ended). Typically, this is a timelimit, but could also be used to indicate an agent physically going out of bounds. Can be used to end the episode prematurely before a terminal state is reached. If true, the Vectorizer will call :meth:`reset`. * **info**: Provide some additional info about the state of the environment. The contents of this may be used for logging purposes in the future but typically do not currently serve a purpose. """ # Clear state self._reset_state() # Check actions self._actions_valid(action) self.set_action(action=action) # Load scenario data for current timestep, if present. # This is the same data which has been in the prior state, # but cleared because of _reset_state(). self.set_scenario_state() # Perform the actual step in the environment reward, terminated, __truncated, info = self._step() self.n_steps += 1 # Call self._truncated() after incrementing n_steps truncated = __truncated or self._truncated() # Load scenario data from next timestep for observations, if present self.set_scenario_state() # Execute optional state modification callback function if self.state_modification_callback: self.state_modification_callback(self) self.state_log.append(self.state) # Render the environment at each step if self.render_mode is not None: self.render() return self.get_observations(), reward, terminated, truncated, info
def _actions_valid(self, action: np.ndarray | dict) -> None: """Check whether the actions are within the specified action space. :param action: Actions taken by the agent. :raise: RuntimeError, when the actions are not inside of the action space. """ if not self.action_space.contains(action): error_msg = self._build_action_error_message(action) raise RuntimeError(error_msg) def _build_action_error_message(self, action: np.ndarray | dict) -> str: """Build a detailed error message explaining why the action is invalid. :param action: The invalid action that was provided. :return: Detailed error message string. """ error_parts = ["Action validation failed!"] error_parts.append(f"\nReceived action: {self._format_array(action)}") error_parts.append(f"Action space: {self.action_space}") # Delegate to specific validators based on space type if isinstance(self.action_space, spaces.Box): error_parts.extend(self._validate_box_action(cast("np.ndarray", action), self.action_space)) elif isinstance(self.action_space, spaces.Discrete): error_parts.extend(self._validate_discrete_action(cast("np.ndarray", action), self.action_space)) elif isinstance(self.action_space, spaces.MultiDiscrete): error_parts.extend(self._validate_multi_discrete_action(cast("np.ndarray", action), self.action_space)) elif isinstance(self.action_space, spaces.Dict): error_parts.extend(self._validate_dict_action(action, self.action_space)) else: error_parts.append("\nThe action does not match the expected action space type.") return "\n".join(error_parts) def _add_shape_error(self, errors: list[str], expected: tuple, received: tuple) -> None: """Add shape mismatch error details to error list :param errors: List to append error messages to. :param expected: Expected shape. :param received: Received shape. """ errors.append("\nShape mismatch:") errors.append(f" Expected: {expected}") errors.append(f" Received: {received}") if len(expected) == 1 and len(received) == 1: errors.append(f" → Expected {expected[0]} action(s), but received {received[0]} action(s)") def _add_violations(self, errors: list[str], violations: list[str], violation_type: str = "Bound") -> None: """Add violation details to error list with truncationS :param errors: List to append error messages to. :param violations: List of violation messages. :param violation_type: Type of violation (e.g., "Bound", "Value"). """ if violations: errors.append(f"\n{violation_type} violations ({len(violations)} found):") errors.extend(violations[:10]) if len(violations) > 10: errors.append(f" ... and {len(violations) - 10} more violation(s)") def _validate_box_action(self, action: np.ndarray, space: spaces.Box) -> list[str]: """Validate Box space action and return specific error details. :param action: The action to validate. :param space: The Box space to validate against. :return: List of error message parts. """ errors: list[str] = [] # Check shape if action.shape != space.shape: self._add_shape_error(errors, space.shape, action.shape) return errors # Check dtype compatibility (warn but don't fail for float32 vs float64) if action.dtype != space.dtype and not ( np.issubdtype(action.dtype, np.floating) and np.issubdtype(space.dtype, np.floating) ): errors.append("\nData type mismatch:") errors.append(f" Expected: {space.dtype}") errors.append(f" Received: {action.dtype}") # Check bounds violations = [] action_flat = action.flatten() low_flat = np.broadcast_to(space.low, space.shape).flatten() high_flat = np.broadcast_to(space.high, space.shape).flatten() for idx, (val, low, high) in enumerate(zip(action_flat, low_flat, high_flat, strict=False)): if val < low: violations.append(f" - action[{idx}] = {val:.6g} is below minimum bound of {low:.6g}") elif val > high: violations.append(f" - action[{idx}] = {val:.6g} exceeds maximum bound of {high:.6g}") self._add_violations(errors, violations, "Bound") return errors def _validate_discrete_action(self, action: np.ndarray | int, space: spaces.Discrete) -> list[str]: """Validate Discrete space action and return specific error details. :param action: The action to validate. :param space: The Discrete space to validate against. :return: List of error message parts. """ errors: list[str] = [] # Convert to int if it's an array if isinstance(action, np.ndarray): if action.size != 1: self._add_shape_error(errors, (1,), action.shape) return errors action_val = int(action.item()) else: action_val = int(action) # Check bounds space_start = int(space.start) space_n = int(space.n) if action_val < space_start or action_val >= space_start + space_n: errors.append("\nValue out of range:") errors.append(f" Valid range: [{space_start}, {space_start + space_n - 1}]") errors.append(f" Received: {action_val}") return errors def _validate_multi_discrete_action(self, action: np.ndarray, space: spaces.MultiDiscrete) -> list[str]: """Validate MultiDiscrete space action and return specific error details. :param action: The action to validate. :param space: The MultiDiscrete space to validate against. :return: List of error message parts. """ errors: list[str] = [] # Check shape if action.shape != space.nvec.shape: self._add_shape_error(errors, space.nvec.shape, action.shape) return errors # Check individual action values violations = [] action_flat = action.flatten() nvec_flat = space.nvec.flatten() start_flat = np.broadcast_to(space.start, space.nvec.shape).flatten() for idx, (val, n, start) in enumerate(zip(action_flat, nvec_flat, start_flat, strict=False)): if val < start or val >= start + n: violations.append(f" - action[{idx}] = {val} is outside valid range [{start}, {start + n - 1}]") self._add_violations(errors, violations, "Value") return errors def _validate_dict_action(self, action: np.ndarray | dict, space: spaces.Dict) -> list[str]: """Validate Dict space action and return specific error details. :param action: The action to validate. :param space: The Dict space to validate against. :return: List of error message parts. """ errors: list[str] = [] errors.append("\nDict action space validation failed.") errors.append(f" Expected a dictionary with keys: {list(space.spaces.keys())}") if not isinstance(action, dict): errors.append(f" Received type: {type(action).__name__}") errors.append(" → Dict action spaces require actions to be dictionaries, not arrays") else: missing_keys = set(space.spaces.keys()) - set(action.keys()) extra_keys = set(action.keys()) - set(space.spaces.keys()) if missing_keys: errors.append(f" Missing keys: {list(missing_keys)}") if extra_keys: errors.append(f" Unexpected keys: {list(extra_keys)}") return errors def _format_array(self, arr: np.ndarray | dict, max_items: int = 10) -> str: """Format a numpy array or dict for display in error messages. :param arr: Array or dict to format. :param max_items: Maximum number of items to display before truncating. :return: Formatted string representation. """ # Handle dict actions (for Dict action spaces) if isinstance(arr, dict): return str(arr) # Handle numpy arrays - cast needed for mypy type narrowing arr = cast("np.ndarray", arr) if arr.size <= max_items: return str(arr) # For large arrays, show first few and last few elements arr_flat = arr.flatten() first_items = arr_flat[: max_items // 2] last_items = arr_flat[-(max_items // 2) :] formatted = f"[{' '.join(f'{x:.6g}' for x in first_items)} ... {' '.join(f'{x:.6g}' for x in last_items)}]" return f"{formatted} (shape: {arr.shape}, dtype: {arr.dtype})" def _reset_state(self) -> None: """Clear self.state and initialize with additional_state.""" self.state = {} if self.additional_state is not None: additional_state = {name: np.array([value]) for name, value in self.additional_state.items()} self.state.update(additional_state) def _truncated(self) -> bool: """Check if the episode is over using the number of steps (n_steps) and the total number of steps in an episode (n_episode_steps). :return: boolean showing, whether the episode is over (truncated by its time limit). """ return self.n_steps >= self.n_episode_steps
[docs] @abc.abstractmethod def _reset( self, *, seed: int | None = None, options: dict[str, Any] | None = None, ) -> dict[str, Any]: """Abstract reset method that must be implemented by subclasses. :meta public: """
[docs] def reset( self, *, seed: int | None = None, options: dict[str, Any] | None = None, ) -> tuple[ObservationType, dict[str, Any]]: """Reset the environment to an initial internal state, returning an initial observation and info. This method generates a new starting state often with some randomness to ensure that the agent explores the state space and learns a generalised policy about the environment. This randomness can be controlled with the ``seed`` parameter. :param seed: The seed for initializing any randomized components of the state. Subclasses should use this for reproducible randomness in their state init :param options: Additional information to specify how the environment is reset (optional, depending on the specific environment) (default: None) :return: Tuple of observation and info. The observation of the initial state will be an element of :attr:`observation_space` (typically a numpy array) and is analogous to the observation returned by :meth:`step`. Info is a dictionary containing auxiliary information complementing ``observation``. It should be analogous to the ``info`` returned by :meth:`step`. """ # Reset state_log and counters if self.n_steps > 0: self._reset_episode() # Clear state self._reset_state() # Set rng seed (only has a value on first reset of first episode) Env.reset(self, seed=seed) # Create separate rng for the scenario manager to ensure deterministic values # The value for self.np_random_seed is set by Env.reset() if seed is not None or not hasattr(self, "_scenario_rng"): self._scenario_rng = np.random.default_rng(self.np_random_seed) # Update with scenario data, if present self.set_scenario_state(reset=True) # Set initial observations in child class info = self._reset(options=options) # Execute optional state modification callback function if self.state_modification_callback: self.state_modification_callback(self) # Log state self.state_log.append(self.state) # Render the environment if self.render_mode is not None: self.render() return self.get_observations(), info
def _reduce_state_log(self) -> list[dict[str, np.ndarray]]: """Remove unwanted parameters from state_log before storing in state_log_longtime. :return: The return value is a list of dictionaries, where the parameters that should not be stored were removed """ allowed_keys = set(self.state_config.add_to_state_log) return [{k: v for k, v in entry.items() if k in allowed_keys} for entry in self.state_log] def _reset_episode(self) -> None: """Store episode statistics and reset episode counters.""" if self.callback is not None: self.callback(self) # Store some logging data self.n_episodes += 1 # store reduced_state_log in state_log_longtime self.state_log_longtime.append(self._reduce_state_log()) self.n_steps_longtime += self.n_steps # Reset episode variables self.n_steps = 0 self.episode_timer = time.time() self.state_log = []
[docs] @abc.abstractmethod def close(self) -> None: """Close the environment. This should always be called when an entire run is finished. It should be used to close any resources (i.e. simulation models) used by the environment. """ msg = "Cannot close an abstract Environment." raise NotImplementedError(msg)
[docs] @abc.abstractmethod def render(self) -> None: """Render the environment. The set of supported modes varies per environment. Some environments do not support rendering at all. By convention in Farama *gymnasium*, if mode is: * human: render to the current display or terminal and return nothing. Usually for human consumption. * rgb_array: Return a numpy.ndarray with shape (x, y, 3), representing RGB values for an x-by-y pixel image, suitable for turning into a video. * ansi: Return a string (str) or StringIO.StringIO containing a terminal-style text representation. The text can include newlines and ANSI escape sequences (e.g. for colors). """ msg = "Cannot render an abstract Environment." raise NotImplementedError(msg)
[docs] @classmethod def get_info(cls) -> tuple[str, str]: """Get info about environment. :return: Tuple of version and description. """ return cls.version, cls.description # type: ignore[return-value]
def __str__(self) -> str: """Human-readable string representation of BaseEnv.""" env_class = self.__class__.__name__ n_actions = len(self.state_config.actions) n_observations = len(self.state_config.observations) status = f"Episode {self.n_episodes}, Step {self.n_steps}/{self.n_episode_steps}" return f"{env_class}(id={self.env_id}, {n_actions} actions, {n_observations} observations, {status})" def __repr__(self) -> str: """Developer-friendly string representation of BaseEnv.""" env_class = self.__class__.__name__ return ( f"{env_class}(env_id={self.env_id}, run_name='{self.run_name}', " f"n_episodes={self.n_episodes}, n_steps={self.n_steps}, " f"episode_duration={self.episode_duration}, sampling_time={self.sampling_time})" )
[docs] def export_state_log( self, path: Path, names: Sequence[str] | None = None, *, sep: str = ";", decimal: str = ".", ) -> None: """Extension of csv_export to include timeseries on the data. :param names: Field names used when data is a Matrix without column names. :param sep: Separator to use between the fields. :param decimal: Sign to use for decimal points. """ start_time = datetime.fromtimestamp(self.episode_timer) step = self.sampling_time / self.sim_steps_per_sample timerange = [start_time + timedelta(seconds=(k * step)) for k in range(len(self.state_log))] csv_export(path=path, data=self.state_log, index=timerange, names=names, sep=sep, decimal=decimal)
[docs] def get_observations(self) -> dict[str, np.ndarray]: """Gather observations from the state. :raises KeyError: Observation is not available in state :return: Filtered observations as a dictionary. :rtype: dict[str, np.ndarray] """ observations = {} for name in self.state_config.observations: try: observations[name] = self.state[name] except KeyError as e: msg = f"Observation {e!s} is unavailable in environment state." raise KeyError(msg) from e return observations
[docs] def get_external_inputs(self) -> dict[str, int | float | bool | str]: """Gather external inputs from the state. Uses scalar values instead of numpy arrays for values. :raises KeyError: External input is not available in state :raises ValueError: External input value is not scalar :return: Filtered external inputs with external id as keys. :rtype: dict[str, int | float | bool | str] """ external_inputs: dict[str, int | float | bool | str] = {} for name in self.state_config.ext_inputs: ext_id = self.state_config.map_ext_ids[name] state_var = self.state_config.vars[name] try: scaled_value = self.state[name].item() except KeyError as e: msg = f"{e!s} is unavailable in environment state." raise KeyError(msg) from e except ValueError as e: msg = "External Inputs can't have multiple values" raise ValueError(msg) from e # Check for boolean FIRST (since bool is a subclass of int in Python) if isinstance(scaled_value, (bool, np.bool_)): # Preserve non-numeric values as-is, cast np.bool_ to Python bool external_inputs[ext_id] = bool(scaled_value) elif isinstance(scaled_value, (int, float, np.integer, np.floating)): # Only scale numeric values (int and float), cast to Python float external_inputs[ext_id] = float(scaled_value / state_var.ext_scale_mult - state_var.ext_scale_add) else: # Preserve other non-numeric values as-is (string, etc.) external_inputs[ext_id] = scaled_value return external_inputs
[docs] def set_action(self, action: np.ndarray | dict[str, np.ndarray]) -> None: """Set action values in the state. :param action: Actions to be set. :type action: np.ndarray | dict[str, np.ndarray] """ iterator: Iterator if isinstance(action, np.ndarray): iterator = zip(self.state_config.actions, action, strict=True) else: iterator = iter(action.items()) for name, value in iterator: val = value if isinstance(value, np.ndarray) else np.array([value]) self.state[name] = val
[docs] def set_external_outputs(self, external_outputs: Mapping[str, int | float | bool | str]) -> None: """Set external outputs in the state. Accepts scalars instead of numpy arrays as values. :param external_outputs: Dict of external outputs with external_ids as keys. :type external_outputs: Mapping[str, int | float | bool | str] :raises KeyError: Received an unknown external id """ for name in self.state_config.ext_outputs: state_var = self.state_config.vars[name] try: unscaled_value = external_outputs[state_var.ext_id] # type: ignore[index] except KeyError as e: msg = f"Missing value for external output: {name}" raise KeyError(msg) from e # Check for boolean FIRST (since bool is a subclass of int in Python) if isinstance(unscaled_value, (bool, np.bool_)): # Preserve boolean with explicit dtype to avoid conversion to float self.state[name] = np.array([unscaled_value], dtype=bool) elif isinstance(unscaled_value, (int, float, np.integer, np.floating)): # Only scale numeric values (int and float) scaled_value = (unscaled_value + state_var.ext_scale_add) * state_var.ext_scale_mult self.state[name] = np.array([scaled_value]) else: # Preserve other non-numeric values as-is (string, etc.) self.state[name] = np.array([unscaled_value])
[docs] def set_scenario_state(self, reset: bool = False) -> None: """Set scenario output values for the current timestep in the state. :param reset: Indicator whether this was called from the reset method """ if self.scenario_manager is None: return # Compute new offset after reset if reset: self._scenario_offset = self.scenario_manager.compute_episode_offset(self._scenario_rng) for state_name in self.state_config.scenario_outputs: state_var = self.state_config.vars[state_name] unscaled_data = self.scenario_manager.get_scenario_state_var( n_step=self.n_steps + self._scenario_offset, state_var=state_var ) scaled_data = (unscaled_data + state_var.ext_scale_add) * state_var.ext_scale_mult self.state[state_name] = scaled_data