Skip to content

Commit

Permalink
refactor(eyeseg): work with .eye files for storing results instead of…
Browse files Browse the repository at this point in the history
… pickel

BREAKING CHANGE:
  • Loading branch information
Oli4 committed Sep 13, 2022
1 parent 7c394e9 commit 239c660
Show file tree
Hide file tree
Showing 18 changed files with 1,376 additions and 318 deletions.
67 changes: 59 additions & 8 deletions eyeseg/io_utils/input_pipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def _parse_function(input_proto):
other_features = {
"volume": tf.io.FixedLenFeature([], tf.string),
"bscan": tf.io.FixedLenFeature([], tf.string),
"layer_positions": tf.io.FixedLenFeature([], tf.string),
"image": tf.io.FixedLenFeature([], tf.string),
"group": tf.io.FixedLenFeature([], tf.string),
}
Expand All @@ -32,7 +33,12 @@ def _parse_function(input_proto):
data = tf.io.parse_single_example(input_proto, image_feature_description)

image = tf.io.parse_tensor(data["image"], tf.uint8)
image = tf.reshape(image, input_shape + (1,))
image = tf.reshape(image, input_shape + (1,), name="reshape_1")

layer_positions = tf.io.parse_tensor(data["layer_positions"], tf.float32)
layer_positions = tf.reshape(
layer_positions, input_shape + (len(mapping),), name="reshape_1.5"
)

# Sort mapping for guaranteed order
layerout = tf.stack(
Expand All @@ -42,13 +48,16 @@ def _parse_function(input_proto):
],
axis=-1,
)
layerout = tf.reshape(layerout, (input_shape[1], len(mapping)))
layerout = tf.reshape(
layerout, (input_shape[1], len(mapping)), name="reshape_2"
)

volume = data["volume"]
bscan = data["bscan"]
group = data["group"]

return {
"layer_positions": layer_positions,
"image": image,
"layerout": layerout,
"Volume": volume,
Expand Down Expand Up @@ -80,7 +89,11 @@ def _augment(in_data):
lambda: image,
)

return {"image": image, "layerout": layerout}
return {
"image": image,
"layerout": layerout,
"layer_positions": in_data["layer_positions"],
}

return _augment

Expand All @@ -97,14 +110,41 @@ def _normalize(in_data):
image = tf.cast(image, tf.float32)
image = image - tf.math.reduce_mean(image)
image = image / tf.math.reduce_std(image)
return {**in_data, **{"image": image, "layerout": layerout}}
return {
**in_data,
**{
"image": image,
"layerout": layerout,
"layer_positions": in_data["layer_positions"],
},
}


@tf.function
def _prepare_train(in_data):
image, layerout = in_data["image"], in_data["layerout"]
image, layerout, layer_positions = (
in_data["image"],
in_data["layerout"],
in_data["layer_positions"],
)

return image, {"layer_output": layerout, "columnwise_softmax": layer_positions}


@tf.function
def _prepare_test(in_data):
volume, bscan, group, image, layerout = (
in_data["Volume"],
in_data["Bscan"],
in_data["Group"],
in_data["image"],
in_data["layerout"],
)
return image, {
"layer_output": layerout,
"Volume": volume,
"Bscan": bscan,
"Group": group,
}


Expand All @@ -119,6 +159,7 @@ def _prepare_train(in_data):
)
return image, {
"layer_output": layerout,
"columnwise_softmax": in_data["layer_positions"],
}

return _prepare_train
Expand Down Expand Up @@ -228,13 +269,15 @@ def _transform(in_data):
tf.linalg.inv(combined_matrix)
),
interpolation="bilinear",
output_shape=input_shape,
)
# combined_matrix = tf.linalg.inv(combined_matrix)

# Warp 1D data
x_vals = (
tf.tile(
tf.reshape(tf.range(0, width, dtype=tf.float32), (1, width)),
tf.reshape(
tf.range(0, width, dtype=tf.float32), (1, width), name="reshape_3"
),
[num_classes, 1],
)
+ 0.5
Expand Down Expand Up @@ -326,13 +369,21 @@ def get_split(
parsed_data.shuffle(
14000, seed, reshuffle_each_iteration=True
) # .map(_augment)
.map(_transform)
# .map(_transform)
.map(_normalize)
.batch(batch_size)
.map(_prepare_train)
.repeat(epochs)
.prefetch(tf.data.experimental.AUTOTUNE)
)
elif split == "test":
dataset = (
parsed_data.map(_normalize)
.batch(batch_size)
.map(_prepare_test)
.repeat(epochs)
.prefetch(tf.data.experimental.AUTOTUNE)
)
else:
dataset = (
parsed_data.map(_normalize)
Expand Down
35 changes: 35 additions & 0 deletions eyeseg/io_utils/losses.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,41 @@
import tensorflow as tf


def layer_ce(true, pred):
# true : batch, columns, channels

# pred: batch, rows, columns, channels
all = tf.math.log(tf.reduce_sum(true * pred, axis=1))
mask = tf.logical_not(tf.math.is_inf(all))
all_clean = tf.ragged.boolean_mask(all, mask)

return tf.reduce_sum(-tf.reduce_sum(all_clean, axis=1), axis=-1)

ctrue = tf.cast(tf.round(true), dtype=tf.int32)

cols = tf.range(1024, dtype=tf.int32)
layers = tf.range(9, dtype=tf.int32)

batch_losses = []
for batch in tf.range(1, dtype=tf.int32):
layer_losses = []
for layer in layers:
layer_results = []
for c in cols:
row = ctrue[batch, c, layer]
# if True: #not tf.math.is_nan(true[batch, c, layer]):
value = pred[batch, row, c, layer]
layer_results.append(value)
layer_losses.append(
-tf.reduce_sum(
tf.math.log(tf.cast(tf.stack(layer_results), tf.float32))
)
)
batch_losses.append(tf.reduce_mean(layer_losses))

return tf.stack(batch_losses)


class MovingMeanFocalSSE(tf.keras.losses.Loss):
# initialize instance attributes
def __init__(self, window_size, curv_weight=0):
Expand Down
136 changes: 136 additions & 0 deletions eyeseg/io_utils/test.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 58,
"metadata": {
"collapsed": true
},
"outputs": [
{
"data": {
"text/plain": "<tf.Tensor: shape=(2,), dtype=float32, numpy=array([7.9214387, 7.9214387], dtype=float32)>"
},
"execution_count": 58,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import tensorflow as tf\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
" #heights = tf.cast(tf.math.round(true), tf.int32)\n",
"\n",
" #return tf.reduce_mean(-tf.reduce_sum(tf.gather(pred, indices, )\n",
"\n",
"layer_ce(true, pred)"
]
},
{
"cell_type": "code",
"execution_count": 56,
"outputs": [],
"source": [
"import tensorflow as tf\n",
"true = tf.tile(tf.expand_dims(tf.tile(tf.expand_dims(tf.range(10), 1), multiples=[1, 5]), 0), [2, 1,1])\n",
"pred = tf.tile(tf.expand_dims(tf.range(10), 1), [1,10])\n",
"pred = tf.tile(tf.expand_dims(tf.tile(tf.expand_dims(pred, -1) , [1,1,5]), 0), [2,1,1,1]) / 10"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 42,
"outputs": [],
"source": [
"x_indices = tf.reshape(tf.range(true.shape[1]), (1,10,1))\n",
"x_indices = tf.tile(x_indices, (2,1,5))\n",
"\n",
"indices = tf.stack([true, x_indices], axis=1)"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 43,
"outputs": [
{
"data": {
"text/plain": "<tf.Tensor: shape=(2, 2, 10, 5), dtype=int32, numpy=\narray([[[[0, 0, 0, 0, 0],\n [1, 1, 1, 1, 1],\n [2, 2, 2, 2, 2],\n [3, 3, 3, 3, 3],\n [4, 4, 4, 4, 4],\n [5, 5, 5, 5, 5],\n [6, 6, 6, 6, 6],\n [7, 7, 7, 7, 7],\n [8, 8, 8, 8, 8],\n [9, 9, 9, 9, 9]],\n\n [[0, 0, 0, 0, 0],\n [1, 1, 1, 1, 1],\n [2, 2, 2, 2, 2],\n [3, 3, 3, 3, 3],\n [4, 4, 4, 4, 4],\n [5, 5, 5, 5, 5],\n [6, 6, 6, 6, 6],\n [7, 7, 7, 7, 7],\n [8, 8, 8, 8, 8],\n [9, 9, 9, 9, 9]]],\n\n\n [[[0, 0, 0, 0, 0],\n [1, 1, 1, 1, 1],\n [2, 2, 2, 2, 2],\n [3, 3, 3, 3, 3],\n [4, 4, 4, 4, 4],\n [5, 5, 5, 5, 5],\n [6, 6, 6, 6, 6],\n [7, 7, 7, 7, 7],\n [8, 8, 8, 8, 8],\n [9, 9, 9, 9, 9]],\n\n [[0, 0, 0, 0, 0],\n [1, 1, 1, 1, 1],\n [2, 2, 2, 2, 2],\n [3, 3, 3, 3, 3],\n [4, 4, 4, 4, 4],\n [5, 5, 5, 5, 5],\n [6, 6, 6, 6, 6],\n [7, 7, 7, 7, 7],\n [8, 8, 8, 8, 8],\n [9, 9, 9, 9, 9]]]], dtype=int32)>"
},
"execution_count": 43,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"indices"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 47,
"outputs": [
{
"data": {
"text/plain": "<tf.Tensor: shape=(1, 1, 1), dtype=int32, numpy=array([[[5]]], dtype=int32)>"
},
"execution_count": 47,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tf.gather_nd(pred, [[[[0,5,5,0]]]])"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
12 changes: 6 additions & 6 deletions eyeseg/io_utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def preprocess_split(volume_paths, savepath, split, excluded=None):
# Load volume
data = ep.Oct.from_duke_mat(p)
# Compute center of annotation
bm_annotation = (~np.isnan(data.layers["BM"])).astype(int)
bm_annotation = (~np.isnan(data.analyse["BM"])).astype(int)
height_center, width_center = [
int(c) for c in ndimage.measurements.center_of_mass(bm_annotation)
]
Expand All @@ -61,13 +61,13 @@ def preprocess_split(volume_paths, savepath, split, excluded=None):
image = bscan.scan[:, width_center - 256 : width_center + 256].astype(
np.uint8
)
bm = bscan.layers["BM"][width_center - 256 : width_center + 256].astype(
np.float32
)
rpe = bscan.layers["RPE"][
bm = bscan.analyse["BM"][
width_center - 256 : width_center + 256
].astype(np.float32)
rpe = bscan.analyse["RPE"][
width_center - 256 : width_center + 256
].astype(np.float32)
ilm = bscan.layers["ILM"][
ilm = bscan.analyse["ILM"][
width_center - 256 : width_center + 256
].astype(np.float32)

Expand Down
Loading

0 comments on commit 239c660

Please sign in to comment.