mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-11 02:35:01 +08:00
[Bugfix] Fix deepseek-vl2 inference with more than 2 images (#13818)
This commit is contained in:
parent
fa82074167
commit
6ff518626c
@ -25,7 +25,8 @@ from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
|
||||
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
|
||||
ImageSize, MultiModalDataItems)
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo, PromptReplacement)
|
||||
BaseProcessingInfo, ProcessingCache,
|
||||
PromptReplacement)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.configs.deepseek_vl2 import (DeepseekVLV2Config,
|
||||
@ -138,18 +139,24 @@ class DeepseekVL2ProcessingInfo(BaseProcessingInfo):
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
return {"image": None}
|
||||
|
||||
def get_num_image_tokens(self, *, image_width: int,
|
||||
image_height: int) -> int:
|
||||
def get_num_image_tokens(self,
|
||||
*,
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
cropping: bool = True) -> int:
|
||||
hf_processor = self.get_hf_processor()
|
||||
image_size = hf_processor.image_size
|
||||
patch_size = hf_processor.patch_size
|
||||
downsample_ratio = hf_processor.downsample_ratio
|
||||
|
||||
best_width, best_height = hf_processor.select_best_resolution(
|
||||
(image_width, image_height))
|
||||
if cropping:
|
||||
best_width, best_height = hf_processor.select_best_resolution(
|
||||
(image_width, image_height))
|
||||
num_width_tiles, num_height_tiles = (best_width // image_size,
|
||||
best_height // image_size)
|
||||
else:
|
||||
num_width_tiles = num_height_tiles = 1
|
||||
|
||||
num_width_tiles, num_height_tiles = (best_width // image_size,
|
||||
best_height // image_size)
|
||||
h = w = math.ceil((image_size // patch_size) / downsample_ratio)
|
||||
|
||||
global_views_tokens = h * (w + 1)
|
||||
@ -169,10 +176,12 @@ class DeepseekVL2ProcessingInfo(BaseProcessingInfo):
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> Mapping[str, int]:
|
||||
num_images = mm_counts.get("image", 0)
|
||||
max_image_size = self.get_image_size_with_most_features()
|
||||
max_image_tokens = self.get_num_image_tokens(
|
||||
image_height=max_image_size.height,
|
||||
image_width=max_image_size.width)
|
||||
image_width=max_image_size.width,
|
||||
cropping=num_images <= 2)
|
||||
|
||||
return {"image": max_image_tokens}
|
||||
|
||||
@ -207,6 +216,30 @@ class DeepseekVL2DummyInputsBuilder(
|
||||
class DeepseekVL2MultiModalProcessor(
|
||||
BaseMultiModalProcessor[DeepseekVL2ProcessingInfo]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
info: DeepseekVL2ProcessingInfo,
|
||||
dummy_inputs: "BaseDummyInputsBuilder[DeepseekVL2ProcessingInfo]",
|
||||
*,
|
||||
cache: Optional[ProcessingCache] = None,
|
||||
enable_sanity_checks: bool = True) -> None:
|
||||
super().__init__(
|
||||
info,
|
||||
dummy_inputs,
|
||||
cache=cache,
|
||||
enable_sanity_checks=enable_sanity_checks,
|
||||
)
|
||||
|
||||
mm_limit = self.info.ctx.model_config.multimodal_config.limit_per_prompt
|
||||
if self.cache is not None and mm_limit["image"] > 2:
|
||||
# The processor output depends on the number of images passed,
|
||||
# making it incompatible with processing cache which is supposed
|
||||
# to be invariant of how many images are passed per prompt
|
||||
self.cache = None
|
||||
logger.warning_once(
|
||||
f"{type(self).__name__} does not support processing cache with "
|
||||
"image limit larger than 2.")
|
||||
|
||||
def _call_hf_processor(
|
||||
self,
|
||||
prompt: str,
|
||||
@ -271,6 +304,7 @@ class DeepseekVL2MultiModalProcessor(
|
||||
num_image_tokens = self.info.get_num_image_tokens(
|
||||
image_width=image_size.width,
|
||||
image_height=image_size.height,
|
||||
cropping=len(images) <= 2,
|
||||
)
|
||||
return [image_token_id] * num_image_tokens
|
||||
|
||||
|
||||
@ -477,13 +477,15 @@ class H2OVLMultiModalProcessor(InternVLMultiModalProcessor[H2OVLProcessingInfo]
|
||||
enable_sanity_checks=enable_sanity_checks,
|
||||
)
|
||||
|
||||
if self.cache is not None:
|
||||
mm_limit = self.info.ctx.model_config.multimodal_config.limit_per_prompt
|
||||
if self.cache is not None and mm_limit["image"] >= 2:
|
||||
# The processor output depends on the number of images passed,
|
||||
# making it incompatible with processing cache which is supposed
|
||||
# to be invariant of how many images are passed per prompt
|
||||
self.cache = None
|
||||
logger.warning_once(
|
||||
f"{type(self).__name__} does not support processing cache.")
|
||||
f"{type(self).__name__} does not support processing cache with "
|
||||
"multi-image support enabled.")
|
||||
|
||||
def _get_prompt_replacements(
|
||||
self,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user