forked from stair-lab/melt
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
src\melt\tools\pipelines\__reasoning.py:5:0: R0914: Too many local variables (24/15) (too-many-locals)
- Loading branch information
1 parent
6b2f1b1
commit ada86ff
Showing
1 changed file
with
112 additions
and
160 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,184 +1,136 @@ | ||
" _reasoning" | ||
"reasoning" | ||
import random | ||
from dataclasses import dataclass | ||
from tqdm import tqdm | ||
from utils.utils import format_fewshot | ||
|
||
@dataclass | ||
class ReasoningConfig: | ||
"class" | ||
config: any | ||
task_name: str | ||
continue_infer_data: dict = None | ||
|
||
class FewShotManager: | ||
"class" | ||
def additional_method(self): | ||
""" | ||
Another public method to satisfy the two-method requirement. | ||
""" | ||
print("This is an additional public method.") | ||
def __init__(self, ds_wrapper, config): | ||
self.ds_wrapper = ds_wrapper | ||
self.config = config | ||
self.selected_sample = [] | ||
self.original_few_shot = [] | ||
self.calib_few_shot = [] | ||
def prepare_few_shot(self): | ||
"pre" | ||
if not self.config.few_shot: | ||
return | ||
from melt.tools.utils.utils import format_fewshot | ||
def __reasoning(self, ds_wrapper, ds_loader, saving_fn, start_idx=0): | ||
predictions = [] | ||
references = [] | ||
generation_probs = [] | ||
calib_probs = [] | ||
idx = 0 | ||
original_few_shot = [] | ||
calib_few_shot = [] | ||
selected_sample = [] | ||
|
||
if self.continue_infer_data is not None: | ||
predictions.extend(self.continue_infer_data["predictions"]) | ||
references.extend(self.continue_infer_data["references"]) | ||
generation_probs.extend(self.continue_infer_data["generation_probs"]) | ||
calib_probs.extend(self.continue_infer_data["calibration_probs"]) | ||
|
||
if self.few_shot: | ||
|
||
def preprocessing_a_record(rec): | ||
return [ | ||
rec[self.ds_wrapper.dataset_info.query], | ||
rec[self.ds_wrapper.dataset_info.answer], | ||
rec[ds_wrapper.dataset_info.query], | ||
rec[ds_wrapper.dataset_info.answer], | ||
] | ||
|
||
self.selected_sample = [ | ||
selected_sample = [ | ||
preprocessing_a_record(s) | ||
for s in random.sample(list(self.ds_wrapper.dataset_training), self.config.num_fs) | ||
for s in list( | ||
random.sample( | ||
list(ds_wrapper.dataset_training), self.config.num_fs | ||
) | ||
) | ||
] | ||
self.original_few_shot = format_fewshot( | ||
self.selected_sample, | ||
query_format=self.ds_wrapper.prompt["prompt"], | ||
answer_format=self.ds_wrapper.prompt["answer_format"], | ||
original_few_shot = format_fewshot( | ||
selected_sample, | ||
query_format=ds_wrapper.prompt["prompt"], | ||
answer_format=ds_wrapper.prompt["answer_format"], | ||
) | ||
self.calib_few_shot = format_fewshot( | ||
self.selected_sample, | ||
query_format=self.ds_wrapper.calibration_prompt["prompt"], | ||
answer_format=self.ds_wrapper.prompt["answer_format"], | ||
calib_few_shot = format_fewshot( | ||
selected_sample, | ||
query_format=ds_wrapper.calibration_prompt["prompt"], | ||
answer_format=ds_wrapper.prompt["answer_format"], | ||
) | ||
|
||
class ResultsManager: | ||
"class" | ||
def __init__(self, continue_infer_data=None): | ||
self.predictions = [] | ||
self.references = [] | ||
self.generation_probs = [] | ||
self.calib_probs = [] | ||
|
||
if continue_infer_data: | ||
self.predictions.extend(continue_infer_data["predictions"]) | ||
self.references.extend(continue_infer_data["references"]) | ||
self.generation_probs.extend(continue_infer_data["generation_probs"]) | ||
self.calib_probs.extend(continue_infer_data["calibration_probs"]) | ||
|
||
def extend_results(self, batch_results, batch_references, batch_logprobs, batch_calibprobs): | ||
"extend" | ||
self.predictions.extend(batch_results) | ||
self.references.extend(batch_references) | ||
self.generation_probs.extend(batch_logprobs) | ||
self.calib_probs.extend(batch_calibprobs) | ||
|
||
def get_generations(self, few_shot_sample): | ||
"get" | ||
return { | ||
"predictions": self.predictions, | ||
"references": self.references, | ||
"generation_probs": self.generation_probs, | ||
"calibration_probs": self.calib_probs, | ||
"fewshot": few_shot_sample, | ||
} | ||
for batch in tqdm(ds_loader): | ||
if idx < start_idx: | ||
idx += 1 | ||
continue | ||
|
||
class ReasoningPipeline: | ||
"class" | ||
def additional_method2(self): | ||
""" | ||
Another public method to satisfy the two-method requirement. | ||
""" | ||
print("This is an additional public method.") | ||
def additional_method3(self): | ||
""" | ||
Another public method to satisfy the two-method requirement. | ||
""" | ||
print("This is an additional public method.") | ||
def __init__(self, reasoning_config: ReasoningConfig, infer_pipeline, metric_pipeline): | ||
self.config = reasoning_config.config | ||
self.task_name = reasoning_config.task_name | ||
self.infer_pipeline = infer_pipeline | ||
self.metric_pipeline = metric_pipeline | ||
self.continue_infer_data = reasoning_config.continue_infer_data | ||
|
||
def _reasoning(self, ds_wrapper, ds_loader, saving_fn, start_idx=0): | ||
few_shot_manager = FewShotManager(ds_wrapper, self.config) | ||
few_shot_manager.prepare_few_shot() | ||
|
||
results_manager = ResultsManager(self.continue_infer_data) | ||
|
||
for idx, batch in enumerate(tqdm(ds_loader)): | ||
if idx < start_idx: | ||
continue | ||
|
||
prompts = self._create_prompts(batch, ds_wrapper, few_shot_manager.original_few_shot) | ||
calib_prompts = self._create_calib_prompts(batch, | ||
ds_wrapper, few_shot_manager.calib_few_shot) | ||
|
||
results, logprobs, _ = self.infer_pipeline(prompts, return_probs=True) | ||
calibprob_batch, _ = self.infer_pipeline.compute_logprob_and_length( | ||
calib_prompts, batch[ds_wrapper.dataset_info.answer] | ||
) | ||
|
||
results_manager.extend_results( | ||
results, | ||
batch[ds_wrapper.dataset_info.answer], | ||
logprobs, | ||
calibprob_batch | ||
) | ||
|
||
if (idx + 1) % 100 == 0: | ||
self._save_intermediate_results(idx + 1, results_manager, ds_wrapper, saving_fn) | ||
|
||
self._save_final_results(results_manager, ds_wrapper, saving_fn) | ||
|
||
def _create_prompts(self, batch, ds_wrapper, few_shot): | ||
return [ | ||
prompts = [ | ||
[ | ||
{"role": "system", "content": ds_wrapper.prompt["system_prompt"]}, | ||
*few_shot, | ||
{"role": "user", "content": ds_wrapper.prompt["prompt"].format(rule)}, | ||
{ | ||
"role": "system", | ||
"content": ds_wrapper.prompt["system_prompt"], | ||
}, | ||
*original_few_shot, | ||
{ | ||
"role": "user", | ||
"content": ds_wrapper.prompt["prompt"].format(rule), | ||
}, | ||
] | ||
for rule in batch[ds_wrapper.dataset_info.query] | ||
] | ||
|
||
def _create_calib_prompts(self, batch, ds_wrapper, calib_few_shot): | ||
return [ | ||
calib_prompts = [ | ||
[ | ||
{"role": "system", "content": ds_wrapper.calibration_prompt["system_prompt"]}, | ||
{ | ||
"role": "system", | ||
"content": ds_wrapper.calibration_prompt["system_prompt"], | ||
}, | ||
*calib_few_shot, | ||
{"role": "user", "content": ds_wrapper.calibration_prompt["prompt"].format(rule)}, | ||
{ | ||
"role": "user", | ||
"content": ds_wrapper.calibration_prompt["prompt"].format(rule), | ||
}, | ||
] | ||
for rule in batch[ds_wrapper.dataset_info.query] | ||
] | ||
|
||
def _save_intermediate_results(self, batch_count, results_manager, ds_wrapper, saving_fn): | ||
print(f"Saving results of {batch_count} batches") | ||
generations = results_manager.get_generations(results_manager.selected_sample) | ||
saving_fn(generations) | ||
mean_result = self._calculate_mean_result(generations, ds_wrapper) | ||
print(f"Results of {batch_count} batches: ", mean_result) | ||
|
||
def _save_final_results(self, results_manager, ds_wrapper, saving_fn): | ||
generations = results_manager.get_generations(results_manager.selected_sample) | ||
mean_result = self._calculate_mean_result(generations, ds_wrapper) | ||
std_result = self._calculate_std_result(generations, ds_wrapper) | ||
final_result = {"mean": mean_result, "std": std_result} | ||
saving_fn(generations, final_result) | ||
|
||
def _calculate_mean_result(self, generations, ds_wrapper): | ||
return self.metric_pipeline.run_mean( | ||
generations, | ||
self.task_name, | ||
ds_wrapper.prompt["answer_key"], | ||
ds_wrapper.dataset_info.label, | ||
self.config, | ||
) | ||
|
||
def _calculate_std_result(self, generations, ds_wrapper): | ||
return self.metric_pipeline.run_std( | ||
generations, | ||
self.task_name, | ||
ds_wrapper.prompt["answer_key"], | ||
ds_wrapper.dataset_info.label, | ||
self.config, | ||
results, logprobs, _ = self.infer_pipeline(prompts, return_probs=True) | ||
calibprob_batch, _ = self.infer_pipeline.compute_logprob_and_length( | ||
calib_prompts, batch[ds_wrapper.dataset_info.answer] | ||
) | ||
predictions.extend(results) | ||
references.extend(list(batch[ds_wrapper.dataset_info.answer])) | ||
generation_probs.extend(logprobs) | ||
calib_probs.extend(calibprob_batch) | ||
|
||
idx += 1 | ||
if idx % 100 == 0: | ||
print(f"Saving results of {idx} batches") | ||
generations = { | ||
"predictions": predictions, | ||
"references": references, | ||
"generation_probs": generation_probs, | ||
"calibration_probs": calib_probs, | ||
"fewshot": selected_sample, | ||
} | ||
|
||
saving_fn(generations) | ||
mean_result = self.metric_pipeline.run_mean( | ||
generations, | ||
self.task_name, | ||
ds_wrapper.prompt["answer_key"], | ||
ds_wrapper.dataset_info.label, | ||
self.config, | ||
) | ||
print(f"Results of {idx} batches: ", mean_result) | ||
|
||
generations = { | ||
"predictions": predictions, | ||
"references": references, | ||
"generation_probs": generation_probs, | ||
"calibration_probs": calib_probs, | ||
"fewshot": selected_sample, | ||
} | ||
|
||
mean_result = self.metric_pipeline.run_mean( | ||
generations, | ||
self.task_name, | ||
ds_wrapper.prompt["answer_key"], | ||
ds_wrapper.dataset_info.label, | ||
self.config, | ||
) | ||
std_result = self.metric_pipeline.run_std( | ||
generations, | ||
self.task_name, | ||
ds_wrapper.prompt["answer_key"], | ||
ds_wrapper.dataset_info.label, | ||
self.config, | ||
) | ||
|
||
final_result = {"mean": mean_result, "std": std_result} | ||
saving_fn(generations, final_result) |