Skip to content

Commit

Permalink
Update __reasoning.py
Browse files Browse the repository at this point in the history
src\melt\tools\pipelines\__reasoning.py:5:0: R0914: Too many local variables (24/15) (too-many-locals)
  • Loading branch information
minhtrung23 authored Sep 20, 2024
1 parent 6b2f1b1 commit ada86ff
Showing 1 changed file with 112 additions and 160 deletions.
272 changes: 112 additions & 160 deletions src/melt/tools/pipelines/__reasoning.py
Original file line number Diff line number Diff line change
@@ -1,184 +1,136 @@
" _reasoning"
"reasoning"
import random
from dataclasses import dataclass
from tqdm import tqdm
from utils.utils import format_fewshot

@dataclass
class ReasoningConfig:
"class"
config: any
task_name: str
continue_infer_data: dict = None

class FewShotManager:
"class"
def additional_method(self):
"""
Another public method to satisfy the two-method requirement.
"""
print("This is an additional public method.")
def __init__(self, ds_wrapper, config):
self.ds_wrapper = ds_wrapper
self.config = config
self.selected_sample = []
self.original_few_shot = []
self.calib_few_shot = []
def prepare_few_shot(self):
"pre"
if not self.config.few_shot:
return
from melt.tools.utils.utils import format_fewshot
def __reasoning(self, ds_wrapper, ds_loader, saving_fn, start_idx=0):
predictions = []
references = []
generation_probs = []
calib_probs = []
idx = 0
original_few_shot = []
calib_few_shot = []
selected_sample = []

if self.continue_infer_data is not None:
predictions.extend(self.continue_infer_data["predictions"])
references.extend(self.continue_infer_data["references"])
generation_probs.extend(self.continue_infer_data["generation_probs"])
calib_probs.extend(self.continue_infer_data["calibration_probs"])

if self.few_shot:

def preprocessing_a_record(rec):
return [
rec[self.ds_wrapper.dataset_info.query],
rec[self.ds_wrapper.dataset_info.answer],
rec[ds_wrapper.dataset_info.query],
rec[ds_wrapper.dataset_info.answer],
]

self.selected_sample = [
selected_sample = [
preprocessing_a_record(s)
for s in random.sample(list(self.ds_wrapper.dataset_training), self.config.num_fs)
for s in list(
random.sample(
list(ds_wrapper.dataset_training), self.config.num_fs
)
)
]
self.original_few_shot = format_fewshot(
self.selected_sample,
query_format=self.ds_wrapper.prompt["prompt"],
answer_format=self.ds_wrapper.prompt["answer_format"],
original_few_shot = format_fewshot(
selected_sample,
query_format=ds_wrapper.prompt["prompt"],
answer_format=ds_wrapper.prompt["answer_format"],
)
self.calib_few_shot = format_fewshot(
self.selected_sample,
query_format=self.ds_wrapper.calibration_prompt["prompt"],
answer_format=self.ds_wrapper.prompt["answer_format"],
calib_few_shot = format_fewshot(
selected_sample,
query_format=ds_wrapper.calibration_prompt["prompt"],
answer_format=ds_wrapper.prompt["answer_format"],
)

class ResultsManager:
"class"
def __init__(self, continue_infer_data=None):
self.predictions = []
self.references = []
self.generation_probs = []
self.calib_probs = []

if continue_infer_data:
self.predictions.extend(continue_infer_data["predictions"])
self.references.extend(continue_infer_data["references"])
self.generation_probs.extend(continue_infer_data["generation_probs"])
self.calib_probs.extend(continue_infer_data["calibration_probs"])

def extend_results(self, batch_results, batch_references, batch_logprobs, batch_calibprobs):
"extend"
self.predictions.extend(batch_results)
self.references.extend(batch_references)
self.generation_probs.extend(batch_logprobs)
self.calib_probs.extend(batch_calibprobs)

def get_generations(self, few_shot_sample):
"get"
return {
"predictions": self.predictions,
"references": self.references,
"generation_probs": self.generation_probs,
"calibration_probs": self.calib_probs,
"fewshot": few_shot_sample,
}
for batch in tqdm(ds_loader):
if idx < start_idx:
idx += 1
continue

class ReasoningPipeline:
"class"
def additional_method2(self):
"""
Another public method to satisfy the two-method requirement.
"""
print("This is an additional public method.")
def additional_method3(self):
"""
Another public method to satisfy the two-method requirement.
"""
print("This is an additional public method.")
def __init__(self, reasoning_config: ReasoningConfig, infer_pipeline, metric_pipeline):
self.config = reasoning_config.config
self.task_name = reasoning_config.task_name
self.infer_pipeline = infer_pipeline
self.metric_pipeline = metric_pipeline
self.continue_infer_data = reasoning_config.continue_infer_data

def _reasoning(self, ds_wrapper, ds_loader, saving_fn, start_idx=0):
few_shot_manager = FewShotManager(ds_wrapper, self.config)
few_shot_manager.prepare_few_shot()

results_manager = ResultsManager(self.continue_infer_data)

for idx, batch in enumerate(tqdm(ds_loader)):
if idx < start_idx:
continue

prompts = self._create_prompts(batch, ds_wrapper, few_shot_manager.original_few_shot)
calib_prompts = self._create_calib_prompts(batch,
ds_wrapper, few_shot_manager.calib_few_shot)

results, logprobs, _ = self.infer_pipeline(prompts, return_probs=True)
calibprob_batch, _ = self.infer_pipeline.compute_logprob_and_length(
calib_prompts, batch[ds_wrapper.dataset_info.answer]
)

results_manager.extend_results(
results,
batch[ds_wrapper.dataset_info.answer],
logprobs,
calibprob_batch
)

if (idx + 1) % 100 == 0:
self._save_intermediate_results(idx + 1, results_manager, ds_wrapper, saving_fn)

self._save_final_results(results_manager, ds_wrapper, saving_fn)

def _create_prompts(self, batch, ds_wrapper, few_shot):
return [
prompts = [
[
{"role": "system", "content": ds_wrapper.prompt["system_prompt"]},
*few_shot,
{"role": "user", "content": ds_wrapper.prompt["prompt"].format(rule)},
{
"role": "system",
"content": ds_wrapper.prompt["system_prompt"],
},
*original_few_shot,
{
"role": "user",
"content": ds_wrapper.prompt["prompt"].format(rule),
},
]
for rule in batch[ds_wrapper.dataset_info.query]
]

def _create_calib_prompts(self, batch, ds_wrapper, calib_few_shot):
return [
calib_prompts = [
[
{"role": "system", "content": ds_wrapper.calibration_prompt["system_prompt"]},
{
"role": "system",
"content": ds_wrapper.calibration_prompt["system_prompt"],
},
*calib_few_shot,
{"role": "user", "content": ds_wrapper.calibration_prompt["prompt"].format(rule)},
{
"role": "user",
"content": ds_wrapper.calibration_prompt["prompt"].format(rule),
},
]
for rule in batch[ds_wrapper.dataset_info.query]
]

def _save_intermediate_results(self, batch_count, results_manager, ds_wrapper, saving_fn):
print(f"Saving results of {batch_count} batches")
generations = results_manager.get_generations(results_manager.selected_sample)
saving_fn(generations)
mean_result = self._calculate_mean_result(generations, ds_wrapper)
print(f"Results of {batch_count} batches: ", mean_result)

def _save_final_results(self, results_manager, ds_wrapper, saving_fn):
generations = results_manager.get_generations(results_manager.selected_sample)
mean_result = self._calculate_mean_result(generations, ds_wrapper)
std_result = self._calculate_std_result(generations, ds_wrapper)
final_result = {"mean": mean_result, "std": std_result}
saving_fn(generations, final_result)

def _calculate_mean_result(self, generations, ds_wrapper):
return self.metric_pipeline.run_mean(
generations,
self.task_name,
ds_wrapper.prompt["answer_key"],
ds_wrapper.dataset_info.label,
self.config,
)

def _calculate_std_result(self, generations, ds_wrapper):
return self.metric_pipeline.run_std(
generations,
self.task_name,
ds_wrapper.prompt["answer_key"],
ds_wrapper.dataset_info.label,
self.config,
results, logprobs, _ = self.infer_pipeline(prompts, return_probs=True)
calibprob_batch, _ = self.infer_pipeline.compute_logprob_and_length(
calib_prompts, batch[ds_wrapper.dataset_info.answer]
)
predictions.extend(results)
references.extend(list(batch[ds_wrapper.dataset_info.answer]))
generation_probs.extend(logprobs)
calib_probs.extend(calibprob_batch)

idx += 1
if idx % 100 == 0:
print(f"Saving results of {idx} batches")
generations = {
"predictions": predictions,
"references": references,
"generation_probs": generation_probs,
"calibration_probs": calib_probs,
"fewshot": selected_sample,
}

saving_fn(generations)
mean_result = self.metric_pipeline.run_mean(
generations,
self.task_name,
ds_wrapper.prompt["answer_key"],
ds_wrapper.dataset_info.label,
self.config,
)
print(f"Results of {idx} batches: ", mean_result)

generations = {
"predictions": predictions,
"references": references,
"generation_probs": generation_probs,
"calibration_probs": calib_probs,
"fewshot": selected_sample,
}

mean_result = self.metric_pipeline.run_mean(
generations,
self.task_name,
ds_wrapper.prompt["answer_key"],
ds_wrapper.dataset_info.label,
self.config,
)
std_result = self.metric_pipeline.run_std(
generations,
self.task_name,
ds_wrapper.prompt["answer_key"],
ds_wrapper.dataset_info.label,
self.config,
)

final_result = {"mean": mean_result, "std": std_result}
saving_fn(generations, final_result)

0 comments on commit ada86ff

Please sign in to comment.