From 7aba0e444ef53d0a4415a0d9b375dff4127d6daa Mon Sep 17 00:00:00 2001 From: Mirko Bunse Date: Thu, 24 Oct 2024 12:19:58 +0200 Subject: [PATCH] Revert "Implement maximize_expectation with a JIT-able jax.lax.while_loop; support multiple bags only through wrapping this function in a vmap" This reverts commit cd94bd5b8e545022dd20b5ebee4335a57e44f739. --- qunfold/__init__.py | 2 +- qunfold/methods/likelihood.py | 49 +++++++++++++++-------------------- qunfold/tests/__init__.py | 30 +++++++++++++++++++-- 3 files changed, 50 insertions(+), 31 deletions(-) diff --git a/qunfold/__init__.py b/qunfold/__init__.py index 96de5ab..fec2095 100644 --- a/qunfold/__init__.py +++ b/qunfold/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.1.5-rc8" +__version__ = "0.1.5-rc7" from .methods.linear.losses import ( LeastSquaresLoss, diff --git a/qunfold/methods/likelihood.py b/qunfold/methods/likelihood.py index 0d99ad3..1d03b3d 100644 --- a/qunfold/methods/likelihood.py +++ b/qunfold/methods/likelihood.py @@ -1,4 +1,3 @@ -import jax import jax.numpy as jnp from . import AbstractMethod, check_y, class_prevalences, minimize, Result @@ -94,38 +93,32 @@ def maximize_expectation(pYX, p_trn, max_iter=100, tol=1e-8, omit_result_convers """The expectation maximization routine that is part of the `ExpectationMaximizer` by Saerens et al. (2002). Args: - pYX: A JAX matrix of the posterior probabilities of a classifier, `P(Y|X)`. This matrix has to have the shape `(n_items, n_classes)`, as returned by some `classifier.predict_proba(X)`. + pYX: A JAX matrix of the posterior probabilities of a classifier, `P(Y|X)`. This matrix has to have the shape `(n_items, n_classes)`, as returned by some `classifier.predict_proba(X)`. Multiple bags, with shape `(n_bags, n_items_per_bag, n_classes)` are also supported. p_trn: A JAX array of prior probabilities of the classifier. This array has to have the shape `(n_classes,)`. max_iter (optional): The maximum number of iterations. Defaults to `100`. tol (optional): The convergence tolerance for the L2 norm between iterations or None to disable convergence checks. Defaults to `1e-8`. omit_result_conversion (optional): Whether to omit the conversion into a `Result` type. """ pYX_pY = pYX / p_trn # P(Y|X) / P_trn(Y) - - # A JIT-able jax.lax.while_loop has the following semantics: - # - # def while_loop(cond_fun, body_fun, init_val): - # val = init_val - # while cond_fun(val): - # val = body_fun(val) - # return val - def cond_fn(val): # val = (p_next, p_prev, n_iter) - return jnp.logical_and( - jnp.logical_and(val[2] > 0, val[2] < max_iter), - jnp.linalg.norm(val[0] - val[1]) >= tol, - ) - def body_fn(val): - pYX = pYX_pY * val[0] # p_next=val[0] takes the role of p_prev + p_prev = p_trn + for n_iter in range(max_iter): + pYX = pYX_pY * p_prev pYX = pYX / pYX.sum(axis=-1, keepdims=True) # normalize to posterior probabilities - p_next = pYX.mean(axis=0) # shape (n_classes,) - return (p_next, val[0], val[2]+1) - p_est, _, n_iter = jax.lax.while_loop(cond_fn, body_fn, init_val=(p_trn, p_trn, 0)) - - # convert to a Result type with meta-data + p_next = pYX.mean(axis=-2, keepdims=True) # shape (n_bags, [1], n_classes) + if tol is not None: + if jnp.all(jnp.linalg.norm(p_next - p_prev, axis=-1) < tol): + if omit_result_conversion: + return jnp.squeeze(p_next, axis=-2) + return Result( + jnp.squeeze(p_next, axis=-2), + n_iter+1, + "Optimization terminated successfully." + ) + p_prev = p_next if omit_result_conversion: - return p_est - if n_iter < max_iter: - msg = "Optimization terminated successfully." - else: - msg = "Maximum number of iterations reached." - return Result(p_est, n_iter, msg) + return jnp.squeeze(p_prev, axis=-2) + return Result( + jnp.squeeze(p_prev, axis=-2), + max_iter, + "Maximum number of iterations reached." + ) diff --git a/qunfold/tests/__init__.py b/qunfold/tests/__init__.py index b27204e..40ea80c 100644 --- a/qunfold/tests/__init__.py +++ b/qunfold/tests/__init__.py @@ -129,6 +129,34 @@ def __init__(self, X, p): def __call__(self): yield self.X, self.p +class TestExpectationMaximizer(TestCase): + def test_maximize_expectation(self): + for _ in range(5): + q, M, p_trn = make_problem() + n_classes = len(p_trn) + X_trn, y_trn = generate_data(M, p_trn) + p_tst_a = RNG.permutation(p_trn) + X_tst_a, y_tst_a = generate_data(M, p_tst_a) + p_tst_b = RNG.permutation(p_trn) + X_tst_b, y_tst_b = generate_data(M, p_tst_b) + rf = RandomForestClassifier( + oob_score = True, + random_state = RNG.randint(np.iinfo("uint16").max), + ).fit(X_trn, y_trn) + pYX_a = rf.predict_proba(X_tst_a) + pYX_b = rf.predict_proba(X_tst_b) + params = { "max_iter": 10, "tol": None } + p_est_sep = np.array([ # separate optimization + qunfold.methods.likelihood.maximize_expectation(pYX_a, p_trn, **params), + qunfold.methods.likelihood.maximize_expectation(pYX_b, p_trn, **params) + ]) + p_est_jnt = qunfold.methods.likelihood.maximize_expectation( # joint optimization + np.array([pYX_a, pYX_b]), + p_trn, + **params + ) + np.testing.assert_array_equal(p_est_jnt, p_est_sep) + class TestQuaPyWrapper(TestCase): def test_methods(self): for _ in range(5): @@ -160,7 +188,6 @@ def test_methods(self): error = "mae", refit = False, verbose = True, - raise_errors = True, ).fit(qp.data.LabelledCollection(X_trn, y_trn)) self.assertEqual( # check that best parameters are actually used cv_acc.best_params_["representation__classifier__estimator__C"], @@ -175,7 +202,6 @@ def test_methods(self): error = "mae", refit = False, verbose = True, - raise_errors = True, ).fit(qp.data.LabelledCollection(X_trn, y_trn)) self.assertEqual( # check that best parameters are actually used cv_sld.best_params_["classifier__C"],