Skip to content

Commit

Permalink
Updated completion for VTIT_chatbot
Browse files Browse the repository at this point in the history
  • Loading branch information
truonggiang1757 committed Feb 7, 2025
1 parent 86754a8 commit 1258584
Showing 1 changed file with 35 additions and 33 deletions.
68 changes: 35 additions & 33 deletions evals/completion_fns/custom_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,13 @@
)
from evals.record import record_sampling

import ssl
import certifi
import json

ssl_context = ssl.create_default_context(cafile=certifi.where())


class CustomCompletionResult(CompletionResult):
def __init__(self, raw_data: Any, prompt: Any):
self.raw_data = raw_data
Expand All @@ -20,20 +27,23 @@ def __init__(self, raw_data: Any, prompt: Any):
def get_completions(self) -> list[str]:
completions = []
if self.raw_data:
completions.append(self.raw_data.get("response", ""))
return completions
response_text = self.raw_data.get("response", "")
note_text = self.raw_data.get("note", "")
full_response = response_text
if note_text:
full_response += f"\n{note_text}"

completions.append(full_response)
return completions
class CustomCompletionFn(CompletionFnSpec):
def __init__(
self,
api_base: str,
api_key: Optional[str] = None,
model: Optional[str] = "gpt-4",
extra_options: Optional[dict] = {},
registry: Optional[Any] = None,
):
self.api_base = api_base
self.api_key = api_key
self.model = model
self.extra_options = extra_options
self.registry = registry
Expand All @@ -42,36 +52,28 @@ def __call__(
prompt: Union[str, OpenAICreateChatPrompt],
**kwargs,
) -> CustomCompletionResult:
if not isinstance(prompt, Prompt):
assert (
isinstance(prompt, str)
or (isinstance(prompt, list) and all(isinstance(token, int) for token in prompt))
or (isinstance(prompt, list) and all(isinstance(token, str) for token in prompt))
or (isinstance(prompt, list) and all(isinstance(msg, dict) for msg in prompt))
), f"Got type {type(prompt)}, with val {type(prompt[0])} for prompt, expected str or list[int] or list[str] or list[dict[str, str]]"

prompt = ChatCompletionPrompt(
raw_prompt=prompt,
)

openai_create_prompt: OpenAICreateChatPrompt = prompt.to_formatted_prompt()

payload = {
"prompt": openai_create_prompt,
"model": self.model,
**self.extra_options,
}

headers = {"Authorization": f"Bearer {self.api_key}"} if self.api_key else {}
response = requests.post(self.api_base, json=payload, headers=headers)

if response.status_code != 200:
logging.warning(f"API request failed with status code {response.status_code}: {response.text}")
raise Exception(f"API request failed: {response.text}")

result = response.json()
result = CustomCompletionResult(raw_data=result, prompt=openai_create_prompt)
if isinstance(prompt, str):
query_text = prompt
elif isinstance(prompt, ChatCompletionPrompt):
query_text = prompt.raw_prompt
elif isinstance(prompt, list): # Handle list format
if all(isinstance(msg, dict) and "content" in msg for msg in prompt):
query_text = prompt[-1]["content"] # Extract last message
else:
raise ValueError(f"Unsupported prompt format: {type(prompt)}")
else:
raise ValueError(f"Unsupported prompt format: {type(prompt)}")
payload = {"question": query_text}
logging.info(f"Sending API request with payload: {json.dumps(payload, indent=4)}")
try:
response = requests.post(self.api_base, json=payload, proxies={"http": None, "https": None}, timeout=40000, verify=False)
response.raise_for_status() # Raises error for non-200 responses
result_data = response.json()
except requests.exceptions.RequestException as e:
logging.error(f"API request failed: {e}")
raise Exception(f"API request failed: {e}")

result = CustomCompletionResult(raw_data=result_data, prompt=query_text)
record_sampling(
prompt=result.prompt,
sampled=result.get_completions(),
Expand Down

0 comments on commit 1258584

Please sign in to comment.