@@ -49,13 +49,12 @@ def test_copy_blocks(
49
49
src_blocks = random .sample (range (num_blocks ), num_mappings )
50
50
remainig_blocks = list (set (range (num_blocks )) - set (src_blocks ))
51
51
dst_blocks = random .sample (remainig_blocks , 2 * num_mappings )
52
- copy_src = []
53
- copy_dst = []
52
+ block_mapping = {}
54
53
for i in range (num_mappings ):
55
- copy_src . append ( src_blocks [i ])
56
- copy_dst . append ( dst_blocks [2 * i ])
57
- copy_src . append ( src_blocks [ i ])
58
- copy_dst . append ( dst_blocks [ 2 * i + 1 ])
54
+ src = src_blocks [i ]
55
+ dst1 = dst_blocks [2 * i ]
56
+ dst2 = dst_blocks [ 2 * i + 1 ]
57
+ block_mapping [ src ] = [ dst1 , dst2 ]
59
58
60
59
# Create the KV caches.
61
60
key_caches , value_caches = kv_cache_factory (num_blocks , block_size ,
@@ -67,14 +66,15 @@ def test_copy_blocks(
67
66
cloned_value_caches = [value_cache .clone () for value_cache in value_caches ]
68
67
69
68
# Call the copy blocks kernel.
70
- cache_ops .copy_blocks (key_caches , value_caches , copy_src , copy_dst )
69
+ cache_ops .copy_blocks (key_caches , value_caches , block_mapping )
71
70
72
71
# Run the reference implementation.
73
- for src , dst in zip (copy_src , copy_dst ):
74
- for cloned_key_cache in cloned_key_caches :
75
- cloned_key_cache [dst ].copy_ (cloned_key_cache [src ])
76
- for cloned_value_cache in cloned_value_caches :
77
- cloned_value_cache [dst ].copy_ (cloned_value_cache [src ])
72
+ for src , dsts in block_mapping .items ():
73
+ for dst in dsts :
74
+ for cloned_key_cache in cloned_key_caches :
75
+ cloned_key_cache [dst ].copy_ (cloned_key_cache [src ])
76
+ for cloned_value_cache in cloned_value_caches :
77
+ cloned_value_cache [dst ].copy_ (cloned_value_cache [src ])
78
78
79
79
# Compare the results.
80
80
for key_cache , cloned_key_cache in zip (key_caches , cloned_key_caches ):
0 commit comments