[Model] PP support for Mamba-like models (#10992)

Signed-off-by: mzusman <mor.zusmann@gmail.com>
This commit is contained in:
Mor Zusman 2024-12-11 04:53:37 +02:00 committed by GitHub
parent d5c5154fcf
commit ffa48c9146
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 227 additions and 79 deletions

View File

@ -128,7 +128,7 @@ Text Generation
- FalconMamba - FalconMamba
- :code:`tiiuae/falcon-mamba-7b`, :code:`tiiuae/falcon-mamba-7b-instruct`, etc. - :code:`tiiuae/falcon-mamba-7b`, :code:`tiiuae/falcon-mamba-7b-instruct`, etc.
- ✅︎ - ✅︎
- - ✅︎
* - :code:`GemmaForCausalLM` * - :code:`GemmaForCausalLM`
- Gemma - Gemma
- :code:`google/gemma-2b`, :code:`google/gemma-7b`, etc. - :code:`google/gemma-2b`, :code:`google/gemma-7b`, etc.
@ -193,7 +193,7 @@ Text Generation
- Jamba - Jamba
- :code:`ai21labs/AI21-Jamba-1.5-Large`, :code:`ai21labs/AI21-Jamba-1.5-Mini`, :code:`ai21labs/Jamba-v0.1`, etc. - :code:`ai21labs/AI21-Jamba-1.5-Large`, :code:`ai21labs/AI21-Jamba-1.5-Mini`, :code:`ai21labs/Jamba-v0.1`, etc.
- ✅︎ - ✅︎
- - ✅︎
* - :code:`LlamaForCausalLM` * - :code:`LlamaForCausalLM`
- Llama 3.1, Llama 3, Llama 2, LLaMA, Yi - Llama 3.1, Llama 3, Llama 2, LLaMA, Yi
- :code:`meta-llama/Meta-Llama-3.1-405B-Instruct`, :code:`meta-llama/Meta-Llama-3.1-70B`, :code:`meta-llama/Meta-Llama-3-70B-Instruct`, :code:`meta-llama/Llama-2-70b-hf`, :code:`01-ai/Yi-34B`, etc. - :code:`meta-llama/Meta-Llama-3.1-405B-Instruct`, :code:`meta-llama/Meta-Llama-3.1-70B`, :code:`meta-llama/Meta-Llama-3-70B-Instruct`, :code:`meta-llama/Llama-2-70b-hf`, :code:`01-ai/Yi-34B`, etc.
@ -203,7 +203,7 @@ Text Generation
- Mamba - Mamba
- :code:`state-spaces/mamba-130m-hf`, :code:`state-spaces/mamba-790m-hf`, :code:`state-spaces/mamba-2.8b-hf`, etc. - :code:`state-spaces/mamba-130m-hf`, :code:`state-spaces/mamba-790m-hf`, :code:`state-spaces/mamba-2.8b-hf`, etc.
- -
- - ✅︎
* - :code:`MiniCPMForCausalLM` * - :code:`MiniCPMForCausalLM`
- MiniCPM - MiniCPM
- :code:`openbmb/MiniCPM-2B-sft-bf16`, :code:`openbmb/MiniCPM-2B-dpo-bf16`, :code:`openbmb/MiniCPM-S-1B-sft`, etc. - :code:`openbmb/MiniCPM-2B-sft-bf16`, :code:`openbmb/MiniCPM-2B-dpo-bf16`, :code:`openbmb/MiniCPM-S-1B-sft`, etc.

View File

@ -156,13 +156,13 @@ TEXT_GENERATION_MODELS = {
# "internlm/internlm-chat-7b": PPTestSettings.fast(), # "internlm/internlm-chat-7b": PPTestSettings.fast(),
"internlm/internlm2-chat-7b": PPTestSettings.fast(trust_remote_code=True), "internlm/internlm2-chat-7b": PPTestSettings.fast(trust_remote_code=True),
"inceptionai/jais-13b-chat": PPTestSettings.fast(), "inceptionai/jais-13b-chat": PPTestSettings.fast(),
# TODO: Implement PP "ai21labs/Jamba-tiny-dev": PPTestSettings.fast(),
# "ai21labs/AI21-Jamba-1.5-Mini": PPTestSettings.fast(),
"meta-llama/Meta-Llama-3-8B": PPTestSettings.detailed(), "meta-llama/Meta-Llama-3-8B": PPTestSettings.detailed(),
"openbmb/MiniCPM-2B-sft-bf16": PPTestSettings.fast(trust_remote_code=True), "openbmb/MiniCPM-2B-sft-bf16": PPTestSettings.fast(trust_remote_code=True),
"openbmb/MiniCPM3-4B": PPTestSettings.fast(trust_remote_code=True), "openbmb/MiniCPM3-4B": PPTestSettings.fast(trust_remote_code=True),
# Uses Llama # Uses Llama
# "mistralai/Mistral-7B-Instruct-v0.1": PPTestSettings.fast(), # "mistralai/Mistral-7B-Instruct-v0.1": PPTestSettings.fast(),
"state-spaces/mamba-130m-hf": PPTestSettings.fast(),
"mistralai/Mixtral-8x7B-Instruct-v0.1": PPTestSettings.fast(tp_base=4), "mistralai/Mixtral-8x7B-Instruct-v0.1": PPTestSettings.fast(tp_base=4),
"mosaicml/mpt-7b": PPTestSettings.fast(), "mosaicml/mpt-7b": PPTestSettings.fast(),
"nvidia/Minitron-8B-Base": PPTestSettings.fast(), "nvidia/Minitron-8B-Base": PPTestSettings.fast(),
@ -234,6 +234,8 @@ TEST_MODELS = [
"OpenGVLab/InternVL2-1B", "OpenGVLab/InternVL2-1B",
"microsoft/Phi-3-vision-128k-instruct", "microsoft/Phi-3-vision-128k-instruct",
"fixie-ai/ultravox-v0_3", "fixie-ai/ultravox-v0_3",
# [LANGUAGE GENERATION - HYBRID ARCH]
"ai21labs/Jamba-tiny-dev",
] ]

View File

@ -27,8 +27,8 @@ from vllm.transformers_utils.config import (
ConfigFormat, get_config, get_hf_image_processor_config, ConfigFormat, get_config, get_hf_image_processor_config,
get_hf_text_config, get_pooling_config, get_hf_text_config, get_pooling_config,
get_sentence_transformer_tokenizer_config, is_encoder_decoder, uses_mrope) get_sentence_transformer_tokenizer_config, is_encoder_decoder, uses_mrope)
from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory, from vllm.utils import (GiB_bytes, LayerBlockType, cuda_device_count_stateless,
print_warning_once, random_uuid, get_cpu_memory, print_warning_once, random_uuid,
resolve_obj_by_qualname) resolve_obj_by_qualname)
if TYPE_CHECKING: if TYPE_CHECKING:
@ -284,6 +284,7 @@ class ModelConfig:
self._verify_tokenizer_mode() self._verify_tokenizer_mode()
self.is_attention_free = self._init_attention_free() self.is_attention_free = self._init_attention_free()
self.is_hybrid = self._init_is_hybrid()
self.has_inner_state = self._init_has_inner_state() self.has_inner_state = self._init_has_inner_state()
if current_platform.is_neuron(): if current_platform.is_neuron():
@ -340,6 +341,10 @@ class ModelConfig:
architectures = getattr(self.hf_config, "architectures", []) architectures = getattr(self.hf_config, "architectures", [])
return ModelRegistry.is_attention_free_model(architectures) return ModelRegistry.is_attention_free_model(architectures)
def _init_is_hybrid(self) -> bool:
architectures = getattr(self.hf_config, "architectures", [])
return ModelRegistry.is_hybrid_model(architectures)
def _init_has_inner_state(self) -> bool: def _init_has_inner_state(self) -> bool:
architectures = getattr(self.hf_config, "architectures", []) architectures = getattr(self.hf_config, "architectures", [])
return ModelRegistry.model_has_inner_state(architectures) return ModelRegistry.model_has_inner_state(architectures)
@ -669,26 +674,51 @@ class ModelConfig:
num_heads = getattr(self.hf_text_config, "num_attention_heads", 0) num_heads = getattr(self.hf_text_config, "num_attention_heads", 0)
return num_heads // parallel_config.tensor_parallel_size return num_heads // parallel_config.tensor_parallel_size
def get_num_layers(self, parallel_config: "ParallelConfig") -> int: def get_layers_start_end_indices(
self, parallel_config: "ParallelConfig") -> Tuple[int, int]:
from vllm.distributed.utils import get_pp_indices from vllm.distributed.utils import get_pp_indices
total_num_hidden_layers = getattr(self.hf_text_config, total_num_hidden_layers = getattr(self.hf_text_config,
"num_hidden_layers", 0) "num_hidden_layers", 0)
pp_rank = parallel_config.rank // parallel_config.tensor_parallel_size pp_rank = parallel_config.rank // parallel_config.tensor_parallel_size
pp_size = parallel_config.pipeline_parallel_size pp_size = parallel_config.pipeline_parallel_size
start, end = get_pp_indices(total_num_hidden_layers, pp_rank, pp_size) start, end = get_pp_indices(total_num_hidden_layers, pp_rank, pp_size)
return start, end
def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
start, end = self.get_layers_start_end_indices(parallel_config)
return end - start return end - start
def get_num_attention_layers(self, def get_num_layers_by_block_type(
parallel_config: "ParallelConfig") -> int: self,
if self.is_attention_free: parallel_config: "ParallelConfig",
return 0 block_type: LayerBlockType = LayerBlockType.attention,
) -> int:
# This function relies on 'layers_block_type' in hf_config,
# for w/o this attribute, we will need to have workarounds like so
attn_block_type = block_type == LayerBlockType.attention
is_transformer = not self.is_hybrid and not self.is_attention_free
start, end = self.get_layers_start_end_indices(parallel_config)
num_layers = self.get_num_layers(parallel_config) if is_transformer:
# Handle the basic case first
return end - start if attn_block_type else 0
elif self.is_attention_free:
# Attention free
# Note that this code assumes there
# is only one type of attention-free block type.
return 0 if attn_block_type else end - start
else:
# Hybrid model
layers_block_type_value = getattr(self.hf_config,
"layers_block_type", None)
if layers_block_type_value is None:
raise ValueError("The model is an hybrid without a"
"layers_block_type in the hf_config,"
"cannot determine the num of "
f"{block_type.value} layers")
# Transformers supports layers_block_type @property return sum(t == block_type.value
layers = getattr(self.hf_config, "layers_block_type", for t in layers_block_type_value[start:end])
["attention"] * num_layers)
return len([t for t in layers if t == "attention"])
def get_multimodal_config(self) -> "MultiModalConfig": def get_multimodal_config(self) -> "MultiModalConfig":
""" """

View File

@ -363,6 +363,43 @@ def is_attention_free(
return isinstance(model, IsAttentionFree) return isinstance(model, IsAttentionFree)
@runtime_checkable
class IsHybrid(Protocol):
"""The interface required for all models like Jamba that have both
attention and mamba blocks, indicates that
hf_config has 'layers_block_type'"""
is_hybrid: ClassVar[Literal[True]] = True
"""
A flag that indicates this model has both mamba and attention blocks
, also indicates that the model's hf_config has
'layers_block_type' """
@runtime_checkable
class _IsHybridType(Protocol):
is_hybrid: ClassVar[Literal[True]]
@overload
def is_hybrid(model: object) -> TypeIs[IsHybrid]:
...
@overload
def is_hybrid(model: Type[object]) -> TypeIs[Type[IsHybrid]]:
...
def is_hybrid(
model: Union[Type[object], object]
) -> Union[TypeIs[Type[IsHybrid]], TypeIs[IsHybrid]]:
if isinstance(model, type):
return isinstance(model, _IsHybridType)
return isinstance(model, IsHybrid)
@runtime_checkable @runtime_checkable
class SupportsCrossEncoding(Protocol): class SupportsCrossEncoding(Protocol):
"""The interface required for all models that support cross encoding.""" """The interface required for all models that support cross encoding."""

View File

@ -9,6 +9,7 @@ from vllm.attention.backends.abstract import AttentionMetadata
from vllm.attention.layer import Attention from vllm.attention.layer import Attention
from vllm.config import _BATCH_SIZES_TO_CAPTURE, CacheConfig, VllmConfig from vllm.config import _BATCH_SIZES_TO_CAPTURE, CacheConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed.parallel_state import get_pp_group
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (QKVParallelLinear, from vllm.model_executor.layers.linear import (QKVParallelLinear,
@ -25,9 +26,12 @@ from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
MambaCacheParams) MambaCacheParams)
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import LayerBlockType
from .interfaces import HasInnerState, SupportsLoRA from .interfaces import HasInnerState, IsHybrid, SupportsLoRA, SupportsPP
from .utils import maybe_prefix from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
KVCache = Tuple[torch.Tensor, torch.Tensor] KVCache = Tuple[torch.Tensor, torch.Tensor]
@ -281,16 +285,24 @@ class JambaModel(nn.Module):
org_num_embeddings=config.vocab_size, org_num_embeddings=config.vocab_size,
) )
decoder_layers = [] def get_layer(prefix: str):
for i in range(config.num_hidden_layers): layer_idx = int(prefix.rsplit(".", 1)[1])
layer_class = ALL_DECODER_LAYER_TYPES[config.layers_block_type[i]] layer_class = ALL_DECODER_LAYER_TYPES[
decoder_layers.append( config.layers_block_type[layer_idx]]
layer_class(config, return layer_class(
layer_idx=i, config,
cache_config=cache_config, layer_idx,
quant_config=quant_config, cache_config,
prefix=f"{prefix}.layers.{i}")) quant_config=quant_config,
self.layers = nn.ModuleList(decoder_layers) prefix=prefix,
)
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers")
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
self.final_layernorm = RMSNorm(config.hidden_size, self.final_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps) eps=config.rms_norm_eps)
@ -304,26 +316,34 @@ class JambaModel(nn.Module):
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
mamba_cache_params: MambaCacheParams, mamba_cache_params: MambaCacheParams,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
if inputs_embeds is not None: if get_pp_group().is_first_rank:
hidden_states = inputs_embeds if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
residual = None
else: else:
hidden_states = self.get_input_embeddings(input_ids) assert intermediate_tensors is not None
residual = None hidden_states = intermediate_tensors["hidden_states"]
for i in range(len(self.layers)): residual = intermediate_tensors["residual"]
kv_cache_index = 0
mamba_cache_index = 0
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i] layer = self.layers[i]
kv_cache = None kv_cache = None
layer_mamba_cache_params = None layer_mamba_cache_params = None
if isinstance(layer, JambaAttentionDecoderLayer): if isinstance(layer, JambaAttentionDecoderLayer):
kv_cache = kv_caches[(i - self.config.attn_layer_offset) // kv_cache = kv_caches[kv_cache_index]
self.config.attn_layer_period] kv_cache_index += 1
if isinstance(layer, JambaMambaDecoderLayer): if isinstance(layer, JambaMambaDecoderLayer):
current_state_layer = i - (1 + current_state_layer = mamba_cache_index
(i - self.config.attn_layer_offset)
// self.config.attn_layer_period)
layer_mamba_cache_params = mamba_cache_params.at_layer_idx( layer_mamba_cache_params = mamba_cache_params.at_layer_idx(
current_state_layer) current_state_layer)
mamba_cache_index += 1
hidden_states, residual = layer( hidden_states, residual = layer(
positions=positions, positions=positions,
@ -332,11 +352,17 @@ class JambaModel(nn.Module):
attn_metadata=attn_metadata, attn_metadata=attn_metadata,
residual=residual, residual=residual,
mamba_cache_params=layer_mamba_cache_params) mamba_cache_params=layer_mamba_cache_params)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states, _ = self.final_layernorm(hidden_states, residual) hidden_states, _ = self.final_layernorm(hidden_states, residual)
return hidden_states return hidden_states
class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA): class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
IsHybrid):
packed_modules_mapping = { packed_modules_mapping = {
"qkv_proj": [ "qkv_proj": [
"q_proj", "q_proj",
@ -368,6 +394,8 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
super().__init__() super().__init__()
self.config = config self.config = config
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.scheduler_config = scheduler_config self.scheduler_config = scheduler_config
self.model = JambaModel(vllm_config=vllm_config, self.model = JambaModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model")) prefix=maybe_prefix(prefix, "model"))
@ -390,6 +418,9 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
config.vocab_size) config.vocab_size)
self.sampler = get_sampler() self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.get_input_embeddings(input_ids)
@ -406,10 +437,8 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
self.scheduler_config.max_num_seqs) if self.scheduler_config self.scheduler_config.max_num_seqs) if self.scheduler_config
else max(_BATCH_SIZES_TO_CAPTURE) + 2) else max(_BATCH_SIZES_TO_CAPTURE) + 2)
layers_type = self.config.layers_block_type num_mamba_layers = self.model_config.get_num_layers_by_block_type(
num_mamba_layers = sum( self.vllm_config.parallel_config, LayerBlockType.mamba)
[layer_type == "mamba" for layer_type in layers_type])
self.mamba_cache = MambaCacheManager( self.mamba_cache = MambaCacheManager(
self.lm_head.weight.dtype, num_mamba_layers, max_batch_size, self.lm_head.weight.dtype, num_mamba_layers, max_batch_size,
*self._get_mamba_cache_shape()) *self._get_mamba_cache_shape())
@ -423,7 +452,7 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
state_indices_tensor) state_indices_tensor)
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, mamba_cache_params, attn_metadata, mamba_cache_params,
inputs_embeds) intermediate_tensors, inputs_embeds)
return hidden_states return hidden_states
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
@ -504,8 +533,12 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
continue continue
name = name.replace(weight_name, param_name) name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name] param = params_dict[name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id) weight_loader(param, loaded_weight, shard_id)
@ -520,6 +553,8 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
if weight_name not in name: if weight_name not in name:
continue continue
if is_pp_missing_parameter(name, self):
continue
name = name.replace(weight_name, param_name) name = name.replace(weight_name, param_name)
param = params_dict[name] param = params_dict[name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
@ -533,6 +568,8 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",

View File

@ -8,6 +8,7 @@ from transformers import MambaConfig
from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import _BATCH_SIZES_TO_CAPTURE, CacheConfig, VllmConfig from vllm.config import _BATCH_SIZES_TO_CAPTURE, CacheConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed.parallel_state import get_pp_group
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer
@ -18,13 +19,16 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import (HasInnerState, from vllm.model_executor.models.interfaces import (HasInnerState,
IsAttentionFree) IsAttentionFree, SupportsPP)
from vllm.model_executor.models.mamba_cache import (MambaCacheManager, from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
MambaCacheParams) MambaCacheParams)
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import LayerBlockType
from .utils import maybe_prefix from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
KVCache = Tuple[torch.Tensor, torch.Tensor] KVCache = Tuple[torch.Tensor, torch.Tensor]
@ -95,15 +99,17 @@ class MambaModel(nn.Module):
org_num_embeddings=config.vocab_size, org_num_embeddings=config.vocab_size,
) )
decoder_layers = [] self.start_layer, self.end_layer, self.layers = make_layers(
for i in range(config.num_hidden_layers): config.num_hidden_layers,
decoder_layers.append( lambda prefix: MambaDecoderLayer(
MambaDecoderLayer(config, config, cache_config=cache_config, quant_config=quant_config),
cache_config=cache_config, prefix=f"{prefix}.layers")
quant_config=quant_config))
self.layers = nn.ModuleList(decoder_layers)
self.norm_f = RMSNorm(config.hidden_size, self.norm_f = RMSNorm(config.hidden_size,
eps=config.layer_norm_epsilon) eps=config.layer_norm_epsilon)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embeddings(input_ids) return self.embeddings(input_ids)
@ -114,29 +120,40 @@ class MambaModel(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
mamba_cache_params: MambaCacheParams, mamba_cache_params: MambaCacheParams,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
if get_pp_group().is_first_rank:
if inputs_embeds is not None: if inputs_embeds is not None:
hidden_states = inputs_embeds hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
residual = None
else: else:
hidden_states = self.get_input_embeddings(input_ids) assert intermediate_tensors is not None
residual = None hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for i in range(len(self.layers)): for i in range(self.start_layer, self.end_layer):
layer = self.layers[i] layer = self.layers[i]
hidden_states, residual = layer( hidden_states, residual = layer(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
attn_metadata=attn_metadata, attn_metadata=attn_metadata,
residual=residual, residual=residual,
mamba_cache_params=mamba_cache_params.at_layer_idx(i)) mamba_cache_params=mamba_cache_params.at_layer_idx(
i - self.start_layer))
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states, _ = self.norm_f(hidden_states, residual) hidden_states, _ = self.norm_f(hidden_states, residual)
return hidden_states return hidden_states
class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree): class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
@ -148,7 +165,9 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
super().__init__() super().__init__()
self.config = config self.config = config
self.vllm_config = vllm_config
self.scheduler_config = scheduler_config self.scheduler_config = scheduler_config
self.model_config = vllm_config.model_config
self.backbone = MambaModel(vllm_config=vllm_config, self.backbone = MambaModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "backbone")) prefix=maybe_prefix(prefix, "backbone"))
self.unpadded_vocab_size = config.vocab_size self.unpadded_vocab_size = config.vocab_size
@ -174,6 +193,9 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
config.vocab_size) config.vocab_size)
self.sampler = get_sampler() self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.backbone.make_empty_intermediate_tensors)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.backbone.get_input_embeddings(input_ids) return self.backbone.get_input_embeddings(input_ids)
@ -189,9 +211,12 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
max_batch_size = (VllmConfig.get_graph_batch_size( max_batch_size = (VllmConfig.get_graph_batch_size(
self.scheduler_config.max_num_seqs) if self.scheduler_config self.scheduler_config.max_num_seqs) if self.scheduler_config
else max(_BATCH_SIZES_TO_CAPTURE) + 2) else max(_BATCH_SIZES_TO_CAPTURE) + 2)
num_mamba_layers = self.model_config.get_num_layers_by_block_type(
self.vllm_config.parallel_config, LayerBlockType.mamba)
self.mamba_cache = MambaCacheManager( self.mamba_cache = MambaCacheManager(
self.lm_head.weight.dtype, self.config.num_hidden_layers, self.lm_head.weight.dtype, num_mamba_layers, max_batch_size,
max_batch_size, *self._get_mamba_cache_shape()) *self._get_mamba_cache_shape())
( (
mamba_cache_tensors, mamba_cache_tensors,
@ -204,7 +229,8 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
state_indices_tensor) state_indices_tensor)
hidden_states = self.backbone(input_ids, positions, attn_metadata, hidden_states = self.backbone(input_ids, positions, attn_metadata,
mamba_cache_params, inputs_embeds) mamba_cache_params, intermediate_tensors,
inputs_embeds)
return hidden_states return hidden_states
@ -252,6 +278,8 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",

View File

@ -21,7 +21,7 @@ from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from .adapters import as_embedding_model from .adapters import as_embedding_model
from .interfaces import (has_inner_state, is_attention_free, from .interfaces import (has_inner_state, is_attention_free, is_hybrid,
supports_cross_encoding, supports_multimodal, supports_cross_encoding, supports_multimodal,
supports_pp) supports_pp)
from .interfaces_base import is_pooling_model, is_text_generation_model from .interfaces_base import is_pooling_model, is_text_generation_model
@ -218,6 +218,7 @@ class _ModelInfo:
supports_pp: bool supports_pp: bool
has_inner_state: bool has_inner_state: bool
is_attention_free: bool is_attention_free: bool
is_hybrid: bool
@staticmethod @staticmethod
def from_model_cls(model: Type[nn.Module]) -> "_ModelInfo": def from_model_cls(model: Type[nn.Module]) -> "_ModelInfo":
@ -239,6 +240,7 @@ class _ModelInfo:
supports_pp=supports_pp(model), supports_pp=supports_pp(model),
has_inner_state=has_inner_state(model), has_inner_state=has_inner_state(model),
is_attention_free=is_attention_free(model), is_attention_free=is_attention_free(model),
is_hybrid=is_hybrid(model),
) )
@ -484,6 +486,13 @@ class _ModelRegistry:
model_cls, _ = self.inspect_model_cls(architectures) model_cls, _ = self.inspect_model_cls(architectures)
return model_cls.is_attention_free return model_cls.is_attention_free
def is_hybrid_model(
self,
architectures: Union[str, List[str]],
) -> bool:
model_cls, _ = self.inspect_model_cls(architectures)
return model_cls.is_hybrid
ModelRegistry = _ModelRegistry({ ModelRegistry = _ModelRegistry({
model_arch: _LazyRegisteredModel( model_arch: _LazyRegisteredModel(

View File

@ -170,6 +170,11 @@ class Device(enum.Enum):
CPU = enum.auto() CPU = enum.auto()
class LayerBlockType(enum.Enum):
attention = "attention"
mamba = "mamba"
class Counter: class Counter:
def __init__(self, start: int = 0) -> None: def __init__(self, start: int = 0) -> None:

View File

@ -15,8 +15,8 @@ from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
from vllm.multimodal import MultiModalKwargs from vllm.multimodal import MultiModalKwargs
from vllm.sampling_params import SamplingType from vllm.sampling_params import SamplingType
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, cdiv, from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
is_pin_memory_available) LayerBlockType, cdiv, is_pin_memory_available)
from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend, from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend,
FlashAttentionMetadata) FlashAttentionMetadata)
from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.outputs import ModelRunnerOutput
@ -68,8 +68,8 @@ class GPUModelRunner:
self.max_num_tokens = scheduler_config.max_num_batched_tokens self.max_num_tokens = scheduler_config.max_num_batched_tokens
# Model-related. # Model-related.
self.num_attn_layers = model_config.get_num_attention_layers( self.num_attn_layers = model_config.get_num_layers_by_block_type(
parallel_config) parallel_config, LayerBlockType.attention)
self.num_kv_heads = model_config.get_num_kv_heads(parallel_config) self.num_kv_heads = model_config.get_num_kv_heads(parallel_config)
self.head_size = model_config.get_head_size() self.head_size = model_config.get_head_size()
self.hidden_size = model_config.get_hidden_size() self.hidden_size = model_config.get_hidden_size()

View File

@ -14,7 +14,7 @@ from vllm.distributed import (ensure_model_parallel_initialized,
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor import set_random_seed from vllm.model_executor import set_random_seed
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, LayerBlockType, get_dtype_size
from vllm.v1.core.scheduler import SchedulerOutput from vllm.v1.core.scheduler import SchedulerOutput
from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.worker.gpu_model_runner import GPUModelRunner from vllm.v1.worker.gpu_model_runner import GPUModelRunner
@ -260,8 +260,8 @@ def _get_cache_block_size(
) -> int: ) -> int:
head_size = model_config.get_head_size() head_size = model_config.get_head_size()
num_heads = model_config.get_num_kv_heads(parallel_config) num_heads = model_config.get_num_kv_heads(parallel_config)
num_attention_layers = model_config.get_num_attention_layers( num_attention_layers = model_config.get_num_layers_by_block_type(
parallel_config) parallel_config, LayerBlockType.attention)
key_cache_block = cache_config.block_size * num_heads * head_size key_cache_block = cache_config.block_size * num_heads * head_size
value_cache_block = key_cache_block value_cache_block = key_cache_block

View File

@ -6,8 +6,8 @@ import torch
from vllm.attention import get_attn_backend from vllm.attention import get_attn_backend
from vllm.config import CacheConfig, DeviceConfig, ModelConfig, ParallelConfig from vllm.config import CacheConfig, DeviceConfig, ModelConfig, ParallelConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size, from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, LayerBlockType,
is_pin_memory_available) get_dtype_size, is_pin_memory_available)
logger = init_logger(__name__) logger = init_logger(__name__)
@ -34,8 +34,8 @@ class CacheEngine:
self.head_size = model_config.get_head_size() self.head_size = model_config.get_head_size()
# Models like Jamba, have mixed typed layers, E.g Mamba # Models like Jamba, have mixed typed layers, E.g Mamba
self.num_attention_layers = model_config.get_num_attention_layers( self.num_attention_layers = model_config.get_num_layers_by_block_type(
parallel_config) parallel_config, LayerBlockType.attention)
self.num_kv_heads = model_config.get_num_kv_heads(parallel_config) self.num_kv_heads = model_config.get_num_kv_heads(parallel_config)
self.block_size = cache_config.block_size self.block_size = cache_config.block_size
@ -105,8 +105,8 @@ class CacheEngine:
) -> int: ) -> int:
head_size = model_config.get_head_size() head_size = model_config.get_head_size()
num_heads = model_config.get_num_kv_heads(parallel_config) num_heads = model_config.get_num_kv_heads(parallel_config)
num_attention_layers = model_config.get_num_attention_layers( num_attention_layers = model_config.get_num_layers_by_block_type(
parallel_config) parallel_config, LayerBlockType.attention)
key_cache_block = cache_config.block_size * num_heads * head_size key_cache_block = cache_config.block_size * num_heads * head_size
value_cache_block = key_cache_block value_cache_block = key_cache_block