-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
move net_design to framework #2553
Changes from 19 commits
5886959
a00900a
4fb581b
bb33f7a
a7dbfe0
8cf2d60
2933da0
7d5b1b2
b6450da
a46506d
3c7bf55
ebe143d
4e5a359
9be4c74
7102b80
2953215
58fdcd0
417b279
b68c8bf
62823b1
deff7cc
740f3a6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,174 @@ | ||
# Network Design | ||
|
||
`Network` is the container and controller of a set of operators, | ||
users can build a real network from a `NetDef` in protobuf message | ||
and use `Network.Run()` to run all the operators in the network. | ||
|
||
The `Network` will | ||
|
||
- manage all the operators contained in the network. | ||
- not own any `Variable`. | ||
|
||
# API | ||
|
||
## NetworkBase | ||
To make the `Network` extendable, a base class is defined like this | ||
|
||
```c++ | ||
// operator's index stored in a network. | ||
typedef int OpIndex; | ||
|
||
// The minimum a network should be implemented. | ||
class NetworkBase { | ||
public: | ||
// `def` is a proto message that describe the structure of a network. | ||
NetworkBase(); | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think class NetworkBase {
public:
using OpIndex = size_t;
OpIndex addOperator(const OpDesc& op);
}; So that bool Run(Scope* scope, OpIndex begin = -1UL, OpIndex end = -1UL); It will be very helpful if a network can run partially. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @reyoung Network maybe run from multiple start-operators to multiple end-operators? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why don't we use the OpName instead of the OpIndex? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agree with @wanghaoshuang. It seems that operators within [begin_index, end_index] cannot define arbitrary sub-DAG. |
||
// Infer the shapes of variables required by operators in the network. The | ||
// `scope` will be mutated according to the inferred shapes. | ||
virtual bool InferShape(Scope *scope) = 0; | ||
|
||
// run all the operators and return success(true) or not, all the | ||
// variables are located in `scope`. `begin` and `end` specify the scope of | ||
// `ops_` to run, If no positive indexes are provided, all operators in `ops_` | ||
// will run. | ||
virtual bool Run(Scope *scope, OpIndex begin = -1, | ||
OpIndex end = -1) const = 0; | ||
}; | ||
``` | ||
|
||
All network implementations should build networks from a protobuf message which | ||
describes the structure of a real network; `Run` method should be implemented by | ||
all implementations to offer a universal method to forward or backward compute a network. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should be implemented by all implementations --> should be implemented by all its inheriting class |
||
|
||
A method of factory pattern can be defined like | ||
|
||
```c++ | ||
std::unique<NetworkBase> CreateNet(const NetDef& def) { | ||
switch (def.model_type()) { | ||
case NN: | ||
return new Network(def); | ||
case Recursive: | ||
return new RecursiveNet(def); | ||
case Recurrent: | ||
return new RecurrentNet(def); | ||
} | ||
return nullptr; | ||
} | ||
``` | ||
|
||
Network is designed as the container of operators, to make it more extendable, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. --> Network is designed as a container for operators, with related variable resources decoupled from it to make it more extendable. |
||
we decoupling it from the related variable resources. | ||
|
||
`Run(Scope* scope)` takes the scope as a argument so that it can run in different scopes. | ||
|
||
Finally, `NetworkBase` can be used as followed | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. as followed --> as follows: |
||
|
||
```c++ | ||
Scope default_scope; | ||
auto net = CreateNet(def); | ||
|
||
if (net) { | ||
net.Run(&default_scope); | ||
} | ||
``` | ||
|
||
|
||
## A Simple Network Implementation | ||
|
||
A very basic implementation is as followed, all it does is simply to run every operators in sequence. | ||
|
||
```c++ | ||
class ScratchNet final : public NetworkBase { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry, I don't understand why the basic implementation is called There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
public: | ||
// Create a network describe by `def`. NetDef is the definition of a network. | ||
ScratchNet(const NetDef &def); | ||
|
||
virtual bool InferShape(Scope *scope) override; | ||
|
||
// Run all the operators with the `scope`, if no scope is provided, default | ||
// scope will be used instead. | ||
virtual bool Run(Scope *scope = nullptr, OpIndex begin, | ||
OpIndex end) const override; | ||
|
||
const std::vector<Operator> &GetOps() const; | ||
|
||
std::vector<Operator> *MutableOps(); | ||
|
||
protected: | ||
// Create operators accordding to `def`. | ||
bool CreateNet(const NetDef &def); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What`t the difference ScratchNet::CreateNet and CreateNet in line#43. Or shall we rename ScratchNet::CreateNet to a better name, such as: CreateOps ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am not sure Maybe we should rename it to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe |
||
|
||
// Add a operator which is identified as `type` and has attributes described | ||
// in `attrs`, the `inputs` are the keys of readonly input variables, | ||
// `outputs` are keys of mutable output variables. An `OpIndex` will be | ||
// returned which indicates the offset of the new operator in `ops_`. | ||
OpIndex AddOp(const std::string &type, const std::vector<string> &inputs, | ||
const std::vector<string> &outputs, | ||
const OprAttr &attrs = OprAttr()); | ||
|
||
private: | ||
// the operators owned by `Network`. | ||
std::vector<Operator> ops_; | ||
}; | ||
``` | ||
|
||
`ScratchNet` will create operators so that a private member `ops_` is defined, | ||
the operators are created by `CreateNet`, and each operator is created by `AddOp`. | ||
|
||
|
||
## Usage | ||
`ScratchNet` can be used to define and run a network as followed | ||
|
||
```c++ | ||
// create an empty scope located on CPU device. | ||
Scope scope(CPUPlace()); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think scope do not take There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sample code |
||
|
||
// create and init variables described in `net_desc`. | ||
scope.CreateVariables(net_desc); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Because the variables' information is located in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Maybe there should be a high-level concept called NetworkBase network;
Scope scope;
Variable* image = scope.CreateVariable("Image");
Variable* label = scope.CreateVariable("Label");
NetworkBuilder builder(&network, &scope);
Variable* fc_out = builder.FCLayer(input=image, size=100, activation="Sigmoid");
Variable* prediction = builder.FCLayer(input=fc_out, size=10, activation="Sigmoid");
Variable* loss = builder.CrossEntropy(input=prediction, label=label);
Variable* avg_loss = builder.Mean(loss);
auto allParams = builder.Parameters();
builder.BackwardFrom(avg_loss)
builder.AddOptimization(1e-4, "adam");
// train one mini-batch
network.run(&scope); There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we merge If we merge NetBuilder's functions into There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe the logic of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I create another PR with NetBuilder Users can create Both of them are not coupled with each other. |
||
scope.InitVariables(net_desc); | ||
|
||
// create a network according to `net_desc` | ||
auto net = CreateNet(net_desc); | ||
|
||
// run the network providing the `scope`. | ||
net.Run(&scope); | ||
``` | ||
|
||
## Compatibility with RNN | ||
|
||
Benefit from the decoupling of `ScratchNet.Run` and `Scope`, `ScratchNet` is compatible with future RNN design, | ||
for example we can implement a simple recurrent neural network as followed | ||
|
||
```c++ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Well, this part should be rewritten.
Scope parent;
// in RNN
NetworkBase& stepNet;
std::vector<Scope> timestepScopes;
for (size_t i=0; i < seqLen; ++i) {
timestepScopes.push_back(Scope(&parent));
stepNet.Run(×tepScopes);
} There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I will give another commit soon to fix this. |
||
// copy some `vars` form `source` to `target` | ||
void Copy(const Scope &source, Scope &target, | ||
const std::vector<std::string> &vars); | ||
|
||
Scope default_scope; | ||
// some initial mutations on `default_scope` here. | ||
|
||
auto rnn_step_net = ScratchNet(rnn_step_net_def); | ||
|
||
// Create rnn's states, the last scope is used to store rnn outputs. | ||
Scope *rnn_states = new Scope[num_states + 1]; | ||
|
||
for (int i = 0; i < num_states + 1; i++) { | ||
// Initialize all rnn state scopes, copy parameters and so on. | ||
rnn_states[i].CreateVars(rnn_step_net_def); | ||
Copy(default_scope, rnn_states[i], rnn_related_vars); | ||
// Prepare rnn's inlinks, just copy inlink variables to each state. | ||
Copy(default_scope, rnn_states[i], inlink_vars); | ||
} | ||
|
||
// Run the rnn. | ||
for (int i = 0; i < num_states; i++) { | ||
rnn_step_net.Run(rnn_states[i]); | ||
// Copy current state's state variables to next state, the related variables | ||
// are named like "previous_state_xxx". | ||
Copy(rnn_states[i], rnn_states[i + 1], pre_state_vars) | ||
} | ||
|
||
// Copy rnn's final outputs to `default_scope`. | ||
Copy(rnn_states[num_states], default_scope, outlink_vars); | ||
``` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
or use AddOp to modify this Net