Skip to content

Commit

Permalink
[Refactor] Refactor random degradation (#1583)
Browse files Browse the repository at this point in the history
* fix jpeg config

* fix random degradation
  • Loading branch information
Z-Fran authored Jan 16, 2023
1 parent ad24a5a commit b154794
Show file tree
Hide file tree
Showing 8 changed files with 17 additions and 19 deletions.
2 changes: 1 addition & 1 deletion configs/_base_/datasets/decompression_test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
imdecode_backend='cv2'),
dict(
type='RandomJPEGCompression',
params=dict(quality=[quality, quality], color_type='color'),
params=dict(quality=[quality, quality]),
bgr2rgb=True,
keys=['img']),
dict(type='PackEditInputs')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@
color_type='color',
channel_order='rgb'),
dict(type='SetValues', dictionary=dict(scale=scale)),
dict(type='RescaleToZeroOne', keys=['gt']),
dict(type='CopyValues', src_keys=['gt'], dst_keys=['img']),
dict(
type='RandomBlur',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
dict(type='LoadImageFromFile', key='gt', channel_order='rgb'),
dict(type='SetValues', dictionary=dict(scale=scale)),
dict(type='FixedCrop', keys=['gt'], crop_size=(256, 256)),
dict(type='RescaleToZeroOne', keys=['gt']),
dict(type='Flip', keys=['gt'], flip_ratio=0.5, direction='horizontal'),
dict(type='Flip', keys=['gt'], flip_ratio=0.5, direction='vertical'),
dict(type='RandomTransposeHW', keys=['gt'], transpose_ratio=0.5),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
keys=['gt'],
crop_size=(gt_crop_size, gt_crop_size),
random_crop=True),
dict(type='RescaleToZeroOne', keys=['gt']),
dict(
type='UnsharpMasking',
keys=['gt'],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@
else:
test_pipeline[0]['color_type'] = 'grayscale'
test_pipeline[1]['color_type'] = 'grayscale'
test_pipeline[2]['params']['color_type'] = 'grayscale'
test_pipeline[2]['color_type'] = 'grayscale'

# optimizer
optim_wrapper = dict(
Expand Down
5 changes: 3 additions & 2 deletions mmedit/datasets/transforms/aug_pixel.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,10 +574,11 @@ def _unsharp_masking(self, imgs):

outputs = []
for img in imgs:
img = img.astype(np.float32)
residue = img - cv2.filter2D(img, -1, self.kernel)
mask = np.float32(np.abs(residue) * 255 > self.threshold)
mask = np.float32(np.abs(residue) > self.threshold)
soft_mask = cv2.filter2D(mask, -1, self.kernel)
sharpened = np.clip(img + self.weight * residue, 0, 1)
sharpened = np.clip(img + self.weight * residue, 0, 255)

outputs.append(soft_mask * sharpened + (1 - soft_mask) * img)

Expand Down
16 changes: 8 additions & 8 deletions mmedit/datasets/transforms/random_degradations.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,9 +165,10 @@ class RandomJPEGCompression:
bgr2rgb (str): Whether change channel order. Default: False.
"""

def __init__(self, params, keys, bgr2rgb=False):
def __init__(self, params, keys, color_type='color', bgr2rgb=False):
self.keys = keys
self.params = params
self.color_type = color_type
self.bgr2rgb = bgr2rgb

def _apply_random_compression(self, imgs):
Expand All @@ -178,19 +179,18 @@ def _apply_random_compression(self, imgs):

# determine initial compression level and the step size
quality = self.params['quality']
color_type = self.params['color_type']
quality_step = self.params.get('quality_step', 0)
jpeg_param = round(np.random.uniform(quality[0], quality[1]))

# apply jpeg compression
outputs = []
for img in imgs:
encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), jpeg_param]
if self.bgr2rgb and color_type == 'color':
if self.bgr2rgb and self.color_type == 'color':
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
_, img_encoded = cv2.imencode('.jpg', img, encode_param)

if color_type == 'color':
if self.color_type == 'color':
img_encoded = cv2.imdecode(img_encoded, 1)
if self.bgr2rgb:
img_encoded = cv2.cvtColor(img_encoded, cv2.COLOR_BGR2RGB)
Expand Down Expand Up @@ -281,7 +281,7 @@ def _apply_poisson_noise(self, imgs):

outputs = []
for img in imgs:
noise = img.copy()
noise = np.float32(img.copy())
if is_gray_noise:
noise = cv2.cvtColor(noise[..., [2, 1, 0]], cv2.COLOR_BGR2GRAY)
noise = noise[..., np.newaxis]
Expand Down Expand Up @@ -498,7 +498,7 @@ def _apply_random_compression(self, imgs):
stream.bit_rate = bitrate

for img in imgs:
img = (255 * img).astype(np.uint8)
img = img.astype(np.uint8)
frame = av.VideoFrame.from_ndarray(img, format='rgb24')
frame.pict_type = 'NONE'
for packet in stream.encode(frame):
Expand All @@ -512,8 +512,8 @@ def _apply_random_compression(self, imgs):
with av.open(buf, 'r', 'mp4') as container:
if container.streams.video:
for frame in container.decode(**{'video': 0}):
outputs.append(
frame.to_rgb().to_ndarray().astype(np.float32) / 255.)
outputs.append(frame.to_rgb().to_ndarray().astype(
np.float32))

return outputs

Expand Down
8 changes: 4 additions & 4 deletions mmedit/models/editors/real_esrgan/real_esrgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,16 +192,16 @@ def extract_gt_data(self, data_samples):
gt_imgs = [data_sample.gt_img.data for data_sample in data_samples]
gt = torch.stack(gt_imgs)
gt_unsharp = [
data_sample.gt_unsharp.data for data_sample in data_samples
data_sample.gt_unsharp.data / 255. for data_sample in data_samples
]
gt_unsharp = torch.stack(gt_unsharp)

gt_pixel, gt_percep, gt_gan = gt.clone(), gt.clone(), gt.clone()
if self.is_use_sharpened_gt_in_pixel:
gt_pixel = gt_unsharp
gt_pixel = gt_unsharp.clone()
if self.is_use_sharpened_gt_in_percep:
gt_percep = gt_unsharp
gt_percep = gt_unsharp.clone()
if self.is_use_sharpened_gt_in_gan:
gt_gan = gt_unsharp
gt_gan = gt_unsharp.clone()

return gt_pixel, gt_percep, gt_gan

0 comments on commit b154794

Please sign in to comment.