graphnet.training.weight_fitting module

Classes for fitting per-event weights for training.

class graphnet.training.weight_fitting.WeightFitter(database_path, truth_table, index_column)[source]

Bases: ABC, Logger

Produces per-event weights.

Weights are returned by the public method fit_weights(), and the weights can be saved as a table in the database.

Construct UniformWeightFitter.

Parameters:
  • database_path (str)

  • truth_table (str)

  • index_column (str)

fit(bins, variable, weight_name, add_to_database, selection, transform, db_count_norm, automatic_log_bins, max_weight, **kwargs)[source]

Fit weights.

Calls private _fit_weights method. Output is returned as a pandas.DataFrame and optionally saved to sql.

Parameters:
  • bins (ndarray) – Desired bins used for fitting.

  • variable (str) – the name of the variable. Must match corresponding column name in the truth table.

  • weight_name (Optional[str], default: None) – Name of the weights.

  • add_to_database (bool, default: False) – If True, the weights are saved to sql in a table named weight_name.

  • selection (Optional[List[int]], default: None) – a list of event_no’s. If given, only events in the selection is used for fitting.

  • transform (Optional[Callable], default: None) – A callable method that transform the variable into a desired space. E.g. np.log10 for energy. If given, fitting will happen in this space.

  • db_count_norm (Optional[int], default: None) – If given, the total sum of the weights for the given db will be this number.

  • automatic_log_bins (bool, default: False) – If True, the bins are generated as a log10 space between the min and max of the variable.

  • max_weight (Optional[float], default: None) – If given, the weights are capped such that a single event weight cannot exceed this number times the sum of all weights.

  • **kwargs (Any) – Additional arguments passed to _fit_weights.

Return type:

DataFrame

Returns:

DataFrame that contains weights, event_nos.

class graphnet.training.weight_fitting.Uniform(database_path, truth_table, index_column)[source]

Bases: WeightFitter

Produces per-event weights making variable distribution uniform.

Construct UniformWeightFitter.

Parameters:
  • database_path (str)

  • truth_table (str)

  • index_column (str)

class graphnet.training.weight_fitting.BjoernLow(database_path, truth_table, index_column)[source]

Bases: WeightFitter

Produces per-event weights.

Events below x_low are weighted to be uniform, whereas events above x_low are weighted to follow a 1/(1+a*(x_low -x)) curve.

Construct UniformWeightFitter.

Parameters:
  • database_path (str)

  • truth_table (str)

  • index_column (str)