Skip to content

Commit f0317d9

Browse files
committed
Sync server : context checkpointing for hybrid and recurrent models
1 parent 869b22f commit f0317d9

File tree

1 file changed

+94
-0
lines changed

1 file changed

+94
-0
lines changed

llama_cpp/llama_cpp.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1694,6 +1694,14 @@ def llama_model_is_recurrent(model: llama_model_p, /) -> bool:
16941694
...
16951695

16961696

1697+
# // Returns true if the model is hybrid (like Jamba, Granite, etc.)
1698+
# LLAMA_API bool llama_model_is_hybrid(const struct llama_model * model);
1699+
@ctypes_function("llama_model_is_hybrid", [llama_model_p_ctypes], ctypes.c_bool)
1700+
def llama_model_is_hybrid(model: llama_model_p, /) -> bool:
1701+
"""Returns true if the model is hybrid (like Jamba, Granite, etc.)"""
1702+
...
1703+
1704+
16971705
# // Returns true if the model is diffusion-based (like LLaDA, Dream, etc.)
16981706
# LLAMA_API bool llama_model_is_diffusion(const struct llama_model * model);
16991707
@ctypes_function("llama_model_is_diffusion", [llama_model_p_ctypes], ctypes.c_bool)
@@ -2539,6 +2547,92 @@ def llama_state_seq_load_file(
25392547
...
25402548

25412549

2550+
# // for backwards-compat
2551+
LLAMA_STATE_SEQ_FLAGS_SWA_ONLY = 1
2552+
2553+
# // work only with partial states, such as SWA KV cache or recurrent cache (e.g. Mamba)
2554+
LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY = 1
2555+
2556+
llama_state_seq_flags = ctypes.c_uint32
2557+
2558+
# LLAMA_API size_t llama_state_seq_get_size_ext(
2559+
# struct llama_context * ctx,
2560+
# llama_seq_id seq_id,
2561+
# llama_state_seq_flags flags);
2562+
@ctypes_function(
2563+
"llama_state_seq_get_size_ext",
2564+
[
2565+
llama_context_p_ctypes,
2566+
llama_seq_id,
2567+
llama_state_seq_flags,
2568+
],
2569+
ctypes.c_size_t,
2570+
)
2571+
def llama_state_seq_get_size_ext(
2572+
ctx: llama_context_p,
2573+
seq_id: llama_seq_id,
2574+
flags: llama_state_seq_flags,
2575+
/,
2576+
) -> int:
2577+
...
2578+
2579+
2580+
# LLAMA_API size_t llama_state_seq_get_data_ext(
2581+
# struct llama_context * ctx,
2582+
# uint8_t * dst,
2583+
# size_t size,
2584+
# llama_seq_id seq_id,
2585+
# llama_state_seq_flags flags);
2586+
@ctypes_function(
2587+
"llama_state_seq_get_data_ext",
2588+
[
2589+
llama_context_p_ctypes,
2590+
ctypes.POINTER(ctypes.c_uint8),
2591+
ctypes.c_size_t,
2592+
llama_seq_id,
2593+
llama_state_seq_flags,
2594+
],
2595+
ctypes.c_size_t,
2596+
)
2597+
def llama_state_seq_get_data_ext(
2598+
ctx: llama_context_p,
2599+
dst: ctypes.POINTER(ctypes.c_uint8),
2600+
size: Union[int, ctypes.c_size_t],
2601+
seq_id: llama_seq_id,
2602+
flags: llama_state_seq_flags,
2603+
/,
2604+
) -> int:
2605+
...
2606+
2607+
2608+
# LLAMA_API size_t llama_state_seq_set_data_ext(
2609+
# struct llama_context * ctx,
2610+
# const uint8_t * src,
2611+
# size_t size,
2612+
# llama_seq_id dest_seq_id,
2613+
# llama_state_seq_flags flags);
2614+
@ctypes_function(
2615+
"llama_state_seq_set_data_ext",
2616+
[
2617+
llama_context_p_ctypes,
2618+
ctypes.POINTER(ctypes.c_uint8),
2619+
ctypes.c_size_t,
2620+
llama_seq_id,
2621+
llama_state_seq_flags,
2622+
],
2623+
ctypes.c_size_t,
2624+
)
2625+
def llama_state_seq_set_data_ext(
2626+
ctx: llama_context_p,
2627+
src: ctypes.POINTER(ctypes.c_uint8),
2628+
size: Union[int, ctypes.c_size_t],
2629+
dest_seq_id: llama_seq_id,
2630+
flags: llama_state_seq_flags,
2631+
/,
2632+
) -> int:
2633+
...
2634+
2635+
25422636
# //
25432637
# // Decoding
25442638
# //

0 commit comments

Comments
 (0)