Skip to content

Commit

Permalink
Merge pull request #8 from truongnh1992/truong-1
Browse files Browse the repository at this point in the history
Refactor code
  • Loading branch information
truongnh1992 authored Oct 29, 2024
2 parents d0f46d6 + 0e32094 commit e3c1436
Show file tree
Hide file tree
Showing 2 changed files with 207 additions and 109 deletions.
2 changes: 1 addition & 1 deletion src/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@ google-ai-generativelanguage==0.6.10
google-generativeai
PyGitHub
github3.py==1.3.0
wcmatch
unidiff
314 changes: 206 additions & 108 deletions src/review_code_gemini.py
Original file line number Diff line number Diff line change
@@ -1,157 +1,255 @@
import os
import json
from typing import List, Dict
import google.generativeai as client
import os
from typing import List, Dict, Any
import google.generativeai as Client
from github import Github
from difflib import unified_diff
import difflib
import fnmatch
from unidiff import Hunk, PatchedFile, PatchSet

# Get input values from environment variables
GITHUB_TOKEN = os.environ.get('GH_TOKEN')
GEMINI_MODEL_NAME = "models/code-bison-001" # Or another Gemini model
# Get input values from environment variables (GitHub Actions)
GITHUB_TOKEN = os.environ.get("GITHUB_TOKEN")

# Initialize GitHub and Gemini clients
gh = Github(GITHUB_TOKEN)
glm_client = client.configure(api_key=os.environ.get('GEMINI_API_KEY'))
gemini_client = Client.configure(api_key=os.environ.get('GEMINI_API_KEY'))


def get_pr_details() -> Dict:
"""Retrieves details of the pull request."""
event_path = os.environ["GITHUB_EVENT_PATH"]
with open(event_path, "r") as f:
event_data = json.load(f)
class PRDetails:
def __init__(self, owner: str, repo: str, pull_number: int, title: str, description: str):
self.owner = owner
self.repo = repo
self.pull_number = pull_number
self.title = title
self.description = description

print(f"Raw event data: {event_data}") # Print the raw data

repo_name = event_data["repository"]["full_name"]
pr_number = event_data["number"]
def get_pr_details() -> PRDetails:
"""Retrieves details of the pull request from GitHub Actions event payload."""
with open(os.environ["GITHUB_EVENT_PATH"], "r") as f:
event_data = json.load(f)
repo_full_name = event_data["repository"]["full_name"]
owner, repo = repo_full_name.split("/")
pull_number = event_data["number"]

print(f"Repository name: {repo_name}") # Print repo_name
print(f"PR number: {pr_number}") # Print pr_number
repo = gh.get_repo(repo_name)
pr = repo.get_pull(pr_number)
repo = gh.get_repo(repo_full_name)
pr = repo.get_pull(pull_number)

return {
"owner": repo.owner.login,
"repo": repo.name,
"pull_number": pr_number,
"title": pr.title,
"description": pr.body,
}
return PRDetails(owner, repo.name, pull_number, pr.title, pr.body)


def get_diff(pr_details: Dict) -> str:
"""Fetches the diff of the pull request."""
repo = gh.get_repo(f"{pr_details['owner']}/{pr_details['repo']}")
pr = repo.get_pull(pr_details["pull_number"])
return pr.get_commits().reversed[0].files[0].raw_data["patch"]
def get_diff(owner: str, repo: str, pull_number: int) -> str:
"""Fetches the diff of the pull request from GitHub API."""
repo = gh.get_repo(f"{owner}/{repo}")
pr = repo.get_pull(pull_number)
commit = pr.get_commits().reversed[0]
diff = ""
for file in commit.files:
try:
# Try accessing 'content' first
current_content = file.raw_data["content"]
except KeyError:
try:
# If 'content' is missing, use 'blob_url'
from urllib.request import urlopen
with urlopen(file.raw_data["blob_url"]) as f:
current_content = f.read().decode('utf-8')
except Exception as e:
print(f"Error fetching content for {file.filename}: {e}")
continue # Skip this file if content retrieval fails

# Generate the diff
diff_lines = difflib.unified_diff(
file.raw_data.get("content", "").splitlines(keepends=True), # Handle potential missing 'content'
current_content.splitlines(keepends=True),
fromfile=file.raw_data.get("filename", "old_file"),
tofile=file.filename
)
diff += ''.join(diff_lines) + "\n"
return diff


def analyze_code(
diff: str, pr_details: Dict
) -> List[Dict[str, str]]:
"""Analyzes the code diff using Gemini and generates review comments."""
def analyze_code(parsed_diff: List[Dict[str, Any]], pr_details: PRDetails) -> List[Dict[str, Any]]:
"""Analyzes the code changes using Gemini and generates review comments."""
comments = []
diff_lines = diff.splitlines()

# Extract changed lines for analysis
changed_lines = []
for line in diff_lines:
if line.startswith('+') or line.startswith('-'):
changed_lines.append(line)

if not changed_lines:
return comments

prompt = create_prompt("\n".join(changed_lines), pr_details)
ai_response = get_gemini_response(prompt)
if ai_response:
comments = create_comment(diff_lines, ai_response)
for file_data in parsed_diff:
file_path = file_data["path"]
if file_path == "/dev/null":
continue # Ignore deleted files
for hunk_data in file_data["hunks"]:
hunk_content = "\n".join(hunk_data["lines"])
prompt = create_prompt(file_path, hunk_content, pr_details) # Adjust create_prompt accordingly
ai_response = get_ai_response(prompt)
if ai_response:
# Adjust create_comment to use file_path and line numbers from hunk_data["lines"]
new_comments = create_comment(file_path, hunk_data, ai_response)
if new_comments:
comments.extend(new_comments)
return comments


def create_prompt(diff: str, pr_details: Dict) -> str:
def create_prompt(file: PatchedFile, hunk: Hunk, pr_details: PRDetails) -> str:
"""Creates the prompt for the Gemini model."""
return f"""Your task is to review pull requests. Instructions:
- Provide the response in following JSON format: {{"reviews": [{{"lineNumber": <line_number>, "reviewComment": "<review comment>"}}]}}
- Do not give positive comments or compliments.
- Provide comments and suggestions ONLY if there is something to improve, otherwise "reviews" should be an empty array.
- Write the comment in GitHub Markdown format.
- Use the given description only for the overall context and only comment the code.
- IMPORTANT: NEVER suggest adding comments to the code.
Review the following code diff and take the pull request title and description into account when writing the response.
Pull request title: {pr_details['title']}
- Provide the response in following JSON format: {{"reviews": [{{"lineNumber": <line_number>, "reviewComment": "<review comment>"}}]}}
- Do not give positive comments or compliments.
- Provide comments and suggestions ONLY if there is something to improve, otherwise "reviews" should be an empty array.
- Write the comment in GitHub Markdown format.
- Use the given description only for the overall context and only comment the code.
- IMPORTANT: NEVER suggest adding comments to the code.
Review the following code diff in the file "{file.path}" and take the pull request title and description into account when writing the response.
Pull request title: {pr_details.title}
Pull request description:
---
{pr_details['description']}
{pr_details.description}
---
Git diff to review:
{diff}"""
```diff
{hunk.content}
{chr(10).join([f"{c.ln if c.ln else c.ln2} {c.content}" for c in hunk.changes])}
```
"""

def get_gemini_response(prompt: str) -> List[Dict[str, str]] | None:
"""Gets the AI response from the Gemini API."""
def get_ai_response(prompt: str) -> List[Dict[str, str]]:
"""Sends the prompt to Gemini API and retrieves the response."""
try:
response = glm_client.generate_text(
model=GEMINI_MODEL_NAME,
prompt=client.TextPrompt(text=prompt),
response = gemini_client.generate_text(
prompt=prompt,
model="gemini-1.5-pro-002",
temperature=0.2,
max_output_tokens=700,
top_p=1.0,
)
res = response.candidates[0].output.strip() or "{}"
return json.loads(res).get("reviews", [])
print(f"Raw Gemini response: {response.result}") # Print raw response
try:
data = json.loads(response.result.strip())
if "reviews" in data and isinstance(data["reviews"], list):
reviews = data["reviews"]
# Check if each review item has the required keys
for review in reviews:
if not ("lineNumber" in review and "reviewComment" in review):
print(f"Incomplete review item: {review}")
return []
return reviews
else:
print("Error: 'reviews' key not found or is not a list in Gemini response")
return []
except json.JSONDecodeError as e:
print(f"Error decoding Gemini response: {e}")
return []
except Exception as e:
print(f"Error: {e}")
return None

print(f"Error during Gemini API call: {e}")
return []

def create_comment(
diff_lines: List[str], ai_responses: List[Dict[str, str]]
) -> List[Dict[str, str]]:
"""Creates comments for the GitHub PR."""
def create_comment(file: PatchedFile, hunk: Hunk, ai_responses: List[Dict[str, str]]) -> List[Dict[str, Any]]:
"""Creates comment objects from AI responses."""
comments = []
for ai_response in ai_responses:
try:
line_number = int(ai_response["lineNumber"])
# Adjust line number for added lines
for i, line in enumerate(diff_lines):
if line.startswith('+') and i < line_number:
line_number += 1
comments.append(
{
"body": ai_response["reviewComment"],
"path": "src/review_code_gemini.py", # Replace with actual file path
"line": line_number,
}
)
except ValueError:
print(f"Invalid line number: {ai_response['lineNumber']}")
line_number = hunk.source_start + int(ai_response["lineNumber"]) - 1
print(f"Creating comment for line: {line_number}") # Debugging print
comments.append({
"body": ai_response["reviewComment"],
"path": file.path,
"line": line_number,
})
except (KeyError, TypeError, ValueError) as e: # Catch ValueError for line number conversion
print(f"Error creating comment from AI response: {e}, Response: {ai_response}")
return comments


def create_review_comment(pr_details: Dict, comments: List[Dict[str, str]]):
"""Creates a review comment on the GitHub PR."""
repo = gh.get_repo(f"{pr_details['owner']}/{pr_details['repo']}")
pr = repo.get_pull(pr_details["pull_number"])
def create_review_comment(
owner: str,
repo: str,
pull_number: int,
comments: List[Dict[str, Any]],
):
"""Submits the review comments to the GitHub API."""
repo = gh.get_repo(f"{owner}/{repo}")
pr = repo.get_pull(pull_number)
pr.create_review(comments=comments, event="COMMENT")

def parse_diff(diff_str: str) -> List[Dict[str, Any]]:
"""Parses the diff string using difflib and returns a list of file changes."""
files = []
current_file = None
diff_lines = diff_str.splitlines()
for line in diff_lines:
if line.startswith("--- a/"):
current_file = {"path": line[6:], "hunks": []}
files.append(current_file)
elif line.startswith("+++ b/"):
if current_file is not None: # Check if current_file is initialized
current_file["path"] = line[6:]
elif line.startswith("@@"):
if current_file is not None: # Check if current_file is initialized
hunk_header = line
hunk_lines = []
current_file["hunks"].append({"header": hunk_header, "lines": hunk_lines})
elif current_file is not None and current_file["hunks"]: # Check for both conditions
current_file["hunks"][-1]["lines"].append(line)
return files



def main():
"""Main function to run the code review process."""
print("This is main function")
"""Main function to execute the code review process."""
pr_details = get_pr_details()
diff = get_diff(pr_details)

if not diff:
print("No diff found")
event_data = json.load(open(os.environ["GITHUB_EVENT_PATH"], "r"))
if event_data["action"] == "opened":
diff = get_diff(pr_details.owner, pr_details.repo, pr_details.pull_number)
if not diff:
print("No diff found")
return

parsed_diff = parse_diff(diff)

exclude_patterns = os.environ.get("INPUT_EXCLUDE", "").split(",")
exclude_patterns = [s.strip() for s in exclude_patterns]

filtered_diff = [
file
for file in parsed_diff
if not any(fnmatch.fnmatch(file.path or "", pattern) for pattern in exclude_patterns)
]

comments = analyze_code(filtered_diff, pr_details)
if comments:
create_review_comment(
pr_details.owner, pr_details.repo, pr_details.pull_number, comments
)
elif event_data["action"] == "synchronize":
diff = get_diff(pr_details.owner, pr_details.repo, pr_details.pull_number)
if not diff:
print("No diff found")
return

parsed_diff = parse_diff(diff)

exclude_patterns = os.environ.get("INPUT_EXCLUDE", "").split(",")
exclude_patterns = [s.strip() for s in exclude_patterns]

filtered_diff = [
file
for file in parsed_diff
if not any(fnmatch.fnmatch(file.path or "", pattern) for pattern in exclude_patterns)
]

comments = analyze_code(filtered_diff, pr_details)
if comments:
create_review_comment(
pr_details.owner, pr_details.repo, pr_details.pull_number, comments
)
else:
print("Unsupported event:", os.environ.get("GITHUB_EVENT_NAME"))
return

comments = analyze_code(diff, pr_details)
if comments:
create_review_comment(pr_details, comments)

if __name__ == "__main__":
main()
try:
main()
except Exception as error:
print("Error:", error)

0 comments on commit e3c1436

Please sign in to comment.