Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate Triton up to [68aa962e67baa191cec5aac173255abdba80db1a](https://github.com/openai/triton/commits/68aa962e67baa191cec5aac173255abdba80db1a) #311

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions jax_triton/triton_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,10 +370,11 @@ def get_or_create_triton_kernel(
# `JITFunction._get_config` to get the specialization_attr.
mock_torch_tensor = types.SimpleNamespace(data_ptr=lambda: 16)
args_for_specialization_attr = [mock_torch_tensor] * len(arg_dtypes)
backend = backend_init_func(device, compute_capability)
for i, _, v in scalar_args:
args_for_specialization_attr[i] = v
specialization_attr = fn._get_config(*args_for_specialization_attr) # pylint: disable=protected-access

specialization_attr = backend.get_attrs_descriptor(fn.params[:len(args_for_specialization_attr)], args_for_specialization_attr) # pylint: disable=protected-access
constants = dict(metaparams)
constants.update({k: None for _, k, v in scalar_args if v is None})
constants.update({fn.arg_names[i]: 1 for i in specialization_attr.equal_to_1})
Expand All @@ -382,7 +383,7 @@ def get_or_create_triton_kernel(
cache_key = (
fn,
tuple(signature.items()),
tuple(vars(specialization_attr).values()),
tuple(specialization_attr.arg_properties),
tuple(constants.items()),
num_warps,
num_stages,
Expand All @@ -402,7 +403,6 @@ def get_or_create_triton_kernel(
"enable_fp_fusion": enable_fp_fusion,
}

backend = backend_init_func(device, compute_capability)
options = backend.parse_options(opts)

kernel_hash = abs(hash(cache_key))
Expand Down Expand Up @@ -645,7 +645,7 @@ def prune_configs(configs, named_args, **kwargs):
kernel_params.append(
triton_kernel_call_lib.create_array_parameter(
zeroed_params_with_sizes.get(i, 0),
16 if (i in specialization_attr.divisible_by_16) else 0,
16 if (i in specialization_attr.divisibility_16) else 0,
)
)
elif i not in specialization_attr.equal_to_1:
Expand Down
4 changes: 2 additions & 2 deletions tests/triton_call_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,10 +564,10 @@ def test_specialization(self):
# Pointers are assumed to divide by 16, as do `M`, `N`, `stride_{bk,cm}`.
# However, we've marked `a_ptr`, `M`, `stride_bk`, and `c_ptr` as "do not
# specialize", leaving `b_ptr`, `N`, and `stride_cm`.
self.assertEqual(specialization.attrs.divisible_by_16, (1, 3, 9))
self.assertEqual(specialization.attrs.divisibility_16, [1, 3, 9])
# `stride_{ak,bn,cn}` equal 1, but we've marked `stride_ak` as "do not
# specialize" leaving `stride_{bn,cn}`.
self.assertEqual(specialization.attrs.equal_to_1, (8, 10))
self.assertEqual(specialization.attrs.equal_to_1, [8, 10])


if __name__ == "__main__":
Expand Down
Loading