From ada86ff5656dbfa2e7e8595615e0b468fb52d72f Mon Sep 17 00:00:00 2001 From: minhtrung23 Date: Fri, 20 Sep 2024 18:07:14 +0700 Subject: [PATCH] Update __reasoning.py src\melt\tools\pipelines\__reasoning.py:5:0: R0914: Too many local variables (24/15) (too-many-locals) --- src/melt/tools/pipelines/__reasoning.py | 272 ++++++++++-------------- 1 file changed, 112 insertions(+), 160 deletions(-) diff --git a/src/melt/tools/pipelines/__reasoning.py b/src/melt/tools/pipelines/__reasoning.py index 97e3f8f..d36a749 100644 --- a/src/melt/tools/pipelines/__reasoning.py +++ b/src/melt/tools/pipelines/__reasoning.py @@ -1,184 +1,136 @@ -" _reasoning" +"reasoning" import random -from dataclasses import dataclass from tqdm import tqdm -from utils.utils import format_fewshot - -@dataclass -class ReasoningConfig: - "class" - config: any - task_name: str - continue_infer_data: dict = None - -class FewShotManager: - "class" - def additional_method(self): - """ - Another public method to satisfy the two-method requirement. - """ - print("This is an additional public method.") - def __init__(self, ds_wrapper, config): - self.ds_wrapper = ds_wrapper - self.config = config - self.selected_sample = [] - self.original_few_shot = [] - self.calib_few_shot = [] - def prepare_few_shot(self): - "pre" - if not self.config.few_shot: - return +from melt.tools.utils.utils import format_fewshot +def __reasoning(self, ds_wrapper, ds_loader, saving_fn, start_idx=0): + predictions = [] + references = [] + generation_probs = [] + calib_probs = [] + idx = 0 + original_few_shot = [] + calib_few_shot = [] + selected_sample = [] + + if self.continue_infer_data is not None: + predictions.extend(self.continue_infer_data["predictions"]) + references.extend(self.continue_infer_data["references"]) + generation_probs.extend(self.continue_infer_data["generation_probs"]) + calib_probs.extend(self.continue_infer_data["calibration_probs"]) + + if self.few_shot: def preprocessing_a_record(rec): return [ - rec[self.ds_wrapper.dataset_info.query], - rec[self.ds_wrapper.dataset_info.answer], + rec[ds_wrapper.dataset_info.query], + rec[ds_wrapper.dataset_info.answer], ] - self.selected_sample = [ + selected_sample = [ preprocessing_a_record(s) - for s in random.sample(list(self.ds_wrapper.dataset_training), self.config.num_fs) + for s in list( + random.sample( + list(ds_wrapper.dataset_training), self.config.num_fs + ) + ) ] - self.original_few_shot = format_fewshot( - self.selected_sample, - query_format=self.ds_wrapper.prompt["prompt"], - answer_format=self.ds_wrapper.prompt["answer_format"], + original_few_shot = format_fewshot( + selected_sample, + query_format=ds_wrapper.prompt["prompt"], + answer_format=ds_wrapper.prompt["answer_format"], ) - self.calib_few_shot = format_fewshot( - self.selected_sample, - query_format=self.ds_wrapper.calibration_prompt["prompt"], - answer_format=self.ds_wrapper.prompt["answer_format"], + calib_few_shot = format_fewshot( + selected_sample, + query_format=ds_wrapper.calibration_prompt["prompt"], + answer_format=ds_wrapper.prompt["answer_format"], ) -class ResultsManager: - "class" - def __init__(self, continue_infer_data=None): - self.predictions = [] - self.references = [] - self.generation_probs = [] - self.calib_probs = [] - - if continue_infer_data: - self.predictions.extend(continue_infer_data["predictions"]) - self.references.extend(continue_infer_data["references"]) - self.generation_probs.extend(continue_infer_data["generation_probs"]) - self.calib_probs.extend(continue_infer_data["calibration_probs"]) - - def extend_results(self, batch_results, batch_references, batch_logprobs, batch_calibprobs): - "extend" - self.predictions.extend(batch_results) - self.references.extend(batch_references) - self.generation_probs.extend(batch_logprobs) - self.calib_probs.extend(batch_calibprobs) - - def get_generations(self, few_shot_sample): - "get" - return { - "predictions": self.predictions, - "references": self.references, - "generation_probs": self.generation_probs, - "calibration_probs": self.calib_probs, - "fewshot": few_shot_sample, - } + for batch in tqdm(ds_loader): + if idx < start_idx: + idx += 1 + continue -class ReasoningPipeline: - "class" - def additional_method2(self): - """ - Another public method to satisfy the two-method requirement. - """ - print("This is an additional public method.") - def additional_method3(self): - """ - Another public method to satisfy the two-method requirement. - """ - print("This is an additional public method.") - def __init__(self, reasoning_config: ReasoningConfig, infer_pipeline, metric_pipeline): - self.config = reasoning_config.config - self.task_name = reasoning_config.task_name - self.infer_pipeline = infer_pipeline - self.metric_pipeline = metric_pipeline - self.continue_infer_data = reasoning_config.continue_infer_data - - def _reasoning(self, ds_wrapper, ds_loader, saving_fn, start_idx=0): - few_shot_manager = FewShotManager(ds_wrapper, self.config) - few_shot_manager.prepare_few_shot() - - results_manager = ResultsManager(self.continue_infer_data) - - for idx, batch in enumerate(tqdm(ds_loader)): - if idx < start_idx: - continue - - prompts = self._create_prompts(batch, ds_wrapper, few_shot_manager.original_few_shot) - calib_prompts = self._create_calib_prompts(batch, - ds_wrapper, few_shot_manager.calib_few_shot) - - results, logprobs, _ = self.infer_pipeline(prompts, return_probs=True) - calibprob_batch, _ = self.infer_pipeline.compute_logprob_and_length( - calib_prompts, batch[ds_wrapper.dataset_info.answer] - ) - - results_manager.extend_results( - results, - batch[ds_wrapper.dataset_info.answer], - logprobs, - calibprob_batch - ) - - if (idx + 1) % 100 == 0: - self._save_intermediate_results(idx + 1, results_manager, ds_wrapper, saving_fn) - - self._save_final_results(results_manager, ds_wrapper, saving_fn) - - def _create_prompts(self, batch, ds_wrapper, few_shot): - return [ + prompts = [ [ - {"role": "system", "content": ds_wrapper.prompt["system_prompt"]}, - *few_shot, - {"role": "user", "content": ds_wrapper.prompt["prompt"].format(rule)}, + { + "role": "system", + "content": ds_wrapper.prompt["system_prompt"], + }, + *original_few_shot, + { + "role": "user", + "content": ds_wrapper.prompt["prompt"].format(rule), + }, ] for rule in batch[ds_wrapper.dataset_info.query] ] - - def _create_calib_prompts(self, batch, ds_wrapper, calib_few_shot): - return [ + calib_prompts = [ [ - {"role": "system", "content": ds_wrapper.calibration_prompt["system_prompt"]}, + { + "role": "system", + "content": ds_wrapper.calibration_prompt["system_prompt"], + }, *calib_few_shot, - {"role": "user", "content": ds_wrapper.calibration_prompt["prompt"].format(rule)}, + { + "role": "user", + "content": ds_wrapper.calibration_prompt["prompt"].format(rule), + }, ] for rule in batch[ds_wrapper.dataset_info.query] ] - def _save_intermediate_results(self, batch_count, results_manager, ds_wrapper, saving_fn): - print(f"Saving results of {batch_count} batches") - generations = results_manager.get_generations(results_manager.selected_sample) - saving_fn(generations) - mean_result = self._calculate_mean_result(generations, ds_wrapper) - print(f"Results of {batch_count} batches: ", mean_result) - - def _save_final_results(self, results_manager, ds_wrapper, saving_fn): - generations = results_manager.get_generations(results_manager.selected_sample) - mean_result = self._calculate_mean_result(generations, ds_wrapper) - std_result = self._calculate_std_result(generations, ds_wrapper) - final_result = {"mean": mean_result, "std": std_result} - saving_fn(generations, final_result) - - def _calculate_mean_result(self, generations, ds_wrapper): - return self.metric_pipeline.run_mean( - generations, - self.task_name, - ds_wrapper.prompt["answer_key"], - ds_wrapper.dataset_info.label, - self.config, - ) - - def _calculate_std_result(self, generations, ds_wrapper): - return self.metric_pipeline.run_std( - generations, - self.task_name, - ds_wrapper.prompt["answer_key"], - ds_wrapper.dataset_info.label, - self.config, + results, logprobs, _ = self.infer_pipeline(prompts, return_probs=True) + calibprob_batch, _ = self.infer_pipeline.compute_logprob_and_length( + calib_prompts, batch[ds_wrapper.dataset_info.answer] ) + predictions.extend(results) + references.extend(list(batch[ds_wrapper.dataset_info.answer])) + generation_probs.extend(logprobs) + calib_probs.extend(calibprob_batch) + + idx += 1 + if idx % 100 == 0: + print(f"Saving results of {idx} batches") + generations = { + "predictions": predictions, + "references": references, + "generation_probs": generation_probs, + "calibration_probs": calib_probs, + "fewshot": selected_sample, + } + + saving_fn(generations) + mean_result = self.metric_pipeline.run_mean( + generations, + self.task_name, + ds_wrapper.prompt["answer_key"], + ds_wrapper.dataset_info.label, + self.config, + ) + print(f"Results of {idx} batches: ", mean_result) + + generations = { + "predictions": predictions, + "references": references, + "generation_probs": generation_probs, + "calibration_probs": calib_probs, + "fewshot": selected_sample, + } + + mean_result = self.metric_pipeline.run_mean( + generations, + self.task_name, + ds_wrapper.prompt["answer_key"], + ds_wrapper.dataset_info.label, + self.config, + ) + std_result = self.metric_pipeline.run_std( + generations, + self.task_name, + ds_wrapper.prompt["answer_key"], + ds_wrapper.dataset_info.label, + self.config, + ) + + final_result = {"mean": mean_result, "std": std_result} + saving_fn(generations, final_result)