Skip to content

Commit

Permalink
Ammendment to the last commit; allow tolerance checks to be entirely …
Browse files Browse the repository at this point in the history
…disabled
  • Loading branch information
mirkobunse committed Sep 21, 2024
1 parent 88212f7 commit ad63445
Showing 1 changed file with 9 additions and 11 deletions.
20 changes: 9 additions & 11 deletions qunfold/methods/likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,23 +96,21 @@ def maximize_expectation(pYX, p_trn, max_iter=100, tol=1e-8):
pYX: A JAX matrix of the posterior probabilities of a classifier, `P(Y|X)`. This matrix has to have the shape `(n_items, n_classes)`, as returned by some `classifier.predict_proba(X)`. Multiple bags, with shape `(n_bags, n_items_per_bag, n_classes)` are also supported.
p_trn: A JAX array of prior probabilities of the classifier. This array has to have the shape `(n_classes,)`.
max_iter (optional): The maximum number of iterations. Defaults to `100`.
tol (optional): The convergence tolerance for the L2 norm between iterations. Defaults to `1e-8`.
tol (optional): The convergence tolerance for the L2 norm between iterations or None to disable convergence checks. Defaults to `1e-8`.
"""
pYX_pY = pYX / p_trn # P(Y|X) / P_trn(Y)
p_prev = jnp.expand_dims( # reshape to (n_bags, [1], n_classes), where n_bags might be 1
p_trn,
list(jnp.arange(len(pYX.shape)-1))
)
p_prev = p_trn
for n_iter in range(max_iter):
pYX = pYX_pY * p_prev
pYX = pYX / pYX.sum(axis=-1, keepdims=True) # normalize to posterior probabilities
p_next = pYX.mean(axis=-2, keepdims=True) # shape (n_bags, [1], n_classes)
if jnp.all(jnp.linalg.norm(p_next - p_prev, axis=-1) < tol):
return Result(
jnp.squeeze(p_next, axis=-2),
n_iter+1,
"Optimization terminated successfully."
)
if tol is not None:
if jnp.all(jnp.linalg.norm(p_next - p_prev, axis=-1) < tol):
return Result(
jnp.squeeze(p_next, axis=-2),
n_iter+1,
"Optimization terminated successfully."
)
p_prev = p_next
return Result(
jnp.squeeze(p_prev, axis=-2),
Expand Down

0 comments on commit ad63445

Please sign in to comment.