From 05d4aac4601b166a8286f23dc32d5e083f89496e Mon Sep 17 00:00:00 2001 From: gd2212 Date: Mon, 22 Jul 2024 10:34:59 +0100 Subject: [PATCH] add fixes for multi-class --- pyChemometrics/ChemometricsPLSDA.py | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/pyChemometrics/ChemometricsPLSDA.py b/pyChemometrics/ChemometricsPLSDA.py index cc82cfe..b3076a8 100644 --- a/pyChemometrics/ChemometricsPLSDA.py +++ b/pyChemometrics/ChemometricsPLSDA.py @@ -228,20 +228,28 @@ def fit(self, x, y, **fit_params): else: y_pred = self.predict(x) - accuracy = metrics.accuracy_score(y, y_pred) - precision = metrics.precision_score(y, y_pred, average='weighted') - recall = metrics.recall_score(y, y_pred, average='weighted') - misclassified_samples = np.where(y.ravel() != y_pred.ravel())[0] - f1_score = metrics.f1_score(y, y_pred, average='weighted') - conf_matrix = metrics.confusion_matrix(y, y_pred) - zero_oneloss = metrics.zero_one_loss(y, y_pred) + + # Make dummy matrix into a label vector for scoring + if isDummy: + # y = np.where(y == 1)[1] + y_vec = np.where(y == 1)[1] + else: + y_vec = y + + accuracy = metrics.accuracy_score(y_vec, y_pred) + precision = metrics.precision_score(y_vec, y_pred, average='weighted') + recall = metrics.recall_score(y_vec, y_pred, average='weighted') + misclassified_samples = np.where(y_vec.ravel() != y_pred.ravel())[0] + f1_score = metrics.f1_score(y_vec, y_pred, average='weighted') + conf_matrix = metrics.confusion_matrix(y_vec, y_pred) + zero_oneloss = metrics.zero_one_loss(y_vec, y_pred) matthews_mcc = np.nan roc_curve = list() auc_area = list() # Generate multiple ROC curves - one for each class the multiple class case for predclass in range(self.n_classes): - current_roc = metrics.roc_curve(y, class_score[:, predclass], pos_label=predclass) + current_roc = metrics.roc_curve(y_vec, class_score[:, predclass], pos_label=predclass) # Interpolate all ROC curves to a finite grid # Makes it easier to average and compare multiple models - with CV in mind tpr = current_roc[1]