[Models] Prevent CUDA sync in Qwen2.5-VL (#24741)

Signed-off-by: Lukas Geiger <lukas.geiger94@gmail.com>
This commit is contained in:
Lukas Geiger 2025-09-12 17:03:55 +01:00 committed by GitHub
parent 57f94e88ea
commit b0d1213ac3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -64,6 +64,7 @@ from vllm.multimodal.utils import run_dp_sharded_mrope_vision_model
from vllm.platforms import _Backend
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.config import uses_mrope
from vllm.utils import is_pin_memory_available
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
@ -737,7 +738,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
@staticmethod
def invert_permutation(perm: torch.Tensor) -> torch.Tensor:
# building the inverse permutation in O(n) time
inv = torch.empty_like(perm)
inv = torch.empty_like(perm, pin_memory=is_pin_memory_available())
inv[perm] = torch.arange(perm.numel(),
device=perm.device,
dtype=perm.dtype)
@ -808,6 +809,8 @@ class Qwen2_5_VisionTransformer(nn.Module):
non_blocking=True)
window_index = window_index.to(device=hidden_states.device,
non_blocking=True)
reverse_indices = reverse_indices.to(device=hidden_states.device,
non_blocking=True)
hidden_states = hidden_states.reshape(
seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)