-
-
Notifications
You must be signed in to change notification settings - Fork 4.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix dtype of matutils.unitvec
. Fix #1722
#1761
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -420,13 +420,13 @@ def unitvec(vec, norm='l2'): | |
return vec | ||
|
||
if isinstance(vec, np.ndarray): | ||
vec = np.asarray(vec, dtype=float) | ||
vec = np.asarray(vec, dtype=vec.dtype) | ||
if norm == 'l1': | ||
veclen = np.sum(np.abs(vec)) | ||
if norm == 'l2': | ||
veclen = blas_nrm2(vec) | ||
if veclen > 0.0: | ||
return blas_scal(1.0 / veclen, vec) | ||
return blas_scal(1.0 / veclen, vec).astype(vec.dtype) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It looks like this cast at the end There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have issued a new pull request addressing this. |
||
else: | ||
return vec | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
#!/usr/bin/env python | ||
# -*- coding: utf-8 -*- | ||
# | ||
# Licensed under the GNU LGPL v2.1 - http://www.gnu.org/licenses/lgpl.html | ||
|
||
""" | ||
Automated tests for checking various matutils functions. | ||
""" | ||
|
||
import logging | ||
import unittest | ||
|
||
import numpy as np | ||
|
||
from gensim import matutils | ||
|
||
|
||
class TestMatutils(unittest.TestCase): | ||
def test_unitvec(self): | ||
input_vector = np.random.uniform(size=(100,)).astype(np.float32) | ||
unit_vector = matutils.unitvec(input_vector) | ||
self.assertEqual(input_vector.dtype, unit_vector.dtype) | ||
|
||
|
||
if __name__ == '__main__': | ||
logging.root.setLevel(logging.WARNING) | ||
unittest.main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @accraze , thanks for the fix! A better solution here might be to modify the hardcoded dtype in line 423 above, it simplifies the logic, and also ensures that the dtype is consistent for vectors with all zeros too (a rather trivial and probably uncommon case, of course)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @jayantj, I looked into this, however
blas_scal
returns an array of typefloat
(see line 398). Not sure if there is a better way to handle this...There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jayantj What else needs to be done?
blas_scal
is not being used anywhere else. So, should I defineblas_scal
before line 429 and remove hardcoded float from its the definition?