From 49b8fad0f08b73ce8360daef816b3bae0da0cb31 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Fri, 17 Feb 2023 18:20:18 -0800 Subject: [PATCH] add map location to huggingface utils --- composer/models/huggingface.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/composer/models/huggingface.py b/composer/models/huggingface.py index 48d5676fdb..243f9e6007 100644 --- a/composer/models/huggingface.py +++ b/composer/models/huggingface.py @@ -205,7 +205,7 @@ def hf_from_composer_checkpoint( get_file(checkpoint_path, str(local_checkpoint_save_location)) # load the state dict in - loaded_state_dict = torch.load(local_checkpoint_save_location) + loaded_state_dict = torch.load(local_checkpoint_save_location, map_location='cpu') hf_state = loaded_state_dict['state']['integrations']['huggingface'] hf_model_state = hf_state['model'] @@ -501,7 +501,7 @@ def write_huggingface_pretrained_from_composer_checkpoint( # download the checkpoint file get_file(str(checkpoint_path), str(local_checkpoint_save_location)) - composer_state_dict = torch.load(local_checkpoint_save_location) + composer_state_dict = torch.load(local_checkpoint_save_location, map_location='cpu') config = get_hf_config_from_composer_state_dict(composer_state_dict) config.save_pretrained(output_folder)