Skip to content

Commit

Permalink
Replace assertTrue(isinstance(...)) with assertIsInstance for easier …
Browse files Browse the repository at this point in the history
…debugging. (#831)

`assertIsInstance` prints out "x is not an instance of y".
`assertTrue(isinstance(...))` prints out "False is not True".
  • Loading branch information
hertschuh authored Sep 1, 2023
1 parent 339c15f commit fa547ec
Show file tree
Hide file tree
Showing 18 changed files with 105 additions and 105 deletions.
10 changes: 5 additions & 5 deletions keras_core/backend/common/compute_output_spec_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,19 @@ def test_basics(self):
out = backend.compute_output_spec(
example_fn, backend.KerasTensor((2, 3))
)
self.assertTrue(isinstance(out, backend.KerasTensor))
self.assertIsInstance(out, backend.KerasTensor)
self.assertEqual(out.shape, (2, 3, 2))

out = backend.compute_output_spec(
example_fn, backend.KerasTensor((None, 3))
)
self.assertTrue(isinstance(out, backend.KerasTensor))
self.assertIsInstance(out, backend.KerasTensor)
self.assertEqual(out.shape, (None, 3, 2))

out = backend.compute_output_spec(
example_fn, backend.KerasTensor((2, None))
)
self.assertTrue(isinstance(out, backend.KerasTensor))
self.assertIsInstance(out, backend.KerasTensor)
self.assertEqual(out.shape, (2, None, 2))

@pytest.mark.skipif(
Expand All @@ -51,14 +51,14 @@ def example_meta_fn(self, x):
out = backend.compute_output_spec(
instance.example_meta_fn, backend.KerasTensor((2, 3))
)
self.assertTrue(isinstance(out, backend.KerasTensor))
self.assertIsInstance(out, backend.KerasTensor)
self.assertTrue(instance.canary)
self.assertEqual(out.shape, (2, 3, 2))

instance = Container()
out = backend.compute_output_spec(
instance.example_meta_fn, backend.KerasTensor((2, None))
)
self.assertTrue(isinstance(out, backend.KerasTensor))
self.assertIsInstance(out, backend.KerasTensor)
self.assertTrue(instance.canary)
self.assertEqual(out.shape, (2, None, 2))
16 changes: 8 additions & 8 deletions keras_core/layers/layer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def compute_output_shape(self, input_shape):

layer = TestLayer()
out = layer.compute_output_spec(backend.KerasTensor((2, 3)))
self.assertTrue(isinstance(out, tuple))
self.assertIsInstance(out, tuple)
self.assertEqual(len(out), 2)
self.assertEqual(out[0].shape, (2, 3))
self.assertEqual(out[1].shape, (2, 3))
Expand All @@ -51,7 +51,7 @@ def compute_output_shape(self, input_shape):

layer = TestLayer()
out = layer.compute_output_spec(backend.KerasTensor((2, 3)))
self.assertTrue(isinstance(out, list))
self.assertIsInstance(out, list)
self.assertEqual(len(out), 2)
self.assertEqual(out[0].shape, (2, 3))
self.assertEqual(out[1].shape, (2, 3))
Expand All @@ -66,7 +66,7 @@ def compute_output_shape(self, input_shape):

layer = TestLayer()
out = layer.compute_output_spec(backend.KerasTensor((2, 3)))
self.assertTrue(isinstance(out, dict))
self.assertIsInstance(out, dict)
self.assertEqual(len(out), 2)
self.assertEqual(out["1"].shape, (2, 3))
self.assertEqual(out["2"].shape, (2, 3))
Expand All @@ -85,14 +85,14 @@ def compute_output_shape(self, input_shape):

layer = TestLayer()
out = layer.compute_output_spec(backend.KerasTensor((2, 3)))
self.assertTrue(isinstance(out, tuple))
self.assertIsInstance(out, tuple)
self.assertEqual(len(out), 3)
self.assertEqual(out[0].shape, (2, 3))
self.assertTrue(isinstance(out[1], tuple))
self.assertIsInstance(out[1], tuple)
self.assertEqual(len(out[1]), 2)
self.assertEqual(out[1][0].shape, (2, 3))
self.assertEqual(out[1][1].shape, (2, 3))
self.assertTrue(isinstance(out[2], tuple))
self.assertIsInstance(out[2], tuple)
self.assertEqual(len(out[2]), 2)
self.assertEqual(out[2][0].shape, (2, 3))
self.assertEqual(out[2][1].shape, (2, 3))
Expand All @@ -110,10 +110,10 @@ def compute_output_shape(self, input_shape):

layer = TestLayer()
out = layer.compute_output_spec(backend.KerasTensor((2, 3)))
self.assertTrue(isinstance(out, dict))
self.assertIsInstance(out, dict)
self.assertEqual(len(out), 2)
self.assertEqual(out["1"].shape, (2, 3))
self.assertTrue(isinstance(out["2"], dict))
self.assertIsInstance(out["2"], dict)
self.assertEqual(len(out["2"]), 2)
self.assertEqual(out["2"]["11"].shape, (2, 3))
self.assertEqual(out["2"]["22"].shape, (2, 3))
Expand Down
6 changes: 3 additions & 3 deletions keras_core/metrics/metric_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,13 +156,13 @@ def test_serialization(self):

def test_get_method(self):
metric = metrics_module.get("mse")
self.assertTrue(isinstance(metric, metrics_module.MeanSquaredError))
self.assertIsInstance(metric, metrics_module.MeanSquaredError)

metric = metrics_module.get("mean_squared_error")
self.assertTrue(isinstance(metric, metrics_module.MeanSquaredError))
self.assertIsInstance(metric, metrics_module.MeanSquaredError)

metric = metrics_module.get("categorical_accuracy")
self.assertTrue(isinstance(metric, metrics_module.CategoricalAccuracy))
self.assertIsInstance(metric, metrics_module.CategoricalAccuracy)

metric = metrics_module.get(None)
self.assertEqual(metric, None)
Expand Down
8 changes: 4 additions & 4 deletions keras_core/models/functional_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ def test_basic_flow_multi_input(self):
model.summary()

self.assertEqual(model.name, "basic")
self.assertTrue(isinstance(model, Functional))
self.assertTrue(isinstance(model, Model))
self.assertIsInstance(model, Functional)
self.assertIsInstance(model, Model)

# Eager call
in_val = [np.random.random((2, 3)), np.random.random((2, 3))]
Expand Down Expand Up @@ -72,14 +72,14 @@ def test_basic_flow_multi_output(self):
# Eager call
in_val = np.random.random((2, 3))
out_val = model(in_val)
self.assertTrue(isinstance(out_val, list))
self.assertIsInstance(out_val, list)
self.assertEqual(len(out_val), 2)
self.assertEqual(out_val[0].shape, (2, 4))
self.assertEqual(out_val[1].shape, (2, 5))

# Symbolic call
out_val = model(Input(shape=(3,), batch_size=2))
self.assertTrue(isinstance(out_val, list))
self.assertIsInstance(out_val, list)
self.assertEqual(len(out_val), 2)
self.assertEqual(out_val[0].shape, (2, 4))
self.assertEqual(out_val[1].shape, (2, 5))
Expand Down
32 changes: 16 additions & 16 deletions keras_core/models/model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def _get_model_multi_outputs_dict():
class ModelTest(testing.TestCase, parameterized.TestCase):
def test_functional_rerouting(self):
model = _get_model()
self.assertTrue(isinstance(model, Functional))
self.assertIsInstance(model, Functional)

def test_json_serialization(self):
model = _get_model()
Expand Down Expand Up @@ -113,7 +113,7 @@ def call(self, x):
new_model = Model.from_config(
config, custom_objects={"CustomDense": CustomDense}
)
self.assertTrue(isinstance(new_model, Functional))
self.assertIsInstance(new_model, Functional)

@parameterized.named_parameters(
("single_output_1", _get_model_single_output, None),
Expand All @@ -131,7 +131,7 @@ def call(self, x):
)
def test_functional_single_output(self, model_fn, loss_type):
model = model_fn()
self.assertTrue(isinstance(model, Functional))
self.assertIsInstance(model, Functional)
loss = "mean_squared_error"
if loss_type == "list":
loss = [loss]
Expand Down Expand Up @@ -172,7 +172,7 @@ def test_functional_single_output(self, model_fn, loss_type):

def test_functional_list_outputs_list_losses(self):
model = _get_model_multi_outputs_list()
self.assertTrue(isinstance(model, Functional))
self.assertIsInstance(model, Functional)
x = np.random.rand(8, 3)
y1 = np.random.rand(8, 1)
y2 = np.random.randint(0, 2, (8, 1))
Expand Down Expand Up @@ -203,7 +203,7 @@ def test_functional_list_outputs_list_losses(self):

def test_functional_list_outputs_list_losses_abbr(self):
model = _get_model_multi_outputs_list()
self.assertTrue(isinstance(model, Functional))
self.assertIsInstance(model, Functional)
x = np.random.rand(8, 3)
y1 = np.random.rand(8, 1)
y2 = np.random.randint(0, 2, (8, 1))
Expand Down Expand Up @@ -236,7 +236,7 @@ def test_functional_list_outputs_list_losses_abbr(self):

def test_functional_list_outputs_nested_list_losses(self):
model = _get_model_multi_outputs_list()
self.assertTrue(isinstance(model, Functional))
self.assertIsInstance(model, Functional)
x = np.random.rand(8, 3)
y1 = np.random.rand(8, 1)
y2 = np.random.randint(0, 2, (8, 1))
Expand Down Expand Up @@ -267,7 +267,7 @@ def test_functional_list_outputs_nested_list_losses(self):

def test_functional_dict_outputs_dict_losses(self):
model = _get_model_multi_outputs_dict()
self.assertTrue(isinstance(model, Functional))
self.assertIsInstance(model, Functional)
x = np.random.rand(8, 3)
y1 = np.random.rand(8, 1)
y2 = np.random.randint(0, 2, (8, 1))
Expand Down Expand Up @@ -313,7 +313,7 @@ def test_functional_dict_outputs_dict_losses(self):

def test_functional_list_outputs_dict_losses_metrics(self):
model = _get_model_multi_outputs_list()
self.assertTrue(isinstance(model, Functional))
self.assertIsInstance(model, Functional)
x = np.random.rand(8, 3)
y1 = np.random.rand(8, 1)
y2 = np.random.randint(0, 2, (8, 1))
Expand Down Expand Up @@ -353,7 +353,7 @@ def test_functional_list_outputs_dict_losses_metrics(self):

def test_functional_list_outputs_dict_losses_metrics_uniq_weighted(self):
model = _get_model_multi_outputs_list()
self.assertTrue(isinstance(model, Functional))
self.assertIsInstance(model, Functional)
x = np.random.rand(8, 3)
y1 = np.random.rand(8, 1)
y2 = np.random.randint(0, 2, (8, 1))
Expand Down Expand Up @@ -394,7 +394,7 @@ def test_functional_list_outputs_dict_losses_metrics_uniq_weighted(self):

def test_functional_list_outputs_dict_losses_partial_metrics(self):
model = _get_model_multi_outputs_list()
self.assertTrue(isinstance(model, Functional))
self.assertIsInstance(model, Functional)
x = np.random.rand(8, 3)
y1 = np.random.rand(8, 1)
y2 = np.random.randint(0, 2, (8, 1))
Expand Down Expand Up @@ -425,7 +425,7 @@ def test_functional_list_outputs_dict_losses_partial_metrics(self):

def test_functional_list_outputs_dict_losses_invalid_keys(self):
model = _get_model_multi_outputs_list()
self.assertTrue(isinstance(model, Functional))
self.assertIsInstance(model, Functional)
x = np.random.rand(8, 3)
y1 = np.random.rand(8, 1)
y2 = np.random.randint(0, 2, (8, 1))
Expand All @@ -446,7 +446,7 @@ def test_functional_list_outputs_dict_losses_invalid_keys(self):

def test_functional_list_outputs_dict_losses_no_output_names(self):
model = _get_model_multi_outputs_list_no_output_names()
self.assertTrue(isinstance(model, Functional))
self.assertIsInstance(model, Functional)
x = np.random.rand(8, 3)
y1 = np.random.rand(8, 1)
y2 = np.random.randint(0, 2, (8, 1))
Expand All @@ -464,7 +464,7 @@ def test_functional_list_outputs_dict_losses_no_output_names(self):

def test_functional_list_outputs_dict_metrics_invalid_keys(self):
model = _get_model_multi_outputs_list()
self.assertTrue(isinstance(model, Functional))
self.assertIsInstance(model, Functional)
x = np.random.rand(8, 3)
y1 = np.random.rand(8, 1)
y2 = np.random.randint(0, 2, (8, 1))
Expand All @@ -488,7 +488,7 @@ def test_functional_list_outputs_dict_metrics_invalid_keys(self):

def test_functional_dict_outputs_dict_losses_invalid_keys(self):
model = _get_model_multi_outputs_dict()
self.assertTrue(isinstance(model, Functional))
self.assertIsInstance(model, Functional)
x = np.random.rand(8, 3)
y1 = np.random.rand(8, 1)
y2 = np.random.randint(0, 2, (8, 1))
Expand All @@ -509,7 +509,7 @@ def test_functional_dict_outputs_dict_losses_invalid_keys(self):

def test_functional_dict_outputs_dict_metrics_invalid_keys(self):
model = _get_model_multi_outputs_dict()
self.assertTrue(isinstance(model, Functional))
self.assertIsInstance(model, Functional)
x = np.random.rand(8, 3)
y1 = np.random.rand(8, 1)
y2 = np.random.randint(0, 2, (8, 1))
Expand All @@ -533,7 +533,7 @@ def test_functional_dict_outputs_dict_metrics_invalid_keys(self):

def test_functional_list_outputs_invalid_nested_list_losses(self):
model = _get_model_multi_outputs_list()
self.assertTrue(isinstance(model, Functional))
self.assertIsInstance(model, Functional)
x = np.random.rand(8, 3)
y1 = np.random.rand(8, 1)
y2 = np.random.randint(0, 2, (8, 1))
Expand Down
4 changes: 2 additions & 2 deletions keras_core/optimizers/optimizer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ def test_constraints_are_applied(self):

def test_get_method(self):
obj = optimizers.get("sgd")
self.assertTrue(isinstance(obj, optimizers.SGD))
self.assertIsInstance(obj, optimizers.SGD)
obj = optimizers.get("adamw")
self.assertTrue(isinstance(obj, optimizers.AdamW))
self.assertIsInstance(obj, optimizers.AdamW)

obj = optimizers.get(None)
self.assertEqual(obj, None)
Expand Down
4 changes: 2 additions & 2 deletions keras_core/random/random_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def train_step(x):
return x

x = train_step(x)
self.assertTrue(isinstance(x, jnp.ndarray))
self.assertIsInstance(x, jnp.ndarray)

def test_dropout_noise_shape(self):
inputs = ops.ones((2, 3, 5, 7))
Expand All @@ -152,7 +152,7 @@ def test_jax_rngkey_seed(self):
self.assertEqual(rng.shape, (2,))
self.assertEqual(rng.dtype, jnp.uint32)
x = random.randint((3, 5), 0, 10, seed=rng)
self.assertTrue(isinstance(x, jnp.ndarray))
self.assertIsInstance(x, jnp.ndarray)

@pytest.mark.skipif(
keras_core.backend.backend() != "jax",
Expand Down
4 changes: 2 additions & 2 deletions keras_core/testing/test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def run_output_asserts(layer, output, eager=False):
msg="Unexpected output shape",
)
elif isinstance(expected_output_shape, dict):
self.assertTrue(isinstance(output, dict))
self.assertIsInstance(output, dict)
self.assertEqual(
set(output.keys()),
set(expected_output_shape.keys()),
Expand All @@ -251,7 +251,7 @@ def run_output_asserts(layer, output, eager=False):
msg="Unexpected output shape",
)
elif isinstance(expected_output_shape, list):
self.assertTrue(isinstance(output, list))
self.assertIsInstance(output, list)
self.assertEqual(
len(output),
len(
Expand Down
14 changes: 7 additions & 7 deletions keras_core/trainers/compile_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,14 @@ def test_single_output_case(self):
y_true, y_pred, sample_weight=sample_weight
)
result = compile_metrics.result()
self.assertTrue(isinstance(result, dict))
self.assertIsInstance(result, dict)
self.assertEqual(len(result), 2)
self.assertAllClose(result["mean_squared_error"], 0.055833336)
self.assertAllClose(result["weighted_mean_squared_error"], 0.0725)

compile_metrics.reset_state()
result = compile_metrics.result()
self.assertTrue(isinstance(result, dict))
self.assertIsInstance(result, dict)
self.assertEqual(len(result), 2)
self.assertAllClose(result["mean_squared_error"], 0.0)
self.assertAllClose(result["weighted_mean_squared_error"], 0.0)
Expand Down Expand Up @@ -98,14 +98,14 @@ def test_list_output_case(self):
y_true, y_pred, sample_weight=sample_weight
)
result = compile_metrics.result()
self.assertTrue(isinstance(result, dict))
self.assertIsInstance(result, dict)
self.assertEqual(len(result), 8)
self.assertAllClose(result["mean_squared_error"], 0.055833336)
self.assertAllClose(result["weighted_mean_squared_error"], 0.0725)

compile_metrics.reset_state()
result = compile_metrics.result()
self.assertTrue(isinstance(result, dict))
self.assertIsInstance(result, dict)
self.assertEqual(len(result), 8)
self.assertAllClose(result["mean_squared_error"], 0.0)
self.assertAllClose(result["weighted_mean_squared_error"], 0.0)
Expand Down Expand Up @@ -167,7 +167,7 @@ def test_dict_output_case(self):
y_true, y_pred, sample_weight=sample_weight
)
result = compile_metrics.result()
self.assertTrue(isinstance(result, dict))
self.assertIsInstance(result, dict)
self.assertEqual(len(result), 8)
# Result values obtained from `tf.keras`
# m = tf.keras.metrics.MeanSquaredError()
Expand All @@ -189,7 +189,7 @@ def test_dict_output_case(self):

compile_metrics.reset_state()
result = compile_metrics.result()
self.assertTrue(isinstance(result, dict))
self.assertIsInstance(result, dict)
self.assertEqual(len(result), 8)
self.assertAllClose(result["output_1_mean_squared_error"], 0.0)
self.assertAllClose(result["output_2_mean_squared_error"], 0.0)
Expand All @@ -206,7 +206,7 @@ def test_name_conversions(self):
compile_metrics.build(y_true, y_pred)
compile_metrics.update_state(y_true, y_pred, sample_weight=None)
result = compile_metrics.result()
self.assertTrue(isinstance(result, dict))
self.assertIsInstance(result, dict)
self.assertEqual(len(result), 3)
self.assertAllClose(result["acc"], 0.333333)
self.assertAllClose(result["accuracy"], 0.333333)
Expand Down
Loading

0 comments on commit fa547ec

Please sign in to comment.