Skip to content

Commit

Permalink
Fix return dtype for matutils.unitvec according to input dtype. Fix p…
Browse files Browse the repository at this point in the history
…iskvorky#1722 (piskvorky#1992)

* matutils.unitvec bug

As requested, I have edited the fix to ignore dtype size. I use np.issubtype to check input type and handle appropriately before return to ensure non-integer output.

* matutils.unitvec fix tests

Tests to ensure float output for both float and integer inputs.

* unitvec equal input and return types

* Update and rename test_unitvec to test_unitvec.py

* Update matutils.py

* Update matutils.py

* Update test_unitvec.py

* Update and rename gensim/test_unitvec.py to gensim/test/test_matutils.py

* Update matutils.py

* Update test_matutils.py

* Update test_matutils.py

* Update following review

Removed leading spaces, which is the source of the PEP8/travis errors. Sorry, only just learnt from you what these actually are :)
Also updated the code to include 'if return_norm' statement from the sparse array case. (I can't remember why I actually removed this in the first place.)

* Update: attempt to solve Travis errors

* Update test_matutils.py

* Update matutils.py

* Update matutils.py

* Update test_matutils.py

* Addressing travis errors

* Remove unnecessary dtype assignment

* return_norm statements for array instance case

* Update test_matutils.py

* Reduce line repetition

* Reduce repeated lines

* Update test_matutils.py

* Remove some redundant code from unitvec

This is what I have done based on Jayanti's suggestion of redundant code. Let me know if I have misunderstood.

* UnitvecTestCase update

Simplified tha manual_unitvec method and created a separate test for each `arrtype, dtype` pair, as suggested.

* Small typo fix

* Trailing white-space fix for Travis

* Improve code quality and remove no-op
  • Loading branch information
o-P-o authored and darindf committed Apr 23, 2018
1 parent ddf7374 commit 89fb7ac
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 5 deletions.
15 changes: 10 additions & 5 deletions gensim/matutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,34 +692,39 @@ def unitvec(vec, norm='l2', return_norm=False):
"""
if norm not in ('l1', 'l2'):
raise ValueError("'%s' is not a supported norm. Currently supported norms are 'l1' and 'l2'." % norm)

if scipy.sparse.issparse(vec):
vec = vec.tocsr()
if norm == 'l1':
veclen = np.sum(np.abs(vec.data))
if norm == 'l2':
veclen = np.sqrt(np.sum(vec.data ** 2))
if veclen > 0.0:
if np.issubdtype(vec.dtype, np.int):
vec = vec.astype(np.float)
vec /= veclen
if return_norm:
return vec / veclen, veclen
return vec, veclen
else:
return vec / veclen
return vec
else:
if return_norm:
return vec, 1.
else:
return vec

if isinstance(vec, np.ndarray):
vec = np.asarray(vec, dtype=float)
if norm == 'l1':
veclen = np.sum(np.abs(vec))
if norm == 'l2':
veclen = blas_nrm2(vec)
if veclen > 0.0:
if np.issubdtype(vec.dtype, np.int):
vec = vec.astype(np.float)
if return_norm:
return blas_scal(1.0 / veclen, vec), veclen
return blas_scal(1.0 / veclen, vec).astype(vec.dtype), veclen
else:
return blas_scal(1.0 / veclen, vec)
return blas_scal(1.0 / veclen, vec).astype(vec.dtype)
else:
if return_norm:
return vec, 1
Expand Down
101 changes: 101 additions & 0 deletions gensim/test/test_matutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import logging
import unittest
import numpy as np
from scipy import sparse
from scipy.special import psi # gamma function utils

import gensim.matutils as matutils
Expand Down Expand Up @@ -141,6 +142,106 @@ def testDirichletExpectation(self):
self.assertTrue(np.allclose(known_good, test_values), msg)


def manual_unitvec(vec):
# manual unit vector calculation for UnitvecTestCase
vec = vec.astype(np.float)
if sparse.issparse(vec):
vec_sum_of_squares = vec.multiply(vec)
unit = 1. / np.sqrt(vec_sum_of_squares.sum())
return vec.multiply(unit)
elif not sparse.issparse(vec):
sum_vec_squared = np.sum(vec ** 2)
vec /= np.sqrt(sum_vec_squared)
return vec


class UnitvecTestCase(unittest.TestCase):
# test unitvec
def test_sparse_npfloat32(self):
input_vector = sparse.csr_matrix(np.asarray([[1, 0, 0, 0, 3], [0, 0, 4, 3, 0]])).astype(np.float32)
unit_vector = matutils.unitvec(input_vector)
man_unit_vector = manual_unitvec(input_vector)
self.assertTrue(np.allclose(unit_vector.data, man_unit_vector.data, atol=1e-3))
self.assertEqual(input_vector.dtype, unit_vector.dtype)

def test_sparse_npfloat64(self):
input_vector = sparse.csr_matrix(np.asarray([[1, 0, 0, 0, 3], [0, 0, 4, 3, 0]])).astype(np.float64)
unit_vector = matutils.unitvec(input_vector)
man_unit_vector = manual_unitvec(input_vector)
self.assertTrue(np.allclose(unit_vector.data, man_unit_vector.data, atol=1e-3))
self.assertEqual(input_vector.dtype, unit_vector.dtype)

def test_sparse_npint32(self):
input_vector = sparse.csr_matrix(np.asarray([[1, 0, 0, 0, 3], [0, 0, 4, 3, 0]])).astype(np.int32)
unit_vector = matutils.unitvec(input_vector)
man_unit_vector = manual_unitvec(input_vector)
self.assertTrue(np.allclose(unit_vector.data, man_unit_vector.data, atol=1e-3))
self.assertTrue(np.issubdtype(unit_vector.dtype, float))

def test_sparse_npint64(self):
input_vector = sparse.csr_matrix(np.asarray([[1, 0, 0, 0, 3], [0, 0, 4, 3, 0]])).astype(np.int64)
unit_vector = matutils.unitvec(input_vector)
man_unit_vector = manual_unitvec(input_vector)
self.assertTrue(np.allclose(unit_vector.data, man_unit_vector.data, atol=1e-3))
self.assertTrue(np.issubdtype(unit_vector.dtype, float))

def test_dense_npfloat32(self):
input_vector = np.random.uniform(size=(5,)).astype(np.float32)
unit_vector = matutils.unitvec(input_vector)
man_unit_vector = manual_unitvec(input_vector)
self.assertTrue(np.allclose(unit_vector, man_unit_vector))
self.assertEqual(input_vector.dtype, unit_vector.dtype)

def test_dense_npfloat64(self):
input_vector = np.random.uniform(size=(5,)).astype(np.float64)
unit_vector = matutils.unitvec(input_vector)
man_unit_vector = manual_unitvec(input_vector)
self.assertTrue(np.allclose(unit_vector, man_unit_vector))
self.assertEqual(input_vector.dtype, unit_vector.dtype)

def test_dense_npint32(self):
input_vector = np.random.randint(10, size=5).astype(np.int32)
unit_vector = matutils.unitvec(input_vector)
man_unit_vector = manual_unitvec(input_vector)
self.assertTrue(np.allclose(unit_vector, man_unit_vector))
self.assertTrue(np.issubdtype(unit_vector.dtype, float))

def test_dense_npint64(self):
input_vector = np.random.randint(10, size=5).astype(np.int32)
unit_vector = matutils.unitvec(input_vector)
man_unit_vector = manual_unitvec(input_vector)
self.assertTrue(np.allclose(unit_vector, man_unit_vector))
self.assertTrue(np.issubdtype(unit_vector.dtype, float))

def test_sparse_python_float(self):
input_vector = sparse.csr_matrix(np.asarray([[1, 0, 0, 0, 3], [0, 0, 4, 3, 0]])).astype(float)
unit_vector = matutils.unitvec(input_vector)
man_unit_vector = manual_unitvec(input_vector)
self.assertTrue(np.allclose(unit_vector.data, man_unit_vector.data, atol=1e-3))
self.assertEqual(input_vector.dtype, unit_vector.dtype)

def test_sparse_python_int(self):
input_vector = sparse.csr_matrix(np.asarray([[1, 0, 0, 0, 3], [0, 0, 4, 3, 0]])).astype(int)
unit_vector = matutils.unitvec(input_vector)
man_unit_vector = manual_unitvec(input_vector)
self.assertTrue(np.allclose(unit_vector.data, man_unit_vector.data, atol=1e-3))
self.assertTrue(np.issubdtype(unit_vector.dtype, float))

def test_dense_python_float(self):
input_vector = np.random.uniform(size=(5,)).astype(float)
unit_vector = matutils.unitvec(input_vector)
man_unit_vector = manual_unitvec(input_vector)
self.assertTrue(np.allclose(unit_vector, man_unit_vector))
self.assertEqual(input_vector.dtype, unit_vector.dtype)

def test_dense_python_int(self):
input_vector = np.random.randint(10, size=5).astype(int)
unit_vector = matutils.unitvec(input_vector)
man_unit_vector = manual_unitvec(input_vector)
self.assertTrue(np.allclose(unit_vector, man_unit_vector))
self.assertTrue(np.issubdtype(unit_vector.dtype, float))


if __name__ == '__main__':
logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.DEBUG)
unittest.main()

0 comments on commit 89fb7ac

Please sign in to comment.