mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 02:35:24 +08:00
Enable V1 for Hybrid SSM/Attention Models (#20016)
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com> Co-authored-by: Stanislaw Wozniak <stw@zurich.ibm.com> Co-authored-by: Tyler Michael Smith <tysmith@redhat.com> Co-authored-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
parent
ffe00ef77a
commit
2f35a022e6
@ -3,6 +3,7 @@
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.models.registry import HF_EXAMPLE_MODELS
|
||||
from tests.utils import multi_gpu_test
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.sampling_params import SamplingParams
|
||||
@ -19,31 +20,55 @@ pytestmark = pytest.mark.hybrid_model
|
||||
SSM_MODELS = [
|
||||
"state-spaces/mamba-130m-hf",
|
||||
"tiiuae/falcon-mamba-tiny-dev",
|
||||
# TODO: Compare to a Mamba2 model. The HF transformers implementation of
|
||||
# Mamba2 is buggy for Codestral as it doesn't handle n_groups, so the test
|
||||
# doesn't compare vLLM output with HF output.
|
||||
# See https://github.com/huggingface/transformers/pull/35943
|
||||
"mistralai/Mamba-Codestral-7B-v0.1",
|
||||
]
|
||||
|
||||
HYBRID_MODELS = [
|
||||
"ai21labs/Jamba-tiny-dev",
|
||||
# NOTE: Currently the test failes due to HF transformers issue fixed in:
|
||||
# https://github.com/huggingface/transformers/pull/39033
|
||||
# We will enable vLLM test for Granite after next HF transformers release.
|
||||
# "ibm-granite/granite-4.0-tiny-preview",
|
||||
# NOTE: Running Plamo2 in transformers implementation requires to install
|
||||
# causal-conv1d package, which is not listed as a test dependency as it's
|
||||
# not compatible with pip-compile.
|
||||
"pfnet/plamo-2-1b",
|
||||
"Zyphra/Zamba2-1.2B-instruct",
|
||||
"hmellor/tiny-random-BambaForCausalLM",
|
||||
"ibm-ai-platform/Bamba-9B-v1",
|
||||
"nvidia/Nemotron-H-8B-Base-8K",
|
||||
"ibm-granite/granite-4.0-tiny-preview",
|
||||
"tiiuae/Falcon-H1-0.5B-Base",
|
||||
]
|
||||
|
||||
HF_UNSUPPORTED_MODELS = [
|
||||
# The HF transformers implementation of
|
||||
# Mamba2 is buggy for Codestral as it doesn't handle n_groups, so the test
|
||||
# doesn't compare vLLM output with HF output.
|
||||
# See https://github.com/huggingface/transformers/pull/35943
|
||||
"mistralai/Mamba-Codestral-7B-v0.1",
|
||||
# Note: I'm not seeing the same output from vLLM V0 vs. HF transformers
|
||||
# for Nemotron-H-8B; currently only compare vLLM V0 vs. vLLM V1
|
||||
"nvidia/Nemotron-H-8B-Base-8K",
|
||||
# NOTE: Currently the test fails due to HF transformers issue fixed in:
|
||||
# https://github.com/huggingface/transformers/pull/39033
|
||||
# We will enable vLLM test for Granite after next HF transformers release.
|
||||
"ibm-granite/granite-4.0-tiny-preview",
|
||||
]
|
||||
|
||||
V1_SUPPORTED_MODELS = [
|
||||
"mistralai/Mamba-Codestral-7B-v0.1",
|
||||
"ibm-ai-platform/Bamba-9B-v1",
|
||||
"Zyphra/Zamba2-1.2B-instruct",
|
||||
"nvidia/Nemotron-H-8B-Base-8K",
|
||||
"ibm-granite/granite-4.0-tiny-preview",
|
||||
"tiiuae/Falcon-H1-0.5B-Base",
|
||||
]
|
||||
|
||||
ATTN_BLOCK_SIZES = {
|
||||
"ibm-ai-platform/Bamba-9B-v1": 528,
|
||||
"Zyphra/Zamba2-1.2B-instruct": 80,
|
||||
"nvidia/Nemotron-H-8B-Base-8K": 528,
|
||||
"ibm-granite/granite-4.0-tiny-preview": 400,
|
||||
"tiiuae/Falcon-H1-0.5B-Base": 800,
|
||||
}
|
||||
|
||||
# Avoid OOM
|
||||
MAX_NUM_SEQS = 4
|
||||
|
||||
@ -60,8 +85,16 @@ def test_models(
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
) -> None:
|
||||
|
||||
try:
|
||||
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
|
||||
model_info.check_available_online(on_fail="skip")
|
||||
model_info.check_transformers_version(on_fail="skip")
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
with hf_runner(model) as hf_model:
|
||||
if model != "mistralai/Mamba-Codestral-7B-v0.1":
|
||||
if model not in HF_UNSUPPORTED_MODELS:
|
||||
hf_outputs = hf_model.generate_greedy_logprobs_limit(
|
||||
example_prompts, max_tokens, num_logprobs)
|
||||
else:
|
||||
@ -72,12 +105,21 @@ def test_models(
|
||||
example_prompts, max_tokens, num_logprobs)
|
||||
|
||||
if model in V1_SUPPORTED_MODELS:
|
||||
if model in HYBRID_MODELS and model in ATTN_BLOCK_SIZES:
|
||||
block_size = ATTN_BLOCK_SIZES[model]
|
||||
else:
|
||||
block_size = 16
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
if model in HYBRID_MODELS:
|
||||
# required due to reorder_batch behaviour
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER")
|
||||
with vllm_runner(model,
|
||||
max_num_seqs=MAX_NUM_SEQS,
|
||||
enforce_eager=True,
|
||||
enable_prefix_caching=False) as vllm_model:
|
||||
enable_prefix_caching=False,
|
||||
block_size=block_size) as vllm_model:
|
||||
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
|
||||
example_prompts, max_tokens, num_logprobs)
|
||||
else:
|
||||
@ -111,6 +153,14 @@ def test_batching(
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
) -> None:
|
||||
|
||||
try:
|
||||
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
|
||||
model_info.check_available_online(on_fail="skip")
|
||||
model_info.check_transformers_version(on_fail="skip")
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
for_loop_outputs = []
|
||||
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
|
||||
for prompt in example_prompts:
|
||||
|
||||
@ -169,7 +169,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
||||
"ExaoneForCausalLM": _HfExamplesInfo("LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct"), # noqa: E501
|
||||
"Fairseq2LlamaForCausalLM": _HfExamplesInfo("mgleize/fairseq2-dummy-Llama-3.2-1B"), # noqa: E501
|
||||
"FalconForCausalLM": _HfExamplesInfo("tiiuae/falcon-7b"),
|
||||
"FalconH1ForCausalLM":_HfExamplesInfo("tiiuae/Falcon-H1-1.5B-Instruct",
|
||||
"FalconH1ForCausalLM":_HfExamplesInfo("tiiuae/Falcon-H1-0.5B-Base",
|
||||
min_transformers_version="4.53"),
|
||||
"GemmaForCausalLM": _HfExamplesInfo("google/gemma-1.1-2b-it"),
|
||||
"Gemma2ForCausalLM": _HfExamplesInfo("google/gemma-2-9b"),
|
||||
|
||||
@ -13,7 +13,6 @@ UNSUPPORTED_MODELS_V1 = [
|
||||
"openai/whisper-large-v3", # transcription
|
||||
"facebook/bart-large-cnn", # encoder decoder
|
||||
"state-spaces/mamba-130m-hf", # mamba1
|
||||
"hmellor/tiny-random-BambaForCausalLM", # hybrid
|
||||
"BAAI/bge-m3", # embedding
|
||||
]
|
||||
|
||||
|
||||
@ -108,7 +108,7 @@ def _selective_scan_update_kernel(
|
||||
# is the same as the batch id.
|
||||
if HAS_STATE_BATCH_INDICES:
|
||||
state_batch_indices_ptr += pid_b
|
||||
state_batch_idx = tl.load(state_batch_indices_ptr)
|
||||
state_batch_idx = tl.load(state_batch_indices_ptr).to(tl.int64)
|
||||
state_ptr += (state_batch_idx * stride_state_batch +
|
||||
pid_h * stride_state_head)
|
||||
else:
|
||||
|
||||
@ -9,6 +9,7 @@ import torch
|
||||
from torch import nn
|
||||
from transformers import BambaConfig
|
||||
|
||||
from vllm import envs
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
||||
@ -36,7 +37,7 @@ from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import LayerBlockType
|
||||
|
||||
from .interfaces import (HasInnerState, IsHybrid, SupportsLoRA, SupportsPP,
|
||||
SupportsQuant, SupportsV0Only)
|
||||
SupportsQuant)
|
||||
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix)
|
||||
@ -97,7 +98,9 @@ class BambaMixerDecoderLayer(nn.Module):
|
||||
head_dim=config.mamba_d_head,
|
||||
rms_norm_eps=config.rms_norm_eps,
|
||||
activation=config.hidden_act,
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mixer",
|
||||
chunk_size=config.mamba_chunk_size)
|
||||
|
||||
self.feed_forward = BambaMLP(config, quant_config=quant_config)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||
@ -313,10 +316,14 @@ class BambaModel(nn.Module):
|
||||
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
|
||||
mamba2_metadata = prepare_mamba2_metadata(
|
||||
chunk_size=self.config.mamba_chunk_size,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
if not envs.VLLM_USE_V1:
|
||||
mamba2_metadata = prepare_mamba2_metadata(
|
||||
chunk_size=self.config.mamba_chunk_size,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
else:
|
||||
# v1 get mamba2_metadata from forward_context
|
||||
mamba2_metadata = None
|
||||
|
||||
if get_pp_group().is_first_rank:
|
||||
if inputs_embeds is not None:
|
||||
@ -337,7 +344,8 @@ class BambaModel(nn.Module):
|
||||
num_attn += 1
|
||||
|
||||
layer_mamba_cache_params = None
|
||||
if isinstance(layer, BambaMixerDecoderLayer):
|
||||
if isinstance(layer,
|
||||
BambaMixerDecoderLayer) and mamba_cache_params:
|
||||
layer_mamba_cache_params = mamba_cache_params.at_layer_idx(
|
||||
i - num_attn)
|
||||
|
||||
@ -411,7 +419,7 @@ class BambaModel(nn.Module):
|
||||
|
||||
|
||||
class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
||||
IsHybrid, SupportsV0Only, SupportsQuant):
|
||||
IsHybrid, SupportsQuant):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
@ -475,15 +483,22 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs):
|
||||
if self.mamba_cache is None:
|
||||
|
||||
num_mamba_layers = self.model_config.get_num_layers_by_block_type(
|
||||
self.vllm_config.parallel_config, LayerBlockType.mamba)
|
||||
mamba_cache_params = None
|
||||
if not envs.VLLM_USE_V1:
|
||||
if self.mamba_cache is None:
|
||||
num_mamba_layers = \
|
||||
self.model_config.get_num_layers_by_block_type(
|
||||
self.vllm_config.parallel_config,
|
||||
LayerBlockType.mamba
|
||||
)
|
||||
|
||||
self.mamba_cache = MambaCacheManager(
|
||||
self.vllm_config, self.lm_head.weight.dtype,
|
||||
num_mamba_layers, *self._get_mamba_cache_shape())
|
||||
|
||||
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
||||
|
||||
self.mamba_cache = MambaCacheManager(
|
||||
self.vllm_config, self.lm_head.weight.dtype, num_mamba_layers,
|
||||
*self._get_mamba_cache_shape())
|
||||
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
||||
hidden_states = self.model(input_ids, positions, mamba_cache_params,
|
||||
intermediate_tensors, inputs_embeds)
|
||||
|
||||
|
||||
@ -8,6 +8,7 @@ import torch
|
||||
from torch import nn
|
||||
from transformers import FalconH1Config
|
||||
|
||||
from vllm import envs
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
||||
@ -33,8 +34,7 @@ from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import (HasInnerState, IsHybrid, SupportsLoRA, SupportsPP,
|
||||
SupportsV0Only)
|
||||
from .interfaces import HasInnerState, IsHybrid, SupportsLoRA, SupportsPP
|
||||
from .utils import (PPMissingLayer, is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix)
|
||||
@ -85,6 +85,7 @@ class FalconH1SSMDecoderLayer(nn.Module):
|
||||
config: FalconH1Config,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -107,6 +108,8 @@ class FalconH1SSMDecoderLayer(nn.Module):
|
||||
activation=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
use_rms_norm=config.mamba_rms_norm,
|
||||
prefix=f"{prefix}.mixer",
|
||||
chunk_size=config.mamba_chunk_size,
|
||||
)
|
||||
# n_groups is overridden later by `MambaMixer2`
|
||||
self.groups_time_state_size = self.mamba.n_groups * config.mamba_d_state
|
||||
@ -316,6 +319,7 @@ class FalconH1ParallelHybrid(nn.Module):
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
# Instantiate the attention branch
|
||||
self.self_attn = FalconH1AttentionDecoderLayer(
|
||||
config=config,
|
||||
@ -323,11 +327,18 @@ class FalconH1ParallelHybrid(nn.Module):
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
)
|
||||
|
||||
# In V1 all attention/ssm layers must have
|
||||
# different index in prefix
|
||||
ssm_layer_idx = config.num_hidden_layers + layer_idx
|
||||
ssm_prefix = prefix.split(".")[0] + f".{ssm_layer_idx}"
|
||||
|
||||
# Instantiate the SSM branch
|
||||
self.mamba = FalconH1SSMDecoderLayer(
|
||||
config=config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=ssm_prefix,
|
||||
)
|
||||
self.ssm_out_multiplier = config.ssm_out_multiplier
|
||||
self.ssm_in_multiplier = config.ssm_in_multiplier
|
||||
@ -452,10 +463,16 @@ class FalconH1Model(nn.Module):
|
||||
# proper continuous batching computation including
|
||||
# chunked prefill
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
mamba2_metadata = prepare_mamba2_metadata(
|
||||
chunk_size=self.config.mamba_chunk_size,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
|
||||
if not envs.VLLM_USE_V1:
|
||||
mamba2_metadata = prepare_mamba2_metadata(
|
||||
chunk_size=self.config.mamba_chunk_size,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
else:
|
||||
# v1 get mamba2_metadata from forward_context
|
||||
mamba2_metadata = None
|
||||
|
||||
if get_pp_group().is_first_rank:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds * self.embedding_multiplier
|
||||
@ -468,7 +485,9 @@ class FalconH1Model(nn.Module):
|
||||
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
layer_mamba_cache_params = mamba_cache_params.at_layer_idx(i)
|
||||
layer_mamba_cache_params = None
|
||||
if mamba_cache_params:
|
||||
layer_mamba_cache_params = mamba_cache_params.at_layer_idx(i)
|
||||
hidden_states = layer(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
@ -484,7 +503,7 @@ class FalconH1Model(nn.Module):
|
||||
|
||||
|
||||
class FalconH1ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
||||
IsHybrid, SupportsV0Only):
|
||||
IsHybrid):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
||||
"gate_up_proj": ["gate_proj", "up_proj"],
|
||||
@ -558,15 +577,19 @@ class FalconH1ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
):
|
||||
if self.mamba_cache is None:
|
||||
self.mamba_cache = MambaCacheManager(
|
||||
self.vllm_config,
|
||||
self.lm_head.weight.dtype
|
||||
if hasattr(self.lm_head, 'weight') else torch.bfloat16,
|
||||
self.config.num_hidden_layers,
|
||||
*self._get_mamba_cache_shape(),
|
||||
)
|
||||
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
||||
|
||||
mamba_cache_params = None
|
||||
if not envs.VLLM_USE_V1:
|
||||
if self.mamba_cache is None:
|
||||
self.mamba_cache = MambaCacheManager(
|
||||
self.vllm_config,
|
||||
self.lm_head.weight.dtype if hasattr(
|
||||
self.lm_head, 'weight') else torch.bfloat16,
|
||||
self.config.num_hidden_layers,
|
||||
*self._get_mamba_cache_shape(),
|
||||
)
|
||||
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
||||
|
||||
hidden_states = self.model(
|
||||
input_ids,
|
||||
positions,
|
||||
|
||||
@ -9,6 +9,7 @@ import torch
|
||||
from torch import nn
|
||||
from transformers import GraniteMoeHybridConfig
|
||||
|
||||
from vllm import envs
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
||||
@ -36,7 +37,7 @@ from vllm.utils import LayerBlockType
|
||||
from .granitemoe import GraniteMoeMoE
|
||||
from .granitemoeshared import GraniteMoeSharedMLP
|
||||
from .interfaces import (HasInnerState, IsHybrid, SupportsLoRA, SupportsPP,
|
||||
SupportsQuant, SupportsV0Only)
|
||||
SupportsQuant)
|
||||
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix)
|
||||
@ -67,7 +68,9 @@ class GraniteMoeHybridMambaDecoderLayer(nn.Module):
|
||||
head_dim=config.mamba_d_head,
|
||||
rms_norm_eps=config.rms_norm_eps,
|
||||
activation=config.hidden_act,
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mixer",
|
||||
chunk_size=config.mamba_chunk_size)
|
||||
|
||||
self.block_sparse_moe = None
|
||||
if getattr(config, "num_local_experts", 0) > 0:
|
||||
@ -361,10 +364,15 @@ class GraniteMoeHybridModel(nn.Module):
|
||||
) -> torch.Tensor:
|
||||
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
mamba2_metadata = prepare_mamba2_metadata(
|
||||
chunk_size=self.config.mamba_chunk_size,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
|
||||
if not envs.VLLM_USE_V1:
|
||||
mamba2_metadata = prepare_mamba2_metadata(
|
||||
chunk_size=self.config.mamba_chunk_size,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
else:
|
||||
# v1 get mamba2_metadata from forward_context
|
||||
mamba2_metadata = None
|
||||
|
||||
if get_pp_group().is_first_rank:
|
||||
if inputs_embeds is not None:
|
||||
@ -386,7 +394,9 @@ class GraniteMoeHybridModel(nn.Module):
|
||||
num_attn += 1
|
||||
|
||||
layer_mamba_cache_params = None
|
||||
if isinstance(layer, GraniteMoeHybridMambaDecoderLayer):
|
||||
if isinstance(
|
||||
layer,
|
||||
GraniteMoeHybridMambaDecoderLayer) and mamba_cache_params:
|
||||
layer_mamba_cache_params = mamba_cache_params.at_layer_idx(
|
||||
i - num_attn)
|
||||
|
||||
@ -501,8 +511,7 @@ class GraniteMoeHybridModel(nn.Module):
|
||||
|
||||
|
||||
class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA,
|
||||
SupportsPP, IsHybrid, SupportsV0Only,
|
||||
SupportsQuant):
|
||||
SupportsPP, IsHybrid, SupportsQuant):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
@ -571,14 +580,20 @@ class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs):
|
||||
if self.mamba_cache is None:
|
||||
num_mamba_layers = self.model_config.get_num_layers_by_block_type(
|
||||
self.vllm_config.parallel_config, LayerBlockType.mamba)
|
||||
self.mamba_cache = MambaCacheManager(
|
||||
self.vllm_config, self.model_config.dtype, num_mamba_layers,
|
||||
*self._get_mamba_cache_shape())
|
||||
|
||||
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
||||
mamba_cache_params = None
|
||||
if not envs.VLLM_USE_V1:
|
||||
if self.mamba_cache is None:
|
||||
num_mamba_layers = (
|
||||
self.model_config.get_num_layers_by_block_type(
|
||||
self.vllm_config.parallel_config,
|
||||
LayerBlockType.mamba))
|
||||
self.mamba_cache = MambaCacheManager(
|
||||
self.vllm_config, self.model_config.dtype,
|
||||
num_mamba_layers, *self._get_mamba_cache_shape())
|
||||
|
||||
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
||||
|
||||
hidden_states = self.model(input_ids, positions, mamba_cache_params,
|
||||
intermediate_tensors, inputs_embeds)
|
||||
|
||||
|
||||
@ -23,6 +23,7 @@ from typing import Optional
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from vllm import envs
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
||||
@ -44,8 +45,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.interfaces import (HasInnerState, IsHybrid,
|
||||
SupportsLoRA, SupportsPP,
|
||||
SupportsQuant,
|
||||
SupportsV0Only)
|
||||
SupportsQuant)
|
||||
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
|
||||
MambaCacheParams)
|
||||
from vllm.model_executor.models.utils import (
|
||||
@ -153,6 +153,8 @@ class NemotronHMambaDecoderLayer(nn.Module):
|
||||
rms_norm_eps=config.rms_norm_eps,
|
||||
activation=config.mamba_hidden_act,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mixer",
|
||||
chunk_size=config.chunk_size,
|
||||
)
|
||||
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
@ -348,10 +350,14 @@ class NemotronHModel(nn.Module):
|
||||
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
|
||||
mamba2_metadata = prepare_mamba2_metadata(
|
||||
chunk_size=self.config.chunk_size,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
if not envs.VLLM_USE_V1:
|
||||
mamba2_metadata = prepare_mamba2_metadata(
|
||||
chunk_size=self.config.chunk_size,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
else:
|
||||
# v1 get mamba2_metadata from forward_context
|
||||
mamba2_metadata = None
|
||||
|
||||
if get_pp_group().is_first_rank:
|
||||
if inputs_embeds is not None:
|
||||
@ -369,7 +375,8 @@ class NemotronHModel(nn.Module):
|
||||
for i in range(len(self.layers)):
|
||||
layer = self.layers[i]
|
||||
layer_mamba_cache_params = None
|
||||
if isinstance(layer, NemotronHMambaDecoderLayer):
|
||||
if isinstance(layer,
|
||||
NemotronHMambaDecoderLayer) and mamba_cache_params:
|
||||
layer_mamba_cache_params = mamba_cache_params.at_layer_idx(
|
||||
i - num_non_mamba_layers)
|
||||
else:
|
||||
@ -437,7 +444,7 @@ class NemotronHModel(nn.Module):
|
||||
|
||||
|
||||
class NemotronHForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
||||
IsHybrid, SupportsV0Only, SupportsQuant):
|
||||
IsHybrid, SupportsQuant):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
@ -499,15 +506,23 @@ class NemotronHForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs):
|
||||
if self.mamba_cache is None:
|
||||
|
||||
num_mamba_layers = self.model_config.get_num_layers_by_block_type(
|
||||
self.vllm_config.parallel_config, LayerBlockType.mamba)
|
||||
mamba_cache_params = None
|
||||
if not envs.VLLM_USE_V1:
|
||||
if self.mamba_cache is None:
|
||||
|
||||
num_mamba_layers = \
|
||||
self.model_config.get_num_layers_by_block_type(
|
||||
self.vllm_config.parallel_config,
|
||||
LayerBlockType.mamba
|
||||
)
|
||||
|
||||
self.mamba_cache = MambaCacheManager(
|
||||
self.vllm_config, self.lm_head.weight.dtype,
|
||||
num_mamba_layers, *self._get_mamba_cache_shape())
|
||||
|
||||
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
||||
|
||||
self.mamba_cache = MambaCacheManager(
|
||||
self.vllm_config, self.lm_head.weight.dtype, num_mamba_layers,
|
||||
*self._get_mamba_cache_shape())
|
||||
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
||||
hidden_states = self.model(input_ids, positions, mamba_cache_params,
|
||||
intermediate_tensors, inputs_embeds)
|
||||
|
||||
|
||||
@ -15,6 +15,7 @@ import torch
|
||||
from torch import nn
|
||||
from transformers import Zamba2Config
|
||||
|
||||
from vllm import envs
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
||||
@ -41,7 +42,7 @@ from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import HasInnerState, IsHybrid, SupportsV0Only
|
||||
from .interfaces import HasInnerState, IsHybrid
|
||||
from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix
|
||||
|
||||
|
||||
@ -58,6 +59,7 @@ class Zamba2LoRA(nn.Module):
|
||||
rank: int,
|
||||
output_dim: Union[int, list[int]],
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
"""Initialize the attention layer.
|
||||
|
||||
@ -283,6 +285,7 @@ class Zamba2MLP(nn.Module):
|
||||
bare_block_idx: int,
|
||||
num_hybrid_layers: dict[int, int],
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
"""Initialize the MLP layer.
|
||||
|
||||
@ -471,11 +474,10 @@ class Zamba2MambaDecoderLayer(nn.Module):
|
||||
computation depending on configuration.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Zamba2Config,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
def __init__(self,
|
||||
config: Zamba2Config,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "") -> None:
|
||||
"""Initialize the Mamba decoder layer.
|
||||
|
||||
Args:
|
||||
@ -486,20 +488,21 @@ class Zamba2MambaDecoderLayer(nn.Module):
|
||||
|
||||
# Initialize Mamba mixer with expanded intermediate size
|
||||
intermediate_size = config.mamba_expand * config.hidden_size
|
||||
self.mamba = MambaMixer2(
|
||||
hidden_size=config.hidden_size,
|
||||
ssm_state_size=config.mamba_d_state,
|
||||
conv_kernel_size=config.mamba_d_conv,
|
||||
intermediate_size=intermediate_size,
|
||||
use_conv_bias=config.use_conv_bias,
|
||||
use_bias=config.add_bias_linear,
|
||||
n_groups=config.mamba_ngroups,
|
||||
num_heads=config.n_mamba_heads,
|
||||
head_dim=intermediate_size // config.n_mamba_heads,
|
||||
rms_norm_eps=config.rms_norm_eps,
|
||||
activation="silu",
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.mamba = MambaMixer2(hidden_size=config.hidden_size,
|
||||
ssm_state_size=config.mamba_d_state,
|
||||
conv_kernel_size=config.mamba_d_conv,
|
||||
intermediate_size=intermediate_size,
|
||||
use_conv_bias=config.use_conv_bias,
|
||||
use_bias=config.add_bias_linear,
|
||||
n_groups=config.mamba_ngroups,
|
||||
num_heads=config.n_mamba_heads,
|
||||
head_dim=intermediate_size //
|
||||
config.n_mamba_heads,
|
||||
rms_norm_eps=config.rms_norm_eps,
|
||||
activation="silu",
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mixer",
|
||||
chunk_size=config.chunk_size)
|
||||
|
||||
# Input normalization
|
||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||
@ -573,6 +576,7 @@ class Zamba2HybridLayer(nn.Module):
|
||||
config: Zamba2Config,
|
||||
block_idx: int,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
"""Initialize the hybrid layer.
|
||||
|
||||
@ -589,7 +593,8 @@ class Zamba2HybridLayer(nn.Module):
|
||||
bias=False,
|
||||
quant_config=quant_config)
|
||||
self.mamba_decoder = Zamba2MambaDecoderLayer(config,
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
prefix=prefix)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -699,14 +704,23 @@ class Zamba2Model(nn.Module):
|
||||
# Initialize layers according to block type configuration
|
||||
layers = []
|
||||
for layer_idx, layer_type in enumerate(config.layers_block_type):
|
||||
# tdoublep: avoid layers getting same index
|
||||
# somewhat hacky but correct (I think)
|
||||
prefix = str(len(layer2block_map) + layer_idx)
|
||||
if layer_type == "hybrid":
|
||||
block = next(blocks)
|
||||
block_idx = layer2block_map[layer_idx]
|
||||
layers.append(
|
||||
Zamba2HybridLayer(block, config, block_idx, quant_config))
|
||||
Zamba2HybridLayer(block,
|
||||
config,
|
||||
block_idx,
|
||||
quant_config,
|
||||
prefix=prefix))
|
||||
else:
|
||||
layers.append(
|
||||
Zamba2MambaDecoderLayer(config, quant_config=quant_config))
|
||||
Zamba2MambaDecoderLayer(config,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix))
|
||||
self.layers = nn.ModuleList(layers)
|
||||
|
||||
# Final layer normalization
|
||||
@ -751,19 +765,30 @@ class Zamba2Model(nn.Module):
|
||||
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
|
||||
mamba2_metadata = prepare_mamba2_metadata(
|
||||
chunk_size=self.config.chunk_size,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
if not envs.VLLM_USE_V1:
|
||||
mamba2_metadata = prepare_mamba2_metadata(
|
||||
chunk_size=self.config.chunk_size,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
else:
|
||||
# v1 get mamba2_metadata from forward_context
|
||||
mamba2_metadata = None
|
||||
|
||||
# Process through layers
|
||||
original_hidden_states = torch.clone(hidden_states)
|
||||
for layer_idx, layer in enumerate(self.layers):
|
||||
|
||||
layer_mamba_cache_params = None
|
||||
if (isinstance(layer, (Zamba2HybridLayer, Zamba2MambaDecoderLayer))
|
||||
and mamba_cache_params):
|
||||
layer_mamba_cache_params = mamba_cache_params.at_layer_idx(
|
||||
layer_idx)
|
||||
|
||||
layer_outputs = layer(
|
||||
hidden_states,
|
||||
original_hidden_states=original_hidden_states,
|
||||
positions=positions,
|
||||
mamba_cache_params=mamba_cache_params.at_layer_idx(layer_idx),
|
||||
mamba_cache_params=layer_mamba_cache_params,
|
||||
mamba2_metadata=mamba2_metadata,
|
||||
)
|
||||
hidden_states = layer_outputs
|
||||
@ -803,7 +828,7 @@ class Zamba2Model(nn.Module):
|
||||
return loaded_params
|
||||
|
||||
|
||||
class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsV0Only):
|
||||
class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid):
|
||||
"""Zamba2 model with causal language modeling head.
|
||||
|
||||
This class wraps the core Zamba2 model and adds:
|
||||
@ -897,14 +922,16 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsV0Only):
|
||||
Output hidden states
|
||||
"""
|
||||
# Initialize Mamba cache if needed
|
||||
if self.mamba_cache is None:
|
||||
num_mamba_layers = self.config.num_hidden_layers
|
||||
self.mamba_cache = MambaCacheManager(
|
||||
self.vllm_config, self.lm_head.weight.dtype, num_mamba_layers,
|
||||
*self._get_mamba_cache_shape())
|
||||
mamba_cache_params = None
|
||||
if not envs.VLLM_USE_V1:
|
||||
if self.mamba_cache is None:
|
||||
num_mamba_layers = self.config.num_hidden_layers
|
||||
self.mamba_cache = MambaCacheManager(
|
||||
self.vllm_config, self.lm_head.weight.dtype,
|
||||
num_mamba_layers, *self._get_mamba_cache_shape())
|
||||
|
||||
# Get cache parameters for current run
|
||||
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
||||
# Get cache parameters for current run
|
||||
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
||||
|
||||
# Forward pass through model
|
||||
hidden_states = self.model(
|
||||
|
||||
@ -27,6 +27,7 @@ class KVCacheCoordinator(ABC):
|
||||
):
|
||||
self.kv_cache_config = kv_cache_config
|
||||
self.max_model_len = max_model_len
|
||||
self.enable_caching = enable_caching
|
||||
|
||||
self.block_pool = BlockPool(kv_cache_config.num_blocks, enable_caching,
|
||||
enable_kv_cache_events)
|
||||
@ -268,9 +269,13 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
|
||||
|
||||
self.full_attention_block_size = self.full_attention_spec.block_size
|
||||
self.other_block_size = self.other_spec.block_size
|
||||
assert self.other_block_size % self.full_attention_block_size == 0, (
|
||||
"KVCacheCoordinator assumes the block_size of full attention "
|
||||
"layers is divisible by other layers now.")
|
||||
|
||||
if self.enable_caching:
|
||||
# this requirement is only needed for the prefix caching logic
|
||||
divisible = self.other_block_size % self.full_attention_block_size
|
||||
assert divisible == 0, (
|
||||
"KVCacheCoordinator assumes the block_size of full "
|
||||
"attention layers is divisible by other layers now.")
|
||||
|
||||
if max(self.full_attention_group_ids) < min(self.other_group_ids):
|
||||
self.full_attn_first = True
|
||||
|
||||
@ -84,12 +84,15 @@ class KVCacheManager:
|
||||
self.log_stats = log_stats
|
||||
# FIXME: make prefix cache stats conditional on log_stats
|
||||
self.prefix_cache_stats = PrefixCacheStats() if log_stats else None
|
||||
assert len(
|
||||
set(g.kv_cache_spec.block_size
|
||||
for g in kv_cache_config.kv_cache_groups)
|
||||
) == 1, "Only one block size is supported for now"
|
||||
self.block_size = kv_cache_config.kv_cache_groups[
|
||||
0].kv_cache_spec.block_size
|
||||
|
||||
self.block_size: Optional[int] = None
|
||||
if self.enable_caching:
|
||||
assert len(
|
||||
set(g.kv_cache_spec.block_size
|
||||
for g in kv_cache_config.kv_cache_groups)
|
||||
) == 1, "Only one block size is supported for now"
|
||||
self.block_size = kv_cache_config.kv_cache_groups[
|
||||
0].kv_cache_spec.block_size
|
||||
|
||||
self.coordinator = get_kv_cache_coordinator(
|
||||
kv_cache_config=kv_cache_config,
|
||||
@ -154,6 +157,7 @@ class KVCacheManager:
|
||||
# if the scheduler has tried to schedule the request before.
|
||||
block_hashes = self.req_to_block_hashes[request.request_id]
|
||||
if not block_hashes:
|
||||
assert self.block_size is not None
|
||||
block_hashes = hash_request_tokens(self.caching_hash_fn,
|
||||
self.block_size, request)
|
||||
self.req_to_block_hashes[request.request_id] = block_hashes
|
||||
|
||||
@ -864,9 +864,11 @@ def _get_kv_cache_config_uniform_page_size(
|
||||
kv_cache_groups=kv_cache_groups,
|
||||
)
|
||||
|
||||
min_block_size = min(
|
||||
[group.kv_cache_spec.block_size for group in kv_cache_groups])
|
||||
|
||||
# Print the KV cache size and maximum concurrency.
|
||||
num_tokens = num_blocks // len(
|
||||
grouped_layers) * vllm_config.cache_config.block_size
|
||||
num_tokens = num_blocks // len(grouped_layers) * min_block_size
|
||||
num_tokens_str = f"{num_tokens:,}"
|
||||
logger.info("GPU KV cache size: %s tokens", num_tokens_str)
|
||||
max_model_len_str = f"{vllm_config.model_config.max_model_len:,}"
|
||||
|
||||
@ -159,6 +159,7 @@ class SlidingWindowSpec(AttentionSpec):
|
||||
class MambaSpec(KVCacheSpec):
|
||||
shapes: tuple[tuple[int, ...], ...]
|
||||
dtype: torch.dtype
|
||||
page_size_padded: Optional[int] = None
|
||||
|
||||
def __post_init__(self):
|
||||
self.num_elements = sum(prod(shape) for shape in self.shapes)
|
||||
@ -169,7 +170,11 @@ class MambaSpec(KVCacheSpec):
|
||||
|
||||
@property
|
||||
def page_size_bytes(self) -> int:
|
||||
return self.num_elements * get_dtype_size(self.dtype)
|
||||
page_size = self.num_elements * get_dtype_size(self.dtype)
|
||||
if self.page_size_padded is not None:
|
||||
assert self.page_size_padded >= page_size
|
||||
return self.page_size_padded
|
||||
return page_size
|
||||
|
||||
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
|
||||
# We allocate 1 block for each request now, so max_memory_usage_bytes is
|
||||
|
||||
@ -334,6 +334,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# the same order of requests. We ensure this by only allowing the first
|
||||
# group to reorder the batch and asserting that all other groups do not
|
||||
# reorder the batch.
|
||||
# TODO(tdoublep): make this more flexible so that any group can
|
||||
# re-order the batch (not only the first).
|
||||
# TODO(tdoublep): verify this during engine init instead of at runtime
|
||||
for i in range(1, len(self.kv_cache_config.kv_cache_groups)):
|
||||
batch_reordered = self.attn_metadata_builders[i].reorder_batch(
|
||||
self.input_batch, scheduler_output)
|
||||
@ -2449,6 +2452,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
corresponding memory buffer for KV cache.
|
||||
"""
|
||||
kv_caches: dict[str, torch.Tensor] = {}
|
||||
has_attn, has_mamba = False, False
|
||||
for i, kv_cache_group_spec in enumerate(
|
||||
kv_cache_config.kv_cache_groups):
|
||||
kv_cache_spec = kv_cache_group_spec.kv_cache_spec
|
||||
@ -2458,6 +2462,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
num_blocks = (raw_tensor.numel() //
|
||||
kv_cache_spec.page_size_bytes)
|
||||
if isinstance(kv_cache_spec, AttentionSpec):
|
||||
has_attn = True
|
||||
kv_cache_shape = self.attn_backends[i].get_kv_cache_shape(
|
||||
num_blocks, kv_cache_spec.block_size,
|
||||
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
|
||||
@ -2486,25 +2491,67 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
layer_name].view(dtype).view(kv_cache_shape).permute(
|
||||
*inv_order)
|
||||
elif isinstance(kv_cache_spec, MambaSpec):
|
||||
has_mamba = True
|
||||
raw_tensor = kv_cache_raw_tensors[layer_name]
|
||||
dtype = kv_cache_spec.dtype
|
||||
num_element_per_page = (kv_cache_spec.page_size_bytes //
|
||||
get_dtype_size(dtype))
|
||||
state_tensors = []
|
||||
start_pos = 0
|
||||
storage_offset = 0
|
||||
for shape in kv_cache_spec.shapes:
|
||||
target_shape = (num_blocks, *shape)
|
||||
size_in_bytes = np.prod(shape) * get_dtype_size(
|
||||
dtype) * num_blocks
|
||||
tensor = raw_tensor[start_pos:start_pos +
|
||||
size_in_bytes]
|
||||
tensor = tensor.view(dtype).view(target_shape)
|
||||
stride = torch.empty(target_shape).stride()
|
||||
target_stride = (num_element_per_page, *stride[1:])
|
||||
tensor = torch.as_strided(
|
||||
raw_tensor.view(dtype),
|
||||
size=target_shape,
|
||||
stride=target_stride,
|
||||
storage_offset=storage_offset,
|
||||
)
|
||||
state_tensors.append(tensor)
|
||||
start_pos += size_in_bytes
|
||||
assert start_pos == raw_tensor.numel()
|
||||
kv_caches[layer_name] = tuple(state_tensors)
|
||||
storage_offset += stride[0]
|
||||
|
||||
kv_caches[layer_name] = state_tensors
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
if has_attn and has_mamba:
|
||||
self._verify_hybrid_attention_mamba_layout(kv_cache_config,
|
||||
kv_cache_raw_tensors)
|
||||
|
||||
return kv_caches
|
||||
|
||||
def _verify_hybrid_attention_mamba_layout(
|
||||
self, kv_cache_config: KVCacheConfig,
|
||||
kv_cache_raw_tensors: dict[str, torch.Tensor]) -> None:
|
||||
"""
|
||||
Verify that the KV cache memory layout is compatible for
|
||||
models with both attention and mamba KV cache groups.
|
||||
|
||||
Args:
|
||||
kv_cache_config: The KV cache config
|
||||
kv_cache_raw_tensors: The KV cache buffer of each layer.
|
||||
"""
|
||||
|
||||
for i, kv_cache_group_spec in enumerate(
|
||||
kv_cache_config.kv_cache_groups):
|
||||
kv_cache_spec = kv_cache_group_spec.kv_cache_spec
|
||||
for layer_name in kv_cache_group_spec.layer_names:
|
||||
raw_tensor = kv_cache_raw_tensors[layer_name]
|
||||
num_blocks = (raw_tensor.numel() //
|
||||
kv_cache_spec.page_size_bytes)
|
||||
if isinstance(kv_cache_spec, AttentionSpec):
|
||||
kv_cache_shape = self.attn_backends[i].get_kv_cache_shape(
|
||||
num_blocks, kv_cache_spec.block_size,
|
||||
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
|
||||
if kv_cache_shape[0] != num_blocks or kv_cache_shape[
|
||||
1] != 2:
|
||||
raise ValueError(
|
||||
"Hybrid models in V1 require an attention "
|
||||
"backend with kv_cache_shape="
|
||||
"(num_blocks, 2, ...). Please try setting "
|
||||
"VLLM_ATTENTION_BACKEND=FLASHINFER")
|
||||
|
||||
def initialize_kv_cache_tensors(
|
||||
self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
@ -2623,11 +2670,69 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
raise NotImplementedError(
|
||||
"Prefix caching is not supported for Mamba yet.")
|
||||
max_model_len = self.vllm_config.model_config.max_model_len
|
||||
|
||||
page_size_padded = self._maybe_pad_mamba_page_size(
|
||||
attn_layers, mamba_layers, kv_cache_spec, max_model_len,
|
||||
block_size)
|
||||
|
||||
# Set block_size to max_model_len, so that mamba model will always
|
||||
# have only one block in the KV cache.
|
||||
for layer_name, mamba_module in mamba_layers.items():
|
||||
kv_cache_spec[layer_name] = MambaSpec(
|
||||
shapes=mamba_module.get_state_shape(),
|
||||
dtype=self.kv_cache_dtype,
|
||||
block_size=max_model_len)
|
||||
block_size=max_model_len,
|
||||
page_size_padded=page_size_padded)
|
||||
|
||||
return kv_cache_spec
|
||||
|
||||
def _maybe_pad_mamba_page_size(
|
||||
self,
|
||||
attn_layers: dict[str, Attention],
|
||||
mamba_layers: dict[str, MambaMixer2],
|
||||
kv_cache_spec: dict[str, KVCacheSpec],
|
||||
max_model_len: int,
|
||||
block_size: int,
|
||||
) -> Optional[int]:
|
||||
"""
|
||||
Ensure that page size of attention KV cache groups is greater than or
|
||||
equal to the mamba KV cache groups. If not, we suggest to the user
|
||||
how to set the attention block size to ensure that it is.
|
||||
|
||||
If the attention page size is strictly greater than the mamba page size,
|
||||
we pad the mamba page size to make them equal.
|
||||
|
||||
Args:
|
||||
attn_layers: Attention layers
|
||||
mamba_layers: Mamba layers
|
||||
kv_cache_spec: KV cache spec (populated with attention layers)
|
||||
|
||||
Returns:
|
||||
Optional[int]: Mamba page size with padding (None if no padding).
|
||||
"""
|
||||
|
||||
if len(attn_layers) == 0:
|
||||
return None
|
||||
|
||||
attn_layer_name = next(iter(attn_layers))
|
||||
attn_page_size = kv_cache_spec[attn_layer_name].page_size_bytes
|
||||
mamba_layer_name = next(iter(mamba_layers))
|
||||
mamba_page_size = MambaSpec(
|
||||
shapes=mamba_layers[mamba_layer_name].get_state_shape(),
|
||||
dtype=self.kv_cache_dtype,
|
||||
block_size=max_model_len).page_size_bytes
|
||||
if attn_page_size < mamba_page_size:
|
||||
# attention page size (for 16 tokens)
|
||||
attn_page_size_16 = 16 * attn_page_size // block_size
|
||||
# some attention backends (e.g. FA) only support setting
|
||||
# block size to multiple of 16, so let's suggest a value
|
||||
# that would work (note: FA is currently not compatible
|
||||
# with mamba layers, use FlashInfer instead).
|
||||
suggest_attn_block_size = 16 * cdiv(mamba_page_size,
|
||||
attn_page_size_16)
|
||||
raise ValueError(
|
||||
"Attention block size should be increased to at least "
|
||||
f"{suggest_attn_block_size} in order to match "
|
||||
"the mamba page size")
|
||||
|
||||
return attn_page_size
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user