Skip to content

Commit

Permalink
[Thrust] Increase static workspace size (#16937)
Browse files Browse the repository at this point in the history
This PR increases the thrust workspace size, since in practice
we found that the current workspace size can still be insufficient.
Thrust sort may require larger workspace when the number of elements
being sorted is large (e.g., in Llama3 that is 128k).
  • Loading branch information
MasterJH5574 authored Apr 27, 2024
1 parent 3ff3daa commit 63e0a0f
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions python/tvm/relax/backend/dispatch_sort_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,13 +227,13 @@ def estimate_thrust_workspace_size(self, call: relax.Call) -> int:
int32_byte_per_elem = DataType("int32").bits // 8
num_elem = reduce(mul, input_shape, 1)
input_size = num_elem * input_byte_per_elem
# Most GPU algorithms take O(n) space or less, we choose 8N + 4MB as a safe estimation
# Most GPU algorithms take O(n) space or less, we choose 8N + 8MB as a safe estimation
# for algorithm workspace.
# The current thrust sort implementation may need extra int64 and int32 arrays
# for temporary data, so we further add this part to the workspace.
return (
8 * input_size
+ 4 * 1024 * 1024
+ 8 * 1024 * 1024
+ num_elem * (int64_byte_per_elem + int32_byte_per_elem)
)

Expand Down

0 comments on commit 63e0a0f

Please sign in to comment.