graphnet.models.task.classification module

Classification-specific Model class(es).

class graphnet.models.task.classification.MulticlassClassificationTask(*args, **kwargs)[source]

Bases: IdentityTask

General task for classifying any number of classes.

Requires the same number of input features as the number of classes being predicted. Returns the untransformed latent features, which are interpreted as the logits for each class being classified.

Construct IdentityTask.

A task that does not apply a learned embedding to the input. It returns the direct inputs from Model.

Parameters:
  • args (Any)

  • kwargs (Any)

Return type:

object

class graphnet.models.task.classification.BinaryClassificationTask(*args, **kwargs)[source]

Bases: StandardLearnedTask

Performs binary classification.

Construct StandardLearnedTask.

Parameters:
  • hidden_size (int) – The number of columns in the output of the last latent layer of Model using this Task. Available through Model.nb_outputs

  • args (Any)

  • kwargs (Any)

Return type:

object

nb_inputs = 1
default_target_labels = ['target']
default_prediction_labels = ['target_pred']
class graphnet.models.task.classification.BinaryClassificationTaskLogits(*args, **kwargs)[source]

Bases: StandardLearnedTask

Performs binary classification form logits.

Construct StandardLearnedTask.

Parameters:
  • hidden_size (int) – The number of columns in the output of the last latent layer of Model using this Task. Available through Model.nb_outputs

  • args (Any)

  • kwargs (Any)

Return type:

object

nb_inputs = 1
default_target_labels = ['target']
default_prediction_labels = ['target_pred']