mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 05:15:42 +08:00
[ROCm] Split AITER unified attention into its own backend (#25507)
Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
This commit is contained in:
parent
2161efe978
commit
f231e5bc21
@ -7,9 +7,7 @@ import pytest
|
|||||||
import torch._dynamo
|
import torch._dynamo
|
||||||
|
|
||||||
from tests.compile.backend import LazyInitPass, TestBackend
|
from tests.compile.backend import LazyInitPass, TestBackend
|
||||||
from tests.models.utils import check_outputs_equal
|
|
||||||
from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata
|
from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata
|
||||||
from vllm import LLM, SamplingParams
|
|
||||||
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
|
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention, AttentionMetadata
|
||||||
from vllm.attention.backends.registry import _Backend
|
from vllm.attention.backends.registry import _Backend
|
||||||
@ -31,7 +29,6 @@ from vllm.config import (
|
|||||||
)
|
)
|
||||||
from vllm.forward_context import get_forward_context, set_forward_context
|
from vllm.forward_context import get_forward_context, set_forward_context
|
||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
QuantKey,
|
|
||||||
kFp8StaticTensorSym,
|
kFp8StaticTensorSym,
|
||||||
kNvfp4Quant,
|
kNvfp4Quant,
|
||||||
)
|
)
|
||||||
@ -48,132 +45,6 @@ backend: Optional[TestBackend] = None
|
|||||||
backend_unfused: Optional[TestBackend] = None
|
backend_unfused: Optional[TestBackend] = None
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"model, quant_key", [("amd/Llama-3.1-8B-Instruct-FP8-KV", kFp8StaticTensorSym)]
|
|
||||||
)
|
|
||||||
@pytest.mark.parametrize("use_triton_fa", [True, False])
|
|
||||||
@pytest.mark.skipif(not current_platform.supports_fp8(), reason="Need FP8")
|
|
||||||
@pytest.mark.skipif(
|
|
||||||
not current_platform.is_rocm(), reason="V0 attn quant fusion only on ROCm"
|
|
||||||
)
|
|
||||||
def test_attention_fusion_v0(
|
|
||||||
example_prompts, monkeypatch, model: str, quant_key: QuantKey, use_triton_fa: bool
|
|
||||||
):
|
|
||||||
# Clean Dynamo cache to avoid reusing other test cases
|
|
||||||
# (for some reason the reset at the end is not enough)
|
|
||||||
torch._dynamo.reset()
|
|
||||||
|
|
||||||
# Use global backends
|
|
||||||
global backend, backend_unfused
|
|
||||||
|
|
||||||
monkeypatch.setenv("VLLM_USE_V1", "1")
|
|
||||||
monkeypatch.setenv("VLLM_USE_TRITON_FLASH_ATTN", str(int(use_triton_fa)))
|
|
||||||
|
|
||||||
# Prompt 4 seems too open-ended, differs between fused and unfused
|
|
||||||
# (both outputs look reasonable though)
|
|
||||||
prompts = example_prompts[:4] + example_prompts[5:]
|
|
||||||
|
|
||||||
compile_config = CompilationConfig(
|
|
||||||
# DYNAMO_AS_IS triggers custom backend & does full Dynamo compilation
|
|
||||||
# DYNAMO_ONCE does not properly propagate shapes.
|
|
||||||
level=CompilationLevel.DYNAMO_AS_IS,
|
|
||||||
backend="tests.compile.test_fusion_attn.backend_unfused",
|
|
||||||
custom_ops=["+quant_fp8"],
|
|
||||||
)
|
|
||||||
vllm_config = VllmConfig(
|
|
||||||
compilation_config=compile_config,
|
|
||||||
model_config=ModelConfig(
|
|
||||||
model=model,
|
|
||||||
dtype=torch.bfloat16,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
backend_unfused = TestBackend(NoOpEliminationPass(vllm_config))
|
|
||||||
|
|
||||||
llm = LLM(
|
|
||||||
model,
|
|
||||||
enforce_eager=True,
|
|
||||||
compilation_config=compile_config,
|
|
||||||
gpu_memory_utilization=0.5,
|
|
||||||
max_model_len=2048,
|
|
||||||
)
|
|
||||||
|
|
||||||
sampling_params = SamplingParams(temperature=0.0, max_tokens=10, top_p=0.95)
|
|
||||||
|
|
||||||
unfused_output = llm.generate(prompts, sampling_params)
|
|
||||||
backend_unfused = None # Reset backend to make sure llm gets released
|
|
||||||
del llm
|
|
||||||
|
|
||||||
compile_config = CompilationConfig(
|
|
||||||
# DYNAMO_AS_IS triggers custom backend & does full Dynamo compilation
|
|
||||||
# DYNAMO_ONCE does not properly propagate shapes.
|
|
||||||
level=CompilationLevel.DYNAMO_AS_IS,
|
|
||||||
backend="tests.compile.test_fusion_attn.backend",
|
|
||||||
custom_ops=["+quant_fp8"],
|
|
||||||
)
|
|
||||||
vllm_config = VllmConfig(
|
|
||||||
compilation_config=compile_config,
|
|
||||||
model_config=ModelConfig(
|
|
||||||
model=model,
|
|
||||||
dtype=torch.bfloat16,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
# AttnFusionPass needs attention layers to be registered in config upon init
|
|
||||||
# so we initialize it during compilation.
|
|
||||||
attn_pass = LazyInitPass(AttnFusionPass, vllm_config)
|
|
||||||
backend = TestBackend(NoOpEliminationPass(vllm_config), attn_pass)
|
|
||||||
llm2 = LLM(
|
|
||||||
model,
|
|
||||||
enforce_eager=True,
|
|
||||||
compilation_config=compile_config,
|
|
||||||
gpu_memory_utilization=0.5,
|
|
||||||
max_model_len=2048,
|
|
||||||
)
|
|
||||||
|
|
||||||
# check support
|
|
||||||
attn_fusion_supported = [
|
|
||||||
layer.impl.fused_output_quant_supported(quant_key)
|
|
||||||
for key, layer in compile_config.static_forward_context.items()
|
|
||||||
]
|
|
||||||
|
|
||||||
print(f"{attn_fusion_supported=}")
|
|
||||||
if any(attn_fusion_supported):
|
|
||||||
# Check quant ops
|
|
||||||
backend.check_before_ops([QUANT_OPS[quant_key]], fully_replaced=False)
|
|
||||||
|
|
||||||
# attention ops present in both, just output_scale param changes
|
|
||||||
attn_nodes_pre = list(find_op_nodes(ATTN_OP, backend.graph_pre_pass))
|
|
||||||
attn_nodes_post = list(find_op_nodes(ATTN_OP, backend.graph_post_pass))
|
|
||||||
assert len(attn_nodes_pre) == len(attn_nodes_post)
|
|
||||||
|
|
||||||
for i in range(len(attn_nodes_pre)):
|
|
||||||
assert attn_nodes_pre[i].kwargs["output_scale"] is None
|
|
||||||
fused = attn_nodes_post[i].kwargs["output_scale"] is not None
|
|
||||||
assert fused == attn_fusion_supported[i], (
|
|
||||||
f"Node {i} {'' if fused else 'not '} expected to have fused output quant"
|
|
||||||
)
|
|
||||||
|
|
||||||
# check outputs
|
|
||||||
fused_output = llm2.generate(prompts, sampling_params)
|
|
||||||
|
|
||||||
# transform outputs to format expected by check_outputs_equal
|
|
||||||
sample_outs = lambda s: (list(s.token_ids), s.text)
|
|
||||||
outs_lst = lambda ros: [sample_outs(ro.outputs[0]) for ro in ros]
|
|
||||||
|
|
||||||
check_outputs_equal(
|
|
||||||
outputs_0_lst=outs_lst(unfused_output),
|
|
||||||
outputs_1_lst=outs_lst(fused_output),
|
|
||||||
name_0="unfused",
|
|
||||||
name_1="fused",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Clean Dynamo cache to avoid polluting other case(s)
|
|
||||||
torch._dynamo.reset()
|
|
||||||
|
|
||||||
# Reset backend to make sure llm2 gets released
|
|
||||||
backend = None
|
|
||||||
|
|
||||||
|
|
||||||
class AttentionQuantPatternModel(torch.nn.Module):
|
class AttentionQuantPatternModel(torch.nn.Module):
|
||||||
"""Base model for AttentionQuantPattern fusion."""
|
"""Base model for AttentionQuantPattern fusion."""
|
||||||
|
|
||||||
@ -221,7 +92,7 @@ class AttentionQuantPatternModel(torch.nn.Module):
|
|||||||
device=self.device,
|
device=self.device,
|
||||||
)
|
)
|
||||||
|
|
||||||
def build_attn_metadata(self, batch_size: int, use_hnd: bool) -> AttentionMetadata:
|
def build_attn_metadata(self, batch_size: int) -> AttentionMetadata:
|
||||||
"""Initialize attention metadata."""
|
"""Initialize attention metadata."""
|
||||||
|
|
||||||
# Create common attn metadata
|
# Create common attn metadata
|
||||||
@ -232,30 +103,57 @@ class AttentionQuantPatternModel(torch.nn.Module):
|
|||||||
|
|
||||||
max_blocks = (max(batch_spec.seq_lens) + self.block_size - 1) // self.block_size
|
max_blocks = (max(batch_spec.seq_lens) + self.block_size - 1) // self.block_size
|
||||||
num_blocks = batch_size * max_blocks
|
num_blocks = batch_size * max_blocks
|
||||||
|
backend = self.attn.backend
|
||||||
|
|
||||||
# Create dummy KV cache for FlashInfer TRTLLM
|
# Create dummy KV cache for the selected backend
|
||||||
# - NHD: [num_blocks, block_size, num_kv_heads, head_size]
|
if backend == _Backend.ROCM_ATTN:
|
||||||
# - HND: [num_blocks, num_kv_heads, block_size, head_size]
|
|
||||||
kv_cache = torch.zeros(
|
|
||||||
num_blocks,
|
|
||||||
2,
|
|
||||||
self.num_kv_heads,
|
|
||||||
self.block_size,
|
|
||||||
self.head_size,
|
|
||||||
dtype=self.kv_cache_dtype,
|
|
||||||
device=self.device,
|
|
||||||
)
|
|
||||||
if current_platform.is_rocm():
|
|
||||||
# k/v as 1st dimention
|
# k/v as 1st dimention
|
||||||
if use_hnd:
|
# HND: [num_blocks, num_kv_heads, block_size, head_size]
|
||||||
kv_cache = kv_cache.permute(1, 0, 2, 3, 4)
|
kv_cache = torch.zeros(
|
||||||
else:
|
2,
|
||||||
kv_cache = kv_cache.permute(1, 0, 3, 2, 4)
|
num_blocks,
|
||||||
else:
|
self.num_kv_heads,
|
||||||
|
self.block_size,
|
||||||
|
self.head_size,
|
||||||
|
dtype=self.kv_cache_dtype,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
elif backend == _Backend.ROCM_AITER_UNIFIED_ATTN:
|
||||||
|
# k/v as 1st dimention
|
||||||
|
# NHD: [num_blocks, block_size, num_kv_heads, head_size]
|
||||||
|
kv_cache = torch.zeros(
|
||||||
|
2,
|
||||||
|
num_blocks,
|
||||||
|
self.block_size,
|
||||||
|
self.num_kv_heads,
|
||||||
|
self.head_size,
|
||||||
|
dtype=self.kv_cache_dtype,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
elif backend == _Backend.TRITON_ATTN:
|
||||||
# k/v as 2nd dimention
|
# k/v as 2nd dimention
|
||||||
# Create kv_cache in HND layout and permute to NHD layout
|
# NHD: [num_blocks, block_size, num_kv_heads, head_size]
|
||||||
# (later will be permuted back to HND layout in forward pass)
|
kv_cache = torch.zeros(
|
||||||
kv_cache = kv_cache.permute(0, 1, 3, 2, 4)
|
num_blocks,
|
||||||
|
2,
|
||||||
|
self.num_kv_heads,
|
||||||
|
self.block_size,
|
||||||
|
self.head_size,
|
||||||
|
dtype=self.kv_cache_dtype,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
elif backend == _Backend.FLASHINFER:
|
||||||
|
kv_cache = torch.zeros(
|
||||||
|
num_blocks,
|
||||||
|
2,
|
||||||
|
self.num_kv_heads,
|
||||||
|
self.block_size,
|
||||||
|
self.head_size,
|
||||||
|
dtype=self.kv_cache_dtype,
|
||||||
|
device=self.device,
|
||||||
|
).permute(0, 1, 3, 2, 4)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported backend: {backend}")
|
||||||
self.attn.kv_cache = [kv_cache]
|
self.attn.kv_cache = [kv_cache]
|
||||||
|
|
||||||
# Build attn metadata
|
# Build attn metadata
|
||||||
@ -375,10 +273,9 @@ else:
|
|||||||
@pytest.mark.parametrize("model_name, model_class", MODELS)
|
@pytest.mark.parametrize("model_name, model_class", MODELS)
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"backend",
|
"backend",
|
||||||
[_Backend.FLASHINFER] if current_platform.is_cuda() else [_Backend.TRITON_ATTN],
|
[_Backend.FLASHINFER]
|
||||||
)
|
if current_platform.is_cuda()
|
||||||
@pytest.mark.parametrize(
|
else [_Backend.ROCM_AITER_UNIFIED_ATTN, _Backend.ROCM_ATTN, _Backend.TRITON_ATTN],
|
||||||
"split_attention", [False, True] if current_platform.is_rocm() else [False]
|
|
||||||
)
|
)
|
||||||
# TODO(boyuan): test inductor graph partition on rocm
|
# TODO(boyuan): test inductor graph partition on rocm
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@ -405,7 +302,6 @@ def test_attention_quant_pattern(
|
|||||||
model_name: str,
|
model_name: str,
|
||||||
model_class: type[AttentionQuantPatternModel],
|
model_class: type[AttentionQuantPatternModel],
|
||||||
backend: _Backend,
|
backend: _Backend,
|
||||||
split_attention: bool,
|
|
||||||
use_inductor_graph_partition: bool,
|
use_inductor_graph_partition: bool,
|
||||||
monkeypatch,
|
monkeypatch,
|
||||||
dist_init,
|
dist_init,
|
||||||
@ -417,8 +313,6 @@ def test_attention_quant_pattern(
|
|||||||
pytest.skip("inductor graph partition is only available in PyTorch 2.9+")
|
pytest.skip("inductor graph partition is only available in PyTorch 2.9+")
|
||||||
|
|
||||||
monkeypatch.setenv("VLLM_USE_V1", "1")
|
monkeypatch.setenv("VLLM_USE_V1", "1")
|
||||||
if split_attention:
|
|
||||||
monkeypatch.setenv("VLLM_V1_USE_PREFILL_DECODE_ATTENTION", "1")
|
|
||||||
|
|
||||||
device = torch.device("cuda:0")
|
device = torch.device("cuda:0")
|
||||||
torch.manual_seed(42)
|
torch.manual_seed(42)
|
||||||
@ -466,9 +360,7 @@ def test_attention_quant_pattern(
|
|||||||
model_unfused = model_unfused.to(device)
|
model_unfused = model_unfused.to(device)
|
||||||
|
|
||||||
forward_ctx = get_forward_context()
|
forward_ctx = get_forward_context()
|
||||||
forward_ctx.attn_metadata = model_unfused.build_attn_metadata(
|
forward_ctx.attn_metadata = model_unfused.build_attn_metadata(batch_size)
|
||||||
batch_size, use_hnd=split_attention
|
|
||||||
)
|
|
||||||
|
|
||||||
# Run model directly without compilation and fusion
|
# Run model directly without compilation and fusion
|
||||||
result_unfused = model_unfused(q, k, v)
|
result_unfused = model_unfused(q, k, v)
|
||||||
@ -494,9 +386,7 @@ def test_attention_quant_pattern(
|
|||||||
model_fused = model_fused.to(device)
|
model_fused = model_fused.to(device)
|
||||||
|
|
||||||
forward_ctx = get_forward_context()
|
forward_ctx = get_forward_context()
|
||||||
forward_ctx.attn_metadata = model_fused.build_attn_metadata(
|
forward_ctx.attn_metadata = model_fused.build_attn_metadata(batch_size)
|
||||||
batch_size, use_hnd=split_attention
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create test backend with fusion passes enabled
|
# Create test backend with fusion passes enabled
|
||||||
noop_pass = NoOpEliminationPass(vllm_config)
|
noop_pass = NoOpEliminationPass(vllm_config)
|
||||||
|
|||||||
@ -25,3 +25,4 @@ class _Backend(enum.Enum):
|
|||||||
FLEX_ATTENTION = enum.auto()
|
FLEX_ATTENTION = enum.auto()
|
||||||
TREE_ATTN = enum.auto()
|
TREE_ATTN = enum.auto()
|
||||||
ROCM_ATTN = enum.auto()
|
ROCM_ATTN = enum.auto()
|
||||||
|
ROCM_AITER_UNIFIED_ATTN = enum.auto()
|
||||||
|
|||||||
@ -254,3 +254,4 @@ def global_force_attn_backend_context_manager(
|
|||||||
finally:
|
finally:
|
||||||
# Revert the original global backend override, if any
|
# Revert the original global backend override, if any
|
||||||
global_force_attn_backend(original_value)
|
global_force_attn_backend(original_value)
|
||||||
|
_cached_get_attn_backend.cache_clear()
|
||||||
|
|||||||
@ -1623,6 +1623,7 @@ class EngineArgs:
|
|||||||
"TREE_ATTN",
|
"TREE_ATTN",
|
||||||
"XFORMERS",
|
"XFORMERS",
|
||||||
"ROCM_ATTN",
|
"ROCM_ATTN",
|
||||||
|
"ROCM_AITER_UNIFIED_ATTN",
|
||||||
]
|
]
|
||||||
if (
|
if (
|
||||||
envs.is_set("VLLM_ATTENTION_BACKEND")
|
envs.is_set("VLLM_ATTENTION_BACKEND")
|
||||||
|
|||||||
13
vllm/envs.py
13
vllm/envs.py
@ -18,7 +18,6 @@ if TYPE_CHECKING:
|
|||||||
LD_LIBRARY_PATH: Optional[str] = None
|
LD_LIBRARY_PATH: Optional[str] = None
|
||||||
VLLM_USE_TRITON_FLASH_ATTN: bool = True
|
VLLM_USE_TRITON_FLASH_ATTN: bool = True
|
||||||
VLLM_V1_USE_PREFILL_DECODE_ATTENTION: bool = False
|
VLLM_V1_USE_PREFILL_DECODE_ATTENTION: bool = False
|
||||||
VLLM_USE_AITER_UNIFIED_ATTENTION: bool = False
|
|
||||||
VLLM_FLASH_ATTN_VERSION: Optional[int] = None
|
VLLM_FLASH_ATTN_VERSION: Optional[int] = None
|
||||||
LOCAL_RANK: int = 0
|
LOCAL_RANK: int = 0
|
||||||
CUDA_VISIBLE_DEVICES: Optional[str] = None
|
CUDA_VISIBLE_DEVICES: Optional[str] = None
|
||||||
@ -109,6 +108,7 @@ if TYPE_CHECKING:
|
|||||||
VLLM_ROCM_USE_AITER_FP4_ASM_GEMM: bool = False
|
VLLM_ROCM_USE_AITER_FP4_ASM_GEMM: bool = False
|
||||||
VLLM_ROCM_USE_TRITON_ROPE: bool = False
|
VLLM_ROCM_USE_TRITON_ROPE: bool = False
|
||||||
VLLM_ROCM_USE_AITER_FP8BMM: bool = True
|
VLLM_ROCM_USE_AITER_FP8BMM: bool = True
|
||||||
|
VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION: bool = False
|
||||||
VLLM_ROCM_USE_SKINNY_GEMM: bool = True
|
VLLM_ROCM_USE_SKINNY_GEMM: bool = True
|
||||||
VLLM_ROCM_FP8_PADDING: bool = True
|
VLLM_ROCM_FP8_PADDING: bool = True
|
||||||
VLLM_ROCM_MOE_PADDING: bool = True
|
VLLM_ROCM_MOE_PADDING: bool = True
|
||||||
@ -475,10 +475,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|||||||
os.getenv("VLLM_V1_USE_PREFILL_DECODE_ATTENTION", "False").lower()
|
os.getenv("VLLM_V1_USE_PREFILL_DECODE_ATTENTION", "False").lower()
|
||||||
in ("true", "1")
|
in ("true", "1")
|
||||||
),
|
),
|
||||||
# Use AITER triton unified attention for V1 attention
|
|
||||||
"VLLM_USE_AITER_UNIFIED_ATTENTION": lambda: (
|
|
||||||
os.getenv("VLLM_USE_AITER_UNIFIED_ATTENTION", "False").lower() in ("true", "1")
|
|
||||||
),
|
|
||||||
# Force vllm to use a specific flash-attention version (2 or 3), only valid
|
# Force vllm to use a specific flash-attention version (2 or 3), only valid
|
||||||
# when using the flash-attention backend.
|
# when using the flash-attention backend.
|
||||||
"VLLM_FLASH_ATTN_VERSION": lambda: maybe_convert_int(
|
"VLLM_FLASH_ATTN_VERSION": lambda: maybe_convert_int(
|
||||||
@ -896,6 +892,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|||||||
"VLLM_ROCM_USE_AITER_FP8BMM": lambda: (
|
"VLLM_ROCM_USE_AITER_FP8BMM": lambda: (
|
||||||
os.getenv("VLLM_ROCM_USE_AITER_FP8BMM", "True").lower() in ("true", "1")
|
os.getenv("VLLM_ROCM_USE_AITER_FP8BMM", "True").lower() in ("true", "1")
|
||||||
),
|
),
|
||||||
|
# Use AITER triton unified attention for V1 attention
|
||||||
|
"VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION": lambda: (
|
||||||
|
os.getenv("VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION", "False").lower()
|
||||||
|
in ("true", "1")
|
||||||
|
),
|
||||||
# use rocm skinny gemms
|
# use rocm skinny gemms
|
||||||
"VLLM_ROCM_USE_SKINNY_GEMM": lambda: (
|
"VLLM_ROCM_USE_SKINNY_GEMM": lambda: (
|
||||||
os.getenv("VLLM_ROCM_USE_SKINNY_GEMM", "True").lower() in ("true", "1")
|
os.getenv("VLLM_ROCM_USE_SKINNY_GEMM", "True").lower() in ("true", "1")
|
||||||
@ -1434,7 +1435,6 @@ def compute_hash() -> str:
|
|||||||
"VLLM_FUSED_MOE_CHUNK_SIZE",
|
"VLLM_FUSED_MOE_CHUNK_SIZE",
|
||||||
"VLLM_FLASHINFER_MOE_BACKEND",
|
"VLLM_FLASHINFER_MOE_BACKEND",
|
||||||
"VLLM_V1_USE_PREFILL_DECODE_ATTENTION",
|
"VLLM_V1_USE_PREFILL_DECODE_ATTENTION",
|
||||||
"VLLM_USE_AITER_UNIFIED_ATTENTION",
|
|
||||||
"VLLM_ATTENTION_BACKEND",
|
"VLLM_ATTENTION_BACKEND",
|
||||||
"VLLM_USE_FLASHINFER_SAMPLER",
|
"VLLM_USE_FLASHINFER_SAMPLER",
|
||||||
"VLLM_DISABLED_KERNELS",
|
"VLLM_DISABLED_KERNELS",
|
||||||
@ -1462,6 +1462,7 @@ def compute_hash() -> str:
|
|||||||
"VLLM_ROCM_USE_AITER_FP4_ASM_GEMM",
|
"VLLM_ROCM_USE_AITER_FP4_ASM_GEMM",
|
||||||
"VLLM_ROCM_USE_TRITON_ROPE",
|
"VLLM_ROCM_USE_TRITON_ROPE",
|
||||||
"VLLM_ROCM_USE_AITER_FP8BMM",
|
"VLLM_ROCM_USE_AITER_FP8BMM",
|
||||||
|
"VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION",
|
||||||
"VLLM_ROCM_USE_SKINNY_GEMM",
|
"VLLM_ROCM_USE_SKINNY_GEMM",
|
||||||
"VLLM_ROCM_FP8_PADDING",
|
"VLLM_ROCM_FP8_PADDING",
|
||||||
"VLLM_ROCM_MOE_PADDING",
|
"VLLM_ROCM_MOE_PADDING",
|
||||||
|
|||||||
@ -276,25 +276,33 @@ class RocmPlatform(Platform):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if envs.VLLM_USE_V1:
|
if envs.VLLM_USE_V1:
|
||||||
if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9():
|
if (
|
||||||
logger.info("Using Flash Attention backend on V1 engine.")
|
envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9()
|
||||||
|
) or selected_backend == _Backend.ROCM_AITER_FA:
|
||||||
|
logger.info("Using Aiter Flash Attention backend on V1 engine.")
|
||||||
return (
|
return (
|
||||||
"vllm.v1.attention.backends."
|
"vllm.v1.attention.backends."
|
||||||
"rocm_aiter_fa.AiterFlashAttentionBackend"
|
"rocm_aiter_fa.AiterFlashAttentionBackend"
|
||||||
)
|
)
|
||||||
elif (
|
if (
|
||||||
(envs.VLLM_ROCM_USE_AITER and envs.VLLM_USE_AITER_UNIFIED_ATTENTION)
|
envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION
|
||||||
or envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION
|
) or selected_backend == _Backend.ROCM_AITER_UNIFIED_ATTN:
|
||||||
|
logger.info("Using Aiter Unified Attention backend on V1 engine.")
|
||||||
|
return (
|
||||||
|
"vllm.v1.attention.backends."
|
||||||
|
"rocm_aiter_unified_attn.RocmAiterUnifiedAttentionBackend"
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION
|
||||||
or selected_backend == _Backend.ROCM_ATTN
|
or selected_backend == _Backend.ROCM_ATTN
|
||||||
):
|
):
|
||||||
# rocm specific backend, with aiter and/or
|
# rocm specific backend, with aiter and/or
|
||||||
# triton prefix-prefill
|
# triton prefix-prefill
|
||||||
logger.info("Using Rocm/Aiter Attention backend on V1 engine.")
|
logger.info("Using Rocm Attention backend on V1 engine.")
|
||||||
return "vllm.v1.attention.backends.rocm_attn.RocmAttentionBackend"
|
return "vllm.v1.attention.backends.rocm_attn.RocmAttentionBackend"
|
||||||
else:
|
# default case, using triton unified attention
|
||||||
# default case, using triton unified attention
|
logger.info("Using Triton Attention backend on V1 engine.")
|
||||||
logger.info("Using Triton Attention backend on V1 engine.")
|
return "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend"
|
||||||
return "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend"
|
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"V0 attention backends have been removed. Set VLLM_USE_V1=1 "
|
"V0 attention backends have been removed. Set VLLM_USE_V1=1 "
|
||||||
"to select a supported backend."
|
"to select a supported backend."
|
||||||
|
|||||||
203
vllm/v1/attention/backends/rocm_aiter_unified_attn.py
Normal file
203
vllm/v1/attention/backends/rocm_aiter_unified_attn.py
Normal file
@ -0,0 +1,203 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
"""Attention layer with PagedAttention and Triton prefix prefill."""
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm import _custom_ops as ops
|
||||||
|
from vllm.attention.backends.abstract import AttentionMetadata, AttentionType
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
|
QuantKey,
|
||||||
|
kFp8StaticTensorSym,
|
||||||
|
)
|
||||||
|
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
|
||||||
|
from vllm.v1.attention.backends.rocm_attn import (
|
||||||
|
RocmAttentionBackend,
|
||||||
|
RocmAttentionImpl,
|
||||||
|
RocmAttentionMetadata,
|
||||||
|
RocmAttentionMetadataBuilder,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class RocmAiterUnifiedAttentionBackend(RocmAttentionBackend):
|
||||||
|
accept_output_buffer: bool = True
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_name() -> str:
|
||||||
|
return "ROCM_AITER_UNIFIED_ATTN"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_impl_cls() -> type["RocmAiterUnifiedAttentionImpl"]:
|
||||||
|
return RocmAiterUnifiedAttentionImpl
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_metadata_cls() -> type["AttentionMetadata"]:
|
||||||
|
return RocmAttentionMetadata
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_kv_cache_shape(
|
||||||
|
num_blocks: int,
|
||||||
|
block_size: int,
|
||||||
|
num_kv_heads: int,
|
||||||
|
head_size: int,
|
||||||
|
cache_dtype_str: str = "auto",
|
||||||
|
) -> tuple[int, ...]:
|
||||||
|
if block_size % 16 != 0:
|
||||||
|
raise ValueError("Block size must be a multiple of 16.")
|
||||||
|
return (2, num_blocks, block_size, num_kv_heads, head_size)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def use_cascade_attention(*args, **kwargs) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_builder_cls() -> type["RocmAttentionMetadataBuilder"]:
|
||||||
|
return RocmAttentionMetadataBuilder
|
||||||
|
|
||||||
|
|
||||||
|
class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl):
|
||||||
|
def fused_output_quant_supported(self, quant_key: QuantKey):
|
||||||
|
return quant_key == kFp8StaticTensorSym
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_heads: int,
|
||||||
|
head_size: int,
|
||||||
|
scale: float,
|
||||||
|
num_kv_heads: int,
|
||||||
|
alibi_slopes: Optional[list[float]],
|
||||||
|
sliding_window: Optional[int],
|
||||||
|
kv_cache_dtype: str,
|
||||||
|
logits_soft_cap: Optional[float] = None,
|
||||||
|
attn_type: AttentionType = AttentionType.DECODER,
|
||||||
|
kv_sharing_target_layer_name: Optional[int] = None,
|
||||||
|
sinks: Optional[torch.Tensor] = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(
|
||||||
|
num_heads,
|
||||||
|
head_size,
|
||||||
|
scale,
|
||||||
|
num_kv_heads,
|
||||||
|
alibi_slopes,
|
||||||
|
sliding_window,
|
||||||
|
kv_cache_dtype,
|
||||||
|
logits_soft_cap,
|
||||||
|
attn_type,
|
||||||
|
kv_sharing_target_layer_name,
|
||||||
|
sinks,
|
||||||
|
)
|
||||||
|
logger.info_once(
|
||||||
|
"Using aiter unified attention for RocmAiterUnifiedAttentionImpl"
|
||||||
|
)
|
||||||
|
from aiter.ops.triton.unified_attention import unified_attention
|
||||||
|
|
||||||
|
self.unified_attention = unified_attention
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
kv_cache: torch.Tensor,
|
||||||
|
attn_metadata: FlashAttentionMetadata,
|
||||||
|
output: Optional[torch.Tensor] = None,
|
||||||
|
output_scale: Optional[torch.Tensor] = None,
|
||||||
|
output_block_scale: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Forward pass with FlashAttention.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: shape = [num_tokens, num_heads, head_size]
|
||||||
|
key: shape = [num_tokens, num_kv_heads, head_size]
|
||||||
|
value: shape = [num_tokens, num_kv_heads, head_size]
|
||||||
|
kv_cache: shape =
|
||||||
|
[2, num_blocks, block_size, num_kv_heads, head_size]
|
||||||
|
attn_metadata: Metadata for attention.
|
||||||
|
Returns:
|
||||||
|
shape = [num_tokens, num_heads * head_size]
|
||||||
|
"""
|
||||||
|
assert output is not None, "Output tensor must be provided."
|
||||||
|
|
||||||
|
if output_block_scale is not None:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"fused block_scale output quantization is not yet supported"
|
||||||
|
" for RocmAttentionImpl"
|
||||||
|
)
|
||||||
|
|
||||||
|
if attn_metadata is None:
|
||||||
|
# Profiling run.
|
||||||
|
return output
|
||||||
|
|
||||||
|
assert attn_metadata.use_cascade is False
|
||||||
|
|
||||||
|
# IMPORTANT!
|
||||||
|
# NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
|
||||||
|
# eager-mode PyTorch. Thus, we need to be careful about any CPU overhead
|
||||||
|
# in this method. For example, `view` and `slice` (or `[:n]`) operations
|
||||||
|
# are surprisingly slow even in the case they do not invoke any GPU ops.
|
||||||
|
# Minimize the PyTorch ops in this method as much as possible.
|
||||||
|
# Whenever making a change in this method, please benchmark the
|
||||||
|
# performance to make sure it does not introduce any overhead.
|
||||||
|
|
||||||
|
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||||
|
|
||||||
|
key_cache, value_cache = kv_cache.unbind(0)
|
||||||
|
|
||||||
|
if self.kv_sharing_target_layer_name is None:
|
||||||
|
# Reshape the input keys and values and store them in the cache.
|
||||||
|
# Skip this if sharing KV cache with an earlier attention layer.
|
||||||
|
ops.reshape_and_cache_flash(
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
attn_metadata.slot_mapping,
|
||||||
|
self.kv_cache_dtype,
|
||||||
|
layer._k_scale,
|
||||||
|
layer._v_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.kv_cache_dtype.startswith("fp8"):
|
||||||
|
key_cache = key_cache.view(self.fp8_dtype)
|
||||||
|
value_cache = value_cache.view(self.fp8_dtype)
|
||||||
|
assert layer._q_scale_float == 1.0, (
|
||||||
|
"A non 1.0 q_scale is not currently supported."
|
||||||
|
)
|
||||||
|
|
||||||
|
cu_seqlens_q = attn_metadata.query_start_loc
|
||||||
|
seqused_k = attn_metadata.seq_lens
|
||||||
|
max_seqlen_q = attn_metadata.max_query_len
|
||||||
|
max_seqlen_k = attn_metadata.max_seq_len
|
||||||
|
block_table = attn_metadata.block_table
|
||||||
|
|
||||||
|
descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1])
|
||||||
|
|
||||||
|
self.unified_attention(
|
||||||
|
q=query[:num_actual_tokens],
|
||||||
|
k=key_cache,
|
||||||
|
v=value_cache,
|
||||||
|
out=output[:num_actual_tokens],
|
||||||
|
cu_seqlens_q=cu_seqlens_q,
|
||||||
|
max_seqlen_q=max_seqlen_q,
|
||||||
|
seqused_k=seqused_k,
|
||||||
|
max_seqlen_k=max_seqlen_k,
|
||||||
|
softmax_scale=self.scale,
|
||||||
|
causal=True,
|
||||||
|
alibi_slopes=self.alibi_slopes,
|
||||||
|
window_size=self.sliding_window,
|
||||||
|
block_table=block_table,
|
||||||
|
softcap=self.logits_soft_cap,
|
||||||
|
q_descale=None, # Not supported
|
||||||
|
k_descale=layer._k_scale.expand(descale_shape),
|
||||||
|
v_descale=layer._v_scale.expand(descale_shape),
|
||||||
|
sinks=self.sinks,
|
||||||
|
output_scale=output_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
return output
|
||||||
@ -3,13 +3,10 @@
|
|||||||
"""Attention layer with PagedAttention and Triton prefix prefill."""
|
"""Attention layer with PagedAttention and Triton prefix prefill."""
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from functools import cache
|
|
||||||
from typing import ClassVar, Optional
|
from typing import ClassVar, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm import _custom_ops as ops
|
|
||||||
from vllm import envs
|
|
||||||
from vllm.attention.backends.abstract import (
|
from vllm.attention.backends.abstract import (
|
||||||
AttentionBackend,
|
AttentionBackend,
|
||||||
AttentionImpl,
|
AttentionImpl,
|
||||||
@ -96,12 +93,11 @@ class RocmAttentionMetadataBuilder(AttentionMetadataBuilder[RocmAttentionMetadat
|
|||||||
# slow, so here we set it to 1.
|
# slow, so here we set it to 1.
|
||||||
attn_metadata.seq_lens.fill_(1)
|
attn_metadata.seq_lens.fill_(1)
|
||||||
|
|
||||||
if envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION:
|
# Here we set the query start locs to 0. This is to
|
||||||
# Here we set the query start locs to 0. This is to
|
# cover up an invalid memory access in the prefix_prefil kernel
|
||||||
# cover up an invalid memory access in the prefix_prefil kernel
|
# that we run into during graph capture (#25985)
|
||||||
# that we run into during graph capture (#25985)
|
common_attn_metadata.query_start_loc.zero_()
|
||||||
common_attn_metadata.query_start_loc.zero_()
|
common_attn_metadata.query_start_loc_cpu.zero_()
|
||||||
common_attn_metadata.query_start_loc_cpu.zero_()
|
|
||||||
|
|
||||||
return attn_metadata
|
return attn_metadata
|
||||||
|
|
||||||
@ -211,14 +207,6 @@ class RocmAttentionBackend(AttentionBackend):
|
|||||||
return RocmAttentionMetadataBuilder
|
return RocmAttentionMetadataBuilder
|
||||||
|
|
||||||
|
|
||||||
@cache
|
|
||||||
def use_aiter_unified_attention() -> bool:
|
|
||||||
"""Check if aiter unified attention should be used."""
|
|
||||||
# VLLM_ROCM_USE_AITER_MHA needs to set to 0 as well as it is set
|
|
||||||
# to 1 as default
|
|
||||||
return envs.VLLM_ROCM_USE_AITER and envs.VLLM_USE_AITER_UNIFIED_ATTENTION
|
|
||||||
|
|
||||||
|
|
||||||
class RocmAttentionImpl(AttentionImpl):
|
class RocmAttentionImpl(AttentionImpl):
|
||||||
def fused_output_quant_supported(self, quant_key: QuantKey):
|
def fused_output_quant_supported(self, quant_key: QuantKey):
|
||||||
return quant_key == kFp8StaticTensorSym
|
return quant_key == kFp8StaticTensorSym
|
||||||
@ -268,23 +256,6 @@ class RocmAttentionImpl(AttentionImpl):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.fp8_dtype = current_platform.fp8_dtype()
|
self.fp8_dtype = current_platform.fp8_dtype()
|
||||||
self.force_prefill_decode_attn = envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION
|
|
||||||
|
|
||||||
if not self.force_prefill_decode_attn:
|
|
||||||
# If not using prefill decode attention, we use the Triton
|
|
||||||
# unified attention implementation.
|
|
||||||
if use_aiter_unified_attention():
|
|
||||||
logger.info_once("Using aiter unified attention for RocmAttentionImpl")
|
|
||||||
from aiter.ops.triton.unified_attention import unified_attention
|
|
||||||
|
|
||||||
self.unified_attention = unified_attention
|
|
||||||
else:
|
|
||||||
logger.info_once("Using vllm unified attention for RocmAttentionImpl")
|
|
||||||
from vllm.attention.ops.triton_unified_attention import (
|
|
||||||
unified_attention,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.unified_attention = unified_attention
|
|
||||||
|
|
||||||
self.sinks = sinks
|
self.sinks = sinks
|
||||||
if sinks is not None:
|
if sinks is not None:
|
||||||
@ -341,58 +312,32 @@ class RocmAttentionImpl(AttentionImpl):
|
|||||||
# Whenever making a change in this method, please benchmark the
|
# Whenever making a change in this method, please benchmark the
|
||||||
# performance to make sure it does not introduce any overhead.
|
# performance to make sure it does not introduce any overhead.
|
||||||
|
|
||||||
use_prefill_decode_attn = self.force_prefill_decode_attn
|
|
||||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||||
|
|
||||||
if use_prefill_decode_attn:
|
key_cache, value_cache = PagedAttention.split_kv_cache(
|
||||||
key_cache, value_cache = PagedAttention.split_kv_cache(
|
kv_cache, self.num_kv_heads, self.head_size
|
||||||
kv_cache, self.num_kv_heads, self.head_size
|
)
|
||||||
)
|
|
||||||
else:
|
|
||||||
key_cache, value_cache = kv_cache.unbind(0)
|
|
||||||
|
|
||||||
if self.kv_sharing_target_layer_name is None:
|
if self.kv_sharing_target_layer_name is None:
|
||||||
# Reshape the input keys and values and store them in the cache.
|
# Reshape the input keys and values and store them in the cache.
|
||||||
# Skip this if sharing KV cache with an earlier attention layer.
|
# Skip this if sharing KV cache with an earlier attention layer.
|
||||||
if use_prefill_decode_attn:
|
PagedAttention.write_to_paged_cache(
|
||||||
PagedAttention.write_to_paged_cache(
|
key,
|
||||||
key,
|
value,
|
||||||
value,
|
key_cache,
|
||||||
key_cache,
|
value_cache,
|
||||||
value_cache,
|
attn_metadata.slot_mapping,
|
||||||
attn_metadata.slot_mapping,
|
self.kv_cache_dtype,
|
||||||
self.kv_cache_dtype,
|
layer._k_scale,
|
||||||
layer._k_scale,
|
layer._v_scale,
|
||||||
layer._v_scale,
|
)
|
||||||
)
|
|
||||||
else:
|
|
||||||
ops.reshape_and_cache_flash(
|
|
||||||
key,
|
|
||||||
value,
|
|
||||||
key_cache,
|
|
||||||
value_cache,
|
|
||||||
attn_metadata.slot_mapping,
|
|
||||||
self.kv_cache_dtype,
|
|
||||||
layer._k_scale,
|
|
||||||
layer._v_scale,
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.kv_cache_dtype.startswith("fp8"):
|
if self.kv_cache_dtype.startswith("fp8"):
|
||||||
key_cache = key_cache.view(self.fp8_dtype)
|
key_cache = key_cache.view(self.fp8_dtype)
|
||||||
value_cache = value_cache.view(self.fp8_dtype)
|
value_cache = value_cache.view(self.fp8_dtype)
|
||||||
num_tokens, num_heads, head_size = query.shape
|
|
||||||
assert layer._q_scale_float == 1.0, (
|
assert layer._q_scale_float == 1.0, (
|
||||||
"A non 1.0 q_scale is not currently supported."
|
"A non 1.0 q_scale is not currently supported."
|
||||||
)
|
)
|
||||||
if current_platform.is_cuda():
|
|
||||||
# Skip Q quantization on ROCm and XPU, enable this on cuda
|
|
||||||
# only, since dequantizing back to f32 in the attention kernel
|
|
||||||
# is not supported.
|
|
||||||
query, _ = ops.scaled_fp8_quant(
|
|
||||||
query.reshape((num_tokens, num_heads * head_size)).contiguous(),
|
|
||||||
layer._q_scale,
|
|
||||||
)
|
|
||||||
query = query.reshape((num_tokens, num_heads, head_size))
|
|
||||||
|
|
||||||
cu_seqlens_q = attn_metadata.query_start_loc
|
cu_seqlens_q = attn_metadata.query_start_loc
|
||||||
seqused_k = attn_metadata.seq_lens
|
seqused_k = attn_metadata.seq_lens
|
||||||
@ -400,53 +345,27 @@ class RocmAttentionImpl(AttentionImpl):
|
|||||||
max_seqlen_k = attn_metadata.max_seq_len
|
max_seqlen_k = attn_metadata.max_seq_len
|
||||||
block_table = attn_metadata.block_table
|
block_table = attn_metadata.block_table
|
||||||
|
|
||||||
if use_prefill_decode_attn:
|
# Compute attention and update output up to `num_actual_tokens`.
|
||||||
# Compute attention and update output up to `num_actual_tokens`.
|
chunked_prefill_paged_decode(
|
||||||
chunked_prefill_paged_decode(
|
query=query[:num_actual_tokens],
|
||||||
query=query[:num_actual_tokens],
|
key=key[:num_actual_tokens],
|
||||||
key=key[:num_actual_tokens],
|
value=value[:num_actual_tokens],
|
||||||
value=value[:num_actual_tokens],
|
output=output[:num_actual_tokens],
|
||||||
output=output[:num_actual_tokens],
|
kv_cache_dtype=self.kv_cache_dtype,
|
||||||
kv_cache_dtype=self.kv_cache_dtype,
|
key_cache=key_cache,
|
||||||
key_cache=key_cache,
|
value_cache=value_cache,
|
||||||
value_cache=value_cache,
|
block_table=block_table,
|
||||||
block_table=block_table,
|
query_start_loc=cu_seqlens_q,
|
||||||
query_start_loc=cu_seqlens_q,
|
seq_lens=seqused_k,
|
||||||
seq_lens=seqused_k,
|
max_seq_len=max_seqlen_k,
|
||||||
max_seq_len=max_seqlen_k,
|
max_query_len=max_seqlen_q,
|
||||||
max_query_len=max_seqlen_q,
|
k_scale=layer._k_scale,
|
||||||
k_scale=layer._k_scale,
|
v_scale=layer._v_scale,
|
||||||
v_scale=layer._v_scale,
|
alibi_slopes=self.alibi_slopes,
|
||||||
alibi_slopes=self.alibi_slopes,
|
sliding_window=self.sliding_window[0],
|
||||||
sliding_window=self.sliding_window[0],
|
sm_scale=self.scale,
|
||||||
sm_scale=self.scale,
|
output_scale=output_scale,
|
||||||
output_scale=output_scale,
|
sinks=self.sinks,
|
||||||
sinks=self.sinks,
|
)
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1])
|
|
||||||
|
|
||||||
self.unified_attention(
|
|
||||||
q=query[:num_actual_tokens],
|
|
||||||
k=key_cache,
|
|
||||||
v=value_cache,
|
|
||||||
out=output[:num_actual_tokens],
|
|
||||||
cu_seqlens_q=cu_seqlens_q,
|
|
||||||
max_seqlen_q=max_seqlen_q,
|
|
||||||
seqused_k=seqused_k,
|
|
||||||
max_seqlen_k=max_seqlen_k,
|
|
||||||
softmax_scale=self.scale,
|
|
||||||
causal=True,
|
|
||||||
alibi_slopes=self.alibi_slopes,
|
|
||||||
window_size=self.sliding_window,
|
|
||||||
block_table=block_table,
|
|
||||||
softcap=self.logits_soft_cap,
|
|
||||||
q_descale=None, # Not supported
|
|
||||||
k_descale=layer._k_scale.expand(descale_shape),
|
|
||||||
v_descale=layer._v_scale.expand(descale_shape),
|
|
||||||
sinks=self.sinks,
|
|
||||||
output_scale=output_scale,
|
|
||||||
)
|
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user