diff --git a/driver/conv_driver.hpp b/driver/conv_driver.hpp index a80f277039..5d627ba1a5 100644 --- a/driver/conv_driver.hpp +++ b/driver/conv_driver.hpp @@ -57,7 +57,6 @@ #include #include #include -#include #include "random.hpp" #include #include diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index e34dda2266..d4b9aba8a3 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -136,7 +136,6 @@ set( MIOpen_Source performance_config.cpp pooling/problem_description.cpp pooling_api.cpp - problem_description.cpp problem.cpp ramdb.cpp readonlyramdb.cpp diff --git a/src/conv/heuristics/ai_heuristics.cpp b/src/conv/heuristics/ai_heuristics.cpp index e6ac8d6b48..58fb88e9cd 100644 --- a/src/conv/heuristics/ai_heuristics.cpp +++ b/src/conv/heuristics/ai_heuristics.cpp @@ -120,9 +120,9 @@ class Model { } virtual ~Model() = default; - virtual bool IsProblemSupported(const ProblemDescription& problem, + virtual bool IsProblemSupported(const conv::ProblemDescription& problem, const ExecutionContext& ctx) const = 0; - std::vector Forward(const ProblemDescription& problem) const + std::vector Forward(const conv::ProblemDescription& problem) const { std::vector features = ToFeatures(problem); std::vector output = model.predict({fdeep::tensor(input_shape, features)}); @@ -142,14 +142,14 @@ class Model MIOPEN_THROW(miopenStatusInternalError, "Unable to load AI model file:" + file_path); return file_path; } - virtual std::vector ToFeatures(const ProblemDescription& problem) const = 0; + virtual std::vector ToFeatures(const conv::ProblemDescription& problem) const = 0; }; class Gfx908Model final : public Model { public: Gfx908Model() : Model("gfx908") {} - bool IsProblemSupported(const ProblemDescription& problem, + bool IsProblemSupported(const conv::ProblemDescription& problem, const ExecutionContext& ctx) const override { // check if problem is of the kind TunaNet was trained to handle @@ -216,7 +216,7 @@ class Gfx908Model final : public Model } protected: - std::vector ToFeatures(const ProblemDescription& problem) const override + std::vector ToFeatures(const conv::ProblemDescription& problem) const override { const bool isFwd = problem.GetDirection() == conv::Direction::Forward; std::vector features = { @@ -259,7 +259,7 @@ class Gfx90aModel final : public Model { public: Gfx90aModel() : Model("gfx90a") {} - bool IsProblemSupported(const ProblemDescription& problem, + bool IsProblemSupported(const conv::ProblemDescription& problem, const ExecutionContext& ctx) const override { // check if problem is of the kind TunaNet was trained to handle @@ -317,7 +317,7 @@ class Gfx90aModel final : public Model } protected: - std::vector ToFeatures(const ProblemDescription& problem) const override + std::vector ToFeatures(const conv::ProblemDescription& problem) const override { const bool isFwd = problem.GetDirection() == conv::Direction::Forward; std::vector features = { @@ -356,7 +356,7 @@ std::unique_ptr GetModel(const std::string& device) return std::make_unique(); } -std::vector PredictSolver(const ProblemDescription& problem, +std::vector PredictSolver(const conv::ProblemDescription& problem, const ExecutionContext& ctx, const std::string& device) { @@ -366,7 +366,7 @@ std::vector PredictSolver(const ProblemDescription& problem, std::string est_name = ":memory:" + device; auto& db = AnyRamDb::GetCached(est_name); - auto db_res = db.FindRecord(static_cast(problem)); + auto db_res = db.FindRecord(problem); if(db_res) { MIOPEN_LOG_I2("Cached heuristic (TunaNet) result found"); @@ -415,7 +415,7 @@ std::vector PredictSolver(const ProblemDescription& problem, sol.push_back(sol_id.Value()); any_sol.push_back(sol_id.Value()); } - db.StoreRecord(static_cast(problem), any_sol); + db.StoreRecord(problem, any_sol); if(miopen::IsLogging(LoggingLevel::Info2)) { std::stringstream ss; diff --git a/src/conv/invokers/impl_gemm.cpp b/src/conv/invokers/impl_gemm.cpp index 649e153491..a474730f9d 100644 --- a/src/conv/invokers/impl_gemm.cpp +++ b/src/conv/invokers/impl_gemm.cpp @@ -10,9 +10,9 @@ namespace miopen { namespace conv { -InvokerFactory MakeImplGemmDataInvokerFactory(const miopen::ProblemDescription& problem) +InvokerFactory MakeImplGemmDataInvokerFactory(const ProblemDescription& problem) { - if(problem.direction.IsForward()) + if(problem.IsDirectionForward()) { return [](const std::vector& kernels) { return [=](const Handle& handle, const AnyInvokeParams& primitive_parameters) { @@ -24,7 +24,7 @@ InvokerFactory MakeImplGemmDataInvokerFactory(const miopen::ProblemDescription& } else { - if(problem.direction.IsBackwardWrW()) + if(problem.IsDirectionBackwardWrW()) MIOPEN_THROW("MakeImplGemmDataInvokerFactory shouldn't be used for WrW invokers."); const auto& conv = problem.GetConv(); diff --git a/src/conv/invokers/impl_gemm_dynamic.cpp b/src/conv/invokers/impl_gemm_dynamic.cpp index 2416217ea2..03de29665b 100644 --- a/src/conv/invokers/impl_gemm_dynamic.cpp +++ b/src/conv/invokers/impl_gemm_dynamic.cpp @@ -86,8 +86,7 @@ static float CallImplGemmDynamicForward1x1(const miopen::Handle& handle, return elapsed; } -InvokerFactory -MakeImplGemmDynamicForward1x1InvokerFactory(const miopen::ProblemDescription& problem) +InvokerFactory MakeImplGemmDynamicForward1x1InvokerFactory(const ProblemDescription& problem) { return [problem](const std::vector& kernels) { return [=](const Handle& handle, const AnyInvokeParams& primitive_parameters) { @@ -112,10 +111,8 @@ MakeImplGemmDynamicForward1x1InvokerFactory(const miopen::ProblemDescription& pr }; } -template <> -InvokerFactory -MakeImplGemmDynamicBackwardDataInvokerFactory(const miopen::ProblemDescription& problem, - const int& cfg) +InvokerFactory MakeImplGemmDynamicBackwardDataInvokerFactory(const ProblemDescription& problem, + const int cfg) { int hi = problem.GetOutHeight_(); int wi = problem.GetOutWidth_(); @@ -249,10 +246,9 @@ MakeImplGemmDynamicBackwardDataInvokerFactory(const miopen::ProblemDescript }; } -template <> InvokerFactory -MakeImplGemmDynamicBackwardDataInvokerFactory( - const miopen::ProblemDescription& problem, const solver::TunableImplicitGemmGTCDynamic_t& cfg) +MakeImplGemmDynamicBackwardDataInvokerFactory(const ProblemDescription& problem, + const solver::TunableImplicitGemmGTCDynamic_t& cfg) { int hi = problem.GetOutHeight_(); int wi = problem.GetOutWidth_(); @@ -439,8 +435,8 @@ MakeImplGemmDynamicBackwardDataInvokerFactory& in_dims, std::vector& in_strides, std::vector& weights_dims, @@ -156,7 +156,7 @@ void InsertGToDimsStrides(const std::string& layout, strides.insert(strides.begin() + index, strides[index] * dims[index + 1]); } -void ComputeMlirDimsStrides(const conv::ProblemDescription& problem, +void ComputeMlirDimsStrides(const ProblemDescription& problem, std::vector& in_dims, std::vector& in_strides, std::vector& weights_dims, @@ -288,9 +288,9 @@ void SetMlirConvArgsPtr( #endif // MIOPEN_BACKEND_HIP } // Anonymous namespace -InvokerFactory MakeMlirFwdInvokerFactory(const miopen::ProblemDescription& problem) +InvokerFactory MakeMlirFwdInvokerFactory(const ProblemDescription& problem) { - assert((problem.direction.IsForward())); + assert((problem.IsDirectionForward())); std::vector in_dims, in_strides; std::vector weights_dims, weights_strides; @@ -354,9 +354,9 @@ InvokerFactory MakeMlirFwdInvokerFactory(const miopen::ProblemDescription& probl }; } -InvokerFactory MakeMlirBwdInvokerFactory(const miopen::ProblemDescription& problem) +InvokerFactory MakeMlirBwdInvokerFactory(const ProblemDescription& problem) { - assert(problem.direction.IsBackwardData()); + assert(problem.IsDirectionBackwardData()); std::vector in_dims, in_strides; std::vector weights_dims, weights_strides; @@ -409,10 +409,9 @@ InvokerFactory MakeMlirBwdInvokerFactory(const miopen::ProblemDescription& probl }; } -InvokerFactory MakeMlirWrWInvokerFactory(const miopen::ProblemDescription& problem, - size_t workspace_req) +InvokerFactory MakeMlirWrWInvokerFactory(const ProblemDescription& problem, size_t workspace_req) { - assert((problem.direction.IsBackwardWrW())); + assert((problem.IsDirectionBackwardWrW())); std::vector in_dims, in_strides; std::vector weights_dims, weights_strides; diff --git a/src/conv/solver_finders.cpp b/src/conv/solver_finders.cpp index 809f333bd8..c998fb75a4 100644 --- a/src/conv/solver_finders.cpp +++ b/src/conv/solver_finders.cpp @@ -30,7 +30,7 @@ #include #include #include -#include +#include namespace miopen { @@ -42,6 +42,9 @@ MIOPEN_DECLARE_ENV_VAR(MIOPEN_DEBUG_CONV_WINOGRAD) MIOPEN_DECLARE_ENV_VAR(MIOPEN_DEBUG_CONV_IMPLICIT_GEMM) MIOPEN_DECLARE_ENV_VAR(MIOPEN_DEBUG_CONV_FFT) +namespace conv { +namespace { + class DirectSolverFinder : public SolversFinderMixin { protected: @@ -178,6 +181,8 @@ class WinogradSolverFinder : public SolversFinderMixin>& GetConvSolverFinders() { static const auto finders = []() { @@ -193,14 +198,16 @@ const std::vector>& GetConvSolverFinders() return finders; } +} // namespace conv + /// Register invoker only for the best solution within algorithm. /// Add all solutions to the find-db record. -void EvaluateInvokers(Handle& handle, - const std::vector& solutions, - const AlgorithmName& algorithm_name, - const NetworkConfig& network_config, - const AnyInvokeParams& invoke_ctx, - DbRecord& record) +static void EvaluateInvokers(Handle& handle, + const std::vector& solutions, + const AlgorithmName& algorithm_name, + const NetworkConfig& network_config, + const AnyInvokeParams& invoke_ctx, + DbRecord& record) { const char* const arch = miopen::GetStringEnv(MIOPEN_DEVICE_ARCH{}); if(arch != nullptr && strlen(arch) > 0) @@ -309,6 +316,8 @@ void FindCore(const AnyInvokeParams& invoke_ctx, EvaluateInvokers(handle, ss.second, ss.first, network_config, invoke_ctx, record); } +namespace conv { + bool IsAlgorithmDisabled(miopenConvAlgorithm_t algo) { switch(algo) @@ -328,4 +337,5 @@ bool IsAlgorithmDisabled(miopenConvAlgorithm_t algo) } // clang-format on } +} // namespace conv } // namespace miopen diff --git a/src/convolution.cpp b/src/convolution.cpp index ac7c28fdc4..d0f9a64ffb 100644 --- a/src/convolution.cpp +++ b/src/convolution.cpp @@ -76,7 +76,7 @@ std::size_t GetMaxWorkSpaceSize(const std::vector= 16 && problem.GetOutChannels_() % 2 == 0)) return false; - return solver::ConvBinWinograd3x3U{}.IsApplicable(ctx, problem); + return solver::conv::ConvBinWinograd3x3U{}.IsApplicable(ctx, problem); } std::size_t ConvolutionDescriptor::GetWorkSpaceSize(ExecutionContext ctx, diff --git a/src/convolution_api.cpp b/src/convolution_api.cpp index 196ad5bcec..b0c751a46c 100644 --- a/src/convolution_api.cpp +++ b/src/convolution_api.cpp @@ -32,7 +32,7 @@ #include #include #include -#include +#include #include #include diff --git a/src/execution_context.cpp b/src/execution_context.cpp index a9cd8806b2..fb64a10c92 100644 --- a/src/execution_context.cpp +++ b/src/execution_context.cpp @@ -33,7 +33,6 @@ #if MIOPEN_BACKEND_OPENCL #include #endif -#include #include #include diff --git a/src/fusion/problem_description.cpp b/src/fusion/problem_description.cpp index 9c9abdc79d..2935f1adf4 100644 --- a/src/fusion/problem_description.cpp +++ b/src/fusion/problem_description.cpp @@ -30,7 +30,7 @@ namespace miopen { -miopen::ProblemDescription FusionDescription::GetConvProblem(conv::Direction dir, int bias) const +conv::ProblemDescription FusionDescription::GetConvProblem(conv::Direction dir, int bias) const { const auto idx = [&]() { switch(dir) diff --git a/src/include/miopen/any_solver.hpp b/src/include/miopen/any_solver.hpp index b2f177b6ea..7edbdf7f03 100644 --- a/src/include/miopen/any_solver.hpp +++ b/src/include/miopen/any_solver.hpp @@ -46,7 +46,8 @@ struct AnySolver AnySolver() : ptr_value(nullptr){}; template AnySolver(U src) : ptr_value(new AnySolver_tmpl(std::forward(src))){}; - bool IsApplicable(const ExecutionContext& ctx, const ProblemDescription& problem) const + bool IsApplicable(const ExecutionContext& ctx, + const miopen::conv::ProblemDescription& problem) const { assert(ptr_value != nullptr); return ptr_value->IsApplicable(ctx, problem); @@ -57,14 +58,14 @@ struct AnySolver return ptr_value->IsTunable(); }; bool TestPerfCfgParams(const ExecutionContext& ctx, - const ProblemDescription& problem, + const miopen::conv::ProblemDescription& problem, const std::string& params) const { assert(ptr_value != nullptr); return ptr_value->TestPerfCfgParams(ctx, problem, params); }; std::vector GetAllSolutions(const ExecutionContext& ctx, - const ProblemDescription& problem) const + const miopen::conv::ProblemDescription& problem) const { assert(ptr_value != nullptr); return ptr_value->GetAllSolutions(ctx, problem); @@ -74,7 +75,7 @@ struct AnySolver assert(ptr_value != nullptr); return ptr_value->IsDynamic(); }; - float GetWti(const ExecutionContext& ctx, const ProblemDescription& problem) const + float GetWti(const ExecutionContext& ctx, const miopen::conv::ProblemDescription& problem) const { assert(ptr_value != nullptr); return ptr_value->GetWti(ctx, problem); @@ -86,7 +87,7 @@ struct AnySolver }; bool IsEmpty() const { return ptr_value == nullptr; }; ConvSolution FindSolution(const ExecutionContext& ctx, - const ProblemDescription& problem, + const miopen::conv::ProblemDescription& problem, PerformanceDb& db, const miopen::AnyInvokeParams& invoke_ctx, const std::string& perf_cfg = "") const @@ -95,7 +96,7 @@ struct AnySolver return ptr_value->FindSolution(ctx, problem, db, invoke_ctx, perf_cfg); }; std::string GetPerfCfgParams(const ExecutionContext& ctx, - const ProblemDescription& problem, + const miopen::conv::ProblemDescription& problem, PerformanceDb& db) const { assert(ptr_value != nullptr); @@ -107,7 +108,8 @@ struct AnySolver return ptr_value->GetSolverDbId(); } - size_t GetWorkspaceSize(const ExecutionContext& ctx, const ProblemDescription& problem) const + size_t GetWorkspaceSize(const ExecutionContext& ctx, + const miopen::conv::ProblemDescription& problem) const { assert(ptr_value != nullptr); return ptr_value->GetWorkspaceSize(ctx, problem); @@ -126,29 +128,30 @@ struct AnySolver virtual ~AnySolver_base(){}; virtual bool IsApplicable(const ExecutionContext& ctx, - const ProblemDescription& problem) const = 0; - virtual bool IsTunable() const = 0; + const miopen::conv::ProblemDescription& problem) const = 0; + virtual bool IsTunable() const = 0; virtual bool TestPerfCfgParams(const ExecutionContext& ctx, - const ProblemDescription& problem, - const std::string& params) const = 0; + const miopen::conv::ProblemDescription& problem, + const std::string& params) const = 0; virtual std::vector - GetAllSolutions(const ExecutionContext& ctx, const ProblemDescription& problem) const = 0; - virtual bool IsDynamic() const = 0; + GetAllSolutions(const ExecutionContext& ctx, + const miopen::conv::ProblemDescription& problem) const = 0; + virtual bool IsDynamic() const = 0; virtual float GetWti(const ExecutionContext& ctx, - const ProblemDescription& problem) const = 0; - virtual const std::type_info& Type() const = 0; - virtual std::string GetSolverDbId() const = 0; + const miopen::conv::ProblemDescription& problem) const = 0; + virtual const std::type_info& Type() const = 0; + virtual std::string GetSolverDbId() const = 0; virtual ConvSolution FindSolution(const ExecutionContext& ctx, - const ProblemDescription& problem, + const miopen::conv::ProblemDescription& problem, PerformanceDb& db, const miopen::AnyInvokeParams& invoke_ctx, - const std::string& perf_cfg) const = 0; + const std::string& perf_cfg) const = 0; virtual std::string GetPerfCfgParams(const ExecutionContext& ctx, - const ProblemDescription& problem, - PerformanceDb& db) const = 0; + const miopen::conv::ProblemDescription& problem, + PerformanceDb& db) const = 0; virtual size_t GetWorkspaceSize(const ExecutionContext& ctx, - const ProblemDescription& problem) const = 0; - virtual bool MayNeedWorkspace() const = 0; + const miopen::conv::ProblemDescription& problem) const = 0; + virtual bool MayNeedWorkspace() const = 0; }; // templated derived class @@ -161,7 +164,7 @@ struct AnySolver static constexpr auto Test(U*) -> typename std::is_class().GetDefaultPerformanceConfig( std::declval(), - std::declval()))>::type; + std::declval()))>::type; template static constexpr std::false_type Test(...); @@ -173,11 +176,11 @@ struct AnySolver struct LegacySolver { template - static constexpr auto Test(U*) -> - typename std::is_same().GetDefaultPerformanceConfig( - std::declval(), - std::declval()))>::type; + static constexpr auto Test(U*) -> typename std::is_same< + LegacyPerformanceConfig, + decltype(std::declval().GetDefaultPerformanceConfig( + std::declval(), + std::declval()))>::type; template static constexpr std::false_type Test(...); @@ -187,13 +190,13 @@ struct AnySolver }; bool TestPerfCfgParams(const ExecutionContext& ctx, - const ProblemDescription& problem, + const miopen::conv::ProblemDescription& problem, const std::string& params, std::true_type) const { using PerformanceConfig = decltype(value.GetDefaultPerformanceConfig( std::declval(), - std::declval())); + std::declval())); PerformanceConfig config{}; bool success = config.Deserialize(params); @@ -209,7 +212,7 @@ struct AnySolver return success; } bool TestPerfCfgParams(const ExecutionContext&, - const ProblemDescription&, + const miopen::conv::ProblemDescription&, const std::string&, std::false_type) const { @@ -217,7 +220,7 @@ struct AnySolver } bool TestPerfCfgParams(const ExecutionContext& ctx, - const ProblemDescription& problem, + const miopen::conv::ProblemDescription& problem, const std::string& params) const override { return TestPerfCfgParams( @@ -226,7 +229,7 @@ struct AnySolver // tunable legacy solver std::vector GetAllSolutions(const ExecutionContext&, - const ProblemDescription&, + const miopen::conv::ProblemDescription&, std::true_type, std::true_type) const { @@ -235,7 +238,7 @@ struct AnySolver // tunable solver, not legacy std::vector GetAllSolutions(const ExecutionContext& ctx, - const ProblemDescription& problem, + const miopen::conv::ProblemDescription& problem, std::true_type, std::false_type) const { @@ -244,7 +247,7 @@ struct AnySolver // non tunable solver std::vector GetAllSolutions(const ExecutionContext& ctx, - const ProblemDescription& problem, + const miopen::conv::ProblemDescription& problem, std::false_type, std::true_type) const { @@ -253,7 +256,7 @@ struct AnySolver return solutions; } std::vector GetAllSolutions(const ExecutionContext& ctx, - const ProblemDescription& problem, + const miopen::conv::ProblemDescription& problem, std::false_type, std::false_type) const { @@ -262,8 +265,9 @@ struct AnySolver return solutions; } - std::vector GetAllSolutions(const ExecutionContext& ctx, - const ProblemDescription& problem) const override + std::vector + GetAllSolutions(const ExecutionContext& ctx, + const miopen::conv::ProblemDescription& problem) const override { return GetAllSolutions(ctx, problem, @@ -274,19 +278,20 @@ struct AnySolver AnySolver_tmpl(T obj) : value(std::move(obj)){}; bool IsApplicable(const ExecutionContext& ctx, - const ProblemDescription& problem) const override + const miopen::conv::ProblemDescription& problem) const override { return value.IsApplicable(ctx, problem); } bool IsTunable() const override { return TunableSolver::Is; } bool IsDynamic() const override { return value.IsDynamic(); } - float GetWti(const ExecutionContext& ctx, const ProblemDescription& problem) const override + float GetWti(const ExecutionContext& ctx, + const miopen::conv::ProblemDescription& problem) const override { return value.GetWti(ctx, problem); } ConvSolution FindSolution(const ExecutionContext& ctx, - const ProblemDescription& problem, + const miopen::conv::ProblemDescription& problem, PerformanceDb& db, const miopen::AnyInvokeParams& invoke_ctx, const std::string& perf_cfg) const override @@ -295,7 +300,7 @@ struct AnySolver }; std::string GetPerfCfgParams(const ExecutionContext& ctx, - const ProblemDescription& problem, + const miopen::conv::ProblemDescription& problem, PerformanceDb& db, std::true_type) const { @@ -327,7 +332,7 @@ struct AnySolver return config.ToString(); } std::string GetPerfCfgParams(const ExecutionContext&, - const ProblemDescription&, + const miopen::conv::ProblemDescription&, const PerformanceDb&, std::false_type) const { @@ -336,7 +341,7 @@ struct AnySolver } std::string GetPerfCfgParams(const ExecutionContext& ctx, - const ProblemDescription& problem, + const miopen::conv::ProblemDescription& problem, PerformanceDb& db) const override { return GetPerfCfgParams( @@ -344,7 +349,7 @@ struct AnySolver } size_t GetWorkspaceSize(const ExecutionContext& ctx, - const ProblemDescription& problem) const override + const miopen::conv::ProblemDescription& problem) const override { return value.GetWorkspaceSize(ctx, problem); } diff --git a/src/include/miopen/conv/compiled_in_parameters.hpp b/src/include/miopen/conv/compiled_in_parameters.hpp index 00543bce20..f3d8cf2763 100644 --- a/src/include/miopen/conv/compiled_in_parameters.hpp +++ b/src/include/miopen/conv/compiled_in_parameters.hpp @@ -27,8 +27,7 @@ #pragma once #include -#include -#include +#include #include @@ -38,7 +37,7 @@ namespace miopen { * arguments. */ inline void GetCompiledInParameters(const ExecutionContext& ctx, - const ProblemDescription& problem, + const conv::ProblemDescription& problem, int* const N, int* const C, int* const H, @@ -56,7 +55,7 @@ inline void GetCompiledInParameters(const ExecutionContext& ctx, } inline void GetCompiledInParameters(const ExecutionContext& ctx, - const ProblemDescription& problem, + const conv::ProblemDescription& problem, int* const N, int* const C, int* const H, @@ -73,7 +72,7 @@ inline void GetCompiledInParameters(const ExecutionContext& ctx, } inline void GetCompiledInParameters(const ExecutionContext& ctx, - const ProblemDescription& problem, + const conv::ProblemDescription& problem, int* const N, int* const C, int* const H, @@ -91,8 +90,8 @@ inline void GetCompiledInParameters(const ExecutionContext& ctx, assert(filter_size_H && filter_size_W && pad_H && pad_W); *filter_size_H = problem.GetWeightsHeight_(); *filter_size_W = problem.GetWeightsWidth_(); - *pad_H = problem.direction.IsForward() ? problem.GetPadH() : problem.GetBackwardPadH(); - *pad_W = problem.direction.IsForward() ? problem.GetPadW() : problem.GetBackwardPadW(); + *pad_H = problem.IsDirectionForward() ? problem.GetPadH() : problem.GetBackwardPadH(); + *pad_W = problem.IsDirectionForward() ? problem.GetPadW() : problem.GetBackwardPadW(); } } // namespace miopen diff --git a/src/include/miopen/conv/heuristics/ai_heuristics.hpp b/src/include/miopen/conv/heuristics/ai_heuristics.hpp index 7da9497070..2f0a4770d0 100644 --- a/src/include/miopen/conv/heuristics/ai_heuristics.hpp +++ b/src/include/miopen/conv/heuristics/ai_heuristics.hpp @@ -70,7 +70,7 @@ struct Metadata size_t EncodeLayout(const std::string& layout) const; }; class Model; -std::vector PredictSolver(const ProblemDescription& problem, +std::vector PredictSolver(const conv::ProblemDescription& problem, const ExecutionContext& ctx, const std::string& device); } // namespace immed_mode diff --git a/src/include/miopen/conv/invokers/impl_gemm.hpp b/src/include/miopen/conv/invokers/impl_gemm.hpp index 23c5afddc5..a199f5337a 100644 --- a/src/include/miopen/conv/invokers/impl_gemm.hpp +++ b/src/include/miopen/conv/invokers/impl_gemm.hpp @@ -28,14 +28,14 @@ #include #include -#include +#include #include namespace miopen { namespace conv { -InvokerFactory MakeImplGemmDataInvokerFactory(const miopen::ProblemDescription& problem); +InvokerFactory MakeImplGemmDataInvokerFactory(const ProblemDescription& problem); } // namespace conv } // namespace miopen diff --git a/src/include/miopen/conv/invokers/impl_gemm_dynamic.hpp b/src/include/miopen/conv/invokers/impl_gemm_dynamic.hpp index b1a0e426a0..8d98c8db65 100644 --- a/src/include/miopen/conv/invokers/impl_gemm_dynamic.hpp +++ b/src/include/miopen/conv/invokers/impl_gemm_dynamic.hpp @@ -185,7 +185,7 @@ ComputeDynamicIGemmForwardKernelArgs( template static inline InvokerFactory -MakeImplGemmDynamicForwardInvokerFactory(const miopen::ProblemDescription& problem, const T& cfg) +MakeImplGemmDynamicForwardInvokerFactory(const ProblemDescription& problem, const T& cfg) { auto opArgs = ComputeDynamicIGemmForwardKernelArgs(problem, cfg); return [opArgs](const std::vector& kernels) mutable { @@ -203,34 +203,26 @@ MakeImplGemmDynamicForwardInvokerFactory(const miopen::ProblemDescription& probl }; } -InvokerFactory -MakeImplGemmDynamicForward1x1InvokerFactory(const miopen::ProblemDescription& problem); +InvokerFactory MakeImplGemmDynamicForward1x1InvokerFactory(const ProblemDescription& problem); -template -InvokerFactory -MakeImplGemmDynamicBackwardDataInvokerFactory(const miopen::ProblemDescription& problem, - const T& cfg); +InvokerFactory MakeImplGemmDynamicBackwardDataInvokerFactory(const ProblemDescription& problem, + int cfg); -template <> InvokerFactory -MakeImplGemmDynamicBackwardDataInvokerFactory(const miopen::ProblemDescription& problem, - const int& cfg); - -template <> -InvokerFactory -MakeImplGemmDynamicBackwardDataInvokerFactory( - const miopen::ProblemDescription& problem, const solver::TunableImplicitGemmGTCDynamic_t& cfg); +MakeImplGemmDynamicBackwardDataInvokerFactory(const ProblemDescription& problem, + const solver::TunableImplicitGemmGTCDynamic_t& cfg); InvokerFactory MakeImplGemmDynamicForwardXdlopsNHWCInvokerFactory( const ExecutionContext& ctx, - const miopen::ProblemDescription& problem, - const solver::PerformanceConfigAsmImplicitGemmGTCFwdXdlopsNHWC& config); + const ProblemDescription& problem, + const solver::conv::PerformanceConfigAsmImplicitGemmGTCFwdXdlopsNHWC& config); InvokerFactory MakeImplGemmDynamicBackwardDataXdlopsNHWCInvokerFactory( const ExecutionContext& ctx, - const miopen::ProblemDescription& problem, - const solver::PerformanceConfigAsmImplicitGemmGTCBwdXdlopsNHWC& config); + const ProblemDescription& problem, + const solver::conv::PerformanceConfigAsmImplicitGemmGTCBwdXdlopsNHWC& config); InvokerFactory MakeImplGemmDynamicForwardDlopsNCHWCInvokerFactory( - const miopen::ProblemDescription& problem, - const solver::PerformanceConfigAsmImplicitGemmGTCFwdDlopsNCHWC& config); + const ProblemDescription& problem, + const solver::conv::PerformanceConfigAsmImplicitGemmGTCFwdDlopsNCHWC& config); + } // namespace conv } // namespace miopen diff --git a/src/include/miopen/conv/invokers/mlir_impl_gemm.hpp b/src/include/miopen/conv/invokers/mlir_impl_gemm.hpp index db467daad1..64e3d25dfc 100644 --- a/src/include/miopen/conv/invokers/mlir_impl_gemm.hpp +++ b/src/include/miopen/conv/invokers/mlir_impl_gemm.hpp @@ -27,15 +27,14 @@ #pragma once #include -#include +#include namespace miopen { namespace conv { -InvokerFactory MakeMlirFwdInvokerFactory(const miopen::ProblemDescription& problem); -InvokerFactory MakeMlirBwdInvokerFactory(const miopen::ProblemDescription& problem); -InvokerFactory MakeMlirWrWInvokerFactory(const miopen::ProblemDescription& problem, - size_t workspace_req); +InvokerFactory MakeMlirFwdInvokerFactory(const ProblemDescription& problem); +InvokerFactory MakeMlirBwdInvokerFactory(const ProblemDescription& problem); +InvokerFactory MakeMlirWrWInvokerFactory(const ProblemDescription& problem, size_t workspace_req); } // namespace conv } // namespace miopen diff --git a/src/include/miopen/conv/problem_description.hpp b/src/include/miopen/conv/problem_description.hpp index 2a6083d684..0b6751f333 100644 --- a/src/include/miopen/conv/problem_description.hpp +++ b/src/include/miopen/conv/problem_description.hpp @@ -288,6 +288,9 @@ struct ProblemDescription : ProblemDescriptionBase const ConvolutionDescriptor& GetConv() const { return conv; } Direction GetDirection() const { return direction; } + bool IsDirectionForward() const { return direction == conv::Direction::Forward; } + bool IsDirectionBackwardData() const { return direction == conv::Direction::BackwardData; } + bool IsDirectionBackwardWrW() const { return direction == conv::Direction::BackwardWeights; } std::string GetDirectionStr() const; int GetBias() const { return bias; } diff --git a/src/include/miopen/conv/solver_finders.hpp b/src/include/miopen/conv/solver_finders.hpp index beba3d8a4c..05d7c13b62 100644 --- a/src/include/miopen/conv/solver_finders.hpp +++ b/src/include/miopen/conv/solver_finders.hpp @@ -30,7 +30,6 @@ #include #include #include -#include #include #include @@ -138,8 +137,12 @@ class SolversFinderMixin : public ISolversFinder const FindParameters& parameters) const = 0; }; +namespace conv { + const std::vector>& GetConvSolverFinders(); +} // namespace conv + void FindCore(const AnyInvokeParams& invoke_ctx, DbRecord& record, const ExecutionContext& ctx, @@ -147,6 +150,8 @@ void FindCore(const AnyInvokeParams& invoke_ctx, const PrimitiveFindParameters& parameters, const std::vector>& finders); +namespace conv { + bool IsAlgorithmDisabled(miopenConvAlgorithm_t algo); struct ConvFindParameters : PrimitiveFindParameters @@ -155,4 +160,5 @@ struct ConvFindParameters : PrimitiveFindParameters ConvFindParameters(bool use_winograd_only_) : use_winograd_only(use_winograd_only_) {} }; +} // namespace conv } // namespace miopen diff --git a/src/include/miopen/convolution.hpp b/src/include/miopen/convolution.hpp index 35c494eab2..28b68a35b0 100644 --- a/src/include/miopen/convolution.hpp +++ b/src/include/miopen/convolution.hpp @@ -63,26 +63,9 @@ namespace solver { struct ConvSolution; } // namespace solver -struct AnyInvokeParams; -struct ExecutionContext; struct ExecutionContext; struct Handle; struct TensorDescriptor; -struct ProblemDescription; -struct ConvFwdTensors; -struct ConvWrwTensors; - -using ExtraKernelArgs = std::tuple; - -struct ConvFwdTensors; -struct ConvWrwTensors; struct ConvolutionAttribute { @@ -210,7 +193,7 @@ struct ConvolutionDescriptor : miopenConvolutionDescriptor miopenDataType_t yType = miopenFloat) const; bool IsWinograd3x3SupportedAndFast(const miopen::ExecutionContext& ctx, - const ProblemDescription& problem) const; + const conv::ProblemDescription& problem) const; std::size_t GetWorkSpaceSize(ExecutionContext ctx, const conv::ProblemDescription& problem) const; @@ -229,31 +212,6 @@ struct ConvolutionDescriptor : miopenConvolutionDescriptor std::size_t workSpaceSize, bool exhaustiveSearch) const; - std::vector - FindWinogradSolutions(const ExecutionContext& ctx, - const ProblemDescription& problem, - const AnyInvokeParams& invoke_ctx) const; - - std::vector - FindWinogradSolutions(const ExecutionContext& ctx, const AnyInvokeParams& invoke_ctx) const; - - std::vector - FindDataGemmSolutions(const ExecutionContext& ctx, const AnyInvokeParams& invoke_ctx) const; - - std::vector - FindDataImplicitGemmSolutions(Handle& handle, - const TensorDescriptor& xDesc, - const TensorDescriptor& wDesc, - const TensorDescriptor& yDesc, - bool exhaustiveSearch, - bool isForward, - const AnyInvokeParams& invoke_ctx) const; - - std::vector - FindFftSolutions(const ExecutionContext& ctx, - const ProblemDescription& problem, - const AnyInvokeParams& invoke_ctx) const; - void ConvolutionForward(Handle& handle, const void* alpha, const TensorDescriptor& xDesc, @@ -267,10 +225,10 @@ struct ConvolutionDescriptor : miopenConvolutionDescriptor Data_t workSpace, std::size_t workSpaceSize) const; - std::size_t GetSolutionCount(const ExecutionContext& exec_ctx, + std::size_t GetSolutionCount(const ExecutionContext& ctx, const conv::ProblemDescription& problem) const; - std::vector GetSolutions(const ExecutionContext& exec_ctx, + std::vector GetSolutions(const ExecutionContext& ctx, const conv::ProblemDescription& problem, size_t maxSolutionCount, bool* fallbackPathTaken) const; @@ -400,7 +358,7 @@ struct ConvolutionDescriptor : miopenConvolutionDescriptor const conv::ProblemDescription& problem, size_t maxSolutionCount) const; - std::size_t GetSolutionCountFallback(const ExecutionContext& exec_ctx, + std::size_t GetSolutionCountFallback(const ExecutionContext& ctx, const conv::ProblemDescription& problem) const; friend void to_json(nlohmann::json& json, const ConvolutionDescriptor& conv); @@ -430,6 +388,7 @@ void DumpTensorToFileFromDevice(const miopen::Handle& handle, const std::string& filename); } // namespace miopen + MIOPEN_DEFINE_OBJECT(miopenConvolutionDescriptor, miopen::ConvolutionDescriptor); #endif // GUARD_MIOPEN_CONVOLUTION_HPP_ diff --git a/src/include/miopen/driver_arguments.hpp b/src/include/miopen/driver_arguments.hpp index b4692d040f..da4064b7f0 100644 --- a/src/include/miopen/driver_arguments.hpp +++ b/src/include/miopen/driver_arguments.hpp @@ -33,7 +33,6 @@ #include #include #include -#include #include #include diff --git a/src/include/miopen/execution_context.hpp b/src/include/miopen/execution_context.hpp index d2195d6061..a308338eb0 100644 --- a/src/include/miopen/execution_context.hpp +++ b/src/include/miopen/execution_context.hpp @@ -65,10 +65,6 @@ class rocm_meta_version namespace miopen { -namespace conv { -struct ProblemDescription; -} // namespace conv - struct ExecutionContext { // Solution-specific diff --git a/src/include/miopen/fusion.hpp b/src/include/miopen/fusion.hpp index 7baa7b9ee4..f00e5d803a 100644 --- a/src/include/miopen/fusion.hpp +++ b/src/include/miopen/fusion.hpp @@ -226,7 +226,7 @@ struct ConvForwardOpDescriptor : FusionOpDescriptor std::string conv_compiler_options; private: - ProblemDescription GetConvProblem(); + conv::ProblemDescription GetConvProblem(); }; namespace fusion { diff --git a/src/include/miopen/fusion/context.hpp b/src/include/miopen/fusion/context.hpp index 43190e6807..65f7cd7b8c 100644 --- a/src/include/miopen/fusion/context.hpp +++ b/src/include/miopen/fusion/context.hpp @@ -27,7 +27,7 @@ #pragma once #include -#include +#include namespace miopen { @@ -37,7 +37,7 @@ struct FusionContext : ExecutionContext { explicit FusionContext(Handle& handle) : ExecutionContext(&handle) {} - ExecutionContext GetConvContext(const ProblemDescription& conv_problem) const + ExecutionContext GetConvContext(const conv::ProblemDescription& conv_problem) const { auto ctx = ExecutionContext{*this}; conv_problem.SetupFloats(ctx); diff --git a/src/include/miopen/fusion/problem_description.hpp b/src/include/miopen/fusion/problem_description.hpp index e6256661d8..cd9c100c04 100644 --- a/src/include/miopen/fusion/problem_description.hpp +++ b/src/include/miopen/fusion/problem_description.hpp @@ -28,6 +28,7 @@ #include #include +#include #include namespace miopen { @@ -73,12 +74,12 @@ struct FusionDescription : ProblemDescriptionBase static void Visit(Self&& self, F f) { auto conv_prob = self.GetConvProblem(conv::Direction::Forward); - ProblemDescription::Visit(conv_prob, f); + conv::ProblemDescription::Visit(conv_prob, f); } #endif // This and the following method should be moved to the Ops once the return type can be unified - miopen::ProblemDescription GetConvProblem(size_t idx, conv::Direction dir, int bias = 0) const + conv::ProblemDescription GetConvProblem(size_t idx, conv::Direction dir, int bias = 0) const { const auto& conv_op = dynamic_cast(*fusion_plan_desc->op_map[idx]); @@ -86,12 +87,12 @@ struct FusionDescription : ProblemDescriptionBase { TensorDescriptor out_desc; conv_op.GetOutputDesc(out_desc); - return miopen::conv::ProblemDescription{conv_op.input_desc, - conv_op.filter_desc, - out_desc, - conv_op.base_desc /* conv desc */, - dir, - bias}; + return conv::ProblemDescription{conv_op.input_desc, + conv_op.filter_desc, + out_desc, + conv_op.base_desc /* conv desc */, + dir, + bias}; } else { @@ -99,7 +100,7 @@ struct FusionDescription : ProblemDescriptionBase } } - miopen::ProblemDescription GetConvProblem(conv::Direction dir, int bias = 0) const; + conv::ProblemDescription GetConvProblem(conv::Direction dir, int bias = 0) const; miopen::batchnorm::ProblemDescription GetBnProblem(size_t idx, miopen::batchnorm::Direction dir) const diff --git a/src/include/miopen/fusion/solvers.hpp b/src/include/miopen/fusion/solvers.hpp index dd8f2df494..0d9b84c918 100644 --- a/src/include/miopen/fusion/solvers.hpp +++ b/src/include/miopen/fusion/solvers.hpp @@ -124,7 +124,7 @@ struct FusionTunableSolver : FusionTunableSolverBase } }; -struct PerformanceConfigConvBiasActivAsm1x1U : PerformanceConfigConvAsm1x1U +struct PerformanceConfigConvBiasActivAsm1x1U : conv::PerformanceConfigConvAsm1x1U { PerformanceConfigConvBiasActivAsm1x1U(const bool spare) : PerformanceConfigConvAsm1x1U(spare) {} PerformanceConfigConvBiasActivAsm1x1U() @@ -217,9 +217,9 @@ struct PerformanceConfigConvCKIgemmFwdBiasActivFused private: template - void Init(const ProblemDescription&); + void Init(const miopen::conv::ProblemDescription&); template - bool CheckIsSupportCKArgs(const ProblemDescription&) const; + bool CheckIsSupportCKArgs(const miopen::conv::ProblemDescription&) const; }; struct ConvCKIgemmFwdBiasActivFused final @@ -250,7 +250,7 @@ struct ConvCKIgemmFwdBiasActivFused final private: template - bool CheckCKApplicability(const ProblemDescription&) const; + bool CheckCKApplicability(const miopen::conv::ProblemDescription&) const; }; struct ConvBinWinogradRxSFused final : FusionSolverBase diff --git a/src/include/miopen/fusion/utils.hpp b/src/include/miopen/fusion/utils.hpp index 5669de990f..26e901d360 100644 --- a/src/include/miopen/fusion/utils.hpp +++ b/src/include/miopen/fusion/utils.hpp @@ -88,7 +88,7 @@ inline bool WinoCommonIsApplicable(const FusionContext& context, const FusionDes return false; if(!conv_problem.IsLayoutDefault()) return false; - if(!conv_problem.direction.IsForward()) + if(!conv_problem.IsDirectionForward()) return false; const auto target = conv_ctx.GetStream().GetTargetProperties(); if(target.Xnack() && *target.Xnack()) diff --git a/src/include/miopen/mlo_internal.hpp b/src/include/miopen/mlo_internal.hpp index f5e7d2fb83..017efd37a1 100644 --- a/src/include/miopen/mlo_internal.hpp +++ b/src/include/miopen/mlo_internal.hpp @@ -71,6 +71,7 @@ POSSIBILITY OF SUCH DAMAGE. #include #include #include +#include #include #if MIOPEN_BACKEND_OPENCL @@ -180,74 +181,74 @@ auto mloConstruct(T& x) -> decltype(x.mloConstruct(), void()) std::vector FindAllGemmSolutions(const miopen::ExecutionContext& ctx, - const miopen::ProblemDescription& problem, + const miopen::conv::ProblemDescription& problem, const miopen::AnyInvokeParams& invoke_ctx); std::vector> AllGemmWorkspaceSize(const miopen::ExecutionContext& ctx, - const miopen::ProblemDescription& problem); + const miopen::conv::ProblemDescription& problem); std::vector> AllDirectForwardBackwardDataWorkspaceSize(const miopen::ExecutionContext& ctx, - const miopen::ProblemDescription& problem); + const miopen::conv::ProblemDescription& problem); std::vector> FindAllImplicitGemmWorkspaceSizes(const miopen::ExecutionContext& ctx, - const miopen::ProblemDescription& problem); + const miopen::conv::ProblemDescription& problem); std::vector> FindAllWinogradWorkspaceSizes(const miopen::ExecutionContext& ctx, - const miopen::ProblemDescription& problem); + const miopen::conv::ProblemDescription& problem); std::vector> FindWinogradWrWWorkspaceSizes(const miopen::ExecutionContext& ctx, - const miopen::ProblemDescription& problem); + const miopen::conv::ProblemDescription& problem); std::vector> FindImplicitGemmWrWWorkspaceSizes(const miopen::ExecutionContext& ctx, - const miopen::ProblemDescription& problem); + const miopen::conv::ProblemDescription& problem); std::vector> AllDirectBwdWrW2DWorkspaceSize(const miopen::ExecutionContext& ctx, - const miopen::ProblemDescription& problem); + const miopen::conv::ProblemDescription& problem); std::vector> AllFFTForwardBackwardDataWorkspaceSize(const miopen::ExecutionContext& ctx, - const miopen::ProblemDescription& problem); + const miopen::conv::ProblemDescription& problem); std::vector FindAllDirectSolutions(const miopen::ExecutionContext& ctx, - const miopen::ProblemDescription& problem, + const miopen::conv::ProblemDescription& problem, const miopen::AnyInvokeParams& invoke_ctx); std::vector FindAllImplicitGemmSolutions(const miopen::ExecutionContext& ctx, - const miopen::ProblemDescription& problem, + const miopen::conv::ProblemDescription& problem, const miopen::AnyInvokeParams& invoke_ctx); std::vector FindAllWinogradSolutions(const miopen::ExecutionContext& ctx, - const miopen::ProblemDescription& problem, + const miopen::conv::ProblemDescription& problem, const miopen::AnyInvokeParams& invoke_ctx); std::vector FindWinogradWrWAllSolutions(const miopen::ExecutionContext& ctx, - const miopen::ProblemDescription& problem, + const miopen::conv::ProblemDescription& problem, const miopen::AnyInvokeParams& invoke_ctx); std::vector FindImplicitGemmWrWAllSolutions(const miopen::ExecutionContext& ctx, - const miopen::ProblemDescription& problem, + const miopen::conv::ProblemDescription& problem, const miopen::AnyInvokeParams& invoke_ctx); std::vector FindAllBwdWrW2DSolutions(const miopen::ExecutionContext& ctx, - const miopen::ProblemDescription& problem, + const miopen::conv::ProblemDescription& problem, const miopen::AnyInvokeParams& invoke_ctx); std::vector FindAllFFTSolutions(const miopen::ExecutionContext& ctx, - const miopen::ProblemDescription& problem, + const miopen::conv::ProblemDescription& problem, const miopen::AnyInvokeParams& invoke_ctx); struct mlo_construct_base diff --git a/src/include/miopen/problem_description.hpp b/src/include/miopen/problem_description.hpp index 76ae792895..8e1167baad 100644 --- a/src/include/miopen/problem_description.hpp +++ b/src/include/miopen/problem_description.hpp @@ -28,13 +28,8 @@ #define GUARD_PROBLEM_DESCRIPTION_HPP_ #include -#include #include -#if MIOPEN_ENABLE_SQLITE -#include -#endif -#include #include #include @@ -62,32 +57,16 @@ SetDescFromMLDesc(int spatial_dims, TTo& to, const TensorDescriptor& tensor, con return tensor.GetElementSpace(); } -struct ConvolutionDescriptor; - -// Todo: change all uses in convolution to conv::ProblemDescription and remove this +#if FIN_OLD_PROBLEM_DESCRIPTION_COMPAT struct ProblemDescription : conv::ProblemDescription { - struct Direction - { - public: - bool IsForward() const { return v == conv::Direction::Forward; } - bool IsBackwardData() const { return v == conv::Direction::BackwardData; } - bool IsBackwardWrW() const { return v == conv::Direction::BackwardWeights; } - - std::string GetStr() const { return IsForward() ? "F" : IsBackwardData() ? "B" : "W"; } - - Direction() = default; - Direction(conv::Direction value) : v(value) {} - - private: - conv::Direction v = conv::Direction::Forward; - } direction; - ProblemDescription() = default; - ProblemDescription(conv::ProblemDescription desc); + ProblemDescription(conv::ProblemDescription desc) : conv::ProblemDescription(std::move(desc)) + { + conv_problem.p = this; + } -#if FIN_OLD_PROBLEM_DESCRIPTION_COMPAT struct { void SetupFloats(ExecutionContext& ctx) const { p->SetupFloats(ctx); } @@ -96,8 +75,8 @@ struct ProblemDescription : conv::ProblemDescription const conv::ProblemDescription* p = nullptr; friend struct ProblemDescription; } conv_problem; -#endif }; +#endif // For mlo_construct_base // TODO remove this @@ -156,9 +135,9 @@ struct ProblemDescriptionCompatTemporary int GetOutChannelStride() const { return out_channel_stride; } int GetOutBatchStride() const { return out_batch_stride; } - ProblemDescriptionCompatTemporary(miopen::conv::Direction dir) : direction(dir) {} + ProblemDescriptionCompatTemporary(conv::Direction dir) : direction(dir) {} - ProblemDescription::Direction direction; + bool IsDirectionForward() const { return direction == conv::Direction::Forward; } /* * set top tensor @@ -269,6 +248,9 @@ struct ProblemDescriptionCompatTemporary batch_sz = batch; n_inputs = channels; } + +private: + conv::Direction direction; }; struct UnifiedDescriptionConv2d @@ -297,28 +279,28 @@ struct UnifiedDescriptionConv2d // strd := U/V -u/v convolution stride (output stride) kernel_stride // idil := input dilation (n/a except transposed convolutions) ? - UnifiedDescriptionConv2d(const ProblemDescription& problem) + UnifiedDescriptionConv2d(const conv::ProblemDescription& problem) { if(!problem.Is2d()) MIOPEN_THROW(miopenStatusInternalError, "UnifiedDescriptionConv2d supports only 2D"); const auto n_inputs_per_group = problem.GetInChannels_() / problem.GetGroupCount(); const auto n_outputs_per_group = problem.GetOutChannels_() / problem.GetGroupCount(); - if(!problem.direction.IsBackwardWrW()) + if(!problem.IsDirectionBackwardWrW()) { R = problem.GetWeightsHeight_(); S = problem.GetWeightsWidth_(); - U = problem.direction.IsForward() ? problem.GetKernelStrideH() : 1; - V = problem.direction.IsForward() ? problem.GetKernelStrideW() : 1; + U = problem.IsDirectionForward() ? problem.GetKernelStrideH() : 1; + V = problem.IsDirectionForward() ? problem.GetKernelStrideW() : 1; C = n_inputs_per_group; // Bwd: C and K is reversed in ProblemDescription. K = n_outputs_per_group; // Ditto. out_h = problem.GetOutHeight_(); // Bwd: height/width is reversed in ProblemDescription. out_w = problem.GetOutWidth_(); // Ditto. N = problem.GetBatchSize_(); - pad_h = problem.direction.IsForward() ? problem.GetPadH() : problem.GetBackwardPadH(); - pad_w = problem.direction.IsForward() ? problem.GetPadW() : problem.GetBackwardPadW(); - input_stride_h = problem.direction.IsForward() ? 1 : problem.GetKernelStrideH(); - input_stride_w = problem.direction.IsForward() ? 1 : problem.GetKernelStrideW(); + pad_h = problem.IsDirectionForward() ? problem.GetPadH() : problem.GetBackwardPadH(); + pad_w = problem.IsDirectionForward() ? problem.GetPadW() : problem.GetBackwardPadW(); + input_stride_h = problem.IsDirectionForward() ? 1 : problem.GetKernelStrideH(); + input_stride_w = problem.IsDirectionForward() ? 1 : problem.GetKernelStrideW(); filter_stride_h = problem.GetDilationH(); filter_stride_w = problem.GetDilationW(); } diff --git a/src/include/miopen/solver.hpp b/src/include/miopen/solver.hpp index ce40d6f081..60be00597e 100644 --- a/src/include/miopen/solver.hpp +++ b/src/include/miopen/solver.hpp @@ -82,7 +82,7 @@ struct SolverBase /// overriden to keep the name to avoid DB corruption. virtual const std::string& SolverDbId() const = 0; - /// In some instances ( particularly fusions) the fused solver might like to + /// In some instances (particularly fusions) the fused solver might like to /// fallback to the non-fused variant for performance parameters, this information /// is returned via AltSolverDbId virtual const std::string& AltSolverDbId() const @@ -180,11 +180,13 @@ struct NonTunableSolverBase : SolverMixin virtual ConvSolution GetSolution(const Context&, const Problem&) const = 0; }; +namespace conv { + /// Typedef for convolution solvers -using ConvSolver = NonTunableSolverBase; +using ConvSolver = NonTunableSolverBase; /// Base class for tunable solvers -struct ConvTunableSolverBase : SolverMixin +struct ConvTunableSolverBase : SolverMixin { /// Initializes performance config to the default values. /// The function may involve some heuristic to guess the best solution @@ -196,13 +198,13 @@ struct ConvTunableSolverBase : SolverMixin /// function in the derived class. Function declarations that differ /// only by its return type cannot be overloaded. virtual boost::any GetDefaultPerformanceConfig(const ExecutionContext& ctx, - const ProblemDescription& problem, + const miopen::conv::ProblemDescription& problem, int) const = 0; /// Should return false if performance config is wrong for a problem. /// Main use is validation of values read from the perf db. virtual bool IsValidPerformanceConfig(const ExecutionContext& ctx, - const ProblemDescription& problem, + const miopen::conv::ProblemDescription& problem, const PerfConfig& config) const = 0; /// Search @@ -211,13 +213,13 @@ struct ConvTunableSolverBase : SolverMixin /// function in the derived class. Function declarations that differ /// only by its return type cannot be overloaded. virtual boost::any Search(const ExecutionContext& ctx, - const ProblemDescription& problem, + const miopen::conv::ProblemDescription& problem, const AnyInvokeParams& invoke_ctx, int) const = 0; /// Tunable solvers provide a GetSolution that takes a Context and PerformanceConfig virtual ConvSolution GetSolution(const ExecutionContext& ctx, - const ProblemDescription& problem, + const miopen::conv::ProblemDescription& problem, const PerfConfig& config) const = 0; }; @@ -227,26 +229,28 @@ struct ConvTunableSolver : ConvTunableSolverBase static_assert(std::is_base_of{}, "PerformanceConfig must be derived of PerfConfig"); - virtual PerformanceConfig GetDefaultPerformanceConfig(const ExecutionContext&, - const ProblemDescription&) const = 0; - virtual bool IsValidPerformanceConfig(const ExecutionContext&, - const ProblemDescription&, - const PerformanceConfig&) const = 0; virtual PerformanceConfig - Search(const ExecutionContext&, const ProblemDescription&, const AnyInvokeParams&) const = 0; + GetDefaultPerformanceConfig(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const = 0; + virtual bool IsValidPerformanceConfig(const ExecutionContext&, + const miopen::conv::ProblemDescription&, + const PerformanceConfig&) const = 0; + virtual PerformanceConfig Search(const ExecutionContext&, + const miopen::conv::ProblemDescription&, + const AnyInvokeParams&) const = 0; virtual ConvSolution GetSolution(const ExecutionContext&, - const ProblemDescription&, - const PerformanceConfig&) const = 0; + const miopen::conv::ProblemDescription&, + const PerformanceConfig&) const = 0; boost::any GetDefaultPerformanceConfig(const ExecutionContext& ctx, - const ProblemDescription& problem, + const miopen::conv::ProblemDescription& problem, int) const final { return GetDefaultPerformanceConfig(ctx, problem); } bool IsValidPerformanceConfig(const ExecutionContext& ctx, - const ProblemDescription& problem, + const miopen::conv::ProblemDescription& problem, const PerfConfig& config) const final { return IsValidPerformanceConfig( @@ -254,7 +258,7 @@ struct ConvTunableSolver : ConvTunableSolverBase } boost::any Search(const ExecutionContext& ctx, - const ProblemDescription& problem, + const miopen::conv::ProblemDescription& problem, const AnyInvokeParams& invoke_ctx, int) const final { @@ -262,7 +266,7 @@ struct ConvTunableSolver : ConvTunableSolverBase } ConvSolution GetSolution(const ExecutionContext& ctx, - const ProblemDescription& problem, + const miopen::conv::ProblemDescription& problem, const PerfConfig& config) const final { return GetSolution(ctx, problem, dynamic_cast(config)); @@ -287,14 +291,14 @@ struct PerformanceConfigConvAsm3x3U : PerfConfigBase { const std::string& SolverDbId() const override { return GetSolverDbId(); } - bool IsApplicable(const ExecutionContext&, const ProblemDescription&) const override; + bool IsApplicable(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; PerformanceConfigConvAsm3x3U - GetDefaultPerformanceConfig(const ExecutionContext&, const ProblemDescription&) const override; + GetDefaultPerformanceConfig(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; bool IsValidPerformanceConfig(const ExecutionContext&, - const ProblemDescription&, + const miopen::conv::ProblemDescription&, const PerformanceConfigConvAsm3x3U&) const override; PerformanceConfigConvAsm3x3U Search(const ExecutionContext&, - const ProblemDescription&, + const miopen::conv::ProblemDescription&, const AnyInvokeParams& invoke_ctx) const override; ConvSolution GetSolution(const ExecutionContext&, - const ProblemDescription&, + const miopen::conv::ProblemDescription&, const PerformanceConfigConvAsm3x3U&) const override; }; @@ -362,25 +368,30 @@ struct PerformanceConfigConvAsm1x1U : PerfConfigBase const std::string& SolverDbId() const override { return GetSolverDbId(); } PerformanceConfigConvAsm1x1U - GetDefaultPerformanceConfig(const ExecutionContext&, const ProblemDescription&) const override; + GetDefaultPerformanceConfig(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; bool IsValidPerformanceConfig(const ExecutionContext&, - const ProblemDescription&, + const miopen::conv::ProblemDescription&, const PerformanceConfigConvAsm1x1U&) const override; PerformanceConfigConvAsm1x1U Search(const ExecutionContext&, - const ProblemDescription&, + const miopen::conv::ProblemDescription&, const AnyInvokeParams& invoke_ctx) const override; - bool IsApplicable(const ExecutionContext&, const ProblemDescription&) const override; - size_t GetWorkspaceSize(const ExecutionContext&, const ProblemDescription&) const override; + bool IsApplicable(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; + size_t GetWorkspaceSize(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; bool MayNeedWorkspace() const override { return true; } ConvSolution GetSolution(const ExecutionContext&, - const ProblemDescription&, + const miopen::conv::ProblemDescription&, const PerformanceConfigConvAsm1x1U&) const override; }; @@ -465,14 +479,14 @@ struct PerformanceConfigConvAsm1x1UV2 : PerfConfigBase const std::string& SolverDbId() const override { return GetSolverDbId(); } PerformanceConfigConvAsm1x1UV2 - GetDefaultPerformanceConfig(const ExecutionContext&, const ProblemDescription&) const override; + GetDefaultPerformanceConfig(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; bool IsValidPerformanceConfig(const ExecutionContext&, - const ProblemDescription&, + const miopen::conv::ProblemDescription&, const PerformanceConfigConvAsm1x1UV2&) const override; PerformanceConfigConvAsm1x1UV2 Search(const ExecutionContext&, - const ProblemDescription&, + const miopen::conv::ProblemDescription&, const AnyInvokeParams& invoke_ctx) const override; - bool IsApplicable(const ExecutionContext&, const ProblemDescription&) const override; + bool IsApplicable(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; ConvSolution GetSolution(const ExecutionContext&, - const ProblemDescription&, + const miopen::conv::ProblemDescription&, const PerformanceConfigConvAsm1x1UV2&) const override; }; struct ConvAsm5x10u2v2f1 final : ConvSolver { - // To suppress -Woverloaded-virtual - using ConvSolver::IsApplicable; - const std::string& SolverDbId() const override { return GetSolverDbId(); } - bool IsApplicable(const ExecutionContext& ctx, - const ProblemDescription& problem) const override; - ConvSolution GetSolution(const ExecutionContext& ctx, - const ProblemDescription& problem) const override; + bool IsApplicable(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; + ConvSolution GetSolution(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; }; struct ConvAsm5x10u2v2b1 final : ConvSolver { - // To suppress -Woverloaded-virtual - using ConvSolver::IsApplicable; - const std::string& SolverDbId() const override { return GetSolverDbId(); } - bool IsApplicable(const ExecutionContext& ctx, - const ProblemDescription& problem) const override; - ConvSolution GetSolution(const ExecutionContext& ctx, - const ProblemDescription& problem) const override; + bool IsApplicable(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; + ConvSolution GetSolution(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; }; struct ConvAsm7x7c3h224w224k64u2v2p3q3f1 final : ConvSolver { - // To suppress -Woverloaded-virtual - using ConvSolver::IsApplicable; - const std::string& SolverDbId() const override { return GetSolverDbId(); } - bool IsApplicable(const ExecutionContext& ctx, - const ProblemDescription& problem) const override; - ConvSolution GetSolution(const ExecutionContext& ctx, - const ProblemDescription& problem) const override; + bool IsApplicable(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; + ConvSolution GetSolution(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; }; struct ConvOclDirectFwd11x11 final : ConvSolver @@ -543,16 +550,20 @@ struct ConvOclDirectFwd11x11 final : ConvSolver return GetSolverDbId(); } - bool IsApplicable(const ExecutionContext&, const ProblemDescription&) const override; - ConvSolution GetSolution(const ExecutionContext&, const ProblemDescription&) const override; + bool IsApplicable(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; + ConvSolution GetSolution(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; }; struct ConvOclDirectFwdGen final : ConvSolver { const std::string& SolverDbId() const override { return GetSolverDbId(); } - bool IsApplicable(const ExecutionContext&, const ProblemDescription&) const override; - ConvSolution GetSolution(const ExecutionContext&, const ProblemDescription&) const override; + bool IsApplicable(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; + ConvSolution GetSolution(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; }; struct PerformanceImplicitGemm : PerfConfigBase @@ -613,10 +624,10 @@ struct PerformanceImplicitGemm : PerfConfigBase f(self.WeiBlockCopyClusterLengths_K, "WeiBlockCopyClusterLengths_K"); } - void HeuristicInit(const ExecutionContext&, const ProblemDescription&); + void HeuristicInit(const ExecutionContext&, const miopen::conv::ProblemDescription&); bool IsValidValue() const; - bool SetNextValue(const ProblemDescription&); - bool IsValid(const ExecutionContext&, const ProblemDescription&) const; + bool SetNextValue(const miopen::conv::ProblemDescription&); + bool IsValid(const ExecutionContext&, const miopen::conv::ProblemDescription&) const; bool operator==(const PerformanceImplicitGemm& other) const; }; @@ -651,7 +662,7 @@ struct PerformanceImplicitGemmV4R1 : public PerformanceImplicitGemm PerformanceImplicitGemmV4R1(bool spare) : PerformanceImplicitGemm(spare) {} - bool IsValid(const ExecutionContext&, const ProblemDescription&) const; + bool IsValid(const ExecutionContext&, const miopen::conv::ProblemDescription&) const; }; struct PerformanceImplicitGemmV4R4Fwd : PerfConfigBase @@ -693,22 +704,23 @@ struct PerformanceImplicitGemmV4R4Fwd : PerfConfigBase CalculateGridSize(const ProblemDescription&) const; + std::tuple CalculateGridSize(const miopen::conv::ProblemDescription&) const; std::tuple CalculateBlockGemmPerformanceParameters() const; std::tuple CalculateGemmABlockCopyPerformanceParameters() const; std::tuple - CalculateGemmBBlockCopyPerformanceParameters(const ProblemDescription&) const; + CalculateGemmBBlockCopyPerformanceParameters(const miopen::conv::ProblemDescription&) const; std::tuple - CalculateGemmCThreadCopyPerformanceParameters(const ProblemDescription&) const; - std::tuple CalculateLdsNumberOfByte(const ProblemDescription&) const; + CalculateGemmCThreadCopyPerformanceParameters(const miopen::conv::ProblemDescription&) const; + std::tuple + CalculateLdsNumberOfByte(const miopen::conv::ProblemDescription&) const; bool IsValidValue() const; - bool IsValid(const ExecutionContext&, const ProblemDescription& problem) const + bool IsValid(const ExecutionContext&, const miopen::conv::ProblemDescription& problem) const { return IsValid(problem); } - bool IsValid(const ProblemDescription&) const; - void HeuristicInit(const ExecutionContext&, const ProblemDescription&); - bool SetNextValue(const ProblemDescription&); + bool IsValid(const miopen::conv::ProblemDescription&) const; + void HeuristicInit(const ExecutionContext&, const miopen::conv::ProblemDescription&); + bool SetNextValue(const miopen::conv::ProblemDescription&); }; struct PerformanceImplicitGemmV4R4WrW : PerfConfigBase @@ -750,22 +762,23 @@ struct PerformanceImplicitGemmV4R4WrW : PerfConfigBase CalculateGridSize(const ProblemDescription&) const; + std::tuple CalculateGridSize(const miopen::conv::ProblemDescription&) const; std::tuple CalculateBlockGemmPerformanceParameters() const; std::tuple CalculateGemmABlockCopyPerformanceParameters() const; std::tuple - CalculateGemmBBlockCopyPerformanceParameters(const ProblemDescription&) const; + CalculateGemmBBlockCopyPerformanceParameters(const miopen::conv::ProblemDescription&) const; std::tuple - CalculateGemmCThreadCopyPerformanceParameters(const ProblemDescription&) const; - std::tuple CalculateLdsNumberOfByte(const ProblemDescription&) const; + CalculateGemmCThreadCopyPerformanceParameters(const miopen::conv::ProblemDescription&) const; + std::tuple + CalculateLdsNumberOfByte(const miopen::conv::ProblemDescription&) const; bool IsValidValue() const; - bool IsValid(const ExecutionContext&, const ProblemDescription& problem) const + bool IsValid(const ExecutionContext&, const miopen::conv::ProblemDescription& problem) const { return IsValid(problem); } - bool IsValid(const ProblemDescription&) const; - void HeuristicInit(const ExecutionContext&, const ProblemDescription&); - bool SetNextValue(const ProblemDescription&); + bool IsValid(const miopen::conv::ProblemDescription&) const; + void HeuristicInit(const ExecutionContext&, const miopen::conv::ProblemDescription&); + bool SetNextValue(const miopen::conv::ProblemDescription&); }; struct PerformanceImplicitGemmBwdDataV1R1 : PerfConfigBase @@ -809,22 +822,23 @@ struct PerformanceImplicitGemmBwdDataV1R1 : PerfConfigBase CalculateGridSize(const ExecutionContext&, - const ProblemDescription&) const; + const miopen::conv::ProblemDescription&) const; std::tuple CalculateBlockGemmPerformanceParameters() const; std::tuple CalculateGemmABlockCopyPerformanceParameters(const ExecutionContext&, - const ProblemDescription&) const; + const miopen::conv::ProblemDescription&) const; std::tuple CalculateGemmBBlockCopyPerformanceParameters(const ExecutionContext&, - const ProblemDescription&) const; + const miopen::conv::ProblemDescription&) const; std::tuple - CalculateGemmCThreadCopyPerformanceParameters(const ProblemDescription&) const; - std::tuple CalculateLdsNumberOfByte(const ExecutionContext&, - const ProblemDescription&) const; + CalculateGemmCThreadCopyPerformanceParameters(const miopen::conv::ProblemDescription&) const; + std::tuple + CalculateLdsNumberOfByte(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const; bool IsValidValue() const; - bool IsValid(const ExecutionContext&, const ProblemDescription&) const; - void HeuristicInit(const ExecutionContext&, const ProblemDescription&); - bool SetNextValue(const ProblemDescription&); + bool IsValid(const ExecutionContext&, const miopen::conv::ProblemDescription&) const; + void HeuristicInit(const ExecutionContext&, const miopen::conv::ProblemDescription&); + bool SetNextValue(const miopen::conv::ProblemDescription&); }; struct PerformanceImplicitGemmBwdDataV4R1 : PerfConfigBase @@ -867,23 +881,24 @@ struct PerformanceImplicitGemmBwdDataV4R1 : PerfConfigBase CalculateGridSize(const ProblemDescription&) const; + std::tuple CalculateGridSize(const miopen::conv::ProblemDescription&) const; std::tuple CalculateBlockGemmPerformanceParameters() const; std::tuple - CalculateGemmABlockCopyPerformanceParameters(const ProblemDescription&) const; + CalculateGemmABlockCopyPerformanceParameters(const miopen::conv::ProblemDescription&) const; std::tuple - CalculateGemmBBlockCopyPerformanceParameters(const ProblemDescription&) const; + CalculateGemmBBlockCopyPerformanceParameters(const miopen::conv::ProblemDescription&) const; std::tuple - CalculateGemmCThreadCopyPerformanceParameters(const ProblemDescription&) const; - std::tuple CalculateLdsNumberOfByte(const ProblemDescription&) const; + CalculateGemmCThreadCopyPerformanceParameters(const miopen::conv::ProblemDescription&) const; + std::tuple + CalculateLdsNumberOfByte(const miopen::conv::ProblemDescription&) const; bool IsValidValue() const; - bool IsValid(const ExecutionContext&, const ProblemDescription& problem) const + bool IsValid(const ExecutionContext&, const miopen::conv::ProblemDescription& problem) const { return IsValid(problem); } - bool IsValid(const ProblemDescription&) const; - void HeuristicInit(const ExecutionContext&, const ProblemDescription&); - bool SetNextValue(const ProblemDescription&); + bool IsValid(const miopen::conv::ProblemDescription&) const; + void HeuristicInit(const ExecutionContext&, const miopen::conv::ProblemDescription&); + bool SetNextValue(const miopen::conv::ProblemDescription&); }; struct PerformanceImplicitGemmBwdDataV4R1Xdlops @@ -928,18 +943,20 @@ struct PerformanceImplicitGemmBwdDataV4R1Xdlops f(self.GemmBThreadCopyMoreGemmKPack, "GemmBThreadCopyMoreGemmKPack"); } - std::tuple CalculateGridSize(const ProblemDescription&) const; - std::tuple CalculateLdsNumberOfByte(const ProblemDescription&) const; + std::tuple CalculateGridSize(const miopen::conv::ProblemDescription&) const; + std::tuple + CalculateLdsNumberOfByte(const miopen::conv::ProblemDescription&) const; std::tuple - CalculateGemmABlockCopyPerformanceParameters(const ProblemDescription&) const; + CalculateGemmABlockCopyPerformanceParameters(const miopen::conv::ProblemDescription&) const; std::tuple - CalculateGemmBBlockCopyPerformanceParameters(const ProblemDescription&) const; + CalculateGemmBBlockCopyPerformanceParameters(const miopen::conv::ProblemDescription&) const; bool IsValidValue() const; - bool IsValid(const ExecutionContext&, const ProblemDescription&) const; - bool IsReallyValid(const ProblemDescription&) const; - bool IsFastToBeUsedForTuning(const ExecutionContext&, const ProblemDescription&) const; - void HeuristicInit(const ExecutionContext&, const ProblemDescription&); - bool SetNextValue(const ProblemDescription&); + bool IsValid(const ExecutionContext&, const miopen::conv::ProblemDescription&) const; + bool IsReallyValid(const miopen::conv::ProblemDescription&) const; + bool IsFastToBeUsedForTuning(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const; + void HeuristicInit(const ExecutionContext&, const miopen::conv::ProblemDescription&); + bool SetNextValue(const miopen::conv::ProblemDescription&); }; struct ConvHipImplicitGemmV4R1Fwd final : ConvTunableSolver @@ -950,16 +967,18 @@ struct ConvHipImplicitGemmV4R1Fwd final : ConvTunableSolver(); } - bool IsApplicable(const ExecutionContext&, const ProblemDescription&) const override; + bool IsApplicable(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; PerformanceImplicitGemmV4R4Fwd - GetDefaultPerformanceConfig(const ExecutionContext&, const ProblemDescription&) const override; + GetDefaultPerformanceConfig(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; bool IsValidPerformanceConfig(const ExecutionContext&, - const ProblemDescription&, + const miopen::conv::ProblemDescription&, const PerformanceImplicitGemmV4R4Fwd&) const override; PerformanceImplicitGemmV4R4Fwd Search(const ExecutionContext&, - const ProblemDescription&, + const miopen::conv::ProblemDescription&, const AnyInvokeParams& invoke_ctx) const override; ConvSolution GetSolution(const ExecutionContext&, - const ProblemDescription&, + const miopen::conv::ProblemDescription&, const PerformanceImplicitGemmV4R4Fwd&) const override; private: - static std::tuple CalculateGemmSize(const ProblemDescription&); + static std::tuple CalculateGemmSize(const miopen::conv::ProblemDescription&); friend struct PerformanceImplicitGemmV4R4Fwd; }; @@ -1031,8 +1052,8 @@ struct PerformanceConvMlirIgemm : PerfConfigBase f(self.GemmNPerThread, "GemmNPerThread"); } - bool IsValid(const ExecutionContext&, const ProblemDescription&) const; - bool SetNextValue(const ProblemDescription&); + bool IsValid(const ExecutionContext&, const miopen::conv::ProblemDescription&) const; + bool SetNextValue(const miopen::conv::ProblemDescription&); private: void SetMlirHeuristicInitRequest(); @@ -1042,17 +1063,19 @@ struct ConvMlirIgemmFwd final : ConvTunableSolver { const std::string& SolverDbId() const override { return GetSolverDbId(); } - bool IsApplicable(const ExecutionContext&, const ProblemDescription&) const override; - PerformanceConvMlirIgemm GetDefaultPerformanceConfig(const ExecutionContext&, - const ProblemDescription&) const override; + bool IsApplicable(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; + PerformanceConvMlirIgemm + GetDefaultPerformanceConfig(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; bool IsValidPerformanceConfig(const ExecutionContext&, - const ProblemDescription&, + const miopen::conv::ProblemDescription&, const PerformanceConvMlirIgemm&) const override; PerformanceConvMlirIgemm Search(const ExecutionContext&, - const ProblemDescription&, + const miopen::conv::ProblemDescription&, const AnyInvokeParams& invoke_ctx) const override; ConvSolution GetSolution(const ExecutionContext&, - const ProblemDescription&, + const miopen::conv::ProblemDescription&, const PerformanceConvMlirIgemm&) const override; }; @@ -1103,8 +1126,8 @@ struct PerformanceConvMlirIgemmXdlops : PerfConfigBase(); } - bool IsApplicable(const ExecutionContext&, const ProblemDescription&) const override; + bool IsApplicable(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; PerformanceConvMlirIgemmXdlops - GetDefaultPerformanceConfig(const ExecutionContext&, const ProblemDescription&) const override; + GetDefaultPerformanceConfig(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; bool IsValidPerformanceConfig(const ExecutionContext&, - const ProblemDescription&, + const miopen::conv::ProblemDescription&, const PerformanceConvMlirIgemmXdlops&) const override; PerformanceConvMlirIgemmXdlops Search(const ExecutionContext&, - const ProblemDescription&, + const miopen::conv::ProblemDescription&, const AnyInvokeParams& invoke_ctx) const override; ConvSolution GetSolution(const ExecutionContext&, - const ProblemDescription&, + const miopen::conv::ProblemDescription&, const PerformanceConvMlirIgemmXdlops&) const override; }; @@ -1138,21 +1163,23 @@ struct ConvHipImplicitGemmV4R4WrW final : ConvTunableSolver(); } - bool IsApplicable(const ExecutionContext&, const ProblemDescription&) const override; + bool IsApplicable(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; PerformanceImplicitGemmV4R4WrW - GetDefaultPerformanceConfig(const ExecutionContext&, const ProblemDescription&) const override; + GetDefaultPerformanceConfig(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; bool IsValidPerformanceConfig(const ExecutionContext&, - const ProblemDescription&, + const miopen::conv::ProblemDescription&, const PerformanceImplicitGemmV4R4WrW&) const override; PerformanceImplicitGemmV4R4WrW Search(const ExecutionContext&, - const ProblemDescription&, + const miopen::conv::ProblemDescription&, const AnyInvokeParams& invoke_ctx) const override; ConvSolution GetSolution(const ExecutionContext&, - const ProblemDescription&, + const miopen::conv::ProblemDescription&, const PerformanceImplicitGemmV4R4WrW&) const override; private: - static std::tuple CalculateGemmSize(const ProblemDescription&); + static std::tuple CalculateGemmSize(const miopen::conv::ProblemDescription&); friend struct PerformanceImplicitGemmV4R4WrW; }; @@ -1161,17 +1188,19 @@ struct ConvMlirIgemmWrW final : ConvTunableSolver { const std::string& SolverDbId() const override { return GetSolverDbId(); } - bool IsApplicable(const ExecutionContext&, const ProblemDescription&) const override; - PerformanceConvMlirIgemm GetDefaultPerformanceConfig(const ExecutionContext&, - const ProblemDescription&) const override; + bool IsApplicable(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; + PerformanceConvMlirIgemm + GetDefaultPerformanceConfig(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; bool IsValidPerformanceConfig(const ExecutionContext&, - const ProblemDescription&, + const miopen::conv::ProblemDescription&, const PerformanceConvMlirIgemm&) const override; PerformanceConvMlirIgemm Search(const ExecutionContext&, - const ProblemDescription&, + const miopen::conv::ProblemDescription&, const AnyInvokeParams& invoke_ctx) const override; ConvSolution GetSolution(const ExecutionContext&, - const ProblemDescription&, + const miopen::conv::ProblemDescription&, const PerformanceConvMlirIgemm&) const override; }; @@ -1182,19 +1211,22 @@ struct ConvMlirIgemmWrWXdlops final : ConvTunableSolver(); } - bool IsApplicable(const ExecutionContext&, const ProblemDescription&) const override; + bool IsApplicable(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; PerformanceConvMlirIgemmXdlops - GetDefaultPerformanceConfig(const ExecutionContext&, const ProblemDescription&) const override; + GetDefaultPerformanceConfig(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; bool IsValidPerformanceConfig(const ExecutionContext&, - const ProblemDescription&, + const miopen::conv::ProblemDescription&, const PerformanceConvMlirIgemmXdlops&) const override; PerformanceConvMlirIgemmXdlops Search(const ExecutionContext&, - const ProblemDescription&, + const miopen::conv::ProblemDescription&, const AnyInvokeParams& invoke_ctx) const override; - size_t GetWorkspaceSize(const ExecutionContext&, const ProblemDescription&) const override; + size_t GetWorkspaceSize(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; bool MayNeedWorkspace() const override { return true; } ConvSolution GetSolution(const ExecutionContext&, - const ProblemDescription&, + const miopen::conv::ProblemDescription&, const PerformanceConvMlirIgemmXdlops&) const override; }; @@ -1231,20 +1263,22 @@ struct PerformanceImplicitGemmForwardV4R4Xdlops bool operator==(const PerformanceImplicitGemmForwardV4R4Xdlops& other) const; - void HeuristicInit(const ExecutionContext&, const ProblemDescription&); - bool SetNextValue(const ProblemDescription&); + void HeuristicInit(const ExecutionContext&, const miopen::conv::ProblemDescription&); + bool SetNextValue(const miopen::conv::ProblemDescription&); bool IsValidValue() const; - bool IsValid(const ExecutionContext&, const ProblemDescription&) const; - bool IsReallyValid(const ProblemDescription&) const; - bool IsFastToBeUsedForTuning(const ExecutionContext&, const ProblemDescription&) const; + bool IsValid(const ExecutionContext&, const miopen::conv::ProblemDescription&) const; + bool IsReallyValid(const miopen::conv::ProblemDescription&) const; + bool IsFastToBeUsedForTuning(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const; std::tuple CalculateBlockSize() const; - std::tuple CalculateGridSize(const ProblemDescription&) const; + std::tuple CalculateGridSize(const miopen::conv::ProblemDescription&) const; std::tuple - CalculateGemmABlockCopyPerformanceParameters(const ProblemDescription&) const; + CalculateGemmABlockCopyPerformanceParameters(const miopen::conv::ProblemDescription&) const; std::tuple - CalculateGemmBBlockCopyPerformanceParameters(const ProblemDescription&) const; - std::tuple CalculateLdsNumberOfByte(const ProblemDescription&) const; + CalculateGemmBBlockCopyPerformanceParameters(const miopen::conv::ProblemDescription&) const; + std::tuple + CalculateLdsNumberOfByte(const miopen::conv::ProblemDescription&) const; }; struct PerformanceImplicitGemmForwardV4R5Xdlops @@ -1288,20 +1322,22 @@ struct PerformanceImplicitGemmForwardV4R5Xdlops bool operator==(const PerformanceImplicitGemmForwardV4R5Xdlops& other) const; - void HeuristicInit(const ExecutionContext&, const ProblemDescription&); - bool SetNextValue(const ProblemDescription&); + void HeuristicInit(const ExecutionContext&, const miopen::conv::ProblemDescription&); + bool SetNextValue(const miopen::conv::ProblemDescription&); bool IsValidValue() const; - bool IsValid(const ExecutionContext&, const ProblemDescription&) const; - bool IsReallyValid(const ProblemDescription&) const; - bool IsFastToBeUsedForTuning(const ExecutionContext&, const ProblemDescription&) const; + bool IsValid(const ExecutionContext&, const miopen::conv::ProblemDescription&) const; + bool IsReallyValid(const miopen::conv::ProblemDescription&) const; + bool IsFastToBeUsedForTuning(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const; std::tuple CalculateBlockSize() const; - std::tuple CalculateGridSize(const ProblemDescription&) const; + std::tuple CalculateGridSize(const miopen::conv::ProblemDescription&) const; std::tuple - CalculateGemmABlockCopyPerformanceParameters(const ProblemDescription&) const; + CalculateGemmABlockCopyPerformanceParameters(const miopen::conv::ProblemDescription&) const; std::tuple - CalculateGemmBBlockCopyPerformanceParameters(const ProblemDescription&) const; - std::tuple CalculateLdsNumberOfByte(const ProblemDescription&) const; + CalculateGemmBBlockCopyPerformanceParameters(const miopen::conv::ProblemDescription&) const; + std::tuple + CalculateLdsNumberOfByte(const miopen::conv::ProblemDescription&) const; }; struct PerformanceImplicitGemmForwardV4R4Xdlops_Padded_Gemm @@ -1347,20 +1383,22 @@ struct PerformanceImplicitGemmForwardV4R4Xdlops_Padded_Gemm bool operator==(const PerformanceImplicitGemmForwardV4R4Xdlops_Padded_Gemm& other) const; - void HeuristicInit(const ExecutionContext&, const ProblemDescription&); - bool SetNextValue(const ProblemDescription&); + void HeuristicInit(const ExecutionContext&, const miopen::conv::ProblemDescription&); + bool SetNextValue(const miopen::conv::ProblemDescription&); bool IsValidValue() const; - bool IsValid(const ExecutionContext&, const ProblemDescription&) const; - bool IsReallyValid(const ProblemDescription&) const; - bool IsFastToBeUsedForTuning(const ExecutionContext&, const ProblemDescription&) const; + bool IsValid(const ExecutionContext&, const miopen::conv::ProblemDescription&) const; + bool IsReallyValid(const miopen::conv::ProblemDescription&) const; + bool IsFastToBeUsedForTuning(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const; std::tuple CalculateBlockSize() const; - std::tuple CalculateGridSize(const ProblemDescription&) const; + std::tuple CalculateGridSize(const miopen::conv::ProblemDescription&) const; std::tuple - CalculateGemmABlockCopyPerformanceParameters(const ProblemDescription&) const; + CalculateGemmABlockCopyPerformanceParameters(const miopen::conv::ProblemDescription&) const; std::tuple - CalculateGemmBBlockCopyPerformanceParameters(const ProblemDescription&) const; - std::tuple CalculateLdsNumberOfByte(const ProblemDescription&) const; + CalculateGemmBBlockCopyPerformanceParameters(const miopen::conv::ProblemDescription&) const; + std::tuple + CalculateLdsNumberOfByte(const miopen::conv::ProblemDescription&) const; }; struct PerformanceImplicitGemmBwdV1R1Xdlops : PerfConfigBase @@ -1393,20 +1431,22 @@ struct PerformanceImplicitGemmBwdV1R1Xdlops : PerfConfigBase CalculateBlockSize() const; - std::tuple CalculateGridSize(const ProblemDescription&) const; + std::tuple CalculateGridSize(const miopen::conv::ProblemDescription&) const; std::tuple - CalculateGemmABlockCopyPerformanceParameters(const ProblemDescription&) const; + CalculateGemmABlockCopyPerformanceParameters(const miopen::conv::ProblemDescription&) const; std::tuple - CalculateGemmBBlockCopyPerformanceParameters(const ProblemDescription&) const; - std::tuple CalculateLdsNumberOfByte(const ProblemDescription&) const; + CalculateGemmBBlockCopyPerformanceParameters(const miopen::conv::ProblemDescription&) const; + std::tuple + CalculateLdsNumberOfByte(const miopen::conv::ProblemDescription&) const; }; struct ConvHipImplicitGemmForwardV4R4Xdlops final @@ -1418,21 +1458,24 @@ struct ConvHipImplicitGemmForwardV4R4Xdlops final } PerformanceImplicitGemmForwardV4R4Xdlops - GetDefaultPerformanceConfig(const ExecutionContext&, const ProblemDescription&) const override; + GetDefaultPerformanceConfig(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; bool IsValidPerformanceConfig(const ExecutionContext&, - const ProblemDescription&, + const miopen::conv::ProblemDescription&, const PerformanceImplicitGemmForwardV4R4Xdlops&) const override; - bool IsApplicable(const ExecutionContext&, const ProblemDescription&) const override; + bool IsApplicable(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; ConvSolution GetSolution(const ExecutionContext&, - const ProblemDescription&, + const miopen::conv::ProblemDescription&, const PerformanceImplicitGemmForwardV4R4Xdlops&) const override; PerformanceImplicitGemmForwardV4R4Xdlops Search(const ExecutionContext&, - const ProblemDescription&, + const miopen::conv::ProblemDescription&, const AnyInvokeParams& invoke_ctx) const override; private: - static std::tuple CalculateGemmSize(const ProblemDescription&); + static std::tuple + CalculateGemmSize(const miopen::conv::ProblemDescription&); friend struct PerformanceImplicitGemmForwardV4R4Xdlops; }; @@ -1446,24 +1489,26 @@ struct ConvHipImplicitGemmForwardV4R4Xdlops_Padded_Gemm final } PerformanceImplicitGemmForwardV4R4Xdlops_Padded_Gemm - GetDefaultPerformanceConfig(const ExecutionContext&, const ProblemDescription&) const override; + GetDefaultPerformanceConfig(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; bool IsValidPerformanceConfig( const ExecutionContext&, - const ProblemDescription&, + const miopen::conv::ProblemDescription&, const PerformanceImplicitGemmForwardV4R4Xdlops_Padded_Gemm&) const override; - bool IsApplicable(const ExecutionContext&, const ProblemDescription&) const override; + bool IsApplicable(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; ConvSolution GetSolution(const ExecutionContext&, - const ProblemDescription&, + const miopen::conv::ProblemDescription&, const PerformanceImplicitGemmForwardV4R4Xdlops_Padded_Gemm&) const override; PerformanceImplicitGemmForwardV4R4Xdlops_Padded_Gemm Search(const ExecutionContext&, - const ProblemDescription&, + const miopen::conv::ProblemDescription&, const AnyInvokeParams& invoke_ctx) const override; private: - static std::tuple - CalculateGemmSize(const ProblemDescription&, int GemmMFactor, int GemmNFactor, int GemmKFactor); + static std::tuple CalculateGemmSize( + const miopen::conv::ProblemDescription&, int GemmMFactor, int GemmNFactor, int GemmKFactor); friend struct PerformanceImplicitGemmForwardV4R4Xdlops_Padded_Gemm; }; @@ -1477,17 +1522,19 @@ struct ConvHipImplicitGemmForwardV4R5Xdlops final } PerformanceImplicitGemmForwardV4R5Xdlops - GetDefaultPerformanceConfig(const ExecutionContext&, const ProblemDescription&) const override; + GetDefaultPerformanceConfig(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; bool IsValidPerformanceConfig(const ExecutionContext&, - const ProblemDescription&, + const miopen::conv::ProblemDescription&, const PerformanceImplicitGemmForwardV4R5Xdlops&) const override; - bool IsApplicable(const ExecutionContext&, const ProblemDescription&) const override; + bool IsApplicable(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; ConvSolution GetSolution(const ExecutionContext&, - const ProblemDescription&, + const miopen::conv::ProblemDescription&, const PerformanceImplicitGemmForwardV4R5Xdlops&) const override; PerformanceImplicitGemmForwardV4R5Xdlops Search(const ExecutionContext&, - const ProblemDescription&, + const miopen::conv::ProblemDescription&, const AnyInvokeParams& invoke_ctx) const override; }; @@ -1499,16 +1546,18 @@ struct ConvHipImplicitGemmV4R1WrW final : ConvTunableSolver(); } - bool IsApplicable(const ExecutionContext&, const ProblemDescription&) const override; + bool IsApplicable(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; PerformanceImplicitGemmBwdDataV1R1 - GetDefaultPerformanceConfig(const ExecutionContext&, const ProblemDescription&) const override; + GetDefaultPerformanceConfig(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; bool IsValidPerformanceConfig(const ExecutionContext&, - const ProblemDescription&, + const miopen::conv::ProblemDescription&, const PerformanceImplicitGemmBwdDataV1R1&) const override; PerformanceImplicitGemmBwdDataV1R1 Search(const ExecutionContext&, - const ProblemDescription&, + const miopen::conv::ProblemDescription&, const AnyInvokeParams& invoke_ctx) const override; ConvSolution GetSolution(const ExecutionContext&, - const ProblemDescription&, + const miopen::conv::ProblemDescription&, const PerformanceImplicitGemmBwdDataV1R1&) const override; - size_t GetWorkspaceSize(const ExecutionContext&, const ProblemDescription&) const override; + size_t GetWorkspaceSize(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; bool MayNeedWorkspace() const override { return true; } private: static std::tuple CalculateGemmSize(const ExecutionContext&, - const ProblemDescription&); + const miopen::conv::ProblemDescription&); friend struct PerformanceImplicitGemmBwdDataV1R1; }; @@ -1545,17 +1597,19 @@ struct ConvMlirIgemmBwd final : ConvTunableSolver { const std::string& SolverDbId() const override { return GetSolverDbId(); } - bool IsApplicable(const ExecutionContext&, const ProblemDescription&) const override; - PerformanceConvMlirIgemm GetDefaultPerformanceConfig(const ExecutionContext&, - const ProblemDescription&) const override; + bool IsApplicable(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; + PerformanceConvMlirIgemm + GetDefaultPerformanceConfig(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; bool IsValidPerformanceConfig(const ExecutionContext&, - const ProblemDescription&, + const miopen::conv::ProblemDescription&, const PerformanceConvMlirIgemm&) const override; PerformanceConvMlirIgemm Search(const ExecutionContext&, - const ProblemDescription&, + const miopen::conv::ProblemDescription&, const AnyInvokeParams& invoke_ctx) const override; ConvSolution GetSolution(const ExecutionContext&, - const ProblemDescription&, + const miopen::conv::ProblemDescription&, const PerformanceConvMlirIgemm&) const override; }; @@ -1566,17 +1620,19 @@ struct ConvMlirIgemmBwdXdlops final : ConvTunableSolver(); } - bool IsApplicable(const ExecutionContext&, const ProblemDescription&) const override; + bool IsApplicable(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; PerformanceConvMlirIgemmXdlops - GetDefaultPerformanceConfig(const ExecutionContext&, const ProblemDescription&) const override; + GetDefaultPerformanceConfig(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; bool IsValidPerformanceConfig(const ExecutionContext&, - const ProblemDescription&, + const miopen::conv::ProblemDescription&, const PerformanceConvMlirIgemmXdlops&) const override; PerformanceConvMlirIgemmXdlops Search(const ExecutionContext&, - const ProblemDescription&, + const miopen::conv::ProblemDescription&, const AnyInvokeParams& invoke_ctx) const override; ConvSolution GetSolution(const ExecutionContext&, - const ProblemDescription&, + const miopen::conv::ProblemDescription&, const PerformanceConvMlirIgemmXdlops&) const override; }; @@ -1587,22 +1643,25 @@ struct ConvHipImplicitGemmBwdDataV4R1 final : ConvTunableSolver(); } - bool IsApplicable(const ExecutionContext&, const ProblemDescription&) const override; + bool IsApplicable(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; PerformanceImplicitGemmBwdDataV4R1 - GetDefaultPerformanceConfig(const ExecutionContext&, const ProblemDescription&) const override; + GetDefaultPerformanceConfig(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; bool IsValidPerformanceConfig(const ExecutionContext&, - const ProblemDescription&, + const miopen::conv::ProblemDescription&, const PerformanceImplicitGemmBwdDataV4R1&) const override; PerformanceImplicitGemmBwdDataV4R1 Search(const ExecutionContext&, - const ProblemDescription&, + const miopen::conv::ProblemDescription&, const AnyInvokeParams& invoke_ctx) const override; ConvSolution GetSolution(const ExecutionContext&, - const ProblemDescription&, + const miopen::conv::ProblemDescription&, const PerformanceImplicitGemmBwdDataV4R1&) const override; private: - static int CalculateNumberOfGemm(const ProblemDescription&); - static std::tuple CalculateGemmSize(const ProblemDescription&, int gemm_id); + static int CalculateNumberOfGemm(const miopen::conv::ProblemDescription&); + static std::tuple CalculateGemmSize(const miopen::conv::ProblemDescription&, + int gemm_id); friend struct PerformanceImplicitGemmBwdDataV4R1; }; @@ -1616,22 +1675,25 @@ struct ConvHipImplicitGemmBwdDataV4R1Xdlops final } PerformanceImplicitGemmBwdDataV4R1Xdlops - GetDefaultPerformanceConfig(const ExecutionContext&, const ProblemDescription&) const override; + GetDefaultPerformanceConfig(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; bool IsValidPerformanceConfig(const ExecutionContext&, - const ProblemDescription&, + const miopen::conv::ProblemDescription&, const PerformanceImplicitGemmBwdDataV4R1Xdlops&) const override; - bool IsApplicable(const ExecutionContext&, const ProblemDescription&) const override; + bool IsApplicable(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; ConvSolution GetSolution(const ExecutionContext&, - const ProblemDescription&, + const miopen::conv::ProblemDescription&, const PerformanceImplicitGemmBwdDataV4R1Xdlops&) const override; PerformanceImplicitGemmBwdDataV4R1Xdlops Search(const ExecutionContext&, - const ProblemDescription&, + const miopen::conv::ProblemDescription&, const AnyInvokeParams& invoke_ctx) const override; private: - static int CalculateNumberOfGemm(const ProblemDescription&); - static std::tuple CalculateGemmSize(const ProblemDescription&, int gemm_id); + static int CalculateNumberOfGemm(const miopen::conv::ProblemDescription&); + static std::tuple CalculateGemmSize(const miopen::conv::ProblemDescription&, + int gemm_id); friend struct PerformanceImplicitGemmBwdDataV4R1Xdlops; }; @@ -1645,201 +1707,185 @@ struct ConvHipImplicitGemmBwdDataV1R1Xdlops final } PerformanceImplicitGemmBwdV1R1Xdlops - GetDefaultPerformanceConfig(const ExecutionContext&, const ProblemDescription&) const override; + GetDefaultPerformanceConfig(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; bool IsValidPerformanceConfig(const ExecutionContext&, - const ProblemDescription&, + const miopen::conv::ProblemDescription&, const PerformanceImplicitGemmBwdV1R1Xdlops&) const override; - bool IsApplicable(const ExecutionContext&, const ProblemDescription&) const override; - size_t GetWorkspaceSize(const ExecutionContext&, const ProblemDescription&) const override; + bool IsApplicable(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; + size_t GetWorkspaceSize(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; bool MayNeedWorkspace() const override { return true; } PerformanceImplicitGemmBwdV1R1Xdlops Search(const ExecutionContext&, - const ProblemDescription&, + const miopen::conv::ProblemDescription&, const AnyInvokeParams& invoke_ctx) const override; ConvSolution GetSolution(const ExecutionContext&, - const ProblemDescription&, + const miopen::conv::ProblemDescription&, const PerformanceImplicitGemmBwdV1R1Xdlops&) const override; private: - static std::tuple CalculateGemmSize(const ProblemDescription&); + static std::tuple + CalculateGemmSize(const miopen::conv::ProblemDescription&); friend struct PerformanceImplicitGemmBwdV1R1Xdlops; }; struct ConvAsmImplicitGemmV4R1DynamicFwd final : ConvSolver { - // To suppress -Woverloaded-virtual - using ConvSolver::IsApplicable; - const std::string& SolverDbId() const override { return GetSolverDbId(); } - bool IsApplicable(const ExecutionContext& ctx, - const ProblemDescription& problem) const override; + bool IsApplicable(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; bool IsDynamic() const override { return true; } - ConvSolution GetSolution(const ExecutionContext& ctx, - const ProblemDescription& problem) const override; + ConvSolution GetSolution(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; }; struct ConvAsmImplicitGemmV4R1DynamicFwd_1x1 final : ConvSolver { - // To suppress -Woverloaded-virtual - using ConvSolver::IsApplicable; - const std::string& SolverDbId() const override { return GetSolverDbId(); } - bool IsApplicable(const ExecutionContext& ctx, - const ProblemDescription& problem) const override; + bool IsApplicable(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; bool IsDynamic() const override { return true; } - ConvSolution GetSolution(const ExecutionContext& ctx, - const ProblemDescription& problem) const override; + ConvSolution GetSolution(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; }; struct ConvAsmImplicitGemmV4R1DynamicWrw final : ConvSolver { - // To suppress -Woverloaded-virtual - using ConvSolver::GetWorkspaceSize; - using ConvSolver::IsApplicable; - const std::string& SolverDbId() const override { return GetSolverDbId(); } - bool IsApplicable(const ExecutionContext& ctx, - const ProblemDescription& problem) const override; + bool IsApplicable(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; bool IsDynamic() const override { return true; } - size_t GetWorkspaceSize(const ExecutionContext& ctx, - const ProblemDescription& problem) const override; + size_t GetWorkspaceSize(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; bool MayNeedWorkspace() const override { return true; } - ConvSolution GetSolution(const ExecutionContext& ctx, - const ProblemDescription& problem) const override; + ConvSolution GetSolution(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; }; struct ConvAsmImplicitGemmGTCDynamicWrwXdlops final : ConvSolver { - // To suppress -Woverloaded-virtual - using ConvSolver::GetWorkspaceSize; - using ConvSolver::IsApplicable; - const std::string& SolverDbId() const override { return GetSolverDbId(); } - bool IsApplicable(const ExecutionContext& ctx, - const ProblemDescription& problem) const override; + bool IsApplicable(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; bool IsDynamic() const override { return true; } - size_t GetWorkspaceSize(const ExecutionContext& ctx, - const ProblemDescription& problem) const override; + size_t GetWorkspaceSize(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; bool MayNeedWorkspace() const override { return true; } - ConvSolution GetSolution(const ExecutionContext& ctx, - const ProblemDescription& problem) const override; + ConvSolution GetSolution(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; }; struct ConvAsmImplicitGemmV4R1DynamicBwd final : ConvSolver { - // To suppress -Woverloaded-virtual - using ConvSolver::IsApplicable; - const std::string& SolverDbId() const override { return GetSolverDbId(); } - bool IsApplicable(const ExecutionContext& ctx, - const ProblemDescription& problem) const override; + bool IsApplicable(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; bool IsDynamic() const override { return true; } - ConvSolution GetSolution(const ExecutionContext& ctx, - const ProblemDescription& problem) const override; + ConvSolution GetSolution(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; }; struct ConvAsmImplicitGemmGTCDynamicFwdXdlops final : ConvSolver { - // To suppress -Woverloaded-virtual - using ConvSolver::IsApplicable; - const std::string& SolverDbId() const override { return GetSolverDbId(); } - bool IsApplicable(const ExecutionContext& ctx, - const ProblemDescription& problem) const override; + bool IsApplicable(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; bool IsDynamic() const override { return true; } - ConvSolution GetSolution(const ExecutionContext& ctx, - const ProblemDescription& problem) const override; + ConvSolution GetSolution(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; }; struct ConvAsmImplicitGemmGTCDynamicBwdXdlops final : ConvSolver { - // To suppress -Woverloaded-virtual - using ConvSolver::IsApplicable; - const std::string& SolverDbId() const override { return GetSolverDbId(); } - bool IsApplicable(const ExecutionContext& ctx, - const ProblemDescription& problem) const override; + bool IsApplicable(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; bool IsDynamic() const override { return true; } - ConvSolution GetSolution(const ExecutionContext& ctx, - const ProblemDescription& problem) const override; + ConvSolution GetSolution(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; }; /// Holds common member functions for the Solvers which share the same /// "legacy exhaustive search" machinery. struct ConvOclDirectFwdLegacyExhaustiveSearch : ConvTunableSolver { - LegacyPerformanceConfig GetDefaultPerformanceConfig(const ExecutionContext&, - const ProblemDescription&) const override; + LegacyPerformanceConfig + GetDefaultPerformanceConfig(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; LegacyPerformanceConfig Search(const ExecutionContext&, - const ProblemDescription&, + const miopen::conv::ProblemDescription&, const AnyInvokeParams& invoke_ctx) const override; private: template LegacyPerformanceConfig SearchImpl(const ExecutionContext&, - const ProblemDescription&, + const miopen::conv::ProblemDescription&, const AnyInvokeParams& invoke_ctx) const; }; struct ConvOclDirectFwd : ConvOclDirectFwdLegacyExhaustiveSearch { - static ConvSolution BaseGetSolution(const ExecutionContext& ctx, - const ProblemDescription& problem, - const LegacyPerformanceConfig& config); const std::string& SolverDbId() const override { return GetSolverDbId(); } - bool IsApplicable(const ExecutionContext&, const ProblemDescription&) const override; + static ConvSolution BaseGetSolution(const ExecutionContext&, + const miopen::conv::ProblemDescription&, + const LegacyPerformanceConfig&); + + bool IsApplicable(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; ConvSolution GetSolution(const ExecutionContext&, - const ProblemDescription&, + const miopen::conv::ProblemDescription&, const LegacyPerformanceConfig&) const override; bool IsValidPerformanceConfig(const ExecutionContext&, - const ProblemDescription&, + const miopen::conv::ProblemDescription&, const LegacyPerformanceConfig&) const override; }; @@ -1847,13 +1893,14 @@ struct ConvOclDirectFwd1x1 final : ConvOclDirectFwdLegacyExhaustiveSearch { const std::string& SolverDbId() const override { return GetSolverDbId(); } - bool IsApplicable(const ExecutionContext&, const ProblemDescription&) const override; + bool IsApplicable(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; ConvSolution GetSolution(const ExecutionContext&, - const ProblemDescription&, + const miopen::conv::ProblemDescription&, const LegacyPerformanceConfig&) const override; bool IsValidPerformanceConfig(const ExecutionContext&, - const ProblemDescription&, + const miopen::conv::ProblemDescription&, const LegacyPerformanceConfig&) const override { return true; @@ -1862,34 +1909,28 @@ struct ConvOclDirectFwd1x1 final : ConvOclDirectFwdLegacyExhaustiveSearch struct ConvBinWinograd3x3U final : ConvSolver { - // To suppress -Woverloaded-virtual - using ConvSolver::IsApplicable; - const std::string& SolverDbId() const override { return GetSolverDbId(); } - bool IsApplicable(const ExecutionContext& ctx, - const ProblemDescription& problem) const override; + bool IsApplicable(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; bool IsDynamic() const override { return true; } - ConvSolution GetSolution(const ExecutionContext& ctx, - const ProblemDescription& problem) const override; + ConvSolution GetSolution(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; }; struct ConvBinWinogradRxS final : ConvSolver { - // To suppress -Woverloaded-virtual - using ConvSolver::IsApplicable; - const std::string& SolverDbId() const override { return GetSolverDbId(); } - bool IsApplicable(const ExecutionContext& ctx, - const ProblemDescription& problem) const override; + bool IsApplicable(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; bool IsDynamic() const override { return true; } - ConvSolution GetSolution(const ExecutionContext& ctx, - const ProblemDescription& problem) const override; + ConvSolution GetSolution(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; }; struct PerformanceConfigConvBinWinogradRxS : PerfConfigBase @@ -1907,10 +1948,10 @@ struct PerformanceConfigConvBinWinogradRxS : PerfConfigBase - void HeuristicInit(const ExecutionContext&, const ProblemDescription&); + void HeuristicInit(const ExecutionContext&, const miopen::conv::ProblemDescription&); bool IsValidValue() const; - bool SetNextValue(const ProblemDescription&); - bool IsValid(const ExecutionContext& ctx, const ProblemDescription&) const + bool SetNextValue(const miopen::conv::ProblemDescription&); + bool IsValid(const ExecutionContext& ctx, const miopen::conv::ProblemDescription&) const { return IsValid(ctx); } @@ -1933,17 +1974,19 @@ struct ConvBinWinoRxS final : ConvTunableSolver(); } - bool IsApplicable(const ExecutionContext&, const ProblemDescription&) const override; + bool IsApplicable(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; bool IsDynamic() const override { return true; } - float GetWti(const ExecutionContext&, const ProblemDescription&) const override; - ConvSolution GetSolution(const ExecutionContext&, const ProblemDescription&) const override; + float GetWti(const ExecutionContext&, const miopen::conv::ProblemDescription&) const override; + ConvSolution GetSolution(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; }; template @@ -1989,11 +2034,14 @@ struct ConvMPBidirectWinograd final : ConvSolver ConvMPBidirectWinograd>(); } - bool IsApplicable(const ExecutionContext&, const ProblemDescription&) const override; + bool IsApplicable(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; bool IsDynamic() const override { return true; } - size_t GetWorkspaceSize(const ExecutionContext&, const ProblemDescription&) const override; + size_t GetWorkspaceSize(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; bool MayNeedWorkspace() const override { return true; } - ConvSolution GetSolution(const ExecutionContext&, const ProblemDescription&) const override; + ConvSolution GetSolution(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; // kernel_file_name for solver identification static std::string GetSolverFileNames(int id) @@ -2045,7 +2093,8 @@ struct ConvMPBidirectWinograd_xdlops final ConvMPBidirectWinograd_xdlops>(); } - bool IsApplicable(const ExecutionContext&, const ProblemDescription&) const override; + bool IsApplicable(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; bool IsDynamic() const override { @@ -2057,7 +2106,7 @@ struct ConvMPBidirectWinograd_xdlops final PerformanceImplicitGemmForwardV4R4Xdlops GetDefaultPerformanceConfig(const ExecutionContext& ctx, - const ProblemDescription& problem) const override + const miopen::conv::ProblemDescription& problem) const override { const auto xdlops_problem = GetTransformedProblem(problem); const auto xdlops_ctx = GetTransformedConvContext(ctx, xdlops_problem); @@ -2068,7 +2117,7 @@ struct ConvMPBidirectWinograd_xdlops final bool IsValidPerformanceConfig(const ExecutionContext& ctx, - const ProblemDescription& problem, + const miopen::conv::ProblemDescription& problem, const PerformanceImplicitGemmForwardV4R4Xdlops& config) const override { const auto xdlops_problem = GetTransformedProblem(problem); @@ -2079,7 +2128,7 @@ struct ConvMPBidirectWinograd_xdlops final } size_t GetWorkspaceSize(const ExecutionContext& ctx, - const ProblemDescription& problem) const override + const miopen::conv::ProblemDescription& problem) const override { const auto xdlops_problem = GetTransformedProblem(problem); const auto xdlops_ctx = GetTransformedConvContext(ctx, xdlops_problem); @@ -2093,16 +2142,18 @@ struct ConvMPBidirectWinograd_xdlops final PerformanceImplicitGemmForwardV4R4Xdlops Search(const ExecutionContext&, - const ProblemDescription&, + const miopen::conv::ProblemDescription&, const AnyInvokeParams& invoke_ctx) const override; ConvSolution GetSolution(const ExecutionContext&, - const ProblemDescription&, + const miopen::conv::ProblemDescription&, const PerformanceImplicitGemmForwardV4R4Xdlops&) const override; private: - ExecutionContext GetTransformedConvContext(const ExecutionContext& ctx, - const ProblemDescription& transformed_problem) const; - ProblemDescription GetTransformedProblem(const ProblemDescription& problem) const; + ExecutionContext + GetTransformedConvContext(const ExecutionContext& ctx, + const miopen::conv::ProblemDescription& transformed_problem) const; + miopen::conv::ProblemDescription + GetTransformedProblem(const miopen::conv::ProblemDescription& problem) const; // kernel_file_name for solver identification static std::string GetSolverFileNames(int id) @@ -2145,28 +2196,24 @@ extern template struct ConvMPBidirectWinograd_xdlops<6, 3>; template struct ConvWinograd3x3MultipassWrW final : ConvSolver { - // To suppress -Woverloaded-virtual - using ConvSolver::GetWorkspaceSize; - using ConvSolver::IsApplicable; - const std::string& SolverDbId() const override { return GetSolverDbId< ConvWinograd3x3MultipassWrW>(); } - bool IsApplicable(const ExecutionContext& ctx, - const ProblemDescription& problem) const override; + bool IsApplicable(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; bool IsDynamic() const override { return true; } - size_t GetWorkspaceSize(const ExecutionContext& ctx, - const ProblemDescription& problem) const override; + size_t GetWorkspaceSize(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; bool MayNeedWorkspace() const override { return true; } - ConvSolution GetSolution(const ExecutionContext& ctx, - const ProblemDescription& problem) const override; + ConvSolution GetSolution(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; // kernel_file_name for solver identification static std::string GetSolverFileNames(int id) @@ -2189,7 +2236,7 @@ struct ConvWinograd3x3MultipassWrW final : ConvSolver static int GetGroupCountMult() { return 4; } - static int GetSolverWinoXformHWSize(const ProblemDescription& problem, int id) + static int GetSolverWinoXformHWSize(const miopen::conv::ProblemDescription& problem, int id) { if(id == 0) return WinoDataH + @@ -2201,7 +2248,7 @@ struct ConvWinograd3x3MultipassWrW final : ConvSolver private: InvokerFactory PrepareInvokerFactory(const ExecutionContext&, - const ProblemDescription&, + const miopen::conv::ProblemDescription&, std::size_t ws_sz) const; }; @@ -2264,10 +2311,10 @@ struct PerformanceConfigAsmDirect3x3WrW : PerfConfigBase(); } PerformanceConfigAsmDirect3x3WrW - GetDefaultPerformanceConfig(const ExecutionContext&, const ProblemDescription&) const override; + GetDefaultPerformanceConfig(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; bool IsValidPerformanceConfig(const ExecutionContext&, - const ProblemDescription&, + const miopen::conv::ProblemDescription&, const PerformanceConfigAsmDirect3x3WrW&) const override; PerformanceConfigAsmDirect3x3WrW Search(const ExecutionContext&, - const ProblemDescription&, + const miopen::conv::ProblemDescription&, const AnyInvokeParams& invoke_ctx) const override; - bool IsApplicable(const ExecutionContext&, const ProblemDescription&) const override; + bool IsApplicable(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; ConvSolution GetSolution(const ExecutionContext&, - const ProblemDescription&, + const miopen::conv::ProblemDescription&, const PerformanceConfigAsmDirect3x3WrW& config) const override; }; @@ -2297,11 +2346,13 @@ struct ConvWinoFuryRxS final : ConvSolver return GetSolverDbId>(); } - bool IsApplicable(const ExecutionContext&, const ProblemDescription&) const override; + bool IsApplicable(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; bool IsDynamic() const override { return true; } - float GetWti(const ExecutionContext&, const ProblemDescription&) const override; + float GetWti(const ExecutionContext&, const miopen::conv::ProblemDescription&) const override; - ConvSolution GetSolution(const ExecutionContext&, const ProblemDescription&) const override; + ConvSolution GetSolution(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; static constexpr bool is2x3() { return Winodata == 2 && Winofilter == 3; } static constexpr bool is3x2() { return Winodata == 3 && Winofilter == 2; } @@ -2403,10 +2454,10 @@ struct PerformanceConfigConvAsmBwdWrW1x1 : PerfConfigBase(); } PerformanceConfigConvAsmBwdWrW1x1 - GetDefaultPerformanceConfig(const ExecutionContext&, const ProblemDescription&) const override; + GetDefaultPerformanceConfig(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; bool IsValidPerformanceConfig(const ExecutionContext&, - const ProblemDescription&, + const miopen::conv::ProblemDescription&, const PerformanceConfigConvAsmBwdWrW1x1&) const override; PerformanceConfigConvAsmBwdWrW1x1 Search(const ExecutionContext&, - const ProblemDescription&, + const miopen::conv::ProblemDescription&, const AnyInvokeParams& invoke_ctx) const override; - bool IsApplicable(const ExecutionContext&, const ProblemDescription&) const override; - size_t GetWorkspaceSize(const ExecutionContext&, const ProblemDescription&) const override; + bool IsApplicable(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; + size_t GetWorkspaceSize(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; bool MayNeedWorkspace() const override { return true; } ConvSolution GetSolution(const ExecutionContext&, - const ProblemDescription&, + const miopen::conv::ProblemDescription&, const PerformanceConfigConvAsmBwdWrW1x1&) const override; }; @@ -2483,10 +2537,10 @@ struct PerformanceConfigConvOclBwdWrw2 int GetNumOutChannelTiles() const { return n_out_channels_tiles; } int GetNumOutRowsPerIterPerWork() const { return n_out_rows_in_lcl; } // clang-format on - void HeuristicInit(const ProblemDescription&); + void HeuristicInit(const miopen::conv::ProblemDescription&); bool IsValidValue() const; - bool SetNextValue(const ProblemDescription&); - bool IsValid(const ExecutionContext&, const ProblemDescription&) const; + bool SetNextValue(const miopen::conv::ProblemDescription&); + bool IsValid(const ExecutionContext&, const miopen::conv::ProblemDescription&) const; bool operator==(const PerformanceConfigConvOclBwdWrw2& other) const; }; @@ -2499,24 +2553,27 @@ struct ConvOclBwdWrW2 : ConvTunableSolver - GetDefaultPerformanceConfig(const ExecutionContext&, const ProblemDescription&) const override; + GetDefaultPerformanceConfig(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; bool IsValidPerformanceConfig(const ExecutionContext&, - const ProblemDescription&, + const miopen::conv::ProblemDescription&, const PerformanceConfigConvOclBwdWrw2&) const override; PerformanceConfigConvOclBwdWrw2 Search(const ExecutionContext&, - const ProblemDescription&, + const miopen::conv::ProblemDescription&, const AnyInvokeParams& invoke_ctx) const override; - bool IsApplicable(const ExecutionContext&, const ProblemDescription&) const override; - size_t GetWorkspaceSize(const ExecutionContext&, const ProblemDescription&) const override; + bool IsApplicable(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; + size_t GetWorkspaceSize(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; bool MayNeedWorkspace() const override { return true; } ConvSolution GetSolution(const ExecutionContext&, - const ProblemDescription&, + const miopen::conv::ProblemDescription&, const PerformanceConfigConvOclBwdWrw2&) const override; protected: - bool IsApplicableBase(const ExecutionContext&, const ProblemDescription&) const; + bool IsApplicableBase(const ExecutionContext&, const miopen::conv::ProblemDescription&) const; }; // To suppress misleading clang warnings @@ -2552,8 +2609,10 @@ struct ConvOclBwdWrW2NonTunable final : ConvOclBwdWrW2<1> return GetSolverDbId(); } - bool IsApplicable(const ExecutionContext&, const ProblemDescription&) const override; - ConvSolution GetSolution(const ExecutionContext&, const ProblemDescription&) const; + bool IsApplicable(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; + ConvSolution GetSolution(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const; private: // This function dervied from ConvOclBwdWrW2 is declared private @@ -2566,40 +2625,42 @@ struct ConvOclBwdWrW53 final : ConvSolver { const std::string& SolverDbId() const override { return GetSolverDbId(); } - bool IsApplicable(const ExecutionContext&, const ProblemDescription&) const override; - size_t GetWorkspaceSize(const ExecutionContext&, const ProblemDescription&) const override; + bool IsApplicable(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; + size_t GetWorkspaceSize(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; bool MayNeedWorkspace() const override { return true; } - ConvSolution GetSolution(const ExecutionContext&, const ProblemDescription&) const override; + ConvSolution GetSolution(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; }; struct ConvOclBwdWrW1x1 final : ConvSolver { const std::string& SolverDbId() const override { return GetSolverDbId(); } - bool IsApplicable(const ExecutionContext&, const ProblemDescription&) const override; - size_t GetWorkspaceSize(const ExecutionContext&, const ProblemDescription&) const override; + bool IsApplicable(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; + size_t GetWorkspaceSize(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; bool MayNeedWorkspace() const override { return true; } - ConvSolution GetSolution(const ExecutionContext&, const ProblemDescription&) const override; + ConvSolution GetSolution(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; }; struct fft final : ConvSolver { - // To suppress -Woverloaded-virtual - using ConvSolver::GetWorkspaceSize; - using ConvSolver::IsApplicable; - const std::string& SolverDbId() const override { return GetSolverDbId(); } - bool IsApplicable(const ExecutionContext& ctx, - const ProblemDescription& problem) const override; + bool IsApplicable(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; - size_t GetWorkspaceSize(const ExecutionContext& ctx, - const ProblemDescription& problem) const override; + size_t GetWorkspaceSize(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; bool MayNeedWorkspace() const override { return true; } - ConvSolution GetSolution(const ExecutionContext& ctx, - const ProblemDescription& problem) const override; + ConvSolution GetSolution(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; }; struct PerformanceImplicitGemmWrwV4R4Xdlops : PerfConfigBase @@ -2637,23 +2698,26 @@ struct PerformanceImplicitGemmWrwV4R4Xdlops : PerfConfigBase - CalculateGemmSizeAndGemmKBlock(const ExecutionContext&, const ProblemDescription&) const; + CalculateGemmSizeAndGemmKBlock(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const; std::tuple CalculateBlockSize() const; std::tuple CalculateGridSize(const ExecutionContext&, - const ProblemDescription&) const; + const miopen::conv::ProblemDescription&) const; std::tuple - CalculateGemmABlockCopyPerformanceParameters(const ProblemDescription&) const; + CalculateGemmABlockCopyPerformanceParameters(const miopen::conv::ProblemDescription&) const; std::tuple - CalculateGemmBBlockCopyPerformanceParameters(const ProblemDescription&) const; - std::tuple CalculateLdsNumberOfByte(const ProblemDescription&) const; + CalculateGemmBBlockCopyPerformanceParameters(const miopen::conv::ProblemDescription&) const; + std::tuple + CalculateLdsNumberOfByte(const miopen::conv::ProblemDescription&) const; }; struct ConvHipImplicitGemmWrwV4R4Xdlops final @@ -2665,18 +2729,21 @@ struct ConvHipImplicitGemmWrwV4R4Xdlops final } PerformanceImplicitGemmWrwV4R4Xdlops - GetDefaultPerformanceConfig(const ExecutionContext&, const ProblemDescription&) const override; - size_t GetWorkspaceSize(const ExecutionContext&, const ProblemDescription&) const override; + GetDefaultPerformanceConfig(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; + size_t GetWorkspaceSize(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; bool MayNeedWorkspace() const override { return true; } bool IsValidPerformanceConfig(const ExecutionContext&, - const ProblemDescription&, + const miopen::conv::ProblemDescription&, const PerformanceImplicitGemmWrwV4R4Xdlops&) const override; - bool IsApplicable(const ExecutionContext&, const ProblemDescription&) const override; + bool IsApplicable(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; ConvSolution GetSolution(const ExecutionContext&, - const ProblemDescription&, + const miopen::conv::ProblemDescription&, const PerformanceImplicitGemmWrwV4R4Xdlops&) const override; PerformanceImplicitGemmWrwV4R4Xdlops Search(const ExecutionContext&, - const ProblemDescription&, + const miopen::conv::ProblemDescription&, const AnyInvokeParams& invoke_ctx) const override; }; @@ -2721,23 +2788,26 @@ struct PerformanceImplicitGemmWrwV4R4Xdlops_Padded_Gemm bool operator==(const PerformanceImplicitGemmWrwV4R4Xdlops_Padded_Gemm& other) const; - void HeuristicInit(const ExecutionContext&, const ProblemDescription&); - bool SetNextValue(const ProblemDescription&); + void HeuristicInit(const ExecutionContext&, const miopen::conv::ProblemDescription&); + bool SetNextValue(const miopen::conv::ProblemDescription&); bool IsValidValue() const; - bool IsValid(const ExecutionContext&, const ProblemDescription&) const; - bool IsReallyValid(const ExecutionContext&, const ProblemDescription&) const; - bool IsFastToBeUsedForTuning(const ExecutionContext&, const ProblemDescription&) const; + bool IsValid(const ExecutionContext&, const miopen::conv::ProblemDescription&) const; + bool IsReallyValid(const ExecutionContext&, const miopen::conv::ProblemDescription&) const; + bool IsFastToBeUsedForTuning(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const; std::tuple - CalculateGemmSizeAndGemmKBlock(const ExecutionContext&, const ProblemDescription&) const; + CalculateGemmSizeAndGemmKBlock(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const; std::tuple CalculateBlockSize() const; std::tuple CalculateGridSize(const ExecutionContext&, - const ProblemDescription&) const; + const miopen::conv::ProblemDescription&) const; std::tuple - CalculateGemmABlockCopyPerformanceParameters(const ProblemDescription&) const; + CalculateGemmABlockCopyPerformanceParameters(const miopen::conv::ProblemDescription&) const; std::tuple - CalculateGemmBBlockCopyPerformanceParameters(const ProblemDescription&) const; - std::tuple CalculateLdsNumberOfByte(const ProblemDescription&) const; + CalculateGemmBBlockCopyPerformanceParameters(const miopen::conv::ProblemDescription&) const; + std::tuple + CalculateLdsNumberOfByte(const miopen::conv::ProblemDescription&) const; }; struct ConvHipImplicitGemmWrwV4R4Xdlops_Padded_Gemm final @@ -2749,21 +2819,24 @@ struct ConvHipImplicitGemmWrwV4R4Xdlops_Padded_Gemm final } PerformanceImplicitGemmWrwV4R4Xdlops_Padded_Gemm - GetDefaultPerformanceConfig(const ExecutionContext&, const ProblemDescription&) const override; - size_t GetWorkspaceSize(const ExecutionContext&, const ProblemDescription&) const override; + GetDefaultPerformanceConfig(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; + size_t GetWorkspaceSize(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; bool MayNeedWorkspace() const override { return true; } bool IsValidPerformanceConfig( const ExecutionContext&, - const ProblemDescription&, + const miopen::conv::ProblemDescription&, const PerformanceImplicitGemmWrwV4R4Xdlops_Padded_Gemm&) const override; - bool IsApplicable(const ExecutionContext&, const ProblemDescription&) const override; + bool IsApplicable(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; ConvSolution GetSolution(const ExecutionContext&, - const ProblemDescription&, + const miopen::conv::ProblemDescription&, const PerformanceImplicitGemmWrwV4R4Xdlops_Padded_Gemm&) const override; PerformanceImplicitGemmWrwV4R4Xdlops_Padded_Gemm Search(const ExecutionContext&, - const ProblemDescription&, + const miopen::conv::ProblemDescription&, const AnyInvokeParams& invoke_ctx) const override; }; @@ -2784,12 +2857,12 @@ struct PerformanceConvCkIgemmFwdV6r1DlopsNchw f(self.ck_tunable_list_id, "ck_tunable_list_id"); } - bool SetNextValue(const ProblemDescription&); - bool IsValid(const ExecutionContext&, const ProblemDescription& problem) const + bool SetNextValue(const miopen::conv::ProblemDescription&); + bool IsValid(const ExecutionContext&, const miopen::conv::ProblemDescription& problem) const { return IsValid(problem); } - bool IsValid(const ProblemDescription&) const; + bool IsValid(const miopen::conv::ProblemDescription&) const; bool operator==(const PerformanceConvCkIgemmFwdV6r1DlopsNchw& config) const { return ck_tunable_list_id == config.ck_tunable_list_id; @@ -2803,20 +2876,23 @@ struct ConvCkIgemmFwdV6r1DlopsNchw final : ConvTunableSolver(); } - bool IsApplicable(const ExecutionContext&, const ProblemDescription&) const override; - size_t GetWorkspaceSize(const ExecutionContext&, const ProblemDescription&) const override; + bool IsApplicable(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; + size_t GetWorkspaceSize(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; bool MayNeedWorkspace() const override { return true; } bool IsDynamic() const override { return false; } PerformanceConvCkIgemmFwdV6r1DlopsNchw - GetDefaultPerformanceConfig(const ExecutionContext&, const ProblemDescription&) const override; + GetDefaultPerformanceConfig(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; bool IsValidPerformanceConfig(const ExecutionContext&, - const ProblemDescription&, + const miopen::conv::ProblemDescription&, const PerformanceConvCkIgemmFwdV6r1DlopsNchw&) const override; PerformanceConvCkIgemmFwdV6r1DlopsNchw Search(const ExecutionContext&, - const ProblemDescription&, + const miopen::conv::ProblemDescription&, const AnyInvokeParams& invoke_ctx) const override; ConvSolution GetSolution(const ExecutionContext&, - const ProblemDescription&, + const miopen::conv::ProblemDescription&, const PerformanceConvCkIgemmFwdV6r1DlopsNchw&) const override; }; @@ -2827,15 +2903,17 @@ struct ConvDirectNaiveConvFwd final : ConvSolver return GetSolverDbId(); } - bool IsApplicable(const ExecutionContext&, const ProblemDescription&) const override; + bool IsApplicable(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; bool IsDynamic() const override { return true; } /// Use very small fixed value enough to backup GEMM for cases when /// GEMM is disabled due to MIOpenGemm or OCL compiler issues. - float GetWti(const ExecutionContext&, const ProblemDescription&) const override + float GetWti(const ExecutionContext&, const miopen::conv::ProblemDescription&) const override { return 0.01f; } - ConvSolution GetSolution(const ExecutionContext&, const ProblemDescription&) const override; + ConvSolution GetSolution(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; }; struct ConvDirectNaiveConvBwd final : ConvSolver @@ -2845,15 +2923,17 @@ struct ConvDirectNaiveConvBwd final : ConvSolver return GetSolverDbId(); } - bool IsApplicable(const ExecutionContext&, const ProblemDescription&) const override; + bool IsApplicable(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; bool IsDynamic() const override { return true; } /// Use very small fixed value enough to backup GEMM for cases when /// GEMM is disabled due to MIOpenGemm or OCL compiler issues. - float GetWti(const ExecutionContext&, const ProblemDescription&) const override + float GetWti(const ExecutionContext&, const miopen::conv::ProblemDescription&) const override { return 0.01f; } - ConvSolution GetSolution(const ExecutionContext&, const ProblemDescription&) const override; + ConvSolution GetSolution(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; }; struct ConvDirectNaiveConvWrw final : ConvSolver @@ -2863,32 +2943,27 @@ struct ConvDirectNaiveConvWrw final : ConvSolver return GetSolverDbId(); } - bool IsApplicable(const ExecutionContext&, const ProblemDescription&) const override; + bool IsApplicable(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; bool IsDynamic() const override { return true; } /// Use very small fixed value enough to backup GEMM for cases when /// GEMM is disabled due to MIOpenGemm or OCL compiler issues. - float GetWti(const ExecutionContext&, const ProblemDescription&) const override + float GetWti(const ExecutionContext&, const miopen::conv::ProblemDescription&) const override { return 0.01f; } - ConvSolution GetSolution(const ExecutionContext&, const ProblemDescription&) const override; + ConvSolution GetSolution(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; }; struct GemmFwdBase : ConvSolver { - // To suppress -Woverloaded-virtual - using ConvSolver::GetWti; - using ConvSolver::IsApplicable; - bool IsDynamic() const override { return true; } - float GetWti(const ExecutionContext& ctx, const ProblemDescription& problem) const override - { - return GetWti(ctx, static_cast(problem)); - } + float GetWti(const ExecutionContext&, const miopen::conv::ProblemDescription&) const override; private: - bool IsApplicable(const ExecutionContext&, const conv::ProblemDescription&) const; - float GetWti(const ExecutionContext& ctx, const conv::ProblemDescription& problem) const; + bool IsApplicable(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; friend struct GemmFwd1x1_0_2; friend struct GemmFwd1x1_0_1_int8; @@ -2898,153 +2973,82 @@ struct GemmFwdBase : ConvSolver struct GemmFwd1x1_0_2 final : GemmFwdBase { - // To suppress -Woverloaded-virtual - using GemmFwdBase::GetWorkspaceSize; - const std::string& SolverDbId() const override { return GetSolverDbId(); } - size_t GetWorkspaceSize(const ExecutionContext& ctx, - const ProblemDescription& problem) const override - { - return GetWorkspaceSize(ctx, static_cast(problem)); - } + size_t GetWorkspaceSize(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; bool MayNeedWorkspace() const override { return true; } - bool IsApplicable(const ExecutionContext& ctx, const ProblemDescription& problem) const override - { - return IsApplicable(ctx, static_cast(problem)); - } + bool IsApplicable(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; - ConvSolution GetSolution(const ExecutionContext& ctx, - const ProblemDescription& problem) const override - { - return GetSolution(ctx, static_cast(problem)); - } - -private: - size_t GetWorkspaceSize(const ExecutionContext&, const conv::ProblemDescription&) const; - bool IsApplicable(const ExecutionContext&, const conv::ProblemDescription&) const; - ConvSolution GetSolution(const ExecutionContext&, const conv::ProblemDescription&) const; + ConvSolution GetSolution(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; friend struct GemmFwdRest; }; struct GemmFwd1x1_0_1_int8 final : GemmFwdBase { - // To suppress -Woverloaded-virtual - using GemmFwdBase::GetWorkspaceSize; - const std::string& SolverDbId() const override { return GetSolverDbId(); } - size_t GetWorkspaceSize(const ExecutionContext& ctx, - const ProblemDescription& problem) const override - { - return GetWorkspaceSize(ctx, static_cast(problem)); - } + size_t GetWorkspaceSize(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; bool MayNeedWorkspace() const override { return true; } - bool IsApplicable(const ExecutionContext& ctx, const ProblemDescription& problem) const override - { - return IsApplicable(ctx, static_cast(problem)); - } - - ConvSolution GetSolution(const ExecutionContext& ctx, - const ProblemDescription& problem) const override - { - return GetSolution(ctx, static_cast(problem)); - } + bool IsApplicable(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; -private: - size_t GetWorkspaceSize(const ExecutionContext&, const conv::ProblemDescription&) const; - bool IsApplicable(const ExecutionContext&, const conv::ProblemDescription&) const; - ConvSolution GetSolution(const ExecutionContext&, const conv::ProblemDescription&) const; + ConvSolution GetSolution(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; friend struct GemmFwdRest; }; struct GemmFwd1x1_0_1 final : GemmFwdBase { - // To suppress -Woverloaded-virtual - using GemmFwdBase::GetWorkspaceSize; - const std::string& SolverDbId() const override { return GetSolverDbId(); } - size_t GetWorkspaceSize(const ExecutionContext& ctx, - const ProblemDescription& problem) const override - { - return GetWorkspaceSize(ctx, static_cast(problem)); - } + size_t GetWorkspaceSize(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; bool MayNeedWorkspace() const override { return true; } - bool IsApplicable(const ExecutionContext& ctx, const ProblemDescription& problem) const override - { - return IsApplicable(ctx, static_cast(problem)); - } - - ConvSolution GetSolution(const ExecutionContext& ctx, - const ProblemDescription& problem) const override - { - return GetSolution(ctx, static_cast(problem)); - } + bool IsApplicable(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; -private: - size_t GetWorkspaceSize(const ExecutionContext&, const conv::ProblemDescription&) const; - bool IsApplicable(const ExecutionContext&, const conv::ProblemDescription&) const; - ConvSolution GetSolution(const ExecutionContext&, const conv::ProblemDescription&) const; + ConvSolution GetSolution(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; friend struct GemmFwdRest; }; struct GemmFwdRest final : GemmFwdBase { - // To suppress -Woverloaded-virtual - using GemmFwdBase::GetWorkspaceSize; - const std::string& SolverDbId() const override { return GetSolverDbId(); } - size_t GetWorkspaceSize(const ExecutionContext& ctx, - const ProblemDescription& problem) const override - { - return GetWorkspaceSize(ctx, static_cast(problem)); - } + size_t GetWorkspaceSize(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; bool MayNeedWorkspace() const override { return true; } - bool IsApplicable(const ExecutionContext& ctx, const ProblemDescription& problem) const override - { - return IsApplicable(ctx, static_cast(problem)); - } - - ConvSolution GetSolution(const ExecutionContext& ctx, - const ProblemDescription& problem) const override - { - return GetSolution(ctx, static_cast(problem)); - } + bool IsApplicable(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; -private: - size_t GetWorkspaceSize(const ExecutionContext&, const conv::ProblemDescription&) const; - bool IsApplicable(const ExecutionContext&, const conv::ProblemDescription&) const; - ConvSolution GetSolution(const ExecutionContext&, const conv::ProblemDescription&) const; + ConvSolution GetSolution(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; }; struct GemmBwdBase : ConvSolver { - // To suppress -Woverloaded-virtual - using ConvSolver::GetWti; - using ConvSolver::IsApplicable; - bool IsDynamic() const override { return true; } - float GetWti(const ExecutionContext& ctx, const ProblemDescription& problem) const override - { - return GetWti(ctx, static_cast(problem)); - } + float GetWti(const ExecutionContext&, const miopen::conv::ProblemDescription&) const override; private: - bool IsApplicable(const ExecutionContext&, const conv::ProblemDescription&) const; - float GetWti(const ExecutionContext& context, const conv::ProblemDescription& problem) const; + bool IsApplicable(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; friend struct GemmBwd1x1_stride2; friend struct GemmBwd1x1_stride1; @@ -3053,122 +3057,68 @@ struct GemmBwdBase : ConvSolver struct GemmBwd1x1_stride2 final : GemmBwdBase { - // To suppress -Woverloaded-virtual - using GemmBwdBase::GetWorkspaceSize; - using GemmBwdBase::IsApplicable; - const std::string& SolverDbId() const override { return GetSolverDbId(); } - size_t GetWorkspaceSize(const ExecutionContext& ctx, - const ProblemDescription& problem) const override - { - return GetWorkspaceSize(ctx, static_cast(problem)); - } + size_t GetWorkspaceSize(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; bool MayNeedWorkspace() const override { return true; } - bool IsApplicable(const ExecutionContext& ctx, const ProblemDescription& problem) const override - { - return IsApplicable(ctx, static_cast(problem)); - } - - ConvSolution GetSolution(const ExecutionContext& ctx, - const ProblemDescription& problem) const override - { - return GetSolution(ctx, static_cast(problem)); - } + bool IsApplicable(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; -private: - size_t GetWorkspaceSize(const ExecutionContext&, const conv::ProblemDescription&) const; - bool IsApplicable(const ExecutionContext&, const conv::ProblemDescription&) const; - ConvSolution GetSolution(const ExecutionContext&, const conv::ProblemDescription&) const; + ConvSolution GetSolution(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; friend struct GemmBwdRest; }; struct GemmBwd1x1_stride1 final : GemmBwdBase { - // To suppress -Woverloaded-virtual - using GemmBwdBase::GetWorkspaceSize; - const std::string& SolverDbId() const override { return GetSolverDbId(); } - size_t GetWorkspaceSize(const ExecutionContext& ctx, - const ProblemDescription& problem) const override - { - return GetWorkspaceSize(ctx, static_cast(problem)); - } + size_t GetWorkspaceSize(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; bool MayNeedWorkspace() const override { return true; } - bool IsApplicable(const ExecutionContext& ctx, const ProblemDescription& problem) const override - { - return IsApplicable(ctx, static_cast(problem)); - } + bool IsApplicable(const ExecutionContext&, + const miopen::conv::ProblemDescription& problem) const override; - ConvSolution GetSolution(const ExecutionContext& ctx, - const ProblemDescription& problem) const override - { - return GetSolution(ctx, static_cast(problem)); - } + ConvSolution GetSolution(const ExecutionContext&, + const miopen::conv::ProblemDescription& problem) const override; private: - size_t GetWorkspaceSize(const ExecutionContext&, const conv::ProblemDescription&) const; bool IsApplicableBeforeWorkaround(const ExecutionContext&, - const conv::ProblemDescription&) const; - bool IsApplicable(const ExecutionContext&, const conv::ProblemDescription&) const; - ConvSolution GetSolution(const ExecutionContext&, const conv::ProblemDescription&) const; + const miopen::conv::ProblemDescription&) const; friend struct GemmBwdRest; }; struct GemmBwdRest final : GemmBwdBase { - // To suppress -Woverloaded-virtual - using GemmBwdBase::GetWorkspaceSize; - const std::string& SolverDbId() const override { return GetSolverDbId(); } - size_t GetWorkspaceSize(const ExecutionContext& ctx, - const ProblemDescription& problem) const override - { - return GetWorkspaceSize(ctx, static_cast(problem)); - } + size_t GetWorkspaceSize(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; bool MayNeedWorkspace() const override { return true; } - bool IsApplicable(const ExecutionContext& ctx, const ProblemDescription& problem) const override - { - return IsApplicable(ctx, static_cast(problem)); - } - - ConvSolution GetSolution(const ExecutionContext& ctx, - const ProblemDescription& problem) const override - { - return GetSolution(ctx, static_cast(problem)); - } + bool IsApplicable(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; -private: - size_t GetWorkspaceSize(const ExecutionContext&, const conv::ProblemDescription&) const; - bool IsApplicable(const ExecutionContext&, const conv::ProblemDescription&) const; - ConvSolution GetSolution(const ExecutionContext&, const conv::ProblemDescription&) const; + ConvSolution GetSolution(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; }; struct GemmWrwBase : ConvSolver { - // To suppress -Woverloaded-virtual - using ConvSolver::GetWti; - using ConvSolver::IsApplicable; - bool IsDynamic() const override { return true; } - float GetWti(const ExecutionContext& ctx, const ProblemDescription& problem) const override - { - return GetWti(ctx, static_cast(problem)); - } + float GetWti(const ExecutionContext&, const miopen::conv::ProblemDescription&) const override; private: - bool IsApplicable(const ExecutionContext&, const conv::ProblemDescription&) const; - float GetWti(const ExecutionContext& context, const conv::ProblemDescription& problem) const; + bool IsApplicable(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; friend struct GemmWrw1x1_stride1; friend struct GemmWrwUniversal; @@ -3176,59 +3126,31 @@ struct GemmWrwBase : ConvSolver struct GemmWrw1x1_stride1 final : GemmWrwBase { - // To suppress -Woverloaded-virtual - using GemmWrwBase::IsApplicable; - const std::string& SolverDbId() const override { return GetSolverDbId(); } - bool IsApplicable(const ExecutionContext& ctx, const ProblemDescription& problem) const override - { - return IsApplicable(ctx, static_cast(problem)); - } + bool IsApplicable(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; - ConvSolution GetSolution(const ExecutionContext& ctx, - const ProblemDescription& problem) const override - { - return GetSolution(ctx, static_cast(problem)); - } - -private: - bool IsApplicable(const ExecutionContext&, const conv::ProblemDescription&) const; - ConvSolution GetSolution(const ExecutionContext&, const conv::ProblemDescription&) const; + ConvSolution GetSolution(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; friend struct GemmWrwUniversal; }; struct GemmWrwUniversal final : GemmWrwBase { - // To suppress -Woverloaded-virtual - using GemmWrwBase::GetWorkspaceSize; - const std::string& SolverDbId() const override { return GetSolverDbId(); } - size_t GetWorkspaceSize(const ExecutionContext& ctx, - const ProblemDescription& problem) const override - { - return GetWorkspaceSize(ctx, static_cast(problem)); - } + size_t GetWorkspaceSize(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; bool MayNeedWorkspace() const override { return true; } - bool IsApplicable(const ExecutionContext& ctx, const ProblemDescription& problem) const override - { - return IsApplicable(ctx, static_cast(problem)); - } + bool IsApplicable(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; - ConvSolution GetSolution(const ExecutionContext& ctx, - const ProblemDescription& problem) const override - { - return GetSolution(ctx, static_cast(problem)); - } - -private: - size_t GetWorkspaceSize(const ExecutionContext&, const conv::ProblemDescription&) const; - bool IsApplicable(const ExecutionContext&, const conv::ProblemDescription&) const; - ConvSolution GetSolution(const ExecutionContext&, const conv::ProblemDescription&) const; + ConvSolution GetSolution(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; }; struct PerformanceConfigAsmImplicitGemmGTC : PerfConfigBase @@ -3421,16 +3343,16 @@ struct PerformanceConfigAsmImplicitGemmGTC : PerfConfigBase(); } PerformanceConfigAsmImplicitGemmGTCFwdDlopsNCHWC - GetDefaultPerformanceConfig(const ExecutionContext&, const ProblemDescription&) const override; + GetDefaultPerformanceConfig(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; bool IsValidPerformanceConfig( const ExecutionContext&, - const ProblemDescription&, + const miopen::conv::ProblemDescription&, const PerformanceConfigAsmImplicitGemmGTCFwdDlopsNCHWC&) const override; PerformanceConfigAsmImplicitGemmGTCFwdDlopsNCHWC Search(const ExecutionContext&, - const ProblemDescription&, + const miopen::conv::ProblemDescription&, const AnyInvokeParams& invoke_ctx) const override; - bool IsApplicable(const ExecutionContext&, const ProblemDescription&) const override; + bool IsApplicable(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; bool IsDynamic() const override { return true; } ConvSolution GetSolution(const ExecutionContext&, - const ProblemDescription&, + const miopen::conv::ProblemDescription&, const PerformanceConfigAsmImplicitGemmGTCFwdDlopsNCHWC&) const override; }; @@ -4404,21 +4338,21 @@ struct PerformanceConfigHipImplicitGemmFwdXdlops : PerformanceConfigHipImplicitGemmFwdXdlops(0, "") { } - void HeuristicInit(const ProblemDescription&); - bool SetNextValue(const ProblemDescription&); + void HeuristicInit(const miopen::conv::ProblemDescription&); + bool SetNextValue(const miopen::conv::ProblemDescription&); bool IsValidValue() const; - bool IsValid(const ExecutionContext&, const ProblemDescription& problem) const + bool IsValid(const ExecutionContext&, const miopen::conv::ProblemDescription& problem) const { return IsValid(problem); } - bool IsValid(const ProblemDescription&) const; + bool IsValid(const miopen::conv::ProblemDescription&) const; bool operator==(const PerformanceConfigHipImplicitGemmFwdXdlops& other) const; private: template - void Init(const ProblemDescription&); + void Init(const miopen::conv::ProblemDescription&); template - bool CheckIsSupportCKArgs(const ProblemDescription&) const; + bool CheckIsSupportCKArgs(const miopen::conv::ProblemDescription&) const; }; struct ConvHipImplicitGemmFwdXdlops final @@ -4430,18 +4364,20 @@ struct ConvHipImplicitGemmFwdXdlops final } PerformanceConfigHipImplicitGemmFwdXdlops - GetDefaultPerformanceConfig(const ExecutionContext&, const ProblemDescription&) const override; + GetDefaultPerformanceConfig(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; bool IsValidPerformanceConfig(const ExecutionContext&, - const ProblemDescription&, + const miopen::conv::ProblemDescription&, const PerformanceConfigHipImplicitGemmFwdXdlops&) const override; PerformanceConfigHipImplicitGemmFwdXdlops Search(const ExecutionContext&, - const ProblemDescription&, + const miopen::conv::ProblemDescription&, const AnyInvokeParams& invoke_ctx) const override; - bool IsApplicable(const ExecutionContext&, const ProblemDescription&) const override; + bool IsApplicable(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; bool IsDynamic() const override { return true; } ConvSolution GetSolution(const ExecutionContext&, - const ProblemDescription&, + const miopen::conv::ProblemDescription&, const PerformanceConfigHipImplicitGemmFwdXdlops&) const override; /// \anchor igemm_get_wti_magic_number // Magic Number Alert: @@ -4454,14 +4390,14 @@ struct ConvHipImplicitGemmFwdXdlops final // Since we would like to us CK before naive, and use it instead (because // we do expect that CK is faster than Naive), therefore we use a // value bigger than 0.01f, e.g. 0.02f. - float GetWti(const ExecutionContext&, const ProblemDescription&) const override + float GetWti(const ExecutionContext&, const miopen::conv::ProblemDescription&) const override { return 0.02f; }; private: template - bool CheckCKApplicability(const ProblemDescription&) const; + bool CheckCKApplicability(const miopen::conv::ProblemDescription&) const; }; struct PerformanceConfigHipImplicitGemmBwdXdlops @@ -4482,21 +4418,21 @@ struct PerformanceConfigHipImplicitGemmBwdXdlops : PerformanceConfigHipImplicitGemmBwdXdlops(0, "") { } - void HeuristicInit(const ProblemDescription&); - bool SetNextValue(const ProblemDescription&); + void HeuristicInit(const miopen::conv::ProblemDescription&); + bool SetNextValue(const miopen::conv::ProblemDescription&); bool IsValidValue() const; - bool IsValid(const ExecutionContext&, const ProblemDescription& problem) const + bool IsValid(const ExecutionContext&, const miopen::conv::ProblemDescription& problem) const { return IsValid(problem); } - bool IsValid(const ProblemDescription&) const; + bool IsValid(const miopen::conv::ProblemDescription&) const; bool operator==(const PerformanceConfigHipImplicitGemmBwdXdlops& other) const; private: template - void Init(const ProblemDescription&); + void Init(const miopen::conv::ProblemDescription&); template - bool CheckIsSupportCKArgs(const ProblemDescription&) const; + bool CheckIsSupportCKArgs(const miopen::conv::ProblemDescription&) const; }; struct ConvHipImplicitGemmBwdXdlops final @@ -4508,28 +4444,30 @@ struct ConvHipImplicitGemmBwdXdlops final } PerformanceConfigHipImplicitGemmBwdXdlops - GetDefaultPerformanceConfig(const ExecutionContext&, const ProblemDescription&) const override; + GetDefaultPerformanceConfig(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; bool IsValidPerformanceConfig(const ExecutionContext&, - const ProblemDescription&, + const miopen::conv::ProblemDescription&, const PerformanceConfigHipImplicitGemmBwdXdlops&) const override; PerformanceConfigHipImplicitGemmBwdXdlops Search(const ExecutionContext&, - const ProblemDescription&, + const miopen::conv::ProblemDescription&, const AnyInvokeParams& invoke_ctx) const override; - bool IsApplicable(const ExecutionContext&, const ProblemDescription&) const override; + bool IsApplicable(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; bool IsDynamic() const override { return true; } ConvSolution GetSolution(const ExecutionContext&, - const ProblemDescription&, + const miopen::conv::ProblemDescription&, const PerformanceConfigHipImplicitGemmBwdXdlops&) const override; /// \ref igemm_get_wti_magic_number - float GetWti(const ExecutionContext&, const ProblemDescription&) const override + float GetWti(const ExecutionContext&, const miopen::conv::ProblemDescription&) const override { return 0.02f; }; private: template - bool CheckCKApplicability(const ProblemDescription&) const; + bool CheckCKApplicability(const miopen::conv::ProblemDescription&) const; }; struct PerformanceConfigHipImplicitGemmGroupFwdXdlops @@ -4550,21 +4488,21 @@ struct PerformanceConfigHipImplicitGemmGroupFwdXdlops : PerformanceConfigHipImplicitGemmGroupFwdXdlops(0, "") { } - void HeuristicInit(const ProblemDescription&); - bool SetNextValue(const ProblemDescription&); + void HeuristicInit(const miopen::conv::ProblemDescription&); + bool SetNextValue(const miopen::conv::ProblemDescription&); bool IsValidValue() const; - bool IsValid(const ExecutionContext&, const ProblemDescription& problem) const + bool IsValid(const ExecutionContext&, const miopen::conv::ProblemDescription& problem) const { return IsValid(problem); } - bool IsValid(const ProblemDescription&) const; + bool IsValid(const miopen::conv::ProblemDescription&) const; bool operator==(const PerformanceConfigHipImplicitGemmGroupFwdXdlops& other) const; private: template - void Init(const ProblemDescription&); + void Init(const miopen::conv::ProblemDescription&); template - bool CheckIsSupportCKArgs(const ProblemDescription&) const; + bool CheckIsSupportCKArgs(const miopen::conv::ProblemDescription&) const; }; struct ConvHipImplicitGemmGroupFwdXdlops final @@ -4576,29 +4514,31 @@ struct ConvHipImplicitGemmGroupFwdXdlops final } PerformanceConfigHipImplicitGemmGroupFwdXdlops - GetDefaultPerformanceConfig(const ExecutionContext&, const ProblemDescription&) const override; + GetDefaultPerformanceConfig(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; bool IsValidPerformanceConfig(const ExecutionContext&, - const ProblemDescription&, + const miopen::conv::ProblemDescription&, const PerformanceConfigHipImplicitGemmGroupFwdXdlops&) const override; PerformanceConfigHipImplicitGemmGroupFwdXdlops Search(const ExecutionContext&, - const ProblemDescription&, + const miopen::conv::ProblemDescription&, const AnyInvokeParams& invoke_ctx) const override; - bool IsApplicable(const ExecutionContext&, const ProblemDescription&) const override; + bool IsApplicable(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; bool IsDynamic() const override { return true; } ConvSolution GetSolution(const ExecutionContext&, - const ProblemDescription&, + const miopen::conv::ProblemDescription&, const PerformanceConfigHipImplicitGemmGroupFwdXdlops&) const override; /// \ref igemm_get_wti_magic_number - float GetWti(const ExecutionContext&, const ProblemDescription&) const override + float GetWti(const ExecutionContext&, const miopen::conv::ProblemDescription&) const override { return 0.02f; }; private: template - bool CheckCKApplicability(const ProblemDescription&) const; + bool CheckCKApplicability(const miopen::conv::ProblemDescription&) const; }; struct PerformanceConfigHipImplicitGemm3DGroupFwdXdlops @@ -4619,21 +4559,21 @@ struct PerformanceConfigHipImplicitGemm3DGroupFwdXdlops : PerformanceConfigHipImplicitGemm3DGroupFwdXdlops(0, "") { } - void HeuristicInit(const ProblemDescription&); - bool SetNextValue(const ProblemDescription&); + void HeuristicInit(const miopen::conv::ProblemDescription&); + bool SetNextValue(const miopen::conv::ProblemDescription&); bool IsValidValue() const; - bool IsValid(const ExecutionContext&, const ProblemDescription& problem) const + bool IsValid(const ExecutionContext&, const miopen::conv::ProblemDescription& problem) const { return IsValid(problem); } - bool IsValid(const ProblemDescription&) const; + bool IsValid(const miopen::conv::ProblemDescription&) const; bool operator==(const PerformanceConfigHipImplicitGemm3DGroupFwdXdlops& other) const; private: template - void Init(const ProblemDescription&); + void Init(const miopen::conv::ProblemDescription&); template - bool CheckIsSupportCKArgs(const ProblemDescription&) const; + bool CheckIsSupportCKArgs(const miopen::conv::ProblemDescription&) const; }; struct ConvHipImplicitGemm3DGroupFwdXdlops final @@ -4645,30 +4585,32 @@ struct ConvHipImplicitGemm3DGroupFwdXdlops final } PerformanceConfigHipImplicitGemm3DGroupFwdXdlops - GetDefaultPerformanceConfig(const ExecutionContext&, const ProblemDescription&) const override; + GetDefaultPerformanceConfig(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; bool IsValidPerformanceConfig( const ExecutionContext&, - const ProblemDescription&, + const miopen::conv::ProblemDescription&, const PerformanceConfigHipImplicitGemm3DGroupFwdXdlops&) const override; PerformanceConfigHipImplicitGemm3DGroupFwdXdlops Search(const ExecutionContext&, - const ProblemDescription&, + const miopen::conv::ProblemDescription&, const AnyInvokeParams& invoke_ctx) const override; - bool IsApplicable(const ExecutionContext&, const ProblemDescription&) const override; + bool IsApplicable(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; bool IsDynamic() const override { return true; } ConvSolution GetSolution(const ExecutionContext&, - const ProblemDescription&, + const miopen::conv::ProblemDescription&, const PerformanceConfigHipImplicitGemm3DGroupFwdXdlops&) const override; /// \ref igemm_get_wti_magic_number - float GetWti(const ExecutionContext&, const ProblemDescription&) const override + float GetWti(const ExecutionContext&, const miopen::conv::ProblemDescription&) const override { return 0.02f; }; private: template - bool CheckCKApplicability(const ProblemDescription&) const; + bool CheckCKApplicability(const miopen::conv::ProblemDescription&) const; }; struct PerformanceConfigHipImplicitGemm3DGroupWrwXdlops @@ -4689,14 +4631,14 @@ struct PerformanceConfigHipImplicitGemm3DGroupWrwXdlops : PerformanceConfigHipImplicitGemm3DGroupWrwXdlops(0, "") { } - void HeuristicInit(const ProblemDescription&); - bool SetNextValue(const ProblemDescription&); + void HeuristicInit(const miopen::conv::ProblemDescription&); + bool SetNextValue(const miopen::conv::ProblemDescription&); bool IsValidValue() const; - bool IsValid(const ExecutionContext&, const ProblemDescription& problem) const + bool IsValid(const ExecutionContext&, const miopen::conv::ProblemDescription& problem) const { return IsValid(problem); } - bool IsValid(const ProblemDescription&) const; + bool IsValid(const miopen::conv::ProblemDescription&) const; template static void Visit(Self&& s, F f) { @@ -4706,9 +4648,9 @@ struct PerformanceConfigHipImplicitGemm3DGroupWrwXdlops private: template - void Init(const ProblemDescription&); + void Init(const miopen::conv::ProblemDescription&); template - bool CheckIsSupportCKArgs(const ProblemDescription&) const; + bool CheckIsSupportCKArgs(const miopen::conv::ProblemDescription&) const; }; struct ConvHipImplicitGemm3DGroupWrwXdlops final @@ -4720,30 +4662,32 @@ struct ConvHipImplicitGemm3DGroupWrwXdlops final } PerformanceConfigHipImplicitGemm3DGroupWrwXdlops - GetDefaultPerformanceConfig(const ExecutionContext&, const ProblemDescription&) const override; + GetDefaultPerformanceConfig(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; bool IsValidPerformanceConfig( const ExecutionContext&, - const ProblemDescription&, + const miopen::conv::ProblemDescription&, const PerformanceConfigHipImplicitGemm3DGroupWrwXdlops&) const override; PerformanceConfigHipImplicitGemm3DGroupWrwXdlops Search(const ExecutionContext&, - const ProblemDescription&, + const miopen::conv::ProblemDescription&, const AnyInvokeParams& invoke_ctx) const override; - bool IsApplicable(const ExecutionContext&, const ProblemDescription&) const override; + bool IsApplicable(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; bool IsDynamic() const override { return true; } ConvSolution GetSolution(const ExecutionContext&, - const ProblemDescription&, + const miopen::conv::ProblemDescription&, const PerformanceConfigHipImplicitGemm3DGroupWrwXdlops&) const override; /// \ref igemm_get_wti_magic_number - float GetWti(const ExecutionContext&, const ProblemDescription&) const override + float GetWti(const ExecutionContext&, const miopen::conv::ProblemDescription&) const override { return 0.02f; }; private: template - bool CheckCKApplicability(const ProblemDescription&) const; + bool CheckCKApplicability(const miopen::conv::ProblemDescription&) const; }; struct PerformanceConfigHipImplicitGemm3DGroupBwdXdlops @@ -4764,14 +4708,14 @@ struct PerformanceConfigHipImplicitGemm3DGroupBwdXdlops : PerformanceConfigHipImplicitGemm3DGroupBwdXdlops(0, "") { } - void HeuristicInit(const ProblemDescription&); - bool SetNextValue(const ProblemDescription&); + void HeuristicInit(const miopen::conv::ProblemDescription&); + bool SetNextValue(const miopen::conv::ProblemDescription&); bool IsValidValue() const; - bool IsValid(const ExecutionContext&, const ProblemDescription& problem) const + bool IsValid(const ExecutionContext&, const miopen::conv::ProblemDescription& problem) const { return IsValid(problem); } - bool IsValid(const ProblemDescription&) const; + bool IsValid(const miopen::conv::ProblemDescription&) const; template static void Visit(Self&& s, F f) { @@ -4781,9 +4725,9 @@ struct PerformanceConfigHipImplicitGemm3DGroupBwdXdlops private: template - void Init(const ProblemDescription&); + void Init(const miopen::conv::ProblemDescription&); template - bool CheckIsSupportCKArgs(const ProblemDescription&) const; + bool CheckIsSupportCKArgs(const miopen::conv::ProblemDescription&) const; }; struct ConvHipImplicitGemm3DGroupBwdXdlops final @@ -4795,32 +4739,36 @@ struct ConvHipImplicitGemm3DGroupBwdXdlops final } PerformanceConfigHipImplicitGemm3DGroupBwdXdlops - GetDefaultPerformanceConfig(const ExecutionContext&, const ProblemDescription&) const override; + GetDefaultPerformanceConfig(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; bool IsValidPerformanceConfig( const ExecutionContext&, - const ProblemDescription&, + const miopen::conv::ProblemDescription&, const PerformanceConfigHipImplicitGemm3DGroupBwdXdlops&) const override; PerformanceConfigHipImplicitGemm3DGroupBwdXdlops Search(const ExecutionContext&, - const ProblemDescription&, + const miopen::conv::ProblemDescription&, const AnyInvokeParams& invoke_ctx) const override; - bool IsApplicable(const ExecutionContext&, const ProblemDescription&) const override; + bool IsApplicable(const ExecutionContext&, + const miopen::conv::ProblemDescription&) const override; bool IsDynamic() const override { return true; } ConvSolution GetSolution(const ExecutionContext&, - const ProblemDescription&, + const miopen::conv::ProblemDescription&, const PerformanceConfigHipImplicitGemm3DGroupBwdXdlops&) const override; /// \ref igemm_get_wti_magic_number - float GetWti(const ExecutionContext&, const ProblemDescription&) const override + float GetWti(const ExecutionContext&, const miopen::conv::ProblemDescription&) const override { return 0.02f; }; private: template - bool CheckCKApplicability(const ProblemDescription&) const; + bool CheckCKApplicability(const miopen::conv::ProblemDescription&) const; }; +} // namespace conv + // Use struct as a syntactic sugar to make the intent as clear as possible. struct ThisSolverIsDeprecatedStatic { diff --git a/src/include/miopen/solver/ck_utility_common.hpp b/src/include/miopen/solver/ck_utility_common.hpp index a8d049c389..18285d8a09 100644 --- a/src/include/miopen/solver/ck_utility_common.hpp +++ b/src/include/miopen/solver/ck_utility_common.hpp @@ -138,7 +138,8 @@ static inline auto get_ck_common_compiler_flag(const Handle& handle) return compiler_flag.str(); } -static inline auto get_ck_convolution_problem_descriptor(const ProblemDescription& problem) +static inline auto +get_ck_convolution_problem_descriptor(const miopen::conv::ProblemDescription& problem) { ck::DataTypeEnum_t ck_datatype; diff --git a/src/include/miopen/solver/conv_direct_naive_conv.hpp b/src/include/miopen/solver/conv_direct_naive_conv.hpp index 6d935b249d..f5f1062ea1 100644 --- a/src/include/miopen/solver/conv_direct_naive_conv.hpp +++ b/src/include/miopen/solver/conv_direct_naive_conv.hpp @@ -26,7 +26,7 @@ #pragma once #include -#include +#include #include "miopen/../../kernels/stride_array.hpp" #include @@ -36,29 +36,30 @@ #include namespace miopen { - namespace solver { +namespace conv { -bool ConvDirectNaiveConvIsAssemblyKernel(const ExecutionContext&, const ProblemDescription&); -std::string ConvDirectNaiveConvKernelName(const ProblemDescription&); +bool ConvDirectNaiveConvIsAssemblyKernel(const ExecutionContext&, + const miopen::conv::ProblemDescription&); +std::string ConvDirectNaiveConvKernelName(const miopen::conv::ProblemDescription&); std::string ConvDirectNaiveConvKernelFile(const ExecutionContext& ctx, - const ProblemDescription& problem); + const miopen::conv::ProblemDescription& problem); std::string ConvDirectNaiveConvCompileOption(const ExecutionContext& ctx, - const ProblemDescription& problem); + const miopen::conv::ProblemDescription& problem); bool ConvDirectNaiveConvIsApplicableByKernelType(const ExecutionContext&, - const ProblemDescription&); - -bool IsInputFp32(const ProblemDescription&); -bool IsInputFp16(const ProblemDescription&); -bool IsInputBfp16(const ProblemDescription&); -bool IsInputInt8(const ProblemDescription&); -bool IsAccFp64(const ProblemDescription&); -bool IsAccInt32(const ProblemDescription&); -bool IsOutputFp32(const ProblemDescription&); -bool IsOutputFp16(const ProblemDescription&); -bool IsOutputBfp16(const ProblemDescription&); -bool IsOutputInt8(const ProblemDescription&); -bool IsOutputInt32(const ProblemDescription&); + const miopen::conv::ProblemDescription&); + +bool IsInputFp32(const miopen::conv::ProblemDescription&); +bool IsInputFp16(const miopen::conv::ProblemDescription&); +bool IsInputBfp16(const miopen::conv::ProblemDescription&); +bool IsInputInt8(const miopen::conv::ProblemDescription&); +bool IsAccFp64(const miopen::conv::ProblemDescription&); +bool IsAccInt32(const miopen::conv::ProblemDescription&); +bool IsOutputFp32(const miopen::conv::ProblemDescription&); +bool IsOutputFp16(const miopen::conv::ProblemDescription&); +bool IsOutputBfp16(const miopen::conv::ProblemDescription&); +bool IsOutputInt8(const miopen::conv::ProblemDescription&); +bool IsOutputInt32(const miopen::conv::ProblemDescription&); namespace conv_internal { @@ -71,7 +72,7 @@ void DebugPrintTensorStrides(const TensorDescriptor& inDesc, * its strides to NGCHW, and for NHWC, we want to convert its strides to NHWGC. * Same applies for the 3D case. */ -int GetGroupStrideIndex(const ProblemDescription& problem); +int GetGroupStrideIndex(const miopen::conv::ProblemDescription& problem); /** * split the strides for C dimension in a tensor descriptor into (G, C_per_group). @@ -145,7 +146,9 @@ auto MakeStrideArray(V vec) } return ret; } + } // end namespace conv_internal +} // namespace conv } // namespace solver } // namespace miopen diff --git a/src/include/miopen/solver/gemm_common.hpp b/src/include/miopen/solver/gemm_common.hpp index 50a22538d2..14ab5b8444 100644 --- a/src/include/miopen/solver/gemm_common.hpp +++ b/src/include/miopen/solver/gemm_common.hpp @@ -30,15 +30,15 @@ #include namespace miopen { -namespace conv { namespace solver { +namespace conv { namespace gemm { bool IsWorkaroundIssue1315(const miopen::ExecutionContext& ctx); } // namespace gemm -} // namespace solver } // namespace conv +} // namespace solver } // namespace miopen #endif diff --git a/src/include/miopen/solver/implicitgemm_ck_util.hpp b/src/include/miopen/solver/implicitgemm_ck_util.hpp index 39e8706321..2bf5848dab 100644 --- a/src/include/miopen/solver/implicitgemm_ck_util.hpp +++ b/src/include/miopen/solver/implicitgemm_ck_util.hpp @@ -30,8 +30,15 @@ #include namespace miopen { + +namespace conv { +struct ProblemDescription; +} // namespace conv + namespace solver { +struct ConvSolution; + template typename ConvPtrsType::iterator FindConvPtrByID(ConvPtrsType& conv_ptrs, const std::string& kernel_id) @@ -43,7 +50,7 @@ typename ConvPtrsType::iterator FindConvPtrByID(ConvPtrsType& conv_ptrs, template + typename ProblemDescriptionType = miopen::conv::ProblemDescription> std::vector FillValidKernelsIDs(const ProblemDescriptionType& problem) { const auto args = CKArgsType{problem}; @@ -63,7 +70,7 @@ std::vector FillValidKernelsIDs(const ProblemDescriptionType& probl template + typename ProblemDescriptionType = miopen::conv::ProblemDescription> bool IsCKArgsSupported(const ProblemDescriptionType& problem, const std::string& kernel_id) { auto conv_ptrs = DeviceOpType::GetInstances(); @@ -74,7 +81,7 @@ bool IsCKArgsSupported(const ProblemDescriptionType& problem, const std::string& template + typename ProblemDescriptionType = miopen::conv::ProblemDescription> bool IsCKApplicable(const ProblemDescriptionType& problem) { const auto args = CKArgsType{problem}; @@ -87,7 +94,7 @@ bool IsCKApplicable(const ProblemDescriptionType& problem) template + typename ProblemDescriptionType = miopen::conv::ProblemDescription> ConvSolution MakeInvokerFactory(const ProblemDescriptionType& problem, const std::string& kernel_id) { auto conv_ptrs = DeviceOpType::GetInstances(); @@ -125,7 +132,7 @@ ConvSolution MakeInvokerFactory(const ProblemDescriptionType& problem, const std template + typename ProblemDescriptionType = miopen::conv::ProblemDescription> ConvSolution InitAnyInvokerFactory(const ProblemDescriptionType& problem, const std::string& kernel_id) { diff --git a/src/include/miopen/solver/implicitgemm_util.hpp b/src/include/miopen/solver/implicitgemm_util.hpp index d9aad50b98..e634be8d5f 100644 --- a/src/include/miopen/solver/implicitgemm_util.hpp +++ b/src/include/miopen/solver/implicitgemm_util.hpp @@ -60,93 +60,93 @@ namespace solver { // these functions map the dimensions of a bwd-wrw problem into a fwd problem // they are not supposed to be called by backward-data -static inline std::size_t KernelFilterStrideH(const ProblemDescription& problem) +static inline std::size_t KernelFilterStrideH(const miopen::conv::ProblemDescription& problem) { - if(problem.direction.IsBackwardWrW()) + if(problem.IsDirectionBackwardWrW()) return problem.GetDilationH(); else return problem.GetKernelStrideH(); } -static inline std::size_t KernelFilterStrideW(const ProblemDescription& problem) +static inline std::size_t KernelFilterStrideW(const miopen::conv::ProblemDescription& problem) { - if(problem.direction.IsBackwardWrW()) + if(problem.IsDirectionBackwardWrW()) return problem.GetDilationW(); else return problem.GetKernelStrideW(); } -static inline std::size_t KernelFilterDilationH(const ProblemDescription& problem) +static inline std::size_t KernelFilterDilationH(const miopen::conv::ProblemDescription& problem) { - if(problem.direction.IsBackwardWrW()) + if(problem.IsDirectionBackwardWrW()) return problem.GetKernelStrideH(); else return problem.GetDilationH(); } -static inline std::size_t KernelFilterDilationW(const ProblemDescription& problem) +static inline std::size_t KernelFilterDilationW(const miopen::conv::ProblemDescription& problem) { - if(problem.direction.IsBackwardWrW()) + if(problem.IsDirectionBackwardWrW()) return problem.GetKernelStrideW(); else return problem.GetDilationW(); } -static inline std::size_t KernelOutputChannelK(const ProblemDescription& problem) +static inline std::size_t KernelOutputChannelK(const miopen::conv::ProblemDescription& problem) { - if(problem.direction.IsBackwardWrW()) + if(problem.IsDirectionBackwardWrW()) return problem.GetInChannels_(); else return problem.GetOutChannels_(); } -static inline std::size_t KernelInputChannelC(const ProblemDescription& problem) +static inline std::size_t KernelInputChannelC(const miopen::conv::ProblemDescription& problem) { - if(problem.direction.IsBackwardWrW()) + if(problem.IsDirectionBackwardWrW()) return problem.GetBatchSize_(); else return problem.GetInChannels_() / problem.GetGroupCount(); } -static inline std::size_t KernelBatchN(const ProblemDescription& problem) +static inline std::size_t KernelBatchN(const miopen::conv::ProblemDescription& problem) { - if(problem.direction.IsBackwardWrW()) + if(problem.IsDirectionBackwardWrW()) return problem.GetOutChannels_() / problem.GetGroupCount(); else return problem.GetBatchSize_(); } -static inline std::size_t KernelOutputHeightHo(const ProblemDescription& problem) +static inline std::size_t KernelOutputHeightHo(const miopen::conv::ProblemDescription& problem) { - if(problem.direction.IsForward()) + if(problem.IsDirectionForward()) return problem.GetOutHeight_(); - else if(problem.direction.IsBackwardWrW()) + else if(problem.IsDirectionBackwardWrW()) return problem.GetWeightsHeight_(); else return problem.GetInHeight_(); } -static inline std::size_t KernelOutputWidthWo(const ProblemDescription& problem) +static inline std::size_t KernelOutputWidthWo(const miopen::conv::ProblemDescription& problem) { - if(problem.direction.IsForward()) + if(problem.IsDirectionForward()) return problem.GetOutWidth_(); - else if(problem.direction.IsBackwardWrW()) + else if(problem.IsDirectionBackwardWrW()) return problem.GetWeightsWidth_(); else return problem.GetInWidth_(); } -static inline std::size_t KernelFilterWidthX(const ProblemDescription& problem) +static inline std::size_t KernelFilterWidthX(const miopen::conv::ProblemDescription& problem) { - if(problem.direction.IsBackwardWrW()) + if(problem.IsDirectionBackwardWrW()) return problem.GetInWidth_(); else return problem.GetWeightsWidth_(); } -static inline std::size_t KernelFilterHeightY(const ProblemDescription& problem) +static inline std::size_t KernelFilterHeightY(const miopen::conv::ProblemDescription& problem) { - if(problem.direction.IsBackwardWrW()) + if(problem.IsDirectionBackwardWrW()) return problem.GetInHeight_(); else return problem.GetWeightsHeight_(); @@ -225,8 +225,9 @@ inline static uint32_t GetReadWriteVectorSize(const int v) } ///\todo remove -inline static uint32_t -GetEPackLength(const ExecutionContext& ctx, const ProblemDescription& problem, bool isXdlopsInvoked) +inline static uint32_t GetEPackLength(const ExecutionContext& ctx, + const miopen::conv::ProblemDescription& problem, + bool isXdlopsInvoked) { // Based on data type, Es are packed int EPACK = 1; @@ -245,42 +246,7 @@ GetEPackLength(const ExecutionContext& ctx, const ProblemDescription& problem, b return EPACK; } -///\todo remove -static inline bool IsValidXdlopsGemm(const int GemmMPerBlock, - const int GemmNPerBlock, - const int GemmKPackedPerBlock, // packed - const int GemmMPerWave, - const int GemmNPerWave) -{ - // unsupported xdlops-gemm - if(GemmMPerWave == 16 && GemmNPerWave == 32) - return false; - if(GemmMPerWave == 32 && GemmNPerWave == 16) - return false; - if(GemmMPerWave == 8 && GemmNPerWave != 64) - return false; - if(GemmMPerWave == 4 && GemmNPerWave != 64) - return false; - if(GemmMPerWave == 32 && GemmNPerWave == 32 && GemmKPackedPerBlock % 2 != 0) - return false; - if(GemmMPerWave == 16 && GemmNPerWave == 16 && GemmKPackedPerBlock % 4 != 0) - return false; - if(GemmMPerWave > 64 && GemmNPerWave < 64) - return false; - if(GemmNPerWave > 64 && GemmMPerWave < 64) - return false; - - const auto WaveSize = 64; - const auto BlockSize = - (GemmNPerBlock * GemmMPerBlock) / (GemmMPerWave * GemmNPerWave) * WaveSize; - - if(BlockSize < 64 || BlockSize > 256) - return false; - - return (GemmMPerBlock % GemmMPerWave) == 0 && (GemmNPerBlock % GemmNPerWave) == 0; -} - -static inline bool IsIndexRangeLargeEnough(const ProblemDescription& problem) +static inline bool IsIndexRangeLargeEnough(const miopen::conv::ProblemDescription& problem) { // composable kernel use int32_t for memory offset, which covers 2GB of memory maximum const std::size_t max_index_range = std::size_t(2) * 1024 * 1024 * 1024; @@ -289,7 +255,7 @@ static inline bool IsIndexRangeLargeEnough(const ProblemDescription& problem) problem.GetOutSize() < max_index_range; } -static inline bool IsValidBlockwiseGemmXdlops(const ProblemDescription& problem, +static inline bool IsValidBlockwiseGemmXdlops(const miopen::conv::ProblemDescription& problem, const int GemmMPerBlock, const int GemmNPerBlock, const int GemmKPerBlock, @@ -366,7 +332,7 @@ IsValidGridGemmXdlops(const std::size_t GemmM, const std::size_t GemmN, const st ///\todo remove static inline bool IsApplicableXdlops(const ExecutionContext& ctx, - const ProblemDescription& problem) + const miopen::conv::ProblemDescription& problem) { if(!IsXdlopsSupport(ctx)) return false; @@ -381,7 +347,7 @@ static inline bool IsApplicableXdlops(const ExecutionContext& ctx, std::size_t GemmM, GemmN, GemmK; // forward - if(problem.direction.IsForward()) + if(problem.IsDirectionForward()) { // TBD/ Since bfp16/fp16 fwd kernel extracts epack from c*y*x, // one could relax the following restriction for bfp16/fp16, @@ -394,7 +360,7 @@ static inline bool IsApplicableXdlops(const ExecutionContext& ctx, GemmK = static_cast(nonVectorizedC) * y * x; } // backwardData - else if(problem.direction.IsBackwardData()) + else if(problem.IsDirectionBackwardData()) { if(k % GetEPackLength(ctx, problem, true) != 0) return false; @@ -420,7 +386,7 @@ static inline bool IsApplicableXdlops(const ExecutionContext& ctx, ///\todo remove template inline static auto GetPerformanceConfigBase(const ExecutionContext& ctx, - const ProblemDescription& problem) + const miopen::conv::ProblemDescription& problem) { PerformanceImplicitGemm_t pp; pp.HeuristicInit(ctx, problem); @@ -429,7 +395,7 @@ inline static auto GetPerformanceConfigBase(const ExecutionContext& ctx, } ///\todo remove -static inline size_t ComputeLDSRequiredSize(const ProblemDescription& problem, +static inline size_t ComputeLDSRequiredSize(const miopen::conv::ProblemDescription& problem, const int BPerBlock, const int KPerBlock, const int EPerBlock, @@ -460,7 +426,7 @@ static inline size_t ComputeLDSRequiredSize(const ProblemDescription& problem, } static inline bool use_amd_inline_asm(const ExecutionContext& ctx, - const ProblemDescription& problem) + const miopen::conv::ProblemDescription& problem) { if(StartsWith(ctx.GetStream().GetDeviceName(), "gfx8")) diff --git a/src/include/miopen/solver/mlir_common.hpp b/src/include/miopen/solver/mlir_common.hpp index d926277e4a..1b678f9146 100644 --- a/src/include/miopen/solver/mlir_common.hpp +++ b/src/include/miopen/solver/mlir_common.hpp @@ -28,7 +28,7 @@ #define GUARD_MLIR_COMMON_HPP_ #include -#include +#include #include @@ -36,16 +36,17 @@ namespace miopen { namespace solver { namespace mlir { -std::string GetKernelName(const ProblemDescription& problem, bool is_xdlops, int kernel_id = 0); +std::string +GetKernelName(const miopen::conv::ProblemDescription& problem, bool is_xdlops, int kernel_id = 0); std::string ConstructBuildOptions(const ExecutionContext& ctx, - const ProblemDescription& problem, + const miopen::conv::ProblemDescription& problem, bool is_xdlops, int kernel_id = 0); template std::string ConstructBuildOptions(const ExecutionContext& ctx, - const ProblemDescription& problem, + const miopen::conv::ProblemDescription& problem, const T& perf_config, bool is_xdlops, int kernel_id = 0) diff --git a/src/include/miopen/solver/problem_description_interpreter.hpp b/src/include/miopen/solver/problem_description_interpreter.hpp index 0690d3d36e..e27533d737 100644 --- a/src/include/miopen/solver/problem_description_interpreter.hpp +++ b/src/include/miopen/solver/problem_description_interpreter.hpp @@ -27,10 +27,7 @@ #ifndef PROBLEM_DESCRIPTION_INTERPRETER_HPP_ #define PROBLEM_DESCRIPTION_INTERPRETER_HPP_ -#include -#include -#include -#include +#include namespace miopen { namespace solver { @@ -42,183 +39,195 @@ namespace solver { // 4. adjust dilation to 1 if filter size is 1 struct ProblemInterpreter { - static auto GetGroupCountG(const ProblemDescription& problem) + static auto GetGroupCountG(const miopen::conv::ProblemDescription& problem) { return problem.GetGroupCount(); } - static int GetBatchN(const ProblemDescription& problem) { return problem.GetBatchSize_(); } + static int GetBatchN(const miopen::conv::ProblemDescription& problem) + { + return problem.GetBatchSize_(); + } - static auto GetOutputLayout(const ProblemDescription& problem) + static auto GetOutputLayout(const miopen::conv::ProblemDescription& problem) { - if(problem.direction.IsForward()) + if(problem.IsDirectionForward()) return problem.GetOutLayout(); else return problem.GetInLayout(); } - static int GetOutputChannelK(const ProblemDescription& problem) + static int GetOutputChannelK(const miopen::conv::ProblemDescription& problem) { - if(problem.direction.IsForward()) + if(problem.IsDirectionForward()) return problem.GetOutChannels_(); else return problem.GetInChannels_(); } - static auto GetInputLayout(const ProblemDescription& problem) + static auto GetInputLayout(const miopen::conv::ProblemDescription& problem) { - if(problem.direction.IsForward()) + if(problem.IsDirectionForward()) return problem.GetInLayout(); else return problem.GetOutLayout(); } - static int GetInputChannelC(const ProblemDescription& problem) + static int GetInputChannelC(const miopen::conv::ProblemDescription& problem) { - if(problem.direction.IsForward()) + if(problem.IsDirectionForward()) return problem.GetInChannels_(); else return problem.GetOutChannels_(); } - static int GetInputDepthDi(const ProblemDescription& problem) + static int GetInputDepthDi(const miopen::conv::ProblemDescription& problem) { - if(problem.direction.IsForward()) + if(problem.IsDirectionForward()) return problem.GetInDepth_(); else return problem.GetOutDepth_(); } - static int GetInputHeightHi(const ProblemDescription& problem) + static int GetInputHeightHi(const miopen::conv::ProblemDescription& problem) { - if(problem.direction.IsForward()) + if(problem.IsDirectionForward()) return problem.GetInHeight_(); else return problem.GetOutHeight_(); } - static int GetInputWidthWi(const ProblemDescription& problem) + static int GetInputWidthWi(const miopen::conv::ProblemDescription& problem) { - if(problem.direction.IsForward()) + if(problem.IsDirectionForward()) return problem.GetInWidth_(); else return problem.GetOutWidth_(); } - static auto GetInputCastType(const ProblemDescription& problem) + static auto GetInputCastType(const miopen::conv::ProblemDescription& problem) { - if(problem.direction.IsForward()) + if(problem.IsDirectionForward()) return problem.GetInCastType(); else return problem.GetOutCastType(); } - static int GetOutputDepthDo(const ProblemDescription& problem) + static int GetOutputDepthDo(const miopen::conv::ProblemDescription& problem) { - if(problem.direction.IsForward()) + if(problem.IsDirectionForward()) return problem.GetOutDepth_(); else return problem.GetInDepth_(); } - static int GetOutputHeightHo(const ProblemDescription& problem) + static int GetOutputHeightHo(const miopen::conv::ProblemDescription& problem) { - if(problem.direction.IsForward()) + if(problem.IsDirectionForward()) return problem.GetOutHeight_(); else return problem.GetInHeight_(); } - static int GetOutputWidthWo(const ProblemDescription& problem) + static int GetOutputWidthWo(const miopen::conv::ProblemDescription& problem) { - if(problem.direction.IsForward()) + if(problem.IsDirectionForward()) return problem.GetOutWidth_(); else return problem.GetInWidth_(); } - static auto GetOutputCastType(const ProblemDescription& problem) + static auto GetOutputCastType(const miopen::conv::ProblemDescription& problem) { - if(problem.direction.IsForward()) + if(problem.IsDirectionForward()) return problem.GetOutCastType(); else return problem.GetInCastType(); } - static auto GetOutputDataType(const ProblemDescription& problem) + static auto GetOutputDataType(const miopen::conv::ProblemDescription& problem) { - return problem.direction.IsForward() ? problem.GetOutDataType() : problem.GetInDataType(); + return problem.IsDirectionForward() ? problem.GetOutDataType() : problem.GetInDataType(); } - static auto GetInputDataType(const ProblemDescription& problem) + static auto GetInputDataType(const miopen::conv::ProblemDescription& problem) { - return problem.direction.IsForward() ? problem.GetInDataType() : problem.GetOutDataType(); + return problem.IsDirectionForward() ? problem.GetInDataType() : problem.GetOutDataType(); } - static int GetFilterDepthZ(const ProblemDescription& problem) + static int GetFilterDepthZ(const miopen::conv::ProblemDescription& problem) { return problem.GetWeightsDepth_(); } - static auto GetFilterLayout(const ProblemDescription& problem) + static auto GetFilterLayout(const miopen::conv::ProblemDescription& problem) { return problem.GetWeightsLayout(); } - static int GetFilterHeightY(const ProblemDescription& problem) + static int GetFilterHeightY(const miopen::conv::ProblemDescription& problem) { return problem.GetWeightsHeight_(); } - static int GetFilterWidthX(const ProblemDescription& problem) + static int GetFilterWidthX(const miopen::conv::ProblemDescription& problem) { return problem.GetWeightsWidth_(); } // adjust conv_stride_d to 1 if Do is 1 - static auto GetAdjustedConvolutionStrideD(const ProblemDescription& problem) + static auto GetAdjustedConvolutionStrideD(const miopen::conv::ProblemDescription& problem) { return GetOutputDepthDo(problem) > 1 ? problem.GetKernelStrideD() : 1; } // adjust conv_stride_h to 1 if Ho is 1 - static auto GetAdjustedConvolutionStrideH(const ProblemDescription& problem) + static auto GetAdjustedConvolutionStrideH(const miopen::conv::ProblemDescription& problem) { return GetOutputHeightHo(problem) > 1 ? problem.GetKernelStrideH() : 1; } // adjust conv_stride_w to 1 if Wo is 1 - static auto GetAdjustedConvolutionStrideW(const ProblemDescription& problem) + static auto GetAdjustedConvolutionStrideW(const miopen::conv::ProblemDescription& problem) { return GetOutputWidthWo(problem) > 1 ? problem.GetKernelStrideW() : 1; } // adjust conv_dilation_d to 1 if Z is 1 - static auto GetAdjustedConvolutionDilationD(const ProblemDescription& problem) + static auto GetAdjustedConvolutionDilationD(const miopen::conv::ProblemDescription& problem) { return GetFilterDepthZ(problem) > 1 ? problem.GetDilationD() : 1; } // adjust conv_dilation_h to 1 if Y is 1 - static auto GetAdjustedConvolutionDilationH(const ProblemDescription& problem) + static auto GetAdjustedConvolutionDilationH(const miopen::conv::ProblemDescription& problem) { return GetFilterHeightY(problem) > 1 ? problem.GetDilationH() : 1; } // adjust conv_dilation_w to 1 if X is 1 - static auto GetAdjustedConvolutionDilationW(const ProblemDescription& problem) + static auto GetAdjustedConvolutionDilationW(const miopen::conv::ProblemDescription& problem) { return GetFilterWidthX(problem) > 1 ? problem.GetDilationW() : 1; } - static auto GetInputLeftPadD(const ProblemDescription& problem) { return problem.GetPadD(); } + static auto GetInputLeftPadD(const miopen::conv::ProblemDescription& problem) + { + return problem.GetPadD(); + } - static auto GetInputLeftPadH(const ProblemDescription& problem) { return problem.GetPadH(); } + static auto GetInputLeftPadH(const miopen::conv::ProblemDescription& problem) + { + return problem.GetPadH(); + } - static auto GetInputLeftPadW(const ProblemDescription& problem) { return problem.GetPadW(); } + static auto GetInputLeftPadW(const miopen::conv::ProblemDescription& problem) + { + return problem.GetPadW(); + } // adjust right padding size so that filter will not move out-of-bound - static auto GetAdjustedInputRightPadD(const ProblemDescription& problem) + static auto GetAdjustedInputRightPadD(const miopen::conv::ProblemDescription& problem) { const int di = GetInputDepthDi(problem); const int dout = GetOutputDepthDo(problem); @@ -236,7 +245,7 @@ struct ProblemInterpreter } // adjust right padding size so that filter will not move out-of-bound - static auto GetAdjustedInputRightPadH(const ProblemDescription& problem) + static auto GetAdjustedInputRightPadH(const miopen::conv::ProblemDescription& problem) { const int hi = GetInputHeightHi(problem); const int ho = GetOutputHeightHo(problem); @@ -254,7 +263,7 @@ struct ProblemInterpreter } // adjust right padding size so that filter will not move out-of-bound - static auto GetAdjustedInputRightPadW(const ProblemDescription& problem) + static auto GetAdjustedInputRightPadW(const miopen::conv::ProblemDescription& problem) { const int wi = GetInputWidthWi(problem); const int wo = GetOutputWidthWo(problem); diff --git a/src/mlo_dir_conv.cpp b/src/mlo_dir_conv.cpp index 522f5931b5..4d90a479e0 100644 --- a/src/mlo_dir_conv.cpp +++ b/src/mlo_dir_conv.cpp @@ -65,163 +65,166 @@ miopen::PerformanceDb miopen::GetDb(const miopen::ExecutionContext& ctx) static auto GetGemmSolvers() { - return miopen::solver::SolverContainer{}; + miopen::solver::conv::GemmWrw1x1_stride1, + miopen::solver::conv::GemmWrwUniversal>{}; } static auto GetDirectSolvers() { - return miopen::solver::SolverContainer{}; + return miopen::solver::SolverContainer{}; } static auto GetImplicitGemmSolvers() { return miopen::solver::SolverContainer< - miopen::solver::ConvHipImplicitGemmForwardV4R5Xdlops, - miopen::solver::ConvHipImplicitGemmForwardV4R4Xdlops, - miopen::solver::ConvHipImplicitGemmForwardV4R4Xdlops_Padded_Gemm, - miopen::solver::ConvHipImplicitGemmBwdDataV4R1Xdlops, - miopen::solver::ConvHipImplicitGemmBwdDataV1R1Xdlops, - miopen::solver::ConvHipImplicitGemmV4R1Fwd, - miopen::solver::ConvHipImplicitGemmV4R4Fwd, - miopen::solver::ConvMlirIgemmFwdXdlops, - miopen::solver::ConvMlirIgemmFwd, - miopen::solver::ConvMlirIgemmBwdXdlops, - miopen::solver::ConvMlirIgemmBwd, - miopen::solver::ConvHipImplicitGemmBwdDataV1R1, - miopen::solver::ConvHipImplicitGemmBwdDataV4R1, - miopen::solver::ConvAsmImplicitGemmV4R1DynamicFwd_1x1, - miopen::solver::ConvAsmImplicitGemmV4R1DynamicFwd, - miopen::solver::ConvAsmImplicitGemmV4R1DynamicBwd, - miopen::solver::ConvAsmImplicitGemmGTCDynamicFwdXdlops, - miopen::solver::ConvAsmImplicitGemmGTCDynamicBwdXdlops, - miopen::solver::ConvAsmImplicitGemmGTCDynamicFwdXdlopsNHWC, - miopen::solver::ConvAsmImplicitGemmGTCDynamicBwdXdlopsNHWC, - miopen::solver::ConvCkIgemmFwdV6r1DlopsNchw, + miopen::solver::conv::ConvHipImplicitGemmForwardV4R5Xdlops, + miopen::solver::conv::ConvHipImplicitGemmForwardV4R4Xdlops, + miopen::solver::conv::ConvHipImplicitGemmForwardV4R4Xdlops_Padded_Gemm, + miopen::solver::conv::ConvHipImplicitGemmBwdDataV4R1Xdlops, + miopen::solver::conv::ConvHipImplicitGemmBwdDataV1R1Xdlops, + miopen::solver::conv::ConvHipImplicitGemmV4R1Fwd, + miopen::solver::conv::ConvHipImplicitGemmV4R4Fwd, + miopen::solver::conv::ConvMlirIgemmFwdXdlops, + miopen::solver::conv::ConvMlirIgemmFwd, + miopen::solver::conv::ConvMlirIgemmBwdXdlops, + miopen::solver::conv::ConvMlirIgemmBwd, + miopen::solver::conv::ConvHipImplicitGemmBwdDataV1R1, + miopen::solver::conv::ConvHipImplicitGemmBwdDataV4R1, + miopen::solver::conv::ConvAsmImplicitGemmV4R1DynamicFwd_1x1, + miopen::solver::conv::ConvAsmImplicitGemmV4R1DynamicFwd, + miopen::solver::conv::ConvAsmImplicitGemmV4R1DynamicBwd, + miopen::solver::conv::ConvAsmImplicitGemmGTCDynamicFwdXdlops, + miopen::solver::conv::ConvAsmImplicitGemmGTCDynamicBwdXdlops, + miopen::solver::conv::ConvAsmImplicitGemmGTCDynamicFwdXdlopsNHWC, + miopen::solver::conv::ConvAsmImplicitGemmGTCDynamicBwdXdlopsNHWC, + miopen::solver::conv::ConvCkIgemmFwdV6r1DlopsNchw, #if MIOPEN_BACKEND_HIP && MIOPEN_USE_COMPOSABLEKERNEL - miopen::solver::ConvHipImplicitGemmFwdXdlops, - miopen::solver::ConvHipImplicitGemmBwdXdlops, - miopen::solver::ConvHipImplicitGemmGroupFwdXdlops, - miopen::solver::ConvHipImplicitGemm3DGroupFwdXdlops, - miopen::solver::ConvHipImplicitGemm3DGroupBwdXdlops, + miopen::solver::conv::ConvHipImplicitGemmFwdXdlops, + miopen::solver::conv::ConvHipImplicitGemmBwdXdlops, + miopen::solver::conv::ConvHipImplicitGemmGroupFwdXdlops, + miopen::solver::conv::ConvHipImplicitGemm3DGroupFwdXdlops, + miopen::solver::conv::ConvHipImplicitGemm3DGroupBwdXdlops, #endif // MIOPEN_BACKEND_HIP && MIOPEN_USE_COMPOSABLEKERNEL - miopen::solver::ConvAsmImplicitGemmGTCDynamicFwdDlopsNCHWC>{}; + miopen::solver::conv::ConvAsmImplicitGemmGTCDynamicFwdDlopsNCHWC>{}; } static auto GetWindogradSolvers() { - return miopen::solver::SolverContainer, - miopen::solver::ConvBinWinoRxS<2, 3>, - miopen::solver::ConvBinWinogradRxSf2x3g1, - miopen::solver::ConvBinWinogradRxS, - miopen::solver::ConvMPBidirectWinograd<3, 3>, - miopen::solver::ConvMPBidirectWinograd<4, 3>, - miopen::solver::ConvMPBidirectWinograd<5, 3>, - miopen::solver::ConvMPBidirectWinograd<6, 3>, - miopen::solver::ConvMPBidirectWinograd_xdlops<2, 3>, - miopen::solver::ConvMPBidirectWinograd_xdlops<3, 3>, - miopen::solver::ConvMPBidirectWinograd_xdlops<4, 3>, - miopen::solver::ConvMPBidirectWinograd_xdlops<5, 3>, - miopen::solver::ConvMPBidirectWinograd_xdlops<6, 3>, - miopen::solver::ConvWinoFuryRxS<2, 3>>{}; + return miopen::solver::SolverContainer< + miopen::solver::conv::ConvBinWinograd3x3U, + miopen::solver::conv::ConvBinWinoRxS<3, 2>, + miopen::solver::conv::ConvBinWinoRxS<2, 3>, + miopen::solver::conv::ConvBinWinogradRxSf2x3g1, + miopen::solver::conv::ConvBinWinogradRxS, + miopen::solver::conv::ConvMPBidirectWinograd<3, 3>, + miopen::solver::conv::ConvMPBidirectWinograd<4, 3>, + miopen::solver::conv::ConvMPBidirectWinograd<5, 3>, + miopen::solver::conv::ConvMPBidirectWinograd<6, 3>, + miopen::solver::conv::ConvMPBidirectWinograd_xdlops<2, 3>, + miopen::solver::conv::ConvMPBidirectWinograd_xdlops<3, 3>, + miopen::solver::conv::ConvMPBidirectWinograd_xdlops<4, 3>, + miopen::solver::conv::ConvMPBidirectWinograd_xdlops<5, 3>, + miopen::solver::conv::ConvMPBidirectWinograd_xdlops<6, 3>, + miopen::solver::conv::ConvWinoFuryRxS<2, 3>>{}; } static auto GetImplicitGemmWrWSolvers() { return miopen::solver::SolverContainer< - miopen::solver::ConvHipImplicitGemmWrwV4R4Xdlops, - miopen::solver::ConvHipImplicitGemmWrwV4R4Xdlops_Padded_Gemm, - miopen::solver::ConvHipImplicitGemmV4R1WrW, - miopen::solver::ConvHipImplicitGemmV4R4WrW, - miopen::solver::ConvAsmImplicitGemmV4R1DynamicWrw, - miopen::solver::ConvMlirIgemmWrWXdlops, - miopen::solver::ConvMlirIgemmWrW, - miopen::solver::ConvAsmImplicitGemmGTCDynamicWrwXdlops, + miopen::solver::conv::ConvHipImplicitGemmWrwV4R4Xdlops, + miopen::solver::conv::ConvHipImplicitGemmWrwV4R4Xdlops_Padded_Gemm, + miopen::solver::conv::ConvHipImplicitGemmV4R1WrW, + miopen::solver::conv::ConvHipImplicitGemmV4R4WrW, + miopen::solver::conv::ConvAsmImplicitGemmV4R1DynamicWrw, + miopen::solver::conv::ConvMlirIgemmWrWXdlops, + miopen::solver::conv::ConvMlirIgemmWrW, + miopen::solver::conv::ConvAsmImplicitGemmGTCDynamicWrwXdlops, #if MIOPEN_BACKEND_HIP && MIOPEN_USE_COMPOSABLEKERNEL - miopen::solver::ConvHipImplicitGemm3DGroupWrwXdlops, + miopen::solver::conv::ConvHipImplicitGemm3DGroupWrwXdlops, #endif // MIOPEN_BACKEND_HIP && MIOPEN_USE_COMPOSABLEKERNEL - miopen::solver::ConvAsmImplicitGemmGTCDynamicWrwXdlopsNHWC>{}; + miopen::solver::conv::ConvAsmImplicitGemmGTCDynamicWrwXdlopsNHWC>{}; } static auto GetWindogradWrWSolvers() { - return miopen::solver::SolverContainer, - miopen::solver::ConvBinWinoRxS<2, 3>, - miopen::solver::ConvBinWinogradRxSf2x3g1, - miopen::solver::ConvWinograd3x3MultipassWrW<3, 2>, - miopen::solver::ConvWinograd3x3MultipassWrW<3, 3>, - miopen::solver::ConvWinograd3x3MultipassWrW<3, 4>, - miopen::solver::ConvWinograd3x3MultipassWrW<3, 5>, - miopen::solver::ConvWinograd3x3MultipassWrW<3, 6>, - miopen::solver::ConvWinograd3x3MultipassWrW<7, 2>, - miopen::solver::ConvWinograd3x3MultipassWrW<7, 3>, - miopen::solver::ConvWinograd3x3MultipassWrW<7, 3, 1, 1>, - miopen::solver::ConvWinograd3x3MultipassWrW<7, 2, 1, 1>, - miopen::solver::ConvWinograd3x3MultipassWrW<1, 1, 7, 2>, - miopen::solver::ConvWinograd3x3MultipassWrW<1, 1, 7, 3>, - miopen::solver::ConvWinograd3x3MultipassWrW<5, 3>, - miopen::solver::ConvWinograd3x3MultipassWrW<5, 4>>{}; + return miopen::solver::SolverContainer< + miopen::solver::conv::ConvBinWinogradRxS, + miopen::solver::conv::ConvBinWinoRxS<3, 2>, + miopen::solver::conv::ConvBinWinoRxS<2, 3>, + miopen::solver::conv::ConvBinWinogradRxSf2x3g1, + miopen::solver::conv::ConvWinograd3x3MultipassWrW<3, 2>, + miopen::solver::conv::ConvWinograd3x3MultipassWrW<3, 3>, + miopen::solver::conv::ConvWinograd3x3MultipassWrW<3, 4>, + miopen::solver::conv::ConvWinograd3x3MultipassWrW<3, 5>, + miopen::solver::conv::ConvWinograd3x3MultipassWrW<3, 6>, + miopen::solver::conv::ConvWinograd3x3MultipassWrW<7, 2>, + miopen::solver::conv::ConvWinograd3x3MultipassWrW<7, 3>, + miopen::solver::conv::ConvWinograd3x3MultipassWrW<7, 3, 1, 1>, + miopen::solver::conv::ConvWinograd3x3MultipassWrW<7, 2, 1, 1>, + miopen::solver::conv::ConvWinograd3x3MultipassWrW<1, 1, 7, 2>, + miopen::solver::conv::ConvWinograd3x3MultipassWrW<1, 1, 7, 3>, + miopen::solver::conv::ConvWinograd3x3MultipassWrW<5, 3>, + miopen::solver::conv::ConvWinograd3x3MultipassWrW<5, 4>>{}; } static auto GetBwdWrW2DSolvers() { - return miopen::solver::SolverContainer, - miopen::solver::ConvOclBwdWrW2<2>, - miopen::solver::ConvOclBwdWrW2<4>, - miopen::solver::ConvOclBwdWrW2<8>, - miopen::solver::ConvOclBwdWrW2<16>, - miopen::solver::ConvOclBwdWrW2NonTunable, - miopen::solver::ConvOclBwdWrW53, - miopen::solver::ConvOclBwdWrW1x1, - miopen::solver::ConvDirectNaiveConvFwd, - miopen::solver::ConvDirectNaiveConvBwd, - miopen::solver::ConvDirectNaiveConvWrw>{}; + return miopen::solver::SolverContainer, + miopen::solver::conv::ConvOclBwdWrW2<2>, + miopen::solver::conv::ConvOclBwdWrW2<4>, + miopen::solver::conv::ConvOclBwdWrW2<8>, + miopen::solver::conv::ConvOclBwdWrW2<16>, + miopen::solver::conv::ConvOclBwdWrW2NonTunable, + miopen::solver::conv::ConvOclBwdWrW53, + miopen::solver::conv::ConvOclBwdWrW1x1, + miopen::solver::conv::ConvDirectNaiveConvFwd, + miopen::solver::conv::ConvDirectNaiveConvBwd, + miopen::solver::conv::ConvDirectNaiveConvWrw>{}; } -static auto GetFFTSolvers() { return miopen::solver::SolverContainer{}; } +static auto GetFFTSolvers() { return miopen::solver::SolverContainer{}; } std::vector FindAllGemmSolutions(const miopen::ExecutionContext& ctx, - const miopen::ProblemDescription& problem, + const miopen::conv::ProblemDescription& problem, const miopen::AnyInvokeParams& invoke_ctx) { return GetGemmSolvers().SearchForAllSolutions(ctx, problem, GetDb(ctx), invoke_ctx); } std::vector> -AllGemmWorkspaceSize(const miopen::ExecutionContext& ctx, const miopen::ProblemDescription& problem) +AllGemmWorkspaceSize(const miopen::ExecutionContext& ctx, + const miopen::conv::ProblemDescription& problem) { return GetGemmSolvers().GetWorkspaceSizes(ctx, problem); } std::vector FindAllDirectSolutions(const miopen::ExecutionContext& ctx, - const miopen::ProblemDescription& problem, + const miopen::conv::ProblemDescription& problem, const miopen::AnyInvokeParams& invoke_ctx) { return GetDirectSolvers().SearchForAllSolutions(ctx, problem, GetDb(ctx), invoke_ctx); @@ -229,28 +232,28 @@ FindAllDirectSolutions(const miopen::ExecutionContext& ctx, std::vector> AllDirectForwardBackwardDataWorkspaceSize(const miopen::ExecutionContext& ctx, - const miopen::ProblemDescription& problem) + const miopen::conv::ProblemDescription& problem) { return GetDirectSolvers().GetWorkspaceSizes(ctx, problem); } std::vector> FindAllWinogradWorkspaceSizes(const miopen::ExecutionContext& ctx, - const miopen::ProblemDescription& problem) + const miopen::conv::ProblemDescription& problem) { return GetWindogradSolvers().GetWorkspaceSizes(ctx, problem); } std::vector> FindWinogradWrWWorkspaceSizes(const miopen::ExecutionContext& ctx, - const miopen::ProblemDescription& problem) + const miopen::conv::ProblemDescription& problem) { return GetWindogradWrWSolvers().GetWorkspaceSizes(ctx, problem); } std::vector> FindAllImplicitGemmWorkspaceSizes(const miopen::ExecutionContext& ctx, - const miopen::ProblemDescription& problem) + const miopen::conv::ProblemDescription& problem) { #if WORKAROUND_SWDEV_227826 if(miopen::IsEnabled(MIOPEN_DEBUG_IMPLICIT_GEMM_FIND_ALL_SOLUTIONS{})) @@ -264,7 +267,7 @@ FindAllImplicitGemmWorkspaceSizes(const miopen::ExecutionContext& ctx, std::vector FindAllImplicitGemmSolutions(const miopen::ExecutionContext& ctx, - const miopen::ProblemDescription& problem, + const miopen::conv::ProblemDescription& problem, const miopen::AnyInvokeParams& invoke_ctx) { #if WORKAROUND_SWDEV_227826 @@ -280,7 +283,7 @@ FindAllImplicitGemmSolutions(const miopen::ExecutionContext& ctx, std::vector FindAllWinogradSolutions(const miopen::ExecutionContext& ctx, - const miopen::ProblemDescription& problem, + const miopen::conv::ProblemDescription& problem, const miopen::AnyInvokeParams& invoke_ctx) { return GetWindogradSolvers().SearchForAllSolutions(ctx, problem, GetDb(ctx), invoke_ctx); @@ -288,7 +291,7 @@ FindAllWinogradSolutions(const miopen::ExecutionContext& ctx, std::vector FindWinogradWrWAllSolutions(const miopen::ExecutionContext& ctx, - const miopen::ProblemDescription& problem, + const miopen::conv::ProblemDescription& problem, const miopen::AnyInvokeParams& invoke_ctx) { return GetWindogradWrWSolvers().SearchForAllSolutions(ctx, problem, GetDb(ctx), invoke_ctx); @@ -296,14 +299,14 @@ FindWinogradWrWAllSolutions(const miopen::ExecutionContext& ctx, std::vector> AllDirectBwdWrW2DWorkspaceSize(const miopen::ExecutionContext& ctx, - const miopen::ProblemDescription& problem) + const miopen::conv::ProblemDescription& problem) { return GetBwdWrW2DSolvers().GetWorkspaceSizes(ctx, problem); } std::vector> FindImplicitGemmWrWWorkspaceSizes(const miopen::ExecutionContext& ctx, - const miopen::ProblemDescription& problem) + const miopen::conv::ProblemDescription& problem) { #if WORKAROUND_SWDEV_227826 if(miopen::IsEnabled(MIOPEN_DEBUG_IMPLICIT_GEMM_FIND_ALL_SOLUTIONS{})) @@ -317,7 +320,7 @@ FindImplicitGemmWrWWorkspaceSizes(const miopen::ExecutionContext& ctx, std::vector FindImplicitGemmWrWAllSolutions(const miopen::ExecutionContext& ctx, - const miopen::ProblemDescription& problem, + const miopen::conv::ProblemDescription& problem, const miopen::AnyInvokeParams& invoke_ctx) { #if WORKAROUND_SWDEV_227826 @@ -334,7 +337,7 @@ FindImplicitGemmWrWAllSolutions(const miopen::ExecutionContext& ctx, std::vector FindAllBwdWrW2DSolutions(const miopen::ExecutionContext& ctx, - const miopen::ProblemDescription& problem, + const miopen::conv::ProblemDescription& problem, const miopen::AnyInvokeParams& invoke_ctx) { return GetBwdWrW2DSolvers().SearchForAllSolutions(ctx, problem, GetDb(ctx), invoke_ctx); @@ -342,7 +345,7 @@ FindAllBwdWrW2DSolutions(const miopen::ExecutionContext& ctx, std::vector FindAllFFTSolutions(const miopen::ExecutionContext& ctx, - const miopen::ProblemDescription& problem, + const miopen::conv::ProblemDescription& problem, const miopen::AnyInvokeParams& invoke_ctx) { return GetFFTSolvers().SearchForAllSolutions(ctx, problem, GetDb(ctx), invoke_ctx); @@ -350,7 +353,7 @@ FindAllFFTSolutions(const miopen::ExecutionContext& ctx, std::vector> AllFFTForwardBackwardDataWorkspaceSize(const miopen::ExecutionContext& ctx, - const miopen::ProblemDescription& problem) + const miopen::conv::ProblemDescription& problem) { return GetFFTSolvers().GetWorkspaceSizes(ctx, problem); } diff --git a/src/ocl/convolutionocl.cpp b/src/ocl/convolutionocl.cpp index 19c65ca32d..29650085d1 100644 --- a/src/ocl/convolutionocl.cpp +++ b/src/ocl/convolutionocl.cpp @@ -139,12 +139,10 @@ static Invoker PrepareInvoker(ExecutionContext ctx, problem.SetupFloats(ctx); ctx.do_search = false; - const auto legacy_problem = ProblemDescription{problem}; - const auto solver = solver_id.GetSolver(); - auto db = GetDb(ctx); - auto solution = - solver.FindSolution(ctx, legacy_problem, db, {}); // auto tune is not expected here - auto& handle = ctx.GetStream(); + const auto solver = solver_id.GetSolver(); + auto db = GetDb(ctx); + auto solution = solver.FindSolution(ctx, problem, db, {}); // auto tune is not expected here + auto& handle = ctx.GetStream(); auto invoker = handle.PrepareInvoker(*solution.invoker_factory, solution.construction_params); const auto algo = AlgorithmName{solver_id.GetAlgo(problem.GetDirection())}; @@ -225,11 +223,10 @@ static inline std::vector FindConvolution(const ExecutionContext& ctx results = UserFindDbRecord::TryLoad(ctx.GetStream(), problem, [&](DbRecord& record) { auto ctx_copy = ctx; ctx_copy.use_dynamic_solutions_only = findMode.IsDynamicHybrid(ctx); - auto legacy_problem = ProblemDescription(problem); const auto params = - ConvFindParameters{conv.IsWinograd3x3SupportedAndFast(ctx_copy, legacy_problem)}; + conv::ConvFindParameters{conv.IsWinograd3x3SupportedAndFast(ctx_copy, problem)}; - FindCore(invoke_ctx, record, ctx_copy, legacy_problem, params, GetConvSolverFinders()); + FindCore(invoke_ctx, record, ctx_copy, problem, params, conv::GetConvSolverFinders()); }); } @@ -535,12 +532,12 @@ static const char immFallbackFailed[] = "Requested convolution is not supported or Immediate mode Fallback unsuccessful."; std::size_t -ConvolutionDescriptor::GetSolutionCountFallback(const ExecutionContext& exec_ctx, +ConvolutionDescriptor::GetSolutionCountFallback(const ExecutionContext& ctx, const conv::ProblemDescription& problem) const { const auto maxSolutionCount = solver::GetSolversByPrimitive(solver::Primitive::Convolution) .size(); // Simple and guarantees to provide enough space. - const auto n = GetSolutionsFallback(exec_ctx, problem, maxSolutionCount).size(); + const auto n = GetSolutionsFallback(ctx, problem, maxSolutionCount).size(); if(n > 0) return n; MIOPEN_LOG_I(immFallbackFailed); @@ -558,14 +555,14 @@ ConvolutionDescriptor::GetSolutionCountFallback(const ExecutionContext& exec_ctx MIOPEN_THROW(miopenStatusNotImplemented, immFallbackFailed); } -std::size_t ConvolutionDescriptor::GetSolutionCount(const ExecutionContext& exec_ctx, +std::size_t ConvolutionDescriptor::GetSolutionCount(const ExecutionContext& ctx, const conv::ProblemDescription& problem) const { MIOPEN_LOG_I(""); - const auto n = miopen::GetSolutionCount(exec_ctx.GetStream(), problem); + const auto n = miopen::GetSolutionCount(ctx.GetStream(), problem); if(n > 0) return n; - return GetSolutionCountFallback(exec_ctx, problem); + return GetSolutionCountFallback(ctx, problem); } struct SolutionTimeComparator @@ -597,9 +594,6 @@ ConvolutionDescriptor::GetSolutionsFallback(const ExecutionContext& ctx, return {}; } - /// \todo This is terrible. Should do away when we converge to - /// single conv::ProblemDescription type. - const auto legacy_problem = ProblemDescription{problem}; const auto& xDesc = (problem.GetDirection() == conv::Direction::Forward) ? problem.GetIn() : problem.GetOut(); const auto& weightsDesc = problem.GetWeights(); @@ -615,7 +609,7 @@ ConvolutionDescriptor::GetSolutionsFallback(const ExecutionContext& ctx, if(!miopen::IsDisabled(MIOPEN_DEBUG_ENABLE_AI_IMMED_MODE_FALLBACK{})) { const static std::string arch = ctx.GetStream().GetDeviceName(); - auto solvers = ai::immed_mode::PredictSolver(legacy_problem, ctx, arch); + auto solvers = ai::immed_mode::PredictSolver(problem, ctx, arch); if(!solvers.empty()) { MIOPEN_LOG_I2("Using TunaNet Fallback"); @@ -628,7 +622,7 @@ ConvolutionDescriptor::GetSolutionsFallback(const ExecutionContext& ctx, const auto solver_id = solver::Id{kinder}; const auto sol = solver_id.GetSolver(); const auto algo = solver_id.GetAlgo(); - if(IsAlgorithmDisabled(algo)) + if(conv::IsAlgorithmDisabled(algo)) continue; if(!sol.IsDynamic()) continue; // branch should never be taken @@ -659,7 +653,7 @@ ConvolutionDescriptor::GetSolutionsFallback(const ExecutionContext& ctx, // solver_id is always valid here, because taken from registry. // Validity check is not required. const auto algo = solver_id.GetAlgo(); - if(IsAlgorithmDisabled(algo)) // Algos can be disabled globally. + if(conv::IsAlgorithmDisabled(algo)) // Algos can be disabled globally. continue; const auto& s = solver_id.GetSolver(); // Let's allow non-dynamic later, if necessary. @@ -685,6 +679,8 @@ ConvolutionDescriptor::GetSolutionsFallback(const ExecutionContext& ctx, return interim; } +namespace { + std::vector GetSolutions(const ExecutionContext& ctx, const conv::ProblemDescription& problem, const size_t maxSolutionCount) @@ -711,7 +707,7 @@ std::vector GetSolutions(const ExecutionContext& ctx, for(const auto& pair : fdb_record) { const auto algo = static_cast(algo_resolver(pair.second.algorithm)); - if(IsAlgorithmDisabled(algo)) + if(conv::IsAlgorithmDisabled(algo)) continue; const auto solver_id = solver::Id{pair.first}; @@ -744,18 +740,20 @@ std::vector GetSolutions(const ExecutionContext& ctx, return interim; } +} // namespace + /// \todo Extend miopenConvSolution_t with an attribute indicating /// how the solution was obtained (benchmarked on the current system, /// taken from the System find-db, heuristically estimated, produced by /// MLP classifier...) and then remove the fallbackPathTaken out param. std::vector -ConvolutionDescriptor::GetSolutions(const ExecutionContext& exec_ctx, +ConvolutionDescriptor::GetSolutions(const ExecutionContext& ctx, const conv::ProblemDescription& problem, size_t maxSolutionCount, bool* fallbackPathTaken) const { MIOPEN_LOG_I(""); - auto solutions = miopen::GetSolutions(exec_ctx, problem, maxSolutionCount); + auto solutions = miopen::GetSolutions(ctx, problem, maxSolutionCount); if(fallbackPathTaken != nullptr) *fallbackPathTaken = solutions.empty(); @@ -763,8 +761,9 @@ ConvolutionDescriptor::GetSolutions(const ExecutionContext& exec_ctx, if(!solutions.empty()) return solutions; - return GetSolutionsFallback(exec_ctx, problem, maxSolutionCount); + return GetSolutionsFallback(ctx, problem, maxSolutionCount); } + std::size_t ConvolutionDescriptor::GetForwardSolutionWorkspaceSize(Handle& handle, const TensorDescriptor& wDesc, const TensorDescriptor& xDesc, @@ -891,6 +890,7 @@ void ConvolutionDescriptor::FindConvBwdDataAlgorithm(Handle& handle, MIOPEN_LOG_I("BWD Chosen Algorithm: " << results[0].solver_id << " , " << results[0].workspace << ", " << results[0].time); } + static void ConvBwdCheckNumerics(const Handle& handle, const ConvBwdTensors& tensors, const void* beta, diff --git a/src/ocl/fusionopconvocl.cpp b/src/ocl/fusionopconvocl.cpp index b4ea5b42aa..61ea49835f 100644 --- a/src/ocl/fusionopconvocl.cpp +++ b/src/ocl/fusionopconvocl.cpp @@ -2,7 +2,7 @@ namespace miopen { -ProblemDescription ConvForwardOpDescriptor::GetConvProblem() +conv::ProblemDescription ConvForwardOpDescriptor::GetConvProblem() { TensorDescriptor o_desc; GetOutputDesc(o_desc); @@ -15,7 +15,7 @@ ProblemDescription ConvForwardOpDescriptor::GetConvProblem() miopenStatus_t ConvForwardOpDescriptor::GetNetworkConfig(std::ostringstream& network_config) { - ProblemDescription conv_problem = GetConvProblem(); + const conv::ProblemDescription conv_problem = GetConvProblem(); std::string conv_config; conv_problem.MakeNetworkConfig(conv_config); diff --git a/src/ocl/mloNorm.cpp b/src/ocl/mloNorm.cpp index 92d45375ec..ea96da6439 100644 --- a/src/ocl/mloNorm.cpp +++ b/src/ocl/mloNorm.cpp @@ -33,7 +33,7 @@ void mlo_construct_norm::mloConstruct() { - if(_problem.direction.IsForward()) + if(_problem.IsDirectionForward()) { mloConstructFwd(); } diff --git a/src/problem.cpp b/src/problem.cpp index 0eaa7b7442..99eb994921 100644 --- a/src/problem.cpp +++ b/src/problem.cpp @@ -386,9 +386,8 @@ std::vector Problem::FindSolutionsImpl(Handle& handle, : conv::Direction::Forward; })(); - const auto legacy_problem = ProblemDescription{conv_problem}; - const auto netcfg = conv_problem.MakeNetworkConfig(); - auto conv_ctx = ExecutionContext{&handle}; + const auto netcfg = conv_problem.MakeNetworkConfig(); + auto conv_ctx = ExecutionContext{&handle}; conv_problem.SetupFloats(conv_ctx); decltype(auto) db = GetDb(conv_ctx); @@ -403,7 +402,7 @@ std::vector Problem::FindSolutionsImpl(Handle& handle, solution.SetWorkspaceSize(find1_solutions[i].memory); solution.SetSolver(handle.GetFound1_0SolverId(netcfg, AlgorithmName{algo}).value()); solution.SetPerfConfig( - solution.GetSolver().GetSolver().GetPerfCfgParams(conv_ctx, legacy_problem, db)); + solution.GetSolver().GetSolver().GetPerfCfgParams(conv_ctx, conv_problem, db)); solution.SetProblem(*this); MIOPEN_LOG_I("Found solution: " << solution.GetSolver().ToString() << " , " << solution.GetWorkspaceSize() << ", " diff --git a/src/problem_description.cpp b/src/problem_description.cpp deleted file mode 100644 index 3d05031bbc..0000000000 --- a/src/problem_description.cpp +++ /dev/null @@ -1,13 +0,0 @@ -#include - -namespace miopen { - -ProblemDescription::ProblemDescription(conv::ProblemDescription desc) - : conv::ProblemDescription(std::move(desc)), direction(GetDirection()) -{ -#if FIN_OLD_PROBLEM_DESCRIPTION_COMPAT - conv_problem.p = this; -#endif -} - -} // namespace miopen diff --git a/src/solution.cpp b/src/solution.cpp index 1997d5f8a1..aa4907079f 100644 --- a/src/solution.cpp +++ b/src/solution.cpp @@ -179,13 +179,12 @@ void Solution::RunImpl(Handle& handle, return; } - const auto legacy_problem = ProblemDescription{conv_problem}; - auto conv_ctx = ExecutionContext{&handle}; + auto conv_ctx = ExecutionContext{&handle}; conv_problem.SetupFloats(conv_ctx); decltype(auto) db = GetDb(conv_ctx); const auto conv_solution = GetSolver().GetSolver().FindSolution( - conv_ctx, legacy_problem, db, invoke_ctx, perf_cfg.value_or("")); + conv_ctx, conv_problem, db, invoke_ctx, perf_cfg.value_or("")); decltype(auto) invoker = handle.PrepareInvoker(*conv_solution.invoker_factory, conv_solution.construction_params); handle.RegisterInvoker(invoker, net_cfg, GetSolver().ToString()); diff --git a/src/solver.cpp b/src/solver.cpp index 4cd680dd9c..01835dcb1c 100644 --- a/src/solver.cpp +++ b/src/solver.cpp @@ -175,7 +175,7 @@ AnySolver Id::GetSolver() const return it != IdRegistry().value_to_entry.end() ? it->second.solver : AnySolver{}; } -std::string Id::GetAlgo(conv::Direction dir) const +std::string Id::GetAlgo(miopen::conv::Direction dir) const { return ConvolutionAlgoToDirectionalString(GetAlgo(), dir); } @@ -274,220 +274,251 @@ inline SolverRegistrar::SolverRegistrar(IdRegistryData& registry) uint64_t id = 0; // 0 is reserved for invalid value. // IMPORTANT: New solvers should be added to the end of the function! - RegisterWithSolver(registry, ++id, ConvAsm3x3U{}, miopenConvolutionAlgoDirect); - RegisterWithSolver(registry, ++id, ConvAsm1x1U{}, miopenConvolutionAlgoDirect); - RegisterWithSolver(registry, ++id, ConvAsm1x1UV2{}, miopenConvolutionAlgoDirect); + RegisterWithSolver(registry, ++id, conv::ConvAsm3x3U{}, miopenConvolutionAlgoDirect); + RegisterWithSolver(registry, ++id, conv::ConvAsm1x1U{}, miopenConvolutionAlgoDirect); + RegisterWithSolver(registry, ++id, conv::ConvAsm1x1UV2{}, miopenConvolutionAlgoDirect); Register(registry, ++id, Primitive::Fusion, - solver::fusion::ConvBiasActivAsm1x1U{}.SolverDbId(), + fusion::ConvBiasActivAsm1x1U{}.SolverDbId(), miopenConvolutionAlgoDirect); - RegisterWithSolver(registry, ++id, ConvAsm5x10u2v2f1{}, miopenConvolutionAlgoDirect); - RegisterWithSolver(registry, ++id, ConvAsm5x10u2v2b1{}, miopenConvolutionAlgoDirect); + RegisterWithSolver(registry, ++id, conv::ConvAsm5x10u2v2f1{}, miopenConvolutionAlgoDirect); + RegisterWithSolver(registry, ++id, conv::ConvAsm5x10u2v2b1{}, miopenConvolutionAlgoDirect); RegisterWithSolver( - registry, ++id, ConvAsm7x7c3h224w224k64u2v2p3q3f1{}, miopenConvolutionAlgoDirect); - RegisterWithSolver(registry, ++id, ConvOclDirectFwd11x11{}, miopenConvolutionAlgoDirect); - RegisterWithSolver(registry, ++id, ConvOclDirectFwdGen{}, miopenConvolutionAlgoDirect); + registry, ++id, conv::ConvAsm7x7c3h224w224k64u2v2p3q3f1{}, miopenConvolutionAlgoDirect); + RegisterWithSolver(registry, ++id, conv::ConvOclDirectFwd11x11{}, miopenConvolutionAlgoDirect); + RegisterWithSolver(registry, ++id, conv::ConvOclDirectFwdGen{}, miopenConvolutionAlgoDirect); ++id; // removed ConvOclDirectFwd3x3 - RegisterWithSolver(registry, ++id, ConvOclDirectFwd{}, miopenConvolutionAlgoDirect); + RegisterWithSolver(registry, ++id, conv::ConvOclDirectFwd{}, miopenConvolutionAlgoDirect); Register(registry, ++id, Primitive::Fusion, - solver::fusion::ConvOclDirectFwdFused{}.SolverDbId(), + fusion::ConvOclDirectFwdFused{}.SolverDbId(), miopenConvolutionAlgoDirect); - RegisterWithSolver(registry, ++id, ConvOclDirectFwd1x1{}, miopenConvolutionAlgoDirect); - RegisterWithSolver(registry, ++id, ConvBinWinograd3x3U{}, miopenConvolutionAlgoWinograd); - RegisterWithSolver(registry, ++id, ConvBinWinogradRxS{}, miopenConvolutionAlgoWinograd); - RegisterWithSolver(registry, ++id, ConvAsmBwdWrW3x3{}, miopenConvolutionAlgoDirect); - RegisterWithSolver(registry, ++id, ConvAsmBwdWrW1x1{}, miopenConvolutionAlgoDirect); - RegisterWithSolver(registry, ++id, ConvOclBwdWrW2<1>{}, miopenConvolutionAlgoDirect); - RegisterWithSolver(registry, ++id, ConvOclBwdWrW2<2>{}, miopenConvolutionAlgoDirect); - RegisterWithSolver(registry, ++id, ConvOclBwdWrW2<4>{}, miopenConvolutionAlgoDirect); - RegisterWithSolver(registry, ++id, ConvOclBwdWrW2<8>{}, miopenConvolutionAlgoDirect); - RegisterWithSolver(registry, ++id, ConvOclBwdWrW2<16>{}, miopenConvolutionAlgoDirect); - RegisterWithSolver(registry, ++id, ConvOclBwdWrW2NonTunable{}, miopenConvolutionAlgoDirect); - RegisterWithSolver(registry, ++id, ConvOclBwdWrW53{}, miopenConvolutionAlgoDirect); - RegisterWithSolver(registry, ++id, ConvOclBwdWrW1x1{}, miopenConvolutionAlgoDirect); - RegisterWithSolver( - registry, ++id, ConvHipImplicitGemmV4R1Fwd{}, miopenConvolutionAlgoImplicitGEMM); + RegisterWithSolver(registry, ++id, conv::ConvOclDirectFwd1x1{}, miopenConvolutionAlgoDirect); + RegisterWithSolver(registry, ++id, conv::ConvBinWinograd3x3U{}, miopenConvolutionAlgoWinograd); + RegisterWithSolver(registry, ++id, conv::ConvBinWinogradRxS{}, miopenConvolutionAlgoWinograd); + RegisterWithSolver(registry, ++id, conv::ConvAsmBwdWrW3x3{}, miopenConvolutionAlgoDirect); + RegisterWithSolver(registry, ++id, conv::ConvAsmBwdWrW1x1{}, miopenConvolutionAlgoDirect); + RegisterWithSolver(registry, ++id, conv::ConvOclBwdWrW2<1>{}, miopenConvolutionAlgoDirect); + RegisterWithSolver(registry, ++id, conv::ConvOclBwdWrW2<2>{}, miopenConvolutionAlgoDirect); + RegisterWithSolver(registry, ++id, conv::ConvOclBwdWrW2<4>{}, miopenConvolutionAlgoDirect); + RegisterWithSolver(registry, ++id, conv::ConvOclBwdWrW2<8>{}, miopenConvolutionAlgoDirect); + RegisterWithSolver(registry, ++id, conv::ConvOclBwdWrW2<16>{}, miopenConvolutionAlgoDirect); + RegisterWithSolver( + registry, ++id, conv::ConvOclBwdWrW2NonTunable{}, miopenConvolutionAlgoDirect); + RegisterWithSolver(registry, ++id, conv::ConvOclBwdWrW53{}, miopenConvolutionAlgoDirect); + RegisterWithSolver(registry, ++id, conv::ConvOclBwdWrW1x1{}, miopenConvolutionAlgoDirect); + RegisterWithSolver( + registry, ++id, conv::ConvHipImplicitGemmV4R1Fwd{}, miopenConvolutionAlgoImplicitGEMM); ++id; // removed solver ConvHipImplicitGemmV4Fwd ++id; // removed solver ConvHipImplicitGemmV4_1x1 ++id; // removed solver ConvHipImplicitGemmV4R4FwdXdlops ++id; // removed solver ConvHipImplicitGemmV4R4Xdlops_1x1 RegisterWithSolver( - registry, ++id, ConvHipImplicitGemmV4R1WrW{}, miopenConvolutionAlgoImplicitGEMM); + registry, ++id, conv::ConvHipImplicitGemmV4R1WrW{}, miopenConvolutionAlgoImplicitGEMM); ++id; // removed solver ConvHipImplicitGemmV4WrW // Several ids w/o solver for immediate mode ++id; // old gemm pseudo-solverid - RegisterWithSolver(registry, ++id, fft{}, miopenConvolutionAlgoFFT); + RegisterWithSolver(registry, ++id, conv::fft{}, miopenConvolutionAlgoFFT); RegisterWithSolver( - registry, ++id, ConvWinograd3x3MultipassWrW<3, 4>{}, miopenConvolutionAlgoWinograd); + registry, ++id, conv::ConvWinograd3x3MultipassWrW<3, 4>{}, miopenConvolutionAlgoWinograd); ++id; // Id for ConvSCGemmFGemm. - RegisterWithSolver(registry, ++id, ConvBinWinoRxS<3, 2>{}, miopenConvolutionAlgoWinograd); + RegisterWithSolver(registry, ++id, conv::ConvBinWinoRxS<3, 2>{}, miopenConvolutionAlgoWinograd); RegisterWithSolver( - registry, ++id, ConvWinograd3x3MultipassWrW<3, 5>{}, miopenConvolutionAlgoWinograd); + registry, ++id, conv::ConvWinograd3x3MultipassWrW<3, 5>{}, miopenConvolutionAlgoWinograd); RegisterWithSolver( - registry, ++id, ConvWinograd3x3MultipassWrW<3, 6>{}, miopenConvolutionAlgoWinograd); + registry, ++id, conv::ConvWinograd3x3MultipassWrW<3, 6>{}, miopenConvolutionAlgoWinograd); RegisterWithSolver( - registry, ++id, ConvWinograd3x3MultipassWrW<3, 2>{}, miopenConvolutionAlgoWinograd); + registry, ++id, conv::ConvWinograd3x3MultipassWrW<3, 2>{}, miopenConvolutionAlgoWinograd); RegisterWithSolver( - registry, ++id, ConvWinograd3x3MultipassWrW<3, 3>{}, miopenConvolutionAlgoWinograd); + registry, ++id, conv::ConvWinograd3x3MultipassWrW<3, 3>{}, miopenConvolutionAlgoWinograd); RegisterWithSolver( - registry, ++id, ConvWinograd3x3MultipassWrW<7, 2>{}, miopenConvolutionAlgoWinograd); + registry, ++id, conv::ConvWinograd3x3MultipassWrW<7, 2>{}, miopenConvolutionAlgoWinograd); RegisterWithSolver( - registry, ++id, ConvWinograd3x3MultipassWrW<7, 3>{}, miopenConvolutionAlgoWinograd); - RegisterWithSolver( - registry, ++id, ConvWinograd3x3MultipassWrW<7, 2, 1, 1>{}, miopenConvolutionAlgoWinograd); - RegisterWithSolver( - registry, ++id, ConvWinograd3x3MultipassWrW<7, 3, 1, 1>{}, miopenConvolutionAlgoWinograd); - RegisterWithSolver( - registry, ++id, ConvWinograd3x3MultipassWrW<1, 1, 7, 2>{}, miopenConvolutionAlgoWinograd); - RegisterWithSolver( - registry, ++id, ConvWinograd3x3MultipassWrW<1, 1, 7, 3>{}, miopenConvolutionAlgoWinograd); + registry, ++id, conv::ConvWinograd3x3MultipassWrW<7, 3>{}, miopenConvolutionAlgoWinograd); + RegisterWithSolver(registry, + ++id, + conv::ConvWinograd3x3MultipassWrW<7, 2, 1, 1>{}, + miopenConvolutionAlgoWinograd); + RegisterWithSolver(registry, + ++id, + conv::ConvWinograd3x3MultipassWrW<7, 3, 1, 1>{}, + miopenConvolutionAlgoWinograd); + RegisterWithSolver(registry, + ++id, + conv::ConvWinograd3x3MultipassWrW<1, 1, 7, 2>{}, + miopenConvolutionAlgoWinograd); + RegisterWithSolver(registry, + ++id, + conv::ConvWinograd3x3MultipassWrW<1, 1, 7, 3>{}, + miopenConvolutionAlgoWinograd); RegisterWithSolver( - registry, ++id, ConvWinograd3x3MultipassWrW<5, 3>{}, miopenConvolutionAlgoWinograd); + registry, ++id, conv::ConvWinograd3x3MultipassWrW<5, 3>{}, miopenConvolutionAlgoWinograd); RegisterWithSolver( - registry, ++id, ConvWinograd3x3MultipassWrW<5, 4>{}, miopenConvolutionAlgoWinograd); + registry, ++id, conv::ConvWinograd3x3MultipassWrW<5, 4>{}, miopenConvolutionAlgoWinograd); ++id; // removed solver ConvHipImplicitGemmV4R4WrWXdlops ++id; // removed solver ConvHipImplicitGemmV4R4GenFwdXdlops ++id; // removed solver ConvHipImplicitGemmV4R4GenWrWXdlops - RegisterWithSolver(registry, ++id, ConvBinWinoRxS<2, 3>{}, miopenConvolutionAlgoWinograd); + RegisterWithSolver(registry, ++id, conv::ConvBinWinoRxS<2, 3>{}, miopenConvolutionAlgoWinograd); RegisterWithSolver( - registry, ++id, ConvHipImplicitGemmV4R4Fwd{}, miopenConvolutionAlgoImplicitGEMM); + registry, ++id, conv::ConvHipImplicitGemmV4R4Fwd{}, miopenConvolutionAlgoImplicitGEMM); RegisterWithSolver( - registry, ++id, ConvHipImplicitGemmBwdDataV1R1{}, miopenConvolutionAlgoImplicitGEMM); + registry, ++id, conv::ConvHipImplicitGemmBwdDataV1R1{}, miopenConvolutionAlgoImplicitGEMM); RegisterWithSolver( - registry, ++id, ConvHipImplicitGemmBwdDataV4R1{}, miopenConvolutionAlgoImplicitGEMM); + registry, ++id, conv::ConvHipImplicitGemmBwdDataV4R1{}, miopenConvolutionAlgoImplicitGEMM); - RegisterWithSolver( - registry, ++id, ConvHipImplicitGemmBwdDataV1R1Xdlops{}, miopenConvolutionAlgoImplicitGEMM); + RegisterWithSolver(registry, + ++id, + conv::ConvHipImplicitGemmBwdDataV1R1Xdlops{}, + miopenConvolutionAlgoImplicitGEMM); ++id; // removed solver ConvHipImplicitGemmV4R4GenXdlopsFwdFp32 ++id; // removed solver ConvHipImplicitGemmV4R4GenXdlopsWrWFp32 - RegisterWithSolver( - registry, ++id, ConvHipImplicitGemmBwdDataV4R1Xdlops{}, miopenConvolutionAlgoImplicitGEMM); + RegisterWithSolver(registry, + ++id, + conv::ConvHipImplicitGemmBwdDataV4R1Xdlops{}, + miopenConvolutionAlgoImplicitGEMM); RegisterWithSolver( - registry, ++id, ConvHipImplicitGemmV4R4WrW{}, miopenConvolutionAlgoImplicitGEMM); + registry, ++id, conv::ConvHipImplicitGemmV4R4WrW{}, miopenConvolutionAlgoImplicitGEMM); - RegisterWithSolver( - registry, ++id, ConvAsmImplicitGemmV4R1DynamicFwd{}, miopenConvolutionAlgoImplicitGEMM); + RegisterWithSolver(registry, + ++id, + conv::ConvAsmImplicitGemmV4R1DynamicFwd{}, + miopenConvolutionAlgoImplicitGEMM); - RegisterWithSolver( - registry, ++id, ConvAsmImplicitGemmV4R1DynamicFwd_1x1{}, miopenConvolutionAlgoImplicitGEMM); + RegisterWithSolver(registry, + ++id, + conv::ConvAsmImplicitGemmV4R1DynamicFwd_1x1{}, + miopenConvolutionAlgoImplicitGEMM); - RegisterWithSolver( - registry, ++id, ConvHipImplicitGemmForwardV4R4Xdlops{}, miopenConvolutionAlgoImplicitGEMM); + RegisterWithSolver(registry, + ++id, + conv::ConvHipImplicitGemmForwardV4R4Xdlops{}, + miopenConvolutionAlgoImplicitGEMM); - RegisterWithSolver( - registry, ++id, ConvAsmImplicitGemmV4R1DynamicBwd{}, miopenConvolutionAlgoImplicitGEMM); + RegisterWithSolver(registry, + ++id, + conv::ConvAsmImplicitGemmV4R1DynamicBwd{}, + miopenConvolutionAlgoImplicitGEMM); - RegisterWithSolver( - registry, ++id, ConvAsmImplicitGemmV4R1DynamicWrw{}, miopenConvolutionAlgoImplicitGEMM); + RegisterWithSolver(registry, + ++id, + conv::ConvAsmImplicitGemmV4R1DynamicWrw{}, + miopenConvolutionAlgoImplicitGEMM); RegisterWithSolver( - registry, ++id, ConvMPBidirectWinograd<2, 3>{}, miopenConvolutionAlgoWinograd); + registry, ++id, conv::ConvMPBidirectWinograd<2, 3>{}, miopenConvolutionAlgoWinograd); RegisterWithSolver( - registry, ++id, ConvMPBidirectWinograd<3, 3>{}, miopenConvolutionAlgoWinograd); + registry, ++id, conv::ConvMPBidirectWinograd<3, 3>{}, miopenConvolutionAlgoWinograd); RegisterWithSolver( - registry, ++id, ConvMPBidirectWinograd<4, 3>{}, miopenConvolutionAlgoWinograd); + registry, ++id, conv::ConvMPBidirectWinograd<4, 3>{}, miopenConvolutionAlgoWinograd); RegisterWithSolver( - registry, ++id, ConvMPBidirectWinograd<5, 3>{}, miopenConvolutionAlgoWinograd); + registry, ++id, conv::ConvMPBidirectWinograd<5, 3>{}, miopenConvolutionAlgoWinograd); RegisterWithSolver( - registry, ++id, ConvMPBidirectWinograd<6, 3>{}, miopenConvolutionAlgoWinograd); + registry, ++id, conv::ConvMPBidirectWinograd<6, 3>{}, miopenConvolutionAlgoWinograd); RegisterWithSolver(registry, ++id, - ConvAsmImplicitGemmGTCDynamicWrwXdlops{}, + conv::ConvAsmImplicitGemmGTCDynamicWrwXdlops{}, + miopenConvolutionAlgoImplicitGEMM); + RegisterWithSolver(registry, + ++id, + conv::ConvHipImplicitGemmWrwV4R4Xdlops{}, miopenConvolutionAlgoImplicitGEMM); - RegisterWithSolver( - registry, ++id, ConvHipImplicitGemmWrwV4R4Xdlops{}, miopenConvolutionAlgoImplicitGEMM); RegisterWithSolver(registry, ++id, - ConvAsmImplicitGemmGTCDynamicFwdXdlops{}, + conv::ConvAsmImplicitGemmGTCDynamicFwdXdlops{}, miopenConvolutionAlgoImplicitGEMM); RegisterWithSolver( - registry, ++id, ConvMPBidirectWinograd_xdlops<2, 3>{}, miopenConvolutionAlgoWinograd); + registry, ++id, conv::ConvMPBidirectWinograd_xdlops<2, 3>{}, miopenConvolutionAlgoWinograd); RegisterWithSolver( - registry, ++id, ConvMPBidirectWinograd_xdlops<3, 3>{}, miopenConvolutionAlgoWinograd); + registry, ++id, conv::ConvMPBidirectWinograd_xdlops<3, 3>{}, miopenConvolutionAlgoWinograd); RegisterWithSolver( - registry, ++id, ConvMPBidirectWinograd_xdlops<4, 3>{}, miopenConvolutionAlgoWinograd); + registry, ++id, conv::ConvMPBidirectWinograd_xdlops<4, 3>{}, miopenConvolutionAlgoWinograd); RegisterWithSolver( - registry, ++id, ConvMPBidirectWinograd_xdlops<5, 3>{}, miopenConvolutionAlgoWinograd); + registry, ++id, conv::ConvMPBidirectWinograd_xdlops<5, 3>{}, miopenConvolutionAlgoWinograd); RegisterWithSolver( - registry, ++id, ConvMPBidirectWinograd_xdlops<6, 3>{}, miopenConvolutionAlgoWinograd); + registry, ++id, conv::ConvMPBidirectWinograd_xdlops<6, 3>{}, miopenConvolutionAlgoWinograd); - RegisterWithSolver( - registry, ++id, ConvHipImplicitGemmForwardV4R5Xdlops{}, miopenConvolutionAlgoImplicitGEMM); + RegisterWithSolver(registry, + ++id, + conv::ConvHipImplicitGemmForwardV4R5Xdlops{}, + miopenConvolutionAlgoImplicitGEMM); RegisterWithSolver(registry, ++id, - ConvHipImplicitGemmForwardV4R4Xdlops_Padded_Gemm{}, + conv::ConvHipImplicitGemmForwardV4R4Xdlops_Padded_Gemm{}, miopenConvolutionAlgoImplicitGEMM); RegisterWithSolver(registry, ++id, - ConvAsmImplicitGemmGTCDynamicBwdXdlops{}, + conv::ConvAsmImplicitGemmGTCDynamicBwdXdlops{}, miopenConvolutionAlgoImplicitGEMM); RegisterWithSolver(registry, ++id, - ConvHipImplicitGemmWrwV4R4Xdlops_Padded_Gemm{}, + conv::ConvHipImplicitGemmWrwV4R4Xdlops_Padded_Gemm{}, miopenConvolutionAlgoImplicitGEMM); - RegisterWithSolver(registry, ++id, ConvBinWinogradRxSf2x3g1{}, miopenConvolutionAlgoWinograd); + RegisterWithSolver( + registry, ++id, conv::ConvBinWinogradRxSf2x3g1{}, miopenConvolutionAlgoWinograd); - RegisterWithSolver(registry, ++id, ConvDirectNaiveConvFwd{}, miopenConvolutionAlgoDirect); - RegisterWithSolver(registry, ++id, ConvDirectNaiveConvBwd{}, miopenConvolutionAlgoDirect); - RegisterWithSolver(registry, ++id, ConvDirectNaiveConvWrw{}, miopenConvolutionAlgoDirect); + RegisterWithSolver(registry, ++id, conv::ConvDirectNaiveConvFwd{}, miopenConvolutionAlgoDirect); + RegisterWithSolver(registry, ++id, conv::ConvDirectNaiveConvBwd{}, miopenConvolutionAlgoDirect); + RegisterWithSolver(registry, ++id, conv::ConvDirectNaiveConvWrw{}, miopenConvolutionAlgoDirect); - RegisterWithSolver(registry, ++id, GemmFwd1x1_0_1{}, miopenConvolutionAlgoGEMM); - RegisterWithSolver(registry, ++id, GemmFwd1x1_0_1_int8{}, miopenConvolutionAlgoGEMM); - RegisterWithSolver(registry, ++id, GemmFwd1x1_0_2{}, miopenConvolutionAlgoGEMM); - RegisterWithSolver(registry, ++id, GemmFwdRest{}, miopenConvolutionAlgoGEMM); + RegisterWithSolver(registry, ++id, conv::GemmFwd1x1_0_1{}, miopenConvolutionAlgoGEMM); + RegisterWithSolver(registry, ++id, conv::GemmFwd1x1_0_1_int8{}, miopenConvolutionAlgoGEMM); + RegisterWithSolver(registry, ++id, conv::GemmFwd1x1_0_2{}, miopenConvolutionAlgoGEMM); + RegisterWithSolver(registry, ++id, conv::GemmFwdRest{}, miopenConvolutionAlgoGEMM); ++id; // removed solver ConvHipImplicitGemmMlirCppFwd ++id; // removed solver ConvHipImplicitGemmMlirCppBwd ++id; // removed solver ConvHipImplicitGemmMlirCppWrW - RegisterWithSolver(registry, ++id, GemmBwd1x1_stride2{}, miopenConvolutionAlgoGEMM); - RegisterWithSolver(registry, ++id, GemmBwd1x1_stride1{}, miopenConvolutionAlgoGEMM); - RegisterWithSolver(registry, ++id, GemmBwdRest{}, miopenConvolutionAlgoGEMM); + RegisterWithSolver(registry, ++id, conv::GemmBwd1x1_stride2{}, miopenConvolutionAlgoGEMM); + RegisterWithSolver(registry, ++id, conv::GemmBwd1x1_stride1{}, miopenConvolutionAlgoGEMM); + RegisterWithSolver(registry, ++id, conv::GemmBwdRest{}, miopenConvolutionAlgoGEMM); - RegisterWithSolver(registry, ++id, ConvMlirIgemmFwd{}, miopenConvolutionAlgoImplicitGEMM); - RegisterWithSolver(registry, ++id, ConvMlirIgemmBwd{}, miopenConvolutionAlgoImplicitGEMM); - RegisterWithSolver(registry, ++id, ConvMlirIgemmWrW{}, miopenConvolutionAlgoImplicitGEMM); + RegisterWithSolver(registry, ++id, conv::ConvMlirIgemmFwd{}, miopenConvolutionAlgoImplicitGEMM); + RegisterWithSolver(registry, ++id, conv::ConvMlirIgemmBwd{}, miopenConvolutionAlgoImplicitGEMM); + RegisterWithSolver(registry, ++id, conv::ConvMlirIgemmWrW{}, miopenConvolutionAlgoImplicitGEMM); - RegisterWithSolver(registry, ++id, GemmWrw1x1_stride1{}, miopenConvolutionAlgoGEMM); - RegisterWithSolver(registry, ++id, GemmWrwUniversal{}, miopenConvolutionAlgoGEMM); + RegisterWithSolver(registry, ++id, conv::GemmWrw1x1_stride1{}, miopenConvolutionAlgoGEMM); + RegisterWithSolver(registry, ++id, conv::GemmWrwUniversal{}, miopenConvolutionAlgoGEMM); - RegisterWithSolver(registry, ++id, ConvMlirIgemmFwdXdlops{}, miopenConvolutionAlgoImplicitGEMM); - RegisterWithSolver(registry, ++id, ConvMlirIgemmBwdXdlops{}, miopenConvolutionAlgoImplicitGEMM); - RegisterWithSolver(registry, ++id, ConvMlirIgemmWrWXdlops{}, miopenConvolutionAlgoImplicitGEMM); + RegisterWithSolver( + registry, ++id, conv::ConvMlirIgemmFwdXdlops{}, miopenConvolutionAlgoImplicitGEMM); + RegisterWithSolver( + registry, ++id, conv::ConvMlirIgemmBwdXdlops{}, miopenConvolutionAlgoImplicitGEMM); + RegisterWithSolver( + registry, ++id, conv::ConvMlirIgemmWrWXdlops{}, miopenConvolutionAlgoImplicitGEMM); Register(registry, ++id, Primitive::Activation, activ::ActivFwdSolver0{}.SolverDbId()); RegisterWithSolver(registry, ++id, - ConvAsmImplicitGemmGTCDynamicFwdXdlopsNHWC{}, + conv::ConvAsmImplicitGemmGTCDynamicFwdXdlopsNHWC{}, miopenConvolutionAlgoImplicitGEMM); RegisterWithSolver(registry, ++id, - ConvAsmImplicitGemmGTCDynamicBwdXdlopsNHWC{}, + conv::ConvAsmImplicitGemmGTCDynamicBwdXdlopsNHWC{}, miopenConvolutionAlgoImplicitGEMM); Register(registry, ++id, Primitive::Activation, activ::ActivFwdSolver1{}.SolverDbId()); RegisterWithSolver(registry, ++id, - ConvAsmImplicitGemmGTCDynamicWrwXdlopsNHWC{}, + conv::ConvAsmImplicitGemmGTCDynamicWrwXdlopsNHWC{}, miopenConvolutionAlgoImplicitGEMM); Register(registry, ++id, Primitive::Activation, activ::ActivBwdSolver0{}.SolverDbId()); @@ -497,7 +528,7 @@ inline SolverRegistrar::SolverRegistrar(IdRegistryData& registry) registry, ++id, Primitive::Batchnorm, batchnorm::BnFwdTrainingSpatialSingle{}.SolverDbId()); RegisterWithSolver( - registry, ++id, ConvCkIgemmFwdV6r1DlopsNchw{}, miopenConvolutionAlgoImplicitGEMM); + registry, ++id, conv::ConvCkIgemmFwdV6r1DlopsNchw{}, miopenConvolutionAlgoImplicitGEMM); Register(registry, ++id, @@ -529,45 +560,49 @@ inline SolverRegistrar::SolverRegistrar(IdRegistryData& registry) RegisterWithSolver(registry, ++id, - ConvAsmImplicitGemmGTCDynamicFwdDlopsNCHWC{}, + conv::ConvAsmImplicitGemmGTCDynamicFwdDlopsNCHWC{}, miopenConvolutionAlgoImplicitGEMM); RegisterWithSolver( - registry, ++id, ConvHipImplicitGemmFwdXdlops{}, miopenConvolutionAlgoImplicitGEMM); + registry, ++id, conv::ConvHipImplicitGemmFwdXdlops{}, miopenConvolutionAlgoImplicitGEMM); RegisterWithSolver( - registry, ++id, ConvHipImplicitGemmBwdXdlops{}, miopenConvolutionAlgoImplicitGEMM); + registry, ++id, conv::ConvHipImplicitGemmBwdXdlops{}, miopenConvolutionAlgoImplicitGEMM); Register(registry, ++id, Primitive::Fusion, - solver::fusion::ConvBinWinogradRxSFused{}.SolverDbId(), + fusion::ConvBinWinogradRxSFused{}.SolverDbId(), miopenConvolutionAlgoWinograd); Register(registry, ++id, Primitive::Fusion, - solver::fusion::ConvBinWinogradRxSf2x3g1Fused{}.SolverDbId(), + fusion::ConvBinWinogradRxSf2x3g1Fused{}.SolverDbId(), miopenConvolutionAlgoWinograd); + Register(registry, ++id, Primitive::Fusion, fusion::BnFwdInferActivationFused{}.SolverDbId()); + Register(registry, ++id, Primitive::Fusion, fusion::BnFwdTrgActivationFused{}.SolverDbId()); + Register(registry, ++id, Primitive::Fusion, fusion::BnBwdTrgActivationFused{}.SolverDbId()); Register(registry, ++id, Primitive::Fusion, - solver::fusion::BnFwdInferActivationFused{}.SolverDbId()); - Register( - registry, ++id, Primitive::Fusion, solver::fusion::BnFwdTrgActivationFused{}.SolverDbId()); - Register( - registry, ++id, Primitive::Fusion, solver::fusion::BnBwdTrgActivationFused{}.SolverDbId()); - Register(registry, - ++id, - Primitive::Fusion, - solver::fusion::ConvCKIgemmFwdBiasActivFused{}.SolverDbId(), + fusion::ConvCKIgemmFwdBiasActivFused{}.SolverDbId(), miopenConvolutionAlgoImplicitGEMM); Register(registry, ++id, Primitive::Pooling, pooling::PoolingForwardNaive{}.SolverDbId()); + RegisterWithSolver(registry, + ++id, + conv::ConvHipImplicitGemmGroupFwdXdlops{}, + miopenConvolutionAlgoImplicitGEMM); + RegisterWithSolver(registry, + ++id, + conv::ConvHipImplicitGemm3DGroupFwdXdlops{}, + miopenConvolutionAlgoImplicitGEMM); RegisterWithSolver( - registry, ++id, ConvHipImplicitGemmGroupFwdXdlops{}, miopenConvolutionAlgoImplicitGEMM); - RegisterWithSolver( - registry, ++id, ConvHipImplicitGemm3DGroupFwdXdlops{}, miopenConvolutionAlgoImplicitGEMM); - RegisterWithSolver(registry, ++id, ConvWinoFuryRxS<2, 3>{}, miopenConvolutionAlgoWinograd); - RegisterWithSolver( - registry, ++id, ConvHipImplicitGemm3DGroupWrwXdlops{}, miopenConvolutionAlgoImplicitGEMM); - RegisterWithSolver( - registry, ++id, ConvHipImplicitGemm3DGroupBwdXdlops{}, miopenConvolutionAlgoImplicitGEMM); + registry, ++id, conv::ConvWinoFuryRxS<2, 3>{}, miopenConvolutionAlgoWinograd); + RegisterWithSolver(registry, + ++id, + conv::ConvHipImplicitGemm3DGroupWrwXdlops{}, + miopenConvolutionAlgoImplicitGEMM); + RegisterWithSolver(registry, + ++id, + conv::ConvHipImplicitGemm3DGroupBwdXdlops{}, + miopenConvolutionAlgoImplicitGEMM); Register(registry, ++id, Primitive::Batchnorm, batchnorm::BnCKFwdInference{}.SolverDbId()); Register(registry, ++id, Primitive::Batchnorm, batchnorm::BnCKBwdBackward{}.SolverDbId()); Register(registry, ++id, Primitive::Batchnorm, batchnorm::BnCKFwdTraining{}.SolverDbId()); diff --git a/src/solver/conv_MP_bidirectional_winograd.cpp b/src/solver/conv_MP_bidirectional_winograd.cpp index 1ecc210144..c70ea319bd 100644 --- a/src/solver/conv_MP_bidirectional_winograd.cpp +++ b/src/solver/conv_MP_bidirectional_winograd.cpp @@ -59,6 +59,10 @@ namespace miopen { namespace solver { +namespace conv { + +using ProblemDescription = miopen::conv::ProblemDescription; + MIOPEN_DECLARE_ENV_VAR(MIOPEN_DEBUG_AMD_MP_BD_WINOGRAD_F2X3) MIOPEN_DECLARE_ENV_VAR(MIOPEN_DEBUG_AMD_MP_BD_WINOGRAD_F3X3) MIOPEN_DECLARE_ENV_VAR(MIOPEN_DEBUG_AMD_MP_BD_WINOGRAD_F4X3) @@ -80,14 +84,13 @@ MIOPEN_DECLARE_ENV_VAR(MIOPEN_CONV_PRECISE_ROCBLAS_TIMING) // Introduces a number of shader-specific aliases (names) in the current scope at zero cost. // These names represent shader parameters, e.g. shader C is batch_size etc and useful for // programming. -#define DEFINE_GETXFORMHWSIZE() \ - const auto \ - wino_xform_h = \ - solver::ConvMPBidirectWinograd:: \ - GetSolverWinoXformHWSize(), \ - wino_xform_w = \ - solver::ConvMPBidirectWinograd:: \ - GetSolverWinoXformHWSize(); +#define DEFINE_GETXFORMHWSIZE() \ + const auto wino_xform_h = \ + ConvMPBidirectWinograd:: \ + GetSolverWinoXformHWSize(), \ + wino_xform_w = \ + ConvMPBidirectWinograd:: \ + GetSolverWinoXformHWSize(); #define DEFINE_SHADER_ALIASES(problem) \ const auto group_cnt = (problem).GetGroupCount(); \ @@ -186,7 +189,7 @@ static bool IsApplicableTransform(const ExecutionContext& ctx, const ProblemDesc return false; if(!problem.Is2d()) return false; - if(!(problem.direction.IsForward() || problem.direction.IsBackwardData())) + if(!(problem.IsDirectionForward() || problem.IsDirectionBackwardData())) return false; if(!(problem.IsFp32() || problem.IsFp16())) return false; @@ -263,7 +266,7 @@ static bool IsApplicableTransform(const ExecutionContext& ctx, const ProblemDesc group_cnt, GetTypeSize(problem.GetOutDataType())), // cppcheck-suppress unreadVariable - wei_buff(GetGroupConvLayout(problem.direction.IsForward() + wei_buff(GetGroupConvLayout(problem.IsDirectionForward() ? (MemLayout_t::NCHW) : GetSwappedNCLayout(MemLayout_t::NCHW), false), @@ -382,8 +385,8 @@ static InvokerFactory MakeWinogradInvokerFactory(const ExecutionContext& ctx, bool isXdlops = false) { #if MIOPEN_BACKEND_HIP - const int pad_H = problem.direction.IsForward() ? problem.GetPadH() : problem.GetBackwardPadH(); - const int pad_W = problem.direction.IsForward() ? problem.GetPadW() : problem.GetBackwardPadW(); + const int pad_H = problem.IsDirectionForward() ? problem.GetPadH() : problem.GetBackwardPadH(); + const int pad_W = problem.IsDirectionForward() ? problem.GetPadW() : problem.GetBackwardPadW(); const int n_groups = ctx.GetStream().GetMaxComputeUnits(); DEFINE_SHADER_ALIASES(problem) DEFINE_GETXFORMHWSIZE() @@ -403,7 +406,7 @@ static InvokerFactory MakeWinogradInvokerFactory(const ExecutionContext& ctx, group_cnt, GetTypeSize(problem.GetOutDataType())), // cppcheck-suppress unreadVariable - weights_buff(GetGroupConvLayout(problem.direction.IsForward() + weights_buff(GetGroupConvLayout(problem.IsDirectionForward() ? (MemLayout_t::NCHW) : GetSwappedNCLayout(MemLayout_t::NCHW), false), @@ -468,8 +471,9 @@ static InvokerFactory MakeWinogradInvokerFactory(const ExecutionContext& ctx, gemm_conv_factory = [=](const std::vector&) { return [=](const Handle& handle, const AnyInvokeParams& primitive_parameters) { #if MIOPEN_USE_ROCBLAS - const auto& data_ctx = primitive_parameters.CastTo(); - Data_t workSpace = data_ctx.workSpace; + const auto& data_ctx = + primitive_parameters.CastTo(); + Data_t workSpace = data_ctx.workSpace; CallGemmStridedBatched( handle, wino_gemm_desc, @@ -499,7 +503,7 @@ static InvokerFactory MakeWinogradInvokerFactory(const ExecutionContext& ctx, auto gemm_conv_invoker = gemm_conv_factory(conv_kernels); return [=](const Handle& handle, const AnyInvokeParams& primitive_parameters) { - const auto& data_ctx = primitive_parameters.CastTo(); + const auto& data_ctx = primitive_parameters.CastTo(); const auto tensors = data_ctx.tensors; Data_t workSpace = data_ctx.workSpace; auto workSpaceSize = data_ctx.workSpaceSize; @@ -520,7 +524,7 @@ static InvokerFactory MakeWinogradInvokerFactory(const ExecutionContext& ctx, // xdlops_conv use tensors.in, tensors.w, tensors.out ConvDataTensors xdlops_tensor = ConvDataTensors(ConvFwdTensors{ zeroDesc, wino_in_ptr, zeroDesc, wino_w_ptr, zeroDesc, wino_out_ptr}); - const auto invoke_params = conv::DataInvokeParams{ + const auto invoke_params = miopen::conv::DataInvokeParams{ xdlops_tensor, workSpace, workSpaceSize, data_ctx.gfx90aFp16alt}; gemm_conv_invoker(handle, invoke_params); @@ -661,7 +665,7 @@ ConvSolution ConvMPBidirectWinograd -static conv::DataInvokeParams GetTransformedInvokeContext(const ProblemDescription& problem, - const AnyInvokeParams& invoke_ctx) +static miopen::conv::DataInvokeParams GetTransformedInvokeContext(const ProblemDescription& problem, + const AnyInvokeParams& invoke_ctx) { #if MIOPEN_BACKEND_HIP const miopenDataType_t transform_data_type = @@ -812,7 +816,7 @@ static conv::DataInvokeParams GetTransformedInvokeContext(const ProblemDescripti const WinoOffsets transform_offset(wino_in.buff_info.total_byte_size, wino_out.buff_info.total_byte_size); - const auto& data_ctx = invoke_ctx.CastTo(); + const auto& data_ctx = invoke_ctx.CastTo(); auto workSpace = data_ctx.workSpace; @@ -833,7 +837,7 @@ static conv::DataInvokeParams GetTransformedInvokeContext(const ProblemDescripti const auto zeroDesc = TensorDescriptor(); ConvDataTensors xdlops_tensor = ConvDataTensors( ConvFwdTensors{zeroDesc, wino_in_ptr, zeroDesc, wino_w_ptr, zeroDesc, wino_out_ptr}); - return conv::DataInvokeParams{ + return miopen::conv::DataInvokeParams{ xdlops_tensor, gemm_workSpace, gemm_workSpaceSize, data_ctx.gfx90aFp16alt}; #else std::ignore = problem; @@ -936,5 +940,6 @@ template struct ConvMPBidirectWinograd_xdlops<4, 3>; template struct ConvMPBidirectWinograd_xdlops<5, 3>; template struct ConvMPBidirectWinograd_xdlops<6, 3>; +} // namespace conv } // namespace solver } // namespace miopen diff --git a/src/solver/conv_asm_1x1u.cpp b/src/solver/conv_asm_1x1u.cpp index be72cf1fe3..e162059246 100644 --- a/src/solver/conv_asm_1x1u.cpp +++ b/src/solver/conv_asm_1x1u.cpp @@ -48,17 +48,20 @@ MIOPEN_DECLARE_ENV_VAR(MIOPEN_DEBUG_CONV_DIRECT_ASM_1X1U_AI_HEUR) namespace miopen { namespace solver { +namespace conv { + +using ProblemDescription = miopen::conv::ProblemDescription; static inline bool UseSubsample(const ProblemDescription& problem) { return (problem.GetKernelStrideW() > 1 || problem.GetKernelStrideH() > 1) && - problem.direction.IsForward(); + problem.IsDirectionForward(); } static inline bool UseUpsample(const ProblemDescription& problem) { return (problem.GetKernelStrideW() > 1 || problem.GetKernelStrideH() > 1) && - problem.direction.IsBackwardData(); + problem.IsDirectionBackwardData(); } /// After 2x subsampling kernel, image size on asm kernel input becomes 4x (2*2) smaller. @@ -302,7 +305,7 @@ bool PerformanceConfigConvAsm1x1U::IsValidImpl(const ProblemDescription& problem { if((k_mult % elements_in_dword) != 0) return false; - if(problem.direction.IsBackwardData() && !(problem.GetOutChannels_() % k_mult == 0)) + if(problem.IsDirectionBackwardData() && !(problem.GetOutChannels_() % k_mult == 0)) return false; } if(sequence_length > 2) @@ -403,12 +406,12 @@ static std::vector TransformFeatures(const ProblemDescription& problem, s // values nominal param can take). std::vector features(n * n, 0.0f); features[0] = problem.IsFp32() ? 2.0 : 1.0; - int offset = (problem.direction.IsForward() ? 0 : 1) + 1; + int offset = (problem.IsDirectionForward() ? 0 : 1) + 1; features[(offset)*n + offset] = 1.0; features[3 * n + 3] = - float(problem.direction.IsForward() ? problem.GetInChannels_() : problem.GetOutChannels_()); + float(problem.IsDirectionForward() ? problem.GetInChannels_() : problem.GetOutChannels_()); features[4 * n + 4] = - float(problem.direction.IsForward() ? problem.GetOutChannels_() : problem.GetInChannels_()); + float(problem.IsDirectionForward() ? problem.GetOutChannels_() : problem.GetInChannels_()); features[5 * n + 5] = float(problem.GetInHeight_()); features[6 * n + 6] = float(problem.GetInWidth_()); features[7 * n + 7] = float(problem.GetBatchSize_()); @@ -529,7 +532,7 @@ bool ConvAsm1x1U::IsApplicable(const ExecutionContext& ctx, const ProblemDescrip return false; if(problem.HasNonPackedTensors()) return false; - if(!(problem.direction.IsForward() || problem.direction.IsBackwardData())) + if(!(problem.IsDirectionForward() || problem.IsDirectionBackwardData())) return false; if(problem.IsAsymmetricPadH() || problem.IsAsymmetricPadW()) return false; @@ -577,7 +580,7 @@ bool ConvAsm1x1U::IsApplicable(const ExecutionContext& ctx, const ProblemDescrip && problem.GetGroupCount() == 1 && img_hw >= elements_in_dword && (elements_in_dword == 1 || problem.GetOutChannels_() >= 4)); - if(problem.direction.IsBackwardData() && elements_in_dword != 1) + if(problem.IsDirectionBackwardData() && elements_in_dword != 1) ok = ok && (problem.GetOutChannels_() % 4 == 0); if(!ok) { @@ -674,7 +677,7 @@ ConvSolution ConvAsm1x1U::GetSolution(const ExecutionContext& ctx, std::string(" -DMLO_OUT_STRIDE=") + std::to_string(problem.GetOutStrideH_()) + std::string(" -DMLO_IN_BATCH_STRIDE=") + std::to_string(in_batch_stride) + std::string(" -DMLO_IN0_BATCH_STRIDE=") + - std::to_string(problem.direction.IsForward() ? problem.GetInBatchStride_() + std::to_string(problem.IsDirectionForward() ? problem.GetInBatchStride_() : problem.GetOutBatchStride_()) + std::string(" -DMLO_IN0_CHANNEL_STRIDE=") + std::to_string(problem.GetInChannelStride_()) + std::string(" -DMLO_IN0_STRIDE=") + std::to_string(problem.GetInStrideH_()) + @@ -717,7 +720,7 @@ ConvSolution ConvAsm1x1U::GetSolution(const ExecutionContext& ctx, GenerateClangDefsym(options, "wei_w", problem.GetWeightsWidth_()); // S GenerateClangDefsym(options, "pad_h", problem.GetPadH()); GenerateClangDefsym(options, "pad_w", problem.GetPadW()); - GenerateClangDefsym(options, "weights_layout", problem.direction.IsForward() ? 0 : 1); + GenerateClangDefsym(options, "weights_layout", problem.IsDirectionForward() ? 0 : 1); GenerateClangDefsym(options, "vec_c_in", 1); GenerateClangDefsym(options, "vec_k_out", 1); @@ -792,7 +795,7 @@ ConvSolution ConvAsm1x1U::GetSolution(const ExecutionContext& ctx, 1, data_len); // cppcheck-suppress unreadVariable - buff_info fbuf(problem.direction.IsForward() ? MemLayout::NCHW : MemLayout::CNHW, + buff_info fbuf(problem.IsDirectionForward() ? MemLayout::NCHW : MemLayout::CNHW, problem.GetOutChannels_(), problem.GetInChannels_(), 1, @@ -891,21 +894,22 @@ ConvSolution ConvAsm1x1U::GetSolution(const ExecutionContext& ctx, { int N, C, H, W, K, n_groups, out_H, out_W; GetCompiledInParameters(ctx, problem, &N, &C, &H, &W, &K, &n_groups, &out_H, &out_W); - result.invoker_factory = conv::MakeGcnAsm1x1USSInvokerFactory( + result.invoker_factory = miopen::conv::MakeGcnAsm1x1USSInvokerFactory( N, C, K, n_groups, out_H, out_W, result.workspace_sz); } else if(UseUpsample(problem)) { int N, C, H, W, K, n_groups; GetCompiledInParameters(ctx, problem, &N, &C, &H, &W, &K, &n_groups); - result.invoker_factory = - conv::MakeGcnAsm1x1UUSInvokerFactory(N, C, K, n_groups, H, W, result.workspace_sz); + result.invoker_factory = miopen::conv::MakeGcnAsm1x1UUSInvokerFactory( + N, C, K, n_groups, H, W, result.workspace_sz); } else { int N, C, H, W, K, n_groups; GetCompiledInParameters(ctx, problem, &N, &C, &H, &W, &K, &n_groups); - result.invoker_factory = conv::MakeGcnAsm1x1UInvokerFactory(N, C, H, W, K, n_groups); + result.invoker_factory = + miopen::conv::MakeGcnAsm1x1UInvokerFactory(N, C, H, W, K, n_groups); } return result; @@ -918,5 +922,6 @@ PerformanceConfigConvAsm1x1U ConvAsm1x1U::Search(const ExecutionContext& ctx, return GenericSearch(*this, ctx, problem, invoke_ctx); } +} // namespace conv } // namespace solver } // namespace miopen diff --git a/src/solver/conv_asm_1x1u_bias_activ_fused.cpp b/src/solver/conv_asm_1x1u_bias_activ_fused.cpp index 343a9e5887..f4cd160f18 100644 --- a/src/solver/conv_asm_1x1u_bias_activ_fused.cpp +++ b/src/solver/conv_asm_1x1u_bias_activ_fused.cpp @@ -50,13 +50,12 @@ MIOPEN_DECLARE_ENV_VAR(MIOPEN_DEBUG_GCN_ASM_KERNELS) namespace miopen { namespace solver { - namespace fusion { void PerformanceConfigConvBiasActivAsm1x1U::HeuristicInit(const FusionContext& ctx, const FusionDescription& problem) { - auto conv_problem = problem.GetConvProblem(0, conv::Direction::Forward); + auto conv_problem = problem.GetConvProblem(0, miopen::conv::Direction::Forward); auto conv_ctx = ctx.GetConvContext(conv_problem); PerformanceConfigConvAsm1x1U::HeuristicInit(conv_ctx, conv_problem); } @@ -64,13 +63,13 @@ void PerformanceConfigConvBiasActivAsm1x1U::HeuristicInit(const FusionContext& c bool PerformanceConfigConvBiasActivAsm1x1U::SetNextValue(const FusionDescription& problem) { return PerformanceConfigConvAsm1x1U::SetNextValue( - problem.GetConvProblem(0, conv::Direction::Forward)); + problem.GetConvProblem(0, miopen::conv::Direction::Forward)); } bool PerformanceConfigConvBiasActivAsm1x1U::IsValid(const FusionDescription& problem) const { return PerformanceConfigConvAsm1x1U::IsValid( - problem.GetConvProblem(0, conv::Direction::Forward)); + problem.GetConvProblem(0, miopen::conv::Direction::Forward)); } PerformanceConfigConvBiasActivAsm1x1U @@ -104,9 +103,9 @@ ConvBiasActivAsm1x1U::GetSolution(const FusionContext& context, const FusionDescription& problem, const PerformanceConfigConvBiasActivAsm1x1U& config) const { - const auto conv_problem = problem.GetConvProblem(0, conv::Direction::Forward); + const auto conv_problem = problem.GetConvProblem(0, miopen::conv::Direction::Forward); const auto conv_ctx = context.GetConvContext(conv_problem); - ConvAsm1x1U base_sol{}; + conv::ConvAsm1x1U base_sol{}; auto sol = base_sol.GetSolution(conv_ctx, conv_problem, config); @@ -239,8 +238,8 @@ bool ConvBiasActivAsm1x1U::IsApplicable(const FusionContext& context, return false; } - ConvAsm1x1U sol{}; - const auto conv_problem = problem.GetConvProblem(0, conv::Direction::Forward); + conv::ConvAsm1x1U sol{}; + const auto conv_problem = problem.GetConvProblem(0, miopen::conv::Direction::Forward); const auto conv_ctx = context.GetConvContext(conv_problem); if(conv_problem.GetPadH() != conv_problem.GetPadW()) diff --git a/src/solver/conv_asm_1x1u_stride2.cpp b/src/solver/conv_asm_1x1u_stride2.cpp index 3a102d4701..f442a8f410 100644 --- a/src/solver/conv_asm_1x1u_stride2.cpp +++ b/src/solver/conv_asm_1x1u_stride2.cpp @@ -43,6 +43,9 @@ MIOPEN_DECLARE_ENV_VAR(MIOPEN_DEBUG_CONV_DIRECT_ASM_1X1UV2) namespace miopen { namespace solver { +namespace conv { + +using ProblemDescription = miopen::conv::ProblemDescription; /// \todo Rework, factor out to separate header and use in other solvers. /// \todo Clarify functions semantics. @@ -116,7 +119,7 @@ struct config_helper { config_helper(const ProblemDescription& problem, const PerformanceConfigConvAsm1x1UV2& config) { - if(problem.direction.IsForward()) + if(problem.IsDirectionForward()) { stride_w = problem.GetKernelStrideW(); stride_h = problem.GetKernelStrideH(); @@ -386,7 +389,7 @@ bool PerformanceConfigConvAsm1x1UV2::IsValid(const ProblemDescription& problem) const auto c_per_wave = (problem.GetInChannels_() + waves_c_in_group - 1) / waves_c_in_group; const auto c_per_last_wave = problem.GetInChannels_() - (c_per_wave * (waves_c_in_group - 1)); - if(problem.direction.IsBackwardData() && !(problem.GetOutChannels_() % k_mult == 0)) + if(problem.IsDirectionBackwardData() && !(problem.GetOutChannels_() % k_mult == 0)) return false; { @@ -418,8 +421,8 @@ bool PerformanceConfigConvAsm1x1UV2::IsValid(const ProblemDescription& problem) void PerformanceConfigConvAsm1x1UV2::HeuristicInit(const ProblemDescription& problem) { - int c_check = problem.direction.IsForward() ? problem.GetInChannels_() : 0; - int k_check = problem.direction.IsForward() ? 0 : problem.GetInChannels_(); + int c_check = problem.IsDirectionForward() ? problem.GetInChannels_() : 0; + int k_check = problem.IsDirectionForward() ? 0 : problem.GetInChannels_(); chunk_size = 16; dwords_per_ld = 1; c_mult = (c_check % 2 == 0) ? 2 : ((c_check % 3 == 0) ? 3 : 1); @@ -487,7 +490,7 @@ bool ConvAsm1x1UV2::IsApplicable(const ExecutionContext& ctx, return false; if(!problem.Is2d()) return false; - if(!(problem.direction.IsForward() || problem.direction.IsBackwardData())) + if(!(problem.IsDirectionForward() || problem.IsDirectionBackwardData())) return false; if(problem.HasNonPackedTensors()) return false; @@ -651,7 +654,7 @@ ConvSolution ConvAsm1x1UV2::GetSolution(const ExecutionContext& ctx, GenerateClangDefsym(options, "wei_w", problem.GetWeightsWidth_()); // S GenerateClangDefsym(options, "pad_h", problem.GetPadH()); GenerateClangDefsym(options, "pad_w", problem.GetPadW()); - GenerateClangDefsym(options, "weights_layout", problem.direction.IsForward() ? 0 : 1); + GenerateClangDefsym(options, "weights_layout", problem.IsDirectionForward() ? 0 : 1); GenerateClangDefsym(options, "vec_c_in", 1); GenerateClangDefsym(options, "vec_k_out", 1); @@ -677,7 +680,7 @@ ConvSolution ConvAsm1x1UV2::GetSolution(const ExecutionContext& ctx, 1, data_len); // cppcheck-suppress unreadVariable - buff_info fbuf(problem.direction.IsForward() ? MemLayout::NCHW : MemLayout::CNHW, + buff_info fbuf(problem.IsDirectionForward() ? MemLayout::NCHW : MemLayout::CNHW, problem.GetOutChannels_(), problem.GetInChannels_(), 1, @@ -746,7 +749,8 @@ ConvSolution ConvAsm1x1UV2::GetSolution(const ExecutionContext& ctx, { int N, C, H, W, K, n_groups; GetCompiledInParameters(ctx, problem, &N, &C, &H, &W, &K, &n_groups); - result.invoker_factory = conv::MakeGcnAsm1x1UInvokerFactory(N, C, H, W, K, n_groups); + result.invoker_factory = + miopen::conv::MakeGcnAsm1x1UInvokerFactory(N, C, H, W, K, n_groups); } return result; @@ -759,5 +763,6 @@ PerformanceConfigConvAsm1x1UV2 ConvAsm1x1UV2::Search(const ExecutionContext& ctx return GenericSearch(*this, ctx, problem, invoke_ctx); } +} // namespace conv } // namespace solver } // namespace miopen diff --git a/src/solver/conv_asm_3x3u.cpp b/src/solver/conv_asm_3x3u.cpp index c486b4d867..284841e465 100644 --- a/src/solver/conv_asm_3x3u.cpp +++ b/src/solver/conv_asm_3x3u.cpp @@ -45,6 +45,9 @@ MIOPEN_DECLARE_ENV_VAR(MIOPEN_DEBUG_CONV_DIRECT_ASM_3X3U) namespace miopen { namespace solver { +namespace conv { + +using ProblemDescription = miopen::conv::ProblemDescription; namespace { // clang-format off @@ -180,7 +183,7 @@ bool ConvAsm3x3U::IsApplicable(const ExecutionContext& ctx, const ProblemDescrip return false; if(problem.IsAsymmetricPadH() || problem.IsAsymmetricPadW()) return false; - if(!(problem.direction.IsForward() || problem.direction.IsBackwardData())) + if(!(problem.IsDirectionForward() || problem.IsDirectionBackwardData())) return false; if(!ctx.rmv.IsV2orV3()) return false; @@ -285,8 +288,8 @@ ConvSolution ConvAsm3x3U::GetSolution(const ExecutionContext& ctx, {"img_height", problem.GetInHeight_()}, {"input_channels", problem.GetInChannels_()}, {"output_channels", problem.GetOutChannels_()}, - {"weights_layout", problem.direction.IsForward() ? 0 : 1}, - {"reverse_weights", problem.direction.IsForward() ? 0 : 1}, + {"weights_layout", problem.IsDirectionForward() ? 0 : 1}, + {"reverse_weights", problem.IsDirectionForward() ? 0 : 1}, {"ROCM_METADATA_VERSION", ctx.rmv.UseV3() ? 5 : 4}, {"limit_wave_cnt", pcfg->limit_wave_cnt}, {"filters_per_wave", pcfg->filters_per_wave}, @@ -315,7 +318,7 @@ ConvSolution ConvAsm3x3U::GetSolution(const ExecutionContext& ctx, construction_params.kernel_name = "miopenGcnAsmConv3x3U"; result.construction_params.push_back(construction_params); - result.invoker_factory = &conv::MakeGenericXWYPadInvoker; + result.invoker_factory = &miopen::conv::MakeGenericXWYPadInvoker; return result; } @@ -327,5 +330,6 @@ PerformanceConfigConvAsm3x3U ConvAsm3x3U::Search(const ExecutionContext& ctx, return GenericSearch(*this, ctx, problem, invoke_ctx); } +} // namespace conv } // namespace solver } // namespace miopen diff --git a/src/solver/conv_asm_5x10u2v2b1.cpp b/src/solver/conv_asm_5x10u2v2b1.cpp index 98b604ef3e..a46f7bd749 100644 --- a/src/solver/conv_asm_5x10u2v2b1.cpp +++ b/src/solver/conv_asm_5x10u2v2b1.cpp @@ -35,6 +35,9 @@ MIOPEN_DECLARE_ENV_VAR(MIOPEN_DEBUG_CONV_DIRECT_ASM_5X10U2V2) namespace miopen { namespace solver { +namespace conv { + +using ProblemDescription = miopen::conv::ProblemDescription; bool ConvAsm5x10u2v2b1::IsApplicable(const ExecutionContext& ctx, const ProblemDescription& problem) const @@ -66,7 +69,7 @@ bool ConvAsm5x10u2v2b1::IsApplicable(const ExecutionContext& ctx, #endif if(!device_is_gfx8_9_no_xnack) return false; - if(!problem.direction.IsBackwardData()) + if(!problem.IsDirectionBackwardData()) return false; if(!problem.IsLayoutDefault()) return false; @@ -130,8 +133,10 @@ ConvSolution ConvAsm5x10u2v2b1::GetSolution(const ExecutionContext& ctx, constr_params.kernel_name = "miopenConv5x10u2v2b1"; result.construction_params.push_back(constr_params); - result.invoker_factory = &conv::MakeGenericXWYPadInvoker; + result.invoker_factory = &miopen::conv::MakeGenericXWYPadInvoker; return result; } + +} // namespace conv } // namespace solver } // namespace miopen diff --git a/src/solver/conv_asm_5x10u2v2f1.cpp b/src/solver/conv_asm_5x10u2v2f1.cpp index 8d3c1a1716..ebc77c2490 100644 --- a/src/solver/conv_asm_5x10u2v2f1.cpp +++ b/src/solver/conv_asm_5x10u2v2f1.cpp @@ -36,6 +36,9 @@ MIOPEN_DECLARE_ENV_VAR(MIOPEN_DEBUG_CONV_DIRECT_ASM_5X10U2V2) namespace miopen { namespace solver { +namespace conv { + +using ProblemDescription = miopen::conv::ProblemDescription; bool ConvAsm5x10u2v2f1::IsApplicable(const ExecutionContext& ctx, const ProblemDescription& problem) const @@ -67,7 +70,7 @@ bool ConvAsm5x10u2v2f1::IsApplicable(const ExecutionContext& ctx, #endif if(!device_is_gfx8_9_no_xnack) return false; - if(!problem.direction.IsForward()) + if(!problem.IsDirectionForward()) return false; if(!problem.IsLayoutDefault()) return false; @@ -147,9 +150,11 @@ ConvSolution ConvAsm5x10u2v2f1::GetSolution(const ExecutionContext& ctx, construction_params.kernel_name = "miopenConv5x10u2v2f1"; result.construction_params.push_back(construction_params); - result.invoker_factory = &conv::MakeGenericXWYPadInvoker; + result.invoker_factory = &miopen::conv::MakeGenericXWYPadInvoker; return result; } + +} // namespace conv } // namespace solver } // namespace miopen diff --git a/src/solver/conv_asm_7x7c3h224w224k64u2v2p3q3f1.cpp b/src/solver/conv_asm_7x7c3h224w224k64u2v2p3q3f1.cpp index 309b178dfb..5864e1a92e 100644 --- a/src/solver/conv_asm_7x7c3h224w224k64u2v2p3q3f1.cpp +++ b/src/solver/conv_asm_7x7c3h224w224k64u2v2p3q3f1.cpp @@ -36,6 +36,9 @@ MIOPEN_DECLARE_ENV_VAR(MIOPEN_DEBUG_CONV_DIRECT_ASM_7X7C3H224W224) namespace miopen { namespace solver { +namespace conv { + +using ProblemDescription = miopen::conv::ProblemDescription; bool ConvAsm7x7c3h224w224k64u2v2p3q3f1::IsApplicable(const ExecutionContext& ctx, const ProblemDescription& problem) const @@ -69,7 +72,7 @@ bool ConvAsm7x7c3h224w224k64u2v2p3q3f1::IsApplicable(const ExecutionContext& ctx if(!(name == "gfx800" || name == "gfx802" || name == "gfx803" || name == "gfx804" || name == "gfx900" || name == "gfx904" || name == "gfx906" || name == "gfx908")) return false; - if(!problem.direction.IsForward()) + if(!problem.IsDirectionForward()) return false; if(!problem.IsLayoutDefault()) return false; @@ -124,8 +127,10 @@ ConvSolution ConvAsm7x7c3h224w224k64u2v2p3q3f1::GetSolution(const ExecutionConte constr_params.kernel_name = "miopenGcnAsmConv7x7c3h224w224k64u2v2p3q3f1"; result.construction_params.push_back(constr_params); - result.invoker_factory = &conv::MakeGenericXWYPadInvoker; + result.invoker_factory = &miopen::conv::MakeGenericXWYPadInvoker; return result; } + +} // namespace conv } // namespace solver } // namespace miopen diff --git a/src/solver/conv_asm_dir_BwdWrW1x1.cpp b/src/solver/conv_asm_dir_BwdWrW1x1.cpp index 04e0f89b4e..79c3046b83 100644 --- a/src/solver/conv_asm_dir_BwdWrW1x1.cpp +++ b/src/solver/conv_asm_dir_BwdWrW1x1.cpp @@ -43,6 +43,9 @@ MIOPEN_DECLARE_ENV_VAR(MIOPEN_DEBUG_CONV_DIRECT_ASM_WRW1X1) namespace miopen { namespace solver { +namespace conv { + +using ProblemDescription = miopen::conv::ProblemDescription; static inline bool UseSubsample(const ProblemDescription& problem) { @@ -310,7 +313,6 @@ bool PerformanceConfigConvAsmBwdWrW1x1::IsValidValue() const bool PerformanceConfigConvAsmBwdWrW1x1::IsValid(const ExecutionContext& ctx, const ProblemDescription& problem) const { - if(!IsValidValue()) return false; if(!((chunk_size * c_per_gpr) >= 16 && ((chunk_size == 1 || c_per_gpr * chunk_size <= 16)))) @@ -477,7 +479,7 @@ bool ConvAsmBwdWrW1x1::IsApplicable(const ExecutionContext& ctx, return false; if(!problem.Is2d()) return false; - if(!problem.direction.IsBackwardWrW()) + if(!problem.IsDirectionBackwardWrW()) return false; if(problem.HasNonPackedTensors()) return false; @@ -800,14 +802,15 @@ ConvSolution ConvAsmBwdWrW1x1::GetSolution(const ExecutionContext& ctx, { result.invoker_factory = [N, C, H, W, K, n_groups](const std::vector& kernels) { return [=](const Handle& handle, const AnyInvokeParams& primitive_params) { - const auto ss_kernel = handle.Run(kernels[0]); - const auto main_kernel = handle.Run(kernels[1]); - const auto& invoke_params = primitive_params.CastTo(); - const auto& x = invoke_params.tensors.x; - const auto& dy = invoke_params.tensors.dy; - const auto& dw = invoke_params.tensors.dw; - const auto& workSpace = invoke_params.workSpace; - auto elapsed = 0.f; + const auto ss_kernel = handle.Run(kernels[0]); + const auto main_kernel = handle.Run(kernels[1]); + const auto& invoke_params = + primitive_params.CastTo(); + const auto& x = invoke_params.tensors.x; + const auto& dy = invoke_params.tensors.dy; + const auto& dw = invoke_params.tensors.dw; + const auto& workSpace = invoke_params.workSpace; + auto elapsed = 0.f; if(invoke_params.type != InvokeType::AutoTune) { @@ -833,13 +836,14 @@ ConvSolution ConvAsmBwdWrW1x1::GetSolution(const ExecutionContext& ctx, { result.invoker_factory = [N, C, H, W, K, n_groups](const std::vector& kernels) { return [=](const Handle& handle, const AnyInvokeParams& primitive_params) { - const auto main_kernel = handle.Run(kernels[0]); - const auto& invoke_params = primitive_params.CastTo(); - const auto& x = invoke_params.tensors.x; - const auto& dy = invoke_params.tensors.dy; - const auto& dw = invoke_params.tensors.dw; - int unused = 0; - int* return_addr = nullptr; + const auto main_kernel = handle.Run(kernels[0]); + const auto& invoke_params = + primitive_params.CastTo(); + const auto& x = invoke_params.tensors.x; + const auto& dy = invoke_params.tensors.dy; + const auto& dw = invoke_params.tensors.dw; + int unused = 0; + int* return_addr = nullptr; main_kernel(N, C, H, W, K, n_groups, unused, unused, x, dw, dy, return_addr); }; }; @@ -855,5 +859,6 @@ PerformanceConfigConvAsmBwdWrW1x1 ConvAsmBwdWrW1x1::Search(const ExecutionContex return GenericSearch(*this, ctx, problem, invoke_ctx); } +} // namespace conv } // namespace solver } // namespace miopen diff --git a/src/solver/conv_asm_dir_BwdWrW3x3.cpp b/src/solver/conv_asm_dir_BwdWrW3x3.cpp index 6ae7330c63..f6972648cb 100644 --- a/src/solver/conv_asm_dir_BwdWrW3x3.cpp +++ b/src/solver/conv_asm_dir_BwdWrW3x3.cpp @@ -47,6 +47,9 @@ MIOPEN_DECLARE_ENV_VAR(MIOPEN_DEBUG_CONV_DIRECT_ASM_WRW3X3) namespace miopen { namespace solver { +namespace conv { + +using ProblemDescription = miopen::conv::ProblemDescription; inline static bool Inc_1_2_4_8(int& v) { @@ -141,7 +144,7 @@ static bool IsReverseInOutAllowed(const ProblemDescription& problem) return problem.GetKernelStrideW() == 1 && problem.GetKernelStrideH() == 1; } -inline int elements_in_dword(const ProblemDescription& problem) { return problem.IsFp16() ? 2 : 1; } +static int elements_in_dword(const ProblemDescription& problem) { return problem.IsFp16() ? 2 : 1; } bool PerformanceConfigAsmDirect3x3WrW::IsValid(const ExecutionContext& ctx, const ProblemDescription& problem) const @@ -360,7 +363,7 @@ bool ConvAsmBwdWrW3x3::IsApplicable(const ExecutionContext& ctx, return false; if(!problem.Is2d()) return false; - if(!problem.direction.IsBackwardWrW()) + if(!problem.IsDirectionBackwardWrW()) return false; if(problem.HasNonPackedTensors()) return false; @@ -548,7 +551,7 @@ ConvSolution ConvAsmBwdWrW3x3::GetSolution(const ExecutionContext& ctx, result.invoker_factory = [N, C, H, W, K, n_groups](const std::vector& kernels) { return [=](const Handle& handle, const AnyInvokeParams& primitive_params) { const auto k = handle.Run(kernels[0]); - const auto& invoke_params = primitive_params.CastTo(); + const auto& invoke_params = primitive_params.CastTo(); int unused = 0; int* return_addr = nullptr; const auto& x = invoke_params.tensors.x; @@ -568,5 +571,6 @@ PerformanceConfigAsmDirect3x3WrW ConvAsmBwdWrW3x3::Search(const ExecutionContext return GenericSearch(*this, ctx, problem, invoke_ctx); } +} // namespace conv } // namespace solver } // namespace miopen diff --git a/src/solver/conv_asm_implicit_gemm_bwd_v4r1_dynamic.cpp b/src/solver/conv_asm_implicit_gemm_bwd_v4r1_dynamic.cpp index 6ff60242dc..207df443a6 100644 --- a/src/solver/conv_asm_implicit_gemm_bwd_v4r1_dynamic.cpp +++ b/src/solver/conv_asm_implicit_gemm_bwd_v4r1_dynamic.cpp @@ -35,6 +35,9 @@ MIOPEN_DECLARE_ENV_VAR(MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_ASM_BWD_V4R1) namespace miopen { namespace solver { +namespace conv { + +using ProblemDescription = miopen::conv::ProblemDescription; static inline bool FindImplicitGemmDynamicKernelBwd(const ProblemDescription& problem, std::string& kernel_name, @@ -140,7 +143,7 @@ bool ConvAsmImplicitGemmV4R1DynamicBwd::IsApplicable(const ExecutionContext& ctx if(!ctx.use_asm_kernels) return false; - if(!problem.direction.IsBackwardData()) + if(!problem.IsDirectionBackwardData()) return false; if(problem.HasNonPackedTensors()) @@ -208,10 +211,12 @@ ConvSolution ConvAsmImplicitGemmV4R1DynamicBwd::GetSolution(const ExecutionConte kernel.comp_options = options.str(); - result.invoker_factory = conv::MakeImplGemmDynamicBackwardDataInvokerFactory(problem, int(0)); + result.invoker_factory = + miopen::conv::MakeImplGemmDynamicBackwardDataInvokerFactory(problem, int(0)); result.construction_params.push_back(kernel); return result; } +} // namespace conv } // namespace solver } // namespace miopen diff --git a/src/solver/conv_asm_implicit_gemm_gtc_bwd.cpp b/src/solver/conv_asm_implicit_gemm_gtc_bwd.cpp index d0551138de..3679f632e6 100644 --- a/src/solver/conv_asm_implicit_gemm_gtc_bwd.cpp +++ b/src/solver/conv_asm_implicit_gemm_gtc_bwd.cpp @@ -35,6 +35,9 @@ MIOPEN_DECLARE_ENV_VAR(MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_ASM_BWD_GTC_XDLOPS) namespace miopen { namespace solver { +namespace conv { + +using ProblemDescription = miopen::conv::ProblemDescription; static inline const std::vector& GetImplicitGemmGtcDynamicBwdTunablesList(const ProblemDescription& problem) @@ -986,7 +989,7 @@ bool ConvAsmImplicitGemmGTCDynamicBwdXdlops::IsApplicable(const ExecutionContext if(!ctx.use_asm_kernels) return false; - if(!problem.direction.IsBackwardData()) + if(!problem.IsDirectionBackwardData()) return false; if(!problem.Is2d()) @@ -1053,10 +1056,12 @@ ConvAsmImplicitGemmGTCDynamicBwdXdlops::GetSolution(const ExecutionContext& ctx, MIOPEN_LOG_I2(kernel.kernel_file + ":" + kernel.kernel_name); - result.invoker_factory = conv::MakeImplGemmDynamicBackwardDataInvokerFactory(problem, cfg); + result.invoker_factory = + miopen::conv::MakeImplGemmDynamicBackwardDataInvokerFactory(problem, cfg); result.construction_params.push_back(kernel); return result; } +} // namespace conv } // namespace solver } // namespace miopen diff --git a/src/solver/conv_asm_implicit_gemm_gtc_bwd_nhwc.cpp b/src/solver/conv_asm_implicit_gemm_gtc_bwd_nhwc.cpp index 42f1e9f03e..426131b99b 100644 --- a/src/solver/conv_asm_implicit_gemm_gtc_bwd_nhwc.cpp +++ b/src/solver/conv_asm_implicit_gemm_gtc_bwd_nhwc.cpp @@ -40,6 +40,9 @@ MIOPEN_DECLARE_ENV_VAR(MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_ASM_PK_ATOMIC_ADD_FP16) namespace miopen { namespace solver { +namespace conv { + +using ProblemDescription = miopen::conv::ProblemDescription; static const inline std::vector& GetBwdXdlopsNHWCConfigList() @@ -388,11 +391,13 @@ GetBwdXdlopsNHWCConfigLargestTileFp32() { return {"bwd", "nhwc", miopenFloat, 0, 1, 256, 64, 16, 32, 32, 2, 1, 1, 2, 2, 1, 0, 0, 0, 0, { 1, 4, 4, 1}, { 1, 4, 1, 64}, { 1, 4, 1, 1}, { 1, 4, 1, 64}}; } + static inline PerformanceConfigAsmImplicitGemmGTCBwdXdlopsNHWC GetBwdXdlopsNHWCConfigLargestTileFp16() { return {"bwd", "nhwc", miopenHalf, 0, 1, 256, 256, 32, 32, 32, 8, 2, 2, 2, 2, 1, 0, 0, 0, 0, { 1, 8, 4, 1}, { 1, 4, 1, 64}, { 1, 8, 1, 4}, { 1, 4, 1, 64}}; } + static inline PerformanceConfigAsmImplicitGemmGTCBwdXdlopsNHWC GetBwdXdlopsNHWCConfigLargestTileBf16() { @@ -744,6 +749,7 @@ void PerformanceConfigAsmImplicitGemmGTCBwdXdlopsNHWC::HeuristicInit( find_with_gemm_k_pad(); } } + bool PerformanceConfigAsmImplicitGemmGTCBwdXdlopsNHWC::IsValidValue() const { if(IsDefaultConstructed()) @@ -753,6 +759,7 @@ bool PerformanceConfigAsmImplicitGemmGTCBwdXdlopsNHWC::IsValidValue() const return true; return miopen::any_of(config_list, [&](auto v) { return (*this == v); }); } + bool PerformanceConfigAsmImplicitGemmGTCBwdXdlopsNHWC::SetNextValue(const ProblemDescription&) { if(use_spare_set) @@ -787,6 +794,7 @@ bool PerformanceConfigAsmImplicitGemmGTCBwdXdlopsNHWC::SetNextValue(const Proble return false; } } + bool PerformanceConfigAsmImplicitGemmGTCBwdXdlopsNHWC::IsValid( const ProblemDescription& problem) const { @@ -894,6 +902,7 @@ ConvAsmImplicitGemmGTCDynamicBwdXdlopsNHWC::GetDefaultPerformanceConfig( MIOPEN_LOG_I(pp.ToString()); return pp; } + bool ConvAsmImplicitGemmGTCDynamicBwdXdlopsNHWC::IsValidPerformanceConfig( const ExecutionContext&, const ProblemDescription& problem, @@ -932,7 +941,7 @@ bool ConvAsmImplicitGemmGTCDynamicBwdXdlopsNHWC::IsApplicable( if(!ctx.use_asm_kernels) return false; - if(!problem.direction.IsBackwardData()) + if(!problem.IsDirectionBackwardData()) return false; if(!problem.Is2d()) @@ -1156,9 +1165,10 @@ ConvSolution ConvAsmImplicitGemmGTCDynamicBwdXdlopsNHWC::GetSolution( MIOPEN_LOG_I2(SolverDbId() << ": " << config.ToString() << msg.str()); result.invoker_factory = - conv::MakeImplGemmDynamicBackwardDataXdlopsNHWCInvokerFactory(ctx, problem, config); + miopen::conv::MakeImplGemmDynamicBackwardDataXdlopsNHWCInvokerFactory(ctx, problem, config); return result; } +} // namespace conv } // namespace solver } // namespace miopen diff --git a/src/solver/conv_asm_implicit_gemm_gtc_fwd.cpp b/src/solver/conv_asm_implicit_gemm_gtc_fwd.cpp index 5bc9da9b81..57834af505 100644 --- a/src/solver/conv_asm_implicit_gemm_gtc_fwd.cpp +++ b/src/solver/conv_asm_implicit_gemm_gtc_fwd.cpp @@ -37,6 +37,9 @@ MIOPEN_DECLARE_ENV_VAR(MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_ASM_FWD_GTC_XDLOPS) namespace miopen { namespace solver { +namespace conv { + +using ProblemDescription = miopen::conv::ProblemDescription; static const inline std::vector& GetImplicitGemmGtcDynamicFwdXdlopsTunablesList() @@ -1322,9 +1325,8 @@ GetImplicitGemmGtcDynamicFwdXdlopsTunablesList() } // This is a helper function for selecting better performing config -bool mayHaveBiggerN1bClusterSize(int gemm_m, - int gemm_n, - const TunableImplicitGemmGTCDynamic_t& tunable) +static bool +mayHaveBiggerN1bClusterSize(int gemm_m, int gemm_n, const TunableImplicitGemmGTCDynamic_t& tunable) { float n_times_m = static_cast(gemm_n) / static_cast(gemm_m); @@ -1512,7 +1514,7 @@ bool ConvAsmImplicitGemmGTCDynamicFwdXdlops::IsApplicable(const ExecutionContext if(!ctx.use_asm_kernels) return false; - if(!problem.direction.IsForward()) + if(!problem.IsDirectionForward()) return false; if(!problem.Is2d()) @@ -1590,11 +1592,12 @@ ConvAsmImplicitGemmGTCDynamicFwdXdlops::GetSolution(const ExecutionContext& ctx, MIOPEN_LOG_I2(kernel.kernel_file + ":" + kernel.kernel_name); result.invoker_factory = - conv::MakeImplGemmDynamicForwardInvokerFactory(problem, - cfg); + miopen::conv::MakeImplGemmDynamicForwardInvokerFactory( + problem, cfg); result.construction_params.push_back(kernel); return result; } +} // namespace conv } // namespace solver } // namespace miopen diff --git a/src/solver/conv_asm_implicit_gemm_gtc_fwd_nchwc.cpp b/src/solver/conv_asm_implicit_gemm_gtc_fwd_nchwc.cpp index d7b692cfb3..8e789ddc0c 100644 --- a/src/solver/conv_asm_implicit_gemm_gtc_fwd_nchwc.cpp +++ b/src/solver/conv_asm_implicit_gemm_gtc_fwd_nchwc.cpp @@ -39,6 +39,9 @@ MIOPEN_DECLARE_ENV_VAR(MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_ASM_PK_ATOMIC_ADD_FP16) namespace miopen { namespace solver { +namespace conv { + +using ProblemDescription = miopen::conv::ProblemDescription; static const inline std::vector& GetFwdDlopsNCHWCConfigList() @@ -455,6 +458,7 @@ bool PerformanceConfigAsmImplicitGemmGTCFwdDlopsNCHWC::SetNextValue(const Proble return false; } } + bool PerformanceConfigAsmImplicitGemmGTCFwdDlopsNCHWC::IsValidValue() const { if(IsDefaultConstructed()) @@ -464,6 +468,7 @@ bool PerformanceConfigAsmImplicitGemmGTCFwdDlopsNCHWC::IsValidValue() const return true; return miopen::any_of(config_list, [&](auto v) { return (*this == v); }); } + bool PerformanceConfigAsmImplicitGemmGTCFwdDlopsNCHWC::IsValid( const ProblemDescription& problem) const { @@ -533,6 +538,7 @@ bool ConvAsmImplicitGemmGTCDynamicFwdDlopsNCHWC::IsValidPerformanceConfig( { return config.IsValidValue() && config.IsValid(problem); } + PerformanceConfigAsmImplicitGemmGTCFwdDlopsNCHWC ConvAsmImplicitGemmGTCDynamicFwdDlopsNCHWC::Search(const ExecutionContext& ctx, const ProblemDescription& problem, @@ -554,7 +560,7 @@ bool ConvAsmImplicitGemmGTCDynamicFwdDlopsNCHWC::IsApplicable( if(!ctx.use_asm_kernels) return false; - if(!problem.direction.IsForward()) + if(!problem.IsDirectionForward()) return false; if(!problem.Is2d()) @@ -634,9 +640,10 @@ ConvSolution ConvAsmImplicitGemmGTCDynamicFwdDlopsNCHWC::GetSolution( MIOPEN_LOG_I2(SolverDbId() << ": " << config.ToString() << msg.str()); result.invoker_factory = - conv::MakeImplGemmDynamicForwardDlopsNCHWCInvokerFactory(problem, config); + miopen::conv::MakeImplGemmDynamicForwardDlopsNCHWCInvokerFactory(problem, config); return result; } +} // namespace conv } // namespace solver } // namespace miopen diff --git a/src/solver/conv_asm_implicit_gemm_gtc_fwd_nhwc.cpp b/src/solver/conv_asm_implicit_gemm_gtc_fwd_nhwc.cpp index 2e38366b74..219e2b9d78 100644 --- a/src/solver/conv_asm_implicit_gemm_gtc_fwd_nhwc.cpp +++ b/src/solver/conv_asm_implicit_gemm_gtc_fwd_nhwc.cpp @@ -40,6 +40,9 @@ MIOPEN_DECLARE_ENV_VAR(MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_ASM_PK_ATOMIC_ADD_FP16) namespace miopen { namespace solver { +namespace conv { + +using ProblemDescription = miopen::conv::ProblemDescription; static const inline std::vector& GetFwdXdlopsNHWCConfigList() @@ -318,11 +321,13 @@ GetFwdXdlopsNHWCConfigLargestTileFp32() { return {"fwd", "nhwc", miopenFloat, 0, 1, 256, 64, 16, 32, 32, 2, 1, 1, 2, 2, 0, 0, 0, 0, 0, { 1, 4, 4, 1}, { 1, 4, 1, 64}, { 1, 4, 1, 1}, { 1, 4, 1, 64}}; } + static inline PerformanceConfigAsmImplicitGemmGTCFwdXdlopsNHWC GetFwdXdlopsNHWCConfigLargestTileFp16() { return {"fwd", "nhwc", miopenHalf, 0, 1, 256, 128, 32, 32, 32, 8, 2, 1, 2, 2, 0, 0, 0, 0, 0, { 1, 8, 4, 1}, { 1, 4, 1, 64}, { 1, 8, 2, 1}, { 1, 4, 1, 64}}; } + static inline PerformanceConfigAsmImplicitGemmGTCFwdXdlopsNHWC GetFwdXdlopsNHWCConfigLargestTileBf16() { @@ -870,7 +875,7 @@ bool ConvAsmImplicitGemmGTCDynamicFwdXdlopsNHWC::IsApplicable( if(!ctx.use_asm_kernels) return false; - if(!problem.direction.IsForward()) + if(!problem.IsDirectionForward()) return false; if(!problem.Is2d()) @@ -918,6 +923,7 @@ bool ConvAsmImplicitGemmGTCDynamicFwdXdlopsNHWC::IsApplicable( return true; } + ConvSolution ConvAsmImplicitGemmGTCDynamicFwdXdlopsNHWC::GetSolution( const ExecutionContext& ctx, const ProblemDescription& problem, @@ -1041,9 +1047,10 @@ ConvSolution ConvAsmImplicitGemmGTCDynamicFwdXdlopsNHWC::GetSolution( MIOPEN_LOG_I2(SolverDbId() << ": " << config.ToString() << msg.str()); result.invoker_factory = - conv::MakeImplGemmDynamicForwardXdlopsNHWCInvokerFactory(ctx, problem, config); + miopen::conv::MakeImplGemmDynamicForwardXdlopsNHWCInvokerFactory(ctx, problem, config); return result; } +} // namespace conv } // namespace solver } // namespace miopen diff --git a/src/solver/conv_asm_implicit_gemm_gtc_perf_config.cpp b/src/solver/conv_asm_implicit_gemm_gtc_perf_config.cpp index 8560c65052..88c322c9c2 100644 --- a/src/solver/conv_asm_implicit_gemm_gtc_perf_config.cpp +++ b/src/solver/conv_asm_implicit_gemm_gtc_perf_config.cpp @@ -30,6 +30,7 @@ namespace miopen { namespace solver { +namespace conv { PerformanceConfigAsmImplicitGemmGTC::PerformanceConfigAsmImplicitGemmGTC( std::string dir, @@ -208,6 +209,7 @@ bool PerformanceConfigAsmImplicitGemmGTC::operator==( && std::equal(std::begin(tensor_b_cluster_lengths), std::end(tensor_b_cluster_lengths), std::begin(other.tensor_b_cluster_lengths)); // clang-format on } + void PerformanceConfigAsmImplicitGemmGTC::CopyParameters( const PerformanceConfigAsmImplicitGemmGTC& other) { @@ -289,6 +291,7 @@ std::string PerformanceConfigAsmImplicitGemmGTC::ToKernelName(const ExecutionCon return kernel_name.str(); } + int PerformanceConfigAsmImplicitGemmGTC::BlockSize() const { return std::accumulate(std::begin(tensor_a_cluster_lengths), @@ -445,6 +448,7 @@ bool PerformanceConfigAsmImplicitGemmGTCvector::operator==( && std::equal(std::begin(tensor_b_cluster_lengths), std::end(tensor_b_cluster_lengths), std::begin(other.tensor_b_cluster_lengths)); // clang-format on } + void PerformanceConfigAsmImplicitGemmGTCvector::CopyParameters( const PerformanceConfigAsmImplicitGemmGTCvector& other) { @@ -509,6 +513,7 @@ PerformanceConfigAsmImplicitGemmGTCvector::ToKernelName(const ExecutionContext& return kernel_name.str(); } + int PerformanceConfigAsmImplicitGemmGTCvector::BlockSize() const { return std::accumulate(std::begin(tensor_a_cluster_lengths), @@ -517,5 +522,6 @@ int PerformanceConfigAsmImplicitGemmGTCvector::BlockSize() const std::multiplies()); } +} // namespace conv } // namespace solver } // namespace miopen diff --git a/src/solver/conv_asm_implicit_gemm_gtc_wrw_nhwc.cpp b/src/solver/conv_asm_implicit_gemm_gtc_wrw_nhwc.cpp index 2959b114fb..2cfaeeb2ee 100644 --- a/src/solver/conv_asm_implicit_gemm_gtc_wrw_nhwc.cpp +++ b/src/solver/conv_asm_implicit_gemm_gtc_wrw_nhwc.cpp @@ -43,6 +43,9 @@ MIOPEN_DECLARE_ENV_VAR(MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_ASM_PK_ATOMIC_ADD_FP16) namespace miopen { namespace solver { +namespace conv { + +using ProblemDescription = miopen::conv::ProblemDescription; static inline std::size_t GetTypeSize(const std::string& s) { @@ -316,11 +319,13 @@ GetWrwXdlopsNHWCConfigLargestTileFp32() { return {"wrw", "nhwc", miopenFloat, 0, 0, 256, 128, 16, 32, 32, 2, 2, 1, 2, 2, 0, 0, 0, 0, 0, { 1, 1, 1,16}, { 1, 16, 1, 16}, { 1, 1, 1, 8}, { 1, 16, 1, 16}}; } + static inline PerformanceConfigAsmImplicitGemmGTCWrwXdlopsNHWC GetWrwXdlopsNHWCConfigLargestTileFp16() { return {"wrw", "nhwc", miopenHalf, 0, 1, 256, 256, 32, 32, 32, 8, 2, 2, 2, 2, 0, 0, 0, 0, 0, { 1, 4, 1, 8}, { 1, 8, 1, 32}, { 1, 4, 1, 8}, { 1, 8, 1, 32}}; } + static inline PerformanceConfigAsmImplicitGemmGTCWrwXdlopsNHWC GetWrwXdlopsNHWCConfigLargestTileBf16() { @@ -831,6 +836,7 @@ bool ConvAsmImplicitGemmGTCDynamicWrwXdlopsNHWC::IsValidPerformanceConfig( { return config.IsValidValue() && config.IsValid(problem); } + PerformanceConfigAsmImplicitGemmGTCWrwXdlopsNHWC ConvAsmImplicitGemmGTCDynamicWrwXdlopsNHWC::Search(const ExecutionContext& ctx, const ProblemDescription& problem, @@ -868,7 +874,7 @@ bool ConvAsmImplicitGemmGTCDynamicWrwXdlopsNHWC::IsApplicable( if(!ctx.use_asm_kernels) return false; - if(!problem.direction.IsBackwardWrW()) + if(!problem.IsDirectionBackwardWrW()) return false; if(!problem.Is2d()) @@ -918,7 +924,7 @@ bool ConvAsmImplicitGemmGTCDynamicWrwXdlopsNHWC::IsApplicable( } static std::vector -ComputeDynamicIGemmWrwKernelArgsNHWC(const conv::ProblemDescription& problem, +ComputeDynamicIGemmWrwKernelArgsNHWC(const ProblemDescription& problem, const int gemm_k_global_splits, const int gemm_k_per_wg, const int splits_4G) @@ -1233,7 +1239,7 @@ ConvSolution ConvAsmImplicitGemmGTCDynamicWrwXdlopsNHWC::GetSolution( result.invoker_factory = [=](const std::vector& kernels) mutable { return [=](const Handle& handle, const AnyInvokeParams& primitive_parameters) mutable { decltype(auto) wrw_invoke_params = - primitive_parameters.CastTo(); + primitive_parameters.CastTo(); const auto& tensors = wrw_invoke_params.tensors; const auto ker = handle.Run( kernels[(isGfx90aFp16altSupport && wrw_invoke_params.gfx90aFp16alt) ? 1 : 0]); @@ -1335,7 +1341,7 @@ ConvSolution ConvAsmImplicitGemmGTCDynamicWrwXdlopsNHWC::GetSolution( result.invoker_factory = [=](const std::vector& kernels) mutable { return [=](const Handle& handle, const AnyInvokeParams& primitive_parameters) mutable { decltype(auto) wrw_invoke_params = - primitive_parameters.CastTo(); + primitive_parameters.CastTo(); const auto& tensors = wrw_invoke_params.tensors; const auto ker = handle.Run( kernels[(isGfx90aFp16altSupport && wrw_invoke_params.gfx90aFp16alt) ? 1 : 0]); @@ -1426,5 +1432,6 @@ ConvSolution ConvAsmImplicitGemmGTCDynamicWrwXdlopsNHWC::GetSolution( return result; } +} // namespace conv } // namespace solver } // namespace miopen diff --git a/src/solver/conv_asm_implicit_gemm_v4r1_dynamic.cpp b/src/solver/conv_asm_implicit_gemm_v4r1_dynamic.cpp index baceb8089f..45f611fabf 100644 --- a/src/solver/conv_asm_implicit_gemm_v4r1_dynamic.cpp +++ b/src/solver/conv_asm_implicit_gemm_v4r1_dynamic.cpp @@ -35,6 +35,9 @@ MIOPEN_DECLARE_ENV_VAR(MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_ASM_FWD_V4R1_1X1) namespace miopen { namespace solver { +namespace conv { + +using ProblemDescription = miopen::conv::ProblemDescription; struct TunableImplicitGemmV4R1Dynamic { @@ -286,7 +289,7 @@ bool ConvAsmImplicitGemmV4R1DynamicFwd::IsApplicable(const ExecutionContext& ctx if(!ctx.use_asm_kernels) return false; - if(!problem.direction.IsForward()) + if(!problem.IsDirectionForward()) return false; if(!problem.Is2d()) @@ -332,7 +335,7 @@ bool ConvAsmImplicitGemmV4R1DynamicFwd_1x1::IsApplicable(const ExecutionContext& if(!ctx.use_asm_kernels) return false; - if(!problem.direction.IsForward()) + if(!problem.IsDirectionForward()) return false; if(!problem.Is2d()) @@ -402,12 +405,12 @@ static inline ConvSolution GetSolutionBase(const ExecutionContext& ctx, MIOPEN_LOG_I2(kernel.kernel_file + ":" + kernel.kernel_name); if(kernel_is_1x1) - result.invoker_factory = conv::MakeImplGemmDynamicForward1x1InvokerFactory(problem); + result.invoker_factory = miopen::conv::MakeImplGemmDynamicForward1x1InvokerFactory(problem); else { int packed_value = 0; result.invoker_factory = - conv::MakeImplGemmDynamicForwardInvokerFactory(problem, packed_value); + miopen::conv::MakeImplGemmDynamicForwardInvokerFactory(problem, packed_value); } result.construction_params.push_back(kernel); return result; @@ -446,5 +449,6 @@ ConvAsmImplicitGemmV4R1DynamicFwd_1x1::GetSolution(const ExecutionContext& ctx, return GetSolutionBase(ctx, problem, *it, AsmImplicitGemmV4R1_1x1); } +} // namespace conv } // namespace solver } // namespace miopen diff --git a/src/solver/conv_asm_implicit_gemm_wrw_gtc_dynamic_xdlops.cpp b/src/solver/conv_asm_implicit_gemm_wrw_gtc_dynamic_xdlops.cpp index aa84c8c76e..19e8b29282 100644 --- a/src/solver/conv_asm_implicit_gemm_wrw_gtc_dynamic_xdlops.cpp +++ b/src/solver/conv_asm_implicit_gemm_wrw_gtc_dynamic_xdlops.cpp @@ -38,6 +38,9 @@ MIOPEN_DECLARE_ENV_VAR(MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_ASM_WRW_GTC_XDLOPS) namespace miopen { namespace solver { +namespace conv { + +using ProblemDescription = miopen::conv::ProblemDescription; static const inline std::vector& GetImplicitGemmWrwGTCDynamicXdlopsKernelList() @@ -444,7 +447,7 @@ static inline int if_gemm_k_global_split(const ProblemDescription& problem, } inline std::vector -ComputeDynamicIGemmWrwKernelArgs(const conv::ProblemDescription& problem, +ComputeDynamicIGemmWrwKernelArgs(const ProblemDescription& problem, const int log2_gemm_k_global_splits, const int nxb, const int gemm_k_per_block) @@ -827,7 +830,7 @@ bool ConvAsmImplicitGemmGTCDynamicWrwXdlops::IsApplicable(const ExecutionContext if(!ctx.use_asm_kernels) return false; - if(!problem.direction.IsBackwardWrW()) + if(!problem.IsDirectionBackwardWrW()) return false; if(!problem.Is2d()) @@ -934,7 +937,7 @@ ConvAsmImplicitGemmGTCDynamicWrwXdlops::GetSolution(const ExecutionContext& ctx, result.invoker_factory = [=](const std::vector& kernels) mutable { return [=](const Handle& handle, const AnyInvokeParams& primitive_parameters) mutable { decltype(auto) wrw_invoke_params = - primitive_parameters.CastTo(); + primitive_parameters.CastTo(); const auto& tensors = wrw_invoke_params.tensors; const auto k = handle.Run(kernels[0]); float elapsed = 0; @@ -967,7 +970,7 @@ ConvAsmImplicitGemmGTCDynamicWrwXdlops::GetSolution(const ExecutionContext& ctx, result.invoker_factory = [=](const std::vector& kernels) mutable { return [=](const Handle& handle, const AnyInvokeParams& primitive_parameters) mutable { decltype(auto) wrw_invoke_params = - primitive_parameters.CastTo(); + primitive_parameters.CastTo(); const auto& tensors = wrw_invoke_params.tensors; const auto k = handle.Run(kernels[0]); const auto& workSpace = wrw_invoke_params.workSpace; @@ -1017,7 +1020,7 @@ ConvAsmImplicitGemmGTCDynamicWrwXdlops::GetSolution(const ExecutionContext& ctx, result.invoker_factory = [=](const std::vector& kernels) mutable { return [=](const Handle& handle, const AnyInvokeParams& primitive_parameters) mutable { decltype(auto) wrw_invoke_params = - primitive_parameters.CastTo(); + primitive_parameters.CastTo(); const auto& tensors = wrw_invoke_params.tensors; const auto k = handle.Run(kernels[0]); @@ -1033,5 +1036,6 @@ ConvAsmImplicitGemmGTCDynamicWrwXdlops::GetSolution(const ExecutionContext& ctx, return result; } +} // namespace conv } // namespace solver } // namespace miopen diff --git a/src/solver/conv_asm_implicit_gemm_wrw_v4r1_dynamic.cpp b/src/solver/conv_asm_implicit_gemm_wrw_v4r1_dynamic.cpp index 7adb28fdae..bef7f4841c 100644 --- a/src/solver/conv_asm_implicit_gemm_wrw_v4r1_dynamic.cpp +++ b/src/solver/conv_asm_implicit_gemm_wrw_v4r1_dynamic.cpp @@ -36,13 +36,16 @@ MIOPEN_DECLARE_ENV_VAR(MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_ASM_WRW_V4R1) namespace miopen { namespace solver { +namespace conv { + +using ProblemDescription = miopen::conv::ProblemDescription; // 3 possible configs: //{ 16, 128, 16, 2, 4, 4, 4, 4, 4, 4, 16, 1, 16, 1, 4, 64}, //{ 16, 128, 16, 2, 4, 4, 4, 4, 4, 4, 16, 1, 16, 1, 16, 16}, //{ 8, 32, 4, 2, 2, 2, 2, 4, 4, 2, 4, 2, 8, 1, 4, 16} -static inline int GetImplicitGemmWrwV4R1DynamicGemmkGroups(const conv::ProblemDescription& problem, +static inline int GetImplicitGemmWrwV4R1DynamicGemmkGroups(const ProblemDescription& problem, const int& GemmKPerBlock) { int n = problem.GetInBatchSize_(); @@ -69,7 +72,7 @@ static inline int GetImplicitGemmWrwV4R1DynamicGemmkGroups(const conv::ProblemDe } static inline float CallImplicitGemmWrwDynamic(const miopen::Handle& handle, - const conv::ProblemDescription& problem, + const ProblemDescription& problem, ConstData_t src, ConstData_t dst, Data_t wei, @@ -309,7 +312,7 @@ bool ConvAsmImplicitGemmV4R1DynamicWrw::IsApplicable(const ExecutionContext& ctx if(!ctx.use_hip_kernels) return false; - if(!problem.direction.IsBackwardWrW()) + if(!problem.IsDirectionBackwardWrW()) return false; if(!problem.Is2d()) @@ -412,7 +415,7 @@ ConvSolution ConvAsmImplicitGemmV4R1DynamicWrw::GetSolution(const ExecutionConte result.invoker_factory = [problem](const std::vector& kernels) { return [=](const Handle& handle, const AnyInvokeParams& primitive_parameters) { - decltype(auto) data_ctx = primitive_parameters.CastTo(); + decltype(auto) data_ctx = primitive_parameters.CastTo(); const auto& tensors = data_ctx.tensors; MIOPEN_LOG_I("wrw workspace size: " << data_ctx.workSpaceSize); const auto& workSpace = data_ctx.workSpace; @@ -435,5 +438,6 @@ ConvSolution ConvAsmImplicitGemmV4R1DynamicWrw::GetSolution(const ExecutionConte return result; } +} // namespace conv } // namespace solver } // namespace miopen diff --git a/src/solver/conv_bin_wino3x3U.cpp b/src/solver/conv_bin_wino3x3U.cpp index 0fc9263413..6d23f77920 100644 --- a/src/solver/conv_bin_wino3x3U.cpp +++ b/src/solver/conv_bin_wino3x3U.cpp @@ -39,6 +39,9 @@ MIOPEN_DECLARE_ENV_VAR(MIOPEN_DEBUG_AMD_WINOGRAD_3X3) namespace miopen { namespace solver { +namespace conv { + +using ProblemDescription = miopen::conv::ProblemDescription; bool ConvBinWinograd3x3U::IsApplicable(const ExecutionContext& ctx, const ProblemDescription& problem) const @@ -47,7 +50,7 @@ bool ConvBinWinograd3x3U::IsApplicable(const ExecutionContext& ctx, return false; if(!problem.Is2d()) return false; - if(!(problem.direction.IsForward() || problem.direction.IsBackwardData())) + if(!(problem.IsDirectionForward() || problem.IsDirectionBackwardData())) return false; if(!(ctx.rmv.IsV2orV3() && ctx.use_asm_kernels)) return false; @@ -139,7 +142,7 @@ ConvSolution ConvBinWinograd3x3U::GetSolution(const ExecutionContext& ctx, result.construction_params.push_back(kernel); - const auto is_forward = problem.direction.IsForward(); + const auto is_forward = problem.IsDirectionForward(); result.invoker_factory = [=](const std::vector& kernels) { constexpr int F_REVERSE_R = 1 << 0; @@ -168,7 +171,7 @@ ConvSolution ConvBinWinograd3x3U::GetSolution(const ExecutionContext& ctx, return [=](const Handle& handle, const AnyInvokeParams& primitive_params) { const auto k = handle.Run(kernels[0]); - const auto& fwd_ctx = primitive_params.CastTo(); + const auto& fwd_ctx = primitive_params.CastTo(); const auto& tensors = fwd_ctx.tensors; k(N, @@ -188,5 +191,7 @@ ConvSolution ConvBinWinograd3x3U::GetSolution(const ExecutionContext& ctx, return result; } + +} // namespace conv } // namespace solver } // namespace miopen diff --git a/src/solver/conv_bin_winoRxS.cpp b/src/solver/conv_bin_winoRxS.cpp index f1fd0868c3..9285108b30 100644 --- a/src/solver/conv_bin_winoRxS.cpp +++ b/src/solver/conv_bin_winoRxS.cpp @@ -72,7 +72,7 @@ static inline int FloorDiv(const int x, const int y) } static inline bool IsShaderContraintsMet(const miopen::ExecutionContext& ctx, - const miopen::ProblemDescription& problem, + const miopen::conv::ProblemDescription& problem, const int R, const int S, const int R_stride, @@ -140,7 +140,7 @@ static inline bool IsShaderContraintsMet(const miopen::ExecutionContext& ctx, { return false; } - const bool is_dilated_stride_2 = (problem.direction.IsBackwardData() && S_stride != 1); + const bool is_dilated_stride_2 = (problem.IsDirectionBackwardData() && S_stride != 1); if(fp16) { if(is_dilated_stride_2) @@ -177,7 +177,7 @@ static inline bool IsShaderContraintsMet(const miopen::ExecutionContext& ctx, return false; } // Padding for bwd data shall not be negative. - if(problem.direction.IsBackwardData() || problem.direction.IsBackwardWrW()) + if(problem.IsDirectionBackwardData() || problem.IsDirectionBackwardWrW()) { if(!(0 <= problem.GetBackwardPadW() && problem.GetBackwardPadW() < std::pow(2, 16))) return false; @@ -212,6 +212,9 @@ static inline bool IsShaderContraintsMet(const miopen::ExecutionContext& ctx, namespace miopen { namespace solver { +namespace conv { + +using ProblemDescription = miopen::conv::ProblemDescription; bool ConvBinWinogradRxS::IsApplicable(const ExecutionContext& ctx, const ProblemDescription& problem) const @@ -226,7 +229,7 @@ bool ConvBinWinogradRxS::IsApplicable(const ExecutionContext& ctx, return false; if(miopen::IsDisabled(MIOPEN_DEBUG_AMD_WINOGRAD_RXS{})) return false; - if(problem.direction.IsBackwardWrW()) + if(problem.IsDirectionBackwardWrW()) { if(miopen::IsDisabled(MIOPEN_DEBUG_AMD_WINOGRAD_RXS_WRW{})) return false; @@ -257,7 +260,7 @@ bool ConvBinWinogradRxS::IsApplicable(const ExecutionContext& ctx, } else { - if(problem.direction.IsBackwardWrW()) + if(problem.IsDirectionBackwardWrW()) { if(!(name == "gfx900" || name == "gfx906" || name == "gfx908")) return false; @@ -280,7 +283,7 @@ bool ConvBinWinogradRxS::IsApplicable(const ExecutionContext& ctx, return false; // clang-format on - if(problem.direction.IsBackwardWrW()) + if(problem.IsDirectionBackwardWrW()) { return IsShaderContraintsMet(ctx, problem, @@ -350,7 +353,7 @@ ConvSolution ConvBinWinogradRxS::GetSolution(const ExecutionContext& ctx, if(problem.GetKernelStrideW() == 2) { - if(problem.direction.IsForward()) + if(problem.IsDirectionForward()) kernel.kernel_file += "2_dec"; else kernel.kernel_file += "2_dil"; @@ -360,7 +363,7 @@ ConvSolution ConvBinWinogradRxS::GetSolution(const ExecutionContext& ctx, kernel.kernel_file += "1"; } } - else if(problem.direction.IsBackwardWrW()) + else if(problem.IsDirectionBackwardWrW()) { kernel.kernel_name = "miopenSp3AsmConvRxSf3x2"; kernel.kernel_file = "Conv_Winograd_v16_5_0_stride1"; @@ -371,7 +374,7 @@ ConvSolution ConvBinWinogradRxS::GetSolution(const ExecutionContext& ctx, kernel.kernel_file = "conv_3x3_wheel_alpha_v9_0_15"; if(problem.GetKernelStrideW() == 2) { - if(problem.direction.IsForward()) + if(problem.IsDirectionForward()) kernel.kernel_file += "_stride_2_dec"; else kernel.kernel_file += "_stride_2_dil"; @@ -381,7 +384,7 @@ ConvSolution ConvBinWinogradRxS::GetSolution(const ExecutionContext& ctx, result.construction_params.push_back(kernel); - if(problem.direction.IsBackwardWrW()) + if(problem.IsDirectionBackwardWrW()) { int unused = 0; int N, C, H, W, K, out_H, out_W, R, S, n_groups_; @@ -404,8 +407,9 @@ ConvSolution ConvBinWinogradRxS::GetSolution(const ExecutionContext& ctx, result.invoker_factory = [=](const std::vector& kernels) { return [=](const Handle& handle, const AnyInvokeParams& primitive_params) { - decltype(auto) invoke_params = primitive_params.CastTo(); - const auto& tensors = invoke_params.tensors; + decltype(auto) invoke_params = + primitive_params.CastTo(); + const auto& tensors = invoke_params.tensors; // clang-format off MIOPEN_LOG_I2(" N=" << N << " C=" << C << " H=" << H << " W=" << W << " K=" << K << " n_groups=" << n_groups_ << " flags=" << flags << " R=" << R << " S=" << S @@ -444,7 +448,7 @@ ConvSolution ConvBinWinogradRxS::GetSolution(const ExecutionContext& ctx, } else { - const auto is_forward = problem.direction.IsForward(); + const auto is_forward = problem.IsDirectionForward(); constexpr int F_REVERSE_R = 1 << 0; constexpr int F_REVERSE_S = 1 << 1; constexpr int F_FLIP_K_C = 1 << 2; @@ -473,7 +477,7 @@ ConvSolution ConvBinWinogradRxS::GetSolution(const ExecutionContext& ctx, << " out_H=" << out_H << " out_W=" << out_W); decltype(auto) k = handle.Run(kernels[0]); - decltype(auto) fwd_ctx = primitive_params.CastTo(); + decltype(auto) fwd_ctx = primitive_params.CastTo(); const auto& tensors = fwd_ctx.tensors; k(N, @@ -501,5 +505,6 @@ ConvSolution ConvBinWinogradRxS::GetSolution(const ExecutionContext& ctx, return result; } +} // namespace conv } // namespace solver } // namespace miopen diff --git a/src/solver/conv_bin_winoRxS_fused.cpp b/src/solver/conv_bin_winoRxS_fused.cpp index f11edc368e..ac2272fcf8 100644 --- a/src/solver/conv_bin_winoRxS_fused.cpp +++ b/src/solver/conv_bin_winoRxS_fused.cpp @@ -145,7 +145,7 @@ bool ConvBinWinogradRxSFused::IsApplicable(const FusionContext& context, ConvSolution ConvBinWinogradRxSFused::GetSolution(const FusionContext& context, const FusionDescription& problem) const { - const auto conv_problem = problem.GetConvProblem(0, conv::Direction::Forward); + const auto conv_problem = problem.GetConvProblem(0, miopen::conv::Direction::Forward); const auto conv_ctx = context.GetConvContext(conv_problem); ConvSolution result; KernelInfo kernel; diff --git a/src/solver/conv_ck_igemm_fwd_bias_activ_fused.cpp b/src/solver/conv_ck_igemm_fwd_bias_activ_fused.cpp index 75f5cd1ac3..1457c7d309 100644 --- a/src/solver/conv_ck_igemm_fwd_bias_activ_fused.cpp +++ b/src/solver/conv_ck_igemm_fwd_bias_activ_fused.cpp @@ -63,10 +63,13 @@ void add_device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instances( namespace miopen { namespace solver { namespace fusion { + #if MIOPEN_BACKEND_HIP && MIOPEN_USE_COMPOSABLEKERNEL +namespace { + struct CKArgs { - CKArgs(const ProblemDescription& problem) + CKArgs(const miopen::conv::ProblemDescription& problem) { N = ProblemInterpreter::GetBatchN(problem); K = ProblemInterpreter::GetOutputChannelK(problem); @@ -98,8 +101,11 @@ struct CKArgs std::vector rPadding; }; +} // namespace + template -void PerformanceConfigConvCKIgemmFwdBiasActivFused::Init(const ProblemDescription& problem) +void PerformanceConfigConvCKIgemmFwdBiasActivFused::Init( + const miopen::conv::ProblemDescription& problem) { const auto& args = CKArgs{problem}; std::vector conv_ptrs; @@ -138,7 +144,7 @@ void PerformanceConfigConvCKIgemmFwdBiasActivFused::Init(const ProblemDescriptio template bool PerformanceConfigConvCKIgemmFwdBiasActivFused::CheckIsSupportCKArgs( - const ProblemDescription& problem) const + const miopen::conv::ProblemDescription& problem) const { const auto& args = CKArgs{problem}; std::vector conv_ptrs; @@ -178,7 +184,8 @@ bool PerformanceConfigConvCKIgemmFwdBiasActivFused::CheckIsSupportCKArgs( } template -bool ConvCKIgemmFwdBiasActivFused::CheckCKApplicability(const ProblemDescription& problem) const +bool ConvCKIgemmFwdBiasActivFused::CheckCKApplicability( + const miopen::conv::ProblemDescription& problem) const { std::vector conv_ptrs; ck::tensor_operation::device::instance:: @@ -210,10 +217,12 @@ bool ConvCKIgemmFwdBiasActivFused::CheckCKApplicability(const ProblemDescription return false; } +namespace { + template void RunCKSolution(const Handle& handle, const AnyInvokeParams& primitive_parameters, - const ProblemDescription& problem, + const miopen::conv::ProblemDescription& problem, const PerformanceConfigConvCKIgemmFwdBiasActivFused& config) { const auto& args = CKArgs{problem}; @@ -270,6 +279,8 @@ void RunCKSolution(const Handle& handle, handle.AccumKernelTime(elapsed_time); } } + +} // namespace #endif void PerformanceConfigConvCKIgemmFwdBiasActivFused::HeuristicInit( @@ -278,7 +289,7 @@ void PerformanceConfigConvCKIgemmFwdBiasActivFused::HeuristicInit( #if !MIOPEN_BACKEND_HIP || !MIOPEN_USE_COMPOSABLEKERNEL std::ignore = fdesc_problem; #else - const auto conv_problem = fdesc_problem.GetConvProblem(0, conv::Direction::Forward); + const auto conv_problem = fdesc_problem.GetConvProblem(0, miopen::conv::Direction::Forward); switch(conv_problem.GetInDataType()) { case miopenHalf: Init(conv_problem); break; @@ -332,7 +343,7 @@ bool PerformanceConfigConvCKIgemmFwdBiasActivFused::IsValid( return false; #else // Extract convolution problem from the fusion context. - const auto conv_problem = fdesc_problem.GetConvProblem(0, conv::Direction::Forward); + const auto conv_problem = fdesc_problem.GetConvProblem(0, miopen::conv::Direction::Forward); switch(conv_problem.GetInDataType()) { case miopenHalf: return CheckIsSupportCKArgs(conv_problem); @@ -407,7 +418,7 @@ bool ConvCKIgemmFwdBiasActivFused::IsApplicable(const FusionContext& ctx, const auto& activ_op = dynamic_cast(*desc.op_map[2]); if(activ_op.activMode != miopenActivationRELU) return false; - const auto conv_problem = fdesc_problem.GetConvProblem(0, conv::Direction::Forward); + const auto conv_problem = fdesc_problem.GetConvProblem(0, miopen::conv::Direction::Forward); if(conv_problem.IsTensorsCasted()) return false; @@ -452,7 +463,7 @@ ConvSolution ConvCKIgemmFwdBiasActivFused::GetSolution( std::ignore = config; return {}; #else - const auto conv_problem = fdesc_problem.GetConvProblem(0, conv::Direction::Forward); + const auto conv_problem = fdesc_problem.GetConvProblem(0, miopen::conv::Direction::Forward); ConvSolution result; result.invoker_factory = [=](const std::vector& kernels) { std::ignore = kernels; diff --git a/src/solver/conv_ck_igemm_fwd_v6r1_dlops_nchw.cpp b/src/solver/conv_ck_igemm_fwd_v6r1_dlops_nchw.cpp index 9c42d8b8db..c4c3398fd9 100644 --- a/src/solver/conv_ck_igemm_fwd_v6r1_dlops_nchw.cpp +++ b/src/solver/conv_ck_igemm_fwd_v6r1_dlops_nchw.cpp @@ -43,7 +43,7 @@ namespace solver { namespace ck_utility { static inline auto get_ck_tunable_conv_igemm_fwd_v6r1_dlops_nchw_kcyx_nkhw( - const PerformanceConvCkIgemmFwdV6r1DlopsNchw& config) + const conv::PerformanceConvCkIgemmFwdV6r1DlopsNchw& config) { return ck::driver::ConvIgemmFwdV6r1DlopsNchwKcyxNkhw::GetTunableList()[config .ck_tunable_list_id]; @@ -51,6 +51,10 @@ static inline auto get_ck_tunable_conv_igemm_fwd_v6r1_dlops_nchw_kcyx_nkhw( } // namespace ck_utility +namespace conv { + +using ProblemDescription = miopen::conv::ProblemDescription; + bool PerformanceConvCkIgemmFwdV6r1DlopsNchw::SetNextValue(const ProblemDescription&) { if(ck_tunable_list_id < @@ -99,7 +103,7 @@ bool ConvCkIgemmFwdV6r1DlopsNchw::IsApplicable(const ExecutionContext& ctx, return false; if(!problem.IsLayoutDefault()) return false; - if(!problem.direction.IsForward()) + if(!problem.IsDirectionForward()) return false; if(!problem.Is2d()) return false; @@ -208,7 +212,7 @@ ConvCkIgemmFwdV6r1DlopsNchw::GetSolution(const ExecutionContext& ctx, sol.invoker_factory = [=](const std::vector& kernels) { return [=](const Handle& handle, const AnyInvokeParams& primitive_params) { - const auto& data_ctx = primitive_params.CastTo(); + const auto& data_ctx = primitive_params.CastTo(); const auto& tensors = data_ctx.tensors; auto kernel0 = handle.Run(kernels[0]); auto kernel1 = handle.Run(kernels[1]); @@ -268,5 +272,6 @@ ConvCkIgemmFwdV6r1DlopsNchw::Search(const ExecutionContext& ctx, return GenericSearch(*this, ctx, problem, invoke_ctx); } +} // namespace conv } // namespace solver } // namespace miopen diff --git a/src/solver/conv_direct_naive_conv.cpp b/src/solver/conv_direct_naive_conv.cpp index f87511f911..443ae6a667 100644 --- a/src/solver/conv_direct_naive_conv.cpp +++ b/src/solver/conv_direct_naive_conv.cpp @@ -27,7 +27,7 @@ #include "miopen/env.hpp" #include #include -#include +#include #include #include #include @@ -43,6 +43,9 @@ bool AlwaysEnableConvDirectNaive = false; } // namespace debug namespace solver { +namespace conv { + +using ProblemDescription = miopen::conv::ProblemDescription; bool ConvDirectNaiveConvIsAssemblyKernel(const ExecutionContext& ctx, const ProblemDescription& problem) @@ -61,12 +64,14 @@ bool IsInputFp32(const ProblemDescription& problem) problem.GetWeightsDataType() == miopenFloat) || (problem.GetInDataType() == miopenFloat && problem.GetOutDataType() == miopenFloat); } + bool IsInputFp16(const ProblemDescription& problem) { return (problem.GetInDataType() == miopenHalf && problem.GetWeightsDataType() == miopenHalf) || (problem.GetOutDataType() == miopenHalf && problem.GetWeightsDataType() == miopenHalf) || (problem.GetInDataType() == miopenHalf && problem.GetOutDataType() == miopenHalf); } + bool IsInputBfp16(const ProblemDescription& problem) { return (problem.GetInDataType() == miopenBFloat16 && @@ -76,30 +81,38 @@ bool IsInputBfp16(const ProblemDescription& problem) (problem.GetInDataType() == miopenBFloat16 && problem.GetOutDataType() == miopenBFloat16); } + bool IsInputInt8(const ProblemDescription& problem) { return (problem.GetInDataType() == miopenInt8 && problem.GetWeightsDataType() == miopenInt8) || (problem.GetOutDataType() == miopenInt8 && problem.GetWeightsDataType() == miopenInt8) || (problem.GetInDataType() == miopenInt8 && problem.GetOutDataType() == miopenInt8); } + bool IsAccFp64(const ProblemDescription& problem) { return IsInputFp32(problem) || IsInputFp16(problem) || IsInputBfp16(problem); } + bool IsAccInt32(const ProblemDescription& problem) { return IsInputInt8(problem); } + bool IsOutputFp32(const ProblemDescription& problem) { return problem.IsFp32() || (problem.GetInDataType() == miopenInt8 && problem.GetWeightsDataType() == miopenInt8 && problem.GetOutDataType() == miopenFloat); } + bool IsOutputFp16(const ProblemDescription& problem) { return problem.IsFp16(); } + bool IsOutputBfp16(const ProblemDescription& problem) { return problem.IsBfp16(); } + bool IsOutputInt8(const ProblemDescription& problem) { return problem.GetInDataType() == miopenInt8 && problem.GetWeightsDataType() == miopenInt8 && problem.GetOutDataType() == miopenInt8; } + bool IsOutputInt32(const ProblemDescription& problem) { return problem.GetInDataType() == miopenInt8 && problem.GetWeightsDataType() == miopenInt8 && @@ -124,11 +137,11 @@ std::string ConvDirectNaiveConvKernelName(const ProblemDescription& problem) kernel_name << "naive_conv_nonpacked_"; } - if(problem.direction.IsForward()) + if(problem.IsDirectionForward()) kernel_name << "fwd_"; - else if(problem.direction.IsBackwardData()) + else if(problem.IsDirectionBackwardData()) kernel_name << "bwd_"; - else if(problem.direction.IsBackwardWrW()) + else if(problem.IsDirectionBackwardWrW()) kernel_name << "wrw_"; else MIOPEN_THROW("unsupported convolution direction"); @@ -303,5 +316,6 @@ void conv_internal::DebugPrintTensorStrides(const TensorDescriptor& inDesc, printOneStrideVec("outDesc = ", outDesc.GetStrides()); } +} // namespace conv } // namespace solver } // namespace miopen diff --git a/src/solver/conv_direct_naive_conv_bwd.cpp b/src/solver/conv_direct_naive_conv_bwd.cpp index 77406744b7..ad96b8badb 100644 --- a/src/solver/conv_direct_naive_conv_bwd.cpp +++ b/src/solver/conv_direct_naive_conv_bwd.cpp @@ -33,6 +33,9 @@ MIOPEN_DECLARE_ENV_VAR(MIOPEN_DEBUG_CONV_DIRECT_NAIVE_CONV_BWD) namespace miopen { namespace solver { +namespace conv { + +using ProblemDescription = miopen::conv::ProblemDescription; bool ConvDirectNaiveConvBwd::IsApplicable(const ExecutionContext& ctx, const ProblemDescription& problem) const @@ -44,7 +47,7 @@ bool ConvDirectNaiveConvBwd::IsApplicable(const ExecutionContext& ctx, if(!ConvDirectNaiveConvIsApplicableByKernelType(ctx, problem)) return false; - if(!problem.direction.IsBackwardData()) + if(!problem.IsDirectionBackwardData()) return false; if(!problem.IsLayoutDefault() && !problem.IsLayoutNHWC()) return false; @@ -144,9 +147,10 @@ ConvSolution ConvDirectNaiveConvBwd::GetSolution(const ExecutionContext& ctx, result.invoker_factory = [=](const std::vector& kernels) { const auto kern = kernels[0]; return [=](const Handle& handle, const AnyInvokeParams& primitive_parameters) { - decltype(auto) data_ctx = primitive_parameters.CastTo(); - const auto& tensors = data_ctx.tensors; - float elapsed = 0; + decltype(auto) data_ctx = + primitive_parameters.CastTo(); + const auto& tensors = data_ctx.tensors; + float elapsed = 0; auto in_strides = conv_internal::MakeStrideArray<5>(conv_internal::SplitStrideCtoGC( group, tensors.inDesc.GetStrides(), G_stride_idx)); // For weights, we split K to (G, K_per_group), which is always index 0 @@ -225,9 +229,10 @@ ConvSolution ConvDirectNaiveConvBwd::GetSolution(const ExecutionContext& ctx, result.invoker_factory = [=](const std::vector& kernels) { const auto kern = kernels[0]; return [=](const Handle& handle, const AnyInvokeParams& primitive_parameters) { - decltype(auto) data_ctx = primitive_parameters.CastTo(); - const auto& tensors = data_ctx.tensors; - float elapsed = 0; + decltype(auto) data_ctx = + primitive_parameters.CastTo(); + const auto& tensors = data_ctx.tensors; + float elapsed = 0; auto in_strides = conv_internal::MakeStrideArray<6>(conv_internal::SplitStrideCtoGC( group, tensors.inDesc.GetStrides(), G_stride_idx)); @@ -286,5 +291,6 @@ ConvSolution ConvDirectNaiveConvBwd::GetSolution(const ExecutionContext& ctx, return result; } +} // namespace conv } // namespace solver } // namespace miopen diff --git a/src/solver/conv_direct_naive_conv_fwd.cpp b/src/solver/conv_direct_naive_conv_fwd.cpp index f1ed2f5b10..9fee363c24 100644 --- a/src/solver/conv_direct_naive_conv_fwd.cpp +++ b/src/solver/conv_direct_naive_conv_fwd.cpp @@ -32,6 +32,9 @@ MIOPEN_DECLARE_ENV_VAR(MIOPEN_DEBUG_CONV_DIRECT_NAIVE_CONV_FWD) namespace miopen { namespace solver { +namespace conv { + +using ProblemDescription = miopen::conv::ProblemDescription; bool ConvDirectNaiveConvFwd::IsApplicable(const ExecutionContext& ctx, const ProblemDescription& problem) const @@ -50,7 +53,7 @@ bool ConvDirectNaiveConvFwd::IsApplicable(const ExecutionContext& ctx, problem.IsFp8() || problem.IsBfp8())) return false; - if(!problem.direction.IsForward()) + if(!problem.IsDirectionForward()) return false; if(problem.IsTensorsCasted()) @@ -144,9 +147,10 @@ ConvSolution ConvDirectNaiveConvFwd::GetSolution(const ExecutionContext& ctx, result.invoker_factory = [=](const std::vector& kernels) { const auto kern = kernels[0]; return [=](const Handle& handle, const AnyInvokeParams& primitive_parameters) { - decltype(auto) data_ctx = primitive_parameters.CastTo(); - const auto& tensors = data_ctx.tensors; - float elapsed = 0; + decltype(auto) data_ctx = + primitive_parameters.CastTo(); + const auto& tensors = data_ctx.tensors; + float elapsed = 0; auto in_strides = conv_internal::MakeStrideArray<5>(conv_internal::SplitStrideCtoGC( group, tensors.inDesc.GetStrides(), G_stride_idx)); @@ -226,9 +230,10 @@ ConvSolution ConvDirectNaiveConvFwd::GetSolution(const ExecutionContext& ctx, result.invoker_factory = [=](const std::vector& kernels) { const auto kern = kernels[0]; return [=](const Handle& handle, const AnyInvokeParams& primitive_parameters) { - decltype(auto) data_ctx = primitive_parameters.CastTo(); - const auto& tensors = data_ctx.tensors; - float elapsed = 0; + decltype(auto) data_ctx = + primitive_parameters.CastTo(); + const auto& tensors = data_ctx.tensors; + float elapsed = 0; auto in_strides = conv_internal::MakeStrideArray<6>(conv_internal::SplitStrideCtoGC( group, tensors.inDesc.GetStrides(), G_stride_idx)); @@ -281,5 +286,6 @@ ConvSolution ConvDirectNaiveConvFwd::GetSolution(const ExecutionContext& ctx, return result; } +} // namespace conv } // namespace solver } // namespace miopen diff --git a/src/solver/conv_direct_naive_conv_wrw.cpp b/src/solver/conv_direct_naive_conv_wrw.cpp index b83b334faa..eaf7ed7d68 100644 --- a/src/solver/conv_direct_naive_conv_wrw.cpp +++ b/src/solver/conv_direct_naive_conv_wrw.cpp @@ -33,6 +33,9 @@ MIOPEN_DECLARE_ENV_VAR(MIOPEN_DEBUG_CONV_DIRECT_NAIVE_CONV_WRW) namespace miopen { namespace solver { +namespace conv { + +using ProblemDescription = miopen::conv::ProblemDescription; bool ConvDirectNaiveConvWrw::IsApplicable(const ExecutionContext& ctx, const ProblemDescription& problem) const @@ -51,7 +54,7 @@ bool ConvDirectNaiveConvWrw::IsApplicable(const ExecutionContext& ctx, problem.IsBfp8())) return false; - if(!problem.direction.IsBackwardWrW()) + if(!problem.IsDirectionBackwardWrW()) return false; if(problem.IsTensorsCasted()) { @@ -132,9 +135,10 @@ ConvSolution ConvDirectNaiveConvWrw::GetSolution(const ExecutionContext& ctx, result.invoker_factory = [=](const std::vector& kernels) { const auto kern = kernels[0]; return [=](const Handle& handle, const AnyInvokeParams& primitive_parameters) { - decltype(auto) data_ctx = primitive_parameters.CastTo(); - const auto& tensors = data_ctx.tensors; - float elapsed = 0; + decltype(auto) data_ctx = + primitive_parameters.CastTo(); + const auto& tensors = data_ctx.tensors; + float elapsed = 0; auto in_strides = conv_internal::MakeStrideArray<5>(conv_internal::SplitStrideCtoGC( group, tensors.xDesc.GetStrides(), G_stride_idx)); @@ -214,9 +218,10 @@ ConvSolution ConvDirectNaiveConvWrw::GetSolution(const ExecutionContext& ctx, result.invoker_factory = [=](const std::vector& kernels) { const auto kern = kernels[0]; return [=](const Handle& handle, const AnyInvokeParams& primitive_parameters) { - decltype(auto) data_ctx = primitive_parameters.CastTo(); - const auto& tensors = data_ctx.tensors; - float elapsed = 0; + decltype(auto) data_ctx = + primitive_parameters.CastTo(); + const auto& tensors = data_ctx.tensors; + float elapsed = 0; auto in_strides = conv_internal::MakeStrideArray<6>(conv_internal::SplitStrideCtoGC( group, tensors.xDesc.GetStrides(), G_stride_idx)); @@ -270,5 +275,6 @@ ConvSolution ConvDirectNaiveConvWrw::GetSolution(const ExecutionContext& ctx, return result; } +} // namespace conv } // namespace solver } // namespace miopen diff --git a/src/solver/conv_hip_implicit_gemm_3d_grouped_bwd_xdlops.cpp b/src/solver/conv_hip_implicit_gemm_3d_grouped_bwd_xdlops.cpp index 82f3411cb8..57d8709c06 100644 --- a/src/solver/conv_hip_implicit_gemm_3d_grouped_bwd_xdlops.cpp +++ b/src/solver/conv_hip_implicit_gemm_3d_grouped_bwd_xdlops.cpp @@ -40,6 +40,9 @@ MIOPEN_DECLARE_ENV_VAR(MIOPEN_DEBUG_3D_CONV_IMPLICIT_GEMM_HIP_BWD_XDLOPS) namespace miopen { namespace solver { +namespace conv { + +using ProblemDescription = miopen::conv::ProblemDescription; #if MIOPEN_BACKEND_HIP && MIOPEN_USE_COMPOSABLEKERNEL template @@ -312,7 +315,7 @@ bool ConvHipImplicitGemm3DGroupBwdXdlops::IsApplicable( return false; if(problem.IsTensorsCasted()) return false; - if(!problem.direction.IsBackwardData()) + if(!problem.IsDirectionBackwardData()) return false; if(!problem.Is3d()) return false; @@ -344,13 +347,14 @@ ConvSolution ConvHipImplicitGemm3DGroupBwdXdlops::GetSolution( switch(problem.GetInDataType()) { case miopenInt8: - return MakeInvokerFactory, CKArgs, conv::DataInvokeParams>( + return MakeInvokerFactory, CKArgs, miopen::conv::DataInvokeParams>( problem, config.kernel_id); case miopenHalf: - return MakeInvokerFactory, CKArgs, conv::DataInvokeParams>( - problem, config.kernel_id); + return MakeInvokerFactory, + CKArgs, + miopen::conv::DataInvokeParams>(problem, config.kernel_id); case miopenFloat: - return MakeInvokerFactory, CKArgs, conv::DataInvokeParams>( + return MakeInvokerFactory, CKArgs, miopen::conv::DataInvokeParams>( problem, config.kernel_id); case miopenInt32: case miopenBFloat16: @@ -365,5 +369,6 @@ ConvSolution ConvHipImplicitGemm3DGroupBwdXdlops::GetSolution( return {}; } +} // namespace conv } // namespace solver } // namespace miopen diff --git a/src/solver/conv_hip_implicit_gemm_3d_grouped_fwd_xdlops.cpp b/src/solver/conv_hip_implicit_gemm_3d_grouped_fwd_xdlops.cpp index b2a09e26d5..54632adc9c 100644 --- a/src/solver/conv_hip_implicit_gemm_3d_grouped_fwd_xdlops.cpp +++ b/src/solver/conv_hip_implicit_gemm_3d_grouped_fwd_xdlops.cpp @@ -40,6 +40,9 @@ MIOPEN_DECLARE_ENV_VAR(MIOPEN_DEBUG_3D_CONV_IMPLICIT_GEMM_HIP_FWD_XDLOPS) namespace miopen { namespace solver { +namespace conv { + +using ProblemDescription = miopen::conv::ProblemDescription; #if MIOPEN_BACKEND_HIP && MIOPEN_USE_COMPOSABLEKERNEL template @@ -179,6 +182,7 @@ struct CKArgs std::array lPadding; std::array rPadding; }; + } // namespace template @@ -310,7 +314,7 @@ bool ConvHipImplicitGemm3DGroupFwdXdlops::IsApplicable( return false; if(problem.HasMixedDataTypes()) return false; - if(!problem.direction.IsForward()) + if(!problem.IsDirectionForward()) return false; if(!problem.Is3d()) return false; @@ -342,13 +346,14 @@ ConvSolution ConvHipImplicitGemm3DGroupFwdXdlops::GetSolution( switch(problem.GetInDataType()) { case miopenInt8: - return MakeInvokerFactory, CKArgs, conv::DataInvokeParams>( + return MakeInvokerFactory, CKArgs, miopen::conv::DataInvokeParams>( problem, config.kernel_id); case miopenHalf: - return MakeInvokerFactory, CKArgs, conv::DataInvokeParams>( - problem, config.kernel_id); + return MakeInvokerFactory, + CKArgs, + miopen::conv::DataInvokeParams>(problem, config.kernel_id); case miopenFloat: - return MakeInvokerFactory, CKArgs, conv::DataInvokeParams>( + return MakeInvokerFactory, CKArgs, miopen::conv::DataInvokeParams>( problem, config.kernel_id); case miopenInt32: case miopenBFloat16: @@ -363,5 +368,6 @@ ConvSolution ConvHipImplicitGemm3DGroupFwdXdlops::GetSolution( return {}; } +} // namespace conv } // namespace solver } // namespace miopen diff --git a/src/solver/conv_hip_implicit_gemm_3d_grouped_wrw_xdlops.cpp b/src/solver/conv_hip_implicit_gemm_3d_grouped_wrw_xdlops.cpp index d395e576f0..dc2f0f6218 100644 --- a/src/solver/conv_hip_implicit_gemm_3d_grouped_wrw_xdlops.cpp +++ b/src/solver/conv_hip_implicit_gemm_3d_grouped_wrw_xdlops.cpp @@ -40,6 +40,9 @@ MIOPEN_DECLARE_ENV_VAR(MIOPEN_DEBUG_3D_CONV_IMPLICIT_GEMM_HIP_WRW_XDLOPS) namespace miopen { namespace solver { +namespace conv { + +using ProblemDescription = miopen::conv::ProblemDescription; #if MIOPEN_BACKEND_HIP && MIOPEN_USE_COMPOSABLEKERNEL template @@ -306,7 +309,7 @@ bool ConvHipImplicitGemm3DGroupWrwXdlops::IsApplicable( return false; if(problem.HasMixedDataTypes()) return false; - if(!problem.direction.IsBackwardWrW()) + if(!problem.IsDirectionBackwardWrW()) return false; if(!problem.Is3d()) return false; @@ -338,13 +341,14 @@ ConvSolution ConvHipImplicitGemm3DGroupWrwXdlops::GetSolution( switch(problem.GetInDataType()) { case miopenInt8: - return MakeInvokerFactory, CKArgs, conv::WrWInvokeParams>( + return MakeInvokerFactory, CKArgs, miopen::conv::WrWInvokeParams>( problem, config.kernel_id); case miopenHalf: - return MakeInvokerFactory, CKArgs, conv::WrWInvokeParams>( - problem, config.kernel_id); + return MakeInvokerFactory, + CKArgs, + miopen::conv::WrWInvokeParams>(problem, config.kernel_id); case miopenFloat: - return MakeInvokerFactory, CKArgs, conv::WrWInvokeParams>( + return MakeInvokerFactory, CKArgs, miopen::conv::WrWInvokeParams>( problem, config.kernel_id); case miopenInt32: case miopenBFloat16: @@ -359,5 +363,6 @@ ConvSolution ConvHipImplicitGemm3DGroupWrwXdlops::GetSolution( return {}; } +} // namespace conv } // namespace solver } // namespace miopen diff --git a/src/solver/conv_hip_implicit_gemm_bwd_data_xdlops.cpp b/src/solver/conv_hip_implicit_gemm_bwd_data_xdlops.cpp index 85bec04104..3cda7a1cc1 100644 --- a/src/solver/conv_hip_implicit_gemm_bwd_data_xdlops.cpp +++ b/src/solver/conv_hip_implicit_gemm_bwd_data_xdlops.cpp @@ -40,6 +40,9 @@ MIOPEN_DECLARE_ENV_VAR(MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_BWD_XDLOPS) namespace miopen { namespace solver { +namespace conv { + +using ProblemDescription = miopen::conv::ProblemDescription; #if MIOPEN_BACKEND_HIP && MIOPEN_USE_COMPOSABLEKERNEL template @@ -266,7 +269,7 @@ bool ConvHipImplicitGemmBwdXdlops::IsApplicable( return false; if(problem.IsTensorsCasted()) return false; - if(!problem.direction.IsBackwardData()) + if(!problem.IsDirectionBackwardData()) return false; if(!problem.Is2d()) return false; @@ -307,10 +310,11 @@ ConvSolution ConvHipImplicitGemmBwdXdlops::GetSolution( switch(problem.GetInDataType()) { case miopenHalf: - return MakeInvokerFactory, CKArgs, conv::DataInvokeParams>( - problem, config.kernel_id); + return MakeInvokerFactory, + CKArgs, + miopen::conv::DataInvokeParams>(problem, config.kernel_id); case miopenFloat: - return MakeInvokerFactory, CKArgs, conv::DataInvokeParams>( + return MakeInvokerFactory, CKArgs, miopen::conv::DataInvokeParams>( problem, config.kernel_id); case miopenInt8: case miopenInt32: @@ -326,5 +330,6 @@ ConvSolution ConvHipImplicitGemmBwdXdlops::GetSolution( return {}; } +} // namespace conv } // namespace solver } // namespace miopen diff --git a/src/solver/conv_hip_implicit_gemm_bwd_v1r1.cpp b/src/solver/conv_hip_implicit_gemm_bwd_v1r1.cpp index 5b3cb4933f..241aba33c9 100644 --- a/src/solver/conv_hip_implicit_gemm_bwd_v1r1.cpp +++ b/src/solver/conv_hip_implicit_gemm_bwd_v1r1.cpp @@ -38,6 +38,9 @@ MIOPEN_DECLARE_ENV_VAR(MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_BWD_V1R1) namespace miopen { namespace solver { +namespace conv { + +using ProblemDescription = miopen::conv::ProblemDescription; PerformanceImplicitGemmBwdDataV1R1::PerformanceImplicitGemmBwdDataV1R1(int BlockSize_, int GemmMPerBlock_, @@ -644,7 +647,7 @@ bool ConvHipImplicitGemmBwdDataV1R1::IsApplicable(const ExecutionContext& ctx, return false; if(!IsComposableKernelSupportedHardware(ctx)) return false; - if(!problem.direction.IsBackwardData()) + if(!problem.IsDirectionBackwardData()) return false; if(!problem.Is2d() && !(problem.Is3d() && problem.IsFp32())) return false; @@ -869,10 +872,11 @@ ConvHipImplicitGemmBwdDataV1R1::GetSolution(const ExecutionContext& ctx, std::to_string(GemmBBlockCopyDstDataPerWrite_GemmKPACK); } - result.invoker_factory = conv::MakeImplGemmDataInvokerFactory(problem); + result.invoker_factory = miopen::conv::MakeImplGemmDataInvokerFactory(problem); result.construction_params.push_back(construction_parameters); return result; } +} // namespace conv } // namespace solver } // namespace miopen diff --git a/src/solver/conv_hip_implicit_gemm_bwd_v1r1_xdlops.cpp b/src/solver/conv_hip_implicit_gemm_bwd_v1r1_xdlops.cpp index 6db26950d5..b8f04e2b38 100644 --- a/src/solver/conv_hip_implicit_gemm_bwd_v1r1_xdlops.cpp +++ b/src/solver/conv_hip_implicit_gemm_bwd_v1r1_xdlops.cpp @@ -35,6 +35,9 @@ MIOPEN_DECLARE_ENV_VAR(MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_BWD_V1R1_XDLOPS) namespace miopen { namespace solver { +namespace conv { + +using ProblemDescription = miopen::conv::ProblemDescription; PerformanceImplicitGemmBwdV1R1Xdlops::PerformanceImplicitGemmBwdV1R1Xdlops() : PerformanceImplicitGemmBwdV1R1Xdlops::PerformanceImplicitGemmBwdV1R1Xdlops( @@ -778,7 +781,7 @@ bool ConvHipImplicitGemmBwdDataV1R1Xdlops::IsApplicable(const ExecutionContext& if(!(problem.IsFp32() || problem.IsFp16() || problem.IsBfp16())) return false; - if(!problem.direction.IsBackwardData()) + if(!problem.IsDirectionBackwardData()) return false; if(!problem.Is2d()) @@ -927,10 +930,11 @@ ConvSolution ConvHipImplicitGemmBwdDataV1R1Xdlops::GetSolution( get_static_ck_common_compiler_flag(ctx) + ctx.general_compile_options; - result.invoker_factory = conv::MakeImplGemmDataInvokerFactory(problem); + result.invoker_factory = miopen::conv::MakeImplGemmDataInvokerFactory(problem); result.construction_params.push_back(construction_parameters); return result; } +} // namespace conv } // namespace solver } // namespace miopen diff --git a/src/solver/conv_hip_implicit_gemm_bwd_v4r1.cpp b/src/solver/conv_hip_implicit_gemm_bwd_v4r1.cpp index a0cca73af9..a58b0df45e 100644 --- a/src/solver/conv_hip_implicit_gemm_bwd_v4r1.cpp +++ b/src/solver/conv_hip_implicit_gemm_bwd_v4r1.cpp @@ -36,6 +36,9 @@ MIOPEN_DECLARE_ENV_VAR(MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_BWD_V4R1) namespace miopen { namespace solver { +namespace conv { + +using ProblemDescription = miopen::conv::ProblemDescription; PerformanceImplicitGemmBwdDataV4R1::PerformanceImplicitGemmBwdDataV4R1(int BlockSize_, int GemmMPerBlock_, @@ -742,7 +745,7 @@ bool ConvHipImplicitGemmBwdDataV4R1::IsApplicable(const ExecutionContext& ctx, if(!IsComposableKernelSupportedHardware(ctx)) return false; - if(!problem.direction.IsBackwardData()) + if(!problem.IsDirectionBackwardData()) return false; if(!ctx.use_hip_kernels) @@ -972,9 +975,10 @@ ConvHipImplicitGemmBwdDataV4R1::GetSolution(const ExecutionContext& ctx, } } - result.invoker_factory = conv::MakeImplGemmDataInvokerFactory(problem); + result.invoker_factory = miopen::conv::MakeImplGemmDataInvokerFactory(problem); return result; } +} // namespace conv } // namespace solver } // namespace miopen diff --git a/src/solver/conv_hip_implicit_gemm_bwd_v4r1_xdlops.cpp b/src/solver/conv_hip_implicit_gemm_bwd_v4r1_xdlops.cpp index 2c4f36e820..485bdbbc5b 100644 --- a/src/solver/conv_hip_implicit_gemm_bwd_v4r1_xdlops.cpp +++ b/src/solver/conv_hip_implicit_gemm_bwd_v4r1_xdlops.cpp @@ -41,6 +41,9 @@ MIOPEN_DECLARE_ENV_VAR(MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_BWD_V4R1_XDLOPS_PERF_ namespace miopen { namespace solver { +namespace conv { + +using ProblemDescription = miopen::conv::ProblemDescription; std::tuple PerformanceImplicitGemmBwdDataV4R1Xdlops::CalculateGridSize(const ProblemDescription& problem) const @@ -832,7 +835,7 @@ bool ConvHipImplicitGemmBwdDataV4R1Xdlops::IsApplicable(const ExecutionContext& return false; if(!IsComposableKernelSupportedHardware(ctx)) return false; - if(!problem.direction.IsBackwardData()) + if(!problem.IsDirectionBackwardData()) return false; if(!ctx.use_hip_kernels) return false; @@ -1064,9 +1067,10 @@ ConvSolution ConvHipImplicitGemmBwdDataV4R1Xdlops::GetSolution( } } - result.invoker_factory = conv::MakeImplGemmDataInvokerFactory(problem); + result.invoker_factory = miopen::conv::MakeImplGemmDataInvokerFactory(problem); return result; } +} // namespace conv } // namespace solver } // namespace miopen diff --git a/src/solver/conv_hip_implicit_gemm_fwd_v4r1.cpp b/src/solver/conv_hip_implicit_gemm_fwd_v4r1.cpp index 88d0f7b314..0a4babf380 100644 --- a/src/solver/conv_hip_implicit_gemm_fwd_v4r1.cpp +++ b/src/solver/conv_hip_implicit_gemm_fwd_v4r1.cpp @@ -39,6 +39,9 @@ MIOPEN_DECLARE_ENV_VAR(MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_WRW_V4R1) namespace miopen { namespace solver { +namespace conv { + +using ProblemDescription = miopen::conv::ProblemDescription; bool ConvHipImplicitGemmV4R1Fwd::IsApplicable(const ExecutionContext& ctx, const ProblemDescription& problem) const @@ -51,7 +54,7 @@ bool ConvHipImplicitGemmV4R1Fwd::IsApplicable(const ExecutionContext& ctx, return false; if(problem.GetConv().attribute.deterministic) return false; - if(!problem.direction.IsForward()) + if(!problem.IsDirectionForward()) return false; if(!ctx.use_hip_kernels) return false; @@ -97,7 +100,7 @@ bool ConvHipImplicitGemmV4R1WrW::IsApplicable(const ExecutionContext& ctx, return false; if(!IsComposableKernelSupportedHardware(ctx)) return false; - if(!problem.direction.IsBackwardWrW()) + if(!problem.IsDirectionBackwardWrW()) return false; if(!ctx.use_hip_kernels) return false; @@ -389,7 +392,7 @@ ConvHipImplicitGemmV4R1Fwd::GetSolution(const ExecutionContext& ctx, } result.construction_params.push_back(construction_parameters); - result.invoker_factory = conv::MakeImplGemmDataInvokerFactory(problem); + result.invoker_factory = miopen::conv::MakeImplGemmDataInvokerFactory(problem); return result; } @@ -598,7 +601,7 @@ ConvHipImplicitGemmV4R1WrW::GetSolution(const ExecutionContext& ctx, result.invoker_factory = [](const std::vector& kernels) { return [=](const Handle& handle, const AnyInvokeParams& primitive_params) { - const auto& invoke_params = primitive_params.CastTo(); + const auto& invoke_params = primitive_params.CastTo(); const auto& tensors = invoke_params.tensors; handle.Run(kernels[0])(tensors.x, tensors.dy, tensors.dw); }; @@ -607,5 +610,6 @@ ConvHipImplicitGemmV4R1WrW::GetSolution(const ExecutionContext& ctx, return result; } +} // namespace conv } // namespace solver } // namespace miopen diff --git a/src/solver/conv_hip_implicit_gemm_fwd_v4r4.cpp b/src/solver/conv_hip_implicit_gemm_fwd_v4r4.cpp index 9babd8fff8..2f51ebf819 100644 --- a/src/solver/conv_hip_implicit_gemm_fwd_v4r4.cpp +++ b/src/solver/conv_hip_implicit_gemm_fwd_v4r4.cpp @@ -35,6 +35,9 @@ MIOPEN_DECLARE_ENV_VAR(MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_FWD_V4R4) namespace miopen { namespace solver { +namespace conv { + +using ProblemDescription = miopen::conv::ProblemDescription; PerformanceImplicitGemmV4R4Fwd::PerformanceImplicitGemmV4R4Fwd(int BlockSize_, int GemmMPerBlock_, @@ -589,7 +592,7 @@ bool ConvHipImplicitGemmV4R4Fwd::IsApplicable(const ExecutionContext& ctx, return false; if(!IsComposableKernelSupportedHardware(ctx)) return false; - if(!problem.direction.IsForward()) + if(!problem.IsDirectionForward()) return false; if(!problem.Is2d() && !problem.Is3d()) return false; @@ -770,10 +773,11 @@ ConvHipImplicitGemmV4R4Fwd::GetSolution(const ExecutionContext& ctx, // clang-format on - result.invoker_factory = conv::MakeImplGemmDataInvokerFactory(problem); + result.invoker_factory = miopen::conv::MakeImplGemmDataInvokerFactory(problem); result.construction_params.push_back(construction_parameters); return result; } +} // namespace conv } // namespace solver } // namespace miopen diff --git a/src/solver/conv_hip_implicit_gemm_fwd_v4r4_xdlops.cpp b/src/solver/conv_hip_implicit_gemm_fwd_v4r4_xdlops.cpp index 8a7b4b150b..c9e7e9bcb0 100644 --- a/src/solver/conv_hip_implicit_gemm_fwd_v4r4_xdlops.cpp +++ b/src/solver/conv_hip_implicit_gemm_fwd_v4r4_xdlops.cpp @@ -43,6 +43,9 @@ MIOPEN_DECLARE_ENV_VAR( namespace miopen { namespace solver { +namespace conv { + +using ProblemDescription = miopen::conv::ProblemDescription; PerformanceImplicitGemmForwardV4R4Xdlops::PerformanceImplicitGemmForwardV4R4Xdlops() : PerformanceImplicitGemmForwardV4R4Xdlops::PerformanceImplicitGemmForwardV4R4Xdlops( @@ -961,7 +964,7 @@ ConvSolution ConvHipImplicitGemmForwardV4R4Xdlops::GetSolution( ctx.general_compile_options; // clang-format on - result.invoker_factory = conv::MakeImplGemmDataInvokerFactory(problem); + result.invoker_factory = miopen::conv::MakeImplGemmDataInvokerFactory(problem); result.construction_params.push_back(construction_parameters); return result; } @@ -996,7 +999,7 @@ bool ConvHipImplicitGemmForwardV4R4Xdlops::IsApplicable(const ExecutionContext& if(problem.IsTensorsCasted()) return false; - if(!problem.direction.IsForward()) + if(!problem.IsDirectionForward()) return false; if(!problem.Is2d()) @@ -1041,5 +1044,6 @@ ConvHipImplicitGemmForwardV4R4Xdlops::Search(const ExecutionContext& ctx, return GenericSearch(*this, ctx, problem, invoke_ctx); } +} // namespace conv } // namespace solver } // namespace miopen diff --git a/src/solver/conv_hip_implicit_gemm_fwd_v4r4_xdlops_padded_gemm.cpp b/src/solver/conv_hip_implicit_gemm_fwd_v4r4_xdlops_padded_gemm.cpp index ed868b3d04..fb94e30441 100644 --- a/src/solver/conv_hip_implicit_gemm_fwd_v4r4_xdlops_padded_gemm.cpp +++ b/src/solver/conv_hip_implicit_gemm_fwd_v4r4_xdlops_padded_gemm.cpp @@ -53,6 +53,9 @@ MIOPEN_DECLARE_ENV_VAR( namespace miopen { namespace solver { +namespace conv { + +using ProblemDescription = miopen::conv::ProblemDescription; PerformanceImplicitGemmForwardV4R4Xdlops_Padded_Gemm:: PerformanceImplicitGemmForwardV4R4Xdlops_Padded_Gemm() @@ -1027,7 +1030,7 @@ ConvSolution ConvHipImplicitGemmForwardV4R4Xdlops_Padded_Gemm::GetSolution( ctx.general_compile_options; // clang-format on - result.invoker_factory = conv::MakeImplGemmDataInvokerFactory(problem); + result.invoker_factory = miopen::conv::MakeImplGemmDataInvokerFactory(problem); result.construction_params.push_back(construction_parameters); return result; } @@ -1059,7 +1062,7 @@ bool ConvHipImplicitGemmForwardV4R4Xdlops_Padded_Gemm::IsApplicable( if(!(problem.IsFp32() || problem.IsFp16() || problem.IsBfp16())) return false; - if(!problem.direction.IsForward()) + if(!problem.IsDirectionForward()) return false; if(!problem.Is2d()) @@ -1133,5 +1136,6 @@ ConvHipImplicitGemmForwardV4R4Xdlops_Padded_Gemm::Search(const ExecutionContext& return GenericSearch(*this, ctx, problem, invoke_ctx); } +} // namespace conv } // namespace solver } // namespace miopen diff --git a/src/solver/conv_hip_implicit_gemm_fwd_v4r5_xdlops.cpp b/src/solver/conv_hip_implicit_gemm_fwd_v4r5_xdlops.cpp index 6006cb0caa..838e39fdae 100644 --- a/src/solver/conv_hip_implicit_gemm_fwd_v4r5_xdlops.cpp +++ b/src/solver/conv_hip_implicit_gemm_fwd_v4r5_xdlops.cpp @@ -43,6 +43,9 @@ MIOPEN_DECLARE_ENV_VAR( namespace miopen { namespace solver { +namespace conv { + +using ProblemDescription = miopen::conv::ProblemDescription; static std::tuple CalculateGemmSize(const ProblemDescription& problem) { @@ -991,7 +994,7 @@ ConvSolution ConvHipImplicitGemmForwardV4R5Xdlops::GetSolution( ctx.general_compile_options; // clang-format on - result.invoker_factory = conv::MakeImplGemmDataInvokerFactory(problem); + result.invoker_factory = miopen::conv::MakeImplGemmDataInvokerFactory(problem); result.construction_params.push_back(construction_parameters); return result; } @@ -1033,7 +1036,7 @@ bool ConvHipImplicitGemmForwardV4R5Xdlops::IsApplicable(const ExecutionContext& if(y == 1 && x == 1) return false; - if(!problem.direction.IsForward()) + if(!problem.IsDirectionForward()) return false; if(!problem.Is2d()) @@ -1077,5 +1080,6 @@ ConvHipImplicitGemmForwardV4R5Xdlops::Search(const ExecutionContext& ctx, return GenericSearch(*this, ctx, problem, invoke_ctx); } +} // namespace conv } // namespace solver } // namespace miopen diff --git a/src/solver/conv_hip_implicit_gemm_fwd_xdlops.cpp b/src/solver/conv_hip_implicit_gemm_fwd_xdlops.cpp index 1b1127d6e1..9dfad4478d 100644 --- a/src/solver/conv_hip_implicit_gemm_fwd_xdlops.cpp +++ b/src/solver/conv_hip_implicit_gemm_fwd_xdlops.cpp @@ -40,6 +40,9 @@ MIOPEN_DECLARE_ENV_VAR(MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_FWD_XDLOPS) namespace miopen { namespace solver { +namespace conv { + +using ProblemDescription = miopen::conv::ProblemDescription; #if MIOPEN_BACKEND_HIP && MIOPEN_USE_COMPOSABLEKERNEL template @@ -265,7 +268,7 @@ bool ConvHipImplicitGemmFwdXdlops::IsApplicable( return false; if(problem.HasMixedDataTypes()) return false; - if(!problem.direction.IsForward()) + if(!problem.IsDirectionForward()) return false; if(!problem.Is2d()) return false; @@ -308,13 +311,13 @@ ConvSolution ConvHipImplicitGemmFwdXdlops::GetSolution( switch(problem.GetInDataType()) { case miopenInt8: - return MakeInvokerFactory, CKArgs, conv::DataInvokeParams>( + return MakeInvokerFactory, CKArgs, miopen::conv::DataInvokeParams>( problem, config.kernel_id); case miopenHalf: - return MakeInvokerFactory, CKArgs, conv::DataInvokeParams>( + return MakeInvokerFactory, CKArgs, miopen::conv::DataInvokeParams>( problem, config.kernel_id); case miopenFloat: - return MakeInvokerFactory, CKArgs, conv::DataInvokeParams>( + return MakeInvokerFactory, CKArgs, miopen::conv::DataInvokeParams>( problem, config.kernel_id); case miopenInt32: case miopenBFloat16: @@ -329,5 +332,6 @@ ConvSolution ConvHipImplicitGemmFwdXdlops::GetSolution( return {}; } +} // namespace conv } // namespace solver } // namespace miopen diff --git a/src/solver/conv_hip_implicit_gemm_grouped_fwd_xdlops.cpp b/src/solver/conv_hip_implicit_gemm_grouped_fwd_xdlops.cpp index 8e7898ea1a..783dabce43 100644 --- a/src/solver/conv_hip_implicit_gemm_grouped_fwd_xdlops.cpp +++ b/src/solver/conv_hip_implicit_gemm_grouped_fwd_xdlops.cpp @@ -39,6 +39,9 @@ MIOPEN_DECLARE_ENV_VAR(MIOPEN_DEBUG_GROUP_CONV_IMPLICIT_GEMM_HIP_FWD_XDLOPS) namespace miopen { namespace solver { +namespace conv { + +using ProblemDescription = miopen::conv::ProblemDescription; #if MIOPEN_BACKEND_HIP && MIOPEN_USE_COMPOSABLEKERNEL template @@ -160,6 +163,7 @@ struct CKArgs std::array lPadding; std::array rPadding; }; + } // namespace template @@ -294,7 +298,7 @@ bool ConvHipImplicitGemmGroupFwdXdlops::IsApplicable( return false; if(problem.HasMixedDataTypes()) return false; - if(!problem.direction.IsForward()) + if(!problem.IsDirectionForward()) return false; if(!problem.Is2d()) return false; @@ -327,13 +331,14 @@ ConvSolution ConvHipImplicitGemmGroupFwdXdlops::GetSolution( switch(problem.GetInDataType()) { case miopenHalf: - return MakeInvokerFactory, CKArgs, conv::DataInvokeParams>( - problem, config.kernel_id); + return MakeInvokerFactory, + CKArgs, + miopen::conv::DataInvokeParams>(problem, config.kernel_id); case miopenFloat: - return MakeInvokerFactory, CKArgs, conv::DataInvokeParams>( + return MakeInvokerFactory, CKArgs, miopen::conv::DataInvokeParams>( problem, config.kernel_id); case miopenInt8: - return MakeInvokerFactory, CKArgs, conv::DataInvokeParams>( + return MakeInvokerFactory, CKArgs, miopen::conv::DataInvokeParams>( problem, config.kernel_id); case miopenInt32: case miopenBFloat16: @@ -348,5 +353,6 @@ ConvSolution ConvHipImplicitGemmGroupFwdXdlops::GetSolution( return {}; } +} // namespace conv } // namespace solver } // namespace miopen diff --git a/src/solver/conv_hip_implicit_gemm_nonxdlops_common.cpp b/src/solver/conv_hip_implicit_gemm_nonxdlops_common.cpp index 472b58f913..66ef3b8ea6 100644 --- a/src/solver/conv_hip_implicit_gemm_nonxdlops_common.cpp +++ b/src/solver/conv_hip_implicit_gemm_nonxdlops_common.cpp @@ -33,6 +33,9 @@ namespace miopen { namespace solver { +namespace conv { + +using ProblemDescription = miopen::conv::ProblemDescription; bool PerformanceImplicitGemm::operator==(const PerformanceImplicitGemm& other) const { @@ -89,7 +92,7 @@ bool PerformanceImplicitGemm::IsValid(const ExecutionContext& ctx, N2 % InBlockCopyClusterLengths_N2 == 0)) return false; - if(problem.direction.IsBackwardWrW()) + if(problem.IsDirectionBackwardWrW()) { if(!((X * Y) % (EPerBlock / WeiBlockCopyClusterLengths_E) == 0)) return false; @@ -551,5 +554,6 @@ PerformanceImplicitGemm::PerformanceImplicitGemm(int BPerBlock_, { } +} // namespace conv } // namespace solver } // namespace miopen diff --git a/src/solver/conv_hip_implicit_gemm_wrw_v4r4.cpp b/src/solver/conv_hip_implicit_gemm_wrw_v4r4.cpp index c9b0d94241..8bdfb5a544 100644 --- a/src/solver/conv_hip_implicit_gemm_wrw_v4r4.cpp +++ b/src/solver/conv_hip_implicit_gemm_wrw_v4r4.cpp @@ -34,6 +34,9 @@ MIOPEN_DECLARE_ENV_VAR(MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_WRW_V4R4) namespace miopen { namespace solver { +namespace conv { + +using ProblemDescription = miopen::conv::ProblemDescription; PerformanceImplicitGemmV4R4WrW::PerformanceImplicitGemmV4R4WrW(int BlockSize_, int GemmMPerBlock_, @@ -592,7 +595,7 @@ bool ConvHipImplicitGemmV4R4WrW::IsApplicable(const ExecutionContext& ctx, return false; if(!IsComposableKernelSupportedHardware(ctx)) return false; - if(!problem.direction.IsBackwardWrW()) + if(!problem.IsDirectionBackwardWrW()) return false; if(!problem.Is2d() && !problem.Is3d()) return false; @@ -780,7 +783,7 @@ ConvHipImplicitGemmV4R4WrW::GetSolution(const ExecutionContext& ctx, result.invoker_factory = [](const std::vector& kernels) { return [=](const Handle& handle, const AnyInvokeParams& primitive_params) { - const auto& invoke_params = primitive_params.CastTo(); + const auto& invoke_params = primitive_params.CastTo(); const auto& tensors = invoke_params.tensors; handle.Run(kernels[0])(tensors.x, tensors.dy, tensors.dw); }; @@ -789,5 +792,6 @@ ConvHipImplicitGemmV4R4WrW::GetSolution(const ExecutionContext& ctx, return result; } +} // namespace conv } // namespace solver } // namespace miopen diff --git a/src/solver/conv_hip_implicit_gemm_wrw_v4r4_xdlops.cpp b/src/solver/conv_hip_implicit_gemm_wrw_v4r4_xdlops.cpp index 0f38df9e6c..f110bc8695 100644 --- a/src/solver/conv_hip_implicit_gemm_wrw_v4r4_xdlops.cpp +++ b/src/solver/conv_hip_implicit_gemm_wrw_v4r4_xdlops.cpp @@ -39,6 +39,9 @@ MIOPEN_DECLARE_ENV_VAR(MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_WRW_V4R4_XDLOPS) namespace miopen { namespace solver { +namespace conv { + +using ProblemDescription = miopen::conv::ProblemDescription; PerformanceImplicitGemmWrwV4R4Xdlops::PerformanceImplicitGemmWrwV4R4Xdlops() : PerformanceImplicitGemmWrwV4R4Xdlops::PerformanceImplicitGemmWrwV4R4Xdlops( @@ -985,7 +988,7 @@ ConvSolution ConvHipImplicitGemmWrwV4R4Xdlops::GetSolution( const auto& lowp_quant = conv.lowp_quant; result.invoker_factory = [=](const std::vector& kernels) { return [=](const Handle& handle, const AnyInvokeParams& primitive_params) { - const auto& invoke_params = primitive_params.CastTo(); + const auto& invoke_params = primitive_params.CastTo(); const auto& tensors = invoke_params.tensors; auto kernel = handle.Run(kernels[0]); float elapsed = 0; @@ -1063,7 +1066,7 @@ bool ConvHipImplicitGemmWrwV4R4Xdlops::IsApplicable(const ExecutionContext& ctx, if(!(problem.IsFp32() || problem.IsFp16() || problem.IsBfp16())) return false; - if(!problem.direction.IsBackwardWrW()) + if(!problem.IsDirectionBackwardWrW()) return false; if(!problem.Is2d()) @@ -1125,5 +1128,7 @@ ConvHipImplicitGemmWrwV4R4Xdlops::GetWorkspaceSize(const ExecutionContext&, return miopen::GetTypeSize(miopenFloat) * k * c * y * x; } } + +} // namespace conv } // namespace solver } // namespace miopen diff --git a/src/solver/conv_hip_implicit_gemm_wrw_v4r4_xdlops_padded_gemm.cpp b/src/solver/conv_hip_implicit_gemm_wrw_v4r4_xdlops_padded_gemm.cpp index b910d9658a..19c320390d 100644 --- a/src/solver/conv_hip_implicit_gemm_wrw_v4r4_xdlops_padded_gemm.cpp +++ b/src/solver/conv_hip_implicit_gemm_wrw_v4r4_xdlops_padded_gemm.cpp @@ -42,6 +42,9 @@ MIOPEN_DECLARE_ENV_VAR(MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_WRW_V4R4_PADDED_GEMM_ namespace miopen { namespace solver { +namespace conv { + +using ProblemDescription = miopen::conv::ProblemDescription; PerformanceImplicitGemmWrwV4R4Xdlops_Padded_Gemm::PerformanceImplicitGemmWrwV4R4Xdlops_Padded_Gemm() : PerformanceImplicitGemmWrwV4R4Xdlops_Padded_Gemm:: @@ -1051,7 +1054,7 @@ ConvSolution ConvHipImplicitGemmWrwV4R4Xdlops_Padded_Gemm::GetSolution( const auto& lowp_quant = conv.lowp_quant; result.invoker_factory = [=](const std::vector& kernels) { return [=](const Handle& handle, const AnyInvokeParams& primitive_params) { - const auto& invoke_params = primitive_params.CastTo(); + const auto& invoke_params = primitive_params.CastTo(); const auto& tensors = invoke_params.tensors; auto kernel = handle.Run(kernels[0]); float elapsed = 0; @@ -1126,7 +1129,7 @@ bool ConvHipImplicitGemmWrwV4R4Xdlops_Padded_Gemm::IsApplicable( if(!(problem.IsFp32() || problem.IsFp16() || problem.IsBfp16())) return false; - if(!problem.direction.IsBackwardWrW()) + if(!problem.IsDirectionBackwardWrW()) return false; if(!problem.Is2d()) @@ -1212,5 +1215,7 @@ std::size_t ConvHipImplicitGemmWrwV4R4Xdlops_Padded_Gemm::GetWorkspaceSize( return miopen::GetTypeSize(miopenFloat) * k * c * y * x; } } + +} // namespace conv } // namespace solver } // namespace miopen diff --git a/src/solver/conv_mlir_igemm_bwd.cpp b/src/solver/conv_mlir_igemm_bwd.cpp index 6804cccf42..012f99b304 100644 --- a/src/solver/conv_mlir_igemm_bwd.cpp +++ b/src/solver/conv_mlir_igemm_bwd.cpp @@ -36,6 +36,9 @@ MIOPEN_DECLARE_ENV_VAR(MIOPEN_DEBUG_CONV_MLIR_IGEMM_BWD) namespace miopen { namespace solver { +namespace conv { + +using ProblemDescription = miopen::conv::ProblemDescription; bool ConvMlirIgemmBwd::IsApplicable(const ExecutionContext& ctx, const ProblemDescription& problem) const @@ -45,7 +48,7 @@ bool ConvMlirIgemmBwd::IsApplicable(const ExecutionContext& ctx, return false; if(problem.GetConv().attribute.deterministic) return false; - if(!problem.direction.IsBackwardData()) + if(!problem.IsDirectionBackwardData()) return false; if(problem.HasNonPackedTensors()) return false; @@ -123,7 +126,7 @@ ConvSolution ConvMlirIgemmBwd::GetSolution(const ExecutionContext& ctx, result.construction_params.push_back(construction_parameters); } - result.invoker_factory = conv::MakeMlirBwdInvokerFactory(problem); + result.invoker_factory = miopen::conv::MakeMlirBwdInvokerFactory(problem); return result; #else std::ignore = ctx; @@ -133,5 +136,6 @@ ConvSolution ConvMlirIgemmBwd::GetSolution(const ExecutionContext& ctx, #endif } +} // namespace conv } // namespace solver } // namespace miopen diff --git a/src/solver/conv_mlir_igemm_bwd_xdlops.cpp b/src/solver/conv_mlir_igemm_bwd_xdlops.cpp index 439b1ab394..71c1cb9020 100644 --- a/src/solver/conv_mlir_igemm_bwd_xdlops.cpp +++ b/src/solver/conv_mlir_igemm_bwd_xdlops.cpp @@ -37,6 +37,9 @@ MIOPEN_DECLARE_ENV_VAR(MIOPEN_DEBUG_CONV_MLIR_IGEMM_BWD_XDLOPS) namespace miopen { namespace solver { +namespace conv { + +using ProblemDescription = miopen::conv::ProblemDescription; bool ConvMlirIgemmBwdXdlops::IsApplicable(const ExecutionContext& ctx, const ProblemDescription& problem) const @@ -48,7 +51,7 @@ bool ConvMlirIgemmBwdXdlops::IsApplicable(const ExecutionContext& ctx, return false; if(!IsXdlopsSupport(ctx)) return false; - if(!problem.direction.IsBackwardData()) + if(!problem.IsDirectionBackwardData()) return false; if(problem.HasNonPackedTensors()) return false; @@ -119,7 +122,7 @@ ConvSolution ConvMlirIgemmBwdXdlops::GetSolution(const ExecutionContext& ctx, result.construction_params.push_back(construction_parameters); } - result.invoker_factory = conv::MakeMlirBwdInvokerFactory(problem); + result.invoker_factory = miopen::conv::MakeMlirBwdInvokerFactory(problem); return result; #else std::ignore = ctx; @@ -129,5 +132,6 @@ ConvSolution ConvMlirIgemmBwdXdlops::GetSolution(const ExecutionContext& ctx, #endif } +} // namespace conv } // namespace solver } // namespace miopen diff --git a/src/solver/conv_mlir_igemm_fwd.cpp b/src/solver/conv_mlir_igemm_fwd.cpp index 034313e1d0..a0e1accc7e 100644 --- a/src/solver/conv_mlir_igemm_fwd.cpp +++ b/src/solver/conv_mlir_igemm_fwd.cpp @@ -36,6 +36,9 @@ MIOPEN_DECLARE_ENV_VAR(MIOPEN_DEBUG_CONV_MLIR_IGEMM_FWD) namespace miopen { namespace solver { +namespace conv { + +using ProblemDescription = miopen::conv::ProblemDescription; void PerformanceConvMlirIgemm::SetMlirHeuristicInitRequest() { @@ -165,7 +168,7 @@ bool ConvMlirIgemmFwd::IsApplicable(const ExecutionContext& ctx, return false; if(problem.GetConv().attribute.deterministic) return false; - if(!problem.direction.IsForward()) + if(!problem.IsDirectionForward()) return false; if(problem.HasNonPackedTensors()) return false; @@ -215,7 +218,7 @@ ConvSolution ConvMlirIgemmFwd::GetSolution(const ExecutionContext& ctx, construction_parameters.g_wk.push_back(1); construction_parameters.g_wk.push_back(1); - result.invoker_factory = conv::MakeMlirFwdInvokerFactory(problem); + result.invoker_factory = miopen::conv::MakeMlirFwdInvokerFactory(problem); result.construction_params.push_back(construction_parameters); return result; #else @@ -226,5 +229,6 @@ ConvSolution ConvMlirIgemmFwd::GetSolution(const ExecutionContext& ctx, #endif } +} // namespace conv } // namespace solver } // namespace miopen diff --git a/src/solver/conv_mlir_igemm_fwd_xdlops.cpp b/src/solver/conv_mlir_igemm_fwd_xdlops.cpp index fa17520f07..5c26b30c26 100644 --- a/src/solver/conv_mlir_igemm_fwd_xdlops.cpp +++ b/src/solver/conv_mlir_igemm_fwd_xdlops.cpp @@ -37,6 +37,9 @@ MIOPEN_DECLARE_ENV_VAR(MIOPEN_DEBUG_CONV_MLIR_IGEMM_FWD_XDLOPS) namespace miopen { namespace solver { +namespace conv { + +using ProblemDescription = miopen::conv::ProblemDescription; void PerformanceConvMlirIgemmXdlops::SetMlirHeuristicInitRequest() { @@ -62,7 +65,7 @@ bool ConvMlirIgemmFwdXdlops::IsApplicable(const ExecutionContext& ctx, return false; if(!IsXdlopsSupport(ctx)) return false; - if(!problem.direction.IsForward()) + if(!problem.IsDirectionForward()) return false; if(problem.HasNonPackedTensors()) return false; @@ -238,7 +241,7 @@ ConvSolution ConvMlirIgemmFwdXdlops::GetSolution(const ExecutionContext& ctx, construction_parameters.g_wk.push_back(1); construction_parameters.g_wk.push_back(1); - result.invoker_factory = conv::MakeMlirFwdInvokerFactory(problem); + result.invoker_factory = miopen::conv::MakeMlirFwdInvokerFactory(problem); result.construction_params.push_back(construction_parameters); return result; #else @@ -249,5 +252,6 @@ ConvSolution ConvMlirIgemmFwdXdlops::GetSolution(const ExecutionContext& ctx, #endif } +} // namespace conv } // namespace solver } // namespace miopen diff --git a/src/solver/conv_mlir_igemm_wrw.cpp b/src/solver/conv_mlir_igemm_wrw.cpp index 7b8286d896..3a56d7eb7f 100644 --- a/src/solver/conv_mlir_igemm_wrw.cpp +++ b/src/solver/conv_mlir_igemm_wrw.cpp @@ -37,6 +37,9 @@ MIOPEN_DECLARE_ENV_VAR(MIOPEN_DEBUG_CONV_MLIR_IGEMM_WRW) namespace miopen { namespace solver { +namespace conv { + +using ProblemDescription = miopen::conv::ProblemDescription; bool ConvMlirIgemmWrW::IsApplicable(const ExecutionContext& ctx, const ProblemDescription& problem) const @@ -46,7 +49,7 @@ bool ConvMlirIgemmWrW::IsApplicable(const ExecutionContext& ctx, return false; if(problem.GetConv().attribute.deterministic) return false; - if(!problem.direction.IsBackwardWrW()) + if(!problem.IsDirectionBackwardWrW()) return false; if(!IsComposableKernelSupportedHardware(ctx)) return false; @@ -118,7 +121,7 @@ ConvSolution ConvMlirIgemmWrW::GetSolution(const ExecutionContext& ctx, construction_parameters.g_wk.push_back(1); construction_parameters.g_wk.push_back(1); - result.invoker_factory = conv::MakeMlirWrWInvokerFactory(problem, 0); + result.invoker_factory = miopen::conv::MakeMlirWrWInvokerFactory(problem, 0); result.construction_params.push_back(construction_parameters); return result; #else @@ -129,5 +132,6 @@ ConvSolution ConvMlirIgemmWrW::GetSolution(const ExecutionContext& ctx, #endif } +} // namespace conv } // namespace solver } // namespace miopen diff --git a/src/solver/conv_mlir_igemm_wrw_xdlops.cpp b/src/solver/conv_mlir_igemm_wrw_xdlops.cpp index 5dd480da63..32fd4a0d00 100644 --- a/src/solver/conv_mlir_igemm_wrw_xdlops.cpp +++ b/src/solver/conv_mlir_igemm_wrw_xdlops.cpp @@ -38,6 +38,9 @@ MIOPEN_DECLARE_ENV_VAR(MIOPEN_DEBUG_CONV_MLIR_IGEMM_WRW_XDLOPS) namespace miopen { namespace solver { +namespace conv { + +using ProblemDescription = miopen::conv::ProblemDescription; bool ConvMlirIgemmWrWXdlops::IsApplicable(const ExecutionContext& ctx, const ProblemDescription& problem) const @@ -49,7 +52,7 @@ bool ConvMlirIgemmWrWXdlops::IsApplicable(const ExecutionContext& ctx, return false; if(!IsXdlopsSupport(ctx)) return false; - if(!problem.direction.IsBackwardWrW()) + if(!problem.IsDirectionBackwardWrW()) return false; if(problem.HasNonPackedTensors()) return false; @@ -121,7 +124,7 @@ ConvSolution ConvMlirIgemmWrWXdlops::GetSolution(const ExecutionContext& ctx, } size_t workspace_req = GetWorkspaceSize(ctx, problem); - result.invoker_factory = conv::MakeMlirWrWInvokerFactory(problem, workspace_req); + result.invoker_factory = miopen::conv::MakeMlirWrWInvokerFactory(problem, workspace_req); result.workspace_sz = workspace_req; return result; #else @@ -145,5 +148,6 @@ std::size_t ConvMlirIgemmWrWXdlops::GetWorkspaceSize(const ExecutionContext& ctx #endif } +} // namespace conv } // namespace solver } // namespace miopen diff --git a/src/solver/conv_multipass_wino3x3WrW.cpp b/src/solver/conv_multipass_wino3x3WrW.cpp index 4f7a1626cc..3bf93a2c48 100644 --- a/src/solver/conv_multipass_wino3x3WrW.cpp +++ b/src/solver/conv_multipass_wino3x3WrW.cpp @@ -44,6 +44,10 @@ namespace miopen { namespace solver { +namespace conv { + +using ProblemDescription = miopen::conv::ProblemDescription; + MIOPEN_DECLARE_ENV_VAR(MIOPEN_DEBUG_AMD_WINOGRAD_MPASS_F3X2) MIOPEN_DECLARE_ENV_VAR(MIOPEN_DEBUG_AMD_WINOGRAD_MPASS_F3X3) MIOPEN_DECLARE_ENV_VAR(MIOPEN_DEBUG_AMD_WINOGRAD_MPASS_F3X4) @@ -58,14 +62,13 @@ MIOPEN_DECLARE_ENV_VAR(MIOPEN_DEBUG_AMD_WINOGRAD_MPASS_WORKSPACE_MAX) // Introduces a number of shader-specific aliases (names) in the current scope at zero cost. // These names represent shader parameters, e.g. shader C is batch_size etc and useful for // programming. -#define DEFINE_GETXFORMHWSIZE(problem) \ - const auto \ - wino_xform_h = \ - solver::ConvWinograd3x3MultipassWrW:: \ - GetSolverWinoXformHWSize(problem, 0), \ - wino_xform_w = \ - solver::ConvWinograd3x3MultipassWrW:: \ - GetSolverWinoXformHWSize(problem, 1); +#define DEFINE_GETXFORMHWSIZE(problem) \ + const auto wino_xform_h = \ + ConvWinograd3x3MultipassWrW:: \ + GetSolverWinoXformHWSize(problem, 0), \ + wino_xform_w = \ + ConvWinograd3x3MultipassWrW:: \ + GetSolverWinoXformHWSize(problem, 1); #define DEFINE_SHADER_ALIASES(problem) \ const int C = (problem).GetBatchSize_(); \ @@ -434,7 +437,7 @@ bool ConvWinograd3x3MultipassWrW return false; if(!problem.Is2d()) return false; - if(!problem.direction.IsBackwardWrW()) + if(!problem.IsDirectionBackwardWrW()) return false; if(problem.HasNonPackedTensors()) return false; @@ -623,7 +626,7 @@ ConvWinograd3x3MultipassWrW::Pre return [=](const std::vector& kernels) { return [=](const Handle& handle, const AnyInvokeParams& primitive_params) { - decltype(auto) invoke_params = primitive_params.CastTo(); + decltype(auto) invoke_params = primitive_params.CastTo(); const auto& tensors = invoke_params.tensors; float total_time = 0; @@ -779,5 +782,6 @@ template struct ConvWinograd3x3MultipassWrW<7, 3, 1, 1>; template struct ConvWinograd3x3MultipassWrW<5, 3>; template struct ConvWinograd3x3MultipassWrW<5, 4>; +} // namespace conv } // namespace solver } // namespace miopen diff --git a/src/solver/conv_ocl_dir2D11x11.cpp b/src/solver/conv_ocl_dir2D11x11.cpp index 482892946e..d6706c5403 100644 --- a/src/solver/conv_ocl_dir2D11x11.cpp +++ b/src/solver/conv_ocl_dir2D11x11.cpp @@ -35,6 +35,9 @@ MIOPEN_DECLARE_ENV_VAR(MIOPEN_DEBUG_CONV_DIRECT_OCL_FWD11X11) namespace miopen { namespace solver { +namespace conv { + +using ProblemDescription = miopen::conv::ProblemDescription; bool ConvOclDirectFwd11x11::IsApplicable(const ExecutionContext& ctx, const ProblemDescription& problem) const @@ -58,7 +61,7 @@ bool ConvOclDirectFwd11x11::IsApplicable(const ExecutionContext& ctx, if(!problem.IsLayoutDefault()) return false; - return problem.direction.IsForward() && problem.GetGroupCount() == 1 && + return problem.IsDirectionForward() && problem.GetGroupCount() == 1 && problem.GetDilationH() == 1 && problem.GetDilationW() == 1 && problem.GetWeightsHeight_() == 11 && problem.GetWeightsWidth_() == 11 && problem.GetKernelStrideH() == 4 && problem.GetKernelStrideW() == 4; @@ -68,7 +71,7 @@ ConvSolution ConvOclDirectFwd11x11::GetSolution(const ExecutionContext& ctx, const ProblemDescription& problem) const { ConvSolution result; - const bool is_forward = problem.direction.IsForward(); + const bool is_forward = problem.IsDirectionForward(); // size_t localMemSize = 64 * 1024; auto hw_wave_sz = 64; // auto dev_local_mem_sz = localMemSize; // in bytes @@ -342,8 +345,9 @@ ConvSolution ConvOclDirectFwd11x11::GetSolution(const ExecutionContext& ctx, MIOPEN_THROW("Two kernels were expected by solver"); return [=](const Handle& handle, const AnyInvokeParams& primitive_parameters) { - const auto& invoke_params = primitive_parameters.CastTo(); - const auto& tensors = invoke_params.tensors; + const auto& invoke_params = + primitive_parameters.CastTo(); + const auto& tensors = invoke_params.tensors; const auto first_pass_kernel = handle.Run(kernels[0]); const auto second_pass_kernel = handle.Run(kernels[1]); @@ -373,9 +377,11 @@ ConvSolution ConvOclDirectFwd11x11::GetSolution(const ExecutionContext& ctx, } else { - result.invoker_factory = &conv::MakeGenericXWYPadInvoker; + result.invoker_factory = &miopen::conv::MakeGenericXWYPadInvoker; } return result; } + +} // namespace conv } // namespace solver } // namespace miopen diff --git a/src/solver/conv_ocl_dir2D_bwdWrW_1x1.cpp b/src/solver/conv_ocl_dir2D_bwdWrW_1x1.cpp index 2c13d29f11..4056779ad7 100644 --- a/src/solver/conv_ocl_dir2D_bwdWrW_1x1.cpp +++ b/src/solver/conv_ocl_dir2D_bwdWrW_1x1.cpp @@ -38,6 +38,9 @@ MIOPEN_DECLARE_ENV_VAR(MIOPEN_DEBUG_CONV_DIRECT_OCL_WRW1X1) namespace miopen { namespace solver { +namespace conv { + +using ProblemDescription = miopen::conv::ProblemDescription; bool ConvOclBwdWrW1x1::IsApplicable(const ExecutionContext& ctx, const ProblemDescription& problem) const @@ -56,7 +59,7 @@ bool ConvOclBwdWrW1x1::IsApplicable(const ExecutionContext& ctx, return false; if(!problem.Is2d()) return false; - if(!problem.direction.IsBackwardWrW()) + if(!problem.IsDirectionBackwardWrW()) return false; if(problem.HasNonPackedTensors()) return false; @@ -456,9 +459,10 @@ ConvSolution ConvOclBwdWrW1x1::GetSolution(const ExecutionContext& ctx, { result.invoker_factory = [ws_sz](const std::vector& kernels) { return [=](const Handle& handle, const AnyInvokeParams& primitive_params) { - const auto ss_kernel = handle.Run(kernels[0]); - const auto main_kernel = handle.Run(kernels[1]); - const auto& invoke_params = primitive_params.CastTo(); + const auto ss_kernel = handle.Run(kernels[0]); + const auto main_kernel = handle.Run(kernels[1]); + const auto& invoke_params = + primitive_params.CastTo(); if(invoke_params.workSpaceSize < ws_sz) MIOPEN_THROW("Not enough workspace for ConvOclBwdWrW1x1"); @@ -487,10 +491,11 @@ ConvSolution ConvOclBwdWrW1x1::GetSolution(const ExecutionContext& ctx, { result.invoker_factory = [](const std::vector& kernels) { return [=](const Handle& handle, const AnyInvokeParams& primitive_params) { - const auto k = handle.Run(kernels[0]); - const auto& invoke_params = primitive_params.CastTo(); - const auto& tensors = invoke_params.tensors; - const auto padding_val = 0.f; + const auto k = handle.Run(kernels[0]); + const auto& invoke_params = + primitive_params.CastTo(); + const auto& tensors = invoke_params.tensors; + const auto padding_val = 0.f; visit_float(tensors.dyDesc.GetType(), [&](auto as_float) { k(tensors.dy, tensors.x, tensors.dw, as_float(padding_val)); @@ -501,5 +506,7 @@ ConvSolution ConvOclBwdWrW1x1::GetSolution(const ExecutionContext& ctx, } return result; } + +} // namespace conv } // namespace solver } // namespace miopen diff --git a/src/solver/conv_ocl_dir2D_bwdWrW_2.cpp b/src/solver/conv_ocl_dir2D_bwdWrW_2.cpp index c59449eab8..0dbd8a1930 100644 --- a/src/solver/conv_ocl_dir2D_bwdWrW_2.cpp +++ b/src/solver/conv_ocl_dir2D_bwdWrW_2.cpp @@ -41,6 +41,9 @@ MIOPEN_DECLARE_ENV_VAR(MIOPEN_DEBUG_CONV_DIRECT_OCL_WRW2) namespace miopen { namespace solver { +namespace conv { + +using ProblemDescription = miopen::conv::ProblemDescription; inline static bool Is_1_to_8(const int& v) { @@ -457,7 +460,7 @@ bool ConvOclBwdWrW2::IsApplicableBase(const ExecutionContext& ctx return false; if(!problem.Is2d()) return false; - if(!problem.direction.IsBackwardWrW()) + if(!problem.IsDirectionBackwardWrW()) return false; if(problem.IsAsymmetricPadH() || problem.IsAsymmetricPadW()) return false; @@ -623,8 +626,8 @@ ConvSolution ConvOclBwdWrW2::GetSolution( std::string UT_READ_TYPE = (utility_read_unit == 1) ? "_FLOAT" : "_FLOAT" + std::to_string((utility_read_unit)); - if(!problem.direction.IsBackwardWrW()) - MIOPEN_THROW("!problem.direction.IsBackwardWrW()"); + if(!problem.IsDirectionBackwardWrW()) + MIOPEN_THROW("!problem.IsDirectionBackwardWrW()"); // it's backward - inputs are outputs and vs versa const auto comp_options = std::string(" -DMLO_DIR_FORWARD=0") + std::string(" -DMLO_GRP_SZ=") + @@ -737,7 +740,7 @@ ConvSolution ConvOclBwdWrW2::GetSolution( const auto ws_sz = GetWorkspaceSize(ctx, problem); result.workspace_sz = ws_sz; - result.invoker_factory = conv::MakeOclWrWRdcInvokerFactory(n_batch_blks > 1, ws_sz); + result.invoker_factory = miopen::conv::MakeOclWrWRdcInvokerFactory(n_batch_blks > 1, ws_sz); return result; } @@ -766,5 +769,6 @@ template struct ConvOclBwdWrW2<4>; template struct ConvOclBwdWrW2<8>; template struct ConvOclBwdWrW2<16>; +} // namespace conv } // namespace solver } // namespace miopen diff --git a/src/solver/conv_ocl_dir2D_bwdWrW_53.cpp b/src/solver/conv_ocl_dir2D_bwdWrW_53.cpp index 7f50e7f3ee..2549cf11ad 100644 --- a/src/solver/conv_ocl_dir2D_bwdWrW_53.cpp +++ b/src/solver/conv_ocl_dir2D_bwdWrW_53.cpp @@ -34,6 +34,9 @@ MIOPEN_DECLARE_ENV_VAR(MIOPEN_DEBUG_CONV_DIRECT_OCL_WRW53) namespace miopen { namespace solver { +namespace conv { + +using ProblemDescription = miopen::conv::ProblemDescription; // Once the compiler fix (SWDEV-168168) is available, that version of compiler needs to be // checked to skip workarounds. Till then, true is returned in all cases so as to skip @@ -60,7 +63,7 @@ bool ConvOclBwdWrW53::IsApplicable(const ExecutionContext& ctx, if(problem.IsTensorsCasted()) return false; - if(!problem.direction.IsBackwardWrW()) + if(!problem.IsDirectionBackwardWrW()) return false; if(!problem.IsLayoutDefault()) return false; @@ -501,8 +504,8 @@ ConvSolution ConvOclBwdWrW53::GetSolution(const ExecutionContext& ctx, int n_input_channels_per_group = problem.GetOutChannels_() / problem.GetGroupCount(); int n_output_channels_per_group = problem.GetInChannels_() / problem.GetGroupCount(); - if(!problem.direction.IsBackwardWrW()) - MIOPEN_THROW("!problem.direction.IsBackwardWrW()"); + if(!problem.IsDirectionBackwardWrW()) + MIOPEN_THROW("!problem.IsDirectionBackwardWrW()"); // it's backward - inputs are outputs and vs versa auto comp_options = std::string(" -DMLO_DIR_FORWARD=0") + std::string(" -DMLO_GRP_SZ=") + @@ -647,9 +650,11 @@ ConvSolution ConvOclBwdWrW53::GetSolution(const ExecutionContext& ctx, const auto ws_sz = GetWorkspaceSize(ctx, problem); result.workspace_sz = ws_sz; - result.invoker_factory = conv::MakeOclWrWRdcInvokerFactory(n_batch_blks > 1, ws_sz); + result.invoker_factory = miopen::conv::MakeOclWrWRdcInvokerFactory(n_batch_blks > 1, ws_sz); return result; } + +} // namespace conv } // namespace solver } // namespace miopen diff --git a/src/solver/conv_ocl_dir2Dfwd.cpp b/src/solver/conv_ocl_dir2Dfwd.cpp index 6b7d2f1f8e..c07c410ba0 100644 --- a/src/solver/conv_ocl_dir2Dfwd.cpp +++ b/src/solver/conv_ocl_dir2Dfwd.cpp @@ -34,6 +34,9 @@ MIOPEN_DECLARE_ENV_VAR(MIOPEN_DEBUG_CONV_DIRECT_OCL_FWD) namespace miopen { namespace solver { +namespace conv { + +using ProblemDescription = miopen::conv::ProblemDescription; bool ConvOclDirectFwd::IsApplicable(const ExecutionContext& ctx, const ProblemDescription& problem) const @@ -46,7 +49,7 @@ bool ConvOclDirectFwd::IsApplicable(const ExecutionContext& ctx, return false; if(!problem.Is2d()) return false; - if(!(problem.direction.IsForward() || problem.direction.IsBackwardData())) + if(!(problem.IsDirectionForward() || problem.IsDirectionBackwardData())) return false; if(problem.HasNonPackedTensors()) return false; @@ -61,7 +64,7 @@ bool ConvOclDirectFwd::IsApplicable(const ExecutionContext& ctx, // clang-format off // Cases when dy has negative padding are not supported (issue 918) - if(problem.direction.IsBackwardData() + if(problem.IsDirectionBackwardData() && (problem.GetBackwardPadW() < 0 || problem.GetBackwardPadH() < 0)) return false; @@ -105,7 +108,7 @@ bool ConvOclDirectFwd::IsApplicable(const ExecutionContext& ctx, /// While MIOpenConvUni is up to 4x faster than MIOpenCDFGen (even not auto-tuned), /// it seems that is has 4x..20x worse precision, and some "test_conv --half" tests fail. /// See issue #1626. - && !(problem.direction.IsForward() + && !(problem.IsDirectionForward() && problem.IsFp16() && problem.GetKernelStrideW() == 2) && IsValidPerformanceConfig(ctx, problem, GetDefaultPerformanceConfig(ctx, problem)); @@ -127,7 +130,7 @@ bool ConvOclDirectFwd::IsValidPerformanceConfig(const ExecutionContext&, // auto pad_w = problem.GetPadW(); // auto pad_h = problem.GetPadH(); // auto hw_wave_sz = 64; - // if(!problem.direction.IsForward()) + // if(!problem.IsDirectionForward()) // { // // backward // pad_w = problem.GetBackwardPadW(); @@ -140,10 +143,10 @@ bool ConvOclDirectFwd::IsValidPerformanceConfig(const ExecutionContext&, config.n_out_pix_tiles); // hacky fix of the incorrect kernel local memory address calculation for data - result.out_pix_tile1 = (!problem.direction.IsForward() && problem.GetKernelStrideH() > 1) + result.out_pix_tile1 = (!problem.IsDirectionForward() && problem.GetKernelStrideH() > 1) ? problem.GetKernelStrideH() : config.out_pix_tile1; - result.out_pix_tile0 = (!problem.direction.IsForward() && problem.GetKernelStrideW() > 1) + result.out_pix_tile0 = (!problem.IsDirectionForward() && problem.GetKernelStrideW() > 1) ? problem.GetKernelStrideW() : config.out_pix_tile0; @@ -245,7 +248,7 @@ bool ConvOclDirectFwd::IsValidPerformanceConfig(const ExecutionContext&, // e.g. mlo_in_lcl_width here represents MLO_IN_LCL_WIDTH macro in the opencl source. long long mlo_in_lcl_width; long long mlo_in_lcl_height; - if(problem.direction.IsForward()) + if(problem.IsDirectionForward()) { mlo_in_lcl_width = ((mlo_in_tile0 - 1) * mlo_filter_stride0 + mlo_filter_size0); mlo_in_lcl_height = ((mlo_in_tile1 - 1) * mlo_filter_stride1 + mlo_filter_size1); @@ -290,7 +293,7 @@ ConvSolution ConvOclDirectFwd::BaseGetSolution(const ExecutionContext& ctx, auto hw_wave_sz = 64; // auto dev_local_mem_sz = localMemSize; // in bytes - if(!problem.direction.IsForward()) + if(!problem.IsDirectionForward()) { // backward pad_w = problem.GetBackwardPadW(); @@ -303,10 +306,10 @@ ConvSolution ConvOclDirectFwd::BaseGetSolution(const ExecutionContext& ctx, config.n_out_pix_tiles); // hacky fix of the incorrect kernel local memory address calculation for data - result.out_pix_tile1 = (!problem.direction.IsForward() && problem.GetKernelStrideH() > 1) + result.out_pix_tile1 = (!problem.IsDirectionForward() && problem.GetKernelStrideH() > 1) ? problem.GetKernelStrideH() : config.out_pix_tile1; - result.out_pix_tile0 = (!problem.direction.IsForward() && problem.GetKernelStrideW() > 1) + result.out_pix_tile0 = (!problem.IsDirectionForward() && problem.GetKernelStrideW() > 1) ? problem.GetKernelStrideW() : config.out_pix_tile0; @@ -368,7 +371,7 @@ ConvSolution ConvOclDirectFwd::BaseGetSolution(const ExecutionContext& ctx, kernel_params.comp_options = std::string(" -DMLO_HW_WAVE_SZ=") + std::to_string(static_cast(hw_wave_sz)) + - std::string(" -DMLO_DIR_FORWARD=") + (problem.direction.IsForward() ? "1" : "0") + + std::string(" -DMLO_DIR_FORWARD=") + (problem.IsDirectionForward() ? "1" : "0") + std::string(" -DMLO_FILTER_SIZE0=") + std::to_string(static_cast(problem.GetWeightsWidth_())) + std::string(" -DMLO_FILTER_SIZE1=") + @@ -476,10 +479,10 @@ ConvSolution ConvOclDirectFwd::BaseGetSolution(const ExecutionContext& ctx, kernel_params.kernel_name = "MIOpenConvUni"; result.construction_params.push_back(kernel_params); - result.invoker_factory = &conv::MakeGenericXWYPadInvoker; + result.invoker_factory = &miopen::conv::MakeGenericXWYPadInvoker; - if(problem.direction.IsForward()) - result.invoker_factory = &conv::MakeGenericXWYPadInvoker; + if(problem.IsDirectionForward()) + result.invoker_factory = &miopen::conv::MakeGenericXWYPadInvoker; return result; } @@ -500,5 +503,6 @@ ConvSolution ConvOclDirectFwd::GetSolution(const ExecutionContext& ctx, return result; } +} // namespace conv } // namespace solver } // namespace miopen diff --git a/src/solver/conv_ocl_dir2Dfwd1x1.cpp b/src/solver/conv_ocl_dir2Dfwd1x1.cpp index b3e0bd0a2b..dde0a23467 100644 --- a/src/solver/conv_ocl_dir2Dfwd1x1.cpp +++ b/src/solver/conv_ocl_dir2Dfwd1x1.cpp @@ -37,6 +37,9 @@ MIOPEN_DECLARE_ENV_VAR(MIOPEN_DEBUG_CONV_DIRECT_OCL_FWD1X1) namespace miopen { namespace solver { +namespace conv { + +using ProblemDescription = miopen::conv::ProblemDescription; bool ConvOclDirectFwd1x1::IsApplicable(const ExecutionContext& ctx, const ProblemDescription& problem) const @@ -55,7 +58,7 @@ bool ConvOclDirectFwd1x1::IsApplicable(const ExecutionContext& ctx, return false; if(!problem.Is2d()) return false; - if(!(problem.direction.IsForward() || problem.direction.IsBackwardData())) + if(!(problem.IsDirectionForward() || problem.IsDirectionBackwardData())) return false; if(problem.HasNonPackedTensors()) return false; @@ -86,7 +89,7 @@ ConvSolution ConvOclDirectFwd1x1::GetSolution(const ExecutionContext& ctx, { // int version = result.out_pix_tile1; - if((problem.direction.IsForward() && problem.GetInChannels_() % 16 == 0 && + if((problem.IsDirectionForward() && problem.GetInChannels_() % 16 == 0 && problem.GetOutChannels_() % 16 == 0) && (problem.GetInDataType() == miopenFloat)) { @@ -259,7 +262,7 @@ ConvSolution ConvOclDirectFwd1x1::GetSolution(const ExecutionContext& ctx, std::min(static_cast(problem.GetOutWidth_()), result.out_pix_tile0); result.out_pix_tile1 = std::min(static_cast(problem.GetOutHeight_()), result.out_pix_tile1); - if(!problem.direction.IsForward()) + if(!problem.IsDirectionForward()) { while(problem.GetOutWidth_() % result.out_pix_tile0 != 0 && result.out_pix_tile0 > 1) @@ -281,7 +284,7 @@ ConvSolution ConvOclDirectFwd1x1::GetSolution(const ExecutionContext& ctx, int wei_cstride = problem.GetWeightsWidth_() * problem.GetWeightsHeight_(); // backward: inputs are forward outputs - const bool is_forward = problem.direction.IsForward(); + const bool is_forward = problem.IsDirectionForward(); int wei_bstride = (is_forward ? problem.GetInChannels_() : problem.GetOutChannels_()) * wei_cstride; @@ -413,8 +416,10 @@ ConvSolution ConvOclDirectFwd1x1::GetSolution(const ExecutionContext& ctx, } } - result.invoker_factory = &conv::MakeGenericXWYPadInvoker; + result.invoker_factory = &miopen::conv::MakeGenericXWYPadInvoker; return result; } + +} // namespace conv } // namespace solver } // namespace miopen diff --git a/src/solver/conv_ocl_dir2Dfwd_exhaustive_search.cpp b/src/solver/conv_ocl_dir2Dfwd_exhaustive_search.cpp index 3024cd1bdd..f590067a6a 100644 --- a/src/solver/conv_ocl_dir2Dfwd_exhaustive_search.cpp +++ b/src/solver/conv_ocl_dir2Dfwd_exhaustive_search.cpp @@ -49,6 +49,9 @@ namespace miopen { namespace solver { +namespace conv { + +using ProblemDescription = miopen::conv::ProblemDescription; /* * select default configuration if a known configuration has not been found. @@ -87,7 +90,7 @@ LegacyPerformanceConfig ConvOclDirectFwdLegacyExhaustiveSearch::GetDefaultPerfor { // version - if(problem.GetInDataType() == miopenFloat && problem.direction.IsForward() && + if(problem.GetInDataType() == miopenFloat && problem.IsDirectionForward() && problem.GetInChannels_() % 16 == 0 && problem.GetOutChannels_() % 16 == 0) { result.n_in_data_tiles = 128; @@ -107,7 +110,7 @@ LegacyPerformanceConfig ConvOclDirectFwdLegacyExhaustiveSearch::GetDefaultPerfor if(problem.GetPadW() > 0 || problem.GetKernelStrideW() > 1) { - if(problem.direction.IsForward()) + if(problem.IsDirectionForward()) { result.out_pix_tile0 = (problem.GetOutWidth_() & 1) != 0 ? 1 : 2; } @@ -250,7 +253,7 @@ ConvOclDirectFwdLegacyExhaustiveSearch::SearchImpl(const ExecutionContext& ctx, candidate.n_stacks = 1; auto& profile_h = ctx.GetStream(); - const auto& invoke_params = invoke_ctx.CastTo(); + const auto& invoke_params = invoke_ctx.CastTo(); const auto bot_ocl_ptr = invoke_params.tensors.in; const auto top_ocl_ptr = invoke_params.tensors.out; const auto wei_ocl_ptr = invoke_params.tensors.w; @@ -312,7 +315,7 @@ ConvOclDirectFwdLegacyExhaustiveSearch::SearchImpl(const ExecutionContext& ctx, report_inteval = 5; // Add 1x1_stride : no padding support yet - if(problem.GetInDataType() == miopenFloat && problem.direction.IsForward() && + if(problem.GetInDataType() == miopenFloat && problem.IsDirectionForward() && problem.GetInChannels_() % 16 == 0 && problem.GetOutChannels_() % 16 == 0) { @@ -348,7 +351,7 @@ ConvOclDirectFwdLegacyExhaustiveSearch::SearchImpl(const ExecutionContext& ctx, } else { - if(problem.direction.IsForward()) + if(problem.IsDirectionForward()) { out_pix_tl_cnt = (problem.GetOutWidth_() & 1) != 0 ? 1 : 2; } @@ -664,5 +667,6 @@ ConvOclDirectFwdLegacyExhaustiveSearch::SearchImpl(const ExecutionContext& ctx, return result; } +} // namespace conv } // namespace solver } // namespace miopen diff --git a/src/solver/conv_ocl_dir2Dfwd_fused.cpp b/src/solver/conv_ocl_dir2Dfwd_fused.cpp index 72d393a8dd..0cc517be91 100644 --- a/src/solver/conv_ocl_dir2Dfwd_fused.cpp +++ b/src/solver/conv_ocl_dir2Dfwd_fused.cpp @@ -47,9 +47,9 @@ ConvOclDirectFwdFused::Search(const FusionContext& context, const FusionDescription& problem, const AnyInvokeParams& invoke_params) const { - const auto conv_problem = problem.GetConvProblem(0, conv::Direction::Forward); + const auto conv_problem = problem.GetConvProblem(0, miopen::conv::Direction::Forward); const auto conv_ctx = context.GetConvContext(conv_problem); - const auto legacy = ConvOclDirectFwd{}; + const auto legacy = conv::ConvOclDirectFwd{}; const auto& fusion_invoke_params = invoke_params.CastTo(); const auto wei_ocl_ptr = dynamic_cast( *fusion_invoke_params.op_args.params[0]) @@ -97,10 +97,10 @@ bool ConvOclDirectFwdFused::IsApplicable(const FusionContext& context, if(!(prim == miopenFusionOpActivForward)) return false; } - const auto conv_problem = problem.GetConvProblem(0, conv::Direction::Forward); + const auto conv_problem = problem.GetConvProblem(0, miopen::conv::Direction::Forward); if(!conv_problem.IsFp32()) return false; - const auto base = ConvOclDirectFwd{}; + const auto base = conv::ConvOclDirectFwd{}; const auto conv_ctx = context.GetConvContext(conv_problem); return base.IsApplicable(conv_ctx, conv_problem); } @@ -110,9 +110,9 @@ ConvOclDirectFwdFused::GetSolution(const FusionContext& context, const FusionDescription& problem, const PerformanceConfigConvOclDirectFwdFused& config) const { - const auto conv_problem = problem.GetConvProblem(0, conv::Direction::Forward); + const auto conv_problem = problem.GetConvProblem(0, miopen::conv::Direction::Forward); const auto conv_ctx = context.GetConvContext(conv_problem); - ConvSolution result = ConvOclDirectFwd::BaseGetSolution(conv_ctx, conv_problem, config); + ConvSolution result = conv::ConvOclDirectFwd::BaseGetSolution(conv_ctx, conv_problem, config); if(result.construction_params.size() != 1) MIOPEN_THROW("ConvOclDirectFwdFused expects only one kernel"); @@ -232,9 +232,9 @@ PerformanceConfigConvOclDirectFwdFused ConvOclDirectFwdFused::GetDefaultPerformanceConfig(const FusionContext& context, const FusionDescription& problem) const { - const auto base = ConvOclDirectFwd{}; + const auto base = conv::ConvOclDirectFwd{}; MIOPEN_LOG_I("Using Unfused class to initialize performance config"); - const auto conv_problem = problem.GetConvProblem(0, conv::Direction::Forward); + const auto conv_problem = problem.GetConvProblem(0, miopen::conv::Direction::Forward); const auto conv_ctx = context.GetConvContext(conv_problem); return base.GetDefaultPerformanceConfig(conv_ctx, conv_problem); } @@ -244,8 +244,8 @@ bool ConvOclDirectFwdFused::IsValidPerformanceConfig( const FusionDescription& problem, const PerformanceConfigConvOclDirectFwdFused& c) const { - const auto base = ConvOclDirectFwd{}; - const auto conv_problem = problem.GetConvProblem(0, conv::Direction::Forward); + const auto base = conv::ConvOclDirectFwd{}; + const auto conv_problem = problem.GetConvProblem(0, miopen::conv::Direction::Forward); const auto conv_ctx = context.GetConvContext(conv_problem); return base.IsValidPerformanceConfig(conv_ctx, conv_problem, c); } diff --git a/src/solver/conv_ocl_dir2Dfwdgen.cpp b/src/solver/conv_ocl_dir2Dfwdgen.cpp index 659f0ddf73..df9b7ab851 100644 --- a/src/solver/conv_ocl_dir2Dfwdgen.cpp +++ b/src/solver/conv_ocl_dir2Dfwdgen.cpp @@ -33,6 +33,9 @@ MIOPEN_DECLARE_ENV_VAR(MIOPEN_DEBUG_CONV_DIRECT_OCL_FWDGEN) namespace miopen { namespace solver { +namespace conv { + +using ProblemDescription = miopen::conv::ProblemDescription; bool ConvOclDirectFwdGen::IsApplicable(const ExecutionContext& ctx, const ProblemDescription& problem) const @@ -84,7 +87,7 @@ bool ConvOclDirectFwdGen::IsApplicable(const ExecutionContext& ctx, return false; } - return problem.direction.IsForward() + return problem.IsDirectionForward() && problem.GetKernelStrideW() == problem.GetKernelStrideH() && problem.GetPadW() == problem.GetPadH() && problem.GetDilationW() == 1 @@ -309,8 +312,10 @@ ConvSolution ConvOclDirectFwdGen::GetSolution(const ExecutionContext& ctx, ConvSolution result; result.construction_params.push_back(construction_params); - result.invoker_factory = &conv::MakeGenericXWYPadInvoker; + result.invoker_factory = &miopen::conv::MakeGenericXWYPadInvoker; return result; } + +} // namespace conv } // namespace solver } // namespace miopen diff --git a/src/solver/conv_winoRxS.cpp b/src/solver/conv_winoRxS.cpp index 6c7d94632f..67486a274c 100644 --- a/src/solver/conv_winoRxS.cpp +++ b/src/solver/conv_winoRxS.cpp @@ -164,16 +164,17 @@ static inline int GetBestNGroupParam(const int R, namespace miopen { namespace solver { +namespace conv { + +using ProblemDescription = miopen::conv::ProblemDescription; namespace { -// clang-format off - auto PerfFieldRules() - { - return seq::MakeRuleSet( - std::make_tuple(seq::Span{}, &PerformanceConfigConvBinWinogradRxS::n_groups) - ); - } -// clang-format on + +auto PerfFieldRules() +{ + return seq::MakeRuleSet(std::make_tuple(seq::Span{}, + &PerformanceConfigConvBinWinogradRxS::n_groups)); +} // Winograd v21 is preferred on Vega10/Vega20 ASICs due to ~25% performance regression with Winograd // v30. The exception is Winograd F(3,2) stride2 as this mode is unsupported in v21. Details: @@ -276,7 +277,7 @@ inline bool IsShaderConstraintsMet(const ProblemDescription& problem, { // Padding for bwd data shall not be negative. /// \todo Either remove WrW related code or re-use function from RxS - if(problem.direction.IsBackwardData()) + if(problem.IsDirectionBackwardData()) { if(!(0 <= problem.GetBackwardPadW() && problem.GetBackwardPadW() < std::pow(2, 16))) return false; @@ -311,7 +312,7 @@ void PerformanceConfigConvBinWinogradRxS::HeuristicInit(const ExecutionContext& return; } - if(problem.direction.IsBackwardWrW()) + if(problem.IsDirectionBackwardWrW()) { n_groups = GetBestNGroupParam(problem.GetInHeight_(), problem.GetInWidth_(), @@ -660,7 +661,7 @@ static bool IsApplicableBase(const ExecutionContext& ctx, const ProblemDescripti const auto n_inputs_per_group = problem.GetInChannels_() / problem.GetGroupCount(), n_outputs_per_group = problem.GetOutChannels_() / problem.GetGroupCount(); - if(problem.direction.IsBackwardWrW()) + if(problem.IsDirectionBackwardWrW()) { if(problem.GetKernelStrideW() == 2) return false; @@ -701,7 +702,7 @@ bool ConvBinWinoRxS::IsApplicable(const ExecutionContext& if(miopen::IsDisabled(MIOPEN_DEBUG_AMD_WINOGRAD_RXS_F2X3{})) return false; #if !WORKAROUND_ISSUE_1681 - if(problem.GetGroupCount() == 1 && !problem.direction.IsBackwardWrW()) + if(problem.GetGroupCount() == 1 && !problem.IsDirectionBackwardWrW()) return false; #endif } @@ -838,7 +839,7 @@ ConvSolution ConvBinWinoRxS::GetSolution( { kernel_postfix += "_stride1"; } - else if(problem.GetKernelStrideW() == 2 && !problem.direction.IsBackwardData()) + else if(problem.GetKernelStrideW() == 2 && !problem.IsDirectionBackwardData()) { kernel_postfix += "_stride2"; } @@ -879,8 +880,8 @@ ConvSolution ConvBinWinoRxS::GetSolution( int N, C, H, W, K, out_H, out_W, R, S, pad_H, pad_W; MemLayout_t d_layout, o_layout, f_layout; - const bool is_forward = problem.direction.IsForward(); - const bool is_backWrW = problem.direction.IsBackwardWrW(); + const bool is_forward = problem.IsDirectionForward(); + const bool is_backWrW = problem.IsDirectionBackwardWrW(); const int group_cnt = problem.GetGroupCount(); if(!is_backWrW) @@ -957,14 +958,14 @@ ConvSolution ConvBinWinoRxS::GetSolution( const auto k = handle.Run(kernels[0]); const auto data_tensors = - !is_backWrW ? primitive_params.CastTo().tensors.in - : primitive_params.CastTo().tensors.x; + !is_backWrW ? primitive_params.CastTo().tensors.in + : primitive_params.CastTo().tensors.x; const auto filter_tensors = - !is_backWrW ? primitive_params.CastTo().tensors.w - : primitive_params.CastTo().tensors.dy; + !is_backWrW ? primitive_params.CastTo().tensors.w + : primitive_params.CastTo().tensors.dy; const auto out_tensors = - !is_backWrW ? primitive_params.CastTo().tensors.out - : primitive_params.CastTo().tensors.dw; + !is_backWrW ? primitive_params.CastTo().tensors.out + : primitive_params.CastTo().tensors.dw; // clang-format off MIOPEN_LOG_I2(" N=" << N << " G=" << group_cnt << " C=" << C << " H=" << H << " W=" << W << " K=" << K @@ -1098,5 +1099,6 @@ ConvSolution ConvBinWinogradRxSf2x3g1::GetSolution(const ExecutionContext& ctx, template struct ConvBinWinoRxS<2, 3>; template struct ConvBinWinoRxS<3, 2>; +} // namespace conv } // namespace solver } // namespace miopen diff --git a/src/solver/conv_winoRxS_fused.cpp b/src/solver/conv_winoRxS_fused.cpp index 377656cfb7..9010a9a61a 100644 --- a/src/solver/conv_winoRxS_fused.cpp +++ b/src/solver/conv_winoRxS_fused.cpp @@ -62,13 +62,14 @@ namespace { // v30. The exception is Winograd F(3,2) stride2 as this mode is unsupported in v21. Details: // https://github.com/ROCmSoftwarePlatform/MIOpen/pull/1927#issuecomment-1412741130 template -inline bool IsWinogradV21Preferred(const std::string& asic, const ProblemDescription& problem) +inline bool IsWinogradV21Preferred(const std::string& asic, + const miopen::conv::ProblemDescription& problem) { return (StartsWith(asic, "gfx900") || StartsWith(asic, "gfx906")) && !(IS3X2 && problem.GetKernelStrideW() == 2); } -inline bool IsShaderConstraintsMetV21(const ProblemDescription& problem, +inline bool IsShaderConstraintsMetV21(const miopen::conv::ProblemDescription& problem, const int R, const int S, const int C, @@ -112,7 +113,7 @@ inline bool IsShaderConstraintsMetV21(const ProblemDescription& problem, // clang-format on } -inline bool IsShaderConstraintsMetV30(const ProblemDescription& problem, +inline bool IsShaderConstraintsMetV30(const miopen::conv::ProblemDescription& problem, const int R, const int S, const int C, diff --git a/src/solver/conv_wino_fury_RxS.cpp b/src/solver/conv_wino_fury_RxS.cpp index 150deac983..cf9e3a0ef6 100644 --- a/src/solver/conv_wino_fury_RxS.cpp +++ b/src/solver/conv_wino_fury_RxS.cpp @@ -35,6 +35,9 @@ MIOPEN_DECLARE_ENV_VAR(MIOPEN_DEBUG_AMD_WINOGRAD_FURY_RXS_F3X2) namespace miopen { namespace solver { +namespace conv { + +using ProblemDescription = miopen::conv::ProblemDescription; namespace { constexpr size_t max_cu_limit = 512; @@ -94,10 +97,10 @@ class ShaderModel : public UnifiedDescriptionConv2d Hs{Ceil(out_h, Toh)}, We{Tow * (Ceil(out_w, Tow) + Ceil(Tw, Tow) - 1)}, - W{static_cast(problem.direction.IsBackwardWrW() ? problem.GetOutWidth_() - : problem.GetInWidth_())}, - H{static_cast(problem.direction.IsBackwardWrW() ? problem.GetOutHeight_() - : problem.GetInHeight_())}, + W{static_cast(problem.IsDirectionBackwardWrW() ? problem.GetOutWidth_() + : problem.GetInWidth_())}, + H{static_cast(problem.IsDirectionBackwardWrW() ? problem.GetOutHeight_() + : problem.GetInHeight_())}, d_H_clip{static_cast(static_cast(Hs * Toh) - pad_h)}, d_W_clip{static_cast(We - pad_w)}, @@ -123,7 +126,7 @@ class ShaderModel : public UnifiedDescriptionConv2d n_groups{static_cast(groups)} { is_applicable = problem.IsFp16() && problem.Is2d() && problem.GetGroupCount() == 1 && - problem.GetInLayout() == "NCHW" && !problem.direction.IsBackwardWrW(); + problem.GetInLayout() == "NCHW" && !problem.IsDirectionBackwardWrW(); } float GetWTI() const { return -2.f; } // unknown @@ -276,7 +279,7 @@ ConvWinoFuryRxS::GetSolution(const ExecutionContext& ctx, kernel_name << "_stride1"; kernel_file << "_stride1"; } - else if(problem.GetKernelStrideW() == 2 && !problem.direction.IsBackwardData()) + else if(problem.GetKernelStrideW() == 2 && !problem.IsDirectionBackwardData()) { kernel_name << "_stride2"; kernel_file << "_stride2"; @@ -304,8 +307,8 @@ ConvWinoFuryRxS::GetSolution(const ExecutionContext& ctx, // constexpr uint32_t F_TENSOR_OFFSETS = 1 << 13; uint32_t flags = 0; - const bool is_forward = problem.direction.IsForward(); - const bool is_backWrW = problem.direction.IsBackwardWrW(); + const bool is_forward = problem.IsDirectionForward(); + const bool is_backWrW = problem.IsDirectionBackwardWrW(); const int group_cnt = problem.GetGroupCount(); MemLayout_t d_layout, o_layout, f_layout; @@ -361,14 +364,14 @@ ConvWinoFuryRxS::GetSolution(const ExecutionContext& ctx, return [=](const Handle& handle, const AnyInvokeParams& primitive_params) { const auto k = handle.Run(kernels[0]); const auto data_tensors = - !is_backWrW ? primitive_params.CastTo().tensors.in - : primitive_params.CastTo().tensors.x; + !is_backWrW ? primitive_params.CastTo().tensors.in + : primitive_params.CastTo().tensors.x; const auto filter_tensors = - !is_backWrW ? primitive_params.CastTo().tensors.w - : primitive_params.CastTo().tensors.dy; + !is_backWrW ? primitive_params.CastTo().tensors.w + : primitive_params.CastTo().tensors.dy; const auto out_tensors = - !is_backWrW ? primitive_params.CastTo().tensors.out - : primitive_params.CastTo().tensors.dw; + !is_backWrW ? primitive_params.CastTo().tensors.out + : primitive_params.CastTo().tensors.dw; float alpha_beta_reserved = 0.0f; uint64_t offset_reserved = 0; @@ -459,5 +462,6 @@ ConvWinoFuryRxS::GetSolution(const ExecutionContext& ctx, template struct ConvWinoFuryRxS<2, 3>; // template struct ConvWinoFuryRxS<3, 2>; +} // namespace conv } // namespace solver } // namespace miopen diff --git a/src/solver/fft.cpp b/src/solver/fft.cpp index 3ca98b4720..4302325dab 100644 --- a/src/solver/fft.cpp +++ b/src/solver/fft.cpp @@ -37,6 +37,9 @@ namespace miopen { namespace solver { +namespace conv { + +using ProblemDescription = miopen::conv::ProblemDescription; MIOPEN_DECLARE_ENV_VAR(MIOPEN_DEBUG_CONV_FFT) @@ -112,13 +115,13 @@ bool fft::IsApplicable(const ExecutionContext& ctx, const ProblemDescription& pr std::ignore = ctx; // disable running any FFT based convolutions by checking this env variable - if(problem.direction.IsBackwardWrW() || !problem.IsFp32()) + if(problem.IsDirectionBackwardWrW() || !problem.IsFp32()) return false; if(!problem.IsLayoutDefault()) return false; - const auto is_fwd = problem.direction.IsForward(); + const auto is_fwd = problem.IsDirectionForward(); decltype(auto) conv = problem.GetConv(); decltype(auto) xDesc = is_fwd ? problem.GetIn() : problem.GetOut(); decltype(auto) yDesc = is_fwd ? problem.GetOut() : problem.GetIn(); @@ -158,7 +161,7 @@ bool fft::IsApplicable(const ExecutionContext& ctx, const ProblemDescription& pr size_t fft::GetWorkspaceSize(const ExecutionContext&, const ProblemDescription& problem) const { - const auto fwd = problem.direction.IsForward(); + const auto fwd = problem.IsDirectionForward(); decltype(auto) xDesc = fwd ? problem.GetIn() : problem.GetOut(); decltype(auto) yDesc = fwd ? problem.GetOut() : problem.GetIn(); decltype(auto) wDesc = problem.GetWeights(); @@ -364,7 +367,7 @@ ConvSolution fft::GetSolution(const ExecutionContext& ctx, const ProblemDescript parms += " -DCFF_HALFW="; parms += std::to_string(workSpaceSize / (sizeof(float) * 2 * 2)); - if(!problem.direction.IsForward()) + if(!problem.IsDirectionForward()) { parms += " -DCFF_BACKWARD"; } @@ -422,7 +425,7 @@ ConvSolution fft::GetSolution(const ExecutionContext& ctx, const ProblemDescript const int padding = FFTConvParams::TransposePadding; return [=](const Handle& handle, const AnyInvokeParams& primitive_params) { - const auto& params = primitive_params.CastTo(); + const auto& params = primitive_params.CastTo(); const auto& tensors = params.tensors; if(params.workSpaceSize < workSpaceSize) @@ -482,5 +485,6 @@ ConvSolution fft::GetSolution(const ExecutionContext& ctx, const ProblemDescript return sol; } +} // namespace conv } // namespace solver } // namespace miopen diff --git a/src/solver/gemm.cpp b/src/solver/gemm.cpp index 2cfbd257dc..3ba7aa5ea7 100644 --- a/src/solver/gemm.cpp +++ b/src/solver/gemm.cpp @@ -50,6 +50,9 @@ MIOPEN_DECLARE_ENV_VAR(MIOPEN_CONV_PRECISE_ROCBLAS_TIMING) namespace miopen { namespace solver { +namespace conv { + +using ProblemDescription = miopen::conv::ProblemDescription; #if MIOPEN_USE_GEMM #ifdef CPPCHECK @@ -78,11 +81,10 @@ static inline bool IsAnyBufferFp16(const TensorDescriptor& xDesc, } #endif -bool GemmFwdBase::IsApplicable(const ExecutionContext& ctx, - const conv::ProblemDescription& problem) const +bool GemmFwdBase::IsApplicable(const ExecutionContext& ctx, const ProblemDescription& problem) const { #if MIOPEN_USE_GEMM - if(conv::solver::gemm::IsWorkaroundIssue1315(ctx)) + if(conv::gemm::IsWorkaroundIssue1315(ctx)) return false; const auto& xDesc = problem.GetIn(); const auto& wDesc = problem.GetWeights(); @@ -130,7 +132,7 @@ bool GemmFwdBase::IsApplicable(const ExecutionContext& ctx, MIOPEN_LOG_I2("GEMM not applicable for F8 on this GPU architecture"); return false; } - return problem.GetDirection() == conv::Direction::Forward && problem.IsLayoutDefault() && + return problem.IsDirectionForward() && problem.IsLayoutDefault() && !(IsAnyBufferBF16(xDesc, yDesc, wDesc) && !IsBf16Supported) && !(IsAnyBufferFp16(xDesc, yDesc, wDesc) && !IsFp16Supported); #else @@ -154,7 +156,7 @@ SlowdownFactor(int n_oper, const double oper_factor, const double multiple_oper_ return 1.0; } -float GemmFwdBase::GetWti(const ExecutionContext&, const conv::ProblemDescription& problem) const +float GemmFwdBase::GetWti(const ExecutionContext&, const ProblemDescription& problem) const { decltype(auto) conv = problem.GetConv(); decltype(auto) wDesc = problem.GetWeights(); @@ -236,7 +238,7 @@ float GemmFwdBase::GetWti(const ExecutionContext&, const conv::ProblemDescriptio #define MAX_MEM_ALLOC_SZ (std::min(handle.GetMaxMemoryAllocSize(), size_t(7287183769))) size_t GemmFwd1x1_0_2::GetWorkspaceSize(const ExecutionContext& context, - const conv::ProblemDescription& problem) const + const ProblemDescription& problem) const { #if MIOPEN_USE_GEMM decltype(auto) handle = context.GetStream(); @@ -274,7 +276,7 @@ size_t GemmFwd1x1_0_2::GetWorkspaceSize(const ExecutionContext& context, } bool GemmFwd1x1_0_2::IsApplicable(const ExecutionContext& context, - const conv::ProblemDescription& problem) const + const ProblemDescription& problem) const { #if MIOPEN_USE_GEMM if(!GemmFwdBase::IsApplicable(context, problem)) @@ -299,7 +301,7 @@ bool GemmFwd1x1_0_2::IsApplicable(const ExecutionContext& context, } ConvSolution GemmFwd1x1_0_2::GetSolution(const ExecutionContext& context, - const conv::ProblemDescription& problem) const + const ProblemDescription& problem) const { #if MIOPEN_USE_GEMM decltype(auto) conv = problem.GetConv(); @@ -354,7 +356,7 @@ ConvSolution GemmFwd1x1_0_2::GetSolution(const ExecutionContext& context, return [=](const Handle& handle, const AnyInvokeParams& primitive_params) { float time_gemm = 0; - const auto& conv_params = primitive_params.CastTo(); + const auto& conv_params = primitive_params.CastTo(); const auto& workSpace = conv_params.workSpace; const auto workSpaceSize = conv_params.workSpaceSize; const auto x = conv_params.tensors.in; @@ -517,7 +519,7 @@ ConvSolution GemmFwd1x1_0_2::GetSolution(const ExecutionContext& context, } size_t GemmFwd1x1_0_1_int8::GetWorkspaceSize(const ExecutionContext& context, - const conv::ProblemDescription& problem) const + const ProblemDescription& problem) const { #if MIOPEN_USE_GEMM decltype(auto) handle = context.GetStream(); @@ -555,7 +557,7 @@ size_t GemmFwd1x1_0_1_int8::GetWorkspaceSize(const ExecutionContext& context, } bool GemmFwd1x1_0_1_int8::IsApplicable(const ExecutionContext& context, - const conv::ProblemDescription& problem) const + const ProblemDescription& problem) const { #if MIOPEN_USE_GEMM if(!GemmFwdBase::IsApplicable(context, problem)) @@ -582,7 +584,7 @@ bool GemmFwd1x1_0_1_int8::IsApplicable(const ExecutionContext& context, } ConvSolution GemmFwd1x1_0_1_int8::GetSolution(const ExecutionContext& context, - const conv::ProblemDescription& problem) const + const ProblemDescription& problem) const { #if MIOPEN_USE_GEMM decltype(auto) conv = problem.GetConv(); @@ -637,7 +639,7 @@ ConvSolution GemmFwd1x1_0_1_int8::GetSolution(const ExecutionContext& context, (!IsDisabled(MIOPEN_CONV_PRECISE_ROCBLAS_TIMING{})); return [=](const Handle& handle, const AnyInvokeParams& primitive_params) { - const auto& conv_params = primitive_params.CastTo(); + const auto& conv_params = primitive_params.CastTo(); const auto& workSpace = conv_params.workSpace; const auto workSpaceSize = conv_params.workSpaceSize; const auto x = conv_params.tensors.in; @@ -726,14 +728,13 @@ ConvSolution GemmFwd1x1_0_1_int8::GetSolution(const ExecutionContext& context, #endif } -size_t GemmFwd1x1_0_1::GetWorkspaceSize(const ExecutionContext&, - const conv::ProblemDescription&) const +size_t GemmFwd1x1_0_1::GetWorkspaceSize(const ExecutionContext&, const ProblemDescription&) const { return 0; } bool GemmFwd1x1_0_1::IsApplicable(const ExecutionContext& context, - const conv::ProblemDescription& problem) const + const ProblemDescription& problem) const { #if MIOPEN_USE_GEMM if(!GemmFwdBase::IsApplicable(context, problem)) @@ -757,7 +758,7 @@ bool GemmFwd1x1_0_1::IsApplicable(const ExecutionContext& context, } ConvSolution GemmFwd1x1_0_1::GetSolution(const ExecutionContext& context, - const conv::ProblemDescription& problem) const + const ProblemDescription& problem) const { #if MIOPEN_USE_GEMM decltype(auto) conv = problem.GetConv(); @@ -813,7 +814,7 @@ ConvSolution GemmFwd1x1_0_1::GetSolution(const ExecutionContext& context, return [=](const Handle& handle, const AnyInvokeParams& primitive_params) { float time_gemm = 0; - const auto& conv_params = primitive_params.CastTo(); + const auto& conv_params = primitive_params.CastTo(); const auto x = conv_params.tensors.in; const auto w = conv_params.tensors.w; const auto y = conv_params.tensors.out; @@ -912,12 +913,13 @@ ConvSolution GemmFwd1x1_0_1::GetSolution(const ExecutionContext& context, (!IsDisabled(MIOPEN_CONV_PRECISE_ROCBLAS_TIMING{})); return [=](const Handle& handle, const AnyInvokeParams& primitive_params) { - float time = 0; - decltype(auto) conv_params = primitive_params.CastTo(); - const auto& tensors = conv_params.tensors; - const auto& x = tensors.in; - const auto& w = tensors.w; - const auto& y = tensors.out; + float time = 0; + decltype(auto) conv_params = + primitive_params.CastTo(); + const auto& tensors = conv_params.tensors; + const auto& x = tensors.in; + const auto& w = tensors.w; + const auto& y = tensors.out; MIOPEN_LOG_FUNCTION("convolution, 1x1"); @@ -972,7 +974,7 @@ ConvSolution GemmFwd1x1_0_1::GetSolution(const ExecutionContext& context, } size_t GemmFwdRest::GetWorkspaceSize(const ExecutionContext& context, - const conv::ProblemDescription& problem) const + const ProblemDescription& problem) const { #if MIOPEN_USE_GEMM decltype(auto) handle = context.GetStream(); @@ -1012,7 +1014,7 @@ size_t GemmFwdRest::GetWorkspaceSize(const ExecutionContext& context, } bool GemmFwdRest::IsApplicable(const ExecutionContext& context, - const conv::ProblemDescription& problem) const + const ProblemDescription& problem) const { #if MIOPEN_USE_GEMM if(!GemmFwdBase::IsApplicable(context, problem)) @@ -1072,7 +1074,7 @@ bool GemmFwdRest::IsApplicable(const ExecutionContext& context, } ConvSolution GemmFwdRest::GetSolution(const ExecutionContext& context, - const conv::ProblemDescription& problem) const + const ProblemDescription& problem) const { #if MIOPEN_USE_GEMM decltype(auto) conv = problem.GetConv(); @@ -1132,7 +1134,7 @@ ConvSolution GemmFwdRest::GetSolution(const ExecutionContext& context, return [=](const Handle& handle, const AnyInvokeParams& primitive_params) { float time_gemm = 0; - const auto& conv_params = primitive_params.CastTo(); + const auto& conv_params = primitive_params.CastTo(); const auto& workSpace = conv_params.workSpace; const auto workSpaceSize = conv_params.workSpaceSize; const auto x = conv_params.tensors.in; @@ -1273,5 +1275,6 @@ ConvSolution GemmFwdRest::GetSolution(const ExecutionContext& context, #endif } +} // namespace conv } // namespace solver } // namespace miopen diff --git a/src/solver/gemm_bwd.cpp b/src/solver/gemm_bwd.cpp index df7d08304b..2001cad929 100644 --- a/src/solver/gemm_bwd.cpp +++ b/src/solver/gemm_bwd.cpp @@ -52,6 +52,9 @@ MIOPEN_DECLARE_ENV_VAR(MIOPEN_CONV_PRECISE_ROCBLAS_TIMING) namespace miopen { namespace solver { +namespace conv { + +using ProblemDescription = miopen::conv::ProblemDescription; #if MIOPEN_USE_GEMM #ifdef CPPCHECK @@ -94,11 +97,10 @@ SlowdownFactor(int n_oper, const double oper_factor, const double multiple_oper_ return 1.0; } -bool GemmBwdBase::IsApplicable(const ExecutionContext& ctx, - const conv::ProblemDescription& problem) const +bool GemmBwdBase::IsApplicable(const ExecutionContext& ctx, const ProblemDescription& problem) const { #if MIOPEN_USE_GEMM - if(conv::solver::gemm::IsWorkaroundIssue1315(ctx)) + if(conv::gemm::IsWorkaroundIssue1315(ctx)) return false; const auto& dyDesc = problem.GetIn(); const auto& wDesc = problem.GetWeights(); @@ -139,7 +141,7 @@ bool GemmBwdBase::IsApplicable(const ExecutionContext& ctx, MIOPEN_LOG_I2("GEMM not applicable for F8 on this GPU architecture"); return false; } - return problem.GetDirection() == conv::Direction::BackwardData && problem.IsLayoutDefault() && + return problem.IsDirectionBackwardData() && problem.IsLayoutDefault() && !(IsAnyBufferBF16(dxDesc, dyDesc, wDesc) && !IsBf16Supported) && !(IsAnyBufferFp16(dxDesc, dyDesc, wDesc) && !IsFp16Supported); #else @@ -149,7 +151,7 @@ bool GemmBwdBase::IsApplicable(const ExecutionContext& ctx, #endif } -float GemmBwdBase::GetWti(const ExecutionContext&, const conv::ProblemDescription& problem) const +float GemmBwdBase::GetWti(const ExecutionContext&, const ProblemDescription& problem) const { const auto& conv = problem.GetConv(); const auto& wDesc = problem.GetWeights(); @@ -204,7 +206,7 @@ float GemmBwdBase::GetWti(const ExecutionContext&, const conv::ProblemDescriptio } size_t GemmBwd1x1_stride2::GetWorkspaceSize(const ExecutionContext& context, - const conv::ProblemDescription& problem) const + const ProblemDescription& problem) const { #if MIOPEN_USE_GEMM auto& handle = context.GetStream(); @@ -243,7 +245,7 @@ size_t GemmBwd1x1_stride2::GetWorkspaceSize(const ExecutionContext& context, } bool GemmBwd1x1_stride2::IsApplicable(const ExecutionContext& context, - const conv::ProblemDescription& problem) const + const ProblemDescription& problem) const { #if MIOPEN_USE_GEMM if(!GemmBwdBase::IsApplicable(context, problem)) @@ -268,7 +270,7 @@ bool GemmBwd1x1_stride2::IsApplicable(const ExecutionContext& context, } ConvSolution GemmBwd1x1_stride2::GetSolution(const ExecutionContext& context, - const conv::ProblemDescription& problem) const + const ProblemDescription& problem) const { #if MIOPEN_USE_GEMM const auto& dyDesc = problem.GetIn(); @@ -315,7 +317,7 @@ ConvSolution GemmBwd1x1_stride2::GetSolution(const ExecutionContext& context, const bool time_precision = (!IsDisabled(MIOPEN_CONV_PRECISE_ROCBLAS_TIMING{})); return [=](const Handle& handle, const AnyInvokeParams& primitive_params) { - const auto& conv_params = primitive_params.CastTo(); + const auto& conv_params = primitive_params.CastTo(); const auto& workspace = conv_params.workSpace; const auto& workspace_size = conv_params.workSpaceSize; const auto& dy = conv_params.tensors.in; @@ -451,13 +453,13 @@ ConvSolution GemmBwd1x1_stride2::GetSolution(const ExecutionContext& context, } size_t GemmBwd1x1_stride1::GetWorkspaceSize(const ExecutionContext&, - const conv::ProblemDescription&) const + const ProblemDescription&) const { return 0; } bool GemmBwd1x1_stride1::IsApplicableBeforeWorkaround(const ExecutionContext& context, - const conv::ProblemDescription& problem) const + const ProblemDescription& problem) const { #if MIOPEN_USE_GEMM if(!GemmBwdBase::IsApplicable(context, problem)) @@ -480,7 +482,7 @@ bool GemmBwd1x1_stride1::IsApplicableBeforeWorkaround(const ExecutionContext& co } bool GemmBwd1x1_stride1::IsApplicable(const ExecutionContext& context, - const conv::ProblemDescription& problem) const + const ProblemDescription& problem) const { #if MIOPEN_USE_GEMM && (!MIOPEN_USE_MIOPENGEMM || !WORKAROUND_MIOPENGEMM_ISSUE_59) return IsApplicableBeforeWorkaround(context, problem); @@ -492,7 +494,7 @@ bool GemmBwd1x1_stride1::IsApplicable(const ExecutionContext& context, } ConvSolution GemmBwd1x1_stride1::GetSolution(const ExecutionContext&, - const conv::ProblemDescription& problem) const + const ProblemDescription& problem) const { #if MIOPEN_USE_GEMM const auto group_count = problem.GetConv().group_count; @@ -505,7 +507,7 @@ ConvSolution GemmBwd1x1_stride1::GetSolution(const ExecutionContext&, const bool time_precision = (!IsDisabled(MIOPEN_CONV_PRECISE_ROCBLAS_TIMING{})); return [=](const Handle& handle, const AnyInvokeParams& primitive_params) { - const auto& conv_params = primitive_params.CastTo(); + const auto& conv_params = primitive_params.CastTo(); const auto& dy = conv_params.tensors.in; const auto& w = conv_params.tensors.w; const auto& dx = conv_params.tensors.out; @@ -639,7 +641,7 @@ ConvSolution GemmBwd1x1_stride1::GetSolution(const ExecutionContext&, } size_t GemmBwdRest::GetWorkspaceSize(const ExecutionContext& context, - const conv::ProblemDescription& problem) const + const ProblemDescription& problem) const { #if MIOPEN_USE_GEMM auto& handle = context.GetStream(); @@ -680,7 +682,7 @@ size_t GemmBwdRest::GetWorkspaceSize(const ExecutionContext& context, } bool GemmBwdRest::IsApplicable(const ExecutionContext& context, - const conv::ProblemDescription& problem) const + const ProblemDescription& problem) const { #if MIOPEN_USE_GEMM if(!GemmBwdBase::IsApplicable(context, problem)) @@ -697,7 +699,7 @@ bool GemmBwdRest::IsApplicable(const ExecutionContext& context, } ConvSolution GemmBwdRest::GetSolution(const ExecutionContext& context, - const conv::ProblemDescription& problem) const + const ProblemDescription& problem) const { #if MIOPEN_USE_GEMM const auto& dyDesc = problem.GetIn(); @@ -756,7 +758,7 @@ ConvSolution GemmBwdRest::GetSolution(const ExecutionContext& context, const bool time_precision = (!IsDisabled(MIOPEN_CONV_PRECISE_ROCBLAS_TIMING{})); return [=](const Handle& handle, const AnyInvokeParams& primitive_params) { - const auto& conv_params = primitive_params.CastTo(); + const auto& conv_params = primitive_params.CastTo(); const auto& workspace = conv_params.workSpace; const auto& workspace_size = conv_params.workSpaceSize; const auto& dy = conv_params.tensors.in; @@ -900,5 +902,6 @@ ConvSolution GemmBwdRest::GetSolution(const ExecutionContext& context, #endif } +} // namespace conv } // namespace solver } // namespace miopen diff --git a/src/solver/gemm_common.cpp b/src/solver/gemm_common.cpp index ff606265bc..22d4b3a2e6 100644 --- a/src/solver/gemm_common.cpp +++ b/src/solver/gemm_common.cpp @@ -38,8 +38,8 @@ #define WORKAROUND_ISSUE_1315 (MIOPEN_USE_MIOPENGEMM && (HIP_PACKAGE_VERSION_FLAT >= 4004000000ULL)) namespace miopen { -namespace conv { namespace solver { +namespace conv { namespace gemm { bool IsWorkaroundIssue1315(const miopen::ExecutionContext& ctx) @@ -54,6 +54,6 @@ bool IsWorkaroundIssue1315(const miopen::ExecutionContext& ctx) } } // namespace gemm -} // namespace solver } // namespace conv +} // namespace solver } // namespace miopen diff --git a/src/solver/gemm_wrw.cpp b/src/solver/gemm_wrw.cpp index 1bda06eca6..b7a206b171 100644 --- a/src/solver/gemm_wrw.cpp +++ b/src/solver/gemm_wrw.cpp @@ -19,6 +19,9 @@ MIOPEN_DECLARE_ENV_VAR(MIOPEN_CONV_PRECISE_ROCBLAS_TIMING) namespace miopen { namespace solver { +namespace conv { + +using ProblemDescription = miopen::conv::ProblemDescription; #if MIOPEN_USE_GEMM #ifdef CPPCHECK @@ -61,11 +64,10 @@ SlowdownFactor(int n_oper, const double oper_factor, const double multiple_oper_ } #endif -bool GemmWrwBase::IsApplicable(const ExecutionContext& ctx, - const conv::ProblemDescription& problem) const +bool GemmWrwBase::IsApplicable(const ExecutionContext& ctx, const ProblemDescription& problem) const { #if MIOPEN_USE_GEMM - if(conv::solver::gemm::IsWorkaroundIssue1315(ctx)) + if(conv::gemm::IsWorkaroundIssue1315(ctx)) return false; const auto& dyDesc = problem.GetIn(); const auto& dwDesc = problem.GetWeights(); @@ -106,8 +108,7 @@ bool GemmWrwBase::IsApplicable(const ExecutionContext& ctx, MIOPEN_LOG_I2("GEMM not applicable for F8 on this GPU architecture"); return false; } - return problem.GetDirection() == conv::Direction::BackwardWeights && - problem.IsLayoutDefault() && + return problem.IsDirectionBackwardWrW() && problem.IsLayoutDefault() && !(IsAnyBufferBF16(xDesc, dyDesc, dwDesc) && !IsBF16PathValid) && !(IsAnyBufferFp16(xDesc, dyDesc, dwDesc) && !IsFp16Supported); #else @@ -117,7 +118,7 @@ bool GemmWrwBase::IsApplicable(const ExecutionContext& ctx, #endif } -float GemmWrwBase::GetWti(const ExecutionContext&, const conv::ProblemDescription& problem) const +float GemmWrwBase::GetWti(const ExecutionContext&, const ProblemDescription& problem) const { #if MIOPEN_USE_GEMM const auto& dwDesc = problem.GetWeights(); @@ -165,7 +166,7 @@ float GemmWrwBase::GetWti(const ExecutionContext&, const conv::ProblemDescriptio } bool GemmWrw1x1_stride1::IsApplicable(const ExecutionContext& context, - const conv::ProblemDescription& problem) const + const ProblemDescription& problem) const { #if MIOPEN_USE_GEMM if(!GemmWrwBase::IsApplicable(context, problem)) @@ -188,7 +189,7 @@ bool GemmWrw1x1_stride1::IsApplicable(const ExecutionContext& context, } ConvSolution GemmWrw1x1_stride1::GetSolution(const ExecutionContext&, - const conv::ProblemDescription& problem) const + const ProblemDescription& problem) const { #if MIOPEN_USE_GEMM const auto& dyDesc = problem.GetIn(); @@ -246,7 +247,7 @@ ConvSolution GemmWrw1x1_stride1::GetSolution(const ExecutionContext&, const bool time_precision = (!IsDisabled(MIOPEN_CONV_PRECISE_ROCBLAS_TIMING{})); return [=](const Handle& handle, const AnyInvokeParams& primitive_params) { - const auto& conv_params = primitive_params.CastTo(); + const auto& conv_params = primitive_params.CastTo(); const auto& dy = conv_params.tensors.dy; const auto& dw = conv_params.tensors.dw; const auto& dwDesc_ = conv_params.tensors.dwDesc; @@ -351,7 +352,7 @@ ConvSolution GemmWrw1x1_stride1::GetSolution(const ExecutionContext&, } size_t GemmWrwUniversal::GetWorkspaceSize(const ExecutionContext& context, - const conv::ProblemDescription& problem) const + const ProblemDescription& problem) const { #if MIOPEN_USE_GEMM auto& handle = context.GetStream(); @@ -387,7 +388,7 @@ size_t GemmWrwUniversal::GetWorkspaceSize(const ExecutionContext& context, } bool GemmWrwUniversal::IsApplicable(const ExecutionContext& context, - const conv::ProblemDescription& problem) const + const ProblemDescription& problem) const { #if MIOPEN_USE_GEMM if(!GemmWrwBase::IsApplicable(context, problem)) @@ -403,7 +404,7 @@ bool GemmWrwUniversal::IsApplicable(const ExecutionContext& context, } ConvSolution GemmWrwUniversal::GetSolution(const ExecutionContext& context, - const conv::ProblemDescription& problem) const + const ProblemDescription& problem) const { #if MIOPEN_USE_GEMM const auto& dyDesc = problem.GetIn(); @@ -463,7 +464,7 @@ ConvSolution GemmWrwUniversal::GetSolution(const ExecutionContext& context, const bool time_precision = (!IsDisabled(MIOPEN_CONV_PRECISE_ROCBLAS_TIMING{})); return [=](const Handle& handle, const AnyInvokeParams& primitive_params) { - const auto& conv_params = primitive_params.CastTo(); + const auto& conv_params = primitive_params.CastTo(); const auto& dy = conv_params.tensors.dy; const auto& dyDesc_ = conv_params.tensors.dyDesc; const auto& dwDesc_ = conv_params.tensors.dwDesc; @@ -611,5 +612,6 @@ ConvSolution GemmWrwUniversal::GetSolution(const ExecutionContext& context, #endif } +} // namespace conv } // namespace solver } // namespace miopen diff --git a/src/solver/mlir_common.cpp b/src/solver/mlir_common.cpp index e54da1ac89..ffaf54fa17 100644 --- a/src/solver/mlir_common.cpp +++ b/src/solver/mlir_common.cpp @@ -74,16 +74,16 @@ static std::string GetIsaName(const miopen::TargetProperties& target) #endif } -std::string GetKernelName(const ProblemDescription& problem, bool is_xdlops, int kernel_id) +std::string GetKernelName(const conv::ProblemDescription& problem, bool is_xdlops, int kernel_id) { std::string version; std::string direction; - if(problem.direction.IsForward()) + if(problem.IsDirectionForward()) { version = "_v4r4"; direction = "_fwd"; } - else if(problem.direction.IsBackwardData()) + else if(problem.IsDirectionBackwardData()) { version = "_v4r1"; direction = "_bwd"; @@ -102,13 +102,13 @@ std::string GetKernelName(const ProblemDescription& problem, bool is_xdlops, int return kernel_name + std::to_string(kernel_id); } -static std::string GetOperation(const ProblemDescription& problem) +static std::string GetOperation(const conv::ProblemDescription& problem) { - if(problem.direction.IsForward()) + if(problem.IsDirectionForward()) { return "conv2d"; } - else if(problem.direction.IsBackwardData()) + else if(problem.IsDirectionBackwardData()) { return "conv2d_bwd_data"; } @@ -121,7 +121,7 @@ static std::string GetOperation(const ProblemDescription& problem) /* Construct the options string passed to MLIR to cause it to generate a given convolution.*/ std::string ConstructBuildOptions(const ExecutionContext& ctx, - const ProblemDescription& problem, + const conv::ProblemDescription& problem, bool is_xdlops, int kernel_id) { diff --git a/src/sqlite_db.cpp b/src/sqlite_db.cpp index 5b42ea13a6..a46c5b2cb8 100644 --- a/src/sqlite_db.cpp +++ b/src/sqlite_db.cpp @@ -29,7 +29,7 @@ #include #include #include -#include +#include #include #if MIOPEN_EMBED_DB @@ -447,10 +447,10 @@ SQLitePerfDb::SQLitePerfDb(const std::string& filename_, bool is_system_) // Check fields for the tables if(!dbInvalid) { - if(!CheckTableColumns(ProblemDescription::table_name(), prob_desc.FieldNames())) + if(!CheckTableColumns(conv::ProblemDescription::table_name(), prob_desc.FieldNames())) { std::ostringstream ss; - ss << "Invalid fields in table: " << ProblemDescription::table_name() + ss << "Invalid fields in table: " << conv::ProblemDescription::table_name() << " disabling access to " << filename; MIOPEN_LOG_W(ss.str()); dbInvalid = true; diff --git a/test/conv_common.hpp b/test/conv_common.hpp index a3c2e97463..60291b8d48 100644 --- a/test/conv_common.hpp +++ b/test/conv_common.hpp @@ -133,10 +133,9 @@ static inline bool skip_config(miopen::Handle& handle, if(convDesc.mode != miopenConvolution) return false; - const auto conv_problem = miopen::conv::ProblemDescription{ + const auto problem = miopen::conv::ProblemDescription{ xDesc, wDesc, yDesc, convDesc, miopen::conv::Direction::Forward}; - const auto problem = miopen::ProblemDescription{conv_problem}; - auto ctx = miopen::ExecutionContext{}; + auto ctx = miopen::ExecutionContext{}; ctx.do_search = false; ctx.save_srch_req = false; diff --git a/test/embed_sqlite.cpp b/test/embed_sqlite.cpp index 32aa1371a8..8c0d2e403b 100644 --- a/test/embed_sqlite.cpp +++ b/test/embed_sqlite.cpp @@ -64,9 +64,8 @@ struct EmbedSQLite : test_driver void run() { // create a context/problem decriptor - const auto conv_problem = miopen::conv::ProblemDescription{ + const auto problem = miopen::conv::ProblemDescription{ x.desc, w.desc, y.desc, filter, miopen::conv::Direction::Forward}; - const auto problem = miopen::ProblemDescription{conv_problem}; miopen::ExecutionContext ctx{}; ctx.SetStream(&handle); // Check PerfDb diff --git a/test/gpu_reference_kernel.cpp b/test/gpu_reference_kernel.cpp index 61db5257de..d31c295fb8 100644 --- a/test/gpu_reference_kernel.cpp +++ b/test/gpu_reference_kernel.cpp @@ -29,7 +29,6 @@ #include #include #include -#include #include #include #include diff --git a/test/gtest/group_conv3d_bwd.cpp b/test/gtest/group_conv3d_bwd.cpp index a9bffceff1..439dbb5eb7 100644 --- a/test/gtest/group_conv3d_bwd.cpp +++ b/test/gtest/group_conv3d_bwd.cpp @@ -80,15 +80,15 @@ void SolverBwd(const miopen::TensorDescriptor& inputDesc, TEST_P(ConvBwdSolverTest3D, CKGroupConvBwd3D) { - SolverBwd(input.desc, - in_dev.get(), - weights.desc, - wei_dev.get(), - output.desc, - out_dev.get(), - conv_desc, - conv_config, - test_skipped); + SolverBwd(input.desc, + in_dev.get(), + weights.desc, + wei_dev.get(), + output.desc, + out_dev.get(), + conv_desc, + conv_config, + test_skipped); } INSTANTIATE_TEST_SUITE_P(ConvBwdTest, diff --git a/test/gtest/group_conv3d_fwd.cpp b/test/gtest/group_conv3d_fwd.cpp index 18d54355e8..5e17aa7c0d 100644 --- a/test/gtest/group_conv3d_fwd.cpp +++ b/test/gtest/group_conv3d_fwd.cpp @@ -80,15 +80,15 @@ void SolverFwd(const miopen::TensorDescriptor& inputDesc, TEST_P(ConvFwdSolverTest3D, CKGroupConvFwd3D) { - SolverFwd(input.desc, - in_dev.get(), - weights.desc, - wei_dev.get(), - output.desc, - out_dev.get(), - conv_desc, - conv_config, - test_skipped); + SolverFwd(input.desc, + in_dev.get(), + weights.desc, + wei_dev.get(), + output.desc, + out_dev.get(), + conv_desc, + conv_config, + test_skipped); } INSTANTIATE_TEST_SUITE_P(ConvFwdTest, diff --git a/test/gtest/group_conv3d_wrw.cpp b/test/gtest/group_conv3d_wrw.cpp index 977a06220a..8633b78b1e 100644 --- a/test/gtest/group_conv3d_wrw.cpp +++ b/test/gtest/group_conv3d_wrw.cpp @@ -80,15 +80,15 @@ void SolverWrw(const miopen::TensorDescriptor& inputDesc, TEST_P(ConvWrwSolverTest3D, CKGroupConvWrw3D) { - SolverWrw(input.desc, - in_dev.get(), - weights.desc, - wei_dev.get(), - output.desc, - out_dev.get(), - conv_desc, - conv_config, - test_skipped); + SolverWrw(input.desc, + in_dev.get(), + weights.desc, + wei_dev.get(), + output.desc, + out_dev.get(), + conv_desc, + conv_config, + test_skipped); } INSTANTIATE_TEST_SUITE_P( diff --git a/test/gtest/group_conv_fwd.cpp b/test/gtest/group_conv_fwd.cpp index c8fdec4cae..4efea68c1d 100644 --- a/test/gtest/group_conv_fwd.cpp +++ b/test/gtest/group_conv_fwd.cpp @@ -80,15 +80,15 @@ void SolverFwd(const miopen::TensorDescriptor& inputDesc, TEST_P(ConvFwdSolverTestFloat, CKGroupConvFwd) { - SolverFwd(input.desc, - in_dev.get(), - weights.desc, - wei_dev.get(), - output.desc, - out_dev.get(), - conv_desc, - conv_config, - test_skipped); + SolverFwd(input.desc, + in_dev.get(), + weights.desc, + wei_dev.get(), + output.desc, + out_dev.get(), + conv_desc, + conv_config, + test_skipped); } INSTANTIATE_TEST_SUITE_P(ConvFwdTest, diff --git a/test/gtest/kernel_tuning_net.cpp b/test/gtest/kernel_tuning_net.cpp index 7657cbbf55..18f8d91a44 100644 --- a/test/gtest/kernel_tuning_net.cpp +++ b/test/gtest/kernel_tuning_net.cpp @@ -69,7 +69,7 @@ struct KernelTuningNetTest : public ::testing::TestWithParam }; template -void TestParameterPredictionModel(miopen::ProblemDescription problem, +void TestParameterPredictionModel(miopen::conv::ProblemDescription problem, bool expected_valid, std::string expected) { @@ -116,13 +116,13 @@ void TestParameterPredictionModel(miopen::ProblemDescription problem, TEST_P(KernelTuningNetTestFloat, ConvAsm1x1UParameterPredictionModelFloat) { - TestParameterPredictionModel( + TestParameterPredictionModel( problem, expected_valid, expected); } TEST_P(KernelTuningNetTestHalf, ConvAsm1x1UParameterPredictionModelHalf) { - TestParameterPredictionModel( + TestParameterPredictionModel( problem, expected_valid, expected); } diff --git a/test/gtest/nonpack_conv3d_fwd.cpp b/test/gtest/nonpack_conv3d_fwd.cpp index 35cc492c74..0d08e5215a 100644 --- a/test/gtest/nonpack_conv3d_fwd.cpp +++ b/test/gtest/nonpack_conv3d_fwd.cpp @@ -80,15 +80,15 @@ void SolverFwd(const miopen::TensorDescriptor& inputDesc, TEST_P(ConvFwdSolverTest3D, CKNonPackConvFwd3D) { - SolverFwd(input.desc, - in_dev.get(), - weights.desc, - wei_dev.get(), - output.desc, - out_dev.get(), - conv_desc, - conv_config, - test_skipped); + SolverFwd(input.desc, + in_dev.get(), + weights.desc, + wei_dev.get(), + output.desc, + out_dev.get(), + conv_desc, + conv_config, + test_skipped); } INSTANTIATE_TEST_SUITE_P(ConvFwdTest, diff --git a/test/gtest/solver_bwd.hpp b/test/gtest/solver_bwd.hpp index d99d149a94..cb55a5951b 100644 --- a/test/gtest/solver_bwd.hpp +++ b/test/gtest/solver_bwd.hpp @@ -53,16 +53,15 @@ struct ConvBwdSolverTest const auto tensors = miopen::ConvBwdTensors{ output.desc, out_dev.get(), weights.desc, wei_dev.get(), input.desc, in_dev.get()}; - const auto conv_problem = + const auto problem = miopen::conv::ProblemDescription(input.desc, weights.desc, output.desc, conv_desc, miopen::conv::Direction::BackwardData); - const auto problem = miopen::ProblemDescription{conv_problem}; const miopen::ExecutionContext ctx = [&] { auto tmp = miopen::ExecutionContext{&handle}; - problem.conv_problem.SetupFloats(tmp); + problem.SetupFloats(tmp); return tmp; }(); diff --git a/test/gtest/solver_bwd_f8.cpp b/test/gtest/solver_bwd_f8.cpp index 7c286cbfc2..6f2cda02bf 100644 --- a/test/gtest/solver_bwd_f8.cpp +++ b/test/gtest/solver_bwd_f8.cpp @@ -35,27 +35,28 @@ struct ConvBwdFp8Naive : ConvBwdSolverTest TEST_P(ConvBwdFp8, DISABLED_GemmBwd1x1_stride2) { - miopen::solver::GemmBwd1x1_stride2 solv{}; + miopen::solver::conv::GemmBwd1x1_stride2 solv{}; SolverBwd(solv); } TEST_P(ConvBwdFp8, DISABLED_GemmBwd1x1_stride1) { - miopen::solver::GemmBwd1x1_stride1 solv{}; + miopen::solver::conv::GemmBwd1x1_stride1 solv{}; SolverBwd(solv); } TEST_P(ConvBwdFp8, DISABLED_GemmBwdRest) { - miopen::solver::GemmBwdRest solv{}; + miopen::solver::conv::GemmBwdRest solv{}; SolverBwd(solv); } TEST_P(ConvBwdFp8Naive, DISABLED_Bwd) { - miopen::solver::ConvDirectNaiveConvBwd solv{}; + miopen::solver::conv::ConvDirectNaiveConvBwd solv{}; SolverBwd(solv); } + INSTANTIATE_TEST_SUITE_P(ConvBwdTest, ConvBwdFp8, testing::Combine(testing::Values(miopenConvolutionAlgoGEMM), diff --git a/test/gtest/solver_convasm3x3u.cpp b/test/gtest/solver_convasm3x3u.cpp index 91133c68ba..3ed1e466dc 100644 --- a/test/gtest/solver_convasm3x3u.cpp +++ b/test/gtest/solver_convasm3x3u.cpp @@ -31,8 +31,8 @@ struct ConvFwdSolverTestFloat : ConvFwdSolverTest TEST_P(ConvFwdSolverTestFloat, ConvASM3x3UFwd) { - miopen::solver::ConvAsm3x3U solv{}; - SolverFwd(solv); + miopen::solver::conv::ConvAsm3x3U solv{}; + SolverFwd(solv); } INSTANTIATE_TEST_SUITE_P(ConvFwdTest, diff --git a/test/gtest/solver_fwd.hpp b/test/gtest/solver_fwd.hpp index a78f65dc04..88fa9a9c55 100644 --- a/test/gtest/solver_fwd.hpp +++ b/test/gtest/solver_fwd.hpp @@ -50,21 +50,20 @@ struct ConvFwdSolverTest { auto&& handle = get_handle(); - const auto tensors = miopen::ConvFwdTensors{this->input.desc, + const auto tensors = miopen::ConvFwdTensors{this->input.desc, this->in_dev.get(), this->weights.desc, this->wei_dev.get(), this->output.desc, this->out_dev.get()}; - const auto problem = miopen::ProblemDescription( - miopen::conv::ProblemDescription{this->input.desc, - this->weights.desc, - this->output.desc, - this->conv_desc, - miopen::conv::Direction::Forward}); + const auto problem = miopen::conv::ProblemDescription(this->input.desc, + this->weights.desc, + this->output.desc, + this->conv_desc, + miopen::conv::Direction::Forward); const miopen::ExecutionContext ctx = [&] { auto tmp = miopen::ExecutionContext{&handle}; - problem.conv_problem.SetupFloats(tmp); + problem.SetupFloats(tmp); return tmp; }(); diff --git a/test/gtest/solver_fwd_f8.cpp b/test/gtest/solver_fwd_f8.cpp index 36f0ec67cd..a48513480c 100644 --- a/test/gtest/solver_fwd_f8.cpp +++ b/test/gtest/solver_fwd_f8.cpp @@ -35,27 +35,28 @@ struct ConvFwdFp8Naive : ConvFwdSolverTest TEST_P(ConvFwdFp8, DISABLED_GemmFwdRest) { - miopen::solver::GemmFwdRest solv{}; + miopen::solver::conv::GemmFwdRest solv{}; SolverFwd(solv); } TEST_P(ConvFwdFp8, DISABLED_GemmFwd1x1_0_2) { - miopen::solver::GemmFwd1x1_0_2 solv{}; + miopen::solver::conv::GemmFwd1x1_0_2 solv{}; SolverFwd(solv); } TEST_P(ConvFwdFp8, DISABLED_Gemm1x1x0x1) { - miopen::solver::GemmFwd1x1_0_1 solv{}; + miopen::solver::conv::GemmFwd1x1_0_1 solv{}; SolverFwd(solv); } TEST_P(ConvFwdFp8Naive, DISABLED_Fwd) { - miopen::solver::ConvDirectNaiveConvFwd solv{}; - SolverFwd(solv); + miopen::solver::conv::ConvDirectNaiveConvFwd solv{}; + SolverFwd(solv); } + INSTANTIATE_TEST_SUITE_P(ConvFwdTest, ConvFwdFp8, testing::Combine(testing::Values(miopenConvolutionAlgoGEMM), diff --git a/test/gtest/solver_wrw.hpp b/test/gtest/solver_wrw.hpp index a3268dbe2e..dde92e2071 100644 --- a/test/gtest/solver_wrw.hpp +++ b/test/gtest/solver_wrw.hpp @@ -53,15 +53,15 @@ struct ConvWrwSolverTest const auto tensors = miopen::ConvWrwTensors{ output.desc, out_dev.get(), input.desc, in_dev.get(), weights.desc, wei_dev.get()}; - const auto problem = miopen::ProblemDescription( - miopen::conv::ProblemDescription{output.desc, + const auto problem = + miopen::conv::ProblemDescription(output.desc, weights.desc, input.desc, conv_desc, - miopen::conv::Direction::BackwardWeights}); + miopen::conv::Direction::BackwardWeights); const miopen::ExecutionContext ctx = [&] { auto tmp = miopen::ExecutionContext{&handle}; - problem.conv_problem.SetupFloats(tmp); + problem.SetupFloats(tmp); return tmp; }(); diff --git a/test/gtest/solver_wrw_f8.cpp b/test/gtest/solver_wrw_f8.cpp index 76c608b622..a970effd02 100644 --- a/test/gtest/solver_wrw_f8.cpp +++ b/test/gtest/solver_wrw_f8.cpp @@ -24,15 +24,17 @@ * *******************************************************************************/ #include "solver_wrw.hpp" + struct ConvWrwFp8Naive : ConvWrwSolverTest { }; TEST_P(ConvWrwFp8Naive, DISABLED_Wrw) { - miopen::solver::ConvDirectNaiveConvWrw solv{}; - SolverWrw(solv); + miopen::solver::conv::ConvDirectNaiveConvWrw solv{}; + SolverWrw(solv); } + // Since NaiveConv is verified against the CPU, we are conservative in the number and type // of test cases we instantiate INSTANTIATE_TEST_SUITE_P(ConvWrwTest, diff --git a/test/gtest/tuna_net.cpp b/test/gtest/tuna_net.cpp index 8b4f76fbe0..24c744268e 100644 --- a/test/gtest/tuna_net.cpp +++ b/test/gtest/tuna_net.cpp @@ -66,7 +66,7 @@ struct TunaNetTest : public ::testing::TestWithParam GTEST_SKIP(); #endif } - miopen::ProblemDescription problem; + miopen::conv::ProblemDescription problem; std::size_t expected_solver; }; @@ -82,7 +82,8 @@ struct TunaNetTestBF16 : TunaNetTest { }; -void TestSolverPredictionModel(miopen::ProblemDescription& problem, std::size_t expected_solver) +void TestSolverPredictionModel(miopen::conv::ProblemDescription& problem, + std::size_t expected_solver) { #if MIOPEN_ENABLE_AI_IMMED_MODE_FALLBACK auto&& handle = get_handle(); diff --git a/test/solver.cpp b/test/solver.cpp index d61524a29e..86d9a0671d 100644 --- a/test/solver.cpp +++ b/test/solver.cpp @@ -42,20 +42,22 @@ namespace miopen { namespace tests { -class TrivialTestSolver final : public solver::ConvSolver + +class TrivialTestSolver final : public solver::conv::ConvSolver { public: static const char* FileName() { return "TrivialTestSolver"; } const std::string& SolverDbId() const override { return GetSolverDbId(); } - bool IsApplicable(const ExecutionContext&, const ProblemDescription& problem) const override + bool IsApplicable(const ExecutionContext&, + const conv::ProblemDescription& problem) const override { return problem.GetInWidth_() == 1; } solver::ConvSolution GetSolution(const ExecutionContext&, - const ProblemDescription&) const override + const conv::ProblemDescription&) const override { solver::ConvSolution ret; solver::KernelInfo kernel; @@ -79,7 +81,7 @@ struct TestConfig : solver::PerfConfigBase } }; -class SearchableTestSolver final : public solver::ConvTunableSolver +class SearchableTestSolver final : public solver::conv::ConvTunableSolver { public: static int searches_done() { return _serches_done; } @@ -88,13 +90,13 @@ class SearchableTestSolver final : public solver::ConvTunableSolver const std::string& SolverDbId() const override { return GetSolverDbId(); } - bool IsApplicable(const ExecutionContext&, const ProblemDescription&) const override + bool IsApplicable(const ExecutionContext&, const conv::ProblemDescription&) const override { return true; } TestConfig GetDefaultPerformanceConfig(const ExecutionContext&, - const ProblemDescription&) const override + const conv::ProblemDescription&) const override { TestConfig config{}; config.str = NoSearchFileName(); @@ -102,14 +104,14 @@ class SearchableTestSolver final : public solver::ConvTunableSolver } bool IsValidPerformanceConfig(const ExecutionContext&, - const ProblemDescription&, + const conv::ProblemDescription&, const TestConfig&) const override { return true; } TestConfig Search(const ExecutionContext&, - const ProblemDescription&, + const conv::ProblemDescription&, const AnyInvokeParams&) const override { TestConfig config; @@ -119,7 +121,7 @@ class SearchableTestSolver final : public solver::ConvTunableSolver } solver::ConvSolution GetSolution(const ExecutionContext&, - const ProblemDescription&, + const conv::ProblemDescription&, const TestConfig& config) const override { @@ -141,7 +143,7 @@ class SearchableTestSolver final : public solver::ConvTunableSolver int SearchableTestSolver::_serches_done = 0; static solver::ConvSolution FindSolution(const ExecutionContext& ctx, - const ProblemDescription& problem, + const conv::ProblemDescription& problem, const std::string& db_path) { PlainTextDb db(db_path); @@ -225,6 +227,7 @@ class SolverTest EXPECT_EQUAL(sol.construction_params[0].kernel_file, expected_kernel); } }; + } // namespace tests } // namespace miopen diff --git a/test/sqlite_perfdb.cpp b/test/sqlite_perfdb.cpp index f435c10b35..2d67312132 100644 --- a/test/sqlite_perfdb.cpp +++ b/test/sqlite_perfdb.cpp @@ -27,7 +27,7 @@ #include "test.hpp" #include "driver.hpp" -#include +#include #include #include #include