grit

Implementation of GRIT, a graph transformer model.

Original author: Liheng Ma Original code: https://github.com/LiamMa/GRIT Paper: “Graph Inductive Biases in Transformers without Message Passing”,

Adapted by: Philip Weigel

class graphnet.models.gnn.grit.GRIT(*args, **kwargs)[source]

Bases: GNN

GRIT is a graph transformer model.

Original code: https://github.com/LiamMa/GRIT/blob/main/grit/network/grit_model.py

Construct GRIT model.

Parameters:
  • nb_inputs (int) – Number of inputs.

  • hidden_dim (int) – Size of hidden dimension.

  • nb_outputs (int, default: 1) – Size of output dimension.

  • ksteps (int, default: 21) – Number of random walk steps.

  • n_layers (int, default: 10) – Number of GRIT layers.

  • n_heads (int, default: 8) – Number of heads in MHA.

  • pad_to_full_graph (bool, default: True) – Pad to form fully-connected graph.

  • add_node_attr_as_self_loop (bool, default: False) – Adds node attr as an self-edge.

  • dropout (float, default: 0.0) – Dropout probability.

  • fill_value (float, default: 0.0) – Padding value.

  • norm (Module, default: <class 'torch.nn.modules.batchnorm.BatchNorm1d'>) – Uninstantiated normalization layer. Either torch.nn.BatchNorm1d or torch.nn.LayerNorm.

  • attn_dropout (float, default: 0.2) – Attention dropout probability.

  • edge_enhance (bool, default: True) – Applies learnable weight matrix with node-pair in output node calculation for MHA.

  • update_edges (bool, default: True) – Update edge values after GRIT layer.

  • attn_clamp (float, default: 5.0) – Clamp absolute value of attention scores to a value.

  • activation (Module, default: <class 'torch.nn.modules.activation.ReLU'>) – Uninstantiated activation function. E.g. torch.nn.ReLU

  • attn_activation (Module, default: <class 'torch.nn.modules.activation.ReLU'>) – Uninstantiated attention activation function. E.g. torch.nn.ReLU

  • norm_edges (bool, default: True) – Apply normalization layer to edges.

  • enable_edge_transform (bool, default: True) – Apply transformation to edges.

  • pred_head_layers (int, default: 2) – Number of layers in the prediction head.

  • pred_head_activation (Module, default: <class 'torch.nn.modules.activation.ReLU'>) – Uninstantiated prediction head activation function. E.g. torch.nn.ReLU

  • pred_head_pooling (str, default: 'mean') – Pooling function to use for the prediction head, either “mean” (default) or “add”.

  • position_encoding (str, default: 'NoPE') – Method of position encoding.

  • args (Any)

  • kwargs (Any)

Return type:

object

forward(x)[source]

Forward pass.

Return type:

Tensor

Parameters:

x (Data)