diff --git a/benchmarks/kernels/benchmark_mixtral_moe.py b/benchmarks/kernels/benchmark_mixtral_moe.py deleted file mode 100644 index 196ec8cfce88..000000000000 --- a/benchmarks/kernels/benchmark_mixtral_moe.py +++ /dev/null @@ -1,239 +0,0 @@ -import argparse -import json -import os -import sys - -import torch -import torch.nn.functional as F -import triton -from tqdm import tqdm - -from vllm.model_executor.layers.fused_moe import (fused_moe, - get_config_file_name) - - -def main(model, tp_size, gpu, dtype: str): - os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu) - method = fused_moe - for bs in [ - 1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128, 256, 512, 1024, 1536, - 2048, 3072, 4096 - ]: - run_grid(bs, - model=model, - method=method, - gpu=gpu, - tp_size=tp_size, - dtype=dtype) - - -def run_grid(bs, model, method, gpu, tp_size, dtype: str): - if model == '8x7B': - d_model = 4096 - model_intermediate_size = 14336 - num_layers = 32 - elif model == '8x22B': - d_model = 6144 - model_intermediate_size = 16384 - num_layers = 56 - else: - raise ValueError(f'Unsupported Mixtral model {model}') - num_total_experts = 8 - top_k = 2 - # tp_size = 2 - num_calls = 100 - - num_warmup_trials = 1 - num_trials = 1 - - configs = [] - - for block_size_n in [32, 64, 128, 256]: - for block_size_m in [16, 32, 64, 128, 256]: - for block_size_k in [64, 128, 256]: - for group_size_m in [1, 16, 32, 64]: - for num_warps in [4, 8]: - for num_stages in [2, 3, 4, 5]: - configs.append({ - "BLOCK_SIZE_M": block_size_m, - "BLOCK_SIZE_N": block_size_n, - "BLOCK_SIZE_K": block_size_k, - "GROUP_SIZE_M": group_size_m, - "num_warps": num_warps, - "num_stages": num_stages, - }) - - best_config = None - best_time_us = 1e20 - - print(f'{tp_size=} {bs=}') - - for config in tqdm(configs): - # warmup - try: - for _ in range(num_warmup_trials): - run_timing( - num_calls=num_calls, - bs=bs, - d_model=d_model, - num_total_experts=num_total_experts, - top_k=top_k, - tp_size=tp_size, - model_intermediate_size=model_intermediate_size, - method=method, - config=config, - dtype=dtype, - ) - except triton.runtime.autotuner.OutOfResources: - continue - - # trial - for _ in range(num_trials): - kernel_dur_ms = run_timing( - num_calls=num_calls, - bs=bs, - d_model=d_model, - num_total_experts=num_total_experts, - top_k=top_k, - tp_size=tp_size, - model_intermediate_size=model_intermediate_size, - method=method, - config=config, - dtype=dtype, - ) - - kernel_dur_us = 1000 * kernel_dur_ms - model_dur_ms = kernel_dur_ms * num_layers - - if kernel_dur_us < best_time_us: - best_config = config - best_time_us = kernel_dur_us - - tqdm.write( - f'{kernel_dur_us=:.1f} {model_dur_ms=:.1f}' - f' {bs=} {tp_size=} {top_k=} {num_total_experts=} ' - f'{d_model=} {model_intermediate_size=} {num_layers=}') - - print("best_time_us", best_time_us) - print("best_config", best_config) - - # holds Dict[str, Dict[str, int]] - filename = get_config_file_name(num_total_experts, - model_intermediate_size // tp_size, - "float8" if dtype == "float8" else None) - print(f"writing config to file {filename}") - existing_content = {} - if os.path.exists(filename): - with open(filename, "r") as f: - existing_content = json.load(f) - existing_content[str(bs)] = best_config - with open(filename, "w") as f: - json.dump(existing_content, f, indent=4) - f.write("\n") - - -def run_timing(num_calls: int, bs: int, d_model: int, num_total_experts: int, - top_k: int, tp_size: int, model_intermediate_size: int, method, - config, dtype: str) -> float: - shard_intermediate_size = model_intermediate_size // tp_size - - hidden_states = torch.rand( - (bs, d_model), - device="cuda:0", - dtype=torch.float16, - ) - - w1 = torch.rand( - (num_total_experts, 2 * shard_intermediate_size, d_model), - device=hidden_states.device, - dtype=hidden_states.dtype, - ) - - w2 = torch.rand( - (num_total_experts, d_model, shard_intermediate_size), - device=hidden_states.device, - dtype=hidden_states.dtype, - ) - - w1_scale = None - w2_scale = None - a1_scale = None - a2_scale = None - - if dtype == "float8": - w1 = w1.to(torch.float8_e4m3fn) - w2 = w2.to(torch.float8_e4m3fn) - w1_scale = torch.ones(num_total_experts, - device=hidden_states.device, - dtype=torch.float32) - w2_scale = torch.ones(num_total_experts, - device=hidden_states.device, - dtype=torch.float32) - a1_scale = torch.ones(1, - device=hidden_states.device, - dtype=torch.float32) - a2_scale = torch.ones(1, - device=hidden_states.device, - dtype=torch.float32) - - gating_output = F.softmax(torch.rand( - (num_calls, bs, num_total_experts), - device=hidden_states.device, - dtype=torch.float32, - ), - dim=-1) - - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - - start_event.record() - for i in range(num_calls): - hidden_states = method( - hidden_states=hidden_states, - w1=w1, - w2=w2, - w1_scale=w1_scale, - w2_scale=w2_scale, - a1_scale=a1_scale, - a2_scale=a2_scale, - gating_output=gating_output[i], - topk=2, - renormalize=True, - inplace=True, - override_config=config, - use_fp8=dtype == "float8", - ) - end_event.record() - end_event.synchronize() - - dur_ms = start_event.elapsed_time(end_event) / num_calls - return dur_ms - - -if __name__ == "__main__": - parser = argparse.ArgumentParser( - prog='benchmark_mixtral_moe', - description='Benchmark and tune the fused_moe kernel', - ) - parser.add_argument( - '--dtype', - type=str, - default='auto', - choices=['float8', 'float16'], - help='Data type used for fused_moe kernel computations', - ) - parser.add_argument('--model', - type=str, - default='8x7B', - choices=['8x7B', '8x22B'], - help='The Mixtral model to benchmark') - parser.add_argument('--tp-size', - type=int, - default=2, - help='Tensor paralleli size') - parser.add_argument('--gpu', - type=int, - default=0, - help="GPU ID for benchmarking") - args = parser.parse_args() - sys.exit(main(args.model, args.tp_size, args.gpu, args.dtype)) diff --git a/benchmarks/kernels/benchmark_moe.py b/benchmarks/kernels/benchmark_moe.py new file mode 100644 index 000000000000..d6fa39a4d30e --- /dev/null +++ b/benchmarks/kernels/benchmark_moe.py @@ -0,0 +1,319 @@ +import argparse +import time +from datetime import datetime +from typing import Any, Dict, List, Tuple + +import ray +import torch +import triton +from ray.experimental.tqdm_ray import tqdm +from transformers import AutoConfig + +from vllm.model_executor.layers.fused_moe.fused_moe import * + + +def benchmark_config( + config: Dict[str, int], + num_tokens: int, + num_experts: int, + shard_intermediate_size: int, + hidden_size: int, + topk: int, + dtype: torch.dtype, + use_fp8: bool, + num_iters: int = 100, +) -> float: + init_dtype = torch.float16 if use_fp8 else dtype + x = torch.randn(num_tokens, hidden_size, dtype=dtype) + w1 = torch.randn(num_experts, + shard_intermediate_size, + hidden_size, + dtype=init_dtype) + w2 = torch.randn(num_experts, + hidden_size, + shard_intermediate_size // 2, + dtype=init_dtype) + gating_output = torch.randn(num_iters, + num_tokens, + num_experts, + dtype=torch.float32) + + w1_scale = None + w2_scale = None + a1_scale = None + a2_scale = None + if use_fp8: + w1_scale = torch.randn(num_experts, dtype=torch.float32) + w2_scale = torch.randn(num_experts, dtype=torch.float32) + a1_scale = torch.randn(1, dtype=torch.float32) + a2_scale = torch.randn(1, dtype=torch.float32) + + w1 = w1.to(torch.float8_e4m3fn) + w2 = w2.to(torch.float8_e4m3fn) + + input_gating = torch.empty(num_tokens, num_experts, dtype=torch.float32) + + def prepare(i: int): + input_gating.copy_(gating_output[i]) + + def run(): + fused_moe( + x, + w1, + w2, + input_gating, + topk, + renormalize=True, + inplace=True, + override_config=config, + use_fp8=use_fp8, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + ) + + # JIT compilation & warmup + run() + torch.cuda.synchronize() + + # Capture 10 invocations with CUDA graph + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph): + for _ in range(10): + run() + torch.cuda.synchronize() + + # Warmup + for _ in range(5): + graph.replay() + torch.cuda.synchronize() + + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + latencies = [] + for i in range(num_iters): + prepare(i) + torch.cuda.synchronize() + + start_event.record() + graph.replay() + end_event.record() + end_event.synchronize() + latencies.append(start_event.elapsed_time(end_event)) + avg = sum(latencies) / (num_iters * 10) * 1000 # us + graph.reset() + return avg + + +def get_configs_compute_bound() -> List[Dict[str, int]]: + # Reduced search space for faster tuning. + # TODO(woosuk): Increase the search space and use a performance model to + # prune the search space. + configs = [] + for num_stages in [2, 3, 4, 5]: + for block_m in [16, 32, 64, 128, 256]: + for block_k in [64, 128, 256]: + for block_n in [32, 64, 128, 256]: + for num_warps in [4, 8]: + for group_size in [1, 16, 32, 64]: + configs.append({ + "BLOCK_SIZE_M": block_m, + "BLOCK_SIZE_N": block_n, + "BLOCK_SIZE_K": block_k, + "GROUP_SIZE_M": group_size, + "num_warps": num_warps, + "num_stages": num_stages, + }) + return configs + + +@ray.remote(num_gpus=1) +class BenchmarkWorker: + + def __init__(self, seed: int) -> None: + torch.set_default_device("cuda") + torch.cuda.manual_seed_all(seed) + self.seed = seed + + def benchmark( + self, + num_tokens: int, + num_experts: int, + shard_intermediate_size: int, + hidden_size: int, + topk: int, + dtype: torch.dtype, + use_fp8: bool, + ) -> Tuple[Dict[str, int], float]: + torch.cuda.manual_seed_all(self.seed) + + dtype_str = "float8" if use_fp8 else None + # NOTE(woosuk): The current naming convention uses w2.shape[2], which + # is the intermediate size after silu_and_mul. + op_config = get_moe_configs(num_experts, shard_intermediate_size // 2, + dtype_str) + if op_config is None: + config = get_default_config(num_tokens, num_experts, + shard_intermediate_size, hidden_size, + topk, dtype_str) + else: + config = op_config[min(op_config.keys(), + key=lambda x: abs(x - num_tokens))] + kernel_time = benchmark_config(config, num_tokens, num_experts, + shard_intermediate_size, hidden_size, + topk, dtype, use_fp8) + return config, kernel_time + + def tune( + self, + num_tokens: int, + num_experts: int, + shard_intermediate_size: int, + hidden_size: int, + topk: int, + dtype: torch.dtype, + use_fp8: bool, + search_space: List[Dict[str, int]], + ) -> Dict[str, int]: + best_config = None + best_time = float("inf") + for config in tqdm(search_space): + try: + kernel_time = benchmark_config(config, + num_tokens, + num_experts, + shard_intermediate_size, + hidden_size, + topk, + dtype, + use_fp8, + num_iters=10) + except triton.runtime.autotuner.OutOfResources: + # Some configurations may be invalid and fail to compile. + continue + + if kernel_time < best_time: + best_time = kernel_time + best_config = config + now = datetime.now() + print(f"{now.ctime()}] Completed tuning for batch_size={num_tokens}") + return best_config + + +def sort_config(config: Dict[str, int]) -> Dict[str, int]: + return { + "BLOCK_SIZE_M": config["BLOCK_SIZE_M"], + "BLOCK_SIZE_N": config["BLOCK_SIZE_N"], + "BLOCK_SIZE_K": config["BLOCK_SIZE_K"], + "GROUP_SIZE_M": config["GROUP_SIZE_M"], + "num_warps": config["num_warps"], + "num_stages": config["num_stages"], + } + + +def save_configs( + configs: Dict[int, Dict[str, int]], + num_experts: int, + shard_intermediate_size: int, + hidden_size: int, + topk: int, + dtype: torch.dtype, + use_fp8: bool, +) -> None: + dtype_str = "float8" if use_fp8 else None + # NOTE(woosuk): The current naming convention uses w2.shape[2], which + # is the intermediate size after silu_and_mul. + filename = get_config_file_name(num_experts, shard_intermediate_size // 2, + dtype_str) + print(f"Writing best config to {filename}...") + with open(filename, "w") as f: + json.dump(configs, f, indent=4) + f.write("\n") + + +def main(args: argparse.Namespace): + print(args) + + config = AutoConfig.from_pretrained(args.model) + if config.architectures[0] == "DbrxForCausalLM": + E = config.ffn_config.moe_num_experts + topk = config.ffn_config.moe_top_k + intermediate_size = config.ffn_config.ffn_hidden_size + shard_intermediate_size = 2 * intermediate_size // args.tp_size + else: + # Default: Mixtral. + E = config.num_local_experts + topk = config.num_experts_per_tok + intermediate_size = config.intermediate_size + shard_intermediate_size = 2 * intermediate_size // args.tp_size + + hidden_size = config.hidden_size + dtype = config.torch_dtype + use_fp8 = args.dtype == "fp8" + + if args.batch_size is None: + batch_sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096] + else: + batch_sizes = [args.batch_size] + + ray.init() + num_gpus = int(ray.available_resources()["GPU"]) + workers = [BenchmarkWorker.remote(args.seed) for _ in range(num_gpus)] + + def _distribute(method: str, inputs: List[Any]) -> List[Any]: + outputs = [] + worker_idx = 0 + for input_args in inputs: + worker = workers[worker_idx] + worker_method = getattr(worker, method) + output = worker_method.remote(*input_args) + outputs.append(output) + worker_idx = (worker_idx + 1) % num_gpus + return ray.get(outputs) + + if args.tune: + search_space = get_configs_compute_bound() + print(f"Start tuning over {len(search_space)} configurations...") + + start = time.time() + configs = _distribute( + "tune", [(batch_size, E, shard_intermediate_size, hidden_size, + topk, dtype, use_fp8, search_space) + for batch_size in batch_sizes]) + best_configs = { + M: sort_config(config) + for M, config in zip(batch_sizes, configs) + } + save_configs(best_configs, E, shard_intermediate_size, hidden_size, + topk, dtype, use_fp8) + end = time.time() + print(f"Tuning took {end - start:.2f} seconds") + else: + outputs = _distribute("benchmark", + [(batch_size, E, shard_intermediate_size, + hidden_size, topk, dtype, use_fp8) + for batch_size in batch_sizes]) + + for batch_size, (config, kernel_time) in zip(batch_sizes, outputs): + print(f"Batch size: {batch_size}, config: {config}") + print(f"Kernel time: {kernel_time:.2f} us") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model", + type=str, + default="mistralai/Mixtral-8x7B-Instruct-v0.1") + parser.add_argument("--tp-size", "-tp", type=int, default=2) + parser.add_argument("--dtype", + type=str, + choices=["auto", "fp8"], + default="auto") + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--batch-size", type=int, required=False) + parser.add_argument("--tune", action="store_true") + args = parser.parse_args() + + main(args) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 20a3c9f6f893..1c6947137a1c 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -308,6 +308,30 @@ def get_moe_configs(E: int, N: int, return None +def get_default_config( + M: int, + E: int, + N: int, + K: int, + topk: int, + dtype: Optional[str], +) -> Dict[str, int]: + config = { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 8 + } + if M <= E: + config = { + 'BLOCK_SIZE_M': 16, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 64, + 'GROUP_SIZE_M': 1 + } + return config + + def fused_topk( hidden_states: torch.Tensor, gating_output: torch.Tensor, @@ -382,20 +406,9 @@ def fused_experts(hidden_states: torch.Tensor, config = configs[min(configs.keys(), key=lambda x: abs(x - M))] else: # Else use the default config - config = { - 'BLOCK_SIZE_M': 64, - 'BLOCK_SIZE_N': 64, - 'BLOCK_SIZE_K': 32, - 'GROUP_SIZE_M': 8 - } - - if M <= E: - config = { - 'BLOCK_SIZE_M': 16, - 'BLOCK_SIZE_N': 32, - 'BLOCK_SIZE_K': 64, - 'GROUP_SIZE_M': 1 - } + config = get_default_config(M, E, N, w1.shape[2], + topk_ids.shape[1], + "float8" if use_fp8 else None) intermediate_cache1 = torch.empty((M, topk_ids.shape[1], N), device=hidden_states.device,