Skip to content

Commit

Permalink
Update pipelines.py
Browse files Browse the repository at this point in the history
src\melt\tools\pipelines\pipelines.py:88:31: E1120: No value for argument 'config' in constructor call (no-value-for-parameter)        
src\melt\tools\pipelines\pipelines.py:112:4: R0913: Too many arguments (8/5) (too-many-arguments)
  • Loading branch information
minhtrung23 authored Sep 20, 2024
1 parent ffe2e69 commit cb97c86
Showing 1 changed file with 37 additions and 70 deletions.
107 changes: 37 additions & 70 deletions src/melt/tools/pipelines/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,118 +31,84 @@ def __init__(self, task, config):
# Load generation configuration
with open(
os.path.join(
config.config_dir, config.lang, "generation_config.json"
),
"r",
config.config_dir, config.lang, "generation_config.json"), "r", encoding="utf-8"
) as f:
GenerationConfig = json.load(f)
generation_config = json.load(f)

with open(
os.path.join(config.config_dir, "llm_template.json"), "r"
os.path.join(config.config_dir, "llm_template.json"), "r", encoding="utf-8"
) as f:
LLM_TEMPLATE = json.load(f)
llm_template = json.load(f)

with open(
os.path.join(
config.config_dir, config.lang, "metric_configuration.json"
),
"r",
config.config_dir, config.lang, "metric_configuration.json"), "r", encoding="utf-8"
) as f:
METRIC_CONFIG = json.load(f)
metric_config = json.load(f)

# Load task
self.task_name = task

# Load pipelines
# print(config.tgi)
if config.wtype == "tgi":
self.infer_pipeline = TGIWrapper(
generation_config=GenerationConfig[self.task_name],
template=LLM_TEMPLATE[config.ptemplate],
generation_config=generation_config[self.task_name],
template=llm_template[config.ptemplate],
)
elif config.wtype == "hf":
self.infer_pipeline = HFWrapper(
config=config,
generation_config=GenerationConfig[self.task_name],
template=LLM_TEMPLATE[config.ptemplate],
generation_config=generation_config[self.task_name],
template=llm_template[config.ptemplate],
)
elif config.wtype == "vllm":
self.infer_pipeline = VLLMWrapper(
config=config,
generation_config=GenerationConfig[self.task_name],
template=LLM_TEMPLATE[config.ptemplate],
generation_config=generation_config[self.task_name],
template=llm_template[config.ptemplate],
)
elif config.wtype == "openai":
self.infer_pipeline = OpenAIWrapper(
engine=config.model_name,
generation_config=GenerationConfig[self.task_name],
generation_config=generation_config[self.task_name],
)
elif config.wtype == "gemini":
self.infer_pipeline = GeminiWrapper(
model_name=config.model_name,
generation_config=GenerationConfig[self.task_name],
generation_config=generation_config[self.task_name],
)
else:
raise ValueError("Invalid wrapper type")

self.config = config
self.config.task = self.task_name
self.config.metric_config = METRIC_CONFIG
self.config.metric_config = metric_config
self.few_shot = False
self.continue_infer_data = None
# Metric pipeline configuration
self.metric_pipeline = MetricPipeline()
self.config.filepath = None
self.generation_results_file = None # Initialize in __init__

def __call__(self, ds_wrapper, ds_loader, saving_fn, start_idx=0):
task = self.task_name
task_mapping = {
"question-answering": __question_answering,
"summarization": __summarization,
"translation": __translation,
"language-modeling": __language_modeling,
"text-classification": __multiple_choice_text_classification,
"sentiment-analysis": __multiple_choice_sentiment,
"toxicity-detection": __multiple_choice_toxicity,
"knowledge-mtpchoice": __multiple_choice,
"knowledge-openended": __question_answering_without_context,
"information-retrieval": __information_retrieval,
"reasoning": __reasoning,
"math": __math,
}

if task == "question-answering":
return __question_answering(
ds_wrapper, ds_loader, saving_fn, start_idx
)
elif task == "summarization":
return __summarization(
ds_wrapper, ds_loader, saving_fn, start_idx
)
elif "translation" in task:
return __translation(
ds_wrapper, ds_loader, saving_fn, start_idx
)
elif "language-modeling" in task:
return __language_modeling(
ds_wrapper, ds_loader, saving_fn, start_idx
)
elif "text-classification" in task:
return __multiple_choice_text_classification(
ds_wrapper, ds_loader, saving_fn, start_idx
)
elif task == "sentiment-analysis":
return __multiple_choice_sentiment(
ds_wrapper, ds_loader, saving_fn, start_idx
)
elif task == "toxicity-detection":
return __multiple_choice_toxicity(
ds_wrapper, ds_loader, saving_fn, start_idx
)
elif task == "knowledge-mtpchoice":
return __multiple_choice(
ds_wrapper, ds_loader, saving_fn, start_idx
)
elif task == "knowledge-openended":
return __question_answering_without_context(
ds_wrapper, ds_loader, saving_fn, start_idx
)
elif task == "information-retrieval":
return __information_retrieval(
ds_wrapper, ds_loader, saving_fn, start_idx
)
elif task == "reasoning":
return __reasoning(
ds_wrapper, ds_loader, saving_fn, start_idx
)
elif task == "math":
return __math(ds_wrapper, ds_loader, saving_fn, start_idx)
else:
raise NotImplementedError
if self.task_name in task_mapping:
return task_mapping[self.task_name](ds_wrapper, ds_loader, saving_fn, start_idx)

raise NotImplementedError # Removed unnecessary "else"
def run(
self,
ds_wrapper,
Expand All @@ -153,6 +119,7 @@ def run(
few_shot=False,
continue_infer=None,
):
"run"
self.generation_results_file = generation_results_file
self.config.filepath = generation_results_file
self.continue_infer_data = continue_infer
Expand Down

0 comments on commit cb97c86

Please sign in to comment.