grit¶
Implementation of GRIT, a graph transformer model.
Original author: Liheng Ma Original code: https://github.com/LiamMa/GRIT Paper: “Graph Inductive Biases in Transformers without Message Passing”,
Adapted by: Philip Weigel
- class graphnet.models.gnn.grit.GRIT(*args, **kwargs)[source]¶
Bases:
GNN
GRIT is a graph transformer model.
Original code: https://github.com/LiamMa/GRIT/blob/main/grit/network/grit_model.py
Construct GRIT model.
- Parameters:
nb_inputs (
int
) – Number of inputs.hidden_dim (
int
) – Size of hidden dimension.nb_outputs (
int
, default:1
) – Size of output dimension.ksteps (
int
, default:21
) – Number of random walk steps.n_layers (
int
, default:10
) – Number of GRIT layers.n_heads (
int
, default:8
) – Number of heads in MHA.pad_to_full_graph (
bool
, default:True
) – Pad to form fully-connected graph.add_node_attr_as_self_loop (
bool
, default:False
) – Adds node attr as an self-edge.dropout (
float
, default:0.0
) – Dropout probability.fill_value (
float
, default:0.0
) – Padding value.norm (
Module
, default:<class 'torch.nn.modules.batchnorm.BatchNorm1d'>
) – Uninstantiated normalization layer. Either torch.nn.BatchNorm1d or torch.nn.LayerNorm.attn_dropout (
float
, default:0.2
) – Attention dropout probability.edge_enhance (
bool
, default:True
) – Applies learnable weight matrix with node-pair in output node calculation for MHA.update_edges (
bool
, default:True
) – Update edge values after GRIT layer.attn_clamp (
float
, default:5.0
) – Clamp absolute value of attention scores to a value.activation (
Module
, default:<class 'torch.nn.modules.activation.ReLU'>
) – Uninstantiated activation function. E.g. torch.nn.ReLUattn_activation (
Module
, default:<class 'torch.nn.modules.activation.ReLU'>
) – Uninstantiated attention activation function. E.g. torch.nn.ReLUnorm_edges (
bool
, default:True
) – Apply normalization layer to edges.enable_edge_transform (
bool
, default:True
) – Apply transformation to edges.pred_head_layers (
int
, default:2
) – Number of layers in the prediction head.pred_head_activation (
Module
, default:<class 'torch.nn.modules.activation.ReLU'>
) – Uninstantiated prediction head activation function. E.g. torch.nn.ReLUpred_head_pooling (
str
, default:'mean'
) – Pooling function to use for the prediction head, either “mean” (default) or “add”.position_encoding (
str
, default:'NoPE'
) – Method of position encoding.args (Any)
kwargs (Any)
- Return type:
object