[Bugfix][VLM] Fix Fuyu batching inference with max_num_seqs>1 (#8892)

This commit is contained in:
Isotr0py 2024-09-27 16:15:58 +08:00 committed by GitHub
parent 0e088750af
commit 6d792d2f31
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 37 additions and 20 deletions

View File

@ -65,8 +65,8 @@ def run_test(
# max_model_len should be greater than image_feature_size
with vllm_runner(model,
max_model_len=2560,
max_num_seqs=1,
max_model_len=2048,
max_num_seqs=2,
dtype=dtype,
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend,
@ -80,8 +80,6 @@ def run_test(
]
with hf_runner(model, dtype=dtype) as hf_model:
hf_model.model.get_output_embeddings = lambda: \
hf_model.model.language_model.get_output_embeddings()
eos_token_id = hf_model.processor.tokenizer.eos_token_id
hf_outputs_per_image = [
hf_model.generate_greedy_logprobs_limit(prompts,

View File

@ -42,7 +42,7 @@ from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
SequenceData)
from .interfaces import SupportsMultiModal
from .utils import merge_multimodal_embeddings
from .utils import flatten_bn, merge_multimodal_embeddings
# Cannot find the following 2 numbers from hf config.
_IMAGE_TOKEN_ID = 71011
@ -165,7 +165,7 @@ def input_processor_for_fuyu(ctx: InputContext, llm_inputs: LLMInputs):
model_config.model)
model_image_input = _fuyu_image_preprocess(image_processor, image_data)
image_patches = torch.stack([
image_patches = torch.cat([
image_patch[0]
for image_patch in model_image_input["image_patches"]
])
@ -210,7 +210,7 @@ def input_mapper_for_fuyu(ctx: InputContext, data: object):
])
# image has been processed with prompt in input processor
return MultiModalInputs({"image_patches": data})
return MultiModalInputs({"pixel_values": data})
@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_fuyu)
@ -242,23 +242,42 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal):
cache_config=cache_config,
quant_config=quant_config)
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
h = w = self.config.patch_size
num_channels = self.config.num_channels
expected_dims = num_channels * h * w
def _validate_shape(d: torch.Tensor):
actual_dims = d.size(-1)
if actual_dims != expected_dims:
expected_expr = str(expected_dims)
raise ValueError(
"The expected shape of pixel values per image per batch "
f" per patch is {expected_expr}. "
f"You supplied {tuple(d.shape)}.")
for d in data:
_validate_shape(d)
return data.to(self.vision_embed_tokens.weight.dtype)
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[FuyuImagePixelInputs]:
image_patches = kwargs.pop("image_patches", None)
pixel_values = kwargs.pop("pixel_values", None)
if isinstance(image_patches, torch.Tensor):
# Remove the N dimension until multiple images are supported.
image_patches = image_patches.squeeze(1)
if pixel_values is not None:
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of image patches. "
f"Got type: {type(pixel_values)}")
return FuyuImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(
flatten_bn(pixel_values, concat=True)),
)
expected_feature_size = self.image_feature_size
if image_patches.size(-1) != expected_feature_size:
raise ValueError(
f"Expected image patches to have the last dimension of "
f"{expected_feature_size}, got {image_patches.size(-1)}")
image_patches = image_patches.to(
self.vision_embed_tokens.weight.dtype)
return FuyuImagePixelInputs(type="pixel_values",
data=image_patches)
return None
def _process_image_input(