from __future__ import annotations
import abc
import inspect
import pathlib
import time
from datetime import datetime, timedelta
from logging import getLogger
from typing import TYPE_CHECKING, cast
import numpy as np
from gymnasium import Env, spaces
from eta_ctrl.util import csv_export
from eta_ctrl.util.utils import timestep_to_seconds
if TYPE_CHECKING:
from collections.abc import Callable, Iterator, Mapping, Sequence
from typing import Any
from eta_ctrl.config import ConfigRun
from eta_ctrl.envs.state import StateConfig
from eta_ctrl.timeseries.scenario_manager import ScenarioManager
from eta_ctrl.util.type_annotations import ObservationType, Path, StepResult, TimeStep
log = getLogger(__name__)
[docs]
class BaseEnv(Env, abc.ABC):
"""Abstract environment definition, providing some basic functionality for concrete environments to use.
The class implements and adapts functions from gymnasium.Env. It provides additional functionality as required by
the ETA Ctrl framework and should be used as the starting point for new environments.
The initialization of this superclass performs many of the necessary tasks, required to specify a concrete
environment. Read the documentation carefully to understand, how new environments can be developed, building on
this starting point.
There are some class attributes that must be set and some methods that must be implemented to satisfy the interface.
This is required to create concrete environments.
The required class attributes are:
- **version**: Version number of the environment.
- **description**: Short description string of the environment.
The gymnasium interface requires the following methods for the environment to work correctly within the framework.
Consult the documentation of each method for more detail.
- **step()**
- **reset()**
- **close()**
- **render()**
.. note::
Subclasses should implement the private _step and _reset methods rather than
overriding the public step and reset methods. The public methods handle the
Gymnasium interface and state management automatically.
:param env_id: Identification for the environment, useful when creating multiple environments.
:param config_run: Configuration of the optimization run.
:param verbose: Verbosity to use for logging.
:param callback: callback that should be called after each episode.
:param state_modification_callback: callback that should be called after state setup, before logging the state.
:param episode_duration: Duration of the episode in seconds.
:param sampling_time: Duration of a single time sample / time step in seconds.
:param render_mode: Renders the environments to help visualise what the agent see, examples
modes are "human", "rgb_array", "ansi" for text.
:param path_env: Explicit path to the environment directory. If not provided, the path will be
automatically detected from the call stack. If detection fails, falls back to current working directory.
:param kwargs: Other keyword arguments (for subclasses).
"""
@property
@abc.abstractmethod
def version(self) -> str:
"""Version of the environment.
Needs to be implemented for each subclass as a class attribute.
"""
raise NotImplementedError
@property
@abc.abstractmethod
def description(self) -> str:
"""Long description of the environment.
Needs to be implemented for each subclass as a class attribute.
"""
raise NotImplementedError
def __init__(
self,
env_id: int,
config_run: ConfigRun,
state_config: StateConfig,
verbose: int = 2,
callback: Callable | None = None,
state_modification_callback: Callable | None = None,
seed: int | None = None,
*,
episode_duration: TimeStep | str,
sampling_time: TimeStep | str,
sim_steps_per_sample: int | str = 1,
scenario_manager: ScenarioManager | None = None,
render_mode: str | None = None,
path_env: Path | None = None,
**kwargs: Any,
) -> None:
super().__init__()
#: Verbosity level used for logging.
self.verbose: int = verbose
log.setLevel(int(verbose * 10))
# Set some standard path settings
#: Information about the optimization run and information about the paths.
#: For example, it defines results_path and scenarios_path.
self.config_run: ConfigRun = config_run
#: Callback can be used for logging and plotting.
self.callback: Callable | None = callback
#: Callback can be used for modifying the state at each time step.
self.state_modification_callback: Callable | None = state_modification_callback
#: ID of the environment (useful for vectorized environments).
self.env_id: int = int(env_id)
#: Render mode for rendering the environment
self.render_mode: str | None = render_mode
#: Duration of one episode in seconds.
self.episode_duration: float = timestep_to_seconds(episode_duration)
#: Sampling time (interval between optimization time steps) in seconds.
self.sampling_time: float = timestep_to_seconds(sampling_time)
#: Number of time steps (of width sampling_time) in each episode.
self.n_episode_steps: int = int(self.episode_duration // self.sampling_time)
#: Number of simulation steps to be taken for each sample. This must be a divisor of 'sampling_time'.
self.sim_steps_per_sample: int = int(sim_steps_per_sample)
#: State Configuration for defining State Variables.
self.state_config: StateConfig = state_config
self.action_space, self.observation_space = self.state_config.continuous_spaces()
#: Manager to load scenario data into the state
self.scenario_manager = scenario_manager
#: Explicit path override for environment directory (used if auto-detection fails)
self._path_env_override: pathlib.Path | None = pathlib.Path(path_env) if path_env is not None else None
if seed is not None:
# Initialize random generator
self.action_space.seed(seed=seed)
self.observation_space.seed(seed=seed)
self._init_attributes()
def _init_attributes(self) -> None:
"""Initialize environment attributes that don't depend on constructor arguments."""
detected_path: pathlib.Path | None = None
if self._path_env_override is not None:
detected_path = self._path_env_override
else:
try:
detected_path = pathlib.Path(inspect.getfile(type(self))).resolve().parent
except (TypeError, OSError):
detected_path = None
if detected_path is None:
detected_path = pathlib.Path.cwd()
log.warning(
f"Could not automatically detect environment path for {type(self).__name__}. "
f"Falling back to current working directory: {detected_path}. "
"Consider passing 'path_env' explicitly to the constructor if this is incorrect."
)
#: Path of the environment file.
self.path_env: pathlib.Path = detected_path
# Store data logs and log other information
#: Episode timer (stores the start time of the episode).
self.episode_timer: float = time.time()
#: Current state of the environment.
self.state: dict[str, np.ndarray]
#: Additional state information to append to the state during stepping and reset
self.additional_state: dict[str, float] | None = None
#: Log of the environment state.
self.state_log: list[dict[str, np.ndarray]] = []
#: Log of the environment state over multiple episodes.
self.state_log_longtime: list[list[dict[str, np.ndarray]]] = []
#: Number of completed episodes.
self.n_episodes: int = 0
#: Current step of the model (number of completed steps) in the current episode.
self.n_steps: int = 0
#: Current step of the model (total over all episodes).
self.n_steps_longtime: int = 0
@property
def run_name(self) -> str:
#: Name of the current optimization run.
return self.config_run.name
@property
def results_path(self) -> pathlib.Path:
#: Path for storing results.
return self.config_run.results_path
@property
def scenarios_path(self) -> pathlib.Path | None:
#: Path for the scenario data.
return self.config_run.scenarios_path
@property
def series_results_path(self) -> pathlib.Path:
#: Path for storing results of series of runs.
return self.config_run.series_results_path
[docs]
@abc.abstractmethod
def _step(self) -> tuple[float, bool, bool, dict]:
"""Abstract method to perform one internal time step.
This private method must be implemented by subclasses to update the internal
state dictionary and return step results. It should work with the internal
state rather than returning observations directly.
:return: Tuple of (reward, terminated, truncated, info)
:meta public:
"""
[docs]
def step(self, action: np.ndarray) -> StepResult:
"""Proceed one time step and return the reward for the action provided as well as the new observation.
This method handles the public interface for the step operation. It validates actions,
executes actions by calling the private _step method implemented by subclasses, increments n_steps,
manages state updates, and returns the formatted results
(reward of the previous action taken, new environment state).
It also updates the state log and calls the state modification callback.
:param action: Actions taken by the agent.
:return: The return value represents the state of the environment after the step was performed:
* **observations**: A dictionary with new observation values as defined by the
observation space, automatically extracted from the internal state.
* **reward**: The value of the reward function. This is just one floating point value.
* **terminated (bool)**: Whether the agent reaches the terminal state (as defined under the MDP of the task)
which can be positive or negative. An example is reaching the goal state or moving into the lava from
the Sutton and Barto Gridworld. If true, the Vectorizer will call :meth:`reset`.
* **truncated (bool)**: Whether the truncation condition outside the scope of the MDP is satisfied
(i.e. the episode ended). Typically, this is a timelimit, but could also be used to indicate an agent
physically going out of bounds. Can be used to end the episode prematurely before a terminal state is
reached. If true, the Vectorizer will call :meth:`reset`.
* **info**: Provide some additional info about the state of the environment. The contents of this may be
used for logging purposes in the future but typically do not currently serve a purpose.
"""
# Clear state
self._reset_state()
# Check actions
self._actions_valid(action)
self.set_action(action=action)
# Load scenario data for current timestep, if present.
# This is the same data which has been in the prior state,
# but cleared because of _reset_state().
self.set_scenario_state()
# Perform the actual step in the environment
reward, terminated, __truncated, info = self._step()
self.n_steps += 1
# Call self._truncated() after incrementing n_steps
truncated = __truncated or self._truncated()
# Load scenario data from next timestep for observations, if present
self.set_scenario_state()
# Execute optional state modification callback function
if self.state_modification_callback:
self.state_modification_callback(self)
self.state_log.append(self.state)
# Render the environment at each step
if self.render_mode is not None:
self.render()
return self.get_observations(), reward, terminated, truncated, info
def _actions_valid(self, action: np.ndarray | dict) -> None:
"""Check whether the actions are within the specified action space.
:param action: Actions taken by the agent.
:raise: RuntimeError, when the actions are not inside of the action space.
"""
if not self.action_space.contains(action):
error_msg = self._build_action_error_message(action)
raise RuntimeError(error_msg)
def _build_action_error_message(self, action: np.ndarray | dict) -> str:
"""Build a detailed error message explaining why the action is invalid.
:param action: The invalid action that was provided.
:return: Detailed error message string.
"""
error_parts = ["Action validation failed!"]
error_parts.append(f"\nReceived action: {self._format_array(action)}")
error_parts.append(f"Action space: {self.action_space}")
# Delegate to specific validators based on space type
if isinstance(self.action_space, spaces.Box):
error_parts.extend(self._validate_box_action(cast("np.ndarray", action), self.action_space))
elif isinstance(self.action_space, spaces.Discrete):
error_parts.extend(self._validate_discrete_action(cast("np.ndarray", action), self.action_space))
elif isinstance(self.action_space, spaces.MultiDiscrete):
error_parts.extend(self._validate_multi_discrete_action(cast("np.ndarray", action), self.action_space))
elif isinstance(self.action_space, spaces.Dict):
error_parts.extend(self._validate_dict_action(action, self.action_space))
else:
error_parts.append("\nThe action does not match the expected action space type.")
return "\n".join(error_parts)
def _add_shape_error(self, errors: list[str], expected: tuple, received: tuple) -> None:
"""Add shape mismatch error details to error list
:param errors: List to append error messages to.
:param expected: Expected shape.
:param received: Received shape.
"""
errors.append("\nShape mismatch:")
errors.append(f" Expected: {expected}")
errors.append(f" Received: {received}")
if len(expected) == 1 and len(received) == 1:
errors.append(f" → Expected {expected[0]} action(s), but received {received[0]} action(s)")
def _add_violations(self, errors: list[str], violations: list[str], violation_type: str = "Bound") -> None:
"""Add violation details to error list with truncationS
:param errors: List to append error messages to.
:param violations: List of violation messages.
:param violation_type: Type of violation (e.g., "Bound", "Value").
"""
if violations:
errors.append(f"\n{violation_type} violations ({len(violations)} found):")
errors.extend(violations[:10])
if len(violations) > 10:
errors.append(f" ... and {len(violations) - 10} more violation(s)")
def _validate_box_action(self, action: np.ndarray, space: spaces.Box) -> list[str]:
"""Validate Box space action and return specific error details.
:param action: The action to validate.
:param space: The Box space to validate against.
:return: List of error message parts.
"""
errors: list[str] = []
# Check shape
if action.shape != space.shape:
self._add_shape_error(errors, space.shape, action.shape)
return errors
# Check dtype compatibility (warn but don't fail for float32 vs float64)
if action.dtype != space.dtype and not (
np.issubdtype(action.dtype, np.floating) and np.issubdtype(space.dtype, np.floating)
):
errors.append("\nData type mismatch:")
errors.append(f" Expected: {space.dtype}")
errors.append(f" Received: {action.dtype}")
# Check bounds
violations = []
action_flat = action.flatten()
low_flat = np.broadcast_to(space.low, space.shape).flatten()
high_flat = np.broadcast_to(space.high, space.shape).flatten()
for idx, (val, low, high) in enumerate(zip(action_flat, low_flat, high_flat, strict=False)):
if val < low:
violations.append(f" - action[{idx}] = {val:.6g} is below minimum bound of {low:.6g}")
elif val > high:
violations.append(f" - action[{idx}] = {val:.6g} exceeds maximum bound of {high:.6g}")
self._add_violations(errors, violations, "Bound")
return errors
def _validate_discrete_action(self, action: np.ndarray | int, space: spaces.Discrete) -> list[str]:
"""Validate Discrete space action and return specific error details.
:param action: The action to validate.
:param space: The Discrete space to validate against.
:return: List of error message parts.
"""
errors: list[str] = []
# Convert to int if it's an array
if isinstance(action, np.ndarray):
if action.size != 1:
self._add_shape_error(errors, (1,), action.shape)
return errors
action_val = int(action.item())
else:
action_val = int(action)
# Check bounds
space_start = int(space.start)
space_n = int(space.n)
if action_val < space_start or action_val >= space_start + space_n:
errors.append("\nValue out of range:")
errors.append(f" Valid range: [{space_start}, {space_start + space_n - 1}]")
errors.append(f" Received: {action_val}")
return errors
def _validate_multi_discrete_action(self, action: np.ndarray, space: spaces.MultiDiscrete) -> list[str]:
"""Validate MultiDiscrete space action and return specific error details.
:param action: The action to validate.
:param space: The MultiDiscrete space to validate against.
:return: List of error message parts.
"""
errors: list[str] = []
# Check shape
if action.shape != space.nvec.shape:
self._add_shape_error(errors, space.nvec.shape, action.shape)
return errors
# Check individual action values
violations = []
action_flat = action.flatten()
nvec_flat = space.nvec.flatten()
start_flat = np.broadcast_to(space.start, space.nvec.shape).flatten()
for idx, (val, n, start) in enumerate(zip(action_flat, nvec_flat, start_flat, strict=False)):
if val < start or val >= start + n:
violations.append(f" - action[{idx}] = {val} is outside valid range [{start}, {start + n - 1}]")
self._add_violations(errors, violations, "Value")
return errors
def _validate_dict_action(self, action: np.ndarray | dict, space: spaces.Dict) -> list[str]:
"""Validate Dict space action and return specific error details.
:param action: The action to validate.
:param space: The Dict space to validate against.
:return: List of error message parts.
"""
errors: list[str] = []
errors.append("\nDict action space validation failed.")
errors.append(f" Expected a dictionary with keys: {list(space.spaces.keys())}")
if not isinstance(action, dict):
errors.append(f" Received type: {type(action).__name__}")
errors.append(" → Dict action spaces require actions to be dictionaries, not arrays")
else:
missing_keys = set(space.spaces.keys()) - set(action.keys())
extra_keys = set(action.keys()) - set(space.spaces.keys())
if missing_keys:
errors.append(f" Missing keys: {list(missing_keys)}")
if extra_keys:
errors.append(f" Unexpected keys: {list(extra_keys)}")
return errors
def _format_array(self, arr: np.ndarray | dict, max_items: int = 10) -> str:
"""Format a numpy array or dict for display in error messages.
:param arr: Array or dict to format.
:param max_items: Maximum number of items to display before truncating.
:return: Formatted string representation.
"""
# Handle dict actions (for Dict action spaces)
if isinstance(arr, dict):
return str(arr)
# Handle numpy arrays - cast needed for mypy type narrowing
arr = cast("np.ndarray", arr)
if arr.size <= max_items:
return str(arr)
# For large arrays, show first few and last few elements
arr_flat = arr.flatten()
first_items = arr_flat[: max_items // 2]
last_items = arr_flat[-(max_items // 2) :]
formatted = f"[{' '.join(f'{x:.6g}' for x in first_items)} ... {' '.join(f'{x:.6g}' for x in last_items)}]"
return f"{formatted} (shape: {arr.shape}, dtype: {arr.dtype})"
def _reset_state(self) -> None:
"""Clear self.state and initialize with additional_state."""
self.state = {}
if self.additional_state is not None:
additional_state = {name: np.array([value]) for name, value in self.additional_state.items()}
self.state.update(additional_state)
def _truncated(self) -> bool:
"""Check if the episode is over using the number of steps (n_steps) and the total number of
steps in an episode (n_episode_steps).
:return: boolean showing, whether the episode is over (truncated by its time limit).
"""
return self.n_steps >= self.n_episode_steps
[docs]
@abc.abstractmethod
def _reset(
self,
*,
seed: int | None = None,
options: dict[str, Any] | None = None,
) -> dict[str, Any]:
"""Abstract reset method that must be implemented by subclasses.
:meta public:
"""
[docs]
def reset(
self,
*,
seed: int | None = None,
options: dict[str, Any] | None = None,
) -> tuple[ObservationType, dict[str, Any]]:
"""Reset the environment to an initial internal state, returning an initial observation and info.
This method generates a new starting state often with some randomness to ensure that the agent explores the
state space and learns a generalised policy about the environment. This randomness can be controlled
with the ``seed`` parameter.
:param seed: The seed for initializing any randomized components of the state.
Subclasses should use this for reproducible randomness in their state init
:param options: Additional information to specify how the environment is reset (optional,
depending on the specific environment) (default: None)
:return: Tuple of observation and info. The observation of the initial state will be an element of
:attr:`observation_space` (typically a numpy array) and is analogous to the observation returned by
:meth:`step`. Info is a dictionary containing auxiliary information complementing ``observation``. It
should be analogous to the ``info`` returned by :meth:`step`.
"""
# Reset state_log and counters
if self.n_steps > 0:
self._reset_episode()
# Clear state
self._reset_state()
# Set rng seed (only has a value on first reset of first episode)
Env.reset(self, seed=seed)
# Create separate rng for the scenario manager to ensure deterministic values
# The value for self.np_random_seed is set by Env.reset()
if seed is not None or not hasattr(self, "_scenario_rng"):
self._scenario_rng = np.random.default_rng(self.np_random_seed)
# Update with scenario data, if present
self.set_scenario_state(reset=True)
# Set initial observations in child class
info = self._reset(options=options)
# Execute optional state modification callback function
if self.state_modification_callback:
self.state_modification_callback(self)
# Log state
self.state_log.append(self.state)
# Render the environment
if self.render_mode is not None:
self.render()
return self.get_observations(), info
def _reduce_state_log(self) -> list[dict[str, np.ndarray]]:
"""Remove unwanted parameters from state_log before storing in state_log_longtime.
:return: The return value is a list of dictionaries,
where the parameters that should not be stored were removed
"""
allowed_keys = set(self.state_config.add_to_state_log)
return [{k: v for k, v in entry.items() if k in allowed_keys} for entry in self.state_log]
def _reset_episode(self) -> None:
"""Store episode statistics and reset episode counters."""
if self.callback is not None:
self.callback(self)
# Store some logging data
self.n_episodes += 1
# store reduced_state_log in state_log_longtime
self.state_log_longtime.append(self._reduce_state_log())
self.n_steps_longtime += self.n_steps
# Reset episode variables
self.n_steps = 0
self.episode_timer = time.time()
self.state_log = []
[docs]
@abc.abstractmethod
def close(self) -> None:
"""Close the environment. This should always be called when an entire run is finished. It should be used to
close any resources (i.e. simulation models) used by the environment.
"""
msg = "Cannot close an abstract Environment."
raise NotImplementedError(msg)
[docs]
@abc.abstractmethod
def render(self) -> None:
"""Render the environment.
The set of supported modes varies per environment. Some environments do not support rendering at
all. By convention in Farama *gymnasium*, if mode is:
* human: render to the current display or terminal and return nothing. Usually for human consumption.
* rgb_array: Return a numpy.ndarray with shape (x, y, 3), representing RGB values for an x-by-y pixel image,
suitable for turning into a video.
* ansi: Return a string (str) or StringIO.StringIO containing a terminal-style text representation.
The text can include newlines and ANSI escape sequences (e.g. for colors).
"""
msg = "Cannot render an abstract Environment."
raise NotImplementedError(msg)
[docs]
@classmethod
def get_info(cls) -> tuple[str, str]:
"""Get info about environment.
:return: Tuple of version and description.
"""
return cls.version, cls.description # type: ignore[return-value]
def __str__(self) -> str:
"""Human-readable string representation of BaseEnv."""
env_class = self.__class__.__name__
n_actions = len(self.state_config.actions)
n_observations = len(self.state_config.observations)
status = f"Episode {self.n_episodes}, Step {self.n_steps}/{self.n_episode_steps}"
return f"{env_class}(id={self.env_id}, {n_actions} actions, {n_observations} observations, {status})"
def __repr__(self) -> str:
"""Developer-friendly string representation of BaseEnv."""
env_class = self.__class__.__name__
return (
f"{env_class}(env_id={self.env_id}, run_name='{self.run_name}', "
f"n_episodes={self.n_episodes}, n_steps={self.n_steps}, "
f"episode_duration={self.episode_duration}, sampling_time={self.sampling_time})"
)
[docs]
def export_state_log(
self,
path: Path,
names: Sequence[str] | None = None,
*,
sep: str = ";",
decimal: str = ".",
) -> None:
"""Extension of csv_export to include timeseries on the data.
:param names: Field names used when data is a Matrix without column names.
:param sep: Separator to use between the fields.
:param decimal: Sign to use for decimal points.
"""
start_time = datetime.fromtimestamp(self.episode_timer)
step = self.sampling_time / self.sim_steps_per_sample
timerange = [start_time + timedelta(seconds=(k * step)) for k in range(len(self.state_log))]
csv_export(path=path, data=self.state_log, index=timerange, names=names, sep=sep, decimal=decimal)
[docs]
def get_observations(self) -> dict[str, np.ndarray]:
"""Gather observations from the state.
:raises KeyError: Observation is not available in state
:return: Filtered observations as a dictionary.
:rtype: dict[str, np.ndarray]
"""
observations = {}
for name in self.state_config.observations:
try:
observations[name] = self.state[name]
except KeyError as e:
msg = f"Observation {e!s} is unavailable in environment state."
raise KeyError(msg) from e
return observations
[docs]
def set_action(self, action: np.ndarray | dict[str, np.ndarray]) -> None:
"""Set action values in the state.
:param action: Actions to be set.
:type action: np.ndarray | dict[str, np.ndarray]
"""
iterator: Iterator
if isinstance(action, np.ndarray):
iterator = zip(self.state_config.actions, action, strict=True)
else:
iterator = iter(action.items())
for name, value in iterator:
val = value if isinstance(value, np.ndarray) else np.array([value])
self.state[name] = val
[docs]
def set_external_outputs(self, external_outputs: Mapping[str, int | float | bool | str]) -> None:
"""Set external outputs in the state.
Accepts scalars instead of numpy arrays as values.
:param external_outputs: Dict of external outputs with external_ids as keys.
:type external_outputs: Mapping[str, int | float | bool | str]
:raises KeyError: Received an unknown external id
"""
for name in self.state_config.ext_outputs:
state_var = self.state_config.vars[name]
try:
unscaled_value = external_outputs[state_var.ext_id] # type: ignore[index]
except KeyError as e:
msg = f"Missing value for external output: {name}"
raise KeyError(msg) from e
# Check for boolean FIRST (since bool is a subclass of int in Python)
if isinstance(unscaled_value, (bool, np.bool_)):
# Preserve boolean with explicit dtype to avoid conversion to float
self.state[name] = np.array([unscaled_value], dtype=bool)
elif isinstance(unscaled_value, (int, float, np.integer, np.floating)):
# Only scale numeric values (int and float)
scaled_value = (unscaled_value + state_var.ext_scale_add) * state_var.ext_scale_mult
self.state[name] = np.array([scaled_value])
else:
# Preserve other non-numeric values as-is (string, etc.)
self.state[name] = np.array([unscaled_value])
[docs]
def set_scenario_state(self, reset: bool = False) -> None:
"""Set scenario output values for the current timestep in the state.
:param reset: Indicator whether this was called from the reset method
"""
if self.scenario_manager is None:
return
# Compute new offset after reset
if reset:
self._scenario_offset = self.scenario_manager.compute_episode_offset(self._scenario_rng)
for state_name in self.state_config.scenario_outputs:
state_var = self.state_config.vars[state_name]
unscaled_data = self.scenario_manager.get_scenario_state_var(
n_step=self.n_steps + self._scenario_offset, state_var=state_var
)
scaled_data = (unscaled_data + state_var.ext_scale_add) * state_var.ext_scale_mult
self.state[state_name] = scaled_data