diff --git a/choice_learn/basket_models/base_basket_model.py b/choice_learn/basket_models/base_basket_model.py index 31f3755e..1ea69b1a 100644 --- a/choice_learn/basket_models/base_basket_model.py +++ b/choice_learn/basket_models/base_basket_model.py @@ -884,7 +884,6 @@ def evaluate( sparse=True, from_logits=False, epsilon=epsilon_eval, - average_on_batch=True, name="basketwise-nll", ) ) @@ -897,7 +896,7 @@ def evaluate( metric.reset_state() # for trip in trip_dataset.trips: - for data_batch, identifier_batch in trip_dataset.iter_batch_evaluate( + for data_batch, weights_batch in trip_dataset.iter_batch_evaluate( trip_batch_size=trip_batch_size ): # Sum of the log-likelihoods of all the baskets in the batch @@ -913,9 +912,17 @@ def evaluate( for metric in exec_metrics: # Use update_state, not append(metric(...)) - metric.update_state( - y_true=data_batch[0], y_pred=predicted_probabilities, batch=identifier_batch - ) + if "basketwise" in metric.name: + metric.update_state( + y_true=data_batch[0], + y_pred=predicted_probabilities, + sample_weight=weights_batch, + ) + else: + metric.update_state( + y_true=data_batch[0], + y_pred=predicted_probabilities, + ) # After the loops, get the final results return {metric.name: metric.result() for metric in exec_metrics} diff --git a/choice_learn/basket_models/data/basket_dataset.py b/choice_learn/basket_models/data/basket_dataset.py index 47b07ecf..482daa73 100644 --- a/choice_learn/basket_models/data/basket_dataset.py +++ b/choice_learn/basket_models/data/basket_dataset.py @@ -783,11 +783,12 @@ def iter_batch_evaluate( np.empty(0, dtype=int), # Weeks np.empty((0, self.n_items), dtype=int), # Prices np.empty((0, self.n_items), dtype=int), # Available items + np.empty(0, dtype=int), # Users ) if trip_batch_size == -1: # Get the whole dataset in one batch - identifiers = [] + weights = [] for trip_index in trip_indexes: additional_trip_data = self.get_one_vs_all_augmented_data_from_trip_index( trip_index @@ -795,17 +796,18 @@ def iter_batch_evaluate( buffer = tuple( np.concatenate((buffer[i], additional_trip_data[i])) for i in range(len(buffer)) ) - identifiers.extend([trip_index] * len(additional_trip_data[0])) + weights.extend([1 / len(additional_trip_data[0])] * len(additional_trip_data[0])) # Yield the whole dataset - yield buffer, np.array(identifiers) + yield buffer, np.array(weights).astype("float32") else: # Yield batches of size batch_size while going through all the trips index = 0 outer_break = False while index < num_trips: - trip_identifier = [] + weights = [] + trip_count = 0 buffer = ( np.empty(0, dtype=int), # Items np.empty((0, self.max_length), dtype=int), # Baskets @@ -816,11 +818,11 @@ def iter_batch_evaluate( np.empty((0, self.n_items), dtype=int), # Available items np.empty(0, dtype=int), # Users ) - while np.max(trip_identifier, initial=-1) + 1 < trip_batch_size: + while trip_count + 1 < trip_batch_size: if index >= num_trips: # Then the buffer is not full but there are no more trips to consider # Yield the batch partially filled - yield buffer, np.array(trip_identifier) + yield buffer, np.array(weights).astype("float32") # Exit the TWO while loops when all trips have been considered outer_break = True @@ -832,18 +834,19 @@ def iter_batch_evaluate( trip_indexes[index] ) index += 1 + trip_count += 1 # Fill the buffer with the new trip buffer = tuple( np.concatenate((buffer[i], additional_trip_data[i])) for i in range(len(buffer)) ) - trip_identifier.extend( - [np.max(trip_identifier, initial=-1) + 1] * len(additional_trip_data[0]) + weights.extend( + [1 / len(additional_trip_data[0])] * len(additional_trip_data[0]) ) if outer_break: break # Yield the batch - yield buffer, np.array(trip_identifier) + yield buffer, np.array(weights).astype("float32") diff --git a/choice_learn/models/base_model.py b/choice_learn/models/base_model.py index bd2fcb89..de577224 100644 --- a/choice_learn/models/base_model.py +++ b/choice_learn/models/base_model.py @@ -12,6 +12,7 @@ import tqdm import choice_learn.tf_ops as tf_ops +from choice_learn.data import ChoiceDataset class ChoiceModel: @@ -254,6 +255,7 @@ def fit( choice_dataset, sample_weight=None, val_dataset=None, + validation_freq=1, verbose=0, ): """Train the model with a ChoiceDataset. @@ -264,7 +266,7 @@ def fit( Input data in the form of a ChoiceDataset sample_weight : np.ndarray, optional Sample weight to apply, by default None - val_dataset : ChoiceDataset, optional + val_dataset : ChoiceDataset or (ChoiceDataset, samples_weight), optional Test ChoiceDataset to evaluate performances on test at each epoch, by default None verbose : int, optional print level, for debugging, by default 0 @@ -272,6 +274,10 @@ def fit( Number of epochs, default is None, meaning we use self.epochs batch_size : int, optional Batch size, default is None, meaning we use self.batch_size + validation_freq: int, optional + Only relevant if validation data is provided. Specifies how many training epochs + to run before a new validation run is performed, e.g. validation_freq=2 runs validation + every 2 epochs. Returns ------- @@ -405,24 +411,55 @@ def fit( ) # Test on val_dataset if provided - if val_dataset is not None: + if val_dataset is not None and ((epoch_nb + 1) % validation_freq) == 0: test_losses = [] - for batch_nb, ( - shared_features_batch, - items_features_batch, - available_items_batch, - choices_batch, - ) in enumerate(val_dataset.iter_batch(shuffle=False, batch_size=batch_size)): + + val_samples_weight = None + if isinstance(val_dataset, tuple): + if not len(val_dataset) == 2: + raise ValueError( + """if argument val_dataset is a tuple, it should be + in the form (ChoiceDataset, weights)""" + ) + validation_dataset, val_samples_weight = val_dataset + elif isinstance(val_dataset, ChoiceDataset): + validation_dataset = val_dataset + else: + raise ValueError( + """val_dataset should be a ChoiceDataset or + a tuple of (ChoiceDataset, weights).""" + ) + + val_iterator = validation_dataset.iter_batch( + shuffle=False, sample_weight=val_samples_weight, batch_size=batch_size + ) + + for batch_nb, batch_data in enumerate(val_iterator): + weight_batch = None + if val_samples_weight is not None: + batch_features, weight_batch = batch_data + else: + batch_features = batch_data + + ( + shared_features_batch, + items_features_batch, + available_items_batch, + choices_batch, + ) = batch_features + self.callbacks.on_batch_begin(batch_nb) self.callbacks.on_test_batch_begin(batch_nb) - test_losses.append( - self.batch_predict( - shared_features_batch, - items_features_batch, - available_items_batch, - choices_batch, - )[0]["optimized_loss"] - ) + + loss = self.batch_predict( + shared_features_batch, + items_features_batch, + available_items_batch, + choices_batch, + sample_weight=weight_batch, + )[0]["optimized_loss"] + test_losses.append(loss) + val_logs["val_loss"].append(test_losses[-1]) temps_logs = {k: tf.reduce_mean(v) for k, v in val_logs.items()} self.callbacks.on_test_batch_end(batch_nb, logs=temps_logs) diff --git a/choice_learn/utils/metrics.py b/choice_learn/utils/metrics.py index a477bc30..8f5067bf 100644 --- a/choice_learn/utils/metrics.py +++ b/choice_learn/utils/metrics.py @@ -24,7 +24,6 @@ def __init__( self, from_logits=False, sparse=False, - average_on_batch=False, epsilon=1e-10, name="negative_log_likelihood", axis=-1, @@ -40,8 +39,6 @@ def __init__( Whether y_true is given as an index or a one-hot, by default False epsilon : float, optional Lower bound for log(.), by default 1e-10 - average_on_batch: bool, optional - Whether the metric should be averaged over each batch. Typically used to get metrics averaged by Trip, by default False name : str, optional Name of operation, by default "negative_log_likelihood" @@ -53,11 +50,10 @@ def __init__( self.n_evals = self.add_variable(shape=(), initializer="zeros", name="n_evals") self.from_logits = from_logits self.sparse = sparse - self.average_on_batch = average_on_batch self.epsilon = epsilon self.axis = axis - def update_state(self, y_true, y_pred, batch=None, sample_weight=None): + def update_state(self, y_true, y_pred, sample_weight=None): """Accumulate statistics for the metric. Parameters @@ -91,16 +87,11 @@ def update_state(self, y_true, y_pred, batch=None, sample_weight=None): axis=self.axis, ) - if batch is not None and self.average_on_batch: - for _, idx in zip(*tf.unique(batch)): - self.nll.assign(self.nll + tf.reduce_mean(nll_value[idx])) - self.n_evals.assign(self.n_evals + 1) + self.nll.assign(self.nll + tf.reduce_sum(nll_value)) + if sample_weight is None: + self.n_evals.assign(self.n_evals + tf.shape(y_true)[0]) else: - self.nll.assign(self.nll + tf.reduce_sum(nll_value)) - if sample_weight is None: - self.n_evals.assign(self.n_evals + tf.shape(y_true)[0]) - else: - self.n_evals.assign(self.n_evals + tf.reduce_sum(sample_weight)) + self.n_evals.assign(self.n_evals + tf.reduce_sum(sample_weight)) def result(self): """Compute the current metric value. @@ -118,7 +109,6 @@ class MRR(tf.keras.metrics.Metric): def __init__( self, - average_on_batch=False, name="mean_reciprocal_rank", axis=-1, **kwargs, @@ -126,14 +116,13 @@ def __init__( super().__init__(name=name, **kwargs) self.mrr = self.add_variable(shape=(), initializer="zeros", name="mrr") self.n_evals = self.add_variable(shape=(), initializer="zeros", name="n_evals") - self.average_on_batch = average_on_batch self.axis = axis def update_state( self, y_true, y_pred, - batch=None, + sample_weight=None, ): """Accumulate statistics for the metric. @@ -156,15 +145,17 @@ def update_state( [tf.range(len(y_true)), y_true], axis=1 ) # Shape: (batch_size, 2) item_ranks = tf.gather_nd(ranks, item_batch_indices) # Shape: (batch_size,) - mean_rank = tf.reduce_sum(tf.cast(1 / item_ranks, dtype=tf.float32), axis=self.axis) - - if batch is not None and self.average_on_batch: - self.mrr.assign(self.mrr + tf.reduce_mean(mean_rank)) - self.n_evals.assign(self.n_evals + 1) + if sample_weight is not None: + mean_rank = tf.reduce_sum( + tf.cast(1 / item_ranks, dtype=tf.float32) * sample_weight, axis=self.axis + ) + self.n_evals.assign(self.n_evals + tf.reduce_sum(sample_weight)) else: - self.mrr.assign(self.mrr + tf.reduce_sum(mean_rank)) + mean_rank = tf.reduce_sum(tf.cast(1 / item_ranks, dtype=tf.float32), axis=self.axis) self.n_evals.assign(self.n_evals + tf.shape(y_true)[0]) + self.mrr.assign(self.mrr + tf.reduce_mean(mean_rank)) + def result(self): """Compute the current metric value. @@ -181,7 +172,6 @@ class HitRate(tf.keras.metrics.Metric): def __init__( self, - average_on_batch=False, top_k: int = 10, name=None, axis=-1, @@ -195,10 +185,9 @@ def __init__( shape=(), initializer="zeros", name=f"hit_rate_at_{self.top_k}" ) self.n_evals = self.add_variable(shape=(), initializer="zeros", name="n_evals") - self.average_on_batch = average_on_batch self.axis = axis - def update_state(self, y_true, y_pred, batch=None): + def update_state(self, y_true, y_pred, sample_weight=None): """Accumulate statistics for the metric. Parameters @@ -223,14 +212,17 @@ def update_state(self, y_true, y_pred, batch=None): ), axis=1, ) - hits = tf.reduce_sum(tf.cast(hits_per_batch, tf.float32), axis=self.axis) - if batch is not None and self.average_on_batch: - self.hit_rate.assign(self.hit_rate + tf.reduce_mean(hits)) - self.n_evals.assign(self.n_evals + 1) + if sample_weight is not None: + hits = tf.reduce_sum( + tf.cast(hits_per_batch, tf.float32) * sample_weight, axis=self.axis + ) + self.n_evals.assign(self.n_evals + tf.reduce_sum(sample_weight)) else: - self.hit_rate.assign(self.hit_rate + tf.reduce_sum(hits)) + hits = tf.reduce_sum(tf.cast(hits_per_batch, tf.float32), axis=self.axis) self.n_evals.assign(self.n_evals + tf.shape(y_true)[0]) + self.hit_rate.assign(self.hit_rate + tf.reduce_sum(hits)) + def result(self): """Compute the current metric value. diff --git a/notebooks/models/simple_mnl.ipynb b/notebooks/models/simple_mnl.ipynb index 67653242..6a9ea190 100644 --- a/notebooks/models/simple_mnl.ipynb +++ b/notebooks/models/simple_mnl.ipynb @@ -307,24 +307,24 @@ "