Skip to content

Commit 79b304c

Browse files
committedMar 23, 2023
Initial commit
0 parents  commit 79b304c

10 files changed

+736
-0
lines changed
 

‎.gitignore

+164
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
.envrc
2+
3+
models/
4+
5+
# Byte-compiled / optimized / DLL files
6+
__pycache__/
7+
*.py[cod]
8+
*$py.class
9+
10+
# C extensions
11+
*.so
12+
13+
# Distribution / packaging
14+
.Python
15+
build/
16+
develop-eggs/
17+
dist/
18+
downloads/
19+
eggs/
20+
.eggs/
21+
lib/
22+
lib64/
23+
parts/
24+
sdist/
25+
var/
26+
wheels/
27+
share/python-wheels/
28+
*.egg-info/
29+
.installed.cfg
30+
*.egg
31+
MANIFEST
32+
33+
# PyInstaller
34+
# Usually these files are written by a python script from a template
35+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
36+
*.manifest
37+
*.spec
38+
39+
# Installer logs
40+
pip-log.txt
41+
pip-delete-this-directory.txt
42+
43+
# Unit test / coverage reports
44+
htmlcov/
45+
.tox/
46+
.nox/
47+
.coverage
48+
.coverage.*
49+
.cache
50+
nosetests.xml
51+
coverage.xml
52+
*.cover
53+
*.py,cover
54+
.hypothesis/
55+
.pytest_cache/
56+
cover/
57+
58+
# Translations
59+
*.mo
60+
*.pot
61+
62+
# Django stuff:
63+
*.log
64+
local_settings.py
65+
db.sqlite3
66+
db.sqlite3-journal
67+
68+
# Flask stuff:
69+
instance/
70+
.webassets-cache
71+
72+
# Scrapy stuff:
73+
.scrapy
74+
75+
# Sphinx documentation
76+
docs/_build/
77+
78+
# PyBuilder
79+
.pybuilder/
80+
target/
81+
82+
# Jupyter Notebook
83+
.ipynb_checkpoints
84+
85+
# IPython
86+
profile_default/
87+
ipython_config.py
88+
89+
# pyenv
90+
# For a library or package, you might want to ignore these files since the code is
91+
# intended to run in multiple environments; otherwise, check them in:
92+
# .python-version
93+
94+
# pipenv
95+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
96+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
97+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
98+
# install all needed dependencies.
99+
#Pipfile.lock
100+
101+
# poetry
102+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
103+
# This is especially recommended for binary packages to ensure reproducibility, and is more
104+
# commonly ignored for libraries.
105+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
106+
#poetry.lock
107+
108+
# pdm
109+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
110+
#pdm.lock
111+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
112+
# in version control.
113+
# https://pdm.fming.dev/#use-with-ide
114+
.pdm.toml
115+
116+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
117+
__pypackages__/
118+
119+
# Celery stuff
120+
celerybeat-schedule
121+
celerybeat.pid
122+
123+
# SageMath parsed files
124+
*.sage.py
125+
126+
# Environments
127+
.env
128+
.venv
129+
env/
130+
venv/
131+
ENV/
132+
env.bak/
133+
venv.bak/
134+
135+
# Spyder project settings
136+
.spyderproject
137+
.spyproject
138+
139+
# Rope project settings
140+
.ropeproject
141+
142+
# mkdocs documentation
143+
/site
144+
145+
# mypy
146+
.mypy_cache/
147+
.dmypy.json
148+
dmypy.json
149+
150+
# Pyre type checker
151+
.pyre/
152+
153+
# pytype static type analyzer
154+
.pytype/
155+
156+
# Cython debug symbols
157+
cython_debug/
158+
159+
# PyCharm
160+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
161+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
162+
# and can be added to the global gitignore or merged into this file. For a more nuclear
163+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
164+
#.idea/

‎LICENSE.md

+9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
MIT License
2+
3+
Copyright (c) 2023 Andrei Betlen
4+
5+
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
6+
7+
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
8+
9+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

‎README.md

+40
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# `llama.cpp` Python Bindings
2+
3+
Simple Python bindings for @ggerganov's [`llama.cpp`](https://github.com/ggerganov/llama.cpp) library.
4+
5+
These bindings expose the low-level `llama.cpp` C API through a complete `ctypes` interface.
6+
This module also exposes a high-level Python API that is more convenient to use and follows a familiar format.
7+
8+
# Install
9+
10+
```bash
11+
pip install llama_cpp
12+
```
13+
14+
# Usage
15+
16+
```python
17+
>>> from llama_cpp import Llama
18+
>>> llm = Llama(model_path="models/7B/...")
19+
>>> output = llm("Q: Name the planets in the solar system? A: ", max_tokens=32, stop=["Q:", "\n"], echo=True)
20+
>>> print(output)
21+
{
22+
"id": "cmpl-xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx",
23+
"object": "text_completion",
24+
"created": 1679561337,
25+
"model": "models/7B/...",
26+
"choices": [
27+
{
28+
"text": "Q: Name the planets in the solar system? A: Mercury, Venus, Earth, Mars, Jupiter, Saturn, Uranus, Neptune and Pluto.",
29+
"index": 0,
30+
"logprobs": None,
31+
"finish_reason": "stop"
32+
}
33+
],
34+
"usage": {
35+
"prompt_tokens": 14,
36+
"completion_tokens": 28,
37+
"total_tokens": 42
38+
}
39+
}
40+
```

‎examples/basic.py

+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
import json
2+
from llama_cpp import Llama
3+
4+
llm = Llama(model_path="models/...")
5+
6+
output = llm("Q: Name the planets in the solar system? A: ", max_tokens=32, stop=["Q:", "\n"], echo=True)
7+
8+
print(json.dumps(output, indent=2))

‎llama_cpp/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .llama_cpp import *
2+
from .llama import *

‎llama_cpp/llama.py

+131
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
import uuid
2+
import time
3+
import multiprocessing
4+
from typing import List, Optional
5+
6+
from . import llama_cpp
7+
8+
class Llama:
9+
def __init__(
10+
self,
11+
model_path: str,
12+
n_ctx: int = 512,
13+
n_parts: int = -1,
14+
seed: int = 1337,
15+
f16_kv: bool = False,
16+
logits_all: bool = False,
17+
vocab_only: bool = False,
18+
n_threads: Optional[int] = None,
19+
model_name: Optional[str]=None,
20+
):
21+
self.model_path = model_path
22+
self.model = model_name or model_path
23+
24+
self.params = llama_cpp.llama_context_default_params()
25+
self.params.n_ctx = n_ctx
26+
self.params.n_parts = n_parts
27+
self.params.seed = seed
28+
self.params.f16_kv = f16_kv
29+
self.params.logits_all = logits_all
30+
self.params.vocab_only = vocab_only
31+
32+
self.n_threads = n_threads or multiprocessing.cpu_count()
33+
34+
self.tokens = (llama_cpp.llama_token * self.params.n_ctx)()
35+
36+
self.ctx = llama_cpp.llama_init_from_file(
37+
self.model_path.encode("utf-8"), self.params
38+
)
39+
40+
def __call__(
41+
self,
42+
prompt: str,
43+
suffix: Optional[str] = None,
44+
max_tokens: int = 16,
45+
temperature: float = 0.8,
46+
top_p: float = 0.95,
47+
echo: bool = False,
48+
stop: List[str] = [],
49+
repeat_penalty: float = 1.1,
50+
top_k: int = 40,
51+
):
52+
text = ""
53+
finish_reason = "length"
54+
completion_tokens = 0
55+
56+
prompt_tokens = llama_cpp.llama_tokenize(
57+
self.ctx, prompt.encode("utf-8"), self.tokens, self.params.n_ctx, True
58+
)
59+
60+
if prompt_tokens + max_tokens > self.params.n_ctx:
61+
raise ValueError(
62+
f"Requested tokens exceed context window of {self.params.n_ctx}"
63+
)
64+
65+
for i in range(prompt_tokens):
66+
llama_cpp.llama_eval(
67+
self.ctx, (llama_cpp.c_int * 1)(self.tokens[i]), 1, i, self.n_threads
68+
)
69+
70+
for i in range(max_tokens):
71+
token = llama_cpp.llama_sample_top_p_top_k(
72+
self.ctx,
73+
self.tokens,
74+
prompt_tokens + completion_tokens,
75+
top_k=top_k,
76+
top_p=top_p,
77+
temp=temperature,
78+
repeat_penalty=repeat_penalty,
79+
)
80+
if token == llama_cpp.llama_token_eos():
81+
finish_reason = "stop"
82+
break
83+
text += llama_cpp.llama_token_to_str(self.ctx, token).decode("utf-8")
84+
self.tokens[prompt_tokens + i] = token
85+
completion_tokens += 1
86+
87+
any_stop = [s for s in stop if s in text]
88+
if len(any_stop) > 0:
89+
first_stop = any_stop[0]
90+
text = text[: text.index(first_stop)]
91+
finish_reason = "stop"
92+
break
93+
94+
llama_cpp.llama_eval(
95+
self.ctx,
96+
(llama_cpp.c_int * 1)(self.tokens[prompt_tokens + i]),
97+
1,
98+
prompt_tokens + completion_tokens,
99+
self.n_threads,
100+
)
101+
102+
if echo:
103+
text = prompt + text
104+
105+
if suffix is not None:
106+
text = text + suffix
107+
108+
return {
109+
"id": f"cmpl-{str(uuid.uuid4())}", # Likely to change
110+
"object": "text_completion",
111+
"created": int(time.time()),
112+
"model": self.model, # Likely to change
113+
"choices": [
114+
{
115+
"text": text,
116+
"index": 0,
117+
"logprobs": None,
118+
"finish_reason": finish_reason,
119+
}
120+
],
121+
"usage": {
122+
"prompt_tokens": prompt_tokens,
123+
"completion_tokens": completion_tokens,
124+
"total_tokens": prompt_tokens + completion_tokens,
125+
},
126+
}
127+
128+
def __del__(self):
129+
llama_cpp.llama_free(self.ctx)
130+
131+

‎llama_cpp/llama_cpp.py

+157
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
import ctypes
2+
3+
from ctypes import c_int, c_float, c_double, c_char_p, c_void_p, c_bool, POINTER, Structure
4+
5+
import pathlib
6+
7+
# Load the library
8+
libfile = pathlib.Path(__file__).parent.parent / "libllama.so"
9+
lib = ctypes.CDLL(str(libfile))
10+
11+
12+
# C types
13+
llama_token = c_int
14+
llama_token_p = POINTER(llama_token)
15+
16+
class llama_token_data(Structure):
17+
_fields_ = [
18+
('id', llama_token), # token id
19+
('p', c_float), # probability of the token
20+
('plog', c_float), # log probability of the token
21+
]
22+
23+
llama_token_data_p = POINTER(llama_token_data)
24+
25+
class llama_context_params(Structure):
26+
_fields_ = [
27+
('n_ctx', c_int), # text context
28+
('n_parts', c_int), # -1 for default
29+
('seed', c_int), # RNG seed, 0 for random
30+
('f16_kv', c_bool), # use fp16 for KV cache
31+
('logits_all', c_bool), # the llama_eval() call computes all logits, not just the last one
32+
33+
('vocab_only', c_bool), # only load the vocabulary, no weights
34+
]
35+
36+
llama_context_params_p = POINTER(llama_context_params)
37+
38+
llama_context_p = c_void_p
39+
40+
# C functions
41+
lib.llama_context_default_params.argtypes = []
42+
lib.llama_context_default_params.restype = llama_context_params
43+
44+
lib.llama_init_from_file.argtypes = [c_char_p, llama_context_params]
45+
lib.llama_init_from_file.restype = llama_context_p
46+
47+
lib.llama_free.argtypes = [llama_context_p]
48+
lib.llama_free.restype = None
49+
50+
lib.llama_model_quantize.argtypes = [c_char_p, c_char_p, c_int, c_int]
51+
lib.llama_model_quantize.restype = c_int
52+
53+
lib.llama_eval.argtypes = [llama_context_p, llama_token_p, c_int, c_int, c_int]
54+
lib.llama_eval.restype = c_int
55+
56+
lib.llama_tokenize.argtypes = [llama_context_p, c_char_p, llama_token_p, c_int, c_bool]
57+
lib.llama_tokenize.restype = c_int
58+
59+
lib.llama_n_vocab.argtypes = [llama_context_p]
60+
lib.llama_n_vocab.restype = c_int
61+
62+
lib.llama_n_ctx.argtypes = [llama_context_p]
63+
lib.llama_n_ctx.restype = c_int
64+
65+
lib.llama_get_logits.argtypes = [llama_context_p]
66+
lib.llama_get_logits.restype = POINTER(c_float)
67+
68+
lib.llama_token_to_str.argtypes = [llama_context_p, llama_token]
69+
lib.llama_token_to_str.restype = c_char_p
70+
71+
lib.llama_token_bos.argtypes = []
72+
lib.llama_token_bos.restype = llama_token
73+
74+
lib.llama_token_eos.argtypes = []
75+
lib.llama_token_eos.restype = llama_token
76+
77+
lib.llama_sample_top_p_top_k.argtypes = [llama_context_p, llama_token_p, c_int, c_int, c_double, c_double, c_double]
78+
lib.llama_sample_top_p_top_k.restype = llama_token
79+
80+
lib.llama_print_timings.argtypes = [llama_context_p]
81+
lib.llama_print_timings.restype = None
82+
83+
lib.llama_reset_timings.argtypes = [llama_context_p]
84+
lib.llama_reset_timings.restype = None
85+
86+
lib.llama_print_system_info.argtypes = []
87+
lib.llama_print_system_info.restype = c_char_p
88+
89+
# Python functions
90+
def llama_context_default_params() -> llama_context_params:
91+
params = lib.llama_context_default_params()
92+
return params
93+
94+
def llama_init_from_file(path_model: bytes, params: llama_context_params) -> llama_context_p:
95+
"""Various functions for loading a ggml llama model.
96+
Allocate (almost) all memory needed for the model.
97+
Return NULL on failure """
98+
return lib.llama_init_from_file(path_model, params)
99+
100+
def llama_free(ctx: llama_context_p):
101+
"""Free all allocated memory"""
102+
lib.llama_free(ctx)
103+
104+
def llama_model_quantize(fname_inp: bytes, fname_out: bytes, itype: c_int, qk: c_int) -> c_int:
105+
"""Returns 0 on success"""
106+
return lib.llama_model_quantize(fname_inp, fname_out, itype, qk)
107+
108+
def llama_eval(ctx: llama_context_p, tokens: llama_token_p, n_tokens: c_int, n_past: c_int, n_threads: c_int) -> c_int:
109+
"""Run the llama inference to obtain the logits and probabilities for the next token.
110+
tokens + n_tokens is the provided batch of new tokens to process
111+
n_past is the number of tokens to use from previous eval calls
112+
Returns 0 on success"""
113+
return lib.llama_eval(ctx, tokens, n_tokens, n_past, n_threads)
114+
115+
def llama_tokenize(ctx: llama_context_p, text: bytes, tokens: llama_token_p, n_max_tokens: c_int, add_bos: c_bool) -> c_int:
116+
"""Convert the provided text into tokens.
117+
The tokens pointer must be large enough to hold the resulting tokens.
118+
Returns the number of tokens on success, no more than n_max_tokens
119+
Returns a negative number on failure - the number of tokens that would have been returned"""
120+
return lib.llama_tokenize(ctx, text, tokens, n_max_tokens, add_bos)
121+
122+
def llama_n_vocab(ctx: llama_context_p) -> c_int:
123+
return lib.llama_n_vocab(ctx)
124+
125+
def llama_n_ctx(ctx: llama_context_p) -> c_int:
126+
return lib.llama_n_ctx(ctx)
127+
128+
def llama_get_logits(ctx: llama_context_p):
129+
"""Token logits obtained from the last call to llama_eval()
130+
The logits for the last token are stored in the last row
131+
Can be mutated in order to change the probabilities of the next token
132+
Rows: n_tokens
133+
Cols: n_vocab"""
134+
return lib.llama_get_logits(ctx)
135+
136+
def llama_token_to_str(ctx: llama_context_p, token: int) -> bytes:
137+
"""Token Id -> String. Uses the vocabulary in the provided context"""
138+
return lib.llama_token_to_str(ctx, token)
139+
140+
def llama_token_bos() -> llama_token:
141+
return lib.llama_token_bos()
142+
143+
def llama_token_eos() -> llama_token:
144+
return lib.llama_token_eos()
145+
146+
def llama_sample_top_p_top_k(ctx: llama_context_p, last_n_tokens_data: llama_token_p, last_n_tokens_size: c_int, top_k: c_int, top_p: c_double, temp: c_double, repeat_penalty: c_double) -> llama_token:
147+
return lib.llama_sample_top_p_top_k(ctx, last_n_tokens_data, last_n_tokens_size, top_k, top_p, temp, repeat_penalty)
148+
149+
def llama_print_timings(ctx: llama_context_p):
150+
lib.llama_print_timings(ctx)
151+
152+
def llama_reset_timings(ctx: llama_context_p):
153+
lib.llama_reset_timings(ctx)
154+
155+
def llama_print_system_info() -> bytes:
156+
"""Print system informaiton"""
157+
return lib.llama_print_system_info()

‎poetry.lock

+159
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

‎pyproject.toml

+24
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
[tool.poetry]
2+
name = "llama_cpp"
3+
version = "0.1.0"
4+
description = "Python bindings for the llama.cpp library"
5+
authors = ["Andrei Betlen <abetlen@gmail.com>"]
6+
license = "MIT"
7+
readme = "README.md"
8+
homepage = "https://github.com/abetlen/llama_cpp_python"
9+
repository = "https://github.com/abetlen/llama_cpp_python"
10+
packages = [{include = "llama_cpp"}]
11+
include = [
12+
"LICENSE.md",
13+
]
14+
15+
[tool.poetry.dependencies]
16+
python = "^3.8.1"
17+
18+
19+
[tool.poetry.group.dev.dependencies]
20+
black = "^23.1.0"
21+
22+
[build-system]
23+
requires = ["poetry-core"]
24+
build-backend = "poetry.core.masonry.api"

‎setup.py

+42
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
import os
2+
import subprocess
3+
from setuptools import setup, Extension
4+
5+
from distutils.command.build_ext import build_ext
6+
7+
8+
class build_ext_custom(build_ext):
9+
def run(self):
10+
build_dir = os.path.join(os.getcwd(), "build")
11+
src_dir = os.path.join(os.getcwd(), "vendor", "llama.cpp")
12+
13+
os.makedirs(build_dir, exist_ok=True)
14+
15+
cmake_flags = [
16+
"-DLLAMA_STATIC=Off",
17+
"-DBUILD_SHARED_LIBS=On",
18+
"-DCMAKE_CXX_FLAGS=-fPIC",
19+
"-DCMAKE_C_FLAGS=-fPIC",
20+
]
21+
subprocess.check_call(["cmake", src_dir, *cmake_flags], cwd=build_dir)
22+
subprocess.check_call(["cmake", "--build", "."], cwd=build_dir)
23+
24+
# Move the shared library to the root directory
25+
lib_path = os.path.join(build_dir, "libllama.so")
26+
target_path = os.path.join(os.getcwd(), "libllama.so")
27+
os.rename(lib_path, target_path)
28+
29+
30+
setup(
31+
name="llama_cpp",
32+
description="A Python wrapper for llama.cpp",
33+
version="0.0.1",
34+
author="Andrei Betlen",
35+
author_email="abetlen@gmail.com",
36+
license="MIT",
37+
py_modules=["llama_cpp"],
38+
ext_modules=[
39+
Extension("libllama", ["vendor/llama.cpp"]),
40+
],
41+
cmdclass={"build_ext": build_ext_custom},
42+
)

0 commit comments

Comments
 (0)
Please sign in to comment.