Source code for graphnet.deployment.deployment_module

"""Class(es) for deploying GraphNeT models in icetray as I3Modules."""
from abc import abstractmethod
from typing import Any, List, Union, Dict

import numpy as np
from torch import Tensor, load
from torch_geometric.data import Data, Batch

from graphnet.models import Model
from graphnet.utilities.config import ModelConfig
from graphnet.utilities.logging import Logger


[docs] class DeploymentModule(Logger): """Base DeploymentModule for GraphNeT. Contains standard methods for loading models doing inference with them. Experiment-specific implementations may overwrite methods and should define `__call__`. """ def __init__( self, model_config: Union[ModelConfig, str], state_dict: Union[Dict[str, Tensor], str], device: str = "cpu", prediction_columns: Union[List[str], None] = None, ): """Construct DeploymentModule. Arguments: model_config: A model configuration file. state_dict: A state dict for the model. device: The computational device to use. Defaults to "cpu". prediction_columns: Column names for each column in model output. """ super().__init__(name=__name__, class_name=self.__class__.__name__) # Set Member Variables self.model = self._load_model( model_config=model_config, state_dict=state_dict ) self.prediction_columns = self._resolve_prediction_columns( prediction_columns ) # Set model to inference mode. self.model.inference() # Move model to device self.model.to(device) @abstractmethod def __call__(self, input_data: Any) -> Any: """Define here how the module acts on a file/data stream.""" def _load_model( self, model_config: Union[ModelConfig, str], state_dict: Union[Dict[str, Tensor], str], ) -> Model: """Load `Model` from config and insert learned weights.""" model = Model.from_config(model_config, trust=True) if isinstance(state_dict, str) and state_dict.endswith(".ckpt"): ckpt = load(state_dict) model.load_state_dict(ckpt["state_dict"]) else: model.load_state_dict(state_dict) return model def _resolve_prediction_columns( self, prediction_columns: Union[List[str], None] ) -> List[str]: if prediction_columns is not None: if isinstance(prediction_columns, str): prediction_columns = [prediction_columns] else: prediction_columns = prediction_columns else: prediction_columns = self.model.prediction_labels return prediction_columns def _inference(self, data: Union[Data, Batch]) -> List[np.ndarray]: """Apply model to a single event or batch of events `data`. Args: data: A `Data` or ``Batch` object - either a single output of a `GraphDefinition` or a batch of them. Returns: A List of numpy arrays, each representing the output from the `Task`s that the model contains. """ # Perform inference output = self.model(data=data) # Loop over tasks in model and transform to numpy for k in range(len(output)): output[k] = output[k].detach().numpy() return output