Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Still too strict dtype requirements for broadcast_like #19343

Open
fhieber opened this issue Oct 13, 2020 · 1 comment
Open

Still too strict dtype requirements for broadcast_like #19343

fhieber opened this issue Oct 13, 2020 · 1 comment
Labels
Bug needs triage v1.x Targeting v1.x branch

Comments

@fhieber
Copy link
Contributor

fhieber commented Oct 13, 2020

Description

mx.nd.broadcast_like has unnecessarily strict dtype requirements for its two data inputs. PR #17977 was aimed to relax them but in MXNet 1.7 I still get the following error from this minimum example:

> a = mx.nd.ones((96, 1), dtype='float32')
> b = mx.nd.ones((96, 32, 32), dtype='float16')
> mx.nd.broadcast_like(a, b, lhs_axes=(1,), rhs_axes=(1,))

Error Message

---------------------------------------------------------------------------
MXNetError                                Traceback (most recent call last)
<ipython-input-17-3cde0a810695> in <module>
----> 1 mx.nd.broadcast_like(a, b, lhs_axes=(1,), rhs_axes=(1,))

~/miniconda3/lib/python3.7/site-packages/mxnet/ndarray/register.py in broadcast_like(lhs, rhs, lhs_axes, rhs_axes, out, name, **kwargs)

~/miniconda3/lib/python3.7/site-packages/mxnet/_ctypes/ndarray.py in _imperative_invoke(handle, ndargs, keys, vals, out, is_np_op, output_is_list)
     89         c_str_array(keys),
     90         c_str_array([str(s) for s in vals]),
---> 91         ctypes.byref(out_stypes)))
     92
     93     create_ndarray_fn = _global_var._np_ndarray_cls if is_np_op else _global_var._ndarray_cls

~/miniconda3/lib/python3.7/site-packages/mxnet/base.py in check_call(ret)
    244     """
    245     if ret != 0:
--> 246         raise get_last_ffi_error()
    247
    248

MXNetError: Traceback (most recent call last):
  File "/home/centos/mxnet/3rdparty/mshadow/../../src/operator/tensor/../elemwise_op_common.h", line 135
MXNetError: Check failed: assign(&dattr, vec.at(i)): Incompatible attr in node  at 1-th input: expected float32, got float16

To Reproduce

a = mx.nd.ones((96, 1), dtype='float32')
b = mx.nd.ones((96, 32, 32), dtype='float16')
mx.nd.broadcast_like(a, b, lhs_axes=(1,), rhs_axes=(1,))

What have you tried to solve it?

aligning dtypes by either casting a or b to the other dtype resolves the issue. But this is not a practical solution in my AMP use case as both tensors are fairly large and I want to avoid a copy.

Environment

mxnet-cu92             1.7.0
@fhieber fhieber changed the title Still too strict type requirements for broadcast_like Still too strict dtype requirements for broadcast_like Oct 13, 2020
@ptrendx
Copy link
Member

ptrendx commented Oct 13, 2020

Seems that PR #17977 was merged to master after 1.x branch creation, so would need to be cherry-picked to 1.x branch.

@szha szha added the v1.x Targeting v1.x branch label Feb 8, 2021
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
Bug needs triage v1.x Targeting v1.x branch
Projects
None yet
Development

No branches or pull requests

3 participants