Skip to content

Commit

Permalink
[Neuron] Adding support for adding/ overriding neuron configuration a…
Browse files Browse the repository at this point in the history
…nd adding support for neuron model quantization configuration.
  • Loading branch information
Harsha Bikki committed Sep 4, 2024
1 parent d331156 commit 1e256ad
Show file tree
Hide file tree
Showing 8 changed files with 243 additions and 42 deletions.
50 changes: 50 additions & 0 deletions examples/offline_inference_neuron_int8_quantization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import os

from vllm import LLM, SamplingParams

# creates XLA hlo graphs for all the context length buckets.
os.environ['NEURON_CONTEXT_LENGTH_BUCKETS'] = "128,512,1024,2048"
# creates XLA hlo graphs for all the token gen buckets.
os.environ['NEURON_TOKEN_GEN_BUCKETS'] = "128,512,1024,2048"
# Quantizes neuron model weight to int8 ,
# The default config for quantization is int8 dtype.
os.environ['NEURON_QUANT_DTYPE'] = "s8"

# Sample prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)

# Create an LLM.
llm = LLM(
model="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
max_num_seqs=8,
# The max_model_len and block_size arguments are required to be same as
# max sequence length when targeting neuron device.
# Currently, this is a known limitation in continuous batching support
# in transformers-neuronx.
# TODO(liangfu): Support paged-attention in transformers-neuronx.
max_model_len=2048,
block_size=2048,
# The device can be automatically detected when AWS Neuron SDK is installed.
# The device argument can be either unspecified for automated detection,
# or explicitly assigned.
device="neuron",
quantization="neuron_quant",
override_neuron_config={
"cast_logits_dtype": "bfloat16",
},
tensor_parallel_size=2)
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
69 changes: 41 additions & 28 deletions vllm/config.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import enum
import json
from dataclasses import dataclass, field, fields
from typing import (TYPE_CHECKING, ClassVar, List, Mapping, Optional, Tuple,
Type, Union)
from typing import (TYPE_CHECKING, Any, ClassVar, Dict, List, Mapping,
Optional, Tuple, Type, Union)

import torch
from transformers import PretrainedConfig
Expand Down Expand Up @@ -115,35 +115,39 @@ class ModelConfig:
the model name will be the same as `model`.
limit_mm_per_prompt: Maximum number of data instances per modality
per prompt. Only applicable for multimodal models.
override_neuron_config: Initialize non default neuron config or
override default neuron config that are specific to Neuron devices,
this argument will be used to configure the neuron config that
can not be gathered from the vllm arguments.
"""

def __init__(
self,
model: str,
tokenizer: str,
tokenizer_mode: str,
trust_remote_code: bool,
dtype: Union[str, torch.dtype],
seed: int,
revision: Optional[str] = None,
code_revision: Optional[str] = None,
rope_scaling: Optional[dict] = None,
rope_theta: Optional[float] = None,
tokenizer_revision: Optional[str] = None,
max_model_len: Optional[int] = None,
spec_target_max_model_len: Optional[int] = None,
quantization: Optional[str] = None,
quantization_param_path: Optional[str] = None,
enforce_eager: Optional[bool] = None,
max_context_len_to_capture: Optional[int] = None,
max_seq_len_to_capture: Optional[int] = None,
max_logprobs: int = 20,
disable_sliding_window: bool = False,
skip_tokenizer_init: bool = False,
served_model_name: Optional[Union[str, List[str]]] = None,
limit_mm_per_prompt: Optional[Mapping[str, int]] = None,
use_async_output_proc: bool = True,
) -> None:
self,
model: str,
tokenizer: str,
tokenizer_mode: str,
trust_remote_code: bool,
dtype: Union[str, torch.dtype],
seed: int,
revision: Optional[str] = None,
code_revision: Optional[str] = None,
rope_scaling: Optional[dict] = None,
rope_theta: Optional[float] = None,
tokenizer_revision: Optional[str] = None,
max_model_len: Optional[int] = None,
spec_target_max_model_len: Optional[int] = None,
quantization: Optional[str] = None,
quantization_param_path: Optional[str] = None,
enforce_eager: Optional[bool] = None,
max_context_len_to_capture: Optional[int] = None,
max_seq_len_to_capture: Optional[int] = None,
max_logprobs: int = 20,
disable_sliding_window: bool = False,
skip_tokenizer_init: bool = False,
served_model_name: Optional[Union[str, List[str]]] = None,
limit_mm_per_prompt: Optional[Mapping[str, int]] = None,
use_async_output_proc: bool = True,
override_neuron_config: Optional[Dict[str, Any]] = None) -> None:
self.model = model
self.tokenizer = tokenizer
self.tokenizer_mode = tokenizer_mode
Expand Down Expand Up @@ -227,6 +231,9 @@ def __init__(
limit_mm_per_prompt)
if not self.skip_tokenizer_init:
self._verify_tokenizer_mode()

self.override_neuron_config = override_neuron_config if is_neuron(
) else None
self._verify_embedding_mode()
self._verify_quantization()
self._verify_cuda_graph()
Expand Down Expand Up @@ -275,6 +282,7 @@ def _verify_quantization(self) -> None:
"experts_int8"
]
tpu_supported_quantization = ["tpu_int8"]
neuron_supported_quantization = ["neuron_quant"]
if self.quantization is not None:
self.quantization = self.quantization.lower()

Expand Down Expand Up @@ -329,6 +337,11 @@ def _verify_quantization(self) -> None:
"Using AWQ quantization with ROCm, but VLLM_USE_TRITON_AWQ"
" is not set, enabling VLLM_USE_TRITON_AWQ.")
envs.VLLM_USE_TRITON_AWQ = True
if is_neuron(
) and self.quantization not in neuron_supported_quantization:
raise ValueError(
f"{self.quantization} quantization is currently not "
f"supported in Neuron Backend.")

def _verify_cuda_graph(self) -> None:
if self.max_seq_len_to_capture is None:
Expand Down
17 changes: 14 additions & 3 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
import dataclasses
import json
from dataclasses import dataclass
from typing import (TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, Type,
Union)
from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple,
Type, Union)

import torch

Expand Down Expand Up @@ -149,6 +149,7 @@ class EngineArgs:
otlp_traces_endpoint: Optional[str] = None
collect_detailed_traces: Optional[str] = None
disable_async_output_proc: bool = False
override_neuron_config: Optional[Dict[str, Any]] = None

def __post_init__(self):
if self.tokenizer is None:
Expand Down Expand Up @@ -742,6 +743,16 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
default=EngineArgs.disable_async_output_proc,
help="Disable async output processing. This may result in "
"lower performance.")
parser.add_argument(
'--override-neuron-config',
type=lambda configs: {
str(key): value
for key, value in
(config.split(':') for config in configs.split(','))
},
default=None,
help="override or set neuron device configuration.")

return parser

@classmethod
Expand Down Expand Up @@ -802,7 +813,7 @@ def create_engine_config(self) -> EngineConfig:
served_model_name=self.served_model_name,
limit_mm_per_prompt=self.limit_mm_per_prompt,
use_async_output_proc=not self.disable_async_output_proc,
)
override_neuron_config=self.override_neuron_config)
cache_config = CacheConfig(
block_size=self.block_size if self.device != "neuron" else
self.max_model_len, # neuron needs block_size = max_model_len
Expand Down
2 changes: 2 additions & 0 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ def __init__(
"Initializing an LLM engine (v%s) with config: "
"model=%r, speculative_config=%r, tokenizer=%r, "
"skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, "
"override_neuron_config=%s, "
"rope_scaling=%r, rope_theta=%r, tokenizer_revision=%s, "
"trust_remote_code=%s, dtype=%s, max_seq_len=%d, "
"download_dir=%r, load_format=%s, tensor_parallel_size=%d, "
Expand All @@ -232,6 +233,7 @@ def __init__(
model_config.skip_tokenizer_init,
model_config.tokenizer_mode,
model_config.revision,
model_config.override_neuron_config,
model_config.rope_scaling,
model_config.rope_theta,
model_config.tokenizer_revision,
Expand Down
3 changes: 3 additions & 0 deletions vllm/model_executor/layers/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
GPTQMarlin24Config)
from vllm.model_executor.layers.quantization.marlin import MarlinConfig
from vllm.model_executor.layers.quantization.neuron_quant import (
NeuronQuantConfig)
from vllm.model_executor.layers.quantization.qqq import QQQConfig
from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig
from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig
Expand All @@ -46,6 +48,7 @@
"bitsandbytes": BitsAndBytesConfig,
"qqq": QQQConfig,
"experts_int8": ExpertsInt8Config,
"neuron_quant": NeuronQuantConfig,
}


Expand Down
67 changes: 67 additions & 0 deletions vllm/model_executor/layers/quantization/neuron_quant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import os
from importlib.util import find_spec
from typing import Any, Dict, List, Optional

from torch.nn import Module

from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)

SUPPORTED_QUANT_DTYPE_LIST = ['s8', 'f8e4m3fn']


class NeuronQuantConfig(QuantizationConfig):
"""Int8 Quantization Config class for Neuron Backend."""

def __init__(
self,
dequant_dtype: str = "f16",
quantize_method: str = "vector_dynamic",
) -> None:
self.quant_dtype = os.getenv("NEURON_QUANT_DTYPE", "s8")
if self.quant_dtype not in SUPPORTED_QUANT_DTYPE_LIST:
raise ValueError(
f"Neuron quantization datatype {self.quant_dtype} is not valid,"
f"the quantization datatype should match one of the below types"
f"{SUPPORTED_QUANT_DTYPE_LIST}")
self.dequant_dtype = dequant_dtype
self.quantize_method = quantize_method

def get_name(self) -> str:
return "neuron_quant"

def get_supported_act_dtypes(self) -> List[str]:
return SUPPORTED_QUANT_DTYPE_LIST

@classmethod
def get_min_capability(cls) -> int:
raise NotImplementedError(
"This function should not be called with Neuron Backend")

@staticmethod
def get_config_filenames() -> List[str]:
return []

@classmethod
def from_config(cls, config: Dict[str, Any]) -> "NeuronQuantConfig":
quantize_method = cls.get_from_keys(config, ["quantize_method"])
dequant_dtype = cls.get_from_keys(config, ["dequant_dtype"])
return cls(dequant_dtype=dequant_dtype,
quantize_method=quantize_method)

def get_quant_method(self, layer: Module, prefix: str) -> Optional[Any]:
if find_spec("transformers_neuronx") is not None:
return self.get_quantization_config()
else:
raise NotImplementedError(
"Neuron Quantization is only supported through"
" transformers_neuronx.")

def get_scaled_act_names(self) -> List[str]:
return []

def get_quantization_config(self):
from transformers_neuronx.config import QuantizationConfig
return QuantizationConfig(quant_dtype=self.quant_dtype,
dequant_dtype=self.dequant_dtype,
quantize_method=self.quantize_method)
Loading

0 comments on commit 1e256ad

Please sign in to comment.