mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 04:26:00 +08:00
[Model][Pixtral] Optimizations for input_processor_for_pixtral_hf (#9514)
This commit is contained in:
parent
263d8ee150
commit
8e3e7f2713
@ -701,63 +701,64 @@ def input_processor_for_pixtral_hf(
|
||||
new_prompt = inputs.get("prompt")
|
||||
new_token_ids = inputs["prompt_token_ids"]
|
||||
|
||||
image_token = processor.image_token
|
||||
image_break_token = processor.image_break_token
|
||||
image_end_token = processor.image_end_token
|
||||
|
||||
# Update new_prompt if present
|
||||
if new_prompt:
|
||||
replace_strings = []
|
||||
for image in image_data:
|
||||
w, h = image.size
|
||||
parts = new_prompt.split(image_token)
|
||||
assert len(parts) - 1 == len(image_data)
|
||||
new_parts = [parts[0]] # Start with the part before any image tokens
|
||||
|
||||
for image, next_part in zip(image_data, parts[1:]):
|
||||
w, h = image.size
|
||||
(num_width_tokens,
|
||||
num_height_tokens) = get_pixtral_hf_image_feature_size(
|
||||
hf_config, image_width=w, image_height=h)
|
||||
|
||||
replace_tokens = [[processor.image_token] * num_width_tokens +
|
||||
[processor.image_break_token]
|
||||
] * num_height_tokens
|
||||
# Flatten list
|
||||
replace_tokens = [
|
||||
item for sublist in replace_tokens for item in sublist
|
||||
replace_tokens = [image_token] * num_width_tokens + [
|
||||
image_break_token
|
||||
]
|
||||
replace_tokens[-1] = processor.image_end_token
|
||||
replace_str = "".join(replace_tokens)
|
||||
replace_strings.append(replace_str)
|
||||
new_prompt = new_prompt.replace(processor.image_token,
|
||||
"<placeholder>", 1)
|
||||
replace_tokens = replace_tokens * num_height_tokens
|
||||
replace_tokens[-1] = image_end_token
|
||||
|
||||
while "<placeholder>" in new_prompt:
|
||||
replace_str = replace_strings.pop(0)
|
||||
new_prompt = new_prompt.replace("<placeholder>", replace_str, 1)
|
||||
new_parts.append("".join(replace_tokens))
|
||||
new_parts.append(next_part)
|
||||
|
||||
new_prompt = "".join(new_parts)
|
||||
|
||||
# Update new_token_ids
|
||||
image_token_id = 10
|
||||
image_break_id = 12
|
||||
image_end_id = 13
|
||||
convert_tokens_to_ids = processor.tokenizer.convert_tokens_to_ids
|
||||
image_token_id = convert_tokens_to_ids(image_token)
|
||||
image_break_id = convert_tokens_to_ids(image_break_token)
|
||||
image_end_id = convert_tokens_to_ids(image_end_token)
|
||||
placeholder_token_id = -999
|
||||
# Find all image token indices at once
|
||||
placeholder_indices = [
|
||||
idx for idx, token_id in enumerate(new_token_ids)
|
||||
if token_id == image_token_id
|
||||
]
|
||||
assert len(placeholder_indices) == len(image_data)
|
||||
replace_tokens_list = []
|
||||
for image in image_data:
|
||||
for placeholder_idx, image in zip(placeholder_indices, image_data):
|
||||
new_token_ids[placeholder_idx] = placeholder_token_id
|
||||
|
||||
w, h = image.size
|
||||
(num_width_tokens,
|
||||
num_height_tokens) = get_pixtral_hf_image_feature_size(hf_config,
|
||||
image_width=w,
|
||||
image_height=h)
|
||||
|
||||
num_width_tokens, num_height_tokens = get_pixtral_hf_image_feature_size(
|
||||
hf_config, image_width=w, image_height=h)
|
||||
|
||||
replace_tokens = [[image_token_id] * num_width_tokens +
|
||||
[image_break_id]] * num_height_tokens
|
||||
# Flatten list
|
||||
replace_tokens = [
|
||||
item for sublist in replace_tokens for item in sublist
|
||||
]
|
||||
replace_tokens = [image_token_id] * num_width_tokens + [image_break_id]
|
||||
replace_tokens = replace_tokens * num_height_tokens
|
||||
replace_tokens[-1] = image_end_id
|
||||
replace_tokens_list.append(replace_tokens)
|
||||
# Replace image id with placeholder id
|
||||
next_image_index = new_token_ids.index(image_token_id)
|
||||
new_token_ids[next_image_index] = placeholder_token_id
|
||||
|
||||
while placeholder_token_id in new_token_ids:
|
||||
replace_tokens = replace_tokens_list.pop(0)
|
||||
next_image_index = new_token_ids.index(placeholder_token_id)
|
||||
prefix = new_token_ids[:next_image_index]
|
||||
postfix = new_token_ids[next_image_index + 1:]
|
||||
new_token_ids = prefix + replace_tokens + postfix
|
||||
# Backward iteration for replacement without affecting known indices
|
||||
for placeholder_idx, replace_tokens in zip(reversed(placeholder_indices),
|
||||
reversed(replace_tokens_list)):
|
||||
new_token_ids[placeholder_idx:placeholder_idx + 1] = replace_tokens
|
||||
|
||||
# NOTE: Create a defensive copy of the original inputs
|
||||
return token_inputs(prompt_token_ids=new_token_ids,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user