Skip to content

Commit

Permalink
change llama2's example input inside script
Browse files Browse the repository at this point in the history
  • Loading branch information
haowhsu-quic committed Mar 1, 2024
1 parent a392269 commit 1af3bdf
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 8 deletions.
6 changes: 3 additions & 3 deletions examples/models/llama2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,7 +599,7 @@ def get_example_inputs(self):
else:
return (
torch.tensor(
[[1, 2, 3]], dtype=torch.int32
[[1, 2, 3]], dtype=torch.long
), # tokens, with kv cache our input token length is always just 1 token.
)

Expand All @@ -609,10 +609,10 @@ def get_example_inputs_kvcache(self):
cache_v = torch.zeros(cache_sizes)
return (
torch.tensor(
[[1]], dtype=torch.int32
[[1]], dtype=torch.long
), # tokens, with kv cache our input token length is always just 1 token.
torch.tensor(
0, dtype=torch.int32
0, dtype=torch.long
), # start_pos, what token of output are we on.
cache_k, # key caches
cache_v, # value caches
Expand Down
8 changes: 3 additions & 5 deletions examples/qualcomm/scripts/dummy_llama2.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,21 +21,19 @@


def create_device_inputs(example_inputs, use_kv_cache):
inputs = None
inputs = [inp.to(torch.int32) for inp in example_inputs]
input_list = ""
if use_kv_cache:
inputs = (example_inputs,)
for i, d in enumerate(inputs[0]):
if type(d) == list:
d = torch.stack(d)
d.numpy().tofile(f"{args.artifact}/input_0_0.raw")
input_list = f"input_0_{i}.raw "
else:
inputs = example_inputs
inputs[0].numpy().tofile(f"{args.artifact}/input_0_0.raw")
input_list = "input_0_0.raw"
input_list += "\n"
return inputs, input_list
return tuple(inputs), input_list


if __name__ == "__main__":
Expand Down Expand Up @@ -94,7 +92,7 @@ def create_device_inputs(example_inputs, use_kv_cache):
use_fp16 = False if args.ptq else True
build_executorch_binary(
instance.get_eager_model().eval(),
instance.get_example_inputs(),
inputs,
args.model,
f"{args.artifact}/{pte_filename}",
inputs,
Expand Down

0 comments on commit 1af3bdf

Please sign in to comment.