mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 14:25:46 +08:00
[Model][Pixtral] Optimizations for input_processor_for_pixtral_hf (#9514)
This commit is contained in:
parent
263d8ee150
commit
8e3e7f2713
@ -701,63 +701,64 @@ def input_processor_for_pixtral_hf(
|
|||||||
new_prompt = inputs.get("prompt")
|
new_prompt = inputs.get("prompt")
|
||||||
new_token_ids = inputs["prompt_token_ids"]
|
new_token_ids = inputs["prompt_token_ids"]
|
||||||
|
|
||||||
|
image_token = processor.image_token
|
||||||
|
image_break_token = processor.image_break_token
|
||||||
|
image_end_token = processor.image_end_token
|
||||||
|
|
||||||
# Update new_prompt if present
|
# Update new_prompt if present
|
||||||
if new_prompt:
|
if new_prompt:
|
||||||
replace_strings = []
|
parts = new_prompt.split(image_token)
|
||||||
for image in image_data:
|
assert len(parts) - 1 == len(image_data)
|
||||||
w, h = image.size
|
new_parts = [parts[0]] # Start with the part before any image tokens
|
||||||
|
|
||||||
|
for image, next_part in zip(image_data, parts[1:]):
|
||||||
|
w, h = image.size
|
||||||
(num_width_tokens,
|
(num_width_tokens,
|
||||||
num_height_tokens) = get_pixtral_hf_image_feature_size(
|
num_height_tokens) = get_pixtral_hf_image_feature_size(
|
||||||
hf_config, image_width=w, image_height=h)
|
hf_config, image_width=w, image_height=h)
|
||||||
|
|
||||||
replace_tokens = [[processor.image_token] * num_width_tokens +
|
replace_tokens = [image_token] * num_width_tokens + [
|
||||||
[processor.image_break_token]
|
image_break_token
|
||||||
] * num_height_tokens
|
|
||||||
# Flatten list
|
|
||||||
replace_tokens = [
|
|
||||||
item for sublist in replace_tokens for item in sublist
|
|
||||||
]
|
]
|
||||||
replace_tokens[-1] = processor.image_end_token
|
replace_tokens = replace_tokens * num_height_tokens
|
||||||
replace_str = "".join(replace_tokens)
|
replace_tokens[-1] = image_end_token
|
||||||
replace_strings.append(replace_str)
|
|
||||||
new_prompt = new_prompt.replace(processor.image_token,
|
|
||||||
"<placeholder>", 1)
|
|
||||||
|
|
||||||
while "<placeholder>" in new_prompt:
|
new_parts.append("".join(replace_tokens))
|
||||||
replace_str = replace_strings.pop(0)
|
new_parts.append(next_part)
|
||||||
new_prompt = new_prompt.replace("<placeholder>", replace_str, 1)
|
|
||||||
|
new_prompt = "".join(new_parts)
|
||||||
|
|
||||||
# Update new_token_ids
|
# Update new_token_ids
|
||||||
image_token_id = 10
|
convert_tokens_to_ids = processor.tokenizer.convert_tokens_to_ids
|
||||||
image_break_id = 12
|
image_token_id = convert_tokens_to_ids(image_token)
|
||||||
image_end_id = 13
|
image_break_id = convert_tokens_to_ids(image_break_token)
|
||||||
|
image_end_id = convert_tokens_to_ids(image_end_token)
|
||||||
placeholder_token_id = -999
|
placeholder_token_id = -999
|
||||||
|
# Find all image token indices at once
|
||||||
|
placeholder_indices = [
|
||||||
|
idx for idx, token_id in enumerate(new_token_ids)
|
||||||
|
if token_id == image_token_id
|
||||||
|
]
|
||||||
|
assert len(placeholder_indices) == len(image_data)
|
||||||
replace_tokens_list = []
|
replace_tokens_list = []
|
||||||
for image in image_data:
|
for placeholder_idx, image in zip(placeholder_indices, image_data):
|
||||||
|
new_token_ids[placeholder_idx] = placeholder_token_id
|
||||||
|
|
||||||
w, h = image.size
|
w, h = image.size
|
||||||
|
(num_width_tokens,
|
||||||
|
num_height_tokens) = get_pixtral_hf_image_feature_size(hf_config,
|
||||||
|
image_width=w,
|
||||||
|
image_height=h)
|
||||||
|
|
||||||
num_width_tokens, num_height_tokens = get_pixtral_hf_image_feature_size(
|
replace_tokens = [image_token_id] * num_width_tokens + [image_break_id]
|
||||||
hf_config, image_width=w, image_height=h)
|
replace_tokens = replace_tokens * num_height_tokens
|
||||||
|
|
||||||
replace_tokens = [[image_token_id] * num_width_tokens +
|
|
||||||
[image_break_id]] * num_height_tokens
|
|
||||||
# Flatten list
|
|
||||||
replace_tokens = [
|
|
||||||
item for sublist in replace_tokens for item in sublist
|
|
||||||
]
|
|
||||||
replace_tokens[-1] = image_end_id
|
replace_tokens[-1] = image_end_id
|
||||||
replace_tokens_list.append(replace_tokens)
|
replace_tokens_list.append(replace_tokens)
|
||||||
# Replace image id with placeholder id
|
|
||||||
next_image_index = new_token_ids.index(image_token_id)
|
|
||||||
new_token_ids[next_image_index] = placeholder_token_id
|
|
||||||
|
|
||||||
while placeholder_token_id in new_token_ids:
|
# Backward iteration for replacement without affecting known indices
|
||||||
replace_tokens = replace_tokens_list.pop(0)
|
for placeholder_idx, replace_tokens in zip(reversed(placeholder_indices),
|
||||||
next_image_index = new_token_ids.index(placeholder_token_id)
|
reversed(replace_tokens_list)):
|
||||||
prefix = new_token_ids[:next_image_index]
|
new_token_ids[placeholder_idx:placeholder_idx + 1] = replace_tokens
|
||||||
postfix = new_token_ids[next_image_index + 1:]
|
|
||||||
new_token_ids = prefix + replace_tokens + postfix
|
|
||||||
|
|
||||||
# NOTE: Create a defensive copy of the original inputs
|
# NOTE: Create a defensive copy of the original inputs
|
||||||
return token_inputs(prompt_token_ids=new_token_ids,
|
return token_inputs(prompt_token_ids=new_token_ids,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user