Skip to content

Commit

Permalink
Export new (pytorch#177)
Browse files Browse the repository at this point in the history
* move model init

* move model init

* typo

* name change to initialize_model
  • Loading branch information
mikekgfb authored and malfet committed Jul 17, 2024
1 parent 592a307 commit 9d46c31
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 41 deletions.
5 changes: 3 additions & 2 deletions eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
except:
lm_eval_available = False

from generate import _load_inference_model, encode_tokens, model_forward
from generate import _initialize_model, encode_tokens, model_forward

if lm_eval_available:
try: # lm_eval version 0.4
Expand Down Expand Up @@ -236,7 +236,7 @@ def eval_main(args) -> None:
precision = name_to_dtype(model_dtype)
set_precision(precision)

model = _load_inference_model(
model = _initialize_model(
checkpoint_path,
checkpoint_dir,
params_path,
Expand All @@ -247,6 +247,7 @@ def eval_main(args) -> None:
quantize,
device,
precision,
setup_caches=False,
use_tp=False
)

Expand Down
46 changes: 11 additions & 35 deletions export.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from export_aoti import export_model as export_model_aoti

from model import Transformer
from generate import _load_model, decode_one_token
from generate import _load_model, decode_one_token, _initialize_model
from quantize import quantize_model, name_to_dtype
from torch._export import capture_pre_autograd_graph

Expand All @@ -41,25 +41,6 @@ def device_sync(device):
print(f"device={device} is not yet suppported")


class model_wrapper(nn.Module):
def __init__(self, model, device):
super().__init__()

max_seq_length = 350
with torch.device(device):
model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length)

self.model = model
# init model here if necessary

def forward(self, idx, input_pos):
# input_pos: [B, 1]
# assert failed on symbolic shape during aot_compile?!
# but not for ET?
# assert input_pos.shape[-1] == 1
logits = self.model(idx, input_pos)
return logits # sample(logits, **sampling_kwargs)


def main(args):
checkpoint_path = args.checkpoint_path
Expand All @@ -72,29 +53,24 @@ def main(args):
precision = name_to_dtype(args.dtype) # torch.float # bfloat16
set_precision(precision)

print("Loading model ...")
t0 = time.time()
model = _load_model(
checkpoint_path,
model = _initialize_model(
args.checkpoint_path,
args.checkpoint_dir,
args.params_path,
args.params_table,
args.gguf_path,
device=device,
precision=precision,
None, # dso_path - cannot re-export exported model
None, # pte_path - cannot re-export exported model
quantize,
device,
precision,
setup_caches=True,
use_tp=False
)

device_sync(device=device) # MKG
print(f"Time to load model: {time.time() - t0:.02f} seconds")

quantize_model(model, args.quantize)

# dtype:
if args.dtype:
model.to(dtype=name_to_dtype(args.dtype))

model = model_wrapper(model, device=device)
# if args.dtype:
# model.to(dtype=name_to_dtype(args.dtype))

output_pte_path = args.output_pte_path
output_dso_path = args.output_dso_path
Expand Down
15 changes: 11 additions & 4 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ def _load_model(

B_INST, E_INST = "[INST]", "[/INST]"

def _load_inference_model(
def _initialize_model(
checkpoint_path,
checkpoint_dir,
params_path,
Expand All @@ -352,6 +352,7 @@ def _load_inference_model(
quantize,
device,
precision,
setup_caches,
use_tp # =False
):
assert (
Expand Down Expand Up @@ -418,6 +419,11 @@ def _load_inference_model(
device_sync(device=device) # MKG
print(f"Time to quantize model: {time.time() - t0q:.02f} seconds")

if setup_caches:
max_seq_length = 350
with torch.device(device):
model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length)

model.to(dtype=precision)

return model
Expand Down Expand Up @@ -470,7 +476,7 @@ def _main(
is_speculative = draft_checkpoint_path is not None
is_chat = "chat" in str(checkpoint_path)

model = _load_inference_model(
model = _initialize_model(
checkpoint_path,
checkpoint_dir,
params_path,
Expand All @@ -481,10 +487,11 @@ def _main(
quantize,
device,
precision,
use_tp
False, # setup_caches
False, # use_tp
)

# will add a version of _load_inference_model in future
# will add a version of _initialize_model in future
# (need additional args)
if is_speculative:
draft_model = _load_model(
Expand Down

0 comments on commit 9d46c31

Please sign in to comment.