Skip to content

Commit

Permalink
Update __summarization.py
Browse files Browse the repository at this point in the history
src\melt\tools\pipelines\__summarization.py:6:0: R0914: Too many local variables (22/15) (too-many-locals)
  • Loading branch information
minhtrung23 authored Sep 20, 2024
1 parent ada86ff commit fd584a5
Showing 1 changed file with 101 additions and 161 deletions.
262 changes: 101 additions & 161 deletions src/melt/tools/pipelines/__summarization.py
Original file line number Diff line number Diff line change
@@ -1,178 +1,118 @@
"""
This module contains the summarization pipeline for processing and evaluating
text summarization tasks.
It uses few-shot learning for prompt generation and handles the inference process
using the provided model. Results are saved periodically and at the end.
"""

"__summarization"
import random
from typing import List, Dict, Any, Callable
from dataclasses import dataclass
from utils.utils import format_fewshot

try:
from tqdm import tqdm
except ImportError:
def tqdm(iterable):
"""
A simple replacement for tqdm if it's not installed.
Args:
iterable: The iterable to wrap.
Returns:
The original iterable.
"""
return iterable

@dataclass
class SummarizationConfig:
"""Configuration for the summarization pipeline."""
num_fs: int
few_shot: bool
continue_infer_data: Dict[str, List] = None

class SummarizationPipeline:
"""
A pipeline for summarizing documents and evaluating the performance.
This class encapsulates the logic for document summarization, including
few-shot learning, batch processing, and result evaluation.
"""

def __init__(self, config: SummarizationConfig, metric_pipeline:
Any, infer_pipeline: Any, task_name: str):
self.config = config
self.metric_pipeline = metric_pipeline
self.infer_pipeline = infer_pipeline
self.task_name = task_name
self.data = self._initialize_data()

def _summarization(self, ds_wrapper: Any, ds_loader:
Any, saving_fn: Callable, start_idx: int = 0) -> None:
"""
Run the summarization pipeline.
Args:
ds_wrapper: A wrapper for the dataset, providing information and prompts.
ds_loader: DataLoader for loading batches of data.
saving_fn: Function to save the results.
start_idx: Index to start processing from.
"""
selected_sample, original_few_shot = self._prepare_few_shot_data(ds_wrapper)

for idx, batch in enumerate(tqdm(ds_loader)):
if idx < start_idx:
continue

self._process_batch(batch, ds_wrapper, original_few_shot)

if (idx + 1) % 100 == 0:
self._save_intermediate_results(idx + 1, selected_sample, saving_fn, ds_wrapper)

self._save_final_results(selected_sample, saving_fn, ds_wrapper)

def get_results(self) -> Dict[str, List]:
"""
Get the current results of the summarization pipeline.
Returns:
A dictionary containing the current results.
"""
return self.data
from tqdm import tqdm
from melt.tools.utils.utils import format_fewshot

def __summarization(self, ds_wrapper, ds_loader, saving_fn, start_idx=0):
original_documents = []
predictions = []
original_few_shot = []
selected_sample = []
references = []
generation_probs = []
if self.continue_infer_data is not None:
original_documents.extend(
self.continue_infer_data["original_documents"]
)
predictions.extend(self.continue_infer_data["predictions"])
references.extend(self.continue_infer_data["references"])
generation_probs.extend(
self.continue_infer_data["generation_probs"]
)
idx = 0
if self.few_shot:

def _initialize_data(self) -> Dict[str, List]:
"""Initialize data structures for storing results."""
data = {
"original_documents": [],
"predictions": [],
"references": [],
"generation_probs": []
}
if self.config.continue_infer_data:
for key, value in self.config.continue_infer_data.items():
data[key].extend(value)
return data
def preprocessing_a_record(rec):
return [
rec[ds_wrapper.dataset_info.source],
rec[ds_wrapper.dataset_info.target],
]

def _prepare_few_shot_data(self, ds_wrapper: Any) -> tuple:
"""Prepare few-shot samples and format them."""
if not self.config.few_shot:
return [], []
selected_sample_idx = list(
random.sample(
range(len(ds_wrapper.dataset_training)), self.config.num_fs
)
)
selected_sample = [
preprocessing_a_record(ds_wrapper.dataset_training[s])
for s in selected_sample_idx
]

selected_sample = self._select_few_shot_samples(ds_wrapper)
original_few_shot = format_fewshot(
selected_sample,
query_format=ds_wrapper.prompt["prompt"],
answer_format=ds_wrapper.prompt["answer_format"],
)
return selected_sample, original_few_shot
for batch in tqdm(ds_loader):
if idx < start_idx:
idx += 1
continue

def _select_few_shot_samples(self, ds_wrapper: Any) -> List[List[str]]:
"""Select few-shot samples from the training dataset."""
selected_sample_idx = random.sample(
range(len(ds_wrapper.dataset_training)), self.config.num_fs
)
return [
prompts = [
[
ds_wrapper.dataset_training[s][ds_wrapper.dataset_info.source],
ds_wrapper.dataset_training[s][ds_wrapper.dataset_info.target]
]
for s in selected_sample_idx
]
def _process_batch(self, batch: Dict[str, Any], ds_wrapper: Any,
original_few_shot: List[Dict[str, str]]) -> None:
"""Process a single batch of data."""
prompts = self._create_prompts(batch, ds_wrapper, original_few_shot)
results, logprobs, _ = self.infer_pipeline(prompts, return_probs=True)

self.data["original_documents"].extend(batch[ds_wrapper.dataset_info.source])
self.data["predictions"].extend(results)
self.data["references"].extend(batch[ds_wrapper.dataset_info.target])
self.data["generation_probs"].extend(logprobs)
def _create_prompts(self, batch: Dict[str, Any], ds_wrapper: Any,
original_few_shot: List[Dict[str, str]]) -> List[List[Dict[str, str]]]:
"""Create prompts for the current batch."""
return [
[
{"role": "system", "content": ds_wrapper.prompt["system_prompt"]},
{
"role": "system",
"content": ds_wrapper.prompt["system_prompt"],
},
*original_few_shot,
{"role": "user", "content": ds_wrapper.prompt["prompt"].format(document)},
{
"role": "user",
"content": ds_wrapper.prompt["prompt"].format(
document,
),
},
]
for document in batch[ds_wrapper.dataset_info.source]
]
def _save_intermediate_results(self, idx: int, selected_sample: List[List[str]],
saving_fn: Callable, ds_wrapper: Any) -> None:
"""Save intermediate results and print mean results."""
print(f"Saving results of {idx} batches")
generations = {**self.data, "fewshot": selected_sample}
saving_fn(generations)
mean_result = self._calculate_mean_result(generations, ds_wrapper)
print(f"Results of {idx} batches: ", mean_result)
def _save_final_results(self, selected_sample: List[List[str]],
saving_fn: Callable, ds_wrapper: Any) -> None:
"""Save final results including mean and standard deviation."""
generations = {**self.data, "fewshot": selected_sample}
mean_result = self._calculate_mean_result(generations, ds_wrapper)
std_result = self._calculate_std_result(generations, ds_wrapper)
final_result = {"mean": mean_result, "std": std_result}
saving_fn(generations, final_result)
def _calculate_mean_result(self, generations: Dict[str, Any],ds_wrapper: Any) -> Dict[str, Any]:
"""Calculate mean results using the metric pipeline."""
return self.metric_pipeline.run_mean(
generations,
self.task_name,
ds_wrapper.prompt["answer_key"],
ds_wrapper.dataset_info.label,
self.config,
)
original_documents.extend(list(batch[ds_wrapper.dataset_info.source]))

def _calculate_std_result(self, generations: Dict[str, Any], ds_wrapper: Any) -> Dict[str, Any]:
"""Calculate standard deviation of results using the metric pipeline."""
return self.metric_pipeline.run_std(
generations,
self.task_name,
ds_wrapper.prompt["answer_key"],
ds_wrapper.dataset_info.label,
self.config,
results, logprobs, _ = self.infer_pipeline(
prompts, return_probs=True
)
predictions.extend(results)
references.extend(list(batch[ds_wrapper.dataset_info.target]))
generation_probs.extend(logprobs)

idx += 1
if idx % 100 == 0:
print(f"Saving results of {idx} batches")
generations = {
"original_documents": original_documents,
"predictions": predictions,
"references": references,
"generation_probs": generation_probs,
"fewshot": selected_sample,
}
saving_fn(generations)
mean_result = self.metric_pipeline.run_mean(
generations,
self.task_name,
ds_wrapper.prompt["answer_key"],
ds_wrapper.dataset_info.label,
self.config,
)
print(f"Results of {idx} batches: ", mean_result)

generations = {
"original_documents": original_documents,
"predictions": predictions,
"references": references,
"generation_probs": generation_probs,
"fewshot": selected_sample,
}
mean_result = self.metric_pipeline.run_mean(
generations,
self.task_name,
ds_wrapper.prompt["answer_key"],
ds_wrapper.dataset_info.label,
self.config,
)
std_result = self.metric_pipeline.run_std(
generations,
self.task_name,
ds_wrapper.prompt["answer_key"],
ds_wrapper.dataset_info.label,
self.config,
)
final_result = {"mean": mean_result, "std": std_result}
saving_fn(generations, final_result)

0 comments on commit fd584a5

Please sign in to comment.