Skip to content

Commit

Permalink
FIX tests: input dtype
Browse files Browse the repository at this point in the history
  • Loading branch information
JeanKossaifi committed Jun 9, 2024
1 parent 6b93e33 commit f5626da
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 7 deletions.
2 changes: 1 addition & 1 deletion tltorch/factorized_layers/tests/test_factorized_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def test_FactorizedLinear(factorization):
in_shape = (3, 3)
out_features = 16
out_shape = (4, 4)
data = tl.tensor(rng.random_sample((batch_size, in_features)))
data = tl.tensor(rng.random_sample((batch_size, in_features)), dtype=tl.float32)

# Creat from a tensor factorization
tensor = TensorizedTensor.new((out_shape, in_shape), rank='same', factorization=factorization)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def test_tcl():
batch_size = 2
in_shape = (4, 5, 6)
out_shape = (2, 3, 5)
data = tl.tensor(rng.random_sample((batch_size, ) + in_shape))
data = tl.tensor(rng.random_sample((batch_size, ) + in_shape), dtype=tl.float32)

expected_shape = (batch_size, ) + out_shape
tcl = TCL(input_shape=in_shape, rank=out_shape, bias=False)
Expand Down
8 changes: 4 additions & 4 deletions tltorch/factorized_layers/tests/test_trl.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ def test_trl(factorization, true_rank, rank):
tol = 0.08

# Generate a random tensor
samples = tl.tensor(rng.normal(size=(batch_size, *input_shape), loc=0, scale=1))
true_bias = tl.tensor(rng.uniform(size=output_shape))
samples = tl.tensor(rng.normal(size=(batch_size, *input_shape), loc=0, scale=1), dtype=tl.float32)
true_bias = tl.tensor(rng.uniform(size=output_shape), dtype=tl.float32)

with torch.no_grad():
true_weight = FactorizedTensor.new(shape=input_shape+output_shape,
Expand Down Expand Up @@ -130,7 +130,7 @@ def test_TuckerTRL(order, project_input, learn_pool):
# fix the random seed for reproducibility and create random input
random_state = 12345
rng = tl.check_random_state(random_state)
data = tl.tensor(rng.random_sample((batch_size, in_features) + (spatial_size, )*order))
data = tl.tensor(rng.random_sample((batch_size, in_features) + (spatial_size, )*order), dtype=tl.float32)

# Build a simple net with avg-pool, flatten + fully-connected
if order == 2:
Expand Down Expand Up @@ -182,7 +182,7 @@ def test_TRL_from_linear(factorization, bias):
# fix the random seed for reproducibility and create random input
random_state = 12345
rng = tl.check_random_state(random_state)
data = tl.tensor(rng.random_sample((batch_size, in_features)))
data = tl.tensor(rng.random_sample((batch_size, in_features)), dtype=tl.float32)
fc = nn.Linear(in_features, out_features, bias=bias)
res_fc = fc(tl.copy(data))
trl = TRL((in_features, ), (out_features, ), rank=10, bias=bias, factorization=factorization)
Expand Down
2 changes: 1 addition & 1 deletion tltorch/functional/tests/test_factorized_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def test_linear_tensor_dot_tucker(factorization, factorized_linear):
rank = 3
batch_size = 2

tensor = tl.randn((batch_size, in_dim))
tensor = tl.randn((batch_size, in_dim), dtype=tl.float32)
fact_weight = TensorizedTensor.new((out_shape, in_shape), rank=rank,
factorization=factorization)
fact_weight.normal_()
Expand Down

0 comments on commit f5626da

Please sign in to comment.