Skip to content

Commit

Permalink
Revert "Implement maximize_expectation with a JIT-able jax.lax.while_…
Browse files Browse the repository at this point in the history
…loop; support multiple bags only through wrapping this function in a vmap"

This reverts commit cd94bd5.
  • Loading branch information
mirkobunse committed Oct 24, 2024
1 parent cd94bd5 commit 7aba0e4
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 31 deletions.
2 changes: 1 addition & 1 deletion qunfold/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.1.5-rc8"
__version__ = "0.1.5-rc7"

from .methods.linear.losses import (
LeastSquaresLoss,
Expand Down
49 changes: 21 additions & 28 deletions qunfold/methods/likelihood.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import jax
import jax.numpy as jnp
from . import AbstractMethod, check_y, class_prevalences, minimize, Result

Expand Down Expand Up @@ -94,38 +93,32 @@ def maximize_expectation(pYX, p_trn, max_iter=100, tol=1e-8, omit_result_convers
"""The expectation maximization routine that is part of the `ExpectationMaximizer` by Saerens et al. (2002).
Args:
pYX: A JAX matrix of the posterior probabilities of a classifier, `P(Y|X)`. This matrix has to have the shape `(n_items, n_classes)`, as returned by some `classifier.predict_proba(X)`.
pYX: A JAX matrix of the posterior probabilities of a classifier, `P(Y|X)`. This matrix has to have the shape `(n_items, n_classes)`, as returned by some `classifier.predict_proba(X)`. Multiple bags, with shape `(n_bags, n_items_per_bag, n_classes)` are also supported.
p_trn: A JAX array of prior probabilities of the classifier. This array has to have the shape `(n_classes,)`.
max_iter (optional): The maximum number of iterations. Defaults to `100`.
tol (optional): The convergence tolerance for the L2 norm between iterations or None to disable convergence checks. Defaults to `1e-8`.
omit_result_conversion (optional): Whether to omit the conversion into a `Result` type.
"""
pYX_pY = pYX / p_trn # P(Y|X) / P_trn(Y)

# A JIT-able jax.lax.while_loop has the following semantics:
#
# def while_loop(cond_fun, body_fun, init_val):
# val = init_val
# while cond_fun(val):
# val = body_fun(val)
# return val
def cond_fn(val): # val = (p_next, p_prev, n_iter)
return jnp.logical_and(
jnp.logical_and(val[2] > 0, val[2] < max_iter),
jnp.linalg.norm(val[0] - val[1]) >= tol,
)
def body_fn(val):
pYX = pYX_pY * val[0] # p_next=val[0] takes the role of p_prev
p_prev = p_trn
for n_iter in range(max_iter):
pYX = pYX_pY * p_prev
pYX = pYX / pYX.sum(axis=-1, keepdims=True) # normalize to posterior probabilities
p_next = pYX.mean(axis=0) # shape (n_classes,)
return (p_next, val[0], val[2]+1)
p_est, _, n_iter = jax.lax.while_loop(cond_fn, body_fn, init_val=(p_trn, p_trn, 0))

# convert to a Result type with meta-data
p_next = pYX.mean(axis=-2, keepdims=True) # shape (n_bags, [1], n_classes)
if tol is not None:
if jnp.all(jnp.linalg.norm(p_next - p_prev, axis=-1) < tol):
if omit_result_conversion:
return jnp.squeeze(p_next, axis=-2)
return Result(
jnp.squeeze(p_next, axis=-2),
n_iter+1,
"Optimization terminated successfully."
)
p_prev = p_next
if omit_result_conversion:
return p_est
if n_iter < max_iter:
msg = "Optimization terminated successfully."
else:
msg = "Maximum number of iterations reached."
return Result(p_est, n_iter, msg)
return jnp.squeeze(p_prev, axis=-2)
return Result(
jnp.squeeze(p_prev, axis=-2),
max_iter,
"Maximum number of iterations reached."
)
30 changes: 28 additions & 2 deletions qunfold/tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,34 @@ def __init__(self, X, p):
def __call__(self):
yield self.X, self.p

class TestExpectationMaximizer(TestCase):
def test_maximize_expectation(self):
for _ in range(5):
q, M, p_trn = make_problem()
n_classes = len(p_trn)
X_trn, y_trn = generate_data(M, p_trn)
p_tst_a = RNG.permutation(p_trn)
X_tst_a, y_tst_a = generate_data(M, p_tst_a)
p_tst_b = RNG.permutation(p_trn)
X_tst_b, y_tst_b = generate_data(M, p_tst_b)
rf = RandomForestClassifier(
oob_score = True,
random_state = RNG.randint(np.iinfo("uint16").max),
).fit(X_trn, y_trn)
pYX_a = rf.predict_proba(X_tst_a)
pYX_b = rf.predict_proba(X_tst_b)
params = { "max_iter": 10, "tol": None }
p_est_sep = np.array([ # separate optimization
qunfold.methods.likelihood.maximize_expectation(pYX_a, p_trn, **params),
qunfold.methods.likelihood.maximize_expectation(pYX_b, p_trn, **params)
])
p_est_jnt = qunfold.methods.likelihood.maximize_expectation( # joint optimization
np.array([pYX_a, pYX_b]),
p_trn,
**params
)
np.testing.assert_array_equal(p_est_jnt, p_est_sep)

class TestQuaPyWrapper(TestCase):
def test_methods(self):
for _ in range(5):
Expand Down Expand Up @@ -160,7 +188,6 @@ def test_methods(self):
error = "mae",
refit = False,
verbose = True,
raise_errors = True,
).fit(qp.data.LabelledCollection(X_trn, y_trn))
self.assertEqual( # check that best parameters are actually used
cv_acc.best_params_["representation__classifier__estimator__C"],
Expand All @@ -175,7 +202,6 @@ def test_methods(self):
error = "mae",
refit = False,
verbose = True,
raise_errors = True,
).fit(qp.data.LabelledCollection(X_trn, y_trn))
self.assertEqual( # check that best parameters are actually used
cv_sld.best_params_["classifier__C"],
Expand Down

0 comments on commit 7aba0e4

Please sign in to comment.