Skip to content

Commit

Permalink
Unit-test and correct the LikelihoodMaximizer
Browse files Browse the repository at this point in the history
  • Loading branch information
mirkobunse committed Sep 20, 2024
1 parent 9b6849b commit 04889cc
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 56 deletions.
18 changes: 12 additions & 6 deletions qunfold/methods/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,25 +31,31 @@ def predict(self, X):
"""
pass

def minimize(fun, n_classes, rng, solver, solver_options):
def minimize(
fun,
n_classes,
solver = "trust-ncg",
solver_options = {"gtol": 1e-8, "maxiter": 1000},
seed = None,
):
"""Numerically minimize a function to predict the most likely class prevalences.
This implementation makes use of a soft-max "trick" by Bunse (2022) and uses the auto-differentiation of JAX for second-order optimization.
Args:
fun: The function to minimize. Has to be implemented in JAX and has to have the signature `p -> loss`.
n_classes: The number of classes.
rng: A random number generator.
solver: The `method` argument in `scipy.optimize.minimize`.
solver_options: The `options` argument in `scipy.optimize.minimize`.
solver (optional): The `method` argument in `scipy.optimize.minimize`. Defaults to `"trust-ncg"`.
solver_options (optional): The `options` argument in `scipy.optimize.minimize`. Defaults to `{"gtol": 1e-8, "maxiter": 1000}`.
seed (optional): A seed for random number generation. Defaults to `None`.
Returns:
A solution vector `p`.
"""
fun_l = lambda l: fun(_jnp_softmax(l))
jac = jax.grad(fun_l) # Jacobian
hess = jax.jacfwd(jac) # Hessian through forward-mode AD
x0 = _rand_x0(rng, n_classes) # random starting point
x0 = _rand_x0(np.random.RandomState(seed), n_classes) # random starting point
state = _CallbackState(x0)
try:
opt = optimize.minimize(
Expand All @@ -59,7 +65,7 @@ def minimize(fun, n_classes, rng, solver, solver_options):
hess = _check_derivative(hess, "hess"),
method = solver,
options = solver_options,
callback = state.callback()
callback = state.callback(),
)
except (DerivativeError, ValueError):
traceback.print_exc()
Expand Down
17 changes: 9 additions & 8 deletions qunfold/methods/likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,28 +41,29 @@ def fit(self, X, y, n_classes=None):
self.classifier.fit(X, y)
return self
def predict(self, X):
pXY = classifier.predict_proba(X) / self.p_trn # proportional to P(X|Y)
pXY = self.classifier.predict_proba(X) / self.p_trn # proportional to P(X|Y)
pXY = pXY / pXY.sum(axis=1, keepdims=True) # normalize to P(X|Y)

# TODO 1) filter out all rows from pXY that contain zeros or ones, or values close to zero or one up to some self.epsilon. Goal: to reduce thrown errors / warnings and to replace the corresponding estimates with proper ones.
#epsilon_filtered_rows = pXY[np.any(pXY <= self.epsilon, axis=1), :]
#pXY = pXY[np.all(pXY > self.epsilon, axis=1),:]
# epsilon_filtered_rows = pXY[jnp.any(pXY <= self.epsilon, axis=1), :]
# pXY = pXY[jnp.all(pXY > self.epsilon, axis=1),:]

# TODO 2) "side-chain" those rows that have contained values close to one, by setting up a classify-and-count estimate that is later added to opt.x. Appropriately weight both the CC estimate and opt.x by the fraction of rows that has lead to each of these estimates. Goal: to further improve the estimation (see also the todo 3).

# TODO 3) consider self.epsilon as a hyper-parameter, assuming that all fairly confident predictions are probably correct, not only the extemely confident exceptions.

# minimize the negative log-likelihood loss
def loss(p):
xi_0 = jnp.sum((p[1:] - p[:-1])**2) / 2 # deviation from a uniform prediction
xi_1 = jnp.sum((-p[:-2] + 2 * p[1:-1] - p[2:])**2) / 2 # deviation from non-ordinal
return -jnp.log(pXY @ p).mean() + self.tau_0 * xi_0 + self.tau_1 * xi_1
# xi_0 = jnp.sum((p[1:] - p[:-1])**2) / 2 # deviation from a uniform prediction
# xi_1 = jnp.sum((-p[:-2] + 2 * p[1:-1] - p[2:])**2) / 2 # deviation from non-ordinal
# return -jnp.log(pXY @ p).mean() + self.tau_0 * xi_0 + self.tau_1 * xi_1
return -jnp.log(pXY @ p).mean()
return minimize(
loss,
len(self.p_trn),
np.random.RandomState(self.seed), # = rng
self.solver,
self.solver_options
self.solver_options,
self.seed,
)

class ExpectationMaximizer(AbstractMethod):
Expand Down
4 changes: 2 additions & 2 deletions qunfold/methods/linear/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@ def solve(self, q, M, N=None): # TODO add argument p_trn=self.p_trn
return minimize(
self.loss.instantiate(q, M, N),
M.shape[1], # = n_classes
np.random.RandomState(self.seed), # = rng
self.solver,
self.solver_options
self.solver_options,
self.seed,
)
@property
def p_trn(self):
Expand Down
67 changes: 27 additions & 40 deletions qunfold/tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import numpy as np
import jax.numpy as jnp
import numpy as np
import quapy as qp
import qunfold
import time
from quapy.data import LabelledCollection
from quapy.model_selection import GridSearchQ
from quapy.protocol import AbstractProtocol
from qunfold.quapy import QuaPyWrapper
from qunfold.sklearn import CVClassifier
from scipy.spatial.distance import cdist
Expand Down Expand Up @@ -54,47 +52,36 @@ def test_methods(self):
oob_score = True,
random_state = RNG.randint(np.iinfo("uint16").max),
)
p_acc = qunfold.ACC(rf).fit(X_trn, y_trn).predict(X_tst)
p_pacc = qunfold.PACC(rf).fit(X_trn, y_trn).predict(X_tst)
p_run = qunfold.RUN(qunfold.ClassRepresentation(rf), tau=1e6).fit(X_trn, y_trn).predict(X_tst)
p_hdx = qunfold.HDx(3).fit(X_trn, y_trn).predict(X_tst)
p_hdy = qunfold.HDy(rf, 3).fit(X_trn, y_trn).predict(X_tst)
p_edx = qunfold.EDx().fit(X_trn, y_trn).predict(X_tst)
p_edy = qunfold.EDy(rf).fit(X_trn, y_trn).predict(X_tst)
p_kmme = qunfold.KMM('energy').fit(X_trn, y_trn).predict(X_tst)
p_kmmg = qunfold.KMM('gaussian').fit(X_trn, y_trn).predict(X_tst)
p_kmml = qunfold.KMM('laplacian').fit(X_trn, y_trn).predict(X_tst)
p_rff = qunfold.KMM('rff').fit(X_trn, y_trn).predict(X_tst)
p_custom = qunfold.LinearMethod( # a custom method
p_cstm = qunfold.LinearMethod( # a custom method
qunfold.LeastSquaresLoss(),
qunfold.HistogramRepresentation(3)
).fit(X_trn, y_trn, n_classes).predict(X_tst)
p_kmme = qunfold.KMM('energy').fit(X_trn, y_trn).predict(X_tst)
p_rff = qunfold.KMM('rff').fit(X_trn, y_trn).predict(X_tst)
p_maxl = qunfold.LikelihoodMaximizer(rf).fit(X_trn, y_trn).predict(X_tst)
qp.environ["SAMPLE_SIZE"] = len(X_tst) # needed to compute the RAE
print(
f"LSq: p_acc = {p_acc}",
f" {p_acc.nit} it.; {p_acc.message}",
f" p_pacc = {p_pacc}",
f" {p_pacc.nit} it.; {p_pacc.message}",
f" p_run = {p_run}",
f" {p_run.nit} it.; {p_run.message}",
f" p_hdx = {p_hdx}",
f" {p_hdx.nit} it.; {p_hdx.message}",
f" p_hdy = {p_hdy}",
f" {p_hdy.nit} it.; {p_hdy.message}",
f" p_edx = {p_edx}",
f" {p_edx.nit} it.; {p_edx.message}",
f" p_edy = {p_edy}",
f" {p_edy.nit} it.; {p_edy.message}",
f" p_custom = {p_custom}",
f" {p_custom.nit} it.; {p_custom.message}",
f" p_kmme = {p_kmme}",
f" {p_kmme.nit} it.; {p_kmme.message}",
f" p_kmmg = {p_kmmg}",
f" {p_kmmg.nit} it.; {p_kmmg.message}",
f" p_kmml = {p_kmml}",
f" {p_kmml.nit} it.; {p_kmml.message}",
f" p_rff = {p_rff}",
f" {p_rff.nit} it.; {p_rff.message}",
f" p_tst = {p_tst}",
f" p_pacc = {p_pacc} (RAE {qp.error.rae(p_pacc, p_tst):.4f})",
f" {p_pacc.nit} it.; {p_pacc.message}",
f" p_run = {p_run} (RAE {qp.error.rae(p_run, p_tst):.4f})",
f" {p_run.nit} it.; {p_run.message}",
f" p_hdy = {p_hdy} (RAE {qp.error.rae(p_hdy, p_tst):.4f})",
f" {p_hdy.nit} it.; {p_hdy.message}",
f" p_edy = {p_edy} (RAE {qp.error.rae(p_edy, p_tst):.4f})",
f" {p_edy.nit} it.; {p_edy.message}",
f" p_cstm = {p_cstm} (RAE {qp.error.rae(p_cstm, p_tst):.4f})",
f" {p_cstm.nit} it.; {p_cstm.message}",
f" p_kmme = {p_kmme} (RAE {qp.error.rae(p_kmme, p_tst):.4f})",
f" {p_kmme.nit} it.; {p_kmme.message}",
f" p_rff = {p_rff} (RAE {qp.error.rae(p_rff, p_tst):.4f})",
f" {p_rff.nit} it.; {p_rff.message}",
f" p_maxl = {p_maxl} (RAE {qp.error.rae(p_maxl, p_tst):.4f})",
f" {p_maxl.nit} it.; {p_maxl.message}",
f" p_tst = {p_tst}",
sep = "\n",
end = "\n"*2
)
Expand Down Expand Up @@ -132,7 +119,7 @@ def test_methods(self):
# self.assertTrue(...)
print(f"Spent {time.time() - start}s")

class SingleSampleProtocol(AbstractProtocol):
class SingleSampleProtocol(qp.protocol.AbstractProtocol):
def __init__(self, X, p):
self.X = X
self.p = p
Expand All @@ -156,7 +143,7 @@ def test_methods(self):
p_acc.get_params(deep=True)["representation__classifier__estimator__C"],
1e-2
)
quapy_method = GridSearchQ(
quapy_method = qp.model_selection.GridSearchQ(
model = p_acc,
param_grid = {
"representation__classifier__estimator__C": [1e-1, 1e0, 1e1, 1e2],
Expand All @@ -165,7 +152,7 @@ def test_methods(self):
error = "mae",
refit = False,
verbose = True,
).fit(LabelledCollection(X_trn, y_trn))
).fit(qp.data.LabelledCollection(X_trn, y_trn))
self.assertEqual( # check that best parameters are actually used
quapy_method.best_params_["representation__classifier__estimator__C"],
quapy_method.best_model_.generic_method.representation.classifier.estimator.C
Expand Down

0 comments on commit 04889cc

Please sign in to comment.