From 86955370cd868b5d4f46f2f80f7632fd864773e3 Mon Sep 17 00:00:00 2001 From: Manu Seth <22492939+mseth10@users.noreply.github.com> Date: Thu, 23 Apr 2020 01:14:44 -0700 Subject: [PATCH] add logic for no batch size while getting data arrays from executors (#17772) (#18122) Co-authored-by: Ubuntu Co-authored-by: Ubuntu --- python/mxnet/module/executor_group.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/python/mxnet/module/executor_group.py b/python/mxnet/module/executor_group.py index d47665d6d509..f2cb62fc8396 100755 --- a/python/mxnet/module/executor_group.py +++ b/python/mxnet/module/executor_group.py @@ -308,8 +308,16 @@ def decide_slices(self, data_shapes): def _collect_arrays(self): """Collect internal arrays from executors.""" # convenient data structures - self.data_arrays = [[(self.slices[i], e.arg_dict[name]) for i, e in enumerate(self.execs)] - for name, _ in self.data_shapes] + + # check if self.slices is populated, if not then that means that there is no batch size + if self.slices: + # based on batch size, slice up data for the given contexts (self.execs) + self.data_arrays = [[(self.slices[i], e.arg_dict[name]) for i, e in enumerate(self.execs)] + for name, _ in self.data_shapes] + else: + # just use the context index as index into the data + self.data_arrays = [[(slice(i, i+1), e.arg_dict[name]) for i, e in enumerate(self.execs)] + for name, _ in self.data_shapes] self.state_arrays = [[e.arg_dict[name] for e in self.execs] for name in self.state_names]