Skip to content

Commit

Permalink
Fix review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
abhikran-quic committed Jun 27, 2024
1 parent 28a4bb6 commit 03a6ad4
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 10 deletions.
4 changes: 1 addition & 3 deletions python/tvm/relax/transform/legalize_ops/manipulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,8 +182,7 @@ def te_layout_transform(data, name):
)

def set_axis_sep(axis_sep: list, sch: tir.schedule, buffer_type: str):
if len(axis_sep) != 0:
sch.set_axis_separator(primfunc_name, (buffer_type, 0), axis_separators=axis_sep)
sch.set_axis_separator(primfunc_name, (buffer_type, 0), axis_separators=axis_sep)

index_map: tvm.tir.IndexMap = call.attrs.index_map
pad_value = call.attrs.pad_value
Expand Down Expand Up @@ -214,7 +213,6 @@ def set_axis_sep(axis_sep: list, sch: tir.schedule, buffer_type: str):
sch.transform_layout(primfunc_name, ("write", 0), index_map, pad_value)
set_axis_sep(axis_separators, sch, "write")
if input_axis_separators is not None:
input_axis_separators = list(map(lambda x: x.value, input_axis_separators))
set_axis_sep(input_axis_separators, sch, "read")
gvar = bb.add_func(sch.mod["main"], primfunc_name)
output_shape = index_map.map_shape(list(call_args[0].struct_info.shape))
Expand Down
2 changes: 1 addition & 1 deletion src/relax/op/tensor/manipulate.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ Expr flatten(Expr x);
*/
Expr layout_transform(Expr x, tir::IndexMap index_map, Optional<PrimValue> pad_value,
Optional<Array<IntImm>> axis_separators,
Optional<Array<IntImm>> input_axis_separators);
Optional<Array<IntImm>> input_axis_separators = NullOpt);

/*!
* \brief Permutes the dimensions of an array.
Expand Down
6 changes: 4 additions & 2 deletions src/relax/transform/alter_op_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,8 @@ class AlterOpImplMutator : public ExprMutator {
const auto& replacement_func = op_impl_map_[op_kind];

Array<IndexMap> buffer_transforms;
Optional<Array<Array<IntImm>>> axis_separators, input_axis_separators;
Optional<Array<Array<IntImm>>> axis_separators;
Optional<Array<Array<IntImm>>> input_axis_separators;
if (op_buffer_transforms__.count(op_kind)) buffer_transforms = op_buffer_transforms__[op_kind];
if (op_buffer_axis_separators__.count(op_kind))
axis_separators = op_buffer_axis_separators__[op_kind];
Expand Down Expand Up @@ -293,7 +294,8 @@ class AlterOpImplMutator : public ExprMutator {
Array<Expr> updated_inputs;
int index = 0;
for (const auto& input : inputs->fields) {
Array<IntImm> axis_separator, input_axis_separator;
Array<IntImm> axis_separator;
Array<IntImm> input_axis_separator;
if (axis_separators.defined()) {
Array<Array<IntImm>> axis_separators_value = axis_separators.value();
axis_separator = axis_separators_value[index];
Expand Down
4 changes: 0 additions & 4 deletions tests/python/relax/test_transform_alter_op_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,8 +609,6 @@ def relax_some_op_replacement(arg0: T.Buffer((4, 4), "float32"), arg1: T.Buffer(
for ax0, ax1 in T.grid(4, 4):
with T.block("T_add"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(arg0[v_ax0, v_ax1], arg1[v_ax0, v_ax1])
T.writes(output0[v_ax0, v_ax1], output1[v_ax0, v_ax1])
output0[v_ax0, v_ax1] = arg0[v_ax0, v_ax1] + arg1[v_ax0, v_ax1]
output1[v_ax0, v_ax1] = arg0[v_ax0, v_ax1] - arg1[v_ax0, v_ax1]

Expand All @@ -633,8 +631,6 @@ def some_op_2d(arg0: T.Buffer((4, 4), "float32"), arg1: T.Buffer((4, 4), "float3
for ax0, ax1 in T.grid(4, 4):
with T.block("T_add"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(arg0[v_ax0, v_ax1], arg1[v_ax0, v_ax1])
T.writes(output0[v_ax0, v_ax1], output1[v_ax0, v_ax1])
output0[v_ax0, v_ax1] = arg0[v_ax0, v_ax1] + arg1[v_ax0, v_ax1]
output1[v_ax0, v_ax1] = arg0[v_ax0, v_ax1] - arg1[v_ax0, v_ax1]
# fmt: on
Expand Down

0 comments on commit 03a6ad4

Please sign in to comment.