Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Initializer.__eq__ (#16680)
Browse files Browse the repository at this point in the history
  • Loading branch information
leezu authored Nov 1, 2019
1 parent 6c42992 commit 33d108b
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 0 deletions.
5 changes: 5 additions & 0 deletions python/mxnet/initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,11 @@ def _init_default(self, name, _):
'"weight", "bias", "gamma" (1.0), and "beta" (0.0).' \
'Please use mx.sym.Variable(init=mx.init.*) to set initialization pattern' % name)

def __eq__(self, other):
if not isinstance(other, Initializer):
return NotImplemented
# pylint: disable=unidiomatic-typecheck
return type(self) is type(other) and self._kwargs == other._kwargs

# pylint: disable=invalid-name
_register = registry.get_register_func(Initializer, 'initializer')
Expand Down
15 changes: 15 additions & 0 deletions tests/python/unittest/test_gluon.py
Original file line number Diff line number Diff line change
Expand Up @@ -3119,6 +3119,21 @@ def forward(self, x):
shape = (np.random.randint(1, 10), np.random.randint(1, 10), 1)
block(mx.nd.ones(shape))

def test_shared_parameters_with_non_default_initializer():
class MyBlock(gluon.HybridBlock):
def __init__(self, **kwargs):
super(MyBlock, self).__init__(**kwargs)

with self.name_scope():
self.param = self.params.get("param", shape=(1, ), init=mx.init.Constant(-10.0))

bl = MyBlock()
bl2 = MyBlock(params=bl.collect_params())
assert bl.param is bl2.param
bl3 = MyBlock()
assert bl.param is not bl3.param
assert bl.param.init == bl3.param.init

@with_seed()
def test_reqs_switching_training_inference():
class Foo(gluon.HybridBlock):
Expand Down

0 comments on commit 33d108b

Please sign in to comment.