mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-24 20:15:01 +08:00
[Bugfix] Fix issues for Pixtral-Large-Instruct-2411 (#11393)
Signed-off-by: ywang96 <ywang@example.com> Co-authored-by: ywang96 <ywang@example.com>
This commit is contained in:
parent
584f0ae40d
commit
c2d1b075ba
@ -45,8 +45,12 @@ try:
|
||||
except ImportError:
|
||||
USE_XFORMERS_OPS = False
|
||||
|
||||
PIXTRAL_IMAGE_BREAK_ID = 12
|
||||
PIXTRAL_IMAGE_END_ID = 13
|
||||
# These token ids cannot be retrieved from model config
|
||||
# so we hardcode them here.
|
||||
PIXTRAL_12B_IMAGE_BREAK_ID = 12
|
||||
PIXTRAL_12B_IMAGE_END_ID = 13
|
||||
PIXTRAL_LARGE_IMAGE_BREAK_ID = 14
|
||||
PIXTRAL_LARGE_IMAGE_END_ID = 15
|
||||
|
||||
|
||||
def get_max_pixtral_image_tokens(ctx: InputContext):
|
||||
@ -118,8 +122,7 @@ def input_mapper_for_pixtral(ctx: InputContext,
|
||||
for image_data in data_list:
|
||||
image = ImageChunk(image=image_data)
|
||||
encoding = tokenizer.instruct.mm_encoder(image)
|
||||
image = torch.from_numpy(encoding.image).to(device="cuda",
|
||||
dtype=torch.float16)
|
||||
image = torch.from_numpy(encoding.image).to(dtype=torch.float16)
|
||||
images.append(image)
|
||||
image_tokens_list.append(encoding.tokens)
|
||||
|
||||
@ -237,8 +240,9 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
|
||||
# NOTE: Image embeddings are split into separate tensors for each image
|
||||
# by the indices of `[IMG_END]` token.
|
||||
split_indices = torch.where(
|
||||
image_tokens == PIXTRAL_IMAGE_END_ID)[0] + 1
|
||||
image_end_condition = (image_tokens == PIXTRAL_12B_IMAGE_END_ID) | (
|
||||
image_tokens == PIXTRAL_LARGE_IMAGE_END_ID)
|
||||
split_indices = torch.where(image_end_condition)[0] + 1
|
||||
if len(split_indices) <= 1:
|
||||
# Do not split, return as tensor of shape [1, fs, hs]
|
||||
return image_embeds.unsqueeze(0)
|
||||
@ -260,8 +264,11 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
if multimodal_embeddings is not None:
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids, inputs_embeds, multimodal_embeddings, [
|
||||
self.vision_args.image_token_id, PIXTRAL_IMAGE_END_ID,
|
||||
PIXTRAL_IMAGE_BREAK_ID
|
||||
self.vision_args.image_token_id,
|
||||
PIXTRAL_12B_IMAGE_END_ID,
|
||||
PIXTRAL_12B_IMAGE_BREAK_ID,
|
||||
PIXTRAL_LARGE_IMAGE_BREAK_ID,
|
||||
PIXTRAL_LARGE_IMAGE_END_ID,
|
||||
])
|
||||
return inputs_embeds
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user