forked from stair-lab/melt
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
src\melt\tools\pipelines\__summarization.py:6:0: R0914: Too many local variables (22/15) (too-many-locals)
- Loading branch information
1 parent
ada86ff
commit fd584a5
Showing
1 changed file
with
101 additions
and
161 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,178 +1,118 @@ | ||
""" | ||
This module contains the summarization pipeline for processing and evaluating | ||
text summarization tasks. | ||
It uses few-shot learning for prompt generation and handles the inference process | ||
using the provided model. Results are saved periodically and at the end. | ||
""" | ||
|
||
"__summarization" | ||
import random | ||
from typing import List, Dict, Any, Callable | ||
from dataclasses import dataclass | ||
from utils.utils import format_fewshot | ||
|
||
try: | ||
from tqdm import tqdm | ||
except ImportError: | ||
def tqdm(iterable): | ||
""" | ||
A simple replacement for tqdm if it's not installed. | ||
Args: | ||
iterable: The iterable to wrap. | ||
Returns: | ||
The original iterable. | ||
""" | ||
return iterable | ||
|
||
@dataclass | ||
class SummarizationConfig: | ||
"""Configuration for the summarization pipeline.""" | ||
num_fs: int | ||
few_shot: bool | ||
continue_infer_data: Dict[str, List] = None | ||
|
||
class SummarizationPipeline: | ||
""" | ||
A pipeline for summarizing documents and evaluating the performance. | ||
This class encapsulates the logic for document summarization, including | ||
few-shot learning, batch processing, and result evaluation. | ||
""" | ||
|
||
def __init__(self, config: SummarizationConfig, metric_pipeline: | ||
Any, infer_pipeline: Any, task_name: str): | ||
self.config = config | ||
self.metric_pipeline = metric_pipeline | ||
self.infer_pipeline = infer_pipeline | ||
self.task_name = task_name | ||
self.data = self._initialize_data() | ||
|
||
def _summarization(self, ds_wrapper: Any, ds_loader: | ||
Any, saving_fn: Callable, start_idx: int = 0) -> None: | ||
""" | ||
Run the summarization pipeline. | ||
Args: | ||
ds_wrapper: A wrapper for the dataset, providing information and prompts. | ||
ds_loader: DataLoader for loading batches of data. | ||
saving_fn: Function to save the results. | ||
start_idx: Index to start processing from. | ||
""" | ||
selected_sample, original_few_shot = self._prepare_few_shot_data(ds_wrapper) | ||
|
||
for idx, batch in enumerate(tqdm(ds_loader)): | ||
if idx < start_idx: | ||
continue | ||
|
||
self._process_batch(batch, ds_wrapper, original_few_shot) | ||
|
||
if (idx + 1) % 100 == 0: | ||
self._save_intermediate_results(idx + 1, selected_sample, saving_fn, ds_wrapper) | ||
|
||
self._save_final_results(selected_sample, saving_fn, ds_wrapper) | ||
|
||
def get_results(self) -> Dict[str, List]: | ||
""" | ||
Get the current results of the summarization pipeline. | ||
Returns: | ||
A dictionary containing the current results. | ||
""" | ||
return self.data | ||
from tqdm import tqdm | ||
from melt.tools.utils.utils import format_fewshot | ||
|
||
def __summarization(self, ds_wrapper, ds_loader, saving_fn, start_idx=0): | ||
original_documents = [] | ||
predictions = [] | ||
original_few_shot = [] | ||
selected_sample = [] | ||
references = [] | ||
generation_probs = [] | ||
if self.continue_infer_data is not None: | ||
original_documents.extend( | ||
self.continue_infer_data["original_documents"] | ||
) | ||
predictions.extend(self.continue_infer_data["predictions"]) | ||
references.extend(self.continue_infer_data["references"]) | ||
generation_probs.extend( | ||
self.continue_infer_data["generation_probs"] | ||
) | ||
idx = 0 | ||
if self.few_shot: | ||
|
||
def _initialize_data(self) -> Dict[str, List]: | ||
"""Initialize data structures for storing results.""" | ||
data = { | ||
"original_documents": [], | ||
"predictions": [], | ||
"references": [], | ||
"generation_probs": [] | ||
} | ||
if self.config.continue_infer_data: | ||
for key, value in self.config.continue_infer_data.items(): | ||
data[key].extend(value) | ||
return data | ||
def preprocessing_a_record(rec): | ||
return [ | ||
rec[ds_wrapper.dataset_info.source], | ||
rec[ds_wrapper.dataset_info.target], | ||
] | ||
|
||
def _prepare_few_shot_data(self, ds_wrapper: Any) -> tuple: | ||
"""Prepare few-shot samples and format them.""" | ||
if not self.config.few_shot: | ||
return [], [] | ||
selected_sample_idx = list( | ||
random.sample( | ||
range(len(ds_wrapper.dataset_training)), self.config.num_fs | ||
) | ||
) | ||
selected_sample = [ | ||
preprocessing_a_record(ds_wrapper.dataset_training[s]) | ||
for s in selected_sample_idx | ||
] | ||
|
||
selected_sample = self._select_few_shot_samples(ds_wrapper) | ||
original_few_shot = format_fewshot( | ||
selected_sample, | ||
query_format=ds_wrapper.prompt["prompt"], | ||
answer_format=ds_wrapper.prompt["answer_format"], | ||
) | ||
return selected_sample, original_few_shot | ||
for batch in tqdm(ds_loader): | ||
if idx < start_idx: | ||
idx += 1 | ||
continue | ||
|
||
def _select_few_shot_samples(self, ds_wrapper: Any) -> List[List[str]]: | ||
"""Select few-shot samples from the training dataset.""" | ||
selected_sample_idx = random.sample( | ||
range(len(ds_wrapper.dataset_training)), self.config.num_fs | ||
) | ||
return [ | ||
prompts = [ | ||
[ | ||
ds_wrapper.dataset_training[s][ds_wrapper.dataset_info.source], | ||
ds_wrapper.dataset_training[s][ds_wrapper.dataset_info.target] | ||
] | ||
for s in selected_sample_idx | ||
] | ||
def _process_batch(self, batch: Dict[str, Any], ds_wrapper: Any, | ||
original_few_shot: List[Dict[str, str]]) -> None: | ||
"""Process a single batch of data.""" | ||
prompts = self._create_prompts(batch, ds_wrapper, original_few_shot) | ||
results, logprobs, _ = self.infer_pipeline(prompts, return_probs=True) | ||
|
||
self.data["original_documents"].extend(batch[ds_wrapper.dataset_info.source]) | ||
self.data["predictions"].extend(results) | ||
self.data["references"].extend(batch[ds_wrapper.dataset_info.target]) | ||
self.data["generation_probs"].extend(logprobs) | ||
def _create_prompts(self, batch: Dict[str, Any], ds_wrapper: Any, | ||
original_few_shot: List[Dict[str, str]]) -> List[List[Dict[str, str]]]: | ||
"""Create prompts for the current batch.""" | ||
return [ | ||
[ | ||
{"role": "system", "content": ds_wrapper.prompt["system_prompt"]}, | ||
{ | ||
"role": "system", | ||
"content": ds_wrapper.prompt["system_prompt"], | ||
}, | ||
*original_few_shot, | ||
{"role": "user", "content": ds_wrapper.prompt["prompt"].format(document)}, | ||
{ | ||
"role": "user", | ||
"content": ds_wrapper.prompt["prompt"].format( | ||
document, | ||
), | ||
}, | ||
] | ||
for document in batch[ds_wrapper.dataset_info.source] | ||
] | ||
def _save_intermediate_results(self, idx: int, selected_sample: List[List[str]], | ||
saving_fn: Callable, ds_wrapper: Any) -> None: | ||
"""Save intermediate results and print mean results.""" | ||
print(f"Saving results of {idx} batches") | ||
generations = {**self.data, "fewshot": selected_sample} | ||
saving_fn(generations) | ||
mean_result = self._calculate_mean_result(generations, ds_wrapper) | ||
print(f"Results of {idx} batches: ", mean_result) | ||
def _save_final_results(self, selected_sample: List[List[str]], | ||
saving_fn: Callable, ds_wrapper: Any) -> None: | ||
"""Save final results including mean and standard deviation.""" | ||
generations = {**self.data, "fewshot": selected_sample} | ||
mean_result = self._calculate_mean_result(generations, ds_wrapper) | ||
std_result = self._calculate_std_result(generations, ds_wrapper) | ||
final_result = {"mean": mean_result, "std": std_result} | ||
saving_fn(generations, final_result) | ||
def _calculate_mean_result(self, generations: Dict[str, Any],ds_wrapper: Any) -> Dict[str, Any]: | ||
"""Calculate mean results using the metric pipeline.""" | ||
return self.metric_pipeline.run_mean( | ||
generations, | ||
self.task_name, | ||
ds_wrapper.prompt["answer_key"], | ||
ds_wrapper.dataset_info.label, | ||
self.config, | ||
) | ||
original_documents.extend(list(batch[ds_wrapper.dataset_info.source])) | ||
|
||
def _calculate_std_result(self, generations: Dict[str, Any], ds_wrapper: Any) -> Dict[str, Any]: | ||
"""Calculate standard deviation of results using the metric pipeline.""" | ||
return self.metric_pipeline.run_std( | ||
generations, | ||
self.task_name, | ||
ds_wrapper.prompt["answer_key"], | ||
ds_wrapper.dataset_info.label, | ||
self.config, | ||
results, logprobs, _ = self.infer_pipeline( | ||
prompts, return_probs=True | ||
) | ||
predictions.extend(results) | ||
references.extend(list(batch[ds_wrapper.dataset_info.target])) | ||
generation_probs.extend(logprobs) | ||
|
||
idx += 1 | ||
if idx % 100 == 0: | ||
print(f"Saving results of {idx} batches") | ||
generations = { | ||
"original_documents": original_documents, | ||
"predictions": predictions, | ||
"references": references, | ||
"generation_probs": generation_probs, | ||
"fewshot": selected_sample, | ||
} | ||
saving_fn(generations) | ||
mean_result = self.metric_pipeline.run_mean( | ||
generations, | ||
self.task_name, | ||
ds_wrapper.prompt["answer_key"], | ||
ds_wrapper.dataset_info.label, | ||
self.config, | ||
) | ||
print(f"Results of {idx} batches: ", mean_result) | ||
|
||
generations = { | ||
"original_documents": original_documents, | ||
"predictions": predictions, | ||
"references": references, | ||
"generation_probs": generation_probs, | ||
"fewshot": selected_sample, | ||
} | ||
mean_result = self.metric_pipeline.run_mean( | ||
generations, | ||
self.task_name, | ||
ds_wrapper.prompt["answer_key"], | ||
ds_wrapper.dataset_info.label, | ||
self.config, | ||
) | ||
std_result = self.metric_pipeline.run_std( | ||
generations, | ||
self.task_name, | ||
ds_wrapper.prompt["answer_key"], | ||
ds_wrapper.dataset_info.label, | ||
self.config, | ||
) | ||
final_result = {"mean": mean_result, "std": std_result} | ||
saving_fn(generations, final_result) |