graphnet.training.utils module

Utility functions for graphnet.training.

graphnet.training.utils.collate_fn(graphs)[source]

Remove graphs with less than two DOM hits.

Should not occur in “production”.

Return type:

Batch

Parameters:

graphs (List[Data])

class graphnet.training.utils.collator_sequence_buckleting(batch_splits=[0.8])[source]

Bases: object

Perform the sequence bucketing for the graphs in the batch.

Set cutting points of the different mini-batches.

batch_splits: list of floats, each element is the fraction of the total number of graphs. This list should not explicitly define the first and last elements, which will always be 0 and 1 respectively.

Parameters:

batch_splits (List[float])

graphnet.training.utils.make_dataloader(db, pulsemaps, graph_definition, features, truth, *, batch_size, shuffle, selection, num_workers, persistent_workers, node_truth, truth_table, node_truth_table, string_selection, loss_weight_table, loss_weight_column, index_column, labels)[source]

Construct DataLoader instance.

Return type:

DataLoader

Parameters:
  • db (str)

  • pulsemaps (str | List[str])

  • graph_definition (GraphDefinition)

  • features (List[str])

  • truth (List[str])

  • batch_size (int)

  • shuffle (bool)

  • selection (List[int] | None)

  • num_workers (int)

  • persistent_workers (bool)

  • node_truth (List[str] | None)

  • truth_table (str)

  • node_truth_table (str | None)

  • string_selection (List[int] | None)

  • loss_weight_table (str | None)

  • loss_weight_column (str | None)

  • index_column (str)

  • labels (Dict[str, Callable] | None)

graphnet.training.utils.make_train_validation_dataloader(db, graph_definition, selection, pulsemaps, features, truth, *, batch_size, database_indices, seed, test_size, num_workers, persistent_workers, node_truth, truth_table, node_truth_table, string_selection, loss_weight_column, loss_weight_table, index_column, labels)[source]

Construct train and test DataLoader instances.

Return type:

Tuple[DataLoader, DataLoader]

Parameters:
  • db (str)

  • graph_definition (GraphDefinition)

  • selection (List[int] | None)

  • pulsemaps (str | List[str])

  • features (List[str])

  • truth (List[str])

  • batch_size (int)

  • database_indices (List[int] | None)

  • seed (int)

  • test_size (float)

  • num_workers (int)

  • persistent_workers (bool)

  • node_truth (str | None)

  • truth_table (str)

  • node_truth_table (str | None)

  • string_selection (List[int] | None)

  • loss_weight_column (str | None)

  • loss_weight_table (str | None)

  • index_column (str)

  • labels (Dict[str, Callable] | None)

graphnet.training.utils.get_predictions(trainer, model, dataloader, prediction_columns, *, node_level, additional_attributes)[source]

Get model predictions on dataloader.

Return type:

DataFrame

Parameters:
  • trainer (Trainer)

  • model (Model)

  • dataloader (DataLoader)

  • prediction_columns (List[str])

  • node_level (bool)

  • additional_attributes (List[str] | None)

graphnet.training.utils.save_results(db, tag, results, archive, model)[source]

Save trained model and prediction results in db.

Return type:

None

Parameters:
  • db (str)

  • tag (str)

  • results (DataFrame)

  • archive (str)

  • model (Model)

graphnet.training.utils.save_selection(selection, file_path)[source]

Save the list of event numbers to a CSV file.

Parameters:
  • selection (List[int]) – List of event ids.

  • file_path (str) – File path to save the selection.

Return type:

None