Skip to content

Commit

Permalink
fix: linter issues
Browse files Browse the repository at this point in the history
  • Loading branch information
DriesSmit committed May 7, 2024
1 parent cf09330 commit f9297df
Show file tree
Hide file tree
Showing 7 changed files with 37 additions and 28 deletions.
32 changes: 18 additions & 14 deletions debatellm/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,16 @@

import abc
import datetime
import logging
import time
from importlib import import_module
from typing import Any, Callable, Dict, Optional, Tuple
import logging

# Set the logging level for 'httpx' to 'WARNING' to suppress info and debug messages
logging.getLogger('httpx').setLevel(logging.WARNING)
import google
import numpy as np
import openai
from openai import OpenAI
import vertexai
from openai import OpenAI
from vertexai.preview.language_models import (
ChatModel,
InputOutputTextPair,
Expand All @@ -41,6 +39,10 @@
from debatellm.utils.gcloud import load_gcloud_credentials
from debatellm.utils.openai import load_openai_api_key

# Set the logging level for 'httpx' to 'WARNING' to suppress info and debug messages
logging.getLogger("httpx").setLevel(logging.WARNING)


# Try except decorator
def try_except_decorator(func: Callable) -> Callable:
def func_wrapper(*args: Any, **kwargs: Any) -> Callable:
Expand Down Expand Up @@ -240,7 +242,9 @@ def __init__(

if engine == "mixtral-8x7b-instruct":
openai.api_key = api_key = load_openai_api_key(path="pplx_api_key.txt")
self._client = OpenAI(api_key=api_key, base_url="https://api.perplexity.ai")
self._client = OpenAI(
api_key=api_key, base_url="https://api.perplexity.ai"
)
else:
openai.api_key = api_key = load_openai_api_key()
self._client = OpenAI(api_key=api_key)
Expand Down Expand Up @@ -327,29 +331,29 @@ def _infer(

response = self._client.chat.completions.create(
model=self._engine,
messages=remove_spaces_in_name(messages),
messages=remove_spaces_in_name(messages), # type: ignore
**self._sampling,
)

prompt_cost = (
np.ceil(response.usage.prompt_tokens / 1000)
np.ceil(response.usage.prompt_tokens / 1000) # type: ignore
* self._cost_per_prompt_token
)
response_cost = (
np.ceil(response.usage.completion_tokens / 1000)
np.ceil(response.usage.completion_tokens / 1000) # type: ignore
* self._cost_per_response_token
)
usage_info = {
"prompt_tokens": int(response.usage.prompt_tokens),
"response_tokens": int(response.usage.completion_tokens),
"prompt_tokens": int(response.usage.prompt_tokens), # type: ignore
"response_tokens": int(response.usage.completion_tokens), # type: ignore
"cost": prompt_cost + response_cost,
"num_messages_removed": history_counter,
}
response = response.choices[0].message.content # type: ignore
else:
response = "This is a mock output."
usage_info = {"prompt_tokens": 0, "response_tokens": 0, "cost": 0}
return response, usage_info
response = "This is a mock output." # type: ignore
usage_info = {"prompt_tokens": 0, "response_tokens": 0, "cost": 0} # type: ignore
return str(response), usage_info


class PaLM(BaseAgent):
Expand Down Expand Up @@ -486,4 +490,4 @@ def _infer(
else:
output = "This is a mock output."
usage_info = {"promt_tokens": 0, "response_tokens": 0, "cost": 0}
return output, usage_info
return output, usage_info
19 changes: 13 additions & 6 deletions debatellm/eval/load_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,7 +561,10 @@ def chess_questions(

questions = []
for i, q in enumerate(json_obj["examples"]):
q_text = "For the following (in-progress) chess games, please select the option that shows all the legal distination squares that completes the notation: " + q["input"]
q_text = (
"For the following (in-progress) chess games, please select the option that shows all the legal distination squares that completes the notation: "
+ q["input"]
)
correct_answer = " ".join(q["target"])

# Generate distractors from other example targets
Expand All @@ -571,7 +574,9 @@ def chess_questions(
# Loop until we find three unique distractors that are not the correct answer
while not unique_distractors_found:
random.shuffle(distractors) # Shuffle distractors to randomize selection
selected_distractors = distractors[:3] # Select first three shuffled as distractors
selected_distractors = distractors[
:3
] # Select first three shuffled as distractors
unique_distractors_found = all(
[distractor != correct_answer for distractor in selected_distractors]
)
Expand All @@ -581,15 +586,17 @@ def chess_questions(

# Identify the correct solution after shuffling
solution_index = options.index(correct_answer)
solution = chr(65 + solution_index) # Map index to A, B, C, D for multiple choice

solution = chr(
65 + solution_index
) # Map index to A, B, C, D for multiple choice

question = {
"no": i,
"question": q_text,
"options": {chr(65 + j): opt for j, opt in enumerate(options)},
"solution": solution,
"subcategory": "Chess Moves",
"category": "Game Analysis"
"category": "Game Analysis",
}

questions.append(question)
Expand All @@ -604,4 +611,4 @@ def format_question(d: dict) -> str:
question += f"\n{k}: {v}"
return question

return questions, format_question
return questions, format_question
4 changes: 2 additions & 2 deletions debatellm/utils/s3_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,9 +225,9 @@ def read_json_lines(self, path: str) -> List[Dict]:
data.append(json.loads(line))
return data

def read_json_file(self, path: str):
def read_json_file(self, path: str) -> Dict:
full_path = os.path.join(self.bucket_path, path)
with open(full_path, 'r') as file:
with open(full_path, "r") as file:
data = json.load(file)
return data

Expand Down
2 changes: 1 addition & 1 deletion experiments/conf/system/single_agent.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,4 @@ agents: # options: [gpt, palm]
- engine: "mixtral-8x7b-instruct" # mixtral 8x7b instruct engine
- prompt: "${system.agent_prompts.simple}"
- cost_per_prompt_token: 0.0006 # 0.6 # dollar costs per million prompt token
- cost_per_response_token: 0.0006 # 0.6 # dollar costs per million response token
- cost_per_response_token: 0.0006 # 0.6 # dollar costs per million response token
2 changes: 1 addition & 1 deletion scripts/eval_datasets/download_chess.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ if [ -e "${FOLDER}task.json" ]; then
echo "CHESS dataset already exists in $FOLDER"
else
echo "Downloading CHESS dataset to $FOLDER"

# Ensure the target directory exists
if [ ! -d "${FOLDER}" ]; then
echo "Creating folder ${FOLDER}"
Expand Down
2 changes: 1 addition & 1 deletion scripts/experiments_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def run_experiments(

if verbose:
print(f"Launching {len(experiments)} experiments...")

with ThreadPoolExecutor(max_workers=parallel_workers) as executor:
list(tqdm(executor.map(run_experiment, experiments), total=len(experiments)))

Expand Down
4 changes: 1 addition & 3 deletions scripts/launch_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,7 @@
exp_table.append(
{
"system": "single_agent",
"system.agents": gen_agent_config(
1, use_gpt=True, prompt=["simple", "cot"]
),
"system.agents": gen_agent_config(1, use_gpt=True, prompt=["simple", "cot"]),
}
)

Expand Down

0 comments on commit f9297df

Please sign in to comment.