-
Notifications
You must be signed in to change notification settings - Fork 21
/
utils.py
28 lines (22 loc) · 883 Bytes
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
"""Utility Functions."""
import json
def load_dataset(path: str) -> list[dict]:
"""Load dataset from JSON or JSONL file."""
if path.endswith(".json"):
return json.load(open(path, "r"))
elif path.endswith(".jsonl"):
return [json.loads(line.strip()) for line in open(path, "r")]
else:
extension = path.split(".")[-1]
raise ValueError(f"File extension [{extension}] not valid.")
def write_dataset(path: str, dataset: list[dict]):
"""Write dataset to JSON or JSONL file."""
if path.endswith(".json"):
json.dump(dataset, open(path, "w"))
elif path.endswith(".jsonl"):
with open(path, "w") as fw:
for res_dict in dataset:
fw.write(json.dumps(res_dict) + "\n")
else:
extension = path.split(".")[-1]
raise ValueError(f"File extension [{extension}] not valid.")