Skip to content

Commit

Permalink
[NFC] Replace miopen::ProblemDescription with conv::ProblemDescriptio…
Browse files Browse the repository at this point in the history
…n, part 4 (#2410)
  • Loading branch information
averinevg authored Nov 2, 2023
1 parent 73a3710 commit 19c6046
Show file tree
Hide file tree
Showing 139 changed files with 2,300 additions and 2,061 deletions.
1 change: 0 additions & 1 deletion driver/conv_driver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@
#include <miopen/convolution.hpp>
#include <miopen/solver.hpp>
#include <miopen/find_controls.hpp>
#include <miopen/problem_description.hpp>
#include "random.hpp"
#include <numeric>
#include <sstream>
Expand Down
1 change: 0 additions & 1 deletion src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,6 @@ set( MIOpen_Source
performance_config.cpp
pooling/problem_description.cpp
pooling_api.cpp
problem_description.cpp
problem.cpp
ramdb.cpp
readonlyramdb.cpp
Expand Down
20 changes: 10 additions & 10 deletions src/conv/heuristics/ai_heuristics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,9 @@ class Model
{
}
virtual ~Model() = default;
virtual bool IsProblemSupported(const ProblemDescription& problem,
virtual bool IsProblemSupported(const conv::ProblemDescription& problem,
const ExecutionContext& ctx) const = 0;
std::vector<float> Forward(const ProblemDescription& problem) const
std::vector<float> Forward(const conv::ProblemDescription& problem) const
{
std::vector<float> features = ToFeatures(problem);
std::vector<fdeep::tensor> output = model.predict({fdeep::tensor(input_shape, features)});
Expand All @@ -142,14 +142,14 @@ class Model
MIOPEN_THROW(miopenStatusInternalError, "Unable to load AI model file:" + file_path);
return file_path;
}
virtual std::vector<float> ToFeatures(const ProblemDescription& problem) const = 0;
virtual std::vector<float> ToFeatures(const conv::ProblemDescription& problem) const = 0;
};

class Gfx908Model final : public Model
{
public:
Gfx908Model() : Model("gfx908") {}
bool IsProblemSupported(const ProblemDescription& problem,
bool IsProblemSupported(const conv::ProblemDescription& problem,
const ExecutionContext& ctx) const override
{
// check if problem is of the kind TunaNet was trained to handle
Expand Down Expand Up @@ -216,7 +216,7 @@ class Gfx908Model final : public Model
}

protected:
std::vector<float> ToFeatures(const ProblemDescription& problem) const override
std::vector<float> ToFeatures(const conv::ProblemDescription& problem) const override
{
const bool isFwd = problem.GetDirection() == conv::Direction::Forward;
std::vector<float> features = {
Expand Down Expand Up @@ -259,7 +259,7 @@ class Gfx90aModel final : public Model
{
public:
Gfx90aModel() : Model("gfx90a") {}
bool IsProblemSupported(const ProblemDescription& problem,
bool IsProblemSupported(const conv::ProblemDescription& problem,
const ExecutionContext& ctx) const override
{
// check if problem is of the kind TunaNet was trained to handle
Expand Down Expand Up @@ -317,7 +317,7 @@ class Gfx90aModel final : public Model
}

protected:
std::vector<float> ToFeatures(const ProblemDescription& problem) const override
std::vector<float> ToFeatures(const conv::ProblemDescription& problem) const override
{
const bool isFwd = problem.GetDirection() == conv::Direction::Forward;
std::vector<float> features = {
Expand Down Expand Up @@ -356,7 +356,7 @@ std::unique_ptr<Model> GetModel(const std::string& device)
return std::make_unique<Gfx908Model>();
}

std::vector<uint64_t> PredictSolver(const ProblemDescription& problem,
std::vector<uint64_t> PredictSolver(const conv::ProblemDescription& problem,
const ExecutionContext& ctx,
const std::string& device)
{
Expand All @@ -366,7 +366,7 @@ std::vector<uint64_t> PredictSolver(const ProblemDescription& problem,

std::string est_name = ":memory:" + device;
auto& db = AnyRamDb::GetCached(est_name);
auto db_res = db.FindRecord(static_cast<const conv::ProblemDescription&>(problem));
auto db_res = db.FindRecord(problem);
if(db_res)
{
MIOPEN_LOG_I2("Cached heuristic (TunaNet) result found");
Expand Down Expand Up @@ -415,7 +415,7 @@ std::vector<uint64_t> PredictSolver(const ProblemDescription& problem,
sol.push_back(sol_id.Value());
any_sol.push_back(sol_id.Value());
}
db.StoreRecord(static_cast<const conv::ProblemDescription&>(problem), any_sol);
db.StoreRecord(problem, any_sol);
if(miopen::IsLogging(LoggingLevel::Info2))
{
std::stringstream ss;
Expand Down
6 changes: 3 additions & 3 deletions src/conv/invokers/impl_gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
namespace miopen {
namespace conv {

InvokerFactory MakeImplGemmDataInvokerFactory(const miopen::ProblemDescription& problem)
InvokerFactory MakeImplGemmDataInvokerFactory(const ProblemDescription& problem)
{
if(problem.direction.IsForward())
if(problem.IsDirectionForward())
{
return [](const std::vector<Kernel>& kernels) {
return [=](const Handle& handle, const AnyInvokeParams& primitive_parameters) {
Expand All @@ -24,7 +24,7 @@ InvokerFactory MakeImplGemmDataInvokerFactory(const miopen::ProblemDescription&
}
else
{
if(problem.direction.IsBackwardWrW())
if(problem.IsDirectionBackwardWrW())
MIOPEN_THROW("MakeImplGemmDataInvokerFactory shouldn't be used for WrW invokers.");

const auto& conv = problem.GetConv();
Expand Down
26 changes: 11 additions & 15 deletions src/conv/invokers/impl_gemm_dynamic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,7 @@ static float CallImplGemmDynamicForward1x1(const miopen::Handle& handle,
return elapsed;
}

InvokerFactory
MakeImplGemmDynamicForward1x1InvokerFactory(const miopen::ProblemDescription& problem)
InvokerFactory MakeImplGemmDynamicForward1x1InvokerFactory(const ProblemDescription& problem)
{
return [problem](const std::vector<Kernel>& kernels) {
return [=](const Handle& handle, const AnyInvokeParams& primitive_parameters) {
Expand All @@ -112,10 +111,8 @@ MakeImplGemmDynamicForward1x1InvokerFactory(const miopen::ProblemDescription& pr
};
}

template <>
InvokerFactory
MakeImplGemmDynamicBackwardDataInvokerFactory<int>(const miopen::ProblemDescription& problem,
const int& cfg)
InvokerFactory MakeImplGemmDynamicBackwardDataInvokerFactory(const ProblemDescription& problem,
const int cfg)
{
int hi = problem.GetOutHeight_();
int wi = problem.GetOutWidth_();
Expand Down Expand Up @@ -249,10 +246,9 @@ MakeImplGemmDynamicBackwardDataInvokerFactory<int>(const miopen::ProblemDescript
};
}

template <>
InvokerFactory
MakeImplGemmDynamicBackwardDataInvokerFactory<solver::TunableImplicitGemmGTCDynamic_t>(
const miopen::ProblemDescription& problem, const solver::TunableImplicitGemmGTCDynamic_t& cfg)
MakeImplGemmDynamicBackwardDataInvokerFactory(const ProblemDescription& problem,
const solver::TunableImplicitGemmGTCDynamic_t& cfg)
{
int hi = problem.GetOutHeight_();
int wi = problem.GetOutWidth_();
Expand Down Expand Up @@ -439,8 +435,8 @@ MakeImplGemmDynamicBackwardDataInvokerFactory<solver::TunableImplicitGemmGTCDyna

InvokerFactory MakeImplGemmDynamicForwardXdlopsNHWCInvokerFactory(
const ExecutionContext& ctx,
const miopen::ProblemDescription& problem,
const solver::PerformanceConfigAsmImplicitGemmGTCFwdXdlopsNHWC& config)
const ProblemDescription& problem,
const solver::conv::PerformanceConfigAsmImplicitGemmGTCFwdXdlopsNHWC& config)
{
int hi = problem.GetInHeight_();
int wi = problem.GetInWidth_();
Expand Down Expand Up @@ -732,8 +728,8 @@ InvokerFactory MakeImplGemmDynamicForwardXdlopsNHWCInvokerFactory(

InvokerFactory MakeImplGemmDynamicBackwardDataXdlopsNHWCInvokerFactory(
const ExecutionContext& ctx,
const miopen::ProblemDescription& problem,
const solver::PerformanceConfigAsmImplicitGemmGTCBwdXdlopsNHWC& config)
const ProblemDescription& problem,
const solver::conv::PerformanceConfigAsmImplicitGemmGTCBwdXdlopsNHWC& config)
{
int hi = problem.GetOutHeight_();
int wi = problem.GetOutWidth_();
Expand Down Expand Up @@ -1047,8 +1043,8 @@ InvokerFactory MakeImplGemmDynamicBackwardDataXdlopsNHWCInvokerFactory(
}

InvokerFactory MakeImplGemmDynamicForwardDlopsNCHWCInvokerFactory(
const miopen::ProblemDescription& problem,
const solver::PerformanceConfigAsmImplicitGemmGTCFwdDlopsNCHWC& config)
const ProblemDescription& problem,
const solver::conv::PerformanceConfigAsmImplicitGemmGTCFwdDlopsNCHWC& config)
{
int hi = problem.GetInHeight_();
int wi = problem.GetInWidth_();
Expand Down
17 changes: 8 additions & 9 deletions src/conv/invokers/mlir_impl_gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ struct MlirConvArgs
#endif

#if MIIR_BARE_POINTER_ABI
void ComputeMlirDimsStrides(const conv::ProblemDescription& problem,
void ComputeMlirDimsStrides(const ProblemDescription& problem,
std::vector<size_t>& in_dims,
std::vector<size_t>& in_strides,
std::vector<size_t>& weights_dims,
Expand Down Expand Up @@ -156,7 +156,7 @@ void InsertGToDimsStrides(const std::string& layout,
strides.insert(strides.begin() + index, strides[index] * dims[index + 1]);
}

void ComputeMlirDimsStrides(const conv::ProblemDescription& problem,
void ComputeMlirDimsStrides(const ProblemDescription& problem,
std::vector<size_t>& in_dims,
std::vector<size_t>& in_strides,
std::vector<size_t>& weights_dims,
Expand Down Expand Up @@ -288,9 +288,9 @@ void SetMlirConvArgsPtr(
#endif // MIOPEN_BACKEND_HIP
} // Anonymous namespace

InvokerFactory MakeMlirFwdInvokerFactory(const miopen::ProblemDescription& problem)
InvokerFactory MakeMlirFwdInvokerFactory(const ProblemDescription& problem)
{
assert((problem.direction.IsForward()));
assert((problem.IsDirectionForward()));

std::vector<size_t> in_dims, in_strides;
std::vector<size_t> weights_dims, weights_strides;
Expand Down Expand Up @@ -354,9 +354,9 @@ InvokerFactory MakeMlirFwdInvokerFactory(const miopen::ProblemDescription& probl
};
}

InvokerFactory MakeMlirBwdInvokerFactory(const miopen::ProblemDescription& problem)
InvokerFactory MakeMlirBwdInvokerFactory(const ProblemDescription& problem)
{
assert(problem.direction.IsBackwardData());
assert(problem.IsDirectionBackwardData());

std::vector<size_t> in_dims, in_strides;
std::vector<size_t> weights_dims, weights_strides;
Expand Down Expand Up @@ -409,10 +409,9 @@ InvokerFactory MakeMlirBwdInvokerFactory(const miopen::ProblemDescription& probl
};
}

InvokerFactory MakeMlirWrWInvokerFactory(const miopen::ProblemDescription& problem,
size_t workspace_req)
InvokerFactory MakeMlirWrWInvokerFactory(const ProblemDescription& problem, size_t workspace_req)
{
assert((problem.direction.IsBackwardWrW()));
assert((problem.IsDirectionBackwardWrW()));

std::vector<size_t> in_dims, in_strides;
std::vector<size_t> weights_dims, weights_strides;
Expand Down
24 changes: 17 additions & 7 deletions src/conv/solver_finders.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
#include <miopen/config.h>
#include <miopen/mlo_internal.hpp>
#include <miopen/perf_field.hpp>
#include <miopen/problem_description.hpp>
#include <miopen/conv/problem_description.hpp>

namespace miopen {

Expand All @@ -42,6 +42,9 @@ MIOPEN_DECLARE_ENV_VAR(MIOPEN_DEBUG_CONV_WINOGRAD)
MIOPEN_DECLARE_ENV_VAR(MIOPEN_DEBUG_CONV_IMPLICIT_GEMM)
MIOPEN_DECLARE_ENV_VAR(MIOPEN_DEBUG_CONV_FFT)

namespace conv {
namespace {

class DirectSolverFinder : public SolversFinderMixin<ProblemDescription, ConvFindParameters>
{
protected:
Expand Down Expand Up @@ -178,6 +181,8 @@ class WinogradSolverFinder : public SolversFinderMixin<ProblemDescription, ConvF
}
};

} // namespace

const std::vector<std::unique_ptr<ISolversFinder>>& GetConvSolverFinders()
{
static const auto finders = []() {
Expand All @@ -193,14 +198,16 @@ const std::vector<std::unique_ptr<ISolversFinder>>& GetConvSolverFinders()
return finders;
}

} // namespace conv

/// Register invoker only for the best solution within algorithm.
/// Add all solutions to the find-db record.
void EvaluateInvokers(Handle& handle,
const std::vector<solver::ConvSolution>& solutions,
const AlgorithmName& algorithm_name,
const NetworkConfig& network_config,
const AnyInvokeParams& invoke_ctx,
DbRecord& record)
static void EvaluateInvokers(Handle& handle,
const std::vector<solver::ConvSolution>& solutions,
const AlgorithmName& algorithm_name,
const NetworkConfig& network_config,
const AnyInvokeParams& invoke_ctx,
DbRecord& record)
{
const char* const arch = miopen::GetStringEnv(MIOPEN_DEVICE_ARCH{});
if(arch != nullptr && strlen(arch) > 0)
Expand Down Expand Up @@ -309,6 +316,8 @@ void FindCore(const AnyInvokeParams& invoke_ctx,
EvaluateInvokers(handle, ss.second, ss.first, network_config, invoke_ctx, record);
}

namespace conv {

bool IsAlgorithmDisabled(miopenConvAlgorithm_t algo)
{
switch(algo)
Expand All @@ -328,4 +337,5 @@ bool IsAlgorithmDisabled(miopenConvAlgorithm_t algo)
} // clang-format on
}

} // namespace conv
} // namespace miopen
Loading

0 comments on commit 19c6046

Please sign in to comment.