diff --git a/cv_lib/cv_lib/segmentation/models/patch_deconvnet.py b/cv_lib/cv_lib/segmentation/models/patch_deconvnet.py index 4ee1ed59..522dd8ce 100644 --- a/cv_lib/cv_lib/segmentation/models/patch_deconvnet.py +++ b/cv_lib/cv_lib/segmentation/models/patch_deconvnet.py @@ -298,7 +298,6 @@ def init_vgg16_params(self, vgg16, copy_fc8=True): l2.bias.data = l1.bias.data i_layer = i_layer + 1 - def get_seg_model(cfg, **kwargs): assert ( cfg.MODEL.IN_CHANNELS == 1 diff --git a/cv_lib/cv_lib/segmentation/models/patch_deconvnet_skip.py b/cv_lib/cv_lib/segmentation/models/patch_deconvnet_skip.py index d5506b84..223cf74f 100644 --- a/cv_lib/cv_lib/segmentation/models/patch_deconvnet_skip.py +++ b/cv_lib/cv_lib/segmentation/models/patch_deconvnet_skip.py @@ -304,4 +304,5 @@ def get_seg_model(cfg, **kwargs): cfg.MODEL.IN_CHANNELS == 1 ), f"Patch deconvnet is not implemented to accept {cfg.MODEL.IN_CHANNELS} channels. Please only pass 1 for cfg.MODEL.IN_CHANNELS" model = patch_deconvnet_skip(n_classes=cfg.DATASET.NUM_CLASSES) + return model diff --git a/cv_lib/cv_lib/segmentation/models/resnet_unet.py b/cv_lib/cv_lib/segmentation/models/resnet_unet.py index 05badb64..5ec7e444 100644 --- a/cv_lib/cv_lib/segmentation/models/resnet_unet.py +++ b/cv_lib/cv_lib/segmentation/models/resnet_unet.py @@ -1,11 +1,16 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +import logging +import os + import torch import torch.nn as nn import torch.nn.functional as F import torchvision +logger = logging.getLogger(__name__) + class FPAv2(nn.Module): def __init__(self, input_dim, output_dim): @@ -208,6 +213,21 @@ def forward(self, x): return logit + def init_weights( + self, pretrained="", + ): + # skip weight initialization - leave at default values + + if os.path.isfile(pretrained): + pretrained_dict = torch.load(pretrained) + logger.info("=> loading pretrained model {}".format(pretrained)) + model_dict = self.state_dict() + pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict.keys()} + for k, _ in pretrained_dict.items(): + logger.info("=> loading {} pretrained model {}".format(k, pretrained)) + model_dict.update(pretrained_dict) + self.load_state_dict(model_dict) + # stage2 model class Res34Unetv3(nn.Module): @@ -299,6 +319,24 @@ def forward(self, x): return logit, logit_pixel, logit_image.view(-1) + def init_weights( + self, pretrained="", + ): + # skip weight initialization - leave at default values + + if pretrained and not os.path.isfile(pretrained): + raise FileNotFoundError(f"The file {pretrained} was not found. Please supply correct path or leave empty") + + if os.path.isfile(pretrained): + pretrained_dict = torch.load(pretrained) + logger.info("=> loading pretrained model {}".format(pretrained)) + model_dict = self.state_dict() + pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict.keys()} + for k, _ in pretrained_dict.items(): + logger.info("=> loading {} pretrained model {}".format(k, pretrained)) + model_dict.update(pretrained_dict) + self.load_state_dict(model_dict) + # stage3 model class Res34Unetv5(nn.Module): @@ -356,10 +394,30 @@ def forward(self, x): return logit + def init_weights( + self, pretrained="", + ): + # skip weight initialization - leave at default values + + if pretrained and not os.path.isfile(pretrained): + raise FileNotFoundError(f"The file {pretrained} was not found. Please supply correct path or leave empty") + + if os.path.isfile(pretrained): + pretrained_dict = torch.load(pretrained) + logger.info("=> loading pretrained model {}".format(pretrained)) + model_dict = self.state_dict() + pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict.keys()} + for k, _ in pretrained_dict.items(): + logger.info("=> loading {} pretrained model {}".format(k, pretrained)) + model_dict.update(pretrained_dict) + self.load_state_dict(model_dict) + def get_seg_model(cfg, **kwargs): assert ( cfg.MODEL.IN_CHANNELS == 3 ), f"SEResnet Unet deconvnet is not implemented to accept {cfg.MODEL.IN_CHANNELS} channels. Please only pass 3 for cfg.MODEL.IN_CHANNELS" model = Res34Unetv4(n_classes=cfg.DATASET.NUM_CLASSES) + if "PRETRAINED" in cfg.MODEL.keys(): + model.init_weights(cfg.MODEL.PRETRAINED) return model diff --git a/cv_lib/cv_lib/segmentation/models/section_deconvnet.py b/cv_lib/cv_lib/segmentation/models/section_deconvnet.py index 7234b1ee..20d583b5 100644 --- a/cv_lib/cv_lib/segmentation/models/section_deconvnet.py +++ b/cv_lib/cv_lib/segmentation/models/section_deconvnet.py @@ -304,4 +304,5 @@ def get_seg_model(cfg, **kwargs): cfg.MODEL.IN_CHANNELS == 1 ), f"Section deconvnet is not implemented to accept {cfg.MODEL.IN_CHANNELS} channels. Please only pass 1 for cfg.MODEL.IN_CHANNELS" model = section_deconvnet(n_classes=cfg.DATASET.NUM_CLASSES) + return model diff --git a/cv_lib/cv_lib/segmentation/models/section_deconvnet_skip.py b/cv_lib/cv_lib/segmentation/models/section_deconvnet_skip.py index cb8b2ecb..fd172d2a 100644 --- a/cv_lib/cv_lib/segmentation/models/section_deconvnet_skip.py +++ b/cv_lib/cv_lib/segmentation/models/section_deconvnet_skip.py @@ -304,4 +304,5 @@ def get_seg_model(cfg, **kwargs): cfg.MODEL.IN_CHANNELS == 1 ), f"Section deconvnet is not implemented to accept {cfg.MODEL.IN_CHANNELS} channels. Please only pass 1 for cfg.MODEL.IN_CHANNELS" model = section_deconvnet_skip(n_classes=cfg.DATASET.NUM_CLASSES) + return model diff --git a/cv_lib/cv_lib/segmentation/models/seg_hrnet.py b/cv_lib/cv_lib/segmentation/models/seg_hrnet.py index 6671603f..40b4c6b6 100644 --- a/cv_lib/cv_lib/segmentation/models/seg_hrnet.py +++ b/cv_lib/cv_lib/segmentation/models/seg_hrnet.py @@ -445,6 +445,6 @@ def init_weights( def get_seg_model(cfg, **kwargs): model = HighResolutionNet(cfg, **kwargs) - model.init_weights(cfg.MODEL.PRETRAINED) - + if "PRETRAINED" in cfg.MODEL.keys(): + model.init_weights(cfg.MODEL.PRETRAINED) return model diff --git a/cv_lib/cv_lib/segmentation/models/unet.py b/cv_lib/cv_lib/segmentation/models/unet.py index c6ae6813..6eea78d7 100644 --- a/cv_lib/cv_lib/segmentation/models/unet.py +++ b/cv_lib/cv_lib/segmentation/models/unet.py @@ -113,4 +113,5 @@ def forward(self, x): def get_seg_model(cfg, **kwargs): model = UNet(cfg.MODEL.IN_CHANNELS, cfg.DATASET.NUM_CLASSES) + return model diff --git a/examples/interpretation/notebooks/Dutch_F3_patch_model_training_and_evaluation.ipynb b/examples/interpretation/notebooks/Dutch_F3_patch_model_training_and_evaluation.ipynb index a1fdd642..4d42ca7b 100644 --- a/examples/interpretation/notebooks/Dutch_F3_patch_model_training_and_evaluation.ipynb +++ b/examples/interpretation/notebooks/Dutch_F3_patch_model_training_and_evaluation.ipynb @@ -59,7 +59,7 @@ "source": [ "# load an existing experiment configuration file\n", "CONFIG_FILE = (\n", - " \"../../../experiments/interpretation/dutchf3_patch/local/configs/hrnet.yaml\"\n", + " \"../../../experiments/interpretation/dutchf3_patch/local/configs/seresnet_unet.yaml\"\n", ")\n", "# number of images to score\n", "N_EVALUATE = 20\n", @@ -239,7 +239,7 @@ "max_snapshots = config.TRAIN.SNAPSHOTS\n", "papermill = False\n", "dataset_root = config.DATASET.ROOT\n", - "model_pretrained = config.MODEL.PRETRAINED" + "model_pretrained = config.MODEL.PRETRAINED if \"PRETRAINED\" in config.MODEL.keys() else None" ] }, { @@ -859,9 +859,17 @@ "outputs": [], "source": [ "# use the model which we just fine-tuned\n", - "opts = [\"TEST.MODEL_PATH\", path.join(output_dir, f\"model_f3_nb_seg_hrnet_{train_len}.pth\")]\n", + "if \"hrnet\" in config.MODEL.NAME:\n", + " model_snapshot_name = f\"model_f3_nb_seg_hrnet_{train_len}.pth\"\n", + "elif \"resnet\" in config.MODEL.NAME: \n", + " model_snapshot_name = f\"model_f3_nb_resnet_unet_{train_len}.pth\"\n", + "else:\n", + " raise NotImplementedError(\"We don't support testing this model in this notebook yet\")\n", + " \n", + "opts = [\"TEST.MODEL_PATH\", path.join(output_dir, model_snapshot_name)]\n", "# uncomment the line below to use the pre-trained model instead\n", "# opts = [\"TEST.MODEL_PATH\", config.MODEL.PRETRAINED]\n", + "\n", "config.merge_from_list(opts)" ] }, diff --git a/examples/interpretation/notebooks/utilities.py b/examples/interpretation/notebooks/utilities.py index f0d3b9e3..068200bd 100644 --- a/examples/interpretation/notebooks/utilities.py +++ b/examples/interpretation/notebooks/utilities.py @@ -328,23 +328,23 @@ def download_pretrained_model(config): raise NameError( "Unknown dataset name. Only dutch f3 and penobscot are currently supported." ) - + if "hrnet" in config.MODEL.NAME: model = "hrnet" elif "deconvnet" in config.MODEL.NAME: model = "deconvnet" - elif "unet" in config.MODEL.NAME: - model = "unet" + elif "resnet" in config.MODEL.NAME: + model = "seresnetunet" else: raise NameError( - "Unknown model name. Only hrnet, deconvnet, and unet are currently supported." + "Unknown model name. Only hrnet, deconvnet, and seresnet_unet are currently supported." ) # check if the user already supplied a URL, otherwise figure out the URL - if validators.url(config.MODEL.PRETRAINED): + if "PRETRAINED" in config.MODEL.keys() and validators.url(config.MODEL.PRETRAINED): url = config.MODEL.PRETRAINED print(f"Will use user-supplied URL of '{url}'") - elif os.path.isfile(config.MODEL.PRETRAINED): + elif "PRETRAINED" in config.MODEL.keys() and os.path.isfile(config.MODEL.PRETRAINED): url = None print(f"Will use user-supplied file on local disk of '{config.MODEL.PRETRAINED}'") else: @@ -371,19 +371,20 @@ def download_pretrained_model(config): and config.TRAIN.DEPTH == "none" ): url = "http://deepseismicsharedstore.blob.core.windows.net/master-public-models/dutchf3_deconvnetskip_patch_no_depth.pth" - elif ( model == "deconvnet" and "skip" not in config.MODEL.NAME and config.TRAIN.DEPTH == "none" ): - url = "http://deepseismicsharedstore.blob.core.windows.net/master-public-models/dutchf3_deconvnet_patch_no_depth.pth" - elif model == "unet" and config.TRAIN.DEPTH == "section": - url = "http://deepseismicsharedstore.blob.core.windows.net/master-public-models/dutchf3_seresnetunet_patch_section_depth.pth" + url = "http://deepseismicsharedstore.blob.core.windows.net/master-public-models/dutchf3_deconvnet_patch_no_depth.pth" + elif model == "seresnetunet" and config.TRAIN.DEPTH == "section": + url = "https://deepseismicsharedstore.blob.core.windows.net/master-public-models/dutchf3_seresnetunet_patch_section_depth.pth" else: raise NotImplementedError( "We don't store a pretrained model for Dutch F3 for this model combination yet." ) + + else: raise NotImplementedError( "We don't store a pretrained model for this dataset/model combination yet." @@ -424,6 +425,11 @@ def download_pretrained_model(config): # Update config MODEL.PRETRAINED # TODO: Only HRNet uses a pretrained model currently. # issue https://github.com/microsoft/seismic-deeplearning/issues/267 + + # now that we have a pre-trained model, we can set it + if "PRETRAINED" not in config.MODEL.keys(): + config.MODEL["PRETRAINED"] = "dummy" + opts = [ "MODEL.PRETRAINED", pretrained_model_path, @@ -432,6 +438,7 @@ def download_pretrained_model(config): "TEST.MODEL_PATH", pretrained_model_path, ] + config.merge_from_list(opts) return config diff --git a/experiments/interpretation/dutchf3_patch/local/configs/seresnet_unet.yaml b/experiments/interpretation/dutchf3_patch/local/configs/seresnet_unet.yaml index 448da775..b3ddc43d 100644 --- a/experiments/interpretation/dutchf3_patch/local/configs/seresnet_unet.yaml +++ b/experiments/interpretation/dutchf3_patch/local/configs/seresnet_unet.yaml @@ -9,11 +9,12 @@ WORKERS: 4 PRINT_FREQ: 10 LOG_CONFIG: logging.conf SEED: 2019 +OPENCV_BORDER_CONSTANT: 0 DATASET: NUM_CLASSES: 6 - ROOT: /home/username/data/dutch/data + ROOT: "/home/username/data/dutch/data" CLASS_WEIGHTS: [0.7151, 0.8811, 0.5156, 0.9346, 0.9683, 0.9852] MIN: -1 MAX: 1