Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Select specific AFIDs for training #18

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
#mis folders
tensorflow/
jobs/
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
Expand Down
12 changes: 8 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
# afids-CNN
Leveraging the recent release of the anatomical fiducial framework for developing an open software infrastructure to solve the landmark regression problem on 3D MRI images

## Processing imaging data for training
1 - skull stripping
2 - conforming image
3 - intensity normalization (i.e., WM to 110)
## Preparation
1- install poetry and configure cache directory
2- poetry install and shell to activate environment

## Processing imaging data for training can be found in the following repo (https://github.com/afids/autoafids_prep)
1 - rigid registraion to MNI template
2 - conforming image to 1mm iso res
3 - intensity normalization (i.e., WM to 110) followed by minmax norm

## Processing landmark data (AFIDs)
1 - extract points from landmark file (.fcsv is supported)
Expand Down
71 changes: 34 additions & 37 deletions afids_cnn/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
Path(__file__).parent / "resources" / "tpl-MNI152NLin2009cAsym_res-01_T1w.fcsv"
)
MNI_IMG = (
Path(__file__).parent / "resources" / "tpl-MNI152NLin2009cAsym_res-01_T1w.nii.gz"
Path(__file__).parent / "resources" / "tpl-MNI152NLin2009cAsym_res-256_T1w.nii.gz"
)


Expand All @@ -53,42 +53,24 @@ def get_fid(fcsv_df: pd.DataFrame, fid_label: int) -> NDArray:

def fid_voxel2world(fid_voxel: NDArray, nii_affine: NDArray) -> NDArray:
"""Transform fiducials in voxel coordinates to world coordinates."""
# Translation
fid_world = fid_voxel.T + nii_affine[:3, 3:4]
# Rotation
fid_world = np.diag(np.dot(fid_world, nii_affine[:3, :3]))

translation = nii_affine[:3, 3]
rotation = nii_affine[:3, :3]
fid_world = rotation.dot(fid_voxel)+translation
return fid_world.astype(float)


def fid_world2voxel(
fid_world: NDArray,
nii_affine: NDArray,
resample_size: int = 1,
padding: int | None = None,
) -> NDArray:
"""Transform fiducials in world coordinates to voxel coordinates.

Optionally, resample to match resampled image
"""
# Translation
fid_voxel = fid_world.T - nii_affine[:3, 3:4]
# Rotation
fid_voxel = np.dot(fid_voxel, np.linalg.inv(nii_affine[:3, :3]))

# Round to nearest voxel
fid_voxel = np.rint(np.diag(fid_voxel) * resample_size)

if padding:
fid_voxel = np.pad(fid_voxel, padding, mode="constant")

"""Transform fiducials in world coordinates to voxel coordinates."""
inv_affine = np.linalg.inv(nii_affine)
translation = inv_affine[:3, 3]
rotation = inv_affine[:3, :3]
fid_voxel = rotation.dot(fid_world) + translation
fid_voxel = np.rint(fid_voxel)
return fid_voxel.astype(int)


def min_max_normalize(img: NDArray) -> NDArray:
return (img - img.min()) / (img.max() - img.min())


def gen_patch_slices(centre: NDArray, radius: int) -> tuple[slice, slice, slice]:
return tuple(slice(coord - radius, coord + radius + 1) for coord in centre[:3])

Expand Down Expand Up @@ -116,18 +98,19 @@ def process_distances(
radius: int,
) -> NDArray:
dim = (2 * radius) + 1
print(f'min distance: {distances.min()}')
arr_dis = np.reshape(distances[0], (dim, dim, dim))
new_pred = np.full((img.shape), 100, dtype=float)
slices = gen_patch_slices(mni_fid, radius)
new_pred[slices[0], slices[1], slices[2]] = arr_dis
transformed = np.exp(-0.5 * new_pred)
thresh = np.percentile(transformed, 99.999)
thresh = np.percentile(transformed, 99)
thresholded = transformed
thresholded[thresholded < thresh] = 0
thresholded = (thresholded * 1000).astype(int)
thresholded = (thresholded * 1000000).astype(int)
new = skimage.measure.regionprops(thresholded)
if not new:
logger.warning("No centroid found for this afid. Results will be suspect.")
logger.warning("No centroid found for this afid. Results may be suspect.")
return np.array(
np.unravel_index(
np.argmax(transformed, axis=None),
Expand All @@ -154,23 +137,37 @@ def apply_model(
mni_fid_resampled = fid_world2voxel(
mni_fid_world,
mni_img.affine,
resample_size=1,
padding=0,
)
normalized = min_max_normalize(img.get_fdata())
print('itr #1')
img_data = img.get_fdata()
distances = predict_distances(
radius,
model,
mni_fid_resampled,
normalized,
img_data,
)
fid_resampled = process_distances(
distances,
normalized,
img_data,
mni_fid_resampled,
radius,
)
return fid_voxel2world(fid_resampled, img.affine)
#do it again to improve prediction
print(f'itr #2')
fid_pred = np.rint(fid_resampled).astype(int)
distances2 = predict_distances(
radius,
model,
fid_pred,
img_data,
)
fid_resampled2 = process_distances(
distances2,
img_data,
fid_pred,
radius,
)
return fid_voxel2world(fid_resampled2, img.affine)


def apply_all(
Expand Down
43 changes: 27 additions & 16 deletions afids_cnn/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import pandas as pd
from numpy.typing import NDArray
from tensorflow import keras
from keras.layers import BatchNormalization, LeakyReLU

from afids_cnn.generator import customImageDataGenerator

Expand Down Expand Up @@ -42,43 +43,53 @@ def create_generator(


def gen_conv3d_layer(
filters: int,
filters: int,
kernel_size: tuple[int, int, int] = (3, 3, 3),
) -> keras.layers.Conv3D:
return keras.layers.Conv3D(filters, kernel_size, padding="same", activation="relu")

) -> keras.layers.Layer:
return keras.Sequential([
keras.layers.Conv3D(filters, kernel_size, padding="same", kernel_initializer="he_normal"),
BatchNormalization(),
LeakyReLU(alpha=0.1)
])

def gen_max_pooling_layer() -> keras.layers.MaxPooling3D:
return keras.layers.MaxPooling3D((2, 2, 2))


def gen_transpose_layer(filters: int) -> keras.layers.Conv3DTranspose:
return keras.layers.Conv3DTranspose(
filters,
kernel_size=2,
strides=2,
padding="same",
return keras.Sequential([
keras.layers.Conv3DTranspose(filters, kernel_size=2, strides=2, padding="same", kernel_initializer="he_normal"),
BatchNormalization(),
LeakyReLU(alpha=0.1)
])

def gen_dropout_layer() -> keras.layers.Dropout:
return keras.layers.Dropout(
rate = 0.2,
noise_shape=None,
seed=None,
)


def gen_std_block(filters: int, input_):
x = gen_conv3d_layer(filters)(input_)
out_layer = gen_conv3d_layer(filters)(x)
x = gen_conv3d_layer(filters)(x)
out_layer = gen_dropout_layer()(x)
return out_layer, gen_max_pooling_layer()(out_layer)


def gen_opposite_block(filters: int, input_, out_layer):
x = input_
for _ in range(3):
for _ in range(2):
x = gen_conv3d_layer(filters)(x)
x = gen_dropout_layer()(x)
next_filters = filters // 2
x = gen_transpose_layer(next_filters)(x)
x = gen_conv3d_layer(next_filters)(x)
return keras.layers.Concatenate(axis=4)([out_layer, x])


def gen_model() -> keras.Model:
input_layer = keras.layers.Input((None, None, None, 1))
input_layer = keras.layers.Input((63, 63, 63, 1))
x = keras.layers.ZeroPadding3D(padding=((1, 0), (1, 0), (1, 0)))(input_layer)

out_layer_1, x = gen_std_block(16, x) # block 1
Expand All @@ -90,16 +101,16 @@ def gen_model() -> keras.Model:
x = gen_conv3d_layer(256)(x)
x = gen_conv3d_layer(256)(x)
x = keras.layers.Conv3DTranspose(filters=128, kernel_size=2, strides=(2, 2, 2))(x)
x = gen_conv3d_layer(128, (2, 2, 2))(x)
x = keras.layers.Concatenate(axis=4)([out_layer_4, x])

x = gen_opposite_block(128, x, out_layer_3) # block 5 (opposite 4)
x = gen_opposite_block(64, x, out_layer_2) # block 6 (opposite 3)
x = gen_opposite_block(32, x, out_layer_1) # block 7 (opposite 2)

# block 8 (opposite 1)
for _ in range(3):
for _ in range(2):
x = gen_conv3d_layer(16)(x)
x = gen_dropout_layer()(x)

# output layer
x = keras.layers.Cropping3D(cropping=((1, 0), (1, 0), (1, 0)), data_format=None)(x)
Expand Down Expand Up @@ -189,7 +200,7 @@ def main():
)

callbacks = (
[keras.callbacks.EarlyStopping(monitor="val_loss", patience=100)]
[keras.callbacks.EarlyStopping(monitor="val_loss", min_delta=0.05, patience=5)]
if args.do_early_stopping
else None
)
Expand Down
5 changes: 4 additions & 1 deletion afids_cnn/train_workflow/config/snakebids.yml
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,13 @@ parse_args:
--afids_dir:
help: 'Path to a BIDS dataset with AFIDs placements corresponding to bids_dir.'
required: True
--afids_num:
help: 'AFID(s) to train. If not specified, will train on all AFIDs.'
nargs: '*'
--validation_bids_dir:
help: 'Path to a validation BIDS dataset.'
--validation_afids_dir:
help: 'Path to a BIDS dataset with AFIDs placements corresponding to validation_bids_dir'
help: 'Path to a BIDS dataset with AFIDs placements corresponding to validation_bids_dir.'
--frequency:
help: 'Augmentation frequency in voxels.'
default: 1000
Expand Down
8 changes: 6 additions & 2 deletions afids_cnn/train_workflow/workflow/Snakefile
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,7 @@ rule train_model:
desc="num{num_augment}angle{angle_stdev}radius{radius}freq{frequency}",
suffix="afid-{afid}_loss.csv",
),

params:
num_channels=config["num_channels"],
epochs=config["epochs"],
Expand All @@ -269,12 +270,15 @@ rule train_model:
"{params.do_early_stopping} "
"{params.validation_arg} {input.combined_patch_validation} "


rule all:
input:
models=expand(
rules.train_model.output.model,
afid=[f"{afid:02}" for afid in range(1, 33)],
afid=[
f"{afid:02}" for afid in (
list(map(int, config["afids_num"])) if config["afids_num"] else range(1, 33)
)
],
num_augment=config["num_augment"],
angle_stdev=config["angle_stdev"],
radius=config["radius"],
Expand Down
4 changes: 4 additions & 0 deletions bash_scripts/train-cnn_scratch.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
#!/bin/bash
set -eo pipefail
source /project/ctb-akhanf/ataha24/venv_archives/virtualenvs/afids-cnn-2o1BhLKg-py3.9/bin/activate
auto_afids_cnn_train_bids --frequency 300 --afids_dir /scratch/ataha24/afids-data/data/autoafids/train/ /scratch/ataha24/afids-data/data/autoafids/train/ /scratch/ataha24/afids-data/data/autoafids/train/training_20240205 --validation-afids-dir /scratch/ataha24/afids-data/data/autoafids/validation --validation-bids-dir /scratch/ataha24/afids-data/data/autoafids/validation participant --epochs 200 --steps-per-epoch 100 --validation-steps 100 --do_early_stopping --cores 2 --use-singularity --config containers='{c3d: "/project/6050199/akhanf/singularity/bids-apps/itksnap_v3.8.2.sif"}'
5 changes: 5 additions & 0 deletions bash_scripts/train-cnn_scratch_gpu.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#!/bin/bash
set -eo pipefail
module load cuda cudnn
source /scratch/ataha24/0_dev/0.afids-CNN/tensorflow/bin/activate
auto_afids_cnn_train_bids --frequency 300 --afids_dir /scratch/ataha24/afids-data/data/autoafids/train/ /scratch/ataha24/afids-data/data/autoafids/train/ /scratch/ataha24/afids-data/data/autoafids/train/training_20240209 --validation-afids-dir /scratch/ataha24/afids-data/data/autoafids/validation --validation-bids-dir /scratch/ataha24/afids-data/data/autoafids/validation participant --epochs 200 --steps-per-epoch 100 --validation-steps 20 --do_early_stopping --cores 1 --use-singularity --config containers='{c3d: "/project/6050199/akhanf/singularity/bids-apps/itksnap_v3.8.2.sif"}'