Skip to content
This repository has been archived by the owner on Nov 16, 2023. It is now read-only.

Commit

Permalink
tensorboard notebook fix & loading of pre-trained models fix (#397)
Browse files Browse the repository at this point in the history
Co-authored-by: Max Kaznady <max.kaznady@gmail.com>
  • Loading branch information
maxkazmsft and maxbikes authored Jul 10, 2020
1 parent a9454ca commit 6fedb0b
Show file tree
Hide file tree
Showing 8 changed files with 40 additions and 10 deletions.
8 changes: 6 additions & 2 deletions cv_lib/cv_lib/segmentation/models/patch_deconvnet.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import torch
import torch.nn as nn


Expand Down Expand Up @@ -304,5 +304,9 @@ def get_seg_model(cfg, **kwargs):
cfg.MODEL.IN_CHANNELS == 1
), f"Patch deconvnet is not implemented to accept {cfg.MODEL.IN_CHANNELS} channels. Please only pass 1 for cfg.MODEL.IN_CHANNELS"
model = patch_deconvnet(n_classes=cfg.DATASET.NUM_CLASSES)

# load the pre-trained model
if "PRETRAINED" in cfg.MODEL.keys():
trained_model = torch.load(cfg.MODEL.PRETRAINED)
trained_model = {k.replace("module.", ""): v for (k, v) in trained_model.items()}
model.load_state_dict(trained_model, strict=True)
return model
8 changes: 6 additions & 2 deletions cv_lib/cv_lib/segmentation/models/patch_deconvnet_skip.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import torch
import torch.nn as nn


Expand Down Expand Up @@ -304,5 +304,9 @@ def get_seg_model(cfg, **kwargs):
cfg.MODEL.IN_CHANNELS == 1
), f"Patch deconvnet is not implemented to accept {cfg.MODEL.IN_CHANNELS} channels. Please only pass 1 for cfg.MODEL.IN_CHANNELS"
model = patch_deconvnet_skip(n_classes=cfg.DATASET.NUM_CLASSES)

# load the pre-trained model
if "PRETRAINED" in cfg.MODEL.keys():
trained_model = torch.load(cfg.MODEL.PRETRAINED)
trained_model = {k.replace("module.", ""): v for (k, v) in trained_model.items()}
model.load_state_dict(trained_model, strict=True)
return model
5 changes: 5 additions & 0 deletions cv_lib/cv_lib/segmentation/models/resnet_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,4 +367,9 @@ def get_seg_model(cfg, **kwargs):
cfg.MODEL.IN_CHANNELS == 3
), f"SEResnet Unet deconvnet is not implemented to accept {cfg.MODEL.IN_CHANNELS} channels. Please only pass 3 for cfg.MODEL.IN_CHANNELS"
model = Res34Unetv4(n_classes=cfg.DATASET.NUM_CLASSES)
# load the pre-trained model
if "PRETRAINED" in cfg.MODEL.keys():
trained_model = torch.load(cfg.MODEL.PRETRAINED)
trained_model = {k.replace("module.", ""): v for (k, v) in trained_model.items()}
model.load_state_dict(trained_model, strict=True)
return model
8 changes: 6 additions & 2 deletions cv_lib/cv_lib/segmentation/models/section_deconvnet.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import torch
import torch.nn as nn


Expand Down Expand Up @@ -304,5 +304,9 @@ def get_seg_model(cfg, **kwargs):
cfg.MODEL.IN_CHANNELS == 1
), f"Section deconvnet is not implemented to accept {cfg.MODEL.IN_CHANNELS} channels. Please only pass 1 for cfg.MODEL.IN_CHANNELS"
model = section_deconvnet(n_classes=cfg.DATASET.NUM_CLASSES)

# load the pre-trained model
if "PRETRAINED" in cfg.MODEL.keys():
trained_model = torch.load(cfg.MODEL.PRETRAINED)
trained_model = {k.replace("module.", ""): v for (k, v) in trained_model.items()}
model.load_state_dict(trained_model, strict=True)
return model
8 changes: 6 additions & 2 deletions cv_lib/cv_lib/segmentation/models/section_deconvnet_skip.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import torch
import torch.nn as nn


Expand Down Expand Up @@ -304,5 +304,9 @@ def get_seg_model(cfg, **kwargs):
cfg.MODEL.IN_CHANNELS == 1
), f"Section deconvnet is not implemented to accept {cfg.MODEL.IN_CHANNELS} channels. Please only pass 1 for cfg.MODEL.IN_CHANNELS"
model = section_deconvnet_skip(n_classes=cfg.DATASET.NUM_CLASSES)

# load the pre-trained model
if "PRETRAINED" in cfg.MODEL.keys():
trained_model = torch.load(cfg.MODEL.PRETRAINED)
trained_model = {k.replace("module.", ""): v for (k, v) in trained_model.items()}
model.load_state_dict(trained_model, strict=True)
return model
6 changes: 5 additions & 1 deletion cv_lib/cv_lib/segmentation/models/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,5 +113,9 @@ def forward(self, x):

def get_seg_model(cfg, **kwargs):
model = UNet(cfg.MODEL.IN_CHANNELS, cfg.DATASET.NUM_CLASSES)

# load the pre-trained model
if "PRETRAINED" in cfg.MODEL.keys():
trained_model = torch.load(cfg.MODEL.PRETRAINED)
trained_model = {k.replace("module.", ""): v for (k, v) in trained_model.items()}
model.load_state_dict(trained_model, strict=True)
return model
2 changes: 1 addition & 1 deletion environment/anaconda/local/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ dependencies:
- torchvision>=0.5.0
- pandas==0.25.3
- scikit-learn==0.21.3
- tensorflow==2.1.0
- tensorflow==2.1
- opt-einsum>=2.3.2
- tqdm==4.39.0
- itkwidgets==0.23.1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,4 +156,9 @@ def getFeatures(self, x, layer_no):

def get_seg_model(cfg, **kwargs):
model = TextureNet(n_classes=cfg.DATASET.NUM_CLASSES)
# load the pre-trained model
if "PRETRAINED" in cfg.MODEL.keys():
trained_model = torch.load(cfg.MODEL.PRETRAINED)
trained_model = {k.replace("module.", ""): v for (k, v) in trained_model.items()}
model.load_state_dict(trained_model, strict=True)
return model

0 comments on commit 6fedb0b

Please sign in to comment.