Skip to content

Commit

Permalink
Revert "Replace deprecated assert_allclose (#405)" (#409)
Browse files Browse the repository at this point in the history
This reverts commit e226606.

Signed-off-by: Harsh Menon <harsh@nod-labs.com>
  • Loading branch information
harsh-nod authored Jan 23, 2025
1 parent 8b7ee88 commit 1847c33
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 24 deletions.
4 changes: 2 additions & 2 deletions tests/kernel/wave/attention/decode_attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
get_decode_attention_kernels,
)
import os
from torch.testing import assert_close
from torch.testing import assert_allclose
from ..common.utils import (
require_e2e,
enable_scheduling_barriers,
Expand Down Expand Up @@ -133,4 +133,4 @@ def testFlashDecoding(
with open(filename, "w") as f:
f.write(mb_sv.module_op.get_asm())

assert_close(output, torch_ref, check_dtype=False, atol=1e-3, rtol=1e-3)
assert_allclose(output, torch_ref)
4 changes: 2 additions & 2 deletions tests/kernel/wave/attention/paged_attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
paged_decode_attention_shape,
)
import os
from torch.testing import assert_close
from torch.testing import assert_allclose
from ..common.utils import (
require_e2e,
require_cdna3,
Expand Down Expand Up @@ -314,4 +314,4 @@ def testPagedFlashDecoding(
else:
ref_vllm_output = torch.load(os.path.join(artifact_directory, "output.pt"))

assert_close(output, ref_vllm_output, rtol=1e-3, atol=1e-3)
assert_allclose(output, ref_vllm_output, rtol=1e-3, atol=1e-3)
8 changes: 4 additions & 4 deletions tests/kernel/wave/attention/vanilla_attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
)
from iree.turbine.kernel.wave.constraints import MMAType
import os
from torch.testing import assert_close
from torch.testing import assert_allclose
from ..common.utils import (
require_e2e,
require_cdna3,
Expand Down Expand Up @@ -227,7 +227,7 @@ def repeat(
with open(filename, "w") as f:
f.write(mb.module_op.get_asm())

assert_close(output, torch_ref, check_dtype=False, atol=1e-3, rtol=1e-3)
assert_allclose(output, torch_ref)


@require_e2e
Expand Down Expand Up @@ -428,10 +428,10 @@ def repeat(
f.write(mb.module_op.get_asm())

if "gfx94" in config["target"]:
assert_close(output, torch_ref, atol=2e-3, rtol=5e-3, check_dtype=False)
assert_allclose(output, torch_ref, atol=2e-3, rtol=5e-3)
else:
# TODO: Determine why the error is higher on gfx90.
assert_close(output, torch_ref, atol=3e-3, rtol=8e-1, check_dtype=False)
assert_allclose(output, torch_ref, atol=3e-3, rtol=8e-1)


@require_e2e
Expand Down
14 changes: 7 additions & 7 deletions tests/kernel/wave/runtime/cache_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import copy
import pytest
import torch
from torch.testing import assert_close
from torch.testing import assert_allclose
import math
import iree.turbine.kernel as tk
import iree.turbine.kernel.lang as tkl
Expand Down Expand Up @@ -203,7 +203,7 @@ def testSameConfig(request):
# First run/call to kernel, this should compile from scratch.
output = device_zeros(shape[0], shape[1], shape[2], dtype=torch.float32)
mb = base_attention(q * dk_sqrt * log2e, k, v.permute([0, 2, 1]), output)
assert_close(output, torch_ref)
assert_allclose(output, torch_ref)
assert isinstance(
mb, tk.compiler.builder.ModuleBuilder
), "Expected first call to not be cached."
Expand All @@ -216,7 +216,7 @@ def testSameConfig(request):
cached_kernel = base_attention(
q * dk_sqrt * log2e, k, v.permute([0, 2, 1]), output
)
assert_close(output, torch_ref)
assert_allclose(output, torch_ref)
assert (
len(cache_manager.session_cache) == 1
), "Expected to keep size of cache because we reuse same kernel."
Expand Down Expand Up @@ -327,7 +327,7 @@ def testDifferentDynamicSameBlock(request):
v_shape_0.permute([0, 2, 1]),
output_shape_0,
)
assert_close(output_shape_0, torch_ref_shape_0)
assert_allclose(output_shape_0, torch_ref_shape_0)
assert isinstance(
mb, tk.compiler.builder.ModuleBuilder
), "Expected first call to not be cached."
Expand Down Expand Up @@ -378,7 +378,7 @@ def testDifferentDynamicSameBlock(request):
v_shape_1.permute([0, 2, 1]),
output_shape_1,
)
assert_close(output_shape_1, torch_ref_shape_1)
assert_allclose(output_shape_1, torch_ref_shape_1)
assert (
len(cache_manager.session_cache) == 1
), "Expected to keep size of cache because we reuse same kernel."
Expand Down Expand Up @@ -475,7 +475,7 @@ def testSameSizeDifferentBlock(request):
mb_config_0 = base_attention(
q * dk_sqrt * log2e, k, v.permute([0, 2, 1]), output
)
assert_close(output, torch_ref)
assert_allclose(output, torch_ref)
assert isinstance(
mb_config_0, tk.compiler.builder.ModuleBuilder
), "Expected first call to not be cached."
Expand All @@ -499,7 +499,7 @@ def testSameSizeDifferentBlock(request):
mb_config_1 = base_attention(
q * dk_sqrt * log2e, k, v.permute([0, 2, 1]), output
)
assert_close(output, torch_ref)
assert_allclose(output, torch_ref)
assert (
len(cache_manager.session_cache) == 2
), "Expected cache size to increment, because we use different block size/config."
Expand Down
18 changes: 9 additions & 9 deletions tests/kernel/wave/wave_sim_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import iree.turbine.kernel.lang as tkl
import iree.turbine.kernel.wave as tkw
from iree.turbine.kernel.wave.wave_sim import wave_sim
from torch.testing import assert_close
from numpy.testing import assert_allclose


def test_eltwise():
Expand Down Expand Up @@ -47,7 +47,7 @@ def eltwise(
b = torch.randn(128, 256, dtype=torch.float32)
c = torch.zeros(128, 256, dtype=torch.float32)
eltwise(a, b, c)
assert_close(c, a + b)
assert_allclose(c, a + b)


def test_broadcast_1():
Expand Down Expand Up @@ -85,7 +85,7 @@ def eltwise(
b = torch.randn(256, dtype=torch.float32)
c = torch.zeros(128, 256, dtype=torch.float32)
eltwise(a, b, c)
assert_close(c, a + b)
assert_allclose(c, a + b)


def test_broadcast_2():
Expand Down Expand Up @@ -120,7 +120,7 @@ def eltwise(
b = torch.randn(256, dtype=torch.float32)
c = torch.zeros(128, 256, dtype=torch.float32)
eltwise(b, c)
assert_close(c, b + torch.zeros(128, 256, dtype=torch.float32))
assert_allclose(c, b + torch.zeros(128, 256, dtype=torch.float32))


def test_broadcast_3():
Expand Down Expand Up @@ -155,7 +155,7 @@ def eltwise(
b = torch.randn(256, dtype=torch.float32)
c = torch.zeros(128, 256, dtype=torch.float32)
eltwise(b, c)
assert_close(c, b[0] + torch.zeros(128, 256, dtype=torch.float32))
assert_allclose(c, b[0] + torch.zeros(128, 256, dtype=torch.float32))


def test_gemm():
Expand Down Expand Up @@ -215,7 +215,7 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
b = torch.randn(128, 256, dtype=torch.float16)
c = torch.zeros(64, 128, dtype=torch.float32)
gemm(a, b, c)
assert_close(c, a @ b.T, check_dtype=False)
assert_allclose(c, a @ b.T)


def test_transpose_1():
Expand Down Expand Up @@ -256,7 +256,7 @@ def transpose(
a = torch.randn(128, 256, dtype=torch.float32)
c = torch.zeros(256, 128, dtype=torch.float32)
transpose(a, c)
assert_close(c, a.T)
assert_allclose(c, a.T)


def test_transpose_2():
Expand Down Expand Up @@ -302,7 +302,7 @@ def transpose(
a = torch.randn(128, 256, dtype=torch.float32)
c = torch.zeros(256, 128, dtype=torch.float32)
transpose(a, c)
assert_close(c, a.T)
assert_allclose(c, a.T)


@pytest.mark.parametrize("n", [1, 2, 4])
Expand Down Expand Up @@ -408,4 +408,4 @@ def repeat(acc: tkl.Register[M, NF, tkl.f32]) -> tkl.Register[M, NF, tkl.f32]:

out = torch.zeros_like(out_ref)
conv(x, we, out)
assert_close(out, out_ref, rtol=1e-05, atol=1e-05)
assert_allclose(out, out_ref, rtol=1e-05, atol=1e-05)

0 comments on commit 1847c33

Please sign in to comment.