From 7682b4c3668a73a40231a5085398a4953498f0c0 Mon Sep 17 00:00:00 2001 From: Andy Craze Date: Mon, 4 Dec 2017 20:49:08 -0800 Subject: [PATCH] returning correct dtype from unitvec matutils.unitvec currently returns a unitvector of a different dtype from the input vector if the input dtype isn't np.float. we should make the return type consistent with the input type. fixes #1722 --- gensim/matutils.py | 4 ++-- gensim/test/test_matutils.py | 27 +++++++++++++++++++++++++++ 2 files changed, 29 insertions(+), 2 deletions(-) create mode 100644 gensim/test/test_matutils.py diff --git a/gensim/matutils.py b/gensim/matutils.py index b7ade1be9f..1913b3fea0 100644 --- a/gensim/matutils.py +++ b/gensim/matutils.py @@ -420,13 +420,13 @@ def unitvec(vec, norm='l2'): return vec if isinstance(vec, np.ndarray): - vec = np.asarray(vec, dtype=float) + vec = np.asarray(vec, dtype=vec.dtype) if norm == 'l1': veclen = np.sum(np.abs(vec)) if norm == 'l2': veclen = blas_nrm2(vec) if veclen > 0.0: - return blas_scal(1.0 / veclen, vec) + return blas_scal(1.0 / veclen, vec).astype(vec.dtype) else: return vec diff --git a/gensim/test/test_matutils.py b/gensim/test/test_matutils.py new file mode 100644 index 0000000000..fe35ae3f52 --- /dev/null +++ b/gensim/test/test_matutils.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Licensed under the GNU LGPL v2.1 - http://www.gnu.org/licenses/lgpl.html + +""" +Automated tests for checking various matutils functions. +""" + +import logging +import unittest + +import numpy as np + +from gensim import matutils + + +class TestMatutils(unittest.TestCase): + def test_unitvec(self): + input_vector = np.random.uniform(size=(100,)).astype(np.float32) + unit_vector = matutils.unitvec(input_vector) + self.assertEqual(input_vector.dtype, unit_vector.dtype) + + +if __name__ == '__main__': + logging.root.setLevel(logging.WARNING) + unittest.main() \ No newline at end of file