node_rnn¶
Implementation of the NodeTimeRNN model.
(cannot be used as a standalone model)
- class graphnet.models.rnn.node_rnn.Node_RNN(*args, **kwargs)[source]¶
Bases:
GNN
Implementation of the Node RNN model architecture.
The model takes as input the typical DOM data format and transforms it into a time series of DOM activations pr. DOM. before applying a RNN layer and outputting the an RNN output for each DOM. This model is in its current state not intended to be used as a standalone model. Furthermore, it needs to be used with a time-series dataset object, where the last column in x is a special column that is used to seperate the activation into time series per dom per batch.
Construct Node_RNN.
- Parameters:
nb_inputs (
int
) – Number of features in the input data.hidden_size (
int
) – Number of features for the RNN output and hidden layers.num_layers (
int
) – Number of layers in the RNN.time_series_columns (
List
[int
]) – The indices of the input data that should be treated as time series data. The first index should be the charge column.nb_neighbours (
int
, default:8
) – Number of neighbours to use when reconstructing the graph representation. Defaults to 8.features_subset (
Optional
[List
[int
]], default:None
) – The subset of latent features on each node that are used as metric dimensions when performing the k-nearest neighbours clustering. Defaults to [0,1,2,3]dropout (
float
, default:0.5
) – Dropout fraction to use in the RNN. Defaults to 0.5. embedding_dim: Embedding dimension of the RNN. Defaults to no embedding.embedding_dim (
int
, default:0
) – Dimension of the embedding. Defaults to 0.args (Any)
kwargs (Any)
- Return type:
object