[VLM] Avoid unnecessary tokenization (#12310)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-01-22 19:08:31 +08:00 committed by GitHub
parent 68ad4e3a8d
commit cd7b6f0857
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 71 additions and 40 deletions

View File

@ -475,15 +475,23 @@ class Blip2MultiModalProcessor(BaseMultiModalProcessor[Blip2ProcessingInfo]):
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
tokenizer = self.info.get_tokenizer()
vocab = tokenizer.get_vocab()
bos_token_id = tokenizer.bos_token_id
assert isinstance(bos_token_id, int)
image_token_id = vocab["image"]
num_image_tokens = self.info.get_num_image_tokens()
image_tokens = [image_token_id] * num_image_tokens
return [
PromptReplacement(
modality="image",
target="</s>",
replacement=PromptReplacementDetails(
full="<image>" * num_image_tokens + "</s>",
features="<image>" * num_image_tokens,
full=image_tokens + [bos_token_id],
features=image_tokens,
),
)
]

View File

@ -122,8 +122,9 @@ class ChameleonMultiModalProcessor(
) -> list[int]:
# HF processor adds sep token for chat mode
tokenizer = self.info.get_tokenizer()
sep_token_id: int = \
tokenizer.vocab[tokenizer.sep_token] # type: ignore
vocab = tokenizer.get_vocab()
sep_token_id = vocab[tokenizer.sep_token] # type: ignore
return prompt_tokens + [sep_token_id]
@ -141,18 +142,22 @@ class ChameleonMultiModalProcessor(
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
image_tokens = processor.image_token * self.info.get_num_image_tokens()
tokenizer = self.info.get_tokenizer()
vocab = tokenizer.get_vocab()
image_start_id = vocab[processor.image_start_token]
image_token_id = vocab[processor.image_token]
image_end_id = vocab[processor.image_end_token]
num_image_tokens = self.info.get_num_image_tokens()
image_tokens = [image_token_id] * num_image_tokens
return [
PromptReplacement(
modality="image",
target="<image>",
target=[image_token_id],
replacement=PromptReplacementDetails(
full="".join([
processor.image_start_token,
image_tokens,
processor.image_end_token,
]),
full=([image_start_id] + image_tokens + [image_end_id]),
features=image_tokens,
),
)

View File

@ -249,8 +249,10 @@ class DeepseekVL2MultiModalProcessor(
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
hf_processor = self.info.get_hf_processor()
image_token_id: int = hf_processor.image_token_id
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
image_token_id = hf_processor.image_token_id
assert isinstance(image_token_id, int)
def get_replacement_deepseek_vl2(item_idx: int):
images = mm_items.get_items(

View File

@ -183,7 +183,9 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
) -> list[int]:
# HF processor adds boa_token_id
tokenizer = self.info.get_tokenizer()
boa_token_id: int = tokenizer.vocab["<0x04>"] # type: ignore
vocab = tokenizer.get_vocab()
boa_token_id = vocab["<0x04>"]
return prompt_tokens + [boa_token_id]
@ -202,6 +204,7 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
) -> list[PromptReplacement]:
hf_config = self.info.get_hf_config()
bos_token_id = hf_config.bos_token_id
assert isinstance(bos_token_id, int)
tokenizer = self.info.get_tokenizer()
eot_token_id = tokenizer.bos_token_id

View File

@ -315,13 +315,14 @@ class PixtralHFMultiModalProcessor(
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
hf_config = self.info.get_hf_config()
image_token_id = hf_config.image_token_index
tokenizer = self.info.get_tokenizer()
vocab = tokenizer.get_vocab()
processor = self.info.get_hf_processor()
image_token = processor.image_token
image_break_token = processor.image_break_token
image_end_token = processor.image_end_token
image_break_id = vocab[processor.image_break_token]
image_token_id = hf_config.image_token_index
image_end_id = vocab[processor.image_end_token]
vision_config = hf_config.vision_config
assert isinstance(vision_config, PixtralVisionConfig)
@ -336,10 +337,10 @@ class PixtralHFMultiModalProcessor(
image_height=image_size.height,
)
tokens = ([image_token] * ncols + [image_break_token]) * nrows
tokens[-1] = image_end_token
tokens = ([image_token_id] * ncols + [image_break_id]) * nrows
tokens[-1] = image_end_id
return "".join(tokens)
return tokens
return [
PromptReplacement(

View File

@ -188,7 +188,9 @@ class Qwen2AudioMultiModalProcessor(
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
processor = self.info.get_hf_processor()
processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
tokenizer = self.info.get_tokenizer()
vocab = tokenizer.get_vocab()
# Use getattr with default to be compatible with transformers<4.48
audio_token = getattr(processor, "audio_token", "<|AUDIO|>")
@ -197,6 +199,10 @@ class Qwen2AudioMultiModalProcessor(
audio_eos_token = getattr(processor, "audio_eos_token",
"<|audio_eos|>")
audio_token_id = vocab[audio_token]
audio_bos_id = vocab[audio_bos_token]
audio_eos_id = vocab[audio_eos_token]
feature_attention_mask = out_mm_kwargs.get("feature_attention_mask")
if feature_attention_mask is None:
audio_output_lengths = []
@ -208,22 +214,18 @@ class Qwen2AudioMultiModalProcessor(
audio_output_lengths = audio_output_lens.tolist()
def get_replacement_qwen2_audio(item_idx: int):
num_placeholders = audio_output_lengths[item_idx]
if num_placeholders == 0:
num_features = audio_output_lengths[item_idx]
if num_features == 0:
audios = mm_items.get_items("audio", AudioProcessorItems)
audio = audios.get(item_idx)
raise ValueError(
f"The audio {audio} (len={len(audio)}) is too short "
"to be represented inside the model")
audio_tokens = audio_token * num_placeholders
audio_tokens = [audio_token_id] * num_features
return PromptReplacementDetails(
full="".join([
audio_bos_token,
audio_tokens,
audio_eos_token,
]),
full=[audio_bos_id] + audio_tokens + [audio_eos_id],
features=audio_tokens,
)

View File

@ -953,12 +953,14 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo]
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
image_processor = self.info.get_image_processor(
**hf_processor_mm_kwargs)
tokenizer = self.info.get_tokenizer()
vocab = tokenizer.get_vocab()
# NOTE: Only Qwen2VLProcessor in transformers 4.47.0 has
# image_token and video_token registered
placeholder = {
"image": hf_processor.image_token,
"video": hf_processor.video_token,
"image": vocab[hf_processor.image_token],
"video": vocab[hf_processor.video_token],
}
merge_length = image_processor.merge_size**2
@ -967,13 +969,13 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo]
grid_thw = out_mm_kwargs[f"{modality}_grid_thw"][item_idx]
assert isinstance(grid_thw, torch.Tensor)
num_tokens = grid_thw.prod().item() // merge_length
return placeholder[modality] * num_tokens
num_tokens = int(grid_thw.prod()) // merge_length
return [placeholder[modality]] * num_tokens
return [
PromptReplacement(
modality=modality,
target=placeholder[modality],
target=[placeholder[modality]],
replacement=partial(get_replacement_qwen2vl,
modality=modality),
) for modality in ("image", "video")

View File

@ -205,16 +205,20 @@ class UltravoxMultiModalProcessor(
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
placeholder = hf_processor.audio_token_replacement # type: ignore
tokenizer = self.info.get_tokenizer()
vocab = tokenizer.get_vocab()
replacement_id = vocab[
hf_processor.audio_token_replacement] # type: ignore
def get_replacement_ultravox(item_idx: int):
audio_token_len = out_mm_kwargs["audio_token_len"][item_idx]
return placeholder * audio_token_len
return [replacement_id] * int(audio_token_len) # type: ignore
return [
PromptReplacement(
modality="audio",
target="<|audio|>",
target='<|audio|>',
replacement=get_replacement_ultravox,
)
]

View File

@ -67,9 +67,10 @@ def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer:
tokenizer_all_special_tokens_extended = (
tokenizer.all_special_tokens_extended)
tokenizer_all_special_tokens = set(tokenizer.all_special_tokens)
tokenizer_vocab = tokenizer.get_vocab()
tokenizer_len = len(tokenizer)
max_token_id = max(tokenizer.get_vocab().values())
max_token_id = max(tokenizer_vocab.values())
# Some tokenizers (e.g., QwenTokenizer) have special tokens that
# are added and included in the implementation of the vocab_size
# property, but not in get_vocab(); if there is an implementation
@ -96,6 +97,9 @@ def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer:
def max_token_id(self):
return max_token_id
def get_vocab(self):
return tokenizer_vocab
def __len__(self):
return tokenizer_len