diff --git a/pyChemometrics/ChemometricsPLSDA.py b/pyChemometrics/ChemometricsPLSDA.py index 35fd766..cc82cfe 100644 --- a/pyChemometrics/ChemometricsPLSDA.py +++ b/pyChemometrics/ChemometricsPLSDA.py @@ -144,6 +144,7 @@ def fit(self, x, y, **fit_params): y_scaled = self.y_scaler.fit_transform(dummy_mat) elif self.n_classes > 2 and isDummy is True: y_scaled = self.y_scaler.fit_transform(y) + dummy_mat = y else: if y.ndim == 1: y = y.reshape(-1, 1) @@ -183,7 +184,7 @@ def fit(self, x, y, **fit_params): else: curr_class_idx = np.where(y[curr_class] == 1) - self.class_means[curr_class, :] = np.mean(self.scores_t[curr_class_idx]) + self.class_means[curr_class, :] = np.mean(self.scores_t[curr_class_idx], axis=0) # Needs to come here for the method shortcuts down the line to work... self._isfitted = True @@ -452,8 +453,7 @@ def predict(self, x): class_pred = np.argmin(np.abs(y_pred - np.array([0, 1])), axis=1) else: - # euclidean distance to mean of class for multiclass PLS-DA - # probably better to use a Logistic/Multinomial or PLS-LDA anyway... + # Euclidean distance to mean of class for multiclass PLS-DA # project X onto T - so then we can get pred_scores = self.transform(x) # prediction rule - find the closest class mean (centroid) for each sample in the score space