utils¶
Utility functions for graphnet.training.
- graphnet.training.utils.collate_fn(graphs)[source]¶
Remove graphs with less than two DOM hits.
Should not occur in “production”.
- Return type:
Batch- Parameters:
graphs (List[Data])
- class graphnet.training.utils.collator_sequence_buckleting(batch_splits=[0.8])[source]¶
Bases:
objectPerform the sequence bucketing for the graphs in the batch.
Set cutting points of the different mini-batches.
batch_splits: list of floats, each element is the fraction of the total number of graphs. This list should not explicitly define the first and last elements, which will always be 0 and 1 respectively.
- Parameters:
batch_splits (List[float])
- graphnet.training.utils.make_dataloader(db, pulsemaps, graph_definition, features, truth, *, batch_size, shuffle, selection, num_workers, persistent_workers, node_truth, truth_table, node_truth_table, string_selection, loss_weight_table, loss_weight_column, index_column, labels)[source]¶
Construct DataLoader instance.
- Return type:
DataLoader- Parameters:
db (str)
pulsemaps (str | List[str])
graph_definition (GraphDefinition)
features (List[str])
truth (List[str])
batch_size (int)
shuffle (bool)
selection (List[int] | None)
num_workers (int)
persistent_workers (bool)
node_truth (List[str] | None)
truth_table (str)
node_truth_table (str | None)
string_selection (List[int] | None)
loss_weight_table (str | None)
loss_weight_column (str | None)
index_column (str)
labels (Dict[str, Callable] | None)
- graphnet.training.utils.make_train_validation_dataloader(db, graph_definition, selection, pulsemaps, features, truth, *, batch_size, database_indices, seed, test_size, num_workers, persistent_workers, node_truth, truth_table, node_truth_table, string_selection, loss_weight_column, loss_weight_table, index_column, labels)[source]¶
Construct train and test DataLoader instances.
- Return type:
Tuple[DataLoader,DataLoader]- Parameters:
db (str)
graph_definition (GraphDefinition)
selection (List[int] | None)
pulsemaps (str | List[str])
features (List[str])
truth (List[str])
batch_size (int)
database_indices (List[int] | None)
seed (int)
test_size (float)
num_workers (int)
persistent_workers (bool)
node_truth (str | None)
truth_table (str)
node_truth_table (str | None)
string_selection (List[int] | None)
loss_weight_column (str | None)
loss_weight_table (str | None)
index_column (str)
labels (Dict[str, Callable] | None)
- graphnet.training.utils.get_predictions(trainer, model, dataloader, prediction_columns, *, node_level, additional_attributes)[source]¶
Get model predictions on dataloader.
- Return type:
DataFrame- Parameters:
trainer (Trainer)
model (Model)
dataloader (DataLoader)
prediction_columns (List[str])
node_level (bool)
additional_attributes (List[str] | None)
- graphnet.training.utils.save_results(db, tag, results, archive, model)[source]¶
Save trained model and prediction results in db.
- Return type:
None- Parameters:
db (str)
tag (str)
results (DataFrame)
archive (str)
model (Model)
- graphnet.training.utils.save_selection(selection, file_path)[source]¶
Save the list of event numbers to a CSV file.
- Parameters:
selection (
List[int]) – List of event ids.file_path (
str) – File path to save the selection.
- Return type:
None
- graphnet.training.utils.add_custom_labels(data, custom_label_functions, repeat_labels_by)[source]¶
Add custom labels to the data.
- Parameters:
data (
Data) – data where the label will be storedcustom_label_functions (
Dict[str,Callable[...,Any]]) – dictionary containing the custom label functionsrepeat_labels_by (
Optional[int], default:None) – If specified, repeats the labels along the specified dimension.
- Return type:
Data- Returns:
data with labels
- graphnet.training.utils.add_truth(data, truth_dicts, dtype, repeat_labels_by)[source]¶
Add truth labels from ´truth_dicts´ to ´data´.
I.e. ´data[key] = truth_dict[key]´
- Parameters:
data (
Data) – data where the label will be storedtruth_dicts (
Union[Dict[str,Any],List[Dict[str,Any]]]) – dictionary containing the labelsdtype (
Optional[dtype], default:torch.float32) – dtype of the truth labelsrepeat_labels_by (
Optional[int], default:None) – If specified, repeats the labels along the specified dimension.
- Return type:
Data- Returns:
data with labels