diff --git a/vllm/config/cache.py b/vllm/config/cache.py index 91f083a5534ba..067799a44db30 100644 --- a/vllm/config/cache.py +++ b/vllm/config/cache.py @@ -29,7 +29,7 @@ CacheDType = Literal[ "fp8_inc", "fp8_ds_mla", ] -MambaDType = Literal["auto", "float32"] +MambaDType = Literal["auto", "float32", "float16"] PrefixCachingHashAlgo = Literal["sha256", "sha256_cbor", "xxhash", "xxhash_cbor"] KVOffloadingBackend = Literal["native", "lmcache"] diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index 4bca36aa4b7de..fbeb28a1c0b36 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -485,6 +485,26 @@ class DeepseekV32ForCausalLM(VerifyAndUpdateConfig): logger.info("Using bfloat16 kv-cache for DeepSeekV3.2") +class NemotronHForCausalLMConfig(VerifyAndUpdateConfig): + @staticmethod + def verify_and_update_config(vllm_config: "VllmConfig") -> None: + """Update mamba_ssm_cache_dtype for NemotronH models when set to 'auto' + (or not explicitly set), to the value specified in the HF config, or to + float16 if not specified. + """ + cache_config = vllm_config.cache_config + if cache_config.mamba_ssm_cache_dtype == "auto": + hf_config = vllm_config.model_config.hf_config + mamba_ssm_cache_dtype = getattr( + hf_config, "mamba_ssm_cache_dtype", "float16" + ) + logger.info( + "Updating mamba_ssm_cache_dtype to '%s' for NemotronH model", + mamba_ssm_cache_dtype, + ) + cache_config.mamba_ssm_cache_dtype = mamba_ssm_cache_dtype + + MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = { "GteModel": SnowflakeGteNewModelConfig, "GteNewModel": GteNewModelConfig, @@ -502,4 +522,5 @@ MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = { "Mamba2ForCausalLM": MambaModelConfig, "FalconMambaForCausalLM": MambaModelConfig, "DeepseekV32ForCausalLM": DeepseekV32ForCausalLM, + "NemotronHForCausalLM": NemotronHForCausalLMConfig, } diff --git a/vllm/utils/torch_utils.py b/vllm/utils/torch_utils.py index f5c49ac169f0c..c97efce312b56 100644 --- a/vllm/utils/torch_utils.py +++ b/vllm/utils/torch_utils.py @@ -28,6 +28,7 @@ else: STR_DTYPE_TO_TORCH_DTYPE = { "float32": torch.float32, "half": torch.half, + "float16": torch.float16, "bfloat16": torch.bfloat16, "float": torch.float, "fp8": torch.uint8,