Skip to content
6 changes: 4 additions & 2 deletions choice_learn/basket_models/base_basket_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -897,7 +897,7 @@ def evaluate(
metric.reset_state()

# for trip in trip_dataset.trips:
for data_batch, identifier_batch in trip_dataset.iter_batch_evaluate(
for data_batch, weights_batch in trip_dataset.iter_batch_evaluate(
trip_batch_size=trip_batch_size
):
# Sum of the log-likelihoods of all the baskets in the batch
Expand All @@ -914,7 +914,9 @@ def evaluate(
for metric in exec_metrics:
# Use update_state, not append(metric(...))
metric.update_state(
y_true=data_batch[0], y_pred=predicted_probabilities, batch=identifier_batch
y_true=data_batch[0],
y_pred=predicted_probabilities,
sample_weight=weights_batch,
)

# After the loops, get the final results
Expand Down
21 changes: 11 additions & 10 deletions choice_learn/basket_models/data/basket_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -783,29 +783,31 @@ def iter_batch_evaluate(
np.empty(0, dtype=int), # Weeks
np.empty((0, self.n_items), dtype=int), # Prices
np.empty((0, self.n_items), dtype=int), # Available items
np.empty(0, dtype=int), # Users
)

if trip_batch_size == -1:
# Get the whole dataset in one batch
identifiers = []
weights = []
for trip_index in trip_indexes:
additional_trip_data = self.get_one_vs_all_augmented_data_from_trip_index(
trip_index
)
buffer = tuple(
np.concatenate((buffer[i], additional_trip_data[i])) for i in range(len(buffer))
)
identifiers.extend([trip_index] * len(additional_trip_data[0]))
weights.extend([1 / len(additional_trip_data[0])] * len(additional_trip_data[0]))

# Yield the whole dataset
yield buffer, np.array(identifiers)
yield buffer, np.array(weights)

else:
# Yield batches of size batch_size while going through all the trips
index = 0
outer_break = False
while index < num_trips:
trip_identifier = []
weights = []
trip_count = 0
buffer = (
np.empty(0, dtype=int), # Items
np.empty((0, self.max_length), dtype=int), # Baskets
Expand All @@ -816,11 +818,11 @@ def iter_batch_evaluate(
np.empty((0, self.n_items), dtype=int), # Available items
np.empty(0, dtype=int), # Users
)
while np.max(trip_identifier, initial=-1) + 1 < trip_batch_size:
while trip_count + 1 < trip_batch_size:
if index >= num_trips:
# Then the buffer is not full but there are no more trips to consider
# Yield the batch partially filled
yield buffer, np.array(trip_identifier)
yield buffer, np.array(weights)

# Exit the TWO while loops when all trips have been considered
outer_break = True
Expand All @@ -832,18 +834,17 @@ def iter_batch_evaluate(
trip_indexes[index]
)
index += 1
trip_count += 1

# Fill the buffer with the new trip
buffer = tuple(
np.concatenate((buffer[i], additional_trip_data[i]))
for i in range(len(buffer))
)
trip_identifier.extend(
[np.max(trip_identifier, initial=-1) + 1] * len(additional_trip_data[0])
)
weights.extend([1 / additional_trip_data[0]] * len(additional_trip_data[0]))

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There's a potential division-by-zero error here. additional_trip_data[0] is a NumPy array of item IDs. The expression 1 / additional_trip_data[0] performs element-wise division. If any item ID in additional_trip_data[0] is 0 (which is common, especially for the checkout item), this will raise a ZeroDivisionError.

Based on the logic in the if trip_batch_size == -1: block (line 799), it seems the intention is to use the number of items to calculate the weight. The line should likely be weights.extend([1 / len(additional_trip_data[0])] * len(additional_trip_data[0])).

Suggested change
weights.extend([1 / additional_trip_data[0]] * len(additional_trip_data[0]))
weights.extend([1 / len(additional_trip_data[0])] * len(additional_trip_data[0]))


if outer_break:
break

# Yield the batch
yield buffer, np.array(trip_identifier)
yield buffer, np.array(weights)
69 changes: 53 additions & 16 deletions choice_learn/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import tqdm

import choice_learn.tf_ops as tf_ops
from choice_learn.data import ChoiceDataset


class ChoiceModel:
Expand Down Expand Up @@ -254,6 +255,7 @@ def fit(
choice_dataset,
sample_weight=None,
val_dataset=None,
validation_freq=1,
verbose=0,
):
"""Train the model with a ChoiceDataset.
Expand All @@ -264,14 +266,18 @@ def fit(
Input data in the form of a ChoiceDataset
sample_weight : np.ndarray, optional
Sample weight to apply, by default None
val_dataset : ChoiceDataset, optional
val_dataset : ChoiceDataset or (ChoiceDataset, samples_weight), optional
Test ChoiceDataset to evaluate performances on test at each epoch, by default None
verbose : int, optional
print level, for debugging, by default 0
epochs : int, optional
Number of epochs, default is None, meaning we use self.epochs
batch_size : int, optional
Batch size, default is None, meaning we use self.batch_size
validation_freq: int, optional
Only relevant if validation data is provided. Specifies how many training epochs
to run before a new validation run is performed, e.g. validation_freq=2 runs validation
every 2 epochs.

Returns
-------
Expand Down Expand Up @@ -405,24 +411,55 @@ def fit(
)

# Test on val_dataset if provided
if val_dataset is not None:
if val_dataset is not None and ((epoch_nb + 1) % validation_freq) == 0:
test_losses = []
for batch_nb, (
shared_features_batch,
items_features_batch,
available_items_batch,
choices_batch,
) in enumerate(val_dataset.iter_batch(shuffle=False, batch_size=batch_size)):

val_samples_weight = None
if isinstance(val_dataset, tuple):
if not len(val_dataset) == 2:
raise ValueError(
"""if argument val_dataset is a tuple, it should be
in the form (ChoiceDataset, weights)"""
)
validation_dataset, val_samples_weight = val_dataset
elif isinstance(val_dataset, ChoiceDataset):
validation_dataset = val_dataset
else:
raise ValueError(
"""val_dataset should be a ChoiceDataset or
a tuple of (ChoiceDataset, weights)."""
)

val_iterator = validation_dataset.iter_batch(
shuffle=False, sample_weight=val_samples_weight, batch_size=batch_size
)

for batch_nb, batch_data in enumerate(val_iterator):
weight_batch = None
if val_samples_weight is not None:
batch_features, weight_batch = batch_data
else:
batch_features = batch_data

(
shared_features_batch,
items_features_batch,
available_items_batch,
choices_batch,
) = batch_features

self.callbacks.on_batch_begin(batch_nb)
self.callbacks.on_test_batch_begin(batch_nb)
test_losses.append(
self.batch_predict(
shared_features_batch,
items_features_batch,
available_items_batch,
choices_batch,
)[0]["optimized_loss"]
)

loss = self.batch_predict(
shared_features_batch,
items_features_batch,
available_items_batch,
choices_batch,
sample_weight=weight_batch,
)[0]["optimized_loss"]
test_losses.append(loss)

val_logs["val_loss"].append(test_losses[-1])
temps_logs = {k: tf.reduce_mean(v) for k, v in val_logs.items()}
self.callbacks.on_test_batch_end(batch_nb, logs=temps_logs)
Expand Down
50 changes: 25 additions & 25 deletions notebooks/models/simple_mnl.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -307,24 +307,24 @@
" <td>Weights_items_features_0</td>\n",
" <td>-0.001533</td>\n",
" <td>0.000621</td>\n",
" <td>-2.469423</td>\n",
" <td>1.353312e-02</td>\n",
" <td>-2.469422</td>\n",
" <td>1.353315e-02</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>Weights_items_features_1</td>\n",
" <td>-0.006996</td>\n",
" <td>0.001554</td>\n",
" <td>-4.501964</td>\n",
" <td>6.675720e-06</td>\n",
" <td>-4.501969</td>\n",
" <td>6.732662e-06</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>Intercept_0</td>\n",
" <td>1.710969</td>\n",
" <td>0.226741</td>\n",
" <td>7.545904</td>\n",
" <td>0.000000e+00</td>\n",
" <td>0.226742</td>\n",
" <td>7.545903</td>\n",
" <td>4.485301e-14</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
Expand All @@ -338,38 +338,38 @@
" <th>4</th>\n",
" <td>Intercept_2</td>\n",
" <td>1.658846</td>\n",
" <td>0.448417</td>\n",
" <td>3.699342</td>\n",
" <td>2.161264e-04</td>\n",
" <td>0.448416</td>\n",
" <td>3.699345</td>\n",
" <td>2.161564e-04</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>Intercept_3</td>\n",
" <td>1.853437</td>\n",
" <td>0.361953</td>\n",
" <td>5.120663</td>\n",
" <td>3.576279e-07</td>\n",
" <td>0.361952</td>\n",
" <td>5.120667</td>\n",
" <td>3.044562e-07</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Coefficient Name Coefficient Estimation Std. Err z_value \\\n",
"0 Weights_items_features_0 -0.001533 0.000621 -2.469423 \n",
"1 Weights_items_features_1 -0.006996 0.001554 -4.501964 \n",
"2 Intercept_0 1.710969 0.226741 7.545904 \n",
"0 Weights_items_features_0 -0.001533 0.000621 -2.469422 \n",
"1 Weights_items_features_1 -0.006996 0.001554 -4.501969 \n",
"2 Intercept_0 1.710969 0.226742 7.545903 \n",
"3 Intercept_1 0.308263 0.206591 1.492140 \n",
"4 Intercept_2 1.658846 0.448417 3.699342 \n",
"5 Intercept_3 1.853437 0.361953 5.120663 \n",
"4 Intercept_2 1.658846 0.448416 3.699345 \n",
"5 Intercept_3 1.853437 0.361952 5.120667 \n",
"\n",
" P(.>z) \n",
"0 1.353312e-02 \n",
"1 6.675720e-06 \n",
"2 0.000000e+00 \n",
"0 1.353315e-02 \n",
"1 6.732662e-06 \n",
"2 4.485301e-14 \n",
"3 1.356624e-01 \n",
"4 2.161264e-04 \n",
"5 3.576279e-07 "
"4 2.161564e-04 \n",
"5 3.044562e-07 "
]
},
"execution_count": null,
Expand All @@ -391,7 +391,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "tf_env",
"display_name": "choice_learn",
"language": "python",
"name": "python3"
},
Expand All @@ -405,7 +405,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.4"
"version": "3.12.11"
}
},
"nbformat": 4,
Expand Down
15 changes: 15 additions & 0 deletions tests/unit_tests/models/test_simplemnl.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,3 +153,18 @@ def test_save_load():

assert nll_a == nll_b
shutil.rmtree("test_save")


def test_weighted_val_dataset():
"""Tests instantiation with item and fit with Adam."""
tf.config.run_functions_eagerly(True)
model = SimpleMNL(intercept="item", optimizer="Adam", epochs=100, lr=0.1)
model.instantiate(n_items=3, n_items_features=2, n_shared_features=3)
nll_b = model.evaluate(test_dataset)
model.fit(
test_dataset, get_report=True, val_dataset=(test_dataset, np.ones((len(test_dataset),)))
)
nll_a = model.evaluate(test_dataset, batch_size=-1)
assert nll_a < nll_b

assert model.report.to_numpy().shape == (7, 5)