mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-27 03:34:26 +08:00
[Frontend][Model] Add 'float16' to possible mamba cache dtype values, override mamba SSM cache dtype value for NemotronH (#29978)
Signed-off-by: amitz-nv <203509407+amitz-nv@users.noreply.github.com>
This commit is contained in:
parent
60a66ea2dc
commit
6038b1b04b
@ -29,7 +29,7 @@ CacheDType = Literal[
|
|||||||
"fp8_inc",
|
"fp8_inc",
|
||||||
"fp8_ds_mla",
|
"fp8_ds_mla",
|
||||||
]
|
]
|
||||||
MambaDType = Literal["auto", "float32"]
|
MambaDType = Literal["auto", "float32", "float16"]
|
||||||
PrefixCachingHashAlgo = Literal["sha256", "sha256_cbor", "xxhash", "xxhash_cbor"]
|
PrefixCachingHashAlgo = Literal["sha256", "sha256_cbor", "xxhash", "xxhash_cbor"]
|
||||||
KVOffloadingBackend = Literal["native", "lmcache"]
|
KVOffloadingBackend = Literal["native", "lmcache"]
|
||||||
|
|
||||||
|
|||||||
@ -485,6 +485,26 @@ class DeepseekV32ForCausalLM(VerifyAndUpdateConfig):
|
|||||||
logger.info("Using bfloat16 kv-cache for DeepSeekV3.2")
|
logger.info("Using bfloat16 kv-cache for DeepSeekV3.2")
|
||||||
|
|
||||||
|
|
||||||
|
class NemotronHForCausalLMConfig(VerifyAndUpdateConfig):
|
||||||
|
@staticmethod
|
||||||
|
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
|
||||||
|
"""Update mamba_ssm_cache_dtype for NemotronH models when set to 'auto'
|
||||||
|
(or not explicitly set), to the value specified in the HF config, or to
|
||||||
|
float16 if not specified.
|
||||||
|
"""
|
||||||
|
cache_config = vllm_config.cache_config
|
||||||
|
if cache_config.mamba_ssm_cache_dtype == "auto":
|
||||||
|
hf_config = vllm_config.model_config.hf_config
|
||||||
|
mamba_ssm_cache_dtype = getattr(
|
||||||
|
hf_config, "mamba_ssm_cache_dtype", "float16"
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
"Updating mamba_ssm_cache_dtype to '%s' for NemotronH model",
|
||||||
|
mamba_ssm_cache_dtype,
|
||||||
|
)
|
||||||
|
cache_config.mamba_ssm_cache_dtype = mamba_ssm_cache_dtype
|
||||||
|
|
||||||
|
|
||||||
MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
|
MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
|
||||||
"GteModel": SnowflakeGteNewModelConfig,
|
"GteModel": SnowflakeGteNewModelConfig,
|
||||||
"GteNewModel": GteNewModelConfig,
|
"GteNewModel": GteNewModelConfig,
|
||||||
@ -502,4 +522,5 @@ MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
|
|||||||
"Mamba2ForCausalLM": MambaModelConfig,
|
"Mamba2ForCausalLM": MambaModelConfig,
|
||||||
"FalconMambaForCausalLM": MambaModelConfig,
|
"FalconMambaForCausalLM": MambaModelConfig,
|
||||||
"DeepseekV32ForCausalLM": DeepseekV32ForCausalLM,
|
"DeepseekV32ForCausalLM": DeepseekV32ForCausalLM,
|
||||||
|
"NemotronHForCausalLM": NemotronHForCausalLMConfig,
|
||||||
}
|
}
|
||||||
|
|||||||
@ -28,6 +28,7 @@ else:
|
|||||||
STR_DTYPE_TO_TORCH_DTYPE = {
|
STR_DTYPE_TO_TORCH_DTYPE = {
|
||||||
"float32": torch.float32,
|
"float32": torch.float32,
|
||||||
"half": torch.half,
|
"half": torch.half,
|
||||||
|
"float16": torch.float16,
|
||||||
"bfloat16": torch.bfloat16,
|
"bfloat16": torch.bfloat16,
|
||||||
"float": torch.float,
|
"float": torch.float,
|
||||||
"fp8": torch.uint8,
|
"fp8": torch.uint8,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user