-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Conversation
I think you can get rid of the change from |
@haojin2: it gets rid of the integer issue by getting rid of integer support (which seems unnecessarily restrictive). But adding integer support could also be a separate PR. What do you think I should do? |
@sbodenstein I would say that floating inputs should be the most common use case, so I would suggest that we start with floating types only and if we see needs for integer types later we can address accordingly. |
@haojin2: ok, done. Can we merge? |
@sbodenstein I cannot decide that as I'm not a committer, I'll ping some committers for you. |
Y = mx.symbol.Pad(data=X, mode=mode, pad_width=pad_width) | ||
x = mx.random.uniform(-1, 1, shape, ctx=mx.cpu()).copyto(xpu) | ||
if dtype in real_types: | ||
x = mx.random.uniform(-1, 1, shape, ctx=mx.cpu(), dtype=dtype).copyto(xpu) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what's the purpose of looping over real_types and generate a new x to override the previous one?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Its not looping over real_types
and its not overriding anything (X
and x
are different). This is simply written to be compatible with the integer case one day (happy to remove for now if you want). Unless I'm misunderstanding something?
It would be great if I could get some idea of when this can be merged, as we want to make new builds with this feature enabled. |
@eric-haibin-lin @anirudh2290 Please take a look when you have time, thanks! |
Y = mx.symbol.Pad(data=X, mode=mode, pad_width=pad_width) | ||
x = mx.random.uniform(-1, 1, shape, ctx=mx.cpu()).copyto(xpu) | ||
if dtype in real_types: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you do not really need this if-else branch at this moment as the caller of this function only passes real types.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed. I had kept it there as I thought that it will be useful when integer support is added. But removed for now.
|
||
|
||
@with_seed() | ||
def test_pad(): | ||
ct = default_context() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: ctx is probably a more commonly used name for context?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed to ctx
.
check_pad_with_shape(shape2, default_context(), pad2, 'reflect') | ||
# note: this op doesn't support ints yet. Add tests when supported | ||
test_types = ["float32", "float64", "float16"] | ||
for d in test_types: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: to be consistent with most other tests within this file I would prefer:
dtypes = [np.float16, np.float32, np.float64]
for dtype in dtypes:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed.
LGTM, I triggered the CI again, we can merge once CI passes. |
@sbodenstein changes to the
|
check_pad_with_shape(shape1, default_context(), pad1, 'reflect') | ||
check_pad_with_shape(shape2, default_context(), pad2, 'reflect') | ||
# note: this op doesn't support ints yet. Add tests when supported | ||
test_types = ["float16", "float32", "float64"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should be
dtypes = [np.float16, np.float32, np.float64]
?
* fix no data type inference for pad * add support for int types * add tests for all types * fix gpu type switch * remove integer support * fix python op test style issues * fix type bug in python tests
* fix no data type inference for pad * add support for int types * add tests for all types * fix gpu type switch * remove integer support * fix python op test style issues * fix type bug in python tests
This PR allows the
pad
operator to operate on any type. Fixes Issue #11967.One potential issue: I've enabled this layer for integer types, for which gradients don't get computed. Is there a standard way of handling this?