From eae836cdf66f54f1e81e78e48bfa051431e8556f Mon Sep 17 00:00:00 2001 From: Gayatri P K Date: Thu, 5 May 2022 21:04:30 +0530 Subject: [PATCH] Fix mixed precision output type to original type (#11142) --- src/relay/transforms/to_mixed_precision.cc | 60 ++++++++++++++++--- tests/python/relay/test_to_mixed_precision.py | 39 ++++++++---- 2 files changed, 82 insertions(+), 17 deletions(-) diff --git a/src/relay/transforms/to_mixed_precision.cc b/src/relay/transforms/to_mixed_precision.cc index 4ad3482f7464..e1d3a264c222 100644 --- a/src/relay/transforms/to_mixed_precision.cc +++ b/src/relay/transforms/to_mixed_precision.cc @@ -36,6 +36,7 @@ namespace tvm { namespace relay { +TVM_REGISTER_PASS_CONFIG_OPTION("relay.ToMixedPrecision.keep_orig_output_dtype", Bool); // A callable which hashes std::pair struct pair_hash { template @@ -105,6 +106,9 @@ class MixedPrecisionPass : public MixedModeMutator { * encountered. Used for emitting warnings on missing ops in the pass. */ std::unordered_map missing_ops_; + const RelayExprNode* root_; + std::vector original_dtype_; + bool keep_orig_output_dtype_; Attrs GetNewAttrs(const CallNode* call, const DataType& accumulation_dtype) const { /* If the accumulation dtype is in the attributes make a copy and mutate the field. */ @@ -278,8 +282,23 @@ class MixedPrecisionPass : public MixedModeMutator { public: using MixedModeMutator::VisitExpr_; - explicit MixedPrecisionPass(DataType mixed_precision_type = DataType::Float(16)) - : MixedModeMutator(), mixed_precision_type_(mixed_precision_type) { + explicit MixedPrecisionPass(Expr base, bool keep_orig_output_dtype, + DataType mixed_precision_type = DataType::Float(16)) + : MixedModeMutator(), + mixed_precision_type_(mixed_precision_type), + root_(Downcast(base)->body.get()), + keep_orig_output_dtype_(keep_orig_output_dtype) { + if (keep_orig_output_dtype_) { + if (root_->IsInstance()) { + const TupleTypeNode* tuple_type = (root_->checked_type_).as(); + for (Type t : tuple_type->fields) { + const TensorTypeNode* tensor_type = t.as(); + original_dtype_.push_back(tensor_type->dtype); + } + } else if (root_->IsInstance()) { + original_dtype_.push_back((root_->checked_type_).as()->dtype); + } + } if (!mixed_precision_type_.is_float() && !mixed_precision_type_.is_bfloat16()) { LOG(FATAL) << "Only support IEEE floating point mixed precision types and bfloat16, but got " << mixed_precision_type_; @@ -381,6 +400,11 @@ class MixedPrecisionPass : public MixedModeMutator { if (accumulation_dtype != output_dtype) { output = CastArg(output, GetType(output), output_dtype); } + if (pre_call_node == root_ && keep_orig_output_dtype_) { + if (original_dtype_[0] != output_dtype) { + output = CastArg(output, GetType(output), original_dtype_[0]); + } + } return output; } @@ -396,6 +420,21 @@ class MixedPrecisionPass : public MixedModeMutator { Expr Rewrite_(const TupleNode* pre, const Expr& post) { // The old checked type in the expression may not be valid so clear it post->checked_type_ = Type(nullptr); + if (pre == root_ && keep_orig_output_dtype_) { + Array new_expr; + bool all_same = true; + for (size_t i = 0; i < original_dtype_.size(); i++) { + Expr output_element = GetField(post, i); + Expr casted_element; + auto output_element_type = transform::InferTypeLocal(output_element); + casted_element = CastArg(output_element, output_element_type, original_dtype_[i]); + new_expr.push_back(casted_element); + all_same &= casted_element.same_as(output_element); + } + if (!all_same) { + return Tuple(new_expr); + } + } return post; } @@ -421,11 +460,12 @@ class MixedPrecisionPass : public MixedModeMutator { } // To access map of ops not registered for error reporting - friend Expr ToMixedPrecision(const Expr& expr, const DataType& mixed_precision_type, - int missing_op_mode); + friend Expr ToMixedPrecision(const Expr& expr, bool keep_orig_output_dtype, + const DataType& mixed_precision_type, int missing_op_mode); }; -Expr ToMixedPrecision(const Expr& expr, const DataType& mixed_precision_type, int missing_op_mode) { +Expr ToMixedPrecision(const Expr& expr, bool keep_orig_output_dtype, + const DataType& mixed_precision_type, int missing_op_mode) { /* missing_op_mode: @@ -436,7 +476,8 @@ Expr ToMixedPrecision(const Expr& expr, const DataType& mixed_precision_type, in ICHECK(missing_op_mode >= 0 && missing_op_mode <= 2) << " missing_op_mode must be either 0, 1, or 2 got " << missing_op_mode; - MixedPrecisionPass converter = MixedPrecisionPass(mixed_precision_type); + MixedPrecisionPass converter = + MixedPrecisionPass(expr, keep_orig_output_dtype, mixed_precision_type); auto result = converter.Mutate(expr); for (auto it = converter.missing_ops_.begin(); @@ -460,7 +501,12 @@ namespace transform { Pass ToMixedPrecision(DataType mixed_precision_type, int missing_op_mode) { runtime::TypedPackedFunc pass_func = [=](Function f, IRModule m, PassContext pc) { - return Downcast(ToMixedPrecision(f, mixed_precision_type, missing_op_mode)); + bool keep_orig_output_dtype = false; + keep_orig_output_dtype = pc->GetConfig("relay.ToMixedPrecision.keep_orig_output_dtype", + Bool(keep_orig_output_dtype)) + .value(); + return Downcast( + ToMixedPrecision(f, keep_orig_output_dtype, mixed_precision_type, missing_op_mode)); }; return CreateFunctionPass(pass_func, 0, "ToMixedPrecision", {}); } diff --git a/tests/python/relay/test_to_mixed_precision.py b/tests/python/relay/test_to_mixed_precision.py index 2afd6ff247ab..026b458bde12 100644 --- a/tests/python/relay/test_to_mixed_precision.py +++ b/tests/python/relay/test_to_mixed_precision.py @@ -41,17 +41,31 @@ def verify_mixed_precision_output_close( mixed_precision_dtype="float16", rtol: float = 1e-3, atol: float = 0, + keep_orig_output_dtype=False, ) -> tvm.runtime.Module: mod = InferType()(mod) result_fp32 = run_module(mod, mod_params) - fp16_mod = ToMixedPrecision(mixed_precision_dtype)(mod) - result_fp16 = run_module(fp16_mod, mod_params) + + if not keep_orig_output_dtype: + fp16_mod = ToMixedPrecision(mixed_precision_dtype)(mod) + result_fp16 = run_module(fp16_mod, mod_params) + else: + with tvm.transform.PassContext( + config={"relay.ToMixedPrecision.keep_orig_output_dtype": True} + ): + fp16_mod = ToMixedPrecision(mixed_precision_dtype)(mod) + result_fp16 = run_module(fp16_mod, mod_params) # Ensure the results are close for fp32, fp16 in zip(result_fp32, result_fp16): np.testing.assert_allclose(fp32, fp16, rtol=rtol, atol=atol) + if keep_orig_output_dtype: + assert ( + np.array(result_fp16).dtype == np.array(result_fp32).dtype + ), "output type and original type mismatch" + return fp16_mod @@ -117,16 +131,21 @@ def test_convert_single_conv(): "data": np.random.uniform(-1, 1, size=data_shape).astype("float32"), "weight": np.random.uniform(-1, 1, size=weight_shape).astype("float32"), } - fp16_mod = verify_mixed_precision_output_close(mod, mod_params, atol=0.01, rtol=1e-3) + fp16_mod = verify_mixed_precision_output_close( + mod, mod_params, atol=0.01, rtol=1e-3, keep_orig_output_dtype=True + ) expected_mod = tvm.IRModule.from_expr( - relay.nn.conv2d( - relay.cast(data, "float16"), - relay.cast(weight, "float16"), - strides=(1, 1), - padding=(1, 1), - out_dtype="float16", - ), + relay.cast( + relay.nn.conv2d( + relay.cast(data, "float16"), + relay.cast(weight, "float16"), + strides=(1, 1), + padding=(1, 1), + out_dtype="float16", + ), + "float32", + ) ) expected_mod = tvm.relay.transform.InferType()(expected_mod)