diff --git a/ci_test/unit_tests/test_unit_layer_convolution_distconv.py b/ci_test/unit_tests/test_unit_layer_convolution_distconv.py index 75bb5c65b73..40fe2783728 100644 --- a/ci_test/unit_tests/test_unit_layer_convolution_distconv.py +++ b/ci_test/unit_tests/test_unit_layer_convolution_distconv.py @@ -55,6 +55,19 @@ def num_samples(): def sample_dims(): return (_sample_size,) +# Convolution results computed using PyTorch beforehand. +def reference_value(num_dims, stride): + if num_dims == 2 and stride == 1: + return 11913.852660080756 + elif num_dims == 2 and stride == 2: + return 2850.678372506634 + elif num_dims == 3 and stride == 1: + return 9952.365297083174 + elif num_dims == 3 and stride == 2: + return 1222.3680410607026 + + raise ValueError('Does not have a pre-computed reference value for this configuration') + # ============================================== # PyTorch convolution # ============================================== @@ -165,12 +178,10 @@ def construct_model(lbann): print('Skip - ' + e) pytest.skip(e) - for num_dims, reference_val in [ - (2, 11913.852660080756), - (3, 9952.365297083174)]: + for num_dims, s in [(2, 1), (2, 2), (3, 1), (3, 2)]: # Convolution settings kernel_dims = [5, _sample_dims[0] if num_dims == 2 else _sample_dims_3d[0],] + [3]*num_dims - strides = [1]*num_dims + strides = [s]*num_dims pads = [1]*num_dims dilations = [1]*num_dims kernel = make_random_array(kernel_dims, 11) @@ -179,7 +190,7 @@ def construct_model(lbann): kernel_weights = lbann.Weights( optimizer=lbann.SGD(), initializer=lbann.ValueInitializer(values=np.nditer(kernel)), - name='kernel1_{}d'.format(num_dims) + name='kernel1_{}d-stride{}'.format(num_dims, s) ) x = x_lbann if num_dims == 3: @@ -198,7 +209,9 @@ def construct_model(lbann): num_height_groups)) z = lbann.L2Norm2(y) obj.append(z) - metrics.append(lbann.Metric(z, name='basic {}D 3^n convolution'.format(num_dims))) + metrics.append( + lbann.Metric(z, name='basic {}D 3^n convolution with stride{}'.format(num_dims, s)) + ) # PyTorch implementation try: @@ -214,7 +227,7 @@ def construct_model(lbann): val = z except: # Precomputed value - val = reference_val + val = reference_value(num_dims, s) # val = 398.6956458317758 # _num_samples=8, 6 channels # val = 381.7401227915947 # _num_samples=23, 6 channels tol = 8 * val * np.finfo(np.float32).eps