"""Class(es) for coarsening operations (i.e., clustering, or local pooling)."""
from abc import abstractmethod
from typing import List, Optional, Union
from copy import deepcopy
import torch
from torch import LongTensor, Tensor
from torch_geometric.data import Data, Batch
from sklearn.cluster import DBSCAN
from graphnet.models.components.pool import (
    group_by,
    avg_pool,
    max_pool,
    min_pool,
    sum_pool,
    avg_pool_x,
    max_pool_x,
    min_pool_x,
    sum_pool_x,
    std_pool_x,
)
from graphnet.models import Model
# Utility method(s)
from torch_geometric.utils import degree
# NOTE: From [https://github.com/pyg-team/pytorch_geometric/pull/4903]
# TODO:  Remove once bumping to torch_geometric>=2.1.0
#       See [https://github.com/pyg-team/pytorch_geometric/blob/master/CHANGELOG.md] # noqa: E501
[docs]
def unbatch_edge_index(edge_index: Tensor, batch: Tensor) -> List[Tensor]:
    # noqa: D401
    r"""Splits the :obj:`edge_index` according to a :obj:`batch` vector.
    Args:
        edge_index (Tensor): The edge_index tensor. Must be ordered.
        batch (LongTensor): The batch vector
            :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each
            node to a specific example. Must be ordered.
    :rtype: :class:`List[Tensor]`
    """
    deg = degree(batch, dtype=torch.int64)
    ptr = torch.cat([deg.new_zeros(1), deg.cumsum(dim=0)[:-1]], dim=0)
    edge_batch = batch[edge_index[0]]
    edge_index = edge_index - ptr[edge_batch]
    sizes = degree(edge_batch, dtype=torch.int64).cpu().tolist()
    return edge_index.split(sizes, dim=1) 
[docs]
class Coarsening(Model):
    """Base class for coarsening operations."""
    # Class variables
    reduce_options = {
        "avg": (avg_pool, avg_pool_x),
        "min": (min_pool, min_pool_x),
        "max": (max_pool, max_pool_x),
        "sum": (sum_pool, sum_pool_x),
    }
    def __init__(
        self,
        reduce: str = "avg",
        transfer_attributes: bool = True,
    ):
        """Construct `Coarsening`."""
        assert reduce in self.reduce_options
        (
            self._reduce_method,
            self._attribute_reduce_method,
        ) = self.reduce_options[reduce]
        self._do_transfer_attributes = transfer_attributes
        # Base class constructor
        super().__init__()
    @abstractmethod
    def _perform_clustering(self, data: Union[Data, Batch]) -> LongTensor:
        """Cluster nodes in `data` by assigning a cluster index to each."""
    def _additional_features(self, cluster: LongTensor, data: Batch) -> Tensor:
        """Perform additional poolings of feature tensor `x` on `data`.
        By default the nominal `pooling_method` is used for features as well.
        This method can be overwritten for bespoke coarsening operations.
        """
    def _transfer_attributes(
        self, cluster: LongTensor, original_data: Batch, pooled_data: Batch
    ) -> Batch:
        """Transfer attributes on `original_data` to `pooled_data`."""
        # Check(s)
        if not self._do_transfer_attributes:
            return pooled_data
        attributes = list(original_data._store.keys())
        batch: Optional[LongTensor] = original_data.batch
        for ix, attr in enumerate(attributes):
            if attr not in pooled_data._store:
                values: Tensor = getattr(original_data, attr)
                attr_is_node_level_tensor = False
                if isinstance(values, Tensor):
                    if batch is None:
                        attr_is_node_level_tensor = (
                            values.dim() > 1 or values.size(dim=0) > 1
                        )
                    else:
                        attr_is_node_level_tensor = (
                            values.size() == original_data.batch.size()
                        )
                if attr_is_node_level_tensor:
                    values = self._attribute_reduce_method(
                        cluster,
                        values,
                        batch=torch.zeros_like(values, dtype=torch.int32),
                    )[0]
                setattr(pooled_data, attr, values)
        return pooled_data
[docs]
    def forward(self, data: Union[Data, Batch]) -> Union[Data, Batch]:
        """Perform coarsening operation."""
        # Get tensor of cluster indices for each node.
        cluster: LongTensor = self._perform_clustering(data)
        # Check whether a graph has already been built. Otherwise, set a dummy
        # connectivity, as this is required by pooling functions.
        edge_index = data.edge_index
        if edge_index is None:
            data.edge_index = torch.tensor([[]], dtype=torch.int64)
        # Pool `data` object, including `x`, `batch`. and `edge_index`.
        pooled_data: Batch = self._reduce_method(cluster, data)
        # Optionally overwrite feature tensor
        x = self._additional_features(cluster, data)
        if x is not None:
            pooled_data.x = torch.cat(
                (
                    pooled_data.x,
                    x,
                ),
                dim=1,
            )
        # Reset `edge_index` if necessary.
        if edge_index is None:
            data.edge_index = edge_index
            pooled_data.edge_index = edge_index
        # Transfer attributes on `data`, pooling as required.
        pooled_data = self._transfer_attributes(cluster, data, pooled_data)
        # Reconstruct Batch Attributes
        if isinstance(data, Batch):  # if a Batch object
            pooled_data = self._reconstruct_batch(data, pooled_data)
        return pooled_data 
    def _reconstruct_batch(self, original: Data, pooled: Data) -> Data:
        pooled = self._add_slice_dict(original, pooled)
        pooled = self._add_inc_dict(original, pooled)
        return pooled
    def _add_slice_dict(self, original: Data, pooled: Data) -> Data:
        # Copy original slice_dict and count nodes in each
        # graph in pooled batch
        slice_dict = deepcopy(original._slice_dict)
        _, counts = torch.unique_consecutive(pooled.batch, return_counts=True)
        # Reconstruct the entry in slice_dict for pulsemaps -
        # only these are affected by pooling
        pulsemap_slice = [0]
        for i in range(len(counts)):
            pulsemap_slice.append(pulsemap_slice[i] + counts[i].item())
        # Identifies pulsemap entries in slice_dict and
        # set them to pulsemap_slice
        for field in slice_dict.keys():
            if (original._num_graphs) == slice_dict[field][-1]:
                pass  # not pulsemap, so skip
            else:
                slice_dict[field] = pulsemap_slice
        pooled._slice_dict = slice_dict
        return pooled
    def _add_inc_dict(self, original: Data, pooled: Data) -> Data:
        # not changed by coarsening
        pooled._inc_dict = deepcopy(original._inc_dict)
        return pooled 
[docs]
class AttributeCoarsening(Coarsening):
    """Coarsen pulses based on specified attributes."""
    def __init__(
        self,
        attributes: List[str],
        reduce: str = "avg",
        transfer_attributes: bool = True,
    ):
        """Construct `SimpleCoarsening`."""
        self._attributes = attributes
        # Base class constructor
        super().__init__(reduce, transfer_attributes)
    def _perform_clustering(self, data: Union[Data, Batch]) -> LongTensor:
        """Cluster nodes in `data` by assigning a cluster index to each."""
        dom_index = group_by(data, self._attributes)
        return dom_index 
[docs]
class DOMCoarsening(Coarsening):
    """Coarsen pulses to DOM-level."""
    def __init__(
        self,
        reduce: str = "avg",
        transfer_attributes: bool = True,
        keys: Optional[List[str]] = None,
    ):
        """Cluster pulses on the same DOM."""
        super().__init__(reduce, transfer_attributes)
        if keys is None:
            self._keys = [
                "dom_x",
                "dom_y",
                "dom_z",
                "rde",
                "pmt_area",
            ]
        else:
            self._keys = keys
    def _perform_clustering(self, data: Union[Data, Batch]) -> LongTensor:
        """Cluster nodes in `data` by assigning a cluster index to each."""
        dom_index = group_by(data, self._keys)
        return dom_index 
[docs]
class CustomDOMCoarsening(DOMCoarsening):
    """Coarsen pulses to DOM-level with additional attributes."""
    def _additional_features(self, cluster: LongTensor, data: Data) -> Tensor:
        """Perform Additional poolings of feature tensor `x` on `data`."""
        batch = data.batch
        features = data.features
        if batch is not None:
            features = [feats[0] for feats in features]
        ix_time = features.index("dom_time")
        ix_charge = features.index("charge")
        time = data.x[:, ix_time]
        charge = data.x[:, ix_charge]
        x = torch.stack(
            (
                min_pool_x(cluster, time, batch)[0],
                max_pool_x(cluster, time, batch)[0],
                std_pool_x(cluster, time, batch)[0],
                min_pool_x(cluster, charge, batch)[0],
                max_pool_x(cluster, charge, batch)[0],
                std_pool_x(cluster, charge, batch)[0],
                sum_pool_x(cluster, torch.ones_like(charge), batch)[
                    0
                ],  # Num. nodes (pulses) per cluster (DOM)
            ),
            dim=1,
        )
        return x 
[docs]
class DOMAndTimeWindowCoarsening(Coarsening):
    """Coarsen pulses to DOM-level, with additional time-window clustering."""
    def __init__(
        self,
        time_window: float,
        reduce: str = "avg",
        transfer_attributes: bool = True,
        keys: List[str] = [
            "dom_x",
            "dom_y",
            "dom_z",
            "rde",
            "pmt_area",
        ],
        time_key: str = "dom_time",
    ):
        """Cluster pulses on the same DOM within `time_window`."""
        super().__init__(reduce, transfer_attributes)
        self._time_window = time_window
        self._cluster_method = DBSCAN(self._time_window, min_samples=1)
        self._keys = keys
        self._time_key = time_key
    def _perform_clustering(self, data: Union[Data, Batch]) -> LongTensor:
        """Cluster nodes in `data` by assigning a cluster index to each."""
        dom_index = group_by(data, self._keys)
        if data.batch is not None:
            features = data.features[0]
        else:
            features = data.features
        ix_time = features.index(self._time_key)
        hit_times = data.x[:, ix_time]
        # Scale up dom_index to make sure clusters are well separated
        times_and_domids = torch.stack(
            [
                hit_times,
                dom_index * self._time_window * 10,
            ]
        ).T
        clusters = torch.tensor(
            self._cluster_method.fit_predict(times_and_domids.cpu()),
            device=hit_times.device,
        )
        return clusters