Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] Implement predict_proba() for OutputCodeClassifier #25148

Open
wants to merge 18 commits into
base: main
Choose a base branch
from

Conversation

dx2-66
Copy link

@dx2-66 dx2-66 commented Dec 9, 2022

Reference Issues/PRs

None found, albeit it was briefly mentioned in #389

What does this implement/fix? Explain your changes.

In its present state, sklearn.multiclass.OutputCodeClassifier() is only able to predict(). Having access to the underlying class scores/probabilities is required for a good part of validation metrics. Those can also be used for model calibration and improved decision making (though that might be considered redundant for the ECOC classifier usecases).

What does it do / how does it work?

This PR implements the (faux) probability estimates using the same base estimators' output <-> codebook euclidean distances that are currently used by the predict() method using inverse distance weighting. Distance calculation is moved to the new predict_proba() method, and predict() is adjusted accordingly.

Any other comments?

Possible cons:

The probabilistic interpretation of said distances might be questionable. However, lots of other classification algorithms that do implement predict_proba() aren't guaranteed to have its output well-calibrated either.

QA:

  • Precautions taken to avoid zero division in IDW calculation.
  • A non-regression test to ensure the predict() results are the same as before is provided, as well as the sanity test of class scores summing to 1.
  • pytest/black/flake8/mypy tests passed.
  • The scores returned by the new predict_proba() appear to yield sane OvR ROC and PR curves, as well as calibration curves of an acceptable shape.

Copy link
Member

@glemaitre glemaitre left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We will also need an entry in the changelog for the latest failure to go away. We can start to put this entry in doc/whats_new/v1.3.rst file.

Comment on lines 1048 to 1051
# Inverse distance weighting:
eps = 1e-16
proba = (1.0 / (dist + eps)) / np.sum(1.0 / (dist + eps), axis=1)[:, np.newaxis]
return proba
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have any scientific literature linked to this approach to express probability. I would have naively think to use a softmax instead of a linear mapping.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume that we could invert the distance because we should have a upper bound of the distance since we deal with binary code (I did not look at the algorithm yet).

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Turning distance into a probability-like value is basically spatial interpolation. IDW was described in a classical 'A two-dimensional interpolation function for irregularly-spaced data', among other methods.

Softmax is an option, of course, I cosidered a selectable method but an extra predict_proba() argument looked like a bad idea to me.

Practically: softmax decays somewhat faster, but that's about it. Both ways yield non-negative scores which sum to 1. Any method is able to come slightly ahead of the other depending on data in terms of ROC_AUC, AP and calibration curve (the avergage difference in metrics is within 0.01).

sklearn/multiclass.py Outdated Show resolved Hide resolved
sklearn/multiclass.py Outdated Show resolved Hide resolved
sklearn/tests/test_multiclass.py Outdated Show resolved Hide resolved
sklearn/tests/test_multiclass.py Outdated Show resolved Hide resolved
@@ -704,6 +706,19 @@ def test_ecoc_fit_predict():
assert len(ecoc.estimators_) == n_classes * 2


def test_ecoc_predict_proba():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should try binary and multiclass here.

# Test that the scores sum to 1
assert_almost_equal(np.sum(proba, axis=1), np.ones(proba.shape[0]))

# Regression test for the new proba-based predict against the old one
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need this. Basically, this non-regression test is only check a "general" case. If something goes wrong, it would be on the corner case.

At least we have now the case where y_prob.argmax(axis=1) == y_pred which is something that we intend to achieve in common test.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at OneVsRestClassifier(), it only tests the output shape (which is reasonable I guess) and the said y_prob.argmax(axis=1) == y_pred equality. The latter seems redundant since predict() basically does nothing but argmax(axis=1) now.

As an alternative, ensuring the shortest distance yields the highest probability for a few cases comes to mind:

  • code size much smaller than the number of classes;
  • code size much bigger than the number of classes;
  • all equal distances.

Should we include those (or, perhaps, other possible cases)?

sklearn/multiclass.py Outdated Show resolved Hide resolved
sklearn/multiclass.py Outdated Show resolved Hide resolved
sklearn/multiclass.py Outdated Show resolved Hide resolved
@glemaitre glemaitre added the Needs Triage Issue requires triage label Dec 15, 2022
@glemaitre
Copy link
Member

Adding the "needs triage" to discuss the inclusion during the triage meeting.

@glemaitre glemaitre removed the Needs Triage Issue requires triage label Dec 16, 2022
@glemaitre
Copy link
Member

We discuss the implementation during the triage meeting.

Basically, we are not against adding the predict_proba. However, we have to go back to the literature to be sure to implement what was proposed initially.

In this regard, I think that there are several bugs to correct before going deeper further here.

The original paper [1] is using an L1 distance (city-block). I have to check [2] to be sure that there is not extension or theory behind using the L2 distance here. At least L1 is more intuitive to me than an L2 when computing distances between binary codes.

A second potential bug is linked to the binary classifier predictions. The original paper states that they concatenate the probabilities of each binary classifier. In our code, we will first try to concatenate the predictions that we get through decision_function and if not available the output of the predict_proba method. So we don't always do the things stated in the paper. Since we compute a distance between these aggregated predictions and the code in the codebook, I would really think that the concatenated values should be bounded between 0 and 1 and represent a probability. Thus, I don't think that we should rely on the decision_function at all.

Regarding the predict_proba which is implemented here, a section in [1] discusses it. I think that we should figure out a way they build this probability estimate. @adrinjalali has intuitions (and previous experience) that the IDW is not the right thing to do here. In the paper, if I understand correctly, they will create a confidence score based on the difference between the nearest and second nearest code to the predicted values. This does not look like a probability yet to me but they might normalize it. It needs to be investigated.

[1] Dietterich, T. G., & Bakiri, G. (1994). Solving multiclass learning problems via error-correcting output codes. Journal of artificial intelligence research, 2, 263-286.

[2] James, G., & Hastie, T. (1998). The error coding method and PICTs. Journal of Computational and Graphical statistics, 7(3), 377-387.

@glemaitre
Copy link
Member

I did not finish yet but for the probability/confidence score I have the following feeling:

  • the original paper (Diettrich et al.) will provide a confidence score and it could be easily used to make the output of the decision_function.
  • I would explore the following paper for probabilities estimate. It looks at calibrated estimates which is interesting and this is something used as well in the ECOC of Matlab from what I could read in the documentation.

@glemaitre
Copy link
Member

I wanted to understand exactly what was in Zadrozny et al. paper and I made a barely tested implementation. I am putting the code that I could come up with. The naming is terrible but it is related to the formula in the paper.

I saw that we mentioned the paper in the calibration section. I assume that we used the non-iterative formulation that was proposed by Hastie et al. that holds only in the case of one vs. rest.

# %%
from sklearn.datasets import make_classification

n_classes = 3
X, y = make_classification(
    n_samples=1_000,
    n_features=4,
    n_clusters_per_class=1,
    n_classes=n_classes,
    random_state=0,
)

# %%
from sklearn.multiclass import OutputCodeClassifier
from sklearn.tree import DecisionTreeClassifier

model = OutputCodeClassifier(
    estimator=DecisionTreeClassifier(max_depth=2, random_state=0),
    code_size=2,
    random_state=0
).fit(X, y)

# %%
import warnings
import numpy as np
from sklearn.exceptions import ConvergenceWarning

# n_samples should be computed during training and stored
n_samples, n_estimators = X.shape[0], len(model.estimators_)

rng = np.random.default_rng(0)

indicator = model.code_book_.astype(bool)
r = np.transpose([
    est.predict_proba(X[:n_samples])[:, 1]
    for est in model.estimators_
])
not_r = 1 - r

# initialization
p_hat = rng.uniform(size=(n_samples, n_classes))
p_hat /= p_hat.sum(axis=1)[:, np.newaxis]
r_hat = np.zeros_like(r)

max_iter, n_iter, tol, err = 100, 0, 1e-3, np.inf

for _ in range(max_iter):
    # compute r_hat
    for learner_idx in range(n_estimators):
        r_hat[:, learner_idx] = p_hat[:, indicator[:, learner_idx]].sum(axis=1)
    not_r_hat = 1 - r_hat

    # estimate p_hat
    p_hat_previous = p_hat.copy()
    for klass in range(n_classes):

        numerator, denominator = 0, 0
        klass_indicator = indicator[klass]
        numerator = n_samples * (
            r[:, klass_indicator].sum(axis=1) +
            not_r[:, ~klass_indicator].sum(axis=1)
        )
        denominator = n_samples * (
            r_hat[:, klass_indicator].sum(axis=1) +
            not_r_hat[:, ~klass_indicator].sum(axis=1)
        )

        p_hat[:, klass] *= numerator / denominator
    p_hat /= p_hat.sum(axis=1)[:, np.newaxis]

    n_iter += 1
    err = np.linalg.norm(p_hat_previous - p_hat, np.inf)
    if err < tol:
        break
else:
    warnings.warn(
        "Did not converge, decrease tol or increase max_iter.",
        ConvergenceWarning,
    )

# %%
np.mean(p_hat.argmax(axis=1) == model.predict(X))

@dx2-66
Copy link
Author

dx2-66 commented Dec 17, 2022

np.mean(p_hat.argmax(axis=1) == model.predict(X))

As I can see, it tries to derive model scores using individual estimators' scores only.

However, changing the distance calculation method is able to change the predictions (hence your proposal to switch from euclidean to manhattan). The individual scores, however, remain the same. So, logically thinking, there's no way to ensure the argmax of the proba equals the argmin of the distance unless the distance calculation is included into scoring. Am I missing something here?

@glemaitre
Copy link
Member

I would look at the theory in the paper stated above in section 4. I did not yet have time to investigate more closely.

@glemaitre
Copy link
Member

glemaitre commented Dec 22, 2022

So I intend to solve the couple of bugs that we had: #25217

From the article above, we can ensure the equivalence by setting decoding="loss" and loss="linear".

@dx2-66
Copy link
Author

dx2-66 commented Dec 23, 2022

I've run some tests against #25217:
Your iterative estimator appears to follow the paper close enough (at least I observe no significant divergencies), yet somehow it converges to the wrong values still, even with decoding='loss'. I tried adjusting the priors ('the number of training examples used to train the binary classifier that corresponds to column b of the code matrix' sounds a bit ambiguous, the current n_samples/n_samples is a no-op), the improvement is not big enough yet. Perhaps there's some error with r_hat calculation I failed to notice.

The non-iterative estimator suggested in the same chapter (i.e. basically r[:, klass_indicator].sum(axis=1) + not_r[:, ~klass_indicator].sum(axis=1) normalized) tends to do much better though, it complies with predict() results almost perfectly for both loss and cityscape decoding (except some cases where the scores are equal and argmax selects the first index while predict prefers the other), I have yet to find a counterexample. It is also suspiciously well-calibrated.
It performs much worse with hamming decoding, but still better than the iterative one.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants