graphnet.models.standard_averaged_model module

Averaged Standard model class(es).

class graphnet.models.standard_averaged_model.StandardAveragedModel(*args, **kwargs)[source]

Bases: StandardModel

Class for SWA and EMA models in graphnet.

Construct StandardAverageModel.

Parameters:
  • args (Any)

  • kwargs (Any)

Return type:

object

training_step(train_batch, batch_idx)[source]

Perform training step.

Return type:

Tensor

Parameters:
  • train_batch (Data | List[Data])

  • batch_idx (int)

validation_step(val_batch, batch_idx)[source]

Perform validation step.

Return type:

Tensor

Parameters:
  • val_batch (Data | List[Data])

  • batch_idx (int)

optimizer_step(epoch, batch_idx, optimizer, optimizer_closure)[source]

Perform an optimizer step.

Return type:

None

Parameters:
  • epoch (int)

  • batch_idx (int)

  • optimizer (Type[Optimizer])

  • optimizer_closure (Callable[[], Any] | None)

load_state_dict(path, **kargs)[source]

Load model state_dict from path.

Return type:

StandardAveragedModel

Parameters:
  • path (str | Dict)

  • kargs (Any | None)

on_train_end()[source]

Update the model parameters with the Averaged ones.

Return type:

None