Skip to content

Commit

Permalink
src\alchemlyb\workflows\abfe.py too-many-branches
Browse files Browse the repository at this point in the history
Both methods generate_result and estimate had a bit too many branches.
I extracted methods to reduce complexity and make code clearer.
  • Loading branch information
evidencebp committed Oct 10, 2024
1 parent 3a311d4 commit ea1ecca
Showing 1 changed file with 19 additions and 10 deletions.
29 changes: 19 additions & 10 deletions src/alchemlyb/workflows/abfe.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,9 @@ def estimate(self, estimators=("MBAR", "BAR", "TI"), **kwargs):
logger.warning("u_nk has not been preprocessed.")
logger.info(f"A total {len(u_nk)} lines of u_nk is used.")

self._fit_estimators(dHdl, estimators, kwargs, u_nk)

def _fit_estimators(self, dHdl, estimators, kwargs, u_nk):
for estimator in estimators:
if estimator == "MBAR":
logger.info("Run MBAR estimator.")
Expand Down Expand Up @@ -553,11 +556,7 @@ def generate_result(self):
stages = dHdl.reset_index("time").index.names
logger.info("use the stage name from dHdl")

for stage in stages:
data_dict["name"].append(stage.split("-")[0])
data_dict["state"].append("Stages")
data_dict["name"].append("TOTAL")
data_dict["state"].append("Stages")
self.handle_stages(data_dict, stages)

col_names = []
for estimator_name, estimator in self.estimator.items():
Expand All @@ -572,11 +571,7 @@ def generate_result(self):
col_names.append(estimator_name + "_Error")
data_dict[estimator_name] = []
data_dict[estimator_name + "_Error"] = []
for index in range(1, num_states):
data_dict[estimator_name].append(delta_f_.iloc[index - 1, index])
data_dict[estimator_name + "_Error"].append(
d_delta_f_.iloc[index - 1, index]
)
self.handle_states(d_delta_f_, data_dict, delta_f_, estimator_name, num_states)

logger.info(f"Generate the staged result from estimator {estimator_name}")
for index, stage in enumerate(stages):
Expand Down Expand Up @@ -642,6 +637,20 @@ def generate_result(self):
logger.info(f"Write results:\n{summary.to_string()}")
return summary

def handle_states(self, d_delta_f_, data_dict, delta_f_, estimator_name, num_states):
for index in range(1, num_states):
data_dict[estimator_name].append(delta_f_.iloc[index - 1, index])
data_dict[estimator_name + "_Error"].append(
d_delta_f_.iloc[index - 1, index]
)

def handle_stages(self, data_dict, stages):
for stage in stages:
data_dict["name"].append(stage.split("-")[0])
data_dict["state"].append("Stages")
data_dict["name"].append("TOTAL")
data_dict["state"].append("Stages")

def plot_overlap_matrix(self, overlap="O_MBAR.pdf", ax=None):
"""Plot the overlap matrix for MBAR estimator using
:func:`~alchemlyb.visualisation.plot_mbar_overlap_matrix`.
Expand Down

0 comments on commit ea1ecca

Please sign in to comment.