Skip to content

Commit

Permalink
Fix threads parameter
Browse files Browse the repository at this point in the history
See #8
  • Loading branch information
marella committed May 25, 2023
1 parent 75ba8ef commit be827d0
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 5 deletions.
29 changes: 24 additions & 5 deletions ctransformers/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,22 +87,41 @@ def load_library(path: Optional[str] = None) -> Any:
path = find_library(path)
lib = CDLL(path)

lib.ctransformers_llm_create.argtypes = [c_char_p, c_char_p]
lib.ctransformers_llm_create.argtypes = [
c_char_p, # model_path
c_char_p, # model_type
]
lib.ctransformers_llm_create.restype = llm_p

lib.ctransformers_llm_delete.argtypes = [llm_p]
lib.ctransformers_llm_delete.restype = None

lib.ctransformers_llm_tokenize.argtypes = [llm_p, c_char_p, c_int_p]
lib.ctransformers_llm_tokenize.argtypes = [
llm_p,
c_char_p, # text
c_int_p, # output
]
lib.ctransformers_llm_tokenize.restype = c_int

lib.ctransformers_llm_detokenize.argtypes = [llm_p, c_int]
lib.ctransformers_llm_detokenize.argtypes = [
llm_p,
c_int, # token
]
lib.ctransformers_llm_detokenize.restype = c_char_p

lib.ctransformers_llm_is_eos_token.argtypes = [llm_p, c_int]
lib.ctransformers_llm_is_eos_token.argtypes = [
llm_p,
c_int, # token
]
lib.ctransformers_llm_is_eos_token.restype = c_bool

lib.ctransformers_llm_batch_eval.argtypes = [llm_p, c_int_p, c_int, c_int]
lib.ctransformers_llm_batch_eval.argtypes = [
llm_p,
c_int_p, # tokens
c_int, # n_tokens
c_int, # batch_size
c_int, # threads
]
lib.ctransformers_llm_batch_eval.restype = c_bool

lib.ctransformers_llm_sample.argtypes = [
Expand Down
1 change: 1 addition & 0 deletions models/llm.h
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ class LLM {
if (threads < 0) {
threads = std::min((int)std::thread::hardware_concurrency(), 4);
}
threads = std::max(threads, 1);
const int n_past =
std::min(ContextLength() - (int)tokens.size(), previous_tokens_.Size());
if (!Eval(tokens, threads, n_past)) {
Expand Down

0 comments on commit be827d0

Please sign in to comment.