mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 09:06:03 +08:00
[Core] Move multimodal placeholder from chat utils to model definition (#20355)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
cb97f2bfc5
commit
b024a42e93
@ -10,6 +10,22 @@ This document walks you through the steps to extend a basic model so that it acc
|
||||
It is assumed that you have already implemented the model in vLLM according to [these steps][new-model-basic].
|
||||
Further update the model as follows:
|
||||
|
||||
- Implement [get_placeholder_str][vllm.model_executor.models.interfaces.SupportsMultiModal.get_placeholder_str] to define the placeholder string which is used to represent the multi-modal item in the text prompt. This should be consistent with the chat template of the model.
|
||||
|
||||
??? Code
|
||||
|
||||
```python
|
||||
class YourModelForImage2Seq(nn.Module):
|
||||
...
|
||||
|
||||
@classmethod
|
||||
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
|
||||
if modality.startswith("image"):
|
||||
return "<image>"
|
||||
|
||||
raise ValueError("Only image modality is supported")
|
||||
```
|
||||
|
||||
- Reserve a keyword parameter in [forward][torch.nn.Module.forward] for each input tensor that corresponds to a multi-modal input, as shown in the following example:
|
||||
|
||||
```diff
|
||||
|
||||
@ -33,7 +33,6 @@ class RequestOutput:
|
||||
class MockModelConfig:
|
||||
use_async_output_proc = True
|
||||
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
|
||||
mm_placeholder_str_override: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
|
||||
class MockEngine:
|
||||
|
||||
@ -263,26 +263,6 @@ def test_media_io_kwargs_parser(arg, expected):
|
||||
assert args.media_io_kwargs == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("arg", "expected"), [
|
||||
(None, dict()),
|
||||
('{"video":"<|video_placeholder|>"}', {
|
||||
"video": "<|video_placeholder|>"
|
||||
}),
|
||||
('{"video":"<|video_placeholder|>", "image": "<|image_placeholder|>"}', {
|
||||
"video": "<|video_placeholder|>",
|
||||
"image": "<|image_placeholder|>"
|
||||
}),
|
||||
])
|
||||
def test_mm_placeholder_str_override_parser(arg, expected):
|
||||
parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
|
||||
if arg is None:
|
||||
args = parser.parse_args([])
|
||||
else:
|
||||
args = parser.parse_args(["--mm-placeholder-str-override", arg])
|
||||
|
||||
assert args.mm_placeholder_str_override == expected
|
||||
|
||||
|
||||
def test_compilation_config():
|
||||
parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
|
||||
|
||||
|
||||
@ -41,7 +41,6 @@ class MockModelConfig:
|
||||
encoder_config = None
|
||||
generation_config: str = "auto"
|
||||
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
|
||||
mm_placeholder_str_override: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
def get_diff_sampling_param(self):
|
||||
return self.diff_sampling_param or {}
|
||||
|
||||
@ -350,8 +350,6 @@ class ModelConfig:
|
||||
"""Additional args passed to process media inputs, keyed by modalities.
|
||||
For example, to set num_frames for video, set
|
||||
`--media-io-kwargs '{"video": {"num_frames": 40} }'` """
|
||||
mm_placeholder_str_override: dict[str, str] = field(default_factory=dict)
|
||||
"""Optionally override placeholder string for given modalities."""
|
||||
use_async_output_proc: bool = True
|
||||
"""Whether to use async output processor."""
|
||||
config_format: Union[str, ConfigFormat] = ConfigFormat.AUTO.value
|
||||
@ -661,7 +659,7 @@ class ModelConfig:
|
||||
return self._architecture
|
||||
|
||||
@property
|
||||
def model_info(self) -> dict[str, Any]:
|
||||
def model_info(self):
|
||||
return self._model_info
|
||||
|
||||
def maybe_pull_model_tokenizer_for_s3(self, model: str,
|
||||
@ -701,7 +699,6 @@ class ModelConfig:
|
||||
return MultiModalConfig(
|
||||
limit_per_prompt=self.limit_mm_per_prompt,
|
||||
media_io_kwargs=self.media_io_kwargs,
|
||||
mm_placeholder_str_override=self.mm_placeholder_str_override,
|
||||
mm_processor_kwargs=self.mm_processor_kwargs,
|
||||
disable_mm_preprocessor_cache=self.
|
||||
disable_mm_preprocessor_cache)
|
||||
@ -3096,9 +3093,6 @@ class MultiModalConfig:
|
||||
For example, to set num_frames for video, set
|
||||
`--media-io-kwargs '{"video": {"num_frames": 40} }'` """
|
||||
|
||||
mm_placeholder_str_override: dict[str, str] = field(default_factory=dict)
|
||||
"""Optionally override placeholder string for given modalities."""
|
||||
|
||||
mm_processor_kwargs: Optional[dict[str, object]] = None
|
||||
"""
|
||||
Overrides for the multi-modal processor obtained from
|
||||
|
||||
@ -373,8 +373,6 @@ class EngineArgs:
|
||||
media_io_kwargs: dict[str, dict[str,
|
||||
Any]] = get_field(MultiModalConfig,
|
||||
"media_io_kwargs")
|
||||
mm_placeholder_str_override: dict[str, str] = \
|
||||
get_field(MultiModalConfig, "mm_placeholder_str_override")
|
||||
mm_processor_kwargs: Optional[Dict[str, Any]] = \
|
||||
MultiModalConfig.mm_processor_kwargs
|
||||
disable_mm_preprocessor_cache: bool = \
|
||||
@ -759,9 +757,6 @@ class EngineArgs:
|
||||
**multimodal_kwargs["limit_per_prompt"])
|
||||
multimodal_group.add_argument("--media-io-kwargs",
|
||||
**multimodal_kwargs["media_io_kwargs"])
|
||||
multimodal_group.add_argument(
|
||||
"--mm-placeholder-str-override",
|
||||
**multimodal_kwargs["mm_placeholder_str_override"])
|
||||
multimodal_group.add_argument(
|
||||
"--mm-processor-kwargs",
|
||||
**multimodal_kwargs["mm_processor_kwargs"])
|
||||
@ -987,7 +982,6 @@ class EngineArgs:
|
||||
served_model_name=self.served_model_name,
|
||||
limit_mm_per_prompt=self.limit_mm_per_prompt,
|
||||
media_io_kwargs=self.media_io_kwargs,
|
||||
mm_placeholder_str_override=self.mm_placeholder_str_override,
|
||||
use_async_output_proc=not self.disable_async_output_proc,
|
||||
config_format=self.config_format,
|
||||
mm_processor_kwargs=self.mm_processor_kwargs,
|
||||
|
||||
@ -6,7 +6,7 @@ import json
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import defaultdict, deque
|
||||
from collections.abc import Awaitable, Iterable
|
||||
from functools import cache, lru_cache, partial
|
||||
from functools import cached_property, lru_cache, partial
|
||||
from pathlib import Path
|
||||
from typing import (Any, Callable, Generic, Literal, Optional, TypeVar, Union,
|
||||
cast)
|
||||
@ -37,6 +37,8 @@ from typing_extensions import Required, TypeAlias, TypedDict
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader import get_model_cls
|
||||
from vllm.model_executor.models import SupportsMultiModal
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict
|
||||
from vllm.multimodal.utils import MediaConnector
|
||||
# yapf: disable
|
||||
@ -492,6 +494,10 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
|
||||
def model_config(self) -> ModelConfig:
|
||||
return self._model_config
|
||||
|
||||
@cached_property
|
||||
def model_cls(self):
|
||||
return get_model_cls(self.model_config)
|
||||
|
||||
@property
|
||||
def allowed_local_media_path(self):
|
||||
return self._model_config.allowed_local_media_path
|
||||
@ -500,89 +506,6 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
|
||||
def mm_registry(self):
|
||||
return MULTIMODAL_REGISTRY
|
||||
|
||||
@staticmethod
|
||||
@cache
|
||||
def _cached_token_str(tokenizer: AnyTokenizer, token_index: int) -> str:
|
||||
return tokenizer.decode(token_index)
|
||||
|
||||
def _placeholder_str(self, modality: ModalityStr,
|
||||
current_count: int) -> Optional[str]:
|
||||
if modality in self._model_config.mm_placeholder_str_override:
|
||||
return self._model_config.mm_placeholder_str_override[modality]
|
||||
|
||||
# TODO: Let user specify how to insert image tokens into prompt
|
||||
# (similar to chat template)
|
||||
hf_config = self._model_config.hf_config
|
||||
model_type = hf_config.model_type
|
||||
|
||||
if modality in ("image", "image_embeds"):
|
||||
if model_type == "chatglm":
|
||||
return "<|begin_of_image|><|endoftext|><|end_of_image|>"
|
||||
if model_type == "glm4v":
|
||||
return "<|begin_of_image|><|image|><|end_of_image|>"
|
||||
if model_type in ("phi3_v", "phi4mm"):
|
||||
return f"<|image_{current_count}|>"
|
||||
if model_type in ("minicpmo", "minicpmv"):
|
||||
return "(<image>./</image>)"
|
||||
if model_type in ("blip-2", "florence2", "fuyu", "paligemma",
|
||||
"pixtral", "mistral3"):
|
||||
# These models do not use image tokens in the prompt
|
||||
return None
|
||||
if model_type == "qwen":
|
||||
return f"Picture {current_count}: <img></img>"
|
||||
if model_type.startswith("llava"):
|
||||
return self._cached_token_str(self._tokenizer,
|
||||
hf_config.image_token_index)
|
||||
|
||||
if model_type in ("aya_vision", "chameleon", "deepseek_vl_v2",
|
||||
"internvl_chat", "ovis", "skywork_chat",
|
||||
"NVLM_D", "h2ovl_chat", "idefics3", "smolvlm"):
|
||||
return "<image>"
|
||||
if model_type in ("mllama", "llama4"):
|
||||
return "<|image|>"
|
||||
if model_type in ("qwen2_vl", "qwen2_5_vl", "keye", "Keye"):
|
||||
return "<|vision_start|><|image_pad|><|vision_end|>"
|
||||
if model_type == "qwen2_5_omni":
|
||||
return "<|vision_start|><|IMAGE|><|vision_end|>"
|
||||
if model_type == "molmo":
|
||||
return ""
|
||||
if model_type == "aria":
|
||||
return "<|fim_prefix|><|img|><|fim_suffix|>"
|
||||
if model_type == "gemma3":
|
||||
return "<start_of_image>"
|
||||
if model_type == "kimi_vl":
|
||||
return "<|media_start|>image<|media_content|><|media_pad|><|media_end|>" # noqa: E501
|
||||
|
||||
raise TypeError(f"Unknown {modality} model type: {model_type}")
|
||||
elif modality == "audio":
|
||||
if model_type in ("ultravox", "granite_speech"):
|
||||
return "<|audio|>"
|
||||
if model_type == "phi4mm":
|
||||
return f"<|audio_{current_count}|>"
|
||||
if model_type in ("qwen2_audio", "qwen2_5_omni"):
|
||||
return (f"Audio {current_count}: "
|
||||
f"<|audio_bos|><|AUDIO|><|audio_eos|>")
|
||||
if model_type == "minicpmo":
|
||||
return "(<audio>./</audio>)"
|
||||
raise TypeError(f"Unknown model type: {model_type}")
|
||||
elif modality == "video":
|
||||
if model_type == "internvl_chat":
|
||||
return "<video>"
|
||||
if model_type == "glm4v":
|
||||
return "<|begin_of_video|><|video|><|end_of_video|>"
|
||||
if model_type in ("qwen2_vl", "qwen2_5_vl", "keye", "Keye"):
|
||||
return "<|vision_start|><|video_pad|><|vision_end|>"
|
||||
if model_type == "qwen2_5_omni":
|
||||
return "<|vision_start|><|VIDEO|><|vision_end|>"
|
||||
if model_type in ("minicpmo", "minicpmv"):
|
||||
return "(<video>./</video>)"
|
||||
if model_type.startswith("llava"):
|
||||
return self._cached_token_str(self._tokenizer,
|
||||
hf_config.video_token_index)
|
||||
raise TypeError(f"Unknown {modality} model type: {model_type}")
|
||||
else:
|
||||
raise TypeError(f"Unknown modality: {modality}")
|
||||
|
||||
def add(self, modality: ModalityStr, item: _T) -> Optional[str]:
|
||||
"""
|
||||
Add a multi-modal item to the current prompt and returns the
|
||||
@ -590,6 +513,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
|
||||
"""
|
||||
mm_registry = self.mm_registry
|
||||
model_config = self.model_config
|
||||
model_cls = cast(SupportsMultiModal, self.model_cls)
|
||||
|
||||
input_modality = modality.replace("_embeds", "")
|
||||
|
||||
@ -614,7 +538,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
|
||||
|
||||
self._items_by_modality[modality].append(item)
|
||||
|
||||
return self._placeholder_str(modality, current_count)
|
||||
return model_cls.get_placeholder_str(modality, current_count)
|
||||
|
||||
@abstractmethod
|
||||
def create_parser(self) -> "BaseMultiModalContentParser":
|
||||
|
||||
@ -5,6 +5,7 @@ import io
|
||||
import math
|
||||
import time
|
||||
from collections.abc import AsyncGenerator
|
||||
from functools import cached_property
|
||||
from math import ceil
|
||||
from typing import Callable, Literal, Optional, TypeVar, Union, cast
|
||||
|
||||
@ -24,7 +25,8 @@ from vllm.entrypoints.openai.serving_engine import (OpenAIServing,
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.inputs.data import PromptType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader.utils import get_model_architecture
|
||||
from vllm.model_executor.model_loader import get_model_cls
|
||||
from vllm.model_executor.models import SupportsTranscription
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.transformers_utils.processor import cached_get_processor
|
||||
from vllm.utils import PlaceholderModule
|
||||
@ -76,24 +78,29 @@ class OpenAISpeechToText(OpenAIServing):
|
||||
self.model_sr = processor.feature_extractor.sampling_rate
|
||||
self.hop_length = processor.feature_extractor.hop_length
|
||||
self.task_type = task_type
|
||||
self.model_cls, _ = get_model_architecture(model_config)
|
||||
|
||||
if self.default_sampling_params:
|
||||
logger.info(
|
||||
"Overwriting default completion sampling param with: %s",
|
||||
self.default_sampling_params)
|
||||
|
||||
@cached_property
|
||||
def model_cls(self):
|
||||
return get_model_cls(self.model_config)
|
||||
|
||||
async def _preprocess_speech_to_text(
|
||||
self,
|
||||
request: SpeechToTextRequest,
|
||||
audio_data: bytes,
|
||||
) -> tuple[list[PromptType], float]:
|
||||
model_cls = cast(SupportsTranscription, self.model_cls)
|
||||
|
||||
# Validate request
|
||||
# TODO language should be optional and can be guessed.
|
||||
# For now we default to en. See
|
||||
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/generation_whisper.py#L1520
|
||||
lang = request.language or "en"
|
||||
self.model_cls.validate_language(lang) # type: ignore[attr-defined]
|
||||
model_cls.validate_language(lang)
|
||||
|
||||
if len(audio_data) / 1024**2 > MAX_AUDIO_CLIP_FILESIZE_MB:
|
||||
raise ValueError("Maximum file size exceeded.")
|
||||
@ -117,9 +124,8 @@ class OpenAISpeechToText(OpenAIServing):
|
||||
},
|
||||
},
|
||||
"decoder_prompt":
|
||||
self.model_cls.
|
||||
get_decoder_prompt( # type: ignore[attr-defined]
|
||||
lang, self.task_type, request.prompt)
|
||||
model_cls.get_decoder_prompt(lang, self.task_type,
|
||||
request.prompt)
|
||||
}
|
||||
prompts.append(cast(PromptType, prompt))
|
||||
return prompts, duration
|
||||
|
||||
@ -18,7 +18,7 @@ from vllm.model_executor.model_loader.sharded_state_loader import (
|
||||
ShardedStateLoader)
|
||||
from vllm.model_executor.model_loader.tensorizer_loader import TensorizerLoader
|
||||
from vllm.model_executor.model_loader.utils import (
|
||||
get_architecture_class_name, get_model_architecture)
|
||||
get_architecture_class_name, get_model_architecture, get_model_cls)
|
||||
|
||||
|
||||
def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
|
||||
@ -65,6 +65,7 @@ __all__ = [
|
||||
"get_model_loader",
|
||||
"get_architecture_class_name",
|
||||
"get_model_architecture",
|
||||
"get_model_cls",
|
||||
"BaseModelLoader",
|
||||
"BitsAndBytesModelLoader",
|
||||
"GGUFModelLoader",
|
||||
|
||||
@ -13,7 +13,7 @@ import time
|
||||
from collections.abc import Generator
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from typing import Any, BinaryIO, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, BinaryIO, Optional, Union
|
||||
|
||||
import regex as re
|
||||
import torch
|
||||
@ -24,12 +24,14 @@ from transformers import PretrainedConfig
|
||||
import vllm.envs as envs
|
||||
from vllm.config import (ModelConfig, ParallelConfig, VllmConfig,
|
||||
set_current_vllm_config)
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
from vllm.utils import FlexibleArgumentParser, PlaceholderModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
|
||||
try:
|
||||
from tensorizer import (DecryptionParams, EncryptionParams,
|
||||
TensorDeserializer, TensorSerializer)
|
||||
@ -503,7 +505,7 @@ def serialize_vllm_model(
|
||||
return model
|
||||
|
||||
|
||||
def tensorize_vllm_model(engine_args: EngineArgs,
|
||||
def tensorize_vllm_model(engine_args: "EngineArgs",
|
||||
tensorizer_config: TensorizerConfig,
|
||||
generate_keyfile: bool = True):
|
||||
"""Utility to load a model and then serialize it with Tensorizer
|
||||
|
||||
@ -253,6 +253,10 @@ def get_model_architecture(
|
||||
return model_cls, arch
|
||||
|
||||
|
||||
def get_model_cls(model_config: ModelConfig) -> type[nn.Module]:
|
||||
return get_model_architecture(model_config)[0]
|
||||
|
||||
|
||||
def get_architecture_class_name(model_config: ModelConfig) -> str:
|
||||
return get_model_architecture(model_config)[1]
|
||||
|
||||
|
||||
@ -2,9 +2,9 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from .interfaces import (HasInnerState, SupportsLoRA, SupportsMultiModal,
|
||||
SupportsPP, SupportsV0Only, has_inner_state,
|
||||
supports_lora, supports_multimodal, supports_pp,
|
||||
supports_v0_only)
|
||||
SupportsPP, SupportsTranscription, SupportsV0Only,
|
||||
has_inner_state, supports_lora, supports_multimodal,
|
||||
supports_pp, supports_transcription, supports_v0_only)
|
||||
from .interfaces_base import (VllmModelForPooling, VllmModelForTextGeneration,
|
||||
is_pooling_model, is_text_generation_model)
|
||||
from .registry import ModelRegistry
|
||||
@ -23,6 +23,8 @@ __all__ = [
|
||||
"supports_multimodal",
|
||||
"SupportsPP",
|
||||
"supports_pp",
|
||||
"SupportsTranscription",
|
||||
"supports_transcription",
|
||||
"SupportsV0Only",
|
||||
"supports_v0_only",
|
||||
]
|
||||
|
||||
@ -499,6 +499,13 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
},
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
|
||||
if modality.startswith("image"):
|
||||
return "<|fim_prefix|><|img|><|fim_suffix|>"
|
||||
|
||||
raise ValueError("Only image modality is supported")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
|
||||
@ -304,6 +304,13 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
"lm_head.": "language_model.lm_head.",
|
||||
})
|
||||
|
||||
@classmethod
|
||||
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
|
||||
if modality.startswith("image"):
|
||||
return "<image>"
|
||||
|
||||
raise ValueError("Only image modality is supported")
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config: AyaVisionConfig = vllm_config.model_config.hf_config
|
||||
|
||||
@ -507,6 +507,13 @@ class Blip2MultiModalProcessor(BaseMultiModalProcessor[Blip2ProcessingInfo]):
|
||||
class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
SupportsQuant):
|
||||
|
||||
@classmethod
|
||||
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
|
||||
if modality.startswith("image"):
|
||||
return None
|
||||
|
||||
raise ValueError("Only image modality is supported")
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
|
||||
super().__init__()
|
||||
|
||||
@ -933,6 +933,13 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
"gate_up_proj": ["gate_proj", "up_proj"]
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
|
||||
if modality.startswith("image"):
|
||||
return "<image>"
|
||||
|
||||
raise ValueError("Only image modality is supported")
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
|
||||
@ -315,6 +315,13 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
"language.": "language_model.",
|
||||
})
|
||||
|
||||
@classmethod
|
||||
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
|
||||
if modality.startswith("image"):
|
||||
return "<image>"
|
||||
|
||||
raise ValueError("Only image modality is supported")
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config: DeepseekVLV2Config = vllm_config.model_config.hf_config
|
||||
|
||||
@ -877,6 +877,13 @@ class Florence2MultiModalProcessor(
|
||||
class Florence2ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
SupportsV0Only):
|
||||
|
||||
@classmethod
|
||||
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
|
||||
if modality.startswith("image"):
|
||||
return None
|
||||
|
||||
raise ValueError("Only image modality is supported")
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
|
||||
@ -254,6 +254,13 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
"lm_head.": "language_model.lm_head.",
|
||||
})
|
||||
|
||||
@classmethod
|
||||
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
|
||||
if modality.startswith("image"):
|
||||
return None
|
||||
|
||||
raise ValueError("Only image modality is supported")
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
|
||||
@ -483,6 +483,13 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
"lm_head.": "language_model.lm_head.",
|
||||
})
|
||||
|
||||
@classmethod
|
||||
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
|
||||
if modality.startswith("image"):
|
||||
return "<start_of_image>"
|
||||
|
||||
raise ValueError("Only image modality is supported")
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
|
||||
@ -1257,6 +1257,15 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
"model.visual.": "visual.",
|
||||
})
|
||||
|
||||
@classmethod
|
||||
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
|
||||
if modality.startswith("image"):
|
||||
return "<|begin_of_image|><|image|><|end_of_image|>"
|
||||
if modality.startswith("video"):
|
||||
return "<|begin_of_video|><|video|><|end_of_video|>"
|
||||
|
||||
raise ValueError("Only image or video modality is supported")
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config: Glm4vConfig = vllm_config.model_config.hf_config
|
||||
|
||||
@ -540,6 +540,13 @@ class GLM4VForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP,
|
||||
connector="transformer.vision.linear_proj",
|
||||
tower_model="transformer.vision.transformer")
|
||||
|
||||
@classmethod
|
||||
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
|
||||
if modality.startswith("image"):
|
||||
return "<|begin_of_image|><|endoftext|><|end_of_image|>"
|
||||
|
||||
raise ValueError("Only image modality is supported")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
|
||||
@ -533,6 +533,13 @@ class GraniteSpeechForConditionalGeneration(
|
||||
],
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
|
||||
if modality.startswith("audio"):
|
||||
return "<|audio|>"
|
||||
|
||||
raise ValueError("Only audio modality is supported")
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
|
||||
@ -591,6 +591,13 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
],
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
|
||||
if modality.startswith("image"):
|
||||
return "<image>"
|
||||
|
||||
raise ValueError("Only image modality is supported")
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
|
||||
@ -46,6 +46,13 @@ class SupportsMultiModal(Protocol):
|
||||
MRO of your model class.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
|
||||
"""
|
||||
Get the placeholder text for the `i`th `modality` item in the prompt.
|
||||
"""
|
||||
...
|
||||
|
||||
def get_multimodal_embeddings(self,
|
||||
**kwargs: object) -> MultiModalEmbeddings:
|
||||
"""
|
||||
|
||||
@ -1023,6 +1023,15 @@ class InternVLMultiModalProcessor(
|
||||
class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
SupportsLoRA):
|
||||
|
||||
@classmethod
|
||||
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
|
||||
if modality.startswith("image"):
|
||||
return "<image>"
|
||||
if modality.startswith("video"):
|
||||
return "<video>"
|
||||
|
||||
raise ValueError("Only image or video modality is supported")
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
|
||||
super().__init__()
|
||||
|
||||
|
||||
@ -1343,6 +1343,15 @@ class KeyeForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA,
|
||||
"model.": "language_model.model.",
|
||||
})
|
||||
|
||||
@classmethod
|
||||
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
|
||||
if modality.startswith("image"):
|
||||
return "<|vision_start|><|image_pad|><|vision_end|>"
|
||||
if modality.startswith("video"):
|
||||
return "<|vision_start|><|video_pad|><|vision_end|>"
|
||||
|
||||
raise ValueError("Only image or video modality is supported")
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config: PretrainedConfig = vllm_config.model_config.hf_config
|
||||
|
||||
@ -264,6 +264,13 @@ class KimiVLMultiModalProcessor(BaseMultiModalProcessor[KimiVLProcessingInfo]):
|
||||
dummy_inputs=KimiVLDummyInputsBuilder)
|
||||
class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
|
||||
@classmethod
|
||||
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
|
||||
if modality.startswith("image"):
|
||||
return "<|media_start|>image<|media_content|><|media_pad|><|media_end|>"
|
||||
|
||||
raise ValueError("Only image modality is supported")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
|
||||
@ -511,6 +511,13 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
"lm_head.": "language_model.lm_head.",
|
||||
})
|
||||
|
||||
@classmethod
|
||||
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
|
||||
if modality.startswith("image"):
|
||||
return "<image>"
|
||||
|
||||
raise ValueError("Only image modality is supported")
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
|
||||
super().__init__()
|
||||
|
||||
|
||||
@ -215,6 +215,13 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
"lm_head.": "language_model.lm_head.",
|
||||
})
|
||||
|
||||
@classmethod
|
||||
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
|
||||
if modality.startswith("image"):
|
||||
return "<image>"
|
||||
|
||||
raise ValueError("Only image modality is supported")
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
|
||||
@ -281,6 +281,15 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
"lm_head.": "language_model.lm_head.",
|
||||
})
|
||||
|
||||
@classmethod
|
||||
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
|
||||
if modality.startswith("image"):
|
||||
return "<image>"
|
||||
if modality.startswith("video"):
|
||||
return "<video>"
|
||||
|
||||
raise ValueError("Only image or video modality is supported")
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
|
||||
@ -446,6 +446,15 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
"lm_head.": "language_model.lm_head.",
|
||||
})
|
||||
|
||||
@classmethod
|
||||
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
|
||||
if modality.startswith("image"):
|
||||
return "<image>"
|
||||
if modality.startswith("video"):
|
||||
return "<video>"
|
||||
|
||||
raise ValueError("Only image or video modality is supported")
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
|
||||
@ -511,6 +511,17 @@ class MiniCPMO(MiniCPMV2_6):
|
||||
],
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
|
||||
if modality.startswith("image"):
|
||||
return "(<image>./</image>)"
|
||||
if modality.startswith("video"):
|
||||
return "(<video>./</video>)"
|
||||
if modality.startswith("audio"):
|
||||
return "(<audio>./</audio>)"
|
||||
|
||||
raise ValueError("Only image, video or audio modality is supported")
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||
self.apm = self.init_audio_module(vllm_config=vllm_config,
|
||||
|
||||
@ -735,6 +735,15 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
instantiated.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
|
||||
if modality.startswith("image"):
|
||||
return "(<image>./</image>)"
|
||||
if modality.startswith("video"):
|
||||
return "(<video>./</video>)"
|
||||
|
||||
raise ValueError("Only image or video modality is supported")
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
config = vllm_config.model_config.hf_config
|
||||
multimodal_config = vllm_config.model_config.multimodal_config
|
||||
|
||||
@ -158,6 +158,13 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
"gate_up_proj": ["gate_proj", "up_proj"]
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
|
||||
if modality.startswith("image"):
|
||||
return "<image>"
|
||||
|
||||
raise ValueError("Only image modality is supported")
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
|
||||
super().__init__()
|
||||
|
||||
|
||||
@ -401,6 +401,13 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsLoRA,
|
||||
"lm_head.": "language_model.lm_head.",
|
||||
})
|
||||
|
||||
@classmethod
|
||||
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
|
||||
if modality.startswith("image"):
|
||||
return None
|
||||
|
||||
raise ValueError("Only image modality is supported")
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
|
||||
super().__init__()
|
||||
|
||||
|
||||
@ -1276,6 +1276,13 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
},
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
|
||||
if modality.startswith("image"):
|
||||
return "<|image|>"
|
||||
|
||||
raise ValueError("Only image modality is supported")
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config: MllamaConfig = vllm_config.model_config.hf_config
|
||||
|
||||
@ -719,6 +719,13 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
|
||||
if modality.startswith("image"):
|
||||
return "<|image|>"
|
||||
|
||||
raise ValueError("Only image modality is supported")
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
|
||||
@ -1366,6 +1366,13 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
|
||||
"merged_linear": ["gate_proj", "up_proj"] # image_projector
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
|
||||
if modality.startswith("image"):
|
||||
return None
|
||||
|
||||
raise ValueError("Only image modality is supported")
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
|
||||
@ -405,6 +405,13 @@ class OvisMultiModalProcessor(BaseMultiModalProcessor[OvisProcessingInfo]):
|
||||
dummy_inputs=OvisDummyInputsBuilder)
|
||||
class Ovis(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
@classmethod
|
||||
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
|
||||
if modality.startswith("image"):
|
||||
return "<image>"
|
||||
|
||||
raise ValueError("Only image modality is supported")
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
|
||||
@ -240,6 +240,13 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
"lm_head.": "language_model.lm_head.",
|
||||
})
|
||||
|
||||
@classmethod
|
||||
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
|
||||
if modality.startswith("image"):
|
||||
return None
|
||||
|
||||
raise ValueError("Only image modality is supported")
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
|
||||
@ -520,6 +520,13 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
"model.": "language_model.model.",
|
||||
})
|
||||
|
||||
@classmethod
|
||||
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
|
||||
if modality.startswith("image"):
|
||||
return f"<|image_{i}|>"
|
||||
|
||||
raise ValueError("Only image modality is supported")
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
|
||||
@ -902,6 +902,15 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
|
||||
},
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
|
||||
if modality.startswith("image"):
|
||||
return f"<|image_{i}|>"
|
||||
if modality.startswith("audio"):
|
||||
return f"<|audio_{i}|>"
|
||||
|
||||
raise ValueError("Only image or audio modality is supported")
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
|
||||
@ -327,6 +327,13 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo]
|
||||
class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
SupportsPP):
|
||||
|
||||
@classmethod
|
||||
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
|
||||
if modality.startswith("image"):
|
||||
return None
|
||||
|
||||
raise ValueError("Only image modality is supported")
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
|
||||
@ -118,6 +118,13 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal,
|
||||
SupportsV0Only):
|
||||
""" Prithvi Masked Autoencoder"""
|
||||
|
||||
@classmethod
|
||||
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
|
||||
if modality.startswith("image"):
|
||||
return None
|
||||
|
||||
raise ValueError("Only image modality is supported")
|
||||
|
||||
def _instantiate_model(self, config: dict) -> Optional[nn.Module]:
|
||||
|
||||
# We might be able/need to support different tasks with this same model
|
||||
|
||||
@ -717,6 +717,17 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
|
||||
"thinker.": "",
|
||||
})
|
||||
|
||||
@classmethod
|
||||
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
|
||||
if modality.startswith("image"):
|
||||
return "<|vision_start|><|IMAGE|><|vision_end|>"
|
||||
if modality.startswith("video"):
|
||||
return "<|vision_start|><|VIDEO|><|vision_end|>"
|
||||
if modality.startswith("audio"):
|
||||
return f"Audio {i}: <|audio_bos|><|AUDIO|><|audio_eos|>"
|
||||
|
||||
raise ValueError("Only image, video or audio modality is supported")
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
thinker_config: Qwen2_5OmniThinkerConfig = (
|
||||
|
||||
@ -835,6 +835,15 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
"model.": "language_model.model.",
|
||||
})
|
||||
|
||||
@classmethod
|
||||
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
|
||||
if modality.startswith("image"):
|
||||
return "<|vision_start|><|image_pad|><|vision_end|>"
|
||||
if modality.startswith("video"):
|
||||
return "<|vision_start|><|video_pad|><|vision_end|>"
|
||||
|
||||
raise ValueError("Only image or video modality is supported")
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config: Qwen2_5_VLConfig = vllm_config.model_config.hf_config
|
||||
|
||||
@ -251,6 +251,13 @@ class Qwen2AudioMultiModalProcessor(
|
||||
class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
SupportsPP):
|
||||
|
||||
@classmethod
|
||||
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
|
||||
if modality.startswith("audio"):
|
||||
return f"Audio {i}: <|audio_bos|><|AUDIO|><|audio_eos|>"
|
||||
|
||||
raise ValueError("Only audio modality is supported")
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
|
||||
@ -1096,6 +1096,15 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
"model.": "language_model.model.",
|
||||
})
|
||||
|
||||
@classmethod
|
||||
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
|
||||
if modality.startswith("image"):
|
||||
return "<|vision_start|><|image_pad|><|vision_end|>"
|
||||
if modality.startswith("video"):
|
||||
return "<|vision_start|><|video_pad|><|vision_end|>"
|
||||
|
||||
raise ValueError("Only image or video modality is supported")
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config: Qwen2VLConfig = vllm_config.model_config.hf_config
|
||||
|
||||
@ -675,6 +675,13 @@ class QwenVLForConditionalGeneration(QWenBaseModel, SupportsPP, SupportsLoRA,
|
||||
connector="transformer.visual.attn_pool",
|
||||
tower_model="transformer.visual.transformer")
|
||||
|
||||
@classmethod
|
||||
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
|
||||
if modality.startswith("image"):
|
||||
return f"Picture {i}: <img></img>"
|
||||
|
||||
raise ValueError("Only image modality is supported")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
|
||||
@ -648,6 +648,13 @@ class SkyworkR1VProcessingInfo(BaseSkyworkR1VProcessingInfo):
|
||||
dummy_inputs=SkyworkR1VDummyInputsBuilder)
|
||||
class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
@classmethod
|
||||
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
|
||||
if modality.startswith("image"):
|
||||
return "<image>"
|
||||
|
||||
raise ValueError("Only image modality is supported")
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
|
||||
super().__init__()
|
||||
|
||||
|
||||
@ -393,6 +393,13 @@ class TarsierForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
"gate_up_proj": ["gate_proj", "up_proj"]
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
|
||||
if modality.startswith("image"):
|
||||
return "<image>"
|
||||
|
||||
raise ValueError("Only image modality is supported")
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
|
||||
super().__init__()
|
||||
config: TarsierHfConfig = vllm_config.model_config.hf_config
|
||||
|
||||
@ -407,6 +407,13 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_prefix={"audio_tower.model.encoder.": "audio_tower."})
|
||||
|
||||
@classmethod
|
||||
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
|
||||
if modality.startswith("audio"):
|
||||
return "<|audio|>"
|
||||
|
||||
raise ValueError("Only audio modality is supported")
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
|
||||
@ -761,6 +761,35 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
|
||||
".fc2.": ".mlp.fc2."
|
||||
})
|
||||
|
||||
@classmethod
|
||||
def validate_language(cls, language: str) -> bool:
|
||||
if language in ISO639_1_SUPPORTED_LANGS:
|
||||
return True
|
||||
elif language in ISO639_1_OTHER_LANGS:
|
||||
logger.warning(
|
||||
"The selected language %s has limited accuracy with"
|
||||
" reported WER>=0.5. Results may be less accurate "
|
||||
"for this choice.", language)
|
||||
return True
|
||||
else:
|
||||
raise ValueError(f"Unsupported language: {language}."
|
||||
"Language should be one of:" +
|
||||
f" {list(ISO639_1_SUPPORTED_LANGS.values())}" +
|
||||
f"or {list(ISO639_1_OTHER_LANGS.values())}")
|
||||
|
||||
@classmethod
|
||||
def get_decoder_prompt(cls, language: str, task_type: str,
|
||||
prompt: str) -> str:
|
||||
return (f"<|startoftranscript|><|{language}|><|{task_type}|>"
|
||||
f"<|notimestamps|>{prompt}")
|
||||
|
||||
@classmethod
|
||||
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
|
||||
if modality.startswith("audio"):
|
||||
return None
|
||||
|
||||
raise ValueError("Only audio modality is supported")
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
@ -840,28 +869,6 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
|
||||
weights = _create_fake_bias_for_k_proj(weights)
|
||||
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
||||
|
||||
@classmethod
|
||||
def validate_language(cls, language: str) -> bool:
|
||||
if language in ISO639_1_SUPPORTED_LANGS:
|
||||
return True
|
||||
elif language in ISO639_1_OTHER_LANGS:
|
||||
logger.warning(
|
||||
"The selected language %s has limited accuracy with"
|
||||
" reported WER>=0.5. Results may be less accurate "
|
||||
"for this choice.", language)
|
||||
return True
|
||||
else:
|
||||
raise ValueError(f"Unsupported language: {language}."
|
||||
"Language should be one of:" +
|
||||
f" {list(ISO639_1_SUPPORTED_LANGS.values())}" +
|
||||
f"or {list(ISO639_1_OTHER_LANGS.values())}")
|
||||
|
||||
@classmethod
|
||||
def get_decoder_prompt(cls, language: str, task_type: str,
|
||||
prompt: str) -> str:
|
||||
return (f"<|startoftranscript|><|{language}|><|{task_type}|>"
|
||||
f"<|notimestamps|>{prompt}")
|
||||
|
||||
|
||||
def _create_fake_bias_for_k_proj(
|
||||
weights: Iterable[tuple[str, torch.Tensor]]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user