diff --git a/tests/test_plsobjconsistency.py b/tests/test_plsobjconsistency.py index 0b46718..d21834a 100644 --- a/tests/test_plsobjconsistency.py +++ b/tests/test_plsobjconsistency.py @@ -40,6 +40,7 @@ def setUp(self): finally: # check this self.da_mat = multiclass['Class_Vector'].values + self.da_mat_dummy = multiclass.iloc[:, 1:5].values self.da = twoclass['Class'].values self.xmat_multi = multiclass.iloc[:, 5::].values self.xmat = twoclass.iloc[:, 1::].values @@ -100,7 +101,31 @@ def test_multi_y(self): assert_allclose(self.plsreg.modelParameters['SSY'], self.plsda.modelParameters['PLS']['SSY']) assert_allclose(self.plsreg.modelParameters['SSXcomp'], self.plsda.modelParameters['PLS']['SSXcomp']) assert_allclose(self.plsreg.modelParameters['SSYcomp'], self.plsda.modelParameters['PLS']['SSYcomp']) - + + def test_multi_y_dummy(self): + """ + + :return: + """ + self.plsreg.fit(self.xmat_multi, self.dummy_y) + self.plsda.fit(self.xmat_multi, self.da_mat_dummy) + + assert_allclose(self.plsreg.scores_t, self.plsda.scores_t) + assert_allclose(self.plsreg.scores_u, self.plsda.scores_u) + assert_allclose(self.plsreg.rotations_cs, self.plsda.rotations_cs) + assert_allclose(self.plsreg.rotations_ws, self.plsda.rotations_ws) + assert_allclose(self.plsreg.weights_w, self.plsda.weights_w) + assert_allclose(self.plsreg.weights_c, self.plsda.weights_c) + assert_allclose(self.plsreg.loadings_p, self.plsda.loadings_p) + assert_allclose(self.plsreg.loadings_q, self.plsda.loadings_q) + assert_allclose(self.plsreg.beta_coeffs, self.plsda.beta_coeffs) + assert_allclose(self.plsreg.modelParameters['R2Y'], self.plsda.modelParameters['PLS']['R2Y']) + assert_allclose(self.plsreg.modelParameters['R2X'], self.plsda.modelParameters['PLS']['R2X']) + assert_allclose(self.plsreg.modelParameters['SSX'], self.plsda.modelParameters['PLS']['SSX']) + assert_allclose(self.plsreg.modelParameters['SSY'], self.plsda.modelParameters['PLS']['SSY']) + assert_allclose(self.plsreg.modelParameters['SSXcomp'], self.plsda.modelParameters['PLS']['SSXcomp']) + assert_allclose(self.plsreg.modelParameters['SSYcomp'], self.plsda.modelParameters['PLS']['SSYcomp']) + if __name__ == '__main__': unittest.main()