From f017972e44f97e9119c407d16db5d3122df02f3f Mon Sep 17 00:00:00 2001 From: evidencebp Date: Thu, 10 Oct 2024 15:28:54 +0300 Subject: [PATCH] Fix _fit_estimators extraction _fit_estimators fails on E UnboundLocalError: cannot access local variable 'u_nk' where it is not associated with a value Since the goal is to reduce branches I revert that and extract the more isolated check_estimators_availability instead --- src/alchemlyb/workflows/abfe.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/alchemlyb/workflows/abfe.py b/src/alchemlyb/workflows/abfe.py index 9ee1f32b..187b8d06 100644 --- a/src/alchemlyb/workflows/abfe.py +++ b/src/alchemlyb/workflows/abfe.py @@ -444,11 +444,7 @@ def estimate(self, estimators=("MBAR", "BAR", "TI"), **kwargs): if isinstance(estimators, str): estimators = (estimators,) - for estimator in estimators: - if estimator not in (FEP_ESTIMATORS + TI_ESTIMATORS): - msg = f"Estimator {estimator} is not available in {FEP_ESTIMATORS + TI_ESTIMATORS}." - logger.error(msg) - raise ValueError(msg) + self.check_estimators_availability(estimators) logger.info(f"Start running estimator: {','.join(estimators)}.") self.estimator = {} @@ -469,9 +465,6 @@ def estimate(self, estimators=("MBAR", "BAR", "TI"), **kwargs): logger.warning("u_nk has not been preprocessed.") logger.info(f"A total {len(u_nk)} lines of u_nk is used.") - self._fit_estimators(dHdl, estimators, kwargs, u_nk) - - def _fit_estimators(self, dHdl, estimators, kwargs, u_nk): for estimator in estimators: if estimator == "MBAR": logger.info("Run MBAR estimator.") @@ -487,6 +480,13 @@ def _fit_estimators(self, dHdl, estimators, kwargs, u_nk): logger.info("Run TI estimator.") self.estimator[estimator] = TI(**kwargs).fit(dHdl) + def check_estimators_availability(self, estimators): + for estimator in estimators: + if estimator not in (FEP_ESTIMATORS + TI_ESTIMATORS): + msg = f"Estimator {estimator} is not available in {FEP_ESTIMATORS + TI_ESTIMATORS}." + logger.error(msg) + raise ValueError(msg) + def generate_result(self): """Summarise the result into a dataframe.