Skip to content

Commit

Permalink
Fix reshape interoperability test (apache#17155)
Browse files Browse the repository at this point in the history
* fix reshape interoperability test

* fix for scipy import
  • Loading branch information
haojin2 authored and ChaiBapchya committed Jan 16, 2020
1 parent 3f449aa commit 8407f93
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 9 deletions.
4 changes: 2 additions & 2 deletions ci/docker/install/docs_requirements
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@ h5py==2.8.0rc1
mock==2.0.0
nose==1.3.7
nose-timer==0.7.3
numpy>1.16.0,<2.0.0
numpy>1.16.0,<1.18.0
pylint==2.3.1; python_version >= '3.0'
pypandoc==1.4
recommonmark==0.4.0
requests<2.19.0,>=2.18.4
scipy==1.0.1
scipy==1.2.1
six==1.11.0
sphinx==1.5.6
43 changes: 36 additions & 7 deletions tests/python/unittest/test_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

import mxnet as mx
import numpy as np
import scipy
from scipy.stats import pearsonr
import json
from common import with_seed
from copy import deepcopy
Expand Down Expand Up @@ -247,13 +249,40 @@ def test_perplexity():
assert perplexity == perplexity_expected

def test_pearsonr():
pred = mx.nd.array([[0.7, 0.3], [0.1, 0.9], [1., 0]])
label = mx.nd.array([[0, 1], [1, 0], [1, 0]])
pearsonr_expected = np.corrcoef(pred.asnumpy().ravel(), label.asnumpy().ravel())[0, 1]
metric = mx.metric.create('pearsonr')
metric.update([label], [pred])
_, pearsonr = metric.get()
assert pearsonr == pearsonr_expected
pred1 = mx.nd.array([[0.3, 0.7], [0, 1.], [0.4, 0.6]])
label1 = mx.nd.array([[1, 0], [0, 1], [0, 1]])
pearsonr_expected_np = np.corrcoef(pred1.asnumpy().ravel(), label1.asnumpy().ravel())[0, 1]
pearsonr_expected_scipy, _ = pearsonr(pred1.asnumpy().ravel(), label1.asnumpy().ravel())
macro_pr = mx.metric.create('pearsonr', average='macro')
micro_pr = mx.metric.create('pearsonr', average='micro')

assert np.isnan(macro_pr.get()[1])
assert np.isnan(micro_pr.get()[1])

macro_pr.update([label1], [pred1])
micro_pr.update([label1], [pred1])

np.testing.assert_almost_equal(macro_pr.get()[1], pearsonr_expected_np)
np.testing.assert_almost_equal(macro_pr.get()[1], pearsonr_expected_scipy)
np.testing.assert_almost_equal(micro_pr.get()[1], pearsonr_expected_np)
np.testing.assert_almost_equal(micro_pr.get()[1], pearsonr_expected_scipy)

pred2 = mx.nd.array([[1, 2], [3, 2], [4, 6]])
label2 = mx.nd.array([[1, 0], [0, 1], [0, 1]])
# Note that pred12 = pred1 + pred2; label12 = label1 + label2
pred12 = mx.nd.array([[0.3, 0.7], [0, 1.], [0.4, 0.6],[1, 2], [3, 2], [4, 6]])
label12 = mx.nd.array([[1, 0], [0, 1], [0, 1], [1, 0], [0, 1], [0, 1]])

pearsonr_expected_np = np.corrcoef(pred12.asnumpy().ravel(), label12.asnumpy().ravel())[0, 1]
pearsonr_expected_scipy, _ = pearsonr(pred12.asnumpy().ravel(), label12.asnumpy().ravel())

macro_pr.reset()
micro_pr.update([label2], [pred2])
macro_pr.update([label12], [pred12])
np.testing.assert_almost_equal(macro_pr.get()[1], pearsonr_expected_np)
np.testing.assert_almost_equal(macro_pr.get()[1], pearsonr_expected_scipy)
np.testing.assert_almost_equal(micro_pr.get()[1], pearsonr_expected_np)
np.testing.assert_almost_equal(micro_pr.get()[1], pearsonr_expected_scipy)

def cm_batch(cm):
# generate a batch yielding a given confusion matrix
Expand Down

0 comments on commit 8407f93

Please sign in to comment.