mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-08 20:14:27 +08:00
[Bugfix][Model] fix mllama multi-image (#14883)
Signed-off-by: yan ma <yan.ma@intel.com>
This commit is contained in:
parent
a164aea35d
commit
ff6473980d
@ -212,7 +212,7 @@ def _run_test(
|
||||
with vllm_runner(model,
|
||||
dtype=dtype,
|
||||
max_model_len=4096,
|
||||
max_num_seqs=2,
|
||||
max_num_seqs=3,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
distributed_executor_backend=distributed_executor_backend,
|
||||
limit_mm_per_prompt={"image": _LIMIT_IMAGE_PER_PROMPT
|
||||
|
||||
@ -1235,11 +1235,34 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def unpack_data(self,
|
||||
image_data: Union[List[torch.Tensor], torch.Tensor],
|
||||
padding_value=0) -> torch.Tensor:
|
||||
if isinstance(image_data, torch.Tensor):
|
||||
# torch.Tensor
|
||||
return image_data
|
||||
else:
|
||||
assert isinstance(
|
||||
image_data[0],
|
||||
torch.Tensor), "Image data is not properly batched."
|
||||
# List[torch.Tensor]
|
||||
bsz = len(image_data)
|
||||
max_length = max(t.size(0) for t in image_data)
|
||||
trailing_dims = image_data[0].shape[1:]
|
||||
for data in image_data:
|
||||
cur_trailing_dims = data.shape[1:]
|
||||
assert cur_trailing_dims == trailing_dims
|
||||
output_tensor = torch.full((bsz, max_length, *trailing_dims),
|
||||
padding_value,
|
||||
dtype=image_data[0].dtype,
|
||||
device=image_data[0].device)
|
||||
for i, t in enumerate(image_data):
|
||||
output_tensor[i, :t.size(0)] = t
|
||||
return output_tensor
|
||||
|
||||
def _parse_and_validate_image_input(self, **kwargs: object):
|
||||
# tensor with the same shape will be batched together by
|
||||
# MultiModalKwargs.batch, so pixel_values here can be:
|
||||
# - List[List[torch.Tensor]]:
|
||||
# with shape (num_tiles, 3, image_res, image_res)
|
||||
# - List[torch.Tensor]:
|
||||
# with shape (num_image, num_tiles, 3, image_res, image_res)
|
||||
# - torch.Tensor:
|
||||
@ -1274,10 +1297,9 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
|
||||
return MllamaImagePixelInputs(
|
||||
type="pixel_values",
|
||||
data=pixel_values,
|
||||
aspect_ratio_ids=aspect_ratio_ids,
|
||||
aspect_ratio_mask=aspect_ratio_mask,
|
||||
)
|
||||
data=self.unpack_data(pixel_values),
|
||||
aspect_ratio_ids=self.unpack_data(aspect_ratio_ids),
|
||||
aspect_ratio_mask=self.unpack_data(aspect_ratio_mask))
|
||||
|
||||
if image_embeds is not None:
|
||||
raise NotImplementedError
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user