utils¶
Utility functions for graphnet.training.
- graphnet.training.utils.collate_fn(graphs)[source]¶
Remove graphs with less than two DOM hits.
Should not occur in “production”.
- Return type:
Batch
- Parameters:
graphs (List[Data])
- class graphnet.training.utils.collator_sequence_buckleting(batch_splits=[0.8])[source]¶
Bases:
object
Perform the sequence bucketing for the graphs in the batch.
Set cutting points of the different mini-batches.
batch_splits: list of floats, each element is the fraction of the total number of graphs. This list should not explicitly define the first and last elements, which will always be 0 and 1 respectively.
- Parameters:
batch_splits (List[float])
- graphnet.training.utils.make_dataloader(db, pulsemaps, graph_definition, features, truth, *, batch_size, shuffle, selection, num_workers, persistent_workers, node_truth, truth_table, node_truth_table, string_selection, loss_weight_table, loss_weight_column, index_column, labels)[source]¶
Construct DataLoader instance.
- Return type:
DataLoader
- Parameters:
db (str)
pulsemaps (str | List[str])
graph_definition (GraphDefinition)
features (List[str])
truth (List[str])
batch_size (int)
shuffle (bool)
selection (List[int] | None)
num_workers (int)
persistent_workers (bool)
node_truth (List[str] | None)
truth_table (str)
node_truth_table (str | None)
string_selection (List[int] | None)
loss_weight_table (str | None)
loss_weight_column (str | None)
index_column (str)
labels (Dict[str, Callable] | None)
- graphnet.training.utils.make_train_validation_dataloader(db, graph_definition, selection, pulsemaps, features, truth, *, batch_size, database_indices, seed, test_size, num_workers, persistent_workers, node_truth, truth_table, node_truth_table, string_selection, loss_weight_column, loss_weight_table, index_column, labels)[source]¶
Construct train and test DataLoader instances.
- Return type:
Tuple
[DataLoader
,DataLoader
]- Parameters:
db (str)
graph_definition (GraphDefinition)
selection (List[int] | None)
pulsemaps (str | List[str])
features (List[str])
truth (List[str])
batch_size (int)
database_indices (List[int] | None)
seed (int)
test_size (float)
num_workers (int)
persistent_workers (bool)
node_truth (str | None)
truth_table (str)
node_truth_table (str | None)
string_selection (List[int] | None)
loss_weight_column (str | None)
loss_weight_table (str | None)
index_column (str)
labels (Dict[str, Callable] | None)
- graphnet.training.utils.get_predictions(trainer, model, dataloader, prediction_columns, *, node_level, additional_attributes)[source]¶
Get model predictions on dataloader.
- Return type:
DataFrame
- Parameters:
trainer (Trainer)
model (Model)
dataloader (DataLoader)
prediction_columns (List[str])
node_level (bool)
additional_attributes (List[str] | None)