Skip to content

Commit

Permalink
Allow running on GPU compute capability 6.x
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 701248931
  • Loading branch information
Augustin-Zidek committed Nov 29, 2024
1 parent 38d599b commit 2eb6555
Showing 1 changed file with 15 additions and 7 deletions.
22 changes: 15 additions & 7 deletions run_alphafold.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,13 +637,21 @@ def main(_):
if _RUN_INFERENCE.value:
# Fail early on incompatible devices, but only if we're running inference.
gpu_devices = jax.local_devices(backend='gpu')
if gpu_devices and float(gpu_devices[0].compute_capability) < 8.0:
raise ValueError(
'There are currently known unresolved numerical issues with using'
' devices with compute capability less than 8.0. See '
' https://github.com/google-deepmind/alphafold3/issues/59 for'
' tracking.'
)
if gpu_devices:
compute_capability = float(gpu_devices[0].compute_capability)
if compute_capability < 6.0:
raise ValueError(
'AlphaFold 3 requires at least GPU compute capability 6.0 (see'
' https://developer.nvidia.com/cuda-gpus).'
)
elif 7.0 <= compute_capability < 8.0:
raise ValueError(
'There are currently known unresolved numerical issues with using'
' devices with GPU compute capability 7.x (see'
' https://developer.nvidia.com/cuda-gpus). Follow '
' https://github.com/google-deepmind/alphafold3/issues/59 for'
' tracking.'
)

notice = textwrap.wrap(
'Running AlphaFold 3. Please note that standard AlphaFold 3 model'
Expand Down

0 comments on commit 2eb6555

Please sign in to comment.