Skip to content

Commit

Permalink
adds LLM statistics to statistics object and
Browse files Browse the repository at this point in the history
Prospector reports

Captures execution time at the level above the
LLMService function, so that even if LLM function
doesn't get executed anymore (because the
information is found in db), the time of the db
retrieval is still measured.
  • Loading branch information
lauraschauer authored and copernico committed Aug 16, 2024
1 parent 058cd25 commit 51bf343
Show file tree
Hide file tree
Showing 6 changed files with 127 additions and 71 deletions.
40 changes: 21 additions & 19 deletions prospector/core/prospector.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@


core_statistics = execution_statistics.sub_collection("core")
llm_statistics = execution_statistics.sub_collection("LLM")


# @profile
Expand Down Expand Up @@ -91,25 +92,26 @@ def prospector( # noqa: C901
return None, -1

if use_llm_repository_url:
with ConsoleWriter("LLM Usage (Repo URL)") as console:
try:
repository_url = LLMService().get_repository_url(
advisory_record.description, advisory_record.references
)
console.print(
f"\n Repository URL: {repository_url}",
status=MessageStatus.OK,
)
except Exception as e:
logger.error(
e,
exc_info=get_level() < logging.INFO,
)
console.print(
e,
status=MessageStatus.ERROR,
)
sys.exit(1)
with ExecutionTimer(llm_statistics.sub_collection("repository_url")):
with ConsoleWriter("LLM Usage (Repo URL)") as console:
try:
repository_url = LLMService().get_repository_url(
advisory_record.description, advisory_record.references
)
console.print(
f"\n Repository URL: {repository_url}",
status=MessageStatus.OK,
)
except Exception as e:
logger.error(
e,
exc_info=get_level() < logging.INFO,
)
console.print(
e,
status=MessageStatus.ERROR,
)
sys.exit(1)

fixing_commit = advisory_record.get_fixing_commit()
# print(advisory_record.references)
Expand Down
6 changes: 2 additions & 4 deletions prospector/core/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,8 @@ def json_(
data = {
"parameters": params,
"advisory_record": advisory_record.__dict__,
"commits": [
r.as_dict(no_hash=True, no_rules=False, no_diff=no_diff)
for r in results
],
"commits": [r.as_dict(no_hash=True, no_rules=False) for r in results],
"processing_statistics": execution_statistics,
}
logger.info(f"Writing results to {fn}")
file = Path(fn)
Expand Down
10 changes: 8 additions & 2 deletions prospector/llm/llm_service.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import itertools
import re

import validators
Expand Down Expand Up @@ -39,7 +40,9 @@ def __init__(self, config: LLMServiceConfig = None):
except Exception:
raise

def get_repository_url(self, advisory_description, advisory_references) -> str:
def get_repository_url(
self, advisory_description, advisory_references
) -> str:
"""Ask an LLM to obtain the repository URL given the advisory description and references.
Args:
Expand All @@ -52,6 +55,7 @@ def get_repository_url(self, advisory_description, advisory_references) -> str:
Raises:
ValueError if advisory information cannot be obtained or there is an error in the model invocation.
"""

try:
chain = prompt_best_guess | self.model | StrOutputParser()

Expand Down Expand Up @@ -125,4 +129,6 @@ def classify_commit(
]:
return False
else:
raise RuntimeError(f"The model returned an invalid response: {is_relevant}")
raise RuntimeError(
f"The model returned an invalid response: {is_relevant}"
)
103 changes: 61 additions & 42 deletions prospector/rules/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,12 @@
from datamodel.commit import Commit, apply_ranking
from llm.llm_service import LLMService
from rules.helpers import extract_security_keywords
from stats.execution import Counter, execution_statistics
from stats.execution import (
Counter,
ExecutionTimer,
execution_statistics,
measure_execution_time,
)
from util.lsh import build_lsh_index, decode_minhash

NUM_COMMITS_PHASE_2 = (
Expand All @@ -17,6 +22,7 @@


rule_statistics = execution_statistics.sub_collection("rules")
llm_statistics = execution_statistics.sub_collection("LLM")


class Rule:
Expand Down Expand Up @@ -57,8 +63,12 @@ def apply_rules(
) -> List[Commit]:
"""Applies the selected set of rules and returns the ranked list of commits."""

phase_1_rules = [rule for rule in RULES_PHASE_1 if rule.get_id() in enabled_rules]
phase_2_rules = [rule for rule in RULES_PHASE_2 if rule.get_id() in enabled_rules]
phase_1_rules = [
rule for rule in RULES_PHASE_1 if rule.get_id() in enabled_rules
]
phase_2_rules = [
rule for rule in RULES_PHASE_2 if rule.get_id() in enabled_rules
]

if phase_2_rules:
Rule.llm_service = LLMService()
Expand All @@ -69,7 +79,9 @@ def apply_rules(

Rule.lsh_index = build_lsh_index()
for candidate in candidates:
Rule.lsh_index.insert(candidate.commit_id, decode_minhash(candidate.minhash))
Rule.lsh_index.insert(
candidate.commit_id, decode_minhash(candidate.minhash)
)

with Counter(rule_statistics) as counter:
counter.initialize("matches", unit="matches")
Expand Down Expand Up @@ -157,9 +169,7 @@ def apply(self, candidate: Commit, advisory_record: AdvisoryRecord):
]
)
if len(relevant_files) > 0:
self.message = (
f"The commit changes some relevant files: {', '.join(relevant_files)}"
)
self.message = f"The commit changes some relevant files: {', '.join(relevant_files)}"
return True
return False

Expand Down Expand Up @@ -266,9 +276,7 @@ def apply(self, candidate: Commit, advisory_record: AdvisoryRecord):
for ref in advisory_record.references:
for twin in candidate.twins:
if twin[1][:8] in ref:
self.message = (
"A twin of this commit is mentioned in the advisory page"
)
self.message = "A twin of this commit is mentioned in the advisory page"
return True
return False

Expand All @@ -283,7 +291,9 @@ def apply(self, candidate: Commit, advisory_record: AdvisoryRecord):
advisory_record.cve_id in content
and len(re.findall(r"CVE-\d{4}-\d{4,8}", content)) == 1
):
self.message = f"Issue {id} linked to the commit mentions the Vuln ID. "
self.message = (
f"Issue {id} linked to the commit mentions the Vuln ID. "
)
return True

for id, content in candidate.jira_refs.items():
Expand Down Expand Up @@ -366,7 +376,9 @@ class CommitMentionedInReference(Rule):
def apply(self, candidate: Commit, advisory_record: AdvisoryRecord):
for ref, n in advisory_record.references.items():
if candidate.commit_id[:8] in ref:
self.message = f"This commit is mentioned {n} times in the references."
self.message = (
f"This commit is mentioned {n} times in the references."
)
return True
return False

Expand All @@ -379,7 +391,9 @@ def apply(self, candidate: Commit, _: AdvisoryRecord) -> bool:
twin_list = Rule.lsh_index.query(decode_minhash(candidate.minhash))
# twin_list.remove(candidate.commit_id)
candidate.twins = [
["no-tag", twin] for twin in twin_list if twin != candidate.commit_id
["no-tag", twin]
for twin in twin_list
if twin != candidate.commit_id
]
# self.lsh_index.insert(candidate.commit_id, decode_minhash(candidate.minhash))
if len(candidate.twins) > 0:
Expand Down Expand Up @@ -423,35 +437,40 @@ def apply(
backend_address: str,
) -> bool:

# Check if this commit is already in the database
try:
r = requests.get(
f"{backend_address}/commits/{candidate.repository}",
params={"commit_id": candidate.commit_id},
timeout=10
)
r.raise_for_status()
commit_data = r.json()[0]

is_security_relevant = commit_data.get('security_relevant')
if is_security_relevant is not None:
candidate.security_relevant = is_security_relevant
return is_security_relevant

candidate.security_relevant = LLMService().classify_commit(
candidate.diff, candidate.repository, candidate.message
)

update_response = requests.post(
backend_address + "/commits/",
json=[candidate.to_dict()],
headers={"content-type": "application/json"},
)
update_response.raise_for_status()

except requests.exceptions.RequestException as e:
error_type = type(e).__name__
print(f"Error communicating with backend: {error_type} - {str(e)}")
with ExecutionTimer(
llm_statistics.sub_collection("commit_classification")
):
# Check if this commit is already in the database
try:
r = requests.get(
f"{backend_address}/commits/{candidate.repository}",
params={"commit_id": candidate.commit_id},
timeout=10,
)
r.raise_for_status()
commit_data = r.json()[0]

is_security_relevant = commit_data.get("security_relevant")
if is_security_relevant is not None:
candidate.security_relevant = is_security_relevant
return is_security_relevant

candidate.security_relevant = LLMService().classify_commit(
candidate.diff, candidate.repository, candidate.message
)

update_response = requests.post(
backend_address + "/commits/",
json=[candidate.to_dict()],
headers={"content-type": "application/json"},
)
update_response.raise_for_status()

except requests.exceptions.RequestException as e:
error_type = type(e).__name__
print(
f"Error communicating with backend: {error_type} - {str(e)}"
)


RULES_PHASE_1: List[Rule] = [
Expand Down
20 changes: 18 additions & 2 deletions prospector/stats/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def record(
unit: Optional[str] = None,
overwrite=False,
):
"""Adds a new statistic to the collection."""
if isinstance(name, str):
if not overwrite and name in self:
raise ForbiddenDuplication(f"{name} already added")
Expand Down Expand Up @@ -112,6 +113,9 @@ def sub_collection(
self,
name: Optional[Union[str, Tuple[str, ...]]] = None,
) -> StatisticCollection:
"""Creates a nested `StatisticCollection` as the value of `name`. Returns
and existing collection if there already exists one under `name`.
"""
if name is None:
name = caller_name()

Expand Down Expand Up @@ -164,8 +168,12 @@ def __contains__(self, key):
raise KeyError("only string ot tuple keys allowed")

def collect(
self, name: Union[str, Tuple[str, ...]], value, unit: Optional[str] = None
self,
name: Union[str, Tuple[str, ...]],
value,
unit: Optional[str] = None,
):
"""Adds a value to the list at key `name`."""
if name not in self:
self.record(name, [], unit=unit)

Expand All @@ -182,7 +190,10 @@ def collect(
raise KeyError(f"can not collect into {name}, because it is not a list")

def collect_unique(
self, name: Union[str, Tuple[str, ...]], value, ensure_uniqueness: bool = False
self,
name: Union[str, Tuple[str, ...]],
value,
ensure_uniqueness: bool = False,
):
if name not in self:
self.record(name, set())
Expand All @@ -205,6 +216,7 @@ def get_descants(self, leaf_only=False, ascents=()):
yield ascents + (child_key,), child, unit

def generate_console_tree(self) -> str:
"""Generate a visual representation of the collection."""
descants = sorted(
list(self.get_descants()), key=lambda e: LEVEL_DELIMITER.join(e[0])
)
Expand Down Expand Up @@ -270,3 +282,7 @@ def as_html_ul(self) -> str:
ul += "</li>"
ul += "</ul>"
return ul

def as_json(self) -> dict:
for key, child in self.items():
print(f"key: {key}, child: {child}")
19 changes: 17 additions & 2 deletions prospector/stats/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from stats.collection import StatisticCollection, SubCollectionWrapper

# Global execution statistics to store all data in
execution_statistics = StatisticCollection()


Expand All @@ -18,6 +19,8 @@ class TimerError(Exception):


class Timer:
"""A simple timer to measure elapsed time."""

def __init__(self):
self._start_time = None

Expand All @@ -36,8 +39,13 @@ def stop(self):


def measure_execution_time(
collection: StatisticCollection, name: Optional[Union[str, Tuple[str, ...]]] = None
collection: StatisticCollection,
name: Optional[Union[str, Tuple[str, ...]]] = None,
):
"""A function decorator that measures and records the execution time of the
decorated function.
"""

def _measure(function):
nonlocal name
if name is None:
Expand All @@ -56,6 +64,8 @@ def _wrapper(*args, **kwargs):


class ExecutionTimer(SubCollectionWrapper):
"""Allows measuring time within the context of a `StatisticCollection`."""

def __init__(self, collection, name: Optional[Union[str, Tuple[str, ...]]] = None):
super().__init__(collection)
self.timer = Timer()
Expand All @@ -79,6 +89,8 @@ def __exit__(self, exc_type, exc_val, exc_tb):


class Counter(SubCollectionWrapper):
"""Allows incrementing counts within the context of a `StatisticCollection`."""

def __enter__(self) -> Counter:
return self

Expand All @@ -98,7 +110,10 @@ def increment(self, name: Union[str, Tuple[str, ...]], by: Union[int, float] = 1
ValueError(f"can not increment {name}")

def initialize(
self, *keys: Union[str, Tuple[str, ...]], value=0, unit: Optional[str] = None
self,
*keys: Union[str, Tuple[str, ...]],
value=0,
unit: Optional[str] = None,
):
for key in keys:
self.collection.collect(key, value, unit=unit)

0 comments on commit 51bf343

Please sign in to comment.