-
-
Notifications
You must be signed in to change notification settings - Fork 19
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #8 from truongnh1992/truong-1
Refactor code
- Loading branch information
Showing
2 changed files
with
207 additions
and
109 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,4 +2,4 @@ google-ai-generativelanguage==0.6.10 | |
google-generativeai | ||
PyGitHub | ||
github3.py==1.3.0 | ||
wcmatch | ||
unidiff |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,157 +1,255 @@ | ||
import os | ||
import json | ||
from typing import List, Dict | ||
import google.generativeai as client | ||
import os | ||
from typing import List, Dict, Any | ||
import google.generativeai as Client | ||
from github import Github | ||
from difflib import unified_diff | ||
import difflib | ||
import fnmatch | ||
from unidiff import Hunk, PatchedFile, PatchSet | ||
|
||
# Get input values from environment variables | ||
GITHUB_TOKEN = os.environ.get('GH_TOKEN') | ||
GEMINI_MODEL_NAME = "models/code-bison-001" # Or another Gemini model | ||
# Get input values from environment variables (GitHub Actions) | ||
GITHUB_TOKEN = os.environ.get("GITHUB_TOKEN") | ||
|
||
# Initialize GitHub and Gemini clients | ||
gh = Github(GITHUB_TOKEN) | ||
glm_client = client.configure(api_key=os.environ.get('GEMINI_API_KEY')) | ||
gemini_client = Client.configure(api_key=os.environ.get('GEMINI_API_KEY')) | ||
|
||
|
||
def get_pr_details() -> Dict: | ||
"""Retrieves details of the pull request.""" | ||
event_path = os.environ["GITHUB_EVENT_PATH"] | ||
with open(event_path, "r") as f: | ||
event_data = json.load(f) | ||
class PRDetails: | ||
def __init__(self, owner: str, repo: str, pull_number: int, title: str, description: str): | ||
self.owner = owner | ||
self.repo = repo | ||
self.pull_number = pull_number | ||
self.title = title | ||
self.description = description | ||
|
||
print(f"Raw event data: {event_data}") # Print the raw data | ||
|
||
repo_name = event_data["repository"]["full_name"] | ||
pr_number = event_data["number"] | ||
def get_pr_details() -> PRDetails: | ||
"""Retrieves details of the pull request from GitHub Actions event payload.""" | ||
with open(os.environ["GITHUB_EVENT_PATH"], "r") as f: | ||
event_data = json.load(f) | ||
repo_full_name = event_data["repository"]["full_name"] | ||
owner, repo = repo_full_name.split("/") | ||
pull_number = event_data["number"] | ||
|
||
print(f"Repository name: {repo_name}") # Print repo_name | ||
print(f"PR number: {pr_number}") # Print pr_number | ||
repo = gh.get_repo(repo_name) | ||
pr = repo.get_pull(pr_number) | ||
repo = gh.get_repo(repo_full_name) | ||
pr = repo.get_pull(pull_number) | ||
|
||
return { | ||
"owner": repo.owner.login, | ||
"repo": repo.name, | ||
"pull_number": pr_number, | ||
"title": pr.title, | ||
"description": pr.body, | ||
} | ||
return PRDetails(owner, repo.name, pull_number, pr.title, pr.body) | ||
|
||
|
||
def get_diff(pr_details: Dict) -> str: | ||
"""Fetches the diff of the pull request.""" | ||
repo = gh.get_repo(f"{pr_details['owner']}/{pr_details['repo']}") | ||
pr = repo.get_pull(pr_details["pull_number"]) | ||
return pr.get_commits().reversed[0].files[0].raw_data["patch"] | ||
def get_diff(owner: str, repo: str, pull_number: int) -> str: | ||
"""Fetches the diff of the pull request from GitHub API.""" | ||
repo = gh.get_repo(f"{owner}/{repo}") | ||
pr = repo.get_pull(pull_number) | ||
commit = pr.get_commits().reversed[0] | ||
diff = "" | ||
for file in commit.files: | ||
try: | ||
# Try accessing 'content' first | ||
current_content = file.raw_data["content"] | ||
except KeyError: | ||
try: | ||
# If 'content' is missing, use 'blob_url' | ||
from urllib.request import urlopen | ||
with urlopen(file.raw_data["blob_url"]) as f: | ||
current_content = f.read().decode('utf-8') | ||
except Exception as e: | ||
print(f"Error fetching content for {file.filename}: {e}") | ||
continue # Skip this file if content retrieval fails | ||
|
||
# Generate the diff | ||
diff_lines = difflib.unified_diff( | ||
file.raw_data.get("content", "").splitlines(keepends=True), # Handle potential missing 'content' | ||
current_content.splitlines(keepends=True), | ||
fromfile=file.raw_data.get("filename", "old_file"), | ||
tofile=file.filename | ||
) | ||
diff += ''.join(diff_lines) + "\n" | ||
return diff | ||
|
||
|
||
def analyze_code( | ||
diff: str, pr_details: Dict | ||
) -> List[Dict[str, str]]: | ||
"""Analyzes the code diff using Gemini and generates review comments.""" | ||
def analyze_code(parsed_diff: List[Dict[str, Any]], pr_details: PRDetails) -> List[Dict[str, Any]]: | ||
"""Analyzes the code changes using Gemini and generates review comments.""" | ||
comments = [] | ||
diff_lines = diff.splitlines() | ||
|
||
# Extract changed lines for analysis | ||
changed_lines = [] | ||
for line in diff_lines: | ||
if line.startswith('+') or line.startswith('-'): | ||
changed_lines.append(line) | ||
|
||
if not changed_lines: | ||
return comments | ||
|
||
prompt = create_prompt("\n".join(changed_lines), pr_details) | ||
ai_response = get_gemini_response(prompt) | ||
if ai_response: | ||
comments = create_comment(diff_lines, ai_response) | ||
for file_data in parsed_diff: | ||
file_path = file_data["path"] | ||
if file_path == "/dev/null": | ||
continue # Ignore deleted files | ||
for hunk_data in file_data["hunks"]: | ||
hunk_content = "\n".join(hunk_data["lines"]) | ||
prompt = create_prompt(file_path, hunk_content, pr_details) # Adjust create_prompt accordingly | ||
ai_response = get_ai_response(prompt) | ||
if ai_response: | ||
# Adjust create_comment to use file_path and line numbers from hunk_data["lines"] | ||
new_comments = create_comment(file_path, hunk_data, ai_response) | ||
if new_comments: | ||
comments.extend(new_comments) | ||
return comments | ||
|
||
|
||
def create_prompt(diff: str, pr_details: Dict) -> str: | ||
def create_prompt(file: PatchedFile, hunk: Hunk, pr_details: PRDetails) -> str: | ||
"""Creates the prompt for the Gemini model.""" | ||
return f"""Your task is to review pull requests. Instructions: | ||
- Provide the response in following JSON format: {{"reviews": [{{"lineNumber": <line_number>, "reviewComment": "<review comment>"}}]}} | ||
- Do not give positive comments or compliments. | ||
- Provide comments and suggestions ONLY if there is something to improve, otherwise "reviews" should be an empty array. | ||
- Write the comment in GitHub Markdown format. | ||
- Use the given description only for the overall context and only comment the code. | ||
- IMPORTANT: NEVER suggest adding comments to the code. | ||
Review the following code diff and take the pull request title and description into account when writing the response. | ||
Pull request title: {pr_details['title']} | ||
- Provide the response in following JSON format: {{"reviews": [{{"lineNumber": <line_number>, "reviewComment": "<review comment>"}}]}} | ||
- Do not give positive comments or compliments. | ||
- Provide comments and suggestions ONLY if there is something to improve, otherwise "reviews" should be an empty array. | ||
- Write the comment in GitHub Markdown format. | ||
- Use the given description only for the overall context and only comment the code. | ||
- IMPORTANT: NEVER suggest adding comments to the code. | ||
Review the following code diff in the file "{file.path}" and take the pull request title and description into account when writing the response. | ||
Pull request title: {pr_details.title} | ||
Pull request description: | ||
--- | ||
{pr_details['description']} | ||
{pr_details.description} | ||
--- | ||
Git diff to review: | ||
{diff}""" | ||
```diff | ||
{hunk.content} | ||
{chr(10).join([f"{c.ln if c.ln else c.ln2} {c.content}" for c in hunk.changes])} | ||
``` | ||
""" | ||
|
||
def get_gemini_response(prompt: str) -> List[Dict[str, str]] | None: | ||
"""Gets the AI response from the Gemini API.""" | ||
def get_ai_response(prompt: str) -> List[Dict[str, str]]: | ||
"""Sends the prompt to Gemini API and retrieves the response.""" | ||
try: | ||
response = glm_client.generate_text( | ||
model=GEMINI_MODEL_NAME, | ||
prompt=client.TextPrompt(text=prompt), | ||
response = gemini_client.generate_text( | ||
prompt=prompt, | ||
model="gemini-1.5-pro-002", | ||
temperature=0.2, | ||
max_output_tokens=700, | ||
top_p=1.0, | ||
) | ||
res = response.candidates[0].output.strip() or "{}" | ||
return json.loads(res).get("reviews", []) | ||
print(f"Raw Gemini response: {response.result}") # Print raw response | ||
try: | ||
data = json.loads(response.result.strip()) | ||
if "reviews" in data and isinstance(data["reviews"], list): | ||
reviews = data["reviews"] | ||
# Check if each review item has the required keys | ||
for review in reviews: | ||
if not ("lineNumber" in review and "reviewComment" in review): | ||
print(f"Incomplete review item: {review}") | ||
return [] | ||
return reviews | ||
else: | ||
print("Error: 'reviews' key not found or is not a list in Gemini response") | ||
return [] | ||
except json.JSONDecodeError as e: | ||
print(f"Error decoding Gemini response: {e}") | ||
return [] | ||
except Exception as e: | ||
print(f"Error: {e}") | ||
return None | ||
|
||
print(f"Error during Gemini API call: {e}") | ||
return [] | ||
|
||
def create_comment( | ||
diff_lines: List[str], ai_responses: List[Dict[str, str]] | ||
) -> List[Dict[str, str]]: | ||
"""Creates comments for the GitHub PR.""" | ||
def create_comment(file: PatchedFile, hunk: Hunk, ai_responses: List[Dict[str, str]]) -> List[Dict[str, Any]]: | ||
"""Creates comment objects from AI responses.""" | ||
comments = [] | ||
for ai_response in ai_responses: | ||
try: | ||
line_number = int(ai_response["lineNumber"]) | ||
# Adjust line number for added lines | ||
for i, line in enumerate(diff_lines): | ||
if line.startswith('+') and i < line_number: | ||
line_number += 1 | ||
comments.append( | ||
{ | ||
"body": ai_response["reviewComment"], | ||
"path": "src/review_code_gemini.py", # Replace with actual file path | ||
"line": line_number, | ||
} | ||
) | ||
except ValueError: | ||
print(f"Invalid line number: {ai_response['lineNumber']}") | ||
line_number = hunk.source_start + int(ai_response["lineNumber"]) - 1 | ||
print(f"Creating comment for line: {line_number}") # Debugging print | ||
comments.append({ | ||
"body": ai_response["reviewComment"], | ||
"path": file.path, | ||
"line": line_number, | ||
}) | ||
except (KeyError, TypeError, ValueError) as e: # Catch ValueError for line number conversion | ||
print(f"Error creating comment from AI response: {e}, Response: {ai_response}") | ||
return comments | ||
|
||
|
||
def create_review_comment(pr_details: Dict, comments: List[Dict[str, str]]): | ||
"""Creates a review comment on the GitHub PR.""" | ||
repo = gh.get_repo(f"{pr_details['owner']}/{pr_details['repo']}") | ||
pr = repo.get_pull(pr_details["pull_number"]) | ||
def create_review_comment( | ||
owner: str, | ||
repo: str, | ||
pull_number: int, | ||
comments: List[Dict[str, Any]], | ||
): | ||
"""Submits the review comments to the GitHub API.""" | ||
repo = gh.get_repo(f"{owner}/{repo}") | ||
pr = repo.get_pull(pull_number) | ||
pr.create_review(comments=comments, event="COMMENT") | ||
|
||
def parse_diff(diff_str: str) -> List[Dict[str, Any]]: | ||
"""Parses the diff string using difflib and returns a list of file changes.""" | ||
files = [] | ||
current_file = None | ||
diff_lines = diff_str.splitlines() | ||
for line in diff_lines: | ||
if line.startswith("--- a/"): | ||
current_file = {"path": line[6:], "hunks": []} | ||
files.append(current_file) | ||
elif line.startswith("+++ b/"): | ||
if current_file is not None: # Check if current_file is initialized | ||
current_file["path"] = line[6:] | ||
elif line.startswith("@@"): | ||
if current_file is not None: # Check if current_file is initialized | ||
hunk_header = line | ||
hunk_lines = [] | ||
current_file["hunks"].append({"header": hunk_header, "lines": hunk_lines}) | ||
elif current_file is not None and current_file["hunks"]: # Check for both conditions | ||
current_file["hunks"][-1]["lines"].append(line) | ||
return files | ||
|
||
|
||
|
||
def main(): | ||
"""Main function to run the code review process.""" | ||
print("This is main function") | ||
"""Main function to execute the code review process.""" | ||
pr_details = get_pr_details() | ||
diff = get_diff(pr_details) | ||
|
||
if not diff: | ||
print("No diff found") | ||
event_data = json.load(open(os.environ["GITHUB_EVENT_PATH"], "r")) | ||
if event_data["action"] == "opened": | ||
diff = get_diff(pr_details.owner, pr_details.repo, pr_details.pull_number) | ||
if not diff: | ||
print("No diff found") | ||
return | ||
|
||
parsed_diff = parse_diff(diff) | ||
|
||
exclude_patterns = os.environ.get("INPUT_EXCLUDE", "").split(",") | ||
exclude_patterns = [s.strip() for s in exclude_patterns] | ||
|
||
filtered_diff = [ | ||
file | ||
for file in parsed_diff | ||
if not any(fnmatch.fnmatch(file.path or "", pattern) for pattern in exclude_patterns) | ||
] | ||
|
||
comments = analyze_code(filtered_diff, pr_details) | ||
if comments: | ||
create_review_comment( | ||
pr_details.owner, pr_details.repo, pr_details.pull_number, comments | ||
) | ||
elif event_data["action"] == "synchronize": | ||
diff = get_diff(pr_details.owner, pr_details.repo, pr_details.pull_number) | ||
if not diff: | ||
print("No diff found") | ||
return | ||
|
||
parsed_diff = parse_diff(diff) | ||
|
||
exclude_patterns = os.environ.get("INPUT_EXCLUDE", "").split(",") | ||
exclude_patterns = [s.strip() for s in exclude_patterns] | ||
|
||
filtered_diff = [ | ||
file | ||
for file in parsed_diff | ||
if not any(fnmatch.fnmatch(file.path or "", pattern) for pattern in exclude_patterns) | ||
] | ||
|
||
comments = analyze_code(filtered_diff, pr_details) | ||
if comments: | ||
create_review_comment( | ||
pr_details.owner, pr_details.repo, pr_details.pull_number, comments | ||
) | ||
else: | ||
print("Unsupported event:", os.environ.get("GITHUB_EVENT_NAME")) | ||
return | ||
|
||
comments = analyze_code(diff, pr_details) | ||
if comments: | ||
create_review_comment(pr_details, comments) | ||
|
||
if __name__ == "__main__": | ||
main() | ||
try: | ||
main() | ||
except Exception as error: | ||
print("Error:", error) |