-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathsubpixel.py
31 lines (25 loc) · 1015 Bytes
/
subpixel.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import tensorflow as tf
from keras.layers import Lambda
def SubpixelConv2D(input_shape, scale=4):
"""
Keras layer to do subpixel convolution.
NOTE: Tensorflow backend only. Uses tf.depth_to_space
Ref:
[1] Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network
Shi et Al.
https://arxiv.org/abs/1609.05158
:param input_shape: tensor shape, (batch, height, width, channel)
:param scale: upsampling scale. Default=4
:return:
"""
# upsample using depth_to_space
def subpixel_shape(input_shape):
dims = [input_shape[0],
input_shape[1] * scale,
input_shape[2] * scale,
int(input_shape[3] / (scale ** 2))]
output_shape = tuple(dims)
return output_shape
def subpixel(x):
return tf.depth_to_space(x, scale)
return Lambda(subpixel, output_shape=subpixel_shape, name='subpixel')