Source code for eta_ctrl.envs.state

from __future__ import annotations

import pathlib
from csv import DictWriter
from typing import TYPE_CHECKING

import numpy as np
import pandas as pd
from gymnasium import spaces
from pydantic import BaseModel, ConfigDict

from eta_ctrl.util.io_utils import load_config

if TYPE_CHECKING:
    from collections.abc import Mapping, Sequence
    from typing import Any, Self

    from eta_ctrl.util.type_annotations import Path
from logging import getLogger

log = getLogger(__name__)

#: Largest finite float32 value, used as default bound for non-action state variables.
_FMAX = float(np.finfo(np.float32).max)


[docs] class StateVar(BaseModel): """A variable in the state of an environment.""" model_config = ConfigDict(frozen=True, extra="forbid") #: Name of the state variable (This must always be specified). name: str #: Should the agent specify actions for this variable? (default: False). is_agent_action: bool = False #: Should the agent be allowed to observe the value of this variable? (default: False). is_agent_observation: bool = False #: Should the state log of this episode be added to state_log_longtime? (default: True). add_to_state_log: bool = True #: Name of the variable in the external model #: (e.g.: environment or FMU) (default: StateVar.name if (is_ext_input or is_ext_output) else None). ext_id: str | None = None #: Should this variable be passed to the external model as an input? (default: False). is_ext_input: bool = False #: Should this variable be parsed from the external model output? (default: False). is_ext_output: bool = False #: Value to add to the output from an external model (default: 0.0). ext_scale_add: float = 0.0 #: Value to multiply to the output from an external model (default: 1.0). ext_scale_mult: float = 1.0 #: Name of the scenario variable, this value should be read from (default: None). scenario_id: str | None = None #: Should this variable be read from imported timeseries date? (default: False). from_scenario: bool = False #: Value to add to the value read from a scenario file (default: 0.0). scenario_scale_add: float = 0.0 #: Value to multiply to the value read from a scenario file (default: 1.0). scenario_scale_mult: float = 1.0 #: Lowest possible value of the state variable (default: -np.finfo(np.float32).max). low_value: float = -_FMAX #: Highest possible value of the state variable (default: np.finfo(np.float32).max). high_value: float = _FMAX #: If the value of the variable dips below this, the episode should be aborted (default: -np.inf). abort_condition_min: float = -np.inf #: If the value of the variable rises above this, the episode should be aborted (default: np.inf). abort_condition_max: float = np.inf #: Determine the index, where to look (useful for mathematical optimization, where multiple time steps could be #: returned). In this case, the index values might be different for actions and observations. index: int = 0 #: For scenario StateVars: Length of StateVars horizon in state, e.g. the prediction horizon length (unit: steps). duration: int = 1
[docs] def model_post_init(self, context: Any) -> None: for flag, id_value, id_name in [ (self.is_ext_input, self.ext_id, "ext_id"), (self.is_ext_output, self.ext_id, "ext_id"), (self.from_scenario, self.scenario_id, "scenario_id"), ]: if flag and id_value is None: # set the correct id attribute (ext_id or scenario_id) when missing object.__setattr__(self, id_name, self.name) log.info(f"Using name as {id_name} for variable {self.name}") # Require explicit finite bounds for action variables if self.is_agent_action and (self.low_value == -_FMAX or self.high_value == _FMAX): msg = ( f"Action variable '{self.name}' requires explicit finite bounds. " f"Set both 'low_value' and 'high_value' in the state config." ) raise ValueError(msg) # Validate mutual exclusivity of from_scenario, is_ext_output and is_agent_action data_sources = { "from_scenario": self.from_scenario, "is_ext_output": self.is_ext_output, "is_agent_action": self.is_agent_action, } if sum(data_sources.values()) > 1: # Find out which flags are set and include their names in the error message sources_set = [name for name, flag in data_sources.items() if flag] msg = f"Variable {self.name} cannot be {', '.join(sources_set)} at the same time." raise ValueError(msg)
[docs] @classmethod def from_dict(cls, mapping: Mapping[str, Any] | pd.Series) -> StateVar: """Initialize a state var from a dictionary or pandas Series. :param mapping: dictionary or pandas Series to initialize from. :return: Initialized StateVar object """ return cls(**dict(mapping))
def __getitem__(self, name: str) -> Any: return getattr(self, name) def __str__(self) -> str: """Human-readable string representation of StateVar.""" var_type = [] if self.is_agent_action: var_type.append("action") if self.is_agent_observation: var_type.append("observation") if not var_type: var_type.append("variable") type_str = "/".join(var_type) has_range = self.low_value != -_FMAX or self.high_value != _FMAX range_str = f"[{self.low_value}, {self.high_value}]" if has_range else "" return f"StateVar '{self.name}' ({type_str}){' ' + range_str if range_str else ''}" def __repr__(self) -> str: """Developer-friendly string representation of StateVar.""" key_attrs = [] if self.is_agent_action: key_attrs.append("is_agent_action=True") if self.is_agent_observation: key_attrs.append("is_agent_observation=True") if self.low_value != -_FMAX: key_attrs.append(f"low_value={self.low_value}") if self.high_value != _FMAX: key_attrs.append(f"high_value={self.high_value}") attrs_str = ", ".join(key_attrs) return f"StateVar(name='{self.name}'{', ' + attrs_str if attrs_str else ''})"
[docs] class StateStructure(BaseModel): """Used for parsing the state structure from a config file.""" model_config = ConfigDict(frozen=True, extra="forbid") state_parameters: dict[str, float | bool] | None = None actions: list[StateVar] observations: list[StateVar]
[docs] class StateConfig: """The configuration for the action and observation spaces. The values are used to control which variables are part of the action space and observation space. Therefore, the *StateConfig* is very important for the functionality of EtaCtrl. """ def __init__(self, *state_vars: StateVar, source_file: pathlib.Path | None = None) -> None: #: Mapping of the variables names to their StateVar instance with all associated information. self.vars = {var.name: var for var in state_vars} #: Attribute to store the source file path (if loaded from file). self.source_file: pathlib.Path | None = source_file # Additional Dataframe for easier access if state_vars: self.df_vars: pd.DataFrame = pd.DataFrame([var.model_dump() for var in state_vars]).set_index("name") if not self.df_vars.index.is_unique: duplicates = self.df_vars.index[self.df_vars.index.duplicated()].unique().tolist() msg = f"Duplicate variable names in StateConfig: {duplicates}" raise ValueError(msg) else: # Handle empty case - create empty DataFrame with expected columns self.df_vars = pd.DataFrame(columns=list(StateVar.model_fields.keys())).set_index("name") #: List of variables that are agent actions. Needs to be ordered. self.actions: list[str] = self.df_vars.query("is_agent_action == True").index.tolist() #: Set of variables that are agent observations. self.observations: list[str] = self.df_vars.query("is_agent_observation == True").index.tolist() #: Set of variables that should be logged. self.add_to_state_log: list[str] = self.df_vars.query("add_to_state_log == True").index.tolist() #: List of variables that should be provided to an external source (such as an FMU). self.ext_inputs: list[str] = self.df_vars.query("is_ext_input == True").index.tolist() #: List of variables that can be received from an external source (such as an FMU). self.ext_outputs: list[str] = self.df_vars.query("is_ext_output == True").index.tolist() #: Mapping of variable names to their external IDs. self.map_ext_ids: dict[str, str] = self.df_vars.loc[self.ext_inputs + self.ext_outputs, "ext_id"].to_dict() #: Reverse mapping of external IDs to their corresponding variable names. self.rev_ext_ids: dict[str, str] = {v: k for k, v in self.map_ext_ids.items()} #: List of variables which are loaded from scenario files. self.scenario_outputs: list[str] = self.df_vars.query("from_scenario == True").index.tolist() #: Mapping of internal environment names to scenario IDs. self.map_scenario_ids: dict[str, str] = self.df_vars.loc[self.scenario_outputs, "scenario_id"].to_dict() _abort_condition_df = self.df_vars.loc[:, ["abort_condition_min", "abort_condition_max"]] _abort_condition_df = _abort_condition_df.replace([np.inf, -np.inf], np.nan) #: List of variables that have minimum values for an abort condition. self.abort_conditions_min: list[str] = _abort_condition_df["abort_condition_min"].dropna().index.tolist() #: List of variables that have maximum values for an abort condition. self.abort_conditions_max: list[str] = _abort_condition_df["abort_condition_max"].dropna().index.tolist()
[docs] @classmethod def from_file( cls, root_path: pathlib.Path, filename: Path, extra_params: Mapping[str, float] | None = None ) -> Self: """Load a StateConfig from a config file. :param file: Path of the config file. :return: StateConfig object. """ state_folder_relpath = "" try: raw_dict = load_config(file=root_path / state_folder_relpath / filename) except FileNotFoundError: state_folder_relpath = "environments/" try: raw_dict = load_config(file=root_path / state_folder_relpath / (filename)) log.info("Using default state_folder_relpath 'environments/'") except FileNotFoundError: msg = f"StateConfig file not found at {root_path / filename} or {root_path / 'environments' / filename}" raise FileNotFoundError(msg) from None file = root_path / state_folder_relpath / filename log.info(f"Loading StateConfig from file at {file}).") actions: list[dict[str, Any]] = raw_dict.get("actions") or [] observations: list[dict[str, Any]] = raw_dict.get("observations") or [] state_vars: list[dict[str, Any]] = raw_dict.get("state_vars") or [] actions = [{**act, "is_agent_action": True} for act in actions] observations = [{**obs, "is_agent_observation": True} for obs in observations] all_states = actions + observations + state_vars if len(all_states) == 0: msg = f"Invalid StateConfig at {file} with no StateVar's" raise ValueError(msg) state_params: dict[str, float] = {} if extra_params is not None: state_params.update(extra_params) # Defined by user in *_state_config.toml config_state_params: dict[str, float] | Any | None = raw_dict.get("state_parameters") if isinstance(config_state_params, dict): log.debug(f"Using State parameters {config_state_params} from {file} for StateConfig.") state_params.update(config_state_params) elif config_state_params is not None: log.warning(f"State parameters in {file} needs to be a dict! Ignoring.") return cls.from_dict(mapping=all_states, source_file=file, state_params=state_params)
[docs] @classmethod def from_dict( cls, mapping: Sequence[dict[str, Any]] | pd.DataFrame, *, state_params: Mapping[str, float] | None = None, **kwargs: Any, ) -> Self: """Convert a potentially incomplete StateConfig DataFrame or a list of dictionaries to the standardized StateConfig format. This will ignore any additional columns. :param mapping: Mapping to be converted to the StateConfig format. :param state_params: State parameter values for parameters supplied in mapping (e.g. {min_temp: 20}) :return: StateConfig object. """ if not state_params: state_params = {} # cast to list of dicts _mapping: Sequence[dict[str, Any]] = ( mapping.to_dict("records") if isinstance(mapping, pd.DataFrame) else mapping ) # build a new list with NaN entries removed _mapping = [{k: v for k, v in statevar.items() if not pd.isna(v)} for statevar in _mapping] for statevar in _mapping: for field_name, value in statevar.items(): if field_name in ("name", "ext_id", "scenario_id"): # Supposed to be strings continue if isinstance(value, str): parameter_name = value if is_negative := value.startswith("-"): parameter_name = parameter_name[1:] # strip minus sign new_value = state_params.get(parameter_name) if new_value is None: msg = f"Parameter {parameter_name} needs to be specified in state_params." raise KeyError(msg) if is_negative: new_value = -new_value statevar[field_name] = new_value return cls(*[StateVar.from_dict(col) for col in _mapping], **kwargs)
[docs] def store_file(self, file: Path) -> None: """Save the StateConfig to a comma separated file. :param file: Path to the file. """ _file = file if isinstance(file, pathlib.Path) else pathlib.Path(file) _header = StateVar.model_fields.keys() with _file.open("w") as f: writer = DictWriter(f, _header, restval="None", delimiter=";") writer.writeheader() for var in self.vars.values(): writer.writerow(var.model_dump())
[docs] def within_abort_conditions(self, state: Mapping[str, float]) -> bool: """Check whether the given state is within the abort conditions specified by the StateConfig instance. :param state: The state array to check for conformance. :return: Result of the check (False if the state does not conform to the required conditions). """ # Only check abort conditions for numeric values (int and float), exclude bool since it's a subclass of int valid_min = all( state[name] >= self.vars[name].abort_condition_min for name in state if not isinstance(state[name], (bool, np.bool_)) and isinstance(state[name], (int, float, np.integer, np.floating)) ) if not valid_min: log.warning("Minimum abort condition exceeded by at least one value.") valid_max = all( state[name] <= self.vars[name].abort_condition_max for name in state if not isinstance(state[name], (bool, np.bool_)) and isinstance(state[name], (int, float, np.integer, np.floating)) ) if not valid_max: log.warning("Maximum abort condition exceeded by at least one value.") return valid_min and valid_max
[docs] def continuous_action_space(self) -> spaces.Box: """Generate a numpy ndarray action space. :return: Action space. """ actions = self.df_vars.query("is_agent_action == True") low_values = actions["low_value"].to_numpy(dtype=np.float32) high_values = actions["high_value"].to_numpy(dtype=np.float32) return spaces.Box(low_values, high_values)
[docs] def continuous_observation_space(self) -> spaces.Dict: """Generate a dictionary observation space. :return: Observation Space. """ observations: dict[str, spaces.Box] = { name: spaces.Box(low=row["low_value"], high=row["high_value"], shape=(row["duration"],), dtype=np.float32) for name, row in self.df_vars.iterrows() if row["is_agent_observation"] is True } return spaces.Dict(observations) # type: ignore[arg-type]
[docs] def continuous_spaces(self) -> tuple[spaces.Box, spaces.Dict]: """Generate continuous action and observation spaces according to the OpenAI specification. :return: Tuple of action space and observation space. """ action_space = self.continuous_action_space() observation_space = self.continuous_observation_space() return action_space, observation_space
def __str__(self) -> str: """Human-readable string representation of StateConfig.""" n_actions = len(self.actions) n_observations = len(self.observations) n_total = len(self.vars) base_str = f"StateConfig with {n_actions} actions, {n_observations} observations ({n_total} total variables)" if self.source_file is not None: return f"{base_str} from '{self.source_file}'" return base_str def __repr__(self) -> str: """Developer-friendly string representation of StateConfig.""" # Show first few variables for context actions_str = str(self.actions[:3]).split("]")[0] observations_str = str(sorted(self.observations)[:3]).split("]")[0] actions_str = f"{actions_str}{', ...' if len(self.actions) > 3 else ''}]" observations_str = f"{observations_str}{', ...' if len(self.observations) > 3 else ''}]" return f"StateConfig(actions={actions_str}, observations={observations_str})" def __getitem__(self, name: str) -> Any: return getattr(self, name) @property def loc(self) -> pd.api.indexers._LocIndexer: """Behave like dataframe (enable indexing via loc) for compatibility.""" return self.vars