-
Notifications
You must be signed in to change notification settings - Fork 434
/
Copy pathtensor.py
3953 lines (3428 loc) · 204 KB
/
tensor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# SPDX-License-Identifier: Apache-2.0
"""
tensor
"""
import logging
import sys
import numpy as np
from onnx import onnx_pb, helper
from onnx.onnx_pb import TensorProto
from tf2onnx import constants, utils
from tf2onnx.graph_builder import GraphBuilder
from tf2onnx.handler import tf_op
from tf2onnx.onnx_opset import nn, math
from tf2onnx.constants import NCHW_TO_NHWC, NHWC_TO_NCHW
logger = logging.getLogger(__name__)
# pylint: disable=unused-argument,missing-docstring,unused-variable,pointless-string-statement,invalid-name
def _convert_shapenode_to_int64(ctx, node, input_number):
"""cast int32 shape into int64 shape."""
name = node.input[input_number]
cast_node = ctx.insert_new_node_on_input(node, "Cast", name, to=onnx_pb.TensorProto.INT64)
ctx.set_dtype(cast_node.output[0], onnx_pb.TensorProto.INT64)
ctx.copy_shape(name, cast_node.output[0])
def _wrap_concat_with_cast(ctx, node):
"""wrap concat in casts for opset < 8 since it only supports."""
supported_types = [onnx_pb.TensorProto.FLOAT, onnx_pb.TensorProto.FLOAT16]
dtype = ctx.get_dtype(node.output[0])
need_casting = dtype not in supported_types
if need_casting:
output_name = node.output[0]
# cast each inputs to float
for i, inp in enumerate(node.inputs):
input_cast = ctx.insert_new_node_on_input(node, "Cast", node.input[i],
to=onnx_pb.TensorProto.FLOAT)
ctx.set_dtype(input_cast.output[0], onnx_pb.TensorProto.FLOAT)
next_nodes = ctx.find_output_consumers(node.output[0])
# cast output back to dtype unless the next op is a cast
if next_nodes[0].type != "Cast":
output_cast = ctx.insert_new_node_on_output("Cast", output_name, name=node.child_name(),
to=dtype)
ctx.set_dtype(output_cast.output[0], dtype)
ctx.copy_shape(output_name, output_cast.output[0])
ctx.set_dtype(node.output[0], onnx_pb.TensorProto.FLOAT)
@tf_op("Size")
class Size:
@classmethod
def version_1(cls, ctx, node, **kwargs):
output_name = node.output[0]
dtype = ctx.get_dtype(output_name)
# TF size can output int32 or int64 but onnx only does int 64
if dtype != onnx_pb.TensorProto.INT64:
ctx.set_dtype(output_name, onnx_pb.TensorProto.INT64)
output_cast = ctx.insert_new_node_on_output("Cast", output_name, name=node.child_name(),
to=dtype)
ctx.set_dtype(output_cast.output[0], dtype)
ctx.copy_shape(output_name, output_cast.output[0])
@tf_op("Flatten")
class Flatten:
@classmethod
def version_1(cls, ctx, node, **kwargs):
pass
@classmethod
def version_9(cls, ctx, node, **kwargs):
# no change for us
cls.version_1(ctx, node, **kwargs)
@classmethod
def version_11(cls, ctx, node, **kwargs):
# no change
cls.version_1(ctx, node, **kwargs)
@tf_op("Dropout")
class Dropout:
@classmethod
def version_1(cls, ctx, node, **kwargs):
pass
@classmethod
def version_6(cls, ctx, node, **kwargs):
pass
@classmethod
def version_7(cls, ctx, node, **kwargs):
pass
@classmethod
def version_10(cls, ctx, node, **kwargs):
pass
@classmethod
def version_12(cls, ctx, node, **kwargs):
pass
@tf_op("Identity")
class Identity:
@classmethod
def version_1(cls, ctx, node, **kwargs):
if node.inputs[0].is_const():
# should not remove the identity node if it is output of the graph
if node.output[0] in ctx.outputs:
return
# if identity has a const as input, remove it
input_name = node.input[0]
output_name = node.output[0]
ctx.replace_all_inputs(output_name, input_name) # ops=ctx.get_nodes()
ctx.remove_node(node.name)
@tf_op("IdentityN")
class IdentityN:
@classmethod
def version_1(cls, ctx, node, **kwargs):
ctx.remove_node(node.name)
for input_name, output_name in zip(node.input, node.output):
ctx.replace_all_inputs(output_name, input_name) # ops=ctx.get_nodes()
@tf_op("EnsureShape")
class EnsureShape:
@classmethod
def version_1(cls, ctx, node, **kwargs):
node.type = "Identity"
@tf_op("Reshape")
class Reshape:
@classmethod
def version_1(cls, ctx, node, **kwargs):
# T output = Reshape(T tensor, Tshape shape, @type Tshape)
# T reshaped = Reshape(T data, @INTS shape) - but takes a optional 2nd input for shape
shape_node = node.inputs[1]
shape = shape_node.get_tensor_value()
if shape is None:
logger.error("Reshape on node %s does not have a const shape", node.name)
return
ctx.remove_input(node, node.input[1], 1)
node.set_attr("shape", shape)
ctx.set_shape(node.output[0], shape)
@classmethod
def version_5(cls, ctx, node, **kwargs):
dtype = ctx.get_dtype(node.output[0])
if node.inputs[1].is_const():
target_shape = node.inputs[1].get_tensor_value(as_list=True)
inp_shape = ctx.get_shape(node.input[0])
if inp_shape is not None and inp_shape == target_shape:
# Remove useless Reshape
node.type = "Identity"
ctx.replace_inputs(node, [node.input[0]])
return
need_casting = dtype in [onnx_pb.TensorProto.INT32,
onnx_pb.TensorProto.INT16,
onnx_pb.TensorProto.INT64]
# onnx wants reshape.input[1] to have the value be int64 which is not the case for tensorflow.
_convert_shapenode_to_int64(ctx, node, 1)
if ctx.opset >= 8 or not need_casting:
# onnx reshape can handle the type - done
return
# onnx < opset 8 does not know reshape for other types than float*, wrap the reshape in casts
ctx.insert_new_node_on_input(node, "Cast", node.input[0], to=onnx_pb.TensorProto.FLOAT)
# if the next node is already a cast we don't need to insert another one
next_nodes = ctx.find_output_consumers(node.output[0])
if len(next_nodes) != 1 or next_nodes[0].type != "Cast":
output_cast = ctx.insert_new_node_on_output("Cast", node.output[0], name=node.child_name(),
to=dtype)
ctx.set_dtype(output_cast.output[0], dtype)
ctx.copy_shape(node.output[0], output_cast.output[0])
ctx.set_dtype(node.output[0], onnx_pb.TensorProto.FLOAT)
@tf_op("Squeeze")
class Squeeze:
@classmethod
def version_1(cls, ctx, node, **kwargs):
# T output = Squeeze(T input, @list(int) squeeze_dims)
# T squeezed = Squeeze(T data, @AttrType.INTS axes), axes are list of positive integers.
axes = node.get_attr_value("squeeze_dims")
if axes is None:
axes = []
else:
del node.attr["squeeze_dims"]
# TF uses empty axes to indicate that all 1 dims should be squeezed
if len(axes) > 0:
neg_axis = any([val < 0 for val in axes])
if neg_axis and ctx.opset < 11:
shape = ctx.get_shape(node.input[0])
utils.make_sure(shape is not None, "squeeze with negative axes and unknown rank requires opset >= 11")
shape_len = len(shape)
axes = [a + shape_len if a < 0 else a for a in axes]
if ctx.opset < 13:
node.set_attr("axes", axes)
else:
axes_const = ctx.make_const(utils.make_name("axes_const"), np.array(axes, dtype=np.int64))
ctx.replace_inputs(node, [node.input[0], axes_const.output[0]])
@classmethod
def version_11(cls, ctx, node, **kwargs):
# Opset 11 supports negative axis, but core logic is same
cls.version_1(ctx, node, **kwargs)
@classmethod
def version_13(cls, ctx, node, **kwargs):
# Opset 13: parameters moved to inputs
cls.version_1(ctx, node, **kwargs)
@tf_op("Transpose")
class Transpose:
@classmethod
def version_1(cls, ctx, node, **kwargs):
# T y = Transpose(T x, Tperm perm, @type Tperm)
# T transposed = Transpose(T data, @INTS perm)
if len(node.input) > 1:
perm = node.inputs[1]
if perm.is_const():
# perms is passed as const
dims = perm.get_tensor_value()
ctx.remove_input(node, node.input[1], 1)
node.set_attr("perm", dims)
else:
utils.make_sure(False, "perm can't be dynamic in ONNX")
else:
# graph rewrite moved perm to attribute
pass
@tf_op("Concat")
class Concat:
@classmethod
def version_1(cls, ctx, node, **kwargs):
# old concat op has axis as input[0]
node.type = "Concat"
axis_node = node.inputs[0]
axis_val = axis_node.get_tensor_value()
ctx.remove_input(node, node.input[0], 0)
if axis_val < 0: # onnxruntime does not support -1 axis, but TF supports.
input_shape = ctx.get_shape(node.input[0])
axis_val = len(input_shape) + axis_val
node.set_attr("axis", axis_val)
if ctx.opset < 8:
# opset < 8: might need to wrap concat in casts since only float is supported
_wrap_concat_with_cast(ctx, node)
return
@classmethod
def version_11(cls, ctx, node, **kwargs):
# Opset 11 supports negative axis, but core logic is same
cls.version_1(ctx, node, **kwargs)
@tf_op("ConcatV2")
class ConcatV2:
@classmethod
def version_1(cls, ctx, node, **kwargs):
# T output = ConcatV2(T values, Tidx axis, @int N, @type Tidx)
# T concat_result = Concat(T inputs, @INT axis)
# if any input is empty, remove the input and concat the others
# NOTE: workaround for https://github.com/Microsoft/onnxruntime/issues/681
node.type = "Concat"
removed_indices = []
for i, inp in enumerate(node.inputs):
if inp.is_const() and inp.get_tensor_value(as_list=False).size == 0:
removed_indices.append(i)
for i in reversed(removed_indices):
ctx.remove_input(node, node.input[i], i)
# all inputs are deleted
if not node.input:
raise RuntimeError("all inputs of {} are empty".format(node.name))
axis_node = node.inputs[-1]
utils.make_sure(axis_node.is_const(), "{} needs to be const".format(axis_node.name))
axis_val = axis_node.get_tensor_value()
ctx.remove_input(node, node.input[-1], len(node.input) - 1)
if axis_val < 0: # onnxruntime does not support -1 axis, but TF supports.
input_shape = ctx.get_shape(node.input[0])
utils.make_sure(input_shape is not None, "shape of {} is None".format(node.input[0]))
axis_val = len(input_shape) + axis_val
node.set_attr("axis", axis_val)
if ctx.opset < 8:
# opset < 8: might need to wrap concat in casts since only float is supported
_wrap_concat_with_cast(ctx, node)
return
@tf_op("Slice")
class Slice:
@classmethod
def version_1(cls, ctx, node, **kwargs):
# T output = Slice(T input, Index begin, Index size)
# T output = Slice(T input, Tind starts, Tind ends, Tind axes, Tind steps)
# "ends" are exclusive, "axes" and "steps" are optional, their default val are [0, ...] and 1
input_tensor = node.input[0]
starts = node.input[1]
size = node.input[2]
# in tf, size can be -1 which means all elem are taken, so size can't be added starts directly.
# the way to make sure size are not less than 0: set "sizes"'s elem to be int_max if elem val is -1
size_dtype = ctx.get_dtype(size)
size_np_dtype = utils.map_onnx_to_numpy_type(size_dtype)
if ctx.get_node_by_output(size).is_const() and ctx.get_node_by_output(starts).is_const():
starts = ctx.get_node_by_output(starts).get_tensor_value()
sizes = ctx.get_node_by_output(size).get_tensor_value()
ends = []
for start, size in zip(starts, sizes):
# get all elements
if size == -1:
dtype = ctx.get_dtype(node.input[1])
utils.make_sure(dtype, "dtype of {} is None".format(node.input[1]))
utils.make_sure(dtype, "dtype of {} is None".format(node.input[1]))
ends.append(np.iinfo(dtype).max)
else:
ends.append(start + size)
else:
neg_one_val = np.array([-1]).astype(size_np_dtype)
neg_one = ctx.make_const(utils.make_name("const"), neg_one_val).output[0]
int_max_val = np.array([utils.get_max_value(size_np_dtype)]).astype(size_np_dtype)
int_max = ctx.make_const(utils.make_name("largest_int_val"), int_max_val).output[0]
size_are_neg_one_flag = ctx.make_node("Equal", [neg_one, size]).output[0]
size_are_neg_one_flag = ctx.make_node("Cast", [size_are_neg_one_flag], attr={"to": size_dtype}).output[0]
value_to_add = ctx.make_node("Mul", [int_max, size_are_neg_one_flag]).output[0]
size_processed = ctx.make_node("Add", [size, value_to_add]).output[0]
ends = ctx.make_node("Add", [starts, size_processed]).output[0]
ctx.remove_node(node.name)
inputs_map = {"data": input_tensor, "starts": starts, "ends": ends}
kwargs = {**inputs_map, "outputs": node.output}
_ = GraphBuilder(ctx).make_slice(kwargs, name=node.name)
@classmethod
def version_10(cls, ctx, node, **kwargs):
cls.version_1(ctx, node, **kwargs)
@classmethod
def version_11(cls, ctx, node, **kwargs):
cls.version_1(ctx, node, **kwargs)
@tf_op("Roll")
class Roll:
@classmethod
def any_version(cls, opset, ctx, node, **kwargs):
utils.make_sure(node.inputs[2].is_const(), "Can only convert Roll is axis is const")
axes = node.inputs[2].get_tensor_value()
if not isinstance(axes, list):
axes = [axes]
rank = ctx.get_rank(node.input[0])
axes = [a if a >= 0 else a + rank for a in axes]
shifts_dtype = ctx.get_dtype(node.input[1])
if shifts_dtype != TensorProto.INT64:
shifts_casted = ctx.insert_new_node_on_input(node, "Cast", node.input[1], to=TensorProto.INT64).output[0]
else:
shifts_casted = node.input[1]
if len(axes) == 1:
unsqueeze_node = GraphBuilder(ctx).make_unsqueeze(
{'data': shifts_casted, "axes": [0]}, op_name_scope=node.name, return_node=True)
shifts_split = [unsqueeze_node.output[0]]
else:
shifts_split = ctx.make_node("Split", [shifts_casted], attr={'axis': 0},
output_count=len(axes), op_name_scope=node.name).output
zero_const = ctx.make_const(utils.make_name("zeros_const"), np.array([0], np.int64)).output[0]
shape_node = ctx.make_node("Shape", [node.input[0]], op_name_scope=node.name)
data = node.input[0]
for axis, shift in zip(axes, shifts_split):
len_along_axis = GraphBuilder(ctx).make_slice(
{"data": shape_node.output[0], "ends": [axis + 1], "starts": [axis]})
shift_mod = ctx.make_node("Mod", [shift, len_along_axis]).output[0]
remaining_len = ctx.make_node("Sub", [len_along_axis, shift_mod], op_name_scope=node.name).output[0]
axes_const = ctx.make_const(utils.make_name("axes_const"), np.array([axis], np.int64)).output[0]
slice_one = ctx.make_node("Slice", [data, zero_const, remaining_len, axes_const], op_name_scope=node.name)
slice_two = ctx.make_node("Slice", [data, remaining_len, len_along_axis, axes_const],
op_name_scope=node.name)
concat_node = ctx.make_node("Concat", [slice_two.output[0], slice_one.output[0]],
attr={'axis': axis}, op_name_scope=node.name)
data = concat_node.output[0]
ctx.replace_all_inputs(node.output[0], data)
ctx.remove_node(node.name)
@classmethod
def version_10(cls, ctx, node, **kwargs):
cls.any_version(10, ctx, node, **kwargs)
@classmethod
def version_11(cls, ctx, node, **kwargs):
cls.any_version(11, ctx, node, **kwargs)
@classmethod
def version_13(cls, ctx, node, **kwargs):
# Parameters moved to inputs for operator Squeeze, Unsqueeze.
cls.any_version(13, ctx, node, **kwargs)
@tf_op("Gather")
class Gather:
@classmethod
def version_1(cls, ctx, node, **kwargs):
node.type = "Gather"
@classmethod
def version_11(cls, ctx, node, **kwargs):
# no change
cls.version_1(ctx, node, **kwargs)
@tf_op("GatherV2")
class GatherV2:
@classmethod
def version_1(cls, ctx, node, **kwargs):
# for GatherV2 axis come as input
err_msg = "Opset 12 required for batch_dims attribute of GatherV2"
utils.make_sure(node.get_attr_value("batch_dims", 0) == 0, err_msg)
node.type = "Gather"
utils.make_sure(node.inputs[2].is_const(), "Axis of GatherV2 node must be constant")
axis = node.inputs[2].get_tensor_value()
ctx.remove_input(node, node.input[2], 2)
node.set_attr("axis", axis)
@classmethod
def version_11(cls, ctx, node, **kwargs):
# no change
cls.version_1(ctx, node, **kwargs)
@classmethod
def version_12(cls, ctx, node, **kwargs):
batch_dims = node.get_attr_value("batch_dims", 0)
if batch_dims == 0:
cls.version_1(ctx, node, **kwargs)
return
# If batch_dims is not zero, use GatherND to simulate Gather with batch dims.
data_inp, indices_inp, axis_inp = node.input
utils.make_sure(node.inputs[2].is_const(), "Axis of GatherV2 node must be constant")
axis = node.inputs[2].get_tensor_value()
ctx.remove_input(node, axis_inp, 2)
if ctx.get_dtype(indices_inp) != TensorProto.INT64:
indices_inp = ctx.make_node("Cast", [indices_inp], attr={'to': TensorProto.INT64}).output[0]
unperm = None
# GatherND doesn't take an axis so we have to transpose stuff around
if axis != batch_dims:
data_rank = ctx.get_rank(data_inp)
indices_rank = ctx.get_rank(indices_inp)
result_rank = data_rank + indices_rank - 1 - batch_dims
shift_amt = axis - batch_dims
err_msg = "Cannot convert GatherV2 with batch dims since inputs have unknown ranks."
utils.make_sure(data_rank is not None and indices_rank is not None, err_msg)
perm = list(range(data_rank))
perm = perm[:batch_dims] + perm[axis:axis+1] + perm[batch_dims:axis] + perm[axis+1:]
data_inp = ctx.make_node("Transpose", [data_inp], attr={'perm': perm}).output[0]
ctx.replace_input(node, node.input[0], data_inp, 0)
unperm = list(range(result_rank))
j = indices_rank+shift_amt
unperm = unperm[:batch_dims] + unperm[indices_rank:j] + unperm[batch_dims:indices_rank] + unperm[j:]
node.type = "GatherND"
unsqueeze_node = GraphBuilder(ctx).make_unsqueeze({'data': indices_inp, 'axes': [-1]})
ctx.replace_input(node, node.input[1], unsqueeze_node, 1)
if unperm is not None:
ctx.update_node_shape_dtype(node, override=True)
ctx.insert_new_node_on_output("Transpose", node.output[0], perm=unperm)
def _make_gathernd_inner_loop(ctx, params, index, dtype):
"""create the inner loop for GatherNd."""
# gather_cur = params
# for (int i = 0; i < size(index); i++)
# gather_res = gather(gather_cur, index[i])
scope_name = utils.make_name("gathernd_inner_loop")
trip_node = ctx.make_node("Size", [index.output[0]])
cond_const = ctx.make_const(utils.make_name("cond"), np.ones((), dtype=np.bool))
trip_name = utils.make_name("i")
cond_name = utils.make_name("cond")
cond_out_name = utils.make_name("cond_out")
cur_name = utils.make_name("gather_cur")
result_name = utils.make_name("res")
# body graph creation
g = ctx.create_new_graph_with_same_config()
g.add_graph_input(trip_name, TensorProto.INT64, [1])
g.add_graph_input(cond_name, TensorProto.BOOL, [])
g.add_graph_input(cur_name, dtype, [])
g.parent_graph = ctx
index_i = g.make_node("Gather", [index.output[0], trip_name], attr={"axis": 0})
gather = g.make_node("Gather", [cur_name, index_i.output[0]], attr={"axis": 0})
GraphBuilder(g).make_squeeze(
{'data': gather.output[0], "axes": [0], 'outputs': [result_name]})
g.make_node("Identity", [cond_name], outputs=[cond_out_name])
g.add_graph_output(cond_out_name, TensorProto.BOOL, [])
g.add_graph_output(result_name, dtype, [])
branches = {"body": g}
inner_loop = ctx.make_node("Loop",
[trip_node.output[0], cond_const.output[0], params],
op_name_scope=scope_name, skip_conversion=False, branches=branches)
return inner_loop
def make_gathernd(ctx, params, indices, output, scope_name, t_params, shapes, dtypes):
"""make GatherNd op."""
# Tparams output = GatherNd(Tparams params, Tidx indices)
scope_name = utils.make_name(scope_name)
# reshape indices into [sum(indices[:-1]), indices[-1]]
indices_shape = ctx.make_node("Shape", [indices], dtypes=[TensorProto.INT64])
indices_size = ctx.make_node("Size", [indices])
attr = {"axes": [0], "ends": [sys.maxsize], "starts": [-1]}
inputs_map = {"data": indices_shape.output[0], **attr}
inner_shape = GraphBuilder(ctx).make_slice(inputs_map, dtypes=[TensorProto.INT64])
outter_shape = ctx.make_node("Div",
[indices_size.output[0], inner_shape],
dtypes=[TensorProto.INT64])
flatten_shape = ctx.make_node("Concat",
[outter_shape.output[0], inner_shape],
attr={"axis": 0},
dtypes=[TensorProto.INT64])
flatten_indices = ctx.make_node("Reshape", [indices, flatten_shape.output[0]])
# outter loop for each index
# for (int i=0; i<outter_shape; i++) inner_loop(params, flatten_indices[i])
cond_const = ctx.make_const(utils.make_name("cond"), np.ones((), dtype=np.bool))
ctx.make_const(utils.make_name("dummy"), np.ones((), dtype=np.int64))
# body graph creation
g = ctx.create_new_graph_with_same_config()
trip_name = utils.make_name("i")
cond_name = utils.make_name("cond")
cond_out_name = utils.make_name("cond_out")
dummy_name = utils.make_name("dummy")
dummy_out_name = utils.make_name("dummy_out")
result_name = utils.make_name("res")
g.add_graph_input(trip_name, TensorProto.INT64, [1])
g.add_graph_input(cond_name, TensorProto.BOOL, [])
g.add_graph_input(dummy_name, t_params, [])
g.parent_graph = ctx
index = g.make_node("Gather", [flatten_indices.output[0], trip_name], attr={"axis": 0})
index_squeeze = GraphBuilder(g).make_squeeze(
{'data': index.output[0], "axes": [0]}, return_node=True)
# inner loop to gather result
inner_loop = _make_gathernd_inner_loop(g, params, index_squeeze, t_params)
g.make_node("Identity", [cond_name], outputs=[cond_out_name])
g.make_node("Identity", [dummy_name], outputs=[dummy_out_name])
g.make_node("Identity", [inner_loop.output[0]], outputs=[result_name])
g.add_graph_output(cond_out_name, TensorProto.BOOL, [])
g.add_graph_output(dummy_out_name, t_params, [])
g.add_graph_output(result_name, t_params, [])
branches = {"body": g}
gathernd_loop = ctx.make_node("Loop",
[outter_shape.output[0], cond_const.output[0], params],
output_count=2,
op_name_scope=scope_name, skip_conversion=False, branches=branches)
# reshape to target shape
# output shape of gathernd: indices.shape[:-1] + gathernd_output.shape[1:]
inner_loop_shape = ctx.make_node("Shape", [gathernd_loop.output[1]], dtypes=[TensorProto.INT64])
# workaround in case gathernd_loop is 1-dimensional
one_const = ctx.make_const(utils.make_name("one"), np.array([1], dtype=np.int64))
inner_loop_shape_ = ctx.make_node("Concat",
[inner_loop_shape.output[0], one_const.output[0]],
attr={"axis": 0},
dtypes=[TensorProto.INT64])
attr = {"axes": [0], "ends": [sys.maxsize], "starts": [1]}
inputs_map = {"data": inner_loop_shape_.output[0], **attr}
output_inner_shape = GraphBuilder(ctx).make_slice(inputs_map, dtypes=[TensorProto.INT64])
attr = {"axes": [0], "ends": [-1], "starts": [0]}
inputs_map = {"data": indices_shape.output[0], **attr}
indices_outter_shape = GraphBuilder(ctx).make_slice(inputs_map, dtypes=[TensorProto.INT64])
output_shape_ = ctx.make_node("Concat",
[indices_outter_shape, output_inner_shape],
attr={"axis": 0},
dtypes=[TensorProto.INT64])
attr = {"axes": [0], "ends": [-1], "starts": [0]}
inputs_map = {"data": output_shape_.output[0], **attr}
output_shape = GraphBuilder(ctx).make_slice(inputs_map, dtypes=[TensorProto.INT64])
ctx.make_node("Reshape",
[gathernd_loop.output[1], output_shape],
outputs=[output],
shapes=shapes,
dtypes=dtypes)
@tf_op("GatherNd", onnx_op="GatherND")
class GatherND:
@classmethod
def version_1(cls, ctx, node, **kwargs):
# Tparams output = GatherNd(Tparams params, Tidx indices)
params = node.input[0]
indices = node.input[1]
output = node.output[0]
# same as the attr Tparams
t_params = ctx.get_dtype(params)
utils.make_sure(t_params, "Dtype of {} is None".format(indices))
shapes = node.output_shapes
dtypes = node.output_dtypes
ctx.remove_node(node.name)
make_gathernd(ctx, params, indices, output, node.name, t_params, shapes, dtypes)
@classmethod
def version_11(cls, ctx, node, **kwargs):
# indicies input
input1 = node.input[1]
target_dtype = TensorProto.INT64
if ctx.get_dtype(input1) != TensorProto.INT64:
inp_cast = ctx.insert_new_node_on_input(node, "Cast", input1, to=target_dtype)
ctx.copy_shape(input1, inp_cast.output[0])
ctx.set_dtype(inp_cast.output[0], target_dtype)
@tf_op("ScatterNd", onnx_op="ScatterND")
class ScatterND:
@classmethod
def version_11(cls, ctx, node, **kwargs):
onnxdtype = ctx.get_dtype(node.input[1])
zero_tensor = helper.make_tensor("value", onnxdtype, dims=[1], vals=[0])
const_of_shape = ctx.make_node("ConstantOfShape", [node.input[2]], attr={'value': zero_tensor},
shapes=node.output_shapes, dtypes=[onnxdtype])
ctx.replace_input(node, node.input[2], const_of_shape.output[0], 2)
ctx.insert_new_node_on_input(const_of_shape, "Cast", const_of_shape.input[0], to=TensorProto.INT64)
ctx.insert_new_node_on_input(node, "Cast", node.input[0], to=TensorProto.INT64)
# reorder inputs to match onnx
ctx.replace_inputs(node, [node.input[2], node.input[0], node.input[1]])
@tf_op("TensorScatterUpdate", onnx_op="ScatterND")
class TensorScatterUpdate:
@classmethod
def version_11(cls, ctx, node, **kwargs):
if ctx.get_dtype(node.input[1]) != TensorProto.INT64:
ctx.insert_new_node_on_input(node, "Cast", node.input[1], to=TensorProto.INT64)
@tf_op("Split")
class Split:
@classmethod
def version_1(cls, ctx, node, **kwargs):
# T output = Split(int32 split_dim, T value, @int num_split)
# T outputs = Split(T input, @INT axis, @INTS split)
split_dims = node.inputs[0].get_tensor_value()
ctx.remove_input(node, node.input[0], 0)
node.set_attr("axis", split_dims)
@classmethod
def version_2(cls, ctx, node, **kwargs):
cls.version_1(ctx, node, **kwargs)
@classmethod
def version_11(cls, ctx, node, **kwargs):
# no change
cls.version_1(ctx, node, **kwargs)
@classmethod
def version_13(cls, ctx, node, **kwargs):
# Default axis is not -1 but doesn't matter since we always set it.
cls.version_1(ctx, node, **kwargs)
@tf_op("SplitV")
class SplitV:
@classmethod
def version_1(cls, ctx, node, **kwargs):
# T output = SplitV(T value, Tlen size_splits, int32 split_dim, @int num_split, @type Tlen)
# T outputs = Split(T input, @INT axis, @INTS split)
node.type = "Split"
split = node.inputs[1].get_tensor_value()
split_dims = node.inputs[2].get_tensor_value()
if -1 in split:
# negative split = use the remaining size
shape = ctx.get_shape(node.input[0])
final_sum = shape[split_dims]
sums = sum([i for i in split if i >= 0])
for i, v in enumerate(split):
if v == -1:
split[i] = final_sum - sums
ctx.remove_input(node, node.input[2], 2)
ctx.remove_input(node, node.input[1], 1)
node.set_attr("split", split)
node.set_attr("axis", split_dims)
@classmethod
def version_2(cls, ctx, node, **kwargs):
cls.version_1(ctx, node, **kwargs)
@classmethod
def version_13(cls, ctx, node, **kwargs):
# Split now supports dynamic split lengths
if node.inputs[1].is_const():
# Call version 1 to deal with -1 cases
cls.version_1(ctx, node, **kwargs)
# Convert attr to input
split_val = node.get_attr_value("split")
split_const = ctx.make_const(utils.make_name("split"), np.array(split_val, np.int64))
ctx.replace_inputs(node, [node.input[0], split_const.output[0]])
del node.attr["split"]
else:
# Technically incorrect if any of the splits are -1
node.type = "Split"
split_dims = node.inputs[2].get_tensor_value()
ctx.remove_input(node, node.input[2], 2)
node.set_attr("axis", split_dims)
if ctx.get_dtype(node.input[1]) != TensorProto.INT64:
ctx.insert_new_node_on_input(node, "Cast", node.input[1], to=TensorProto.INT64)
@tf_op("ExpandDims")
class ExpandDims:
@classmethod
def version_1(cls, ctx, node, **kwargs):
shape = ctx.get_shape(node.output[0])
dim_node = node.inputs[1]
utils.make_sure(dim_node.is_const(), "ExpandDims with non-const axes requires opset 13")
node.type = "Unsqueeze"
# tf.expanddims() wants a scalar per doc but quietly accepts any single-element tensor
axis = dim_node.get_tensor_value(as_list=False).flatten()[0]
if axis < 0 and ctx.opset < 11:
utils.make_sure(shape is not None, "ExpandDims with negative axes and unknown rank requires opset >= 11")
out_rank = len(shape)
axis += out_rank
node.set_attr("axes", [axis])
ctx.remove_input(node, node.input[1], 1)
@classmethod
def version_7(cls, ctx, node, **kwargs):
cls.version_1(ctx, node, **kwargs)
@classmethod
def version_11(cls, ctx, node, **kwargs):
cls.version_1(ctx, node, **kwargs)
@classmethod
def version_13(cls, ctx, node, **kwargs):
# Parameters moved to inputs for operator Squeeze, Unsqueeze.
if ctx.get_dtype(node.input[1]) != onnx_pb.TensorProto.INT64:
ctx.insert_new_node_on_input(node, "Cast", node.input[1], to=onnx_pb.TensorProto.INT64)
if ctx.get_shape(node.input[1]) != [1]:
const_newshape = ctx.make_const(utils.make_name("reshape_const"), np.array([1], dtype=np.int64))
reshape_node = ctx.make_node("Reshape", [node.input[1], const_newshape.output[0]])
ctx.replace_inputs(node, [node.input[0], reshape_node.output[0]])
node.type = "Unsqueeze"
@tf_op("StridedSlice")
class StridedSlice:
@classmethod
def version_1(cls, ctx, node, **kwargs):
# for now we implement common cases. Things like strides!=1 are not mappable to onnx.
not_supported_attr = ["new_axis_mask"]
for attr_name in not_supported_attr:
attr = node.get_attr(attr_name)
if attr is not None and attr.i != 0:
raise ValueError("StridedSlice: attribute " + attr_name + " not supported")
onnx_dtype = ctx.get_dtype(node.input[1])
np_dtype = utils.ONNX_TO_NUMPY_DTYPE[onnx_dtype]
max_size = np.iinfo(np_dtype).max
begin = node.inputs[1].get_tensor_value()
end = node.inputs[2].get_tensor_value()
strides = node.inputs[3].get_tensor_value()
end_mask = node.get_attr("end_mask")
end_mask = end_mask.i if end_mask is not None else 0
begin_mask = node.get_attr("begin_mask")
begin_mask = begin_mask.i if begin_mask is not None else 0
shrink_axis_mask = node.get_attr("shrink_axis_mask")
shrink_axis_mask = shrink_axis_mask.i if shrink_axis_mask is not None else 0
ellipsis_mask = node.get_attr("ellipsis_mask")
ellipsis_mask = ellipsis_mask.i if ellipsis_mask is not None else 0
new_begin = []
new_end = []
axes = []
# onnx slice op can't remove a axis, track axis and add a squeeze op if needed
needs_squeeze = []
# ellipsis: one bit at most can be 1. An ellipsis implicitly creates as many range specifications as
# necessary to fully specify the sliced range for every dimension.
# For example for a 4-dimensional tensor foo the slice foo[2, ..., 5:8] implies foo[2, :, :, 5:8]
# NOTE: we ignore those axes denoted by ellipsis using `axes` attribute
ellipsis_gap = 0
for idx, begin_item in enumerate(begin):
if strides[idx] != 1:
raise ValueError("StridedSlice: only strides=1 is supported")
if (ellipsis_mask >> idx) & 1:
input_shape = ctx.get_shape(node.input[0])
utils.make_sure(
input_shape is not None,
"StridedSlice op {} requires the shape of input".format(node.name)
)
ellipsis_gap = len(input_shape) - len(begin)
continue
# ignore ellipsis axes
axes.append(idx + ellipsis_gap)
end_item = end[idx]
# an implicit condition is stride == 1 (checked in above)
if begin_item < 0 and end_item == 0:
end_item = max_size
mask = (shrink_axis_mask >> idx) & 1
if mask != 0:
new_begin.append(begin_item)
end_item = begin_item + 1 if begin_item != -1 else max_size
new_end.append(end_item)
needs_squeeze.append(idx + ellipsis_gap)
continue
mask = (begin_mask >> idx) & 1
if mask != 0:
new_begin.append(0)
else:
new_begin.append(begin_item)
mask = (end_mask >> idx) & 1
if mask != 0:
new_end.append(max_size)
else:
new_end.append(end_item)
out_dtypes = [ctx.get_dtype(node.output[0])]
out_shapes = [ctx.get_shape(node.output[0])]
ctx.remove_node(node.name)
attr = {"starts": new_begin, "ends": new_end, "axes": axes}
inputs_map = {"data": node.input[0], **attr}
kwargs = {**inputs_map, "outputs": node.output}
if len(axes) > 0:
node = GraphBuilder(ctx).make_slice(
kwargs, name=node.name, dtypes=out_dtypes, shapes=out_shapes, return_node=True)
else:
node = ctx.make_node("Identity", [node.input[0]], name=node.name, outputs=node.output,
dtypes=out_dtypes, shapes=out_shapes)
nodes = [node]
if needs_squeeze:
# insert_new_node_on_output(self, op_type, output_name=None, name=None, inputs=None, domain=None, **kwargs)
# ctx.insert_new_node_on_output("Squeeze", node.output[0], name)
name = utils.make_name(node.name)
shape = ctx.get_shape(node.output[0])
dtype = ctx.get_dtype(node.output[0])
squeeze_node = GraphBuilder(ctx).make_squeeze(
{"axes": needs_squeeze, 'data': node.output[0]}, name=name,
dtypes=[dtype], shapes=[shape], return_node=True)
ctx.insert_node_on_output(squeeze_node)
nodes.append(squeeze_node)
ctx.update_node_shape_dtype(node, override=True)
# onnx slice as of opset 7 does only take float tensors ... cast if needed
input_dtype = ctx.get_dtype(node.input[0])
if ctx.opset < 9:
if input_dtype != onnx_pb.TensorProto.FLOAT:
if node.inputs[0].type == "Cast" and len(ctx.find_output_consumers(node.inputs[0].output[0])) == 1:
# override the previous cast
cast_node = node.inputs[0]
cast_node.set_attr("to", onnx_pb.TensorProto.FLOAT)
else:
cast_node = ctx.insert_new_node_on_input(node, "Cast", node.input[0],
to=onnx_pb.TensorProto.FLOAT)
nodes.insert(0, cast_node)
ctx.set_dtype(cast_node.output[0], onnx_pb.TensorProto.FLOAT)
ctx.copy_shape(node.input[0], cast_node.output[0])
# undo the cast afer slice
name = utils.make_name(node.name)
cast_node = ctx.insert_new_node_on_output("Cast", nodes[-1].output[0], name,
to=input_dtype)
ctx.set_dtype(cast_node.output[0], input_dtype)
ctx.copy_shape(node.output[0], cast_node.output[0])
nodes.append(cast_node)
ctx.set_dtype(node.output[0], onnx_pb.TensorProto.FLOAT)
@classmethod
def any_version_after10(cls, opset, ctx, node, **kwargs):
# T output = Slice(T input, Index begin, Index end, Index strides
# @int begin_mask, @int end_mask, @int ellipsis_mask
# @int shrink_axis_mask, @int new_axis_mask)
# T output = Slice(T input, Tind starts, Tind ends, Tind axes, Tind steps)
# "ends" are exclusive, "axes" and "steps" are optional, their default val are [0, ...] and 1
input_x = node.input[0]
begin = node.input[1]
end = node.input[2]
strides = node.input[3]
new_axis_mask = node.get_attr("new_axis_mask")
new_axis_mask = new_axis_mask.i if new_axis_mask is not None else 0
if ctx.is_const(begin) and ctx.is_const(end) and ctx.is_const(strides) \
and all(val == 1 for val in ctx.get_tensor_value(strides)) \
and new_axis_mask == 0:
cls.version_1(ctx, node, **kwargs)
return
onnx_dtype = ctx.get_dtype(node.input[1])
np_dtype = utils.ONNX_TO_NUMPY_DTYPE[onnx_dtype]
# NOTE: Max op only supports float32, deal with overflow when cast back to int32
# enable it after Max supports int32 and int64
# max_size = utils.get_max_value(np_dtype)
# min_size = utils.get_min_value(np_dtype)
max_size = 1e9
min_size = -1e9
end_mask = node.get_attr("end_mask")
end_mask = end_mask.i if end_mask is not None else 0
begin_mask = node.get_attr("begin_mask")
begin_mask = begin_mask.i if begin_mask is not None else 0
ellipsis_mask = node.get_attr("ellipsis_mask")
ellipsis_mask = ellipsis_mask.i if ellipsis_mask is not None else 0
shrink_axis_mask = node.get_attr("shrink_axis_mask")
shrink_axis_mask = shrink_axis_mask.i if shrink_axis_mask is not None else 0
param_shape = ctx.get_shape(node.input[1]) or \
ctx.get_shape(node.input[2]) or \
ctx.get_shape(node.input[3])
utils.make_sure(
param_shape is not None,
"StridedSlice op {} requires the shape of begin/end/strides".format(node.name)
)
param_rank = param_shape[0]
if new_axis_mask != 0:
unqueeze_at = []
ellipsis_gap = 0
num_new = 0
for bit in range(32):
if (new_axis_mask >> bit) & 1 == 1:
num_new += 1
if (ellipsis_mask >> bit) & 1:
input_shape = ctx.get_shape(input_x)
# calculate what rank for ellipsis: input rank - (being rank - all new_axis - 1)
ellipsis_gap = len(input_shape) - param_rank + num_new + 1
if (new_axis_mask >> bit) & 1 == 1:
unqueeze_at.append(bit + ellipsis_gap)
begin_mask |= 1 << bit
end_mask |= 1 << bit
input_x = GraphBuilder(ctx).make_unsqueeze(
{'data': input_x, 'axes': unqueeze_at})
# use in onnx graph to mask begin
new_begin_mask = [1] * param_rank
# use in onnx graph to mask end
new_end_mask = [min_size] * param_rank
# for shrink mask, if shrink mask is 1, set stride to be max_size
shrink_strided_mask = [min_size] * param_rank
axes = []
# onnx slice op can't remove a axis, track axis and add a squeeze op if needed
needs_squeeze = []
ellipsis_gap = 0
for idx in range(param_rank):
if (ellipsis_mask >> idx) & 1:
input_shape = ctx.get_shape(input_x)
utils.make_sure(
input_shape is not None,
"StridedSlice op {} requires the shape of input".format(node.name)
)
ellipsis_gap = len(input_shape) - param_rank
# handle the redundant param
new_begin_mask[idx] = 0
new_end_mask[idx] = max_size
axes.append(idx)
continue
# ignore ellipsis axes
axes.append(idx + ellipsis_gap)
mask = (shrink_axis_mask >> idx) & 1
if mask != 0:
shrink_strided_mask[idx] = max_size