-
Notifications
You must be signed in to change notification settings - Fork 37
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Flux transformer benchmarking #870
base: main
Are you sure you want to change the base?
Changes from 3 commits
47e648a
51862dc
5fafeb7
2fa1288
c52b8a7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
# Copyright 2024 Advanced Micro Devices, Inc. | ||
# | ||
# Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
name: Sharktank Nightly Tests | ||
|
||
on: | ||
workflow_dispatch: | ||
schedule: | ||
# Weekdays at 10:00 AM UTC = 02:00 AM PST / 03:00 AM PDT | ||
- cron: "0 10 * * 1-5" | ||
|
||
concurrency: | ||
# A PR number if a pull request and otherwise the commit hash. This cancels | ||
# queued and in-progress runs for the same PR (presubmit) or commit | ||
# (postsubmit). The workflow name is prepended to avoid conflicts between | ||
# different workflows. | ||
group: ${{ github.workflow }}-${{ github.event.number || github.sha }} | ||
cancel-in-progress: true | ||
|
||
jobs: | ||
nightly-mi300x: | ||
if: ${{ github.repository_owner == 'nod-ai' || github.event_name != 'schedule' }} | ||
name: "Nightly tests and benchmarks" | ||
strategy: | ||
matrix: | ||
version: [3.11] | ||
fail-fast: false | ||
runs-on: llama-mi300x-3 | ||
defaults: | ||
run: | ||
shell: bash | ||
env: | ||
VENV_DIR: ${{ github.workspace }}/.venv | ||
HF_HOME: "/data/huggingface" | ||
steps: | ||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 | ||
|
||
- name: Get Current Date | ||
id: date | ||
run: echo "::set-output name=date::$(date +'%Y-%m-%d')" | ||
|
||
- name: "Setting up Python" | ||
id: setup_python | ||
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 | ||
with: | ||
python-version: ${{matrix.version}} | ||
- name: Create Python venv | ||
run: python -m venv ${VENV_DIR} | ||
|
||
- name: Install pip deps | ||
run: | | ||
source ${VENV_DIR}/bin/activate | ||
python -m pip install --no-compile --upgrade pip | ||
|
||
# Note: We install in three steps in order to satisfy requirements | ||
# from non default locations first. | ||
pip install --no-compile -r pytorch-cpu-requirements.txt | ||
pip install -r requirements-iree-unpinned.txt | ||
pip install --no-compile \ | ||
-r sharktank/requirements-tests.txt \ | ||
-e sharktank/ | ||
|
||
pip freeze | ||
|
||
- name: Run benchmarks | ||
run: | | ||
source ${VENV_DIR}/bin/activate | ||
pytest \ | ||
--verbose \ | ||
--capture=no \ | ||
--iree-hip-target=gfx942 \ | ||
--iree-device=hip://0 \ | ||
--with-flux-data \ | ||
-m="benchmark" \ | ||
--html=out/sharktank-nightly/benchmark/index.html \ | ||
archana-ramalingam marked this conversation as resolved.
Show resolved
Hide resolved
|
||
sharktank/tests | ||
|
||
- name: Deploy to GitHub Pages | ||
uses: peaceiris/actions-gh-pages@4f9cc6602d3f66b9c108549d475ec49e8ef4d45e # v4.0.0 | ||
with: | ||
github_token: ${{ secrets.SHARK_PLATFORM_GH_TOKEN }} | ||
publish_dir: ./out/sharktank-nightly | ||
destination_dir: ./sharktank-nightly | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I believe There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wanted to have one step to push all sharktank_nightly artifacts. This workflow in the future may have more jobs that produce artifacts. I would be surprised if the action does not support whole directory trees. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure, were you able to test this nightly on this PR? |
||
keep_files: true |
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,3 @@ | ||||||||||||||||||||||||||||||||
[submodule "third_party/benchmark"] | ||||||||||||||||||||||||||||||||
path = third_party/benchmark | ||||||||||||||||||||||||||||||||
url = https://github.com/google/benchmark | ||||||||||||||||||||||||||||||||
Comment on lines
+1
to
+3
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this actually used? You have A source dependency on the C++ library could be added to shortfin as needed, probably via FetchContent here: shark-ai/shortfin/CMakeLists.txt Lines 326 to 340 in 4eac34e
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wrongly left the pip package dependency. It is removed now. Unfortunately, the script to compare benchmark results is not a part of the pip package. This script is used in the python test. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
# Copyright 2024 Advanced Micro Devices, Inc. | ||
# | ||
# Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
from pathlib import Path | ||
import iree.compiler | ||
import iree.runtime | ||
import os | ||
from iree.turbine.support.tools import iree_tool_prepare_input_args | ||
|
||
from .export import ( | ||
export_flux_transformer_from_hugging_face, | ||
flux_transformer_default_batch_sizes, | ||
iree_compile_flags, | ||
) | ||
from ...types import Dataset | ||
from .flux import FluxModelV1, FluxParams | ||
from ...utils.export_artifacts import ExportArtifacts | ||
from ...utils.iree import flatten_for_iree_signature | ||
from ...utils.benchmark import iree_benchmark_module | ||
|
||
|
||
def iree_benchmark_flux_dev_transformer( | ||
artifacts_dir: Path, | ||
iree_device: str, | ||
json_result_output_path: Path, | ||
caching: bool = False, | ||
) -> str: | ||
mlir_path = artifacts_dir / "model.mlir" | ||
parameters_path = artifacts_dir / "parameters.irpa" | ||
if ( | ||
not caching | ||
or not os.path.exists(mlir_path) | ||
or not os.path.exists(parameters_path) | ||
): | ||
export_flux_transformer_from_hugging_face( | ||
"black-forest-labs/FLUX.1-dev/black-forest-labs-transformer", | ||
mlir_output_path=mlir_path, | ||
parameters_output_path=parameters_path, | ||
) | ||
return iree_benchmark_flux_transformer( | ||
mlir_path=mlir_path, | ||
parameters_path=parameters_path, | ||
artifacts_dir=artifacts_dir, | ||
iree_device=iree_device, | ||
json_result_output_path=json_result_output_path, | ||
caching=caching, | ||
) | ||
|
||
|
||
def iree_benchmark_flux_transformer( | ||
artifacts_dir: Path, | ||
mlir_path: Path, | ||
parameters_path: Path, | ||
iree_device: str, | ||
json_result_output_path: Path, | ||
caching: bool = False, | ||
) -> str: | ||
dataset = Dataset.load(parameters_path) | ||
model = FluxModelV1( | ||
theta=dataset.root_theta, | ||
params=FluxParams.from_hugging_face_properties(dataset.properties), | ||
) | ||
input_args = flatten_for_iree_signature( | ||
model.sample_inputs(batch_size=flux_transformer_default_batch_sizes[0]) | ||
) | ||
cli_input_args = iree_tool_prepare_input_args( | ||
input_args, file_path_prefix=f"{artifacts_dir / 'arg'}" | ||
) | ||
cli_input_args = [f"--input={v}" for v in cli_input_args] | ||
|
||
iree_module_path = artifacts_dir / "model.vmfb" | ||
if not caching or not os.path.exists(iree_module_path): | ||
iree.compiler.compile_file( | ||
mlir_path, | ||
output_file=iree_module_path, | ||
extra_args=iree_compile_flags, | ||
) | ||
|
||
iree_benchmark_args = [ | ||
f"--device={iree_device}", | ||
f"--module={iree_module_path}", | ||
f"--parameters=model={parameters_path}", | ||
f"--function=forward_bs{flux_transformer_default_batch_sizes[0]}", | ||
"--benchmark_repetitions=30", | ||
"--benchmark_min_warmup_time=1.0", | ||
"--benchmark_out_format=json", | ||
f"--benchmark_out={json_result_output_path}", | ||
] + cli_input_args | ||
return iree_benchmark_module(iree_benchmark_args) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
# Copyright 2025 Advanced Micro Devices, Inc. | ||
# | ||
# Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
from typing import Any | ||
import iree.runtime | ||
import subprocess | ||
import json | ||
import sys | ||
import pandas | ||
from pathlib import Path | ||
import os | ||
from os import PathLike | ||
|
||
|
||
def _run_program( | ||
args: tuple[str], | ||
): | ||
process_result = subprocess.run( | ||
args=args, | ||
stdout=subprocess.PIPE, | ||
stderr=subprocess.PIPE, | ||
) | ||
|
||
out = process_result.stdout.decode() | ||
err = process_result.stderr.decode() | ||
|
||
if process_result.returncode != 0: | ||
raise RuntimeError(f"stderr:\n{err}\nstdout:\n{out}") | ||
|
||
if err != "": | ||
print(err, file=sys.stderr) | ||
|
||
return out | ||
|
||
|
||
def iree_benchmark_module( | ||
cli_args: tuple[str], | ||
): | ||
args = [iree.runtime.benchmark_exe()] + cli_args | ||
return _run_program(args=args) | ||
|
||
|
||
def google_benchmark_compare_path() -> str: | ||
return os.path.abspath( | ||
Path(__file__).parent.parent.parent.parent | ||
/ "third_party" | ||
/ "benchmark" | ||
/ "tools" | ||
/ "compare.py" | ||
) | ||
|
||
|
||
def iree_benchmark_compare(cli_args: tuple[str]): | ||
args = [google_benchmark_compare_path()] + cli_args | ||
return _run_program(args=args) | ||
|
||
|
||
def _get_benchmark_comparison_aggregate_real_time( | ||
benchmark_comparison_result_json: dict[str, Any], aggregate: str | ||
) -> tuple[float, float, str]: | ||
real_time = [ | ||
( | ||
benchmark["measurements"][0]["real_time"], | ||
benchmark["measurements"][0]["real_time_other"], | ||
benchmark["time_unit"], | ||
) | ||
for benchmark in benchmark_comparison_result_json | ||
if "aggregate_name" in benchmark and benchmark["aggregate_name"] == aggregate | ||
] | ||
assert len(real_time) == 1 | ||
return real_time[0] | ||
|
||
|
||
def _assert_contender_aggregate_real_time_is_not_worse( | ||
benchmark_comparison_result_json: dict[str, Any], aggregate: str | ||
): | ||
real_time = _get_benchmark_comparison_aggregate_real_time( | ||
benchmark_comparison_result_json, aggregate | ||
) | ||
baseline_real_time, contender_real_time, time_unit = real_time | ||
if baseline_real_time < contender_real_time: | ||
raise AssertionError( | ||
f"Benchmark contender {aggregate} " | ||
f"real time {contender_real_time} {time_unit} " | ||
f"is worse than baseline {baseline_real_time} {time_unit}." | ||
) | ||
|
||
|
||
def iree_benchmark_assert_contender_is_not_worse( | ||
benchmark_comparison_result_json: dict[str, Any], alpha: float = 0.05 | ||
): | ||
"""If contender is not from the same distribution as baseline, assert that and | ||
that its median and mean is not worse. | ||
|
||
Arguments | ||
--------- | ||
alpha: acceptance/significance threshold probability that the two benchmark sample | ||
sets are from the same distribution. Meaning they are not different.""" | ||
time_pvalue = [ | ||
b["utest"]["time_pvalue"] | ||
for b in benchmark_comparison_result_json | ||
if "utest" in b and "time_pvalue" in b["utest"] | ||
] | ||
assert len(time_pvalue) == 1 | ||
time_pvalue = time_pvalue[0] | ||
if alpha <= time_pvalue: | ||
# The benchmarks are from the same distribution. | ||
return | ||
|
||
_assert_contender_aggregate_real_time_is_not_worse( | ||
benchmark_comparison_result_json, "mean" | ||
) | ||
_assert_contender_aggregate_real_time_is_not_worse( | ||
benchmark_comparison_result_json, "median" | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd much prefer we consolidate existing nightly workflows (eval, sglang benchmark, llama large, etc.) before adding a new one. That being said, I like the "sharktank-nightly" name better than model-specific names...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the CI jobs are due for a refactoring. To reorganize them and to reduce code duplication. I want to do that next.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The advantage of keeping them separately is the ease of tracking various nightly across llm, flux models and sharktank/ shortfin regressions. If there are tools to do this, or a CI summary in github-pages, that would be great.