diff --git a/vllm/v1/attention/backends/mla/indexer.py b/vllm/v1/attention/backends/mla/indexer.py index 1344840af6a56..b8a232c8447bb 100644 --- a/vllm/v1/attention/backends/mla/indexer.py +++ b/vllm/v1/attention/backends/mla/indexer.py @@ -1,11 +1,15 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass -from typing import ClassVar, Optional +from typing import ClassVar, Optional, Union import torch -from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata +from vllm.attention.backends.abstract import ( + AttentionBackend, + AttentionMetadata, + MultipleOf, +) from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.utils.deep_gemm import get_paged_mqa_logits_metadata @@ -47,6 +51,10 @@ class DeepseekV32IndexerBackend(AttentionBackend): def get_kv_cache_stride_order() -> tuple[int, ...]: return (0, 1, 2) + @classmethod + def get_supported_kernel_block_size(cls) -> list[Union[int, MultipleOf]]: + return [64] + @dataclass class DeepseekV32IndexerPrefillChunkMetadata: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index dce8c650e0eb3..d1e41e6f6f12e 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -4242,9 +4242,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): for kv_cache_group_id, kv_cache_group in enumerate( kv_cache_config.kv_cache_groups ): - if isinstance(kv_cache_group.kv_cache_spec, EncoderOnlyAttentionSpec): + kv_cache_spec = kv_cache_group.kv_cache_spec + if isinstance(kv_cache_spec, UniformTypeKVCacheSpecs): + # All layers in the UniformTypeKVCacheSpecs have the same type, + # Pick an arbitrary one to dispatch. + kv_cache_spec = next(iter(kv_cache_spec.kv_cache_specs.values())) + if isinstance(kv_cache_spec, EncoderOnlyAttentionSpec): continue - elif isinstance(kv_cache_group.kv_cache_spec, AttentionSpec): + elif isinstance(kv_cache_spec, AttentionSpec): # This is an attention backend that supports virtual # block splitting. Get the supported block sizes from # all backends in the group. @@ -4254,10 +4259,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): kv_manager_block_size, attn_groups ) kernel_block_sizes.append(selected_kernel_size) - elif isinstance(kv_cache_group.kv_cache_spec, MambaSpec): + elif isinstance(kv_cache_spec, MambaSpec): # This is likely Mamba or other non-attention cache, # no splitting. - kernel_block_sizes.append(kv_cache_group.kv_cache_spec.block_size) + kernel_block_sizes.append(kv_cache_spec.block_size) else: raise NotImplementedError( f"unknown kv cache spec {kv_cache_group.kv_cache_spec}"