mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-21 01:45:54 +08:00
[V1] [Hybrid] Support using float32 for state in Hybrid Models (Mamba2, Mamba1, Minimax) (#22928)
Signed-off-by: Daniel Afrimi <danielafrimi8@gmail.com> Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com> Signed-off-by: Chen Zhang <zhangch99@outlook.com> Co-authored-by: Daniel Afrimi <danielafrimi8@gmail.com> Co-authored-by: Burkhard Ringlein <ngl@zurich.ibm.com> Co-authored-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
parent
22341b996e
commit
75531a6c13
@ -431,3 +431,65 @@ def test_full_cuda_graph(
|
|||||||
name_0="hf" if hf_outputs is not None else "vllm-v0",
|
name_0="hf" if hf_outputs is not None else "vllm-v0",
|
||||||
name_1="vllm-v1",
|
name_1="vllm-v1",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model", ["Zyphra/Zamba2-1.2B-instruct"])
|
||||||
|
@pytest.mark.parametrize("max_tokens", [64])
|
||||||
|
@pytest.mark.parametrize("num_logprobs", [5])
|
||||||
|
def test_fp32_state(
|
||||||
|
hf_runner,
|
||||||
|
vllm_runner,
|
||||||
|
example_prompts,
|
||||||
|
monkeypatch,
|
||||||
|
model: str,
|
||||||
|
max_tokens: int,
|
||||||
|
num_logprobs: int,
|
||||||
|
) -> None:
|
||||||
|
|
||||||
|
try:
|
||||||
|
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
|
||||||
|
model_info.check_available_online(on_fail="skip")
|
||||||
|
model_info.check_transformers_version(on_fail="skip")
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
with hf_runner(model) as hf_model:
|
||||||
|
if model not in HF_UNSUPPORTED_MODELS:
|
||||||
|
hf_outputs = hf_model.generate_greedy_logprobs_limit(
|
||||||
|
example_prompts, max_tokens, num_logprobs)
|
||||||
|
else:
|
||||||
|
hf_outputs = None
|
||||||
|
|
||||||
|
with vllm_runner(model,
|
||||||
|
max_num_seqs=MAX_NUM_SEQS,
|
||||||
|
mamba_ssm_cache_dtype="float32") as vllm_model:
|
||||||
|
vllm_v0_outputs = vllm_model.generate_greedy_logprobs(
|
||||||
|
example_prompts, max_tokens, num_logprobs)
|
||||||
|
|
||||||
|
with monkeypatch.context() as m:
|
||||||
|
m.setenv("VLLM_USE_V1", "1")
|
||||||
|
if model in HYBRID_MODELS:
|
||||||
|
# required due to reorder_batch behaviour
|
||||||
|
m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER")
|
||||||
|
with vllm_runner(model,
|
||||||
|
max_num_seqs=MAX_NUM_SEQS,
|
||||||
|
mamba_ssm_cache_dtype="float32",
|
||||||
|
enable_prefix_caching=False) as vllm_model:
|
||||||
|
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
|
||||||
|
example_prompts, max_tokens, num_logprobs)
|
||||||
|
|
||||||
|
if hf_outputs is not None:
|
||||||
|
check_logprobs_close(
|
||||||
|
outputs_0_lst=hf_outputs,
|
||||||
|
outputs_1_lst=vllm_v0_outputs,
|
||||||
|
name_0="hf",
|
||||||
|
name_1="vllm-v0",
|
||||||
|
)
|
||||||
|
|
||||||
|
ref_outputs = hf_outputs if hf_outputs is not None else vllm_v0_outputs
|
||||||
|
check_logprobs_close(
|
||||||
|
outputs_0_lst=ref_outputs,
|
||||||
|
outputs_1_lst=vllm_v1_outputs,
|
||||||
|
name_0="hf" if hf_outputs is not None else "vllm-v0",
|
||||||
|
name_1="vllm-v1",
|
||||||
|
)
|
||||||
|
|||||||
@ -772,6 +772,8 @@ def test_hybrid_attention_mamba_tensor_shapes(monkeypatch):
|
|||||||
head_dim=hf_config.mamba_d_head,
|
head_dim=hf_config.mamba_d_head,
|
||||||
rms_norm_eps=hf_config.rms_norm_eps,
|
rms_norm_eps=hf_config.rms_norm_eps,
|
||||||
activation=hf_config.hidden_act,
|
activation=hf_config.hidden_act,
|
||||||
|
cache_config=cache_config,
|
||||||
|
model_config=model_config,
|
||||||
prefix=key,
|
prefix=key,
|
||||||
)
|
)
|
||||||
# suppress var not used error
|
# suppress var not used error
|
||||||
|
|||||||
@ -29,7 +29,7 @@ from typing_extensions import Self, assert_never, runtime_checkable
|
|||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm import version
|
from vllm import version
|
||||||
from vllm.config.cache import (BlockSize, CacheConfig, CacheDType,
|
from vllm.config.cache import (BlockSize, CacheConfig, CacheDType, MambaDType,
|
||||||
PrefixCachingHashAlgo)
|
PrefixCachingHashAlgo)
|
||||||
from vllm.config.compilation import (CompilationConfig, CompilationLevel,
|
from vllm.config.compilation import (CompilationConfig, CompilationLevel,
|
||||||
PassConfig)
|
PassConfig)
|
||||||
|
|||||||
@ -23,6 +23,7 @@ logger = init_logger(__name__)
|
|||||||
|
|
||||||
BlockSize = Literal[1, 8, 16, 32, 64, 128]
|
BlockSize = Literal[1, 8, 16, 32, 64, 128]
|
||||||
CacheDType = Literal["auto", "fp8", "fp8_e4m3", "fp8_e5m2", "fp8_inc"]
|
CacheDType = Literal["auto", "fp8", "fp8_e4m3", "fp8_e5m2", "fp8_inc"]
|
||||||
|
MambaDType = Literal["auto", "float32"]
|
||||||
PrefixCachingHashAlgo = Literal["builtin", "sha256", "sha256_cbor_64bit"]
|
PrefixCachingHashAlgo = Literal["builtin", "sha256", "sha256_cbor_64bit"]
|
||||||
|
|
||||||
|
|
||||||
@ -93,6 +94,15 @@ class CacheConfig:
|
|||||||
""" Optional override for mamba page size; used by hybrid mamba/attention
|
""" Optional override for mamba page size; used by hybrid mamba/attention
|
||||||
models to ensure exact alignment with attention page size."""
|
models to ensure exact alignment with attention page size."""
|
||||||
|
|
||||||
|
mamba_cache_dtype: MambaDType = "auto"
|
||||||
|
"""The data type to use for the Mamba cache (both the conv as well as the
|
||||||
|
ssm state). If set to 'auto', the data type will be inferred from the model
|
||||||
|
config."""
|
||||||
|
mamba_ssm_cache_dtype: MambaDType = "auto"
|
||||||
|
"""The data type to use for the Mamba cache (ssm state only, conv state will
|
||||||
|
still be controlled by mamba_cache_dtype). If set to 'auto', the data type
|
||||||
|
for the ssm state will be determined by mamba_cache_dtype."""
|
||||||
|
|
||||||
# Will be set after profiling.
|
# Will be set after profiling.
|
||||||
num_gpu_blocks: Optional[int] = field(default=None, init=False)
|
num_gpu_blocks: Optional[int] = field(default=None, init=False)
|
||||||
"""The number of blocks to allocate for GPU memory."""
|
"""The number of blocks to allocate for GPU memory."""
|
||||||
@ -123,6 +133,8 @@ class CacheConfig:
|
|||||||
"""
|
"""
|
||||||
factors: list[Any] = []
|
factors: list[Any] = []
|
||||||
factors.append(self.cache_dtype)
|
factors.append(self.cache_dtype)
|
||||||
|
factors.append(self.mamba_cache_dtype)
|
||||||
|
factors.append(self.mamba_ssm_cache_dtype)
|
||||||
# `cpu_offload_gb` does not use `torch.compile` yet.
|
# `cpu_offload_gb` does not use `torch.compile` yet.
|
||||||
hash_str = hashlib.md5(str(factors).encode(),
|
hash_str = hashlib.md5(str(factors).encode(),
|
||||||
usedforsecurity=False).hexdigest()
|
usedforsecurity=False).hexdigest()
|
||||||
|
|||||||
@ -27,12 +27,12 @@ from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig,
|
|||||||
DeviceConfig, DistributedExecutorBackend,
|
DeviceConfig, DistributedExecutorBackend,
|
||||||
GuidedDecodingBackend, HfOverrides, KVEventsConfig,
|
GuidedDecodingBackend, HfOverrides, KVEventsConfig,
|
||||||
KVTransferConfig, LoadConfig, LogprobsMode,
|
KVTransferConfig, LoadConfig, LogprobsMode,
|
||||||
LoRAConfig, ModelConfig, ModelDType, ModelImpl,
|
LoRAConfig, MambaDType, ModelConfig, ModelDType,
|
||||||
MultiModalConfig, ObservabilityConfig, ParallelConfig,
|
ModelImpl, MultiModalConfig, ObservabilityConfig,
|
||||||
PoolerConfig, PrefixCachingHashAlgo, RunnerOption,
|
ParallelConfig, PoolerConfig, PrefixCachingHashAlgo,
|
||||||
SchedulerConfig, SchedulerPolicy, SpeculativeConfig,
|
RunnerOption, SchedulerConfig, SchedulerPolicy,
|
||||||
TaskOption, TokenizerMode, VllmConfig, get_attr_docs,
|
SpeculativeConfig, TaskOption, TokenizerMode,
|
||||||
get_field)
|
VllmConfig, get_attr_docs, get_field)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.platforms import CpuArchEnum, current_platform
|
from vllm.platforms import CpuArchEnum, current_platform
|
||||||
from vllm.plugins import load_general_plugins
|
from vllm.plugins import load_general_plugins
|
||||||
@ -422,6 +422,8 @@ class EngineArgs:
|
|||||||
override_attention_dtype: str = ModelConfig.override_attention_dtype
|
override_attention_dtype: str = ModelConfig.override_attention_dtype
|
||||||
|
|
||||||
calculate_kv_scales: bool = CacheConfig.calculate_kv_scales
|
calculate_kv_scales: bool = CacheConfig.calculate_kv_scales
|
||||||
|
mamba_cache_dtype: MambaDType = CacheConfig.mamba_cache_dtype
|
||||||
|
mamba_ssm_cache_dtype: MambaDType = CacheConfig.mamba_ssm_cache_dtype
|
||||||
|
|
||||||
additional_config: dict[str, Any] = \
|
additional_config: dict[str, Any] = \
|
||||||
get_field(VllmConfig, "additional_config")
|
get_field(VllmConfig, "additional_config")
|
||||||
@ -694,6 +696,10 @@ class EngineArgs:
|
|||||||
**cache_kwargs["calculate_kv_scales"])
|
**cache_kwargs["calculate_kv_scales"])
|
||||||
cache_group.add_argument("--kv-sharing-fast-prefill",
|
cache_group.add_argument("--kv-sharing-fast-prefill",
|
||||||
**cache_kwargs["kv_sharing_fast_prefill"])
|
**cache_kwargs["kv_sharing_fast_prefill"])
|
||||||
|
cache_group.add_argument("--mamba-cache-dtype",
|
||||||
|
**cache_kwargs["mamba_cache_dtype"])
|
||||||
|
cache_group.add_argument("--mamba-ssm-cache-dtype",
|
||||||
|
**cache_kwargs["mamba_ssm_cache_dtype"])
|
||||||
|
|
||||||
# Multimodal related configs
|
# Multimodal related configs
|
||||||
multimodal_kwargs = get_kwargs(MultiModalConfig)
|
multimodal_kwargs = get_kwargs(MultiModalConfig)
|
||||||
@ -1105,6 +1111,8 @@ class EngineArgs:
|
|||||||
cpu_offload_gb=self.cpu_offload_gb,
|
cpu_offload_gb=self.cpu_offload_gb,
|
||||||
calculate_kv_scales=self.calculate_kv_scales,
|
calculate_kv_scales=self.calculate_kv_scales,
|
||||||
kv_sharing_fast_prefill=self.kv_sharing_fast_prefill,
|
kv_sharing_fast_prefill=self.kv_sharing_fast_prefill,
|
||||||
|
mamba_cache_dtype=self.mamba_cache_dtype,
|
||||||
|
mamba_ssm_cache_dtype=self.mamba_ssm_cache_dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
ray_runtime_env = None
|
ray_runtime_env = None
|
||||||
|
|||||||
@ -9,7 +9,7 @@ from torch.nn.parameter import Parameter
|
|||||||
|
|
||||||
from vllm import envs
|
from vllm import envs
|
||||||
from vllm.attention.backends.abstract import AttentionMetadata
|
from vllm.attention.backends.abstract import AttentionMetadata
|
||||||
from vllm.config import get_current_vllm_config
|
from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config
|
||||||
from vllm.distributed.parallel_state import (
|
from vllm.distributed.parallel_state import (
|
||||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||||
from vllm.forward_context import ForwardContext, get_forward_context
|
from vllm.forward_context import ForwardContext, get_forward_context
|
||||||
@ -20,7 +20,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
|||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
from vllm.model_executor.layers.mamba.abstract import MambaBase
|
from vllm.model_executor.layers.mamba.abstract import MambaBase
|
||||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||||
MambaStateShapeCalculator)
|
MambaStateDtypeCalculator, MambaStateShapeCalculator)
|
||||||
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
|
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
|
||||||
causal_conv1d_fn, causal_conv1d_update)
|
causal_conv1d_fn, causal_conv1d_update)
|
||||||
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
|
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
|
||||||
@ -56,6 +56,8 @@ class MambaMixer(MambaBase, CustomOp):
|
|||||||
rms_norm_eps: float = 1e-5,
|
rms_norm_eps: float = 1e-5,
|
||||||
activation="silu",
|
activation="silu",
|
||||||
is_lora_enabled: bool = False,
|
is_lora_enabled: bool = False,
|
||||||
|
model_config: Optional[ModelConfig] = None,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
prefix: str = ""):
|
prefix: str = ""):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.time_step_rank = time_step_rank
|
self.time_step_rank = time_step_rank
|
||||||
@ -153,6 +155,8 @@ class MambaMixer(MambaBase, CustomOp):
|
|||||||
# The inner tuple is (conv_state, ssm_state)
|
# The inner tuple is (conv_state, ssm_state)
|
||||||
self.kv_cache = [(torch.tensor([]), torch.tensor([]))]
|
self.kv_cache = [(torch.tensor([]), torch.tensor([]))]
|
||||||
|
|
||||||
|
self.model_config = model_config
|
||||||
|
self.cache_config = cache_config
|
||||||
self.prefix = prefix
|
self.prefix = prefix
|
||||||
|
|
||||||
def _ssm_transform(
|
def _ssm_transform(
|
||||||
@ -369,6 +373,15 @@ class MambaMixer(MambaBase, CustomOp):
|
|||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
def get_state_dtype(self) -> tuple[torch.dtype]:
|
||||||
|
assert self.model_config is not None
|
||||||
|
assert self.cache_config is not None
|
||||||
|
return MambaStateDtypeCalculator.mamba1_state_dtype(
|
||||||
|
self.model_config.dtype,
|
||||||
|
self.cache_config.mamba_cache_dtype,
|
||||||
|
self.cache_config.mamba_ssm_cache_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
|
def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
|
||||||
return MambaStateShapeCalculator.mamba1_state_shape(
|
return MambaStateShapeCalculator.mamba1_state_shape(
|
||||||
tp_world_size=get_tensor_model_parallel_world_size(),
|
tp_world_size=get_tensor_model_parallel_world_size(),
|
||||||
|
|||||||
@ -8,7 +8,7 @@ from torch import nn
|
|||||||
|
|
||||||
from vllm import envs
|
from vllm import envs
|
||||||
from vllm.attention.backends.abstract import AttentionMetadata
|
from vllm.attention.backends.abstract import AttentionMetadata
|
||||||
from vllm.config import get_current_vllm_config
|
from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config
|
||||||
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
|
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
|
||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
tensor_model_parallel_all_gather,
|
tensor_model_parallel_all_gather,
|
||||||
@ -21,7 +21,7 @@ from vllm.model_executor.layers.mamba.abstract import MambaBase
|
|||||||
from vllm.model_executor.layers.mamba.mamba2_metadata import (Mamba2Metadata,
|
from vllm.model_executor.layers.mamba.mamba2_metadata import (Mamba2Metadata,
|
||||||
update_metadata)
|
update_metadata)
|
||||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||||
MambaStateShapeCalculator)
|
MambaStateDtypeCalculator, MambaStateShapeCalculator)
|
||||||
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
|
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
|
||||||
causal_conv1d_fn, causal_conv1d_update)
|
causal_conv1d_fn, causal_conv1d_update)
|
||||||
from vllm.model_executor.layers.mamba.ops.layernorm_gated import rms_norm_gated
|
from vllm.model_executor.layers.mamba.ops.layernorm_gated import rms_norm_gated
|
||||||
@ -218,23 +218,23 @@ class MambaMixer2(MambaBase, CustomOp):
|
|||||||
**selective** state spaces)
|
**selective** state spaces)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self,
|
||||||
self,
|
hidden_size: int,
|
||||||
hidden_size: int,
|
ssm_state_size: int,
|
||||||
ssm_state_size: int,
|
conv_kernel_size: int,
|
||||||
conv_kernel_size: int,
|
intermediate_size: int,
|
||||||
intermediate_size: int,
|
use_conv_bias: bool,
|
||||||
use_conv_bias: bool,
|
use_bias: bool,
|
||||||
use_bias: bool,
|
n_groups: int = 1,
|
||||||
n_groups: int = 1,
|
num_heads: int = 128,
|
||||||
num_heads: int = 128,
|
head_dim: int = 64,
|
||||||
head_dim: int = 64,
|
rms_norm_eps: float = 1e-5,
|
||||||
rms_norm_eps: float = 1e-5,
|
activation: str = "silu",
|
||||||
activation: str = "silu",
|
use_rms_norm: bool = True,
|
||||||
use_rms_norm: bool = True,
|
model_config: Optional[ModelConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
prefix: str = "",
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
):
|
prefix: str = ""):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
# For TP, the sharding plan is as follows:
|
# For TP, the sharding plan is as follows:
|
||||||
@ -417,6 +417,8 @@ class MambaMixer2(MambaBase, CustomOp):
|
|||||||
# The inner tuple is (conv_state, ssm_state)
|
# The inner tuple is (conv_state, ssm_state)
|
||||||
self.kv_cache = [(torch.tensor([]), torch.tensor([]))]
|
self.kv_cache = [(torch.tensor([]), torch.tensor([]))]
|
||||||
|
|
||||||
|
self.model_config = model_config
|
||||||
|
self.cache_config = cache_config
|
||||||
self.prefix = prefix
|
self.prefix = prefix
|
||||||
|
|
||||||
def forward_native(
|
def forward_native(
|
||||||
@ -670,7 +672,7 @@ class MambaMixer2(MambaBase, CustomOp):
|
|||||||
dt_limit=(0.0, float("inf")),
|
dt_limit=(0.0, float("inf")),
|
||||||
out=preallocated_ssm_out_p.view(1, num_prefill_tokens, -1,
|
out=preallocated_ssm_out_p.view(1, num_prefill_tokens, -1,
|
||||||
self.head_dim),
|
self.head_dim),
|
||||||
)
|
state_dtype=ssm_state.dtype)
|
||||||
|
|
||||||
# update ssm states
|
# update ssm states
|
||||||
# - varlen state is a (num_prefills, nheads, headdim, dstate) tensor
|
# - varlen state is a (num_prefills, nheads, headdim, dstate) tensor
|
||||||
@ -732,6 +734,15 @@ class MambaMixer2(MambaBase, CustomOp):
|
|||||||
# 5. Final linear projection
|
# 5. Final linear projection
|
||||||
output[:num_actual_tokens], _ = self.out_proj(hidden_states)
|
output[:num_actual_tokens], _ = self.out_proj(hidden_states)
|
||||||
|
|
||||||
|
def get_state_dtype(self) -> tuple[torch.dtype, torch.dtype]:
|
||||||
|
assert self.model_config is not None
|
||||||
|
assert self.cache_config is not None
|
||||||
|
return MambaStateDtypeCalculator.mamba2_state_dtype(
|
||||||
|
self.model_config.dtype,
|
||||||
|
self.cache_config.mamba_cache_dtype,
|
||||||
|
self.cache_config.mamba_ssm_cache_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
|
def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
|
||||||
return MambaStateShapeCalculator.mamba2_state_shape(
|
return MambaStateShapeCalculator.mamba2_state_shape(
|
||||||
intermediate_size=self.intermediate_size,
|
intermediate_size=self.intermediate_size,
|
||||||
|
|||||||
@ -1,6 +1,58 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.config import MambaDType, ModelDType
|
||||||
from vllm.distributed import divide
|
from vllm.distributed import divide
|
||||||
|
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, get_kv_cache_torch_dtype
|
||||||
|
|
||||||
|
|
||||||
|
class MambaStateDtypeCalculator:
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def linear_attention_state_dtype(
|
||||||
|
cls,
|
||||||
|
model_dtype: Union[ModelDType, torch.dtype],
|
||||||
|
mamba_cache_dtype: MambaDType,
|
||||||
|
) -> tuple[torch.dtype, ...]:
|
||||||
|
# TODO (tdoublep) requires testing
|
||||||
|
if mamba_cache_dtype == "float32":
|
||||||
|
raise ValueError("fp32 state for minimax is not yet supported")
|
||||||
|
state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype, model_dtype)
|
||||||
|
return (state_dtype, )
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def mamba1_state_dtype(
|
||||||
|
cls,
|
||||||
|
model_dtype: Union[ModelDType, torch.dtype],
|
||||||
|
mamba_cache_dtype: MambaDType,
|
||||||
|
mamba_ssm_cache_dtype: MambaDType,
|
||||||
|
) -> tuple[torch.dtype, ...]:
|
||||||
|
# TODO (tdoublep) requires kernel changes
|
||||||
|
if mamba_cache_dtype == "float32" or mamba_ssm_cache_dtype == "float32":
|
||||||
|
raise ValueError("fp32 state for mamba1 is not yet supported")
|
||||||
|
else:
|
||||||
|
return MambaStateDtypeCalculator.mamba2_state_dtype(
|
||||||
|
model_dtype, mamba_cache_dtype, mamba_ssm_cache_dtype)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def mamba2_state_dtype(
|
||||||
|
cls,
|
||||||
|
model_dtype: Union[ModelDType, torch.dtype],
|
||||||
|
mamba_cache_dtype: MambaDType,
|
||||||
|
mamba_ssm_cache_dtype: MambaDType,
|
||||||
|
) -> tuple[torch.dtype, ...]:
|
||||||
|
conv_state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype,
|
||||||
|
model_dtype)
|
||||||
|
if mamba_ssm_cache_dtype == "auto":
|
||||||
|
temporal_state_dtype = conv_state_dtype
|
||||||
|
else:
|
||||||
|
temporal_state_dtype = (
|
||||||
|
STR_DTYPE_TO_TORCH_DTYPE[mamba_ssm_cache_dtype])
|
||||||
|
|
||||||
|
return (conv_state_dtype, temporal_state_dtype)
|
||||||
|
|
||||||
|
|
||||||
class MambaStateShapeCalculator:
|
class MambaStateShapeCalculator:
|
||||||
|
|||||||
@ -41,6 +41,7 @@ def _mamba_chunk_scan_combined_fwd(x,
|
|||||||
cu_seqlens=None,
|
cu_seqlens=None,
|
||||||
dt_softplus=False,
|
dt_softplus=False,
|
||||||
dt_limit=(0.0, float("inf")),
|
dt_limit=(0.0, float("inf")),
|
||||||
|
state_dtype=None,
|
||||||
out=None):
|
out=None):
|
||||||
assert is_int_pow_2(chunk_size), "chunk_size must be integer power of 2"
|
assert is_int_pow_2(chunk_size), "chunk_size must be integer power of 2"
|
||||||
batch, seqlen, nheads, headdim = x.shape
|
batch, seqlen, nheads, headdim = x.shape
|
||||||
@ -118,7 +119,7 @@ def _mamba_chunk_scan_combined_fwd(x,
|
|||||||
if initial_states is not None else None,
|
if initial_states is not None else None,
|
||||||
seq_idx=seq_idx,
|
seq_idx=seq_idx,
|
||||||
chunk_size=chunk_size,
|
chunk_size=chunk_size,
|
||||||
out_dtype=C.dtype,
|
out_dtype=state_dtype if state_dtype is not None else C.dtype,
|
||||||
is_cont_batched=cu_seqlens is not None)
|
is_cont_batched=cu_seqlens is not None)
|
||||||
states, final_states = (rearrange(t, "... (p n) -> ... p n", n=dstate)
|
states, final_states = (rearrange(t, "... (p n) -> ... p n", n=dstate)
|
||||||
for t in [states, final_states])
|
for t in [states, final_states])
|
||||||
@ -189,7 +190,8 @@ def mamba_chunk_scan_combined(x,
|
|||||||
dt_limit=(0.0, float("inf")),
|
dt_limit=(0.0, float("inf")),
|
||||||
out=None,
|
out=None,
|
||||||
return_final_states=False,
|
return_final_states=False,
|
||||||
return_varlen_states=False):
|
return_varlen_states=False,
|
||||||
|
state_dtype=None):
|
||||||
"""
|
"""
|
||||||
Argument:
|
Argument:
|
||||||
x: (batch, seqlen, nheads, headdim)
|
x: (batch, seqlen, nheads, headdim)
|
||||||
@ -206,6 +208,7 @@ def mamba_chunk_scan_combined(x,
|
|||||||
cu_seqlens: (num_sequences + 1) or None, only used if return_varlen_states is True
|
cu_seqlens: (num_sequences + 1) or None, only used if return_varlen_states is True
|
||||||
dt_softplus: Whether to apply softplus to dt
|
dt_softplus: Whether to apply softplus to dt
|
||||||
out: Preallocated output tensor
|
out: Preallocated output tensor
|
||||||
|
state_dtype: The data type of the ssm state
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if not return_varlen_states:
|
if not return_varlen_states:
|
||||||
@ -229,7 +232,8 @@ def mamba_chunk_scan_combined(x,
|
|||||||
cu_seqlens=cu_seqlens,
|
cu_seqlens=cu_seqlens,
|
||||||
dt_softplus=dt_softplus,
|
dt_softplus=dt_softplus,
|
||||||
dt_limit=dt_limit,
|
dt_limit=dt_limit,
|
||||||
out=out)
|
out=out,
|
||||||
|
state_dtype=state_dtype)
|
||||||
if not return_varlen_states:
|
if not return_varlen_states:
|
||||||
if not return_final_states:
|
if not return_final_states:
|
||||||
return
|
return
|
||||||
|
|||||||
@ -12,7 +12,7 @@ from transformers import BambaConfig
|
|||||||
from vllm import envs
|
from vllm import envs
|
||||||
from vllm.attention.layer import Attention
|
from vllm.attention.layer import Attention
|
||||||
from vllm.compilation.decorators import support_torch_compile
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import CacheConfig, VllmConfig
|
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
||||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
from vllm.distributed.parallel_state import get_pp_group
|
from vllm.distributed.parallel_state import get_pp_group
|
||||||
from vllm.forward_context import get_forward_context
|
from vllm.forward_context import get_forward_context
|
||||||
@ -26,7 +26,7 @@ from vllm.model_executor.layers.mamba.mamba2_metadata import (
|
|||||||
Mamba2Metadata, prepare_mamba2_metadata)
|
Mamba2Metadata, prepare_mamba2_metadata)
|
||||||
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
|
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
|
||||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||||
MambaStateShapeCalculator)
|
MambaStateDtypeCalculator, MambaStateShapeCalculator)
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
@ -83,6 +83,7 @@ class BambaMixerDecoderLayer(nn.Module):
|
|||||||
def __init__(self,
|
def __init__(self,
|
||||||
config: BambaConfig,
|
config: BambaConfig,
|
||||||
layer_idx: int,
|
layer_idx: int,
|
||||||
|
model_config: Optional[ModelConfig] = None,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
prefix: str = "") -> None:
|
prefix: str = "") -> None:
|
||||||
@ -100,6 +101,8 @@ class BambaMixerDecoderLayer(nn.Module):
|
|||||||
head_dim=config.mamba_d_head,
|
head_dim=config.mamba_d_head,
|
||||||
rms_norm_eps=config.rms_norm_eps,
|
rms_norm_eps=config.rms_norm_eps,
|
||||||
activation=config.hidden_act,
|
activation=config.hidden_act,
|
||||||
|
model_config=model_config,
|
||||||
|
cache_config=cache_config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.mixer")
|
prefix=f"{prefix}.mixer")
|
||||||
|
|
||||||
@ -138,6 +141,7 @@ class BambaAttentionDecoderLayer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
config: BambaConfig,
|
config: BambaConfig,
|
||||||
layer_idx: int,
|
layer_idx: int,
|
||||||
|
model_config: Optional[ModelConfig] = None,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
@ -266,6 +270,7 @@ class BambaModel(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
config: BambaConfig = vllm_config.model_config.hf_config
|
config: BambaConfig = vllm_config.model_config.hf_config
|
||||||
|
model_config = vllm_config.model_config
|
||||||
cache_config = vllm_config.cache_config
|
cache_config = vllm_config.cache_config
|
||||||
quant_config = vllm_config.quant_config
|
quant_config = vllm_config.quant_config
|
||||||
lora_config = vllm_config.lora_config
|
lora_config = vllm_config.lora_config
|
||||||
@ -289,6 +294,7 @@ class BambaModel(nn.Module):
|
|||||||
return layer_class(
|
return layer_class(
|
||||||
config,
|
config,
|
||||||
layer_idx,
|
layer_idx,
|
||||||
|
model_config,
|
||||||
cache_config,
|
cache_config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=prefix,
|
prefix=prefix,
|
||||||
@ -437,6 +443,18 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
|||||||
}
|
}
|
||||||
embedding_padding_modules = ["lm_head"]
|
embedding_padding_modules = ["lm_head"]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_mamba_state_dtype_from_config(
|
||||||
|
cls,
|
||||||
|
vllm_config: "VllmConfig",
|
||||||
|
) -> tuple[torch.dtype, torch.dtype]:
|
||||||
|
|
||||||
|
return MambaStateDtypeCalculator.mamba2_state_dtype(
|
||||||
|
vllm_config.model_config.dtype,
|
||||||
|
vllm_config.cache_config.mamba_cache_dtype,
|
||||||
|
vllm_config.cache_config.mamba_ssm_cache_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_mamba_state_shape_from_config(
|
def get_mamba_state_shape_from_config(
|
||||||
cls,
|
cls,
|
||||||
@ -528,10 +546,13 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
|||||||
mamba_state_shape = \
|
mamba_state_shape = \
|
||||||
self.get_mamba_state_shape_from_config(
|
self.get_mamba_state_shape_from_config(
|
||||||
self.vllm_config, use_v1=False)
|
self.vllm_config, use_v1=False)
|
||||||
|
mamba_state_dtype = \
|
||||||
|
self.get_mamba_state_dtype_from_config(
|
||||||
|
self.vllm_config)
|
||||||
self.mamba_cache = MambaCacheManager(self.vllm_config,
|
self.mamba_cache = MambaCacheManager(self.vllm_config,
|
||||||
self.lm_head.weight.dtype,
|
|
||||||
num_mamba_layers,
|
num_mamba_layers,
|
||||||
*mamba_state_shape)
|
*mamba_state_shape,
|
||||||
|
*mamba_state_dtype)
|
||||||
|
|
||||||
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
||||||
|
|
||||||
|
|||||||
@ -318,7 +318,7 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
|
|||||||
# get mamba page size
|
# get mamba page size
|
||||||
mamba_page_size = MambaSpec(
|
mamba_page_size = MambaSpec(
|
||||||
shapes=model_cls.get_mamba_state_shape_from_config(vllm_config),
|
shapes=model_cls.get_mamba_state_shape_from_config(vllm_config),
|
||||||
dtype=kv_cache_dtype,
|
dtypes=model_cls.get_mamba_state_dtype_from_config(vllm_config),
|
||||||
block_size=model_config.max_model_len,
|
block_size=model_config.max_model_len,
|
||||||
).page_size_bytes
|
).page_size_bytes
|
||||||
|
|
||||||
|
|||||||
@ -11,7 +11,7 @@ from transformers import FalconH1Config
|
|||||||
from vllm import envs
|
from vllm import envs
|
||||||
from vllm.attention.layer import Attention
|
from vllm.attention.layer import Attention
|
||||||
from vllm.compilation.decorators import support_torch_compile
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import CacheConfig, VllmConfig
|
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
||||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
from vllm.distributed.parallel_state import get_pp_group
|
from vllm.distributed.parallel_state import get_pp_group
|
||||||
from vllm.forward_context import get_forward_context
|
from vllm.forward_context import get_forward_context
|
||||||
@ -25,7 +25,7 @@ from vllm.model_executor.layers.mamba.mamba2_metadata import (
|
|||||||
Mamba2Metadata, prepare_mamba2_metadata)
|
Mamba2Metadata, prepare_mamba2_metadata)
|
||||||
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
|
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
|
||||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||||
MambaStateShapeCalculator)
|
MambaStateDtypeCalculator, MambaStateShapeCalculator)
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
@ -85,6 +85,7 @@ class FalconH1SSMDecoderLayer(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: FalconH1Config,
|
config: FalconH1Config,
|
||||||
|
model_config: Optional[ModelConfig] = None,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
@ -108,6 +109,8 @@ class FalconH1SSMDecoderLayer(nn.Module):
|
|||||||
head_dim=config.mamba_d_head,
|
head_dim=config.mamba_d_head,
|
||||||
rms_norm_eps=config.rms_norm_eps,
|
rms_norm_eps=config.rms_norm_eps,
|
||||||
activation=config.hidden_act,
|
activation=config.hidden_act,
|
||||||
|
model_config=model_config,
|
||||||
|
cache_config=cache_config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
use_rms_norm=config.mamba_rms_norm,
|
use_rms_norm=config.mamba_rms_norm,
|
||||||
prefix=f"{prefix}.mixer",
|
prefix=f"{prefix}.mixer",
|
||||||
@ -317,6 +320,7 @@ class FalconH1ParallelHybrid(nn.Module):
|
|||||||
self,
|
self,
|
||||||
config: FalconH1Config,
|
config: FalconH1Config,
|
||||||
layer_idx: int,
|
layer_idx: int,
|
||||||
|
model_config: Optional[ModelConfig] = None,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
@ -339,6 +343,7 @@ class FalconH1ParallelHybrid(nn.Module):
|
|||||||
# Instantiate the SSM branch
|
# Instantiate the SSM branch
|
||||||
self.mamba = FalconH1SSMDecoderLayer(
|
self.mamba = FalconH1SSMDecoderLayer(
|
||||||
config=config,
|
config=config,
|
||||||
|
model_config=model_config,
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=ssm_prefix,
|
prefix=ssm_prefix,
|
||||||
@ -408,6 +413,7 @@ class FalconH1Model(nn.Module):
|
|||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
config: FalconH1Config = vllm_config.model_config.hf_config
|
config: FalconH1Config = vllm_config.model_config.hf_config
|
||||||
|
model_config = vllm_config.model_config
|
||||||
cache_config = vllm_config.cache_config
|
cache_config = vllm_config.cache_config
|
||||||
quant_config = vllm_config.quant_config
|
quant_config = vllm_config.quant_config
|
||||||
lora_config = vllm_config.lora_config
|
lora_config = vllm_config.lora_config
|
||||||
@ -435,6 +441,7 @@ class FalconH1Model(nn.Module):
|
|||||||
return layer_class(
|
return layer_class(
|
||||||
config,
|
config,
|
||||||
layer_idx,
|
layer_idx,
|
||||||
|
model_config,
|
||||||
cache_config,
|
cache_config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=prefix,
|
prefix=prefix,
|
||||||
@ -519,6 +526,18 @@ class FalconH1ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
|||||||
}
|
}
|
||||||
embedding_padding_modules = ["lm_head"]
|
embedding_padding_modules = ["lm_head"]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_mamba_state_dtype_from_config(
|
||||||
|
cls,
|
||||||
|
vllm_config: "VllmConfig",
|
||||||
|
) -> tuple[torch.dtype, torch.dtype]:
|
||||||
|
|
||||||
|
return MambaStateDtypeCalculator.mamba2_state_dtype(
|
||||||
|
vllm_config.model_config.dtype,
|
||||||
|
vllm_config.cache_config.mamba_cache_dtype,
|
||||||
|
vllm_config.cache_config.mamba_ssm_cache_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_mamba_state_shape_from_config(
|
def get_mamba_state_shape_from_config(
|
||||||
cls,
|
cls,
|
||||||
@ -624,12 +643,14 @@ class FalconH1ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
|||||||
mamba_state_shape = \
|
mamba_state_shape = \
|
||||||
self.get_mamba_state_shape_from_config(
|
self.get_mamba_state_shape_from_config(
|
||||||
self.vllm_config, use_v1=False)
|
self.vllm_config, use_v1=False)
|
||||||
|
mamba_state_dtype = \
|
||||||
|
self.get_mamba_state_dtype_from_config(
|
||||||
|
self.vllm_config)
|
||||||
self.mamba_cache = MambaCacheManager(
|
self.mamba_cache = MambaCacheManager(
|
||||||
self.vllm_config,
|
self.vllm_config,
|
||||||
self.lm_head.weight.dtype if hasattr(
|
|
||||||
self.lm_head, 'weight') else torch.bfloat16,
|
|
||||||
self.config.num_hidden_layers,
|
self.config.num_hidden_layers,
|
||||||
*mamba_state_shape,
|
*mamba_state_shape,
|
||||||
|
*mamba_state_dtype,
|
||||||
)
|
)
|
||||||
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
||||||
|
|
||||||
|
|||||||
@ -12,7 +12,7 @@ from transformers import GraniteMoeHybridConfig
|
|||||||
from vllm import envs
|
from vllm import envs
|
||||||
from vllm.attention.layer import Attention
|
from vllm.attention.layer import Attention
|
||||||
from vllm.compilation.decorators import support_torch_compile
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import CacheConfig, VllmConfig
|
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
||||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
from vllm.distributed.parallel_state import get_pp_group
|
from vllm.distributed.parallel_state import get_pp_group
|
||||||
from vllm.forward_context import get_forward_context
|
from vllm.forward_context import get_forward_context
|
||||||
@ -24,7 +24,7 @@ from vllm.model_executor.layers.mamba.mamba2_metadata import (
|
|||||||
Mamba2Metadata, prepare_mamba2_metadata)
|
Mamba2Metadata, prepare_mamba2_metadata)
|
||||||
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
|
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
|
||||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||||
MambaStateShapeCalculator)
|
MambaStateDtypeCalculator, MambaStateShapeCalculator)
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
@ -50,6 +50,7 @@ class GraniteMoeHybridMambaDecoderLayer(nn.Module):
|
|||||||
def __init__(self,
|
def __init__(self,
|
||||||
config: GraniteMoeHybridConfig,
|
config: GraniteMoeHybridConfig,
|
||||||
layer_idx: int,
|
layer_idx: int,
|
||||||
|
model_config: Optional[ModelConfig] = None,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
prefix: str = "") -> None:
|
prefix: str = "") -> None:
|
||||||
@ -70,6 +71,8 @@ class GraniteMoeHybridMambaDecoderLayer(nn.Module):
|
|||||||
head_dim=config.mamba_d_head,
|
head_dim=config.mamba_d_head,
|
||||||
rms_norm_eps=config.rms_norm_eps,
|
rms_norm_eps=config.rms_norm_eps,
|
||||||
activation=config.hidden_act,
|
activation=config.hidden_act,
|
||||||
|
model_config=model_config,
|
||||||
|
cache_config=cache_config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.mixer")
|
prefix=f"{prefix}.mixer")
|
||||||
|
|
||||||
@ -137,6 +140,7 @@ class GraniteMoeHybridAttentionDecoderLayer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
config: GraniteMoeHybridConfig,
|
config: GraniteMoeHybridConfig,
|
||||||
layer_idx: int,
|
layer_idx: int,
|
||||||
|
model_config: Optional[ModelConfig] = None,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
@ -217,6 +221,7 @@ class GraniteMoeHybridAttention(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: GraniteMoeHybridConfig,
|
config: GraniteMoeHybridConfig,
|
||||||
|
model_config: Optional[ModelConfig] = None,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
@ -316,6 +321,7 @@ class GraniteMoeHybridModel(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
config = vllm_config.model_config.hf_config
|
config = vllm_config.model_config.hf_config
|
||||||
|
model_config = vllm_config.model_config
|
||||||
cache_config = vllm_config.cache_config
|
cache_config = vllm_config.cache_config
|
||||||
quant_config = vllm_config.quant_config
|
quant_config = vllm_config.quant_config
|
||||||
lora_config = vllm_config.lora_config
|
lora_config = vllm_config.lora_config
|
||||||
@ -340,6 +346,7 @@ class GraniteMoeHybridModel(nn.Module):
|
|||||||
return layer_class(
|
return layer_class(
|
||||||
config,
|
config,
|
||||||
layer_idx,
|
layer_idx,
|
||||||
|
model_config,
|
||||||
cache_config,
|
cache_config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=prefix,
|
prefix=prefix,
|
||||||
@ -527,6 +534,18 @@ class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA,
|
|||||||
}
|
}
|
||||||
embedding_padding_modules = ["lm_head"]
|
embedding_padding_modules = ["lm_head"]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_mamba_state_dtype_from_config(
|
||||||
|
cls,
|
||||||
|
vllm_config: "VllmConfig",
|
||||||
|
) -> tuple[torch.dtype, torch.dtype]:
|
||||||
|
|
||||||
|
return MambaStateDtypeCalculator.mamba2_state_dtype(
|
||||||
|
vllm_config.model_config.dtype,
|
||||||
|
vllm_config.cache_config.mamba_cache_dtype,
|
||||||
|
vllm_config.cache_config.mamba_ssm_cache_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_mamba_state_shape_from_config(
|
def get_mamba_state_shape_from_config(
|
||||||
cls,
|
cls,
|
||||||
@ -625,10 +644,13 @@ class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA,
|
|||||||
mamba_state_shape = \
|
mamba_state_shape = \
|
||||||
self.get_mamba_state_shape_from_config(
|
self.get_mamba_state_shape_from_config(
|
||||||
self.vllm_config, use_v1=False)
|
self.vllm_config, use_v1=False)
|
||||||
|
mamba_state_dtype = \
|
||||||
|
self.get_mamba_state_dtype_from_config(
|
||||||
|
self.vllm_config)
|
||||||
self.mamba_cache = MambaCacheManager(self.vllm_config,
|
self.mamba_cache = MambaCacheManager(self.vllm_config,
|
||||||
self.model_config.dtype,
|
|
||||||
num_mamba_layers,
|
num_mamba_layers,
|
||||||
*mamba_state_shape)
|
*mamba_state_shape,
|
||||||
|
*mamba_state_dtype)
|
||||||
|
|
||||||
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
||||||
|
|
||||||
|
|||||||
@ -10,7 +10,7 @@ from transformers import JambaConfig
|
|||||||
|
|
||||||
from vllm import envs
|
from vllm import envs
|
||||||
from vllm.attention.layer import Attention
|
from vllm.attention.layer import Attention
|
||||||
from vllm.config import CacheConfig, VllmConfig
|
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
||||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
from vllm.distributed.parallel_state import get_pp_group
|
from vllm.distributed.parallel_state import get_pp_group
|
||||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||||
@ -21,7 +21,7 @@ from vllm.model_executor.layers.linear import (QKVParallelLinear,
|
|||||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer
|
from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer
|
||||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||||
MambaStateShapeCalculator)
|
MambaStateDtypeCalculator, MambaStateShapeCalculator)
|
||||||
from vllm.model_executor.layers.pooler import DispatchPooler, Pooler
|
from vllm.model_executor.layers.pooler import DispatchPooler, Pooler
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
@ -94,6 +94,7 @@ class JambaMambaDecoderLayer(nn.Module):
|
|||||||
def __init__(self,
|
def __init__(self,
|
||||||
config: JambaConfig,
|
config: JambaConfig,
|
||||||
layer_idx: int,
|
layer_idx: int,
|
||||||
|
model_config: Optional[ModelConfig] = None,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
is_lora_enabled: Optional[bool] = False,
|
is_lora_enabled: Optional[bool] = False,
|
||||||
@ -114,6 +115,8 @@ class JambaMambaDecoderLayer(nn.Module):
|
|||||||
rms_norm_eps=config.rms_norm_eps,
|
rms_norm_eps=config.rms_norm_eps,
|
||||||
activation=config.hidden_act,
|
activation=config.hidden_act,
|
||||||
is_lora_enabled = self.is_lora_enabled,
|
is_lora_enabled = self.is_lora_enabled,
|
||||||
|
model_config=model_config,
|
||||||
|
cache_config=cache_config,
|
||||||
prefix=f"{prefix}.mixer",
|
prefix=f"{prefix}.mixer",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -164,6 +167,7 @@ class JambaAttentionDecoderLayer(nn.Module):
|
|||||||
def __init__(self,
|
def __init__(self,
|
||||||
config: JambaConfig,
|
config: JambaConfig,
|
||||||
layer_idx: int,
|
layer_idx: int,
|
||||||
|
model_config: Optional[ModelConfig] = None,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
@ -280,6 +284,7 @@ class JambaModel(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
config = vllm_config.model_config.hf_config
|
config = vllm_config.model_config.hf_config
|
||||||
|
model_config = vllm_config.model_config
|
||||||
cache_config = vllm_config.cache_config
|
cache_config = vllm_config.cache_config
|
||||||
quant_config = vllm_config.quant_config
|
quant_config = vllm_config.quant_config
|
||||||
lora_config = vllm_config.lora_config
|
lora_config = vllm_config.lora_config
|
||||||
@ -304,6 +309,7 @@ class JambaModel(nn.Module):
|
|||||||
config.layers_block_type[layer_idx]]
|
config.layers_block_type[layer_idx]]
|
||||||
return layer_class(config,
|
return layer_class(config,
|
||||||
layer_idx,
|
layer_idx,
|
||||||
|
model_config,
|
||||||
cache_config,
|
cache_config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=prefix,
|
prefix=prefix,
|
||||||
@ -520,9 +526,11 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
|||||||
self.vllm_config.parallel_config, LayerBlockType.mamba)
|
self.vllm_config.parallel_config, LayerBlockType.mamba)
|
||||||
state_shape = self.get_mamba_state_shape_from_config(
|
state_shape = self.get_mamba_state_shape_from_config(
|
||||||
self.vllm_config)
|
self.vllm_config)
|
||||||
|
state_dtype = self.get_mamba_state_dtype_from_config(
|
||||||
|
self.vllm_config)
|
||||||
self.mamba_cache = MambaCacheManager(self.vllm_config,
|
self.mamba_cache = MambaCacheManager(self.vllm_config,
|
||||||
self.lm_head.weight.dtype,
|
num_layers, *state_shape,
|
||||||
num_layers, *state_shape)
|
*state_dtype)
|
||||||
|
|
||||||
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
||||||
|
|
||||||
@ -537,6 +545,18 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
|||||||
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
|
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
|
||||||
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
|
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_mamba_state_dtype_from_config(
|
||||||
|
cls,
|
||||||
|
vllm_config: "VllmConfig",
|
||||||
|
) -> tuple[torch.dtype, torch.dtype]:
|
||||||
|
|
||||||
|
return MambaStateDtypeCalculator.mamba1_state_dtype(
|
||||||
|
vllm_config.model_config.dtype,
|
||||||
|
vllm_config.cache_config.mamba_cache_dtype,
|
||||||
|
vllm_config.cache_config.mamba_ssm_cache_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_mamba_state_shape_from_config(
|
def get_mamba_state_shape_from_config(
|
||||||
cls,
|
cls,
|
||||||
|
|||||||
@ -9,13 +9,13 @@ from torch import nn
|
|||||||
from transformers import MambaConfig
|
from transformers import MambaConfig
|
||||||
|
|
||||||
from vllm import envs
|
from vllm import envs
|
||||||
from vllm.config import CacheConfig, VllmConfig
|
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
||||||
from vllm.distributed.parallel_state import get_pp_group
|
from vllm.distributed.parallel_state import get_pp_group
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer
|
from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer
|
||||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||||
MambaStateShapeCalculator)
|
MambaStateDtypeCalculator, MambaStateShapeCalculator)
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
QuantizationConfig)
|
QuantizationConfig)
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
@ -40,6 +40,7 @@ class MambaDecoderLayer(nn.Module):
|
|||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
config: MambaConfig,
|
config: MambaConfig,
|
||||||
|
model_config: Optional[ModelConfig] = None,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
is_lora_enabled: Optional[bool] = False,
|
is_lora_enabled: Optional[bool] = False,
|
||||||
@ -61,6 +62,8 @@ class MambaDecoderLayer(nn.Module):
|
|||||||
rms_norm_eps=mixer_rms_eps,
|
rms_norm_eps=mixer_rms_eps,
|
||||||
activation=config.hidden_act,
|
activation=config.hidden_act,
|
||||||
is_lora_enabled=self.is_lora_enabled,
|
is_lora_enabled=self.is_lora_enabled,
|
||||||
|
model_config=model_config,
|
||||||
|
cache_config=cache_config,
|
||||||
prefix=f"{prefix}.mixer")
|
prefix=f"{prefix}.mixer")
|
||||||
|
|
||||||
self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
||||||
@ -88,6 +91,7 @@ class MambaModel(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
config = vllm_config.model_config.hf_config
|
config = vllm_config.model_config.hf_config
|
||||||
|
model_config = vllm_config.model_config
|
||||||
cache_config = vllm_config.cache_config
|
cache_config = vllm_config.cache_config
|
||||||
quant_config = vllm_config.quant_config
|
quant_config = vllm_config.quant_config
|
||||||
lora_config = vllm_config.lora_config
|
lora_config = vllm_config.lora_config
|
||||||
@ -108,6 +112,7 @@ class MambaModel(nn.Module):
|
|||||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||||
config.num_hidden_layers,
|
config.num_hidden_layers,
|
||||||
lambda prefix: MambaDecoderLayer(config,
|
lambda prefix: MambaDecoderLayer(config,
|
||||||
|
model_config=model_config,
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
is_lora_enabled=is_lora_enabled,
|
is_lora_enabled=is_lora_enabled,
|
||||||
@ -243,9 +248,11 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP):
|
|||||||
self.vllm_config.parallel_config, LayerBlockType.mamba)
|
self.vllm_config.parallel_config, LayerBlockType.mamba)
|
||||||
state_shape = self.get_mamba_state_shape_from_config(
|
state_shape = self.get_mamba_state_shape_from_config(
|
||||||
self.vllm_config)
|
self.vllm_config)
|
||||||
|
state_dtype = self.get_mamba_state_dtype_from_config(
|
||||||
|
self.vllm_config)
|
||||||
self.mamba_cache = MambaCacheManager(self.vllm_config,
|
self.mamba_cache = MambaCacheManager(self.vllm_config,
|
||||||
self.lm_head.weight.dtype,
|
num_layers, *state_shape,
|
||||||
num_layers, *state_shape)
|
*state_dtype)
|
||||||
|
|
||||||
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
||||||
|
|
||||||
@ -254,6 +261,18 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP):
|
|||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_mamba_state_dtype_from_config(
|
||||||
|
cls,
|
||||||
|
vllm_config: "VllmConfig",
|
||||||
|
) -> tuple[torch.dtype, torch.dtype]:
|
||||||
|
|
||||||
|
return MambaStateDtypeCalculator.mamba1_state_dtype(
|
||||||
|
vllm_config.model_config.dtype,
|
||||||
|
vllm_config.cache_config.mamba_cache_dtype,
|
||||||
|
vllm_config.cache_config.mamba_ssm_cache_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_mamba_state_shape_from_config(
|
def get_mamba_state_shape_from_config(
|
||||||
cls,
|
cls,
|
||||||
|
|||||||
@ -11,7 +11,7 @@ from transformers import MambaConfig
|
|||||||
from vllm import envs
|
from vllm import envs
|
||||||
from vllm.attention.backends.abstract import AttentionMetadata
|
from vllm.attention.backends.abstract import AttentionMetadata
|
||||||
from vllm.compilation.decorators import support_torch_compile
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
||||||
from vllm.distributed.parallel_state import get_pp_group
|
from vllm.distributed.parallel_state import get_pp_group
|
||||||
from vllm.forward_context import get_forward_context
|
from vllm.forward_context import get_forward_context
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
@ -20,7 +20,7 @@ from vllm.model_executor.layers.mamba.mamba2_metadata import (
|
|||||||
Mamba2Metadata, prepare_mamba2_metadata)
|
Mamba2Metadata, prepare_mamba2_metadata)
|
||||||
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
|
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
|
||||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||||
MambaStateShapeCalculator)
|
MambaStateDtypeCalculator, MambaStateShapeCalculator)
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
QuantizationConfig)
|
QuantizationConfig)
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
@ -45,6 +45,8 @@ class Mamba2DecoderLayer(nn.Module):
|
|||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
config: MambaConfig,
|
config: MambaConfig,
|
||||||
|
model_config: Optional[ModelConfig] = None,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
prefix: str = "") -> None:
|
prefix: str = "") -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -62,6 +64,8 @@ class Mamba2DecoderLayer(nn.Module):
|
|||||||
head_dim=config.head_dim,
|
head_dim=config.head_dim,
|
||||||
rms_norm_eps=config.layer_norm_epsilon,
|
rms_norm_eps=config.layer_norm_epsilon,
|
||||||
activation=config.hidden_act,
|
activation=config.hidden_act,
|
||||||
|
model_config=model_config,
|
||||||
|
cache_config=cache_config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.mixer")
|
prefix=f"{prefix}.mixer")
|
||||||
|
|
||||||
@ -93,6 +97,8 @@ class Mamba2Model(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
config = vllm_config.model_config.hf_config
|
config = vllm_config.model_config.hf_config
|
||||||
|
model_config = vllm_config.model_config
|
||||||
|
cache_config = vllm_config.cache_config
|
||||||
quant_config = vllm_config.quant_config
|
quant_config = vllm_config.quant_config
|
||||||
lora_config = vllm_config.lora_config
|
lora_config = vllm_config.lora_config
|
||||||
is_lora_enabled = bool(lora_config)
|
is_lora_enabled = bool(lora_config)
|
||||||
@ -112,8 +118,11 @@ class Mamba2Model(nn.Module):
|
|||||||
|
|
||||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||||
config.num_hidden_layers,
|
config.num_hidden_layers,
|
||||||
lambda prefix: Mamba2DecoderLayer(
|
lambda prefix: Mamba2DecoderLayer(config,
|
||||||
config, quant_config=quant_config, prefix=prefix),
|
model_config=model_config,
|
||||||
|
cache_config=cache_config,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=prefix),
|
||||||
prefix=f"{prefix}.layers")
|
prefix=f"{prefix}.layers")
|
||||||
|
|
||||||
self.norm_f = RMSNorm(config.hidden_size,
|
self.norm_f = RMSNorm(config.hidden_size,
|
||||||
@ -200,6 +209,18 @@ class Mamba2Model(nn.Module):
|
|||||||
|
|
||||||
class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
|
class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_mamba_state_dtype_from_config(
|
||||||
|
cls,
|
||||||
|
vllm_config: "VllmConfig",
|
||||||
|
) -> tuple[torch.dtype, torch.dtype]:
|
||||||
|
|
||||||
|
return MambaStateDtypeCalculator.mamba2_state_dtype(
|
||||||
|
vllm_config.model_config.dtype,
|
||||||
|
vllm_config.cache_config.mamba_cache_dtype,
|
||||||
|
vllm_config.cache_config.mamba_ssm_cache_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_mamba_state_shape_from_config(
|
def get_mamba_state_shape_from_config(
|
||||||
cls,
|
cls,
|
||||||
@ -290,10 +311,13 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
|
|||||||
mamba_state_shape = \
|
mamba_state_shape = \
|
||||||
self.get_mamba_state_shape_from_config(
|
self.get_mamba_state_shape_from_config(
|
||||||
self.vllm_config, use_v1=False)
|
self.vllm_config, use_v1=False)
|
||||||
|
mamba_state_dtype = \
|
||||||
|
self.get_mamba_state_dtype_from_config(
|
||||||
|
self.vllm_config)
|
||||||
self.mamba_cache = MambaCacheManager(self.vllm_config,
|
self.mamba_cache = MambaCacheManager(self.vllm_config,
|
||||||
self.lm_head.weight.dtype,
|
|
||||||
num_mamba_layers,
|
num_mamba_layers,
|
||||||
*mamba_state_shape)
|
*mamba_state_shape,
|
||||||
|
*mamba_state_dtype)
|
||||||
|
|
||||||
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -24,9 +24,14 @@ class MambaCacheParams:
|
|||||||
|
|
||||||
class MambaCacheManager(ConstantSizeCache):
|
class MambaCacheManager(ConstantSizeCache):
|
||||||
|
|
||||||
def __init__(self, vllm_config: VllmConfig, dtype: torch.dtype,
|
def __init__(self, vllm_config: VllmConfig, num_mamba_layers: int,
|
||||||
num_mamba_layers: int, conv_state_shape: tuple[int, int],
|
conv_state_shape: tuple[int, int],
|
||||||
temporal_state_shape: tuple[int, int]):
|
temporal_state_shape: tuple[int, int],
|
||||||
|
conv_state_dtype: torch.dtype,
|
||||||
|
temporal_state_dtype: torch.dtype):
|
||||||
|
|
||||||
|
self.conv_state_dtype = conv_state_dtype
|
||||||
|
self.temporal_state_dtype = temporal_state_dtype
|
||||||
|
|
||||||
# Determine max batch size to set size of MambaCache
|
# Determine max batch size to set size of MambaCache
|
||||||
max_batch_size = vllm_config.scheduler_config.max_num_seqs
|
max_batch_size = vllm_config.scheduler_config.max_num_seqs
|
||||||
@ -40,11 +45,11 @@ class MambaCacheManager(ConstantSizeCache):
|
|||||||
assert conv_state_shape[0] > conv_state_shape[1]
|
assert conv_state_shape[0] > conv_state_shape[1]
|
||||||
conv_state = torch.empty(size=(num_mamba_layers, max_batch_size) +
|
conv_state = torch.empty(size=(num_mamba_layers, max_batch_size) +
|
||||||
(conv_state_shape[1], conv_state_shape[0]),
|
(conv_state_shape[1], conv_state_shape[0]),
|
||||||
dtype=dtype,
|
dtype=self.conv_state_dtype,
|
||||||
device="cuda").transpose(-1, -2)
|
device="cuda").transpose(-1, -2)
|
||||||
temporal_state = torch.empty(size=(num_mamba_layers, max_batch_size) +
|
temporal_state = torch.empty(size=(num_mamba_layers, max_batch_size) +
|
||||||
temporal_state_shape,
|
temporal_state_shape,
|
||||||
dtype=dtype,
|
dtype=self.temporal_state_dtype,
|
||||||
device="cuda")
|
device="cuda")
|
||||||
|
|
||||||
self._mamba_cache = (conv_state, temporal_state)
|
self._mamba_cache = (conv_state, temporal_state)
|
||||||
|
|||||||
@ -16,7 +16,8 @@ from transformers import MiniMaxConfig
|
|||||||
|
|
||||||
from vllm import envs
|
from vllm import envs
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention, AttentionMetadata
|
||||||
from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config
|
from vllm.config import (CacheConfig, ModelConfig, VllmConfig,
|
||||||
|
get_current_vllm_config)
|
||||||
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
|
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
|
||||||
from vllm.distributed.parallel_state import (
|
from vllm.distributed.parallel_state import (
|
||||||
get_pp_group, get_tensor_model_parallel_rank,
|
get_pp_group, get_tensor_model_parallel_rank,
|
||||||
@ -36,7 +37,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
|||||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
from vllm.model_executor.layers.mamba.abstract import MambaBase
|
from vllm.model_executor.layers.mamba.abstract import MambaBase
|
||||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||||
MambaStateShapeCalculator)
|
MambaStateDtypeCalculator, MambaStateShapeCalculator)
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
QuantizationConfig)
|
QuantizationConfig)
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
@ -338,6 +339,12 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
|
|||||||
def mamba_type(self) -> str:
|
def mamba_type(self) -> str:
|
||||||
return "linear_attention"
|
return "linear_attention"
|
||||||
|
|
||||||
|
def get_state_dtype(self) -> tuple[torch.dtype]:
|
||||||
|
return MambaStateDtypeCalculator.linear_attention_state_dtype(
|
||||||
|
self.model_config.dtype,
|
||||||
|
self.cache_config.mamba_cache_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
|
def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
|
||||||
return MambaStateShapeCalculator.linear_attention_state_shape(
|
return MambaStateShapeCalculator.linear_attention_state_shape(
|
||||||
num_heads=self.num_heads,
|
num_heads=self.num_heads,
|
||||||
@ -353,6 +360,8 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
|
|||||||
max_position: int,
|
max_position: int,
|
||||||
block_size: int,
|
block_size: int,
|
||||||
num_hidden_layer: int,
|
num_hidden_layer: int,
|
||||||
|
model_config: Optional[ModelConfig] = None,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
layer_idx: int = 0,
|
layer_idx: int = 0,
|
||||||
linear_layer_idx: int = 0,
|
linear_layer_idx: int = 0,
|
||||||
@ -374,6 +383,8 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
|
|||||||
self.tp_heads = self.total_num_heads // self.tp_size
|
self.tp_heads = self.total_num_heads // self.tp_size
|
||||||
self.qkv_size = self.num_heads * self.head_dim
|
self.qkv_size = self.num_heads * self.head_dim
|
||||||
self.tp_hidden = self.head_dim * self.tp_heads
|
self.tp_hidden = self.head_dim * self.tp_heads
|
||||||
|
self.model_config = model_config
|
||||||
|
self.cache_config = cache_config
|
||||||
self.prefix = prefix
|
self.prefix = prefix
|
||||||
|
|
||||||
self.qkv_proj = ColumnParallelLinear(
|
self.qkv_proj = ColumnParallelLinear(
|
||||||
@ -657,6 +668,7 @@ class MiniMaxText01DecoderLayer(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: MiniMaxConfig,
|
config: MiniMaxConfig,
|
||||||
|
model_config: Optional[ModelConfig] = None,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
expert_num: int = 1,
|
expert_num: int = 1,
|
||||||
@ -693,6 +705,8 @@ class MiniMaxText01DecoderLayer(nn.Module):
|
|||||||
max_position=max_position_embeddings,
|
max_position=max_position_embeddings,
|
||||||
block_size=config.block if hasattr(config, "block") else 256,
|
block_size=config.block if hasattr(config, "block") else 256,
|
||||||
num_hidden_layer=config.num_hidden_layers,
|
num_hidden_layer=config.num_hidden_layers,
|
||||||
|
model_config=model_config,
|
||||||
|
cache_config=cache_config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
layer_idx=self._ilayer,
|
layer_idx=self._ilayer,
|
||||||
linear_layer_idx=linear_layer_id,
|
linear_layer_idx=linear_layer_id,
|
||||||
@ -861,6 +875,7 @@ class MiniMaxText01Model(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: MiniMaxConfig,
|
config: MiniMaxConfig,
|
||||||
|
model_config: Optional[ModelConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
scheduler_config=None,
|
scheduler_config=None,
|
||||||
@ -910,6 +925,7 @@ class MiniMaxText01Model(nn.Module):
|
|||||||
decoder_kwargs = {
|
decoder_kwargs = {
|
||||||
"quant_config": quant_config,
|
"quant_config": quant_config,
|
||||||
"layer_id": layer_idx,
|
"layer_id": layer_idx,
|
||||||
|
"model_config": model_config,
|
||||||
"cache_config": cache_config
|
"cache_config": cache_config
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1111,8 +1127,9 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid):
|
|||||||
self.config.max_model_len = vllm_config.model_config.max_model_len
|
self.config.max_model_len = vllm_config.model_config.max_model_len
|
||||||
self.model = MiniMaxText01Model(
|
self.model = MiniMaxText01Model(
|
||||||
self.config,
|
self.config,
|
||||||
quant_config,
|
model_config=vllm_config.model_config,
|
||||||
cache_config=vllm_config.cache_config,
|
cache_config=vllm_config.cache_config,
|
||||||
|
quant_config=quant_config,
|
||||||
scheduler_config=vllm_config.scheduler_config,
|
scheduler_config=vllm_config.scheduler_config,
|
||||||
prefix=maybe_prefix(prefix, "model"))
|
prefix=maybe_prefix(prefix, "model"))
|
||||||
if get_pp_group().is_last_rank:
|
if get_pp_group().is_last_rank:
|
||||||
@ -1409,6 +1426,17 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid):
|
|||||||
load_basic_weight(name, loaded_weight, self)
|
load_basic_weight(name, loaded_weight, self)
|
||||||
return loaded_params
|
return loaded_params
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_mamba_state_dtype_from_config(
|
||||||
|
cls,
|
||||||
|
vllm_config: "VllmConfig",
|
||||||
|
) -> tuple[torch.dtype, torch.dtype]:
|
||||||
|
|
||||||
|
return MambaStateDtypeCalculator.linear_attention_state_dtype(
|
||||||
|
vllm_config.model_config.dtype,
|
||||||
|
vllm_config.cache_config.mamba_cache_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_mamba_state_shape_from_config(
|
def get_mamba_state_shape_from_config(
|
||||||
cls,
|
cls,
|
||||||
|
|||||||
@ -26,7 +26,7 @@ from torch import nn
|
|||||||
from vllm import envs
|
from vllm import envs
|
||||||
from vllm.attention.layer import Attention
|
from vllm.attention.layer import Attention
|
||||||
from vllm.compilation.decorators import support_torch_compile
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import CacheConfig, VllmConfig
|
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
||||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
from vllm.distributed.parallel_state import get_pp_group
|
from vllm.distributed.parallel_state import get_pp_group
|
||||||
from vllm.forward_context import get_forward_context
|
from vllm.forward_context import get_forward_context
|
||||||
@ -40,7 +40,7 @@ from vllm.model_executor.layers.mamba.mamba2_metadata import (
|
|||||||
Mamba2Metadata, prepare_mamba2_metadata)
|
Mamba2Metadata, prepare_mamba2_metadata)
|
||||||
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
|
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
|
||||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||||
MambaStateShapeCalculator)
|
MambaStateDtypeCalculator, MambaStateShapeCalculator)
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||||
@ -110,6 +110,7 @@ class NemotronHMLPDecoderLayer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
config: NemotronHConfig,
|
config: NemotronHConfig,
|
||||||
layer_idx: int,
|
layer_idx: int,
|
||||||
|
model_config: Optional[ModelConfig] = None,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
@ -149,6 +150,7 @@ class NemotronHMambaDecoderLayer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
config: NemotronHConfig,
|
config: NemotronHConfig,
|
||||||
layer_idx: int,
|
layer_idx: int,
|
||||||
|
model_config: Optional[ModelConfig] = None,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
@ -167,6 +169,8 @@ class NemotronHMambaDecoderLayer(nn.Module):
|
|||||||
head_dim=config.mamba_head_dim,
|
head_dim=config.mamba_head_dim,
|
||||||
rms_norm_eps=config.rms_norm_eps,
|
rms_norm_eps=config.rms_norm_eps,
|
||||||
activation=config.mamba_hidden_act,
|
activation=config.mamba_hidden_act,
|
||||||
|
model_config=model_config,
|
||||||
|
cache_config=cache_config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.mixer",
|
prefix=f"{prefix}.mixer",
|
||||||
)
|
)
|
||||||
@ -198,6 +202,7 @@ class NemotronHAttention(nn.Module):
|
|||||||
self,
|
self,
|
||||||
config: NemotronHConfig,
|
config: NemotronHConfig,
|
||||||
layer_idx: int,
|
layer_idx: int,
|
||||||
|
model_config: Optional[ModelConfig] = None,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
@ -270,6 +275,7 @@ class NemotronHAttentionDecoderLayer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
config: NemotronHConfig,
|
config: NemotronHConfig,
|
||||||
layer_idx: int,
|
layer_idx: int,
|
||||||
|
model_config: Optional[ModelConfig] = None,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
@ -279,6 +285,7 @@ class NemotronHAttentionDecoderLayer(nn.Module):
|
|||||||
self.mixer = NemotronHAttention(
|
self.mixer = NemotronHAttention(
|
||||||
config,
|
config,
|
||||||
layer_idx,
|
layer_idx,
|
||||||
|
model_config,
|
||||||
cache_config,
|
cache_config,
|
||||||
quant_config,
|
quant_config,
|
||||||
prefix=f"{prefix}.mixer",
|
prefix=f"{prefix}.mixer",
|
||||||
@ -317,6 +324,7 @@ class NemotronHModel(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
config: NemotronHConfig = vllm_config.model_config.hf_config
|
config: NemotronHConfig = vllm_config.model_config.hf_config
|
||||||
|
model_config = vllm_config.model_config
|
||||||
cache_config = vllm_config.cache_config
|
cache_config = vllm_config.cache_config
|
||||||
quant_config = vllm_config.quant_config
|
quant_config = vllm_config.quant_config
|
||||||
lora_config = vllm_config.lora_config
|
lora_config = vllm_config.lora_config
|
||||||
@ -340,6 +348,7 @@ class NemotronHModel(nn.Module):
|
|||||||
return layer_class(
|
return layer_class(
|
||||||
config,
|
config,
|
||||||
layer_idx,
|
layer_idx,
|
||||||
|
model_config,
|
||||||
cache_config,
|
cache_config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=prefix,
|
prefix=prefix,
|
||||||
@ -478,6 +487,18 @@ class NemotronHForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
|||||||
}
|
}
|
||||||
embedding_padding_modules = ["lm_head"]
|
embedding_padding_modules = ["lm_head"]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_mamba_state_dtype_from_config(
|
||||||
|
cls,
|
||||||
|
vllm_config: "VllmConfig",
|
||||||
|
) -> tuple[torch.dtype, torch.dtype]:
|
||||||
|
|
||||||
|
return MambaStateDtypeCalculator.mamba2_state_dtype(
|
||||||
|
vllm_config.model_config.dtype,
|
||||||
|
vllm_config.cache_config.mamba_cache_dtype,
|
||||||
|
vllm_config.cache_config.mamba_ssm_cache_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_mamba_state_shape_from_config(
|
def get_mamba_state_shape_from_config(
|
||||||
cls,
|
cls,
|
||||||
@ -569,10 +590,13 @@ class NemotronHForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
|||||||
mamba_state_shape = \
|
mamba_state_shape = \
|
||||||
self.get_mamba_state_shape_from_config(
|
self.get_mamba_state_shape_from_config(
|
||||||
self.vllm_config, use_v1=False)
|
self.vllm_config, use_v1=False)
|
||||||
|
mamba_state_dtype = \
|
||||||
|
self.get_mamba_state_dtype_from_config(
|
||||||
|
self.vllm_config)
|
||||||
self.mamba_cache = MambaCacheManager(self.vllm_config,
|
self.mamba_cache = MambaCacheManager(self.vllm_config,
|
||||||
self.lm_head.weight.dtype,
|
|
||||||
num_mamba_layers,
|
num_mamba_layers,
|
||||||
*mamba_state_shape)
|
*mamba_state_shape,
|
||||||
|
*mamba_state_dtype)
|
||||||
|
|
||||||
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
||||||
|
|
||||||
|
|||||||
@ -18,7 +18,7 @@ from transformers import Zamba2Config
|
|||||||
from vllm import envs
|
from vllm import envs
|
||||||
from vllm.attention.layer import Attention
|
from vllm.attention.layer import Attention
|
||||||
from vllm.compilation.decorators import support_torch_compile
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import CacheConfig, VllmConfig
|
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
||||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
from vllm.forward_context import get_forward_context
|
from vllm.forward_context import get_forward_context
|
||||||
from vllm.model_executor.layers.activation import GeluAndMul
|
from vllm.model_executor.layers.activation import GeluAndMul
|
||||||
@ -33,7 +33,7 @@ from vllm.model_executor.layers.mamba.mamba2_metadata import (
|
|||||||
Mamba2Metadata, prepare_mamba2_metadata)
|
Mamba2Metadata, prepare_mamba2_metadata)
|
||||||
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
|
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
|
||||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||||
MambaStateShapeCalculator)
|
MambaStateDtypeCalculator, MambaStateShapeCalculator)
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
@ -478,6 +478,8 @@ class Zamba2MambaDecoderLayer(nn.Module):
|
|||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
config: Zamba2Config,
|
config: Zamba2Config,
|
||||||
|
model_config: Optional[ModelConfig] = None,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
prefix: str = "") -> None:
|
prefix: str = "") -> None:
|
||||||
"""Initialize the Mamba decoder layer.
|
"""Initialize the Mamba decoder layer.
|
||||||
@ -502,6 +504,8 @@ class Zamba2MambaDecoderLayer(nn.Module):
|
|||||||
config.n_mamba_heads,
|
config.n_mamba_heads,
|
||||||
rms_norm_eps=config.rms_norm_eps,
|
rms_norm_eps=config.rms_norm_eps,
|
||||||
activation="silu",
|
activation="silu",
|
||||||
|
model_config=model_config,
|
||||||
|
cache_config=cache_config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.mixer")
|
prefix=f"{prefix}.mixer")
|
||||||
|
|
||||||
@ -578,6 +582,8 @@ class Zamba2HybridLayer(nn.Module):
|
|||||||
shared_transformer: Zamba2AttentionDecoderLayer,
|
shared_transformer: Zamba2AttentionDecoderLayer,
|
||||||
config: Zamba2Config,
|
config: Zamba2Config,
|
||||||
block_idx: int,
|
block_idx: int,
|
||||||
|
model_config: Optional[ModelConfig] = None,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -596,6 +602,8 @@ class Zamba2HybridLayer(nn.Module):
|
|||||||
bias=False,
|
bias=False,
|
||||||
quant_config=quant_config)
|
quant_config=quant_config)
|
||||||
self.mamba_decoder = Zamba2MambaDecoderLayer(config,
|
self.mamba_decoder = Zamba2MambaDecoderLayer(config,
|
||||||
|
model_config=model_config,
|
||||||
|
cache_config=cache_config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=prefix)
|
prefix=prefix)
|
||||||
|
|
||||||
@ -669,6 +677,7 @@ class Zamba2Model(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
config = vllm_config.model_config.hf_config
|
config = vllm_config.model_config.hf_config
|
||||||
|
model_config = vllm_config.model_config
|
||||||
cache_config = vllm_config.cache_config
|
cache_config = vllm_config.cache_config
|
||||||
quant_config = vllm_config.quant_config
|
quant_config = vllm_config.quant_config
|
||||||
lora_config = vllm_config.lora_config
|
lora_config = vllm_config.lora_config
|
||||||
@ -718,11 +727,15 @@ class Zamba2Model(nn.Module):
|
|||||||
Zamba2HybridLayer(block,
|
Zamba2HybridLayer(block,
|
||||||
config,
|
config,
|
||||||
block_idx,
|
block_idx,
|
||||||
quant_config,
|
model_config=model_config,
|
||||||
|
cache_config=cache_config,
|
||||||
|
quant_config=quant_config,
|
||||||
prefix=prefix))
|
prefix=prefix))
|
||||||
else:
|
else:
|
||||||
layers.append(
|
layers.append(
|
||||||
Zamba2MambaDecoderLayer(config,
|
Zamba2MambaDecoderLayer(config,
|
||||||
|
model_config=model_config,
|
||||||
|
cache_config=cache_config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=prefix))
|
prefix=prefix))
|
||||||
self.layers = nn.ModuleList(layers)
|
self.layers = nn.ModuleList(layers)
|
||||||
@ -848,6 +861,18 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid):
|
|||||||
"1.weight": "B.weight",
|
"1.weight": "B.weight",
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_mamba_state_dtype_from_config(
|
||||||
|
cls,
|
||||||
|
vllm_config: "VllmConfig",
|
||||||
|
) -> tuple[torch.dtype, torch.dtype]:
|
||||||
|
|
||||||
|
return MambaStateDtypeCalculator.mamba2_state_dtype(
|
||||||
|
vllm_config.model_config.dtype,
|
||||||
|
vllm_config.cache_config.mamba_cache_dtype,
|
||||||
|
vllm_config.cache_config.mamba_ssm_cache_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_mamba_state_shape_from_config(
|
def get_mamba_state_shape_from_config(
|
||||||
cls,
|
cls,
|
||||||
@ -966,10 +991,13 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid):
|
|||||||
mamba_state_shape = \
|
mamba_state_shape = \
|
||||||
self.get_mamba_state_shape_from_config(
|
self.get_mamba_state_shape_from_config(
|
||||||
self.vllm_config, use_v1=False)
|
self.vllm_config, use_v1=False)
|
||||||
|
mamba_state_dtype = \
|
||||||
|
self.get_mamba_state_dtype_from_config(
|
||||||
|
self.vllm_config)
|
||||||
self.mamba_cache = MambaCacheManager(self.vllm_config,
|
self.mamba_cache = MambaCacheManager(self.vllm_config,
|
||||||
self.lm_head.weight.dtype,
|
|
||||||
num_mamba_layers,
|
num_mamba_layers,
|
||||||
*mamba_state_shape)
|
*mamba_state_shape,
|
||||||
|
*mamba_state_dtype)
|
||||||
|
|
||||||
# Get cache parameters for current run
|
# Get cache parameters for current run
|
||||||
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
||||||
|
|||||||
@ -173,6 +173,7 @@ CYAN = '\033[1;36m'
|
|||||||
RESET = '\033[0;0m'
|
RESET = '\033[0;0m'
|
||||||
|
|
||||||
STR_DTYPE_TO_TORCH_DTYPE = {
|
STR_DTYPE_TO_TORCH_DTYPE = {
|
||||||
|
"float32": torch.float32,
|
||||||
"half": torch.half,
|
"half": torch.half,
|
||||||
"bfloat16": torch.bfloat16,
|
"bfloat16": torch.bfloat16,
|
||||||
"float": torch.float,
|
"float": torch.float,
|
||||||
|
|||||||
@ -182,14 +182,15 @@ class SlidingWindowSpec(AttentionSpec):
|
|||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class MambaSpec(KVCacheSpec):
|
class MambaSpec(KVCacheSpec):
|
||||||
shapes: tuple[tuple[int, ...], ...]
|
shapes: tuple[tuple[int, ...], ...]
|
||||||
dtype: torch.dtype
|
dtypes: tuple[torch.dtype]
|
||||||
page_size_padded: Optional[int] = None
|
page_size_padded: Optional[int] = None
|
||||||
mamba_type: str = "mamba2"
|
mamba_type: str = "mamba2"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def page_size_bytes(self) -> int:
|
def page_size_bytes(self) -> int:
|
||||||
num_elements = sum(prod(shape) for shape in self.shapes)
|
page_size = sum(
|
||||||
page_size = num_elements * get_dtype_size(self.dtype)
|
prod(shape) * get_dtype_size(dtype)
|
||||||
|
for (shape, dtype) in zip(self.shapes, self.dtypes))
|
||||||
if self.page_size_padded is not None:
|
if self.page_size_padded is not None:
|
||||||
assert self.page_size_padded >= page_size
|
assert self.page_size_padded >= page_size
|
||||||
return self.page_size_padded
|
return self.page_size_padded
|
||||||
|
|||||||
@ -2884,23 +2884,25 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
elif isinstance(kv_cache_spec, MambaSpec):
|
elif isinstance(kv_cache_spec, MambaSpec):
|
||||||
has_mamba = True
|
has_mamba = True
|
||||||
raw_tensor = kv_cache_raw_tensors[layer_name]
|
raw_tensor = kv_cache_raw_tensors[layer_name]
|
||||||
dtype = kv_cache_spec.dtype
|
|
||||||
num_element_per_page = (kv_cache_spec.page_size_bytes //
|
|
||||||
get_dtype_size(dtype))
|
|
||||||
state_tensors = []
|
state_tensors = []
|
||||||
storage_offset = 0
|
storage_offset_bytes = 0
|
||||||
for shape in kv_cache_spec.shapes:
|
for (shape, dtype) in zip(kv_cache_spec.shapes,
|
||||||
|
kv_cache_spec.dtypes):
|
||||||
|
dtype_size = get_dtype_size(dtype)
|
||||||
|
num_element_per_page = (
|
||||||
|
kv_cache_spec.page_size_bytes // dtype_size)
|
||||||
target_shape = (num_blocks, *shape)
|
target_shape = (num_blocks, *shape)
|
||||||
stride = torch.empty(target_shape).stride()
|
stride = torch.empty(target_shape).stride()
|
||||||
target_stride = (num_element_per_page, *stride[1:])
|
target_stride = (num_element_per_page, *stride[1:])
|
||||||
|
assert storage_offset_bytes % dtype_size == 0
|
||||||
tensor = torch.as_strided(
|
tensor = torch.as_strided(
|
||||||
raw_tensor.view(dtype),
|
raw_tensor.view(dtype),
|
||||||
size=target_shape,
|
size=target_shape,
|
||||||
stride=target_stride,
|
stride=target_stride,
|
||||||
storage_offset=storage_offset,
|
storage_offset=storage_offset_bytes // dtype_size,
|
||||||
)
|
)
|
||||||
state_tensors.append(tensor)
|
state_tensors.append(tensor)
|
||||||
storage_offset += stride[0]
|
storage_offset_bytes += stride[0] * dtype_size
|
||||||
|
|
||||||
kv_caches[layer_name] = state_tensors
|
kv_caches[layer_name] = state_tensors
|
||||||
else:
|
else:
|
||||||
@ -3087,7 +3089,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
for layer_name, mamba_module in mamba_layers.items():
|
for layer_name, mamba_module in mamba_layers.items():
|
||||||
kv_cache_spec[layer_name] = MambaSpec(
|
kv_cache_spec[layer_name] = MambaSpec(
|
||||||
shapes=mamba_module.get_state_shape(),
|
shapes=mamba_module.get_state_shape(),
|
||||||
dtype=self.kv_cache_dtype,
|
dtypes=mamba_module.get_state_dtype(),
|
||||||
block_size=max_model_len,
|
block_size=max_model_len,
|
||||||
page_size_padded=page_size_padded,
|
page_size_padded=page_size_padded,
|
||||||
mamba_type=mamba_module.mamba_type)
|
mamba_type=mamba_module.mamba_type)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user