From 24834f48943d4d106cca0764ffe3d68c0eb57500 Mon Sep 17 00:00:00 2001 From: ajayvohra2005 Date: Wed, 9 Apr 2025 06:43:22 -0400 Subject: [PATCH] update neuron config (#16289) Signed-off-by: Ajay Vohra --- vllm/model_executor/model_loader/neuron.py | 33 +++++++++++++++++++++- 1 file changed, 32 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/model_loader/neuron.py b/vllm/model_executor/model_loader/neuron.py index d900fb3a7d397..67aaad10fcfe9 100644 --- a/vllm/model_executor/model_loader/neuron.py +++ b/vllm/model_executor/model_loader/neuron.py @@ -174,8 +174,39 @@ def _is_neuron_on_device_sampling_disabled(model_config: ModelConfig) -> bool: def _get_neuron_config_after_override(default_neuron_config, overridden_neuron_config): - from transformers_neuronx.config import NeuronConfig + from transformers_neuronx.config import (ContinuousBatchingConfig, + GenerationConfig, + KVCacheQuantizationConfig, + NeuronConfig, QuantizationConfig, + SparseAttnConfig) + overridden_neuron_config = overridden_neuron_config or {} + sparse_attn = overridden_neuron_config.pop("sparse_attn", {}) + if sparse_attn: + overridden_neuron_config["sparse_attn"] = SparseAttnConfig( + **sparse_attn) + + kv_cache_quant = overridden_neuron_config.pop("kv_cache_quant", {}) + if kv_cache_quant: + overridden_neuron_config["kv_cache_quant"] = KVCacheQuantizationConfig( + **kv_cache_quant) + + continuous_batching = overridden_neuron_config.pop("continuous_batching", + {}) + if continuous_batching: + overridden_neuron_config[ + "continuous_batching"] = ContinuousBatchingConfig( + **continuous_batching) + + quant = overridden_neuron_config.pop("quant", {}) + if quant: + overridden_neuron_config["quant"] = QuantizationConfig(**quant) + + on_device_generation = overridden_neuron_config.pop( + "on_device_generation", {}) + if on_device_generation: + overridden_neuron_config["on_device_generation"] = GenerationConfig( + **on_device_generation) default_neuron_config.update(overridden_neuron_config) return NeuronConfig(**default_neuron_config)