This repository experiments with best techniques to improve dense, volumetric semantic segmentation. Specifically, the model is of U-net architectural style and includes variational autoencoder (for regularization), residual blocks, spatial and channel squeeze-excitation layers, and dense connections.
This is a variation of the U-net architecture with variational autoencoder regularization. There are several architectural enhancements, including
- Spatial and channel squeeze-excitation layers in the ResNet blocks.
- Dense connections between encoder ResNet blocks at the same spatial resolution level.
- Convolutional layers to consist of order
[Conv3D, GroupNorm, ReLU]
, except for all pointwise and output layers. - He normal initialization for all layer kernels except those with sigmoid activations, which are initialized with Glorot normal.
- Convolutional downsampling and upsampling operations.
Dependencies are only supported for Python3 and can be found in requirements.txt
(numpy==1.15
for preprocessing and tensorflow==2.0.0-alpha0
for model architecture, utilizing tf.keras.Model
and tf.keras.Layer
subclassing).
The model can be found in model/model.py
and contains an inference
mode in addition to the training
mode that tf.Keras.Model
supports.
- Specify
training=False, inference=True
to only receive the decoder output, as desired in test time. - Specify
training=False, inference=False
to receive both the decoder and variational autoencoder output to be able to run loss and metrics, as desired in validation time.
The BraTS 2017/2018 dataset is not publicly available, so download scripts for those are not available. Once downloaded, run preprocessing on the original data format, which should look something like this:
BraTS17TrainingData/*/*/*[t1,t1ce,t2,flair,seg].nii.gz
For each example, there are 4 modalities and 1 label, each of shape 240 x 240 x 155
. Preprocessing steps consist of:
- Concatenate the
t1ce
andflair
modalities along the channel dimension. - Compute per-channel image-wise
mean
andstd
and normalize per channel for the training set. - Crop as much background as possible across all images. Final image sizes are
155 x 190 x 147
. - Serialize to
tf.TFRecord
format for convenience in training.
python preprocess.py \
--in_locs /path/to/BraTS17TrainingData \
--modalities t1ce,flair \
--truth seg \
--create_val
All command-line arguments can be found in
args.py
.
There are 285 training examples in the BraTS 2017/2018 training sets, but for lack of validation set, the
--create_val
flag creates a 10:1 split, resulting in 260 and 25 training and validation examples, respectively.
Most hyperparameters proposed in the paper are used in training. The input is randomly flipped across spatial axes with probability 0.5 and cropped to 128 x 128 x 128
per example in training (making the training data stochastic). The validation set is dynamically created each epoch in a similar fashion.
python train.py \
--train_loc /path/to/train \
--val_loc /path/to/val \
--prepro_file /path/to/prepro/prepro.npy \
--save_folder checkpoint \
--crop_size 128,128,128
Use the
--gpu
flag to run on GPU.
The testing script test.py
runs inference on unlabeled data provided as input by generating sample labels on the whole image, padded to a size that is compatible with downsampling. The VAE is not run in inference so the model is actually fully convolutional.
python test.py \
--in_locs /path/to/test \
--modalities t1ce,flair \
--prepro_loc /path/to/prepro/prepro.npy \
--tumor_model checkpoint
Training arguments are saved in the checkpoint folder. This bypasses the need for manual model initialization.
The
Interpolator
class is used to interpolate voxel sizes in rescaling so that all inputs can be resized to 1 mm^3.
NOTE:
test.py
is not fully debugged and functional. If needed please open an issue.
Because BraTS contains skull-stripped images which are uncommon in actual applications, we support training and inference of skull stripping models. The same pipeline can be generalized, but using the NFBS skull-stripping dataset here. Note that in model initialization and training, the number of output channels --out_ch
would be different for these tasks.
If the testing data contains skull bits, run skull stripping and tumor segmentation sequentially in inference time by specifying the
--skull_model
flag. All preprocessing and training should work for both tasks as is.
We run training on a V100 32GB GPU with a batch size of 1. Each epoch takes around ~12 minutes to run. Below is a sample training curve, using all default model parameters.
Epoch | Training Loss | Training Dice Score | Validation Loss | Validation Dice Score |
---|---|---|---|---|
0 | 1.000 | 0.134 | 0.732 | 0.248 |
50 | 0.433 | 0.598 | 0.413 | 0.580 |
100 | 0.386 | 0.651 | 0.421 | 0.575 |
150 | 0.356 | 0.676 | 0.393 | 0.594 |
200 | 0.324 | 0.692 | 0.349 | 0.642 |
250 | 0.295 | 0.716 | 0.361 | 0.630 |
300 | 0.282 | 0.729 | 0.352 | 0.644 |