from __future__ import annotations
import abc
import json
import pathlib
from logging import getLogger
from typing import TYPE_CHECKING
from attrs import asdict
from eta_ctrl.util import log_add_filehandler
if TYPE_CHECKING:
from stable_baselines3.common.base_class import BaseAlgorithm
from eta_ctrl.config import Config, ConfigRun
log = getLogger(__name__)
[docs]
def log_to_file(config: Config, config_run: ConfigRun) -> None:
"""Log output in terminal to the run_info file.
:param config: Configuration to figure out the logging settings.
:param config_run: Configuration for this optimization run.
"""
file_path = config_run.log_output_path
if config.settings.log_to_file:
try:
log_add_filehandler(filename=file_path)
except Exception:
log.exception("Log file could not be created.")
[docs]
def log_run_info(config: Config, config_run: ConfigRun) -> None:
"""Save run configuration to the run_info file.
:param config: Configuration for the framework.
:param config_run: Configuration for this optimization run.
"""
with config_run.run_info_path.open("w") as f:
class Encoder(json.JSONEncoder):
def default(self, o: object) -> object:
if isinstance(o, pathlib.Path):
return str(o)
if isinstance(o, abc.ABCMeta):
return None
return repr(o)
try:
json.dump({**asdict(config_run), **config.model_dump()}, f, indent=4, cls=Encoder)
log.info("Log file successfully created.")
except TypeError:
log.warning("Log file could not be created because of non-serializable input in config.")
[docs]
def log_net_arch(model: BaseAlgorithm, config_run: ConfigRun) -> None:
"""Store network architecture or policy information in a file. This requires for the model to be initialized,
otherwise it will raise a ValueError.
:param model: The algorithm whose network architecture is stored.
:param config_run: Optimization run configuration (which contains info about the file to store info in).
:raises: ValueError.
"""
from .sb3_extensions.policies import NoPolicy # noqa: PLC0415
if not config_run.net_arch_path.exists() and model.policy is not None and model.policy.__class__ is not NoPolicy:
with pathlib.Path(config_run.net_arch_path).open("w") as f:
f.write(str(model.policy))
log.info(f"Net arch / Policy information store successfully in: {config_run.net_arch_path}.")
elif config_run.net_arch_path.exists():
log.info(f"Net arch / Policy information already exists in {config_run.net_arch_path}")