Skip to content

Commit

Permalink
add executor.prepare (#9022)
Browse files Browse the repository at this point in the history
optimize executor.run
  • Loading branch information
jacquesqiao authored Mar 20, 2018
1 parent 30b7032 commit 37a272e
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 93 deletions.
28 changes: 11 additions & 17 deletions paddle/fluid/framework/executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,8 @@ limitations under the License. */

#include "paddle/fluid/framework/executor.h"

#include <set>

#include "gflags/gflags.h"
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/framework/feed_fetch_method.h"
#include "paddle/fluid/framework/feed_fetch_type.h"
#include "paddle/fluid/framework/lod_rank_table.h"
#include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/op_registry.h"
Expand All @@ -40,14 +36,13 @@ namespace {
int kProgramId = -1;
} // namespace

struct ExecutorPrepareContext {
ExecutorPrepareContext(const framework::ProgramDesc& prog, size_t block_id)
: prog_(prog), block_id_(block_id) {}
ExecutorPrepareContext::ExecutorPrepareContext(
const framework::ProgramDesc& prog, size_t block_id)
: prog_(prog), block_id_(block_id) {}

const framework::ProgramDesc& prog_;
size_t block_id_;
std::vector<std::unique_ptr<OperatorBase>> ops_;
};
ExecutorPrepareContext::~ExecutorPrepareContext() {
VLOG(5) << "destroy ExecutorPrepareContext";
}

Executor::Executor(const platform::Place& place) : place_(place) {}

Expand Down Expand Up @@ -101,9 +96,8 @@ static void CheckTensorNANOrInf(const std::string& name,
void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id,
bool create_local_scope, bool create_vars) {
platform::RecordBlock b(block_id);
auto* ctx = Prepare(pdesc, block_id);
RunPreparedContext(ctx, scope, create_local_scope, create_vars);
delete ctx;
auto ctx = Prepare(pdesc, block_id);
RunPreparedContext(ctx.get(), scope, create_local_scope, create_vars);
}

// Check whether the block already has feed operators and feed_holder.
Expand Down Expand Up @@ -274,15 +268,15 @@ void Executor::Run(const ProgramDesc& program, Scope* scope,
}
}

ExecutorPrepareContext* Executor::Prepare(const ProgramDesc& program,
int block_id) {
std::unique_ptr<ExecutorPrepareContext> Executor::Prepare(
const ProgramDesc& program, int block_id) {
auto* ctx = new ExecutorPrepareContext(program, block_id);
PADDLE_ENFORCE_LT(static_cast<size_t>(block_id), program.Size());
auto& block = program.Block(block_id);
for (auto& op_desc : block.AllOps()) {
ctx->ops_.push_back(OpRegistry::CreateOp(*op_desc));
}
return ctx;
return std::unique_ptr<ExecutorPrepareContext>(ctx);
}

void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
Expand Down
15 changes: 12 additions & 3 deletions paddle/fluid/framework/executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,16 @@ limitations under the License. */

namespace paddle {
namespace framework {
struct ExecutorPrepareContext;

struct ExecutorPrepareContext {
ExecutorPrepareContext(const framework::ProgramDesc& prog, size_t block_id);
~ExecutorPrepareContext();

const framework::ProgramDesc& prog_;
size_t block_id_;
std::vector<std::unique_ptr<OperatorBase>> ops_;
};

class Executor {
public:
// TODO(dzhwinter) : Do not rely on this function, it will be removed
Expand All @@ -47,8 +56,8 @@ class Executor {
const std::string& feed_holder_name = "feed",
const std::string& fetch_holder_name = "fetch");

static ExecutorPrepareContext* Prepare(const ProgramDesc& program,
int block_id);
static std::unique_ptr<ExecutorPrepareContext> Prepare(
const ProgramDesc& program, int block_id);

void RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
bool create_local_scope = true,
Expand Down
165 changes: 93 additions & 72 deletions python/paddle/fluid/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,77 @@ def parselod(data):
tensor.set_lod(lod)
return tensor

def _get_program_cache(self, program_cache_key):
return self.program_caches.get(program_cache_key, None)

def _add_program_cache(self, program_cache_key, program):
self.program_caches[program_cache_key] = program

def _add_feed_fetch_ops(self, program, feed, fetch_list, feed_var_name,
fetch_var_name):
tmp_program = program.clone()

global_block = tmp_program.global_block()

if feed_var_name in global_block.vars:
feed_var = global_block.var(feed_var_name)
else:
feed_var = global_block.create_var(
name=feed_var_name,
type=core.VarDesc.VarType.FEED_MINIBATCH,
persistable=True)

if fetch_var_name in global_block.vars:
fetch_var = global_block.var(fetch_var_name)
else:
fetch_var = global_block.create_var(
name=fetch_var_name,
type=core.VarDesc.VarType.FETCH_LIST,
persistable=True)

# prepend feed operators
if not has_feed_operators(global_block, feed, feed_var_name):
for i, name in enumerate(feed):
out = global_block.var(name)
global_block.prepend_op(
type='feed',
inputs={'X': [feed_var]},
outputs={'Out': [out]},
attrs={'col': i})

# append fetch_operators
if not has_fetch_operators(global_block, fetch_list, fetch_var_name):
for i, var in enumerate(fetch_list):
assert isinstance(var, Variable) or isinstance(var, str), (
"Wrong type for fetch_list[%s]: %s" % (i, type(var)))
global_block.append_op(
type='fetch',
inputs={'X': [var]},
outputs={'Out': [fetch_var]},
attrs={'col': i})

return tmp_program

def _feed_data(self, program, feed, feed_var_name, scope):
# feed var to framework
for op in program.global_block().ops:
if op.desc.type() == 'feed':
feed_target_name = op.desc.output('Out')[0]
cur_feed = feed[feed_target_name]
if not isinstance(cur_feed, core.LoDTensor):
cur_feed = self.aslodtensor(cur_feed)
idx = op.desc.attr('col')
core.set_feed_variable(scope, cur_feed, feed_var_name, idx)
else:
break

def _fetch_data(self, fetch_list, fetch_var_name, scope):
outs = [
core.get_fetch_variable(scope, fetch_var_name, i)
for i in xrange(len(fetch_list))
]
return outs

def run(self,
program=None,
feed=None,
Expand Down Expand Up @@ -268,7 +339,6 @@ def run(self,
raise TypeError("feed should be a map")
if fetch_list is None:
fetch_list = []

if program is None:
program = default_main_program()

Expand All @@ -278,79 +348,30 @@ def run(self,
if scope is None:
scope = global_scope()

program_cache = None
program_cache_key = get_program_cache_key(feed, fetch_list)

cache_key = get_program_cache_key(feed, fetch_list)
if use_program_cache:
# find program cache by cache_key
program_cache = self.program_caches.get(program_cache_key, None)
# TODO(qiao): Should check program_cache and program are exactly the same.
cached_program = self._get_program_cache(cache_key)
if cached_program is None:
cached_program = self._add_feed_fetch_ops(
program=program,
feed=feed,
fetch_list=fetch_list,
feed_var_name=feed_var_name,
fetch_var_name=fetch_var_name)
self._add_program_cache(cache_key, cached_program)
program = cached_program
else:
self.program_caches.pop(program_cache_key, None)

if program_cache is None:
program_cache = program.clone()

if use_program_cache:
self.program_caches[program_cache_key] = program_cache

global_block = program_cache.global_block()

if feed_var_name in global_block.vars:
feed_var = global_block.var(feed_var_name)
else:
feed_var = global_block.create_var(
name=feed_var_name,
type=core.VarDesc.VarType.FEED_MINIBATCH,
persistable=True)

if fetch_var_name in global_block.vars:
fetch_var = global_block.var(fetch_var_name)
else:
fetch_var = global_block.create_var(
name=fetch_var_name,
type=core.VarDesc.VarType.FETCH_LIST,
persistable=True)

# prepend feed operators
if not has_feed_operators(global_block, feed, feed_var_name):
for i, name in enumerate(feed):
out = global_block.var(name)
global_block.prepend_op(
type='feed',
inputs={'X': [feed_var]},
outputs={'Out': [out]},
attrs={'col': i})

# append fetch_operators
if not has_fetch_operators(global_block, fetch_list,
fetch_var_name):
for i, var in enumerate(fetch_list):
assert isinstance(var, Variable) or isinstance(var, str), (
"Wrong type for fetch_list[%s]: %s" % (i, type(var)))
global_block.append_op(
type='fetch',
inputs={'X': [var]},
outputs={'Out': [fetch_var]},
attrs={'col': i})

# feed var to framework
for op in program_cache.global_block().ops:
if op.desc.type() == 'feed':
feed_target_name = op.desc.output('Out')[0]
cur_feed = feed[feed_target_name]
if not isinstance(cur_feed, core.LoDTensor):
cur_feed = self.aslodtensor(cur_feed)
idx = op.desc.attr('col')
core.set_feed_variable(scope, cur_feed, feed_var_name, idx)
else:
break

self.executor.run(program_cache.desc, scope, 0, True, True)
outs = [
core.get_fetch_variable(scope, fetch_var_name, i)
for i in xrange(len(fetch_list))
]
self.program_caches.pop(cache_key, None)
program = self._add_feed_fetch_ops(
program=program,
feed=feed,
fetch_list=fetch_list,
feed_var_name=feed_var_name,
fetch_var_name=fetch_var_name)

self._feed_data(program, feed, feed_var_name, scope)
self.executor.run(program.desc, scope, 0, True, True)
outs = self._fetch_data(fetch_list, fetch_var_name, scope)
if return_numpy:
outs = as_numpy(outs)
return outs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

import numpy
import paddle.fluid.core as core

from paddle.fluid.executor import Executor
from paddle.fluid.layers import mul, data

Expand Down

0 comments on commit 37a272e

Please sign in to comment.