diff --git a/gensim/matutils.py b/gensim/matutils.py index b7ade1be9f..404e08c75a 100644 --- a/gensim/matutils.py +++ b/gensim/matutils.py @@ -420,7 +420,7 @@ def unitvec(vec, norm='l2'): return vec if isinstance(vec, np.ndarray): - vec = np.asarray(vec, dtype=float) + vec = np.asarray(vec, dtype=vec.dtype) if norm == 'l1': veclen = np.sum(np.abs(vec)) if norm == 'l2': diff --git a/gensim/test/test_matutils.py b/gensim/test/test_matutils.py new file mode 100644 index 0000000000..fe35ae3f52 --- /dev/null +++ b/gensim/test/test_matutils.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Licensed under the GNU LGPL v2.1 - http://www.gnu.org/licenses/lgpl.html + +""" +Automated tests for checking various matutils functions. +""" + +import logging +import unittest + +import numpy as np + +from gensim import matutils + + +class TestMatutils(unittest.TestCase): + def test_unitvec(self): + input_vector = np.random.uniform(size=(100,)).astype(np.float32) + unit_vector = matutils.unitvec(input_vector) + self.assertEqual(input_vector.dtype, unit_vector.dtype) + + +if __name__ == '__main__': + logging.root.setLevel(logging.WARNING) + unittest.main() \ No newline at end of file