Skip to content

Commit

Permalink
bugfixes and sklearn updates
Browse files Browse the repository at this point in the history
  • Loading branch information
Gscorreia89 committed Jul 22, 2024
1 parent dd1e9aa commit bd2f9e0
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions pyChemometrics/ChemometricsPLSDA.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,14 +121,14 @@ def fit(self, x, y, **fit_params):
# Scaling for the classifier setting proceeds as usual for the X block
xscaled = self.x_scaler.fit_transform(x)

# For this "classifier" PLS objects, the yscaler is not used, as we are not interesting in decentering and
# scaling class labels and dummy matrices.

# Instead, we just do some on the fly detection of binary vs multiclass classification
# On the fly detection of binary vs multiclass classification
# Verify number of classes in the provided class label y vector so the algorithm can adjust accordingly
# detect dummy vector
if (np.unique(y).size == 2) and (np.all(np.isin(np.unique(y), np.array([0, 1]), assume_unique=True))):
n_classes = y.shape[1]
if y.ndim == 1:
n_classes = 2
else:
n_classes = y.shape[1]
isDummy = True
else:
n_classes = np.unique(y).size
Expand Down Expand Up @@ -178,8 +178,10 @@ def fit(self, x, y, **fit_params):
for curr_class in range(self.n_classes):
if not isDummy:
curr_class_idx = np.where(y == curr_class)
else:
elif isDummy and n_classes > 2:
curr_class_idx = np.where(y[:, curr_class] == 1)
else:
curr_class_idx = np.where(y[curr_class] == 1)

self.class_means[curr_class, :] = np.mean(self.scores_t[curr_class_idx])

Expand Down

0 comments on commit bd2f9e0

Please sign in to comment.