Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Deferred computation does not work in HybridSequential #18593

Closed
xidulu opened this issue Jun 19, 2020 · 2 comments · Fixed by #19470
Closed

Deferred computation does not work in HybridSequential #18593

xidulu opened this issue Jun 19, 2020 · 2 comments · Fixed by #19470
Assignees

Comments

@xidulu
Copy link
Contributor

xidulu commented Jun 19, 2020

Description

import mxnet as mx
from mxnet import np, npx
from mxnet.gluon.nn import HybridBlock

class normalBlock(HybridBlock):
    def forward(self, x):
        return (x + 1)

shape = (4, 4)
for hybridize in [True, False]:
    initial_value = np.ones(shape)
    net = mx.gluon.nn.HybridSequential()
    net.add(normalBlock())
    net.add(normalBlock())
    if hybridize:
        net.hybridize()
    mx_out = net(initial_value).asnumpy()
NotImplementedForSymbol                   Traceback (most recent call last)
<ipython-input-321-a136cd061b1d> in <module>
     15     if hybridize:
     16         net.hybridize()
---> 17     mx_out = net(initial_value).asnumpy()
     18 

~/mxnet_master_develop/python/mxnet/gluon/block.py in __call__(self, x, *args)
   1322         if self.hybrid_forward.__func__ is not HybridBlock.hybrid_forward:
   1323             # Gluon 1 based on F:  hybrid_forward is defined by user
-> 1324             return super().__call__(x, *args)
   1325         else:  # Gluon 2 based on deferred compute mode
   1326             assert self.forward is not HybridBlock.forward, (

~/mxnet_master_develop/python/mxnet/gluon/block.py in __call__(self, *args)
    703             hook(self, args)
    704 
--> 705         out = self.forward(*args)
    706 
    707         for hook in self._forward_hooks.values():

~/mxnet_master_develop/python/mxnet/gluon/block.py in forward(self, x, *args)
   1367                                      'Find all contexts = {}'.format(ctx_set))
   1368                 with ctx:
-> 1369                     return self._call_cached_op(x, *args)
   1370             with ctx:
   1371                 try:

~/mxnet_master_develop/python/mxnet/gluon/block.py in _call_cached_op(self, *args)
   1055     def _call_cached_op(self, *args):
   1056         if self._cached_op is None:
-> 1057             self._build_cache(*args)
   1058         assert self._cached_op, "Gluon failed to build the cache. " \
   1059                                 "This should never happen. " \

~/mxnet_master_develop/python/mxnet/gluon/block.py in _build_cache(self, *args)
    984 
    985     def _build_cache(self, *args):
--> 986         data, out = self._get_graph(*args)
    987         data_names = {data.name: i for i, data in enumerate(data)}
    988         params = self.collect_params()

~/mxnet_master_develop/python/mxnet/gluon/block.py in _get_graph(self, *args)
    978         if not self._cached_graph:
    979             if self.hybrid_forward.__func__ is not HybridBlock.hybrid_forward:  # Gluon 1
--> 980                 return self._get_graph_v1(*args)
    981             else:  # Gluon 2 based on deferred compute mode
    982                 return self._get_graph_v2(*args)

~/mxnet_master_develop/python/mxnet/gluon/block.py in _get_graph_v1(self, *args)
    942             params = {i: j.var() for i, j in self._reg_params.items()}
    943             with self.name_scope():
--> 944                 out = self.hybrid_forward(symbol, *grouped_inputs, **params)  # pylint: disable=no-value-for-parameter
    945             out, self._out_format = _flatten(out, "output")
    946 

~/mxnet_master_develop/python/mxnet/gluon/nn/basic_layers.py in hybrid_forward(self, F, x, *args)
    126     def hybrid_forward(self, F, x, *args):
    127         for block in self._children.values():
--> 128             x = block()(x, *args)
    129             args = []
    130             if isinstance(x, (tuple, list)):

~/mxnet_master_develop/python/mxnet/gluon/block.py in __call__(self, x, *args)
   1343                 return super().__call__(x, *args)
   1344 
-> 1345             return self._call_cached_op(x, *args)
   1346 
   1347     def forward(self, x, *args):

~/mxnet_master_develop/python/mxnet/gluon/block.py in _call_cached_op(self, *args)
   1055     def _call_cached_op(self, *args):
   1056         if self._cached_op is None:
-> 1057             self._build_cache(*args)
   1058         assert self._cached_op, "Gluon failed to build the cache. " \
   1059                                 "This should never happen. " \

~/mxnet_master_develop/python/mxnet/gluon/block.py in _build_cache(self, *args)
    984 
    985     def _build_cache(self, *args):
--> 986         data, out = self._get_graph(*args)
    987         data_names = {data.name: i for i, data in enumerate(data)}
    988         params = self.collect_params()

~/mxnet_master_develop/python/mxnet/gluon/block.py in _get_graph(self, *args)
    980                 return self._get_graph_v1(*args)
    981             else:  # Gluon 2 based on deferred compute mode
--> 982                 return self._get_graph_v2(*args)
    983         return self._cached_graph
    984 

~/mxnet_master_develop/python/mxnet/gluon/block.py in _get_graph_v2(self, *args)
    952         if not self._cached_graph:
    953             flatten_args, self._in_format = _flatten(args, "input")
--> 954             flatten_args = [ele.detach() if ele is not None else None for ele in flatten_args]
    955             real_args = [ele for ele in flatten_args if ele is not None]
    956             if len(real_args) == 0:

~/mxnet_master_develop/python/mxnet/gluon/block.py in <listcomp>(.0)
    952         if not self._cached_graph:
    953             flatten_args, self._in_format = _flatten(args, "input")
--> 954             flatten_args = [ele.detach() if ele is not None else None for ele in flatten_args]
    955             real_args = [ele for ele in flatten_args if ele is not None]
    956             if len(real_args) == 0:

~/mxnet_master_develop/python/mxnet/symbol/symbol.py in detach(self)
   2791 
   2792     def detach(self):
-> 2793         raise NotImplementedForSymbol(self.detach, None)
   2794 
   2795     def backward(self):

NotImplementedForSymbol: Function detach is not implemented for Symbol and only available in NDArray.
@xidulu xidulu added the Bug label Jun 19, 2020
@leezu
Copy link
Contributor

leezu commented Jun 19, 2020

It's "by design" as you can't call the new API from the old API: Calling a HybridBock.forward block from HybridBlock.hybrid_forward is not supported. Note that all the MXNet provided HybridBlocks currently implement hybrid_forward for compatibility.

We can add temporary hack to HybridSequential to support this usecase if needed, as HybridSequential only passes through the arguments. In general HybridSequential will be switched to the new API together with all the other HybridBlocks in the near future.

@leezu leezu self-assigned this Jun 19, 2020
@leezu leezu added the v2.0 label Jun 19, 2020
@xidulu
Copy link
Contributor Author

xidulu commented Jun 19, 2020

@leezu
Thanks for your response, I am not in urgent need of this usecase.
But I believe it's crucial to let HybridSequential support the new API.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants