Skip to content
This repository has been archived by the owner on Nov 16, 2023. It is now read-only.

Replace HRNet with SEResNet model in the notebook #362

Merged
merged 7 commits into from
Jun 17, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion cv_lib/cv_lib/segmentation/models/patch_deconvnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,6 @@ def init_vgg16_params(self, vgg16, copy_fc8=True):
l2.bias.data = l1.bias.data
i_layer = i_layer + 1


def get_seg_model(cfg, **kwargs):
assert (
cfg.MODEL.IN_CHANNELS == 1
Expand Down
1 change: 1 addition & 0 deletions cv_lib/cv_lib/segmentation/models/patch_deconvnet_skip.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,4 +304,5 @@ def get_seg_model(cfg, **kwargs):
cfg.MODEL.IN_CHANNELS == 1
), f"Patch deconvnet is not implemented to accept {cfg.MODEL.IN_CHANNELS} channels. Please only pass 1 for cfg.MODEL.IN_CHANNELS"
model = patch_deconvnet_skip(n_classes=cfg.DATASET.NUM_CLASSES)

return model
58 changes: 58 additions & 0 deletions cv_lib/cv_lib/segmentation/models/resnet_unet.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import logging
import os

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision

logger = logging.getLogger(__name__)


class FPAv2(nn.Module):
def __init__(self, input_dim, output_dim):
Expand Down Expand Up @@ -208,6 +213,21 @@ def forward(self, x):

return logit

def init_weights(
self, pretrained="",
):
# skip weight initialization - leave at default values

if os.path.isfile(pretrained):
pretrained_dict = torch.load(pretrained)
logger.info("=> loading pretrained model {}".format(pretrained))
model_dict = self.state_dict()
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict.keys()}
for k, _ in pretrained_dict.items():
logger.info("=> loading {} pretrained model {}".format(k, pretrained))
model_dict.update(pretrained_dict)
self.load_state_dict(model_dict)


# stage2 model
class Res34Unetv3(nn.Module):
Expand Down Expand Up @@ -299,6 +319,24 @@ def forward(self, x):

return logit, logit_pixel, logit_image.view(-1)

def init_weights(
self, pretrained="",
):
# skip weight initialization - leave at default values

if pretrained and not os.path.isfile(pretrained):
raise FileNotFoundError(f"The file {pretrained} was not found. Please supply correct path or leave empty")

if os.path.isfile(pretrained):
pretrained_dict = torch.load(pretrained)
logger.info("=> loading pretrained model {}".format(pretrained))
model_dict = self.state_dict()
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict.keys()}
for k, _ in pretrained_dict.items():
logger.info("=> loading {} pretrained model {}".format(k, pretrained))
model_dict.update(pretrained_dict)
self.load_state_dict(model_dict)


# stage3 model
class Res34Unetv5(nn.Module):
Expand Down Expand Up @@ -356,10 +394,30 @@ def forward(self, x):

return logit

def init_weights(
self, pretrained="",
):
# skip weight initialization - leave at default values

if pretrained and not os.path.isfile(pretrained):
raise FileNotFoundError(f"The file {pretrained} was not found. Please supply correct path or leave empty")

if os.path.isfile(pretrained):
pretrained_dict = torch.load(pretrained)
logger.info("=> loading pretrained model {}".format(pretrained))
model_dict = self.state_dict()
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict.keys()}
for k, _ in pretrained_dict.items():
logger.info("=> loading {} pretrained model {}".format(k, pretrained))
model_dict.update(pretrained_dict)
self.load_state_dict(model_dict)


def get_seg_model(cfg, **kwargs):
assert (
cfg.MODEL.IN_CHANNELS == 3
), f"SEResnet Unet deconvnet is not implemented to accept {cfg.MODEL.IN_CHANNELS} channels. Please only pass 3 for cfg.MODEL.IN_CHANNELS"
model = Res34Unetv4(n_classes=cfg.DATASET.NUM_CLASSES)
if "PRETRAINED" in cfg.MODEL.keys():
model.init_weights(cfg.MODEL.PRETRAINED)
return model
1 change: 1 addition & 0 deletions cv_lib/cv_lib/segmentation/models/section_deconvnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,4 +304,5 @@ def get_seg_model(cfg, **kwargs):
cfg.MODEL.IN_CHANNELS == 1
), f"Section deconvnet is not implemented to accept {cfg.MODEL.IN_CHANNELS} channels. Please only pass 1 for cfg.MODEL.IN_CHANNELS"
model = section_deconvnet(n_classes=cfg.DATASET.NUM_CLASSES)

return model
Original file line number Diff line number Diff line change
Expand Up @@ -304,4 +304,5 @@ def get_seg_model(cfg, **kwargs):
cfg.MODEL.IN_CHANNELS == 1
), f"Section deconvnet is not implemented to accept {cfg.MODEL.IN_CHANNELS} channels. Please only pass 1 for cfg.MODEL.IN_CHANNELS"
model = section_deconvnet_skip(n_classes=cfg.DATASET.NUM_CLASSES)

return model
4 changes: 2 additions & 2 deletions cv_lib/cv_lib/segmentation/models/seg_hrnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,6 @@ def init_weights(

def get_seg_model(cfg, **kwargs):
model = HighResolutionNet(cfg, **kwargs)
model.init_weights(cfg.MODEL.PRETRAINED)

if "PRETRAINED" in cfg.MODEL.keys():
model.init_weights(cfg.MODEL.PRETRAINED)
return model
1 change: 1 addition & 0 deletions cv_lib/cv_lib/segmentation/models/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,4 +113,5 @@ def forward(self, x):

def get_seg_model(cfg, **kwargs):
model = UNet(cfg.MODEL.IN_CHANNELS, cfg.DATASET.NUM_CLASSES)

return model
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
"source": [
"# load an existing experiment configuration file\n",
"CONFIG_FILE = (\n",
" \"../../../experiments/interpretation/dutchf3_patch/local/configs/hrnet.yaml\"\n",
" \"../../../experiments/interpretation/dutchf3_patch/local/configs/seresnet_unet.yaml\"\n",
")\n",
"# number of images to score\n",
"N_EVALUATE = 20\n",
Expand Down Expand Up @@ -239,7 +239,7 @@
"max_snapshots = config.TRAIN.SNAPSHOTS\n",
"papermill = False\n",
"dataset_root = config.DATASET.ROOT\n",
"model_pretrained = config.MODEL.PRETRAINED"
"model_pretrained = config.MODEL.PRETRAINED if \"PRETRAINED\" in config.MODEL.keys() else None"
]
},
{
Expand Down Expand Up @@ -859,9 +859,17 @@
"outputs": [],
"source": [
"# use the model which we just fine-tuned\n",
"opts = [\"TEST.MODEL_PATH\", path.join(output_dir, f\"model_f3_nb_seg_hrnet_{train_len}.pth\")]\n",
"if \"hrnet\" in config.MODEL.NAME:\n",
" model_snapshot_name = f\"model_f3_nb_seg_hrnet_{train_len}.pth\"\n",
"elif \"resnet\" in config.MODEL.NAME: \n",
" model_snapshot_name = f\"model_f3_nb_resnet_unet_{train_len}.pth\"\n",
"else:\n",
" raise NotImplementedError(\"We don't support testing this model in this notebook yet\")\n",
" \n",
"opts = [\"TEST.MODEL_PATH\", path.join(output_dir, model_snapshot_name)]\n",
"# uncomment the line below to use the pre-trained model instead\n",
"# opts = [\"TEST.MODEL_PATH\", config.MODEL.PRETRAINED]\n",
"\n",
"config.merge_from_list(opts)"
]
},
Expand Down
27 changes: 17 additions & 10 deletions examples/interpretation/notebooks/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,23 +328,23 @@ def download_pretrained_model(config):
raise NameError(
"Unknown dataset name. Only dutch f3 and penobscot are currently supported."
)

if "hrnet" in config.MODEL.NAME:
model = "hrnet"
elif "deconvnet" in config.MODEL.NAME:
model = "deconvnet"
elif "unet" in config.MODEL.NAME:
model = "unet"
elif "resnet" in config.MODEL.NAME:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't this be "seresnet" or "resnet_unet" ?
we might have other models that have a resent backbone

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I disabled unet as we don't have / provide a pre-trained model with it

model = "seresnetunet"
else:
raise NameError(
"Unknown model name. Only hrnet, deconvnet, and unet are currently supported."
"Unknown model name. Only hrnet, deconvnet, and seresnet_unet are currently supported."
)

# check if the user already supplied a URL, otherwise figure out the URL
if validators.url(config.MODEL.PRETRAINED):
if "PRETRAINED" in config.MODEL.keys() and validators.url(config.MODEL.PRETRAINED):
url = config.MODEL.PRETRAINED
print(f"Will use user-supplied URL of '{url}'")
elif os.path.isfile(config.MODEL.PRETRAINED):
elif "PRETRAINED" in config.MODEL.keys() and os.path.isfile(config.MODEL.PRETRAINED):
url = None
print(f"Will use user-supplied file on local disk of '{config.MODEL.PRETRAINED}'")
else:
Expand All @@ -371,19 +371,20 @@ def download_pretrained_model(config):
and config.TRAIN.DEPTH == "none"
):
url = "http://deepseismicsharedstore.blob.core.windows.net/master-public-models/dutchf3_deconvnetskip_patch_no_depth.pth"

elif (
model == "deconvnet"
and "skip" not in config.MODEL.NAME
and config.TRAIN.DEPTH == "none"
):
url = "http://deepseismicsharedstore.blob.core.windows.net/master-public-models/dutchf3_deconvnet_patch_no_depth.pth"
elif model == "unet" and config.TRAIN.DEPTH == "section":
url = "http://deepseismicsharedstore.blob.core.windows.net/master-public-models/dutchf3_seresnetunet_patch_section_depth.pth"
url = "http://deepseismicsharedstore.blob.core.windows.net/master-public-models/dutchf3_deconvnet_patch_no_depth.pth"
elif model == "seresnetunet" and config.TRAIN.DEPTH == "section":
url = "https://deepseismicsharedstore.blob.core.windows.net/master-public-models/dutchf3_seresnetunet_patch_section_depth.pth"
else:
raise NotImplementedError(
"We don't store a pretrained model for Dutch F3 for this model combination yet."
)


else:
raise NotImplementedError(
"We don't store a pretrained model for this dataset/model combination yet."
Expand Down Expand Up @@ -424,6 +425,11 @@ def download_pretrained_model(config):
# Update config MODEL.PRETRAINED
# TODO: Only HRNet uses a pretrained model currently.
# issue https://github.com/microsoft/seismic-deeplearning/issues/267

# now that we have a pre-trained model, we can set it
if "PRETRAINED" not in config.MODEL.keys():
config.MODEL["PRETRAINED"] = "dummy"

opts = [
Comment on lines 426 to 433
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it true that only HRnet has a pretrained model?
Also, is "dummy" here like a placeholder?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's a key placeholder because we need to add a PRETRAINED key - we check the path later in the code so this can literally be anything

"MODEL.PRETRAINED",
pretrained_model_path,
Expand All @@ -432,6 +438,7 @@ def download_pretrained_model(config):
"TEST.MODEL_PATH",
pretrained_model_path,
]

config.merge_from_list(opts)

return config
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@ WORKERS: 4
PRINT_FREQ: 10
LOG_CONFIG: logging.conf
SEED: 2019
OPENCV_BORDER_CONSTANT: 0


DATASET:
NUM_CLASSES: 6
ROOT: /home/username/data/dutch/data
ROOT: "/home/username/data/dutch/data"
CLASS_WEIGHTS: [0.7151, 0.8811, 0.5156, 0.9346, 0.9683, 0.9852]
MIN: -1
MAX: 1
Expand Down