fix cuda graph (#22721)

Signed-off-by: fsx950223 <fsx950223@outlook.com>
This commit is contained in:
who who who 2025-08-20 14:24:37 +08:00 committed by GitHub
parent 8fd920924c
commit d983769c41
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention layer with AiterFlashAttention.""" """Attention layer with AiterFlashAttention."""
from dataclasses import dataclass from dataclasses import dataclass
from typing import ClassVar, Optional from typing import Optional
import torch import torch
@ -11,7 +11,8 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, from vllm.v1.attention.backends.utils import (AttentionCGSupport,
AttentionMetadataBuilder,
CommonAttentionMetadata) CommonAttentionMetadata)
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec
@ -231,7 +232,7 @@ class AiterFlashAttentionMetadata:
class AiterFlashAttentionMetadataBuilder( class AiterFlashAttentionMetadataBuilder(
AttentionMetadataBuilder[AiterFlashAttentionMetadata]): AttentionMetadataBuilder[AiterFlashAttentionMetadata]):
full_cudagraph_supported: ClassVar[bool] = True cudagraph_support = AttentionCGSupport.ALWAYS
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device): vllm_config: VllmConfig, device: torch.device):