[platform] Add verify_quantization in platform. (#10757)

Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
wangxiyuan 2024-11-29 23:22:21 +08:00 committed by GitHub
parent 3132aac043
commit 661175bc82
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 38 additions and 27 deletions

View File

@ -393,17 +393,11 @@ class ModelConfig:
def _verify_quantization(self) -> None:
supported_quantization = QUANTIZATION_METHODS
rocm_supported_quantization = [
"awq", "gptq", "fp8", "compressed_tensors", "compressed-tensors",
"fbgemm_fp8", "gguf"
]
optimized_quantization_methods = [
"fp8", "marlin", "modelopt", "gptq_marlin_24", "gptq_marlin",
"awq_marlin", "fbgemm_fp8", "compressed_tensors",
"compressed-tensors", "experts_int8"
]
tpu_supported_quantization = ["tpu_int8"]
neuron_supported_quantization = ["neuron_quant"]
if self.quantization is not None:
self.quantization = self.quantization.lower()
@ -438,32 +432,12 @@ class ModelConfig:
raise ValueError(
f"Unknown quantization method: {self.quantization}. Must "
f"be one of {supported_quantization}.")
if current_platform.is_rocm(
) and self.quantization not in rocm_supported_quantization:
raise ValueError(
f"{self.quantization} quantization is currently not "
f"supported in ROCm.")
if current_platform.is_tpu(
) and self.quantization not in tpu_supported_quantization:
raise ValueError(
f"{self.quantization} quantization is currently not "
f"supported in TPU Backend.")
current_platform.verify_quantization(self.quantization)
if self.quantization not in optimized_quantization_methods:
logger.warning(
"%s quantization is not fully "
"optimized yet. The speed can be slower than "
"non-quantized models.", self.quantization)
if (self.quantization == "awq" and current_platform.is_rocm()
and not envs.VLLM_USE_TRITON_AWQ):
logger.warning(
"Using AWQ quantization with ROCm, but VLLM_USE_TRITON_AWQ"
" is not set, enabling VLLM_USE_TRITON_AWQ.")
envs.VLLM_USE_TRITON_AWQ = True
if current_platform.is_neuron(
) and self.quantization not in neuron_supported_quantization:
raise ValueError(
f"{self.quantization} quantization is currently not "
f"supported in Neuron Backend.")
def _verify_cuda_graph(self) -> None:
if self.max_seq_len_to_capture is None:

View File

@ -19,6 +19,7 @@ logger = init_logger(__name__)
class CpuPlatform(Platform):
_enum = PlatformEnum.CPU
device_name: str = "cpu"
device_type: str = "cpu"
dispatch_key: str = "CPU"

View File

@ -72,6 +72,7 @@ def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]:
class CudaPlatformBase(Platform):
_enum = PlatformEnum.CUDA
device_name: str = "cuda"
device_type: str = "cuda"
dispatch_key: str = "CUDA"

View File

@ -12,6 +12,7 @@ else:
class HpuPlatform(Platform):
_enum = PlatformEnum.HPU
device_name: str = "hpu"
device_type: str = "hpu"
dispatch_key: str = "HPU"

View File

@ -56,11 +56,13 @@ class DeviceCapability(NamedTuple):
class Platform:
_enum: PlatformEnum
device_name: str
device_type: str
# available dispatch keys:
# check https://github.com/pytorch/pytorch/blob/313dac6c1ca0fa0cde32477509cce32089f8532a/torchgen/model.py#L134 # noqa
# use "CPU" as a fallback for platforms not registered in PyTorch
dispatch_key: str = "CPU"
supported_quantization: list[str] = []
def is_cuda(self) -> bool:
return self._enum == PlatformEnum.CUDA
@ -171,6 +173,17 @@ class Platform:
"""
pass
@classmethod
def verify_quantization(cls, quant: str) -> None:
"""
Verify whether the quantization is supported by the current platform.
"""
if cls.supported_quantization and \
quant not in cls.supported_quantization:
raise ValueError(
f"{quant} quantization is currently not supported in "
f"{cls.device_name}.")
class UnspecifiedPlatform(Platform):
_enum = PlatformEnum.UNSPECIFIED

View File

@ -10,7 +10,9 @@ else:
class NeuronPlatform(Platform):
_enum = PlatformEnum.NEURON
device_name: str = "neuron"
device_type: str = "neuron"
supported_quantization: list[str] = ["neuron_quant"]
@classmethod
def get_device_name(cls, device_id: int = 0) -> str:

View File

@ -23,6 +23,7 @@ except ImportError as e:
class OpenVinoPlatform(Platform):
_enum = PlatformEnum.OPENVINO
device_name: str = "openvino"
device_type: str = "openvino"
dispatch_key: str = "CPU"

View File

@ -4,6 +4,7 @@ from typing import TYPE_CHECKING
import torch
import vllm.envs as envs
from vllm.logger import init_logger
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend
@ -35,8 +36,13 @@ if os.environ.get("VLLM_WORKER_MULTIPROC_METHOD", None) in ["fork", None]:
class RocmPlatform(Platform):
_enum = PlatformEnum.ROCM
device_name: str = "rocm"
device_type: str = "cuda"
dispatch_key: str = "CUDA"
supported_quantization: list[str] = [
"awq", "gptq", "fp8", "compressed_tensors", "compressed-tensors",
"fbgemm_fp8", "gguf"
]
@classmethod
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
@ -79,3 +85,12 @@ class RocmPlatform(Platform):
"vllm.spec_decode.spec_decode_worker.create_spec_worker"
else:
parallel_config.worker_cls = "vllm.worker.worker.Worker"
@classmethod
def verify_quantization(cls, quant: str) -> None:
super().verify_quantization(quant)
if quant == "awq" and not envs.VLLM_USE_TRITON_AWQ:
logger.warning(
"Using AWQ quantization with ROCm, but VLLM_USE_TRITON_AWQ"
" is not set, enabling VLLM_USE_TRITON_AWQ.")
envs.VLLM_USE_TRITON_AWQ = True

View File

@ -16,8 +16,10 @@ logger = init_logger(__name__)
class TpuPlatform(Platform):
_enum = PlatformEnum.TPU
device_name: str = "tpu"
device_type: str = "tpu"
dispatch_key: str = "XLA"
supported_quantization: list[str] = ["tpu_int8"]
@classmethod
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:

View File

@ -16,6 +16,7 @@ logger = init_logger(__name__)
class XPUPlatform(Platform):
_enum = PlatformEnum.XPU
device_name: str = "xpu"
device_type: str = "xpu"
dispatch_key: str = "XPU"