Skip to content
This repository has been archived by the owner on Nov 16, 2023. It is now read-only.

Replace HRNet with SEResNet model in the notebook #362

Merged
merged 7 commits into from
Jun 17, 2020
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
"source": [
"# load an existing experiment configuration file\n",
"CONFIG_FILE = (\n",
" \"../../../experiments/interpretation/dutchf3_patch/local/configs/hrnet.yaml\"\n",
" \"../../../experiments/interpretation/dutchf3_patch/local/configs/seresnet_unet.yaml\"\n",
")\n",
"# number of images to score\n",
"N_EVALUATE = 20\n",
Expand Down Expand Up @@ -239,7 +239,7 @@
"max_snapshots = config.TRAIN.SNAPSHOTS\n",
"papermill = False\n",
"dataset_root = config.DATASET.ROOT\n",
"model_pretrained = config.MODEL.PRETRAINED"
"model_pretrained = config.MODEL.PRETRAINED if \"PRETRAINED\" in config.MODEL.keys() else None"
]
},
{
Expand Down Expand Up @@ -859,9 +859,18 @@
"outputs": [],
"source": [
"# use the model which we just fine-tuned\n",
"opts = [\"TEST.MODEL_PATH\", path.join(output_dir, f\"model_f3_nb_seg_hrnet_{train_len}.pth\")]\n",
"if \"hrnet\" in config.MODEL.NAME:\n",
" model_snapshot_name = f\"model_f3_nb_seg_hrnet_{train_len}.pth\"\n",
"elif \"resnet\" in config.MODEL.NAME: \n",
" model_snapshot_name = f\"model_f3_nb_resnet_unet_{train_len}.pth\"\n",
"else:\n",
" raise NotImplementedError(\"We don't support testing this model in this notebook yet\")\n",
" \n",
"opts = [\"TEST.MODEL_PATH\", path.join(output_dir, model_snapshot_name)]\n",
"\n",
"# uncomment the line below to use the pre-trained model instead\n",
"# opts = [\"TEST.MODEL_PATH\", config.MODEL.PRETRAINED]\n",
"\n",
"config.merge_from_list(opts)"
]
},
Expand Down
27 changes: 17 additions & 10 deletions examples/interpretation/notebooks/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,23 +328,23 @@ def download_pretrained_model(config):
raise NameError(
"Unknown dataset name. Only dutch f3 and penobscot are currently supported."
)

if "hrnet" in config.MODEL.NAME:
model = "hrnet"
elif "deconvnet" in config.MODEL.NAME:
model = "deconvnet"
elif "unet" in config.MODEL.NAME:
model = "unet"
elif "resnet" in config.MODEL.NAME:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't this be "seresnet" or "resnet_unet" ?
we might have other models that have a resent backbone

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I disabled unet as we don't have / provide a pre-trained model with it

model = "seresnetunet"
else:
raise NameError(
"Unknown model name. Only hrnet, deconvnet, and unet are currently supported."
"Unknown model name. Only hrnet, deconvnet, and seresnet_unet are currently supported."
)

# check if the user already supplied a URL, otherwise figure out the URL
if validators.url(config.MODEL.PRETRAINED):
if "PRETRAINED" in config.MODEL.keys() and validators.url(config.MODEL.PRETRAINED):
url = config.MODEL.PRETRAINED
print(f"Will use user-supplied URL of '{url}'")
elif os.path.isfile(config.MODEL.PRETRAINED):
elif "PRETRAINED" in config.MODEL.keys() and os.path.isfile(config.MODEL.PRETRAINED):
url = None
print(f"Will use user-supplied file on local disk of '{config.MODEL.PRETRAINED}'")
else:
Expand All @@ -371,19 +371,20 @@ def download_pretrained_model(config):
and config.TRAIN.DEPTH == "none"
):
url = "http://deepseismicsharedstore.blob.core.windows.net/master-public-models/dutchf3_deconvnetskip_patch_no_depth.pth"

elif (
model == "deconvnet"
and "skip" not in config.MODEL.NAME
and config.TRAIN.DEPTH == "none"
):
url = "http://deepseismicsharedstore.blob.core.windows.net/master-public-models/dutchf3_deconvnet_patch_no_depth.pth"
elif model == "unet" and config.TRAIN.DEPTH == "section":
url = "http://deepseismicsharedstore.blob.core.windows.net/master-public-models/dutchf3_seresnetunet_patch_section_depth.pth"
url = "http://deepseismicsharedstore.blob.core.windows.net/master-public-models/dutchf3_deconvnet_patch_no_depth.pth"
elif model == "seresnetunet" and config.TRAIN.DEPTH == "section":
url = "https://deepseismicsharedstore.blob.core.windows.net/master-public-models/dutchf3_seresnetunet_patch_section_depth.pth"
else:
raise NotImplementedError(
"We don't store a pretrained model for Dutch F3 for this model combination yet."
)


else:
raise NotImplementedError(
"We don't store a pretrained model for this dataset/model combination yet."
Expand Down Expand Up @@ -424,6 +425,11 @@ def download_pretrained_model(config):
# Update config MODEL.PRETRAINED
# TODO: Only HRNet uses a pretrained model currently.
# issue https://github.com/microsoft/seismic-deeplearning/issues/267

# now that we have a pre-trained model, we can set it
if "PRETRAINED" not in config.MODEL.keys():
config.MODEL["PRETRAINED"] = "dummy"

opts = [
Comment on lines 426 to 433
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it true that only HRnet has a pretrained model?
Also, is "dummy" here like a placeholder?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's a key placeholder because we need to add a PRETRAINED key - we check the path later in the code so this can literally be anything

"MODEL.PRETRAINED",
pretrained_model_path,
Expand All @@ -432,6 +438,7 @@ def download_pretrained_model(config):
"TEST.MODEL_PATH",
pretrained_model_path,
]

config.merge_from_list(opts)

return config
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@ WORKERS: 4
PRINT_FREQ: 10
LOG_CONFIG: logging.conf
SEED: 2019
OPENCV_BORDER_CONSTANT: 0


DATASET:
NUM_CLASSES: 6
ROOT: /home/username/data/dutch/data
ROOT: "/home/username/data/dutch/data"
CLASS_WEIGHTS: [0.7151, 0.8811, 0.5156, 0.9346, 0.9683, 0.9852]
MIN: -1
MAX: 1
Expand Down