Skip to content

Commit

Permalink
Test functions with suggested changes (#36)
Browse files Browse the repository at this point in the history
* added tests

* added tests

Co-authored-by: shikha <shikha@shikhas-MacBook-Air.local>
  • Loading branch information
shikhareddy and shikha authored Feb 3, 2021
1 parent 6dc87f0 commit e607f21
Showing 1 changed file with 69 additions and 12 deletions.
81 changes: 69 additions & 12 deletions tests/test_dsm.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
"""This module contains test functions to
test the accuracy of Deep Survival Machines
models on certain standard datasets.
"""
import unittest

from dsm import DeepSurvivalMachines
Expand All @@ -7,21 +11,74 @@
import numpy as np

class TestDSM(unittest.TestCase):
"""Base Class for all test functions"""
def test_support_dataset(self):
"""Test function to load and test the SUPPORT dataset.
"""

def test_dsm(self):
x, t, e = datasets.load_dataset('SUPPORT')
t_median = np.median(t[e==1])

x, t, e = datasets.load_dataset('SUPPORT')
self.assertIsInstance(x, np.ndarray)
self.assertIsInstance(t, np.ndarray)
self.assertIsInstance(e, np.ndarray)

self.assertIsInstance(x, np.ndarray)
self.assertIsInstance(t, np.ndarray)
self.assertIsInstance(e, np.ndarray)
self.assertEqual(x.shape, (9105, 44))
self.assertEqual(t.shape, (9105,))
self.assertEqual(e.shape, (9105,))

self.assertEqual(x.shape, (9105, 44))
self.assertEqual(t.shape, (9105,))
self.assertEqual(e.shape, (9105,))
model = DeepSurvivalMachines()
self.assertIsInstance(model, DeepSurvivalMachines)
model.fit(x, t, e, iters=10)
self.assertIsInstance(model.torch_model,
DeepSurvivalMachinesTorch)
risk_score = model.predict_risk(x, t_median)
survival_probability = model.predict_survival(x, t_median)
np.testing.assert_equal((risk_score+survival_probability).all(), 1.0)

def test_pbc_dataset(self):
"""Test function to load and test the PBC dataset.
"""

x, t, e = datasets.load_dataset('PBC')
t_median = np.median(t[e==1])

self.assertIsInstance(x, np.ndarray)
self.assertIsInstance(t, np.ndarray)
self.assertIsInstance(e, np.ndarray)

self.assertEqual(x.shape, (1945, 25))
self.assertEqual(t.shape, (1945,))
self.assertEqual(e.shape, (1945,))

model = DeepSurvivalMachines()
self.assertIsInstance(model, DeepSurvivalMachines)
model.fit(x, t, e, iters=10)
self.assertIsInstance(model.torch_model,
DeepSurvivalMachinesTorch)
risk_score = model.predict_risk(x, t_median)
survival_probability = model.predict_survival(x, t_median)
np.testing.assert_equal((risk_score+survival_probability).all(), 1.0)

def test_framingham_dataset(self):
"""Test function to load and test the Framingham dataset.
"""
x, t, e = datasets.load_dataset('FRAMINGHAM')
t_median = np.median(t)

self.assertIsInstance(x, np.ndarray)
self.assertIsInstance(t, np.ndarray)
self.assertIsInstance(e, np.ndarray)

self.assertEqual(x.shape, (11627, 18))
self.assertEqual(t.shape, (11627,))
self.assertEqual(e.shape, (11627,))

model = DeepSurvivalMachines()
self.assertIsInstance(model, DeepSurvivalMachines)
model.fit(x, t, e, iters=10)
self.assertIsInstance(model.torch_model,
model = DeepSurvivalMachines()
self.assertIsInstance(model, DeepSurvivalMachines)
model.fit(x, t, e, iters=10)
self.assertIsInstance(model.torch_model,
DeepSurvivalMachinesTorch)
risk_score = model.predict_risk(x, t_median)
survival_probability = model.predict_survival(x, t_median)
np.testing.assert_equal((risk_score+survival_probability).all(), 1.0)

0 comments on commit e607f21

Please sign in to comment.