From 0da27c0679a15324c785210a1c478e0ff2976396 Mon Sep 17 00:00:00 2001 From: Shing Lyu Date: Wed, 5 Jul 2023 14:29:36 +0000 Subject: [PATCH] Fix: fix FLAN-XXL input/output format --- kendra_retriever_samples/kendra_chat_flan_xxl.py | 5 +++-- kendra_retriever_samples/kendra_retriever_flan_xxl.py | 4 ++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/kendra_retriever_samples/kendra_chat_flan_xxl.py b/kendra_retriever_samples/kendra_chat_flan_xxl.py index dd5d237..aff9a9e 100644 --- a/kendra_retriever_samples/kendra_chat_flan_xxl.py +++ b/kendra_retriever_samples/kendra_chat_flan_xxl.py @@ -30,12 +30,13 @@ class ContentHandler(LLMContentHandler): accepts = "application/json" def transform_input(self, prompt: str, model_kwargs: dict) -> bytes: - input_str = json.dumps({"inputs": prompt, "parameters": model_kwargs}) + input_str = json.dumps({"text_inputs": prompt, **model_kwargs}) return input_str.encode('utf-8') def transform_output(self, output: bytes) -> str: response_json = json.loads(output.read().decode("utf-8")) - return response_json[0]["generated_text"] + print(response_json) + return response_json["generated_texts"][0] content_handler = ContentHandler() diff --git a/kendra_retriever_samples/kendra_retriever_flan_xxl.py b/kendra_retriever_samples/kendra_retriever_flan_xxl.py index e10693c..8d61f8b 100644 --- a/kendra_retriever_samples/kendra_retriever_flan_xxl.py +++ b/kendra_retriever_samples/kendra_retriever_flan_xxl.py @@ -18,12 +18,12 @@ class ContentHandler(LLMContentHandler): accepts = "application/json" def transform_input(self, prompt: str, model_kwargs: dict) -> bytes: - input_str = json.dumps({"inputs": prompt, "parameters": model_kwargs}) + input_str = json.dumps({"text_inputs": prompt, **model_kwargs}) return input_str.encode('utf-8') def transform_output(self, output: bytes) -> str: response_json = json.loads(output.read().decode("utf-8")) - return response_json[0]["generated_text"] + return response_json["generated_texts"][0] content_handler = ContentHandler()