Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[1.x] Backporting #18779 to v1.x (#18894)
Browse files Browse the repository at this point in the history
* initial commit

* Support extra inputs for subgraph ops (#18779)

Support additional inputs to custom subgraph ops that are not direct dependencies to ops in the subgraph. This will enable various use cases: custom control flow ops, custom ops that maintain a state that should be saved/loaded, etc.

Highlights:

* Added test that uses a graph pass (addInputPass) to add a new custom input to the subgraph op

* Added new optional argument (clear) to hybridize & optimize_for APIs in Gluon Block to enable multiple optimizations

* refactored lib_api.h JSON utilities

* added new Graph data structure utilities to simplify custom graph passes

* refactored custom op registration

* enhanced custom subgraph op to support additional inputs to subgraph op that is not an input to ops in the subgraph

* updated subgraph & graph pass READMEs

* Added error messaging from external library

* changed messages

* changed to pointers and types

* added cast

* updated cast

* fixed signed int

* whitespace

* fixd pass resource

Co-authored-by: Ubuntu <ubuntu@ip-172-31-6-220.us-west-2.compute.internal>
  • Loading branch information
samskalicky and Ubuntu authored Aug 18, 2020
1 parent 9981e84 commit d1ac7c8
Show file tree
Hide file tree
Showing 16 changed files with 1,489 additions and 783 deletions.
4 changes: 3 additions & 1 deletion example/extensions/lib_api/init_lib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,14 @@
#include <iostream>
#include "lib_api.h"

using namespace mxnet::ext;

MXReturnValue initialize(int version) {
if (version >= 10700) {
std::cout << "MXNet version " << version << " supported" << std::endl;
return MX_SUCCESS;
} else {
std::cout << "MXNet version " << version << " not supported" << std::endl;
MX_ERROR_MSG << "MXNet version " << version << " not supported";
return MX_FAIL;
}
}
16 changes: 8 additions & 8 deletions example/extensions/lib_custom_op/gemm_lib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
#include <iostream>
#include "lib_api.h"

using namespace mxnet::ext;

// main matrix multiplication routine
void gemm(const float* A, const float* B, float* C,
const unsigned n, const unsigned k, const unsigned m) {
Expand Down Expand Up @@ -127,12 +129,12 @@ MXReturnValue inferType(const std::unordered_map<std::string, std::string>& attr
std::vector<int> *outtypes) {
// validate inputs
if (intypes->size() != 2) {
std::cout << "Expected 2 inputs to inferType" << std::endl;
MX_ERROR_MSG << "Expected 2 inputs to inferType";
return MX_FAIL;
}
for (unsigned i = 0; i < intypes->size(); i++) {
if (intypes->at(i) != kFloat32) {
std::cout << "Expected input " << i << " to have float32 type" << std::endl;
MX_ERROR_MSG << "Expected input " << i << " to have float32 type";
return MX_FAIL;
}
}
Expand All @@ -146,11 +148,11 @@ MXReturnValue inferShape(const std::unordered_map<std::string, std::string>& att
std::vector<std::vector<unsigned int>>* outshapes) {
// validate inputs
if (inshapes->size() != 2) {
std::cout << "Expected 2 inputs to inferShape" << std::endl;
MX_ERROR_MSG << "Expected 2 inputs to inferShape";
return MX_FAIL;
}
if (inshapes->at(0).size() != 2 || inshapes->at(1).size() != 2) {
std::cout << "Expected 2D matrices for both inputs to inferShape" << std::endl;
MX_ERROR_MSG << "Expected 2D matrices for both inputs to inferShape";
return MX_FAIL;
}

Expand All @@ -159,7 +161,7 @@ MXReturnValue inferShape(const std::unordered_map<std::string, std::string>& att
unsigned kk = inshapes->at(1)[0];
unsigned m = inshapes->at(1)[1];
if (k != kk) {
std::cout << "Exected first input axis 1 equals to second input axis 0" << std::endl;
MX_ERROR_MSG << "Exected first input axis 1 equals to second input axis 0";
return MX_FAIL;
}

Expand Down Expand Up @@ -195,8 +197,6 @@ class MyStatefulGemm : public CustomStatefulOp {
return backward(attrs_, inputs, outputs, op_res);
}

~MyStatefulGemm() {}

private:
int count;
const std::unordered_map<std::string, std::string> attrs_;
Expand Down Expand Up @@ -230,7 +230,7 @@ MXReturnValue initialize(int version) {
std::cout << "MXNet version " << version << " supported" << std::endl;
return MX_SUCCESS;
} else {
std::cout << "MXNet version " << version << " not supported" << std::endl;
MX_ERROR_MSG << "MXNet version " << version << " not supported";
return MX_FAIL;
}
}
4 changes: 3 additions & 1 deletion example/extensions/lib_custom_op/relu_lib.cu
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
#include <iostream>
#include "lib_api.h"

using namespace mxnet::ext;

#define NumThreadPerBlock 256 // mxnet recommended cuda thread number per block

__global__ void relu_gpu_forward(float *out, float *in, int64_t N) {
Expand Down Expand Up @@ -263,7 +265,7 @@ MXReturnValue initialize(int version) {
std::cout << "MXNet version " << version << " supported" << std::endl;
return MX_SUCCESS;
} else {
std::cout << "MXNet version " << version << " not supported" << std::endl;
MX_ERROR_MSG << "MXNet version " << version << " not supported";
return MX_FAIL;
}
}
22 changes: 12 additions & 10 deletions example/extensions/lib_custom_op/transposecsr_lib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
#include <iostream>
#include "lib_api.h"

using namespace mxnet::ext;

void transpose(MXTensor& src, MXTensor& dst, const OpResource& res) {
MXSparse* A = src.data<MXSparse>();
MXSparse* B = dst.data<MXSparse>();
Expand Down Expand Up @@ -70,11 +72,11 @@ MXReturnValue forward(const std::unordered_map<std::string, std::string>& attrs,
// The data types and storage types of inputs and outputs should be the same.
if(inputs->at(0).dtype != outputs->at(0).dtype ||
inputs->at(0).stype != outputs->at(0).stype) {
std::cout << "Error! Expected all inputs and outputs to be the same type."
<< "Found input storage type:" << inputs->at(0).stype
<< " Found output storage type:" << outputs->at(0).stype
<< " Found input data type:" << inputs->at(0).dtype
<< " Found output data type:" << outputs->at(0).dtype << std::endl;
MX_ERROR_MSG << "Error! Expected all inputs and outputs to be the same type."
<< "Found input storage type:" << inputs->at(0).stype
<< " Found output storage type:" << outputs->at(0).stype
<< " Found input data type:" << inputs->at(0).dtype
<< " Found output data type:" << outputs->at(0).dtype;
return MX_FAIL;
}

Expand All @@ -101,11 +103,11 @@ MXReturnValue inferType(const std::unordered_map<std::string, std::string>& attr
std::vector<int>* outtypes) {
// validate inputs
if (intypes->size() != 1) {
std::cout << "Expected 1 inputs to inferType" << std::endl;
MX_ERROR_MSG << "Expected 1 inputs to inferType";
return MX_FAIL;
}
if (intypes->at(0) != kFloat32) {
std::cout << "Expected input to have float32 type" << std::endl;
MX_ERROR_MSG << "Expected input to have float32 type";
return MX_FAIL;
}

Expand All @@ -117,7 +119,7 @@ MXReturnValue inferSType(const std::unordered_map<std::string, std::string>& att
std::vector<int>* instypes,
std::vector<int>* outstypes) {
if (instypes->at(0) != kCSRStorage) {
std::cout << "Expected storage type is kCSRStorage" << std::endl;
MX_ERROR_MSG << "Expected storage type is kCSRStorage";
return MX_FAIL;
}
outstypes->at(0) = instypes->at(0);
Expand All @@ -129,7 +131,7 @@ MXReturnValue inferShape(const std::unordered_map<std::string, std::string>& att
std::vector<std::vector<unsigned int>>* outshapes) {
// validate inputs
if (inshapes->size() != 1) {
std::cout << "Expected 1 inputs to inferShape" << std::endl;
MX_ERROR_MSG << "Expected 1 inputs to inferShape";
return MX_FAIL;
}

Expand Down Expand Up @@ -194,7 +196,7 @@ MXReturnValue initialize(int version) {
std::cout << "MXNet version " << version << " supported" << std::endl;
return MX_SUCCESS;
} else {
std::cout << "MXNet version " << version << " not supported" << std::endl;
MX_ERROR_MSG << "MXNet version " << version << " not supported";
return MX_FAIL;
}
}
22 changes: 12 additions & 10 deletions example/extensions/lib_custom_op/transposerowsp_lib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
#include <iostream>
#include "lib_api.h"

using namespace mxnet::ext;

void transpose(MXTensor& src, MXTensor& dst, const OpResource& res) {
MXSparse* A = src.data<MXSparse>();
MXSparse* B = dst.data<MXSparse>();
Expand Down Expand Up @@ -73,11 +75,11 @@ MXReturnValue forward(const std::unordered_map<std::string, std::string>& attrs,
// The data types and storage types of inputs and outputs should be the same.
if(inputs->at(0).dtype != outputs->at(0).dtype ||
inputs->at(0).stype != outputs->at(0).stype) {
std::cout << "Error! Expected all inputs and outputs to be the same type."
<< "Found input storage type:" << inputs->at(0).stype
<< " Found output storage type:" << outputs->at(0).stype
<< " Found input data type:" << inputs->at(0).dtype
<< " Found output data type:" << outputs->at(0).dtype << std::endl;
MX_ERROR_MSG << "Error! Expected all inputs and outputs to be the same type."
<< "Found input storage type:" << inputs->at(0).stype
<< " Found output storage type:" << outputs->at(0).stype
<< " Found input data type:" << inputs->at(0).dtype
<< " Found output data type:" << outputs->at(0).dtype;
return MX_FAIL;
}
transpose(inputs->at(0), outputs->at(0), res);
Expand All @@ -103,11 +105,11 @@ MXReturnValue inferType(const std::unordered_map<std::string, std::string>& attr
std::vector<int>* outtypes) {
// validate inputs
if (intypes->size() != 1) {
std::cout << "Expected 1 inputs to inferType" << std::endl;
MX_ERROR_MSG << "Expected 1 inputs to inferType";
return MX_FAIL;
}
if (intypes->at(0) != kFloat32) {
std::cout << "Expected input to have float32 type" << std::endl;
MX_ERROR_MSG << "Expected input to have float32 type";
return MX_FAIL;
}

Expand All @@ -119,7 +121,7 @@ MXReturnValue inferSType(const std::unordered_map<std::string, std::string>& att
std::vector<int>* instypes,
std::vector<int>* outstypes) {
if (instypes->at(0) != kRowSparseStorage) {
std::cout << "Expected storage type is kRowSparseStorage" << std::endl;
MX_ERROR_MSG << "Expected storage type is kRowSparseStorage";
return MX_FAIL;
}
outstypes->at(0) = instypes->at(0);
Expand All @@ -131,7 +133,7 @@ MXReturnValue inferShape(const std::unordered_map<std::string, std::string>& att
std::vector<std::vector<unsigned int>>* outshapes) {
// validate inputs
if (inshapes->size() != 1) {
std::cout << "Expected 1 inputs to inferShape" << std::endl;
MX_ERROR_MSG << "Expected 1 inputs to inferShape";
return MX_FAIL;
}

Expand Down Expand Up @@ -196,7 +198,7 @@ MXReturnValue initialize(int version) {
std::cout << "MXNet version " << version << " supported" << std::endl;
return MX_SUCCESS;
} else {
std::cout << "MXNet version " << version << " not supported" << std::endl;
MX_ERROR_MSG << "MXNet version " << version << " not supported";
return MX_FAIL;
}
}
Loading

0 comments on commit d1ac7c8

Please sign in to comment.