graphnet.models.rnn.node_rnn module

Implementation of the NodeTimeRNN model.

(cannot be used as a standalone model)

class graphnet.models.rnn.node_rnn.Node_RNN(*args, **kwargs)[source]

Bases: GNN

Implementation of the Node RNN model architecture.

The model takes as input the typical DOM data format and transforms it into a time series of DOM activations pr. DOM. before applying a RNN layer and outputting the an RNN output for each DOM. This model is in its current state not intended to be used as a standalone model. Furthermore, it needs to be used with a time-series dataset object, where the last column in x is a special column that is used to seperate the activation into time series per dom per batch.

Construct Node_RNN.

Parameters:
  • nb_inputs (int) – Number of features in the input data.

  • hidden_size (int) – Number of features for the RNN output and hidden layers.

  • num_layers (int) – Number of layers in the RNN.

  • time_series_columns (List[int]) – The indices of the input data that should be treated as time series data. The first index should be the charge column.

  • nb_neighbours (int, default: 8) – Number of neighbours to use when reconstructing the graph representation. Defaults to 8.

  • features_subset (Optional[List[int]], default: None) – The subset of latent features on each node that are used as metric dimensions when performing the k-nearest neighbours clustering. Defaults to [0,1,2,3]

  • dropout (float, default: 0.5) – Dropout fraction to use in the RNN. Defaults to 0.5.

  • embedding_dim (int, default: 0) – Embedding dimension of the RNN. Defaults to no embedding.

  • args (Any)

  • kwargs (Any)

Return type:

object

clean_up_data_object(data)[source]

Update the feature names of the data object.

Parameters:

data (Data) – The input data object.

Return type:

Data

forward(data)[source]

Apply learnable forward pass to the GNN.

Return type:

Tensor

Parameters:

data (Data)