You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I have been experimenting with ti.ad.Tape and kernel.grad.
Am I doing it the wrong way or there is a bug?
The following two chunks of codes should give the same results but they don't...
Using with ti.ad.Tape(loss)
ti.init()
max_t = 4
theta = ti.field(dtype=ti.f32, shape=())
loss = ti.field(dtype=ti.f32, shape=())
z = ti.field(dtype=ti.f32, shape=(max_t,))
z_ = ti.field(dtype=ti.f32, shape=())
ti.root.lazy_grad()
@ti.kernel
def compute(t: ti.i32):
z[t+1] = z[t] + theta[None] * 2
@ti.kernel
def loss_func(t: ti.i32):
loss[None] = ti.abs(z[t] - z_[None])
theta[None] = 2
z_[None] = 200
z[0] = 1
for k in range(10):
loss[None] = 0
with ti.ad.Tape(loss):
for i in range(max_t-1):
compute(i)
loss_func(max_t-1)
print(f'loss = {loss[None]: 0.3f}, theta.grad = {theta.grad}')
theta[None] -= 0.2 * theta.grad[None]
print(f'new theta = {theta[None]: 0.3f}')
Using kernel.grad()
for k in range(10):
loss[None] = 0
loss.grad[None] = 1
for i in range(max_t-1):
compute(i)
loss_func(max_t-1)
loss_func.grad(max_t-1)
for i in range(max_t-1, -1, -1):
compute.grad(i)
print(f'loss = {loss[None]: 0.3f}, theta.grad = {theta.grad}')
theta[None] -= 0.2 * theta.grad[None]
print(f'new theta = {theta[None]: 0.3f}')
theta.grad[None] = 0
z.grad.fill(0)
z_.grad[None] = 0
Any thought?
The text was updated successfully, but these errors were encountered:
Hi,
I have been experimenting with ti.ad.Tape and kernel.grad.
Am I doing it the wrong way or there is a bug?
The following two chunks of codes should give the same results but they don't...
Using with ti.ad.Tape(loss)
Using kernel.grad()
Any thought?
The text was updated successfully, but these errors were encountered: