Skip to content

Commit

Permalink
enable accelerate test on npu
Browse files Browse the repository at this point in the history
  • Loading branch information
jihuazhong authored and statelesshz committed Jul 12, 2023
1 parent 5a65c65 commit 3cb2c3c
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/accelerate/test_utils/scripts/test_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
gather,
is_bf16_available,
is_ipex_available,
is_npu_available,
is_xpu_available,
set_seed,
synchronize_rng_states,
Expand Down Expand Up @@ -358,7 +359,7 @@ def training_check():

accelerator.print("Training yielded the same results on one CPU or distributes setup with batch split.")

if torch.cuda.is_available():
if torch.cuda.is_available() or is_npu_available():
# Mostly a test that FP16 doesn't crash as the operation inside the model is not converted to FP16
print("FP16 training check.")
AcceleratorState._reset_state()
Expand Down
2 changes: 2 additions & 0 deletions src/accelerate/utils/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ def is_bf16_available(ignore_tpu=False):
return not ignore_tpu
if torch.cuda.is_available():
return torch.cuda.is_bf16_supported()
if is_npu_available():
return False
return True


Expand Down

0 comments on commit 3cb2c3c

Please sign in to comment.