-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathmath.py
102 lines (94 loc) · 2.92 KB
/
math.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
from ..expr import *
from ..spec import TypeSpec, Op, rank_ran, dim_ran
def create_identity():
return TypeSpec(
attrs=[],
in_num=1,
in_ranks=[Var()],
in_dtypes=[Var()],
in_shapes=[List(IN[0].rank, lambda _: Var(tmpl=True))],
extra=[],
out_num=1,
out_ranks=[IN[0].rank],
out_dtypes=[IN[0].dtype],
out_shapes=[IN[0].shape]
)
Op('negative', create_identity)
Op('abs', create_identity)
Op('ceil', create_identity)
Op('floor', create_identity)
Op('round', create_identity)
Op('trunc', create_identity)
Op('exp', create_identity)
Op('sin', create_identity)
Op('cos', create_identity)
Op('tan', create_identity)
Op('sigmoid', create_identity)
Op('tanh', create_identity)
def _create_bcast():
m = IN[0].rank
n = IN[1].rank
if TypeSpec.for_graph:
return TypeSpec(
attrs=[],
in_num=2,
in_ranks=[Var(), Var(ran=iran(2, m))],
in_dtypes=List(2, lambda _: Var()),
in_shapes=[
List(m, lambda _: Var(ran=dim_ran, tmpl=True)),
List(n, lambda _: Var(ran=dim_ran, tmpl=True))
],
extra=[
ForAll(Range(end=n), lambda i: Or(
IN[0].shape[m - i - 1] == IN[1].shape[n - i - 1],
IN[0].shape[m - i - 1] == 1,
IN[1].shape[n - i - 1] == 1,
))
],
out_num=1,
out_ranks=[m],
out_dtypes=[IN[0].dtype],
out_shapes=[Concat(
IN[0].shape[Range(end=m - n)],
List(n, lambda i: IN[0].shape[m - n + i].max(IN[1].shape[i]))
)],
)
return TypeSpec(
attrs=[],
in_num=2,
in_ranks=List(2, lambda _: Var(ran=rank_ran, tmpl=True)),
in_dtypes=List(2, lambda _: Var()),
in_shapes=[
List(m, lambda _: Var(ran=dim_ran, tmpl=True)),
List(n, lambda _: Var(ran=dim_ran, tmpl=True))
],
extra=[
ForAll(Range(end=m.min(n)), lambda i: Or(
IN[0].shape[m - i - 1] == IN[1].shape[n - i - 1],
IN[0].shape[m - i - 1] == 1,
IN[1].shape[n - i - 1] == 1
))
],
out_num=1,
out_ranks=[m.max(n)],
out_dtypes=[IN[0].dtype],
out_shapes=[
Cond(
m >= n,
Concat(
IN[0].shape[Range(end=m - n)],
List(n, lambda i: IN[0].shape[m - n + i].max(IN[1].shape[i]))
),
Concat(
IN[1].shape[Range(end=n - m)],
List(m, lambda i: IN[0].shape[i].max(IN[1].shape[n - m + i]))
)
)
]
)
Op('add', _create_bcast)
Op('subtract', _create_bcast)
Op('multiply', _create_bcast)
Op('divide', _create_bcast)
Op('maximum', _create_bcast)
Op('minimum', _create_bcast)