mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-02 11:11:19 +08:00
[Bugfix] Fix ChunkedLocalAttention CUDA Graph setting (#28739)
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
This commit is contained in:
parent
e5c78956c0
commit
bf3ffb61e6
@ -1,7 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
import functools
|
import functools
|
||||||
from typing import ClassVar
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -12,11 +11,16 @@ from vllm.config.vllm import VllmConfig
|
|||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
from vllm.v1.attention.backends.utils import (
|
from vllm.v1.attention.backends.utils import (
|
||||||
AttentionCGSupport,
|
AttentionCGSupport,
|
||||||
|
AttentionMetadataBuilder,
|
||||||
CommonAttentionMetadata,
|
CommonAttentionMetadata,
|
||||||
make_local_attention_virtual_batches,
|
make_local_attention_virtual_batches,
|
||||||
subclass_attention_backend,
|
subclass_attention_backend,
|
||||||
)
|
)
|
||||||
from vllm.v1.kv_cache_interface import ChunkedLocalAttentionSpec, KVCacheSpec
|
from vllm.v1.kv_cache_interface import (
|
||||||
|
AttentionSpec,
|
||||||
|
ChunkedLocalAttentionSpec,
|
||||||
|
KVCacheSpec,
|
||||||
|
)
|
||||||
|
|
||||||
from ..layer import Attention
|
from ..layer import Attention
|
||||||
|
|
||||||
@ -30,9 +34,18 @@ def create_chunked_local_attention_backend(
|
|||||||
prefix = f"ChunkedLocalAttention_{attention_chunk_size}_{block_size}_"
|
prefix = f"ChunkedLocalAttention_{attention_chunk_size}_{block_size}_"
|
||||||
|
|
||||||
underlying_builder = underlying_attn_backend.get_builder_cls()
|
underlying_builder = underlying_attn_backend.get_builder_cls()
|
||||||
|
assert issubclass(underlying_builder, AttentionMetadataBuilder)
|
||||||
|
|
||||||
class ChunkedLocalAttentionBuilder(underlying_builder): # type: ignore
|
class ChunkedLocalAttentionBuilder(underlying_builder): # type: ignore
|
||||||
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.NEVER
|
@classmethod
|
||||||
|
def get_cudagraph_support(
|
||||||
|
cls: type["AttentionMetadataBuilder"],
|
||||||
|
vllm_config: VllmConfig,
|
||||||
|
kv_cache_spec: AttentionSpec,
|
||||||
|
) -> AttentionCGSupport:
|
||||||
|
# Explicit override in case the underlying builder specialized this getter.
|
||||||
|
# @override omitted only because of mypy limitation due to type variable.
|
||||||
|
return AttentionCGSupport.NEVER
|
||||||
|
|
||||||
def build(
|
def build(
|
||||||
self,
|
self,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user