Skip to content

Commit

Permalink
add permute cols opcheck
Browse files Browse the repository at this point in the history
  • Loading branch information
LucasWilkinson committed Sep 13, 2024
1 parent c452a86 commit 7bc8316
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 0 deletions.
14 changes: 14 additions & 0 deletions tests/kernels/test_permute_cols.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import torch
import pytest
from vllm._custom_ops import permute_cols
from tests.kernels.utils import opcheck


@pytest.mark.parametrize('shape', [(1, 512), (544, 4096), (67, 8192)])
@pytest.mark.parametrize('dtype', [torch.bfloat16, torch.float16])
def test_permute_cols(shape, dtype):
x = torch.randn(shape, dtype=dtype).cuda()
perm = torch.randperm(x.shape[1]).to(torch.int).cuda()
opcheck(torch.ops._C.permute_cols, (x, perm))
y = permute_cols(x, perm)
torch.testing.assert_close(y, x[:, perm])
12 changes: 12 additions & 0 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,6 +576,18 @@ def machete_prepack_B(b_q_weight: torch.Tensor,
return torch.ops._C.machete_prepack_B(b_q_weight, b_type)


# TODO: has to be a better way to do this
try:
torch.ops._C.permute_cols # noqa B018

@torch.library.register_fake("_C::permute_cols")
def _permute_cols_fake(a: torch.Tensor,
perm: torch.Tensor) -> torch.Tensor:
return torch.empty_like(a)
except Exception:
pass


def permute_cols(a: torch.Tensor, perm: torch.Tensor) -> torch.Tensor:
return torch.ops._C.permute_cols(a, perm)

Expand Down

0 comments on commit 7bc8316

Please sign in to comment.