Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batch Matmul Op #1023

Merged
merged 46 commits into from
Dec 31, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
6b248cf
batch matmul initial commit
KateUnger Aug 22, 2023
d9aae74
fix FF_VISITABLE_STRUCT_NO_EQ
KateUnger Aug 22, 2023
45ecd8c
Merge branch 'repo-refactor' into batch_matmul
lockshaw Aug 23, 2023
2073822
finish draft 1 batch_matmul
KateUnger Aug 23, 2023
5c322bd
Merge branch 'batch_matmul' of github.com:KateUnger/FlexFlow into bat…
KateUnger Aug 23, 2023
d5806a0
add output and weights
KateUnger Aug 23, 2023
f620070
format
KateUnger Aug 23, 2023
e954ab0
Merge branch 'repo-refactor' of https://github.com/flexflow/FlexFlow …
KateUnger Aug 23, 2023
5458923
fix DeviceSpecific
KateUnger Aug 23, 2023
f2205f4
change
KateUnger Aug 29, 2023
2f4662d
change
KateUnger Aug 29, 2023
642eb90
change
KateUnger Aug 29, 2023
e79406a
change
KateUnger Aug 29, 2023
bb3c10f
change
KateUnger Aug 29, 2023
8488509
change
KateUnger Aug 29, 2023
36eba29
change
KateUnger Aug 29, 2023
98773b4
change
KateUnger Aug 29, 2023
ae59261
change
KateUnger Aug 29, 2023
8db2512
fix asserts
KateUnger Aug 30, 2023
09acbe5
format
KateUnger Aug 30, 2023
9f121d5
Merge branch 'repo-refactor' of https://github.com/flexflow/FlexFlow …
KateUnger Sep 10, 2023
c5be7a7
batch_matmul
KateUnger Sep 11, 2023
67cf7be
format
KateUnger Sep 11, 2023
0bddb9e
Merge branch 'repo-refactor' of https://github.com/flexflow/FlexFlow …
KateUnger Sep 13, 2023
bbdfe9a
add cuda
KateUnger Sep 13, 2023
dce19e2
fix attention kuda
KateUnger Sep 13, 2023
e4a796d
format
KateUnger Sep 13, 2023
e817e75
draft
KateUnger Sep 30, 2023
5c80bcd
format
KateUnger Sep 30, 2023
70dc9f0
reyna fixes
KateUnger Sep 30, 2023
03b291c
format
KateUnger Sep 30, 2023
2bb4501
format repo-refactor
KateUnger Oct 4, 2023
944c1b4
Merge branch 'form' of github.com:KateUnger/FlexFlow into batch_matmul
KateUnger Oct 4, 2023
2b24fc3
Merge branch 'repo-refactor' of https://github.com/flexflow/FlexFlow …
KateUnger Oct 4, 2023
125921b
delete init and split register_task
KateUnger Oct 4, 2023
dc225a3
finish bmm and att.
KateUnger Oct 4, 2023
34825a1
Merge branch 'repo-refactor' into batch_matmul
lockshaw Oct 6, 2023
b9b7642
Merge branch 'repo-refactor' into batch_matmul
reyna-abhyankar Oct 7, 2023
d6ce742
Align hip kernel
reyna-abhyankar Oct 7, 2023
ca9efd9
Merge branch 'repo-refactor' into batch_matmul
reyna-abhyankar Oct 7, 2023
3fb5f6f
Merge branch 'repo-refactor' into batch_matmul
reyna-abhyankar Oct 9, 2023
3f6b8b0
Fix attention kernels
reyna-abhyankar Oct 10, 2023
edc1935
Replace with unique ptr
reyna-abhyankar Oct 19, 2023
07d9db9
Format
reyna-abhyankar Oct 23, 2023
1d4ac9a
Merge branch 'repo-refactor' into batch_matmul
reyna-abhyankar Nov 1, 2023
bae778d
Merge branch 'repo-refactor' into batch_matmul
reyna-abhyankar Dec 31, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 15 additions & 8 deletions lib/kernels/include/kernels/batch_matmul_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,26 @@

namespace FlexFlow {

class BatchMatmulPerDeviceState : public PerDeviceOpState {
public:
BatchMatmulPerDeviceState(FFHandler handler);
int a_seq_length_dim, b_seq_length_dim;
struct BMMPerDeviceState {
PerDeviceFFHandle handle;
Allocator allocator;
int a_seq_length_dim;
int b_seq_length_dim;
};

FF_VISITABLE_STRUCT_NO_EQ(BMMPerDeviceState,
handle,);

namespace Kernels {
namespace BatchMatmul {

BMMPerDeviceState init_kernel(PerDeviceFFHandle handle,
Allocator allocator,
int a_seq_length_dim,
int b_seq_length_dim);

void forward_kernel(ffStream_t stream,
BatchMatmulPerDeviceState const *,
BMMPerDeviceState const *meta,
float *o_ptr,
float const *a_ptr,
float const *b_ptr,
Expand All @@ -25,12 +34,10 @@ void forward_kernel(ffStream_t stream,
int n,
int k,
int batch,
int a_seq_length_dim = -1,
int b_seq_length_dim = -1,
int seq_length = -1);

void backward_kernel(ffStream_t stream,
BatchMatmulPerDeviceState const *,
BMMPerDeviceState const *meta,
float const *o_ptr,
float const *o_grad_ptr,
float const *a_ptr,
Expand Down
7 changes: 2 additions & 5 deletions lib/kernels/src/hip/batch_matmul_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,6 @@

namespace FlexFlow {

BatchMatmulPerDeviceState::BatchMatmulPerDeviceState(FFHandler handler)
: PerDeviceOpState(handler) {}

namespace Kernels {
namespace BatchMatmul {

Expand All @@ -42,9 +39,9 @@ void forward_kernel(hipStream_t stream,
int k,
int batch,
hipStream_t stream,
int a_seq_length_dim,
int b_seq_length_dim,
int seq_length) {
int a_seq_length_dim = meta->a_seq_length_dim;
int b_seq_length_dim = meta->b_seq_length_dim;
checkCUDA(hipblasSetStream(meta->handle.blas, stream));
checkCUDNN(miopenSetStream(meta->handle.dnn, stream));

Expand Down
4 changes: 3 additions & 1 deletion lib/op-attrs/include/op-attrs/ops/batch_matmul.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@ struct BatchMatmulAttrs {
};
FF_VISITABLE_STRUCT(BatchMatmulAttrs, a_seq_length_dim, b_seq_length_dim);

CHECK_VALID_OP_ATTR(BatchMatmulAttrs);
int get_aSeqLengthDim(BatchMatmulAttrs const &attrs);
int get_bSeqLengthDim(BatchMatmulAttrs const &attrs);

CHECK_VALID_OP_ATTR(BatchMatmulAttrs);
} // namespace FlexFlow

#endif
8 changes: 8 additions & 0 deletions lib/op-attrs/src/batch_matmul.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,14 @@

namespace FlexFlow {

int get_aSeqLengthDim(BatchMatmulAttrs const &attrs) {
return attrs.a_seq_length_dim;
}

int get_bSeqLengthDim(BatchMatmulAttrs const &attrs) {
return attrs.b_seq_length_dim;
}

/* bool BatchMatmulAttrs::is_valid( */
/* ParallelTensorShape const &lhs, ParallelTensorShape const &rhs) const {
*/
Expand Down
Loading
Loading