diff --git a/gensim/matutils.py b/gensim/matutils.py index b7ade1be9f..1913b3fea0 100644 --- a/gensim/matutils.py +++ b/gensim/matutils.py @@ -420,13 +420,13 @@ def unitvec(vec, norm='l2'): return vec if isinstance(vec, np.ndarray): - vec = np.asarray(vec, dtype=float) + vec = np.asarray(vec, dtype=vec.dtype) if norm == 'l1': veclen = np.sum(np.abs(vec)) if norm == 'l2': veclen = blas_nrm2(vec) if veclen > 0.0: - return blas_scal(1.0 / veclen, vec) + return blas_scal(1.0 / veclen, vec).astype(vec.dtype) else: return vec diff --git a/gensim/test/test_matutils.py b/gensim/test/test_matutils.py new file mode 100644 index 0000000000..1eb75437fc --- /dev/null +++ b/gensim/test/test_matutils.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Licensed under the GNU LGPL v2.1 - http://www.gnu.org/licenses/lgpl.html + +""" +Automated tests for checking various matutils functions. +""" + +import logging +import unittest + +import numpy as np + +from gensim import matutils + + +class TestMatutils(unittest.TestCase): + def test_unitvec(self): + input_vector = np.random.uniform(size=(100,)).astype(np.float32) + unit_vector = matutils.unitvec(input_vector) + self.assertEqual(input_vector.dtype, unit_vector.dtype) + + +if __name__ == '__main__': + logging.root.setLevel(logging.WARNING) + unittest.main()