From cd2cc6624d85b24f9aaeb3bf4c841f0dd72a85b8 Mon Sep 17 00:00:00 2001 From: liuzhenwei <109187816+zhenwei-intel@users.noreply.github.com> Date: Mon, 8 May 2023 16:56:29 +0800 Subject: [PATCH] remove temp pt (#884) --- .../pytorch/text-generation/deployment/gen_ir.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/examples/huggingface/pytorch/text-generation/deployment/gen_ir.py b/examples/huggingface/pytorch/text-generation/deployment/gen_ir.py index e8dec926b5f..059b8c0d42c 100644 --- a/examples/huggingface/pytorch/text-generation/deployment/gen_ir.py +++ b/examples/huggingface/pytorch/text-generation/deployment/gen_ir.py @@ -12,7 +12,7 @@ ) parser.add_argument('--dtype', default=None, type=str) parser.add_argument('--output_model', default="./ir", type=str) -parser.add_argument('--pt_file', default="./model.pt", type=str) +parser.add_argument('--pt_file', default="temp.pt", type=str) args = parser.parse_args() print(args) @@ -28,6 +28,8 @@ past_key_value_torch = tuple([(torch.zeros([1,16,32,256]), torch.zeros([1,16,32,256])) for i in range(28)]) input_ids = input_ids[0:1].unsqueeze(0) attention_mask = attention_mask.unsqueeze(0) + +clean_model = False if 'llama' in model_id: past_key_value_torch = tuple([(torch.zeros([1,32,32,256]), torch.zeros([1,32,32,256])) for i in range(32)]) if os.path.exists(args.pt_file): @@ -40,6 +42,7 @@ traced_model = torch.jit.trace(model, (input_ids, past_key_value_torch, attention_mask)) torch.jit.save(traced_model, args.pt_file) print("Traced model is saved as {}".format(args.pt_file)) + clean_model = True else: print("Model with {} can't be traced, please provide one.".format(args.dtype)) sys.exit(1) @@ -64,3 +67,5 @@ graph.save(args.output_model) print('Neural Engine ir is saved as {}'.format(args.output_model)) +if clean_model: + os.remove(args.pt_file)