diff --git a/.buildkite/check-wheel-size.py b/.buildkite/check-wheel-size.py index 75ad094fa1382..90a5e54736cf3 100644 --- a/.buildkite/check-wheel-size.py +++ b/.buildkite/check-wheel-size.py @@ -1,7 +1,7 @@ import os import zipfile -MAX_SIZE_MB = 200 +MAX_SIZE_MB = 100 def print_top_10_largest_files(zip_file): diff --git a/.buildkite/download-images.sh b/.buildkite/download-images.sh index 360a7584bccf1..389a12956c3c3 100644 --- a/.buildkite/download-images.sh +++ b/.buildkite/download-images.sh @@ -8,6 +8,10 @@ set -o pipefail # aws s3 sync s3://air-example-data-2/vllm_opensource_llava/ images/ mkdir -p images cd images +wget https://air-example-data-2.s3.us-west-2.amazonaws.com/vllm_opensource_llava/stop_sign_pixel_values.pt +wget https://air-example-data-2.s3.us-west-2.amazonaws.com/vllm_opensource_llava/stop_sign_image_features.pt +wget https://air-example-data-2.s3.us-west-2.amazonaws.com/vllm_opensource_llava/cherry_blossom_pixel_values.pt +wget https://air-example-data-2.s3.us-west-2.amazonaws.com/vllm_opensource_llava/cherry_blossom_image_features.pt wget https://air-example-data-2.s3.us-west-2.amazonaws.com/vllm_opensource_llava/stop_sign.jpg wget https://air-example-data-2.s3.us-west-2.amazonaws.com/vllm_opensource_llava/cherry_blossom.jpg diff --git a/.buildkite/lm-eval-harness/configs/Meta-Llama-3-70B-Instruct.yaml b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-70B-Instruct.yaml deleted file mode 100644 index fa6ea236ef04f..0000000000000 --- a/.buildkite/lm-eval-harness/configs/Meta-Llama-3-70B-Instruct.yaml +++ /dev/null @@ -1,11 +0,0 @@ -# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-hf-baseline.sh -m meta-llama/Meta-Llama-3-70B-Instruct -b 32 -l 250 -f 5 -model_name: "meta-llama/Meta-Llama-3-70B-Instruct" -tasks: -- name: "gsm8k" - metrics: - - name: "exact_match,strict-match" - value: 0.892 - - name: "exact_match,flexible-extract" - value: 0.892 -limit: 250 -num_fewshot: 5 diff --git a/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-FP8.yaml b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-FP8.yaml deleted file mode 100644 index 02668702b83af..0000000000000 --- a/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-FP8.yaml +++ /dev/null @@ -1,11 +0,0 @@ -# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-hf-baseline.sh -m neuralmagic/Meta-Llama-3-8B-Instruct-FP8 -b 32 -l 250 -f 5 -t 1 -model_name: "neuralmagic/Meta-Llama-3-8B-Instruct-FP8" -tasks: -- name: "gsm8k" - metrics: - - name: "exact_match,strict-match" - value: 0.756 - - name: "exact_match,flexible-extract" - value: 0.752 -limit: 250 -num_fewshot: 5 diff --git a/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct.yaml b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct.yaml deleted file mode 100644 index fb4b4915ab955..0000000000000 --- a/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct.yaml +++ /dev/null @@ -1,11 +0,0 @@ -# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-hf-baseline.sh -m meta-llama/Meta-Llama-3-8B-Instruct -b 32 -l 250 -f 5 -t 1 -model_name: "meta-llama/Meta-Llama-3-8B-Instruct" -tasks: -- name: "gsm8k" - metrics: - - name: "exact_match,strict-match" - value: 0.756 - - name: "exact_match,flexible-extract" - value: 0.752 -limit: 250 -num_fewshot: 5 diff --git a/.buildkite/lm-eval-harness/configs/Mixtral-8x7B-Instruct-v0.1.yaml b/.buildkite/lm-eval-harness/configs/Mixtral-8x7B-Instruct-v0.1.yaml deleted file mode 100644 index dec9164d1b84e..0000000000000 --- a/.buildkite/lm-eval-harness/configs/Mixtral-8x7B-Instruct-v0.1.yaml +++ /dev/null @@ -1,11 +0,0 @@ -# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-hf-baseline.sh -m neuralmagic/Mixtral-8x7B-Instruct-v0.1 -b 32 -l 250 -f 5 -t 4 -model_name: "mistralai/Mixtral-8x7B-Instruct-v0.1" -tasks: -- name: "gsm8k" - metrics: - - name: "exact_match,strict-match" - value: 0.616 - - name: "exact_match,flexible-extract" - value: 0.632 -limit: 250 -num_fewshot: 5 diff --git a/.buildkite/lm-eval-harness/configs/models-large.txt b/.buildkite/lm-eval-harness/configs/models-large.txt deleted file mode 100644 index 127ec5d97bcff..0000000000000 --- a/.buildkite/lm-eval-harness/configs/models-large.txt +++ /dev/null @@ -1,2 +0,0 @@ -Meta-Llama-3-70B-Instruct.yaml -Mixtral-8x7B-Instruct-v0.1.yaml diff --git a/.buildkite/lm-eval-harness/configs/models-small.txt b/.buildkite/lm-eval-harness/configs/models-small.txt deleted file mode 100644 index 273c5482db264..0000000000000 --- a/.buildkite/lm-eval-harness/configs/models-small.txt +++ /dev/null @@ -1,2 +0,0 @@ -Meta-Llama-3-8B-Instruct.yaml -Meta-Llama-3-8B-Instruct-FP8.yaml diff --git a/.buildkite/lm-eval-harness/run-lm-eval-gsm-hf-baseline.sh b/.buildkite/lm-eval-harness/run-lm-eval-gsm-hf-baseline.sh deleted file mode 100644 index fdb8ec5393b36..0000000000000 --- a/.buildkite/lm-eval-harness/run-lm-eval-gsm-hf-baseline.sh +++ /dev/null @@ -1,46 +0,0 @@ -#!/bin/bash -# We can use this script to compute baseline accuracy on GSM for transformers. -# -# Make sure you have lm-eval-harness installed: -# pip install git+https://github.com/EleutherAI/lm-evaluation-harness.git@9516087b81a61d0e220b22cc1b75be76de23bc10 - -usage() { - echo`` - echo "Runs lm eval harness on GSM8k using huggingface transformers." - echo "This pathway is intended to be used to create baselines for " - echo "our automated nm-test-accuracy workflow" - echo - echo "usage: ${0} " - echo - echo " -m - huggingface stub or local directory of the model" - echo " -b - batch size to run the evaluation at" - echo " -l - limit number of samples to run" - echo " -f - number of fewshot samples to use" - echo -} - -while getopts "m:b:l:f:" OPT; do - case ${OPT} in - m ) - MODEL="$OPTARG" - ;; - b ) - BATCH_SIZE="$OPTARG" - ;; - l ) - LIMIT="$OPTARG" - ;; - f ) - FEWSHOT="$OPTARG" - ;; - \? ) - usage - exit 1 - ;; - esac -done - -lm_eval --model hf \ - --model_args pretrained=$MODEL,parallelize=True \ - --tasks gsm8k --num_fewshot $FEWSHOT --limit $LIMIT \ - --batch_size $BATCH_SIZE diff --git a/.buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh b/.buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh deleted file mode 100644 index a2876bade8893..0000000000000 --- a/.buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh +++ /dev/null @@ -1,51 +0,0 @@ -#!/bin/bash -# We can use this script to compute baseline accuracy on GSM for vllm. -# We use this for fp8, which HF does not support. -# -# Make sure you have lm-eval-harness installed: -# pip install lm-eval==0.4.2 - -usage() { - echo`` - echo "Runs lm eval harness on GSM8k using huggingface transformers." - echo "This pathway is intended to be used to create baselines for " - echo "our automated nm-test-accuracy workflow" - echo - echo "usage: ${0} " - echo - echo " -m - huggingface stub or local directory of the model" - echo " -b - batch size to run the evaluation at" - echo " -l - limit number of samples to run" - echo " -f - number of fewshot samples to use" - echo " -t - tensor parallel size to run at" - echo -} - -while getopts "m:b:l:f:t:" OPT; do - case ${OPT} in - m ) - MODEL="$OPTARG" - ;; - b ) - BATCH_SIZE="$OPTARG" - ;; - l ) - LIMIT="$OPTARG" - ;; - f ) - FEWSHOT="$OPTARG" - ;; - t ) - TP_SIZE="$OPTARG" - ;; - \? ) - usage - exit 1 - ;; - esac -done - -lm_eval --model vllm \ - --model_args pretrained=$MODEL,tensor_parallel_size=$TP_SIZE \ - --tasks gsm8k --num_fewshot $FEWSHOT --limit $LIMIT \ - --batch_size $BATCH_SIZE diff --git a/.buildkite/lm-eval-harness/run-tests.sh b/.buildkite/lm-eval-harness/run-tests.sh deleted file mode 100644 index b4fdde6dab425..0000000000000 --- a/.buildkite/lm-eval-harness/run-tests.sh +++ /dev/null @@ -1,59 +0,0 @@ -#!/bin/bash - -usage() { - echo`` - echo "Runs lm eval harness on GSM8k using vllm and compares to " - echo "precomputed baseline (measured by HF transformers.)" - echo - echo "usage: ${0} " - echo - echo " -c - path to the test data config (e.g. configs/small-models.txt)" - echo " -t - tensor parallel size" - echo -} - -SUCCESS=0 - -while getopts "c:t:" OPT; do - case ${OPT} in - c ) - CONFIG="$OPTARG" - ;; - t ) - TP_SIZE="$OPTARG" - ;; - \? ) - usage - exit 1 - ;; - esac -done - -# Parse list of configs. -IFS=$'\n' read -d '' -r -a MODEL_CONFIGS < $CONFIG - -for MODEL_CONFIG in "${MODEL_CONFIGS[@]}" -do - LOCAL_SUCCESS=0 - - echo "=== RUNNING MODEL: $MODEL_CONFIG WITH TP SIZE: $TP_SIZE===" - - export LM_EVAL_TEST_DATA_FILE=$PWD/configs/${MODEL_CONFIG} - export LM_EVAL_TP_SIZE=$TP_SIZE - pytest -s test_lm_eval_correctness.py || LOCAL_SUCCESS=$? - - if [[ $LOCAL_SUCCESS == 0 ]]; then - echo "=== PASSED MODEL: ${MODEL_CONFIG} ===" - else - echo "=== FAILED MODEL: ${MODEL_CONFIG} ===" - fi - - SUCCESS=$((SUCCESS + LOCAL_SUCCESS)) - -done - -if [ "${SUCCESS}" -eq "0" ]; then - exit 0 -else - exit 1 -fi diff --git a/.buildkite/lm-eval-harness/test_lm_eval_correctness.py b/.buildkite/lm-eval-harness/test_lm_eval_correctness.py deleted file mode 100644 index 975841dad1c29..0000000000000 --- a/.buildkite/lm-eval-harness/test_lm_eval_correctness.py +++ /dev/null @@ -1,54 +0,0 @@ -""" -LM eval harness on model to compare vs HF baseline computed offline. -Configs are found in configs/$MODEL.yaml - -* export LM_EVAL_TEST_DATA_FILE=configs/Meta-Llama-3-70B-Instruct.yaml -* export LM_EVAL_TP_SIZE=4 -* pytest -s test_lm_eval_correctness.py -""" - -import os -from pathlib import Path - -import lm_eval -import numpy -import yaml - -RTOL = 0.02 -TEST_DATA_FILE = os.environ.get( - "LM_EVAL_TEST_DATA_FILE", - ".buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct.yaml") - -TP_SIZE = os.environ.get("LM_EVAL_TP_SIZE", 1) - - -def launch_lm_eval(eval_config): - model_args = f"pretrained={eval_config['model_name']}," \ - f"tensor_parallel_size={TP_SIZE}" - - results = lm_eval.simple_evaluate( - model="vllm", - model_args=model_args, - tasks=[task["name"] for task in eval_config["tasks"]], - num_fewshot=eval_config["num_fewshot"], - limit=eval_config["limit"], - batch_size="auto") - - return results - - -def test_lm_eval_correctness(): - eval_config = yaml.safe_load( - Path(TEST_DATA_FILE).read_text(encoding="utf-8")) - - # Launch eval requests. - results = launch_lm_eval(eval_config) - - # Confirm scores match ground truth. - for task in eval_config["tasks"]: - for metric in task["metrics"]: - ground_truth = metric["value"] - measured_value = results["results"][task["name"]][metric["name"]] - print(f'{task["name"]} | {metric["name"]}: ' - f'ground_truth={ground_truth} | measured={measured_value}') - assert numpy.isclose(ground_truth, measured_value, rtol=RTOL) diff --git a/.buildkite/nightly-benchmarks/README.md b/.buildkite/nightly-benchmarks/README.md deleted file mode 100644 index 4036b32a46bf7..0000000000000 --- a/.buildkite/nightly-benchmarks/README.md +++ /dev/null @@ -1,103 +0,0 @@ -# vLLM benchmark suite - -## Introduction - -This directory contains the performance benchmarking CI for vllm. -The goal is to help developers know the impact of their PRs on the performance of vllm. - -This benchmark will be *triggered* upon: -- A PR being merged into vllm. -- Every commit for those PRs with `perf-benchmarks` label. - -**Benchmarking Coverage**: latency, throughput and fix-qps serving on A100 (the support for more GPUs is comming later), with different models. - -**Benchmarking Duration**: about 1hr. - -**For benchmarking developers**: please try your best to constraint the duration of benchmarking to less than 1.5 hr so that it won't take forever to run. - - -## Configuring the workload - -The benchmarking workload contains three parts: -- Latency tests in `latency-tests.json`. -- Throughput tests in `throughput-tests.json`. -- Serving tests in `serving-tests.json`. - -See [descriptions.md](tests/descriptions.md) for detailed descriptions. - -### Latency test - -Here is an example of one test inside `latency-tests.json`: - -```json -[ - { - "test_name": "latency_llama8B_tp1", - "parameters": { - "model": "meta-llama/Meta-Llama-3-8B", - "tensor_parallel_size": 1, - "load_format": "dummy", - "num_iters_warmup": 5, - "num_iters": 15 - } - }, -] -``` - -In this example: -- The `test_name` attributes is a unique identifier for the test. In `latency-tests.json`, it must start with `latency_`. -- The `parameters` attribute control the command line arguments to be used for `benchmark_latency.py`. Note that please use underline `_` instead of the dash `-` when specifying the command line arguments, and `run-benchmarks-suite.sh` will convert the underline to dash when feeding the arguments to `benchmark_latency.py`. For example, the corresponding command line arguments for `benchmark_latency.py` will be `--model meta-llama/Meta-Llama-3-8B --tensor-parallel-size 1 --load-format dummy --num-iters-warmup 5 --num-iters 15` - -Note that the performance numbers are highly sensitive to the value of the parameters. Please make sure the parameters are set correctly. - -WARNING: The benchmarking script will save json results by itself, so please do not configure `--output-json` parameter in the json file. - - -### Throughput test -The tests are specified in `throughput-tests.json`. The syntax is similar to `latency-tests.json`, except for that the parameters will be fed forward to `benchmark_throughput.py`. - -The number of this test is also stable -- a slight change on the value of this number might vary the performance numbers by a lot. - -### Serving test -We test the throughput by using `benchmark_serving.py` with request rate = inf to cover the online serving overhead. The corresponding parameters are in `serving-tests.json`, and here is an example: - -``` -[ - { - "test_name": "serving_llama8B_tp1_sharegpt", - "qps_list": [1, 4, 16, "inf"], - "server_parameters": { - "model": "meta-llama/Meta-Llama-3-8B", - "tensor_parallel_size": 1, - "swap_space": 16, - "disable_log_stats": "", - "disable_log_requests": "", - "load_format": "dummy" - }, - "client_parameters": { - "model": "meta-llama/Meta-Llama-3-8B", - "backend": "vllm", - "dataset_name": "sharegpt", - "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", - "num_prompts": 200 - } - }, -] -``` - -Inside this example: -- The `test_name` attribute is also a unique identifier for the test. It must start with `serving_`. -- The `server-parameters` includes the command line arguments for vLLM server. -- The `client-parameters` includes the command line arguments for `benchmark_serving.py`. -- The `qps_list` controls the list of qps for test. It will be used to configure the `--request-rate` parameter in `benchmark_serving.py` - -The number of this test is less stable compared to the delay and latency benchmarks (due to randomized sharegpt dataset sampling inside `benchmark_serving.py`), but a large change on this number (e.g. 5% change) still vary the output greatly. - -WARNING: The benchmarking script will save json results by itself, so please do not configure `--save-results` or other results-saving-related parameters in `serving-tests.json`. - -## Visualizing the results -The `convert-results-json-to-markdown.py` helps you put the benchmarking results inside a markdown table, by formatting [descriptions.md](tests/descriptions.md) with real benchmarking results. -You can find the result presented as a table inside the `buildkite/performance-benchmark` job page. -If you do not see the table, please wait till the benchmark finish running. -The json version of the table (together with the json version of the benchmark) will be also attached to the markdown file. -The raw benchmarking results (in the format of json files) are in the `Artifacts` tab of the benchmarking. diff --git a/.buildkite/nightly-benchmarks/benchmark-pipeline.yaml b/.buildkite/nightly-benchmarks/benchmark-pipeline.yaml deleted file mode 100644 index 2b25c954b5c5c..0000000000000 --- a/.buildkite/nightly-benchmarks/benchmark-pipeline.yaml +++ /dev/null @@ -1,62 +0,0 @@ -steps: - - label: "Wait for container to be ready" - agents: - queue: A100 - plugins: - - kubernetes: - podSpec: - containers: - - image: badouralix/curl-jq - command: - - sh - - .buildkite/nightly-benchmarks/scripts/wait-for-image.sh - - wait - - label: "A100 Benchmark" - agents: - queue: A100 - plugins: - - kubernetes: - podSpec: - priorityClassName: perf-benchmark - containers: - - image: public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT - command: - - bash .buildkite/nightly-benchmarks/run-benchmarks-suite.sh - resources: - limits: - nvidia.com/gpu: 8 - volumeMounts: - - name: devshm - mountPath: /dev/shm - env: - - name: VLLM_USAGE_SOURCE - value: ci-test - - name: HF_TOKEN - valueFrom: - secretKeyRef: - name: hf-token-secret - key: token - nodeSelector: - nvidia.com/gpu.product: NVIDIA-A100-SXM4-80GB - volumes: - - name: devshm - emptyDir: - medium: Memory - # - label: "H100: NVIDIA SMI" - # agents: - # queue: H100 - # plugins: - # - docker#v5.11.0: - # image: public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT - # command: - # - bash - # - .buildkite/nightly-benchmarks/run-benchmarks-suite.sh - # mount-buildkite-agent: true - # propagate-environment: true - # propagate-uid-gid: false - # ipc: host - # gpus: all - # environment: - # - VLLM_USAGE_SOURCE - # - HF_TOKEN - diff --git a/.buildkite/nightly-benchmarks/kickoff-pipeline.sh b/.buildkite/nightly-benchmarks/kickoff-pipeline.sh deleted file mode 100755 index 15d411febcee1..0000000000000 --- a/.buildkite/nightly-benchmarks/kickoff-pipeline.sh +++ /dev/null @@ -1,27 +0,0 @@ -#!/usr/bin/env bash - -# NOTE(simon): this script runs inside a buildkite agent with CPU only access. -set -euo pipefail - -# Install system packages -apt update -apt install -y curl jq - -# Install minijinja for templating -curl -sSfL https://github.com/mitsuhiko/minijinja/releases/latest/download/minijinja-cli-installer.sh | sh -source $HOME/.cargo/env - -# If BUILDKITE_PULL_REQUEST != "false", then we check the PR labels using curl and jq -if [ "$BUILDKITE_PULL_REQUEST" != "false" ]; then - PR_LABELS=$(curl -s "https://api.github.com/repos/vllm-project/vllm/pulls/$BUILDKITE_PULL_REQUEST" | jq -r '.labels[].name') - - if [[ $PR_LABELS == *"perf-benchmarks"* ]]; then - echo "This PR has the 'perf-benchmarks' label. Proceeding with the nightly benchmarks." - else - echo "This PR does not have the 'perf-benchmarks' label. Skipping the nightly benchmarks." - exit 0 - fi -fi - -# Upload sample.yaml -buildkite-agent pipeline upload .buildkite/nightly-benchmarks/benchmark-pipeline.yaml diff --git a/.buildkite/nightly-benchmarks/run-benchmarks-suite.sh b/.buildkite/nightly-benchmarks/run-benchmarks-suite.sh deleted file mode 100644 index 021473f76d0e5..0000000000000 --- a/.buildkite/nightly-benchmarks/run-benchmarks-suite.sh +++ /dev/null @@ -1,358 +0,0 @@ -#!/bin/bash - -# This script should be run inside the CI process -# This script assumes that we are already inside the vllm/ directory -# Benchmarking results will be available inside vllm/benchmarks/results/ - -# Do not set -e, as the mixtral 8x22B model tends to crash occasionally -# and we still want to see other benchmarking results even when mixtral crashes. -set -o pipefail - -check_gpus() { - # check the number of GPUs and GPU type. - declare -g gpu_count=$(nvidia-smi --list-gpus | wc -l) - if [[ $gpu_count -gt 0 ]]; then - echo "GPU found." - else - echo "Need at least 1 GPU to run benchmarking." - exit 1 - fi - declare -g gpu_type=$(echo $(nvidia-smi --query-gpu=name --format=csv,noheader) | awk '{print $2}') - echo "GPU type is $gpu_type" -} - -check_hf_token() { - # check if HF_TOKEN is available and valid - if [[ -z "$HF_TOKEN" ]]; then - echo "Error: HF_TOKEN is not set." - exit 1 - elif [[ ! "$HF_TOKEN" =~ ^hf_ ]]; then - echo "Error: HF_TOKEN does not start with 'hf_'." - exit 1 - else - echo "HF_TOKEN is set and valid." - fi -} - -json2args() { - # transforms the JSON string to command line args, and '_' is replaced to '-' - # example: - # input: { "model": "meta-llama/Llama-2-7b-chat-hf", "tensor_parallel_size": 1 } - # output: --model meta-llama/Llama-2-7b-chat-hf --tensor-parallel-size 1 - local json_string=$1 - local args=$( - echo "$json_string" | jq -r ' - to_entries | - map("--" + (.key | gsub("_"; "-")) + " " + (.value | tostring)) | - join(" ") - ' - ) - echo "$args" -} - -wait_for_server() { - # wait for vllm server to start - # return 1 if vllm server crashes - timeout 1200 bash -c ' - until curl localhost:8000/v1/completions; do - sleep 1 - done' && return 0 || return 1 -} - -kill_gpu_processes() { - # kill all processes on GPU. - pids=$(nvidia-smi --query-compute-apps=pid --format=csv,noheader) - if [ -z "$pids" ]; then - echo "No GPU processes found." - else - for pid in $pids; do - kill -9 "$pid" - echo "Killed process with PID: $pid" - done - - echo "All GPU processes have been killed." - fi - - # waiting for GPU processes to be fully killed - sleep 10 - - # remove vllm config file - rm -rf ~/.config/vllm - - # Print the GPU memory usage - # so that we know if all GPU processes are killed. - gpu_memory_usage=$(nvidia-smi --query-gpu=memory.used --format=csv,noheader,nounits -i 0) - # The memory usage should be 0 MB. - echo "GPU 0 Memory Usage: $gpu_memory_usage MB" -} - -upload_to_buildkite() { - # upload the benchmarking results to buildkite - - # if the agent binary is not found, skip uploading the results, exit 0 - if [ ! -f /workspace/buildkite-agent ]; then - echo "buildkite-agent binary not found. Skip uploading the results." - return 0 - fi - /workspace/buildkite-agent annotate --style "info" --context "benchmark-results" < $RESULTS_FOLDER/benchmark_results.md - /workspace/buildkite-agent artifact upload "$RESULTS_FOLDER/*" -} - -run_latency_tests() { - # run latency tests using `benchmark_latency.py` - # $1: a json file specifying latency test cases - - local latency_test_file - latency_test_file=$1 - - # Iterate over latency tests - jq -c '.[]' "$latency_test_file" | while read -r params; do - # get the test name, and append the GPU type back to it. - test_name=$(echo "$params" | jq -r '.test_name') - if [[ ! "$test_name" =~ ^latency_ ]]; then - echo "In latency-test.json, test_name must start with \"latency_\"." - exit 1 - fi - - # if TEST_SELECTOR is set, only run the test cases that match the selector - if [[ -n "$TEST_SELECTOR" ]] && [[ ! "$test_name" =~ $TEST_SELECTOR ]]; then - echo "Skip test case $test_name." - continue - fi - - # get arguments - latency_params=$(echo "$params" | jq -r '.parameters') - latency_args=$(json2args "$latency_params") - - # check if there is enough GPU to run the test - tp=$(echo "$latency_params" | jq -r '.tensor_parallel_size') - if [[ $gpu_count -lt $tp ]]; then - echo "Required tensor-parallel-size $tp but only $gpu_count GPU found. Skip testcase $testname." - continue - fi - - latency_command="python3 benchmark_latency.py \ - --output-json $RESULTS_FOLDER/${test_name}.json \ - $latency_args" - - echo "Running test case $test_name" - echo "Latency command: $latency_command" - - # recoding benchmarking command ang GPU command - jq_output=$(jq -n \ - --arg latency "$latency_command" \ - --arg gpu "$gpu_type" \ - '{ - latency_command: $latency, - gpu_type: $gpu - }') - echo "$jq_output" > "$RESULTS_FOLDER/$test_name.commands" - - # run the benchmark - eval "$latency_command" - - kill_gpu_processes - - done -} - - -run_throughput_tests() { - # run throughput tests using `benchmark_throughput.py` - # $1: a json file specifying throughput test cases - - local throughput_test_file - throughput_test_file=$1 - - # Iterate over throughput tests - jq -c '.[]' "$throughput_test_file" | while read -r params; do - # get the test name, and append the GPU type back to it. - test_name=$(echo "$params" | jq -r '.test_name') - if [[ ! "$test_name" =~ ^throughput_ ]]; then - echo "In throughput-test.json, test_name must start with \"throughput_\"." - exit 1 - fi - - # if TEST_SELECTOR is set, only run the test cases that match the selector - if [[ -n "$TEST_SELECTOR" ]] && [[ ! "$test_name" =~ $TEST_SELECTOR ]]; then - echo "Skip test case $test_name." - continue - fi - - # get arguments - throughput_params=$(echo "$params" | jq -r '.parameters') - throughput_args=$(json2args "$throughput_params") - - # check if there is enough GPU to run the test - tp=$(echo $throughput_params | jq -r '.tensor_parallel_size') - if [[ $gpu_count -lt $tp ]]; then - echo "Required tensor-parallel-size $tp but only $gpu_count GPU found. Skip testcase $testname." - continue - fi - - throughput_command="python3 benchmark_throughput.py \ - --output-json $RESULTS_FOLDER/${test_name}.json \ - $throughput_args" - - echo "Running test case $test_name" - echo "Throughput command: $throughput_command" - # recoding benchmarking command ang GPU command - jq_output=$(jq -n \ - --arg command "$throughput_command" \ - --arg gpu "$gpu_type" \ - '{ - throughput_command: $command, - gpu_type: $gpu - }') - echo "$jq_output" > "$RESULTS_FOLDER/$test_name.commands" - - # run the benchmark - eval "$throughput_command" - - kill_gpu_processes - - done -} - -run_serving_tests() { - # run serving tests using `benchmark_serving.py` - # $1: a json file specifying serving test cases - - local serving_test_file - serving_test_file=$1 - - # Iterate over serving tests - jq -c '.[]' "$serving_test_file" | while read -r params; do - # get the test name, and append the GPU type back to it. - test_name=$(echo "$params" | jq -r '.test_name') - if [[ ! "$test_name" =~ ^serving_ ]]; then - echo "In serving-test.json, test_name must start with \"serving_\"." - exit 1 - fi - - # if TEST_SELECTOR is set, only run the test cases that match the selector - if [[ -n "$TEST_SELECTOR" ]] && [[ ! "$test_name" =~ $TEST_SELECTOR ]]; then - echo "Skip test case $test_name." - continue - fi - - - # get client and server arguments - server_params=$(echo "$params" | jq -r '.server_parameters') - client_params=$(echo "$params" | jq -r '.client_parameters') - server_args=$(json2args "$server_params") - client_args=$(json2args "$client_params") - qps_list=$(echo "$params" | jq -r '.qps_list') - qps_list=$(echo "$qps_list" | jq -r '.[] | @sh') - echo "Running over qps list $qps_list" - - # check if there is enough GPU to run the test - tp=$(echo "$server_params" | jq -r '.tensor_parallel_size') - if [[ $gpu_count -lt $tp ]]; then - echo "Required tensor-parallel-size $tp but only $gpu_count GPU found. Skip testcase $testname." - continue - fi - - # check if server model and client model is aligned - server_model=$(echo "$server_params" | jq -r '.model') - client_model=$(echo "$client_params" | jq -r '.model') - if [[ $server_model != "$client_model" ]]; then - echo "Server model and client model must be the same. Skip testcase $testname." - continue - fi - - server_command="python3 \ - -m vllm.entrypoints.openai.api_server \ - $server_args" - - # run the server - echo "Running test case $test_name" - echo "Server command: $server_command" - eval "$server_command" & - - # wait until the server is alive - wait_for_server - if [ $? -eq 0 ]; then - echo "" - echo "vllm server is up and running." - else - echo "" - echo "vllm failed to start within the timeout period." - fi - - # iterate over different QPS - for qps in $qps_list; do - # remove the surrounding single quote from qps - if [[ "$qps" == *"inf"* ]]; then - echo "qps was $qps" - qps="inf" - echo "now qps is $qps" - fi - - new_test_name=$test_name"_qps_"$qps - - client_command="python3 benchmark_serving.py \ - --save-result \ - --result-dir $RESULTS_FOLDER \ - --result-filename ${new_test_name}.json \ - --request-rate $qps \ - $client_args" - - echo "Running test case $test_name with qps $qps" - echo "Client command: $client_command" - - eval "$client_command" - - # record the benchmarking commands - jq_output=$(jq -n \ - --arg server "$server_command" \ - --arg client "$client_command" \ - --arg gpu "$gpu_type" \ - '{ - server_command: $server, - client_command: $client, - gpu_type: $gpu - }') - echo "$jq_output" > "$RESULTS_FOLDER/${new_test_name}.commands" - - done - - # clean up - kill_gpu_processes - done -} - -main() { - check_gpus - check_hf_token - - # dependencies - (which wget && which curl) || (apt-get update && apt-get install -y wget curl) - (which jq) || (apt-get update && apt-get -y install jq) - - # get the current IP address, required by benchmark_serving.py - export VLLM_HOST_IP=$(hostname -I | awk '{print $1}') - # turn of the reporting of the status of each request, to clean up the terminal output - export VLLM_LOG_LEVEL="WARNING" - - # prepare for benchmarking - cd benchmarks || exit 1 - wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json - declare -g RESULTS_FOLDER=results/ - mkdir -p $RESULTS_FOLDER - QUICK_BENCHMARK_ROOT=../.buildkite/nightly-benchmarks/ - - # benchmarking - run_serving_tests $QUICK_BENCHMARK_ROOT/tests/serving-tests.json - run_latency_tests $QUICK_BENCHMARK_ROOT/tests/latency-tests.json - run_throughput_tests $QUICK_BENCHMARK_ROOT/tests/throughput-tests.json - - - # postprocess benchmarking results - pip install tabulate pandas - python3 $QUICK_BENCHMARK_ROOT/scripts/convert-results-json-to-markdown.py - - upload_to_buildkite -} - -main "$@" diff --git a/.buildkite/nightly-benchmarks/scripts/convert-results-json-to-markdown.py b/.buildkite/nightly-benchmarks/scripts/convert-results-json-to-markdown.py deleted file mode 100644 index 534ecf17930e9..0000000000000 --- a/.buildkite/nightly-benchmarks/scripts/convert-results-json-to-markdown.py +++ /dev/null @@ -1,192 +0,0 @@ -import json -import os -from pathlib import Path - -import pandas as pd -from tabulate import tabulate - -results_folder = Path("results/") - -# latency results and the keys that will be printed into markdown -latency_results = [] -latency_column_mapping = { - "test_name": "Test name", - "gpu_type": "GPU", - "avg_latency": "Mean latency (ms)", - # "P10": "P10 (s)", - # "P25": "P25 (s)", - "P50": "Median latency (ms)", - # "P75": "P75 (s)", - # "P90": "P90 (s)", - "P99": "P99 latency (ms)", -} - -# throughput tests and the keys that will be printed into markdown -throughput_results = [] -throughput_results_column_mapping = { - "test_name": "Test name", - "gpu_type": "GPU", - # "num_requests": "# of req.", - # "total_num_tokens": "Total # of tokens", - # "elapsed_time": "Elapsed time (s)", - "requests_per_second": "Tput (req/s)", - # "tokens_per_second": "Tput (tok/s)", -} - -# serving results and the keys that will be printed into markdown -serving_results = [] -serving_column_mapping = { - "test_name": "Test name", - "gpu_type": "GPU", - # "completed": "# of req.", - "request_throughput": "Tput (req/s)", - # "input_throughput": "Input Tput (tok/s)", - # "output_throughput": "Output Tput (tok/s)", - "mean_ttft_ms": "Mean TTFT (ms)", - "median_ttft_ms": "Median TTFT (ms)", - "p99_ttft_ms": "P99 TTFT (ms)", - # "mean_tpot_ms": "Mean TPOT (ms)", - # "median_tpot_ms": "Median", - # "p99_tpot_ms": "P99", - "mean_itl_ms": "Mean ITL (ms)", - "median_itl_ms": "Median ITL (ms)", - "p99_itl_ms": "P99 ITL (ms)", -} - - -def read_markdown(file): - if os.path.exists(file): - with open(file, "r") as f: - return f.read() + "\n" - else: - return f"{file} not found.\n" - - -def results_to_json(latency, throughput, serving): - return json.dumps({ - 'latency': latency.to_dict(), - 'throughput': throughput.to_dict(), - 'serving': serving.to_dict() - }) - - -if __name__ == "__main__": - - # collect results - for test_file in results_folder.glob("*.json"): - - with open(test_file, "r") as f: - raw_result = json.loads(f.read()) - - if "serving" in str(test_file): - # this result is generated via `benchmark_serving.py` - - # attach the benchmarking command to raw_result - with open(test_file.with_suffix(".commands"), "r") as f: - command = json.loads(f.read()) - raw_result.update(command) - - # update the test name of this result - raw_result.update({"test_name": test_file.stem}) - - # add the result to raw_result - serving_results.append(raw_result) - continue - - elif "latency" in f.name: - # this result is generated via `benchmark_latency.py` - - # attach the benchmarking command to raw_result - with open(test_file.with_suffix(".commands"), "r") as f: - command = json.loads(f.read()) - raw_result.update(command) - - # update the test name of this result - raw_result.update({"test_name": test_file.stem}) - - # get different percentiles - for perc in [10, 25, 50, 75, 90, 99]: - # Multiply 1000 to convert the time unit from s to ms - raw_result.update( - {f"P{perc}": 1000 * raw_result["percentiles"][str(perc)]}) - raw_result["avg_latency"] = raw_result["avg_latency"] * 1000 - - # add the result to raw_result - latency_results.append(raw_result) - continue - - elif "throughput" in f.name: - # this result is generated via `benchmark_throughput.py` - - # attach the benchmarking command to raw_result - with open(test_file.with_suffix(".commands"), "r") as f: - command = json.loads(f.read()) - raw_result.update(command) - - # update the test name of this result - raw_result.update({"test_name": test_file.stem}) - - # add the result to raw_result - throughput_results.append(raw_result) - continue - - print(f"Skipping {test_file}") - - latency_results = pd.DataFrame.from_dict(latency_results) - serving_results = pd.DataFrame.from_dict(serving_results) - throughput_results = pd.DataFrame.from_dict(throughput_results) - - raw_results_json = results_to_json(latency_results, throughput_results, - serving_results) - - # remapping the key, for visualization purpose - if not latency_results.empty: - latency_results = latency_results[list( - latency_column_mapping.keys())].rename( - columns=latency_column_mapping) - if not serving_results.empty: - serving_results = serving_results[list( - serving_column_mapping.keys())].rename( - columns=serving_column_mapping) - if not throughput_results.empty: - throughput_results = throughput_results[list( - throughput_results_column_mapping.keys())].rename( - columns=throughput_results_column_mapping) - - processed_results_json = results_to_json(latency_results, - throughput_results, - serving_results) - - # get markdown tables - latency_md_table = tabulate(latency_results, - headers='keys', - tablefmt='pipe', - showindex=False) - serving_md_table = tabulate(serving_results, - headers='keys', - tablefmt='pipe', - showindex=False) - throughput_md_table = tabulate(throughput_results, - headers='keys', - tablefmt='pipe', - showindex=False) - - # document the result - with open(results_folder / "benchmark_results.md", "w") as f: - - results = read_markdown( - "../.buildkite/nightly-benchmarks/tests/descriptions.md") - results = results.format( - latency_tests_markdown_table=latency_md_table, - throughput_tests_markdown_table=throughput_md_table, - serving_tests_markdown_table=serving_md_table, - benchmarking_results_in_json_string=processed_results_json) - f.write(results) - - # document benchmarking results in json - with open(results_folder / "benchmark_results.json", "w") as f: - - results = latency_results.to_dict( - orient='records') + throughput_results.to_dict( - orient='records') + serving_results.to_dict(orient='records') - f.write(json.dumps(results)) diff --git a/.buildkite/nightly-benchmarks/scripts/wait-for-image.sh b/.buildkite/nightly-benchmarks/scripts/wait-for-image.sh deleted file mode 100644 index c785e6a0da628..0000000000000 --- a/.buildkite/nightly-benchmarks/scripts/wait-for-image.sh +++ /dev/null @@ -1,17 +0,0 @@ -#!/bin/sh -TOKEN=$(curl -s -L "https://public.ecr.aws/token?service=public.ecr.aws&scope=repository:q9t5s3a7/vllm-ci-test-repo:pull" | jq -r .token) -URL="https://public.ecr.aws/v2/q9t5s3a7/vllm-ci-test-repo/manifests/$BUILDKITE_COMMIT" - -retries=0 -while [ $retries -lt 1000 ]; do - if [ $(curl -s -L -H "Authorization: Bearer $TOKEN" -o /dev/null -w "%{http_code}" $URL) -eq 200 ]; then - exit 0 - fi - - echo "Waiting for image to be available..." - - retries=$((retries + 1)) - sleep 5 -done - -exit 1 \ No newline at end of file diff --git a/.buildkite/nightly-benchmarks/tests/descriptions.md b/.buildkite/nightly-benchmarks/tests/descriptions.md deleted file mode 100644 index 891e4917070d9..0000000000000 --- a/.buildkite/nightly-benchmarks/tests/descriptions.md +++ /dev/null @@ -1,67 +0,0 @@ - -## Latency tests - -This test suite aims to test vllm's end-to-end latency under a controlled setup. - -- Input length: 32 tokens. -- Output length: 128 tokens. -- Batch size: fixed (8). -- Models: llama-3 8B, llama-3 70B, mixtral 8x7B. -- Evaluation metrics: end-to-end latency (mean, median, p99). - -### Latency benchmarking results - -{latency_tests_markdown_table} - -## Throughput tests - -This test suite aims to test vllm's throughput. - -- Input length: randomly sample 200 prompts from ShareGPT dataset (with fixed random seed). -- Output length: the corresponding output length of these 200 prompts. -- Batch size: dynamically determined by vllm to achieve maximum throughput. -- Models: llama-3 8B, llama-3 70B, mixtral 8x7B. -- Evaluation metrics: throughput. - -### Throughput benchmarking results - -{throughput_tests_markdown_table} - -## Serving tests - -This test suite aims to test vllm's real serving metrics. - -- Input length: randomly sample 200 prompts from ShareGPT dataset (with fixed random seed). -- Output length: the corresponding output length of these 200 prompts. -- Batch size: dynamically determined by vllm and the arrival pattern of the requests. -- **Average QPS (query per second)**: 1, 4, 16 and inf. QPS = inf means all requests come at once. For other QPS values, the arrival time of each query is determined using a random Poisson process (with fixed random seed). -- Models: llama-3 8B, llama-3 70B, mixtral 8x7B. -- Evaluation metrics: throughput, TTFT (time to the first token, with mean, median and p99), ITL (inter-token latency, with mean, median and p99). - -### Serving benchmarking results - -{serving_tests_markdown_table} - -## json version of the benchmarking tables - -This section contains the data of the markdown tables above in JSON format. -You can load the benchmarking tables into pandas dataframes as follows: - -```python -import json -import pandas as pd - -benchmarking_results_json = """The json string""" -benchmarking_results = json.loads(benchmarking_results_json) -latency_results = pd.DataFrame.from_dict(benchmarking_results["latency"]) -throughput_results = pd.DataFrame.from_dict(benchmarking_results["throughput"]) -serving_results = pd.DataFrame.from_dict(benchmarking_results["serving"]) -``` - -The json string for all benchmarking tables: -```json -{benchmarking_results_in_json_string} -``` - -You can also check the raw experiment data in the Artifact tab of the Buildkite page. - diff --git a/.buildkite/nightly-benchmarks/tests/latency-tests.json b/.buildkite/nightly-benchmarks/tests/latency-tests.json deleted file mode 100644 index 06488cd79110a..0000000000000 --- a/.buildkite/nightly-benchmarks/tests/latency-tests.json +++ /dev/null @@ -1,32 +0,0 @@ -[ - { - "test_name": "latency_llama8B_tp1", - "parameters": { - "model": "meta-llama/Meta-Llama-3-8B", - "tensor_parallel_size": 1, - "load_format": "dummy", - "num_iters_warmup": 5, - "num_iters": 15 - } - }, - { - "test_name": "latency_llama70B_tp4", - "parameters": { - "model": "meta-llama/Meta-Llama-3-70B-Instruct", - "tensor_parallel_size": 4, - "load_format": "dummy", - "num-iters-warmup": 5, - "num-iters": 15 - } - }, - { - "test_name": "latency_mixtral8x7B_tp2", - "parameters": { - "model": "mistralai/Mixtral-8x7B-Instruct-v0.1", - "tensor_parallel_size": 2, - "load_format": "dummy", - "num-iters-warmup": 5, - "num-iters": 15 - } - } -] \ No newline at end of file diff --git a/.buildkite/nightly-benchmarks/tests/serving-tests.json b/.buildkite/nightly-benchmarks/tests/serving-tests.json deleted file mode 100644 index 86a0fefa339f7..0000000000000 --- a/.buildkite/nightly-benchmarks/tests/serving-tests.json +++ /dev/null @@ -1,59 +0,0 @@ -[ - { - "test_name": "serving_llama8B_tp1_sharegpt", - "qps_list": [1, 4, 16, "inf"], - "server_parameters": { - "model": "meta-llama/Meta-Llama-3-8B", - "tensor_parallel_size": 1, - "swap_space": 16, - "disable_log_stats": "", - "disable_log_requests": "", - "load_format": "dummy" - }, - "client_parameters": { - "model": "meta-llama/Meta-Llama-3-8B", - "backend": "vllm", - "dataset_name": "sharegpt", - "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", - "num_prompts": 200 - } - }, - { - "test_name": "serving_llama70B_tp4_sharegpt", - "qps_list": [1, 4, 16, "inf"], - "server_parameters": { - "model": "meta-llama/Meta-Llama-3-70B-Instruct", - "tensor_parallel_size": 4, - "swap_space": 16, - "disable_log_stats": "", - "disable_log_requests": "", - "load_format": "dummy" - }, - "client_parameters": { - "model": "meta-llama/Meta-Llama-3-70B-Instruct", - "backend": "vllm", - "dataset_name": "sharegpt", - "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", - "num_prompts": 200 - } - }, - { - "test_name": "serving_mixtral8x7B_tp2_sharegpt", - "qps_list": [1, 4, 16, "inf"], - "server_parameters": { - "model": "mistralai/Mixtral-8x7B-Instruct-v0.1", - "tensor_parallel_size": 2, - "swap_space": 16, - "disable_log_stats": "", - "disable_log_requests": "", - "load_format": "dummy" - }, - "client_parameters": { - "model": "mistralai/Mixtral-8x7B-Instruct-v0.1", - "backend": "vllm", - "dataset_name": "sharegpt", - "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", - "num_prompts": 200 - } - } -] \ No newline at end of file diff --git a/.buildkite/nightly-benchmarks/tests/throughput-tests.json b/.buildkite/nightly-benchmarks/tests/throughput-tests.json deleted file mode 100644 index 41ac135748704..0000000000000 --- a/.buildkite/nightly-benchmarks/tests/throughput-tests.json +++ /dev/null @@ -1,35 +0,0 @@ -[ - { - "test_name": "throughput_llama8B_tp1", - "parameters": { - "model": "meta-llama/Meta-Llama-3-8B", - "tensor_parallel_size": 1, - "load_format": "dummy", - "dataset": "./ShareGPT_V3_unfiltered_cleaned_split.json", - "num_prompts": 200, - "backend": "vllm" - } - }, - { - "test_name": "throughput_llama70B_tp4", - "parameters": { - "model": "meta-llama/Meta-Llama-3-70B-Instruct", - "tensor_parallel_size": 4, - "load_format": "dummy", - "dataset": "./ShareGPT_V3_unfiltered_cleaned_split.json", - "num_prompts": 200, - "backend": "vllm" - } - }, - { - "test_name": "throughput_mixtral8x7B_tp2", - "parameters": { - "model": "mistralai/Mixtral-8x7B-Instruct-v0.1", - "tensor_parallel_size": 2, - "load_format": "dummy", - "dataset": "./ShareGPT_V3_unfiltered_cleaned_split.json", - "num_prompts": 200, - "backend": "vllm" - } - } -] \ No newline at end of file diff --git a/.buildkite/release-pipeline.yaml b/.buildkite/release-pipeline.yaml deleted file mode 100644 index 1959f9752069f..0000000000000 --- a/.buildkite/release-pipeline.yaml +++ /dev/null @@ -1,21 +0,0 @@ -steps: - - block: "Build wheels" - - - label: "Build wheel - Python {{matrix.python_version}}, CUDA {{matrix.cuda_version}}" - agents: - queue: cpu_queue - commands: - - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg CUDA_VERSION={{matrix.cuda_version}} --build-arg PYTHON_VERSION={{matrix.python_version}} --tag vllm-ci:build-image --target build --progress plain ." - - "mkdir artifacts" - - "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image cp -r dist /artifacts_host" - - "aws s3 cp --recursive artifacts/dist s3://vllm-wheels/$BUILDKITE_COMMIT/" - matrix: - setup: - cuda_version: - - "11.8.0" - - "12.1.0" - python_version: - - "3.8" - - "3.9" - - "3.10" - - "3.11" diff --git a/.buildkite/run-amd-test.sh b/.buildkite/run-amd-test.sh index bde8ab6184d3c..ce508e4748aba 100644 --- a/.buildkite/run-amd-test.sh +++ b/.buildkite/run-amd-test.sh @@ -1,38 +1,10 @@ -# This script runs test inside the corresponding ROCm docker container. +# This script build the ROCm docker image and runs test inside it. set -ex # Print ROCm version echo "--- ROCm info" rocminfo -# cleanup older docker images -cleanup_docker() { - # Get Docker's root directory - docker_root=$(docker info -f '{{.DockerRootDir}}') - if [ -z "$docker_root" ]; then - echo "Failed to determine Docker root directory." - exit 1 - fi - echo "Docker root directory: $docker_root" - # Check disk usage of the filesystem where Docker's root directory is located - disk_usage=$(df "$docker_root" | tail -1 | awk '{print $5}' | sed 's/%//') - # Define the threshold - threshold=70 - if [ "$disk_usage" -gt "$threshold" ]; then - echo "Disk usage is above $threshold%. Cleaning up Docker images and volumes..." - # Remove dangling images (those that are not tagged and not used by any container) - docker image prune -f - # Remove unused volumes - docker volume prune -f - echo "Docker images and volumes cleanup completed." - else - echo "Disk usage is below $threshold%. No cleanup needed." - fi -} - -# Call the cleanup docker function -cleanup_docker - echo "--- Resetting GPUs" echo "reset" > /opt/amdgpu/etc/gpu_state @@ -47,16 +19,15 @@ done echo "--- Building container" sha=$(git rev-parse --short HEAD) -image_name=rocm_${sha} -container_name=rocm_${sha}_$(tr -dc A-Za-z0-9 < /dev/urandom | head -c 10; echo) +container_name=rocm_${sha} docker build \ - -t ${image_name} \ + -t ${container_name} \ -f Dockerfile.rocm \ --progress plain \ . remove_docker_container() { - docker rm -f ${container_name} || docker image rm -f ${image_name} || true + docker rm -f ${container_name} || docker image rm -f ${container_name} || true } trap remove_docker_container EXIT @@ -68,6 +39,6 @@ docker run \ --rm \ -e HF_TOKEN \ --name ${container_name} \ - ${image_name} \ + ${container_name} \ /bin/bash -c "${@}" diff --git a/.buildkite/run-benchmarks.sh b/.buildkite/run-benchmarks.sh index cbf6dda677c53..7fbad1c4bd950 100644 --- a/.buildkite/run-benchmarks.sh +++ b/.buildkite/run-benchmarks.sh @@ -9,10 +9,10 @@ cd "$(dirname "${BASH_SOURCE[0]}")/.." (which wget && which curl) || (apt-get update && apt-get install -y wget curl) # run python-based benchmarks and upload the result to buildkite -python3 benchmarks/benchmark_latency.py --output-json latency_results.json 2>&1 | tee benchmark_latency.txt +python3 benchmarks/benchmark_latency.py 2>&1 | tee benchmark_latency.txt bench_latency_exit_code=$? -python3 benchmarks/benchmark_throughput.py --input-len 256 --output-len 256 --output-json throughput_results.json 2>&1 | tee benchmark_throughput.txt +python3 benchmarks/benchmark_throughput.py --input-len 256 --output-len 256 2>&1 | tee benchmark_throughput.txt bench_throughput_exit_code=$? # run server-based benchmarks and upload the result to buildkite @@ -50,16 +50,16 @@ echo "### Serving Benchmarks" >> benchmark_results.md sed -n '1p' benchmark_serving.txt >> benchmark_results.md # first line echo "" >> benchmark_results.md echo '```' >> benchmark_results.md -tail -n 24 benchmark_serving.txt >> benchmark_results.md # last 24 lines +tail -n 20 benchmark_serving.txt >> benchmark_results.md # last 20 lines echo '```' >> benchmark_results.md # if the agent binary is not found, skip uploading the results, exit 0 -if [ ! -f /usr/bin/buildkite-agent ]; then +if [ ! -f /workspace/buildkite-agent ]; then exit 0 fi # upload the results to buildkite -buildkite-agent annotate --style "info" --context "benchmark-results" < benchmark_results.md +/workspace/buildkite-agent annotate --style "info" --context "benchmark-results" < benchmark_results.md # exit with the exit code of the benchmarks if [ $bench_latency_exit_code -ne 0 ]; then @@ -74,5 +74,4 @@ if [ $bench_serving_exit_code -ne 0 ]; then exit $bench_serving_exit_code fi -rm ShareGPT_V3_unfiltered_cleaned_split.json -buildkite-agent artifact upload "*.json" +/workspace/buildkite-agent artifact upload openai-*.json diff --git a/.buildkite/run-cpu-test.sh b/.buildkite/run-cpu-test.sh index f4fa24be1f20f..f187d1f181724 100644 --- a/.buildkite/run-cpu-test.sh +++ b/.buildkite/run-cpu-test.sh @@ -4,23 +4,11 @@ set -ex # Try building the docker image docker build -t cpu-test -f Dockerfile.cpu . -docker build --build-arg VLLM_CPU_DISABLE_AVX512="true" -t cpu-test-avx2 -f Dockerfile.cpu . # Setup cleanup -remove_docker_container() { docker rm -f cpu-test cpu-test-avx2 || true; } +remove_docker_container() { docker rm -f cpu-test || true; } trap remove_docker_container EXIT remove_docker_container -# Run the image -docker run -itd -v ~/.cache/huggingface:/root/.cache/huggingface --cpuset-cpus=48-95 --cpuset-mems=1 --network host -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --name cpu-test cpu-test -docker run -itd -v ~/.cache/huggingface:/root/.cache/huggingface --cpuset-cpus=48-95 --cpuset-mems=1 --network host -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --name cpu-test-avx2 cpu-test-avx2 - -# offline inference -docker exec cpu-test bash -c "python3 examples/offline_inference.py" -docker exec cpu-test-avx2 bash -c "python3 examples/offline_inference.py" - -# Run basic model test -docker exec cpu-test bash -c "cd tests; - pip install pytest Pillow protobuf - cd ../ - pytest -v -s tests/models -m \"not vlm\" --ignore=tests/models/test_embedding.py --ignore=tests/models/test_registry.py" +# Run the image and launch offline inference +docker run --network host --env VLLM_CPU_KVCACHE_SPACE=1 --name cpu-test cpu-test python3 examples/offline_inference.py diff --git a/.buildkite/run-openvino-test.sh b/.buildkite/run-openvino-test.sh deleted file mode 100755 index 70e56596c4a86..0000000000000 --- a/.buildkite/run-openvino-test.sh +++ /dev/null @@ -1,14 +0,0 @@ -# This script build the OpenVINO docker image and run the offline inference inside the container. -# It serves a sanity check for compilation and basic model usage. -set -ex - -# Try building the docker image -docker build -t openvino-test -f Dockerfile.openvino . - -# Setup cleanup -remove_docker_container() { docker rm -f openvino-test || true; } -trap remove_docker_container EXIT -remove_docker_container - -# Run the image and launch offline inference -docker run --network host --env VLLM_OPENVINO_KVCACHE_SPACE=1 --name openvino-test openvino-test python3 /workspace/vllm/examples/offline_inference.py diff --git a/.buildkite/run-xpu-test.sh b/.buildkite/run-xpu-test.sh deleted file mode 100644 index 22a7e76937a76..0000000000000 --- a/.buildkite/run-xpu-test.sh +++ /dev/null @@ -1,14 +0,0 @@ -# This script build the CPU docker image and run the offline inference inside the container. -# It serves a sanity check for compilation and basic model usage. -set -ex - -# Try building the docker image -docker build -t xpu-test -f Dockerfile.xpu . - -# Setup cleanup -remove_docker_container() { docker rm -f xpu-test || true; } -trap remove_docker_container EXIT -remove_docker_container - -# Run the image and launch offline inference -docker run --network host --name xpu-test --device /dev/dri -v /dev/dri/by-path:/dev/dri/by-path xpu-test python3 examples/offline_inference.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index d96e3c6d192e2..cee5e7e9d2a73 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -1,23 +1,17 @@ # In this file, you can add more tests to run either by adding a new step or # adding a new command to an existing step. See different options here for examples. - -# This script will be feed into Jinja template in `test-template-aws.j2` at -# https://github.com/vllm-project/buildkite-ci/blob/main/scripts/test-template-aws.j2 -# to generate the final pipeline yaml file. - +# This script will be feed into Jinja template in `test-template.j2` to generate +# the final pipeline yaml file. steps: - label: Regression Test - mirror_hardwares: [amd] command: pytest -v -s test_regression.py working_dir: "/vllm-workspace/tests" # optional - label: AsyncEngine Test - #mirror_hardwares: [amd] command: pytest -v -s async_engine - label: Basic Correctness Test - mirror_hardwares: [amd] commands: - VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_basic_correctness.py - VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_basic_correctness.py @@ -27,99 +21,68 @@ steps: - label: Core Test mirror_hardwares: [amd] - commands: - - pytest -v -s core - - pytest -v -s distributed/test_parallel_state.py + command: pytest -v -s core - label: Distributed Comm Ops Test - #mirror_hardwares: [amd] - working_dir: "/vllm-workspace/tests" + command: pytest -v -s test_comm_ops.py + working_dir: "/vllm-workspace/tests/distributed" num_gpus: 2 - commands: - - pytest -v -s distributed/test_comm_ops.py - - pytest -v -s distributed/test_shm_broadcast.py -- label: Distributed Tests (2 GPUs) +- label: Distributed Tests + working_dir: "/vllm-workspace/tests/distributed" + + num_gpus: 2 # only support 1 or 2 for now. mirror_hardwares: [amd] - working_dir: "/vllm-workspace/tests" - num_gpus: 2 + commands: - - bash ../.buildkite/download-images.sh - - VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py - - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py - - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py - - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_chunked_prefill_distributed.py - - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_chunked_prefill_distributed.py - - TEST_DIST_MODEL=llava-hf/llava-1.5-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_multimodal_broadcast.py - - TEST_DIST_MODEL=microsoft/Phi-3-vision-128k-instruct DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_multimodal_broadcast.py - - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py - - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py - - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py - - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py - - TEST_DIST_MODEL=llava-hf/llava-1.5-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_multimodal_broadcast.py - - TEST_DIST_MODEL=microsoft/Phi-3-vision-128k-instruct DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_multimodal_broadcast.py - - pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py - - CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py - - CUDA_VISIBLE_DEVICES=0,1 pytest -v -s distributed/test_utils.py - -- label: Distributed Tests (4 GPUs) - #mirror_hardwares: [amd] - working_dir: "/vllm-workspace/tests" + - pytest -v -s test_pynccl_library.py + - TEST_DIST_MODEL=facebook/opt-125m pytest -v -s test_basic_distributed_correctness.py + - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s test_basic_distributed_correctness.py + - TEST_DIST_MODEL=facebook/opt-125m pytest -v -s test_chunked_prefill_distributed.py + - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s test_chunked_prefill_distributed.py + +- label: Distributed Tests (Multiple Groups) + working_dir: "/vllm-workspace/tests/distributed" num_gpus: 4 commands: - - pytest -v -s distributed/test_pynccl.py - # We want to test that models which use 2 GPUs work with 4 GPUs, which is why we duplicate them here. - # See https://github.com/vllm-project/vllm/pull/5473#issuecomment-2166601837 for context. - - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py - - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py - - pytest -v -s spec_decode/e2e/test_integration_dist_tp4.py + - pytest -v -s test_pynccl.py - label: Engine Test - mirror_hardwares: [amd] + #mirror_hardwares: [amd] command: pytest -v -s engine tokenization test_sequence.py test_config.py test_logger.py - label: Entrypoints Test - mirror_hardwares: [amd] - commands: - - pytest -v -s entrypoints/llm - - pytest -v -s entrypoints/openai + # these tests have to be separated, because each one will allocate all posible GPU memory + - pytest -v -s entrypoints --ignore=entrypoints/test_server_oot_registration.py + - pytest -v -s entrypoints/test_server_oot_registration.py - label: Examples Test working_dir: "/vllm-workspace/examples" mirror_hardwares: [amd] commands: # install aws cli for llava_example.py - # install tensorizer for tensorize_vllm_model.py - - pip install awscli tensorizer + - pip install awscli - python3 offline_inference.py - python3 offline_inference_with_prefix.py - python3 llm_engine_example.py - python3 llava_example.py - - python3 tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors - -- label: Inputs Test - #mirror_hardwares: [amd] - commands: - - bash ../.buildkite/download-images.sh - - pytest -v -s test_inputs.py - - pytest -v -s multimodal - label: Kernels Test %N - #mirror_hardwares: [amd] command: pytest -v -s kernels --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT parallelism: 4 - label: Models Test #mirror_hardwares: [amd] commands: - - pytest -v -s models -m \"not vlm\" + - bash ../.buildkite/download-images.sh + - pytest -v -s models --ignore=models/test_llava.py --ignore=models/test_mistral.py -- label: Vision Language Models Test - mirror_hardwares: [amd] +- label: Llava Test + #mirror_hardwares: [amd] commands: - bash ../.buildkite/download-images.sh - - pytest -v -s models -m vlm + - pytest -v -s models/test_llava.py - label: Prefix Caching Test mirror_hardwares: [amd] @@ -127,63 +90,33 @@ steps: - pytest -v -s prefix_caching - label: Samplers Test - #mirror_hardwares: [amd] command: pytest -v -s samplers - label: LogitsProcessor Test mirror_hardwares: [amd] command: pytest -v -s test_logits_processor.py -- label: Utils Test - command: pytest -v -s test_utils.py - - label: Worker Test mirror_hardwares: [amd] command: pytest -v -s worker - label: Speculative decoding tests #mirror_hardwares: [amd] - commands: - # See https://github.com/vllm-project/vllm/issues/5152 - - export VLLM_ATTENTION_BACKEND=XFORMERS - - pytest -v -s spec_decode + command: pytest -v -s spec_decode - label: LoRA Test %N - #mirror_hardwares: [amd] - command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_long_context.py + command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT parallelism: 4 -- label: LoRA Long Context (Distributed) - #mirror_hardwares: [amd] - num_gpus: 4 - # This test runs llama 13B, so it is required to run on 4 GPUs. - commands: - # FIXIT: find out which code initialize cuda before running the test - # before the fix, we need to use spawn to test it - - export VLLM_WORKER_MULTIPROC_METHOD=spawn - - pytest -v -s -x lora/test_long_context.py - - label: Tensorizer Test - #mirror_hardwares: [amd] command: apt-get install curl libsodium23 && pytest -v -s tensorizer_loader - label: Metrics Test - mirror_hardwares: [amd] command: pytest -v -s metrics - label: Quantization Test - #mirror_hardwares: [amd] command: pytest -v -s quantization -- label: Tracing Test - commands: - - "pip install \ - opentelemetry-sdk \ - opentelemetry-api \ - opentelemetry-exporter-otlp \ - opentelemetry-semantic-conventions-ai" - - pytest -v -s tracing - - label: Benchmarks working_dir: "/vllm-workspace/.buildkite" mirror_hardwares: [amd] @@ -191,39 +124,9 @@ steps: - pip install aiohttp - bash run-benchmarks.sh -- label: LM Eval Small Models - working_dir: "/vllm-workspace/.buildkite/lm-eval-harness" - commands: - - pip install lm-eval - - export VLLM_WORKER_MULTIPROC_METHOD=spawn - - bash ./run-tests.sh -c configs/models-small.txt -t 1 - -- label: LM Eval Large Models - gpu: a100 - num_gpus: 4 - working_dir: "/vllm-workspace/.buildkite/lm-eval-harness" - commands: - - pip install lm-eval - - export VLLM_WORKER_MULTIPROC_METHOD=spawn - - bash ./run-tests.sh -c configs/models-large.txt -t 4 - - label: Documentation Build working_dir: "/vllm-workspace/test_docs/docs" no_gpu: True commands: - pip install -r requirements-docs.txt - SPHINXOPTS=\"-W\" make html - -- label: Distributed Tests (A100) - gpu: a100 - num_gpus: 4 - commands: - # NOTE: don't test llama model here, it seems hf implementation is buggy - # see https://github.com/vllm-project/vllm/pull/5689 for details - - pytest -v -s distributed/test_custom_all_reduce.py - - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py - - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py - - pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.5/flashinfer-0.0.5+cu121torch2.3-cp310-cp310-linux_x86_64.whl - - VLLM_ATTENTION_BACKEND=FLASHINFER TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py - - VLLM_ATTENTION_BACKEND=FLASHINFER TEST_DIST_MODEL=meta-llama/Meta-Llama-3-8B DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py - - pytest -v -s -x lora/test_mixtral.py diff --git a/.buildkite/test-template.j2 b/.buildkite/test-template.j2 new file mode 100644 index 0000000000000..174c756ae74a3 --- /dev/null +++ b/.buildkite/test-template.j2 @@ -0,0 +1,94 @@ +{% set docker_image = "us-central1-docker.pkg.dev/vllm-405802/vllm-ci-test-repo/vllm-test:$BUILDKITE_COMMIT" %} +{% set default_num_gpu = 1 %} +{% set default_working_dir = "/vllm-workspace/tests" %} + +steps: + + - label: ":docker: build image" + commands: + - "docker build --build-arg max_jobs=16 --tag {{ docker_image }} --target test --progress plain ." + - "docker push {{ docker_image }}" + env: + DOCKER_BUILDKIT: "1" + retry: + automatic: + - exit_status: -1 # Agent was lost + limit: 5 + - exit_status: -10 # Agent was lost + limit: 5 + - wait + + - group: "AMD Tests" + depends_on: ~ + steps: + {% for step in steps %} + {% if step.mirror_hardwares and "amd" in step.mirror_hardwares %} + - label: "AMD: {{ step.label }}" + agents: + queue: amd + command: bash .buildkite/run-amd-test.sh "cd {{ (step.working_dir or default_working_dir) | safe }} ; {{ step.command or (step.commands | join(" ; ")) | safe }}" + env: + DOCKER_BUILDKIT: "1" + {% endif %} + {% endfor %} + + - label: "Neuron Test" + depends_on: ~ + agents: + queue: neuron + command: bash .buildkite/run-neuron-test.sh + soft_fail: true + + - label: "Intel Test" + depends_on: ~ + command: bash .buildkite/run-cpu-test.sh + + {% for step in steps %} + - label: "{{ step.label }}" + agents: + queue: kubernetes + soft_fail: {{ step.soft_fail or false }} + {% if step.parallelism %} + parallelism: {{ step.parallelism }} + {% endif %} + retry: + automatic: + - exit_status: -1 # Agent was lost + limit: 5 + - exit_status: -10 # Agent was lost + limit: 5 + plugins: + - kubernetes: + podSpec: + {% if step.num_gpus %} + priorityClassName: gpu-priority-cls-{{ step.num_gpus }} + {% endif %} + volumes: + - name: dshm + emptyDir: + medium: Memory + containers: + - image: "{{ docker_image }}" + command: ["bash"] + args: + - '-c' + - "'cd {{ (step.working_dir or default_working_dir) | safe }} && {{ step.command or (step.commands | join(' && ')) | safe }}'" + {% if not step.no_gpu %} + resources: + requests: + nvidia.com/gpu: "{{ step.num_gpus or default_num_gpu }}" + limits: + nvidia.com/gpu: "{{ step.num_gpus or default_num_gpu }}" + {% endif %} + env: + - name: VLLM_USAGE_SOURCE + value: ci-test + - name: HF_TOKEN + valueFrom: + secretKeyRef: + name: hf-token-secret + key: token + volumeMounts: + - mountPath: /dev/shm + name: dshm + {% endfor %} diff --git a/.clang-format b/.clang-format deleted file mode 100644 index 7f9e6d720fae5..0000000000000 --- a/.clang-format +++ /dev/null @@ -1,26 +0,0 @@ -BasedOnStyle: Google -UseTab: Never -IndentWidth: 2 -ColumnLimit: 80 - -# Force pointers to the type for C++. -DerivePointerAlignment: false -PointerAlignment: Left - -# Reordering #include statements can (and currently will) introduce errors -SortIncludes: false - -# Style choices -AlignConsecutiveAssignments: false -AlignConsecutiveDeclarations: false -IndentPPDirectives: BeforeHash - -IncludeCategories: - - Regex: '^<' - Priority: 4 - - Regex: '^"(llvm|llvm-c|clang|clang-c|mlir|mlir-c)/' - Priority: 3 - - Regex: '^"(qoda|\.\.)/' - Priority: 2 - - Regex: '.*' - Priority: 1 diff --git a/.github/ISSUE_TEMPLATE/400-bug report.yml b/.github/ISSUE_TEMPLATE/400-bug report.yml index ce980c3f4a01d..08120ad8e5a60 100644 --- a/.github/ISSUE_TEMPLATE/400-bug report.yml +++ b/.github/ISSUE_TEMPLATE/400-bug report.yml @@ -59,8 +59,6 @@ body: Please also paste or describe the results you observe instead of the expected results. If you observe an error, please paste the error message including the **full** traceback of the exception. It may be relevant to wrap error messages in ```` ```triple quotes blocks``` ````. - Please set the environment variable `export VLLM_LOGGING_LEVEL=DEBUG` to turn on more logging to help debugging potential issues. - If you experienced crashes or hangs, it would be helpful to run vllm with `export VLLM_TRACE_FUNCTION=1` . All the function calls in vllm will be recorded. Inspect these log files, and tell which function crashes or hangs. placeholder: | A clear and concise description of what the bug is. diff --git a/.github/workflows/clang-format.yml b/.github/workflows/clang-format.yml deleted file mode 100644 index e9b6e28fa6bcb..0000000000000 --- a/.github/workflows/clang-format.yml +++ /dev/null @@ -1,42 +0,0 @@ -name: clang-format - -on: - # Trigger the workflow on push or pull request, - # but only for the main branch - push: - branches: - - main - pull_request: - branches: - - main - -jobs: - clang-format: - runs-on: ubuntu-latest - strategy: - matrix: - python-version: ["3.11"] - steps: - - uses: actions/checkout@v2 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - - name: Install dependencies - run: | - python -m pip install --upgrade pip - pip install clang-format==18.1.5 - - name: Running clang-format - run: | - EXCLUDES=( - 'csrc/moe/topk_softmax_kernels.cu' - 'csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu' - 'csrc/punica/bgmv/bgmv_config.h' - 'csrc/punica/bgmv/bgmv_impl.cuh' - 'csrc/punica/bgmv/vec_dtypes.cuh' - 'csrc/punica/punica_ops.cu' - 'csrc/punica/type_convert.h' - ) - find csrc/ \( -name '*.h' -o -name '*.cpp' -o -name '*.cu' -o -name '*.cuh' \) -print \ - | grep -vFf <(printf "%s\n" "${EXCLUDES[@]}") \ - | xargs clang-format --dry-run --Werror \ No newline at end of file diff --git a/.github/workflows/mypy.yaml b/.github/workflows/mypy.yaml index 62f0dbcd93eff..a20753d8a7702 100644 --- a/.github/workflows/mypy.yaml +++ b/.github/workflows/mypy.yaml @@ -37,7 +37,6 @@ jobs: mypy vllm/distributed --config-file pyproject.toml mypy vllm/entrypoints --config-file pyproject.toml mypy vllm/executor --config-file pyproject.toml - mypy vllm/multimodal --config-file pyproject.toml mypy vllm/usage --config-file pyproject.toml mypy vllm/*.py --config-file pyproject.toml mypy vllm/transformers_utils --config-file pyproject.toml @@ -47,5 +46,5 @@ jobs: mypy vllm/model_executor --config-file pyproject.toml mypy vllm/lora --config-file pyproject.toml mypy vllm/logging --config-file pyproject.toml - mypy tests --config-file pyproject.toml + mypy vllm/model_executor --config-file pyproject.toml diff --git a/.github/workflows/ruff.yml b/.github/workflows/ruff.yml index 773def58fd966..e71033f828006 100644 --- a/.github/workflows/ruff.yml +++ b/.github/workflows/ruff.yml @@ -25,7 +25,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install ruff==0.1.5 codespell==2.3.0 tomli==2.0.1 isort==5.13.2 + pip install ruff==0.1.5 codespell==2.2.6 tomli==2.0.1 isort==5.13.2 - name: Analysing the code with ruff run: | ruff . diff --git a/CMakeLists.txt b/CMakeLists.txt index ede9192cd1dbb..f817f3382c5e1 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -2,8 +2,7 @@ cmake_minimum_required(VERSION 3.21) project(vllm_extensions LANGUAGES CXX) -# CUDA by default, can be overridden by using -DVLLM_TARGET_DEVICE=... (used by setup.py) -set(VLLM_TARGET_DEVICE "cuda" CACHE STRING "Target device backend for vLLM") +option(VLLM_TARGET_DEVICE "Target device backend for vLLM" "cuda") message(STATUS "Build type: ${CMAKE_BUILD_TYPE}") message(STATUS "Target device: ${VLLM_TARGET_DEVICE}") @@ -33,7 +32,8 @@ set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx11 # versions are derived from Dockerfile.rocm # set(TORCH_SUPPORTED_VERSION_CUDA "2.3.0") -set(TORCH_SUPPORTED_VERSION_ROCM "2.4.0") +set(TORCH_SUPPORTED_VERSION_ROCM_5X "2.0.1") +set(TORCH_SUPPORTED_VERSION_ROCM_6X "2.1.1") # # Try to find python package with an executable that exactly matches @@ -66,6 +66,19 @@ endif() # find_package(Torch REQUIRED) +# +# Normally `torch.utils.cpp_extension.CUDAExtension` would add +# `libtorch_python.so` for linking against an extension. Torch's cmake +# configuration does not include this library (presumably since the cmake +# config is used for standalone C++ binaries that link against torch). +# The `libtorch_python.so` library defines some of the glue code between +# torch/python via pybind and is required by VLLM extensions for this +# reason. So, add it by manually with `find_library` using torch's +# installed library path. +# +find_library(torch_python_LIBRARY torch_python PATHS + "${TORCH_INSTALL_PREFIX}/lib") + # # Forward the non-CUDA device extensions to external CMake scripts. # @@ -98,11 +111,18 @@ elseif(HIP_FOUND) # .hip extension automatically, HIP must be enabled explicitly. enable_language(HIP) - # ROCm 5.X and 6.X - if (ROCM_VERSION_DEV_MAJOR GREATER_EQUAL 5 AND - NOT Torch_VERSION VERSION_EQUAL ${TORCH_SUPPORTED_VERSION_ROCM}) - message(WARNING "Pytorch version ${TORCH_SUPPORTED_VERSION_ROCM} " - "expected for ROCm build, saw ${Torch_VERSION} instead.") + # ROCm 5.x + if (ROCM_VERSION_DEV_MAJOR EQUAL 5 AND + NOT Torch_VERSION VERSION_EQUAL ${TORCH_SUPPORTED_VERSION_ROCM_5X}) + message(WARNING "Pytorch version ${TORCH_SUPPORTED_VERSION_ROCM_5X} " + "expected for ROCMm 5.x build, saw ${Torch_VERSION} instead.") + endif() + + # ROCm 6.x + if (ROCM_VERSION_DEV_MAJOR EQUAL 6 AND + NOT Torch_VERSION VERSION_EQUAL ${TORCH_SUPPORTED_VERSION_ROCM_6X}) + message(WARNING "Pytorch version ${TORCH_SUPPORTED_VERSION_ROCM_6X} " + "expected for ROCMm 6.x build, saw ${Torch_VERSION} instead.") endif() else() message(FATAL_ERROR "Can't find CUDA or HIP installation.") @@ -147,47 +167,19 @@ set(VLLM_EXT_SRC "csrc/layernorm_kernels.cu" "csrc/quantization/squeezellm/quant_cuda_kernel.cu" "csrc/quantization/gptq/q_gemm.cu" - "csrc/quantization/compressed_tensors/int8_quant_kernels.cu" - "csrc/quantization/fp8/common.cu" + "csrc/quantization/fp8/fp8_cuda_kernels.cu" "csrc/cuda_utils_kernels.cu" "csrc/moe_align_block_size_kernels.cu" - "csrc/torch_bindings.cpp") + "csrc/pybind.cpp") if(VLLM_GPU_LANG STREQUAL "CUDA") - include(FetchContent) - SET(CUTLASS_ENABLE_HEADERS_ONLY=ON) - FetchContent_Declare( - cutlass - GIT_REPOSITORY https://github.com/nvidia/cutlass.git - # CUTLASS 3.5.0 - GIT_TAG 7d49e6c7e2f8896c47f586706e67e1fb215529dc - ) - FetchContent_MakeAvailable(cutlass) - list(APPEND VLLM_EXT_SRC "csrc/quantization/aqlm/gemm_kernels.cu" "csrc/quantization/awq/gemm_kernels.cu" - "csrc/quantization/marlin/dense/marlin_cuda_kernel.cu" - "csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu" + "csrc/quantization/marlin/marlin_cuda_kernel.cu" "csrc/quantization/gptq_marlin/gptq_marlin.cu" "csrc/quantization/gptq_marlin/gptq_marlin_repack.cu" - "csrc/custom_all_reduce.cu" - "csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu" - "csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu" - "csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu") - - # - # The CUTLASS kernels for Hopper require sm90a to be enabled. - # This is done via the below gencode option, BUT that creates kernels for both sm90 and sm90a. - # That adds an extra 17MB to compiled binary, so instead we selectively enable it. - if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0) - set_source_files_properties( - "csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu" - PROPERTIES - COMPILE_FLAGS - "-gencode arch=compute_90a,code=sm_90a") - endif() - + "csrc/custom_all_reduce.cu") endif() define_gpu_extension_target( @@ -197,8 +189,6 @@ define_gpu_extension_target( SOURCES ${VLLM_EXT_SRC} COMPILE_FLAGS ${VLLM_GPU_FLAGS} ARCHITECTURES ${VLLM_GPU_ARCHES} - INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR};${CUTLASS_TOOLS_UTIL_INCLUDE_DIR} - USE_SABI 3 WITH_SOABI) # @@ -206,7 +196,7 @@ define_gpu_extension_target( # set(VLLM_MOE_EXT_SRC - "csrc/moe/torch_bindings.cpp" + "csrc/moe/moe_ops.cpp" "csrc/moe/topk_softmax_kernels.cu") define_gpu_extension_target( @@ -216,7 +206,6 @@ define_gpu_extension_target( SOURCES ${VLLM_MOE_EXT_SRC} COMPILE_FLAGS ${VLLM_GPU_FLAGS} ARCHITECTURES ${VLLM_GPU_ARCHES} - USE_SABI 3 WITH_SOABI) # @@ -230,8 +219,7 @@ set(VLLM_PUNICA_EXT_SRC "csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu" "csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu" "csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu" - "csrc/punica/punica_ops.cu" - "csrc/punica/torch_bindings.cpp") + "csrc/punica/punica_ops.cc") # # Copy GPU compilation flags+update for punica @@ -255,9 +243,6 @@ if (${VLLM_GPU_LANG} STREQUAL "CUDA") endif() endforeach() message(STATUS "Punica target arches: ${VLLM_PUNICA_GPU_ARCHES}") -elseif(${VLLM_GPU_LANG} STREQUAL "HIP") - set(VLLM_PUNICA_GPU_ARCHES ${VLLM_GPU_ARCHES}) - message(STATUS "Punica target arches: ${VLLM_PUNICA_GPU_ARCHES}") endif() if (VLLM_PUNICA_GPU_ARCHES) @@ -268,7 +253,6 @@ if (VLLM_PUNICA_GPU_ARCHES) SOURCES ${VLLM_PUNICA_EXT_SRC} COMPILE_FLAGS ${VLLM_PUNICA_GPU_FLAGS} ARCHITECTURES ${VLLM_PUNICA_GPU_ARCHES} - USE_SABI 3 WITH_SOABI) else() message(WARNING "Unable to create _punica_C target because none of the " @@ -293,7 +277,9 @@ add_custom_target(default) if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP") message(STATUS "Enabling C extension.") add_dependencies(default _C) +endif() +if(VLLM_GPU_LANG STREQUAL "CUDA") message(STATUS "Enabling moe extension.") add_dependencies(default _moe_C) diff --git a/Dockerfile b/Dockerfile index d031d98c5b7e4..90be3a30f89b1 100644 --- a/Dockerfile +++ b/Dockerfile @@ -5,35 +5,18 @@ # docs/source/dev/dockerfile/dockerfile.rst and # docs/source/assets/dev/dockerfile-stages-dependency.png -ARG CUDA_VERSION=12.4.1 #################### BASE BUILD IMAGE #################### # prepare basic build environment -FROM nvidia/cuda:${CUDA_VERSION}-devel-ubuntu22.04 AS base - -ARG CUDA_VERSION=12.4.1 -ARG PYTHON_VERSION=3 - -ENV DEBIAN_FRONTEND=noninteractive - -RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \ - && echo 'tzdata tzdata/Zones/America select Los_Angeles' | debconf-set-selections \ - && apt-get update -y \ - && apt-get install -y ccache software-properties-common \ - && add-apt-repository ppa:deadsnakes/ppa \ - && apt-get update -y \ - && apt-get install -y python${PYTHON_VERSION} python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv python3-pip \ - && if [ "${PYTHON_VERSION}" != "3" ]; then update-alternatives --install /usr/bin/python3 python3 /usr/bin/python${PYTHON_VERSION} 1; fi \ - && python3 --version \ - && python3 -m pip --version +FROM nvidia/cuda:12.4.1-devel-ubuntu22.04 AS dev RUN apt-get update -y \ - && apt-get install -y python3-pip git curl sudo + && apt-get install -y python3-pip git # Workaround for https://github.com/openai/triton/issues/2507 and # https://github.com/pytorch/pytorch/issues/107960 -- hopefully # this won't be needed for future versions of this docker image # or future versions of triton. -RUN ldconfig /usr/local/cuda-$(echo $CUDA_VERSION | cut -d. -f1,2)/compat/ +RUN ldconfig /usr/local/cuda-12.4/compat/ WORKDIR /workspace @@ -41,7 +24,12 @@ WORKDIR /workspace COPY requirements-common.txt requirements-common.txt COPY requirements-cuda.txt requirements-cuda.txt RUN --mount=type=cache,target=/root/.cache/pip \ - python3 -m pip install -r requirements-cuda.txt + pip install -r requirements-cuda.txt + +# install development dependencies +COPY requirements-dev.txt requirements-dev.txt +RUN --mount=type=cache,target=/root/.cache/pip \ + pip install -r requirements-dev.txt # cuda arch list used by torch # can be useful for both `dev` and `test` @@ -51,16 +39,14 @@ ARG torch_cuda_arch_list='7.0 7.5 8.0 8.6 8.9 9.0+PTX' ENV TORCH_CUDA_ARCH_LIST=${torch_cuda_arch_list} #################### BASE BUILD IMAGE #################### -#################### WHEEL BUILD IMAGE #################### -FROM base AS build -ARG PYTHON_VERSION=3 +#################### WHEEL BUILD IMAGE #################### +FROM dev AS build # install build dependencies COPY requirements-build.txt requirements-build.txt - RUN --mount=type=cache,target=/root/.cache/pip \ - python3 -m pip install -r requirements-build.txt + pip install -r requirements-build.txt # install compiler cache to speed up compilation leveraging local or remote caching RUN apt-get update -y && apt-get install -y ccache @@ -84,50 +70,43 @@ ENV NVCC_THREADS=$nvcc_threads # make sure punica kernels are built (for LoRA) ENV VLLM_INSTALL_PUNICA_KERNELS=1 -ARG USE_SCCACHE -# if USE_SCCACHE is set, use sccache to speed up compilation -RUN --mount=type=cache,target=/root/.cache/pip \ - if [ "$USE_SCCACHE" = "1" ]; then \ - echo "Installing sccache..." \ - && curl -L -o sccache.tar.gz https://github.com/mozilla/sccache/releases/download/v0.8.1/sccache-v0.8.1-x86_64-unknown-linux-musl.tar.gz \ - && tar -xzf sccache.tar.gz \ - && sudo mv sccache-v0.8.1-x86_64-unknown-linux-musl/sccache /usr/bin/sccache \ - && rm -rf sccache.tar.gz sccache-v0.8.1-x86_64-unknown-linux-musl \ - && export SCCACHE_BUCKET=vllm-build-sccache \ - && export SCCACHE_REGION=us-west-2 \ - && sccache --show-stats \ - && python3 setup.py bdist_wheel --dist-dir=dist \ - && sccache --show-stats; \ - fi - ENV CCACHE_DIR=/root/.cache/ccache RUN --mount=type=cache,target=/root/.cache/ccache \ --mount=type=cache,target=/root/.cache/pip \ - if [ "$USE_SCCACHE" != "1" ]; then \ - python3 setup.py bdist_wheel --dist-dir=dist; \ - fi + python3 setup.py bdist_wheel --dist-dir=dist # check the size of the wheel, we cannot upload wheels larger than 100MB COPY .buildkite/check-wheel-size.py check-wheel-size.py RUN python3 check-wheel-size.py dist +# the `vllm_nccl` package must be installed from source distribution +# pip is too smart to store a wheel in the cache, and other CI jobs +# will directly use the wheel from the cache, which is not what we want. +# we need to remove it manually +RUN --mount=type=cache,target=/root/.cache/pip \ + pip cache remove vllm_nccl* #################### EXTENSION Build IMAGE #################### -#################### DEV IMAGE #################### -FROM base as dev +#################### FLASH_ATTENTION Build IMAGE #################### +FROM dev as flash-attn-builder +# max jobs used for build +ARG max_jobs=2 +ENV MAX_JOBS=${max_jobs} +# flash attention version +ARG flash_attn_version=v2.5.8 +ENV FLASH_ATTN_VERSION=${flash_attn_version} -COPY requirements-lint.txt requirements-lint.txt -COPY requirements-test.txt requirements-test.txt -COPY requirements-dev.txt requirements-dev.txt -RUN --mount=type=cache,target=/root/.cache/pip \ - python3 -m pip install -r requirements-dev.txt +WORKDIR /usr/src/flash-attention-v2 + +# Download the wheel or build it if a pre-compiled release doesn't exist +RUN pip --verbose wheel flash-attn==${FLASH_ATTN_VERSION} \ + --no-build-isolation --no-deps --no-cache-dir -#################### DEV IMAGE #################### +#################### FLASH_ATTENTION Build IMAGE #################### #################### vLLM installation IMAGE #################### # image with vLLM installed -FROM nvidia/cuda:${CUDA_VERSION}-base-ubuntu22.04 AS vllm-base -ARG CUDA_VERSION=12.4.1 +FROM nvidia/cuda:12.4.1-base-ubuntu22.04 AS vllm-base WORKDIR /vllm-workspace RUN apt-get update -y \ @@ -137,12 +116,16 @@ RUN apt-get update -y \ # https://github.com/pytorch/pytorch/issues/107960 -- hopefully # this won't be needed for future versions of this docker image # or future versions of triton. -RUN ldconfig /usr/local/cuda-$(echo $CUDA_VERSION | cut -d. -f1,2)/compat/ +RUN ldconfig /usr/local/cuda-12.4/compat/ # install vllm wheel first, so that torch etc will be installed RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist \ --mount=type=cache,target=/root/.cache/pip \ - python3 -m pip install dist/*.whl --verbose + pip install dist/*.whl --verbose + +RUN --mount=type=bind,from=flash-attn-builder,src=/usr/src/flash-attention-v2,target=/usr/src/flash-attention-v2 \ + --mount=type=cache,target=/root/.cache/pip \ + pip install /usr/src/flash-attention-v2/*.whl --no-cache-dir #################### vLLM installation IMAGE #################### @@ -155,7 +138,7 @@ ADD . /vllm-workspace/ # install development dependencies (for testing) RUN --mount=type=cache,target=/root/.cache/pip \ - python3 -m pip install -r requirements-dev.txt + pip install -r requirements-dev.txt # doc requires source code # we hide them inside `test_docs/` , so that this source code @@ -172,7 +155,7 @@ FROM vllm-base AS vllm-openai # install additional dependencies for openai api server RUN --mount=type=cache,target=/root/.cache/pip \ - pip install accelerate hf_transfer 'modelscope!=1.15.0' + pip install accelerate hf_transfer modelscope ENV VLLM_USAGE_SOURCE production-docker-image diff --git a/Dockerfile.cpu b/Dockerfile.cpu index 6e55203decc56..4251fddd6cc3b 100644 --- a/Dockerfile.cpu +++ b/Dockerfile.cpu @@ -1,19 +1,13 @@ # This vLLM Dockerfile is used to construct image that can build and run vLLM on x86 CPU platform. -FROM ubuntu:22.04 AS cpu-test-1 +FROM ubuntu:22.04 RUN apt-get update -y \ - && apt-get install -y git wget vim numactl gcc-12 g++-12 python3 python3-pip libtcmalloc-minimal4 \ + && apt-get install -y git wget vim numactl gcc-12 g++-12 python3 python3-pip \ && update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 10 --slave /usr/bin/g++ g++ /usr/bin/g++-12 -RUN echo 'export LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libtcmalloc_minimal.so.4:$LD_PRELOAD' >> ~/.bashrc - -RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/cpu/intel_extension_for_pytorch-2.3.100%2Bgit0eb3473-cp310-cp310-linux_x86_64.whl - RUN pip install --upgrade pip \ - && pip install wheel packaging ninja "setuptools>=49.4.0" numpy - -FROM cpu-test-1 AS build + && pip install wheel packaging ninja setuptools>=49.4.0 numpy COPY ./ /workspace/vllm @@ -21,14 +15,6 @@ WORKDIR /workspace/vllm RUN pip install -v -r requirements-cpu.txt --extra-index-url https://download.pytorch.org/whl/cpu -# Support for building with non-AVX512 vLLM: docker build --build-arg VLLM_CPU_DISABLE_AVX512="true" ... -ARG VLLM_CPU_DISABLE_AVX512 -ENV VLLM_CPU_DISABLE_AVX512=${VLLM_CPU_DISABLE_AVX512} - RUN VLLM_TARGET_DEVICE=cpu python3 setup.py install -WORKDIR /workspace/ - -RUN ln -s /workspace/vllm/tests && ln -s /workspace/vllm/examples && ln -s /workspace/vllm/benchmarks - CMD ["/bin/bash"] diff --git a/Dockerfile.neuron b/Dockerfile.neuron index 010f23a143010..fe42b4ef393f1 100644 --- a/Dockerfile.neuron +++ b/Dockerfile.neuron @@ -28,7 +28,7 @@ COPY ./requirements-neuron.txt /app/vllm/requirements-neuron.txt RUN cd /app/vllm \ && python3 -m pip install -U -r requirements-neuron.txt -ENV VLLM_TARGET_DEVICE neuron +ENV VLLM_BUILD_WITH_NEURON 1 RUN cd /app/vllm \ && pip install -e . \ && cd .. diff --git a/Dockerfile.openvino b/Dockerfile.openvino deleted file mode 100644 index 9861997b451a9..0000000000000 --- a/Dockerfile.openvino +++ /dev/null @@ -1,26 +0,0 @@ -# The vLLM Dockerfile is used to construct vLLM image that can be directly used -# to run the OpenAI compatible server. - -FROM ubuntu:22.04 AS dev - -RUN apt-get update -y && \ - apt-get install -y python3-pip git -WORKDIR /workspace - -# copy requirements -COPY requirements-build.txt /workspace/vllm/ -COPY requirements-common.txt /workspace/vllm/ -COPY requirements-openvino.txt /workspace/vllm/ - -COPY vllm/ /workspace/vllm/vllm -COPY setup.py /workspace/vllm/ - -# install build requirements -RUN PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu" python3 -m pip install -r /workspace/vllm/requirements-build.txt -# build vLLM with OpenVINO backend -RUN PIP_PRE=1 PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu https://storage.openvinotoolkit.org/simple/wheels/nightly/" VLLM_TARGET_DEVICE="openvino" python3 -m pip install /workspace/vllm/ - -COPY examples/ /workspace/vllm/examples -COPY benchmarks/ /workspace/vllm/benchmarks - -CMD ["/bin/bash"] diff --git a/Dockerfile.ppc64le b/Dockerfile.ppc64le deleted file mode 100644 index d4e4c483cada8..0000000000000 --- a/Dockerfile.ppc64le +++ /dev/null @@ -1,22 +0,0 @@ -FROM mambaorg/micromamba -ARG MAMBA_DOCKERFILE_ACTIVATE=1 -USER root - -RUN apt-get update -y && apt-get install -y git wget vim numactl gcc-12 g++-12 protobuf-compiler libprotobuf-dev && update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 10 --slave /usr/bin/g++ g++ /usr/bin/g++-12 - -# Some packages in requirements-cpu are installed here -# IBM provides optimized packages for ppc64le processors in the open-ce project for mamba -# Currently these may not be available for venv or pip directly -RUN micromamba install -y -n base -c https://ftp.osuosl.org/pub/open-ce/1.11.0-p10/ -c defaults python=3.10 pytorch-cpu=2.1.2 torchvision-cpu=0.16.2 && micromamba clean --all --yes - -COPY ./ /workspace/vllm - -WORKDIR /workspace/vllm - -# These packages will be in rocketce eventually -RUN pip install -v -r requirements-cpu.txt --prefer-binary --extra-index-url https://repo.fury.io/mgiessing - -RUN VLLM_TARGET_DEVICE=cpu python3 setup.py install - -WORKDIR /vllm-workspace -ENTRYPOINT ["/opt/conda/bin/python3", "-m", "vllm.entrypoints.openai.api_server"] diff --git a/Dockerfile.rocm b/Dockerfile.rocm index 1b89b892bbf1c..d04bb9915e2ab 100644 --- a/Dockerfile.rocm +++ b/Dockerfile.rocm @@ -1,35 +1,35 @@ -# Default ROCm 6.1 base image -ARG BASE_IMAGE="rocm/pytorch:rocm6.1.2_ubuntu20.04_py3.9_pytorch_staging" - -# Tested and supported base rocm/pytorch images -ARG ROCm_5_7_BASE="rocm/pytorch:rocm5.7_ubuntu20.04_py3.9_pytorch_2.0.1" \ - ROCm_6_0_BASE="rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1" \ - ROCM_6_1_BASE="rocm/pytorch:rocm6.1.2_ubuntu20.04_py3.9_pytorch_staging" - -# Default ROCm ARCHes to build vLLM for. -ARG PYTORCH_ROCM_ARCH="gfx908;gfx90a;gfx942;gfx1100" - -# Whether to build CK-based flash-attention -# If 0, will not build flash attention -# This is useful for gfx target where flash-attention is not supported -# (i.e. those that do not appear in `FA_GFX_ARCHS`) -# Triton FA is used by default on ROCm now so this is unnecessary. -ARG BUILD_FA="1" +# default base image +ARG BASE_IMAGE="rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1" + +FROM $BASE_IMAGE + +ARG BASE_IMAGE="rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1" + +RUN echo "Base image is $BASE_IMAGE" + +# BASE_IMAGE for ROCm_5.7: "rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1" +# BASE_IMAGE for ROCm_6.0: "rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1" + + ARG FA_GFX_ARCHS="gfx90a;gfx942" -ARG FA_BRANCH="ae7928c" +RUN echo "FA_GFX_ARCHS is $FA_GFX_ARCHS" -# Whether to build triton on rocm -ARG BUILD_TRITON="1" -ARG TRITON_BRANCH="0ef1848" +ARG FA_BRANCH="ae7928c" +RUN echo "FA_BRANCH is $FA_BRANCH" -### Base image build stage -FROM $BASE_IMAGE AS base +# whether to build flash-attention +# if 0, will not build flash attention +# this is useful for gfx target where flash-attention is not supported +# In that case, we need to use the python reference attention implementation in vllm +ARG BUILD_FA="1" -# Import arg(s) defined before this build stage -ARG PYTORCH_ROCM_ARCH +# whether to build triton on rocm +ARG BUILD_TRITON="1" # Install some basic utilities RUN apt-get update && apt-get install python3 python3-pip -y + +# Install some basic utilities RUN apt-get update && apt-get install -y \ curl \ ca-certificates \ @@ -40,165 +40,68 @@ RUN apt-get update && apt-get install -y \ build-essential \ wget \ unzip \ + nvidia-cuda-toolkit \ tmux \ - ccache \ && rm -rf /var/lib/apt/lists/* -# When launching the container, mount the code directory to /vllm-workspace +### Mount Point ### +# When launching the container, mount the code directory to /app ARG APP_MOUNT=/vllm-workspace +VOLUME [ ${APP_MOUNT} ] WORKDIR ${APP_MOUNT} -RUN pip install --upgrade pip -# Remove sccache so it doesn't interfere with ccache -# TODO: implement sccache support across components -RUN apt-get purge -y sccache; pip uninstall -y sccache; rm -f "$(which sccache)" -# Install torch == 2.4.0 on ROCm -RUN case "$(ls /opt | grep -Po 'rocm-[0-9]\.[0-9]')" in \ - *"rocm-5.7"*) \ - pip uninstall -y torch torchaudio torchvision \ - && pip install --no-cache-dir --pre \ - torch==2.4.0.dev20240612 torchaudio==2.4.0.dev20240612 \ - torchvision==0.19.0.dev20240612 \ - --index-url https://download.pytorch.org/whl/nightly/rocm5.7;; \ - *"rocm-6.0"*) \ - pip uninstall -y torch torchaudio torchvision \ - && pip install --no-cache-dir --pre \ - torch==2.4.0.dev20240612 torchaudio==2.4.0.dev20240612 \ - torchvision==0.19.0.dev20240612 \ - --index-url https://download.pytorch.org/whl/nightly/rocm6.0;; \ - *"rocm-6.1"*) \ - pip uninstall -y torch torchaudio torchvision \ - && pip install --no-cache-dir --pre \ - torch==2.4.0.dev20240612 torchaudio==2.4.0.dev20240612 \ - torchvision==0.19.0.dev20240612 \ - --index-url https://download.pytorch.org/whl/nightly/rocm6.1;; \ - *) ;; esac +RUN python3 -m pip install --upgrade pip +RUN python3 -m pip install --no-cache-dir fastapi ninja tokenizers pandas ENV LLVM_SYMBOLIZER_PATH=/opt/rocm/llvm/bin/llvm-symbolizer ENV PATH=$PATH:/opt/rocm/bin:/libtorch/bin: ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/rocm/lib/:/libtorch/lib: ENV CPLUS_INCLUDE_PATH=$CPLUS_INCLUDE_PATH:/libtorch/include:/libtorch/include/torch/csrc/api/include/:/opt/rocm/include/: -ENV PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH} -ENV CCACHE_DIR=/root/.cache/ccache - - -### AMD-SMI build stage -FROM base AS build_amdsmi -# Build amdsmi wheel always -RUN cd /opt/rocm/share/amd_smi \ - && pip wheel . --wheel-dir=/install - - -### Flash-Attention wheel build stage -FROM base AS build_fa -ARG BUILD_FA -ARG FA_GFX_ARCHS -ARG FA_BRANCH -# Build ROCm flash-attention wheel if `BUILD_FA = 1` -RUN --mount=type=cache,target=${CCACHE_DIR} \ - if [ "$BUILD_FA" = "1" ]; then \ - mkdir -p libs \ +# Install ROCm flash-attention +RUN if [ "$BUILD_FA" = "1" ]; then \ + mkdir libs \ && cd libs \ && git clone https://github.com/ROCm/flash-attention.git \ && cd flash-attention \ - && git checkout "${FA_BRANCH}" \ + && git checkout ${FA_BRANCH} \ && git submodule update --init \ - && case "$(ls /opt | grep -Po 'rocm-[0-9]\.[0-9]')" in \ - *"rocm-5.7"*) \ - export VLLM_TORCH_PATH="$(python3 -c 'import torch; print(torch.__path__[0])')" \ - && patch "${VLLM_TORCH_PATH}"/utils/hipify/hipify_python.py hipify_patch.patch;; \ - *) ;; esac \ - && GPU_ARCHS="${FA_GFX_ARCHS}" python3 setup.py bdist_wheel --dist-dir=/install; \ - # Create an empty directory otherwise as later build stages expect one - else mkdir -p /install; \ + && export GPU_ARCHS=${FA_GFX_ARCHS} \ + && if [ "$BASE_IMAGE" = "rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1" ]; then \ + patch /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/utils/hipify/hipify_python.py hipify_patch.patch; fi \ + && python3 setup.py install \ + && cd ..; \ fi +# Error related to odd state for numpy 1.20.3 where there is no METADATA etc, but an extra LICENSES_bundled.txt. +# Manually removed it so that later steps of numpy upgrade can continue +RUN if [ "$BASE_IMAGE" = "rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1" ]; then \ + rm -rf /opt/conda/envs/py_3.9/lib/python3.9/site-packages/numpy-1.20.3.dist-info/; fi -### Triton wheel build stage -FROM base AS build_triton -ARG BUILD_TRITON -ARG TRITON_BRANCH -# Build triton wheel if `BUILD_TRITON = 1` -RUN --mount=type=cache,target=${CCACHE_DIR} \ - if [ "$BUILD_TRITON" = "1" ]; then \ +# build triton +RUN if [ "$BUILD_TRITON" = "1" ]; then \ mkdir -p libs \ && cd libs \ - && git clone https://github.com/OpenAI/triton.git \ - && cd triton \ - && git checkout "${TRITON_BRANCH}" \ - && cd python \ - && python3 setup.py bdist_wheel --dist-dir=/install; \ - # Create an empty directory otherwise as later build stages expect one - else mkdir -p /install; \ + && pip uninstall -y triton \ + && git clone https://github.com/ROCm/triton.git \ + && cd triton/python \ + && pip3 install . \ + && cd ../..; \ fi - -### Final vLLM build stage -FROM base AS final -# Import the vLLM development directory from the build context +WORKDIR /vllm-workspace COPY . . -# Error related to odd state for numpy 1.20.3 where there is no METADATA etc, but an extra LICENSES_bundled.txt. -# Manually remove it so that later steps of numpy upgrade can continue -RUN case "$(which python3)" in \ - *"/opt/conda/envs/py_3.9"*) \ - rm -rf /opt/conda/envs/py_3.9/lib/python3.9/site-packages/numpy-1.20.3.dist-info/;; \ - *) ;; esac +RUN python3 -m pip install --upgrade pip numba -# Package upgrades for useful functionality or to avoid dependency issues RUN --mount=type=cache,target=/root/.cache/pip \ - pip install --upgrade numba scipy huggingface-hub[cli] - -# Make sure punica kernels are built (for LoRA) -ENV VLLM_INSTALL_PUNICA_KERNELS=1 -# Workaround for ray >= 2.10.0 -ENV RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES=1 -# Silences the HF Tokenizers warning -ENV TOKENIZERS_PARALLELISM=false - -RUN --mount=type=cache,target=${CCACHE_DIR} \ - --mount=type=cache,target=/root/.cache/pip \ pip install -U -r requirements-rocm.txt \ - && case "$(ls /opt | grep -Po 'rocm-[0-9]\.[0-9]')" in \ - *"rocm-6.0"*) \ - patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h rocm_patch/rocm_bf16.patch;; \ - *"rocm-6.1"*) \ - # Bring in upgrades to HIP graph earlier than ROCm 6.2 for vLLM - wget -N https://github.com/ROCm/vllm/raw/fa78403/rocm_patch/libamdhip64.so.6 -P rocm_patch \ - && cp rocm_patch/libamdhip64.so.6 /opt/rocm/lib/libamdhip64.so.6 \ - # Prevent interference if torch bundles its own HIP runtime - && rm -f "$(python3 -c 'import torch; print(torch.__path__[0])')"/lib/libamdhip64.so* || true;; \ - *) ;; esac \ - && python3 setup.py clean --all \ - && python3 setup.py develop - -# Copy amdsmi wheel into final image -RUN --mount=type=bind,from=build_amdsmi,src=/install,target=/install \ - mkdir -p libs \ - && cp /install/*.whl libs \ - # Preemptively uninstall to avoid same-version no-installs - && pip uninstall -y amdsmi; - -# Copy triton wheel(s) into final image if they were built -RUN --mount=type=bind,from=build_triton,src=/install,target=/install \ - mkdir -p libs \ - && if ls /install/*.whl; then \ - cp /install/*.whl libs \ - # Preemptively uninstall to avoid same-version no-installs - && pip uninstall -y triton; fi + && patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h ./rocm_patch/rocm_bf16.patch \ + && python3 setup.py install \ + && cp build/lib.linux-x86_64-cpython-39/vllm/_C.cpython-39-x86_64-linux-gnu.so vllm/ \ + && cd .. -# Copy flash-attn wheel(s) into final image if they were built -RUN --mount=type=bind,from=build_fa,src=/install,target=/install \ - mkdir -p libs \ - && if ls /install/*.whl; then \ - cp /install/*.whl libs \ - # Preemptively uninstall to avoid same-version no-installs - && pip uninstall -y flash-attn; fi - -# Install wheels that were built to the final image -RUN --mount=type=cache,target=/root/.cache/pip \ - if ls libs/*.whl; then \ - pip install libs/*.whl; fi +RUN python3 -m pip install --upgrade pip +RUN python3 -m pip install --no-cache-dir ray[all]==2.9.3 CMD ["/bin/bash"] diff --git a/Dockerfile.tpu b/Dockerfile.tpu deleted file mode 100644 index 931c844c08dce..0000000000000 --- a/Dockerfile.tpu +++ /dev/null @@ -1,19 +0,0 @@ -ARG NIGHTLY_DATE="20240601" -ARG BASE_IMAGE="us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_$NIGHTLY_DATE" - -FROM $BASE_IMAGE - -WORKDIR /workspace -COPY . /workspace/vllm - -ENV VLLM_TARGET_DEVICE="tpu" -# Install aiohttp separately to avoid build errors. -RUN pip install aiohttp -# Install the TPU and Pallas dependencies. -RUN pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html -RUN pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html - -# Build vLLM. -RUN cd /workspace/vllm && python setup.py develop - -CMD ["/bin/bash"] diff --git a/Dockerfile.xpu b/Dockerfile.xpu deleted file mode 100644 index c39e551672d20..0000000000000 --- a/Dockerfile.xpu +++ /dev/null @@ -1,22 +0,0 @@ -FROM intel/oneapi-basekit:2024.1.0-devel-ubuntu22.04 - -RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB | gpg --dearmor | tee /usr/share/keyrings/intel-oneapi-archive-keyring.gpg > /dev/null && \ - echo "deb [signed-by=/usr/share/keyrings/intel-oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main " | tee /etc/apt/sources.list.d/oneAPI.list && \ - chmod 644 /usr/share/keyrings/intel-oneapi-archive-keyring.gpg && \ - rm /etc/apt/sources.list.d/intel-graphics.list && \ - wget -O- https://repositories.intel.com/graphics/intel-graphics.key | gpg --dearmor | tee /usr/share/keyrings/intel-graphics.gpg > /dev/null && \ - echo "deb [arch=amd64,i386 signed-by=/usr/share/keyrings/intel-graphics.gpg] https://repositories.intel.com/graphics/ubuntu jammy arc" | tee /etc/apt/sources.list.d/intel.gpu.jammy.list && \ - chmod 644 /usr/share/keyrings/intel-graphics.gpg - -RUN apt-get update -y \ -&& apt-get install -y curl libicu70 lsb-release git wget vim numactl python3 python3-pip - -COPY ./ /workspace/vllm - -WORKDIR /workspace/vllm - -RUN pip install -v -r requirements-xpu.txt - -RUN VLLM_TARGET_DEVICE=xpu python3 setup.py install - -CMD ["/bin/bash"] diff --git a/README.md b/README.md index d6957a7f5ee3a..9b180877a5a82 100644 --- a/README.md +++ b/README.md @@ -14,19 +14,7 @@ Easy, fast, and cheap LLM serving for everyone

---- - -**Ray Summit CPF is Open (June 4th to June 20th)!** - -There will be a track for vLLM at the Ray Summit (09/30-10/02, SF) this year! -If you have cool projects related to vLLM or LLM inference, we would love to see your proposals. -This will be a great chance for everyone in the community to get together and learn. -Please submit your proposal [here](https://raysummit.anyscale.com/flow/anyscale/raysummit2024/landing/page/eventsite) - ---- - *Latest News* 🔥 -- [2024/06] We hosted [the fourth vLLM meetup](https://lu.ma/agivllm) with Cloudflare and BentoML! Please find the meetup slides [here](https://docs.google.com/presentation/d/1iJ8o7V2bQEi0BFEljLTwc5G1S10_Rhv3beed5oB0NJ4/edit?usp=sharing). - [2024/05] vLLM-fork specific: Added Intel® Gaudi® 2 support with SynapseAI 1.16.0. For more information, please refer to Intel® Gaudi® README. - [2024/04] We hosted [the third vLLM meetup](https://robloxandvllmmeetup2024.splashthat.com/) with Roblox! Please find the meetup slides [here](https://docs.google.com/presentation/d/1A--47JAK4BJ39t954HyTkvtfwn0fkqtsL8NGFuslReM/edit?usp=sharing). - [2024/01] We hosted [the second vLLM meetup](https://lu.ma/ygxbpzhl) in SF! Please find the meetup slides [here](https://docs.google.com/presentation/d/12mI2sKABnUw5RBWXDYY-HtHth4iMSNcEoQ10jDQbxgA/edit?usp=sharing). @@ -60,18 +48,45 @@ vLLM is flexible and easy to use with: - Tensor parallelism support for distributed inference - Streaming outputs - OpenAI-compatible API server -- Support NVIDIA GPUs, AMD GPUs, Intel CPUs and GPUs +- Support NVIDIA GPUs and AMD GPUs - (Experimental) Prefix caching support - (Experimental) Multi-lora support -vLLM seamlessly supports most popular open-source models on HuggingFace, including: -- Transformer-like LLMs (e.g., Llama) -- Mixture-of-Expert LLMs (e.g., Mixtral) -- Multi-modal LLMs (e.g., LLaVA) - -Find the full list of supported models [here](https://docs.vllm.ai/en/latest/models/supported_models.html). - -## Getting Started +vLLM seamlessly supports many Hugging Face models, including the following architectures: + +- Aquila & Aquila2 (`BAAI/AquilaChat2-7B`, `BAAI/AquilaChat2-34B`, `BAAI/Aquila-7B`, `BAAI/AquilaChat-7B`, etc.) +- Baichuan & Baichuan2 (`baichuan-inc/Baichuan2-13B-Chat`, `baichuan-inc/Baichuan-7B`, etc.) +- BLOOM (`bigscience/bloom`, `bigscience/bloomz`, etc.) +- ChatGLM (`THUDM/chatglm2-6b`, `THUDM/chatglm3-6b`, etc.) +- Command-R (`CohereForAI/c4ai-command-r-v01`, etc.) +- DBRX (`databricks/dbrx-base`, `databricks/dbrx-instruct` etc.) +- DeciLM (`Deci/DeciLM-7B`, `Deci/DeciLM-7B-instruct`, etc.) +- Falcon (`tiiuae/falcon-7b`, `tiiuae/falcon-40b`, `tiiuae/falcon-rw-7b`, etc.) +- Gemma (`google/gemma-2b`, `google/gemma-7b`, etc.) +- GPT-2 (`gpt2`, `gpt2-xl`, etc.) +- GPT BigCode (`bigcode/starcoder`, `bigcode/gpt_bigcode-santacoder`, etc.) +- GPT-J (`EleutherAI/gpt-j-6b`, `nomic-ai/gpt4all-j`, etc.) +- GPT-NeoX (`EleutherAI/gpt-neox-20b`, `databricks/dolly-v2-12b`, `stabilityai/stablelm-tuned-alpha-7b`, etc.) +- InternLM (`internlm/internlm-7b`, `internlm/internlm-chat-7b`, etc.) +- InternLM2 (`internlm/internlm2-7b`, `internlm/internlm2-chat-7b`, etc.) +- Jais (`core42/jais-13b`, `core42/jais-13b-chat`, `core42/jais-30b-v3`, `core42/jais-30b-chat-v3`, etc.) +- LLaMA, Llama 2, and Meta Llama 3 (`meta-llama/Meta-Llama-3-8B-Instruct`, `meta-llama/Meta-Llama-3-70B-Instruct`, `meta-llama/Llama-2-70b-hf`, `lmsys/vicuna-13b-v1.3`, `young-geng/koala`, `openlm-research/open_llama_13b`, etc.) +- MiniCPM (`openbmb/MiniCPM-2B-sft-bf16`, `openbmb/MiniCPM-2B-dpo-bf16`, etc.) +- Mistral (`mistralai/Mistral-7B-v0.1`, `mistralai/Mistral-7B-Instruct-v0.1`, etc.) +- Mixtral (`mistralai/Mixtral-8x7B-v0.1`, `mistralai/Mixtral-8x7B-Instruct-v0.1`, `mistral-community/Mixtral-8x22B-v0.1`, etc.) +- MPT (`mosaicml/mpt-7b`, `mosaicml/mpt-30b`, etc.) +- OLMo (`allenai/OLMo-1B-hf`, `allenai/OLMo-7B-hf`, etc.) +- OPT (`facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc.) +- Orion (`OrionStarAI/Orion-14B-Base`, `OrionStarAI/Orion-14B-Chat`, etc.) +- Phi (`microsoft/phi-1_5`, `microsoft/phi-2`, etc.) +- Phi-3 (`microsoft/Phi-3-mini-4k-instruct`, `microsoft/Phi-3-mini-128k-instruct`, etc.) +- Qwen (`Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc.) +- Qwen2 (`Qwen/Qwen1.5-7B`, `Qwen/Qwen1.5-7B-Chat`, etc.) +- Qwen2MoE (`Qwen/Qwen1.5-MoE-A2.7B`, `Qwen/Qwen1.5-MoE-A2.7B-Chat`, etc.) +- StableLM(`stabilityai/stablelm-3b-4e1t`, `stabilityai/stablelm-base-alpha-7b-v2`, etc.) +- Starcoder2(`bigcode/starcoder2-3b`, `bigcode/starcoder2-7b`, `bigcode/starcoder2-15b`, etc.) +- Xverse (`xverse/XVERSE-7B-Chat`, `xverse/XVERSE-13B-Chat`, `xverse/XVERSE-65B-Chat`, etc.) +- Yi (`01-ai/Yi-6B`, `01-ai/Yi-34B`, etc.) Install vLLM with pip or [from source](https://vllm.readthedocs.io/en/latest/getting_started/installation.html#build-from-source): @@ -79,7 +94,9 @@ Install vLLM with pip or [from source](https://vllm.readthedocs.io/en/latest/get pip install vllm ``` -Visit our [documentation](https://vllm.readthedocs.io/en/latest/) to learn more. +## Getting Started + +Visit our [documentation](https://vllm.readthedocs.io/en/latest/) to get started. - [Installation](https://vllm.readthedocs.io/en/latest/getting_started/installation.html) - [Quickstart](https://vllm.readthedocs.io/en/latest/getting_started/quickstart.html) - [Supported Models](https://vllm.readthedocs.io/en/latest/models/supported_models.html) @@ -89,34 +106,6 @@ Visit our [documentation](https://vllm.readthedocs.io/en/latest/) to learn more. We welcome and value any contributions and collaborations. Please check out [CONTRIBUTING.md](./CONTRIBUTING.md) for how to get involved. -## Sponsors - -vLLM is a community project. Our compute resources for development and testing are supported by the following organizations. Thank you for your support! - - - - -- a16z -- AMD -- Anyscale -- AWS -- Crusoe Cloud -- Databricks -- DeepInfra -- Dropbox -- Lambda Lab -- NVIDIA -- Replicate -- Roblox -- RunPod -- Sequoia Capital -- Trainy -- UC Berkeley -- UC San Diego -- ZhenFund - -We also have an official fundraising venue through [OpenCollective](https://opencollective.com/vllm). We plan to use the fund to support the development, maintenance, and adoption of vLLM. - ## Citation If you use vLLM for your research, please cite our [paper](https://arxiv.org/abs/2309.06180): diff --git a/benchmarks/backend_request_func.py b/benchmarks/backend_request_func.py index fe29c67086158..f9d167590fe47 100644 --- a/benchmarks/backend_request_func.py +++ b/benchmarks/backend_request_func.py @@ -4,13 +4,10 @@ import time import traceback from dataclasses import dataclass, field -from typing import List, Optional, Union +from typing import List, Optional import aiohttp -import huggingface_hub.constants from tqdm.asyncio import tqdm -from transformers import (AutoTokenizer, PreTrainedTokenizer, - PreTrainedTokenizerFast) AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60) @@ -71,13 +68,9 @@ async def async_request_tgi( chunk_bytes = chunk_bytes.strip() if not chunk_bytes: continue - chunk_bytes = chunk_bytes.decode("utf-8") - #NOTE: Sometimes TGI returns a ping response without - # any data, we should skip it. - if chunk_bytes.startswith(":"): - continue - chunk = remove_prefix(chunk_bytes, "data:") + chunk = remove_prefix(chunk_bytes.decode("utf-8"), + "data:") data = json.loads(chunk) timestamp = time.perf_counter() @@ -96,9 +89,6 @@ async def async_request_tgi( output.latency = most_recent_timestamp - st output.success = True output.generated_text = data["generated_text"] - else: - output.error = response.reason or "" - output.success = False except Exception: output.success = False exc_info = sys.exc_info() @@ -225,8 +215,8 @@ async def async_request_openai_completions( ) -> RequestFuncOutput: api_url = request_func_input.api_url assert api_url.endswith( - "completions" - ), "OpenAI Completions API URL must end with 'completions'." + "v1/completions" + ), "OpenAI Completions API URL must end with 'v1/completions'." async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: assert not request_func_input.use_beam_search @@ -265,9 +255,6 @@ async def async_request_openai_completions( else: data = json.loads(chunk) - # NOTE: Some completion API might have a last - # usage summary response without a token so we - # want to check a token was generated if data["choices"][0]["text"]: timestamp = time.perf_counter() # First token @@ -276,8 +263,12 @@ async def async_request_openai_completions( output.ttft = ttft # Decoding phase - output.itl.append(timestamp - - most_recent_timestamp) + # NOTE: Some completion API might have a last + # usage summary response without a token so we + # do not want to include as inter-token-latency + elif data.get("usage", None) is None: + output.itl.append(timestamp - + most_recent_timestamp) most_recent_timestamp = timestamp generated_text += data["choices"][0]["text"] @@ -285,9 +276,6 @@ async def async_request_openai_completions( output.generated_text = generated_text output.success = True output.latency = latency - else: - output.error = response.reason or "" - output.success = False except Exception: output.success = False exc_info = sys.exc_info() @@ -304,8 +292,8 @@ async def async_request_openai_chat_completions( ) -> RequestFuncOutput: api_url = request_func_input.api_url assert api_url.endswith( - "chat/completions" - ), "OpenAI Chat Completions API URL must end with 'chat/completions'." + "v1/chat/completions" + ), "OpenAI Chat Completions API URL must end with 'v1/chat/completions'." async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: assert not request_func_input.use_beam_search @@ -390,30 +378,6 @@ def remove_prefix(text: str, prefix: str) -> str: return text -def get_model(pretrained_model_name_or_path: str): - if os.getenv('VLLM_USE_MODELSCOPE', 'False').lower() == 'true': - from modelscope import snapshot_download - else: - from huggingface_hub import snapshot_download - - model_path = snapshot_download( - model_id=pretrained_model_name_or_path, - local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, - ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"]) - return model_path - - -def get_tokenizer( - pretrained_model_name_or_path: str, trust_remote_code: bool -) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: - if pretrained_model_name_or_path is not None and not os.path.exists( - pretrained_model_name_or_path): - pretrained_model_name_or_path = get_model( - pretrained_model_name_or_path) - return AutoTokenizer.from_pretrained(pretrained_model_name_or_path, - trust_remote_code=trust_remote_code) - - ASYNC_REQUEST_FUNCS = { "tgi": async_request_tgi, "vllm": async_request_openai_completions, @@ -422,5 +386,4 @@ def get_tokenizer( "openai": async_request_openai_completions, "openai-chat": async_request_openai_chat_completions, "tensorrt-llm": async_request_trt_llm, - "scalellm": async_request_openai_completions, } diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index 16802d879c0ca..e8530c2761acf 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -1,19 +1,15 @@ """Benchmark the latency of processing a single batch of requests.""" import argparse -import json import time from pathlib import Path -from typing import List, Optional +from typing import Optional import numpy as np import torch from tqdm import tqdm from vllm import LLM, SamplingParams -from vllm.engine.arg_utils import EngineArgs -from vllm.inputs import PromptStrictInputs from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS -from vllm.utils import FlexibleArgumentParser def main(args: argparse.Namespace): @@ -21,33 +17,20 @@ def main(args: argparse.Namespace): # NOTE(woosuk): If the request cannot be processed in a single batch, # the engine will automatically process the request in multiple batches. - llm = LLM( - model=args.model, - speculative_model=args.speculative_model, - num_speculative_tokens=args.num_speculative_tokens, - speculative_draft_tensor_parallel_size=\ - args.speculative_draft_tensor_parallel_size, - tokenizer=args.tokenizer, - quantization=args.quantization, - tensor_parallel_size=args.tensor_parallel_size, - trust_remote_code=args.trust_remote_code, - dtype=args.dtype, - max_model_len=args.max_model_len, - enforce_eager=args.enforce_eager, - kv_cache_dtype=args.kv_cache_dtype, - quantization_param_path=args.quantization_param_path, - device=args.device, - ray_workers_use_nsight=args.ray_workers_use_nsight, - use_v2_block_manager=args.use_v2_block_manager, - enable_chunked_prefill=args.enable_chunked_prefill, - download_dir=args.download_dir, - block_size=args.block_size, - gpu_memory_utilization=args.gpu_memory_utilization, - load_format=args.load_format, - distributed_executor_backend=args.distributed_executor_backend, - otlp_traces_endpoint=args.otlp_traces_endpoint, - enable_prefix_caching=args.enable_prefix_caching, - ) + llm = LLM(model=args.model, + tokenizer=args.tokenizer, + quantization=args.quantization, + tensor_parallel_size=args.tensor_parallel_size, + trust_remote_code=args.trust_remote_code, + dtype=args.dtype, + enforce_eager=args.enforce_eager, + kv_cache_dtype=args.kv_cache_dtype, + quantization_param_path=args.quantization_param_path, + device=args.device, + ray_workers_use_nsight=args.ray_workers_use_nsight, + enable_chunked_prefill=args.enable_chunked_prefill, + download_dir=args.download_dir, + block_size=args.block_size) sampling_params = SamplingParams( n=args.n, @@ -61,9 +44,7 @@ def main(args: argparse.Namespace): dummy_prompt_token_ids = np.random.randint(10000, size=(args.batch_size, args.input_len)) - dummy_inputs: List[PromptStrictInputs] = [{ - "prompt_token_ids": batch - } for batch in dummy_prompt_token_ids.tolist()] + dummy_prompt_token_ids = dummy_prompt_token_ids.tolist() def run_to_completion(profile_dir: Optional[str] = None): if profile_dir: @@ -74,13 +55,13 @@ def run_to_completion(profile_dir: Optional[str] = None): ], on_trace_ready=torch.profiler.tensorboard_trace_handler( str(profile_dir))) as p: - llm.generate(dummy_inputs, + llm.generate(prompt_token_ids=dummy_prompt_token_ids, sampling_params=sampling_params, use_tqdm=False) print(p.key_averages()) else: start_time = time.perf_counter() - llm.generate(dummy_inputs, + llm.generate(prompt_token_ids=dummy_prompt_token_ids, sampling_params=sampling_params, use_tqdm=False) end_time = time.perf_counter() @@ -106,34 +87,18 @@ def run_to_completion(profile_dir: Optional[str] = None): for _ in tqdm(range(args.num_iters), desc="Profiling iterations"): latencies.append(run_to_completion(profile_dir=None)) latencies = np.array(latencies) - percentages = [10, 25, 50, 75, 90, 99] + percentages = [10, 25, 50, 75, 90] percentiles = np.percentile(latencies, percentages) print(f'Avg latency: {np.mean(latencies)} seconds') for percentage, percentile in zip(percentages, percentiles): print(f'{percentage}% percentile latency: {percentile} seconds') - # Output JSON results if specified - if args.output_json: - results = { - "avg_latency": np.mean(latencies), - "latencies": latencies.tolist(), - "percentiles": dict(zip(percentages, percentiles.tolist())), - } - with open(args.output_json, "w") as f: - json.dump(results, f, indent=4) - if __name__ == '__main__': - parser = FlexibleArgumentParser( + parser = argparse.ArgumentParser( description='Benchmark the latency of processing a single batch of ' 'requests till completion.') parser.add_argument('--model', type=str, default='facebook/opt-125m') - parser.add_argument('--speculative-model', type=str, default=None) - parser.add_argument('--num-speculative-tokens', type=int, default=None) - parser.add_argument('--speculative-draft-tensor-parallel-size', - '-spec-draft-tp', - type=int, - default=None) parser.add_argument('--tokenizer', type=str, default=None) parser.add_argument('--quantization', '-q', @@ -159,12 +124,6 @@ def run_to_completion(profile_dir: Optional[str] = None): parser.add_argument('--trust-remote-code', action='store_true', help='trust remote code from huggingface') - parser.add_argument( - '--max-model-len', - type=int, - default=None, - help='Maximum length of a sequence (including prompt and output). ' - 'If None, will be derived from the model.') parser.add_argument( '--dtype', type=str, @@ -178,13 +137,15 @@ def run_to_completion(profile_dir: Optional[str] = None): action='store_true', help='enforce eager mode and disable CUDA graph') parser.add_argument( - '--kv-cache-dtype', + "--kv-cache-dtype", type=str, - choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'], - default="auto", - help='Data type for kv cache storage. If "auto", will use model ' - 'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. ' - 'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)') + choices=['auto', 'fp8'], + default='auto', + help= + 'Data type for kv cache storage. If "auto", will use model data type. ' + 'FP8_E5M2 (without scaling) is only supported on cuda version greater ' + 'than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for ' + 'common inference criteria.') parser.add_argument( '--quantization-param-path', type=str, @@ -208,10 +169,9 @@ def run_to_completion(profile_dir: Optional[str] = None): parser.add_argument( "--device", type=str, - default="auto", - choices=["auto", "cuda", "cpu", "hpu", "openvino", "tpu", "xpu"], - help='device type for vLLM execution, supporting CUDA, HPU, ' - 'OpenVINO and CPU.') + default="cuda", + choices=["cuda", "cpu", "hpu"], + help='device type for vLLM execution, supporting CUDA, CPU and HPU.') parser.add_argument('--block-size', type=int, default=16, @@ -221,10 +181,6 @@ def run_to_completion(profile_dir: Optional[str] = None): action='store_true', help='If True, the prefill requests can be chunked based on the ' 'max_num_batched_tokens') - parser.add_argument("--enable-prefix-caching", - action='store_true', - help="Enable automatic prefix caching") - parser.add_argument('--use-v2-block-manager', action='store_true') parser.add_argument( "--ray-workers-use-nsight", action='store_true', @@ -235,51 +191,5 @@ def run_to_completion(profile_dir: Optional[str] = None): default=None, help='directory to download and load the weights, ' 'default to the default cache dir of huggingface') - parser.add_argument( - '--output-json', - type=str, - default=None, - help='Path to save the latency results in JSON format.') - parser.add_argument('--gpu-memory-utilization', - type=float, - default=0.9, - help='the fraction of GPU memory to be used for ' - 'the model executor, which can range from 0 to 1.' - 'If unspecified, will use the default value of 0.9.') - parser.add_argument( - '--load-format', - type=str, - default=EngineArgs.load_format, - choices=[ - 'auto', 'pt', 'safetensors', 'npcache', 'dummy', 'tensorizer', - 'bitsandbytes' - ], - help='The format of the model weights to load.\n\n' - '* "auto" will try to load the weights in the safetensors format ' - 'and fall back to the pytorch bin format if safetensors format ' - 'is not available.\n' - '* "pt" will load the weights in the pytorch bin format.\n' - '* "safetensors" will load the weights in the safetensors format.\n' - '* "npcache" will load the weights in pytorch format and store ' - 'a numpy cache to speed up the loading.\n' - '* "dummy" will initialize the weights with random values, ' - 'which is mainly for profiling.\n' - '* "tensorizer" will load the weights using tensorizer from ' - 'CoreWeave. See the Tensorize vLLM Model script in the Examples' - 'section for more information.\n' - '* "bitsandbytes" will load the weights using bitsandbytes ' - 'quantization.\n') - parser.add_argument( - '--distributed-executor-backend', - choices=['ray', 'mp'], - default=None, - help='Backend to use for distributed serving. When more than 1 GPU ' - 'is used, will be automatically set to "ray" if installed ' - 'or "mp" (multiprocessing) otherwise.') - parser.add_argument( - '--otlp-traces-endpoint', - type=str, - default=None, - help='Target URL to which OpenTelemetry traces will be sent.') args = parser.parse_args() main(args) diff --git a/benchmarks/benchmark_prefix_caching.py b/benchmarks/benchmark_prefix_caching.py index 395107a5ec747..089966986984f 100644 --- a/benchmarks/benchmark_prefix_caching.py +++ b/benchmarks/benchmark_prefix_caching.py @@ -1,7 +1,7 @@ +import argparse import time from vllm import LLM, SamplingParams -from vllm.utils import FlexibleArgumentParser PROMPT = "You are a helpful assistant in recognizes the content of tables in markdown format. Here is a table as fellows. You need to answer my question about the table.\n# Table\n|Opening|Opening|Sl. No.|Film|Cast|Director|Music Director|Notes|\n|----|----|----|----|----|----|----|----|\n|J A N|9|1|Agni Pushpam|Jayabharathi, Kamalahasan|Jeassy|M. K. Arjunan||\n|J A N|16|2|Priyamvada|Mohan Sharma, Lakshmi, KPAC Lalitha|K. S. Sethumadhavan|V. Dakshinamoorthy||\n|J A N|23|3|Yakshagaanam|Madhu, Sheela|Sheela|M. S. Viswanathan||\n|J A N|30|4|Paalkkadal|Sheela, Sharada|T. K. Prasad|A. T. Ummer||\n|F E B|5|5|Amma|Madhu, Srividya|M. Krishnan Nair|M. K. Arjunan||\n|F E B|13|6|Appooppan|Thikkurissi Sukumaran Nair, Kamal Haasan|P. Bhaskaran|M. S. Baburaj||\n|F E B|20|7|Srishti|Chowalloor Krishnankutty, Ravi Alummoodu|K. T. Muhammad|M. S. Baburaj||\n|F E B|20|8|Vanadevatha|Prem Nazir, Madhubala|Yusufali Kechery|G. Devarajan||\n|F E B|27|9|Samasya|Madhu, Kamalahaasan|K. Thankappan|Shyam||\n|F E B|27|10|Yudhabhoomi|K. P. Ummer, Vidhubala|Crossbelt Mani|R. K. Shekhar||\n|M A R|5|11|Seemantha Puthran|Prem Nazir, Jayabharathi|A. B. Raj|M. K. Arjunan||\n|M A R|12|12|Swapnadanam|Rani Chandra, Dr. Mohandas|K. G. George|Bhaskar Chandavarkar||\n|M A R|19|13|Thulavarsham|Prem Nazir, sreedevi, Sudheer|N. Sankaran Nair|V. Dakshinamoorthy||\n|M A R|20|14|Aruthu|Kaviyoor Ponnamma, Kamalahasan|Ravi|G. Devarajan||\n|M A R|26|15|Swimming Pool|Kamal Haasan, M. G. Soman|J. Sasikumar|M. K. Arjunan||\n\n# Question\nWhat' s the content in the (1,1) cells\n" # noqa: E501 @@ -44,7 +44,7 @@ def main(args): if __name__ == "__main__": - parser = FlexibleArgumentParser( + parser = argparse.ArgumentParser( description='Benchmark the performance with or without automatic ' 'prefix caching.') parser.add_argument('--model', diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index 42867fc40edd2..2c2d69da4a7d1 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -17,10 +17,6 @@ --dataset-path \ --request-rate \ # By default is inf --num-prompts # By default is 1000 - - when using tgi backend, add - --endpoint /generate_stream - to the end of the command above. """ import argparse import asyncio @@ -31,7 +27,7 @@ import warnings from dataclasses import dataclass from datetime import datetime -from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple +from typing import AsyncGenerator, List, Optional, Tuple import numpy as np from backend_request_func import (ASYNC_REQUEST_FUNCS, RequestFuncInput, @@ -39,15 +35,7 @@ from tqdm.asyncio import tqdm from transformers import PreTrainedTokenizerBase -try: - from vllm.transformers_utils.tokenizer import get_tokenizer -except ImportError: - from backend_request_func import get_tokenizer - -try: - from vllm.utils import FlexibleArgumentParser -except ImportError: - from argparse import ArgumentParser as FlexibleArgumentParser +from vllm.transformers_utils.tokenizer import get_tokenizer @dataclass @@ -64,9 +52,6 @@ class BenchmarkMetrics: mean_tpot_ms: float median_tpot_ms: float p99_tpot_ms: float - mean_itl_ms: float - median_itl_ms: float - p99_itl_ms: float def sample_sharegpt_requests( @@ -208,37 +193,24 @@ def calculate_metrics( dur_s: float, tokenizer: PreTrainedTokenizerBase, ) -> Tuple[BenchmarkMetrics, List[int]]: - actual_output_lens: List[int] = [] + actual_output_lens = [] total_input = 0 completed = 0 - itls: List[float] = [] - tpots: List[float] = [] - ttfts: List[float] = [] + tpots = [] + ttfts = [] for i in range(len(outputs)): if outputs[i].success: - # We use the tokenizer to count the number of output tokens for all - # serving backends instead of looking at len(outputs[i].itl) since - # multiple output tokens may be bundled together - # Note: this may inflate the output token count slightly - output_len = len( - tokenizer(outputs[i].generated_text, - add_special_tokens=False).input_ids) + output_len = len(tokenizer(outputs[i].generated_text).input_ids) actual_output_lens.append(output_len) total_input += input_requests[i][1] if output_len > 1: tpots.append( (outputs[i].latency - outputs[i].ttft) / (output_len - 1)) - itls += outputs[i].itl ttfts.append(outputs[i].ttft) completed += 1 else: actual_output_lens.append(0) - if completed == 0: - warnings.warn( - "All requests failed. This is likely due to a misconfiguration " - "on the benchmark arguments.", - stacklevel=2) metrics = BenchmarkMetrics( completed=completed, total_input=total_input, @@ -250,12 +222,9 @@ def calculate_metrics( 1000, # ttfts is empty if streaming is not supported by backend median_ttft_ms=np.median(ttfts or 0) * 1000, p99_ttft_ms=np.percentile(ttfts or 0, 99) * 1000, - mean_tpot_ms=np.mean(tpots or 0) * 1000, - median_tpot_ms=np.median(tpots or 0) * 1000, - p99_tpot_ms=np.percentile(tpots or 0, 99) * 1000, - mean_itl_ms=np.mean(itls or 0) * 1000, - median_itl_ms=np.median(itls or 0) * 1000, - p99_itl_ms=np.percentile(itls or 0, 99) * 1000, + mean_tpot_ms=np.mean(tpots) * 1000, + median_tpot_ms=np.median(tpots) * 1000, + p99_tpot_ms=np.percentile(tpots, 99) * 1000, ) return metrics, actual_output_lens @@ -273,34 +242,16 @@ async def benchmark( disable_tqdm: bool, ): if backend in ASYNC_REQUEST_FUNCS: - request_func = ASYNC_REQUEST_FUNCS[backend] + request_func = ASYNC_REQUEST_FUNCS.get(backend) else: raise ValueError(f"Unknown backend: {backend}") - print("Starting initial single prompt test run...") - test_prompt, test_prompt_len, test_output_len = input_requests[0] - test_input = RequestFuncInput( - model=model_id, - prompt=test_prompt, - api_url=api_url, - prompt_len=test_prompt_len, - output_len=test_output_len, - best_of=best_of, - use_beam_search=use_beam_search, - ) - test_output = await request_func(request_func_input=test_input) - if not test_output.success: - raise ValueError( - "Initial test run failed - Please make sure benchmark arguments " - f"are correctly specified. Error: {test_output.error}") - else: - print("Initial test run completed. Starting main benchmark run...") print(f"Traffic request rate: {request_rate}") pbar = None if disable_tqdm else tqdm(total=len(input_requests)) benchmark_start_time = time.perf_counter() - tasks: List[asyncio.Task] = [] + tasks = [] async for request in get_request(input_requests, request_rate): prompt, prompt_len, output_len = request request_func_input = RequestFuncInput( @@ -318,7 +269,7 @@ async def benchmark( pbar=pbar))) outputs: List[RequestFuncOutput] = await asyncio.gather(*tasks) - if pbar is not None: + if not disable_tqdm: pbar.close() benchmark_duration = time.perf_counter() - benchmark_start_time @@ -355,10 +306,6 @@ async def benchmark( print("{:<40} {:<10.2f}".format("Median TPOT (ms):", metrics.median_tpot_ms)) print("{:<40} {:<10.2f}".format("P99 TPOT (ms):", metrics.p99_tpot_ms)) - print("{s:{c}^{n}}".format(s='Inter-token Latency', n=50, c='-')) - print("{:<40} {:<10.2f}".format("Mean ITL (ms):", metrics.mean_itl_ms)) - print("{:<40} {:<10.2f}".format("Median ITL (ms):", metrics.median_itl_ms)) - print("{:<40} {:<10.2f}".format("P99 ITL (ms):", metrics.p99_itl_ms)) print("=" * 50) result = { @@ -375,9 +322,6 @@ async def benchmark( "mean_tpot_ms": metrics.mean_tpot_ms, "median_tpot_ms": metrics.median_tpot_ms, "p99_tpot_ms": metrics.p99_tpot_ms, - "mean_itl_ms": metrics.mean_itl_ms, - "median_itl_ms": metrics.median_itl_ms, - "p99_itl_ms": metrics.p99_itl_ms, "input_lens": [output.prompt_len for output in outputs], "output_lens": actual_output_lens, "ttfts": [output.ttft for output in outputs], @@ -474,7 +418,7 @@ def main(args: argparse.Namespace): # Save config and results to json if args.save_result: - result_json: Dict[str, Any] = {} + result_json = {} # Setup current_dt = datetime.now().strftime("%Y%m%d-%H%M%S") @@ -507,8 +451,6 @@ def main(args: argparse.Namespace): # Save to file base_model_id = model_id.split("/")[-1] file_name = f"{backend}-{args.request_rate}qps-{base_model_id}-{current_dt}.json" #noqa - if args.result_filename: - file_name = args.result_filename if args.result_dir: file_name = os.path.join(args.result_dir, file_name) with open(file_name, "w") as outfile: @@ -516,7 +458,7 @@ def main(args: argparse.Namespace): if __name__ == "__main__": - parser = FlexibleArgumentParser( + parser = argparse.ArgumentParser( description="Benchmark the online serving throughput.") parser.add_argument( "--backend", @@ -649,15 +591,6 @@ def main(args: argparse.Namespace): help="Specify directory to save benchmark json results." "If not specified, results are saved in the current directory.", ) - parser.add_argument( - "--result-filename", - type=str, - default=None, - help="Specify the filename to save benchmark json results." - "If not specified, results will be saved in " - "{backend}-{args.request_rate}qps-{base_model_id}-{current_dt}.json" - " format.", - ) args = parser.parse_args() main(args) diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index ff33e3dced66f..2e8cfd3f2ca3e 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -10,9 +10,7 @@ from transformers import (AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizerBase) -from vllm.engine.arg_utils import EngineArgs from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS -from vllm.utils import FlexibleArgumentParser def sample_requests( @@ -80,10 +78,8 @@ def run_vllm( enable_prefix_caching: bool, enable_chunked_prefill: bool, max_num_batched_tokens: int, - distributed_executor_backend: Optional[str], gpu_memory_utilization: float = 0.9, download_dir: Optional[str] = None, - load_format: str = EngineArgs.load_format, ) -> float: from vllm import LLM, SamplingParams llm = LLM( @@ -104,13 +100,11 @@ def run_vllm( download_dir=download_dir, enable_chunked_prefill=enable_chunked_prefill, max_num_batched_tokens=max_num_batched_tokens, - distributed_executor_backend=distributed_executor_backend, - load_format=load_format, ) # Add the requests to the engine. - prompts: List[str] = [] - sampling_params: List[SamplingParams] = [] + prompts = [] + sampling_params = [] for prompt, _, output_len in requests: prompts.append(prompt) sampling_params.append( @@ -231,8 +225,8 @@ def main(args: argparse.Namespace): args.enforce_eager, args.kv_cache_dtype, args.quantization_param_path, args.device, args.enable_prefix_caching, args.enable_chunked_prefill, - args.max_num_batched_tokens, args.distributed_executor_backend, - args.gpu_memory_utilization, args.download_dir, args.load_format) + args.max_num_batched_tokens, args.gpu_memory_utilization, + args.download_dir) elif args.backend == "hf": assert args.tensor_parallel_size == 1 elapsed_time = run_hf(requests, args.model, tokenizer, args.n, @@ -248,21 +242,9 @@ def main(args: argparse.Namespace): print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, " f"{total_num_tokens / elapsed_time:.2f} tokens/s") - # Output JSON results if specified - if args.output_json: - results = { - "elapsed_time": elapsed_time, - "num_requests": len(requests), - "total_num_tokens": total_num_tokens, - "requests_per_second": len(requests) / elapsed_time, - "tokens_per_second": total_num_tokens / elapsed_time, - } - with open(args.output_json, "w") as f: - json.dump(results, f, indent=4) - if __name__ == "__main__": - parser = FlexibleArgumentParser(description="Benchmark the throughput.") + parser = argparse.ArgumentParser(description="Benchmark the throughput.") parser.add_argument("--backend", type=str, choices=["vllm", "hf", "mii"], @@ -329,13 +311,15 @@ def main(args: argparse.Namespace): action="store_true", help="enforce eager execution") parser.add_argument( - '--kv-cache-dtype', + "--kv-cache-dtype", type=str, - choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'], + choices=["auto", "fp8"], default="auto", - help='Data type for kv cache storage. If "auto", will use model ' - 'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. ' - 'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)') + help= + 'Data type for kv cache storage. If "auto", will use model data type. ' + 'FP8_E5M2 (without scaling) is only supported on cuda version greater ' + 'than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for ' + 'common inference criteria.') parser.add_argument( '--quantization-param-path', type=str, @@ -349,10 +333,9 @@ def main(args: argparse.Namespace): parser.add_argument( "--device", type=str, - default="auto", - choices=["auto", "cuda", "cpu", "hpu", "openvino", "tpu", "xpu"], - help='device type for vLLM execution, supporting CUDA, HPU, ' - 'OpenVINO and CPU.') + default="cuda", + choices=["cuda", "cpu", "hpu"], + help='device type for vLLM execution, supporting CUDA, CPU and HPU.') parser.add_argument( "--enable-prefix-caching", action='store_true', @@ -370,41 +353,6 @@ def main(args: argparse.Namespace): default=None, help='directory to download and load the weights, ' 'default to the default cache dir of huggingface') - parser.add_argument( - '--output-json', - type=str, - default=None, - help='Path to save the throughput results in JSON format.') - parser.add_argument( - '--distributed-executor-backend', - choices=['ray', 'mp'], - default=None, - help='Backend to use for distributed serving. When more than 1 GPU ' - 'is used, will be automatically set to "ray" if installed ' - 'or "mp" (multiprocessing) otherwise.') - parser.add_argument( - '--load-format', - type=str, - default=EngineArgs.load_format, - choices=[ - 'auto', 'pt', 'safetensors', 'npcache', 'dummy', 'tensorizer', - 'bitsandbytes' - ], - help='The format of the model weights to load.\n\n' - '* "auto" will try to load the weights in the safetensors format ' - 'and fall back to the pytorch bin format if safetensors format ' - 'is not available.\n' - '* "pt" will load the weights in the pytorch bin format.\n' - '* "safetensors" will load the weights in the safetensors format.\n' - '* "npcache" will load the weights in pytorch format and store ' - 'a numpy cache to speed up the loading.\n' - '* "dummy" will initialize the weights with random values, ' - 'which is mainly for profiling.\n' - '* "tensorizer" will load the weights using tensorizer from ' - 'CoreWeave. See the Tensorize vLLM Model script in the Examples' - 'section for more information.\n' - '* "bitsandbytes" will load the weights using bitsandbytes ' - 'quantization.\n') args = parser.parse_args() if args.tokenizer is None: args.tokenizer = args.model diff --git a/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py b/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py deleted file mode 100644 index 377f8683c021f..0000000000000 --- a/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py +++ /dev/null @@ -1,353 +0,0 @@ -import argparse -import copy -import itertools -import pickle as pkl -import time -from typing import Callable, Iterable, List, Tuple - -import torch -import torch.utils.benchmark as TBenchmark -from torch.utils.benchmark import Measurement as TMeasurement -from weight_shapes import WEIGHT_SHAPES - -from vllm import _custom_ops as ops -from vllm.utils import FlexibleArgumentParser - -DEFAULT_MODELS = list(WEIGHT_SHAPES.keys())[1:] -DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512] -DEFAULT_TP_SIZES = [1] - -# helpers - - -def to_fp8(tensor: torch.tensor) -> torch.tensor: - finfo = torch.finfo(torch.float8_e4m3fn) - return torch.round(tensor.clamp( - min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn) - - -def to_int8(tensor: torch.tensor) -> torch.tensor: - return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8) - - -def make_rand_tensors(dtype: torch.dtype, m: int, n: int, - k: int) -> Tuple[torch.tensor, torch.tensor]: - - a = torch.randn((m, k), device='cuda') * 5 - b = torch.randn((n, k), device='cuda').t() * 5 - - if dtype == torch.int8: - return to_int8(a), to_int8(b) - if dtype == torch.float8_e4m3fn: - return to_fp8(a), to_fp8(b) - - raise ValueError("unsupported dtype") - - -# impl - - -def pytorch_mm_impl(a: torch.tensor, b: torch.tensor, scale_a: torch.tensor, - scale_b: torch.tensor, - out_dtype: torch.dtype) -> torch.tensor: - return torch.mm(a, b) - - -def pytorch_fp8_impl(a: torch.tensor, b: torch.tensor, scale_a: torch.tensor, - scale_b: torch.tensor, - out_dtype: torch.dtype) -> torch.tensor: - return torch._scaled_mm(a, - b, - scale_a=scale_a, - scale_b=scale_b, - out_dtype=out_dtype) - - -def pytorch_fp8_impl_fast_accum(a: torch.tensor, b: torch.tensor, - scale_a: torch.tensor, scale_b: torch.tensor, - out_dtype: torch.dtype) -> torch.tensor: - return torch._scaled_mm(a, - b, - scale_a=scale_a, - scale_b=scale_b, - out_dtype=out_dtype, - use_fast_accum=True) - - -def cutlass_impl(a: torch.tensor, b: torch.tensor, scale_a: torch.tensor, - scale_b: torch.tensor, - out_dtype: torch.dtype) -> torch.tensor: - return ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype=out_dtype) - - -# bench -def bench_fn(a: torch.tensor, b: torch.tensor, scale_a: torch.tensor, - scale_b: torch.tensor, out_dtype: torch.dtype, label: str, - sub_label: str, fn: Callable, description: str) -> TMeasurement: - - min_run_time = 1 - - globals = { - "a": a, - "b": b, - "scale_a": scale_a, - "scale_b": scale_b, - "out_dtype": out_dtype, - "fn": fn, - } - return TBenchmark.Timer( - stmt="fn(a, b, scale_a, scale_b, out_dtype)", - globals=globals, - label=label, - sub_label=sub_label, - description=description, - ).blocked_autorange(min_run_time=min_run_time) - - -def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str, - sub_label: str) -> Iterable[TMeasurement]: - assert dtype == torch.int8 - a, b = make_rand_tensors(torch.int8, m, n, k) - scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32) - scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32) - - timers = [] - # pytorch impl - timers.append( - bench_fn(a.to(dtype=torch.bfloat16, device="cuda"), - b.to(dtype=torch.bfloat16, device="cuda"), scale_a, scale_b, - torch.bfloat16, label, sub_label, pytorch_mm_impl, - "pytorch_bf16_bf16_bf16_matmul-no-scales")) - - # cutlass impl - timers.append( - bench_fn(a, b, scale_a, scale_b, torch.bfloat16, label, sub_label, - cutlass_impl, "cutlass_i8_i8_bf16_scaled_mm")) - - return timers - - -def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str, - sub_label: str) -> Iterable[TMeasurement]: - assert dtype == torch.float8_e4m3fn - a, b = make_rand_tensors(torch.float8_e4m3fn, m, n, k) - scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32) - scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32) - - timers = [] - - # pytorch impl w. bf16 - timers.append( - bench_fn(a.to(dtype=torch.bfloat16, device="cuda"), - b.to(dtype=torch.bfloat16, device="cuda"), scale_a, scale_b, - torch.bfloat16, label, sub_label, pytorch_mm_impl, - "pytorch_bf16_bf16_bf16_matmul-no-scales")) - - # pytorch impl: bf16 output, without fp8 fast accum - timers.append( - bench_fn(a, b, scale_a, scale_b, torch.bfloat16, label, sub_label, - pytorch_fp8_impl, "pytorch_fp8_fp8_bf16_scaled_mm")) - - # pytorch impl: bf16 output, with fp8 fast accum - timers.append( - bench_fn(a, b, scale_a, scale_b, torch.bfloat16, label, sub_label, - pytorch_fp8_impl_fast_accum, - "pytorch_fp8_fp8_bf16_scaled_mm_fast_accum")) - - # pytorch impl: fp16 output, without fp8 fast accum - timers.append( - bench_fn(a, b, scale_a, scale_b, torch.float16, label, sub_label, - pytorch_fp8_impl, "pytorch_fp8_fp8_fp16_scaled_mm")) - - # pytorch impl: fp16 output, with fp8 fast accum - timers.append( - bench_fn(a, b, scale_a, scale_b, torch.float16, label, sub_label, - pytorch_fp8_impl_fast_accum, - "pytorch_fp8_fp8_fp16_scaled_mm_fast_accum")) - - # cutlass impl: bf16 output - timers.append( - bench_fn(a, b, scale_a, scale_b, torch.bfloat16, label, sub_label, - cutlass_impl, "cutlass_fp8_fp8_bf16_scaled_mm")) - # cutlass impl: fp16 output - timers.append( - bench_fn(a, b, scale_a, scale_b, torch.float16, label, sub_label, - cutlass_impl, "cutlass_fp8_fp8_fp16_scaled_mm")) - return timers - - -def bench(dtype: torch.dtype, m: int, k: int, n: int, label: str, - sub_label: str) -> Iterable[TMeasurement]: - if dtype == torch.int8: - return bench_int8(dtype, m, k, n, label, sub_label) - if dtype == torch.float8_e4m3fn: - return bench_fp8(dtype, m, k, n, label, sub_label) - raise ValueError("unsupported type") - - -# runner -def print_timers(timers: Iterable[TMeasurement]): - compare = TBenchmark.Compare(timers) - compare.print() - - -def run(dtype: torch.dtype, - MKNs: Iterable[Tuple[int, int, int]]) -> Iterable[TMeasurement]: - - results = [] - for m, k, n in MKNs: - timers = bench(dtype, m, k, n, f"scaled-{dtype}-gemm", - f"MKN=({m}x{k}x{n})") - print_timers(timers) - results.extend(timers) - - return results - - -# output makers -def make_output(data: Iterable[TMeasurement], - MKNs: Iterable[Tuple[int, int, int]], - base_description: str, - timestamp=None): - - print(f"== All Results {base_description} ====") - print_timers(data) - - # pickle all the results - timestamp = int(time.time()) if timestamp is None else timestamp - with open(f"{base_description}-{timestamp}.pkl", "wb") as f: - pkl.dump(data, f) - - -# argparse runners - - -def run_square_bench(args): - dim_sizes = list( - range(args.dim_start, args.dim_end + 1, args.dim_increment)) - MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes)) - data = run(args.dtype, MKNs) - - make_output(data, MKNs, f"square_bench-{args.dtype}") - - -def run_range_bench(args): - dim_sizes = list(range(args.dim_start, args.dim_end, args.dim_increment)) - n = len(dim_sizes) - Ms = [args.m_constant] * n if args.m_constant is not None else dim_sizes - Ks = [args.k_constant] * n if args.k_constant is not None else dim_sizes - Ns = [args.n_constant] * n if args.n_constant is not None else dim_sizes - MKNs = list(zip(Ms, Ks, Ns)) - data = run(args.dtype, MKNs) - - make_output(data, MKNs, f"range_bench-{args.dtype}") - - -def run_model_bench(args): - - print("Benchmarking models:") - for i, model in enumerate(args.models): - print(f"[{i}] {model}") - - def model_shapes(model_name: str, tp_size: int) -> List[Tuple[int, int]]: - KNs = [] - for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model_name]): - KN[tp_split_dim] = KN[tp_split_dim] // tp_size - KNs.append(KN) - return KNs - - model_bench_data = [] - models_tps = list(itertools.product(args.models, args.tp_sizes)) - for model, tp_size in models_tps: - Ms = args.batch_sizes - KNs = model_shapes(model, tp_size) - MKNs = [] - for m in Ms: - for k, n in KNs: - MKNs.append((m, k, n)) - - data = run(args.dtype, MKNs) - model_bench_data.append(data) - - # Print all results - for data, model_tp in zip(model_bench_data, models_tps): - model, tp_size = model_tp - print(f"== Results {args.dtype} {model}-TP{tp_size} ====") - print_timers(data) - - timestamp = int(time.time()) - - all_data = [] - for d in model_bench_data: - all_data.extend(d) - # pickle all data - with open(f"model_bench-{args.dtype}-{timestamp}.pkl", "wb") as f: - pkl.dump(all_data, f) - - -if __name__ == '__main__': - - def to_torch_dtype(dt): - if dt == "int8": - return torch.int8 - if dt == "fp8": - return torch.float8_e4m3fn - raise ValueError("unsupported dtype") - - parser = FlexibleArgumentParser( - description=""" -Benchmark Cutlass GEMM. - - To run square GEMMs: - python3 ./benchmarks/cutlass_benchmarks/w8a8_benchmarks.py --dtype fp8 square_bench --dim-start 128 --dim-end 512 --dim-increment 64 - - To run constant N and K and sweep M: - python3 ./benchmarks/cutlass_benchmarks/w8a8_benchmarks.py --dtype fp8 range_bench --dim-start 128 --dim-end 512 --dim-increment 64 --n-constant 16384 --k-constant 16384 - - To run dimensions from a model: - python3 ./benchmarks/cutlass_benchmarks/w8a8_benchmarks.py --dtype fp8 model_bench --models meta-llama/Llama-2-7b-hf --batch-sizes 16 --tp-sizes 1 - - Output: - - a .pkl file, that is a list of raw torch.benchmark.utils.Measurements for the pytorch and cutlass implementations for the various GEMMs. - """, # noqa: E501 - formatter_class=argparse.RawTextHelpFormatter) - - parser.add_argument("--dtype", - type=to_torch_dtype, - required=True, - help="Available options are ['int8', 'fp8']") - subparsers = parser.add_subparsers(dest="cmd") - - square_parser = subparsers.add_parser("square_bench") - square_parser.add_argument("--dim-start", type=int, required=True) - square_parser.add_argument("--dim-end", type=int, required=True) - square_parser.add_argument("--dim-increment", type=int, required=True) - square_parser.set_defaults(func=run_square_bench) - - range_parser = subparsers.add_parser("range_bench") - range_parser.add_argument("--dim-start", type=int, required=True) - range_parser.add_argument("--dim-end", type=int, required=True) - range_parser.add_argument("--dim-increment", type=int, required=True) - range_parser.add_argument("--m-constant", type=int, default=None) - range_parser.add_argument("--n-constant", type=int, default=None) - range_parser.add_argument("--k-constant", type=int, default=None) - range_parser.set_defaults(func=run_range_bench) - - model_parser = subparsers.add_parser("model_bench") - model_parser.add_argument("--models", - nargs="+", - type=str, - default=DEFAULT_MODELS, - choices=WEIGHT_SHAPES.keys()) - model_parser.add_argument("--tp-sizes", - nargs="+", - type=int, - default=DEFAULT_TP_SIZES) - model_parser.add_argument("--batch-sizes", - nargs="+", - type=int, - default=DEFAULT_BATCH_SIZES) - model_parser.set_defaults(func=run_model_bench) - - args = parser.parse_args() - args.func(args) diff --git a/benchmarks/cutlass_benchmarks/weight_shapes.py b/benchmarks/cutlass_benchmarks/weight_shapes.py deleted file mode 100644 index 25ec9d6028627..0000000000000 --- a/benchmarks/cutlass_benchmarks/weight_shapes.py +++ /dev/null @@ -1,43 +0,0 @@ -# Weight Shapes are in the format -# ([K, N], TP_SPLIT_DIM) -# Example: -# A shape of ([14336, 4096], 0) indicates the following GEMM shape, -# - TP1 : K = 14336, N = 4096 -# - TP2 : K = 7168, N = 4096 -# A shape of ([4096, 6144], 1) indicates the following GEMM shape, -# - TP1 : K = 4096, N = 6144 -# - TP4 : K = 4096, N = 1536 - -# TP1 shapes -WEIGHT_SHAPES = { - "mistralai/Mistral-7B-v0.1": [ - ([4096, 6144], 1), - ([4096, 4096], 0), - ([4096, 28672], 1), - ([14336, 4096], 0), - ], - "meta-llama/Llama-2-7b-hf": [ - ([4096, 12288], 1), - ([4096, 4096], 0), - ([4096, 22016], 1), - ([11008, 4096], 0), - ], - "meta-llama/Llama-3-8b": [ - ([4096, 6144], 1), - ([4096, 4096], 0), - ([4096, 28672], 1), - ([14336, 4096], 0), - ], - "meta-llama/Llama-2-13b-hf": [ - ([5120, 15360], 1), - ([5120, 5120], 0), - ([5120, 27648], 1), - ([13824, 5120], 0), - ], - "meta-llama/Llama-2-70b-hf": [ - ([8192, 10240], 1), - ([8192, 8192], 0), - ([8192, 57344], 1), - ([28672, 8192], 0), - ], -} diff --git a/benchmarks/kernels/benchmark_aqlm.py b/benchmarks/kernels/benchmark_aqlm.py index 601c4ea439aea..59392947b15c8 100644 --- a/benchmarks/kernels/benchmark_aqlm.py +++ b/benchmarks/kernels/benchmark_aqlm.py @@ -1,3 +1,4 @@ +import argparse import os import sys from typing import Optional @@ -9,7 +10,6 @@ from vllm.model_executor.layers.quantization.aqlm import ( dequantize_weight, generic_dequantize_gemm, get_int_dtype, optimized_dequantize_gemm) -from vllm.utils import FlexibleArgumentParser os.environ['CUDA_VISIBLE_DEVICES'] = '0' @@ -86,9 +86,9 @@ def dequant_no_scale( # Compare the optimized 1x16 and 2x8 cuda decompression/dequant kernels against # the generic pytorch version. # Just visual comparison. -def dequant_test(k: int, parts: torch.Tensor, nbooks: int, bits: int) -> None: +def dequant_test(k: int, parts: torch.tensor, nbooks: int, bits: int) -> None: - n = int(parts.sum().item()) + n = parts.sum().item() device = torch.device('cuda:0') @@ -137,7 +137,7 @@ def dequant_test(k: int, parts: torch.Tensor, nbooks: int, bits: int) -> None: def main(): - parser = FlexibleArgumentParser(description="Benchmark aqlm performance.") + parser = argparse.ArgumentParser(description="Benchmark aqlm performance.") # Add arguments parser.add_argument("--nbooks", @@ -204,7 +204,7 @@ def main(): sys.stdout = sys.__stdout__ -def run_grid(m: int, k: int, parts: torch.Tensor, nbooks: int, bits: int, +def run_grid(m: int, k: int, parts: torch.tensor, nbooks: int, bits: int, methods): # I didn't see visible improvements from increasing these, but feel free :) @@ -252,10 +252,10 @@ def run_grid(m: int, k: int, parts: torch.Tensor, nbooks: int, bits: int, print('') -def run_timing(num_calls: int, m: int, k: int, parts: torch.Tensor, +def run_timing(num_calls: int, m: int, k: int, parts: torch.tensor, nbooks: int, bits: int, method) -> float: - n = int(parts.sum().item()) + n = parts.sum().item() device = torch.device('cuda:0') diff --git a/benchmarks/kernels/benchmark_marlin.py b/benchmarks/kernels/benchmark_marlin.py deleted file mode 100644 index 261f5829631ee..0000000000000 --- a/benchmarks/kernels/benchmark_marlin.py +++ /dev/null @@ -1,235 +0,0 @@ -from typing import List - -import torch -import torch.utils.benchmark as benchmark -from benchmark_shapes import WEIGHT_SHAPES - -from vllm import _custom_ops as ops -from vllm.model_executor.layers.quantization.gptq_marlin import ( - GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, - GPTQ_MARLIN_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_SUPPORTED_NUM_BITS) -from vllm.model_executor.layers.quantization.gptq_marlin_24 import ( - GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N, - GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_NUM_BITS) -from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - MarlinWorkspace, marlin_24_quantize, marlin_quantize) -from vllm.model_executor.layers.quantization.utils.quant_utils import ( - gptq_pack, quantize_weights, sort_weights) -from vllm.utils import FlexibleArgumentParser - -DEFAULT_MODELS = ["meta-llama/Llama-2-7b-hf/TP1"] -DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512] - -ACT_ORDER_OPTS = [False, True] -K_FULL_OPTS = [False, True] - - -def bench_run(results: List[benchmark.Measurement], model: str, - act_order: bool, is_k_full: bool, num_bits: int, group_size: int, - size_m: int, size_k: int, size_n: int): - label = "Quant Matmul" - - sub_label = ("{}, act={} k_full={}, b={}, g={}, " - "MKN=({}x{}x{})".format(model, act_order, is_k_full, num_bits, - group_size, size_m, size_k, size_n)) - - print(f"Testing: {sub_label}") - - a = torch.randn(size_m, size_k).to(torch.half).cuda() - b = torch.rand(size_k, size_n).to(torch.half).cuda() - - a_tmp = (torch.zeros(size_m, size_k).to(torch.half).cuda()) - - # Marlin quant - ( - marlin_w_ref, - marlin_q_w, - marlin_s, - marlin_g_idx, - marlin_sort_indices, - marlin_rand_perm, - ) = marlin_quantize(b, num_bits, group_size, act_order) - - # Marlin_24 quant - (marlin_24_w_ref, marlin_24_q_w_comp, marlin_24_meta, - marlin_24_s) = marlin_24_quantize(b, num_bits, group_size) - - # GPTQ quant - (w_ref, q_w, s, g_idx, - rand_perm) = quantize_weights(b, num_bits, group_size, act_order) - q_w_gptq = gptq_pack(q_w, num_bits, size_k, size_n) - - # For act_order, sort the "weights" and "g_idx" - # so that group ids are increasing - repack_sort_indices = torch.empty(0, dtype=torch.int, device=b.device) - if act_order: - (q_w, g_idx, repack_sort_indices) = sort_weights(q_w, g_idx) - - # Prepare - marlin_workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N, - GPTQ_MARLIN_MAX_PARALLEL) - - marlin_24_workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_24_MIN_THREAD_N, - GPTQ_MARLIN_24_MAX_PARALLEL) - - globals = { - # Gen params - "num_bits": num_bits, - "group_size": group_size, - "size_m": size_m, - "size_n": size_n, - "size_k": size_k, - "a": a, - "a_tmp": a_tmp, - # Marlin params - "marlin_w_ref": marlin_w_ref, - "marlin_q_w": marlin_q_w, - "marlin_s": marlin_s, - "marlin_g_idx": marlin_g_idx, - "marlin_sort_indices": marlin_sort_indices, - "marlin_rand_perm": marlin_rand_perm, - "marlin_workspace": marlin_workspace, - "is_k_full": is_k_full, - # Marlin_24 params - "marlin_24_w_ref": marlin_24_w_ref, - "marlin_24_q_w_comp": marlin_24_q_w_comp, - "marlin_24_meta": marlin_24_meta, - "marlin_24_s": marlin_24_s, - "marlin_24_workspace": marlin_24_workspace, - # GPTQ params - "q_w_gptq": q_w_gptq, - "repack_sort_indices": repack_sort_indices, - # Kernels - "gptq_marlin_gemm": ops.gptq_marlin_gemm, - "gptq_marlin_24_gemm": ops.gptq_marlin_24_gemm, - "gptq_marlin_repack": ops.gptq_marlin_repack, - } - - min_run_time = 1 - - # Warmup pytorch - for i in range(5): - torch.matmul(a, marlin_w_ref) - - results.append( - benchmark.Timer( - stmt="torch.matmul(a, marlin_w_ref)", - globals=globals, - label=label, - sub_label=sub_label, - description="pytorch_gemm", - ).blocked_autorange(min_run_time=min_run_time)) - - results.append( - benchmark.Timer( - stmt= - "output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, num_bits, size_m, size_n, size_k, is_k_full)", # noqa: E501 - globals=globals, - label=label, - sub_label=sub_label, - description="gptq_marlin_gemm", - ).blocked_autorange(min_run_time=min_run_time)) - - if (num_bits in GPTQ_MARLIN_24_SUPPORTED_NUM_BITS - and group_size in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES): - results.append( - benchmark.Timer( - stmt= - "output = gptq_marlin_24_gemm(a, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s, marlin_24_workspace.scratch, num_bits, size_m, size_n, size_k)", # noqa: E501 - globals=globals, - label=label, - sub_label=sub_label, - description="gptq_marlin_24_gemm", - ).blocked_autorange(min_run_time=min_run_time)) - - results.append( - benchmark.Timer( - stmt= - "q_res = gptq_marlin_repack(q_w_gptq, repack_sort_indices, size_k, size_n, num_bits)", # noqa: E501 - globals=globals, - label=label, - sub_label=sub_label, - description="gptq_marlin_repack", - ).blocked_autorange(min_run_time=min_run_time)) - - -def main(args): - print("Benchmarking models:") - for i, model in enumerate(args.models): - print(f"[{i}] {model}") - - results: List[benchmark.Measurement] = [] - - for model in args.models: - for layer in WEIGHT_SHAPES[model]: - size_k = layer[0] - size_n = layer[1] - - if len(args.limit_k) > 0 and size_k not in args.limit_k: - continue - - if len(args.limit_n) > 0 and size_n not in args.limit_n: - continue - - for act_order in ACT_ORDER_OPTS: - if len(args.limit_act_order - ) > 0 and act_order not in args.limit_act_order: - continue - - for is_k_full in K_FULL_OPTS: - if len(args.limit_k_full - ) > 0 and is_k_full not in args.limit_k_full: - continue - - for num_bits in GPTQ_MARLIN_SUPPORTED_NUM_BITS: - if len(args.limit_num_bits - ) > 0 and num_bits not in args.limit_num_bits: - continue - - for group_size in GPTQ_MARLIN_SUPPORTED_GROUP_SIZES: - if len( - args.limit_group_size - ) > 0 and group_size not in args.limit_group_size: - continue - - # For act_order, the group_size must be less than - # size_k - if act_order and (group_size == size_k - or group_size == -1): - continue - - for size_m in args.batch_sizes: - bench_run(results, model, act_order, is_k_full, - num_bits, group_size, size_m, size_k, - size_n) - - compare = benchmark.Compare(results) - compare.print() - - -# For quick benchmarking use: -# python benchmark_marlin.py --batch-sizes 1 16 32 --limit-k 4096 --limit-n 4096 --limit-group-size 128 --limit-num-bits 4 --limit-act-order 0 --limit-k-full 1 # noqa E501 -# -if __name__ == "__main__": - parser = FlexibleArgumentParser( - description="Benchmark Marlin across specified models/shapes/batches") - parser.add_argument( - "--models", - nargs="+", - type=str, - default=DEFAULT_MODELS, - choices=WEIGHT_SHAPES.keys(), - ) - parser.add_argument("--batch-sizes", - nargs="+", - type=int, - default=DEFAULT_BATCH_SIZES) - parser.add_argument("--limit-k", nargs="+", type=int, default=[]) - parser.add_argument("--limit-n", nargs="+", type=int, default=[]) - parser.add_argument("--limit-group-size", nargs="+", type=int, default=[]) - parser.add_argument("--limit-num-bits", nargs="+", type=int, default=[]) - parser.add_argument("--limit-act-order", nargs="+", type=int, default=[]) - parser.add_argument("--limit-k-full", nargs="+", type=int, default=[]) - - args = parser.parse_args() - main(args) diff --git a/benchmarks/kernels/benchmark_mixtral_moe.py b/benchmarks/kernels/benchmark_mixtral_moe.py new file mode 100644 index 0000000000000..5280b214144c9 --- /dev/null +++ b/benchmarks/kernels/benchmark_mixtral_moe.py @@ -0,0 +1,215 @@ +import argparse +import json +import os +import sys + +import torch +import torch.nn.functional as F +import triton +from tqdm import tqdm + +from vllm.model_executor.layers.fused_moe import (fused_moe, + get_config_file_name) + +os.environ['CUDA_VISIBLE_DEVICES'] = '0' + + +def main(dtype: str): + method = fused_moe + for bs in [ + 1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128, 256, 512, 1024, 1536, + 2048, 3072, 4096 + ]: + run_grid(bs, method=method, dtype=dtype) + + +def run_grid(bs, method, dtype: str): + d_model = 4096 + num_total_experts = 8 + top_k = 2 + tp_size = 2 + model_intermediate_size = 14336 + num_layers = 32 + num_calls = 100 + + num_warmup_trials = 1 + num_trials = 1 + + configs = [] + + for block_size_n in [32, 64, 128, 256]: + for block_size_m in [16, 32, 64, 128, 256]: + for block_size_k in [64, 128, 256]: + for group_size_m in [1, 16, 32, 64]: + for num_warps in [4, 8]: + for num_stages in [2, 3, 4, 5]: + configs.append({ + "BLOCK_SIZE_M": block_size_m, + "BLOCK_SIZE_N": block_size_n, + "BLOCK_SIZE_K": block_size_k, + "GROUP_SIZE_M": group_size_m, + "num_warps": num_warps, + "num_stages": num_stages, + }) + + best_config = None + best_time_us = 1e20 + + print(f'{tp_size=} {bs=}') + + for config in tqdm(configs): + # warmup + try: + for _ in range(num_warmup_trials): + run_timing( + num_calls=num_calls, + bs=bs, + d_model=d_model, + num_total_experts=num_total_experts, + top_k=top_k, + tp_size=tp_size, + model_intermediate_size=model_intermediate_size, + method=method, + config=config, + dtype=dtype, + ) + except triton.runtime.autotuner.OutOfResources: + continue + + # trial + for _ in range(num_trials): + kernel_dur_ms = run_timing( + num_calls=num_calls, + bs=bs, + d_model=d_model, + num_total_experts=num_total_experts, + top_k=top_k, + tp_size=tp_size, + model_intermediate_size=model_intermediate_size, + method=method, + config=config, + dtype=dtype, + ) + + kernel_dur_us = 1000 * kernel_dur_ms + model_dur_ms = kernel_dur_ms * num_layers + + if kernel_dur_us < best_time_us: + best_config = config + best_time_us = kernel_dur_us + + tqdm.write( + f'{kernel_dur_us=:.1f} {model_dur_ms=:.1f}' + f' {bs=} {tp_size=} {top_k=} {num_total_experts=} ' + f'{d_model=} {model_intermediate_size=} {num_layers=}') + + print("best_time_us", best_time_us) + print("best_config", best_config) + + # holds Dict[str, Dict[str, int]] + filename = get_config_file_name(num_total_experts, + model_intermediate_size // tp_size, + "float8" if dtype == "float8" else None) + print(f"writing config to file {filename}") + existing_content = {} + if os.path.exists(filename): + with open(filename, "r") as f: + existing_content = json.load(f) + existing_content[str(bs)] = best_config + with open(filename, "w") as f: + json.dump(existing_content, f, indent=4) + f.write("\n") + + +def run_timing(num_calls: int, bs: int, d_model: int, num_total_experts: int, + top_k: int, tp_size: int, model_intermediate_size: int, method, + config, dtype: str) -> float: + shard_intermediate_size = model_intermediate_size // tp_size + + hidden_states = torch.rand( + (bs, d_model), + device="cuda:0", + dtype=torch.float16, + ) + + w1 = torch.rand( + (num_total_experts, 2 * shard_intermediate_size, d_model), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + + w2 = torch.rand( + (num_total_experts, d_model, shard_intermediate_size), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + + w1_scale = None + w2_scale = None + a1_scale = None + a2_scale = None + + if dtype == "float8": + w1 = w1.to(torch.float8_e4m3fn) + w2 = w2.to(torch.float8_e4m3fn) + w1_scale = torch.ones(num_total_experts, + device=hidden_states.device, + dtype=torch.float32) + w2_scale = torch.ones(num_total_experts, + device=hidden_states.device, + dtype=torch.float32) + a1_scale = torch.ones(1, + device=hidden_states.device, + dtype=torch.float32) + a2_scale = torch.ones(1, + device=hidden_states.device, + dtype=torch.float32) + + gating_output = F.softmax(torch.rand( + (num_calls, bs, num_total_experts), + device=hidden_states.device, + dtype=torch.float32, + ), + dim=-1) + + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + start_event.record() + for i in range(num_calls): + hidden_states = method( + hidden_states=hidden_states, + w1=w1, + w2=w2, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + gating_output=gating_output[i], + topk=2, + renormalize=True, + inplace=True, + override_config=config, + use_fp8=dtype == "float8", + ) + end_event.record() + end_event.synchronize() + + dur_ms = start_event.elapsed_time(end_event) / num_calls + return dur_ms + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + prog='benchmark_mixtral_moe', + description='Benchmark and tune the fused_moe kernel', + ) + parser.add_argument( + '--dtype', + type=str, + default='auto', + choices=['float8', 'float16'], + help='Data type used for fused_moe kernel computations', + ) + args = parser.parse_args() + sys.exit(main(args.dtype)) diff --git a/benchmarks/kernels/benchmark_moe.py b/benchmarks/kernels/benchmark_moe.py deleted file mode 100644 index e00696d6d43cb..0000000000000 --- a/benchmarks/kernels/benchmark_moe.py +++ /dev/null @@ -1,333 +0,0 @@ -import argparse -import time -from datetime import datetime -from typing import Any, Dict, List, Tuple, TypedDict - -import ray -import torch -import triton -from ray.experimental.tqdm_ray import tqdm -from transformers import AutoConfig - -from vllm.model_executor.layers.fused_moe.fused_moe import * -from vllm.utils import FlexibleArgumentParser - - -class BenchmarkConfig(TypedDict): - BLOCK_SIZE_M: int - BLOCK_SIZE_N: int - BLOCK_SIZE_K: int - GROUP_SIZE_M: int - num_warps: int - num_stages: int - - -def benchmark_config( - config: BenchmarkConfig, - num_tokens: int, - num_experts: int, - shard_intermediate_size: int, - hidden_size: int, - topk: int, - dtype: torch.dtype, - use_fp8: bool, - num_iters: int = 100, -) -> float: - init_dtype = torch.float16 if use_fp8 else dtype - x = torch.randn(num_tokens, hidden_size, dtype=dtype) - w1 = torch.randn(num_experts, - shard_intermediate_size, - hidden_size, - dtype=init_dtype) - w2 = torch.randn(num_experts, - hidden_size, - shard_intermediate_size // 2, - dtype=init_dtype) - gating_output = torch.randn(num_iters, - num_tokens, - num_experts, - dtype=torch.float32) - - w1_scale = None - w2_scale = None - a1_scale = None - a2_scale = None - if use_fp8: - w1_scale = torch.randn(num_experts, dtype=torch.float32) - w2_scale = torch.randn(num_experts, dtype=torch.float32) - a1_scale = torch.randn(1, dtype=torch.float32) - a2_scale = torch.randn(1, dtype=torch.float32) - - w1 = w1.to(torch.float8_e4m3fn) - w2 = w2.to(torch.float8_e4m3fn) - - input_gating = torch.empty(num_tokens, num_experts, dtype=torch.float32) - - def prepare(i: int): - input_gating.copy_(gating_output[i]) - - def run(): - fused_moe( - x, - w1, - w2, - input_gating, - topk, - renormalize=True, - inplace=True, - override_config=config, - use_fp8=use_fp8, - w1_scale=w1_scale, - w2_scale=w2_scale, - a1_scale=a1_scale, - a2_scale=a2_scale, - ) - - # JIT compilation & warmup - run() - torch.cuda.synchronize() - - # Capture 10 invocations with CUDA graph - graph = torch.cuda.CUDAGraph() - with torch.cuda.graph(graph): - for _ in range(10): - run() - torch.cuda.synchronize() - - # Warmup - for _ in range(5): - graph.replay() - torch.cuda.synchronize() - - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - - latencies: List[float] = [] - for i in range(num_iters): - prepare(i) - torch.cuda.synchronize() - - start_event.record() - graph.replay() - end_event.record() - end_event.synchronize() - latencies.append(start_event.elapsed_time(end_event)) - avg = sum(latencies) / (num_iters * 10) * 1000 # us - graph.reset() - return avg - - -def get_configs_compute_bound() -> List[Dict[str, int]]: - # Reduced search space for faster tuning. - # TODO(woosuk): Increase the search space and use a performance model to - # prune the search space. - configs: List[BenchmarkConfig] = [] - for num_stages in [2, 3, 4, 5]: - for block_m in [16, 32, 64, 128, 256]: - for block_k in [64, 128, 256]: - for block_n in [32, 64, 128, 256]: - for num_warps in [4, 8]: - for group_size in [1, 16, 32, 64]: - configs.append({ - "BLOCK_SIZE_M": block_m, - "BLOCK_SIZE_N": block_n, - "BLOCK_SIZE_K": block_k, - "GROUP_SIZE_M": group_size, - "num_warps": num_warps, - "num_stages": num_stages, - }) - return configs - - -@ray.remote(num_gpus=1) -class BenchmarkWorker: - - def __init__(self, seed: int) -> None: - torch.set_default_device("cuda") - torch.cuda.manual_seed_all(seed) - self.seed = seed - - def benchmark( - self, - num_tokens: int, - num_experts: int, - shard_intermediate_size: int, - hidden_size: int, - topk: int, - dtype: torch.dtype, - use_fp8: bool, - ) -> Tuple[Dict[str, int], float]: - torch.cuda.manual_seed_all(self.seed) - - dtype_str = "float8" if use_fp8 else None - # NOTE(woosuk): The current naming convention uses w2.shape[2], which - # is the intermediate size after silu_and_mul. - op_config = get_moe_configs(num_experts, shard_intermediate_size // 2, - dtype_str) - if op_config is None: - config = get_default_config(num_tokens, num_experts, - shard_intermediate_size, hidden_size, - topk, dtype_str) - else: - config = op_config[min(op_config.keys(), - key=lambda x: abs(x - num_tokens))] - kernel_time = benchmark_config(config, num_tokens, num_experts, - shard_intermediate_size, hidden_size, - topk, dtype, use_fp8) - return config, kernel_time - - def tune( - self, - num_tokens: int, - num_experts: int, - shard_intermediate_size: int, - hidden_size: int, - topk: int, - dtype: torch.dtype, - use_fp8: bool, - search_space: List[BenchmarkConfig], - ) -> BenchmarkConfig: - best_config = None - best_time = float("inf") - for config in tqdm(search_space): - try: - kernel_time = benchmark_config(config, - num_tokens, - num_experts, - shard_intermediate_size, - hidden_size, - topk, - dtype, - use_fp8, - num_iters=10) - except triton.runtime.autotuner.OutOfResources: - # Some configurations may be invalid and fail to compile. - continue - - if kernel_time < best_time: - best_time = kernel_time - best_config = config - now = datetime.now() - print(f"{now.ctime()}] Completed tuning for batch_size={num_tokens}") - assert best_config is not None - return best_config - - -def sort_config(config: BenchmarkConfig) -> BenchmarkConfig: - return { - "BLOCK_SIZE_M": config["BLOCK_SIZE_M"], - "BLOCK_SIZE_N": config["BLOCK_SIZE_N"], - "BLOCK_SIZE_K": config["BLOCK_SIZE_K"], - "GROUP_SIZE_M": config["GROUP_SIZE_M"], - "num_warps": config["num_warps"], - "num_stages": config["num_stages"], - } - - -def save_configs( - configs: Dict[int, BenchmarkConfig], - num_experts: int, - shard_intermediate_size: int, - hidden_size: int, - topk: int, - dtype: torch.dtype, - use_fp8: bool, -) -> None: - dtype_str = "float8" if use_fp8 else None - # NOTE(woosuk): The current naming convention uses w2.shape[2], which - # is the intermediate size after silu_and_mul. - filename = get_config_file_name(num_experts, shard_intermediate_size // 2, - dtype_str) - print(f"Writing best config to {filename}...") - with open(filename, "w") as f: - json.dump(configs, f, indent=4) - f.write("\n") - - -def main(args: argparse.Namespace): - print(args) - - config = AutoConfig.from_pretrained(args.model) - if config.architectures[0] == "DbrxForCausalLM": - E = config.ffn_config.moe_num_experts - topk = config.ffn_config.moe_top_k - intermediate_size = config.ffn_config.ffn_hidden_size - shard_intermediate_size = 2 * intermediate_size // args.tp_size - else: - # Default: Mixtral. - E = config.num_local_experts - topk = config.num_experts_per_tok - intermediate_size = config.intermediate_size - shard_intermediate_size = 2 * intermediate_size // args.tp_size - - hidden_size = config.hidden_size - dtype = config.torch_dtype - use_fp8 = args.dtype == "fp8" - - if args.batch_size is None: - batch_sizes = [ - 1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128, 256, 512, 1024, 1536, - 2048, 3072, 4096 - ] - else: - batch_sizes = [args.batch_size] - - ray.init() - num_gpus = int(ray.available_resources()["GPU"]) - workers = [BenchmarkWorker.remote(args.seed) for _ in range(num_gpus)] - - def _distribute(method: str, inputs: List[Any]) -> List[Any]: - outputs = [] - worker_idx = 0 - for input_args in inputs: - worker = workers[worker_idx] - worker_method = getattr(worker, method) - output = worker_method.remote(*input_args) - outputs.append(output) - worker_idx = (worker_idx + 1) % num_gpus - return ray.get(outputs) - - if args.tune: - search_space = get_configs_compute_bound() - print(f"Start tuning over {len(search_space)} configurations...") - - start = time.time() - configs = _distribute( - "tune", [(batch_size, E, shard_intermediate_size, hidden_size, - topk, dtype, use_fp8, search_space) - for batch_size in batch_sizes]) - best_configs = { - M: sort_config(config) - for M, config in zip(batch_sizes, configs) - } - save_configs(best_configs, E, shard_intermediate_size, hidden_size, - topk, dtype, use_fp8) - end = time.time() - print(f"Tuning took {end - start:.2f} seconds") - else: - outputs = _distribute("benchmark", - [(batch_size, E, shard_intermediate_size, - hidden_size, topk, dtype, use_fp8) - for batch_size in batch_sizes]) - - for batch_size, (config, kernel_time) in zip(batch_sizes, outputs): - print(f"Batch size: {batch_size}, config: {config}") - print(f"Kernel time: {kernel_time:.2f} us") - - -if __name__ == "__main__": - parser = FlexibleArgumentParser() - parser.add_argument("--model", - type=str, - default="mistralai/Mixtral-8x7B-Instruct-v0.1") - parser.add_argument("--tp-size", "-tp", type=int, default=2) - parser.add_argument("--dtype", - type=str, - choices=["auto", "fp8"], - default="auto") - parser.add_argument("--seed", type=int, default=0) - parser.add_argument("--batch-size", type=int, required=False) - parser.add_argument("--tune", action="store_true") - args = parser.parse_args() - - main(args) diff --git a/benchmarks/kernels/benchmark_paged_attention.py b/benchmarks/kernels/benchmark_paged_attention.py index 16de60477c305..ca7967c1ab0d2 100644 --- a/benchmarks/kernels/benchmark_paged_attention.py +++ b/benchmarks/kernels/benchmark_paged_attention.py @@ -1,12 +1,12 @@ +import argparse import random import time -from typing import List, Optional +from typing import Optional import torch from vllm import _custom_ops as ops -from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser, - create_kv_caches_with_random) +from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, create_kv_caches_with_random NUM_BLOCKS = 1024 PARTITION_SIZE = 512 @@ -54,17 +54,14 @@ def main( # Create the block tables. max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size - block_tables_lst: List[List[int]] = [] + block_tables = [] for _ in range(num_seqs): block_table = [ random.randint(0, NUM_BLOCKS - 1) for _ in range(max_num_blocks_per_seq) ] - block_tables_lst.append(block_table) - - block_tables = torch.tensor(block_tables_lst, - dtype=torch.int, - device=device) + block_tables.append(block_table) + block_tables = torch.tensor(block_tables, dtype=torch.int, device=device) # Create the KV cache. key_caches, value_caches = create_kv_caches_with_random(NUM_BLOCKS, @@ -161,19 +158,19 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: if __name__ == '__main__': - parser = FlexibleArgumentParser( + parser = argparse.ArgumentParser( description="Benchmark the paged attention kernel.") parser.add_argument("--version", type=str, choices=["v1", "v2"], default="v2") parser.add_argument("--batch-size", type=int, default=8) - parser.add_argument("--seq-len", type=int, default=4096) + parser.add_argument("--seq_len", type=int, default=4096) parser.add_argument("--num-query-heads", type=int, default=64) parser.add_argument("--num-kv-heads", type=int, default=8) parser.add_argument("--head-size", type=int, - choices=[64, 80, 96, 112, 128, 192, 256], + choices=[64, 80, 96, 112, 128, 256], default=128) parser.add_argument("--block-size", type=int, choices=[16, 32], default=16) parser.add_argument("--use-alibi", action="store_true") @@ -186,11 +183,13 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: parser.add_argument( "--kv-cache-dtype", type=str, - choices=["auto", "fp8", "fp8_e5m2", "fp8_e4m3"], + choices=["auto", "fp8"], default="auto", - help="Data type for kv cache storage. If 'auto', will use model " - "data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. " - "ROCm (AMD GPU) supports fp8 (=fp8_e4m3)") + help= + 'Data type for kv cache storage. If "auto", will use model data type. ' + 'FP8_E5M2 (without scaling) is only supported on cuda version greater ' + 'than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for ' + 'common inference criteria.') args = parser.parse_args() print(args) diff --git a/benchmarks/kernels/benchmark_rope.py b/benchmarks/kernels/benchmark_rope.py index 78736c7a7ba6f..9188e811e2982 100644 --- a/benchmarks/kernels/benchmark_rope.py +++ b/benchmarks/kernels/benchmark_rope.py @@ -1,12 +1,11 @@ +import argparse from itertools import accumulate -from typing import List, Optional +from typing import Optional import nvtx import torch -from vllm.model_executor.layers.rotary_embedding import (RotaryEmbedding, - get_rope) -from vllm.utils import FlexibleArgumentParser +from vllm.model_executor.layers.rotary_embedding import get_rope def benchmark_rope_kernels_multi_lora( @@ -38,7 +37,7 @@ def benchmark_rope_kernels_multi_lora( }) # non-batched RoPE takes only one scaling factor, we create multiple # instances to simulate the same behavior - non_batched_ropes: List[RotaryEmbedding] = [] + non_batched_ropes = [] for scaling_factor in scaling_factors: non_batched_ropes.append( get_rope(head_size, rotary_dim, max_position, base, is_neox_style, @@ -86,7 +85,7 @@ def benchmark_rope_kernels_multi_lora( if __name__ == '__main__': - parser = FlexibleArgumentParser( + parser = argparse.ArgumentParser( description="Benchmark the rotary embedding kernels.") parser.add_argument("--is-neox-style", type=bool, default=True) parser.add_argument("--batch-size", type=int, default=16) @@ -94,7 +93,7 @@ def benchmark_rope_kernels_multi_lora( parser.add_argument("--num-heads", type=int, default=8) parser.add_argument("--head-size", type=int, - choices=[64, 80, 96, 112, 128, 192, 256], + choices=[64, 80, 96, 112, 128, 256], default=128) parser.add_argument("--rotary-dim", type=int, choices=[16, 32], default=32) parser.add_argument("--dtype", diff --git a/benchmarks/kernels/benchmark_shapes.py b/benchmarks/kernels/benchmark_shapes.py deleted file mode 100644 index 4eeeca35a37cc..0000000000000 --- a/benchmarks/kernels/benchmark_shapes.py +++ /dev/null @@ -1,75 +0,0 @@ -WEIGHT_SHAPES = { - "ideal": [[4 * 256 * 32, 256 * 32]], - "mistralai/Mistral-7B-v0.1/TP1": [ - [4096, 6144], - [4096, 4096], - [4096, 28672], - [14336, 4096], - ], - "mistralai/Mistral-7B-v0.1/TP2": [ - [4096, 3072], - [2048, 4096], - [4096, 14336], - [7168, 4096], - ], - "mistralai/Mistral-7B-v0.1/TP4": [ - [4096, 1536], - [1024, 4096], - [4096, 7168], - [3584, 4096], - ], - "meta-llama/Llama-2-7b-hf/TP1": [ - [4096, 12288], - [4096, 4096], - [4096, 22016], - [11008, 4096], - ], - "meta-llama/Llama-2-7b-hf/TP2": [ - [4096, 6144], - [2048, 4096], - [4096, 11008], - [5504, 4096], - ], - "meta-llama/Llama-2-7b-hf/TP4": [ - [4096, 3072], - [1024, 4096], - [4096, 5504], - [2752, 4096], - ], - "meta-llama/Llama-2-13b-hf/TP1": [ - [5120, 15360], - [5120, 5120], - [5120, 27648], - [13824, 5120], - ], - "meta-llama/Llama-2-13b-hf/TP2": [ - [5120, 7680], - [2560, 5120], - [5120, 13824], - [6912, 5120], - ], - "meta-llama/Llama-2-13b-hf/TP4": [ - [5120, 3840], - [1280, 5120], - [5120, 6912], - [3456, 5120], - ], - "meta-llama/Llama-2-70b-hf/TP1": [ - [8192, 10240], - [8192, 8192], - [8192, 57344], - [28672, 8192], - ], - "meta-llama/Llama-2-70b-hf/TP2": [ - [8192, 5120], - [4096, 8192], - [8192, 28672], - [14336, 8192], - ], - "meta-llama/Llama-2-70b-hf/TP4": [ - [8192, 2560], - [2048, 8192], - [8192, 14336], - [7168, 8192], - ], -} diff --git a/benchmarks/launch_tgi_server.sh b/benchmarks/launch_tgi_server.sh index f491c90d0683e..64d3c4f4b3889 100755 --- a/benchmarks/launch_tgi_server.sh +++ b/benchmarks/launch_tgi_server.sh @@ -4,7 +4,7 @@ PORT=8000 MODEL=$1 TOKENS=$2 -docker run -e HF_TOKEN=$HF_TOKEN --gpus all --shm-size 1g -p $PORT:80 \ +docker run --gpus all --shm-size 1g -p $PORT:80 \ -v $PWD/data:/data \ ghcr.io/huggingface/text-generation-inference:1.4.0 \ --model-id $MODEL \ diff --git a/benchmarks/overheads/benchmark_hashing.py b/benchmarks/overheads/benchmark_hashing.py deleted file mode 100644 index 203699e9a8d06..0000000000000 --- a/benchmarks/overheads/benchmark_hashing.py +++ /dev/null @@ -1,63 +0,0 @@ -import cProfile -import pstats - -from vllm import LLM, SamplingParams -from vllm.utils import FlexibleArgumentParser - -# A very long prompt, total number of tokens is about 15k. -LONG_PROMPT = ["You are an expert in large language models, aren't you?" - ] * 1000 -LONG_PROMPT = ' '.join(LONG_PROMPT) - - -def main(args): - llm = LLM( - model=args.model, - enforce_eager=True, - enable_prefix_caching=True, - tensor_parallel_size=args.tensor_parallel_size, - use_v2_block_manager=args.use_v2_block_manager, - ) - - sampling_params = SamplingParams(temperature=0, max_tokens=args.output_len) - profiler = cProfile.Profile() - - print("------warm up------") - for i in range(3): - output = llm.generate(LONG_PROMPT, sampling_params) - print(output[0].outputs[0].text) - - print("------start generating------") - for i in range(3): - profiler.runctx('llm.generate(LONG_PROMPT, sampling_params)', - globals(), locals()) - - # analyze the runtime of hashing function - stats = pstats.Stats(profiler) - stats.sort_stats('cumulative') - total_time = 0 - total_calls = 0 - for func in stats.stats: - if 'hash_of_block' in func[2]: - total_time = stats.stats[func][3] - total_calls = stats.stats[func][0] - percentage = (total_time / stats.total_tt) * 100 - print(f"Hashing took {total_time:.2f} seconds," - f"{percentage:.2f}% of the total runtime.") - - -if __name__ == "__main__": - parser = FlexibleArgumentParser( - description='Benchmark the performance of hashing function in' - 'automatic prefix caching.') - parser.add_argument('--model', type=str, default='lmsys/longchat-7b-16k') - parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1) - parser.add_argument('--output-len', type=int, default=10) - parser.add_argument('--enable-prefix-caching', - action='store_true', - help='enable prefix caching') - parser.add_argument('--use-v2-block-manager', - action='store_true', - help='Use BlockSpaceMangerV2') - args = parser.parse_args() - main(args) diff --git a/cmake/cpu_extension.cmake b/cmake/cpu_extension.cmake index 690559ee265e9..0cf37769a6960 100644 --- a/cmake/cpu_extension.cmake +++ b/cmake/cpu_extension.cmake @@ -12,7 +12,7 @@ include_directories("${CMAKE_SOURCE_DIR}/csrc") # # Check the compile flags # -list(APPEND CXX_COMPILE_FLAGS +list(APPEND CXX_COMPILE_FLAGS "-fopenmp" "-DVLLM_CPU_EXTENSION") @@ -33,23 +33,9 @@ function (find_isa CPUINFO TARGET OUT) endif() endfunction() -function (is_avx512_disabled OUT) - set(DISABLE_AVX512 $ENV{VLLM_CPU_DISABLE_AVX512}) - if(DISABLE_AVX512 AND DISABLE_AVX512 STREQUAL "true") - set(${OUT} ON PARENT_SCOPE) - else() - set(${OUT} OFF PARENT_SCOPE) - endif() -endfunction() - -is_avx512_disabled(AVX512_DISABLED) - -find_isa(${CPUINFO} "avx2" AVX2_FOUND) find_isa(${CPUINFO} "avx512f" AVX512_FOUND) -find_isa(${CPUINFO} "POWER10" POWER10_FOUND) -find_isa(${CPUINFO} "POWER9" POWER9_FOUND) -if (AVX512_FOUND AND NOT AVX512_DISABLED) +if (AVX512_FOUND) list(APPEND CXX_COMPILE_FLAGS "-mavx512f" "-mavx512vl" @@ -58,8 +44,8 @@ if (AVX512_FOUND AND NOT AVX512_DISABLED) find_isa(${CPUINFO} "avx512_bf16" AVX512BF16_FOUND) if (AVX512BF16_FOUND OR ENABLE_AVX512BF16) - if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND - CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 12.3) + if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND + CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 12.3) list(APPEND CXX_COMPILE_FLAGS "-mavx512bf16") else() message(WARNING "Disable AVX512-BF16 ISA support, requires gcc/g++ >= 12.3") @@ -67,18 +53,8 @@ if (AVX512_FOUND AND NOT AVX512_DISABLED) else() message(WARNING "Disable AVX512-BF16 ISA support, no avx512_bf16 found in local CPU flags." " If cross-compilation is required, please set env VLLM_CPU_AVX512BF16=1.") endif() -elseif (AVX2_FOUND) - list(APPEND CXX_COMPILE_FLAGS "-mavx2") - message(WARNING "vLLM CPU backend using AVX2 ISA") -elseif (POWER9_FOUND OR POWER10_FOUND) - message(STATUS "PowerPC detected") - # Check for PowerPC VSX support - list(APPEND CXX_COMPILE_FLAGS - "-mvsx" - "-mcpu=native" - "-mtune=native") else() - message(FATAL_ERROR "vLLM CPU backend requires AVX512 or AVX2 or Power9+ ISA support.") + message(FATAL_ERROR "vLLM CPU backend requires AVX512 ISA support.") endif() message(STATUS "CPU extension compile flags: ${CXX_COMPILE_FLAGS}") @@ -97,7 +73,7 @@ set(VLLM_EXT_SRC "csrc/cpu/cache.cpp" "csrc/cpu/layernorm.cpp" "csrc/cpu/pos_encoding.cpp" - "csrc/cpu/torch_bindings.cpp") + "csrc/cpu/pybind.cpp") define_gpu_extension_target( _C @@ -105,10 +81,10 @@ define_gpu_extension_target( LANGUAGE CXX SOURCES ${VLLM_EXT_SRC} COMPILE_FLAGS ${CXX_COMPILE_FLAGS} - USE_SABI 3 - WITH_SOABI + WITH_SOABI ) add_custom_target(default) message(STATUS "Enabling C extension.") add_dependencies(default _C) + diff --git a/cmake/utils.cmake b/cmake/utils.cmake index 4869cad541135..7c71673e36f29 100644 --- a/cmake/utils.cmake +++ b/cmake/utils.cmake @@ -5,7 +5,7 @@ macro (find_python_from_executable EXECUTABLE SUPPORTED_VERSIONS) file(REAL_PATH ${EXECUTABLE} EXECUTABLE) set(Python_EXECUTABLE ${EXECUTABLE}) - find_package(Python COMPONENTS Interpreter Development.Module Development.SABIModule) + find_package(Python COMPONENTS Interpreter Development.Module) if (NOT Python_FOUND) message(FATAL_ERROR "Unable to find python matching: ${EXECUTABLE}.") endif() @@ -99,7 +99,7 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG) "Failed to determine torch nvcc compiler flags") if (CUDA_VERSION VERSION_GREATER_EQUAL 11.8) - list(APPEND GPU_FLAGS "-DENABLE_FP8") + list(APPEND GPU_FLAGS "-DENABLE_FP8_E5M2") endif() if (CUDA_VERSION VERSION_GREATER_EQUAL 12.0) list(REMOVE_ITEM GPU_FLAGS @@ -119,7 +119,7 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG) list(APPEND GPU_FLAGS "-DUSE_ROCM" - "-DENABLE_FP8" + "-DENABLE_FP8_E4M3" "-U__HIP_NO_HALF_CONVERSIONS__" "-U__HIP_NO_HALF_OPERATORS__" "-fno-gpu-rdc") @@ -147,23 +147,16 @@ macro(override_gpu_arches GPU_ARCHES GPU_LANG GPU_SUPPORTED_ARCHES) if (${GPU_LANG} STREQUAL "HIP") # # `GPU_ARCHES` controls the `--offload-arch` flags. + # `CMAKE_HIP_ARCHITECTURES` is set up by torch and can be controlled + # via the `PYTORCH_ROCM_ARCH` env variable. # - # If PYTORCH_ROCM_ARCH env variable exists, then we take it as a list, - # if not, then we use CMAKE_HIP_ARCHITECTURES which was generated by calling - # "rocm_agent_enumerator" in "enable_language(HIP)" - # (in file Modules/CMakeDetermineHIPCompiler.cmake) - # - if(DEFINED ENV{PYTORCH_ROCM_ARCH}) - set(HIP_ARCHITECTURES $ENV{PYTORCH_ROCM_ARCH}) - else() - set(HIP_ARCHITECTURES ${CMAKE_HIP_ARCHITECTURES}) - endif() + # # Find the intersection of the supported + detected architectures to # set the module architecture flags. # set(${GPU_ARCHES}) - foreach (_ARCH ${HIP_ARCHITECTURES}) + foreach (_ARCH ${CMAKE_HIP_ARCHITECTURES}) if (_ARCH IN_LIST _GPU_SUPPORTED_ARCHES_LIST) list(APPEND ${GPU_ARCHES} ${_ARCH}) endif() @@ -171,7 +164,7 @@ macro(override_gpu_arches GPU_ARCHES GPU_LANG GPU_SUPPORTED_ARCHES) if(NOT ${GPU_ARCHES}) message(FATAL_ERROR - "None of the detected ROCm architectures: ${HIP_ARCHITECTURES} is" + "None of the detected ROCm architectures: ${CMAKE_HIP_ARCHITECTURES} is" " supported. Supported ROCm architectures are: ${_GPU_SUPPORTED_ARCHES_LIST}.") endif() @@ -301,7 +294,6 @@ endmacro() # INCLUDE_DIRECTORIES - Extra include directories. # LIBRARIES - Extra link libraries. # WITH_SOABI - Generate library with python SOABI suffix name. -# USE_SABI - Use python stable api # # Note: optimization level/debug info is set via cmake build type. # @@ -309,7 +301,7 @@ function (define_gpu_extension_target GPU_MOD_NAME) cmake_parse_arguments(PARSE_ARGV 1 GPU "WITH_SOABI" - "DESTINATION;LANGUAGE;USE_SABI" + "DESTINATION;LANGUAGE" "SOURCES;ARCHITECTURES;COMPILE_FLAGS;INCLUDE_DIRECTORIES;LIBRARIES") # Add hipify preprocessing step when building with HIP/ROCm. @@ -323,11 +315,7 @@ function (define_gpu_extension_target GPU_MOD_NAME) set(GPU_WITH_SOABI) endif() - if (GPU_USE_SABI) - Python_add_library(${GPU_MOD_NAME} MODULE USE_SABI ${GPU_USE_SABI} ${GPU_WITH_SOABI} "${GPU_SOURCES}") - else() - Python_add_library(${GPU_MOD_NAME} MODULE ${GPU_WITH_SOABI} "${GPU_SOURCES}") - endif() + Python_add_library(${GPU_MOD_NAME} MODULE "${GPU_SOURCES}" ${GPU_WITH_SOABI}) if (GPU_LANGUAGE STREQUAL "HIP") # Make this target dependent on the hipify preprocessor step. diff --git a/collect_env.py b/collect_env.py index 083cb768f5399..1ecfeb8e22e2f 100644 --- a/collect_env.py +++ b/collect_env.py @@ -64,7 +64,6 @@ "triton", "optree", "nccl", - "transformers", } DEFAULT_PIP_PATTERNS = { @@ -76,7 +75,6 @@ "optree", "onnx", "nccl", - "transformers", } @@ -603,11 +601,6 @@ def get_version_or_na(cfg, prefix): {conda_packages} """.strip() -# both the above code and the following code use `strip()` to -# remove leading/trailing whitespaces, so we need to add a newline -# in between to separate the two sections -env_info_fmt += "\n" - env_info_fmt += """ ROCM Version: {rocm_version} Neuron SDK Version: {neuron_sdk_version} diff --git a/csrc/activation_kernels.cu b/csrc/activation_kernels.cu index 5ed1dc3b8f792..24d972702c858 100644 --- a/csrc/activation_kernels.cu +++ b/csrc/activation_kernels.cu @@ -1,5 +1,5 @@ #include -#include +#include #include #include @@ -10,11 +10,11 @@ namespace vllm { // Activation and gating kernel template. -template +template __global__ void act_and_mul_kernel( - scalar_t* __restrict__ out, // [..., d] - const scalar_t* __restrict__ input, // [..., 2, d] - const int d) { + scalar_t* __restrict__ out, // [..., d] + const scalar_t* __restrict__ input, // [..., 2, d] + const int d) { const int64_t token_idx = blockIdx.x; for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]); @@ -23,66 +23,72 @@ __global__ void act_and_mul_kernel( } } -template +template __device__ __forceinline__ T silu_kernel(const T& x) { // x * sigmoid(x) - return (T)(((float)x) / (1.0f + expf((float)-x))); + return (T) (((float) x) / (1.0f + expf((float) -x))); } -template +template __device__ __forceinline__ T gelu_kernel(const T& x) { // Equivalent to PyTorch GELU with 'none' approximation. // Refer to: // https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L36-L38 - const float f = (float)x; + const float f = (float) x; constexpr float ALPHA = M_SQRT1_2; - return (T)(f * 0.5f * (1.0f + ::erf(f * ALPHA))); + return (T) (f * 0.5f * (1.0f + ::erf(f * ALPHA))); } -template +template __device__ __forceinline__ T gelu_tanh_kernel(const T& x) { // Equivalent to PyTorch GELU with 'tanh' approximation. // Refer to: // https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L25-L30 - const float f = (float)x; + const float f = (float) x; constexpr float BETA = M_SQRT2 * M_2_SQRTPI * 0.5f; constexpr float KAPPA = 0.044715; float x_cube = f * f * f; float inner = BETA * (f + KAPPA * x_cube); - return (T)(0.5f * f * (1.0f + ::tanhf(inner))); + return (T) (0.5f * f * (1.0f + ::tanhf(inner))); } -} // namespace vllm +} // namespace vllm // Launch activation and gating kernel. -#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL) \ - int d = input.size(-1) / 2; \ - int64_t num_tokens = input.numel() / input.size(-1); \ - dim3 grid(num_tokens); \ - dim3 block(std::min(d, 1024)); \ - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ - VLLM_DISPATCH_FLOATING_TYPES( \ - input.scalar_type(), "act_and_mul_kernel", [&] { \ - vllm::act_and_mul_kernel> \ - <<>>(out.data_ptr(), \ - input.data_ptr(), d); \ - }); - -void silu_and_mul(torch::Tensor& out, // [..., d] - torch::Tensor& input) // [..., 2 * d] +#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL) \ + int d = input.size(-1) / 2; \ + int64_t num_tokens = input.numel() / input.size(-1); \ + dim3 grid(num_tokens); \ + dim3 block(std::min(d, 1024)); \ + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ + VLLM_DISPATCH_FLOATING_TYPES( \ + input.scalar_type(), \ + "act_and_mul_kernel", \ + [&] { \ + vllm::act_and_mul_kernel><<>>( \ + out.data_ptr(), \ + input.data_ptr(), \ + d); \ + }); + +void silu_and_mul( + torch::Tensor& out, // [..., d] + torch::Tensor& input) // [..., 2 * d] { LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel); } -void gelu_and_mul(torch::Tensor& out, // [..., d] - torch::Tensor& input) // [..., 2 * d] +void gelu_and_mul( + torch::Tensor& out, // [..., d] + torch::Tensor& input) // [..., 2 * d] { LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_kernel); } -void gelu_tanh_and_mul(torch::Tensor& out, // [..., d] - torch::Tensor& input) // [..., 2 * d] +void gelu_tanh_and_mul( + torch::Tensor& out, // [..., d] + torch::Tensor& input) // [..., 2 * d] { LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_tanh_kernel); } @@ -90,11 +96,11 @@ void gelu_tanh_and_mul(torch::Tensor& out, // [..., d] namespace vllm { // Element-wise activation kernel template. -template +template __global__ void activation_kernel( - scalar_t* __restrict__ out, // [..., d] - const scalar_t* __restrict__ input, // [..., d] - const int d) { + scalar_t* __restrict__ out, // [..., d] + const scalar_t* __restrict__ input, // [..., d] + const int d) { const int64_t token_idx = blockIdx.x; for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { const scalar_t x = VLLM_LDG(&input[token_idx * d + idx]); @@ -102,61 +108,54 @@ __global__ void activation_kernel( } } -} // namespace vllm +} // namespace vllm // Launch element-wise activation kernel. -#define LAUNCH_ACTIVATION_KERNEL(KERNEL) \ - int d = input.size(-1); \ - int64_t num_tokens = input.numel() / d; \ - dim3 grid(num_tokens); \ - dim3 block(std::min(d, 1024)); \ - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ - VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "activation_kernel", [&] { \ - vllm::activation_kernel> \ - <<>>(out.data_ptr(), \ - input.data_ptr(), d); \ - }); +#define LAUNCH_ACTIVATION_KERNEL(KERNEL) \ + int d = input.size(-1); \ + int64_t num_tokens = input.numel() / d; \ + dim3 grid(num_tokens); \ + dim3 block(std::min(d, 1024)); \ + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ + VLLM_DISPATCH_FLOATING_TYPES( \ + input.scalar_type(), \ + "activation_kernel", \ + [&] { \ + vllm::activation_kernel><<>>( \ + out.data_ptr(), \ + input.data_ptr(), \ + d); \ + }); namespace vllm { -template +template __device__ __forceinline__ T gelu_new_kernel(const T& x) { - const float x3 = (float)(x * x * x); - const T t = (T)tanhf((T)(0.79788456f * (float)(x + (T)(0.044715f * x3)))); - return ((T)0.5) * x * (((T)1.0) + t); + const float x3 = (float) (x * x * x); + const T t = (T) tanhf((T) (0.79788456f * (float) (x + (T) (0.044715f * x3)))); + return ((T) 0.5) * x * (((T) 1.0) + t); } -template +template __device__ __forceinline__ T gelu_fast_kernel(const T& x) { - const float f = (float)x; - const T t = - (T)tanhf(((T)(f * 0.79788456f)) * (((T)1.0) + (T)(0.044715f * f) * x)); - return ((T)0.5) * x * (((T)1.0) + t); + const float f = (float) x; + const T t = (T) tanhf(((T) (f * 0.79788456f)) * (((T) 1.0) + (T) (0.044715f * f) * x)); + return ((T) 0.5) * x * (((T) 1.0) + t); } -template -__device__ __forceinline__ T gelu_quick_kernel(const T& x) { - // x * sigmoid(1.702 * x) - return (T)(((float)x) / (1.0f + expf(-1.702f * (float)x))); -} - -} // namespace vllm +} // namespace vllm -void gelu_new(torch::Tensor& out, // [..., d] - torch::Tensor& input) // [..., d] +void gelu_new( + torch::Tensor& out, // [..., d] + torch::Tensor& input) // [..., d] { LAUNCH_ACTIVATION_KERNEL(vllm::gelu_new_kernel); } -void gelu_fast(torch::Tensor& out, // [..., d] - torch::Tensor& input) // [..., d] +void gelu_fast( + torch::Tensor& out, // [..., d] + torch::Tensor& input) // [..., d] { LAUNCH_ACTIVATION_KERNEL(vllm::gelu_fast_kernel); } - -void gelu_quick(torch::Tensor& out, // [..., d] - torch::Tensor& input) // [..., d] -{ - LAUNCH_ACTIVATION_KERNEL(vllm::gelu_quick_kernel); -} diff --git a/csrc/attention/attention_generic.cuh b/csrc/attention/attention_generic.cuh index 62409c0cce93e..31fb401cbe2c1 100644 --- a/csrc/attention/attention_generic.cuh +++ b/csrc/attention/attention_generic.cuh @@ -1,6 +1,5 @@ /* - * Adapted from - * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h + * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h * Copyright (c) 2023, The vLLM team. * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * @@ -23,31 +22,31 @@ namespace vllm { // A vector type to store Q, K, V elements. -template +template struct Vec {}; // A vector type to store FP32 accumulators. -template +template struct FloatVec {}; // Template vector operations. -template +template inline __device__ Acc mul(A a, B b); -template +template inline __device__ float sum(T v); -template +template inline __device__ float dot(T a, T b) { return sum(mul(a, b)); } -template +template inline __device__ float dot(T a, T b) { return sum(mul(a, b)); } -template +template inline __device__ void zero(T& dst) { constexpr int WORDS = sizeof(T) / 4; union { @@ -62,4 +61,4 @@ inline __device__ void zero(T& dst) { dst = tmp.raw; } -} // namespace vllm +} // namespace vllm diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu index 91083481705cb..8b1b5e098015f 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cu @@ -1,6 +1,5 @@ /* - * Adapted from - * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp + * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp * Copyright (c) 2023, The vLLM team. * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * @@ -17,26 +16,30 @@ * limitations under the License. */ -#include +#include #include #include -#include #include "attention_dtypes.h" #include "attention_utils.cuh" +#if defined(ENABLE_FP8_E5M2) +#include "../quantization/fp8_e5m2_kvcache/quant_utils.cuh" +#elif defined(ENABLE_FP8_E4M3) +#include "../quantization/fp8/amd_detail/quant_utils.cuh" +#endif + +#include + #ifdef USE_ROCM #include - #include "../quantization/fp8/amd/quant_utils.cuh" -typedef __hip_bfloat16 __nv_bfloat16; -#else - #include "../quantization/fp8/nvidia/quant_utils.cuh" + typedef __hip_bfloat16 __nv_bfloat16; #endif #ifndef USE_ROCM - #define WARP_SIZE 32 +#define WARP_SIZE 32 #else - #define WARP_SIZE warpSize +#define WARP_SIZE warpSize #endif #define MAX(a, b) ((a) > (b) ? (a) : (b)) @@ -46,7 +49,7 @@ typedef __hip_bfloat16 __nv_bfloat16; namespace vllm { // Utility function for attention softmax. -template +template inline __device__ float block_sum(float* red_smem, float sum) { // Decompose the thread index into warp / lane. int warp = threadIdx.x / WARP_SIZE; @@ -83,31 +86,31 @@ inline __device__ float block_sum(float* red_smem, float sum) { // TODO(woosuk): Merge the last two dimensions of the grid. // Grid: (num_heads, num_seqs, max_num_partitions). -template // Zero means no partitioning. +template< + typename scalar_t, + typename cache_t, + int HEAD_SIZE, + int BLOCK_SIZE, + int NUM_THREADS, + bool IS_FP8_KV_CACHE, + int PARTITION_SIZE = 0> // Zero means no partitioning. __device__ void paged_attention_kernel( - float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] - float* __restrict__ max_logits, // [num_seqs, num_heads, - // max_num_partitions] - scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, - // head_size] - const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] - const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, - // head_size/x, block_size, x] - const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, - // head_size, block_size] - const int num_kv_heads, // [num_heads] - const float scale, - const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] - const int* __restrict__ seq_lens, // [num_seqs] - const int max_num_blocks_per_seq, - const float* __restrict__ alibi_slopes, // [num_heads] - const int q_stride, const int kv_block_stride, const int kv_head_stride, - const float kv_scale, const int tp_rank, const int blocksparse_local_blocks, - const int blocksparse_vert_stride, const int blocksparse_block_size, - const int blocksparse_head_sliding_step) { + float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] + float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] + scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size] + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] + const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] + const int num_kv_heads, // [num_heads] + const float scale, + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ seq_lens, // [num_seqs] + const int max_num_blocks_per_seq, + const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, + const int kv_block_stride, + const int kv_head_stride, + const float kv_scale) { const int seq_idx = blockIdx.y; const int partition_idx = blockIdx.z; const int max_num_partitions = gridDim.z; @@ -119,29 +122,22 @@ __device__ void paged_attention_kernel( } const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE); - const int num_blocks_per_partition = - USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_seq_blocks; + const int num_blocks_per_partition = USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_seq_blocks; // [start_block_idx, end_block_idx) is the range of blocks to process. - const int start_block_idx = - USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0; - const int end_block_idx = - MIN(start_block_idx + num_blocks_per_partition, num_seq_blocks); + const int start_block_idx = USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0; + const int end_block_idx = MIN(start_block_idx + num_blocks_per_partition, num_seq_blocks); const int num_blocks = end_block_idx - start_block_idx; // [start_token_idx, end_token_idx) is the range of tokens to process. const int start_token_idx = start_block_idx * BLOCK_SIZE; - const int end_token_idx = - MIN(start_token_idx + num_blocks * BLOCK_SIZE, seq_len); + const int end_token_idx = MIN(start_token_idx + num_blocks * BLOCK_SIZE, seq_len); const int num_tokens = end_token_idx - start_token_idx; constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1); - constexpr int NUM_THREAD_GROUPS = - NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE - // divides NUM_THREADS + constexpr int NUM_THREAD_GROUPS = NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE divides NUM_THREADS assert(NUM_THREADS % THREAD_GROUP_SIZE == 0); - constexpr int NUM_TOKENS_PER_THREAD_GROUP = - DIVIDE_ROUND_UP(BLOCK_SIZE, WARP_SIZE); + constexpr int NUM_TOKENS_PER_THREAD_GROUP = DIVIDE_ROUND_UP(BLOCK_SIZE, WARP_SIZE); constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; const int thread_idx = threadIdx.x; const int warp_idx = thread_idx / WARP_SIZE; @@ -151,18 +147,19 @@ __device__ void paged_attention_kernel( const int num_heads = gridDim.x; const int num_queries_per_kv = num_heads / num_kv_heads; const int kv_head_idx = head_idx / num_queries_per_kv; - const float alibi_slope = - alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx]; + const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx]; // A vector type to store a part of a key or a query. - // The vector size is configured in such a way that the threads in a thread - // group fetch or compute 16 bytes at a time. For example, if the size of a - // thread group is 4 and the data type is half, then the vector size is 16 / - // (4 * sizeof(half)) == 2. + // The vector size is configured in such a way that the threads in a thread group + // fetch or compute 16 bytes at a time. + // For example, if the size of a thread group is 4 and the data type is half, + // then the vector size is 16 / (4 * sizeof(half)) == 2. constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1); using K_vec = typename Vec::Type; using Q_vec = typename Vec::Type; +#if defined(ENABLE_FP8_E5M2) || defined(ENABLE_FP8_E4M3) using Quant_vec = typename Vec::Type; +#endif constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE; constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE; @@ -172,21 +169,18 @@ __device__ void paged_attention_kernel( // Load the query to registers. // Each thread in a thread group has a different part of the query. - // For example, if the the thread group size is 4, then the first thread in - // the group has 0, 4, 8, ... th vectors of the query, and the second thread - // has 1, 5, 9, ... th vectors of the query, and so on. NOTE(woosuk): Because - // q is split from a qkv tensor, it may not be contiguous. + // For example, if the the thread group size is 4, then the first thread in the group + // has 0, 4, 8, ... th vectors of the query, and the second thread has 1, 5, 9, ... + // th vectors of the query, and so on. + // NOTE(woosuk): Because q is split from a qkv tensor, it may not be contiguous. const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE; __shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD]; #pragma unroll - for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; - i += NUM_THREAD_GROUPS) { + for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; i += NUM_THREAD_GROUPS) { const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE; - q_vecs[thread_group_offset][i] = - *reinterpret_cast(q_ptr + vec_idx * VEC_SIZE); + q_vecs[thread_group_offset][i] = *reinterpret_cast(q_ptr + vec_idx * VEC_SIZE); } - __syncthreads(); // TODO(naed90): possible speedup if this is replaced with a - // memory wall right before we use q_vecs + __syncthreads(); // TODO(naed90): possible speedup if this is replaced with a memory wall right before we use q_vecs // Memory planning. extern __shared__ char shared_mem[]; @@ -205,94 +199,51 @@ __device__ void paged_attention_kernel( // Each thread group in a warp fetches a key from the block, and computes // dot product with the query. const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq; - - // blocksparse specific vars - int bs_block_offset; - int q_bs_block_id; - if constexpr (IS_BLOCK_SPARSE) { - // const int num_blocksparse_blocks = DIVIDE_ROUND_UP(seq_len, - // blocksparse_block_size); - q_bs_block_id = (seq_len - 1) / blocksparse_block_size; - if (blocksparse_head_sliding_step >= 0) - // sliding on q heads - bs_block_offset = - (tp_rank * num_heads + head_idx) * blocksparse_head_sliding_step + 1; - else - // sliding on kv heads - bs_block_offset = (tp_rank * num_kv_heads + kv_head_idx) * - (-blocksparse_head_sliding_step) + - 1; - } - - for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; - block_idx += NUM_WARPS) { - // NOTE(woosuk): The block number is stored in int32. However, we cast it to - // int64 because int32 can lead to overflow when this variable is multiplied - // by large numbers (e.g., kv_block_stride). - // For blocksparse attention: skip computation on blocks that are not - // attended - if constexpr (IS_BLOCK_SPARSE) { - const int k_bs_block_id = block_idx * BLOCK_SIZE / blocksparse_block_size; - const bool is_remote = - ((k_bs_block_id + bs_block_offset) % blocksparse_vert_stride == 0); - const bool is_local = - (k_bs_block_id > q_bs_block_id - blocksparse_local_blocks); - if (!is_remote && !is_local) { - for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) { - const int physical_block_offset = - (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE; - const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; - - if (thread_group_offset == 0) { - // NOTE(linxihui): assign very large number to skipped tokens to - // avoid contribution to the sumexp softmax normalizer. This will - // not be used at computing sum(softmax*v) as the blocks will be - // skipped. - logits[token_idx - start_token_idx] = -FLT_MAX; - } - } - continue; - } - } - const int64_t physical_block_number = - static_cast(block_table[block_idx]); + for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) { + // NOTE(woosuk): The block number is stored in int32. However, we cast it to int64 + // because int32 can lead to overflow when this variable is multiplied by large numbers + // (e.g., kv_block_stride). + const int64_t physical_block_number = static_cast(block_table[block_idx]); // Load a key to registers. // Each thread in a thread group has a different part of the key. - // For example, if the the thread group size is 4, then the first thread in - // the group has 0, 4, 8, ... th vectors of the key, and the second thread - // has 1, 5, 9, ... th vectors of the key, and so on. + // For example, if the the thread group size is 4, then the first thread in the group + // has 0, 4, 8, ... th vectors of the key, and the second thread has 1, 5, 9, ... th + // vectors of the key, and so on. for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) { - const int physical_block_offset = - (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE; + const int physical_block_offset = (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE; const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; K_vec k_vecs[NUM_VECS_PER_THREAD]; #pragma unroll for (int j = 0; j < NUM_VECS_PER_THREAD; j++) { - const cache_t* k_ptr = - k_cache + physical_block_number * kv_block_stride + - kv_head_idx * kv_head_stride + physical_block_offset * x; + const cache_t* k_ptr = k_cache + physical_block_number * kv_block_stride + + kv_head_idx * kv_head_stride + + physical_block_offset * x; const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE; const int offset1 = (vec_idx * VEC_SIZE) / x; const int offset2 = (vec_idx * VEC_SIZE) % x; - - if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) { - k_vecs[j] = *reinterpret_cast( - k_ptr + offset1 * BLOCK_SIZE * x + offset2); - } else { + if constexpr (IS_FP8_KV_CACHE) { +#if defined(ENABLE_FP8_E5M2) + Quant_vec k_vec_quant = *reinterpret_cast(k_ptr + offset1 * BLOCK_SIZE * x + offset2); // Vector conversion from Quant_vec to K_vec. - Quant_vec k_vec_quant = *reinterpret_cast( - k_ptr + offset1 * BLOCK_SIZE * x + offset2); - k_vecs[j] = fp8::scaled_convert( - k_vec_quant, kv_scale); + k_vecs[j] = fp8_e5m2_unscaled::vec_conversion(k_vec_quant); +#elif defined(ENABLE_FP8_E4M3) + Quant_vec k_vec_quant = *reinterpret_cast(k_ptr + offset1 * BLOCK_SIZE * x + offset2); + // Vector conversion from Quant_vec to K_vec. Use scaled_vec_conversion to convert FP8_E4M3 quantized k + // cache vec to k vec in higher precision (FP16, BFloat16, etc.) + k_vecs[j] = fp8_e4m3::scaled_vec_conversion(k_vec_quant, kv_scale); +#else + assert(false); +#endif + } else { + k_vecs[j] = *reinterpret_cast(k_ptr + offset1 * BLOCK_SIZE * x + offset2); } } // Compute dot product. // This includes a reduction across the threads in the same thread group. - float qk = scale * Qk_dot::dot( - q_vecs[thread_group_offset], k_vecs); + float qk = scale * Qk_dot::dot(q_vecs[thread_group_offset], k_vecs); // Add the ALiBi bias if slopes are given. qk += (alibi_slope != 0) ? alibi_slope * (token_idx - seq_len + 1) : 0; @@ -347,12 +298,13 @@ __device__ void paged_attention_kernel( // If partitioning is enabled, store the max logit and exp_sum. if (USE_PARTITIONING && thread_idx == 0) { - float* max_logits_ptr = max_logits + - seq_idx * num_heads * max_num_partitions + - head_idx * max_num_partitions + partition_idx; + float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions + + partition_idx; *max_logits_ptr = qk_max; - float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions + - head_idx * max_num_partitions + partition_idx; + float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions + + partition_idx; *exp_sums_ptr = exp_sum; } @@ -360,13 +312,14 @@ __device__ void paged_attention_kernel( constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE); using V_vec = typename Vec::Type; using L_vec = typename Vec::Type; +#if defined(ENABLE_FP8_E5M2) || defined(ENABLE_FP8_E4M3) using V_quant_vec = typename Vec::Type; +#endif using Float_L_vec = typename FloatVec::Type; constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE; constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW; - constexpr int NUM_ROWS_PER_THREAD = - DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER); + constexpr int NUM_ROWS_PER_THREAD = DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER); // NOTE(woosuk): We use FP32 for the accumulator for better accuracy. float accs[NUM_ROWS_PER_THREAD]; @@ -377,51 +330,44 @@ __device__ void paged_attention_kernel( scalar_t zero_value; zero(zero_value); - for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; - block_idx += NUM_WARPS) { - // NOTE(woosuk): The block number is stored in int32. However, we cast it to - // int64 because int32 can lead to overflow when this variable is multiplied - // by large numbers (e.g., kv_block_stride). - // For blocksparse attention: skip computation on blocks that are not - // attended - if constexpr (IS_BLOCK_SPARSE) { - int v_bs_block_id = block_idx * BLOCK_SIZE / blocksparse_block_size; - if (!((v_bs_block_id + bs_block_offset) % blocksparse_vert_stride == 0) && - !((v_bs_block_id > q_bs_block_id - blocksparse_local_blocks))) { - continue; - } - } - const int64_t physical_block_number = - static_cast(block_table[block_idx]); + for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) { + // NOTE(woosuk): The block number is stored in int32. However, we cast it to int64 + // because int32 can lead to overflow when this variable is multiplied by large numbers + // (e.g., kv_block_stride). + const int64_t physical_block_number = static_cast(block_table[block_idx]); const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE; const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; L_vec logits_vec; - from_float(logits_vec, *reinterpret_cast(logits + token_idx - - start_token_idx)); + from_float(logits_vec, *reinterpret_cast(logits + token_idx - start_token_idx)); - const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride + - kv_head_idx * kv_head_stride; + const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride + + kv_head_idx * kv_head_stride; #pragma unroll for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; if (row_idx < HEAD_SIZE) { const int offset = row_idx * BLOCK_SIZE + physical_block_offset; V_vec v_vec; - - if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) { - v_vec = *reinterpret_cast(v_ptr + offset); - } else { - V_quant_vec v_quant_vec = - *reinterpret_cast(v_ptr + offset); + if constexpr (IS_FP8_KV_CACHE) { +#if defined(ENABLE_FP8_E5M2) + V_quant_vec v_quant_vec = *reinterpret_cast(v_ptr + offset); // Vector conversion from V_quant_vec to V_vec. - v_vec = fp8::scaled_convert(v_quant_vec, - kv_scale); + v_vec = fp8_e5m2_unscaled::vec_conversion(v_quant_vec); +#elif defined(ENABLE_FP8_E4M3) + V_quant_vec v_quant_vec = *reinterpret_cast(v_ptr + offset); + // Vector conversion from V_quant_vec to V_vec. Use scaled_vec_conversion to convert + // FP8_E4M3 quantized v cache vec to v vec in higher precision (FP16, BFloat16, etc.) + v_vec = fp8_e4m3::scaled_vec_conversion(v_quant_vec, kv_scale); +#else + assert(false); +#endif + } else { + v_vec = *reinterpret_cast(v_ptr + offset); } if (block_idx == num_seq_blocks - 1) { - // NOTE(woosuk): When v_vec contains the tokens that are out of the - // context, we should explicitly zero out the values since they may - // contain NaNs. See - // https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472 + // NOTE(woosuk): When v_vec contains the tokens that are out of the context, + // we should explicitly zero out the values since they may contain NaNs. + // See https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472 scalar_t* v_vec_ptr = reinterpret_cast(&v_vec); #pragma unroll for (int j = 0; j < V_VEC_SIZE; j++) { @@ -444,8 +390,8 @@ __device__ void paged_attention_kernel( accs[i] = acc; } - // NOTE(woosuk): A barrier is required because the shared memory space for - // logits is reused for the output. + // NOTE(woosuk): A barrier is required because the shared memory space for logits + // is reused for the output. __syncthreads(); // Perform reduction across warps. @@ -482,9 +428,9 @@ __device__ void paged_attention_kernel( // Write the final output. if (warp_idx == 0) { - scalar_t* out_ptr = - out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + - head_idx * max_num_partitions * HEAD_SIZE + partition_idx * HEAD_SIZE; + scalar_t* out_ptr = out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + head_idx * max_num_partitions * HEAD_SIZE + + partition_idx * HEAD_SIZE; #pragma unroll for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; @@ -496,84 +442,79 @@ __device__ void paged_attention_kernel( } // Grid: (num_heads, num_seqs, 1). -template +template< + typename scalar_t, + typename cache_t, + int HEAD_SIZE, + int BLOCK_SIZE, + int NUM_THREADS, + bool IS_FP8_KV_CACHE> __global__ void paged_attention_v1_kernel( - scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] - const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] - const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, - // head_size/x, block_size, x] - const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, - // head_size, block_size] - const int num_kv_heads, // [num_heads] - const float scale, - const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] - const int* __restrict__ seq_lens, // [num_seqs] - const int max_num_blocks_per_seq, - const float* __restrict__ alibi_slopes, // [num_heads] - const int q_stride, const int kv_block_stride, const int kv_head_stride, - const float kv_scale, const int tp_rank, const int blocksparse_local_blocks, - const int blocksparse_vert_stride, const int blocksparse_block_size, - const int blocksparse_head_sliding_step) { - paged_attention_kernel( - /* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache, - v_cache, num_kv_heads, scale, block_tables, seq_lens, - max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, - kv_head_stride, kv_scale, tp_rank, blocksparse_local_blocks, - blocksparse_vert_stride, blocksparse_block_size, - blocksparse_head_sliding_step); + scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] + const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] + const int num_kv_heads, // [num_heads] + const float scale, + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ seq_lens, // [num_seqs] + const int max_num_blocks_per_seq, + const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, + const int kv_block_stride, + const int kv_head_stride, + const float kv_scale) { + paged_attention_kernel( + /* exp_sums */ nullptr, /* max_logits */ nullptr, + out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, seq_lens, + max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride, kv_scale); } // Grid: (num_heads, num_seqs, max_num_partitions). -template +template< + typename scalar_t, + typename cache_t, + int HEAD_SIZE, + int BLOCK_SIZE, + int NUM_THREADS, + bool IS_FP8_KV_CACHE, + int PARTITION_SIZE> __global__ void paged_attention_v2_kernel( - float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] - float* __restrict__ max_logits, // [num_seqs, num_heads, - // max_num_partitions] - scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, - // max_num_partitions, head_size] - const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] - const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, - // head_size/x, block_size, x] - const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, - // head_size, block_size] - const int num_kv_heads, // [num_heads] - const float scale, - const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] - const int* __restrict__ seq_lens, // [num_seqs] - const int max_num_blocks_per_seq, - const float* __restrict__ alibi_slopes, // [num_heads] - const int q_stride, const int kv_block_stride, const int kv_head_stride, - const float kv_scale, const int tp_rank, const int blocksparse_local_blocks, - const int blocksparse_vert_stride, const int blocksparse_block_size, - const int blocksparse_head_sliding_step) { - paged_attention_kernel( - exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale, - block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, q_stride, - kv_block_stride, kv_head_stride, kv_scale, tp_rank, - blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size, - blocksparse_head_sliding_step); + float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] + float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] + scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] + const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] + const int num_kv_heads, // [num_heads] + const float scale, + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ seq_lens, // [num_seqs] + const int max_num_blocks_per_seq, + const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, + const int kv_block_stride, + const int kv_head_stride, + const float kv_scale) { + paged_attention_kernel( + exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale, + block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, + q_stride, kv_block_stride, kv_head_stride, kv_scale); } // Grid: (num_heads, num_seqs). -template +template< + typename scalar_t, + int HEAD_SIZE, + int NUM_THREADS, + int PARTITION_SIZE> __global__ void paged_attention_v2_reduce_kernel( - scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] - const float* __restrict__ exp_sums, // [num_seqs, num_heads, - // max_num_partitions] - const float* __restrict__ max_logits, // [num_seqs, num_heads, - // max_num_partitions] - const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, - // max_num_partitions, head_size] - const int* __restrict__ seq_lens, // [num_seqs] - const int max_num_partitions) { + scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] + const float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] + const float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] + const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] + const int* __restrict__ seq_lens, // [num_seqs] + const int max_num_partitions) { const int num_heads = gridDim.x; const int head_idx = blockIdx.x; const int seq_idx = blockIdx.y; @@ -581,11 +522,9 @@ __global__ void paged_attention_v2_reduce_kernel( const int num_partitions = DIVIDE_ROUND_UP(seq_len, PARTITION_SIZE); if (num_partitions == 1) { // No need to reduce. Only copy tmp_out to out. - scalar_t* out_ptr = - out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; - const scalar_t* tmp_out_ptr = - tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + - head_idx * max_num_partitions * HEAD_SIZE; + scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; + const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + head_idx * max_num_partitions * HEAD_SIZE; for (int i = threadIdx.x; i < HEAD_SIZE; i += blockDim.x) { out_ptr[i] = tmp_out_ptr[i]; } @@ -604,9 +543,8 @@ __global__ void paged_attention_v2_reduce_kernel( // Load max logits to shared memory. float* shared_max_logits = reinterpret_cast(shared_mem); - const float* max_logits_ptr = max_logits + - seq_idx * num_heads * max_num_partitions + - head_idx * max_num_partitions; + const float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions; float max_logit = -FLT_MAX; for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) { const float l = max_logits_ptr[i]; @@ -635,11 +573,9 @@ __global__ void paged_attention_v2_reduce_kernel( max_logit = VLLM_SHFL_SYNC(max_logit, 0); // Load rescaled exp sums to shared memory. - float* shared_exp_sums = - reinterpret_cast(shared_mem + sizeof(float) * num_partitions); - const float* exp_sums_ptr = exp_sums + - seq_idx * num_heads * max_num_partitions + - head_idx * max_num_partitions; + float* shared_exp_sums = reinterpret_cast(shared_mem + sizeof(float) * num_partitions); + const float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions; float global_exp_sum = 0.0f; for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) { float l = shared_max_logits[i]; @@ -652,52 +588,61 @@ __global__ void paged_attention_v2_reduce_kernel( const float inv_global_exp_sum = __fdividef(1.0f, global_exp_sum + 1e-6f); // Aggregate tmp_out to out. - const scalar_t* tmp_out_ptr = - tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + - head_idx * max_num_partitions * HEAD_SIZE; - scalar_t* out_ptr = - out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; + const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + head_idx * max_num_partitions * HEAD_SIZE; + scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; #pragma unroll for (int i = threadIdx.x; i < HEAD_SIZE; i += NUM_THREADS) { float acc = 0.0f; for (int j = 0; j < num_partitions; ++j) { - acc += to_float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] * - inv_global_exp_sum; + acc += to_float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] * inv_global_exp_sum; } from_float(out_ptr[i], acc); } } -} // namespace vllm - -#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \ - VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \ - ((void*)vllm::paged_attention_v1_kernel), \ - shared_mem_size); \ - vllm::paged_attention_v1_kernel \ - <<>>( \ - out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, \ - scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \ - alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \ - kv_scale, tp_rank, blocksparse_local_blocks, \ - blocksparse_vert_stride, blocksparse_block_size, \ - blocksparse_head_sliding_step); +} // namespace vllm + +#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \ + VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \ + ((void*)vllm::paged_attention_v1_kernel), shared_mem_size); \ + vllm::paged_attention_v1_kernel<<>>( \ + out_ptr, \ + query_ptr, \ + key_cache_ptr, \ + value_cache_ptr, \ + num_kv_heads, \ + scale, \ + block_tables_ptr, \ + seq_lens_ptr, \ + max_num_blocks_per_seq, \ + alibi_slopes_ptr, \ + q_stride, \ + kv_block_stride, \ + kv_head_stride, \ + kv_scale); // TODO(woosuk): Tune NUM_THREADS. -template +template< + typename T, + typename CACHE_T, + int BLOCK_SIZE, + bool IS_FP8_KV_CACHE, + int NUM_THREADS = 128> void paged_attention_v1_launcher( - torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, - torch::Tensor& value_cache, int num_kv_heads, float scale, - torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len, - const c10::optional& alibi_slopes, float kv_scale, - const int tp_rank, const int blocksparse_local_blocks, - const int blocksparse_vert_stride, const int blocksparse_block_size, - const int blocksparse_head_sliding_step) { + torch::Tensor& out, + torch::Tensor& query, + torch::Tensor& key_cache, + torch::Tensor& value_cache, + int num_kv_heads, + float scale, + torch::Tensor& block_tables, + torch::Tensor& seq_lens, + int max_seq_len, + const c10::optional& alibi_slopes, + float kv_scale) { int num_seqs = query.size(0); int num_heads = query.size(1); int head_size = query.size(2); @@ -710,10 +655,9 @@ void paged_attention_v1_launcher( assert(head_size % thread_group_size == 0); // NOTE: alibi_slopes is optional. - const float* alibi_slopes_ptr = - alibi_slopes - ? reinterpret_cast(alibi_slopes.value().data_ptr()) - : nullptr; + const float* alibi_slopes_ptr = alibi_slopes ? + reinterpret_cast(alibi_slopes.value().data_ptr()) + : nullptr; T* out_ptr = reinterpret_cast(out.data_ptr()); T* query_ptr = reinterpret_cast(query.data_ptr()); @@ -723,8 +667,7 @@ void paged_attention_v1_launcher( int* seq_lens_ptr = seq_lens.data_ptr(); constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; - int padded_max_seq_len = - DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE; + int padded_max_seq_len = DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE; int logits_size = padded_max_seq_len * sizeof(float); int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); // Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len @@ -754,9 +697,6 @@ void paged_attention_v1_launcher( case 128: LAUNCH_PAGED_ATTENTION_V1(128); break; - case 192: - LAUNCH_PAGED_ATTENTION_V1(192); - break; case 256: LAUNCH_PAGED_ATTENTION_V1(256); break; @@ -766,94 +706,128 @@ void paged_attention_v1_launcher( } } -#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \ - paged_attention_v1_launcher( \ - out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \ - seq_lens, max_seq_len, alibi_slopes, kv_scale, tp_rank, \ - blocksparse_local_blocks, blocksparse_vert_stride, \ - blocksparse_block_size, blocksparse_head_sliding_step); - -#define CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \ - switch (is_block_sparse) { \ - case true: \ - CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, true); \ - break; \ - case false: \ - CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, false); \ - break; \ - } +#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \ + paged_attention_v1_launcher( \ + out, \ + query, \ + key_cache, \ + value_cache, \ + num_kv_heads, \ + scale, \ + block_tables, \ + seq_lens, \ + max_seq_len, \ + alibi_slopes, \ + kv_scale); // NOTE(woosuk): To reduce the compilation time, we omitted block sizes // 1, 2, 4, 64, 128, 256. -#define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \ - switch (block_size) { \ - case 8: \ - CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 8, KV_DTYPE); \ - break; \ - case 16: \ - CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 16, KV_DTYPE); \ - break; \ - case 32: \ - CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 32, KV_DTYPE); \ - break; \ - default: \ - TORCH_CHECK(false, "Unsupported block size: ", block_size); \ - break; \ +#define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, IS_FP8_KV_CACHE) \ + switch (block_size) { \ + case 8: \ + CALL_V1_LAUNCHER(T, CACHE_T, 8, IS_FP8_KV_CACHE); \ + break; \ + case 16: \ + CALL_V1_LAUNCHER(T, CACHE_T, 16, IS_FP8_KV_CACHE); \ + break; \ + case 32: \ + CALL_V1_LAUNCHER(T, CACHE_T, 32, IS_FP8_KV_CACHE); \ + break; \ + default: \ + TORCH_CHECK(false, "Unsupported block size: ", block_size); \ + break; \ } void paged_attention_v1( - torch::Tensor& out, // [num_seqs, num_heads, head_size] - torch::Tensor& query, // [num_seqs, num_heads, head_size] - torch::Tensor& - key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] - torch::Tensor& - value_cache, // [num_blocks, num_heads, head_size, block_size] - int64_t num_kv_heads, // [num_heads] - double scale, - torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] - torch::Tensor& seq_lens, // [num_seqs] - int64_t block_size, int64_t max_seq_len, - const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank, - const int64_t blocksparse_local_blocks, - const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, - const int64_t blocksparse_head_sliding_step) { - const bool is_block_sparse = (blocksparse_vert_stride > 1); - - DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype, - CALL_V1_LAUNCHER_BLOCK_SIZE) + torch::Tensor& out, // [num_seqs, num_heads, head_size] + torch::Tensor& query, // [num_seqs, num_heads, head_size] + torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] + int num_kv_heads, // [num_heads] + float scale, + torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] + torch::Tensor& seq_lens, // [num_seqs] + int block_size, + int max_seq_len, + const c10::optional& alibi_slopes, + const std::string& kv_cache_dtype, + float kv_scale) { + if (kv_cache_dtype == "auto") { + if (query.dtype() == at::ScalarType::Float) { + CALL_V1_LAUNCHER_BLOCK_SIZE(float, float, false); + } else if (query.dtype() == at::ScalarType::Half) { + CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t, uint16_t, false); + } else if (query.dtype() == at::ScalarType::BFloat16) { + CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, __nv_bfloat16, false); + } else { + TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); + } + } else if (kv_cache_dtype == "fp8") { + if (query.dtype() == at::ScalarType::Float) { + CALL_V1_LAUNCHER_BLOCK_SIZE(float, uint8_t, true); + } else if (query.dtype() == at::ScalarType::Half) { + CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t, uint8_t, true); + } else if (query.dtype() == at::ScalarType::BFloat16) { + CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, uint8_t, true); + } else { + TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); + } + } else { + TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype); + } } -#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \ - vllm::paged_attention_v2_kernel \ - <<>>( \ - exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \ - value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \ - seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \ - kv_block_stride, kv_head_stride, kv_scale, tp_rank, \ - blocksparse_local_blocks, blocksparse_vert_stride, \ - blocksparse_block_size, blocksparse_head_sliding_step); \ - vllm::paged_attention_v2_reduce_kernel \ - <<>>( \ - out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, seq_lens_ptr, \ - max_num_partitions); - -template +#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \ + vllm::paged_attention_v2_kernel \ + <<>>( \ + exp_sums_ptr, \ + max_logits_ptr, \ + tmp_out_ptr, \ + query_ptr, \ + key_cache_ptr, \ + value_cache_ptr, \ + num_kv_heads, \ + scale, \ + block_tables_ptr, \ + seq_lens_ptr, \ + max_num_blocks_per_seq, \ + alibi_slopes_ptr, \ + q_stride, \ + kv_block_stride, \ + kv_head_stride, \ + kv_scale); \ + vllm::paged_attention_v2_reduce_kernel \ + <<>>( \ + out_ptr, \ + exp_sums_ptr, \ + max_logits_ptr, \ + tmp_out_ptr, \ + seq_lens_ptr, \ + max_num_partitions); + +template< + typename T, + typename CACHE_T, + int BLOCK_SIZE, + bool IS_FP8_KV_CACHE, + int NUM_THREADS = 128, + int PARTITION_SIZE = 512> void paged_attention_v2_launcher( - torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, - torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, - torch::Tensor& value_cache, int num_kv_heads, float scale, - torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len, - const c10::optional& alibi_slopes, float kv_scale, - const int tp_rank, const int blocksparse_local_blocks, - const int blocksparse_vert_stride, const int blocksparse_block_size, - const int blocksparse_head_sliding_step) { + torch::Tensor& out, + torch::Tensor& exp_sums, + torch::Tensor& max_logits, + torch::Tensor& tmp_out, + torch::Tensor& query, + torch::Tensor& key_cache, + torch::Tensor& value_cache, + int num_kv_heads, + float scale, + torch::Tensor& block_tables, + torch::Tensor& seq_lens, + int max_seq_len, + const c10::optional& alibi_slopes, + float kv_scale) { int num_seqs = query.size(0); int num_heads = query.size(1); int head_size = query.size(2); @@ -866,10 +840,9 @@ void paged_attention_v2_launcher( assert(head_size % thread_group_size == 0); // NOTE: alibi_slopes is optional. - const float* alibi_slopes_ptr = - alibi_slopes - ? reinterpret_cast(alibi_slopes.value().data_ptr()) - : nullptr; + const float* alibi_slopes_ptr = alibi_slopes ? + reinterpret_cast(alibi_slopes.value().data_ptr()) + : nullptr; T* out_ptr = reinterpret_cast(out.data_ptr()); float* exp_sums_ptr = reinterpret_cast(exp_sums.data_ptr()); @@ -915,9 +888,6 @@ void paged_attention_v2_launcher( case 128: LAUNCH_PAGED_ATTENTION_V2(128); break; - case 192: - LAUNCH_PAGED_ATTENTION_V2(192); - break; case 256: LAUNCH_PAGED_ATTENTION_V2(256); break; @@ -927,66 +897,81 @@ void paged_attention_v2_launcher( } } -#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \ - paged_attention_v2_launcher( \ - out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ - num_kv_heads, scale, block_tables, seq_lens, max_seq_len, alibi_slopes, \ - kv_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride, \ - blocksparse_block_size, blocksparse_head_sliding_step); - -#define CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \ - switch (is_block_sparse) { \ - case true: \ - CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, true); \ - break; \ - case false: \ - CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, false); \ - break; \ - } +#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \ + paged_attention_v2_launcher( \ + out, \ + exp_sums, \ + max_logits, \ + tmp_out, \ + query, \ + key_cache, \ + value_cache, \ + num_kv_heads, \ + scale, \ + block_tables, \ + seq_lens, \ + max_seq_len, \ + alibi_slopes, \ + kv_scale); // NOTE(woosuk): To reduce the compilation time, we omitted block sizes // 1, 2, 4, 64, 128, 256. -#define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \ - switch (block_size) { \ - case 8: \ - CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 8, KV_DTYPE); \ - break; \ - case 16: \ - CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 16, KV_DTYPE); \ - break; \ - case 32: \ - CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 32, KV_DTYPE); \ - break; \ - default: \ - TORCH_CHECK(false, "Unsupported block size: ", block_size); \ - break; \ +#define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, IS_FP8_KV_CACHE) \ + switch (block_size) { \ + case 8: \ + CALL_V2_LAUNCHER(T, CACHE_T, 8, IS_FP8_KV_CACHE); \ + break; \ + case 16: \ + CALL_V2_LAUNCHER(T, CACHE_T, 16, IS_FP8_KV_CACHE); \ + break; \ + case 32: \ + CALL_V2_LAUNCHER(T, CACHE_T, 32, IS_FP8_KV_CACHE); \ + break; \ + default: \ + TORCH_CHECK(false, "Unsupported block size: ", block_size); \ + break; \ } void paged_attention_v2( - torch::Tensor& out, // [num_seqs, num_heads, head_size] - torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions] - torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions] - torch::Tensor& - tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] - torch::Tensor& query, // [num_seqs, num_heads, head_size] - torch::Tensor& - key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] - torch::Tensor& - value_cache, // [num_blocks, num_heads, head_size, block_size] - int64_t num_kv_heads, // [num_heads] - double scale, - torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] - torch::Tensor& seq_lens, // [num_seqs] - int64_t block_size, int64_t max_seq_len, - const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank, - const int64_t blocksparse_local_blocks, - const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, - const int64_t blocksparse_head_sliding_step) { - const bool is_block_sparse = (blocksparse_vert_stride > 1); - DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype, - CALL_V2_LAUNCHER_BLOCK_SIZE) + torch::Tensor& out, // [num_seqs, num_heads, head_size] + torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions] + torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions] + torch::Tensor& tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] + torch::Tensor& query, // [num_seqs, num_heads, head_size] + torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] + int num_kv_heads, // [num_heads] + float scale, + torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] + torch::Tensor& seq_lens, // [num_seqs] + int block_size, + int max_seq_len, + const c10::optional& alibi_slopes, + const std::string& kv_cache_dtype, + float kv_scale) { + if (kv_cache_dtype == "auto") { + if (query.dtype() == at::ScalarType::Float) { + CALL_V2_LAUNCHER_BLOCK_SIZE(float, float, false); + } else if (query.dtype() == at::ScalarType::Half) { + CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t, uint16_t, false); + } else if (query.dtype() == at::ScalarType::BFloat16) { + CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, __nv_bfloat16, false); + } else { + TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); + } + } else if (kv_cache_dtype == "fp8") { + if (query.dtype() == at::ScalarType::Float) { + CALL_V2_LAUNCHER_BLOCK_SIZE(float, uint8_t, true); + } else if (query.dtype() == at::ScalarType::Half) { + CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t, uint8_t, true); + } else if (query.dtype() == at::ScalarType::BFloat16) { + CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, uint8_t, true); + } else { + TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); + } + } else { + TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype); + } } #undef WARP_SIZE diff --git a/csrc/attention/attention_utils.cuh b/csrc/attention/attention_utils.cuh index cdcee42748998..ff64c4bd8f80c 100644 --- a/csrc/attention/attention_utils.cuh +++ b/csrc/attention/attention_utils.cuh @@ -1,6 +1,5 @@ /* - * Adapted from - * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp + * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp * Copyright (c) 2023, The vLLM team. * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * @@ -27,7 +26,7 @@ namespace vllm { // Q*K^T operation. -template +template inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) { using A_vec = typename FloatVec::Type; // Compute the parallel products for Q*K^T (treat vector lanes separately). @@ -46,12 +45,12 @@ inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) { return qk; } -template +template struct Qk_dot { - template + template static inline __device__ float dot(const Vec (&q)[N], const Vec (&k)[N]) { return qk_dot_(q, k); } }; -} // namespace vllm +} // namespace vllm diff --git a/csrc/attention/dtype_bfloat16.cuh b/csrc/attention/dtype_bfloat16.cuh index 3cdcb95e08099..31e0cee01d2e1 100644 --- a/csrc/attention/dtype_bfloat16.cuh +++ b/csrc/attention/dtype_bfloat16.cuh @@ -1,8 +1,6 @@ /* - * Adapted from - * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp - * and - * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h + * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp + * and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h * Copyright (c) 2023, The vLLM team. * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * @@ -30,8 +28,8 @@ #include #include -typedef __hip_bfloat162 __nv_bfloat162; -typedef __hip_bfloat16 __nv_bfloat16; + typedef __hip_bfloat162 __nv_bfloat162; + typedef __hip_bfloat16 __nv_bfloat16; #endif #include @@ -52,37 +50,37 @@ struct bf16_8_t { }; // BF16 vector types for Q, K, V. -template <> +template<> struct Vec<__nv_bfloat16, 1> { using Type = __nv_bfloat16; }; -template <> +template<> struct Vec<__nv_bfloat16, 2> { using Type = __nv_bfloat162; }; -template <> +template<> struct Vec<__nv_bfloat16, 4> { using Type = bf16_4_t; }; -template <> +template<> struct Vec<__nv_bfloat16, 8> { using Type = bf16_8_t; }; // FP32 accumulator vector types corresponding to Vec. -template <> +template<> struct FloatVec<__nv_bfloat16> { using Type = float; }; -template <> +template<> struct FloatVec<__nv_bfloat162> { using Type = float2; }; -template <> +template<> struct FloatVec { using Type = Float4_; }; -template <> +template<> struct FloatVec { using Type = Float8_; }; @@ -110,9 +108,9 @@ inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) { assert(false); #else #ifndef USE_ROCM - return a + b; + return a + b; #else - return __hadd(a, b); + return __hadd(a, b); #endif #endif } @@ -163,7 +161,7 @@ inline __device__ Float8_ add(bf16_8_t a, Float8_ fb) { } // Vector multiplication. -template <> +template<> inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 assert(false); @@ -172,7 +170,7 @@ inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b) { #endif } -template <> +template<> inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 assert(false); @@ -181,12 +179,12 @@ inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b) { #endif } -template <> +template<> inline __device__ __nv_bfloat162 mul(__nv_bfloat16 a, __nv_bfloat162 b) { return mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(bf162bf162(a), b); } -template <> +template<> inline __device__ bf16_4_t mul(bf16_4_t a, bf16_4_t b) { bf16_4_t c; c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x); @@ -194,7 +192,7 @@ inline __device__ bf16_4_t mul(bf16_4_t a, bf16_4_t b) { return c; } -template <> +template<> inline __device__ bf16_4_t mul(__nv_bfloat16 a, bf16_4_t b) { __nv_bfloat162 s = bf162bf162(a); bf16_4_t c; @@ -203,7 +201,7 @@ inline __device__ bf16_4_t mul(__nv_bfloat16 a, bf16_4_t b) { return c; } -template <> +template<> inline __device__ bf16_8_t mul(bf16_8_t a, bf16_8_t b) { bf16_8_t c; c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x); @@ -213,7 +211,7 @@ inline __device__ bf16_8_t mul(bf16_8_t a, bf16_8_t b) { return c; } -template <> +template<> inline __device__ bf16_8_t mul(__nv_bfloat16 a, bf16_8_t b) { __nv_bfloat162 s = bf162bf162(a); bf16_8_t c; @@ -224,26 +222,26 @@ inline __device__ bf16_8_t mul(__nv_bfloat16 a, bf16_8_t b) { return c; } -template <> +template<> inline __device__ float mul(__nv_bfloat16 a, __nv_bfloat16 b) { float fa = __bfloat162float(a); float fb = __bfloat162float(b); return fa * fb; } -template <> +template<> inline __device__ float2 mul(__nv_bfloat162 a, __nv_bfloat162 b) { float2 fa = bf1622float2(a); float2 fb = bf1622float2(b); return mul(fa, fb); } -template <> +template<> inline __device__ float2 mul(__nv_bfloat16 a, __nv_bfloat162 b) { return mul(bf162bf162(a), b); } -template <> +template<> inline __device__ Float4_ mul(bf16_4_t a, bf16_4_t b) { Float4_ fc; fc.x = mul(a.x, b.x); @@ -251,7 +249,7 @@ inline __device__ Float4_ mul(bf16_4_t a, bf16_4_t b) { return fc; } -template <> +template<> inline __device__ Float4_ mul(__nv_bfloat16 a, bf16_4_t b) { __nv_bfloat162 s = bf162bf162(a); Float4_ fc; @@ -260,7 +258,7 @@ inline __device__ Float4_ mul(__nv_bfloat16 a, bf16_4_t b) { return fc; } -template <> +template<> inline __device__ Float8_ mul(bf16_8_t a, bf16_8_t b) { Float8_ fc; fc.x = mul(a.x, b.x); @@ -270,7 +268,7 @@ inline __device__ Float8_ mul(bf16_8_t a, bf16_8_t b) { return fc; } -template <> +template<> inline __device__ Float8_ mul(__nv_bfloat16 a, bf16_8_t b) { __nv_bfloat162 s = bf162bf162(a); Float8_ fc; @@ -282,8 +280,7 @@ inline __device__ Float8_ mul(__nv_bfloat16 a, bf16_8_t b) { } // Vector fused multiply-add. -inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, - __nv_bfloat162 c) { +inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 assert(false); #else @@ -291,8 +288,7 @@ inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, #endif } -inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b, - __nv_bfloat162 c) { +inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b, __nv_bfloat162 c) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 assert(false); #else @@ -383,23 +379,23 @@ inline __device__ Float8_ fma(__nv_bfloat16 a, bf16_8_t b, Float8_ fc) { } // Vector sum. -template <> +template<> inline __device__ float sum(__nv_bfloat16 v) { return __bfloat162float(v); } -template <> +template<> inline __device__ float sum(__nv_bfloat162 v) { float2 vf = bf1622float2(v); return vf.x + vf.y; } -template <> +template<> inline __device__ float sum(bf16_4_t v) { return sum(v.x) + sum(v.y); } -template <> +template<> inline __device__ float sum(bf16_8_t v) { return sum(v.x) + sum(v.y) + sum(v.z) + sum(v.w); } @@ -452,4 +448,4 @@ inline __device__ void zero(__nv_bfloat16& dst) { #endif } -} // namespace vllm +} // namespace vllm diff --git a/csrc/attention/dtype_float16.cuh b/csrc/attention/dtype_float16.cuh index 3a1815f0ed4fc..d3271e69cd69d 100644 --- a/csrc/attention/dtype_float16.cuh +++ b/csrc/attention/dtype_float16.cuh @@ -1,8 +1,6 @@ /* - * Adapted from - * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp - * and - * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h + * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp + * and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h * Copyright (c) 2023, The vLLM team. * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * @@ -32,37 +30,37 @@ namespace vllm { // FP16 vector types for Q, K, V. -template <> +template<> struct Vec { using Type = uint16_t; }; -template <> +template<> struct Vec { using Type = uint32_t; }; -template <> +template<> struct Vec { using Type = uint2; }; -template <> +template<> struct Vec { using Type = uint4; }; // FP32 accumulator vector types corresponding to Vec. -template <> +template<> struct FloatVec { using Type = float; }; -template <> +template<> struct FloatVec { using Type = float2; }; -template <> +template<> struct FloatVec { using Type = Float4_; }; -template <> +template<> struct FloatVec { using Type = Float8_; }; @@ -75,8 +73,8 @@ inline __device__ uint32_t h0_h0(uint16_t a) { return b; #else union { - uint32_t u32; - uint16_t u16[2]; + uint32_t u32; + uint16_t u16[2]; } tmp; tmp.u16[0] = a; tmp.u16[1] = a; @@ -132,12 +130,10 @@ inline __device__ uint32_t float2_to_half2(float2 f) { } tmp; #ifndef USE_ROCM #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" - : "=r"(tmp.u32) - : "f"(f.y), "f"(f.x)); + asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(f.y), "f"(f.x)); #else - asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x)); - asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y)); + asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x)); + asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y)); #endif #else tmp.u16[0] = float_to_half(f.x); @@ -205,7 +201,7 @@ inline __device__ Float8_ add(uint4 a, Float8_ fb) { } // Vector multiplication. -template <> +template<> inline __device__ uint16_t mul(uint16_t a, uint16_t b) { uint16_t c; #ifndef USE_ROCM @@ -216,7 +212,7 @@ inline __device__ uint16_t mul(uint16_t a, uint16_t b) { return c; } -template <> +template<> inline __device__ uint32_t mul(uint32_t a, uint32_t b) { uint32_t c; #ifndef USE_ROCM @@ -227,12 +223,12 @@ inline __device__ uint32_t mul(uint32_t a, uint32_t b) { return c; } -template <> +template<> inline __device__ uint32_t mul(uint16_t a, uint32_t b) { return mul(h0_h0(a), b); } -template <> +template<> inline __device__ uint2 mul(uint2 a, uint2 b) { uint2 c; c.x = mul(a.x, b.x); @@ -240,7 +236,7 @@ inline __device__ uint2 mul(uint2 a, uint2 b) { return c; } -template <> +template<> inline __device__ uint2 mul(uint16_t a, uint2 b) { uint32_t s = h0_h0(a); uint2 c; @@ -249,7 +245,7 @@ inline __device__ uint2 mul(uint16_t a, uint2 b) { return c; } -template <> +template<> inline __device__ uint4 mul(uint4 a, uint4 b) { uint4 c; c.x = mul(a.x, b.x); @@ -259,7 +255,7 @@ inline __device__ uint4 mul(uint4 a, uint4 b) { return c; } -template <> +template<> inline __device__ uint4 mul(uint16_t a, uint4 b) { uint32_t s = h0_h0(a); uint4 c; @@ -270,26 +266,26 @@ inline __device__ uint4 mul(uint16_t a, uint4 b) { return c; } -template <> +template<> inline __device__ float mul(uint16_t a, uint16_t b) { float fa = half_to_float(a); float fb = half_to_float(b); return fa * fb; } -template <> +template<> inline __device__ float2 mul(uint32_t a, uint32_t b) { float2 fa = half2_to_float2(a); float2 fb = half2_to_float2(b); return mul(fa, fb); } -template <> +template<> inline __device__ float2 mul(uint16_t a, uint32_t b) { return mul(h0_h0(a), b); } -template <> +template<> inline __device__ Float4_ mul(uint2 a, uint2 b) { Float4_ fc; fc.x = mul(a.x, b.x); @@ -297,7 +293,7 @@ inline __device__ Float4_ mul(uint2 a, uint2 b) { return fc; } -template <> +template<> inline __device__ Float4_ mul(uint16_t a, uint2 b) { uint32_t s = h0_h0(a); Float4_ fc; @@ -306,7 +302,7 @@ inline __device__ Float4_ mul(uint16_t a, uint2 b) { return fc; } -template <> +template<> inline __device__ Float8_ mul(uint4 a, uint4 b) { Float8_ fc; fc.x = mul(a.x, b.x); @@ -316,7 +312,7 @@ inline __device__ Float8_ mul(uint4 a, uint4 b) { return fc; } -template <> +template<> inline __device__ Float8_ mul(uint16_t a, uint4 b) { uint32_t s = h0_h0(a); Float8_ fc; @@ -331,13 +327,9 @@ inline __device__ Float8_ mul(uint16_t a, uint4 b) { inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) { uint32_t d; #ifndef USE_ROCM - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" - : "=r"(d) - : "r"(a), "r"(b), "r"(c)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(d) : "r"(a), "r"(b), "r"(c)); #else - asm volatile("v_pk_fma_f16 %0, %1, %2, %3;\n" - : "=v"(d) - : "v"(a), "v"(b), "v"(c)); + asm volatile("v_pk_fma_f16 %0, %1, %2, %3;\n" : "=v"(d) : "v"(a), "v"(b), "v"(c)); #endif return d; } @@ -431,24 +423,24 @@ inline __device__ Float8_ fma(uint16_t a, uint4 b, Float8_ fc) { } // Vector sum. -template <> +template<> inline __device__ float sum(uint16_t v) { return half_to_float(v); } -template <> +template<> inline __device__ float sum(uint32_t v) { float2 tmp = half2_to_float2(v); return tmp.x + tmp.y; } -template <> +template<> inline __device__ float sum(uint2 v) { uint32_t c = add(v.x, v.y); return sum(c); } -template <> +template<> inline __device__ float sum(uint4 v) { uint32_t c = add(v.x, v.y); c = add(c, v.z); @@ -478,9 +470,13 @@ inline __device__ void from_float(uint4& dst, Float8_ src) { } // From float16 to float32. -inline __device__ float to_float(uint16_t u) { return half_to_float(u); } +inline __device__ float to_float(uint16_t u) { + return half_to_float(u); +} -inline __device__ float2 to_float(uint32_t u) { return half2_to_float2(u); } +inline __device__ float2 to_float(uint32_t u) { + return half2_to_float2(u); +} inline __device__ Float4_ to_float(uint2 u) { Float4_ tmp; @@ -499,6 +495,8 @@ inline __device__ Float8_ to_float(uint4 u) { } // Zero-out a variable. -inline __device__ void zero(uint16_t& dst) { dst = uint16_t(0); } +inline __device__ void zero(uint16_t& dst) { + dst = uint16_t(0); +} -} // namespace vllm +} // namespace vllm diff --git a/csrc/attention/dtype_float32.cuh b/csrc/attention/dtype_float32.cuh index 7c6a686db3ba9..b200d2d226eb0 100644 --- a/csrc/attention/dtype_float32.cuh +++ b/csrc/attention/dtype_float32.cuh @@ -1,8 +1,6 @@ /* - * Adapted from - * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp - * and - * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h + * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp + * and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h * Copyright (c) 2023, The vLLM team. * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * @@ -40,35 +38,37 @@ struct Float8_ { }; // FP32 vector types for Q, K, V. -template <> +template<> struct Vec { using Type = float; }; -template <> +template<> struct Vec { using Type = float2; }; -template <> +template<> struct Vec { using Type = float4; }; // FP32 accumulator vector types corresponding to Vec. -template <> +template<> struct FloatVec { using Type = float; }; -template <> +template<> struct FloatVec { using Type = float2; }; -template <> +template<> struct FloatVec { using Type = float4; }; // Vector addition. -inline __device__ float add(float a, float b) { return a + b; } +inline __device__ float add(float a, float b) { + return a + b; +} inline __device__ float2 add(float2 a, float2 b) { float2 c; @@ -87,12 +87,12 @@ inline __device__ float4 add(float4 a, float4 b) { } // Vector multiplication. -template <> +template<> inline __device__ float mul(float a, float b) { return a * b; } -template <> +template<> inline __device__ float2 mul(float2 a, float2 b) { float2 c; c.x = a.x * b.x; @@ -100,7 +100,7 @@ inline __device__ float2 mul(float2 a, float2 b) { return c; } -template <> +template<> inline __device__ float2 mul(float a, float2 b) { float2 c; c.x = a * b.x; @@ -108,7 +108,7 @@ inline __device__ float2 mul(float a, float2 b) { return c; } -template <> +template<> inline __device__ float4 mul(float4 a, float4 b) { float4 c; c.x = a.x * b.x; @@ -118,7 +118,7 @@ inline __device__ float4 mul(float4 a, float4 b) { return c; } -template <> +template<> inline __device__ float4 mul(float a, float4 b) { float4 c; c.x = a * b.x; @@ -129,7 +129,9 @@ inline __device__ float4 mul(float a, float4 b) { } // Vector fused multiply-add. -inline __device__ float fma(float a, float b, float c) { return a * b + c; } +inline __device__ float fma(float a, float b, float c) { + return a * b + c; +} inline __device__ float2 fma(float2 a, float2 b, float2 c) { float2 d; @@ -180,33 +182,35 @@ inline __device__ Float8_ fma(float a, Float8_ b, Float8_ c) { } // Vector sum. -template <> +template<> inline __device__ float sum(float v) { return v; } -template <> +template<> inline __device__ float sum(float2 v) { return v.x + v.y; } -template <> +template<> inline __device__ float sum(float4 v) { return v.x + v.y + v.z + v.w; } -template <> +template<> inline __device__ float sum(Float4_ v) { return v.x.x + v.x.y + v.y.x + v.y.y; } -template <> +template<> inline __device__ float sum(Float8_ v) { return v.x.x + v.x.y + v.y.x + v.y.y + v.z.x + v.z.y + v.w.x + v.w.y; } // Vector dot product. -inline __device__ float dot(float a, float b) { return a * b; } +inline __device__ float dot(float a, float b) { + return a * b; +} inline __device__ float dot(float2 a, float2 b) { float2 c = mul(a, b); @@ -228,24 +232,42 @@ inline __device__ float dot(Float8_ a, Float8_ b) { } // From float to float. -inline __device__ void from_float(float& dst, float src) { dst = src; } +inline __device__ void from_float(float& dst, float src) { + dst = src; +} -inline __device__ void from_float(float2& dst, float2 src) { dst = src; } +inline __device__ void from_float(float2& dst, float2 src) { + dst = src; +} -inline __device__ void from_float(float4& dst, float4 src) { dst = src; } +inline __device__ void from_float(float4& dst, float4 src) { + dst = src; +} // From float to float. -inline __device__ float to_float(float u) { return u; } +inline __device__ float to_float(float u) { + return u; +} -inline __device__ float2 to_float(float2 u) { return u; } +inline __device__ float2 to_float(float2 u) { + return u; +} -inline __device__ float4 to_float(float4 u) { return u; } +inline __device__ float4 to_float(float4 u) { + return u; +} -inline __device__ Float4_ to_float(Float4_ u) { return u; } +inline __device__ Float4_ to_float(Float4_ u) { + return u; +} -inline __device__ Float8_ to_float(Float8_ u) { return u; } +inline __device__ Float8_ to_float(Float8_ u) { + return u; +} // Zero-out a variable. -inline __device__ void zero(float& dst) { dst = 0.f; } +inline __device__ void zero(float& dst) { + dst = 0.f; +} -} // namespace vllm +} // namespace vllm diff --git a/csrc/attention/dtype_fp8.cuh b/csrc/attention/dtype_fp8.cuh index e714e321b0beb..d11dee91ebe87 100644 --- a/csrc/attention/dtype_fp8.cuh +++ b/csrc/attention/dtype_fp8.cuh @@ -3,39 +3,33 @@ #include "attention_generic.cuh" #include -#ifdef ENABLE_FP8 - #ifndef USE_ROCM - #include - #endif // USE_ROCM -#endif // ENABLE_FP8 +#ifdef ENABLE_FP8_E5M2 +#include +#endif namespace vllm { - -enum class Fp8KVCacheDataType { - kAuto = 0, - kFp8E4M3 = 1, - kFp8E5M2 = 2, -}; - +#if defined(ENABLE_FP8_E5M2) || defined(ENABLE_FP8_E4M3) // fp8 vector types for quantization of kv cache -template <> + +template<> struct Vec { - using Type = uint8_t; + using Type = uint8_t; }; -template <> +template<> struct Vec { - using Type = uint16_t; + using Type = uint16_t; }; -template <> +template<> struct Vec { - using Type = uint32_t; + using Type = uint32_t; }; -template <> +template<> struct Vec { - using Type = uint2; + using Type = uint2; }; +#endif // ENABLE_FP8_E5M2 -} // namespace vllm +} // namespace vllm diff --git a/csrc/cache.h b/csrc/cache.h index 86caa9345361d..10871b3670bac 100644 --- a/csrc/cache.h +++ b/csrc/cache.h @@ -1,32 +1,38 @@ #pragma once -#include +#include #include #include -void swap_blocks(torch::Tensor& src, torch::Tensor& dst, - const torch::Tensor& block_mapping); +void swap_blocks( + torch::Tensor& src, + torch::Tensor& dst, + const std::map& block_mapping); -// Note: the key_caches and value_caches vectors are constant but -// not the Tensors they contain. The vectors need to be const refs -// in order to satisfy pytorch's C++ operator registration code. -void copy_blocks(std::vector const& key_caches, - std::vector const& value_caches, - const torch::Tensor& block_mapping); +void copy_blocks( + std::vector& key_caches, + std::vector& value_caches, + torch::Tensor& block_mapping); -void reshape_and_cache(torch::Tensor& key, torch::Tensor& value, - torch::Tensor& key_cache, torch::Tensor& value_cache, - torch::Tensor& slot_mapping, - const std::string& kv_cache_dtype, - const double kv_scale); +void reshape_and_cache( + torch::Tensor& key, + torch::Tensor& value, + torch::Tensor& key_cache, + torch::Tensor& value_cache, + torch::Tensor& slot_mapping, + const std::string& kv_cache_dtype, + const float kv_scale); -void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value, - torch::Tensor& key_cache, - torch::Tensor& value_cache, - torch::Tensor& slot_mapping, - const std::string& kv_cache_dtype); +void reshape_and_cache_flash( + torch::Tensor& key, + torch::Tensor& value, + torch::Tensor& key_cache, + torch::Tensor& value_cache, + torch::Tensor& slot_mapping, + const std::string& kv_cache_dtype); // Just for unittest -void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache, - const double scale, const std::string& kv_cache_dtype); +void convert_fp8( + torch::Tensor& src_cache, + torch::Tensor& dst_cache); diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 72041076ae009..1e02f7fcbae4c 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -1,14 +1,13 @@ -#include +#include #include #include #include "cuda_compat.h" #include "dispatch_utils.h" - -#ifdef USE_ROCM - #include "quantization/fp8/amd/quant_utils.cuh" -#else - #include "quantization/fp8/nvidia/quant_utils.cuh" +#if defined(ENABLE_FP8_E5M2) +#include "quantization/fp8_e5m2_kvcache/quant_utils.cuh" +#elif defined(ENABLE_FP8_E4M3) +#include "quantization/fp8/amd_detail/quant_utils.cuh" #endif #include @@ -18,17 +17,20 @@ #ifdef USE_ROCM #include -typedef __hip_bfloat16 __nv_bfloat16; + typedef __hip_bfloat16 __nv_bfloat16; #endif -void swap_blocks(torch::Tensor& src, torch::Tensor& dst, - const torch::Tensor& block_mapping) { +void swap_blocks( + torch::Tensor& src, + torch::Tensor& dst, + const std::map& block_mapping) { torch::Device src_device = src.device(); torch::Device dst_device = dst.device(); cudaMemcpyKind memcpy_type; if (src_device.is_cuda() && dst_device.is_cuda()) { - TORCH_CHECK(src_device.index() == dst_device.index(), - "src and dst must be on the same GPU"); + TORCH_CHECK( + src_device.index() == dst_device.index(), + "src and dst must be on the same GPU"); memcpy_type = cudaMemcpyDeviceToDevice; } else if (src_device.is_cuda() && dst_device.is_cpu()) { memcpy_type = cudaMemcpyDeviceToHost; @@ -38,44 +40,41 @@ void swap_blocks(torch::Tensor& src, torch::Tensor& dst, TORCH_CHECK(false, "Invalid device combination"); } - // NOTE(youkaichao): keep in mind that `block_mapping` should be - // a cpu tensor, otherwise every `item` call will require a gpu-cpu - // synchronization. - TORCH_CHECK(block_mapping.device().is_cpu(), "block_mapping must be on CPU"); - - char* src_ptr = static_cast(src.data_ptr()); - char* dst_ptr = static_cast(dst.data_ptr()); + char *src_ptr = static_cast(src.data_ptr()); + char *dst_ptr = static_cast(dst.data_ptr()); const int64_t block_size_in_bytes = src.element_size() * src[0].numel(); - const at::cuda::OptionalCUDAGuard device_guard( - src_device.is_cuda() ? src_device : dst_device); + const at::cuda::OptionalCUDAGuard device_guard(src_device.is_cuda() ? src_device : dst_device); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); // NOTE(woosuk): This can be slow if the number of blocks is large. - const int64_t num_blocks = block_mapping.size(0); - for (size_t i = 0; i < num_blocks; i++) { - int64_t src_block_number = block_mapping[i][0].item(); - int64_t dst_block_number = block_mapping[i][1].item(); + for (const auto& pair : block_mapping) { + int64_t src_block_number = pair.first; + int64_t dst_block_number = pair.second; int64_t src_offset = src_block_number * block_size_in_bytes; int64_t dst_offset = dst_block_number * block_size_in_bytes; - cudaMemcpyAsync(dst_ptr + dst_offset, src_ptr + src_offset, - block_size_in_bytes, memcpy_type, stream); + cudaMemcpyAsync( + dst_ptr + dst_offset, + src_ptr + src_offset, + block_size_in_bytes, + memcpy_type, + stream); } } namespace vllm { // Grid: (num_layers, num_pairs) -template -__global__ void copy_blocks_kernel(int64_t* key_cache_ptrs, - int64_t* value_cache_ptrs, - const int64_t* __restrict__ block_mapping, - const int numel_per_block) { +template +__global__ void copy_blocks_kernel( + int64_t* key_cache_ptrs, + int64_t* value_cache_ptrs, + const int64_t* __restrict__ block_mapping, + const int numel_per_block) { const int layer_idx = blockIdx.x; const int pair_idx = blockIdx.y; scalar_t* key_cache = reinterpret_cast(key_cache_ptrs[layer_idx]); - scalar_t* value_cache = - reinterpret_cast(value_cache_ptrs[layer_idx]); + scalar_t* value_cache = reinterpret_cast(value_cache_ptrs[layer_idx]); int64_t src_block_number = block_mapping[2 * pair_idx]; int64_t dst_block_number = block_mapping[2 * pair_idx + 1]; @@ -93,14 +92,12 @@ __global__ void copy_blocks_kernel(int64_t* key_cache_ptrs, } } -} // namespace vllm +} // namespace vllm -// Note: the key_caches and value_caches vectors are constant but -// not the Tensors they contain. The vectors need to be const refs -// in order to satisfy pytorch's C++ operator registration code. -void copy_blocks(std::vector const& key_caches, - std::vector const& value_caches, - const torch::Tensor& block_mapping) { +void copy_blocks( + std::vector& key_caches, + std::vector& value_caches, + torch::Tensor& block_mapping) { int num_layers = key_caches.size(); TORCH_CHECK(num_layers == value_caches.size()); if (num_layers == 0) { @@ -114,10 +111,8 @@ void copy_blocks(std::vector const& key_caches, int64_t key_cache_ptrs[num_layers]; int64_t value_cache_ptrs[num_layers]; for (int layer_idx = 0; layer_idx < num_layers; ++layer_idx) { - key_cache_ptrs[layer_idx] = - reinterpret_cast(key_caches[layer_idx].data_ptr()); - value_cache_ptrs[layer_idx] = - reinterpret_cast(value_caches[layer_idx].data_ptr()); + key_cache_ptrs[layer_idx] = reinterpret_cast(key_caches[layer_idx].data_ptr()); + value_cache_ptrs[layer_idx] = reinterpret_cast(value_caches[layer_idx].data_ptr()); } // block_mapping is a 2D tensor with shape (num_pairs, 2). @@ -125,12 +120,10 @@ void copy_blocks(std::vector const& key_caches, // Move the data structures to the GPU. // NOTE: This synchronizes the CPU and GPU. - torch::Tensor key_cache_ptrs_tensor = - torch::from_blob(key_cache_ptrs, {num_layers}, torch::kInt64) - .to(cache_device); - torch::Tensor value_cache_ptrs_tensor = - torch::from_blob(value_cache_ptrs, {num_layers}, torch::kInt64) - .to(cache_device); + torch::Tensor key_cache_ptrs_tensor = torch::from_blob( + key_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device); + torch::Tensor value_cache_ptrs_tensor = torch::from_blob( + value_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device); // Launch the kernel. const int numel_per_block = key_caches[0][0].numel(); @@ -139,28 +132,31 @@ void copy_blocks(std::vector const& key_caches, const at::cuda::OptionalCUDAGuard device_guard(cache_device); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES( - key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] { - vllm::copy_blocks_kernel<<>>( - key_cache_ptrs_tensor.data_ptr(), - value_cache_ptrs_tensor.data_ptr(), - block_mapping.data_ptr(), numel_per_block); - })); + key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] { + vllm::copy_blocks_kernel<<>>( + key_cache_ptrs_tensor.data_ptr(), + value_cache_ptrs_tensor.data_ptr(), + block_mapping.data_ptr(), + numel_per_block); + })); } namespace vllm { -template +template __global__ void reshape_and_cache_kernel( - const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] - const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size] - cache_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, - // block_size, x] - cache_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, - // block_size] - const int64_t* __restrict__ slot_mapping, // [num_tokens] - const int key_stride, const int value_stride, const int num_heads, - const int head_size, const int block_size, const int x, - const float kv_scale) { + const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] + const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size] + cache_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + cache_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size] + const int64_t* __restrict__ slot_mapping, // [num_tokens] + const int key_stride, + const int value_stride, + const int num_heads, + const int head_size, + const int block_size, + const int x, + const float kv_scale) { const int64_t token_idx = blockIdx.x; const int64_t slot_idx = slot_mapping[token_idx]; if (slot_idx < 0) { @@ -181,39 +177,47 @@ __global__ void reshape_and_cache_kernel( const int x_idx = head_offset / x; const int x_offset = head_offset % x; - const int64_t tgt_key_idx = - block_idx * num_heads * (head_size / x) * block_size * x + - head_idx * (head_size / x) * block_size * x + x_idx * block_size * x + - block_offset * x + x_offset; - const int64_t tgt_value_idx = - block_idx * num_heads * head_size * block_size + - head_idx * head_size * block_size + head_offset * block_size + - block_offset; + const int64_t tgt_key_idx = block_idx * num_heads * (head_size / x) * block_size * x + + head_idx * (head_size / x) * block_size * x + + x_idx * block_size * x + + block_offset * x + + x_offset; + const int64_t tgt_value_idx = block_idx * num_heads * head_size * block_size + + head_idx * head_size * block_size + + head_offset * block_size + + block_offset; scalar_t tgt_key = key[src_key_idx]; scalar_t tgt_value = value[src_value_idx]; - if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) { + if constexpr (is_fp8_kv_cache) { +#if defined(ENABLE_FP8_E5M2) + key_cache[tgt_key_idx] = fp8_e5m2_unscaled::vec_conversion(tgt_key); + value_cache[tgt_value_idx] = fp8_e5m2_unscaled::vec_conversion(tgt_value); +#elif defined(ENABLE_FP8_E4M3) + key_cache[tgt_key_idx] = fp8_e4m3::scaled_vec_conversion(tgt_key, kv_scale); + value_cache[tgt_value_idx] = fp8_e4m3::scaled_vec_conversion(tgt_value, kv_scale); +#else + assert(false); +#endif + } else { key_cache[tgt_key_idx] = tgt_key; value_cache[tgt_value_idx] = tgt_value; - } else { - key_cache[tgt_key_idx] = - fp8::scaled_convert(tgt_key, kv_scale); - value_cache[tgt_value_idx] = - fp8::scaled_convert(tgt_value, kv_scale); } } } -template +template __global__ void reshape_and_cache_flash_kernel( - const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] - const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size] - scalar_t* __restrict__ k_cache, // [num_blocks, block_size, num_heads, - // head_size] - scalar_t* __restrict__ v_cache, // [num_blocks, block_size, num_heads, - // head_size] - const int64_t* __restrict__ slot_mapping, // [num_tokens] - const int block_stride, const int key_stride, const int value_stride, - const int num_heads, const int head_size, const int block_size) { + const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] + const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size] + scalar_t* __restrict__ k_cache, // [num_blocks, block_size, num_heads, head_size] + scalar_t* __restrict__ v_cache, // [num_blocks, block_size, num_heads, head_size] + const int64_t* __restrict__ slot_mapping, // [num_tokens] + const int block_stride, + const int key_stride, + const int value_stride, + const int num_heads, + const int head_size, + const int block_size) { const int64_t token_idx = blockIdx.x; const int64_t slot_idx = slot_mapping[token_idx]; // NOTE: slot_idx can be -1 if the token is padded @@ -228,37 +232,40 @@ __global__ void reshape_and_cache_flash_kernel( const int64_t src_value_idx = token_idx * value_stride + i; const int head_idx = i / head_size; const int head_offset = i % head_size; - const int64_t tgt_value_idx = block_idx * block_stride + - block_offset * num_heads * head_size + - head_idx * head_size + head_offset; + const int64_t tgt_value_idx = block_idx * block_stride + + block_offset * num_heads * head_size + + head_idx * head_size + + head_offset; k_cache[tgt_value_idx] = key[src_key_idx]; v_cache[tgt_value_idx] = value[src_value_idx]; } } -} // namespace vllm - -// KV_T is the stored data type of kv-cache. -// CACHE_T is the data type of key and value tensors. -// KV_DTYPE is the real data type of kv-cache. -#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, KV_DTYPE) \ - vllm::reshape_and_cache_kernel \ - <<>>( \ - reinterpret_cast(key.data_ptr()), \ - reinterpret_cast(value.data_ptr()), \ - reinterpret_cast(key_cache.data_ptr()), \ - reinterpret_cast(value_cache.data_ptr()), \ - slot_mapping.data_ptr(), key_stride, value_stride, \ - num_heads, head_size, block_size, x, kv_scale); +} // namespace vllm + +#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, IS_FP8_KV_CACHE) \ + vllm::reshape_and_cache_kernel<<>>( \ + reinterpret_cast(key.data_ptr()), \ + reinterpret_cast(value.data_ptr()), \ + reinterpret_cast(key_cache.data_ptr()), \ + reinterpret_cast(value_cache.data_ptr()), \ + slot_mapping.data_ptr(), \ + key_stride, \ + value_stride, \ + num_heads, \ + head_size, \ + block_size, \ + x, \ + kv_scale); void reshape_and_cache( - torch::Tensor& key, // [num_tokens, num_heads, head_size] - torch::Tensor& value, // [num_tokens, num_heads, head_size] - torch::Tensor& - key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] - torch::Tensor& - value_cache, // [num_blocks, num_heads, head_size, block_size] - torch::Tensor& slot_mapping, // [num_tokens] - const std::string& kv_cache_dtype, const double kv_scale) { + torch::Tensor& key, // [num_tokens, num_heads, head_size] + torch::Tensor& value, // [num_tokens, num_heads, head_size] + torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] + torch::Tensor& slot_mapping, // [num_tokens] + const std::string& kv_cache_dtype, + const float kv_scale) +{ int num_tokens = key.size(0); int num_heads = key.size(1); int head_size = key.size(2); @@ -272,18 +279,35 @@ void reshape_and_cache( dim3 block(std::min(num_heads * head_size, 512)); const at::cuda::OptionalCUDAGuard device_guard(device_of(key)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - - DISPATCH_BY_KV_CACHE_DTYPE(key.dtype(), kv_cache_dtype, - CALL_RESHAPE_AND_CACHE) + if (kv_cache_dtype == "auto") { + if (key.dtype() == at::ScalarType::Float) { + CALL_RESHAPE_AND_CACHE(float, float, false); + } else if (key.dtype() == at::ScalarType::Half) { + CALL_RESHAPE_AND_CACHE(uint16_t, uint16_t, false); + } else if (key.dtype() == at::ScalarType::BFloat16) { + CALL_RESHAPE_AND_CACHE(__nv_bfloat16, __nv_bfloat16, false); + } + } else if (kv_cache_dtype == "fp8") { + if (key.dtype() == at::ScalarType::Float) { + CALL_RESHAPE_AND_CACHE(float, uint8_t, true); + } else if (key.dtype() == at::ScalarType::Half) { + CALL_RESHAPE_AND_CACHE(uint16_t, uint8_t, true); + } else if (key.dtype() == at::ScalarType::BFloat16) { + CALL_RESHAPE_AND_CACHE(__nv_bfloat16, uint8_t, true); + } + } else { + TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype); + } } void reshape_and_cache_flash( - torch::Tensor& key, // [num_tokens, num_heads, head_size] - torch::Tensor& value, // [num_tokens, num_heads, head_size] - torch::Tensor& k_cache, // [num_blocks, block_size, num_heads, head_size] - torch::Tensor& v_cache, // [num_blocks, block_size, num_heads, head_size] - torch::Tensor& slot_mapping, // [num_tokens] - const std::string& kv_cache_dtype) { + torch::Tensor& key, // [num_tokens, num_heads, head_size] + torch::Tensor& value, // [num_tokens, num_heads, head_size] + torch::Tensor& k_cache, // [num_blocks, block_size, num_heads, head_size] + torch::Tensor& v_cache, // [num_blocks, block_size, num_heads, head_size] + torch::Tensor& slot_mapping, // [num_tokens] + const std::string& kv_cache_dtype) +{ // FIXME: only support auto datatype, does not support fp8 if (kv_cache_dtype != "auto") { TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype); @@ -303,47 +327,63 @@ void reshape_and_cache_flash( const at::cuda::OptionalCUDAGuard device_guard(device_of(key)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); VLLM_DISPATCH_FLOATING_TYPES( - key.scalar_type(), "reshape_and_cache_flash", [&] { - vllm::reshape_and_cache_flash_kernel - <<>>( - key.data_ptr(), value.data_ptr(), - k_cache.data_ptr(), v_cache.data_ptr(), - slot_mapping.data_ptr(), block_stride, key_stride, - value_stride, num_heads, head_size, block_size); - }); + key.scalar_type(), + "reshape_and_cache_flash", + [&] { + vllm::reshape_and_cache_flash_kernel<<>>( + key.data_ptr(), + value.data_ptr(), + k_cache.data_ptr(), + v_cache.data_ptr(), + slot_mapping.data_ptr(), + block_stride, + key_stride, + value_stride, + num_heads, + head_size, + block_size); + }); } namespace vllm { -template -__global__ void convert_fp8_kernel(const Tin* __restrict__ src_cache, - Tout* __restrict__ dst_cache, - const float kv_scale, - const int64_t block_stride) { +template +__global__ void convert_fp8_kernel( + const Tin* __restrict__ src_cache, + Tout* __restrict__ dst_cache, + const int64_t block_stride) { const int64_t block_idx = blockIdx.x; for (int i = threadIdx.x; i < block_stride; i += blockDim.x) { int64_t idx = block_idx * block_stride + i; - dst_cache[idx] = - fp8::scaled_convert(src_cache[idx], kv_scale); +#if defined(ENABLE_FP8_E5M2) + dst_cache[idx] = fp8_e5m2_unscaled::vec_conversion(src_cache[idx]); +#elif defined(ENABLE_FP8_E4M3) + dst_cache[idx] = fp8_e4m3::vec_conversion(src_cache[idx]); +#else + assert(false); +#endif } } -} // namespace vllm +} // namespace vllm -#define CALL_CONVERT_FP8(Tout, Tin, KV_DTYPE) \ - vllm::convert_fp8_kernel<<>>( \ - reinterpret_cast(src_cache.data_ptr()), \ - reinterpret_cast(dst_cache.data_ptr()), kv_scale, block_stride); +#define CALL_CONVERT_FP8(Tout, Tin) \ + vllm::convert_fp8_kernel<<>>( \ + reinterpret_cast(src_cache.data_ptr()), \ + reinterpret_cast(dst_cache.data_ptr()), \ + block_stride); -// Only for testing. -void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache, - const double kv_scale, const std::string& kv_cache_dtype) { +void convert_fp8( + torch::Tensor& src_cache, + torch::Tensor& dst_cache) +{ torch::Device src_device = src_cache.device(); torch::Device dst_device = dst_cache.device(); TORCH_CHECK(src_device.is_cuda(), "src must be on a GPU") TORCH_CHECK(dst_device.is_cuda(), "dst must be on a GPU") - TORCH_CHECK(src_device.index() == dst_device.index(), - "src and dst must be on the same GPU"); + TORCH_CHECK( + src_device.index() == dst_device.index(), + "src and dst must be on the same GPU"); at::cuda::OptionalCUDAGuard device_guard(src_device); int64_t num_blocks = src_cache.size(0); @@ -353,37 +393,17 @@ void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache, dim3 block(std::min(block_stride, int64_t(512))); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - if (kv_cache_dtype == "auto") { - if (src_cache.dtype() == at::ScalarType::Float) { - CALL_CONVERT_FP8(uint8_t, float, vllm::Fp8KVCacheDataType::kAuto); - } else if (src_cache.dtype() == at::ScalarType::Half) { - CALL_CONVERT_FP8(uint8_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto); - } else if (src_cache.dtype() == at::ScalarType::BFloat16) { - CALL_CONVERT_FP8(uint8_t, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto); - } else if (dst_cache.dtype() == at::ScalarType::Float) { - CALL_CONVERT_FP8(float, uint8_t, vllm::Fp8KVCacheDataType::kAuto); - } else if (dst_cache.dtype() == at::ScalarType::Half) { - CALL_CONVERT_FP8(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kAuto); - } else if (dst_cache.dtype() == at::ScalarType::BFloat16) { - CALL_CONVERT_FP8(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kAuto); - } - } else if (kv_cache_dtype == "fp8" || kv_cache_dtype == "fp8_e4m3") { - if (src_cache.dtype() == at::ScalarType::Float) { - CALL_CONVERT_FP8(uint8_t, float, vllm::Fp8KVCacheDataType::kFp8E4M3); - } else if (src_cache.dtype() == at::ScalarType::Half) { - CALL_CONVERT_FP8(uint8_t, uint16_t, vllm::Fp8KVCacheDataType::kFp8E4M3); - } else if (src_cache.dtype() == at::ScalarType::BFloat16) { - CALL_CONVERT_FP8(uint8_t, __nv_bfloat16, - vllm::Fp8KVCacheDataType::kFp8E4M3); - } else if (dst_cache.dtype() == at::ScalarType::Float) { - CALL_CONVERT_FP8(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); - } else if (dst_cache.dtype() == at::ScalarType::Half) { - CALL_CONVERT_FP8(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); - } else if (dst_cache.dtype() == at::ScalarType::BFloat16) { - CALL_CONVERT_FP8(__nv_bfloat16, uint8_t, - vllm::Fp8KVCacheDataType::kFp8E4M3); - } - } else { - TORCH_CHECK(false, "Unsupported data type: ", kv_cache_dtype); + if (src_cache.dtype() == at::ScalarType::Float) { + CALL_CONVERT_FP8(uint8_t, float); + } else if (src_cache.dtype() == at::ScalarType::Half) { + CALL_CONVERT_FP8(uint8_t, uint16_t); + } else if (src_cache.dtype() == at::ScalarType::BFloat16) { + CALL_CONVERT_FP8(uint8_t, __nv_bfloat16); + } else if (dst_cache.dtype() == at::ScalarType::Float) { + CALL_CONVERT_FP8(float, uint8_t); + } else if (dst_cache.dtype() == at::ScalarType::Half) { + CALL_CONVERT_FP8(uint16_t, uint8_t); + } else if (dst_cache.dtype() == at::ScalarType::BFloat16) { + CALL_CONVERT_FP8(__nv_bfloat16, uint8_t); } } diff --git a/csrc/cpu/activation.cpp b/csrc/cpu/activation.cpp index 039b8d5c30d46..1bd24eb79d129 100644 --- a/csrc/cpu/activation.cpp +++ b/csrc/cpu/activation.cpp @@ -1,10 +1,10 @@ #include "cpu_types.hpp" namespace { -template -void activation_kernel(int num_tokens, int d, scalar_t* __restrict__ input, - scalar_t* __restrict__ output) { +void activation_kernel(int num_tokens, int d, scalar_t *__restrict__ input, + scalar_t *__restrict__ output) { using scalar_vec_t = vec_op::vec_t; constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num(); @@ -34,13 +34,13 @@ void activation_kernel(int num_tokens, int d, scalar_t* __restrict__ input, } } -FORCE_INLINE vec_op::FP32Vec8 silu_act(const vec_op::FP32Vec8& x) { +FORCE_INLINE vec_op::FP32Vec8 silu_act(const vec_op::FP32Vec8 &x) { const vec_op::FP32Vec8 zeros(0.0); const vec_op::FP32Vec8 ones(1.0); return x / (ones + (zeros - x).exp()); } -FORCE_INLINE vec_op::FP32Vec8 gelu_new_act(const vec_op::FP32Vec8& x) { +FORCE_INLINE vec_op::FP32Vec8 gelu_new_act(const vec_op::FP32Vec8 &x) { const vec_op::FP32Vec8 ones(1.0); const vec_op::FP32Vec8 w1(0.79788456f); const vec_op::FP32Vec8 w2(0.044715f); @@ -50,7 +50,7 @@ FORCE_INLINE vec_op::FP32Vec8 gelu_new_act(const vec_op::FP32Vec8& x) { return w3 * x * (ones + t); } -FORCE_INLINE vec_op::FP32Vec8 gelu_fast_act(const vec_op::FP32Vec8& x) { +FORCE_INLINE vec_op::FP32Vec8 gelu_fast_act(const vec_op::FP32Vec8 &x) { const vec_op::FP32Vec8 ones(1.0); const vec_op::FP32Vec8 w1(0.79788456f); const vec_op::FP32Vec8 w2(0.044715f); @@ -59,21 +59,14 @@ FORCE_INLINE vec_op::FP32Vec8 gelu_fast_act(const vec_op::FP32Vec8& x) { return w3 * x * (ones + t); } -FORCE_INLINE vec_op::FP32Vec8 gelu_quick_act(const vec_op::FP32Vec8& x) { - const vec_op::FP32Vec8 zeros(0.0); - const vec_op::FP32Vec8 ones(1.0); - const vec_op::FP32Vec8 w1(1.702f); - return x / (ones + (zeros - w1 * x).exp()); -} - -FORCE_INLINE vec_op::FP32Vec8 gelu_act(const vec_op::FP32Vec8& x) { +FORCE_INLINE vec_op::FP32Vec8 gelu_act(const vec_op::FP32Vec8 &x) { const vec_op::FP32Vec8 ones(1.0); const vec_op::FP32Vec8 w1(M_SQRT1_2); const vec_op::FP32Vec8 w2(0.5); return x * w2 * (ones + (x * w1).er()); } -FORCE_INLINE vec_op::FP32Vec8 gelu_tanh_act(const vec_op::FP32Vec8& x) { +FORCE_INLINE vec_op::FP32Vec8 gelu_tanh_act(const vec_op::FP32Vec8 &x) { const vec_op::FP32Vec8 ones(1.0); const vec_op::FP32Vec8 w1(M_SQRT2 * M_2_SQRTPI * 0.5); const vec_op::FP32Vec8 w2(0.5); @@ -82,36 +75,40 @@ FORCE_INLINE vec_op::FP32Vec8 gelu_tanh_act(const vec_op::FP32Vec8& x) { const vec_op::FP32Vec8 inner = w1 * (x + x_3 * w3); return x * w2 * (ones + inner.tanh()); } -}; // namespace +}; // namespace -void silu_and_mul(torch::Tensor& out, torch::Tensor& input) { +void silu_and_mul(torch::Tensor &out, torch::Tensor &input) { int num_tokens = input.numel() / input.size(-1); int d = input.size(-1) / 2; - VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "silu_and_mul_impl", [&] { - CPU_KERNEL_GUARD_IN(silu_and_mul_impl) - activation_kernel( - num_tokens, d, input.data_ptr(), out.data_ptr()); - CPU_KERNEL_GUARD_OUT(silu_and_mul_impl) - }); + VLLM_DISPATCH_FLOATING_TYPES( + input.scalar_type(), "silu_and_mul_impl", [&] { + CPU_KERNEL_GUARD_IN(silu_and_mul_impl) + activation_kernel(num_tokens, d, + input.data_ptr(), + out.data_ptr()); + CPU_KERNEL_GUARD_OUT(silu_and_mul_impl) + }); } -void gelu_and_mul(torch::Tensor& out, // [..., d] - torch::Tensor& input) // [..., 2 * d] +void gelu_and_mul(torch::Tensor &out, // [..., d] + torch::Tensor &input) // [..., 2 * d] { int num_tokens = input.numel() / input.size(-1); int d = input.size(-1) / 2; - VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "gelu_and_mul_impl", [&] { - CPU_KERNEL_GUARD_IN(gelu_and_mul_impl) - activation_kernel( - num_tokens, d, input.data_ptr(), out.data_ptr()); - CPU_KERNEL_GUARD_OUT(gelu_and_mul_impl) - }); + VLLM_DISPATCH_FLOATING_TYPES( + input.scalar_type(), "gelu_and_mul_impl", [&] { + CPU_KERNEL_GUARD_IN(gelu_and_mul_impl) + activation_kernel(num_tokens, d, + input.data_ptr(), + out.data_ptr()); + CPU_KERNEL_GUARD_OUT(gelu_and_mul_impl) + }); } -void gelu_tanh_and_mul(torch::Tensor& out, // [..., d] - torch::Tensor& input) // [..., 2 * d] +void gelu_tanh_and_mul(torch::Tensor &out, // [..., d] + torch::Tensor &input) // [..., 2 * d] { int num_tokens = input.numel() / input.size(-1); int d = input.size(-1) / 2; @@ -126,7 +123,7 @@ void gelu_tanh_and_mul(torch::Tensor& out, // [..., d] }); } -void gelu_new(torch::Tensor& out, torch::Tensor& input) { +void gelu_new(torch::Tensor &out, torch::Tensor &input) { int num_tokens = input.numel() / input.size(-1); int d = input.size(-1); @@ -138,7 +135,7 @@ void gelu_new(torch::Tensor& out, torch::Tensor& input) { }); } -void gelu_fast(torch::Tensor& out, torch::Tensor& input) { +void gelu_fast(torch::Tensor &out, torch::Tensor &input) { int num_tokens = input.numel() / input.size(-1); int d = input.size(-1); @@ -149,15 +146,3 @@ void gelu_fast(torch::Tensor& out, torch::Tensor& input) { CPU_KERNEL_GUARD_OUT(gelu_fast_impl) }); } - -void gelu_quick(torch::Tensor& out, torch::Tensor& input) { - int num_tokens = input.numel() / input.size(-1); - int d = input.size(-1); - - VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "gelu_quick_impl", [&] { - CPU_KERNEL_GUARD_IN(gelu_quick_impl) - activation_kernel( - num_tokens, d, input.data_ptr(), out.data_ptr()); - CPU_KERNEL_GUARD_OUT(gelu_quick_impl) - }); -} diff --git a/csrc/cpu/attention.cpp b/csrc/cpu/attention.cpp index 8367093325314..c1d765be05598 100644 --- a/csrc/cpu/attention.cpp +++ b/csrc/cpu/attention.cpp @@ -2,8 +2,7 @@ namespace { -template -struct KernelVecType { +template struct KernelVecType { using q_load_vec_type = void; using q_vec_type = void; using k_load_vec_type = void; @@ -12,8 +11,7 @@ struct KernelVecType { using v_load_vec_type = void; }; -template <> -struct KernelVecType { +template <> struct KernelVecType { using q_load_vec_type = vec_op::FP32Vec4; using q_vec_type = vec_op::FP32Vec16; using k_load_vec_type = vec_op::FP32Vec16; @@ -23,8 +21,7 @@ struct KernelVecType { }; #ifdef __AVX512BF16__ -template <> -struct KernelVecType { +template <> struct KernelVecType { using q_load_vec_type = vec_op::BF16Vec8; using q_vec_type = vec_op::BF16Vec32; using k_load_vec_type = vec_op::BF16Vec32; @@ -33,8 +30,7 @@ struct KernelVecType { using v_load_vec_type = vec_op::BF16Vec16; }; #else -template <> -struct KernelVecType { +template <> struct KernelVecType { using q_load_vec_type = vec_op::BF16Vec8; using q_vec_type = vec_op::FP32Vec16; using k_load_vec_type = vec_op::BF16Vec16; @@ -45,7 +41,7 @@ struct KernelVecType { #endif template -FORCE_INLINE std::pair reduceSoftmax(T* data, const int size, +FORCE_INLINE std::pair reduceSoftmax(T *data, const int size, const int capacity) { T max = data[0]; for (int i = 1; i < size; ++i) { @@ -71,11 +67,10 @@ FORCE_INLINE std::pair reduceSoftmax(T* data, const int size, } template -FORCE_INLINE std::pair reduceSoftmaxAlibi(T* data, const int size, - const int capacity, - const float alibi_slope, - const int start_index, - const int seq_len) { +FORCE_INLINE std::pair +reduceSoftmaxAlibi(T *data, const int size, const int capacity, + const float alibi_slope, const int start_index, + const int seq_len) { data[0] += alibi_slope * (start_index - seq_len + 1); T max = data[0]; for (int i = 1; i < size; ++i) { @@ -103,7 +98,7 @@ FORCE_INLINE std::pair reduceSoftmaxAlibi(T* data, const int size, } template -FORCE_INLINE void reducePartitonSoftmax(const T* max_data, T* sum_data, +FORCE_INLINE void reducePartitonSoftmax(const T *max_data, T *sum_data, const int size) { T max = max_data[0]; for (int i = 1; i < size; ++i) { @@ -137,9 +132,9 @@ struct reduceQKBlockKernel { static_assert(k_load_vec_type::get_elem_num() % x == 0); static_assert(q_load_vec_type::get_elem_num() * sizeof(scalar_t) == 16); - FORCE_INLINE static void call(const scalar_t* __restrict__ q, - const scalar_t* __restrict__ k_block, - float* __restrict__ logits, float scale, + FORCE_INLINE static void call(const scalar_t *__restrict__ q, + const scalar_t *__restrict__ k_block, + float *__restrict__ logits, float scale, const int token_num) { const int group_num = (token_num + TOKEN_PER_GROUP - 1) / TOKEN_PER_GROUP; @@ -201,8 +196,8 @@ struct reduceQKBlockKernel { template -FORCE_INLINE void reduceValueBlock(const float* prob, const scalar_t* v_block, - acc_t&& acc) { +FORCE_INLINE void reduceValueBlock(const float *prob, const scalar_t *v_block, + acc_t &&acc) { using v_load_vec_type = typename KernelVecType::v_load_vec_type; constexpr int ELEM_NUM = v_load_vec_type::get_elem_num(); static_assert(BLOCK_SIZE == ELEM_NUM); @@ -214,27 +209,27 @@ FORCE_INLINE void reduceValueBlock(const float* prob, const scalar_t* v_block, acc[head_elem_idx] = acc[head_elem_idx] + prob_vec * fp32_v_vec; }); } -}; // namespace +}; // namespace // Paged attention v1 namespace { template struct paged_attention_v1_impl { - static void call( - scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] - const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] - const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, + static void + call(scalar_t *__restrict__ out, // [num_seqs, num_heads, head_size] + const scalar_t *__restrict__ q, // [num_seqs, num_heads, head_size] + const scalar_t *__restrict__ k_cache, // [num_blocks, num_kv_heads, // head_size/x, block_size, x] - const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, + const scalar_t *__restrict__ v_cache, // [num_blocks, num_kv_heads, // head_size, block_size] - const int num_kv_heads, const float scale, - const int* __restrict__ block_tables, // [num_seqs, - // max_num_blocks_per_seq] - const int* __restrict__ seq_lens, // [num_seqs] - const int max_num_blocks_per_seq, - const float* __restrict__ alibi_slopes, // [num_heads] - const int q_stride, const int kv_block_stride, const int kv_head_stride, - const int num_seqs, const int num_heads) { + const int num_kv_heads, const float scale, + const int + *__restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int *__restrict__ seq_lens, // [num_seqs] + const int max_num_blocks_per_seq, + const float *__restrict__ alibi_slopes, // [num_heads] + const int q_stride, const int kv_block_stride, const int kv_head_stride, + const int num_seqs, const int num_heads) { constexpr int x = 16 / sizeof(scalar_t); const int num_queries_per_kv = num_heads / num_kv_heads; @@ -248,31 +243,32 @@ struct paged_attention_v1_impl { size_t logits_bytes = parallel_work_item_num * max_seq_len_padded * sizeof(float); - float* logits = (float*)std::aligned_alloc( - 64, logits_bytes); // Cacheline alignment for each context token. - // [parallel_work_item_num, max_seq_len_padded] + float *logits = (float *)std::aligned_alloc( + 64, logits_bytes); // Cacheline alignment for each context token. + // [parallel_work_item_num, max_seq_len_padded] #pragma omp parallel for collapse(2) schedule(dynamic, 1) for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) { for (int head_idx = 0; head_idx < num_heads; ++head_idx) { int seq_len = seq_lens[seq_idx]; - const int* seq_block_table = + const int *seq_block_table = block_tables + max_num_blocks_per_seq * seq_idx; const int block_num = (seq_len + BLOCK_SIZE - 1) / BLOCK_SIZE; const int64_t kv_head_idx = head_idx / num_queries_per_kv; - const scalar_t* __restrict__ q_vec_ptr = + const scalar_t *__restrict__ q_vec_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE; - const int last_block_token_num = seq_len - (block_num - 1) * BLOCK_SIZE; - float* __restrict__ thread_block_logits = + const int last_block_token_num = + seq_len - (block_num - 1) * BLOCK_SIZE; + float *__restrict__ thread_block_logits = logits + omp_get_thread_num() * max_seq_len_padded; // Compute logits for (int block_idx = 0; block_idx < block_num; ++block_idx) { const int64_t physical_block_idx = seq_block_table[block_idx]; - const scalar_t* __restrict__ k_block_cache_ptr = + const scalar_t *__restrict__ k_block_cache_ptr = k_cache + physical_block_idx * kv_block_stride + kv_head_idx * kv_head_stride; - float* __restrict__ head_block_logits = + float *__restrict__ head_block_logits = thread_block_logits + block_idx * BLOCK_SIZE; reduceQKBlockKernel::call( @@ -286,7 +282,8 @@ struct paged_attention_v1_impl { block_num * BLOCK_SIZE, alibi_slopes[head_idx], 0, seq_len); } else { - reduceSoftmax(thread_block_logits, seq_len, block_num * BLOCK_SIZE); + reduceSoftmax(thread_block_logits, seq_len, + block_num * BLOCK_SIZE); } // Compute value @@ -296,14 +293,14 @@ struct paged_attention_v1_impl { for (int head_part_idx = 0; head_part_idx < head_partition_num; ++head_part_idx) { vec_op::FP32Vec16 accums[head_elem_num_per_partition]; - scalar_t* __restrict__ out_ptr = + scalar_t *__restrict__ out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE + head_part_idx * head_elem_num_per_partition; for (int block_idx = 0; block_idx < block_num; ++block_idx) { const int64_t physical_block_idx = seq_block_table[block_idx]; - const float* __restrict__ prob_vec_ptr = + const float *__restrict__ prob_vec_ptr = thread_block_logits + block_idx * BLOCK_SIZE; - const scalar_t* __restrict__ v_block_cache_ptr = + const scalar_t *__restrict__ v_block_cache_ptr = v_cache + physical_block_idx * kv_block_stride + kv_head_idx * kv_head_stride + BLOCK_SIZE * head_part_idx * head_elem_num_per_partition; @@ -314,7 +311,7 @@ struct paged_attention_v1_impl { if (block_idx != block_num - 1) { const int64_t next_physical_block_idx = seq_block_table[block_idx + 1]; - const scalar_t* __restrict__ next_v_block_cache_ptr = + const scalar_t *__restrict__ next_v_block_cache_ptr = v_cache + next_physical_block_idx * kv_block_stride + kv_head_idx * kv_head_stride + BLOCK_SIZE * head_part_idx * head_elem_num_per_partition; @@ -343,16 +340,16 @@ struct paged_attention_v1_impl { #define LAUNCH_V1_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE) \ paged_attention_v1_impl::call( \ out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \ - block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \ + block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \ alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, num_seqs, \ num_heads); template void paged_attention_v1_impl_launcher( - torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, - torch::Tensor& value_cache, int num_kv_heads, float scale, - torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len, - const c10::optional& alibi_slopes) { + torch::Tensor &out, torch::Tensor &query, torch::Tensor &key_cache, + torch::Tensor &value_cache, int num_kv_heads, float scale, + torch::Tensor &block_tables, torch::Tensor &seq_lens, + int max_seq_len, const c10::optional &alibi_slopes) { int num_seqs = query.size(0); int num_heads = query.size(1); int head_size = query.size(2); @@ -362,74 +359,68 @@ void paged_attention_v1_impl_launcher( int kv_head_stride = key_cache.stride(1); // NOTE: alibi_slopes is optional. - const float* alibi_slopes_ptr = + const float *alibi_slopes_ptr = alibi_slopes - ? reinterpret_cast(alibi_slopes.value().data_ptr()) + ? reinterpret_cast(alibi_slopes.value().data_ptr()) : nullptr; - T* out_ptr = reinterpret_cast(out.data_ptr()); - T* query_ptr = reinterpret_cast(query.data_ptr()); - T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); - T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); - int* block_tables_ptr = block_tables.data_ptr(); - int* seq_lens_ptr = seq_lens.data_ptr(); + T *out_ptr = reinterpret_cast(out.data_ptr()); + T *query_ptr = reinterpret_cast(query.data_ptr()); + T *key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); + T *value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); + int *block_tables_ptr = block_tables.data_ptr(); + int *seq_lens_ptr = seq_lens.data_ptr(); switch (head_size) { - case 64: - LAUNCH_V1_ATTENTION_KERNEL(T, 64, BLOCK_SIZE); - break; - case 80: - LAUNCH_V1_ATTENTION_KERNEL(T, 80, BLOCK_SIZE); - break; - case 96: - LAUNCH_V1_ATTENTION_KERNEL(T, 96, BLOCK_SIZE); - break; - case 112: - LAUNCH_V1_ATTENTION_KERNEL(T, 112, BLOCK_SIZE); - break; - case 128: - LAUNCH_V1_ATTENTION_KERNEL(T, 128, BLOCK_SIZE); - break; - case 192: - LAUNCH_V1_ATTENTION_KERNEL(T, 192, BLOCK_SIZE); - break; - case 256: - LAUNCH_V1_ATTENTION_KERNEL(T, 256, BLOCK_SIZE); - break; - default: - TORCH_CHECK(false, "Unsupported head size: ", head_size); - break; + case 64: + LAUNCH_V1_ATTENTION_KERNEL(T, 64, BLOCK_SIZE); + break; + case 80: + LAUNCH_V1_ATTENTION_KERNEL(T, 80, BLOCK_SIZE); + break; + case 96: + LAUNCH_V1_ATTENTION_KERNEL(T, 96, BLOCK_SIZE); + break; + case 112: + LAUNCH_V1_ATTENTION_KERNEL(T, 112, BLOCK_SIZE); + break; + case 128: + LAUNCH_V1_ATTENTION_KERNEL(T, 128, BLOCK_SIZE); + break; + case 256: + LAUNCH_V1_ATTENTION_KERNEL(T, 256, BLOCK_SIZE); + break; + default: + TORCH_CHECK(false, "Unsupported head size: ", head_size); + break; } } -#define CALL_V1_KERNEL_LAUNCHER(T, BLOCK_SIZE) \ - paged_attention_v1_impl_launcher( \ - out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \ +#define CALL_V1_KERNEL_LAUNCHER(T, BLOCK_SIZE) \ + paged_attention_v1_impl_launcher( \ + out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \ seq_lens, max_seq_len, alibi_slopes); -#define CALL_V1_KERNEL_LAUNCHER_BLOCK_SIZE(T) \ - switch (block_size) { \ - case 16: \ - CALL_V1_KERNEL_LAUNCHER(T, 16); \ - break; \ - default: \ - TORCH_CHECK(false, "Unsupported block size: ", block_size); \ - break; \ +#define CALL_V1_KERNEL_LAUNCHER_BLOCK_SIZE(T) \ + switch (block_size) { \ + case 16: \ + CALL_V1_KERNEL_LAUNCHER(T, 16); \ + break; \ + default: \ + TORCH_CHECK(false, "Unsupported block size: ", block_size); \ + break; \ } -} // namespace - -void paged_attention_v1( - torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, - torch::Tensor& value_cache, int64_t num_kv_heads, double scale, - torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, - int64_t max_seq_len, const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank, - const int64_t blocksparse_local_blocks, - const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, - const int64_t blocksparse_head_sliding_step) { +} // namespace + +void paged_attention_v1(torch::Tensor &out, torch::Tensor &query, + torch::Tensor &key_cache, torch::Tensor &value_cache, + int num_kv_heads, float scale, + torch::Tensor &block_tables, + torch::Tensor &seq_lens, int block_size, + int max_seq_len, + const c10::optional &alibi_slopes, + const std::string &kv_cache_dtype, float kv_scale) { TORCH_CHECK(kv_scale == 1.0f); - TORCH_CHECK(blocksparse_vert_stride <= 1, - "CPU backend does not support blocksparse attention yet."); VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v1_impl", [&] { CPU_KERNEL_GUARD_IN(paged_attention_v1_impl) @@ -443,24 +434,23 @@ namespace { template struct paged_attention_v2_impl { static void call( - scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] - float* __restrict__ exp_sums, // [num_seqs, num_heads, - // max_num_partitions] - float* __restrict__ max_logits, // [num_seqs, num_heads, - // max_num_partitions] - scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, - // max_num_partitions, head_size] - const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] - const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, - // head_size/x, block_size, x] - const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, - // head_size, block_size] + scalar_t *__restrict__ out, // [num_seqs, num_heads, head_size] + float *__restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] + float + *__restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] + scalar_t *__restrict__ tmp_out, // [num_seqs, num_heads, + // max_num_partitions, head_size] + const scalar_t *__restrict__ q, // [num_seqs, num_heads, head_size] + const scalar_t *__restrict__ k_cache, // [num_blocks, num_kv_heads, + // head_size/x, block_size, x] + const scalar_t *__restrict__ v_cache, // [num_blocks, num_kv_heads, + // head_size, block_size] const int num_kv_heads, const float scale, - const int* __restrict__ block_tables, // [num_seqs, - // max_num_blocks_per_seq] - const int* __restrict__ seq_lens, // [num_seqs] + const int + *__restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int *__restrict__ seq_lens, // [num_seqs] const int max_num_blocks_per_seq, - const float* __restrict__ alibi_slopes, // [num_heads] + const float *__restrict__ alibi_slopes, // [num_heads] const int q_stride, const int kv_block_stride, const int kv_head_stride, const int num_seqs, const int num_heads, const int max_num_partitions) { constexpr int x = 16 / sizeof(scalar_t); @@ -478,7 +468,8 @@ struct paged_attention_v2_impl { const int seq_len = seq_lens[seq_idx]; const int start_token_idx = partition_idx * PARTITION_SIZE; - if (start_token_idx >= seq_len) continue; + if (start_token_idx >= seq_len) + continue; const int partition_num = (seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE; @@ -486,14 +477,15 @@ struct paged_attention_v2_impl { const int token_num = (std::min(seq_len, start_token_idx + PARTITION_SIZE) - start_token_idx); - const int block_num = (token_num + BLOCK_SIZE - 1) / BLOCK_SIZE; + const int block_num = + (token_num + BLOCK_SIZE - 1) / BLOCK_SIZE; const int last_block_token_num = token_num - (block_num - 1) * BLOCK_SIZE; - const int* seq_block_table = block_tables + + const int *seq_block_table = block_tables + max_num_blocks_per_seq * seq_idx + start_token_idx / BLOCK_SIZE; const int64_t kv_head_idx = head_idx / num_queries_per_kv; - const scalar_t* __restrict__ q_vec_ptr = + const scalar_t *__restrict__ q_vec_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE; float logits[PARTITION_SIZE] __attribute__((aligned(64))) = {0}; @@ -501,10 +493,10 @@ struct paged_attention_v2_impl { // Compute logits for (int block_idx = 0; block_idx < block_num; ++block_idx) { const int64_t physical_block_idx = seq_block_table[block_idx]; - const scalar_t* __restrict__ k_block_cache_ptr = + const scalar_t *__restrict__ k_block_cache_ptr = k_cache + physical_block_idx * kv_block_stride + kv_head_idx * kv_head_stride; - float* __restrict__ head_block_logits = + float *__restrict__ head_block_logits = logits + block_idx * BLOCK_SIZE; reduceQKBlockKernel::call( @@ -518,13 +510,13 @@ struct paged_attention_v2_impl { logits, token_num, block_num * BLOCK_SIZE, alibi_slopes[head_idx], start_token_idx, seq_len); } else { - max_and_sum = - reduceSoftmax(logits, token_num, block_num * BLOCK_SIZE); + max_and_sum = reduceSoftmax(logits, token_num, + block_num * BLOCK_SIZE); } - auto&& [max_logit, exp_sum] = max_and_sum; + auto &&[max_logit, exp_sum] = max_and_sum; - scalar_t* __restrict__ output_buffer = nullptr; + scalar_t *__restrict__ output_buffer = nullptr; if (!no_reduce) { auto idx = seq_idx * num_heads * max_num_partitions + head_idx * max_num_partitions + partition_idx; @@ -546,13 +538,13 @@ struct paged_attention_v2_impl { for (int head_part_idx = 0; head_part_idx < head_partition_num; ++head_part_idx) { vec_op::FP32Vec16 accums[head_elem_num_per_partition]; - scalar_t* __restrict__ out_ptr = + scalar_t *__restrict__ out_ptr = output_buffer + head_part_idx * head_elem_num_per_partition; for (int block_idx = 0; block_idx < block_num; ++block_idx) { const int64_t physical_block_idx = seq_block_table[block_idx]; - const float* __restrict__ prob_vec_ptr = + const float *__restrict__ prob_vec_ptr = logits + block_idx * BLOCK_SIZE; - const scalar_t* __restrict__ v_block_cache_ptr = + const scalar_t *__restrict__ v_block_cache_ptr = v_cache + physical_block_idx * kv_block_stride + kv_head_idx * kv_head_stride + BLOCK_SIZE * head_part_idx * head_elem_num_per_partition; @@ -563,7 +555,7 @@ struct paged_attention_v2_impl { if (block_idx != block_num - 1) { const int64_t next_physical_block_idx = seq_block_table[block_idx + 1]; - const scalar_t* __restrict__ next_v_block_cache_ptr = + const scalar_t *__restrict__ next_v_block_cache_ptr = v_cache + next_physical_block_idx * kv_block_stride + kv_head_idx * kv_head_stride + BLOCK_SIZE * head_part_idx * head_elem_num_per_partition; @@ -595,7 +587,8 @@ struct paged_attention_v2_impl { const int partition_num = (seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE; - if (partition_num == 1) continue; + if (partition_num == 1) + continue; reducePartitonSoftmax( max_logits + seq_idx * num_heads * max_num_partitions + @@ -610,11 +603,11 @@ struct paged_attention_v2_impl { using v_load_vec_type = typename KernelVecType::v_load_vec_type; static_assert(v_load_vec_type::get_elem_num() == BLOCK_SIZE); constexpr int head_elem_num_per_group = - 16; // Note: didn't align with the cacheline size, due to some - // HEAD_SIZE didn't align with 64 bytes + 16; // Note: didn't align with the cacheline size, due to some HEAD_SIZE + // didn't align with 64 bytes static_assert(HEAD_SIZE % head_elem_num_per_group == 0); constexpr int head_group_num = HEAD_SIZE / head_elem_num_per_group; - const float* __restrict__ rescale_factors = exp_sums; + const float *__restrict__ rescale_factors = exp_sums; #pragma omp parallel for collapse(3) schedule(static, 1) for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) { for (int head_idx = 0; head_idx < num_heads; ++head_idx) { @@ -623,16 +616,17 @@ struct paged_attention_v2_impl { const int partition_num = (seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE; - if (partition_num == 1) continue; + if (partition_num == 1) + continue; - const float* __restrict__ seq_head_rescale_factors = + const float *__restrict__ seq_head_rescale_factors = rescale_factors + seq_idx * num_heads * max_num_partitions + head_idx * max_num_partitions; - const scalar_t* __restrict__ seq_head_tmp_out = + const scalar_t *__restrict__ seq_head_tmp_out = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + head_idx * max_num_partitions * HEAD_SIZE + group_idx * head_elem_num_per_group; - scalar_t* __restrict__ seq_head_output = + scalar_t *__restrict__ seq_head_output = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE + group_idx * head_elem_num_per_group; @@ -651,21 +645,21 @@ struct paged_attention_v2_impl { } }; -#define LAUNCH_V2_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE) \ - paged_attention_v2_impl::call( \ - out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, \ - key_cache_ptr, value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \ - seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \ - kv_block_stride, kv_head_stride, num_seqs, num_heads, \ +#define LAUNCH_V2_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE) \ + paged_attention_v2_impl::call( \ + out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, \ + key_cache_ptr, value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \ + seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \ + kv_block_stride, kv_head_stride, num_seqs, num_heads, \ max_num_partitions); template void paged_attention_v2_impl_launcher( - torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, - torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, - torch::Tensor& value_cache, int num_kv_heads, float scale, - torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size, - int max_seq_len, const c10::optional& alibi_slopes) { + torch::Tensor &out, torch::Tensor &exp_sums, torch::Tensor &max_logits, + torch::Tensor &tmp_out, torch::Tensor &query, torch::Tensor &key_cache, + torch::Tensor &value_cache, int num_kv_heads, float scale, + torch::Tensor &block_tables, torch::Tensor &seq_lens, int block_size, + int max_seq_len, const c10::optional &alibi_slopes) { int num_seqs = query.size(0); int num_heads = query.size(1); int head_size = query.size(2); @@ -676,79 +670,73 @@ void paged_attention_v2_impl_launcher( int max_num_partitions = exp_sums.size(-1); // NOTE: alibi_slopes is optional. - const float* alibi_slopes_ptr = + const float *alibi_slopes_ptr = alibi_slopes - ? reinterpret_cast(alibi_slopes.value().data_ptr()) + ? reinterpret_cast(alibi_slopes.value().data_ptr()) : nullptr; - T* out_ptr = reinterpret_cast(out.data_ptr()); - float* exp_sums_ptr = reinterpret_cast(exp_sums.data_ptr()); - float* max_logits_ptr = reinterpret_cast(max_logits.data_ptr()); - T* tmp_out_ptr = reinterpret_cast(tmp_out.data_ptr()); - T* query_ptr = reinterpret_cast(query.data_ptr()); - T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); - T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); - int* block_tables_ptr = block_tables.data_ptr(); - int* seq_lens_ptr = seq_lens.data_ptr(); + T *out_ptr = reinterpret_cast(out.data_ptr()); + float *exp_sums_ptr = reinterpret_cast(exp_sums.data_ptr()); + float *max_logits_ptr = reinterpret_cast(max_logits.data_ptr()); + T *tmp_out_ptr = reinterpret_cast(tmp_out.data_ptr()); + T *query_ptr = reinterpret_cast(query.data_ptr()); + T *key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); + T *value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); + int *block_tables_ptr = block_tables.data_ptr(); + int *seq_lens_ptr = seq_lens.data_ptr(); switch (head_size) { - case 64: - LAUNCH_V2_ATTENTION_KERNEL(T, 64, BLOCK_SIZE); - break; - case 80: - LAUNCH_V2_ATTENTION_KERNEL(T, 80, BLOCK_SIZE); - break; - case 96: - LAUNCH_V2_ATTENTION_KERNEL(T, 96, BLOCK_SIZE); - break; - case 112: - LAUNCH_V2_ATTENTION_KERNEL(T, 112, BLOCK_SIZE); - break; - case 128: - LAUNCH_V2_ATTENTION_KERNEL(T, 128, BLOCK_SIZE); - break; - case 192: - LAUNCH_V2_ATTENTION_KERNEL(T, 192, BLOCK_SIZE); - break; - case 256: - LAUNCH_V2_ATTENTION_KERNEL(T, 256, BLOCK_SIZE); - break; - default: - TORCH_CHECK(false, "Unsupported head size: ", head_size); - break; + case 64: + LAUNCH_V2_ATTENTION_KERNEL(T, 64, BLOCK_SIZE); + break; + case 80: + LAUNCH_V2_ATTENTION_KERNEL(T, 80, BLOCK_SIZE); + break; + case 96: + LAUNCH_V2_ATTENTION_KERNEL(T, 96, BLOCK_SIZE); + break; + case 112: + LAUNCH_V2_ATTENTION_KERNEL(T, 112, BLOCK_SIZE); + break; + case 128: + LAUNCH_V2_ATTENTION_KERNEL(T, 128, BLOCK_SIZE); + break; + case 256: + LAUNCH_V2_ATTENTION_KERNEL(T, 256, BLOCK_SIZE); + break; + default: + TORCH_CHECK(false, "Unsupported head size: ", head_size); + break; } } -#define CALL_V2_KERNEL_LAUNCHER(T, BLOCK_SIZE) \ - paged_attention_v2_impl_launcher( \ - out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ - num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, \ - alibi_slopes); - -#define CALL_V2_KERNEL_LAUNCHER_BLOCK_SIZE(T) \ - switch (block_size) { \ - case 16: \ - CALL_V2_KERNEL_LAUNCHER(T, 16); \ - break; \ - default: \ - TORCH_CHECK(false, "Unsupported block size: ", block_size); \ - break; \ +#define CALL_V2_KERNEL_LAUNCHER(T, BLOCK_SIZE) \ + paged_attention_v2_impl_launcher( \ + out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ + num_kv_heads, scale, block_tables, seq_lens, block_size, \ + max_seq_len, alibi_slopes); + +#define CALL_V2_KERNEL_LAUNCHER_BLOCK_SIZE(T) \ + switch (block_size) { \ + case 16: \ + CALL_V2_KERNEL_LAUNCHER(T, 16); \ + break; \ + default: \ + TORCH_CHECK(false, "Unsupported block size: ", block_size); \ + break; \ } -} // namespace - -void paged_attention_v2( - torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, - torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, - torch::Tensor& value_cache, int64_t num_kv_heads, double scale, - torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, - int64_t max_seq_len, const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank, - const int64_t blocksparse_local_blocks, - const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, - const int64_t blocksparse_head_sliding_step) { +} // namespace + +void paged_attention_v2(torch::Tensor &out, torch::Tensor &exp_sums, + torch::Tensor &max_logits, torch::Tensor &tmp_out, + torch::Tensor &query, torch::Tensor &key_cache, + torch::Tensor &value_cache, int num_kv_heads, + float scale, torch::Tensor &block_tables, + torch::Tensor &seq_lens, int block_size, + int max_seq_len, + const c10::optional &alibi_slopes, + const std::string &kv_cache_dtype, float kv_scale) { TORCH_CHECK(kv_scale == 1.0f); - TORCH_CHECK(blocksparse_vert_stride <= 1, - "CPU backend does not support blocksparse attention yet."); VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v2_impl", [&] { CPU_KERNEL_GUARD_IN(paged_attention_v2_impl) diff --git a/csrc/cpu/cache.cpp b/csrc/cpu/cache.cpp index 2b5c3bd6ee70b..95e3f11900fde 100644 --- a/csrc/cpu/cache.cpp +++ b/csrc/cpu/cache.cpp @@ -5,26 +5,25 @@ namespace { template -void copy_blocks_cpu_impl(std::vector const& key_caches, - std::vector const& value_caches, - const torch::Tensor& mapping_pairs, - const int element_num_per_block, - const int layer_num) { +void copy_blocks_cpu_impl( + std::vector &key_caches, + std::vector &value_caches, + const torch::Tensor& mapping_pairs, + const int element_num_per_block, const int layer_num) { const size_t pair_num = mapping_pairs.size(0); const size_t block_bytes = sizeof(scalar_t) * element_num_per_block; #pragma omp parallel for collapse(2) for (int layer = 0; layer < layer_num; ++layer) { for (size_t pair = 0; pair < pair_num; ++pair) { - int64_t source_offset = - element_num_per_block * mapping_pairs[pair][0].item(); + int64_t source_offset = element_num_per_block * mapping_pairs[pair][0].item(); int64_t target_offset = element_num_per_block * mapping_pairs[pair][1].item(); - scalar_t* key_cache_ptr = key_caches[layer].data_ptr(); - scalar_t* source_ptr = key_cache_ptr + source_offset; - scalar_t* target_ptr = key_cache_ptr + target_offset; + scalar_t *key_cache_ptr = key_caches[layer].data_ptr(); + scalar_t *source_ptr = key_cache_ptr + source_offset; + scalar_t *target_ptr = key_cache_ptr + target_offset; std::memcpy(target_ptr, source_ptr, block_bytes); - scalar_t* value_cache_ptr = value_caches[layer].data_ptr(); + scalar_t *value_cache_ptr = value_caches[layer].data_ptr(); source_ptr = value_cache_ptr + source_offset; target_ptr = value_cache_ptr + target_offset; std::memcpy(target_ptr, source_ptr, block_bytes); @@ -34,9 +33,9 @@ void copy_blocks_cpu_impl(std::vector const& key_caches, template void reshape_and_cache_cpu_impl( - const scalar_t* __restrict__ key, const scalar_t* __restrict__ value, - scalar_t* __restrict__ key_cache, scalar_t* __restrict__ value_cache, - const int64_t* __restrict__ slot_mapping, const int num_tokens, + const scalar_t *__restrict__ key, const scalar_t *__restrict__ value, + scalar_t *__restrict__ key_cache, scalar_t *__restrict__ value_cache, + const int64_t *__restrict__ slot_mapping, const int num_tokens, const int key_stride, const int value_stride, const int num_heads, const int head_size, const int block_size, const int x) { const int block_elem_num = num_heads * head_size * block_size; @@ -49,14 +48,14 @@ void reshape_and_cache_cpu_impl( int src_key_head_idx = token_idx * key_stride + head_idx * head_size; int src_value_head_idx = token_idx * value_stride + head_idx * head_size; - const scalar_t* src_key_head_ptr = key + src_key_head_idx; - const scalar_t* src_value_head_ptr = value + src_value_head_idx; + const scalar_t *src_key_head_ptr = key + src_key_head_idx; + const scalar_t *src_value_head_ptr = value + src_value_head_idx; const int64_t block_index = slot_idx / block_size; const int64_t block_offset = slot_idx % block_size; - scalar_t* target_key_head_ptr = key_cache + + scalar_t *target_key_head_ptr = key_cache + block_elem_num * block_index + head_idx * block_size * head_size; - scalar_t* target_value_head_ptr = value_cache + + scalar_t *target_value_head_ptr = value_cache + block_elem_num * block_index + head_idx * block_size * head_size; @@ -80,15 +79,12 @@ void reshape_and_cache_cpu_impl( } } } -}; // namespace +}; // namespace -// Note: the key_caches and value_caches vectors are constant but -// not the Tensors they contain. The vectors need to be const refs -// in order to satisfy pytorch's C++ operator registration code. -void copy_blocks(std::vector const& key_caches, - std::vector const& value_caches, - const torch::Tensor& block_mapping) { - unsigned num_layers = key_caches.size(); +void copy_blocks(std::vector &key_caches, + std::vector &value_caches, + torch::Tensor& block_mapping) { + int num_layers = key_caches.size(); TORCH_CHECK(num_layers == value_caches.size()); if (num_layers == 0) { return; @@ -104,10 +100,10 @@ void copy_blocks(std::vector const& key_caches, }); } -void reshape_and_cache(torch::Tensor& key, torch::Tensor& value, - torch::Tensor& key_cache, torch::Tensor& value_cache, - torch::Tensor& slot_mapping, - const std::string& kv_cache_dtype, double kv_scale) { +void reshape_and_cache(torch::Tensor &key, torch::Tensor &value, + torch::Tensor &key_cache, torch::Tensor &value_cache, + torch::Tensor &slot_mapping, + const std::string &kv_cache_dtype, float kv_scale) { TORCH_CHECK(kv_scale == 1.0f); int num_tokens = key.size(0); @@ -131,7 +127,7 @@ void reshape_and_cache(torch::Tensor& key, torch::Tensor& value, }); } -void swap_blocks(torch::Tensor& src, torch::Tensor& dst, - const torch::Tensor& block_mapping) { +void swap_blocks(torch::Tensor &src, torch::Tensor &dst, + const std::map &block_mapping) { TORCH_CHECK(false, "swap_blocks is unsupported on CPU.") } diff --git a/csrc/cpu/cpu_types.hpp b/csrc/cpu/cpu_types.hpp index 0213be09105ed..c1d3ec058b991 100644 --- a/csrc/cpu/cpu_types.hpp +++ b/csrc/cpu/cpu_types.hpp @@ -2,14 +2,351 @@ #ifndef CPU_TYPES_HPP #define CPU_TYPES_HPP -#if defined(__x86_64__) - //x86 implementation - #include "cpu_types_x86.hpp" -#elif defined(__POWER9_VECTOR__) - //ppc implementation - #include "cpu_types_vsx.hpp" +#include +#include + +namespace vec_op { + +// FIXME: FP16 is not fully supported in Torch-CPU +#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) + +#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) + +#ifndef CPU_OP_GUARD +#define CPU_KERNEL_GUARD_IN(NAME) +#define CPU_KERNEL_GUARD_OUT(NAME) #else - #warning "unsupported vLLM cpu implementation" +#define CPU_KERNEL_GUARD_IN(NAME) \ + std::cout << #NAME << " invoked." << std::endl; +#define CPU_KERNEL_GUARD_OUT(NAME) std::cout << #NAME << " exit." << std::endl; +#endif + +#define FORCE_INLINE __attribute__((always_inline)) inline + +namespace { +template +constexpr void unroll_loop_item(std::integer_sequence, F &&f) { + (f(std::integral_constant{}), ...); +} +}; // namespace + +template >> +constexpr void unroll_loop(F &&f) { + unroll_loop_item(std::make_integer_sequence{}, std::forward(f)); +} + +template struct Vec { + constexpr static int get_elem_num() { return T::VEC_ELEM_NUM; } +}; + +struct FP32Vec8; +struct FP32Vec16; + +#ifdef __AVX512FP16__ +struct FP16Vec8 : public Vec { + constexpr static int VEC_ELEM_NUM = 8; + + __m128h reg; + + explicit FP16Vec8(_Float16 v) : reg(_mm_set1_ph(v)) {} + + explicit FP16Vec8(const void *ptr) : reg(_mm_loadu_ph(ptr)) {} + + explicit FP16Vec8(__m128h data) : reg(data) {} + + FP16Vec8 operator*(const FP16Vec8 &b) const { + return FP16Vec8(_mm_mul_ph(reg, b.reg)); + } + + FP16Vec8 operator+(const FP16Vec8 &b) const { + return FP16Vec8(_mm_add_ph(reg, b.reg)); + } + + FP16Vec8 operator-(const FP16Vec8 &b) const { + return FP16Vec8(_mm_sub_ph(reg, b.reg)); + } + + FP16Vec8 operator/(const FP16Vec8 &b) const { + return FP16Vec8(_mm_div_ph(reg, b.reg)); + } + + void save(void *ptr) const { _mm_storeu_ph(ptr, reg); } +}; #endif +struct BF16Vec8 : public Vec { + constexpr static int VEC_ELEM_NUM = 8; + + __m128i reg; + + explicit BF16Vec8(const void *ptr) + : reg((__m128i)_mm_loadu_si128((__m128i *)ptr)) {} + + explicit BF16Vec8(const FP32Vec8 &); + + void save(void *ptr) const { *reinterpret_cast<__m128i *>(ptr) = reg; } +}; + +struct BF16Vec16 : public Vec { + constexpr static int VEC_ELEM_NUM = 16; + + __m256i reg; + + explicit BF16Vec16(const void *ptr) + : reg((__m256i)_mm256_loadu_si256((__m256i *)ptr)) {} + + explicit BF16Vec16(const FP32Vec16 &); + + void save(void *ptr) const { *reinterpret_cast<__m256i *>(ptr) = reg; } +}; + +struct BF16Vec32 : public Vec { + constexpr static int VEC_ELEM_NUM = 32; + + __m512i reg; + + explicit BF16Vec32(const void *ptr) : reg((__m512i)_mm512_loadu_si512(ptr)) {} + + explicit BF16Vec32(__m512i data) : reg(data) {} + + explicit BF16Vec32(BF16Vec8 &vec8_data) + : reg((__m512i)_mm512_inserti32x4( + _mm512_inserti32x4(_mm512_inserti32x4(_mm512_castsi128_si512( + (__m128i)vec8_data.reg), + (__m128i)vec8_data.reg, 1), + (__m128i)vec8_data.reg, 2), + (__m128i)vec8_data.reg, 3)) {} + + void save(void *ptr) const { *reinterpret_cast<__m512i *>(ptr) = reg; } +}; + +struct FP32Vec4 : public Vec { + constexpr static int VEC_ELEM_NUM = 4; + union AliasReg { + __m128 reg; + float values[VEC_ELEM_NUM]; + }; + + __m128 reg; + + explicit FP32Vec4(float v) : reg(_mm_set1_ps(v)) {} + + explicit FP32Vec4() : reg(_mm_set1_ps(0.0)) {} + + explicit FP32Vec4(const float *ptr) : reg(_mm_loadu_ps(ptr)) {} + + explicit FP32Vec4(__m128 data) : reg(data) {} + + explicit FP32Vec4(const FP32Vec4 &data) : reg(data.reg) {} +}; + +struct FP32Vec8 : public Vec { + constexpr static int VEC_ELEM_NUM = 8; + union AliasReg { + __m256 reg; + float values[VEC_ELEM_NUM]; + }; + + __m256 reg; + + explicit FP32Vec8(float v) : reg(_mm256_set1_ps(v)) {} + + explicit FP32Vec8() : reg(_mm256_set1_ps(0.0)) {} + + explicit FP32Vec8(const float *ptr) : reg(_mm256_loadu_ps(ptr)) {} + + explicit FP32Vec8(__m256 data) : reg(data) {} + + explicit FP32Vec8(const FP32Vec8 &data) : reg(data.reg) {} + +#ifdef __AVX512FP16__ + explicit FP32Vec8(__m128h v) : reg(_mm256_cvtph_ps(_mm_castph_si128(v))) {} +#endif + + explicit FP32Vec8(const BF16Vec8 &v) + : reg(_mm256_castsi256_ps( + _mm256_bslli_epi128(_mm256_cvtepu16_epi32(v.reg), 2))) {} + + float reduce_sum() const { + AliasReg ar; + ar.reg = reg; + float result = 0; + unroll_loop([&result, &ar](int i) { result += ar.values[i]; }); + + return result; + } + + FP32Vec8 exp() const { + AliasReg ar; + ar.reg = reg; + return FP32Vec8(_mm256_set_ps(expf(ar.values[7]), expf(ar.values[6]), + expf(ar.values[5]), expf(ar.values[4]), + expf(ar.values[3]), expf(ar.values[2]), + expf(ar.values[1]), expf(ar.values[0]))); + } + + FP32Vec8 tanh() const { + AliasReg ar; + ar.reg = reg; + return FP32Vec8(_mm256_set_ps(tanhf(ar.values[7]), tanhf(ar.values[6]), + tanhf(ar.values[5]), tanhf(ar.values[4]), + tanhf(ar.values[3]), tanhf(ar.values[2]), + tanhf(ar.values[1]), tanhf(ar.values[0]))); + } + + FP32Vec8 er() const { + AliasReg ar; + ar.reg = reg; + return FP32Vec8(_mm256_set_ps(erf(ar.values[7]), erf(ar.values[6]), + erf(ar.values[5]), erf(ar.values[4]), + erf(ar.values[3]), erf(ar.values[2]), + erf(ar.values[1]), erf(ar.values[0]))); + } + + FP32Vec8 operator*(const FP32Vec8 &b) const { + return FP32Vec8(_mm256_mul_ps(reg, b.reg)); + } + + FP32Vec8 operator+(const FP32Vec8 &b) const { + return FP32Vec8(_mm256_add_ps(reg, b.reg)); + } + + FP32Vec8 operator-(const FP32Vec8 &b) const { + return FP32Vec8(_mm256_sub_ps(reg, b.reg)); + } + + FP32Vec8 operator/(const FP32Vec8 &b) const { + return FP32Vec8(_mm256_div_ps(reg, b.reg)); + } + + void save(float *ptr) const { _mm256_storeu_ps(ptr, reg); } +}; + +struct FP32Vec16 : public Vec { + constexpr static int VEC_ELEM_NUM = 16; + union AliasReg { + __m512 reg; + float values[VEC_ELEM_NUM]; + }; + + __m512 reg; + + explicit FP32Vec16(float v) : reg(_mm512_set1_ps(v)) {} + + explicit FP32Vec16() : reg(_mm512_set1_ps(0.0)) {} + + explicit FP32Vec16(const float *ptr) : reg(_mm512_loadu_ps(ptr)) {} + + explicit FP32Vec16(__m512 data) : reg(data) {} + + explicit FP32Vec16(const FP32Vec16 &data) : reg(data.reg) {} + + explicit FP32Vec16(const FP32Vec4 &data) + : reg((__m512)_mm512_inserti32x4( + _mm512_inserti32x4( + _mm512_inserti32x4(_mm512_castsi128_si512((__m128i)data.reg), + (__m128i)data.reg, 1), + (__m128i)data.reg, 2), + (__m128i)data.reg, 3)) {} + + explicit FP32Vec16(const FP32Vec8 &data) + : reg((__m512)_mm512_inserti32x8( + _mm512_castsi256_si512((__m256i)data.reg), (__m256i)data.reg, 1)) {} + + explicit FP32Vec16(const BF16Vec16 &v) + : reg(_mm512_castsi512_ps( + _mm512_bslli_epi128(_mm512_cvtepu16_epi32(v.reg), 2))) {} + + explicit FP32Vec16(const BF16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {} + + FP32Vec16 operator*(const FP32Vec16 &b) const { + return FP32Vec16(_mm512_mul_ps(reg, b.reg)); + } + + FP32Vec16 operator+(const FP32Vec16 &b) const { + return FP32Vec16(_mm512_add_ps(reg, b.reg)); + } + + FP32Vec16 operator-(const FP32Vec16 &b) const { + return FP32Vec16(_mm512_sub_ps(reg, b.reg)); + } + + FP32Vec16 operator/(const FP32Vec16 &b) const { + return FP32Vec16(_mm512_div_ps(reg, b.reg)); + } + + float reduce_sum() const { return _mm512_reduce_add_ps(reg); } + + template float reduce_sub_sum(int idx) { + static_assert(VEC_ELEM_NUM % group_size == 0); + constexpr uint32_t base_mask = (0xFFFF >> (16 - group_size)); + __mmask16 mask = _cvtu32_mask16(base_mask << (idx * group_size)); + return _mm512_mask_reduce_add_ps(mask, reg); + } + + void save(float *ptr) const { _mm512_storeu_ps(ptr, reg); } +}; + +template struct VecType { using vec_type = void; }; + +template using vec_t = typename VecType::vec_type; + +template <> struct VecType { using vec_type = FP32Vec8; }; + +#ifdef __AVX512FP16__ +template <> struct VecType { using vec_type = FP16Vec16; }; +#endif + +template <> struct VecType { using vec_type = BF16Vec8; }; + +template void storeFP32(float v, T *ptr) { *ptr = v; } + +#ifdef __AVX512FP16__ +template <> inline void storeFP32(float v, c10::Half *ptr) { + *reinterpret_cast<_Float16 *>(ptr) = v; +} +#endif + +inline void fma(FP32Vec16 &acc, FP32Vec16 &a, FP32Vec16 &b) { + acc = acc + a * b; +} + +#ifdef __AVX512BF16__ +template <> inline void storeFP32(float v, c10::BFloat16 *ptr) { + *reinterpret_cast<__bfloat16 *>(ptr) = _mm_cvtness_sbh(v); +} + +inline BF16Vec8::BF16Vec8(const FP32Vec8 &v) + : reg((__m128i)_mm256_cvtneps_pbh(v.reg)) {} + +inline BF16Vec16::BF16Vec16(const FP32Vec16 &v) + : reg((__m256i)_mm512_cvtneps_pbh(v.reg)) {} + +inline void fma(FP32Vec16 &acc, BF16Vec32 &a, BF16Vec32 &b) { + acc.reg = _mm512_dpbf16_ps(acc.reg, (__m512bh)a.reg, (__m512bh)b.reg); +} +#else +template <> inline void storeFP32(float v, c10::BFloat16 *ptr) { + c10::BFloat16 __attribute__((__may_alias__)) *v_ptr = + reinterpret_cast(&v); + *ptr = *(v_ptr + 1); +} + +inline BF16Vec8::BF16Vec8(const FP32Vec8 &v) + : reg(_mm256_cvtepi32_epi16( + _mm256_bsrli_epi128(_mm256_castps_si256(v.reg), 2))) {} + +inline BF16Vec16::BF16Vec16(const FP32Vec16 &v) + : reg(_mm512_cvtepi32_epi16( + _mm512_bsrli_epi128(_mm512_castps_si512(v.reg), 2))) {} +#endif + +inline void prefetch(const void *addr) { _mm_prefetch(addr, _MM_HINT_T1); } + +}; // namespace vec_op + #endif diff --git a/csrc/cpu/cpu_types_vsx.hpp b/csrc/cpu/cpu_types_vsx.hpp deleted file mode 100644 index b50bdadc5713d..0000000000000 --- a/csrc/cpu/cpu_types_vsx.hpp +++ /dev/null @@ -1,491 +0,0 @@ - -#ifndef CPU_TYPES_VSX_HPP -#define CPU_TYPES_VSX_HPP - -#include -#include -#include - -namespace vec_op { - -// FIXME: FP16 is not fully supported in Torch-CPU -#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \ - AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ - AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) - -#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ - AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) - -#ifndef CPU_OP_GUARD -#define CPU_KERNEL_GUARD_IN(NAME) -#define CPU_KERNEL_GUARD_OUT(NAME) -#else -#define CPU_KERNEL_GUARD_IN(NAME) \ - std::cout << #NAME << " invoked." << std::endl; -#define CPU_KERNEL_GUARD_OUT(NAME) std::cout << #NAME << " exit." << std::endl; -#endif - -#define FORCE_INLINE __attribute__((always_inline)) inline - -namespace { -template -constexpr void unroll_loop_item(std::integer_sequence, F &&f) { - (f(std::integral_constant{}), ...); -} -}; // namespace - -template >> -constexpr void unroll_loop(F &&f) { - unroll_loop_item(std::make_integer_sequence{}, std::forward(f)); -} - -template struct Vec { - constexpr static int get_elem_num() { return T::VEC_ELEM_NUM; } -}; - -typedef struct ss16x8x2_t { - __vector signed short val[2]; -} ss16x8x2_t; - -typedef struct ss16x8x4_t { - __vector signed short val[4]; -} ss16x8x4_t; - -typedef struct f32x4x2_t { - __vector float val[2]; -} f32x4x2_t; - -typedef struct f32x4x4_t { - __vector float val[4]; -} f32x4x4_t; - -struct FP32Vec8; -struct FP32Vec16; - -struct BF16Vec8 : public Vec { - constexpr static int VEC_ELEM_NUM = 8; - - __vector signed short reg; - - explicit BF16Vec8(const void *ptr) - : reg((__vector signed short)vec_xl(0, (__vector signed short *)ptr)) {} - - explicit BF16Vec8(const FP32Vec8 &); - - void save(void *ptr) const { *reinterpret_cast<__vector signed short *>(ptr) = reg; } -}; - -struct BF16Vec16 : public Vec { - constexpr static int VEC_ELEM_NUM = 16; - - ss16x8x2_t reg; - - explicit BF16Vec16(const void *ptr) { - // Load 256 bits in two parts - reg.val[0] = (__vector signed short)vec_xl(0, (signed short *)ptr); - reg.val[1] = (__vector signed short)vec_xl(16, (signed short *)ptr); - } - - explicit BF16Vec16(const FP32Vec16 &); - - void save(void *ptr) const { - // Save 256 bits in two parts - vec_xst(reg.val[0], 0, (signed short *)ptr); - vec_xst(reg.val[1], 16, (signed short *)ptr); - } -}; - -const static __vector signed short zero = vec_splats((signed short)0); - -struct BF16Vec32 : public Vec { - constexpr static int VEC_ELEM_NUM = 32; - - ss16x8x4_t reg; - explicit BF16Vec32(const void *ptr) - : reg(*reinterpret_cast(ptr)) {} - - explicit BF16Vec32(ss16x8x4_t data) : reg(data) {} - - explicit BF16Vec32(const BF16Vec8 &vec8_data) : reg({ - vec8_data.reg, - vec8_data.reg, - vec8_data.reg, - vec8_data.reg - }) {} - - void save(void *ptr) const { *reinterpret_cast(ptr) = reg; } -}; - -struct FP32Vec4 : public Vec { - constexpr static int VEC_ELEM_NUM = 4; - union AliasReg { - __vector float reg; - float values[VEC_ELEM_NUM]; - }; - - __vector float reg; - - explicit FP32Vec4(float v) : reg(vec_splats(v)) {} - - explicit FP32Vec4() : reg(vec_splats(0.0f)) {} - - explicit FP32Vec4(const float *ptr) : reg(vec_xl(0, ptr)) {} - - explicit FP32Vec4(__vector float data) : reg(data) {} - - explicit FP32Vec4(const FP32Vec4 &data) : reg(data.reg) {} -}; - -struct FP32Vec8 : public Vec { - constexpr static int VEC_ELEM_NUM = 8; - union AliasReg { - f32x4x2_t reg; - float values[VEC_ELEM_NUM]; - }; - - f32x4x2_t reg; - - explicit FP32Vec8(float v) { - reg.val[0] = vec_splats(v); - reg.val[1] = vec_splats(v); - } - - explicit FP32Vec8() { - reg.val[0] = vec_splats(0.0f); - reg.val[1] = vec_splats(0.0f); - } - - explicit FP32Vec8(const float *ptr) { - reg.val[0] = vec_xl(0, ptr); - reg.val[1] = vec_xl(16, ptr); - } - - explicit FP32Vec8(f32x4x2_t data) : reg(data) {} - - explicit FP32Vec8(const FP32Vec8 &data) { - reg.val[0] = data.reg.val[0]; - reg.val[1] = data.reg.val[1]; - } - - explicit FP32Vec8(const BF16Vec8 &v) { - reg.val[0] = (__vector float)vec_mergeh(zero, v.reg); - reg.val[1] = (__vector float)vec_mergel(zero, v.reg); - } - - float reduce_sum() const { - AliasReg ar; - ar.reg = reg; - float result = 0; - unroll_loop([&result, &ar](int i) { result += ar.values[i]; }); - - return result; - } - - FP32Vec8 exp() const { - // TODO: Vectorize this - AliasReg ar; - ar.reg = reg; - f32x4x4_t ret; - ret.val[0][0] = std::exp(ar.values[0]); - ret.val[0][1] = std::exp(ar.values[1]); - ret.val[0][2] = std::exp(ar.values[2]); - ret.val[0][3] = std::exp(ar.values[3]); - ret.val[1][0] = std::exp(ar.values[4]); - ret.val[1][1] = std::exp(ar.values[5]); - ret.val[1][2] = std::exp(ar.values[6]); - ret.val[1][3] = std::exp(ar.values[7]); - return FP32Vec8(f32x4x2_t({ret.val[0], ret.val[1]})); - } - - FP32Vec8 tanh() const { - // TODO: Vectorize this - AliasReg ar; - ar.reg = reg; - f32x4x4_t ret; - ret.val[0][0] = std::tanh(ar.values[0]); - ret.val[0][1] = std::tanh(ar.values[1]); - ret.val[0][2] = std::tanh(ar.values[2]); - ret.val[0][3] = std::tanh(ar.values[3]); - ret.val[1][0] = std::tanh(ar.values[4]); - ret.val[1][1] = std::tanh(ar.values[5]); - ret.val[1][2] = std::tanh(ar.values[6]); - ret.val[1][3] = std::tanh(ar.values[7]); - return FP32Vec8(f32x4x2_t({ret.val[0], ret.val[1]})); - } - - FP32Vec8 er() const { - // TODO: Vectorize this - AliasReg ar; - ar.reg = reg; - f32x4x4_t ret; - ret.val[0][0] = std::erf(ar.values[0]); - ret.val[0][1] = std::erf(ar.values[1]); - ret.val[0][2] = std::erf(ar.values[2]); - ret.val[0][3] = std::erf(ar.values[3]); - ret.val[1][0] = std::erf(ar.values[4]); - ret.val[1][1] = std::erf(ar.values[5]); - ret.val[1][2] = std::erf(ar.values[6]); - ret.val[1][3] = std::erf(ar.values[7]); - return FP32Vec8(f32x4x2_t({ret.val[0], ret.val[1]})); - } - - FP32Vec8 operator*(const FP32Vec8 &b) const { - return FP32Vec8({vec_mul(reg.val[0], b.reg.val[0]), vec_mul(reg.val[1], b.reg.val[1])}); - } - - FP32Vec8 operator+(const FP32Vec8 &b) const { - return FP32Vec8({vec_add(reg.val[0], b.reg.val[0]), vec_add(reg.val[1], b.reg.val[1])}); - } - - FP32Vec8 operator-(const FP32Vec8 &b) const { - return FP32Vec8({vec_sub(reg.val[0], b.reg.val[0]), vec_sub(reg.val[1], b.reg.val[1])}); - } - - FP32Vec8 operator/(const FP32Vec8 &b) const { - return FP32Vec8({vec_div(reg.val[0], b.reg.val[0]), vec_div(reg.val[1], b.reg.val[1])}); - } - - void save(float *ptr) const { - vec_xst(reg.val[0], 0, ptr); - vec_xst(reg.val[1], 16, ptr); - } -}; - -struct FP32Vec16 : public Vec { - constexpr static int VEC_ELEM_NUM = 16; - union AliasReg { - f32x4x4_t reg; - float values[VEC_ELEM_NUM]; - }; - - f32x4x4_t reg; - - explicit FP32Vec16(float v) { - reg.val[0] = vec_splats(v); - reg.val[1] = vec_splats(v); - reg.val[2] = vec_splats(v); - reg.val[3] = vec_splats(v); - } - - explicit FP32Vec16() { - reg.val[0] = vec_splats(0.0f); - reg.val[1] = vec_splats(0.0f); - reg.val[2] = vec_splats(0.0f); - reg.val[3] = vec_splats(0.0f); - } - - explicit FP32Vec16(const float *ptr) { - reg.val[0] = vec_xl(0, ptr); - reg.val[1] = vec_xl(16, ptr); - reg.val[2] = vec_xl(32, ptr); - reg.val[3] = vec_xl(48, ptr); - } - - explicit FP32Vec16(f32x4x4_t data) : reg(data) {} - - explicit FP32Vec16(const FP32Vec16 &data) { - reg.val[0] = data.reg.val[0]; - reg.val[1] = data.reg.val[1]; - reg.val[2] = data.reg.val[2]; - reg.val[3] = data.reg.val[3]; - } - - explicit FP32Vec16(const FP32Vec4 &data) { - reg.val[0] = data.reg; - reg.val[1] = data.reg; - reg.val[2] = data.reg; - reg.val[3] = data.reg; - } - - explicit FP32Vec16(const FP32Vec8 &data) { - reg.val[0] = data.reg.val[0]; - reg.val[1] = data.reg.val[1]; - reg.val[2] = data.reg.val[0]; - reg.val[3] = data.reg.val[1]; - } - - explicit FP32Vec16(const BF16Vec16 &v) { - reg.val[0] = (__vector float)vec_mergeh(zero, v.reg.val[0]); - reg.val[1] = (__vector float)vec_mergel(zero, v.reg.val[0]); - reg.val[2] = (__vector float)vec_mergeh(zero, v.reg.val[1]); - reg.val[3] = (__vector float)vec_mergel(zero, v.reg.val[1]); - } - - explicit FP32Vec16(const BF16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {} - - FP32Vec16 operator*(const FP32Vec16 &b) const { - return FP32Vec16(f32x4x4_t({ - vec_mul(reg.val[0], b.reg.val[0]), - vec_mul(reg.val[1], b.reg.val[1]), - vec_mul(reg.val[2], b.reg.val[2]), - vec_mul(reg.val[3], b.reg.val[3])})); - } - - FP32Vec16 operator+(const FP32Vec16 &b) const { - return FP32Vec16(f32x4x4_t({ - vec_add(reg.val[0], b.reg.val[0]), - vec_add(reg.val[1], b.reg.val[1]), - vec_add(reg.val[2], b.reg.val[2]), - vec_add(reg.val[3], b.reg.val[3])})); - } - - FP32Vec16 operator-(const FP32Vec16 &b) const { - return FP32Vec16(f32x4x4_t({ - vec_sub(reg.val[0], b.reg.val[0]), - vec_sub(reg.val[1], b.reg.val[1]), - vec_sub(reg.val[2], b.reg.val[2]), - vec_sub(reg.val[3], b.reg.val[3])})); - } - - FP32Vec16 operator/(const FP32Vec16 &b) const { - return FP32Vec16(f32x4x4_t({ - vec_div(reg.val[0], b.reg.val[0]), - vec_div(reg.val[1], b.reg.val[1]), - vec_div(reg.val[2], b.reg.val[2]), - vec_div(reg.val[3], b.reg.val[3])})); - } - - float reduce_sum() const { - AliasReg ar; - ar.reg = reg; - float result = 0; - unroll_loop([&result, &ar](int i) { result += ar.values[i]; }); - - return result; - } - - template float reduce_sub_sum(int idx) { - static_assert(VEC_ELEM_NUM % group_size == 0); - - AliasReg ar; - ar.reg = reg; - float result = 0; - const int start = idx * group_size; - unroll_loop( - [&result, &start, ar](int i) { result += ar.values[start + i]; }); - - return result; - } - - void save(float *ptr) const { - vec_xst(reg.val[0], 0, ptr); - vec_xst(reg.val[1], 16, ptr); - vec_xst(reg.val[2], 32, ptr); - vec_xst(reg.val[3], 48, ptr); - } -}; - -template struct VecType { using vec_type = void; }; - -template using vec_t = typename VecType::vec_type; - -template <> struct VecType { using vec_type = FP32Vec8; }; - -template <> struct VecType { using vec_type = BF16Vec8; }; - -template void storeFP32(float v, T *ptr) { *ptr = v; } - -inline void fma(FP32Vec16 &acc, FP32Vec16 &a, FP32Vec16 &b) { - acc = acc + a * b; -} - -template <> inline void storeFP32(float v, c10::BFloat16 *ptr) { - c10::BFloat16 __attribute__((__may_alias__)) *v_ptr = - reinterpret_cast(&v); - *ptr = *(v_ptr + 1); -} - -#ifndef __VEC_CLASS_FP_NAN -#define __VEC_CLASS_FP_NAN (1 << 6) -#endif - -const static __vector unsigned char omask = { 0, 1, 4, 5, 8, 9, 12, 13, 16, 17, 20, 21, 24, 25, 28, 29 }; -#ifndef _ARCH_PWR10 -const static __vector unsigned int bias = { 0x00007fff, 0x00007fff, 0x00007fff, 0x00007fff }; -const static __vector unsigned int nan = { 0x7fc00000, 0x7fc00000, 0x7fc00000, 0x7fc00000 }; -const static __vector unsigned int sh16 = { 16, 16, 16, 16 }; -const static __vector unsigned int one = { 1, 1, 1, 1 }; -#endif - -inline BF16Vec8::BF16Vec8(const FP32Vec8 &v) { -#ifdef _ARCH_PWR10 - __vector signed short ret[2]; - ret[0] = (__vector signed short)__builtin_vsx_xvcvspbf16((__vector unsigned char)v.reg.val[0]); - ret[1] = (__vector signed short)__builtin_vsx_xvcvspbf16((__vector unsigned char)v.reg.val[1]); - reg = vec_perm(ret[0], ret[1], omask); -#elif defined(_ARCH_PWR9) - __vector unsigned int inp0 = (__vector unsigned int)(v.reg.val[0]); - __vector unsigned int inp1 = (__vector unsigned int)(v.reg.val[1]); - __vector unsigned int lsb0 = vec_sr(inp0, sh16); - __vector unsigned int lsb1 = vec_sr(inp1, sh16); - lsb0 = vec_and(lsb0, one); - lsb1 = vec_and(lsb1, one); - __vector unsigned int rnd0 = vec_add(lsb0, bias); - __vector unsigned int rnd1 = vec_add(lsb1, bias); - inp0 = vec_add(inp0, rnd0); - inp1 = vec_add(inp1, rnd1); - __vector __bool int sel0 = vec_test_data_class(v.reg.val[0], __VEC_CLASS_FP_NAN); - __vector __bool int sel1 = vec_test_data_class(v.reg.val[1], __VEC_CLASS_FP_NAN); - inp0 = vec_sel(inp0, nan, sel0); - inp1 = vec_sel(inp1, nan, sel1); - inp0 = vec_sr(inp0, sh16); - inp1 = vec_sr(inp1, sh16); - reg = (__vector signed short)vec_perm(inp0, inp1, omask); -#endif -} - -inline BF16Vec16::BF16Vec16(const FP32Vec16 &v) { -#ifdef _ARCH_PWR10 - __vector signed short ret[4]; - ret[0] = (__vector signed short)__builtin_vsx_xvcvspbf16((__vector unsigned char)v.reg.val[0]); - ret[1] = (__vector signed short)__builtin_vsx_xvcvspbf16((__vector unsigned char)v.reg.val[1]); - ret[2] = (__vector signed short)__builtin_vsx_xvcvspbf16((__vector unsigned char)v.reg.val[2]); - ret[3] = (__vector signed short)__builtin_vsx_xvcvspbf16((__vector unsigned char)v.reg.val[3]); - reg.val[0] = vec_perm(ret[0], ret[1], omask); - reg.val[1] = vec_perm(ret[2], ret[3], omask); -#elif defined(_ARCH_PWR9) - __vector unsigned int inp0 = (__vector unsigned int)(v.reg.val[0]); - __vector unsigned int inp1 = (__vector unsigned int)(v.reg.val[1]); - __vector unsigned int inp2 = (__vector unsigned int)(v.reg.val[2]); - __vector unsigned int inp3 = (__vector unsigned int)(v.reg.val[3]); - __vector unsigned int lsb0 = vec_sr(inp0, sh16); - __vector unsigned int lsb1 = vec_sr(inp1, sh16); - __vector unsigned int lsb2 = vec_sr(inp2, sh16); - __vector unsigned int lsb3 = vec_sr(inp3, sh16); - lsb0 = vec_and(lsb0, one); - lsb1 = vec_and(lsb1, one); - lsb2 = vec_and(lsb2, one); - lsb3 = vec_and(lsb3, one); - __vector unsigned int rnd0 = vec_add(lsb0, bias); - __vector unsigned int rnd1 = vec_add(lsb1, bias); - __vector unsigned int rnd2 = vec_add(lsb2, bias); - __vector unsigned int rnd3 = vec_add(lsb3, bias); - inp0 = vec_add(inp0, rnd0); - inp1 = vec_add(inp1, rnd1); - inp2 = vec_add(inp2, rnd2); - inp3 = vec_add(inp3, rnd3); - __vector __bool int sel0 = vec_test_data_class(v.reg.val[0], __VEC_CLASS_FP_NAN); - __vector __bool int sel1 = vec_test_data_class(v.reg.val[1], __VEC_CLASS_FP_NAN); - __vector __bool int sel2 = vec_test_data_class(v.reg.val[2], __VEC_CLASS_FP_NAN); - __vector __bool int sel3 = vec_test_data_class(v.reg.val[3], __VEC_CLASS_FP_NAN); - inp0 = vec_sel(inp0, nan, sel0); - inp1 = vec_sel(inp1, nan, sel1); - inp2 = vec_sel(inp2, nan, sel2); - inp3 = vec_sel(inp3, nan, sel3); - inp0 = vec_sr(inp0, sh16); - inp1 = vec_sr(inp1, sh16); - inp2 = vec_sr(inp2, sh16); - inp3 = vec_sr(inp3, sh16); - reg.val[0] = (__vector signed short)vec_perm(inp0, inp1, omask); - reg.val[1] = (__vector signed short)vec_perm(inp2, inp3, omask); -#endif -} - -inline void prefetch(const void *addr) { - __asm__ __volatile__("dcbt 0, %0" : : "r"(addr) : "memory"); -} - -}; // namespace vec_op - -#endif diff --git a/csrc/cpu/cpu_types_x86.hpp b/csrc/cpu/cpu_types_x86.hpp deleted file mode 100644 index f50620a5287d4..0000000000000 --- a/csrc/cpu/cpu_types_x86.hpp +++ /dev/null @@ -1,515 +0,0 @@ - -#ifndef CPU_TYPES_X86_HPP -#define CPU_TYPES_X86_HPP - -#include -#include - -#ifndef __AVX2__ -static_assert(false, "AVX2 must be supported for the current implementation."); -#endif - -namespace vec_op { - -// FIXME: FP16 is not fully supported in Torch-CPU -#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \ - AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ - AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) - -#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ - AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) - -#ifndef CPU_OP_GUARD -#define CPU_KERNEL_GUARD_IN(NAME) -#define CPU_KERNEL_GUARD_OUT(NAME) -#else -#define CPU_KERNEL_GUARD_IN(NAME) \ - std::cout << #NAME << " invoked." << std::endl; -#define CPU_KERNEL_GUARD_OUT(NAME) std::cout << #NAME << " exit." << std::endl; -#endif - -#define FORCE_INLINE __attribute__((always_inline)) inline - -namespace { -template -constexpr void unroll_loop_item(std::integer_sequence, F &&f) { - (f(std::integral_constant{}), ...); -} -}; // namespace - -template >> -constexpr void unroll_loop(F &&f) { - unroll_loop_item(std::make_integer_sequence{}, std::forward(f)); -} - -template struct Vec { - constexpr static int get_elem_num() { return T::VEC_ELEM_NUM; } -}; - -struct FP32Vec8; -struct FP32Vec16; - -#ifdef __AVX512FP16__ -struct FP16Vec8 : public Vec { - constexpr static int VEC_ELEM_NUM = 8; - - __m128h reg; - - explicit FP16Vec8(_Float16 v) : reg(_mm_set1_ph(v)) {} - - explicit FP16Vec8(const void *ptr) : reg(_mm_loadu_ph(ptr)) {} - - explicit FP16Vec8(__m128h data) : reg(data) {} - - FP16Vec8 operator*(const FP16Vec8 &b) const { - return FP16Vec8(_mm_mul_ph(reg, b.reg)); - } - - FP16Vec8 operator+(const FP16Vec8 &b) const { - return FP16Vec8(_mm_add_ph(reg, b.reg)); - } - - FP16Vec8 operator-(const FP16Vec8 &b) const { - return FP16Vec8(_mm_sub_ph(reg, b.reg)); - } - - FP16Vec8 operator/(const FP16Vec8 &b) const { - return FP16Vec8(_mm_div_ph(reg, b.reg)); - } - - void save(void *ptr) const { _mm_storeu_ph(ptr, reg); } -}; -#endif - -struct BF16Vec8 : public Vec { - constexpr static int VEC_ELEM_NUM = 8; - - __m128i reg; - - explicit BF16Vec8(const void *ptr) - : reg((__m128i)_mm_loadu_si128((__m128i *)ptr)) {} - - explicit BF16Vec8(const FP32Vec8 &); - - void save(void *ptr) const { *reinterpret_cast<__m128i *>(ptr) = reg; } -}; - -struct BF16Vec16 : public Vec { - constexpr static int VEC_ELEM_NUM = 16; - - __m256i reg; - - explicit BF16Vec16(const void *ptr) - : reg((__m256i)_mm256_loadu_si256((__m256i *)ptr)) {} - - explicit BF16Vec16(const FP32Vec16 &); - - void save(void *ptr) const { *reinterpret_cast<__m256i *>(ptr) = reg; } -}; - -#ifdef __AVX512F__ -struct BF16Vec32 : public Vec { - constexpr static int VEC_ELEM_NUM = 32; - - __m512i reg; - - explicit BF16Vec32(const void *ptr) : reg((__m512i)_mm512_loadu_si512(ptr)) {} - - explicit BF16Vec32(__m512i data) : reg(data) {} - - explicit BF16Vec32(BF16Vec8 &vec8_data) - : reg((__m512i)_mm512_inserti32x4( - _mm512_inserti32x4(_mm512_inserti32x4(_mm512_castsi128_si512( - (__m128i)vec8_data.reg), - (__m128i)vec8_data.reg, 1), - (__m128i)vec8_data.reg, 2), - (__m128i)vec8_data.reg, 3)) {} - - void save(void *ptr) const { *reinterpret_cast<__m512i *>(ptr) = reg; } -}; -#else -struct BF16Vec32 : public Vec { - constexpr static int VEC_ELEM_NUM = 32; - - __m256i reg_low; - __m256i reg_high; - - explicit BF16Vec32(const void *ptr) - : reg_low(_mm256_loadu_si256((__m256i const *)ptr)), - reg_high(_mm256_loadu_si256((__m256i const *)ptr + 1)) {} - - explicit BF16Vec32(__m256i low, __m256i high) : reg_low(low), - reg_high(high) {} - - explicit BF16Vec32(BF16Vec8 &vec8_data) - : reg_low((__m256i)_mm256_inserti32x4( - _mm256_castsi128_si256((__m128i)vec8_data.reg), - (__m128i)vec8_data.reg, 1)), - reg_high((__m256i)_mm256_inserti32x4( - _mm256_castsi128_si256((__m128i)vec8_data.reg), - (__m128i)vec8_data.reg, 1)) {} - - void save(void *ptr) const { - *reinterpret_cast<__m256i *>(ptr) = reg_low; - *reinterpret_cast<__m256i *>((__m256i *)ptr + 1) = reg_high; - } -}; -#endif - -struct FP32Vec4 : public Vec { - constexpr static int VEC_ELEM_NUM = 4; - union AliasReg { - __m128 reg; - float values[VEC_ELEM_NUM]; - }; - - __m128 reg; - - explicit FP32Vec4(float v) : reg(_mm_set1_ps(v)) {} - - explicit FP32Vec4() : reg(_mm_set1_ps(0.0)) {} - - explicit FP32Vec4(const float *ptr) : reg(_mm_loadu_ps(ptr)) {} - - explicit FP32Vec4(__m128 data) : reg(data) {} - - explicit FP32Vec4(const FP32Vec4 &data) : reg(data.reg) {} -}; - -struct FP32Vec8 : public Vec { - constexpr static int VEC_ELEM_NUM = 8; - union AliasReg { - __m256 reg; - float values[VEC_ELEM_NUM]; - }; - - __m256 reg; - - explicit FP32Vec8(float v) : reg(_mm256_set1_ps(v)) {} - - explicit FP32Vec8() : reg(_mm256_set1_ps(0.0)) {} - - explicit FP32Vec8(const float *ptr) : reg(_mm256_loadu_ps(ptr)) {} - - explicit FP32Vec8(__m256 data) : reg(data) {} - - explicit FP32Vec8(const FP32Vec8 &data) : reg(data.reg) {} - -#ifdef __AVX512FP16__ - explicit FP32Vec8(__m128h v) : reg(_mm256_cvtph_ps(_mm_castph_si128(v))) {} -#endif - - explicit FP32Vec8(const BF16Vec8 &v) - : reg(_mm256_castsi256_ps( - _mm256_bslli_epi128(_mm256_cvtepu16_epi32(v.reg), 2))) {} - - float reduce_sum() const { - AliasReg ar; - ar.reg = reg; - float result = 0; - unroll_loop([&result, &ar](int i) { result += ar.values[i]; }); - - return result; - } - - FP32Vec8 exp() const { - AliasReg ar; - ar.reg = reg; - return FP32Vec8(_mm256_set_ps(expf(ar.values[7]), expf(ar.values[6]), - expf(ar.values[5]), expf(ar.values[4]), - expf(ar.values[3]), expf(ar.values[2]), - expf(ar.values[1]), expf(ar.values[0]))); - } - - FP32Vec8 tanh() const { - AliasReg ar; - ar.reg = reg; - return FP32Vec8(_mm256_set_ps(tanhf(ar.values[7]), tanhf(ar.values[6]), - tanhf(ar.values[5]), tanhf(ar.values[4]), - tanhf(ar.values[3]), tanhf(ar.values[2]), - tanhf(ar.values[1]), tanhf(ar.values[0]))); - } - - FP32Vec8 er() const { - AliasReg ar; - ar.reg = reg; - return FP32Vec8(_mm256_set_ps(erf(ar.values[7]), erf(ar.values[6]), - erf(ar.values[5]), erf(ar.values[4]), - erf(ar.values[3]), erf(ar.values[2]), - erf(ar.values[1]), erf(ar.values[0]))); - } - - FP32Vec8 operator*(const FP32Vec8 &b) const { - return FP32Vec8(_mm256_mul_ps(reg, b.reg)); - } - - FP32Vec8 operator+(const FP32Vec8 &b) const { - return FP32Vec8(_mm256_add_ps(reg, b.reg)); - } - - FP32Vec8 operator-(const FP32Vec8 &b) const { - return FP32Vec8(_mm256_sub_ps(reg, b.reg)); - } - - FP32Vec8 operator/(const FP32Vec8 &b) const { - return FP32Vec8(_mm256_div_ps(reg, b.reg)); - } - - void save(float *ptr) const { _mm256_storeu_ps(ptr, reg); } -}; - -#ifdef __AVX512F__ -struct FP32Vec16 : public Vec { - constexpr static int VEC_ELEM_NUM = 16; - union AliasReg { - __m512 reg; - float values[VEC_ELEM_NUM]; - }; - - __m512 reg; - - explicit FP32Vec16(float v) : reg(_mm512_set1_ps(v)) {} - - explicit FP32Vec16() : reg(_mm512_set1_ps(0.0)) {} - - explicit FP32Vec16(const float *ptr) : reg(_mm512_loadu_ps(ptr)) {} - - explicit FP32Vec16(__m512 data) : reg(data) {} - - explicit FP32Vec16(const FP32Vec16 &data) : reg(data.reg) {} - - explicit FP32Vec16(const FP32Vec4 &data) - : reg((__m512)_mm512_inserti32x4( - _mm512_inserti32x4( - _mm512_inserti32x4(_mm512_castsi128_si512((__m128i)data.reg), - (__m128i)data.reg, 1), - (__m128i)data.reg, 2), - (__m128i)data.reg, 3)) {} - - explicit FP32Vec16(const FP32Vec8 &data) - : reg((__m512)_mm512_inserti32x8( - _mm512_castsi256_si512((__m256i)data.reg), (__m256i)data.reg, 1)) {} - - explicit FP32Vec16(const BF16Vec16 &v) - : reg(_mm512_castsi512_ps( - _mm512_bslli_epi128(_mm512_cvtepu16_epi32(v.reg), 2))) {} - - explicit FP32Vec16(const BF16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {} - - FP32Vec16 operator*(const FP32Vec16 &b) const { - return FP32Vec16(_mm512_mul_ps(reg, b.reg)); - } - - FP32Vec16 operator+(const FP32Vec16 &b) const { - return FP32Vec16(_mm512_add_ps(reg, b.reg)); - } - - FP32Vec16 operator-(const FP32Vec16 &b) const { - return FP32Vec16(_mm512_sub_ps(reg, b.reg)); - } - - FP32Vec16 operator/(const FP32Vec16 &b) const { - return FP32Vec16(_mm512_div_ps(reg, b.reg)); - } - - float reduce_sum() const { return _mm512_reduce_add_ps(reg); } - - template float reduce_sub_sum(int idx) { - static_assert(VEC_ELEM_NUM % group_size == 0); - constexpr uint32_t base_mask = (0xFFFF >> (16 - group_size)); - __mmask16 mask = _cvtu32_mask16(base_mask << (idx * group_size)); - return _mm512_mask_reduce_add_ps(mask, reg); - } - - void save(float *ptr) const { _mm512_storeu_ps(ptr, reg); } -}; -#else -struct FP32Vec16 : public Vec { - constexpr static int VEC_ELEM_NUM = 16; - - union AliasReg { - __m256 reg; - float values[8]; - }; - - __m256 reg_low; - __m256 reg_high; - - explicit FP32Vec16(float v) : reg_low(_mm256_set1_ps(v)), - reg_high(_mm256_set1_ps(v)) {} - - explicit FP32Vec16() : reg_low(_mm256_set1_ps(0.0)), - reg_high(_mm256_set1_ps(0.0)) {} - - explicit FP32Vec16(const float *ptr) : reg_low(_mm256_loadu_ps(ptr)), - reg_high(_mm256_loadu_ps(ptr + 8)) {} - - explicit FP32Vec16(__m256 low, __m256 high) : reg_low(low), reg_high(high) {} - - explicit FP32Vec16(const FP32Vec16 &data) : reg_low(data.reg_low), - reg_high(data.reg_high) {} - - explicit FP32Vec16(const FP32Vec4 &data) - : reg_low((__m256)_mm256_inserti128_si256( - _mm256_castsi128_si256((__m128i)data.reg), - (__m128i)data.reg, 1)), - reg_high((__m256)_mm256_inserti128_si256( - _mm256_castsi128_si256((__m128i)data.reg), - (__m128i)data.reg, 1)) {} - - explicit FP32Vec16(const FP32Vec8 &data) - : reg_low(data.reg), reg_high(data.reg) {} - - explicit FP32Vec16(const BF16Vec16 &v) { - __m128i low = _mm256_extractf128_si256(v.reg, 0); - __m128i high = _mm256_extractf128_si256(v.reg, 1); - - __m256i v_low_epi32 = _mm256_cvtepu16_epi32(low); - __m256i v_high_epi32 = _mm256_cvtepu16_epi32(high); - - __m256i v_low_shifted = _mm256_bslli_epi128(v_low_epi32, 2); - __m256i v_high_shifted = _mm256_bslli_epi128(v_high_epi32, 2); - - reg_low = _mm256_castsi256_ps(v_low_shifted); - reg_high = _mm256_castsi256_ps(v_high_shifted); - } - - explicit FP32Vec16(const BF16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {} - - FP32Vec16 operator*(const FP32Vec16 &b) const { - return FP32Vec16(_mm256_mul_ps(reg_low, b.reg_low), - _mm256_mul_ps(reg_high, b.reg_high)); - } - - FP32Vec16 operator+(const FP32Vec16 &b) const { - return FP32Vec16(_mm256_add_ps(reg_low, b.reg_low), - _mm256_add_ps(reg_high, b.reg_high)); - } - - FP32Vec16 operator-(const FP32Vec16 &b) const { - return FP32Vec16(_mm256_sub_ps(reg_low, b.reg_low), - _mm256_sub_ps(reg_high, b.reg_high)); - } - - FP32Vec16 operator/(const FP32Vec16 &b) const { - return FP32Vec16(_mm256_div_ps(reg_low, b.reg_low), - _mm256_div_ps(reg_high, b.reg_high)); - } - - float reduce_sum() const { - FP32Vec8 low = FP32Vec8(reg_low); - FP32Vec8 high = FP32Vec8(reg_high); - return low.reduce_sum() + high.reduce_sum(); - } - - template float reduce_sub_sum(int idx) { - float sum = 0.0; - static_assert(VEC_ELEM_NUM % group_size == 0); - constexpr uint32_t base_mask = (0xFFFF >> (16 - group_size)); - uint32_t mask = base_mask << (idx * group_size); - - AliasReg ar; - - auto func = [&sum, &mask, &ar](int i) { - int flag = mask & 0x1; - mask = mask >> 1; - if (flag != 0) sum += ar.values[i]; - }; - - ar.reg = reg_low; - unroll_loop(func); - - ar.reg = reg_high; - unroll_loop(func); - - return sum; - } - - void save(float *ptr) const { - _mm256_storeu_ps(ptr, reg_low); - _mm256_storeu_ps(ptr + 8, reg_high); - } -}; -#endif - -template struct VecType { using vec_type = void; }; - -template using vec_t = typename VecType::vec_type; - -template <> struct VecType { using vec_type = FP32Vec8; }; - -#ifdef __AVX512FP16__ -template <> struct VecType { using vec_type = FP16Vec16; }; -#endif - -template <> struct VecType { using vec_type = BF16Vec8; }; - -template void storeFP32(float v, T *ptr) { *ptr = v; } - -#ifdef __AVX512FP16__ -template <> inline void storeFP32(float v, c10::Half *ptr) { - *reinterpret_cast<_Float16 *>(ptr) = v; -} -#endif - -inline void fma(FP32Vec16 &acc, FP32Vec16 &a, FP32Vec16 &b) { - acc = acc + a * b; -} - -#ifdef __AVX512BF16__ -template <> inline void storeFP32(float v, c10::BFloat16 *ptr) { - *reinterpret_cast<__bfloat16 *>(ptr) = _mm_cvtness_sbh(v); -} - -inline BF16Vec8::BF16Vec8(const FP32Vec8 &v) - : reg((__m128i)_mm256_cvtneps_pbh(v.reg)) {} - -inline BF16Vec16::BF16Vec16(const FP32Vec16 &v) - : reg((__m256i)_mm512_cvtneps_pbh(v.reg)) {} - -inline void fma(FP32Vec16 &acc, BF16Vec32 &a, BF16Vec32 &b) { - acc.reg = _mm512_dpbf16_ps(acc.reg, (__m512bh)a.reg, (__m512bh)b.reg); -} -#else -template <> inline void storeFP32(float v, c10::BFloat16 *ptr) { - c10::BFloat16 __attribute__((__may_alias__)) *v_ptr = - reinterpret_cast(&v); - *ptr = *(v_ptr + 1); -} - -#ifdef __AVX512F__ -inline BF16Vec8::BF16Vec8(const FP32Vec8 &v) - : reg(_mm256_cvtepi32_epi16( - _mm256_bsrli_epi128(_mm256_castps_si256(v.reg), 2))) {} - -inline BF16Vec16::BF16Vec16(const FP32Vec16 &v) - : reg(_mm512_cvtepi32_epi16( - _mm512_bsrli_epi128(_mm512_castps_si512(v.reg), 2))) {} -#else -namespace{ -__m128i FP32Vec8_to_BF16Vec8_avx2(__m256 a) { - __m256i ai = _mm256_castps_si256(a); - ai = _mm256_srli_epi32(ai, 16); - ai = _mm256_packus_epi32(ai, ai); - ai = _mm256_permute4x64_epi64(ai, 0b00111001); - return _mm256_extracti128_si256(ai, 0); -} -} - -inline BF16Vec8::BF16Vec8(const FP32Vec8 &v) - : reg(FP32Vec8_to_BF16Vec8_avx2(v.reg)) {} - -inline BF16Vec16::BF16Vec16(const FP32Vec16 &v) { - BF16Vec8 low = BF16Vec8(FP32Vec8(v.reg_low)); - BF16Vec8 high = BF16Vec8(FP32Vec8(v.reg_high)); - reg = _mm256_insertf128_si256(_mm256_castsi128_si256(low.reg), high.reg, 1); -} -#endif // __AVX512F__ -#endif // __AVX512BF16__ - -inline void prefetch(const void *addr) { _mm_prefetch(addr, _MM_HINT_T1); } - -}; // namespace vec_op - -#endif diff --git a/csrc/cpu/layernorm.cpp b/csrc/cpu/layernorm.cpp index a76ad08928a2c..467f0dc84982c 100644 --- a/csrc/cpu/layernorm.cpp +++ b/csrc/cpu/layernorm.cpp @@ -2,10 +2,10 @@ namespace { template -void rms_norm_impl(scalar_t* __restrict__ out, - const scalar_t* __restrict__ input, - const scalar_t* __restrict__ weight, const float epsilon, - const int num_tokens, const int hidden_size) { +void rms_norm_impl(scalar_t *__restrict__ out, + const scalar_t *__restrict__ input, + const scalar_t *__restrict__ weight, const float epsilon, + const int num_tokens, const int hidden_size) { using scalar_vec_t = vec_op::vec_t; constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num(); TORCH_CHECK(hidden_size % VEC_ELEM_NUM == 0); @@ -41,11 +41,11 @@ void rms_norm_impl(scalar_t* __restrict__ out, } template -void fused_add_rms_norm_impl(scalar_t* __restrict__ input, - scalar_t* __restrict__ residual, - const scalar_t* __restrict__ weight, - const float epsilon, const int num_tokens, - const int hidden_size) { +void fused_add_rms_norm_impl(scalar_t *__restrict__ input, + scalar_t *__restrict__ residual, + const scalar_t *__restrict__ weight, + const float epsilon, const int num_tokens, + const int hidden_size) { using scalar_vec_t = vec_op::vec_t; constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num(); TORCH_CHECK(hidden_size % VEC_ELEM_NUM == 0); @@ -85,24 +85,24 @@ void fused_add_rms_norm_impl(scalar_t* __restrict__ input, } } } -} // namespace +} // namespace -void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, - double epsilon) { +void rms_norm(torch::Tensor &out, torch::Tensor &input, + torch::Tensor &weight, float epsilon) { int hidden_size = input.size(-1); int num_tokens = input.numel() / hidden_size; VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_impl", [&] { CPU_KERNEL_GUARD_IN(rms_norm_impl) rms_norm_impl(out.data_ptr(), input.data_ptr(), - weight.data_ptr(), epsilon, num_tokens, - hidden_size); + weight.data_ptr(), epsilon, num_tokens, + hidden_size); CPU_KERNEL_GUARD_OUT(rms_norm_impl) }); } -void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual, - torch::Tensor& weight, double epsilon) { +void fused_add_rms_norm(torch::Tensor &input, torch::Tensor &residual, + torch::Tensor &weight, float epsilon) { int hidden_size = input.size(-1); int num_tokens = input.numel() / hidden_size; diff --git a/csrc/cpu/pos_encoding.cpp b/csrc/cpu/pos_encoding.cpp index 96bce7dda0132..e9b3992204bb2 100644 --- a/csrc/cpu/pos_encoding.cpp +++ b/csrc/cpu/pos_encoding.cpp @@ -4,107 +4,107 @@ namespace { template void rotary_embedding_impl( - const int64_t* __restrict__ positions, // [batch_size, seq_len] or - // [num_tokens] - scalar_t* __restrict__ query, /// [batch_size, seq_len, num_heads, - /// head_size] or [num_tokens, num_heads, - /// head_size] - scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, - // head_size] or [num_tokens, num_kv_heads, - // head_size] - const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // - // 2] + const int64_t + *__restrict__ positions, // [batch_size, seq_len] or [num_tokens] + scalar_t + *__restrict__ query, /// [batch_size, seq_len, num_heads, head_size] or + /// [num_tokens, num_heads, head_size] + scalar_t + *__restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or + // [num_tokens, num_kv_heads, head_size] + const scalar_t + *__restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2] const int rot_dim, const int64_t query_stride, const int64_t key_stride, const int num_heads, const int num_kv_heads, const int head_size, const int num_tokens) { using scalar_vec_t = vec_op::vec_t; constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num(); + constexpr int ELEM_SIZE = sizeof(scalar_t); const int embed_dim = rot_dim / 2; - bool flag = (embed_dim % VEC_ELEM_NUM == 0); - const int loop_upper = flag ? embed_dim : embed_dim - VEC_ELEM_NUM; + TORCH_CHECK(embed_dim % VEC_ELEM_NUM == 0); - auto compute_loop = [&](const int64_t token_head, const scalar_t* cache_ptr, - scalar_t* qk) { - int j = 0; - for (; j < loop_upper; j += VEC_ELEM_NUM) { - const int rot_offset = j; - const int x_index = rot_offset; - const int y_index = embed_dim + rot_offset; +#pragma omp parallel for + for (int token_idx = 0; token_idx < num_tokens; ++token_idx) { + int64_t pos = positions[token_idx]; + const scalar_t *cache_ptr = cos_sin_cache + pos * rot_dim; + + for (int i = 0; i < num_heads; ++i) { + const int head_idx = i; + const int64_t token_head = + token_idx * query_stride + head_idx * head_size; + for (int j = 0; j < embed_dim; j += VEC_ELEM_NUM) { + const int rot_offset = j; + const int x_index = rot_offset; + const int y_index = embed_dim + rot_offset; - const int64_t out_x = token_head + x_index; - const int64_t out_y = token_head + y_index; + const int64_t out_x = token_head + x_index; + const int64_t out_y = token_head + y_index; - const scalar_vec_t cos(cache_ptr + x_index); - const scalar_vec_t sin(cache_ptr + y_index); + const scalar_vec_t cos(cache_ptr + x_index); + const scalar_vec_t sin(cache_ptr + y_index); - const scalar_vec_t q_x(qk + out_x); - const scalar_vec_t q_y(qk + out_y); + const scalar_vec_t q_x(query + out_x); + const scalar_vec_t q_y(query + out_y); - vec_op::FP32Vec8 fp32_cos(cos); - vec_op::FP32Vec8 fp32_sin(sin); + vec_op::FP32Vec8 fp32_cos(cos); + vec_op::FP32Vec8 fp32_sin(sin); - vec_op::FP32Vec8 fp32_q_x(q_x); - vec_op::FP32Vec8 fp32_q_y(q_y); + vec_op::FP32Vec8 fp32_q_x(q_x); + vec_op::FP32Vec8 fp32_q_y(q_y); - auto out1 = fp32_q_x * fp32_cos - fp32_q_y * fp32_sin; - scalar_vec_t(out1).save(qk + out_x); + auto out1 = fp32_q_x * fp32_cos - fp32_q_y * fp32_sin; + scalar_vec_t(out1).save(query + out_x); - auto out2 = fp32_q_y * fp32_cos + fp32_q_x * fp32_sin; - scalar_vec_t(out2).save(qk + out_y); + auto out2 = fp32_q_y * fp32_cos + fp32_q_x * fp32_sin; + scalar_vec_t(out2).save(query + out_y); + } } - if (!flag) { - for (; j < embed_dim; ++j) { - const int x_index = j; - const int y_index = embed_dim + j; + + for (int i = 0; i < num_kv_heads; ++i) { + const int head_idx = i; + const int64_t token_head = token_idx * key_stride + head_idx * head_size; + for (int j = 0; j < embed_dim; j += VEC_ELEM_NUM) { + const int rot_offset = j; + const int x_index = rot_offset; + const int y_index = embed_dim + rot_offset; const int64_t out_x = token_head + x_index; const int64_t out_y = token_head + y_index; - const float fp32_cos = cache_ptr[x_index]; - const float fp32_sin = cache_ptr[y_index]; - - const float fp32_q_x = qk[out_x]; - const float fp32_q_y = qk[out_y]; + const scalar_vec_t cos(cache_ptr + x_index); + const scalar_vec_t sin(cache_ptr + y_index); - qk[out_x] = fp32_q_x * fp32_cos - fp32_q_y * fp32_sin; - qk[out_y] = fp32_q_y * fp32_cos + fp32_q_x * fp32_sin; - } - } - }; + const scalar_vec_t k_x(key + out_x); + const scalar_vec_t k_y(key + out_y); -#pragma omp parallel for - for (int token_idx = 0; token_idx < num_tokens; ++token_idx) { - int64_t pos = positions[token_idx]; - const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim; + vec_op::FP32Vec8 fp32_cos(cos); + vec_op::FP32Vec8 fp32_sin(sin); - for (int i = 0; i < num_heads; ++i) { - const int head_idx = i; - const int64_t token_head = - token_idx * query_stride + head_idx * head_size; - compute_loop(token_head, cache_ptr, query); - } + vec_op::FP32Vec8 fp32_k_x(k_x); + vec_op::FP32Vec8 fp32_k_y(k_y); - for (int i = 0; i < num_kv_heads; ++i) { - const int head_idx = i; - const int64_t token_head = token_idx * key_stride + head_idx * head_size; - compute_loop(token_head, cache_ptr, key); + auto out1 = fp32_k_x * fp32_cos - fp32_k_y * fp32_sin; + scalar_vec_t(out1).save(key + out_x); + auto out2 = fp32_k_y * fp32_cos + fp32_k_x * fp32_sin; + scalar_vec_t(out2).save(key + out_y); + } } } } template void rotary_embedding_gptj_impl( - const int64_t* __restrict__ positions, // [batch_size, seq_len] or - // [num_tokens] - scalar_t* __restrict__ query, /// [batch_size, seq_len, num_heads, - /// head_size] or [num_tokens, num_heads, - /// head_size] - scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, - // head_size] or [num_tokens, num_kv_heads, - // head_size] - const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // - // 2] + const int64_t + *__restrict__ positions, // [batch_size, seq_len] or [num_tokens] + scalar_t + *__restrict__ query, /// [batch_size, seq_len, num_heads, head_size] or + /// [num_tokens, num_heads, head_size] + scalar_t + *__restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or + // [num_tokens, num_kv_heads, head_size] + const scalar_t + *__restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2] const int rot_dim, const int64_t query_stride, const int64_t key_stride, const int num_heads, const int num_kv_heads, const int head_size, const int num_tokens) { @@ -114,13 +114,13 @@ void rotary_embedding_gptj_impl( for (int token_idx = 0; token_idx < num_tokens; ++token_idx) { for (int i = 0; i < num_heads; ++i) { int64_t pos = positions[token_idx]; - const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim; - const scalar_t* cos_cache_ptr = cache_ptr; - const scalar_t* sin_cache_ptr = cache_ptr + embed_dim; + const scalar_t *cache_ptr = cos_sin_cache + pos * rot_dim; + const scalar_t *cos_cache_ptr = cache_ptr; + const scalar_t *sin_cache_ptr = cache_ptr + embed_dim; const int head_idx = i; const int64_t token_head = token_idx * query_stride + head_idx * head_size; - scalar_t* head_query = token_head + query; + scalar_t *head_query = token_head + query; for (int j = 0; j < embed_dim; j += 1) { const int rot_offset = j; const int x_index = 2 * rot_offset; @@ -142,12 +142,12 @@ void rotary_embedding_gptj_impl( for (int token_idx = 0; token_idx < num_tokens; ++token_idx) { for (int i = 0; i < num_kv_heads; ++i) { int64_t pos = positions[token_idx]; - const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim; - const scalar_t* cos_cache_ptr = cache_ptr; - const scalar_t* sin_cache_ptr = cache_ptr + embed_dim; + const scalar_t *cache_ptr = cos_sin_cache + pos * rot_dim; + const scalar_t *cos_cache_ptr = cache_ptr; + const scalar_t *sin_cache_ptr = cache_ptr + embed_dim; const int head_idx = i; const int64_t token_head = token_idx * key_stride + head_idx * head_size; - scalar_t* head_key = key + token_head; + scalar_t *head_key = key + token_head; for (int j = 0; j < embed_dim; j += 1) { const int rot_offset = j; const int x_index = 2 * rot_offset; @@ -165,11 +165,11 @@ void rotary_embedding_gptj_impl( } } } -}; // namespace +}; // namespace -void rotary_embedding(torch::Tensor& positions, torch::Tensor& query, - torch::Tensor& key, int64_t head_size, - torch::Tensor& cos_sin_cache, bool is_neox) { +void rotary_embedding(torch::Tensor &positions, torch::Tensor &query, + torch::Tensor &key, int head_size, + torch::Tensor &cos_sin_cache, bool is_neox) { int num_tokens = query.numel() / query.size(-1); int rot_dim = cos_sin_cache.size(1); int num_heads = query.size(-1) / head_size; diff --git a/csrc/cpu/pybind.cpp b/csrc/cpu/pybind.cpp new file mode 100644 index 0000000000000..bba044087f37c --- /dev/null +++ b/csrc/cpu/pybind.cpp @@ -0,0 +1,73 @@ +#include "cache.h" +#include "cuda_utils.h" +#include "ops.h" +#include + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + // vLLM custom ops + pybind11::module ops = m.def_submodule("ops", "vLLM custom operators"); + + // Attention ops + ops.def( + "paged_attention_v1", + &paged_attention_v1, + "Compute the attention between an input query and the cached keys/values using PagedAttention."); + ops.def( + "paged_attention_v2", + &paged_attention_v2, + "PagedAttention V2."); + + // Activation ops + ops.def( + "silu_and_mul", + &silu_and_mul, + "Activation function used in SwiGLU."); + ops.def( + "gelu_and_mul", + &gelu_and_mul, + "Activation function used in GeGLU with `none` approximation."); + ops.def( + "gelu_tanh_and_mul", + &gelu_tanh_and_mul, + "Activation function used in GeGLU with `tanh` approximation."); + ops.def( + "gelu_new", + &gelu_new, + "GELU implementation used in GPT-2."); + ops.def( + "gelu_fast", + &gelu_fast, + "Approximate GELU implementation."); + + // Layernorm + ops.def( + "rms_norm", + &rms_norm, + "Apply Root Mean Square (RMS) Normalization to the input tensor."); + + ops.def( + "fused_add_rms_norm", + &fused_add_rms_norm, + "In-place fused Add and RMS Normalization"); + + // Rotary embedding + ops.def( + "rotary_embedding", + &rotary_embedding, + "Apply GPT-NeoX or GPT-J style rotary embedding to query and key"); + + // Cache ops + pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops"); + cache_ops.def( + "swap_blocks", + &swap_blocks, + "Swap in (out) the cache blocks from src to dst"); + cache_ops.def( + "copy_blocks", + ©_blocks, + "Copy the cache blocks from src to dst"); + cache_ops.def( + "reshape_and_cache", + &reshape_and_cache, + "Reshape the key and value tensors and cache them"); +} diff --git a/csrc/cpu/torch_bindings.cpp b/csrc/cpu/torch_bindings.cpp deleted file mode 100644 index 39e8cf3ed3c10..0000000000000 --- a/csrc/cpu/torch_bindings.cpp +++ /dev/null @@ -1,110 +0,0 @@ -#include "cache.h" -#include "ops.h" -#include "registration.h" - -#include - -TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { - // vLLM custom ops - - // Attention ops - // Compute the attention between an input query and the cached keys/values - // using PagedAttention. - ops.def( - "paged_attention_v1(" - " Tensor! out, Tensor query, Tensor key_cache," - " Tensor value_cache, int num_kv_heads, float scale," - " Tensor block_tables, Tensor seq_lens, int block_size," - " int max_seq_len, Tensor? alibi_slopes," - " str kv_cache_dtype, float kv_scale, int tp_rank," - " int blocksparse_local_blocks," - " int blocksparse_vert_stride, int blocksparse_block_size," - " int blocksparse_head_sliding_step) -> ()"); - ops.impl("paged_attention_v1", torch::kCPU, &paged_attention_v1); - - // PagedAttention V2. - ops.def( - "paged_attention_v2(" - " Tensor! out, Tensor exp_sums, Tensor max_logits," - " Tensor tmp_out, Tensor query, Tensor key_cache," - " Tensor value_cache, int num_kv_heads, float scale," - " Tensor block_tables, Tensor seq_lens, int block_size," - " int max_seq_len, Tensor? alibi_slopes," - " str kv_cache_dtype, float kv_scale, int tp_rank," - " int blocksparse_local_blocks," - " int blocksparse_vert_stride, int blocksparse_block_size," - " int blocksparse_head_sliding_step) -> ()"); - ops.impl("paged_attention_v2", torch::kCPU, &paged_attention_v2); - - // Activation ops - - // Activation function used in SwiGLU. - ops.def("silu_and_mul(Tensor! out, Tensor input) -> ()"); - ops.impl("silu_and_mul", torch::kCPU, &silu_and_mul); - - // Activation function used in GeGLU with `none` approximation. - ops.def("gelu_and_mul(Tensor! out, Tensor input) -> ()"); - ops.impl("gelu_and_mul", torch::kCPU, &gelu_and_mul); - - // Activation function used in GeGLU with `tanh` approximation. - ops.def("gelu_tanh_and_mul(Tensor! out, Tensor input) -> ()"); - ops.impl("gelu_tanh_and_mul", torch::kCPU, &gelu_tanh_and_mul); - - // GELU implementation used in GPT-2. - ops.def("gelu_new(Tensor! out, Tensor input) -> ()"); - ops.impl("gelu_new", torch::kCPU, &gelu_new); - - // Approximate GELU implementation. - ops.def("gelu_fast(Tensor! out, Tensor input) -> ()"); - ops.impl("gelu_fast", torch::kCPU, &gelu_fast); - - // Quick GELU implementation. - ops.def("gelu_quick(Tensor! out, Tensor input) -> ()"); - ops.impl("gelu_quick", torch::kCPU, &gelu_quick); - - // Layernorm - // Apply Root Mean Square (RMS) Normalization to the input tensor. - ops.def( - "rms_norm(Tensor! out, Tensor input, Tensor weight, float epsilon) -> " - "()"); - ops.impl("rms_norm", torch::kCPU, &rms_norm); - - // In-place fused Add and RMS Normalization. - ops.def( - "fused_add_rms_norm(Tensor! input, Tensor! residual, Tensor weight, " - "float epsilon) -> ()"); - ops.impl("fused_add_rms_norm", torch::kCPU, &fused_add_rms_norm); - - // Rotary embedding - // Apply GPT-NeoX or GPT-J style rotary embedding to query and key. - ops.def( - "rotary_embedding(Tensor positions, Tensor! query," - " Tensor! key, int head_size," - " Tensor cos_sin_cache, bool is_neox) -> ()"); - ops.impl("rotary_embedding", torch::kCPU, &rotary_embedding); -} - -TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { - // Cache ops - // Swap in (out) the cache blocks from src to dst. - cache_ops.def( - "swap_blocks(Tensor src, Tensor! dst, Tensor block_mapping) -> ()"); - cache_ops.impl("swap_blocks", torch::kCPU, &swap_blocks); - - // Copy the cache blocks from src to dst. - cache_ops.def( - "copy_blocks(Tensor[]! key_caches, Tensor[]! value_caches, Tensor " - "block_mapping) -> ()"); - cache_ops.impl("copy_blocks", torch::kCPU, ©_blocks); - - // Reshape the key and value tensors and cache them. - cache_ops.def( - "reshape_and_cache(Tensor key, Tensor value," - " Tensor! key_cache, Tensor! value_cache," - " Tensor slot_mapping," - " str kv_cache_dtype," - " float kv_scale) -> ()"); - cache_ops.impl("reshape_and_cache", torch::kCPU, &reshape_and_cache); -} - -REGISTER_EXTENSION(TORCH_EXTENSION_NAME) diff --git a/csrc/cuda_compat.h b/csrc/cuda_compat.h index 82e55613d915a..c711d8d1b24b9 100644 --- a/csrc/cuda_compat.h +++ b/csrc/cuda_compat.h @@ -1,7 +1,7 @@ #pragma once #ifdef USE_ROCM - #include +#include #endif #ifndef USE_ROCM @@ -17,14 +17,9 @@ #endif #ifndef USE_ROCM - #define VLLM_SHFL_XOR_SYNC(var, lane_mask) \ - __shfl_xor_sync(uint32_t(-1), var, lane_mask) - #define VLLM_SHFL_XOR_SYNC_WIDTH(var, lane_mask, width) \ - __shfl_xor_sync(uint32_t(-1), var, lane_mask, width) + #define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor_sync(uint32_t(-1), var, lane_mask) #else #define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor(var, lane_mask) - #define VLLM_SHFL_XOR_SYNC_WIDTH(var, lane_mask, width) \ - __shfl_xor(var, lane_mask, width) #endif #ifndef USE_ROCM @@ -33,13 +28,6 @@ #define VLLM_SHFL_SYNC(var, src_lane) __shfl(var, src_lane) #endif -#ifndef USE_ROCM - #define VLLM_SHFL_DOWN_SYNC(var, lane_delta) \ - __shfl_down_sync(uint32_t(-1), var, lane_delta) -#else - #define VLLM_SHFL_DOWN_SYNC(var, lane_delta) __shfl_down(var, lane_delta) -#endif - #ifndef USE_ROCM #define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \ cudaFuncSetAttribute(FUNC, cudaFuncAttributeMaxDynamicSharedMemorySize, VAL) @@ -47,3 +35,4 @@ #define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \ hipFuncSetAttribute(FUNC, hipFuncAttributeMaxDynamicSharedMemorySize, VAL) #endif + diff --git a/csrc/cuda_utils.h b/csrc/cuda_utils.h index 73944f4c14890..1483484faeb4a 100644 --- a/csrc/cuda_utils.h +++ b/csrc/cuda_utils.h @@ -1,5 +1,10 @@ #pragma once -int64_t get_device_attribute(int64_t attribute, int64_t device_id); +#include -int64_t get_max_shared_memory_per_block_device_attribute(int64_t device_id); +int get_device_attribute( + int attribute, + int device_id); + +int get_max_shared_memory_per_block_device_attribute( + int device_id); diff --git a/csrc/cuda_utils_kernels.cu b/csrc/cuda_utils_kernels.cu index d6f9eb646fad5..1a443ef3620cc 100644 --- a/csrc/cuda_utils_kernels.cu +++ b/csrc/cuda_utils_kernels.cu @@ -2,28 +2,34 @@ #include #include #endif -int64_t get_device_attribute(int64_t attribute, int64_t device_id) { - int device, value; - if (device_id < 0) { - cudaGetDevice(&device); - } else { - device = device_id; - } - cudaDeviceGetAttribute(&value, static_cast(attribute), - device); - return value; +int get_device_attribute( + int attribute, + int device_id) +{ + int device, value; + if (device_id < 0) { + cudaGetDevice(&device); + } + else { + device = device_id; + } + cudaDeviceGetAttribute(&value, static_cast(attribute), device); + return value; } -int64_t get_max_shared_memory_per_block_device_attribute(int64_t device_id) { - int64_t attribute; - // https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html - // cudaDevAttrMaxSharedMemoryPerBlockOptin = 97 if not is_hip() else 74 + +int get_max_shared_memory_per_block_device_attribute( + int device_id) +{ +int attribute; +// https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html +// cudaDevAttrMaxSharedMemoryPerBlockOptin = 97 if not is_hip() else 74 #ifdef USE_ROCM - attribute = hipDeviceAttributeMaxSharedMemoryPerBlock; + attribute = hipDeviceAttributeMaxSharedMemoryPerBlock; #else - attribute = cudaDevAttrMaxSharedMemoryPerBlockOptin; + attribute = cudaDevAttrMaxSharedMemoryPerBlockOptin; #endif - return get_device_attribute(attribute, device_id); + return get_device_attribute(attribute, device_id); } diff --git a/csrc/custom_all_reduce.cu b/csrc/custom_all_reduce.cu index 82a3563979f16..3906dcfc80dbf 100644 --- a/csrc/custom_all_reduce.cu +++ b/csrc/custom_all_reduce.cu @@ -1,17 +1,17 @@ #include #include #include -#include +#include #include "custom_all_reduce.cuh" -// fake pointer type, must match fptr_t type in ops.h -using fptr_t = int64_t; -static_assert(sizeof(void*) == sizeof(fptr_t)); +// fake pointer type +using fptr_t = uint64_t; +static_assert(sizeof(void *) == sizeof(fptr_t)); -fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data, - const std::vector& handles, - const std::vector& offsets, int64_t rank, +fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data, + const std::vector &handles, + const std::vector &offsets, int rank, bool full_nvlink) { int world_size = offsets.size(); if (world_size > 8) @@ -29,7 +29,7 @@ fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data, std::memcpy(&ipc_handles[i], handles[i].data(), sizeof(cudaIpcMemHandle_t)); } return (fptr_t) new vllm::CustomAllreduce( - reinterpret_cast(meta.data_ptr()), rank_data.data_ptr(), + reinterpret_cast(meta.data_ptr()), rank_data.data_ptr(), rank_data.numel(), ipc_handles, offsets, rank, full_nvlink); } @@ -49,13 +49,13 @@ fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data, * 5. A[None].expand(2, -1, -1, -1): Not OK * 6. A[:, 1:, 1:]: Not OK */ -bool _is_weak_contiguous(torch::Tensor& t) { +bool _is_weak_contiguous(torch::Tensor &t) { return t.is_contiguous() || (t.storage().nbytes() - t.storage_offset() * t.element_size() == t.numel() * t.element_size()); } -bool should_custom_ar(torch::Tensor& inp, int64_t max_size, int64_t world_size, +bool should_custom_ar(torch::Tensor &inp, int max_size, int world_size, bool full_nvlink) { auto inp_size = inp.numel() * inp.element_size(); // custom allreduce requires input byte size to be multiples of 16 @@ -67,27 +67,28 @@ bool should_custom_ar(torch::Tensor& inp, int64_t max_size, int64_t world_size, return false; } -void _all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, +void _all_reduce(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out, cudaStream_t stream) { - auto fa = reinterpret_cast(_fa); + auto fa = reinterpret_cast(_fa); TORCH_CHECK(_is_weak_contiguous(out)); switch (out.scalar_type()) { case at::ScalarType::Float: { - fa->allreduce(stream, reinterpret_cast(inp.data_ptr()), - reinterpret_cast(out.data_ptr()), + fa->allreduce(stream, reinterpret_cast(inp.data_ptr()), + reinterpret_cast(out.data_ptr()), out.numel()); break; } case at::ScalarType::Half: { - fa->allreduce(stream, reinterpret_cast(inp.data_ptr()), - reinterpret_cast(out.data_ptr()), out.numel()); + fa->allreduce(stream, reinterpret_cast(inp.data_ptr()), + reinterpret_cast(out.data_ptr()), + out.numel()); break; } #if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)) case at::ScalarType::BFloat16: { fa->allreduce( - stream, reinterpret_cast(inp.data_ptr()), - reinterpret_cast(out.data_ptr()), out.numel()); + stream, reinterpret_cast(inp.data_ptr()), + reinterpret_cast(out.data_ptr()), out.numel()); break; } #endif @@ -97,7 +98,7 @@ void _all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, } } -void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out) { +void all_reduce_reg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out) { const at::cuda::OptionalCUDAGuard device_guard(device_of(inp)); auto stream = c10::cuda::getCurrentCUDAStream().stream(); TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type()); @@ -105,8 +106,8 @@ void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out) { _all_reduce(_fa, inp, out, stream); } -void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer, - torch::Tensor& out) { +void all_reduce_unreg(fptr_t _fa, torch::Tensor &inp, torch::Tensor ®_buffer, + torch::Tensor &out) { const at::cuda::OptionalCUDAGuard device_guard(device_of(inp)); auto stream = c10::cuda::getCurrentCUDAStream().stream(); @@ -121,33 +122,27 @@ void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer, } void dispose(fptr_t _fa) { - auto fa = reinterpret_cast(_fa); + auto fa = reinterpret_cast(_fa); delete fa; } -int64_t meta_size() { return sizeof(vllm::Signal); } +int meta_size() { return sizeof(vllm::Signal); } -void register_buffer(fptr_t _fa, torch::Tensor& t, - const std::vector& handles, - const std::vector& offsets) { - auto fa = reinterpret_cast(_fa); +void register_buffer(fptr_t _fa, torch::Tensor &t, + const std::vector &handles, + const std::vector &offsets) { + auto fa = reinterpret_cast(_fa); fa->register_buffer(handles, offsets, t.data_ptr()); } -std::tuple> get_graph_buffer_ipc_meta( +std::pair, std::vector> get_graph_buffer_ipc_meta( fptr_t _fa) { - auto fa = reinterpret_cast(_fa); - auto [handle_bytes, offsets] = fa->get_graph_buffer_ipc_meta(); - auto options = - torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU); - auto handles = - torch::empty({static_cast(handle_bytes.size())}, options); - std::memcpy(handles.data_ptr(), handle_bytes.data(), handle_bytes.size()); - return {handles, std::move(offsets)}; + auto fa = reinterpret_cast(_fa); + return fa->get_graph_buffer_ipc_meta(); } -void register_graph_buffers(fptr_t _fa, const std::vector& handles, - const std::vector>& offsets) { - auto fa = reinterpret_cast(_fa); +void register_graph_buffers(fptr_t _fa, const std::vector &handles, + const std::vector> &offsets) { + auto fa = reinterpret_cast(_fa); fa->register_graph_buffers(handles, offsets); } diff --git a/csrc/custom_all_reduce.cuh b/csrc/custom_all_reduce.cuh index 1ed49b8aa9cae..750e68d42f6c6 100644 --- a/csrc/custom_all_reduce.cuh +++ b/csrc/custom_all_reduce.cuh @@ -31,9 +31,9 @@ struct Signal { alignas(128) uint32_t end[kMaxBlocks][8]; }; -struct __align__(16) RankData { const void* __restrict__ ptrs[8]; }; +struct __align__(16) RankData { const void *__restrict__ ptrs[8]; }; -struct __align__(16) RankSignals { volatile Signal* signals[8]; }; +struct __align__(16) RankSignals { volatile Signal *signals[8]; }; // like std::array, but aligned template @@ -68,11 +68,11 @@ DINLINE half downcast_s(float val) { // scalar add functions // for some reason when compiling with Pytorch, the + operator for half and // bfloat is disabled so we call the intrinsics directly -DINLINE half& assign_add(half& a, half b) { +DINLINE half &assign_add(half &a, half b) { a = __hadd(a, b); return a; } -DINLINE float& assign_add(float& a, float b) { return a += b; } +DINLINE float &assign_add(float &a, float b) { return a += b; } #if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)) DINLINE float upcast_s(nv_bfloat16 val) { return __bfloat162float(val); } @@ -80,14 +80,14 @@ template <> DINLINE nv_bfloat16 downcast_s(float val) { return __float2bfloat16(val); } -DINLINE nv_bfloat16& assign_add(nv_bfloat16& a, nv_bfloat16 b) { +DINLINE nv_bfloat16 &assign_add(nv_bfloat16 &a, nv_bfloat16 b) { a = __hadd(a, b); return a; } #endif template -DINLINE array_t& packed_assign_add(array_t& a, array_t b) { +DINLINE array_t &packed_assign_add(array_t &a, array_t b) { #pragma unroll for (int i = 0; i < N; i++) { assign_add(a.data[i], b.data[i]); @@ -128,7 +128,7 @@ DINLINE O downcast(array_t val) { // prior memory accesses. Note: volatile writes will not be reordered against // other volatile writes. template -DINLINE void start_sync(const RankSignals& sg, volatile Signal* self_sg, +DINLINE void start_sync(const RankSignals &sg, volatile Signal *self_sg, int rank) { if (threadIdx.x < ngpus) { // reset flag for next time @@ -137,7 +137,8 @@ DINLINE void start_sync(const RankSignals& sg, volatile Signal* self_sg, // Latency = 1 p2p write sg.signals[threadIdx.x]->start[blockIdx.x][rank] = 1; // wait until we got true from all ranks - while (!self_sg->start[blockIdx.x][threadIdx.x]); + while (!self_sg->start[blockIdx.x][threadIdx.x]) + ; } __syncthreads(); } @@ -146,13 +147,13 @@ DINLINE void start_sync(const RankSignals& sg, volatile Signal* self_sg, // barrier in the all reduce kernel. If it's the final synchronization barrier, // we don't need to make any visibility guarantees for prior memory accesses. template -DINLINE void end_sync(const RankSignals& sg, volatile Signal* self_sg, +DINLINE void end_sync(const RankSignals &sg, volatile Signal *self_sg, int rank) { __syncthreads(); // eliminate the case that prior writes are not visible after signals become // visible. Note that I did not managed to make this happen through a lot of // testing. Might be the case that hardware provides stronger guarantee than - // the memory model. + // the memory model. if constexpr (!final_sync) __threadfence_system(); if (threadIdx.x < ngpus) { // reset flag for next time @@ -161,13 +162,14 @@ DINLINE void end_sync(const RankSignals& sg, volatile Signal* self_sg, // Latency = 1 p2p write sg.signals[threadIdx.x]->end[blockIdx.x][rank] = 1; // wait until we got true from all ranks - while (!self_sg->end[blockIdx.x][threadIdx.x]); + while (!self_sg->end[blockIdx.x][threadIdx.x]) + ; } if constexpr (!final_sync) __syncthreads(); } template -DINLINE P packed_reduce(const P* ptrs[], int idx) { +DINLINE P packed_reduce(const P *ptrs[], int idx) { A tmp = upcast(ptrs[0][idx]); #pragma unroll for (int i = 1; i < ngpus; i++) { @@ -178,8 +180,8 @@ DINLINE P packed_reduce(const P* ptrs[], int idx) { template __global__ void __launch_bounds__(512, 1) - cross_device_reduce_1stage(RankData* _dp, RankSignals sg, - volatile Signal* self_sg, T* __restrict__ result, + cross_device_reduce_1stage(RankData *_dp, RankSignals sg, + volatile Signal *self_sg, T *__restrict__ result, int rank, int size) { using P = typename packed_t::P; using A = typename packed_t::A; @@ -190,20 +192,21 @@ __global__ void __launch_bounds__(512, 1) // do the actual reduction for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; idx += gridDim.x * blockDim.x) { - ((P*)result)[idx] = packed_reduce((const P**)&dp.ptrs[0], idx); + ((P *)result)[idx] = + packed_reduce((const P **)&dp.ptrs[0], idx); } end_sync(sg, self_sg, rank); } template -DINLINE P* get_tmp_buf(volatile Signal* sg) { - return (P*)(((Signal*)sg) + 1); +DINLINE P *get_tmp_buf(volatile Signal *sg) { + return (P *)(((Signal *)sg) + 1); } template __global__ void __launch_bounds__(512, 1) - cross_device_reduce_2stage(RankData* _dp, RankSignals sg, - volatile Signal* self_sg, T* __restrict__ result, + cross_device_reduce_2stage(RankData *_dp, RankSignals sg, + volatile Signal *self_sg, T *__restrict__ result, int rank, int size) { int tid = blockIdx.x * blockDim.x + threadIdx.x; int stride = gridDim.x * blockDim.x; @@ -213,12 +216,12 @@ __global__ void __launch_bounds__(512, 1) int start = rank * part; int end = rank == ngpus - 1 ? size : start + part; int largest_part = part + size % ngpus; - const P* ptrs[ngpus]; - P* tmps[ngpus]; + const P *ptrs[ngpus]; + P *tmps[ngpus]; #pragma unroll for (int i = 0; i < ngpus; i++) { int target = (rank + i) % ngpus; - ptrs[i] = (const P*)_dp->ptrs[target]; + ptrs[i] = (const P *)_dp->ptrs[target]; tmps[i] = get_tmp_buf

(sg.signals[target]); } auto tmp_out = tmps[0]; @@ -240,7 +243,7 @@ __global__ void __launch_bounds__(512, 1) int gather_from_rank = ((rank + i) % ngpus); if (gather_from_rank == ngpus - 1 || idx < part) { int dst_idx = gather_from_rank * part + idx; - ((P*)result)[dst_idx] = tmps[i][idx]; + ((P *)result)[dst_idx] = tmps[i][idx]; } } } @@ -258,14 +261,14 @@ class CustomAllreduce { // below are device pointers RankSignals sg_; - std::unordered_map buffers_; - Signal* self_sg_; + std::unordered_map buffers_; + Signal *self_sg_; // stores the registered device pointers from all ranks RankData *d_rank_data_base_, *d_rank_data_end_; - std::vector graph_unreg_buffers_; + std::vector graph_unreg_buffers_; // a map from IPC handles to opened IPC pointers - std::map ipc_handles_; + std::map ipc_handles_; /** * meta is a pointer to device metadata and temporary buffer for allreduce. @@ -276,22 +279,22 @@ class CustomAllreduce { * note: this class does not own any device memory. Any required buffers * are passed in from the constructor */ - CustomAllreduce(Signal* meta, void* rank_data, size_t rank_data_sz, - const cudaIpcMemHandle_t* handles, - const std::vector& offsets, int rank, + CustomAllreduce(Signal *meta, void *rank_data, size_t rank_data_sz, + const cudaIpcMemHandle_t *handles, + const std::vector &offsets, int rank, bool full_nvlink = true) : rank_(rank), world_size_(offsets.size()), full_nvlink_(full_nvlink), self_sg_(meta), - d_rank_data_base_(reinterpret_cast(rank_data)), + d_rank_data_base_(reinterpret_cast(rank_data)), d_rank_data_end_(d_rank_data_base_ + rank_data_sz / sizeof(RankData)) { for (int i = 0; i < world_size_; i++) { - Signal* rank_sg; + Signal *rank_sg; if (i != rank_) { - char* handle = open_ipc_handle(&handles[i]); + char *handle = open_ipc_handle(&handles[i]); handle += offsets[i]; - rank_sg = (Signal*)handle; + rank_sg = (Signal *)handle; } else { rank_sg = self_sg_; } @@ -299,13 +302,13 @@ class CustomAllreduce { } } - char* open_ipc_handle(const void* ipc_handle) { + char *open_ipc_handle(const void *ipc_handle) { auto [it, new_handle] = - ipc_handles_.insert({*((IPC_KEY*)ipc_handle), nullptr}); + ipc_handles_.insert({*((IPC_KEY *)ipc_handle), nullptr}); if (new_handle) { - char* ipc_ptr; - CUDACHECK(cudaIpcOpenMemHandle((void**)&ipc_ptr, - *((const cudaIpcMemHandle_t*)ipc_handle), + char *ipc_ptr; + CUDACHECK(cudaIpcOpenMemHandle((void **)&ipc_ptr, + *((const cudaIpcMemHandle_t *)ipc_handle), cudaIpcMemLazyEnablePeerAccess)); it->second = ipc_ptr; } @@ -320,7 +323,7 @@ class CustomAllreduce { std::vector offsets(num_buffers); for (int i = 0; i < num_buffers; i++) { auto ptr = graph_unreg_buffers_[i]; - void* base_ptr; + void *base_ptr; // note: must share the base address of each allocation, or we get wrong // address if (cuPointerGetAttribute(&base_ptr, @@ -328,8 +331,8 @@ class CustomAllreduce { (CUdeviceptr)ptr) != CUDA_SUCCESS) throw std::runtime_error("failed to get pointer attr"); CUDACHECK(cudaIpcGetMemHandle( - (cudaIpcMemHandle_t*)&handles[i * handle_sz], base_ptr)); - offsets[i] = ((char*)ptr) - ((char*)base_ptr); + (cudaIpcMemHandle_t *)&handles[i * handle_sz], base_ptr)); + offsets[i] = ((char *)ptr) - ((char *)base_ptr); } return std::make_pair(handles, offsets); } @@ -341,13 +344,13 @@ class CustomAllreduce { std::to_string(d_rank_data_base_ + num - d_rank_data_end_)); } - void register_buffer(const std::vector& handles, - const std::vector& offsets, void* self) { + void register_buffer(const std::vector &handles, + const std::vector &offsets, void *self) { check_rank_data_capacity(); RankData data; for (int i = 0; i < world_size_; i++) { if (i != rank_) { - char* handle = open_ipc_handle(handles[i].data()); + char *handle = open_ipc_handle(handles[i].data()); handle += offsets[i]; data.ptrs[i] = handle; } else { @@ -368,17 +371,17 @@ class CustomAllreduce { // got a different address. IPC handles have internal reference counting // mechanism so overhead should be small. void register_graph_buffers( - const std::vector& handles, - const std::vector>& offsets) { + const std::vector &handles, + const std::vector> &offsets) { auto num_buffers = graph_unreg_buffers_.size(); check_rank_data_capacity(num_buffers); std::vector rank_data(num_buffers); for (int i = 0; i < num_buffers; i++) { auto self_ptr = graph_unreg_buffers_[i]; - auto& rd = rank_data[i]; + auto &rd = rank_data[i]; for (int j = 0; j < world_size_; j++) { if (j != rank_) { - char* handle = + char *handle = open_ipc_handle(&handles[j][i * sizeof(cudaIpcMemHandle_t)]); handle += offsets[j][i]; rd.ptrs[j] = handle; @@ -402,7 +405,7 @@ class CustomAllreduce { * will cause contention on NVLink bus. */ template - void allreduce(cudaStream_t stream, T* input, T* output, int size, + void allreduce(cudaStream_t stream, T *input, T *output, int size, int threads = 512, int block_limit = 36) { auto d = packed_t::P::size; if (size % d != 0) @@ -415,7 +418,7 @@ class CustomAllreduce { std::to_string(kMaxBlocks) + ". Got " + std::to_string(block_limit)); - RankData* ptrs; + RankData *ptrs; cudaStreamCaptureStatus status; CUDACHECK(cudaStreamIsCapturing(stream, &status)); if (status == cudaStreamCaptureStatusActive) { diff --git a/csrc/custom_all_reduce_test.cu b/csrc/custom_all_reduce_test.cu index f7868233076cd..c34a50389c21c 100644 --- a/csrc/custom_all_reduce_test.cu +++ b/csrc/custom_all_reduce_test.cu @@ -48,7 +48,7 @@ __global__ void dummy_kernel() { } template -__global__ void set_data(T* data, int size, int myRank) { +__global__ void set_data(T *data, int size, int myRank) { for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; idx += gridDim.x * blockDim.x) { data[idx] = myRank * 0.11f; @@ -56,8 +56,8 @@ __global__ void set_data(T* data, int size, int myRank) { } template -__global__ void convert_data(const T* data1, const T* data2, double* fdata1, - double* fdata2, int size) { +__global__ void convert_data(const T *data1, const T *data2, double *fdata1, + double *fdata2, int size) { for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; idx += gridDim.x * blockDim.x) { fdata1[idx] = data1[idx]; @@ -65,7 +65,7 @@ __global__ void convert_data(const T* data1, const T* data2, double* fdata1, } } -__global__ void init_rand(curandState_t* state, int size, int nRanks) { +__global__ void init_rand(curandState_t *state, int size, int nRanks) { for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; idx += gridDim.x * blockDim.x) { for (int i = 0; i < nRanks; i++) { @@ -75,7 +75,7 @@ __global__ void init_rand(curandState_t* state, int size, int nRanks) { } template -__global__ void gen_data(curandState_t* state, T* data, double* ground_truth, +__global__ void gen_data(curandState_t *state, T *data, double *ground_truth, int myRank, int nRanks, int size) { for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; idx += gridDim.x * blockDim.x) { @@ -91,9 +91,9 @@ __global__ void gen_data(curandState_t* state, T* data, double* ground_truth, } template -void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit, +void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit, int data_size, bool performance_test) { - T* result; + T *result; cudaStream_t stream; CUDACHECK(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking)); CUDACHECK(cudaMalloc(&result, data_size * sizeof(T))); @@ -101,8 +101,8 @@ void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit, cudaIpcMemHandle_t self_data_handle; cudaIpcMemHandle_t data_handles[8]; - vllm::Signal* buffer; - T* self_data_copy; + vllm::Signal *buffer; + T *self_data_copy; /** * Allocate IPC buffer * @@ -125,22 +125,22 @@ void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit, MPI_BYTE, data_handles, sizeof(cudaIpcMemHandle_t), MPI_BYTE, MPI_COMM_WORLD)); - void* rank_data; + void *rank_data; size_t rank_data_sz = 16 * 1024 * 1024; CUDACHECK(cudaMalloc(&rank_data, rank_data_sz)); std::vector offsets(nRanks, 0); vllm::CustomAllreduce fa(buffer, rank_data, rank_data_sz, data_handles, offsets, myRank); - auto* self_data = - reinterpret_cast(reinterpret_cast(buffer) + - sizeof(vllm::Signal) + data_size * sizeof(T)); + auto *self_data = + reinterpret_cast(reinterpret_cast(buffer) + + sizeof(vllm::Signal) + data_size * sizeof(T)); // hack buffer registration { std::vector handles; handles.reserve(nRanks); for (int i = 0; i < nRanks; i++) { - char* begin = (char*)&data_handles[i]; - char* end = (char*)&data_handles[i + 1]; + char *begin = (char *)&data_handles[i]; + char *end = (char *)&data_handles[i + 1]; handles.emplace_back(begin, end); } std::vector offsets(nRanks, @@ -148,9 +148,9 @@ void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit, fa.register_buffer(handles, offsets, self_data); } - double* ground_truth; + double *ground_truth; CUDACHECK(cudaMallocHost(&ground_truth, data_size * sizeof(double))); - curandState_t* states; + curandState_t *states; CUDACHECK(cudaMalloc(&states, sizeof(curandState_t) * nRanks * data_size)); init_rand<<<108, 1024, 0, stream>>>(states, data_size, nRanks); gen_data<<<108, 1024, 0, stream>>>(states, self_data, ground_truth, myRank, @@ -287,7 +287,7 @@ void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit, CUDACHECK(cudaStreamDestroy(stream)); } -int main(int argc, char** argv) { +int main(int argc, char **argv) { int nRanks, myRank; MPICHECK(MPI_Init(&argc, &argv)); MPICHECK(MPI_Comm_rank(MPI_COMM_WORLD, &myRank)); @@ -296,7 +296,7 @@ int main(int argc, char** argv) { ncclUniqueId id; ncclComm_t comm; if (myRank == 0) ncclGetUniqueId(&id); - MPICHECK(MPI_Bcast(static_cast(&id), sizeof(id), MPI_BYTE, 0, + MPICHECK(MPI_Bcast(static_cast(&id), sizeof(id), MPI_BYTE, 0, MPI_COMM_WORLD)); NCCLCHECK(ncclCommInitRank(&comm, nRanks, id, myRank)); diff --git a/csrc/dispatch_utils.h b/csrc/dispatch_utils.h index a634e1c3d4886..91abd9e85b4bb 100644 --- a/csrc/dispatch_utils.h +++ b/csrc/dispatch_utils.h @@ -4,32 +4,34 @@ */ #pragma once -#include +#include -#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \ - AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ - AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ +#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) -#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ - AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) +#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) -#define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \ - AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ - AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ - AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \ +#define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) -#define VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(TYPE, NAME, ...) \ - AT_DISPATCH_SWITCH(TYPE, NAME, \ - VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(__VA_ARGS__)) - -#define VLLM_DISPATCH_CASE_INTEGRAL_TYPES(...) \ - AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \ - AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \ - AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \ - AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \ +#define VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(__VA_ARGS__)) + +#define VLLM_DISPATCH_CASE_INTEGRAL_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) -#define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \ - AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__)) +#define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__)) diff --git a/csrc/layernorm_kernels.cu b/csrc/layernorm_kernels.cu index ca1c04bd880d9..e56b4d2204005 100644 --- a/csrc/layernorm_kernels.cu +++ b/csrc/layernorm_kernels.cu @@ -1,4 +1,4 @@ -#include +#include #include #include @@ -11,24 +11,26 @@ #include #include -using __nv_bfloat16 = __hip_bfloat16; -using __nv_bfloat162 = __hip_bfloat162; + using __nv_bfloat16 = __hip_bfloat16; + using __nv_bfloat162 = __hip_bfloat162; #endif namespace vllm { // TODO(woosuk): Further optimize this kernel. -template +template __global__ void rms_norm_kernel( - scalar_t* __restrict__ out, // [..., hidden_size] - const scalar_t* __restrict__ input, // [..., hidden_size] - const scalar_t* __restrict__ weight, // [hidden_size] - const float epsilon, const int num_tokens, const int hidden_size) { + scalar_t* __restrict__ out, // [..., hidden_size] + const scalar_t* __restrict__ input, // [..., hidden_size] + const scalar_t* __restrict__ weight, // [hidden_size] + const float epsilon, + const int num_tokens, + const int hidden_size) { __shared__ float s_variance; float variance = 0.0f; for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { - const float x = (float)input[blockIdx.x * hidden_size + idx]; + const float x = (float) input[blockIdx.x * hidden_size + idx]; variance += x * x; } variance = blockReduceSum(variance); @@ -38,12 +40,12 @@ __global__ void rms_norm_kernel( __syncthreads(); for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { - float x = (float)input[blockIdx.x * hidden_size + idx]; - out[blockIdx.x * hidden_size + idx] = - ((scalar_t)(x * s_variance)) * weight[idx]; + float x = (float) input[blockIdx.x * hidden_size + idx]; + out[blockIdx.x * hidden_size + idx] = ((scalar_t) (x * s_variance)) * weight[idx]; } } + /* Converter structs for the conversion from torch types to HIP/CUDA types, and the associated type conversions within HIP/CUDA. These helpers need to be implemented for now because the relevant type conversion @@ -52,68 +54,51 @@ __global__ void rms_norm_kernel( Each struct should have the member static constexpr bool `exists`: If false, the optimized kernel is not used for the corresponding torch type. - If true, the struct should be fully defined as shown in the examples below. + If true, the struct should be fully defined as shown in the examples below. */ -template -struct _typeConvert { - static constexpr bool exists = false; -}; +template +struct _typeConvert { static constexpr bool exists = false; }; #if defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000)) // CUDA < 12.0 runs into issues with packed type conversion -template <> +template<> struct _typeConvert { static constexpr bool exists = true; using hip_type = __half; using packed_hip_type = __half2; __device__ static inline float convert(hip_type x) { return __half2float(x); } - __device__ static inline float2 convert(packed_hip_type x) { - return __half22float2(x); - } - __device__ static inline hip_type convert(float x) { - return __float2half_rn(x); - } - __device__ static inline packed_hip_type convert(float2 x) { - return __float22half2_rn(x); - } + __device__ static inline float2 convert(packed_hip_type x) { return __half22float2(x); } + __device__ static inline hip_type convert(float x) { return __float2half_rn(x); } + __device__ static inline packed_hip_type convert(float2 x) { return __float22half2_rn(x); } }; - #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 // CUDA_ARCH < 800 does not have BF16 support // TODO: Add in ROCm support once public headers handle bf16 maturely -template <> +template<> struct _typeConvert { static constexpr bool exists = true; using hip_type = __nv_bfloat16; using packed_hip_type = __nv_bfloat162; - __device__ static inline float convert(hip_type x) { - return __bfloat162float(x); - } - __device__ static inline float2 convert(packed_hip_type x) { - return __bfloat1622float2(x); - } - __device__ static inline hip_type convert(float x) { - return __float2bfloat16(x); - } - __device__ static inline packed_hip_type convert(float2 x) { - return __float22bfloat162_rn(x); - } + __device__ static inline float convert(hip_type x) { return __bfloat162float(x); } + __device__ static inline float2 convert(packed_hip_type x) { return __bfloat1622float2(x); } + __device__ static inline hip_type convert(float x) { return __float2bfloat16(x); } + __device__ static inline packed_hip_type convert(float2 x) { return __float22bfloat162_rn(x); } }; - #endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 -#endif // defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= - // 12000)) +#endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +#endif // defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000)) /* Vector POD struct to generate vectorized and packed FP16/BF16 ops for appropriate specializations of fused_add_rms_norm_kernel. Only functions that are necessary in that kernel are implemented. Alignment to 16 bytes is required to use 128-bit global memory ops. */ -template +template struct alignas(16) _f16Vec { - /* Not theoretically necessary that width is a power of 2 but should - almost always be the case for optimization purposes */ + /* Not theoretically necessary that width is a power of 2 but should + almost always be the case for optimization purposes */ static_assert(width > 0 && (width & (width - 1)) == 0, "Width is not a positive power of 2!"); using Converter = _typeConvert; @@ -123,49 +108,51 @@ struct alignas(16) _f16Vec { __device__ _f16Vec& operator+=(const _f16Vec& other) { if constexpr (width % 2 == 0) { -#pragma unroll + #pragma unroll for (int i = 0; i < width; i += 2) { - T2 temp{data[i], data[i + 1]}; - temp += T2{other.data[i], other.data[i + 1]}; + T2 temp{data[i], data[i+1]}; + temp += T2{other.data[i], other.data[i+1]}; data[i] = temp.x; - data[i + 1] = temp.y; + data[i+1] = temp.y; } } else { -#pragma unroll - for (int i = 0; i < width; ++i) data[i] += other.data[i]; + #pragma unroll + for (int i = 0; i < width; ++i) + data[i] += other.data[i]; } return *this; } __device__ _f16Vec& operator*=(const _f16Vec& other) { if constexpr (width % 2 == 0) { -#pragma unroll + #pragma unroll for (int i = 0; i < width; i += 2) { - T2 temp{data[i], data[i + 1]}; - temp *= T2{other.data[i], other.data[i + 1]}; + T2 temp{data[i], data[i+1]}; + temp *= T2{other.data[i], other.data[i+1]}; data[i] = temp.x; - data[i + 1] = temp.y; + data[i+1] = temp.y; } } else { -#pragma unroll - for (int i = 0; i < width; ++i) data[i] *= other.data[i]; + #pragma unroll + for (int i = 0; i < width; ++i) + data[i] *= other.data[i]; } return *this; } __device__ _f16Vec& operator*=(const float scale) { if constexpr (width % 2 == 0) { -#pragma unroll + #pragma unroll for (int i = 0; i < width; i += 2) { - float2 temp_f = Converter::convert(T2{data[i], data[i + 1]}); + float2 temp_f = Converter::convert(T2{data[i], data[i+1]}); temp_f.x *= scale; temp_f.y *= scale; T2 temp = Converter::convert(temp_f); data[i] = temp.x; - data[i + 1] = temp.y; + data[i+1] = temp.y; } } else { -#pragma unroll + #pragma unroll for (int i = 0; i < width; ++i) { float temp = Converter::convert(data[i]) * scale; data[i] = Converter::convert(temp); @@ -177,13 +164,13 @@ struct alignas(16) _f16Vec { __device__ float sum_squares() const { float result = 0.0f; if constexpr (width % 2 == 0) { -#pragma unroll + #pragma unroll for (int i = 0; i < width; i += 2) { - float2 z = Converter::convert(T2{data[i], data[i + 1]}); + float2 z = Converter::convert(T2{data[i], data[i+1]}); result += z.x * z.x + z.y * z.y; } } else { -#pragma unroll + #pragma unroll for (int i = 0; i < width; ++i) { float x = Converter::convert(data[i]); result += x * x; @@ -197,13 +184,15 @@ struct alignas(16) _f16Vec { Additional optimizations we can make in this case are packed and vectorized operations, which help with the memory latency bottleneck. */ -template -__global__ std::enable_if_t<(width > 0) && _typeConvert::exists> -fused_add_rms_norm_kernel( - scalar_t* __restrict__ input, // [..., hidden_size] - scalar_t* __restrict__ residual, // [..., hidden_size] - const scalar_t* __restrict__ weight, // [hidden_size] - const float epsilon, const int num_tokens, const int hidden_size) { +template +__global__ std::enable_if_t< + (width > 0) && _typeConvert::exists> fused_add_rms_norm_kernel( + scalar_t* __restrict__ input, // [..., hidden_size] + scalar_t* __restrict__ residual, // [..., hidden_size] + const scalar_t* __restrict__ weight, // [hidden_size] + const float epsilon, + const int num_tokens, + const int hidden_size) { // Sanity checks on our vector struct and type-punned pointer arithmetic static_assert(std::is_pod_v<_f16Vec>); static_assert(sizeof(_f16Vec) == sizeof(scalar_t) * width); @@ -214,12 +203,9 @@ fused_add_rms_norm_kernel( /* These and the argument pointers are all declared `restrict` as they are not aliased in practice. Argument pointers should not be dereferenced in this kernel as that would be undefined behavior */ - auto* __restrict__ input_v = - reinterpret_cast<_f16Vec*>(input); - auto* __restrict__ residual_v = - reinterpret_cast<_f16Vec*>(residual); - auto* __restrict__ weight_v = - reinterpret_cast*>(weight); + auto* __restrict__ input_v = reinterpret_cast<_f16Vec*>(input); + auto* __restrict__ residual_v = reinterpret_cast<_f16Vec*>(residual); + auto* __restrict__ weight_v = reinterpret_cast*>(weight); for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) { int id = blockIdx.x * vec_hidden_size + idx; @@ -229,11 +215,10 @@ fused_add_rms_norm_kernel( residual_v[id] = temp; } /* Keep the following if-else block in sync with the - calculation of max_block_size in fused_add_rms_norm */ + calculation of max_block_size in fused_add_rms_norm */ if (num_tokens < 256) { variance = blockReduceSum(variance); - } else - variance = blockReduceSum(variance); + } else variance = blockReduceSum(variance); if (threadIdx.x == 0) { s_variance = rsqrtf(variance / hidden_size + epsilon); } @@ -248,50 +233,52 @@ fused_add_rms_norm_kernel( } } + /* Generic fused_add_rms_norm_kernel The width field is not used here but necessary for other specializations. */ -template -__global__ std::enable_if_t<(width == 0) || !_typeConvert::exists> -fused_add_rms_norm_kernel( - scalar_t* __restrict__ input, // [..., hidden_size] - scalar_t* __restrict__ residual, // [..., hidden_size] - const scalar_t* __restrict__ weight, // [hidden_size] - const float epsilon, const int num_tokens, const int hidden_size) { +template +__global__ std::enable_if_t< + (width == 0) || !_typeConvert::exists> fused_add_rms_norm_kernel( + scalar_t* __restrict__ input, // [..., hidden_size] + scalar_t* __restrict__ residual, // [..., hidden_size] + const scalar_t* __restrict__ weight, // [hidden_size] + const float epsilon, + const int num_tokens, + const int hidden_size) { __shared__ float s_variance; float variance = 0.0f; for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { scalar_t z = input[blockIdx.x * hidden_size + idx]; z += residual[blockIdx.x * hidden_size + idx]; - float x = (float)z; + float x = (float) z; variance += x * x; residual[blockIdx.x * hidden_size + idx] = z; } /* Keep the following if-else block in sync with the - calculation of max_block_size in fused_add_rms_norm */ + calculation of max_block_size in fused_add_rms_norm */ if (num_tokens < 256) { variance = blockReduceSum(variance); - } else - variance = blockReduceSum(variance); + } else variance = blockReduceSum(variance); if (threadIdx.x == 0) { s_variance = rsqrtf(variance / hidden_size + epsilon); } __syncthreads(); for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { - float x = (float)residual[blockIdx.x * hidden_size + idx]; - input[blockIdx.x * hidden_size + idx] = - ((scalar_t)(x * s_variance)) * weight[idx]; + float x = (float) residual[blockIdx.x * hidden_size + idx]; + input[blockIdx.x * hidden_size + idx] = ((scalar_t) (x * s_variance)) * weight[idx]; } } -} // namespace vllm +} // namespace vllm -void rms_norm(torch::Tensor& out, // [..., hidden_size] - torch::Tensor& input, // [..., hidden_size] - torch::Tensor& weight, // [hidden_size] - double epsilon) { +void rms_norm( + torch::Tensor& out, // [..., hidden_size] + torch::Tensor& input, // [..., hidden_size] + torch::Tensor& weight, // [hidden_size] + float epsilon) { int hidden_size = input.size(-1); int num_tokens = input.numel() / hidden_size; @@ -299,27 +286,40 @@ void rms_norm(torch::Tensor& out, // [..., hidden_size] dim3 block(std::min(hidden_size, 1024)); const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] { - vllm::rms_norm_kernel<<>>( - out.data_ptr(), input.data_ptr(), - weight.data_ptr(), epsilon, num_tokens, hidden_size); - }); + VLLM_DISPATCH_FLOATING_TYPES( + input.scalar_type(), + "rms_norm_kernel", + [&] { + vllm::rms_norm_kernel<<>>( + out.data_ptr(), + input.data_ptr(), + weight.data_ptr(), + epsilon, + num_tokens, + hidden_size); + }); } -#define LAUNCH_FUSED_ADD_RMS_NORM(width) \ - VLLM_DISPATCH_FLOATING_TYPES( \ - input.scalar_type(), "fused_add_rms_norm_kernel", [&] { \ - vllm::fused_add_rms_norm_kernel \ - <<>>(input.data_ptr(), \ - residual.data_ptr(), \ - weight.data_ptr(), epsilon, \ - num_tokens, hidden_size); \ - }); - -void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size] - torch::Tensor& residual, // [..., hidden_size] - torch::Tensor& weight, // [hidden_size] - double epsilon) { +#define LAUNCH_FUSED_ADD_RMS_NORM(width) \ + VLLM_DISPATCH_FLOATING_TYPES( \ + input.scalar_type(), \ + "fused_add_rms_norm_kernel", \ + [&] { \ + vllm::fused_add_rms_norm_kernel \ + <<>>( \ + input.data_ptr(), \ + residual.data_ptr(), \ + weight.data_ptr(), \ + epsilon, \ + num_tokens, \ + hidden_size); \ + }); + +void fused_add_rms_norm( + torch::Tensor& input, // [..., hidden_size] + torch::Tensor& residual, // [..., hidden_size] + torch::Tensor& weight, // [hidden_size] + float epsilon) { int hidden_size = input.size(-1); int num_tokens = input.numel() / hidden_size; @@ -342,8 +342,8 @@ void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size] auto inp_ptr = reinterpret_cast(input.data_ptr()); auto res_ptr = reinterpret_cast(residual.data_ptr()); auto wt_ptr = reinterpret_cast(weight.data_ptr()); - bool ptrs_are_aligned = - inp_ptr % 16 == 0 && res_ptr % 16 == 0 && wt_ptr % 16 == 0; + bool ptrs_are_aligned = inp_ptr % 16 == 0 && res_ptr % 16 == 0 \ + && wt_ptr % 16 == 0; if (ptrs_are_aligned && hidden_size % 8 == 0) { LAUNCH_FUSED_ADD_RMS_NORM(8); } else { diff --git a/csrc/moe/moe_ops.cpp b/csrc/moe/moe_ops.cpp new file mode 100644 index 0000000000000..35c328499a22d --- /dev/null +++ b/csrc/moe/moe_ops.cpp @@ -0,0 +1,7 @@ +#include "moe_ops.h" + +#include + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("topk_softmax", &topk_softmax, "Apply topk softmax to the gating outputs."); +} diff --git a/csrc/moe/moe_ops.h b/csrc/moe/moe_ops.h index a251730aa765a..a01be3e426d72 100644 --- a/csrc/moe/moe_ops.h +++ b/csrc/moe/moe_ops.h @@ -1,7 +1,9 @@ #pragma once -#include +#include -void topk_softmax(torch::Tensor& topk_weights, torch::Tensor& topk_indices, - torch::Tensor& token_expert_indices, - torch::Tensor& gating_output); +void topk_softmax( + torch::Tensor& topk_weights, + torch::Tensor& topk_indices, + torch::Tensor& token_expert_indices, + torch::Tensor& gating_output); diff --git a/csrc/moe/topk_softmax_kernels.cu b/csrc/moe/topk_softmax_kernels.cu index de9747b602524..8c65f40fe836a 100644 --- a/csrc/moe/topk_softmax_kernels.cu +++ b/csrc/moe/topk_softmax_kernels.cu @@ -16,25 +16,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include +#include #include #include -#include "../cuda_compat.h" -#ifndef USE_ROCM - #include - #include -#else - #include - #include -#endif - -#define MAX(a, b) ((a) > (b) ? (a) : (b)) -#define MIN(a, b) ((a) < (b) ? (a) : (b)) +#include +#include namespace vllm { namespace moe { +static constexpr int WARP_SIZE = 32; + /// Aligned array type template < typename T, @@ -272,7 +265,7 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ #pragma unroll for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) { - thread_max = max(thread_max, VLLM_SHFL_XOR_SYNC_WIDTH(thread_max, mask, THREADS_PER_ROW)); + thread_max = max(thread_max, __shfl_xor_sync(0xFFFFFFFF, thread_max, mask, THREADS_PER_ROW)); } // From this point, thread max in all the threads have the max within the row. @@ -289,7 +282,7 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ #pragma unroll for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) { - row_sum += VLLM_SHFL_XOR_SYNC_WIDTH(row_sum, mask, THREADS_PER_ROW); + row_sum += __shfl_xor_sync(0xFFFFFFFF, row_sum, mask, THREADS_PER_ROW); } // From this point, all threads have the max and the sum for their rows in the thread_max and thread_sum variables @@ -339,8 +332,8 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ #pragma unroll for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) { - float other_max = VLLM_SHFL_XOR_SYNC_WIDTH(max_val, mask, THREADS_PER_ROW); - int other_expert = VLLM_SHFL_XOR_SYNC_WIDTH(expert, mask, THREADS_PER_ROW); + float other_max = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, THREADS_PER_ROW); + int other_expert = __shfl_xor_sync(0xFFFFFFFF, expert, mask, THREADS_PER_ROW); // We want lower indices to "win" in every thread so we break ties this way if (other_max > max_val || (other_max == max_val && other_expert < expert)) @@ -390,7 +383,7 @@ struct TopkConstants { static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(float); static_assert(EXPERTS / (ELTS_PER_LDG * WARP_SIZE) == 0 || EXPERTS % (ELTS_PER_LDG * WARP_SIZE) == 0, ""); - static constexpr int VECs_PER_THREAD = MAX(1, EXPERTS / (ELTS_PER_LDG * WARP_SIZE)); + static constexpr int VECs_PER_THREAD = std::max(1, EXPERTS / (ELTS_PER_LDG * WARP_SIZE)); static constexpr int VPT = VECs_PER_THREAD * ELTS_PER_LDG; static constexpr int THREADS_PER_ROW = EXPERTS / VPT; static constexpr int ROWS_PER_WARP = WARP_SIZE / THREADS_PER_ROW; @@ -403,7 +396,7 @@ void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, f { static constexpr std::size_t MAX_BYTES_PER_LDG = 16; - static constexpr int BYTES_PER_LDG = MIN(MAX_BYTES_PER_LDG, sizeof(float) * EXPERTS); + static constexpr int BYTES_PER_LDG = std::min(MAX_BYTES_PER_LDG, sizeof(float) * EXPERTS); using Constants = detail::TopkConstants; static constexpr int VPT = Constants::VPT; static constexpr int ROWS_PER_WARP = Constants::ROWS_PER_WARP; diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp deleted file mode 100644 index 243752b9a9e8c..0000000000000 --- a/csrc/moe/torch_bindings.cpp +++ /dev/null @@ -1,12 +0,0 @@ -#include "registration.h" -#include "moe_ops.h" - -TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { - // Apply topk softmax to the gating outputs. - m.def( - "topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor! " - "token_expert_indices, Tensor gating_output) -> ()"); - m.impl("topk_softmax", torch::kCUDA, &topk_softmax); -} - -REGISTER_EXTENSION(TORCH_EXTENSION_NAME) diff --git a/csrc/moe_align_block_size_kernels.cu b/csrc/moe_align_block_size_kernels.cu index 1f8d75da83bb8..e01b23685ef4e 100644 --- a/csrc/moe_align_block_size_kernels.cu +++ b/csrc/moe_align_block_size_kernels.cu @@ -1,4 +1,4 @@ -#include +#include #include #include @@ -7,128 +7,119 @@ #include "cuda_compat.h" #include "dispatch_utils.h" -#define CEILDIV(x, y) (((x) + (y) - 1) / (y)) +#define CEILDIV(x,y) (((x) + (y) - 1) / (y)) namespace vllm { namespace { -__device__ __forceinline__ int32_t index(int32_t total_col, int32_t row, - int32_t col) { - // don't worry about overflow because num_experts is relatively small - return row * total_col + col; +__device__ __forceinline__ int32_t index(int32_t total_col, int32_t row, int32_t col) { + // don't worry about overflow because num_experts is relatively small + return row * total_col + col; +} } -} // namespace template -__global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids, - int32_t* sorted_token_ids, - int32_t* expert_ids, - int32_t* total_tokens_post_pad, - int32_t num_experts, - int32_t block_size, size_t numel) { - const size_t tokens_per_thread = CEILDIV(numel, blockDim.x); - const size_t start_idx = threadIdx.x * tokens_per_thread; - - extern __shared__ int32_t shared_mem[]; - - int32_t* tokens_cnts = - shared_mem; // 2d tensor with shape (num_experts + 1, num_experts) - int32_t* cumsum = - shared_mem + (num_experts + 1) * - num_experts; // 1d tensor with shape (num_experts + 1) - - for (int i = 0; i < num_experts; ++i) { - tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0; - } - - /** - * In the first step we compute token_cnts[thread_index + 1][expert_index], - * which counts how many tokens in the token shard of thread_index are - * assigned to expert expert_index. - */ - for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { - ++tokens_cnts[index(num_experts, threadIdx.x + 1, topk_ids[i])]; - } - - __syncthreads(); - - // For each expert we accumulate the token counts from the different threads. - tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0; - for (int i = 1; i <= blockDim.x; ++i) { - tokens_cnts[index(num_experts, i, threadIdx.x)] += - tokens_cnts[index(num_experts, i - 1, threadIdx.x)]; - } - - __syncthreads(); - - // We accumulate the token counts of all experts in thread 0. - if (threadIdx.x == 0) { - cumsum[0] = 0; - for (int i = 1; i <= num_experts; ++i) { - cumsum[i] = cumsum[i - 1] + - CEILDIV(tokens_cnts[index(num_experts, blockDim.x, i - 1)], - block_size) * - block_size; +__global__ void moe_align_block_size_kernel(scalar_t *__restrict__ topk_ids, + int32_t *sorted_token_ids, + int32_t *expert_ids, + int32_t *total_tokens_post_pad, + int32_t num_experts, + int32_t block_size, + size_t numel) { + const size_t tokens_per_thread = CEILDIV(numel, blockDim.x); + const size_t start_idx = threadIdx.x * tokens_per_thread; + + extern __shared__ int32_t shared_mem[]; + + int32_t* tokens_cnts = shared_mem; // 2d tensor with shape (num_experts + 1, num_experts) + int32_t* cumsum = shared_mem + (num_experts + 1) * num_experts; // 1d tensor with shape (num_experts + 1) + + for (int i = 0; i < num_experts; ++i) { + tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0; + } + + /** + * In the first step we compute token_cnts[thread_index + 1][expert_index], + * which counts how many tokens in the token shard of thread_index are assigned + * to expert expert_index. + */ + for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { + ++tokens_cnts[index(num_experts, threadIdx.x + 1, topk_ids[i])]; + } + + __syncthreads(); + + // For each expert we accumulate the token counts from the different threads. + tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0; + for (int i = 1; i <= blockDim.x; ++i) { + tokens_cnts[index(num_experts, i, threadIdx.x)] += tokens_cnts[index(num_experts, i-1, threadIdx.x)]; + } + + __syncthreads(); + + // We accumulate the token counts of all experts in thread 0. + if (threadIdx.x == 0) { + cumsum[0] = 0; + for (int i = 1; i <= num_experts; ++i) { + cumsum[i] = cumsum[i-1] + CEILDIV(tokens_cnts[index(num_experts, blockDim.x, i - 1)], block_size) * block_size; + } + *total_tokens_post_pad = cumsum[num_experts]; + } + + __syncthreads(); + + /** + * For each expert, each thread processes the tokens of the corresponding blocks + * and stores the corresponding expert_id for each block. + */ + for (int i = cumsum[threadIdx.x];i < cumsum[threadIdx.x + 1];i += block_size) { + expert_ids[i / block_size] = threadIdx.x; } - *total_tokens_post_pad = cumsum[num_experts]; - } - - __syncthreads(); - - /** - * For each expert, each thread processes the tokens of the corresponding - * blocks and stores the corresponding expert_id for each block. - */ - for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1]; - i += block_size) { - expert_ids[i / block_size] = threadIdx.x; - } - - /** - * Each thread processes a token shard, calculating the index of each token - * after sorting by expert number. Given the example topk_ids = - * [0,1,2,1,2,3,0,3,4] and block_size = 4, then the output would be [0, 6, *, - * *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *], where * represents a - * padding value(preset in python). - */ - for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { - int32_t expert_id = topk_ids[i]; - /** The cumsum[expert_id] stores the starting index of the tokens that the - * expert with expert_id needs to process, and - * tokens_cnts[threadIdx.x][expert_id] stores the indices of the tokens - * processed by the expert with expert_id within the current thread's token - * shard. - */ - int32_t rank_post_pad = - tokens_cnts[index(num_experts, threadIdx.x, expert_id)] + - cumsum[expert_id]; - sorted_token_ids[rank_post_pad] = i; - ++tokens_cnts[index(num_experts, threadIdx.x, expert_id)]; - } + + /** + * Each thread processes a token shard, calculating the index of each token after + * sorting by expert number. Given the example topk_ids = [0,1,2,1,2,3,0,3,4] and + * block_size = 4, then the output would be [0, 6, *, *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *], + * where * represents a padding value(preset in python). + */ + for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { + int32_t expert_id = topk_ids[i]; + /** The cumsum[expert_id] stores the starting index of the tokens that the + * expert with expert_id needs to process, and tokens_cnts[threadIdx.x][expert_id] + * stores the indices of the tokens processed by the expert with expert_id within + * the current thread's token shard. + */ + int32_t rank_post_pad = tokens_cnts[index(num_experts, threadIdx.x, expert_id)] + cumsum[expert_id]; + sorted_token_ids[rank_post_pad] = i; + ++tokens_cnts[index(num_experts, threadIdx.x, expert_id)]; + } +} } -} // namespace vllm - -void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, - int64_t block_size, torch::Tensor sorted_token_ids, - torch::Tensor experts_ids, - torch::Tensor num_tokens_post_pad) { - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - VLLM_DISPATCH_INTEGRAL_TYPES( - topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { - // calc needed amount of shared mem for `tokens_cnts` and `cumsum` - // tensors - const int32_t shared_mem = - ((num_experts + 1) * num_experts + (num_experts + 1)) * - sizeof(int32_t); + +void moe_align_block_size( + torch::Tensor topk_ids, + int num_experts, + int block_size, + torch::Tensor sorted_token_ids, + torch::Tensor experts_ids, + torch::Tensor num_tokens_post_pad) { + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + VLLM_DISPATCH_INTEGRAL_TYPES( + topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { + // calc needed amount of shared mem for `tokens_cnts` and `cumsum` tensors + const int32_t shared_mem = ((num_experts + 1) * num_experts + (num_experts + 1)) * sizeof(int32_t); // set dynamic shared mem auto kernel = vllm::moe_align_block_size_kernel; - AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( - (void*)kernel, shared_mem)); + AT_CUDA_CHECK( + VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize((void *)kernel, shared_mem)); kernel<<<1, num_experts, shared_mem, stream>>>( - topk_ids.data_ptr(), sorted_token_ids.data_ptr(), - experts_ids.data_ptr(), - num_tokens_post_pad.data_ptr(), num_experts, block_size, + topk_ids.data_ptr(), + sorted_token_ids.data_ptr(), + experts_ids.data_ptr(), + num_tokens_post_pad.data_ptr(), + num_experts, + block_size, topk_ids.numel()); - }); + }); } diff --git a/csrc/ops.h b/csrc/ops.h index 8a92afdc81a9b..9541adcb3de88 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -1,152 +1,206 @@ #pragma once -#include -#include +#include void paged_attention_v1( - torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, - torch::Tensor& value_cache, int64_t num_kv_heads, double scale, - torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, - int64_t max_seq_len, const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank, - const int64_t blocksparse_local_blocks, - const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, - const int64_t blocksparse_head_sliding_step); + torch::Tensor& out, + torch::Tensor& query, + torch::Tensor& key_cache, + torch::Tensor& value_cache, + int num_kv_heads, + float scale, + torch::Tensor& block_tables, + torch::Tensor& seq_lens, + int block_size, + int max_seq_len, + const c10::optional& alibi_slopes, + const std::string& kv_cache_dtype, + float kv_scale); void paged_attention_v2( - torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, - torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, - torch::Tensor& value_cache, int64_t num_kv_heads, double scale, - torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, - int64_t max_seq_len, const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank, - const int64_t blocksparse_local_blocks, - const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, - const int64_t blocksparse_head_sliding_step); - -void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, - double epsilon); - -void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual, - torch::Tensor& weight, double epsilon); - -void rotary_embedding(torch::Tensor& positions, torch::Tensor& query, - torch::Tensor& key, int64_t head_size, - torch::Tensor& cos_sin_cache, bool is_neox); - -void batched_rotary_embedding(torch::Tensor& positions, torch::Tensor& query, - torch::Tensor& key, int64_t head_size, - torch::Tensor& cos_sin_cache, bool is_neox, - int64_t rot_dim, - torch::Tensor& cos_sin_cache_offsets); - -void silu_and_mul(torch::Tensor& out, torch::Tensor& input); - -void gelu_and_mul(torch::Tensor& out, torch::Tensor& input); - -void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input); - -void gelu_new(torch::Tensor& out, torch::Tensor& input); - -void gelu_fast(torch::Tensor& out, torch::Tensor& input); - -void gelu_quick(torch::Tensor& out, torch::Tensor& input); + torch::Tensor& out, + torch::Tensor& exp_sums, + torch::Tensor& max_logits, + torch::Tensor& tmp_out, + torch::Tensor& query, + torch::Tensor& key_cache, + torch::Tensor& value_cache, + int num_kv_heads, + float scale, + torch::Tensor& block_tables, + torch::Tensor& seq_lens, + int block_size, + int max_seq_len, + const c10::optional& alibi_slopes, + const std::string& kv_cache_dtype, + float kv_scale); + +void rms_norm( + torch::Tensor& out, + torch::Tensor& input, + torch::Tensor& weight, + float epsilon); + +void fused_add_rms_norm( + torch::Tensor& input, + torch::Tensor& residual, + torch::Tensor& weight, + float epsilon); + +void rotary_embedding( + torch::Tensor& positions, + torch::Tensor& query, + torch::Tensor& key, + int head_size, + torch::Tensor& cos_sin_cache, + bool is_neox); + +void batched_rotary_embedding( + torch::Tensor& positions, + torch::Tensor& query, + torch::Tensor& key, + int head_size, + torch::Tensor& cos_sin_cache, + bool is_neox, + int rot_dim, + torch::Tensor& cos_sin_cache_offsets); + +void silu_and_mul( + torch::Tensor& out, + torch::Tensor& input); + +void gelu_and_mul( + torch::Tensor& out, + torch::Tensor& input); + +void gelu_tanh_and_mul( + torch::Tensor& out, + torch::Tensor& input); + +void gelu_new( + torch::Tensor& out, + torch::Tensor& input); + +void gelu_fast( + torch::Tensor& out, + torch::Tensor& input); #ifndef USE_ROCM -torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes, - const torch::Tensor& codebooks, - const torch::Tensor& scales, - const torch::Tensor& codebook_partition_sizes, - const std::optional& bias); - -torch::Tensor aqlm_dequant(const torch::Tensor& codes, - const torch::Tensor& codebooks, - const torch::Tensor& codebook_partition_sizes); - -torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel, - torch::Tensor _scaling_factors, torch::Tensor _zeros, - int64_t split_k_iters); - -torch::Tensor awq_dequantize(torch::Tensor _kernel, - torch::Tensor _scaling_factors, - torch::Tensor _zeros, int64_t split_k_iters, - int64_t thx, int64_t thy); - -torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, - torch::Tensor& b_scales, torch::Tensor& workspace, - int64_t size_m, int64_t size_n, int64_t size_k); - -torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, - torch::Tensor& b_meta, - torch::Tensor& b_scales, - torch::Tensor& workspace, int64_t num_bits, - int64_t size_m, int64_t size_n, - int64_t size_k); - -torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, - torch::Tensor& b_scales, torch::Tensor& g_idx, - torch::Tensor& perm, torch::Tensor& workspace, - int64_t num_bits, int64_t size_m, int64_t size_n, - int64_t size_k, bool is_k_full); - -torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, - int64_t size_k, int64_t size_n, - int64_t num_bits); - -bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability); - -void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a, - torch::Tensor const& b, torch::Tensor const& a_scales, - torch::Tensor const& b_scales, - c10::optional const& bias); - +torch::Tensor aqlm_gemm( + const torch::Tensor& input, + const torch::Tensor& codes, + const torch::Tensor& codebooks, + const torch::Tensor& scales, + const torch::Tensor& codebook_partition_sizes, + const std::optional& bias +); + +torch::Tensor aqlm_dequant( + const torch::Tensor& codes, + const torch::Tensor& codebooks, + const torch::Tensor& codebook_partition_sizes +); + +torch::Tensor awq_gemm( + torch::Tensor _in_feats, + torch::Tensor _kernel, + torch::Tensor _scaling_factors, + torch::Tensor _zeros, + int split_k_iters); + +torch::Tensor awq_dequantize( + torch::Tensor _kernel, + torch::Tensor _scaling_factors, + torch::Tensor _zeros, + int split_k_iters, + int thx, + int thy); + +torch::Tensor marlin_gemm( + torch::Tensor& a, + torch::Tensor& b_q_weight, + torch::Tensor& b_scales, + torch::Tensor& workspace, + int64_t size_m, + int64_t size_n, + int64_t size_k); + +torch::Tensor gptq_marlin_gemm( + torch::Tensor &a, + torch::Tensor &b_q_weight, + torch::Tensor &b_scales, + torch::Tensor &g_idx, + torch::Tensor &perm, + torch::Tensor &workspace, + int64_t num_bits, + int64_t size_m, + int64_t size_n, + int64_t size_k, + bool is_k_full); + +torch::Tensor gptq_marlin_repack( + torch::Tensor &b_q_weight, + torch::Tensor &perm, + int64_t size_k, + int64_t size_n, + int64_t num_bits); #endif -void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input, - torch::Tensor const& scale); - -void dynamic_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input, - torch::Tensor& scales); - -void squeezellm_gemm(torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, - torch::Tensor lookup_table); - -torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight, - torch::Tensor b_gptq_qzeros, - torch::Tensor b_gptq_scales, torch::Tensor b_g_idx, - bool use_exllama, int64_t bit); - -void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit); - -void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor& input, - torch::Tensor& scale); - -void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor& input, - torch::Tensor& scale); - -void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, - int64_t block_size, torch::Tensor sorted_token_ids, - torch::Tensor experts_ids, - torch::Tensor num_tokens_post_pad); +void squeezellm_gemm( + torch::Tensor vec, + torch::Tensor mat, + torch::Tensor mul, + torch::Tensor lookup_table); + +torch::Tensor gptq_gemm( + torch::Tensor a, + torch::Tensor b_q_weight, + torch::Tensor b_gptq_qzeros, + torch::Tensor b_gptq_scales, + torch::Tensor b_g_idx, + bool use_exllama, + int bit); + +void gptq_shuffle( + torch::Tensor q_weight, + torch::Tensor q_perm, + int bit); + +void static_scaled_fp8_quant( + torch::Tensor& out, + torch::Tensor& input, + torch::Tensor& scale); + +void dynamic_scaled_fp8_quant( + torch::Tensor& out, + torch::Tensor& input, + torch::Tensor& scale); + +void moe_align_block_size( + torch::Tensor topk_ids, + int num_experts, + int block_size, + torch::Tensor sorted_token_ids, + torch::Tensor experts_ids, + torch::Tensor num_tokens_post_pad); #ifndef USE_ROCM -using fptr_t = int64_t; -fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data, - const std::vector& handles, - const std::vector& offsets, int64_t rank, - bool full_nvlink); -bool should_custom_ar(torch::Tensor& inp, int64_t max_size, int64_t world_size, +using fptr_t = uint64_t; +fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data, + const std::vector &handles, + const std::vector &offsets, int rank, + bool full_nvlink); +bool should_custom_ar(torch::Tensor &inp, int max_size, int world_size, bool full_nvlink); -void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out); -void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer, - torch::Tensor& out); +void all_reduce_reg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out); +void all_reduce_unreg(fptr_t _fa, torch::Tensor &inp, torch::Tensor ®_buffer, + torch::Tensor &out); void dispose(fptr_t _fa); -int64_t meta_size(); -void register_buffer(fptr_t _fa, torch::Tensor& t, - const std::vector& handles, - const std::vector& offsets); -std::tuple> get_graph_buffer_ipc_meta( - fptr_t _fa); -void register_graph_buffers(fptr_t _fa, const std::vector& handles, - const std::vector>& offsets); +int meta_size(); +void register_buffer(fptr_t _fa, torch::Tensor &t, + const std::vector &handles, + const std::vector &offsets); +std::pair, std::vector> get_graph_buffer_ipc_meta(fptr_t _fa); +void register_graph_buffers(fptr_t _fa, const std::vector &handles, + const std::vector> &offsets); #endif diff --git a/csrc/pos_encoding_kernels.cu b/csrc/pos_encoding_kernels.cu index 97184a8735593..d80cb6973fad6 100644 --- a/csrc/pos_encoding_kernels.cu +++ b/csrc/pos_encoding_kernels.cu @@ -1,4 +1,4 @@ -#include +#include #include #include @@ -7,10 +7,14 @@ namespace vllm { -template +template inline __device__ void apply_token_rotary_embedding( - scalar_t* __restrict__ arr, const scalar_t* __restrict__ cos_ptr, - const scalar_t* __restrict__ sin_ptr, int rot_offset, int embed_dim) { + scalar_t* __restrict__ arr, + const scalar_t* __restrict__ cos_ptr, + const scalar_t* __restrict__ sin_ptr, + int rot_offset, + int embed_dim) +{ int x_index, y_index; scalar_t cos, sin; if (IS_NEOX) { @@ -33,17 +37,19 @@ inline __device__ void apply_token_rotary_embedding( arr[y_index] = y * cos + x * sin; } -template +template inline __device__ void apply_rotary_embedding( - scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, - // head_size] or [num_tokens, num_heads, - // head_size] - scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, - // head_size] or [num_tokens, num_kv_heads, - // head_size] - const scalar_t* cache_ptr, const int head_size, const int num_heads, - const int num_kv_heads, const int rot_dim, const int token_idx, - const int64_t query_stride, const int64_t key_stride) { + scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size] + scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size] + const scalar_t* cache_ptr, + const int head_size, + const int num_heads, + const int num_kv_heads, + const int rot_dim, + const int token_idx, + const int64_t query_stride, + const int64_t key_stride) +{ const int embed_dim = rot_dim / 2; const scalar_t* cos_ptr = cache_ptr; const scalar_t* sin_ptr = cache_ptr + embed_dim; @@ -53,8 +59,8 @@ inline __device__ void apply_rotary_embedding( const int head_idx = i / embed_dim; const int64_t token_head = token_idx * query_stride + head_idx * head_size; const int rot_offset = i % embed_dim; - apply_token_rotary_embedding( - query + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim); + apply_token_rotary_embedding(query + token_head, cos_ptr, + sin_ptr, rot_offset, embed_dim); } const int nk = num_kv_heads * embed_dim; @@ -62,74 +68,62 @@ inline __device__ void apply_rotary_embedding( const int head_idx = i / embed_dim; const int64_t token_head = token_idx * key_stride + head_idx * head_size; const int rot_offset = i % embed_dim; - apply_token_rotary_embedding( - key + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim); + apply_token_rotary_embedding(key + token_head, cos_ptr, + sin_ptr, rot_offset, embed_dim); } } -template +template __global__ void rotary_embedding_kernel( - const int64_t* __restrict__ positions, // [batch_size, seq_len] or - // [num_tokens] - scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, - // head_size] or [num_tokens, num_heads, - // head_size] - scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, - // head_size] or [num_tokens, num_kv_heads, - // head_size] - const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // - // 2] - const int rot_dim, const int64_t query_stride, const int64_t key_stride, - const int num_heads, const int num_kv_heads, const int head_size) { + const int64_t* __restrict__ positions, // [batch_size, seq_len] or [num_tokens] + scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size] + scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size] + const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2] + const int rot_dim, + const int64_t query_stride, + const int64_t key_stride, + const int num_heads, + const int num_kv_heads, + const int head_size) { // Each thread block is responsible for one token. const int token_idx = blockIdx.x; int64_t pos = positions[token_idx]; const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim; - apply_rotary_embedding( - query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, - token_idx, query_stride, key_stride); + apply_rotary_embedding(query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, token_idx, query_stride, key_stride); } -template +template __global__ void batched_rotary_embedding_kernel( - const int64_t* __restrict__ positions, // [batch_size, seq_len] or - // [num_tokens] - scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, - // head_size] or [num_tokens, num_heads, - // head_size] - scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, - // head_size] or [num_tokens, num_kv_heads, - // head_size] - const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // - // 2] - const int64_t* __restrict__ cos_sin_cache_offsets, // [batch_size, seq_len] - // or [num_tokens] - const int rot_dim, const int64_t query_stride, const int64_t key_stride, - const int num_heads, const int num_kv_heads, const int head_size) { + const int64_t* __restrict__ positions, // [batch_size, seq_len] or [num_tokens] + scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size] + scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size] + const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2] + const int64_t* __restrict__ cos_sin_cache_offsets, // [batch_size, seq_len] or [num_tokens] + const int rot_dim, + const int64_t query_stride, + const int64_t key_stride, + const int num_heads, + const int num_kv_heads, + const int head_size) { // Each thread block is responsible for one token. const int token_idx = blockIdx.x; int64_t pos = positions[token_idx]; int64_t cos_sin_cache_offset = cos_sin_cache_offsets[token_idx]; - const scalar_t* cache_ptr = - cos_sin_cache + (cos_sin_cache_offset + pos) * rot_dim; + const scalar_t* cache_ptr = cos_sin_cache + (cos_sin_cache_offset + pos) * rot_dim; - apply_rotary_embedding( - query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, - token_idx, query_stride, key_stride); + apply_rotary_embedding(query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, token_idx, query_stride, key_stride); } -} // namespace vllm +} // namespace vllm void rotary_embedding( - torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens] - torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or - // [num_tokens, num_heads * head_size] - torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or - // [num_tokens, num_kv_heads * head_size] - int64_t head_size, - torch::Tensor& cos_sin_cache, // [max_position, rot_dim] - bool is_neox) { + torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens] + torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or [num_tokens, num_heads * head_size] + torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or [num_tokens, num_kv_heads * head_size] + int head_size, + torch::Tensor& cos_sin_cache, // [max_position, rot_dim] + bool is_neox) { int64_t num_tokens = query.numel() / query.size(-1); int rot_dim = cos_sin_cache.size(1); int num_heads = query.size(-1) / head_size; @@ -138,24 +132,39 @@ void rotary_embedding( int64_t key_stride = key.stride(-2); dim3 grid(num_tokens); - dim3 block(std::min(num_heads * rot_dim / 2, 512)); + dim3 block(std::min(num_heads * rot_dim / 2, 512)); const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "rotary_embedding", [&] { - if (is_neox) { - vllm::rotary_embedding_kernel<<>>( - positions.data_ptr(), query.data_ptr(), - key.data_ptr(), cos_sin_cache.data_ptr(), rot_dim, - query_stride, key_stride, num_heads, num_kv_heads, head_size); - } else { - vllm::rotary_embedding_kernel - <<>>( - positions.data_ptr(), query.data_ptr(), - key.data_ptr(), cos_sin_cache.data_ptr(), - rot_dim, query_stride, key_stride, num_heads, num_kv_heads, - head_size); - } - }); + VLLM_DISPATCH_FLOATING_TYPES( + query.scalar_type(), + "rotary_embedding", + [&] { + if (is_neox) { + vllm::rotary_embedding_kernel<<>>( + positions.data_ptr(), + query.data_ptr(), + key.data_ptr(), + cos_sin_cache.data_ptr(), + rot_dim, + query_stride, + key_stride, + num_heads, + num_kv_heads, + head_size); + } else { + vllm::rotary_embedding_kernel<<>>( + positions.data_ptr(), + query.data_ptr(), + key.data_ptr(), + cos_sin_cache.data_ptr(), + rot_dim, + query_stride, + key_stride, + num_heads, + num_kv_heads, + head_size); + } + }); } /* @@ -163,15 +172,14 @@ Batched version of rotary embedding, pack multiple LoRAs together and process in batched manner. */ void batched_rotary_embedding( - torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens] - torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or - // [num_tokens, num_heads * head_size] - torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or - // [num_tokens, num_kv_heads * head_size] - int64_t head_size, - torch::Tensor& cos_sin_cache, // [max_position, rot_dim] - bool is_neox, int64_t rot_dim, - torch::Tensor& cos_sin_cache_offsets // [num_tokens] + torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens] + torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or [num_tokens, num_heads * head_size] + torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or [num_tokens, num_kv_heads * head_size] + int head_size, + torch::Tensor& cos_sin_cache, // [max_position, rot_dim] + bool is_neox, + int rot_dim, + torch::Tensor& cos_sin_cache_offsets // [num_tokens] ) { int64_t num_tokens = cos_sin_cache_offsets.size(0); int num_heads = query.size(-1) / head_size; @@ -180,24 +188,39 @@ void batched_rotary_embedding( int64_t key_stride = key.stride(-2); dim3 grid(num_tokens); - dim3 block(std::min(num_heads * rot_dim / 2, 512)); + dim3 block(std::min(num_heads * rot_dim / 2, 512)); const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "rotary_embedding", [&] { - if (is_neox) { - vllm::batched_rotary_embedding_kernel - <<>>( - positions.data_ptr(), query.data_ptr(), - key.data_ptr(), cos_sin_cache.data_ptr(), - cos_sin_cache_offsets.data_ptr(), rot_dim, query_stride, - key_stride, num_heads, num_kv_heads, head_size); - } else { - vllm::batched_rotary_embedding_kernel - <<>>( - positions.data_ptr(), query.data_ptr(), - key.data_ptr(), cos_sin_cache.data_ptr(), - cos_sin_cache_offsets.data_ptr(), rot_dim, query_stride, - key_stride, num_heads, num_kv_heads, head_size); - } - }); + VLLM_DISPATCH_FLOATING_TYPES( + query.scalar_type(), + "rotary_embedding", + [&] { + if (is_neox) { + vllm::batched_rotary_embedding_kernel<<>>( + positions.data_ptr(), + query.data_ptr(), + key.data_ptr(), + cos_sin_cache.data_ptr(), + cos_sin_cache_offsets.data_ptr(), + rot_dim, + query_stride, + key_stride, + num_heads, + num_kv_heads, + head_size); + } else { + vllm::batched_rotary_embedding_kernel<<>>( + positions.data_ptr(), + query.data_ptr(), + key.data_ptr(), + cos_sin_cache.data_ptr(), + cos_sin_cache_offsets.data_ptr(), + rot_dim, + query_stride, + key_stride, + num_heads, + num_kv_heads, + head_size); + } + }); } diff --git a/csrc/punica/bgmv/bgmv_config.h b/csrc/punica/bgmv/bgmv_config.h index 2c8d007d8719f..19c058cacfbc4 100644 --- a/csrc/punica/bgmv/bgmv_config.h +++ b/csrc/punica/bgmv/bgmv_config.h @@ -16,68 +16,44 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, f(in_T, out_T, W_T, narrow, 512) \ f(in_T, out_T, W_T, narrow, 640) \ f(in_T, out_T, W_T, narrow, 768) \ - f(in_T, out_T, W_T, narrow, 896) \ f(in_T, out_T, W_T, narrow, 1024) \ f(in_T, out_T, W_T, narrow, 1152) \ - f(in_T, out_T, W_T, narrow, 1216) \ f(in_T, out_T, W_T, narrow, 1280) \ f(in_T, out_T, W_T, narrow, 1536) \ - f(in_T, out_T, W_T, narrow, 1664) \ f(in_T, out_T, W_T, narrow, 1728) \ f(in_T, out_T, W_T, narrow, 1792) \ f(in_T, out_T, W_T, narrow, 2048) \ - f(in_T, out_T, W_T, narrow, 2240) \ f(in_T, out_T, W_T, narrow, 2304) \ - f(in_T, out_T, W_T, narrow, 2368) \ - f(in_T, out_T, W_T, narrow, 2432) \ f(in_T, out_T, W_T, narrow, 2560) \ f(in_T, out_T, W_T, narrow, 2752) \ f(in_T, out_T, W_T, narrow, 2816) \ f(in_T, out_T, W_T, narrow, 3072) \ - f(in_T, out_T, W_T, narrow, 3328) \ f(in_T, out_T, W_T, narrow, 3456) \ f(in_T, out_T, W_T, narrow, 3584) \ - f(in_T, out_T, W_T, narrow, 3712) \ f(in_T, out_T, W_T, narrow, 4096) \ - f(in_T, out_T, W_T, narrow, 4480) \ f(in_T, out_T, W_T, narrow, 4608) \ - f(in_T, out_T, W_T, narrow, 4736) \ - f(in_T, out_T, W_T, narrow, 4864) \ f(in_T, out_T, W_T, narrow, 5120) \ f(in_T, out_T, W_T, narrow, 5504) \ f(in_T, out_T, W_T, narrow, 5632) \ - f(in_T, out_T, W_T, narrow, 5888) \ f(in_T, out_T, W_T, narrow, 6144) \ - f(in_T, out_T, W_T, narrow, 6400) \ f(in_T, out_T, W_T, narrow, 6848) \ f(in_T, out_T, W_T, narrow, 6912) \ f(in_T, out_T, W_T, narrow, 7168) \ - f(in_T, out_T, W_T, narrow, 7424) \ f(in_T, out_T, W_T, narrow, 8192) \ - f(in_T, out_T, W_T, narrow, 8960) \ f(in_T, out_T, W_T, narrow, 9216) \ - f(in_T, out_T, W_T, narrow, 9472) \ f(in_T, out_T, W_T, narrow, 10240) \ f(in_T, out_T, W_T, narrow, 11008) \ - f(in_T, out_T, W_T, narrow, 11264) \ f(in_T, out_T, W_T, narrow, 12288) \ f(in_T, out_T, W_T, narrow, 13696) \ f(in_T, out_T, W_T, narrow, 13824) \ f(in_T, out_T, W_T, narrow, 14336) \ - f(in_T, out_T, W_T, narrow, 14784) \ - f(in_T, out_T, W_T, narrow, 14848) \ f(in_T, out_T, W_T, narrow, 15360) \ f(in_T, out_T, W_T, narrow, 16384) \ - f(in_T, out_T, W_T, narrow, 18944) \ f(in_T, out_T, W_T, narrow, 20480) \ f(in_T, out_T, W_T, narrow, 22016) \ - f(in_T, out_T, W_T, narrow, 22528) \ f(in_T, out_T, W_T, narrow, 24576) \ f(in_T, out_T, W_T, narrow, 27392) \ - f(in_T, out_T, W_T, narrow, 27648) \ f(in_T, out_T, W_T, narrow, 28672) \ - f(in_T, out_T, W_T, narrow, 29568) \ - f(in_T, out_T, W_T, narrow, 29696) \ f(in_T, out_T, W_T, narrow, 32000) \ f(in_T, out_T, W_T, narrow, 32256) \ f(in_T, out_T, W_T, narrow, 32512) \ @@ -86,9 +62,6 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, f(in_T, out_T, W_T, narrow, 36864) \ f(in_T, out_T, W_T, narrow, 43264) \ f(in_T, out_T, W_T, narrow, 49152) \ - f(in_T, out_T, W_T, narrow, 49408) \ - f(in_T, out_T, W_T, narrow, 60544) \ - f(in_T, out_T, W_T, narrow, 60672) \ f(in_T, out_T, W_T, narrow, 64000) \ f(in_T, out_T, W_T, narrow, 64256) \ f(in_T, out_T, W_T, narrow, 64512) \ @@ -98,14 +71,12 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, f(in_T, out_T, W_T, narrow, 128000) \ f(in_T, out_T, W_T, narrow, 128256) \ f(in_T, out_T, W_T, narrow, 128512) \ - - // Keep above in sync with vllm/lora/layers::LogitsProcessorWithLoRA // and vllm/tests/lora/test_punica.py -// Used for defining kernels going from the variety of +// Used for defining kernels going from the variety of // dim in to the narrow dim out - // Using it for the fully sharded column + // Using it for the fully sharded column // parallel LoRA A which splits the rank dim #define FOR_INST_BGMV_NARROW(f, in_T, out_T, W_T, narrow) \ f(in_T, out_T, W_T, 128, narrow) \ @@ -113,68 +84,44 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, f(in_T, out_T, W_T, 512, narrow) \ f(in_T, out_T, W_T, 640, narrow) \ f(in_T, out_T, W_T, 768, narrow) \ - f(in_T, out_T, W_T, 896, narrow) \ f(in_T, out_T, W_T, 1024, narrow) \ f(in_T, out_T, W_T, 1152, narrow) \ - f(in_T, out_T, W_T, 1216, narrow) \ f(in_T, out_T, W_T, 1280, narrow) \ f(in_T, out_T, W_T, 1536, narrow) \ - f(in_T, out_T, W_T, 1664, narrow) \ f(in_T, out_T, W_T, 1728, narrow) \ f(in_T, out_T, W_T, 1792, narrow) \ f(in_T, out_T, W_T, 2048, narrow) \ - f(in_T, out_T, W_T, 2240, narrow) \ f(in_T, out_T, W_T, 2304, narrow) \ - f(in_T, out_T, W_T, 2368, narrow) \ - f(in_T, out_T, W_T, 2432, narrow) \ f(in_T, out_T, W_T, 2560, narrow) \ f(in_T, out_T, W_T, 2752, narrow) \ f(in_T, out_T, W_T, 2816, narrow) \ f(in_T, out_T, W_T, 3072, narrow) \ - f(in_T, out_T, W_T, 3328, narrow) \ f(in_T, out_T, W_T, 3456, narrow) \ f(in_T, out_T, W_T, 3584, narrow) \ - f(in_T, out_T, W_T, 3712, narrow) \ f(in_T, out_T, W_T, 4096, narrow) \ - f(in_T, out_T, W_T, 4480, narrow) \ f(in_T, out_T, W_T, 4608, narrow) \ - f(in_T, out_T, W_T, 4736, narrow) \ - f(in_T, out_T, W_T, 4864, narrow) \ f(in_T, out_T, W_T, 5120, narrow) \ f(in_T, out_T, W_T, 5504, narrow) \ f(in_T, out_T, W_T, 5632, narrow) \ - f(in_T, out_T, W_T, 5888, narrow) \ f(in_T, out_T, W_T, 6144, narrow) \ - f(in_T, out_T, W_T, 6400, narrow) \ f(in_T, out_T, W_T, 6848, narrow) \ f(in_T, out_T, W_T, 6912, narrow) \ f(in_T, out_T, W_T, 7168, narrow) \ - f(in_T, out_T, W_T, 7424, narrow) \ f(in_T, out_T, W_T, 8192, narrow) \ - f(in_T, out_T, W_T, 8960, narrow) \ f(in_T, out_T, W_T, 9216, narrow) \ - f(in_T, out_T, W_T, 9472, narrow) \ f(in_T, out_T, W_T, 10240, narrow) \ f(in_T, out_T, W_T, 11008, narrow) \ - f(in_T, out_T, W_T, 11264, narrow) \ f(in_T, out_T, W_T, 12288, narrow) \ f(in_T, out_T, W_T, 13696, narrow) \ f(in_T, out_T, W_T, 13824, narrow) \ f(in_T, out_T, W_T, 14336, narrow) \ - f(in_T, out_T, W_T, 14784, narrow) \ - f(in_T, out_T, W_T, 14848, narrow) \ f(in_T, out_T, W_T, 15360, narrow) \ f(in_T, out_T, W_T, 16384, narrow) \ - f(in_T, out_T, W_T, 18944, narrow) \ f(in_T, out_T, W_T, 20480, narrow) \ f(in_T, out_T, W_T, 22016, narrow) \ - f(in_T, out_T, W_T, 22528, narrow) \ f(in_T, out_T, W_T, 24576, narrow) \ f(in_T, out_T, W_T, 27392, narrow) \ - f(in_T, out_T, W_T, 27648, narrow) \ f(in_T, out_T, W_T, 28672, narrow) \ - f(in_T, out_T, W_T, 29568, narrow) \ - f(in_T, out_T, W_T, 29696, narrow) \ f(in_T, out_T, W_T, 32000, narrow) \ f(in_T, out_T, W_T, 32256, narrow) \ f(in_T, out_T, W_T, 32512, narrow) \ @@ -183,9 +130,6 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, f(in_T, out_T, W_T, 36864, narrow) \ f(in_T, out_T, W_T, 43264, narrow) \ f(in_T, out_T, W_T, 49152, narrow) \ - f(in_T, out_T, W_T, 49408, narrow) \ - f(in_T, out_T, W_T, 60544, narrow) \ - f(in_T, out_T, W_T, 60672, narrow) \ f(in_T, out_T, W_T, 64000, narrow) \ f(in_T, out_T, W_T, 64256, narrow) \ f(in_T, out_T, W_T, 64512, narrow) \ diff --git a/csrc/punica/bgmv/bgmv_impl.cuh b/csrc/punica/bgmv/bgmv_impl.cuh index 8a3b8403b4a6f..dad8805c750cb 100644 --- a/csrc/punica/bgmv/bgmv_impl.cuh +++ b/csrc/punica/bgmv/bgmv_impl.cuh @@ -1,14 +1,8 @@ #pragma once #include -#ifndef USE_ROCM #include -#else -#include -#endif -#ifndef USE_ROCM #include -#endif #include #include #include @@ -17,24 +11,6 @@ namespace cg = cooperative_groups; -#ifdef USE_ROCM -template -__host__ __device__ -inline void* memcpy_blocking(void *dst, const void *src) { - // Does not handle the case of long datatypes - char *d = reinterpret_cast(dst); - const char *s = reinterpret_cast(src); - size_t i = 0; -#pragma unroll - for (i = 0; i < len; ++i) { - d[i] = s[i]; - } - return dst; -} -#endif - -#ifndef USE_ROCM - // nthrs = (32, 4) template -__global__ void -bgmv_shrink_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, - const W_T *__restrict__ W, - const int64_t *__restrict__ indicies, int64_t y_offset, - int64_t full_y_size, int64_t num_layers, int64_t layer_idx, - float scale) { - size_t batch_idx = blockIdx.y; - int64_t idx = indicies[batch_idx] * num_layers + layer_idx; - if (idx < 0) { - return; - } - - size_t j = blockIdx.x; - constexpr size_t tile_size = tx * ty * vec_size; - constexpr size_t num_tiles = (feat_in + tile_size - 1) / tile_size; - __shared__ float y_warpwise[ty]; - - float y = 0; - vec_t x_vec; - vec_t w_vec; - size_t tile_idx; - -#pragma unroll - for (tile_idx = 0; tile_idx < num_tiles; ++tile_idx) { - if (tile_idx * tile_size + (threadIdx.y * tx + threadIdx.x + 1) * vec_size - 1 < feat_in) { - x_vec.load(X + (batch_idx * feat_in) + - tile_idx * tile_size + - (threadIdx.y * tx + threadIdx.x) * vec_size); - w_vec.load(W + (idx * feat_out + j) * feat_in + - tile_idx * tile_size + - (threadIdx.y * tx + threadIdx.x) * vec_size); - } - - float sum = 0.f; -#pragma unroll - for (size_t i = 0; i < vec_size; ++i) { - sum += convert_type(w_vec[i]) * convert_type(x_vec[i]) * scale; - } -#pragma unroll - for (size_t offset = tx / 2; offset > 0; offset /= 2) { - sum += VLLM_SHFL_DOWN_SYNC(sum, offset); - } - - __syncthreads(); - - if (tile_idx * tile_size + (threadIdx.y * tx + threadIdx.x + 1) * vec_size - 1 < feat_in) { - y += sum; - } - } - - if (threadIdx.x == 0) { - y_warpwise[threadIdx.y] = y; - } - __syncthreads(); - - float y_write = 0.f; -#pragma unroll - for (size_t i = 0; i < ty; ++i) { - y_write += y_warpwise[i]; - } - - // write Y; - if (threadIdx.x == 0 && threadIdx.y == 0) { - size_t y_idx = batch_idx * full_y_size + y_offset + j; - Y[y_idx] = vllm_add(Y[y_idx], convert_type(y_write)); - } -} - -#endif - // nthrs = (2, 16, 4) template @@ -271,11 +172,7 @@ bgmv_expand_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, float sum = 0.f; #pragma unroll for (size_t i = 0; i < vec_size; ++i) { -#ifndef USE_ROCM sum += float(w_vec[i]) * float(x_vec[i]) * scale; -#else - sum += convert_type(w_vec[i]) * convert_type(x_vec[i]) * scale; -#endif } cg::thread_block_tile g = cg::tiled_partition(block); @@ -286,14 +183,8 @@ bgmv_expand_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, sum = g.shfl(sum, 0); if (threadIdx.x == 0) { -#ifndef USE_ROCM Y[batch_idx * full_y_size + y_offset + tile_idx * (tz * ty) + threadIdx.z * ty + threadIdx.y] += static_cast(sum); -#else - size_t y_idx = batch_idx * full_y_size + y_offset + tile_idx * (tz * ty) + - threadIdx.z * ty + threadIdx.y; - Y[y_idx] = vllm_add(Y[y_idx], convert_type(sum)); -#endif } } @@ -345,7 +236,6 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, scale); } } else { -#ifndef USE_ROCM static_assert(feat_in % (vec_size * 32) == 0 || feat_in % (vec_size * 16) == 0 || feat_in % (vec_size * 8) == 0); @@ -389,50 +279,6 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, full_y_size, num_layers, layer_idx, scale); } -#else - constexpr size_t rocm_warp_size = warpSize; - -#define CHECK_INPUT_TILEABLE_BY(vec_size_) \ - feat_in % (rocm_warp_size * vec_size_) == 0 - -#define LAUNCH_BGMV_SHRINK_KERNELS_ROCM(factor_, vec_size_, tx_, ty_) \ - if constexpr (CHECK_INPUT_TILEABLE_BY(factor_)) { \ - constexpr size_t vec_size_shrink = vec_size_; \ - constexpr int tx = tx_; \ - constexpr int ty = ty_; \ - dim3 nblks(feat_out, batch_size); \ - dim3 nthrs(tx, ty); \ - bgmv_shrink_kernel \ - <<>>(Y, X, W, indicies, y_offset, \ - full_y_size, num_layers, layer_idx, \ - scale); \ - } - - static_assert(CHECK_INPUT_TILEABLE_BY(32) || - CHECK_INPUT_TILEABLE_BY(16) || - CHECK_INPUT_TILEABLE_BY( 8) || - CHECK_INPUT_TILEABLE_BY( 4) || - CHECK_INPUT_TILEABLE_BY( 2) || - CHECK_INPUT_TILEABLE_BY( 1)); - - LAUNCH_BGMV_SHRINK_KERNELS_ROCM(32, vec_size, rocm_warp_size, 32/vec_size) - else - LAUNCH_BGMV_SHRINK_KERNELS_ROCM(16, vec_size, rocm_warp_size, 16/vec_size) - else - LAUNCH_BGMV_SHRINK_KERNELS_ROCM( 8, vec_size, rocm_warp_size, 8/vec_size) - else - LAUNCH_BGMV_SHRINK_KERNELS_ROCM( 4, vec_size, rocm_warp_size/(vec_size/4), vec_size/4) - else - LAUNCH_BGMV_SHRINK_KERNELS_ROCM( 2, vec_size, rocm_warp_size/(vec_size/2), vec_size/2) - else - LAUNCH_BGMV_SHRINK_KERNELS_ROCM( 1, vec_size, rocm_warp_size/(vec_size/1), vec_size/1) - -#undef CHECK_INPUT_TILEABLE_BY -#undef LAUNCH_BGMV_SHRINK_KERNELS_ROCM -#endif } } diff --git a/csrc/punica/bgmv/vec_dtypes.cuh b/csrc/punica/bgmv/vec_dtypes.cuh index 2738892e6dc4a..cf00d869cf635 100644 --- a/csrc/punica/bgmv/vec_dtypes.cuh +++ b/csrc/punica/bgmv/vec_dtypes.cuh @@ -1,6 +1,8 @@ #ifndef VEC_DTYPES_CUH_ #define VEC_DTYPES_CUH_ +#include +#include #ifdef FLASHINFER_USE_FP8 #include #endif @@ -8,9 +10,6 @@ #include -#include "../type_convert.h" -#include "../../cuda_compat.h" - #define FLASHINFER_INLINE \ inline __attribute__((always_inline)) __device__ __host__ diff --git a/csrc/punica/punica_ops.cu b/csrc/punica/punica_ops.cc similarity index 97% rename from csrc/punica/punica_ops.cu rename to csrc/punica/punica_ops.cc index dd29820144b34..8797fde85744a 100644 --- a/csrc/punica/punica_ops.cu +++ b/csrc/punica/punica_ops.cc @@ -1,11 +1,12 @@ -#include +#include +#include +#include #include #include -#include "type_convert.h" -#include "../cuda_compat.h" #include "bgmv/bgmv_config.h" +namespace { //====== utils ====== @@ -88,7 +89,7 @@ inline bool launch_bgmv_kernel(out_T *Y, const in_T *X, const W_T *W, } void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w, - torch::Tensor indicies, int64_t layer_idx, double scale) { + torch::Tensor indicies, int64_t layer_idx, float scale) { CHECK_INPUT(y); CHECK_INPUT(x); CHECK_INPUT(w); @@ -320,7 +321,7 @@ void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w, void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor x, torch::Tensor w, torch::Tensor indicies, int64_t layer_idx, - double scale, int64_t h_in, int64_t h_out, + float scale, int64_t h_in, int64_t h_out, int64_t y_offset) { CHECK_INPUT(y); CHECK_INPUT(x); @@ -567,3 +568,15 @@ void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor x, torch::Tensor w, TORCH_CHECK(ok, "No suitable kernel.", " h_in=", h_in, " h_out=", h_out, " dtype=", x.scalar_type(), " out_dtype=", y.scalar_type()); } + +} // namespace + +//====== pybind ====== + +#define DEFINE_pybind(name) m.def(#name, &name, #name); + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("dispatch_bgmv", &dispatch_bgmv, "dispatch_bgmv"); + m.def("dispatch_bgmv_low_level", &dispatch_bgmv_low_level, + "dispatch_bgmv_low_level"); +} diff --git a/csrc/punica/punica_ops.h b/csrc/punica/punica_ops.h deleted file mode 100644 index 5d625d0564f75..0000000000000 --- a/csrc/punica/punica_ops.h +++ /dev/null @@ -1,11 +0,0 @@ -#pragma once - -#include - -void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w, - torch::Tensor indicies, int64_t layer_idx, double scale); - -void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor x, torch::Tensor w, - torch::Tensor indicies, int64_t layer_idx, - double scale, int64_t h_in, int64_t h_out, - int64_t y_offset); diff --git a/csrc/punica/torch_bindings.cpp b/csrc/punica/torch_bindings.cpp deleted file mode 100644 index 894e229b6d9db..0000000000000 --- a/csrc/punica/torch_bindings.cpp +++ /dev/null @@ -1,18 +0,0 @@ -#include "registration.h" -#include "punica_ops.h" - -TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { - m.def( - "dispatch_bgmv(Tensor! y, Tensor x, Tensor w, Tensor indicies, int " - "layer_idx, float scale) -> ()"); - m.impl("dispatch_bgmv", torch::kCUDA, &dispatch_bgmv); - - m.def( - "dispatch_bgmv_low_level(Tensor! y, Tensor x, Tensor w," - "Tensor indicies, int layer_idx," - "float scale, int h_in, int h_out," - "int y_offset) -> ()"); - m.impl("dispatch_bgmv_low_level", torch::kCUDA, &dispatch_bgmv_low_level); -} - -REGISTER_EXTENSION(TORCH_EXTENSION_NAME) diff --git a/csrc/punica/type_convert.h b/csrc/punica/type_convert.h deleted file mode 100644 index dff7ce49283d7..0000000000000 --- a/csrc/punica/type_convert.h +++ /dev/null @@ -1,82 +0,0 @@ -#ifndef CSRC__PUNICA__TYPE_CONVERT_H__ -#define CSRC__PUNICA__TYPE_CONVERT_H__ - -#ifndef USE_ROCM - -#include -#include - -#else - -#include -#include - -#define __TYPE_CONVERT__HOST_DEVICE__ __host__ __device__ - -typedef __half nv_half; -typedef __hip_bfloat16 nv_bfloat16; -typedef __hip_bfloat162 nv_bfloat162; - -__TYPE_CONVERT__HOST_DEVICE__ -inline __hip_bfloat162 make_bfloat162(__hip_bfloat16 val) { - return __hip_bfloat162{val, val}; -} - -__TYPE_CONVERT__HOST_DEVICE__ -inline __hip_bfloat162 make_bfloat162(__hip_bfloat16 vall, __hip_bfloat16 valr) { - return __hip_bfloat162{vall, valr}; -} - -template -__TYPE_CONVERT__HOST_DEVICE__ -inline T_dst convert_type(T_src val) { - return static_cast(val); -} - -template <> -__TYPE_CONVERT__HOST_DEVICE__ -inline float convert_type<__half, float>(__half val) { - return __half2float(val); -} - -template <> -__TYPE_CONVERT__HOST_DEVICE__ -inline __half convert_type(float val) { - return __float2half(val); -} - -template <> -__TYPE_CONVERT__HOST_DEVICE__ -inline float convert_type<__hip_bfloat16, float>(__hip_bfloat16 val) { - return __bfloat162float(val); -} - -template <> -__TYPE_CONVERT__HOST_DEVICE__ -inline __hip_bfloat16 convert_type(float val) { - return __float2bfloat16(val); -} - -template -__TYPE_CONVERT__HOST_DEVICE__ -inline T vllm_add(T a, T b) { - return a + b; -} - -template <> -__TYPE_CONVERT__HOST_DEVICE__ -inline __half vllm_add<__half>(__half a, __half b) { - return __hadd(a, b); -} - -template <> -__TYPE_CONVERT__HOST_DEVICE__ -inline __hip_bfloat16 vllm_add<__hip_bfloat16>(__hip_bfloat16 a, __hip_bfloat16 b) { - return __hadd(a, b); -} - -#undef __TYPE_CONVERT__HOST_DEVICE__ - -#endif // USE_ROCM - -#endif // CSRC__PUNICA__TYPE_CONVERT_H__ diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp new file mode 100644 index 0000000000000..173e0b1732e13 --- /dev/null +++ b/csrc/pybind.cpp @@ -0,0 +1,136 @@ +#include "cache.h" +#include "cuda_utils.h" +#include "ops.h" +#include + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + // vLLM custom ops + pybind11::module ops = m.def_submodule("ops", "vLLM custom operators"); + + // Attention ops + ops.def( + "paged_attention_v1", + &paged_attention_v1, + "Compute the attention between an input query and the cached keys/values using PagedAttention."); + ops.def( + "paged_attention_v2", + &paged_attention_v2, + "PagedAttention V2."); + + // Activation ops + ops.def( + "silu_and_mul", + &silu_and_mul, + "Activation function used in SwiGLU."); + ops.def( + "gelu_and_mul", + &gelu_and_mul, + "Activation function used in GeGLU with `none` approximation."); + ops.def( + "gelu_tanh_and_mul", + &gelu_tanh_and_mul, + "Activation function used in GeGLU with `tanh` approximation."); + ops.def( + "gelu_new", + &gelu_new, + "GELU implementation used in GPT-2."); + ops.def( + "gelu_fast", + &gelu_fast, + "Approximate GELU implementation."); + + // Layernorm + ops.def( + "rms_norm", + &rms_norm, + "Apply Root Mean Square (RMS) Normalization to the input tensor."); + + ops.def( + "fused_add_rms_norm", + &fused_add_rms_norm, + "In-place fused Add and RMS Normalization"); + + // Rotary embedding + ops.def( + "rotary_embedding", + &rotary_embedding, + "Apply GPT-NeoX or GPT-J style rotary embedding to query and key"); + + ops.def( + "batched_rotary_embedding", + &batched_rotary_embedding, + "Apply GPT-NeoX or GPT-J style rotary embedding to query and key (supports multiple loras)"); + +// Quantization ops +#ifndef USE_ROCM + ops.def("aqlm_gemm", &aqlm_gemm, "Quantized GEMM for AQLM"); + ops.def("aqlm_dequant", &aqlm_dequant, "Decompression method for AQLM"); + ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ"); + ops.def("marlin_gemm", &marlin_gemm, "Marlin Optimized Quantized GEMM for GPTQ"); + ops.def("gptq_marlin_gemm", &gptq_marlin_gemm, "gptq_marlin Optimized Quantized GEMM for GPTQ"); + ops.def("gptq_marlin_repack", &gptq_marlin_repack, "gptq_marlin repack from GPTQ"); + ops.def("awq_dequantize", &awq_dequantize, "Dequantization for AWQ"); +#endif + + ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ"); + ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ"); + ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM"); + ops.def("static_scaled_fp8_quant", &static_scaled_fp8_quant, "Compute FP8 quantized tensor for given scaling factor"); + ops.def("dynamic_scaled_fp8_quant", &dynamic_scaled_fp8_quant, "Compute FP8 quantized tensor and scaling factor"); + ops.def( + "moe_align_block_size", + &moe_align_block_size, + "Aligning the number of tokens to be processed by each expert such that it is divisible by the block size."); + + // Cache ops + pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops"); + cache_ops.def( + "swap_blocks", + &swap_blocks, + "Swap in (out) the cache blocks from src to dst"); + cache_ops.def( + "copy_blocks", + ©_blocks, + "Copy the cache blocks from src to dst"); + cache_ops.def( + "reshape_and_cache", + &reshape_and_cache, + "Reshape the key and value tensors and cache them"); + cache_ops.def( + "reshape_and_cache_flash", + &reshape_and_cache_flash, + "Reshape the key and value tensors and cache them"); + cache_ops.def( + "convert_fp8", + &convert_fp8, + "Convert the key and value cache to fp8 data type"); + + // Cuda utils + pybind11::module cuda_utils = m.def_submodule("cuda_utils", "vLLM cuda utils"); + cuda_utils.def( + "get_device_attribute", + &get_device_attribute, + "Gets the specified device attribute."); + + cuda_utils.def( + "get_max_shared_memory_per_block_device_attribute", + &get_max_shared_memory_per_block_device_attribute, + "Gets the maximum shared memory per block device attribute."); + +#ifndef USE_ROCM + // Custom all-reduce kernels + pybind11::module custom_ar = m.def_submodule("custom_ar", "custom allreduce"); + custom_ar.def("init_custom_ar", &init_custom_ar, "init_custom_ar"); + custom_ar.def("should_custom_ar", &should_custom_ar, "should_custom_ar"); + custom_ar.def("all_reduce_reg", &all_reduce_reg, "all_reduce_reg"); + custom_ar.def("all_reduce_unreg", &all_reduce_unreg, "all_reduce_unreg"); + custom_ar.def("dispose", &dispose, "dispose"); + custom_ar.def("meta_size", &meta_size, "meta_size"); + custom_ar.def("register_buffer", ®ister_buffer, "register_buffer"); + custom_ar.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta, + "get_graph_buffer_ipc_meta"); + custom_ar.def("register_graph_buffers", ®ister_graph_buffers, + "register_graph_buffers"); +#endif + +} diff --git a/csrc/quantization/aqlm/gemm_kernels.cu b/csrc/quantization/aqlm/gemm_kernels.cu index 8fb9856800867..4415316e1e8cd 100644 --- a/csrc/quantization/aqlm/gemm_kernels.cu +++ b/csrc/quantization/aqlm/gemm_kernels.cu @@ -18,35 +18,39 @@ #include #include #include -#include +#include #include #include #include #include + namespace vllm { namespace aqlm { __global__ void Code1x16MatVec( - const int4* __restrict__ A, const int4* __restrict__ B, - int4* __restrict__ C, const int4* __restrict__ codebook, const int prob_m, - const int prob_k, - const int4 codebook_a_sizes, // cumulative sizes of A spanning each - // codebook, at most 3 long. - const int codebook_stride // as int4. + const int4* __restrict__ A, + const int4* __restrict__ B, + int4* __restrict__ C, + const int4* __restrict__ codebook, + const int prob_m, + const int prob_k, + const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long. + const int codebook_stride // as int4. ) { int a_gl_stride = prob_k / 8 / 8; int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32); bool pred = a_gl_rd < prob_m; - if (pred) { - // advance to the correct codebook, this easy because we only multiply one - // column of the codebook. + if (pred) + { + // advance to the correct codebook, this easy because we only multiply one column of the codebook. auto codebook_size = &codebook_a_sizes.x; - while (a_gl_rd >= *codebook_size) { - codebook += codebook_stride; - ++codebook_size; + while (a_gl_rd >= *codebook_size) + { + codebook += codebook_stride; + ++codebook_size; } } @@ -63,7 +67,8 @@ __global__ void Code1x16MatVec( // We pad shared memory to avoid bank conflicts during reads __syncthreads(); for (int i = threadIdx.x; i < 32 * 8; i += blockDim.x) { - if (b_gl_rd + i < prob_k / 8) sh_b[9 * (i / 8) + i % 8] = B[b_gl_rd + i]; + if (b_gl_rd + i < prob_k / 8) + sh_b[9 * (i / 8) + i % 8] = B[b_gl_rd + i]; } __syncthreads(); b_gl_rd += 32 * 8; @@ -71,19 +76,22 @@ __global__ void Code1x16MatVec( int b_sh_rd = 9 * (threadIdx.x % 32); if (pred && a_gl_rd < a_gl_end) { const uint16_t* enc = reinterpret_cast(&A[a_gl_rd]); -#pragma unroll + #pragma unroll for (int i = 0; i < 8; i++) { uint32_t dec[4]; - // We bypass the L1 cache to avoid massive amounts of memory streaming - // that doesn't actually help us; this brings > 2x speedup. - asm volatile("ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];" - : "=r"(dec[0]), "=r"(dec[1]), "=r"(dec[2]), "=r"(dec[3]) - : "l"((void*)&codebook[enc[i]])); + // We bypass the L1 cache to avoid massive amounts of memory streaming that doesn't + // actually help us; this brings > 2x speedup. + asm volatile ( + "ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];" + : "=r"(dec[0]), "=r"(dec[1]), "=r"(dec[2]), "=r"(dec[3]) + : "l"((void*) &codebook[enc[i]]) + ); half2* a = reinterpret_cast(&dec); half2* b = reinterpret_cast(&sh_b[b_sh_rd]); half2 res2 = {}; -#pragma unroll - for (int j = 0; j < 4; j++) res2 = __hfma2(a[j], b[j], res2); + #pragma unroll + for (int j = 0; j < 4; j++) + res2 = __hfma2(a[j], b[j], res2); res += __half2float(res2.x) + __half2float(res2.y); b_sh_rd++; } @@ -92,33 +100,37 @@ __global__ void Code1x16MatVec( } if (pred) { -#pragma unroll - for (int i = 16; i > 0; i /= 2) res += __shfl_down_sync(0xffffffff, res, i); + #pragma unroll + for (int i = 16; i > 0; i /= 2) + res += __shfl_down_sync(0xffffffff, res, i); if (threadIdx.x % 32 == 0) reinterpret_cast<__half*>(C)[c_gl_wr] = __float2half(res); } } __global__ void Code2x8MatVec( - const int4* __restrict__ A, const int4* __restrict__ B, - int4* __restrict__ C, const int4* __restrict__ codebook, int prob_m, - int prob_k, - const int4 codebook_a_sizes, // cumulative sizes of A spanning each - // codebook, at most 3 long. - const int codebook_stride // as int4. + const int4* __restrict__ A, + const int4* __restrict__ B, + int4* __restrict__ C, + const int4* __restrict__ codebook, + int prob_m, + int prob_k, + const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long. + const int codebook_stride // as int4. ) { int a_gl_stride = prob_k / 8 / 8; int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32); bool pred = a_gl_rd < prob_m; - if (pred) { - // advance to the correct codebook, this easy because we only multiply one - // column of the codebook. + if (pred) + { + // advance to the correct codebook, this easy because we only multiply one column of the codebook. auto codebook_size = &codebook_a_sizes.x; - while (a_gl_rd >= *codebook_size) { - codebook += codebook_stride; - ++codebook_size; + while (a_gl_rd >= *codebook_size) + { + codebook += codebook_stride; + ++codebook_size; } } @@ -136,8 +148,9 @@ __global__ void Code2x8MatVec( for (int i = threadIdx.x; i < 2 * 256; i += blockDim.x) { int4 dec = codebook[i]; -#pragma unroll - for (int j = 0; j < 8; j++) sh_code[8 * i + (j + lane) % 8] = dec; + #pragma unroll + for (int j = 0; j < 8; j++) + sh_code[8 * i + (j + lane) % 8] = dec; } __syncthreads(); @@ -148,7 +161,8 @@ __global__ void Code2x8MatVec( // We pad shared memory to avoid bank conflicts during reads __syncthreads(); for (int i = threadIdx.x; i < 32 * 8; i += blockDim.x) { - if (b_gl_rd + i < prob_k / 8) sh_b[9 * (i / 8) + i % 8] = B[b_gl_rd + i]; + if (b_gl_rd + i < prob_k / 8) + sh_b[9 * (i / 8) + i % 8] = B[b_gl_rd + i]; } __syncthreads(); b_gl_rd += 32 * 8; @@ -156,15 +170,13 @@ __global__ void Code2x8MatVec( int b_sh_rd = 9 * (threadIdx.x % 32); if (pred && a_gl_rd < a_gl_end) { const uint8_t* enc = reinterpret_cast(&A[a_gl_rd]); -#pragma unroll + #pragma unroll for (int i = 0; i < 8; i++) { - half2* a0 = - reinterpret_cast(&sh_code0[8 * enc[2 * i + 0] + lane]); - half2* a1 = - reinterpret_cast(&sh_code1[8 * enc[2 * i + 1] + lane]); - half2* b = reinterpret_cast(&sh_b[b_sh_rd]); + half2* a0 = reinterpret_cast(&sh_code0[8 * enc[2 * i + 0] + lane]); + half2* a1 = reinterpret_cast(&sh_code1[8 * enc[2 * i + 1] + lane]); + half2* b = reinterpret_cast(&sh_b[b_sh_rd]); half2 res2 = {}; -#pragma unroll + #pragma unroll for (int j = 0; j < 4; j++) res2 = __hfma2(__hadd2(a0[j], a1[j]), b[j], res2); res += __half2float(res2.x) + __half2float(res2.y); @@ -175,31 +187,36 @@ __global__ void Code2x8MatVec( } if (pred) { -#pragma unroll - for (int i = 16; i > 0; i /= 2) res += __shfl_down_sync(0xffffffff, res, i); + #pragma unroll + for (int i = 16; i > 0; i /= 2) + res += __shfl_down_sync(0xffffffff, res, i); if (threadIdx.x % 32 == 0) reinterpret_cast<__half*>(C)[c_gl_wr] = __float2half(res); } } + __global__ void Code1x16Dequant( - const int4* __restrict__ A, int4* __restrict__ C, - const int4* __restrict__ codebook, int prob_m, int prob_k, - const int4 codebook_a_sizes, // cumulative sizes of A spanning each - // codebook, at most 3 long, sums to m. - const int codebook_stride // as int4 + const int4* __restrict__ A, + int4* __restrict__ C, + const int4* __restrict__ codebook, + int prob_m, + int prob_k, + const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long, sums to m. + const int codebook_stride // as int4 ) { int a_gl_stride = prob_k / 8 / 8; int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32); bool pred = a_gl_rd < prob_m; - if (pred) { - // advance to the correct codebook, this easy because we only multiply one - // column of the codebook. + if (pred) + { + // advance to the correct codebook, this easy because we only multiply one column of the codebook. auto codebook_size = &codebook_a_sizes.x; - while (a_gl_rd >= *codebook_size) { - codebook += codebook_stride; - ++codebook_size; + while (a_gl_rd >= *codebook_size) + { + codebook += codebook_stride; + ++codebook_size; } } @@ -214,15 +231,17 @@ __global__ void Code1x16Dequant( while (iters--) { if (pred && a_gl_rd < a_gl_end) { const uint16_t* enc = reinterpret_cast(&A[a_gl_rd]); -#pragma unroll + #pragma unroll for (int i = 0; i < 8; i++) { int4 chunk; auto dec = reinterpret_cast(&chunk); - // We bypass the L1 cache to avoid massive amounts of memory streaming - // that doesn't actually help us; this brings > 2x speedup. - asm volatile("ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];" - : "=r"(dec[0]), "=r"(dec[1]), "=r"(dec[2]), "=r"(dec[3]) - : "l"((void*)&codebook[enc[i]])); + // We bypass the L1 cache to avoid massive amounts of memory streaming that doesn't + // actually help us; this brings > 2x speedup. + asm volatile ( + "ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];" + : "=r"(dec[0]), "=r"(dec[1]), "=r"(dec[2]), "=r"(dec[3]) + : "l"((void*) &codebook[enc[i]]) + ); C[a_gl_rd * 8 + i] = chunk; } @@ -231,25 +250,28 @@ __global__ void Code1x16Dequant( } } + __global__ void Code2x8Dequant( - const int4* __restrict__ A, int4* __restrict__ C, - const int4* __restrict__ codebook, int prob_m, int prob_k, - const int4 - codebook_a_sizes, // cumulative sizes of A spanning each codebook, at - // most 3 long, corresponds to cols. - const int codebook_stride // as int4 + const int4* __restrict__ A, + int4* __restrict__ C, + const int4* __restrict__ codebook, + int prob_m, + int prob_k, + const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long, corresponds to cols. + const int codebook_stride // as int4 ) { int a_gl_stride = prob_k / 8 / 8; int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32); bool pred = a_gl_rd < prob_m; - if (pred) { - // advance to the correct codebook, this easy because we only multiply one - // column of the codebook. + if (pred) + { + // advance to the correct codebook, this easy because we only multiply one column of the codebook. auto codebook_size = &codebook_a_sizes.x; - while (a_gl_rd >= *codebook_size) { - codebook += codebook_stride; - ++codebook_size; + while (a_gl_rd >= *codebook_size) + { + codebook += codebook_stride; + ++codebook_size; } } @@ -268,8 +290,9 @@ __global__ void Code2x8Dequant( for (int i = threadIdx.x; i < 2 * 256; i += blockDim.x) { int4 dec = codebook[i]; -#pragma unroll - for (int j = 0; j < 8; j++) sh_code[8 * i + (j + lane) % 8] = dec; + #pragma unroll + for (int j = 0; j < 8; j++) + sh_code[8 * i + (j + lane) % 8] = dec; } __syncthreads(); @@ -279,14 +302,12 @@ __global__ void Code2x8Dequant( while (iters--) { if (pred && a_gl_rd < a_gl_end) { const uint8_t* enc = reinterpret_cast(&A[a_gl_rd]); -#pragma unroll + #pragma unroll for (int i = 0; i < 8; i++) { int4 chunk; - half2* a0 = - reinterpret_cast(&sh_code0[8 * enc[2 * i + 0] + lane]); - half2* a1 = - reinterpret_cast(&sh_code1[8 * enc[2 * i + 1] + lane]); -#pragma unroll + half2* a0 = reinterpret_cast(&sh_code0[8 * enc[2 * i + 0] + lane]); + half2* a1 = reinterpret_cast(&sh_code1[8 * enc[2 * i + 1] + lane]); + #pragma unroll for (int j = 0; j < 4; j++) reinterpret_cast(&chunk)[j] = __hadd2(a0[j], a1[j]); C[a_gl_rd * 8 + i] = chunk; @@ -296,15 +317,22 @@ __global__ void Code2x8Dequant( } } -inline int ceildiv(int a, int b) { return (a + b - 1) / b; } +inline int ceildiv(int a, int b) { + return (a + b - 1) / b; +} const int THREAD_M = 16; -void code1x16_matvec_cuda(const void* __restrict__ A, - const void* __restrict__ B, void* __restrict__ C, - const void* __restrict__ codebook, int prob_m, - int prob_k, const int4 codebook_a_sizes, - const int codebook_stride) { +void code1x16_matvec_cuda( + const void* __restrict__ A, + const void* __restrict__ B, + void* __restrict__ C, + const void* __restrict__ codebook, + int prob_m, + int prob_k, + const int4 codebook_a_sizes, + const int codebook_stride +) { int sms; cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0); int waves = 0; @@ -317,16 +345,28 @@ void code1x16_matvec_cuda(const void* __restrict__ A, int blocks = ceildiv(prob_m, thread_m); int threads = 32 * thread_m; cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - Code1x16MatVec<<>>( - (const int4*)A, (const int4*)B, (int4*)C, (const int4*)codebook, prob_m, - prob_k, codebook_a_sizes, codebook_stride); + Code1x16MatVec<<>>( + (const int4*) A, + (const int4*) B, + (int4*) C, + (const int4*) codebook, + prob_m, + prob_k, + codebook_a_sizes, + codebook_stride + ); } -void code2x8_matvec_cuda(const void* __restrict__ A, const void* __restrict__ B, - void* __restrict__ C, - const void* __restrict__ codebook, int prob_m, - int prob_k, const int4 codebook_a_sizes, - const int codebook_stride) { +void code2x8_matvec_cuda( + const void* __restrict__ A, + const void* __restrict__ B, + void* __restrict__ C, + const void* __restrict__ codebook, + int prob_m, + int prob_k, + const int4 codebook_a_sizes, + const int codebook_stride +) { int sms; cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0); int waves = 0; @@ -339,20 +379,30 @@ void code2x8_matvec_cuda(const void* __restrict__ A, const void* __restrict__ B, int blocks = ceildiv(prob_m, thread_m); int threads = 32 * thread_m; int shared = 16 * (2 * 256 * 8 + 32 * 9); - cudaFuncSetAttribute(Code2x8MatVec, - cudaFuncAttributeMaxDynamicSharedMemorySize, shared); + cudaFuncSetAttribute( + Code2x8MatVec, cudaFuncAttributeMaxDynamicSharedMemorySize, shared + ); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); Code2x8MatVec<<>>( - (const int4*)A, (const int4*)B, (int4*)C, (const int4*)codebook, prob_m, - prob_k, codebook_a_sizes, codebook_stride); + (const int4*) A, + (const int4*) B, + (int4*) C, + (const int4*) codebook, + prob_m, + prob_k, + codebook_a_sizes, + codebook_stride + ); } void code1x16_dequant_cuda( - const void* __restrict__ A, void* __restrict__ C, - const void* __restrict__ codebook, int prob_m, int prob_k, - const int4 codebook_a_sizes, // cumulative sizes of A spanning each - // codebook, at most 3 long. - const int codebook_stride // as int4. + const void* __restrict__ A, + void* __restrict__ C, + const void* __restrict__ codebook, + int prob_m, + int prob_k, + const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long. + const int codebook_stride // as int4. ) { int sms; cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0); @@ -367,21 +417,25 @@ void code1x16_dequant_cuda( int threads = 32 * thread_m; cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); Code1x16Dequant<<>>( - (const int4*)A, (int4*)C, (const int4*)codebook, prob_m, prob_k, - codebook_a_sizes, // cumulative sizes of A spanning each codebook, at - // most 3 long. - codebook_stride // as int4. + (const int4*) A, + (int4*) C, + (const int4*) codebook, + prob_m, + prob_k, + codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long. + codebook_stride // as int4. ); } // Dequantizes the code and codebook into weights. -void code2x8_dequant_cuda( - const void* __restrict__ A, void* __restrict__ C, - const void* __restrict__ codebook, int prob_m, int prob_k, - const int4 - codebook_a_sizes, // cumulative sizes of A spanning each codebook, at - // most 3 long, corresponds to cols. - const int codebook_stride // as int4 +void code2x8_dequant_cuda( + const void* __restrict__ A, + void* __restrict__ C, + const void* __restrict__ codebook, + int prob_m, + int prob_k, + const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long, corresponds to cols. + const int codebook_stride // as int4 ) { int sms; cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0); @@ -397,50 +451,74 @@ void code2x8_dequant_cuda( int shared = 16 * (2 * 256 * 8 + 32 * 9); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - cudaFuncSetAttribute(Code2x8Dequant, - cudaFuncAttributeMaxDynamicSharedMemorySize, shared); + cudaFuncSetAttribute( + Code2x8Dequant, cudaFuncAttributeMaxDynamicSharedMemorySize, shared + ); Code2x8Dequant<<>>( - (const int4*)A, (int4*)C, (const int4*)codebook, prob_m, prob_k, - codebook_a_sizes, codebook_stride); + (const int4*) A, + (int4*) C, + (const int4*) codebook, + prob_m, + prob_k, + codebook_a_sizes, + codebook_stride + ); } -int codebook_stride(const torch::Tensor& codebooks) { +int codebook_stride(const torch::Tensor& codebooks) +{ return codebooks.stride(0) * codebooks.element_size() / sizeof(int4); } void code1x16_matvec( - const torch::Tensor& A, const torch::Tensor& B, torch::Tensor& C, - const torch::Tensor& codebook, - const int4 codebook_a_sizes // cumulative sizes of A spanning each - // codebook, at most 3 long. + const torch::Tensor& A, + const torch::Tensor& B, + torch::Tensor& C, + const torch::Tensor& codebook, + const int4 codebook_a_sizes // cumulative sizes of A spanning each codebook, at most 3 long. ) { const at::cuda::OptionalCUDAGuard device_guard(device_of(A)); int prob_m = C.size(0); int prob_k = B.size(0); - code1x16_matvec_cuda(A.data_ptr(), B.data_ptr(), C.data_ptr(), - codebook.data_ptr(), prob_m, prob_k, codebook_a_sizes, - codebook_stride(codebook)); + code1x16_matvec_cuda( + A.data_ptr(), + B.data_ptr(), + C.data_ptr(), + codebook.data_ptr(), + prob_m, + prob_k, + codebook_a_sizes, + codebook_stride(codebook) + ); } -torch::Tensor code1x16_matmat(const torch::Tensor& input, - const torch::Tensor& codes, - const torch::Tensor& codebooks, - const torch::Tensor& scales, - const int4 codebook_a_sizes, - const std::optional& bias) { +torch::Tensor code1x16_matmat( + const torch::Tensor& input, + const torch::Tensor& codes, + const torch::Tensor& codebooks, + const torch::Tensor& scales, + const int4 codebook_a_sizes, + const std::optional& bias) { auto input_sizes = input.sizes(); auto out_features = codes.size(0) * codebooks.size(2); auto flat_input = input.reshape({-1, input.size(-1)}); - auto flat_output = torch::empty( - {flat_input.size(0), out_features}, - torch::TensorOptions().dtype(input.dtype()).device(input.device())); + auto flat_output = torch::empty({flat_input.size(0), out_features}, + torch::TensorOptions() + .dtype(input.dtype()) + .device(input.device()) + ); for (int i = 0; i < flat_input.size(0); ++i) { auto input_vec = flat_input.index({i}); auto output_vec = flat_output.index({i}); - code1x16_matvec(codes.squeeze(2), input_vec, output_vec, codebooks, - codebook_a_sizes); + code1x16_matvec( + codes.squeeze(2), + input_vec, + output_vec, + codebooks, + codebook_a_sizes + ); } flat_output *= scales.flatten().unsqueeze(0); @@ -455,35 +533,55 @@ torch::Tensor code1x16_matmat(const torch::Tensor& input, return output; } -void code2x8_matvec(const torch::Tensor& A, const torch::Tensor& B, - torch::Tensor& C, const torch::Tensor& codebook, - const int4 codebook_a_sizes) { +void code2x8_matvec( + const torch::Tensor& A, + const torch::Tensor& B, + torch::Tensor& C, + const torch::Tensor& codebook, + const int4 codebook_a_sizes +) { const at::cuda::OptionalCUDAGuard device_guard(device_of(A)); int prob_m = C.size(0); int prob_k = B.size(0); - code2x8_matvec_cuda(A.data_ptr(), B.data_ptr(), C.data_ptr(), - codebook.data_ptr(), prob_m, prob_k, codebook_a_sizes, - 2 * codebook_stride(codebook)); + code2x8_matvec_cuda( + A.data_ptr(), + B.data_ptr(), + C.data_ptr(), + codebook.data_ptr(), + prob_m, + prob_k, + codebook_a_sizes, + 2 * codebook_stride(codebook) + ); } -torch::Tensor code2x8_matmat(const torch::Tensor& input, - const torch::Tensor& codes, - const torch::Tensor& codebooks, - const torch::Tensor& scales, - const int4 codebook_a_sizes, - const std::optional& bias) { +torch::Tensor code2x8_matmat( + const torch::Tensor& input, + const torch::Tensor& codes, + const torch::Tensor& codebooks, + const torch::Tensor& scales, + const int4 codebook_a_sizes, + const std::optional& bias +) { auto input_sizes = input.sizes(); auto out_features = codes.size(0) * codebooks.size(2); auto flat_input = input.reshape({-1, input.size(-1)}); - auto flat_output = torch::empty( - {flat_input.size(0), out_features}, - torch::TensorOptions().dtype(input.dtype()).device(input.device())); + auto flat_output = torch::empty({flat_input.size(0), out_features}, + torch::TensorOptions() + .dtype(input.dtype()) + .device(input.device()) + ); for (int i = 0; i < flat_input.size(0); ++i) { auto input_vec = flat_input.index({i}); auto output_vec = flat_output.index({i}); - code2x8_matvec(codes.squeeze(2), input_vec, output_vec, codebooks, - codebook_a_sizes); + code2x8_matvec( + codes.squeeze(2), + input_vec, + output_vec, + codebooks, + codebook_a_sizes + ); } flat_output *= scales.flatten().unsqueeze(0); if (bias.has_value()) { @@ -498,56 +596,64 @@ torch::Tensor code2x8_matmat(const torch::Tensor& input, } // Accumulate the partition sizes. -int4 accumulate_sizes(const torch::Tensor& codebook_partition_sizes) { +int4 accumulate_sizes(const torch::Tensor& codebook_partition_sizes) +{ int4 cumulative_sizes; auto cumulative_size = &cumulative_sizes.x; int i = 0; int last = 0; assert(codebook_partition_sizes.size(0) <= 4); - for (; i < codebook_partition_sizes.size(0); ++i, ++cumulative_size) { + for (; i < codebook_partition_sizes.size(0); ++i, ++cumulative_size) + { *cumulative_size = codebook_partition_sizes[i].item() + last; last = *cumulative_size; } // fill in the rest with unreachable. - for (; i < 4; ++i, ++cumulative_size) { - *cumulative_size = last * 10; + for (; i < 4; ++i, ++cumulative_size) + { + *cumulative_size = last*10; } return cumulative_sizes; } -} // namespace aqlm -} // namespace vllm +} // namespace aqlm +} // namespace vllm + -torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes, - const torch::Tensor& codebooks, - const torch::Tensor& scales, - const torch::Tensor& codebook_partition_sizes, - const std::optional& bias) { - int4 cumulative_sizes = - vllm::aqlm::accumulate_sizes(codebook_partition_sizes); +torch::Tensor aqlm_gemm( + const torch::Tensor& input, + const torch::Tensor& codes, + const torch::Tensor& codebooks, + const torch::Tensor& scales, + const torch::Tensor& codebook_partition_sizes, + const std::optional& bias +) +{ + int4 cumulative_sizes = vllm::aqlm::accumulate_sizes(codebook_partition_sizes); int const nbooks = codebooks.size(0) / codebook_partition_sizes.size(0); int const entries = codebooks.size(1); - if (nbooks == 1 && entries == (1 << 16)) { - return vllm::aqlm::code1x16_matmat(input, codes, codebooks, scales, - cumulative_sizes, bias); + if (nbooks == 1 && entries == (1 << 16)) + { + return vllm::aqlm::code1x16_matmat(input, codes, codebooks, scales, cumulative_sizes, bias); } - if (nbooks == 2 && entries == (1 << 8)) { - return vllm::aqlm::code2x8_matmat(input, codes, codebooks, scales, - cumulative_sizes, bias); + if (nbooks == 2 && entries == (1 << 8)) + { + return vllm::aqlm::code2x8_matmat(input, codes, codebooks, scales, cumulative_sizes, bias); } - TORCH_CHECK(false, "AQLM with ", nbooks, " codebooks and ", entries, - " entries is not currently supported.") + TORCH_CHECK(false, "AQLM with ", nbooks, " codebooks and ", entries, " entries is not currently supported.") return {}; } -torch::Tensor aqlm_dequant(const torch::Tensor& codes, - const torch::Tensor& codebooks, - const torch::Tensor& codebook_partition_sizes) { - int4 cumulative_sizes = - vllm::aqlm::accumulate_sizes(codebook_partition_sizes); +torch::Tensor aqlm_dequant( + const torch::Tensor& codes, + const torch::Tensor& codebooks, + const torch::Tensor& codebook_partition_sizes +) +{ + int4 cumulative_sizes = vllm::aqlm::accumulate_sizes(codebook_partition_sizes); int const nbooks = codebooks.size(0) / codebook_partition_sizes.size(0); int const entries = codebooks.size(1); @@ -562,37 +668,45 @@ torch::Tensor aqlm_dequant(const torch::Tensor& codes, assert(out_features = codebook_partition_sizes.sum().item()); auto weights = torch::empty({out_features, in_features}, - torch::TensorOptions() - .dtype(codebooks.dtype()) - .device(codebooks.device())); - - if (nbooks == 1 && entries == (1 << 16)) { - vllm::aqlm::code1x16_dequant_cuda(codes.data_ptr(), weights.data_ptr(), - codebooks.data_ptr(), out_features, - in_features, cumulative_sizes, - vllm::aqlm::codebook_stride(codebooks)); - - // if you wanted to flip to scaling the weights, (though it's 30%-ish slower - // and not consistent with gemv implementation.) weights *= - // scales.index({"...", 0, 0}); + torch::TensorOptions() + .dtype(codebooks.dtype()) + .device(codebooks.device()) + ); - return weights; + if (nbooks == 1 && entries == (1 << 16)) + { + vllm::aqlm::code1x16_dequant_cuda( + codes.data_ptr(), + weights.data_ptr(), + codebooks.data_ptr(), + out_features, + in_features, + cumulative_sizes, + vllm::aqlm::codebook_stride(codebooks)); + + // if you wanted to flip to scaling the weights, (though it's 30%-ish slower and not consistent with gemv implementation.) + // weights *= scales.index({"...", 0, 0}); + + return weights; } - if (nbooks == 2 && entries == (1 << 8)) { - vllm::aqlm::code2x8_dequant_cuda(codes.data_ptr(), weights.data_ptr(), - codebooks.data_ptr(), out_features, - in_features, cumulative_sizes, - vllm::aqlm::codebook_stride(codebooks)); - - // if you wanted to flip to scaling the weights, (though it's 30%-ish slower - // and not consistent with gemv implementation) weights *= - // scales.index({"...", 0, 0}); - - return weights; + if (nbooks == 2 && entries == (1 << 8)) + { + vllm::aqlm::code2x8_dequant_cuda( + codes.data_ptr(), + weights.data_ptr(), + codebooks.data_ptr(), + out_features, + in_features, + cumulative_sizes, + vllm::aqlm::codebook_stride(codebooks)); + + // if you wanted to flip to scaling the weights, (though it's 30%-ish slower and not consistent with gemv implementation) + // weights *= scales.index({"...", 0, 0}); + + return weights; } - TORCH_CHECK(false, "AQLM with ", nbooks, " codebooks and ", entries, - " entries is not currently supported.") + TORCH_CHECK(false, "AQLM with ", nbooks, " codebooks and ", entries, " entries is not currently supported.") return {}; } diff --git a/csrc/quantization/awq/dequantize.cuh b/csrc/quantization/awq/dequantize.cuh index 813ec6716cf54..d1d926de18d78 100644 --- a/csrc/quantization/awq/dequantize.cuh +++ b/csrc/quantization/awq/dequantize.cuh @@ -1,11 +1,11 @@ /* Adapted from https://github.com/mit-han-lab/llm-awq -Modified from NVIDIA FasterTransformer: -https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h +Modified from NVIDIA FasterTransformer: https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h @article{lin2023awq, - title={AWQ: Activation-aware Weight Quantization for LLM Compression and -Acceleration}, author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, -Shang and Dang, Xingyu and Han, Song}, journal={arXiv}, year={2023} + title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration}, + author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song}, + journal={arXiv}, + year={2023} } */ @@ -14,88 +14,74 @@ Shang and Dang, Xingyu and Han, Song}, journal={arXiv}, year={2023} namespace vllm { namespace awq { -__device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source) { +__device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source) +{ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750 assert(false); #else - uint4 result; + uint4 result; - uint32_t* h = reinterpret_cast(&result); - uint32_t const i4s = reinterpret_cast(source); + uint32_t* h = reinterpret_cast(&result); + uint32_t const i4s = reinterpret_cast(source); - // First, we extract the i4s and construct an intermediate fp16 number. - static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; - static constexpr uint32_t BOTTOM_MASK = 0x000f000f; - static constexpr uint32_t TOP_MASK = 0x00f000f0; - static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400; + // First, we extract the i4s and construct an intermediate fp16 number. + static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint32_t BOTTOM_MASK = 0x000f000f; + static constexpr uint32_t TOP_MASK = 0x00f000f0; + static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400; - // Note that the entire sequence only requires 1 shift instruction. This is - // thanks to the register packing format and the fact that we force our - // integers to be unsigned, and account for this in the fp16 subtractions. In - // addition, I exploit the fact that sub and fma have the same throughput in - // order to convert elt_23 and elt_67 to fp16 without having to shift them to - // the bottom bits before hand. + // Note that the entire sequence only requires 1 shift instruction. This is thanks to the register packing + // format and the fact that we force our integers to be unsigned, and account for this in the fp16 subtractions. + // In addition, I exploit the fact that sub and fma have the same throughput in order to convert elt_23 and + // elt_67 to fp16 without having to shift them to the bottom bits before hand. - // Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW - // dependency if we issue immediately before required. - const uint32_t top_i4s = i4s >> 8; - // Extract elt_01 - (i4s & 0x000f000f) | 0x64006400 - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" - : "=r"(h[0]) - : "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), - "n"(immLut)); - // Extract elt_23 (i4s & 0x00f000f0) | 0x64006400 - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" - : "=r"(h[1]) - : "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), - "n"(immLut)); - // Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400 - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" - : "=r"(h[2]) - : "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), - "n"(immLut)); - // Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400 - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" - : "=r"(h[3]) - : "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), - "n"(immLut)); + // Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW dependency if we issue + // immediately before required. + const uint32_t top_i4s = i4s >> 8; + // Extract elt_01 - (i4s & 0x000f000f) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[0]) + : "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); + // Extract elt_23 (i4s & 0x00f000f0) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[1]) + : "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); + // Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[2]) + : "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); + // Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[3]) + : "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); - // I use inline PTX below because I am not sure if the compiler will emit - // float2half instructions if I use the half2 ctor. In this case, I chose - // performance reliability over code readability. + // I use inline PTX below because I am not sure if the compiler will emit float2half instructions if I use the + // half2 ctor. In this case, I chose performance reliability over code readability. - // This is the half2 {1032, 1032} represented as an integer. - // static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408; - // Haotian: subtract {1024, 1024} instead, we do not need to map to [-8, 7] - static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64006400; - // This is the half2 {1 / 16, 1 / 16} represented as an integer. - static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00; - // This is the half2 {-72, -72} represented as an integer. - // static constexpr uint32_t NEG_72 = 0xd480d480; - // Haotian: Let's use {-64, -64}. - static constexpr uint32_t NEG_64 = 0xd400d400; + // This is the half2 {1032, 1032} represented as an integer. + // static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408; + // Haotian: subtract {1024, 1024} instead, we do not need to map to [-8, 7] + static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64006400; + // This is the half2 {1 / 16, 1 / 16} represented as an integer. + static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00; + // This is the half2 {-72, -72} represented as an integer. + // static constexpr uint32_t NEG_72 = 0xd480d480; + // Haotian: Let's use {-64, -64}. + static constexpr uint32_t NEG_64 = 0xd400d400; - // Finally, we construct the output numbers. - // Convert elt_01 - asm volatile("sub.f16x2 %0, %1, %2;\n" - : "=r"(h[0]) - : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM)); - // Convert elt_23 - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" - : "=r"(h[1]) - : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_64)); - // Convert elt_45 - asm volatile("sub.f16x2 %0, %1, %2;\n" - : "=r"(h[2]) - : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM)); - // Convert elt_67 - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" - : "=r"(h[3]) - : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64)); + // Finally, we construct the output numbers. + // Convert elt_01 + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM)); + // Convert elt_23 + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_64)); + // Convert elt_45 + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM)); + // Convert elt_67 + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64)); - return result; + return result; #endif } -} // namespace awq -} // namespace vllm +} // namespace awq +} // namespace vllm diff --git a/csrc/quantization/awq/gemm_kernels.cu b/csrc/quantization/awq/gemm_kernels.cu index 6d6da5f3d8746..5aefb0bd16aef 100644 --- a/csrc/quantization/awq/gemm_kernels.cu +++ b/csrc/quantization/awq/gemm_kernels.cu @@ -1,13 +1,15 @@ /* Adapted from https://github.com/mit-han-lab/llm-awq @article{lin2023awq, - title={AWQ: Activation-aware Weight Quantization for LLM Compression and -Acceleration}, author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, -Shang and Dang, Xingyu and Han, Song}, journal={arXiv}, year={2023} + title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration}, + author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song}, + journal={arXiv}, + year={2023} } */ -#include + +#include #include #include "dequantize.cuh" @@ -18,20 +20,26 @@ namespace vllm { namespace awq { // Pack two half values. -static inline __device__ __host__ unsigned __pack_half2(const half x, - const half y) { - unsigned v0 = *((unsigned short*)&x); - unsigned v1 = *((unsigned short*)&y); +static inline __device__ __host__ unsigned +__pack_half2(const half x, const half y) { + unsigned v0 = *((unsigned short *)&x); + unsigned v1 = *((unsigned short *)&y); return (v1 << 16) | v0; } -template -__global__ void __launch_bounds__(64) - gemm_forward_4bit_cuda_m16nXk32(int G, int split_k_iters, - half* __restrict__ A, int* __restrict__ B, - half* __restrict__ scaling_factors, - int* __restrict__ zeros, int M, int IC, - int OC, half* __restrict__ C) { +template +__global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16nXk32( + int G, + int split_k_iters, + half* __restrict__ A, + int* __restrict__ B, + half* __restrict__ scaling_factors, + int* __restrict__ zeros, + int M, + int IC, + int OC, + half* __restrict__ C) +{ // Only support matrix n = 64 or 128 assert(N == 64 || N == 128); #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750 @@ -62,46 +70,43 @@ __global__ void __launch_bounds__(64) static constexpr int row_stride = 2 * 32 * 8 / N; bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < N; // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16 - bool ld_A_flag = - (blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp + - threadIdx.x * 8 / 32) < M; // threadIdx.y is warp_id + bool ld_A_flag = (blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp + threadIdx.x * 8 / 32) < M; // threadIdx.y is warp_id // bool wb_C_flag = (threadIdx.x / 4) < M; - half* A_ptr = - A + - (((int)blockIdx_y) / j_factors1 * 16 + - (((int)threadIdx.y) * row_stride_warp) + ((int)threadIdx.x) / (32 / 8)) * - IC + - (((int)threadIdx.x) % (32 / 8)) * 8; - - int* B_ptr = B + ((int)threadIdx.y) * (OC / 8) * (256 / N) + - (((int)threadIdx.x) / (N / 8)) * (OC / 8) + - (((int)blockIdx_y) % j_factors1) * (N / 8) + - (((int)threadIdx.x) % (N / 8)) * 1; - // Why * 1 in the above line? - - half* A_shared_ptr = A_shared + - ((int)threadIdx.y) * row_stride_warp * (32 + 8) + - (((int)threadIdx.x) / (32 / 8)) * (32 + 8) + - (((int)threadIdx.x) % (32 / 8)) * 8; - - half* B_shared_ptr = B_shared + - ((int)threadIdx.y) * (row_stride / 2) * (N + 8) + - (((int)threadIdx.x) / (N / 8)) * (N + 8) + - (((int)threadIdx.x) % (N / 8)) * 8; - - int* zeros_ptr = zeros + (((int)blockIdx_y) % j_factors1) * (N / 8) + - ((int)threadIdx.x) % (N / 8); - - half* scaling_factors_ptr = scaling_factors + - (((int)blockIdx_y) % j_factors1) * N + - (((int)threadIdx.x) % (N / 8)) * 8; - - half* C_ptr = - C + - static_cast(blockIdx_z) * M * OC // blockIdz.x -> split_k dim - + (((int)blockIdx_y) % j_factors1) * N + ((int)threadIdx.y) * (N / 2) + - (((int)threadIdx.x) % 4) * 2; + half* A_ptr = A + + (((int)blockIdx_y) / j_factors1 * 16 + (((int)threadIdx.y) * row_stride_warp) + ((int)threadIdx.x) / (32 / 8)) * IC + + (((int)threadIdx.x) % (32 / 8)) * 8; + + int* B_ptr = B + + ((int)threadIdx.y) * (OC / 8) * (256 / N) + + (((int)threadIdx.x) / (N / 8)) * (OC / 8) + + (((int)blockIdx_y) % j_factors1) * (N / 8) + + (((int)threadIdx.x) % (N / 8)) * 1; +// Why * 1 in the above line? + + half* A_shared_ptr = A_shared + + ((int)threadIdx.y) * row_stride_warp * (32 + 8) + + (((int)threadIdx.x) / (32 / 8)) * (32 + 8) + + (((int)threadIdx.x) % (32 / 8) ) * 8; + + half* B_shared_ptr = B_shared + + ((int)threadIdx.y) * (row_stride / 2) * (N + 8) + + (((int)threadIdx.x) / (N / 8)) * (N + 8) + + (((int)threadIdx.x) % (N / 8)) * 8; + + int* zeros_ptr = zeros + + (((int)blockIdx_y) % j_factors1) * (N / 8) + + ((int)threadIdx.x) % (N / 8); + + half* scaling_factors_ptr = scaling_factors + + (((int)blockIdx_y) % j_factors1) * N + + (((int)threadIdx.x) % (N / 8)) * 8; + + half* C_ptr = C + + static_cast(blockIdx_z) * M * OC // blockIdz.x -> split_k dim + + (((int)blockIdx_y) % j_factors1) * N + + ((int)threadIdx.y) * (N / 2) + + (((int)threadIdx.x) % 4) * 2; // preload s.f. and zeros int k_bound = (IC / 32 + split_k_iters - 1) / split_k_iters; @@ -110,83 +115,57 @@ __global__ void __launch_bounds__(64) int k_0_0 = _k_0_0 * split_k_iters + blockIdx_z; __syncthreads(); // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16 - if (ld_A_flag) { + if (ld_A_flag) + { *(uint4*)(A_shared_ptr) = *(uint4*)(A_ptr + (k_0_0 * 32)); - } else { + } + else + { *(uint4*)(A_shared_ptr) = make_uint4(0, 0, 0, 0); } // for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 2; ++ax0_ax1_fused_0) { uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr + k_0_0 * 32 / G * (OC / 8)); uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded); - uint4 B_loaded_scale = - *(uint4*)(scaling_factors_ptr + k_0_0 * 32 / G * (OC)); + uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr + k_0_0 * 32 / G * (OC)); /* - if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 && - threadIdx.y == 0){ printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x, - B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x, - B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w); + if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 && threadIdx.y == 0){ + printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x, B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x, B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w); } */ // uint4 B_loaded_scale = make_uint4(0, 0, 0, 0); int* B_ptr_local = B_ptr + k_0_0 * 32 * (OC / 8); for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < N / 16; ++ax0_ax1_fused_0) { + // B: 32 x 136 (128+8) float16 // each warp: 32 x 4 - // each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus - // zero -> WB UINT4 - // *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) * - // 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) - // * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 * - // 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) * - // 5120)) + (((int)blockIdx_y) * 128)) + ((((int)threadIdx.x) & 15) * - // 8))); row stride in shared memory: (NWARPS * 32 * 8 / cta_N) - uint32_t B_loaded = - *(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8)); + // each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus zero -> WB UINT4 + // *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) * 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 * 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) * 5120)) + (((int)blockIdx_y) * 128)) + ((((int)threadIdx.x) & 15) * 8))); + // row stride in shared memory: (NWARPS * 32 * 8 / cta_N) + uint32_t B_loaded = *(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8)); uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded); - // uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N / - // 8)) * 8); + //uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N / 8)) * 8); - // uint4 B_loaded_scale = *(uint4*)(scaling_factors_shared + (threadIdx.x - // % (cta_N / 8)) * 8); + // uint4 B_loaded_scale = *(uint4*)(scaling_factors_shared + (threadIdx.x % (cta_N / 8)) * 8); // - zero and * scale - // TODO (Haotian): can save 4 assembly instructions if sormulate as deq = - // q * scale - zero * scale. - asm volatile("sub.f16x2 %0, %1, %2;\n" - : "=r"(B_loaded_fp16.x) - : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" - : "=r"(B_loaded_fp16.x) - : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO)); - asm volatile("sub.f16x2 %0, %1, %2;\n" - : "=r"(B_loaded_fp16.y) - : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" - : "=r"(B_loaded_fp16.y) - : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO)); - asm volatile("sub.f16x2 %0, %1, %2;\n" - : "=r"(B_loaded_fp16.z) - : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" - : "=r"(B_loaded_fp16.z) - : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO)); - asm volatile("sub.f16x2 %0, %1, %2;\n" - : "=r"(B_loaded_fp16.w) - : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" - : "=r"(B_loaded_fp16.w) - : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO)); + // TODO (Haotian): can save 4 assembly instructions if sormulate as deq = q * scale - zero * scale. + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO)); /* - if (ax0_ax1_fused_0 == 0 && blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == - 0 && threadIdx.x == 17 && threadIdx.y == 0){ printf("[x] %X %X %X %X\n", - B_loaded_fp16.x, B_loaded_fp16.y, B_loaded_fp16.z, B_loaded_fp16.w); + if (ax0_ax1_fused_0 == 0 && blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 17 && threadIdx.y == 0){ + printf("[x] %X %X %X %X\n", B_loaded_fp16.x, B_loaded_fp16.y, B_loaded_fp16.z, B_loaded_fp16.w); } */ // write back - *(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (N + 8)) = - B_loaded_fp16; + *(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (N + 8)) = B_loaded_fp16; } __syncthreads(); @@ -194,179 +173,112 @@ __global__ void __launch_bounds__(64) { unsigned int addr; __asm__ __volatile__( - "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, " - "addr; }\n" - : "=r"(addr) - : "l"((void*)((&(A_shared[(k_0_1 * 16)])) + - (((((int)threadIdx.x) & 15) * 40) + - ((((int)threadIdx.x) >> 4) * 8))))); + "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" + : "=r"(addr) + : "l"((void *)((&(A_shared[(k_0_1 * 16)])) + (((((int)threadIdx.x) & 15) * 40) + ((((int)threadIdx.x) >> 4) * 8)))) + ); + __asm__ __volatile__( - "ldmatrix.sync.aligned.m8n8.x4.shared.b16" - "{%0, %1, %2, %3}, [%4];\n" - : "=r"(((unsigned*)(A_shared_warp + 0))[0]), - "=r"(((unsigned*)(A_shared_warp + 0))[1]), - "=r"(((unsigned*)(A_shared_warp + 0))[2]), - "=r"(((unsigned*)(A_shared_warp + 0))[3]) - : "r"(addr)); + "ldmatrix.sync.aligned.m8n8.x4.shared.b16" + "{%0, %1, %2, %3}, [%4];\n" + : "=r"(((unsigned *)(A_shared_warp + 0))[0]), "=r"(((unsigned *)(A_shared_warp + 0))[1]), "=r"(((unsigned *)(A_shared_warp + 0))[2]), "=r"(((unsigned *)(A_shared_warp + 0))[3]) + : "r"(addr) + ); } for (int ax1_0 = 0; ax1_0 < N / 32; ++ax1_0) { { unsigned int addr; __asm__ __volatile__( - "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, " - "addr; }\n" - : "=r"(addr) - : "l"((void*)((&(B_shared[(((k_0_1 * (N * 16 + 128)) + - (((int)threadIdx.y) * (N / 2))) + - (ax1_0 * 16))])) + - (((((int)threadIdx.x) & 15) * (N + 8)) + - ((((int)threadIdx.x) >> 4) * 8))))); + "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" + : "=r"(addr) + : "l"((void *)((&(B_shared[(((k_0_1 * (N * 16 + 128)) + (((int)threadIdx.y) * (N / 2))) + (ax1_0 * 16))])) + (((((int)threadIdx.x) & 15) * (N + 8)) + ((((int)threadIdx.x) >> 4) * 8)))) + ); __asm__ __volatile__( - "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16" - "{%0, %1, %2, %3}, [%4];\n" - : "=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[0]), - "=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[1]), - "=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[2]), - "=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[3]) - : "r"(addr)); + "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16" + "{%0, %1, %2, %3}, [%4];\n" + : "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[0]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[1]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[2]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[3]) + : "r"(addr) + ); } } for (int j_0_4 = 0; j_0_4 < N / 32; ++j_0_4) { - #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750 +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750 { __asm__ __volatile__( - "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" - "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" - : "=f"(((float*)(C_warp + (j_0_4 * 8)))[0]), - "=f"(((float*)(C_warp + (j_0_4 * 8)))[1]), - "=f"(((float*)(C_warp + (j_0_4 * 8)))[2]), - "=f"(((float*)(C_warp + (j_0_4 * 8)))[3]) - : "r"(((unsigned*)(A_shared_warp + 0))[0]), - "r"(((unsigned*)(A_shared_warp + 0))[1]), - "r"(((unsigned*)(B_shared_warp + (j_0_4 * 8)))[0]), - "f"(((float*)(C_warp + (j_0_4 * 8)))[0]), - "f"(((float*)(C_warp + (j_0_4 * 8)))[1]), - "f"(((float*)(C_warp + (j_0_4 * 8)))[2]), - "f"(((float*)(C_warp + (j_0_4 * 8)))[3])); + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" + : "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3]) + : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3])); } { __asm__ __volatile__( - "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" - "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" - : "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]), - "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]), - "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]), - "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3]) - : "r"(((unsigned*)(A_shared_warp + 0))[0]), - "r"(((unsigned*)(A_shared_warp + 0))[1]), - "r"(((unsigned*)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), - "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]), - "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]), - "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]), - "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3])); + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" + : "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]) + : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])); } { __asm__ __volatile__( - "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" - "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" - : "=f"(((float*)(C_warp + (j_0_4 * 8)))[0]), - "=f"(((float*)(C_warp + (j_0_4 * 8)))[1]), - "=f"(((float*)(C_warp + (j_0_4 * 8)))[2]), - "=f"(((float*)(C_warp + (j_0_4 * 8)))[3]) - : "r"(((unsigned*)(A_shared_warp + 0))[2]), - "r"(((unsigned*)(A_shared_warp + 0))[3]), - "r"(((unsigned*)(B_shared_warp + (j_0_4 * 8)))[1]), - "f"(((float*)(C_warp + (j_0_4 * 8)))[0]), - "f"(((float*)(C_warp + (j_0_4 * 8)))[1]), - "f"(((float*)(C_warp + (j_0_4 * 8)))[2]), - "f"(((float*)(C_warp + (j_0_4 * 8)))[3])); + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" + : "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3]) + : "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3])); } { __asm__ __volatile__( - "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" - "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" - : "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]), - "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]), - "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]), - "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3]) - : "r"(((unsigned*)(A_shared_warp + 0))[2]), - "r"(((unsigned*)(A_shared_warp + 0))[3]), - "r"(((unsigned*)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), - "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]), - "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]), - "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]), - "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3])); + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" + : "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]) + : "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])); } - #else +#else { __asm__ __volatile__( - "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" - "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, " - "%13};\n" - : "=f"(((float*)(C_warp + (j_0_4 * 8)))[0]), - "=f"(((float*)(C_warp + (j_0_4 * 8)))[1]), - "=f"(((float*)(C_warp + (j_0_4 * 8)))[2]), - "=f"(((float*)(C_warp + (j_0_4 * 8)))[3]) - : "r"(((unsigned*)(A_shared_warp + 0))[0]), - "r"(((unsigned*)(A_shared_warp + 0))[1]), - "r"(((unsigned*)(A_shared_warp + 0))[2]), - "r"(((unsigned*)(A_shared_warp + 0))[3]), - "r"(((unsigned*)(B_shared_warp + (j_0_4 * 8)))[0]), - "r"(((unsigned*)(B_shared_warp + (j_0_4 * 8)))[1]), - "f"(((float*)(C_warp + (j_0_4 * 8)))[0]), - "f"(((float*)(C_warp + (j_0_4 * 8)))[1]), - "f"(((float*)(C_warp + (j_0_4 * 8)))[2]), - "f"(((float*)(C_warp + (j_0_4 * 8)))[3])); + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n" + : "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3]) + : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3])); } { __asm__ __volatile__( - "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" - "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, " - "%13};\n" - : "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]), - "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]), - "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]), - "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3]) - : "r"(((unsigned*)(A_shared_warp + 0))[0]), - "r"(((unsigned*)(A_shared_warp + 0))[1]), - "r"(((unsigned*)(A_shared_warp + 0))[2]), - "r"(((unsigned*)(A_shared_warp + 0))[3]), - "r"(((unsigned*)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), - "r"(((unsigned*)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), - "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]), - "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]), - "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]), - "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3])); + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n" + : "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]) + : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])); } - #endif +#endif } } } - // TODO: Shang: Hoist loop invariance. +// TODO: Shang: Hoist loop invariance. for (int ax1_0_1 = 0; ax1_0_1 < 4; ++ax1_0_1) { for (int local_id = 0; local_id < 8; ++local_id) { - int row_offset = (((int)blockIdx_y) / j_factors1) * 16 + - ((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8; - if (row_offset < M) { - *(C_ptr + ax1_0_1 * 16 + row_offset * OC + (local_id / 4) * 8 + - local_id % 2) = __float2half(C_warp[(ax1_0_1 * 8) + local_id]); + int row_offset = (((int)blockIdx_y) / j_factors1) * 16 + ((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8; + if (row_offset < M) + { + *(C_ptr + ax1_0_1 * 16 + row_offset * OC + (local_id / 4) * 8 + local_id % 2) = __float2half(C_warp[(ax1_0_1 * 8) + local_id]); } } } #endif } -__global__ void __launch_bounds__(64) - dequantize_weights(int* __restrict__ B, half* __restrict__ scaling_factors, - int* __restrict__ zeros, half* __restrict__ C, int G) { +__global__ void __launch_bounds__(64) dequantize_weights( + int* __restrict__ B, + half* __restrict__ scaling_factors, + int* __restrict__ zeros, + half* __restrict__ C, + int G +) +{ int j_factors1 = 4; int row_stride2 = 4; int split_k_iters = 1; @@ -398,30 +310,14 @@ __global__ void __launch_bounds__(64) uint32_t B_loaded = *(uint32_t*)B_ptr2; uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded); - asm volatile("sub.f16x2 %0, %1, %2;\n" - : "=r"(B_loaded_fp16.x) - : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" - : "=r"(B_loaded_fp16.x) - : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO)); - asm volatile("sub.f16x2 %0, %1, %2;\n" - : "=r"(B_loaded_fp16.y) - : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" - : "=r"(B_loaded_fp16.y) - : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO)); - asm volatile("sub.f16x2 %0, %1, %2;\n" - : "=r"(B_loaded_fp16.z) - : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" - : "=r"(B_loaded_fp16.z) - : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO)); - asm volatile("sub.f16x2 %0, %1, %2;\n" - : "=r"(B_loaded_fp16.w) - : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" - : "=r"(B_loaded_fp16.w) - : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO)); *(uint4*)B_shared_ptr2 = B_loaded_fp16; @@ -430,57 +326,58 @@ __global__ void __launch_bounds__(64) } } -} // namespace awq -} // namespace vllm - -torch::Tensor awq_dequantize(torch::Tensor _kernel, - torch::Tensor _scaling_factors, - torch::Tensor _zeros, int64_t split_k_iters, - int64_t thx, int64_t thy) { - int in_c = _kernel.size(0); - int qout_c = _kernel.size(1); - int out_c = qout_c * 8; - int G = in_c / _scaling_factors.size(0); - - int x_thread = thx; - int y_thread = thy; - - int x_blocks = 1; - int y_blocks = 1; - if (thx == 0) { - x_thread = qout_c; - } - if (thy == 0) { - y_thread = in_c; - } - if (thx == 0 && thy == 0) { - x_thread = 8; - y_thread = 8; - x_blocks = (int)(qout_c / 8); - y_blocks = (int)(in_c / 8); - } +} // namespace awq +} // namespace vllm + +torch::Tensor awq_dequantize( + torch::Tensor _kernel, + torch::Tensor _scaling_factors, + torch::Tensor _zeros, + int split_k_iters, + int thx, + int thy) +{ + int in_c = _kernel.size(0); + int qout_c = _kernel.size(1); + int out_c = qout_c * 8; + int G = in_c / _scaling_factors.size(0); + + int x_thread = thx; + int y_thread = thy; + + int x_blocks = 1; + int y_blocks = 1; + if (thx==0) { + x_thread = qout_c; + } + if (thy==0) { + y_thread = in_c; + } + if (thx==0 && thy==0) { + x_thread = 8; + y_thread = 8; + x_blocks = (int)(qout_c / 8); + y_blocks = (int)(in_c / 8); + } - const at::cuda::OptionalCUDAGuard device_guard(device_of(_scaling_factors)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(_scaling_factors)); - auto options = torch::TensorOptions() - .dtype(_scaling_factors.dtype()) - .device(_scaling_factors.device()); - at::Tensor _de_kernel = torch::empty({in_c, out_c}, options); + auto options = torch::TensorOptions().dtype(_scaling_factors.dtype()).device(_scaling_factors.device()); + at::Tensor _de_kernel = torch::empty({in_c, out_c}, options); - auto kernel = reinterpret_cast(_kernel.data_ptr()); - auto de_kernel = reinterpret_cast(_de_kernel.data_ptr()); - auto scaling_factors = - reinterpret_cast(_scaling_factors.data_ptr()); - auto zeros = reinterpret_cast(_zeros.data_ptr()); + auto kernel = reinterpret_cast(_kernel.data_ptr()); + auto de_kernel = reinterpret_cast(_de_kernel.data_ptr()); + auto scaling_factors = reinterpret_cast(_scaling_factors.data_ptr()); + auto zeros = reinterpret_cast(_zeros.data_ptr()); - dim3 num_blocks(x_blocks, y_blocks); - dim3 threads_per_block(x_thread, y_thread); + dim3 num_blocks(x_blocks, y_blocks); + dim3 threads_per_block(x_thread, y_thread); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - vllm::awq::dequantize_weights<<>>( - kernel, scaling_factors, zeros, de_kernel, G); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + vllm::awq::dequantize_weights<<>>( + kernel, scaling_factors, zeros, de_kernel, G); - return _de_kernel; + return _de_kernel; } // in_feats: M, IC [float16] @@ -489,61 +386,61 @@ torch::Tensor awq_dequantize(torch::Tensor _kernel, // zeros: IC // G, OC // 8 [int32] -> cast to IC // G, OC [uint4b] // assume that batch_size < 16 for now -torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel, - torch::Tensor _scaling_factors, torch::Tensor _zeros, - int64_t split_k_iters) { - int num_in_feats = _in_feats.size(0); - int num_in_channels = _in_feats.size(1); - const at::cuda::OptionalCUDAGuard device_guard(device_of(_in_feats)); - - auto options = torch::TensorOptions() - .dtype(_in_feats.dtype()) - .device(_in_feats.device()); - at::Tensor _out_feats = - torch::empty({split_k_iters, num_in_feats, _kernel.size(1) * 8}, options); - int num_out_feats = _out_feats.size(-2); - int num_out_channels = _out_feats.size(-1); - - auto in_feats = reinterpret_cast(_in_feats.data_ptr()); - auto kernel = reinterpret_cast(_kernel.data_ptr()); - auto out_feats = reinterpret_cast(_out_feats.data_ptr()); - auto scaling_factors = - reinterpret_cast(_scaling_factors.data_ptr()); - auto zeros = reinterpret_cast(_zeros.data_ptr()); - int group_size = num_in_channels / _scaling_factors.size(0); - - if (num_out_channels % 64 != 0) - throw std::invalid_argument("OC is not multiple of cta_N = 64"); - if (num_out_channels % 8 != 0) - throw std::invalid_argument("OC is not multiple of pack_num = 8"); - if (group_size % 32 != 0) - throw std::invalid_argument("Group size should be a multiple of 32"); - if (num_out_channels % group_size != 0) - throw std::invalid_argument("OC is not multiple of Group size"); - - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - if (num_out_channels % 128 == 0) { - int j_factors1 = num_out_channels / 128 / 1; - dim3 num_blocks((num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters); - // threadIdx.x: 32 - // threadIdx.y: i_factors[2] * j_factors[2] - dim3 threads_per_block(32, 2); - vllm::awq::gemm_forward_4bit_cuda_m16nXk32<128> - <<>>( - group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, - num_in_feats, num_in_channels, num_out_channels, out_feats); - } else if (num_out_channels % 64 == 0) { - int j_factors1 = num_out_channels / 64 / 1; - dim3 num_blocks(1 * (num_out_feats + 16 - 1) / 16 * j_factors1 * - split_k_iters); - - // threadIdx.x: 32 - // threadIdx.y: i_factors[2] * j_factors[2] - dim3 threads_per_block(32, 2); - vllm::awq::gemm_forward_4bit_cuda_m16nXk32<64> - <<>>( - group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, - num_in_feats, num_in_channels, num_out_channels, out_feats); - } - return _out_feats.sum(0); +torch::Tensor awq_gemm( + torch::Tensor _in_feats, + torch::Tensor _kernel, + torch::Tensor _scaling_factors, + torch::Tensor _zeros, + int split_k_iters) +{ + int num_in_feats = _in_feats.size(0); + int num_in_channels = _in_feats.size(1); + const at::cuda::OptionalCUDAGuard device_guard(device_of(_in_feats)); + + auto options = torch::TensorOptions().dtype(_in_feats.dtype()).device(_in_feats.device()); + at::Tensor _out_feats = torch::empty({split_k_iters, num_in_feats, _kernel.size(1) * 8}, options); + int num_out_feats = _out_feats.size(-2); + int num_out_channels = _out_feats.size(-1); + + auto in_feats = reinterpret_cast(_in_feats.data_ptr()); + auto kernel = reinterpret_cast(_kernel.data_ptr()); + auto out_feats = reinterpret_cast(_out_feats.data_ptr()); + auto scaling_factors = reinterpret_cast(_scaling_factors.data_ptr()); + auto zeros = reinterpret_cast(_zeros.data_ptr()); + int group_size = num_in_channels / _scaling_factors.size(0); + + if (num_out_channels % 64 != 0) + throw std::invalid_argument("OC is not multiple of cta_N = 64"); + if (num_out_channels % 8 != 0) + throw std::invalid_argument("OC is not multiple of pack_num = 8"); + if (group_size % 32 != 0) + throw std::invalid_argument("Group size should be a multiple of 32"); + if (num_out_channels % group_size != 0) + throw std::invalid_argument("OC is not multiple of Group size"); + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + if (num_out_channels % 128 == 0) + { + int j_factors1 = num_out_channels / 128 / 1; + dim3 num_blocks((num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters); + // threadIdx.x: 32 + // threadIdx.y: i_factors[2] * j_factors[2] + dim3 threads_per_block(32, 2); + vllm::awq::gemm_forward_4bit_cuda_m16nXk32<128><<>>( + group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, + num_out_channels, out_feats); + } + else if (num_out_channels % 64 == 0) + { + int j_factors1 = num_out_channels / 64 / 1; + dim3 num_blocks(1 * (num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters); + + // threadIdx.x: 32 + // threadIdx.y: i_factors[2] * j_factors[2] + dim3 threads_per_block(32, 2); + vllm::awq::gemm_forward_4bit_cuda_m16nXk32<64><<>>( + group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, + num_out_channels, out_feats); + } + return _out_feats.sum(0); } diff --git a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu deleted file mode 100644 index aa9511daa2772..0000000000000 --- a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu +++ /dev/null @@ -1,115 +0,0 @@ -#include -#include -#include - -#include "../../dispatch_utils.h" -#include "../../reduction_utils.cuh" - -static inline __device__ int8_t float_to_int8_rn(float x) { -#ifdef USE_ROCM - static const float i8_min = - static_cast(std::numeric_limits::min()); - static const float i8_max = - static_cast(std::numeric_limits::max()); - // round - float dst = std::nearbyint(x); - // saturate - dst = std::clamp(dst, i8_min, i8_max); - return static_cast(dst); -#else - // CUDA path - uint32_t dst; - asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=r"(dst) : "f"(x)); - return reinterpret_cast(dst); -#endif -} - -namespace vllm { - -template -__global__ void static_scaled_int8_quant_kernel( - scalar_t const* __restrict__ input, int8_t* __restrict__ out, - scale_type const* scale_ptr, const int hidden_size) { - int const tid = threadIdx.x; - int const token_idx = blockIdx.x; - scale_type const scale = *scale_ptr; - - for (int i = tid; i < hidden_size; i += blockDim.x) { - out[token_idx * hidden_size + i] = float_to_int8_rn( - static_cast(input[token_idx * hidden_size + i]) / scale); - } -} - -template -__global__ void dynamic_scaled_int8_quant_kernel( - scalar_t const* __restrict__ input, int8_t* __restrict__ out, - scale_type* scale, const int hidden_size) { - int const tid = threadIdx.x; - int const token_idx = blockIdx.x; - float absmax_val = 0.0f; - float const zero = 0.0f; - - for (int i = tid; i < hidden_size; i += blockDim.x) { - float val = static_cast(input[token_idx * hidden_size + i]); - val = val > zero ? val : -val; - absmax_val = val > absmax_val ? val : absmax_val; - } - - float const block_absmax_val_maybe = blockReduceMax(absmax_val); - __shared__ float block_absmax_val; - if (tid == 0) { - block_absmax_val = block_absmax_val_maybe; - scale[token_idx] = block_absmax_val / 127.0f; - } - __syncthreads(); - - float const tmp_scale = 127.0f / block_absmax_val; - for (int i = tid; i < hidden_size; i += blockDim.x) { - out[token_idx * hidden_size + i] = float_to_int8_rn( - static_cast(input[token_idx * hidden_size + i]) * tmp_scale); - } -} - -} // namespace vllm - -void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size] - torch::Tensor const& input, // [..., hidden_size] - torch::Tensor const& scale) { - TORCH_CHECK(input.is_contiguous()); - TORCH_CHECK(out.is_contiguous()); - TORCH_CHECK(scale.numel() == 1); - - int const hidden_size = input.size(-1); - int const num_tokens = input.numel() / hidden_size; - dim3 const grid(num_tokens); - dim3 const block(std::min(hidden_size, 1024)); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - VLLM_DISPATCH_FLOATING_TYPES( - input.scalar_type(), "static_scaled_int8_quant_kernel", [&] { - vllm::static_scaled_int8_quant_kernel - <<>>(input.data_ptr(), - out.data_ptr(), - scale.data_ptr(), hidden_size); - }); -} - -void dynamic_scaled_int8_quant( - torch::Tensor& out, // [..., hidden_size] - torch::Tensor const& input, // [..., hidden_size] - torch::Tensor& scales) { - TORCH_CHECK(input.is_contiguous()); - TORCH_CHECK(out.is_contiguous()); - - int const hidden_size = input.size(-1); - int const num_tokens = input.numel() / hidden_size; - dim3 const grid(num_tokens); - dim3 const block(std::min(hidden_size, 1024)); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - VLLM_DISPATCH_FLOATING_TYPES( - input.scalar_type(), "dynamic_scaled_int8_quant_kernel", [&] { - vllm::dynamic_scaled_int8_quant_kernel - <<>>(input.data_ptr(), - out.data_ptr(), - scales.data_ptr(), hidden_size); - }); -} diff --git a/csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c2x.hpp b/csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c2x.hpp deleted file mode 100644 index c4c6b18654eed..0000000000000 --- a/csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c2x.hpp +++ /dev/null @@ -1,346 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights - *reserved. SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, - *this list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE - *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE - *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR - *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF - *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS - *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN - *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) - *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE - *POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -// -// This file is a modified excerpt of -// include/cutlass/epilogue/fusion/visitor_load.hpp from -// https://github.com/NVIDIA/cutlass v3.5.0 -// It has been modified to support either -// row/column or scalar broadcasting where the tensor being loaded from is -// always passed in via a device pointer. This lets one compiled kernel handle -// all cases of per-tensor or per-channel/per-token quantization. -// -// This interface also allows the scales to be passed in as tensors that -// consistently reside on the device, which avoids an issue with a previous -// implementation where scalars needed to be on the CPU since they -// were passed in via float values. This created a potential performance hazard -// if scales were initially on the device, and caused torch.compile graph -// breaks when moving scales to the CPU. -// -#pragma once - -// Turn off clang-format for the entire file to keep it close to upstream -// clang-format off - -#include "cutlass/epilogue/threadblock/fusion/visitor_2x.hpp" -#include "cute/tensor.hpp" - -namespace cutlass::epilogue::threadblock { - -using namespace cute; -using namespace detail; - -template< - class ThreadMap, - class Element, - class StrideMNL -> -struct VisitorRowOrScalarBroadcast { - - // This struct has been modified to have a bool indicating that ptr_row is a - // scalar that must be broadcast. - struct Arguments { - Element const* ptr_row = nullptr; - bool row_broadcast = true; - StrideMNL dRow = {}; - }; - - using Params = Arguments; - - template - static constexpr Params - to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { - return args; - } - - template - static size_t - get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { - return 0; - } - - struct SharedStorage {}; - - // Global load type - static int constexpr vec_bits = ThreadMap::kElementsPerAccess * sizeof_bits::value; - using VecType = uint_bit_t; - static int constexpr VecLength = sizeof(VecType) / sizeof(Element); - - CUTLASS_HOST_DEVICE - VisitorRowOrScalarBroadcast() { } - - CUTLASS_HOST_DEVICE - VisitorRowOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage) - : params_ptr(¶ms) { } - - Params const* params_ptr; - - template - struct Callbacks : EmptyCallbacks { - CUTLASS_DEVICE - Callbacks( - GTensor&& tC_gRow, - RTensor&& tC_rRow, - CTensor&& tC_cRow, - ProblemShape problem_shape, - Params const* params_ptr - ): - tC_gRow(cute::forward(tC_gRow)), - tC_rRow(cute::forward(tC_rRow)), - tC_cRow(cute::forward(tC_cRow)), - n(get<1>(problem_shape)), - params_ptr(params_ptr) { } - - GTensor tC_gRow; - RTensor tC_rRow; - CTensor tC_cRow; - Params const* params_ptr; - int n; - - // This function is modified from VisitorRowBroadcast - CUTLASS_DEVICE void - begin_epilogue() { - clear(tC_rRow); - auto src_v = filter(tC_gRow); - auto coord_v = filter(tC_cRow); - auto dst_v = filter(tC_rRow); - - if (params_ptr->row_broadcast) { - // In this case we are loading from a row vector and broadcasting - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(src_v); ++i) { - bool guard = get<1>(coord_v(i)) < n; - cutlass::arch::global_load( - dst_v(i), (void const*)&src_v(i), guard); - } - } else { - // In this case we are loading from a scalar and broadcasting - VecType filled_vec; - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < VecLength; i++) { - reinterpret_cast(&filled_vec)[i] = *(params_ptr->ptr_row); - } - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(src_v); ++i) { - if (get<1>(coord_v(i)) < n) { - dst_v(i) = filled_vec; - } - } - } - } - - template - CUTLASS_DEVICE auto // returns an Array - visit(int iter_idx, int row_idx, int column_idx, int frg_idx, - Array const& frg_acc) { - Tensor rRow_frg = recast>(coalesce(tC_rRow)); - return rRow_frg(column_idx); - } - }; - - template - CUTLASS_DEVICE auto - get_callbacks( - gemm::GemmCoord threadblock_tile_offset, - int thread_idx, - ProblemShape problem_shape - ) { - Tensor mRow = make_tensor( - make_gmem_ptr(params_ptr->ptr_row), - problem_shape, - params_ptr->dRow); - - // VECTOR, FRAGMENT_COLUMN - Tensor tC_gRow = recast( - ThreadMap::partition(mRow, thread_idx, threadblock_tile_offset) - )(_,_,_0{},_0{},_0{},_0{}); - Tensor tC_rRow = make_tensor_like(tC_gRow); - - // Generate the pred tensor - Tensor cRow = make_identity_tensor(mRow.shape()); - Tensor tC_cRow = outer_partition( - ThreadMap::partition(cRow, thread_idx, threadblock_tile_offset)(_,_,_0{},_0{},_0{},_0{}), - Shape>{}, - (_0{}) - ); - - return Callbacks< - decltype(tC_gRow), decltype(tC_rRow), - decltype(tC_cRow), ProblemShape>( - cute::move(tC_gRow), - cute::move(tC_rRow), - cute::move(tC_cRow), - problem_shape, - params_ptr - ); - } - -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// Column vector broadcast -template< - class ThreadMap, - class Element, - class StrideMNL = Stride<_1,_0,_0> -> -struct VisitorColOrScalarBroadcast { - - // This struct has been modified to have a bool indicating that ptr_col is a - // scalar that must be broadcast. - struct Arguments { - Element const* ptr_col = nullptr; - bool col_broadcast = true; - StrideMNL dCol = {}; - }; - - using Params = Arguments; - - template - static constexpr Params - to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { - return args; - } - - template - static size_t - get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { - return 0; - } - - struct SharedStorage { }; - - CUTLASS_HOST_DEVICE - VisitorColOrScalarBroadcast() { } - - CUTLASS_HOST_DEVICE - VisitorColOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage) - : params_ptr(¶ms) { } - - Params const* params_ptr; - - template - struct Callbacks : EmptyCallbacks { - CUTLASS_DEVICE - Callbacks( - GTensor&& tC_gCol, - RTensor&& tC_rCol, - CTensor&& tC_cCol, - ProblemShape problem_shape, - Params const* params_ptr - ): - tC_gCol(cute::forward(tC_gCol)), - tC_rCol(cute::forward(tC_rCol)), - tC_cCol(cute::forward(tC_cCol)), - m(get<0>(problem_shape)), - params_ptr(params_ptr) { } - - GTensor tC_gCol; - RTensor tC_rCol; - CTensor tC_cCol; - Params const* params_ptr; - int m; - - // This function is modified from VisitorColBroadcast - CUTLASS_DEVICE void - begin_epilogue() { - clear(tC_rCol); - - Tensor pred = make_tensor(shape(tC_gCol)); - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(pred); ++i) { - pred(i) = get<0>(tC_cCol(i)) < m; - } - - if (params_ptr->col_broadcast) { - // In this case we are loading from a column vector and broadcasting - copy_if(pred, tC_gCol, tC_rCol); - } else { - // In this case we are loading from a scalar and broadcasting - auto dst_v = filter(tC_rCol); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(dst_v); ++i) { - if (pred(i)) { - dst_v(i) = *(params_ptr->ptr_col); - } - } - } - } - - template - CUTLASS_DEVICE auto // returns an Array - visit(int iter_idx, int row_idx, int column_idx, int frg_idx, - Array const& frg_acc) { - Array frg_col; - frg_col.fill(tC_rCol(row_idx,iter_idx)); - return frg_col; - } - }; - - template - CUTLASS_DEVICE auto - get_callbacks( - gemm::GemmCoord threadblock_tile_offset, - int thread_idx, - ProblemShape problem_shape - ) { - Tensor mCol = make_tensor( - make_gmem_ptr(params_ptr->ptr_col), - problem_shape, - params_ptr->dCol); - - // VECTOR, FRAGMENT_COLUMN, FRAGMENT_ROW, ITERATION_ROW, ITERATION_GROUP, ITERATION_CLUSTER - Tensor tC_gCol = group_modes<1,4>( - ThreadMap::partition(mCol, thread_idx, threadblock_tile_offset)(_0{},_0{},_,_,_,_)); - Tensor tC_rCol = make_tensor_like(tC_gCol); - - // Generate the pred tensor - Tensor cCol = make_identity_tensor(mCol.shape()); - Tensor tC_cCol = group_modes<1,4>( - ThreadMap::partition(cCol, thread_idx, threadblock_tile_offset)(_0{},_0{},_,_,_,_)); - - return Callbacks< - decltype(tC_gCol), decltype(tC_rCol), - decltype(tC_cCol), ProblemShape>( - cute::move(tC_gCol), - cute::move(tC_rCol), - cute::move(tC_cCol), - problem_shape, - params_ptr - ); - } -}; - -} diff --git a/csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c3x.hpp b/csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c3x.hpp deleted file mode 100644 index 877a9f5b9e5de..0000000000000 --- a/csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c3x.hpp +++ /dev/null @@ -1,389 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights - *reserved. SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, - *this list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE - *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE - *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR - *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF - *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS - *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN - *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) - *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE - *POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -// -// This file is a modified excerpt of -// include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp -// from https://github.com/NVIDIA/cutlass v3.5.0 -// It has been modified to support either row/column or scalar broadcasting -// where the tensor being loaded from is always passed in via a device pointer. -// This lets one compiled kernel handle all cases of per-tensor or -// per-channel/per-token quantization. -// -// This interface also allows the scales to be passed in as tensors that -// consistently reside on the device, which avoids an issue with a previous -// implementation where scalars needed to be on the CPU since they -// were passed in via float values. This created a potential performance hazard -// if scales were initially on the device, and caused torch.compile graphs -// breaks when moving scales to the CPU. -// -#pragma once - -// Turn off clang-format for the entire file to keep it close to upstream -// clang-format off - -#include "cutlass/cutlass.h" -#include "cutlass/arch/barrier.h" - -#include "cute/tensor.hpp" -#include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp" - -namespace cutlass::epilogue::fusion { - -using namespace cute; -using namespace detail; - -// Row vector broadcast -template< - // Row bcast reuses the mbarriers from the epilogue subtile load pipeline, so this must be at least - // ceil_div(StagesC, epi tiles per CTA tile) + 1 to ensure no data races - int Stages, - class CtaTileShapeMNK, - class Element, - class StrideMNL = Stride<_0,_1,_0>, - int Alignment = 128 / sizeof_bits_v -> -struct Sm90RowOrScalarBroadcast { - static_assert(Alignment * sizeof_bits_v % 128 == 0, "sub-16B alignment not supported yet"); - static_assert( - (cute::is_same_v>) || // row vector broadcast, e.g. per-col alpha/bias - (cute::is_same_v>)); // batched row vector broadcast - - // Accumulator doesn't distribute row elements evenly amongst threads so we must buffer in smem - struct SharedStorage { - alignas(16) array_aligned(CtaTileShapeMNK{}) * Stages> smem_row; - }; - - // This struct has been modified to have a bool indicating that ptr_row is a - // scalar that must be broadcast, instead of containing a scalar that is - // valid if ptr_row is null. - struct Arguments { - Element const* ptr_row = nullptr; - bool row_broadcast = true; - StrideMNL dRow = {}; - }; - - using Params = Arguments; - - template - static constexpr Params - to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { - return args; - } - - template - static size_t - get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { - return 0; - } - - template - static cutlass::Status - initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, - CudaHostAdapter* cuda_adapter = nullptr) { - return cutlass::Status::kSuccess; - } - - CUTLASS_HOST_DEVICE - Sm90RowOrScalarBroadcast() { } - - CUTLASS_HOST_DEVICE - Sm90RowOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage) - : params(params), - smem_row(const_cast(shared_storage.smem_row.data())) { } - - Params params; - Element* smem_row; - - CUTLASS_DEVICE bool - is_producer_load_needed() const { - return true; - } - - CUTLASS_DEVICE bool - is_C_load_needed() const { - return false; - } - - CUTLASS_DEVICE bool - is_zero() const { - return (!params.row_broadcast && *(params.ptr_row) == Element(0)); - } - - template - struct ProducerLoadCallbacks : EmptyProducerLoadCallbacks { - CUTLASS_DEVICE - ProducerLoadCallbacks(GTensor&& gRow, STensor&& sRow, Params const& params) - : gRow(cute::forward(gRow)), - sRow(cute::forward(sRow)), - params(params) {} - - GTensor gRow; // (CTA_M,CTA_N) - STensor sRow; // (CTA_M,CTA_N,PIPE) - Params const& params; - - CUTLASS_DEVICE void - begin(uint64_t* full_mbarrier_ptr, int load_iteration, bool issue_tma_load) { - if (!params.row_broadcast) { - return; - } - - if (issue_tma_load) { - // Increment the expect-tx count of the first subtile's mbarrier by the row vector's byte-size - constexpr uint32_t copy_bytes = size<1>(CtaTileShapeMNK{}) * sizeof_bits_v / 8; - cutlass::arch::ClusterTransactionBarrier::expect_transaction(full_mbarrier_ptr, copy_bytes); - // Issue the TMA bulk copy - auto bulk_copy = Copy_Atom{}.with(*full_mbarrier_ptr); - // Filter so we don't issue redundant copies over stride-0 modes - int bcast_pipe_index = (load_iteration / EpiTiles) % Stages; - copy(bulk_copy, filter(gRow), filter(sRow(_,_,bcast_pipe_index))); - } - } - }; - - template - CUTLASS_DEVICE auto - get_producer_load_callbacks(ProducerLoadArgs const& args) { - - auto [M, N, K, L] = args.problem_shape_mnkl; - auto [m, n, k, l] = args.tile_coord_mnkl; - Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row), make_shape(M,N,L), params.dRow); - Tensor gRow = local_tile(mRow, take<0,2>(args.tile_shape_mnk), make_coord(m,n,l)); // (CTA_M,CTA_N) - Tensor sRow = make_tensor(make_smem_ptr(smem_row), // (CTA_M,CTA_N,PIPE) - make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}), Stages), - make_stride(_0{},_1{},size<1>(CtaTileShapeMNK{}))); - - constexpr int EpiTiles = decltype(size<1>(zipped_divide(make_layout(take<0,2>(args.tile_shape_mnk)), args.epi_tile)))::value; - return ProducerLoadCallbacks( - cute::move(gRow), cute::move(sRow), params); - } - - template - struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { - CUTLASS_DEVICE - ConsumerStoreCallbacks(RTensor&& tCrRow, STensor&& tCsRow, Params const& params) - : tCrRow(cute::forward(tCrRow)), - tCsRow(cute::forward(tCsRow)), - params(params) {} - - RTensor tCrRow; // (CPY,CPY_M,CPY_N) - STensor tCsRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N,PIPE) - Params const& params; - - CUTLASS_DEVICE void - previsit(int epi_m, int epi_n, int load_iteration, bool is_producer_load_needed) { - if (!params.row_broadcast) { - fill(tCrRow, *(params.ptr_row)); - return; - } - - if (epi_m == 0) { // Assumes M-major subtile loop - // Filter so we don't issue redundant copies over stride-0 modes - // (only works if 0-strides are in same location, which is by construction) - int bcast_pipe_index = (load_iteration / EpiTiles) % Stages; - copy_aligned(filter(tCsRow(_,_,_,epi_m,epi_n,bcast_pipe_index)), filter(tCrRow)); - } - } - - template - CUTLASS_DEVICE Array - visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { - Array frg_row; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < FragmentSize; ++i) { - frg_row[i] = tCrRow(epi_v * FragmentSize + i); - } - - return frg_row; - } - }; - - template < - bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy - class... Args - > - CUTLASS_DEVICE auto - get_consumer_store_callbacks(ConsumerStoreArgs const& args) { - - Tensor sRow = make_tensor(make_smem_ptr(smem_row), // (CTA_M,CTA_N,PIPE) - make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}), Stages), - make_stride(_0{},_1{},size<1>(CtaTileShapeMNK{}))); - Tensor tCsRow = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N,PIPE) - sRow, args.epi_tile, args.tiled_copy, args.thread_idx); - Tensor tCrRow = make_tensor_like(take<0,3>(tCsRow)); // (CPY,CPY_M,CPY_N) - - constexpr int EpiTiles = decltype(size<1>(zipped_divide(make_layout(take<0,2>(args.tile_shape_mnk)), args.epi_tile)))::value; - return ConsumerStoreCallbacks( - cute::move(tCrRow), cute::move(tCsRow), params); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// Column vector broadcast -template< - int Stages, - class CtaTileShapeMNK, - class Element, - class StrideMNL = Stride<_1,_0,_0>, - int Alignment = 128 / sizeof_bits_v -> -struct Sm90ColOrScalarBroadcast { - static_assert(Stages == 0, "Column broadcast doesn't support smem usage yet"); - static_assert(Alignment * sizeof_bits_v % 128 == 0, "sub-16B alignment not supported yet"); - static_assert( - (cute::is_same_v>) || // col vector broadcast, e.g. per-row alpha/bias - (cute::is_same_v>)); // batched col vector broadcast, e.g. batched per-row bias - - // Accumulator distributes col elements evenly amongst threads so we can just directly load from gmem - struct SharedStorage { }; - - // This struct has been modified to have a bool indicating that ptr_col is a - // scalar that must be broadcast, instead of containing a scalar that is - // valid if ptr_col is null. - struct Arguments { - Element const* ptr_col = nullptr; - bool col_broadcast = true; - StrideMNL dCol = {}; - }; - - using Params = Arguments; - - template - static constexpr Params - to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { - return args; - } - - template - static size_t - get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { - return 0; - } - - template - static cutlass::Status - initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, - CudaHostAdapter* cuda_adapter = nullptr) { - return cutlass::Status::kSuccess; - } - - CUTLASS_DEVICE bool - is_producer_load_needed() const { - return false; - } - - CUTLASS_DEVICE bool - is_C_load_needed() const { - return false; - } - - CUTLASS_DEVICE bool - is_zero() const { - return (!params.col_broadcast && *(params.ptr_col) == Element(0)); - } - - CUTLASS_HOST_DEVICE - Sm90ColOrScalarBroadcast() { } - - CUTLASS_HOST_DEVICE - Sm90ColOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage) - : params(params) { } - - Params params; - - template - CUTLASS_DEVICE auto - get_producer_load_callbacks(ProducerLoadArgs const& args) { - return EmptyProducerLoadCallbacks{}; - } - - template - struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { - CUTLASS_DEVICE - ConsumerStoreCallbacks(GTensor&& tCgCol, RTensor&& tCrCol, Params const& params) - : tCgCol(cute::forward(tCgCol)), - tCrCol(cute::forward(tCrCol)), - params(params) {} - - GTensor tCgCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) - RTensor tCrCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) - Params const& params; - - CUTLASS_DEVICE void - begin() { - if (!params.col_broadcast) { - fill(tCrCol, *(params.ptr_col)); - return; - } - - // Filter so we don't issue redundant copies over stride-0 modes - // (only works if 0-strides are in same location, which is by construction) - copy_aligned(filter(tCgCol), filter(tCrCol)); - } - - template - CUTLASS_DEVICE Array - visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { - Array frg_col; - Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < FragmentSize; ++i) { - frg_col[i] = tCrCol_mn(epi_v * FragmentSize + i); - } - - return frg_col; - } - - }; - - template < - bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy - class... Args - > - CUTLASS_DEVICE auto - get_consumer_store_callbacks(ConsumerStoreArgs const& args) { - - auto [M, N, K, L] = args.problem_shape_mnkl; - Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col), make_shape(M,N,L), params.dCol); - Tensor tCgCol = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) - mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); - Tensor tCrCol = make_tensor_like(tCgCol); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) - - return ConsumerStoreCallbacks( - cute::move(tCgCol), cute::move(tCrCol), params); - } -}; - -} diff --git a/csrc/quantization/cutlass_w8a8/common.hpp b/csrc/quantization/cutlass_w8a8/common.hpp deleted file mode 100644 index bf04bb400790f..0000000000000 --- a/csrc/quantization/cutlass_w8a8/common.hpp +++ /dev/null @@ -1,27 +0,0 @@ -#pragma once - -#include "cutlass/cutlass.h" -#include - -/** - * Helper function for checking CUTLASS errors - */ -#define CUTLASS_CHECK(status) \ - { \ - TORCH_CHECK(status == cutlass::Status::kSuccess, \ - cutlassGetStatusString(status)) \ - } - -inline uint32_t next_pow_2(uint32_t const num) { - if (num <= 1) return num; - return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1)); -} - -inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) { - int max_shared_mem_per_block_opt_in = 0; - cudaDeviceGetAttribute(&max_shared_mem_per_block_opt_in, - cudaDevAttrMaxSharedMemoryPerBlockOptin, - device); - return max_shared_mem_per_block_opt_in; -} - diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu deleted file mode 100644 index 6ce25c5ac897b..0000000000000 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu +++ /dev/null @@ -1,609 +0,0 @@ -#include -#include - -#include - -// clang-format will break include orders -// clang-format off -#include "cute/tensor.hpp" -#include "cute/atom/mma_atom.hpp" -#include "cutlass/numeric_types.h" - -#include "cutlass/util/device_memory.h" - -#include "cutlass/cutlass.h" -#include "cutlass/gemm_coord.h" -#include "cutlass/arch/mma_sm75.h" -#include "cutlass/arch/arch.h" -#include "cutlass/arch/mma.h" -#include "cutlass/gemm/device/gemm.h" -#include "cutlass/gemm/device/gemm_universal_adapter.h" - -#include "cutlass/epilogue/threadblock/fusion/visitors.hpp" -#include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h" - -#include "broadcast_load_epilogue_c2x.hpp" -#include "common.hpp" -// clang-format on - -using namespace cute; - -/* - This file defines quantized GEMM operations using the CUTLASS 2.x API, for - NVIDIA GPUs with SM versions prior to sm90 (Hopper). - - Epilogue functions can be defined to post-process the output before it is - written to GPU memory. - Epilogues must contain a public type named EVTCompute of type Sm80EVT, - as well as a static prepare_args function that constructs an - EVTCompute::Arguments struct. -*/ - -namespace { - -// Wrappers for the GEMM kernel that is used to guard against compilation on -// architectures that will never use the kernel. The purpose of this is to -// reduce the size of the compiled binary. -// __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef -// into code that will be executed on the device where it is defined. -template -struct enable_sm75_to_sm80 : Kernel { - template - CUTLASS_DEVICE static void invoke(Args&&... args) { -#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 750 && __CUDA_ARCH__ < 800 - Kernel::invoke(std::forward(args)...); -#endif - } -}; - -template -struct enable_sm80_to_sm89 : Kernel { - template - CUTLASS_DEVICE static void invoke(Args&&... args) { -#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800 && __CUDA_ARCH__ < 890 - Kernel::invoke(std::forward(args)...); -#endif - } -}; - -template -struct enable_sm89_to_sm90 : Kernel { - template - CUTLASS_DEVICE static void invoke(Args&&... args) { -#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 890 && __CUDA_ARCH__ < 900 - Kernel::invoke(std::forward(args)...); -#endif - } -}; - -/* - * This class provides the common ScaleA and ScaleB descriptors for the - * ScaledEpilogue and ScaledEpilogueBias classes. - */ -template -struct ScaledEpilogueBase { - protected: - using Accum = cutlass::epilogue::threadblock::VisitorAccFetch; - - using ScaleA = cutlass::epilogue::threadblock::VisitorColOrScalarBroadcast< - OutputTileThreadMap, float, Stride, Int<0>, Int<0>>>; - - using ScaleB = cutlass::epilogue::threadblock::VisitorRowOrScalarBroadcast< - OutputTileThreadMap, float, Stride, Int<1>, Int<0>>>; -}; - -/* - This epilogue function defines a quantized GEMM operation similar to - torch._scaled_mm. - - A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or - per-row. B can be quantized per-tensor or per-column. - Any combination of per-tensor and per-row or column is supported. - A and B must have symmetric quantization (zero point == 0). - - So the GEMM operation is D = (a_scales * A) (b_scales * B), where the - scales are applied elementwise with numpy-style broadcasting. - - ScaleA and ScaleB define the epilogue functions that apply the scales for - the A and B operands respectively. These scales may be either per-tensor or - per row or column. -*/ -template -struct ScaledEpilogue - : private ScaledEpilogueBase { - private: - using SUPER = ScaledEpilogueBase; - using Accum = typename SUPER::Accum; - using ScaleA = typename SUPER::ScaleA; - using ScaleB = typename SUPER::ScaleB; - - using Compute0 = cutlass::epilogue::threadblock::VisitorCompute< - cutlass::multiplies, float, float, - cutlass::FloatRoundStyle::round_to_nearest>; - - using EVTCompute0 = - cutlass::epilogue::threadblock::Sm80EVT; - - using Compute1 = cutlass::epilogue::threadblock::VisitorCompute< - cutlass::multiplies, ElementD, float, - cutlass::FloatRoundStyle::round_to_nearest>; - - public: - using EVTCompute = - cutlass::epilogue::threadblock::Sm80EVT; - using ArgumentType = typename EVTCompute::Arguments; - - static ArgumentType prepare_args(torch::Tensor const& a_scales, - torch::Tensor const& b_scales) { - using ScaleAArgs = typename ScaleA::Arguments; - using ScaleBArgs = typename ScaleB::Arguments; - - ScaleBArgs b_args{b_scales.data_ptr(), b_scales.numel() != 1, {}}; - ScaleAArgs a_args{a_scales.data_ptr(), a_scales.numel() != 1, {}}; - - typename EVTCompute0::Arguments evt0_compute_args{b_args}; - - typename EVTCompute::Arguments evt_compute_args{a_args, evt0_compute_args}; - return evt_compute_args; - } -}; - -template -struct ScaledEpilogueBias - : private ScaledEpilogueBase { - private: - using SUPER = ScaledEpilogueBase; - using Accum = typename SUPER::Accum; - using ScaleA = typename SUPER::ScaleA; - using ScaleB = typename SUPER::ScaleB; - - using Compute0 = cutlass::epilogue::threadblock::VisitorCompute< - cutlass::multiplies, float, float, - cutlass::FloatRoundStyle::round_to_nearest>; - - using EVTCompute0 = - cutlass::epilogue::threadblock::Sm80EVT; - - using Compute1 = cutlass::epilogue::threadblock::VisitorCompute< - cutlass::multiply_add, ElementD, float, - cutlass::FloatRoundStyle::round_to_nearest>; - - using Bias = cutlass::epilogue::threadblock::VisitorRowBroadcast< - OutputTileThreadMap, ElementD, Stride, Int<1>, Int<0>>>; - - public: - using EVTCompute = cutlass::epilogue::threadblock::Sm80EVT; - using ArgumentType = typename EVTCompute::Arguments; - - static ArgumentType prepare_args(torch::Tensor const& a_scales, - torch::Tensor const& b_scales, - torch::Tensor const& bias) { - using ScaleAArgs = typename ScaleA::Arguments; - using ScaleBArgs = typename ScaleB::Arguments; - using BiasArgs = typename Bias::Arguments; - - ScaleBArgs b_args{b_scales.data_ptr(), b_scales.numel() != 1, {}}; - ScaleAArgs a_args{a_scales.data_ptr(), a_scales.numel() != 1, {}}; - BiasArgs bias_args{static_cast(bias.data_ptr()), {}}; - - typename EVTCompute0::Arguments evt0_compute_args{b_args}; - - typename EVTCompute::Arguments evt_compute_args{a_args, evt0_compute_args, - bias_args}; - return evt_compute_args; - } -}; - -template typename ArchGuard, - typename ElementAB_, typename ElementD_, - template typename Epilogue_, typename TileShape, - typename WarpShape, typename InstructionShape, int32_t MainLoopStages> -struct cutlass_2x_gemm { - using ElementAB = ElementAB_; - using ElementD = ElementD_; - - using ElementAcc = - typename std::conditional, int32_t, - float>::type; - - using Operator = - typename std::conditional, - cutlass::arch::OpMultiplyAddSaturate, - cutlass::arch::OpMultiplyAdd>::type; - - using OutputTileThreadMap = - cutlass::epilogue::threadblock::OutputTileThreadLayout< - TileShape, WarpShape, float, 4, 1 /* epilogue stages */ - >; - - using Epilogue = Epilogue_; - using EVTCompute = typename Epilogue::EVTCompute; - - using D = cutlass::epilogue::threadblock::VisitorAuxStore< - OutputTileThreadMap, ElementD, cutlass::FloatRoundStyle::round_to_nearest, - Stride, Int<0>>>; - - using EVTD = cutlass::epilogue::threadblock::Sm80EVT; - - // clang-format off - using RowMajor = typename cutlass::layout::RowMajor; - using ColumnMajor = typename cutlass::layout::ColumnMajor; - using KernelType = - ArchGuard::GemmKernel>; - // clang-format on - - using Op = cutlass::gemm::device::GemmUniversalAdapter; -}; - -template -void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a, - torch::Tensor const& b, - EpilogueArgs&&... epilogue_params) { - using ElementAB = typename Gemm::ElementAB; - using ElementD = typename Gemm::ElementD; - - int32_t m = a.size(0); - int32_t n = b.size(1); - int32_t k = a.size(1); - cutlass::gemm::GemmCoord problem_size{m, n, k}; - - int64_t lda = a.stride(0); - int64_t ldb = b.stride(1); - int64_t ldc = out.stride(0); - - using StrideC = Stride, Int<0>>; - StrideC c_stride{ldc, Int<1>{}, Int<0>{}}; - - auto a_ptr = static_cast(a.data_ptr()); - auto b_ptr = static_cast(b.data_ptr()); - auto c_ptr = static_cast(out.data_ptr()); - - typename Gemm::D::Arguments d_args{c_ptr, c_stride}; - - using Epilogue = typename Gemm::Epilogue; - auto evt_args = - Epilogue::prepare_args(std::forward(epilogue_params)...); - - typename Gemm::EVTD::Arguments epilogue_args{ - evt_args, - d_args, - }; - - typename Gemm::Op::Arguments args{ - cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel, // universal mode - problem_size, // problem size - 1, // batch count - epilogue_args, - a_ptr, - b_ptr, - nullptr, - nullptr, - 0, - 0, - 0, - 0, - lda, - ldb, - ldc, - ldc}; - - // Launch the CUTLASS GEMM kernel. - typename Gemm::Op gemm_op; - size_t workspace_size = gemm_op.get_workspace_size(args); - cutlass::device_memory::allocation workspace(workspace_size); - - auto stream = at::cuda::getCurrentCUDAStream(a.get_device()); - - CUTLASS_CHECK(gemm_op.can_implement(args)); - cutlass::Status status = gemm_op(args, workspace.get(), stream); - CUTLASS_CHECK(status); -} - -template -void fallback_cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a, - torch::Tensor const& b, - EpilogueArgs&&... args) { - // In some cases, the GPU isn't able to accommodate the - // shared memory requirements of the Gemm. In such cases, use - // the FallbackGemm instead. - static const int max_shared_mem_per_block_opt_in = - get_cuda_max_shared_memory_per_block_opt_in(0); - - size_t const gemm_shared_mem_size = - sizeof(typename Gemm::KernelType::SharedStorage); - size_t const fallback_gemm_shared_mem_size = - sizeof(typename FallbackGemm::KernelType::SharedStorage); - - if (gemm_shared_mem_size <= max_shared_mem_per_block_opt_in) { - return cutlass_gemm_caller(out, a, b, - std::forward(args)...); - } else { - TORCH_CHECK(fallback_gemm_shared_mem_size <= - max_shared_mem_per_block_opt_in); - return cutlass_gemm_caller( - out, a, b, std::forward(args)...); - } -} - -template typename Epilogue> -struct sm80_config_default { - // This config is used in 2 cases, - // - M in (128, inf) - // - M in (64, 128] and N >= 8192 - // Shared Memory required by this Gemm - 81920 bytes - static_assert(std::is_same()); - using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>; - using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>; - using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; - using Cutlass2xGemm = - cutlass_2x_gemm; -}; - -template typename Epilogue> -struct sm80_config_M64 { - // This config is used in 2 cases, - // - M in (32, 64] - // - M in (64, 128] and N < 8192 - // Shared Memory required by this Gemm - 122880 bytes - static_assert(std::is_same()); - using TileShape = typename cutlass::gemm::GemmShape<64, 128, 128>; - using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>; - using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; - using Cutlass2xGemm = - cutlass_2x_gemm; -}; - -template typename Epilogue> -struct sm80_config_M32 { - // M in (16, 32] - // Shared Memory required by this Gemm - 61440 bytes - static_assert(std::is_same()); - using TileShape = typename cutlass::gemm::GemmShape<32, 64, 128>; - using WarpShape = typename cutlass::gemm::GemmShape<32, 64, 64>; - using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; - using Cutlass2xGemm = - cutlass_2x_gemm; -}; - -template typename Epilogue> -struct sm80_config_M16 { - // M in [1, 16] - // Shared Memory required by this Gemm - 51200 bytes - static_assert(std::is_same()); - using TileShape = typename cutlass::gemm::GemmShape<16, 64, 128>; - using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>; - using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; - using Cutlass2xGemm = - cutlass_2x_gemm; -}; - -} // namespace - -template typename Epilogue, - typename... EpilogueArgs> -void cutlass_gemm_sm80_dispatch(torch::Tensor& out, torch::Tensor const& a, - torch::Tensor const& b, - EpilogueArgs&&... args) { - static_assert(std::is_same()); - TORCH_CHECK(a.dtype() == torch::kInt8); - TORCH_CHECK(b.dtype() == torch::kInt8); - - using Cutlass2xGemmDefault = - typename sm80_config_default::Cutlass2xGemm; - using Cutlass2xGemmM128BigN = - typename sm80_config_default::Cutlass2xGemm; - using Cutlass2xGemmM128SmallN = - typename sm80_config_M64::Cutlass2xGemm; - using Cutlass2xGemmM64 = - typename sm80_config_M64::Cutlass2xGemm; - using Cutlass2xGemmM32 = - typename sm80_config_M32::Cutlass2xGemm; - using Cutlass2xGemmM16 = - typename sm80_config_M16::Cutlass2xGemm; - - // Due to shared memory requirements, some Gemms may fail to run on some - // GPUs. As the name indicates, the Fallback Gemm is used as an alternative - // in such cases. - // sm80_config_M16 has the least shared-memory requirement. However, - // based on some profiling, we select sm80_config_M32 as a better alternative - // performance wise. - using FallbackGemm = - typename sm80_config_M32::Cutlass2xGemm; - - uint32_t const m = a.size(0); - uint32_t const mp2 = - std::max(static_cast(16), next_pow_2(m)); // next power of 2 - if (mp2 <= 16) { - // M in [1, 16] - return fallback_cutlass_gemm_caller( - out, a, b, std::forward(args)...); - } else if (mp2 <= 32) { - // M in (16, 32] - return fallback_cutlass_gemm_caller( - out, a, b, std::forward(args)...); - } else if (mp2 <= 64) { - // M in (32, 64] - return fallback_cutlass_gemm_caller( - out, a, b, std::forward(args)...); - } else if (mp2 <= 128) { - // M in (64, 128] - uint32_t const n = out.size(1); - bool const small_n = n < 8192; - if (small_n) { - return fallback_cutlass_gemm_caller( - out, a, b, std::forward(args)...); - } else { - return fallback_cutlass_gemm_caller( - out, a, b, std::forward(args)...); - } - } else { - // M in (128, inf) - return fallback_cutlass_gemm_caller( - out, a, b, std::forward(args)...); - } -} - -template