"""Standard model class(es)."""
from typing import Dict, List, Optional, Union, Type
import torch
from torch import Tensor
from torch_geometric.data import Data
from torch.optim import Adam
from graphnet.models.gnn.gnn import GNN
from .easy_model import EasySyntax
from graphnet.models.task import StandardFlowTask
from graphnet.models.graphs import GraphDefinition
from graphnet.models.utils import get_fields
[docs]
class NormalizingFlow(EasySyntax):
"""A model for building (conditional) normalizing flows in GraphNeT.
This model relies on `jammy_flows` for building and evaluating
normalizing flows.
https://thoglu.github.io/jammy_flows/usage/introduction.html
for details.
"""
def __init__(
self,
graph_definition: GraphDefinition,
target_labels: str,
backbone: Optional[GNN] = None,
condition_on: Union[str, List[str], None] = None,
flow_layers: str = "gggt",
optimizer_class: Type[torch.optim.Optimizer] = Adam,
optimizer_kwargs: Optional[Dict] = None,
scheduler_class: Optional[type] = None,
scheduler_kwargs: Optional[Dict] = None,
scheduler_config: Optional[Dict] = None,
) -> None:
"""Build NormalizingFlow to learn (conditional) normalizing flows.
NormalizingFlow is able to build, train and evaluate a wide suite of
normalizing flows. Instead of optimizing a loss function, flows
minimize a learned pdf of your data, providing you with a posterior
distribution for every example instead of point-like predictions.
`NormalizingFlow` can be conditioned on existing fields in the
DataRepresentation or latent representations from `Models`.
NormalizingFlow is built upon https://github.com/thoglu/jammy_flows,
and we refer to their documentation for details on the flows.
Args:
graph_definition: The `GraphDefinition` to train the model on.
target_labels: Name of target(s) to learn the pdf of.
backbone: Architecture used to produce latent representations of
the input data on which the pdf will be conditioned.
Defaults to None.
condition_on: List of fields in Data objects to condition the
pdf on. Defaults to None.
flow_layers: A string defining the flow layers.
See https://thoglu.github.io/jammy_flows/usage/introduction.html
for details. Defaults to "gggt".
optimizer_class: Optimizer to use. Defaults to Adam.
optimizer_kwargs: Optimzier arguments. Defaults to None.
scheduler_class: Learning rate scheduler to use. Defaults to None.
scheduler_kwargs: Arguments to learning rate scheduler.
Defaults to None.
scheduler_config: Defaults to None.
Raises:
ValueError: if both `backbone` and `condition_on` is specified.
"""
# Checks
if (backbone is not None) & (condition_on is not None):
# If user wants to condition on both
raise ValueError(
f"{self.__class__.__name__} got values for both "
"`backbone` and `condition_on`, but can only"
"condition on one of those. Please specify just "
"one of these arguments."
)
# Handle args
if backbone is not None:
assert isinstance(backbone, GNN)
hidden_size = backbone.nb_outputs
elif condition_on is not None:
if isinstance(condition_on, str):
condition_on = [condition_on]
hidden_size = len(condition_on)
else:
hidden_size = None
# Build Flow Task
task = StandardFlowTask(
hidden_size=hidden_size,
flow_layers=flow_layers,
target_labels=target_labels,
)
# Base class constructor
super().__init__(
tasks=task,
optimizer_class=optimizer_class,
optimizer_kwargs=optimizer_kwargs,
scheduler_class=scheduler_class,
scheduler_kwargs=scheduler_kwargs,
scheduler_config=scheduler_config,
)
# Member variable(s)
self._graph_definition = graph_definition
self.backbone = backbone
self._condition_on = condition_on
self._norm = torch.nn.BatchNorm1d(hidden_size)
[docs]
def forward(self, data: Union[Data, List[Data]]) -> Tensor:
"""Forward pass, chaining model components."""
if isinstance(data, Data):
data = [data]
x_list = []
for d in data:
if self.backbone is not None:
x = self._backbone(d)
x = self._norm(x)
elif self._condition_on is not None:
assert isinstance(self._condition_on, list)
x = get_fields(data=d, fields=self._condition_on)
else:
# Unconditional flow
x = None
x = self._tasks[0](x, d)
x_list.append(x)
x = torch.cat(x_list, dim=0)
return [x]
def _backbone(
self, data: Union[Data, List[Data]]
) -> List[Union[Tensor, Data]]:
assert self.backbone is not None
return self.backbone(data)
[docs]
def shared_step(self, batch: List[Data], batch_idx: int) -> Tensor:
"""Perform shared step.
Applies the forward pass and the following loss calculation, shared
between the training and validation step.
"""
loss = self(batch)
if isinstance(loss, list):
assert len(loss) == 1
loss = loss[0]
return torch.mean(loss, dim=0)
[docs]
def validate_tasks(self) -> None:
"""Verify that self._tasks contain compatible elements."""
accepted_tasks = StandardFlowTask
for task in self._tasks:
assert isinstance(task, accepted_tasks)