update neuron config (#16289)

Signed-off-by: Ajay Vohra <ajayvohr@amazon.com>
This commit is contained in:
ajayvohra2005 2025-04-09 06:43:22 -04:00 committed by GitHub
parent ec7da6fcf3
commit 24834f4894
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -174,8 +174,39 @@ def _is_neuron_on_device_sampling_disabled(model_config: ModelConfig) -> bool:
def _get_neuron_config_after_override(default_neuron_config,
overridden_neuron_config):
from transformers_neuronx.config import NeuronConfig
from transformers_neuronx.config import (ContinuousBatchingConfig,
GenerationConfig,
KVCacheQuantizationConfig,
NeuronConfig, QuantizationConfig,
SparseAttnConfig)
overridden_neuron_config = overridden_neuron_config or {}
sparse_attn = overridden_neuron_config.pop("sparse_attn", {})
if sparse_attn:
overridden_neuron_config["sparse_attn"] = SparseAttnConfig(
**sparse_attn)
kv_cache_quant = overridden_neuron_config.pop("kv_cache_quant", {})
if kv_cache_quant:
overridden_neuron_config["kv_cache_quant"] = KVCacheQuantizationConfig(
**kv_cache_quant)
continuous_batching = overridden_neuron_config.pop("continuous_batching",
{})
if continuous_batching:
overridden_neuron_config[
"continuous_batching"] = ContinuousBatchingConfig(
**continuous_batching)
quant = overridden_neuron_config.pop("quant", {})
if quant:
overridden_neuron_config["quant"] = QuantizationConfig(**quant)
on_device_generation = overridden_neuron_config.pop(
"on_device_generation", {})
if on_device_generation:
overridden_neuron_config["on_device_generation"] = GenerationConfig(
**on_device_generation)
default_neuron_config.update(overridden_neuron_config)
return NeuronConfig(**default_neuron_config)