Skip to content

Commit e338f4e

Browse files
committed
Restructure sub-tests, remove JSON support
1 parent 7b6cfd5 commit e338f4e

File tree

2 files changed

+72
-157
lines changed

2 files changed

+72
-157
lines changed

eval_tests_old.json

Lines changed: 0 additions & 76 deletions
This file was deleted.

evaluation_function/auto_tests.py

Lines changed: 72 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -1,48 +1,18 @@
1-
import json
21
import yaml
32
from typing import Union
3+
from dataclasses import dataclass
44

55
class TestFile:
6-
"""An abstraction over a test file, which may be in one of several different formats.
7-
Currently, JSON and YAML are supported.
6+
"""An abstraction over a test file.
7+
Currently, only YAML files are supported.
88
"""
99

10-
def __init__(self, path: str) -> None:
10+
def __init__(self, file_content: str, file_name: str) -> None:
1111
self.groups = []
1212

13-
# Attempt to open the given file. Exit with an error if this
14-
# is not possible.
15-
file_content = ""
16-
try:
17-
with open(path, "r") as test_file:
18-
file_content = test_file.read()
19-
except IOError as e:
20-
raise Exception(f'Failed to open test file: "{e}"')
21-
2213
# Get the file extension to determine which format should be used.
23-
extension = path.split(".")[-1]
24-
if extension == "json":
25-
try:
26-
questions = json.loads(file_content)
27-
28-
for question in questions:
29-
out = []
30-
title = question["title"]
31-
for part in question["parts"]:
32-
for response_area in part["responseAreas"]:
33-
params = response_area["params"]
34-
answer = response_area["answer"]
35-
for test in response_area["tests"]:
36-
test.update({"answer": answer})
37-
test.update({"params": params})
38-
out.append(SingleTest(test))
39-
self.groups.append({"title": title, "tests": out})
40-
41-
except KeyError as e:
42-
raise Exception(f'The key "{e.args[0]}" doesn\'t exist, or is in the wrong place.')
43-
except json.JSONDecodeError as e:
44-
raise Exception(f'Error parsing JSON: "{e}"')
45-
elif extension == "yaml":
14+
extension = file_name.split(".")[-1]
15+
if extension == "yaml":
4616
try:
4717
# Tests are organised in groups of separate YAML documents (separated by "---")
4818
docs = yaml.safe_load_all(file_content)
@@ -53,66 +23,85 @@ def __init__(self, path: str) -> None:
5323
# Add an empty params field if none was provided.
5424
if test.get("params") == None:
5525
test["params"] = {}
56-
57-
# Does this test have sub-tests?
58-
sub_tests = test.get("sub_tests")
59-
if sub_tests != None:
60-
params = test["params"]
61-
answer = test["answer"]
62-
63-
for sub_test in sub_tests:
64-
sub_test["params"] = params
65-
sub_test["answer"] = answer
66-
tests.append(SingleTest(sub_test))
67-
else:
68-
tests.append(SingleTest(test))
26+
27+
tests.append(SingleTest(test))
6928

7029
self.groups.append({"title": title, "tests": tests})
7130
except yaml.YAMLError as e:
7231
raise Exception(f'Error parsing YAML: {e}')
7332
else:
7433
raise Exception(f'"{extension}" files are not supported as a test format.')
7534

35+
7636
class SingleTest:
7737
def __init__(self, test_dict: dict):
78-
self.response = test_dict.get("response", "")
7938
self.answer = test_dict.get("answer", "")
8039
self.params = test_dict.get("params", {})
81-
expected_result = test_dict.get("expected_result")
82-
if not expected_result:
83-
raise Exception("No expected result given for test")
84-
self.is_correct = expected_result.get("is_correct")
85-
self.results = expected_result
8640
self.desc = test_dict.get("description", "")
8741

88-
def evaluate(self, func) -> dict:
89-
return func(self.response, self.answer, self.params)
42+
self.sub_tests = []
43+
if "sub_tests" in test_dict:
44+
for sub_test in test_dict["sub_tests"]:
45+
expected_result = sub_test.get("expected_result")
46+
if not expected_result:
47+
raise Exception("No expected result given for test")
48+
49+
self.sub_tests.append(SubTest(
50+
sub_test.get("description", ""),
51+
sub_test.get("response", ""),
52+
expected_result.get("is_correct"),
53+
expected_result,
54+
))
55+
else:
56+
expected_result = test_dict.get("expected_result")
57+
if not expected_result:
58+
raise Exception("No expected result given for test")
59+
60+
self.sub_tests.append(SubTest(
61+
"",
62+
test_dict.get("response", ""),
63+
expected_result.get("is_correct"),
64+
expected_result,
65+
))
66+
67+
def evaluate_all(self, func) -> list[dict]:
68+
return [func(test.response, self.answer, self.params) for test in self.sub_tests]
9069

91-
def compare(self, eval_result: dict) -> tuple[bool, str]:
92-
eval_correct = eval_result["is_correct"]
93-
94-
if eval_correct != self.is_correct:
95-
return (
96-
False,
97-
f"response \"{self.response}\" with answer \"{self.answer}\" was {'' if eval_correct else 'in'}correct: {eval_result['feedback']}\nTest description: {self.desc}"
98-
)
99-
100-
# Are there any other fields in the eval function result that need to be checked?
101-
if self.results != None:
102-
# Check each one in turn
103-
for key, value in self.results.items():
104-
actual_result_val = eval_result.get(key)
105-
if actual_result_val == None:
106-
return (False, f"No value returned for \"{key}\"")
70+
def compare_all(self, eval_results: list[dict]) -> tuple[bool, str]:
71+
for i, eval_result in enumerate(eval_results):
72+
eval_correct = eval_result["is_correct"]
10773

108-
if actual_result_val != value:
109-
return (
110-
False,
111-
f"expected {key} = \"{value}\", got {key} = \"{actual_result_val}\"\nTest description: {self.desc}"
112-
)
74+
if eval_correct != self.sub_tests[i].is_correct:
75+
return (
76+
False,
77+
(f"response \"{self.sub_tests[i].response}\" with answer "
78+
f"\"{self.answer}\" was {'' if eval_correct else 'in'}correct: "
79+
f"{eval_result['feedback']}\nTest description: {self.sub_tests[i].desc}")
80+
)
81+
82+
# Are there any other fields in the eval function result that need to be checked?
83+
if self.sub_tests[i].expected_result != None:
84+
# Check each one in turn
85+
for key, value in self.sub_tests[i].expected_result.items():
86+
actual_result_val = eval_result.get(key)
87+
if actual_result_val == None:
88+
return (False, f"No value returned for \"{key}\"")
89+
90+
if actual_result_val != value:
91+
return (
92+
False,
93+
f"expected {key} = \"{value}\", got {key} = \"{actual_result_val}\"\nTest description: {self.desc}"
94+
)
11395

11496
return (True, "")
11597

98+
@dataclass
99+
class SubTest:
100+
desc: str
101+
response: str
102+
is_correct: bool
103+
expected_result: dict
104+
116105

117106
def auto_test(path, func):
118107
"""A decorator that adds the necessary infrastructure to run tests defined
@@ -124,16 +113,18 @@ def _auto_test(orig_class):
124113
def test_auto(self):
125114
# Creating a TestFile can fail for several reasons.
126115
# If so, an exception is raised with a suitable error message
116+
tests = {}
127117
try:
128-
tests = TestFile(path)
118+
with open(path, "r") as f:
119+
tests = TestFile(f.read(), path)
129120
except Exception as e:
130121
self.fail(e)
131122

132123
# Successfully loaded
133124
for group in tests.groups:
134125
for test in group["tests"]:
135-
results = test.evaluate(func)
136-
self.assertTrue(*test.compare(results.to_dict()))
126+
results = test.evaluate_all(func)
127+
self.assertTrue(*test.compare_all(map(lambda r: r.to_dict(), results)))
137128

138129
orig_class.test_auto = test_auto # Add the test_auto function to the class
139130
return orig_class

0 commit comments

Comments
 (0)