-
Notifications
You must be signed in to change notification settings - Fork 636
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Create BlockSparse Tensor #202
Conversation
Commenting on the intent already: nice, I think that it would be great to consolidate the attentions ! I can have a look at the tests which don't pass blocksparse wise, and @ptillet was wondering on whether updating the upstream blocksparse kernel in Triton. Worst case we could host it here, the assumption is that the existing known bug (faulty result when row full of zeros) is not too complicated to fix |
from xformers.ops import masked_matmul | ||
|
||
|
||
class BlockSparseTensor(torch.Tensor): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like the abstraction personally, I think that we'll need to be watertight on the mask description so that users understand well what are the options, but it consolidates the code nicely and makes a lot of sense (don't swap attentions when all you want is changing the mask)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree we will need further documentation and checks.
One thing that I'm still considering is what should the internal representation for the block-sparse be. For now we are just passing whatever triton
expects (which internally gets converted to a combination of CSR and COO), but ultimately we would want to have some guidelines on what we should be doing.
Keeping combinations of CSR and COO is fine for block-sparse as it uses less memory per element (as it gets amortized by the block size), but the cost for generic sparsity might be higher so this will need to be weighted.
|
||
res_gt.sum().backward() | ||
res._blocksparse_values.sum().backward() | ||
# TODO: this is not passing!!! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is that only when a row is [0], or do you have other issues ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The failures here are due to triton-lang/triton#419
@ptillet Are we planning on releasing a new version of 1.1.x
sometime soon?
tests/test_sparse_tensors.py
Outdated
aa = a.clone() | ||
bb = b.clone() | ||
|
||
b = b.transpose(-2, -1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
something which could happen here (not sure) is that the kernel could assume contiguous tensors, and these are not. But even if that was the case it should probably be caught
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We check in https://github.com/facebookresearch/xformers/pull/202/files#diff-cca707218d1b441069abce210ffe653d2719bd9fadc51cc90d15a33cd918a2bcR92-R93 that the tensor is contiguous, so we should be fine here
Yep. FWIW the zero-row bug is almost surely an issue with the LUT computation rather than the Triton kernel/compiler. |
ohh interesting. Not sure to have the cycles but maybe that I can do a PR tomorrow, unless you can smash that before me ? Thoughts on passing non-contiguous tensors ? I had a quick look and I'm not sure that this case is caught |
I can definitely prioritize this. It would be a shame for such relatively minor bugs to turn people off the Triton blocksparse kernels. We've used those internally at OpenAI for a while now -- and so has been Anthropic. We found a bunch of buggs (mostly related to FP32) over the past few months but they've all been fixed in v2.0. What's left to do is probably the zero-row edge case and adding a bunch of asserts. re:contiguous. I believe Triton blocksparse now converts to contiguous automatically when necessary. There was indeed a bug related to that, but it was fixed within a day of being reported (triton-lang/triton#419). |
Gradients are modified in-place, and grad tensor is not checked to be contiguous, yielding wrong results
BTW, @ptillet I found another potential issue with Triton performs While this memory optimization is nice in principle, it doesn't work if the In general, I think it would be preferable to let the user specify if they want to perform the operation in-place or not, as the current in-place operation blocks the user to be able to apply EDIT: looking at triton-lang/triton#419, the problem I'm facing is the same as the reported one, but for |
Codecov Report
@@ Coverage Diff @@
## main #202 +/- ##
==========================================
- Coverage 91.97% 91.97% -0.01%
==========================================
Files 57 58 +1
Lines 2929 3128 +199
==========================================
+ Hits 2694 2877 +183
- Misses 235 251 +16
Flags with carried forward coverage won't be shown. Click here to find out more.
Continue to review full report at Codecov.
|
I believe this is ready to be merged. The current implementation as it stands is fairly fragile as many configurations don't actually work (e.g., when non-contiguous tensors are involved). Some of this is already fixed in triton-lang/triton#419 , but the Nonetheless, on a V100 GPU using Triton 1.1.1, for the forward-pass only on fp32 (backwards will probably be much better but benchmarking it turned out to be harder due to random errors), there doesn't seem to be that much benefits from using Triton version the naive PyTorch operations that I have implemented on this PR. Here are the results I got:
I can post the benchmark script in the PR if it helps. |
I can check it out this evening, afk at the moment. FYI for Triton blocksparse we even block fp32 and force fp16, with a V100 fp32 cannot use tensor cores so the speed is really slow anyway. I don't think that people using fp32 and blocksparse overlap too much, I would focus on fp16 first ? |
1 similar comment
I can check it out this evening, afk at the moment. FYI for Triton blocksparse we even block fp32 and force fp16, with a V100 fp32 cannot use tensor cores so the speed is really slow anyway. I don't think that people using fp32 and blocksparse overlap too much, I would focus on fp16 first ? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, Thanks for this @fmassa , looking forward to the next steps ! I think that we need to cover fp16 here, but can be part of the next PR
from xformers.sparse import BlockSparseTensor | ||
|
||
cuda_only = pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA") | ||
_devices = ["cpu", "cuda:0"] if torch.cuda.is_available() else ["cpu"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that the tests should cover fp16 also, most people will use that in fp16 and for some gpus it will follow a different code path (v100 for instance)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried running on fp16 but I got consistent segfaults. This might be due to my version of triton, or something else I'm doing wrong. Anyway, I'll let the traceback here if it can be of use.
cc @ptillet
Traceback of the segfault
Thread 1 "python" received signal SIGSEGV, Segmentation fault.
0x00007ffef65bb34f in ?? () from /lib/x86_64-linux-gnu/libnvidia-ptxjitcompiler.so.1
(gdb) bt
#0 0x00007ffef65bb34f in ?? () from /lib/x86_64-linux-gnu/libnvidia-ptxjitcompiler.so.1
#1 0x00007ffef65c375f in ?? () from /lib/x86_64-linux-gnu/libnvidia-ptxjitcompiler.so.1
#2 0x00007ffef650b1e4 in ?? () from /lib/x86_64-linux-gnu/libnvidia-ptxjitcompiler.so.1
#3 0x00007ffef650b34f in ?? () from /lib/x86_64-linux-gnu/libnvidia-ptxjitcompiler.so.1
#4 0x00007ffef64e503b in ?? () from /lib/x86_64-linux-gnu/libnvidia-ptxjitcompiler.so.1
#5 0x00007ffef64e5bca in ?? () from /lib/x86_64-linux-gnu/libnvidia-ptxjitcompiler.so.1
#6 0x00007ffef66b2f83 in ?? () from /lib/x86_64-linux-gnu/libnvidia-ptxjitcompiler.so.1
#7 0x00007ffef66b3027 in ?? () from /lib/x86_64-linux-gnu/libnvidia-ptxjitcompiler.so.1
#8 0x00007ffef63a9bf4 in ?? () from /lib/x86_64-linux-gnu/libnvidia-ptxjitcompiler.so.1
#9 0x00007ffef63b2578 in ?? () from /lib/x86_64-linux-gnu/libnvidia-ptxjitcompiler.so.1
#10 0x00007ffef63b67c2 in ?? () from /lib/x86_64-linux-gnu/libnvidia-ptxjitcompiler.so.1
#11 0x00007ffef63b7c2c in ?? () from /lib/x86_64-linux-gnu/libnvidia-ptxjitcompiler.so.1
#12 0x00007ffef63ab05c in __cuda_CallJitEntryPoint () from /lib/x86_64-linux-gnu/libnvidia-ptxjitcompiler.so.1
#13 0x00007fff37951942 in ?? () from /lib/x86_64-linux-gnu/libcuda.so
#14 0x00007fff379a010d in ?? () from /lib/x86_64-linux-gnu/libcuda.so
#15 0x00007fff37733d7a in ?? () from /lib/x86_64-linux-gnu/libcuda.so
#16 0x00007fff376e248e in ?? () from /lib/x86_64-linux-gnu/libcuda.so
#17 0x00007fff377a617c in ?? () from /lib/x86_64-linux-gnu/libcuda.so
#18 0x00007fff0d671216 in triton::driver::dispatch::cuModuleLoadData(CUmod_st**, void const*) () from /private/home/fmassa/.conda/envs/xformers/lib/python3.8/site-packages/triton/_C/libtriton.so
#19 0x00007fff0d6ae865 in cu_load_binary(std::string const&, std::map<std::string, pybind11::object, std::less<std::string>, std::allocator<std::pair<std::string const, pybind11::object> > >&, unsigned long, unsigned long) ()
from /private/home/fmassa/.conda/envs/xformers/lib/python3.8/site-packages/triton/_C/libtriton.so
#20 0x00007fff0d6b21f5 in pybind11::cpp_function::initialize<init_triton_codegen(pybind11::module&&)::{lambda(backend_t, std::string const&, std::map<std::string, pybind11::object, std::less<std::string>, std::allocator<std::pair<std::string const, pybind11::object> > >&, unsigned long, unsigned lo
ng)#2}, std::tuple<unsigned long, unsigned long>, backend_t, std::string const&, std::map<std::string, pybind11::object, std::less<std::string>, std::allocator<std::pair<std::string const, pybind11::object> > >&, unsigned long, unsigned long, pybind11::name, pybind11::scope, pybind11::sibling, pybi
nd11::return_value_policy>(init_triton_codegen(pybind11::module&&)::{lambda(backend_t, std::string const&, std::map<std::string, pybind11::object, std::less<std::string>, std::allocator<std::pair<std::string const, pybind11::object> > >&, unsigned long, unsigned long)#2}&&, std::tuple<unsigned long
, unsigned long> (*)(backend_t, std::string const&, std::map<std::string, pybind11::object, std::less<std::string>, std::allocator<std::pair<std::string const, pybind11::object> > >&, unsigned long, unsigned long), pybind11::name const&, pybind11::scope const&, pybind11::sibling const&, pybind11::r
eturn_value_policy const&)::{lambda(pybind11::detail::function_call&)#3}::_FUN(pybind11::detail::function_call) () from /private/home/fmassa/.conda/envs/xformers/lib/python3.8/site-packages/triton/_C/libtriton.so
#21 0x00007fff0d6ab6b2 in pybind11::cpp_function::dispatcher(_object*, _object*, _object*) () from /private/home/fmassa/.conda/envs/xformers/lib/python3.8/site-packages/triton/_C/libtriton.so
#22 0x00005555556a8348 in cfunction_call_varargs (kwargs=<optimized out>, args=<optimized out>, func=0x7fff3751ed60) at /tmp/build/80754af9/python_1618343417471/work/Objects/call.c:743
#23 PyCFunction_Call (func=0x7fff3751ed60, args=<optimized out>, kwargs=<optimized out>) at /tmp/build/80754af9/python_1618343417471/work/Objects/call.c:773
#24 0x0000555555697dbc in _PyObject_MakeTpCall (callable=0x7fff3751ed60, args=<optimized out>, nargs=<optimized out>, keywords=0x0) at /tmp/build/80754af9/python_1618343417471/work/Objects/call.c:159
#25 0x0000555555723666 in _PyObject_Vectorcall (kwnames=0x0, nargsf=<optimized out>, args=0x7fff29043b40, callable=0x7fff3751ed60) at /tmp/build/80754af9/python_1618343417471/work/Include/cpython/abstract.h:125
#26 call_function (kwnames=0x0, oparg=<optimized out>, pp_stack=<synthetic pointer>, tstate=0x5555558f4850) at /tmp/build/80754af9/python_1618343417471/work/Python/ceval.c:4963
#27 _PyEval_EvalFrameDefault (f=<optimized out>, throwflag=<optimized out>) at /tmp/build/80754af9/python_1618343417471/work/Python/ceval.c:3469
#28 0x00005555556eee3f in function_code_fastcall (globals=<optimized out>, nargs=3, args=<optimized out>, co=<optimized out>) at /tmp/build/80754af9/python_1618343417471/work/Objects/call.c:284
#29 _PyFunction_Vectorcall (kwnames=0x0, nargsf=<optimized out>, stack=0x7fffffffa300, func=0x7fff28e98040) at /tmp/build/80754af9/python_1618343417471/work/Objects/call.c:411
#30 _PyObject_FastCallDict (kwargs=<optimized out>, nargsf=<optimized out>, args=0x7fffffffa300, callable=0x7fff28e98040) at /tmp/build/80754af9/python_1618343417471/work/Objects/call.c:96
#31 _PyObject_Call_Prepend (callable=0x7fff28e98040, obj=<optimized out>, args=<optimized out>, kwargs=<optimized out>) at /tmp/build/80754af9/python_1618343417471/work/Objects/call.c:888
#32 0x00005555556eef9a in slot_tp_init (self=0x7fff28f9b910, args=0x7fff28e14440, kwds=0x0) at /tmp/build/80754af9/python_1618343417471/work/Objects/typeobject.c:6790
#33 0x0000555555697d2e in type_call (kwds=0x0, args=0x7fff28e14440, type=<optimized out>) at /tmp/build/80754af9/python_1618343417471/work/Objects/typeobject.c:994
#34 _PyObject_MakeTpCall (callable=0x555558cf9240, args=<optimized out>, nargs=<optimized out>, keywords=0x0) at /tmp/build/80754af9/python_1618343417471/work/Objects/call.c:159
#35 0x000055555571f545 in _PyObject_Vectorcall (kwnames=0x0, nargsf=<optimized out>, args=0x5555f36a2098, callable=<optimized out>) at /tmp/build/80754af9/python_1618343417471/work/Include/cpython/abstract.h:125
#36 call_function (kwnames=0x0, oparg=<optimized out>, pp_stack=<synthetic pointer>, tstate=0x5555558f4850) at /tmp/build/80754af9/python_1618343417471/work/Python/ceval.c:4963
#37 _PyEval_EvalFrameDefault (f=<optimized out>, throwflag=<optimized out>) at /tmp/build/80754af9/python_1618343417471/work/Python/ceval.c:3500
#38 0x00005555556ed821 in PyEval_EvalFrameEx (throwflag=0, f=0x5555f36a1e10) at /tmp/build/80754af9/python_1618343417471/work/Python/ceval.c:741
#39 _PyEval_EvalCodeWithName (_co=<optimized out>, globals=<optimized out>, locals=<optimized out>, args=<optimized out>, argcount=<optimized out>, kwnames=<optimized out>, kwargs=0x7fff28f84988, kwcount=<optimized out>, kwstep=1, defs=0x0, defcount=0, kwdefs=0x7fff28f009c0, closure=0x0,
name=0x7ffff78cc1f0, qualname=0x7fff29b4b8b0) at /tmp/build/80754af9/python_1618343417471/work/Python/ceval.c:4298
#40 0x00005555556ee0a3 in _PyFunction_Vectorcall (func=<optimized out>, stack=0x7fff28f848f0, nargsf=<optimized out>, kwnames=<optimized out>) at /tmp/build/80754af9/python_1618343417471/work/Objects/call.c:436
#41 0x00005555556eec71 in _PyObject_FastCallDict (kwargs=0x7fff29bf0ac0, nargsf=19, args=0x7fff21cbbe90, callable=0x7fff28e985e0) at /tmp/build/80754af9/python_1618343417471/work/Objects/call.c:104
#42 _PyObject_Call_Prepend (callable=0x7fff28e985e0, obj=<optimized out>, args=<optimized out>, kwargs=0x7fff29bf0ac0) at /tmp/build/80754af9/python_1618343417471/work/Objects/call.c:888
#43 0x00005555556eef0a in slot_tp_call (self=0x7fff28f9b550, args=0x7fff21ceb040, kwds=0x7fff29bf0ac0) at /tmp/build/80754af9/python_1618343417471/work/Objects/typeobject.c:6556
#44 0x00005555556985fb in PyObject_Call (callable=0x7fff28f9b550, args=0x7fff21ceb040, kwargs=0x7fff29bf0ac0) at /tmp/build/80754af9/python_1618343417471/work/Objects/call.c:246
#45 0x00005555557210b6 in do_call_core (kwdict=0x7fff29bf0ac0, callargs=0x7fff21ceb040, func=0x7fff28f9b550, tstate=<optimized out>) at /tmp/build/80754af9/python_1618343417471/work/Python/ceval.c:5010
#46 _PyEval_EvalFrameDefault (f=<optimized out>, throwflag=<optimized out>) at /tmp/build/80754af9/python_1618343417471/work/Python/ceval.c:3559
#47 0x00005555556ed821 in PyEval_EvalFrameEx (throwflag=0, f=0x7fff29c6d800) at /tmp/build/80754af9/python_1618343417471/work/Python/ceval.c:741
#48 _PyEval_EvalCodeWithName (_co=<optimized out>, globals=<optimized out>, locals=<optimized out>, args=<optimized out>, argcount=<optimized out>, kwnames=<optimized out>, kwargs=0x7fff29083740, kwcount=<optimized out>, kwstep=1, defs=0x0, defcount=0, kwdefs=0x0, closure=0x7fff28f3c8c0,
name=0x7ffff76b36b0, qualname=0x7fff28e31490) at /tmp/build/80754af9/python_1618343417471/work/Python/ceval.c:4298
#49 0x00005555556ee0a3 in _PyFunction_Vectorcall (func=<optimized out>, stack=0x7fff290836b0, nargsf=<optimized out>, kwnames=<optimized out>) at /tmp/build/80754af9/python_1618343417471/work/Objects/call.c:436
#50 0x0000555555698693 in PyVectorcall_Call (kwargs=<optimized out>, tuple=<optimized out>, callable=0x7fff28f9ae50) at /tmp/build/80754af9/python_1618343417471/work/Objects/call.c:200
#51 PyObject_Call (callable=0x7fff28f9ae50, args=<optimized out>, kwargs=<optimized out>) at /tmp/build/80754af9/python_1618343417471/work/Objects/call.c:228
#52 0x00005555557210b6 in do_call_core (kwdict=0x7fff29bf0400, callargs=0x7fff2900ab80, func=0x7fff28f9ae50, tstate=<optimized out>) at /tmp/build/80754af9/python_1618343417471/work/Python/ceval.c:5010
#53 _PyEval_EvalFrameDefault (f=<optimized out>, throwflag=<optimized out>) at /tmp/build/80754af9/python_1618343417471/work/Python/ceval.c:3559
#54 0x00005555556ed270 in PyEval_EvalFrameEx (throwflag=0, f=0x7fff29b74440) at /tmp/build/80754af9/python_1618343417471/work/Python/ceval.c:741
#55 _PyEval_EvalCodeWithName (_co=<optimized out>, globals=<optimized out>, locals=<optimized out>, args=<optimized out>, argcount=<optimized out>, kwnames=0x7fff21cda598, kwargs=0x7fff29083678, kwcount=<optimized out>, kwstep=1, defs=0x0, defcount=0, kwdefs=0x0, closure=0x0, name=0x7ffff78cc1f0,
qualname=0x7fff28e1b760) at /tmp/build/80754af9/python_1618343417471/work/Python/ceval.c:4298
#56 0x00005555556ee0a3 in _PyFunction_Vectorcall (func=<optimized out>, stack=0x7fff290835e0, nargsf=<optimized out>, kwnames=<optimized out>) at /tmp/build/80754af9/python_1618343417471/work/Objects/call.c:436
#57 0x00005555556eec71 in _PyObject_FastCallDict (kwargs=0x7fff3751be80, nargsf=19, args=0x7fff21cbbdf0, callable=0x7fff28e98700) at /tmp/build/80754af9/python_1618343417471/work/Objects/call.c:104
#58 _PyObject_Call_Prepend (callable=0x7fff28e98700, obj=<optimized out>, args=<optimized out>, kwargs=0x7fff3751be80) at /tmp/build/80754af9/python_1618343417471/work/Objects/call.c:888
#59 0x00005555556eef0a in slot_tp_call (self=0x7fff28f9b760, args=0x7fff2900a700, kwds=0x7fff3751be80) at /tmp/build/80754af9/python_1618343417471/work/Objects/typeobject.c:6556
#60 0x0000555555697dbc in _PyObject_MakeTpCall (callable=0x7fff28f9b760, args=<optimized out>, nargs=<optimized out>, keywords=0x7fff28e79160) at /tmp/build/80754af9/python_1618343417471/work/Objects/call.c:159
#61 0x00005555557202ab in _PyObject_Vectorcall (kwnames=0x7fff28e79160, nargsf=<optimized out>, args=<optimized out>, callable=<optimized out>) at /tmp/build/80754af9/python_1618343417471/work/Include/cpython/abstract.h:125
#62 call_function (kwnames=0x7fff28e79160, oparg=<optimized out>, pp_stack=<synthetic pointer>, tstate=<optimized out>) at /tmp/build/80754af9/python_1618343417471/work/Python/ceval.c:4963
#63 _PyEval_EvalFrameDefault (f=<optimized out>, throwflag=<optimized out>) at /tmp/build/80754af9/python_1618343417471/work/Python/ceval.c:3515
#64 0x00005555556edfcb in function_code_fastcall (globals=<optimized out>, nargs=11, args=<optimized out>, co=<optimized out>) at /tmp/build/80754af9/python_1618343417471/work/Objects/call.c:284
#65 _PyFunction_Vectorcall (func=<optimized out>, stack=0x5555e73fccb0, nargsf=<optimized out>, kwnames=<optimized out>) at /tmp/build/80754af9/python_1618343417471/work/Objects/call.c:411
--Type <RET> for more, q to quit, c to continue without paging--
#66 0x00005555556575db in _PyObject_Vectorcall (kwnames=0x0, nargsf=<optimized out>, args=0x5555e73fccb0, callable=0x7fff28e7f4c0) at /tmp/build/80754af9/python_1618343417471/work/Include/cpython/abstract.h:127
#67 call_function (kwnames=0x0, oparg=<optimized out>, pp_stack=<synthetic pointer>, tstate=0x5555558f4850) at /tmp/build/80754af9/python_1618343417471/work/Python/ceval.c:4963
#68 _PyEval_EvalFrameDefault (f=<optimized out>, throwflag=<optimized out>) at /tmp/build/80754af9/python_1618343417471/work/Python/ceval.c:3500
#69 0x00005555556edfcb in function_code_fastcall (globals=<optimized out>, nargs=21, args=<optimized out>, co=<optimized out>) at /tmp/build/80754af9/python_1618343417471/work/Objects/call.c:284
#70 _PyFunction_Vectorcall (func=<optimized out>, stack=0x7fff29083538, nargsf=<optimized out>, kwnames=<optimized out>) at /tmp/build/80754af9/python_1618343417471/work/Objects/call.c:411
#71 0x00005555556f3a22 in PyVectorcall_Call (kwargs=0x0, tuple=<optimized out>, callable=0x7fff28ff4a60) at /tmp/build/80754af9/python_1618343417471/work/Objects/call.c:200
#72 PyObject_Call (kwargs=0x0, args=<optimized out>, callable=0x7fff28ff4a60) at /tmp/build/80754af9/python_1618343417471/work/Objects/call.c:228
#73 PyEval_CallObjectWithKeywords (kwargs=0x0, args=<optimized out>, callable=0x7fff28ff4a60) at /tmp/build/80754af9/python_1618343417471/work/Objects/call.c:810
#74 PyObject_CallObject (callable=0x7fff28ff4a60, args=<optimized out>) at /tmp/build/80754af9/python_1618343417471/work/Objects/call.c:818
#75 0x00007ffff5c9b608 in THPFunction_apply(_object*, _object*) () from /private/home/fmassa/.conda/envs/xformers/lib/python3.8/site-packages/torch/lib/libtorch_python.so
#76 0x00005555556a83d0 in cfunction_call_varargs (kwargs=<optimized out>, args=<optimized out>, func=0x7fff290a5ea0) at /tmp/build/80754af9/python_1618343417471/work/Objects/call.c:758
#77 PyCFunction_Call (func=0x7fff290a5ea0, args=<optimized out>, kwargs=<optimized out>) at /tmp/build/80754af9/python_1618343417471/work/Objects/call.c:773
#78 0x0000555555697dbc in _PyObject_MakeTpCall (callable=0x7fff290a5ea0, args=<optimized out>, nargs=<optimized out>, keywords=0x0) at /tmp/build/80754af9/python_1618343417471/work/Objects/call.c:159
#79 0x0000555555723666 in _PyObject_Vectorcall (kwnames=0x0, nargsf=<optimized out>, args=0x55555832ddb0, callable=0x7fff290a5ea0) at /tmp/build/80754af9/python_1618343417471/work/Include/cpython/abstract.h:125
#80 call_function (kwnames=0x0, oparg=<optimized out>, pp_stack=<synthetic pointer>, tstate=0x5555558f4850) at /tmp/build/80754af9/python_1618343417471/work/Python/ceval.c:4963
#81 _PyEval_EvalFrameDefault (f=<optimized out>, throwflag=<optimized out>) at /tmp/build/80754af9/python_1618343417471/work/Python/ceval.c:3469
#82 0x00005555556eee3f in function_code_fastcall (globals=<optimized out>, nargs=3, args=<optimized out>, co=<optimized out>) at /tmp/build/80754af9/python_1618343417471/work/Objects/call.c:284
#83 _PyFunction_Vectorcall (kwnames=0x0, nargsf=<optimized out>, stack=0x7fffffffb670, func=0x7fff28ff4ca0) at /tmp/build/80754af9/python_1618343417471/work/Objects/call.c:411
#84 _PyObject_FastCallDict (kwargs=<optimized out>, nargsf=<optimized out>, args=0x7fffffffb670, callable=0x7fff28ff4ca0) at /tmp/build/80754af9/python_1618343417471/work/Objects/call.c:96
#85 _PyObject_Call_Prepend (callable=0x7fff28ff4ca0, obj=<optimized out>, args=<optimized out>, kwargs=<optimized out>) at /tmp/build/80754af9/python_1618343417471/work/Objects/call.c:888
#86 0x00005555556eef0a in slot_tp_call (self=0x7ffff7820490, args=0x7ffff772e8c0, kwds=0x0) at /tmp/build/80754af9/python_1618343417471/work/Objects/typeobject.c:6556
#87 0x0000555555697dbc in _PyObject_MakeTpCall (callable=0x7ffff7820490, args=<optimized out>, nargs=<optimized out>, keywords=0x0) at /tmp/build/80754af9/python_1618343417471/work/Objects/call.c:159
#88 0x0000555555723666 in _PyObject_Vectorcall (kwnames=0x0, nargsf=<optimized out>, args=0x7fff29084580, callable=0x7ffff7820490) at /tmp/build/80754af9/python_1618343417471/work/Include/cpython/abstract.h:125
#89 call_function (kwnames=0x0, oparg=<optimized out>, pp_stack=<synthetic pointer>, tstate=0x5555558f4850) at /tmp/build/80754af9/python_1618343417471/work/Python/ceval.c:4963
#90 _PyEval_EvalFrameDefault (f=<optimized out>, throwflag=<optimized out>) at /tmp/build/80754af9/python_1618343417471/work/Python/ceval.c:3469
#91 0x00005555556ee36b in function_code_fastcall (globals=<optimized out>, nargs=4, args=<optimized out>, co=<optimized out>) at /tmp/build/80754af9/python_1618343417471/work/Objects/call.c:284
#92 _PyFunction_Vectorcall (kwnames=<optimized out>, nargsf=<optimized out>, stack=0x555558b50d48, func=0x7fff28fe21f0) at /tmp/build/80754af9/python_1618343417471/work/Objects/call.c:411
#93 _PyObject_Vectorcall (kwnames=<optimized out>, nargsf=<optimized out>, args=0x555558b50d48, callable=0x7fff28fe21f0) at /tmp/build/80754af9/python_1618343417471/work/Include/cpython/abstract.h:127
#94 method_vectorcall (method=<optimized out>, args=0x555558b50d50, nargsf=<optimized out>, kwnames=<optimized out>) at /tmp/build/80754af9/python_1618343417471/work/Objects/classobject.c:60
#95 0x0000555555657a61 in _PyObject_Vectorcall (kwnames=0x0, nargsf=<optimized out>, args=0x555558b50d50, callable=0x7fff2a6deb40) at /tmp/build/80754af9/python_1618343417471/work/Include/cpython/abstract.h:127
#96 call_function (kwnames=0x0, oparg=<optimized out>, pp_stack=<synthetic pointer>, tstate=0x5555558f4850) at /tmp/build/80754af9/python_1618343417471/work/Python/ceval.c:4963
#97 _PyEval_EvalFrameDefault (f=<optimized out>, throwflag=<optimized out>) at /tmp/build/80754af9/python_1618343417471/work/Python/ceval.c:3469
#98 0x00005555556ed270 in PyEval_EvalFrameEx (throwflag=0, f=0x555558b50b90) at /tmp/build/80754af9/python_1618343417471/work/Python/ceval.c:741
#99 _PyEval_EvalCodeWithName (_co=<optimized out>, globals=<optimized out>, locals=<optimized out>, args=<optimized out>, argcount=<optimized out>, kwnames=0x0, kwargs=0x7fff2a043428, kwcount=<optimized out>, kwstep=1, defs=0x7fff28f1e258, defcount=2, kwdefs=0x0, closure=0x0, name=0x7fff374d4b70,
qualname=0x7fff28fc5510) at /tmp/build/80754af9/python_1618343417471/work/Python/ceval.c:4298
#100 0x00005555556ee480 in _PyFunction_Vectorcall (kwnames=<optimized out>, nargsf=<optimized out>, stack=0x7fff2a043400, func=0x7fff28fe2550) at /tmp/build/80754af9/python_1618343417471/work/Objects/call.c:436
#101 _PyObject_Vectorcall (kwnames=<optimized out>, nargsf=<optimized out>, args=0x7fff2a043400, callable=0x7fff28fe2550) at /tmp/build/80754af9/python_1618343417471/work/Include/cpython/abstract.h:127
#102 method_vectorcall (method=<optimized out>, args=0x7fff2a043408, nargsf=<optimized out>, kwnames=<optimized out>) at /tmp/build/80754af9/python_1618343417471/work/Objects/classobject.c:60
#103 0x00005555556575db in _PyObject_Vectorcall (kwnames=0x0, nargsf=<optimized out>, args=0x7fff2a043408, callable=0x7fff29df0540) at /tmp/build/80754af9/python_1618343417471/work/Include/cpython/abstract.h:127
#104 call_function (kwnames=0x0, oparg=<optimized out>, pp_stack=<synthetic pointer>, tstate=0x5555558f4850) at /tmp/build/80754af9/python_1618343417471/work/Python/ceval.c:4963
#105 _PyEval_EvalFrameDefault (f=<optimized out>, throwflag=<optimized out>) at /tmp/build/80754af9/python_1618343417471/work/Python/ceval.c:3500
#106 0x00005555556ed270 in PyEval_EvalFrameEx (throwflag=0, f=0x7fff2a043240) at /tmp/build/80754af9/python_1618343417471/work/Python/ceval.c:741
#107 _PyEval_EvalCodeWithName (_co=<optimized out>, globals=<optimized out>, locals=<optimized out>, args=<optimized out>, argcount=<optimized out>, kwnames=0x0, kwargs=0x7fff36373c00, kwcount=<optimized out>, kwstep=1, defs=0x0, defcount=0, kwdefs=0x0, closure=0x0, name=0x7ffff77364e0,
qualname=0x7ffff77364e0) at /tmp/build/80754af9/python_1618343417471/work/Python/ceval.c:4298
#108 0x00005555556ee0a3 in _PyFunction_Vectorcall (func=<optimized out>, stack=0x7fff36373bd8, nargsf=<optimized out>, kwnames=<optimized out>) at /tmp/build/80754af9/python_1618343417471/work/Objects/call.c:436
#109 0x0000555555657a61 in _PyObject_Vectorcall (kwnames=0x0, nargsf=<optimized out>, args=0x7fff36373bd8, callable=0x7fff3743b040) at /tmp/build/80754af9/python_1618343417471/work/Include/cpython/abstract.h:127
#110 call_function (kwnames=0x0, oparg=<optimized out>, pp_stack=<synthetic pointer>, tstate=0x5555558f4850) at /tmp/build/80754af9/python_1618343417471/work/Python/ceval.c:4963
#111 _PyEval_EvalFrameDefault (f=<optimized out>, throwflag=<optimized out>) at /tmp/build/80754af9/python_1618343417471/work/Python/ceval.c:3469
#112 0x00005555556ed270 in PyEval_EvalFrameEx (throwflag=0, f=0x7fff36373a40) at /tmp/build/80754af9/python_1618343417471/work/Python/ceval.c:741
#113 _PyEval_EvalCodeWithName (_co=<optimized out>, globals=<optimized out>, locals=<optimized out>, args=<optimized out>, argcount=<optimized out>, kwnames=0x0, kwargs=0x7fff2a0ba1f8, kwcount=<optimized out>, kwstep=1, defs=0x7fff28fc85f8, defcount=1, kwdefs=0x0, closure=0x0,
name=0x7ffff77322f0, qualname=0x7ffff77322f0) at /tmp/build/80754af9/python_1618343417471/work/Python/ceval.c:4298
#114 0x00005555556ee0a3 in _PyFunction_Vectorcall (func=<optimized out>, stack=0x7fff2a0ba1e0, nargsf=<optimized out>, kwnames=<optimized out>) at /tmp/build/80754af9/python_1618343417471/work/Objects/call.c:436
#115 0x00005555556575db in _PyObject_Vectorcall (kwnames=0x0, nargsf=<optimized out>, args=0x7fff2a0ba1e0, callable=0x7fff28fc7790) at /tmp/build/80754af9/python_1618343417471/work/Include/cpython/abstract.h:127
#116 call_function (kwnames=0x0, oparg=<optimized out>, pp_stack=<synthetic pointer>, tstate=0x5555558f4850) at /tmp/build/80754af9/python_1618343417471/work/Python/ceval.c:4963
#117 _PyEval_EvalFrameDefault (f=<optimized out>, throwflag=<optimized out>) at /tmp/build/80754af9/python_1618343417471/work/Python/ceval.c:3500
#118 0x00005555556edfcb in function_code_fastcall (globals=<optimized out>, nargs=4, args=<optimized out>, co=<optimized out>) at /tmp/build/80754af9/python_1618343417471/work/Objects/call.c:284
#119 _PyFunction_Vectorcall (func=<optimized out>, stack=0x7ffff78185b0, nargsf=<optimized out>, kwnames=<optimized out>) at /tmp/build/80754af9/python_1618343417471/work/Objects/call.c:411
#120 0x00005555556575db in _PyObject_Vectorcall (kwnames=0x0, nargsf=<optimized out>, args=0x7ffff78185b0, callable=0x7fff29ec93a0) at /tmp/build/80754af9/python_1618343417471/work/Include/cpython/abstract.h:127
#121 call_function (kwnames=0x0, oparg=<optimized out>, pp_stack=<synthetic pointer>, tstate=0x5555558f4850) at /tmp/build/80754af9/python_1618343417471/work/Python/ceval.c:4963
#122 _PyEval_EvalFrameDefault (f=<optimized out>, throwflag=<optimized out>) at /tmp/build/80754af9/python_1618343417471/work/Python/ceval.c:3500
#123 0x00005555556ed270 in PyEval_EvalFrameEx (throwflag=0, f=0x7ffff7818440) at /tmp/build/80754af9/python_1618343417471/work/Python/ceval.c:741
#124 _PyEval_EvalCodeWithName (_co=<optimized out>, globals=<optimized out>, locals=<optimized out>, args=<optimized out>, argcount=<optimized out>, kwnames=0x0, kwargs=0x0, kwcount=<optimized out>, kwstep=2, defs=0x0, defcount=0, kwdefs=0x0, closure=0x0, name=0x0, qualname=0x0)
at /tmp/build/80754af9/python_1618343417471/work/Python/ceval.c:4298
#125 0x0000555555782543 in PyEval_EvalCodeEx () at /tmp/build/80754af9/python_1618343417471/work/Python/ceval.c:4327
#126 PyEval_EvalCode (co=<optimized out>, globals=<optimized out>, locals=<optimized out>) at /tmp/build/80754af9/python_1618343417471/work/Python/ceval.c:718
#127 0x00005555557825e4 in run_eval_code_obj (co=0x7ffff77c4c90, globals=0x7ffff787a9c0, locals=0x7ffff787a9c0) at /tmp/build/80754af9/python_1618343417471/work/Python/pythonrun.c:1165
#128 0x00005555557a8854 in run_mod (mod=<optimized out>, filename=<optimized out>, globals=0x7ffff787a9c0, locals=0x7ffff787a9c0, flags=<optimized out>, arena=<optimized out>) at /tmp/build/80754af9/python_1618343417471/work/Python/pythonrun.c:1187
#129 0x0000555555669390 in pyrun_file (fp=0x5555558f0340, filename=0x7ffff7848eb0, start=<optimized out>, globals=0x7ffff787a9c0, locals=0x7ffff787a9c0, closeit=1, flags=0x7fffffffc5b8) at /tmp/build/80754af9/python_1618343417471/work/Python/pythonrun.c:1084
#130 0x000055555566c0d2 in pyrun_simple_file (flags=0x7fffffffc5b8, closeit=1, filename=0x7ffff7848eb0, fp=0x5555558f0340) at /tmp/build/80754af9/python_1618343417471/work/Python/pythonrun.c:439
#131 PyRun_SimpleFileExFlags (fp=0x5555558f0340, filename=<optimized out>, closeit=1, flags=0x7fffffffc5b8) at /tmp/build/80754af9/python_1618343417471/work/Python/pythonrun.c:472
#132 0x000055555566cbf0 in pymain_run_file (cf=0x7fffffffc5b8, config=0x5555558f39b0) at /tmp/build/80754af9/python_1618343417471/work/Modules/main.c:391
#133 pymain_run_python (exitcode=0x7fffffffc5b0) at /tmp/build/80754af9/python_1618343417471/work/Modules/main.c:616
#134 Py_RunMain () at /tmp/build/80754af9/python_1618343417471/work/Modules/main.c:695
#135 0x00005555557aba09 in Py_BytesMain (argc=<optimized out>, argv=<optimized out>) at /tmp/build/80754af9/python_1618343417471/work/Modules/main.c:1141
--Type <RET> for more, q to quit, c to continue without paging--
#136 0x00007ffff7db40b3 in __libc_start_main (main=0x55555566d460 <main>, argc=2, argv=0x7fffffffc7b8, init=<optimized out>, fini=<optimized out>, rtld_fini=<optimized out>, stack_end=0x7fffffffc7a8) at ../csu/libc-start.c:308
#137 0x000055555573afe5 in _start () at ../sysdeps/x86_64/elf/start.S:103
Merging, I'll iterate on improving the overall pipeline on follow-up PRs |
FYI, I've merged a bunch of fixes in triton blocksparse that should take care of the issues mentioned (and also improve performance on triangular matrices) |
…oder [CI] Unit test vs. Pytorch Encoder and Decoder 1/2
What does this PR do?
This PR creates a BlockSparse Tensor (similar to SparseCSRTensor). This will ultimately enable to use the same code-path in the
scaled_dot_product_attention
, so that users just need to pass a block-sparse matrix in the mask to get the expected results (as is the case now withSparseCSRTensor
). This is not yet entirely plugged, but will be done in a follow-up PR.I haven't made the
Blocksparse
nn.Module
use our new Tensor for now as it relies on additional information for the softmax which I would rather not add in theBlockSparseTensor
API (all the different types of masking).Note that for now the backward of
masked_matmul
is not passing, which is very weird as it was pretty-much a copy-paste of what we had before. This needs to be further investigated.