-
Notifications
You must be signed in to change notification settings - Fork 0
/
ordinal_classifier.py
49 lines (37 loc) · 1.68 KB
/
ordinal_classifier.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
from sklearn.base import clone, BaseEstimator, ClassifierMixin
import numpy as np
class OrdinalClassifier(BaseEstimator, ClassifierMixin):
# https://towardsdatascience.com/simple-trick-to-train-an-ordinal-regression-with-any-classifier-6911183d2a3c
"""
A classifier that can be trained on a range of classes.
@param classifier: A scikit-learn classifier.
"""
def __init__(self, clf):
self.clf = clf
self.clfs = {}
self.uniques_class = None
def fit(self, X, y):
"""Fit the ordinal classifier. """
self.uniques_class = np.sort(np.unique(y))
assert self.uniques_class.shape[
0] >= 3, f'OrdinalClassifier needs at least 3 classes, only {self.uniques_class.shape[0]} found'
for i in range(self.uniques_class.shape[0] - 1):
binary_y = (y > self.uniques_class[i]).astype(np.uint8)
clf = clone(self.clf)
clf.fit(X, binary_y)
self.clfs[i] = clf
def predict(self, X):
"""Predict a new set of samples."""
return np.argmax(self.predict_proba(X), axis=1)
def predict_proba(self, X):
"""Calculate the probabilities of set of samples for each class"""
predicted = [self.clfs[k].predict_proba(X)[:, 1].reshape(-1, 1) for k in self.clfs]
p_x_first = 1 - predicted[0]
p_x_last = predicted[-1]
p_x_middle = [predicted[i] - predicted[i + 1] for i in range(len(predicted) - 1)]
probs = np.hstack([p_x_first, *p_x_middle, p_x_last])
return probs
def set_params(self, **params):
self.clf.set_params(**params)
for _, clf in self.clfs.items():
clf.set_params(**params)