From cb97c86d37a37dd1673e931e8d78a78e5eab15f9 Mon Sep 17 00:00:00 2001 From: minhtrung23 Date: Fri, 20 Sep 2024 20:07:39 +0700 Subject: [PATCH] Update pipelines.py src\melt\tools\pipelines\pipelines.py:88:31: E1120: No value for argument 'config' in constructor call (no-value-for-parameter) src\melt\tools\pipelines\pipelines.py:112:4: R0913: Too many arguments (8/5) (too-many-arguments) --- src/melt/tools/pipelines/pipelines.py | 107 +++++++++----------------- 1 file changed, 37 insertions(+), 70 deletions(-) diff --git a/src/melt/tools/pipelines/pipelines.py b/src/melt/tools/pipelines/pipelines.py index c87365a..87213a5 100644 --- a/src/melt/tools/pipelines/pipelines.py +++ b/src/melt/tools/pipelines/pipelines.py @@ -31,118 +31,84 @@ def __init__(self, task, config): # Load generation configuration with open( os.path.join( - config.config_dir, config.lang, "generation_config.json" - ), - "r", + config.config_dir, config.lang, "generation_config.json"), "r", encoding="utf-8" ) as f: - GenerationConfig = json.load(f) + generation_config = json.load(f) with open( - os.path.join(config.config_dir, "llm_template.json"), "r" + os.path.join(config.config_dir, "llm_template.json"), "r", encoding="utf-8" ) as f: - LLM_TEMPLATE = json.load(f) + llm_template = json.load(f) with open( os.path.join( - config.config_dir, config.lang, "metric_configuration.json" - ), - "r", + config.config_dir, config.lang, "metric_configuration.json"), "r", encoding="utf-8" ) as f: - METRIC_CONFIG = json.load(f) + metric_config = json.load(f) + # Load task self.task_name = task # Load pipelines - # print(config.tgi) if config.wtype == "tgi": self.infer_pipeline = TGIWrapper( - generation_config=GenerationConfig[self.task_name], - template=LLM_TEMPLATE[config.ptemplate], + generation_config=generation_config[self.task_name], + template=llm_template[config.ptemplate], ) elif config.wtype == "hf": self.infer_pipeline = HFWrapper( config=config, - generation_config=GenerationConfig[self.task_name], - template=LLM_TEMPLATE[config.ptemplate], + generation_config=generation_config[self.task_name], + template=llm_template[config.ptemplate], ) elif config.wtype == "vllm": self.infer_pipeline = VLLMWrapper( config=config, - generation_config=GenerationConfig[self.task_name], - template=LLM_TEMPLATE[config.ptemplate], + generation_config=generation_config[self.task_name], + template=llm_template[config.ptemplate], ) elif config.wtype == "openai": self.infer_pipeline = OpenAIWrapper( engine=config.model_name, - generation_config=GenerationConfig[self.task_name], + generation_config=generation_config[self.task_name], ) elif config.wtype == "gemini": self.infer_pipeline = GeminiWrapper( model_name=config.model_name, - generation_config=GenerationConfig[self.task_name], + generation_config=generation_config[self.task_name], ) else: raise ValueError("Invalid wrapper type") self.config = config self.config.task = self.task_name - self.config.metric_config = METRIC_CONFIG + self.config.metric_config = metric_config self.few_shot = False self.continue_infer_data = None - # Metric pipeline configuration self.metric_pipeline = MetricPipeline() self.config.filepath = None + self.generation_results_file = None # Initialize in __init__ + def __call__(self, ds_wrapper, ds_loader, saving_fn, start_idx=0): - task = self.task_name + task_mapping = { + "question-answering": __question_answering, + "summarization": __summarization, + "translation": __translation, + "language-modeling": __language_modeling, + "text-classification": __multiple_choice_text_classification, + "sentiment-analysis": __multiple_choice_sentiment, + "toxicity-detection": __multiple_choice_toxicity, + "knowledge-mtpchoice": __multiple_choice, + "knowledge-openended": __question_answering_without_context, + "information-retrieval": __information_retrieval, + "reasoning": __reasoning, + "math": __math, + } - if task == "question-answering": - return __question_answering( - ds_wrapper, ds_loader, saving_fn, start_idx - ) - elif task == "summarization": - return __summarization( - ds_wrapper, ds_loader, saving_fn, start_idx - ) - elif "translation" in task: - return __translation( - ds_wrapper, ds_loader, saving_fn, start_idx - ) - elif "language-modeling" in task: - return __language_modeling( - ds_wrapper, ds_loader, saving_fn, start_idx - ) - elif "text-classification" in task: - return __multiple_choice_text_classification( - ds_wrapper, ds_loader, saving_fn, start_idx - ) - elif task == "sentiment-analysis": - return __multiple_choice_sentiment( - ds_wrapper, ds_loader, saving_fn, start_idx - ) - elif task == "toxicity-detection": - return __multiple_choice_toxicity( - ds_wrapper, ds_loader, saving_fn, start_idx - ) - elif task == "knowledge-mtpchoice": - return __multiple_choice( - ds_wrapper, ds_loader, saving_fn, start_idx - ) - elif task == "knowledge-openended": - return __question_answering_without_context( - ds_wrapper, ds_loader, saving_fn, start_idx - ) - elif task == "information-retrieval": - return __information_retrieval( - ds_wrapper, ds_loader, saving_fn, start_idx - ) - elif task == "reasoning": - return __reasoning( - ds_wrapper, ds_loader, saving_fn, start_idx - ) - elif task == "math": - return __math(ds_wrapper, ds_loader, saving_fn, start_idx) - else: - raise NotImplementedError + if self.task_name in task_mapping: + return task_mapping[self.task_name](ds_wrapper, ds_loader, saving_fn, start_idx) + + raise NotImplementedError # Removed unnecessary "else" def run( self, ds_wrapper, @@ -153,6 +119,7 @@ def run( few_shot=False, continue_infer=None, ): + "run" self.generation_results_file = generation_results_file self.config.filepath = generation_results_file self.continue_infer_data = continue_infer