-
Notifications
You must be signed in to change notification settings - Fork 1
/
image_loss.py
41 lines (31 loc) · 1.14 KB
/
image_loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import tensorflow as tf
content_layers = [
'block4_conv1'
]
def vgg(input_shape):
input = tf.keras.layers.Input(shape=input_shape)
net = tf.keras.applications.vgg16.VGG16(
input_tensor=input,
weights='imagenet',
include_top=False
)
output = None
for layer in net.layers:
if layer.name in content_layers:
output = layer.output
layer.trainable = False
break
return tf.keras.models.Model(input, output)
class PerceptualLoss:
def __init__(self, image_shape):
self.image_shape = image_shape
self.model = vgg(image_shape)
def calculate_loss(self, original_image, image_tensor):
content_output = self.model(original_image)
prediction = self.model(image_tensor)
content_loss = tf.math.reduce_mean(tf.math.square(content_output-prediction))
return content_loss
def __call__(self, original_image, image_tensor):
original_image = original_image*255
image_tensor = image_tensor*255
return self.calculate_loss(original_image, image_tensor)