diff --git a/data_juicer/utils/model_utils.py b/data_juicer/utils/model_utils.py index 4569c0de1..cd5922f6e 100644 --- a/data_juicer/utils/model_utils.py +++ b/data_juicer/utils/model_utils.py @@ -289,8 +289,8 @@ def move_to_cuda(model, rank): f'Moving {module.__class__.__name__} to CUDA device {rank}') module.to(f'cuda:{rank}') # Optionally, verify the device assignment - logger.debug(f'{module.__class__.__name__} is on device ' - f'{next(module.parameters()).device}') + logger.debug( + f'{module.__class__.__name__} is on device {module.device}') def get_model(model_key=None, rank=None): @@ -303,7 +303,7 @@ def get_model(model_key=None, rank=None): f'{model_key} not found in MODEL_ZOO ({mp.current_process().name})' ) MODEL_ZOO[model_key] = model_key() - if use_cuda(): - rank = 0 if rank is None else rank - move_to_cuda(MODEL_ZOO[model_key], rank) + if use_cuda(): + rank = 0 if rank is None else rank + move_to_cuda(MODEL_ZOO[model_key], rank) return MODEL_ZOO[model_key]