Skip to content

Commit

Permalink
Fix llama 8b fp8 benchmarking test (#894)
Browse files Browse the repository at this point in the history
Benchmarking test for fp8 was failing the Export phase, but was expected
to fail for the Benchmarking phase and was getting killed.

---------

Signed-off-by: aviator19941 <avinash.sharma@amd.com>
  • Loading branch information
aviator19941 authored Feb 3, 2025
1 parent a2f8334 commit d1a3f3d
Showing 1 changed file with 11 additions and 51 deletions.
62 changes: 11 additions & 51 deletions sharktank/tests/models/llama/benchmark_amdgpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,12 @@ class BenchmarkLlama3_1_8B(BaseBenchmarkTest):
def setUp(self):
super().setUp()
# TODO: add numpy files to Azure and download from it
self.artifacts_dir = Path("/shark-dev/data/llama3.1/weights/8b")
self.artifacts_dir_2048 = Path("/shark-dev/8b")
self.irpa_path = self.artifacts_dir / "fp16/llama3.1_8b_fp16.irpa"
self.irpa_path_fp8 = self.artifacts_dir / "f8/llama3.1_8b_fp8.irpa"
self.artifacts_dir = Path("/shark-dev/8b")
self.weights_dir = self.artifacts_dir / "instruct/weights"
self.irpa_path = self.weights_dir / "llama3.1_8b_instruct_fp16.irpa"
self.irpa_path_fp8 = (
self.artifacts_dir / "fp8/native_fp8_e4m3fnuz_llama3_8b.irpa"
)
self.tensor_parallelism_size = 1
self.dir_path_8b = self.dir_path / "llama-8b"
self.temp_dir_8b = Path(self.dir_path_8b)
Expand All @@ -83,15 +85,6 @@ def setUp(self):
tensor_parallelism_size=self.tensor_parallelism_size,
block_seq_stride=32,
)
self.llama8b_fp8_decomposed_artifacts = ExportArtifacts(
irpa_path=str(self.irpa_path_fp8),
batch_size=4,
iree_hip_target="gfx942",
iree_hal_target_device="hip",
attention_kernel="decomposed",
tensor_parallelism_size=self.tensor_parallelism_size,
block_seq_stride=32,
)
self.llama8b_fp8_torch_sdpa_artifacts = ExportArtifacts(
irpa_path=str(self.irpa_path_fp8),
batch_size=4,
Expand All @@ -104,16 +97,16 @@ def setUp(self):
attention_dtype="float8_e4m3fnuz",
)
self.prefill_args_bs4_128_stride_32_f16 = (
self.artifacts_dir / "prefill_args_bs4_128_stride_32"
self.artifacts_dir / "prefill_args_bs4_128_stride_32_tp1"
)
self.decode_args_bs4_128_stride_32_f16 = (
self.artifacts_dir / "decode_args_bs4_128_stride_32"
self.artifacts_dir / "decode_args_bs4_128_stride_32_tp1"
)
self.prefill_args_bs4_2048_stride_32_f16 = (
self.artifacts_dir_2048 / "prefill_args_bs4_2048_stride_32"
self.artifacts_dir / "prefill_args_bs4_2048_stride_32"
)
self.decode_args_bs4_2048_stride_32_f16 = (
self.artifacts_dir_2048 / "decode_args_bs4_2048_stride_32"
self.artifacts_dir / "decode_args_bs4_2048_stride_32"
)
self.prefill_args_fp8 = self.artifacts_dir / "prefill_args_fp8"
self.decode_args_fp8 = self.artifacts_dir / "decode_args_fp8"
Expand Down Expand Up @@ -169,39 +162,6 @@ def setUp(self):
"--benchmark_repetitions=3",
]

def testBenchmark8B_f16_Non_Decomposed_Prefill_Input_Len_128(self):
output_file_name = self.dir_path_8b / "f16_torch_prefill_128"
output_mlir = self.llama8b_f16_torch_sdpa_artifacts.create_file(
suffix=".mlir", prefix=output_file_name
)
output_json = self.llama8b_f16_torch_sdpa_artifacts.create_file(
suffix=".json", prefix=output_file_name
)
output_vmfb = self.llama8b_f16_torch_sdpa_artifacts.create_file(
suffix=".vmfb", prefix=output_file_name
)
export_return_code = self.llama8b_f16_torch_sdpa_artifacts.export_to_mlir(
mlir_path=output_mlir,
json_path=output_json,
skip_decode=True,
)
self.llama8b_f16_torch_sdpa_artifacts.compile_to_vmfb(
mlir_path=str(output_mlir),
vmfb_path=output_vmfb,
hal_dump_path=output_file_name,
cwd=self.repo_root,
args=self.compile_args,
)
# benchmark prefill
self.llama8b_f16_torch_sdpa_artifacts.iree_benchmark_vmfb(
hip_device_id=self.iree_device,
vmfb_name=output_vmfb,
irpa_path=self.irpa_path,
args=self.iree_run_prefill_nondecomposed_args_fp16,
cwd=self.repo_root,
)

@skipif_run_quick_llama_test
def testBenchmark8B_f16_Non_Decomposed_Input_Len_128(self):
output_file_name = self.dir_path_8b / "f16_torch_128"
output_mlir = self.llama8b_f16_torch_sdpa_artifacts.create_file(
Expand Down Expand Up @@ -283,7 +243,7 @@ def testBenchmark8B_f16_Non_Decomposed_Input_Len_2048(self):

@pytest.mark.xfail(
reason="Benchmark inputs not configured yet.",
strict=False,
strict=True,
raises=IreeBenchmarkException,
)
def testBenchmark8B_fp8_Non_Decomposed(self):
Expand Down

0 comments on commit d1a3f3d

Please sign in to comment.