[CI/Build] Remove unnecessary flags from test registry (#27353)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-10-23 22:42:40 +08:00 committed by GitHub
parent 237cf6d32a
commit fe2016de2d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 89 additions and 123 deletions

View File

@ -374,8 +374,8 @@ th {
| `GraniteMoeSharedForCausalLM` | Granite MoE Shared | `ibm-research/moe-7b-1b-active-shared-experts` (test model) | ✅︎ | ✅︎ | | `GraniteMoeSharedForCausalLM` | Granite MoE Shared | `ibm-research/moe-7b-1b-active-shared-experts` (test model) | ✅︎ | ✅︎ |
| `GritLM` | GritLM | `parasail-ai/GritLM-7B-vllm`. | ✅︎ | ✅︎ | | `GritLM` | GritLM | `parasail-ai/GritLM-7B-vllm`. | ✅︎ | ✅︎ |
| `Grok1ModelForCausalLM` | Grok1 | `hpcai-tech/grok-1`. | ✅︎ | ✅︎ | | `Grok1ModelForCausalLM` | Grok1 | `hpcai-tech/grok-1`. | ✅︎ | ✅︎ |
| `HunYuanDenseV1ForCausalLM` | Hunyuan-7B-Instruct-0124 | `tencent/Hunyuan-7B-Instruct-0124` | ✅︎ | ✅︎ | | `HunYuanDenseV1ForCausalLM` | Hunyuan Dense | `tencent/Hunyuan-7B-Instruct` | ✅︎ | ✅︎ |
| `HunYuanMoEV1ForCausalLM` | Hunyuan-80B-A13B | `tencent/Hunyuan-A13B-Instruct`, `tencent/Hunyuan-A13B-Pretrain`, `tencent/Hunyuan-A13B-Instruct-FP8`, etc. | ✅︎ | ✅︎ | | `HunYuanMoEV1ForCausalLM` | Hunyuan-A13B | `tencent/Hunyuan-A13B-Instruct`, `tencent/Hunyuan-A13B-Pretrain`, `tencent/Hunyuan-A13B-Instruct-FP8`, etc. | ✅︎ | ✅︎ |
| `HCXVisionForCausalLM` | HyperCLOVAX-SEED-Vision-Instruct-3B | `naver-hyperclovax/HyperCLOVAX-SEED-Vision-Instruct-3B` | | | | `HCXVisionForCausalLM` | HyperCLOVAX-SEED-Vision-Instruct-3B | `naver-hyperclovax/HyperCLOVAX-SEED-Vision-Instruct-3B` | | |
| `InternLMForCausalLM` | InternLM | `internlm/internlm-7b`, `internlm/internlm-chat-7b`, etc. | ✅︎ | ✅︎ | | `InternLMForCausalLM` | InternLM | `internlm/internlm-7b`, `internlm/internlm-chat-7b`, etc. | ✅︎ | ✅︎ |
| `InternLM2ForCausalLM` | InternLM2 | `internlm/internlm2-7b`, `internlm/internlm2-chat-7b`, etc. | ✅︎ | ✅︎ | | `InternLM2ForCausalLM` | InternLM2 | `internlm/internlm2-7b`, `internlm/internlm2-chat-7b`, etc. | ✅︎ | ✅︎ |

View File

@ -244,7 +244,7 @@ def _compare_tp(
tokenizer_mode = model_info.tokenizer_mode tokenizer_mode = model_info.tokenizer_mode
hf_overrides = model_info.hf_overrides hf_overrides = model_info.hf_overrides
hf_config = get_config(model_id, trust_remote_code) hf_config = get_config(model_id, trust_remote_code)
skip_tokenizer_init = model_info.skip_tokenizer_init require_embed_inputs = model_info.require_embed_inputs
max_num_seqs = model_info.max_num_seqs max_num_seqs = model_info.max_num_seqs
dtype = "float16" dtype = "float16"
@ -299,8 +299,14 @@ def _compare_tp(
common_args.extend(["--load-format", load_format]) common_args.extend(["--load-format", load_format])
if hf_overrides: if hf_overrides:
common_args.extend(["--hf-overrides", json.dumps(hf_overrides)]) common_args.extend(["--hf-overrides", json.dumps(hf_overrides)])
if skip_tokenizer_init: if require_embed_inputs:
common_args.append("--skip-tokenizer-init") common_args.extend(
[
"--skip-tokenizer-init",
"--enable-prompt-embeds",
"--enable-mm-embeds",
]
)
if max_num_seqs: if max_num_seqs:
common_args.extend(["--max-num-seqs", f"{max_num_seqs}"]) common_args.extend(["--max-num-seqs", f"{max_num_seqs}"])

View File

@ -181,7 +181,7 @@ def _compare_sp(
trust_remote_code = model_info.trust_remote_code trust_remote_code = model_info.trust_remote_code
tokenizer_mode = model_info.tokenizer_mode tokenizer_mode = model_info.tokenizer_mode
hf_overrides = model_info.hf_overrides hf_overrides = model_info.hf_overrides
skip_tokenizer_init = model_info.skip_tokenizer_init require_embed_inputs = model_info.require_embed_inputs
if load_format == "dummy": if load_format == "dummy":
# Avoid OOM # Avoid OOM
@ -233,8 +233,14 @@ def _compare_sp(
common_args.extend(["--load-format", load_format]) common_args.extend(["--load-format", load_format])
if hf_overrides: if hf_overrides:
common_args.extend(["--hf-overrides", json.dumps(hf_overrides)]) common_args.extend(["--hf-overrides", json.dumps(hf_overrides)])
if skip_tokenizer_init: if require_embed_inputs:
common_args.append("--skip-tokenizer-init") common_args.extend(
[
"--skip-tokenizer-init",
"--enable-prompt-embeds",
"--enable-mm-embeds",
]
)
compilation_config = { compilation_config = {
"mode": CompilationMode.VLLM_COMPILE, "mode": CompilationMode.VLLM_COMPILE,

View File

@ -114,7 +114,9 @@ def test_get_gen_prompt(
trust_remote_code=model_info.trust_remote_code, trust_remote_code=model_info.trust_remote_code,
revision=model_info.revision, revision=model_info.revision,
hf_overrides=model_info.hf_overrides, hf_overrides=model_info.hf_overrides,
skip_tokenizer_init=model_info.skip_tokenizer_init, skip_tokenizer_init=model_info.require_embed_inputs,
enable_prompt_embeds=model_info.require_embed_inputs,
enable_mm_embeds=model_info.require_embed_inputs,
enforce_eager=model_info.enforce_eager, enforce_eager=model_info.enforce_eager,
dtype=model_info.dtype, dtype=model_info.dtype,
) )

View File

@ -1742,7 +1742,9 @@ def test_resolve_hf_chat_template(sample_json_schema, model, use_tools):
revision=model_info.revision, revision=model_info.revision,
trust_remote_code=model_info.trust_remote_code, trust_remote_code=model_info.trust_remote_code,
hf_overrides=model_info.hf_overrides, hf_overrides=model_info.hf_overrides,
skip_tokenizer_init=model_info.skip_tokenizer_init, skip_tokenizer_init=model_info.require_embed_inputs,
enable_prompt_embeds=model_info.require_embed_inputs,
enable_mm_embeds=model_info.require_embed_inputs,
enforce_eager=model_info.enforce_eager, enforce_eager=model_info.enforce_eager,
dtype=model_info.dtype, dtype=model_info.dtype,
) )
@ -1842,7 +1844,9 @@ def test_resolve_hf_chat_template_kwargs(sample_json_schema, model, expected_kwa
revision=model_info.revision, revision=model_info.revision,
trust_remote_code=model_info.trust_remote_code, trust_remote_code=model_info.trust_remote_code,
hf_overrides=model_info.hf_overrides, hf_overrides=model_info.hf_overrides,
skip_tokenizer_init=model_info.skip_tokenizer_init, skip_tokenizer_init=model_info.require_embed_inputs,
enable_prompt_embeds=model_info.require_embed_inputs,
enable_mm_embeds=model_info.require_embed_inputs,
enforce_eager=model_info.enforce_eager, enforce_eager=model_info.enforce_eager,
dtype=model_info.dtype, dtype=model_info.dtype,
) )
@ -1903,7 +1907,9 @@ def test_resolve_content_format_hf_defined(model, expected_format):
revision=model_info.revision, revision=model_info.revision,
trust_remote_code=model_info.trust_remote_code, trust_remote_code=model_info.trust_remote_code,
hf_overrides=model_info.hf_overrides, hf_overrides=model_info.hf_overrides,
skip_tokenizer_init=model_info.skip_tokenizer_init, skip_tokenizer_init=model_info.require_embed_inputs,
enable_prompt_embeds=model_info.require_embed_inputs,
enable_mm_embeds=model_info.require_embed_inputs,
enforce_eager=model_info.enforce_eager, enforce_eager=model_info.enforce_eager,
dtype=model_info.dtype, dtype=model_info.dtype,
) )
@ -1961,7 +1967,9 @@ def test_resolve_content_format_fallbacks(model, expected_format):
revision=model_info.revision, revision=model_info.revision,
trust_remote_code=model_info.trust_remote_code, trust_remote_code=model_info.trust_remote_code,
hf_overrides=model_info.hf_overrides, hf_overrides=model_info.hf_overrides,
skip_tokenizer_init=model_info.skip_tokenizer_init, skip_tokenizer_init=model_info.require_embed_inputs,
enable_prompt_embeds=model_info.require_embed_inputs,
enable_mm_embeds=model_info.require_embed_inputs,
enforce_eager=model_info.enforce_eager, enforce_eager=model_info.enforce_eager,
dtype=model_info.dtype, dtype=model_info.dtype,
) )

View File

@ -71,8 +71,9 @@ def run_test(
vllm_runner_kwargs_["tokenizer_mode"] = model_info.tokenizer_mode vllm_runner_kwargs_["tokenizer_mode"] = model_info.tokenizer_mode
if model_info.hf_overrides: if model_info.hf_overrides:
vllm_runner_kwargs_["hf_overrides"] = model_info.hf_overrides vllm_runner_kwargs_["hf_overrides"] = model_info.hf_overrides
if model_info.skip_tokenizer_init: if model_info.require_embed_inputs:
vllm_runner_kwargs_["skip_tokenizer_init"] = model_info.skip_tokenizer_init for k in ("skip_tokenizer_init", "enable_prompt_embeds", "enable_mm_embeds"):
vllm_runner_kwargs_[k] = model_info.require_embed_inputs
if vllm_runner_kwargs: if vllm_runner_kwargs:
vllm_runner_kwargs_.update(vllm_runner_kwargs) vllm_runner_kwargs_.update(vllm_runner_kwargs)

View File

@ -108,7 +108,9 @@ def _test_processing_correctness(
hf_overrides=model_info.hf_overrides, hf_overrides=model_info.hf_overrides,
# Ensure that the cache can fit all of the data # Ensure that the cache can fit all of the data
mm_processor_cache_gb=2048, mm_processor_cache_gb=2048,
skip_tokenizer_init=model_info.skip_tokenizer_init, skip_tokenizer_init=model_info.require_embed_inputs,
enable_prompt_embeds=model_info.require_embed_inputs,
enable_mm_embeds=model_info.require_embed_inputs,
enforce_eager=model_info.enforce_eager, enforce_eager=model_info.enforce_eager,
dtype=model_info.dtype, dtype=model_info.dtype,
) )

View File

@ -218,7 +218,9 @@ def test_model_tensor_schema(model_arch: str, model_id: str):
revision=model_info.revision, revision=model_info.revision,
trust_remote_code=model_info.trust_remote_code, trust_remote_code=model_info.trust_remote_code,
hf_overrides=hf_overrides_fn, hf_overrides=hf_overrides_fn,
skip_tokenizer_init=model_info.skip_tokenizer_init, skip_tokenizer_init=model_info.require_embed_inputs,
enable_prompt_embeds=model_info.require_embed_inputs,
enable_mm_embeds=model_info.require_embed_inputs,
enforce_eager=model_info.enforce_eager, enforce_eager=model_info.enforce_eager,
dtype=model_info.dtype, dtype=model_info.dtype,
) )

View File

@ -59,7 +59,9 @@ def test_hf_model_weights_mapper(model_arch: str):
revision=model_info.revision, revision=model_info.revision,
trust_remote_code=model_info.trust_remote_code, trust_remote_code=model_info.trust_remote_code,
hf_overrides=model_info.hf_overrides, hf_overrides=model_info.hf_overrides,
skip_tokenizer_init=model_info.skip_tokenizer_init, skip_tokenizer_init=model_info.require_embed_inputs,
enable_prompt_embeds=model_info.require_embed_inputs,
enable_mm_embeds=model_info.require_embed_inputs,
enforce_eager=model_info.enforce_eager, enforce_eager=model_info.enforce_eager,
dtype=model_info.dtype, dtype=model_info.dtype,
) )

View File

@ -6,7 +6,6 @@ from dataclasses import dataclass, field
from typing import Any, Literal from typing import Any, Literal
import pytest import pytest
import torch
from packaging.version import Version from packaging.version import Version
from transformers import __version__ as TRANSFORMERS_VERSION from transformers import __version__ as TRANSFORMERS_VERSION
@ -33,6 +32,11 @@ class _HfExamplesInfo:
for speculative decoding. for speculative decoding.
""" """
speculative_method: str | None = None
"""
The method to use for speculative decoding.
"""
min_transformers_version: str | None = None min_transformers_version: str | None = None
""" """
The minimum version of HF Transformers that is required to run this model. The minimum version of HF Transformers that is required to run this model.
@ -48,9 +52,10 @@ class _HfExamplesInfo:
The reason for the minimum/maximum version requirement. The reason for the minimum/maximum version requirement.
""" """
skip_tokenizer_init: bool = False require_embed_inputs: bool = False
""" """
If true, skip initialization of tokenizer and detokenizer. If `True`, enables prompt and multi-modal embedding inputs while
disabling tokenization.
""" """
dtype: ModelDType = "auto" dtype: ModelDType = "auto"
@ -168,10 +173,7 @@ class _HfExamplesInfo:
_TEXT_GENERATION_EXAMPLE_MODELS = { _TEXT_GENERATION_EXAMPLE_MODELS = {
# [Decoder-only] # [Decoder-only]
"ApertusForCausalLM": _HfExamplesInfo( "ApertusForCausalLM": _HfExamplesInfo("swiss-ai/Apertus-8B-Instruct-2509"),
"swiss-ai/Apertus-8B-Instruct-2509",
min_transformers_version="4.56.0",
),
"AquilaModel": _HfExamplesInfo("BAAI/AquilaChat-7B", trust_remote_code=True), "AquilaModel": _HfExamplesInfo("BAAI/AquilaChat-7B", trust_remote_code=True),
"AquilaForCausalLM": _HfExamplesInfo("BAAI/AquilaChat2-7B", trust_remote_code=True), "AquilaForCausalLM": _HfExamplesInfo("BAAI/AquilaChat2-7B", trust_remote_code=True),
"ArceeForCausalLM": _HfExamplesInfo("arcee-ai/AFM-4.5B-Base"), "ArceeForCausalLM": _HfExamplesInfo("arcee-ai/AFM-4.5B-Base"),
@ -192,7 +194,6 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
), ),
"BambaForCausalLM": _HfExamplesInfo( "BambaForCausalLM": _HfExamplesInfo(
"ibm-ai-platform/Bamba-9B-v1", "ibm-ai-platform/Bamba-9B-v1",
min_transformers_version="4.55.3",
extras={"tiny": "hmellor/tiny-random-BambaForCausalLM"}, extras={"tiny": "hmellor/tiny-random-BambaForCausalLM"},
), ),
"BloomForCausalLM": _HfExamplesInfo( "BloomForCausalLM": _HfExamplesInfo(
@ -212,11 +213,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"CohereForAI/c4ai-command-r7b-12-2024", "CohereForAI/c4ai-command-r7b-12-2024",
trust_remote_code=True, trust_remote_code=True,
), ),
"CwmForCausalLM": _HfExamplesInfo( "CwmForCausalLM": _HfExamplesInfo("facebook/cwm", min_transformers_version="4.58"),
"facebook/cwm",
trust_remote_code=True,
is_available_online=False,
),
"DbrxForCausalLM": _HfExamplesInfo("databricks/dbrx-instruct"), "DbrxForCausalLM": _HfExamplesInfo("databricks/dbrx-instruct"),
"DeciLMForCausalLM": _HfExamplesInfo( "DeciLMForCausalLM": _HfExamplesInfo(
"nvidia/Llama-3_3-Nemotron-Super-49B-v1", "nvidia/Llama-3_3-Nemotron-Super-49B-v1",
@ -232,18 +229,12 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
trust_remote_code=True, trust_remote_code=True,
), ),
"DeepseekV32ForCausalLM": _HfExamplesInfo("deepseek-ai/DeepSeek-V3.2-Exp"), "DeepseekV32ForCausalLM": _HfExamplesInfo("deepseek-ai/DeepSeek-V3.2-Exp"),
"Ernie4_5ForCausalLM": _HfExamplesInfo( "Ernie4_5ForCausalLM": _HfExamplesInfo("baidu/ERNIE-4.5-0.3B-PT"),
"baidu/ERNIE-4.5-0.3B-PT", min_transformers_version="4.54" "Ernie4_5_MoeForCausalLM": _HfExamplesInfo("baidu/ERNIE-4.5-21B-A3B-PT"),
),
"Ernie4_5_MoeForCausalLM": _HfExamplesInfo(
"baidu/ERNIE-4.5-21B-A3B-PT", min_transformers_version="4.54"
),
"ExaoneForCausalLM": _HfExamplesInfo( "ExaoneForCausalLM": _HfExamplesInfo(
"LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct", trust_remote_code=True "LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct", trust_remote_code=True
), ),
"Exaone4ForCausalLM": _HfExamplesInfo( "Exaone4ForCausalLM": _HfExamplesInfo("LGAI-EXAONE/EXAONE-4.0-32B"),
"LGAI-EXAONE/EXAONE-4.0-32B", min_transformers_version="4.54"
),
"Fairseq2LlamaForCausalLM": _HfExamplesInfo("mgleize/fairseq2-dummy-Llama-3.2-1B"), "Fairseq2LlamaForCausalLM": _HfExamplesInfo("mgleize/fairseq2-dummy-Llama-3.2-1B"),
"FalconForCausalLM": _HfExamplesInfo("tiiuae/falcon-7b"), "FalconForCausalLM": _HfExamplesInfo("tiiuae/falcon-7b"),
"FalconH1ForCausalLM": _HfExamplesInfo("tiiuae/Falcon-H1-0.5B-Base"), "FalconH1ForCausalLM": _HfExamplesInfo("tiiuae/Falcon-H1-0.5B-Base"),
@ -251,14 +242,10 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"GemmaForCausalLM": _HfExamplesInfo("google/gemma-1.1-2b-it"), "GemmaForCausalLM": _HfExamplesInfo("google/gemma-1.1-2b-it"),
"Gemma2ForCausalLM": _HfExamplesInfo("google/gemma-2-9b"), "Gemma2ForCausalLM": _HfExamplesInfo("google/gemma-2-9b"),
"Gemma3ForCausalLM": _HfExamplesInfo("google/gemma-3-1b-it"), "Gemma3ForCausalLM": _HfExamplesInfo("google/gemma-3-1b-it"),
"Gemma3nForCausalLM": _HfExamplesInfo( "Gemma3nForCausalLM": _HfExamplesInfo("google/gemma-3n-E2B-it"),
"google/gemma-3n-E2B-it", min_transformers_version="4.53"
),
"GlmForCausalLM": _HfExamplesInfo("zai-org/glm-4-9b-chat-hf"), "GlmForCausalLM": _HfExamplesInfo("zai-org/glm-4-9b-chat-hf"),
"Glm4ForCausalLM": _HfExamplesInfo("zai-org/GLM-4-9B-0414"), "Glm4ForCausalLM": _HfExamplesInfo("zai-org/GLM-4-9B-0414"),
"Glm4MoeForCausalLM": _HfExamplesInfo( "Glm4MoeForCausalLM": _HfExamplesInfo("zai-org/GLM-4.5"),
"zai-org/GLM-4.5", min_transformers_version="4.54"
),
"GPT2LMHeadModel": _HfExamplesInfo("openai-community/gpt2", {"alias": "gpt2"}), "GPT2LMHeadModel": _HfExamplesInfo("openai-community/gpt2", {"alias": "gpt2"}),
"GPTBigCodeForCausalLM": _HfExamplesInfo( "GPTBigCodeForCausalLM": _HfExamplesInfo(
"bigcode/starcoder", "bigcode/starcoder",
@ -266,8 +253,6 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"tiny": "bigcode/tiny_starcoder_py", "tiny": "bigcode/tiny_starcoder_py",
"santacoder": "bigcode/gpt_bigcode-santacoder", "santacoder": "bigcode/gpt_bigcode-santacoder",
}, },
min_transformers_version="4.55.1",
transformers_version_reason="HF model broken in 4.55.0",
), ),
"GPTJForCausalLM": _HfExamplesInfo( "GPTJForCausalLM": _HfExamplesInfo(
"Milos/slovak-gpt-j-405M", {"6b": "EleutherAI/gpt-j-6b"} "Milos/slovak-gpt-j-405M", {"6b": "EleutherAI/gpt-j-6b"}
@ -279,8 +264,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"GraniteForCausalLM": _HfExamplesInfo("ibm/PowerLM-3b"), "GraniteForCausalLM": _HfExamplesInfo("ibm/PowerLM-3b"),
"GraniteMoeForCausalLM": _HfExamplesInfo("ibm/PowerMoE-3b"), "GraniteMoeForCausalLM": _HfExamplesInfo("ibm/PowerMoE-3b"),
"GraniteMoeHybridForCausalLM": _HfExamplesInfo( "GraniteMoeHybridForCausalLM": _HfExamplesInfo(
"ibm-granite/granite-4.0-tiny-preview", "ibm-granite/granite-4.0-tiny-preview"
min_transformers_version="4.55.3",
), ),
"GraniteMoeSharedForCausalLM": _HfExamplesInfo( "GraniteMoeSharedForCausalLM": _HfExamplesInfo(
"ibm-research/moe-7b-1b-active-shared-experts" "ibm-research/moe-7b-1b-active-shared-experts"
@ -288,15 +272,10 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"Grok1ModelForCausalLM": _HfExamplesInfo( "Grok1ModelForCausalLM": _HfExamplesInfo(
"hpcai-tech/grok-1", trust_remote_code=True "hpcai-tech/grok-1", trust_remote_code=True
), ),
"HunYuanDenseV1ForCausalLM": _HfExamplesInfo("tencent/Hunyuan-7B-Instruct"),
"HunYuanMoEV1ForCausalLM": _HfExamplesInfo( "HunYuanMoEV1ForCausalLM": _HfExamplesInfo(
"tencent/Hunyuan-A13B-Instruct", trust_remote_code=True "tencent/Hunyuan-A13B-Instruct", trust_remote_code=True
), ),
# TODO: Remove is_available_online once their config.json is fixed
"HunYuanDenseV1ForCausalLM": _HfExamplesInfo(
"tencent/Hunyuan-7B-Instruct-0124",
trust_remote_code=True,
is_available_online=False,
),
"InternLMForCausalLM": _HfExamplesInfo( "InternLMForCausalLM": _HfExamplesInfo(
"internlm/internlm-chat-7b", trust_remote_code=True "internlm/internlm-chat-7b", trust_remote_code=True
), ),
@ -312,15 +291,12 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"JAISLMHeadModel": _HfExamplesInfo("inceptionai/jais-13b-chat"), "JAISLMHeadModel": _HfExamplesInfo("inceptionai/jais-13b-chat"),
"JambaForCausalLM": _HfExamplesInfo( "JambaForCausalLM": _HfExamplesInfo(
"ai21labs/AI21-Jamba-1.5-Mini", "ai21labs/AI21-Jamba-1.5-Mini",
min_transformers_version="4.55.3",
extras={ extras={
"tiny": "ai21labs/Jamba-tiny-dev", "tiny": "ai21labs/Jamba-tiny-dev",
"random": "ai21labs/Jamba-tiny-random", "random": "ai21labs/Jamba-tiny-random",
}, },
), ),
"Lfm2ForCausalLM": _HfExamplesInfo( "Lfm2ForCausalLM": _HfExamplesInfo("LiquidAI/LFM2-1.2B"),
"LiquidAI/LFM2-1.2B", min_transformers_version="4.54"
),
"Lfm2MoeForCausalLM": _HfExamplesInfo( "Lfm2MoeForCausalLM": _HfExamplesInfo(
"LiquidAI/LFM2-8B-A1B", min_transformers_version="4.58" "LiquidAI/LFM2-8B-A1B", min_transformers_version="4.58"
), ),
@ -338,7 +314,6 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
), ),
"Llama4ForCausalLM": _HfExamplesInfo( "Llama4ForCausalLM": _HfExamplesInfo(
"meta-llama/Llama-4-Scout-17B-16E-Instruct", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
is_available_online=False,
), ),
"LongcatFlashForCausalLM": _HfExamplesInfo( "LongcatFlashForCausalLM": _HfExamplesInfo(
"meituan-longcat/LongCat-Flash-Chat", trust_remote_code=True "meituan-longcat/LongCat-Flash-Chat", trust_remote_code=True
@ -346,7 +321,6 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"MambaForCausalLM": _HfExamplesInfo("state-spaces/mamba-130m-hf"), "MambaForCausalLM": _HfExamplesInfo("state-spaces/mamba-130m-hf"),
"Mamba2ForCausalLM": _HfExamplesInfo( "Mamba2ForCausalLM": _HfExamplesInfo(
"mistralai/Mamba-Codestral-7B-v0.1", "mistralai/Mamba-Codestral-7B-v0.1",
min_transformers_version="4.55.3",
extras={ extras={
"random": "yujiepan/mamba2-codestral-v0.1-tiny-random", "random": "yujiepan/mamba2-codestral-v0.1-tiny-random",
}, },
@ -421,7 +395,6 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"SeedOssForCausalLM": _HfExamplesInfo( "SeedOssForCausalLM": _HfExamplesInfo(
"ByteDance-Seed/Seed-OSS-36B-Instruct", "ByteDance-Seed/Seed-OSS-36B-Instruct",
trust_remote_code=True, trust_remote_code=True,
is_available_online=False,
), ),
"SmolLM3ForCausalLM": _HfExamplesInfo("HuggingFaceTB/SmolLM3-3B"), "SmolLM3ForCausalLM": _HfExamplesInfo("HuggingFaceTB/SmolLM3-3B"),
"StableLMEpochForCausalLM": _HfExamplesInfo("stabilityai/stablelm-zephyr-3b"), "StableLMEpochForCausalLM": _HfExamplesInfo("stabilityai/stablelm-zephyr-3b"),
@ -488,7 +461,8 @@ _EMBEDDING_EXAMPLE_MODELS = {
"RobertaForMaskedLM": _HfExamplesInfo("sentence-transformers/all-roberta-large-v1"), "RobertaForMaskedLM": _HfExamplesInfo("sentence-transformers/all-roberta-large-v1"),
"XLMRobertaModel": _HfExamplesInfo("intfloat/multilingual-e5-small"), "XLMRobertaModel": _HfExamplesInfo("intfloat/multilingual-e5-small"),
"BertSpladeSparseEmbeddingModel": _HfExamplesInfo( "BertSpladeSparseEmbeddingModel": _HfExamplesInfo(
"naver/splade-v3", is_available_online=False "naver/splade-v3",
hf_overrides={"architectures": ["BertSpladeSparseEmbeddingModel"]},
), ),
# [Multimodal] # [Multimodal]
"CLIPModel": _HfExamplesInfo("openai/clip-vit-base-patch32"), "CLIPModel": _HfExamplesInfo("openai/clip-vit-base-patch32"),
@ -499,18 +473,17 @@ _EMBEDDING_EXAMPLE_MODELS = {
"Qwen2VLForConditionalGeneration": _HfExamplesInfo("MrLight/dse-qwen2-2b-mrl-v1"), "Qwen2VLForConditionalGeneration": _HfExamplesInfo("MrLight/dse-qwen2-2b-mrl-v1"),
"PrithviGeoSpatialMAE": _HfExamplesInfo( "PrithviGeoSpatialMAE": _HfExamplesInfo(
"ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11", "ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11",
dtype=torch.float16, dtype="float16",
enforce_eager=True, enforce_eager=True,
skip_tokenizer_init=True, require_embed_inputs=True,
# This is to avoid the model # This is to avoid the model going OOM in CI
# going OOM in CI
max_num_seqs=32, max_num_seqs=32,
), ),
"Terratorch": _HfExamplesInfo( "Terratorch": _HfExamplesInfo(
"ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11", "ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11",
dtype=torch.float16, dtype="float16",
enforce_eager=True, enforce_eager=True,
skip_tokenizer_init=True, require_embed_inputs=True,
# This is to avoid the model going OOM in CI # This is to avoid the model going OOM in CI
max_num_seqs=32, max_num_seqs=32,
), ),
@ -598,10 +571,7 @@ _MULTIMODAL_EXAMPLE_MODELS = {
), ),
"FuyuForCausalLM": _HfExamplesInfo("adept/fuyu-8b"), "FuyuForCausalLM": _HfExamplesInfo("adept/fuyu-8b"),
"Gemma3ForConditionalGeneration": _HfExamplesInfo("google/gemma-3-4b-it"), "Gemma3ForConditionalGeneration": _HfExamplesInfo("google/gemma-3-4b-it"),
"Gemma3nForConditionalGeneration": _HfExamplesInfo( "Gemma3nForConditionalGeneration": _HfExamplesInfo("google/gemma-3n-E2B-it"),
"google/gemma-3n-E2B-it",
min_transformers_version="4.53",
),
"GraniteSpeechForConditionalGeneration": _HfExamplesInfo( "GraniteSpeechForConditionalGeneration": _HfExamplesInfo(
"ibm-granite/granite-speech-3.3-2b" "ibm-granite/granite-speech-3.3-2b"
), ),
@ -611,9 +581,7 @@ _MULTIMODAL_EXAMPLE_MODELS = {
hf_overrides={"architectures": ["GLM4VForCausalLM"]}, hf_overrides={"architectures": ["GLM4VForCausalLM"]},
), ),
"Glm4vForConditionalGeneration": _HfExamplesInfo("zai-org/GLM-4.1V-9B-Thinking"), "Glm4vForConditionalGeneration": _HfExamplesInfo("zai-org/GLM-4.1V-9B-Thinking"),
"Glm4vMoeForConditionalGeneration": _HfExamplesInfo( "Glm4vMoeForConditionalGeneration": _HfExamplesInfo("zai-org/GLM-4.5V"),
"zai-org/GLM-4.5V", min_transformers_version="4.56"
),
"H2OVLChatModel": _HfExamplesInfo( "H2OVLChatModel": _HfExamplesInfo(
"h2oai/h2ovl-mississippi-800m", "h2oai/h2ovl-mississippi-800m",
trust_remote_code=True, trust_remote_code=True,
@ -627,9 +595,7 @@ _MULTIMODAL_EXAMPLE_MODELS = {
), ),
"Idefics3ForConditionalGeneration": _HfExamplesInfo( "Idefics3ForConditionalGeneration": _HfExamplesInfo(
"HuggingFaceM4/Idefics3-8B-Llama3", "HuggingFaceM4/Idefics3-8B-Llama3",
{"tiny": "HuggingFaceTB/SmolVLM-256M-Instruct"}, extras={"tiny": "HuggingFaceTB/SmolVLM-256M-Instruct"},
min_transformers_version="4.56",
transformers_version_reason="HF model broken in 4.55",
), ),
"InternS1ForConditionalGeneration": _HfExamplesInfo( "InternS1ForConditionalGeneration": _HfExamplesInfo(
"internlm/Intern-S1", trust_remote_code=True "internlm/Intern-S1", trust_remote_code=True
@ -781,13 +747,11 @@ _MULTIMODAL_EXAMPLE_MODELS = {
"Qwen/Qwen3-VL-4B-Instruct", "Qwen/Qwen3-VL-4B-Instruct",
max_model_len=4096, max_model_len=4096,
min_transformers_version="4.57", min_transformers_version="4.57",
is_available_online=False,
), ),
"Qwen3VLMoeForConditionalGeneration": _HfExamplesInfo( "Qwen3VLMoeForConditionalGeneration": _HfExamplesInfo(
"Qwen/Qwen3-VL-30B-A3B-Instruct", "Qwen/Qwen3-VL-30B-A3B-Instruct",
max_model_len=4096, max_model_len=4096,
min_transformers_version="4.57", min_transformers_version="4.57",
is_available_online=False,
), ),
"Qwen3OmniMoeForConditionalGeneration": _HfExamplesInfo( "Qwen3OmniMoeForConditionalGeneration": _HfExamplesInfo(
"Qwen/Qwen3-Omni-30B-A3B-Instruct", "Qwen/Qwen3-Omni-30B-A3B-Instruct",
@ -799,9 +763,7 @@ _MULTIMODAL_EXAMPLE_MODELS = {
"Skywork/Skywork-R1V-38B", trust_remote_code=True "Skywork/Skywork-R1V-38B", trust_remote_code=True
), ),
"SmolVLMForConditionalGeneration": _HfExamplesInfo( "SmolVLMForConditionalGeneration": _HfExamplesInfo(
"HuggingFaceTB/SmolVLM2-2.2B-Instruct", "HuggingFaceTB/SmolVLM2-2.2B-Instruct"
min_transformers_version="4.56",
transformers_version_reason="HF model broken in 4.55",
), ),
"Step3VLForConditionalGeneration": _HfExamplesInfo( "Step3VLForConditionalGeneration": _HfExamplesInfo(
"stepfun-ai/step3", trust_remote_code=True "stepfun-ai/step3", trust_remote_code=True
@ -817,7 +779,6 @@ _MULTIMODAL_EXAMPLE_MODELS = {
), ),
"VoxtralForConditionalGeneration": _HfExamplesInfo( "VoxtralForConditionalGeneration": _HfExamplesInfo(
"mistralai/Voxtral-Mini-3B-2507", "mistralai/Voxtral-Mini-3B-2507",
min_transformers_version="4.54",
# disable this temporarily until we support HF format # disable this temporarily until we support HF format
is_available_online=False, is_available_online=False,
), ),
@ -878,8 +839,8 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
"EagleMiniCPMForCausalLM": _HfExamplesInfo( "EagleMiniCPMForCausalLM": _HfExamplesInfo(
"openbmb/MiniCPM-1B-sft-bf16", "openbmb/MiniCPM-1B-sft-bf16",
trust_remote_code=True, trust_remote_code=True,
is_available_online=False,
speculative_model="openbmb/MiniCPM-2B-sft-bf16", speculative_model="openbmb/MiniCPM-2B-sft-bf16",
speculative_method="eagle",
tokenizer="openbmb/MiniCPM-2B-sft-bf16", tokenizer="openbmb/MiniCPM-2B-sft-bf16",
), ),
"ErnieMTPModel": _HfExamplesInfo( "ErnieMTPModel": _HfExamplesInfo(
@ -890,8 +851,6 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
"Glm4MoeMTPModel": _HfExamplesInfo( "Glm4MoeMTPModel": _HfExamplesInfo(
"zai-org/GLM-4.5", "zai-org/GLM-4.5",
speculative_model="zai-org/GLM-4.5", speculative_model="zai-org/GLM-4.5",
min_transformers_version="4.56",
is_available_online=False,
), ),
"LongCatFlashMTPModel": _HfExamplesInfo( "LongCatFlashMTPModel": _HfExamplesInfo(
"meituan-longcat/LongCat-Flash-Chat", "meituan-longcat/LongCat-Flash-Chat",

View File

@ -105,20 +105,19 @@ def can_initialize(
if model_arch == "WhisperForConditionalGeneration": if model_arch == "WhisperForConditionalGeneration":
m.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn") m.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
extra_args = {}
if model_arch in ("PrithviGeoSpatialMAE", "Terratorch"):
extra_args["enable_mm_embeds"] = True
LLM( LLM(
model_info.default, model_info.default,
tokenizer=model_info.tokenizer, tokenizer=model_info.tokenizer,
tokenizer_mode=model_info.tokenizer_mode, tokenizer_mode=model_info.tokenizer_mode,
revision=model_info.revision, revision=model_info.revision,
enforce_eager=model_info.enforce_eager, enforce_eager=model_info.enforce_eager,
skip_tokenizer_init=model_info.skip_tokenizer_init, skip_tokenizer_init=model_info.require_embed_inputs,
enable_prompt_embeds=model_info.require_embed_inputs,
enable_mm_embeds=model_info.require_embed_inputs,
dtype=model_info.dtype, dtype=model_info.dtype,
speculative_config={ speculative_config={
"model": model_info.speculative_model, "model": model_info.speculative_model,
"method": model_info.speculative_method,
"num_speculative_tokens": 1, "num_speculative_tokens": 1,
} }
if model_info.speculative_model if model_info.speculative_model
@ -133,7 +132,6 @@ def can_initialize(
else "vllm", else "vllm",
hf_overrides=hf_overrides_fn, hf_overrides=hf_overrides_fn,
max_num_seqs=model_info.max_num_seqs, max_num_seqs=model_info.max_num_seqs,
**extra_args,
) )

View File

@ -309,7 +309,9 @@ def build_model_context(
limit_mm_per_prompt=limit_mm_per_prompt, limit_mm_per_prompt=limit_mm_per_prompt,
mm_processor_cache_gb=mm_processor_cache_gb, mm_processor_cache_gb=mm_processor_cache_gb,
hf_overrides=model_info.hf_overrides, hf_overrides=model_info.hf_overrides,
skip_tokenizer_init=model_info.skip_tokenizer_init, skip_tokenizer_init=model_info.require_embed_inputs,
enable_prompt_embeds=model_info.require_embed_inputs,
enable_mm_embeds=model_info.require_embed_inputs,
enforce_eager=model_info.enforce_eager, enforce_eager=model_info.enforce_eager,
**model_config_kwargs, **model_config_kwargs,
) )

View File

@ -36,9 +36,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from einops import rearrange from einops import rearrange
from packaging.version import Version
from transformers import BatchFeature from transformers import BatchFeature
from transformers import __version__ as TRANSFORMERS_VERSION
from transformers.models.glm4v.configuration_glm4v import Glm4vVisionConfig from transformers.models.glm4v.configuration_glm4v import Glm4vVisionConfig
from transformers.models.glm4v.image_processing_glm4v import ( from transformers.models.glm4v.image_processing_glm4v import (
Glm4vImageProcessor, Glm4vImageProcessor,
@ -1270,14 +1268,7 @@ class Glm4vMultiModalProcessor(BaseMultiModalProcessor[Glm4vProcessingInfo]):
video_mm_data = dict() video_mm_data = dict()
video_mm_data["videos"] = [[video_array]] video_mm_data["videos"] = [[video_array]]
# backward compatibility for Transformers 4.55
unuse_metadata = ["do_sample_frames"] unuse_metadata = ["do_sample_frames"]
if (
not hasattr(VideoMetadata, "frames_indices")
and "frames_indices" in metadata
):
unuse_metadata.append("frames_indices")
video_mm_data["video_metadata"] = [ video_mm_data["video_metadata"] = [
[ [
VideoMetadata( VideoMetadata(
@ -1296,24 +1287,11 @@ class Glm4vMultiModalProcessor(BaseMultiModalProcessor[Glm4vProcessingInfo]):
mm_kwargs=video_mm_kwargs, mm_kwargs=video_mm_kwargs,
tok_kwargs=tok_kwargs, tok_kwargs=tok_kwargs,
) )
if not video_mm_kwargs["do_sample_frames"] and Version( input_ids = video_outputs.pop("input_ids")
TRANSFORMERS_VERSION input_ids[input_ids == processor.image_token_id] = (
) < Version("4.56.0"): processor.video_token_id
# Transformers v4.55 has incorrect timestamps issue for )
# skip sampling. We construct the placeholder manually to video_placeholder = processor.tokenizer.batch_decode(input_ids)[0]
# get placeholders with correct timestamps.
placeholder = self.info._construct_video_placeholder(
video_array,
metadata,
video_outputs["video_grid_thw"].squeeze(0),
)
video_placeholder = processor.tokenizer.decode(placeholder)
else:
input_ids = video_outputs.pop("input_ids")
input_ids[input_ids == processor.image_token_id] = (
processor.video_token_id
)
video_placeholder = processor.tokenizer.batch_decode(input_ids)[0]
prompt = prompt.replace( prompt = prompt.replace(
"<|begin_of_video|><|video|><|end_of_video|>", "<|begin_of_video|><|video|><|end_of_video|>",
video_placeholder, video_placeholder,