Commit fe2b6f3 1 parent 3854ccd commit fe2b6f3 Copy full SHA for fe2b6f3
File tree 2 files changed +10
-5
lines changed
2 files changed +10
-5
lines changed Original file line number Diff line number Diff line change @@ -319,7 +319,7 @@ def _verify_cache_dtype(self) -> None:
319
319
pass
320
320
elif self .cache_dtype == "fp8_e5m2" :
321
321
nvcc_cuda_version = get_nvcc_cuda_version ()
322
- if nvcc_cuda_version < Version ("11.8" ):
322
+ if nvcc_cuda_version and nvcc_cuda_version < Version ("11.8" ):
323
323
raise ValueError (
324
324
"FP8 is not supported when cuda version is lower than 11.8."
325
325
)
Original file line number Diff line number Diff line change @@ -181,13 +181,18 @@ def set_cuda_visible_devices(device_ids: List[int]) -> None:
181
181
os .environ ["CUDA_VISIBLE_DEVICES" ] = "," .join (map (str , device_ids ))
182
182
183
183
184
- def get_nvcc_cuda_version () -> Version :
184
+ def get_nvcc_cuda_version () -> Optional [ Version ] :
185
185
cuda_home = os .environ .get ('CUDA_HOME' )
186
186
if not cuda_home :
187
187
cuda_home = '/usr/local/cuda'
188
- logger .info (
189
- f'CUDA_HOME is not found in the environment. Using { cuda_home } as CUDA_HOME.'
190
- )
188
+ if os .path .isfile (cuda_home + '/bin/nvcc' ):
189
+ logger .info (
190
+ f'CUDA_HOME is not found in the environment. Using { cuda_home } as CUDA_HOME.'
191
+ )
192
+ else :
193
+ logger .warning (
194
+ f'Not found nvcc in { cuda_home } . Skip cuda version check!' )
195
+ return None
191
196
nvcc_output = subprocess .check_output ([cuda_home + "/bin/nvcc" , "-V" ],
192
197
universal_newlines = True )
193
198
output = nvcc_output .split ()
You can’t perform that action at this time.
0 commit comments