[Bugfix] Fix Idefics3 bug (#10778)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Jee Jee Li 2024-11-29 21:56:46 +08:00 committed by GitHub
parent c82b432d4a
commit 3132aac043
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -267,54 +267,56 @@ def input_processor_for_idefics3(ctx: InputContext,
n_images_in_text = []
text = inputs.get("prompt")
if text is not None:
if isinstance(text, str):
text = [text]
elif not isinstance(text, list) and not isinstance(text[0], str):
raise ValueError("Invalid input text. Please provide a string, "
"or a list of strings")
if text is None:
prompt_token_ids = inputs.get("prompt_token_ids", [])
assert prompt_token_ids
text = tokenizer.decode(prompt_token_ids)
fake_image_token = processor.fake_image_token.content
image_token = processor.image_token.content
global_img_token = processor.global_image_tag
if isinstance(text, str):
text = [text]
elif not isinstance(text, list) and not isinstance(text[0], str):
raise ValueError("Invalid input text. Please provide a string, "
"or a list of strings")
prompt_strings = []
for sample, sample_rows, sample_cols in zip(text, image_rows,
image_cols):
n_images_in_text.append(sample.count(image_token))
fake_image_token = processor.fake_image_token.content
image_token = processor.image_token.content
global_img_token = processor.global_image_tag
# Replace the image token with fake tokens around the expanded
# image token sequence of length `image_seq_len`
image_prompt_strings = []
for n_rows, n_cols in zip(sample_rows, sample_cols):
image_prompt_string = _get_image_prompt_string(
n_rows,
n_cols,
processor.image_seq_len,
image_token=image_token,
fake_token_around_image=fake_image_token,
global_img_token=global_img_token,
)
image_prompt_strings.append(image_prompt_string)
prompt_strings = []
for sample, sample_rows, sample_cols in zip(text, image_rows, image_cols):
n_images_in_text.append(sample.count(image_token))
split_sample = sample.split(image_token)
if len(split_sample) == 0:
raise ValueError(
"The image token should be present in the text.")
# Replace the image token with fake tokens around the expanded
# image token sequence of length `image_seq_len`
image_prompt_strings = []
for n_rows, n_cols in zip(sample_rows, sample_cols):
image_prompt_string = _get_image_prompt_string(
n_rows,
n_cols,
processor.image_seq_len,
image_token=image_token,
fake_token_around_image=fake_image_token,
global_img_token=global_img_token,
)
image_prompt_strings.append(image_prompt_string)
# Place in the image prompt strings where the image tokens are
sample = split_sample[0]
for i, image_prompt_string in enumerate(image_prompt_strings):
sample += image_prompt_string + split_sample[i + 1]
prompt_strings.append(sample)
split_sample = sample.split(image_token)
if len(split_sample) == 0:
raise ValueError("The image token should be present in the text.")
prompt_token_ids = tokenizer(text=prompt_strings[0]).input_ids
# Place in the image prompt strings where the image tokens are
sample = split_sample[0]
for i, image_prompt_string in enumerate(image_prompt_strings):
sample += image_prompt_string + split_sample[i + 1]
prompt_strings.append(sample)
return token_inputs(
prompt_token_ids=prompt_token_ids,
prompt=prompt_strings[0],
multi_modal_data=multi_modal_data,
)
prompt_token_ids = tokenizer(text=prompt_strings[0]).input_ids
return token_inputs(
prompt_token_ids=prompt_token_ids,
prompt=prompt_strings[0],
multi_modal_data=multi_modal_data,
)
def _get_max_num_image_patch(image_processor: Idefics3ImageProcessor) -> int: