From 1258584e13b4b6dd3e8bdc3fd823a585493d97c6 Mon Sep 17 00:00:00 2001 From: truonggiang1757 Date: Fri, 7 Feb 2025 09:06:30 +0700 Subject: [PATCH] Updated completion for VTIT_chatbot --- evals/completion_fns/custom_completion.py | 68 ++++++++++++----------- 1 file changed, 35 insertions(+), 33 deletions(-) diff --git a/evals/completion_fns/custom_completion.py b/evals/completion_fns/custom_completion.py index b5df585ad1..98756895aa 100644 --- a/evals/completion_fns/custom_completion.py +++ b/evals/completion_fns/custom_completion.py @@ -12,6 +12,13 @@ ) from evals.record import record_sampling +import ssl +import certifi +import json + +ssl_context = ssl.create_default_context(cafile=certifi.where()) + + class CustomCompletionResult(CompletionResult): def __init__(self, raw_data: Any, prompt: Any): self.raw_data = raw_data @@ -20,20 +27,23 @@ def __init__(self, raw_data: Any, prompt: Any): def get_completions(self) -> list[str]: completions = [] if self.raw_data: - completions.append(self.raw_data.get("response", "")) - return completions + response_text = self.raw_data.get("response", "") + note_text = self.raw_data.get("note", "") + full_response = response_text + if note_text: + full_response += f"\n{note_text}" + completions.append(full_response) + return completions class CustomCompletionFn(CompletionFnSpec): def __init__( self, api_base: str, - api_key: Optional[str] = None, model: Optional[str] = "gpt-4", extra_options: Optional[dict] = {}, registry: Optional[Any] = None, ): self.api_base = api_base - self.api_key = api_key self.model = model self.extra_options = extra_options self.registry = registry @@ -42,36 +52,28 @@ def __call__( prompt: Union[str, OpenAICreateChatPrompt], **kwargs, ) -> CustomCompletionResult: - if not isinstance(prompt, Prompt): - assert ( - isinstance(prompt, str) - or (isinstance(prompt, list) and all(isinstance(token, int) for token in prompt)) - or (isinstance(prompt, list) and all(isinstance(token, str) for token in prompt)) - or (isinstance(prompt, list) and all(isinstance(msg, dict) for msg in prompt)) - ), f"Got type {type(prompt)}, with val {type(prompt[0])} for prompt, expected str or list[int] or list[str] or list[dict[str, str]]" - - prompt = ChatCompletionPrompt( - raw_prompt=prompt, - ) - - openai_create_prompt: OpenAICreateChatPrompt = prompt.to_formatted_prompt() - - payload = { - "prompt": openai_create_prompt, - "model": self.model, - **self.extra_options, - } - - headers = {"Authorization": f"Bearer {self.api_key}"} if self.api_key else {} - response = requests.post(self.api_base, json=payload, headers=headers) - - if response.status_code != 200: - logging.warning(f"API request failed with status code {response.status_code}: {response.text}") - raise Exception(f"API request failed: {response.text}") - - result = response.json() - result = CustomCompletionResult(raw_data=result, prompt=openai_create_prompt) + if isinstance(prompt, str): + query_text = prompt + elif isinstance(prompt, ChatCompletionPrompt): + query_text = prompt.raw_prompt + elif isinstance(prompt, list): # Handle list format + if all(isinstance(msg, dict) and "content" in msg for msg in prompt): + query_text = prompt[-1]["content"] # Extract last message + else: + raise ValueError(f"Unsupported prompt format: {type(prompt)}") + else: + raise ValueError(f"Unsupported prompt format: {type(prompt)}") + payload = {"question": query_text} + logging.info(f"Sending API request with payload: {json.dumps(payload, indent=4)}") + try: + response = requests.post(self.api_base, json=payload, proxies={"http": None, "https": None}, timeout=40000, verify=False) + response.raise_for_status() # Raises error for non-200 responses + result_data = response.json() + except requests.exceptions.RequestException as e: + logging.error(f"API request failed: {e}") + raise Exception(f"API request failed: {e}") + result = CustomCompletionResult(raw_data=result_data, prompt=query_text) record_sampling( prompt=result.prompt, sampled=result.get_completions(),