mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-16 15:37:13 +08:00
[VLM] Avoid unnecessary tokenization (#12310)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
68ad4e3a8d
commit
cd7b6f0857
@ -475,15 +475,23 @@ class Blip2MultiModalProcessor(BaseMultiModalProcessor[Blip2ProcessingInfo]):
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> list[PromptReplacement]:
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
vocab = tokenizer.get_vocab()
|
||||
|
||||
bos_token_id = tokenizer.bos_token_id
|
||||
assert isinstance(bos_token_id, int)
|
||||
|
||||
image_token_id = vocab["image"]
|
||||
num_image_tokens = self.info.get_num_image_tokens()
|
||||
image_tokens = [image_token_id] * num_image_tokens
|
||||
|
||||
return [
|
||||
PromptReplacement(
|
||||
modality="image",
|
||||
target="</s>",
|
||||
replacement=PromptReplacementDetails(
|
||||
full="<image>" * num_image_tokens + "</s>",
|
||||
features="<image>" * num_image_tokens,
|
||||
full=image_tokens + [bos_token_id],
|
||||
features=image_tokens,
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
@ -122,8 +122,9 @@ class ChameleonMultiModalProcessor(
|
||||
) -> list[int]:
|
||||
# HF processor adds sep token for chat mode
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
sep_token_id: int = \
|
||||
tokenizer.vocab[tokenizer.sep_token] # type: ignore
|
||||
vocab = tokenizer.get_vocab()
|
||||
|
||||
sep_token_id = vocab[tokenizer.sep_token] # type: ignore
|
||||
|
||||
return prompt_tokens + [sep_token_id]
|
||||
|
||||
@ -141,18 +142,22 @@ class ChameleonMultiModalProcessor(
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> list[PromptReplacement]:
|
||||
processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
||||
image_tokens = processor.image_token * self.info.get_num_image_tokens()
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
vocab = tokenizer.get_vocab()
|
||||
|
||||
image_start_id = vocab[processor.image_start_token]
|
||||
image_token_id = vocab[processor.image_token]
|
||||
image_end_id = vocab[processor.image_end_token]
|
||||
|
||||
num_image_tokens = self.info.get_num_image_tokens()
|
||||
image_tokens = [image_token_id] * num_image_tokens
|
||||
|
||||
return [
|
||||
PromptReplacement(
|
||||
modality="image",
|
||||
target="<image>",
|
||||
target=[image_token_id],
|
||||
replacement=PromptReplacementDetails(
|
||||
full="".join([
|
||||
processor.image_start_token,
|
||||
image_tokens,
|
||||
processor.image_end_token,
|
||||
]),
|
||||
full=([image_start_id] + image_tokens + [image_end_id]),
|
||||
features=image_tokens,
|
||||
),
|
||||
)
|
||||
|
||||
@ -249,8 +249,10 @@ class DeepseekVL2MultiModalProcessor(
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> list[PromptReplacement]:
|
||||
hf_processor = self.info.get_hf_processor()
|
||||
image_token_id: int = hf_processor.image_token_id
|
||||
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
||||
|
||||
image_token_id = hf_processor.image_token_id
|
||||
assert isinstance(image_token_id, int)
|
||||
|
||||
def get_replacement_deepseek_vl2(item_idx: int):
|
||||
images = mm_items.get_items(
|
||||
|
||||
@ -183,7 +183,9 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
|
||||
) -> list[int]:
|
||||
# HF processor adds boa_token_id
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
boa_token_id: int = tokenizer.vocab["<0x04>"] # type: ignore
|
||||
vocab = tokenizer.get_vocab()
|
||||
|
||||
boa_token_id = vocab["<0x04>"]
|
||||
|
||||
return prompt_tokens + [boa_token_id]
|
||||
|
||||
@ -202,6 +204,7 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
|
||||
) -> list[PromptReplacement]:
|
||||
hf_config = self.info.get_hf_config()
|
||||
bos_token_id = hf_config.bos_token_id
|
||||
assert isinstance(bos_token_id, int)
|
||||
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
eot_token_id = tokenizer.bos_token_id
|
||||
|
||||
@ -315,13 +315,14 @@ class PixtralHFMultiModalProcessor(
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> list[PromptReplacement]:
|
||||
processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
||||
hf_config = self.info.get_hf_config()
|
||||
image_token_id = hf_config.image_token_index
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
vocab = tokenizer.get_vocab()
|
||||
|
||||
processor = self.info.get_hf_processor()
|
||||
image_token = processor.image_token
|
||||
image_break_token = processor.image_break_token
|
||||
image_end_token = processor.image_end_token
|
||||
image_break_id = vocab[processor.image_break_token]
|
||||
image_token_id = hf_config.image_token_index
|
||||
image_end_id = vocab[processor.image_end_token]
|
||||
|
||||
vision_config = hf_config.vision_config
|
||||
assert isinstance(vision_config, PixtralVisionConfig)
|
||||
@ -336,10 +337,10 @@ class PixtralHFMultiModalProcessor(
|
||||
image_height=image_size.height,
|
||||
)
|
||||
|
||||
tokens = ([image_token] * ncols + [image_break_token]) * nrows
|
||||
tokens[-1] = image_end_token
|
||||
tokens = ([image_token_id] * ncols + [image_break_id]) * nrows
|
||||
tokens[-1] = image_end_id
|
||||
|
||||
return "".join(tokens)
|
||||
return tokens
|
||||
|
||||
return [
|
||||
PromptReplacement(
|
||||
|
||||
@ -188,7 +188,9 @@ class Qwen2AudioMultiModalProcessor(
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> list[PromptReplacement]:
|
||||
processor = self.info.get_hf_processor()
|
||||
processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
vocab = tokenizer.get_vocab()
|
||||
|
||||
# Use getattr with default to be compatible with transformers<4.48
|
||||
audio_token = getattr(processor, "audio_token", "<|AUDIO|>")
|
||||
@ -197,6 +199,10 @@ class Qwen2AudioMultiModalProcessor(
|
||||
audio_eos_token = getattr(processor, "audio_eos_token",
|
||||
"<|audio_eos|>")
|
||||
|
||||
audio_token_id = vocab[audio_token]
|
||||
audio_bos_id = vocab[audio_bos_token]
|
||||
audio_eos_id = vocab[audio_eos_token]
|
||||
|
||||
feature_attention_mask = out_mm_kwargs.get("feature_attention_mask")
|
||||
if feature_attention_mask is None:
|
||||
audio_output_lengths = []
|
||||
@ -208,22 +214,18 @@ class Qwen2AudioMultiModalProcessor(
|
||||
audio_output_lengths = audio_output_lens.tolist()
|
||||
|
||||
def get_replacement_qwen2_audio(item_idx: int):
|
||||
num_placeholders = audio_output_lengths[item_idx]
|
||||
if num_placeholders == 0:
|
||||
num_features = audio_output_lengths[item_idx]
|
||||
if num_features == 0:
|
||||
audios = mm_items.get_items("audio", AudioProcessorItems)
|
||||
audio = audios.get(item_idx)
|
||||
raise ValueError(
|
||||
f"The audio {audio} (len={len(audio)}) is too short "
|
||||
"to be represented inside the model")
|
||||
|
||||
audio_tokens = audio_token * num_placeholders
|
||||
audio_tokens = [audio_token_id] * num_features
|
||||
|
||||
return PromptReplacementDetails(
|
||||
full="".join([
|
||||
audio_bos_token,
|
||||
audio_tokens,
|
||||
audio_eos_token,
|
||||
]),
|
||||
full=[audio_bos_id] + audio_tokens + [audio_eos_id],
|
||||
features=audio_tokens,
|
||||
)
|
||||
|
||||
|
||||
@ -953,12 +953,14 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo]
|
||||
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
||||
image_processor = self.info.get_image_processor(
|
||||
**hf_processor_mm_kwargs)
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
vocab = tokenizer.get_vocab()
|
||||
|
||||
# NOTE: Only Qwen2VLProcessor in transformers 4.47.0 has
|
||||
# image_token and video_token registered
|
||||
placeholder = {
|
||||
"image": hf_processor.image_token,
|
||||
"video": hf_processor.video_token,
|
||||
"image": vocab[hf_processor.image_token],
|
||||
"video": vocab[hf_processor.video_token],
|
||||
}
|
||||
|
||||
merge_length = image_processor.merge_size**2
|
||||
@ -967,13 +969,13 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo]
|
||||
grid_thw = out_mm_kwargs[f"{modality}_grid_thw"][item_idx]
|
||||
assert isinstance(grid_thw, torch.Tensor)
|
||||
|
||||
num_tokens = grid_thw.prod().item() // merge_length
|
||||
return placeholder[modality] * num_tokens
|
||||
num_tokens = int(grid_thw.prod()) // merge_length
|
||||
return [placeholder[modality]] * num_tokens
|
||||
|
||||
return [
|
||||
PromptReplacement(
|
||||
modality=modality,
|
||||
target=placeholder[modality],
|
||||
target=[placeholder[modality]],
|
||||
replacement=partial(get_replacement_qwen2vl,
|
||||
modality=modality),
|
||||
) for modality in ("image", "video")
|
||||
|
||||
@ -205,16 +205,20 @@ class UltravoxMultiModalProcessor(
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> list[PromptReplacement]:
|
||||
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
||||
placeholder = hf_processor.audio_token_replacement # type: ignore
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
vocab = tokenizer.get_vocab()
|
||||
|
||||
replacement_id = vocab[
|
||||
hf_processor.audio_token_replacement] # type: ignore
|
||||
|
||||
def get_replacement_ultravox(item_idx: int):
|
||||
audio_token_len = out_mm_kwargs["audio_token_len"][item_idx]
|
||||
return placeholder * audio_token_len
|
||||
return [replacement_id] * int(audio_token_len) # type: ignore
|
||||
|
||||
return [
|
||||
PromptReplacement(
|
||||
modality="audio",
|
||||
target="<|audio|>",
|
||||
target='<|audio|>',
|
||||
replacement=get_replacement_ultravox,
|
||||
)
|
||||
]
|
||||
|
||||
@ -67,9 +67,10 @@ def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer:
|
||||
tokenizer_all_special_tokens_extended = (
|
||||
tokenizer.all_special_tokens_extended)
|
||||
tokenizer_all_special_tokens = set(tokenizer.all_special_tokens)
|
||||
tokenizer_vocab = tokenizer.get_vocab()
|
||||
tokenizer_len = len(tokenizer)
|
||||
|
||||
max_token_id = max(tokenizer.get_vocab().values())
|
||||
max_token_id = max(tokenizer_vocab.values())
|
||||
# Some tokenizers (e.g., QwenTokenizer) have special tokens that
|
||||
# are added and included in the implementation of the vocab_size
|
||||
# property, but not in get_vocab(); if there is an implementation
|
||||
@ -96,6 +97,9 @@ def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer:
|
||||
def max_token_id(self):
|
||||
return max_token_id
|
||||
|
||||
def get_vocab(self):
|
||||
return tokenizer_vocab
|
||||
|
||||
def __len__(self):
|
||||
return tokenizer_len
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user