Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[sharktank] Add Perplexity pre-submit test #579

Merged
merged 29 commits into from
Nov 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
814b364
Add perplexity pre-submit tests
archana-ramalingam Nov 21, 2024
9e965a2
Fix log indentation
archana-ramalingam Nov 21, 2024
3089e61
Add better logging
archana-ramalingam Nov 21, 2024
c1f7acc
Merge branch 'main' into perplexity-pre-submit
archana-ramalingam Nov 21, 2024
8266b62
Change hip device for pre-submits
archana-ramalingam Nov 21, 2024
8ea8a24
Merge branch 'perplexity-pre-submit' of https://github.com/nod-ai/sha…
archana-ramalingam Nov 21, 2024
93b6a2b
Add is_mi300x marker to skip ppl tests
archana-ramalingam Nov 21, 2024
e88a3ea
Merge branch 'main' into perplexity-pre-submit
archana-ramalingam Nov 21, 2024
c5d788e
Rename workflow and job for consistency
archana-ramalingam Nov 22, 2024
cca5a14
Merge branch 'main' into perplexity-pre-submit
archana-ramalingam Nov 22, 2024
1b1cec7
Merge branch 'perplexity-pre-submit' of https://github.com/nod-ai/sha…
archana-ramalingam Nov 22, 2024
33d25e1
Add --bs and adjust mean perplexity
archana-ramalingam Nov 22, 2024
453b4d2
Merge branch 'main' into perplexity-pre-submit
archana-ramalingam Nov 22, 2024
71a520d
Merge branch 'main' into perplexity-pre-submit
archana-ramalingam Nov 22, 2024
a2d7ac9
Update workflow names to be consistent
archana-ramalingam Nov 22, 2024
03bbced
Use batch_size fixture
archana-ramalingam Nov 22, 2024
1335166
Test longrun tests
archana-ramalingam Nov 22, 2024
2c6b191
Merge branch 'perplexity-pre-submit' of https://github.com/nod-ai/sha…
archana-ramalingam Nov 22, 2024
6626fa1
Correct bs to batch_size
archana-ramalingam Nov 22, 2024
1038878
Merge main changes
archana-ramalingam Nov 22, 2024
33c7995
Add page_cache_size calc to torch ppl
archana-ramalingam Nov 22, 2024
850c775
Remove exceptions
archana-ramalingam Nov 22, 2024
51be7c1
Merge branch 'main' into perplexity-pre-submit
archana-ramalingam Nov 22, 2024
29aadd3
update test name
archana-ramalingam Nov 22, 2024
9886d6f
Merge branch 'perplexity-pre-submit' of https://github.com/nod-ai/sha…
archana-ramalingam Nov 22, 2024
56de4e0
Revert Test longrun tests
archana-ramalingam Nov 22, 2024
df33088
Merge branch 'main' into perplexity-pre-submit
archana-ramalingam Nov 22, 2024
5dcac71
Opt in large tests with --run-nightly-llama-tests flag
archana-ramalingam Nov 23, 2024
ec02ffb
Merge branch 'perplexity-pre-submit' of https://github.com/nod-ai/sha…
archana-ramalingam Nov 23, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/ci_eval.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

name: CI - Perplexity
name: CI - sharktank perplexity

on:
workflow_dispatch:
Expand Down Expand Up @@ -75,7 +75,7 @@ jobs:
iree-base-runtime

- name: Run perplexity test with IREE
run: pytest -n 8 -v -s sharktank/tests/evaluate/perplexity_iree_test.py --longrun --iree-device='hip://7' --iree-hip-target=gfx942 --iree-hal-target-backends=rocm --llama3-8b-f16-model-path=/data/llama3.1/8b/llama8b_f16.irpa --llama3-8b-tokenizer-path=/data/llama3.1/8b/tokenizer_config.json --html=out/llm/llama/perplexity/iree_perplexity/index.html
run: pytest -n 8 -v -s sharktank/tests/evaluate/perplexity_iree_test.py --run-nightly-llama-tests --bs=100 --iree-device='hip://7' --iree-hip-target=gfx942 --iree-hal-target-backends=rocm --llama3-8b-f16-model-path=/data/llama3.1/8b/llama8b_f16.irpa --llama3-8b-tokenizer-path=/data/llama3.1/8b/tokenizer_config.json --html=out/llm/llama/perplexity/iree_perplexity/index.html

- name: Deploy to GitHub Pages
uses: peaceiris/actions-gh-pages@4f9cc6602d3f66b9c108549d475ec49e8ef4d45e # v4.0.0
Expand Down
77 changes: 77 additions & 0 deletions .github/workflows/ci_eval_short.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Copyright 2024 Advanced Micro Devices, Inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

name: CI - sharktank perplexity short

on:
workflow_dispatch:
pull_request:
push:
branches:
- main

concurrency:
# A PR number if a pull request and otherwise the commit hash. This cancels
# queued and in-progress runs for the same PR (presubmit) or commit
# (postsubmit). The workflow name is prepended to avoid conflicts between
# different workflows.
group: ${{ github.workflow }}-${{ github.event.number || github.sha }}
cancel-in-progress: true

jobs:
test_perplexity_iree:
name: "Llama3.1 8B FP16"
strategy:
matrix:
version: [3.11]
runs-on: [llama-mi300x-3]
fail-fast: false
runs-on: ${{matrix.runs-on}}
defaults:
run:
shell: bash
env:
PIP_CACHE_DIR: "${{ github.workspace }}/.pip-cache"
SHARK_PLATFORM_REPO_ROOT: ${{ github.workspace }}
steps:
- name: "Setting up Python"
id: setup_python
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
with:
python-version: ${{matrix.version}}

- name: "Checkout Code"
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2

- name: Cache Pip Packages
uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2
id: cache-pip
with:
path: ${{ env.PIP_CACHE_DIR }}
key: pip-${{ steps.setup_python.outputs.python-version }}-${{ hashFiles('*requirements*.txt','sharktank/requirements*.txt') }}

- name: Install sharktank deps
run: |
python -m pip install --no-compile --upgrade pip
# Note: We install in three steps in order to satisfy requirements
# from non default locations first. Installing the PyTorch CPU
# wheels saves multiple minutes and a lot of bandwidth on runner setup.
pip install --no-compile -r pytorch-cpu-requirements.txt
pip install --no-compile -r requirements.txt -r sharktank/requirements-tests.txt -e sharktank/

# Install latest iree-tubrine.
pip install --no-compile -f https://iree.dev/pip-release-links.html --src deps \
-e "git+https://github.com/iree-org/iree-turbine.git#egg=iree-turbine"

# Try with the latest IREE nightly releases, not what iree-turbine pins.
# We could also pin to a known working or stable version.
# This should eventually stabilize. Do the best we can for now.
pip install -f https://iree.dev/pip-release-links.html --upgrade --pre \
iree-base-compiler \
iree-base-runtime

- name: Run perplexity test with vmfb
run: pytest -n 8 -v -s sharktank/tests/evaluate/perplexity_iree_test.py --bs=5 --iree-device='hip://6' --iree-hip-target=gfx942 --iree-hal-target-backends=rocm --llama3-8b-f16-model-path=/data/llama3.1/8b/llama8b_f16.irpa --llama3-8b-tokenizer-path=/data/llama3.1/8b/tokenizer_config.json
102 changes: 62 additions & 40 deletions sharktank/sharktank/evaluate/perplexity_iree.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,9 @@ def wrapper(*args, **kwargs):

func_name = func.__name__
if func_name == "get_perplexity":
func_name = f"Total time to calculate perplexity"
func_name = f"Calculate perplexity"
elif func_name == "compile_model":
func_name = f"Total time to export and compile"
func_name = f"Export & compile"
logger.info(f" {func_name}: {time_taken}")
return result

Expand All @@ -127,7 +127,7 @@ def print_token_comparison(self, i):
def compile_model(self, weight_path_str):
self.weight_path_str = weight_path_str

logger.info(f"Compiling: {self.weight_path_str}")
logger.info(f" Compiling: {self.weight_path_str}")

export_artifacts = ExportArtifacts(
irpa_path=self.weight_path_str,
Expand All @@ -143,7 +143,7 @@ def compile_model(self, weight_path_str):
@timeit
def load_model(self, weight_path, tokenizer, vmfb_path):

config = LlamaModelConfig(
self.config = LlamaModelConfig(
hp=configs.LlamaHParams.from_gguf_props(weight_path.properties),
block_seq_stride=16,
kv_cache_type=self.kv_cache_type,
Expand All @@ -153,18 +153,18 @@ def load_model(self, weight_path, tokenizer, vmfb_path):
tensor_parallelism_size=self.tensor_parallelism_size,
)

if config.tensor_parallelism_size > 1:
weight_path.root_theta = shard_theta(weight_path.root_theta, config)
if self.config.tensor_parallelism_size > 1:
weight_path.root_theta = shard_theta(weight_path.root_theta, self.config)

theta = weight_path.root_theta

if config.hp.expert_count:
if config.hp.model_arch == "grok":
model = PagedGrokModelV1(theta, config)
if self.config.hp.expert_count:
if self.config.hp.model_arch == "grok":
model = PagedGrokModelV1(theta, self.config)
else:
model = PagedMixtralModelV1(theta, config)
model = PagedMixtralModelV1(theta, self.config)
else:
model = PagedLlamaModelV1(theta, config)
model = PagedLlamaModelV1(theta, self.config)

self.generator = TorchGenerator(model, tokenizer)

Expand All @@ -177,7 +177,7 @@ def load_model(self, weight_path, tokenizer, vmfb_path):
self.haldevice = self.runner.config.device

@timeit
def get_prompts(self):
def get_prompts(self, num_prompts):
test_prompts = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")[
"text"
]
Expand All @@ -191,12 +191,15 @@ def get_prompts(self):
s.replace("\n", "").rstrip()
for s in test_prompts
if s != "" and len(s.split()) >= 20 and s.count("=") < 2
]
][0:num_prompts]

self.test_prompts = test_prompts

self.bs = len(test_prompts)

return test_prompts
logger.info(f" Batch size: {self.bs}")

@timeit
def prefill_vmfb(self, token_batch, i):

seq_block_ids = self.batch.pad_block_ids()
Expand Down Expand Up @@ -252,25 +255,7 @@ def decode_vmfb(self, token_batch, i):
return decode_logits

@timeit
def get_logits(self):

token_ids, seq_lens = self.generator.tokenizer.encode(
self.test_prompts,
pad_to_multiple_of=self.generator.model.cache.pad_sequence_stride,
)

logger.info(f" Prompts for Evaluation:")
for idx, prompt in enumerate(self.test_prompts):
logger.info(
f" Prompt {idx}: \nTokens: {prompt.encode()}\nToken ids: {token_ids[idx]}\n"
)

self.max_prompt_length = max(seq_lens)

self.token_ids = torch.tensor(token_ids, device=self.torch_device)
self.attention_mask = (
(self.token_ids != 0).int().detach().clone().to(self.torch_device)
)
def get_logits(self, page_cache_size):

is_first_token = True
start = 0
Expand Down Expand Up @@ -306,6 +291,7 @@ def get_logits(self):
token_batch=token_batch,
seq_lens_batch=self.seq_lens_batch,
bs=self.bs,
page_cache_size=page_cache_size,
)

self.cache_state = ireert.asdevicearray(
Expand Down Expand Up @@ -355,11 +341,31 @@ def compute_perplexity(self):
}

@timeit
def get_perplexity(self, test_prompts):
def get_perplexity(self):

self.test_prompts = test_prompts
token_ids, seq_lens = self.generator.tokenizer.encode(
self.test_prompts,
pad_to_multiple_of=self.generator.model.cache.pad_sequence_stride,
)

self.page_cache_size = (
len(token_ids[0]) // self.config.block_seq_stride
) * self.bs + 1

logger.debug(f" Prompts for Evaluation:")
for idx, prompt in enumerate(self.test_prompts):
logger.debug(
f" Prompt {idx}: \nTokens: {prompt.encode()}\nToken ids: {token_ids[idx]}\n"
)

self.max_prompt_length = max(seq_lens)

self.token_ids = torch.tensor(token_ids, device=self.torch_device)
self.attention_mask = (
(self.token_ids != 0).int().detach().clone().to(self.torch_device)
)

self.get_logits()
self.get_logits(page_cache_size=self.page_cache_size)

self.out_logits = self.out_logits[..., :-1, :].contiguous()
self.token_ids = self.token_ids[..., 1:].contiguous()
Expand Down Expand Up @@ -387,7 +393,9 @@ def run_perplexity(
kv_cache_type,
tensor_parallelism_size,
attention_kernel,
num_prompts,
):
start = time.time()
perplexity = Perplexity(
torch_device=torch_device,
iree_device=iree_device,
Expand All @@ -398,12 +406,19 @@ def run_perplexity(
attention_kernel=attention_kernel,
)

test_prompts = perplexity.get_prompts()
logger.info(f" Total test prompts: {len(test_prompts)}")
perplexity.get_prompts(num_prompts=num_prompts)

vmfb_path = perplexity.compile_model(weight_path_str)
perplexity.load_model(weight_path, tokenizer, vmfb_path)
ppl = perplexity.get_perplexity(test_prompts)
ppl = perplexity.get_perplexity()

end = time.time()
total_time = round(end - start, 2)
if total_time < 60:
total_time = str(total_time) + " secs"
else:
total_time = str(round(total_time / 60, 2)) + " mins"
logger.info(f" Total time taken: {total_time}")

return ppl

Expand All @@ -412,7 +427,7 @@ def main(argv):
parser = cli.create_parser()
parser.add_argument("--kv-cache-type", default="paged", help="KV cache type")
parser.add_argument("--torch-device", help="Torch device (or default)")
parser.add_argument("--iree-device", help="List an IREE device, eg: 'hip://0'")
parser.add_argument("--iree-device", help="List an IREE device (e.g., 'hip://0')")
parser.add_argument(
"--iree-hip-target",
action="store",
Expand All @@ -437,6 +452,12 @@ def main(argv):
default=1,
help="Number of devices for tensor parallel sharding",
)
parser.add_argument(
"--num-prompts",
type=int,
default=100,
help="Number of prompts for perplexity test",
)

cli.add_tokenizer_options(parser)
cli.add_input_dataset_options(parser)
Expand All @@ -460,6 +481,7 @@ def main(argv):
kv_cache_type=kv_cache_type,
tensor_parallelism_size=args.tensor_parallelism_size,
attention_kernel=args.attention_kernel,
num_prompts=args.num_prompts,
)

logger.info(f"\n{json.dumps(ppl, indent=2)}")
Expand Down
Loading
Loading