Skip to content

Commit

Permalink
feat: update JSONParseEvaluator
Browse files Browse the repository at this point in the history
  • Loading branch information
gwkrsrch committed Aug 23, 2022
1 parent d2fd95a commit 86bcafe
Showing 1 changed file with 21 additions and 19 deletions.
40 changes: 21 additions & 19 deletions donut/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,14 +137,13 @@ def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tenso

class JSONParseEvaluator:
"""
Calculate n-TED(Normalized Tree Edit Distance) based accuracy and F1 accuracy score
Calculate n-TED(Normalized Tree Edit Distance) based accuracy and F1 accuracy score
"""

@staticmethod
def flatten(data: dict):
"""
Convert Dictionary into Non-nested Dictionary
Example:
input(dict)
{
Expand All @@ -153,13 +152,15 @@ def flatten(data: dict):
{"name" : ["juice"], "count" : ["1"]},
]
}
output(dict)
{
"menu.name": ["cake", "juice"],
"menu.count": ["2", "1"],
}
output(list)
[
("menu.name", "cake"),
("menu.count", "2"),
("menu.name", "juice"),
("menu.count", "1"),
]
"""
flatten_data = defaultdict(list)
flatten_data = list()

def _flatten(value, key=""):
if type(value) is dict:
Expand All @@ -169,10 +170,10 @@ def _flatten(value, key=""):
for value_item in value:
_flatten(value_item, key)
else:
flatten_data[key].append(value)
flatten_data.append((key, value))

_flatten(data)
return dict(flatten_data)
return flatten_data

@staticmethod
def update_cost(label1: str, label2: str):
Expand Down Expand Up @@ -225,10 +226,11 @@ def normalize_dict(self, data: Union[Dict, List, Any]):
elif isinstance(data, list):
if all(isinstance(item, dict) for item in data):
new_data = []
for item in sorted(data, key=lambda x: str(sorted(x.items()))):
for item in data:
item = self.normalize_dict(item)
if item:
new_data.append(item)
new_data = sorted(new_data, key=lambda x: str(x.keys())+str(x.values()))
else:
new_data = sorted([str(item) for item in data if type(item) in {str, int, float} and str(item)])
else:
Expand All @@ -243,14 +245,14 @@ def cal_f1(self, preds: List[dict], answers: List[dict]):
total_tp, total_fn_or_fp = 0, 0
for pred, answer in zip(preds, answers):
pred, answer = self.flatten(self.normalize_dict(pred)), self.flatten(self.normalize_dict(answer))
for pred_key, pred_values in pred.items():
for pred_value in pred_values:
if pred_key in answer and pred_value in answer[pred_key]:
answer[pred_key].remove(pred_value)
total_tp += 1
else:
total_fn_or_fp += 1
return total_tp / (total_tp + (total_fn_or_fp) / 2)
for field in pred:
if field in answer:
total_tp += 1
answer.remove(field)
else:
total_fn_or_fp += 1
total_fn_or_fp += len(answer)
return total_tp / (total_tp + total_fn_or_fp / 2)

def construct_tree_from_dict(self, data: Union[Dict, List], node_name: str = None):
"""
Expand Down

0 comments on commit 86bcafe

Please sign in to comment.