mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-06 05:41:48 +08:00
[platform] Add verify_quantization in platform. (#10757)
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
parent
3132aac043
commit
661175bc82
@ -393,17 +393,11 @@ class ModelConfig:
|
||||
|
||||
def _verify_quantization(self) -> None:
|
||||
supported_quantization = QUANTIZATION_METHODS
|
||||
rocm_supported_quantization = [
|
||||
"awq", "gptq", "fp8", "compressed_tensors", "compressed-tensors",
|
||||
"fbgemm_fp8", "gguf"
|
||||
]
|
||||
optimized_quantization_methods = [
|
||||
"fp8", "marlin", "modelopt", "gptq_marlin_24", "gptq_marlin",
|
||||
"awq_marlin", "fbgemm_fp8", "compressed_tensors",
|
||||
"compressed-tensors", "experts_int8"
|
||||
]
|
||||
tpu_supported_quantization = ["tpu_int8"]
|
||||
neuron_supported_quantization = ["neuron_quant"]
|
||||
if self.quantization is not None:
|
||||
self.quantization = self.quantization.lower()
|
||||
|
||||
@ -438,32 +432,12 @@ class ModelConfig:
|
||||
raise ValueError(
|
||||
f"Unknown quantization method: {self.quantization}. Must "
|
||||
f"be one of {supported_quantization}.")
|
||||
if current_platform.is_rocm(
|
||||
) and self.quantization not in rocm_supported_quantization:
|
||||
raise ValueError(
|
||||
f"{self.quantization} quantization is currently not "
|
||||
f"supported in ROCm.")
|
||||
if current_platform.is_tpu(
|
||||
) and self.quantization not in tpu_supported_quantization:
|
||||
raise ValueError(
|
||||
f"{self.quantization} quantization is currently not "
|
||||
f"supported in TPU Backend.")
|
||||
current_platform.verify_quantization(self.quantization)
|
||||
if self.quantization not in optimized_quantization_methods:
|
||||
logger.warning(
|
||||
"%s quantization is not fully "
|
||||
"optimized yet. The speed can be slower than "
|
||||
"non-quantized models.", self.quantization)
|
||||
if (self.quantization == "awq" and current_platform.is_rocm()
|
||||
and not envs.VLLM_USE_TRITON_AWQ):
|
||||
logger.warning(
|
||||
"Using AWQ quantization with ROCm, but VLLM_USE_TRITON_AWQ"
|
||||
" is not set, enabling VLLM_USE_TRITON_AWQ.")
|
||||
envs.VLLM_USE_TRITON_AWQ = True
|
||||
if current_platform.is_neuron(
|
||||
) and self.quantization not in neuron_supported_quantization:
|
||||
raise ValueError(
|
||||
f"{self.quantization} quantization is currently not "
|
||||
f"supported in Neuron Backend.")
|
||||
|
||||
def _verify_cuda_graph(self) -> None:
|
||||
if self.max_seq_len_to_capture is None:
|
||||
|
||||
@ -19,6 +19,7 @@ logger = init_logger(__name__)
|
||||
|
||||
class CpuPlatform(Platform):
|
||||
_enum = PlatformEnum.CPU
|
||||
device_name: str = "cpu"
|
||||
device_type: str = "cpu"
|
||||
dispatch_key: str = "CPU"
|
||||
|
||||
|
||||
@ -72,6 +72,7 @@ def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]:
|
||||
|
||||
class CudaPlatformBase(Platform):
|
||||
_enum = PlatformEnum.CUDA
|
||||
device_name: str = "cuda"
|
||||
device_type: str = "cuda"
|
||||
dispatch_key: str = "CUDA"
|
||||
|
||||
|
||||
@ -12,6 +12,7 @@ else:
|
||||
|
||||
class HpuPlatform(Platform):
|
||||
_enum = PlatformEnum.HPU
|
||||
device_name: str = "hpu"
|
||||
device_type: str = "hpu"
|
||||
dispatch_key: str = "HPU"
|
||||
|
||||
|
||||
@ -56,11 +56,13 @@ class DeviceCapability(NamedTuple):
|
||||
|
||||
class Platform:
|
||||
_enum: PlatformEnum
|
||||
device_name: str
|
||||
device_type: str
|
||||
# available dispatch keys:
|
||||
# check https://github.com/pytorch/pytorch/blob/313dac6c1ca0fa0cde32477509cce32089f8532a/torchgen/model.py#L134 # noqa
|
||||
# use "CPU" as a fallback for platforms not registered in PyTorch
|
||||
dispatch_key: str = "CPU"
|
||||
supported_quantization: list[str] = []
|
||||
|
||||
def is_cuda(self) -> bool:
|
||||
return self._enum == PlatformEnum.CUDA
|
||||
@ -171,6 +173,17 @@ class Platform:
|
||||
"""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def verify_quantization(cls, quant: str) -> None:
|
||||
"""
|
||||
Verify whether the quantization is supported by the current platform.
|
||||
"""
|
||||
if cls.supported_quantization and \
|
||||
quant not in cls.supported_quantization:
|
||||
raise ValueError(
|
||||
f"{quant} quantization is currently not supported in "
|
||||
f"{cls.device_name}.")
|
||||
|
||||
|
||||
class UnspecifiedPlatform(Platform):
|
||||
_enum = PlatformEnum.UNSPECIFIED
|
||||
|
||||
@ -10,7 +10,9 @@ else:
|
||||
|
||||
class NeuronPlatform(Platform):
|
||||
_enum = PlatformEnum.NEURON
|
||||
device_name: str = "neuron"
|
||||
device_type: str = "neuron"
|
||||
supported_quantization: list[str] = ["neuron_quant"]
|
||||
|
||||
@classmethod
|
||||
def get_device_name(cls, device_id: int = 0) -> str:
|
||||
|
||||
@ -23,6 +23,7 @@ except ImportError as e:
|
||||
|
||||
class OpenVinoPlatform(Platform):
|
||||
_enum = PlatformEnum.OPENVINO
|
||||
device_name: str = "openvino"
|
||||
device_type: str = "openvino"
|
||||
dispatch_key: str = "CPU"
|
||||
|
||||
|
||||
@ -4,6 +4,7 @@ from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend
|
||||
@ -35,8 +36,13 @@ if os.environ.get("VLLM_WORKER_MULTIPROC_METHOD", None) in ["fork", None]:
|
||||
|
||||
class RocmPlatform(Platform):
|
||||
_enum = PlatformEnum.ROCM
|
||||
device_name: str = "rocm"
|
||||
device_type: str = "cuda"
|
||||
dispatch_key: str = "CUDA"
|
||||
supported_quantization: list[str] = [
|
||||
"awq", "gptq", "fp8", "compressed_tensors", "compressed-tensors",
|
||||
"fbgemm_fp8", "gguf"
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
|
||||
@ -79,3 +85,12 @@ class RocmPlatform(Platform):
|
||||
"vllm.spec_decode.spec_decode_worker.create_spec_worker"
|
||||
else:
|
||||
parallel_config.worker_cls = "vllm.worker.worker.Worker"
|
||||
|
||||
@classmethod
|
||||
def verify_quantization(cls, quant: str) -> None:
|
||||
super().verify_quantization(quant)
|
||||
if quant == "awq" and not envs.VLLM_USE_TRITON_AWQ:
|
||||
logger.warning(
|
||||
"Using AWQ quantization with ROCm, but VLLM_USE_TRITON_AWQ"
|
||||
" is not set, enabling VLLM_USE_TRITON_AWQ.")
|
||||
envs.VLLM_USE_TRITON_AWQ = True
|
||||
|
||||
@ -16,8 +16,10 @@ logger = init_logger(__name__)
|
||||
|
||||
class TpuPlatform(Platform):
|
||||
_enum = PlatformEnum.TPU
|
||||
device_name: str = "tpu"
|
||||
device_type: str = "tpu"
|
||||
dispatch_key: str = "XLA"
|
||||
supported_quantization: list[str] = ["tpu_int8"]
|
||||
|
||||
@classmethod
|
||||
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
|
||||
|
||||
@ -16,6 +16,7 @@ logger = init_logger(__name__)
|
||||
|
||||
class XPUPlatform(Platform):
|
||||
_enum = PlatformEnum.XPU
|
||||
device_name: str = "xpu"
|
||||
device_type: str = "xpu"
|
||||
dispatch_key: str = "XPU"
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user