Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

can akg descrip a whole network model #8

Open
MingliSun opened this issue May 2, 2021 · 5 comments
Open

can akg descrip a whole network model #8

MingliSun opened this issue May 2, 2021 · 5 comments

Comments

@MingliSun
Copy link

Hi,I noticed that we can pass compute/hybrid or autodiff to akg,but how to descrip a whole network model which includes lot of operator? should combine all operators into a tvm.compute,then we can can akg.build(schdule,args,...).
I'm confusing about that.
Thanks a lot.

@anyrenwei
Copy link
Contributor

HI, mingli @MingliSun. Akg also supports composite operators which are descripted by a json file. In fact, if u enable Graph kernel fusion on MindSpore(https://www.mindspore.cn/tutorial/zh-CN/r0.5/advanced_use/graph_kernel_fusion.html), you can see the detailed compiling process of akg

@MingliSun
Copy link
Author

MingliSun commented May 13, 2021

@anyrenwei Thanks for your reply,Graph kernel fusion only support ascend.Now I found all needed to do to implement a network using tvm.compute (tvm.te.compute) is defining different operators and call them sequentially.so I write a simple demo (lenet),but a check failed errors occured.here is my code and error message:

from akg import tvm
import akg
import numpy as np

def padding(X, ph, pw, val=0):
    """Pad X with the given value in 2-D

    ph, pw : height and width padding
    val : padding value, default 0
    """
    assert len(X.shape) >= 2
    nh, nw = X.shape[-2], X.shape[-1]
    return tvm.compute(
            (*X.shape[0:-2], nh+ph*2, nw+pw*2),
            lambda *i: tvm.if_then_else(
                tvm.any(i[-2]<ph, i[-2]>=nh+ph, i[-1]<pw, i[-1]>=nw+pw),
                val, X[i[:-2]+(i[-2]-ph, i[-1]-pw)]),
            name='PaddedX')
def conv_out_size(n, k, p, s):
    """Compute the output size by given input size n (width or height),
    kernel size k, padding p, and stride s
    Return output size (width or height)
    """
    return (n - k + 2 * p)//s + 1

def conv2d(X,K,ph=0,pw=0,sh=1,sw=1):
	"""
	conv2d with activaction = relu
	padding = (ph,pw)
	stride = (sh,sw)
	X ; layout NCHW

	"""
	batch_size,ic,nh,nw = X.shape
	oc,ic,kh,kw = K.shape
	ric = tvm.reduce_axis((0, ic), name='ric')
	rkh = tvm.reduce_axis((0, kh), name='rkh')
	rkw = tvm.reduce_axis((0, kw), name='rkw')

	oh = conv_out_size(nh, kh, ph, sh)
	ow = conv_out_size(nw, kw, pw, sw)

	PaddedX = padding(X, ph, pw) if ph * pw != 0 else X
	Y = tvm.compute(
	    (batch_size,oc, oh, ow),
	    lambda n,c, i, j: tvm.sum(
	        PaddedX[n,ric, i*sh+rkh, j*sw+rkw] * K[c, ric, rkh, rkw],
	        axis=[ric, rkh, rkw]), name='Y')
	out = tvm.compute(
			(batch_size,oc, oh, ow),
			lambda n,c,i,j:tvm.if_then_else(
					Y[n,c,i,j]<=0,0,Y[n,c,i,j]
				)
		)
	return out

def pooling(pool_type,X,kh,kw,ph=0,pw=0,sh=1,sw=1):
	"""
	pool_type : max or avg
	kernel size = (kh,kw)
	padding = (ph,pw)
	stride = (sh,sw)
	X ; layout NCHW
	"""
	batch_size,c,nh,nw = X.shape
	rkh = tvm.reduce_axis((0, kh), name='rkh')
	rkw = tvm.reduce_axis((0, kw), name='rkw')
    # output height and weights
	oh = conv_out_size(nh, kh, ph, sh)
	ow = conv_out_size(nw, kw, pw, sw)
	if pool_type == 'max':
	    PaddedX = padding(X, ph, pw, val=tvm.min_value(X.dtype)) \
	        if ph * pw != 0 else X
	    Y = tvm.compute((batch_size,c, oh, ow), \
	                        lambda n,c, h, w: \
	                        tvm.max(PaddedX[n,c, h*sh+rkh, w*sw+rkw], \
	                            axis=[rkh, rkw]), \
	                        tag="pool_max", name='PoolMax')
	elif pool_type == 'avg':
	    PaddedX = padding(X, ph, pw) if ph * pw != 0 else X
	    tsum = tvm.compute((batch_size,c, oh, ow), \
	                        lambda n,c, h, w: \
	                        tvm.sum(PaddedX[n,c, h*sh+rkh, w*sw+rkw], \
	                            axis=[rkh, rkw]), \
	                        tag="pool_avg1", name='PoolSum')
	    Y = tvm.compute((batch_size,c, oh, ow), \
	                        lambda n,c, h, w: \
	                        tsum[n,c, h, w] / (kh*kw), \
	                        tag='pool_avg2', name='PoolAvg')
	else:
	    raise ValueError("Pool type should be 'avg' or 'max'.")
	return Y
def reshape(X):
	"""
	X layout NCHW====>(N,C*H*W)
	"""
	assert len(X.shape)==4
	batch_size,c,nh,nw = X.shape
	return tvm.compute(
			(batch_size,c*nh*nw),
			lambda n,s:X[n,s//(nh*nw),s%(nh*nw)//nw,s%(nh*nw)%nw]
		)
def dense(X,W,activaction=False):
	"""
	fully-connected layer
	"""
	reshapeX = X if len(X.shape)==2 else reshape(X)
	n,_ = reshapeX.shape
	l,m = W.shape
	k = tvm.reduce_axis((0, l), name='k')
	Y = tvm.compute(
			(n,m),
			lambda x,y:tvm.sum(reshapeX[x,k]*W[k,y],axis=k),
			name='Y'
		)
	if activaction==False:return Y
	out = tvm.compute(
    		(n,m),
    		lambda i,j:tvm.if_then_else(
    				Y[i,j]<=0,0,Y[i,j]
    			),
    		name = 'dense'
    	)
	return out
def lenet():
	batch_size,oc, ic, n, k = 1,6, 1, 28, 5
	X = tvm.placeholder((batch_size,ic,n,n),name='data')
	conv1_weight = tvm.placeholder((oc,ic,k,k),name='conv1_weight')
	conv1 = conv2d(X,conv1_weight)
	print("conv1 shape:",conv1.shape)
	pool1 = pooling('max',conv1,2,2,0,0,2,2)
	print("pool1 shape:",pool1.shape)
	conv2_weight = tvm.placeholder((16,6,5,5),name='conv2_weight')
	conv2 = conv2d(pool1,conv2_weight)
	print("conv2 shape:",conv2.shape)
	pool2 = pooling('max',conv2,2,2,0,0,2,2)
	print("pool2 shape:",pool2.shape)

	dense1_weight = tvm.placeholder((sum(pool2.shape[-3:]),120),name='dense1_weight')
	dense2_weight = tvm.placeholder((120,84),name='dense2_weight')
	dense3_weight = tvm.placeholder((84,10),name='dense3_weight')
	dense1 = dense(pool2,dense1_weight,True)
	dense2 = dense(dense1,dense2_weight,True)
	dense3 = dense(dense2,dense3_weight,True)
	print("dense1 shape:",dense1.shape)
	print("dense2 shape:",dense2.shape)
	print("dense3 shape:",dense3.shape)

	return X,conv1_weight,conv2_weight,dense1_weight,dense2_weight,dense3_weight,dense3



X,conv1_weight,conv2_weight,dense1_weight,dense2_weight,dense3_weight,Y = lenet()
sch = tvm.create_schedule(Y.op)
mod = akg.build(sch, [X,conv1_weight,conv2_weight,dense1_weight,dense2_weight,dense3_weight,Y], 'cuda', [], name='myfunc', attrs={}, polyhedral=True, binds=None)

error messges:

conv1 shape: [1, 6, 24, 24]
pool1 shape: [1, 6, 12, 12]
conv2 shape: [1, 16, 8, 8]
pool2 shape: [1, 16, 4, 4]
dense1 shape: [1, 120]
dense2 shape: [1, 84]
dense3 shape: [1, 10]
[ERROR] AKG:2021-05-13-09:16:42.265.267 [tiling_utils.cc:171] [tiling] Check failed: gemm_m.size() <= FormatM.size() (5 vs. 2) : 
Stack trace:
  [bt] (0) /home/sun/gitDownload/akg_new/akg/build/libakg.so(dmlc::LogMessageFatal::~LogMessageFatal()+0x4f) [0x7f5144beb13f]
  [bt] (1) /home/sun/gitDownload/akg_new/akg/build/libakg.so(akg::ir::poly::ExtractLoopIndicesFromMatrices(std::vector<std::vector<std::string, std::allocator<std::string> >, std::allocator<std::vector<std::string, std::allocator<std::string> > > >)+0x2e2c) [0x7f5145295b6c]
  [bt] (2) /home/sun/gitDownload/akg_new/akg/build/libakg.so(akg::ir::poly::SpaceAnalyzer::MarkGemmAxes(akg::ir::poly::SpaceAnalyzer::ProvideEntry const&)+0x409) [0x7f5145228819]
  [bt] (3) /home/sun/gitDownload/akg_new/akg/build/libakg.so(akg::ir::poly::SpaceAnalyzer::IdentifyInsnType()+0x87a) [0x7f514522d76a]
  [bt] (4) /home/sun/gitDownload/akg_new/akg/build/libakg.so(akg::ir::poly::SpaceAnalyzer::AnalyzeSpecialAxes()+0x865) [0x7f514522e195]
  [bt] (5) /home/sun/gitDownload/akg_new/akg/build/libakg.so(akg::ir::poly::TilingAnalyzer::Prepare()+0x5f9) [0x7f514524a1f9]
  [bt] (6) /home/sun/gitDownload/akg_new/akg/build/libakg.so(akg::ir::poly::GenerateTiling(isl::schedule const&, akg::ir::poly::ScopInfo&, air::Stmt)+0x12a) [0x7f5145235aaa]
  [bt] (7) /home/sun/gitDownload/akg_new/akg/build/libakg.so(akg::ir::poly::GpuDmaAnalysis::GetTiledNode(isl::schedule, isl::schedule_node)+0x2cf) [0x7f514514bd0f]
  [bt] (8) /home/sun/gitDownload/akg_new/akg/build/libakg.so(akg::ir::poly::GpuDmaAnalysis::RemoveInjectiveTensorFromMemFlows(isl::schedule)+0x116) [0x7f514514f946]

Traceback (most recent call last):

  File "lenet_compute.py", line 155, in <module>
    mod = akg.build(sch, [X,conv1_weight,conv2_weight,dense1_weight,dense2_weight,dense3_weight,Y], 'cuda', [], name='myfunc', attrs={}, polyhedral=True, binds=None)

  File "/home/sun/gitDownload/akg_new/akg/python/akg/utils/validation_check.py", line 135, in in_wrapper
    return func(*args, **kwargs)

  File "/home/sun/gitDownload/akg_new/akg/python/akg/build_module.py", line 155, in build
    tmp_rst = build_to_func(inputs, args, shape_params=shape_params, name=name, binds=binds,

  File "/home/sun/gitDownload/akg_new/akg/python/akg/utils/validation_check.py", line 135, in in_wrapper
    return func(*args, **kwargs)

  File "/home/sun/gitDownload/akg_new/akg/python/akg/build_module.py", line 148, in build_to_func
    return _api_internal._BuildToFunc(inputs, args, shape_params, name, tmp_binds, tmp_attrs,

  File "/home/sun/gitDownload/akg_new/akg/third_party/incubator-tvm/python/tvm/_ffi/_ctypes/function.py", line 207, in __call__
    raise get_last_ffi_error()

tvm._ffi.base.TVMError: Traceback (most recent call last):
  [bt] (8) /home/sun/gitDownload/akg_new/akg/build/libakg.so(akg::ir::poly::GpuDmaAnalysis::RemoveInjectiveTensorFromMemFlows(isl::schedule)+0x116) [0x7f514514f946]
  [bt] (7) /home/sun/gitDownload/akg_new/akg/build/libakg.so(akg::ir::poly::GpuDmaAnalysis::GetTiledNode(isl::schedule, isl::schedule_node)+0x2cf) [0x7f514514bd0f]
  [bt] (6) /home/sun/gitDownload/akg_new/akg/build/libakg.so(akg::ir::poly::GenerateTiling(isl::schedule const&, akg::ir::poly::ScopInfo&, air::Stmt)+0x12a) [0x7f5145235aaa]
  [bt] (5) /home/sun/gitDownload/akg_new/akg/build/libakg.so(akg::ir::poly::TilingAnalyzer::Prepare()+0x5f9) [0x7f514524a1f9]
  [bt] (4) /home/sun/gitDownload/akg_new/akg/build/libakg.so(akg::ir::poly::SpaceAnalyzer::AnalyzeSpecialAxes()+0x865) [0x7f514522e195]
  [bt] (3) /home/sun/gitDownload/akg_new/akg/build/libakg.so(akg::ir::poly::SpaceAnalyzer::IdentifyInsnType()+0x87a) [0x7f514522d76a]
  [bt] (2) /home/sun/gitDownload/akg_new/akg/build/libakg.so(akg::ir::poly::SpaceAnalyzer::MarkGemmAxes(akg::ir::poly::SpaceAnalyzer::ProvideEntry const&)+0x409) [0x7f5145228819]
  [bt] (1) /home/sun/gitDownload/akg_new/akg/build/libakg.so(akg::ir::poly::ExtractLoopIndicesFromMatrices(std::vector<std::vector<std::string, std::allocator<std::string> >, std::allocator<std::vector<std::string, std::allocator<std::string> > > >)+0x2e2c) [0x7f5145295b6c]
  [bt] (0) /home/sun/gitDownload/akg_new/akg/build/libakg.so(dmlc::LogMessageFatal::~LogMessageFatal()+0x4f) [0x7f5144beb13f]
  File "/home/sun/gitDownload/akg_new/akg/src/poly/tiling/tiling_utils.cc", line 171
TVMError: Check failed: gemm_m.size() <= FormatM.size() (5 vs. 2) :

what's that check failed mean?if my way to write a network is wrong,please give some tips to modify it.If akg don't support that for now,can you give me a demo about a typical network in cuda or ascend(especially ascend) .Using tvm.compute or topi or something else are fine.And my akg version is updated to latest release.
Thanks a lot!

@anyrenwei
Copy link
Contributor

First, Graph kernel fusion also support GPU. You can run Bert on MindSpore and it will enable graph kernel fusion defaultly. If you want to check all the fusion operators we now supported, you can just export the env below:
export MS_AKG_DUMP_CODE=on #保存算子cuda代码
export MS_AKG_DUMP_IR=on #保存算子IR

@anyrenwei
Copy link
Contributor

Second, the error you listed seems that the auto-tiling module does not support the gemm format of you defined... If you want to run related operators on Ascend, I suggest you use these operators directly which defined in: https://gitee.com/mindspore/akg/tree/master/tests/common/test_op

@MingliSun
Copy link
Author

MingliSun commented May 17, 2021

@anyrenwei Thanks for your reply.For now I run multiple operators in cuda(cause all I have is ascend310 which is not supported as told on another issue),and I try using topi.nn not define own operators using tvm.compute,here is my code:

import akg
from akg import tvm
from akg import topi
X0 = tvm.placeholder((2,3,28,28),dtype='float32',name='X0')
conv1_weight = tvm.placeholder((6,3,5,5),dtype='float32',name='conv1_weight')
conv1 = akg.topi.nn.conv2d(X0,conv1_weight,1,0,1)
pool1 = akg.topi.nn.pool(conv1,(2,2),(2,2),(0,0,0,0),'max')

conv2_weight = tvm.placeholder((16,6,5,5),dtype='float32',name='conv2_weight')
conv2 = akg.topi.nn.conv2d(pool1,conv2_weight,1,0,1)
pool2 = akg.topi.nn.pool(conv2,(2,2),(2,2),(0,0,0,0),'max')
print(conv2,pool2)

s = tvm.create_schedule([conv1.op,pool1.op,conv2.op,pool2.op])
mod = akg.build(s,[X0,conv1_weight,conv1,pool1,conv2_weight,conv2,pool2],'cuda', [], name='myfunc', attrs={}, polyhedral=True, binds=None)

And an error occurred .
on early version of akg (commit id: 8ac73f9) it was segmentation fault and no error messages and for latest version of akg it was:

terminate called after throwing an instance of 'isl::exception_invalid'
  what():  isl_tab_pip.c:714: unbounded optimum

so is that a akg bug?How to solve it ?And can you give me any network examples such as lenet,resnet,vgg... works fine.
using topi ,compute,tvm.hybrid.script or something else are either ok

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants