Skip to content

Commit

Permalink
feat: store different versions of the same file
Browse files Browse the repository at this point in the history
fixes #226
  • Loading branch information
kantord committed Dec 8, 2023
1 parent d29f989 commit 49bba31
Show file tree
Hide file tree
Showing 12 changed files with 224 additions and 98 deletions.
14 changes: 7 additions & 7 deletions seagoat/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from typing_extensions import TypedDict

from seagoat.cache import Cache
from seagoat.file import File
from seagoat.gitfile import GitFile
from seagoat.repository import Repository
from seagoat.result import get_best_score
from seagoat.sources import chroma, ripgrep
Expand All @@ -24,7 +24,7 @@ class RepositoryData(TypedDict):
last_analyzed_version_of_branch: Dict[str, str]
required_commits: Set[str]
commits_already_analyzed: Set[str]
file_data: Dict[str, File]
file_data: Dict[str, GitFile]
sorted_files: List[str]
chunks_already_analyzed: Set[str]
chunks_not_yet_analyzed: Set[str]
Expand Down Expand Up @@ -188,14 +188,14 @@ def _format_results(self, query: str, hard_count_limit: int = 1000):
merged_results = {}

for result_item in self._results:
if self._is_file_ignored(result_item.path):
if self._is_file_ignored(result_item.gitfile.path):
continue

if result_item.path not in merged_results:
merged_results[result_item.path] = result_item
if result_item.gitfile.path not in merged_results:
merged_results[result_item.gitfile.path] = result_item
continue

merged_results[result_item.path].extend(result_item)
merged_results[result_item.gitfile.path].extend(result_item)

results_to_sort = list(merged_results.values())

Expand Down Expand Up @@ -225,7 +225,7 @@ def get_file_position(path: str):
results_to_sort,
key=lambda x: (
0.7 * normalize_score(get_best_score(x))
+ 0.3 * normalize_file_position(get_file_position(x.path))
+ 0.3 * normalize_file_position(get_file_position(x.gitfile.path))
),
)
)[:hard_count_limit]
28 changes: 21 additions & 7 deletions seagoat/file.py → seagoat/gitfile.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,29 @@
import functools
import hashlib
from typing import Dict, List, Literal

from seagoat.utils.file_reader import read_file_with_correct_encoding


class File:
class GitFile:
"""
Represents a specific version of a file in a Git repository.
The object_id is the Git object id of the file, which is basically
its SHA1 hash.
"""

def __init__(
self,
path: str,
absolute_path: str,
object_id: str,
score: float,
commit_messages: list[str],
):
self.path = path
self.absolute_path = absolute_path
self.object_id = object_id
self.commit_hashes = set()
self.score = score
self.commit_messages = commit_messages
Expand All @@ -31,7 +41,9 @@ def get_metadata(self):
Commits:
{commit_messages}"""

def _get_file_lines(self) -> Dict[int, str]:
@property
@functools.lru_cache(5000)
def lines(self) -> Dict[int, str]:
lines = {
(i + 1): line
for i, line in enumerate(
Expand Down Expand Up @@ -90,24 +102,26 @@ def _line_has_relevant_data(self, line: str):
return False

def get_chunks(self):
lines = self._get_file_lines()
# TODO: should be turned into a class method on FileChunk
return [
self._get_chunk_for_line(line_number, lines)
for line_number in lines.keys()
if self._line_has_relevant_data(lines[line_number])
self._get_chunk_for_line(line_number, self.lines)
for line_number in self.lines.keys()
if self._line_has_relevant_data(self.lines[line_number])
]


class FileChunk:
def __init__(self, parent: File, codeline: int, chunk: str):
def __init__(self, parent: GitFile, codeline: int, chunk: str):
self.path = parent.path
self.object_id = parent.object_id
self.codeline = codeline
self.chunk = chunk
self.chunk_id = self._get_id()

def _get_id(self):
text = f"""
Path: {self.path}
Object ID: {self.object_id}
Code line: {self.codeline}
Chunk: {self.chunk}
"""
Expand Down
32 changes: 30 additions & 2 deletions seagoat/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from collections import defaultdict
from pathlib import Path

from seagoat.file import File
from seagoat.gitfile import GitFile
from seagoat.utils.file_types import is_file_type_supported


Expand Down Expand Up @@ -35,6 +35,29 @@ def _get_working_tree_diff(self):
["git", "-C", str(self.path), "diff"], text=True
).strip()

def get_file_object_id(self, file_path: str):
"""
Returns the git object id for the current version
of a file
"""
object_id = subprocess.check_output(
[
"git",
"-C",
str(self.path),
"ls-tree",
"HEAD",
str(file_path),
"--object-only",
],
text=True,
).strip()

return object_id

def is_up_to_date_git_object(self, file_path: str, git_object_id: str):
return self.get_file_object_id(file_path) == git_object_id

def get_status_hash(self):
combined = self._get_head_hash() + self._get_working_tree_diff()
return hashlib.sha256(combined.encode()).hexdigest()
Expand Down Expand Up @@ -90,9 +113,14 @@ def top_files(self):
]

def get_file(self, filename: str):
return File(
"""
Returns a GitFile object with the current version of the file
"""

return GitFile(
filename,
str(self.path / filename),
self.get_file_object_id(filename),
self.frecency_scores[filename],
[commit[3] for commit in self.file_changes[filename]],
)
25 changes: 11 additions & 14 deletions seagoat/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@
from collections import Counter
from dataclasses import dataclass
from enum import Enum
from pathlib import Path
from typing import Dict, List, Set
from seagoat.gitfile import GitFile

from seagoat.utils.file_reader import read_file_with_correct_encoding
from seagoat.utils.file_types import get_file_penalty_factor

SPLITTER_PATTERN = re.compile(r"\s+")
Expand Down Expand Up @@ -35,7 +34,7 @@ def get_best_score(result) -> float:
key=lambda item: item.get_score(),
).get_score()

best_score *= get_file_penalty_factor(result.full_path)
best_score *= get_file_penalty_factor(result.gitfile.absolute_path)

return best_score

Expand Down Expand Up @@ -103,15 +102,13 @@ def to_json(self):


class Result:
def __init__(self, query_text: str, path: str, full_path: Path) -> None:
self.path: str = path
def __init__(self, query_text: str, gitfile: GitFile) -> None:
self.gitfile: GitFile = gitfile
self.query_text: str = query_text
self.full_path: Path = full_path
self.lines: Dict[int, ResultLine] = {}
self.line_texts = read_file_with_correct_encoding(self.full_path).splitlines()

def __repr__(self) -> str:
return f"Result(path={self.path})"
return f"Result(path={self.gitfile.path})"

def extend(self, other) -> None:
self.lines.update(other.lines)
Expand All @@ -130,7 +127,7 @@ def add_line(self, line: int, vector_distance: float) -> None:
self,
line,
vector_distance,
self.line_texts[line - 1],
self.gitfile.lines[line - 1],
types,
)

Expand Down Expand Up @@ -167,7 +164,7 @@ def _merge_almost_touching_blocks(self, blocks):
self,
line,
0.0,
self.line_texts[line - 1],
self.gitfile.lines[line - 1],
{ResultLineType.BRIDGE},
)
last_block.lines.append(bridge_line)
Expand Down Expand Up @@ -202,8 +199,8 @@ def get_result_blocks(self):

def to_json(self):
return {
"path": self.path,
"fullPath": str(self.full_path),
"path": self.gitfile.path,
"fullPath": str(self.gitfile.absolute_path),
"score": round(get_best_score(self), 4),
"blocks": [block.to_json() for block in self.get_result_blocks()],
}
Expand All @@ -217,15 +214,15 @@ def add_context_lines(self, lines: int):
for offset in range(abs(lines)):
new_line = result_line + (offset + 1) * direction

if (new_line) not in range(len(self.line_texts)):
if (new_line) not in range(len(self.gitfile.lines)):
continue

if new_line not in self.lines:
self.lines[new_line] = ResultLine(
self,
line=new_line,
vector_distance=0.0,
line_text=self.line_texts[new_line - 1],
line_text=self.gitfile.lines[new_line - 1],
types={ResultLineType.CONTEXT},
)

Expand Down
16 changes: 14 additions & 2 deletions seagoat/sources/chroma.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from chromadb.utils import embedding_functions

from seagoat.cache import Cache
from seagoat.gitfile import GitFile
from seagoat.repository import Repository
from seagoat.result import Result
from seagoat.utils.config import get_config_values
Expand Down Expand Up @@ -36,13 +37,18 @@ def format_results(query_text: str, repository, chromadb_results):
break
path = str(metadata["path"])
line = int(metadata["line"])
git_object_id = str(metadata["git_object_id"])
full_path = Path(repository.path) / path
gitfile = GitFile(path, str(full_path), git_object_id, 0, [])

if not full_path.exists():
continue

if not repository.is_up_to_date_git_object(path, git_object_id):
continue

if path not in files:
files[path] = Result(query_text, path, full_path)
files[path] = Result(query_text, gitfile)
files[path].add_line(line, distance)

return files.values()
Expand Down Expand Up @@ -86,7 +92,13 @@ def cache_chunk(chunk):
chroma_collection.add(
ids=[chunk.chunk_id],
documents=[chunk.chunk],
metadatas=[{"path": chunk.path, "line": chunk.codeline}],
metadatas=[
{
"path": chunk.path,
"line": chunk.codeline,
"git_object_id": chunk.object_id,
}
],
)
except IDAlreadyExistsError:
pass
Expand Down
Loading

0 comments on commit 49bba31

Please sign in to comment.