Skip to content

Commit

Permalink
[BYOC-OpenCLML] Cleanup and review.
Browse files Browse the repository at this point in the history
  • Loading branch information
srkreddy1238 committed Jun 7, 2022
1 parent cbb8eda commit 8a36c25
Show file tree
Hide file tree
Showing 6 changed files with 101 additions and 234 deletions.
2 changes: 1 addition & 1 deletion cmake/modules/contrib/CLML.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ if(USE_CLML_GRAPH_EXECUTOR)
message(STATUS "Build with CLML graph runtime support: "
${EXTERN_CLML_COMPUTE_LIB})

# Set flag to detect ADRENO DNN graph runtime support.
# Set flag to detect CLML graph runtime support.
add_definitions(-DTVM_GRAPH_EXECUTOR_CLML)

message(STATUS "Enable OpenCL as fallback to CLML")
Expand Down
107 changes: 22 additions & 85 deletions python/tvm/relay/op/contrib/clml.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def partition_for_clml(mod, params=None):
@register_func("relay.ext.clml.optimize")
def preprocess_module(mod):
"""
Pre-process a module containing functions ready for CLML codegen. For now we enforce OHWI
Pre-process a module containing functions ready for CLML codegen. For now we enforce OIHW
kernel layout and fold the transforms away.
Parameters
Expand Down Expand Up @@ -132,15 +132,7 @@ def clml_pattern_table():
"""Get the CLML pattern table."""

def conv_pattern():
"""Create a convolution pattern.
Returns
-------
pattern : dataflow_pattern.AltPattern
Denotes the convolution pattern.
"""
# pattern = is_op("nn.pad")(wildcard(), wildcard()) | wildcard()
# pattern = is_op("nn.conv2d")(pattern, is_constant())
"""Create a convolution pattern."""
pattern = is_op("nn.conv2d")(wildcard(), is_constant())
pattern = pattern.optional(lambda x: is_op("nn.bias_add")(x, is_constant()))
pattern = pattern.optional(
Expand All @@ -153,39 +145,21 @@ def conv_pattern():
return pattern

def batch_norm_pattern():
"""Create a batch norm pattern.
Returns
-------
pattern : dataflow_pattern.AltPattern
Denotes the batch norm pattern.
"""
"""Create a batch norm pattern."""
pattern = is_op("nn.batch_norm")(
wildcard(), is_constant(), is_constant(), is_constant(), is_constant()
)
pattern = is_tuple_get_item(pattern)
return pattern

def dense_pattern():
"""Create a dense pattern.
Returns
-------
pattern : dataflow_pattern.AltPattern
Denotes the dense pattern.
"""
"""Create a dense pattern."""
pattern = is_op("nn.dense")(wildcard(), is_constant())
pattern = pattern.optional(lambda x: is_op("add")(x, is_constant()))
return pattern

def pad_pattern():
"""Create a pad pattern.
Returns
-------
pattern : dataflow_pattern.AltPattern
Denotes the dense pattern.
"""
"""Create a pad pattern."""
pattern = is_op("nn.pad")(wildcard(), wildcard())
return pattern

Expand All @@ -200,25 +174,27 @@ def check_conv(extract):
call = call.tuple_value
while call.op.name != "nn.conv2d":
call = call.args[0]
return conv2d(call)

def check_batch_norm(extract):
"""Check batch norm pattern is supported by CLML."""
return True

def check_dense(extract):
"""Check dense pattern is supported by CLML."""
return True

def check_pad(extract):
"""Check pad pattern is supported by CLML."""
attrs, args = call.attrs, call.args
if attrs.data_layout != "NCHW":
return False
data_typ = args[0].checked_type
kernel_typ = args[1].checked_type
is_depthwise = is_depthwise_conv2d(
data_typ.shape,
attrs["data_layout"],
kernel_typ.shape,
attrs["kernel_layout"],
attrs["groups"],
)
if attrs.groups != 1 and not is_depthwise:
return False
return True

return [
("clml.conv2d", conv_pattern(), check_conv),
("clml.dense", dense_pattern(), check_dense),
("clml.pad", pad_pattern(), check_pad),
("clml.batch_norm", batch_norm_pattern(), check_batch_norm),
("clml.dense", dense_pattern()),
("clml.pad", pad_pattern()),
("clml.batch_norm", batch_norm_pattern()),
]


Expand All @@ -230,45 +206,6 @@ def _func_wrapper(expr):
return _func_wrapper


@tvm.ir.register_op_attr("nn.conv2d", "target.clml")
def conv2d(expr):
"""Check if the external CLML codegen for conv2d should be used."""
attrs, args = expr.attrs, expr.args
if attrs.data_layout != "NCHW":
return False
# if attrs.out_dtype != "float32" and attrs.out_dtype != "":
# return False
data_typ = args[0].checked_type
# if len(data_typ.shape) != 4 or data_typ.shape[0] != 1 or data_typ.dtype != "float32":
# return False
kernel_typ = args[1].checked_type
# if len(kernel_typ.shape) != 4 or kernel_typ.dtype != "float32":
# return False
is_depthwise = is_depthwise_conv2d(
data_typ.shape,
attrs["data_layout"],
kernel_typ.shape,
attrs["kernel_layout"],
attrs["groups"],
)
if is_depthwise:
return depthwise_conv2d(attrs, args)
if attrs.groups != 1 and not is_depthwise:
return False
return True


def depthwise_conv2d(attrs, args):
"""Check if the CLML codegen for depthwise convolution should be used.
Note
----
Relay does not have a depthwise conv2d operator whilst ACL does. We simply
separate the checks for depthwise for clarity.
"""
return True


_register_external_op_helper("clip")
_register_external_op_helper("relu")
_register_external_op_helper("nn.global_avg_pool2d")
Expand Down
19 changes: 5 additions & 14 deletions src/relay/backend/contrib/clml/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -139,10 +139,8 @@ class CLMLJSONSerializer : public backend::contrib::JSONSerializer {
} else {
current_call = fn->body.as<CallNode>();
}
// if (backend::IsOp(current_call, "nn.relu") || backend::IsOp(current_call, "nn.relu6")) {
if (backend::IsOp(current_call, "nn.relu")) {
nodes.activation = current_call;
// nodes.act_type = backend::IsOp(current_call, "nn.relu") ? "relu" : "relu6";
nodes.act_type = "relu";
if (current_call->args[0].as<TupleGetItemNode>()) {
auto tuple_item = current_call->args[0].as<TupleGetItemNode>();
Expand All @@ -151,13 +149,6 @@ class CLMLJSONSerializer : public backend::contrib::JSONSerializer {
current_call = current_call->args[0].as<CallNode>();
}
}
#if 0
if (backend::IsOp(current_call, "nn.relu6")) {
nodes.activation = current_call;
current_call = current_call->args[0].as<CallNode>();
nodes.act_type = "relu6";
}
#endif
if (backend::IsOp(current_call, "nn.batch_norm")) {
nodes.bn = current_call;
current_call = current_call->args[0].as<CallNode>();
Expand Down Expand Up @@ -240,8 +231,8 @@ class CLMLJSONSerializer : public backend::contrib::JSONSerializer {
const auto* pad_attr = nodes.pad->attrs.as<PadAttrs>();
ICHECK(pad_attr);
auto p = pad_attr->pad_width;
// Standard convolution pad layout for TVM: before and after dimension wise.
// CLML Takes dimensions as all befores and afters.
// Standard convolution pad layout for TVM: dimension wise pair of pre and post padding.
// CLML takes dimension wise pre-padding followed by dimension wise post-padding.
std::vector<std::string> padding = {std::to_string(p[2][0].as<IntImmNode>()->value),
std::to_string(p[3][0].as<IntImmNode>()->value),
std::to_string(p[2][1].as<IntImmNode>()->value),
Expand Down Expand Up @@ -342,8 +333,8 @@ class CLMLJSONSerializer : public backend::contrib::JSONSerializer {
const auto* pad_attr = pad->attrs.as<PadAttrs>();
ICHECK(pad_attr);
auto p = pad_attr->pad_width;
// Standard convolution pad layout for TVM: before and after dimension wise.
// CLML Takes dimensions as all befores and afters.
// TVM padding format: Dimension wise pair of pre and post padding.
// CLML padding format: Dimension wise pre padding followed by dimension wise post padding.
std::vector<std::string> padding = {std::to_string(p[2][0].as<IntImmNode>()->value),
std::to_string(p[2][1].as<IntImmNode>()->value),
std::to_string(p[3][0].as<IntImmNode>()->value),
Expand Down Expand Up @@ -385,7 +376,7 @@ runtime::Module CLMLCompiler(const ObjectRef& ref) {
std::string graph_json = serializer.GetJSON();
auto param_names = serializer.GetParams();
const auto* pf = runtime::Registry::Get("runtime.clml_runtime_create");
ICHECK(pf != nullptr) << "Cannot find JSON runtime module to create";
ICHECK(pf != nullptr) << "Cannot find CLML runtime module to create";
runtime::Module lib = (*pf)(func_name, graph_json, param_names);
return lib;
}
Expand Down
Loading

0 comments on commit 8a36c25

Please sign in to comment.