task

Base physics task-specific Model class(es).

class graphnet.models.task.task.Task(*args, **kwargs)[source]

Bases: Model

Base class for Tasks in GraphNeT.

Construct Task.

Parameters:
  • target_labels (Union[List[str], str, None], default: None) – Name(s) of the quantity/-ies being predicted, used to extract the target tensor(s) from the Data object in .compute_loss(…).

  • prediction_labels (Union[List[str], str, None], default: None) – The name(s) of each column that is predicted by the model during inference. If not given, the name will auto matically be set to target_label + _pred.

  • transform_prediction_and_target (Optional[Callable], default: None) – Optional function to transform both the predicted and target tensor before passing them to the loss function. Useful e.g. for having the model predict quantities on a physical scale, but transforming this scale to O(1) for a numerically stable loss computation.

  • transform_target (Optional[Callable], default: None) – Optional function to transform only the target tensor before passing it, and the predicted tensor, to the loss function. Useful e.g. for having the model predict a transformed version of the target quantity, e.g. the log10- scaled energy, rather than the physical quantity itself. Used in conjunction with transform_inference to perform the inverse transform on the predicted quantity to recover the physical scale.

  • transform_inference (Optional[Callable], default: None) – Optional function to inverse-transform the model prediction to recover a physical scale. Used in conjunction with transform_target.

  • transform_support (Optional[Tuple], default: None) – Optional tuple to specify minimum and maximum of the range of validity for the inverse transforms transform_target and transform_inference in case this is restricted. By default the invertibility of transform_target is tested on the range [-1e6, 1e6].

  • loss_weight (Optional[str], default: None) – Name of the attribute in data containing per-event loss weights.

  • args (Any)

  • kwargs (Any)

Return type:

object

abstract property nb_inputs: int

Return number of inputs assumed by task.

property default_target_labels: List[str]

Return default target labels.

property default_prediction_labels: List[str]

Return default prediction labels.

inference()[source]

Activate inference mode.

Return type:

None

train_eval()[source]

Deactivate inference mode.

Return type:

None

class graphnet.models.task.task.LearnedTask(*args, **kwargs)[source]

Bases: Task

Task class with a learned mapping.

Applies a learned mapping between the last latent layer of Model and target space. E.g. the LearnedTask contains learnable parameters that acts like a prediction head.

Construct LearnedTask.

Parameters:
  • hidden_size (int) – The number of columns in the output of the last latent layer of Model using this Task. Available through Model.nb_outputs

  • loss_function (LossFunction) – Loss function appropriate to the task.

  • args (Any)

  • kwargs (Any)

Return type:

object

abstract compute_loss(pred, data)[source]

Compute loss of pred wrt.

target labels in data.

Return type:

Tensor

Parameters:
  • pred (Tensor | Data)

  • data (Data)

abstract property nb_inputs: int

Return number of inputs assumed by task.

forward(x)[source]

Forward call for LearnedTask.

The learned embedding transforms last latent layer of Model to meet target dimensions.

Return type:

Union[Tensor, Data]

Parameters:

x (Tensor | Data)

class graphnet.models.task.task.StandardLearnedTask(*args, **kwargs)[source]

Bases: LearnedTask

Standard class for classification and reconstruction in GraphNeT.

This class comes with a definition of compute_loss that is compatible with the vast majority of supervised learning tasks.

Construct StandardLearnedTask.

Parameters:
  • hidden_size (int) – The number of columns in the output of the last latent layer of Model using this Task. Available through Model.nb_outputs

  • args (Any)

  • kwargs (Any)

Return type:

object

abstract property nb_inputs: int

Return number of inputs assumed by task.

compute_loss(pred, data)[source]

Compute supervised learning loss.

Return type:

Tensor

Parameters:
  • pred (Tensor | Data)

  • data (Data)

Grabs truth labels in data and sends both pred and target to loss function for evaluation. Suits most supervised learning `Task`s.

class graphnet.models.task.task.IdentityTask(*args, **kwargs)[source]

Bases: StandardLearnedTask

Identity, or trivial, task.

Construct IdentityTask.

A task that does not apply a learned embedding to the input. It returns the direct inputs from Model.

Parameters:
  • args (Any)

  • kwargs (Any)

Return type:

object

property default_target_labels: List[str]

Return default target labels.

property default_prediction_labels: List[str]

Return default prediction labels.

property nb_inputs: int

Return number of inputs assumed by task.

class graphnet.models.task.task.StandardFlowTask(*args, **kwargs)[source]

Bases: Task

A Task for `NormalizingFlow`s in GraphNeT.

This Task requires the support package`jammy_flows` for constructing and evaluating normalizing flows.

Construct StandardFlowTask.

Parameters:
  • target_labels – A list of names for the targets of this Task.

  • flow_layers (str, default: 'gggt') – A string indicating the flow layer types. See https://thoglu.github.io/jammy_flows/usage/introduction.html for details.

  • target_norm (float, default: 1000.0) – A normalization constant used to divide the target

  • 1000. (values. Value is applied to all targets. Defaults to)

  • hidden_size (Optional[int]) – The number of columns on which the normalizing flow

  • None (is conditioned on. May be)

  • flow. (indicating non-conditional)

  • args (Any)

  • kwargs (Any)

Return type:

object

property default_prediction_labels: List[str]

Return default prediction labels.

nb_inputs()[source]

Return number of conditional inputs assumed by task.

Return type:

Optional[int]

forward(x, data)[source]

Forward pass.

Return type:

Union[Tensor, Data]

Parameters:
  • x (Tensor | Data)

  • data (List[Data])