Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Add boolean ndarray #15940

Merged
merged 4 commits into from
Oct 8, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 61 additions & 1 deletion 3rdparty/mshadow/mshadow/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,7 @@ enum TypeFlag {
kInt32 = 4,
kInt8 = 5,
kInt64 = 6,
kBool = 7,
};

template<typename DType>
Expand Down Expand Up @@ -411,6 +412,11 @@ struct DataType<int64_t> {
static const int kFlag = kInt64;
static const int kLanes = 1;
};
template<>
struct DataType<bool> {
static const int kFlag = kBool;
static const int kLanes = 1;
};

/*! \brief type enum value for default real type */
const int default_type_flag = DataType<default_real_t>::kFlag;
Expand Down Expand Up @@ -1138,10 +1144,64 @@ struct minimum {
LOG(FATAL) << "Unknown type enum " << type; \
}

#define MSHADOW_TYPE_SWITCH_WITH_BOOL(type, DType, ...) \
switch (type) { \
case mshadow::kFloat32: \
{ \
typedef float DType; \
{__VA_ARGS__} \
} \
break; \
case mshadow::kFloat64: \
{ \
typedef double DType; \
{__VA_ARGS__} \
} \
break; \
case mshadow::kFloat16: \
{ \
typedef mshadow::half::half_t DType; \
{__VA_ARGS__} \
} \
break; \
case mshadow::kUint8: \
{ \
typedef uint8_t DType; \
{__VA_ARGS__} \
} \
break; \
case mshadow::kInt8: \
{ \
typedef int8_t DType; \
{__VA_ARGS__} \
} \
break; \
case mshadow::kInt32: \
{ \
typedef int32_t DType; \
{__VA_ARGS__} \
} \
break; \
case mshadow::kInt64: \
{ \
typedef int64_t DType; \
{__VA_ARGS__} \
} \
break; \
case mshadow::kBool: \
{ \
typedef bool DType; \
{__VA_ARGS__} \
} \
break; \
default: \
LOG(FATAL) << "Unknown type enum " << type; \
}

/*! \brief get data type size from type enum */
inline size_t mshadow_sizeof(int type) {
int size = 0;
MSHADOW_TYPE_SWITCH(type, DType, size = sizeof(DType););
MSHADOW_TYPE_SWITCH_WITH_BOOL(type, DType, size = sizeof(DType););
return size;
}

Expand Down
6 changes: 5 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -756,11 +756,15 @@ if(USE_TVM_OP)
endif()
endif()

set(TVM_OP_COMPILE_OPTIONS "-o${CMAKE_CURRENT_BINARY_DIR}/libtvmop.so")
if(CUDA_ARCH_BIN)
set(TVM_OP_COMPILE_OPTIONS "${TVM_OP_COMPILE_OPTIONS}" "--cuda-arch" "${CUDA_ARCH_BIN}")
endif()
add_custom_command(TARGET mxnet POST_BUILD
COMMAND ${CMAKE_COMMAND} -E env
PYTHONPATH="${CMAKE_CURRENT_SOURCE_DIR}/3rdparty/tvm/python:${CMAKE_CURRENT_SOURCE_DIR}/3rdparty/tvm/topi/python:${CMAKE_CURRENT_SOURCE_DIR}/contrib"
LD_LIBRARY_PATH=${CMAKE_CURRENT_BINARY_DIR}:${CMAKE_CURRENT_BINARY_DIR}/3rdparty/tvm:$ENV{LD_LIBRARY_PATH}
${Python3_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/contrib/tvmop/compile.py -o${CMAKE_CURRENT_BINARY_DIR}/libtvmop.so
${Python3_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/contrib/tvmop/compile.py ${TVM_OP_COMPILE_OPTIONS}
)
endif()

Expand Down
6 changes: 5 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -630,11 +630,15 @@ lib/libtvm_runtime.so:
ls $(ROOTDIR)/lib; \
cd $(ROOTDIR)

TVM_OP_COMPILE_OPTIONS = -o $(ROOTDIR)/lib/libtvmop.so
ifneq ($(CUDA_ARCH),)
TVM_OP_COMPILE_OPTIONS += --cuda-arch "$(CUDA_ARCH)"
endif
lib/libtvmop.so: lib/libtvm_runtime.so $(wildcard contrib/tvmop/*/*.py contrib/tvmop/*.py)
echo "Compile TVM operators"
PYTHONPATH=$(TVM_PATH)/python:$(TVM_PATH)/topi/python:$(ROOTDIR)/contrib \
LD_LIBRARY_PATH=$(ROOTDIR)/lib \
python3 $(ROOTDIR)/contrib/tvmop/compile.py -o $(ROOTDIR)/lib/libtvmop.so
python3 $(ROOTDIR)/contrib/tvmop/compile.py $(TVM_OP_COMPILE_OPTIONS)

NNVM_INC = $(wildcard $(NNVM_PATH)/include/*/*.h)
NNVM_SRC = $(wildcard $(NNVM_PATH)/src/*/*/*.cc $(NNVM_PATH)/src/*/*.cc $(NNVM_PATH)/src/*.cc)
Expand Down
4 changes: 2 additions & 2 deletions ci/docker/runtime_functions.sh
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ build_ubuntu_gpu_mkldnn_release() {
# $1 -> mxnet_variant: the mxnet variant to build, e.g. cpu, cu100, cu92mkl, etc.
build_dynamic_libmxnet() {
set -ex

local mxnet_variant=${1:?"This function requires a mxnet variant as the first argument"}

# relevant licenses will be placed in the licenses directory
Expand Down Expand Up @@ -948,7 +948,7 @@ cd_unittest_ubuntu() {
fi

$nose_cmd $NOSE_TIMER_ARGUMENTS --verbose tests/python/unittest
$nose_cmd $NOSE_TIMER_ARGUMENTS --verbose tests/python/quantization
$nose_cmd $NOSE_TIMER_ARGUMENTS --verbose tests/python/quantization

# https://github.com/apache/incubator-mxnet/issues/11801
# if [[ ${mxnet_variant} = "cpu" ]] || [[ ${mxnet_variant} = "mkl" ]]; then
Expand Down
1 change: 1 addition & 0 deletions contrib/tvmop/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,4 @@
from .utils import assign_by_req, reduce_axes

from . import basic
from . import core
41 changes: 40 additions & 1 deletion contrib/tvmop/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,13 @@

import os
import argparse
import re
import logging
from tvmop.opdef import __OP_DEF__
from tvm.autotvm.measure.measure_methods import set_cuda_target_arch

logging.basicConfig(level=logging.INFO)


def get_target(device):
if device == "cpu":
Expand All @@ -31,12 +37,39 @@ def get_target(device):
assert False, "Unknown device " + device


def get_cuda_arch(arch):
if arch is None:
return None

if not isinstance(arch, str):
raise TypeError('Expecting parameter arch as a str, while got a {}'.format(str(type(arch))))

if len(arch) == 0:
return None

# the arch string contains '-arch=sm_xx'
flags = arch.split()
for flag in flags:
if flag.startswith('-arch='):
return flag[len('-arch='):]

# find the highest compute capability
comp_caps = re.findall(r'\d+', arch)
if len(comp_caps) == 0:
return None

comp_caps = [int(c) for c in comp_caps]
return 'sm_' + str(max(comp_caps))


if __name__ == "__main__":
import sys
sys.path.append(os.path.dirname(sys.path[0]))
parser = argparse.ArgumentParser(description="Generate tvm operators")
parser.add_argument("-o", action="store", required=True, dest="target_path",
help="Target path which stores compiled library")
parser.add_argument('--cuda-arch', type=str, default=None, dest='cuda_arch',
help='The cuda arch for compiling kernels for')
arguments = parser.parse_args()

func_list_llvm = []
Expand All @@ -52,8 +85,14 @@ def get_target(device):
binds=operator_def.get_binds(args))
func_list.append(func_lower)

lowered_funcs = {get_target("cpu") : func_list_llvm}
lowered_funcs = {get_target("cpu"): func_list_llvm}
if len(func_list_cuda) > 0:
lowered_funcs[get_target("cuda")] = func_list_cuda
cuda_arch = get_cuda_arch(arguments.cuda_arch)
if cuda_arch is None:
logging.info('No cuda arch specified. TVM will try to detect it from the build platform.')
else:
logging.info('Cuda arch {} set for compiling TVM operator kernels.'.format(cuda_arch))
set_cuda_target_arch(cuda_arch)
func_binary = tvm.build(lowered_funcs, name="tvmop")
func_binary.export_library(arguments.target_path)
18 changes: 18 additions & 0 deletions contrib/tvmop/core/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from . import umath, fromnumeric
63 changes: 63 additions & 0 deletions contrib/tvmop/core/fromnumeric.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.


import tvm
from .. import defop
from ..utils import reduce_axes, assign_by_req


def _compute_sum(itype, otype, ndim, reduce1st_dim, req):
axes = ([reduce1st_dim, 1 - reduce1st_dim] * ndim)[:ndim]
a = tvm.placeholder([tvm.var() for _ in range(ndim)], name='a', dtype=itype)
reduce_output = reduce_axes(a, axes, tvm.sum, otype)
output_placeholder, final_output = assign_by_req(reduce_output, req)
s = tvm.create_schedule(final_output.op)
return s, a, output_placeholder, final_output, [reduce_output, final_output]


@defop(name='sum_cpu', target='cpu', itype=['bool'],
otype=['float32', 'float64', 'int32', 'int64'],
ndim=[5], req=['kWriteTo', 'kAddTo'], reduce1st_dim=[0, 1],
attrs=["reduce1st_dim", "req"])
def _sum_cpu(itype, otype, ndim, reduce1st_dim, req):
s, a, output_placeholder, final_output, tensor_list = _compute_sum(
itype, otype, ndim, reduce1st_dim, req)
for t in tensor_list:
axes = [axis for axis in t.op.axis]
fused = s[t].fuse(*axes)
s[t].parallel(fused)
return s, [a, output_placeholder, final_output]


@defop(name='sum_gpu', target='gpu', itype=['bool'],
otype=['float32', 'float64', 'int32', 'int64'],
ndim=[5], req=['kWriteTo', 'kAddTo'], reduce1st_dim=[0, 1],
attrs=["reduce1st_dim", "req"])
def _sum_gpu(itype, otype, ndim, reduce1st_dim, req):
s, a, output_placeholder, final_output, tensor_list = _compute_sum(
itype, otype, ndim, reduce1st_dim, req)
num_threads = 64
for t in tensor_list:
block_x = tvm.thread_axis("blockIdx.x")
thread_x = tvm.thread_axis("threadIdx.x")
axes = [axis for axis in t.op.axis]
fused = s[t].fuse(*axes)
bx, tx = s[t].split(fused, factor=num_threads)
s[t].bind(bx, block_x)
s[t].bind(tx, thread_x)
return s, [a, output_placeholder, final_output]
Loading