mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 04:45:01 +08:00
[Models] Prevent CUDA sync in Qwen2.5-VL (#24741)
Signed-off-by: Lukas Geiger <lukas.geiger94@gmail.com>
This commit is contained in:
parent
57f94e88ea
commit
b0d1213ac3
@ -64,6 +64,7 @@ from vllm.multimodal.utils import run_dp_sharded_mrope_vision_model
|
|||||||
from vllm.platforms import _Backend
|
from vllm.platforms import _Backend
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.transformers_utils.config import uses_mrope
|
from vllm.transformers_utils.config import uses_mrope
|
||||||
|
from vllm.utils import is_pin_memory_available
|
||||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||||
|
|
||||||
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
||||||
@ -737,7 +738,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def invert_permutation(perm: torch.Tensor) -> torch.Tensor:
|
def invert_permutation(perm: torch.Tensor) -> torch.Tensor:
|
||||||
# building the inverse permutation in O(n) time
|
# building the inverse permutation in O(n) time
|
||||||
inv = torch.empty_like(perm)
|
inv = torch.empty_like(perm, pin_memory=is_pin_memory_available())
|
||||||
inv[perm] = torch.arange(perm.numel(),
|
inv[perm] = torch.arange(perm.numel(),
|
||||||
device=perm.device,
|
device=perm.device,
|
||||||
dtype=perm.dtype)
|
dtype=perm.dtype)
|
||||||
@ -808,6 +809,8 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
|||||||
non_blocking=True)
|
non_blocking=True)
|
||||||
window_index = window_index.to(device=hidden_states.device,
|
window_index = window_index.to(device=hidden_states.device,
|
||||||
non_blocking=True)
|
non_blocking=True)
|
||||||
|
reverse_indices = reverse_indices.to(device=hidden_states.device,
|
||||||
|
non_blocking=True)
|
||||||
|
|
||||||
hidden_states = hidden_states.reshape(
|
hidden_states = hidden_states.reshape(
|
||||||
seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
|
seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user