mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-11 03:34:57 +08:00
[CI/Build] Improve Tensor Schema tests speed by avoid engine core initialization (#23357)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
parent
b55713683c
commit
ff0e59d83a
@ -566,8 +566,7 @@ steps:
|
|||||||
- tests/models/multimodal
|
- tests/models/multimodal
|
||||||
commands:
|
commands:
|
||||||
- pip install git+https://github.com/TIGER-AI-Lab/Mantis.git
|
- pip install git+https://github.com/TIGER-AI-Lab/Mantis.git
|
||||||
- pytest -v -s models/multimodal/processing --ignore models/multimodal/processing/test_tensor_schema.py
|
- pytest -v -s models/multimodal/processing
|
||||||
- pytest -v -s models/multimodal/processing/test_tensor_schema.py
|
|
||||||
|
|
||||||
- label: Multi-Modal Models Test (Standard)
|
- label: Multi-Modal Models Test (Standard)
|
||||||
mirror_hardwares: [amdexperimental]
|
mirror_hardwares: [amdexperimental]
|
||||||
|
|||||||
@ -1,30 +1,31 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import tempfile
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
|
from contextlib import contextmanager
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Any, Union
|
from typing import Any, Union
|
||||||
from unittest.mock import patch
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
import torch.nn as nn
|
||||||
from mistral_common.protocol.instruct.messages import (ImageChunk, TextChunk,
|
from mistral_common.protocol.instruct.messages import (ImageChunk, TextChunk,
|
||||||
UserMessage)
|
UserMessage)
|
||||||
from mistral_common.protocol.instruct.request import ChatCompletionRequest
|
from mistral_common.protocol.instruct.request import ChatCompletionRequest
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig, VllmConfig, set_current_vllm_config
|
||||||
from vllm.engine.llm_engine import LLMEngine as V0LLMEngine
|
from vllm.distributed import (cleanup_dist_env_and_memory,
|
||||||
|
init_distributed_environment,
|
||||||
|
initialize_model_parallel)
|
||||||
from vllm.inputs import InputProcessingContext
|
from vllm.inputs import InputProcessingContext
|
||||||
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
|
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
|
||||||
MultiModalKwargs)
|
from vllm.multimodal import MULTIMODAL_REGISTRY, BatchedTensorInputs
|
||||||
from vllm.multimodal.processing import BaseMultiModalProcessor
|
from vllm.multimodal.processing import BaseMultiModalProcessor
|
||||||
from vllm.multimodal.utils import group_mm_kwargs_by_modality
|
from vllm.multimodal.utils import group_mm_kwargs_by_modality
|
||||||
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
|
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
|
||||||
from vllm.utils import GiB_bytes, is_list_of, set_default_torch_num_threads
|
from vllm.utils import is_list_of
|
||||||
from vllm.v1.core.kv_cache_utils import get_kv_cache_config
|
|
||||||
from vllm.v1.engine.core import EngineCore as V1EngineCore
|
|
||||||
|
|
||||||
from ....conftest import VllmRunner
|
|
||||||
from ...registry import _MULTIMODAL_EXAMPLE_MODELS, HF_EXAMPLE_MODELS
|
from ...registry import _MULTIMODAL_EXAMPLE_MODELS, HF_EXAMPLE_MODELS
|
||||||
from ...utils import dummy_hf_overrides
|
from ...utils import dummy_hf_overrides
|
||||||
|
|
||||||
@ -137,6 +138,27 @@ def create_batched_mm_kwargs(
|
|||||||
return group_mm_kwargs_by_modality(items)
|
return group_mm_kwargs_by_modality(items)
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def initialize_dummy_model(model_cls: nn.Module, model_config: ModelConfig):
|
||||||
|
temp_file = tempfile.mkstemp()[1]
|
||||||
|
init_distributed_environment(
|
||||||
|
world_size=1,
|
||||||
|
rank=0,
|
||||||
|
distributed_init_method=f"file://{temp_file}",
|
||||||
|
local_rank=0,
|
||||||
|
backend="nccl",
|
||||||
|
)
|
||||||
|
initialize_model_parallel(tensor_model_parallel_size=1)
|
||||||
|
vllm_config = VllmConfig(model_config=model_config)
|
||||||
|
with set_current_vllm_config(vllm_config=vllm_config):
|
||||||
|
with set_default_torch_dtype(model_config.dtype):
|
||||||
|
model = model_cls(vllm_config=vllm_config)
|
||||||
|
yield model
|
||||||
|
|
||||||
|
del model
|
||||||
|
cleanup_dist_env_and_memory()
|
||||||
|
|
||||||
|
|
||||||
def get_model_id_to_test(
|
def get_model_id_to_test(
|
||||||
model_arch_list: Iterable[str]) -> list[tuple[str, str]]:
|
model_arch_list: Iterable[str]) -> list[tuple[str, str]]:
|
||||||
filtered_results = []
|
filtered_results = []
|
||||||
@ -155,8 +177,7 @@ def get_model_id_to_test(
|
|||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"model_arch, model_id",
|
"model_arch, model_id",
|
||||||
get_model_id_to_test(_MULTIMODAL_EXAMPLE_MODELS.keys()))
|
get_model_id_to_test(_MULTIMODAL_EXAMPLE_MODELS.keys()))
|
||||||
def test_model_tensor_schema(model_arch: str, model_id: str,
|
def test_model_tensor_schema(model_arch: str, model_id: str):
|
||||||
vllm_runner: type[VllmRunner], monkeypatch):
|
|
||||||
if model_arch in ARCH_TO_SKIP:
|
if model_arch in ARCH_TO_SKIP:
|
||||||
pytest.skip(f"Skipping {model_arch} due to {ARCH_TO_SKIP[model_arch]}")
|
pytest.skip(f"Skipping {model_arch} due to {ARCH_TO_SKIP[model_arch]}")
|
||||||
if model_id in REPO_ID_TO_SKIP:
|
if model_id in REPO_ID_TO_SKIP:
|
||||||
@ -177,14 +198,20 @@ def test_model_tensor_schema(model_arch: str, model_id: str,
|
|||||||
tokenizer_mode=model_info.tokenizer_mode,
|
tokenizer_mode=model_info.tokenizer_mode,
|
||||||
revision=model_info.revision,
|
revision=model_info.revision,
|
||||||
trust_remote_code=model_info.trust_remote_code,
|
trust_remote_code=model_info.trust_remote_code,
|
||||||
hf_overrides=model_info.hf_overrides,
|
hf_overrides=hf_overrides_fn,
|
||||||
)
|
)
|
||||||
model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)
|
model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)
|
||||||
factories = MULTIMODAL_REGISTRY._processor_factories[model_cls]
|
factories = MULTIMODAL_REGISTRY._processor_factories[model_cls]
|
||||||
|
|
||||||
if not any(
|
inputs_parse_methods = []
|
||||||
hasattr(model_cls, f"_parse_and_validate_{m}_input")
|
for attr_name in dir(model_cls):
|
||||||
for m in ["image", "video", "audio"]):
|
attr = getattr(model_cls, attr_name)
|
||||||
|
if hasattr(attr, "__annotations__"):
|
||||||
|
return_type = attr.__annotations__.get("return", None)
|
||||||
|
if return_type is not None and "Input" in str(return_type):
|
||||||
|
inputs_parse_methods.append(attr_name)
|
||||||
|
|
||||||
|
if not any(inputs_parse_methods):
|
||||||
pytest.skip(f"{model_arch} does not support tensor schema validation.")
|
pytest.skip(f"{model_arch} does not support tensor schema validation.")
|
||||||
|
|
||||||
ctx = InputProcessingContext(
|
ctx = InputProcessingContext(
|
||||||
@ -197,68 +224,13 @@ def test_model_tensor_schema(model_arch: str, model_id: str,
|
|||||||
modality: 3 if limit is None else limit
|
modality: 3 if limit is None else limit
|
||||||
for modality, limit in supported_mm_limits.items()
|
for modality, limit in supported_mm_limits.items()
|
||||||
}
|
}
|
||||||
|
model_config.get_multimodal_config().limit_per_prompt = limit_mm_per_prompt
|
||||||
|
processor = factories.build_processor(ctx, cache=None)
|
||||||
|
|
||||||
# Avoid calling model.forward()
|
with initialize_dummy_model(model_cls, model_config) as model:
|
||||||
def _initialize_kv_caches_v0(self) -> None:
|
|
||||||
self.cache_config.num_gpu_blocks = 0
|
|
||||||
self.cache_config.num_cpu_blocks = 0
|
|
||||||
|
|
||||||
def _initialize_kv_caches_v1(self, vllm_config):
|
|
||||||
kv_cache_specs = self.model_executor.get_kv_cache_specs()
|
|
||||||
scheduler_kv_cache_config = get_kv_cache_config(
|
|
||||||
vllm_config,
|
|
||||||
kv_cache_specs[0],
|
|
||||||
10 * GiB_bytes,
|
|
||||||
)
|
|
||||||
|
|
||||||
# gpu_blocks (> 0), cpu_blocks, scheduler_kv_cache_config
|
|
||||||
return 1, 0, scheduler_kv_cache_config
|
|
||||||
|
|
||||||
with (patch.object(V0LLMEngine, "_initialize_kv_caches",
|
|
||||||
_initialize_kv_caches_v0),
|
|
||||||
patch.object(V1EngineCore, "_initialize_kv_caches",
|
|
||||||
_initialize_kv_caches_v1), monkeypatch.context() as m):
|
|
||||||
m.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
|
|
||||||
if model_info.v0_only:
|
|
||||||
m.setenv("VLLM_USE_V1", "0")
|
|
||||||
|
|
||||||
# TODO(Isotr0py): Can we avoid initializing engine?
|
|
||||||
with (
|
|
||||||
set_default_torch_num_threads(1),
|
|
||||||
vllm_runner(
|
|
||||||
model_id,
|
|
||||||
tokenizer_name=model_info.tokenizer,
|
|
||||||
tokenizer_mode=model_info.tokenizer_mode,
|
|
||||||
revision=model_info.revision,
|
|
||||||
trust_remote_code=model_info.trust_remote_code,
|
|
||||||
max_model_len=model_info.max_model_len,
|
|
||||||
load_format="dummy",
|
|
||||||
hf_overrides=hf_overrides_fn,
|
|
||||||
limit_mm_per_prompt=limit_mm_per_prompt,
|
|
||||||
enforce_eager=True,
|
|
||||||
) as vllm_model,
|
|
||||||
):
|
|
||||||
model_config = vllm_model.llm.llm_engine.model_config
|
|
||||||
llm_engine = vllm_model.llm.llm_engine
|
|
||||||
|
|
||||||
if hasattr(llm_engine, "processor"):
|
|
||||||
# v1 processor
|
|
||||||
mm_registry = llm_engine.processor.mm_registry
|
|
||||||
else:
|
|
||||||
# v0 input_preprocessor
|
|
||||||
mm_registry = llm_engine.input_preprocessor.mm_registry
|
|
||||||
|
|
||||||
processor = mm_registry.create_processor(model_config)
|
|
||||||
|
|
||||||
def validate_model_input(model, modality: str,
|
|
||||||
mm_kwargs: MultiModalKwargs):
|
|
||||||
method_name = f"_parse_and_validate_{modality}_input"
|
|
||||||
if hasattr(model, method_name):
|
|
||||||
getattr(model, method_name)(**mm_kwargs)
|
|
||||||
|
|
||||||
for modality, _, mm_kwargs in create_batched_mm_kwargs(
|
for modality, _, mm_kwargs in create_batched_mm_kwargs(
|
||||||
model_config, processor):
|
model_config, processor):
|
||||||
valid_func = partial(validate_model_input,
|
for method_name in inputs_parse_methods:
|
||||||
modality=modality,
|
print(f"Testing `{method_name}` with modality={modality} "
|
||||||
mm_kwargs=mm_kwargs)
|
f"and mm_kwargs{list(mm_kwargs.keys())}")
|
||||||
vllm_model.apply_model(valid_func)
|
getattr(model, method_name)(modality=modality, **mm_kwargs)
|
||||||
|
|||||||
@ -549,7 +549,7 @@ class GraniteSpeechForConditionalGeneration(
|
|||||||
|
|
||||||
raise ValueError("Only audio modality is supported")
|
raise ValueError("Only audio modality is supported")
|
||||||
|
|
||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str):
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
config = vllm_config.model_config.hf_config
|
config = vllm_config.model_config.hf_config
|
||||||
quant_config = vllm_config.quant_config
|
quant_config = vllm_config.quant_config
|
||||||
|
|||||||
@ -1371,7 +1371,8 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
output_tensor[i, :t.size(0)] = t
|
output_tensor[i, :t.size(0)] = t
|
||||||
return output_tensor
|
return output_tensor
|
||||||
|
|
||||||
def _parse_and_validate_image_input(self, **kwargs: object):
|
def _parse_and_validate_image_input(
|
||||||
|
self, **kwargs: object) -> Optional[MllamaImagePixelInputs]:
|
||||||
# tensor with the same shape will be batched together by
|
# tensor with the same shape will be batched together by
|
||||||
# MultiModalKwargs.batch, so pixel_values here can be:
|
# MultiModalKwargs.batch, so pixel_values here can be:
|
||||||
# - list[torch.Tensor]:
|
# - list[torch.Tensor]:
|
||||||
|
|||||||
@ -209,7 +209,7 @@ class OvisImagePatchInputs(TypedDict):
|
|||||||
`(batch_size * num_patches, patch_size_x * patch_size_y * num_channels)`
|
`(batch_size * num_patches, patch_size_x * patch_size_y * num_channels)`
|
||||||
"""
|
"""
|
||||||
|
|
||||||
inducator_tokens: torch.Tensor
|
indicator_tokens: torch.Tensor
|
||||||
"""
|
"""
|
||||||
Shape:
|
Shape:
|
||||||
`(batch_size * (num_patches + 1))`
|
`(batch_size * (num_patches + 1))`
|
||||||
|
|||||||
@ -3,7 +3,7 @@
|
|||||||
""" PyTorch Ovis model."""
|
""" PyTorch Ovis model."""
|
||||||
from collections.abc import Iterable, Mapping
|
from collections.abc import Iterable, Mapping
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Optional, Union
|
from typing import Literal, Optional, TypedDict, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -50,6 +50,27 @@ IMAGE_PAD_TOKEN_ID_MAP = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class OvisVideoPatchInputs(TypedDict):
|
||||||
|
type: Literal["video_patches"]
|
||||||
|
flat_data: torch.Tensor
|
||||||
|
"""
|
||||||
|
Shape:
|
||||||
|
`(batch_size * num_patches, patch_size_x * patch_size_y * num_channels)`
|
||||||
|
"""
|
||||||
|
|
||||||
|
indicator_tokens: torch.Tensor
|
||||||
|
"""
|
||||||
|
Shape:
|
||||||
|
`(batch_size * (num_patches + 1))`
|
||||||
|
"""
|
||||||
|
|
||||||
|
patches_per_image: list[int]
|
||||||
|
"""
|
||||||
|
List of number of total patches for each frame in the video.
|
||||||
|
This is used to restore the first two dimensions of `flat_data`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
def _ovis2_5_field_config():
|
def _ovis2_5_field_config():
|
||||||
return dict(pixel_values=MultiModalFieldConfig.batched("image"),
|
return dict(pixel_values=MultiModalFieldConfig.batched("image"),
|
||||||
grids=MultiModalFieldConfig.batched("image"),
|
grids=MultiModalFieldConfig.batched("image"),
|
||||||
@ -429,14 +450,8 @@ class Ovis2_5(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
self.make_empty_intermediate_tensors = (
|
self.make_empty_intermediate_tensors = (
|
||||||
self.get_language_model().make_empty_intermediate_tensors)
|
self.get_language_model().make_empty_intermediate_tensors)
|
||||||
|
|
||||||
def _parse_and_validate_visual_input(
|
def _parse_and_validate_image_input(
|
||||||
self, is_video,
|
self, **kwargs: object) -> Optional[OvisImagePatchInputs]:
|
||||||
**kwargs: object) -> Optional[OvisImagePatchInputs]:
|
|
||||||
if is_video:
|
|
||||||
pixel_values = kwargs.pop("video_pixel_values", None)
|
|
||||||
indicator_tokens = kwargs.pop("video_indicator_tokens", None)
|
|
||||||
grids = kwargs.pop("video_grids", None)
|
|
||||||
else:
|
|
||||||
pixel_values = kwargs.pop("pixel_values", None)
|
pixel_values = kwargs.pop("pixel_values", None)
|
||||||
indicator_tokens = kwargs.pop("indicator_tokens", None)
|
indicator_tokens = kwargs.pop("indicator_tokens", None)
|
||||||
grids = kwargs.pop("grids", None)
|
grids = kwargs.pop("grids", None)
|
||||||
@ -466,8 +481,40 @@ class Ovis2_5(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
|
|
||||||
raise AssertionError("This line should be unreachable.")
|
raise AssertionError("This line should be unreachable.")
|
||||||
|
|
||||||
|
def _parse_and_validate_video_input(
|
||||||
|
self, **kwargs: object) -> Optional[OvisImagePatchInputs]:
|
||||||
|
pixel_values = kwargs.pop("video_pixel_values", None)
|
||||||
|
indicator_tokens = kwargs.pop("video_indicator_tokens", None)
|
||||||
|
grids = kwargs.pop("video_grids", None)
|
||||||
|
if pixel_values is None and indicator_tokens is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if pixel_values is not None and indicator_tokens is not None:
|
||||||
|
if not isinstance(pixel_values, (torch.Tensor, list)):
|
||||||
|
raise ValueError("Incorrect type of pixel values. "
|
||||||
|
f"Got type: {type(pixel_values)}")
|
||||||
|
|
||||||
|
if not isinstance(indicator_tokens, (torch.Tensor, list)):
|
||||||
|
raise ValueError("Incorrect type of indicator_tokens. "
|
||||||
|
f"Got type: {type(indicator_tokens)}")
|
||||||
|
|
||||||
|
return OvisVideoPatchInputs(
|
||||||
|
type="video_patches",
|
||||||
|
flat_data=flatten_bn(flatten_bn(pixel_values), concat=True),
|
||||||
|
patches_per_image=[
|
||||||
|
x.shape[0] // (self.config.vit_config.hidden_stride**2)
|
||||||
|
for x in flatten_bn(pixel_values)
|
||||||
|
],
|
||||||
|
indicator_tokens=flatten_bn(flatten_bn(indicator_tokens),
|
||||||
|
concat=True),
|
||||||
|
grids=flatten_bn(flatten_bn(grids), concat=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
raise AssertionError("This line should be unreachable.")
|
||||||
|
|
||||||
def _process_image_input(
|
def _process_image_input(
|
||||||
self, image_input: OvisImagePatchInputs) -> MultiModalEmbeddings:
|
self, image_input: Union[OvisImagePatchInputs, OvisVideoPatchInputs]
|
||||||
|
) -> MultiModalEmbeddings:
|
||||||
image_patches_flat = image_input["flat_data"]
|
image_patches_flat = image_input["flat_data"]
|
||||||
patches_per_image = image_input["patches_per_image"]
|
patches_per_image = image_input["patches_per_image"]
|
||||||
indicator_tokens = image_input["indicator_tokens"]
|
indicator_tokens = image_input["indicator_tokens"]
|
||||||
@ -500,21 +547,44 @@ class Ovis2_5(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
torch.cat(vision_embeddings_per_image, dim=0))
|
torch.cat(vision_embeddings_per_image, dim=0))
|
||||||
return tuple(vision_embeddings)
|
return tuple(vision_embeddings)
|
||||||
|
|
||||||
def get_multimodal_embeddings(
|
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
|
||||||
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
|
modalities = {}
|
||||||
embeddings = []
|
|
||||||
|
|
||||||
# NOTE: _parse_and_validate_visual_input has side-effects and pops
|
# Preserve the order of modalities if there are multiple of them
|
||||||
# keys from kwargs. We process images first, then videos.
|
# from the order of kwargs.
|
||||||
image_input = self._parse_and_validate_visual_input(False, **kwargs)
|
for input_key in kwargs:
|
||||||
if image_input:
|
if input_key in ("pixel_values", "indicator_tokens",
|
||||||
embeddings.extend(self._process_image_input(image_input))
|
"grids") and "images" not in modalities:
|
||||||
|
modalities["images"] = self._parse_and_validate_image_input(
|
||||||
|
**kwargs)
|
||||||
|
if input_key in ("video_pixel_values", "video_indicator_tokens",
|
||||||
|
"video_grids") and "videos" not in modalities:
|
||||||
|
modalities["videos"] = self._parse_and_validate_video_input(
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
video_input = self._parse_and_validate_visual_input(True, **kwargs)
|
return modalities
|
||||||
if video_input:
|
|
||||||
embeddings.extend(self._process_image_input(video_input))
|
|
||||||
|
|
||||||
return tuple(embeddings) if embeddings else None
|
def get_multimodal_embeddings(self,
|
||||||
|
**kwargs: object) -> MultiModalEmbeddings:
|
||||||
|
|
||||||
|
modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
|
||||||
|
if not modalities:
|
||||||
|
return []
|
||||||
|
|
||||||
|
multimodal_embeddings: tuple[torch.Tensor, ...] = ()
|
||||||
|
# NOTE: It is important to iterate over the keys in this dictionary
|
||||||
|
# to preserve the order of the modalities.
|
||||||
|
for modality in modalities:
|
||||||
|
if modality == "images":
|
||||||
|
image_input = modalities["images"]
|
||||||
|
vision_embeddings = self._process_image_input(image_input)
|
||||||
|
multimodal_embeddings += vision_embeddings
|
||||||
|
if modality == "videos":
|
||||||
|
video_input = modalities["videos"]
|
||||||
|
video_embeddings = self._process_image_input(video_input)
|
||||||
|
multimodal_embeddings += video_embeddings
|
||||||
|
|
||||||
|
return multimodal_embeddings
|
||||||
|
|
||||||
def get_input_embeddings(
|
def get_input_embeddings(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -1031,8 +1031,8 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
|
|||||||
]
|
]
|
||||||
return audio_embeds
|
return audio_embeds
|
||||||
|
|
||||||
def _parse_and_validate_image_input(self,
|
def _parse_and_validate_image_input(
|
||||||
**kwargs: object) -> Optional[dict]:
|
self, **kwargs: object) -> Optional[Phi4MMImagePixelInputs]:
|
||||||
input_image_embeds: NestedTensors = kwargs.get("input_image_embeds")
|
input_image_embeds: NestedTensors = kwargs.get("input_image_embeds")
|
||||||
if input_image_embeds is None:
|
if input_image_embeds is None:
|
||||||
return None
|
return None
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user