Skip to content

Commit

Permalink
Merge pull request #80 from ipa-lab/development_without_spacy
Browse files Browse the repository at this point in the history
Development without spacy
  • Loading branch information
andreashappe authored Aug 6, 2024
2 parents 70a9018 + 88fcf70 commit 033b598
Show file tree
Hide file tree
Showing 15 changed files with 731 additions and 97 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,9 @@ wintermute.py: error: the following arguments are required: {linux_privesc,windo

# start wintermute, i.e., attack the configured virtual machine
$ python wintermute.py minimal_linux_privesc

# install dependencies for testing if you want to run the tests
$ pip install .[testing]
~~~

## Publications about hackingBuddyGPT
Expand Down
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@ pythonpath = "src"
addopts = [
"--import-mode=importlib",
]
[project.optional-dependencies]
testing = [
'pytest',
'pytest-mock'
]

[project.scripts]
wintermute = "hackingBuddyGPT.cli.wintermute:main"
Expand Down
124 changes: 58 additions & 66 deletions src/hackingBuddyGPT/usecases/web_api_testing/prompt_engineer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import nltk
from nltk.tokenize import word_tokenize
from nltk.corpus import stopwords
from instructor.retry import InstructorRetryException


Expand Down Expand Up @@ -37,13 +36,11 @@ def __init__(self, strategy, llm_handler, history, schemas, response_handler):
self.found_endpoints = ["/"]
self.endpoint_methods = {}
self.endpoint_found_methods = {}
model_name = "en_core_web_sm"

# Check if the models are already installed
nltk.download('punkt')
nltk.download('stopwords')
self._prompt_history = history
self.prompt = self._prompt_history
self.prompt = {self.round: {"content": "initial_prompt"}}
self.previous_prompt = self._prompt_history[self.round]["content"]
self.schemas = schemas

Expand Down Expand Up @@ -77,10 +74,6 @@ def generate_prompt(self, doc=False):
self.round = self.round +1
return self._prompt_history





def in_context_learning(self, doc=False, hint=""):
"""
Generates a prompt for in-context learning.
Expand All @@ -91,7 +84,14 @@ def in_context_learning(self, doc=False, hint=""):
Returns:
str: The generated prompt.
"""
return str("\n".join(self._prompt_history[self.round]["content"] + [self.prompt]))
history_content = [entry["content"] for entry in self._prompt_history]
prompt_content = self.prompt.get(self.round, {}).get("content", "")

# Add hint if provided
if hint:
prompt_content += f"\n{hint}"

return "\n".join(history_content + [prompt_content])

def get_http_action_template(self, method):
"""Helper to construct a consistent HTTP action description."""
Expand All @@ -103,13 +103,45 @@ def get_http_action_template(self, method):
else:
return (
f"Create HTTPRequests of type {method} considering only the object with id=1 for the endpoint and understand the responses. Ensure that they are correct requests.")


def get_initial_steps(self, common_steps):
return [
"Identify all available endpoints via GET Requests. Exclude those in this list: {self.found_endpoints}",
"Note down the response structures, status codes, and headers for each endpoint.",
"For each endpoint, document the following details: URL, HTTP method, query parameters and path variables, expected request body structure for requests, response structure for successful and error responses."
] + common_steps

def get_phase_steps(self, phase, common_steps):
if phase != "DELETE":
return [
f"Identify for all endpoints {self.found_endpoints} excluding {self.endpoint_found_methods[phase]} a valid HTTP method {phase} call.",
self.get_http_action_template(phase)
] + common_steps
else:
return [
"Check for all endpoints the DELETE method. Delete the first instance for all endpoints.",
self.get_http_action_template(phase)
] + common_steps

def get_endpoints_needing_help(self):
endpoints_needing_help = []
endpoints_and_needed_methods = {}
http_methods_set = {"GET", "POST", "PUT", "DELETE"}

for endpoint, methods in self.endpoint_methods.items():
missing_methods = http_methods_set - set(methods)
if len(methods) < 4:
endpoints_needing_help.append(endpoint)
endpoints_and_needed_methods[endpoint] = list(missing_methods)

if endpoints_needing_help:
first_endpoint = endpoints_needing_help[0]
needed_method = endpoints_and_needed_methods[first_endpoint][0]
return [
f"For endpoint {first_endpoint} find this missing method: {needed_method}. If all the HTTP methods have already been found for an endpoint, then do not include this endpoint in your search."]
return []
def chain_of_thought(self, doc=False, hint=""):
"""
Generates a prompt using the chain-of-thought strategy.
If 'doc' is True, it follows a detailed documentation-oriented prompt strategy based on the round number.
If 'doc' is False, it provides general guidance for early round numbers and focuses on HTTP methods for later rounds.
Args:
doc (bool): Determines whether the documentation-oriented chain of thought should be used.
Expand All @@ -126,70 +158,30 @@ def chain_of_thought(self, doc=False, hint=""):
"Make the OpenAPI specification available to developers by incorporating it into your API documentation site and keep the documentation up to date with API changes."
]

http_methods = [ "PUT", "DELETE"]
http_phase = {
5: http_methods[0],
10: http_methods[1]
}

http_methods = ["PUT", "DELETE"]
http_phase = {10: http_methods[0], 15: http_methods[1]}
if doc:
if self.round < 5:

chain_of_thought_steps = [
f"Identify all available endpoints via GET Requests. Exclude those in this list: {self.found_endpoints}", f"Note down the response structures, status codes, and headers for each endpoint.",
f"For each endpoint, document the following details: URL, HTTP method, "
f"query parameters and path variables, expected request body structure for requests, response structure for successful and error responses."
] + common_steps
if self.round <= 5:
chain_of_thought_steps = self.get_initial_steps(common_steps)
elif self.round <= 10:
phase = http_phase.get(min(filter(lambda x: self.round <= x, http_phase.keys())))
chain_of_thought_steps = self.get_phase_steps(phase, common_steps)
else:
if self.round <= 10:
phase = http_phase.get(min(filter(lambda x: self.round <= x, http_phase.keys())))
print(f'phase:{phase}')
if phase != "DELETE":
chain_of_thought_steps = [
f"Identify for all endpoints {self.found_endpoints} excluding {self.endpoint_found_methods[phase]} a valid HTTP method {phase} call.",
self.get_http_action_template(phase)
] + common_steps
else:
chain_of_thought_steps = [
f"Check for all endpoints the DELETE method. Delete the first instance for all endpoints. ",
self.get_http_action_template(phase)
] + common_steps
else:
endpoints_needing_help = []
endpoints_and_needed_methods = {}

# Standard HTTP methods
http_methods = {"GET", "POST", "PUT", "DELETE"}

for endpoint in self.endpoint_methods:
# Calculate the missing methods for the current endpoint
missing_methods = http_methods - set(self.endpoint_methods[endpoint])

if len(self.endpoint_methods[endpoint]) < 4:
endpoints_needing_help.append(endpoint)
# Add the missing methods to the dictionary
endpoints_and_needed_methods[endpoint] = list(missing_methods)

print(f'endpoints_and_needed_methods: {endpoints_and_needed_methods}')
print(f'first endpoint in list: {endpoints_needing_help[0]}')
print(f'methods needed for first endpoint: {endpoints_and_needed_methods[endpoints_needing_help[0]][0]}')

chain_of_thought_steps = [f"For enpoint {endpoints_needing_help[0]} find this missing method :{endpoints_and_needed_methods[endpoints_needing_help[0]][0]} "
f"If all the HTTP methods have already been found for an endpoint, then do not include this endpoint in your search. ",]

chain_of_thought_steps = self.get_endpoints_needing_help()
else:
if self.round == 0:
chain_of_thought_steps = ["Let's think step by step."] # Zero shot prompt
chain_of_thought_steps = ["Let's think step by step."]
elif self.round <= 20:
focus_phases = ["endpoints", "HTTP method GET", "HTTP method POST and PUT", "HTTP method DELETE"]
focus_phase = focus_phases[self.round // 5]
chain_of_thought_steps = [f"Just focus on the {focus_phase} for now."]
else:
chain_of_thought_steps = ["Look for exploits."]

print(f'chain of thought steps: {chain_of_thought_steps}')
prompt = self.check_prompt(self.previous_prompt,
chain_of_thought_steps + [hint] if hint else chain_of_thought_steps)
if hint:
chain_of_thought_steps.append(hint)

prompt = self.check_prompt(self.previous_prompt, chain_of_thought_steps)
return prompt

def token_count(self, text):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from hackingBuddyGPT.capabilities.http_request import HTTPRequest
from hackingBuddyGPT.capabilities.record_note import RecordNote
from hackingBuddyGPT.usecases.agents import Agent
from hackingBuddyGPT.usecases.web_api_testing.utils.documentation_handler import DocumentationHandler
from hackingBuddyGPT.usecases.web_api_testing.utils.openapi_specification_manager import OpenAPISpecificationManager
from hackingBuddyGPT.usecases.web_api_testing.utils.llm_handler import LLMHandler
from hackingBuddyGPT.usecases.web_api_testing.prompt_engineer import PromptEngineer, PromptStrategy
from hackingBuddyGPT.usecases.web_api_testing.utils.response_handler import ResponseHandler
Expand Down Expand Up @@ -52,7 +52,7 @@ def init(self):
self.llm_handler = LLMHandler(self.llm, self._capabilities)
self.response_handler = ResponseHandler(self.llm_handler)
self._setup_initial_prompt()
self.documentation_handler = DocumentationHandler(self.llm_handler, self.response_handler)
self.documentation_handler = OpenAPISpecificationManager(self.llm_handler, self.response_handler)

def _setup_capabilities(self):
notes = self._context["notes"]
Expand All @@ -74,7 +74,7 @@ def _setup_initial_prompt(self):
response_handler=self.response_handler)


def all_http_methods_found(self):
def all_http_methods_found(self,turn):
print(f'found endpoints:{self.documentation_handler.endpoint_methods.items()}')
print(f'found endpoints values:{self.documentation_handler.endpoint_methods.values()}')

Expand All @@ -83,17 +83,20 @@ def all_http_methods_found(self):
print(f'found endpoints:{found_endpoints}')
print(f'expected endpoints:{expected_endpoints}')
print(f'correct? {found_endpoints== expected_endpoints}')
if found_endpoints== expected_endpoints or found_endpoints == expected_endpoints -1:
if found_endpoints > 0 and (found_endpoints== expected_endpoints) :
return True
else:
if turn == 20:
if found_endpoints > 0 and (found_endpoints == expected_endpoints):
return True
return False

def perform_round(self, turn: int):
prompt = self.prompt_engineer.generate_prompt(doc=True)
response, completion = self.llm_handler.call_llm(prompt)
return self._handle_response(completion, response)
return self._handle_response(completion, response, turn)

def _handle_response(self, completion, response):
def _handle_response(self, completion, response, turn):
message = completion.choices[0].message
tool_call_id = message.tool_calls[0].id
command = pydantic_core.to_json(response).decode()
Expand All @@ -106,7 +109,6 @@ def _handle_response(self, completion, response):
result_str = self.response_handler.parse_http_status_line(result)
self._prompt_history.append(tool_message(result_str, tool_call_id))
invalid_flags = ["recorded","Not a valid HTTP method", "404" ,"Client Error: Not Found"]
print(f'result_str:{result_str}')
if not result_str in invalid_flags or any(item in result_str for item in invalid_flags):
self.prompt_engineer.found_endpoints = self.documentation_handler.update_openapi_spec(response, result)
self.documentation_handler.write_openapi_to_yaml()
Expand All @@ -120,8 +122,7 @@ def _handle_response(self, completion, response):
http_methods_dict[method].append(endpoint)
self.prompt_engineer.endpoint_found_methods = http_methods_dict
self.prompt_engineer.endpoint_methods = self.documentation_handler.endpoint_methods
print(f'SCHEMAS:{self.prompt_engineer.schemas}')
return self.all_http_methods_found()
return self.all_http_methods_found(turn)



Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .documentation_handler import DocumentationHandler
from .openapi_specification_manager import OpenAPISpecificationManager
from .llm_handler import LLMHandler
from .response_handler import ResponseHandler
from .openapi_parser import OpenAPISpecificationParser
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from datetime import datetime
from hackingBuddyGPT.capabilities.yamlFile import YAMLFile

class DocumentationHandler:
class OpenAPISpecificationManager:
"""
Handles the generation and updating of an OpenAPI specification document based on dynamic API responses.
Expand Down Expand Up @@ -51,7 +51,7 @@ def __init__(self, llm_handler, response_handler):
"yaml": YAMLFile()
}

def partial_match(self, element, string_list):
def is_partial_match(self, element, string_list):
return any(element in string or string in element for string in string_list)

def update_openapi_spec(self, resp, result):
Expand All @@ -66,7 +66,7 @@ def update_openapi_spec(self, resp, result):

if request.__class__.__name__ == 'RecordNote': # TODO: check why isinstance does not work
self.check_openapi_spec(resp)
if request.__class__.__name__ == 'HTTPRequest':
elif request.__class__.__name__ == 'HTTPRequest':
path = request.path
method = request.method
print(f'method: {method}')
Expand Down Expand Up @@ -107,7 +107,7 @@ def update_openapi_spec(self, resp, result):

if '1' not in path and x != "":
endpoint_methods[path].append(method)
elif self.partial_match(x, endpoints.keys()):
elif self.is_partial_match(x, endpoints.keys()):
path = f"/{x}"
print(f'endpoint methods = {endpoint_methods}')
print(f'new path:{path}')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,14 @@ def parse_http_status_line(self, status_line):
"""
if status_line == "Not a valid HTTP method":
return status_line
if status_line and " " in status_line:
protocol, status_code, status_message = status_line.split(' ', 2)
status_message = status_message.split("\r\n")[0]
status_line = status_line.split('\r\n')[0]
# Regular expression to match valid HTTP status lines
match = re.match(r'^(HTTP/\d\.\d) (\d{3}) (.*)$', status_line)
if match:
protocol, status_code, status_message = match.groups()
return f'{status_code} {status_message}'
raise ValueError("Invalid HTTP status line")
else:
raise ValueError("Invalid HTTP status line")

def extract_response_example(self, html_content):
"""
Expand Down
61 changes: 61 additions & 0 deletions tests/test_llm_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import unittest
from unittest.mock import MagicMock, patch
from hackingBuddyGPT.capabilities.capability import capabilities_to_action_model
from hackingBuddyGPT.usecases.web_api_testing.utils import LLMHandler


class TestLLMHandler(unittest.TestCase):
def setUp(self):
self.llm_mock = MagicMock()
self.capabilities = {'cap1': MagicMock(), 'cap2': MagicMock()}
self.llm_handler = LLMHandler(self.llm_mock, self.capabilities)

'''@patch('hackingBuddyGPT.usecases.web_api_testing.utils.capabilities_to_action_model')
def test_call_llm(self, mock_capabilities_to_action_model):
prompt = [{'role': 'user', 'content': 'Hello, LLM!'}]
response_mock = MagicMock()
self.llm_mock.instructor.chat.completions.create_with_completion.return_value = response_mock
# Mock the capabilities_to_action_model to return a dummy Pydantic model
mock_model = MagicMock()
mock_capabilities_to_action_model.return_value = mock_model
response = self.llm_handler.call_llm(prompt)
self.llm_mock.instructor.chat.completions.create_with_completion.assert_called_once_with(
model=self.llm_mock.model,
messages=prompt,
response_model=mock_model
)
self.assertEqual(response, response_mock)'''
def test_add_created_object(self):
created_object = MagicMock()
object_type = 'test_type'

self.llm_handler.add_created_object(created_object, object_type)

self.assertIn(object_type, self.llm_handler.created_objects)
self.assertIn(created_object, self.llm_handler.created_objects[object_type])

def test_add_created_object_limit(self):
created_object = MagicMock()
object_type = 'test_type'

for _ in range(8): # Exceed the limit of 7 objects
self.llm_handler.add_created_object(created_object, object_type)

self.assertEqual(len(self.llm_handler.created_objects[object_type]), 7)

def test_get_created_objects(self):
created_object = MagicMock()
object_type = 'test_type'
self.llm_handler.add_created_object(created_object, object_type)

created_objects = self.llm_handler.get_created_objects()

self.assertIn(object_type, created_objects)
self.assertIn(created_object, created_objects[object_type])
self.assertEqual(created_objects, self.llm_handler.created_objects)

if __name__ == "__main__":
unittest.main()
Loading

0 comments on commit 033b598

Please sign in to comment.