Skip to content

Commit

Permalink
add et export with gguf with test (pytorch#245)
Browse files Browse the repository at this point in the history
* add et export with gguf with test

* fix generate too

* add gguf path to generate
  • Loading branch information
metascroy authored and malfet committed Jul 17, 2024
1 parent a45e86e commit 0906a11
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 23 deletions.
19 changes: 18 additions & 1 deletion .github/workflows/et.yml
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,13 @@ jobs:
wget https://github.com/karpathy/llama2.c/raw/master/tokenizer.model
wget https://github.com/karpathy/llama2.c/raw/master/tokenizer.bin
popd
mkdir gguf_files
export GGUF_PATH=gguf_files/TinyLlama-1.1B-openorca.Q4_0.gguf
export GGUF_TOKENIZER_PATH=gguf_files/tokenizer.model
wget -O ${GGUF_PATH} "https://huggingface.co/TheBloke/TinyLlama-1.1B-1T-OpenOrca-GGUF/resolve/main/tinyllama-1.1b-1t-openorca.Q4_0.gguf?download=true"
wget -O ${GGUF_TOKENIZER_PATH} https://github.com/karpathy/llama2.c/raw/master/tokenizer.model
- name: Run inference
run: |
export MODEL_PATH=${PWD}/checkpoints/stories15M/stories15M.pt
Expand All @@ -75,7 +82,7 @@ jobs:
echo "Tests complete."
- name: Run inference
run: |
run: |
export MODEL_PATH=checkpoints/stories15M/stories15M.pt
export MODEL_NAME=stories15M
export MODEL_DIR=/tmp
Expand Down Expand Up @@ -121,3 +128,13 @@ jobs:
echo "tests complete"
echo "******************************************"
- name: Run GGUF export + inference
run: |
export GGUF_PATH=gguf_files/TinyLlama-1.1B-openorca.Q4_0.gguf
export GGUF_TOKENIZER_PATH=gguf_files/tokenizer.model
python torchchat.py export --gguf-path ${GGUF_PATH} --output-pte-path ${PWD}/${MODEL_NAME}.pte
python torchchat.py generate --gguf-path ${GGUF_PATH} --pte-path ${PWD}/${MODEL_NAME}.pte --tokenizer-path ${GGUF_TOKENIZER_PATH} --temperature 0 --max-new-tokens 20 > ${PWD}/output_et
cat ${PWD}/output_et
echo "Tests complete."
36 changes: 34 additions & 2 deletions build/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, Union
from typing import Any, Optional, Union

import torch
import torch._dynamo.config
Expand All @@ -29,6 +29,7 @@ class BuilderArgs:
params_path: Optional[Union[Path, str]] = None
params_table: Optional[str] = None
gguf_path: Optional[Union[Path, str]] = None
gguf_kwargs: Optional[dict[str, Any]] = None
dso_path: Optional[Union[Path, str]] = None
pte_path: Optional[Union[Path, str]] = None
device: str = "cpu"
Expand Down Expand Up @@ -91,6 +92,7 @@ def from_args(cls, args): # -> BuilderArgs:
params_path=args.params_path,
params_table=args.params_table,
gguf_path=args.gguf_path,
gguf_kwargs=None,
dso_path=args.dso_path,
pte_path=args.pte_path,
device=args.device,
Expand Down Expand Up @@ -174,9 +176,30 @@ def device_sync(device):
sys.path.append(str(wd))


# TODO: remove these once ET supports _weight_int4pack_mm
def _set_gguf_kwargs(builder_args, is_et, context: str):
assert context in ["export", "generate"]
assert builder_args.gguf_kwargs is None

if builder_args.gguf_path is None:
print("No gguf_path provided, so ignoring set_gguf_kwargs.")
return

builder_args.gguf_kwargs = {}
if is_et:
builder_args.gguf_kwargs["load_as_quantized"] = False

def _unset_gguf_kwargs(builder_args):
builder_args.gguf_kwargs = None


def _load_model_gguf(builder_args):
assert builder_args.gguf_path
model = Transformer.from_gguf(builder_args.gguf_path)
if builder_args.gguf_kwargs is None:
kwargs = {}
else:
kwargs = builder_args.gguf_kwargs
model = Transformer.from_gguf(builder_args.gguf_path, **kwargs)
return model


Expand Down Expand Up @@ -254,6 +277,15 @@ def _initialize_model(
):
print("Loading model ...")
t0 = time.time()

if builder_args.gguf_path and (builder_args.dso_path or builder_args.pte_path):
print("Setting gguf_kwargs for generate.")
is_dso = builder_args.dso_path is not None
is_pte = builder_args.pte_path is not None
assert not (is_dso and is_pte)
assert builder_args.gguf_kwargs is None
_set_gguf_kwargs(builder_args, is_et=is_pte, context="generate")

model_ = _load_model(builder_args)
device_sync(device=builder_args.device)
print(f"Time to load model: {time.time() - t0:.02f} seconds")
Expand Down
22 changes: 12 additions & 10 deletions build/gguf_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def load_model(gguf_file: str) -> torch.nn.Module:
return model


def load_model_and_state_dict(gguf_file: str, load_as_quantized: bool, *, inner_k_tiles = 8) -> torch.nn.Module:
def load_model_and_state_dict(gguf_file: str, *, load_state_dict: bool = True, load_as_quantized: bool = True, inner_k_tiles = 8) -> torch.nn.Module:
"""
Parses the GGUF file and returns an nn.Module on meta device along with a state_dict
that can be loaded into it.
Expand Down Expand Up @@ -174,14 +174,14 @@ def load_model_and_state_dict(gguf_file: str, load_as_quantized: bool, *, inner_
in_features = mod.in_features
assert all(t.shape == (in_features, out_features))

q, s, z = Q4_0.unpack(t)
scales_and_zeros = pack_scales_and_zeros(s, z)
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(
q, inner_k_tiles
)

state_dict[f"{fqn}.weight"] = weight_int4pack
state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros
if load_state_dict:
q, s, z = Q4_0.unpack(t)
scales_and_zeros = pack_scales_and_zeros(s, z)
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(
q, inner_k_tiles
)
state_dict[f"{fqn}.weight"] = weight_int4pack
state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros

parent = _fqn_lookup(_fqn_up(fqn), model)
setattr(
Expand All @@ -197,8 +197,10 @@ def load_model_and_state_dict(gguf_file: str, load_as_quantized: bool, *, inner_
),
)
else:
state_dict[f"{fqn}.weight"] = to_float(t)
if load_state_dict:
state_dict[f"{fqn}.weight"] = to_float(t)

assert (state_dict == {}) == (not load_state_dict)
return model, state_dict


Expand Down
7 changes: 4 additions & 3 deletions build/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,10 +246,11 @@ def from_params(cls, params_path: str):
return cls(ModelArgs.from_params(params_path))

@classmethod
def from_gguf(cls, gguf_path: str):
def from_gguf(cls, gguf_path: str, **kwargs):
from build.gguf_loader import load_model_and_state_dict
model, state_dict = load_model_and_state_dict(gguf_path, load_as_quantized=True, inner_k_tiles=8)
model.load_state_dict(state_dict, assign=True)
model, state_dict = load_model_and_state_dict(gguf_path, **kwargs)
if state_dict != {}:
model.load_state_dict(state_dict, assign=True)
return model


Expand Down
38 changes: 31 additions & 7 deletions export.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import torch

from build.builder import _initialize_model, BuilderArgs
from build.builder import _initialize_model, BuilderArgs, _set_gguf_kwargs, _unset_gguf_kwargs
from cli import add_arguments_for_export, arg_init, check_args
from export_aoti import export_model as export_model_aoti

Expand Down Expand Up @@ -42,24 +42,48 @@ def main(args):
print(f"Using device={builder_args.device}")
set_precision(builder_args.precision)


builder_args.dso_path = None
builder_args.pte_path = None
builder_args.setup_caches = True
model = _initialize_model(
builder_args,
quantize,
)

output_pte_path = args.output_pte_path
output_dso_path = args.output_dso_path

# TODO: clean this up
# This mess is because ET does not support _weight_int4pack_mm right now
if not builder_args.gguf_path:
model = _initialize_model(
builder_args,
quantize,
)
model_to_pte = model
model_to_dso = model
else:
if output_pte_path:
_set_gguf_kwargs(builder_args, is_et=True, context="export")
model_to_pte = _initialize_model(
builder_args,
quantize,
)
_unset_gguf_kwargs(builder_args)

if output_dso_path:
_set_gguf_kwargs(builder_args, is_et=False, context="export")
model_to_dso = _initialize_model(
builder_args,
quantize,
)
_unset_gguf_kwargs(builder_args)


with torch.no_grad():
if output_pte_path:
output_pte_path = str(os.path.abspath(output_pte_path))
print(f">{output_pte_path}<")
if executorch_export_available:
print(f"Exporting model using Executorch to {output_pte_path}")
export_model_et(model, builder_args.device, args.output_pte_path, args)
export_model_et(model_to_pte, builder_args.device, args.output_pte_path, args)
else:
print(
"Export with executorch requested but Executorch could not be loaded"
Expand All @@ -68,7 +92,7 @@ def main(args):
if output_dso_path:
output_dso_path = str(os.path.abspath(output_dso_path))
print(f"Exporting model using AOT Inductor to {output_dso_path}")
export_model_aoti(model, builder_args.device, output_dso_path, args)
export_model_aoti(model_to_dso, builder_args.device, output_dso_path, args)


if __name__ == "__main__":
Expand Down

0 comments on commit 0906a11

Please sign in to comment.