-
Notifications
You must be signed in to change notification settings - Fork 1
/
distanceregression_task_example.py
161 lines (127 loc) · 4.51 KB
/
distanceregression_task_example.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
#!/usr/bin/env python
# coding: utf-8
# %% Distance Regression with CHILI-3K using GCN model
# %% Imports
import warnings
import torch
import torch.nn as nn
from torch_geometric.loader import DataLoader
from torch_geometric.nn.models import GCN
from benchmark.dataset_class import CHILI
# %% Model Setup
# Hyperparamters
learning_rate = 0.001
batch_size = 16
max_epochs = 10
seeds = 42
max_patience = 50 # Epochs
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Model and Optimizer
model = GCN(in_channels = 7, hidden_channels = 32, out_channels = 1, num_layers = 4).to(device=device)
optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate)
# %% Dataset Module
# Create dataset
root = 'benchmark/dataset/'
dataset='CHILI-3K'
dataset = CHILI(root, dataset)
print(f'Running DistanceRegression example on {dataset}\n', flush=True)
# Create random split and load that into the dataset class
with warnings.catch_warnings():
warnings.simplefilter('ignore')
dataset.create_data_split(split_strategy = 'random', test_size=0.1)
dataset.load_data_split(split_strategy = 'random')
# Define dataloaders
train_loader = DataLoader(dataset.train_set, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(dataset.validation_set, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(dataset.test_set, batch_size=batch_size, shuffle=False)
print(f"Number of training samples: {len(dataset.train_set)}", flush=True)
print(f"Number of validation samples: {len(dataset.validation_set)}", flush=True)
print(f"Number of test samples: {len(dataset.test_set)}\n", flush=True)
# %% Train, validate and test
# Initialise loss function and metric function
loss_function = nn.SmoothL1Loss()
metric_function = nn.MSELoss()
improved_function = lambda best, new: new < best if best is not None else True
# Training & Validation
patience = 0
best_error = None
for epoch in range(max_epochs):
# Patience
if patience >= max_patience:
print("Max Patience reached, quitting...", flush=True)
break
# Training loop
model.train()
train_loss = 0
for data in train_loader:
# Send to device
data = data.to(device)
# Perform forward pass
pred = model.forward(
x = torch.cat((data.x, data.pos_abs), dim=1),
edge_index = data.edge_index,
edge_attr = None,
edge_weight = None,
batch = data.batch
)
pred = torch.sum(pred[data.edge_index[0, :]] * pred[data.edge_index[1, :]], dim = -1)
truth = data.edge_attr
loss = loss_function(pred, truth)
# Back prop. loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_loss += loss.item()
# Training loss
train_loss = train_loss / len(train_loader)
# Validation loop
model.eval()
val_error = 0
for data in val_loader:
# Send to device
data = data.to(device)
# Perform forward pass
with torch.no_grad():
pred = model.forward(
x = torch.cat((data.x, data.pos_abs), dim=1),
edge_index = data.edge_index,
edge_attr = None,
edge_weight = None,
batch = data.batch
)
pred = torch.sum(pred[data.edge_index[0, :]] * pred[data.edge_index[1, :]], dim = -1)
truth = data.edge_attr
metric = metric_function(pred, truth)
# Aggregate errors
val_error += metric.item()
val_error = val_error / len(val_loader)
if improved_function(best_error, val_error):
best_error = val_error
patience = 0
else:
patience += 1
# Print checkpoint
print(f'Epoch: {epoch+1}/{max_epochs}, Train Loss: {train_loss:.4f}, Val MSE: {val_error:.4f}')
# Testing loop
model.eval()
test_error = 0
for data in test_loader:
# Send to device
data = data.to(device)
# Perform forward pass
with torch.no_grad():
pred = model.forward(
x = torch.cat((data.x, data.pos_abs), dim=1),
edge_index = data.edge_index,
edge_attr = None,
edge_weight = None,
batch = data.batch
)
pred = torch.sum(pred[data.edge_index[0, :]] * pred[data.edge_index[1, :]], dim = -1)
truth = data.edge_attr
metric = metric_function(pred, truth)
# Aggregate errors
test_error += metric.item()
# Final test error
test_error = test_error / len(test_loader)
print(f"Test MSE: {test_error:.4f}")