-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathsrcnn.py
27 lines (22 loc) · 782 Bytes
/
srcnn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
from typing import Any
import torch.nn as nn
import torch.nn.functional as F
from .srmodel import SRModel
class SRCNN(SRModel):
"""
LightningModule for SRCNN, https://ieeexplore.ieee.org/document/7115171?arnumber=7115171
https://arxiv.org/pdf/1501.00092.pdf.
"""
def __init__(self, **kwargs: dict[str, Any]):
super(SRCNN, self).__init__(**kwargs)
self._net = nn.Sequential(
nn.Conv2d(self._channels, 64, 9, padding=4),
nn.ReLU(True),
nn.Conv2d(64, 32, 1, padding=0),
nn.ReLU(True),
nn.Conv2d(32, self._channels, 5, padding=2)
)
def forward(self, x):
x = F.interpolate(
x, scale_factor=self._scale_factor, mode='bicubic')
return self._net(x)