-
Notifications
You must be signed in to change notification settings - Fork 3.4k
/
generic.py
1176 lines (973 loc) · 39.6 KB
/
generic.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Definition of generic operator strategy."""
# pylint: disable=invalid-name,unused-argument
import logging
import re
from tvm import topi
from tvm.topi.util import get_const_int, get_const_float, get_const_tuple, get_float_tuple
from .. import op as _op
from ....target import generic_func, override_native_generic_func
logger = logging.getLogger("strategy")
def wrap_topi_schedule(topi_schedule):
"""Wrap TOPI schedule which doesn't use attrs"""
def wrapper(attrs, outs, target):
with target:
return topi_schedule(outs)
return wrapper
def get_conv2d_in_channels(data_shape, data_layout):
"""Get conv2d input channels"""
data_shape = get_const_tuple(data_shape)
if len(data_shape) == 4:
idx = data_layout.find("C")
assert idx >= 0, "Invalid conv2d data layout {}".format(data_layout)
return data_shape[idx]
if re.match(r"NCHW\d*c", data_layout):
# NCHW[8]c
return data_shape[1] * data_shape[4]
raise ValueError("Unknown conv2d data layout {}".format(data_layout))
def get_conv2d_out_channels(kernel_shape, kernel_layout):
"""Get conv2d output channels"""
kernel_shape = get_const_tuple(kernel_shape)
if len(kernel_shape) == 4:
idx = kernel_layout.find("O")
assert idx >= 0, "Invalid conv2d kernel layout {}".format(kernel_layout)
return kernel_shape[idx]
if re.match(r"OIHW\d*i\d*o", kernel_layout):
return kernel_shape[0] * kernel_shape[5]
if re.match(r"OIHW\d*o", kernel_layout):
return kernel_shape[0] * kernel_shape[4]
raise ValueError("Unknown conv2d kernel layout {}".format(kernel_layout))
def is_depthwise_conv2d(data_shape, data_layout, kernel_shape, kernel_layout, groups):
ic = get_conv2d_in_channels(data_shape, data_layout)
oc = get_conv2d_out_channels(kernel_shape, kernel_layout)
return ic == oc == groups
@generic_func
def schedule_injective(attrs, outs, target):
"""Schedule injective ops"""
with target:
return topi.generic.schedule_injective(outs)
@generic_func
def schedule_reduce(attrs, outs, target):
"""Schedule reduction ops"""
with target:
return topi.generic.schedule_reduce(outs)
_op._schedule_injective = schedule_injective
_op._schedule_reduce = schedule_reduce
# concatenate
@generic_func
def schedule_concatenate(attrs, outs, target):
"""Schedule concatenate op"""
with target:
return topi.generic.schedule_injective(outs)
# pool
@generic_func
def schedule_pool(attrs, outs, target):
"""Schedule pooling ops"""
with target:
return topi.generic.schedule_pool(outs, attrs.layout)
# pool_grad
@generic_func
def schedule_pool_grad(attrs, outs, target):
"""Schedule pooling gradient ops"""
with target:
return topi.generic.schedule_pool_grad(outs)
# adaptive pool
@generic_func
def schedule_adaptive_pool(attrs, outs, target):
"""Schedule adaptive pooling ops"""
with target:
return topi.generic.schedule_adaptive_pool(outs)
# softmax
def wrap_compute_softmax(topi_compute):
"""Wrap softmax topi compute"""
def _compute_softmax(attrs, inputs, out_type):
axis = attrs.get_int("axis")
return [topi_compute(inputs[0], axis)]
return _compute_softmax
@override_native_generic_func("softmax_strategy")
def softmax_strategy(attrs, inputs, out_type, target):
"""softmax generic strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_softmax(topi.nn.softmax),
wrap_topi_schedule(topi.generic.schedule_softmax),
name="softmax.generic",
)
return strategy
# log_softmax
@generic_func
def schedule_log_softmax(attrs, outs, target):
"""Schedule log_softmax op"""
with target:
return topi.generic.schedule_softmax(outs)
# lrn
@generic_func
def schedule_lrn(attrs, outs, target):
"""Schedule LRN op"""
with target:
return topi.generic.schedule_lrn(outs)
# bitpack
@generic_func
def schedule_bitpack(attrs, outs, target):
"""Schedule bitpack"""
with target:
return topi.generic.schedule_bitpack(outs)
# conv2d
def wrap_compute_conv2d(
topi_compute, need_data_layout=False, need_out_layout=False, has_groups=False
):
"""Wrap conv2d topi compute"""
def _compute_conv2d(attrs, inputs, out_type):
padding = get_const_tuple(attrs.padding)
strides = get_const_tuple(attrs.strides)
dilation = get_const_tuple(attrs.dilation)
data_layout = attrs.get_str("data_layout")
out_layout = attrs.get_str("out_layout")
out_dtype = attrs.out_dtype
out_dtype = inputs[0].dtype if out_dtype in ("same", "") else out_dtype
args = [inputs[0], inputs[1], strides, padding, dilation]
if has_groups:
args.append(attrs.groups)
if need_data_layout:
args.append(data_layout)
if need_out_layout:
args.append(out_layout)
args.append(out_dtype)
return [topi_compute(*args)]
return _compute_conv2d
@override_native_generic_func("conv2d_strategy")
def conv2d_strategy(attrs, inputs, out_type, target):
"""conv2d generic strategy"""
logger.warning("conv2d is not optimized for this platform.")
strategy = _op.OpStrategy()
data, kernel = inputs
dilation = get_const_tuple(attrs.dilation)
groups = attrs.groups
layout = attrs.data_layout
kernel_layout = attrs.kernel_layout
(dilation_h, dilation_w) = dilation
if dilation_h < 1 or dilation_w < 1:
raise ValueError("dilation should be positive value")
if groups == 1:
if layout == "NCHW":
assert kernel_layout == "OIHW"
strategy.add_implementation(
wrap_compute_conv2d(topi.nn.conv2d_nchw),
wrap_topi_schedule(topi.generic.schedule_conv2d_nchw),
name="conv2d_nchw.generic",
)
elif layout == "NHWC":
assert kernel_layout == "HWIO"
strategy.add_implementation(
wrap_compute_conv2d(topi.nn.conv2d_nhwc),
wrap_topi_schedule(topi.generic.schedule_conv2d_nhwc),
name="conv2d_nhwc.generic",
)
elif layout == "HWCN":
assert kernel_layout == "HWIO"
strategy.add_implementation(
wrap_compute_conv2d(topi.nn.conv2d_hwcn),
wrap_topi_schedule(topi.generic.schedule_conv2d_hwcn),
name="conv2d_hwcn.generic",
)
else:
raise RuntimeError("Unsupported conv2d layout {}".format(layout))
elif is_depthwise_conv2d(data.shape, layout, kernel.shape, kernel_layout, groups):
if layout == "NCHW":
assert kernel_layout == "OIHW"
strategy.add_implementation(
wrap_compute_conv2d(topi.nn.depthwise_conv2d_nchw),
wrap_topi_schedule(topi.generic.schedule_depthwise_conv2d_nchw),
name="depthwise_conv2d_nchw.generic",
)
elif layout == "NHWC":
assert kernel_layout == "HWOI"
strategy.add_implementation(
wrap_compute_conv2d(topi.nn.depthwise_conv2d_nhwc),
wrap_topi_schedule(topi.generic.schedule_depthwise_conv2d_nhwc),
name="depthwise_conv2d_nhwc.generic",
)
else:
raise RuntimeError("Unsupported depthwise_conv2d layout {}".format(layout))
else: # group_conv2d
if layout == "NCHW":
assert kernel_layout == "OIHW"
strategy.add_implementation(
wrap_compute_conv2d(topi.nn.group_conv2d_nchw, has_groups=True),
wrap_topi_schedule(topi.generic.schedule_group_conv2d_nchw),
name="group_conv2d_nchw.generic",
)
elif layout == "NHWC":
assert kernel_layout == "HWIO"
strategy.add_implementation(
wrap_compute_conv2d(topi.nn.group_conv2d_nhwc, has_groups=True),
wrap_topi_schedule(topi.generic.schedule_group_conv2d_nhwc),
name="group_conv2d_nhwc.generic",
)
else:
raise RuntimeError("Unsupported group_conv2d layout {}".format(layout))
return strategy
# conv2d_NCHWc
@override_native_generic_func("conv2d_NCHWc_strategy")
def conv2d_NCHWc_strategy(attrs, inputs, out_type, target):
"""conv2d_NCHWc generic strategy"""
logger.warning("conv2d_NCHWc is not optimized for this platform.")
strategy = _op.OpStrategy()
if inputs[0].dtype == "int8" or inputs[0].dtype == "uint8":
strategy.add_implementation(
wrap_compute_conv2d(topi.nn.conv2d_NCHWc_int8, True, True),
wrap_topi_schedule(topi.generic.schedule_conv2d_NCHWc_int8),
name="conv2d_NCHWc_int8.generic",
)
else:
strategy.add_implementation(
wrap_compute_conv2d(topi.nn.conv2d_NCHWc, True, True),
wrap_topi_schedule(topi.generic.schedule_conv2d_NCHWc),
name="conv2d_NCHWc.generic",
)
return strategy
# depthwise_conv2d_NCHWc
@override_native_generic_func("depthwise_conv2d_NCHWc_strategy")
def depthwise_conv2d_NCHWc_strategy(attrs, inputs, out_type, target):
"""depthwise_conv2d generic strategy"""
logger.warning("depthwise_conv2d_NCHWc is not optimized for this platform.")
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_conv2d(topi.nn.depthwise_conv2d_NCHWc, True, True),
wrap_topi_schedule(topi.generic.schedule_depthwise_conv2d_NCHWc),
name="depthwise_conv2d_NCHWc.generic",
)
return strategy
# conv2d_winograd_without_weight_transform
@override_native_generic_func("conv2d_winograd_without_weight_transform_strategy")
def conv2d_winograd_without_weight_transfrom_strategy(attrs, inputs, out_type, target):
"""conv2d_winograd_without_weight_transfrom generic strategy"""
raise ValueError("No generic implemenation for conv2d_winograd_without_weight_transform")
# conv2d_gemm_without_weight_transform
@override_native_generic_func("conv2d_gemm_without_weight_transform_strategy")
def conv2d_gemm_without_weight_transform_strategy(attrs, inputs, out_type, target):
"""conv2d_gemm_without_weight_transfrom generic strategy"""
raise ValueError("No generic implemenation for conv2d_gemm_without_weight_transform")
# conv2d_winograd_weight_transform
@generic_func
def schedule_conv2d_winograd_weight_transform(attrs, outs, target):
"""Schedule conv2d_winograd_weight_transform"""
with target:
return topi.generic.schedule_conv2d_winograd_weight_transform(outs)
# conv2d_winograd_nnpack_weight_transform
@generic_func
def schedule_conv2d_winograd_nnpack_weight_transform(attrs, outs, target):
"""Schedule conv2d_winograd_nnpack_weight_transform"""
with target:
return topi.generic.schedule_conv2d_winograd_nnpack_weight_transform(outs)
# conv2d_gemm_weight_transform
@generic_func
def schedule_conv2d_gemm_weight_transform(attrs, outs, target):
"""Schedule conv2d_gemm_weight_transform"""
with target:
return topi.generic.schedule_conv2d_gemm_weight_transform(outs)
# deformable_conv2d
def wrap_compute_deformable_conv2d(topi_compute):
"""wrap deformable_conv2d topi compute"""
def _compute_deformable_conv2d(attrs, inputs, out_dtype):
assert attrs.data_layout == "NCHW"
padding = get_const_tuple(attrs.padding)
strides = get_const_tuple(attrs.strides)
dilation = get_const_tuple(attrs.dilation)
deformable_groups = attrs.deformable_groups
groups = attrs.groups
out_dtype = attrs.out_dtype
out_dtype = inputs[0].dtype if out_dtype in ("same", "") else out_dtype
out = topi_compute(
inputs[0],
inputs[1],
inputs[2],
strides,
padding,
dilation,
deformable_groups,
groups,
out_dtype,
)
return [out]
return _compute_deformable_conv2d
@override_native_generic_func("deformable_conv2d_strategy")
def deformable_conv2d_strategy(attrs, inputs, out_type, target):
"""deformable_conv2d generic strategy"""
logger.warning("deformable_conv2d is not optimized for this platform.")
layout = attrs.data_layout
assert layout == "NCHW"
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_deformable_conv2d(topi.nn.deformable_conv2d_nchw),
wrap_topi_schedule(topi.generic.schedule_deformable_conv2d_nchw),
name="deformable_conv2d.generic",
)
return strategy
# conv2d_transpose
def wrap_compute_conv2d_transpose(topi_compute):
"""wrap conv2d_transpose topi compute"""
def compute_conv2d_transpose(attrs, inputs, out_dtype):
"""Compute definition of conv2d_transpose"""
padding = get_const_tuple(attrs.padding)
strides = get_const_tuple(attrs.strides)
out_dtype = attrs.out_dtype
out_dtype = inputs[0].dtype if out_dtype in ("same", "") else out_dtype
output_padding = get_const_tuple(attrs.output_padding)
out = topi_compute(inputs[0], inputs[1], strides, padding, out_dtype, output_padding)
return [out]
return compute_conv2d_transpose
@override_native_generic_func("conv2d_transpose_strategy")
def conv2d_transpose_strategy(attrs, inputs, out_type, target):
"""conv2d_transpose generic strategy"""
logger.warning("conv2d_transpose is not optimized for this platform.")
layout = attrs.data_layout
dilation = get_const_tuple(attrs.dilation)
groups = attrs.groups
assert layout == "NCHW", "only support nchw for now"
assert dilation == (1, 1), "not support dilate now"
assert groups == 1, "only support groups == 1 for now"
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_conv2d_transpose(topi.nn.conv2d_transpose_nchw),
wrap_topi_schedule(topi.generic.schedule_conv2d_transpose_nchw),
name="conv2d_transpose_nchw.generic",
)
return strategy
# conv3d_transpose
def wrap_compute_conv3d_transpose(topi_compute):
"""wrap conv3d_transpose topi compute"""
def compute_conv3d_transpose(attrs, inputs, out_dtype):
"""Compute definition of conv3d_transpose"""
padding = get_const_tuple(attrs.padding)
strides = get_const_tuple(attrs.strides)
output_padding = get_const_tuple(attrs.output_padding)
out_dtype = attrs.out_dtype
out_dtype = inputs[0].dtype if out_dtype in ("same", "") else out_dtype
out = topi_compute(inputs[0], inputs[1], strides, padding, out_dtype, output_padding)
return [out]
return compute_conv3d_transpose
@override_native_generic_func("conv3d_transpose_strategy")
def conv3d_transpose_strategy(attrs, inputs, out_type, target):
"""conv3d_transpose generic strategy"""
logger.warning("conv3d_transpose is not optimized for this platform.")
layout = attrs.data_layout
dilation = get_const_tuple(attrs.dilation)
groups = attrs.groups
assert layout == "NCDHW", "only support ncdhw for now"
assert dilation == (1, 1, 1), "not support dilate now"
assert groups == 1, "only support groups == 1 for now"
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_conv3d_transpose(topi.nn.conv3d_transpose_ncdhw),
wrap_topi_schedule(topi.generic.schedule_conv3d_transpose_ncdhw),
name="conv3d_transpose_ncdhw.generic",
)
return strategy
# conv3d
def wrap_compute_conv3d(topi_compute, need_layout=False):
"""wrap conv3d topi compute"""
def _compute_conv3d(attrs, inputs, out_type):
padding = get_const_tuple(attrs.padding)
strides = get_const_tuple(attrs.strides)
dilation = get_const_tuple(attrs.dilation)
groups = attrs.groups
layout = attrs.data_layout
out_dtype = attrs.out_dtype
out_dtype = inputs[0].dtype if out_dtype in ("same", "") else out_dtype
(dilation_d, dilation_h, dilation_w) = dilation
if dilation_d < 1 or dilation_h < 1 or dilation_w < 1:
raise ValueError("Dilation should be positive value")
if groups != 1:
raise ValueError("Not support arbitrary group number for conv3d")
if need_layout:
out = topi_compute(inputs[0], inputs[1], strides, padding, dilation, layout, out_dtype)
else:
out = topi_compute(inputs[0], inputs[1], strides, padding, dilation, out_dtype)
return [out]
return _compute_conv3d
@override_native_generic_func("conv3d_strategy")
def conv3d_strategy(attrs, inputs, out_type, target):
"""conv3d generic strategy"""
logger.warning("conv3d is not optimized for this platform.")
strategy = _op.OpStrategy()
layout = attrs.data_layout
if layout == "NCDHW":
strategy.add_implementation(
wrap_compute_conv3d(topi.nn.conv3d_ncdhw),
wrap_topi_schedule(topi.generic.schedule_conv3d_ncdhw),
name="conv3d_ncdhw.generic",
)
elif layout == "NDHWC":
strategy.add_implementation(
wrap_compute_conv3d(topi.nn.conv3d_ndhwc),
wrap_topi_schedule(topi.generic.schedule_conv3d_ndhwc),
name="conv3d_ndhwc.generic",
)
else:
raise ValueError("Not support this layout {} yet".format(layout))
return strategy
# conv3d_winograd_without_weight_transform
@override_native_generic_func("conv3d_winograd_without_weight_transform_strategy")
def conv3d_winograd_without_weight_transfrom_strategy(attrs, inputs, out_type, target):
"""conv3d_winograd_without_weight_transfrom generic strategy"""
raise ValueError("No generic implemenation for conv3d_winograd_without_weight_transform")
# conv3d_winograd_weight_transform
@generic_func
def schedule_conv3d_winograd_weight_transform(attrs, outs, target):
"""Schedule conv3d_winograd_weight_transform"""
with target:
return topi.generic.schedule_conv3d_winograd_weight_transform(outs)
# conv1d
def wrap_compute_conv1d(topi_compute):
"""wrap conv1d topi compute"""
def _compute_conv1d(attrs, inputs, out_type):
"""Compute definition of conv1d"""
strides = get_const_tuple(attrs.strides)
padding = get_const_tuple(attrs.padding)
dilation = get_const_tuple(attrs.dilation)
out_dtype = attrs.out_dtype
out_dtype = inputs[0].dtype if out_dtype in ("same", "") else out_dtype
return [topi_compute(inputs[0], inputs[1], strides, padding, dilation, out_dtype)]
return _compute_conv1d
@override_native_generic_func("conv1d_strategy")
def conv1d_strategy(attrs, inputs, out_type, target):
"""conv1d generic strategy"""
logger.warning("conv1d is not optimized for this platform.")
layout = attrs.data_layout
dilation = get_const_tuple(attrs.dilation)
if dilation[0] < 1:
raise ValueError("dilation should be a positive value")
strategy = _op.OpStrategy()
if layout == "NCW":
strategy.add_implementation(
wrap_compute_conv1d(topi.nn.conv1d_ncw),
wrap_topi_schedule(topi.generic.schedule_conv1d_ncw),
name="conv1d_ncw.generic",
)
elif layout == "NWC":
strategy.add_implementation(
wrap_compute_conv1d(topi.nn.conv1d_nwc),
wrap_topi_schedule(topi.generic.schedule_conv1d_nwc),
name="conv1d_nwc.generic",
)
else:
raise ValueError("Unsupported conv1d layout {}".format(layout))
return strategy
# conv1d_transpose
def wrap_compute_conv1d_transpose(topi_compute):
"""wrap conv1d_transpose topi compute"""
def _compute_conv1d_tranpsoe(attrs, inputs, out_type):
padding = get_const_tuple(attrs.padding)
strides = get_const_tuple(attrs.strides)
out_dtype = attrs.out_dtype
out_dtype = inputs[0].dtype if out_dtype in ("same", "") else out_dtype
output_padding = get_const_tuple(attrs.output_padding)
out = topi_compute(inputs[0], inputs[1], strides, padding, out_dtype, output_padding)
return [out]
return _compute_conv1d_tranpsoe
@override_native_generic_func("conv1d_transpose_strategy")
def conv1d_transpose_strategy(attrs, inputs, out_type, target):
"""conv1d_transpose generic strategy"""
logger.warning("conv1d_transpose is not optimized for this platform.")
strategy = _op.OpStrategy()
layout = attrs.data_layout
dilation = get_const_tuple(attrs.dilation)
groups = attrs.groups
assert layout == "NCW", "conv1d_transpose ncw only supported"
assert dilation == (1,), "conv1d_transpose dilation is not supported"
assert groups == 1, "conv1d_transpose groups == 1 only supported"
strategy.add_implementation(
wrap_compute_conv1d_transpose(topi.nn.conv1d_transpose_ncw),
wrap_topi_schedule(topi.generic.schedule_conv1d_transpose_ncw),
name="conv1d_transpose_ncw.generic",
)
return strategy
# dilation2d
def wrap_compute_dilation2d(topi_compute, need_data_layout=False):
"""Wrap dilation2d topi compute"""
def _compute_dilation2d(attrs, inputs, out_type):
padding = get_const_tuple(attrs.padding)
strides = get_const_tuple(attrs.strides)
dilations = get_const_tuple(attrs.dilations)
data_layout = attrs.get_str("data_layout")
out_dtype = attrs.out_dtype
out_dtype = inputs[0].dtype if out_dtype in ("same", "") else out_dtype
args = [inputs[0], inputs[1], strides, padding, dilations]
if need_data_layout:
args.append(data_layout)
args.append(out_dtype)
return [topi_compute(*args)]
return _compute_dilation2d
@override_native_generic_func("dilation2d_strategy")
def dilation2d_strategy(attrs, inputs, out_type, target):
"""dilation2d_strategy generic strategy"""
logger.warning("dilation2d_strategy is not optimized for this platform.")
strategy = _op.OpStrategy()
dilations = get_const_tuple(attrs.dilations)
layout = attrs.data_layout
kernel_layout = attrs.kernel_layout
assert layout in ["NCHW", "NHWC"]
(dilation_h, dilation_w) = dilations
if dilation_h < 1 or dilation_w < 1:
raise ValueError("dilation should be positive value")
if layout == "NCHW":
assert kernel_layout == "IHW"
strategy.add_implementation(
wrap_compute_dilation2d(topi.image.dilation2d_nchw),
wrap_topi_schedule(topi.generic.schedule_dilation2d_nchw),
name="dilation2d_nchw.generic",
)
elif layout == "NHWC":
assert kernel_layout == "HWI"
strategy.add_implementation(
wrap_compute_dilation2d(topi.image.dilation2d_nhwc),
wrap_topi_schedule(topi.generic.schedule_dilation2d_nhwc),
name="dilation2d_nhwc.generic",
)
else:
raise RuntimeError("Unsupported dilation2d layout {}".format(layout))
return strategy
# dense
def wrap_compute_dense(topi_compute):
"""wrap dense topi compute"""
def _compute_dense(attrs, inputs, out_type):
"""Compute definition of dense"""
out_dtype = attrs.out_dtype
out_dtype = inputs[0].dtype if out_dtype == "" else out_dtype
return [topi_compute(inputs[0], inputs[1], None, out_dtype)]
return _compute_dense
@override_native_generic_func("dense_strategy")
def dense_strategy(attrs, inputs, out_type, target):
"""dense generic strategy"""
logger.warning("dense is not optimized for this platform.")
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_dense(topi.nn.dense),
wrap_topi_schedule(topi.generic.schedule_dense),
name="dense.generic",
)
return strategy
# batch_matmul
def wrap_compute_batch_matmul(topi_compute):
"""wrap batch_matmul topi compute"""
def _compute_batch_matmul(attrs, inputs, out_type):
return [topi_compute(inputs[0], inputs[1], out_type.shape)]
return _compute_batch_matmul
@override_native_generic_func("batch_matmul_strategy")
def batch_matmul_strategy(attrs, inputs, out_type, target):
"""batch_matmul generic strategy"""
logger.warning("batch_matmul is not optimized for this platform.")
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_batch_matmul(topi.nn.batch_matmul),
wrap_topi_schedule(topi.generic.schedule_batch_matmul),
name="batch_matmul.generic",
)
return strategy
# sparse dense
def wrap_compute_sparse_dense(topi_compute):
"""wrap sparse dense topi compute"""
def _compute_sparse_dense(attrs, inputs, out_type):
return [topi_compute(inputs[0], inputs[1], inputs[2], inputs[3])]
return _compute_sparse_dense
@override_native_generic_func("sparse_dense_strategy")
def sparse_dense_strategy(attrs, inputs, out_type, target):
"""sparse dense generic strategy"""
logger.warning("sparse dense is not optimized for this platform.")
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_sparse_dense(topi.nn.sparse_dense),
wrap_topi_schedule(topi.generic.schedule_sparse_dense),
name="sparse_dense.generic",
)
return strategy
@override_native_generic_func("sparse_dense_padded_strategy")
def sparse_dense_padded_strategy(attrs, inputs, out_type, target):
"""sparse dense padded generic strategy"""
raise NotImplementedError("sparse_dense_padded is only implemented for cuda")
# sparse_transpose
@generic_func
def schedule_sparse_transpose(attrs, outs, target):
"""schedule sparse_transpose"""
with target:
return topi.generic.schedule_sparse_transpose(outs)
# argsort
def wrap_compute_argsort(topi_compute):
"""Wrap argsort topi compute"""
def _compute_argsort(attrs, inputs, _):
axis = get_const_int(attrs.axis)
is_ascend = bool(get_const_int(attrs.is_ascend))
dtype = attrs.dtype
return [topi_compute(inputs[0], axis=axis, is_ascend=is_ascend, dtype=dtype)]
return _compute_argsort
@override_native_generic_func("argsort_strategy")
def argsort_strategy(attrs, inputs, out_type, target):
"""argsort generic strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_argsort(topi.argsort),
wrap_topi_schedule(topi.generic.schedule_argsort),
name="argsort.generic",
)
return strategy
# topk
def wrap_compute_topk(topi_compute):
"""Wrap topk compute"""
def _compute_topk(attrs, inputs, out_type):
if attrs.k is not None:
k = attrs.k
else:
k = inputs[1]
axis = get_const_int(attrs.axis)
ret_type = attrs.ret_type
is_ascend = bool(get_const_int(attrs.is_ascend))
dtype = attrs.dtype
out = topi_compute(inputs[0], k, axis, ret_type, is_ascend, dtype)
out = out if isinstance(out, list) else [out]
return out
return _compute_topk
@override_native_generic_func("topk_strategy")
def topk_strategy(attrs, inputs, out_type, target):
"""topk generic strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_topk(topi.topk),
wrap_topi_schedule(topi.generic.schedule_topk),
name="topk.generic",
)
return strategy
# multibox_prior
def wrap_compute_multibox_prior(topi_compute):
"""Wrap multibox_prior compute"""
def _compute_multibox_prior(attrs, inputs, _):
"""Compute definition of multibox_prior"""
sizes = get_float_tuple(attrs.sizes)
ratios = get_float_tuple(attrs.ratios)
steps = get_float_tuple(attrs.steps)
offsets = get_float_tuple(attrs.offsets)
clip = bool(get_const_int(attrs.clip))
return [topi_compute(inputs[0], sizes, ratios, steps, offsets, clip)]
return _compute_multibox_prior
@override_native_generic_func("multibox_prior_strategy")
def multibox_prior_strategy(attrs, inputs, out_type, target):
"""multibox_prior generic strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_multibox_prior(topi.vision.ssd.multibox_prior),
wrap_topi_schedule(topi.generic.schedule_multibox_prior),
name="multibox_prior.generic",
)
return strategy
# multibox_transform_loc
def wrap_compute_multibox_transform_loc(topi_compute):
"""Wrap multibox_transform_loc compute"""
def _compute_multibox_transform_loc(attrs, inputs, _):
"""Compute definition of multibox_detection"""
clip = bool(get_const_int(attrs.clip))
threshold = get_const_float(attrs.threshold)
variances = get_float_tuple(attrs.variances)
return topi_compute(inputs[0], inputs[1], inputs[2], clip, threshold, variances)
return _compute_multibox_transform_loc
@override_native_generic_func("multibox_transform_loc_strategy")
def multibox_transform_loc_strategy(attrs, inputs, out_type, target):
"""schedule multibox_transform_loc"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_multibox_transform_loc(topi.vision.ssd.multibox_transform_loc),
wrap_topi_schedule(topi.generic.schedule_multibox_transform_loc),
name="multibox_transform_loc.generic",
)
return strategy
# get_valid_counts
def wrap_compute_get_valid_counts(topi_compute):
"""wrap get_valid_counts topi compute"""
def _compute_get_valid_counts(attrs, inputs, out_type):
score_threshold = get_const_float(attrs.score_threshold)
id_index = get_const_int(attrs.id_index)
score_index = get_const_int(attrs.score_index)
return topi_compute(inputs[0], score_threshold, id_index, score_index)
return _compute_get_valid_counts
@override_native_generic_func("get_valid_counts_strategy")
def get_valid_counts_strategy(attrs, inputs, out_type, target):
"""get_valid_counts generic strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_get_valid_counts(topi.vision.get_valid_counts),
wrap_topi_schedule(topi.generic.schedule_get_valid_counts),
name="get_valid_counts.generic",
)
return strategy
# non-maximum suppression
def wrap_compute_nms(topi_compute):
"""wrap nms topi compute"""
def _compute_nms(attrs, inputs, out_type):
max_output_size = inputs[3]
if attrs.max_output_size is not None:
max_output_size = attrs.max_output_size
return_indices = bool(get_const_int(attrs.return_indices))
iou_threshold = get_const_float(attrs.iou_threshold)
force_suppress = bool(get_const_int(attrs.force_suppress))
top_k = get_const_int(attrs.top_k)
coord_start = get_const_int(attrs.coord_start)
score_index = get_const_int(attrs.score_index)
id_index = get_const_int(attrs.id_index)
invalid_to_bottom = bool(get_const_int(attrs.invalid_to_bottom))
if return_indices:
return topi_compute(
inputs[0],
inputs[1],
inputs[2],
max_output_size,
iou_threshold,
force_suppress,
top_k,
coord_start,
score_index,
id_index,
return_indices,
invalid_to_bottom,
)
return [
topi_compute(
inputs[0],
inputs[1],
inputs[2],
max_output_size,
iou_threshold,
force_suppress,
top_k,
coord_start,
score_index,
id_index,
return_indices,
invalid_to_bottom,
)
]
return _compute_nms
@override_native_generic_func("non_max_suppression_strategy")
def nms_strategy(attrs, inputs, out_type, target):
"""nms generic strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_nms(topi.vision.non_max_suppression),
wrap_topi_schedule(topi.generic.schedule_nms),
name="nms.generic",
)
return strategy
# roi_align
def wrap_compute_roi_align(topi_compute):
"""wrap roi_align topi compute"""
def _compute_roi_align(attrs, inputs, out_type):
assert attrs.layout == "NCHW"
pooled_size = get_const_tuple(attrs.pooled_size)
return [
topi_compute(
inputs[0],
inputs[1],
pooled_size=pooled_size,
spatial_scale=attrs.spatial_scale,
sample_ratio=attrs.sample_ratio,
)
]
return _compute_roi_align
@override_native_generic_func("roi_align_strategy")
def roi_align_strategy(attrs, inputs, out_type, target):
"""roi_align generic strategy"""
strategy = _op.OpStrategy()
layout = attrs.layout
assert layout == "NCHW", "only support nchw for now"
strategy.add_implementation(
wrap_compute_roi_align(topi.vision.rcnn.roi_align_nchw),
wrap_topi_schedule(topi.generic.schedule_roi_align),
name="roi_align.generic",
)
return strategy
# roi_pool
@generic_func
def schedule_roi_pool(attrs, outs, target):
"""schedule roi_pool"""
with target:
return topi.generic.schedule_roi_pool(outs)
# proposal
def wrap_compute_proposal(topi_compute):
"""wrap proposal topi compute"""
def _compute_proposal(attrs, inputs, out_type):
scales = get_float_tuple(attrs.scales)
ratios = get_float_tuple(attrs.ratios)
feature_stride = attrs.feature_stride
threshold = attrs.threshold
rpn_pre_nms_top_n = attrs.rpn_pre_nms_top_n
rpn_post_nms_top_n = attrs.rpn_post_nms_top_n
rpn_min_size = attrs.rpn_min_size
iou_loss = bool(get_const_int(attrs.iou_loss))
return [
topi_compute(
inputs[0],
inputs[1],
inputs[2],
scales,