mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 02:15:01 +08:00
[Model] Refactor Ovis2 to support original tokenizer (#17537)
Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
parent
6768ff4a22
commit
88c8304104
@ -730,11 +730,9 @@ def run_ovis2(questions: list[str], modality: str) -> ModelRequestData:
|
||||
assert modality == "image"
|
||||
|
||||
model_name = "AIDC-AI/Ovis2-1B"
|
||||
tokenizer = "Isotr0py/Ovis2-tokenizer"
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
tokenizer=tokenizer,
|
||||
max_model_len=4096,
|
||||
max_num_seqs=2,
|
||||
trust_remote_code=True,
|
||||
|
||||
@ -439,11 +439,9 @@ def load_nvlm_d(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||
# Ovis2
|
||||
def load_ovis2(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||
model_name = "AIDC-AI/Ovis2-1B"
|
||||
tokenizer = "Isotr0py/Ovis2-tokenizer"
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
tokenizer=tokenizer,
|
||||
max_model_len=8192,
|
||||
max_num_seqs=2,
|
||||
trust_remote_code=True,
|
||||
|
||||
@ -349,7 +349,6 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
||||
transformers_version_reason="Use of deprecated imports which have been removed.", # noqa: E501
|
||||
extras={"phi3.5": "microsoft/Phi-3.5-vision-instruct"}), # noqa: E501
|
||||
"Ovis2ForConditionalGeneration": _HfExamplesInfo("AIDC-AI/Ovis2-1B",
|
||||
tokenizer="Isotr0py/Ovis2-tokenizer",
|
||||
trust_remote_code=True,
|
||||
hf_overrides={"architectures": ["Ovis2ForConditionalGeneration"]}), # noqa: E501
|
||||
"Phi4MMForCausalLM": _HfExamplesInfo("microsoft/Phi-4-multimodal-instruct",
|
||||
|
||||
@ -46,8 +46,7 @@ from .utils import merge_multimodal_embeddings
|
||||
|
||||
# Cannot find the following number from hf config.
|
||||
IMAGE_TOKEN = "<image>"
|
||||
IMAGE_ATOM_TOKEN_ID = 151666
|
||||
IMAGE_PAD_TOKEN_ID = 151672
|
||||
IMAGE_PAD_TOKEN_ID = 151655
|
||||
NUMBER_OF_TOKEN_TO_RESERVE_FOR_SEGMENT = 256
|
||||
|
||||
|
||||
@ -59,6 +58,12 @@ class Ovis2ImagePatchInputs(TypedDict):
|
||||
`(batch_size * num_patches, patch_size_x * patch_size_y * num_channels)`
|
||||
"""
|
||||
|
||||
inducator_tokens: torch.Tensor
|
||||
"""
|
||||
Shape:
|
||||
`(batch_size * (num_patches + 1))`
|
||||
"""
|
||||
|
||||
patches_per_image: List[int]
|
||||
"""
|
||||
List of number of total patches for each image in the batch.
|
||||
@ -138,6 +143,21 @@ class Ovis2DummyInputsBuilder(BaseDummyInputsBuilder[Ovis2ProcessingInfo]):
|
||||
|
||||
class Ovis2MultiModalProcessor(BaseMultiModalProcessor[Ovis2ProcessingInfo]):
|
||||
|
||||
def image_indicators_to_visual_tokens(
|
||||
self,
|
||||
image_indicators: list[int],
|
||||
) -> list[int]:
|
||||
"""
|
||||
Filter image indicators placeholders and convert them to corresponding
|
||||
tokens in visual tokenizer.
|
||||
For example, [-301, -300, -302, -300, -303, -300, -304, -300, -305]
|
||||
should return [vocab_size-1, vocab_size-2, ..., vocab_size-5]
|
||||
"""
|
||||
hf_config = self.info.get_hf_config()
|
||||
vte_vocab_size = hf_config.visual_tokenizer_config.vocab_size
|
||||
# -300 is image_atom token, filter them out
|
||||
return [vte_vocab_size + x + 300 for x in image_indicators if x < -300]
|
||||
|
||||
def _call_hf_processor(
|
||||
self,
|
||||
prompt: str,
|
||||
@ -156,6 +176,16 @@ class Ovis2MultiModalProcessor(BaseMultiModalProcessor[Ovis2ProcessingInfo]):
|
||||
mm_kwargs=mm_kwargs,
|
||||
)
|
||||
|
||||
hf_processor = self.info.get_hf_processor()
|
||||
image_indicators = [
|
||||
hf_processor.construct_image_indicators(grid)
|
||||
for grid in processed_outputs["grids"]
|
||||
]
|
||||
indicator_tokens = [
|
||||
self.image_indicators_to_visual_tokens(indicator)
|
||||
for indicator in image_indicators
|
||||
]
|
||||
processed_outputs["indicator_tokens"] = indicator_tokens
|
||||
return processed_outputs
|
||||
|
||||
def _apply_hf_processor_tokens_only(
|
||||
@ -171,7 +201,8 @@ class Ovis2MultiModalProcessor(BaseMultiModalProcessor[Ovis2ProcessingInfo]):
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return dict(pixel_values=MultiModalFieldConfig.batched("image"),
|
||||
grids=MultiModalFieldConfig.batched("image"))
|
||||
grids=MultiModalFieldConfig.batched("image"),
|
||||
indicator_tokens=MultiModalFieldConfig.batched("image"))
|
||||
|
||||
def _get_prompt_updates(
|
||||
self,
|
||||
@ -230,20 +261,28 @@ class Ovis2ForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object) -> Optional[Ovis2ImagePatchInputs]:
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
if pixel_values is None:
|
||||
indicator_tokens = kwargs.pop("indicator_tokens", None)
|
||||
|
||||
if pixel_values is None and indicator_tokens is None:
|
||||
return None
|
||||
|
||||
if pixel_values is not None:
|
||||
if pixel_values is not None and indicator_tokens is not None:
|
||||
if not isinstance(pixel_values, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of pixel values. "
|
||||
f"Got type: {type(pixel_values)}")
|
||||
|
||||
if not isinstance(indicator_tokens, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of indicator_tokens. "
|
||||
f"Got type: {type(pixel_values)}")
|
||||
|
||||
return Ovis2ImagePatchInputs(
|
||||
type="image_patches",
|
||||
flat_data=flatten_bn(flatten_bn(pixel_values), concat=True),
|
||||
patches_per_image=[
|
||||
x.shape[0] for x in flatten_bn(pixel_values)
|
||||
],
|
||||
indicator_tokens=flatten_bn(flatten_bn(indicator_tokens),
|
||||
concat=True),
|
||||
)
|
||||
|
||||
raise AssertionError("This line should be unreachable.")
|
||||
@ -252,15 +291,33 @@ class Ovis2ForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
self, image_input: Ovis2ImagePatchInputs) -> MultiModalEmbeddings:
|
||||
image_patches_flat = image_input["flat_data"]
|
||||
patches_per_image = image_input["patches_per_image"]
|
||||
indicator_tokens = image_input["indicator_tokens"]
|
||||
|
||||
indicator_per_image = list(
|
||||
map(lambda x: x + 1 if x > 1 else x + 2, patches_per_image))
|
||||
|
||||
target_dtype = self.visual_tokenizer.dtype
|
||||
visual_tokens = self.visual_tokenizer(
|
||||
image_patches_flat.to(target_dtype))
|
||||
visual_embeds = self.vte(visual_tokens) # 1:1 numeric eq.
|
||||
|
||||
return tuple(
|
||||
x.flatten(0, 1)
|
||||
for x in visual_embeds.split(patches_per_image, dim=0))
|
||||
indicator_embeds = self.vte(indicator_tokens)
|
||||
indicator_embeds_per_image = indicator_embeds.split(
|
||||
indicator_per_image)
|
||||
|
||||
visual_embeds_per_image = visual_embeds.split(patches_per_image, dim=0)
|
||||
vision_embeddings = []
|
||||
for indicator, visual in zip(indicator_embeds_per_image,
|
||||
visual_embeds_per_image):
|
||||
vision_embeddings_per_image = []
|
||||
for i in range(visual.shape[0]):
|
||||
vision_embeddings_per_image.append(
|
||||
torch.cat([indicator[i:i + 1], visual[i]], dim=0))
|
||||
vision_embeddings_per_image.append(indicator[i + 1:])
|
||||
vision_embeddings.append(
|
||||
torch.cat(vision_embeddings_per_image, dim=0))
|
||||
|
||||
return tuple(vision_embeddings)
|
||||
|
||||
def get_multimodal_embeddings(
|
||||
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
|
||||
@ -281,7 +338,7 @@ class Ovis2ForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
if multimodal_embeddings is not None:
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids, inputs_embeds, multimodal_embeddings,
|
||||
[IMAGE_ATOM_TOKEN_ID, IMAGE_PAD_TOKEN_ID])
|
||||
[IMAGE_PAD_TOKEN_ID])
|
||||
return inputs_embeds
|
||||
|
||||
def forward(
|
||||
|
||||
@ -69,20 +69,21 @@ class OvisProcessor(ProcessorMixin):
|
||||
image_processor_class = "AutoImageProcessor"
|
||||
tokenizer_class = "Qwen2Tokenizer"
|
||||
|
||||
def __init__(self, image_processor=None, tokenizer=None, chat_template=None, **kwargs):
|
||||
self.image_token = "<|image_pad|>" if not hasattr(tokenizer, "image_token") else tokenizer.image_token
|
||||
self.video_token = "<|video_pad|>" if not hasattr(tokenizer, "video_token") else tokenizer.video_token
|
||||
def __init__(self, image_processor=None, tokenizer=None, chat_template=None, image_pad_token=None, **kwargs):
|
||||
self.image_token = "<image>"
|
||||
self.image_pad_token = "<|image_pad|>" if image_pad_token is None else image_pad_token
|
||||
super().__init__(image_processor, tokenizer, chat_template=chat_template)
|
||||
|
||||
self.image_pad_token_id = self.tokenizer.get_vocab()[self.image_pad_token]
|
||||
self.extra_special_tokens = {
|
||||
"image_token": "<image>",
|
||||
"image_atom": "<image_atom>",
|
||||
"image_start": "<img>",
|
||||
"image_prefix": "<pre>",
|
||||
"image_col_sep": "<col>",
|
||||
"image_row_sep": "<row>",
|
||||
"image_end": "</img>",
|
||||
'image_pad': '<image_pad>',
|
||||
"image_token": -200,
|
||||
"image_atom": -300,
|
||||
"image_start": -301,
|
||||
"image_prefix": -302,
|
||||
"image_col_sep": -303,
|
||||
"image_row_sep": -304,
|
||||
"image_end": -305,
|
||||
'image_pad': self.image_pad_token_id,
|
||||
}
|
||||
|
||||
def __call__(
|
||||
@ -157,58 +158,44 @@ class OvisProcessor(ProcessorMixin):
|
||||
if not isinstance(text, list):
|
||||
text = [text]
|
||||
|
||||
tokenized_batched_text = self.tokenizer.batch_encode_plus(
|
||||
text,
|
||||
**output_kwargs["text_kwargs"]
|
||||
)
|
||||
tokenized_batched_text = self._tokenize_with_image_symbol(text)
|
||||
image_token_id = self.get_token_value("image_token")
|
||||
replaced_ids_list = []
|
||||
replaced_attn_mask_list = []
|
||||
idx = 0
|
||||
for ids_tensor, attn_mask in zip(tokenized_batched_text['input_ids'],
|
||||
tokenized_batched_text['attention_mask']):
|
||||
for ids_tensor in tokenized_batched_text:
|
||||
if image_token_id in ids_tensor and "image_placeholders" in image_features:
|
||||
if idx < len(image_features["image_placeholders"]):
|
||||
# Converts in list for ease of use
|
||||
ids_list = ids_tensor.tolist()
|
||||
attn_list = attn_mask.tolist()
|
||||
|
||||
new_ids = []
|
||||
new_attn = []
|
||||
|
||||
# replace placeholders
|
||||
for i, token_id in enumerate(ids_list):
|
||||
if token_id == image_token_id:
|
||||
placeholder_ids = image_features["image_placeholders"][idx]
|
||||
new_ids.extend(placeholder_ids)
|
||||
new_attn.extend([1] * len(placeholder_ids))
|
||||
idx += 1
|
||||
else:
|
||||
new_ids.append(token_id)
|
||||
new_attn.append(attn_list[i])
|
||||
|
||||
# Converts back to tensors
|
||||
ids_tensor = torch.tensor(new_ids, dtype=torch.long)
|
||||
attn_mask = torch.tensor(new_attn, dtype=torch.long)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
'Mismatch between the images you provided and the number of placeholder present in the text')
|
||||
|
||||
replaced_ids_list.append(ids_tensor)
|
||||
replaced_attn_mask_list.append(attn_mask)
|
||||
|
||||
if replaced_ids_list:
|
||||
replaced_and_tokenized_ids = torch.stack(replaced_ids_list)
|
||||
replaced_and_tokenized_attn_mask = torch.stack(replaced_attn_mask_list)
|
||||
else:
|
||||
replaced_and_tokenized_ids = torch.tensor([], dtype=torch.long)
|
||||
replaced_and_tokenized_attn_mask = torch.tensor([], dtype=torch.long)
|
||||
|
||||
# Create the output with text features
|
||||
output = BatchFeature(
|
||||
data={
|
||||
"input_ids": replaced_and_tokenized_ids,
|
||||
"attention_mask": replaced_and_tokenized_attn_mask,
|
||||
}
|
||||
)
|
||||
|
||||
@ -219,10 +206,22 @@ class OvisProcessor(ProcessorMixin):
|
||||
|
||||
return output
|
||||
|
||||
|
||||
# If only images were provided
|
||||
return BatchFeature(data=image_features)
|
||||
|
||||
def _tokenize_with_image_symbol(self, text_list: list[str]) -> torch.LongTensor:
|
||||
batch_token_ids = []
|
||||
for text in text_list:
|
||||
text_chunks = [self.tokenizer(chunk, add_special_tokens=False).input_ids for chunk in
|
||||
text.split(self.image_token)]
|
||||
token_ids = []
|
||||
num_chuck = len(text_chunks)
|
||||
for i, chunk in enumerate(text_chunks):
|
||||
token_ids.extend(chunk)
|
||||
if i < num_chuck - 1:
|
||||
token_ids.append(self.get_token_value("image_token"))
|
||||
batch_token_ids.append(token_ids)
|
||||
return torch.tensor(batch_token_ids, dtype=torch.long)
|
||||
|
||||
def get_image_size(self):
|
||||
height = self.image_processor.crop_size["height"]
|
||||
@ -230,10 +229,9 @@ class OvisProcessor(ProcessorMixin):
|
||||
return height, width
|
||||
|
||||
def get_token_value(self, tok):
|
||||
return self.tokenizer.get_vocab()[self.extra_special_tokens[tok]]
|
||||
|
||||
def construct_image_placeholders(self, grid):
|
||||
return self.extra_special_tokens[tok]
|
||||
|
||||
def construct_image_indicators(self, grid):
|
||||
image_placeholders = [self.get_token_value('image_start'),
|
||||
self.get_token_value('image_atom'),
|
||||
self.get_token_value('image_prefix')]
|
||||
@ -246,7 +244,11 @@ class OvisProcessor(ProcessorMixin):
|
||||
if r < grid[0] - 1:
|
||||
image_placeholders.append(self.get_token_value('image_row_sep'))
|
||||
image_placeholders.append(self.get_token_value('image_end'))
|
||||
# return image_placeholders
|
||||
return image_placeholders
|
||||
|
||||
def construct_image_placeholders(self, grid):
|
||||
|
||||
image_placeholders = self.construct_image_indicators(grid)
|
||||
|
||||
image_atom_token_id = self.get_token_value('image_atom')
|
||||
# Extract the padding token ID from tokenizer
|
||||
@ -255,7 +257,7 @@ class OvisProcessor(ProcessorMixin):
|
||||
# Create a new list with padding tokens inserted
|
||||
padded_placeholder_tokens = []
|
||||
for token in image_placeholders:
|
||||
padded_placeholder_tokens.append(token)
|
||||
padded_placeholder_tokens.append(image_padding_token_id)
|
||||
if token == image_atom_token_id:
|
||||
# Add 255 padding tokens after each image atom token
|
||||
padded_placeholder_tokens.extend([image_padding_token_id] * 255)
|
||||
@ -394,4 +396,4 @@ class OvisProcessor(ProcessorMixin):
|
||||
return names_from_processor + ["second_per_grid_ts"]
|
||||
|
||||
|
||||
AutoProcessor.register("OvisProcessor", OvisProcessor)
|
||||
AutoProcessor.register("OvisProcessor", OvisProcessor)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user