-
Notifications
You must be signed in to change notification settings - Fork 80
/
Copy pathhybrid_gpu_pipeline.py
118 lines (92 loc) · 3.07 KB
/
hybrid_gpu_pipeline.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
"""
quantum part in tensorflow or jax, neural part in torch, both on GPU,
fantastic hybrid pipeline
"""
import os
os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true"
import time
import numpy as np
import tensorflow as tf
import torch
import tensorcircuit as tc
K = tc.set_backend("tensorflow")
if torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
print(device)
enable_dlpack = True
# enable_dlpack = False # for old version of ML libs
tf_device = "/GPU:0"
# tf_device = "/device:CPU:0"
# another scheme to globally close GPU only for tf
# https://datascience.stackexchange.com/a/76039
# but if gpu support is fully shut down as above
# dlpack=True wont work
# dataset preparation
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train = x_train[..., np.newaxis] / 255.0
def filter_pair(x, y, a, b):
keep = (y == a) | (y == b)
x, y = x[keep], y[keep]
y = y == a
return x, y
x_train, y_train = filter_pair(x_train, y_train, 1, 5)
x_train_small = tf.image.resize(x_train, (3, 3)).numpy()
x_train_bin = np.array(x_train_small > 0.5, dtype=np.float32)
x_train_bin = np.squeeze(x_train_bin).reshape([-1, 9])
y_train_torch = torch.tensor(y_train, dtype=torch.float32)
x_train_torch = torch.tensor(x_train_bin)
x_train_torch = x_train_torch.to(device=device)
y_train_torch = y_train_torch.to(device=device)
n = 9
nlayers = 3
# We define the quantum function,
# note how this function is running on tensorflow
def qpreds(x, weights):
with tf.device(tf_device):
c = tc.Circuit(n)
for i in range(n):
c.rx(i, theta=x[i])
for j in range(nlayers):
for i in range(n - 1):
c.cnot(i, i + 1)
for i in range(n):
c.rx(i, theta=weights[2 * j, i])
c.ry(i, theta=weights[2 * j + 1, i])
return K.stack([K.real(c.expectation_ps(z=[i])) for i in range(n)])
# qpreds_vmap = K.vmap(qpreds, vectorized_argnums=0)
# qpreds_batch = tc.interfaces.torch_interface(qpreds_vmap, jit=True, enable_dlpack=True)
quantumnet = tc.TorchLayer(
qpreds,
weights_shape=[2 * nlayers, n],
use_vmap=True,
use_interface=True,
use_jit=True,
enable_dlpack=enable_dlpack,
)
model = torch.nn.Sequential(quantumnet, torch.nn.Linear(9, 1), torch.nn.Sigmoid())
model = model.to(device=device)
criterion = torch.nn.BCELoss()
opt = torch.optim.Adam(model.parameters(), lr=1e-2)
nepochs = 300
nbatch = 32
times = []
for epoch in range(nepochs):
index = np.random.randint(low=0, high=100, size=nbatch)
# index = np.arange(nbatch)
inputs, labels = x_train_torch[index], y_train_torch[index]
opt.zero_grad()
with torch.set_grad_enabled(True):
time0 = time.time()
yps = model(inputs)
loss = criterion(
torch.reshape(yps, [nbatch, 1]), torch.reshape(labels, [nbatch, 1])
)
loss.backward()
if epoch % 100 == 0:
print(loss)
opt.step()
time1 = time.time()
times.append(time1 - time0)
print("training time per step: ", np.mean(times[1:]))