Skip to content

Commit

Permalink
bugfixes and sklearn updates
Browse files Browse the repository at this point in the history
  • Loading branch information
Gscorreia89 committed Jul 22, 2024
1 parent bd2f9e0 commit f04d63e
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions pyChemometrics/ChemometricsPLSDA.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ def fit(self, x, y, **fit_params):
y_scaled = self.y_scaler.fit_transform(dummy_mat)
elif self.n_classes > 2 and isDummy is True:
y_scaled = self.y_scaler.fit_transform(y)
dummy_mat = y
else:
if y.ndim == 1:
y = y.reshape(-1, 1)
Expand Down Expand Up @@ -183,7 +184,7 @@ def fit(self, x, y, **fit_params):
else:
curr_class_idx = np.where(y[curr_class] == 1)

self.class_means[curr_class, :] = np.mean(self.scores_t[curr_class_idx])
self.class_means[curr_class, :] = np.mean(self.scores_t[curr_class_idx], axis=0)

# Needs to come here for the method shortcuts down the line to work...
self._isfitted = True
Expand Down Expand Up @@ -452,8 +453,7 @@ def predict(self, x):
class_pred = np.argmin(np.abs(y_pred - np.array([0, 1])), axis=1)

else:
# euclidean distance to mean of class for multiclass PLS-DA
# probably better to use a Logistic/Multinomial or PLS-LDA anyway...
# Euclidean distance to mean of class for multiclass PLS-DA
# project X onto T - so then we can get
pred_scores = self.transform(x)
# prediction rule - find the closest class mean (centroid) for each sample in the score space
Expand Down

0 comments on commit f04d63e

Please sign in to comment.