-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy patheval.py
30 lines (22 loc) · 1011 Bytes
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import argparse
import os
import paddle
from model import Model
from evaluator import Evaluator
parser = argparse.ArgumentParser()
parser.add_argument('-d', '--data_dir', default='./data', help='directory to read LMDB files')
parser.add_argument('checkpoint', type=str, help='path to evaluate checkpoint, e.g. ./logs/model-100.pdparams')
def _eval(path_to_checkpoint_file, path_to_eval_lmdb_dir):
model=Model()
param_state_dict=paddle.load(path_to_checkpoint_file)
model.set_dict(param_state_dict)
accuracy = Evaluator(path_to_eval_lmdb_dir).evaluate(model)
print('Evaluate %s on %s, accuracy = %f' % (path_to_checkpoint_file, path_to_eval_lmdb_dir, accuracy))
def main(args):
path_to_test_lmdb_dir = os.path.join(args.data_dir, 'test.lmdb')
path_to_checkpoint_file = args.checkpoint
print('Start evaluating')
_eval(path_to_checkpoint_file, path_to_test_lmdb_dir)
print('Done')
if __name__ == '__main__':
main(parser.parse_args())