Skip to content

Commit

Permalink
Merge branch 'einsum' of https://github.com/Exusial/jittor
Browse files Browse the repository at this point in the history
  • Loading branch information
cjld committed Apr 5, 2022
2 parents 5a8b6c9 + 1aa77b9 commit 385ab26
Show file tree
Hide file tree
Showing 2 changed files with 204 additions and 1 deletion.
72 changes: 71 additions & 1 deletion python/jittor/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import jittor as jt
from functools import partial


#TODO:full_matrices=1
def svd(x):
r'''
Expand Down Expand Up @@ -430,3 +429,74 @@ def T(x):
[backward_code],
)
return q, r


def einsum(string, *args):
r"""
do the einsum operation. Using the implementation in https://github.com/HIPS/autograd
:param string, args:
:return: return values depend on the input string kinds.
"""
import numpy as np_cpu
def forward_code(np, data):
out = data["outputs"][0]
npout = np.einsum(string, *data["inputs"])
np.copyto(out, npout)

def backward_code(np, data, argnum=0):
real_len = len(data["inputs"]) - 2
operands = data["inputs"][:real_len]
in_subs, out_subs, _ = np.core.einsumfunc._parse_einsum_input([string] + operands)
dout = data["dout"]
out_index = data["out_index"]
out = data["outputs"][0]
inp = data["inputs"][argnum]
c = data["f_outputs"]

in_subs_list = in_subs.split(',')
op_num = argnum
subs_wrt = in_subs_list[op_num]
rest_of_ops = operands[:op_num] + operands[op_num+1:]
rest_of_subs = in_subs_list[:op_num] + in_subs_list[op_num+1:]
other_named_subs = set(''.join([out_subs] + rest_of_subs))
naked_summed = [(i, sub) for i, sub in enumerate(subs_wrt)
if sub not in other_named_subs]
if naked_summed:
naked_summed_dims, ones_subs = zip(*naked_summed)
ones_subs = ''.join(ones_subs)
ones = np_cpu.ones(np_cpu.array(operands[op_num].shape)[list(naked_summed_dims)])
new_input_subs = ','.join([out_subs, ones_subs] + rest_of_subs)
new_operands = [dout, ones] + rest_of_ops
else:
new_input_subs = ','.join([out_subs] + rest_of_subs)
new_operands = [dout] + rest_of_ops

new_subscripts = new_input_subs + '->' + subs_wrt
x = np.einsum(new_subscripts, *new_operands)
while np.ndim(x) > np.ndim(inp):
x = np.sum(x, axis=broadcast_idx)
for axis, size in enumerate(inp.shape):
if size == 1:
x = np.sum(x, axis=axis, keepdims=True)
np.copyto(out, x)

def einsum_outshape(einsum_expr, inputs):
shps = np_cpu.concatenate([in_.shape for in_ in inputs])
p = einsum_expr.split(',')
s = p[:-1] + p[-1].split('->')
if s[-1]=='':
return ()
else:
inop = list(map(list,s))
return tuple(shps[(np_cpu.concatenate(inop[:-1])[:,None]==inop[-1]).argmax(0)].astype(np_cpu.int64))

output_shape = [int(x) for x in einsum_outshape(string, args)]
backwards = [partial(backward_code, argnum=idx) for idx in range(len(args))]
a = jt.numpy_code(
[output_shape],
[args[0].dtype],
args,
forward_code,
backwards,
)[0]
return a
133 changes: 133 additions & 0 deletions python/jittor/test/test_einsum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
# ***************************************************************
# Copyright (c) 2021 Jittor. All Rights Reserved.
# Maintainers:
# Haoyang Peng <2247838039@qq.com>
# Guowei Yang <471184555@qq.com>
# Dun Liang <randonlang@gmail.com>.
#
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
import jittor as jt
import numpy as np
import unittest

try:
import torch
from torch.autograd import Variable
import autograd.numpy as anp
from autograd import jacobian

has_autograd = True
except:
has_autograd = False


@unittest.skipIf(not has_autograd, "No autograd found.")
class TestEinsum(unittest.TestCase):
def test_einsum_ijjk(self):
for i in range(30):
string = "ij,jk->ik"
tn, tm = np.random.randn(3, 3).astype('float32'), np.random.randn(3, 3).astype('float32')
x = jt.array(tn)
y = jt.array(tm)
t_x = torch.from_numpy(tn)
t_y = torch.from_numpy(tm)
t_x = Variable(t_x, requires_grad=True)
t_y = Variable(t_y, requires_grad=True)
jq = jt.linalg.einsum(string, x, y)
tq = torch.einsum(string, t_x, t_y)
try:
assert np.allclose(jq.data, tq.detach().numpy(), rtol=1e-4, atol=1e-6)
except AssertionError:
print("ours' results:")
print(jq)
print("pytorch's results:")
print(tq)
gq = jt.grad(jq, x).data
gr = jt.grad(jq, y).data
tgq = torch.autograd.grad(tq, t_x, torch.ones_like(tq), retain_graph=True)
tgr = torch.autograd.grad(tq, t_y, torch.ones_like(tq))
try:
assert np.allclose(gq, tgq[0].numpy(), rtol=1e-4, atol=1e-6)
assert np.allclose(gr, tgr[0].numpy(), rtol=1e-4, atol=1e-6)
except AssertionError:
print("ours' grad results:")
print(gq)
print(gr)
print("pytorch's grad result")
print(tgq[0])
print(tgr[0])

def test_einsum_ii(self):
for i in range(30):
string = "ij->i"
tn, tm = np.random.randn(3, 3).astype('float32'), np.random.randn(3, 3).astype('float32')
x = jt.array(tn)
# x = x.reindex([2, 2, x.shape[0], x.shape[1]], ["i2", "i3"])
t_x = torch.from_numpy(tn)
t_x = Variable(t_x, requires_grad=True)
jq = jt.linalg.einsum(string, x)
tq = torch.einsum(string, t_x)
try:
assert np.allclose(jq.data, tq.detach().numpy(), rtol=1e-4, atol=1e-6)
except AssertionError:
print("ours' results:")
print(jq)
print("pytorch's results:")
print(tq)
gq = jt.grad(jq, x).data
tgq = torch.autograd.grad(tq, t_x, torch.ones_like(tq))
try:
assert np.allclose(gq, tgq[0].numpy(), rtol=1e-4, atol=1e-6)
except AssertionError:
print("ours' grad results:")
print(gq)
print("pytorch's grad result")
print(tgq[0])

def test_einsum_multi(self):
for i in range(30):
string = "ij,ijk,jk->ik"
tn, tm, tk = np.random.randn(3, 4).astype('float32'), np.random.randn(3, 4, 5).astype('float32'), np.random.randn(4, 5).astype('float32')
x = jt.array(tn)
y = jt.array(tm)
z = jt.array(tk)
# x = x.reindex([2, 2, x.shape[0], x.shape[1]], ["i2", "i3"])
t_x = torch.from_numpy(tn)
t_y = torch.from_numpy(tm)
t_z = torch.from_numpy(tk)
t_x = Variable(t_x, requires_grad=True)
t_y = Variable(t_y, requires_grad=True)
t_z = Variable(t_z, requires_grad=True)
jq = jt.linalg.einsum(string, x, y, z)
tq = torch.einsum(string, t_x, t_y, t_z)
try:
assert np.allclose(jq.data, tq.detach().numpy(), rtol=1e-4, atol=1e-6)
except AssertionError:
print("ours' results:")
print(jq)
print("pytorch's results:")
print(tq)
gq = jt.grad(jq, x).data
gr = jt.grad(jq, y).data
gz = jt.grad(jq, z).data
tgq = torch.autograd.grad(tq, t_x, torch.ones_like(tq), retain_graph=True)
tgr = torch.autograd.grad(tq, t_y, torch.ones_like(tq), retain_graph=True)
tgz = torch.autograd.grad(tq, t_z, torch.ones_like(tq), retain_graph=True)
try:
assert np.allclose(gq, tgq[0].numpy(), rtol=1e-4, atol=1e-6)
assert np.allclose(gr, tgr[0].numpy(), rtol=1e-4, atol=1e-6)
assert np.allclose(gz, tgz[0].numpy(), rtol=1e-4, atol=1e-6)
except AssertionError:
print("ours' grad results:")
print(gq)
print(gr)
print(gz)
print("pytorch's grad result")
print(tgq[0])
print(tgr[0])
print(tgz[0])

if __name__ == "__main__":
unittest.main()

0 comments on commit 385ab26

Please sign in to comment.