graphnet.training.callbacks module

Callback class(es) for using during model training.

class graphnet.training.callbacks.PiecewiseLinearLR(optimizer, milestones, factors, last_epoch, verbose)[source]

Bases: _LRScheduler

Interpolate learning rate linearly between milestones.

Construct PiecewiseLinearLR.

For each milestone, denoting a specified number of steps, a factor multiplying the base learning rate is specified. For steps between two milestones, the learning rate is interpolated linearly between the two closest milestones. For steps before the first milestone, the factor for the first milestone is used; vice versa for steps after the last milestone.

Parameters:
  • optimizer (Optimizer) – Wrapped optimizer.

  • milestones (List[int]) – List of step indices. Must be increasing.

  • factors (List[float]) – List of multiplicative factors. Must be same length as milestones.

  • last_epoch (int, default: -1) – The index of the last epoch.

  • verbose (bool, default: False) – If True, prints a message to stdout for each update.

get_lr()[source]

Get effective learning rate(s) for each optimizer.

Return type:

List[float]

class graphnet.training.callbacks.ProgressBar(refresh_rate, process_position, leave)[source]

Bases: TQDMProgressBar

Custom progress bar for graphnet.

Customises the default progress in pytorch-lightning.

Parameters:
  • refresh_rate (int)

  • process_position (int)

  • leave (bool)

init_validation_tqdm()[source]

Override for customisation.

Return type:

Bar

init_predict_tqdm()[source]

Override for customisation.

Return type:

Bar

init_test_tqdm()[source]

Override for customisation.

Return type:

Bar

init_train_tqdm()[source]

Override for customisation.

Return type:

Bar

get_metrics(trainer, model)[source]

Override to not show the version number in the logging.

Return type:

Dict

Parameters:
  • trainer (Trainer)

  • model (LightningModule)

on_train_epoch_start(trainer, model)[source]

Print the results of the previous epoch on a separate line.

This allows the user to see the losses/metrics for previous epochs while the current is training. The default behaviour in pytorch- lightning is to overwrite the progress bar from previous epochs.

Return type:

None

Parameters:
  • trainer (Trainer)

  • model (LightningModule)

on_train_epoch_end(trainer, model)[source]

Log the final progress bar for the epoch to file.

Don’t duplciate to stdout.

Return type:

None

Parameters:
  • trainer (Trainer)

  • model (LightningModule)

class graphnet.training.callbacks.GraphnetEarlyStopping(save_dir, **kwargs)[source]

Bases: EarlyStopping

Early stopping callback for graphnet.

Construct GraphnetEarlyStopping Callback.

Parameters:
  • save_dir (str) – Path to directory to save best model and config.

  • **kwargs (Dict[str, Any]) – Keyword arguments to pass to EarlyStopping. See pytorch_lightning.callbacks.EarlyStopping for details.

setup(trainer, graphnet_model, stage)[source]

Call at setup stage of training.

Parameters:
  • trainer (Trainer) – The trainer.

  • graphnet_model (Model) – The model.

  • stage (Optional[str], default: None) – The stage of training.

Return type:

None

on_train_epoch_end(trainer, graphnet_model)[source]

Call after each train epoch.

Parameters:
  • trainer (Trainer) – Trainer object.

  • graphnet_model (Model) – Graphnet Model.

Return type:

None

Returns: None.

on_validation_end(trainer, graphnet_model)[source]

Call after each validation epoch.

Parameters:
  • trainer (Trainer) – Trainer object.

  • graphnet_model (Model) – Graphnet Model.

Return type:

None

Returns: None.

on_fit_end(trainer, graphnet_model)[source]

Call at the end of training.

Parameters:
  • trainer (Trainer) – Trainer object.

  • graphnet_model (Model) – Graphnet Model.

Return type:

None

Returns: None.