diff --git a/src/serialbox-python/compare/compare.py b/src/serialbox-python/compare/compare.py index f98c9988..aef44021 100644 --- a/src/serialbox-python/compare/compare.py +++ b/src/serialbox-python/compare/compare.py @@ -264,18 +264,23 @@ def compare_fields(serializers, field, savepoint, dim_bounds): "error": float('nan')}] # Compute error else: - abs_error = abs(value_2 - value_1) - rel_error = abs((value_2 - value_1) / value_2) if abs(value_2) > 1.0 else 0 - err = rel_error if abs(value_2) > 1.0 else abs_error - - - # Check error - if err > tol: - errors += [ - {"index": it_1.multi_index, "value_1": value_1, "value_2": value_2, - "error": err}] - max_abs_error = max(max_abs_error, abs_error) - max_rel_error = max(max_rel_error, rel_error) + if(value_1.dtype == 'bool'): + if(value_1 != value_2): + errors += [ + {"index": it_1.multi_index, "value_1": value_1, "value_2": value_2, + "error": 1.0}] + else: + abs_error = abs(value_2 - value_1) + rel_error = abs((value_2 - value_1) / value_2) if abs(value_2) > 1.0 else 0 + err = rel_error if abs(value_2) > 1.0 else abs_error + + # Check error + if err > tol: + errors += [ + {"index": it_1.multi_index, "value_1": value_1, "value_2": value_2, + "error": err}] + max_abs_error = max(max_abs_error, abs_error) + max_rel_error = max(max_rel_error, rel_error) it_1.iternext() it_2.iternext()