-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
80 lines (63 loc) · 2.65 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import torch.nn as nn
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
class Logger(object):
"""Tensorboard logger."""
def __init__(self, log_dir):
"""Initialize summary writer."""
self.writer = tf.summary.FileWriter(log_dir)
def scalar_summary(self, tag, value, step):
"""Add scalar summary."""
summary = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value)])
self.writer.add_summary(summary, step)
class NormLayer(nn.Module):
def __init__(self, channels, normalize_shape=None, norm_type='bn'):
super(NormLayer, self).__init__()
norm_type = norm_type.lower()
if norm_type == 'bn':
self.norm = nn.BatchNorm2d(channels)
elif norm_type == 'in':
self.norm = nn.InstanceNorm2d(channels, affine=True)
elif norm_type == 'none':
self.norm = lambda x: x
else:
assert 1==0, 'Norm type {} not support.'.format(norm_type)
def forward(self, x):
return self.norm(x)
class ReluLayer(nn.Module):
def __init__(self, channels, relu_type='relu'):
super(ReluLayer, self).__init__()
relu_type = relu_type.lower()
if relu_type == 'relu':
self.func = nn.ReLU(True)
elif relu_type == 'leakyrelu':
self.func = nn.LeakyReLU(0.2, inplace=True)
elif relu_type == 'prelu':
self.func = nn.PReLU(channels)
elif relu_type == 'none':
self.func = lambda x: x
else:
assert 1==0, 'Relu type {} not support.'.format(relu_type)
def forward(self, x):
return self.func(x)
class ConvLayer(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=3, scale='none', norm_type='none', relu_type='none', use_pad=True):
super(ConvLayer, self).__init__()
self.use_pad = use_pad
bias = True if norm_type in ['pixel', 'none'] else False
stride = 2 if scale == 'down' else 1
self.scale_func = lambda x: x
if scale == 'up':
self.scale_func = lambda x: nn.functional.interpolate(x, scale_factor=2, mode='nearest')
self.reflection_pad = nn.ReflectionPad2d(kernel_size // 2)
self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride, bias=bias)
self.relu = ReluLayer(out_channels, relu_type)
self.norm = NormLayer(out_channels, norm_type=norm_type)
def forward(self, x):
out = self.scale_func(x)
if self.use_pad:
out = self.reflection_pad(out)
out = self.conv2d(out)
out = self.norm(out)
out = self.relu(out)
return out