mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 21:05:01 +08:00
[Bugfix] Fix platform-specific routing in CustomOp implementations (#24444)
Signed-off-by: Konrad Zawora <kzawora@habana.ai>
This commit is contained in:
parent
1fdd5c42d7
commit
4aa23892d6
@ -454,7 +454,7 @@ class XIELU(CustomOp):
|
|||||||
)
|
)
|
||||||
return result.view(original_shape)
|
return result.view(original_shape)
|
||||||
|
|
||||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
def forward_native(self, input: torch.Tensor) -> torch.Tensor:
|
||||||
if self._xielu_cuda_obj is not None and input.is_cuda:
|
if self._xielu_cuda_obj is not None and input.is_cuda:
|
||||||
if not torch._dynamo.is_compiling():
|
if not torch._dynamo.is_compiling():
|
||||||
return self._xielu_cuda_fn(input)
|
return self._xielu_cuda_fn(input)
|
||||||
@ -464,6 +464,9 @@ class XIELU(CustomOp):
|
|||||||
)
|
)
|
||||||
return self._xielu_python(input)
|
return self._xielu_python(input)
|
||||||
|
|
||||||
|
def forward_cuda(self, input: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.forward_native(input)
|
||||||
|
|
||||||
|
|
||||||
class ScaledActivation(nn.Module):
|
class ScaledActivation(nn.Module):
|
||||||
"""An activation function with post-scale parameters.
|
"""An activation function with post-scale parameters.
|
||||||
|
|||||||
@ -1593,7 +1593,7 @@ class FusedMoE(CustomOp):
|
|||||||
else:
|
else:
|
||||||
return tensor_model_parallel_all_reduce(final_hidden_states)
|
return tensor_model_parallel_all_reduce(final_hidden_states)
|
||||||
|
|
||||||
def forward(
|
def forward_native(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
router_logits: torch.Tensor,
|
router_logits: torch.Tensor,
|
||||||
@ -1627,6 +1627,13 @@ class FusedMoE(CustomOp):
|
|||||||
return (shared_output[..., :og_hidden_states],
|
return (shared_output[..., :og_hidden_states],
|
||||||
fused_output[..., :og_hidden_states])
|
fused_output[..., :og_hidden_states])
|
||||||
|
|
||||||
|
def forward_cuda(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
router_logits: torch.Tensor,
|
||||||
|
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||||
|
return self.forward_native(hidden_states, router_logits)
|
||||||
|
|
||||||
def forward_impl_chunked(
|
def forward_impl_chunked(
|
||||||
self,
|
self,
|
||||||
full_hidden_states: torch.Tensor,
|
full_hidden_states: torch.Tensor,
|
||||||
|
|||||||
@ -88,7 +88,7 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
|
|||||||
cache = torch.cat((cos, sin), dim=-1)
|
cache = torch.cat((cos, sin), dim=-1)
|
||||||
return cache
|
return cache
|
||||||
|
|
||||||
def forward(
|
def forward_native(
|
||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
@ -129,3 +129,12 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
|
|||||||
query = query_rot
|
query = query_rot
|
||||||
key = key_rot
|
key = key_rot
|
||||||
return query, key
|
return query, key
|
||||||
|
|
||||||
|
def forward_cuda(
|
||||||
|
self,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: Optional[torch.Tensor] = None,
|
||||||
|
offsets: Optional[torch.Tensor] = None,
|
||||||
|
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
|
return self.forward_native(positions, query, key, offsets)
|
||||||
|
|||||||
@ -111,7 +111,7 @@ class DualChunkRotaryEmbedding(CustomOp):
|
|||||||
device=self.device)
|
device=self.device)
|
||||||
return q_cache, qc_cache, k_cache, qc_no_clamp_cache, q_inter_cache
|
return q_cache, qc_cache, k_cache, qc_no_clamp_cache, q_inter_cache
|
||||||
|
|
||||||
def forward(
|
def forward_native(
|
||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
@ -161,6 +161,15 @@ class DualChunkRotaryEmbedding(CustomOp):
|
|||||||
dim=-1)
|
dim=-1)
|
||||||
return query, key
|
return query, key
|
||||||
|
|
||||||
|
def forward_cuda(
|
||||||
|
self,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
offsets: Optional[torch.Tensor] = None,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
return self.forward_native(positions, query, key, offsets)
|
||||||
|
|
||||||
def _apply_rotary_embedding(self, cos_sin, hidden_rot, hidden_pass):
|
def _apply_rotary_embedding(self, cos_sin, hidden_rot, hidden_pass):
|
||||||
cos, sin = cos_sin.chunk(2, dim=-1)
|
cos, sin = cos_sin.chunk(2, dim=-1)
|
||||||
if self.is_neox_style:
|
if self.is_neox_style:
|
||||||
|
|||||||
@ -12,7 +12,7 @@ from .mrope import MRotaryEmbedding
|
|||||||
class Ernie4_5_VLRotaryEmbedding(MRotaryEmbedding):
|
class Ernie4_5_VLRotaryEmbedding(MRotaryEmbedding):
|
||||||
"""3D rotary positional embedding. 3D is t:time h:height w:width"""
|
"""3D rotary positional embedding. 3D is t:time h:height w:width"""
|
||||||
|
|
||||||
def forward(
|
def forward_native( # type: ignore[override]
|
||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
@ -70,3 +70,11 @@ class Ernie4_5_VLRotaryEmbedding(MRotaryEmbedding):
|
|||||||
self.is_neox_style)
|
self.is_neox_style)
|
||||||
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
||||||
return query, key
|
return query, key
|
||||||
|
|
||||||
|
def forward_cuda( # type: ignore[override]
|
||||||
|
self,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: Optional[torch.Tensor] = None,
|
||||||
|
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
|
return self.forward_native(positions, query, key)
|
||||||
@ -53,7 +53,7 @@ class Llama4VisionRotaryEmbedding(RotaryEmbedding):
|
|||||||
torch.stack([torch.cos(freqs), torch.sin(freqs)], dim=-1))
|
torch.stack([torch.cos(freqs), torch.sin(freqs)], dim=-1))
|
||||||
return cache
|
return cache
|
||||||
|
|
||||||
def forward(
|
def forward_native( # type: ignore[override]
|
||||||
self,
|
self,
|
||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
key: Optional[torch.Tensor] = None,
|
key: Optional[torch.Tensor] = None,
|
||||||
@ -72,3 +72,10 @@ class Llama4VisionRotaryEmbedding(RotaryEmbedding):
|
|||||||
query_out = torch.view_as_real(query_ * freqs_ci).flatten(3)
|
query_out = torch.view_as_real(query_ * freqs_ci).flatten(3)
|
||||||
key_out = torch.view_as_real(key_ * freqs_ci).flatten(3)
|
key_out = torch.view_as_real(key_ * freqs_ci).flatten(3)
|
||||||
return query_out.type_as(query), key_out.type_as(key)
|
return query_out.type_as(query), key_out.type_as(key)
|
||||||
|
|
||||||
|
def forward_cuda( # type: ignore[override]
|
||||||
|
self,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: Optional[torch.Tensor] = None,
|
||||||
|
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
|
return self.forward_native(query, key)
|
||||||
|
|||||||
@ -8,7 +8,6 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
from vllm.platforms import current_platform
|
|
||||||
from vllm.triton_utils import tl, triton
|
from vllm.triton_utils import tl, triton
|
||||||
|
|
||||||
from .base import RotaryEmbedding
|
from .base import RotaryEmbedding
|
||||||
@ -202,28 +201,6 @@ class MRotaryEmbedding(RotaryEmbedding):
|
|||||||
if self.mrope_section:
|
if self.mrope_section:
|
||||||
assert sum(self.mrope_section) == rotary_dim // 2
|
assert sum(self.mrope_section) == rotary_dim // 2
|
||||||
|
|
||||||
self.use_triton = current_platform.is_cuda_alike()
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
positions: torch.Tensor,
|
|
||||||
query: torch.Tensor,
|
|
||||||
key: Optional[torch.Tensor] = None,
|
|
||||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
|
||||||
"""MRope forward.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
positions:
|
|
||||||
[num_tokens,] (text only) or
|
|
||||||
[3, num_tokens] (T/H/W positions with multimodal inputs)
|
|
||||||
query: [num_tokens, num_heads * head_size]
|
|
||||||
key: [num_tokens, num_kv_heads * head_size]
|
|
||||||
"""
|
|
||||||
if self.use_triton:
|
|
||||||
return self.forward_cuda(positions, query, key)
|
|
||||||
else:
|
|
||||||
return self.forward_native(positions, query, key)
|
|
||||||
|
|
||||||
def forward_native(
|
def forward_native(
|
||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
|
|||||||
@ -399,7 +399,7 @@ class VocabParallelEmbedding(CustomOp):
|
|||||||
param[:loaded_weight.shape[0]].data.copy_(loaded_weight)
|
param[:loaded_weight.shape[0]].data.copy_(loaded_weight)
|
||||||
param[loaded_weight.shape[0]:].data.fill_(0)
|
param[loaded_weight.shape[0]:].data.fill_(0)
|
||||||
|
|
||||||
def forward(self, input_):
|
def forward_native(self, input_):
|
||||||
if self.tp_size > 1:
|
if self.tp_size > 1:
|
||||||
# Build the mask.
|
# Build the mask.
|
||||||
masked_input, input_mask = get_masked_input_and_mask(
|
masked_input, input_mask = get_masked_input_and_mask(
|
||||||
@ -420,6 +420,9 @@ class VocabParallelEmbedding(CustomOp):
|
|||||||
output = tensor_model_parallel_all_reduce(output_parallel)
|
output = tensor_model_parallel_all_reduce(output_parallel)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
def forward_cuda(self, input_):
|
||||||
|
return self.forward_native(input_)
|
||||||
|
|
||||||
def extra_repr(self) -> str:
|
def extra_repr(self) -> str:
|
||||||
s = f"num_embeddings={self.num_embeddings_per_partition}"
|
s = f"num_embeddings={self.num_embeddings_per_partition}"
|
||||||
s += f", embedding_dim={self.embedding_dim}"
|
s += f", embedding_dim={self.embedding_dim}"
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user