Source code for graphnet.utilities.config.base_config
"""Base config class(es)."""
from collections import OrderedDict
import inspect
import sys
from typing import Any, Callable, Dict, Optional
from pydantic import BaseModel
import ruamel.yaml as yaml
CONFIG_FILES_SUFFIXES = (".yml", ".yaml")
[docs]
class BaseConfig(BaseModel):
"""Base class for Configs."""
[docs]
@classmethod
def load(cls, path: str) -> "BaseConfig":
"""Load BaseConfig from `path`."""
assert path.endswith(
CONFIG_FILES_SUFFIXES
), "Please specify YAML config file."
with open(path, "r") as f:
yaml_ = yaml.YAML(typ="safe", pure=True)
config_dict = yaml_.load(f)
return cls(**config_dict)
[docs]
def dump(self, path: Optional[str] = None) -> Optional[str]:
"""Save BaseConfig to `path` as YAML file, or return as string."""
config_dict = self.as_dict()[self.__class__.__name__]
yaml_ = yaml.YAML(typ="safe", pure=True)
if path:
if not path.endswith(CONFIG_FILES_SUFFIXES):
path += CONFIG_FILES_SUFFIXES[0]
with open(path, "w") as f:
yaml_.dump(config_dict, f)
return None
else:
return yaml_.dump(config_dict, sys.stdout)
[docs]
def as_dict(self) -> Dict[str, Dict[str, Any]]:
"""Represent BaseConfig as a dict.
This builds on `BaseModel.dict()` but can be overwritten.
"""
return {self.__class__.__name__: self.dict()}
[docs]
def get_all_argument_values(
fn: Callable, *args: Any, **kwargs: Any
) -> Dict[str, Any]:
"""Return dict of all argument values to `fn`, including defaults."""
# Get all default argument values
cfg = OrderedDict()
for key, param in inspect.signature(fn).parameters.items():
# Don't save `self`, `*args`, or `**kwargs`
if key == "self" or param.kind in [
param.VAR_POSITIONAL,
param.VAR_KEYWORD,
]:
continue
cfg[key] = param.default
# Add positional arguments
for key, val in zip(cfg.keys(), args):
cfg[key] = val
# Add keyword arguments
cfg.update(kwargs)
return cfg