Source code for graphnet.utilities.config.base_config
"""Base config class(es)."""
from abc import abstractmethod
from collections import OrderedDict
import inspect
import sys
from typing import Any, Callable, Dict, Optional
from pydantic import BaseModel
import ruamel.yaml as yaml
CONFIG_FILES_SUFFIXES = (".yml", ".yaml")
[docs]
class BaseConfig(BaseModel):
"""Base class for Configs."""
[docs]
@classmethod
def load(cls, path: str) -> "BaseConfig":
"""Load BaseConfig from `path`."""
assert path.endswith(
CONFIG_FILES_SUFFIXES
), "Please specify YAML config file."
with open(path, "r") as f:
yaml_ = yaml.YAML(typ="safe", pure=True)
config_dict = yaml_.load(f)
return cls(**config_dict)
[docs]
def dump(self, path: Optional[str] = None) -> Optional[str]:
"""Save BaseConfig to `path` as YAML file, or return as string."""
config_dict = self.as_dict()[self.__class__.__name__]
yaml_ = yaml.YAML(typ="safe", pure=True)
if path:
if not path.endswith(CONFIG_FILES_SUFFIXES):
path += CONFIG_FILES_SUFFIXES[0]
with open(path, "w") as f:
yaml_.dump(config_dict, f)
return None
else:
return yaml_.dump(config_dict, sys.stdout)
[docs]
def as_dict(self) -> Dict[str, Dict[str, Any]]:
"""Represent BaseConfig as a dict.
This builds on `BaseModel.dict()` but can be overwritten.
"""
return {self.__class__.__name__: self.dict()}
[docs]
def get_all_argument_values(
fn: Callable, *args: Any, **kwargs: Any
) -> Dict[str, Any]:
"""Return dict of all argument values to `fn`, including defaults."""
# Get all default argument values
cfg = OrderedDict()
for key, param in inspect.signature(fn).parameters.items():
# Don't save `self`, `*args`, or `**kwargs`
if key == "self" or param.kind in [
param.VAR_POSITIONAL,
param.VAR_KEYWORD,
]:
continue
cfg[key] = param.default
# Add positional arguments
for key, val in zip(cfg.keys(), args):
cfg[key] = val
# Add keyword arguments
cfg.update(kwargs)
return cfg