-
Notifications
You must be signed in to change notification settings - Fork 15
ENH: better handling of DF -> ChoiceDataset parametrization #303
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 5 commits
fe57a1f
7588834
1253df8
5e81a6c
83f9289
e6bf4b7
64b53c8
4603909
2483024
0311533
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -884,11 +884,8 @@ def from_single_wide_df( | |||||||||||||||||||||||||||||||||||||
| df, | ||||||||||||||||||||||||||||||||||||||
| items_id, | ||||||||||||||||||||||||||||||||||||||
| shared_features_columns=None, | ||||||||||||||||||||||||||||||||||||||
| items_features_suffixes=None, | ||||||||||||||||||||||||||||||||||||||
| items_features_prefixes=None, | ||||||||||||||||||||||||||||||||||||||
| available_items_suffix=None, | ||||||||||||||||||||||||||||||||||||||
| available_items_prefix=None, | ||||||||||||||||||||||||||||||||||||||
| delimiter="_", | ||||||||||||||||||||||||||||||||||||||
| items_features_patterns=None, | ||||||||||||||||||||||||||||||||||||||
| available_items_pattern=None, | ||||||||||||||||||||||||||||||||||||||
| choices_column="choice", | ||||||||||||||||||||||||||||||||||||||
| choice_format="items_id", | ||||||||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||||||||
|
|
@@ -902,21 +899,14 @@ def from_single_wide_df( | |||||||||||||||||||||||||||||||||||||
| List of items ids | ||||||||||||||||||||||||||||||||||||||
| shared_features_columns : list, optional | ||||||||||||||||||||||||||||||||||||||
| List of columns of the dataframe that are shared_features_by_choice, default is None | ||||||||||||||||||||||||||||||||||||||
| items_features_prefixes : list, optional | ||||||||||||||||||||||||||||||||||||||
| Prefixes of the columns of the dataframe that are items_features_by_choice, | ||||||||||||||||||||||||||||||||||||||
| items_features_patterns : list of str, optional | ||||||||||||||||||||||||||||||||||||||
| Patterns of the columns of the dataframe that are items_features_by_choice, | ||||||||||||||||||||||||||||||||||||||
| given as "*suffix" or "prefix*" where "*" is replaced by items_id in df columns. | ||||||||||||||||||||||||||||||||||||||
| default is None | ||||||||||||||||||||||||||||||||||||||
| items_features_suffixes : list, optional | ||||||||||||||||||||||||||||||||||||||
| Suffixes of the columns of the dataframe that are items_features_by_choice, | ||||||||||||||||||||||||||||||||||||||
| available_items_pattern: str, optional | ||||||||||||||||||||||||||||||||||||||
| Pattern of the columns of the dataframe that are available_items_by_choice, | ||||||||||||||||||||||||||||||||||||||
| given as "*suffix" or "prefix*" where "*" is replaced by items_id in df columns. | ||||||||||||||||||||||||||||||||||||||
| default is None | ||||||||||||||||||||||||||||||||||||||
| available_items_prefix: str, optional | ||||||||||||||||||||||||||||||||||||||
| Prefix of the columns of the dataframe that precise available_items_by_choice, | ||||||||||||||||||||||||||||||||||||||
| default is None | ||||||||||||||||||||||||||||||||||||||
| available_items_suffix: str, optional | ||||||||||||||||||||||||||||||||||||||
| Suffix of the columns of the dataframe that precise available_items_by_choice, | ||||||||||||||||||||||||||||||||||||||
| default is None | ||||||||||||||||||||||||||||||||||||||
| delimiter: str, optional | ||||||||||||||||||||||||||||||||||||||
| Delimiter used to separate the given prefix or suffixes and the features names, | ||||||||||||||||||||||||||||||||||||||
| default is "_" | ||||||||||||||||||||||||||||||||||||||
| choice_column: str, optional | ||||||||||||||||||||||||||||||||||||||
| Name of the column containing the choices, default is "choice" | ||||||||||||||||||||||||||||||||||||||
| choice_format: str, optional | ||||||||||||||||||||||||||||||||||||||
|
|
@@ -928,11 +918,6 @@ def from_single_wide_df( | |||||||||||||||||||||||||||||||||||||
| ChoiceDataset | ||||||||||||||||||||||||||||||||||||||
| corresponding ChoiceDataset | ||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||
| if available_items_prefix is not None and available_items_suffix is not None: | ||||||||||||||||||||||||||||||||||||||
| raise ValueError( | ||||||||||||||||||||||||||||||||||||||
| "You cannot give both available_items_prefix and\ | ||||||||||||||||||||||||||||||||||||||
| available_items_suffix." | ||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||
| if choice_format not in ["items_index", "items_id"]: | ||||||||||||||||||||||||||||||||||||||
| logging.warning("choice_format not understood, defaulting to 'items_index'") | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
|
|
@@ -943,43 +928,13 @@ def from_single_wide_df( | |||||||||||||||||||||||||||||||||||||
| shared_features_by_choice = None | ||||||||||||||||||||||||||||||||||||||
| shared_features_by_choice_names = None | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| if items_features_suffixes is not None and items_features_prefixes is not None: | ||||||||||||||||||||||||||||||||||||||
| # The list of features names is the concatenation of the two lists of | ||||||||||||||||||||||||||||||||||||||
| # prefixes and suffixes | ||||||||||||||||||||||||||||||||||||||
| items_features_names = items_features_prefixes + items_features_suffixes | ||||||||||||||||||||||||||||||||||||||
| items_features_by_choice = [] | ||||||||||||||||||||||||||||||||||||||
| for item in items_id: | ||||||||||||||||||||||||||||||||||||||
| columns = [f"{feature}{delimiter}{item}" for feature in items_features_prefixes] + [ | ||||||||||||||||||||||||||||||||||||||
| f"{item}{delimiter}{feature}" for feature in items_features_suffixes | ||||||||||||||||||||||||||||||||||||||
| ] | ||||||||||||||||||||||||||||||||||||||
| for col in columns: | ||||||||||||||||||||||||||||||||||||||
| if col not in df.columns: | ||||||||||||||||||||||||||||||||||||||
| logging.warning( | ||||||||||||||||||||||||||||||||||||||
| f"Column {col} was not in DataFrame,\ | ||||||||||||||||||||||||||||||||||||||
| dummy creation of the feature with zeros." | ||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||
| df[col] = 0 | ||||||||||||||||||||||||||||||||||||||
| items_features_by_choice.append(df[columns].to_numpy()) | ||||||||||||||||||||||||||||||||||||||
| items_features_by_choice = np.stack(items_features_by_choice, axis=1) | ||||||||||||||||||||||||||||||||||||||
| elif items_features_suffixes is not None: | ||||||||||||||||||||||||||||||||||||||
| items_features_names = items_features_suffixes | ||||||||||||||||||||||||||||||||||||||
| items_features_by_choice = [] | ||||||||||||||||||||||||||||||||||||||
| for item in items_id: | ||||||||||||||||||||||||||||||||||||||
| columns = [f"{item}{delimiter}{feature}" for feature in items_features_suffixes] | ||||||||||||||||||||||||||||||||||||||
| for col in columns: | ||||||||||||||||||||||||||||||||||||||
| if col not in df.columns: | ||||||||||||||||||||||||||||||||||||||
| logging.warning( | ||||||||||||||||||||||||||||||||||||||
| f"Column {col} was not in DataFrame,\ | ||||||||||||||||||||||||||||||||||||||
| dummy creation of the feature with zeros." | ||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||
| df[col] = 0 | ||||||||||||||||||||||||||||||||||||||
| items_features_by_choice.append(df[columns].to_numpy()) | ||||||||||||||||||||||||||||||||||||||
| items_features_by_choice = np.stack(items_features_by_choice, axis=1) | ||||||||||||||||||||||||||||||||||||||
| elif items_features_prefixes is not None: | ||||||||||||||||||||||||||||||||||||||
| items_features_names = items_features_prefixes | ||||||||||||||||||||||||||||||||||||||
| if items_features_patterns is not None: | ||||||||||||||||||||||||||||||||||||||
| assert all(["*" in pattern for pattern in items_features_patterns]), ( | ||||||||||||||||||||||||||||||||||||||
| "items_features_patterns should all contain '*' character." | ||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The variable For example, you could add the following line before this one: |
||||||||||||||||||||||||||||||||||||||
| items_features_by_choice = [] | ||||||||||||||||||||||||||||||||||||||
| for item in items_id: | ||||||||||||||||||||||||||||||||||||||
| columns = [f"{feature}{delimiter}{item}" for feature in items_features_prefixes] | ||||||||||||||||||||||||||||||||||||||
| columns = [feature.replace("*", item) for feature in items_features_patterns] | ||||||||||||||||||||||||||||||||||||||
| for col in columns: | ||||||||||||||||||||||||||||||||||||||
| if col not in df.columns: | ||||||||||||||||||||||||||||||||||||||
| logging.warning( | ||||||||||||||||||||||||||||||||||||||
|
|
@@ -993,31 +948,21 @@ def from_single_wide_df( | |||||||||||||||||||||||||||||||||||||
| items_features_by_choice = None | ||||||||||||||||||||||||||||||||||||||
| items_features_names = None | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| if available_items_suffix is not None: | ||||||||||||||||||||||||||||||||||||||
| if isinstance(available_items_suffix, list): | ||||||||||||||||||||||||||||||||||||||
| if not len(available_items_suffix) == len(items_id): | ||||||||||||||||||||||||||||||||||||||
| if available_items_pattern is not None: | ||||||||||||||||||||||||||||||||||||||
| if isinstance(available_items_pattern, list): | ||||||||||||||||||||||||||||||||||||||
| if not len(available_items_pattern) == len(items_id): | ||||||||||||||||||||||||||||||||||||||
| raise ValueError( | ||||||||||||||||||||||||||||||||||||||
| "You have given a list of columns for availabilities." | ||||||||||||||||||||||||||||||||||||||
| "We consider that it is one for each item however lenghts do not match" | ||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||
| logging.info("You have given a list of columns for availabilities.") | ||||||||||||||||||||||||||||||||||||||
| logging.info("Each column will be matched to an item, given their order") | ||||||||||||||||||||||||||||||||||||||
| available_items_by_choice = df[available_items_suffix].to_numpy() | ||||||||||||||||||||||||||||||||||||||
| available_items_by_choice = df[available_items_pattern].to_numpy() | ||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||
| columns = [f"{item}{delimiter}{available_items_suffix}" for item in items_id] | ||||||||||||||||||||||||||||||||||||||
| available_items_by_choice = df[columns].to_numpy() | ||||||||||||||||||||||||||||||||||||||
| elif available_items_prefix is not None: | ||||||||||||||||||||||||||||||||||||||
| if isinstance(available_items_prefix, list): | ||||||||||||||||||||||||||||||||||||||
| if not len(available_items_prefix) == len(items_id): | ||||||||||||||||||||||||||||||||||||||
| raise ValueError( | ||||||||||||||||||||||||||||||||||||||
| "You have given a list of columns for availabilities." | ||||||||||||||||||||||||||||||||||||||
| "We consider that it is one for each item however lenghts do not match" | ||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||
| logging.info("You have given a list of columns for availabilities.") | ||||||||||||||||||||||||||||||||||||||
| logging.info("Each column will be matched to an item, given their order") | ||||||||||||||||||||||||||||||||||||||
| available_items_by_choice = df[available_items_prefix].to_numpy() | ||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||
| columns = [f"{available_items_prefix}{delimiter}{item}" for item in items_id] | ||||||||||||||||||||||||||||||||||||||
| assert "*" in available_items_pattern, ( | ||||||||||||||||||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Using
Suggested change
|
||||||||||||||||||||||||||||||||||||||
| "available_items_pattern should contain '*' character." | ||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||
| columns = [available_items_pattern.replace("*", item) for item in items_id] | ||||||||||||||||||||||||||||||||||||||
| available_items_by_choice = df[columns].to_numpy() | ||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||
| available_items_by_choice = None | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -338,9 +338,11 @@ def load_swissmetro(add_items_one_hot=False, as_frame=False, return_desc=False, | |||||||||||||||||
| df=swiss_df, | ||||||||||||||||||
| items_id=items, | ||||||||||||||||||
| shared_features_columns=shared_features_by_choice_names, | ||||||||||||||||||
| items_features_suffixes=items_features_by_choice_names | ||||||||||||||||||
| + ["ASC_TRAIN", "ASC_SM", "ASC_CAR"], | ||||||||||||||||||
| available_items_suffix=availabilities_column, | ||||||||||||||||||
| items_features_patterns=[ | ||||||||||||||||||
| "*_%s" % column | ||||||||||||||||||
| for column in (items_features_by_choice_names + ["ASC_TRAIN", "ASC_SM", "ASC_CAR"]) | ||||||||||||||||||
| ], | ||||||||||||||||||
| available_items_pattern="*_%s" % availabilities_column, | ||||||||||||||||||
| choices_column=choice_column, | ||||||||||||||||||
| choice_format="items_index", | ||||||||||||||||||
| ) | ||||||||||||||||||
|
|
@@ -532,8 +534,8 @@ def load_swissmetro(add_items_one_hot=False, as_frame=False, return_desc=False, | |||||||||||||||||
| df=swiss_df, | ||||||||||||||||||
| items_id=items, | ||||||||||||||||||
| shared_features_columns=shared_features_by_choice_names, | ||||||||||||||||||
| items_features_suffixes=items_features_by_choice_names, | ||||||||||||||||||
| available_items_suffix=availabilities_column, | ||||||||||||||||||
| items_features_patterns=["*_%s" % s for s in items_features_by_choice_names], | ||||||||||||||||||
| available_items_pattern="*_%s" % availabilities_column, | ||||||||||||||||||
| choices_column=choice_column, | ||||||||||||||||||
| choice_format="items_index", | ||||||||||||||||||
| ) | ||||||||||||||||||
|
|
@@ -927,9 +929,7 @@ def load_train( | |||||||||||||||||
| df=train_df, | ||||||||||||||||||
| items_id=["1", "2"], | ||||||||||||||||||
| shared_features_columns=["id"], | ||||||||||||||||||
| items_features_prefixes=["price", "time", "change", "comfort"], | ||||||||||||||||||
| delimiter="", | ||||||||||||||||||
| available_items_suffix=None, | ||||||||||||||||||
| items_features_patterns=["price*", "time*", "change*", "comfort*"], | ||||||||||||||||||
| choices_column="choice", | ||||||||||||||||||
| choice_format="items_id", | ||||||||||||||||||
| ) | ||||||||||||||||||
|
|
@@ -974,26 +974,25 @@ def load_car_preferences( | |||||||||||||||||
| cars_df["choice"] = cars_df.apply(lambda row: row.choice[-1], axis=1) | ||||||||||||||||||
| shared_features = ["college", "hsg2", "coml5"] | ||||||||||||||||||
| items_features = [ | ||||||||||||||||||
| "type", | ||||||||||||||||||
| "fuel", | ||||||||||||||||||
| "price", | ||||||||||||||||||
| "range", | ||||||||||||||||||
| "acc", | ||||||||||||||||||
| "speed", | ||||||||||||||||||
| "pollution", | ||||||||||||||||||
| "size", | ||||||||||||||||||
| "space", | ||||||||||||||||||
| "cost", | ||||||||||||||||||
| "station", | ||||||||||||||||||
| "type*", | ||||||||||||||||||
| "fuel*", | ||||||||||||||||||
| "price*", | ||||||||||||||||||
| "range*", | ||||||||||||||||||
| "acc*", | ||||||||||||||||||
| "speed*", | ||||||||||||||||||
| "pollution*", | ||||||||||||||||||
| "size*", | ||||||||||||||||||
| "space*", | ||||||||||||||||||
| "cost*", | ||||||||||||||||||
| "station*", | ||||||||||||||||||
| ] | ||||||||||||||||||
| items_id = [f"{i}" for i in range(1, 7)] | ||||||||||||||||||
|
|
||||||||||||||||||
| return ChoiceDataset.from_single_wide_df( | ||||||||||||||||||
| df=cars_df, | ||||||||||||||||||
| items_id=items_id, | ||||||||||||||||||
| shared_features_columns=shared_features, | ||||||||||||||||||
| items_features_prefixes=items_features, | ||||||||||||||||||
| delimiter="", | ||||||||||||||||||
| items_features_patterns=items_features, | ||||||||||||||||||
| choices_column="choice", | ||||||||||||||||||
| choice_format="items_id", | ||||||||||||||||||
| ) | ||||||||||||||||||
|
|
@@ -1060,8 +1059,7 @@ def load_hc( | |||||||||||||||||
| return ChoiceDataset.from_single_wide_df( | ||||||||||||||||||
| df=hc_df, | ||||||||||||||||||
| shared_features_columns=["income"], | ||||||||||||||||||
| items_features_prefixes=["ich", "och", "occa", "icca"], | ||||||||||||||||||
| delimiter=".", | ||||||||||||||||||
| items_features_patterns=["ich.*", "och.*", "occa.*", "icca.*"], | ||||||||||||||||||
| items_id=items_id, | ||||||||||||||||||
| choices_column="depvar", | ||||||||||||||||||
| choice_format="items_id", | ||||||||||||||||||
|
|
@@ -1206,7 +1204,7 @@ def load_londonpassenger( | |||||||||||||||||
| df=london_df, | ||||||||||||||||||
| items_id=items, | ||||||||||||||||||
| shared_features_columns=shared_features_by_choice_names, | ||||||||||||||||||
| items_features_suffixes=items_features_by_choice_names, | ||||||||||||||||||
| items_features_patterns=["*_%s" % s for s in items_features_by_choice_names], | ||||||||||||||||||
| delimiter="_", | ||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The
Suggested change
|
||||||||||||||||||
| choices_column=choice_column, | ||||||||||||||||||
| choice_format="items_index", | ||||||||||||||||||
|
|
||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using
assertfor input validation is not recommended as assertions can be disabled (e.g., with Python's-Oflag), which would bypass this check. It's better to raise aValueErrorfor invalid user input.