callbacks¶
Callback class(es) for using during model training.
- class graphnet.training.callbacks.PiecewiseLinearLR(optimizer, milestones, factors, last_epoch, verbose)[source]¶
Bases:
_LRScheduler
Interpolate learning rate linearly between milestones.
Construct PiecewiseLinearLR.
For each milestone, denoting a specified number of steps, a factor multiplying the base learning rate is specified. For steps between two milestones, the learning rate is interpolated linearly between the two closest milestones. For steps before the first milestone, the factor for the first milestone is used; vice versa for steps after the last milestone.
- Parameters:
optimizer (
Optimizer
) – Wrapped optimizer.milestones (
List
[int
]) – List of step indices. Must be increasing.factors (
List
[float
]) – List of multiplicative factors. Must be same length as milestones.last_epoch (
int
, default:-1
) – The index of the last epoch.verbose (
bool
, default:False
) – IfTrue
, prints a message to stdout for each update.
- class graphnet.training.callbacks.ProgressBar(refresh_rate, process_position, leave)[source]¶
Bases:
TQDMProgressBar
Custom progress bar for graphnet.
Customises the default progress in pytorch-lightning.
- Parameters:
refresh_rate (int)
process_position (int)
leave (bool)
- get_metrics(trainer, model)[source]¶
Override to not show the version number in the logging.
- Return type:
Dict
- Parameters:
trainer (Trainer)
model (LightningModule)
- on_train_epoch_start(trainer, model)[source]¶
Print the results of the previous epoch on a separate line.
This allows the user to see the losses/metrics for previous epochs while the current is training. The default behaviour in pytorch- lightning is to overwrite the progress bar from previous epochs.
- Return type:
None
- Parameters:
trainer (Trainer)
model (LightningModule)
- class graphnet.training.callbacks.GraphnetEarlyStopping(save_dir, **kwargs)[source]¶
Bases:
EarlyStopping
Early stopping callback for graphnet.
Construct GraphnetEarlyStopping Callback.
- Parameters:
save_dir (
str
) – Path to directory to save best model and config.**kwargs (
Dict
[str
,Any
]) – Keyword arguments to pass to EarlyStopping. See pytorch_lightning.callbacks.EarlyStopping for details.
- setup(trainer, graphnet_model, stage)[source]¶
Call at setup stage of training.
- Parameters:
trainer (
Trainer
) – The trainer.graphnet_model (
Model
) – The model.stage (
Optional
[str
], default:None
) – The stage of training.
- Return type:
None
- on_train_epoch_end(trainer, graphnet_model)[source]¶
Call after each train epoch.
- Parameters:
trainer (
Trainer
) – Trainer object.graphnet_model (
Model
) – Graphnet Model.
- Return type:
None
Returns: None.