diff --git a/vllm/model_executor/models/ernie45_vl.py b/vllm/model_executor/models/ernie45_vl.py index 86536b21c33f..7c1eba103ae7 100644 --- a/vllm/model_executor/models/ernie45_vl.py +++ b/vllm/model_executor/models/ernie45_vl.py @@ -1367,6 +1367,23 @@ class Ernie4_5_VLMoeForConditionalGeneration( self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors ) + if getattr(self.config, "im_patch_id", None): + visual_token_ids = [ + token_id + for token_id in [ + self.config.im_patch_id, + getattr(self.config, "image_start_token_id", None), + getattr(self.config, "image_end_token_id", None), + getattr(self.config, "video_start_token_id", None), + getattr(self.config, "video_end_token_id", None), + ] + if token_id is not None + ] + self._visual_token_ids_tensor_cache = torch.tensor( + visual_token_ids, dtype=torch.long + ) + else: + self._visual_token_ids_tensor_cache = None def compute_logits( self, @@ -1398,12 +1415,19 @@ class Ernie4_5_VLMoeForConditionalGeneration( return image_features def _set_visual_token_mask(self, input_ids: torch.Tensor) -> None: - if getattr(self.config, "im_patch_id", None) is not None: - self.visual_token_mask = (input_ids == self.config.im_patch_id).reshape( - -1, 1 - ) - else: + """Set mask for visual tokens (image/video patches and delimiters).""" + if self._visual_token_ids_tensor_cache is None: self.visual_token_mask = None + return + # Create tensor on the correct device + visual_token_ids_tensor = self._visual_token_ids_tensor_cache.to( + device=input_ids.device, + dtype=input_ids.dtype, + ) + + self.visual_token_mask = torch.isin(input_ids, visual_token_ids_tensor).reshape( + -1, 1 + ) def get_mrope_input_positions( self,