mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 10:46:05 +08:00
[CI/Build] Remove unnecessary flags from test registry (#27353)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
237cf6d32a
commit
fe2016de2d
@ -374,8 +374,8 @@ th {
|
||||
| `GraniteMoeSharedForCausalLM` | Granite MoE Shared | `ibm-research/moe-7b-1b-active-shared-experts` (test model) | ✅︎ | ✅︎ |
|
||||
| `GritLM` | GritLM | `parasail-ai/GritLM-7B-vllm`. | ✅︎ | ✅︎ |
|
||||
| `Grok1ModelForCausalLM` | Grok1 | `hpcai-tech/grok-1`. | ✅︎ | ✅︎ |
|
||||
| `HunYuanDenseV1ForCausalLM` | Hunyuan-7B-Instruct-0124 | `tencent/Hunyuan-7B-Instruct-0124` | ✅︎ | ✅︎ |
|
||||
| `HunYuanMoEV1ForCausalLM` | Hunyuan-80B-A13B | `tencent/Hunyuan-A13B-Instruct`, `tencent/Hunyuan-A13B-Pretrain`, `tencent/Hunyuan-A13B-Instruct-FP8`, etc. | ✅︎ | ✅︎ |
|
||||
| `HunYuanDenseV1ForCausalLM` | Hunyuan Dense | `tencent/Hunyuan-7B-Instruct` | ✅︎ | ✅︎ |
|
||||
| `HunYuanMoEV1ForCausalLM` | Hunyuan-A13B | `tencent/Hunyuan-A13B-Instruct`, `tencent/Hunyuan-A13B-Pretrain`, `tencent/Hunyuan-A13B-Instruct-FP8`, etc. | ✅︎ | ✅︎ |
|
||||
| `HCXVisionForCausalLM` | HyperCLOVAX-SEED-Vision-Instruct-3B | `naver-hyperclovax/HyperCLOVAX-SEED-Vision-Instruct-3B` | | |
|
||||
| `InternLMForCausalLM` | InternLM | `internlm/internlm-7b`, `internlm/internlm-chat-7b`, etc. | ✅︎ | ✅︎ |
|
||||
| `InternLM2ForCausalLM` | InternLM2 | `internlm/internlm2-7b`, `internlm/internlm2-chat-7b`, etc. | ✅︎ | ✅︎ |
|
||||
|
||||
@ -244,7 +244,7 @@ def _compare_tp(
|
||||
tokenizer_mode = model_info.tokenizer_mode
|
||||
hf_overrides = model_info.hf_overrides
|
||||
hf_config = get_config(model_id, trust_remote_code)
|
||||
skip_tokenizer_init = model_info.skip_tokenizer_init
|
||||
require_embed_inputs = model_info.require_embed_inputs
|
||||
max_num_seqs = model_info.max_num_seqs
|
||||
|
||||
dtype = "float16"
|
||||
@ -299,8 +299,14 @@ def _compare_tp(
|
||||
common_args.extend(["--load-format", load_format])
|
||||
if hf_overrides:
|
||||
common_args.extend(["--hf-overrides", json.dumps(hf_overrides)])
|
||||
if skip_tokenizer_init:
|
||||
common_args.append("--skip-tokenizer-init")
|
||||
if require_embed_inputs:
|
||||
common_args.extend(
|
||||
[
|
||||
"--skip-tokenizer-init",
|
||||
"--enable-prompt-embeds",
|
||||
"--enable-mm-embeds",
|
||||
]
|
||||
)
|
||||
if max_num_seqs:
|
||||
common_args.extend(["--max-num-seqs", f"{max_num_seqs}"])
|
||||
|
||||
|
||||
@ -181,7 +181,7 @@ def _compare_sp(
|
||||
trust_remote_code = model_info.trust_remote_code
|
||||
tokenizer_mode = model_info.tokenizer_mode
|
||||
hf_overrides = model_info.hf_overrides
|
||||
skip_tokenizer_init = model_info.skip_tokenizer_init
|
||||
require_embed_inputs = model_info.require_embed_inputs
|
||||
|
||||
if load_format == "dummy":
|
||||
# Avoid OOM
|
||||
@ -233,8 +233,14 @@ def _compare_sp(
|
||||
common_args.extend(["--load-format", load_format])
|
||||
if hf_overrides:
|
||||
common_args.extend(["--hf-overrides", json.dumps(hf_overrides)])
|
||||
if skip_tokenizer_init:
|
||||
common_args.append("--skip-tokenizer-init")
|
||||
if require_embed_inputs:
|
||||
common_args.extend(
|
||||
[
|
||||
"--skip-tokenizer-init",
|
||||
"--enable-prompt-embeds",
|
||||
"--enable-mm-embeds",
|
||||
]
|
||||
)
|
||||
|
||||
compilation_config = {
|
||||
"mode": CompilationMode.VLLM_COMPILE,
|
||||
|
||||
@ -114,7 +114,9 @@ def test_get_gen_prompt(
|
||||
trust_remote_code=model_info.trust_remote_code,
|
||||
revision=model_info.revision,
|
||||
hf_overrides=model_info.hf_overrides,
|
||||
skip_tokenizer_init=model_info.skip_tokenizer_init,
|
||||
skip_tokenizer_init=model_info.require_embed_inputs,
|
||||
enable_prompt_embeds=model_info.require_embed_inputs,
|
||||
enable_mm_embeds=model_info.require_embed_inputs,
|
||||
enforce_eager=model_info.enforce_eager,
|
||||
dtype=model_info.dtype,
|
||||
)
|
||||
|
||||
@ -1742,7 +1742,9 @@ def test_resolve_hf_chat_template(sample_json_schema, model, use_tools):
|
||||
revision=model_info.revision,
|
||||
trust_remote_code=model_info.trust_remote_code,
|
||||
hf_overrides=model_info.hf_overrides,
|
||||
skip_tokenizer_init=model_info.skip_tokenizer_init,
|
||||
skip_tokenizer_init=model_info.require_embed_inputs,
|
||||
enable_prompt_embeds=model_info.require_embed_inputs,
|
||||
enable_mm_embeds=model_info.require_embed_inputs,
|
||||
enforce_eager=model_info.enforce_eager,
|
||||
dtype=model_info.dtype,
|
||||
)
|
||||
@ -1842,7 +1844,9 @@ def test_resolve_hf_chat_template_kwargs(sample_json_schema, model, expected_kwa
|
||||
revision=model_info.revision,
|
||||
trust_remote_code=model_info.trust_remote_code,
|
||||
hf_overrides=model_info.hf_overrides,
|
||||
skip_tokenizer_init=model_info.skip_tokenizer_init,
|
||||
skip_tokenizer_init=model_info.require_embed_inputs,
|
||||
enable_prompt_embeds=model_info.require_embed_inputs,
|
||||
enable_mm_embeds=model_info.require_embed_inputs,
|
||||
enforce_eager=model_info.enforce_eager,
|
||||
dtype=model_info.dtype,
|
||||
)
|
||||
@ -1903,7 +1907,9 @@ def test_resolve_content_format_hf_defined(model, expected_format):
|
||||
revision=model_info.revision,
|
||||
trust_remote_code=model_info.trust_remote_code,
|
||||
hf_overrides=model_info.hf_overrides,
|
||||
skip_tokenizer_init=model_info.skip_tokenizer_init,
|
||||
skip_tokenizer_init=model_info.require_embed_inputs,
|
||||
enable_prompt_embeds=model_info.require_embed_inputs,
|
||||
enable_mm_embeds=model_info.require_embed_inputs,
|
||||
enforce_eager=model_info.enforce_eager,
|
||||
dtype=model_info.dtype,
|
||||
)
|
||||
@ -1961,7 +1967,9 @@ def test_resolve_content_format_fallbacks(model, expected_format):
|
||||
revision=model_info.revision,
|
||||
trust_remote_code=model_info.trust_remote_code,
|
||||
hf_overrides=model_info.hf_overrides,
|
||||
skip_tokenizer_init=model_info.skip_tokenizer_init,
|
||||
skip_tokenizer_init=model_info.require_embed_inputs,
|
||||
enable_prompt_embeds=model_info.require_embed_inputs,
|
||||
enable_mm_embeds=model_info.require_embed_inputs,
|
||||
enforce_eager=model_info.enforce_eager,
|
||||
dtype=model_info.dtype,
|
||||
)
|
||||
|
||||
@ -71,8 +71,9 @@ def run_test(
|
||||
vllm_runner_kwargs_["tokenizer_mode"] = model_info.tokenizer_mode
|
||||
if model_info.hf_overrides:
|
||||
vllm_runner_kwargs_["hf_overrides"] = model_info.hf_overrides
|
||||
if model_info.skip_tokenizer_init:
|
||||
vllm_runner_kwargs_["skip_tokenizer_init"] = model_info.skip_tokenizer_init
|
||||
if model_info.require_embed_inputs:
|
||||
for k in ("skip_tokenizer_init", "enable_prompt_embeds", "enable_mm_embeds"):
|
||||
vllm_runner_kwargs_[k] = model_info.require_embed_inputs
|
||||
|
||||
if vllm_runner_kwargs:
|
||||
vllm_runner_kwargs_.update(vllm_runner_kwargs)
|
||||
|
||||
@ -108,7 +108,9 @@ def _test_processing_correctness(
|
||||
hf_overrides=model_info.hf_overrides,
|
||||
# Ensure that the cache can fit all of the data
|
||||
mm_processor_cache_gb=2048,
|
||||
skip_tokenizer_init=model_info.skip_tokenizer_init,
|
||||
skip_tokenizer_init=model_info.require_embed_inputs,
|
||||
enable_prompt_embeds=model_info.require_embed_inputs,
|
||||
enable_mm_embeds=model_info.require_embed_inputs,
|
||||
enforce_eager=model_info.enforce_eager,
|
||||
dtype=model_info.dtype,
|
||||
)
|
||||
|
||||
@ -218,7 +218,9 @@ def test_model_tensor_schema(model_arch: str, model_id: str):
|
||||
revision=model_info.revision,
|
||||
trust_remote_code=model_info.trust_remote_code,
|
||||
hf_overrides=hf_overrides_fn,
|
||||
skip_tokenizer_init=model_info.skip_tokenizer_init,
|
||||
skip_tokenizer_init=model_info.require_embed_inputs,
|
||||
enable_prompt_embeds=model_info.require_embed_inputs,
|
||||
enable_mm_embeds=model_info.require_embed_inputs,
|
||||
enforce_eager=model_info.enforce_eager,
|
||||
dtype=model_info.dtype,
|
||||
)
|
||||
|
||||
@ -59,7 +59,9 @@ def test_hf_model_weights_mapper(model_arch: str):
|
||||
revision=model_info.revision,
|
||||
trust_remote_code=model_info.trust_remote_code,
|
||||
hf_overrides=model_info.hf_overrides,
|
||||
skip_tokenizer_init=model_info.skip_tokenizer_init,
|
||||
skip_tokenizer_init=model_info.require_embed_inputs,
|
||||
enable_prompt_embeds=model_info.require_embed_inputs,
|
||||
enable_mm_embeds=model_info.require_embed_inputs,
|
||||
enforce_eager=model_info.enforce_eager,
|
||||
dtype=model_info.dtype,
|
||||
)
|
||||
|
||||
@ -6,7 +6,6 @@ from dataclasses import dataclass, field
|
||||
from typing import Any, Literal
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from packaging.version import Version
|
||||
from transformers import __version__ as TRANSFORMERS_VERSION
|
||||
|
||||
@ -33,6 +32,11 @@ class _HfExamplesInfo:
|
||||
for speculative decoding.
|
||||
"""
|
||||
|
||||
speculative_method: str | None = None
|
||||
"""
|
||||
The method to use for speculative decoding.
|
||||
"""
|
||||
|
||||
min_transformers_version: str | None = None
|
||||
"""
|
||||
The minimum version of HF Transformers that is required to run this model.
|
||||
@ -48,9 +52,10 @@ class _HfExamplesInfo:
|
||||
The reason for the minimum/maximum version requirement.
|
||||
"""
|
||||
|
||||
skip_tokenizer_init: bool = False
|
||||
require_embed_inputs: bool = False
|
||||
"""
|
||||
If true, skip initialization of tokenizer and detokenizer.
|
||||
If `True`, enables prompt and multi-modal embedding inputs while
|
||||
disabling tokenization.
|
||||
"""
|
||||
|
||||
dtype: ModelDType = "auto"
|
||||
@ -168,10 +173,7 @@ class _HfExamplesInfo:
|
||||
|
||||
_TEXT_GENERATION_EXAMPLE_MODELS = {
|
||||
# [Decoder-only]
|
||||
"ApertusForCausalLM": _HfExamplesInfo(
|
||||
"swiss-ai/Apertus-8B-Instruct-2509",
|
||||
min_transformers_version="4.56.0",
|
||||
),
|
||||
"ApertusForCausalLM": _HfExamplesInfo("swiss-ai/Apertus-8B-Instruct-2509"),
|
||||
"AquilaModel": _HfExamplesInfo("BAAI/AquilaChat-7B", trust_remote_code=True),
|
||||
"AquilaForCausalLM": _HfExamplesInfo("BAAI/AquilaChat2-7B", trust_remote_code=True),
|
||||
"ArceeForCausalLM": _HfExamplesInfo("arcee-ai/AFM-4.5B-Base"),
|
||||
@ -192,7 +194,6 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
||||
),
|
||||
"BambaForCausalLM": _HfExamplesInfo(
|
||||
"ibm-ai-platform/Bamba-9B-v1",
|
||||
min_transformers_version="4.55.3",
|
||||
extras={"tiny": "hmellor/tiny-random-BambaForCausalLM"},
|
||||
),
|
||||
"BloomForCausalLM": _HfExamplesInfo(
|
||||
@ -212,11 +213,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
||||
"CohereForAI/c4ai-command-r7b-12-2024",
|
||||
trust_remote_code=True,
|
||||
),
|
||||
"CwmForCausalLM": _HfExamplesInfo(
|
||||
"facebook/cwm",
|
||||
trust_remote_code=True,
|
||||
is_available_online=False,
|
||||
),
|
||||
"CwmForCausalLM": _HfExamplesInfo("facebook/cwm", min_transformers_version="4.58"),
|
||||
"DbrxForCausalLM": _HfExamplesInfo("databricks/dbrx-instruct"),
|
||||
"DeciLMForCausalLM": _HfExamplesInfo(
|
||||
"nvidia/Llama-3_3-Nemotron-Super-49B-v1",
|
||||
@ -232,18 +229,12 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
||||
trust_remote_code=True,
|
||||
),
|
||||
"DeepseekV32ForCausalLM": _HfExamplesInfo("deepseek-ai/DeepSeek-V3.2-Exp"),
|
||||
"Ernie4_5ForCausalLM": _HfExamplesInfo(
|
||||
"baidu/ERNIE-4.5-0.3B-PT", min_transformers_version="4.54"
|
||||
),
|
||||
"Ernie4_5_MoeForCausalLM": _HfExamplesInfo(
|
||||
"baidu/ERNIE-4.5-21B-A3B-PT", min_transformers_version="4.54"
|
||||
),
|
||||
"Ernie4_5ForCausalLM": _HfExamplesInfo("baidu/ERNIE-4.5-0.3B-PT"),
|
||||
"Ernie4_5_MoeForCausalLM": _HfExamplesInfo("baidu/ERNIE-4.5-21B-A3B-PT"),
|
||||
"ExaoneForCausalLM": _HfExamplesInfo(
|
||||
"LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct", trust_remote_code=True
|
||||
),
|
||||
"Exaone4ForCausalLM": _HfExamplesInfo(
|
||||
"LGAI-EXAONE/EXAONE-4.0-32B", min_transformers_version="4.54"
|
||||
),
|
||||
"Exaone4ForCausalLM": _HfExamplesInfo("LGAI-EXAONE/EXAONE-4.0-32B"),
|
||||
"Fairseq2LlamaForCausalLM": _HfExamplesInfo("mgleize/fairseq2-dummy-Llama-3.2-1B"),
|
||||
"FalconForCausalLM": _HfExamplesInfo("tiiuae/falcon-7b"),
|
||||
"FalconH1ForCausalLM": _HfExamplesInfo("tiiuae/Falcon-H1-0.5B-Base"),
|
||||
@ -251,14 +242,10 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
||||
"GemmaForCausalLM": _HfExamplesInfo("google/gemma-1.1-2b-it"),
|
||||
"Gemma2ForCausalLM": _HfExamplesInfo("google/gemma-2-9b"),
|
||||
"Gemma3ForCausalLM": _HfExamplesInfo("google/gemma-3-1b-it"),
|
||||
"Gemma3nForCausalLM": _HfExamplesInfo(
|
||||
"google/gemma-3n-E2B-it", min_transformers_version="4.53"
|
||||
),
|
||||
"Gemma3nForCausalLM": _HfExamplesInfo("google/gemma-3n-E2B-it"),
|
||||
"GlmForCausalLM": _HfExamplesInfo("zai-org/glm-4-9b-chat-hf"),
|
||||
"Glm4ForCausalLM": _HfExamplesInfo("zai-org/GLM-4-9B-0414"),
|
||||
"Glm4MoeForCausalLM": _HfExamplesInfo(
|
||||
"zai-org/GLM-4.5", min_transformers_version="4.54"
|
||||
),
|
||||
"Glm4MoeForCausalLM": _HfExamplesInfo("zai-org/GLM-4.5"),
|
||||
"GPT2LMHeadModel": _HfExamplesInfo("openai-community/gpt2", {"alias": "gpt2"}),
|
||||
"GPTBigCodeForCausalLM": _HfExamplesInfo(
|
||||
"bigcode/starcoder",
|
||||
@ -266,8 +253,6 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
||||
"tiny": "bigcode/tiny_starcoder_py",
|
||||
"santacoder": "bigcode/gpt_bigcode-santacoder",
|
||||
},
|
||||
min_transformers_version="4.55.1",
|
||||
transformers_version_reason="HF model broken in 4.55.0",
|
||||
),
|
||||
"GPTJForCausalLM": _HfExamplesInfo(
|
||||
"Milos/slovak-gpt-j-405M", {"6b": "EleutherAI/gpt-j-6b"}
|
||||
@ -279,8 +264,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
||||
"GraniteForCausalLM": _HfExamplesInfo("ibm/PowerLM-3b"),
|
||||
"GraniteMoeForCausalLM": _HfExamplesInfo("ibm/PowerMoE-3b"),
|
||||
"GraniteMoeHybridForCausalLM": _HfExamplesInfo(
|
||||
"ibm-granite/granite-4.0-tiny-preview",
|
||||
min_transformers_version="4.55.3",
|
||||
"ibm-granite/granite-4.0-tiny-preview"
|
||||
),
|
||||
"GraniteMoeSharedForCausalLM": _HfExamplesInfo(
|
||||
"ibm-research/moe-7b-1b-active-shared-experts"
|
||||
@ -288,15 +272,10 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
||||
"Grok1ModelForCausalLM": _HfExamplesInfo(
|
||||
"hpcai-tech/grok-1", trust_remote_code=True
|
||||
),
|
||||
"HunYuanDenseV1ForCausalLM": _HfExamplesInfo("tencent/Hunyuan-7B-Instruct"),
|
||||
"HunYuanMoEV1ForCausalLM": _HfExamplesInfo(
|
||||
"tencent/Hunyuan-A13B-Instruct", trust_remote_code=True
|
||||
),
|
||||
# TODO: Remove is_available_online once their config.json is fixed
|
||||
"HunYuanDenseV1ForCausalLM": _HfExamplesInfo(
|
||||
"tencent/Hunyuan-7B-Instruct-0124",
|
||||
trust_remote_code=True,
|
||||
is_available_online=False,
|
||||
),
|
||||
"InternLMForCausalLM": _HfExamplesInfo(
|
||||
"internlm/internlm-chat-7b", trust_remote_code=True
|
||||
),
|
||||
@ -312,15 +291,12 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
||||
"JAISLMHeadModel": _HfExamplesInfo("inceptionai/jais-13b-chat"),
|
||||
"JambaForCausalLM": _HfExamplesInfo(
|
||||
"ai21labs/AI21-Jamba-1.5-Mini",
|
||||
min_transformers_version="4.55.3",
|
||||
extras={
|
||||
"tiny": "ai21labs/Jamba-tiny-dev",
|
||||
"random": "ai21labs/Jamba-tiny-random",
|
||||
},
|
||||
),
|
||||
"Lfm2ForCausalLM": _HfExamplesInfo(
|
||||
"LiquidAI/LFM2-1.2B", min_transformers_version="4.54"
|
||||
),
|
||||
"Lfm2ForCausalLM": _HfExamplesInfo("LiquidAI/LFM2-1.2B"),
|
||||
"Lfm2MoeForCausalLM": _HfExamplesInfo(
|
||||
"LiquidAI/LFM2-8B-A1B", min_transformers_version="4.58"
|
||||
),
|
||||
@ -338,7 +314,6 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
||||
),
|
||||
"Llama4ForCausalLM": _HfExamplesInfo(
|
||||
"meta-llama/Llama-4-Scout-17B-16E-Instruct",
|
||||
is_available_online=False,
|
||||
),
|
||||
"LongcatFlashForCausalLM": _HfExamplesInfo(
|
||||
"meituan-longcat/LongCat-Flash-Chat", trust_remote_code=True
|
||||
@ -346,7 +321,6 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
||||
"MambaForCausalLM": _HfExamplesInfo("state-spaces/mamba-130m-hf"),
|
||||
"Mamba2ForCausalLM": _HfExamplesInfo(
|
||||
"mistralai/Mamba-Codestral-7B-v0.1",
|
||||
min_transformers_version="4.55.3",
|
||||
extras={
|
||||
"random": "yujiepan/mamba2-codestral-v0.1-tiny-random",
|
||||
},
|
||||
@ -421,7 +395,6 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
||||
"SeedOssForCausalLM": _HfExamplesInfo(
|
||||
"ByteDance-Seed/Seed-OSS-36B-Instruct",
|
||||
trust_remote_code=True,
|
||||
is_available_online=False,
|
||||
),
|
||||
"SmolLM3ForCausalLM": _HfExamplesInfo("HuggingFaceTB/SmolLM3-3B"),
|
||||
"StableLMEpochForCausalLM": _HfExamplesInfo("stabilityai/stablelm-zephyr-3b"),
|
||||
@ -488,7 +461,8 @@ _EMBEDDING_EXAMPLE_MODELS = {
|
||||
"RobertaForMaskedLM": _HfExamplesInfo("sentence-transformers/all-roberta-large-v1"),
|
||||
"XLMRobertaModel": _HfExamplesInfo("intfloat/multilingual-e5-small"),
|
||||
"BertSpladeSparseEmbeddingModel": _HfExamplesInfo(
|
||||
"naver/splade-v3", is_available_online=False
|
||||
"naver/splade-v3",
|
||||
hf_overrides={"architectures": ["BertSpladeSparseEmbeddingModel"]},
|
||||
),
|
||||
# [Multimodal]
|
||||
"CLIPModel": _HfExamplesInfo("openai/clip-vit-base-patch32"),
|
||||
@ -499,18 +473,17 @@ _EMBEDDING_EXAMPLE_MODELS = {
|
||||
"Qwen2VLForConditionalGeneration": _HfExamplesInfo("MrLight/dse-qwen2-2b-mrl-v1"),
|
||||
"PrithviGeoSpatialMAE": _HfExamplesInfo(
|
||||
"ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11",
|
||||
dtype=torch.float16,
|
||||
dtype="float16",
|
||||
enforce_eager=True,
|
||||
skip_tokenizer_init=True,
|
||||
# This is to avoid the model
|
||||
# going OOM in CI
|
||||
require_embed_inputs=True,
|
||||
# This is to avoid the model going OOM in CI
|
||||
max_num_seqs=32,
|
||||
),
|
||||
"Terratorch": _HfExamplesInfo(
|
||||
"ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11",
|
||||
dtype=torch.float16,
|
||||
dtype="float16",
|
||||
enforce_eager=True,
|
||||
skip_tokenizer_init=True,
|
||||
require_embed_inputs=True,
|
||||
# This is to avoid the model going OOM in CI
|
||||
max_num_seqs=32,
|
||||
),
|
||||
@ -598,10 +571,7 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
||||
),
|
||||
"FuyuForCausalLM": _HfExamplesInfo("adept/fuyu-8b"),
|
||||
"Gemma3ForConditionalGeneration": _HfExamplesInfo("google/gemma-3-4b-it"),
|
||||
"Gemma3nForConditionalGeneration": _HfExamplesInfo(
|
||||
"google/gemma-3n-E2B-it",
|
||||
min_transformers_version="4.53",
|
||||
),
|
||||
"Gemma3nForConditionalGeneration": _HfExamplesInfo("google/gemma-3n-E2B-it"),
|
||||
"GraniteSpeechForConditionalGeneration": _HfExamplesInfo(
|
||||
"ibm-granite/granite-speech-3.3-2b"
|
||||
),
|
||||
@ -611,9 +581,7 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
||||
hf_overrides={"architectures": ["GLM4VForCausalLM"]},
|
||||
),
|
||||
"Glm4vForConditionalGeneration": _HfExamplesInfo("zai-org/GLM-4.1V-9B-Thinking"),
|
||||
"Glm4vMoeForConditionalGeneration": _HfExamplesInfo(
|
||||
"zai-org/GLM-4.5V", min_transformers_version="4.56"
|
||||
),
|
||||
"Glm4vMoeForConditionalGeneration": _HfExamplesInfo("zai-org/GLM-4.5V"),
|
||||
"H2OVLChatModel": _HfExamplesInfo(
|
||||
"h2oai/h2ovl-mississippi-800m",
|
||||
trust_remote_code=True,
|
||||
@ -627,9 +595,7 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
||||
),
|
||||
"Idefics3ForConditionalGeneration": _HfExamplesInfo(
|
||||
"HuggingFaceM4/Idefics3-8B-Llama3",
|
||||
{"tiny": "HuggingFaceTB/SmolVLM-256M-Instruct"},
|
||||
min_transformers_version="4.56",
|
||||
transformers_version_reason="HF model broken in 4.55",
|
||||
extras={"tiny": "HuggingFaceTB/SmolVLM-256M-Instruct"},
|
||||
),
|
||||
"InternS1ForConditionalGeneration": _HfExamplesInfo(
|
||||
"internlm/Intern-S1", trust_remote_code=True
|
||||
@ -781,13 +747,11 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
||||
"Qwen/Qwen3-VL-4B-Instruct",
|
||||
max_model_len=4096,
|
||||
min_transformers_version="4.57",
|
||||
is_available_online=False,
|
||||
),
|
||||
"Qwen3VLMoeForConditionalGeneration": _HfExamplesInfo(
|
||||
"Qwen/Qwen3-VL-30B-A3B-Instruct",
|
||||
max_model_len=4096,
|
||||
min_transformers_version="4.57",
|
||||
is_available_online=False,
|
||||
),
|
||||
"Qwen3OmniMoeForConditionalGeneration": _HfExamplesInfo(
|
||||
"Qwen/Qwen3-Omni-30B-A3B-Instruct",
|
||||
@ -799,9 +763,7 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
||||
"Skywork/Skywork-R1V-38B", trust_remote_code=True
|
||||
),
|
||||
"SmolVLMForConditionalGeneration": _HfExamplesInfo(
|
||||
"HuggingFaceTB/SmolVLM2-2.2B-Instruct",
|
||||
min_transformers_version="4.56",
|
||||
transformers_version_reason="HF model broken in 4.55",
|
||||
"HuggingFaceTB/SmolVLM2-2.2B-Instruct"
|
||||
),
|
||||
"Step3VLForConditionalGeneration": _HfExamplesInfo(
|
||||
"stepfun-ai/step3", trust_remote_code=True
|
||||
@ -817,7 +779,6 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
||||
),
|
||||
"VoxtralForConditionalGeneration": _HfExamplesInfo(
|
||||
"mistralai/Voxtral-Mini-3B-2507",
|
||||
min_transformers_version="4.54",
|
||||
# disable this temporarily until we support HF format
|
||||
is_available_online=False,
|
||||
),
|
||||
@ -878,8 +839,8 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
|
||||
"EagleMiniCPMForCausalLM": _HfExamplesInfo(
|
||||
"openbmb/MiniCPM-1B-sft-bf16",
|
||||
trust_remote_code=True,
|
||||
is_available_online=False,
|
||||
speculative_model="openbmb/MiniCPM-2B-sft-bf16",
|
||||
speculative_method="eagle",
|
||||
tokenizer="openbmb/MiniCPM-2B-sft-bf16",
|
||||
),
|
||||
"ErnieMTPModel": _HfExamplesInfo(
|
||||
@ -890,8 +851,6 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
|
||||
"Glm4MoeMTPModel": _HfExamplesInfo(
|
||||
"zai-org/GLM-4.5",
|
||||
speculative_model="zai-org/GLM-4.5",
|
||||
min_transformers_version="4.56",
|
||||
is_available_online=False,
|
||||
),
|
||||
"LongCatFlashMTPModel": _HfExamplesInfo(
|
||||
"meituan-longcat/LongCat-Flash-Chat",
|
||||
|
||||
@ -105,20 +105,19 @@ def can_initialize(
|
||||
if model_arch == "WhisperForConditionalGeneration":
|
||||
m.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
|
||||
|
||||
extra_args = {}
|
||||
if model_arch in ("PrithviGeoSpatialMAE", "Terratorch"):
|
||||
extra_args["enable_mm_embeds"] = True
|
||||
|
||||
LLM(
|
||||
model_info.default,
|
||||
tokenizer=model_info.tokenizer,
|
||||
tokenizer_mode=model_info.tokenizer_mode,
|
||||
revision=model_info.revision,
|
||||
enforce_eager=model_info.enforce_eager,
|
||||
skip_tokenizer_init=model_info.skip_tokenizer_init,
|
||||
skip_tokenizer_init=model_info.require_embed_inputs,
|
||||
enable_prompt_embeds=model_info.require_embed_inputs,
|
||||
enable_mm_embeds=model_info.require_embed_inputs,
|
||||
dtype=model_info.dtype,
|
||||
speculative_config={
|
||||
"model": model_info.speculative_model,
|
||||
"method": model_info.speculative_method,
|
||||
"num_speculative_tokens": 1,
|
||||
}
|
||||
if model_info.speculative_model
|
||||
@ -133,7 +132,6 @@ def can_initialize(
|
||||
else "vllm",
|
||||
hf_overrides=hf_overrides_fn,
|
||||
max_num_seqs=model_info.max_num_seqs,
|
||||
**extra_args,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -309,7 +309,9 @@ def build_model_context(
|
||||
limit_mm_per_prompt=limit_mm_per_prompt,
|
||||
mm_processor_cache_gb=mm_processor_cache_gb,
|
||||
hf_overrides=model_info.hf_overrides,
|
||||
skip_tokenizer_init=model_info.skip_tokenizer_init,
|
||||
skip_tokenizer_init=model_info.require_embed_inputs,
|
||||
enable_prompt_embeds=model_info.require_embed_inputs,
|
||||
enable_mm_embeds=model_info.require_embed_inputs,
|
||||
enforce_eager=model_info.enforce_eager,
|
||||
**model_config_kwargs,
|
||||
)
|
||||
|
||||
@ -36,9 +36,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from packaging.version import Version
|
||||
from transformers import BatchFeature
|
||||
from transformers import __version__ as TRANSFORMERS_VERSION
|
||||
from transformers.models.glm4v.configuration_glm4v import Glm4vVisionConfig
|
||||
from transformers.models.glm4v.image_processing_glm4v import (
|
||||
Glm4vImageProcessor,
|
||||
@ -1270,14 +1268,7 @@ class Glm4vMultiModalProcessor(BaseMultiModalProcessor[Glm4vProcessingInfo]):
|
||||
video_mm_data = dict()
|
||||
video_mm_data["videos"] = [[video_array]]
|
||||
|
||||
# backward compatibility for Transformers 4.55
|
||||
unuse_metadata = ["do_sample_frames"]
|
||||
if (
|
||||
not hasattr(VideoMetadata, "frames_indices")
|
||||
and "frames_indices" in metadata
|
||||
):
|
||||
unuse_metadata.append("frames_indices")
|
||||
|
||||
video_mm_data["video_metadata"] = [
|
||||
[
|
||||
VideoMetadata(
|
||||
@ -1296,24 +1287,11 @@ class Glm4vMultiModalProcessor(BaseMultiModalProcessor[Glm4vProcessingInfo]):
|
||||
mm_kwargs=video_mm_kwargs,
|
||||
tok_kwargs=tok_kwargs,
|
||||
)
|
||||
if not video_mm_kwargs["do_sample_frames"] and Version(
|
||||
TRANSFORMERS_VERSION
|
||||
) < Version("4.56.0"):
|
||||
# Transformers v4.55 has incorrect timestamps issue for
|
||||
# skip sampling. We construct the placeholder manually to
|
||||
# get placeholders with correct timestamps.
|
||||
placeholder = self.info._construct_video_placeholder(
|
||||
video_array,
|
||||
metadata,
|
||||
video_outputs["video_grid_thw"].squeeze(0),
|
||||
)
|
||||
video_placeholder = processor.tokenizer.decode(placeholder)
|
||||
else:
|
||||
input_ids = video_outputs.pop("input_ids")
|
||||
input_ids[input_ids == processor.image_token_id] = (
|
||||
processor.video_token_id
|
||||
)
|
||||
video_placeholder = processor.tokenizer.batch_decode(input_ids)[0]
|
||||
input_ids = video_outputs.pop("input_ids")
|
||||
input_ids[input_ids == processor.image_token_id] = (
|
||||
processor.video_token_id
|
||||
)
|
||||
video_placeholder = processor.tokenizer.batch_decode(input_ids)[0]
|
||||
prompt = prompt.replace(
|
||||
"<|begin_of_video|><|video|><|end_of_video|>",
|
||||
video_placeholder,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user