Skip to content
This repository has been archived by the owner on Oct 17, 2024. It is now read-only.

Commit

Permalink
feat: add gemini, ci
Browse files Browse the repository at this point in the history
  • Loading branch information
sigridjineth committed Jul 7, 2024
1 parent 1c71cb4 commit 59ef4f4
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 0 deletions.
38 changes: 38 additions & 0 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
name: CI

on: [push]

env:
OMP_NUM_THREADS: 2
MKL_NUM_THREADS: 2
PIP_DISABLE_PIP_VERSION_CHECK: 1

jobs:
lint:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.11"]

steps:
- uses: actions/checkout@v4

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}

- uses: actions/cache@v4
name: Cache pip packages
with:
path: ~/.cache/uv
key: ${{ runner.os }}-python-${{ matrix.python-version }}

- name: Install uv
run: curl -LsSf https://astral.sh/uv/install.sh | sh

- name: Install dependencies
run: uv pip install --system -r requirements-format.txt

- name: Check lint
run: make check
69 changes: 69 additions & 0 deletions gemini.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import google.generativeai as genai # noqa: I001
import pandas as pd
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_fixed
from tqdm import tqdm

# Constants
API_KEY = "..."
MODEL_NAME = "gemini-1.5-pro-001"

# Configure the Gemini API
genai.configure(api_key=API_KEY)
model = genai.GenerativeModel(MODEL_NAME)

# Safety settings
safety_settings = {
"HARM_CATEGORY_SEXUALLY_EXPLICIT": "BLOCK_NONE",
"HARM_CATEGORY_HATE_SPEECH": "BLOCK_NONE",
"HARM_CATEGORY_HARASSMENT": "BLOCK_NONE",
"HARM_CATEGORY_DANGEROUS_CONTENT": "BLOCK_NONE",
}

# Load questions
df_questions = pd.read_json("questions.jsonl", lines=True)


@retry(stop=stop_after_attempt(10), wait=wait_fixed(1), retry=retry_if_exception_type(Exception))
def call_gemini_api(input_text):
"""Function to call the Gemini API and return the generated text."""
response = model.generate_content(
[input_text],
safety_settings=safety_settings,
)

if not response.candidates:
raise ValueError("Invalid operation: No candidates returned in the response.")

candidate = response.candidates[0]
if not candidate.content.parts:
print(candidate)
raise ValueError("Invalid operation: No parts found in the candidate.")

return candidate.content.parts[0].text


# Generate single-turn outputs
single_turn_outputs = []
for question in tqdm(df_questions["questions"].map(lambda x: x[0]), desc="Generating single-turn outputs"):
generated_text = call_gemini_api(question)
single_turn_outputs.append(generated_text)

# Generate multi-turn outputs
multi_turn_outputs = []
for idx, row in tqdm(df_questions.iterrows(), total=df_questions.shape[0], desc="Generating multi-turn outputs"):
question_format = f"{row['questions'][0]} {single_turn_outputs[idx]} {row['questions'][1]}"
generated_text = call_gemini_api(question_format)
multi_turn_outputs.append(generated_text)

# Save outputs
df_output = pd.DataFrame(
{
"id": df_questions["id"],
"category": df_questions["category"],
"questions": df_questions["questions"],
"outputs": list(zip(single_turn_outputs, multi_turn_outputs)),
"references": df_questions["references"],
}
)

df_output.to_json("gemini_pro_outputs.jsonl", orient="records", lines=True, force_ascii=False)
1 change: 1 addition & 0 deletions requirements-format.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
ruff==0.4.9 ; python_version >= "3.11" and python_version < "3.12"

0 comments on commit 59ef4f4

Please sign in to comment.