-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add executor.prepare #9022
add executor.prepare #9022
Changes from 11 commits
3980011
9b4bd6b
3d90b07
efac71b
c54d5f8
82ec018
e7eab76
7241a52
582659a
7590242
c682ed2
56d5c61
8190ff3
04518af
b873486
31aca56
9fb33bf
b938180
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -414,8 +414,21 @@ All parameter, weight, gradient are variables in Paddle. | |
self.set_falsenet(net.Clone()); | ||
}); | ||
|
||
py::class_<ExecutorPrepareContext>(m, "ExecutorPrepareContext"); | ||
|
||
py::class_<framework::Executor>(m, "Executor") | ||
.def(py::init<const platform::Place &>()) | ||
.def_static("prepare", | ||
[](const ProgramDesc &pdesc, | ||
int block_id) -> std::unique_ptr<ExecutorPrepareContext> { | ||
return Executor::Prepare(pdesc, block_id); | ||
}) | ||
.def("run_prepared_ctx", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. _run_prepared_ctx There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same as above |
||
[](Executor &self, ExecutorPrepareContext *handle, Scope *scope, | ||
bool create_local_scope, bool create_vars) { | ||
self.RunPreparedContext(handle, scope, create_local_scope, | ||
create_vars); | ||
}) | ||
.def("run", | ||
(void (Executor::*)(const ProgramDesc &, Scope *, int, bool, bool)) & | ||
Executor::Run); | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -179,6 +179,16 @@ def to_name_str(var): | |
return str(feed_var_names + fetch_var_names) | ||
|
||
|
||
class PreparedContext(object): | ||
def __init__(self, handle, program, fetch_list, feed_var_name, | ||
fetch_var_name): | ||
self.handle = handle | ||
self.program = program | ||
self.fetch_list = fetch_list | ||
self.feed_var_name = feed_var_name | ||
self.fetch_var_name = fetch_var_name | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are all of them public members? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove this class if it's no longer used? |
||
|
||
|
||
class Executor(object): | ||
def __init__(self, places): | ||
if not isinstance(places, list) and not isinstance(places, tuple): | ||
|
@@ -235,6 +245,119 @@ def parselod(data): | |
tensor.set_lod(lod) | ||
return tensor | ||
|
||
def _get_program_cache(self, feed, fetch_list): | ||
program_cache_key = get_program_cache_key(feed, fetch_list) | ||
program_cache = self.program_caches.get(program_cache_key, None) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: return at this line? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
return program_cache | ||
|
||
def _add_program_cache(self, feed, fetch_list, program): | ||
program_cache_key = get_program_cache_key(feed, fetch_list) | ||
self.program_caches[program_cache_key] = program | ||
|
||
def _add_feed_fetch_ops(self, program, feed, fetch_list, feed_var_name, | ||
fetch_var_name): | ||
tmp_program = program.clone() | ||
|
||
global_block = tmp_program.global_block() | ||
|
||
if feed_var_name in global_block.vars: | ||
feed_var = global_block.var(feed_var_name) | ||
else: | ||
feed_var = global_block.create_var( | ||
name=feed_var_name, | ||
type=core.VarDesc.VarType.FEED_MINIBATCH, | ||
persistable=True) | ||
|
||
if fetch_var_name in global_block.vars: | ||
fetch_var = global_block.var(fetch_var_name) | ||
else: | ||
fetch_var = global_block.create_var( | ||
name=fetch_var_name, | ||
type=core.VarDesc.VarType.FETCH_LIST, | ||
persistable=True) | ||
|
||
# prepend feed operators | ||
if not has_feed_operators(global_block, feed, feed_var_name): | ||
for i, name in enumerate(feed): | ||
out = global_block.var(name) | ||
global_block.prepend_op( | ||
type='feed', | ||
inputs={'X': [feed_var]}, | ||
outputs={'Out': [out]}, | ||
attrs={'col': i}) | ||
|
||
# append fetch_operators | ||
if not has_fetch_operators(global_block, fetch_list, fetch_var_name): | ||
for i, var in enumerate(fetch_list): | ||
assert isinstance(var, Variable) or isinstance(var, str), ( | ||
"Wrong type for fetch_list[%s]: %s" % (i, type(var))) | ||
global_block.append_op( | ||
type='fetch', | ||
inputs={'X': [var]}, | ||
outputs={'Out': [fetch_var]}, | ||
attrs={'col': i}) | ||
|
||
return tmp_program | ||
|
||
def feed_data(self, program, feed, feed_var_name, scope): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. private member? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
# feed var to framework | ||
for op in program.global_block().ops: | ||
if op.desc.type() == 'feed': | ||
feed_target_name = op.desc.output('Out')[0] | ||
cur_feed = feed[feed_target_name] | ||
if not isinstance(cur_feed, core.LoDTensor): | ||
cur_feed = self.aslodtensor(cur_feed) | ||
idx = op.desc.attr('col') | ||
core.set_feed_variable(scope, cur_feed, feed_var_name, idx) | ||
else: | ||
break | ||
|
||
def fetch_data(self, fetch_list, fetch_var_name, scope): | ||
outs = [ | ||
core.get_fetch_variable(scope, fetch_var_name, i) | ||
for i in xrange(len(fetch_list)) | ||
] | ||
return outs | ||
|
||
def prepare(self, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. private member? |
||
program=None, | ||
feed=None, | ||
fetch_list=None, | ||
feed_var_name='feed', | ||
fetch_var_name='fetch'): | ||
if feed is None: | ||
feed = {} | ||
if not isinstance(feed, dict): | ||
raise TypeError("feed should be a map") | ||
if fetch_list is None: | ||
fetch_list = [] | ||
if program is None: | ||
program = default_main_program() | ||
|
||
if not isinstance(program, Program): | ||
raise TypeError() | ||
|
||
program = self._add_feed_fetch_ops( | ||
program=program, | ||
feed=feed, | ||
fetch_list=fetch_list, | ||
feed_var_name=feed_var_name, | ||
fetch_var_name=fetch_var_name) | ||
handle = self.executor.prepare(program.desc, 0) | ||
return PreparedContext(handle, program, fetch_list, feed_var_name, | ||
fetch_var_name) | ||
|
||
def run_prepared_ctx(self, ctx, feed=None, scope=None, return_numpy=True): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. private? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
if scope is None: | ||
scope = global_scope() | ||
|
||
self.feed_data(ctx.program, feed, ctx.feed_var_name, scope) | ||
self.executor.run_prepared_ctx(ctx.handle, scope, True, True) | ||
outs = self.fetch_data(ctx.fetch_list, ctx.fetch_var_name, scope) | ||
if return_numpy: | ||
outs = as_numpy(outs) | ||
return outs | ||
|
||
def run(self, | ||
program=None, | ||
feed=None, | ||
|
@@ -268,7 +391,6 @@ def run(self, | |
raise TypeError("feed should be a map") | ||
if fetch_list is None: | ||
fetch_list = [] | ||
|
||
if program is None: | ||
program = default_main_program() | ||
|
||
|
@@ -278,79 +400,28 @@ def run(self, | |
if scope is None: | ||
scope = global_scope() | ||
|
||
program_cache = None | ||
program_cache_key = get_program_cache_key(feed, fetch_list) | ||
|
||
if use_program_cache: | ||
# find program cache by cache_key | ||
program_cache = self.program_caches.get(program_cache_key, None) | ||
# TODO(qiao): Should check program_cache and program are exactly the same. | ||
cached_program = self._get_program_cache(feed, fetch_list) | ||
if cached_program is None: | ||
cached_program = self._add_feed_fetch_ops( | ||
program=program, | ||
feed=feed, | ||
fetch_list=fetch_list, | ||
feed_var_name=feed_var_name, | ||
fetch_var_name=fetch_var_name) | ||
self._add_program_cache(feed, fetch_list, cached_program) | ||
program = cached_program | ||
else: | ||
self.program_caches.pop(program_cache_key, None) | ||
|
||
if program_cache is None: | ||
program_cache = program.clone() | ||
|
||
if use_program_cache: | ||
self.program_caches[program_cache_key] = program_cache | ||
|
||
global_block = program_cache.global_block() | ||
|
||
if feed_var_name in global_block.vars: | ||
feed_var = global_block.var(feed_var_name) | ||
else: | ||
feed_var = global_block.create_var( | ||
name=feed_var_name, | ||
type=core.VarDesc.VarType.FEED_MINIBATCH, | ||
persistable=True) | ||
|
||
if fetch_var_name in global_block.vars: | ||
fetch_var = global_block.var(fetch_var_name) | ||
else: | ||
fetch_var = global_block.create_var( | ||
name=fetch_var_name, | ||
type=core.VarDesc.VarType.FETCH_LIST, | ||
persistable=True) | ||
|
||
# prepend feed operators | ||
if not has_feed_operators(global_block, feed, feed_var_name): | ||
for i, name in enumerate(feed): | ||
out = global_block.var(name) | ||
global_block.prepend_op( | ||
type='feed', | ||
inputs={'X': [feed_var]}, | ||
outputs={'Out': [out]}, | ||
attrs={'col': i}) | ||
|
||
# append fetch_operators | ||
if not has_fetch_operators(global_block, fetch_list, | ||
fetch_var_name): | ||
for i, var in enumerate(fetch_list): | ||
assert isinstance(var, Variable) or isinstance(var, str), ( | ||
"Wrong type for fetch_list[%s]: %s" % (i, type(var))) | ||
global_block.append_op( | ||
type='fetch', | ||
inputs={'X': [var]}, | ||
outputs={'Out': [fetch_var]}, | ||
attrs={'col': i}) | ||
|
||
# feed var to framework | ||
for op in program_cache.global_block().ops: | ||
if op.desc.type() == 'feed': | ||
feed_target_name = op.desc.output('Out')[0] | ||
cur_feed = feed[feed_target_name] | ||
if not isinstance(cur_feed, core.LoDTensor): | ||
cur_feed = self.aslodtensor(cur_feed) | ||
idx = op.desc.attr('col') | ||
core.set_feed_variable(scope, cur_feed, feed_var_name, idx) | ||
else: | ||
break | ||
|
||
self.executor.run(program_cache.desc, scope, 0, True, True) | ||
outs = [ | ||
core.get_fetch_variable(scope, fetch_var_name, i) | ||
for i in xrange(len(fetch_list)) | ||
] | ||
program = self._add_feed_fetch_ops( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why is it no longer poping the cached program here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done, add back. |
||
program=program, | ||
feed=feed, | ||
fetch_list=fetch_list, | ||
feed_var_name=feed_var_name, | ||
fetch_var_name=fetch_var_name) | ||
|
||
self.feed_data(program, feed, feed_var_name, scope) | ||
self.executor.run(program.desc, scope, 0, True, True) | ||
outs = self.fetch_data(fetch_list, fetch_var_name, scope) | ||
if return_numpy: | ||
outs = as_numpy(outs) | ||
return outs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_prepare?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the visibility can be controlled on the Python side.