From 3d4ef9e2d73474eb4c2a29065c1a9a32b29f79e4 Mon Sep 17 00:00:00 2001 From: VincentAuriau Date: Sat, 13 Dec 2025 15:33:46 +0100 Subject: [PATCH 01/11] ADD: possibility to add val weights in .fit() --- choice_learn/models/base_model.py | 60 ++++++++++++++++++++++++------- 1 file changed, 48 insertions(+), 12 deletions(-) diff --git a/choice_learn/models/base_model.py b/choice_learn/models/base_model.py index bd2fcb89..90888080 100644 --- a/choice_learn/models/base_model.py +++ b/choice_learn/models/base_model.py @@ -12,6 +12,7 @@ import tqdm import choice_learn.tf_ops as tf_ops +from choice_learn.data import ChoiceDataset class ChoiceModel: @@ -407,22 +408,57 @@ def fit( # Test on val_dataset if provided if val_dataset is not None: test_losses = [] - for batch_nb, ( - shared_features_batch, - items_features_batch, - available_items_batch, - choices_batch, - ) in enumerate(val_dataset.iter_batch(shuffle=False, batch_size=batch_size)): - self.callbacks.on_batch_begin(batch_nb) - self.callbacks.on_test_batch_begin(batch_nb) - test_losses.append( - self.batch_predict( + + if isinstance(val_dataset, tuple): + if not len(val_dataset) == 2: + raise ValueError( + """if argument val_dataset is a tuple, + should in the form (ChoiceDataset, weights)""" + ) + validation_dataset, val_samples_weight = val_dataset + for batch_nb, ( + ( shared_features_batch, items_features_batch, available_items_batch, choices_batch, - )[0]["optimized_loss"] - ) + ), + weight_batch, + ) in enumerate( + validation_dataset.iter_batch( + shuffle=False, sample_weight=val_samples_weight, batch_size=batch_size + ) + ): + self.callbacks.on_batch_begin(batch_nb) + self.callbacks.on_test_batch_begin(batch_nb) + test_losses.append( + self.batch_predict( + shared_features_batch, + items_features_batch, + available_items_batch, + choices_batch, + sample_weight=weight_batch, + )[0]["optimized_loss"] + ) + else: + if not isinstance(val_dataset, ChoiceDataset): + raise ValueError("val_dataset should be a ChoiceDataset object.") + for batch_nb, ( + shared_features_batch, + items_features_batch, + available_items_batch, + choices_batch, + ) in enumerate(val_dataset.iter_batch(shuffle=False, batch_size=batch_size)): + self.callbacks.on_batch_begin(batch_nb) + self.callbacks.on_test_batch_begin(batch_nb) + test_losses.append( + self.batch_predict( + shared_features_batch, + items_features_batch, + available_items_batch, + choices_batch, + )[0]["optimized_loss"] + ) val_logs["val_loss"].append(test_losses[-1]) temps_logs = {k: tf.reduce_mean(v) for k, v in val_logs.items()} self.callbacks.on_test_batch_end(batch_nb, logs=temps_logs) From 2f8aa7b18a63ed4506172433d1e68218901ff9be Mon Sep 17 00:00:00 2001 From: VincentAuriau Date: Sat, 13 Dec 2025 15:33:50 +0100 Subject: [PATCH 02/11] ADD: possibility to add val weights in .fit() --- choice_learn/models/base_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/choice_learn/models/base_model.py b/choice_learn/models/base_model.py index 90888080..ed8b443f 100644 --- a/choice_learn/models/base_model.py +++ b/choice_learn/models/base_model.py @@ -265,7 +265,7 @@ def fit( Input data in the form of a ChoiceDataset sample_weight : np.ndarray, optional Sample weight to apply, by default None - val_dataset : ChoiceDataset, optional + val_dataset : ChoiceDataset or (ChoiceDataset, samples_weight), optional Test ChoiceDataset to evaluate performances on test at each epoch, by default None verbose : int, optional print level, for debugging, by default 0 From 43b9dcc457d4608c029e3561db387873274d7188 Mon Sep 17 00:00:00 2001 From: VincentAuriau Date: Sat, 13 Dec 2025 15:34:08 +0100 Subject: [PATCH 03/11] ADD: corresponding tests --- tests/unit_tests/models/test_simplemnl.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/tests/unit_tests/models/test_simplemnl.py b/tests/unit_tests/models/test_simplemnl.py index 253f4f3d..aa336d47 100644 --- a/tests/unit_tests/models/test_simplemnl.py +++ b/tests/unit_tests/models/test_simplemnl.py @@ -153,3 +153,18 @@ def test_save_load(): assert nll_a == nll_b shutil.rmtree("test_save") + + +def test_weighted_val_dataset(): + """Tests instantiation with item and fit with Adam.""" + tf.config.run_functions_eagerly(True) + model = SimpleMNL(intercept="item", optimizer="Adam", epochs=100, lr=0.1) + model.instantiate(n_items=3, n_items_features=2, n_shared_features=3) + nll_b = model.evaluate(test_dataset) + model.fit( + test_dataset, get_report=True, val_dataset=(test_dataset, np.ones((len(test_dataset),))) + ) + nll_a = model.evaluate(test_dataset, batch_size=-1) + assert nll_a < nll_b + + assert model.report.to_numpy().shape == (7, 5) From d0b97e13d712b64378b08251916a4c1e0e1e36ae Mon Sep 17 00:00:00 2001 From: VincentAuriau Date: Sat, 13 Dec 2025 15:34:56 +0100 Subject: [PATCH 04/11] update notebook --- notebooks/models/simple_mnl.ipynb | 50 +++++++++++++++---------------- 1 file changed, 25 insertions(+), 25 deletions(-) diff --git a/notebooks/models/simple_mnl.ipynb b/notebooks/models/simple_mnl.ipynb index 67653242..6a9ea190 100644 --- a/notebooks/models/simple_mnl.ipynb +++ b/notebooks/models/simple_mnl.ipynb @@ -307,24 +307,24 @@ " Weights_items_features_0\n", " -0.001533\n", " 0.000621\n", - " -2.469423\n", - " 1.353312e-02\n", + " -2.469422\n", + " 1.353315e-02\n", " \n", " \n", " 1\n", " Weights_items_features_1\n", " -0.006996\n", " 0.001554\n", - " -4.501964\n", - " 6.675720e-06\n", + " -4.501969\n", + " 6.732662e-06\n", " \n", " \n", " 2\n", " Intercept_0\n", " 1.710969\n", - " 0.226741\n", - " 7.545904\n", - " 0.000000e+00\n", + " 0.226742\n", + " 7.545903\n", + " 4.485301e-14\n", " \n", " \n", " 3\n", @@ -338,17 +338,17 @@ " 4\n", " Intercept_2\n", " 1.658846\n", - " 0.448417\n", - " 3.699342\n", - " 2.161264e-04\n", + " 0.448416\n", + " 3.699345\n", + " 2.161564e-04\n", " \n", " \n", " 5\n", " Intercept_3\n", " 1.853437\n", - " 0.361953\n", - " 5.120663\n", - " 3.576279e-07\n", + " 0.361952\n", + " 5.120667\n", + " 3.044562e-07\n", " \n", " \n", "\n", @@ -356,20 +356,20 @@ ], "text/plain": [ " Coefficient Name Coefficient Estimation Std. Err z_value \\\n", - "0 Weights_items_features_0 -0.001533 0.000621 -2.469423 \n", - "1 Weights_items_features_1 -0.006996 0.001554 -4.501964 \n", - "2 Intercept_0 1.710969 0.226741 7.545904 \n", + "0 Weights_items_features_0 -0.001533 0.000621 -2.469422 \n", + "1 Weights_items_features_1 -0.006996 0.001554 -4.501969 \n", + "2 Intercept_0 1.710969 0.226742 7.545903 \n", "3 Intercept_1 0.308263 0.206591 1.492140 \n", - "4 Intercept_2 1.658846 0.448417 3.699342 \n", - "5 Intercept_3 1.853437 0.361953 5.120663 \n", + "4 Intercept_2 1.658846 0.448416 3.699345 \n", + "5 Intercept_3 1.853437 0.361952 5.120667 \n", "\n", " P(.>z) \n", - "0 1.353312e-02 \n", - "1 6.675720e-06 \n", - "2 0.000000e+00 \n", + "0 1.353315e-02 \n", + "1 6.732662e-06 \n", + "2 4.485301e-14 \n", "3 1.356624e-01 \n", - "4 2.161264e-04 \n", - "5 3.576279e-07 " + "4 2.161564e-04 \n", + "5 3.044562e-07 " ] }, "execution_count": null, @@ -391,7 +391,7 @@ ], "metadata": { "kernelspec": { - "display_name": "tf_env", + "display_name": "choice_learn", "language": "python", "name": "python3" }, @@ -405,7 +405,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.4" + "version": "3.12.11" } }, "nbformat": 4, From 84d5fb5dff676640c188fbd250c30b268f32fe68 Mon Sep 17 00:00:00 2001 From: VincentAuriau Date: Sat, 13 Dec 2025 15:51:34 +0100 Subject: [PATCH 05/11] ADD: validation freq parameter in model.fit --- choice_learn/models/base_model.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/choice_learn/models/base_model.py b/choice_learn/models/base_model.py index ed8b443f..f3ff83c4 100644 --- a/choice_learn/models/base_model.py +++ b/choice_learn/models/base_model.py @@ -255,6 +255,7 @@ def fit( choice_dataset, sample_weight=None, val_dataset=None, + validation_freq=1, verbose=0, ): """Train the model with a ChoiceDataset. @@ -273,6 +274,10 @@ def fit( Number of epochs, default is None, meaning we use self.epochs batch_size : int, optional Batch size, default is None, meaning we use self.batch_size + validation_freq: int, optional + Only relevant if validation data is provided. Specifies how many training epochs + to run before a new validation run is performed, e.g. validation_freq=2 runs validation + every 2 epochs. Returns ------- @@ -406,7 +411,7 @@ def fit( ) # Test on val_dataset if provided - if val_dataset is not None: + if val_dataset is not None and ((epoch_nb + 1) % validation_freq) == 0: test_losses = [] if isinstance(val_dataset, tuple): From b58b5f181a64b3e9c945f033fcba807460087a77 Mon Sep 17 00:00:00 2001 From: VincentAuriau Date: Sat, 13 Dec 2025 15:59:17 +0100 Subject: [PATCH 06/11] ENH: minimized code --- choice_learn/models/base_model.py | 76 +++++++++++++++---------------- 1 file changed, 36 insertions(+), 40 deletions(-) diff --git a/choice_learn/models/base_model.py b/choice_learn/models/base_model.py index f3ff83c4..de577224 100644 --- a/choice_learn/models/base_model.py +++ b/choice_learn/models/base_model.py @@ -414,56 +414,52 @@ def fit( if val_dataset is not None and ((epoch_nb + 1) % validation_freq) == 0: test_losses = [] + val_samples_weight = None if isinstance(val_dataset, tuple): if not len(val_dataset) == 2: raise ValueError( - """if argument val_dataset is a tuple, - should in the form (ChoiceDataset, weights)""" + """if argument val_dataset is a tuple, it should be + in the form (ChoiceDataset, weights)""" ) validation_dataset, val_samples_weight = val_dataset - for batch_nb, ( - ( - shared_features_batch, - items_features_batch, - available_items_batch, - choices_batch, - ), - weight_batch, - ) in enumerate( - validation_dataset.iter_batch( - shuffle=False, sample_weight=val_samples_weight, batch_size=batch_size - ) - ): - self.callbacks.on_batch_begin(batch_nb) - self.callbacks.on_test_batch_begin(batch_nb) - test_losses.append( - self.batch_predict( - shared_features_batch, - items_features_batch, - available_items_batch, - choices_batch, - sample_weight=weight_batch, - )[0]["optimized_loss"] - ) + elif isinstance(val_dataset, ChoiceDataset): + validation_dataset = val_dataset else: - if not isinstance(val_dataset, ChoiceDataset): - raise ValueError("val_dataset should be a ChoiceDataset object.") - for batch_nb, ( + raise ValueError( + """val_dataset should be a ChoiceDataset or + a tuple of (ChoiceDataset, weights).""" + ) + + val_iterator = validation_dataset.iter_batch( + shuffle=False, sample_weight=val_samples_weight, batch_size=batch_size + ) + + for batch_nb, batch_data in enumerate(val_iterator): + weight_batch = None + if val_samples_weight is not None: + batch_features, weight_batch = batch_data + else: + batch_features = batch_data + + ( shared_features_batch, items_features_batch, available_items_batch, choices_batch, - ) in enumerate(val_dataset.iter_batch(shuffle=False, batch_size=batch_size)): - self.callbacks.on_batch_begin(batch_nb) - self.callbacks.on_test_batch_begin(batch_nb) - test_losses.append( - self.batch_predict( - shared_features_batch, - items_features_batch, - available_items_batch, - choices_batch, - )[0]["optimized_loss"] - ) + ) = batch_features + + self.callbacks.on_batch_begin(batch_nb) + self.callbacks.on_test_batch_begin(batch_nb) + + loss = self.batch_predict( + shared_features_batch, + items_features_batch, + available_items_batch, + choices_batch, + sample_weight=weight_batch, + )[0]["optimized_loss"] + test_losses.append(loss) + val_logs["val_loss"].append(test_losses[-1]) temps_logs = {k: tf.reduce_mean(v) for k, v in val_logs.items()} self.callbacks.on_test_batch_end(batch_nb, logs=temps_logs) From 03be565739e3f496b01d804af383be721002e8f7 Mon Sep 17 00:00:00 2001 From: VincentAuriau Date: Mon, 15 Dec 2025 17:35:42 +0100 Subject: [PATCH 07/11] ENH: tripwise metrics handled with weights --- .../basket_models/base_basket_model.py | 6 ++++-- .../basket_models/data/basket_dataset.py | 21 ++++++++++--------- 2 files changed, 15 insertions(+), 12 deletions(-) diff --git a/choice_learn/basket_models/base_basket_model.py b/choice_learn/basket_models/base_basket_model.py index 31f3755e..e9597313 100644 --- a/choice_learn/basket_models/base_basket_model.py +++ b/choice_learn/basket_models/base_basket_model.py @@ -897,7 +897,7 @@ def evaluate( metric.reset_state() # for trip in trip_dataset.trips: - for data_batch, identifier_batch in trip_dataset.iter_batch_evaluate( + for data_batch, weights_batch in trip_dataset.iter_batch_evaluate( trip_batch_size=trip_batch_size ): # Sum of the log-likelihoods of all the baskets in the batch @@ -914,7 +914,9 @@ def evaluate( for metric in exec_metrics: # Use update_state, not append(metric(...)) metric.update_state( - y_true=data_batch[0], y_pred=predicted_probabilities, batch=identifier_batch + y_true=data_batch[0], + y_pred=predicted_probabilities, + sample_weight=weights_batch, ) # After the loops, get the final results diff --git a/choice_learn/basket_models/data/basket_dataset.py b/choice_learn/basket_models/data/basket_dataset.py index 47b07ecf..0568635b 100644 --- a/choice_learn/basket_models/data/basket_dataset.py +++ b/choice_learn/basket_models/data/basket_dataset.py @@ -783,11 +783,12 @@ def iter_batch_evaluate( np.empty(0, dtype=int), # Weeks np.empty((0, self.n_items), dtype=int), # Prices np.empty((0, self.n_items), dtype=int), # Available items + np.empty(0, dtype=int), # Users ) if trip_batch_size == -1: # Get the whole dataset in one batch - identifiers = [] + weights = [] for trip_index in trip_indexes: additional_trip_data = self.get_one_vs_all_augmented_data_from_trip_index( trip_index @@ -795,17 +796,18 @@ def iter_batch_evaluate( buffer = tuple( np.concatenate((buffer[i], additional_trip_data[i])) for i in range(len(buffer)) ) - identifiers.extend([trip_index] * len(additional_trip_data[0])) + weights.extend([1 / len(additional_trip_data[0])] * len(additional_trip_data[0])) # Yield the whole dataset - yield buffer, np.array(identifiers) + yield buffer, np.array(weights) else: # Yield batches of size batch_size while going through all the trips index = 0 outer_break = False while index < num_trips: - trip_identifier = [] + weights = [] + trip_count = 0 buffer = ( np.empty(0, dtype=int), # Items np.empty((0, self.max_length), dtype=int), # Baskets @@ -816,11 +818,11 @@ def iter_batch_evaluate( np.empty((0, self.n_items), dtype=int), # Available items np.empty(0, dtype=int), # Users ) - while np.max(trip_identifier, initial=-1) + 1 < trip_batch_size: + while trip_count + 1 < trip_batch_size: if index >= num_trips: # Then the buffer is not full but there are no more trips to consider # Yield the batch partially filled - yield buffer, np.array(trip_identifier) + yield buffer, np.array(weights) # Exit the TWO while loops when all trips have been considered outer_break = True @@ -832,18 +834,17 @@ def iter_batch_evaluate( trip_indexes[index] ) index += 1 + trip_count += 1 # Fill the buffer with the new trip buffer = tuple( np.concatenate((buffer[i], additional_trip_data[i])) for i in range(len(buffer)) ) - trip_identifier.extend( - [np.max(trip_identifier, initial=-1) + 1] * len(additional_trip_data[0]) - ) + weights.extend([1 / additional_trip_data[0]] * len(additional_trip_data[0])) if outer_break: break # Yield the batch - yield buffer, np.array(trip_identifier) + yield buffer, np.array(weights) From 55ab10c1d6683110f61b4eff4eb91e0793e0100f Mon Sep 17 00:00:00 2001 From: VincentAuriau Date: Mon, 15 Dec 2025 17:39:55 +0100 Subject: [PATCH 08/11] fix --- choice_learn/basket_models/data/basket_dataset.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/choice_learn/basket_models/data/basket_dataset.py b/choice_learn/basket_models/data/basket_dataset.py index 0568635b..0a9916cb 100644 --- a/choice_learn/basket_models/data/basket_dataset.py +++ b/choice_learn/basket_models/data/basket_dataset.py @@ -841,7 +841,9 @@ def iter_batch_evaluate( np.concatenate((buffer[i], additional_trip_data[i])) for i in range(len(buffer)) ) - weights.extend([1 / additional_trip_data[0]] * len(additional_trip_data[0])) + weights.extend( + [1 / len(additional_trip_data[0])] * len(additional_trip_data[0]) + ) if outer_break: break From 89608c325054ed06f51e9cf7a1f7e99d887dd49b Mon Sep 17 00:00:00 2001 From: VincentAuriau Date: Mon, 15 Dec 2025 23:15:14 +0100 Subject: [PATCH 09/11] fix: weights type --- choice_learn/basket_models/data/basket_dataset.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/choice_learn/basket_models/data/basket_dataset.py b/choice_learn/basket_models/data/basket_dataset.py index 0a9916cb..482daa73 100644 --- a/choice_learn/basket_models/data/basket_dataset.py +++ b/choice_learn/basket_models/data/basket_dataset.py @@ -799,7 +799,7 @@ def iter_batch_evaluate( weights.extend([1 / len(additional_trip_data[0])] * len(additional_trip_data[0])) # Yield the whole dataset - yield buffer, np.array(weights) + yield buffer, np.array(weights).astype("float32") else: # Yield batches of size batch_size while going through all the trips @@ -822,7 +822,7 @@ def iter_batch_evaluate( if index >= num_trips: # Then the buffer is not full but there are no more trips to consider # Yield the batch partially filled - yield buffer, np.array(weights) + yield buffer, np.array(weights).astype("float32") # Exit the TWO while loops when all trips have been considered outer_break = True @@ -849,4 +849,4 @@ def iter_batch_evaluate( break # Yield the batch - yield buffer, np.array(weights) + yield buffer, np.array(weights).astype("float32") From c1e30dbacc1e82835c329a9476ad84396dd8f3cf Mon Sep 17 00:00:00 2001 From: VincentAuriau Date: Tue, 16 Dec 2025 10:30:11 +0100 Subject: [PATCH 10/11] ADD: metrics updated w/ sample_weight --- choice_learn/utils/metrics.py | 54 +++++++++++++++-------------------- 1 file changed, 23 insertions(+), 31 deletions(-) diff --git a/choice_learn/utils/metrics.py b/choice_learn/utils/metrics.py index a477bc30..8f5067bf 100644 --- a/choice_learn/utils/metrics.py +++ b/choice_learn/utils/metrics.py @@ -24,7 +24,6 @@ def __init__( self, from_logits=False, sparse=False, - average_on_batch=False, epsilon=1e-10, name="negative_log_likelihood", axis=-1, @@ -40,8 +39,6 @@ def __init__( Whether y_true is given as an index or a one-hot, by default False epsilon : float, optional Lower bound for log(.), by default 1e-10 - average_on_batch: bool, optional - Whether the metric should be averaged over each batch. Typically used to get metrics averaged by Trip, by default False name : str, optional Name of operation, by default "negative_log_likelihood" @@ -53,11 +50,10 @@ def __init__( self.n_evals = self.add_variable(shape=(), initializer="zeros", name="n_evals") self.from_logits = from_logits self.sparse = sparse - self.average_on_batch = average_on_batch self.epsilon = epsilon self.axis = axis - def update_state(self, y_true, y_pred, batch=None, sample_weight=None): + def update_state(self, y_true, y_pred, sample_weight=None): """Accumulate statistics for the metric. Parameters @@ -91,16 +87,11 @@ def update_state(self, y_true, y_pred, batch=None, sample_weight=None): axis=self.axis, ) - if batch is not None and self.average_on_batch: - for _, idx in zip(*tf.unique(batch)): - self.nll.assign(self.nll + tf.reduce_mean(nll_value[idx])) - self.n_evals.assign(self.n_evals + 1) + self.nll.assign(self.nll + tf.reduce_sum(nll_value)) + if sample_weight is None: + self.n_evals.assign(self.n_evals + tf.shape(y_true)[0]) else: - self.nll.assign(self.nll + tf.reduce_sum(nll_value)) - if sample_weight is None: - self.n_evals.assign(self.n_evals + tf.shape(y_true)[0]) - else: - self.n_evals.assign(self.n_evals + tf.reduce_sum(sample_weight)) + self.n_evals.assign(self.n_evals + tf.reduce_sum(sample_weight)) def result(self): """Compute the current metric value. @@ -118,7 +109,6 @@ class MRR(tf.keras.metrics.Metric): def __init__( self, - average_on_batch=False, name="mean_reciprocal_rank", axis=-1, **kwargs, @@ -126,14 +116,13 @@ def __init__( super().__init__(name=name, **kwargs) self.mrr = self.add_variable(shape=(), initializer="zeros", name="mrr") self.n_evals = self.add_variable(shape=(), initializer="zeros", name="n_evals") - self.average_on_batch = average_on_batch self.axis = axis def update_state( self, y_true, y_pred, - batch=None, + sample_weight=None, ): """Accumulate statistics for the metric. @@ -156,15 +145,17 @@ def update_state( [tf.range(len(y_true)), y_true], axis=1 ) # Shape: (batch_size, 2) item_ranks = tf.gather_nd(ranks, item_batch_indices) # Shape: (batch_size,) - mean_rank = tf.reduce_sum(tf.cast(1 / item_ranks, dtype=tf.float32), axis=self.axis) - - if batch is not None and self.average_on_batch: - self.mrr.assign(self.mrr + tf.reduce_mean(mean_rank)) - self.n_evals.assign(self.n_evals + 1) + if sample_weight is not None: + mean_rank = tf.reduce_sum( + tf.cast(1 / item_ranks, dtype=tf.float32) * sample_weight, axis=self.axis + ) + self.n_evals.assign(self.n_evals + tf.reduce_sum(sample_weight)) else: - self.mrr.assign(self.mrr + tf.reduce_sum(mean_rank)) + mean_rank = tf.reduce_sum(tf.cast(1 / item_ranks, dtype=tf.float32), axis=self.axis) self.n_evals.assign(self.n_evals + tf.shape(y_true)[0]) + self.mrr.assign(self.mrr + tf.reduce_mean(mean_rank)) + def result(self): """Compute the current metric value. @@ -181,7 +172,6 @@ class HitRate(tf.keras.metrics.Metric): def __init__( self, - average_on_batch=False, top_k: int = 10, name=None, axis=-1, @@ -195,10 +185,9 @@ def __init__( shape=(), initializer="zeros", name=f"hit_rate_at_{self.top_k}" ) self.n_evals = self.add_variable(shape=(), initializer="zeros", name="n_evals") - self.average_on_batch = average_on_batch self.axis = axis - def update_state(self, y_true, y_pred, batch=None): + def update_state(self, y_true, y_pred, sample_weight=None): """Accumulate statistics for the metric. Parameters @@ -223,14 +212,17 @@ def update_state(self, y_true, y_pred, batch=None): ), axis=1, ) - hits = tf.reduce_sum(tf.cast(hits_per_batch, tf.float32), axis=self.axis) - if batch is not None and self.average_on_batch: - self.hit_rate.assign(self.hit_rate + tf.reduce_mean(hits)) - self.n_evals.assign(self.n_evals + 1) + if sample_weight is not None: + hits = tf.reduce_sum( + tf.cast(hits_per_batch, tf.float32) * sample_weight, axis=self.axis + ) + self.n_evals.assign(self.n_evals + tf.reduce_sum(sample_weight)) else: - self.hit_rate.assign(self.hit_rate + tf.reduce_sum(hits)) + hits = tf.reduce_sum(tf.cast(hits_per_batch, tf.float32), axis=self.axis) self.n_evals.assign(self.n_evals + tf.shape(y_true)[0]) + self.hit_rate.assign(self.hit_rate + tf.reduce_sum(hits)) + def result(self): """Compute the current metric value. From 91f89e4ca310c5545f99ab9e0b26c92e1473854a Mon Sep 17 00:00:00 2001 From: VincentAuriau Date: Tue, 16 Dec 2025 11:27:35 +0100 Subject: [PATCH 11/11] FIX: handling of bastwise / samplewise metrics in model.evaluate() --- choice_learn/basket_models/base_basket_model.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/choice_learn/basket_models/base_basket_model.py b/choice_learn/basket_models/base_basket_model.py index e9597313..1ea69b1a 100644 --- a/choice_learn/basket_models/base_basket_model.py +++ b/choice_learn/basket_models/base_basket_model.py @@ -884,7 +884,6 @@ def evaluate( sparse=True, from_logits=False, epsilon=epsilon_eval, - average_on_batch=True, name="basketwise-nll", ) ) @@ -913,11 +912,17 @@ def evaluate( for metric in exec_metrics: # Use update_state, not append(metric(...)) - metric.update_state( - y_true=data_batch[0], - y_pred=predicted_probabilities, - sample_weight=weights_batch, - ) + if "basketwise" in metric.name: + metric.update_state( + y_true=data_batch[0], + y_pred=predicted_probabilities, + sample_weight=weights_batch, + ) + else: + metric.update_state( + y_true=data_batch[0], + y_pred=predicted_probabilities, + ) # After the loops, get the final results return {metric.name: metric.result() for metric in exec_metrics}