[Bugfix] Fix issues for Pixtral-Large-Instruct-2411 (#11393)

Signed-off-by: ywang96 <ywang@example.com>
Co-authored-by: ywang96 <ywang@example.com>
This commit is contained in:
Roger Wang 2024-12-21 02:15:03 -08:00 committed by GitHub
parent 584f0ae40d
commit c2d1b075ba
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -45,8 +45,12 @@ try:
except ImportError:
USE_XFORMERS_OPS = False
PIXTRAL_IMAGE_BREAK_ID = 12
PIXTRAL_IMAGE_END_ID = 13
# These token ids cannot be retrieved from model config
# so we hardcode them here.
PIXTRAL_12B_IMAGE_BREAK_ID = 12
PIXTRAL_12B_IMAGE_END_ID = 13
PIXTRAL_LARGE_IMAGE_BREAK_ID = 14
PIXTRAL_LARGE_IMAGE_END_ID = 15
def get_max_pixtral_image_tokens(ctx: InputContext):
@ -118,8 +122,7 @@ def input_mapper_for_pixtral(ctx: InputContext,
for image_data in data_list:
image = ImageChunk(image=image_data)
encoding = tokenizer.instruct.mm_encoder(image)
image = torch.from_numpy(encoding.image).to(device="cuda",
dtype=torch.float16)
image = torch.from_numpy(encoding.image).to(dtype=torch.float16)
images.append(image)
image_tokens_list.append(encoding.tokens)
@ -237,8 +240,9 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
# NOTE: Image embeddings are split into separate tensors for each image
# by the indices of `[IMG_END]` token.
split_indices = torch.where(
image_tokens == PIXTRAL_IMAGE_END_ID)[0] + 1
image_end_condition = (image_tokens == PIXTRAL_12B_IMAGE_END_ID) | (
image_tokens == PIXTRAL_LARGE_IMAGE_END_ID)
split_indices = torch.where(image_end_condition)[0] + 1
if len(split_indices) <= 1:
# Do not split, return as tensor of shape [1, fs, hs]
return image_embeds.unsqueeze(0)
@ -260,8 +264,11 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
if multimodal_embeddings is not None:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings, [
self.vision_args.image_token_id, PIXTRAL_IMAGE_END_ID,
PIXTRAL_IMAGE_BREAK_ID
self.vision_args.image_token_id,
PIXTRAL_12B_IMAGE_END_ID,
PIXTRAL_12B_IMAGE_BREAK_ID,
PIXTRAL_LARGE_IMAGE_BREAK_ID,
PIXTRAL_LARGE_IMAGE_END_ID,
])
return inputs_embeds