-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathutils_clr_downstream.py
115 lines (100 loc) · 4.62 KB
/
utils_clr_downstream.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import os
from torch_geometric.data import InMemoryDataset
from torch_geometric import data as DATA
import torch
from creat_data_DC import smile_to_graph
import deepchem as dc
class TestbedDataset(InMemoryDataset):
def __init__(self, root='tmp', dataset='train', task='bbbp',
transform=None, pre_transform=None):
#root is required for save preprocessed data, default is '/tmp'
super(TestbedDataset, self).__init__(root, transform, pre_transform)
# benchmark dataset
self.dataset = dataset
self.task = task
if os.path.isfile(self.processed_paths[0]):
if dataset == 'train':
print('Pre-processed data found: {}, loading ...'.format(self.processed_paths[0]))
self.data, self.slices = torch.load(self.processed_paths[0])
if dataset == 'valid':
print('Pre-processed data found: {}, loading ...'.format(self.processed_paths[1]))
self.data, self.slices = torch.load(self.processed_paths[1])
if dataset == 'test':
print('Pre-processed data found: {}, loading ...'.format(self.processed_paths[2]))
self.data, self.slices = torch.load(self.processed_paths[2])
else:
print('Pre-processed data {} not found, doing pre-processing...'.format(self.processed_paths[0]))
self.process(root, task)
if dataset == 'train':
self.data, self.slices = torch.load(self.processed_paths[0])
if dataset == 'valid':
self.data, self.slices = torch.load(self.processed_paths[1])
if dataset == 'test':
self.data, self.slices = torch.load(self.processed_paths[2])
@property
def raw_file_names(self):
pass
#return ['some_file_1', 'some_file_2', ...]
@property
def processed_file_names(self):
return [self.task + '_train.pt', self.task + '_valid.pt', self.task + '_test.pt']
def download(self):
# Download to `self.raw_dir`.
pass
def _download(self):
pass
def _process(self):
if not os.path.exists(self.processed_dir):
os.makedirs(self.processed_dir)
def process(self, root, task):
splitter = 'scaffold'
featurizer = 'ECFP'
print(dc.__version__)
if task == 'BBBP':
tasks, datasets, transformers = dc.molnet.load_bbbp(featurizer=featurizer, splitter=splitter)
elif task == 'Tox21':
tasks, datasets, transformers = dc.molnet.load_tox21(featurizer=featurizer, splitter=splitter)
elif task == 'ClinTox':
tasks, datasets, transformers = dc.molnet.load_clintox(featurizer=featurizer, splitter=splitter)
elif task == 'HIV':
tasks, datasets, transformers = dc.molnet.load_hiv(featurizer=featurizer, splitter=splitter)
elif task == 'BACE':
tasks, datasets, transformers = dc.molnet.load_bace_classification(featurizer=featurizer, splitter=splitter)
elif task == 'SIDER':
tasks, datasets, transformers = dc.molnet.load_sider(featurizer=featurizer, splitter=splitter)
elif task == 'MUV':
tasks, datasets, transformers = dc.molnet.load_muv(featurizer=featurizer, splitter=splitter)
train, valid, test = datasets
save(self, train, 0)
save(self, valid, 1)
save(self, test, 2)
def save(self, dataset, path):
data_list = []
for i in range(len(dataset)):
smile = dataset.ids[i]
label = dataset.y[i]
if len(smile) <= 5:
continue
print('smiles', smile)
c_size, features, edge_index, atoms = smile_to_graph(smile)
if len(edge_index) > 0:
edge_index = torch.LongTensor(edge_index).transpose(1, 0)
else:
edge_index = torch.LongTensor(edge_index)
GCNData = DATA.Data(x=torch.Tensor(features),
edge_index=edge_index,
y=torch.Tensor(label),
smiles=smile)
# append graph, label and target sequence to data list
data_list.append(GCNData)
if self.pre_filter is not None:
data_list = [data for data in data_list if self.pre_filter(data)]
if self.pre_transform is not None:
data_list = [self.pre_transform(data) for data in data_list]
print('Graph construction done. Saving to file.')
data, slices = self.collate(data_list)
# save preprocessed data:
torch.save((data, slices), self.processed_paths[path])
def save_AUCs(AUCs, filename):
with open(filename, 'a') as f:
f.write('\t'.join(map(str, AUCs)) + '\n')