Skip to content

Commit

Permalink
[Kernel] W8A16 Int8 inside FusedMoE (#7415)
Browse files Browse the repository at this point in the history
  • Loading branch information
mzusman authored Aug 16, 2024
1 parent e837b62 commit 7fc23be
Show file tree
Hide file tree
Showing 15 changed files with 412 additions and 136 deletions.
108 changes: 70 additions & 38 deletions benchmarks/kernels/benchmark_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,19 +30,36 @@ def benchmark_config(
hidden_size: int,
topk: int,
dtype: torch.dtype,
use_fp8: bool,
use_fp8_w8a8: bool,
use_int8_w8a16: bool,
num_iters: int = 100,
) -> float:
init_dtype = torch.float16 if use_fp8 else dtype
init_dtype = torch.float16 if use_fp8_w8a8 else dtype
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
w1 = torch.randn(num_experts,
shard_intermediate_size,
hidden_size,
dtype=init_dtype)
w2 = torch.randn(num_experts,
hidden_size,
shard_intermediate_size // 2,
dtype=init_dtype)
if use_int8_w8a16:
w1 = torch.randint(-127,
127, (
num_experts,
shard_intermediate_size,
hidden_size,
),
dtype=torch.int8)
w2 = torch.randint(-127,
127, (
num_experts,
hidden_size,
shard_intermediate_size // 2,
),
dtype=torch.int8)
else:
w1 = torch.randn(num_experts,
shard_intermediate_size,
hidden_size,
dtype=init_dtype)
w2 = torch.randn(num_experts,
hidden_size,
shard_intermediate_size // 2,
dtype=init_dtype)
gating_output = torch.randn(num_iters,
num_tokens,
num_experts,
Expand All @@ -52,7 +69,11 @@ def benchmark_config(
w2_scale = None
a1_scale = None
a2_scale = None
if use_fp8:
if use_int8_w8a16:
w1_scale = torch.randn((num_experts, 2 * shard_intermediate_size),
dtype=torch.float32)
w2_scale = torch.randn((hidden_size, num_experts), dtype=torch.float32)
if use_fp8_w8a8:
w1_scale = torch.randn(num_experts, dtype=torch.float32)
w2_scale = torch.randn(num_experts, dtype=torch.float32)
a1_scale = torch.randn(1, dtype=torch.float32)
Expand All @@ -76,7 +97,8 @@ def run():
renormalize=True,
inplace=True,
override_config=config,
use_fp8=use_fp8,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a16=use_int8_w8a16,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
Expand Down Expand Up @@ -155,11 +177,13 @@ def benchmark(
hidden_size: int,
topk: int,
dtype: torch.dtype,
use_fp8: bool,
use_fp8_w8a8: bool,
use_int8_w8a16: bool,
) -> Tuple[Dict[str, int], float]:
torch.cuda.manual_seed_all(self.seed)

dtype_str = "float8" if use_fp8 else None
dtype_str = get_config_dtype_str(dtype,
use_int8_w8a16=use_int8_w8a16,
use_fp8_w8a8=use_fp8_w8a8)
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
# is the intermediate size after silu_and_mul.
op_config = get_moe_configs(num_experts, shard_intermediate_size // 2,
Expand All @@ -173,7 +197,8 @@ def benchmark(
key=lambda x: abs(x - num_tokens))]
kernel_time = benchmark_config(config, num_tokens, num_experts,
shard_intermediate_size, hidden_size,
topk, dtype, use_fp8)
topk, dtype, use_fp8_w8a8,
use_int8_w8a16)
return config, kernel_time

def tune(
Expand All @@ -184,9 +209,10 @@ def tune(
hidden_size: int,
topk: int,
dtype: torch.dtype,
use_fp8: bool,
search_space: List[BenchmarkConfig],
) -> BenchmarkConfig:
use_fp8_w8a8: bool,
use_int8_w8a16: bool,
search_space: List[Dict[str, int]],
) -> Dict[str, int]:
best_config = None
best_time = float("inf")
for config in tqdm(search_space):
Expand All @@ -198,7 +224,8 @@ def tune(
hidden_size,
topk,
dtype,
use_fp8,
use_fp8_w8a8,
use_int8_w8a16,
num_iters=10)
except triton.runtime.autotuner.OutOfResources:
# Some configurations may be invalid and fail to compile.
Expand All @@ -224,20 +251,19 @@ def sort_config(config: BenchmarkConfig) -> BenchmarkConfig:
}


def save_configs(
configs: Dict[int, BenchmarkConfig],
num_experts: int,
shard_intermediate_size: int,
hidden_size: int,
topk: int,
dtype: torch.dtype,
use_fp8: bool,
) -> None:
dtype_str = "float8" if use_fp8 else None
def save_configs(configs: Dict[int, BenchmarkConfig], num_experts: int,
shard_intermediate_size: int, hidden_size: int, topk: int,
dtype: torch.dtype, use_fp8_w8a8: bool,
use_int8_w8a16: bool) -> None:
dtype_str = get_config_dtype_str(dtype,
use_int8_w8a16=use_int8_w8a16,
use_fp8_w8a8=use_fp8_w8a8)

# NOTE(woosuk): The current naming convention uses w2.shape[2], which
# is the intermediate size after silu_and_mul.
filename = get_config_file_name(num_experts, shard_intermediate_size // 2,
dtype_str)

print(f"Writing best config to {filename}...")
with open(filename, "w") as f:
json.dump(configs, f, indent=4)
Expand All @@ -253,6 +279,11 @@ def main(args: argparse.Namespace):
topk = config.ffn_config.moe_top_k
intermediate_size = config.ffn_config.ffn_hidden_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
elif config.architectures[0] == "JambaForCausalLM":
E = config.num_experts
topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
else:
# Default: Mixtral.
E = config.num_local_experts
Expand All @@ -262,7 +293,8 @@ def main(args: argparse.Namespace):

hidden_size = config.hidden_size
dtype = config.torch_dtype
use_fp8 = args.dtype == "fp8"
use_fp8_w8a8 = args.dtype == "fp8_w8a8"
use_int8_w8a16 = args.dtype == "int8_w8a16"

if args.batch_size is None:
batch_sizes = [
Expand Down Expand Up @@ -294,21 +326,21 @@ def _distribute(method: str, inputs: List[Any]) -> List[Any]:
start = time.time()
configs = _distribute(
"tune", [(batch_size, E, shard_intermediate_size, hidden_size,
topk, dtype, use_fp8, search_space)
topk, dtype, use_fp8_w8a8, use_int8_w8a16, search_space)
for batch_size in batch_sizes])
best_configs = {
M: sort_config(config)
for M, config in zip(batch_sizes, configs)
}
save_configs(best_configs, E, shard_intermediate_size, hidden_size,
topk, dtype, use_fp8)
topk, dtype, use_fp8_w8a8, use_int8_w8a16)
end = time.time()
print(f"Tuning took {end - start:.2f} seconds")
else:
outputs = _distribute("benchmark",
[(batch_size, E, shard_intermediate_size,
hidden_size, topk, dtype, use_fp8)
for batch_size in batch_sizes])
outputs = _distribute(
"benchmark", [(batch_size, E, shard_intermediate_size, hidden_size,
topk, dtype, use_fp8_w8a8, use_int8_w8a16)
for batch_size in batch_sizes])

for batch_size, (config, kernel_time) in zip(batch_sizes, outputs):
print(f"Batch size: {batch_size}, config: {config}")
Expand All @@ -323,7 +355,7 @@ def _distribute(method: str, inputs: List[Any]) -> List[Any]:
parser.add_argument("--tp-size", "-tp", type=int, default=2)
parser.add_argument("--dtype",
type=str,
choices=["auto", "fp8"],
choices=["auto", "fp8_w8a8", "int8_w8a16"],
default="auto")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--batch-size", type=int, required=False)
Expand Down
13 changes: 7 additions & 6 deletions tests/models/test_jamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,12 @@
MODELS = ["ai21labs/Jamba-tiny-random"]


# Fails due to usage of MoE as MLP(E=1_, which is different than the HF impl
# TODO: Fix this with trained model
@pytest.mark.skip()
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("max_tokens", [20])
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [10])
def test_models(
hf_runner,
vllm_runner,
Expand All @@ -17,8 +20,6 @@ def test_models(
dtype: str,
max_tokens: int,
) -> None:
# To pass the small model tests, we need full precision.
assert dtype == "float"

with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
Expand All @@ -36,8 +37,8 @@ def test_models(


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("max_tokens", [20])
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [5])
def test_batching(
vllm_runner,
example_prompts,
Expand Down
28 changes: 28 additions & 0 deletions tests/quantization/test_experts_int8.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# flake8: noqa
"""Tests experts_int8 quantization startup and generation,
doesn't test correctness
"""
import pytest

from tests.quantization.utils import is_quant_method_supported

MODELS = ["ai21labs/Jamba-tiny-random"]


@pytest.mark.skipif(not is_quant_method_supported("experts_int8"),
reason="ExpertsInt8 is not supported on this GPU type.")
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [10])
def test_model_experts_int8_startup(
hf_runner,
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
) -> None:

with vllm_runner(model, dtype=dtype,
quantization="experts_int8") as vllm_model:
vllm_model.generate_greedy(example_prompts, max_tokens)
3 changes: 2 additions & 1 deletion vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,8 @@ def _verify_quantization(self) -> None:
rocm_supported_quantization = ["gptq", "squeezellm", "fp8"]
optimized_quantization_methods = [
"fp8", "marlin", "gptq_marlin_24", "gptq_marlin", "awq_marlin",
"fbgemm_fp8", "compressed_tensors", "compressed-tensors"
"fbgemm_fp8", "compressed_tensors", "compressed-tensors",
"experts_int8"
]
tpu_supported_quantization = ["tpu_int8"]
if self.quantization is not None:
Expand Down
Loading

0 comments on commit 7fc23be

Please sign in to comment.