graphnet.data.dataset.samplers module

Sampler and BatchSampler objects for graphnet.

MIT License

Copyright (c) 2023 DrHB

Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. _____________________

class graphnet.data.dataset.samplers.RandomChunkSampler(data_source, num_samples, generator)[source]

Bases: Sampler[int]

A Sampler that randomly selects chunks.

Original implementation: https://github.com/DrHB/icecube-2nd-place/blob/main/src/dataset.py

Construct RandomChunkSampler.

Parameters:
  • data_source (Dataset)

  • num_samples (int | None)

  • generator (Generator | None)

property data_source: Sequence[Any]

Return the data source.

property num_samples: int

Return the number of samples in the data source.

property chunks: List[int]

Return the list of chunks.

graphnet.data.dataset.samplers.gather_len_matched_buckets(params)[source]

Gather length-matched buckets of events.

The function that will be used to gather batches of events for the LenMatchBatchSampler. When using multiprocessing, each worker will call this function. Given indices, this function will group events based on their length. If the length of event is N, then it will go into the (N // bucket_width) bucket. This returns completed batches and a list of incomplete batches that did not fill to batch_size at the end.

Parameters:
  • params (Tuple[range, Sequence[Any], int, int]) – A tuple containg the list of indices to process,

  • data_source (the)

  • width. (bucket)

Returns:

A list containing batches. remaining_batches: Incomplete batches.

Return type:

batches

class graphnet.data.dataset.samplers.LenMatchBatchSampler(sampler, batch_size, num_workers, bucket_width, chunks_per_segment, multiprocessing_context, drop_last)[source]

Bases: BatchSampler, Logger

A BatchSampler that batches similar length events.

Original implementation: https://github.com/DrHB/icecube-2nd-place/blob/main/src/dataset.py

Construct LenMatchBatchSampler.

This BatchSampler groups data with similar lengths to be more efficient in operations like masking for MultiHeadAttention. Since batch samplers run on the main process and can result in a CPU bottleneck, num_workers can be specified to use multiprocessing for creating the batches. The bucket_width argument specifies how wide the bins are for grouping batches. For example, with bucket_width=16, data with length [1, 16] are grouped into a bucket, data with length [17, 32] into another, etc.

Parameters:
  • sampler (Sampler) – A Sampler object that selects/draws data in some way.

  • batch_size (int, default: 1) – Batch size.

  • num_workers (int, default: 1) – Number of workers to spawn to create batches.

  • bucket_width (int, default: 16) – Size of length buckets for grouping data.

  • chunks_per_segment (int, default: 4) – Number of chunks to group together.

  • multiprocessing_context (str, default: 'spawn') – Start method for multiprocessing.

  • drop_last (Optional[bool], default: False) – (Optional) Drop the last incomplete batch.