Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update next-basket evaluation #559

Merged
merged 15 commits into from
Dec 8, 2023
70 changes: 43 additions & 27 deletions cornac/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,14 @@ def num_batches(self, batch_size):
"""Estimate number of batches per epoch"""
return estimate_batches(len(self.uir_tuple[0]), batch_size)

def num_user_batches(self, batch_size):
"""Estimate number of batches per epoch iterating over users"""
return estimate_batches(self.num_users, batch_size)

def num_item_batches(self, batch_size):
"""Estimate number of batches per epoch iterating over items"""
return estimate_batches(self.num_items, batch_size)

def idx_iter(self, idx_range, batch_size=1, shuffle=False):
"""Create an iterator over batch of indices

Expand Down Expand Up @@ -700,9 +708,8 @@ def __init__(
def baskets(self):
"""A dictionary to store indices where basket ID appears in the data."""
if self.__baskets is None:
self.__baskets = OrderedDict()
self.__baskets = defaultdict(list)
for idx, bid in enumerate(self.basket_ids):
self.__baskets.setdefault(bid, [])
self.__baskets[bid].append(idx)
return self.__baskets

Expand All @@ -712,10 +719,9 @@ def user_basket_data(self):
values are list of baskets purchased by corresponding users.
"""
if self.__user_basket_data is None:
self.__user_basket_data = defaultdict()
self.__user_basket_data = defaultdict(list)
for bid, ids in self.baskets.items():
u = self.uir_tuple[0][ids[0]]
self.__user_basket_data.setdefault(u, [])
self.__user_basket_data[u].append(bid)
return self.__user_basket_data

Expand Down Expand Up @@ -916,37 +922,50 @@ def from_ubitjson(cls, data, seed=None):
"""
return cls.build(data, fmt="UBITJson", seed=seed)

def num_batches(self, batch_size):
"""Estimate number of batches per epoch"""
return estimate_batches(len(self.user_data), batch_size)
def ub_iter(self, batch_size=1, shuffle=False):
"""Create an iterator over data yielding batch of users and batch of baskets

def user_basket_data_iter(self, batch_size=1, shuffle=False):
"""Create an iterator over data yielding batch of basket indices and batch of baskets
Parameters
----------
batch_size: int, optional, default = 1

shuffle: bool, optional, default: False
If `True`, orders of users will be randomized. If `False`, default orders kept.

Returns
-------
iterator : batch of user indices, batch of baskets corresponding to user indices

"""
for batch_users in self.user_iter(batch_size, shuffle):
batch_baskets = [self.user_basket_data[uid] for uid in batch_users]
yield batch_users, batch_baskets

def ubi_iter(self, batch_size=1, shuffle=False):
"""Create an iterator over data yielding batch of users, basket ids, and batch of the corresponding items

Parameters
----------
batch_size: int, optional, default = 1

shuffle: bool, optional, default: False
If `True`, orders of triplets will be randomized. If `False`, default orders kept.
If `True`, orders of users will be randomized. If `False`, default orders kept.

Returns
-------
iterator : batch of user indices, batch of user data corresponding to user indices
iterator : batch of user indices, batch of baskets corresponding to user indices, and batch of items correponding to baskets

"""
user_indices = np.asarray(list(self.user_basket_data.keys()), dtype="int")
for batch_ids in self.idx_iter(
len(self.user_basket_data), batch_size=batch_size, shuffle=shuffle
):
batch_users = user_indices[batch_ids]
batch_basket_ids = np.asarray(
[self.user_basket_data[uid] for uid in batch_users], dtype="int"
)
yield batch_users, batch_basket_ids
_, item_indices, _ = self.uir_tuple
for batch_users, batch_baskets in self.ub_iter(batch_size, shuffle):
batch_basket_items = [
[item_indices[self.baskets[bid]] for bid in user_baskets]
for user_baskets in batch_baskets
]
yield batch_users, batch_baskets, batch_basket_items

def basket_iter(self, batch_size=1, shuffle=False):
"""Create an iterator over data yielding batch of basket indices and batch of baskets
"""Create an iterator over data yielding batch of basket indices

Parameters
----------
Expand All @@ -957,12 +976,9 @@ def basket_iter(self, batch_size=1, shuffle=False):

Returns
-------
iterator : batch of basket indices, batch of baskets (list of list)
iterator : batch of basket indices (array of 'int')

"""
basket_indices = np.array(list(self.baskets.keys()))
baskets = list(self.baskets.values())
basket_indices = np.fromiter(set(self.baskets.keys()), dtype="int")
for batch_ids in self.idx_iter(len(basket_indices), batch_size, shuffle):
batch_basket_indices = basket_indices[batch_ids]
batch_baskets = [baskets[idx] for idx in batch_ids]
yield batch_basket_indices, batch_baskets
yield basket_indices[batch_ids]
67 changes: 28 additions & 39 deletions cornac/eval_methods/next_basket_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,14 +108,15 @@ def get_gt_items(train_set, test_set, test_pos_items, exclude_unknowns):
item_indices = np.nonzero(u_gt_pos_mask + u_gt_neg_mask)[0]
return item_indices, u_gt_pos_items, u_gt_neg_items

(test_user_indices, test_item_indices, _) = test_set.uir_tuple
for user_idx in tqdm(
set(test_user_indices), desc="Ranking", disable=not verbose, miniters=100
(test_user_indices, *_) = test_set.uir_tuple
for [user_idx], [bids], [(*history_baskets, gt_basket)] in tqdm(
test_set.ubi_iter(batch_size=1, shuffle=False),
total=len(set(test_user_indices)),
desc="Ranking",
disable=not verbose,
miniters=100,
):
[*history_bids, gt_bid] = test_set.user_basket_data[user_idx]
test_pos_items = pos_items(
[[test_item_indices[idx] for idx in test_set.baskets[gt_bid]]]
)
test_pos_items = pos_items([gt_basket])
if len(test_pos_items) == 0:
continue

Expand All @@ -126,10 +127,9 @@ def get_gt_items(train_set, test_set, test_pos_items, exclude_unknowns):
item_rank, item_scores = model.rank(
user_idx,
item_indices,
history_baskets=[
[test_item_indices[idx] for idx in test_set.baskets[bid]]
for bid in history_bids
],
history_baskets=history_baskets,
history_basket_ids=bids[:-1],
uir_tuple=test_set.uir_tuple,
baskets=test_set.baskets,
basket_ids=test_set.basket_ids,
extra_data=test_set.extra_data,
Expand All @@ -146,19 +146,11 @@ def get_gt_items(train_set, test_set, test_pos_items, exclude_unknowns):
user_results["conventional"][i][user_idx] = mt_score

history_items = set(
test_item_indices[idx]
for bid in history_bids
for idx in test_set.baskets[bid]
item_idx for basket in history_baskets for item_idx in basket
)
if repetition_eval:
test_repetition_pos_items = pos_items(
[
[
test_item_indices[idx]
for idx in test_set.baskets[gt_bid]
if test_item_indices[idx] in history_items
]
]
[[iid for iid in gt_basket if iid in history_items]]
)
if len(test_repetition_pos_items) > 0:
_, u_gt_pos_items, u_gt_neg_items = get_gt_items(
Expand All @@ -176,13 +168,7 @@ def get_gt_items(train_set, test_set, test_pos_items, exclude_unknowns):

if exploration_eval:
test_exploration_pos_items = pos_items(
[
[
test_item_indices[idx]
for idx in test_set.baskets[gt_bid]
if test_item_indices[idx] not in history_items
]
]
[[iid for iid in gt_basket if iid not in history_items]]
)
if len(test_exploration_pos_items) > 0:
_, u_gt_pos_items, u_gt_neg_items = get_gt_items(
Expand All @@ -200,18 +186,21 @@ def get_gt_items(train_set, test_set, test_pos_items, exclude_unknowns):
# avg results of ranking metrics
for i, mt in enumerate(metrics):
avg_results["conventional"].append(
sum(user_results["conventional"][i].values())
/ len(user_results["conventional"][i])
np.mean(list(user_results["conventional"][i].values()))
if len(user_results["conventional"][i]) > 0
else 0
)
if repetition_eval:
avg_results["repetition"].append(
sum(user_results["repetition"][i].values())
/ len(user_results["repetition"][i])
np.mean(list(user_results["repetition"][i].values()))
if len(user_results["repetition"][i]) > 0
else 0
)
if exploration_eval:
avg_results["exploration"].append(
sum(user_results["exploration"][i].values())
/ len(user_results["exploration"][i])
np.mean(list(user_results["exploration"][i].values()))
if len(user_results["repetition"][i]) > 0
else 0
)

return avg_results, user_results
Expand Down Expand Up @@ -365,13 +354,13 @@ def _build_datasets(self, train_data, test_data, val_data=None):
print("Total items = {}".format(self.total_items))
print("Total baskets = {}".format(self.total_baskets))

def _eval(self, model, test_set, **kwargs):
def eval(self, model, test_set, ranking_metrics, **kwargs):
metric_avg_results = OrderedDict()
metric_user_results = OrderedDict()

avg_results, user_results = ranking_eval(
model=model,
metrics=self.ranking_metrics,
metrics=ranking_metrics,
train_set=self.train_set,
test_set=test_set,
repetition_eval=self.repetition_eval,
Expand All @@ -380,12 +369,12 @@ def _eval(self, model, test_set, **kwargs):
verbose=self.verbose,
)

for i, mt in enumerate(self.ranking_metrics):
for i, mt in enumerate(ranking_metrics):
metric_avg_results[mt.name] = avg_results["conventional"][i]
metric_user_results[mt.name] = user_results["conventional"][i]

if self.repetition_eval:
for i, mt in enumerate(self.ranking_metrics):
for i, mt in enumerate(ranking_metrics):
metric_avg_results["{}-rep".format(mt.name)] = avg_results[
"repetition"
][i]
Expand All @@ -394,7 +383,7 @@ def _eval(self, model, test_set, **kwargs):
][i]

if self.repetition_eval:
for i, mt in enumerate(self.ranking_metrics):
for i, mt in enumerate(ranking_metrics):
metric_avg_results["{}-expl".format(mt.name)] = avg_results[
"exploration"
][i]
Expand Down
46 changes: 36 additions & 10 deletions cornac/models/gp_top/recom_gp_top.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ class GPTop(NextBasketRecommender):
use_personalized_popularity: boolean, optional, default: True
When False, no item frequency from history baskets are being used.

use_quantity: boolean, optional, default: False
When True, constructing item frequency based on its quantity (getting from extra_data).
The data must be in fmt 'UBITJson'.

References
----------
Ming Li, Sami Jullien, Mozhdeh Ariannezhad, and Maarten de Rijke. 2023.
Expand All @@ -42,31 +46,53 @@ class GPTop(NextBasketRecommender):
"""

def __init__(
self, name="GPTop", use_global_popularity=True, use_personalized_popularity=True
self,
name="GPTop",
use_global_popularity=True,
use_personalized_popularity=True,
use_quantity=False,
):
super().__init__(name=name, trainable=False)
self.use_global_popularity = use_global_popularity
self.use_personalized_popularity = use_personalized_popularity
self.use_quantity = use_quantity
self.item_freq = Counter()

def fit(self, train_set, val_set=None):
super().fit(train_set=train_set, val_set=val_set)
if self.use_global_popularity:
self.item_freq = Counter(self.train_set.uir_tuple[1])
if self.use_quantity:
self.item_freq = Counter()
for idx, iid in enumerate(self.train_set.uir_tuple[1]):
self.item_freq[iid] += self.train_set.extra_data[idx].get(
"quantity", 0
)
else:
self.item_freq = Counter(self.train_set.uir_tuple[1])
return self

def score(self, user_idx, history_baskets, **kwargs):
item_scores = np.ones(self.total_items)
item_scores = np.zeros(self.total_items, dtype=np.float32)
if self.use_global_popularity:
for iid, freq in self.item_freq.items():
item_scores[iid] = freq

if self.use_personalized_popularity:
p_item_freq = Counter([iid for iids in history_baskets for iid in iids])

max_item_freq = (
max(self.item_freq.values()) if len(self.item_freq) > 0 else 1
)
for iid, freq in self.item_freq.items():
item_scores[iid] = freq / max_item_freq

if self.use_personalized_popularity:
if self.use_quantity:
history_basket_bids = kwargs.get("history_basket_ids")
baskets = kwargs.get("baskets")
p_item_freq = Counter()
(_, item_ids, _) = kwargs.get("uir_tuple")
extra_data = kwargs.get("extra_data")
for bid in history_basket_bids:
ids = baskets[bid]
for idx in ids:
p_item_freq[item_ids[idx]] += extra_data[idx].get("quantity", 0)
else:
p_item_freq = Counter([iid for iids in history_baskets for iid in iids])
for iid, cnt in p_item_freq.most_common():
item_scores[iid] = max_item_freq + cnt
item_scores[iid] += cnt
return item_scores
3 changes: 2 additions & 1 deletion examples/gp_top_tafeng.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,9 @@
)

models = [
GPTop(name="PTop", use_global_popularity=False),
GPTop(name="GTop", use_personalized_popularity=False),
GPTop(name="PTop", use_global_popularity=False),
GPTop(name="GPTop-quantity", use_quantity=True),
GPTop(),
]

Expand Down
Loading
Loading