Skip to content

Commit

Permalink
skip scalars when broadcasting for batch dimension agreement
Browse files Browse the repository at this point in the history
  • Loading branch information
froystig committed Mar 20, 2021
1 parent f8c36d9 commit 129f343
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 1 deletion.
1 change: 1 addition & 0 deletions jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -2113,6 +2113,7 @@ def _naryop_weak_type_rule(name, *avals, **kwargs):
return all(aval.weak_type for aval in avals)

def naryop(result_dtype, accepted_dtypes, name, translation_rule=None):
# TODO(frostig,mattjj): only used with arity > 2 once, simplify
dtype_rule = partial(naryop_dtype_rule, result_dtype, accepted_dtypes, name)
shape_rule = partial(_broadcasting_shape_rule, name)
weak_type_rule = partial(_naryop_weak_type_rule, name)
Expand Down
3 changes: 2 additions & 1 deletion jax/interpreters/batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,8 @@ def broadcast_batcher(prim, args, dims, **params):
return (out, (d,) * len(out)) if prim.multiple_results else (out, d)
else:
size, = {shape[d] for shape, d in shapes if d is not not_mapped}
args = [bdim_at_front(x, d, size) for x, d in zip(args, dims)]
args = [bdim_at_front(x, d, size) if np.ndim(x) else x
for x, d in zip(args, dims)]
ndim = max(np.ndim(x) for x in args) # special-case scalar broadcasting
args = [_handle_scalar_broadcasting(ndim, x, d) for x, d in zip(args, dims)]
out = prim.bind(*args, **params)
Expand Down
9 changes: 9 additions & 0 deletions tests/batching_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import jax
import jax.numpy as jnp
import jax.scipy as jsp
from jax import test_util as jtu
from jax import lax
from jax._src.lax import parallel
Expand Down Expand Up @@ -1240,5 +1241,13 @@ def testNonJaxTypedOutput(self):
TypeError, "Output from batched function.*is not a valid JAX type"):
vmap(lambda x: "hello")(np.arange(5))

def testIssue6096(self):
def f(x):
return jsp.special.betainc(jnp.ones(3), 1., x)

self.assertEquals(f(jnp.ones(3)).shape, (3,))
self.assertEquals(jax.vmap(f)(jnp.ones((2, 3))).shape, (2, 3))


if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit 129f343

Please sign in to comment.