diff --git a/vllm/transformers_utils/model_arch_config_convertor.py b/vllm/transformers_utils/model_arch_config_convertor.py index db5899eab5c12..999f46083c6e0 100644 --- a/vllm/transformers_utils/model_arch_config_convertor.py +++ b/vllm/transformers_utils/model_arch_config_convertor.py @@ -269,6 +269,22 @@ class ModelArchConfigConvertorBase: return model_arch_config +class MambaModelArchConfigConvertor(ModelArchConfigConvertorBase): + def get_head_size(self) -> int: + return 0 + + def get_total_num_kv_heads(self) -> int: + return 0 + + +class TerratorchModelArchConfigConvertor(ModelArchConfigConvertorBase): + def get_head_size(self) -> int: + return 0 + + def get_total_num_kv_heads(self) -> int: + return 0 + + class Zamba2ModelArchConfigConvertor(ModelArchConfigConvertorBase): def get_head_size(self) -> int: return getattr(self.hf_text_config, "attention_head_dim", 0) @@ -357,6 +373,9 @@ class LongCatFlashMTPModelArchConfigConvertor(ModelArchConfigConvertorBase): # hf_config.model_type -> convertor class MODEL_ARCH_CONFIG_CONVERTORS = { + "mamba": MambaModelArchConfigConvertor, + "mamba2": MambaModelArchConfigConvertor, + "terratorch": TerratorchModelArchConfigConvertor, "zamba2": Zamba2ModelArchConfigConvertor, "mpt": MPTModelArchConfigConvertor, "dbrx": DbrxModelArchConfigConvertor,