Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Kernel docs #274

Merged
merged 3 commits into from
May 25, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 54 additions & 5 deletions torchao/kernel/intmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,24 @@
from torch._dynamo import is_compiling as dynamo_is_compiling
from torch._higher_order_ops.out_dtype import out_dtype
def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor:
"""
Performs a safe integer matrix multiplication, considering different paths for
torch.compile, cublas, and fallback cases.

Args:
input (torch.Tensor): The input tensor of shape [i, j].
mat2 (torch.Tensor): The matrix to multiply with, of shape [j, k].

Returns:
torch.Tensor: The result of the matrix multiplication.

Raises:
AssertionError: If the tensors are not on the same device.
"""
# torch.compile path
if dynamo_is_compiling() or "FakeTensor" in input.__repr__():
return out_dtype(torch.ops.aten.mm.default, torch.int32, input, mat2)

# error checking for cublas path
assert (
mat2.device == input.device
Expand All @@ -39,13 +53,13 @@ def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor:
and j_is_nonzero_multiple_of_8
and k_is_nonzero_multiple_of_8
)

if device_cpu or bad_dimensions_for_cublas:
# fallback path
return torch.matmul(input.cpu().to(torch.int32), mat2.cpu().to(torch.int32)).to(
input.device.type
)

# cublas paths
if not mat2.is_contiguous(): # silently gives incorrect result without this
mat2 = mat2.contiguous()
Expand All @@ -58,18 +72,53 @@ def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor:
return out_dtype(torch.ops.aten.mm.default, torch.int32, input, mat2)
else:
def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor:
"""
Performs a fallback integer matrix multiplication for torch versions before 2.2.

Args:
input (torch.Tensor): The input tensor of shape [i, j].
mat2 (torch.Tensor): The matrix to multiply with, of shape [j, k].

Returns:
torch.Tensor: The result of the matrix multiplication in int32.
"""
# We can improve on this by writing Triton code that works for older versions of Triton
# that ship with 2.1 or 2.0.
return torch.matmul(input.to(torch.float32), mat2.to(torch.float32)).to(torch.int32)


def int_matmul(a, b):
def int_matmul(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
"""
Performs integer matrix multiplication using intmm_triton if available and autotuner is enabled,
otherwise falls back to safe_int_mm.

Args:
a (torch.Tensor): The first matrix to multiply.
b (torch.Tensor): The second matrix to multiply.

Returns:
torch.Tensor: The result of the matrix multiplication.
"""
if intmm_triton is not None and AUTOTUNER_ENABLE:
return torch.ops.torchao.int_matmul(a, b)
return safe_int_mm(a, b)


def int_scaled_matmul(a, b, scales1):
def int_scaled_matmul(a: torch.Tensor, b: torch.Tensor, scales1: torch.Tensor) -> torch.Tensor:
"""
Performs scaled integer matrix multiplication.

Args:
a (torch.Tensor): The first matrix to multiply.
b (torch.Tensor): The second matrix to multiply.
scales1 (torch.Tensor): The scaling factors for the rows of the result.

Returns:
torch.Tensor: The result of the scaled matrix multiplication.

Raises:
AssertionError: If the dimensions of the input tensors do not match the expected shapes.
"""
M, K = a.shape
K, N = b.shape
assert M == scales1.size(0)
Expand Down
Loading