Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Standardize workspace abstraction #2524

Merged
merged 25 commits into from
Dec 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
5d9d7fc
added checks on workspace params
amberhassaan Nov 1, 2023
75e18bb
addressed review comments
amberhassaan Nov 8, 2023
9859d97
Merge remote-tracking branch 'origin/develop' into amber/add-workspac…
amberhassaan Nov 9, 2023
9bb0750
fix release build warning
amberhassaan Nov 9, 2023
675b6a9
WIP: workspace abstraction
amberhassaan Nov 11, 2023
4b33079
added an abstraction for workspace
amberhassaan Nov 15, 2023
6b12bcf
Merge branch 'develop' into amber/add-workspace-check
amberhassaan Nov 15, 2023
cdb9325
fix a check
amberhassaan Nov 15, 2023
71f709b
fix format
amberhassaan Nov 15, 2023
ae5d50e
missed some instances
amberhassaan Nov 15, 2023
02bf9d0
Merge remote-tracking branch 'origin/develop' into amber/workspace-ab…
amberhassaan Nov 15, 2023
83d622b
bring back the workspace round up logic
amberhassaan Nov 15, 2023
4bd4253
formatting
amberhassaan Nov 15, 2023
6eef5fe
fix hip tidy issues. add more checks
amberhassaan Nov 16, 2023
f22aa2d
fix a bug with zeroing out rnn workspace
amberhassaan Nov 16, 2023
d6ae39f
fix hip-tidy error
amberhassaan Nov 16, 2023
385d816
Merge branch 'amber/add-workspace-check' into amber/workspace-abstrac…
amberhassaan Nov 21, 2023
815b73a
Merge remote-tracking branch 'origin/develop' into amber/workspace-ab…
amberhassaan Nov 21, 2023
98ec151
address review comments
amberhassaan Nov 21, 2023
437ea5c
Merge remote-tracking branch 'origin/develop' into amber/workspace-ab…
amberhassaan Dec 6, 2023
43bb198
Merge branch 'develop' into amber/workspace-abstraction
amberhassaan Dec 7, 2023
0b9a725
Merge branch 'develop' into amber/workspace-abstraction
amberhassaan Dec 12, 2023
9880704
Merge branch 'develop' into amber/workspace-abstraction
amberhassaan Dec 12, 2023
75b9f45
Merge branch 'develop' into amber/workspace-abstraction
amberhassaan Dec 14, 2023
f1e32d7
Merge branch 'develop' into amber/workspace-abstraction
junliume Dec 14, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions src/ocl/convolutionocl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,18 @@ static inline void ValidateGroupCount(const TensorDescriptor& x,
MIOPEN_THROW(miopenStatusBadParm, "Invalid group number");
}

static inline void ValidateWorkspace(Data_t workSpace, const size_t workSpaceSize)
{

[[maybe_unused]] bool x = (workSpace != nullptr);
[[maybe_unused]] bool y = (workSpaceSize != 0);

assert(((x && y) || (!x && !y)) && "workspace pointer and size don't match. Either both should "
"be zero or both should be non-zero");

/// \todo could add a check here that workSpace points to GPU memory
}

static Invoker PrepareInvoker(ExecutionContext ctx,
const conv::ProblemDescription& problem,
const NetworkConfig& config,
Expand Down Expand Up @@ -260,6 +272,7 @@ void ConvolutionDescriptor::FindConvFwdAlgorithm(Handle& handle,
bool exhaustiveSearch) const
{
MIOPEN_LOG_I("requestAlgoCount = " << requestAlgoCount << ", workspace = " << workSpaceSize);
ValidateWorkspace(workSpace, workSpaceSize);
if(x == nullptr || w == nullptr || y == nullptr)
MIOPEN_THROW(miopenStatusBadParm, "Buffers cannot be NULL");
if(returnedAlgoCount == nullptr)
Expand Down Expand Up @@ -495,6 +508,7 @@ void ConvolutionDescriptor::ConvolutionForward(Handle& handle,
size_t workSpaceSize) const
{
MIOPEN_LOG_I("algo = " << algo << ", workspace = " << workSpaceSize);
ValidateWorkspace(workSpace, workSpaceSize);

const auto tensors = ConvFwdTensors{xDesc, x, wDesc, w, yDesc, y};
ValidateTensors(tensors);
Expand Down Expand Up @@ -812,6 +826,7 @@ void ConvolutionDescriptor::ConvolutionForwardImmediate(Handle& handle,
const solver::Id solver_id) const
{
MIOPEN_LOG_I("solver_id = " << solver_id.ToString() << ", workspace = " << workSpaceSize);
ValidateWorkspace(workSpace, workSpaceSize);
const auto tensors = ConvFwdTensors{xDesc, x, wDesc, w, yDesc, y};

ValidateTensors(tensors);
Expand Down Expand Up @@ -846,6 +861,7 @@ void ConvolutionDescriptor::FindConvBwdDataAlgorithm(Handle& handle,
bool exhaustiveSearch) const
{
MIOPEN_LOG_I("requestAlgoCount = " << requestAlgoCount << ", workspace = " << workSpaceSize);
ValidateWorkspace(workSpace, workSpaceSize);
if(dx == nullptr || w == nullptr || dy == nullptr)
MIOPEN_THROW(miopenStatusBadParm, "Buffers cannot be NULL");
if(returnedAlgoCount == nullptr)
Expand Down Expand Up @@ -944,6 +960,7 @@ void ConvolutionDescriptor::ConvolutionBackwardData(Handle& handle,
size_t workSpaceSize) const
{
MIOPEN_LOG_I("algo = " << algo << ", workspace = " << workSpaceSize);
ValidateWorkspace(workSpace, workSpaceSize);

auto tensors = ConvBwdTensors{dyDesc, dy, wDesc, w, dxDesc, dx};

Expand Down Expand Up @@ -1015,6 +1032,7 @@ void ConvolutionDescriptor::ConvolutionBackwardImmediate(Handle& handle,
solver::Id solver_id) const
{
MIOPEN_LOG_I("solver_id = " << solver_id.ToString() << ", workspace = " << workSpaceSize);
ValidateWorkspace(workSpace, workSpaceSize);
auto tensors = ConvBwdTensors{dyDesc, dy, wDesc, w, dxDesc, dx};

ValidateTensors(tensors);
Expand Down Expand Up @@ -1055,6 +1073,7 @@ void ConvolutionDescriptor::FindConvBwdWeightsAlgorithm(Handle& handle,
bool exhaustiveSearch) const
{
MIOPEN_LOG_I("requestAlgoCount = " << requestAlgoCount << ", workspace = " << workSpaceSize);
ValidateWorkspace(workSpace, workSpaceSize);
if(x == nullptr || dw == nullptr || dy == nullptr)
MIOPEN_THROW(miopenStatusBadParm, "Buffers cannot be NULL");
if(returnedAlgoCount == nullptr)
Expand Down Expand Up @@ -1151,6 +1170,7 @@ void ConvolutionDescriptor::ConvolutionBackwardWeights(const Handle& handle,
size_t workSpaceSize) const
{
MIOPEN_LOG_I("algo = " << algo << ", workspace = " << workSpaceSize);
ValidateWorkspace(workSpace, workSpaceSize);
decltype(auto) tensors = ConvWrwTensors{dyDesc, dy, xDesc, x, dwDesc, dw};
ValidateTensors(tensors);
ValidateAlphaBeta(alpha, beta);
Expand Down Expand Up @@ -1218,6 +1238,7 @@ void ConvolutionDescriptor::ConvolutionWrwImmediate(Handle& handle,
solver::Id solver_id) const
{
MIOPEN_LOG_I("solver_id = " << solver_id.ToString() << ", workspace = " << workSpaceSize);
ValidateWorkspace(workSpace, workSpaceSize);
auto tensors = ConvWrwTensors{dyDesc, dy, xDesc, x, dwDesc, dw};
ValidateTensors(tensors);

Expand Down
Loading
Loading