Source code for graphnet.models.easy_model

"""Suggested Model subclass that enables simple user syntax."""

from collections import OrderedDict
from typing import Any, Dict, List, Optional, Union, Type

import numpy as np
import torch
from pytorch_lightning import Callback, Trainer
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from torch import Tensor
from torch.nn import ModuleList
from torch.optim import Adam
from torch.utils.data import DataLoader, SequentialSampler
from torch_geometric.data import Data
import pandas as pd
from pytorch_lightning.loggers import Logger as LightningLogger

from graphnet.training.callbacks import ProgressBar
from graphnet.models.model import Model
from graphnet.models.task import StandardLearnedTask


[docs] class EasySyntax(Model): """A suggested Model class that comes with simple user syntax. This class delivers simple user syntax for training and prediction, while imposing minimal constraints on structure. """ def __init__( self, *, tasks: Union[StandardLearnedTask, List[StandardLearnedTask]], optimizer_class: Type[torch.optim.Optimizer] = Adam, optimizer_kwargs: Optional[Dict] = None, scheduler_class: Optional[type] = None, scheduler_kwargs: Optional[Dict] = None, scheduler_config: Optional[Dict] = None, ) -> None: """Construct `StandardModel`.""" # Base class constructor super().__init__(name=__name__, class_name=self.__class__.__name__) # Check(s) if not isinstance(tasks, (list, tuple)): tasks = [tasks] # Member variable(s) self._tasks = ModuleList(tasks) self._optimizer_class = optimizer_class self._optimizer_kwargs = optimizer_kwargs or dict() self._scheduler_class = scheduler_class self._scheduler_kwargs = scheduler_kwargs or dict() self._scheduler_config = scheduler_config or dict() self.validate_tasks()
[docs] def compute_loss( self, preds: Tensor, data: List[Data], verbose: bool = False ) -> Tensor: """Compute and sum losses across tasks.""" raise NotImplementedError
[docs] def forward( self, data: Union[Data, List[Data]] ) -> List[Union[Tensor, Data]]: """Forward pass, chaining model components.""" raise NotImplementedError
[docs] def shared_step(self, batch: List[Data], batch_idx: int) -> Tensor: """Perform shared step. Applies the forward pass and the following loss calculation, shared between the training and validation step. """ raise NotImplementedError
[docs] def validate_tasks(self) -> None: """Verify that self._tasks contain compatible elements.""" raise NotImplementedError
@staticmethod def _construct_trainer( max_epochs: int = 10, gpus: Optional[Union[List[int], int]] = None, callbacks: Optional[List[Callback]] = None, logger: Optional[LightningLogger] = None, log_every_n_steps: int = 1, gradient_clip_val: Optional[float] = None, distribution_strategy: Optional[str] = "ddp", **trainer_kwargs: Any, ) -> Trainer: if gpus: accelerator = "gpu" devices = gpus else: accelerator = "cpu" devices = 1 trainer = Trainer( accelerator=accelerator, devices=devices, max_epochs=max_epochs, callbacks=callbacks, log_every_n_steps=log_every_n_steps, logger=logger, gradient_clip_val=gradient_clip_val, strategy=distribution_strategy, **trainer_kwargs, ) return trainer
[docs] def fit( self, train_dataloader: DataLoader, val_dataloader: Optional[DataLoader] = None, *, max_epochs: int = 10, early_stopping_patience: int = 5, gpus: Optional[Union[List[int], int]] = None, callbacks: Optional[List[Callback]] = None, ckpt_path: Optional[str] = None, logger: Optional[LightningLogger] = None, log_every_n_steps: int = 1, gradient_clip_val: Optional[float] = None, distribution_strategy: Optional[str] = "ddp", **trainer_kwargs: Any, ) -> None: """Fit `StandardModel` using `pytorch_lightning.Trainer`.""" # Checks if callbacks is None: # We create the bare-minimum callbacks for you. callbacks = self._create_default_callbacks( val_dataloader=val_dataloader, early_stopping_patience=early_stopping_patience, ) self.debug("No Callbacks specified. Default callbacks added.") else: # You are on your own! self.debug("Initializing training with user-provided callbacks.") pass self._print_callbacks(callbacks) has_early_stopping = self._contains_callback(callbacks, EarlyStopping) has_model_checkpoint = self._contains_callback( callbacks, ModelCheckpoint ) if (has_early_stopping) & (has_model_checkpoint is False): self.warning( "No ModelCheckpoint found in callbacks. Best-fit model will" " not automatically be loaded after training!" "" ) self.train(mode=True) trainer = self._construct_trainer( max_epochs=max_epochs, gpus=gpus, callbacks=callbacks, logger=logger, log_every_n_steps=log_every_n_steps, gradient_clip_val=gradient_clip_val, distribution_strategy=distribution_strategy, **trainer_kwargs, ) try: trainer.fit( self, train_dataloader, val_dataloader, ckpt_path=ckpt_path ) except KeyboardInterrupt: self.warning("[ctrl+c] Exiting gracefully.") pass # Load weights from best-fit model after training if possible if has_early_stopping & has_model_checkpoint: for callback in callbacks: if isinstance(callback, ModelCheckpoint): checkpoint_callback = callback self.load_state_dict( torch.load(checkpoint_callback.best_model_path)["state_dict"] ) self.info("Best-fit weights from EarlyStopping loaded.")
def _print_callbacks(self, callbacks: List[Callback]) -> None: callback_names = [] for cbck in callbacks: callback_names.append(cbck.__class__.__name__) self.info( f"Training initiated with callbacks: {', '.join(callback_names)}" ) def _contains_callback( self, callbacks: List[Callback], callback: Callback ) -> bool: """Check if `callback` is in `callbacks`.""" for cbck in callbacks: if isinstance(cbck, callback): return True return False @property def target_labels(self) -> List[str]: """Return target label.""" return [label for task in self._tasks for label in task._target_labels] @property def prediction_labels(self) -> List[str]: """Return prediction labels.""" return [ label for task in self._tasks for label in task._prediction_labels ]
[docs] def configure_optimizers(self) -> Dict[str, Any]: """Configure the model's optimizer(s).""" optimizer = self._optimizer_class( self.parameters(), **self._optimizer_kwargs ) config = { "optimizer": optimizer, } if self._scheduler_class is not None: scheduler = self._scheduler_class( optimizer, **self._scheduler_kwargs ) config.update( { "lr_scheduler": { "scheduler": scheduler, **self._scheduler_config, }, } ) return config
[docs] def training_step( self, train_batch: Union[Data, List[Data]], batch_idx: int ) -> Tensor: """Perform training step.""" if isinstance(train_batch, Data): train_batch = [train_batch] loss = self.shared_step(train_batch, batch_idx) self.log( "train_loss", loss, batch_size=self._get_batch_size(train_batch), prog_bar=True, on_epoch=True, on_step=False, sync_dist=True, ) current_lr = self.trainer.optimizers[0].param_groups[0]["lr"] self.log("lr", current_lr, prog_bar=True, on_step=True) return loss
[docs] def validation_step( self, val_batch: Union[Data, List[Data]], batch_idx: int ) -> Tensor: """Perform validation step.""" if isinstance(val_batch, Data): val_batch = [val_batch] loss = self.shared_step(val_batch, batch_idx) self.log( "val_loss", loss, batch_size=self._get_batch_size(val_batch), prog_bar=True, on_epoch=True, on_step=False, sync_dist=True, ) return loss
[docs] def inference(self) -> None: """Activate inference mode.""" for task in self._tasks: task.inference()
[docs] def train(self, mode: bool = True) -> "Model": """Deactivate inference mode.""" super().train(mode) if mode: for task in self._tasks: task.train_eval() return self
[docs] def predict( self, dataloader: DataLoader, gpus: Optional[Union[List[int], int]] = None, distribution_strategy: Optional[str] = "auto", **trainer_kwargs: Any, ) -> List[Tensor]: """Return predictions for `dataloader`.""" self.inference() self.train(mode=False) callbacks = self._create_default_callbacks( val_dataloader=None, ) inference_trainer = self._construct_trainer( gpus=gpus, distribution_strategy=distribution_strategy, callbacks=callbacks, **trainer_kwargs, ) predictions_list = inference_trainer.predict(self, dataloader) assert len(predictions_list), "Got no predictions" nb_outputs = len(predictions_list[0]) predictions: List[Tensor] = [ torch.cat([preds[ix] for preds in predictions_list], dim=0) for ix in range(nb_outputs) ] return predictions
[docs] def predict_as_dataframe( self, dataloader: DataLoader, prediction_columns: Optional[List[str]] = None, *, additional_attributes: Optional[List[str]] = None, gpus: Optional[Union[List[int], int]] = None, distribution_strategy: Optional[str] = "auto", **trainer_kwargs: Any, ) -> pd.DataFrame: """Return predictions for `dataloader` as a DataFrame. Include `additional_attributes` as additional columns in the output DataFrame. """ if prediction_columns is None: prediction_columns = self.prediction_labels if additional_attributes is None: additional_attributes = [] assert isinstance(additional_attributes, list) if ( not isinstance(dataloader.sampler, SequentialSampler) and additional_attributes ): print(dataloader.sampler) raise UserWarning( "DataLoader has a `sampler` that is not `SequentialSampler`, " "indicating that shuffling is enabled. Using " "`predict_as_dataframe` with `additional_attributes` assumes " "that the sequence of batches in `dataloader` are " "deterministic. Either call this method a `dataloader` which " "doesn't resample batches; or do not request " "`additional_attributes`." ) self.info(f"Column names for predictions are: \n {prediction_columns}") predictions_torch = self.predict( dataloader=dataloader, gpus=gpus, distribution_strategy=distribution_strategy, **trainer_kwargs, ) predictions = ( torch.cat(predictions_torch, dim=1).detach().cpu().numpy() ) assert len(prediction_columns) == predictions.shape[1], ( f"Number of provided column names ({len(prediction_columns)}) and " f"number of output columns ({predictions.shape[1]}) don't match." ) # Check if predictions are on event- or pulse-level pulse_level_predictions = len(predictions) > len(dataloader.dataset) # Get additional attributes attributes: Dict[str, List[np.ndarray]] = OrderedDict( [(attr, []) for attr in additional_attributes] ) for batch in dataloader: for attr in attributes: attribute = batch[attr] if isinstance(attribute, torch.Tensor): attribute = attribute.detach().cpu().numpy() # Check if node level predictions # If true, additional attributes are repeated # to make dimensions fit if pulse_level_predictions: if len(attribute) < np.sum( batch.n_pulses.detach().cpu().numpy() ): attribute = np.repeat( attribute, batch.n_pulses.detach().cpu().numpy() ) attributes[attr].extend(attribute) # Confirm that attributes match length of predictions skip_attributes = [] for attr in attributes.keys(): try: assert len(attributes[attr]) == len(predictions) except AssertionError: self.warning_once( "Could not automatically adjust length" f" of additional attribute '{attr}' to match length of" f" predictions.This error can be caused by heavy" " disagreement between number of examples in the" " dataset vs. actual events in the dataloader, e.g. " " heavy filtering of events in `collate_fn` passed to" " `dataloader`. This can also be caused by requesting" " pulse-level attributes for `Task`s that produce" " event-level predictions. Attribute skipped." ) skip_attributes.append(attr) # Remove bad attributes for attr in skip_attributes: attributes.pop(attr) additional_attributes.remove(attr) data = np.concatenate( [predictions] + [ np.asarray(values)[:, np.newaxis] for values in attributes.values() ], axis=1, ) results = pd.DataFrame( data, columns=prediction_columns + additional_attributes ) return results
def _create_default_callbacks( self, val_dataloader: DataLoader, early_stopping_patience: Optional[int] = None, ) -> List: """Create default callbacks. Used in cases where no callbacks are specified by the user in .fit """ callbacks = [ProgressBar()] if val_dataloader is not None: assert early_stopping_patience is not None # Add Early Stopping callbacks.append( EarlyStopping( monitor="val_loss", patience=early_stopping_patience, ) ) # Add Model Check Point callbacks.append( ModelCheckpoint( save_top_k=1, monitor="val_loss", mode="min", filename=f"{self.backbone.__class__.__name__}" + "-{epoch}-{val_loss:.2f}-{train_loss:.2f}", ) ) self.info( "EarlyStopping has been added" f" with a patience of {early_stopping_patience}." ) return callbacks def _add_early_stopping( self, val_dataloader: DataLoader, callbacks: List ) -> List: if val_dataloader is None: return callbacks has_early_stopping = False assert isinstance(callbacks, list) for callback in callbacks: if isinstance(callback, EarlyStopping): has_early_stopping = True if not has_early_stopping: callbacks.append( EarlyStopping( monitor="val_loss", patience=5, ) ) self.warning_once( "Got validation dataloader but no EarlyStopping callback. An " "EarlyStopping callback has been added automatically with " "patience=5 and monitor = 'val_loss'." ) return callbacks