Skip to content

Commit

Permalink
[INTERNAL] Speed up dgan tests
Browse files Browse the repository at this point in the history
GitOrigin-RevId: 962311af186cafc55b6020c97d6bc90c29c77703
  • Loading branch information
kboyd committed Oct 21, 2024
1 parent ae129c5 commit f95beb9
Showing 1 changed file with 4 additions and 16 deletions.
20 changes: 4 additions & 16 deletions tests/timeseries_dgan/test_dgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def config() -> DGANConfig:
max_sequence_len=20,
sample_len=5,
batch_size=10,
epochs=10,
epochs=1,
)


Expand Down Expand Up @@ -162,10 +162,10 @@ def test_generate():
attributes, features, attributes_shape=(64, 3), features_shape=(64, 20, 2)
)

attributes, features = dg.generate_numpy(200)
attributes, features = dg.generate_numpy(50)

assert_attributes_features_shape(
attributes, features, attributes_shape=(200, 3), features_shape=(200, 20, 2)
attributes, features, attributes_shape=(50, 3), features_shape=(50, 20, 2)
)

attributes, features = dg.generate_numpy(1)
Expand Down Expand Up @@ -297,7 +297,7 @@ def test_train_numpy_no_attributes_1(
def test_train_numpy_no_attributes_2(config: DGANConfig):
features = np.random.rand(100, 20, 2)
n_samples = 10
config.epochs = 1

model_attributes_blank = DGAN(config=config)
model_attributes_blank.train_numpy(features=features)
synthetic_attributes, synthetic_features = model_attributes_blank.generate_numpy(
Expand Down Expand Up @@ -331,7 +331,6 @@ def test_train_numpy_batch_size_of_1(config: DGANConfig):
# Check model trains when (# of examples) % batch_size == 1.

config.batch_size = 10
config.epochs = 1

features = np.random.rand(91, 20, 2)
attributes = np.random.randint(0, 3, (91, 1))
Expand Down Expand Up @@ -506,7 +505,6 @@ def test_train_dataframe_wide_no_attributes(config: DGANConfig):

config.max_sequence_len = 4
config.sample_len = 1
config.epochs = 1

dg = DGAN(config=config)
dg.train_dataframe(df=df, df_style=DfStyle.WIDE)
Expand Down Expand Up @@ -1901,7 +1899,6 @@ def test_save_and_load(
attributes, attribute_types = attribute_data
features, feature_types = feature_data

config.epochs = 1
config.use_attribute_discriminator = use_attribute_discriminator
config.apply_example_scaling = apply_example_scaling
config.attribute_noise_dim = noise_dim
Expand Down Expand Up @@ -1957,7 +1954,6 @@ def test_save_and_load_no_attributes(
):
features, feature_types = feature_data

config.epochs = 1
config.use_attribute_discriminator = use_attribute_discriminator
config.apply_example_scaling = apply_example_scaling
config.attribute_noise_dim = noise_dim
Expand Down Expand Up @@ -2009,7 +2005,6 @@ def test_save_and_load_dataframe_with_attributes(config: DGANConfig, tmp_path):
)
config.max_sequence_len = 4
config.sample_len = 1
config.epochs = 1

dg = DGAN(config=config)

Expand Down Expand Up @@ -2044,7 +2039,6 @@ def test_attribute_and_feature_overlap(config: DGANConfig):
)
config.max_sequence_len = 4
config.sample_len = 1
config.epochs = 1

dg = DGAN(config=config)

Expand All @@ -2070,7 +2064,6 @@ def test_save_and_load_dataframe_no_attributes(config: DGANConfig, tmp_path):

config.max_sequence_len = 3
config.sample_len = 1
config.epochs = 1

dg = DGAN(config=config)

Expand Down Expand Up @@ -2101,7 +2094,6 @@ def test_dataframe_long_no_continuous_features(config: DGANConfig):

config.max_sequence_len = 3
config.sample_len = 1
config.epochs = 1

dg = DGAN(config=config)

Expand All @@ -2124,7 +2116,6 @@ def test_dataframe_wide_no_continuous_features(config: DGANConfig):

config.max_sequence_len = 3
config.sample_len = 1
config.epochs = 1

dg = DGAN(config=config)

Expand All @@ -2146,7 +2137,6 @@ def test_dataframe_long_partial_example(config: DGANConfig):

config.max_sequence_len = 10
config.sample_len = 1
config.epochs = 1

dg = DGAN(config=config)

Expand All @@ -2170,7 +2160,6 @@ def test_dataframe_long_one_and_partial_example(config: DGANConfig):

config.max_sequence_len = 5
config.sample_len = 1
config.epochs = 1

dg = DGAN(config=config)

Expand Down Expand Up @@ -2203,7 +2192,6 @@ def test_dataframe_variable_sequences(config: DGANConfig):

config.max_sequence_len = 8
config.sample_len = 1
config.epochs = 1

dg = DGAN(config=config)

Expand Down

0 comments on commit f95beb9

Please sign in to comment.