Skip to content

Commit

Permalink
add support for softmax and log_softmax with MIOpen (apache#8543)
Browse files Browse the repository at this point in the history
  • Loading branch information
altanh authored and ylc committed Sep 29, 2021
1 parent 280fb7b commit b2d9e50
Show file tree
Hide file tree
Showing 6 changed files with 257 additions and 3 deletions.
52 changes: 52 additions & 0 deletions python/tvm/contrib/miopen.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,3 +136,55 @@ def conv2d_forward(
),
name="y",
)


def softmax(x, axis=-1):
"""Compute softmax with MIOpen
Parameters
----------
x : tvm.te.Tensor
The input tensor
axis : int
The axis to compute softmax over
Returns
-------
ret : tvm.te.Tensor
The result tensor
"""
return te.extern(
x.shape,
[x],
lambda ins, outs: tvm.tir.call_packed(
"tvm.contrib.miopen.softmax.forward", ins[0], outs[0], axis
),
name="y",
)


def log_softmax(x, axis=-1):
"""Compute log softmax with MIOpen
Parameters
----------
x : tvm.te.Tensor
The input tensor
axis : int
The axis to compute log softmax over
Returns
-------
ret : tvm.te.Tensor
The result tensor
"""
return te.extern(
x.shape,
[x],
lambda ins, outs: tvm.tir.call_packed(
"tvm.contrib.miopen.log_softmax.forward", ins[0], outs[0], axis
),
name="y",
)
39 changes: 39 additions & 0 deletions python/tvm/relay/op/strategy/rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from tvm.auto_scheduler import is_auto_scheduler_enabled
from tvm.te import SpecializedCondition
from tvm.contrib.thrust import can_use_rocthrust
from tvm.contrib import miopen

from .generic import *
from .. import op as _op
Expand Down Expand Up @@ -304,3 +305,41 @@ def topk_strategy_cuda(attrs, inputs, out_type, target):
plevel=15,
)
return strategy


@softmax_strategy.register(["rocm"])
def softmax_strategy_rocm(attrs, inputs, out_type, target):
"""rocm strategy for softmax"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_softmax(topi.nn.softmax),
wrap_topi_schedule(topi.cuda.schedule_softmax),
name="softmax.rocm",
)
if "miopen" in target.libs:
strategy.add_implementation(
wrap_compute_softmax(miopen.softmax),
wrap_topi_schedule(topi.generic.schedule_extern),
name="softmax.miopen",
plevel=15,
)
return strategy


@log_softmax_strategy.register(["rocm"])
def log_softmax_strategy_rocm(attrs, inputs, out_type, target):
"""rocm strategy for log softmax"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_softmax(topi.nn.log_softmax),
wrap_topi_schedule(topi.cuda.schedule_softmax),
name="log_softmax.rocm",
)
if "miopen" in target.libs:
strategy.add_implementation(
wrap_compute_softmax(miopen.log_softmax),
wrap_topi_schedule(topi.generic.schedule_extern),
name="log_softmax.miopen",
plevel=15,
)
return strategy
4 changes: 4 additions & 0 deletions src/runtime/contrib/miopen/miopen_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,10 @@ void ConvEntry::CleanWorkspace() {
workspace_size = 0;
}

SoftmaxEntry::SoftmaxEntry() { MIOPEN_CALL(miopenCreateTensorDescriptor(&shape_desc)); }

SoftmaxEntry::~SoftmaxEntry() { MIOPEN_CALL(miopenDestroyTensorDescriptor(shape_desc)); }

} // namespace miopen
} // namespace contrib
} // namespace tvm
7 changes: 7 additions & 0 deletions src/runtime/contrib/miopen/miopen_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,18 @@ struct ConvEntry {
void CleanWorkspace();
}; // ConvThreadEntry

struct SoftmaxEntry {
miopenTensorDescriptor_t shape_desc;
SoftmaxEntry();
~SoftmaxEntry();
}; // SoftmaxEntry

struct MIOpenThreadEntry {
MIOpenThreadEntry();
~MIOpenThreadEntry();
miopenHandle_t handle{nullptr};
ConvEntry conv_entry;
SoftmaxEntry softmax_entry;
runtime::DeviceAPI* rocm_api{nullptr};
static MIOpenThreadEntry* ThreadLocal();
}; // MIOpenThreadEntry
Expand Down
92 changes: 92 additions & 0 deletions src/runtime/contrib/miopen/softmax.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file src/runtime/contrib/miopen/softmax.cc
* \brief Use external miopen softmax function
*/
#include <tvm/runtime/data_type.h>
#include <tvm/runtime/registry.h>

#include "miopen_utils.h"

namespace tvm {
namespace contrib {
namespace miopen {

using namespace runtime;

void softmax_impl(TVMArgs args, TVMRetValue* ret, miopenSoftmaxAlgorithm_t alg) {
DLTensor* x = args[0];
DLTensor* y = args[1];
int axis = args[2];
int ndim = x->ndim;
int64_t* shape = x->shape;
if (axis < 0) axis += ndim;
ICHECK(axis >= 0 && axis < ndim);
// just fp32 for now
ICHECK(TypeMatch(x->dtype, kDLFloat, 32));
ICHECK(TypeMatch(y->dtype, kDLFloat, 32));

MIOpenThreadEntry* entry_ptr = MIOpenThreadEntry::ThreadLocal();

miopenSoftmaxMode_t mode;
if (axis == ndim - 1) {
int64_t N = 1;
for (int i = 0; i < ndim - 1; ++i) {
N *= shape[i];
}
mode = MIOPEN_SOFTMAX_MODE_INSTANCE;
MIOPEN_CALL(miopenSet4dTensorDescriptor(entry_ptr->softmax_entry.shape_desc, miopenFloat,
static_cast<int>(N), static_cast<int>(shape[ndim - 1]),
1, 1));
} else {
int64_t pre_axis_dim = 1;
int64_t post_axis_dim = 1;
for (int i = 0; i < ndim; ++i) {
if (i < axis) {
pre_axis_dim *= shape[i];
} else if (i > axis) {
post_axis_dim *= shape[i];
}
}
mode = MIOPEN_SOFTMAX_MODE_CHANNEL;
MIOPEN_CALL(miopenSet4dTensorDescriptor(
entry_ptr->softmax_entry.shape_desc, miopenFloat, static_cast<int>(pre_axis_dim),
static_cast<int>(shape[axis]), static_cast<int>(post_axis_dim), 1));
}

const float alpha = 1.f;
const float beta = 0.f;
MIOPEN_CALL(miopenSoftmaxForward_V2(entry_ptr->handle, &alpha,
entry_ptr->softmax_entry.shape_desc, x->data, &beta,
entry_ptr->softmax_entry.shape_desc, y->data, alg, mode));
}

TVM_REGISTER_GLOBAL("tvm.contrib.miopen.softmax.forward")
.set_body([](TVMArgs args, TVMRetValue* ret) {
softmax_impl(args, ret, MIOPEN_SOFTMAX_ACCURATE);
});

TVM_REGISTER_GLOBAL("tvm.contrib.miopen.log_softmax.forward")
.set_body([](TVMArgs args, TVMRetValue* ret) { softmax_impl(args, ret, MIOPEN_SOFTMAX_LOG); });

} // namespace miopen
} // namespace contrib
} // namespace tvm
66 changes: 63 additions & 3 deletions tests/python/contrib/test_miopen.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,17 @@
from tvm import te
from tvm.contrib import miopen
import numpy as np
import pytest


requires_miopen = pytest.mark.skipif(
tvm.get_global_func("tvm.contrib.miopen.conv2d.setup", True) is None,
reason="MIOpen is not enabled",
)


@tvm.testing.requires_rocm
@requires_miopen
def test_conv2d():
in_channel = 3
out_channel = 64
Expand All @@ -35,9 +43,6 @@ def test_conv2d():
dilation_w = 1

xshape = [1, in_channel, 128, 128]
if not tvm.get_global_func("tvm.contrib.miopen.conv2d.setup", True):
print("skip because miopen is not enabled...")
return
wshape = (out_channel, in_channel, filter_h, filter_w)

X = te.placeholder(xshape, name="X")
Expand Down Expand Up @@ -72,5 +77,60 @@ def verify():
verify()


def verify_softmax(shape, axis, dtype="float32", log_softmax=False):
miopen_op = miopen.log_softmax if log_softmax else miopen.softmax
testing_op = (
tvm.topi.testing.log_softmax_python if log_softmax else tvm.topi.testing.softmax_python
)

A = te.placeholder(shape, dtype=dtype, name="A")
B = miopen_op(A, axis)
s = te.create_schedule([B.op])

dev = tvm.rocm(0)
a_np = np.random.uniform(size=shape).astype(dtype)
b_np = testing_op(a_np)
a = tvm.nd.array(a_np, dev)
b = tvm.nd.array(b_np, dev)
f = tvm.build(s, [A, B], target="rocm --host=llvm", name="softmax")
f(a, b)
tvm.testing.assert_allclose(b.numpy(), b_np, rtol=1e-3)


def verify_softmax_4d(shape, dtype="float32", log_softmax=False):
miopen_op = miopen.log_softmax if log_softmax else miopen.softmax
testing_op = (
tvm.topi.testing.log_softmax_python if log_softmax else tvm.topi.testing.softmax_python
)

A = te.placeholder(shape, dtype=dtype, name="A")
B = miopen_op(A, axis=1)
s = te.create_schedule([B.op])

dev = tvm.rocm(0)
n, c, h, w = shape
a_np = np.random.uniform(size=shape).astype(dtype)
b_np = testing_op(a_np.transpose(0, 2, 3, 1).reshape(h * w, c))
b_np = b_np.reshape(n, h, w, c).transpose(0, 3, 1, 2)
a = tvm.nd.array(a_np, dev)
b = tvm.nd.array(b_np, dev)
f = tvm.build(s, [A, B], target="rocm --host=llvm", name="softmax")
f(a, b)
tvm.testing.assert_allclose(b.numpy(), b_np, rtol=1e-3)


@tvm.testing.requires_rocm
@requires_miopen
def test_softmax():
verify_softmax((32, 10), -1)
verify_softmax((3, 4), -1)
verify_softmax_4d((1, 16, 256, 256))
verify_softmax_4d((1, 16, 256, 256))

verify_softmax((32, 10), -1, log_softmax=True)
verify_softmax((3, 4), -1, log_softmax=True)
verify_softmax_4d((1, 16, 256, 256), log_softmax=True)


if __name__ == "__main__":
test_conv2d()

0 comments on commit b2d9e50

Please sign in to comment.