Skip to content

Commit

Permalink
polish einsum(phy)
Browse files Browse the repository at this point in the history
  • Loading branch information
cjld committed Apr 5, 2022
1 parent 385ab26 commit 2ff5eba
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 53 deletions.
2 changes: 1 addition & 1 deletion python/jittor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************

__version__ = '1.3.2.3'
__version__ = '1.3.2.4'
from jittor_utils import lock
with lock.lock_scope():
ori_int = int
Expand Down
6 changes: 5 additions & 1 deletion python/jittor/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,11 @@ def forward_code(np, data):
def backward_code(np, data, argnum=0):
real_len = len(data["inputs"]) - 2
operands = data["inputs"][:real_len]
in_subs, out_subs, _ = np.core.einsumfunc._parse_einsum_input([string] + operands)
_ops = operands
if np_cpu is not np:
# fake a numpy array
_ops = [ np_cpu.zeros((1,)*o.ndim) for o in _ops ]
in_subs, out_subs, _ = np_cpu.core.einsumfunc._parse_einsum_input([string] + _ops)
dout = data["dout"]
out_index = data["out_index"]
out = data["outputs"][0]
Expand Down
73 changes: 22 additions & 51 deletions python/jittor/test/test_einsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@
except:
has_autograd = False

cupy = None
try:
import cupy
except:
pass

@unittest.skipIf(not has_autograd, "No autograd found.")
class TestEinsum(unittest.TestCase):
Expand All @@ -37,27 +42,13 @@ def test_einsum_ijjk(self):
t_y = Variable(t_y, requires_grad=True)
jq = jt.linalg.einsum(string, x, y)
tq = torch.einsum(string, t_x, t_y)
try:
assert np.allclose(jq.data, tq.detach().numpy(), rtol=1e-4, atol=1e-6)
except AssertionError:
print("ours' results:")
print(jq)
print("pytorch's results:")
print(tq)
np.testing.assert_allclose(jq.data, tq.detach().numpy(), rtol=1e-4, atol=1e-6)
gq = jt.grad(jq, x).data
gr = jt.grad(jq, y).data
tgq = torch.autograd.grad(tq, t_x, torch.ones_like(tq), retain_graph=True)
tgr = torch.autograd.grad(tq, t_y, torch.ones_like(tq))
try:
assert np.allclose(gq, tgq[0].numpy(), rtol=1e-4, atol=1e-6)
assert np.allclose(gr, tgr[0].numpy(), rtol=1e-4, atol=1e-6)
except AssertionError:
print("ours' grad results:")
print(gq)
print(gr)
print("pytorch's grad result")
print(tgq[0])
print(tgr[0])
np.testing.assert_allclose(gq, tgq[0].numpy(), rtol=1e-4, atol=1e-6)
np.testing.assert_allclose(gr, tgr[0].numpy(), rtol=1e-4, atol=1e-6)

def test_einsum_ii(self):
for i in range(30):
Expand All @@ -69,22 +60,10 @@ def test_einsum_ii(self):
t_x = Variable(t_x, requires_grad=True)
jq = jt.linalg.einsum(string, x)
tq = torch.einsum(string, t_x)
try:
assert np.allclose(jq.data, tq.detach().numpy(), rtol=1e-4, atol=1e-6)
except AssertionError:
print("ours' results:")
print(jq)
print("pytorch's results:")
print(tq)
np.testing.assert_allclose(jq.data, tq.detach().numpy(), rtol=1e-4, atol=1e-6)
gq = jt.grad(jq, x).data
tgq = torch.autograd.grad(tq, t_x, torch.ones_like(tq))
try:
assert np.allclose(gq, tgq[0].numpy(), rtol=1e-4, atol=1e-6)
except AssertionError:
print("ours' grad results:")
print(gq)
print("pytorch's grad result")
print(tgq[0])
np.testing.assert_allclose(gq, tgq[0].numpy(), rtol=1e-4, atol=1e-6)

def test_einsum_multi(self):
for i in range(30):
Expand All @@ -102,32 +81,24 @@ def test_einsum_multi(self):
t_z = Variable(t_z, requires_grad=True)
jq = jt.linalg.einsum(string, x, y, z)
tq = torch.einsum(string, t_x, t_y, t_z)
try:
assert np.allclose(jq.data, tq.detach().numpy(), rtol=1e-4, atol=1e-6)
except AssertionError:
print("ours' results:")
print(jq)
print("pytorch's results:")
print(tq)
np.testing.assert_allclose(jq.data, tq.detach().numpy(), rtol=1e-4, atol=1e-6)
gq = jt.grad(jq, x).data
gr = jt.grad(jq, y).data
gz = jt.grad(jq, z).data
tgq = torch.autograd.grad(tq, t_x, torch.ones_like(tq), retain_graph=True)
tgr = torch.autograd.grad(tq, t_y, torch.ones_like(tq), retain_graph=True)
tgz = torch.autograd.grad(tq, t_z, torch.ones_like(tq), retain_graph=True)
try:
assert np.allclose(gq, tgq[0].numpy(), rtol=1e-4, atol=1e-6)
assert np.allclose(gr, tgr[0].numpy(), rtol=1e-4, atol=1e-6)
assert np.allclose(gz, tgz[0].numpy(), rtol=1e-4, atol=1e-6)
except AssertionError:
print("ours' grad results:")
print(gq)
print(gr)
print(gz)
print("pytorch's grad result")
print(tgq[0])
print(tgr[0])
print(tgz[0])
np.testing.assert_allclose(gq, tgq[0].numpy(), rtol=1e-4, atol=1e-6)
np.testing.assert_allclose(gr, tgr[0].numpy(), rtol=1e-4, atol=1e-6)
np.testing.assert_allclose(gz, tgz[0].numpy(), rtol=1e-4, atol=1e-6)


@unittest.skipIf(not jt.compiler.has_cuda or cupy is None, "No CUDA found")
class TestCudaEinsum(TestEinsum):
def setUp(self):
jt.flags.use_cuda = 1
def tearDown(self):
jt.flags.use_cuda = 0

if __name__ == "__main__":
unittest.main()

0 comments on commit 2ff5eba

Please sign in to comment.