Skip to content

Commit

Permalink
Format Eval output and enabled cuda support (#2569)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2569

When using eval_llama_lib, it is much faster with cuda enabled if possible. This diff enables this

In addition it wraps the output format of eval to more digestable

Reviewed By: jerryzh168

Differential Revision: D55208754

fbshipit-source-id: 8744d58064b6bcab5567a62bb2bf99fe69507aa1
  • Loading branch information
Jack-Khuu authored and facebook-github-bot committed Mar 23, 2024
1 parent 725c590 commit 579ccce
Showing 1 changed file with 11 additions and 3 deletions.
14 changes: 11 additions & 3 deletions examples/models/llama2/eval_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@ def __init__(
super().__init__()
self._model = model
self._tokenizer = tokenizer
self._device = torch.device("cpu")
self._device = (
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
)
self._max_seq_length = 2048 if max_seq_length is None else max_seq_length

@property
Expand Down Expand Up @@ -153,12 +155,18 @@ def eval_llama(
tokenizer = SentencePieceProcessor(model_file=str(args.tokenizer_path))

# Evaluate the model
model = (
manager.model.eval().to(device="cuda")
if torch.cuda.is_available()
else manager.model.to(device="cpu")
)
eval_results = eval(
manager.model.to(device="cpu"),
model,
tokenizer,
args.tasks,
args.limit,
args.max_seq_length,
)

print("Results: ", eval_results)
for task, res in eval_results["results"].items():
print(f"{task}: {res}")

0 comments on commit 579ccce

Please sign in to comment.