From 84f7bbf778040ce7736a665f9e5fbbfaa2c1f1b0 Mon Sep 17 00:00:00 2001 From: masahi Date: Sun, 8 Aug 2021 05:40:42 +0900 Subject: [PATCH] [Contrib] Support fp16 input in cpu sort (#8672) --- src/runtime/contrib/sort/sort.cc | 71 +++++++++++++++++++++++----- tests/python/relay/test_op_level6.py | 42 ++++++++-------- web/Makefile | 2 +- 3 files changed, 83 insertions(+), 32 deletions(-) diff --git a/src/runtime/contrib/sort/sort.cc b/src/runtime/contrib/sort/sort.cc index 66f36ffa50d64..4aa8c92f5199f 100644 --- a/src/runtime/contrib/sort/sort.cc +++ b/src/runtime/contrib/sort/sort.cc @@ -21,6 +21,7 @@ * \file Use standard C library call. */ +#include #include #include @@ -42,6 +43,24 @@ bool CompareDescend(const std::pair& lhs, const std::pair rhs.second; } +struct float16 { + uint16_t bits; + float to_float() const { + return __extendXfYf2__(bits); + } +}; + +template <> +bool CompareAscend(const std::pair& lhs, const std::pair& rhs) { + return lhs.second.to_float() < rhs.second.to_float(); +} + +template <> +bool CompareDescend(const std::pair& lhs, + const std::pair& rhs) { + return lhs.second.to_float() > rhs.second.to_float(); +} + // Argsort implemented C library sort for nms. // Return indices of sorted tensor. // By default, the last axis will be used to sort. @@ -125,7 +144,9 @@ TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort_nms").set_body([](TVMArgs args, TV }); template -void sort_impl(DLTensor* input, DLTensor* output, int32_t axis, bool is_ascend, bool is_argsort) { +void sort_impl( + DLTensor* input, DLTensor* output, int32_t axis, bool is_ascend, + std::function&)> epilogue) { auto data_ptr = static_cast(input->data); auto out_ptr = static_cast(output->data); std::vector> sorter; @@ -153,14 +174,8 @@ void sort_impl(DLTensor* input, DLTensor* output, int32_t axis, bool is_ascend, } else { std::stable_sort(sorter.begin(), sorter.end(), CompareDescend); } - if (is_argsort) { - for (int64_t k = 0; k < input->shape[axis]; ++k) { - out_ptr[base_idx + k * axis_mul_after] = static_cast(sorter[k].first); - } - } else { - for (int64_t k = 0; k < input->shape[axis]; ++k) { - out_ptr[base_idx + k * axis_mul_after] = static_cast(sorter[k].second); - } + for (int64_t k = 0; k < input->shape[axis]; ++k) { + epilogue(out_ptr, base_idx + k * axis_mul_after, sorter[k]); } } } @@ -168,12 +183,20 @@ void sort_impl(DLTensor* input, DLTensor* output, int32_t axis, bool is_ascend, template void argsort(DLTensor* input, DLTensor* output, int32_t axis, bool is_ascend) { - return sort_impl(input, output, axis, is_ascend, true); + return sort_impl( + input, output, axis, is_ascend, + [](OutType* out_ptr, size_t index, const std::pair& sort_pair) { + out_ptr[index] = static_cast(sort_pair.first); + }); } template void sort(DLTensor* input, DLTensor* output, int32_t axis, bool is_ascend) { - return sort_impl(input, output, axis, is_ascend, false); + return sort_impl( + input, output, axis, is_ascend, + [](DataType* out_ptr, size_t index, const std::pair& sort_pair) { + out_ptr[index] = sort_pair.second; + }); } // Argsort implemented C library sort. @@ -254,6 +277,18 @@ TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort").set_body([](TVMArgs args, TVMRet } else { LOG(FATAL) << "Unsupported output dtype: " << out_dtype; } + } else if (data_dtype == "float16") { + if (out_dtype == "int32") { + argsort(input, output, axis, is_ascend); + } else if (out_dtype == "int64") { + argsort(input, output, axis, is_ascend); + } else if (out_dtype == "float32") { + argsort(input, output, axis, is_ascend); + } else if (out_dtype == "float64") { + argsort(input, output, axis, is_ascend); + } else { + LOG(FATAL) << "Unsupported output dtype: " << out_dtype; + } } else { LOG(FATAL) << "Unsupported input dtype: " << data_dtype; } @@ -295,6 +330,8 @@ TVM_REGISTER_GLOBAL("tvm.contrib.sort.sort").set_body([](TVMArgs args, TVMRetVal sort(input, output, axis, is_ascend); } else if (data_dtype == "int64") { sort(input, output, axis, is_ascend); + } else if (data_dtype == "float16") { + sort(input, output, axis, is_ascend); } else { LOG(FATAL) << "Unsupported input dtype: " << data_dtype; } @@ -432,6 +469,18 @@ TVM_REGISTER_GLOBAL("tvm.contrib.sort.topk").set_body([](TVMArgs args, TVMRetVal } else { LOG(FATAL) << "Unsupported output dtype: " << out_dtype; } + } else if (data_dtype == "float16") { + if (out_dtype == "int32") { + topk(input, values_out, indices_out, k, axis, is_ascend); + } else if (out_dtype == "int64") { + topk(input, values_out, indices_out, k, axis, is_ascend); + } else if (out_dtype == "float32") { + topk(input, values_out, indices_out, k, axis, is_ascend); + } else if (out_dtype == "float64") { + topk(input, values_out, indices_out, k, axis, is_ascend); + } else { + LOG(FATAL) << "Unsupported output dtype: " << out_dtype; + } } else { LOG(FATAL) << "Unsupported input dtype: " << data_dtype; } diff --git a/tests/python/relay/test_op_level6.py b/tests/python/relay/test_op_level6.py index 1838233e3a3af..f4a4dd4e61349 100644 --- a/tests/python/relay/test_op_level6.py +++ b/tests/python/relay/test_op_level6.py @@ -16,24 +16,23 @@ # under the License. """ Support level6 operator test cases. """ +import pytest import numpy as np import tvm -from tvm import te from tvm import relay import tvm.testing @tvm.testing.uses_gpu def test_sort(): - def verify_sort(shape, axis, is_ascend, is_dyn=False): - + def verify_sort(shape, axis, is_ascend, is_dyn=False, in_dtype="float32"): if is_dyn: - x = relay.var("x", relay.TensorType([relay.Any()] * len(shape), "float32")) + x = relay.var("x", relay.TensorType([relay.Any()] * len(shape), in_dtype)) else: - x = relay.var("x", relay.TensorType(shape, "float32")) + x = relay.var("x", relay.TensorType(shape, in_dtype)) z = relay.sort(x, axis=axis, is_ascend=is_ascend) func = relay.Function([x], z) - x_data = np.random.uniform(size=shape).astype("float32") + x_data = np.random.uniform(size=shape).astype(in_dtype) if is_ascend: ref_res = np.sort(x_data, axis=axis) else: @@ -56,18 +55,19 @@ def verify_sort(shape, axis, is_ascend, is_dyn=False): verify_sort((3, 5, 6), axis=-1, is_ascend=False, is_dyn=is_dyn) verify_sort((3, 2000, 6), axis=1, is_ascend=False, is_dyn=is_dyn) verify_sort((1, 122640), axis=1, is_ascend=False, is_dyn=is_dyn) + verify_sort((1, 122640), axis=1, is_ascend=False, is_dyn=is_dyn, in_dtype="float16") @tvm.testing.uses_gpu def test_argsort(): - def verify_argsort(shape, axis, is_ascend, dtype, is_dyn=False): + def verify_argsort(shape, axis, is_ascend, dtype, is_dyn=False, in_dtype="float32"): if is_dyn: - x = relay.var("x", relay.TensorType([relay.Any()] * len(shape), "float32")) + x = relay.var("x", relay.TensorType([relay.Any()] * len(shape), in_dtype)) else: - x = relay.var("x", relay.TensorType(shape, "float32")) + x = relay.var("x", relay.TensorType(shape, in_dtype)) z = relay.argsort(x, axis=axis, is_ascend=is_ascend, dtype=dtype) func = relay.Function([x], z) - x_data = np.random.uniform(size=shape).astype("float32") + x_data = np.random.uniform(size=shape).astype(in_dtype) if is_ascend: ref_res = np.argsort(x_data, axis=axis, kind="stable") else: @@ -93,31 +93,34 @@ def verify_argsort(shape, axis, is_ascend, dtype, is_dyn=False): verify_argsort((3, 6000, 6), axis=1, is_ascend=False, dtype=dtype, is_dyn=is_dyn) verify_argsort((1000, 1, 1), axis=0, is_ascend=False, dtype=dtype, is_dyn=is_dyn) verify_argsort((1, 122640), axis=1, is_ascend=False, dtype=dtype, is_dyn=is_dyn) + verify_argsort( + (1, 122640), axis=1, is_ascend=False, dtype=dtype, is_dyn=is_dyn, in_dtype="float16" + ) @tvm.testing.uses_gpu def test_topk(): - def verify_topk(k, axis, ret_type, is_ascend, dtype): + def verify_topk(k, axis, ret_type, is_ascend, dtype, in_dtype="float32"): shape = (20, 100) - x = relay.var("x", relay.TensorType(shape, "float32")) + x = relay.var("x", relay.TensorType(shape, in_dtype)) out = relay.topk(x, k, axis, ret_type, is_ascend, dtype) if isinstance(out, relay.expr.TupleWrapper): out = out.astuple() func = relay.Function([x], out) - np_data = np.random.uniform(size=shape).astype("float32") + np_data = np.random.uniform(size=shape).astype(in_dtype) if is_ascend: - np_indices = np.argsort(np_data, axis=axis) + np_indices = np.argsort(np_data, axis=axis, kind="stable") else: - np_indices = np.argsort(-np_data, axis=axis) + np_indices = np.argsort(-np_data, axis=axis, kind="stable") kk = k if k >= 1 else shape[axis] if axis == 0: np_indices = np_indices[:kk, :] - np_values = np.zeros(np_indices.shape).astype("float32") + np_values = np.zeros(np_indices.shape).astype(in_dtype) for i in range(shape[1]): np_values[:, i] = np_data[np_indices[:, i], i] else: np_indices = np_indices[:, :kk] - np_values = np.zeros(np_indices.shape).astype("float32") + np_values = np.zeros(np_indices.shape).astype(in_dtype) for i in range(shape[0]): np_values[i, :] = np_data[i, np_indices[i, :]] np_indices = np_indices.astype(dtype) @@ -140,9 +143,8 @@ def verify_topk(k, axis, ret_type, is_ascend, dtype): for ret_type in ["both", "values", "indices"]: verify_topk(k, axis, ret_type, True, "int64") verify_topk(k, axis, ret_type, False, "float32") + verify_topk(k, axis, ret_type, False, "int64", "float16") if __name__ == "__main__": - test_sort() - test_argsort() - test_topk() + pytest.main([__file__]) diff --git a/web/Makefile b/web/Makefile index 8c4dbc20dadc6..34a1b8172484a 100644 --- a/web/Makefile +++ b/web/Makefile @@ -18,7 +18,7 @@ TVM_ROOT=$(shell cd ..; pwd) INCLUDE_FLAGS = -I$(TVM_ROOT) -I$(TVM_ROOT)/include\ - -I$(TVM_ROOT)/3rdparty/dlpack/include -I$(TVM_ROOT)/3rdparty/dmlc-core/include + -I$(TVM_ROOT)/3rdparty/dlpack/include -I$(TVM_ROOT)/3rdparty/dmlc-core/include -I$(TVM_ROOT)/3rdparty/compiler-rt .PHONY: clean all rmtypedep preparetest