graphnet.models.graphs.graph_definition module

Modules for defining graphs.

These are self-contained graph definitions that hold all the graph-altering code in graphnet. These modules define what the GNNs sees as input and can be passed to dataloaders during training and deployment.

class graphnet.models.graphs.graph_definition.GraphDefinition(*args, **kwargs)[source]

Bases: Model

An Abstract class to create graph definitions from.

Construct ´GraphDefinition´. The ´detector´ holds.

´Detector´-specific code. E.g. scaling/standardization and geometry tables.

´node_definition´ defines the nodes in the graph.

´edge_definition´ defines the connectivity of the nodes in the graph.

Parameters:
  • detector (Detector) – The corresponding ´Detector´ representing the data.

  • node_definition (Optional[NodeDefinition], default: None) – Definition of nodes. Defaults to NodesAsPulses.

  • edge_definition (Optional[EdgeDefinition], default: None) – Definition of edges. Defaults to None.

  • input_feature_names (Optional[List[str]], default: None) – Names of each column in expected input data that will be built into a graph. If not provided, it is automatically assumed that all features in Detector is used.

  • dtype (Optional[dtype], default: torch.float32) – data type used for node features. e.g. ´torch.float´

  • perturbation_dict (Optional[Dict[str, float]], default: None) – Dictionary mapping a feature name to a standard deviation according to which the values for this feature should be randomly perturbed. Defaults to None.

  • seed (Union[int, Generator, None], default: None) – seed or Generator used to randomly sample perturbations. Defaults to None.

  • add_inactive_sensors (bool, default: False) – If True, inactive sensors will be appended to the graph with padded pulse information. Defaults to False.

  • sensor_mask (Optional[List[int]], default: None) –

    A list of sensor id’s to be masked from the graph. Any sensor listed here will be removed from the graph.

    Defaults to None.

  • string_mask (Optional[List[int]], default: None) – A list of string id’s to be masked from the graph. Defaults to None.

  • sort_by (Optional[str], default: None) – Name of node feature to sort by. Defaults to None.

  • repeat_labels (bool, default: False) – If True, labels will be repeated to match the the number of rows in the output of the GraphDefinition. Defaults to False.

  • args (Any)

  • kwargs (Any)

Return type:

object

forward(input_features, input_feature_names, truth_dicts, custom_label_functions, loss_weight_column, loss_weight, loss_weight_default_value, data_path)[source]

Construct graph as ´Data´ object.

Parameters:
  • input_features (ndarray) – Input features for graph construction. Shape ´[num_rows, d]´

  • input_feature_names (List[str]) – name of each column. Shape ´[,d]´.

  • truth_dicts (Optional[List[Dict[str, Any]]], default: None) – Dictionary containing truth labels.

  • custom_label_functions (Optional[Dict[str, Callable[..., Any]]], default: None) – Custom label functions. See https://github.com/graphnet-team/graphnet/blob/main/GETTING_STARTED.md#adding-custom-truth-labels.

  • loss_weight_column (Optional[str], default: None) – Name of column that holds loss weight. Defaults to None.

  • loss_weight (Optional[float], default: None) – Loss weight associated with event. Defaults to None.

  • loss_weight_default_value (Optional[float], default: None) – default value for loss weight. Used in instances where some events have no pre-defined loss weight. Defaults to None.

  • data_path (Optional[str], default: None) – Path to dataset data files. Defaults to None.

Return type:

Data

Returns:

graph