from __future__ import annotations
import abc
from collections.abc import Sequence
from logging import getLogger
from typing import TYPE_CHECKING
from eta_nexus.connection_manager import ConnectionManager
from eta_ctrl.envs import BaseEnv
if TYPE_CHECKING:
from collections.abc import Callable
from typing import Any
from eta_ctrl.config import ConfigRun
from eta_ctrl.util.type_annotations import Path, TimeStep
log = getLogger(__name__)
[docs]
class LiveEnv(BaseEnv, abc.ABC):
"""Base class for Live environments.
The class will create an ETA Nexus ConnectionManager instance
and provide facilities to automatically read step results and reset the connection.
Additionally to required class attribute from `BaseEnv`, `LiveEnv` requires
the name of the connection manager configuration file as a class attribute:
- **config_name**: Name of the connection manager configuration.
:param env_id: Identification for the environment, useful when creating multiple environments.
:param config_run: Configuration of the optimization run.
:param verbose: Verbosity to use for logging.
:param callback: callback which should be called after each episode.
:param episode_duration: Duration of the episode in seconds.
:param sampling_time: Duration of a single time sample / time step in seconds.
:param max_errors: Maximum number of connection errors before interrupting the optimization process.
:param render_mode: Renders the environments to help visualise what the agent see, examples
modes are "human", "rgb_array", "ansi" for text.
:param kwargs: Other keyword arguments (for subclasses).
"""
@property
@abc.abstractmethod
def config_name(self) -> str:
"""Name of the connection manager configuration.
Needs to be implemented for each subclass as a class attribute.
"""
raise NotImplementedError
def __init__(
self,
env_id: int,
config_run: ConfigRun,
verbose: int = 2,
callback: Callable | None = None,
*,
episode_duration: TimeStep | str,
sampling_time: TimeStep | str,
max_errors: int = 10,
render_mode: str | None = None,
**kwargs: Any,
) -> None:
super().__init__(
env_id=env_id,
config_run=config_run,
verbose=verbose,
callback=callback,
episode_duration=episode_duration,
sampling_time=sampling_time,
render_mode=render_mode,
**kwargs,
)
#: Instance of the Live Connector.
self.connection_manager: ConnectionManager
#: Path or Dict to initialize the live connector.
self.connection_manager_config: Path | Sequence[Path] | dict[str, Any] | None = (
self.path_env / f"{self.config_name}.json"
)
#: Maximum error count when connections in live connector are aborted.
self.max_error_count: int = max_errors
def __str__(self) -> str:
"""Human-readable string representation of LiveEnv."""
base_str = super().__str__()
config_name = self.config_name
return f"{base_str}, Live config: {config_name}"
def __repr__(self) -> str:
"""Developer-friendly string representation of LiveEnv."""
base_repr = super().__repr__()
# Remove the closing parenthesis to add our info
base_repr = base_repr.rstrip(")")
return f"{base_repr}, config_name='{self.config_name}', max_error_count={self.max_error_count})"
def _init_connection_manager(self, files: Path | Sequence[Path] | dict[str, Any] | None = None) -> None:
"""Initialize the live connector object. Make sure to call _names_from_state before this or to otherwise
initialize the names array.
:param files: Path or Dict to initialize the connection directly from JSON configuration files or a config
dictionary.
"""
_files = self.connection_manager_config if files is None else files
self.connection_manager_config = _files
if _files is None:
msg = "Configuration files or a dictionary must be specified before the connector can be initialized."
raise TypeError(msg)
if isinstance(_files, dict):
self.connection_manager = ConnectionManager.from_dict(
step_size=self.sampling_time,
max_error_count=self.max_error_count,
**_files,
)
elif isinstance(_files, Sequence):
self.connection_manager = ConnectionManager.from_config(
*_files, step_size=self.sampling_time, max_error_count=self.max_error_count
)
else:
self.connection_manager = ConnectionManager.from_config(
_files, step_size=self.sampling_time, max_error_count=self.max_error_count
)
[docs]
def _step(self) -> tuple[float, bool, bool, dict]:
"""Perform one internal time step and return core step results.
This is called for every event or for every time step during the simulation/optimization run.
It should utilize the actions as supplied by the agent to determine the new
state of the environment, which are available in the state dictionary.
This also updates self.state and self.state_log to store current state information.
.. note::
This function always returns 0 reward. Therefore, it must be extended if it is to be used with reinforcement
learning agents. If you need to manipulate actions (discretization, policy shaping, ...) do this before
calling this function. If you need to manipulate observations and rewards, do this after calling this
function.
:return: A tuple containing:
* **reward**: The value of the reward function. This is just one floating point value.
* **terminated (bool)**: Whether the agent reaches the terminal state (as defined under the MDP of the task)
which can be positive or negative. An example is reaching the goal state or moving into the lava from
the Sutton and Barto Gridworld. If true, the Vectorizer will call :meth:`reset`.
* **truncated (bool)**: Whether the truncation condition outside the scope of the MDP is satisfied
(i.e. the episode ended). Typically, this is a timelimit, but could also be used to indicate an agent
physically going out of bounds. Can be used to end the episode prematurely before a terminal state is
reached. If true, the Vectorizer will call :meth:`reset`.
* **info**: Provide some additional info about the state of the environment. The contents of this may
be used for logging purposes in the future but typically do not currently serve a purpose.
.. note::
Stable Baselines3 combines terminated and truncated with a logical OR to trigger
the automatic environment reset. Implement both flags for compatibility.
:meta public:
"""
# Set the external inputs in the live connector and read out the external outputs
results = self.connection_manager.step(value=self.get_external_inputs())
self.set_external_outputs(external_outputs=results)
return 0, False, False, {}
[docs]
def _reset(
self,
*,
seed: int | None = None,
options: dict[str, Any] | None = None,
) -> dict[str, Any]:
"""Reset the environment to an initial internal state, returning an initial observation and info.
This method generates a new starting state often with some randomness to ensure that the agent explores the
state space and learns a generalised policy about the environment. This randomness can be controlled
with the ``seed`` parameter otherwise if the environment already has a random number generator and
:meth:`reset` is called with ``seed=None``, the RNG is not reset. When using the environment in conjunction with
*stable_baselines3*, the vectorized environment will take care of seeding your custom environment automatically.
:param seed: The seed for initializing any randomized components of the state.
Subclasses should use this for reproducible randomness in their state init
:param options: Additional information to specify how the environment is reset (optional,
depending on the specific environment) (default: None)
:return: Info dictionary containing information about the initial state.
The initial observations are automatically filtered from the internal state
by the public reset method and must not be returned here.
.. note::
The base implementation initializes external outputs from the live connector
without using the seed. Subclasses should use the seed parameter
for any additional randomized state observations they implement.
:meta public:
"""
self._init_connection_manager()
# Read out the start conditions from LiveConnect and store the results
start_obs_names = [self.state_config.map_ext_ids[name] for name in self.state_config.ext_outputs]
results = self.connection_manager.read(*start_obs_names)
self.set_external_outputs(external_outputs=results)
return {}
[docs]
def close(self) -> None:
"""Close the environment. This should always be called when an entire run is finished. It should be used to
close any resources (i.e. simulation models) used by the environment.
Default behavior for the connection_manager environment is to do nothing.
"""
if hasattr(self, "connection_manager"):
self.connection_manager.close()