Skip to content

Commit

Permalink
Merge pull request #31 from crytic/tests-from-corpus
Browse files Browse the repository at this point in the history
Add flag to generate tests from the entire corpus
  • Loading branch information
tuturu-tech authored Mar 25, 2024
2 parents 61a4a66 + 24da301 commit 4db6306
Show file tree
Hide file tree
Showing 148 changed files with 7,112 additions and 385,158 deletions.
10 changes: 8 additions & 2 deletions fuzz_utils/fuzzers/Echidna.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def __init__(
self.slither = slither
self.target = self.get_target_contract()
self.reproducer_dir = f"{corpus_path}/reproducers"
self.corpus_dirs = [f"{corpus_path}/coverage", self.reproducer_dir]
self.named_inputs = named_inputs

def get_target_contract(self) -> Contract:
Expand All @@ -43,7 +44,7 @@ def get_target_contract(self) -> Contract:

handle_exit(f"\n* Slither could not find the specified contract `{self.target_name}`.")

def parse_reproducer(self, calls: Any, index: int) -> str:
def parse_reproducer(self, file_path: str, calls: Any, index: int) -> str:
"""
Takes a list of call dicts and returns a Foundry unit test string containing the call sequence.
"""
Expand All @@ -59,7 +60,9 @@ def parse_reproducer(self, calls: Any, index: int) -> str:

# 2. Generate the test string and return it
template = jinja2.Template(templates["TEST"])
return template.render(function_name=function_name, call_list=call_list)
return template.render(
function_name=function_name, call_list=call_list, file_path=file_path
)

# pylint: disable=too-many-locals,too-many-branches
def _parse_call_object(self, call_dict: dict[Any, Any]) -> tuple[str, str]:
Expand Down Expand Up @@ -103,6 +106,9 @@ def _parse_call_object(self, call_dict: dict[Any, Any]) -> tuple[str, str]:
f"\n* Slither could not find the function `{function_name}` specified in the call object"
)

if not slither_entry_point.payable:
value = 0

# 2. Decode the function parameters
variable_definition, call_definition = self._decode_function_params(
function_parameters, False, slither_entry_point
Expand Down
18 changes: 14 additions & 4 deletions fuzz_utils/fuzzers/Medusa.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from fuzz_utils.utils.error_handler import handle_exit


class Medusa:
class Medusa: # pylint: disable=too-many-instance-attributes
"""
Handles the generation of Foundry test files from Medusa reproducers
"""
Expand All @@ -32,6 +32,11 @@ def __init__(
self.slither = slither
self.target = self.get_target_contract()
self.reproducer_dir = f"{corpus_path}/test_results"
self.corpus_dirs = [
f"{corpus_path}/call_sequences/immutable",
f"{corpus_path}/call_sequences/mutable",
self.reproducer_dir,
]
self.named_inputs = named_inputs

def get_target_contract(self) -> Contract:
Expand All @@ -44,7 +49,7 @@ def get_target_contract(self) -> Contract:

handle_exit(f"\n* Slither could not find the specified contract `{self.target_name}`.")

def parse_reproducer(self, calls: Any, index: int) -> str:
def parse_reproducer(self, file_path: str, calls: Any, index: int) -> str:
"""
Takes a list of call dicts and returns a Foundry unit test string containing the call sequence.
"""
Expand All @@ -58,7 +63,9 @@ def parse_reproducer(self, calls: Any, index: int) -> str:
function_name = fn_name + "_" + str(index)

template = jinja2.Template(templates["TEST"])
return template.render(function_name=function_name, call_list=call_list)
return template.render(
function_name=function_name, call_list=call_list, file_path=file_path
)
# 1. Take a reproducer list and create a test file based on the name of the last function of the list e.g. test_auto_$function_name
# 2. For each object in the list process the call object and add it to the call list
# 3. Using the call list to generate a test string
Expand All @@ -73,6 +80,7 @@ def _parse_call_object(self, call_dict: dict) -> tuple[str, str]:
# 1. Parse call object and save the variables
time_delay = int(call_dict["blockTimestampDelay"])
block_delay = int(call_dict["blockNumberDelay"])
value = int(call_dict["call"]["value"], 16)
has_delay = time_delay > 0 or block_delay > 0
function_name: str = ""

Expand All @@ -94,7 +102,6 @@ def _parse_call_object(self, call_dict: dict) -> tuple[str, str]:
if len(function_parameters) == 0:
function_parameters = ""
caller = call_dict["call"]["from"]
value = int(call_dict["call"]["value"], 16)

slither_entry_point: FunctionContract

Expand All @@ -107,6 +114,9 @@ def _parse_call_object(self, call_dict: dict) -> tuple[str, str]:
f"\n* Slither could not find the function `{function_name}` specified in the call object"
)

if not slither_entry_point.payable:
value = 0

# 2. Decode the function parameters
parameters: list = []
variable_definition: str = ""
Expand Down
57 changes: 41 additions & 16 deletions fuzz_utils/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import sys
import json
import argparse
from typing import Any
import jinja2

from pkg_resources import require
Expand All @@ -16,7 +17,7 @@
from fuzz_utils.utils.error_handler import handle_exit


class FoundryTest:
class FoundryTest: # pylint: disable=too-many-instance-attributes
"""
Handles the generation of Foundry test files
"""
Expand All @@ -29,6 +30,7 @@ def __init__(
test_dir: str,
slither: Slither,
fuzzer: Echidna | Medusa,
all_sequences: bool,
) -> None:
self.inheritance_path = inheritance_path
self.target_name = target_name
Expand All @@ -37,6 +39,7 @@ def __init__(
self.slither = slither
self.target = self.get_target_contract()
self.fuzzer = fuzzer
self.all_sequences = all_sequences

def get_target_contract(self) -> Contract:
"""Gets the Slither Contract object for the specified contract file"""
Expand All @@ -52,25 +55,34 @@ def get_target_contract(self) -> Contract:
def create_poc(self) -> str:
"""Takes in a directory path to the echidna reproducers and generates a test file"""

file_list = []
file_list: list[dict[str, Any]] = []
tests_list = []
# 1. Iterate over each reproducer file (open it)
for entry in os.listdir(self.fuzzer.reproducer_dir):
full_path = os.path.join(self.fuzzer.reproducer_dir, entry)

if os.path.isfile(full_path):
try:
with open(full_path, "r", encoding="utf-8") as file:
file_list.append(json.load(file))
except Exception: # pylint: disable=broad-except
print(f"Fail on {full_path}")
dir_list = []
if self.all_sequences:
dir_list = self.fuzzer.corpus_dirs
else:
dir_list = [self.fuzzer.reproducer_dir]

# 1. Iterate over each directory and reproducer file (open it)
for directory in dir_list:
for entry in os.listdir(directory):
full_path = os.path.join(directory, entry)

if os.path.isfile(full_path):
try:
with open(full_path, "r", encoding="utf-8") as file:
file_list.append({"path": full_path, "content": json.load(file)})
except Exception: # pylint: disable=broad-except
print(f"Fail on {full_path}")

# 2. Parse each reproducer file and add each test function to the functions list
for idx, file in enumerate(file_list):
for idx, file_obj in enumerate(file_list):
try:
tests_list.append(self.fuzzer.parse_reproducer(file, idx))
tests_list.append(
self.fuzzer.parse_reproducer(file_obj["path"], file_obj["content"], idx)
)
except Exception: # pylint: disable=broad-except
print(f"Parsing fail on {file}: index: {idx}")
print(f"Parsing fail on {file_obj['content']}: index: {idx}")

# 4. Generate the test file
template = jinja2.Template(templates["CONTRACT"])
Expand Down Expand Up @@ -135,6 +147,13 @@ def main() -> None: # type: ignore[func-returns-value]
default=False,
action="store_true",
)
parser.add_argument(
"--all-sequences",
dest="all_sequences",
help="Include all corpus sequences when generating unit tests.",
default=False,
action="store_true",
)

args = parser.parse_args()

Expand Down Expand Up @@ -165,7 +184,13 @@ def main() -> None: # type: ignore[func-returns-value]
f"Generating Foundry unit tests based on the {fuzzer.name} reproducers..."
)
foundry_test = FoundryTest(
inheritance_path, target_contract, corpus_dir, test_directory, slither, fuzzer
inheritance_path,
target_contract,
corpus_dir,
test_directory,
slither,
fuzzer,
args.all_sequences,
)
foundry_test.create_poc()
CryticPrint().print_success("Done!")
Expand Down
2 changes: 2 additions & 0 deletions fuzz_utils/templates/foundry_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
{%- else %}
target.{{function_name}}({{function_parameters}});
{%- endif %}
"""

__TRANSFER__TEMPLATE: str = """
Expand All @@ -52,6 +53,7 @@
"""

__TEST_TEMPLATE: str = """
// Reproduced from: {{file_path}}
function test_auto_{{function_name}}() public { {% for call in call_list %}
{{call}}{% endfor %}
}"""
Expand Down
4 changes: 2 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@ def __init__(self, target: str, target_path: str, corpus_dir: str):
echidna = Echidna(target, f"echidna-corpora/{corpus_dir}", slither, False)
medusa = Medusa(target, f"medusa-corpora/{corpus_dir}", slither, False)
self.echidna_generator = FoundryTest(
"../src/", target, f"echidna-corpora/{corpus_dir}", "./test/", slither, echidna
"../src/", target, f"echidna-corpora/{corpus_dir}", "./test/", slither, echidna, False
)
self.medusa_generator = FoundryTest(
"../src/", target, f"medusa-corpora/{corpus_dir}", "./test/", slither, medusa
"../src/", target, f"medusa-corpora/{corpus_dir}", "./test/", slither, medusa, False
)

def echidna_generate_tests(self) -> None:
Expand Down
Loading

0 comments on commit 4db6306

Please sign in to comment.