diff --git a/qunfold/methods/likelihood.py b/qunfold/methods/likelihood.py index 2b64fae..ee94eb7 100644 --- a/qunfold/methods/likelihood.py +++ b/qunfold/methods/likelihood.py @@ -96,23 +96,21 @@ def maximize_expectation(pYX, p_trn, max_iter=100, tol=1e-8): pYX: A JAX matrix of the posterior probabilities of a classifier, `P(Y|X)`. This matrix has to have the shape `(n_items, n_classes)`, as returned by some `classifier.predict_proba(X)`. Multiple bags, with shape `(n_bags, n_items_per_bag, n_classes)` are also supported. p_trn: A JAX array of prior probabilities of the classifier. This array has to have the shape `(n_classes,)`. max_iter (optional): The maximum number of iterations. Defaults to `100`. - tol (optional): The convergence tolerance for the L2 norm between iterations. Defaults to `1e-8`. + tol (optional): The convergence tolerance for the L2 norm between iterations or None to disable convergence checks. Defaults to `1e-8`. """ pYX_pY = pYX / p_trn # P(Y|X) / P_trn(Y) - p_prev = jnp.expand_dims( # reshape to (n_bags, [1], n_classes), where n_bags might be 1 - p_trn, - list(jnp.arange(len(pYX.shape)-1)) - ) + p_prev = p_trn for n_iter in range(max_iter): pYX = pYX_pY * p_prev pYX = pYX / pYX.sum(axis=-1, keepdims=True) # normalize to posterior probabilities p_next = pYX.mean(axis=-2, keepdims=True) # shape (n_bags, [1], n_classes) - if jnp.all(jnp.linalg.norm(p_next - p_prev, axis=-1) < tol): - return Result( - jnp.squeeze(p_next, axis=-2), - n_iter+1, - "Optimization terminated successfully." - ) + if tol is not None: + if jnp.all(jnp.linalg.norm(p_next - p_prev, axis=-1) < tol): + return Result( + jnp.squeeze(p_next, axis=-2), + n_iter+1, + "Optimization terminated successfully." + ) p_prev = p_next return Result( jnp.squeeze(p_prev, axis=-2),