Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

akg build error: Invalid Schedule #5

Open
MingliSun opened this issue Apr 5, 2021 · 1 comment
Open

akg build error: Invalid Schedule #5

MingliSun opened this issue Apr 5, 2021 · 1 comment

Comments

@MingliSun
Copy link

Hi,I'm trying dive into deep learning compiler tutorial and replace ### tvm.build with ### akg.build(sch, (X,Y), 'cuda', [], name='myfunc', attrs={}, polyhedral=True, binds=None)
when I try AvgPooling operator,I'm trying to do some schedule to merge stages of avgpooling such as autoInlineInjective.But when I have merged poolsum stage and poolavg stage usingPoolSum = Y.op.input_tensors[0] sch[PoolSum].compute_at(sch[Y], sch[Y].op.axis[2]),an error ocurred.
`[ERROR] AKG:2021-04-05-17:43:31.410.549 [graph.cc:223] [schedule] Check failed: start_attach: Invalid Schedule: cannot find attach point iter_var(h, range(min=0, ext=12)) in the schedule of compute(PoolAvg, 0x3126cc0)
Stack trace:
[bt] (0) /home/sun/gitDownload/akg/mybuild/libakg.so(dmlc::LogMessageFatal::~LogMessageFatal()+0x4f) [0x7fd326aa5fcf]
[bt] (1) /home/sun/gitDownload/akg/mybuild/libakg.so(air::schedule::CreateAttachPath(air::Schedule)+0x5d4) [0x7fd32789e654]
[bt] (2) /home/sun/gitDownload/akg/mybuild/libakg.so(air::schedule::InferBound(air::Schedule const&)+0xda4) [0x7fd327899ad4]
[bt] (3) /home/sun/gitDownload/akg/mybuild/libakg.so(akg::LowerStmt(air::Schedule, air::Array<air::NodeRef, void> const&, air::Array<air::NodeRef, void> const&, std::__cxx11::basic_string<char, std::char_traits, std::allocator > const&, air::Map<air::Tensor, air::Buffer, void, void> const&, air::Map<std::__cxx11::basic_string<char, std::char_traits, std::allocator >, air::NodeRef, void, void> const&, bool, bool, bool, std::__cxx11::basic_string<char, std::char_traits, std::allocator > const&, air::BuildConfig const&, air::Array<air::NodeRef, void>, air::Array<air::NodeRef, void>, air::Map<air::Tensor, air::Buffer, void, void>, air::Map<air::Tensor, air::Buffer, void, void>, bool)+0x384) [0x7fd326af3b34]
[bt] (4) /home/sun/gitDownload/akg/mybuild/libakg.so(akg::Lower(air::Schedule, air::Array<air::NodeRef, void> const&, air::Array<air::NodeRef, void> const&, std::__cxx11::basic_string<char, std::char_traits, std::allocator > const&, air::Map<air::Tensor, air::Buffer, void, void> const&, air::Map<std::__cxx11::basic_string<char, std::char_traits, std::allocator >, air::NodeRef, void, void> const&, bool, bool, bool, std::__cxx11::basic_string<char, std::char_traits, std::allocator > const&, air::BuildConfig const&)+0x166) [0x7fd326af67f6]
[bt] (5) /home/sun/gitDownload/akg/mybuild/libakg.so(akg::BuildToFunc(air::Schedule const&, air::Array<air::NodeRef, void> const&, air::Array<air::NodeRef, void> const&, std::__cxx11::basic_string<char, std::char_traits, std::allocator > const&, air::Map<air::Tensor, air::Buffer, void, void> const&, air::Map<std::__cxx11::basic_string<char, std::char_traits, std::allocator >, air::NodeRef, void, void> const&, bool, std::__cxx11::basic_string<char, std::char_traits, std::allocator > const&, air::BuildConfig const&)+0x24f) [0x7fd326b00dbf]
[bt] (6) /home/sun/gitDownload/akg/mybuild/libakg.so(void air::runtime::detail::unpack_call_dispatcher<akg::BuildRst, 0, 9, akg::BuildRst ()(air::Schedule const&, air::Array<air::NodeRef, void> const&, air::Array<air::NodeRef, void> const&, std::__cxx11::basic_string<char, std::char_traits, std::allocator > const&, air::Map<air::Tensor, air::Buffer, void, void> const&, air::Map<std::__cxx11::basic_string<char, std::char_traits, std::allocator >, air::NodeRef, void, void> const&, bool, std::__cxx11::basic_string<char, std::char_traits, std::allocator > const&, air::BuildConfig const&)>::run<air::runtime::TVMArgValue, air::runtime::TVMArgValue, air::runtime::TVMArgValue, air::runtime::TVMArgValue, air::runtime::TVMArgValue, air::runtime::TVMArgValue, air::runtime::TVMArgValue, air::runtime::TVMArgValue, air::runtime::TVMArgValue>(akg::BuildRst ( const&)(air::Schedule const&, air::Array<air::NodeRef, void> const&, air::Array<air::NodeRef, void> const&, std::__cxx11::basic_string<char, std::char_traits, std::allocator > const&, air::Map<air::Tensor, air::Buffer, void, void> const&, air::Map<std::__cxx11::basic_string<char, std::char_traits, std::allocator >, air::NodeRef, void, void> const&, bool, std::__cxx11::basic_string<char, std::char_traits, std::allocator > const&, air::BuildConfig const&), air::runtime::TVMArgs const&, air::runtime::TVMRetValue*, air::runtime::TVMArgValue&&, air::runtime::TVMArgValue&&, air::runtime::TVMArgValue&&, air::runtime::TVMArgValue&&, air::runtime::TVMArgValue&&, air::runtime::TVMArgValue&&, air::runtime::TVM
Traceback (most recent call last):

File "pooling.py", line 70, in
mod = akg.build(sch, (X,Y), 'cuda', [], name='myfunc', attrs={}, polyhedral=True, binds=None)

File "/home/sun/gitDownload/akg/python/akg/utils/validation_check.py", line 135, in in_wrapper
return func(*args, **kwargs)

File "/home/sun/gitDownload/akg/python/akg/build_module.py", line 141, in build
tmp_rst = build_to_func(inputs, args, shape_params=shape_params, name=name, binds=binds,

File "/home/sun/gitDownload/akg/python/akg/utils/validation_check.py", line 135, in in_wrapper
return func(*args, **kwargs)

File "/home/sun/gitDownload/akg/python/akg/build_module.py", line 134, in build_to_func
return _api_internal._BuildToFunc(inputs, args, shape_params, name, tmp_binds, tmp_attrs,

File "/home/sun/gitDownload/akg/third_party/incubator-tvm/python/tvm/_ffi/_ctypes/function.py", line 207, in call
raise get_last_ffi_error()

tvm._ffi.base.TVMError: Traceback (most recent call last):
[bt] (8) /home/sun/gitDownload/akg/mybuild/libakg.so(TVMFuncCall+0x65) [0x7fd32780e305]
[bt] (7) /home/sun/gitDownload/akg/mybuild/libakg.so(std::_Function_handler<void (air::runtime::TVMArgs, air::runtime::TVMRetValue*), air::runtime::TypedPackedFunc<akg::BuildRst (air::Schedule const&, air::Array<air::NodeRef, void> const&, air::Array<air::NodeRef, void> const&, std::__cxx11::basic_string<char, std::char_traits, std::allocator > const&, air::Map<air::Tensor, air::Buffer, void, void> const&, air::Map<std::__cxx11::basic_string<char, std::char_traits, std::allocator >, air::NodeRef, void, void> const&, bool, std::__cxx11::basic_string<char, std::char_traits, std::allocator > const&, air::BuildConfig const&)>::AssignTypedLambda<akg::BuildRst ()(air::Schedule const&, air::Array<air::NodeRef, void> const&, air::Array<air::NodeRef, void> const&, std::__cxx11::basic_string<char, std::char_traits, std::allocator > const&, air::Map<air::Tensor, air::Buffer, void, void> const&, air::Map<std::__cxx11::basic_string<char, std::char_traits, std::allocator >, air::NodeRef, void, void> const&, bool, std::__cxx11::basic_string<char, std::char_traits, std::allocator > const&, air::BuildConfig const&)>(akg::BuildRst ()(air::Schedule const&, air::Array<air::NodeRef, void> const&, air::Array<air::NodeRef, void> const&, std::__cxx11::basic_string<char, std::char_traits, std::allocator > const&, air::Map<air::Tensor, air::Buffer, void, void> const&, air::Map<std::__cxx11::basic_string<char, std::char_traits, std::allocator >, air::NodeRef, void, void> const&, bool, std::__cxx11::basic_string<char, std::char_traits, std::allocator > const&, air::BuildConfig const&))::{lambda(air::runtime::TVMArgs const&, air::runtime::TVMRetValue*)#1}>::_M_invoke(std::_Any_data const&, air::runtime::TVMArgs&&, air::runtime::TVMRetValue*&&)+0x13a) [0x7fd326b1003a]
[bt] (6) /home/sun/gitDownload/akg/mybuild/libakg.so(void air::runtime::detail::unpack_call_dispatcher<akg::BuildRst, 0, 9, akg::BuildRst ()(air::Schedule const&, air::Array<air::NodeRef, void> const&, air::Array<air::NodeRef, void> const&, std::__cxx11::basic_string<char, std::char_traits, std::allocator > const&, air::Map<air::Tensor, air::Buffer, void, void> const&, air::Map<std::__cxx11::basic_string<char, std::char_traits, std::allocator >, air::NodeRef, void, void> const&, bool, std::__cxx11::basic_string<char, std::char_traits, std::allocator > const&, air::BuildConfig const&)>::run<air::runtime::TVMArgValue, air::runtime::TVMArgValue, air::runtime::TVMArgValue, air::runtime::TVMArgValue, air::runtime::TVMArgValue, air::runtime::TVMArgValue, air::runtime::TVMArgValue, air::runtime::TVMArgValue, air::runtime::TVMArgValue>(akg::BuildRst ( const&)(air::Schedule const&, air::Array<air::NodeRef, void> const&, air::Array<air::NodeRef, void> const&, std::__cxx11::basic_string<char, std::char_traits, std::allocator > const&, air::Map<air::Tensor, air::Buffer, void, void> const&, air::Map<std::__cxx11::basic_string<char, std::char_traits, std::allocator >, air::NodeRef, void, void> const&, bool, std::__cxx11::basic_string<char, std::char_traits, std::allocator > const&, air::BuildConfig const&), air::runtime::TVMArgs const&, air::runtime::TVMRetValue*, air::runtime::TVMArgValue&&, air::runtime::TVMArgValue&&, air::runtime::TVMArgValue&&, air::runtime::TVMArgValue&&, air::runtime::TVMArgValue&&, air::runtime::TVMArgValue&&, air::runtime::TVMArgValue&&, air::runtime::TVMArgValue&&, air::runtime::TVMArgValue&&)+0x176) [0x7fd326b0fcd6]
[bt] (5) /home/sun/gitDownload/akg/mybuild/libakg.so(akg::BuildToFunc(air::Schedule const&, air::Array<air::NodeRef, void> const&, air::Array<air::NodeRef, void> const&, std::__cxx11::basic_string<char, std::char_traits, std::allocator > const&, air::Map<air::Tensor, air::Buffer, void, void> const&, air::Map<std::__cxx11::basic_string<char, std::char_traits, std::allocator >, air::NodeRef, void, void> const&, bool, std::__cxx11::basic_string<char, std::char_traits, std::allocator > const&, air::BuildConfig const&)+0x24f) [0x7fd326b00dbf]
[bt] (4) /home/sun/gitDownload/akg/mybuild/libakg.so(akg::Lower(air::Schedule, air::Array<air::NodeRef, void> const&, air::Array<air::NodeRef, void> const&, std::__cxx11::basic_string<char, std::char_traits, std::allocator > const&, air::Map<air::Tensor, air::Buffer, void, void> const&, air::Map<std::__cxx11::basic_string<char, std::char_traits, std::allocator >, air::NodeRef, void, void> const&, bool, bool, bool, std::__cxx11::basic_string<char, std::char_traits, std::allocator > const&, air::BuildConfig const&)+0x166) [0x7fd326af67f6]
[bt] (3) /home/sun/gitDownload/akg/mybuild/libakg.so(akg::LowerStmt(air::Schedule, air::Array<air::NodeRef, void> const&, air::Array<air::NodeRef, void> const&, std::__cxx11::basic_string<char, std::char_traits, std::allocator > const&, air::Map<air::Tensor, air::Buffer, void, void> const&, air::Map<std::__cxx11::basic_string<char, std::char_traits, std::allocator >, air::NodeRef, void, void> const&, bool, bool, bool, std::__cxx11::basic_string<char, std::char_traits, std::allocator > const&, air::BuildConfig const&, air::Array<air::NodeRef, void>, air::Array<air::NodeRef, void>, air::Map<air::Tensor, air::Buffer, void, void>, air::Map<air::Tensor, air::Buffer, void, void>, bool)+0x384) [0x7fd326af3b34]
[bt] (2) /home/sun/gitDownload/akg/mybuild/libakg.so(air::schedule::InferBound(air::Schedule const&)+0xda4) [0x7fd327899ad4]
[bt] (1) /home/sun/gitDownload/akg/mybuild/libakg.so(air::schedule::CreateAttachPath(air::Schedule)+0x5d4) [0x7fd32789e654]
[bt] (0) /home/sun/gitDownload/akg/mybuild/libakg.so(dmlc::LogMessageFatal::~LogMessageFatal()+0x4f) [0x7fd326aa5fcf]
File "/home/sun/gitDownload/akg/third_party/incubator-tvm/src/schedule/graph.cc", line 223
TVMError: Check failed: start_attach: Invalid Schedule: cannot find attach point iter_var(h, range(min=0, ext=12)) in the schedule of compute(PoolAvg, 0x3126cc0)Here is my source code:import akg
from akg import tvm

import numpy as np

def padding(X, ph, pw, val=0):
"""Pad X with the given value in 2-D

ph, pw : height and width padding
val : padding value, default 0
"""
assert len(X.shape) >= 2
nh, nw = X.shape[-2], X.shape[-1]
return tvm.compute(
        (*X.shape[0:-2], nh+ph*2, nw+pw*2),
        lambda *i: tvm.if_then_else(
            tvm.any(i[-2]<ph, i[-2]>=nh+ph, i[-1]<pw, i[-1]>=nw+pw),
            val, X[i[:-2]+(i[-2]-ph, i[-1]-pw)]),
        name='PaddedX')

Save to the d2ltvm package.

def conv_out_size(n, k, p, s):
"""Compute the output size by given input size n (width or height),
kernel size k, padding p, and stride s
Return output size (width or height)
"""
return (n - k + 2 * p)//s + 1

def get_conv_data(oc, ic, n, k, p=0, s=1, constructor=None,ctx=tvm.gpu(0),conv_type='direct'):
"""Return random 3-D data tensor, 3-D kernel tenor and empty 3-D output
tensor with the shapes specified by input arguments.

oc, ic : output and input channels
n : input width and height
k : kernel width and height
p : padding size, default 0
s : stride, default 1
constructor : user-defined tensor constructor
"""
np.random.seed(0)
data = np.random.normal(size=(ic, n, n)).astype('float32')
ic_weight = ic
if  conv_type =='depthwise':
    ic_weight=1
weight = np.random.normal(size=(oc, ic_weight, k, k)).astype('float32')
# data =  np.ones(shape=(ic,n,n)).astype('float32')
# weight = np.ones(shape=(oc,ic,k,k)).astype('float32')
on = conv_out_size(n, k, p, s)
out = np.empty((oc, on, on), dtype='float32')
if constructor:
    data, weight, out = (constructor(x,ctx) for x in [data, weight, out])
return data, weight, out

def pool(pool_type,c,nh,nw,kh,kw,ph=0,pw=0,sh=1,sw=1):
rkh = tvm.reduce_axis((0,kh),name='rkh')
rkw = tvm.reduce_axis((0,kw),name='rkw')

oh = conv_out_size(nh,kw,ph,sh)
ow = conv_out_size(nw,kw,pw,sw)

X = tvm.placeholder((c,nh,nw),name='X')
if pool_type=='max':
    PaddedX = padding(X,ph,pw,val=tvm.min_value(X.dtype)) if ph*pw!=0 else X
    Y = tvm.compute(
        (c,oh,ow),
        lambda c,h,w:tvm.max(PaddedX[c,h*sh+rkh,w*sw+rkw],axis=[rkh,rkw]),
        tag='pool_max',name='PoolMax'
    )
elif pool_type=='avg':
    PaddedX = padding(X,ph,pw) if ph*pw!=0 else X
    tsum = tvm.compute(
        (c,oh,ow),
       lambda c,h,w: tvm.sum(PaddedX[c,h*sh+rkh,w*sw+rkw],axis = [rkh,rkw]),
        tag='pool_avg1',name='PoolSum'
    )
    Y = tvm.compute(
        (c,oh,ow),
        lambda c,h,w:tsum[c,h,w]/(kh*kw),
        tag = 'pool_avg2',name='PoolAvg'
        )
else:
        raise ValueError("'Pool type should be 'avg' or 'max'.")
return X,Y,PaddedX

c,n,k,p,s = 4,12,3,1,1
X,Y,PaddedX = pool('avg',c,n,n,k,k,p,p,s,s)
sch = tvm.create_schedule(Y.op)
tvm.schedule.AutoInlineInjective(sch)
PoolSum = Y.op.input_tensors[0]
sch[PoolSum].compute_at(sch[Y], sch[Y].op.axis[2])

print(tvm.lower(sch,[X,Y],simple_mode=True))
mod = akg.build(sch, (X,Y), 'cuda', [], name='myfunc', attrs={}, polyhedral=True, binds=None)

ctx = tvm.context('cuda')
data,_,out_max = get_conv_data(c,c,n,k,p,s,tvm.nd.array,ctx)

mod(data,out_max)
ctx.sync()`

/device gpu
ir by tvm.lower() is printed normally,so something happened with akg.build .
do akg make a default schedule inside ? So I can't do it in a normal tvm way, any tips to merge the 2 stages in avgpooling?

@anyrenwei
Copy link
Contributor

HI, @MingliSun. I think this issue occurs because you use tvm "create_schedule" and set "polyhedral = true" at the same time. In fact, in akg we needn't to create schedule but use the autopoly pass to create schedule automatically. If you still want to "do some schedule to merge stages of avgpooling such as autoInlineInjective", maybe you could set "akg.build(sch, (X,Y), 'cuda', [], name='myfunc', attrs={"enable_auto_fuse":False}, polyhedral=True, binds=None)" to close an "auto_fuse" pass。

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants