mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-26 10:49:38 +08:00
update neuron config (#16289)
Signed-off-by: Ajay Vohra <ajayvohr@amazon.com>
This commit is contained in:
parent
ec7da6fcf3
commit
24834f4894
@ -174,8 +174,39 @@ def _is_neuron_on_device_sampling_disabled(model_config: ModelConfig) -> bool:
|
||||
|
||||
def _get_neuron_config_after_override(default_neuron_config,
|
||||
overridden_neuron_config):
|
||||
from transformers_neuronx.config import NeuronConfig
|
||||
from transformers_neuronx.config import (ContinuousBatchingConfig,
|
||||
GenerationConfig,
|
||||
KVCacheQuantizationConfig,
|
||||
NeuronConfig, QuantizationConfig,
|
||||
SparseAttnConfig)
|
||||
|
||||
overridden_neuron_config = overridden_neuron_config or {}
|
||||
sparse_attn = overridden_neuron_config.pop("sparse_attn", {})
|
||||
if sparse_attn:
|
||||
overridden_neuron_config["sparse_attn"] = SparseAttnConfig(
|
||||
**sparse_attn)
|
||||
|
||||
kv_cache_quant = overridden_neuron_config.pop("kv_cache_quant", {})
|
||||
if kv_cache_quant:
|
||||
overridden_neuron_config["kv_cache_quant"] = KVCacheQuantizationConfig(
|
||||
**kv_cache_quant)
|
||||
|
||||
continuous_batching = overridden_neuron_config.pop("continuous_batching",
|
||||
{})
|
||||
if continuous_batching:
|
||||
overridden_neuron_config[
|
||||
"continuous_batching"] = ContinuousBatchingConfig(
|
||||
**continuous_batching)
|
||||
|
||||
quant = overridden_neuron_config.pop("quant", {})
|
||||
if quant:
|
||||
overridden_neuron_config["quant"] = QuantizationConfig(**quant)
|
||||
|
||||
on_device_generation = overridden_neuron_config.pop(
|
||||
"on_device_generation", {})
|
||||
if on_device_generation:
|
||||
overridden_neuron_config["on_device_generation"] = GenerationConfig(
|
||||
**on_device_generation)
|
||||
default_neuron_config.update(overridden_neuron_config)
|
||||
return NeuronConfig(**default_neuron_config)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user