Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Development without spacy #80

Merged
merged 9 commits into from
Aug 6, 2024
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,9 @@ wintermute.py: error: the following arguments are required: {linux_privesc,windo

# start wintermute, i.e., attack the configured virtual machine
$ python wintermute.py minimal_linux_privesc

# install dependencies for testing if you want to run the tests
$ pip install .[testing]
~~~

## Publications about hackingBuddyGPT
Expand Down
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@ pythonpath = "src"
addopts = [
"--import-mode=importlib",
]
[project.optional-dependencies]
testing = [
'pytest',
'pytest-mock'
]

[project.scripts]
wintermute = "hackingBuddyGPT.cli.wintermute:main"
Expand Down
124 changes: 58 additions & 66 deletions src/hackingBuddyGPT/usecases/web_api_testing/prompt_engineer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import nltk
from nltk.tokenize import word_tokenize
from nltk.corpus import stopwords
from instructor.retry import InstructorRetryException


Expand Down Expand Up @@ -37,13 +36,11 @@ def __init__(self, strategy, llm_handler, history, schemas, response_handler):
self.found_endpoints = ["/"]
self.endpoint_methods = {}
self.endpoint_found_methods = {}
model_name = "en_core_web_sm"

# Check if the models are already installed
nltk.download('punkt')
nltk.download('stopwords')
self._prompt_history = history
self.prompt = self._prompt_history
self.prompt = {self.round: {"content": "initial_prompt"}}
self.previous_prompt = self._prompt_history[self.round]["content"]
self.schemas = schemas

Expand Down Expand Up @@ -77,10 +74,6 @@ def generate_prompt(self, doc=False):
self.round = self.round +1
return self._prompt_history





def in_context_learning(self, doc=False, hint=""):
"""
Generates a prompt for in-context learning.
Expand All @@ -91,7 +84,14 @@ def in_context_learning(self, doc=False, hint=""):
Returns:
str: The generated prompt.
"""
return str("\n".join(self._prompt_history[self.round]["content"] + [self.prompt]))
history_content = [entry["content"] for entry in self._prompt_history]
prompt_content = self.prompt.get(self.round, {}).get("content", "")

# Add hint if provided
if hint:
prompt_content += f"\n{hint}"

return "\n".join(history_content + [prompt_content])

def get_http_action_template(self, method):
"""Helper to construct a consistent HTTP action description."""
Expand All @@ -103,13 +103,45 @@ def get_http_action_template(self, method):
else:
return (
f"Create HTTPRequests of type {method} considering only the object with id=1 for the endpoint and understand the responses. Ensure that they are correct requests.")


def get_initial_steps(self, common_steps):
return [
"Identify all available endpoints via GET Requests. Exclude those in this list: {self.found_endpoints}",
"Note down the response structures, status codes, and headers for each endpoint.",
"For each endpoint, document the following details: URL, HTTP method, query parameters and path variables, expected request body structure for requests, response structure for successful and error responses."
] + common_steps

def get_phase_steps(self, phase, common_steps):
if phase != "DELETE":
return [
f"Identify for all endpoints {self.found_endpoints} excluding {self.endpoint_found_methods[phase]} a valid HTTP method {phase} call.",
self.get_http_action_template(phase)
] + common_steps
else:
return [
"Check for all endpoints the DELETE method. Delete the first instance for all endpoints.",
self.get_http_action_template(phase)
] + common_steps

def get_endpoints_needing_help(self):
endpoints_needing_help = []
endpoints_and_needed_methods = {}
http_methods_set = {"GET", "POST", "PUT", "DELETE"}

for endpoint, methods in self.endpoint_methods.items():
missing_methods = http_methods_set - set(methods)
if len(methods) < 4:
endpoints_needing_help.append(endpoint)
endpoints_and_needed_methods[endpoint] = list(missing_methods)

if endpoints_needing_help:
first_endpoint = endpoints_needing_help[0]
needed_method = endpoints_and_needed_methods[first_endpoint][0]
return [
f"For endpoint {first_endpoint} find this missing method: {needed_method}. If all the HTTP methods have already been found for an endpoint, then do not include this endpoint in your search."]
return []
def chain_of_thought(self, doc=False, hint=""):
"""
Generates a prompt using the chain-of-thought strategy.
If 'doc' is True, it follows a detailed documentation-oriented prompt strategy based on the round number.
If 'doc' is False, it provides general guidance for early round numbers and focuses on HTTP methods for later rounds.

Args:
doc (bool): Determines whether the documentation-oriented chain of thought should be used.
Expand All @@ -126,70 +158,30 @@ def chain_of_thought(self, doc=False, hint=""):
"Make the OpenAPI specification available to developers by incorporating it into your API documentation site and keep the documentation up to date with API changes."
]

http_methods = [ "PUT", "DELETE"]
http_phase = {
5: http_methods[0],
10: http_methods[1]
}

http_methods = ["PUT", "DELETE"]
http_phase = {10: http_methods[0], 15: http_methods[1]}
if doc:
if self.round < 5:

chain_of_thought_steps = [
f"Identify all available endpoints via GET Requests. Exclude those in this list: {self.found_endpoints}", f"Note down the response structures, status codes, and headers for each endpoint.",
f"For each endpoint, document the following details: URL, HTTP method, "
f"query parameters and path variables, expected request body structure for requests, response structure for successful and error responses."
] + common_steps
if self.round <= 5:
chain_of_thought_steps = self.get_initial_steps(common_steps)
elif self.round <= 10:
phase = http_phase.get(min(filter(lambda x: self.round <= x, http_phase.keys())))
chain_of_thought_steps = self.get_phase_steps(phase, common_steps)
else:
if self.round <= 10:
phase = http_phase.get(min(filter(lambda x: self.round <= x, http_phase.keys())))
print(f'phase:{phase}')
if phase != "DELETE":
chain_of_thought_steps = [
f"Identify for all endpoints {self.found_endpoints} excluding {self.endpoint_found_methods[phase]} a valid HTTP method {phase} call.",
self.get_http_action_template(phase)
] + common_steps
else:
chain_of_thought_steps = [
f"Check for all endpoints the DELETE method. Delete the first instance for all endpoints. ",
self.get_http_action_template(phase)
] + common_steps
else:
endpoints_needing_help = []
endpoints_and_needed_methods = {}

# Standard HTTP methods
http_methods = {"GET", "POST", "PUT", "DELETE"}

for endpoint in self.endpoint_methods:
# Calculate the missing methods for the current endpoint
missing_methods = http_methods - set(self.endpoint_methods[endpoint])

if len(self.endpoint_methods[endpoint]) < 4:
endpoints_needing_help.append(endpoint)
# Add the missing methods to the dictionary
endpoints_and_needed_methods[endpoint] = list(missing_methods)

print(f'endpoints_and_needed_methods: {endpoints_and_needed_methods}')
print(f'first endpoint in list: {endpoints_needing_help[0]}')
print(f'methods needed for first endpoint: {endpoints_and_needed_methods[endpoints_needing_help[0]][0]}')

chain_of_thought_steps = [f"For enpoint {endpoints_needing_help[0]} find this missing method :{endpoints_and_needed_methods[endpoints_needing_help[0]][0]} "
f"If all the HTTP methods have already been found for an endpoint, then do not include this endpoint in your search. ",]

chain_of_thought_steps = self.get_endpoints_needing_help()
else:
if self.round == 0:
chain_of_thought_steps = ["Let's think step by step."] # Zero shot prompt
chain_of_thought_steps = ["Let's think step by step."]
elif self.round <= 20:
focus_phases = ["endpoints", "HTTP method GET", "HTTP method POST and PUT", "HTTP method DELETE"]
focus_phase = focus_phases[self.round // 5]
chain_of_thought_steps = [f"Just focus on the {focus_phase} for now."]
else:
chain_of_thought_steps = ["Look for exploits."]

print(f'chain of thought steps: {chain_of_thought_steps}')
prompt = self.check_prompt(self.previous_prompt,
chain_of_thought_steps + [hint] if hint else chain_of_thought_steps)
if hint:
chain_of_thought_steps.append(hint)

prompt = self.check_prompt(self.previous_prompt, chain_of_thought_steps)
return prompt

def token_count(self, text):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from hackingBuddyGPT.capabilities.http_request import HTTPRequest
from hackingBuddyGPT.capabilities.record_note import RecordNote
from hackingBuddyGPT.usecases.agents import Agent
from hackingBuddyGPT.usecases.web_api_testing.utils.documentation_handler import DocumentationHandler
from hackingBuddyGPT.usecases.web_api_testing.utils.openapi_specification_manager import OpenAPISpecificationManager
from hackingBuddyGPT.usecases.web_api_testing.utils.llm_handler import LLMHandler
from hackingBuddyGPT.usecases.web_api_testing.prompt_engineer import PromptEngineer, PromptStrategy
from hackingBuddyGPT.usecases.web_api_testing.utils.response_handler import ResponseHandler
Expand Down Expand Up @@ -52,7 +52,7 @@ def init(self):
self.llm_handler = LLMHandler(self.llm, self._capabilities)
self.response_handler = ResponseHandler(self.llm_handler)
self._setup_initial_prompt()
self.documentation_handler = DocumentationHandler(self.llm_handler, self.response_handler)
self.documentation_handler = OpenAPISpecificationManager(self.llm_handler, self.response_handler)

def _setup_capabilities(self):
notes = self._context["notes"]
Expand All @@ -74,7 +74,7 @@ def _setup_initial_prompt(self):
response_handler=self.response_handler)


def all_http_methods_found(self):
def all_http_methods_found(self,turn):
print(f'found endpoints:{self.documentation_handler.endpoint_methods.items()}')
print(f'found endpoints values:{self.documentation_handler.endpoint_methods.values()}')

Expand All @@ -83,17 +83,20 @@ def all_http_methods_found(self):
print(f'found endpoints:{found_endpoints}')
print(f'expected endpoints:{expected_endpoints}')
print(f'correct? {found_endpoints== expected_endpoints}')
if found_endpoints== expected_endpoints or found_endpoints == expected_endpoints -1:
if found_endpoints > 0 and (found_endpoints== expected_endpoints) :
return True
else:
if turn == 20:
if found_endpoints > 0 and (found_endpoints == expected_endpoints):
return True
return False

def perform_round(self, turn: int):
prompt = self.prompt_engineer.generate_prompt(doc=True)
response, completion = self.llm_handler.call_llm(prompt)
return self._handle_response(completion, response)
return self._handle_response(completion, response, turn)

def _handle_response(self, completion, response):
def _handle_response(self, completion, response, turn):
message = completion.choices[0].message
tool_call_id = message.tool_calls[0].id
command = pydantic_core.to_json(response).decode()
Expand All @@ -106,7 +109,6 @@ def _handle_response(self, completion, response):
result_str = self.response_handler.parse_http_status_line(result)
self._prompt_history.append(tool_message(result_str, tool_call_id))
invalid_flags = ["recorded","Not a valid HTTP method", "404" ,"Client Error: Not Found"]
print(f'result_str:{result_str}')
if not result_str in invalid_flags or any(item in result_str for item in invalid_flags):
self.prompt_engineer.found_endpoints = self.documentation_handler.update_openapi_spec(response, result)
self.documentation_handler.write_openapi_to_yaml()
Expand All @@ -120,8 +122,7 @@ def _handle_response(self, completion, response):
http_methods_dict[method].append(endpoint)
self.prompt_engineer.endpoint_found_methods = http_methods_dict
self.prompt_engineer.endpoint_methods = self.documentation_handler.endpoint_methods
print(f'SCHEMAS:{self.prompt_engineer.schemas}')
return self.all_http_methods_found()
return self.all_http_methods_found(turn)



Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .documentation_handler import DocumentationHandler
from .openapi_specification_manager import OpenAPISpecificationManager
from .llm_handler import LLMHandler
from .response_handler import ResponseHandler
from .openapi_parser import OpenAPISpecificationParser
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from datetime import datetime
from hackingBuddyGPT.capabilities.yamlFile import YAMLFile

class DocumentationHandler:
class OpenAPISpecificationManager:
"""
Handles the generation and updating of an OpenAPI specification document based on dynamic API responses.

Expand Down Expand Up @@ -51,7 +51,7 @@ def __init__(self, llm_handler, response_handler):
"yaml": YAMLFile()
}

def partial_match(self, element, string_list):
def is_partial_match(self, element, string_list):
return any(element in string or string in element for string in string_list)

def update_openapi_spec(self, resp, result):
Expand All @@ -66,7 +66,7 @@ def update_openapi_spec(self, resp, result):

if request.__class__.__name__ == 'RecordNote': # TODO: check why isinstance does not work
self.check_openapi_spec(resp)
if request.__class__.__name__ == 'HTTPRequest':
elif request.__class__.__name__ == 'HTTPRequest':
path = request.path
method = request.method
print(f'method: {method}')
Expand Down Expand Up @@ -107,7 +107,7 @@ def update_openapi_spec(self, resp, result):

if '1' not in path and x != "":
endpoint_methods[path].append(method)
elif self.partial_match(x, endpoints.keys()):
elif self.is_partial_match(x, endpoints.keys()):
path = f"/{x}"
print(f'endpoint methods = {endpoint_methods}')
print(f'new path:{path}')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,14 @@ def parse_http_status_line(self, status_line):
"""
if status_line == "Not a valid HTTP method":
return status_line
if status_line and " " in status_line:
protocol, status_code, status_message = status_line.split(' ', 2)
status_message = status_message.split("\r\n")[0]
status_line = status_line.split('\r\n')[0]
# Regular expression to match valid HTTP status lines
match = re.match(r'^(HTTP/\d\.\d) (\d{3}) (.*)$', status_line)
if match:
protocol, status_code, status_message = match.groups()
return f'{status_code} {status_message}'
raise ValueError("Invalid HTTP status line")
else:
raise ValueError("Invalid HTTP status line")

def extract_response_example(self, html_content):
"""
Expand Down
61 changes: 61 additions & 0 deletions tests/test_llm_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import unittest
from unittest.mock import MagicMock, patch
from hackingBuddyGPT.capabilities.capability import capabilities_to_action_model
from hackingBuddyGPT.usecases.web_api_testing.utils import LLMHandler


class TestLLMHandler(unittest.TestCase):
def setUp(self):
self.llm_mock = MagicMock()
self.capabilities = {'cap1': MagicMock(), 'cap2': MagicMock()}
self.llm_handler = LLMHandler(self.llm_mock, self.capabilities)

'''@patch('hackingBuddyGPT.usecases.web_api_testing.utils.capabilities_to_action_model')
def test_call_llm(self, mock_capabilities_to_action_model):
prompt = [{'role': 'user', 'content': 'Hello, LLM!'}]
response_mock = MagicMock()
self.llm_mock.instructor.chat.completions.create_with_completion.return_value = response_mock

# Mock the capabilities_to_action_model to return a dummy Pydantic model
mock_model = MagicMock()
mock_capabilities_to_action_model.return_value = mock_model

response = self.llm_handler.call_llm(prompt)

self.llm_mock.instructor.chat.completions.create_with_completion.assert_called_once_with(
model=self.llm_mock.model,
messages=prompt,
response_model=mock_model
)
self.assertEqual(response, response_mock)'''
def test_add_created_object(self):
created_object = MagicMock()
object_type = 'test_type'

self.llm_handler.add_created_object(created_object, object_type)

self.assertIn(object_type, self.llm_handler.created_objects)
self.assertIn(created_object, self.llm_handler.created_objects[object_type])

def test_add_created_object_limit(self):
created_object = MagicMock()
object_type = 'test_type'

for _ in range(8): # Exceed the limit of 7 objects
self.llm_handler.add_created_object(created_object, object_type)

self.assertEqual(len(self.llm_handler.created_objects[object_type]), 7)

def test_get_created_objects(self):
created_object = MagicMock()
object_type = 'test_type'
self.llm_handler.add_created_object(created_object, object_type)

created_objects = self.llm_handler.get_created_objects()

self.assertIn(object_type, created_objects)
self.assertIn(created_object, created_objects[object_type])
self.assertEqual(created_objects, self.llm_handler.created_objects)

if __name__ == "__main__":
unittest.main()
Loading
Loading