Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Streaming conversion with no torch #176

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ ls ./models
65B 30B 13B 7B tokenizer_checklist.chk tokenizer.model

# install Python dependencies
python3 -m pip install torch numpy sentencepiece
python3 -m pip install tqdm numpy sentencepiece

# convert the 7B model to ggml FP16 format
python3 convert-pth-to-ggml.py models/7B/ 1
Expand Down
118 changes: 96 additions & 22 deletions convert-pth-to-ggml.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# Convert a LLaMA model checkpoint to a ggml compatible file
#
# Load the model using Torch
# Iterate over all variables and write them to a binary file.
#
# For each variable, write the following:
Expand All @@ -17,11 +16,19 @@
# and vocabulary.
#

from collections import defaultdict
import sys
import json
import struct
import numpy as np
import torch
from tqdm import tqdm
import zipfile
import pickle
import concurrent.futures
import io
import threading
import queue

from sentencepiece import SentencePieceProcessor

if len(sys.argv) < 3:
Expand Down Expand Up @@ -73,19 +80,66 @@ def get_n_parts(dim):

n_parts = get_n_parts(hparams["dim"])

print(hparams)
print('n_parts = ', n_parts)

for p in range(n_parts):
print('Processing part ', p)

#fname_model = sys.argv[1] + "/consolidated.00.pth"
fname_model = sys.argv[1] + "/consolidated.0" + str(p) + ".pth"
print(f'Model params.json: {hparams}')
print(f'Parts to process: {n_parts}')


def load_model(fname):
class Tensor():
def __init__(self, shape, dtype, loadinfo):
self.shape = shape
self.dtype = dtype
self.loadinfo = loadinfo

def numpy(self):
myzip, base_name, storage_offset, k, shape, dtype = self.loadinfo
with myzip.open(f'{base_name}/data/{k}') as myfile:
bytes_size = np.dtype(self.dtype).itemsize
myfile.seek(storage_offset * bytes_size, 1)
ret = np.empty(shape, dtype=dtype)
myfile.readinto(ret.data)
return ret

def my_unpickle(datapkl, myzip, base_name):
def my_rebuild_tensor(storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata=None):
storage_type = storage[1]
obj_key = storage[2]
return Tensor(shape=size, dtype=storage_type, loadinfo=(
myzip, base_name, storage_offset,
obj_key, size, storage_type
))

class MyUnpickler(pickle.Unpickler):
def find_class(self, *p):
if p == ('torch', 'HalfStorage'): return np.float16
if p == ('torch', 'FloatStorage'): return np.float32
if p == ('torch._utils', '_rebuild_tensor_v2'): return my_rebuild_tensor
if p == ('collections', 'OrderedDict'): return dict
raise ValueError(f'Unrecognized pickle {p}')

def persistent_load(self, pid):
return pid

return MyUnpickler(datapkl).load()

myzip = zipfile.ZipFile(fname, 'r')
base_name = myzip.namelist()[0].split('/', 1)[0]
with myzip.open(f'{base_name}/data.pkl') as myfile:
model = my_unpickle(myfile, myzip, base_name)
return model

def get_fname(p):
fname = "/consolidated.0" + str(p) + ".pth"
return fname

def process_part(p):
fname = get_fname(p)
fname_model = sys.argv[1] + fname
fname_out = sys.argv[1] + "/ggml-model-" + ftype_str[ftype] + ".bin"
if (p > 0):
fname_out = sys.argv[1] + "/ggml-model-" + ftype_str[ftype] + ".bin" + "." + str(p)

model = torch.load(fname_model, map_location="cpu")
print(f"Processing part {fname}")

fout = open(fname_out, "wb")

Expand Down Expand Up @@ -123,19 +177,30 @@ def get_n_parts(dim):
fout.write(struct.pack("i", len(text)))
fout.write(text)

for k, v in model.items():
model = load_model(fname_model)

q = queue.Queue(maxsize=2)

def writer():
while True:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

while True? Does this function ever return? I don't know if the function exists but maybe something like while !q.atEnd()

Please correct me if I'm wrong. I haven't worked with Python since a year or so.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed!

item = q.get()
fout.write(item.getvalue())
q.task_done()

threading.Thread(target=writer, daemon=True).start()

for k, v in (t := tqdm(model.items())):
t.set_description(f"Processing {k} with shape {tuple(v.shape)} and type {np.dtype(v.dtype)}")
name = k
shape = v.shape

# skip layers.X.attention.inner_attention.rope.freqs
if name[-5:] == "freqs":
continue

print("Processing variable: " + name + " with shape: ", shape, " and type: ", v.dtype)

#data = tf.train.load_variable(dir_model, name).squeeze()
data = v.numpy().squeeze()
n_dims = len(data.shape);
n_dims = len(data.shape)

# for efficiency - transpose some matrices
# "model/h.*/attn/c_attn/w"
Expand All @@ -154,24 +219,33 @@ def get_n_parts(dim):
# default type is fp16
ftype_cur = 1
if ftype == 0 or n_dims == 1:
print(" Converting to float32")
# print(" Converting to float32")
data = data.astype(np.float32)
ftype_cur = 0

memout = io.BytesIO()
# header
sname = name.encode('utf-8')
fout.write(struct.pack("iii", n_dims, len(sname), ftype_cur))
memout.write(struct.pack("iii", n_dims, len(sname), ftype_cur))
for i in range(n_dims):
fout.write(struct.pack("i", dshape[n_dims - 1 - i]))
fout.write(sname);
memout.write(struct.pack("i", dshape[n_dims - 1 - i]))
memout.write(sname)

# data
data.tofile(fout)
memout.write(data.tobytes())
q.put(memout)

q.join()

# I hope this deallocates the memory ..
model = None

fout.close()

print("Done. Output file: " + fname_out + ", (part ", p, ")")
print("")

with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
futures = {executor.submit(process_part, p) for p in range(n_parts)}
for f in (concurrent.futures.as_completed(futures)):
if f.exception() is not None: raise f.exception()

print("All done.")