Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FRONTEND][TF] conv2d_transpose 'SAME' support kernel more than 1x1 #4484

Merged
merged 3 commits into from
Dec 28, 2019

Conversation

optima2005
Copy link
Contributor

This is to support 'SAME' padding for conv3d_transpose for kernel more than 1x1

Discussions:
https://discuss.tvm.ai/t/why-we-only-support-kernel-1-1-for-tf-conv2d-transpose-same/4957

@yongwww @apivovarov

@optima2005 optima2005 changed the title [FRONTEND][TF] conv3d_transpose 'SAME' support kernel more than 1x1 [FRONTEND][TF] conv2d_transpose 'SAME' support kernel more than 1x1 Dec 9, 2019
_test_convolution('conv_transpose', [4, 19, 8, 8], [3, 3, 19, 19], [1, 1], [2, 2], 'VALID',
'NCHW', [4, 19, 17, 17])
_test_convolution('conv_transpose', [4, 19, 17, 17], [1, 1, 124, 19], [1, 1], [1, 1], 'SAME',
'NCHW', [4, 124, 17, 17])
_test_convolution('conv_transpose', [4, 19, 17, 17], [3, 3, 124, 19], [1, 1], [1, 1], 'SAME',
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you also add test case for

kernel 2x2, strides 2x2, SAME
kernel 3x3, strides 2x2, SAME

E.g. if input is 5x5 then valid outputs are 9x9 or 10x10 (you can use one or another in the output_shape tensor) regardless of the kernel size.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The strides in transpose convolution is used to dilate the input, see this. If the input is dilated, there would be no way to get 'SAME' size of output by only padding.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok. I can get your point. 'SAME' means the size of the enlarged input size(by dilation).

Copy link
Contributor

@apivovarov apivovarov Dec 10, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TF code for kernel 2x2, strides 2x2 and padding SAME is

dshape=(1,5,5,4)
#hwoi
kshape=(2,2,2,4)
oshape=(1,9,9,2)
# or 
oshape=(1,10,10,2)
dtype='float32'
with tf.Session() as sess:
    x = tf.placeholder(shape=dshape, dtype=dtype)
    w = tf.placeholder(shape=kshape, dtype=dtype)
    dc = tf.nn.conv2d_transpose(x, w, output_shape=oshape, strides=(1,2,2,1), padding='SAME')

    res_dc = sess.run(dc, feed_dict={x: data, w:weight})

2 * param->padding[0] + param->output_padding[0]));
oshape.Set(3, (param->strides[1] * (dshape_nchw[3] - 1) + dilated_ksize_x -
2 * param->padding[1] + param->output_padding[1]));
if ( param->padding.size() == 2 ) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor formatting comment. I think we do not need spaces after ( and before )

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will fix it.

oshape.Set(2, (param->strides[0] * (dshape_nchw[2] - 1) + dilated_ksize_y -
param->padding[0] - param->padding[2] + param->output_padding[0]));
oshape.Set(3, (param->strides[1] * (dshape_nchw[3] - 1) + dilated_ksize_x -
param->padding[1] - param->padding[3] + param->output_padding[1]));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Formulas are very similar except of using

2 * param->padding[0]  vs  param->padding[0] + param->padding[2]
and 
2 * param->padding[1]   vs   param->padding[1] + param->padding[3]

Can we calculate correct padding and then use the same formulas with calculated padding for both cases?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The additional condition is to handle the head padding and tail padding are diffenct, for even kernel in this case. I didn't quite understand your point. Please clearify.

Copy link
Contributor

@apivovarov apivovarov Dec 10, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about the following?

  int pad_h, pad_w;
  if (param->padding.size() == 2) {
    pad_h = 2 * param->padding[0];
    pad_w = 2 * param->padding[1];
  } else if (param->padding.size() == 4) {
    pad_h = param->padding[0] + param->padding[2];
    pad_w = param->padding[1] + param->padding[3];
  } else {
    CHECK_EQ(param->padding.size(), 4);
  }
  oshape.Set(2, (param->strides[0] * (dshape_nchw[2] - 1) + dilated_ksize_y -
                 pad_h + param->output_padding[0]));
  oshape.Set(3, (param->strides[1] * (dshape_nchw[3] - 1) + dilated_ksize_x -
                 pad_w + param->output_padding[1]));

Copy link
Contributor Author

@optima2005 optima2005 Dec 10, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is better. Thanks!

@FrozenGene
Copy link
Member

@optima2005 do you mind taking up this RFC? #2682 Because I see you start to support 4D padding in Conv2dTranspose. I think the work should be similar with convolution, which could make us avoid pad operator and get potential performance.

@optima2005
Copy link
Contributor Author

@FrozenGene let me do some evaluations first and then I will get back to you if I think I could do it. Actually I am still in the progess to understand those ops' implementation, especially those optimazation approach.

@optima2005
Copy link
Contributor Author

@apivovarov I added the test cases per as your proposal and tried to pass them. But one condition (10x10 in your example) still not working in cuda target. I haven't found the root cause. I am wondering whether you could help.

@optima2005
Copy link
Contributor Author

I got this error on cuda target with test case:

        _test_convolution('conv_transpose', [4, 32, 8, 8], [3, 3, 176, 32], [1, 1], [2, 2], 'SAME',
                          'NCHW', [4, 176, 16, 16])

incubator-tvm/src/pass/loop_partition.cc:544: Cannot prove: ((((floordiv(((((floordiv((floormod(dh, 2) + 7), 2)16) + 16)(floordiv((floormod(dw, 2) + 7), 2) + 1)) + 63), 64) - 1) - ((((floordiv((floormod(dh, 2) + 7), 2)16) + 16)(floordiv((floormod(dw, 2) + 7), 2) + 1)) - (floordiv(((((floordiv((floormod(dh, 2) + 7), 2)16) + 16)(floordiv((floormod(dw, 2) + 7), 2) + 1)) + 63), 64)*63))) + 1) >= 0), when generating the post doubt loop

Is it relative to issue #4470? @apivovarov

@optima2005
Copy link
Contributor Author

This can be confirmed that after I add exclusion for this specific case in workaround #4472, that test case can pass.

@tqchen tqchen added the status: need update need update based on feedbacks label Dec 19, 2019
@tqchen
Copy link
Member

tqchen commented Dec 19, 2019

please comment about the status of the PR, @FrozenGene

_test_convolution('conv_transpose', [4, 8, 8, 32], [3, 3, 176, 32], [1, 1], [2, 2], 'SAME',
'NHWC', [4, 15, 15, 176])
# cuda target not working
#_test_convolution('conv_transpose', [4, 8, 8, 32], [3, 3, 176, 32], [1, 1], [2, 2], 'SAME',
Copy link
Member

@FrozenGene FrozenGene Dec 20, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The case failed due to the "Cannot prove" error when compiling the module. see #4470. It seems there is no fix for it yet.

'NCHW', [4, 176, 8, 8])
_test_convolution('conv_transpose', [4, 32, 8, 8], [3, 3, 176, 32], [1, 1], [2, 2], 'SAME',
'NCHW', [4, 176, 15, 15])
# cuda target not working
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the reason why it doen't work for cuda

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The case failed due to the "Cannot prove" error when compiling the module. see #4470. It seems there is no fix for it yet.

@optima2005
Copy link
Contributor Author

For stride '2x2', I tried kernel size up to '7x7', all has the "Cannot prove" compiling error for cuda target. So I added the walkaround. for other strides > '2x2', I didn't find this error.
@yongwww @FrozenGene @apivovarov Would you please take a look again? Thanks!

Copy link
Member

@yongwww yongwww left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@optima2005
Copy link
Contributor Author

@yzhliu, would you mind to take a look? It is the fix to 'SAME' mode of #4300
Thanks!

@tqchen tqchen merged commit 227c7af into apache:master Dec 28, 2019
@tqchen
Copy link
Member

tqchen commented Dec 28, 2019

zhiics pushed a commit to zhiics/tvm that referenced this pull request Dec 31, 2019
…pache#4484)

* [FRONTEND][TF] conv3d_transpose 'SAME' support kernel more than 1x1

* revised per as review comments

* add more fallback wolkaround to make all tests pass
@optima2005 optima2005 deleted the fix_conv2d_transpose_same branch January 1, 2020 05:41
zhiics added a commit to neo-ai/tvm that referenced this pull request Jan 9, 2020
* Change upstream url

* Fix bias_add gradient (apache#4516)

* Fix bias_add gradient

A change caused collapse_sum_like to reject implicit dimension
broadcasting for bias_add gradient, so switch to explicit sum reduction
on the non-bias axis dimensions.

* Lint fix

* [Bugfix][Frontend][TFlite] Fix wrong function call in TANH tests (apache#4517)

* Replace sigmoid() with tanh() in tests for TANH

* Fixed extra reshape parameter bug. (apache#4524)

* Use the best tuner possible (apache#4397)

* Use the best tuner possible

* Add comment denoting availability of better tuners

* Fix typos and wording

* [ir] use DataType instead of Type for readability because Type has been deprecated (apache#4513)

* add bfloat16 typeflag support (apache#4525)

* fix empty config caused KeyError (apache#4520)

* fix onnx shape dtype (apache#4528)

* fix crash issue in tsim backend (apache#4527)

* PIL is depreciated and should be replaced with pillow (a fork of PIL) (apache#4533)

Change-Id: If2075df5475505f2da87dae7145af5a7ab83d8a4

* [Relay] External codegen (apache#4482)

* Update legacy places from nnvm to relay. (apache#4535)

* Update legacy places from nnvm to relay.

This PR prepares the current mainline to remove nnvm compiler dep.

* remove legacy stage

* Implement 1d deconvolution (apache#4476)

* [relay][op] add expand op (from ONNX) to relay frontend (apache#4483)

* Add Expand to onnx.py

* add test function for expand

* Fix a onnx frontend test

* Add tests for the value itself instead of shape only on test_expand

* Cleaned up some unnecessary modifications.

* [TOPI] Allow batch matmul to be fused into injective ops (apache#4537)

* [TOPI] Fixed nms max_output_size loop (apache#4541)

One of the loops in hybrid_nms used for
performing the max_output_size reordering
was incorrectly designated as parallel
resulting in incorrect behaviour. This patch
changes that loop to a serial loop.

Change-Id: I97184f5887f5f028d8ab339fa2808eb7630a4017

* [DOCS] Mention Ninja build system in install/from_source.rst (apache#4554)

* [DOCS] Mention Ninja build system in install/from_source.rst

* Address comments

* [PYTHON][FFI] Cythonize NDArray.copyto (apache#4549)

* [PYTHON][FFI] Cythonize NDArray.copyto

* Cythonize the shape property

* vm external codegen (apache#4544)

* [COMMUNITY] @cchung100m -> reviewer (apache#4557)

* [VTA] improved virtual memory mapping (apache#4545)

* [VTA] improved virtual memory mapping

* Update virtual_memory.cc

* [IR] fix style in ir_mutator and ir_visitor (apache#4561)

* [RUNTIME][VULKAN] Fix compiler warning (apache#4559)

* [REFACTOR][DTYPE] Isolate dtype to runtime (apache#4560)

dtype.h -> runtime/data_type.h

Changes:
- Rename all old reference of tvm::Type to DataType
- ExprNode.type -> ExprNode.dtype
- Expr.type() -> Expr.dtype()
- Change Expr related functions to expr_operator.
  - DataType::min() -> min_value(DataType)
  - DataType::max() -> max_value(DataType)
- Move type constructor Int, UInt, Float, Handle, Bool into DataType.
  - Int(bits) -> DataType::Int(bits)
  - UInt(bits) -> DataType::UInt(bits)

* Support standardize runtime module (apache#4532)

* [Relay][Frontend][ONNX] Support auto_pad in Conv and ConvTranspose (apache#4563)

* [TEST] Remove nnvm related code in topi and test script (apache#4562)

* [TEST] Remove nnvm related code in topi and test script

* Remove docs dep

* [Relay] add max_pool3d in relay and TF converter (apache#4551)

* [Relay] add max_pool3d in relay and TF converter

* fix comments

* Remove nnvm (apache#4565)

* [VTA][Chisel] End-to-end Inference with Chisel VTA (apache#4574)

* [VTA][Chisel] End-to-end Inference with Chisel VTA

* Update TensorAlu.scala

* remove unnecessary cast to int32 (apache#4573)

* Fix llvm-enabled build by adding missing intrinsics headers (apache#4575)

* [DEPRECATION] Remove NNVM compiler (apache#4571)

* Remove NNVM compiler

* [Relay/Topi][Op] Added native DepthToSpace and SpaceToDepth Operators (apache#4566)

* Added tvm function stencil for subpixel operations to topi.

* Topi subpixel operators added and tested.

* Added subpixel attrs.

* Added depth_to_space relay attributes.

* depth_to_space fully working.

* Fixed NHWC shape bug.

* SpaceToDepth in and all tests passing.

* lint fixes.

* Added string include

* Fixed topi formatting.

* Added DCR/CDR mode to depthtospace operator.

* [DOC] fix doc in api.py (apache#4580)

* [DEPRECATION] Cleanup legacy verilog support (apache#4576)

This PR cleans up the left over code for legacy verilog support which was experimental.
The new hardware backend path is now support by VTA via TSIM.

* [RUNTIME] Remove Extension VTable in favor of Unified Object system. (apache#4578)

Before the unified object protocol, we support pass
additional extension objects around by declaring a type as an extension type.
The old extension mechanism requires the types to register their
constructor and deleter to a VTable and does not enjoy the benefit of the
self-contained deletion property of the new Object system.

This PR upgrades the extension example to make use of the new object system
and removed the old Extension VTable.

Note that the register_extension funtion in the python side continues to work
when the passed argument does not require explicit container copy/deletion,
which covers the current usecases of the extension mechanism.

* Some Windows and MSVC fixes (apache#4569)

* fix python exception creation in Windows

* better string conversion for msvc

* fix cpp style issue

* [NEWS] add v0.6 release (apache#4558)

* [NEWS] add v0.6 release

* remove link prefix

* fix issue number

* [DOCS]fix typos in autotvm tutorial (apache#4585)

* [Quantization, Calibrate] Fix context creation when current_target is explicity set (apache#4582)

* [Container] Fix NDArray SaveDLTensor declaration and implementation signature different (apache#4586)

* [TOPI][AutoTVM] NHWC conv2d templates for ARM (apache#3859)

* [AutoTVM][TOPI] NHWC conv2d templates (spatial pack) for ARM

As some frontends (tflite for example) are using NHWC as the default
layout, we are enabling NHWC schedule templates in TOPI and AutoTVM.

* some comments fix

* [FIX][TOPI][X86] schedule dense pack (apache#4539)

* [Relay] Convert Layout Pass. (apache#4335)

* [Relay][AlterLayout] Broadcast with scalar shape (apache#4577)

* [TOPI] add 3D upsampling Op. (apache#4584)

* [TOPI] add 3D upsampling Op.

* fix lint issues

* change align_corners to coordinate_transformation_mode

* fix resize3d half_pixel

* make a simple function and clean up trilinear_resize3d_python

* fix doc

* [Runtime] add necessary const qualifier for NDArray container of parameters (apache#4590)

* [autotvm] fix typos in comment (apache#4591)

* fix tf.compat.v1 issue for tf verison <=1.12 (apache#4593)

* [FRONTEND][TF] conv2d_transpose 'SAME' support kernel more than 1x1 (apache#4484)

* [FRONTEND][TF] conv3d_transpose 'SAME' support kernel more than 1x1

* revised per as review comments

* add more fallback wolkaround to make all tests pass

* [GraphRuntime] Support parameter out in the graph runtime debug (apache#4598)

* [GraphRuntime] Support parameter out in the graph runtime debug

* Dummy commit to trigger build

* [Perf] Add CublasLt extern support for better Igemm performance (apache#4550)

* cublaslt added

* fix lint

* address comments

* address more comments

* Trigger CI

* Trigger CI

* fix codegenc (apache#4597)

* [REFACTOR][RUNTIME] Update NDArray use the Unified Object System (apache#4581)

* [REFACTOR][RUNTIME] Move NDArray to Object System.

Previously NDArray has its own object reference counting mechanism.
This PR migrates NDArray to the unified object protocol.

The calling convention of NDArray remained intact.
That means NDArray still has its own type_code and
its handle is still DLTensor compatible.

In order to do so, this PR added a few minimum runtime type
detection in TVMArgValue and RetValue only when the corresponding
type is a base type(ObjectRef) that could also refer to NDArray.

This means that even if we return a base reference object ObjectRef
which refers to the NDArray. The type_code will still be translated
correctly as kNDArrayContainer.
If we assign a non-base type(say Expr) that we know is not compatible
with NDArray during compile time, no runtime type detection will be performed.

This PR also adopts the object protocol for NDArray sub-classing and
removed the legacy NDArray subclass protocol.
Examples in apps/extension are now updated to reflect that.

Making NDArray as an Object brings all the benefits of the object system.
For example, we can now use the Array container to store NDArrays.

* Address review comments

* [Relay][Convert Layout] Handling batch norm layout change. (apache#4600)

* [relay][refactor] Cache Op::Get in passes to reduce lookup overhead (apache#4594)

* Refactor to use IsOp utility

* retrigger CI

* Update dmlc_tvm_commit_id.txt

* disable one test_batch_norm unit test for now to check CI

* enable test_batch_norm

Co-authored-by: SWu <SWu@users.noreply.github.com>
Co-authored-by: Ina Dobreva <55383260+inadob@users.noreply.github.com>
Co-authored-by: Josh Fromm <jwfromm@uw.edu>
Co-authored-by: miheer vaidya <v.miheer@gmail.com>
Co-authored-by: Liang ZOU <liang.d.zou@gmail.com>
Co-authored-by: YixinBao <yixin.bao@intel.com>
Co-authored-by: Cody Yu <comaniac0422@gmail.com>
Co-authored-by: masahi <masahi129@gmail.com>
Co-authored-by: Liangfu Chen <liangfu.chen@icloud.com>
Co-authored-by: lhutton1 <35535092+lhutton1@users.noreply.github.com>
Co-authored-by: Tianqi Chen <tqchen@users.noreply.github.com>
Co-authored-by: Alex Gladkov <gladkov_alex@yahoo.com>
Co-authored-by: Takato Yamada <tkclimb0911@gmail.com>
Co-authored-by: Haichen Shen <shenhaichen@gmail.com>
Co-authored-by: mbarrett97 <55580676+mbarrett97@users.noreply.github.com>
Co-authored-by: Hideto Ueno <uenoku.tokotoko@gmail.com>
Co-authored-by: Siyuan Feng <Hzfengsy@sjtu.edu.cn>
Co-authored-by: Zhao Wu <wuzhaozju@gmail.com>
Co-authored-by: Neo Chien <cchung100m@cs.ccu.edu.tw>
Co-authored-by: Yong Wu <55wuyong@163.com>
Co-authored-by: Dmitri Makarov <dmakarov@users.noreply.github.com>
Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com>
Co-authored-by: kice <wslikerqs@gmail.com>
Co-authored-by: Yizhi Liu <liuyizhi@apache.org>
Co-authored-by: Wang Yucheng <wyc91543@163.com>
Co-authored-by: 王振华(Zhenhua WANG) <i@jackwish.net>
Co-authored-by: deepIgnorance <zhengsizemax@outlook.com>
Co-authored-by: Animesh Jain <anijain@umich.edu>
Co-authored-by: optima2005 <56945758+optima2005@users.noreply.github.com>
Co-authored-by: zhuochen <zhuochen@outlook.com>
Co-authored-by: Leyuan Wang <laurawly@gmail.com>
zhiics pushed a commit to neo-ai/tvm that referenced this pull request Jan 11, 2020
…pache#4484)

* [FRONTEND][TF] conv3d_transpose 'SAME' support kernel more than 1x1

* revised per as review comments

* add more fallback wolkaround to make all tests pass
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
status: need update need update based on feedbacks
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants