diff --git a/eval.py b/eval.py index 5647c30721..a28eda0ef0 100644 --- a/eval.py +++ b/eval.py @@ -31,7 +31,7 @@ except: lm_eval_available = False -from generate import _load_inference_model, encode_tokens, model_forward +from generate import _initialize_model, encode_tokens, model_forward if lm_eval_available: try: # lm_eval version 0.4 @@ -236,7 +236,7 @@ def eval_main(args) -> None: precision = name_to_dtype(model_dtype) set_precision(precision) - model = _load_inference_model( + model = _initialize_model( checkpoint_path, checkpoint_dir, params_path, @@ -247,6 +247,7 @@ def eval_main(args) -> None: quantize, device, precision, + setup_caches=False, use_tp=False ) diff --git a/export.py b/export.py index 649ce68455..9a783e2289 100644 --- a/export.py +++ b/export.py @@ -25,7 +25,7 @@ from export_aoti import export_model as export_model_aoti from model import Transformer -from generate import _load_model, decode_one_token +from generate import _load_model, decode_one_token, _initialize_model from quantize import quantize_model, name_to_dtype from torch._export import capture_pre_autograd_graph @@ -41,25 +41,6 @@ def device_sync(device): print(f"device={device} is not yet suppported") -class model_wrapper(nn.Module): - def __init__(self, model, device): - super().__init__() - - max_seq_length = 350 - with torch.device(device): - model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length) - - self.model = model - # init model here if necessary - - def forward(self, idx, input_pos): - # input_pos: [B, 1] - # assert failed on symbolic shape during aot_compile?! - # but not for ET? - # assert input_pos.shape[-1] == 1 - logits = self.model(idx, input_pos) - return logits # sample(logits, **sampling_kwargs) - def main(args): checkpoint_path = args.checkpoint_path @@ -72,29 +53,24 @@ def main(args): precision = name_to_dtype(args.dtype) # torch.float # bfloat16 set_precision(precision) - print("Loading model ...") - t0 = time.time() - model = _load_model( - checkpoint_path, + model = _initialize_model( + args.checkpoint_path, args.checkpoint_dir, args.params_path, args.params_table, args.gguf_path, - device=device, - precision=precision, + None, # dso_path - cannot re-export exported model + None, # pte_path - cannot re-export exported model + quantize, + device, + precision, + setup_caches=True, use_tp=False ) - device_sync(device=device) # MKG - print(f"Time to load model: {time.time() - t0:.02f} seconds") - - quantize_model(model, args.quantize) - # dtype: - if args.dtype: - model.to(dtype=name_to_dtype(args.dtype)) - - model = model_wrapper(model, device=device) + # if args.dtype: + # model.to(dtype=name_to_dtype(args.dtype)) output_pte_path = args.output_pte_path output_dso_path = args.output_dso_path diff --git a/generate.py b/generate.py index 8d0178f892..61d5185bbe 100644 --- a/generate.py +++ b/generate.py @@ -341,7 +341,7 @@ def _load_model( B_INST, E_INST = "[INST]", "[/INST]" -def _load_inference_model( +def _initialize_model( checkpoint_path, checkpoint_dir, params_path, @@ -352,6 +352,7 @@ def _load_inference_model( quantize, device, precision, + setup_caches, use_tp # =False ): assert ( @@ -418,6 +419,11 @@ def _load_inference_model( device_sync(device=device) # MKG print(f"Time to quantize model: {time.time() - t0q:.02f} seconds") + if setup_caches: + max_seq_length = 350 + with torch.device(device): + model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length) + model.to(dtype=precision) return model @@ -470,7 +476,7 @@ def _main( is_speculative = draft_checkpoint_path is not None is_chat = "chat" in str(checkpoint_path) - model = _load_inference_model( + model = _initialize_model( checkpoint_path, checkpoint_dir, params_path, @@ -481,10 +487,11 @@ def _main( quantize, device, precision, - use_tp + False, # setup_caches + False, # use_tp ) - # will add a version of _load_inference_model in future + # will add a version of _initialize_model in future # (need additional args) if is_speculative: draft_model = _load_model(