1
- import json
2
1
import yaml
3
2
from typing import Union
3
+ from dataclasses import dataclass
4
4
5
5
class TestFile :
6
- """An abstraction over a test file, which may be in one of several different formats .
7
- Currently, JSON and YAML are supported.
6
+ """An abstraction over a test file.
7
+ Currently, only YAML files are supported.
8
8
"""
9
9
10
- def __init__ (self , path : str ) -> None :
10
+ def __init__ (self , file_content : str , file_name : str ) -> None :
11
11
self .groups = []
12
12
13
- # Attempt to open the given file. Exit with an error if this
14
- # is not possible.
15
- file_content = ""
16
- try :
17
- with open (path , "r" ) as test_file :
18
- file_content = test_file .read ()
19
- except IOError as e :
20
- raise Exception (f'Failed to open test file: "{ e } "' )
21
-
22
13
# Get the file extension to determine which format should be used.
23
- extension = path .split ("." )[- 1 ]
24
- if extension == "json" :
25
- try :
26
- questions = json .loads (file_content )
27
-
28
- for question in questions :
29
- out = []
30
- title = question ["title" ]
31
- for part in question ["parts" ]:
32
- for response_area in part ["responseAreas" ]:
33
- params = response_area ["params" ]
34
- answer = response_area ["answer" ]
35
- for test in response_area ["tests" ]:
36
- test .update ({"answer" : answer })
37
- test .update ({"params" : params })
38
- out .append (SingleTest (test ))
39
- self .groups .append ({"title" : title , "tests" : out })
40
-
41
- except KeyError as e :
42
- raise Exception (f'The key "{ e .args [0 ]} " doesn\' t exist, or is in the wrong place.' )
43
- except json .JSONDecodeError as e :
44
- raise Exception (f'Error parsing JSON: "{ e } "' )
45
- elif extension == "yaml" :
14
+ extension = file_name .split ("." )[- 1 ]
15
+ if extension == "yaml" :
46
16
try :
47
17
# Tests are organised in groups of separate YAML documents (separated by "---")
48
18
docs = yaml .safe_load_all (file_content )
@@ -53,66 +23,85 @@ def __init__(self, path: str) -> None:
53
23
# Add an empty params field if none was provided.
54
24
if test .get ("params" ) == None :
55
25
test ["params" ] = {}
56
-
57
- # Does this test have sub-tests?
58
- sub_tests = test .get ("sub_tests" )
59
- if sub_tests != None :
60
- params = test ["params" ]
61
- answer = test ["answer" ]
62
-
63
- for sub_test in sub_tests :
64
- sub_test ["params" ] = params
65
- sub_test ["answer" ] = answer
66
- tests .append (SingleTest (sub_test ))
67
- else :
68
- tests .append (SingleTest (test ))
26
+
27
+ tests .append (SingleTest (test ))
69
28
70
29
self .groups .append ({"title" : title , "tests" : tests })
71
30
except yaml .YAMLError as e :
72
31
raise Exception (f'Error parsing YAML: { e } ' )
73
32
else :
74
33
raise Exception (f'"{ extension } " files are not supported as a test format.' )
75
34
35
+
76
36
class SingleTest :
77
37
def __init__ (self , test_dict : dict ):
78
- self .response = test_dict .get ("response" , "" )
79
38
self .answer = test_dict .get ("answer" , "" )
80
39
self .params = test_dict .get ("params" , {})
81
- expected_result = test_dict .get ("expected_result" )
82
- if not expected_result :
83
- raise Exception ("No expected result given for test" )
84
- self .is_correct = expected_result .get ("is_correct" )
85
- self .results = expected_result
86
40
self .desc = test_dict .get ("description" , "" )
87
41
88
- def evaluate (self , func ) -> dict :
89
- return func (self .response , self .answer , self .params )
42
+ self .sub_tests = []
43
+ if "sub_tests" in test_dict :
44
+ for sub_test in test_dict ["sub_tests" ]:
45
+ expected_result = sub_test .get ("expected_result" )
46
+ if not expected_result :
47
+ raise Exception ("No expected result given for test" )
48
+
49
+ self .sub_tests .append (SubTest (
50
+ sub_test .get ("description" , "" ),
51
+ sub_test .get ("response" , "" ),
52
+ expected_result .get ("is_correct" ),
53
+ expected_result ,
54
+ ))
55
+ else :
56
+ expected_result = test_dict .get ("expected_result" )
57
+ if not expected_result :
58
+ raise Exception ("No expected result given for test" )
59
+
60
+ self .sub_tests .append (SubTest (
61
+ "" ,
62
+ test_dict .get ("response" , "" ),
63
+ expected_result .get ("is_correct" ),
64
+ expected_result ,
65
+ ))
66
+
67
+ def evaluate_all (self , func ) -> list [dict ]:
68
+ return [func (test .response , self .answer , self .params ) for test in self .sub_tests ]
90
69
91
- def compare (self , eval_result : dict ) -> tuple [bool , str ]:
92
- eval_correct = eval_result ["is_correct" ]
93
-
94
- if eval_correct != self .is_correct :
95
- return (
96
- False ,
97
- f"response \" { self .response } \" with answer \" { self .answer } \" was { '' if eval_correct else 'in' } correct: { eval_result ['feedback' ]} \n Test description: { self .desc } "
98
- )
99
-
100
- # Are there any other fields in the eval function result that need to be checked?
101
- if self .results != None :
102
- # Check each one in turn
103
- for key , value in self .results .items ():
104
- actual_result_val = eval_result .get (key )
105
- if actual_result_val == None :
106
- return (False , f"No value returned for \" { key } \" " )
70
+ def compare_all (self , eval_results : list [dict ]) -> tuple [bool , str ]:
71
+ for i , eval_result in enumerate (eval_results ):
72
+ eval_correct = eval_result ["is_correct" ]
107
73
108
- if actual_result_val != value :
109
- return (
110
- False ,
111
- f"expected { key } = \" { value } \" , got { key } = \" { actual_result_val } \" \n Test description: { self .desc } "
112
- )
74
+ if eval_correct != self .sub_tests [i ].is_correct :
75
+ return (
76
+ False ,
77
+ (f"response \" { self .sub_tests [i ].response } \" with answer "
78
+ f"\" { self .answer } \" was { '' if eval_correct else 'in' } correct: "
79
+ f"{ eval_result ['feedback' ]} \n Test description: { self .sub_tests [i ].desc } " )
80
+ )
81
+
82
+ # Are there any other fields in the eval function result that need to be checked?
83
+ if self .sub_tests [i ].expected_result != None :
84
+ # Check each one in turn
85
+ for key , value in self .sub_tests [i ].expected_result .items ():
86
+ actual_result_val = eval_result .get (key )
87
+ if actual_result_val == None :
88
+ return (False , f"No value returned for \" { key } \" " )
89
+
90
+ if actual_result_val != value :
91
+ return (
92
+ False ,
93
+ f"expected { key } = \" { value } \" , got { key } = \" { actual_result_val } \" \n Test description: { self .desc } "
94
+ )
113
95
114
96
return (True , "" )
115
97
98
+ @dataclass
99
+ class SubTest :
100
+ desc : str
101
+ response : str
102
+ is_correct : bool
103
+ expected_result : dict
104
+
116
105
117
106
def auto_test (path , func ):
118
107
"""A decorator that adds the necessary infrastructure to run tests defined
@@ -124,16 +113,18 @@ def _auto_test(orig_class):
124
113
def test_auto (self ):
125
114
# Creating a TestFile can fail for several reasons.
126
115
# If so, an exception is raised with a suitable error message
116
+ tests = {}
127
117
try :
128
- tests = TestFile (path )
118
+ with open (path , "r" ) as f :
119
+ tests = TestFile (f .read (), path )
129
120
except Exception as e :
130
121
self .fail (e )
131
122
132
123
# Successfully loaded
133
124
for group in tests .groups :
134
125
for test in group ["tests" ]:
135
- results = test .evaluate (func )
136
- self .assertTrue (* test .compare ( results .to_dict ()))
126
+ results = test .evaluate_all (func )
127
+ self .assertTrue (* test .compare_all ( map ( lambda r : r .to_dict (), results )))
137
128
138
129
orig_class .test_auto = test_auto # Add the test_auto function to the class
139
130
return orig_class
0 commit comments