Skip to content

Commit

Permalink
[ETHOSN] Per-tensor support for int8 operations (apache#10018)
Browse files Browse the repository at this point in the history
* Per-axis quantization to follow
  • Loading branch information
leo-blonk authored and ylc committed Feb 16, 2022
1 parent 52cce12 commit e0ecc2f
Show file tree
Hide file tree
Showing 18 changed files with 529 additions and 287 deletions.
3 changes: 0 additions & 3 deletions python/tvm/relay/op/contrib/ethosn.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,9 +183,6 @@ def check_sigmoid(extract):
if not ethosn_available():
return False

if extract.attrs.out_dtype != "uint8":
return False

return support.sigmoid(extract)

return [
Expand Down
89 changes: 61 additions & 28 deletions src/relay/backend/contrib/ethosn/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ namespace relay {
namespace contrib {
namespace ethosn {

constexpr size_t kReasonMaxLength = sl::g_ReasonMaxLength;

sl::TensorInfo GetTensorInfo(std::map<Expr, std::vector<sl::TensorInfo>> tensor_table,
const Call& call) {
if (tensor_table.find(call) != tensor_table.end()) return tensor_table[call][0];
Expand Down Expand Up @@ -442,11 +444,6 @@ EthosnError ConstructNetworkVisitor::MakeAdditionLayer(const Call& call,

EthosnError ConstructNetworkVisitor::MakeSigmoidLayer(const Call& call,
sl::TensorAndId<sl::Operand>* out) {
SigmoidParams params;
if (auto err = EthosnAPI::Sigmoid(call->op.as<FunctionNode>()->body, &params)) {
return err;
}

auto input = operand_table_[call->args[0]][0];

try {
Expand Down Expand Up @@ -644,15 +641,18 @@ TVM_REGISTER_GLOBAL("relay.ethos-n.support.conv2d")
ConvolutionParams params;
auto err = EthosnAPI::QnnConv2d(call, &params);
err += EthosnCompiler::SupportedSetup();
char reason[kReasonMaxLength];
reason[0] = '\0';
if (params.is_depthwise) {
*rv = !err &&
EthosnCompiler::GetSupported()->IsDepthwiseConvolutionSupported(
params.bias_info, params.weights_info, params.conv_info, params.activation_info);
*rv = !err && EthosnCompiler::GetSupported()->IsDepthwiseConvolutionSupported(
params.bias_info, params.weights_info, params.conv_info,
params.activation_info, nullptr, reason, sizeof(reason));
} else {
*rv = !err &&
EthosnCompiler::GetSupported()->IsConvolutionSupported(
params.bias_info, params.weights_info, params.conv_info, params.activation_info);
*rv = !err && EthosnCompiler::GetSupported()->IsConvolutionSupported(
params.bias_info, params.weights_info, params.conv_info,
params.activation_info, nullptr, reason, sizeof(reason));
}
err += EthosnError(reason);
});

TVM_REGISTER_GLOBAL("relay.ethos-n.support.fc")
Expand All @@ -661,8 +661,12 @@ TVM_REGISTER_GLOBAL("relay.ethos-n.support.fc")
FullyConnectedParams params;
auto err = EthosnAPI::QnnFullyConnected(call, &params);
err += EthosnCompiler::SupportedSetup();
char reason[kReasonMaxLength];
reason[0] = '\0';
*rv = !err && EthosnCompiler::GetSupported()->IsFullyConnectedSupported(
params.bias_info, params.weights_info, params.fc_info, params.input_info);
params.bias_info, params.weights_info, params.fc_info, params.input_info,
nullptr, reason, sizeof(reason));
err += EthosnError(reason);
});

TVM_REGISTER_GLOBAL("relay.ethos-n.support.max_pool2d")
Expand All @@ -671,8 +675,11 @@ TVM_REGISTER_GLOBAL("relay.ethos-n.support.max_pool2d")
MaxPool2DParams params;
auto err = EthosnAPI::MaxPool2D(call, &params);
err += EthosnCompiler::SupportedSetup();
*rv = !err &&
EthosnCompiler::GetSupported()->IsPoolingSupported(params.pool_info, params.input_info);
char reason[kReasonMaxLength];
reason[0] = '\0';
*rv = !err && EthosnCompiler::GetSupported()->IsPoolingSupported(
params.pool_info, params.input_info, nullptr, reason, sizeof(reason));
err += EthosnError(reason);
});

TVM_REGISTER_GLOBAL("relay.ethos-n.support.avg_pool2d")
Expand All @@ -681,8 +688,11 @@ TVM_REGISTER_GLOBAL("relay.ethos-n.support.avg_pool2d")
AvgPool2DParams params;
auto err = EthosnAPI::AvgPool2D(call, &params);
err += EthosnCompiler::SupportedSetup();
*rv = !err &&
EthosnCompiler::GetSupported()->IsPoolingSupported(params.pool_info, params.input_info);
char reason[kReasonMaxLength];
reason[0] = '\0';
*rv = !err && EthosnCompiler::GetSupported()->IsPoolingSupported(
params.pool_info, params.input_info, nullptr, reason, sizeof(reason));
err += EthosnError(reason);
});

TVM_REGISTER_GLOBAL("relay.ethos-n.support.reshape")
Expand All @@ -691,8 +701,11 @@ TVM_REGISTER_GLOBAL("relay.ethos-n.support.reshape")
ReshapeParams params;
auto err = EthosnAPI::Reshape(call, &params);
err += EthosnCompiler::SupportedSetup();
*rv = !err &&
EthosnCompiler::GetSupported()->IsReshapeSupported(params.new_shape, params.input_info);
char reason[kReasonMaxLength];
reason[0] = '\0';
*rv = !err && EthosnCompiler::GetSupported()->IsReshapeSupported(
params.new_shape, params.input_info, nullptr, reason, sizeof(reason));
err += EthosnError(reason);
});

TVM_REGISTER_GLOBAL("relay.ethos-n.support.addition")
Expand All @@ -701,8 +714,12 @@ TVM_REGISTER_GLOBAL("relay.ethos-n.support.addition")
AdditionParams params;
auto err = EthosnAPI::Addition(call, &params);
err += EthosnCompiler::SupportedSetup();
char reason[kReasonMaxLength];
reason[0] = '\0';
*rv = !err && EthosnCompiler::GetSupported()->IsAdditionSupported(
params.lhs_info, params.rhs_info, params.output_quantization_info);
params.lhs_info, params.rhs_info, params.output_quantization_info, nullptr,
reason, sizeof(reason));
err += EthosnError(reason);
});

TVM_REGISTER_GLOBAL("relay.ethos-n.support.sigmoid")
Expand All @@ -711,7 +728,11 @@ TVM_REGISTER_GLOBAL("relay.ethos-n.support.sigmoid")
SigmoidParams params;
auto err = EthosnAPI::Sigmoid(call, &params);
err += EthosnCompiler::SupportedSetup();
*rv = !err && EthosnCompiler::GetSupported()->IsSigmoidSupported(params.input_info);
char reason[kReasonMaxLength];
reason[0] = '\0';
*rv = !err && EthosnCompiler::GetSupported()->IsSigmoidSupported(params.input_info, nullptr,
reason, sizeof(reason));
err += EthosnError(reason);
});

TVM_REGISTER_GLOBAL("relay.ethos-n.support.concatenate")
Expand All @@ -720,8 +741,11 @@ TVM_REGISTER_GLOBAL("relay.ethos-n.support.concatenate")
ConcatenateParams params;
auto err = EthosnAPI::Concatenate(call, &params);
err += EthosnCompiler::SupportedSetup();
*rv = !err && EthosnCompiler::GetSupported()->IsConcatenationSupported(params.input_infos,
params.concat_info);
char reason[kReasonMaxLength];
reason[0] = '\0';
*rv = !err && EthosnCompiler::GetSupported()->IsConcatenationSupported(
params.input_infos, params.concat_info, nullptr, reason, sizeof(reason));
err += EthosnError(reason);
});

TVM_REGISTER_GLOBAL("relay.ethos-n.support.split")
Expand All @@ -730,8 +754,11 @@ TVM_REGISTER_GLOBAL("relay.ethos-n.support.split")
SplitParams params;
auto err = EthosnAPI::Split(call, &params);
err += EthosnCompiler::SupportedSetup();
*rv = !err &&
EthosnCompiler::GetSupported()->IsSplitSupported(params.input_info, params.split_info);
char reason[kReasonMaxLength];
reason[0] = '\0';
*rv = !err && EthosnCompiler::GetSupported()->IsSplitSupported(
params.input_info, params.split_info, nullptr, reason, sizeof(reason));
err += EthosnError(reason);
});

TVM_REGISTER_GLOBAL("relay.ethos-n.support.depth_to_space")
Expand All @@ -740,8 +767,11 @@ TVM_REGISTER_GLOBAL("relay.ethos-n.support.depth_to_space")
DepthToSpaceParams params;
auto err = EthosnAPI::DepthToSpace(call, &params);
err += EthosnCompiler::SupportedSetup();
*rv = !err && EthosnCompiler::GetSupported()->IsDepthToSpaceSupported(params.input_info,
params.depth_info);
char reason[kReasonMaxLength];
reason[0] = '\0';
*rv = !err && EthosnCompiler::GetSupported()->IsDepthToSpaceSupported(
params.input_info, params.depth_info, nullptr, reason, sizeof(reason));
err += EthosnError(reason);
});

TVM_REGISTER_GLOBAL("relay.ethos-n.support.relu")
Expand All @@ -750,8 +780,11 @@ TVM_REGISTER_GLOBAL("relay.ethos-n.support.relu")
ReluParams params;
auto err = EthosnAPI::Relu(call, &params);
err += EthosnCompiler::SupportedSetup();
*rv = !err &&
EthosnCompiler::GetSupported()->IsReluSupported(params.relu_info, params.input_info);
char reason[kReasonMaxLength];
reason[0] = '\0';
*rv = !err && EthosnCompiler::GetSupported()->IsReluSupported(
params.relu_info, params.input_info, nullptr, reason, sizeof(reason));
err += EthosnError(reason);
});

TVM_REGISTER_GLOBAL("relay.ethos-n.query").set_body([](tvm::TVMArgs args, tvm::TVMRetValue* rv) {
Expand Down
Loading

0 comments on commit e0ecc2f

Please sign in to comment.