-
Notifications
You must be signed in to change notification settings - Fork 0
/
askllm_groq.py
156 lines (134 loc) · 7.11 KB
/
askllm_groq.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import os
import argparse
import json
import time
import jsonlines
import logging
from tqdm import tqdm
import requests
from google.auth import default
from google.auth.transport.requests import Request
# Configure logging
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
# Configure the Vertex AI API details
PROJECT_ID = "north-390910"
LOCATION = "us-central1"
MODEL_ID = "gemini-1.5-flash-001"
ENDPOINT = f"https://{LOCATION}-aiplatform.googleapis.com/v1/projects/{PROJECT_ID}/locations/{LOCATION}/publishers/google/models/{MODEL_ID}:generateContent"
# Load the templates from the template.json file
with open('template.json', 'r') as f:
templates = json.load(f)
# Define the model configuration
generation_config = {
"temperature": 0.5,
"top_p": 0.95,
"top_k": 40, # Adjusted value
"max_output_tokens": 8192,
"response_mime_type": "application/json",
}
# Removed the problematic safety settings
safety_settings = []
# Get an authentication token
def get_auth_token():
credentials, _ = default()
credentials.refresh(Request())
return credentials.token
def send_request(prompt):
headers = {
"Authorization": f"Bearer {get_auth_token()}",
"Content-Type": "application/json",
}
payload = {
"contents": [{
"role": "user",
"parts": [{"text": prompt}]
}],
"generation_config": generation_config,
"safety_settings": safety_settings
}
response = requests.post(ENDPOINT, headers=headers, json=payload)
response.raise_for_status()
return response.json()
def process_batch(chat_prompt, batch, text_field, wait_time):
responses = []
for line in batch:
input_text = line[text_field]
prompt = chat_prompt.format(content=input_text)
response = send_request(prompt)
responses.append((line, response))
time.sleep(wait_time) # Wait for the specified time between each request
return responses
def process_json_lines(json_lines_file, output_file, num_examples, max_requests_per_minute, batch_size, language, text_field, verbose, wait_time):
if verbose:
logging.getLogger().setLevel(logging.DEBUG)
chat_prompt = templates[language]
try:
# Count the number of lines already processed in the output file
if os.path.exists(output_file):
with jsonlines.open(output_file, mode='r') as reader:
processed_lines_count = sum(1 for _ in reader)
else:
processed_lines_count = 0
logging.info(f"Starting from line: {processed_lines_count}")
with jsonlines.open(json_lines_file, mode='r') as reader:
lines = list(reader)
retries = 0
total_lines = min(processed_lines_count + num_examples, len(lines))
with jsonlines.open(output_file, mode='a') as writer:
with tqdm(total=total_lines - processed_lines_count, desc="Processing lines", disable=verbose) as pbar:
for idx in range(processed_lines_count, total_lines, batch_size):
batch = lines[idx:idx + batch_size]
if any("educational score" in line for line in batch):
for line in batch:
if "educational score" in line:
writer.write(line)
logging.debug(f"Line already processed, skipping: {line}")
pbar.update(1)
continue
if any(text_field not in line for line in batch):
for line in batch:
if text_field not in line:
logging.error(f"Field '{text_field}' not found in line {idx+1}. Make sure the input JSONL file contains this field.")
writer.write(line)
pbar.update(1)
continue
while retries < 5:
try:
responses = process_batch(chat_prompt, batch, text_field, wait_time)
for line, response in responses:
response_json_str = response['candidates'][0]['content']['parts'][0]['text']
response_json = json.loads(response_json_str)
logging.debug(f"Response JSON: {response_json}")
line["justification"] = response_json.get("reason", "No justification found")
line["educational score"] = response_json.get("educational score", 0)
writer.write(line)
logging.debug(f"Written line: {line}")
pbar.update(1)
retries = 0 # Reset retries after a successful operation
if (idx + batch_size) % max_requests_per_minute == 0:
logging.info(f"Processed {idx + batch_size} entries. Waiting for a minute to respect rate limit.")
time.sleep(60) # Wait for a minute to respect rate limit
break
except Exception as e:
retries += 1
logging.error(f"An error occurred while processing batch starting at line {idx+1} (attempt {retries}): {e}")
if retries >= 5:
logging.error("Maximum retry limit reached. Exiting the script.")
exit(1)
except Exception as e:
logging.error(f"An error occurred: {e}")
def main():
parser = argparse.ArgumentParser(description="Process a JSONLines file with the Vertex AI API.")
parser.add_argument('--json_lines_file', type=str, required=True, help='Path to the JSONLines file.')
parser.add_argument('--output_file', type=str, required=True, help='Path to the output JSONLines file.')
parser.add_argument('--num_examples', type=int, default=100, help='Number of requests to process (default: 100).')
parser.add_argument('--max_requests_per_minute', type=int, default=1000, help='Maximum number of requests per minute (default: 1000).')
parser.add_argument('--batch_size', type=int, default=10, help='Number of requests to process in each batch (default: 10).')
parser.add_argument('--language', type=str, choices=['en', 'sv', 'da', 'nb', 'nn'], default='en', help='Language for the prompt (default: en).')
parser.add_argument('--text_field', type=str, default='text', help='Field in JSON lines containing the text (default: text).')
parser.add_argument('--wait_time', type=float, default=0, help='Time to wait between requests in seconds (default: 0).')
parser.add_argument('--verbose', action='store_true', help='Enable verbose logging.')
args = parser.parse_args()
process_json_lines(args.json_lines_file, args.output_file, args.num_examples, args.max_requests_per_minute, args.batch_size, args.language, args.text_field, args.verbose, args.wait_time)
if __name__ == "__main__":
main()