-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_tree_3c.py
46 lines (42 loc) · 1.59 KB
/
test_tree_3c.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
from parsing import Tensor, SparseIndex, BinaryContraction, NaryContraction, IntermediateResult
from fused_ir import FusedIR, get_includes
k = SparseIndex("k", 100)
mu = SparseIndex("mu", 100)
nu = SparseIndex("nu", 100)
i = SparseIndex("i", 100)
muhat = SparseIndex("muhat", 100)
Int = Tensor("Int", [mu, nu, k])
C = Tensor("C", [nu, i])
Phat = Tensor("Phat", [mu, muhat])
L = Tensor("L", [k, i])
X = Tensor("X", [k, i, muhat])
IntC = IntermediateResult(Int, C, [nu], const_shape="PAO")
IntCPhat = IntermediateResult(IntC, Phat, [mu], const_shape="PAO")
statements = [BinaryContraction(IntC, Int, C), BinaryContraction(
IntCPhat, IntC, Phat), BinaryContraction(X, IntCPhat, L)]
contraction = NaryContraction(X, [Int, C, Phat, L])
contraction.statements = statements
print(contraction)
gen = contraction.fuse_loops()
fir = FusedIR(gen)
fir.reduce_intermediates()
print(fir)
print(fir.emit_taco_kernel("filter_const"))
with open("3c_filter_fused.hpp", "w") as f:
f.write(get_includes())
f.write(fir.emit_taco_kernel("filter_const"))
X_nofilter = Tensor("X_nofilter", [k, i, muhat])
IntC = IntermediateResult(Int, C, [nu], const_shape="PAO")
statements = [BinaryContraction(IntC, Int, C),
BinaryContraction(X_nofilter, IntC, Phat)]
contraction = NaryContraction(X_nofilter, [Int, C, Phat])
contraction.statements = statements
print(contraction)
gen = contraction.fuse_loops()
fir = FusedIR(gen)
fir.reduce_intermediates()
print(fir)
print(fir.emit_taco_kernel("nofilter_const"))
with open("3c_nofilter_fused.hpp", "w") as f:
f.write(get_includes())
f.write(fir.emit_taco_kernel("nofilter_const"))