mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-16 12:25:41 +08:00
[V1] [Hybrid] Support using float32 for state in Hybrid Models (Mamba2, Mamba1, Minimax) (#22928)
Signed-off-by: Daniel Afrimi <danielafrimi8@gmail.com> Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com> Signed-off-by: Chen Zhang <zhangch99@outlook.com> Co-authored-by: Daniel Afrimi <danielafrimi8@gmail.com> Co-authored-by: Burkhard Ringlein <ngl@zurich.ibm.com> Co-authored-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
parent
22341b996e
commit
75531a6c13
@ -431,3 +431,65 @@ def test_full_cuda_graph(
|
||||
name_0="hf" if hf_outputs is not None else "vllm-v0",
|
||||
name_1="vllm-v1",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", ["Zyphra/Zamba2-1.2B-instruct"])
|
||||
@pytest.mark.parametrize("max_tokens", [64])
|
||||
@pytest.mark.parametrize("num_logprobs", [5])
|
||||
def test_fp32_state(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
monkeypatch,
|
||||
model: str,
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
) -> None:
|
||||
|
||||
try:
|
||||
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
|
||||
model_info.check_available_online(on_fail="skip")
|
||||
model_info.check_transformers_version(on_fail="skip")
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
with hf_runner(model) as hf_model:
|
||||
if model not in HF_UNSUPPORTED_MODELS:
|
||||
hf_outputs = hf_model.generate_greedy_logprobs_limit(
|
||||
example_prompts, max_tokens, num_logprobs)
|
||||
else:
|
||||
hf_outputs = None
|
||||
|
||||
with vllm_runner(model,
|
||||
max_num_seqs=MAX_NUM_SEQS,
|
||||
mamba_ssm_cache_dtype="float32") as vllm_model:
|
||||
vllm_v0_outputs = vllm_model.generate_greedy_logprobs(
|
||||
example_prompts, max_tokens, num_logprobs)
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
if model in HYBRID_MODELS:
|
||||
# required due to reorder_batch behaviour
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER")
|
||||
with vllm_runner(model,
|
||||
max_num_seqs=MAX_NUM_SEQS,
|
||||
mamba_ssm_cache_dtype="float32",
|
||||
enable_prefix_caching=False) as vllm_model:
|
||||
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
|
||||
example_prompts, max_tokens, num_logprobs)
|
||||
|
||||
if hf_outputs is not None:
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=hf_outputs,
|
||||
outputs_1_lst=vllm_v0_outputs,
|
||||
name_0="hf",
|
||||
name_1="vllm-v0",
|
||||
)
|
||||
|
||||
ref_outputs = hf_outputs if hf_outputs is not None else vllm_v0_outputs
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=ref_outputs,
|
||||
outputs_1_lst=vllm_v1_outputs,
|
||||
name_0="hf" if hf_outputs is not None else "vllm-v0",
|
||||
name_1="vllm-v1",
|
||||
)
|
||||
|
||||
@ -772,6 +772,8 @@ def test_hybrid_attention_mamba_tensor_shapes(monkeypatch):
|
||||
head_dim=hf_config.mamba_d_head,
|
||||
rms_norm_eps=hf_config.rms_norm_eps,
|
||||
activation=hf_config.hidden_act,
|
||||
cache_config=cache_config,
|
||||
model_config=model_config,
|
||||
prefix=key,
|
||||
)
|
||||
# suppress var not used error
|
||||
|
||||
@ -29,7 +29,7 @@ from typing_extensions import Self, assert_never, runtime_checkable
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm import version
|
||||
from vllm.config.cache import (BlockSize, CacheConfig, CacheDType,
|
||||
from vllm.config.cache import (BlockSize, CacheConfig, CacheDType, MambaDType,
|
||||
PrefixCachingHashAlgo)
|
||||
from vllm.config.compilation import (CompilationConfig, CompilationLevel,
|
||||
PassConfig)
|
||||
|
||||
@ -23,6 +23,7 @@ logger = init_logger(__name__)
|
||||
|
||||
BlockSize = Literal[1, 8, 16, 32, 64, 128]
|
||||
CacheDType = Literal["auto", "fp8", "fp8_e4m3", "fp8_e5m2", "fp8_inc"]
|
||||
MambaDType = Literal["auto", "float32"]
|
||||
PrefixCachingHashAlgo = Literal["builtin", "sha256", "sha256_cbor_64bit"]
|
||||
|
||||
|
||||
@ -93,6 +94,15 @@ class CacheConfig:
|
||||
""" Optional override for mamba page size; used by hybrid mamba/attention
|
||||
models to ensure exact alignment with attention page size."""
|
||||
|
||||
mamba_cache_dtype: MambaDType = "auto"
|
||||
"""The data type to use for the Mamba cache (both the conv as well as the
|
||||
ssm state). If set to 'auto', the data type will be inferred from the model
|
||||
config."""
|
||||
mamba_ssm_cache_dtype: MambaDType = "auto"
|
||||
"""The data type to use for the Mamba cache (ssm state only, conv state will
|
||||
still be controlled by mamba_cache_dtype). If set to 'auto', the data type
|
||||
for the ssm state will be determined by mamba_cache_dtype."""
|
||||
|
||||
# Will be set after profiling.
|
||||
num_gpu_blocks: Optional[int] = field(default=None, init=False)
|
||||
"""The number of blocks to allocate for GPU memory."""
|
||||
@ -123,6 +133,8 @@ class CacheConfig:
|
||||
"""
|
||||
factors: list[Any] = []
|
||||
factors.append(self.cache_dtype)
|
||||
factors.append(self.mamba_cache_dtype)
|
||||
factors.append(self.mamba_ssm_cache_dtype)
|
||||
# `cpu_offload_gb` does not use `torch.compile` yet.
|
||||
hash_str = hashlib.md5(str(factors).encode(),
|
||||
usedforsecurity=False).hexdigest()
|
||||
|
||||
@ -27,12 +27,12 @@ from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig,
|
||||
DeviceConfig, DistributedExecutorBackend,
|
||||
GuidedDecodingBackend, HfOverrides, KVEventsConfig,
|
||||
KVTransferConfig, LoadConfig, LogprobsMode,
|
||||
LoRAConfig, ModelConfig, ModelDType, ModelImpl,
|
||||
MultiModalConfig, ObservabilityConfig, ParallelConfig,
|
||||
PoolerConfig, PrefixCachingHashAlgo, RunnerOption,
|
||||
SchedulerConfig, SchedulerPolicy, SpeculativeConfig,
|
||||
TaskOption, TokenizerMode, VllmConfig, get_attr_docs,
|
||||
get_field)
|
||||
LoRAConfig, MambaDType, ModelConfig, ModelDType,
|
||||
ModelImpl, MultiModalConfig, ObservabilityConfig,
|
||||
ParallelConfig, PoolerConfig, PrefixCachingHashAlgo,
|
||||
RunnerOption, SchedulerConfig, SchedulerPolicy,
|
||||
SpeculativeConfig, TaskOption, TokenizerMode,
|
||||
VllmConfig, get_attr_docs, get_field)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import CpuArchEnum, current_platform
|
||||
from vllm.plugins import load_general_plugins
|
||||
@ -422,6 +422,8 @@ class EngineArgs:
|
||||
override_attention_dtype: str = ModelConfig.override_attention_dtype
|
||||
|
||||
calculate_kv_scales: bool = CacheConfig.calculate_kv_scales
|
||||
mamba_cache_dtype: MambaDType = CacheConfig.mamba_cache_dtype
|
||||
mamba_ssm_cache_dtype: MambaDType = CacheConfig.mamba_ssm_cache_dtype
|
||||
|
||||
additional_config: dict[str, Any] = \
|
||||
get_field(VllmConfig, "additional_config")
|
||||
@ -694,6 +696,10 @@ class EngineArgs:
|
||||
**cache_kwargs["calculate_kv_scales"])
|
||||
cache_group.add_argument("--kv-sharing-fast-prefill",
|
||||
**cache_kwargs["kv_sharing_fast_prefill"])
|
||||
cache_group.add_argument("--mamba-cache-dtype",
|
||||
**cache_kwargs["mamba_cache_dtype"])
|
||||
cache_group.add_argument("--mamba-ssm-cache-dtype",
|
||||
**cache_kwargs["mamba_ssm_cache_dtype"])
|
||||
|
||||
# Multimodal related configs
|
||||
multimodal_kwargs = get_kwargs(MultiModalConfig)
|
||||
@ -1105,6 +1111,8 @@ class EngineArgs:
|
||||
cpu_offload_gb=self.cpu_offload_gb,
|
||||
calculate_kv_scales=self.calculate_kv_scales,
|
||||
kv_sharing_fast_prefill=self.kv_sharing_fast_prefill,
|
||||
mamba_cache_dtype=self.mamba_cache_dtype,
|
||||
mamba_ssm_cache_dtype=self.mamba_ssm_cache_dtype,
|
||||
)
|
||||
|
||||
ray_runtime_env = None
|
||||
|
||||
@ -9,7 +9,7 @@ from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm import envs
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
@ -20,7 +20,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.mamba.abstract import MambaBase
|
||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||
MambaStateShapeCalculator)
|
||||
MambaStateDtypeCalculator, MambaStateShapeCalculator)
|
||||
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
|
||||
causal_conv1d_fn, causal_conv1d_update)
|
||||
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
|
||||
@ -56,6 +56,8 @@ class MambaMixer(MambaBase, CustomOp):
|
||||
rms_norm_eps: float = 1e-5,
|
||||
activation="silu",
|
||||
is_lora_enabled: bool = False,
|
||||
model_config: Optional[ModelConfig] = None,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
prefix: str = ""):
|
||||
super().__init__()
|
||||
self.time_step_rank = time_step_rank
|
||||
@ -153,6 +155,8 @@ class MambaMixer(MambaBase, CustomOp):
|
||||
# The inner tuple is (conv_state, ssm_state)
|
||||
self.kv_cache = [(torch.tensor([]), torch.tensor([]))]
|
||||
|
||||
self.model_config = model_config
|
||||
self.cache_config = cache_config
|
||||
self.prefix = prefix
|
||||
|
||||
def _ssm_transform(
|
||||
@ -369,6 +373,15 @@ class MambaMixer(MambaBase, CustomOp):
|
||||
|
||||
return out
|
||||
|
||||
def get_state_dtype(self) -> tuple[torch.dtype]:
|
||||
assert self.model_config is not None
|
||||
assert self.cache_config is not None
|
||||
return MambaStateDtypeCalculator.mamba1_state_dtype(
|
||||
self.model_config.dtype,
|
||||
self.cache_config.mamba_cache_dtype,
|
||||
self.cache_config.mamba_ssm_cache_dtype,
|
||||
)
|
||||
|
||||
def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
|
||||
return MambaStateShapeCalculator.mamba1_state_shape(
|
||||
tp_world_size=get_tensor_model_parallel_world_size(),
|
||||
|
||||
@ -8,7 +8,7 @@ from torch import nn
|
||||
|
||||
from vllm import envs
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config
|
||||
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_gather,
|
||||
@ -21,7 +21,7 @@ from vllm.model_executor.layers.mamba.abstract import MambaBase
|
||||
from vllm.model_executor.layers.mamba.mamba2_metadata import (Mamba2Metadata,
|
||||
update_metadata)
|
||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||
MambaStateShapeCalculator)
|
||||
MambaStateDtypeCalculator, MambaStateShapeCalculator)
|
||||
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
|
||||
causal_conv1d_fn, causal_conv1d_update)
|
||||
from vllm.model_executor.layers.mamba.ops.layernorm_gated import rms_norm_gated
|
||||
@ -218,23 +218,23 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
**selective** state spaces)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
ssm_state_size: int,
|
||||
conv_kernel_size: int,
|
||||
intermediate_size: int,
|
||||
use_conv_bias: bool,
|
||||
use_bias: bool,
|
||||
n_groups: int = 1,
|
||||
num_heads: int = 128,
|
||||
head_dim: int = 64,
|
||||
rms_norm_eps: float = 1e-5,
|
||||
activation: str = "silu",
|
||||
use_rms_norm: bool = True,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
def __init__(self,
|
||||
hidden_size: int,
|
||||
ssm_state_size: int,
|
||||
conv_kernel_size: int,
|
||||
intermediate_size: int,
|
||||
use_conv_bias: bool,
|
||||
use_bias: bool,
|
||||
n_groups: int = 1,
|
||||
num_heads: int = 128,
|
||||
head_dim: int = 64,
|
||||
rms_norm_eps: float = 1e-5,
|
||||
activation: str = "silu",
|
||||
use_rms_norm: bool = True,
|
||||
model_config: Optional[ModelConfig] = None,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
# For TP, the sharding plan is as follows:
|
||||
@ -417,6 +417,8 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
# The inner tuple is (conv_state, ssm_state)
|
||||
self.kv_cache = [(torch.tensor([]), torch.tensor([]))]
|
||||
|
||||
self.model_config = model_config
|
||||
self.cache_config = cache_config
|
||||
self.prefix = prefix
|
||||
|
||||
def forward_native(
|
||||
@ -670,7 +672,7 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
dt_limit=(0.0, float("inf")),
|
||||
out=preallocated_ssm_out_p.view(1, num_prefill_tokens, -1,
|
||||
self.head_dim),
|
||||
)
|
||||
state_dtype=ssm_state.dtype)
|
||||
|
||||
# update ssm states
|
||||
# - varlen state is a (num_prefills, nheads, headdim, dstate) tensor
|
||||
@ -732,6 +734,15 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
# 5. Final linear projection
|
||||
output[:num_actual_tokens], _ = self.out_proj(hidden_states)
|
||||
|
||||
def get_state_dtype(self) -> tuple[torch.dtype, torch.dtype]:
|
||||
assert self.model_config is not None
|
||||
assert self.cache_config is not None
|
||||
return MambaStateDtypeCalculator.mamba2_state_dtype(
|
||||
self.model_config.dtype,
|
||||
self.cache_config.mamba_cache_dtype,
|
||||
self.cache_config.mamba_ssm_cache_dtype,
|
||||
)
|
||||
|
||||
def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
|
||||
return MambaStateShapeCalculator.mamba2_state_shape(
|
||||
intermediate_size=self.intermediate_size,
|
||||
|
||||
@ -1,6 +1,58 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import MambaDType, ModelDType
|
||||
from vllm.distributed import divide
|
||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, get_kv_cache_torch_dtype
|
||||
|
||||
|
||||
class MambaStateDtypeCalculator:
|
||||
|
||||
@classmethod
|
||||
def linear_attention_state_dtype(
|
||||
cls,
|
||||
model_dtype: Union[ModelDType, torch.dtype],
|
||||
mamba_cache_dtype: MambaDType,
|
||||
) -> tuple[torch.dtype, ...]:
|
||||
# TODO (tdoublep) requires testing
|
||||
if mamba_cache_dtype == "float32":
|
||||
raise ValueError("fp32 state for minimax is not yet supported")
|
||||
state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype, model_dtype)
|
||||
return (state_dtype, )
|
||||
|
||||
@classmethod
|
||||
def mamba1_state_dtype(
|
||||
cls,
|
||||
model_dtype: Union[ModelDType, torch.dtype],
|
||||
mamba_cache_dtype: MambaDType,
|
||||
mamba_ssm_cache_dtype: MambaDType,
|
||||
) -> tuple[torch.dtype, ...]:
|
||||
# TODO (tdoublep) requires kernel changes
|
||||
if mamba_cache_dtype == "float32" or mamba_ssm_cache_dtype == "float32":
|
||||
raise ValueError("fp32 state for mamba1 is not yet supported")
|
||||
else:
|
||||
return MambaStateDtypeCalculator.mamba2_state_dtype(
|
||||
model_dtype, mamba_cache_dtype, mamba_ssm_cache_dtype)
|
||||
|
||||
@classmethod
|
||||
def mamba2_state_dtype(
|
||||
cls,
|
||||
model_dtype: Union[ModelDType, torch.dtype],
|
||||
mamba_cache_dtype: MambaDType,
|
||||
mamba_ssm_cache_dtype: MambaDType,
|
||||
) -> tuple[torch.dtype, ...]:
|
||||
conv_state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype,
|
||||
model_dtype)
|
||||
if mamba_ssm_cache_dtype == "auto":
|
||||
temporal_state_dtype = conv_state_dtype
|
||||
else:
|
||||
temporal_state_dtype = (
|
||||
STR_DTYPE_TO_TORCH_DTYPE[mamba_ssm_cache_dtype])
|
||||
|
||||
return (conv_state_dtype, temporal_state_dtype)
|
||||
|
||||
|
||||
class MambaStateShapeCalculator:
|
||||
|
||||
@ -41,6 +41,7 @@ def _mamba_chunk_scan_combined_fwd(x,
|
||||
cu_seqlens=None,
|
||||
dt_softplus=False,
|
||||
dt_limit=(0.0, float("inf")),
|
||||
state_dtype=None,
|
||||
out=None):
|
||||
assert is_int_pow_2(chunk_size), "chunk_size must be integer power of 2"
|
||||
batch, seqlen, nheads, headdim = x.shape
|
||||
@ -118,7 +119,7 @@ def _mamba_chunk_scan_combined_fwd(x,
|
||||
if initial_states is not None else None,
|
||||
seq_idx=seq_idx,
|
||||
chunk_size=chunk_size,
|
||||
out_dtype=C.dtype,
|
||||
out_dtype=state_dtype if state_dtype is not None else C.dtype,
|
||||
is_cont_batched=cu_seqlens is not None)
|
||||
states, final_states = (rearrange(t, "... (p n) -> ... p n", n=dstate)
|
||||
for t in [states, final_states])
|
||||
@ -189,7 +190,8 @@ def mamba_chunk_scan_combined(x,
|
||||
dt_limit=(0.0, float("inf")),
|
||||
out=None,
|
||||
return_final_states=False,
|
||||
return_varlen_states=False):
|
||||
return_varlen_states=False,
|
||||
state_dtype=None):
|
||||
"""
|
||||
Argument:
|
||||
x: (batch, seqlen, nheads, headdim)
|
||||
@ -206,6 +208,7 @@ def mamba_chunk_scan_combined(x,
|
||||
cu_seqlens: (num_sequences + 1) or None, only used if return_varlen_states is True
|
||||
dt_softplus: Whether to apply softplus to dt
|
||||
out: Preallocated output tensor
|
||||
state_dtype: The data type of the ssm state
|
||||
"""
|
||||
|
||||
if not return_varlen_states:
|
||||
@ -229,7 +232,8 @@ def mamba_chunk_scan_combined(x,
|
||||
cu_seqlens=cu_seqlens,
|
||||
dt_softplus=dt_softplus,
|
||||
dt_limit=dt_limit,
|
||||
out=out)
|
||||
out=out,
|
||||
state_dtype=state_dtype)
|
||||
if not return_varlen_states:
|
||||
if not return_final_states:
|
||||
return
|
||||
|
||||
@ -12,7 +12,7 @@ from transformers import BambaConfig
|
||||
from vllm import envs
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.distributed.parallel_state import get_pp_group
|
||||
from vllm.forward_context import get_forward_context
|
||||
@ -26,7 +26,7 @@ from vllm.model_executor.layers.mamba.mamba2_metadata import (
|
||||
Mamba2Metadata, prepare_mamba2_metadata)
|
||||
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
|
||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||
MambaStateShapeCalculator)
|
||||
MambaStateDtypeCalculator, MambaStateShapeCalculator)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
@ -83,6 +83,7 @@ class BambaMixerDecoderLayer(nn.Module):
|
||||
def __init__(self,
|
||||
config: BambaConfig,
|
||||
layer_idx: int,
|
||||
model_config: Optional[ModelConfig] = None,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "") -> None:
|
||||
@ -100,6 +101,8 @@ class BambaMixerDecoderLayer(nn.Module):
|
||||
head_dim=config.mamba_d_head,
|
||||
rms_norm_eps=config.rms_norm_eps,
|
||||
activation=config.hidden_act,
|
||||
model_config=model_config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mixer")
|
||||
|
||||
@ -138,6 +141,7 @@ class BambaAttentionDecoderLayer(nn.Module):
|
||||
self,
|
||||
config: BambaConfig,
|
||||
layer_idx: int,
|
||||
model_config: Optional[ModelConfig] = None,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
@ -266,6 +270,7 @@ class BambaModel(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
config: BambaConfig = vllm_config.model_config.hf_config
|
||||
model_config = vllm_config.model_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
lora_config = vllm_config.lora_config
|
||||
@ -289,6 +294,7 @@ class BambaModel(nn.Module):
|
||||
return layer_class(
|
||||
config,
|
||||
layer_idx,
|
||||
model_config,
|
||||
cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
@ -437,6 +443,18 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
||||
}
|
||||
embedding_padding_modules = ["lm_head"]
|
||||
|
||||
@classmethod
|
||||
def get_mamba_state_dtype_from_config(
|
||||
cls,
|
||||
vllm_config: "VllmConfig",
|
||||
) -> tuple[torch.dtype, torch.dtype]:
|
||||
|
||||
return MambaStateDtypeCalculator.mamba2_state_dtype(
|
||||
vllm_config.model_config.dtype,
|
||||
vllm_config.cache_config.mamba_cache_dtype,
|
||||
vllm_config.cache_config.mamba_ssm_cache_dtype,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_mamba_state_shape_from_config(
|
||||
cls,
|
||||
@ -528,10 +546,13 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
||||
mamba_state_shape = \
|
||||
self.get_mamba_state_shape_from_config(
|
||||
self.vllm_config, use_v1=False)
|
||||
mamba_state_dtype = \
|
||||
self.get_mamba_state_dtype_from_config(
|
||||
self.vllm_config)
|
||||
self.mamba_cache = MambaCacheManager(self.vllm_config,
|
||||
self.lm_head.weight.dtype,
|
||||
num_mamba_layers,
|
||||
*mamba_state_shape)
|
||||
*mamba_state_shape,
|
||||
*mamba_state_dtype)
|
||||
|
||||
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
||||
|
||||
|
||||
@ -318,7 +318,7 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
|
||||
# get mamba page size
|
||||
mamba_page_size = MambaSpec(
|
||||
shapes=model_cls.get_mamba_state_shape_from_config(vllm_config),
|
||||
dtype=kv_cache_dtype,
|
||||
dtypes=model_cls.get_mamba_state_dtype_from_config(vllm_config),
|
||||
block_size=model_config.max_model_len,
|
||||
).page_size_bytes
|
||||
|
||||
|
||||
@ -11,7 +11,7 @@ from transformers import FalconH1Config
|
||||
from vllm import envs
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.distributed.parallel_state import get_pp_group
|
||||
from vllm.forward_context import get_forward_context
|
||||
@ -25,7 +25,7 @@ from vllm.model_executor.layers.mamba.mamba2_metadata import (
|
||||
Mamba2Metadata, prepare_mamba2_metadata)
|
||||
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
|
||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||
MambaStateShapeCalculator)
|
||||
MambaStateDtypeCalculator, MambaStateShapeCalculator)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
@ -85,6 +85,7 @@ class FalconH1SSMDecoderLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: FalconH1Config,
|
||||
model_config: Optional[ModelConfig] = None,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
@ -108,6 +109,8 @@ class FalconH1SSMDecoderLayer(nn.Module):
|
||||
head_dim=config.mamba_d_head,
|
||||
rms_norm_eps=config.rms_norm_eps,
|
||||
activation=config.hidden_act,
|
||||
model_config=model_config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
use_rms_norm=config.mamba_rms_norm,
|
||||
prefix=f"{prefix}.mixer",
|
||||
@ -317,6 +320,7 @@ class FalconH1ParallelHybrid(nn.Module):
|
||||
self,
|
||||
config: FalconH1Config,
|
||||
layer_idx: int,
|
||||
model_config: Optional[ModelConfig] = None,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
@ -339,6 +343,7 @@ class FalconH1ParallelHybrid(nn.Module):
|
||||
# Instantiate the SSM branch
|
||||
self.mamba = FalconH1SSMDecoderLayer(
|
||||
config=config,
|
||||
model_config=model_config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=ssm_prefix,
|
||||
@ -408,6 +413,7 @@ class FalconH1Model(nn.Module):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config: FalconH1Config = vllm_config.model_config.hf_config
|
||||
model_config = vllm_config.model_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
lora_config = vllm_config.lora_config
|
||||
@ -435,6 +441,7 @@ class FalconH1Model(nn.Module):
|
||||
return layer_class(
|
||||
config,
|
||||
layer_idx,
|
||||
model_config,
|
||||
cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
@ -519,6 +526,18 @@ class FalconH1ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
||||
}
|
||||
embedding_padding_modules = ["lm_head"]
|
||||
|
||||
@classmethod
|
||||
def get_mamba_state_dtype_from_config(
|
||||
cls,
|
||||
vllm_config: "VllmConfig",
|
||||
) -> tuple[torch.dtype, torch.dtype]:
|
||||
|
||||
return MambaStateDtypeCalculator.mamba2_state_dtype(
|
||||
vllm_config.model_config.dtype,
|
||||
vllm_config.cache_config.mamba_cache_dtype,
|
||||
vllm_config.cache_config.mamba_ssm_cache_dtype,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_mamba_state_shape_from_config(
|
||||
cls,
|
||||
@ -624,12 +643,14 @@ class FalconH1ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
||||
mamba_state_shape = \
|
||||
self.get_mamba_state_shape_from_config(
|
||||
self.vllm_config, use_v1=False)
|
||||
mamba_state_dtype = \
|
||||
self.get_mamba_state_dtype_from_config(
|
||||
self.vllm_config)
|
||||
self.mamba_cache = MambaCacheManager(
|
||||
self.vllm_config,
|
||||
self.lm_head.weight.dtype if hasattr(
|
||||
self.lm_head, 'weight') else torch.bfloat16,
|
||||
self.config.num_hidden_layers,
|
||||
*mamba_state_shape,
|
||||
*mamba_state_dtype,
|
||||
)
|
||||
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
||||
|
||||
|
||||
@ -12,7 +12,7 @@ from transformers import GraniteMoeHybridConfig
|
||||
from vllm import envs
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.distributed.parallel_state import get_pp_group
|
||||
from vllm.forward_context import get_forward_context
|
||||
@ -24,7 +24,7 @@ from vllm.model_executor.layers.mamba.mamba2_metadata import (
|
||||
Mamba2Metadata, prepare_mamba2_metadata)
|
||||
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
|
||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||
MambaStateShapeCalculator)
|
||||
MambaStateDtypeCalculator, MambaStateShapeCalculator)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
@ -50,6 +50,7 @@ class GraniteMoeHybridMambaDecoderLayer(nn.Module):
|
||||
def __init__(self,
|
||||
config: GraniteMoeHybridConfig,
|
||||
layer_idx: int,
|
||||
model_config: Optional[ModelConfig] = None,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "") -> None:
|
||||
@ -70,6 +71,8 @@ class GraniteMoeHybridMambaDecoderLayer(nn.Module):
|
||||
head_dim=config.mamba_d_head,
|
||||
rms_norm_eps=config.rms_norm_eps,
|
||||
activation=config.hidden_act,
|
||||
model_config=model_config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mixer")
|
||||
|
||||
@ -137,6 +140,7 @@ class GraniteMoeHybridAttentionDecoderLayer(nn.Module):
|
||||
self,
|
||||
config: GraniteMoeHybridConfig,
|
||||
layer_idx: int,
|
||||
model_config: Optional[ModelConfig] = None,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
@ -217,6 +221,7 @@ class GraniteMoeHybridAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: GraniteMoeHybridConfig,
|
||||
model_config: Optional[ModelConfig] = None,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
@ -316,6 +321,7 @@ class GraniteMoeHybridModel(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
model_config = vllm_config.model_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
lora_config = vllm_config.lora_config
|
||||
@ -340,6 +346,7 @@ class GraniteMoeHybridModel(nn.Module):
|
||||
return layer_class(
|
||||
config,
|
||||
layer_idx,
|
||||
model_config,
|
||||
cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
@ -527,6 +534,18 @@ class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA,
|
||||
}
|
||||
embedding_padding_modules = ["lm_head"]
|
||||
|
||||
@classmethod
|
||||
def get_mamba_state_dtype_from_config(
|
||||
cls,
|
||||
vllm_config: "VllmConfig",
|
||||
) -> tuple[torch.dtype, torch.dtype]:
|
||||
|
||||
return MambaStateDtypeCalculator.mamba2_state_dtype(
|
||||
vllm_config.model_config.dtype,
|
||||
vllm_config.cache_config.mamba_cache_dtype,
|
||||
vllm_config.cache_config.mamba_ssm_cache_dtype,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_mamba_state_shape_from_config(
|
||||
cls,
|
||||
@ -625,10 +644,13 @@ class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA,
|
||||
mamba_state_shape = \
|
||||
self.get_mamba_state_shape_from_config(
|
||||
self.vllm_config, use_v1=False)
|
||||
mamba_state_dtype = \
|
||||
self.get_mamba_state_dtype_from_config(
|
||||
self.vllm_config)
|
||||
self.mamba_cache = MambaCacheManager(self.vllm_config,
|
||||
self.model_config.dtype,
|
||||
num_mamba_layers,
|
||||
*mamba_state_shape)
|
||||
*mamba_state_shape,
|
||||
*mamba_state_dtype)
|
||||
|
||||
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
||||
|
||||
|
||||
@ -10,7 +10,7 @@ from transformers import JambaConfig
|
||||
|
||||
from vllm import envs
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.distributed.parallel_state import get_pp_group
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
@ -21,7 +21,7 @@ from vllm.model_executor.layers.linear import (QKVParallelLinear,
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer
|
||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||
MambaStateShapeCalculator)
|
||||
MambaStateDtypeCalculator, MambaStateShapeCalculator)
|
||||
from vllm.model_executor.layers.pooler import DispatchPooler, Pooler
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
@ -94,6 +94,7 @@ class JambaMambaDecoderLayer(nn.Module):
|
||||
def __init__(self,
|
||||
config: JambaConfig,
|
||||
layer_idx: int,
|
||||
model_config: Optional[ModelConfig] = None,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
is_lora_enabled: Optional[bool] = False,
|
||||
@ -114,6 +115,8 @@ class JambaMambaDecoderLayer(nn.Module):
|
||||
rms_norm_eps=config.rms_norm_eps,
|
||||
activation=config.hidden_act,
|
||||
is_lora_enabled = self.is_lora_enabled,
|
||||
model_config=model_config,
|
||||
cache_config=cache_config,
|
||||
prefix=f"{prefix}.mixer",
|
||||
)
|
||||
|
||||
@ -164,6 +167,7 @@ class JambaAttentionDecoderLayer(nn.Module):
|
||||
def __init__(self,
|
||||
config: JambaConfig,
|
||||
layer_idx: int,
|
||||
model_config: Optional[ModelConfig] = None,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
@ -280,6 +284,7 @@ class JambaModel(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
model_config = vllm_config.model_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
lora_config = vllm_config.lora_config
|
||||
@ -304,6 +309,7 @@ class JambaModel(nn.Module):
|
||||
config.layers_block_type[layer_idx]]
|
||||
return layer_class(config,
|
||||
layer_idx,
|
||||
model_config,
|
||||
cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
@ -520,9 +526,11 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
||||
self.vllm_config.parallel_config, LayerBlockType.mamba)
|
||||
state_shape = self.get_mamba_state_shape_from_config(
|
||||
self.vllm_config)
|
||||
state_dtype = self.get_mamba_state_dtype_from_config(
|
||||
self.vllm_config)
|
||||
self.mamba_cache = MambaCacheManager(self.vllm_config,
|
||||
self.lm_head.weight.dtype,
|
||||
num_layers, *state_shape)
|
||||
num_layers, *state_shape,
|
||||
*state_dtype)
|
||||
|
||||
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
||||
|
||||
@ -537,6 +545,18 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
||||
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
|
||||
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
|
||||
|
||||
@classmethod
|
||||
def get_mamba_state_dtype_from_config(
|
||||
cls,
|
||||
vllm_config: "VllmConfig",
|
||||
) -> tuple[torch.dtype, torch.dtype]:
|
||||
|
||||
return MambaStateDtypeCalculator.mamba1_state_dtype(
|
||||
vllm_config.model_config.dtype,
|
||||
vllm_config.cache_config.mamba_cache_dtype,
|
||||
vllm_config.cache_config.mamba_ssm_cache_dtype,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_mamba_state_shape_from_config(
|
||||
cls,
|
||||
|
||||
@ -9,13 +9,13 @@ from torch import nn
|
||||
from transformers import MambaConfig
|
||||
|
||||
from vllm import envs
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
||||
from vllm.distributed.parallel_state import get_pp_group
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer
|
||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||
MambaStateShapeCalculator)
|
||||
MambaStateDtypeCalculator, MambaStateShapeCalculator)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
@ -40,6 +40,7 @@ class MambaDecoderLayer(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
config: MambaConfig,
|
||||
model_config: Optional[ModelConfig] = None,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
is_lora_enabled: Optional[bool] = False,
|
||||
@ -61,6 +62,8 @@ class MambaDecoderLayer(nn.Module):
|
||||
rms_norm_eps=mixer_rms_eps,
|
||||
activation=config.hidden_act,
|
||||
is_lora_enabled=self.is_lora_enabled,
|
||||
model_config=model_config,
|
||||
cache_config=cache_config,
|
||||
prefix=f"{prefix}.mixer")
|
||||
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
||||
@ -88,6 +91,7 @@ class MambaModel(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
model_config = vllm_config.model_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
lora_config = vllm_config.lora_config
|
||||
@ -108,6 +112,7 @@ class MambaModel(nn.Module):
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: MambaDecoderLayer(config,
|
||||
model_config=model_config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
is_lora_enabled=is_lora_enabled,
|
||||
@ -243,9 +248,11 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP):
|
||||
self.vllm_config.parallel_config, LayerBlockType.mamba)
|
||||
state_shape = self.get_mamba_state_shape_from_config(
|
||||
self.vllm_config)
|
||||
state_dtype = self.get_mamba_state_dtype_from_config(
|
||||
self.vllm_config)
|
||||
self.mamba_cache = MambaCacheManager(self.vllm_config,
|
||||
self.lm_head.weight.dtype,
|
||||
num_layers, *state_shape)
|
||||
num_layers, *state_shape,
|
||||
*state_dtype)
|
||||
|
||||
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
||||
|
||||
@ -254,6 +261,18 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP):
|
||||
|
||||
return hidden_states
|
||||
|
||||
@classmethod
|
||||
def get_mamba_state_dtype_from_config(
|
||||
cls,
|
||||
vllm_config: "VllmConfig",
|
||||
) -> tuple[torch.dtype, torch.dtype]:
|
||||
|
||||
return MambaStateDtypeCalculator.mamba1_state_dtype(
|
||||
vllm_config.model_config.dtype,
|
||||
vllm_config.cache_config.mamba_cache_dtype,
|
||||
vllm_config.cache_config.mamba_ssm_cache_dtype,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_mamba_state_shape_from_config(
|
||||
cls,
|
||||
|
||||
@ -11,7 +11,7 @@ from transformers import MambaConfig
|
||||
from vllm import envs
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
||||
from vllm.distributed.parallel_state import get_pp_group
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
@ -20,7 +20,7 @@ from vllm.model_executor.layers.mamba.mamba2_metadata import (
|
||||
Mamba2Metadata, prepare_mamba2_metadata)
|
||||
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
|
||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||
MambaStateShapeCalculator)
|
||||
MambaStateDtypeCalculator, MambaStateShapeCalculator)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
@ -45,6 +45,8 @@ class Mamba2DecoderLayer(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
config: MambaConfig,
|
||||
model_config: Optional[ModelConfig] = None,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "") -> None:
|
||||
super().__init__()
|
||||
@ -62,6 +64,8 @@ class Mamba2DecoderLayer(nn.Module):
|
||||
head_dim=config.head_dim,
|
||||
rms_norm_eps=config.layer_norm_epsilon,
|
||||
activation=config.hidden_act,
|
||||
model_config=model_config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mixer")
|
||||
|
||||
@ -93,6 +97,8 @@ class Mamba2Model(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
model_config = vllm_config.model_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
lora_config = vllm_config.lora_config
|
||||
is_lora_enabled = bool(lora_config)
|
||||
@ -112,8 +118,11 @@ class Mamba2Model(nn.Module):
|
||||
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: Mamba2DecoderLayer(
|
||||
config, quant_config=quant_config, prefix=prefix),
|
||||
lambda prefix: Mamba2DecoderLayer(config,
|
||||
model_config=model_config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix),
|
||||
prefix=f"{prefix}.layers")
|
||||
|
||||
self.norm_f = RMSNorm(config.hidden_size,
|
||||
@ -200,6 +209,18 @@ class Mamba2Model(nn.Module):
|
||||
|
||||
class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
|
||||
|
||||
@classmethod
|
||||
def get_mamba_state_dtype_from_config(
|
||||
cls,
|
||||
vllm_config: "VllmConfig",
|
||||
) -> tuple[torch.dtype, torch.dtype]:
|
||||
|
||||
return MambaStateDtypeCalculator.mamba2_state_dtype(
|
||||
vllm_config.model_config.dtype,
|
||||
vllm_config.cache_config.mamba_cache_dtype,
|
||||
vllm_config.cache_config.mamba_ssm_cache_dtype,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_mamba_state_shape_from_config(
|
||||
cls,
|
||||
@ -290,10 +311,13 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
|
||||
mamba_state_shape = \
|
||||
self.get_mamba_state_shape_from_config(
|
||||
self.vllm_config, use_v1=False)
|
||||
mamba_state_dtype = \
|
||||
self.get_mamba_state_dtype_from_config(
|
||||
self.vllm_config)
|
||||
self.mamba_cache = MambaCacheManager(self.vllm_config,
|
||||
self.lm_head.weight.dtype,
|
||||
num_mamba_layers,
|
||||
*mamba_state_shape)
|
||||
*mamba_state_shape,
|
||||
*mamba_state_dtype)
|
||||
|
||||
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
||||
else:
|
||||
|
||||
@ -24,9 +24,14 @@ class MambaCacheParams:
|
||||
|
||||
class MambaCacheManager(ConstantSizeCache):
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, dtype: torch.dtype,
|
||||
num_mamba_layers: int, conv_state_shape: tuple[int, int],
|
||||
temporal_state_shape: tuple[int, int]):
|
||||
def __init__(self, vllm_config: VllmConfig, num_mamba_layers: int,
|
||||
conv_state_shape: tuple[int, int],
|
||||
temporal_state_shape: tuple[int, int],
|
||||
conv_state_dtype: torch.dtype,
|
||||
temporal_state_dtype: torch.dtype):
|
||||
|
||||
self.conv_state_dtype = conv_state_dtype
|
||||
self.temporal_state_dtype = temporal_state_dtype
|
||||
|
||||
# Determine max batch size to set size of MambaCache
|
||||
max_batch_size = vllm_config.scheduler_config.max_num_seqs
|
||||
@ -40,11 +45,11 @@ class MambaCacheManager(ConstantSizeCache):
|
||||
assert conv_state_shape[0] > conv_state_shape[1]
|
||||
conv_state = torch.empty(size=(num_mamba_layers, max_batch_size) +
|
||||
(conv_state_shape[1], conv_state_shape[0]),
|
||||
dtype=dtype,
|
||||
dtype=self.conv_state_dtype,
|
||||
device="cuda").transpose(-1, -2)
|
||||
temporal_state = torch.empty(size=(num_mamba_layers, max_batch_size) +
|
||||
temporal_state_shape,
|
||||
dtype=dtype,
|
||||
dtype=self.temporal_state_dtype,
|
||||
device="cuda")
|
||||
|
||||
self._mamba_cache = (conv_state, temporal_state)
|
||||
|
||||
@ -16,7 +16,8 @@ from transformers import MiniMaxConfig
|
||||
|
||||
from vllm import envs
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config
|
||||
from vllm.config import (CacheConfig, ModelConfig, VllmConfig,
|
||||
get_current_vllm_config)
|
||||
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_pp_group, get_tensor_model_parallel_rank,
|
||||
@ -36,7 +37,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.mamba.abstract import MambaBase
|
||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||
MambaStateShapeCalculator)
|
||||
MambaStateDtypeCalculator, MambaStateShapeCalculator)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
@ -338,6 +339,12 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
|
||||
def mamba_type(self) -> str:
|
||||
return "linear_attention"
|
||||
|
||||
def get_state_dtype(self) -> tuple[torch.dtype]:
|
||||
return MambaStateDtypeCalculator.linear_attention_state_dtype(
|
||||
self.model_config.dtype,
|
||||
self.cache_config.mamba_cache_dtype,
|
||||
)
|
||||
|
||||
def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
|
||||
return MambaStateShapeCalculator.linear_attention_state_shape(
|
||||
num_heads=self.num_heads,
|
||||
@ -353,6 +360,8 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
|
||||
max_position: int,
|
||||
block_size: int,
|
||||
num_hidden_layer: int,
|
||||
model_config: Optional[ModelConfig] = None,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
layer_idx: int = 0,
|
||||
linear_layer_idx: int = 0,
|
||||
@ -374,6 +383,8 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
|
||||
self.tp_heads = self.total_num_heads // self.tp_size
|
||||
self.qkv_size = self.num_heads * self.head_dim
|
||||
self.tp_hidden = self.head_dim * self.tp_heads
|
||||
self.model_config = model_config
|
||||
self.cache_config = cache_config
|
||||
self.prefix = prefix
|
||||
|
||||
self.qkv_proj = ColumnParallelLinear(
|
||||
@ -657,6 +668,7 @@ class MiniMaxText01DecoderLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: MiniMaxConfig,
|
||||
model_config: Optional[ModelConfig] = None,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
expert_num: int = 1,
|
||||
@ -693,6 +705,8 @@ class MiniMaxText01DecoderLayer(nn.Module):
|
||||
max_position=max_position_embeddings,
|
||||
block_size=config.block if hasattr(config, "block") else 256,
|
||||
num_hidden_layer=config.num_hidden_layers,
|
||||
model_config=model_config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
layer_idx=self._ilayer,
|
||||
linear_layer_idx=linear_layer_id,
|
||||
@ -861,6 +875,7 @@ class MiniMaxText01Model(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: MiniMaxConfig,
|
||||
model_config: Optional[ModelConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
scheduler_config=None,
|
||||
@ -910,6 +925,7 @@ class MiniMaxText01Model(nn.Module):
|
||||
decoder_kwargs = {
|
||||
"quant_config": quant_config,
|
||||
"layer_id": layer_idx,
|
||||
"model_config": model_config,
|
||||
"cache_config": cache_config
|
||||
}
|
||||
|
||||
@ -1111,8 +1127,9 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid):
|
||||
self.config.max_model_len = vllm_config.model_config.max_model_len
|
||||
self.model = MiniMaxText01Model(
|
||||
self.config,
|
||||
quant_config,
|
||||
model_config=vllm_config.model_config,
|
||||
cache_config=vllm_config.cache_config,
|
||||
quant_config=quant_config,
|
||||
scheduler_config=vllm_config.scheduler_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
if get_pp_group().is_last_rank:
|
||||
@ -1409,6 +1426,17 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid):
|
||||
load_basic_weight(name, loaded_weight, self)
|
||||
return loaded_params
|
||||
|
||||
@classmethod
|
||||
def get_mamba_state_dtype_from_config(
|
||||
cls,
|
||||
vllm_config: "VllmConfig",
|
||||
) -> tuple[torch.dtype, torch.dtype]:
|
||||
|
||||
return MambaStateDtypeCalculator.linear_attention_state_dtype(
|
||||
vllm_config.model_config.dtype,
|
||||
vllm_config.cache_config.mamba_cache_dtype,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_mamba_state_shape_from_config(
|
||||
cls,
|
||||
|
||||
@ -26,7 +26,7 @@ from torch import nn
|
||||
from vllm import envs
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.distributed.parallel_state import get_pp_group
|
||||
from vllm.forward_context import get_forward_context
|
||||
@ -40,7 +40,7 @@ from vllm.model_executor.layers.mamba.mamba2_metadata import (
|
||||
Mamba2Metadata, prepare_mamba2_metadata)
|
||||
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
|
||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||
MambaStateShapeCalculator)
|
||||
MambaStateDtypeCalculator, MambaStateShapeCalculator)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||
@ -110,6 +110,7 @@ class NemotronHMLPDecoderLayer(nn.Module):
|
||||
self,
|
||||
config: NemotronHConfig,
|
||||
layer_idx: int,
|
||||
model_config: Optional[ModelConfig] = None,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
@ -149,6 +150,7 @@ class NemotronHMambaDecoderLayer(nn.Module):
|
||||
self,
|
||||
config: NemotronHConfig,
|
||||
layer_idx: int,
|
||||
model_config: Optional[ModelConfig] = None,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
@ -167,6 +169,8 @@ class NemotronHMambaDecoderLayer(nn.Module):
|
||||
head_dim=config.mamba_head_dim,
|
||||
rms_norm_eps=config.rms_norm_eps,
|
||||
activation=config.mamba_hidden_act,
|
||||
model_config=model_config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mixer",
|
||||
)
|
||||
@ -198,6 +202,7 @@ class NemotronHAttention(nn.Module):
|
||||
self,
|
||||
config: NemotronHConfig,
|
||||
layer_idx: int,
|
||||
model_config: Optional[ModelConfig] = None,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
@ -270,6 +275,7 @@ class NemotronHAttentionDecoderLayer(nn.Module):
|
||||
self,
|
||||
config: NemotronHConfig,
|
||||
layer_idx: int,
|
||||
model_config: Optional[ModelConfig] = None,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
@ -279,6 +285,7 @@ class NemotronHAttentionDecoderLayer(nn.Module):
|
||||
self.mixer = NemotronHAttention(
|
||||
config,
|
||||
layer_idx,
|
||||
model_config,
|
||||
cache_config,
|
||||
quant_config,
|
||||
prefix=f"{prefix}.mixer",
|
||||
@ -317,6 +324,7 @@ class NemotronHModel(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
config: NemotronHConfig = vllm_config.model_config.hf_config
|
||||
model_config = vllm_config.model_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
lora_config = vllm_config.lora_config
|
||||
@ -340,6 +348,7 @@ class NemotronHModel(nn.Module):
|
||||
return layer_class(
|
||||
config,
|
||||
layer_idx,
|
||||
model_config,
|
||||
cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
@ -478,6 +487,18 @@ class NemotronHForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
||||
}
|
||||
embedding_padding_modules = ["lm_head"]
|
||||
|
||||
@classmethod
|
||||
def get_mamba_state_dtype_from_config(
|
||||
cls,
|
||||
vllm_config: "VllmConfig",
|
||||
) -> tuple[torch.dtype, torch.dtype]:
|
||||
|
||||
return MambaStateDtypeCalculator.mamba2_state_dtype(
|
||||
vllm_config.model_config.dtype,
|
||||
vllm_config.cache_config.mamba_cache_dtype,
|
||||
vllm_config.cache_config.mamba_ssm_cache_dtype,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_mamba_state_shape_from_config(
|
||||
cls,
|
||||
@ -569,10 +590,13 @@ class NemotronHForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
||||
mamba_state_shape = \
|
||||
self.get_mamba_state_shape_from_config(
|
||||
self.vllm_config, use_v1=False)
|
||||
mamba_state_dtype = \
|
||||
self.get_mamba_state_dtype_from_config(
|
||||
self.vllm_config)
|
||||
self.mamba_cache = MambaCacheManager(self.vllm_config,
|
||||
self.lm_head.weight.dtype,
|
||||
num_mamba_layers,
|
||||
*mamba_state_shape)
|
||||
*mamba_state_shape,
|
||||
*mamba_state_dtype)
|
||||
|
||||
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
||||
|
||||
|
||||
@ -18,7 +18,7 @@ from transformers import Zamba2Config
|
||||
from vllm import envs
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.activation import GeluAndMul
|
||||
@ -33,7 +33,7 @@ from vllm.model_executor.layers.mamba.mamba2_metadata import (
|
||||
Mamba2Metadata, prepare_mamba2_metadata)
|
||||
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
|
||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||
MambaStateShapeCalculator)
|
||||
MambaStateDtypeCalculator, MambaStateShapeCalculator)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
@ -478,6 +478,8 @@ class Zamba2MambaDecoderLayer(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
config: Zamba2Config,
|
||||
model_config: Optional[ModelConfig] = None,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "") -> None:
|
||||
"""Initialize the Mamba decoder layer.
|
||||
@ -502,6 +504,8 @@ class Zamba2MambaDecoderLayer(nn.Module):
|
||||
config.n_mamba_heads,
|
||||
rms_norm_eps=config.rms_norm_eps,
|
||||
activation="silu",
|
||||
model_config=model_config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mixer")
|
||||
|
||||
@ -578,6 +582,8 @@ class Zamba2HybridLayer(nn.Module):
|
||||
shared_transformer: Zamba2AttentionDecoderLayer,
|
||||
config: Zamba2Config,
|
||||
block_idx: int,
|
||||
model_config: Optional[ModelConfig] = None,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
@ -596,6 +602,8 @@ class Zamba2HybridLayer(nn.Module):
|
||||
bias=False,
|
||||
quant_config=quant_config)
|
||||
self.mamba_decoder = Zamba2MambaDecoderLayer(config,
|
||||
model_config=model_config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix)
|
||||
|
||||
@ -669,6 +677,7 @@ class Zamba2Model(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
model_config = vllm_config.model_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
lora_config = vllm_config.lora_config
|
||||
@ -718,11 +727,15 @@ class Zamba2Model(nn.Module):
|
||||
Zamba2HybridLayer(block,
|
||||
config,
|
||||
block_idx,
|
||||
quant_config,
|
||||
model_config=model_config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix))
|
||||
else:
|
||||
layers.append(
|
||||
Zamba2MambaDecoderLayer(config,
|
||||
model_config=model_config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix))
|
||||
self.layers = nn.ModuleList(layers)
|
||||
@ -848,6 +861,18 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid):
|
||||
"1.weight": "B.weight",
|
||||
})
|
||||
|
||||
@classmethod
|
||||
def get_mamba_state_dtype_from_config(
|
||||
cls,
|
||||
vllm_config: "VllmConfig",
|
||||
) -> tuple[torch.dtype, torch.dtype]:
|
||||
|
||||
return MambaStateDtypeCalculator.mamba2_state_dtype(
|
||||
vllm_config.model_config.dtype,
|
||||
vllm_config.cache_config.mamba_cache_dtype,
|
||||
vllm_config.cache_config.mamba_ssm_cache_dtype,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_mamba_state_shape_from_config(
|
||||
cls,
|
||||
@ -966,10 +991,13 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid):
|
||||
mamba_state_shape = \
|
||||
self.get_mamba_state_shape_from_config(
|
||||
self.vllm_config, use_v1=False)
|
||||
mamba_state_dtype = \
|
||||
self.get_mamba_state_dtype_from_config(
|
||||
self.vllm_config)
|
||||
self.mamba_cache = MambaCacheManager(self.vllm_config,
|
||||
self.lm_head.weight.dtype,
|
||||
num_mamba_layers,
|
||||
*mamba_state_shape)
|
||||
*mamba_state_shape,
|
||||
*mamba_state_dtype)
|
||||
|
||||
# Get cache parameters for current run
|
||||
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
||||
|
||||
@ -173,6 +173,7 @@ CYAN = '\033[1;36m'
|
||||
RESET = '\033[0;0m'
|
||||
|
||||
STR_DTYPE_TO_TORCH_DTYPE = {
|
||||
"float32": torch.float32,
|
||||
"half": torch.half,
|
||||
"bfloat16": torch.bfloat16,
|
||||
"float": torch.float,
|
||||
|
||||
@ -182,14 +182,15 @@ class SlidingWindowSpec(AttentionSpec):
|
||||
@dataclass(frozen=True)
|
||||
class MambaSpec(KVCacheSpec):
|
||||
shapes: tuple[tuple[int, ...], ...]
|
||||
dtype: torch.dtype
|
||||
dtypes: tuple[torch.dtype]
|
||||
page_size_padded: Optional[int] = None
|
||||
mamba_type: str = "mamba2"
|
||||
|
||||
@property
|
||||
def page_size_bytes(self) -> int:
|
||||
num_elements = sum(prod(shape) for shape in self.shapes)
|
||||
page_size = num_elements * get_dtype_size(self.dtype)
|
||||
page_size = sum(
|
||||
prod(shape) * get_dtype_size(dtype)
|
||||
for (shape, dtype) in zip(self.shapes, self.dtypes))
|
||||
if self.page_size_padded is not None:
|
||||
assert self.page_size_padded >= page_size
|
||||
return self.page_size_padded
|
||||
|
||||
@ -2884,23 +2884,25 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
elif isinstance(kv_cache_spec, MambaSpec):
|
||||
has_mamba = True
|
||||
raw_tensor = kv_cache_raw_tensors[layer_name]
|
||||
dtype = kv_cache_spec.dtype
|
||||
num_element_per_page = (kv_cache_spec.page_size_bytes //
|
||||
get_dtype_size(dtype))
|
||||
state_tensors = []
|
||||
storage_offset = 0
|
||||
for shape in kv_cache_spec.shapes:
|
||||
storage_offset_bytes = 0
|
||||
for (shape, dtype) in zip(kv_cache_spec.shapes,
|
||||
kv_cache_spec.dtypes):
|
||||
dtype_size = get_dtype_size(dtype)
|
||||
num_element_per_page = (
|
||||
kv_cache_spec.page_size_bytes // dtype_size)
|
||||
target_shape = (num_blocks, *shape)
|
||||
stride = torch.empty(target_shape).stride()
|
||||
target_stride = (num_element_per_page, *stride[1:])
|
||||
assert storage_offset_bytes % dtype_size == 0
|
||||
tensor = torch.as_strided(
|
||||
raw_tensor.view(dtype),
|
||||
size=target_shape,
|
||||
stride=target_stride,
|
||||
storage_offset=storage_offset,
|
||||
storage_offset=storage_offset_bytes // dtype_size,
|
||||
)
|
||||
state_tensors.append(tensor)
|
||||
storage_offset += stride[0]
|
||||
storage_offset_bytes += stride[0] * dtype_size
|
||||
|
||||
kv_caches[layer_name] = state_tensors
|
||||
else:
|
||||
@ -3087,7 +3089,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
for layer_name, mamba_module in mamba_layers.items():
|
||||
kv_cache_spec[layer_name] = MambaSpec(
|
||||
shapes=mamba_module.get_state_shape(),
|
||||
dtype=self.kv_cache_dtype,
|
||||
dtypes=mamba_module.get_state_dtype(),
|
||||
block_size=max_model_len,
|
||||
page_size_padded=page_size_padded,
|
||||
mamba_type=mamba_module.mamba_type)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user