normalizing_flow¶
Standard model class(es).
- class graphnet.models.normalizing_flow.NormalizingFlow(*args, **kwargs)[source]¶
Bases:
EasySyntax
A model for building (conditional) normalizing flows in GraphNeT.
This model relies on jammy_flows for building and evaluating normalizing flows. https://thoglu.github.io/jammy_flows/usage/introduction.html for details.
Build NormalizingFlow to learn (conditional) normalizing flows.
NormalizingFlow is able to build, train and evaluate a wide suite of normalizing flows. Instead of optimizing a loss function, flows minimize a learned pdf of your data, providing you with a posterior distribution for every example instead of point-like predictions.
NormalizingFlow can be conditioned on existing fields in the DataRepresentation or latent representations from Models.
NormalizingFlow is built upon https://github.com/thoglu/jammy_flows, and we refer to their documentation for details on the flows.
- Parameters:
graph_definition (
GraphDefinition
) – The GraphDefinition to train the model on.target_labels (
str
) – Name of target(s) to learn the pdf of.backbone (
Optional
[GNN
], default:None
) – Architecture used to produce latent representations ofconditioned. (the input data on which the pdf will be)
None. (Defaults to)
condition_on (
Union
[List
[str
],str
,None
], default:None
) – List of fields in Data objects to condition theNone.
flow_layers (
str
, default:'gggt'
) – A string defining the flow layers.https (See) – //thoglu.github.io/jammy_flows/usage/introduction.html
"gggt". (for details. Defaults to)
optimizer_class (
Type
[Optimizer
], default:<class 'torch.optim.adam.Adam'>
) – Optimizer to use. Defaults to Adam.optimizer_kwargs (
Optional
[Dict
], default:None
) – Optimzier arguments. Defaults to None.scheduler_class (
Optional
[type
], default:None
) – Learning rate scheduler to use. Defaults to None.scheduler_kwargs (
Optional
[Dict
], default:None
) – Arguments to learning rate scheduler.None.
scheduler_config (
Optional
[Dict
], default:None
) – Defaults to None.args (Any)
kwargs (Any)
- Raises:
ValueError – if both backbone and condition_on is specified.
- Return type:
object
- forward(data)[source]¶
Forward pass, chaining model components.
- Return type:
Tensor
- Parameters:
data (Data | List[Data])
Perform shared step.
Applies the forward pass and the following loss calculation, shared between the training and validation step.
- Return type:
Tensor
- Parameters:
batch (List[Data])
batch_idx (int)