-
Notifications
You must be signed in to change notification settings - Fork 85
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix: pacman ghost valid action calculations result in NaNs #241
Conversation
Hi, thank you for spotting this bug! |
I've updated my commit to bump the version of PacMan to v1. Let me know if I've missed anything! |
I did a bit of digging, and essentially if the function is The NaN's don't show up in the |
Here is the script that I run, with the environment variable import jax
from jumanji.environments.routing.pac_man import PacMan
if __name__ == "__main__":
jax.disable_jit(True)
seed = 2024
key = jax.random.PRNGKey(seed)
reset_key, key = jax.random.split(key)
env = PacMan()
state, tstep = env.reset(reset_key)
next_state, tstep = env.step(state, 1) |
When I train with and without the fix, I get the exact same learning curves (same loss at every step), hinting that the behavior has not changed. I wonder if the NaN behavior depends on the version of JAX? |
Yes, the behavior would be the same, since according to this thread, XLA returns 0 instead of NaN for If this doesn't warrant a version bump, then I'm more than happy to change the version back to v0. |
Oh that makes complete sense. Since the behavior of the non-jitted environment changes, let's then bump the version. Thank you for your contribution! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
In the PacMan environment, when trying to calculate all the valid actions a ghost could take (in
check_ghost_wall_collisions
in pac_man/utils.py) theinvert_mask * jnp.inf
call was producing an array full of NaN's where invert_mask == 1. This lead to all actions being valid for ghosts.Instead, what this line should be doing is a
jnp.where
call, that conditionally replaces all 1's ininvert_mask
withjnp.inf
.