Skip to content

Commit

Permalink
Improve error messages in FanInSum/FanInConcat
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 391669059
  • Loading branch information
romanngg committed Aug 19, 2021
1 parent 50517f4 commit 8ca8b98
Showing 1 changed file with 9 additions and 5 deletions.
14 changes: 9 additions & 5 deletions neural_tangents/stax.py
Original file line number Diff line number Diff line change
Expand Up @@ -1607,16 +1607,20 @@ def kernel_fn(ks: Kernels, **kwargs) -> Kernel:
ks, is_reversed = _proprocess_kernels_for_fan_in(ks)
if not all([k.shape1 == ks[0].shape1 and
k.shape2 == ks[0].shape2 for k in ks[1:]]):
raise ValueError('All shapes should be equal in `FanInSum/FanInProd`.')
raise ValueError('All shapes should be equal in `FanInSum/FanInProd`, '
f'got `x1.shape`s of {[k.shape1 for k in ks]}, '
f'`x2.shape`s of {[k.shape2 for k in ks]}.')

is_gaussian = all(k.is_gaussian for k in ks)
if not is_gaussian and len(ks) != 1:
# TODO(xlc): FanInSum/FanInConcat could allow non-Gaussian inputs, but
# we need to propagate the mean of the random variables as well.
raise NotImplementedError('`FanInSum` layer along the non-channel axis is'
' only implemented for the case if all input'
' layers guaranteed to be mean-zero Gaussian,'
' i.e. having all `is_gaussian set to `True`.')
raise NotImplementedError('`FanInSum` or `FanInConcat` along the '
'non-channel axis is only implemented for the '
'case where all input layers guaranteed to be '
'mean-zero Gaussian, i.e. having all '
'`is_gaussian` set to `True`, got '
f'{[k.is_gaussian for k in ks]}.')

_mats_sum = lambda mats: None if mats[0] is None else sum(mats)

Expand Down

0 comments on commit 8ca8b98

Please sign in to comment.