From e0cbad4e30d7df4e9ee4634939ce19042639b735 Mon Sep 17 00:00:00 2001 From: Satyajith Chilappagari Date: Tue, 27 May 2025 15:10:33 -0700 Subject: [PATCH] [Neuron] Support quantization on neuron (#18283) Signed-off-by: Satyajith Chilappagari --- tests/neuron/1_core/test_neuron_quant.py | 11 +++++++++++ .../layers/quantization/neuron_quant.py | 9 ++++++++- vllm/platforms/neuron.py | 2 +- 3 files changed, 20 insertions(+), 2 deletions(-) create mode 100644 tests/neuron/1_core/test_neuron_quant.py diff --git a/tests/neuron/1_core/test_neuron_quant.py b/tests/neuron/1_core/test_neuron_quant.py new file mode 100644 index 0000000000000..68f0cb8054b4f --- /dev/null +++ b/tests/neuron/1_core/test_neuron_quant.py @@ -0,0 +1,11 @@ +# SPDX-License-Identifier: Apache-2.0 +from vllm.model_executor.layers.quantization.neuron_quant import ( + NeuronQuantConfig) + + +def test_get_supported_act_dtypes(): + neuron_quant_config = NeuronQuantConfig() + supported_act_dtypes = neuron_quant_config.get_supported_act_dtypes() + target_list = ["any_dtype1", "any_dtype2"] + for dtype in target_list: + assert dtype in supported_act_dtypes diff --git a/vllm/model_executor/layers/quantization/neuron_quant.py b/vllm/model_executor/layers/quantization/neuron_quant.py index 38b374feea81d..b2d6bf5dbf9cc 100644 --- a/vllm/model_executor/layers/quantization/neuron_quant.py +++ b/vllm/model_executor/layers/quantization/neuron_quant.py @@ -13,6 +13,12 @@ from vllm.model_executor.layers.quantization.base_config import ( SUPPORTED_QUANT_DTYPE_LIST = ['s8', 'f8e4m3fn'] +class AlwaysSupportedDtypes(list): + + def __contains__(self, item): + return True + + class NeuronQuantConfig(QuantizationConfig): """Int8 Quantization Config class for Neuron Backend.""" @@ -35,7 +41,8 @@ class NeuronQuantConfig(QuantizationConfig): return "neuron_quant" def get_supported_act_dtypes(self) -> list[str]: - return SUPPORTED_QUANT_DTYPE_LIST + # Neuron implements custom handling logic for quantization support + return AlwaysSupportedDtypes() @classmethod def get_min_capability(cls) -> int: diff --git a/vllm/platforms/neuron.py b/vllm/platforms/neuron.py index 9cd49fd348049..474c70d04140b 100644 --- a/vllm/platforms/neuron.py +++ b/vllm/platforms/neuron.py @@ -28,7 +28,7 @@ class NeuronPlatform(Platform): device_name: str = "neuron" device_type: str = "neuron" ray_device_key: str = "neuron_cores" - supported_quantization: list[str] = ["neuron_quant"] + supported_quantization: list[str] = ["neuron_quant", "fbgemm_fp8"] device_control_env_var: str = "NEURON_RT_VISIBLE_CORES" @classmethod