diff --git a/csrc/layernorm_kernels.cu b/csrc/layernorm_kernels.cu index fb6882f3e7c3e..d073dd6d2dee1 100644 --- a/csrc/layernorm_kernels.cu +++ b/csrc/layernorm_kernels.cu @@ -140,6 +140,10 @@ void rms_norm(torch::Tensor& out, // [..., hidden_size] torch::Tensor& input, // [..., hidden_size] torch::Tensor& weight, // [hidden_size] double epsilon) { + TORCH_CHECK(out.is_contiguous()); + TORCH_CHECK(input.is_contiguous()); + TORCH_CHECK(weight.is_contiguous()); + int hidden_size = input.size(-1); int num_tokens = input.numel() / hidden_size; diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 6f0a5f991908f..8079a63017177 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -186,7 +186,9 @@ def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor, # layer norm ops def rms_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, epsilon: float) -> None: - torch.ops._C.rms_norm(out, input, weight, epsilon) + # TODO: Remove this contiguous call when the kernel is updated to support non-contiguous input + input_contiguous = input.contiguous() + torch.ops._C.rms_norm(out, input_contiguous, weight, epsilon) def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor, diff --git a/vllm/model_executor/models/intern_vit.py b/vllm/model_executor/models/intern_vit.py index 0499f339b2465..fdcef8b9be8d2 100644 --- a/vllm/model_executor/models/intern_vit.py +++ b/vllm/model_executor/models/intern_vit.py @@ -190,8 +190,8 @@ class InternParallelAttention(nn.Module): if self.tp_size > 1: q = tensor_model_parallel_all_gather(q.contiguous()) k = tensor_model_parallel_all_gather(k.contiguous()) - q = self.q_norm.forward_native(q) - k = self.k_norm.forward_native(k) + q = self.q_norm(q) + k = self.k_norm(k) if self.tp_size > 1: splitter = partial(split_tensor_along_last_dim, num_partitions=self.tp_size) @@ -264,10 +264,8 @@ class InternSdpaAttention(nn.Module): if self.qk_normalization: B_, N_, H_, D_ = q.shape - q = self.q_norm.forward_native(q.flatten(-2, - -1)).view(B_, N_, H_, D_) - k = self.k_norm.forward_native(k.flatten(-2, - -1)).view(B_, N_, H_, D_) + q = self.q_norm(q.flatten(-2, -1)).view(B_, N_, H_, D_) + k = self.k_norm(k.flatten(-2, -1)).view(B_, N_, H_, D_) q = q.transpose(1, 2) k = k.transpose(1, 2) v = v.transpose(1, 2) diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py index 75eebdacfdca0..42bbb77a22c07 100644 --- a/vllm/model_executor/models/molmo.py +++ b/vllm/model_executor/models/molmo.py @@ -438,8 +438,8 @@ class MolmoAttention(nn.Module): if self.tp_size > 1: q = tensor_model_parallel_all_gather(q.contiguous()) k = tensor_model_parallel_all_gather(k.contiguous()) - q = self.q_norm.forward_native(q) - k = self.k_norm.forward_native(k) + q = self.q_norm(q) + k = self.k_norm(k) if self.tp_size > 1: splitter = partial(split_tensor_along_last_dim, num_partitions=self.tp_size) diff --git a/vllm/model_executor/models/olmo2.py b/vllm/model_executor/models/olmo2.py index 44beae5726dc0..422b53d86f119 100644 --- a/vllm/model_executor/models/olmo2.py +++ b/vllm/model_executor/models/olmo2.py @@ -139,8 +139,8 @@ class Olmo2Attention(nn.Module): if self.tp_size > 1: q = tensor_model_parallel_all_gather(q.contiguous()) k = tensor_model_parallel_all_gather(k.contiguous()) - q = self.q_norm.forward_native(q) - k = self.k_norm.forward_native(k) + q = self.q_norm(q) + k = self.k_norm(k) if self.tp_size > 1: splitter = partial(split_tensor_along_last_dim, num_partitions=self.tp_size) diff --git a/vllm/model_executor/models/qwen3.py b/vllm/model_executor/models/qwen3.py index 73d2838f461ea..40e0ccc1bab6b 100644 --- a/vllm/model_executor/models/qwen3.py +++ b/vllm/model_executor/models/qwen3.py @@ -133,11 +133,11 @@ class Qwen3Attention(nn.Module): # Add qk-norm q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim) - q_by_head = self.q_norm.forward_native(q_by_head) + q_by_head = self.q_norm(q_by_head) q = q_by_head.view(q.shape) k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim) - k_by_head = self.k_norm.forward_native(k_by_head) + k_by_head = self.k_norm(k_by_head) k = k_by_head.view(k.shape) q, k = self.rotary_emb(positions, q, k) attn_output = self.attn(q, k, v) diff --git a/vllm/model_executor/models/qwen3_moe.py b/vllm/model_executor/models/qwen3_moe.py index 97acbaa2ac340..fe6b303ba0b5a 100644 --- a/vllm/model_executor/models/qwen3_moe.py +++ b/vllm/model_executor/models/qwen3_moe.py @@ -225,12 +225,12 @@ class Qwen3MoeAttention(nn.Module): # Add qk-norm q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim) - q_by_head = self.q_norm.forward_native(q_by_head) + q_by_head = self.q_norm(q_by_head) q = q_by_head.view(q.shape) k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim) - k_by_head = self.k_norm.forward_native(k_by_head) + k_by_head = self.k_norm(k_by_head) k = k_by_head.view(k.shape) q, k = self.rotary_emb(positions, q, k) attn_output = self.attn(q, k, v)