mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-26 13:57:04 +08:00
[Qwen][ROCm] Flash Attention Rotary Embeddings (#24642)
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com> Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
422f2cca4b
commit
06d102ecc8
@ -2,15 +2,21 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
import math
|
import math
|
||||||
|
from functools import cache
|
||||||
|
from importlib.util import find_spec
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from vllm.logger import init_logger
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import direct_register_custom_op
|
from vllm.utils import direct_register_custom_op
|
||||||
|
|
||||||
if current_platform.is_cuda():
|
if current_platform.is_cuda():
|
||||||
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
|
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
# common functions
|
# common functions
|
||||||
def rotate_neox(x: torch.Tensor) -> torch.Tensor:
|
def rotate_neox(x: torch.Tensor) -> torch.Tensor:
|
||||||
@ -65,6 +71,23 @@ def apply_rotary_emb_dispatch(x: torch.Tensor, cos: torch.Tensor,
|
|||||||
return apply_rotary_emb_torch(x, cos, sin, is_neox_style)
|
return apply_rotary_emb_torch(x, cos, sin, is_neox_style)
|
||||||
|
|
||||||
|
|
||||||
|
@cache
|
||||||
|
def dispatch_rotary_emb_function() -> Callable[..., torch.Tensor]:
|
||||||
|
if current_platform.is_cuda():
|
||||||
|
return apply_rotary_emb
|
||||||
|
|
||||||
|
if current_platform.is_rocm():
|
||||||
|
if find_spec("flash_attn") is not None:
|
||||||
|
from flash_attn.ops.triton.rotary import apply_rotary
|
||||||
|
return apply_rotary
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"flash_attn is not installed. Falling back to PyTorch "
|
||||||
|
"implementation for rotary embeddings.")
|
||||||
|
|
||||||
|
return apply_rotary_emb_torch
|
||||||
|
|
||||||
|
|
||||||
# yarn functions
|
# yarn functions
|
||||||
# Inverse dim formula to find dim based on number of rotations
|
# Inverse dim formula to find dim based on number of rotations
|
||||||
def yarn_find_correction_dim(num_rotations: int,
|
def yarn_find_correction_dim(num_rotations: int,
|
||||||
|
|||||||
@ -50,6 +50,8 @@ from vllm.model_executor.layers.activation import QuickGELU
|
|||||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
|
from vllm.model_executor.layers.rotary_embedding.common import (
|
||||||
|
dispatch_rotary_emb_function)
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
@ -63,7 +65,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
|||||||
BaseProcessingInfo, PromptReplacement,
|
BaseProcessingInfo, PromptReplacement,
|
||||||
PromptUpdate)
|
PromptUpdate)
|
||||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||||
from vllm.platforms import _Backend, current_platform
|
from vllm.platforms import _Backend
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||||
@ -272,13 +274,11 @@ def apply_rotary_emb_torch(x: torch.Tensor,
|
|||||||
|
|
||||||
def apply_rotary_pos_emb_vision(t: torch.Tensor,
|
def apply_rotary_pos_emb_vision(t: torch.Tensor,
|
||||||
freqs: torch.Tensor) -> torch.Tensor:
|
freqs: torch.Tensor) -> torch.Tensor:
|
||||||
|
rotary_emb_function = dispatch_rotary_emb_function()
|
||||||
t_ = t.float()
|
t_ = t.float()
|
||||||
cos = freqs.cos()
|
cos = freqs.cos()
|
||||||
sin = freqs.sin()
|
sin = freqs.sin()
|
||||||
apply_rotary_emb = apply_rotary_emb_torch
|
output = rotary_emb_function(t_, cos, sin).type_as(t)
|
||||||
if current_platform.is_cuda():
|
|
||||||
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
|
|
||||||
output = apply_rotary_emb(t_, cos, sin).type_as(t)
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user