mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-19 22:45:44 +08:00
[Fix] optimize visual token mask with caching and multi-token support (#28374)
Signed-off-by: Ferrebo <itachi971009@gmail.com> Signed-off-by: kebo01 <kebo01@baidu.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
parent
15be507c86
commit
912744d066
@ -1367,6 +1367,23 @@ class Ernie4_5_VLMoeForConditionalGeneration(
|
|||||||
self.make_empty_intermediate_tensors = (
|
self.make_empty_intermediate_tensors = (
|
||||||
self.language_model.make_empty_intermediate_tensors
|
self.language_model.make_empty_intermediate_tensors
|
||||||
)
|
)
|
||||||
|
if getattr(self.config, "im_patch_id", None):
|
||||||
|
visual_token_ids = [
|
||||||
|
token_id
|
||||||
|
for token_id in [
|
||||||
|
self.config.im_patch_id,
|
||||||
|
getattr(self.config, "image_start_token_id", None),
|
||||||
|
getattr(self.config, "image_end_token_id", None),
|
||||||
|
getattr(self.config, "video_start_token_id", None),
|
||||||
|
getattr(self.config, "video_end_token_id", None),
|
||||||
|
]
|
||||||
|
if token_id is not None
|
||||||
|
]
|
||||||
|
self._visual_token_ids_tensor_cache = torch.tensor(
|
||||||
|
visual_token_ids, dtype=torch.long
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self._visual_token_ids_tensor_cache = None
|
||||||
|
|
||||||
def compute_logits(
|
def compute_logits(
|
||||||
self,
|
self,
|
||||||
@ -1398,12 +1415,19 @@ class Ernie4_5_VLMoeForConditionalGeneration(
|
|||||||
return image_features
|
return image_features
|
||||||
|
|
||||||
def _set_visual_token_mask(self, input_ids: torch.Tensor) -> None:
|
def _set_visual_token_mask(self, input_ids: torch.Tensor) -> None:
|
||||||
if getattr(self.config, "im_patch_id", None) is not None:
|
"""Set mask for visual tokens (image/video patches and delimiters)."""
|
||||||
self.visual_token_mask = (input_ids == self.config.im_patch_id).reshape(
|
if self._visual_token_ids_tensor_cache is None:
|
||||||
|
self.visual_token_mask = None
|
||||||
|
return
|
||||||
|
# Create tensor on the correct device
|
||||||
|
visual_token_ids_tensor = self._visual_token_ids_tensor_cache.to(
|
||||||
|
device=input_ids.device,
|
||||||
|
dtype=input_ids.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.visual_token_mask = torch.isin(input_ids, visual_token_ids_tensor).reshape(
|
||||||
-1, 1
|
-1, 1
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
self.visual_token_mask = None
|
|
||||||
|
|
||||||
def get_mrope_input_positions(
|
def get_mrope_input_positions(
|
||||||
self,
|
self,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user