diff --git a/python/tvm/contrib/miopen.py b/python/tvm/contrib/miopen.py index 112fc320973b..0e336c1c82b9 100644 --- a/python/tvm/contrib/miopen.py +++ b/python/tvm/contrib/miopen.py @@ -136,3 +136,55 @@ def conv2d_forward( ), name="y", ) + + +def softmax(x, axis=-1): + """Compute softmax with MIOpen + + Parameters + ---------- + x : tvm.te.Tensor + The input tensor + + axis : int + The axis to compute softmax over + + Returns + ------- + ret : tvm.te.Tensor + The result tensor + """ + return te.extern( + x.shape, + [x], + lambda ins, outs: tvm.tir.call_packed( + "tvm.contrib.miopen.softmax.forward", ins[0], outs[0], axis + ), + name="y", + ) + + +def log_softmax(x, axis=-1): + """Compute log softmax with MIOpen + + Parameters + ---------- + x : tvm.te.Tensor + The input tensor + + axis : int + The axis to compute log softmax over + + Returns + ------- + ret : tvm.te.Tensor + The result tensor + """ + return te.extern( + x.shape, + [x], + lambda ins, outs: tvm.tir.call_packed( + "tvm.contrib.miopen.log_softmax.forward", ins[0], outs[0], axis + ), + name="y", + ) diff --git a/python/tvm/relay/op/strategy/rocm.py b/python/tvm/relay/op/strategy/rocm.py index f4538071e11e..64373dcdd7bf 100644 --- a/python/tvm/relay/op/strategy/rocm.py +++ b/python/tvm/relay/op/strategy/rocm.py @@ -20,6 +20,7 @@ from tvm.auto_scheduler import is_auto_scheduler_enabled from tvm.te import SpecializedCondition from tvm.contrib.thrust import can_use_rocthrust +from tvm.contrib import miopen from .generic import * from .. import op as _op @@ -304,3 +305,41 @@ def topk_strategy_cuda(attrs, inputs, out_type, target): plevel=15, ) return strategy + + +@softmax_strategy.register(["rocm"]) +def softmax_strategy_rocm(attrs, inputs, out_type, target): + """rocm strategy for softmax""" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_softmax(topi.nn.softmax), + wrap_topi_schedule(topi.cuda.schedule_softmax), + name="softmax.rocm", + ) + if "miopen" in target.libs: + strategy.add_implementation( + wrap_compute_softmax(miopen.softmax), + wrap_topi_schedule(topi.generic.schedule_extern), + name="softmax.miopen", + plevel=15, + ) + return strategy + + +@log_softmax_strategy.register(["rocm"]) +def log_softmax_strategy_rocm(attrs, inputs, out_type, target): + """rocm strategy for log softmax""" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_softmax(topi.nn.log_softmax), + wrap_topi_schedule(topi.cuda.schedule_softmax), + name="log_softmax.rocm", + ) + if "miopen" in target.libs: + strategy.add_implementation( + wrap_compute_softmax(miopen.log_softmax), + wrap_topi_schedule(topi.generic.schedule_extern), + name="log_softmax.miopen", + plevel=15, + ) + return strategy diff --git a/src/runtime/contrib/miopen/miopen_utils.cc b/src/runtime/contrib/miopen/miopen_utils.cc index 426d2f24ddf5..ff748b8826de 100644 --- a/src/runtime/contrib/miopen/miopen_utils.cc +++ b/src/runtime/contrib/miopen/miopen_utils.cc @@ -89,6 +89,10 @@ void ConvEntry::CleanWorkspace() { workspace_size = 0; } +SoftmaxEntry::SoftmaxEntry() { MIOPEN_CALL(miopenCreateTensorDescriptor(&shape_desc)); } + +SoftmaxEntry::~SoftmaxEntry() { MIOPEN_CALL(miopenDestroyTensorDescriptor(shape_desc)); } + } // namespace miopen } // namespace contrib } // namespace tvm diff --git a/src/runtime/contrib/miopen/miopen_utils.h b/src/runtime/contrib/miopen/miopen_utils.h index d3a8c7b9ad64..76913696b0b9 100644 --- a/src/runtime/contrib/miopen/miopen_utils.h +++ b/src/runtime/contrib/miopen/miopen_utils.h @@ -62,11 +62,18 @@ struct ConvEntry { void CleanWorkspace(); }; // ConvThreadEntry +struct SoftmaxEntry { + miopenTensorDescriptor_t shape_desc; + SoftmaxEntry(); + ~SoftmaxEntry(); +}; // SoftmaxEntry + struct MIOpenThreadEntry { MIOpenThreadEntry(); ~MIOpenThreadEntry(); miopenHandle_t handle{nullptr}; ConvEntry conv_entry; + SoftmaxEntry softmax_entry; runtime::DeviceAPI* rocm_api{nullptr}; static MIOpenThreadEntry* ThreadLocal(); }; // MIOpenThreadEntry diff --git a/src/runtime/contrib/miopen/softmax.cc b/src/runtime/contrib/miopen/softmax.cc new file mode 100644 index 000000000000..5a0f24ed7a84 --- /dev/null +++ b/src/runtime/contrib/miopen/softmax.cc @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/runtime/contrib/miopen/softmax.cc + * \brief Use external miopen softmax function + */ +#include +#include + +#include "miopen_utils.h" + +namespace tvm { +namespace contrib { +namespace miopen { + +using namespace runtime; + +void softmax_impl(TVMArgs args, TVMRetValue* ret, miopenSoftmaxAlgorithm_t alg) { + DLTensor* x = args[0]; + DLTensor* y = args[1]; + int axis = args[2]; + int ndim = x->ndim; + int64_t* shape = x->shape; + if (axis < 0) axis += ndim; + ICHECK(axis >= 0 && axis < ndim); + // just fp32 for now + ICHECK(TypeMatch(x->dtype, kDLFloat, 32)); + ICHECK(TypeMatch(y->dtype, kDLFloat, 32)); + + MIOpenThreadEntry* entry_ptr = MIOpenThreadEntry::ThreadLocal(); + + miopenSoftmaxMode_t mode; + if (axis == ndim - 1) { + int64_t N = 1; + for (int i = 0; i < ndim - 1; ++i) { + N *= shape[i]; + } + mode = MIOPEN_SOFTMAX_MODE_INSTANCE; + MIOPEN_CALL(miopenSet4dTensorDescriptor(entry_ptr->softmax_entry.shape_desc, miopenFloat, + static_cast(N), static_cast(shape[ndim - 1]), + 1, 1)); + } else { + int64_t pre_axis_dim = 1; + int64_t post_axis_dim = 1; + for (int i = 0; i < ndim; ++i) { + if (i < axis) { + pre_axis_dim *= shape[i]; + } else if (i > axis) { + post_axis_dim *= shape[i]; + } + } + mode = MIOPEN_SOFTMAX_MODE_CHANNEL; + MIOPEN_CALL(miopenSet4dTensorDescriptor( + entry_ptr->softmax_entry.shape_desc, miopenFloat, static_cast(pre_axis_dim), + static_cast(shape[axis]), static_cast(post_axis_dim), 1)); + } + + const float alpha = 1.f; + const float beta = 0.f; + MIOPEN_CALL(miopenSoftmaxForward_V2(entry_ptr->handle, &alpha, + entry_ptr->softmax_entry.shape_desc, x->data, &beta, + entry_ptr->softmax_entry.shape_desc, y->data, alg, mode)); +} + +TVM_REGISTER_GLOBAL("tvm.contrib.miopen.softmax.forward") + .set_body([](TVMArgs args, TVMRetValue* ret) { + softmax_impl(args, ret, MIOPEN_SOFTMAX_ACCURATE); + }); + +TVM_REGISTER_GLOBAL("tvm.contrib.miopen.log_softmax.forward") + .set_body([](TVMArgs args, TVMRetValue* ret) { softmax_impl(args, ret, MIOPEN_SOFTMAX_LOG); }); + +} // namespace miopen +} // namespace contrib +} // namespace tvm diff --git a/tests/python/contrib/test_miopen.py b/tests/python/contrib/test_miopen.py index 27a8ec6df357..81115b6c0238 100644 --- a/tests/python/contrib/test_miopen.py +++ b/tests/python/contrib/test_miopen.py @@ -19,9 +19,17 @@ from tvm import te from tvm.contrib import miopen import numpy as np +import pytest + + +requires_miopen = pytest.mark.skipif( + tvm.get_global_func("tvm.contrib.miopen.conv2d.setup", True) is None, + reason="MIOpen is not enabled", +) @tvm.testing.requires_rocm +@requires_miopen def test_conv2d(): in_channel = 3 out_channel = 64 @@ -35,9 +43,6 @@ def test_conv2d(): dilation_w = 1 xshape = [1, in_channel, 128, 128] - if not tvm.get_global_func("tvm.contrib.miopen.conv2d.setup", True): - print("skip because miopen is not enabled...") - return wshape = (out_channel, in_channel, filter_h, filter_w) X = te.placeholder(xshape, name="X") @@ -72,5 +77,60 @@ def verify(): verify() +def verify_softmax(shape, axis, dtype="float32", log_softmax=False): + miopen_op = miopen.log_softmax if log_softmax else miopen.softmax + testing_op = ( + tvm.topi.testing.log_softmax_python if log_softmax else tvm.topi.testing.softmax_python + ) + + A = te.placeholder(shape, dtype=dtype, name="A") + B = miopen_op(A, axis) + s = te.create_schedule([B.op]) + + dev = tvm.rocm(0) + a_np = np.random.uniform(size=shape).astype(dtype) + b_np = testing_op(a_np) + a = tvm.nd.array(a_np, dev) + b = tvm.nd.array(b_np, dev) + f = tvm.build(s, [A, B], target="rocm --host=llvm", name="softmax") + f(a, b) + tvm.testing.assert_allclose(b.numpy(), b_np, rtol=1e-3) + + +def verify_softmax_4d(shape, dtype="float32", log_softmax=False): + miopen_op = miopen.log_softmax if log_softmax else miopen.softmax + testing_op = ( + tvm.topi.testing.log_softmax_python if log_softmax else tvm.topi.testing.softmax_python + ) + + A = te.placeholder(shape, dtype=dtype, name="A") + B = miopen_op(A, axis=1) + s = te.create_schedule([B.op]) + + dev = tvm.rocm(0) + n, c, h, w = shape + a_np = np.random.uniform(size=shape).astype(dtype) + b_np = testing_op(a_np.transpose(0, 2, 3, 1).reshape(h * w, c)) + b_np = b_np.reshape(n, h, w, c).transpose(0, 3, 1, 2) + a = tvm.nd.array(a_np, dev) + b = tvm.nd.array(b_np, dev) + f = tvm.build(s, [A, B], target="rocm --host=llvm", name="softmax") + f(a, b) + tvm.testing.assert_allclose(b.numpy(), b_np, rtol=1e-3) + + +@tvm.testing.requires_rocm +@requires_miopen +def test_softmax(): + verify_softmax((32, 10), -1) + verify_softmax((3, 4), -1) + verify_softmax_4d((1, 16, 256, 256)) + verify_softmax_4d((1, 16, 256, 256)) + + verify_softmax((32, 10), -1, log_softmax=True) + verify_softmax((3, 4), -1, log_softmax=True) + verify_softmax_4d((1, 16, 256, 256), log_softmax=True) + + if __name__ == "__main__": test_conv2d()