Skip to content

Commit

Permalink
Remove outdated TODOs in LikelihoodMaximization
Browse files Browse the repository at this point in the history
  • Loading branch information
mirkobunse committed Sep 20, 2024
1 parent a3741d0 commit f159a88
Showing 1 changed file with 5 additions and 18 deletions.
23 changes: 5 additions & 18 deletions qunfold/methods/likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ def __init__(
solver_options = {"gtol": 1e-8, "maxiter": 1000}, # , "disp": True
tau_0 = 0,
tau_1 = 0,
epsilon = 0,
fit_classifier = True,
seed = None,
):
Expand All @@ -31,7 +30,6 @@ def __init__(
self.solver_options = solver_options
self.tau_0 = tau_0
self.tau_1 = tau_1
self.epsilon = epsilon
self.fit_classifier = fit_classifier
self.seed = seed
def fit(self, X, y, n_classes=None):
Expand All @@ -43,21 +41,10 @@ def fit(self, X, y, n_classes=None):
def predict(self, X):
pXY = jnp.array(self.classifier.predict_proba(X) / self.p_trn) # proportional to P(X|Y)
pXY = pXY / pXY.sum(axis=1, keepdims=True) # normalize to P(X|Y)

# TODO 1) filter out all rows from pXY that contain zeros or ones, or values close to zero or one up to some self.epsilon. Goal: to reduce thrown errors / warnings and to replace the corresponding estimates with proper ones.
# epsilon_filtered_rows = pXY[jnp.any(pXY <= self.epsilon, axis=1), :]
# pXY = pXY[jnp.all(pXY > self.epsilon, axis=1),:]

# TODO 2) "side-chain" those rows that have contained values close to one, by setting up a classify-and-count estimate that is later added to opt.x. Appropriately weight both the CC estimate and opt.x by the fraction of rows that has lead to each of these estimates. Goal: to further improve the estimation (see also the todo 3).

# TODO 3) consider self.epsilon as a hyper-parameter, assuming that all fairly confident predictions are probably correct, not only the extemely confident exceptions.

# minimize the negative log-likelihood loss
def loss(p):
# xi_0 = jnp.sum((p[1:] - p[:-1])**2) / 2 # deviation from a uniform prediction
# xi_1 = jnp.sum((-p[:-2] + 2 * p[1:-1] - p[2:])**2) / 2 # deviation from non-ordinal
# return -jnp.log(pXY @ p).mean() + self.tau_0 * xi_0 + self.tau_1 * xi_1
return -jnp.log(pXY @ p).mean()
def loss(p): # the (regularized) negative log-likelihood loss
xi_0 = jnp.sum((p[1:] - p[:-1])**2) / 2 # deviation from a uniform prediction
xi_1 = jnp.sum((-p[:-2] + 2 * p[1:-1] - p[2:])**2) / 2 # deviation from non-ordinal
return -jnp.log(pXY @ p).mean() + self.tau_0 * xi_0 + self.tau_1 * xi_1
return minimize(
loss,
len(self.p_trn),
Expand All @@ -67,7 +54,7 @@ def loss(p):
)

class ExpectationMaximizer(AbstractMethod):
"""The expectation maximization-based method by Saerens et al. (2002).
"""The expectation maximization-based method by Saerens et al. (2002).
This method is proven to be asymptotically equivalent to the `LikelihoodMaximizer` by Alexandari et al. (2020).
Expand Down

0 comments on commit f159a88

Please sign in to comment.