diff --git a/vllm/model_executor/layers/rotary_embedding/common.py b/vllm/model_executor/layers/rotary_embedding/common.py index 50660c6ecc223..b86cd9f001d61 100644 --- a/vllm/model_executor/layers/rotary_embedding/common.py +++ b/vllm/model_executor/layers/rotary_embedding/common.py @@ -178,6 +178,37 @@ class ApplyRotaryEmb(CustomOp): output = output.to(origin_dtype) return output + def _pre_process( + self, + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Size, torch.dtype]: + origin_shape = x.shape + if len(origin_shape) == 3: + # x: [seq_len, num_heads, head_size] + x = x.unsqueeze(0) + + origin_dtype = x.dtype + if self.enable_fp32_compute: + x = x.float() + cos = cos.float() + sin = sin.float() + + return x, cos, sin, origin_shape, origin_dtype + + def _post_process( + self, + output: torch.Tensor, + origin_shape: torch.Size, + origin_dtype: torch.dtype, + ) -> torch.Tensor: + if len(origin_shape) == 3: + output = output.squeeze(0) + if self.enable_fp32_compute: + output = output.to(origin_dtype) + return output + def forward_native( self, x: torch.Tensor, @@ -197,16 +228,7 @@ class ApplyRotaryEmb(CustomOp): ) -> torch.Tensor: from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb - origin_dtype = x.dtype - if self.enable_fp32_compute: - x = x.float() - cos = cos.float() - sin = sin.float() - - origin_shape = x.shape - if len(origin_shape) == 3: - # x: [seq_len, num_heads, head_size] - x = x.unsqueeze(0) + x, cos, sin, origin_shape, origin_dtype = self._pre_process(x, cos, sin) """ Arguments of apply_rotary_emb() in vllm_flash_attn: @@ -218,10 +240,7 @@ class ApplyRotaryEmb(CustomOp): interleaved = not self.is_neox_style output = apply_rotary_emb(x, cos, sin, interleaved) - if len(origin_shape) == 3: - output = output.squeeze(0) - if self.enable_fp32_compute: - output = output.to(origin_dtype) + output = self._post_process(output, origin_shape, origin_dtype) return output def forward_hip( @@ -231,16 +250,7 @@ class ApplyRotaryEmb(CustomOp): sin: torch.Tensor, ) -> torch.Tensor: if self.apply_rotary_emb_flash_attn is not None: - origin_dtype = x.dtype - if self.enable_fp32_compute: - x = x.float() - cos = cos.float() - sin = sin.float() - - origin_shape = x.shape - if len(origin_shape) == 3: - # x: [seq_len, num_heads, head_size] - x = x.unsqueeze(0) + x, cos, sin, origin_shape, origin_dtype = self._pre_process(x, cos, sin) """ Arguments of apply_rotary() in flash_attn: @@ -254,10 +264,7 @@ class ApplyRotaryEmb(CustomOp): x, cos, sin, interleaved=interleaved ).type_as(x) - if len(origin_shape) == 3: - output = output.squeeze(0) - if self.enable_fp32_compute: - output = output.to(origin_dtype) + output = self._post_process(output, origin_shape, origin_dtype) else: # Falling back to PyTorch native implementation. output = self.forward_native(x, cos, sin)