Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(2048): environment performance improvements #172

Merged
merged 15 commits into from
Jun 20, 2023

Conversation

aar65537
Copy link
Contributor

This PR improves the performance of the Game2048 environment. The improvements include

  • Minimizing logic inside of jax.lax.cond and jax.lax.switch
  • Using jax.vmap over jax.lax.scan where possible
  • A new move implementation
  • A can_move implementation that validates an action without mutating the board
no vmap vmap 103 vmap 106
cpu 64.36% 201.80% 392.29%
cuda 900.12% 1923.08% 706.87%

The above figure shows the total performance improvement measured as percent increase in steps/sec. For more detailed benchmarking, see here.

@clement-bonnet
Copy link
Collaborator

Hi @aar65537, thanks a lot for your suggestions for speed improvement! We will look into them, check that the environment's behavior has not changed, and get back to you shortly.

clement-bonnet
clement-bonnet previously approved these changes Jun 20, 2023
Copy link
Collaborator

@clement-bonnet clement-bonnet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you very much for the code and speed improvements! I have left a few comments for which I don't feel very strongly about.

jumanji/environments/logic/game_2048/utils.py Outdated Show resolved Hide resolved
jumanji/environments/logic/game_2048/utils.py Outdated Show resolved Hide resolved
jumanji/environments/logic/game_2048/utils.py Outdated Show resolved Hide resolved
jumanji/environments/logic/game_2048/utils.py Show resolved Hide resolved
jumanji/environments/logic/game_2048/utils.py Outdated Show resolved Hide resolved
jumanji/environments/logic/game_2048/utils.py Outdated Show resolved Hide resolved
@clement-bonnet
Copy link
Collaborator

I have checked that this version is equivalent (in terms of environment behavior) to the current version. I obtained the same learning curves.
Moreover, on a TPU-v4 with 8 cores, I got the performances below (orange is 2048 in main, pink is this updated version):

Environment steps per second (random agent): higher is better

x12 improvement when randomly rolling out.
Screenshot from 2023-06-20 15-17-22

Train epoch time (a2c agent): lower is better

x12 improvement when training.
Screenshot from 2023-06-20 15-27-13

The other curves (learning metrics or episode returns) are completely equivalent in both versions.

@clement-bonnet
Copy link
Collaborator

If that's okay, I will resolve the comments by applying suggestions and will merge.

clement-bonnet
clement-bonnet previously approved these changes Jun 20, 2023
@clement-bonnet clement-bonnet enabled auto-merge (squash) June 20, 2023 14:38
@clement-bonnet clement-bonnet merged commit 32685cb into instadeepai:main Jun 20, 2023
@aar65537 aar65537 deleted the perf-improvements-2048 branch January 18, 2024 20:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants