Skip to content

Latest commit

 

History

History
48 lines (46 loc) · 1.79 KB

Torch_保存模型和从文件中加载模型.org

File metadata and controls

48 lines (46 loc) · 1.79 KB

如何将模型保存到文件并从文件中加载模型?

第一种方式:仅保存和加载模型参数(官方的推荐方式)

保存:

PATH = 'path/to/dir/my_model.pth'
torch.save(trained_model.state_dict(), PATH)

之后以下面方式加载:

untrained_model = YourModelClass(*args, **kwargs) # 要先实例化你的模型,并和要加载的模型结构相同
untrained_model.load_state_dict(torch.load(PATH))

第二种方式:保存和加载整个模型(包括中间参数)

保存:

torch.save(trained_model, PATH)

加载:

untrained_model = torch.load(PATH)

在不同的情景下你该怎么用呢?

  • 模型已经训练完成,你需要保存以备以后使用,如推断等
    torch.save(model.state_dict(), PATH) # 保存模型
    untrained_model = YourModelClass(*args, **kwargs) # 要先实例化你的模型,并和要加载的模型结构相同
    untrained_model.load_state_dict(torch.load(PATH)) # 加载到实例化了的模型
    untrained_model.eval() # 加载之后默认为train mode,推断时先转换到eval mode.
        
  • 训练未完成,你需要保存以从断点处恢复训练进度。除了要保存模型,你还要保存optimizer, epoch, score等状态。
    state = { # 保存恢复训练所需要的所有状态
        'epoch': epoch,
        'model_state_dit': your_model.state_dict(),
        'optimizer': optimizer.state_dict(),
        # ...
        }
    torch.save(state, PATH) # 加载训练的所有状态
    # 加载训练进度
    state = torch.load(PATH)
    # 恢复进度
    epoch = state['epoch']
    your_model.load_state_dict(state['model_state_dit'])
    optimizer = state['optimizer']