from __future__ import annotations
import logging
from abc import ABC, abstractmethod
from pathlib import Path # noqa: TC003, pydantic needs this for type validation at runtime
from typing import TYPE_CHECKING
from pydantic import BaseModel, ConfigDict, Field
from eta_ctrl import timeseries
from eta_ctrl.util.type_annotations import FillMethod, InferDatetimeType # noqa: TC001
if TYPE_CHECKING:
from datetime import datetime
from typing import Any
import numpy as np
import pandas as pd
from eta_ctrl.envs.state import StateVar
log = logging.getLogger(__name__)
[docs]
class ConfigCsvScenario(BaseModel):
model_config = ConfigDict(extra="forbid", frozen=True, use_attribute_docstrings=True)
path: str
"""Relative path to the scenario."""
interpolation_method: FillMethod | None = None
"""Pandas method to use for filling missing data ["ffill", "bfill", "interpolate", "asfreq"]."""
scale_factors: dict[str, float] | None = None
"""Scale factors for each column."""
prefix: str | None = None
"""Prefix for all column names."""
infer_datetime_cols: InferDatetimeType | tuple[int, int] = Field(
default="dates", description="Methof of datetime parsing"
)
"""Setting how the datetime values should be converted.
When set to string it uses the format from ``time_conversion_str``,
when set to 'dates' it will use pandas to determine the datetime.
If a two-tuple (row, col) is given, data from the specified field in the data files
will be used to determine the date format.
"""
time_conversion_str: str = "%Y-%m-%d %H:%M"
"""Time conversion string used when ``infer_datetime_cols`` is set to 'string'.
Should specify the format for Python ``strptime``.
"""
rename_cols: dict[str, str] | None = None
"""Dictionary for renaming column names.
.. note::
The column names are stripped of illegal characters and underscores are added in place of spaces.
"Water Temperature #2 [°C]" becomes "Water_Temperature_2_C". If you want to rename the column,
you need to specify the processed name, for example: {"Water_Temperature_2_C": "T_W"}.
"""
scenarios_path: Path | None = Field(default=None, exclude=True)
"""Directory for the scenarios. Not included in config declaration, passed by main Config object."""
[docs]
def model_post_init(self, _: Any) -> None:
"""Ensure that the CSV file exists.
:raises FileNotFoundError: If file does not exist.
"""
if self.scenarios_path is not None and not self.abs_path.exists():
msg = "Scenario file does not exist"
raise FileNotFoundError(msg)
@property
def abs_path(self) -> Path:
"""Absolute file path of the scenario."""
if self.scenarios_path is None:
msg = f"Relative path is not set for file {self.path}"
raise AttributeError(msg)
return (self.scenarios_path / self.path).resolve()
[docs]
class ScenarioManager(ABC):
[docs]
def compute_episode_offset(self, rng: np.random.Generator) -> int:
"""Compute the row offset into the scenario data for the next episode.
Returns 0 by default (no random slicing). Override in subclasses that support
random time slicing.
:param rng: Random number generator from the environment.
:return: Integer row offset into the scenario data.
"""
return 0
[docs]
def get_scenario_state_var(self, n_step: int, state_var: StateVar) -> np.ndarray:
"""Get scenario values for a single state variable at the given (absolute) step.
:param n_step: Absolute row index into the scenario data (env step + episode offset).
:param state_var: State variable configuration.
:return: Array of scenario values.
"""
scenario_id = state_var.scenario_id
duration = state_var.duration
data = self._get_data(n_step=n_step, duration=duration, names=[scenario_id]) # type: ignore[list-item]
return data[scenario_id] # type: ignore[index]
@abstractmethod
def _get_data(self, n_step: int, duration: int = 1, names: list[str] | None = None) -> dict[str, np.ndarray]:
"""Get all scenario values for the interval [n_step, n_step+duration].
:param n_step: Absolute row index into the scenario data (env step + episode offset).
:param duration: Number of steps to retrieve.
:param names: Column names to retrieve. If None, all columns are returned.
:return: Dictionary mapping column names to value arrays.
"""
raise NotImplementedError
[docs]
class CsvScenarioManager(ScenarioManager):
"""ScenarioManager class for loading scenario data from CSV files."""
def __init__(
self,
scenario_configs: list[ConfigCsvScenario],
start_time: datetime,
end_time: datetime,
total_time: float,
resample_time: float,
use_random_time_slice: bool = False,
) -> None:
super().__init__()
self._data: pd.DataFrame
self.scenario_steps = int(total_time / resample_time)
self.scenario_configs: list[ConfigCsvScenario] = scenario_configs
self.start_time = start_time
self.end_time = end_time
self.total_time = total_time
self.resample_time = resample_time
self.use_random_time_slice = use_random_time_slice
self.load_data()
[docs]
def compute_episode_offset(self, rng: np.random.Generator) -> int:
"""Compute the row offset into the scenario dataframe for the next episode.
:param rng: Random number generator used to pick a random starting position.
:return: Integer row index into self.scenarios representing the episode start.
"""
if not self.use_random_time_slice:
return 0
available_space = self.total_df_length - self.scenario_steps
if available_space == 0:
return 0
return rng.choice(range(available_space)).item()
[docs]
def load_data(self) -> None:
"""Load scenario data by calling 'scenario_from_csv' with the ConfigCsvScenario objects"""
self.scenarios = timeseries.scenario_from_csv(
scenario_configs=self.scenario_configs,
start_time=self.start_time,
end_time=self.end_time,
resample_time=self.resample_time,
prefix_renamed=True,
)
self.total_df_length = len(self.scenarios)
def __str__(self) -> str:
"""Human-readable string representation of CsvScenarioManager."""
n_scenarios = len(self.scenario_configs)
return f"CsvScenarioManager({n_scenarios} scenario(s), {self.start_time} to {self.end_time})"
def __repr__(self) -> str:
"""Developer-friendly string representation of CsvScenarioManager."""
return (
f"CsvScenarioManager(start_time={self.start_time!r}, end_time={self.end_time!r}, "
f"total_time={self.total_time}, n_scenarios={len(self.scenario_configs)}, "
f"columns={list(self.scenarios.columns)})"
)
def _validate_columns(self, columns: list[str] | None) -> list[str]:
"""Validate and return the list of columns to retrieve.
:param columns: Requested column names, or None for all columns.
:return: List of valid column names to retrieve.
:raises KeyError: If any requested column is not found in the scenario data.
"""
if columns is None:
return list(self.scenarios.columns)
missing_cols = set(columns) - set(self.scenarios.columns)
if missing_cols:
available_cols = list(self.scenarios.columns)
msg = (
f"Requested scenario columns {sorted(missing_cols)} not found in loaded scenario data. "
f"Available columns: {available_cols}"
)
raise KeyError(msg)
return columns
def _get_data(self, n_step: int, duration: int = 1, names: list[str] | None = None) -> dict[str, np.ndarray]:
end_index = n_step + duration
if end_index > self.total_df_length:
msg = (
f"Requested data from {n_step} to {end_index} ({duration} steps) "
f"but only {self.total_df_length} steps available. "
f"Shortfall: {end_index - self.total_df_length} steps."
)
raise IndexError(msg)
# Choose all columns if names are not supplied
columns = self._validate_columns(columns=names)
return {col: self.scenarios[col].iloc[n_step:end_index].to_numpy() for col in columns}