From f5626da4e704b25d7bce9b7f9e82816de7e94f4e Mon Sep 17 00:00:00 2001 From: Jean Kossaifi Date: Sun, 9 Jun 2024 10:01:46 -0700 Subject: [PATCH] FIX tests: input dtype --- tltorch/factorized_layers/tests/test_factorized_linear.py | 2 +- .../tests/test_tensor_contraction_layers.py | 2 +- tltorch/factorized_layers/tests/test_trl.py | 8 ++++---- tltorch/functional/tests/test_factorized_linear.py | 2 +- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/tltorch/factorized_layers/tests/test_factorized_linear.py b/tltorch/factorized_layers/tests/test_factorized_linear.py index e35c554..b19f367 100644 --- a/tltorch/factorized_layers/tests/test_factorized_linear.py +++ b/tltorch/factorized_layers/tests/test_factorized_linear.py @@ -16,7 +16,7 @@ def test_FactorizedLinear(factorization): in_shape = (3, 3) out_features = 16 out_shape = (4, 4) - data = tl.tensor(rng.random_sample((batch_size, in_features))) + data = tl.tensor(rng.random_sample((batch_size, in_features)), dtype=tl.float32) # Creat from a tensor factorization tensor = TensorizedTensor.new((out_shape, in_shape), rank='same', factorization=factorization) diff --git a/tltorch/factorized_layers/tests/test_tensor_contraction_layers.py b/tltorch/factorized_layers/tests/test_tensor_contraction_layers.py index 04a8f8e..b520643 100644 --- a/tltorch/factorized_layers/tests/test_tensor_contraction_layers.py +++ b/tltorch/factorized_layers/tests/test_tensor_contraction_layers.py @@ -10,7 +10,7 @@ def test_tcl(): batch_size = 2 in_shape = (4, 5, 6) out_shape = (2, 3, 5) - data = tl.tensor(rng.random_sample((batch_size, ) + in_shape)) + data = tl.tensor(rng.random_sample((batch_size, ) + in_shape), dtype=tl.float32) expected_shape = (batch_size, ) + out_shape tcl = TCL(input_shape=in_shape, rank=out_shape, bias=False) diff --git a/tltorch/factorized_layers/tests/test_trl.py b/tltorch/factorized_layers/tests/test_trl.py index 257aec8..176c432 100644 --- a/tltorch/factorized_layers/tests/test_trl.py +++ b/tltorch/factorized_layers/tests/test_trl.py @@ -76,8 +76,8 @@ def test_trl(factorization, true_rank, rank): tol = 0.08 # Generate a random tensor - samples = tl.tensor(rng.normal(size=(batch_size, *input_shape), loc=0, scale=1)) - true_bias = tl.tensor(rng.uniform(size=output_shape)) + samples = tl.tensor(rng.normal(size=(batch_size, *input_shape), loc=0, scale=1), dtype=tl.float32) + true_bias = tl.tensor(rng.uniform(size=output_shape), dtype=tl.float32) with torch.no_grad(): true_weight = FactorizedTensor.new(shape=input_shape+output_shape, @@ -130,7 +130,7 @@ def test_TuckerTRL(order, project_input, learn_pool): # fix the random seed for reproducibility and create random input random_state = 12345 rng = tl.check_random_state(random_state) - data = tl.tensor(rng.random_sample((batch_size, in_features) + (spatial_size, )*order)) + data = tl.tensor(rng.random_sample((batch_size, in_features) + (spatial_size, )*order), dtype=tl.float32) # Build a simple net with avg-pool, flatten + fully-connected if order == 2: @@ -182,7 +182,7 @@ def test_TRL_from_linear(factorization, bias): # fix the random seed for reproducibility and create random input random_state = 12345 rng = tl.check_random_state(random_state) - data = tl.tensor(rng.random_sample((batch_size, in_features))) + data = tl.tensor(rng.random_sample((batch_size, in_features)), dtype=tl.float32) fc = nn.Linear(in_features, out_features, bias=bias) res_fc = fc(tl.copy(data)) trl = TRL((in_features, ), (out_features, ), rank=10, bias=bias, factorization=factorization) diff --git a/tltorch/functional/tests/test_factorized_linear.py b/tltorch/functional/tests/test_factorized_linear.py index 7465cdd..17e6a44 100644 --- a/tltorch/functional/tests/test_factorized_linear.py +++ b/tltorch/functional/tests/test_factorized_linear.py @@ -21,7 +21,7 @@ def test_linear_tensor_dot_tucker(factorization, factorized_linear): rank = 3 batch_size = 2 - tensor = tl.randn((batch_size, in_dim)) + tensor = tl.randn((batch_size, in_dim), dtype=tl.float32) fact_weight = TensorizedTensor.new((out_shape, in_shape), rank=rank, factorization=factorization) fact_weight.normal_()