Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate CK's layer norm into MIOpen solver #2481

Merged
merged 36 commits into from
Nov 15, 2023
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
3608d58
apply issue #2403
seungmanhan Oct 12, 2023
7962087
apply clang format
seungmanhan Oct 12, 2023
dc87589
Init solver structure for layernorm HIP kernel
seungmanhan Oct 12, 2023
b7a2494
apply clang format
seungmanhan Oct 13, 2023
5bea80e
Fix build error
seungmanhan Oct 13, 2023
46ea166
Redesign Solver/Solution/Invoker architecture
seungmanhan Oct 13, 2023
1f1292c
integrate CK layernorm
seungmanhan Oct 18, 2023
a69cab1
add 2D test, fix CK call
seungmanhan Oct 23, 2023
506c4ba
clang format
seungmanhan Oct 23, 2023
41fb4b1
fix drive and gtest, remove layernorm 4d
seungmanhan Oct 24, 2023
18ce140
Merge remote-tracking branch 'origin' into integrate_CK_layernorm
seungmanhan Oct 24, 2023
b6daff2
fie build error
seungmanhan Oct 25, 2023
9f27fa4
remove duplicate check
seungmanhan Oct 25, 2023
4733ad4
check unused parameter
seungmanhan Oct 25, 2023
29f14b2
add check normalized dim
seungmanhan Oct 25, 2023
7abff75
Merge branch 'develop' into integrate_CK_layernorm
seungmanhan Oct 25, 2023
414e5ec
add override
seungmanhan Oct 26, 2023
1369a9a
Merge branch 'integrate_CK_layernorm' of https://github.com/ROCmSoftw…
seungmanhan Oct 26, 2023
6f93b03
update CK kernel call host API latest version
seungmanhan Oct 27, 2023
27173d2
normalization to norm
seungmanhan Oct 27, 2023
13ddac8
static code analysis, add drive fail error, sort cmake list
seungmanhan Nov 1, 2023
b080240
fix logic error
seungmanhan Nov 1, 2023
749ee5a
add local memory check
seungmanhan Nov 1, 2023
8be5a7b
add MIOPEN_BETA_API, change Env vars
seungmanhan Nov 2, 2023
682927d
fix build error
seungmanhan Nov 2, 2023
d3e1e8c
remove MIOPEN_BACKEND_HIP
seungmanhan Nov 4, 2023
bea14c9
remove unnecessary network config
seungmanhan Nov 7, 2023
610b26f
Apply the latest layernorm ck format
seungmanhan Nov 8, 2023
0e781f1
init 4d layernorm CK
seungmanhan Nov 10, 2023
70b060b
clang format
seungmanhan Nov 10, 2023
208f3a5
fix driver, 4d layernorm ck call
seungmanhan Nov 10, 2023
8b801c8
clang format
seungmanhan Nov 10, 2023
d28a36a
add CK test
seungmanhan Nov 10, 2023
0298f0d
Merge remote-tracking branch 'origin/develop' into integrate_CK_layer…
seungmanhan Nov 13, 2023
201bb53
remove MIOPEN_BETA_API, organize problem description, change mean, rs…
seungmanhan Nov 13, 2023
9a5b754
init layernorm doc
seungmanhan Nov 13, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 3 additions & 10 deletions driver/driver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,11 +150,7 @@ inline void PadBufferSize(size_t& sz, int datatype_sz)
printf("Supported Base Arguments: conv[fp16|int8|bfp16|fp8|bfp8], CBAInfer[fp16], "
"pool[fp16], lrn[fp16], "
"activ[fp16], softmax[fp16], bnorm[fp16], rnn[fp16], gemm, ctc, dropout[fp16], "
"tensorop[fp16], reduce[fp16,fp64]"
#ifdef MIOPEN_BETA_API
", layernorm[bf16, fp16, fp32]"
#endif
junliume marked this conversation as resolved.
Show resolved Hide resolved
"\n");
"tensorop[fp16], reduce[fp16,fp64], layernorm[bfp16, fp16]\n");
exit(0); // NOLINT (concurrency-mt-unsafe)
}

Expand All @@ -175,11 +171,8 @@ inline std::string ParseBaseArg(int argc, char* argv[])
arg != "bnormfp16" && arg != "rnn" && arg != "rnnfp16" && arg != "rnn_seq" &&
arg != "rnn_seqfp16" && arg != "gemm" /*&& arg != "gemmfp16"*/ && arg != "ctc" &&
arg != "dropout" && arg != "dropoutfp16" && arg != "tensorop" && arg != "tensoropfp16" &&
arg != "reduce" && arg != "reducefp16" && arg != "reducefp64" &&
#ifdef MIOPEN_BETA_API
arg != "layernorm" && arg != "layernormfp16" && arg != "layernormbfp16" &&
#endif
junliume marked this conversation as resolved.
Show resolved Hide resolved
arg != "--version")
arg != "reduce" && arg != "reducefp16" && arg != "reducefp64" && arg != "layernorm" &&
arg != "layernormfp16" && arg != "layernormbfp16" && arg != "--version")
{
printf("FAILED: Invalid Base Input Argument\n");
Usage();
Expand Down
55 changes: 32 additions & 23 deletions driver/layernorm_driver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
*
*******************************************************************************/
#include <miopen/miopen.h>
#ifdef MIOPEN_BETA_API
junliume marked this conversation as resolved.
Show resolved Hide resolved
#ifndef GUARD_MIOPEN_LAYERNORM_DRIVER_HPP
#define GUARD_MIOPEN_LAYERNORM_DRIVER_HPP

Expand Down Expand Up @@ -164,7 +163,7 @@ int LayerNormDriver<Tgpu, Tref>::GetandSetData()
eps = static_cast<double>(inflags.GetValueDouble("eps"));
mode = miopenLayerNormMode_t(inflags.GetValueInt("mode"));

return (0);
return 0;
}

template <typename Tgpu, typename Tref>
Expand Down Expand Up @@ -200,24 +199,31 @@ std::vector<int> LayerNormDriver<Tgpu, Tref>::GetInputTensorLengthsFromCmdLine()
int in_h = inflags.GetValueInt("in_h");
int in_d = inflags.GetValueInt("in_d");

if(in_h != 0)
if((in_n != 0) && (in_c != 0) && (in_d != 0) && (in_h != 0) && (in_w != 0))
{
if(in_d != 0)
{
dim_size = 5;
return std::vector<int>({in_n, in_c, in_d, in_h, in_w});
}
else
{
dim_size = 4;
return std::vector<int>({in_n, in_c, in_h, in_w});
}
dim_size = 5;
return std::vector<int>({in_n, in_c, in_d, in_h, in_w});
}
else
else if((in_n != 0) && (in_c != 0) && (in_h != 0) && (in_w != 0))
{
dim_size = 4;
return std::vector<int>({in_n, in_c, in_h, in_w});
}
else if((in_n != 0) && (in_c != 0) && (in_w != 0))
{
dim_size = 3;
return std::vector<int>({in_n, in_c, in_w});
}
else if((in_n != 0) && (in_w != 0))
{
dim_size = 2;
return std::vector<int>({in_n, in_w});
}
else
{
std::cout << "Error Input Tensor Lengths\n" << std::endl;
return std::vector<int>({0});
}
}

template <typename Tgpu, typename Tref>
Expand All @@ -230,7 +236,6 @@ int LayerNormDriver<Tgpu, Tref>::AllocateBuffersAndCopy()
size_t mean_sz = GetTensorSize(meanDesc);
size_t rstd_sz = GetTensorSize(rstdDesc);

// MIOPEN_BACKEND_HIP
uint32_t ctx = 0;

in_dev = std::unique_ptr<GPUMem>(new GPUMem(ctx, in_sz, sizeof(Tgpu)));
Expand All @@ -250,7 +255,6 @@ int LayerNormDriver<Tgpu, Tref>::AllocateBuffersAndCopy()
meanhost = std::vector<Tref>(mean_sz, static_cast<Tref>(0));
rstdhost = std::vector<Tref>(rstd_sz, static_cast<Tref>(0));

// MIOPEN_BACKEND_HIP
int status;

for(int i = 0; i < in_sz; i++)
Expand All @@ -261,22 +265,28 @@ int LayerNormDriver<Tgpu, Tref>::AllocateBuffersAndCopy()

for(int i = 0; i < weight_sz; i++)
{
weight[i] = prng::gen_A_to_B<Tgpu>(static_cast<Tgpu>(0.0), static_cast<Tgpu>(1.0));
if(mode == MIOPEN_ELEMENTWISE_AFFINE)
weight[i] = static_cast<Tgpu>(1);
else
weight[i] = prng::gen_A_to_B<Tgpu>(static_cast<Tgpu>(0.0), static_cast<Tgpu>(1.0));
}
status = weight_dev->ToGPU(q, weight.data());
status |= weight_dev->ToGPU(q, weight.data());

for(int i = 0; i < bias_sz; i++)
{
bias[i] = prng::gen_A_to_B<Tgpu>(static_cast<Tgpu>(0.0), static_cast<Tgpu>(1.0));
if(mode == MIOPEN_ELEMENTWISE_AFFINE)
bias[i] = static_cast<Tgpu>(0);
else
bias[i] = prng::gen_A_to_B<Tgpu>(static_cast<Tgpu>(0.0), static_cast<Tgpu>(1.0));
}
status = bias_dev->ToGPU(q, bias.data());
status |= bias_dev->ToGPU(q, bias.data());

status |= out_dev->ToGPU(q, out.data());
status |= mean_dev->ToGPU(q, mean.data());
status |= rstd_dev->ToGPU(q, rstd.data());

if(status != CL_SUCCESS)
printf("Error copying data to GPU\n");
if(status != 0)
std::cout << "Error copying data to GPU\n" << std::endl;

return miopenStatusSuccess;
}
Expand Down Expand Up @@ -426,4 +436,3 @@ int LayerNormDriver<Tgpu, Tref>::VerifyBackward()
}

#endif // GUARD_MIOPEN_SOFTMAX_DRIVER_HPP
#endif
4 changes: 0 additions & 4 deletions driver/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,7 @@
#include "reduce_driver.hpp"
#include <miopen/config.h>
#include <miopen/stringutils.hpp>
#ifdef MIOPEN_BETA_API
#include "layernorm_driver.hpp"
#endif

int main(int argc, char* argv[])
{
Expand Down Expand Up @@ -199,7 +197,6 @@ int main(int argc, char* argv[])
{
drv = new ReduceDriver<double, double>();
}
#ifdef MIOPEN_BETA_API
else if(base_arg == "layernorm")
{
drv = new LayerNormDriver<float, float>();
Expand All @@ -212,7 +209,6 @@ int main(int argc, char* argv[])
{
drv = new LayerNormDriver<bfloat16, float>();
}
#endif
else
{
printf("Incorrect BaseArg\n");
Expand Down
6 changes: 2 additions & 4 deletions driver/mloLayerNormHost.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
* SOFTWARE.
*
*******************************************************************************/
#ifdef MIOPEN_BETA_API
#ifndef MLO_LAYERNORMHOST_H_
#define MLO_LAYERNORMHOST_H_

Expand Down Expand Up @@ -79,13 +78,12 @@ int32_t mloLayerNormForwardRunHost(miopenTensorDescriptor_t inputDesc,

for(i = 0; i < inner_size; i++)
{
Tcheck pweight = mode ? 1 : static_cast<Tcheck>(weight[i]);
Tcheck pbias = mode ? 0 : static_cast<Tcheck>(bias[i]);
Tcheck pweight = mode ? static_cast<Tcheck>(weight[i]) : 1;
Tcheck pbias = mode ? static_cast<Tcheck>(bias[i]) : 0;
outputhost[o * inner_size + i] =
(static_cast<Tcheck>(input[o * inner_size + i]) - pmean) * prstd * pweight + pbias;
}
}
return ret;
}
#endif
#endif
3 changes: 3 additions & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ set( MIOpen_Source
logger.cpp
layernorm_api.cpp
lrn_api.cpp
normalization/problem_description.cpp
op_args.cpp
operator.cpp
performance_config.cpp
Expand Down Expand Up @@ -166,6 +167,8 @@ set( MIOpen_Source
solver/batchnorm/forward_spatial_multiple.cpp
solver/batchnorm/forward_spatial_single.cpp
solver/batchnorm/forward_training_ck.cpp
solver/normalization/forward_layernorm.cpp
solver/normalization/forward_layernorm2d_ck.cpp
solver/conv_asm_1x1u.cpp
solver/conv_asm_1x1u_bias_activ_fused.cpp
solver/conv_asm_1x1u_stride2.cpp
Expand Down
5 changes: 5 additions & 0 deletions src/include/miopen/kernel_build_params.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,11 @@ struct GcnAsm
{
static std::string Generate(const std::vector<KernelBuildParameter>& options);
};

struct HIP
{
static std::string Generate(const std::vector<KernelBuildParameter>& options);
};
} // namespace kbp

} // namespace miopen
Expand Down
4 changes: 1 addition & 3 deletions src/include/miopen/layernorm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
*
*******************************************************************************/
#include <miopen/miopen.h>
#ifdef MIOPEN_BETA_API
junliume marked this conversation as resolved.
Show resolved Hide resolved
#ifndef MIOPEN_LAYERNORM_HPP_
#define MIOPEN_LAYERNORM_HPP_

Expand All @@ -35,7 +34,7 @@ namespace miopen {
struct Handle;
struct TensorDescriptor;

miopenStatus_t LayerNormForward(const Handle& handle,
seungmanhan marked this conversation as resolved.
Show resolved Hide resolved
miopenStatus_t LayerNormForward(Handle& handle,
const TensorDescriptor& xDesc,
ConstData_t x,
const TensorDescriptor& weightDesc,
Expand All @@ -54,4 +53,3 @@ miopenStatus_t LayerNormForward(const Handle& handle,

} // namespace miopen
#endif // _MIOPEN_LAYERNORM_HPP_
#endif
57 changes: 57 additions & 0 deletions src/include/miopen/normalization/invoke_params.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2023 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/

#pragma once

#include <miopen/invoke_params.hpp>
#include <miopen/tensor.hpp>

namespace miopen {
namespace normalization {

struct InvokeParams : public miopen::InvokeParams
{
InvokeParams() = default;

const TensorDescriptor* xDesc = nullptr;

ConstData_t x = nullptr;
ConstData_t weight = nullptr;
ConstData_t bias = nullptr;
Data_t y = nullptr;
Data_t mean = nullptr;
Data_t rstd = nullptr;
float epsilon = 0;
int32_t normalized_dim = 0;
miopenLayerNormMode_t mode = MIOPEN_ELEMENTWISE_AFFINE;

std::size_t GetWorkspaceSize() const { return 0; }
Data_t GetWorkspace() const { return nullptr; }
};

} // namespace normalization

} // namespace miopen
Loading