-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrian_test.py
40 lines (27 loc) · 974 Bytes
/
trian_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
import train
FLAGS = tf.flags.FLAGS
mock = tf.test.mock
class TrainTest(tf.test.TestCase):
def _test_build_graph_helper(self, weight_factor):
FLAGS.max_number_of_steps = 0
FLAGS.weight_factor = weight_factor
batch_size = 3
patch_size = 16
FLAGS.batch_size = batch_size
FLAGS.patch_size = patch_size
mock_imgs = np.zeros([batch_size, patch_size, patch_size, 3],
dtype=np.float32)
with mock.patch.object(train, 'data_provider') as mock_data_provider:
mock_data_provider.provide_data.return_value = mock_imgs
train.main(None)
def test_build_graph_noadversarialloss(self):
self._test_build_graph_helper(0.0)
def test_build_graph_adversarialloss(self):
self._test_build_graph_helper(1.0)
if __name__ == '__main__':
tf.test.main()