Skip to content

Commit

Permalink
remove temp pt (intel#884)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhenwei-intel authored May 8, 2023
1 parent 0065db0 commit cd2cc66
Showing 1 changed file with 6 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
)
parser.add_argument('--dtype', default=None, type=str)
parser.add_argument('--output_model', default="./ir", type=str)
parser.add_argument('--pt_file', default="./model.pt", type=str)
parser.add_argument('--pt_file', default="temp.pt", type=str)
args = parser.parse_args()
print(args)

Expand All @@ -28,6 +28,8 @@
past_key_value_torch = tuple([(torch.zeros([1,16,32,256]), torch.zeros([1,16,32,256])) for i in range(28)])
input_ids = input_ids[0:1].unsqueeze(0)
attention_mask = attention_mask.unsqueeze(0)

clean_model = False
if 'llama' in model_id:
past_key_value_torch = tuple([(torch.zeros([1,32,32,256]), torch.zeros([1,32,32,256])) for i in range(32)])
if os.path.exists(args.pt_file):
Expand All @@ -40,6 +42,7 @@
traced_model = torch.jit.trace(model, (input_ids, past_key_value_torch, attention_mask))
torch.jit.save(traced_model, args.pt_file)
print("Traced model is saved as {}".format(args.pt_file))
clean_model = True
else:
print("Model with {} can't be traced, please provide one.".format(args.dtype))
sys.exit(1)
Expand All @@ -64,3 +67,5 @@

graph.save(args.output_model)
print('Neural Engine ir is saved as {}'.format(args.output_model))
if clean_model:
os.remove(args.pt_file)

0 comments on commit cd2cc66

Please sign in to comment.