Skip to content

Commit

Permalink
feat: add mixtral agent
Browse files Browse the repository at this point in the history
  • Loading branch information
DriesSmit committed Apr 20, 2024
1 parent 62171c1 commit 893e513
Show file tree
Hide file tree
Showing 15 changed files with 53 additions and 28 deletions.
43 changes: 27 additions & 16 deletions debatellm/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import google
import numpy as np
import openai
from openai import OpenAI
import vertexai
from vertexai.preview.language_models import (
ChatModel,
Expand All @@ -37,7 +38,6 @@
from debatellm.utils.gcloud import load_gcloud_credentials
from debatellm.utils.openai import load_openai_api_key


# Try except decorator
def try_except_decorator(func: Callable) -> Callable:
def func_wrapper(*args: Any, **kwargs: Any) -> Callable:
Expand All @@ -48,18 +48,20 @@ def func_wrapper(*args: Any, **kwargs: Any) -> Callable:
except (
google.api_core.exceptions.InternalServerError,
google.api_core.exceptions.ResourceExhausted,
openai.error.RateLimitError,
openai.error.ServiceUnavailableError,
openai.error.APIError,
openai.error.Timeout,
openai.error.APIConnectionError,
openai.error.InvalidRequestError,
openai.RateLimitError,
# openai.ServiceUnavailableError,
openai.APIError,
openai.APIResponseValidationError,
# openai.Timeout,
# openai.APIConnectionError,
# openai.OpenAIError,
# openai.InvalidRequestError,
) as e:
print("API error occurred:", str(e), ". Retrying in 1 second...")

# If the error is not an invalid request error, then wait for 1 second;
# otherwise, increment the history counter to discard the oldest message.
if type(e) == openai.error.InvalidRequestError:
if type(e) == openai.APIResponseValidationError:
history_counter += 1
else:
time.sleep(1)
Expand Down Expand Up @@ -211,6 +213,7 @@ def __init__(
"gpt-4": 0.03,
"gpt-4-1106-preview": 0.01,
"gpt-4-32k": 0.06,
"mixtral-8x7b-instruct": 0.0006,
}

cost_per_response_token_dict = {
Expand All @@ -219,6 +222,7 @@ def __init__(
"gpt-4": 0.06,
"gpt-4-1106-preview": 0.03,
"gpt-4-32k": 0.12,
"mixtral-8x7b-instruct": 0.0006,
}

assert (
Expand All @@ -230,7 +234,13 @@ def __init__(

self._mock = mock
if not self._mock:
openai.api_key = load_openai_api_key()

if engine == "mixtral-8x7b-instruct":
openai.api_key = api_key = load_openai_api_key(path="pplx_api_key.txt")
self._client = OpenAI(api_key=api_key, base_url="https://api.perplexity.ai")
else:
openai.api_key = api_key = load_openai_api_key()
self._client = OpenAI(api_key=api_key)
self._engine = engine

if sampling is None:
Expand Down Expand Up @@ -309,29 +319,30 @@ def _infer(
"gpt-3.5-turbo-0613",
"gpt-4-1106-preview",
"gpt-4-32k",
"mixtral-8x7b-instruct",
]

response = openai.ChatCompletion.create(
response = self._client.chat.completions.create(
model=self._engine,
messages=remove_spaces_in_name(messages),
**self._sampling,
)

prompt_cost = (
np.ceil(response["usage"]["prompt_tokens"] / 1000)
np.ceil(response.usage.prompt_tokens / 1000)
* self._cost_per_prompt_token
)
response_cost = (
np.ceil(response["usage"]["completion_tokens"] / 1000)
np.ceil(response.usage.completion_tokens / 1000)
* self._cost_per_response_token
)
usage_info = {
"prompt_tokens": int(response["usage"]["prompt_tokens"]),
"response_tokens": int(response["usage"]["completion_tokens"]),
"prompt_tokens": int(response.usage.prompt_tokens),
"response_tokens": int(response.usage.completion_tokens),
"cost": prompt_cost + response_cost,
"num_messages_removed": history_counter,
}
response = response["choices"][0]["message"]["content"] # type: ignore
response = response.choices[0].message.content # type: ignore
else:
response = "This is a mock output."
usage_info = {"prompt_tokens": 0, "response_tokens": 0, "cost": 0}
Expand Down Expand Up @@ -472,4 +483,4 @@ def _infer(
else:
output = "This is a mock output."
usage_info = {"promt_tokens": 0, "response_tokens": 0, "cost": 0}
return output, usage_info
return output, usage_info
2 changes: 1 addition & 1 deletion experiments/conf/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ defaults:
- dataset: medqa # [options: usmle, medmcqa, mmlu, pubmedqa, medqa, ciar, cosmosqa, gpqa]
- _self_

max_eval_count: None
max_eval_count: 100
num_eval_workers: 1 # Each worker receives a full batch of questions.
eval_batch_size: 10 # Defaults to batch_size=1.
verbose: False
Expand Down
1 change: 1 addition & 0 deletions experiments/conf/system/chateval.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
defaults:
- gpt
- palm
- debate_prompts: chateval_ma_debate

_target_: debatellm.systems.ChatEvalDebate
Expand Down
1 change: 1 addition & 0 deletions experiments/conf/system/google_mad.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
defaults:
- gpt
- palm
- debate_prompts: google_ma_debate

_target_: debatellm.systems.MultiAgentDebateGoogle
Expand Down
6 changes: 4 additions & 2 deletions experiments/conf/system/gpt.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,5 +24,7 @@ gpt:
max_tokens: 1000
temperature: 0.5 # Taken from here: https://community.openai.com/t/cheat-sheet-mastering-temperature-and-top-p-in-chatgpt-api-a-few-tips-and-tricks-on-controlling-the-creativity-deterministic-output-of-prompt-responses/172683
top_p: 0.5
cost_per_prompt_token: 0.001 # 0.03 # dollar costs per 1000 prompt token
cost_per_response_token: 0.002 # 0.06 # dollar costs per 1000 response token
# cost_per_prompt_token: 0.001 # 0.03 # dollar costs per 1000 prompt token
# cost_per_response_token: 0.002 # 0.06 # dollar costs per 1000 response token
cost_per_prompt_token: 0.0006 # 0.6 # dollar costs per million prompt token
cost_per_response_token: 0.0006 # 0.6 # dollar costs per million response token
1 change: 1 addition & 0 deletions experiments/conf/system/medprompt.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
defaults:
- gpt
- palm
- debate_prompts: medprompt
_target_: debatellm.systems.Medprompt
num_reasoning_steps: 5
Expand Down
1 change: 1 addition & 0 deletions experiments/conf/system/multi_agent_debate.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
defaults:
- gpt
- palm
- debate_prompts: ma_debate

_target_: debatellm.systems.MultiAgentDebate
Expand Down
11 changes: 9 additions & 2 deletions experiments/conf/system/single_agent.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,13 @@ agents: # options: [gpt, palm]
# - prompt: "${system.agent_prompts.simple}"

# PaLM agent
- - "${system.palm}" # palm uses default setup
# - - "${system.palm}" # palm uses default setup
# - prompt: "${system.agent_prompts.simple}"
# - engine: "text-bison@001"

# Mixtral agent
- - "${system.gpt}"
- engine: "mixtral-8x7b-instruct" # mixtral 8x7b instruct engine
- prompt: "${system.agent_prompts.simple}"
- engine: "text-bison@001"
- cost_per_prompt_token: 0.0006 # 0.6 # dollar costs per million prompt token
- cost_per_response_token: 0.0006 # 0.6 # dollar costs per million response token
1 change: 1 addition & 0 deletions experiments/conf/system/tsinghua_mad.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
defaults:
- gpt
- palm
- debate_prompts: tsinghua_ma_debate

_target_: debatellm.systems.MultiAgentDebateTsinghua
Expand Down
2 changes: 1 addition & 1 deletion experiments/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def evaluate(cfg: omegaconf.DictConfig):
eval_dataset = cfg.dataset.eval_dataset
if os.getenv("TM_NEPTUNE_API_TOKEN"):
logger = neptune.init_run(
project="InstaDeep/debatellm",
project="InstaDeep/truemed",
api_token=os.getenv("TM_NEPTUNE_API_TOKEN"),
tags=[f"{eval_dataset}", f"{system_name}{agents}"],
)
Expand Down
2 changes: 1 addition & 1 deletion manifest.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ builder:
spec:
operator: tf
image: debatellm
command: python scripts/launch_experiments.py
command: python experiments/evaluate.py
tensorboard:
enabled: false

Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ nvidia-nccl-cu11==2.14.3
nvidia-nvtx-cu11==11.7.91
oauthlib==3.2.2
omegaconf==2.3.0
openai==0.27.8
openai==1.21.2
opt-einsum==3.3.0
orjson==3.9.1
overrides==7.3.1
Expand Down
4 changes: 2 additions & 2 deletions scripts/full_results_table.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"https://app.neptune.ai/InstaDeep/debatellm/\n"
"https://app.neptune.ai/InstaDeep/truemed/\n"
]
}
],
Expand Down Expand Up @@ -48,7 +48,7 @@
"# Initialize Neptune\n",
"API_TOKEN = os.environ[\"TM_NEPTUNE_API_TOKEN\"]\n",
"project = neptune.init_project(\n",
" project=\"InstaDeep/debatellm\",\n",
" project=\"InstaDeep/truemed\",\n",
" mode=\"read-only\",\n",
")\n",
"\n",
Expand Down
2 changes: 1 addition & 1 deletion scripts/visualise_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
# Initialize Neptune
API_TOKEN = os.environ["TM_NEPTUNE_API_TOKEN"]
project = neptune.init_project(
project="InstaDeep/debatellm",
project="InstaDeep/truemed",
mode="read-only",
)

Expand Down
2 changes: 1 addition & 1 deletion scripts/visualise_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,7 +574,7 @@ def get_dataset_runs(run_range: List[int], dataset: str = None, engine: str = No

api_token = os.environ["TM_NEPTUNE_API_TOKEN"]
project = neptune.init_project(
project="InstaDeep/debatellm",
project="InstaDeep/truemed",
api_token=api_token,
mode="read-only",
)
Expand Down

0 comments on commit 893e513

Please sign in to comment.