Skip to content

Commit

Permalink
[sharktank][llama] enable quark parity test on mi300x (#886)
Browse files Browse the repository at this point in the history
  • Loading branch information
dan-garvey authored Feb 1, 2025
1 parent 4eac34e commit a2f8334
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 24 deletions.
4 changes: 3 additions & 1 deletion .github/workflows/ci-sharktank.yml
Original file line number Diff line number Diff line change
Expand Up @@ -141,12 +141,14 @@ jobs:
--with-flux-data \
--with-t5-data \
--with-vae-data \
--with-quark-data \
sharktank/tests/models/clip/clip_test.py \
sharktank/tests/models/t5/t5_test.py \
sharktank/tests/models/flux/flux_test.py \
sharktank/tests/models/vae/vae_test.py \
sharktank/tests/models/llama/quark_parity_test.py \
--durations=0 \
--timeout=600
--timeout=800
test_integration:
Expand Down
9 changes: 9 additions & 0 deletions sharktank/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,15 @@ def pytest_addoption(parser):
),
)

parser.addoption(
"--with-quark-data",
action="store_true",
default=False,
help=(
"Enable tests that use vae data such as models not part of the source code."
),
)

# TODO: Remove all hardcoded paths in CI tests
parser.addoption(
"--llama3-8b-tokenizer-path",
Expand Down
21 changes: 11 additions & 10 deletions sharktank/sharktank/examples/paged_llm_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,16 +325,17 @@ def main():
intermediates_saver.save_file(
args.save_intermediates_path + "_prefill.safetensors"
)
counter = 0
while not batch.done:
batch.decode()
if args.save_intermediates_path:
intermediates_saver.save_file(
args.save_intermediates_path + f"_step_{counter}.safetensors"
)
print(f":: Result tokens: {batch.results}")
batch.print_current_results()
counter += 1
if not args.skip_decode:
counter = 0
while not batch.done:
batch.decode()
if args.save_intermediates_path:
intermediates_saver.save_file(
args.save_intermediates_path + f"_step_{counter}.safetensors"
)
print(f":: Result tokens: {batch.results}")
batch.print_current_results()
counter += 1


if __name__ == "__main__":
Expand Down
64 changes: 51 additions & 13 deletions sharktank/tests/models/llama/quark_parity_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,31 @@
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import os

from safetensors import safe_open
import torch
import unittest
import pytest
from pathlib import Path
import subprocess

with_quark_data = pytest.mark.skipif("not config.getoption('with_quark_data')")


@pytest.mark.skip(reason="need to generate values to compare against")
class QuarkParityTest(unittest.TestCase):
def setUp(self):
super().setUp()
self.path_prefix = Path("/shark-dev/quark_test")

@with_quark_data
def test_compare_against_quark(self):
def both(key, index=None):
o = ours[key]
t = theirs[key]
if index is None:
return o, t
else:
return o[index], t[index]
sharktank_dir = str(
Path(os.path.dirname(os.path.abspath(__file__))).parent.parent.parent.parent
)
our_path = self.path_prefix / "ours_prefill.safetensors"
if os.path.exists(our_path):
os.remove(our_path)

mapping = dict()
for i in range(32):
Expand All @@ -41,18 +48,49 @@ def both(key, index=None):
mapping[a] = b
mapping[a + "_input_0"] = b + "_input_0"

command = [
"python",
"-m",
"sharktank.examples.paged_llm_v1",
"The capitol of Texas is",
f"--irpa-file={self.path_prefix}/fp8_bf16_weight.irpa",
f"--tokenizer-config-json=/data/llama3.1/8b/tokenizer.json",
"--fake-quant",
"--attention-kernel=torch",
"--activation-dtype=bfloat16",
f"--save_intermediates_path={self.path_prefix}/ours",
"--use-hf",
"--attention-dtype=bfloat16",
"--skip-decode",
"--block-seq-stride=16",
]
command = subprocess.list2cmdline(command)
proc = subprocess.run(
command, shell=True, capture_output=True, cwd=sharktank_dir
)

ours = dict()
with safe_open("../ours_newest_prefill.safetensors", "pytorch") as st:
with safe_open(our_path, "pytorch") as st:
for key in st.keys():
ours[key] = st.get_tensor(key)

theirs = dict()
with safe_open("../theirs2.safetensors", "pytorch") as st:
golden = dict()
golden_path = self.path_prefix / "golden.safetensors"
with safe_open(golden_path, "pytorch") as st:
for key in st.keys():
if key in mapping:
theirs[mapping[key]] = st.get_tensor(key)
golden[mapping[key]] = st.get_tensor(key)

test_layers = [v for k, v in mapping.items()]

def both(key, index=None):
o = ours[key]
t = golden[key]
if index is None:
return o, t
else:
return o[index], t[index]

for lyr in test_layers:
name = lyr
if name in ours.keys() and name != "freqs":
Expand Down

0 comments on commit a2f8334

Please sign in to comment.