Skip to content
This repository has been archived by the owner on Oct 17, 2024. It is now read-only.

Commit

Permalink
refactor: generator-gemini.py
Browse files Browse the repository at this point in the history
  • Loading branch information
sigridjineth committed Jul 8, 2024
1 parent 3ec6a20 commit 4cf575d
Showing 1 changed file with 20 additions and 28 deletions.
48 changes: 20 additions & 28 deletions generator-gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,13 @@

from templates import PROMPT_STRATEGY

# Constants
# TODO: generator-gemini.py to converge with generator.py
API_KEY = "..."
MODEL_NAME = "gemini-1.5-pro-001"

# Configure the Gemini API
genai.configure(api_key=API_KEY)
model = genai.GenerativeModel(MODEL_NAME)

# Safety settings
safety_settings = {
"HARM_CATEGORY_SEXUALLY_EXPLICIT": "BLOCK_NONE",
"HARM_CATEGORY_HATE_SPEECH": "BLOCK_NONE",
Expand All @@ -28,63 +26,57 @@
parser.add_argument("-o", "--output_dir", help="Directory to save outputs", default="./generated")
args = parser.parse_args()

print(f"Args - {args}")

df_questions = pd.read_json("questions.jsonl", orient="records", encoding="utf-8-sig", lines=True)

if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)


@retry(stop=stop_after_attempt(10), wait=wait_fixed(1), retry=retry_if_exception_type(Exception))
def call_gemini_api(messages):
def call_gemini_api(input_text):
"""Function to call the Gemini API and return the generated text."""
response = model.generate_content(messages, safety_settings=safety_settings)
response = model.generate_content([input_text], safety_settings=safety_settings)

if not response.candidates:
raise ValueError("Invalid operation: No candidates returned in the response.")

candidate = response.candidates[0]
if not candidate.messages:
if not candidate.content.parts:
print(candidate)
raise ValueError("Invalid operation: No messages found in the candidate.")
raise ValueError("Invalid operation: No parts found in the candidate.")

return candidate.messages[-1].content
return candidate.content.parts[0].text


for strategy_name, prompts in PROMPT_STRATEGY.items():

def format_single_turn_question(question):
# Make a deep copy of the prompts to avoid modifying the original
formatted_prompts = [dict(p) for p in prompts]
formatted_prompts.append({"role": "user", "content": question[0]})
return formatted_prompts
messages = prompts + [{"role": "user", "content": question[0]}]
formatted_text = "\n".join([f"{message['role']}: {message['content']}" for message in messages])
return formatted_text

single_turn_questions = df_questions["questions"].map(format_single_turn_question)
single_turn_outputs = []
for messages in tqdm(single_turn_questions, desc=f"Generating single-turn outputs for {strategy_name}"):
generated_text = call_gemini_api(messages)
for formatted_text in tqdm(single_turn_questions, desc=f"Generating single-turn outputs for {strategy_name}"):
generated_text = call_gemini_api(formatted_text)
single_turn_outputs.append(generated_text)

def format_double_turn_question(question, single_turn_output):
# Make a deep copy of the prompts to avoid modifying the original
formatted_prompts = [dict(p) for p in prompts]
formatted_prompts.extend(
[
{"role": "user", "content": question[0]},
{"role": "assistant", "content": single_turn_output},
{"role": "user", "content": question[1]},
]
)
return formatted_prompts
messages = prompts + [
{"role": "user", "content": question[0]},
{"role": "assistant", "content": single_turn_output},
{"role": "user", "content": question[1]},
]
formatted_text = "\n".join([f"{message['role']}: {message['content']}" for message in messages])
return formatted_text

multi_turn_questions = df_questions[["questions", "id"]].apply(
lambda x: format_double_turn_question(x["questions"], single_turn_outputs[x["id"] - 1]),
axis=1,
)
multi_turn_outputs = []
for messages in tqdm(multi_turn_questions, desc=f"Generating multi-turn outputs for {strategy_name}"):
generated_text = call_gemini_api(messages)
for formatted_text in tqdm(multi_turn_questions, desc=f"Generating multi-turn outputs for {strategy_name}"):
generated_text = call_gemini_api(formatted_text)
multi_turn_outputs.append(generated_text)

df_output = pd.DataFrame(
Expand Down

0 comments on commit 4cf575d

Please sign in to comment.