diff --git a/intel_extension_for_transformers/neural_chat/models/base_model.py b/intel_extension_for_transformers/neural_chat/models/base_model.py index 0f80100bd27..c413c70258b 100644 --- a/intel_extension_for_transformers/neural_chat/models/base_model.py +++ b/intel_extension_for_transformers/neural_chat/models/base_model.py @@ -189,7 +189,8 @@ def predict(self, query, config=None): if plugin_name == "safety_checker" and response: return "Your query contains sensitive words, please try another query." else: - pass + if response != None and response != False: + query = response assert query is not None, "Query cannot be None." # LLM inference diff --git a/intel_extension_for_transformers/neural_chat/pipeline/plugins/retrieval/retrieval_agent.py b/intel_extension_for_transformers/neural_chat/pipeline/plugins/retrieval/retrieval_agent.py index dcc7fb55e9b..8987ca63200 100644 --- a/intel_extension_for_transformers/neural_chat/pipeline/plugins/retrieval/retrieval_agent.py +++ b/intel_extension_for_transformers/neural_chat/pipeline/plugins/retrieval/retrieval_agent.py @@ -24,7 +24,7 @@ class Agent_QA(): def __init__(self, persist_dir="./output", process=True, input_path=None, - embedding_model="hkunlp/instructor-large", max_length=512, retrieval_type="dense", + embedding_model="hkunlp/instructor-large", max_length=2048, retrieval_type="dense", document_store=None, top_k=1, search_type="mmr", search_kwargs={"k": 1, "fetch_k": 5}): self.model = None self.tokenizer = None