-
Notifications
You must be signed in to change notification settings - Fork 37
/
Copy pathmobilenet_program_test.py
239 lines (191 loc) · 7.94 KB
/
mobilenet_program_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
# Copyright 2024 Advanced Micro Devices, Inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
import array
import asyncio
import time
import functools
import pytest
import shortfin as sf
import shortfin.array as sfnp
@pytest.fixture
def lsys():
sc = sf.SystemBuilder()
lsys = sc.create_system()
yield lsys
lsys.shutdown()
@pytest.fixture
def fiber0(lsys):
return lsys.create_fiber()
@pytest.fixture
def device(fiber0):
return fiber0.device(0)
@pytest.fixture
def mobilenet_program_function(
lsys, mobilenet_compiled_path
) -> tuple[sf.ProgramFunction]:
program_module = lsys.load_module(mobilenet_compiled_path)
program = sf.Program([program_module], devices=lsys.devices)
main_function = program["module.torch-jit-export"]
return main_function
@pytest.fixture
def mobilenet_program_function_per_call(
lsys, mobilenet_compiled_path
) -> tuple[sf.ProgramFunction]:
program_module = lsys.load_module(mobilenet_compiled_path)
program = sf.Program(
[program_module], devices=lsys.devices, isolation=sf.ProgramIsolation.PER_CALL
)
main_function = program["module.torch-jit-export"]
return main_function
def get_mobilenet_ref_input(device) -> sfnp.device_array:
dummy_data = array.array(
"f", ([0.2] * (224 * 224)) + ([0.4] * (224 * 224)) + ([-0.2] * (224 * 224))
)
device_input = sfnp.device_array(device, [1, 3, 224, 224], sfnp.float32)
staging_input = device_input.for_transfer()
with staging_input.map(discard=True) as m:
m.fill(dummy_data)
device_input.copy_from(staging_input)
return device_input
async def assert_mobilenet_ref_output(device, device_output):
host_output = device_output.for_transfer()
host_output.copy_from(device_output)
await device
flat_output = host_output.items
absmean = functools.reduce(
lambda x, y: x + abs(y) / len(flat_output), flat_output, 0.0
)
# Note: this value was just copied from a sample run of the test.
# Comparison against a reference backend for this model is tested upstream
# in https://github.com/iree-org/iree-test-suites/tree/main/onnx_models.
assert absmean == pytest.approx(0.81196929)
# Tests that a single invocation on a single fiber works.
def test_invoke_mobilenet_single_per_fiber(lsys, fiber0, mobilenet_program_function):
assert mobilenet_program_function.isolation == sf.ProgramIsolation.PER_FIBER
device = fiber0.device(0)
async def main():
device_input = get_mobilenet_ref_input(device)
(device_output,) = await mobilenet_program_function(device_input, fiber=fiber0)
await assert_mobilenet_ref_output(device, device_output)
lsys.run(main())
# Tests that a single invocation on a single fiber in per_call mode works.
def test_invoke_mobilenet_single_per_call(
lsys, fiber0, mobilenet_program_function_per_call
):
assert mobilenet_program_function_per_call.isolation == sf.ProgramIsolation.PER_CALL
device = fiber0.device(0)
async def main():
device_input = get_mobilenet_ref_input(device)
(device_output,) = await mobilenet_program_function_per_call(
device_input, fiber=fiber0
)
await assert_mobilenet_ref_output(device, device_output)
lsys.run(main())
# Tests that chained back to back invocations on the same fiber work correctly.
# Does an async gather/assert with all results at the end.
def test_invoke_mobilenet_chained_per_fiber(lsys, fiber0, mobilenet_program_function):
assert mobilenet_program_function.isolation == sf.ProgramIsolation.PER_FIBER
device = fiber0.device(0)
async def main():
device_input = get_mobilenet_ref_input(device)
results = [
await mobilenet_program_function(device_input, fiber=fiber0)
for _ in range(5)
]
await asyncio.gather(
*[
assert_mobilenet_ref_output(device, device_output)
for (device_output,) in results
]
)
lsys.run(main())
# Tests that parallel invocations on a single fiber with a program in PER_CALL
# isolation functions properly. Note that in this variant, the await is done
# on all invocations vs serially per invocation (as in
# test_invoke_mobilenet_chained_per_fiber). This would be illegal if done on the
# same fiber without PER_CALL isolation managing forks.
#
# Note that since these are all operating on the same fiber, they are added to
# the device-side work graph with a one-after-the-other dependency, but the
# host side schedules concurrently.
def test_invoke_mobilenet_parallel_per_call(
lsys, fiber0, mobilenet_program_function_per_call
):
assert mobilenet_program_function_per_call.isolation == sf.ProgramIsolation.PER_CALL
device = fiber0.device(0)
async def main():
device_input = get_mobilenet_ref_input(device)
results = await asyncio.gather(
*[
mobilenet_program_function_per_call(device_input, fiber=fiber0)
for _ in range(5)
]
)
await asyncio.gather(
*[
assert_mobilenet_ref_output(device, device_output)
for (device_output,) in results
]
)
lsys.run(main())
# Same as above but uses explicit isolation controls on the function vs as the
# program level. If this constraint were violated, shortfin makes a best effort
# attempt to detect the situation and raise an exception, but there are a subset
# of programs which are purely async and would make detection of this exception
# lossy in the synchronous completion case.
def test_invoke_mobilenet_parallel_per_call_explicit(
lsys, fiber0, mobilenet_program_function
):
assert mobilenet_program_function.isolation == sf.ProgramIsolation.PER_FIBER
device = fiber0.device(0)
async def main():
device_input = get_mobilenet_ref_input(device)
results = await asyncio.gather(
*[
mobilenet_program_function(
device_input, fiber=fiber0, isolation=sf.ProgramIsolation.PER_CALL
)
for _ in range(50)
]
)
await asyncio.gather(
*[
assert_mobilenet_ref_output(device, device_output)
for (device_output,) in results
]
)
lsys.run(main())
# Tests that independent executions on multiple fibers all run concurrently.
# All fibers share the same host thread but schedule concurrently. Since
# each fiber has its own timeline, device side graphs have no dependency on
# each other and also schedule concurrently.
def test_invoke_mobilenet_multi_fiber_per_fiber(lsys, mobilenet_program_function):
assert mobilenet_program_function.isolation == sf.ProgramIsolation.PER_FIBER
class InferProcess(sf.Process):
async def run(self):
start_time = time.time()
def duration():
return round((time.time() - start_time) * 1000.0)
print(f"{self}: Start")
device = self.fiber.device(0)
device_input = get_mobilenet_ref_input(device)
(device_output,) = await mobilenet_program_function(
device_input, fiber=self.fiber
)
print(f"{self}: Program complete (+{duration()}ms)")
await assert_mobilenet_ref_output(device, device_output)
print(f"{self} End (+{duration()}ms)")
async def main():
start_time = time.time()
def duration():
return round((time.time() - start_time) * 1000.0)
fibers = [lsys.create_fiber() for _ in range(5)]
print("Fibers:", fibers)
processes = [InferProcess(fiber=f).launch() for f in fibers]
print("Waiting for processes:", processes)
await asyncio.gather(*processes)
print(f"All processes complete: (+{duration()}ms)")
lsys.run(main())