forked from thangvubk/SRCNN_Pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
model.py
36 lines (29 loc) · 948 Bytes
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from data_loader import SRCNN_dataset
from utils import *
class SRCNN(nn.Module):
"""
Model for SRCNN
Input -> Conv1 -> Relu -> Conv2 -> Relu -> Conv3 -> MSE
Args:
- C1, C2, C3: num output channels for Conv1, Conv2, and Conv3
- F1, F2, F3: filter size
"""
def __init__(self,
C1=64, C2=32, C3=1,
F1=9, F2=1, F3=5):
super(SRCNN, self).__init__()
self.conv1 = nn.Conv2d(1, C1, F1) # in, out, kernel
self.conv2 = nn.Conv2d(C1, C2, F2)
self.conv3 = nn.Conv2d(C2, C3, F3)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
x = self.conv3(x)
return x