Skip to content

Commit

Permalink
fix flaky TF test (apache#8431)
Browse files Browse the repository at this point in the history
  • Loading branch information
Matthew Brookhart authored and ylc committed Jan 13, 2022
1 parent c4a8c37 commit 5ed33bd
Showing 1 changed file with 5 additions and 13 deletions.
18 changes: 5 additions & 13 deletions python/tvm/topi/image/resize.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,9 @@ def get_3d_indices(indices, layout="NCDHW"):
return n, c, z, y, x, cc


def get_1d_pixel(data, layout, boxes, image_width, n, c, x, cc, ib, ic):
def get_1d_pixel(data, layout, image_width, n, c, x, cc, ib, ic):
"""Get 1d pixel"""
if boxes is None:
x = tvm.te.max(tvm.te.min(x, image_width - 1), 0)
x = tvm.te.max(tvm.te.min(x, image_width - 1), 0)
if layout == "NWC":
return data(n, x, c).astype("float")
if layout == "NCW":
Expand All @@ -91,11 +90,10 @@ def get_1d_pixel(data, layout, boxes, image_width, n, c, x, cc, ib, ic):
return data(n, c, x, cc).astype("float")


def get_2d_pixel(data, layout, boxes, image_height, image_width, n, c, y, x, cc, ib, ic):
def get_2d_pixel(data, layout, image_height, image_width, n, c, y, x, cc, ib, ic):
"""Get 2d pixel"""
if boxes is None:
y = tvm.te.max(tvm.te.min(y, image_height - 1), 0)
x = tvm.te.max(tvm.te.min(x, image_width - 1), 0)
y = tvm.te.max(tvm.te.min(y, image_height - 1), 0)
x = tvm.te.max(tvm.te.min(x, image_width - 1), 0)
if layout == "NHWC":
return data(n, y, x, c).astype("float")
if layout == "NCHW":
Expand Down Expand Up @@ -288,7 +286,6 @@ def _cast_output(value, data_dtype="float32", out_dtype=None):
value = get_1d_pixel(
data,
layout,
boxes,
image_width,
box_idx,
c,
Expand All @@ -307,7 +304,6 @@ def _cast_output(value, data_dtype="float32", out_dtype=None):
p[i] = get_1d_pixel(
data,
layout,
boxes,
image_width,
box_idx,
c,
Expand All @@ -329,7 +325,6 @@ def _cast_output(value, data_dtype="float32", out_dtype=None):
p[i] = get_1d_pixel(
data,
layout,
boxes,
image_width,
box_idx,
c,
Expand Down Expand Up @@ -576,7 +571,6 @@ def _cast_output(value, data_dtype="float32", out_dtype=None):
value = get_2d_pixel(
data,
layout,
boxes,
image_height,
image_width,
box_idx,
Expand All @@ -600,7 +594,6 @@ def _cast_output(value, data_dtype="float32", out_dtype=None):
p[j][i] = get_2d_pixel(
data,
layout,
boxes,
image_height,
image_width,
box_idx,
Expand Down Expand Up @@ -630,7 +623,6 @@ def _cast_output(value, data_dtype="float32", out_dtype=None):
p[j][i] = get_2d_pixel(
data,
layout,
boxes,
image_height,
image_width,
box_idx,
Expand Down

0 comments on commit 5ed33bd

Please sign in to comment.