weight_fitting¶
Classes for fitting per-event weights for training.
- class graphnet.training.weight_fitting.WeightFitter(database_path, truth_table, index_column)[source]¶
Bases:
ABC
,Logger
Produces per-event weights.
Weights are returned by the public method fit_weights(), and the weights can be saved as a table in the database.
Construct UniformWeightFitter.
- Parameters:
database_path (str)
truth_table (str)
index_column (str)
- fit(bins, variable, weight_name, add_to_database, selection, transform, db_count_norm, automatic_log_bins, max_weight, **kwargs)[source]¶
Fit weights.
Calls private _fit_weights method. Output is returned as a pandas.DataFrame and optionally saved to sql.
- Parameters:
bins (
ndarray
) – Desired bins used for fitting.variable (
str
) – the name of the variable. Must match corresponding column name in the truth table.weight_name (
Optional
[str
], default:None
) – Name of the weights.add_to_database (
bool
, default:False
) – If True, the weights are saved to sql in a table named weight_name.selection (
Optional
[List
[int
]], default:None
) – a list of event_no’s. If given, only events in the selection is used for fitting.transform (
Optional
[Callable
], default:None
) – A callable method that transform the variable into a desired space. E.g. np.log10 for energy. If given, fitting will happen in this space.db_count_norm (
Optional
[int
], default:None
) – If given, the total sum of the weights for the given db will be this number.automatic_log_bins (
bool
, default:False
) – If True, the bins are generated as a log10 space between the min and max of the variable.max_weight (
Optional
[float
], default:None
) – If given, the weights are capped such that a single event weight cannot exceed this number times the sum of all weights.**kwargs (
Any
) – Additional arguments passed to _fit_weights.
- Return type:
DataFrame
- Returns:
DataFrame that contains weights, event_nos.
- class graphnet.training.weight_fitting.Uniform(database_path, truth_table, index_column)[source]¶
Bases:
WeightFitter
Produces per-event weights making variable distribution uniform.
Construct UniformWeightFitter.
- Parameters:
database_path (str)
truth_table (str)
index_column (str)
- class graphnet.training.weight_fitting.BjoernLow(database_path, truth_table, index_column)[source]¶
Bases:
WeightFitter
Produces per-event weights.
Events below x_low are weighted to be uniform, whereas events above x_low are weighted to follow a 1/(1+a*(x_low -x)) curve.
Construct UniformWeightFitter.
- Parameters:
database_path (str)
truth_table (str)
index_column (str)