Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds support for generic text parsing of capabilities #58

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 93 additions & 1 deletion capabilities/capability.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import abc
import inspect
from typing import Union, Type, Dict
from typing import Union, Type, Dict, Callable, Any

from pydantic import create_model, BaseModel

Expand Down Expand Up @@ -78,3 +78,95 @@ class Model(Action):

return Model


SimpleTextHandlerResult = tuple[bool, Union[str, tuple[str, str, ...]]]
SimpleTextHandler = Callable[[str], SimpleTextHandlerResult]


def capabilities_to_simple_text_handler(capabilities: Dict[str, Capability], default_capability: Capability = None, include_description: bool = True) -> tuple[Dict[str, str], SimpleTextHandler]:
"""
This function generates a simple text handler from a set of capabilities.
It is to be used when no function calling is available, and structured output is not to be trusted, which is why it
only supports the most basic of parameter types for the capabilities (str, int, float, bool).

As result it returns a dictionary of capability names to their descriptions and a parser function that can be used
to parse the text input and execute it. The first return value of the parser function is a boolean indicating
whether the parsing was successful, the second return value is a tuple containing the capability name, the parameters
as a string and the result of the capability execution.
"""
def get_simple_fields(func, name) -> Dict[str, Type]:
sig = inspect.signature(func)
fields = {param: param_info.annotation for param, param_info in sig.parameters.items()}
for param, param_type in fields.items():
if param_type not in (str, int, float, bool):
raise ValueError(f"The command {name} is not compatible with this calling convention (this is not a LLM error, but rather a problem with the capability itself, the parameter {param} is {param_type} and not a simple type (str, int, float, bool))")
return fields

def parse_params(fields, params) -> tuple[bool, Union[str, Dict[str, Any]]]:
split_params = params.split(" ", maxsplit=len(fields) - 1)
if len(split_params) != len(fields):
return False, "Invalid number of parameters"

parsed_params = dict()
for param, param_type in fields.items():
try:
parsed_params[param] = param_type(split_params.pop(0))
except ValueError as e:
return False, f"Could not parse parameter {param}: {e}"
return True, parsed_params

capability_descriptions = dict()
capability_params = dict()
for capability_name, capability in capabilities.items():
fields = get_simple_fields(capability.__call__, capability_name)

description = f"`{capability_name}"
if len(fields) > 0:
description += " " + " ".join(param for param in fields)
description += "`"
if include_description:
description += f": {capability.describe()}"

capability_descriptions[capability_name] = description
capability_params[capability_name] = fields

def parser(text: str) -> SimpleTextHandlerResult:
capability_name_and_params = text.split(" ", maxsplit=1)
if len(capability_name_and_params) == 1:
capability_name = capability_name_and_params[0]
params = ""
else:
capability_name, params = capability_name_and_params
if capability_name not in capabilities:
return False, "Unknown command"

success, parsing_result = parse_params(capability_params[capability_name], params)
if not success:
return False, parsing_result

return True, (capability_name, params, capabilities[capability_name](**parsing_result))

resolved_parser: SimpleTextHandler = parser

if default_capability is not None:
default_fields = get_simple_fields(default_capability.__call__, "__default__")

def default_capability_parser(text: str) -> SimpleTextHandlerResult:
success, *output = parser(text)
if success:
return success, *output

params = text
success, parsing_result = parse_params(default_fields, params)
if not success:
params = text.split(" ", maxsplit=1)[1]
success, parsing_result = parse_params(default_fields, params)
if not success:
return False, parsing_result

return True, (capability_name, params, default_capability(**parsing_result))


resolved_parser = default_capability_parser

return capability_descriptions, resolved_parser
2 changes: 1 addition & 1 deletion capabilities/psexec_test_credential.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class PSExecTestCredential(Capability):
conn: PSExecConnection

def describe(self) -> str:
return f"give credentials to be tested by stating `{self.get_name()} username password`"
return f"give credentials to be tested"

def get_name(self) -> str:
return "test_credential"
Expand Down
10 changes: 4 additions & 6 deletions capabilities/ssh_run_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,17 @@
@dataclass
class SSHRunCommand(Capability):
conn: SSHConnection
timeout: int = 10

def describe(self) -> str:
return f"give a command to be executed by stating `{self.get_name()} command arguments` and I will respond with the terminal output when running this command over SSH on the linux machine. The given command must not require user interaction."
return f"give a command to be executed and I will respond with the terminal output when running this command over SSH on the linux machine. The given command must not require user interaction."

def get_name(self):
return "exec_command"

def __call__(self, command: str, timeout:int=10) -> Tuple[str, bool]:
def __call__(self, command: str) -> Tuple[str, bool]:
got_root = False

cmd_parts = command.split(" ", 1)
command = cmd_parts[1]

sudo_pass = Responder(
pattern=r'\[sudo\] password for ' + self.conn.username + ':',
response=self.conn.password + '\n',
Expand All @@ -39,7 +37,7 @@ def __call__(self, command: str, timeout:int=10) -> Tuple[str, bool]:
out = StringIO()

try:
resp = self.conn.run(command, pty=True, warn=True, out_stream=out, watchers=[sudo_pass], timeout=timeout)
resp = self.conn.run(command, pty=True, warn=True, out_stream=out, watchers=[sudo_pass], timeout=self.timeout)
except Exception as e:
print("TIMEOUT! Could we have become root?")
out.seek(0)
Expand Down
12 changes: 3 additions & 9 deletions capabilities/ssh_test_credential.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,13 @@ class SSHTestCredential(Capability):
conn: SSHConnection

def describe(self) -> str:
return f"give credentials to be tested by stating `{self.get_name()} username password`"
return f"give credentials to be tested"

def get_name(self):
return "test_credential"

def __call__(self, command: str) -> Tuple[str, bool]:
cmd_parts = command.split(" ")
assert (cmd_parts[0] == self.get_name())

if len(cmd_parts) != 3:
return "didn't provide username/password", False

test_conn = self.conn.new_with(username=cmd_parts[1], password=cmd_parts[2])
def __call__(self, username: str, password: str) -> Tuple[str, bool]:
test_conn = self.conn.new_with(username=username, password=password)
try:
test_conn.init()
user = test_conn.run("whoami")[0].strip('\n\r ')
Expand Down
13 changes: 7 additions & 6 deletions usecases/agents.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,27 @@
from abc import ABC
from dataclasses import dataclass, field
from typing import Dict

from capabilities.capability import Capability
from capabilities.capability import Capability, capabilities_to_simple_text_handler
from usecases.common_patterns import RoundBasedUseCase


@dataclass
class Agent(RoundBasedUseCase):

class Agent(RoundBasedUseCase, ABC):
_capabilities: Dict[str, Capability] = field(default_factory=dict)
_default_capability: Capability = None

def init(self):
super().init()

def add_capability(self, cap:Capability, default:bool=False):
def add_capability(self, cap: Capability, default: bool = False):
self._capabilities[cap.get_name()] = cap
if default:
self._default_capability = cap

def get_capability(self, name:str) -> Capability:
def get_capability(self, name: str) -> Capability:
return self._capabilities.get(name, self._default_capability)

def get_capability_block(self) -> str:
return "You can either\n\n" + "\n".join(map(lambda i: f"- {i.describe()}", self._capabilities.values()))
capability_descriptions, _parser = capabilities_to_simple_text_handler(self._capabilities)
return "You can either\n\n" + "\n".join(f"- {description}" for description in capability_descriptions.values())
12 changes: 8 additions & 4 deletions usecases/privesc/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from rich.panel import Panel

from capabilities import Capability
from capabilities.capability import capabilities_to_simple_text_handler
from usecases.agents import Agent
from utils import llm_util, ui
from utils.cli_history import SlidingCliHistory
Expand Down Expand Up @@ -48,10 +49,13 @@ def perform_round(self, turn):

with self.console.status("[bold green]Executing that command..."):
self.console.print(Panel(answer.result, title="[bold cyan]Got command from LLM:"))
capability = cmd.split(" ", 1)[0]
result, got_root = self.get_capability(capability)(cmd)
if capability == "exec_command":
cmd = cmd[len(capability)+1:]
_capability_descriptions, parser = capabilities_to_simple_text_handler(self._capabilities, default_capability=self._default_capability)
success, *output = parser(cmd)
if not success:
self.console.print(Panel(output[0], title=f"[bold red]Error parsing command:"))
return False

capability, cmd, (result, got_root) = output

# log and output the command and its result
self.log_db.add_log_query(self._run_id, turn, cmd, result, answer)
Expand Down
2 changes: 1 addition & 1 deletion usecases/privesc/linux.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def read_hint(self):

run_cmd = "wget -q 'https://github.com/diego-treitos/linux-smart-enumeration/releases/latest/download/lse.sh' -O lse.sh;chmod 700 lse.sh; ./lse.sh -c -i -l 0 | grep -v 'nope$' | grep -v 'skip$'"

result, got_root = SSHRunCommand(conn=self.conn)(run_cmd, timeout=120)
result, got_root = SSHRunCommand(conn=self.conn, timeout=120)(run_cmd)

self.console.print("[yellow]got the output: " + result)
cmd = self.llm.get_response(template_lse, lse_output=result, number=3)
Expand Down
2 changes: 1 addition & 1 deletion utils/configurable.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def get_parameters(fun, basename: str, fields: Dict[str, dataclasses.Field] = No

if hasattr(type, "__parameters__"):
params[name] = ComplexParameterDefinition(name, type, default, description, get_class_parameters(type, f"{basename}.{fun.__name__}"))
elif type in (str, int, bool):
elif type in (str, int, float, bool):
params[name] = ParameterDefinition(name, type, default, description)
else:
raise ValueError(f"Parameter {name} of {basename}.{fun.__name__} must have str, int, bool, or a __parameters__ class as type, not {type}")
Expand Down