Skip to content

Commit

Permalink
Merge branch 'master' into f_frontend
Browse files Browse the repository at this point in the history
  • Loading branch information
alexnick83 committed Jun 22, 2023
2 parents 9f7e181 + ee5bb6f commit a7f383a
Show file tree
Hide file tree
Showing 7 changed files with 1,064 additions and 20 deletions.
1 change: 1 addition & 0 deletions AUTHORS
Original file line number Diff line number Diff line change
Expand Up @@ -34,5 +34,6 @@ Cliff Hodel
Tiancheng Chen
Reid Wahl
Yihang Luo
Alexandru Calotoiu

and other contributors listed in https://github.com/spcl/dace/graphs/contributors
321 changes: 303 additions & 18 deletions dace/sdfg/analysis/cutout.py

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions dace/transformation/interstate/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
""" This module initializes the inter-state transformations package."""

from .state_fusion import StateFusion
from .state_fusion_with_happens_before import StateFusionExtended
from .state_elimination import (EndStateElimination, StartStateElimination, StateAssignElimination,
SymbolAliasPromotion, HoistState)
from .fpga_transform_state import FPGATransformState
Expand Down
590 changes: 590 additions & 0 deletions dace/transformation/interstate/state_fusion_with_happens_before.py

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion dace/transformation/subgraph/temporal_vectorization.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def can_be_applied(self, sdfg: SDFG, subgraph: SubgraphView) -> bool:
src_nodes = subgraph.source_nodes()
dst_nodes = subgraph.sink_nodes()
srcdst_nodes = src_nodes + dst_nodes
srcdst_arrays = [sdfg.arrays[node.data] for node in srcdst_nodes]
srcdst_arrays = [sdfg.arrays[node.data] for node in srcdst_nodes if isinstance(node, nodes.AccessNode)]
access_nodes = [
node for node in subgraph.nodes() if isinstance(node, nodes.AccessNode) and not node in srcdst_nodes
]
Expand Down
104 changes: 103 additions & 1 deletion tests/sdfg/cutout_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright 2019-2022 ETH Zurich and the DaCe authors. All rights reserved.
import numpy as np
import dace
from dace.sdfg.analysis.cutout import SDFGCutout
from dace.sdfg.analysis.cutout import SDFGCutout, _reduce_in_configuration
import pytest


Expand Down Expand Up @@ -312,6 +312,106 @@ def test_input_output_configuration():
assert len(ct.arrays) == 4


def test_minimum_cut_simple_no_further_input_config():
sdfg = dace.SDFG('mincut')
N = dace.symbol('N')
sdfg.add_array('A', [N], dace.float64)
sdfg.add_array('B', [N], dace.float64)
sdfg.add_array('C', [N, N], dace.float64)
sdfg.add_array('tmp1', [1], dace.float64, transient=True)
sdfg.add_array('tmp2', [1], dace.float64, transient=True)
sdfg.add_array('tmp3', [1], dace.float64, transient=True)
sdfg.add_array('tmp4', [1], dace.float64, transient=True)
sdfg.add_array('tmp5', [1], dace.float64, transient=True)
sdfg.add_array('tmp6', [1], dace.float64, transient=True)
state = sdfg.add_state('state')
mi, mo = state.add_map('map', dict(i='0:N', j='0:N'))
t1 = state.add_tasklet('t1', {'a', 'b'}, {'t'}, 't = a + b')
t2 = state.add_tasklet(
't2', {'tin'}, {'t1', 't2', 't3', 't4'}, 't1 = tin + 2\nt2 = tin * 2\nt3 = tin / 2\nt4 = tin + 1'
)
t3 = state.add_tasklet('t3', {'a', 'b'}, {'t'}, 't = a + b')
t4 = state.add_tasklet('t4', {'a', 'b', 'c'}, {'t'}, 't = (a - b) * c')
a_access = state.add_access('A')
b_access = state.add_access('B')
c_access = state.add_access('C')
tmp1_access = state.add_access('tmp1')
tmp2_access = state.add_access('tmp2')
tmp3_access = state.add_access('tmp3')
tmp4_access = state.add_access('tmp4')
tmp5_access = state.add_access('tmp5')
tmp6_access = state.add_access('tmp6')
state.add_memlet_path(a_access, mi, t1, dst_conn='a', memlet=dace.Memlet('A[i]'))
state.add_memlet_path(b_access, mi, t1, dst_conn='b', memlet=dace.Memlet('B[j]'))
state.add_edge(t1, 't', tmp1_access, None, dace.Memlet('tmp1[0]'))
state.add_edge(tmp1_access, None, t2, 'tin', dace.Memlet('tmp1[0]'))
state.add_edge(t2, 't1', tmp2_access, None, dace.Memlet('tmp2[0]'))
state.add_edge(t2, 't2', tmp3_access, None, dace.Memlet('tmp3[0]'))
state.add_edge(t2, 't3', tmp4_access, None, dace.Memlet('tmp4[0]'))
state.add_edge(t2, 't4', tmp5_access, None, dace.Memlet('tmp5[0]'))
state.add_edge(tmp2_access, None, t3, 'a', dace.Memlet('tmp2[0]'))
state.add_edge(tmp3_access, None, t3, 'b', dace.Memlet('tmp3[0]'))
state.add_edge(tmp4_access, None, t4, 'a', dace.Memlet('tmp4[0]'))
state.add_edge(tmp5_access, None, t4, 'b', dace.Memlet('tmp5[0]'))
state.add_edge(t3, 't', tmp6_access, None, dace.Memlet('tmp6[0]'))
state.add_edge(tmp6_access, None, t4, 'c', dace.Memlet('tmp6[0]'))
state.add_memlet_path(t4, mo, c_access, src_conn='t', memlet=dace.Memlet('C[i, j]'))

cutout = SDFGCutout.singlestate_cutout(state, t3, t4, tmp6_access, reduce_input_config=True)

c_state = cutout.nodes()[0]
c_nodes = set(c_state.nodes())
o_nodes = {t2, t3, t4, tmp6_access, tmp4_access, tmp5_access, tmp2_access, tmp3_access, tmp1_access, c_access}
assert len(c_nodes) == 10
for n in o_nodes:
assert cutout._in_translation[n] in c_nodes
for n in c_nodes:
assert cutout._out_translation[n] in o_nodes


def test_minimum_cut_map_scopes():
sdfg = dace.SDFG('mincut')
sdfg.add_array('A', [10, 10], dace.float64)
sdfg.add_array('B', [10, 10], dace.float64)
sdfg.add_array('tmp_1', [10, 10], dace.float64, transient=True)
sdfg.add_array('tmp_2', [10, 10], dace.float64, transient=True)
sdfg.add_array('C', [10, 10], dace.float64)

state = sdfg.add_state('state')
t1 = state.add_tasklet('t1', {'in1', 'in2'}, {'out1'}, 'out1 = in1 + in2')
t2 = state.add_tasklet('t2', {'in1'}, {'out1'}, 'out1 = in1 * 2')
t3 = state.add_tasklet('t3', {'in1', 'in2'}, {'out1'}, 'out1 = in1 + in2')
m1_i, m1_o = state.add_map('m1', dict(i='0:10', j='0:10'))
m2_i, m2_o = state.add_map('m2', dict(i='0:10', j='0:10'))
m3_i, m3_o = state.add_map('m3', dict(i='0:10', j='0:10'))

a_access = state.add_access('A')
b_access = state.add_access('B')
c_access = state.add_access('C')
tmp1_access = state.add_access('tmp_1')
tmp2_access = state.add_access('tmp_2')

state.add_memlet_path(a_access, m1_i, t1, dst_conn='in1', memlet=dace.Memlet('A[i, j]'))
state.add_memlet_path(b_access, m1_i, t1, dst_conn='in2', memlet=dace.Memlet('B[i, j]'))
state.add_memlet_path(t1, m1_o, tmp1_access, src_conn='out1', memlet=dace.Memlet('tmp_1[i, j]'))
state.add_memlet_path(tmp1_access, m2_i, t2, dst_conn='in1', memlet=dace.Memlet('tmp_1[i, j]'))
state.add_memlet_path(t2, m2_o, tmp2_access, src_conn='out1', memlet=dace.Memlet('tmp_2[i, j]'))
state.add_memlet_path(tmp1_access, m3_i, t3, dst_conn='in1', memlet=dace.Memlet('tmp_1[i, j]'))
state.add_memlet_path(tmp2_access, m3_i, t3, dst_conn='in2', memlet=dace.Memlet('tmp_2[i, j]'))
state.add_memlet_path(t3, m3_o, c_access, src_conn='out1', memlet=dace.Memlet('C[i, j]'))

cutout = SDFGCutout.singlestate_cutout(state, t3, m3_i, m3_o, reduce_input_config=True)

c_state = cutout.nodes()[0]
c_nodes = set(c_state.nodes())
o_nodes = {t2, t3, tmp1_access, tmp2_access, c_access, m2_i, m2_o, m3_i, m3_o}
assert len(c_nodes) == 9
for n in o_nodes:
assert cutout._in_translation[n] in c_nodes
for n in c_nodes:
assert cutout._out_translation[n] in o_nodes


if __name__ == '__main__':
test_cutout_onenode()
test_cutout_multinode()
Expand All @@ -322,3 +422,5 @@ def test_input_output_configuration():
test_multistate_cutout_simple_expand()
test_multistate_cutout_complex_expand()
test_input_output_configuration()
test_minimum_cut_simple_no_further_input_config()
test_minimum_cut_map_scopes()
65 changes: 65 additions & 0 deletions tests/transformations/state_fusion_extended_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from dace import SDFG, InterstateEdge,Memlet
from dace import dtypes
from dace.transformation.interstate import StateFusionExtended


def test_extended_fusion():
"""
Test the extended state fusion transformation.
It should fuse the two states into one and add a dependency between the two uses of tmp.
"""
sdfg=SDFG('extended_state_fusion_test')
sdfg.add_array('A', [20, 20], dtypes.float64)
sdfg.add_array('B', [20, 20], dtypes.float64)
sdfg.add_array('C', [20, 20], dtypes.float64)
sdfg.add_array('D', [20, 20], dtypes.float64)
sdfg.add_array('E', [20, 20], dtypes.float64)
sdfg.add_array('F', [20, 20], dtypes.float64)

sdfg.add_scalar('tmp', dtypes.float64)

strt = sdfg.add_state("start")
mid = sdfg.add_state("middle")

sdfg.add_edge(strt, mid, InterstateEdge())

acc_a = strt.add_read('A')
acc_b = strt.add_read('B')
acc_c = strt.add_write('C')
acc_tmp = strt.add_access('tmp')

acc2_d = mid.add_read('D')
acc2_e = mid.add_read('E')
acc2_f = mid.add_write('F')
acc2_tmp = mid.add_access('tmp')

t1 = strt.add_tasklet('t1', {'a', 'b'}, {
'c',
}, 'c[1,1] = a[1,1] + b[1,1]')
t2 = strt.add_tasklet('t2', {}, {
'tmpa',
}, 'tmpa=4')

t3 = mid.add_tasklet('t3', {'d', 'e'}, {
'f',
}, 'f[1,1] = e[1,1] + d[1,1]')
t4 = mid.add_tasklet('t4', {}, {
'tmpa',
}, 'tmpa=7')

strt.add_edge(acc_a, None, t1, 'a', Memlet.simple('A', '1,1'))
strt.add_edge(acc_b, None, t1, 'b', Memlet.simple('B', '1,1'))
strt.add_edge(t1, 'c', acc_c, None, Memlet.simple('C', '1,1'))
strt.add_edge(t2, 'tmpa', acc_tmp, None, Memlet.simple('tmp', '0'))

mid.add_edge(acc2_d, None, t3, 'd', Memlet.simple('D', '1,1'))
mid.add_edge(acc2_e, None, t3, 'e', Memlet.simple('E', '1,1'))
mid.add_edge(t3, 'f', acc2_f, None, Memlet.simple('F', '1,1'))
mid.add_edge(t4, 'tmpa', acc2_tmp, None, Memlet.simple('tmp', '0'))
sdfg.simplify()
sdfg.apply_transformations_repeated(StateFusionExtended)
assert sdfg.number_of_nodes()==1


if __name__ == '__main__':
test_extended_fusion()

0 comments on commit a7f383a

Please sign in to comment.