mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-19 06:45:01 +08:00
[Attention] Unify mamba and attention backend selection (#23171)
Signed-off-by: Ayush Satyam <ayushsatyam146@gmail.com>
This commit is contained in:
parent
d0a4a3f645
commit
5c4b6e66fe
104
tests/v1/attention/test_attention_backends_selection.py
Normal file
104
tests/v1/attention/test_attention_backends_selection.py
Normal file
@ -0,0 +1,104 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
"""Tests for mamba attention backend selectors."""
|
||||||
|
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer
|
||||||
|
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
|
||||||
|
from vllm.model_executor.layers.mamba.short_conv import ShortConv
|
||||||
|
from vllm.model_executor.models.minimax_text_01 import (
|
||||||
|
MiniMaxText01LinearAttention)
|
||||||
|
from vllm.v1.attention.backends.linear_attn import LinearAttentionBackend
|
||||||
|
from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionBackend
|
||||||
|
from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionBackend
|
||||||
|
from vllm.v1.attention.backends.short_conv_attn import (
|
||||||
|
ShortConvAttentionBackend)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"layer_class, init_kwargs, expected_backend, expected_mamba_type", [
|
||||||
|
(
|
||||||
|
MambaMixer,
|
||||||
|
dict(
|
||||||
|
hidden_size=128,
|
||||||
|
ssm_state_size=16,
|
||||||
|
conv_kernel_size=4,
|
||||||
|
intermediate_size=256,
|
||||||
|
time_step_rank=8,
|
||||||
|
use_conv_bias=True,
|
||||||
|
use_bias=False,
|
||||||
|
use_rms_norm=True,
|
||||||
|
),
|
||||||
|
Mamba1AttentionBackend,
|
||||||
|
"mamba1",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
MambaMixer2,
|
||||||
|
dict(
|
||||||
|
hidden_size=128,
|
||||||
|
ssm_state_size=16,
|
||||||
|
conv_kernel_size=4,
|
||||||
|
intermediate_size=256,
|
||||||
|
use_conv_bias=True,
|
||||||
|
use_bias=False,
|
||||||
|
n_groups=1,
|
||||||
|
num_heads=8,
|
||||||
|
head_dim=32,
|
||||||
|
),
|
||||||
|
Mamba2AttentionBackend,
|
||||||
|
"mamba2",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
MiniMaxText01LinearAttention,
|
||||||
|
dict(
|
||||||
|
hidden_size=128,
|
||||||
|
hidden_inner_size=256,
|
||||||
|
num_heads=8,
|
||||||
|
head_dim=32,
|
||||||
|
max_position=2048,
|
||||||
|
block_size=64,
|
||||||
|
num_hidden_layer=12,
|
||||||
|
layer_idx=0,
|
||||||
|
linear_layer_idx=0,
|
||||||
|
),
|
||||||
|
LinearAttentionBackend,
|
||||||
|
"linear_attention",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
ShortConv,
|
||||||
|
dict(
|
||||||
|
config=SimpleNamespace(conv_L_cache=32, conv_bias=True),
|
||||||
|
dim=128,
|
||||||
|
layer_idx=0,
|
||||||
|
),
|
||||||
|
ShortConvAttentionBackend,
|
||||||
|
"short_conv",
|
||||||
|
),
|
||||||
|
])
|
||||||
|
def test_mamba_layers_get_attn_backend(dist_init, layer_class, init_kwargs,
|
||||||
|
expected_backend, expected_mamba_type):
|
||||||
|
"""Test that Mamba-like layers return the correct attention backend."""
|
||||||
|
layer = layer_class(**init_kwargs)
|
||||||
|
|
||||||
|
backend_class = layer.get_attn_backend()
|
||||||
|
assert backend_class is expected_backend
|
||||||
|
assert layer.mamba_type == expected_mamba_type
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("layer_class,expected_backend,expected_mamba_type", [
|
||||||
|
(MambaMixer, Mamba1AttentionBackend, "mamba1"),
|
||||||
|
(MambaMixer2, Mamba2AttentionBackend, "mamba2"),
|
||||||
|
(MiniMaxText01LinearAttention, LinearAttentionBackend, "linear_attention"),
|
||||||
|
(ShortConv, ShortConvAttentionBackend, "short_conv"),
|
||||||
|
])
|
||||||
|
def test_mamba_layers_have_unified_interface(layer_class, expected_backend,
|
||||||
|
expected_mamba_type):
|
||||||
|
"""Test that all Mamba layers have the unified get_attn_backend
|
||||||
|
interface."""
|
||||||
|
assert hasattr(layer_class, 'get_attn_backend'), (
|
||||||
|
f"{layer_class.__name__} should have get_attn_backend method")
|
||||||
|
assert hasattr(layer_class, 'mamba_type'), (
|
||||||
|
f"{layer_class.__name__} should have mamba_type property")
|
||||||
@ -1,25 +0,0 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
||||||
"""Tests for mamba attention backend selectors."""
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionBackend
|
|
||||||
from vllm.v1.attention.backends.mamba_selectors import get_mamba_attn_backend
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(argnames=["mamba_type", "expected_backend"],
|
|
||||||
argvalues=[("mamba2", Mamba2AttentionBackend)])
|
|
||||||
def test_get_mamba_attn_backend_mamba2(mamba_type, expected_backend):
|
|
||||||
backend_class = get_mamba_attn_backend(mamba_type)
|
|
||||||
|
|
||||||
assert backend_class is expected_backend
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_mamba_attn_backend_unsupported():
|
|
||||||
unsupported_types = ["mamba", ""]
|
|
||||||
|
|
||||||
for mamba_type in unsupported_types:
|
|
||||||
err_message = f"Mamba Attention type {mamba_type} is not supported yet."
|
|
||||||
with pytest.raises(NotImplementedError, match=err_message):
|
|
||||||
get_mamba_attn_backend(mamba_type)
|
|
||||||
@ -18,6 +18,7 @@ from vllm.distributed.kv_transfer import (get_kv_transfer_group,
|
|||||||
is_v1_kv_transfer_group)
|
is_v1_kv_transfer_group)
|
||||||
from vllm.forward_context import ForwardContext, get_forward_context
|
from vllm.forward_context import ForwardContext, get_forward_context
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||||
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
|
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
QuantizationConfig)
|
QuantizationConfig)
|
||||||
@ -54,7 +55,7 @@ def check_xformers_availability():
|
|||||||
return USE_XFORMERS_OPS
|
return USE_XFORMERS_OPS
|
||||||
|
|
||||||
|
|
||||||
class Attention(nn.Module):
|
class Attention(nn.Module, AttentionLayerBase):
|
||||||
"""Attention layer.
|
"""Attention layer.
|
||||||
|
|
||||||
This class takes query, key, and value tensors as input. The input tensors
|
This class takes query, key, and value tensors as input. The input tensors
|
||||||
|
|||||||
23
vllm/model_executor/layers/attention_layer_base.py
Normal file
23
vllm/model_executor/layers/attention_layer_base.py
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
"""Base class for attention-like layers."""
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from vllm.attention.backends.abstract import AttentionBackend
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionLayerBase(ABC):
|
||||||
|
"""
|
||||||
|
Base class for attention-like layers (Attention, Mamba, etc.)
|
||||||
|
that support the v1 engine.
|
||||||
|
|
||||||
|
This provides a common interface for getting attention backends
|
||||||
|
from different layer types.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_attn_backend(self) -> type["AttentionBackend"]:
|
||||||
|
"""Get the attention backend class for this layer."""
|
||||||
|
pass
|
||||||
@ -1,12 +1,18 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
from abc import ABC, abstractmethod
|
from abc import abstractmethod
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||||
|
|
||||||
class MambaBase(ABC):
|
if TYPE_CHECKING:
|
||||||
|
from vllm.attention.backends.abstract import AttentionBackend
|
||||||
|
|
||||||
|
|
||||||
|
class MambaBase(AttentionLayerBase):
|
||||||
"""
|
"""
|
||||||
Base class for Mamba-like layers which support the v1 engine.
|
Base class for Mamba-like layers which support the v1 engine.
|
||||||
Inherit from this class if you implement a custom layer.
|
Inherit from this class if you implement a custom layer.
|
||||||
@ -32,3 +38,8 @@ class MambaBase(ABC):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def mamba_type(self) -> str:
|
def mamba_type(self) -> str:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_attn_backend(self) -> type["AttentionBackend"]:
|
||||||
|
"""Get the attention backend class for this Mamba layer."""
|
||||||
|
pass
|
||||||
|
|||||||
@ -1,7 +1,10 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
from typing import NamedTuple, Optional
|
from typing import TYPE_CHECKING, NamedTuple, Optional
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from vllm.attention.backends.abstract import AttentionBackend
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@ -404,6 +407,11 @@ class MambaMixer(MambaBase, CustomOp):
|
|||||||
def mamba_type(self) -> str:
|
def mamba_type(self) -> str:
|
||||||
return "mamba1"
|
return "mamba1"
|
||||||
|
|
||||||
|
def get_attn_backend(self) -> type["AttentionBackend"]:
|
||||||
|
from vllm.v1.attention.backends.mamba1_attn import (
|
||||||
|
Mamba1AttentionBackend)
|
||||||
|
return Mamba1AttentionBackend
|
||||||
|
|
||||||
def _time_proj_bias(self) -> Optional[torch.Tensor]:
|
def _time_proj_bias(self) -> Optional[torch.Tensor]:
|
||||||
if hasattr(self.dt_proj, "bias") and self.dt_proj.bias is not None:
|
if hasattr(self.dt_proj, "bias") and self.dt_proj.bias is not None:
|
||||||
return self.dt_proj.bias.float()
|
return self.dt_proj.bias.float()
|
||||||
|
|||||||
@ -1,7 +1,10 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
from typing import Optional, Union
|
from typing import TYPE_CHECKING, Optional, Union
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from vllm.attention.backends.abstract import AttentionBackend
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@ -758,6 +761,11 @@ class MambaMixer2(MambaBase, CustomOp):
|
|||||||
def mamba_type(self) -> str:
|
def mamba_type(self) -> str:
|
||||||
return "mamba2"
|
return "mamba2"
|
||||||
|
|
||||||
|
def get_attn_backend(self) -> type["AttentionBackend"]:
|
||||||
|
from vllm.v1.attention.backends.mamba2_attn import (
|
||||||
|
Mamba2AttentionBackend)
|
||||||
|
return Mamba2AttentionBackend
|
||||||
|
|
||||||
|
|
||||||
def mamba_mixer2(
|
def mamba_mixer2(
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
|
|||||||
@ -1,7 +1,10 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
from typing import Optional
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from vllm.attention.backends.abstract import AttentionBackend
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -232,6 +235,11 @@ class ShortConv(MambaBase, CustomOp):
|
|||||||
def mamba_type(self) -> str:
|
def mamba_type(self) -> str:
|
||||||
return "short_conv"
|
return "short_conv"
|
||||||
|
|
||||||
|
def get_attn_backend(self) -> type["AttentionBackend"]:
|
||||||
|
from vllm.v1.attention.backends.short_conv_attn import (
|
||||||
|
ShortConvAttentionBackend)
|
||||||
|
return ShortConvAttentionBackend
|
||||||
|
|
||||||
|
|
||||||
def short_conv(
|
def short_conv(
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
|
|||||||
@ -4,7 +4,10 @@
|
|||||||
import copy
|
import copy
|
||||||
import math
|
import math
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
from typing import Optional, Union
|
from typing import TYPE_CHECKING, Optional, Union
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from vllm.attention.backends.abstract import AttentionBackend
|
||||||
|
|
||||||
import regex as re
|
import regex as re
|
||||||
import torch
|
import torch
|
||||||
@ -339,6 +342,11 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
|
|||||||
def mamba_type(self) -> str:
|
def mamba_type(self) -> str:
|
||||||
return "linear_attention"
|
return "linear_attention"
|
||||||
|
|
||||||
|
def get_attn_backend(self) -> type["AttentionBackend"]:
|
||||||
|
from vllm.v1.attention.backends.linear_attn import (
|
||||||
|
LinearAttentionBackend)
|
||||||
|
return LinearAttentionBackend
|
||||||
|
|
||||||
def get_state_dtype(self) -> tuple[torch.dtype]:
|
def get_state_dtype(self) -> tuple[torch.dtype]:
|
||||||
return MambaStateDtypeCalculator.linear_attention_state_dtype(
|
return MambaStateDtypeCalculator.linear_attention_state_dtype(
|
||||||
self.model_config.dtype,
|
self.model_config.dtype,
|
||||||
|
|||||||
@ -1,22 +0,0 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
||||||
from vllm.attention.backends.abstract import AttentionBackend
|
|
||||||
from vllm.v1.attention.backends.linear_attn import LinearAttentionBackend
|
|
||||||
from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionBackend
|
|
||||||
from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionBackend
|
|
||||||
from vllm.v1.attention.backends.short_conv_attn import (
|
|
||||||
ShortConvAttentionBackend)
|
|
||||||
|
|
||||||
|
|
||||||
def get_mamba_attn_backend(mamba_type: str) -> type[AttentionBackend]:
|
|
||||||
if mamba_type == "mamba1":
|
|
||||||
return Mamba1AttentionBackend
|
|
||||||
if mamba_type == "mamba2":
|
|
||||||
return Mamba2AttentionBackend
|
|
||||||
if mamba_type == "linear_attention":
|
|
||||||
return LinearAttentionBackend
|
|
||||||
if mamba_type == "short_conv":
|
|
||||||
return ShortConvAttentionBackend
|
|
||||||
|
|
||||||
raise NotImplementedError(f"Mamba Attention type {mamba_type} is not "
|
|
||||||
"supported yet.")
|
|
||||||
@ -35,7 +35,8 @@ from vllm.distributed.parallel_state import (
|
|||||||
from vllm.forward_context import (BatchDescriptor, DPMetadata,
|
from vllm.forward_context import (BatchDescriptor, DPMetadata,
|
||||||
set_forward_context)
|
set_forward_context)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaBase
|
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||||
|
from vllm.model_executor.layers.mamba.abstract import MambaBase
|
||||||
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
|
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
|
||||||
from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader
|
from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader
|
||||||
from vllm.model_executor.models.interfaces import (is_mixture_of_experts,
|
from vllm.model_executor.models.interfaces import (is_mixture_of_experts,
|
||||||
@ -55,7 +56,6 @@ from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
|
|||||||
GiB_bytes, LazyLoader, cdiv, check_use_alibi,
|
GiB_bytes, LazyLoader, cdiv, check_use_alibi,
|
||||||
get_dtype_size, is_pin_memory_available, round_up,
|
get_dtype_size, is_pin_memory_available, round_up,
|
||||||
supports_dynamo)
|
supports_dynamo)
|
||||||
from vllm.v1.attention.backends.mamba_selectors import get_mamba_attn_backend
|
|
||||||
from vllm.v1.attention.backends.utils import (
|
from vllm.v1.attention.backends.utils import (
|
||||||
AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata,
|
AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata,
|
||||||
make_kv_sharing_fast_prefill_attention_metadata,
|
make_kv_sharing_fast_prefill_attention_metadata,
|
||||||
@ -2747,11 +2747,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
"""
|
"""
|
||||||
assert len(self.attn_groups) == 0, \
|
assert len(self.attn_groups) == 0, \
|
||||||
"Attention backends are already initialized"
|
"Attention backends are already initialized"
|
||||||
attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention)
|
|
||||||
|
|
||||||
def get_attn_backends_for_layers(
|
def get_attn_backends_for_layers(
|
||||||
layer_names: list[str]
|
layer_names: list[str]
|
||||||
) -> dict[type[AttentionBackend], list[str]]:
|
) -> dict[type[AttentionBackend], list[str]]:
|
||||||
|
layers = get_layers_from_vllm_config(self.vllm_config,
|
||||||
|
AttentionLayerBase,
|
||||||
|
layer_names)
|
||||||
attn_backends = {}
|
attn_backends = {}
|
||||||
attn_backend_layers = defaultdict(list)
|
attn_backend_layers = defaultdict(list)
|
||||||
# Dedupe based on full class name; this is a bit safer than using
|
# Dedupe based on full class name; this is a bit safer than using
|
||||||
@ -2760,7 +2762,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
# they are cached correctly, there will be different objects per
|
# they are cached correctly, there will be different objects per
|
||||||
# layer.
|
# layer.
|
||||||
for layer_name in layer_names:
|
for layer_name in layer_names:
|
||||||
attn_backend = attn_layers[layer_name].get_attn_backend()
|
attn_backend = layers[layer_name].get_attn_backend()
|
||||||
key = attn_backend.full_cls_name()
|
key = attn_backend.full_cls_name()
|
||||||
attn_backends[key] = attn_backend
|
attn_backends[key] = attn_backend
|
||||||
attn_backend_layers[key].append(layer_name)
|
attn_backend_layers[key].append(layer_name)
|
||||||
@ -2789,20 +2791,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
|
|
||||||
for kv_cache_group_spec in kv_cache_config.kv_cache_groups:
|
for kv_cache_group_spec in kv_cache_config.kv_cache_groups:
|
||||||
kv_cache_spec = kv_cache_group_spec.kv_cache_spec
|
kv_cache_spec = kv_cache_group_spec.kv_cache_spec
|
||||||
if isinstance(kv_cache_spec, AttentionSpec):
|
|
||||||
attn_backends = get_attn_backends_for_layers(
|
attn_backends = get_attn_backends_for_layers(
|
||||||
kv_cache_group_spec.layer_names)
|
kv_cache_group_spec.layer_names)
|
||||||
# TODO(lucas): move `get_mamba_attn_backend` into the mamba
|
|
||||||
# layers like above
|
|
||||||
elif isinstance(kv_cache_spec, MambaSpec):
|
|
||||||
attn_backends = {
|
|
||||||
get_mamba_attn_backend(kv_cache_spec.mamba_type):
|
|
||||||
kv_cache_group_spec.layer_names
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
f"Unknown KV cache spec type: {type(kv_cache_spec)}")
|
|
||||||
|
|
||||||
self.attn_groups.append(
|
self.attn_groups.append(
|
||||||
create_attn_groups(attn_backends, kv_cache_spec))
|
create_attn_groups(attn_backends, kv_cache_spec))
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user