Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batch Matmul Op #1023

Merged
merged 46 commits into from
Dec 31, 2023
Merged

Batch Matmul Op #1023

merged 46 commits into from
Dec 31, 2023

Conversation

KateUnger
Copy link
Collaborator

@KateUnger KateUnger commented Aug 22, 2023

Description of changes:

  • Updates Batch Matmul Operation

TODO: Change batch_matmul_kernels.cpp and batch_matmul_kernels.cu

Related Issues:

Linked Issues:

  • Issue #

Issues closed by this PR:


This change is Reviewable

lockshaw
lockshaw previously approved these changes Aug 23, 2023
@lockshaw lockshaw self-requested a review August 23, 2023 11:41
@KateUnger KateUnger added the repo-refactor Topics related to the repo and search refactors label Aug 24, 2023
This was referenced Aug 28, 2023
Copy link
Collaborator

@lockshaw lockshaw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewed 6 of 6 files at r1, 6 of 6 files at r4, all commit messages.
Reviewable status: all files reviewed, 25 unresolved discussions (waiting on @KateUnger)


lib/kernels/include/kernels/batch_matmul_kernels.h line 23 at r1 (raw file):

BMMPerDeviceState init_kernel(PerDeviceFFHandle handle,
                  Allocator allocator,

Suggestion:

BMMPerDeviceState init_kernel(PerDeviceFFHandle const &handle,
                  Allocator const &allocator,

lib/kernels/include/kernels/batch_matmul_kernels.h line 29 at r1 (raw file):

void forward_kernel(ffStream_t stream,
                    BMMPerDeviceState const *meta,
                    float *o_ptr,

Suggestion:

void forward_kernel(ffStream_t stream,
                    BMMPerDeviceState const &meta,
                    float *output_ptr,

lib/kernels/include/kernels/batch_matmul_kernels.h line 30 at r1 (raw file):

                    BMMPerDeviceState const *meta,
                    float *o_ptr,
                    float const *a_ptr,

Suggestion:

                    float const *lhs_input_ptr,

lib/kernels/include/kernels/batch_matmul_kernels.h line 31 at r1 (raw file):

                    float *o_ptr,
                    float const *a_ptr,
                    float const *b_ptr,

Suggestion:

                    float const *rhs_input_ptr,

lib/kernels/include/kernels/batch_matmul_kernels.h line 32 at r1 (raw file):

                    float const *a_ptr,
                    float const *b_ptr,
                    float const *c_ptr,

lib/kernels/include/kernels/batch_matmul_kernels.h line 41 at r1 (raw file):

void backward_kernel(ffStream_t stream,
                    BMMPerDeviceState const *meta,
                     float const *o_ptr,

Suggestion:

                    BMMPerDeviceState const &meta,
                     float const *o_ptr,

lib/kernels/src/hip/batch_matmul_kernels.cpp line 0 at r1 (raw file):
If you have to modify the .cpp file (i.e., AMD) you probably also have to modify the .cu file (i.e., CUDA)


lib/kernels/src/hip/batch_matmul_kernels.cpp line 33 at r1 (raw file):

void forward_kernel(hipStream_t stream,
                    BatchMatmulPerDeviceState const *meta,
                    float *o_ptr,

Suggestion:

void forward_kernel(hipStream_t stream,
                    BatchMatmulPerDeviceState const &meta,
                    float *o_ptr,

lib/op-attrs/include/op-attrs/ops/batch_matmul.h line 15 at r1 (raw file):

FF_VISITABLE_STRUCT(BatchMatmulAttrs, a_seq_length_dim, b_seq_length_dim);

int get_aSeqLengthDim(BatchMatmulAttrs const &attrs);

Probably not necessary--the additional functions were present for attention because iirc they involved some computation


lib/runtime/include/runtime/config.h line 110 at r4 (raw file):

  FFIterationConfig();
  void reset();
  req<int> seq_length;

req shouldn't be necessary here since it's using nonstandard construction (so it's allowed to be default constructible)

Suggestion:

  int seq_length;

lib/runtime/src/ops/batch_matmul.h line 86 at r1 (raw file):

/* public: */
/*   int a_seq_length_dim, b_seq_length_dim; */
/* }; */

lib/runtime/src/ops/batch_matmul.h line 91 at r1 (raw file):

#endif

Why is there this giant commented-out block after the #endif?


lib/runtime/src/ops/batch_matmul.cc line 20 at r1 (raw file):

// #include "kernels/profiling.h"
// #include "legion/legion_utilities.h"
// #include "tasks.h"

lib/runtime/src/ops/batch_matmul.cc line 77 at r1 (raw file):

static DeviceSpecificArg<BMMPerDeviceState> init_task_impl(TaskArgumentAccessor const &acc) {
  auto const a_seq_length_dim = acc.get_argument<int>(A_SEQ_LENGTH_DIM);

As long as the type is short you can skip auto--in general auto harms readability unless the type name is long (here it's a bit weird because the type is repeated twice, but if the type name is super short I'd lean toward using the type name)

Suggestion:

  int const a_seq_length_dim = acc.get_argument<int>(A_SEQ_LENGTH_DIM);

lib/runtime/src/ops/batch_matmul.cc line 89 at r1 (raw file):

                      b_seq_length_dim));

  // assert(weight.shape.get_volume() * sizeof(float) ==

Should this be commented out? Or should it be removed entirely?


lib/runtime/src/ops/batch_matmul.cc line 105 at r1 (raw file):

static optional<float> forward_task_impl(TaskArgumentAccessor const &acc) {
  assert(regions.size() == 3);
  assert(task->regions.size() == 3);

There's no way this compiles--regions and task don't exist

Code quote:

  assert(regions.size() == 3);
  assert(task->regions.size() == 3);

lib/runtime/src/ops/batch_matmul.cc line 113 at r1 (raw file):

  ProfilingSettings profiling = acc.get_argument<ProfilingSettings>(PROFILING);
  auto per_device_state = acc.get_argument<BMMPerDeviceState>(PER_DEVICE_STATE);
  FFIterationConfig const *iter_config = (FFIterationConfig const *)task->args;

task does not exist--this should be done via a new RuntimeArgRef for the iteration config


lib/runtime/src/ops/batch_matmul.cc line 140 at r1 (raw file):

          a_input.get_float_ptr(),
          b_input.get_float_ptr(),
          NULL, //c_ptr

Okay, this appears to just be null everywhere, let's just delete it


lib/runtime/src/ops/batch_matmul.cc line 159 at r1 (raw file):

  // Currently assume C is NULL
  assert(regions.size() == 6);
  assert(task->regions.size() == 6);

task and regions do not exist

Code quote:

  assert(regions.size() == 6);
  assert(task->regions.size() == 6);

lib/runtime/src/ops/batch_matmul.cc line 169 at r1 (raw file):

  auto output_grad = acc.get_tensor_grad<Permissions::RW>(OUTPUT);
  // is this equivalent to checking `Domain` equality?
  assert(output == output_grad); 

Suggestion:

  assert(output.shape == output_grad.shape);

lib/runtime/src/ops/batch_matmul.cc line 173 at r1 (raw file):

  auto a_input = acc.get_tensor<Permissions::RO>(A_INPUT);
  auto a_input_grad = acc.get_tensor_grad<Permissions::RW>(A_INPUT);
  assert(a_input == a_input_grad);

Suggestion:

  assert(a_input.shape == a_input_grad.shape);

lib/runtime/src/ops/batch_matmul.cc line 196 at r1 (raw file):

  }

  // TODO: add support for meta->a_seq_length_dim >= 0

Create issue


lib/runtime/src/ops/batch_matmul.cc line 275 at r1 (raw file):

  init.add_unchecked_arg_slot(HANDLE, ff_handle());

  register_task(BATCHMATMUL_INIT_TASK_ID, "BatchMatmul Init", init, init_task);

You don't have a variable called attrs in scope so I'm not sure how the previous code would ever compile

Suggestion:

  OpTaskSignature init(OpTaskType::INIT);

  init.add_arg_slot<int>(A_SEQ_LENGTH_DIM);
  init.add_arg_slot<int>(B_SEQ_LENGTH_DIM);
  init.add_unchecked_arg_slot<PerDeviceFFHandle>(HANDLE);

  register_task(BATCHMATMUL_INIT_TASK_ID, "BatchMatmul Init", init, init_task);

lib/runtime/src/ops/batch_matmul.cc line 188 at r4 (raw file):

  int batch = 1;
  for (int i = 2; i < a_input.shape.get_dim();
       i++) { //@colin get_dim() or get_volume()?

input.shape.dims.num_dims()

Code quote:

 //@colin get_dim() or get_volume()?

lib/runtime/src/task_spec/op_arg_ref.h line 31 at r4 (raw file):

}

OpArgRef<FFIterationConfig> iteration_config() {

Should be a RuntimeArgRef not an OpArgRef as there is an FFIterationConfig per model and not per operator

@lockshaw lockshaw linked an issue Aug 29, 2023 that may be closed by this pull request
Copy link
Collaborator Author

@KateUnger KateUnger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewable status: all files reviewed, 25 unresolved discussions (waiting on @lockshaw)


lib/kernels/include/kernels/batch_matmul_kernels.h line 23 at r1 (raw file):

BMMPerDeviceState init_kernel(PerDeviceFFHandle handle,
                  Allocator allocator,

Done.


lib/kernels/include/kernels/batch_matmul_kernels.h line 29 at r1 (raw file):

void forward_kernel(ffStream_t stream,
                    BMMPerDeviceState const *meta,
                    float *o_ptr,

Done.


lib/kernels/include/kernels/batch_matmul_kernels.h line 30 at r1 (raw file):

                    BMMPerDeviceState const *meta,
                    float *o_ptr,
                    float const *a_ptr,

Done.


lib/kernels/include/kernels/batch_matmul_kernels.h line 31 at r1 (raw file):

                    float *o_ptr,
                    float const *a_ptr,
                    float const *b_ptr,

Done.


lib/kernels/include/kernels/batch_matmul_kernels.h line 32 at r1 (raw file):

                    float const *a_ptr,
                    float const *b_ptr,
                    float const *c_ptr,

Done.


lib/kernels/include/kernels/batch_matmul_kernels.h line 41 at r1 (raw file):

void backward_kernel(ffStream_t stream,
                    BMMPerDeviceState const *meta,
                     float const *o_ptr,

Done.


lib/kernels/src/hip/batch_matmul_kernels.cpp line 33 at r1 (raw file):

void forward_kernel(hipStream_t stream,
                    BatchMatmulPerDeviceState const *meta,
                    float *o_ptr,

Done.


lib/kernels/src/hip/batch_matmul_kernels.cpp line at r1 (raw file):

Previously, lockshaw (Colin Unger) wrote…

If you have to modify the .cpp file (i.e., AMD) you probably also have to modify the .cu file (i.e., CUDA)

Done.


lib/op-attrs/include/op-attrs/ops/batch_matmul.h line 15 at r1 (raw file):

Previously, lockshaw (Colin Unger) wrote…

Probably not necessary--the additional functions were present for attention because iirc they involved some computation

Done.


lib/runtime/include/runtime/config.h line 110 at r4 (raw file):

Previously, lockshaw (Colin Unger) wrote…

req shouldn't be necessary here since it's using nonstandard construction (so it's allowed to be default constructible)

Done.


lib/runtime/src/ops/batch_matmul.h line 86 at r1 (raw file):

/* public: */
/*   int a_seq_length_dim, b_seq_length_dim; */
/* }; */

Done.


lib/runtime/src/ops/batch_matmul.h line 91 at r1 (raw file):

Previously, lockshaw (Colin Unger) wrote…

Why is there this giant commented-out block after the #endif?

Thats all the old code from the .cc file


lib/runtime/src/ops/batch_matmul.cc line 20 at r1 (raw file):

// #include "kernels/profiling.h"
// #include "legion/legion_utilities.h"
// #include "tasks.h"

Done.


lib/runtime/src/ops/batch_matmul.cc line 77 at r1 (raw file):

Previously, lockshaw (Colin Unger) wrote…

As long as the type is short you can skip auto--in general auto harms readability unless the type name is long (here it's a bit weird because the type is repeated twice, but if the type name is super short I'd lean toward using the type name)

Done.


lib/runtime/src/ops/batch_matmul.cc line 89 at r1 (raw file):

Previously, lockshaw (Colin Unger) wrote…

Should this be commented out? Or should it be removed entirely?

I think I deleted it in the final version


lib/runtime/src/ops/batch_matmul.cc line 105 at r1 (raw file):

Previously, lockshaw (Colin Unger) wrote…

There's no way this compiles--regions and task don't exist

I think the asserts don't stop compilation if they don't work. I'll delete these for now


lib/runtime/src/ops/batch_matmul.cc line 113 at r1 (raw file):

Previously, lockshaw (Colin Unger) wrote…

task does not exist--this should be done via a new RuntimeArgRef for the iteration config

Done.


lib/runtime/src/ops/batch_matmul.cc line 140 at r1 (raw file):

Previously, lockshaw (Colin Unger) wrote…

Okay, this appears to just be null everywhere, let's just delete it

Done.


lib/runtime/src/ops/batch_matmul.cc line 159 at r1 (raw file):

Previously, lockshaw (Colin Unger) wrote…

task and regions do not exist

Done.


lib/runtime/src/ops/batch_matmul.cc line 169 at r1 (raw file):

  auto output_grad = acc.get_tensor_grad<Permissions::RW>(OUTPUT);
  // is this equivalent to checking `Domain` equality?
  assert(output == output_grad); 

Done.


lib/runtime/src/ops/batch_matmul.cc line 173 at r1 (raw file):

  auto a_input = acc.get_tensor<Permissions::RO>(A_INPUT);
  auto a_input_grad = acc.get_tensor_grad<Permissions::RW>(A_INPUT);
  assert(a_input == a_input_grad);

Done.


lib/runtime/src/ops/batch_matmul.cc line 196 at r1 (raw file):

Previously, lockshaw (Colin Unger) wrote…

Create issue

Done.


lib/runtime/src/ops/batch_matmul.cc line 275 at r1 (raw file):

Previously, lockshaw (Colin Unger) wrote…

You don't have a variable called attrs in scope so I'm not sure how the previous code would ever compile

Done.


lib/runtime/src/ops/batch_matmul.cc line 188 at r4 (raw file):

Previously, lockshaw (Colin Unger) wrote…

input.shape.dims.num_dims()

Done.


lib/runtime/src/task_spec/op_arg_ref.h line 31 at r4 (raw file):

Previously, lockshaw (Colin Unger) wrote…

Should be a RuntimeArgRef not an OpArgRef as there is an FFIterationConfig per model and not per operator

Done.

@lockshaw lockshaw requested review from reyna-abhyankar and removed request for lockshaw and wmdi October 7, 2023 00:12
reyna-abhyankar
reyna-abhyankar previously approved these changes Oct 7, 2023
Copy link
Collaborator

@lockshaw lockshaw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewed 5 of 9 files at r6, 2 of 4 files at r7, 3 of 5 files at r8, 1 of 1 files at r9, 3 of 3 files at r10, all commit messages.
Reviewable status: all files reviewed, 7 unresolved discussions (waiting on @KateUnger and @reyna-abhyankar)


lib/kernels/src/cuda/attention_kernels.cu line 195 at r10 (raw file):

  }

  MHAPerDeviceState per_device_state = {handle,

Shouldn't this be returned?


lib/kernels/src/cuda/attention_kernels.cu line 209 at r10 (raw file):

                                        reserveSpace,
                                        allocator};
  free(qoSeqArray);

allocator.deallocate


lib/kernels/src/cuda/batch_matmul_kernels.cu line 35 at r10 (raw file):

                    int a_seq_length_dim,
                    int b_seq_length_dim,
                    int seq_length = -1) {

Why?


lib/op-attrs/include/op-attrs/ops/attention.h line 15 at r10 (raw file):

  req<bool> bias, add_bias_kv, add_zero_attn;
};
FF_VISITABLE_STRUCT_NONSTANDARD_CONSTRUCTION(MultiHeadAttentionAttrs,

Why?

Copy link
Collaborator

@lockshaw lockshaw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewed 4 of 4 files at r11, all commit messages.
Reviewable status: all files reviewed, 5 unresolved discussions (waiting on @KateUnger and @reyna-abhyankar)


lib/kernels/src/cuda/attention_kernels.cu line 210 at r11 (raw file):

                                        allocator};
  allocator.deallocate(qoSeqArray);
  allocator.deallocate(kvSeqArray);

Correct me if this is wrong

Suggestion:

  allocator.deallocate(devQoSeqArray);
  allocator.deallocate(devKvSeqArray);

lib/kernels/src/cuda/attention_kernels.cu line 212 at r11 (raw file):

  allocator.deallocate(kvSeqArray);

  return per_device_state;

Should this return a MHAPerDeviceState or a DeviceSpecific<MHAPerDeviceState>?

Copy link
Collaborator

@reyna-abhyankar reyna-abhyankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewable status: all files reviewed, 5 unresolved discussions (waiting on @lockshaw)


lib/kernels/src/cuda/attention_kernels.cu line 210 at r11 (raw file):

Previously, lockshaw (Colin Unger) wrote…

Correct me if this is wrong

It should be qoSeqArray and kvSeqArray because the dev___ arrays are pointers to GPU memory that we don't want to deallocate. That being said, I am unclear on the allocator interface. Is it able to (de)allocate host and device memory. If it's only supposed to be used for device memory, then these lines should just be free(qoSeqArray); free(kvSeqArray);


lib/kernels/src/cuda/attention_kernels.cu line 212 at r11 (raw file):

Previously, lockshaw (Colin Unger) wrote…

Should this return a MHAPerDeviceState or a DeviceSpecific<MHAPerDeviceState>?

This should be a MHAPerDeviceState because in init_task_impl we call acc.create_device_specific<>(init_kernel(...))

Copy link
Collaborator

@lockshaw lockshaw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewable status: all files reviewed, 4 unresolved discussions (waiting on @KateUnger and @reyna-abhyankar)


lib/kernels/src/cuda/attention_kernels.cu line 210 at r11 (raw file):

Previously, reyna-abhyankar (Reyna Abhyankar) wrote…

It should be qoSeqArray and kvSeqArray because the dev___ arrays are pointers to GPU memory that we don't want to deallocate. That being said, I am unclear on the allocator interface. Is it able to (de)allocate host and device memory. If it's only supposed to be used for device memory, then these lines should just be free(qoSeqArray); free(kvSeqArray);

Correct me if I'm wrong here, but devQoSeqArray and devKvSeqArray were created by allocator.allocate, and so should be cleaned up by allocator.deallocate? qoSeqArray was created by malloc, so should be cleaned up by free (or probably better would be to use a smart ptr/container so manual cleanup isn't necessary)

Copy link
Collaborator

@reyna-abhyankar reyna-abhyankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewable status: all files reviewed, 4 unresolved discussions (waiting on @lockshaw)


lib/kernels/src/cuda/attention_kernels.cu line 210 at r11 (raw file):

Previously, lockshaw (Colin Unger) wrote…

Correct me if I'm wrong here, but devQoSeqArray and devKvSeqArray were created by allocator.allocate, and so should be cleaned up by allocator.deallocate? qoSeqArray was created by malloc, so should be cleaned up by free (or probably better would be to use a smart ptr/container so manual cleanup isn't necessary)

Ok, yes it should still be qoSeqArray and kvSeqArray that need to be deallocated, so we should change to unique_ptr<int>

Copy link
Collaborator

@lockshaw lockshaw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewed 1 of 1 files at r12, all commit messages.
Reviewable status: all files reviewed, 4 unresolved discussions (waiting on @KateUnger and @reyna-abhyankar)


lib/kernels/src/cuda/attention_kernels.cu line 210 at r11 (raw file):

Previously, reyna-abhyankar (Reyna Abhyankar) wrote…

Ok, yes it should still be qoSeqArray and kvSeqArray that need to be deallocated, so we should change to unique_ptr<int>

Why does devQoSeqArray not get cleaned up?

Copy link
Collaborator

@reyna-abhyankar reyna-abhyankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewable status: all files reviewed, 4 unresolved discussions (waiting on @lockshaw)


lib/kernels/src/cuda/attention_kernels.cu line 210 at r11 (raw file):

Previously, lockshaw (Colin Unger) wrote…

Why does devQoSeqArray not get cleaned up?

It is still needed for the device state.

Copy link
Collaborator

@lockshaw lockshaw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewable status: all files reviewed, 3 unresolved discussions (waiting on @reyna-abhyankar)

Copy link
Collaborator

@lockshaw lockshaw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewable status: all files reviewed, 3 unresolved discussions (waiting on @reyna-abhyankar)

@lockshaw lockshaw merged commit 2940646 into flexflow:repo-refactor Dec 31, 2023
7 of 8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
repo-refactor Topics related to the repo and search refactors
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Update BatchMatmul operator
3 participants