diff --git a/skorch/tests/test_classifier.py b/skorch/tests/test_classifier.py index b8f6b0235..03b728678 100644 --- a/skorch/tests/test_classifier.py +++ b/skorch/tests/test_classifier.py @@ -134,7 +134,7 @@ def test_classes_with_gaps(self, net_cls, module_cls, data): def test_pass_classes_explicitly_overrides(self, net_cls, module_cls, data): net = net_cls(module_cls, max_epochs=0, classes=['foo', 'bar']).fit(*data) - assert net.classes_ == ['foo', 'bar'] + assert (net.classes_ == np.array(['foo', 'bar'])).all() def test_classes_are_set_with_tensordataset_explicit_y( self, net_cls, module_cls, data