Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion choice_learn/basket_models/base_basket_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,6 @@ def compute_batch_utility(
"""
return

# Not clear
def compute_item_likelihood(
self,
basket: Union[None, np.ndarray] = None,
Expand Down
68 changes: 47 additions & 21 deletions choice_learn/basket_models/data/basket_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,9 @@ def get_items_up_to_index(self, i: int) -> np.ndarray:
class TripDataset:
"""Class for a dataset of trips."""

def __init__(self, trips: list[Trip], available_items: np.ndarray) -> None:
def __init__(
self, trips: list[Trip], available_items: np.ndarray = None, prices: np.ndarray = None
) -> None:
"""Initialize the dataset.

Parameters
Expand All @@ -112,6 +114,7 @@ def __init__(self, trips: list[Trip], available_items: np.ndarray) -> None:
self.max_length = max([trip.trip_length for trip in self.trips])
self.n_samples = len(self.get_transactions())
self.available_items = available_items
self.prices = prices

def __len__(self) -> int:
"""Return the number of trips in the dataset.
Expand Down Expand Up @@ -308,7 +311,7 @@ def get_one_vs_all_augmented_data_from_trip_index(
self,
trip_index: int,
) -> tuple[np.ndarray]:
"""Get augmented data from a trip index.
"""Get augmented data from a trip index - following AleaCarta method.

Augmented data consists in removing one item from the basket that will be used
as a target from the remaining items. It is done for all items, leading to returning:
Expand Down Expand Up @@ -386,7 +389,7 @@ def get_subbaskets_augmented_data_from_trip_index(
self,
trip_index: int,
) -> tuple[np.ndarray]:
"""Get augmented data from a trip index.
"""Get augmented data from a trip index - following Shopper method.

Augmented data includes all the transactions obtained sequentially from the trip.
In particular, items in the basket are shuffled and sub-baskets are built iteratively
Expand Down Expand Up @@ -417,11 +420,11 @@ def get_subbaskets_augmented_data_from_trip_index(

# Draw a random permutation of the items in the basket without the checkout item 0
# TODO at a later stage: improve by sampling several permutations here
permutation_list = list(permutations(range(length_trip - 1)))
permutation_list = list(permutations(range(length_trip)))
permutation = random.sample(permutation_list, 1)[0] # nosec

# Permute the basket while keeping the checkout item 0 at the end
permuted_purchases = np.array([trip.purchases[j] for j in permutation] + [0])
permuted_purchases = np.array([trip.purchases[j] for j in permutation] + [self.n_items])

# Truncate the baskets: for each batch sample, we consider the truncation possibilities
# ranging from an empty basket to the basket with all the elements except the checkout item
Expand All @@ -430,7 +433,7 @@ def get_subbaskets_augmented_data_from_trip_index(
padded_truncated_purchases = np.array(
[
np.concatenate((permuted_purchases[:i], -1 * np.ones(self.max_length - i)))
for i in range(0, length_trip)
for i in range(0, length_trip + 1)
],
dtype=int,
)
Expand All @@ -447,7 +450,7 @@ def get_subbaskets_augmented_data_from_trip_index(
-1 * np.ones(self.max_length - len(permuted_purchases) + i + 1),
)
)
for i in range(0, length_trip)
for i in range(0, length_trip + 1)
],
dtype=int,
)
Expand All @@ -458,17 +461,28 @@ def get_subbaskets_augmented_data_from_trip_index(
else: # np.ndarray
# Then it is directly the availability matrix
assortment = trip.assortment
# end-of-basket item always available
assortment = np.concatenate([assortment, [1.0]])

if not (isinstance(trip.prices, np.ndarray) or isinstance(trip.prices, list)):
# Then it is the assortment ID (ie its index in self.available_items)
prices = self.prices[trip.prices]
else: # np.ndarray
# Then it is directly the availability matrix
prices = trip.prices
# end-of-basket item always 0.
prices = np.concatenate([prices, [0.0]])

# Each item is linked to a basket, the future purchases,
# a store, a week, prices and an assortment
return (
permuted_purchases, # Items
padded_truncated_purchases, # Baskets
padded_future_purchases, # Future purchases
np.full(length_trip, trip.store), # Stores
np.full(length_trip, trip.week), # Weeks
np.tile(trip.prices, (length_trip, 1)), # Prices
np.tile(assortment, (length_trip, 1)), # Available items
np.full(length_trip + 1, trip.store), # Stores
np.full(length_trip + 1, trip.week), # Weeks
np.tile(prices, (length_trip + 1, 1)), # Prices
np.tile(assortment, (length_trip + 1, 1)), # Available items
)

def iter_batch(
Expand Down Expand Up @@ -509,16 +523,28 @@ def iter_batch(
trip_indexes = np.random.default_rng().permutation(trip_indexes)

# Initialize the buffer
buffer = (
np.empty(0, dtype=int), # Items
np.empty((0, self.max_length), dtype=int), # Baskets
np.empty((0, self.max_length), dtype=int), # Future purchases
np.empty(0, dtype=int), # Stores
np.empty(0, dtype=int), # Weeks
np.empty((0, self.n_items), dtype=int), # Prices
np.empty((0, self.n_items), dtype=int), # Available items
)

if data_method == "shopper":
buffer = (
np.empty(0, dtype=int), # Items
np.empty((0, self.max_length), dtype=int), # Baskets
np.empty((0, self.max_length), dtype=int), # Future purchases
np.empty(0, dtype=int), # Stores
np.empty(0, dtype=int), # Weeks
np.empty((0, self.n_items + 1), dtype=int), # Prices
np.empty((0, self.n_items + 1), dtype=int), # Available items
)
elif data_method == "aleacarta":
buffer = (
np.empty(0, dtype=int), # Items
np.empty((0, self.max_length), dtype=int), # Baskets
np.empty((0, self.max_length), dtype=int), # Future purchases
np.empty(0, dtype=int), # Stores
np.empty(0, dtype=int), # Weeks
np.empty((0, self.n_items), dtype=int), # Prices
np.empty((0, self.n_items), dtype=int), # Available items
)
else:
raise ValueError(f"Unknown data method: {data_method}")

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There's some code duplication in the initialization of the buffer. The first five elements of the tuple are identical for both shopper and aleacarta data methods. Consider refactoring this to improve maintainability by defining the common part of the buffer first, and then appending the method-specific parts based on data_method.

if batch_size == -1:
# Get the whole dataset in one batch
for trip_index in trip_indexes:
Expand Down
Loading