Source code for eta_ctrl.timeseries.scenario_manager

from __future__ import annotations

from abc import ABC, abstractmethod
from pathlib import Path  # noqa: TC003, pydantic needs this for type validation at runtime
from typing import TYPE_CHECKING

import numpy as np
from pydantic import BaseModel, ConfigDict, Field

from eta_ctrl import timeseries
from eta_ctrl.util.type_annotations import FillMethod, InferDatetimeType  # noqa: TC001

if TYPE_CHECKING:
    from datetime import datetime
    from typing import Any

    import pandas as pd


[docs] class ConfigCsvScenario(BaseModel): #: :meta private: model_config = ConfigDict(extra="forbid", frozen=True) #: Relative path to the scenario path: str #: Method of interpolation interpolation_method: FillMethod | None = None #: Scale factors for each columns scale_factors: dict[str, float] | None = None #: Prefix for all column names prefix: str | None = None #: Setting how the datetime values should be converted. #: When set to string it uses the format from ``time_conversion_str``, #: when set to 'dates' it will use pandas to determine the datetime. #: If a two-tuple (row, col) is given, data from the specified field in the data files #: will be used to determine the date format infer_datetime_cols: InferDatetimeType | tuple[int, int] = "dates" #: Time conversion string used when ``infer_datetime_cols`` is set to 'string' #: Should specify the format for Python ``strptime`` time_conversion_str: str = "%Y-%m-%d %H:%M" #: Dictionary for renaming column names #: #: .. note:: #: #: The column names are stripped of illegal characters and underscores are added in place of spaces. #: "Water Temperature #2 [°C]" becomes "Water_Temperature_2_C". If you want to rename the column, #: you need to specify the processed name, for example: {"Water_Temperature_2_C": "T_W"}. rename_cols: dict[str, str] | None = None #: Directory for the scenarios. #: Not included in config declaration, passed by main Config object scenarios_path: Path = Field(exclude=True)
[docs] def model_post_init(self, _: Any) -> None: """Ensure that the CSV file exists. :raises FileNotFoundError: If file does not exist. """ if not self.abs_path.exists(): msg = "Scenario file does not exist" raise FileNotFoundError(msg)
@property def abs_path(self) -> Path: """Absolute file path of the scenario.""" return (self.scenarios_path / self.path).resolve()
[docs] class ScenarioManager(ABC):
[docs] @abstractmethod def get_scenario_state(self, n_steps: int, columns: list[str] | None = None) -> dict[str, np.ndarray]: """Get scenario values for the current time step. :param n_steps: Current time step. :param columns: Optional list of column names to return. If None, returns all columns. :return: Dictionary with scenario data for the requested columns. """ raise NotImplementedError
[docs] @abstractmethod def get_scenario_state_with_duration( self, n_step: int, duration: int, columns: list[str] | None = None ) -> dict[str, np.ndarray]: """Get scenario values for the interval [n_step, n_step+duration]. :param n_step: Current time step. :param duration: Additional amount of steps in interval. :param columns: Optional list of column names to return. If None, returns all columns. :return: Dictionary with scenario data for the requested columns. """ raise NotImplementedError
[docs] class CsvScenarioManager(ScenarioManager): """ScenarioManager class for loading scenario data from CSV files.""" def __init__( self, scenario_configs: list[ConfigCsvScenario], start_time: datetime, end_time: datetime, total_time: float, resample_time: float = 0.0, seed: int | None = None, ) -> None: self._data: pd.DataFrame self.scenario_configs: list[ConfigCsvScenario] = scenario_configs self.start_time = start_time self.end_time = end_time self.total_time = total_time self.resample_time = resample_time self.seed = seed self.load_data()
[docs] def load_data(self) -> None: """Load scenario data by calling 'scenario_from_csv' with the ConfigCsvScenario objects""" random = np.random.default_rng(self.seed) if self.seed is not None else None self.scenarios = timeseries.scenario_from_csv( scenario_configs=self.scenario_configs, start_time=self.start_time, end_time=self.end_time, total_time=self.total_time, random=random, resample_time=self.resample_time, prefix_renamed=True, ) self.total_length = len(self.scenarios)
def _validate_columns(self, columns: list[str] | None) -> list[str]: """Validate and return the list of columns to retrieve. :param columns: Requested column names, or None for all columns. :return: List of valid column names to retrieve. :raises KeyError: If any requested column is not found in the scenario data. """ if columns is None: return list(self.scenarios.columns) missing_cols = set(columns) - set(self.scenarios.columns) if missing_cols: available_cols = list(self.scenarios.columns) msg = ( f"Requested scenario columns {sorted(missing_cols)} not found in loaded scenario data. " f"Available columns: {available_cols}" ) raise KeyError(msg) return columns
[docs] def get_scenario_state(self, n_steps: int, columns: list[str] | None = None) -> dict[str, np.ndarray]: if n_steps >= self.total_length: msg = f"n_steps {n_steps} is out of bounds for scenarios with length {len(self.scenarios)}" raise IndexError(msg) valid_columns = self._validate_columns(columns) vals = self.scenarios.loc[self.scenarios.index[n_steps], valid_columns] if isinstance(vals, (int, float)): # Single column case - vals is a scalar return {valid_columns[0]: np.array([vals])} return {name: np.array([val]) for name, val in vals.to_dict().items()}
[docs] def get_scenario_state_with_duration( self, n_step: int, duration: int, columns: list[str] | None = None ) -> dict[str, np.ndarray]: end_index = n_step + duration if end_index > self.total_length: msg = ( f"Requested data from {n_step} to {end_index} ({duration} steps) " f"but only {self.total_length} steps available. " f"Shortfall: {end_index - self.total_length} steps." ) raise IndexError(msg) valid_columns = self._validate_columns(columns) return self.scenarios.loc[self.scenarios.index[n_step:end_index], valid_columns].to_dict()