from __future__ import annotations
import pathlib
from csv import DictWriter
from typing import TYPE_CHECKING
import numpy as np
import pandas as pd
from gymnasium import spaces
from pydantic import BaseModel, ConfigDict
from eta_ctrl.util.io_utils import load_config
if TYPE_CHECKING:
from collections.abc import Mapping, Sequence
from typing import Any, Self
from eta_ctrl.util.type_annotations import Path
from logging import getLogger
log = getLogger(__name__)
#: Largest finite float32 value, used as default bound for non-action state variables.
_FMAX = float(np.finfo(np.float32).max)
[docs]
class StateVar(BaseModel):
"""A variable in the state of an environment."""
model_config = ConfigDict(frozen=True, extra="forbid")
#: Name of the state variable (This must always be specified).
name: str
#: Should the agent specify actions for this variable? (default: False).
is_agent_action: bool = False
#: Should the agent be allowed to observe the value of this variable? (default: False).
is_agent_observation: bool = False
#: Should the state log of this episode be added to state_log_longtime? (default: True).
add_to_state_log: bool = True
#: Name of the variable in the external model
#: (e.g.: environment or FMU) (default: StateVar.name if (is_ext_input or is_ext_output) else None).
ext_id: str | None = None
#: Should this variable be passed to the external model as an input? (default: False).
is_ext_input: bool = False
#: Should this variable be parsed from the external model output? (default: False).
is_ext_output: bool = False
#: Value to add to the output from an external model (default: 0.0).
ext_scale_add: float = 0.0
#: Value to multiply to the output from an external model (default: 1.0).
ext_scale_mult: float = 1.0
#: Name of the scenario variable, this value should be read from (default: None).
scenario_id: str | None = None
#: Should this variable be read from imported timeseries date? (default: False).
from_scenario: bool = False
#: Value to add to the value read from a scenario file (default: 0.0).
scenario_scale_add: float = 0.0
#: Value to multiply to the value read from a scenario file (default: 1.0).
scenario_scale_mult: float = 1.0
#: Lowest possible value of the state variable (default: -np.finfo(np.float32).max).
low_value: float = -_FMAX
#: Highest possible value of the state variable (default: np.finfo(np.float32).max).
high_value: float = _FMAX
#: If the value of the variable dips below this, the episode should be aborted (default: -np.inf).
abort_condition_min: float = -np.inf
#: If the value of the variable rises above this, the episode should be aborted (default: np.inf).
abort_condition_max: float = np.inf
#: Determine the index, where to look (useful for mathematical optimization, where multiple time steps could be
#: returned). In this case, the index values might be different for actions and observations.
index: int = 0
#: For scenario StateVars: Length of StateVars horizon in state, e.g. the prediction horizon length (unit: steps).
duration: int = 1
[docs]
def model_post_init(self, context: Any) -> None:
for flag, id_value, id_name in [
(self.is_ext_input, self.ext_id, "ext_id"),
(self.is_ext_output, self.ext_id, "ext_id"),
(self.from_scenario, self.scenario_id, "scenario_id"),
]:
if flag and id_value is None:
# set the correct id attribute (ext_id or scenario_id) when missing
object.__setattr__(self, id_name, self.name)
log.info(f"Using name as {id_name} for variable {self.name}")
# Require explicit finite bounds for action variables
if self.is_agent_action and (self.low_value == -_FMAX or self.high_value == _FMAX):
msg = (
f"Action variable '{self.name}' requires explicit finite bounds. "
f"Set both 'low_value' and 'high_value' in the state config."
)
raise ValueError(msg)
# Validate mutual exclusivity of from_scenario, is_ext_output and is_agent_action
data_sources = {
"from_scenario": self.from_scenario,
"is_ext_output": self.is_ext_output,
"is_agent_action": self.is_agent_action,
}
if sum(data_sources.values()) > 1:
# Find out which flags are set and include their names in the error message
sources_set = [name for name, flag in data_sources.items() if flag]
msg = f"Variable {self.name} cannot be {', '.join(sources_set)} at the same time."
raise ValueError(msg)
[docs]
@classmethod
def from_dict(cls, mapping: Mapping[str, Any] | pd.Series) -> StateVar:
"""Initialize a state var from a dictionary or pandas Series.
:param mapping: dictionary or pandas Series to initialize from.
:return: Initialized StateVar object
"""
return cls(**dict(mapping))
def __getitem__(self, name: str) -> Any:
return getattr(self, name)
def __str__(self) -> str:
"""Human-readable string representation of StateVar."""
var_type = []
if self.is_agent_action:
var_type.append("action")
if self.is_agent_observation:
var_type.append("observation")
if not var_type:
var_type.append("variable")
type_str = "/".join(var_type)
has_range = self.low_value != -_FMAX or self.high_value != _FMAX
range_str = f"[{self.low_value}, {self.high_value}]" if has_range else ""
return f"StateVar '{self.name}' ({type_str}){' ' + range_str if range_str else ''}"
def __repr__(self) -> str:
"""Developer-friendly string representation of StateVar."""
key_attrs = []
if self.is_agent_action:
key_attrs.append("is_agent_action=True")
if self.is_agent_observation:
key_attrs.append("is_agent_observation=True")
if self.low_value != -_FMAX:
key_attrs.append(f"low_value={self.low_value}")
if self.high_value != _FMAX:
key_attrs.append(f"high_value={self.high_value}")
attrs_str = ", ".join(key_attrs)
return f"StateVar(name='{self.name}'{', ' + attrs_str if attrs_str else ''})"
[docs]
class StateStructure(BaseModel):
"""Used for parsing the state structure from a config file."""
model_config = ConfigDict(frozen=True, extra="forbid")
state_parameters: dict[str, float | bool] | None = None
actions: list[StateVar]
observations: list[StateVar]
[docs]
class StateConfig:
"""The configuration for the action and observation spaces. The values are used to control which variables are
part of the action space and observation space. Therefore, the *StateConfig* is very important
for the functionality of EtaCtrl.
"""
def __init__(self, *state_vars: StateVar, source_file: pathlib.Path | None = None) -> None:
#: Mapping of the variables names to their StateVar instance with all associated information.
self.vars = {var.name: var for var in state_vars}
#: Attribute to store the source file path (if loaded from file).
self.source_file: pathlib.Path | None = source_file
# Additional Dataframe for easier access
if state_vars:
self.df_vars: pd.DataFrame = pd.DataFrame([var.model_dump() for var in state_vars]).set_index("name")
if not self.df_vars.index.is_unique:
duplicates = self.df_vars.index[self.df_vars.index.duplicated()].unique().tolist()
msg = f"Duplicate variable names in StateConfig: {duplicates}"
raise ValueError(msg)
else:
# Handle empty case - create empty DataFrame with expected columns
self.df_vars = pd.DataFrame(columns=list(StateVar.model_fields.keys())).set_index("name")
#: List of variables that are agent actions. Needs to be ordered.
self.actions: list[str] = self.df_vars.query("is_agent_action == True").index.tolist()
#: Set of variables that are agent observations.
self.observations: list[str] = self.df_vars.query("is_agent_observation == True").index.tolist()
#: Set of variables that should be logged.
self.add_to_state_log: list[str] = self.df_vars.query("add_to_state_log == True").index.tolist()
#: List of variables that should be provided to an external source (such as an FMU).
self.ext_inputs: list[str] = self.df_vars.query("is_ext_input == True").index.tolist()
#: List of variables that can be received from an external source (such as an FMU).
self.ext_outputs: list[str] = self.df_vars.query("is_ext_output == True").index.tolist()
#: Mapping of variable names to their external IDs.
self.map_ext_ids: dict[str, str] = self.df_vars.loc[self.ext_inputs + self.ext_outputs, "ext_id"].to_dict()
#: Reverse mapping of external IDs to their corresponding variable names.
self.rev_ext_ids: dict[str, str] = {v: k for k, v in self.map_ext_ids.items()}
#: List of variables which are loaded from scenario files.
self.scenario_outputs: list[str] = self.df_vars.query("from_scenario == True").index.tolist()
#: Mapping of internal environment names to scenario IDs.
self.map_scenario_ids: dict[str, str] = self.df_vars.loc[self.scenario_outputs, "scenario_id"].to_dict()
_abort_condition_df = self.df_vars.loc[:, ["abort_condition_min", "abort_condition_max"]]
_abort_condition_df = _abort_condition_df.replace([np.inf, -np.inf], np.nan)
#: List of variables that have minimum values for an abort condition.
self.abort_conditions_min: list[str] = _abort_condition_df["abort_condition_min"].dropna().index.tolist()
#: List of variables that have maximum values for an abort condition.
self.abort_conditions_max: list[str] = _abort_condition_df["abort_condition_max"].dropna().index.tolist()
[docs]
@classmethod
def from_file(
cls, root_path: pathlib.Path, filename: Path, extra_params: Mapping[str, float] | None = None
) -> Self:
"""Load a StateConfig from a config file.
:param file: Path of the config file.
:return: StateConfig object.
"""
state_folder_relpath = ""
try:
raw_dict = load_config(file=root_path / state_folder_relpath / filename)
except FileNotFoundError:
state_folder_relpath = "environments/"
try:
raw_dict = load_config(file=root_path / state_folder_relpath / (filename))
log.info("Using default state_folder_relpath 'environments/'")
except FileNotFoundError:
msg = f"StateConfig file not found at {root_path / filename} or {root_path / 'environments' / filename}"
raise FileNotFoundError(msg) from None
file = root_path / state_folder_relpath / filename
log.info(f"Loading StateConfig from file at {file}).")
actions: list[dict[str, Any]] = raw_dict.get("actions") or []
observations: list[dict[str, Any]] = raw_dict.get("observations") or []
state_vars: list[dict[str, Any]] = raw_dict.get("state_vars") or []
actions = [{**act, "is_agent_action": True} for act in actions]
observations = [{**obs, "is_agent_observation": True} for obs in observations]
all_states = actions + observations + state_vars
if len(all_states) == 0:
msg = f"Invalid StateConfig at {file} with no StateVar's"
raise ValueError(msg)
state_params: dict[str, float] = {}
if extra_params is not None:
state_params.update(extra_params)
# Defined by user in *_state_config.toml
config_state_params: dict[str, float] | Any | None = raw_dict.get("state_parameters")
if isinstance(config_state_params, dict):
log.debug(f"Using State parameters {config_state_params} from {file} for StateConfig.")
state_params.update(config_state_params)
elif config_state_params is not None:
log.warning(f"State parameters in {file} needs to be a dict! Ignoring.")
return cls.from_dict(mapping=all_states, source_file=file, state_params=state_params)
[docs]
@classmethod
def from_dict(
cls,
mapping: Sequence[dict[str, Any]] | pd.DataFrame,
*,
state_params: Mapping[str, float] | None = None,
**kwargs: Any,
) -> Self:
"""Convert a potentially incomplete StateConfig DataFrame or a list of dictionaries to the
standardized StateConfig format. This will ignore any additional columns.
:param mapping: Mapping to be converted to the StateConfig format.
:param state_params: State parameter values for parameters supplied in mapping (e.g. {min_temp: 20})
:return: StateConfig object.
"""
if not state_params:
state_params = {}
# cast to list of dicts
_mapping: Sequence[dict[str, Any]] = (
mapping.to_dict("records") if isinstance(mapping, pd.DataFrame) else mapping
)
# build a new list with NaN entries removed
_mapping = [{k: v for k, v in statevar.items() if not pd.isna(v)} for statevar in _mapping]
for statevar in _mapping:
for field_name, value in statevar.items():
if field_name in ("name", "ext_id", "scenario_id"): # Supposed to be strings
continue
if isinstance(value, str):
parameter_name = value
if is_negative := value.startswith("-"):
parameter_name = parameter_name[1:] # strip minus sign
new_value = state_params.get(parameter_name)
if new_value is None:
msg = f"Parameter {parameter_name} needs to be specified in state_params."
raise KeyError(msg)
if is_negative:
new_value = -new_value
statevar[field_name] = new_value
return cls(*[StateVar.from_dict(col) for col in _mapping], **kwargs)
[docs]
def store_file(self, file: Path) -> None:
"""Save the StateConfig to a comma separated file.
:param file: Path to the file.
"""
_file = file if isinstance(file, pathlib.Path) else pathlib.Path(file)
_header = StateVar.model_fields.keys()
with _file.open("w") as f:
writer = DictWriter(f, _header, restval="None", delimiter=";")
writer.writeheader()
for var in self.vars.values():
writer.writerow(var.model_dump())
[docs]
def within_abort_conditions(self, state: Mapping[str, float]) -> bool:
"""Check whether the given state is within the abort conditions specified by the StateConfig instance.
:param state: The state array to check for conformance.
:return: Result of the check (False if the state does not conform to the required conditions).
"""
# Only check abort conditions for numeric values (int and float), exclude bool since it's a subclass of int
valid_min = all(
state[name] >= self.vars[name].abort_condition_min
for name in state
if not isinstance(state[name], (bool, np.bool_))
and isinstance(state[name], (int, float, np.integer, np.floating))
)
if not valid_min:
log.warning("Minimum abort condition exceeded by at least one value.")
valid_max = all(
state[name] <= self.vars[name].abort_condition_max
for name in state
if not isinstance(state[name], (bool, np.bool_))
and isinstance(state[name], (int, float, np.integer, np.floating))
)
if not valid_max:
log.warning("Maximum abort condition exceeded by at least one value.")
return valid_min and valid_max
[docs]
def continuous_action_space(self) -> spaces.Box:
"""Generate a numpy ndarray action space.
:return: Action space.
"""
actions = self.df_vars.query("is_agent_action == True")
low_values = actions["low_value"].to_numpy(dtype=np.float32)
high_values = actions["high_value"].to_numpy(dtype=np.float32)
return spaces.Box(low_values, high_values)
[docs]
def continuous_observation_space(self) -> spaces.Dict:
"""Generate a dictionary observation space.
:return: Observation Space.
"""
observations: dict[str, spaces.Box] = {
name: spaces.Box(low=row["low_value"], high=row["high_value"], shape=(row["duration"],), dtype=np.float32)
for name, row in self.df_vars.iterrows()
if row["is_agent_observation"] is True
}
return spaces.Dict(observations) # type: ignore[arg-type]
[docs]
def continuous_spaces(self) -> tuple[spaces.Box, spaces.Dict]:
"""Generate continuous action and observation spaces according to the OpenAI specification.
:return: Tuple of action space and observation space.
"""
action_space = self.continuous_action_space()
observation_space = self.continuous_observation_space()
return action_space, observation_space
def __str__(self) -> str:
"""Human-readable string representation of StateConfig."""
n_actions = len(self.actions)
n_observations = len(self.observations)
n_total = len(self.vars)
base_str = f"StateConfig with {n_actions} actions, {n_observations} observations ({n_total} total variables)"
if self.source_file is not None:
return f"{base_str} from '{self.source_file}'"
return base_str
def __repr__(self) -> str:
"""Developer-friendly string representation of StateConfig."""
# Show first few variables for context
actions_str = str(self.actions[:3]).split("]")[0]
observations_str = str(sorted(self.observations)[:3]).split("]")[0]
actions_str = f"{actions_str}{', ...' if len(self.actions) > 3 else ''}]"
observations_str = f"{observations_str}{', ...' if len(self.observations) > 3 else ''}]"
return f"StateConfig(actions={actions_str}, observations={observations_str})"
def __getitem__(self, name: str) -> Any:
return getattr(self, name)
@property
def loc(self) -> pd.api.indexers._LocIndexer:
"""Behave like dataframe (enable indexing via loc) for compatibility."""
return self.vars