Skip to content

Commit

Permalink
[Dy2St] transforms.RandomVerticalFlip Support static mode (#49024)
Browse files Browse the repository at this point in the history
* add static RandomVerticalFlip

* object => unittest.TestCase
  • Loading branch information
DrRyanHuang authored Dec 13, 2022
1 parent 3192269 commit 889e583
Show file tree
Hide file tree
Showing 4 changed files with 131 additions and 2 deletions.
102 changes: 102 additions & 0 deletions python/paddle/tests/test_transforms_static.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import numpy as np

import paddle
from paddle.vision.transforms import transforms

SEED = 2022


class TestTransformUnitTestBase(unittest.TestCase):
def setUp(self):
self.img = (np.random.rand(*self.get_shape()) * 255.0).astype(
np.float32
)
self.set_trans_api()

def get_shape(self):
return (64, 64, 3)

def set_trans_api(self):
self.api = transforms.Resize(size=16)

def dynamic_transform(self):
paddle.seed(SEED)

img_t = paddle.to_tensor(self.img)
return self.api(img_t)

def static_transform(self):
paddle.enable_static()
paddle.seed(SEED)

main_program = paddle.static.Program()
with paddle.static.program_guard(main_program):
x = paddle.static.data(
shape=self.get_shape(), dtype=paddle.float32, name='img'
)
out = self.api(x)

exe = paddle.static.Executor()
res = exe.run(main_program, fetch_list=[out], feed={'img': self.img})

paddle.disable_static()
return res[0]

def test_transform(self):
dy_res = self.dynamic_transform()
st_res = self.static_transform()

np.testing.assert_almost_equal(dy_res, st_res)


class TestResize(TestTransformUnitTestBase):
def set_trans_api(self):
self.api = transforms.Resize(size=(16, 16))


class TestResizeError(TestTransformUnitTestBase):
def test_transform(self):
pass

def test_error(self):
paddle.enable_static()
# Not support while w<=0 or h<=0, but received w=-1, h=-1
with self.assertRaises(NotImplementedError):
main_program = paddle.static.Program()
with paddle.static.program_guard(main_program):
x = paddle.static.data(
shape=[-1, -1, -1], dtype=paddle.float32, name='img'
)
self.api(x)

paddle.disable_static()


class TestRandomVerticalFlip0(TestTransformUnitTestBase):
def set_trans_api(self):
self.api = transforms.RandomVerticalFlip(prob=0)


class TestRandomVerticalFlip1(TestTransformUnitTestBase):
def set_trans_api(self):
self.api = transforms.RandomVerticalFlip(prob=1)


if __name__ == "__main__":
unittest.main()
6 changes: 5 additions & 1 deletion python/paddle/vision/transforms/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import paddle

from ...fluid.framework import Variable
from . import functional_cv2 as F_cv2
from . import functional_pil as F_pil
from . import functional_tensor as F_t
Expand All @@ -32,7 +33,10 @@ def _is_pil_image(img):


def _is_tensor_image(img):
return isinstance(img, paddle.Tensor)
"""
Return True if img is a Tensor for dynamic mode or Variable for static mode.
"""
return isinstance(img, (paddle.Tensor, Variable))


def _is_numpy_image(img):
Expand Down
12 changes: 11 additions & 1 deletion python/paddle/vision/transforms/functional_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@
import paddle
import paddle.nn.functional as F

from ...fluid.framework import Variable

__all__ = []


def _assert_image_tensor(img, data_format):
if (
not isinstance(img, paddle.Tensor)
not isinstance(img, (paddle.Tensor, Variable))
or img.ndim < 3
or img.ndim > 4
or not data_format.lower() in ('chw', 'hwc')
Expand Down Expand Up @@ -725,6 +727,14 @@ def resize(img, size, interpolation='bilinear', data_format='CHW'):

if isinstance(size, int):
w, h = _get_image_size(img, data_format)
# TODO(Aurelius84): In static mode, w and h will be -1 for dynamic shape.
# We should consider to support this case in future.
if w <= 0 or h <= 0:
raise NotImplementedError(
"Not support while w<=0 or h<=0, but received w={}, h={}".format(
w, h
)
)
if (w <= h and w == size) or (h <= w and h == size):
return img
if w < h:
Expand Down
13 changes: 13 additions & 0 deletions python/paddle/vision/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -653,10 +653,23 @@ def __init__(self, prob=0.5, keys=None):
self.prob = prob

def _apply_image(self, img):
if paddle.in_dynamic_mode():
return self._dynamic_apply_image(img)
else:
return self._static_apply_image(img)

def _dynamic_apply_image(self, img):
if random.random() < self.prob:
return F.vflip(img)
return img

def _static_apply_image(self, img):
return paddle.static.nn.cond(
paddle.rand(shape=(1,)) < self.prob,
lambda: F.vflip(img),
lambda: img,
)


class Normalize(BaseTransform):
"""Normalize the input data with mean and standard deviation.
Expand Down

0 comments on commit 889e583

Please sign in to comment.