mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 12:25:01 +08:00
[Bugfix][Multi Modal] Fix incorrect Molmo token processing (#26873)
Signed-off-by: sanghol <sanghol@allenai.org>
This commit is contained in:
parent
f0862eae43
commit
8865da157b
@ -1264,13 +1264,16 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]):
|
||||
) -> list[int]:
|
||||
processor = self.info.get_hf_processor()
|
||||
|
||||
# Apply the chat template to the tokens
|
||||
# The chat template is already applied to the prompt tokens
|
||||
# Use message_format="none" to avoid applying it again
|
||||
# Prepend an empty space if `always_start_with_space` is True
|
||||
tokens = processor.processor.get_tokens_input( # type: ignore
|
||||
self.info.get_tokenizer().decode(prompt_tokens),
|
||||
message_format=processor.message_format,
|
||||
message_format="none",
|
||||
always_start_with_space=processor.always_start_with_space,
|
||||
)
|
||||
|
||||
# Prepend a BOS token id to the tokens
|
||||
processed_data = self.info.ctx.call_hf_processor(
|
||||
processor, # type: ignore
|
||||
dict(tokens=tokens),
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user