Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MHA] Integrate ck mha fp8 solver into miopen #3014

Closed
wants to merge 29 commits into from

Conversation

bghimireamd
Copy link
Contributor

@bghimireamd bghimireamd commented Jun 3, 2024

CK's fp8 col major solver integration into MIOpen

@bghimireamd bghimireamd changed the title initial comment for attemting to integrate ck mha fp8 solver Integrate ck mha fp8 solver into miopen Jun 3, 2024
@bghimireamd bghimireamd changed the title Integrate ck mha fp8 solver into miopen [DRAFT][MHA] Integrate ck mha fp8 solver into miopen Jun 3, 2024
Copy link
Contributor

@CAHEK7 CAHEK7 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

some minor suggestions

src/solver/mha/mha_ck_solver_forward.cpp Outdated Show resolved Hide resolved
src/solver/mha/mha_ck_solver_forward.cpp Outdated Show resolved Hide resolved
@bghimireamd bghimireamd marked this pull request as ready for review July 19, 2024 18:49
@bghimireamd bghimireamd changed the title [DRAFT][MHA] Integrate ck mha fp8 solver into miopen [MHA] Integrate ck mha fp8 solver into miopen Jul 19, 2024
@bghimireamd bghimireamd marked this pull request as draft July 21, 2024 15:13
@bghimireamd bghimireamd marked this pull request as ready for review July 23, 2024 19:22
@bghimireamd bghimireamd requested a review from CAHEK7 July 24, 2024 14:05
@@ -7,5 +7,5 @@ nlohmann/json@v3.11.2 -DJSON_MultipleHeaders=ON -DJSON_BuildTests=Off
ROCm/FunctionalPlus@v0.2.18-p0
ROCm/eigen@3.4.0
ROCm/frugally-deep@9683d557eb672ee2304f80f6682c51242d748a50
ROCm/composable_kernel@15baccf2ecad4fb3498b8acb6bbf58fb5359c7a5 -DCMAKE_BUILD_TYPE=Release -DINSTANCES_ONLY=ON
ROCm/composable_kernel@1dd9875c9ac783cabc8a536e0802b58d43b6b107 -DCMAKE_BUILD_TYPE=Release -DINSTANCES_ONLY=ON
Copy link
Contributor

@CAHEK7 CAHEK7 Aug 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this be aligned with #3142 and #3181?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I also recommend picking CK commit hashes from its amd-develop branch.

Comment on lines +56 to +64
const auto& fptr = GetDescsForward();
if(fptr.oDesc.GetType() == miopenFloat8)
{
return true;
}
else
{
return false;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
const auto& fptr = GetDescsForward();
if(fptr.oDesc.GetType() == miopenFloat8)
{
return true;
}
else
{
return false;
}
return GetDescsForward().oDesc.GetType() == miopenFloat8;

But I'm not sure that it has to be here at all.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was following convolution's problem description pattern.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then my suggestion matches this pattern perfectly.

Comment on lines +140 to +142
const std::string name = context.GetStream().GetDeviceName();
if(!(name == "gfx940" || name == "gfx941" || name == "gfx942"))
return false;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

StartsWith(handle.GetDeviceName(), "gfx103")

Suggested change
const std::string name = context.GetStream().GetDeviceName();
if(!(name == "gfx940" || name == "gfx941" || name == "gfx942"))
return false;
if(!StartsWith(handle.GetDeviceName(), "gfx94"))
return false;

Comment on lines +143 to +144
if(!problem.IsFFp8()) // forward mha fp8
return false;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably all the tensors should be explicitly checked.

Comment on lines +145 to +147
::miopen::mha::MhaInputDescsForward mha_des = problem.GetDescsForward();
const auto& lens = mha_des.qDesc.GetLengths();
auto [N, H, S, D] = std::tie(lens[0], lens[1], lens[2], lens[3]);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
::miopen::mha::MhaInputDescsForward mha_des = problem.GetDescsForward();
const auto& lens = mha_des.qDesc.GetLengths();
auto [N, H, S, D] = std::tie(lens[0], lens[1], lens[2], lens[3]);
auto [N, H, S, D] = miopen::tien<4>(problem.GetDescsForward().qDesc.GetLengths());

Comment on lines +148 to +150
if(D <= 256 && S % 128 == 0 && D % 64 == 0)
return true;
return false;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if(D <= 256 && S % 128 == 0 && D % 64 == 0)
return true;
return false;
return (D <= 256 && S % 128 == 0 && D % 64 == 0);

auto [n, h, s, d, drop] = GetParam();
Handle& handle = get_handle();

if((drop > 0.0f) && (s % handle.GetWavefrontWidth() != 0))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as I know, CK does not support dropout at all.

InitTensor(miopenTensorMhaV, std::move(v.mTensor));

float s_scale = test::cpu::GetF8Scaling(1.0);
// clang-tidy complains about the same expression on both sides of "/": 1.f / 1.f
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// clang-tidy complains about the same expression on both sides of "/": 1.f / 1.f

float amaxO_ref;

bool is_ck_solver = false;
float ck_fp8_solver_threshold = 0.015;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is too high threshold, MIOpen fp8 algorithm fits into 0.0002 which is 75 times smaller.
Could you double-check that the values are similar, because rms error is not sensible to permutations and some significant outliers.

Comment on lines +293 to +295
class Test_Fwd_Mha_F32 : public Test_Fwd_Mha<float>
{
};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as I know, CK does not support FP32.

else
{
// ck solver
ScaleMult(atten_heads_fp32, GetF8Scaling(aMax_O), multi_head_attention_fp8);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

aMax_O must not be used here, it's an output parameter and will be passed as scale_o into the next iteration.

@junliume
Copy link
Collaborator

@bghimireamd do you wish to follow up with this PR?

@bghimireamd
Copy link
Contributor Author

bghimireamd commented Sep 12, 2024

Current fp8 V2 implementation is very limited. We will place close this PR since CK now implementing fp8 v3. We will have new PR for the new CK fp8 v3 integration.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants