From b70a7b690cdee075788df2d1491a9ca5d1812b1d Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Fri, 26 Apr 2024 11:20:47 -0400 Subject: [PATCH] [Thrust] Increase static workspace size This PR increases the thrust workspace size, since in practice we found that the current workspace size can still be insufficient. Thrust sort may require larger workspace when the number of elements being sorted is large (e.g., in Llama3 that is 128k). --- python/tvm/relax/backend/dispatch_sort_scan.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/relax/backend/dispatch_sort_scan.py b/python/tvm/relax/backend/dispatch_sort_scan.py index eb82e49d9a99..561c141efa8b 100644 --- a/python/tvm/relax/backend/dispatch_sort_scan.py +++ b/python/tvm/relax/backend/dispatch_sort_scan.py @@ -186,13 +186,13 @@ def estimate_thrust_workspace_size(self, call: relax.Call) -> int: int32_byte_per_elem = DataType("int32").bits // 8 num_elem = reduce(mul, input_shape, 1) input_size = num_elem * input_byte_per_elem - # Most GPU algorithms take O(n) space or less, we choose 8N + 4MB as a safe estimation + # Most GPU algorithms take O(n) space or less, we choose 8N + 8MB as a safe estimation # for algorithm workspace. # The current thrust sort implementation may need extra int64 and int32 arrays # for temporary data, so we further add this part to the workspace. return ( 8 * input_size - + 4 * 1024 * 1024 + + 8 * 1024 * 1024 + num_elem * (int64_byte_per_elem + int32_byte_per_elem) )