mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-20 06:45:01 +08:00
[V1] v1 engine + full CUDA graph support for PLaMo2 (#23998)
Signed-off-by: Hemmi Shinichi <shemmi@preferred.jp> Signed-off-by: nopperl <54780682+nopperl@users.noreply.github.com> Co-authored-by: Hemmi Shinichi <shemmi@preferred.jp> Co-authored-by: Thomas Parnell <tom.parnell@gmail.com>
This commit is contained in:
parent
6d80ae83e1
commit
fa4311d85f
@ -395,7 +395,7 @@ th {
|
|||||||
| `PhiMoEForCausalLM` | Phi-3.5-MoE | `microsoft/Phi-3.5-MoE-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
| `PhiMoEForCausalLM` | Phi-3.5-MoE | `microsoft/Phi-3.5-MoE-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||||
| `Phi4FlashForCausalLM` | Phi-4-mini-flash-reasoning | `microsoft/microsoft/Phi-4-mini-instruct`, etc. | | | |
|
| `Phi4FlashForCausalLM` | Phi-4-mini-flash-reasoning | `microsoft/microsoft/Phi-4-mini-instruct`, etc. | | | |
|
||||||
| `PersimmonForCausalLM` | Persimmon | `adept/persimmon-8b-base`, `adept/persimmon-8b-chat`, etc. | | ✅︎ | ✅︎ |
|
| `PersimmonForCausalLM` | Persimmon | `adept/persimmon-8b-base`, `adept/persimmon-8b-chat`, etc. | | ✅︎ | ✅︎ |
|
||||||
| `Plamo2ForCausalLM` | PLaMo2 | `pfnet/plamo-2-1b`, `pfnet/plamo-2-8b`, etc. | | ✅︎ | |
|
| `Plamo2ForCausalLM` | PLaMo2 | `pfnet/plamo-2-1b`, `pfnet/plamo-2-8b`, etc. | | ✅︎ | ✅︎ |
|
||||||
| `QWenLMHeadModel` | Qwen | `Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
| `QWenLMHeadModel` | Qwen | `Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||||
| `Qwen2ForCausalLM` | QwQ, Qwen2 | `Qwen/QwQ-32B-Preview`, `Qwen/Qwen2-7B-Instruct`, `Qwen/Qwen2-7B`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
| `Qwen2ForCausalLM` | QwQ, Qwen2 | `Qwen/QwQ-32B-Preview`, `Qwen/Qwen2-7B-Instruct`, `Qwen/Qwen2-7B`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||||
| `Qwen2MoeForCausalLM` | Qwen2MoE | `Qwen/Qwen1.5-MoE-A2.7B`, `Qwen/Qwen1.5-MoE-A2.7B-Chat`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
| `Qwen2MoeForCausalLM` | Qwen2MoE | `Qwen/Qwen1.5-MoE-A2.7B`, `Qwen/Qwen1.5-MoE-A2.7B-Chat`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||||
|
|||||||
@ -110,7 +110,7 @@ Models using selective state-space mechanisms instead of standard transformer at
|
|||||||
Models that use Mamba-2 and Mamba-1 layers (e.g., `Mamba2ForCausalLM`, `MambaForCausalLM`,`FalconMambaForCausalLM`) are supported.
|
Models that use Mamba-2 and Mamba-1 layers (e.g., `Mamba2ForCausalLM`, `MambaForCausalLM`,`FalconMambaForCausalLM`) are supported.
|
||||||
|
|
||||||
Hybrid models that combine Mamba-2 and Mamba-1 layers with standard attention layers are also supported (e.g., `BambaForCausalLM`,
|
Hybrid models that combine Mamba-2 and Mamba-1 layers with standard attention layers are also supported (e.g., `BambaForCausalLM`,
|
||||||
`Zamba2ForCausalLM`, `NemotronHForCausalLM`, `FalconH1ForCausalLM` and `GraniteMoeHybridForCausalLM`, `JambaForCausalLM`).
|
`Zamba2ForCausalLM`, `NemotronHForCausalLM`, `FalconH1ForCausalLM` and `GraniteMoeHybridForCausalLM`, `JambaForCausalLM`, `Plamo2ForCausalLM`).
|
||||||
|
|
||||||
Hybrid models with mechanisms different to Mamba are also supported (e.g, `MiniMaxText01ForCausalLM`, `MiniMaxM1ForCausalLM`, `Lfm2ForCausalLM`).
|
Hybrid models with mechanisms different to Mamba are also supported (e.g, `MiniMaxText01ForCausalLM`, `MiniMaxM1ForCausalLM`, `Lfm2ForCausalLM`).
|
||||||
|
|
||||||
|
|||||||
@ -25,8 +25,7 @@ SSM_MODELS = [
|
|||||||
|
|
||||||
HYBRID_MODELS = [
|
HYBRID_MODELS = [
|
||||||
"ai21labs/Jamba-tiny-dev",
|
"ai21labs/Jamba-tiny-dev",
|
||||||
# skipping until vLLM implementation issues are resolved
|
"pfnet/plamo-2-1b",
|
||||||
# "pfnet/plamo-2-1b",
|
|
||||||
"Zyphra/Zamba2-1.2B-instruct",
|
"Zyphra/Zamba2-1.2B-instruct",
|
||||||
"hmellor/tiny-random-BambaForCausalLM",
|
"hmellor/tiny-random-BambaForCausalLM",
|
||||||
"ibm-granite/granite-4.0-tiny-preview",
|
"ibm-granite/granite-4.0-tiny-preview",
|
||||||
@ -37,6 +36,7 @@ HYBRID_MODELS = [
|
|||||||
V1_SUPPORTED_MODELS = [
|
V1_SUPPORTED_MODELS = [
|
||||||
"state-spaces/mamba-130m-hf",
|
"state-spaces/mamba-130m-hf",
|
||||||
"ai21labs/Jamba-tiny-dev",
|
"ai21labs/Jamba-tiny-dev",
|
||||||
|
"pfnet/plamo-2-1b",
|
||||||
"yujiepan/mamba2-codestral-v0.1-tiny-random",
|
"yujiepan/mamba2-codestral-v0.1-tiny-random",
|
||||||
"Zyphra/Zamba2-1.2B-instruct",
|
"Zyphra/Zamba2-1.2B-instruct",
|
||||||
"hmellor/tiny-random-BambaForCausalLM",
|
"hmellor/tiny-random-BambaForCausalLM",
|
||||||
@ -47,6 +47,7 @@ V1_SUPPORTED_MODELS = [
|
|||||||
|
|
||||||
FULL_CUDA_GRAPH_MODELS = [
|
FULL_CUDA_GRAPH_MODELS = [
|
||||||
"ai21labs/Jamba-tiny-dev",
|
"ai21labs/Jamba-tiny-dev",
|
||||||
|
"pfnet/plamo-2-1b",
|
||||||
"Zyphra/Zamba2-1.2B-instruct",
|
"Zyphra/Zamba2-1.2B-instruct",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@ -287,8 +287,6 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
|||||||
"PhiMoEForCausalLM": _HfExamplesInfo("microsoft/Phi-3.5-MoE-instruct",
|
"PhiMoEForCausalLM": _HfExamplesInfo("microsoft/Phi-3.5-MoE-instruct",
|
||||||
trust_remote_code=True),
|
trust_remote_code=True),
|
||||||
"Plamo2ForCausalLM": _HfExamplesInfo("pfnet/plamo-2-1b",
|
"Plamo2ForCausalLM": _HfExamplesInfo("pfnet/plamo-2-1b",
|
||||||
max_transformers_version="4.53",
|
|
||||||
transformers_version_reason="vLLM impl inherits PreTrainedModel and clashes with get_input_embeddings", # noqa: E501
|
|
||||||
trust_remote_code=True),
|
trust_remote_code=True),
|
||||||
"QWenLMHeadModel": _HfExamplesInfo("Qwen/Qwen-7B-Chat",
|
"QWenLMHeadModel": _HfExamplesInfo("Qwen/Qwen-7B-Chat",
|
||||||
max_transformers_version="4.53",
|
max_transformers_version="4.53",
|
||||||
|
|||||||
@ -340,6 +340,7 @@ class CompilationConfig:
|
|||||||
"vllm.mamba_mixer",
|
"vllm.mamba_mixer",
|
||||||
"vllm.short_conv",
|
"vllm.short_conv",
|
||||||
"vllm.linear_attention",
|
"vllm.linear_attention",
|
||||||
|
"vllm.plamo2_mamba_mixer",
|
||||||
]
|
]
|
||||||
|
|
||||||
def compute_hash(self) -> str:
|
def compute_hash(self) -> str:
|
||||||
|
|||||||
@ -3,19 +3,24 @@
|
|||||||
"""Inference-only PLaMo2 model."""
|
"""Inference-only PLaMo2 model."""
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
from itertools import islice
|
from itertools import islice
|
||||||
from typing import Optional
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from vllm.attention.backends.abstract import AttentionBackend
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import PretrainedConfig, PreTrainedModel
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
|
from vllm import envs
|
||||||
from vllm.attention.backends.abstract import AttentionMetadata
|
from vllm.attention.backends.abstract import AttentionMetadata
|
||||||
from vllm.attention.layer import Attention
|
from vllm.attention.layer import Attention
|
||||||
from vllm.compilation.decorators import support_torch_compile
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig, get_current_vllm_config
|
||||||
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
||||||
from vllm.distributed.parallel_state import get_pp_group
|
from vllm.distributed.parallel_state import get_pp_group
|
||||||
from vllm.forward_context import get_forward_context
|
from vllm.forward_context import ForwardContext, get_forward_context
|
||||||
|
from vllm.model_executor.custom_op import CustomOp
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
@ -23,8 +28,11 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
|||||||
QKVParallelLinear,
|
QKVParallelLinear,
|
||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
|
from vllm.model_executor.layers.mamba.abstract import MambaBase
|
||||||
from vllm.model_executor.layers.mamba.mamba2_metadata import (
|
from vllm.model_executor.layers.mamba.mamba2_metadata import (
|
||||||
Mamba2Metadata, prepare_mamba2_metadata)
|
Mamba2Metadata, prepare_mamba2_metadata, update_metadata)
|
||||||
|
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||||
|
MambaStateDtypeCalculator, MambaStateShapeCalculator)
|
||||||
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
|
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
|
||||||
causal_conv1d_fn, causal_conv1d_update)
|
causal_conv1d_fn, causal_conv1d_update)
|
||||||
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
|
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
|
||||||
@ -39,7 +47,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|||||||
from vllm.model_executor.model_loader.weight_utils import (
|
from vllm.model_executor.model_loader.weight_utils import (
|
||||||
composed_weight_loader, default_weight_loader, sharded_weight_loader)
|
composed_weight_loader, default_weight_loader, sharded_weight_loader)
|
||||||
from vllm.model_executor.models.interfaces import (HasInnerState, IsHybrid,
|
from vllm.model_executor.models.interfaces import (HasInnerState, IsHybrid,
|
||||||
SupportsPP, SupportsV0Only)
|
SupportsPP)
|
||||||
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
|
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
|
||||||
MambaCacheParams)
|
MambaCacheParams)
|
||||||
from vllm.model_executor.models.utils import (
|
from vllm.model_executor.models.utils import (
|
||||||
@ -47,8 +55,10 @@ from vllm.model_executor.models.utils import (
|
|||||||
make_layers, maybe_prefix)
|
make_layers, maybe_prefix)
|
||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.model_executor.utils import set_weight_attrs
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
|
from vllm.platforms import current_platform
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.utils import LayerBlockType
|
from vllm.utils import LayerBlockType, direct_register_custom_op
|
||||||
|
from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadata
|
||||||
|
|
||||||
|
|
||||||
# Only used for type hinting.
|
# Only used for type hinting.
|
||||||
@ -73,20 +83,6 @@ class Plamo2Config(PretrainedConfig): # type: ignore
|
|||||||
vocab_size: int
|
vocab_size: int
|
||||||
|
|
||||||
|
|
||||||
class Plamo2PreTrainedModel(PreTrainedModel): # type: ignore
|
|
||||||
|
|
||||||
def _init_weights(self, module: torch.nn.Module) -> None:
|
|
||||||
std = 0.02
|
|
||||||
if isinstance(module, nn.Linear):
|
|
||||||
module.weight.data.normal_(mean=0.0, std=std)
|
|
||||||
if module.bias is not None:
|
|
||||||
module.bias.data.zero_()
|
|
||||||
elif isinstance(module, nn.Embedding):
|
|
||||||
module.weight.data.normal_(mean=0.0, std=std)
|
|
||||||
if module.padding_idx is not None:
|
|
||||||
module.weight.data[module.padding_idx].zero_()
|
|
||||||
|
|
||||||
|
|
||||||
def is_mamba(config: Plamo2Config, i: int) -> bool:
|
def is_mamba(config: Plamo2Config, i: int) -> bool:
|
||||||
assert config.mamba_step > 1
|
assert config.mamba_step > 1
|
||||||
|
|
||||||
@ -99,7 +95,8 @@ def is_mamba(config: Plamo2Config, i: int) -> bool:
|
|||||||
# Adapted from:
|
# Adapted from:
|
||||||
# vllm.model_executor.layers.mamba.mamba_mixer2.MambaMixer2
|
# vllm.model_executor.layers.mamba.mamba_mixer2.MambaMixer2
|
||||||
# transformers.models.mamba.modeling_mamba.MambaMixer
|
# transformers.models.mamba.modeling_mamba.MambaMixer
|
||||||
class Plamo2MambaMixer(nn.Module):
|
@CustomOp.register(name="plamo2_mamba_mixer")
|
||||||
|
class Plamo2MambaMixer(MambaBase, CustomOp):
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
vllm_config: VllmConfig,
|
vllm_config: VllmConfig,
|
||||||
@ -108,6 +105,8 @@ class Plamo2MambaMixer(nn.Module):
|
|||||||
**kwargs) -> None:
|
**kwargs) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = vllm_config.model_config.hf_config
|
self.config = vllm_config.model_config.hf_config
|
||||||
|
self.cache_config = vllm_config.cache_config
|
||||||
|
self.model_config = vllm_config.model_config
|
||||||
self.quant_config = vllm_config.quant_config
|
self.quant_config = vllm_config.quant_config
|
||||||
self.hidden_size = self.config.hidden_size
|
self.hidden_size = self.config.hidden_size
|
||||||
self.ssm_state_size = self.config.mamba_d_state
|
self.ssm_state_size = self.config.mamba_d_state
|
||||||
@ -115,8 +114,6 @@ class Plamo2MambaMixer(nn.Module):
|
|||||||
self.intermediate_size = (self.config.mamba_num_heads *
|
self.intermediate_size = (self.config.mamba_num_heads *
|
||||||
self.config.hidden_size_per_head)
|
self.config.hidden_size_per_head)
|
||||||
self.tp_size = get_tensor_model_parallel_world_size()
|
self.tp_size = get_tensor_model_parallel_world_size()
|
||||||
self.intermediate_size_per_tp_worker = \
|
|
||||||
self.intermediate_size // self.tp_size
|
|
||||||
self.head_dim = self.config.hidden_size_per_head
|
self.head_dim = self.config.hidden_size_per_head
|
||||||
self.num_heads = self.config.mamba_num_heads
|
self.num_heads = self.config.mamba_num_heads
|
||||||
self.time_step_rank = max(64, self.hidden_size // 16)
|
self.time_step_rank = max(64, self.hidden_size // 16)
|
||||||
@ -197,6 +194,22 @@ class Plamo2MambaMixer(nn.Module):
|
|||||||
self.C_norm = RMSNorm(self.ssm_state_size,
|
self.C_norm = RMSNorm(self.ssm_state_size,
|
||||||
eps=self.config.rms_norm_eps)
|
eps=self.config.rms_norm_eps)
|
||||||
|
|
||||||
|
self.chunk_size = self.config.mamba_chunk_size
|
||||||
|
|
||||||
|
if envs.VLLM_USE_V1:
|
||||||
|
compilation_config = get_current_vllm_config().compilation_config
|
||||||
|
if prefix in compilation_config.static_forward_context:
|
||||||
|
raise ValueError(f"Duplicate layer name: {prefix}")
|
||||||
|
compilation_config.static_forward_context[prefix] = self
|
||||||
|
# The outer list is for v0 PP virtual engine. Though this code path
|
||||||
|
# only runs for v1, we have to do this to unify with the interface
|
||||||
|
# of Attention + v0 PP.
|
||||||
|
# The inner tuple is (conv_state, ssm_state)
|
||||||
|
self.kv_cache = [(torch.tensor([]), torch.tensor([]))]
|
||||||
|
assert self.chunk_size != -1, "chunk_size must be set for v1"
|
||||||
|
|
||||||
|
self.prefix = prefix
|
||||||
|
|
||||||
def _project_ssm_parameters(self, hidden_states):
|
def _project_ssm_parameters(self, hidden_states):
|
||||||
ssm_parameters = self.bcdt_proj(hidden_states)
|
ssm_parameters = self.bcdt_proj(hidden_states)
|
||||||
B, C, time_step = torch.split(
|
B, C, time_step = torch.split(
|
||||||
@ -212,25 +225,76 @@ class Plamo2MambaMixer(nn.Module):
|
|||||||
dt = self.dt_proj(time_step)
|
dt = self.dt_proj(time_step)
|
||||||
return B, C, dt
|
return B, C, dt
|
||||||
|
|
||||||
def forward(
|
def forward_native(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
|
output: torch.Tensor,
|
||||||
mamba_cache_params: MambaCacheParams,
|
mamba_cache_params: MambaCacheParams,
|
||||||
mamba2_metadata: Mamba2Metadata,
|
mamba2_metadata: Mamba2Metadata,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> torch.Tensor:
|
):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
output: torch.Tensor,
|
||||||
|
mamba_cache_params: MambaCacheParams,
|
||||||
|
mamba2_metadata: Mamba2Metadata,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
if not envs.VLLM_USE_V1:
|
||||||
|
CustomOp.forward(self, hidden_states, output, mamba_cache_params,
|
||||||
|
mamba2_metadata)
|
||||||
|
else:
|
||||||
|
torch.ops.vllm.plamo2_mamba_mixer(
|
||||||
|
hidden_states,
|
||||||
|
output,
|
||||||
|
self.prefix,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward_cuda(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
output: torch.Tensor,
|
||||||
|
mamba_cache_params: MambaCacheParams,
|
||||||
|
mamba2_metadata: Mamba2Metadata,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
|
||||||
|
forward_context = get_forward_context()
|
||||||
# mamba2_metadata contains metadata necessary for the mamba2 triton
|
# mamba2_metadata contains metadata necessary for the mamba2 triton
|
||||||
# kernels to operate in continuous batching and in chunked prefill
|
# kernels to operate in continuous batching and in chunked prefill
|
||||||
# modes; they are computed at top-level model forward since they
|
# modes; they are computed at top-level model forward since they
|
||||||
# stay the same and reused for all mamba layers in the same iteration
|
# stay the same and reused for all mamba layers in the same iteration
|
||||||
attn_metadata: AttentionMetadata = get_forward_context().attn_metadata
|
attn_metadata: AttentionMetadata = forward_context.attn_metadata
|
||||||
|
if envs.VLLM_USE_V1:
|
||||||
num_prefills = attn_metadata.num_prefills # request count
|
if attn_metadata is not None:
|
||||||
num_decodes = attn_metadata.num_decode_tokens # token count (=request)
|
assert isinstance(attn_metadata, dict)
|
||||||
num_prefill_tokens = attn_metadata.num_prefill_tokens # token count
|
attn_metadata = attn_metadata[self.prefix]
|
||||||
has_prefill = num_prefills > 0
|
mamba2_metadata = attn_metadata
|
||||||
has_decode = num_decodes > 0
|
assert isinstance(attn_metadata, Mamba2AttentionMetadata)
|
||||||
|
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||||
|
# conv_state = (..., dim, width-1) yet contiguous along 'dim'
|
||||||
|
conv_state = self_kv_cache[0].transpose(-1, -2)
|
||||||
|
ssm_state = self_kv_cache[1]
|
||||||
|
state_indices_tensor = attn_metadata.state_indices_tensor
|
||||||
|
has_initial_states_p = attn_metadata.has_initial_states_p
|
||||||
|
prep_initial_states = attn_metadata.prep_initial_states
|
||||||
|
chunk_size = attn_metadata.chunk_size
|
||||||
|
seq_idx_p = attn_metadata.seq_idx_p
|
||||||
|
chunk_indices_p = attn_metadata.chunk_indices_p
|
||||||
|
chunk_offsets_p = attn_metadata.chunk_offsets_p
|
||||||
|
else:
|
||||||
|
conv_state = mamba_cache_params.conv_state
|
||||||
|
ssm_state = mamba_cache_params.ssm_state
|
||||||
|
state_indices_tensor = mamba_cache_params.state_indices_tensor
|
||||||
|
has_initial_states_p = mamba2_metadata.has_initial_states
|
||||||
|
prep_initial_states = mamba2_metadata.prep_initial_states
|
||||||
|
chunk_size = mamba2_metadata.chunk_size
|
||||||
|
seq_idx_p = mamba2_metadata.seq_idx
|
||||||
|
chunk_indices_p = mamba2_metadata.chunk_indices
|
||||||
|
chunk_offsets_p = mamba2_metadata.chunk_offsets
|
||||||
|
|
||||||
# 1. Gated MLP's linear projection
|
# 1. Gated MLP's linear projection
|
||||||
projected_states = self.in_proj(hidden_states)
|
projected_states = self.in_proj(hidden_states)
|
||||||
@ -240,22 +304,58 @@ class Plamo2MambaMixer(nn.Module):
|
|||||||
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0),
|
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0),
|
||||||
self.conv1d.weight.size(2))
|
self.conv1d.weight.size(2))
|
||||||
|
|
||||||
|
if envs.VLLM_USE_V1 and attn_metadata is None:
|
||||||
|
# V1 profile run
|
||||||
|
hidden_states = (hidden_states.transpose(0, 1).clone().transpose(
|
||||||
|
0, 1)).contiguous()
|
||||||
|
output[:] = self.out_proj(hidden_states)
|
||||||
|
return
|
||||||
|
|
||||||
|
num_prefills = attn_metadata.num_prefills # request count
|
||||||
|
num_decodes = attn_metadata.num_decode_tokens # token count (=request)
|
||||||
|
num_prefill_tokens = attn_metadata.num_prefill_tokens # token count
|
||||||
|
has_prefill = num_prefills > 0
|
||||||
|
has_decode = num_decodes > 0
|
||||||
|
num_actual_tokens = num_prefill_tokens + num_decodes
|
||||||
|
|
||||||
|
# NOTE: V0 put prefill before decode, v1 puts decode before prefill
|
||||||
# Separate prefill and decode by splitting varlen input
|
# Separate prefill and decode by splitting varlen input
|
||||||
# Split along token dimension
|
# Split along token dimension
|
||||||
|
if envs.VLLM_USE_V1:
|
||||||
|
hidden_states_d, hidden_states_p = torch.split(
|
||||||
|
hidden_states[:num_actual_tokens],
|
||||||
|
[num_decodes, num_prefill_tokens],
|
||||||
|
dim=0,
|
||||||
|
)
|
||||||
|
gate_d, gate_p = torch.split(gate[:num_actual_tokens],
|
||||||
|
[num_decodes, num_prefill_tokens],
|
||||||
|
dim=0)
|
||||||
|
# Split along batch dimension
|
||||||
|
state_indices_tensor_d, state_indices_tensor_p = torch.split(
|
||||||
|
state_indices_tensor,
|
||||||
|
[num_decodes, num_prefills],
|
||||||
|
dim=0,
|
||||||
|
)
|
||||||
|
query_start_loc_p = (
|
||||||
|
attn_metadata.query_start_loc[-num_prefills - 1:] -
|
||||||
|
num_decodes if has_prefill else None)
|
||||||
|
else:
|
||||||
hidden_states_p, hidden_states_d = torch.split(
|
hidden_states_p, hidden_states_d = torch.split(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
[num_prefill_tokens, num_decodes],
|
[num_prefill_tokens, num_decodes],
|
||||||
dim=0,
|
dim=0,
|
||||||
)
|
)
|
||||||
gate_p, gate_d = torch.split(gate, [num_prefill_tokens, num_decodes],
|
gate_p, gate_d = torch.split(gate,
|
||||||
|
[num_prefill_tokens, num_decodes],
|
||||||
dim=0)
|
dim=0)
|
||||||
# Split along batch dimension
|
# Split along batch dimension
|
||||||
state_indices_tensor_p, state_indices_tensor_d = torch.split(
|
state_indices_tensor_p, state_indices_tensor_d = torch.split(
|
||||||
mamba_cache_params.state_indices_tensor,
|
state_indices_tensor,
|
||||||
[num_prefills, num_decodes],
|
[num_prefills, num_decodes],
|
||||||
dim=0,
|
dim=0,
|
||||||
)
|
)
|
||||||
query_start_loc_p = (attn_metadata.query_start_loc[:num_prefills + 1]
|
query_start_loc_p = (attn_metadata.query_start_loc[:num_prefills +
|
||||||
|
1]
|
||||||
if has_prefill else None)
|
if has_prefill else None)
|
||||||
|
|
||||||
# Preallocate output tensor to avoid memcpy cost for merging prefill
|
# Preallocate output tensor to avoid memcpy cost for merging prefill
|
||||||
@ -268,6 +368,13 @@ class Plamo2MambaMixer(nn.Module):
|
|||||||
dtype=hidden_states.dtype,
|
dtype=hidden_states.dtype,
|
||||||
device=hidden_states.device,
|
device=hidden_states.device,
|
||||||
)
|
)
|
||||||
|
if envs.VLLM_USE_V1:
|
||||||
|
preallocated_ssm_out_d, preallocated_ssm_out_p = torch.split(
|
||||||
|
preallocated_ssm_out,
|
||||||
|
[num_decodes, num_prefill_tokens],
|
||||||
|
dim=0,
|
||||||
|
)
|
||||||
|
else:
|
||||||
preallocated_ssm_out_p, preallocated_ssm_out_d = torch.split(
|
preallocated_ssm_out_p, preallocated_ssm_out_d = torch.split(
|
||||||
preallocated_ssm_out,
|
preallocated_ssm_out,
|
||||||
[num_prefill_tokens, num_decodes],
|
[num_prefill_tokens, num_decodes],
|
||||||
@ -278,15 +385,21 @@ class Plamo2MambaMixer(nn.Module):
|
|||||||
if has_prefill:
|
if has_prefill:
|
||||||
# 2. Convolution sequence transformation
|
# 2. Convolution sequence transformation
|
||||||
# - "cache_indices" updates the conv_state cache in positions
|
# - "cache_indices" updates the conv_state cache in positions
|
||||||
# pointed to by "mamba_cache_params.state_indices_tensor"
|
# pointed to by "state_indices_tensor"
|
||||||
|
x = hidden_states_p.transpose(
|
||||||
|
0, 1) # this is the form that causal-conv see
|
||||||
|
if mamba2_metadata.cu_seqlen is None:
|
||||||
|
mamba2_metadata = update_metadata(x, query_start_loc_p,
|
||||||
|
mamba2_metadata)
|
||||||
hidden_states_p = causal_conv1d_fn(
|
hidden_states_p = causal_conv1d_fn(
|
||||||
hidden_states_p.transpose(0, 1),
|
x,
|
||||||
conv_weights,
|
conv_weights,
|
||||||
self.conv1d.bias,
|
self.conv1d.bias,
|
||||||
activation=self.activation,
|
activation=self.activation,
|
||||||
conv_states=mamba_cache_params.conv_state,
|
conv_states=conv_state,
|
||||||
has_initial_state=mamba2_metadata.has_initial_states,
|
has_initial_state=has_initial_states_p,
|
||||||
cache_indices=state_indices_tensor_p,
|
cache_indices=state_indices_tensor_p,
|
||||||
|
metadata=mamba2_metadata,
|
||||||
query_start_loc=query_start_loc_p)
|
query_start_loc=query_start_loc_p)
|
||||||
hidden_states_p = hidden_states_p.transpose(0, 1)
|
hidden_states_p = hidden_states_p.transpose(0, 1)
|
||||||
hidden_states_p = hidden_states_p[:num_prefill_tokens]
|
hidden_states_p = hidden_states_p[:num_prefill_tokens]
|
||||||
@ -299,12 +412,16 @@ class Plamo2MambaMixer(nn.Module):
|
|||||||
|
|
||||||
# 3. State Space Model sequence transformation
|
# 3. State Space Model sequence transformation
|
||||||
initial_states = None
|
initial_states = None
|
||||||
if (mamba2_metadata.has_initial_states is not None
|
if has_initial_states_p is not None and prep_initial_states:
|
||||||
and mamba2_metadata.prep_initial_states):
|
|
||||||
# making a copy of the states
|
# making a copy of the states
|
||||||
|
if envs.VLLM_USE_V1:
|
||||||
initial_states = torch.where(
|
initial_states = torch.where(
|
||||||
mamba2_metadata.has_initial_states[:, None, None, None],
|
has_initial_states_p[:, None, None, None],
|
||||||
mamba_cache_params.ssm_state[state_indices_tensor_p], 0)
|
ssm_state[state_indices_tensor_p], 0)
|
||||||
|
else:
|
||||||
|
initial_states = torch.where(
|
||||||
|
has_initial_states_p[:num_prefills, None, None, None],
|
||||||
|
ssm_state[state_indices_tensor_p], 0)
|
||||||
varlen_state = mamba_chunk_scan_combined(
|
varlen_state = mamba_chunk_scan_combined(
|
||||||
hidden_states_p.view(1, num_prefill_tokens,
|
hidden_states_p.view(1, num_prefill_tokens,
|
||||||
self.num_heads // self.tp_size,
|
self.num_heads // self.tp_size,
|
||||||
@ -313,15 +430,15 @@ class Plamo2MambaMixer(nn.Module):
|
|||||||
self.A,
|
self.A,
|
||||||
B.view(1, num_prefill_tokens, 1, -1),
|
B.view(1, num_prefill_tokens, 1, -1),
|
||||||
C.view(1, num_prefill_tokens, 1, -1),
|
C.view(1, num_prefill_tokens, 1, -1),
|
||||||
chunk_size=mamba2_metadata.chunk_size,
|
chunk_size=chunk_size,
|
||||||
D=self.D,
|
D=self.D,
|
||||||
z=gate_p.view(1, num_prefill_tokens,
|
z=gate_p.view(1, num_prefill_tokens,
|
||||||
self.num_heads // self.tp_size, self.head_dim),
|
self.num_heads // self.tp_size, self.head_dim),
|
||||||
dt_bias=self.dt_bias,
|
dt_bias=self.dt_bias,
|
||||||
seq_idx=mamba2_metadata.seq_idx,
|
seq_idx=seq_idx_p,
|
||||||
chunk_indices=mamba2_metadata.chunk_indices,
|
chunk_indices=chunk_indices_p,
|
||||||
chunk_offsets=mamba2_metadata.chunk_offsets,
|
chunk_offsets=chunk_offsets_p,
|
||||||
cu_seqlens=attn_metadata.query_start_loc[:num_prefills + 1],
|
cu_seqlens=query_start_loc_p,
|
||||||
initial_states=initial_states,
|
initial_states=initial_states,
|
||||||
return_varlen_states=True,
|
return_varlen_states=True,
|
||||||
return_final_states=False,
|
return_final_states=False,
|
||||||
@ -329,18 +446,19 @@ class Plamo2MambaMixer(nn.Module):
|
|||||||
dt_limit=(0.0, float("inf")),
|
dt_limit=(0.0, float("inf")),
|
||||||
out=preallocated_ssm_out_p.view(1, num_prefill_tokens, -1,
|
out=preallocated_ssm_out_p.view(1, num_prefill_tokens, -1,
|
||||||
self.head_dim),
|
self.head_dim),
|
||||||
|
state_dtype=ssm_state.dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
# update ssm states
|
# update ssm states
|
||||||
# - varlen state is a (batch, nheads, headdim, dstate) tensor
|
# - varlen state is a (batch, nheads, headdim, dstate) tensor
|
||||||
mamba_cache_params.ssm_state[state_indices_tensor_p] = varlen_state
|
ssm_state[state_indices_tensor_p] = varlen_state
|
||||||
|
|
||||||
# Process decode requests
|
# Process decode requests
|
||||||
if has_decode:
|
if has_decode:
|
||||||
# 2. Convolution sequence transformation
|
# 2. Convolution sequence transformation
|
||||||
hidden_states_d = causal_conv1d_update(
|
hidden_states_d = causal_conv1d_update(
|
||||||
hidden_states_d,
|
hidden_states_d,
|
||||||
mamba_cache_params.conv_state,
|
conv_state,
|
||||||
conv_weights,
|
conv_weights,
|
||||||
self.conv1d.bias,
|
self.conv1d.bias,
|
||||||
self.activation,
|
self.activation,
|
||||||
@ -363,8 +481,10 @@ class Plamo2MambaMixer(nn.Module):
|
|||||||
# - the hidden is reshaped into (bs, num_heads, head_dim)
|
# - the hidden is reshaped into (bs, num_heads, head_dim)
|
||||||
# - mamba_cache_params.ssm_state's slots will be selected
|
# - mamba_cache_params.ssm_state's slots will be selected
|
||||||
# using state_indices_tensor_d
|
# using state_indices_tensor_d
|
||||||
|
|
||||||
|
# NOTE: final output is an in-place update of out tensor
|
||||||
selective_state_update(
|
selective_state_update(
|
||||||
mamba_cache_params.ssm_state,
|
ssm_state,
|
||||||
hidden_states_d,
|
hidden_states_d,
|
||||||
dt,
|
dt,
|
||||||
A,
|
A,
|
||||||
@ -378,11 +498,68 @@ class Plamo2MambaMixer(nn.Module):
|
|||||||
out=preallocated_ssm_out_d.view(num_decodes, -1,
|
out=preallocated_ssm_out_d.view(num_decodes, -1,
|
||||||
self.head_dim),
|
self.head_dim),
|
||||||
)
|
)
|
||||||
assert self.num_heads % self.tp_size == 0
|
|
||||||
|
|
||||||
# 4. Final linear projection
|
# 4. Final linear projection
|
||||||
out = self.out_proj(preallocated_ssm_out)
|
output[:num_actual_tokens] = self.out_proj(preallocated_ssm_out)
|
||||||
return out
|
|
||||||
|
def get_state_dtype(self) -> tuple[torch.dtype, torch.dtype]:
|
||||||
|
assert self.model_config is not None
|
||||||
|
assert self.cache_config is not None
|
||||||
|
return MambaStateDtypeCalculator.mamba2_state_dtype(
|
||||||
|
self.model_config.dtype,
|
||||||
|
self.cache_config.mamba_cache_dtype,
|
||||||
|
self.cache_config.mamba_ssm_cache_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
|
||||||
|
return MambaStateShapeCalculator.mamba2_state_shape(
|
||||||
|
intermediate_size=self.intermediate_size,
|
||||||
|
tp_world_size=get_tensor_model_parallel_world_size(),
|
||||||
|
n_groups=0,
|
||||||
|
num_heads=self.num_heads,
|
||||||
|
head_dim=self.head_dim,
|
||||||
|
state_size=self.ssm_state_size,
|
||||||
|
conv_kernel=self.conv_kernel_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def mamba_type(self) -> str:
|
||||||
|
return "mamba2"
|
||||||
|
|
||||||
|
def get_attn_backend(self) -> type["AttentionBackend"]:
|
||||||
|
from vllm.v1.attention.backends.mamba2_attn import (
|
||||||
|
Mamba2AttentionBackend)
|
||||||
|
return Mamba2AttentionBackend
|
||||||
|
|
||||||
|
|
||||||
|
def plamo2_mamba_mixer(
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
output: torch.Tensor,
|
||||||
|
layer_name: str,
|
||||||
|
) -> None:
|
||||||
|
forward_context: ForwardContext = get_forward_context()
|
||||||
|
self = forward_context.no_compile_layers[layer_name]
|
||||||
|
self.forward_cuda(hidden_states=hidden_states,
|
||||||
|
output=output,
|
||||||
|
mamba_cache_params=None,
|
||||||
|
mamba2_metadata=None)
|
||||||
|
|
||||||
|
|
||||||
|
def plamo2_mamba_mixer_fake(
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
output: torch.Tensor,
|
||||||
|
layer_name: str,
|
||||||
|
) -> None:
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
direct_register_custom_op(
|
||||||
|
op_name="plamo2_mamba_mixer",
|
||||||
|
op_func=plamo2_mamba_mixer,
|
||||||
|
mutates_args=["output"],
|
||||||
|
fake_impl=plamo2_mamba_mixer_fake,
|
||||||
|
dispatch_key=current_platform.dispatch_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class DenseMLP(nn.Module):
|
class DenseMLP(nn.Module):
|
||||||
@ -418,7 +595,6 @@ class DenseMLP(nn.Module):
|
|||||||
return self.down_proj(h)
|
return self.down_proj(h)
|
||||||
|
|
||||||
|
|
||||||
@support_torch_compile
|
|
||||||
class Plamo2AttentionMixer(nn.Module):
|
class Plamo2AttentionMixer(nn.Module):
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
@ -575,12 +751,24 @@ class Plamo2DecoderLayer(nn.Module):
|
|||||||
hidden_states, residual = self.pre_mixer_norm(
|
hidden_states, residual = self.pre_mixer_norm(
|
||||||
hidden_states, residual)
|
hidden_states, residual)
|
||||||
|
|
||||||
|
if self.is_mamba:
|
||||||
|
# Plamo2MambaMixer writes output to this tensor
|
||||||
|
output = torch.empty_like(hidden_states)
|
||||||
|
mixer_kwargs = {
|
||||||
|
"output": output,
|
||||||
|
"mamba_cache_params": mamba_cache_params,
|
||||||
|
"mamba2_metadata": mamba2_metadata,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
mixer_kwargs = {
|
||||||
|
"positions": positions,
|
||||||
|
}
|
||||||
hidden_states = self.mixer(
|
hidden_states = self.mixer(
|
||||||
positions=positions,
|
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
mamba_cache_params=mamba_cache_params,
|
**mixer_kwargs,
|
||||||
mamba2_metadata=mamba2_metadata,
|
|
||||||
)
|
)
|
||||||
|
if self.is_mamba:
|
||||||
|
hidden_states = output
|
||||||
hidden_states = self.post_mixer_norm(hidden_states)
|
hidden_states = self.post_mixer_norm(hidden_states)
|
||||||
# Fully Connected
|
# Fully Connected
|
||||||
hidden_states, residual = self.pre_mlp_norm(hidden_states, residual)
|
hidden_states, residual = self.pre_mlp_norm(hidden_states, residual)
|
||||||
@ -591,7 +779,7 @@ class Plamo2DecoderLayer(nn.Module):
|
|||||||
|
|
||||||
class Plamo2Decoder(torch.nn.Module):
|
class Plamo2Decoder(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None:
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
config = vllm_config.model_config.hf_config
|
config = vllm_config.model_config.hf_config
|
||||||
extra_kwargs = {"is_lora_enabled": bool(vllm_config.lora_config)}
|
extra_kwargs = {"is_lora_enabled": bool(vllm_config.lora_config)}
|
||||||
@ -617,7 +805,7 @@ class Plamo2Decoder(torch.nn.Module):
|
|||||||
mamba_cache_index = 0
|
mamba_cache_index = 0
|
||||||
for layer in islice(self.layers, self.start_layer, self.end_layer):
|
for layer in islice(self.layers, self.start_layer, self.end_layer):
|
||||||
layer_mamba_cache_params = None
|
layer_mamba_cache_params = None
|
||||||
if layer.is_mamba:
|
if layer.is_mamba and mamba_cache_params is not None:
|
||||||
layer_mamba_cache_params = mamba_cache_params.at_layer_idx(
|
layer_mamba_cache_params = mamba_cache_params.at_layer_idx(
|
||||||
mamba_cache_index)
|
mamba_cache_index)
|
||||||
mamba_cache_index += 1
|
mamba_cache_index += 1
|
||||||
@ -632,10 +820,11 @@ class Plamo2Decoder(torch.nn.Module):
|
|||||||
return hidden_states, residual
|
return hidden_states, residual
|
||||||
|
|
||||||
|
|
||||||
class Plamo2Model(Plamo2PreTrainedModel):
|
@support_torch_compile
|
||||||
|
class Plamo2Model(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
super().__init__(vllm_config.model_config.hf_config)
|
super().__init__()
|
||||||
|
|
||||||
config = vllm_config.model_config.hf_config
|
config = vllm_config.model_config.hf_config
|
||||||
|
|
||||||
@ -653,9 +842,9 @@ class Plamo2Model(Plamo2PreTrainedModel):
|
|||||||
self.make_empty_intermediate_tensors = (
|
self.make_empty_intermediate_tensors = (
|
||||||
make_empty_intermediate_tensors_factory(
|
make_empty_intermediate_tensors_factory(
|
||||||
["hidden_states", "residual"], config.hidden_size))
|
["hidden_states", "residual"], config.hidden_size))
|
||||||
self.layers = Plamo2Decoder(vllm_config, prefix=f"{prefix}.layers")
|
self.layers = Plamo2Decoder(vllm_config=vllm_config,
|
||||||
|
prefix=f"{prefix}.layers")
|
||||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
self.post_init()
|
|
||||||
|
|
||||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||||
return self.embed_tokens(input_ids)
|
return self.embed_tokens(input_ids)
|
||||||
@ -679,11 +868,16 @@ class Plamo2Model(Plamo2PreTrainedModel):
|
|||||||
hidden_states = intermediate_tensors["hidden_states"]
|
hidden_states = intermediate_tensors["hidden_states"]
|
||||||
residual = intermediate_tensors["residual"]
|
residual = intermediate_tensors["residual"]
|
||||||
|
|
||||||
attn_metadata: AttentionMetadata = get_forward_context().attn_metadata
|
if not envs.VLLM_USE_V1:
|
||||||
|
attn_metadata: AttentionMetadata = get_forward_context(
|
||||||
|
).attn_metadata
|
||||||
mamba2_metadata = prepare_mamba2_metadata(
|
mamba2_metadata = prepare_mamba2_metadata(
|
||||||
chunk_size=self.config.mamba_chunk_size,
|
chunk_size=self.config.mamba_chunk_size,
|
||||||
attn_metadata=attn_metadata,
|
attn_metadata=attn_metadata,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
# v1 get mamba2_metadata from forward_context
|
||||||
|
mamba2_metadata = None
|
||||||
|
|
||||||
hidden_states, residual = self.layers(
|
hidden_states, residual = self.layers(
|
||||||
positions=positions,
|
positions=positions,
|
||||||
@ -701,8 +895,7 @@ class Plamo2Model(Plamo2PreTrainedModel):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class Plamo2ForCausalLM(Plamo2PreTrainedModel, HasInnerState, SupportsPP,
|
class Plamo2ForCausalLM(torch.nn.Module, HasInnerState, SupportsPP, IsHybrid):
|
||||||
IsHybrid, SupportsV0Only):
|
|
||||||
packed_modules_mapping = {
|
packed_modules_mapping = {
|
||||||
"qkv_proj": [
|
"qkv_proj": [
|
||||||
"q_proj",
|
"q_proj",
|
||||||
@ -712,12 +905,10 @@ class Plamo2ForCausalLM(Plamo2PreTrainedModel, HasInnerState, SupportsPP,
|
|||||||
}
|
}
|
||||||
|
|
||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
|
||||||
|
super().__init__()
|
||||||
config = vllm_config.model_config.hf_config
|
config = vllm_config.model_config.hf_config
|
||||||
scheduler_config = vllm_config.scheduler_config
|
scheduler_config = vllm_config.scheduler_config
|
||||||
assert not vllm_config.cache_config.enable_prefix_caching, \
|
|
||||||
"PLaMo2 currently does not support prefix caching"
|
|
||||||
|
|
||||||
super().__init__(config)
|
|
||||||
self.config = config
|
self.config = config
|
||||||
self.vllm_config = vllm_config
|
self.vllm_config = vllm_config
|
||||||
self.model_config = vllm_config.model_config
|
self.model_config = vllm_config.model_config
|
||||||
@ -751,8 +942,6 @@ class Plamo2ForCausalLM(Plamo2PreTrainedModel, HasInnerState, SupportsPP,
|
|||||||
self.sampler = get_sampler()
|
self.sampler = get_sampler()
|
||||||
self.make_empty_intermediate_tensors = (
|
self.make_empty_intermediate_tensors = (
|
||||||
self.model.make_empty_intermediate_tensors)
|
self.model.make_empty_intermediate_tensors)
|
||||||
# Initialize weights and apply final processing
|
|
||||||
self.post_init()
|
|
||||||
|
|
||||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||||
return self.model.get_input_embeddings(input_ids)
|
return self.model.get_input_embeddings(input_ids)
|
||||||
@ -763,19 +952,27 @@ class Plamo2ForCausalLM(Plamo2PreTrainedModel, HasInnerState, SupportsPP,
|
|||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
|
if not envs.VLLM_USE_V1:
|
||||||
if self.mamba_cache is None:
|
if self.mamba_cache is None:
|
||||||
num_mamba_layers = self.model_config.get_num_layers_by_block_type(
|
num_mamba_layers = (
|
||||||
self.vllm_config.parallel_config, LayerBlockType.mamba)
|
self.model_config.get_num_layers_by_block_type(
|
||||||
|
self.vllm_config.parallel_config,
|
||||||
|
LayerBlockType.mamba))
|
||||||
|
|
||||||
self.mamba_cache = MambaCacheManager(
|
mamba_state_shape = self.get_mamba_state_shape_from_config(
|
||||||
self.vllm_config,
|
self.vllm_config, use_v1=False)
|
||||||
|
mamba_state_dtype = \
|
||||||
|
self.get_mamba_state_dtype_from_config(
|
||||||
|
self.vllm_config)
|
||||||
|
self.mamba_cache = MambaCacheManager(self.vllm_config,
|
||||||
num_mamba_layers,
|
num_mamba_layers,
|
||||||
*self._get_mamba_cache_shape(),
|
*mamba_state_shape,
|
||||||
self.lm_head.weight.dtype,
|
*mamba_state_dtype)
|
||||||
self.lm_head.weight.dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
||||||
|
else:
|
||||||
|
# NOTE: mamba_cache_params is not needed for v1
|
||||||
|
mamba_cache_params = None
|
||||||
|
|
||||||
hidden_states = self.model(input_ids, positions, mamba_cache_params,
|
hidden_states = self.model(input_ids, positions, mamba_cache_params,
|
||||||
intermediate_tensors, inputs_embeds)
|
intermediate_tensors, inputs_embeds)
|
||||||
@ -788,21 +985,48 @@ class Plamo2ForCausalLM(Plamo2PreTrainedModel, HasInnerState, SupportsPP,
|
|||||||
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
|
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
|
||||||
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
|
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
|
||||||
|
|
||||||
def _get_mamba_cache_shape(
|
@classmethod
|
||||||
self) -> tuple[tuple[int, int], tuple[int, int, int]]:
|
def get_mamba_state_dtype_from_config(
|
||||||
world_size = get_tensor_model_parallel_world_size()
|
cls,
|
||||||
hidden_size = (self.config.mamba_num_heads *
|
vllm_config: "VllmConfig",
|
||||||
self.config.hidden_size_per_head)
|
) -> tuple[torch.dtype, torch.dtype]:
|
||||||
conv_state_shape = (
|
|
||||||
hidden_size // world_size,
|
return MambaStateDtypeCalculator.mamba2_state_dtype(
|
||||||
self.config.mamba_d_conv - 1,
|
vllm_config.model_config.dtype,
|
||||||
|
vllm_config.cache_config.mamba_cache_dtype,
|
||||||
|
vllm_config.cache_config.mamba_ssm_cache_dtype,
|
||||||
)
|
)
|
||||||
temporal_state_shape = (
|
|
||||||
divide(self.config.mamba_num_heads, world_size),
|
@classmethod
|
||||||
self.config.hidden_size_per_head,
|
def get_mamba_state_shape_from_config(
|
||||||
self.config.mamba_d_state,
|
cls,
|
||||||
|
vllm_config: "VllmConfig",
|
||||||
|
use_v1: bool = True,
|
||||||
|
) -> tuple[tuple[int, int], tuple[int, int, int]]:
|
||||||
|
"""Calculate shapes for Mamba's convolutional and state caches.
|
||||||
|
Args:
|
||||||
|
vllm_config: vLLM config
|
||||||
|
use_v1: Get shapes for V1 (or V0)
|
||||||
|
Returns:
|
||||||
|
Tuple containing:
|
||||||
|
- conv_state_shape: Shape for convolutional state cache
|
||||||
|
- temporal_state_shape: Shape for state space model cache
|
||||||
|
"""
|
||||||
|
parallel_config = vllm_config.parallel_config
|
||||||
|
hf_config = vllm_config.model_config.hf_config
|
||||||
|
intermediate_size =\
|
||||||
|
hf_config.mamba_num_heads * hf_config.hidden_size_per_head
|
||||||
|
|
||||||
|
return MambaStateShapeCalculator.mamba2_state_shape(
|
||||||
|
intermediate_size=intermediate_size,
|
||||||
|
tp_world_size=parallel_config.tensor_parallel_size,
|
||||||
|
n_groups=0,
|
||||||
|
num_heads=hf_config.mamba_num_heads,
|
||||||
|
head_dim=hf_config.hidden_size_per_head,
|
||||||
|
state_size=hf_config.mamba_d_state,
|
||||||
|
conv_kernel=hf_config.mamba_d_conv,
|
||||||
|
use_v1=use_v1,
|
||||||
)
|
)
|
||||||
return conv_state_shape, temporal_state_shape
|
|
||||||
|
|
||||||
def compute_logits(
|
def compute_logits(
|
||||||
self,
|
self,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user