[Neuron] Support quantization on neuron (#18283)

Signed-off-by: Satyajith Chilappagari <satchill@amazon.com>
This commit is contained in:
Satyajith Chilappagari 2025-05-27 15:10:33 -07:00 committed by GitHub
parent b48d5cca16
commit e0cbad4e30
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 20 additions and 2 deletions

View File

@ -0,0 +1,11 @@
# SPDX-License-Identifier: Apache-2.0
from vllm.model_executor.layers.quantization.neuron_quant import (
NeuronQuantConfig)
def test_get_supported_act_dtypes():
neuron_quant_config = NeuronQuantConfig()
supported_act_dtypes = neuron_quant_config.get_supported_act_dtypes()
target_list = ["any_dtype1", "any_dtype2"]
for dtype in target_list:
assert dtype in supported_act_dtypes

View File

@ -13,6 +13,12 @@ from vllm.model_executor.layers.quantization.base_config import (
SUPPORTED_QUANT_DTYPE_LIST = ['s8', 'f8e4m3fn']
class AlwaysSupportedDtypes(list):
def __contains__(self, item):
return True
class NeuronQuantConfig(QuantizationConfig):
"""Int8 Quantization Config class for Neuron Backend."""
@ -35,7 +41,8 @@ class NeuronQuantConfig(QuantizationConfig):
return "neuron_quant"
def get_supported_act_dtypes(self) -> list[str]:
return SUPPORTED_QUANT_DTYPE_LIST
# Neuron implements custom handling logic for quantization support
return AlwaysSupportedDtypes()
@classmethod
def get_min_capability(cls) -> int:

View File

@ -28,7 +28,7 @@ class NeuronPlatform(Platform):
device_name: str = "neuron"
device_type: str = "neuron"
ray_device_key: str = "neuron_cores"
supported_quantization: list[str] = ["neuron_quant"]
supported_quantization: list[str] = ["neuron_quant", "fbgemm_fp8"]
device_control_env_var: str = "NEURON_RT_VISIBLE_CORES"
@classmethod