Skip to content
97 changes: 21 additions & 76 deletions choice_learn/data/choice_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -884,11 +884,8 @@ def from_single_wide_df(
df,
items_id,
shared_features_columns=None,
items_features_suffixes=None,
items_features_prefixes=None,
available_items_suffix=None,
available_items_prefix=None,
delimiter="_",
items_features_patterns=None,
available_items_pattern=None,
choices_column="choice",
choice_format="items_id",
):
Expand All @@ -902,21 +899,14 @@ def from_single_wide_df(
List of items ids
shared_features_columns : list, optional
List of columns of the dataframe that are shared_features_by_choice, default is None
items_features_prefixes : list, optional
Prefixes of the columns of the dataframe that are items_features_by_choice,
items_features_patterns : list of str, optional
Patterns of the columns of the dataframe that are items_features_by_choice,
given as "*suffix" or "prefix*" where "*" is replaced by items_id in df columns.
default is None
items_features_suffixes : list, optional
Suffixes of the columns of the dataframe that are items_features_by_choice,
available_items_pattern: str, optional
Pattern of the columns of the dataframe that are available_items_by_choice,
given as "*suffix" or "prefix*" where "*" is replaced by items_id in df columns.
default is None
available_items_prefix: str, optional
Prefix of the columns of the dataframe that precise available_items_by_choice,
default is None
available_items_suffix: str, optional
Suffix of the columns of the dataframe that precise available_items_by_choice,
default is None
delimiter: str, optional
Delimiter used to separate the given prefix or suffixes and the features names,
default is "_"
choice_column: str, optional
Name of the column containing the choices, default is "choice"
choice_format: str, optional
Expand All @@ -928,11 +918,6 @@ def from_single_wide_df(
ChoiceDataset
corresponding ChoiceDataset
"""
if available_items_prefix is not None and available_items_suffix is not None:
raise ValueError(
"You cannot give both available_items_prefix and\
available_items_suffix."
)
if choice_format not in ["items_index", "items_id"]:
logging.warning("choice_format not understood, defaulting to 'items_index'")

Expand All @@ -943,43 +928,13 @@ def from_single_wide_df(
shared_features_by_choice = None
shared_features_by_choice_names = None

if items_features_suffixes is not None and items_features_prefixes is not None:
# The list of features names is the concatenation of the two lists of
# prefixes and suffixes
items_features_names = items_features_prefixes + items_features_suffixes
items_features_by_choice = []
for item in items_id:
columns = [f"{feature}{delimiter}{item}" for feature in items_features_prefixes] + [
f"{item}{delimiter}{feature}" for feature in items_features_suffixes
]
for col in columns:
if col not in df.columns:
logging.warning(
f"Column {col} was not in DataFrame,\
dummy creation of the feature with zeros."
)
df[col] = 0
items_features_by_choice.append(df[columns].to_numpy())
items_features_by_choice = np.stack(items_features_by_choice, axis=1)
elif items_features_suffixes is not None:
items_features_names = items_features_suffixes
items_features_by_choice = []
for item in items_id:
columns = [f"{item}{delimiter}{feature}" for feature in items_features_suffixes]
for col in columns:
if col not in df.columns:
logging.warning(
f"Column {col} was not in DataFrame,\
dummy creation of the feature with zeros."
)
df[col] = 0
items_features_by_choice.append(df[columns].to_numpy())
items_features_by_choice = np.stack(items_features_by_choice, axis=1)
elif items_features_prefixes is not None:
items_features_names = items_features_prefixes
if items_features_patterns is not None:
assert all(["*" in pattern for pattern in items_features_patterns]), (
"items_features_patterns should all contain '*' character."

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using assert for input validation is not recommended as assertions can be disabled (e.g., with Python's -O flag), which would bypass this check. It's better to raise a ValueError for invalid user input.

Suggested change
assert all(["*" in pattern for pattern in items_features_patterns]), (
"items_features_patterns should all contain '*' character."
if not all("*" in pattern for pattern in items_features_patterns):
raise ValueError("items_features_patterns should all contain '*' character.")

)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The variable items_features_names is not defined within the if items_features_patterns is not None: block. This will cause a NameError at line 981 when ChoiceDataset is constructed. This appears to be a regression from the previous implementation which correctly set the feature names. You should derive items_features_names from items_features_patterns.

For example, you could add the following line before this one:
items_features_names = [p.strip("*_.") for p in items_features_patterns]

items_features_by_choice = []
for item in items_id:
columns = [f"{feature}{delimiter}{item}" for feature in items_features_prefixes]
columns = [feature.replace("*", item) for feature in items_features_patterns]
for col in columns:
if col not in df.columns:
logging.warning(
Expand All @@ -993,31 +948,21 @@ def from_single_wide_df(
items_features_by_choice = None
items_features_names = None

if available_items_suffix is not None:
if isinstance(available_items_suffix, list):
if not len(available_items_suffix) == len(items_id):
if available_items_pattern is not None:
if isinstance(available_items_pattern, list):
if not len(available_items_pattern) == len(items_id):
raise ValueError(
"You have given a list of columns for availabilities."
"We consider that it is one for each item however lenghts do not match"
)
logging.info("You have given a list of columns for availabilities.")
logging.info("Each column will be matched to an item, given their order")
available_items_by_choice = df[available_items_suffix].to_numpy()
available_items_by_choice = df[available_items_pattern].to_numpy()
else:
columns = [f"{item}{delimiter}{available_items_suffix}" for item in items_id]
available_items_by_choice = df[columns].to_numpy()
elif available_items_prefix is not None:
if isinstance(available_items_prefix, list):
if not len(available_items_prefix) == len(items_id):
raise ValueError(
"You have given a list of columns for availabilities."
"We consider that it is one for each item however lenghts do not match"
)
logging.info("You have given a list of columns for availabilities.")
logging.info("Each column will be matched to an item, given their order")
available_items_by_choice = df[available_items_prefix].to_numpy()
else:
columns = [f"{available_items_prefix}{delimiter}{item}" for item in items_id]
assert "*" in available_items_pattern, (

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using assert for input validation is not recommended as assertions can be disabled (e.g., with Python's -O flag). It's better to raise a ValueError for invalid user input.

Suggested change
else:
columns = [f"{item}{delimiter}{available_items_suffix}" for item in items_id]
available_items_by_choice = df[columns].to_numpy()
elif available_items_prefix is not None:
if isinstance(available_items_prefix, list):
if not len(available_items_prefix) == len(items_id):
raise ValueError(
"You have given a list of columns for availabilities."
"We consider that it is one for each item however lenghts do not match"
)
logging.info("You have given a list of columns for availabilities.")
logging.info("Each column will be matched to an item, given their order")
available_items_by_choice = df[available_items_prefix].to_numpy()
else:
columns = [f"{available_items_prefix}{delimiter}{item}" for item in items_id]
assert "*" in available_items_pattern, (
if "*" not in available_items_pattern:
raise ValueError("available_items_pattern should contain '*' character.")

"available_items_pattern should contain '*' character."
)
columns = [available_items_pattern.replace("*", item) for item in items_id]
available_items_by_choice = df[columns].to_numpy()
else:
available_items_by_choice = None
Expand Down
46 changes: 22 additions & 24 deletions choice_learn/datasets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,9 +338,11 @@ def load_swissmetro(add_items_one_hot=False, as_frame=False, return_desc=False,
df=swiss_df,
items_id=items,
shared_features_columns=shared_features_by_choice_names,
items_features_suffixes=items_features_by_choice_names
+ ["ASC_TRAIN", "ASC_SM", "ASC_CAR"],
available_items_suffix=availabilities_column,
items_features_patterns=[
"*_%s" % column
for column in (items_features_by_choice_names + ["ASC_TRAIN", "ASC_SM", "ASC_CAR"])
],
available_items_pattern="*_%s" % availabilities_column,
choices_column=choice_column,
choice_format="items_index",
)
Expand Down Expand Up @@ -532,8 +534,8 @@ def load_swissmetro(add_items_one_hot=False, as_frame=False, return_desc=False,
df=swiss_df,
items_id=items,
shared_features_columns=shared_features_by_choice_names,
items_features_suffixes=items_features_by_choice_names,
available_items_suffix=availabilities_column,
items_features_patterns=["*_%s" % s for s in items_features_by_choice_names],
available_items_pattern="*_%s" % availabilities_column,
choices_column=choice_column,
choice_format="items_index",
)
Expand Down Expand Up @@ -927,9 +929,7 @@ def load_train(
df=train_df,
items_id=["1", "2"],
shared_features_columns=["id"],
items_features_prefixes=["price", "time", "change", "comfort"],
delimiter="",
available_items_suffix=None,
items_features_patterns=["price*", "time*", "change*", "comfort*"],
choices_column="choice",
choice_format="items_id",
)
Expand Down Expand Up @@ -974,26 +974,25 @@ def load_car_preferences(
cars_df["choice"] = cars_df.apply(lambda row: row.choice[-1], axis=1)
shared_features = ["college", "hsg2", "coml5"]
items_features = [
"type",
"fuel",
"price",
"range",
"acc",
"speed",
"pollution",
"size",
"space",
"cost",
"station",
"type*",
"fuel*",
"price*",
"range*",
"acc*",
"speed*",
"pollution*",
"size*",
"space*",
"cost*",
"station*",
]
items_id = [f"{i}" for i in range(1, 7)]

return ChoiceDataset.from_single_wide_df(
df=cars_df,
items_id=items_id,
shared_features_columns=shared_features,
items_features_prefixes=items_features,
delimiter="",
items_features_patterns=items_features,
choices_column="choice",
choice_format="items_id",
)
Expand Down Expand Up @@ -1060,8 +1059,7 @@ def load_hc(
return ChoiceDataset.from_single_wide_df(
df=hc_df,
shared_features_columns=["income"],
items_features_prefixes=["ich", "och", "occa", "icca"],
delimiter=".",
items_features_patterns=["ich.*", "och.*", "occa.*", "icca.*"],
items_id=items_id,
choices_column="depvar",
choice_format="items_id",
Expand Down Expand Up @@ -1206,7 +1204,7 @@ def load_londonpassenger(
df=london_df,
items_id=items,
shared_features_columns=shared_features_by_choice_names,
items_features_suffixes=items_features_by_choice_names,
items_features_patterns=["*_%s" % s for s in items_features_by_choice_names],
delimiter="_",

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The delimiter parameter is no longer accepted by from_single_wide_df after your changes. This will cause a TypeError. Please remove it.

Suggested change
items_id=items,
shared_features_columns=shared_features_by_choice_names,
items_features_suffixes=items_features_by_choice_names,
items_features_patterns=["*_%s" % s for s in items_features_by_choice_names],
delimiter="_",
items_features_patterns=["*_%s" % s for s in items_features_by_choice_names],
choices_column=choice_column,
choice_format="items_index",

choices_column=choice_column,
choice_format="items_index",
Expand Down
4 changes: 2 additions & 2 deletions notebooks/data/dataset_creation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -677,8 +677,8 @@
" items_id=[\"TRAIN\", \"SM\", \"CAR\"],\n",
" shared_features_columns=[\"GROUP\", \"SURVEY\", \"SP\", \"PURPOSE\", \"FIRST\", \"TICKET\", \"WHO\", \"LUGGAGE\", \"AGE\",\n",
" \"MALE\", \"INCOME\", \"GA\", \"ORIGIN\", \"DEST\"],\n",
" items_features_suffixes=[\"CO\", \"TT\", \"HE\", \"SEATS\"],\n",
" available_items_suffix=\"AV\", # [\"TRAIN_AV\", \"SM_AV\", \"CAR_AV\"] also works\n",
" items_features_patterns=[\"*_CO\", \"*_TT\", \"*_HE\", \"*_SEATS\"],\n",
" available_items_pattern=\"*_AV\", # [\"TRAIN_AV\", \"SM_AV\", \"CAR_AV\"] also works\n",
" choices_column=\"CHOICE\",\n",
" choice_format=\"item_index\",\n",
")"
Expand Down
Loading