-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathtestcase_utils.py
222 lines (179 loc) · 8.55 KB
/
testcase_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
from typing import List, Dict, Any
from loguru import logger
import ast
import re
import json
from tqdm import tqdm
def get_parameter_names(prompt: str, entry_point: str) -> List[str]:
"""
Extract parameter names from the function signature in the prompt.
"""
# logger.debug(f"Prompt: {prompt}")
# logger.debug(f"Entry point: {entry_point}")
tree = ast.parse(prompt)
for node in ast.walk(tree):
# logger.debug(f"Node name: {node.name if hasattr(node, 'name') else None}")
if isinstance(node, ast.FunctionDef) and node.name == entry_point:
# Return the parameter names from the function definition that matches the entry point
return [param.arg for param in node.args.args]
return []
def parse_tests(test: str, parameter_names: List[str], entry_point: str) -> Dict[str, List[Dict[str, Any]]]:
"""
Parse the test string into a structured format using AST.
"""
# Remove the METADATA section
test = re.sub(r'METADATA = \{[^}]*\}', '', test)
# Parse the entire test string
tree = ast.parse(test)
test_cases = []
for node in ast.walk(tree):
if isinstance(node, ast.Assert):
# Process each assert statement
test_case = process_assert(node, entry_point, parameter_names)
if test_case:
test_cases.append(test_case)
return {"test_cases": test_cases}
def process_assert(node: ast.Assert, entry_point: str, parameter_names: List[str]) -> Dict[str, Any]:
"""
Process a single assert statement and extract input and expected output.
"""
if isinstance(node.test, ast.Compare) and isinstance(node.test.ops[0], ast.Eq):
left = node.test.left
right = node.test.comparators[0]
if isinstance(left, ast.Call) and isinstance(left.func, ast.Name) and left.func.id == "candidate":
input_dict = process_input(left.args, parameter_names)
# logger.debug(f"Input: {input_dict}")
# logger.debug(f"right: {right}")
# logger.debug(f"right type: {type(right)}")
# logger.debug(f"right value: {right.name if isinstance(right, ast.Name) else right.s if isinstance(right, ast.Str) else None}")
try:
# Attempt to evaluate using literal_eval
expected_output = ast.literal_eval(right)
except ValueError:
# Fallback to eval if literal_eval fails
# logger.warning("Falling back to eval due to failure in literal_eval")
expected_output = eval(compile(ast.Expression(right), filename="<ast>", mode="eval"))
return {"input": input_dict, "expected_output": expected_output}
return None
def process_input(args: List[ast.expr], parameter_names: List[str]) -> Dict[str, Any]:
"""
Process the input arguments and match them with parameter names.
"""
input_dict = {}
for i, arg in enumerate(args):
try:
# Attempt to evaluate using literal_eval for simpler cases
evaluated_arg = ast.literal_eval(arg)
except ValueError:
# Fallback to eval if literal_eval fails
# logger.warning("Falling back to eval due to failure in literal_eval")
evaluated_arg = eval(compile(ast.Expression(arg), filename="<ast>", mode="eval"))
if i < len(parameter_names):
input_dict[parameter_names[i]] = evaluated_arg
else:
# Handle extra arguments if any
input_dict[f"arg_{i}"] = evaluated_arg
return input_dict
def parse_all_problems(problems):
success_count = 0
unhandled_failures = 0
for problem in problems:
try:
problem = json.loads(problem)
# logger.info(f"Problem: {problem}")
# logger.debug(f"Test: {problem['test']}")
entry_point = problem["entry_point"]
parameter_names = get_parameter_names(problem["prompt"], entry_point)
# logger.info(f"Parameter names: {parameter_names}")
given_tests_raw = "\n".join(problem["given_tests"]).replace(entry_point, "candidate")
given_tests = parse_tests(given_tests_raw, parameter_names, entry_point)
# Parse the test cases using the parameter names
parsed_tests = parse_tests(problem["test"], parameter_names, entry_point)
# logger.info(f"Parsed tests: {parsed_tests}")
success_count += 1
except:
logger.exception(f"Error processing problem {problem['task_id']}")
if problem['is_solved'] == False:
unhandled_failures += 1
continue
logger.info(f"Success count: {success_count}")
logger.info(f"Total problems: {len(problems)}")
logger.info(f"Unhandled failures: {unhandled_failures}")
def parse_specific_problem(problem):
try:
if isinstance(problem, str):
problem = json.loads(problem)
logger.info(f"Problem: {problem}")
logger.debug(f"Test: {problem['test']}")
logger.debug(f"Given Test: {problem['given_tests']}")
entry_point = problem["entry_point"]
parameter_names = get_parameter_names(problem["prompt"], entry_point)
logger.debug(f"Parameter names: {parameter_names}")
given_tests_raw = "\n".join(problem["given_tests"]).replace(entry_point, "candidate")
given_tests = parse_tests(given_tests_raw, parameter_names, entry_point)
logger.debug(f"Given tests: {given_tests}")
# Parse the test cases using the parameter names
all_tests = parse_tests(problem["test"], parameter_names, entry_point)
logger.debug(f"Parsed tests: {all_tests}")
return all_tests
except:
logger.exception(f"Error processing problem {problem['task_id']}")
return None
#assert next_smallest([]) is None
#assert decode_cyclic(encode_cyclic("abc")) == "abc"
#assert round(find_zero([-6, 11, -6, 1]), 2) == 1.0
#assert abs(candidate(1.33) - 0.33) < 1e-6
def check_all_problems(problems):
problems_q = []
success_count = 0
fail_count = 0
for problem in tqdm(problems):
try:
problem = json.loads(problem)
logger.info(f"Problem: {problem}")
logger.debug(f"Test: {problem['test']}")
logger.debug(f"All Test: {problem['given_tests']}")
entry_point = problem["entry_point"]
parameter_names = get_parameter_names(problem["prompt"], entry_point)
logger.info(f"Parameter names: {parameter_names}")
# given_tests_len = len(problem["given_tests"])
# given_tests_raw = "\n".join(problem["given_tests"]).replace(entry_point, "candidate")
# given_tests = parse_tests(given_tests_raw, parameter_names, entry_point)
# parsed_given_tests_len = len(given_tests['test_cases'])
# assert given_tests_len == parsed_given_tests_len
# success_count += 1
#Parse the test cases using the parameter names
tests_len_candidate = problem["test"].count('candidate')
parsed_tests = parse_tests(problem["test"], parameter_names, entry_point)
parsed_test_len = len(parsed_tests['test_cases'])
#assert parsed_test_len != 0
assert tests_len_candidate - 1 == parsed_test_len
logger.info(f"Parsed tests: {parsed_tests}")
success_count += 1
except:
logger.exception(f"Error processing problem {problem['task_id']}")
if problem['is_solved'] == False:
fail_count += 1
problems_q.append(problem['task_id'])
continue
with open('output_data/humaneval/seed/deepseek-coder-v2-lite-instruct/20240828-174550/dscoder_debugged_seeds_deepseek-coder-v2-lite-instruct_1_1_10.jsonl', "r") as f:
fixed = f.readlines()
for fix_problem in fixed:
fix_problem = json.loads(fix_problem)
if fix_problem['task_id'] in problems_q:
print(1)
logger.info(f"Success count: {success_count}")
logger.info(f"Total problems: {len(problems)}")
logger.info(f"Unhandled failures: {fail_count}")
if __name__ == "__main__":
input_seeds = "input_data/humaneval/seed/deepseek-coder-v2-lite-instruct/seed.jsonl"
with open(input_seeds, "r") as f:
problems = f.readlines()
check_all_problems(problems)
#parse_all_problems(problems)
# parse the one with 'task_id': 'HumanEval/32'
# for problem in problems:
# problem = json.loads(problem)
# if problem['task_id'] == 'HumanEval/33':
# parsed_tests = parse_specific_problem(problem)
# break