standard_averaged_model¶
Averaged Standard model class(es).
- class graphnet.models.standard_averaged_model.StandardAveragedModel(*args, **kwargs)[source]¶
Bases:
StandardModel
Class for SWA and EMA models in graphnet.
Construct StandardAverageModel.
- Parameters:
args (Any)
kwargs (Any)
- Return type:
object
- training_step(train_batch, batch_idx)[source]¶
Perform training step.
- Return type:
Tensor
- Parameters:
train_batch (Data | List[Data])
batch_idx (int)
- validation_step(val_batch, batch_idx)[source]¶
Perform validation step.
- Return type:
Tensor
- Parameters:
val_batch (Data | List[Data])
batch_idx (int)
- optimizer_step(epoch, batch_idx, optimizer, optimizer_closure)[source]¶
Perform an optimizer step.
- Return type:
None
- Parameters:
epoch (int)
batch_idx (int)
optimizer (Type[Optimizer])
optimizer_closure (Callable[[], Any] | None)