graphnet.models.gnn.RNN_tito module

RNN_DynEdge model implementation.

class graphnet.models.gnn.RNN_tito.RNN_TITO(*args, **kwargs)[source]

Bases: GNN

The RNN_TITO model class.

Combines the Node_RNN and DynEdgeTITO models, intended for data with large amount of DOM activations per event. This model works only with non- standard dataset specific to the Node_RNN model see Node_RNN for more details.

Initialize the RNN_DynEdge model.

Parameters:
  • nb_inputs (int) – Number of input features.

  • time_series_columns (List[int]) – The indices of the input data that should be treated as time series data. The first index should be the charge column.

  • nb_neighbours (int, optional) – Number of neighbours to consider. Defaults to 8.

  • rnn_layers (int, optional) – Number of RNN layers. Defaults to 1.

  • rnn_hidden_size (int, optional) – Size of the hidden state of the RNN. Also determines the size of the output of the RNN. Defaults to 64.

  • rnn_dropout (float, optional) – Dropout to use in the RNN. Defaults to 0.5.

  • features_subset (List[int], optional) – The subset of latent features on each node that are used as metric dimensions when performing the k-nearest neighbours clustering. Defaults to [0,1,2,3]

  • dyntrans_layer_sizes (List[Tuple[int, ...]], optional) – List of tuples representing the sizes of the hidden layers of the DynTrans model.

  • post_processing_layer_sizes (List[int], optional) – List of integers representing the sizes of the hidden layers of the post-processing model.

  • readout_layer_sizes (List[int], optional) – List of integers representing the sizes of the hidden layers of the readout model.

  • global_pooling_schemes (Union[str, List[str]], optional) – Pooling schemes to use. Defaults to None.

  • embedding_dim (int, optional) – Embedding dimension of the RNN. Defaults to None ie. no embedding.

  • n_head (int, optional) – Number of heads to use in the DynTrans model. Defaults to 16.

  • use_global_features (bool, optional) – Whether to use global features after pooling. Defaults to True.

  • use_post_processing_layers (bool, optional) – Whether to use post-processing layers after the DynTrans layers. Defaults to True.

  • args (Any)

  • kwargs (Any)

Return type:

object

forward(data)[source]

Apply learnable forward pass of the RNN and tito model.

Return type:

Tensor

Parameters:

data (Data)