diff --git a/research/lamp-automated-model-parallelism/README.md b/research/lamp-automated-model-parallelism/README.md new file mode 100644 index 0000000000..321d7a2cdf --- /dev/null +++ b/research/lamp-automated-model-parallelism/README.md @@ -0,0 +1,53 @@ +# LAMP: Large Deep Nets with Automated Model Parallelism for Image Segmentation + +
+ +
+ + +> If you use this work in your research, please cite the paper. + +A reimplementation of the LAMP system originally proposed by: + +Wentao Zhu, Can Zhao, Wenqi Li, Holger Roth, Ziyue Xu, and Daguang Xu (2020) +"LAMP: Large Deep Nets with Automated Model Parallelism for Image Segmentation." +MICCAI 2020 (Early Accept, paper link: https://arxiv.org/abs/2006.12575) + + +## To run the demo: + +### Prerequisites +- install the latest version of MONAI: `git clone https://github.com/Project-MONAI/MONAI` and `pip install -e .` +- `pip install torchgpipe` + +### Data +```bash +mkdir ./data; +cd ./data; +``` +Head and Neck CT dataset + +Please download and unzip the images into `./data` folder. + +- `HaN.zip`: https://drive.google.com/file/d/1A2zpVlR3CkvtkJPvtAF3-MH0nr1WZ2Mn/view?usp=sharing +```bash +unzip HaN.zip; # unzip +``` + +Please find more details of the dataset at https://github.com/wentaozhu/AnatomyNet-for-anatomical-segmentation.git + + +### Minimal hardware requirements for full image training +- U-Net (`n_feat=32`): 2x 16Gb GPUs +- U-Net (`n_feat=64`): 4x 16Gb GPUs +- U-Net (`n_feat=128`): 2x 32Gb GPUs + + +### Commands +The number of features in the first block (`--n_feat`) can be 32, 64, or 128. +```bash +mkdir ./log; +python train.py --n_feat=128 --crop_size='64,64,64' --bs=16 --ep=4800 --lr=0.001 > ./log/YOURLOG.log +python train.py --n_feat=128 --crop_size='128,128,128' --bs=4 --ep=1200 --lr=0.001 --pretrain='./HaN_32_16_1200_64,64,64_0.001_*' > ./log/YOURLOG.log +python train.py --n_feat=128 --crop_size='-1,-1,-1' --bs=1 --ep=300 --lr=0.001 --pretrain='./HaN_32_16_1200_64,64,64_0.001_*' > ./log/YOURLOG.log +``` diff --git a/research/lamp-automated-model-parallelism/__init__.py b/research/lamp-automated-model-parallelism/__init__.py new file mode 100644 index 0000000000..d0044e3563 --- /dev/null +++ b/research/lamp-automated-model-parallelism/__init__.py @@ -0,0 +1,10 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/research/lamp-automated-model-parallelism/data_utils.py b/research/lamp-automated-model-parallelism/data_utils.py new file mode 100644 index 0000000000..b4825c1910 --- /dev/null +++ b/research/lamp-automated-model-parallelism/data_utils.py @@ -0,0 +1,66 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import numpy as np +from monai.transforms import DivisiblePad + +STRUCTURES = ( + "BrainStem", + "Chiasm", + "Mandible", + "OpticNerve_L", + "OpticNerve_R", + "Parotid_L", + "Parotid_R", + "Submandibular_L", + "Submandibular_R", +) + + +def get_filenames(path, maskname=STRUCTURES): + """ + create file names according to the predefined folder structure. + + Args: + path: data folder name + maskname: target structure names + """ + maskfiles = [] + for seg in maskname: + if os.path.exists(os.path.join(path, "./structures/" + seg + "_crp_v2.npy")): + maskfiles.append(os.path.join(path, "./structures/" + seg + "_crp_v2.npy")) + else: + # the corresponding mask is missing seg, path.split("/")[-1] + maskfiles.append(None) + return os.path.join(path, "img_crp_v2.npy"), maskfiles + + +def load_data_and_mask(data, mask_data): + """ + Load data filename and mask_data (list of file names) + into a dictionary of {'image': array, "label": list of arrays, "name": str}. + """ + pad_xform = DivisiblePad(k=32) + img = np.load(data) # z y x + img = pad_xform(img[None])[0] + item = dict(image=img, label=[]) + for idx, maskfnm in enumerate(mask_data): + if maskfnm is None: + ms = np.zeros(img.shape, np.uint8) + else: + ms = np.load(maskfnm).astype(np.uint8) + assert ms.min() == 0 and ms.max() == 1 + mask = pad_xform(ms[None])[0] + item["label"].append(mask) + assert len(item["label"]) == 9 + item["name"] = str(data) + return item diff --git a/research/lamp-automated-model-parallelism/fig/acc_speed_han_0_5hor.png b/research/lamp-automated-model-parallelism/fig/acc_speed_han_0_5hor.png new file mode 100644 index 0000000000..f8a8254832 Binary files /dev/null and b/research/lamp-automated-model-parallelism/fig/acc_speed_han_0_5hor.png differ diff --git a/research/lamp-automated-model-parallelism/test_unet_pipe.py b/research/lamp-automated-model-parallelism/test_unet_pipe.py new file mode 100644 index 0000000000..6783996480 --- /dev/null +++ b/research/lamp-automated-model-parallelism/test_unet_pipe.py @@ -0,0 +1,52 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from parameterized import parameterized + +from unet_pipe import UNetPipe + +TEST_CASES = [ + [ # 1-channel 3D, batch 12 + {"spatial_dims": 3, "out_channels": 2, "in_channels": 1, "depth": 3, "n_feat": 8}, + torch.randn(12, 1, 32, 64, 48), + (12, 2, 32, 64, 48), + ], + [ # 1-channel 3D, batch 16 + {"spatial_dims": 3, "out_channels": 2, "in_channels": 1, "depth": 3}, + torch.randn(16, 1, 32, 64, 48), + (16, 2, 32, 64, 48), + ], + [ # 4-channel 3D, batch 16, batch normalisation + {"spatial_dims": 3, "out_channels": 3, "in_channels": 2}, + torch.randn(16, 2, 64, 64, 64), + (16, 3, 64, 64, 64), + ], +] + + +class TestUNETPipe(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_shape(self, input_param, input_data, expected_shape): + net = UNetPipe(**input_param) + if torch.cuda.is_available(): + net = net.to(torch.device("cuda")) + input_data = input_data.to(torch.device("cuda")) + net.eval() + with torch.no_grad(): + result = net.forward(input_data.float()) + self.assertEqual(result.shape, expected_shape) + + +if __name__ == "__main__": + unittest.main() diff --git a/research/lamp-automated-model-parallelism/train.py b/research/lamp-automated-model-parallelism/train.py new file mode 100644 index 0000000000..1f6f578591 --- /dev/null +++ b/research/lamp-automated-model-parallelism/train.py @@ -0,0 +1,242 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +from argparse import ArgumentParser +import os + +import numpy as np +import torch +from monai.transforms import AddChannelDict, Compose, RandCropByPosNegLabeld, Rand3DElasticd, SpatialPadd +from monai.losses import DiceLoss, FocalLoss +from monai.metrics import compute_meandice +from monai.data import Dataset, list_data_collate +from monai.utils import first +from torchgpipe import GPipe +from torchgpipe.balance import balance_by_size + +from unet_pipe import UNetPipe, flatten_sequential +from data_utils import get_filenames, load_data_and_mask + +N_CLASSES = 10 +TRAIN_PATH = "./data/HaN/train/" # training data folder +VAL_PATH = "./data/HaN/test/" # validation data folder + +torch.backends.cudnn.enabled = True + + +class ImageLabelDataset: + """ + Load image and multi-class labels based on the predefined folder structure. + """ + + def __init__(self, path, n_class=10): + self.path = path + self.data = sorted(os.listdir(path)) + self.n_class = n_class + + def __getitem__(self, index): + data = os.path.join(self.path, self.data[index]) + train_data, train_masks_data = get_filenames(data) + data = load_data_and_mask(train_data, train_masks_data) # read into a data dict + # loading image + data["image"] = data["image"].astype(np.float32) # shape (H W D) + # loading labels + class_shape = (1,) + data["image"].shape + mask0 = np.zeros(class_shape) + mask_list = [] + flagvect = np.ones((self.n_class,), np.float32) + for i, mask in enumerate(data["label"]): + if mask is None: + mask = np.zeros(class_shape) + flagvect[0] = 0 + flagvect[i + 1] = 0 + mask0 = np.logical_or(mask0, mask) + mask_list.append(mask.reshape(class_shape)) + mask0 = 1 - mask0 + data["label"] = np.concatenate([mask0] + mask_list, axis=0).astype(np.uint8) # shape (C H W D) + # setting flags + data["with_complete_groundtruth"] = flagvect # flagvec is a boolean indicator for complete annotation + return data + + def __len__(self): + return len(self.data) + + +def train(n_feat, crop_size, bs, ep, optimizer="rmsprop", lr=5e-4, pretrain=None): + model_name = f"./HaN_{n_feat}_{bs}_{ep}_{crop_size}_{lr}_" + print(f"save the best model as '{model_name}' during training.") + + crop_size = [int(cz) for cz in crop_size.split(",")] + print(f"input image crop_size: {crop_size}") + + # starting training set loader + train_images = ImageLabelDataset(path=TRAIN_PATH, n_class=N_CLASSES) + if np.any([cz == -1 for cz in crop_size]): # using full image + train_transform = Compose( + [ + AddChannelDict(keys="image"), + Rand3DElasticd( + keys=("image", "label"), + spatial_size=crop_size, + sigma_range=(10, 50), # 30 + magnitude_range=[600, 1200], # 1000 + prob=0.8, + rotate_range=(np.pi / 12, np.pi / 12, np.pi / 12), + shear_range=(np.pi / 18, np.pi / 18, np.pi / 18), + translate_range=(sz * 0.05 for sz in crop_size), + scale_range=(0.2, 0.2, 0.2), + mode=("bilinear", "nearest"), + padding_mode=("border", "zeros"), + ), + ] + ) + train_dataset = Dataset(train_images, transform=train_transform) + # when bs > 1, the loader assumes that the full image sizes are the same across the dataset + train_dataloader = torch.utils.data.DataLoader(train_dataset, num_workers=4, batch_size=bs, shuffle=True) + else: + # draw balanced foreground/background window samples according to the ground truth label + train_transform = Compose( + [ + AddChannelDict(keys="image"), + SpatialPadd(keys=("image", "label"), spatial_size=crop_size), # ensure image size >= crop_size + RandCropByPosNegLabeld( + keys=("image", "label"), label_key="label", spatial_size=crop_size, num_samples=bs + ), + Rand3DElasticd( + keys=("image", "label"), + spatial_size=crop_size, + sigma_range=(10, 50), # 30 + magnitude_range=[600, 1200], # 1000 + prob=0.8, + rotate_range=(np.pi / 12, np.pi / 12, np.pi / 12), + shear_range=(np.pi / 18, np.pi / 18, np.pi / 18), + translate_range=(sz * 0.05 for sz in crop_size), + scale_range=(0.2, 0.2, 0.2), + mode=("bilinear", "nearest"), + padding_mode=("border", "zeros"), + ), + ] + ) + train_dataset = Dataset(train_images, transform=train_transform) # each dataset item is a list of windows + train_dataloader = torch.utils.data.DataLoader( # stack each dataset item into a single tensor + train_dataset, num_workers=4, batch_size=1, shuffle=True, collate_fn=list_data_collate + ) + first_sample = first(train_dataloader) + print(first_sample["image"].shape) + + # starting validation set loader + val_transform = Compose([AddChannelDict(keys="image")]) + val_dataset = Dataset(ImageLabelDataset(VAL_PATH, n_class=N_CLASSES), transform=val_transform) + val_dataloader = torch.utils.data.DataLoader(val_dataset, num_workers=1, batch_size=1) + print(val_dataset[0]["image"].shape) + print(f"training images: {len(train_dataloader)}, validation images: {len(val_dataloader)}") + + model = UNetPipe(spatial_dims=3, in_channels=1, out_channels=N_CLASSES, n_feat=n_feat) + model = flatten_sequential(model) + lossweight = torch.from_numpy(np.array([2.22, 1.31, 1.99, 1.13, 1.93, 1.93, 1.0, 1.0, 1.90, 1.98], np.float32)) + + if optimizer.lower() == "rmsprop": + optimizer = torch.optim.RMSprop(model.parameters(), lr=lr) # lr = 5e-4 + elif optimizer.lower() == "momentum": + optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9) # lr = 1e-4 for finetuning + else: + raise ValueError(f"Unknown optimizer type {optimizer}. (options are 'rmsprop' and 'momentum').") + + # config GPipe + x = first_sample["image"].float() + x = torch.autograd.Variable(x.cuda()) + partitions = torch.cuda.device_count() + print(f"partition: {partitions}, input: {x.size()}") + balance = balance_by_size(partitions, model, x) + model = GPipe(model, balance, chunks=4, checkpoint="always") + + # config loss functions + dice_loss_func = DiceLoss(softmax=True, reduction="none") + # use the same pipeline and loss in + # AnatomyNet: Deep learning for fast and fully automated whole‐volume segmentation of head and neck anatomy, + # Medical Physics, 2018. + focal_loss_func = FocalLoss(reduction="none") + + if pretrain: + print(f"loading from {pretrain}.") + pretrained_dict = torch.load(pretrain)["weight"] + model_dict = model.state_dict() + pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} + model_dict.update(pretrained_dict) + model.load_state_dict(pretrained_dict) + + b_time = time.time() + best_val_loss = [0] * (N_CLASSES - 1) # foreground + best_ave = -1 + for epoch in range(ep): + model.train() + trainloss = 0 + for b_idx, data_dict in enumerate(train_dataloader): + x_train = data_dict["image"] + y_train = data_dict["label"] + flagvec = data_dict["with_complete_groundtruth"] + + x_train = torch.autograd.Variable(x_train.cuda()) + y_train = torch.autograd.Variable(y_train.cuda().float()) + optimizer.zero_grad() + o = model(x_train).to(0, non_blocking=True).float() + + loss = (dice_loss_func(o, y_train.to(o)) * flagvec.to(o) * lossweight.to(o)).mean() + loss += 0.5 * (focal_loss_func(o, y_train.to(o)) * flagvec.to(o) * lossweight.to(o)).mean() + loss.backward() + optimizer.step() + trainloss += loss.item() + + if b_idx % 20 == 0: + print(f"Train Epoch: {epoch} [{b_idx}/{len(train_dataloader)}] \tLoss: {loss.item()}") + print(f"epoch {epoch} TRAIN loss {trainloss / len(train_dataloader)}") + + if epoch % 10 == 0: + model.eval() + # check validation dice + val_loss = [0] * (N_CLASSES - 1) + n_val = [0] * (N_CLASSES - 1) + for data_dict in val_dataloader: + x_val = data_dict["image"] + y_val = data_dict["label"] + with torch.no_grad(): + x_val = torch.autograd.Variable(x_val.cuda()) + o = model(x_val).to(0, non_blocking=True) + loss = compute_meandice(o, y_val.to(o), mutually_exclusive=True, include_background=False) + val_loss = [l.item() + tl if l == l else tl for l, tl in zip(loss[0], val_loss)] + n_val = [n + 1 if l == l else n for l, n in zip(loss[0], n_val)] + val_loss = [l / n for l, n in zip(val_loss, n_val)] + print("validation scores %.4f, %.4f, %.4f, %.4f, %.4f, %.4f, %.4f, %.4f, %.4f" % tuple(val_loss)) + for c in range(1, 10): + if best_val_loss[c - 1] < val_loss[c - 1]: + best_val_loss[c - 1] = val_loss[c - 1] + state = {"epoch": epoch, "weight": model.state_dict(), "score_" + str(c): best_val_loss[c - 1]} + torch.save(state, f"{model_name}" + str(c)) + print("best validation scores %.4f, %.4f, %.4f, %.4f, %.4f, %.4f, %.4f, %.4f, %.4f" % tuple(best_val_loss)) + + print("total time", time.time() - b_time) + + +if __name__ == "__main__": + parser = ArgumentParser() + parser.add_argument("--n_feat", type=int, default=32, dest="n_feat") + parser.add_argument("--crop_size", type=str, default="-1,-1,-1", dest="crop_size") + parser.add_argument("--bs", type=int, default=1, dest="bs") # batch size + parser.add_argument("--ep", type=int, default=150, dest="ep") # number of epochs + parser.add_argument("--lr", type=float, default=5e-4, dest="lr") # learning rate + parser.add_argument("--optimizer", type=str, default="rmsprop", dest="optimizer") # type of optimizer + parser.add_argument("--pretrain", type=str, default=None, dest="pretrain") + args = parser.parse_args() + + input_dict = vars(args) + print(input_dict) + train(**input_dict) diff --git a/research/lamp-automated-model-parallelism/unet_pipe.py b/research/lamp-automated-model-parallelism/unet_pipe.py new file mode 100644 index 0000000000..d563de8257 --- /dev/null +++ b/research/lamp-automated-model-parallelism/unet_pipe.py @@ -0,0 +1,171 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import OrderedDict +from typing import List + +import torch +from monai.networks.blocks import Convolution, UpSample +from monai.networks.layers.factories import Act, Conv, Norm +from torch import nn +from torchgpipe.skip import Namespace, pop, skippable, stash + + +@skippable(stash=["skip"], pop=[]) +class Stash(nn.Module): + def forward(self, input: torch.Tensor): + yield stash("skip", input) + return input # noqa using yield together with return + + +@skippable(stash=[], pop=["skip"]) +class PopCat(nn.Module): + def forward(self, input: torch.Tensor): + skip = yield pop("skip") + if skip is not None: + input = torch.cat([input, skip], dim=1) + return input + + +def flatten_sequential(module: nn.Sequential): + """ + Recursively make all the submodules sequential. + + Args: + module: a torch sequential model. + """ + if not isinstance(module, nn.Sequential): + raise TypeError("module must be a nn.Sequential instance.") + + def _flatten(module): + for name, child in module.named_children(): + if isinstance(child, nn.Sequential): + for sub_name, sub_child in _flatten(child): + yield f"{name}_{sub_name}", sub_child + else: + yield name, child + + return nn.Sequential(OrderedDict(_flatten(module))) + + +class DoubleConv(nn.Module): + def __init__( + self, + spatial_dims, + in_channels, + out_channels, + stride=2, + act_1=Act.LEAKYRELU, + norm_1=Norm.BATCH, + act_2=Act.LEAKYRELU, + norm_2=Norm.BATCH, + conv_only=True, + ): + """ + A sequence of Conv_1 + Norm_1 + Act_1 + Conv_2 (+ Norm_2 + Act_2). + + `norm_2` and `act_2` are ignored when `conv_only` is True. + `stride` is for `Conv_1`, typically stride=2 for 2x spatial downsampling. + + Args: + spatial_dims: number of the input spatial dimension. + in_channels: number of input channels. + out_channels: number of output channels. + stride: stride of the first conv., mainly used for 2x downsampling when stride=2. + act_1: activation type of the first convolution. + norm_1: normalization type of the first convolution. + act_2: activation type of the second convolution. + norm_2: normalization type of the second convolution. + conv_only: whether the second conv is convolution layer only. Default to True, + indicates that `act_2` and `norm_2` are not in use. + """ + super(DoubleConv, self).__init__() + self.conv = nn.Sequential( + Convolution(spatial_dims, in_channels, out_channels, strides=stride, act=act_1, norm=norm_1, bias=False,), + Convolution(spatial_dims, out_channels, out_channels, act=act_2, norm=norm_2, conv_only=conv_only), + ) + + def forward(self, x): + return self.conv(x) + + +class UNetPipe(nn.Sequential): + def __init__(self, spatial_dims: int, in_channels: int, out_channels: int, n_feat: int = 32, depth: int = 4): + """ + A UNet-like architecture for model parallelism. + + Args: + spatial_dims: number of input spatial dimensions, + 2 for (B, in_channels, H, W), 3 for (B, in_channels, H, W, D). + in_channels: number of input channels. + out_channels: number of output channels. + n_feat: number of features in the first convolution. + depth: number of downsampling stages. + """ + super(UNetPipe, self).__init__() + n_enc_filter: List[int] = [n_feat] + for i in range(1, depth + 1): + n_enc_filter.append(min(n_enc_filter[-1] * 2, 1024)) + namespaces = [Namespace() for _ in range(depth)] + + # construct the encoder + encoder_layers: List[nn.Module] = [] + init_conv = Convolution( + spatial_dims, in_channels, n_enc_filter[0], strides=2, act=Act.LEAKYRELU, norm=Norm.BATCH, bias=False, + ) + encoder_layers.append( + nn.Sequential(OrderedDict([("Conv", init_conv,), ("skip", Stash().isolate(namespaces[0]))])) + ) + for i in range(1, depth + 1): + down_conv = DoubleConv(spatial_dims, n_enc_filter[i - 1], n_enc_filter[i]) + if i == depth: + layer_dict = OrderedDict([("Down", down_conv)]) + else: + layer_dict = OrderedDict([("Down", down_conv), ("skip", Stash().isolate(namespaces[i]))]) + encoder_layers.append(nn.Sequential(layer_dict)) + encoder = nn.Sequential(*encoder_layers) + + # construct the decoder + decoder_layers: List[nn.Module] = [] + for i in reversed(range(1, depth + 1)): + in_ch, out_ch = n_enc_filter[i], n_enc_filter[i - 1] + layer_dict = OrderedDict( + [ + ("Up", UpSample(spatial_dims, in_ch, out_ch, 2, True)), + ("skip", PopCat().isolate(namespaces[i - 1])), + ("Conv1x1x1", Conv[Conv.CONV, spatial_dims](out_ch * 2, in_ch, kernel_size=1)), + ("Conv", DoubleConv(spatial_dims, in_ch, out_ch, stride=1, conv_only=True)), + ] + ) + decoder_layers.append(nn.Sequential(layer_dict)) + in_ch = min(n_enc_filter[0] // 2, 32) + layer_dict = OrderedDict( + [ + ("Up", UpSample(spatial_dims, n_feat, in_ch, 2, True)), + ("RELU", Act[Act.LEAKYRELU](inplace=False)), + ("out", Conv[Conv.CONV, spatial_dims](in_ch, out_channels, kernel_size=3, padding=1),), + ] + ) + decoder_layers.append(nn.Sequential(layer_dict)) + decoder = nn.Sequential(*decoder_layers) + + # making a sequential model + self.add_module("encoder", encoder) + self.add_module("decoder", decoder) + + for m in self.modules(): + if isinstance(m, Conv[Conv.CONV, spatial_dims]): + nn.init.kaiming_normal_(m.weight) + elif isinstance(m, Norm[Norm.BATCH, spatial_dims]): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + elif isinstance(m, Conv[Conv.CONVTRANS, spatial_dims]): + nn.init.kaiming_normal_(m.weight) diff --git a/setup.py b/setup.py index 83372856ea..5158fa1fb9 100644 --- a/setup.py +++ b/setup.py @@ -17,7 +17,7 @@ setup( version=versioneer.get_version(), cmdclass=versioneer.get_cmdclass(), - packages=find_packages(exclude=("docs", "examples", "tests")), + packages=find_packages(exclude=("docs", "examples", "tests", "research")), zip_safe=False, package_data={"monai": ["py.typed"]}, )