Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

export and run with bfloat16 weight matrices #407

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ CC = gcc
run: run.c
$(CC) -O3 -o run run.c -lm

.PHONY: runbf16
runbf16: runbf16.c
$(CC) -O3 -o runbf16 runbf16.c -lm -fopenmp

# useful for a debug build, can then e.g. analyze with valgrind, example:
# $ valgrind --leak-check=full ./run out/model.bin -n 3
rundebug: run.c
Expand Down
66 changes: 66 additions & 0 deletions export.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,16 @@ def serialize_fp32(file, tensor):
b = struct.pack(f'{len(d)}f', *d)
file.write(b)

def serialize_bf16(file, tensor):
""" writes one bf16 tensor to file that is open in wb mode """
d = tensor.detach().cpu().view(-1)
if d.dtype == torch.bfloat16:
d = d.view(torch.int16).numpy()
else:
d = d.to(torch.bfloat16).view(torch.int16).numpy()
b = struct.pack(f'{len(d)}h', *d)
file.write(b)

def serialize_int8(file, tensor):
""" writes one int8 tensor to file that is open in wb mode """
d = tensor.detach().cpu().view(-1).numpy().astype(np.int8)
Expand Down Expand Up @@ -126,6 +136,60 @@ def legacy_export(model, filepath):
out_file.close()
print(f"wrote {filepath}")

def legacy_export_bf16(model, filepath):
""" Original export of llama2.c bin files, i.e. version v0 """
out_file = open(filepath, 'wb')

# first write out the header
hidden_dim = model.layers[0].feed_forward.w1.weight.shape[0]
p = model.params
shared_classifier = torch.equal(model.tok_embeddings.weight, model.output.weight)
# legacy format uses negative/positive vocab size as a shared classifier flag
if not shared_classifier:
p.vocab_size = -p.vocab_size
n_kv_heads = p.n_heads if p.n_kv_heads is None else p.n_kv_heads
header = struct.pack('iiiiiii', p.dim, hidden_dim, p.n_layers, p.n_heads,
n_kv_heads, p.vocab_size, p.max_seq_len)
out_file.write(header)

# next write out the embedding weights
serialize_fp32(out_file, model.tok_embeddings.weight)

# now all the layers
# attention weights
for layer in model.layers:
serialize_fp32(out_file, layer.attention_norm.weight)
for layer in model.layers:
serialize_bf16(out_file, layer.attention.wq.weight)
for layer in model.layers:
serialize_bf16(out_file, layer.attention.wk.weight)
for layer in model.layers:
serialize_bf16(out_file, layer.attention.wv.weight)
for layer in model.layers:
serialize_bf16(out_file, layer.attention.wo.weight)
# ffn weights
for layer in model.layers:
serialize_fp32(out_file, layer.ffn_norm.weight)
for layer in model.layers:
serialize_bf16(out_file, layer.feed_forward.w1.weight)
for layer in model.layers:
serialize_bf16(out_file, layer.feed_forward.w2.weight)
for layer in model.layers:
serialize_bf16(out_file, layer.feed_forward.w3.weight)
# final rmsnorm
serialize_fp32(out_file, model.norm.weight)
# freqs_cis
serialize_fp32(out_file, model.freqs_cos[:p.max_seq_len])
serialize_fp32(out_file, model.freqs_sin[:p.max_seq_len])

# final classifier weights
if not shared_classifier:
serialize_bf16(out_file, model.output.weight)

# write to binary file
out_file.close()
print(f"wrote {filepath}")

# -----------------------------------------------------------------------------
# new version

Expand Down Expand Up @@ -408,6 +472,8 @@ def permute_reverse(w, n_heads=config.n_heads, dim1=config.dim, dim2=config.dim)
def model_export(model, filepath, version):
if version == 0:
legacy_export(model, filepath)
elif version == -1:
legacy_export_bf16(model, filepath)
elif version == 1:
version1_export(model, filepath)
elif version == 2:
Expand Down
Loading