From 814b364778391c2ebba207345c5662b0b94ae80f Mon Sep 17 00:00:00 2001 From: archana-ramalingam Date: Thu, 21 Nov 2024 10:42:42 +0000 Subject: [PATCH 01/16] Add perplexity pre-submit tests --- .github/workflows/ci_eval.yaml | 2 +- .github/workflows/ci_eval_short.yaml | 77 +++++++++++++ .../sharktank/evaluate/perplexity_vmfb.py | 102 +++++++++++------- sharktank/sharktank/utils/load_llm.py | 23 ++-- .../tests/evaluate/perplexity_vmfb_test.py | 20 ++-- 5 files changed, 162 insertions(+), 62 deletions(-) create mode 100644 .github/workflows/ci_eval_short.yaml diff --git a/.github/workflows/ci_eval.yaml b/.github/workflows/ci_eval.yaml index 856d37c40..1379d72c1 100644 --- a/.github/workflows/ci_eval.yaml +++ b/.github/workflows/ci_eval.yaml @@ -74,7 +74,7 @@ jobs: iree-base-runtime - name: Run perplexity test with vmfb - run: pytest -n 8 -v -s sharktank/tests/evaluate/perplexity_vmfb_test.py --longrun --iree-device='hip://7' --iree-hip-target=gfx942 --iree-hal-target-backends=rocm --llama3-8b-f16-model-path=/data/llama3.1/8b/llama8b_f16.irpa --llama3-8b-tokenizer-path=/data/llama3.1/8b/tokenizer_config.json + run: pytest -n 8 -v -s sharktank/tests/evaluate/perplexity_vmfb_test.py --run-nightly-llama-tests --iree-device='hip://7' --iree-hip-target=gfx942 --iree-hal-target-backends=rocm --llama3-8b-f16-model-path=/data/llama3.1/8b/llama8b_f16.irpa --llama3-8b-tokenizer-path=/data/llama3.1/8b/tokenizer_config.json test_perplexity_torch: timeout-minutes: 1000 diff --git a/.github/workflows/ci_eval_short.yaml b/.github/workflows/ci_eval_short.yaml new file mode 100644 index 000000000..df4416bc3 --- /dev/null +++ b/.github/workflows/ci_eval_short.yaml @@ -0,0 +1,77 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +name: Llama3.1 8B perplexity test + +on: + workflow_dispatch: + pull_request: + push: + branches: + - main + +concurrency: + # A PR number if a pull request and otherwise the commit hash. This cancels + # queued and in-progress runs for the same PR (presubmit) or commit + # (postsubmit). The workflow name is prepended to avoid conflicts between + # different workflows. + group: ${{ github.workflow }}-${{ github.event.number || github.sha }} + cancel-in-progress: true + +jobs: + test_perplexity_iree: + name: "IREE-Perplexity short test" + strategy: + matrix: + version: [3.11] + runs-on: [llama-mi300x-3] + fail-fast: false + runs-on: ${{matrix.runs-on}} + defaults: + run: + shell: bash + env: + PIP_CACHE_DIR: "${{ github.workspace }}/.pip-cache" + SHARK_PLATFORM_REPO_ROOT: ${{ github.workspace }} + steps: + - name: "Setting up Python" + id: setup_python + uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 + with: + python-version: ${{matrix.version}} + + - name: "Checkout Code" + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + + - name: Cache Pip Packages + uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2 + id: cache-pip + with: + path: ${{ env.PIP_CACHE_DIR }} + key: pip-${{ steps.setup_python.outputs.python-version }}-${{ hashFiles('*requirements*.txt','sharktank/requirements*.txt') }} + + - name: Install sharktank deps + run: | + python -m pip install --no-compile --upgrade pip + # Note: We install in three steps in order to satisfy requirements + # from non default locations first. Installing the PyTorch CPU + # wheels saves multiple minutes and a lot of bandwidth on runner setup. + pip install --no-compile -r pytorch-cpu-requirements.txt + pip install --no-compile -r requirements.txt -r sharktank/requirements-tests.txt -e sharktank/ + + # Install latest iree-tubrine. + pip install --no-compile -f https://iree.dev/pip-release-links.html --src deps \ + -e "git+https://github.com/iree-org/iree-turbine.git#egg=iree-turbine" + + # Try with the latest IREE nightly releases, not what iree-turbine pins. + # We could also pin to a known working or stable version. + # This should eventually stabilize. Do the best we can for now. + pip install -f https://iree.dev/pip-release-links.html --upgrade --pre \ + iree-base-compiler \ + iree-base-runtime + + - name: Run perplexity test with vmfb + run: pytest -n 8 -v -s sharktank/tests/evaluate/perplexity_vmfb_test.py --run-quick-llama-test --num-prompts=5 --iree-device='hip://7' --iree-hip-target=gfx942 --iree-hal-target-backends=rocm --llama3-8b-f16-model-path=/data/llama3.1/8b/llama8b_f16.irpa --llama3-8b-tokenizer-path=/data/llama3.1/8b/tokenizer_config.json diff --git a/sharktank/sharktank/evaluate/perplexity_vmfb.py b/sharktank/sharktank/evaluate/perplexity_vmfb.py index 4f95ae1bd..4f63729a6 100644 --- a/sharktank/sharktank/evaluate/perplexity_vmfb.py +++ b/sharktank/sharktank/evaluate/perplexity_vmfb.py @@ -91,9 +91,9 @@ def wrapper(*args, **kwargs): func_name = func.__name__ if func_name == "get_perplexity": - func_name = f"Total time to calculate perplexity" + func_name = f"Calculate perplexity" elif func_name == "compile_model": - func_name = f"Total time to export and compile" + func_name = f"Export & compile" logger.info(f" {func_name}: {time_taken}") return result @@ -119,7 +119,7 @@ def print_token_comparison(self, i): def compile_model(self, weight_path_str): self.weight_path_str = weight_path_str - logger.info(f"Compiling: {self.weight_path_str}") + logger.info(f" Compiling: {self.weight_path_str}") export_artifacts = ExportArtifacts( irpa_path=self.weight_path_str, @@ -135,7 +135,7 @@ def compile_model(self, weight_path_str): @timeit def load_model(self, weight_path, tokenizer, vmfb_path): - config = LlamaModelConfig( + self.config = LlamaModelConfig( hp=configs.LlamaHParams.from_gguf_props(weight_path.properties), block_seq_stride=16, kv_cache_type=self.kv_cache_type, @@ -145,18 +145,18 @@ def load_model(self, weight_path, tokenizer, vmfb_path): tensor_parallelism_size=self.tensor_parallelism_size, ) - if config.tensor_parallelism_size > 1: - weight_path.root_theta = shard_theta(weight_path.root_theta, config) + if self.config.tensor_parallelism_size > 1: + weight_path.root_theta = shard_theta(weight_path.root_theta, self.config) theta = weight_path.root_theta - if config.hp.expert_count: - if config.hp.model_arch == "grok": - model = PagedGrokModelV1(theta, config) + if self.config.hp.expert_count: + if self.config.hp.model_arch == "grok": + model = PagedGrokModelV1(theta, self.config) else: - model = PagedMixtralModelV1(theta, config) + model = PagedMixtralModelV1(theta, self.config) else: - model = PagedLlamaModelV1(theta, config) + model = PagedLlamaModelV1(theta, self.config) self.generator = TorchGenerator(model, tokenizer) @@ -169,7 +169,7 @@ def load_model(self, weight_path, tokenizer, vmfb_path): self.haldevice = self.runner.config.device @timeit - def get_prompts(self): + def get_prompts(self, num_prompts): test_prompts = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")[ "text" ] @@ -183,12 +183,15 @@ def get_prompts(self): s.replace("\n", "").rstrip() for s in test_prompts if s != "" and len(s.split()) >= 20 and s.count("=") < 2 - ] + ][0:num_prompts] + + self.test_prompts = test_prompts self.bs = len(test_prompts) - return test_prompts + logger.info(f" Batch size: {self.bs}") + @timeit def prefill_vmfb(self, token_batch, i): seq_block_ids = self.batch.pad_block_ids() @@ -244,25 +247,7 @@ def decode_vmfb(self, token_batch, i): return decode_logits @timeit - def get_logits(self): - - token_ids, seq_lens = self.generator.tokenizer.encode( - self.test_prompts, - pad_to_multiple_of=self.generator.model.cache.pad_sequence_stride, - ) - - logger.info(f" Prompts for Evaluation:") - for idx, prompt in enumerate(self.test_prompts): - logger.info( - f" Prompt {idx}: \nTokens: {prompt.encode()}\nToken ids: {token_ids[idx]}\n" - ) - - self.max_prompt_length = max(seq_lens) - - self.token_ids = torch.tensor(token_ids, device=self.torch_device) - self.attention_mask = ( - (self.token_ids != 0).int().detach().clone().to(self.torch_device) - ) + def get_logits(self, page_cache_size): is_first_token = True start = 0 @@ -298,6 +283,7 @@ def get_logits(self): token_batch=token_batch, seq_lens_batch=self.seq_lens_batch, bs=self.bs, + page_cache_size=page_cache_size, ) self.cache_state = ireert.asdevicearray( @@ -347,11 +333,31 @@ def compute_perplexity(self): } @timeit - def get_perplexity(self, test_prompts): + def get_perplexity(self): - self.test_prompts = test_prompts + token_ids, seq_lens = self.generator.tokenizer.encode( + self.test_prompts, + pad_to_multiple_of=self.generator.model.cache.pad_sequence_stride, + ) + + self.page_cache_size = ( + len(token_ids[0]) // self.config.block_seq_stride + ) * self.bs + 1 + + logger.debug(f" Prompts for Evaluation:") + for idx, prompt in enumerate(self.test_prompts): + logger.debug( + f" Prompt {idx}: \nTokens: {prompt.encode()}\nToken ids: {token_ids[idx]}\n" + ) + + self.max_prompt_length = max(seq_lens) + + self.token_ids = torch.tensor(token_ids, device=self.torch_device) + self.attention_mask = ( + (self.token_ids != 0).int().detach().clone().to(self.torch_device) + ) - self.get_logits() + self.get_logits(page_cache_size=self.page_cache_size) self.out_logits = self.out_logits[..., :-1, :].contiguous() self.token_ids = self.token_ids[..., 1:].contiguous() @@ -379,7 +385,9 @@ def run_perplexity( kv_cache_type, tensor_parallelism_size, attention_kernel, + num_prompts, ): + start = time.time() perplexity = Perplexity( torch_device=torch_device, iree_device=iree_device, @@ -390,12 +398,19 @@ def run_perplexity( attention_kernel=attention_kernel, ) - test_prompts = perplexity.get_prompts() - logger.info(f" Total test prompts: {len(test_prompts)}") + perplexity.get_prompts(num_prompts=num_prompts) vmfb_path = perplexity.compile_model(weight_path_str) perplexity.load_model(weight_path, tokenizer, vmfb_path) - ppl = perplexity.get_perplexity(test_prompts) + ppl = perplexity.get_perplexity() + + end = time.time() + total_time = round(end - start, 2) + if total_time < 60: + total_time = str(total_time) + " secs" + else: + total_time = str(round(total_time / 60, 2)) + " mins" + logger.info(f" Total time taken: {total_time}") return ppl @@ -404,7 +419,7 @@ def main(argv): parser = cli.create_parser() parser.add_argument("--kv-cache-type", default="paged", help="KV cache type") parser.add_argument("--torch-device", help="Torch device (or default)") - parser.add_argument("--iree-device", help="List an IREE device, eg: 'hip://0'") + parser.add_argument("--iree-device", help="List an IREE device (e.g., 'hip://0')") parser.add_argument( "--iree-hip-target", action="store", @@ -429,6 +444,12 @@ def main(argv): default=1, help="Number of devices for tensor parallel sharding", ) + parser.add_argument( + "--num-prompts", + type=int, + default=100, + help="Number of prompts for perplexity test", + ) cli.add_tokenizer_options(parser) cli.add_input_dataset_options(parser) @@ -452,6 +473,7 @@ def main(argv): kv_cache_type=kv_cache_type, tensor_parallelism_size=args.tensor_parallelism_size, attention_kernel=args.attention_kernel, + num_prompts=args.num_prompts, ) logger.info(f"\n{json.dumps(ppl, indent=2)}") diff --git a/sharktank/sharktank/utils/load_llm.py b/sharktank/sharktank/utils/load_llm.py index acf56eb1b..47d9f0244 100644 --- a/sharktank/sharktank/utils/load_llm.py +++ b/sharktank/sharktank/utils/load_llm.py @@ -23,24 +23,20 @@ def __init__( self, model: PagedLlamaModelV1, tokenizer: InferenceTokenizer, - page_cache_size: int = 8192, # Need to look at the model more for this. end_token: int = 2, ): self.model = model self.tokenizer = tokenizer - if model.cache.is_paged: - self.shared_cache_state = model.cache.paged.allocate(page_cache_size) - self.free_pages = list(range(1, page_cache_size)) - else: - self.shared_cache_state = None self.end_token = end_token @property def block_seq_stride(self) -> int: return self.model.cache.block_seq_stride - def begin_batch(self, prompts: list[str], add_start_token: bool): + def begin_batch( + self, prompts: list[str], add_start_token: bool, page_cache_size: int = 128 + ): token_ids, seq_lens = self.tokenizer.encode( prompts, pad_to_multiple_of=self.model.cache.pad_sequence_stride, @@ -48,8 +44,10 @@ def begin_batch(self, prompts: list[str], add_start_token: bool): ) token_ids = torch.tensor(token_ids, device=self.model.device) seq_lens = torch.tensor(seq_lens, device=self.model.device) - if self.shared_cache_state is not None: - cache_state = self.shared_cache_state + + if self.model.cache.is_paged: + cache_state = self.model.cache.paged.allocate(page_cache_size) + self.free_pages = list(range(1, page_cache_size)) else: cache_state = self.model.cache.direct.allocate(bs=len(prompts)) return Batch(self, token_ids, seq_lens, cache_state) @@ -59,10 +57,11 @@ def begin_eval_batch( token_batch: torch.tensor, seq_lens_batch: torch.tensor, bs: int, + page_cache_size: int = 128, ): - - if self.shared_cache_state is not None: - cache_state = self.shared_cache_state + if self.model.cache.is_paged: + cache_state = self.model.cache.paged.allocate(page_cache_size) + self.free_pages = list(range(1, page_cache_size)) else: cache_state = self.model.cache.direct.allocate(bs=bs) return Batch(self, token_batch, seq_lens_batch, cache_state) diff --git a/sharktank/tests/evaluate/perplexity_vmfb_test.py b/sharktank/tests/evaluate/perplexity_vmfb_test.py index 93ffbe61c..6c84c0fff 100644 --- a/sharktank/tests/evaluate/perplexity_vmfb_test.py +++ b/sharktank/tests/evaluate/perplexity_vmfb_test.py @@ -10,7 +10,10 @@ from sharktank.evaluate import perplexity_vmfb -longrun = pytest.mark.skipif("not config.getoption('longrun')") +skipif_run_quick_llama_test = pytest.mark.skipif( + 'config.getoption("run-quick-llama-test") and not config.getoption("run-nightly-llama-tests")', + reason="Skipping large tests when --run-quick-llama-test is set", +) @pytest.mark.usefixtures( @@ -27,7 +30,6 @@ def setUp(self): with open(self.baseline_perplexity_scores, "r") as f: self.baseline_perplexity = json.load(f) - @longrun def test_llama3_8B_f16_decomposed(self): # Llama 3.1 8B decomposed @@ -62,7 +64,7 @@ def test_llama3_8B_f16_decomposed(self): @pytest.mark.xfail( reason="Non-decomposed attention is not supported yet", ) - @longrun + @skipif_run_quick_llama_test def test_llama3_8B_f16(self): # Llama 3.1 8B non-decomposed @@ -97,7 +99,7 @@ def test_llama3_8B_f16(self): @pytest.mark.xfail( reason="FP8 model is unsupported", ) - @longrun + @skipif_run_quick_llama_test def test_llama3_8B_fp8_decomposed(self): # Llama 3.1 8B decomposed @@ -132,7 +134,7 @@ def test_llama3_8B_fp8_decomposed(self): @pytest.mark.xfail( reason="FP8 model is unsupported", ) - @longrun + @skipif_run_quick_llama_test def test_llama3_8B_fp8(self): # Llama 3.1 8B non-decomposed @@ -167,7 +169,7 @@ def test_llama3_8B_fp8(self): @pytest.mark.xfail( reason="Sharding is unsupported", ) - @longrun + @skipif_run_quick_llama_test def test_llama3_405B_f16_decomposed(self): # Llama 3.1 405B decomposed @@ -202,7 +204,7 @@ def test_llama3_405B_f16_decomposed(self): @pytest.mark.xfail( reason="Non-decomposed attention is not supported yet", ) - @longrun + @skipif_run_quick_llama_test def test_llama3_405B_f16(self): # Llama 3.1 405B non-decomposed @@ -237,7 +239,7 @@ def test_llama3_405B_f16(self): @pytest.mark.xfail( reason="FP8 model is unsupported", ) - @longrun + @skipif_run_quick_llama_test def test_llama3_405B_fp8_decomposed(self): # Llama 3.1 405B decomposed @@ -272,7 +274,7 @@ def test_llama3_405B_fp8_decomposed(self): @pytest.mark.xfail( reason="FP8 model is unsupported", ) - @longrun + @skipif_run_quick_llama_test def test_llama3_405B_fp8(self): # Llama 3.1 405B non-decomposed From 9e965a21dda9ba9031b646ecf1a139b7f0fa1e8b Mon Sep 17 00:00:00 2001 From: archana-ramalingam Date: Thu, 21 Nov 2024 10:43:03 +0000 Subject: [PATCH 02/16] Fix log indentation --- sharktank/sharktank/utils/export_artifacts.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sharktank/sharktank/utils/export_artifacts.py b/sharktank/sharktank/utils/export_artifacts.py index bd33e1a62..a22abe67e 100644 --- a/sharktank/sharktank/utils/export_artifacts.py +++ b/sharktank/sharktank/utils/export_artifacts.py @@ -180,13 +180,13 @@ def export_to_mlir( cwd = self.sharktank_dir cmd = subprocess.list2cmdline(export_args) - logger.info(f"Exporting mlir:\n" f"cd {cwd} && {cmd}") + logger.info(f" Exporting mlir:\n" f"cd {cwd} && {cmd}") proc = subprocess.run(cmd, shell=True, capture_output=True, cwd=cwd, text=True) if proc.returncode != 0: raise ExportMlirException(proc, cwd) else: - logger.info(f"Exported to mlir successfully:\n" f"{proc.stdout}") + logger.info(f" Exported to mlir successfully:\n" f"{proc.stdout}") return proc.returncode @@ -223,7 +223,7 @@ def compile_to_vmfb( compile_args += args cmd = subprocess.list2cmdline(compile_args) - logging.getLogger().info(f"Launching compile command:\n" f"cd {cwd} && {cmd}") + logger.info(f" Launching compile command:\n" f"cd {cwd} && {cmd}") proc = subprocess.run(cmd, shell=True, capture_output=True, cwd=cwd) return_code = proc.returncode if return_code != 0: @@ -277,7 +277,7 @@ def iree_benchmark_vmfb( benchmark_args += devices benchmark_args += args cmd = subprocess.list2cmdline(benchmark_args) - logging.getLogger().info(f"Launching run command:\n" f"cd {cwd} && {cmd}") + logger.info(f" Launching run command:\n" f"cd {cwd} && {cmd}") proc = subprocess.run(cmd, shell=True, stdout=sys.stdout, cwd=cwd) return_code = proc.returncode if return_code != 0: From 3089e6144b8a6c419d5343972d41c47d2f2123bf Mon Sep 17 00:00:00 2001 From: archana-ramalingam Date: Thu, 21 Nov 2024 10:54:11 +0000 Subject: [PATCH 03/16] Add better logging --- .../tests/evaluate/perplexity_vmfb_test.py | 31 +++++++------------ 1 file changed, 12 insertions(+), 19 deletions(-) diff --git a/sharktank/tests/evaluate/perplexity_vmfb_test.py b/sharktank/tests/evaluate/perplexity_vmfb_test.py index 6c84c0fff..1fb8f3a6e 100644 --- a/sharktank/tests/evaluate/perplexity_vmfb_test.py +++ b/sharktank/tests/evaluate/perplexity_vmfb_test.py @@ -9,6 +9,10 @@ import json from sharktank.evaluate import perplexity_vmfb +from sharktank.utils.export_artifacts import ( + ExportMlirException, + IreeCompileException, +) skipif_run_quick_llama_test = pytest.mark.skipif( 'config.getoption("run-quick-llama-test") and not config.getoption("run-nightly-llama-tests")', @@ -46,6 +50,7 @@ def test_llama3_8B_f16_decomposed(self): f"--iree-hip-target={self.iree_hip_target}", f"--tensor-parallelism-size=1", f"--attention-kernel=decomposed", + f"--num-prompts=5", ] ) @@ -61,10 +66,8 @@ def test_llama3_8B_f16_decomposed(self): msg=f"Current perplexity deviates baseline by {perplexity_difference}", ) - @pytest.mark.xfail( - reason="Non-decomposed attention is not supported yet", - ) @skipif_run_quick_llama_test + @pytest.mark.xfail(reason="Compile Error", strict=True, raises=IreeCompileException) def test_llama3_8B_f16(self): # Llama 3.1 8B non-decomposed @@ -96,10 +99,8 @@ def test_llama3_8B_f16(self): msg=f"Current perplexity deviates baseline by {perplexity_difference}", ) - @pytest.mark.xfail( - reason="FP8 model is unsupported", - ) @skipif_run_quick_llama_test + @pytest.mark.xfail(reason="Compile Error", strict=True, raises=IreeCompileException) def test_llama3_8B_fp8_decomposed(self): # Llama 3.1 8B decomposed @@ -131,10 +132,8 @@ def test_llama3_8B_fp8_decomposed(self): msg=f"Current perplexity deviates baseline by {perplexity_difference}", ) - @pytest.mark.xfail( - reason="FP8 model is unsupported", - ) @skipif_run_quick_llama_test + @pytest.mark.xfail(reason="Compile Error", strict=True, raises=IreeCompileException) def test_llama3_8B_fp8(self): # Llama 3.1 8B non-decomposed @@ -166,10 +165,10 @@ def test_llama3_8B_fp8(self): msg=f"Current perplexity deviates baseline by {perplexity_difference}", ) + @skipif_run_quick_llama_test @pytest.mark.xfail( reason="Sharding is unsupported", ) - @skipif_run_quick_llama_test def test_llama3_405B_f16_decomposed(self): # Llama 3.1 405B decomposed @@ -201,10 +200,8 @@ def test_llama3_405B_f16_decomposed(self): msg=f"Current perplexity deviates baseline by {perplexity_difference}", ) - @pytest.mark.xfail( - reason="Non-decomposed attention is not supported yet", - ) @skipif_run_quick_llama_test + @pytest.mark.xfail(reason="Compile Error", strict=True, raises=IreeCompileException) def test_llama3_405B_f16(self): # Llama 3.1 405B non-decomposed @@ -236,10 +233,8 @@ def test_llama3_405B_f16(self): msg=f"Current perplexity deviates baseline by {perplexity_difference}", ) - @pytest.mark.xfail( - reason="FP8 model is unsupported", - ) @skipif_run_quick_llama_test + @pytest.mark.xfail(reason="Compile Error", strict=True, raises=IreeCompileException) def test_llama3_405B_fp8_decomposed(self): # Llama 3.1 405B decomposed @@ -271,10 +266,8 @@ def test_llama3_405B_fp8_decomposed(self): msg=f"Current perplexity deviates baseline by {perplexity_difference}", ) - @pytest.mark.xfail( - reason="FP8 model is unsupported", - ) @skipif_run_quick_llama_test + @pytest.mark.xfail(reason="Compile Error", strict=True, raises=IreeCompileException) def test_llama3_405B_fp8(self): # Llama 3.1 405B non-decomposed From 8266b627783f50f870fb7a7bc719047333bc3a38 Mon Sep 17 00:00:00 2001 From: archana-ramalingam Date: Thu, 21 Nov 2024 10:59:14 +0000 Subject: [PATCH 04/16] Change hip device for pre-submits --- .github/workflows/ci_eval_short.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci_eval_short.yaml b/.github/workflows/ci_eval_short.yaml index df4416bc3..ccb48ee93 100644 --- a/.github/workflows/ci_eval_short.yaml +++ b/.github/workflows/ci_eval_short.yaml @@ -74,4 +74,4 @@ jobs: iree-base-runtime - name: Run perplexity test with vmfb - run: pytest -n 8 -v -s sharktank/tests/evaluate/perplexity_vmfb_test.py --run-quick-llama-test --num-prompts=5 --iree-device='hip://7' --iree-hip-target=gfx942 --iree-hal-target-backends=rocm --llama3-8b-f16-model-path=/data/llama3.1/8b/llama8b_f16.irpa --llama3-8b-tokenizer-path=/data/llama3.1/8b/tokenizer_config.json + run: pytest -n 8 -v -s sharktank/tests/evaluate/perplexity_vmfb_test.py --run-quick-llama-test --num-prompts=5 --iree-device='hip://6' --iree-hip-target=gfx942 --iree-hal-target-backends=rocm --llama3-8b-f16-model-path=/data/llama3.1/8b/llama8b_f16.irpa --llama3-8b-tokenizer-path=/data/llama3.1/8b/tokenizer_config.json From 93b6a2b97672d320ecfa4c6ca2eedd1b2ba5b08e Mon Sep 17 00:00:00 2001 From: archana-ramalingam Date: Thu, 21 Nov 2024 23:54:55 +0000 Subject: [PATCH 05/16] Add is_mi300x marker to skip ppl tests --- sharktank/tests/evaluate/perplexity_vmfb_test.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sharktank/tests/evaluate/perplexity_vmfb_test.py b/sharktank/tests/evaluate/perplexity_vmfb_test.py index 1fb8f3a6e..5762b9e8e 100644 --- a/sharktank/tests/evaluate/perplexity_vmfb_test.py +++ b/sharktank/tests/evaluate/perplexity_vmfb_test.py @@ -14,6 +14,7 @@ IreeCompileException, ) +is_mi300x = pytest.mark.skipif("config.getoption('iree_hip_target') != 'gfx942'") skipif_run_quick_llama_test = pytest.mark.skipif( 'config.getoption("run-quick-llama-test") and not config.getoption("run-nightly-llama-tests")', reason="Skipping large tests when --run-quick-llama-test is set", @@ -26,6 +27,7 @@ "tensor_parallelism_size", "baseline_perplexity_scores", ) +@is_mi300x class PerplexityTest(unittest.TestCase): def setUp(self): self.current_perplexity_all = {} From c5d788edc0871f28d228eded50b78902b9f4b8e9 Mon Sep 17 00:00:00 2001 From: archana-ramalingam Date: Fri, 22 Nov 2024 00:01:50 +0000 Subject: [PATCH 06/16] Rename workflow and job for consistency --- .github/workflows/ci_eval_short.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci_eval_short.yaml b/.github/workflows/ci_eval_short.yaml index ccb48ee93..13d9a685c 100644 --- a/.github/workflows/ci_eval_short.yaml +++ b/.github/workflows/ci_eval_short.yaml @@ -4,7 +4,7 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -name: Llama3.1 8B perplexity test +name: CI - sharktank IREE perplexity on: workflow_dispatch: @@ -23,7 +23,7 @@ concurrency: jobs: test_perplexity_iree: - name: "IREE-Perplexity short test" + name: "Llama3.1 8B FP16" strategy: matrix: version: [3.11] From 33d25e1fdc25a53525e9f4586f4d839d0400f146 Mon Sep 17 00:00:00 2001 From: archana-ramalingam Date: Fri, 22 Nov 2024 02:52:45 +0000 Subject: [PATCH 07/16] Add --bs and adjust mean perplexity --- .github/workflows/ci_eval.yaml | 2 +- .github/workflows/ci_eval_short.yaml | 2 +- .../tests/evaluate/perplexity_vmfb_test.py | 106 +++++++++++------- 3 files changed, 67 insertions(+), 43 deletions(-) diff --git a/.github/workflows/ci_eval.yaml b/.github/workflows/ci_eval.yaml index cde6f294e..e36c045bd 100644 --- a/.github/workflows/ci_eval.yaml +++ b/.github/workflows/ci_eval.yaml @@ -75,7 +75,7 @@ jobs: iree-base-runtime - name: Run perplexity test with vmfb - run: pytest -n 8 -v -s sharktank/tests/evaluate/perplexity_vmfb_test.py --run-nightly-llama-tests --iree-device='hip://7' --iree-hip-target=gfx942 --iree-hal-target-backends=rocm --llama3-8b-f16-model-path=/data/llama3.1/8b/llama8b_f16.irpa --llama3-8b-tokenizer-path=/data/llama3.1/8b/tokenizer_config.json + run: pytest -n 8 -v -s sharktank/tests/evaluate/perplexity_vmfb_test.py --run-nightly-llama-tests --bs=100 --iree-device='hip://7' --iree-hip-target=gfx942 --iree-hal-target-backends=rocm --llama3-8b-f16-model-path=/data/llama3.1/8b/llama8b_f16.irpa --llama3-8b-tokenizer-path=/data/llama3.1/8b/tokenizer_config.json test_perplexity_torch: if: ${{ github.repository_owner == 'nod-ai' || github.event_name != 'schedule' }} diff --git a/.github/workflows/ci_eval_short.yaml b/.github/workflows/ci_eval_short.yaml index 13d9a685c..74ec3b405 100644 --- a/.github/workflows/ci_eval_short.yaml +++ b/.github/workflows/ci_eval_short.yaml @@ -74,4 +74,4 @@ jobs: iree-base-runtime - name: Run perplexity test with vmfb - run: pytest -n 8 -v -s sharktank/tests/evaluate/perplexity_vmfb_test.py --run-quick-llama-test --num-prompts=5 --iree-device='hip://6' --iree-hip-target=gfx942 --iree-hal-target-backends=rocm --llama3-8b-f16-model-path=/data/llama3.1/8b/llama8b_f16.irpa --llama3-8b-tokenizer-path=/data/llama3.1/8b/tokenizer_config.json + run: pytest -n 8 -v -s sharktank/tests/evaluate/perplexity_vmfb_test.py --run-quick-llama-test --bs=5 --iree-device='hip://6' --iree-hip-target=gfx942 --iree-hal-target-backends=rocm --llama3-8b-f16-model-path=/data/llama3.1/8b/llama8b_f16.irpa --llama3-8b-tokenizer-path=/data/llama3.1/8b/tokenizer_config.json diff --git a/sharktank/tests/evaluate/perplexity_vmfb_test.py b/sharktank/tests/evaluate/perplexity_vmfb_test.py index 5762b9e8e..3dd2a7ed5 100644 --- a/sharktank/tests/evaluate/perplexity_vmfb_test.py +++ b/sharktank/tests/evaluate/perplexity_vmfb_test.py @@ -7,6 +7,7 @@ import unittest import pytest import json +import numpy as np from sharktank.evaluate import perplexity_vmfb from sharktank.utils.export_artifacts import ( @@ -52,18 +53,20 @@ def test_llama3_8B_f16_decomposed(self): f"--iree-hip-target={self.iree_hip_target}", f"--tensor-parallelism-size=1", f"--attention-kernel=decomposed", - f"--num-prompts=5", + f"--num-prompts={self.bs}", ] ) - perplexity_difference = ( - current_perplexity["mean_perplexity"] - - baseline_perplexity["mean_perplexity"] + baseline_mean_perplexity = round( + np.mean(baseline_perplexity["perplexities"][0 : self.bs]), 6 ) + current_mean_perplexity = round(current_perplexity["mean_perplexity"], 6) + + perplexity_difference = current_mean_perplexity - baseline_mean_perplexity self.assertAlmostEqual( - baseline_perplexity["mean_perplexity"], - current_perplexity["mean_perplexity"], + baseline_mean_perplexity, + current_mean_perplexity, delta=self.delta, msg=f"Current perplexity deviates baseline by {perplexity_difference}", ) @@ -86,17 +89,20 @@ def test_llama3_8B_f16(self): f"--iree-hip-target={self.iree_hip_target}", f"--tensor-parallelism-size=1", f"--attention-kernel=torch_sdpa", + f"--num-prompts={self.bs}", ] ) - perplexity_difference = ( - current_perplexity["mean_perplexity"] - - baseline_perplexity["mean_perplexity"] + baseline_mean_perplexity = round( + np.mean(baseline_perplexity["perplexities"][0 : self.bs]), 6 ) + current_mean_perplexity = round(current_perplexity["mean_perplexity"], 6) + + perplexity_difference = current_mean_perplexity - baseline_mean_perplexity self.assertAlmostEqual( - baseline_perplexity["mean_perplexity"], - current_perplexity["mean_perplexity"], + baseline_mean_perplexity, + current_mean_perplexity, delta=self.delta, msg=f"Current perplexity deviates baseline by {perplexity_difference}", ) @@ -119,17 +125,20 @@ def test_llama3_8B_fp8_decomposed(self): f"--iree-hip-target={self.iree_hip_target}", f"--tensor-parallelism-size=1", f"--attention-kernel=decomposed", + f"--num-prompts={self.bs}", ] ) - perplexity_difference = ( - current_perplexity["mean_perplexity"] - - baseline_perplexity["mean_perplexity"] + baseline_mean_perplexity = round( + np.mean(baseline_perplexity["perplexities"][0 : self.bs]), 6 ) + current_mean_perplexity = round(current_perplexity["mean_perplexity"], 6) + + perplexity_difference = current_mean_perplexity - baseline_mean_perplexity self.assertAlmostEqual( - baseline_perplexity["mean_perplexity"], - current_perplexity["mean_perplexity"], + baseline_mean_perplexity, + current_mean_perplexity, delta=self.delta, msg=f"Current perplexity deviates baseline by {perplexity_difference}", ) @@ -152,17 +161,20 @@ def test_llama3_8B_fp8(self): f"--iree-hip-target={self.iree_hip_target}", f"--tensor-parallelism-size=1", f"--attention-kernel=torch_sdpa", + f"--num-prompts={self.bs}", ] ) - perplexity_difference = ( - current_perplexity["mean_perplexity"] - - baseline_perplexity["mean_perplexity"] + baseline_mean_perplexity = round( + np.mean(baseline_perplexity["perplexities"][0 : self.bs]), 6 ) + current_mean_perplexity = round(current_perplexity["mean_perplexity"], 6) + + perplexity_difference = current_mean_perplexity - baseline_mean_perplexity self.assertAlmostEqual( - baseline_perplexity["mean_perplexity"], - current_perplexity["mean_perplexity"], + baseline_mean_perplexity, + current_mean_perplexity, delta=self.delta, msg=f"Current perplexity deviates baseline by {perplexity_difference}", ) @@ -187,17 +199,20 @@ def test_llama3_405B_f16_decomposed(self): f"--iree-hip-target={self.iree_hip_target}", f"--tensor-parallelism-size={self.tensor_parallelism_size}", f"--attention-kernel=decomposed", + f"--num-prompts={self.bs}", ] ) - perplexity_difference = ( - current_perplexity["mean_perplexity"] - - baseline_perplexity["mean_perplexity"] + baseline_mean_perplexity = round( + np.mean(baseline_perplexity["perplexities"][0 : self.bs]), 6 ) + current_mean_perplexity = round(current_perplexity["mean_perplexity"], 6) + + perplexity_difference = current_mean_perplexity - baseline_mean_perplexity self.assertAlmostEqual( - baseline_perplexity["mean_perplexity"], - current_perplexity["mean_perplexity"], + baseline_mean_perplexity, + current_mean_perplexity, delta=self.delta, msg=f"Current perplexity deviates baseline by {perplexity_difference}", ) @@ -220,17 +235,20 @@ def test_llama3_405B_f16(self): f"--iree-hip-target={self.iree_hip_target}", f"--tensor-parallelism-size={self.tensor_parallelism_size}", f"--attention-kernel=torch_sdpa", + f"--num-prompts={self.bs}", ] ) - perplexity_difference = ( - current_perplexity["mean_perplexity"] - - baseline_perplexity["mean_perplexity"] + baseline_mean_perplexity = round( + np.mean(baseline_perplexity["perplexities"][0 : self.bs]), 6 ) + current_mean_perplexity = round(current_perplexity["mean_perplexity"], 6) + + perplexity_difference = current_mean_perplexity - baseline_mean_perplexity self.assertAlmostEqual( - baseline_perplexity["mean_perplexity"], - current_perplexity["mean_perplexity"], + baseline_mean_perplexity, + current_mean_perplexity, delta=self.delta, msg=f"Current perplexity deviates baseline by {perplexity_difference}", ) @@ -253,17 +271,20 @@ def test_llama3_405B_fp8_decomposed(self): f"--iree-hip-target={self.iree_hip_target}", f"--tensor-parallelism-size={self.tensor_parallelism_size}", f"--attention-kernel=decomposed", + f"--num-prompts={self.bs}", ] ) - perplexity_difference = ( - current_perplexity["mean_perplexity"] - - baseline_perplexity["mean_perplexity"] + baseline_mean_perplexity = round( + np.mean(baseline_perplexity["perplexities"][0 : self.bs]), 6 ) + current_mean_perplexity = round(current_perplexity["mean_perplexity"], 6) + + perplexity_difference = current_mean_perplexity - baseline_mean_perplexity self.assertAlmostEqual( - baseline_perplexity["mean_perplexity"], - current_perplexity["mean_perplexity"], + baseline_mean_perplexity, + current_mean_perplexity, delta=self.delta, msg=f"Current perplexity deviates baseline by {perplexity_difference}", ) @@ -286,17 +307,20 @@ def test_llama3_405B_fp8(self): f"--iree-hip-target={self.iree_hip_target}", f"--tensor-parallelism-size={self.tensor_parallelism_size}", f"--attention-kernel=torch_sdpa", + f"--num-prompts={self.bs}", ] ) - perplexity_difference = ( - current_perplexity["mean_perplexity"] - - baseline_perplexity["mean_perplexity"] + baseline_mean_perplexity = round( + np.mean(baseline_perplexity["perplexities"][0 : self.bs]), 6 ) + current_mean_perplexity = round(current_perplexity["mean_perplexity"], 6) + + perplexity_difference = current_mean_perplexity - baseline_mean_perplexity self.assertAlmostEqual( - baseline_perplexity["mean_perplexity"], - current_perplexity["mean_perplexity"], + baseline_mean_perplexity, + current_mean_perplexity, delta=self.delta, msg=f"Current perplexity deviates baseline by {perplexity_difference}", ) From a2d7ac98345201893625f0e191c86fc845c39c31 Mon Sep 17 00:00:00 2001 From: archana-ramalingam Date: Fri, 22 Nov 2024 04:00:38 +0000 Subject: [PATCH 08/16] Update workflow names to be consistent --- .github/workflows/ci_eval.yaml | 2 +- .github/workflows/ci_eval_short.yaml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci_eval.yaml b/.github/workflows/ci_eval.yaml index e36c045bd..3ec6e3e4d 100644 --- a/.github/workflows/ci_eval.yaml +++ b/.github/workflows/ci_eval.yaml @@ -4,7 +4,7 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -name: CI - Perplexity +name: CI - sharktank perplexity on: workflow_dispatch: diff --git a/.github/workflows/ci_eval_short.yaml b/.github/workflows/ci_eval_short.yaml index 74ec3b405..33f8d1f80 100644 --- a/.github/workflows/ci_eval_short.yaml +++ b/.github/workflows/ci_eval_short.yaml @@ -4,7 +4,7 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -name: CI - sharktank IREE perplexity +name: CI - sharktank perplexity short on: workflow_dispatch: From 03bbcedf9c9adfa3d1f4167bf8230d77ec6c0151 Mon Sep 17 00:00:00 2001 From: archana-ramalingam Date: Fri, 22 Nov 2024 04:00:46 +0000 Subject: [PATCH 09/16] Use batch_size fixture --- sharktank/tests/evaluate/perplexity_vmfb_test.py | 1 + 1 file changed, 1 insertion(+) diff --git a/sharktank/tests/evaluate/perplexity_vmfb_test.py b/sharktank/tests/evaluate/perplexity_vmfb_test.py index 3dd2a7ed5..431b05a42 100644 --- a/sharktank/tests/evaluate/perplexity_vmfb_test.py +++ b/sharktank/tests/evaluate/perplexity_vmfb_test.py @@ -27,6 +27,7 @@ "get_iree_flags", "tensor_parallelism_size", "baseline_perplexity_scores", + "batch_size", ) @is_mi300x class PerplexityTest(unittest.TestCase): From 13351669d44ed7920fe71f8b04fb8ca82bbf58be Mon Sep 17 00:00:00 2001 From: archana-ramalingam Date: Fri, 22 Nov 2024 04:01:30 +0000 Subject: [PATCH 10/16] Test longrun tests --- .github/workflows/ci_eval.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/ci_eval.yaml b/.github/workflows/ci_eval.yaml index 3ec6e3e4d..8c055717c 100644 --- a/.github/workflows/ci_eval.yaml +++ b/.github/workflows/ci_eval.yaml @@ -7,6 +7,7 @@ name: CI - sharktank perplexity on: + pull_request: workflow_dispatch: schedule: # Weekdays nightly at 07:00 UTC = 23:00 PST / 00:00 PDT. From 6626fa1271a24147dc8f143aedbf77464caac2b9 Mon Sep 17 00:00:00 2001 From: archana-ramalingam Date: Fri, 22 Nov 2024 04:16:08 +0000 Subject: [PATCH 11/16] Correct bs to batch_size --- .../tests/evaluate/perplexity_vmfb_test.py | 44 +++++++++---------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/sharktank/tests/evaluate/perplexity_vmfb_test.py b/sharktank/tests/evaluate/perplexity_vmfb_test.py index 431b05a42..0f1adcbe7 100644 --- a/sharktank/tests/evaluate/perplexity_vmfb_test.py +++ b/sharktank/tests/evaluate/perplexity_vmfb_test.py @@ -54,12 +54,12 @@ def test_llama3_8B_f16_decomposed(self): f"--iree-hip-target={self.iree_hip_target}", f"--tensor-parallelism-size=1", f"--attention-kernel=decomposed", - f"--num-prompts={self.bs}", + f"--num-prompts={self.batch_size}", ] ) baseline_mean_perplexity = round( - np.mean(baseline_perplexity["perplexities"][0 : self.bs]), 6 + np.mean(baseline_perplexity["perplexities"][0 : self.batch_size]), 6 ) current_mean_perplexity = round(current_perplexity["mean_perplexity"], 6) @@ -73,7 +73,7 @@ def test_llama3_8B_f16_decomposed(self): ) @skipif_run_quick_llama_test - @pytest.mark.xfail(reason="Compile Error", strict=True, raises=IreeCompileException) + @pytest.mark.xfail(reason="Compile Error", raises=IreeCompileException) def test_llama3_8B_f16(self): # Llama 3.1 8B non-decomposed @@ -90,12 +90,12 @@ def test_llama3_8B_f16(self): f"--iree-hip-target={self.iree_hip_target}", f"--tensor-parallelism-size=1", f"--attention-kernel=torch_sdpa", - f"--num-prompts={self.bs}", + f"--num-prompts={self.batch_size}", ] ) baseline_mean_perplexity = round( - np.mean(baseline_perplexity["perplexities"][0 : self.bs]), 6 + np.mean(baseline_perplexity["perplexities"][0 : self.batch_size]), 6 ) current_mean_perplexity = round(current_perplexity["mean_perplexity"], 6) @@ -109,7 +109,7 @@ def test_llama3_8B_f16(self): ) @skipif_run_quick_llama_test - @pytest.mark.xfail(reason="Compile Error", strict=True, raises=IreeCompileException) + @pytest.mark.xfail(reason="Compile Error", raises=IreeCompileException) def test_llama3_8B_fp8_decomposed(self): # Llama 3.1 8B decomposed @@ -126,12 +126,12 @@ def test_llama3_8B_fp8_decomposed(self): f"--iree-hip-target={self.iree_hip_target}", f"--tensor-parallelism-size=1", f"--attention-kernel=decomposed", - f"--num-prompts={self.bs}", + f"--num-prompts={self.batch_size}", ] ) baseline_mean_perplexity = round( - np.mean(baseline_perplexity["perplexities"][0 : self.bs]), 6 + np.mean(baseline_perplexity["perplexities"][0 : self.batch_size]), 6 ) current_mean_perplexity = round(current_perplexity["mean_perplexity"], 6) @@ -145,7 +145,7 @@ def test_llama3_8B_fp8_decomposed(self): ) @skipif_run_quick_llama_test - @pytest.mark.xfail(reason="Compile Error", strict=True, raises=IreeCompileException) + @pytest.mark.xfail(reason="Compile Error", raises=IreeCompileException) def test_llama3_8B_fp8(self): # Llama 3.1 8B non-decomposed @@ -162,12 +162,12 @@ def test_llama3_8B_fp8(self): f"--iree-hip-target={self.iree_hip_target}", f"--tensor-parallelism-size=1", f"--attention-kernel=torch_sdpa", - f"--num-prompts={self.bs}", + f"--num-prompts={self.batch_size}", ] ) baseline_mean_perplexity = round( - np.mean(baseline_perplexity["perplexities"][0 : self.bs]), 6 + np.mean(baseline_perplexity["perplexities"][0 : self.batch_size]), 6 ) current_mean_perplexity = round(current_perplexity["mean_perplexity"], 6) @@ -200,12 +200,12 @@ def test_llama3_405B_f16_decomposed(self): f"--iree-hip-target={self.iree_hip_target}", f"--tensor-parallelism-size={self.tensor_parallelism_size}", f"--attention-kernel=decomposed", - f"--num-prompts={self.bs}", + f"--num-prompts={self.batch_size}", ] ) baseline_mean_perplexity = round( - np.mean(baseline_perplexity["perplexities"][0 : self.bs]), 6 + np.mean(baseline_perplexity["perplexities"][0 : self.batch_size]), 6 ) current_mean_perplexity = round(current_perplexity["mean_perplexity"], 6) @@ -219,7 +219,7 @@ def test_llama3_405B_f16_decomposed(self): ) @skipif_run_quick_llama_test - @pytest.mark.xfail(reason="Compile Error", strict=True, raises=IreeCompileException) + @pytest.mark.xfail(reason="Compile Error", raises=IreeCompileException) def test_llama3_405B_f16(self): # Llama 3.1 405B non-decomposed @@ -236,12 +236,12 @@ def test_llama3_405B_f16(self): f"--iree-hip-target={self.iree_hip_target}", f"--tensor-parallelism-size={self.tensor_parallelism_size}", f"--attention-kernel=torch_sdpa", - f"--num-prompts={self.bs}", + f"--num-prompts={self.batch_size}", ] ) baseline_mean_perplexity = round( - np.mean(baseline_perplexity["perplexities"][0 : self.bs]), 6 + np.mean(baseline_perplexity["perplexities"][0 : self.batch_size]), 6 ) current_mean_perplexity = round(current_perplexity["mean_perplexity"], 6) @@ -255,7 +255,7 @@ def test_llama3_405B_f16(self): ) @skipif_run_quick_llama_test - @pytest.mark.xfail(reason="Compile Error", strict=True, raises=IreeCompileException) + @pytest.mark.xfail(reason="Compile Error", raises=IreeCompileException) def test_llama3_405B_fp8_decomposed(self): # Llama 3.1 405B decomposed @@ -272,12 +272,12 @@ def test_llama3_405B_fp8_decomposed(self): f"--iree-hip-target={self.iree_hip_target}", f"--tensor-parallelism-size={self.tensor_parallelism_size}", f"--attention-kernel=decomposed", - f"--num-prompts={self.bs}", + f"--num-prompts={self.batch_size}", ] ) baseline_mean_perplexity = round( - np.mean(baseline_perplexity["perplexities"][0 : self.bs]), 6 + np.mean(baseline_perplexity["perplexities"][0 : self.batch_size]), 6 ) current_mean_perplexity = round(current_perplexity["mean_perplexity"], 6) @@ -291,7 +291,7 @@ def test_llama3_405B_fp8_decomposed(self): ) @skipif_run_quick_llama_test - @pytest.mark.xfail(reason="Compile Error", strict=True, raises=IreeCompileException) + @pytest.mark.xfail(reason="Compile Error", raises=IreeCompileException) def test_llama3_405B_fp8(self): # Llama 3.1 405B non-decomposed @@ -308,12 +308,12 @@ def test_llama3_405B_fp8(self): f"--iree-hip-target={self.iree_hip_target}", f"--tensor-parallelism-size={self.tensor_parallelism_size}", f"--attention-kernel=torch_sdpa", - f"--num-prompts={self.bs}", + f"--num-prompts={self.batch_size}", ] ) baseline_mean_perplexity = round( - np.mean(baseline_perplexity["perplexities"][0 : self.bs]), 6 + np.mean(baseline_perplexity["perplexities"][0 : self.batch_size]), 6 ) current_mean_perplexity = round(current_perplexity["mean_perplexity"], 6) From 33c79956101c15dc4f0bbe23422fdbe18c16d9af Mon Sep 17 00:00:00 2001 From: archana-ramalingam Date: Fri, 22 Nov 2024 18:27:11 +0000 Subject: [PATCH 12/16] Add page_cache_size calc to torch ppl --- .../sharktank/evaluate/perplexity_torch.py | 101 +++++++++++------- 1 file changed, 61 insertions(+), 40 deletions(-) diff --git a/sharktank/sharktank/evaluate/perplexity_torch.py b/sharktank/sharktank/evaluate/perplexity_torch.py index da5fc104a..258e8c9a0 100644 --- a/sharktank/sharktank/evaluate/perplexity_torch.py +++ b/sharktank/sharktank/evaluate/perplexity_torch.py @@ -85,7 +85,7 @@ def wrapper(*args, **kwargs): func_name = func.__name__ if func_name == "get_perplexity": - func_name = "Total time" + func_name = "Calculate perplexity" logger.info(f" {func_name}: {time_taken}") return result @@ -110,7 +110,7 @@ def print_token_comparison(self, i): @timeit def load_model(self, dataset, tokenizer, tensor_parallelism_size, attention_kernel): - config = LlamaModelConfig( + self.config = LlamaModelConfig( hp=configs.LlamaHParams.from_gguf_props(dataset.properties), block_seq_stride=16, kv_cache_type=self.kv_cache_type, @@ -120,23 +120,23 @@ def load_model(self, dataset, tokenizer, tensor_parallelism_size, attention_kern tensor_parallelism_size=tensor_parallelism_size, ) - if config.tensor_parallelism_size > 1: - dataset.root_theta = shard_theta(dataset.root_theta, config) + if self.config.tensor_parallelism_size > 1: + dataset.root_theta = shard_theta(dataset.root_theta, self.config) theta = dataset.root_theta - if config.hp.expert_count: - if config.hp.model_arch == "grok": - model = PagedGrokModelV1(theta, config) + if self.config.hp.expert_count: + if self.config.hp.model_arch == "grok": + model = PagedGrokModelV1(theta, self.config) else: - model = PagedMixtralModelV1(theta, config) + model = PagedMixtralModelV1(theta, self.config) else: - model = PagedLlamaModelV1(theta, config) + model = PagedLlamaModelV1(theta, self.config) self.generator = TorchGenerator(model, tokenizer) @timeit - def get_prompts(self): + def get_prompts(self, num_prompts): test_prompts = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")[ "text" @@ -152,34 +152,16 @@ def get_prompts(self): s.replace("\n", "").rstrip() for s in test_prompts if s != "" and len(s.split()) >= 20 and s.count("=") < 2 - ] - - logger.info(f" num_test_prompts: {len(test_prompts)}") - - return test_prompts + ][0:num_prompts] - @timeit - def get_logits(self): - - token_ids, seq_lens = self.generator.tokenizer.encode( - self.test_prompts, - pad_to_multiple_of=self.generator.model.cache.pad_sequence_stride, - ) - - logger.info(f" Prompts for Evaluation:") - for idx, prompt in enumerate(self.test_prompts): - logger.info( - f" Prompt {idx}: \nTokens: {prompt.encode()}\nToken ids: {token_ids[idx]}\n" - ) + self.test_prompts = test_prompts - self.max_prompt_length = max(seq_lens) + self.bs = len(test_prompts) - self.token_ids = torch.tensor(token_ids, device=self.device) - self.attention_mask = ( - (self.token_ids != 0).int().detach().clone().to(self.device) - ) + logger.info(f" Batch size: {self.bs}") - self.bs = len(self.test_prompts) + @timeit + def get_logits(self, page_cache_size): is_first_token = True start = 0 @@ -212,6 +194,7 @@ def get_logits(self): token_batch=token_batch, seq_lens_batch=seq_lens_batch, bs=self.bs, + page_cache_size=page_cache_size, ) self.batch.prefill() @@ -268,10 +251,31 @@ def compute_perplexity(self): } @timeit - def get_perplexity(self, test_prompts): + def get_perplexity(self): - self.test_prompts = test_prompts - self.get_logits() + token_ids, seq_lens = self.generator.tokenizer.encode( + self.test_prompts, + pad_to_multiple_of=self.generator.model.cache.pad_sequence_stride, + ) + + self.page_cache_size = ( + len(token_ids[0]) // self.config.block_seq_stride + ) * self.bs + 1 + + logger.debug(f" Prompts for Evaluation:") + for idx, prompt in enumerate(self.test_prompts): + logger.debug( + f" Prompt {idx}: \nTokens: {prompt.encode()}\nToken ids: {token_ids[idx]}\n" + ) + + self.max_prompt_length = max(seq_lens) + + self.token_ids = torch.tensor(token_ids, device=self.device) + self.attention_mask = ( + (self.token_ids != 0).int().detach().clone().to(self.device) + ) + + self.get_logits(page_cache_size=self.page_cache_size) self.out_logits = self.out_logits[..., :-1, :].contiguous() self.token_ids = self.token_ids[..., 1:].contiguous() @@ -295,12 +299,22 @@ def run_perplexity_torch( kv_cache_type, tensor_parallelism_size, attention_kernel, + num_prompts, ): - perplexity = Perplexity_torch(device=device, kv_cache_type=kv_cache_type) + start = time.time() + perplexity = Perplexity_torch(device=device, kv_cache_type=kv_cache_type) + perplexity.get_prompts(num_prompts=num_prompts) perplexity.load_model(dataset, tokenizer, tensor_parallelism_size, attention_kernel) - test_prompts = perplexity.get_prompts() - ppl = perplexity.get_perplexity(test_prompts=test_prompts) + ppl = perplexity.get_perplexity() + + end = time.time() + total_time = round(end - start, 2) + if total_time < 60: + total_time = str(total_time) + " secs" + else: + total_time = str(round(total_time / 60, 2)) + " mins" + logger.info(f" Total time taken: {total_time}") return ppl @@ -322,6 +336,12 @@ def main(argv): default=1, help="Number of devices for tensor parallel sharding.", ) + parser.add_argument( + "--num-prompts", + type=int, + default=100, + help="Number of prompts for perplexity test", + ) cli.add_input_dataset_options(parser) cli.add_tokenizer_options(parser) @@ -339,6 +359,7 @@ def main(argv): kv_cache_type=kv_cache_type, tensor_parallelism_size=args.tensor_parallelism_size, attention_kernel=args.attention_kernel, + num_prompts=args.num_prompts, ) logger.info(f"\n{json.dumps(ppl, indent=2)}") From 850c7757739215b1d907e1e842c8d343f923e7ee Mon Sep 17 00:00:00 2001 From: archana-ramalingam Date: Fri, 22 Nov 2024 18:27:50 +0000 Subject: [PATCH 13/16] Remove exceptions --- sharktank/tests/evaluate/perplexity_iree_test.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/sharktank/tests/evaluate/perplexity_iree_test.py b/sharktank/tests/evaluate/perplexity_iree_test.py index bb28b6ba4..103a37c5b 100644 --- a/sharktank/tests/evaluate/perplexity_iree_test.py +++ b/sharktank/tests/evaluate/perplexity_iree_test.py @@ -10,10 +10,6 @@ import numpy as np from sharktank.evaluate import perplexity_iree -from sharktank.utils.export_artifacts import ( - ExportMlirException, - IreeCompileException, -) is_mi300x = pytest.mark.skipif("config.getoption('iree_hip_target') != 'gfx942'") skipif_run_quick_llama_test = pytest.mark.skipif( @@ -73,7 +69,7 @@ def test_llama3_8B_f16_decomposed(self): ) @skipif_run_quick_llama_test - @pytest.mark.xfail(reason="Compile Error", raises=IreeCompileException) + @pytest.mark.xfail(reason="Compile Error") def test_llama3_8B_f16(self): # Llama 3.1 8B non-decomposed @@ -109,7 +105,7 @@ def test_llama3_8B_f16(self): ) @skipif_run_quick_llama_test - @pytest.mark.xfail(reason="Compile Error", raises=IreeCompileException) + @pytest.mark.xfail(reason="Compile Error") def test_llama3_8B_fp8_decomposed(self): # Llama 3.1 8B decomposed @@ -145,7 +141,7 @@ def test_llama3_8B_fp8_decomposed(self): ) @skipif_run_quick_llama_test - @pytest.mark.xfail(reason="Compile Error", raises=IreeCompileException) + @pytest.mark.xfail(reason="Compile Error") def test_llama3_8B_fp8(self): # Llama 3.1 8B non-decomposed @@ -219,7 +215,7 @@ def test_llama3_405B_f16_decomposed(self): ) @skipif_run_quick_llama_test - @pytest.mark.xfail(reason="Compile Error", raises=IreeCompileException) + @pytest.mark.xfail(reason="Compile Error") def test_llama3_405B_f16(self): # Llama 3.1 405B non-decomposed @@ -255,7 +251,7 @@ def test_llama3_405B_f16(self): ) @skipif_run_quick_llama_test - @pytest.mark.xfail(reason="Compile Error", raises=IreeCompileException) + @pytest.mark.xfail(reason="Compile Error") def test_llama3_405B_fp8_decomposed(self): # Llama 3.1 405B decomposed @@ -291,7 +287,7 @@ def test_llama3_405B_fp8_decomposed(self): ) @skipif_run_quick_llama_test - @pytest.mark.xfail(reason="Compile Error", raises=IreeCompileException) + @pytest.mark.xfail(reason="Compile Error") def test_llama3_405B_fp8(self): # Llama 3.1 405B non-decomposed From 29aadd3e4e2143032aed0b9d7ee506126734d0e2 Mon Sep 17 00:00:00 2001 From: archana-ramalingam Date: Fri, 22 Nov 2024 19:05:14 +0000 Subject: [PATCH 14/16] update test name --- .github/workflows/ci_eval_short.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci_eval_short.yaml b/.github/workflows/ci_eval_short.yaml index 33f8d1f80..3037edc5c 100644 --- a/.github/workflows/ci_eval_short.yaml +++ b/.github/workflows/ci_eval_short.yaml @@ -74,4 +74,4 @@ jobs: iree-base-runtime - name: Run perplexity test with vmfb - run: pytest -n 8 -v -s sharktank/tests/evaluate/perplexity_vmfb_test.py --run-quick-llama-test --bs=5 --iree-device='hip://6' --iree-hip-target=gfx942 --iree-hal-target-backends=rocm --llama3-8b-f16-model-path=/data/llama3.1/8b/llama8b_f16.irpa --llama3-8b-tokenizer-path=/data/llama3.1/8b/tokenizer_config.json + run: pytest -n 8 -v -s sharktank/tests/evaluate/perplexity_iree_test.py --run-quick-llama-test --bs=5 --iree-device='hip://6' --iree-hip-target=gfx942 --iree-hal-target-backends=rocm --llama3-8b-f16-model-path=/data/llama3.1/8b/llama8b_f16.irpa --llama3-8b-tokenizer-path=/data/llama3.1/8b/tokenizer_config.json From 56de4e00aabda71a1a325e1f8416eb4a23a245f1 Mon Sep 17 00:00:00 2001 From: archana-ramalingam Date: Fri, 22 Nov 2024 22:36:24 +0000 Subject: [PATCH 15/16] Revert Test longrun tests --- .github/workflows/ci_eval.yaml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/ci_eval.yaml b/.github/workflows/ci_eval.yaml index c8157fbc8..0164b6cdc 100644 --- a/.github/workflows/ci_eval.yaml +++ b/.github/workflows/ci_eval.yaml @@ -7,7 +7,6 @@ name: CI - sharktank perplexity on: - pull_request: workflow_dispatch: schedule: # Weekdays nightly at 07:00 UTC = 23:00 PST / 00:00 PDT. From 5dcac71b9da2082ef720ec60b141006afe7158fb Mon Sep 17 00:00:00 2001 From: archana-ramalingam Date: Sat, 23 Nov 2024 00:10:35 +0000 Subject: [PATCH 16/16] Opt in large tests with --run-nightly-llama-tests flag --- .github/workflows/ci_eval_short.yaml | 2 +- sharktank/tests/evaluate/perplexity_iree_test.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/ci_eval_short.yaml b/.github/workflows/ci_eval_short.yaml index 3037edc5c..4622f5c57 100644 --- a/.github/workflows/ci_eval_short.yaml +++ b/.github/workflows/ci_eval_short.yaml @@ -74,4 +74,4 @@ jobs: iree-base-runtime - name: Run perplexity test with vmfb - run: pytest -n 8 -v -s sharktank/tests/evaluate/perplexity_iree_test.py --run-quick-llama-test --bs=5 --iree-device='hip://6' --iree-hip-target=gfx942 --iree-hal-target-backends=rocm --llama3-8b-f16-model-path=/data/llama3.1/8b/llama8b_f16.irpa --llama3-8b-tokenizer-path=/data/llama3.1/8b/tokenizer_config.json + run: pytest -n 8 -v -s sharktank/tests/evaluate/perplexity_iree_test.py --bs=5 --iree-device='hip://6' --iree-hip-target=gfx942 --iree-hal-target-backends=rocm --llama3-8b-f16-model-path=/data/llama3.1/8b/llama8b_f16.irpa --llama3-8b-tokenizer-path=/data/llama3.1/8b/tokenizer_config.json diff --git a/sharktank/tests/evaluate/perplexity_iree_test.py b/sharktank/tests/evaluate/perplexity_iree_test.py index 103a37c5b..d10d9f5db 100644 --- a/sharktank/tests/evaluate/perplexity_iree_test.py +++ b/sharktank/tests/evaluate/perplexity_iree_test.py @@ -13,8 +13,8 @@ is_mi300x = pytest.mark.skipif("config.getoption('iree_hip_target') != 'gfx942'") skipif_run_quick_llama_test = pytest.mark.skipif( - 'config.getoption("run-quick-llama-test") and not config.getoption("run-nightly-llama-tests")', - reason="Skipping large tests when --run-quick-llama-test is set", + 'not config.getoption("run-nightly-llama-tests")', + reason="Run large tests if --run-nightly-llama-tests is passed", )