From 04889cc180974f8a44afeddd6cfb81d935f73593 Mon Sep 17 00:00:00 2001 From: Mirko Bunse Date: Fri, 20 Sep 2024 09:25:34 +0200 Subject: [PATCH] Unit-test and correct the LikelihoodMaximizer --- qunfold/methods/__init__.py | 18 +++++--- qunfold/methods/likelihood.py | 17 ++++---- qunfold/methods/linear/__init__.py | 4 +- qunfold/tests/__init__.py | 67 ++++++++++++------------------ 4 files changed, 50 insertions(+), 56 deletions(-) diff --git a/qunfold/methods/__init__.py b/qunfold/methods/__init__.py index d71e234..5529716 100644 --- a/qunfold/methods/__init__.py +++ b/qunfold/methods/__init__.py @@ -31,7 +31,13 @@ def predict(self, X): """ pass -def minimize(fun, n_classes, rng, solver, solver_options): +def minimize( + fun, + n_classes, + solver = "trust-ncg", + solver_options = {"gtol": 1e-8, "maxiter": 1000}, + seed = None, + ): """Numerically minimize a function to predict the most likely class prevalences. This implementation makes use of a soft-max "trick" by Bunse (2022) and uses the auto-differentiation of JAX for second-order optimization. @@ -39,9 +45,9 @@ def minimize(fun, n_classes, rng, solver, solver_options): Args: fun: The function to minimize. Has to be implemented in JAX and has to have the signature `p -> loss`. n_classes: The number of classes. - rng: A random number generator. - solver: The `method` argument in `scipy.optimize.minimize`. - solver_options: The `options` argument in `scipy.optimize.minimize`. + solver (optional): The `method` argument in `scipy.optimize.minimize`. Defaults to `"trust-ncg"`. + solver_options (optional): The `options` argument in `scipy.optimize.minimize`. Defaults to `{"gtol": 1e-8, "maxiter": 1000}`. + seed (optional): A seed for random number generation. Defaults to `None`. Returns: A solution vector `p`. @@ -49,7 +55,7 @@ def minimize(fun, n_classes, rng, solver, solver_options): fun_l = lambda l: fun(_jnp_softmax(l)) jac = jax.grad(fun_l) # Jacobian hess = jax.jacfwd(jac) # Hessian through forward-mode AD - x0 = _rand_x0(rng, n_classes) # random starting point + x0 = _rand_x0(np.random.RandomState(seed), n_classes) # random starting point state = _CallbackState(x0) try: opt = optimize.minimize( @@ -59,7 +65,7 @@ def minimize(fun, n_classes, rng, solver, solver_options): hess = _check_derivative(hess, "hess"), method = solver, options = solver_options, - callback = state.callback() + callback = state.callback(), ) except (DerivativeError, ValueError): traceback.print_exc() diff --git a/qunfold/methods/likelihood.py b/qunfold/methods/likelihood.py index d77cd55..e60623d 100644 --- a/qunfold/methods/likelihood.py +++ b/qunfold/methods/likelihood.py @@ -41,12 +41,12 @@ def fit(self, X, y, n_classes=None): self.classifier.fit(X, y) return self def predict(self, X): - pXY = classifier.predict_proba(X) / self.p_trn # proportional to P(X|Y) + pXY = self.classifier.predict_proba(X) / self.p_trn # proportional to P(X|Y) pXY = pXY / pXY.sum(axis=1, keepdims=True) # normalize to P(X|Y) # TODO 1) filter out all rows from pXY that contain zeros or ones, or values close to zero or one up to some self.epsilon. Goal: to reduce thrown errors / warnings and to replace the corresponding estimates with proper ones. - #epsilon_filtered_rows = pXY[np.any(pXY <= self.epsilon, axis=1), :] - #pXY = pXY[np.all(pXY > self.epsilon, axis=1),:] + # epsilon_filtered_rows = pXY[jnp.any(pXY <= self.epsilon, axis=1), :] + # pXY = pXY[jnp.all(pXY > self.epsilon, axis=1),:] # TODO 2) "side-chain" those rows that have contained values close to one, by setting up a classify-and-count estimate that is later added to opt.x. Appropriately weight both the CC estimate and opt.x by the fraction of rows that has lead to each of these estimates. Goal: to further improve the estimation (see also the todo 3). @@ -54,15 +54,16 @@ def predict(self, X): # minimize the negative log-likelihood loss def loss(p): - xi_0 = jnp.sum((p[1:] - p[:-1])**2) / 2 # deviation from a uniform prediction - xi_1 = jnp.sum((-p[:-2] + 2 * p[1:-1] - p[2:])**2) / 2 # deviation from non-ordinal - return -jnp.log(pXY @ p).mean() + self.tau_0 * xi_0 + self.tau_1 * xi_1 + # xi_0 = jnp.sum((p[1:] - p[:-1])**2) / 2 # deviation from a uniform prediction + # xi_1 = jnp.sum((-p[:-2] + 2 * p[1:-1] - p[2:])**2) / 2 # deviation from non-ordinal + # return -jnp.log(pXY @ p).mean() + self.tau_0 * xi_0 + self.tau_1 * xi_1 + return -jnp.log(pXY @ p).mean() return minimize( loss, len(self.p_trn), - np.random.RandomState(self.seed), # = rng self.solver, - self.solver_options + self.solver_options, + self.seed, ) class ExpectationMaximizer(AbstractMethod): diff --git a/qunfold/methods/linear/__init__.py b/qunfold/methods/linear/__init__.py index 36902c4..dd73403 100644 --- a/qunfold/methods/linear/__init__.py +++ b/qunfold/methods/linear/__init__.py @@ -53,9 +53,9 @@ def solve(self, q, M, N=None): # TODO add argument p_trn=self.p_trn return minimize( self.loss.instantiate(q, M, N), M.shape[1], # = n_classes - np.random.RandomState(self.seed), # = rng self.solver, - self.solver_options + self.solver_options, + self.seed, ) @property def p_trn(self): diff --git a/qunfold/tests/__init__.py b/qunfold/tests/__init__.py index f9541d4..d338457 100644 --- a/qunfold/tests/__init__.py +++ b/qunfold/tests/__init__.py @@ -1,10 +1,8 @@ -import numpy as np import jax.numpy as jnp +import numpy as np +import quapy as qp import qunfold import time -from quapy.data import LabelledCollection -from quapy.model_selection import GridSearchQ -from quapy.protocol import AbstractProtocol from qunfold.quapy import QuaPyWrapper from qunfold.sklearn import CVClassifier from scipy.spatial.distance import cdist @@ -54,47 +52,36 @@ def test_methods(self): oob_score = True, random_state = RNG.randint(np.iinfo("uint16").max), ) - p_acc = qunfold.ACC(rf).fit(X_trn, y_trn).predict(X_tst) p_pacc = qunfold.PACC(rf).fit(X_trn, y_trn).predict(X_tst) p_run = qunfold.RUN(qunfold.ClassRepresentation(rf), tau=1e6).fit(X_trn, y_trn).predict(X_tst) - p_hdx = qunfold.HDx(3).fit(X_trn, y_trn).predict(X_tst) p_hdy = qunfold.HDy(rf, 3).fit(X_trn, y_trn).predict(X_tst) - p_edx = qunfold.EDx().fit(X_trn, y_trn).predict(X_tst) p_edy = qunfold.EDy(rf).fit(X_trn, y_trn).predict(X_tst) - p_kmme = qunfold.KMM('energy').fit(X_trn, y_trn).predict(X_tst) - p_kmmg = qunfold.KMM('gaussian').fit(X_trn, y_trn).predict(X_tst) - p_kmml = qunfold.KMM('laplacian').fit(X_trn, y_trn).predict(X_tst) - p_rff = qunfold.KMM('rff').fit(X_trn, y_trn).predict(X_tst) - p_custom = qunfold.LinearMethod( # a custom method + p_cstm = qunfold.LinearMethod( # a custom method qunfold.LeastSquaresLoss(), qunfold.HistogramRepresentation(3) ).fit(X_trn, y_trn, n_classes).predict(X_tst) + p_kmme = qunfold.KMM('energy').fit(X_trn, y_trn).predict(X_tst) + p_rff = qunfold.KMM('rff').fit(X_trn, y_trn).predict(X_tst) + p_maxl = qunfold.LikelihoodMaximizer(rf).fit(X_trn, y_trn).predict(X_tst) + qp.environ["SAMPLE_SIZE"] = len(X_tst) # needed to compute the RAE print( - f"LSq: p_acc = {p_acc}", - f" {p_acc.nit} it.; {p_acc.message}", - f" p_pacc = {p_pacc}", - f" {p_pacc.nit} it.; {p_pacc.message}", - f" p_run = {p_run}", - f" {p_run.nit} it.; {p_run.message}", - f" p_hdx = {p_hdx}", - f" {p_hdx.nit} it.; {p_hdx.message}", - f" p_hdy = {p_hdy}", - f" {p_hdy.nit} it.; {p_hdy.message}", - f" p_edx = {p_edx}", - f" {p_edx.nit} it.; {p_edx.message}", - f" p_edy = {p_edy}", - f" {p_edy.nit} it.; {p_edy.message}", - f" p_custom = {p_custom}", - f" {p_custom.nit} it.; {p_custom.message}", - f" p_kmme = {p_kmme}", - f" {p_kmme.nit} it.; {p_kmme.message}", - f" p_kmmg = {p_kmmg}", - f" {p_kmmg.nit} it.; {p_kmmg.message}", - f" p_kmml = {p_kmml}", - f" {p_kmml.nit} it.; {p_kmml.message}", - f" p_rff = {p_rff}", - f" {p_rff.nit} it.; {p_rff.message}", - f" p_tst = {p_tst}", + f" p_pacc = {p_pacc} (RAE {qp.error.rae(p_pacc, p_tst):.4f})", + f" {p_pacc.nit} it.; {p_pacc.message}", + f" p_run = {p_run} (RAE {qp.error.rae(p_run, p_tst):.4f})", + f" {p_run.nit} it.; {p_run.message}", + f" p_hdy = {p_hdy} (RAE {qp.error.rae(p_hdy, p_tst):.4f})", + f" {p_hdy.nit} it.; {p_hdy.message}", + f" p_edy = {p_edy} (RAE {qp.error.rae(p_edy, p_tst):.4f})", + f" {p_edy.nit} it.; {p_edy.message}", + f" p_cstm = {p_cstm} (RAE {qp.error.rae(p_cstm, p_tst):.4f})", + f" {p_cstm.nit} it.; {p_cstm.message}", + f" p_kmme = {p_kmme} (RAE {qp.error.rae(p_kmme, p_tst):.4f})", + f" {p_kmme.nit} it.; {p_kmme.message}", + f" p_rff = {p_rff} (RAE {qp.error.rae(p_rff, p_tst):.4f})", + f" {p_rff.nit} it.; {p_rff.message}", + f" p_maxl = {p_maxl} (RAE {qp.error.rae(p_maxl, p_tst):.4f})", + f" {p_maxl.nit} it.; {p_maxl.message}", + f" p_tst = {p_tst}", sep = "\n", end = "\n"*2 ) @@ -132,7 +119,7 @@ def test_methods(self): # self.assertTrue(...) print(f"Spent {time.time() - start}s") -class SingleSampleProtocol(AbstractProtocol): +class SingleSampleProtocol(qp.protocol.AbstractProtocol): def __init__(self, X, p): self.X = X self.p = p @@ -156,7 +143,7 @@ def test_methods(self): p_acc.get_params(deep=True)["representation__classifier__estimator__C"], 1e-2 ) - quapy_method = GridSearchQ( + quapy_method = qp.model_selection.GridSearchQ( model = p_acc, param_grid = { "representation__classifier__estimator__C": [1e-1, 1e0, 1e1, 1e2], @@ -165,7 +152,7 @@ def test_methods(self): error = "mae", refit = False, verbose = True, - ).fit(LabelledCollection(X_trn, y_trn)) + ).fit(qp.data.LabelledCollection(X_trn, y_trn)) self.assertEqual( # check that best parameters are actually used quapy_method.best_params_["representation__classifier__estimator__C"], quapy_method.best_model_.generic_method.representation.classifier.estimator.C