diff --git a/tests/config/model_arch_groundtruth.json b/tests/config/model_arch_groundtruth.json index 6916ec50d9743..c6d321f6c3257 100644 --- a/tests/config/model_arch_groundtruth.json +++ b/tests/config/model_arch_groundtruth.json @@ -1,4 +1,55 @@ { + "state-spaces/mamba-130m-hf": { + "architectures": [ + "MambaForCausalLM" + ], + "model_type": "mamba", + "text_model_type": "mamba", + "hidden_size": 768, + "total_num_hidden_layers": 24, + "total_num_attention_heads": 0, + "head_size": 0, + "vocab_size": 50280, + "total_num_kv_heads": 0, + "num_experts": 0, + "is_deepseek_mla": false, + "is_multimodal_model": false, + "dtype": "torch.float32" + }, + "mistralai/Mamba-Codestral-7B-v0.1": { + "architectures": [ + "Mamba2ForCausalLM" + ], + "model_type": "mamba", + "text_model_type": "mamba", + "hidden_size": 4096, + "total_num_hidden_layers": 64, + "total_num_attention_heads": 0, + "head_size": 0, + "vocab_size": 32768, + "total_num_kv_heads": 0, + "num_experts": 0, + "is_deepseek_mla": false, + "is_multimodal_model": false, + "dtype": "torch.bfloat16" + }, + "ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11": { + "architectures": [ + "Terratorch" + ], + "model_type": "timm_wrapper", + "text_model_type": "timm_wrapper", + "hidden_size": 0, + "total_num_hidden_layers": 0, + "total_num_attention_heads": 0, + "head_size": 0, + "vocab_size": 0, + "total_num_kv_heads": 0, + "num_experts": 0, + "is_deepseek_mla": false, + "is_multimodal_model": true, + "dtype": "torch.float32" + }, "Zyphra/Zamba2-7B-instruct": { "architectures": [ "Zamba2ForCausalLM" diff --git a/tests/config/test_model_arch_config.py b/tests/config/test_model_arch_config.py index 43750753ea514..365cc1104ccaf 100644 --- a/tests/config/test_model_arch_config.py +++ b/tests/config/test_model_arch_config.py @@ -15,6 +15,9 @@ def test_model_arch_config(): "meituan-longcat/LongCat-Flash-Chat", ] models_to_test = [ + "state-spaces/mamba-130m-hf", + "mistralai/Mamba-Codestral-7B-v0.1", + "ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11", "Zyphra/Zamba2-7B-instruct", "mosaicml/mpt-7b", "databricks/dbrx-instruct", diff --git a/vllm/transformers_utils/model_arch_config_convertor.py b/vllm/transformers_utils/model_arch_config_convertor.py index 16fd5b6b6dcbd..d1e28cbe558bb 100644 --- a/vllm/transformers_utils/model_arch_config_convertor.py +++ b/vllm/transformers_utils/model_arch_config_convertor.py @@ -381,8 +381,7 @@ class LongCatFlashMTPModelArchConfigConvertor(ModelArchConfigConvertorBase): # hf_config.model_type -> convertor class MODEL_ARCH_CONFIG_CONVERTORS = { "mamba": MambaModelArchConfigConvertor, - "mamba2": MambaModelArchConfigConvertor, - "terratorch": TerratorchModelArchConfigConvertor, + "timm_wrapper": TerratorchModelArchConfigConvertor, "zamba2": Zamba2ModelArchConfigConvertor, "mpt": MPTModelArchConfigConvertor, "dbrx": DbrxModelArchConfigConvertor,