diff --git a/.github/workflows/ci-sharktank.yml b/.github/workflows/ci-sharktank.yml index 6bf5446f8..6bb105979 100644 --- a/.github/workflows/ci-sharktank.yml +++ b/.github/workflows/ci-sharktank.yml @@ -141,12 +141,14 @@ jobs: --with-flux-data \ --with-t5-data \ --with-vae-data \ + --with-quark-data \ sharktank/tests/models/clip/clip_test.py \ sharktank/tests/models/t5/t5_test.py \ sharktank/tests/models/flux/flux_test.py \ sharktank/tests/models/vae/vae_test.py \ + sharktank/tests/models/llama/quark_parity_test.py \ --durations=0 \ - --timeout=600 + --timeout=800 test_integration: diff --git a/sharktank/conftest.py b/sharktank/conftest.py index 3416e50c9..5aae72d41 100644 --- a/sharktank/conftest.py +++ b/sharktank/conftest.py @@ -125,6 +125,15 @@ def pytest_addoption(parser): ), ) + parser.addoption( + "--with-quark-data", + action="store_true", + default=False, + help=( + "Enable tests that use vae data such as models not part of the source code." + ), + ) + # TODO: Remove all hardcoded paths in CI tests parser.addoption( "--llama3-8b-tokenizer-path", diff --git a/sharktank/sharktank/examples/paged_llm_v1.py b/sharktank/sharktank/examples/paged_llm_v1.py index 780f7bc13..5d338bd74 100644 --- a/sharktank/sharktank/examples/paged_llm_v1.py +++ b/sharktank/sharktank/examples/paged_llm_v1.py @@ -325,16 +325,17 @@ def main(): intermediates_saver.save_file( args.save_intermediates_path + "_prefill.safetensors" ) - counter = 0 - while not batch.done: - batch.decode() - if args.save_intermediates_path: - intermediates_saver.save_file( - args.save_intermediates_path + f"_step_{counter}.safetensors" - ) - print(f":: Result tokens: {batch.results}") - batch.print_current_results() - counter += 1 + if not args.skip_decode: + counter = 0 + while not batch.done: + batch.decode() + if args.save_intermediates_path: + intermediates_saver.save_file( + args.save_intermediates_path + f"_step_{counter}.safetensors" + ) + print(f":: Result tokens: {batch.results}") + batch.print_current_results() + counter += 1 if __name__ == "__main__": diff --git a/sharktank/tests/models/llama/quark_parity_test.py b/sharktank/tests/models/llama/quark_parity_test.py index 1ffdffd30..b45696fe4 100644 --- a/sharktank/tests/models/llama/quark_parity_test.py +++ b/sharktank/tests/models/llama/quark_parity_test.py @@ -3,24 +3,31 @@ # Licensed under the Apache License v2.0 with LLVM Exceptions. # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - +import os from safetensors import safe_open import torch import unittest import pytest +from pathlib import Path +import subprocess + +with_quark_data = pytest.mark.skipif("not config.getoption('with_quark_data')") -@pytest.mark.skip(reason="need to generate values to compare against") class QuarkParityTest(unittest.TestCase): + def setUp(self): + super().setUp() + self.path_prefix = Path("/shark-dev/quark_test") + + @with_quark_data def test_compare_against_quark(self): - def both(key, index=None): - o = ours[key] - t = theirs[key] - if index is None: - return o, t - else: - return o[index], t[index] + sharktank_dir = str( + Path(os.path.dirname(os.path.abspath(__file__))).parent.parent.parent.parent + ) + our_path = self.path_prefix / "ours_prefill.safetensors" + if os.path.exists(our_path): + os.remove(our_path) mapping = dict() for i in range(32): @@ -41,18 +48,49 @@ def both(key, index=None): mapping[a] = b mapping[a + "_input_0"] = b + "_input_0" + command = [ + "python", + "-m", + "sharktank.examples.paged_llm_v1", + "The capitol of Texas is", + f"--irpa-file={self.path_prefix}/fp8_bf16_weight.irpa", + f"--tokenizer-config-json=/data/llama3.1/8b/tokenizer.json", + "--fake-quant", + "--attention-kernel=torch", + "--activation-dtype=bfloat16", + f"--save_intermediates_path={self.path_prefix}/ours", + "--use-hf", + "--attention-dtype=bfloat16", + "--skip-decode", + "--block-seq-stride=16", + ] + command = subprocess.list2cmdline(command) + proc = subprocess.run( + command, shell=True, capture_output=True, cwd=sharktank_dir + ) + ours = dict() - with safe_open("../ours_newest_prefill.safetensors", "pytorch") as st: + with safe_open(our_path, "pytorch") as st: for key in st.keys(): ours[key] = st.get_tensor(key) - theirs = dict() - with safe_open("../theirs2.safetensors", "pytorch") as st: + golden = dict() + golden_path = self.path_prefix / "golden.safetensors" + with safe_open(golden_path, "pytorch") as st: for key in st.keys(): if key in mapping: - theirs[mapping[key]] = st.get_tensor(key) + golden[mapping[key]] = st.get_tensor(key) test_layers = [v for k, v in mapping.items()] + + def both(key, index=None): + o = ours[key] + t = golden[key] + if index is None: + return o, t + else: + return o[index], t[index] + for lyr in test_layers: name = lyr if name in ours.keys() and name != "freqs":