Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relax] Support input_axis_separator to allow 2D to 1D conversion #17115

Merged
merged 3 commits into from
Jul 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions include/tvm/relax/attrs/manipulate.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,12 @@ struct LayoutTransformAttrs : public tvm::AttrsNode<LayoutTransformAttrs> {
* first input axis that is part of a new flattened axis.
*/
Optional<Array<IntImm>> axis_separators;
/*!
* axis_separators for input buffers.
* Needed to identify if the input buffer to layout_transform
* contains axis separator.
*/
Optional<Array<IntImm>> input_axis_separators;

TVM_DECLARE_ATTRS(LayoutTransformAttrs, "relax.attrs.LayoutTransformAttrs") {
TVM_ATTR_FIELD(index_map).describe("The layout transformation to apply.");
Expand All @@ -74,6 +80,8 @@ struct LayoutTransformAttrs : public tvm::AttrsNode<LayoutTransformAttrs> {
"padding. If not specified, the compiler is free to choose any value.");
TVM_ATTR_FIELD(axis_separators)
.describe("The separators between input axes when generating flat output axes");
TVM_ATTR_FIELD(input_axis_separators)
.describe("The separators between axes to regenerate output");
}
}; // struct LayoutTransformAttrs

Expand Down
4 changes: 3 additions & 1 deletion include/tvm/relax/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -559,11 +559,13 @@ TVM_DLL Pass DecomposeOpsForTraining(Optional<String> func_name);
* \param op_buffer_transforms Map from kOperatorName attr to layout transformations on each of the
* PrimFunc i/o buffers.
* \param axis_separators Map from kOperatorName attr to axis_separators of each buffer_transforms
* \param input_axis_separators Map from kOperatorName attr to axis_separator for input buffer
* \return The Pass.
*/
TVM_DLL Pass AlterOpImpl(const Map<String, tir::PrimFunc>& op_impl_map,
const Map<String, Array<tir::IndexMap>>& op_buffer_transforms,
const Map<String, Array<Array<IntImm>>>& axis_separators);
const Map<String, Array<Array<IntImm>>>& axis_separators,
const Map<String, Array<Array<IntImm>>>& input_axis_separators);

/*!
* \brief Layout conversion pass.
Expand Down
8 changes: 7 additions & 1 deletion python/tvm/relax/op/manipulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ def layout_transform(
index_map: Union[Callable, IndexMap],
pad_value: Optional[Union[int, float, PrimValue]] = None,
axis_separators: Optional[Union[int, IndexMap.AXIS_SEPARATOR]] = None,
input_axis_separators: Optional[Union[int, IndexMap.AXIS_SEPARATOR]] = None,
):
"""Modifies the layout of a tensor.

Expand Down Expand Up @@ -158,7 +159,12 @@ def layout_transform(
if axis_separators is None:
axis_separators = []

return _ffi_api.layout_transform(x, index_map, pad_value, axis_separators) # type: ignore
if input_axis_separators is None:
input_axis_separators = []

return _ffi_api.layout_transform(
x, index_map, pad_value, axis_separators, input_axis_separators
)


def permute_dims(x: Expr, axes: Optional[List[int]] = None) -> Expr:
Expand Down
13 changes: 10 additions & 3 deletions python/tvm/relax/transform/legalize_ops/manipulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,9 @@ def te_layout_transform(data, name):
name=name,
)

def set_axis_sep(axis_sep: list, sch: tir.schedule, buffer_type: str):
sch.set_axis_separator(primfunc_name, (buffer_type, 0), axis_separators=axis_sep)

index_map: tvm.tir.IndexMap = call.attrs.index_map
pad_value = call.attrs.pad_value
if pad_value is not None:
Expand All @@ -192,8 +195,10 @@ def te_layout_transform(data, name):
pad_value = float(0.0)

axis_separators: tvm.tir.IndexMap.AXIS_SEPARATOR = call.attrs.axis_separators
input_axis_separators: tvm.tir.IndexMap.AXIS_SEPARATOR = call.attrs.input_axis_separators

# Convert to list from array
axis_separators = list(map(lambda x: x.value, axis_separators))
axis_separators = [int(sep) for sep in axis_separators]
primfunc_name = "te_layout_transform"
_, padding_predicate = index_map.non_surjective_inverse(call.args[0].struct_info.shape)
if not isinstance(padding_predicate, tvm.tir.expr.IntImm):
Expand All @@ -206,8 +211,10 @@ def te_layout_transform(data, name):
# Create TIR schedule to apply layout changes with axis separators
sch = tir.Schedule(tir_func)
sch.transform_layout(primfunc_name, ("write", 0), index_map, pad_value)
if len(axis_separators) != 0:
sch.set_axis_separator(primfunc_name, ("write", 0), axis_separators=axis_separators)
set_axis_sep(axis_separators, sch, "write")
if input_axis_separators is not None:
input_axis_separators = [int(sep) for sep in input_axis_separators]
set_axis_sep(input_axis_separators, sch, "read")
gvar = bb.add_func(sch.mod["main"], primfunc_name)
output_shape = index_map.map_shape(list(call_args[0].struct_info.shape))
output_dtype = call_args[0].struct_info.dtype
Expand Down
12 changes: 11 additions & 1 deletion python/tvm/relax/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import numpy as np # type: ignore

import tvm.ir
from tvm.ir.container import Array
from tvm.relax import Expr, Var, StructInfo
from tvm.relax.dpl import DFPattern
from tvm.runtime import NDArray, Object
Expand Down Expand Up @@ -1280,6 +1281,7 @@ def AlterOpImpl(
op_impl_map: Dict[str, PrimFunc],
op_buffer_transforms: Dict[str, List[Union[IndexMap, Callable]]],
op_buffer_axis_separators: Dict[str, List[Union[IndexMap.AXIS_SEPARATOR, Callable]]],
op_buffer_input_axis_separators: Dict[str, List[Union[IndexMap.AXIS_SEPARATOR, Callable]]],
):
"""Replace all PrimFunc's which have matching 'operator_name' attribute, with replacement
PrimFunc that could possibly have different layouts on i/o buffers. The layout
Expand All @@ -1295,6 +1297,8 @@ def AlterOpImpl(
op_kind to layout transformation map for each of the buffers
op_buffer_axis_separators: Dict[str, List[Union[IndexMap.AXIS_SEPARATOR, Callable]]]
op_kind to axis_separator for each index_map
op_buffer_input_axis_separators: Dict[str, List[Union[IndexMap.AXIS_SEPARATOR, Callable]]]
op_kind to axis_separator for input index_map

Returns
-------
Expand All @@ -1303,13 +1307,19 @@ def AlterOpImpl(
for operator_name, transform_list in op_buffer_transforms.items():
l = []
for transform in transform_list:
# Extract the index_map
if isinstance(transform, Callable):
transform = IndexMap.from_func_with_separators(transform)[0]
elif isinstance(transform, (Array, tuple)) and isinstance(transform[0], IndexMap):
transform = transform[0]
l.append(transform)
op_buffer_transforms[operator_name] = l

return _ffi_api.AlterOpImpl(
op_impl_map, op_buffer_transforms, op_buffer_axis_separators
op_impl_map,
op_buffer_transforms,
op_buffer_axis_separators,
op_buffer_input_axis_separators,
) # type: ignore


Expand Down
4 changes: 3 additions & 1 deletion src/relax/op/tensor/manipulate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -472,11 +472,13 @@ TVM_REGISTER_OP("relax.flatten")
TVM_REGISTER_NODE_TYPE(LayoutTransformAttrs);

Expr layout_transform(Expr x, tir::IndexMap index_map, Optional<PrimValue> pad_value,
Optional<Array<IntImm>> axis_separators) {
Optional<Array<IntImm>> axis_separators,
Optional<Array<IntImm>> input_axis_separators) {
ObjectPtr<LayoutTransformAttrs> attrs = make_object<LayoutTransformAttrs>();
attrs->index_map = std::move(index_map);
attrs->pad_value = std::move(pad_value);
attrs->axis_separators = std::move(axis_separators);
attrs->input_axis_separators = std::move(input_axis_separators);

static const Op& op = Op::Get("relax.layout_transform");
return Call(op, {std::move(x)}, Attrs{attrs}, {});
Expand Down
4 changes: 3 additions & 1 deletion src/relax/op/tensor/manipulate.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,12 @@ Expr flatten(Expr x);
* not specified, any value can be used.
* \param axis_separators Array of values to differentiate between input axes
* when generating flattened output axes.
* \param input axis_separators Array of values for input buffer.
* \return The transformed result.
*/
Expr layout_transform(Expr x, tir::IndexMap index_map, Optional<PrimValue> pad_value,
Optional<Array<IntImm>> axis_separators);
Optional<Array<IntImm>> axis_separators,
Optional<Array<IntImm>> input_axis_separators = NullOpt);

/*!
* \brief Permutes the dimensions of an array.
Expand Down
68 changes: 50 additions & 18 deletions src/relax/transform/alter_op_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,14 @@ class AlterOpImplMutator : public ExprMutator {
public:
AlterOpImplMutator(const IRModule& mod, const Map<String, tir::PrimFunc>& op_impl_map,
const Map<String, Array<IndexMap>>& op_buffer_transforms_,
const Map<String, Array<Array<IntImm>>>& axis_separators_)
const Map<String, Array<Array<IntImm>>>& axis_separators_,
const Map<String, Array<Array<IntImm>>>& input_axis_separators_)
: ExprMutator(mod),
mod_(mod),
op_impl_map_(op_impl_map),
op_buffer_transforms__(op_buffer_transforms_),
op_buffer_axis_separators__(axis_separators_) {}
op_buffer_axis_separators__(axis_separators_),
op_buffer_input_axis_separators__(input_axis_separators_) {}

IRModule Run() {
for (const auto& gv : mod_->GetGlobalVars()) {
Expand Down Expand Up @@ -127,9 +129,12 @@ class AlterOpImplMutator : public ExprMutator {

Array<IndexMap> buffer_transforms;
Optional<Array<Array<IntImm>>> axis_separators;
Optional<Array<Array<IntImm>>> input_axis_separators;
if (op_buffer_transforms__.count(op_kind)) buffer_transforms = op_buffer_transforms__[op_kind];
if (op_buffer_axis_separators__.count(op_kind))
axis_separators = op_buffer_axis_separators__[op_kind];
if (op_buffer_input_axis_separators__.count(op_kind))
input_axis_separators = op_buffer_input_axis_separators__[op_kind];

ICHECK(buffer_transforms.empty() || buffer_transforms.size() == replacement_func->params.size())
<< "Either the i/o buffers do not require any transformations or transformations for each "
Expand All @@ -140,15 +145,17 @@ class AlterOpImplMutator : public ExprMutator {
GlobalVar replacement_gv = GetOrCreateGlobalVarForFunc(replacement_func, op_kind);

auto call_tir_inputs_tuple = GetRef<Tuple>(call->args[1].as<TupleNode>());
Tuple updated_inputs = UpdateInputs(call_tir_inputs_tuple, buffer_transforms, axis_separators);
Tuple updated_inputs = UpdateInputs(call_tir_inputs_tuple, buffer_transforms, axis_separators,
input_axis_separators);

ICHECK_EQ(call->sinfo_args.size(), 1) << "call_tir sinfo_args.size() is expected to be 1";
StructInfo updated_ret_sinfo = UpdateStructInfo(call->sinfo_args[0], buffer_transforms);
auto updated_call = builder_->Normalize(
Call(call_tir_op_, {replacement_gv, updated_inputs}, call->attrs, {updated_ret_sinfo}));

// Now transform each of the outputs to previous layout.
return TransformOutputs(updated_call, buffer_transforms, call->sinfo_args[0], axis_separators);
return TransformOutputs(updated_call, buffer_transforms, call->sinfo_args[0], axis_separators,
input_axis_separators);
}

Array<TensorStructInfo> GetTensorStructInfoPerOutput(const StructInfo& output_sinfo) {
Expand All @@ -175,7 +182,8 @@ class AlterOpImplMutator : public ExprMutator {
}

Expr TransformLayout(const Expr& expr, const IndexMap& index_map,
const Array<IntImm>& axis_separators) {
const Array<IntImm>& axis_separators,
const Array<IntImm>& input_axis_separators) {
if (IsScalarConstant(expr) || index_map.get() == nullptr) {
return expr;
}
Expand All @@ -185,6 +193,7 @@ class AlterOpImplMutator : public ExprMutator {
// so would confuse the structural equality check.
attrs->index_map = std::move(DeepCopyIndexMap(index_map));
attrs->axis_separators = std::move(axis_separators);
attrs->input_axis_separators = std::move(input_axis_separators);
return Call(layout_transform_op_, {expr}, Attrs{std::move(attrs)}, {});
}

Expand Down Expand Up @@ -232,7 +241,8 @@ class AlterOpImplMutator : public ExprMutator {

Expr TransformLayoutInverse(const Expr& expr, const IndexMap& index_map,
const TensorStructInfo& old_tensor_sinfo,
const Array<IntImm>& axis_separator) {
const Array<IntImm>& axis_separator,
const Array<IntImm>& input_axis_separator) {
if (IsScalarConstant(expr) || index_map.get() == nullptr) {
return expr;
}
Expand All @@ -243,10 +253,10 @@ class AlterOpImplMutator : public ExprMutator {
index_map.NonSurjectiveInverse(initial_ranges, &analyzer);

if (tir::is_zero(padding_predicate)) {
return TransformLayout(expr, inverse_index_map, axis_separator);
return TransformLayout(expr, inverse_index_map, axis_separator, input_axis_separator);
} else {
auto padded_expr =
builder_->Normalize(TransformLayout(expr, inverse_index_map, axis_separator));
auto padded_expr = builder_->Normalize(
TransformLayout(expr, inverse_index_map, axis_separator, input_axis_separator));
const auto& tensor_sinfo = Downcast<TensorStructInfo>(padded_expr->struct_info_);

GlobalVar gv_remove_pad = GetOrCreateRemovePadOp(old_shape, tensor_sinfo->dtype);
Expand Down Expand Up @@ -277,19 +287,26 @@ class AlterOpImplMutator : public ExprMutator {
* \brief Updates call inputs with layout transformed inputs
*/
Tuple UpdateInputs(const Tuple& inputs, const Array<IndexMap>& transforms,
const Optional<Array<Array<IntImm>>>& axis_separators) {
const Optional<Array<Array<IntImm>>>& axis_separators,
const Optional<Array<Array<IntImm>>>& input_axis_separators) {
if (transforms.empty()) return inputs;

Array<Expr> updated_inputs;
int index = 0;
for (const auto& input : inputs->fields) {
Array<IntImm> axis_separator;
Array<IntImm> input_axis_separator;
if (axis_separators.defined()) {
Array<Array<IntImm>> axis_separators_value = axis_separators.value();
axis_separator = axis_separators_value[index];
}
if (input_axis_separators.defined()) {
Array<Array<IntImm>> input_axis_separators_value = input_axis_separators.value();
input_axis_separator = input_axis_separators_value[index];
}
auto transform = transforms[index++];
updated_inputs.push_back(TransformLayout(input, transform, axis_separator));
updated_inputs.push_back(
TransformLayout(input, transform, axis_separator, input_axis_separator));
}
return Tuple(updated_inputs);
}
Expand Down Expand Up @@ -338,12 +355,13 @@ class AlterOpImplMutator : public ExprMutator {

Expr TransformOutputs(const Expr& expr, const Array<IndexMap>& buffer_transforms,
const StructInfo& old_struct_info,
const Optional<Array<Array<IntImm>>>& axis_separators) {
const Optional<Array<Array<IntImm>>>& axis_separators,
const Optional<Array<Array<IntImm>>>& input_axis_separators) {
if (buffer_transforms.empty()) return expr;

Array<TensorStructInfo> old_output_sinfo = GetTensorStructInfoPerOutput(old_struct_info);

Array<IntImm> axis_sep;
Array<IntImm> axis_sep, input_axis_sep;
size_t num_outputs = old_output_sinfo.size();
if (num_outputs == 0) return expr;

Expand All @@ -355,7 +373,12 @@ class AlterOpImplMutator : public ExprMutator {
Array<Array<IntImm>> axis_separators_value = axis_separators.value();
axis_sep = axis_separators_value[first_output_index];
}
return TransformLayoutInverse(expr, output_map, old_output_sinfo[0], axis_sep);
if (input_axis_separators.defined()) {
Array<Array<IntImm>> input_axis_separators_value = input_axis_separators.value();
input_axis_sep = input_axis_separators_value[first_output_index];
}
return TransformLayoutInverse(expr, output_map, old_output_sinfo[0], axis_sep,
input_axis_sep);
}

// In case of more than one output, we would have to get each item of the output tuple,
Expand All @@ -367,9 +390,13 @@ class AlterOpImplMutator : public ExprMutator {
Array<Array<IntImm>> axis_separators_value = axis_separators.value();
axis_sep = axis_separators_value[i + first_output_index];
}
if (input_axis_separators.defined()) {
Array<Array<IntImm>> input_axis_separators_value = input_axis_separators.value();
input_axis_sep = input_axis_separators_value[i + first_output_index];
}
auto output = builder_->Normalize(TupleGetItem(expr, static_cast<int>(i)));
transformed_outputs.push_back(
TransformLayoutInverse(output, output_map, old_output_sinfo[i], axis_sep));
transformed_outputs.push_back(TransformLayoutInverse(output, output_map, old_output_sinfo[i],
axis_sep, input_axis_sep));
}
return Tuple(transformed_outputs);
}
Expand All @@ -387,6 +414,8 @@ class AlterOpImplMutator : public ExprMutator {
const Map<String, Array<IndexMap>>& op_buffer_transforms__;
/*! \brief Map from kOperatorName attribute to the axis separatos on i/o buffers */
const Map<String, Array<Array<IntImm>>>& op_buffer_axis_separators__;
/*! \brief Map from kOperatorName attribute to the input axis separatos */
const Map<String, Array<Array<IntImm>>>& op_buffer_input_axis_separators__;

const Op& call_tir_op_ = Op::Get("relax.call_tir");
const Op& layout_transform_op_ = Op::Get("relax.layout_transform");
Expand All @@ -396,10 +425,13 @@ namespace transform {

Pass AlterOpImpl(const Map<String, tir::PrimFunc>& op_impl_map,
const Map<String, Array<IndexMap>>& op_buffer_transforms_,
const Map<String, Array<Array<IntImm>>>& axis_separators_) {
const Map<String, Array<Array<IntImm>>>& axis_separators_,
const Map<String, Array<Array<IntImm>>>& input_axis_separators_) {
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func = [=](IRModule mod,
PassContext pc) {
return AlterOpImplMutator(mod, op_impl_map, op_buffer_transforms_, axis_separators_).Run();
return AlterOpImplMutator(mod, op_impl_map, op_buffer_transforms_, axis_separators_,
input_axis_separators_)
.Run();
};
return CreateModulePass(/*pass_function=*/pass_func, //
/*opt_level=*/0, //
Expand Down
Loading
Loading