mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-24 19:15:57 +08:00
[Multi Modal] Configurable MM Profiling (#25631)
Signed-off-by: wwl2755 <wangwenlong2755@gmail.com> Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
fa29d31f0d
commit
2bcc745042
@ -258,17 +258,21 @@ Assuming that the memory usage increases with the number of tokens, the dummy in
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
|
||||
) -> MultiModalDataDict:
|
||||
num_images = mm_counts.get("image", 0)
|
||||
|
||||
target_width, target_height = \
|
||||
self.info.get_image_size_with_most_features()
|
||||
|
||||
image_overrides = mm_options.get("image") if mm_options else None
|
||||
|
||||
return {
|
||||
"image":
|
||||
self._get_dummy_images(width=target_width,
|
||||
height=target_height,
|
||||
num_images=num_images)
|
||||
num_images=num_images,
|
||||
overrides=image_overrides)
|
||||
}
|
||||
```
|
||||
|
||||
@ -438,16 +442,20 @@ Assuming that the memory usage increases with the number of tokens, the dummy in
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
|
||||
) -> MultiModalDataDict:
|
||||
target_width, target_height = \
|
||||
self.info.get_image_size_with_most_features()
|
||||
num_images = mm_counts.get("image", 0)
|
||||
|
||||
image_overrides = mm_options.get("image") if mm_options else None
|
||||
|
||||
return {
|
||||
"image":
|
||||
self._get_dummy_images(width=target_width,
|
||||
height=target_height,
|
||||
num_images=num_images)
|
||||
num_images=num_images,
|
||||
overrides=image_overrides)
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
@ -12,6 +12,8 @@ from mistral_common.protocol.instruct.request import ChatCompletionRequest
|
||||
from PIL import Image
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.config.multimodal import (AudioDummyOptions, BaseDummyOptions,
|
||||
ImageDummyOptions, VideoDummyOptions)
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict
|
||||
from vllm.multimodal.cache import MultiModalProcessorOnlyCache
|
||||
from vllm.multimodal.inputs import MultiModalInputs
|
||||
@ -112,12 +114,26 @@ def _test_processing_correctness(
|
||||
|
||||
processing_info = factories.info(ctx)
|
||||
supported_mm_limits = processing_info.get_supported_mm_limits()
|
||||
limit_mm_per_prompt = {
|
||||
# Keep integer limits for local data generation
|
||||
limit_mm_per_prompt_ints = {
|
||||
modality: 3 if limit is None else limit
|
||||
for modality, limit in supported_mm_limits.items()
|
||||
}
|
||||
|
||||
model_config.get_multimodal_config().limit_per_prompt = limit_mm_per_prompt
|
||||
def _to_dummy_options(modality: str, count: int) -> BaseDummyOptions:
|
||||
if modality == "video":
|
||||
return VideoDummyOptions(count=count)
|
||||
if modality == "image":
|
||||
return ImageDummyOptions(count=count)
|
||||
if modality == "audio":
|
||||
return AudioDummyOptions(count=count)
|
||||
return BaseDummyOptions(count=count)
|
||||
|
||||
# Assign normalized DummyOptions to the model config
|
||||
model_config.get_multimodal_config().limit_per_prompt = {
|
||||
modality: _to_dummy_options(modality, count)
|
||||
for modality, count in limit_mm_per_prompt_ints.items()
|
||||
}
|
||||
|
||||
baseline_processor = factories.build_processor(ctx, cache=None)
|
||||
cached_processor = factories.build_processor(ctx, cache=cache)
|
||||
@ -150,7 +166,7 @@ def _test_processing_correctness(
|
||||
k:
|
||||
[(input_to_hit[k] if rng.rand() < hit_rate else input_factory[k]())
|
||||
for _ in range(rng.randint(limit + 1))]
|
||||
for k, limit in limit_mm_per_prompt.items()
|
||||
for k, limit in limit_mm_per_prompt_ints.items()
|
||||
}
|
||||
|
||||
mm_counts = {k: len(vs) for k, vs in mm_data.items()}
|
||||
|
||||
@ -17,23 +17,23 @@ def test_profiling(model_id: str, max_model_len: int):
|
||||
model_config_kwargs = {
|
||||
"max_model_len": max_model_len,
|
||||
}
|
||||
mm_counts = {"image": 1}
|
||||
ctx = build_model_context(
|
||||
model_id,
|
||||
model_config_kwargs=model_config_kwargs,
|
||||
limit_mm_per_prompt={"image": 1},
|
||||
limit_mm_per_prompt=mm_counts,
|
||||
)
|
||||
|
||||
mm_config = ctx.get_mm_config()
|
||||
processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config)
|
||||
profiler = MultiModalProfiler(processor)
|
||||
|
||||
decoder_dummy_data = profiler.get_decoder_dummy_data(
|
||||
max_model_len,
|
||||
mm_counts=mm_config.limit_per_prompt,
|
||||
mm_counts=mm_counts,
|
||||
)
|
||||
dummy_mm_data = processor.dummy_inputs.get_dummy_processor_inputs(
|
||||
max_model_len,
|
||||
mm_counts=mm_config.limit_per_prompt,
|
||||
mm_counts=mm_counts,
|
||||
)
|
||||
|
||||
hf_config = ctx.get_hf_config(Llama4Config)
|
||||
@ -58,7 +58,7 @@ def test_profiling(model_id: str, max_model_len: int):
|
||||
|
||||
profiled_tokens = profiler.get_mm_max_contiguous_tokens(
|
||||
max_model_len,
|
||||
mm_counts=mm_config.limit_per_prompt,
|
||||
mm_counts=mm_counts,
|
||||
)
|
||||
|
||||
assert total_tokens == profiled_tokens["image"]
|
||||
|
||||
@ -15,6 +15,8 @@ from mistral_common.protocol.instruct.request import ChatCompletionRequest
|
||||
from PIL import Image
|
||||
|
||||
from vllm.config import ModelConfig, VllmConfig, set_current_vllm_config
|
||||
from vllm.config.multimodal import (AudioDummyOptions, BaseDummyOptions,
|
||||
ImageDummyOptions, VideoDummyOptions)
|
||||
from vllm.distributed import (cleanup_dist_env_and_memory,
|
||||
init_distributed_environment,
|
||||
initialize_model_parallel)
|
||||
@ -236,7 +238,20 @@ def test_model_tensor_schema(model_arch: str, model_id: str):
|
||||
modality: 3 if limit is None else limit
|
||||
for modality, limit in supported_mm_limits.items()
|
||||
}
|
||||
model_config.get_multimodal_config().limit_per_prompt = limit_mm_per_prompt
|
||||
|
||||
def _to_dummy_options(modality: str, count: int) -> BaseDummyOptions:
|
||||
if modality == "video":
|
||||
return VideoDummyOptions(count=count)
|
||||
if modality == "image":
|
||||
return ImageDummyOptions(count=count)
|
||||
if modality == "audio":
|
||||
return AudioDummyOptions(count=count)
|
||||
return BaseDummyOptions(count=count)
|
||||
|
||||
model_config.get_multimodal_config().limit_per_prompt = {
|
||||
modality: _to_dummy_options(modality, count)
|
||||
for modality, count in limit_mm_per_prompt.items()
|
||||
}
|
||||
processor = factories.build_processor(ctx, cache=None)
|
||||
|
||||
with initialize_dummy_model(model_cls, model_config) as model:
|
||||
|
||||
@ -276,7 +276,9 @@ class ModelConfig:
|
||||
multimodal_config: Optional[MultiModalConfig] = None
|
||||
"""Configuration for multimodal model. If `None`, this will be inferred
|
||||
from the architecture of `self.model`."""
|
||||
limit_mm_per_prompt: InitVar[Optional[dict[str, int]]] = None
|
||||
limit_mm_per_prompt: InitVar[Optional[dict[str, Union[int,
|
||||
dict[str,
|
||||
int]]]]] = None
|
||||
media_io_kwargs: InitVar[Optional[dict[str, dict[str, Any]]]] = None
|
||||
mm_processor_kwargs: InitVar[Optional[dict[str, Any]]] = None
|
||||
mm_processor_cache_gb: InitVar[Optional[float]] = None
|
||||
|
||||
@ -4,15 +4,45 @@
|
||||
import hashlib
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import field
|
||||
from typing import Any, Literal, Optional
|
||||
from typing import Any, Literal, Optional, Union
|
||||
|
||||
from pydantic import ConfigDict, Field, field_validator
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config.utils import config
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseDummyOptions:
|
||||
"""Base options for generating dummy data during profiling."""
|
||||
count: int = Field(999, ge=0)
|
||||
|
||||
|
||||
@dataclass(config=ConfigDict(extra="forbid"))
|
||||
class VideoDummyOptions(BaseDummyOptions):
|
||||
"""Options for generating dummy video data during profiling."""
|
||||
num_frames: Optional[int] = Field(None, gt=0)
|
||||
width: Optional[int] = Field(None, gt=0)
|
||||
height: Optional[int] = Field(None, gt=0)
|
||||
|
||||
|
||||
@dataclass(config=ConfigDict(extra="forbid"))
|
||||
class ImageDummyOptions(BaseDummyOptions):
|
||||
"""Options for generating dummy image data during profiling."""
|
||||
width: Optional[int] = Field(None, gt=0)
|
||||
height: Optional[int] = Field(None, gt=0)
|
||||
|
||||
|
||||
@dataclass(config=ConfigDict(extra="forbid"))
|
||||
class AudioDummyOptions(BaseDummyOptions):
|
||||
"""Options for generating dummy audio data during profiling."""
|
||||
length: Optional[int] = Field(None, gt=0)
|
||||
|
||||
|
||||
MMEncoderTPMode = Literal["weights", "data"]
|
||||
MMCacheType = Literal["shm", "lru"]
|
||||
DummyOptions = Union[BaseDummyOptions, VideoDummyOptions, ImageDummyOptions,
|
||||
AudioDummyOptions]
|
||||
|
||||
|
||||
@config
|
||||
@ -20,12 +50,22 @@ MMCacheType = Literal["shm", "lru"]
|
||||
class MultiModalConfig:
|
||||
"""Controls the behavior of multimodal models."""
|
||||
|
||||
limit_per_prompt: dict[str, int] = field(default_factory=dict)
|
||||
"""The maximum number of input items allowed per prompt for each modality.
|
||||
Defaults to 1 (V0) or 999 (V1) for each modality.
|
||||
limit_per_prompt: dict[str, DummyOptions] = field(default_factory=dict)
|
||||
"""The maximum number of input items and options allowed per
|
||||
prompt for each modality.
|
||||
Defaults to 999 for each modality.
|
||||
|
||||
For example, to allow up to 16 images and 2 videos per prompt:
|
||||
`{"image": 16, "video": 2}`"""
|
||||
Legacy format (count only):
|
||||
{"image": 16, "video": 2}
|
||||
|
||||
Configurable format (with options):
|
||||
{"video": {"count": 1, "num_frames": 32, "width": 512, "height": 512},
|
||||
"image": {"count": 5, "width": 512, "height": 512}}
|
||||
|
||||
Mixed format (combining both):
|
||||
{"image": 16, "video": {"count": 1, "num_frames": 32, "width": 512,
|
||||
"height": 512}}
|
||||
"""
|
||||
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
|
||||
"""Additional args passed to process media inputs, keyed by modalities.
|
||||
For example, to set num_frames for video, set
|
||||
@ -84,6 +124,27 @@ class MultiModalConfig:
|
||||
from each video to be pruned.
|
||||
"""
|
||||
|
||||
@field_validator("limit_per_prompt", mode="before")
|
||||
@classmethod
|
||||
def _validate_limit_per_prompt(
|
||||
cls, value: dict[str, Union[int,
|
||||
dict[str,
|
||||
int]]]) -> dict[str, DummyOptions]:
|
||||
for k, v in value.items():
|
||||
# Handle legacy format where only count is specified
|
||||
if isinstance(v, int):
|
||||
v = {"count": v}
|
||||
# Convert to the appropriate DummyOptions subclass
|
||||
if k == "video":
|
||||
value[k] = VideoDummyOptions(**v)
|
||||
elif k == "image":
|
||||
value[k] = ImageDummyOptions(**v)
|
||||
elif k == "audio":
|
||||
value[k] = AudioDummyOptions(**v)
|
||||
else:
|
||||
value[k] = BaseDummyOptions(**v)
|
||||
return value
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
WARNING: Whenever a new field is added to this config,
|
||||
@ -106,12 +167,22 @@ class MultiModalConfig:
|
||||
def get_limit_per_prompt(self, modality: str) -> int:
|
||||
"""
|
||||
Get the maximum number of input items allowed per prompt
|
||||
for the given modality.
|
||||
for the given modality (backward compatible).
|
||||
"""
|
||||
return self.limit_per_prompt.get(
|
||||
modality,
|
||||
999 if envs.VLLM_USE_V1 else 1,
|
||||
)
|
||||
limit_data = self.limit_per_prompt.get(modality)
|
||||
|
||||
if limit_data is None:
|
||||
# Unspecified modality is set to 999 by default
|
||||
return 999
|
||||
return limit_data.count
|
||||
|
||||
def get_dummy_options(self, modality: str) -> Optional[BaseDummyOptions]:
|
||||
"""
|
||||
Get the configurable dummy data options for a modality.
|
||||
Returns None if no options are configured for this modality.
|
||||
"""
|
||||
# All values are now DummyOptions after normalization
|
||||
return self.limit_per_prompt.get(modality)
|
||||
|
||||
def merge_mm_processor_kwargs(
|
||||
self,
|
||||
|
||||
@ -376,7 +376,7 @@ class EngineArgs:
|
||||
quantization: Optional[QuantizationMethods] = ModelConfig.quantization
|
||||
enforce_eager: bool = ModelConfig.enforce_eager
|
||||
disable_custom_all_reduce: bool = ParallelConfig.disable_custom_all_reduce
|
||||
limit_mm_per_prompt: dict[str, int] = \
|
||||
limit_mm_per_prompt: dict[str, Union[int, dict[str, int]]] = \
|
||||
get_field(MultiModalConfig, "limit_per_prompt")
|
||||
interleave_mm_strings: bool = MultiModalConfig.interleave_mm_strings
|
||||
media_io_kwargs: dict[str, dict[str,
|
||||
|
||||
@ -10,6 +10,7 @@ from transformers.models.aria.modeling_aria import AriaCrossAttention
|
||||
from transformers.models.aria.processing_aria import AriaProcessor
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.distributed import get_tensor_model_parallel_rank
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
@ -431,17 +432,21 @@ class AriaDummyInputsBuilder(BaseDummyInputsBuilder[AriaProcessingInfo]):
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
|
||||
) -> MultiModalDataDict:
|
||||
vision_config = self.info.get_vision_config()
|
||||
|
||||
max_image_size = vision_config.image_size
|
||||
num_images = mm_counts.get("image", 0)
|
||||
|
||||
image_overrides = mm_options.get("image") if mm_options else None
|
||||
|
||||
return {
|
||||
"image":
|
||||
self._get_dummy_images(width=max_image_size,
|
||||
height=max_image_size,
|
||||
num_images=num_images)
|
||||
num_images=num_images,
|
||||
overrides=image_overrides)
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -16,6 +16,7 @@ from transformers.models.got_ocr2.image_processing_got_ocr2 import (
|
||||
get_optimal_tiled_canvas)
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import MultiModalDataDict, MultiModalKwargsItems
|
||||
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
|
||||
@ -166,16 +167,20 @@ class AyaVisionDummyInputsBuilder(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
|
||||
) -> MultiModalDataDict:
|
||||
num_images = mm_counts.get("image", 0)
|
||||
image_size = \
|
||||
self.info.get_image_size_with_most_features()
|
||||
|
||||
image_overrides = mm_options.get("image") if mm_options else None
|
||||
|
||||
return {
|
||||
"image":
|
||||
self._get_dummy_images(width=image_size.width,
|
||||
height=image_size.height,
|
||||
num_images=num_images)
|
||||
num_images=num_images,
|
||||
overrides=image_overrides)
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -10,6 +10,7 @@ from transformers import (BatchFeature, Blip2Config, Blip2QFormerConfig,
|
||||
apply_chunking_to_forward)
|
||||
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
@ -435,6 +436,7 @@ class Blip2DummyInputsBuilder(BaseDummyInputsBuilder[Blip2ProcessingInfo]):
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
|
||||
) -> MultiModalDataDict:
|
||||
hf_config = self.info.get_hf_config()
|
||||
vision_config = hf_config.vision_config
|
||||
@ -442,11 +444,14 @@ class Blip2DummyInputsBuilder(BaseDummyInputsBuilder[Blip2ProcessingInfo]):
|
||||
max_image_size = vision_config.image_size
|
||||
num_images = mm_counts.get("image", 0)
|
||||
|
||||
image_overrides = mm_options.get("image") if mm_options else None
|
||||
|
||||
return {
|
||||
"image":
|
||||
self._get_dummy_images(width=max_image_size,
|
||||
height=max_image_size,
|
||||
num_images=num_images)
|
||||
num_images=num_images,
|
||||
overrides=image_overrides)
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -14,6 +14,7 @@ from transformers import (BatchFeature, ChameleonConfig, ChameleonProcessor,
|
||||
|
||||
from vllm.attention import Attention
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
@ -92,17 +93,21 @@ class ChameleonDummyInputsBuilder(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
|
||||
) -> MultiModalDataDict:
|
||||
config = self.info.get_hf_config()
|
||||
|
||||
width = height = config.vq_config.resolution
|
||||
num_images = mm_counts.get("image", 0)
|
||||
|
||||
image_overrides = mm_options.get("image") if mm_options else None
|
||||
|
||||
return {
|
||||
"image":
|
||||
self._get_dummy_images(width=width,
|
||||
height=height,
|
||||
num_images=num_images)
|
||||
num_images=num_images,
|
||||
overrides=image_overrides)
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -16,6 +16,7 @@ from transformers.models.cohere2_vision.processing_cohere2_vision import (
|
||||
Cohere2VisionProcessor)
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.model_executor.layers.activation import MulAndSilu
|
||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
@ -209,16 +210,20 @@ class Cohere2VisionDummyInputsBuilder(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
|
||||
) -> MultiModalDataDict:
|
||||
num_images = mm_counts.get("image", 0)
|
||||
image_size = \
|
||||
self.info.get_image_size_with_most_features()
|
||||
|
||||
image_overrides = mm_options.get("image") if mm_options else None
|
||||
|
||||
return {
|
||||
"image":
|
||||
self._get_dummy_images(width=image_size.width,
|
||||
height=image_size.height,
|
||||
num_images=num_images)
|
||||
num_images=num_images,
|
||||
overrides=image_overrides)
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -14,6 +14,7 @@ from einops import rearrange, repeat
|
||||
from transformers import BatchFeature
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
|
||||
@ -191,16 +192,20 @@ class DeepseekVL2DummyInputsBuilder(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
|
||||
) -> MultiModalDataDict:
|
||||
num_images = mm_counts.get("image", 0)
|
||||
|
||||
max_image_size = self.info.get_image_size_with_most_features()
|
||||
|
||||
image_overrides = mm_options.get("image") if mm_options else None
|
||||
|
||||
return {
|
||||
"image":
|
||||
self._get_dummy_images(width=max_image_size.width,
|
||||
height=max_image_size.height,
|
||||
num_images=num_images)
|
||||
num_images=num_images,
|
||||
overrides=image_overrides)
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -13,6 +13,7 @@ from vllm.attention.backends.registry import _Backend
|
||||
from vllm.attention.layer import (check_upstream_fa_availability,
|
||||
maybe_get_vit_flash_attn_backend)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.distributed import utils as dist_utils
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
@ -91,17 +92,21 @@ class DotsOCRDummyInputsBuilder(Qwen2VLDummyInputsBuilder):
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
|
||||
) -> MultiModalDataDict:
|
||||
num_images = mm_counts.get("image", 0)
|
||||
|
||||
target_width, target_height = self.info.get_image_size_with_most_features( # noqa: E501
|
||||
)
|
||||
|
||||
image_overrides = mm_options.get("image") if mm_options else None
|
||||
|
||||
return {
|
||||
"image":
|
||||
self._get_dummy_images(width=target_width,
|
||||
height=target_height,
|
||||
num_images=num_images),
|
||||
num_images=num_images,
|
||||
overrides=image_overrides),
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -38,6 +38,7 @@ from vllm.attention.backends.registry import _Backend
|
||||
from vllm.attention.layer import (check_upstream_fa_availability,
|
||||
maybe_get_vit_flash_attn_backend)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.distributed import parallel_state
|
||||
from vllm.distributed import utils as dist_utils
|
||||
from vllm.logger import init_logger
|
||||
@ -1184,6 +1185,7 @@ class Ernie4_5_VLDummyInputsBuilder(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
|
||||
) -> MultiModalDataDict:
|
||||
num_images = mm_counts.get("image", 0)
|
||||
num_videos = mm_counts.get("video", 0)
|
||||
@ -1193,16 +1195,21 @@ class Ernie4_5_VLDummyInputsBuilder(
|
||||
target_num_frames = \
|
||||
self.info.get_num_frames_with_most_features(seq_len, mm_counts)
|
||||
|
||||
image_overrides = mm_options.get("image") if mm_options else None
|
||||
video_overrides = mm_options.get("video") if mm_options else None
|
||||
|
||||
return {
|
||||
"image":
|
||||
self._get_dummy_images(width=target_width,
|
||||
height=target_height,
|
||||
num_images=num_images),
|
||||
num_images=num_images,
|
||||
overrides=image_overrides),
|
||||
"video":
|
||||
self._get_dummy_videos(width=target_width,
|
||||
height=target_height,
|
||||
num_frames=target_num_frames,
|
||||
num_videos=num_videos)
|
||||
num_videos=num_videos,
|
||||
overrides=video_overrides)
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -27,6 +27,7 @@ from transformers import (BatchFeature, FuyuConfig, FuyuImageProcessor,
|
||||
FuyuProcessor)
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.model_executor.layers.linear import ColumnParallelLinear
|
||||
from vllm.model_executor.models.persimmon import PersimmonForCausalLM
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
@ -136,16 +137,20 @@ class FuyuDummyInputsBuilder(BaseDummyInputsBuilder[FuyuProcessingInfo]):
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
|
||||
) -> MultiModalDataDict:
|
||||
target_width, target_height = \
|
||||
self.info.get_image_size_with_most_features()
|
||||
num_images = mm_counts.get("image", 0)
|
||||
|
||||
image_overrides = mm_options.get("image") if mm_options else None
|
||||
|
||||
return {
|
||||
"image":
|
||||
self._get_dummy_images(width=target_width,
|
||||
height=target_height,
|
||||
num_images=num_images)
|
||||
num_images=num_images,
|
||||
overrides=image_overrides)
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -11,6 +11,7 @@ from transformers.models.gemma3.processing_gemma3 import Gemma3ProcessorKwargs
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.layernorm import GemmaRMSNorm
|
||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
@ -241,17 +242,21 @@ class Gemma3DummyInputsBuilder(BaseDummyInputsBuilder[Gemma3ProcessingInfo]):
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
|
||||
) -> MultiModalDataDict:
|
||||
num_images = mm_counts.get("image", 0)
|
||||
|
||||
target_width, target_height = \
|
||||
self.info.get_image_size_with_most_features()
|
||||
|
||||
image_overrides = mm_options.get("image") if mm_options else None
|
||||
|
||||
return {
|
||||
"image":
|
||||
self._get_dummy_images(width=target_width,
|
||||
height=target_height,
|
||||
num_images=num_images)
|
||||
num_images=num_images,
|
||||
overrides=image_overrides)
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -16,6 +16,7 @@ from transformers.models.gemma3n import (Gemma3nAudioConfig,
|
||||
from transformers.models.siglip import SiglipImageProcessorFast
|
||||
|
||||
from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.inputs.data import PromptType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
@ -153,6 +154,7 @@ class Gemma3nDummyInputsBuilder(BaseDummyInputsBuilder[Gemma3nProcessingInfo]):
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
|
||||
) -> MultiModalDataDict:
|
||||
num_images = mm_counts.get("image", 0)
|
||||
num_audios = mm_counts.get("audio", 0)
|
||||
@ -163,13 +165,19 @@ class Gemma3nDummyInputsBuilder(BaseDummyInputsBuilder[Gemma3nProcessingInfo]):
|
||||
img_width = image_processor.size.get("width", 224)
|
||||
img_height = image_processor.size.get("height", 224)
|
||||
|
||||
image_overrides = mm_options.get("image") if mm_options else None
|
||||
audio_overrides = mm_options.get("audio") if mm_options else None
|
||||
|
||||
return {
|
||||
"image":
|
||||
self._get_dummy_images(width=img_width,
|
||||
height=img_height,
|
||||
num_images=num_images),
|
||||
num_images=num_images,
|
||||
overrides=image_overrides),
|
||||
"audio":
|
||||
self._get_dummy_audios(length=audio_len, num_audios=num_audios)
|
||||
self._get_dummy_audios(length=audio_len,
|
||||
num_audios=num_audios,
|
||||
overrides=audio_overrides)
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -29,7 +29,7 @@
|
||||
import math
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from functools import partial
|
||||
from typing import Annotated, Any, Callable, Literal, Optional, Union
|
||||
from typing import Annotated, Any, Callable, Literal, Optional, Union, override
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -50,6 +50,7 @@ from vllm.attention.backends.registry import _Backend
|
||||
from vllm.attention.layer import (check_upstream_fa_availability,
|
||||
maybe_get_vit_flash_attn_backend)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions
|
||||
from vllm.distributed import (get_tensor_model_parallel_world_size,
|
||||
parallel_state)
|
||||
from vllm.distributed import utils as dist_utils
|
||||
@ -1110,6 +1111,7 @@ class Glm4vDummyInputsBuilder(BaseDummyInputsBuilder[Glm4vProcessingInfo]):
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
|
||||
) -> MultiModalDataDict:
|
||||
num_images = mm_counts.get("image", 0)
|
||||
num_videos = mm_counts.get("video", 0)
|
||||
@ -1118,17 +1120,23 @@ class Glm4vDummyInputsBuilder(BaseDummyInputsBuilder[Glm4vProcessingInfo]):
|
||||
self.info.get_image_size_with_most_features())
|
||||
target_num_frames = self.info.get_num_frames_with_most_features(
|
||||
seq_len, mm_counts)
|
||||
|
||||
image_overrides = mm_options.get("image") if mm_options else None
|
||||
video_overrides = mm_options.get("video") if mm_options else None
|
||||
|
||||
return {
|
||||
"image":
|
||||
self._get_dummy_images(width=target_width,
|
||||
height=target_height,
|
||||
num_images=num_images),
|
||||
num_images=num_images,
|
||||
overrides=image_overrides),
|
||||
"video":
|
||||
self._get_dummy_videos(
|
||||
width=target_width,
|
||||
height=target_height,
|
||||
num_frames=target_num_frames,
|
||||
num_videos=num_videos,
|
||||
overrides=video_overrides,
|
||||
),
|
||||
}
|
||||
|
||||
@ -1139,7 +1147,31 @@ class Glm4vDummyInputsBuilder(BaseDummyInputsBuilder[Glm4vProcessingInfo]):
|
||||
height: int,
|
||||
num_frames: int,
|
||||
num_videos: int,
|
||||
overrides: Optional[VideoDummyOptions] = None,
|
||||
) -> list[VideoItem]:
|
||||
if overrides:
|
||||
if overrides.num_frames:
|
||||
if overrides.num_frames > num_frames:
|
||||
logger.warning(
|
||||
"video.num_frames override (%d) exceeds model's "
|
||||
"maximum number of frames (%d), will be ignored",
|
||||
overrides.num_frames, num_frames)
|
||||
num_frames = min(num_frames, overrides.num_frames)
|
||||
if overrides.width:
|
||||
if overrides.width > width:
|
||||
logger.warning(
|
||||
"video.width override (%d) exceeds model's "
|
||||
"maximum width (%d), will be ignored", overrides.width,
|
||||
width)
|
||||
width = min(width, overrides.width)
|
||||
if overrides.height:
|
||||
if overrides.height > height:
|
||||
logger.warning(
|
||||
"video.height override (%d) exceeds model's "
|
||||
"maximum height (%d), will be ignored",
|
||||
overrides.height, height)
|
||||
height = min(height, override.height)
|
||||
|
||||
video = np.full((num_frames, width, height, 3), 255, dtype=np.uint8)
|
||||
video_items = []
|
||||
for i in range(num_videos):
|
||||
|
||||
@ -19,6 +19,7 @@ from transformers.tokenization_utils_base import TextInput
|
||||
|
||||
from vllm.attention.layer import MultiHeadAttention
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
@ -465,6 +466,7 @@ class GLM4VDummyInputsBuilder(BaseDummyInputsBuilder[GLM4VProcessingInfo]):
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
|
||||
) -> MultiModalDataDict:
|
||||
hf_config = self.info.get_hf_config()
|
||||
vision_config = hf_config.vision_config
|
||||
@ -472,11 +474,14 @@ class GLM4VDummyInputsBuilder(BaseDummyInputsBuilder[GLM4VProcessingInfo]):
|
||||
target_width = target_height = vision_config["image_size"]
|
||||
num_images = mm_counts.get("image", 0)
|
||||
|
||||
image_overrides = mm_options.get("image") if mm_options else None
|
||||
|
||||
return {
|
||||
"image":
|
||||
self._get_dummy_images(width=target_width,
|
||||
height=target_height,
|
||||
num_images=num_images)
|
||||
num_images=num_images,
|
||||
overrides=image_overrides)
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -33,6 +33,7 @@ from torch import nn
|
||||
from transformers import BatchFeature, PretrainedConfig
|
||||
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
@ -181,13 +182,17 @@ class GraniteSpeechDummyInputsBuilder(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
|
||||
) -> MultiModalDataDict:
|
||||
num_audios = mm_counts.get("audio", 0)
|
||||
audio_overrides = mm_options.get("audio") if mm_options else None
|
||||
|
||||
return {
|
||||
"audio":
|
||||
self._get_dummy_audios(
|
||||
length=self.info.get_max_audio_len(),
|
||||
num_audios=num_audios,
|
||||
overrides=audio_overrides,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@ -29,6 +29,7 @@ from transformers import BatchFeature, CLIPVisionConfig, SiglipVisionConfig
|
||||
from transformers.modeling_utils import no_init_weights
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.cache import BaseMultiModalProcessorCache
|
||||
@ -149,6 +150,7 @@ class HCXVisionDummyInputsBuilder(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
|
||||
) -> MultiModalDataDict:
|
||||
num_images = mm_counts.get("image", 0)
|
||||
num_videos = mm_counts.get("video", 0)
|
||||
@ -156,12 +158,17 @@ class HCXVisionDummyInputsBuilder(
|
||||
target_width, target_height = \
|
||||
self.info.get_image_size_with_most_features()
|
||||
target_num_frames = 32
|
||||
|
||||
image_overrides = mm_options.get("image") if mm_options else None
|
||||
video_overrides = mm_options.get("video") if mm_options else None
|
||||
|
||||
return {
|
||||
"image":
|
||||
self._get_dummy_images(
|
||||
width=target_width,
|
||||
height=target_height,
|
||||
num_images=num_images,
|
||||
overrides=image_overrides,
|
||||
),
|
||||
"video":
|
||||
self._get_dummy_videos(
|
||||
@ -169,6 +176,7 @@ class HCXVisionDummyInputsBuilder(
|
||||
height=target_height - 1,
|
||||
num_frames=target_num_frames,
|
||||
num_videos=num_videos,
|
||||
overrides=video_overrides,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@ -26,6 +26,7 @@ from transformers import (BatchFeature, Idefics3Config, Idefics3ImageProcessor,
|
||||
Idefics3Processor)
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.model_executor.layers.linear import ReplicatedLinear
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
@ -292,17 +293,21 @@ class Idefics3DummyInputsBuilder(BaseDummyInputsBuilder[Idefics3ProcessingInfo]
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
|
||||
) -> MultiModalDataDict:
|
||||
num_images = mm_counts.get("image", 0)
|
||||
hf_processor = self.info.get_hf_processor()
|
||||
image_processor: Idefics3ImageProcessor = hf_processor.image_processor
|
||||
longest_edge = image_processor.max_image_size['longest_edge']
|
||||
|
||||
image_overrides = mm_options.get("image") if mm_options else None
|
||||
|
||||
return {
|
||||
"image":
|
||||
self._get_dummy_images(width=longest_edge,
|
||||
height=longest_edge,
|
||||
num_images=num_images)
|
||||
num_images=num_images,
|
||||
overrides=image_overrides)
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -20,6 +20,7 @@ from transformers.models.internvl.video_processing_internvl import (
|
||||
InternVLVideoProcessor)
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.models.interns1_vit import InternS1VisionModel
|
||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
@ -270,6 +271,7 @@ class InternS1DummyInputsBuilder(BaseDummyInputsBuilder[InternS1ProcessingInfo]
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
|
||||
) -> MultiModalDataDict:
|
||||
target_width, target_height = \
|
||||
self.info.get_image_size_with_most_features()
|
||||
@ -281,16 +283,21 @@ class InternS1DummyInputsBuilder(BaseDummyInputsBuilder[InternS1ProcessingInfo]
|
||||
config = self.info.get_hf_config()
|
||||
image_size_h, image_size_w = config.vision_config.image_size
|
||||
|
||||
image_overrides = mm_options.get("image") if mm_options else None
|
||||
video_overrides = mm_options.get("video") if mm_options else None
|
||||
|
||||
return {
|
||||
"image":
|
||||
self._get_dummy_images(width=target_width,
|
||||
height=target_height,
|
||||
num_images=num_images),
|
||||
num_images=num_images,
|
||||
overrides=image_overrides),
|
||||
"video":
|
||||
self._get_dummy_videos(width=image_size_w,
|
||||
height=image_size_h,
|
||||
num_frames=target_num_frames,
|
||||
num_videos=num_videos),
|
||||
num_videos=num_videos,
|
||||
overrides=video_overrides),
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -20,6 +20,7 @@ from PIL import Image
|
||||
from transformers import BatchFeature, PretrainedConfig, TensorType
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.quantization.awq import AWQConfig
|
||||
from vllm.model_executor.models.intern_vit import (InternVisionModel,
|
||||
@ -747,16 +748,20 @@ class BaseInternVLDummyInputsBuilder(BaseDummyInputsBuilder[_I]):
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
|
||||
) -> MultiModalDataDict:
|
||||
target_width, target_height = \
|
||||
self.info.get_image_size_with_most_features()
|
||||
num_images = mm_counts.get("image", 0)
|
||||
|
||||
image_overrides = mm_options.get("image") if mm_options else None
|
||||
|
||||
return {
|
||||
"image":
|
||||
self._get_dummy_images(width=target_width,
|
||||
height=target_height,
|
||||
num_images=num_images)
|
||||
num_images=num_images,
|
||||
overrides=image_overrides)
|
||||
}
|
||||
|
||||
|
||||
@ -913,21 +918,25 @@ class InternVLDummyInputsBuilder(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
|
||||
) -> MultiModalDataDict:
|
||||
dummy_image = super().get_dummy_mm_data(seq_len=seq_len,
|
||||
mm_counts=mm_counts)
|
||||
mm_counts=mm_counts,
|
||||
mm_options=mm_options)
|
||||
if self.info.supports_video:
|
||||
config = self.info.get_hf_config()
|
||||
image_size: int = config.vision_config.image_size
|
||||
target_num_frames = \
|
||||
self.info.get_num_frames_with_most_features(seq_len, mm_counts)
|
||||
num_videos = mm_counts.get("video", 0)
|
||||
video_overrides = mm_options.get("video") if mm_options else None
|
||||
dummy_video = {
|
||||
"video":
|
||||
self._get_dummy_videos(width=image_size,
|
||||
height=image_size,
|
||||
num_frames=target_num_frames,
|
||||
num_videos=num_videos)
|
||||
num_videos=num_videos,
|
||||
overrides=video_overrides)
|
||||
}
|
||||
else:
|
||||
dummy_video = {}
|
||||
|
||||
@ -20,6 +20,7 @@ from transformers.utils import torch_int
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
from vllm.attention.layer import check_upstream_fa_availability
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
@ -1170,6 +1171,7 @@ class KeyeBaseDummyInputsBuilder(BaseDummyInputsBuilder[_I]):
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
|
||||
) -> MultiModalDataDict:
|
||||
num_images = mm_counts.get("image", 0)
|
||||
num_videos = mm_counts.get("video", 0)
|
||||
@ -1179,12 +1181,16 @@ class KeyeBaseDummyInputsBuilder(BaseDummyInputsBuilder[_I]):
|
||||
target_num_frames = self.info.get_num_frames_with_most_features(
|
||||
seq_len)
|
||||
|
||||
image_overrides = mm_options.get("image") if mm_options else None
|
||||
video_overrides = mm_options.get("video") if mm_options else None
|
||||
|
||||
mm_data = {
|
||||
"image":
|
||||
self._get_dummy_images(
|
||||
width=target_width,
|
||||
height=target_height,
|
||||
num_images=num_images,
|
||||
overrides=image_overrides,
|
||||
),
|
||||
"video":
|
||||
self._get_dummy_videos(
|
||||
@ -1192,6 +1198,7 @@ class KeyeBaseDummyInputsBuilder(BaseDummyInputsBuilder[_I]):
|
||||
height=target_height,
|
||||
num_frames=target_num_frames,
|
||||
num_videos=num_videos,
|
||||
overrides=video_overrides,
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@ -54,6 +54,7 @@ from transformers import BatchFeature
|
||||
from transformers.activations import GELUActivation
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.distributed import get_pp_group
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.linear import ReplicatedLinear
|
||||
@ -212,14 +213,18 @@ class KimiVLDummyInputsBuilder(BaseDummyInputsBuilder[KimiVLProcessingInfo]):
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
|
||||
) -> MultiModalDataDict:
|
||||
num_images = mm_counts.get("image", 0)
|
||||
|
||||
image_overrides = mm_options.get("image") if mm_options else None
|
||||
|
||||
return {
|
||||
"image":
|
||||
self._get_dummy_images(width=MaxImageTokenMeta.width,
|
||||
height=MaxImageTokenMeta.height,
|
||||
num_images=num_images)
|
||||
num_images=num_images,
|
||||
overrides=image_overrides)
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -15,6 +15,7 @@ from transformers.models.llava import LlavaProcessor
|
||||
from transformers.models.pixtral import PixtralProcessor
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
@ -195,17 +196,21 @@ class LlavaDummyInputsBuilder(BaseDummyInputsBuilder[_I]):
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
|
||||
) -> MultiModalDataDict:
|
||||
num_images = mm_counts.get("image", 0)
|
||||
|
||||
target_width, target_height = \
|
||||
self.info.get_image_size_with_most_features()
|
||||
|
||||
image_overrides = mm_options.get("image") if mm_options else None
|
||||
|
||||
return {
|
||||
"image":
|
||||
self._get_dummy_images(width=target_width,
|
||||
height=target_height,
|
||||
num_images=num_images)
|
||||
num_images=num_images,
|
||||
overrides=image_overrides)
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -11,6 +11,7 @@ from transformers import (BatchFeature, LlavaNextVideoConfig,
|
||||
LlavaNextVideoProcessor)
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.models.clip import CLIPVisionModel
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
@ -150,6 +151,7 @@ class LlavaNextVideoDummyInputsBuilder(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
|
||||
) -> MultiModalDataDict:
|
||||
num_videos = mm_counts.get("video", 0)
|
||||
|
||||
@ -158,6 +160,8 @@ class LlavaNextVideoDummyInputsBuilder(
|
||||
target_num_frames = \
|
||||
self.info.get_num_frames_with_most_features(seq_len, mm_counts)
|
||||
|
||||
video_overrides = mm_options.get("video") if mm_options else None
|
||||
|
||||
return {
|
||||
"video":
|
||||
self._get_dummy_videos(
|
||||
@ -165,6 +169,7 @@ class LlavaNextVideoDummyInputsBuilder(
|
||||
height=target_height,
|
||||
num_frames=target_num_frames,
|
||||
num_videos=num_videos,
|
||||
overrides=video_overrides,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@ -13,6 +13,7 @@ from transformers.models.llava_onevision.modeling_llava_onevision import (
|
||||
get_anyres_image_grid_shape, unpad_image)
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||
@ -254,6 +255,7 @@ class LlavaOnevisionDummyInputsBuilder(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
|
||||
) -> MultiModalDataDict:
|
||||
num_images = mm_counts.get("image", 0)
|
||||
num_videos = mm_counts.get("video", 0)
|
||||
@ -264,17 +266,22 @@ class LlavaOnevisionDummyInputsBuilder(
|
||||
self.info.get_num_frames_with_most_features(seq_len,
|
||||
mm_counts)
|
||||
|
||||
image_overrides = mm_options.get("image") if mm_options else None
|
||||
video_overrides = mm_options.get("video") if mm_options else None
|
||||
|
||||
return {
|
||||
"image":
|
||||
self._get_dummy_images(width=target_width,
|
||||
height=target_height,
|
||||
num_images=num_images),
|
||||
num_images=num_images,
|
||||
overrides=image_overrides),
|
||||
"video":
|
||||
self._get_dummy_videos(
|
||||
width=target_width,
|
||||
height=target_height,
|
||||
num_frames=target_num_frames,
|
||||
num_videos=num_videos,
|
||||
overrides=video_overrides,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@ -36,6 +36,7 @@ from torch.nn.functional import scaled_dot_product_attention
|
||||
from transformers import BatchFeature
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
@ -539,13 +540,17 @@ class MiDashengLMDummyInputsBuilder(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
|
||||
) -> MultiModalDataDict:
|
||||
num_audios = mm_counts.get("audio", 0)
|
||||
|
||||
audio_overrides = mm_options.get("audio") if mm_options else None
|
||||
|
||||
return {
|
||||
"audio":
|
||||
self._get_dummy_audios(length=self.info.get_max_audio_len(),
|
||||
num_audios=num_audios)
|
||||
num_audios=num_audios,
|
||||
overrides=audio_overrides)
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -36,6 +36,7 @@ from transformers.models.whisper.modeling_whisper import (ACT2FN,
|
||||
WhisperEncoder)
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargsItems
|
||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||
NestedTensors)
|
||||
@ -237,18 +238,23 @@ class MiniCPMODummyInputsBuilder(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
|
||||
) -> MultiModalDataDict:
|
||||
num_audios = mm_counts.get("audio", 0)
|
||||
audio_len = self.info.get_max_audio_chunks_with_most_features() * \
|
||||
self.info.get_default_audio_sampling_rate()
|
||||
|
||||
audio_overrides = mm_options.get("audio") if mm_options else None
|
||||
|
||||
audio_mm_data = {
|
||||
"audio":
|
||||
self._get_dummy_audios(length=audio_len, num_audios=num_audios)
|
||||
self._get_dummy_audios(length=audio_len,
|
||||
num_audios=num_audios,
|
||||
overrides=audio_overrides)
|
||||
}
|
||||
|
||||
return {
|
||||
**super().get_dummy_mm_data(seq_len, mm_counts),
|
||||
**super().get_dummy_mm_data(seq_len, mm_counts, mm_options),
|
||||
**audio_mm_data,
|
||||
}
|
||||
|
||||
|
||||
@ -39,6 +39,7 @@ from transformers import BatchFeature, PretrainedConfig
|
||||
from typing_extensions import TypeVar
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.quantization.awq import AWQConfig
|
||||
from vllm.model_executor.layers.quantization.awq_marlin import AWQMarlinConfig
|
||||
@ -679,6 +680,7 @@ class MiniCPMVDummyInputsBuilder(BaseDummyInputsBuilder[_I]):
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
|
||||
) -> MultiModalDataDict:
|
||||
num_images = mm_counts.get("image", 0)
|
||||
num_videos = mm_counts.get("video", 0)
|
||||
@ -690,15 +692,20 @@ class MiniCPMVDummyInputsBuilder(BaseDummyInputsBuilder[_I]):
|
||||
num_video_frames = \
|
||||
self.info.get_num_frames_with_most_features(seq_len, mm_counts)
|
||||
|
||||
image_overrides = mm_options.get("image") if mm_options else None
|
||||
video_overrides = mm_options.get("video") if mm_options else None
|
||||
|
||||
return {
|
||||
"image":
|
||||
self._get_dummy_images(width=image_width,
|
||||
height=image_height,
|
||||
num_images=num_images),
|
||||
num_images=num_images,
|
||||
overrides=image_overrides),
|
||||
"video": [
|
||||
self._get_dummy_images(width=video_width,
|
||||
height=video_height,
|
||||
num_images=num_video_frames)
|
||||
num_images=num_video_frames,
|
||||
overrides=video_overrides)
|
||||
] * num_videos,
|
||||
}
|
||||
|
||||
|
||||
@ -13,6 +13,7 @@ from transformers import (BatchFeature, Mistral3Config, PixtralVisionConfig,
|
||||
from transformers.models.pixtral import PixtralProcessor
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
@ -208,17 +209,21 @@ class Mistral3DummyInputsBuilder(BaseDummyInputsBuilder[_I]):
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
|
||||
) -> MultiModalDataDict:
|
||||
num_images = mm_counts.get("image", 0)
|
||||
|
||||
target_width, target_height = \
|
||||
self.info.get_image_size_with_most_features()
|
||||
|
||||
image_overrides = mm_options.get("image") if mm_options else None
|
||||
|
||||
return {
|
||||
"image":
|
||||
self._get_dummy_images(width=target_width,
|
||||
height=target_height,
|
||||
num_images=num_images)
|
||||
num_images=num_images,
|
||||
overrides=image_overrides)
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -31,6 +31,7 @@ from transformers.models.llama4.image_processing_llama4_fast import (
|
||||
|
||||
from vllm.attention.layer import MultiHeadAttention
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
@ -689,17 +690,21 @@ class Mllama4DummyInputsBuilder(BaseDummyInputsBuilder[Mllama4ProcessingInfo]):
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
|
||||
) -> MultiModalDataDict:
|
||||
num_images = mm_counts.get("image", 0)
|
||||
|
||||
(target_width,
|
||||
target_height) = self.info.get_image_size_with_most_features()
|
||||
|
||||
image_overrides = mm_options.get("image") if mm_options else None
|
||||
|
||||
return {
|
||||
"image":
|
||||
self._get_dummy_images(width=target_width,
|
||||
height=target_height,
|
||||
num_images=num_images)
|
||||
num_images=num_images,
|
||||
overrides=image_overrides)
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -22,6 +22,7 @@ from vllm.attention import Attention
|
||||
from vllm.attention.layer import MultiHeadAttention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
split_tensor_along_last_dim,
|
||||
@ -1226,16 +1227,20 @@ class MolmoDummyInputsBuilder(BaseDummyInputsBuilder[MolmoProcessingInfo]):
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
|
||||
) -> MultiModalDataDict:
|
||||
target_width, target_height = \
|
||||
self.info.get_image_size_with_most_features()
|
||||
num_images = mm_counts.get("image", 0)
|
||||
|
||||
image_overrides = mm_options.get("image") if mm_options else None
|
||||
|
||||
return {
|
||||
"image":
|
||||
self._get_dummy_images(width=target_width,
|
||||
height=target_height,
|
||||
num_images=num_images)
|
||||
num_images=num_images,
|
||||
overrides=image_overrides)
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -21,6 +21,7 @@ from PIL import Image
|
||||
from transformers import BatchFeature, PretrainedConfig, TensorType
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.model_executor.layers.activation import ReLUSquaredActivation
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
@ -809,6 +810,7 @@ class NanoNemotronVLDummyInputsBuilder(BaseDummyInputsBuilder[_I]):
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
|
||||
) -> MultiModalDataDict:
|
||||
# Use default max_num_tiles for dummy data generation
|
||||
max_num_tiles = 12
|
||||
@ -816,11 +818,14 @@ class NanoNemotronVLDummyInputsBuilder(BaseDummyInputsBuilder[_I]):
|
||||
self.info.get_image_size_with_most_features(max_num_tiles))
|
||||
num_images = mm_counts.get("image", 0)
|
||||
|
||||
image_overrides = mm_options.get("image") if mm_options else None
|
||||
|
||||
return {
|
||||
"image":
|
||||
self._get_dummy_images(width=target_width,
|
||||
height=target_height,
|
||||
num_images=num_images)
|
||||
num_images=num_images,
|
||||
overrides=image_overrides)
|
||||
}
|
||||
|
||||
|
||||
@ -837,21 +842,25 @@ class NanoNemotronVLDummyInputsBuilder(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
|
||||
) -> MultiModalDataDict:
|
||||
dummy_image = super().get_dummy_mm_data(seq_len=seq_len,
|
||||
mm_counts=mm_counts)
|
||||
mm_counts=mm_counts,
|
||||
mm_options=mm_options)
|
||||
if self.info.supports_video:
|
||||
config = self.info.get_hf_config()
|
||||
image_size: int = config.force_image_size
|
||||
target_num_frames = \
|
||||
self.info.get_num_frames_with_most_features(seq_len, mm_counts)
|
||||
num_videos = mm_counts.get("video", 0)
|
||||
video_overrides = mm_options.get("video") if mm_options else None
|
||||
dummy_video = {
|
||||
"video":
|
||||
self._get_dummy_videos(width=image_size,
|
||||
height=image_size,
|
||||
num_frames=target_num_frames,
|
||||
num_videos=num_videos)
|
||||
num_videos=num_videos,
|
||||
overrides=video_overrides)
|
||||
}
|
||||
else:
|
||||
dummy_video = {}
|
||||
|
||||
@ -14,6 +14,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import MultiModalDataDict, MultiModalKwargsItems
|
||||
@ -86,16 +87,20 @@ class NVLMDummyInputsBuilder(BaseInternVLDummyInputsBuilder[NVLMProcessingInfo]
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
|
||||
) -> MultiModalDataDict:
|
||||
target_width, target_height = \
|
||||
self.info.get_image_size_with_most_features()
|
||||
num_images = mm_counts.get("image", 0)
|
||||
|
||||
image_overrides = mm_options.get("image") if mm_options else None
|
||||
|
||||
return {
|
||||
"image":
|
||||
self._get_dummy_images(width=target_width,
|
||||
height=target_height,
|
||||
num_images=num_images)
|
||||
num_images=num_images,
|
||||
overrides=image_overrides)
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -28,6 +28,7 @@ from torch.nn.functional import gumbel_softmax, pad, softmax
|
||||
from transformers import BatchFeature, PretrainedConfig
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.model_executor.layers.linear import ReplicatedLinear
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.models.aimv2 import AIMv2Model
|
||||
@ -283,17 +284,21 @@ class OvisDummyInputsBuilder(BaseDummyInputsBuilder[OvisProcessingInfo]):
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
|
||||
) -> MultiModalDataDict:
|
||||
num_images = mm_counts.get("image", 0)
|
||||
|
||||
target_width, target_height = \
|
||||
self.info.get_image_size_with_most_features()
|
||||
|
||||
image_overrides = mm_options.get("image") if mm_options else None
|
||||
|
||||
mm_data = {
|
||||
"image":
|
||||
self._get_dummy_images(width=target_width,
|
||||
height=target_height,
|
||||
num_images=num_images),
|
||||
num_images=num_images,
|
||||
overrides=image_overrides),
|
||||
}
|
||||
return mm_data
|
||||
|
||||
|
||||
@ -10,6 +10,7 @@ import torch.nn as nn
|
||||
from transformers import BaseImageProcessor, BatchFeature, PretrainedConfig
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.model_executor.layers.linear import ReplicatedLinear
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.models.ovis import (OvisImagePatchInputs,
|
||||
@ -290,6 +291,7 @@ class Ovis2_5DummyInputsBuilder(BaseDummyInputsBuilder[Ovis2_5ProcessingInfo]):
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
|
||||
) -> MultiModalDataDict:
|
||||
num_images = mm_counts.get("image", 0)
|
||||
num_videos = mm_counts.get("video", 0)
|
||||
@ -298,17 +300,23 @@ class Ovis2_5DummyInputsBuilder(BaseDummyInputsBuilder[Ovis2_5ProcessingInfo]):
|
||||
self.info.get_image_size_with_most_features()
|
||||
target_num_frames = \
|
||||
self.info.get_num_frames_with_most_features(seq_len, mm_counts)
|
||||
|
||||
image_overrides = mm_options.get("image") if mm_options else None
|
||||
video_overrides = mm_options.get("video") if mm_options else None
|
||||
|
||||
mm_data = {
|
||||
"image":
|
||||
self._get_dummy_images(width=target_width,
|
||||
height=target_height,
|
||||
num_images=num_images),
|
||||
num_images=num_images,
|
||||
overrides=image_overrides),
|
||||
"video":
|
||||
self._get_dummy_videos(
|
||||
width=target_width,
|
||||
height=target_height,
|
||||
num_frames=target_num_frames,
|
||||
num_videos=num_videos,
|
||||
overrides=video_overrides,
|
||||
)
|
||||
}
|
||||
return mm_data
|
||||
|
||||
@ -8,6 +8,7 @@ from torch import nn
|
||||
from transformers import BatchFeature, PaliGemmaConfig
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.logger import init_logger
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||
@ -106,6 +107,7 @@ class PaliGemmaDummyInputsBuilder(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
|
||||
) -> MultiModalDataDict:
|
||||
hf_config = self.info.get_hf_config()
|
||||
vision_config = hf_config.vision_config
|
||||
@ -113,11 +115,14 @@ class PaliGemmaDummyInputsBuilder(
|
||||
|
||||
num_images = mm_counts.get("image", 0)
|
||||
|
||||
image_overrides = mm_options.get("image") if mm_options else None
|
||||
|
||||
return {
|
||||
"image":
|
||||
self._get_dummy_images(width=max_image_size,
|
||||
height=max_image_size,
|
||||
num_images=num_images)
|
||||
num_images=num_images,
|
||||
overrides=image_overrides)
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -25,6 +25,7 @@ from transformers import (BatchFeature, CLIPVisionConfig, PretrainedConfig,
|
||||
ProcessorMixin)
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
@ -356,17 +357,21 @@ class Phi3VDummyInputsBuilder(BaseDummyInputsBuilder[Phi3VProcessingInfo]):
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
|
||||
) -> MultiModalDataDict:
|
||||
num_images = mm_counts.get("image", 0)
|
||||
|
||||
target_width, target_height = \
|
||||
self.info.get_image_size_with_most_features()
|
||||
|
||||
image_overrides = mm_options.get("image") if mm_options else None
|
||||
|
||||
return {
|
||||
"image":
|
||||
self._get_dummy_images(width=target_width,
|
||||
height=target_height,
|
||||
num_images=num_images)
|
||||
num_images=num_images,
|
||||
overrides=image_overrides)
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -17,6 +17,7 @@ from transformers.models.phi4_multimodal.modeling_phi4_multimodal import (
|
||||
Phi4MultimodalAudioRelativeAttentionBias, adaptive_enc_mask, unfold_tensor)
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.layers.activation import MulAndSilu, get_act_fn
|
||||
@ -980,6 +981,7 @@ class Phi4MMDummyInputsBuilder(BaseDummyInputsBuilder[Phi4MMProcessingInfo]):
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
|
||||
) -> MultiModalDataDict:
|
||||
num_audios = mm_counts.get("audio", 0)
|
||||
num_images = mm_counts.get("image", 0)
|
||||
@ -987,14 +989,19 @@ class Phi4MMDummyInputsBuilder(BaseDummyInputsBuilder[Phi4MMProcessingInfo]):
|
||||
target_width, target_height = \
|
||||
self.info.get_image_size_with_most_features()
|
||||
|
||||
image_overrides = mm_options.get("image") if mm_options else None
|
||||
audio_overrides = mm_options.get("audio") if mm_options else None
|
||||
|
||||
mm_data = {
|
||||
"image":
|
||||
self._get_dummy_images(width=target_width,
|
||||
height=target_height,
|
||||
num_images=num_images),
|
||||
num_images=num_images,
|
||||
overrides=image_overrides),
|
||||
"audio":
|
||||
self._get_dummy_audios(length=_AUDIO_MAX_SOUNDFILE_SIZE,
|
||||
num_audios=num_audios),
|
||||
num_audios=num_audios,
|
||||
overrides=audio_overrides),
|
||||
}
|
||||
|
||||
return mm_data
|
||||
|
||||
@ -11,6 +11,7 @@ from transformers import (BatchFeature, PretrainedConfig, ProcessorMixin,
|
||||
SequenceFeatureExtractor, SiglipVisionConfig)
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.distributed import get_pp_group
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
@ -749,6 +750,7 @@ class Phi4MMDummyInputsBuilder(BaseDummyInputsBuilder[Phi4MMProcessingInfo]):
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
|
||||
) -> MultiModalDataDict:
|
||||
num_audios = mm_counts.get("audio", 0)
|
||||
num_images = mm_counts.get("image", 0)
|
||||
@ -756,14 +758,19 @@ class Phi4MMDummyInputsBuilder(BaseDummyInputsBuilder[Phi4MMProcessingInfo]):
|
||||
target_width, target_height = \
|
||||
self.info.get_image_size_with_most_features()
|
||||
|
||||
image_overrides = mm_options.get("image") if mm_options else None
|
||||
audio_overrides = mm_options.get("audio") if mm_options else None
|
||||
|
||||
mm_data = {
|
||||
"image":
|
||||
self._get_dummy_images(width=target_width,
|
||||
height=target_height,
|
||||
num_images=num_images),
|
||||
num_images=num_images,
|
||||
overrides=image_overrides),
|
||||
"audio":
|
||||
self._get_dummy_audios(length=_AUDIO_MAX_SOUNDFILE_SIZE,
|
||||
num_audios=num_audios),
|
||||
num_audios=num_audios,
|
||||
overrides=audio_overrides),
|
||||
}
|
||||
|
||||
return mm_data
|
||||
|
||||
@ -24,6 +24,7 @@ from transformers.models.pixtral.modeling_pixtral import (
|
||||
from transformers.tokenization_utils_base import TextInput
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import get_act_and_mul_fn
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
@ -228,28 +229,33 @@ class PixtralDummyInputsBuilder(BaseDummyInputsBuilder[PixtralProcessingInfo]):
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
|
||||
) -> MultiModalDataDict:
|
||||
num_images = mm_counts.get("image", 0)
|
||||
|
||||
target_width, target_height = \
|
||||
self.info.get_image_size_with_most_features()
|
||||
|
||||
image_overrides = mm_options.get("image") if mm_options else None
|
||||
|
||||
return {
|
||||
"image":
|
||||
self._get_dummy_images(width=target_width,
|
||||
height=target_height,
|
||||
num_images=num_images)
|
||||
num_images=num_images,
|
||||
overrides=image_overrides)
|
||||
}
|
||||
|
||||
def get_dummy_processor_inputs(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
|
||||
) -> ProcessorInputs:
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
|
||||
dummy_text = self.get_dummy_text(mm_counts)
|
||||
dummy_mm_data = self.get_dummy_mm_data(seq_len, mm_counts)
|
||||
dummy_mm_data = self.get_dummy_mm_data(seq_len, mm_counts, mm_options)
|
||||
dummy_images = dummy_mm_data.get("image", [])
|
||||
tokenization_kwargs = {"truncation": False}
|
||||
|
||||
|
||||
@ -39,6 +39,7 @@ from transformers.models.qwen2_5_omni.processing_qwen2_5_omni import (
|
||||
from transformers.models.whisper import WhisperFeatureExtractor
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
|
||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
@ -212,6 +213,7 @@ class Qwen2_5OmniThinkerDummyInputsBuilder(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
|
||||
) -> MultiModalDataDict:
|
||||
num_audios = mm_counts.get("audio", 0)
|
||||
num_images = mm_counts.get("image", 0)
|
||||
@ -228,19 +230,26 @@ class Qwen2_5OmniThinkerDummyInputsBuilder(
|
||||
target_num_frames = \
|
||||
self.info.get_num_frames_with_most_features(seq_len, mm_counts)
|
||||
|
||||
image_overrides = mm_options.get("image") if mm_options else None
|
||||
video_overrides = mm_options.get("video") if mm_options else None
|
||||
audio_overrides = mm_options.get("audio") if mm_options else None
|
||||
|
||||
mm_data = {
|
||||
"audio":
|
||||
self._get_dummy_audios(length=target_audio_length,
|
||||
num_audios=num_audios),
|
||||
num_audios=num_audios,
|
||||
overrides=audio_overrides),
|
||||
"image":
|
||||
self._get_dummy_images(width=target_width,
|
||||
height=target_height,
|
||||
num_images=num_images),
|
||||
num_images=num_images,
|
||||
overrides=image_overrides),
|
||||
"video":
|
||||
self._get_dummy_videos(width=target_width,
|
||||
height=target_height,
|
||||
num_frames=target_num_frames,
|
||||
num_videos=num_videos),
|
||||
num_videos=num_videos,
|
||||
overrides=video_overrides),
|
||||
}
|
||||
|
||||
return mm_data
|
||||
|
||||
@ -34,6 +34,7 @@ from transformers.models.qwen2_audio import (Qwen2AudioConfig,
|
||||
from transformers.models.whisper import WhisperFeatureExtractor
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (AudioItem, ModalityData,
|
||||
MultiModalDataDict, MultiModalFieldConfig,
|
||||
@ -144,6 +145,7 @@ class Qwen2AudioDummyInputsBuilder(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
|
||||
) -> MultiModalDataDict:
|
||||
feature_extractor = self.info.get_feature_extractor()
|
||||
|
||||
@ -151,9 +153,13 @@ class Qwen2AudioDummyInputsBuilder(
|
||||
audio_len = feature_extractor.chunk_length * sampling_rate
|
||||
num_audios = mm_counts.get("audio", 0)
|
||||
|
||||
audio_overrides = mm_options.get("audio") if mm_options else None
|
||||
|
||||
return {
|
||||
"audio":
|
||||
self._get_dummy_audios(length=audio_len, num_audios=num_audios)
|
||||
self._get_dummy_audios(length=audio_len,
|
||||
num_audios=num_audios,
|
||||
overrides=audio_overrides)
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -45,6 +45,7 @@ from vllm.attention.backends.registry import _Backend
|
||||
from vllm.attention.layer import (check_upstream_fa_availability,
|
||||
maybe_get_vit_flash_attn_backend)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.distributed import parallel_state, tensor_model_parallel_all_gather
|
||||
from vllm.distributed import utils as dist_utils
|
||||
from vllm.logger import init_logger
|
||||
@ -1034,6 +1035,7 @@ class Qwen2VLDummyInputsBuilder(BaseDummyInputsBuilder[Qwen2VLProcessingInfo]):
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
|
||||
) -> MultiModalDataDict:
|
||||
num_images = mm_counts.get("image", 0)
|
||||
num_videos = mm_counts.get("video", 0)
|
||||
@ -1043,17 +1045,22 @@ class Qwen2VLDummyInputsBuilder(BaseDummyInputsBuilder[Qwen2VLProcessingInfo]):
|
||||
target_num_frames = \
|
||||
self.info.get_num_frames_with_most_features(seq_len, mm_counts)
|
||||
|
||||
image_overrides = mm_options.get("image") if mm_options else None
|
||||
video_overrides = mm_options.get("video") if mm_options else None
|
||||
|
||||
return {
|
||||
"image":
|
||||
self._get_dummy_images(width=target_width,
|
||||
height=target_height,
|
||||
num_images=num_images),
|
||||
num_images=num_images,
|
||||
overrides=image_overrides),
|
||||
"video":
|
||||
self._get_dummy_videos(
|
||||
width=target_width,
|
||||
height=target_height,
|
||||
num_frames=target_num_frames,
|
||||
num_videos=num_videos,
|
||||
overrides=video_overrides,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@ -47,6 +47,7 @@ from vllm.attention.backends.registry import _Backend
|
||||
from vllm.attention.layer import check_upstream_fa_availability
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.distributed import get_pp_group
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY
|
||||
@ -736,6 +737,7 @@ class Qwen3VLDummyInputsBuilder(BaseDummyInputsBuilder[Qwen3VLProcessingInfo]):
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
|
||||
) -> MultiModalDataDict:
|
||||
num_images = mm_counts.get("image", 0)
|
||||
num_videos = mm_counts.get("video", 0)
|
||||
@ -750,17 +752,23 @@ class Qwen3VLDummyInputsBuilder(BaseDummyInputsBuilder[Qwen3VLProcessingInfo]):
|
||||
num_frames=target_num_frames,
|
||||
image_processor=self.info.get_video_processor(),
|
||||
)
|
||||
|
||||
image_overrides = mm_options.get("image") if mm_options else None
|
||||
video_overrides = mm_options.get("video") if mm_options else None
|
||||
|
||||
return {
|
||||
"image":
|
||||
self._get_dummy_images(width=target_width,
|
||||
height=target_height,
|
||||
num_images=num_images),
|
||||
num_images=num_images,
|
||||
overrides=image_overrides),
|
||||
"video":
|
||||
self._get_dummy_videos(
|
||||
width=target_video_size.width,
|
||||
height=target_video_size.height,
|
||||
num_frames=target_num_frames,
|
||||
num_videos=num_videos,
|
||||
overrides=video_overrides,
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@ -24,6 +24,7 @@ from transformers.image_utils import ImageInput
|
||||
from transformers.tokenization_utils_base import TextInput
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
ReplicatedLinear,
|
||||
@ -567,6 +568,7 @@ class QwenVLDummyInputsBuilder(BaseDummyInputsBuilder[QwenVLProcessingInfo]):
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
|
||||
) -> MultiModalDataDict:
|
||||
hf_config = self.info.get_hf_config()
|
||||
vision_config = hf_config.visual
|
||||
@ -574,11 +576,14 @@ class QwenVLDummyInputsBuilder(BaseDummyInputsBuilder[QwenVLProcessingInfo]):
|
||||
target_width = target_height = vision_config["image_size"]
|
||||
num_images = mm_counts.get("image", 0)
|
||||
|
||||
image_overrides = mm_options.get("image") if mm_options else None
|
||||
|
||||
return {
|
||||
"image":
|
||||
self._get_dummy_images(width=target_width,
|
||||
height=target_height,
|
||||
num_images=num_images)
|
||||
num_images=num_images,
|
||||
overrides=image_overrides)
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -2,12 +2,14 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Mapping
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers.activations import GELUActivation
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import MultiModalDataDict
|
||||
|
||||
@ -38,17 +40,21 @@ class RVLDummyInputsBuilder(LlavaDummyInputsBuilder[RVLProcessingInfo]):
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
|
||||
) -> MultiModalDataDict:
|
||||
num_images = mm_counts.get("image", 0)
|
||||
|
||||
target_width, target_height = (
|
||||
self.info.get_image_size_with_most_features())
|
||||
|
||||
image_overrides = mm_options.get("image") if mm_options else None
|
||||
|
||||
return {
|
||||
"image":
|
||||
self._get_dummy_images(width=target_width,
|
||||
height=target_height,
|
||||
num_images=num_images),
|
||||
num_images=num_images,
|
||||
overrides=image_overrides),
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -17,6 +17,7 @@ from PIL import Image
|
||||
from transformers import BatchFeature, PretrainedConfig, TensorType
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.model_executor.layers.linear import ReplicatedLinear
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.quantization.awq import AWQConfig
|
||||
@ -505,16 +506,20 @@ class SkyworkR1VDummyInputsBuilder(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
|
||||
) -> MultiModalDataDict:
|
||||
target_width, target_height = \
|
||||
self.info.get_image_size_with_most_features()
|
||||
num_images = mm_counts.get("image", 0)
|
||||
|
||||
image_overrides = mm_options.get("image") if mm_options else None
|
||||
|
||||
return {
|
||||
"image":
|
||||
self._get_dummy_images(width=target_width,
|
||||
height=target_height,
|
||||
num_images=num_images)
|
||||
num_images=num_images,
|
||||
overrides=image_overrides)
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -17,6 +17,7 @@ from transformers import BatchFeature, PretrainedConfig, TensorType
|
||||
|
||||
from vllm.attention.layer import MultiHeadAttention
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
@ -496,16 +497,20 @@ class Step3VLDummyInputsBuilder(BaseDummyInputsBuilder[Step3VLProcessingInfo]):
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
|
||||
) -> MultiModalDataDict:
|
||||
target_width, target_height = \
|
||||
self.info.get_image_size_with_most_features()
|
||||
num_images = mm_counts.get("image", 0)
|
||||
|
||||
image_overrides = mm_options.get("image") if mm_options else None
|
||||
|
||||
return {
|
||||
"image":
|
||||
self._get_dummy_images(width=target_width,
|
||||
height=target_height,
|
||||
num_images=num_images)
|
||||
num_images=num_images,
|
||||
overrides=image_overrides)
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -28,6 +28,8 @@ from terratorch.vllm import (DummyDataGenerator, InferenceRunner,
|
||||
from transformers import BatchFeature
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.pooler import DispatchPooler, Pooler
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.utils import AutoWeightsLoader
|
||||
@ -48,6 +50,8 @@ from .interfaces import (IsAttentionFree, MultiModalEmbeddings,
|
||||
SupportsMultiModal)
|
||||
from .interfaces_base import default_pooling_type
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def _terratorch_field_names(pretrained_cfg: dict):
|
||||
input_definition = InputDefinition(**pretrained_cfg["input"])
|
||||
@ -97,9 +101,16 @@ class TerratorchInputBuilder(BaseDummyInputsBuilder[TerratorchProcessingInfo]):
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
|
||||
) -> MultiModalDataDict:
|
||||
# Dummy data is generated based on the 'input' section
|
||||
# defined in the HF configuration file
|
||||
|
||||
if mm_options:
|
||||
logger.warning("Configurable multimodal profiling "
|
||||
"options are not supported for Terratorch. "
|
||||
"They are ignored for now.")
|
||||
|
||||
return self.dummy_data_generator.get_dummy_mm_data()
|
||||
|
||||
|
||||
|
||||
@ -33,6 +33,7 @@ from vllm.attention import Attention, AttentionType
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
|
||||
ParallelConfig, VllmConfig)
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.config.utils import getattr_iter
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
from vllm.distributed.utils import get_pp_indices
|
||||
@ -285,16 +286,20 @@ class MultiModalDummyInputsBuilder(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
|
||||
) -> MultiModalDataDict:
|
||||
num_images = mm_counts.get("image", 0)
|
||||
|
||||
target_width, target_height = self.info.get_max_image_size()
|
||||
|
||||
image_overrides = mm_options.get("image") if mm_options else None
|
||||
|
||||
return {
|
||||
"image":
|
||||
self._get_dummy_images(width=target_width,
|
||||
height=target_height,
|
||||
num_images=num_images),
|
||||
num_images=num_images,
|
||||
overrides=image_overrides),
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -14,6 +14,7 @@ from transformers.models.whisper import WhisperFeatureExtractor
|
||||
from transformers.models.whisper.modeling_whisper import WhisperEncoder
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.model_executor.layers.activation import MulAndSilu, get_act_fn
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.model_loader import DefaultModelLoader
|
||||
@ -114,6 +115,7 @@ class UltravoxDummyInputsBuilder(BaseDummyInputsBuilder[UltravoxProcessingInfo]
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
|
||||
) -> MultiModalDataDict:
|
||||
feature_extractor = self.info.get_feature_extractor()
|
||||
|
||||
@ -122,9 +124,13 @@ class UltravoxDummyInputsBuilder(BaseDummyInputsBuilder[UltravoxProcessingInfo]
|
||||
_MAX_ENCODER_BATCH_SIZE)
|
||||
num_audios = mm_counts.get("audio", 0)
|
||||
|
||||
audio_overrides = mm_options.get("audio") if mm_options else None
|
||||
|
||||
return {
|
||||
"audio":
|
||||
self._get_dummy_audios(length=audio_len, num_audios=num_audios)
|
||||
self._get_dummy_audios(length=audio_len,
|
||||
num_audios=num_audios,
|
||||
overrides=audio_overrides)
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -21,6 +21,7 @@ from transformers import BatchFeature, TensorType, WhisperConfig
|
||||
from transformers.tokenization_utils_base import TextInput
|
||||
|
||||
from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.inputs.data import PromptType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
@ -204,25 +205,31 @@ class VoxtralDummyInputsBuilder(BaseDummyInputsBuilder[VoxtralProcessingInfo]):
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
|
||||
) -> MultiModalDataDict:
|
||||
num_audios = mm_counts.get("audio", 0)
|
||||
|
||||
target_length = self.info.get_max_audio_array_len()
|
||||
|
||||
audio_overrides = mm_options.get("audio") if mm_options else None
|
||||
|
||||
return {
|
||||
"audio":
|
||||
self._get_dummy_audios(length=target_length, num_audios=num_audios)
|
||||
self._get_dummy_audios(length=target_length,
|
||||
num_audios=num_audios,
|
||||
overrides=audio_overrides)
|
||||
}
|
||||
|
||||
def get_dummy_processor_inputs(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
|
||||
) -> ProcessorInputs:
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
|
||||
dummy_text = self.get_dummy_text(mm_counts)
|
||||
dummy_mm_data = self.get_dummy_mm_data(seq_len, mm_counts)
|
||||
dummy_mm_data = self.get_dummy_mm_data(seq_len, mm_counts, mm_options)
|
||||
dummy_audios = dummy_mm_data.get("audio", [])
|
||||
|
||||
audio_chunks: list[AudioChunk] = []
|
||||
|
||||
@ -18,6 +18,7 @@ from vllm.attention.layer import MultiHeadAttention
|
||||
from vllm.attention.layers.cross_attention import CrossAttention
|
||||
from vllm.config import (CacheConfig, ModelConfig, SpeechToTextConfig,
|
||||
VllmConfig)
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.inputs.data import PromptType
|
||||
from vllm.logger import init_logger
|
||||
@ -691,6 +692,7 @@ class WhisperDummyInputsBuilder(BaseDummyInputsBuilder[WhisperProcessingInfo]):
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
|
||||
) -> MultiModalDataDict:
|
||||
feature_extractor = self.info.get_feature_extractor()
|
||||
|
||||
@ -698,9 +700,13 @@ class WhisperDummyInputsBuilder(BaseDummyInputsBuilder[WhisperProcessingInfo]):
|
||||
audio_len = feature_extractor.chunk_length * sampling_rate
|
||||
num_audios = mm_counts.get("audio", 0)
|
||||
|
||||
audio_overrides = mm_options.get("audio") if mm_options else None
|
||||
|
||||
return {
|
||||
"audio":
|
||||
self._get_dummy_audios(length=audio_len, num_audios=num_audios)
|
||||
self._get_dummy_audios(length=audio_len,
|
||||
num_audios=num_audios,
|
||||
overrides=audio_overrides)
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -10,6 +10,8 @@ import numpy.typing as npt
|
||||
from PIL import Image
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config.multimodal import (AudioDummyOptions, BaseDummyOptions,
|
||||
ImageDummyOptions, VideoDummyOptions)
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from .inputs import (MultiModalDataDict, MultiModalEncDecInputs,
|
||||
@ -73,10 +75,19 @@ class BaseDummyInputsBuilder(ABC, Generic[_I]):
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
|
||||
) -> MultiModalDataDict:
|
||||
"""
|
||||
Build the multimodal input which, after processing, results in
|
||||
the maximum possible number of placeholder tokens.
|
||||
|
||||
Args:
|
||||
seq_len: Sequence length
|
||||
mm_counts: Count of items per modality
|
||||
mm_options: Configurable options per modality (optional).
|
||||
If None, use model defaults for backward compatibility.
|
||||
If provided, models can use these to customize dummy
|
||||
data generation.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@ -84,13 +95,22 @@ class BaseDummyInputsBuilder(ABC, Generic[_I]):
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
|
||||
) -> ProcessorInputs:
|
||||
"""
|
||||
Build the input which, after processing, results in
|
||||
the maximum possible number of placeholder tokens.
|
||||
|
||||
Args:
|
||||
seq_len: Sequence length
|
||||
mm_counts: Count of items per modality
|
||||
mm_options: Configurable options per modality (optional)
|
||||
"""
|
||||
dummy_text = self.get_dummy_text(mm_counts)
|
||||
dummy_mm_data = self.get_dummy_mm_data(seq_len, mm_counts)
|
||||
|
||||
# Use the unified function for both legacy and configurable cases
|
||||
dummy_mm_data = self.get_dummy_mm_data(seq_len, mm_counts, mm_options)
|
||||
|
||||
tokenization_kwargs = {"truncation": False}
|
||||
|
||||
return ProcessorInputs(prompt=dummy_text,
|
||||
@ -102,9 +122,17 @@ class BaseDummyInputsBuilder(ABC, Generic[_I]):
|
||||
*,
|
||||
length: int,
|
||||
num_audios: int,
|
||||
overrides: Optional[AudioDummyOptions] = None,
|
||||
) -> list[npt.NDArray]:
|
||||
if num_audios == 0:
|
||||
return []
|
||||
if overrides and overrides.length:
|
||||
if overrides.length > length:
|
||||
logger.warning(
|
||||
"audio.length override (%d) exceeds model's "
|
||||
"maximum length (%d), will be ignored", overrides.length,
|
||||
length)
|
||||
length = min(length, overrides.length)
|
||||
audio = np.zeros((length, ))
|
||||
return [audio] * num_audios
|
||||
|
||||
@ -114,9 +142,25 @@ class BaseDummyInputsBuilder(ABC, Generic[_I]):
|
||||
width: int,
|
||||
height: int,
|
||||
num_images: int,
|
||||
overrides: Optional[ImageDummyOptions] = None,
|
||||
) -> list[Image.Image]:
|
||||
if num_images == 0:
|
||||
return []
|
||||
if overrides:
|
||||
if overrides.width:
|
||||
if overrides.width > width:
|
||||
logger.warning(
|
||||
"image.width override (%d) exceeds model's "
|
||||
"maximum width (%d), will be ignored", overrides.width,
|
||||
width)
|
||||
width = min(width, overrides.width)
|
||||
if overrides.height:
|
||||
if overrides.height > height:
|
||||
logger.warning(
|
||||
"image.height override (%d) exceeds model's "
|
||||
"maximum height (%d), will be ignored",
|
||||
overrides.height, height)
|
||||
height = min(height, overrides.height)
|
||||
image = Image.new("RGB", (width, height), color=255)
|
||||
return [image] * num_images
|
||||
|
||||
@ -127,9 +171,32 @@ class BaseDummyInputsBuilder(ABC, Generic[_I]):
|
||||
height: int,
|
||||
num_frames: int,
|
||||
num_videos: int,
|
||||
overrides: Optional[VideoDummyOptions] = None,
|
||||
) -> list[npt.NDArray]:
|
||||
if num_videos == 0:
|
||||
return []
|
||||
if overrides:
|
||||
if overrides.num_frames:
|
||||
if overrides.num_frames > num_frames:
|
||||
logger.warning(
|
||||
"video.num_frames override (%d) exceeds model's "
|
||||
"maximum number of frames (%d), will be ignored",
|
||||
overrides.num_frames, num_frames)
|
||||
num_frames = min(num_frames, overrides.num_frames)
|
||||
if overrides.width:
|
||||
if overrides.width > width:
|
||||
logger.warning(
|
||||
"video.width override (%d) exceeds model's "
|
||||
"maximum width (%d), will be ignored", overrides.width,
|
||||
width)
|
||||
width = min(width, overrides.width)
|
||||
if overrides.height:
|
||||
if overrides.height > height:
|
||||
logger.warning(
|
||||
"video.height override (%d) exceeds model's "
|
||||
"maximum height (%d), will be ignored",
|
||||
overrides.height, height)
|
||||
height = min(height, overrides.height)
|
||||
video = np.full((num_frames, width, height, 3), 255)
|
||||
return [video] * num_videos
|
||||
|
||||
@ -162,13 +229,14 @@ class MultiModalProfiler(Generic[_I]):
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Optional[Mapping[str, int]] = None,
|
||||
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
|
||||
) -> MultiModalInputs:
|
||||
if mm_counts is None:
|
||||
mm_counts = self.get_mm_limits()
|
||||
|
||||
factory = self.dummy_inputs
|
||||
processor_inputs = factory.get_dummy_processor_inputs(
|
||||
seq_len, mm_counts)
|
||||
seq_len, mm_counts, mm_options)
|
||||
|
||||
return self.processor.apply(
|
||||
prompt=processor_inputs.prompt,
|
||||
@ -195,8 +263,9 @@ class MultiModalProfiler(Generic[_I]):
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Optional[Mapping[str, int]] = None,
|
||||
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
|
||||
) -> DummyEncoderData:
|
||||
mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts)
|
||||
mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts, mm_options)
|
||||
mm_inputs = cast(MultiModalEncDecInputs, mm_inputs)
|
||||
|
||||
# For encoder-decoder models, use encoder prompt token ids instead of
|
||||
@ -228,8 +297,9 @@ class MultiModalProfiler(Generic[_I]):
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Optional[Mapping[str, int]] = None,
|
||||
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
|
||||
) -> DummyDecoderData:
|
||||
mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts)
|
||||
mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts, mm_options)
|
||||
|
||||
prompt_token_ids = mm_inputs["prompt_token_ids"]
|
||||
total_len = len(prompt_token_ids)
|
||||
@ -274,7 +344,7 @@ class MultiModalProfiler(Generic[_I]):
|
||||
|
||||
`<im_start> [IMG] [IMG] [IMG] <row_break> [IMG] [IMG] [IMG] <im_end>`
|
||||
Returns 9, even when the number of image embeddings is 6.
|
||||
|
||||
|
||||
This is important to take into account when profiling and
|
||||
initializing the encoder cache size.
|
||||
"""
|
||||
|
||||
@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Generic, Optional, Protocol, TypeVar
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer import (AnyTokenizer,
|
||||
cached_tokenizer_from_config)
|
||||
@ -52,7 +53,7 @@ class DummyInputsBuilderFactory(Protocol[_I]): # type: ignore[misc]
|
||||
...
|
||||
|
||||
|
||||
class MultiModalProcessorFactory(Protocol[_I]):
|
||||
class MultiModalProcessorFactory(Protocol[_I]): # type: ignore[misc]
|
||||
"""
|
||||
Constructs a
|
||||
[`BaseMultiModalProcessor`][vllm.multimodal.processing.BaseMultiModalProcessor]
|
||||
@ -95,6 +96,28 @@ class MultiModalRegistry:
|
||||
self._processor_factories = ClassRegistry[nn.Module,
|
||||
_ProcessorFactories]()
|
||||
|
||||
def _extract_mm_options(
|
||||
self,
|
||||
model_config: "ModelConfig",
|
||||
) -> Optional[Mapping[str, BaseDummyOptions]]:
|
||||
"""
|
||||
Extract multimodal dummy options from model config.
|
||||
|
||||
Returns None if no configurable options are found, otherwise returns
|
||||
a mapping of modality names to their dummy options.
|
||||
"""
|
||||
if not model_config.multimodal_config:
|
||||
return None
|
||||
|
||||
mm_options = {
|
||||
m: opt
|
||||
for m in model_config.multimodal_config.limit_per_prompt
|
||||
if (opt := model_config.multimodal_config.get_dummy_options(m)
|
||||
) is not None
|
||||
}
|
||||
|
||||
return mm_options if len(mm_options) > 0 else None
|
||||
|
||||
def supports_multimodal_inputs(self, model_config: "ModelConfig") -> bool:
|
||||
"""
|
||||
Checks if the model supports multimodal inputs.
|
||||
@ -135,7 +158,7 @@ class MultiModalRegistry:
|
||||
return {}
|
||||
|
||||
processor = self.create_processor(model_config, cache=cache)
|
||||
profiler = MultiModalProfiler(processor)
|
||||
profiler: MultiModalProfiler = MultiModalProfiler(processor)
|
||||
|
||||
seq_len = model_config.max_model_len
|
||||
mm_limits = self.get_mm_limits_per_prompt(model_config, cache=cache)
|
||||
@ -189,7 +212,7 @@ class MultiModalRegistry:
|
||||
return {}
|
||||
|
||||
processor = self.create_processor(model_config, cache=cache)
|
||||
profiler = MultiModalProfiler(processor)
|
||||
profiler: MultiModalProfiler = MultiModalProfiler(processor)
|
||||
return profiler.get_mm_limits()
|
||||
|
||||
def register_processor(
|
||||
@ -285,8 +308,15 @@ class MultiModalRegistry:
|
||||
The model is identified by ``model_config``.
|
||||
"""
|
||||
processor = self.create_processor(model_config, cache=cache)
|
||||
profiler = MultiModalProfiler(processor)
|
||||
dummy_data = profiler.get_decoder_dummy_data(seq_len, mm_counts)
|
||||
profiler: MultiModalProfiler = MultiModalProfiler(processor)
|
||||
|
||||
# Extract configurable options from multimodal config.
|
||||
# Only include modalities that use advanced option types so legacy
|
||||
# count-only behavior remains unchanged.
|
||||
mm_options = self._extract_mm_options(model_config)
|
||||
|
||||
dummy_data = profiler.get_decoder_dummy_data(seq_len, mm_counts,
|
||||
mm_options)
|
||||
|
||||
# Having more tokens is over-conservative but otherwise fine
|
||||
token_ids = dummy_data.prompt_token_ids
|
||||
@ -311,8 +341,15 @@ class MultiModalRegistry:
|
||||
The model is identified by ``model_config``.
|
||||
"""
|
||||
processor = self.create_processor(model_config, cache=cache)
|
||||
profiler = MultiModalProfiler(processor)
|
||||
dummy_data = profiler.get_encoder_dummy_data(seq_len, mm_counts)
|
||||
profiler: MultiModalProfiler = MultiModalProfiler(processor)
|
||||
|
||||
# Extract configurable options from multimodal config.
|
||||
# Only include modalities that use advanced option types so legacy
|
||||
# count-only behavior remains unchanged.
|
||||
mm_options = self._extract_mm_options(model_config)
|
||||
|
||||
dummy_data = profiler.get_encoder_dummy_data(seq_len, mm_counts,
|
||||
mm_options)
|
||||
|
||||
# Having more tokens is over-conservative but otherwise fine
|
||||
token_ids = dummy_data.prompt_token_ids
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user