forked from apache/tvm
-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[CONTRIB] MPS DNN Dense (apache#615)
* mps * update
- Loading branch information
Showing
8 changed files
with
277 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -68,3 +68,6 @@ USE_NNPACK = 0 | |
|
||
# Whether use CuDNN | ||
USE_CUDNN = 0 | ||
|
||
# Whether use MPS | ||
USE_MPS = 0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
MPS_CONTRIB_SRC = $(wildcard src/contrib/mps/*.mm, src/contrib/mps/*.cc) | ||
MPS_CONTRIB_OBJ = $(patsubst src/%.mm, build/%.o, $(MPS_CONTRIB_SRC)) | ||
|
||
ifeq ($(USE_MPS), 1) | ||
FRAMEWORKS += -framework MetalPerformanceShaders | ||
CFLAGS += | ||
ADD_LDFLAGS += | ||
RUNTIME_DEP += $(MPS_CONTRIB_OBJ) | ||
endif | ||
|
||
build/contrib/mps/%.o: src/contrib/mps/%.mm src/contrib/mps/%.cc | ||
@mkdir -p $(@D) | ||
$(CXX) $(OBJCFLAGS) $(CFLAGS) -MM -MT build/contrib/mps/$*.o $< >build/contrib/mps/$*.d | ||
$(CXX) $(OBJCFLAGS) -c $(CFLAGS) -c $< -o $@ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
"""External function interface to MPS libraroes.""" | ||
from __future__ import absolute_import as _abs | ||
|
||
from .. import api as _api | ||
from .. import intrin as _intrin | ||
|
||
|
||
def matmul(lhs, rhs, transa=False, transb=False): | ||
"""Create an extern op that compute matrix mult of A and rhs with CrhsLAS | ||
This function serves as an example on how to calle external libraries. | ||
Parameters | ||
---------- | ||
lhs : Tensor | ||
The left matrix operand | ||
rhs : Tensor | ||
The right matrix operand | ||
transa : bool | ||
Whether transpose lhs | ||
transb : bool | ||
Whether transpose rhs | ||
Returns | ||
------- | ||
C : Tensor | ||
The result tensor. | ||
""" | ||
m = lhs.shape[0] | ||
n = rhs.shape[1] | ||
return _api.extern( | ||
(n, m), [lhs, rhs], | ||
lambda ins, outs: _intrin.call_packed( | ||
"tvm.contrib.mps.matmul", ins[0], ins[1], outs[0], transa, transb), | ||
name="C") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
#include "../../runtime/metal/metal_common.h" | ||
#include <MetalPerformanceShaders/MetalPerformanceShaders.h> | ||
#include <dmlc/logging.h> | ||
#include <tvm/runtime/device_api.h> | ||
#include <tvm/runtime/registry.h> | ||
#include <tvm/runtime/util.h> | ||
|
||
namespace tvm { | ||
namespace contrib { | ||
|
||
using namespace runtime; | ||
|
||
TVM_REGISTER_GLOBAL("tvm.contrib.mps.matmul") | ||
.set_body([](TVMArgs args, TVMRetValue *ret) { | ||
DLTensor *A = args[0]; | ||
DLTensor *B = args[1]; | ||
DLTensor *C = args[2]; | ||
bool transa = args[3]; | ||
bool transb = args[4]; | ||
// call gemm for simple compact code. | ||
CHECK_EQ(A->ndim, 2); | ||
CHECK_EQ(B->ndim, 2); | ||
CHECK_EQ(C->ndim, 2); | ||
CHECK(C->strides == nullptr); | ||
CHECK(B->strides == nullptr); | ||
CHECK(A->strides == nullptr); | ||
CHECK(TypeMatch(A->dtype, kDLFloat, 32)); | ||
CHECK(TypeMatch(B->dtype, kDLFloat, 32)); | ||
CHECK(TypeMatch(C->dtype, kDLFloat, 32)); | ||
// Get Metal device API | ||
MetalThreadEntry* entry_ptr = MetalThreadEntry::ThreadLocal(); | ||
CHECK_EQ(A->ctx, B->ctx); | ||
CHECK_EQ(A->ctx, C->ctx); | ||
id<MTLDevice> dev = entry_ptr->metal_api->GetDevice(A->ctx); | ||
id<MTLCommandQueue> queue = entry_ptr->metal_api->GetCommandQueue(A->ctx); | ||
id<MTLCommandBuffer> cb = [queue commandBuffer]; | ||
NSUInteger M = A->shape[0 + transa?1:0]; | ||
NSUInteger N = B->shape[1 - transb?1:0]; | ||
NSUInteger K = B->shape[0 + transb?1:0]; | ||
CHECK_EQ(A->shape[1-transa?1:0], K); | ||
// mps a | ||
MPSDataType dtype = MPSType::DLTypeToMPSType(A->dtype); | ||
MPSMatrixDescriptor *descA = [MPSMatrixDescriptor | ||
matrixDescriptorWithDimensions:M | ||
columns:K | ||
rowBytes:M * sizeof(dtype) | ||
dataType:dtype]; | ||
id<MTLBuffer> bufA = (__bridge id<MTLBuffer>)(A->data); | ||
MPSMatrix *matrixA = | ||
[[MPSMatrix alloc] initWithBuffer:bufA descriptor:descA]; | ||
// mps b | ||
MPSMatrixDescriptor *descB = [MPSMatrixDescriptor | ||
matrixDescriptorWithDimensions:K | ||
columns:N | ||
rowBytes:K * sizeof(dtype) | ||
dataType:dtype]; | ||
id<MTLBuffer> bufB = (__bridge id<MTLBuffer>)(B->data); | ||
MPSMatrix *matrixB = | ||
[[MPSMatrix alloc] initWithBuffer:bufB descriptor:descB]; | ||
// mps c | ||
MPSMatrixDescriptor *descC = [MPSMatrixDescriptor | ||
matrixDescriptorWithDimensions:M | ||
columns:N | ||
rowBytes:M * sizeof(dtype) | ||
dataType:dtype]; | ||
id<MTLBuffer> bufC = (__bridge id<MTLBuffer>)(C->data); | ||
MPSMatrix *matrixC = | ||
[[MPSMatrix alloc] initWithBuffer:bufC descriptor:descC]; | ||
// kernel | ||
|
||
MPSMatrixMultiplication *mul_obj = [[MPSMatrixMultiplication alloc] init]; | ||
MPSMatrixMultiplication *sgemm = [mul_obj initWithDevice:dev | ||
transposeLeft:transa | ||
transposeRight:transb | ||
resultRows:M | ||
resultColumns:N | ||
interiorColumns:K | ||
alpha:1.0f | ||
beta:0.0f]; | ||
CHECK(sgemm != nil); | ||
[sgemm encodeToCommandBuffer:cb | ||
leftMatrix:matrixA | ||
rightMatrix:matrixB | ||
resultMatrix:matrixC]; | ||
[cb commit]; | ||
[mul_obj dealloc]; | ||
[matrixA dealloc]; | ||
[matrixB dealloc]; | ||
[matrixC dealloc]; | ||
}); | ||
|
||
} // namespace contrib | ||
} // namespace tvm |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
/*! | ||
* Copyright (c) 2017 by Contributors | ||
* \file Use external mps utils function | ||
*/ | ||
#include "mps_utils.h" | ||
#include <MetalPerformanceShaders/MetalPerformanceShaders.h> | ||
#include <dmlc/thread_local.h> | ||
#include <tvm/runtime/registry.h> | ||
|
||
|
||
namespace tvm { | ||
namespace contrib { | ||
|
||
// MPS Data Type | ||
MPSDataType MPSType::DLTypeToMPSType(const DLDataType &dtype) { | ||
switch (dtype.code) { | ||
case kDLInt: | ||
if (dtype.bits == 8 && dtype.lanes == 1) return MPSDataTypeInt8; | ||
else if (dtype.bits == 16 && dtype.lanes == 1) return MPSDataTypeInt16; | ||
else | ||
LOG(FATAL) << "Unsupported type"; | ||
break; | ||
case kDLUInt: | ||
if (dtype.bits == 8 && dtype.lanes == 1) return MPSDataTypeUInt8; | ||
else if (dtype.bits == 16 && dtype.lanes == 1) return MPSDataTypeUInt16; | ||
else if (dtype.bits == 32 && dtype.lanes == 1) return MPSDataTypeUInt32; | ||
LOG(FATAL) << "Unsupported type"; | ||
break; | ||
case kDLFloat: | ||
if (dtype.bits == 16 && dtype.lanes == 1) return MPSDataTypeFloat16; | ||
else if (dtype.bits == 32 && dtype.lanes == 1) return MPSDataTypeFloat32; | ||
else | ||
LOG(FATAL) << "Unsupported type"; | ||
break; | ||
default: | ||
LOG(FATAL) << "Unsupported type"; | ||
} | ||
} | ||
|
||
// MetalThreadEntry | ||
|
||
MetalThreadEntry::MetalThreadEntry() { | ||
auto func = runtime::Registry::Get("device_api.metal"); | ||
void *ret = (*func)(); | ||
metal_api = static_cast<runtime::metal::MetalWorkspace *>(ret); | ||
} | ||
|
||
MetalThreadEntry::~MetalThreadEntry() { | ||
} | ||
|
||
typedef dmlc::ThreadLocalStore<MetalThreadEntry> MetalThreadStore; | ||
|
||
MetalThreadEntry* MetalThreadEntry::ThreadLocal() { | ||
return MetalThreadStore::Get(); | ||
} | ||
|
||
} // namespace contrib | ||
} // namespace tvm |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
/*! | ||
* Copyright (c) 2017 by Contributors | ||
* \file Use external mps utils function | ||
*/ | ||
|
||
#ifndef TVM_CONTRIB_MPS_MPS_UTILS_H_ | ||
#define TVM_CONTRIB_MPS_MPS_UTILS_H_ | ||
|
||
#include <dmlc/logging.h> | ||
#include <tvm/runtime/device_api.h> | ||
#include "../../runtime/metal/metal_common.h" | ||
|
||
|
||
namespace tvm { | ||
namespace contrib { | ||
|
||
/*! breif Convert DLTensor type to MPS type */ | ||
struct MPSType { | ||
static MPSDataType DLTypeToMPSType(const DLDataType &dtype); | ||
}; // struct MPSType | ||
|
||
|
||
struct MetalThreadEntry { | ||
MetalThreadEntry(); | ||
~MetalThreadEntry(); | ||
runtime::MetalWorkspace *metal_api{nullptr}; | ||
static MetalThreadEntry* ThreadLocal(); | ||
}; // MetalThreadEntry | ||
|
||
} // namespace contrib | ||
} // namespace tvm | ||
|
||
#endif // TVM_CONTRIB_MPS_MPS_UTILS_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
import tvm | ||
import numpy as np | ||
from tvm.contrib import mps | ||
|
||
def test_matmul_add(): | ||
n = 1024 | ||
l = 128 | ||
m = 235 | ||
bias = tvm.var('bias', dtype=tvm.float32) | ||
A = tvm.placeholder((n, l), name='A') | ||
B = tvm.placeholder((l, m), name='B') | ||
C1 = mps.matmul(A, B) | ||
C2 = mps.matmul(B, A, True, True) | ||
D1 = tvm.compute(C1.shape, lambda i, j: C1[i,j] + bias, name="D1") | ||
D2 = tvm.compute(C2.shape, lambda i, j: C2[i,j] + bias, name="D2") | ||
s1 = tvm.create_schedule(D1.op) | ||
s2 = tvm.create_schedule(D2.op) | ||
|
||
def verify(A, B, D, s, bias, target="llvm"): | ||
if not tvm.module.enabled(target): | ||
print("skip because %s is not enabled..." % target) | ||
return | ||
if not tvm.get_global_func("tvm.contrib.mps.matmul", True): | ||
print("skip because extern function is not avalable") | ||
return | ||
ctx = tvm.cpu(0) | ||
f = tvm.build(s, [A, B, D, bias], target) | ||
a = tvm.nd.array(np.random.uniform(size=(n, l)).astype(A.dtype), ctx) | ||
b = tvm.nd.array(np.random.uniform(size=(l, m)).astype(B.dtype), ctx) | ||
d = tvm.nd.array(np.zeros((n, m), dtype=D.dtype), ctx) | ||
bb = 10.0 | ||
f(a, b, d, bb) | ||
np.testing.assert_allclose( | ||
d.asnumpy(), np.dot(a.asnumpy(), b.asnumpy()) + bb, rtol=1e-5) | ||
verify(A, B, D1, s1, bias) | ||
verify(A, B, D2, s2, bias) | ||
|
||
|
||
if __name__ == "__main__": | ||
test_matmul_add() |