"""Config classes for the `graphnet.data.dataset` module."""
import warnings
from abc import ABCMeta
from functools import wraps
from typing import (
    TYPE_CHECKING,
    Any,
    Callable,
    Dict,
    List,
    Optional,
    Union,
)
from graphnet.utilities.config.base_config import (
    BaseConfig,
    get_all_argument_values,
)
from graphnet.utilities.config.parsing import traverse_and_apply
from .model_config import ModelConfig
if TYPE_CHECKING:
    from graphnet.models import Model
BACKEND_LOOKUP = {
    "db": "sqlite",
    "parquet": "parquet",
}
[docs]
class DatasetConfig(BaseConfig):
    """Configuration for all `Dataset`s."""
    # Fields
    path: Union[str, List[str]]
    pulsemaps: Union[str, List[str]]
    features: List[str]
    truth: List[str]
    node_truth: Optional[List[str]] = None
    index_column: str = "event_no"
    truth_table: str = "truth"
    node_truth_table: Optional[str] = None
    string_selection: Optional[List[int]] = None
    selection: Optional[
        Union[
            str,
            List[str],
            List[Union[int, List[int]]],
            Dict[str, Union[str, List[str]]],
        ]
    ] = None
    loss_weight_table: Optional[str] = None
    loss_weight_column: Optional[str] = None
    loss_weight_default_value: Optional[float] = None
    seed: Optional[int] = None
    # DEPRECATION FIELD: REMOVE AT 2.0 LAUNCH
    # See https://github.com/graphnet-team/graphnet/issues/647
    graph_definition: Any = None
    data_representation: Any = None
    labels: Optional[Dict[str, Any]] = None
    def __init__(self, **data: Any) -> None:
        """Construct `DataConfig`.
        Can be used for dataset configuration as code, thereby making dataset
        construction more transparent and reproducible.
        Examples:
            In one session, do:
            >>> dataset = Dataset(...)
            >>> dataset.config.dump()
            path: (...)
            pulsemaps:
                - (...)
            (...)
            >>> dataset.config.dump("dataset.yml")
            In another session, you can then do:
            >>> dataset = Dataset.from_config("dataset.yml")
            # Uniquely for `DatasetConfig`, you can also define and load
            # multiple datasets
            >>> dataset.config.selection = {
                "train": "event_no % 2 == 0",
                "test": "event_no % 2 == 1",
            }
            >>> dataset.config.dump("dataset.yml")
            >>> datasets: Dict[str, Dataset] = Dataset.from_config(
                "dataset.yml"
            )
            >>> datasets
            {
                "train": Dataset(...),
                "test": Dataset(...),
            }
            # You can also combine multiple selections into a single, named
            # dataset
            >>> dataset.config.selection = {
                "train": [
                    "event_no % 2 == 0 & abs(pid) == 12",
                    "event_no % 2 == 0 & abs(pid) == 14",
                    "event_no % 2 == 0 & abs(pid) == 16",
                ],
                (...)
            }
            >>> dataset.config.dump("dataset.yml")
            >>> datasets: Dict[str, EnsembleDataset] = Dataset.from_config(
                "dataset.yml"
            )
            >>> datasets
            {
                "train": EnsembleDataset(...),
                (...)
            }
            # Finally, you can still reference existing selection files in CSV
            # or JSON formats:
            >>> dataset.config.selection = {
                "train": "50000 random events ~ train_selection.csv",
                "test": "test_selection.csv",
            }
        """
        # Single-key dictioaries are unpacked
        if isinstance(data["selection"], dict) and len(data["selection"]) == 1:
            data["selection"] = next(iter(data["selection"].values()))
        # Base class constructor
        super().__init__(**data)
    @property
    def _backend(self) -> str:
        path: str
        if isinstance(self.path, list):
            path = self.path[0]
        else:
            assert isinstance(self.path, str)
            path = self.path
        suffix = path.split(".")[-1]
        try:
            return BACKEND_LOOKUP[suffix]
        except KeyError:
            self.error(
                f"Dataset at `path` {self.path} with suffix {suffix} not "
                "supported."
            )
            raise
    @property
    def _dataset_class(self) -> type:
        """Return the `Dataset` class implementation for this configuration."""
        from graphnet.data.dataset.sqlite import SQLiteDataset
        from graphnet.data.dataset.parquet import ParquetDataset
        dataset_class = {
            "sqlite": SQLiteDataset,
            "parquet": ParquetDataset,
        }[self._backend]
        return dataset_class
[docs]
    def as_dict(self) -> Dict[str, Dict[str, Any]]:
        """Represent ModelConfig as a dict.
        This builds on `BaseModel.dict()` but wraps the output in a
        single-key dictionary to make it unambiguous to identify model
        arguments that are themselves models.
        """
        config_dict = self.dict()
        config_dict = traverse_and_apply(
            obj=dict(**config_dict), fn=self._parse_torch
        )
        return {self.__class__.__name__: config_dict} 
    def _parse_torch(self, obj: Any) -> Any:
        import torch
        if isinstance(obj, torch.dtype):
            return obj.__str__()
        else:
            return obj 
[docs]
def save_dataset_config(init_fn: Callable) -> Callable:
    """Save the arguments to `__init__` functions as member `DatasetConfig`."""
    warnings.warn(
        "Warning: `save_dataset_config` is deprecated. Config saving "
        "is now done automatically, for all classes inheriting from Dataset",
        DeprecationWarning,
    )
    def _replace_model_instance_with_config(
        obj: Union["Model", Any]
    ) -> Union[ModelConfig, Any]:
        """Replace `Model` instances in `obj` with their `ModelConfig`."""
        from graphnet.models import Model
        import torch
        if isinstance(obj, Model):
            return obj.config
        if isinstance(obj, torch.dtype):
            return obj.__str__()
        else:
            return obj
    @wraps(init_fn)
    def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
        """Set `DatasetConfig` after calling `init_fn`."""
        # Call wrapped method
        ret = init_fn(self, *args, **kwargs)
        # Get all argument values, including defaults
        cfg = get_all_argument_values(init_fn, *args, **kwargs)
        # Handle nested `Model`s, etc.
        cfg = traverse_and_apply(cfg, _replace_model_instance_with_config)
        # Add `DatasetConfig` as member variables
        self._config = DatasetConfig(**cfg)
        return ret
    return wrapper