Skip to content

Commit

Permalink
Fix for ALS calculate_loss on the cpu (#662)
Browse files Browse the repository at this point in the history
We weren't calculating the loss correctly for CPU ALS models,
when regularization was non-zero. In addition to incorrect results,
this also could cause a segfault in certain conditions.

Fix.
  • Loading branch information
benfred authored Jun 6, 2023
1 parent ec36f33 commit 6c9db63
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 1 deletion.
2 changes: 1 addition & 1 deletion implicit/cpu/_als.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ def _calculate_loss(Cui, integral[:] indptr, integral[:] indices, float[:] data,
loss += dot(&N, r, &one, &X[u, 0], &one)
user_norm += dot(&N, &X[u, 0], &one, &X[u, 0], &one)

for u in prange(users, schedule='dynamic', chunksize=8):
for i in prange(items, schedule='dynamic', chunksize=8):
item_norm += dot(&N, &Y[i, 0], &one, &Y[i, 0], &one)

finally:
Expand Down
15 changes: 15 additions & 0 deletions tests/als_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from recommender_base_test import RecommenderBaseTestMixin, get_checker_board
from scipy.sparse import coo_matrix, csr_matrix, random

import implicit
from implicit.als import AlternatingLeastSquares
from implicit.gpu import HAS_CUDA

Expand Down Expand Up @@ -298,3 +299,17 @@ def test_incremental_retrain(use_gpu):
model.partial_fit_items([101], likes[1])
ids, _ = model.recommend(101, likes[1], N=3)
assert set(ids) == {1, 100, 101}


def test_calculate_loss_segfault():
# this code used to segfault, because of a bug in calculate_loss
factors = 1
regularization = 0
n_users, n_items = 4, 4

item_factors = np.random.random((n_items, factors)).astype("float32")
user_factors = np.random.random((n_users, factors)).astype("float32")
c_ui = coo_matrix(([1.0, 1.0], ([0, 1], [0, 1])), shape=(n_users, n_items)).tocsr()

loss = implicit.cpu._als.calculate_loss(c_ui, user_factors, item_factors, regularization)
assert loss > 0

0 comments on commit 6c9db63

Please sign in to comment.