-
Notifications
You must be signed in to change notification settings - Fork 224
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Batch Matmul Op #1023
Batch Matmul Op #1023
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reviewed 6 of 6 files at r1, 6 of 6 files at r4, all commit messages.
Reviewable status: all files reviewed, 25 unresolved discussions (waiting on @KateUnger)
lib/kernels/include/kernels/batch_matmul_kernels.h
line 23 at r1 (raw file):
BMMPerDeviceState init_kernel(PerDeviceFFHandle handle, Allocator allocator,
Suggestion:
BMMPerDeviceState init_kernel(PerDeviceFFHandle const &handle,
Allocator const &allocator,
lib/kernels/include/kernels/batch_matmul_kernels.h
line 29 at r1 (raw file):
void forward_kernel(ffStream_t stream, BMMPerDeviceState const *meta, float *o_ptr,
Suggestion:
void forward_kernel(ffStream_t stream,
BMMPerDeviceState const &meta,
float *output_ptr,
lib/kernels/include/kernels/batch_matmul_kernels.h
line 30 at r1 (raw file):
BMMPerDeviceState const *meta, float *o_ptr, float const *a_ptr,
Suggestion:
float const *lhs_input_ptr,
lib/kernels/include/kernels/batch_matmul_kernels.h
line 31 at r1 (raw file):
float *o_ptr, float const *a_ptr, float const *b_ptr,
Suggestion:
float const *rhs_input_ptr,
lib/kernels/include/kernels/batch_matmul_kernels.h
line 32 at r1 (raw file):
float const *a_ptr, float const *b_ptr, float const *c_ptr,
lib/kernels/include/kernels/batch_matmul_kernels.h
line 41 at r1 (raw file):
void backward_kernel(ffStream_t stream, BMMPerDeviceState const *meta, float const *o_ptr,
Suggestion:
BMMPerDeviceState const &meta,
float const *o_ptr,
lib/kernels/src/hip/batch_matmul_kernels.cpp
line 0 at r1 (raw file):
If you have to modify the .cpp
file (i.e., AMD) you probably also have to modify the .cu
file (i.e., CUDA)
lib/kernels/src/hip/batch_matmul_kernels.cpp
line 33 at r1 (raw file):
void forward_kernel(hipStream_t stream, BatchMatmulPerDeviceState const *meta, float *o_ptr,
Suggestion:
void forward_kernel(hipStream_t stream,
BatchMatmulPerDeviceState const &meta,
float *o_ptr,
lib/op-attrs/include/op-attrs/ops/batch_matmul.h
line 15 at r1 (raw file):
FF_VISITABLE_STRUCT(BatchMatmulAttrs, a_seq_length_dim, b_seq_length_dim); int get_aSeqLengthDim(BatchMatmulAttrs const &attrs);
Probably not necessary--the additional functions were present for attention because iirc they involved some computation
lib/runtime/include/runtime/config.h
line 110 at r4 (raw file):
FFIterationConfig(); void reset(); req<int> seq_length;
req
shouldn't be necessary here since it's using nonstandard construction (so it's allowed to be default constructible)
Suggestion:
int seq_length;
lib/runtime/src/ops/batch_matmul.h
line 86 at r1 (raw file):
/* public: */ /* int a_seq_length_dim, b_seq_length_dim; */ /* }; */
lib/runtime/src/ops/batch_matmul.h
line 91 at r1 (raw file):
#endif
Why is there this giant commented-out block after the #endif
?
lib/runtime/src/ops/batch_matmul.cc
line 20 at r1 (raw file):
// #include "kernels/profiling.h" // #include "legion/legion_utilities.h" // #include "tasks.h"
lib/runtime/src/ops/batch_matmul.cc
line 77 at r1 (raw file):
static DeviceSpecificArg<BMMPerDeviceState> init_task_impl(TaskArgumentAccessor const &acc) { auto const a_seq_length_dim = acc.get_argument<int>(A_SEQ_LENGTH_DIM);
As long as the type is short you can skip auto
--in general auto
harms readability unless the type name is long (here it's a bit weird because the type is repeated twice, but if the type name is super short I'd lean toward using the type name)
Suggestion:
int const a_seq_length_dim = acc.get_argument<int>(A_SEQ_LENGTH_DIM);
lib/runtime/src/ops/batch_matmul.cc
line 89 at r1 (raw file):
b_seq_length_dim)); // assert(weight.shape.get_volume() * sizeof(float) ==
Should this be commented out? Or should it be removed entirely?
lib/runtime/src/ops/batch_matmul.cc
line 105 at r1 (raw file):
static optional<float> forward_task_impl(TaskArgumentAccessor const &acc) { assert(regions.size() == 3); assert(task->regions.size() == 3);
There's no way this compiles--regions
and task
don't exist
Code quote:
assert(regions.size() == 3);
assert(task->regions.size() == 3);
lib/runtime/src/ops/batch_matmul.cc
line 113 at r1 (raw file):
ProfilingSettings profiling = acc.get_argument<ProfilingSettings>(PROFILING); auto per_device_state = acc.get_argument<BMMPerDeviceState>(PER_DEVICE_STATE); FFIterationConfig const *iter_config = (FFIterationConfig const *)task->args;
task
does not exist--this should be done via a new RuntimeArgRef
for the iteration config
lib/runtime/src/ops/batch_matmul.cc
line 140 at r1 (raw file):
a_input.get_float_ptr(), b_input.get_float_ptr(), NULL, //c_ptr
Okay, this appears to just be null everywhere, let's just delete it
lib/runtime/src/ops/batch_matmul.cc
line 159 at r1 (raw file):
// Currently assume C is NULL assert(regions.size() == 6); assert(task->regions.size() == 6);
task
and regions
do not exist
Code quote:
assert(regions.size() == 6);
assert(task->regions.size() == 6);
lib/runtime/src/ops/batch_matmul.cc
line 169 at r1 (raw file):
auto output_grad = acc.get_tensor_grad<Permissions::RW>(OUTPUT); // is this equivalent to checking `Domain` equality? assert(output == output_grad);
Suggestion:
assert(output.shape == output_grad.shape);
lib/runtime/src/ops/batch_matmul.cc
line 173 at r1 (raw file):
auto a_input = acc.get_tensor<Permissions::RO>(A_INPUT); auto a_input_grad = acc.get_tensor_grad<Permissions::RW>(A_INPUT); assert(a_input == a_input_grad);
Suggestion:
assert(a_input.shape == a_input_grad.shape);
lib/runtime/src/ops/batch_matmul.cc
line 196 at r1 (raw file):
} // TODO: add support for meta->a_seq_length_dim >= 0
Create issue
lib/runtime/src/ops/batch_matmul.cc
line 275 at r1 (raw file):
init.add_unchecked_arg_slot(HANDLE, ff_handle()); register_task(BATCHMATMUL_INIT_TASK_ID, "BatchMatmul Init", init, init_task);
You don't have a variable called attrs
in scope so I'm not sure how the previous code would ever compile
Suggestion:
OpTaskSignature init(OpTaskType::INIT);
init.add_arg_slot<int>(A_SEQ_LENGTH_DIM);
init.add_arg_slot<int>(B_SEQ_LENGTH_DIM);
init.add_unchecked_arg_slot<PerDeviceFFHandle>(HANDLE);
register_task(BATCHMATMUL_INIT_TASK_ID, "BatchMatmul Init", init, init_task);
lib/runtime/src/ops/batch_matmul.cc
line 188 at r4 (raw file):
int batch = 1; for (int i = 2; i < a_input.shape.get_dim(); i++) { //@colin get_dim() or get_volume()?
input.shape.dims.num_dims()
Code quote:
//@colin get_dim() or get_volume()?
lib/runtime/src/task_spec/op_arg_ref.h
line 31 at r4 (raw file):
} OpArgRef<FFIterationConfig> iteration_config() {
Should be a RuntimeArgRef
not an OpArgRef
as there is an FFIterationConfig
per model and not per operator
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reviewable status: all files reviewed, 25 unresolved discussions (waiting on @lockshaw)
lib/kernels/include/kernels/batch_matmul_kernels.h
line 23 at r1 (raw file):
BMMPerDeviceState init_kernel(PerDeviceFFHandle handle, Allocator allocator,
Done.
lib/kernels/include/kernels/batch_matmul_kernels.h
line 29 at r1 (raw file):
void forward_kernel(ffStream_t stream, BMMPerDeviceState const *meta, float *o_ptr,
Done.
lib/kernels/include/kernels/batch_matmul_kernels.h
line 30 at r1 (raw file):
BMMPerDeviceState const *meta, float *o_ptr, float const *a_ptr,
Done.
lib/kernels/include/kernels/batch_matmul_kernels.h
line 31 at r1 (raw file):
float *o_ptr, float const *a_ptr, float const *b_ptr,
Done.
lib/kernels/include/kernels/batch_matmul_kernels.h
line 32 at r1 (raw file):
float const *a_ptr, float const *b_ptr, float const *c_ptr,
Done.
lib/kernels/include/kernels/batch_matmul_kernels.h
line 41 at r1 (raw file):
void backward_kernel(ffStream_t stream, BMMPerDeviceState const *meta, float const *o_ptr,
Done.
lib/kernels/src/hip/batch_matmul_kernels.cpp
line 33 at r1 (raw file):
void forward_kernel(hipStream_t stream, BatchMatmulPerDeviceState const *meta, float *o_ptr,
Done.
lib/kernels/src/hip/batch_matmul_kernels.cpp
line at r1 (raw file):
Previously, lockshaw (Colin Unger) wrote…
If you have to modify the
.cpp
file (i.e., AMD) you probably also have to modify the.cu
file (i.e., CUDA)
Done.
lib/op-attrs/include/op-attrs/ops/batch_matmul.h
line 15 at r1 (raw file):
Previously, lockshaw (Colin Unger) wrote…
Probably not necessary--the additional functions were present for attention because iirc they involved some computation
Done.
lib/runtime/include/runtime/config.h
line 110 at r4 (raw file):
Previously, lockshaw (Colin Unger) wrote…
req
shouldn't be necessary here since it's using nonstandard construction (so it's allowed to be default constructible)
Done.
lib/runtime/src/ops/batch_matmul.h
line 86 at r1 (raw file):
/* public: */ /* int a_seq_length_dim, b_seq_length_dim; */ /* }; */
Done.
lib/runtime/src/ops/batch_matmul.h
line 91 at r1 (raw file):
Previously, lockshaw (Colin Unger) wrote…
Why is there this giant commented-out block after the
#endif
?
Thats all the old code from the .cc file
lib/runtime/src/ops/batch_matmul.cc
line 20 at r1 (raw file):
// #include "kernels/profiling.h" // #include "legion/legion_utilities.h" // #include "tasks.h"
Done.
lib/runtime/src/ops/batch_matmul.cc
line 77 at r1 (raw file):
Previously, lockshaw (Colin Unger) wrote…
As long as the type is short you can skip
auto
--in generalauto
harms readability unless the type name is long (here it's a bit weird because the type is repeated twice, but if the type name is super short I'd lean toward using the type name)
Done.
lib/runtime/src/ops/batch_matmul.cc
line 89 at r1 (raw file):
Previously, lockshaw (Colin Unger) wrote…
Should this be commented out? Or should it be removed entirely?
I think I deleted it in the final version
lib/runtime/src/ops/batch_matmul.cc
line 105 at r1 (raw file):
Previously, lockshaw (Colin Unger) wrote…
There's no way this compiles--
regions
andtask
don't exist
I think the asserts don't stop compilation if they don't work. I'll delete these for now
lib/runtime/src/ops/batch_matmul.cc
line 113 at r1 (raw file):
Previously, lockshaw (Colin Unger) wrote…
task
does not exist--this should be done via a newRuntimeArgRef
for the iteration config
Done.
lib/runtime/src/ops/batch_matmul.cc
line 140 at r1 (raw file):
Previously, lockshaw (Colin Unger) wrote…
Okay, this appears to just be null everywhere, let's just delete it
Done.
lib/runtime/src/ops/batch_matmul.cc
line 159 at r1 (raw file):
Previously, lockshaw (Colin Unger) wrote…
task
andregions
do not exist
Done.
lib/runtime/src/ops/batch_matmul.cc
line 169 at r1 (raw file):
auto output_grad = acc.get_tensor_grad<Permissions::RW>(OUTPUT); // is this equivalent to checking `Domain` equality? assert(output == output_grad);
Done.
lib/runtime/src/ops/batch_matmul.cc
line 173 at r1 (raw file):
auto a_input = acc.get_tensor<Permissions::RO>(A_INPUT); auto a_input_grad = acc.get_tensor_grad<Permissions::RW>(A_INPUT); assert(a_input == a_input_grad);
Done.
lib/runtime/src/ops/batch_matmul.cc
line 196 at r1 (raw file):
Previously, lockshaw (Colin Unger) wrote…
Create issue
Done.
lib/runtime/src/ops/batch_matmul.cc
line 275 at r1 (raw file):
Previously, lockshaw (Colin Unger) wrote…
You don't have a variable called
attrs
in scope so I'm not sure how the previous code would ever compile
Done.
lib/runtime/src/ops/batch_matmul.cc
line 188 at r4 (raw file):
Previously, lockshaw (Colin Unger) wrote…
input.shape.dims.num_dims()
Done.
lib/runtime/src/task_spec/op_arg_ref.h
line 31 at r4 (raw file):
Previously, lockshaw (Colin Unger) wrote…
Should be a
RuntimeArgRef
not anOpArgRef
as there is anFFIterationConfig
per model and not per operator
Done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reviewed 5 of 9 files at r6, 2 of 4 files at r7, 3 of 5 files at r8, 1 of 1 files at r9, 3 of 3 files at r10, all commit messages.
Reviewable status: all files reviewed, 7 unresolved discussions (waiting on @KateUnger and @reyna-abhyankar)
lib/kernels/src/cuda/attention_kernels.cu
line 195 at r10 (raw file):
} MHAPerDeviceState per_device_state = {handle,
Shouldn't this be returned?
lib/kernels/src/cuda/attention_kernels.cu
line 209 at r10 (raw file):
reserveSpace, allocator}; free(qoSeqArray);
allocator.deallocate
lib/kernels/src/cuda/batch_matmul_kernels.cu
line 35 at r10 (raw file):
int a_seq_length_dim, int b_seq_length_dim, int seq_length = -1) {
Why?
lib/op-attrs/include/op-attrs/ops/attention.h
line 15 at r10 (raw file):
req<bool> bias, add_bias_kv, add_zero_attn; }; FF_VISITABLE_STRUCT_NONSTANDARD_CONSTRUCTION(MultiHeadAttentionAttrs,
Why?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reviewed 4 of 4 files at r11, all commit messages.
Reviewable status: all files reviewed, 5 unresolved discussions (waiting on @KateUnger and @reyna-abhyankar)
lib/kernels/src/cuda/attention_kernels.cu
line 210 at r11 (raw file):
allocator}; allocator.deallocate(qoSeqArray); allocator.deallocate(kvSeqArray);
Correct me if this is wrong
Suggestion:
allocator.deallocate(devQoSeqArray);
allocator.deallocate(devKvSeqArray);
lib/kernels/src/cuda/attention_kernels.cu
line 212 at r11 (raw file):
allocator.deallocate(kvSeqArray); return per_device_state;
Should this return a MHAPerDeviceState
or a DeviceSpecific<MHAPerDeviceState>
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reviewable status: all files reviewed, 5 unresolved discussions (waiting on @lockshaw)
lib/kernels/src/cuda/attention_kernels.cu
line 210 at r11 (raw file):
Previously, lockshaw (Colin Unger) wrote…
Correct me if this is wrong
It should be qoSeqArray
and kvSeqArray
because the dev___
arrays are pointers to GPU memory that we don't want to deallocate. That being said, I am unclear on the allocator interface. Is it able to (de)allocate host and device memory. If it's only supposed to be used for device memory, then these lines should just be free(qoSeqArray); free(kvSeqArray);
lib/kernels/src/cuda/attention_kernels.cu
line 212 at r11 (raw file):
Previously, lockshaw (Colin Unger) wrote…
Should this return a
MHAPerDeviceState
or aDeviceSpecific<MHAPerDeviceState>
?
This should be a MHAPerDeviceState
because in init_task_impl
we call acc.create_device_specific<>(init_kernel(...))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reviewable status: all files reviewed, 4 unresolved discussions (waiting on @KateUnger and @reyna-abhyankar)
lib/kernels/src/cuda/attention_kernels.cu
line 210 at r11 (raw file):
Previously, reyna-abhyankar (Reyna Abhyankar) wrote…
It should be
qoSeqArray
andkvSeqArray
because thedev___
arrays are pointers to GPU memory that we don't want to deallocate. That being said, I am unclear on the allocator interface. Is it able to (de)allocate host and device memory. If it's only supposed to be used for device memory, then these lines should just befree(qoSeqArray); free(kvSeqArray);
Correct me if I'm wrong here, but devQoSeqArray
and devKvSeqArray
were created by allocator.allocate
, and so should be cleaned up by allocator.deallocate
? qoSeqArray
was created by malloc
, so should be cleaned up by free
(or probably better would be to use a smart ptr/container so manual cleanup isn't necessary)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reviewable status: all files reviewed, 4 unresolved discussions (waiting on @lockshaw)
lib/kernels/src/cuda/attention_kernels.cu
line 210 at r11 (raw file):
Previously, lockshaw (Colin Unger) wrote…
Correct me if I'm wrong here, but
devQoSeqArray
anddevKvSeqArray
were created byallocator.allocate
, and so should be cleaned up byallocator.deallocate
?qoSeqArray
was created bymalloc
, so should be cleaned up byfree
(or probably better would be to use a smart ptr/container so manual cleanup isn't necessary)
Ok, yes it should still be qoSeqArray
and kvSeqArray
that need to be deallocated, so we should change to unique_ptr<int>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reviewed 1 of 1 files at r12, all commit messages.
Reviewable status: all files reviewed, 4 unresolved discussions (waiting on @KateUnger and @reyna-abhyankar)
lib/kernels/src/cuda/attention_kernels.cu
line 210 at r11 (raw file):
Previously, reyna-abhyankar (Reyna Abhyankar) wrote…
Ok, yes it should still be
qoSeqArray
andkvSeqArray
that need to be deallocated, so we should change tounique_ptr<int>
Why does devQoSeqArray
not get cleaned up?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reviewable status: all files reviewed, 4 unresolved discussions (waiting on @lockshaw)
lib/kernels/src/cuda/attention_kernels.cu
line 210 at r11 (raw file):
Previously, lockshaw (Colin Unger) wrote…
Why does
devQoSeqArray
not get cleaned up?
It is still needed for the device state.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reviewable status: all files reviewed, 3 unresolved discussions (waiting on @reyna-abhyankar)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reviewable status: all files reviewed, 3 unresolved discussions (waiting on @reyna-abhyankar)
Description of changes:
TODO: Change
batch_matmul_kernels.cpp
andbatch_matmul_kernels.cu
Related Issues:
Linked Issues:
Issues closed by this PR:
This change is