[CI/Build] Remove unnecessary flags from test registry (#27353)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-10-23 22:42:40 +08:00 committed by GitHub
parent 237cf6d32a
commit fe2016de2d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 89 additions and 123 deletions

View File

@ -374,8 +374,8 @@ th {
| `GraniteMoeSharedForCausalLM` | Granite MoE Shared | `ibm-research/moe-7b-1b-active-shared-experts` (test model) | ✅︎ | ✅︎ |
| `GritLM` | GritLM | `parasail-ai/GritLM-7B-vllm`. | ✅︎ | ✅︎ |
| `Grok1ModelForCausalLM` | Grok1 | `hpcai-tech/grok-1`. | ✅︎ | ✅︎ |
| `HunYuanDenseV1ForCausalLM` | Hunyuan-7B-Instruct-0124 | `tencent/Hunyuan-7B-Instruct-0124` | ✅︎ | ✅︎ |
| `HunYuanMoEV1ForCausalLM` | Hunyuan-80B-A13B | `tencent/Hunyuan-A13B-Instruct`, `tencent/Hunyuan-A13B-Pretrain`, `tencent/Hunyuan-A13B-Instruct-FP8`, etc. | ✅︎ | ✅︎ |
| `HunYuanDenseV1ForCausalLM` | Hunyuan Dense | `tencent/Hunyuan-7B-Instruct` | ✅︎ | ✅︎ |
| `HunYuanMoEV1ForCausalLM` | Hunyuan-A13B | `tencent/Hunyuan-A13B-Instruct`, `tencent/Hunyuan-A13B-Pretrain`, `tencent/Hunyuan-A13B-Instruct-FP8`, etc. | ✅︎ | ✅︎ |
| `HCXVisionForCausalLM` | HyperCLOVAX-SEED-Vision-Instruct-3B | `naver-hyperclovax/HyperCLOVAX-SEED-Vision-Instruct-3B` | | |
| `InternLMForCausalLM` | InternLM | `internlm/internlm-7b`, `internlm/internlm-chat-7b`, etc. | ✅︎ | ✅︎ |
| `InternLM2ForCausalLM` | InternLM2 | `internlm/internlm2-7b`, `internlm/internlm2-chat-7b`, etc. | ✅︎ | ✅︎ |

View File

@ -244,7 +244,7 @@ def _compare_tp(
tokenizer_mode = model_info.tokenizer_mode
hf_overrides = model_info.hf_overrides
hf_config = get_config(model_id, trust_remote_code)
skip_tokenizer_init = model_info.skip_tokenizer_init
require_embed_inputs = model_info.require_embed_inputs
max_num_seqs = model_info.max_num_seqs
dtype = "float16"
@ -299,8 +299,14 @@ def _compare_tp(
common_args.extend(["--load-format", load_format])
if hf_overrides:
common_args.extend(["--hf-overrides", json.dumps(hf_overrides)])
if skip_tokenizer_init:
common_args.append("--skip-tokenizer-init")
if require_embed_inputs:
common_args.extend(
[
"--skip-tokenizer-init",
"--enable-prompt-embeds",
"--enable-mm-embeds",
]
)
if max_num_seqs:
common_args.extend(["--max-num-seqs", f"{max_num_seqs}"])

View File

@ -181,7 +181,7 @@ def _compare_sp(
trust_remote_code = model_info.trust_remote_code
tokenizer_mode = model_info.tokenizer_mode
hf_overrides = model_info.hf_overrides
skip_tokenizer_init = model_info.skip_tokenizer_init
require_embed_inputs = model_info.require_embed_inputs
if load_format == "dummy":
# Avoid OOM
@ -233,8 +233,14 @@ def _compare_sp(
common_args.extend(["--load-format", load_format])
if hf_overrides:
common_args.extend(["--hf-overrides", json.dumps(hf_overrides)])
if skip_tokenizer_init:
common_args.append("--skip-tokenizer-init")
if require_embed_inputs:
common_args.extend(
[
"--skip-tokenizer-init",
"--enable-prompt-embeds",
"--enable-mm-embeds",
]
)
compilation_config = {
"mode": CompilationMode.VLLM_COMPILE,

View File

@ -114,7 +114,9 @@ def test_get_gen_prompt(
trust_remote_code=model_info.trust_remote_code,
revision=model_info.revision,
hf_overrides=model_info.hf_overrides,
skip_tokenizer_init=model_info.skip_tokenizer_init,
skip_tokenizer_init=model_info.require_embed_inputs,
enable_prompt_embeds=model_info.require_embed_inputs,
enable_mm_embeds=model_info.require_embed_inputs,
enforce_eager=model_info.enforce_eager,
dtype=model_info.dtype,
)

View File

@ -1742,7 +1742,9 @@ def test_resolve_hf_chat_template(sample_json_schema, model, use_tools):
revision=model_info.revision,
trust_remote_code=model_info.trust_remote_code,
hf_overrides=model_info.hf_overrides,
skip_tokenizer_init=model_info.skip_tokenizer_init,
skip_tokenizer_init=model_info.require_embed_inputs,
enable_prompt_embeds=model_info.require_embed_inputs,
enable_mm_embeds=model_info.require_embed_inputs,
enforce_eager=model_info.enforce_eager,
dtype=model_info.dtype,
)
@ -1842,7 +1844,9 @@ def test_resolve_hf_chat_template_kwargs(sample_json_schema, model, expected_kwa
revision=model_info.revision,
trust_remote_code=model_info.trust_remote_code,
hf_overrides=model_info.hf_overrides,
skip_tokenizer_init=model_info.skip_tokenizer_init,
skip_tokenizer_init=model_info.require_embed_inputs,
enable_prompt_embeds=model_info.require_embed_inputs,
enable_mm_embeds=model_info.require_embed_inputs,
enforce_eager=model_info.enforce_eager,
dtype=model_info.dtype,
)
@ -1903,7 +1907,9 @@ def test_resolve_content_format_hf_defined(model, expected_format):
revision=model_info.revision,
trust_remote_code=model_info.trust_remote_code,
hf_overrides=model_info.hf_overrides,
skip_tokenizer_init=model_info.skip_tokenizer_init,
skip_tokenizer_init=model_info.require_embed_inputs,
enable_prompt_embeds=model_info.require_embed_inputs,
enable_mm_embeds=model_info.require_embed_inputs,
enforce_eager=model_info.enforce_eager,
dtype=model_info.dtype,
)
@ -1961,7 +1967,9 @@ def test_resolve_content_format_fallbacks(model, expected_format):
revision=model_info.revision,
trust_remote_code=model_info.trust_remote_code,
hf_overrides=model_info.hf_overrides,
skip_tokenizer_init=model_info.skip_tokenizer_init,
skip_tokenizer_init=model_info.require_embed_inputs,
enable_prompt_embeds=model_info.require_embed_inputs,
enable_mm_embeds=model_info.require_embed_inputs,
enforce_eager=model_info.enforce_eager,
dtype=model_info.dtype,
)

View File

@ -71,8 +71,9 @@ def run_test(
vllm_runner_kwargs_["tokenizer_mode"] = model_info.tokenizer_mode
if model_info.hf_overrides:
vllm_runner_kwargs_["hf_overrides"] = model_info.hf_overrides
if model_info.skip_tokenizer_init:
vllm_runner_kwargs_["skip_tokenizer_init"] = model_info.skip_tokenizer_init
if model_info.require_embed_inputs:
for k in ("skip_tokenizer_init", "enable_prompt_embeds", "enable_mm_embeds"):
vllm_runner_kwargs_[k] = model_info.require_embed_inputs
if vllm_runner_kwargs:
vllm_runner_kwargs_.update(vllm_runner_kwargs)

View File

@ -108,7 +108,9 @@ def _test_processing_correctness(
hf_overrides=model_info.hf_overrides,
# Ensure that the cache can fit all of the data
mm_processor_cache_gb=2048,
skip_tokenizer_init=model_info.skip_tokenizer_init,
skip_tokenizer_init=model_info.require_embed_inputs,
enable_prompt_embeds=model_info.require_embed_inputs,
enable_mm_embeds=model_info.require_embed_inputs,
enforce_eager=model_info.enforce_eager,
dtype=model_info.dtype,
)

View File

@ -218,7 +218,9 @@ def test_model_tensor_schema(model_arch: str, model_id: str):
revision=model_info.revision,
trust_remote_code=model_info.trust_remote_code,
hf_overrides=hf_overrides_fn,
skip_tokenizer_init=model_info.skip_tokenizer_init,
skip_tokenizer_init=model_info.require_embed_inputs,
enable_prompt_embeds=model_info.require_embed_inputs,
enable_mm_embeds=model_info.require_embed_inputs,
enforce_eager=model_info.enforce_eager,
dtype=model_info.dtype,
)

View File

@ -59,7 +59,9 @@ def test_hf_model_weights_mapper(model_arch: str):
revision=model_info.revision,
trust_remote_code=model_info.trust_remote_code,
hf_overrides=model_info.hf_overrides,
skip_tokenizer_init=model_info.skip_tokenizer_init,
skip_tokenizer_init=model_info.require_embed_inputs,
enable_prompt_embeds=model_info.require_embed_inputs,
enable_mm_embeds=model_info.require_embed_inputs,
enforce_eager=model_info.enforce_eager,
dtype=model_info.dtype,
)

View File

@ -6,7 +6,6 @@ from dataclasses import dataclass, field
from typing import Any, Literal
import pytest
import torch
from packaging.version import Version
from transformers import __version__ as TRANSFORMERS_VERSION
@ -33,6 +32,11 @@ class _HfExamplesInfo:
for speculative decoding.
"""
speculative_method: str | None = None
"""
The method to use for speculative decoding.
"""
min_transformers_version: str | None = None
"""
The minimum version of HF Transformers that is required to run this model.
@ -48,9 +52,10 @@ class _HfExamplesInfo:
The reason for the minimum/maximum version requirement.
"""
skip_tokenizer_init: bool = False
require_embed_inputs: bool = False
"""
If true, skip initialization of tokenizer and detokenizer.
If `True`, enables prompt and multi-modal embedding inputs while
disabling tokenization.
"""
dtype: ModelDType = "auto"
@ -168,10 +173,7 @@ class _HfExamplesInfo:
_TEXT_GENERATION_EXAMPLE_MODELS = {
# [Decoder-only]
"ApertusForCausalLM": _HfExamplesInfo(
"swiss-ai/Apertus-8B-Instruct-2509",
min_transformers_version="4.56.0",
),
"ApertusForCausalLM": _HfExamplesInfo("swiss-ai/Apertus-8B-Instruct-2509"),
"AquilaModel": _HfExamplesInfo("BAAI/AquilaChat-7B", trust_remote_code=True),
"AquilaForCausalLM": _HfExamplesInfo("BAAI/AquilaChat2-7B", trust_remote_code=True),
"ArceeForCausalLM": _HfExamplesInfo("arcee-ai/AFM-4.5B-Base"),
@ -192,7 +194,6 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
),
"BambaForCausalLM": _HfExamplesInfo(
"ibm-ai-platform/Bamba-9B-v1",
min_transformers_version="4.55.3",
extras={"tiny": "hmellor/tiny-random-BambaForCausalLM"},
),
"BloomForCausalLM": _HfExamplesInfo(
@ -212,11 +213,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"CohereForAI/c4ai-command-r7b-12-2024",
trust_remote_code=True,
),
"CwmForCausalLM": _HfExamplesInfo(
"facebook/cwm",
trust_remote_code=True,
is_available_online=False,
),
"CwmForCausalLM": _HfExamplesInfo("facebook/cwm", min_transformers_version="4.58"),
"DbrxForCausalLM": _HfExamplesInfo("databricks/dbrx-instruct"),
"DeciLMForCausalLM": _HfExamplesInfo(
"nvidia/Llama-3_3-Nemotron-Super-49B-v1",
@ -232,18 +229,12 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
trust_remote_code=True,
),
"DeepseekV32ForCausalLM": _HfExamplesInfo("deepseek-ai/DeepSeek-V3.2-Exp"),
"Ernie4_5ForCausalLM": _HfExamplesInfo(
"baidu/ERNIE-4.5-0.3B-PT", min_transformers_version="4.54"
),
"Ernie4_5_MoeForCausalLM": _HfExamplesInfo(
"baidu/ERNIE-4.5-21B-A3B-PT", min_transformers_version="4.54"
),
"Ernie4_5ForCausalLM": _HfExamplesInfo("baidu/ERNIE-4.5-0.3B-PT"),
"Ernie4_5_MoeForCausalLM": _HfExamplesInfo("baidu/ERNIE-4.5-21B-A3B-PT"),
"ExaoneForCausalLM": _HfExamplesInfo(
"LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct", trust_remote_code=True
),
"Exaone4ForCausalLM": _HfExamplesInfo(
"LGAI-EXAONE/EXAONE-4.0-32B", min_transformers_version="4.54"
),
"Exaone4ForCausalLM": _HfExamplesInfo("LGAI-EXAONE/EXAONE-4.0-32B"),
"Fairseq2LlamaForCausalLM": _HfExamplesInfo("mgleize/fairseq2-dummy-Llama-3.2-1B"),
"FalconForCausalLM": _HfExamplesInfo("tiiuae/falcon-7b"),
"FalconH1ForCausalLM": _HfExamplesInfo("tiiuae/Falcon-H1-0.5B-Base"),
@ -251,14 +242,10 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"GemmaForCausalLM": _HfExamplesInfo("google/gemma-1.1-2b-it"),
"Gemma2ForCausalLM": _HfExamplesInfo("google/gemma-2-9b"),
"Gemma3ForCausalLM": _HfExamplesInfo("google/gemma-3-1b-it"),
"Gemma3nForCausalLM": _HfExamplesInfo(
"google/gemma-3n-E2B-it", min_transformers_version="4.53"
),
"Gemma3nForCausalLM": _HfExamplesInfo("google/gemma-3n-E2B-it"),
"GlmForCausalLM": _HfExamplesInfo("zai-org/glm-4-9b-chat-hf"),
"Glm4ForCausalLM": _HfExamplesInfo("zai-org/GLM-4-9B-0414"),
"Glm4MoeForCausalLM": _HfExamplesInfo(
"zai-org/GLM-4.5", min_transformers_version="4.54"
),
"Glm4MoeForCausalLM": _HfExamplesInfo("zai-org/GLM-4.5"),
"GPT2LMHeadModel": _HfExamplesInfo("openai-community/gpt2", {"alias": "gpt2"}),
"GPTBigCodeForCausalLM": _HfExamplesInfo(
"bigcode/starcoder",
@ -266,8 +253,6 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"tiny": "bigcode/tiny_starcoder_py",
"santacoder": "bigcode/gpt_bigcode-santacoder",
},
min_transformers_version="4.55.1",
transformers_version_reason="HF model broken in 4.55.0",
),
"GPTJForCausalLM": _HfExamplesInfo(
"Milos/slovak-gpt-j-405M", {"6b": "EleutherAI/gpt-j-6b"}
@ -279,8 +264,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"GraniteForCausalLM": _HfExamplesInfo("ibm/PowerLM-3b"),
"GraniteMoeForCausalLM": _HfExamplesInfo("ibm/PowerMoE-3b"),
"GraniteMoeHybridForCausalLM": _HfExamplesInfo(
"ibm-granite/granite-4.0-tiny-preview",
min_transformers_version="4.55.3",
"ibm-granite/granite-4.0-tiny-preview"
),
"GraniteMoeSharedForCausalLM": _HfExamplesInfo(
"ibm-research/moe-7b-1b-active-shared-experts"
@ -288,15 +272,10 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"Grok1ModelForCausalLM": _HfExamplesInfo(
"hpcai-tech/grok-1", trust_remote_code=True
),
"HunYuanDenseV1ForCausalLM": _HfExamplesInfo("tencent/Hunyuan-7B-Instruct"),
"HunYuanMoEV1ForCausalLM": _HfExamplesInfo(
"tencent/Hunyuan-A13B-Instruct", trust_remote_code=True
),
# TODO: Remove is_available_online once their config.json is fixed
"HunYuanDenseV1ForCausalLM": _HfExamplesInfo(
"tencent/Hunyuan-7B-Instruct-0124",
trust_remote_code=True,
is_available_online=False,
),
"InternLMForCausalLM": _HfExamplesInfo(
"internlm/internlm-chat-7b", trust_remote_code=True
),
@ -312,15 +291,12 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"JAISLMHeadModel": _HfExamplesInfo("inceptionai/jais-13b-chat"),
"JambaForCausalLM": _HfExamplesInfo(
"ai21labs/AI21-Jamba-1.5-Mini",
min_transformers_version="4.55.3",
extras={
"tiny": "ai21labs/Jamba-tiny-dev",
"random": "ai21labs/Jamba-tiny-random",
},
),
"Lfm2ForCausalLM": _HfExamplesInfo(
"LiquidAI/LFM2-1.2B", min_transformers_version="4.54"
),
"Lfm2ForCausalLM": _HfExamplesInfo("LiquidAI/LFM2-1.2B"),
"Lfm2MoeForCausalLM": _HfExamplesInfo(
"LiquidAI/LFM2-8B-A1B", min_transformers_version="4.58"
),
@ -338,7 +314,6 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
),
"Llama4ForCausalLM": _HfExamplesInfo(
"meta-llama/Llama-4-Scout-17B-16E-Instruct",
is_available_online=False,
),
"LongcatFlashForCausalLM": _HfExamplesInfo(
"meituan-longcat/LongCat-Flash-Chat", trust_remote_code=True
@ -346,7 +321,6 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"MambaForCausalLM": _HfExamplesInfo("state-spaces/mamba-130m-hf"),
"Mamba2ForCausalLM": _HfExamplesInfo(
"mistralai/Mamba-Codestral-7B-v0.1",
min_transformers_version="4.55.3",
extras={
"random": "yujiepan/mamba2-codestral-v0.1-tiny-random",
},
@ -421,7 +395,6 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"SeedOssForCausalLM": _HfExamplesInfo(
"ByteDance-Seed/Seed-OSS-36B-Instruct",
trust_remote_code=True,
is_available_online=False,
),
"SmolLM3ForCausalLM": _HfExamplesInfo("HuggingFaceTB/SmolLM3-3B"),
"StableLMEpochForCausalLM": _HfExamplesInfo("stabilityai/stablelm-zephyr-3b"),
@ -488,7 +461,8 @@ _EMBEDDING_EXAMPLE_MODELS = {
"RobertaForMaskedLM": _HfExamplesInfo("sentence-transformers/all-roberta-large-v1"),
"XLMRobertaModel": _HfExamplesInfo("intfloat/multilingual-e5-small"),
"BertSpladeSparseEmbeddingModel": _HfExamplesInfo(
"naver/splade-v3", is_available_online=False
"naver/splade-v3",
hf_overrides={"architectures": ["BertSpladeSparseEmbeddingModel"]},
),
# [Multimodal]
"CLIPModel": _HfExamplesInfo("openai/clip-vit-base-patch32"),
@ -499,18 +473,17 @@ _EMBEDDING_EXAMPLE_MODELS = {
"Qwen2VLForConditionalGeneration": _HfExamplesInfo("MrLight/dse-qwen2-2b-mrl-v1"),
"PrithviGeoSpatialMAE": _HfExamplesInfo(
"ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11",
dtype=torch.float16,
dtype="float16",
enforce_eager=True,
skip_tokenizer_init=True,
# This is to avoid the model
# going OOM in CI
require_embed_inputs=True,
# This is to avoid the model going OOM in CI
max_num_seqs=32,
),
"Terratorch": _HfExamplesInfo(
"ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11",
dtype=torch.float16,
dtype="float16",
enforce_eager=True,
skip_tokenizer_init=True,
require_embed_inputs=True,
# This is to avoid the model going OOM in CI
max_num_seqs=32,
),
@ -598,10 +571,7 @@ _MULTIMODAL_EXAMPLE_MODELS = {
),
"FuyuForCausalLM": _HfExamplesInfo("adept/fuyu-8b"),
"Gemma3ForConditionalGeneration": _HfExamplesInfo("google/gemma-3-4b-it"),
"Gemma3nForConditionalGeneration": _HfExamplesInfo(
"google/gemma-3n-E2B-it",
min_transformers_version="4.53",
),
"Gemma3nForConditionalGeneration": _HfExamplesInfo("google/gemma-3n-E2B-it"),
"GraniteSpeechForConditionalGeneration": _HfExamplesInfo(
"ibm-granite/granite-speech-3.3-2b"
),
@ -611,9 +581,7 @@ _MULTIMODAL_EXAMPLE_MODELS = {
hf_overrides={"architectures": ["GLM4VForCausalLM"]},
),
"Glm4vForConditionalGeneration": _HfExamplesInfo("zai-org/GLM-4.1V-9B-Thinking"),
"Glm4vMoeForConditionalGeneration": _HfExamplesInfo(
"zai-org/GLM-4.5V", min_transformers_version="4.56"
),
"Glm4vMoeForConditionalGeneration": _HfExamplesInfo("zai-org/GLM-4.5V"),
"H2OVLChatModel": _HfExamplesInfo(
"h2oai/h2ovl-mississippi-800m",
trust_remote_code=True,
@ -627,9 +595,7 @@ _MULTIMODAL_EXAMPLE_MODELS = {
),
"Idefics3ForConditionalGeneration": _HfExamplesInfo(
"HuggingFaceM4/Idefics3-8B-Llama3",
{"tiny": "HuggingFaceTB/SmolVLM-256M-Instruct"},
min_transformers_version="4.56",
transformers_version_reason="HF model broken in 4.55",
extras={"tiny": "HuggingFaceTB/SmolVLM-256M-Instruct"},
),
"InternS1ForConditionalGeneration": _HfExamplesInfo(
"internlm/Intern-S1", trust_remote_code=True
@ -781,13 +747,11 @@ _MULTIMODAL_EXAMPLE_MODELS = {
"Qwen/Qwen3-VL-4B-Instruct",
max_model_len=4096,
min_transformers_version="4.57",
is_available_online=False,
),
"Qwen3VLMoeForConditionalGeneration": _HfExamplesInfo(
"Qwen/Qwen3-VL-30B-A3B-Instruct",
max_model_len=4096,
min_transformers_version="4.57",
is_available_online=False,
),
"Qwen3OmniMoeForConditionalGeneration": _HfExamplesInfo(
"Qwen/Qwen3-Omni-30B-A3B-Instruct",
@ -799,9 +763,7 @@ _MULTIMODAL_EXAMPLE_MODELS = {
"Skywork/Skywork-R1V-38B", trust_remote_code=True
),
"SmolVLMForConditionalGeneration": _HfExamplesInfo(
"HuggingFaceTB/SmolVLM2-2.2B-Instruct",
min_transformers_version="4.56",
transformers_version_reason="HF model broken in 4.55",
"HuggingFaceTB/SmolVLM2-2.2B-Instruct"
),
"Step3VLForConditionalGeneration": _HfExamplesInfo(
"stepfun-ai/step3", trust_remote_code=True
@ -817,7 +779,6 @@ _MULTIMODAL_EXAMPLE_MODELS = {
),
"VoxtralForConditionalGeneration": _HfExamplesInfo(
"mistralai/Voxtral-Mini-3B-2507",
min_transformers_version="4.54",
# disable this temporarily until we support HF format
is_available_online=False,
),
@ -878,8 +839,8 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
"EagleMiniCPMForCausalLM": _HfExamplesInfo(
"openbmb/MiniCPM-1B-sft-bf16",
trust_remote_code=True,
is_available_online=False,
speculative_model="openbmb/MiniCPM-2B-sft-bf16",
speculative_method="eagle",
tokenizer="openbmb/MiniCPM-2B-sft-bf16",
),
"ErnieMTPModel": _HfExamplesInfo(
@ -890,8 +851,6 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
"Glm4MoeMTPModel": _HfExamplesInfo(
"zai-org/GLM-4.5",
speculative_model="zai-org/GLM-4.5",
min_transformers_version="4.56",
is_available_online=False,
),
"LongCatFlashMTPModel": _HfExamplesInfo(
"meituan-longcat/LongCat-Flash-Chat",

View File

@ -105,20 +105,19 @@ def can_initialize(
if model_arch == "WhisperForConditionalGeneration":
m.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
extra_args = {}
if model_arch in ("PrithviGeoSpatialMAE", "Terratorch"):
extra_args["enable_mm_embeds"] = True
LLM(
model_info.default,
tokenizer=model_info.tokenizer,
tokenizer_mode=model_info.tokenizer_mode,
revision=model_info.revision,
enforce_eager=model_info.enforce_eager,
skip_tokenizer_init=model_info.skip_tokenizer_init,
skip_tokenizer_init=model_info.require_embed_inputs,
enable_prompt_embeds=model_info.require_embed_inputs,
enable_mm_embeds=model_info.require_embed_inputs,
dtype=model_info.dtype,
speculative_config={
"model": model_info.speculative_model,
"method": model_info.speculative_method,
"num_speculative_tokens": 1,
}
if model_info.speculative_model
@ -133,7 +132,6 @@ def can_initialize(
else "vllm",
hf_overrides=hf_overrides_fn,
max_num_seqs=model_info.max_num_seqs,
**extra_args,
)

View File

@ -309,7 +309,9 @@ def build_model_context(
limit_mm_per_prompt=limit_mm_per_prompt,
mm_processor_cache_gb=mm_processor_cache_gb,
hf_overrides=model_info.hf_overrides,
skip_tokenizer_init=model_info.skip_tokenizer_init,
skip_tokenizer_init=model_info.require_embed_inputs,
enable_prompt_embeds=model_info.require_embed_inputs,
enable_mm_embeds=model_info.require_embed_inputs,
enforce_eager=model_info.enforce_eager,
**model_config_kwargs,
)

View File

@ -36,9 +36,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from packaging.version import Version
from transformers import BatchFeature
from transformers import __version__ as TRANSFORMERS_VERSION
from transformers.models.glm4v.configuration_glm4v import Glm4vVisionConfig
from transformers.models.glm4v.image_processing_glm4v import (
Glm4vImageProcessor,
@ -1270,14 +1268,7 @@ class Glm4vMultiModalProcessor(BaseMultiModalProcessor[Glm4vProcessingInfo]):
video_mm_data = dict()
video_mm_data["videos"] = [[video_array]]
# backward compatibility for Transformers 4.55
unuse_metadata = ["do_sample_frames"]
if (
not hasattr(VideoMetadata, "frames_indices")
and "frames_indices" in metadata
):
unuse_metadata.append("frames_indices")
video_mm_data["video_metadata"] = [
[
VideoMetadata(
@ -1296,24 +1287,11 @@ class Glm4vMultiModalProcessor(BaseMultiModalProcessor[Glm4vProcessingInfo]):
mm_kwargs=video_mm_kwargs,
tok_kwargs=tok_kwargs,
)
if not video_mm_kwargs["do_sample_frames"] and Version(
TRANSFORMERS_VERSION
) < Version("4.56.0"):
# Transformers v4.55 has incorrect timestamps issue for
# skip sampling. We construct the placeholder manually to
# get placeholders with correct timestamps.
placeholder = self.info._construct_video_placeholder(
video_array,
metadata,
video_outputs["video_grid_thw"].squeeze(0),
)
video_placeholder = processor.tokenizer.decode(placeholder)
else:
input_ids = video_outputs.pop("input_ids")
input_ids[input_ids == processor.image_token_id] = (
processor.video_token_id
)
video_placeholder = processor.tokenizer.batch_decode(input_ids)[0]
input_ids = video_outputs.pop("input_ids")
input_ids[input_ids == processor.image_token_id] = (
processor.video_token_id
)
video_placeholder = processor.tokenizer.batch_decode(input_ids)[0]
prompt = prompt.replace(
"<|begin_of_video|><|video|><|end_of_video|>",
video_placeholder,