mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 12:25:32 +08:00
[Bugfix][VLM] Fix Fuyu batching inference with max_num_seqs>1 (#8892)
This commit is contained in:
parent
0e088750af
commit
6d792d2f31
@ -65,8 +65,8 @@ def run_test(
|
||||
|
||||
# max_model_len should be greater than image_feature_size
|
||||
with vllm_runner(model,
|
||||
max_model_len=2560,
|
||||
max_num_seqs=1,
|
||||
max_model_len=2048,
|
||||
max_num_seqs=2,
|
||||
dtype=dtype,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
distributed_executor_backend=distributed_executor_backend,
|
||||
@ -80,8 +80,6 @@ def run_test(
|
||||
]
|
||||
|
||||
with hf_runner(model, dtype=dtype) as hf_model:
|
||||
hf_model.model.get_output_embeddings = lambda: \
|
||||
hf_model.model.language_model.get_output_embeddings()
|
||||
eos_token_id = hf_model.processor.tokenizer.eos_token_id
|
||||
hf_outputs_per_image = [
|
||||
hf_model.generate_greedy_logprobs_limit(prompts,
|
||||
|
||||
@ -42,7 +42,7 @@ from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
|
||||
SequenceData)
|
||||
|
||||
from .interfaces import SupportsMultiModal
|
||||
from .utils import merge_multimodal_embeddings
|
||||
from .utils import flatten_bn, merge_multimodal_embeddings
|
||||
|
||||
# Cannot find the following 2 numbers from hf config.
|
||||
_IMAGE_TOKEN_ID = 71011
|
||||
@ -165,7 +165,7 @@ def input_processor_for_fuyu(ctx: InputContext, llm_inputs: LLMInputs):
|
||||
model_config.model)
|
||||
|
||||
model_image_input = _fuyu_image_preprocess(image_processor, image_data)
|
||||
image_patches = torch.stack([
|
||||
image_patches = torch.cat([
|
||||
image_patch[0]
|
||||
for image_patch in model_image_input["image_patches"]
|
||||
])
|
||||
@ -210,7 +210,7 @@ def input_mapper_for_fuyu(ctx: InputContext, data: object):
|
||||
])
|
||||
|
||||
# image has been processed with prompt in input processor
|
||||
return MultiModalInputs({"image_patches": data})
|
||||
return MultiModalInputs({"pixel_values": data})
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_fuyu)
|
||||
@ -242,23 +242,42 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal):
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
|
||||
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
h = w = self.config.patch_size
|
||||
num_channels = self.config.num_channels
|
||||
expected_dims = num_channels * h * w
|
||||
|
||||
def _validate_shape(d: torch.Tensor):
|
||||
actual_dims = d.size(-1)
|
||||
|
||||
if actual_dims != expected_dims:
|
||||
expected_expr = str(expected_dims)
|
||||
raise ValueError(
|
||||
"The expected shape of pixel values per image per batch "
|
||||
f" per patch is {expected_expr}. "
|
||||
f"You supplied {tuple(d.shape)}.")
|
||||
|
||||
for d in data:
|
||||
_validate_shape(d)
|
||||
|
||||
return data.to(self.vision_embed_tokens.weight.dtype)
|
||||
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object) -> Optional[FuyuImagePixelInputs]:
|
||||
image_patches = kwargs.pop("image_patches", None)
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
|
||||
if isinstance(image_patches, torch.Tensor):
|
||||
# Remove the N dimension until multiple images are supported.
|
||||
image_patches = image_patches.squeeze(1)
|
||||
if pixel_values is not None:
|
||||
if not isinstance(pixel_values, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of image patches. "
|
||||
f"Got type: {type(pixel_values)}")
|
||||
|
||||
return FuyuImagePixelInputs(
|
||||
type="pixel_values",
|
||||
data=self._validate_pixel_values(
|
||||
flatten_bn(pixel_values, concat=True)),
|
||||
)
|
||||
|
||||
expected_feature_size = self.image_feature_size
|
||||
if image_patches.size(-1) != expected_feature_size:
|
||||
raise ValueError(
|
||||
f"Expected image patches to have the last dimension of "
|
||||
f"{expected_feature_size}, got {image_patches.size(-1)}")
|
||||
image_patches = image_patches.to(
|
||||
self.vision_embed_tokens.weight.dtype)
|
||||
return FuyuImagePixelInputs(type="pixel_values",
|
||||
data=image_patches)
|
||||
return None
|
||||
|
||||
def _process_image_input(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user