Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rename symbols to ensure consistency #565

Merged
merged 2 commits into from
Dec 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 15 additions & 7 deletions cornac/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,7 +645,7 @@ class BasketDataset(Dataset):
uir_tuple: tuple, required
Tuple of 3 numpy arrays (user_indices, item_indices, rating_values).

basket_ids: numpy.array, required
basket_indices: numpy.array, required
Array of basket indices corresponding to observation in `uir_tuple`.

timestamps: numpy.array, optional, default: None
Expand Down Expand Up @@ -677,7 +677,7 @@ def __init__(
bid_map,
iid_map,
uir_tuple,
basket_ids=None,
basket_indices=None,
timestamps=None,
extra_data=None,
seed=None,
Expand All @@ -693,23 +693,31 @@ def __init__(
)
self.num_baskets = num_baskets
self.bid_map = bid_map
self.basket_ids = basket_ids
self.basket_indices = basket_indices
self.extra_data = extra_data
basket_sizes = list(Counter(basket_ids).values())
basket_sizes = list(Counter(basket_indices).values())
self.max_basket_size = np.max(basket_sizes)
self.min_basket_size = np.min(basket_sizes)
self.avg_basket_size = np.mean(basket_sizes)

self.__baskets = None
self.__basket_ids = None
self.__user_basket_data = None
self.__chrono_user_basket_data = None

@property
def basket_ids(self):
"""Return the list of raw basket ids"""
if self.__basket_ids is None:
self.__basket_ids = list(self.bid_map.keys())
return self.__basket_ids

@property
def baskets(self):
"""A dictionary to store indices where basket ID appears in the data."""
if self.__baskets is None:
self.__baskets = defaultdict(list)
for idx, bid in enumerate(self.basket_ids):
for idx, bid in enumerate(self.basket_indices):
self.__baskets[bid].append(idx)
return self.__baskets

Expand Down Expand Up @@ -836,7 +844,7 @@ def build(
np.ones(len(u_indices), dtype="float"),
)

basket_ids = np.asarray(b_indices, dtype="int")
basket_indices = np.asarray(b_indices, dtype="int")

timestamps = (
np.fromiter((int(data[i][3]) for i in valid_idx), dtype="int")
Expand All @@ -854,7 +862,7 @@ def build(
bid_map=global_bid_map,
iid_map=global_iid_map,
uir_tuple=uir_tuple,
basket_ids=basket_ids,
basket_indices=basket_indices,
timestamps=timestamps,
extra_data=extra_data,
seed=seed,
Expand Down
4 changes: 2 additions & 2 deletions cornac/eval_methods/next_basket_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,10 +128,10 @@ def get_gt_items(train_set, test_set, test_pos_items, exclude_unknowns):
user_idx,
item_indices,
history_baskets=history_baskets,
history_basket_ids=bids[:-1],
history_bids=bids[:-1],
uir_tuple=test_set.uir_tuple,
baskets=test_set.baskets,
basket_ids=test_set.basket_ids,
basket_indices=test_set.basket_indices,
extra_data=test_set.extra_data,
)

Expand Down
10 changes: 4 additions & 6 deletions cornac/models/gp_top/recom_gp_top.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,15 +82,13 @@ def score(self, user_idx, history_baskets, **kwargs):

if self.use_personalized_popularity:
if self.use_quantity:
history_basket_bids = kwargs.get("history_basket_ids")
history_bids = kwargs.get("history_bids")
baskets = kwargs.get("baskets")
p_item_freq = Counter()
(_, item_ids, _) = kwargs.get("uir_tuple")
extra_data = kwargs.get("extra_data")
for bid in history_basket_bids:
ids = baskets[bid]
for idx in ids:
p_item_freq[item_ids[idx]] += extra_data[idx].get("quantity", 0)
for bid, iids in zip(history_bids, history_baskets):
for idx, iid in zip(baskets[bid], iids):
p_item_freq[iid] += extra_data[idx].get("quantity", 0)
else:
p_item_freq = Counter([iid for iids in history_baskets for iid in iids])
for iid, cnt in p_item_freq.most_common():
Expand Down
Loading