Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

testing runner with GPU support #3274

Closed
martinjrobins opened this issue Aug 16, 2023 · 4 comments · Fixed by #3380
Closed

testing runner with GPU support #3274

martinjrobins opened this issue Aug 16, 2023 · 4 comments · Fixed by #3380
Labels

Comments

@martinjrobins
Copy link
Contributor

Description

Add test runner with GPU support to test gpu-specific code

Motivation

#3121 introduced some GPU-specific code, which we need to test.

Possible Implementation

We're using a mac mini for M1 testing and deployments, we could use this

Additional context

@BradyPlanden

@BradyPlanden
Copy link
Member

Currently, JAX on Apple M-series GPU's is experimental and requires a build from source (see here) with the a minimum JAX version of 4.11. Is there a reason PyBaMM pin's JAX to 4.7 & 4.8? As similarly discussed in #3371, it appears that we can relax this dependancy?

@jsbrittain
Copy link
Contributor

@BradyPlanden As far as I know we should be able to relax this requirement, at least to the minor version, i.e. 0.4 (there were some substantial API changes between 0.3 and 0.4).

@agriyakhetarpal
Copy link
Member

agriyakhetarpal commented Sep 27, 2023

If it helps, I built jaxlib==0.4.11 from source on an M2 Air from https://developer.apple.com/metal/jax/ as mentioned above. I then installed jax==0.4.10 and jax-metal==0.0.3 (i.e., the ones that do not cause version conflicts currently) and ran the unit tests with python run-tests.py --unit

Verifying that Jax uses the GPU

the output of

import jax
jax.default_backend()

is

2023-09-27 23:54:55.469762: W pjrt_plugin/src/mps_client.cc:535] WARNING: JAX Apple GPU support is experimental and not all JAX functionality is correctly supported!
Metal device set to: Apple M2

systemMemory: 8.00 GB
maxCacheSize: 2.67 GB

'METAL'

Error logs

Note: the unit tests did not complete and exited early with this warning: /opt/homebrew/Cellar/python@3.11/3.11.5/Frameworks/Python.framework/Versions/3.11/lib/python3.11/multiprocessing/resource_tracker.py:224: UserWarning: resource_tracker: There appear to be 1 leaked semaphore objects to clean up at shutdown warnings.warn('resource_tracker: There appear to be %d '

Most Jax-related tests fail, all with the same error, e.g.,

test_evaluator_jax_jvp (test_expression_tree.test_operations.test_evaluate_python.TestEvaluate.test_evaluator_jax_jvp) ... "builtin.module"() ({
  "func.func"() ({
  ^bb0(%arg0: tensor<1x1xf64>):
    %0 = "mps.constant"() {value = dense<2.000000e+00> : tensor<1x1xf64>} : () -> tensor<1x1xf64>
    %1 = "mps.constant"() {value = dense<1.000000e+00> : tensor<1x1xf64>} : () -> tensor<1x1xf64>
    %2 = "mps.power"(%arg0, %1) : (tensor<1x1xf64>, tensor<1x1xf64>) -> tensor<1x1xf64>
    %3 = "mps.multiply"(%2, %0) : (tensor<1x1xf64>, tensor<1x1xf64>) -> tensor<1x1xf64>
    %4 = "mps.constant"() {value = dense<1> : tensor<4xsi64>} : () -> tensor<4xsi64>
    %5 = "mps.reshape"(%3, %4) : (tensor<1x1xf64>, tensor<4xsi64>) -> tensor<1x1x1x1xf64>
    "func.return"(%5) : (tensor<1x1x1x1xf64>) -> ()
  }) {arg_attrs = [{jax.arg_info = "y", mhlo.sharding = "{replicated}"}], function_type = (tensor<1x1xf64>) -> tensor<1x1x1x1xf64>, res_attrs = [{jax.result_info = ""}], sym_name = "main", sym_visibility = "public"} : () -> ()
}) {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32, sym_name = "jit_evaluate_jax"} : () -> ()
ERROR

From a quick look, jax-ml/jax#16435 seems to be related. With the instability of jax-metal in its current state, perhaps it would be best to test this GPU-specific code on either Windows + WSL or an Ubuntu machine? On a GPU-enabled self-hosted runner, we could use https://github.com/marketplace/actions/cuda-toolkit for drivers. It looks like GPU-enabled runners are on GitHub's roadmap too (github/roadmap#505).

@BradyPlanden
Copy link
Member

BradyPlanden commented Sep 28, 2023

Right, so a bit of investigation confirmed that M-series GPU's don't support FP64. In this case, config.update("jax_enable_x64", True) needs to be skipped in evaluate_python.py and jax_bdf_solver.py. Once that's done, the python run-tests.py --unit proceeds, but fails due to assertion tolerances (Jax is FP32 as default).

To proceed we could add os.uname() logic to the config.update("jax_enable_x64", True), in addition with GPU or CPU flags. Thoughts?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants