add tests for attention free models

Signed-off-by: Xingyu Liu <charlotteliu12x@gmail.com>
This commit is contained in:
Xingyu Liu 2025-12-09 00:32:57 -08:00
parent c327dffce1
commit a8cc81a695
3 changed files with 55 additions and 2 deletions

View File

@ -1,4 +1,55 @@
{
"state-spaces/mamba-130m-hf": {
"architectures": [
"MambaForCausalLM"
],
"model_type": "mamba",
"text_model_type": "mamba",
"hidden_size": 768,
"total_num_hidden_layers": 24,
"total_num_attention_heads": 0,
"head_size": 0,
"vocab_size": 50280,
"total_num_kv_heads": 0,
"num_experts": 0,
"is_deepseek_mla": false,
"is_multimodal_model": false,
"dtype": "torch.float32"
},
"mistralai/Mamba-Codestral-7B-v0.1": {
"architectures": [
"Mamba2ForCausalLM"
],
"model_type": "mamba",
"text_model_type": "mamba",
"hidden_size": 4096,
"total_num_hidden_layers": 64,
"total_num_attention_heads": 0,
"head_size": 0,
"vocab_size": 32768,
"total_num_kv_heads": 0,
"num_experts": 0,
"is_deepseek_mla": false,
"is_multimodal_model": false,
"dtype": "torch.bfloat16"
},
"ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11": {
"architectures": [
"Terratorch"
],
"model_type": "timm_wrapper",
"text_model_type": "timm_wrapper",
"hidden_size": 0,
"total_num_hidden_layers": 0,
"total_num_attention_heads": 0,
"head_size": 0,
"vocab_size": 0,
"total_num_kv_heads": 0,
"num_experts": 0,
"is_deepseek_mla": false,
"is_multimodal_model": true,
"dtype": "torch.float32"
},
"Zyphra/Zamba2-7B-instruct": {
"architectures": [
"Zamba2ForCausalLM"

View File

@ -15,6 +15,9 @@ def test_model_arch_config():
"meituan-longcat/LongCat-Flash-Chat",
]
models_to_test = [
"state-spaces/mamba-130m-hf",
"mistralai/Mamba-Codestral-7B-v0.1",
"ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11",
"Zyphra/Zamba2-7B-instruct",
"mosaicml/mpt-7b",
"databricks/dbrx-instruct",

View File

@ -381,8 +381,7 @@ class LongCatFlashMTPModelArchConfigConvertor(ModelArchConfigConvertorBase):
# hf_config.model_type -> convertor class
MODEL_ARCH_CONFIG_CONVERTORS = {
"mamba": MambaModelArchConfigConvertor,
"mamba2": MambaModelArchConfigConvertor,
"terratorch": TerratorchModelArchConfigConvertor,
"timm_wrapper": TerratorchModelArchConfigConvertor,
"zamba2": Zamba2ModelArchConfigConvertor,
"mpt": MPTModelArchConfigConvertor,
"dbrx": DbrxModelArchConfigConvertor,