diff --git a/src/flat_estimators/classifiers/poelm.py b/src/flat_estimators/classifiers/poelm.py index 95c3e37..ad8ff31 100644 --- a/src/flat_estimators/classifiers/poelm.py +++ b/src/flat_estimators/classifiers/poelm.py @@ -35,6 +35,7 @@ def __init__(self, activation='sigm', hidden_layer_size=5, initialization='rando self.preprocessing = preprocessing def fit(self, x, y): + x, y = x.astype(np.float64), y.astype(np.float64) assert len(y.shape) == 2, 'POELM expects Y to be of shape (n_samples, n_classes) even if it is a binary classification - got {}'.format(y.shape) assert len(x.shape) == 2, 'POELM expects X to be of shape (n_samples, n_features) - got {}'.format(x.shape) assert 0.0 <= np.min(y) and np.max(y) <= 1.0, 'POELM expects that no value in Y is ever greater than 1 or less than 0' @@ -195,6 +196,7 @@ def fit(self, x, y): def predict(self, x, preprocessing='asis'): + x = x.astype(np.float64) # first, take care of preprocessing ret = self.predict_proba(x, preprocessing=preprocessing) if self.using_multiclass: