-
Notifications
You must be signed in to change notification settings - Fork 0
/
collect_responses.py
347 lines (304 loc) · 13 KB
/
collect_responses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
"""
Script to collect responses from various LLMs for a set of prompts.
"""
import os
import time
import traceback
from typing import Any, Dict, Optional, List
from concurrent.futures import ProcessPoolExecutor
import dirtyjson
import google.generativeai as genai
from openai import OpenAI
from google.generativeai.types.safety_types import HarmBlockThreshold, HarmCategory
import anthropic
import typing_extensions
from datetime import datetime
from google.api_core.exceptions import ResourceExhausted
from response_utils import (
PROMPTS_FILE,
AnswerEnum,
is_valid_answer_enum,
load_prompts,
save_response,
load_responses,
)
# DEBUG = True
DEBUG = False
GEMINI_API_KEY = os.getenv("PERSONAL_GOOGLE_AISTUDIO_API_KEY")
assert GEMINI_API_KEY, "GEMINI_API_KEY is not set"
ANTHROPIC_API_KEY = os.getenv("PERSONAL_ANTHROPIC_API_KEY")
assert ANTHROPIC_API_KEY, "ANTHROPIC_API_KEY is not set"
OPENAI_API_KEY = os.getenv("PERSONAL_OPENAI_KEY")
assert OPENAI_API_KEY, "OPENAI_API_KEY is not set"
genai.configure(api_key=GEMINI_API_KEY)
# Configurable variables
class LLMConfig(typing_extensions.TypedDict):
name: str
rate_limit_delay: float # Delay in seconds between requests
model_provider: str # "openai" | "anthropic" | "google"
# really restrictive rate limits because we're using flash for ALL the cleaning so we're just gonna get rate limited on everything if our gemini thing stops working. but who cares about the absolute speed? as long as it works.
LLMS: List[LLMConfig] = [
# openai line
# {"name": "gpt-4o-2024-11-20", "rate_limit_delay": 30.0, "model_provider": "openai"},
# {"name": "gpt-4o-2024-08-06", "rate_limit_delay": 30.0, "model_provider": "openai"},
{
"name": "gpt-4o-mini-2024-07-18",
"rate_limit_delay": 30.0,
"model_provider": "openai",
},
# {"name": "gpt-4-0613", "rate_limit_delay": 30.0, "model_provider": "openai"},
# {"name": "gpt-4-0314", "rate_limit_delay": 30.0, "model_provider": "openai"},
# {
# "name": "o1-preview-2024-09-12",
# "rate_limit_delay": 30.0,
# "model_provider": "openai",
# },
# {"name": "o1-mini-2024-09-12", "rate_limit_delay": 30.0, "model_provider": "openai"},
# todo: we want to get all the models but let's pick a cheap one just to dip toes in
# claude line
# {
# "name": "claude-3-opus-20240229",
# "rate_limit_delay": 30.0,
# "model_provider": "anthropic",
# },
# {
# "name": "claude-3-5-sonnet-20241022",
# "rate_limit_delay": 30.0,
# "model_provider": "anthropic",
# },
# {
# "name": "claude-3-5-haiku-20241022",
# "rate_limit_delay": 30.0,
# "model_provider": "anthropic",
# },
# gemini line
# todo: seems like we don't have this yet? idk
# {"name": "gemini-exp-1121", "rate_limit_delay": 30.0, "model_provider": "google"},
# {
# "name": "gemini-1.5-flash-002",
# "rate_limit_delay": 30.0,
# "model_provider": "google",
# },
]
NUM_RESPONSES_PER_LLM = 30
# region response cleaning
RESPONSE_CLEANING_SYSTEM_PROMPT = """
You are a helpful assistant that extracts the final choice and reasoning from an LLM's response text.
The response will be about choosing between pills (either red vs blue).
Extract:
1. The final stated choice
2. The explanation/reasoning given
Format the output as JSON with this structure:
{
"answer": "blue" | "red" | "refusal" | "other",
"explanation": "1-2 sentence summary of why the choice was made, based on input text."
}
If no clear choice is stated or the response refuses to choose, use "refusal".
If the choice doesn't fit the above categories, use "other".
"""
class GeminiCleanedResponse(typing_extensions.TypedDict):
answer: AnswerEnum
explanation: str
class FullCleanedResponse(typing_extensions.TypedDict):
gemini_cleaned_response: GeminiCleanedResponse
raw_response: str # The raw response text
def is_gemini_cleaned_response(response: Dict[str, Any]) -> bool:
print(f"Checking GeminiCleanedResponse of type {type(response)}: {response}")
conditions_and_error_messages = [
(is_valid_answer_enum(response["answer"]), "answer is not an AnswerEnum"),
(isinstance(response["explanation"], str), "explanation is not a string"),
]
errors = [msg for cond, msg in conditions_and_error_messages if not cond]
if errors:
print(f"Invalid GeminiCleanedResponse: {', '.join(errors)}")
return False
return True
def clean_llm_response(raw_response: str) -> Optional[FullCleanedResponse]:
"""
Clean a raw response from an LLM using Gemini's cleaning API.
Returns a FullCleanedResponse or None if an error occurs.
"""
print(f"Cleaning response of type {type(raw_response)}: {raw_response}")
if DEBUG:
return {
"gemini_cleaned_response": {
"answer": AnswerEnum.OTHER.value,
"explanation": "dummy explanation"
},
"raw_response": raw_response
}
try:
model = genai.GenerativeModel(
"gemini-1.5-flash-8b",
system_instruction=RESPONSE_CLEANING_SYSTEM_PROMPT,
)
response = model.generate_content(
raw_response,
safety_settings={
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
},
generation_config=genai.GenerationConfig(
response_mime_type="application/json",
response_schema=GeminiCleanedResponse,
temperature=0.0,
),
)
response_text: str = response.text
cleaned_response = dirtyjson.loads(response_text)
if not is_gemini_cleaned_response(cleaned_response):
raise ValueError("Cleaned response is not a GeminiCleanedResponse")
return {
"gemini_cleaned_response": cleaned_response,
"raw_response": raw_response,
}
except ResourceExhausted as e:
print(f"Rate limit exceeded: {e}")
time.sleep(60.0) # probably just 429 too many requests, chill out for like a minute to let the quota refresh.
return None
except Exception as e:
print(f"Failed to clean response: {traceback.format_exc()}")
return None
# endregion response cleaning
# region LLM API calls
class LLMResponse(typing_extensions.TypedDict):
content: str # The actual response text
model: str # The specific model version used
provider: str # The provider (openai, anthropic, gemini)
metadata: Dict[str, Any] # Usage stats, model info, etc (optional)
def call_openai(prompt_text: str, model_version: str = "gpt-4o") -> Optional[LLMResponse]:
client = OpenAI(api_key=OPENAI_API_KEY)
try:
response = client.chat.completions.create(
model=model_version,
max_tokens=4096,
messages=[{"role": "user", "content": prompt_text}],
temperature=0.0,
)
return {
"content": response.choices[0].message.content,
"model": model_version,
"provider": "openai",
"metadata": {
"usage": {
"total_tokens": response.usage.total_tokens,
"prompt_tokens": response.usage.prompt_tokens,
"completion_tokens": response.usage.completion_tokens,
}
}
}
except Exception as e:
print(f"Error calling OpenAI API: {traceback.format_exc()}")
return None
def call_anthropic(prompt_text: str, model_version: str = "claude-3-5-sonnet-20241022") -> Optional[LLMResponse]:
try:
client = anthropic.Anthropic(api_key=ANTHROPIC_API_KEY)
response = client.messages.create(
model=model_version,
max_tokens=4096,
temperature=0,
messages=[{"role": "user", "content": [{"type": "text", "text": prompt_text}]}]
)
return {
"content": response.content[0].text,
"model": model_version,
"provider": "anthropic",
"metadata": {}
}
except Exception as e:
print(f"Error calling Anthropic: {traceback.format_exc()}")
return None
def call_gemini(prompt_text: str, gemini_model_version: str = "gemini-exp-1121") -> Optional[LLMResponse]:
try:
model = genai.GenerativeModel(gemini_model_version)
response = model.generate_content(
prompt_text,
safety_settings={
# apparently OFF is not a valid setting even though the enum is there in the code because there's ANOTHER dict doing the same thing idk
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
},
generation_config=genai.types.GenerationConfig(temperature=0.0),
)
return {
"content": response.text,
"model": gemini_model_version,
"provider": "google",
"metadata": {}
}
except Exception as e:
print(f"Error calling Gemini: {traceback.format_exc()}")
return None
def placeholder_call_llm_api(llm_config: LLMConfig, prompt_text: str) -> Optional[LLMResponse]:
return {
"content": "dummy response",
"model": llm_config["name"],
"provider": llm_config["model_provider"],
"metadata": {}
}
def call_llm_api(llm_config: LLMConfig, prompt_text: str) -> Optional[LLMResponse]:
try:
llm_config_name = llm_config["name"]
llm_config_model_provider = llm_config["model_provider"]
if DEBUG:
return placeholder_call_llm_api(llm_config, prompt_text)
elif llm_config_model_provider == "openai":
return call_openai(prompt_text, model_version=llm_config_name)
elif llm_config_model_provider == "anthropic":
return call_anthropic(prompt_text, model_version=llm_config_name)
elif llm_config_model_provider == "google":
return call_gemini(prompt_text, gemini_model_version=llm_config_name)
else:
raise ValueError(f"LLM '{llm_config['name']}' is not supported.")
except Exception as e:
print(f"Error calling {llm_config['name']}: {traceback.format_exc()}")
return None
def collect_responses_by_provider(provider: str, llm_configs: List[LLMConfig], prompts, existing_responses):
"""Collect responses for all models from a single provider, respecting rate limits"""
provider_configs = [cfg for cfg in llm_configs if cfg["model_provider"] == provider]
for llm_config in provider_configs:
for prompt_id, prompt_data in prompts.items():
variation = prompt_data["variation"]
existing_count = len([
r for r in existing_responses
if r["llm"] == llm_config["name"] and r["prompt_id"] == prompt_id
])
for i in range(existing_count, NUM_RESPONSES_PER_LLM):
print(f"Collecting response {i+1} for LLM '{llm_config['name']}' on prompt '{prompt_id}'")
raw_response = call_llm_api(llm_config, variation["prompt"])
if raw_response:
cleaned_response = clean_llm_response(raw_response["content"])
if cleaned_response:
response_obj = {
"llm": llm_config["name"],
"prompt_id": prompt_id,
"response_number": i + 1,
"response": cleaned_response,
"metadata": raw_response["metadata"],
"timestamp": datetime.now().isoformat(),
}
save_response(response_obj)
# Use provider-specific rate limit
if not DEBUG:
time.sleep(llm_config["rate_limit_delay"])
def main():
try:
prompts = load_prompts(PROMPTS_FILE)
responses = load_responses()
# Create one worker per provider
providers = ["openai", "anthropic", "google"]
with ProcessPoolExecutor(max_workers=len(providers)) as executor:
futures = [
executor.submit(collect_responses_by_provider, provider, LLMS, prompts, responses)
for provider in providers
]
for future in futures:
future.result()
except Exception as e:
print(f"Error in main: {traceback.format_exc()}")
if __name__ == "__main__":
main()