Skip to content

Commit

Permalink
github2 agents
Browse files Browse the repository at this point in the history
  • Loading branch information
xvnpw committed Dec 27, 2024
1 parent 97d33b7 commit a422afa
Show file tree
Hide file tree
Showing 8 changed files with 933 additions and 668 deletions.
4 changes: 2 additions & 2 deletions ai_security_analyzer/agent_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from ai_security_analyzer.llms import LLMProvider
from ai_security_analyzer.markdowns import MarkdownMermaidValidator
from ai_security_analyzer.prompts import DOC_TYPE_PROMPTS, get_agent_prompt
from ai_security_analyzer.github_agents import GithubAgent
from ai_security_analyzer.github2_agents import GithubAgent2
from ai_security_analyzer.file_agents import FileAgent
from ai_security_analyzer.base_agent import AgentType

Expand All @@ -28,7 +28,7 @@ def __init__(self, llm_provider: LLMProvider, config: AppConfig) -> None:
self._agents: dict[AgentType, Type[BaseAgent]] = {
AgentType.DIR: FullDirScanAgent,
AgentType.DRY_RUN_DIR: DryRunFullDirScanAgent,
AgentType.GITHUB: GithubAgent,
AgentType.GITHUB: GithubAgent2,
AgentType.FILE: FileAgent,
}
agent_type = AgentType(f"dry-run-{config.mode}") if config.dry_run else AgentType(config.mode)
Expand Down
135 changes: 135 additions & 0 deletions ai_security_analyzer/github2_agents.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
import logging
from dataclasses import dataclass
from typing import Any, List, Literal, Callable

from langchain_core.messages import HumanMessage
from langchain_text_splitters import CharacterTextSplitter
from langgraph.graph import START, StateGraph
from langgraph.graph.state import CompiledStateGraph
from tiktoken import Encoding

from ai_security_analyzer.base_agent import BaseAgent
from ai_security_analyzer.documents import DocumentFilter, DocumentProcessor
from ai_security_analyzer.llms import LLMProvider
from ai_security_analyzer.markdowns import MarkdownMermaidValidator
from ai_security_analyzer.utils import get_response_content, get_total_tokens
from langgraph.graph import MessagesState


logger = logging.getLogger(__name__)


@dataclass
class AgentState(MessagesState):
target_repo: str
sec_repo_doc: str
document_tokens: int
step0: str
step1: str
step2: str
step3: str
step_index: int
step_count: int
step_prompts: List[Callable[[str], str]]


class GithubAgent2(BaseAgent):
def __init__(
self,
llm_provider: LLMProvider,
text_splitter: CharacterTextSplitter,
tokenizer: Encoding,
max_editor_turns_count: int,
markdown_validator: MarkdownMermaidValidator,
doc_processor: DocumentProcessor,
doc_filter: DocumentFilter,
agent_prompt: str,
doc_type_prompt: str,
):
super().__init__(
llm_provider,
text_splitter,
tokenizer,
max_editor_turns_count,
markdown_validator,
doc_processor,
doc_filter,
agent_prompt,
doc_type_prompt,
)

def _internal_step(self, state: AgentState, llm: Any, use_system_message: bool): # type: ignore[no-untyped-def]
logger.info(f"Internal step {state.get('step_index', 0)+1} of {state['step_count']}")
try:
target_repo = state["target_repo"]
step_index = state.get("step_index", 0)
step_prompts = state["step_prompts"]

step_prompt = step_prompts[step_index](target_repo)

step_msg = HumanMessage(content=step_prompt)

response = llm.invoke(state["messages"] + [step_msg])
document_tokens = get_total_tokens(response)
return {
"messages": state["messages"] + [step_msg, response],
"document_tokens": state.get("document_tokens", 0) + document_tokens,
"step_index": step_index + 1,
f"step{step_index}": get_response_content(response),
}
except Exception as e:
logger.error(f"Error on internal step {state['step_index']} of {state['step_count']}: {e}")
raise ValueError(str(e))

def _internal_step_condition(self, state: AgentState) -> Literal["internal_step", "final_response"]:
current_step_index = state["step_index"]
step_count = state["step_count"]

if current_step_index < step_count:
return "internal_step"
else:
return "final_response"

def _final_response(self, state: AgentState): # type: ignore[no-untyped-def]
logger.info("Getting final response")
try:
messages = state["messages"]
last_message = messages[-1]
final_response = get_response_content(last_message)

if final_response.startswith("```markdown"):
final_response = final_response.replace("```markdown", "")

if final_response.endswith("```"):
final_response = final_response[:-3]

return {
"sec_repo_doc": final_response,
}
except Exception as e:
logger.error(f"Error on getting final response: {e}")
raise ValueError(str(e))

def build_graph(self) -> CompiledStateGraph:
logger.debug(f"[{GithubAgent2.__name__}] building graph...")

llm = self.llm_provider.create_agent_llm()

def internal_step(state: AgentState): # type: ignore[no-untyped-def]
return self._internal_step(state, llm.llm, llm.model_config.use_system_message)

def internal_step_condition(state: AgentState) -> Literal["internal_step", "final_response"]:
return self._internal_step_condition(state)

def final_response(state: AgentState): # type: ignore[no-untyped-def]
return self._final_response(state)

builder = StateGraph(AgentState)
builder.add_node("internal_step", internal_step)
builder.add_node("final_response", final_response)
builder.add_edge(START, "internal_step")
builder.add_conditional_edges("internal_step", internal_step_condition)
builder.add_edge("final_response", "__end__")
graph = builder.compile()

return graph
17 changes: 16 additions & 1 deletion ai_security_analyzer/graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from ai_security_analyzer.config import AppConfig
from langchain_core.documents import Document
from ai_security_analyzer.base_agent import AgentType
from ai_security_analyzer.prompts import GITHUB2_CONFIGS

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -77,17 +78,31 @@ class GithubGraphExecutor(FullDirScanGraphExecutor):

def execute(self, graph: CompiledStateGraph, target: str) -> None:
try:
config = GITHUB2_CONFIGS[self.config.agent_prompt_type]

state = graph.invoke(
{
"target_repo": target,
"refinement_count": self.config.refinement_count,
"step_count": config["steps"],
"step_index": 0,
"step_prompts": config["step_prompts"],
}
)
self._write_output(state)
except Exception as e:
logger.error(f"Graph execution failed: {e}")
raise

def _write_output(self, state: dict[str, Any] | Any) -> None:
actual_token_count = state.get("document_tokens", 0)
logger.info(f"Actual token usage: {actual_token_count}")
output_content = state.get("sec_repo_doc", "")

if self.config.agent_preamble_enabled:
output_content = f"{self.config.agent_preamble}\n\n{output_content}"

self.config.output_file.write(output_content)


class FileGraphExecutor(FullDirScanGraphExecutor):

Expand Down
104 changes: 88 additions & 16 deletions ai_security_analyzer/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def get_agent_prompt(prompt_type: str, mode: str) -> str:
# Map modes to their base templates
mode_templates = {
"dir": (DIR_1, DIR_2, DIR_3, DIR_STEPS_2),
"github": (GITHUB_1, GITHUB_2, GITHUB_3, GITHUB_STEPS_2),
"github": ("GITHUB_1", "GITHUB_2", "GITHUB_3", "GITHUB_STEPS_2"),
"file": (FILE_1, FILE_2, FILE_3, FILE_STEPS_2),
}

Expand Down Expand Up @@ -56,21 +56,6 @@ def get_agent_prompt(prompt_type: str, mode: str) -> str:
- Thoroughly review all provided files to identify components, configurations, and code relevant to the attack surface.
"""

GITHUB_1 = "GITHUB REPOSITORY"
GITHUB_2 = "- If CURRENT {} is not empty - it means that draft of this document was created in previous interactions with LLM. In such case update CURRENT {} with new information that you get from your knowledge base. In case CURRENT {} is empty it means that you get first iteration."
GITHUB_3 = "- CURRENT {} - document that was created in previous interactions with LLM based on knowledge base so far"

GITHUB_STEPS_2 = """1. Update the CURRENT {} (if applicable):
- When the `CURRENT {}` is not empty, it indicates that a draft of this document was created in previous interactions with LLM. In this case, integrate new findings from the latest knowledge base into the existing `CURRENT {}`. Ensure consistency and avoid duplication.
- If the `CURRENT {}` is empty, proceed to create a new threat model based on your knowledge base.
2. Analyze the Project Files:
- Thoroughly review all project files you have in your knowledge base to identify components, configurations, and code relevant to the attack surface.
"""

FILE_1 = "FILE"
FILE_2 = "- If CURRENT {} is not empty - it means that draft of this document was created in previous interactions with LLM using FILE content. In such case update CURRENT {} with new information that you get from FILE. In case CURRENT {} is empty it means that you first time get FILE content"
FILE_3 = """- You will get FILE content
Expand Down Expand Up @@ -768,3 +753,90 @@ def get_agent_prompt(prompt_type: str, mode: str) -> str:
"threat-scenarios": "THREAT MODEL",
"attack-tree": "ATTACK TREE",
}

GITHUB2_THREAT_MODELING_PROMPTS = [
"You are cybersecurity expert, working with development team. Your task is to create threat model for application that is using {}. Focus on threats introduced by {} and omit general, common web application threats.",
"Create threat list with: threat, description (describe what the attacker might do and how), impact (describe the impact of the threat), which {} component is affected (describe what component is affected, e.g. module, function, etc.), risk severity (critical, high, medium or low), and mitigation strategies (describe how can developers or users reduce the risk). Use valid markdown formatting, especially for tables.",
"Update threat list and return only threats that directly involve {}. Return high and critical threats only. Use valid markdown formatting, especially for tables.",
]

GITHUB2_ATTACK_TREE_PROMPTS = [
"""You are cybersecurity expert, working with development team. Your task is to create detail threat model using attack tree analysis for application that is using {}. Focus on threats introduced by {} and omit general, common web application threats. Identify how an attacker might compromise application using {} by exploiting its weaknesses. Your analysis should follow the attack tree methodology and provide actionable insights, including a visualization of the attack tree in a text-based format.
Objective:
Attacker's Goal: To compromise application that use given project by exploiting weaknesses or vulnerabilities within the project itself.
(Note: If you find a more precise or impactful goal during your analysis, feel free to refine it.)""",
"""For each attack step, estimate:
- Likelihood: How probable is it that the attack could occur?
- Impact: What would be the potential damage if the attack is successful?
- Effort: What resources or time would the attacker need?
- Skill Level: What level of expertise is required?
- Detection Difficulty: How easy would it be to detect the attack?""",
"Update attack tree and mark High-Risk Paths and Critical Nodes",
"Update attack tree and return sub-tree with only High-Risk Paths and Critical Nodes. Return title, goal,sub-tree and detailed breakdown of attack vectors for High-Risk Paths and Critical Nodes.",
]

GITHUB2_SEC_DESIGN_PROMPTS = [
"You are an expert in software, cloud and cybersecurity architecture. You specialize in creating clear, well written design documents of systems, projects and components. Provide a well written, detailed project design document that will be use later for threat modelling for project: {}",
"Improve it. Return improved version.",
]

GITHUB2_ATTACK_SURFACE_PROMPTS = [
"You are cybersecurity expert, working with development team. Your task is to create attack surface analysis for application that is using {}. Focus on attack surface introduced by {} and omit general, common attack surface.",
"Create key attack surface list with: description, how {} contributes to the attack surface, example, impact, risk severity (critical, high, medium or low), and mitigation strategies (describe how can developers or users reduce the risk). Use valid markdown formatting, especially for tables.",
"Update key attack surface list and return only elements that directly involve {}. Return high and critical elements only. Use valid markdown formatting, especially for tables.",
]

GITHUB2_PROMPTS: Dict[str, str] = {
"sec-design": "DESIGN DOCUMENT",
"threat-modeling": "THREAT MODEL",
"attack-surface": "THREAT MODEL",
"threat-scenarios": "THREAT MODEL",
"attack-tree": "ATTACK TREE",
}

GITHUB2_THREAT_MODELING_CONFIG = {
"steps": 3,
"step_prompts": [
lambda target_repo: GITHUB2_THREAT_MODELING_PROMPTS[0].format(target_repo, target_repo.split("/")[-1]),
lambda target_repo: GITHUB2_THREAT_MODELING_PROMPTS[1].format(target_repo, target_repo.split("/")[-1]),
lambda target_repo: GITHUB2_THREAT_MODELING_PROMPTS[2].format(target_repo, target_repo.split("/")[-1]),
],
}

GITHUB2_ATTACK_TREE_CONFIG = {
"steps": 4,
"step_prompts": [
lambda target_repo: GITHUB2_ATTACK_TREE_PROMPTS[0].format(
target_repo, target_repo.split("/")[-1], target_repo.split("/")[-1]
),
lambda target_repo: GITHUB2_ATTACK_TREE_PROMPTS[1].format(target_repo),
lambda target_repo: GITHUB2_ATTACK_TREE_PROMPTS[2].format(target_repo),
lambda target_repo: GITHUB2_ATTACK_TREE_PROMPTS[3].format(target_repo),
],
}

GITHUB2_SEC_DESIGN_CONFIG = {
"steps": 2,
"step_prompts": [
lambda target_repo: GITHUB2_SEC_DESIGN_PROMPTS[0].format(target_repo),
lambda target_repo: GITHUB2_SEC_DESIGN_PROMPTS[1].format(target_repo),
],
}

GITHUB2_ATTACK_SURFACE_CONFIG = {
"steps": 3,
"step_prompts": [
lambda target_repo: GITHUB2_ATTACK_SURFACE_PROMPTS[0].format(target_repo, target_repo.split("/")[-1]),
lambda target_repo: GITHUB2_ATTACK_SURFACE_PROMPTS[1].format(target_repo.split("/")[-1]),
lambda target_repo: GITHUB2_ATTACK_SURFACE_PROMPTS[2].format(target_repo.split("/")[-1]),
],
}

GITHUB2_CONFIGS = {
"threat-modeling": GITHUB2_THREAT_MODELING_CONFIG,
"attack-tree": GITHUB2_ATTACK_TREE_CONFIG,
"sec-design": GITHUB2_SEC_DESIGN_CONFIG,
"attack-surface": GITHUB2_ATTACK_SURFACE_CONFIG,
}
11 changes: 10 additions & 1 deletion ai_security_analyzer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,16 @@
import platform
from shutil import which
from typing import Union
from langchain_core.messages import BaseMessage
from langchain_core.messages import BaseMessage, AIMessage


def convert_to_ai_message(message: BaseMessage) -> AIMessage:
content = message.content

if isinstance(content, list):
content = str(content[-1])

return AIMessage(content=content)


def get_total_tokens(message: BaseMessage) -> int:
Expand Down
Loading

0 comments on commit a422afa

Please sign in to comment.