Skip to content

Commit

Permalink
Merge pull request #11 from kjhall01/kyle/fix_poelm
Browse files Browse the repository at this point in the history
lapack datatype error - coercing to float64
  • Loading branch information
kjhall01 authored Jan 22, 2022
2 parents ed74f80 + d4f5bcb commit 5f0e794
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/flat_estimators/classifiers/poelm.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def __init__(self, activation='sigm', hidden_layer_size=5, initialization='rando
self.preprocessing = preprocessing

def fit(self, x, y):
x, y = x.astype(np.float64), y.astype(np.float64)
assert len(y.shape) == 2, 'POELM expects Y to be of shape (n_samples, n_classes) even if it is a binary classification - got {}'.format(y.shape)
assert len(x.shape) == 2, 'POELM expects X to be of shape (n_samples, n_features) - got {}'.format(x.shape)
assert 0.0 <= np.min(y) and np.max(y) <= 1.0, 'POELM expects that no value in Y is ever greater than 1 or less than 0'
Expand Down Expand Up @@ -195,6 +196,7 @@ def fit(self, x, y):


def predict(self, x, preprocessing='asis'):
x = x.astype(np.float64)
# first, take care of preprocessing
ret = self.predict_proba(x, preprocessing=preprocessing)
if self.using_multiclass:
Expand Down

0 comments on commit 5f0e794

Please sign in to comment.