From 4cf575d77ac357c3f66a8f3903b98c21bd1d7b2e Mon Sep 17 00:00:00 2001 From: sigridjineth Date: Mon, 8 Jul 2024 09:00:27 +0900 Subject: [PATCH] refactor: generator-gemini.py --- generator-gemini.py | 48 +++++++++++++++++++-------------------------- 1 file changed, 20 insertions(+), 28 deletions(-) diff --git a/generator-gemini.py b/generator-gemini.py index 9aaedba..e4af2f0 100644 --- a/generator-gemini.py +++ b/generator-gemini.py @@ -8,15 +8,13 @@ from templates import PROMPT_STRATEGY -# Constants +# TODO: generator-gemini.py to converge with generator.py API_KEY = "..." MODEL_NAME = "gemini-1.5-pro-001" -# Configure the Gemini API genai.configure(api_key=API_KEY) model = genai.GenerativeModel(MODEL_NAME) -# Safety settings safety_settings = { "HARM_CATEGORY_SEXUALLY_EXPLICIT": "BLOCK_NONE", "HARM_CATEGORY_HATE_SPEECH": "BLOCK_NONE", @@ -28,8 +26,6 @@ parser.add_argument("-o", "--output_dir", help="Directory to save outputs", default="./generated") args = parser.parse_args() -print(f"Args - {args}") - df_questions = pd.read_json("questions.jsonl", orient="records", encoding="utf-8-sig", lines=True) if not os.path.exists(args.output_dir): @@ -37,54 +33,50 @@ @retry(stop=stop_after_attempt(10), wait=wait_fixed(1), retry=retry_if_exception_type(Exception)) -def call_gemini_api(messages): +def call_gemini_api(input_text): """Function to call the Gemini API and return the generated text.""" - response = model.generate_content(messages, safety_settings=safety_settings) + response = model.generate_content([input_text], safety_settings=safety_settings) if not response.candidates: raise ValueError("Invalid operation: No candidates returned in the response.") candidate = response.candidates[0] - if not candidate.messages: + if not candidate.content.parts: print(candidate) - raise ValueError("Invalid operation: No messages found in the candidate.") + raise ValueError("Invalid operation: No parts found in the candidate.") - return candidate.messages[-1].content + return candidate.content.parts[0].text for strategy_name, prompts in PROMPT_STRATEGY.items(): def format_single_turn_question(question): - # Make a deep copy of the prompts to avoid modifying the original - formatted_prompts = [dict(p) for p in prompts] - formatted_prompts.append({"role": "user", "content": question[0]}) - return formatted_prompts + messages = prompts + [{"role": "user", "content": question[0]}] + formatted_text = "\n".join([f"{message['role']}: {message['content']}" for message in messages]) + return formatted_text single_turn_questions = df_questions["questions"].map(format_single_turn_question) single_turn_outputs = [] - for messages in tqdm(single_turn_questions, desc=f"Generating single-turn outputs for {strategy_name}"): - generated_text = call_gemini_api(messages) + for formatted_text in tqdm(single_turn_questions, desc=f"Generating single-turn outputs for {strategy_name}"): + generated_text = call_gemini_api(formatted_text) single_turn_outputs.append(generated_text) def format_double_turn_question(question, single_turn_output): - # Make a deep copy of the prompts to avoid modifying the original - formatted_prompts = [dict(p) for p in prompts] - formatted_prompts.extend( - [ - {"role": "user", "content": question[0]}, - {"role": "assistant", "content": single_turn_output}, - {"role": "user", "content": question[1]}, - ] - ) - return formatted_prompts + messages = prompts + [ + {"role": "user", "content": question[0]}, + {"role": "assistant", "content": single_turn_output}, + {"role": "user", "content": question[1]}, + ] + formatted_text = "\n".join([f"{message['role']}: {message['content']}" for message in messages]) + return formatted_text multi_turn_questions = df_questions[["questions", "id"]].apply( lambda x: format_double_turn_question(x["questions"], single_turn_outputs[x["id"] - 1]), axis=1, ) multi_turn_outputs = [] - for messages in tqdm(multi_turn_questions, desc=f"Generating multi-turn outputs for {strategy_name}"): - generated_text = call_gemini_api(messages) + for formatted_text in tqdm(multi_turn_questions, desc=f"Generating multi-turn outputs for {strategy_name}"): + generated_text = call_gemini_api(formatted_text) multi_turn_outputs.append(generated_text) df_output = pd.DataFrame(