Skip to content

Commit

Permalink
fix axis
Browse files Browse the repository at this point in the history
  • Loading branch information
Valery Chernov committed Feb 10, 2023
1 parent 286cb18 commit 91af964
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions python/tvm/topi/cuda/scatter_elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def scatter_elements(data, indices, updates, axis=0, reduction="update"):
if not isinstance(axis, int):
axis = get_const_int(axis)

def gen_ir(data, indices, updates, out):
def gen_ir(data, indices, updates, out, axis):
ib = tir.ir_builder.create()

data_ptr = ib.buffer_ptr(data)
Expand All @@ -92,7 +92,7 @@ def gen_ir(data, indices, updates, out):
full_range = before_axis_range * before_axis_stride

ind_shape = indices.shape
ind_axis_range = shape[axis]
ind_axis_range = ind_shape[axis]

ind_before_axis_range = 1
ind_after_axis_range = 1
Expand Down Expand Up @@ -173,7 +173,7 @@ def gen_ir(data, indices, updates, out):
return te.extern(
[data.shape],
[data, indices, updates],
lambda ins, outs: gen_ir(ins[0], ins[1], ins[2], outs[0]),
lambda ins, outs: gen_ir(ins[0], ins[1], ins[2], outs[0], axis),
dtype=data.dtype,
out_buffers=[out_buf],
name="scatter_elements_cuda",
Expand Down

0 comments on commit 91af964

Please sign in to comment.