-
Notifications
You must be signed in to change notification settings - Fork 2
/
cal_metrics_from_pt.py
52 lines (42 loc) · 2.19 KB
/
cal_metrics_from_pt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import os
import numpy as np
import torch
from glob import glob
if __name__ == '__main__':
eval_path = '/path/to/eval_results/'
results_fn_list = glob(os.path.join(eval_path, 'metrics_*.pt'))
print("num of results.pt: ", len(results_fn_list))
docking_mode = 'vina_score'
qed_all = []
sa_all = []
qvina_all = []
vina_score_all = []
vina_min_all = []
vina_dock_all = []
for rfn in results_fn_list:
result_i = torch.load(rfn)['all_results']
qed_all += [r['chem_results']['qed'] for r in result_i]
sa_all += [r['chem_results']['sa'] for r in result_i]
if docking_mode == 'qvina':
qvina_all += [r['vina'][0]['affinity'] for r in result_i]
elif docking_mode in ['vina_dock', 'vina_score']:
vina_score_all += [r['vina']['score_only'][0]['affinity'] for r in result_i]
vina_min_all += [r['vina']['minimize'][0]['affinity'] for r in result_i]
if docking_mode == 'vina_dock':
vina_dock_all += [r['vina']['dock'][0]['affinity'] for r in result_i]
qed_all_mean, qed_all_median = np.mean(qed_all), np.median(qed_all)
sa_all_mean, sa_all_median = np.mean(sa_all), np.median(sa_all)
print("qed_all_mean, qed_all_median:", qed_all_mean, qed_all_median)
print("sa_all_mean, sa_all_median:", sa_all_mean, sa_all_median)
if len(qvina_all):
qvina_all_mean, qvina_all_median = np.mean(qvina_all), np.median(qvina_all)
print("qvina_all_mean, qvina_all_median:", qvina_all_mean, qvina_all_median)
if len(vina_score_all):
vina_score_all_mean, vina_score_all_median = np.mean(vina_score_all), np.median(vina_score_all)
print("vina_score_all_mean, vina_score_all_median:", vina_score_all_mean, vina_score_all_median)
if len(vina_min_all):
vina_min_all_mean, vina_min_all_median = np.mean(vina_min_all), np.median(vina_min_all)
print("vina_min_all_mean, vina_min_all_median:", vina_min_all_mean, vina_min_all_median)
if len(vina_dock_all):
vina_dock_all_mean, vina_dock_all_median = np.mean(vina_dock_all), np.median(vina_dock_all)
print("qvina_all_mean, qvina_all_median:" , vina_dock_all_mean, vina_dock_all_median)