mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 03:54:59 +08:00
[V1] - Split Prefill and Decode for Mamba1 models (#22653)
Signed-off-by: amirk <amirk@ai21.com> Signed-off-by: asafg <asafg@ai21.com> Co-authored-by: asafg <asafg@ai21.com> Co-authored-by: Asaf Joseph Gardin <39553475+Josephasafg@users.noreply.github.com>
This commit is contained in:
parent
5406ebf5c9
commit
fe91ce9591
@ -57,6 +57,13 @@ V1_SUPPORTED_MODELS = [
|
|||||||
# Avoid OOM
|
# Avoid OOM
|
||||||
MAX_NUM_SEQS = 4
|
MAX_NUM_SEQS = 4
|
||||||
|
|
||||||
|
# Once we add support for FCG in Mamba1, this list will be removed and tests
|
||||||
|
# all test cases will use enforce_eager=False
|
||||||
|
ENFORCE_EAGER_MODELS_V1 = [
|
||||||
|
"state-spaces/mamba-130m-hf",
|
||||||
|
"ai21labs/Jamba-tiny-dev",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model", SSM_MODELS + HYBRID_MODELS)
|
@pytest.mark.parametrize("model", SSM_MODELS + HYBRID_MODELS)
|
||||||
@pytest.mark.parametrize("max_tokens", [64])
|
@pytest.mark.parametrize("max_tokens", [64])
|
||||||
@ -94,13 +101,19 @@ def test_models(
|
|||||||
example_prompts, max_tokens, num_logprobs)
|
example_prompts, max_tokens, num_logprobs)
|
||||||
|
|
||||||
if model in V1_SUPPORTED_MODELS:
|
if model in V1_SUPPORTED_MODELS:
|
||||||
|
enforce_eager = False
|
||||||
with monkeypatch.context() as m:
|
with monkeypatch.context() as m:
|
||||||
m.setenv("VLLM_USE_V1", "1")
|
m.setenv("VLLM_USE_V1", "1")
|
||||||
if model in HYBRID_MODELS:
|
if model in HYBRID_MODELS:
|
||||||
# required due to reorder_batch behaviour
|
# required due to reorder_batch behaviour
|
||||||
m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER")
|
m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER")
|
||||||
|
|
||||||
|
if model in ENFORCE_EAGER_MODELS_V1:
|
||||||
|
enforce_eager = True
|
||||||
|
|
||||||
with vllm_runner(model,
|
with vllm_runner(model,
|
||||||
max_num_seqs=MAX_NUM_SEQS,
|
max_num_seqs=MAX_NUM_SEQS,
|
||||||
|
enforce_eager=enforce_eager,
|
||||||
enable_prefix_caching=False) as vllm_model:
|
enable_prefix_caching=False) as vllm_model:
|
||||||
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
|
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
|
||||||
example_prompts, max_tokens, num_logprobs)
|
example_prompts, max_tokens, num_logprobs)
|
||||||
|
|||||||
@ -1,13 +1,14 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
from typing import Optional
|
from typing import NamedTuple, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
from vllm import envs
|
from vllm import envs
|
||||||
|
from vllm.attention.backends.abstract import AttentionMetadata
|
||||||
from vllm.config import get_current_vllm_config
|
from vllm.config import get_current_vllm_config
|
||||||
from vllm.distributed.parallel_state import (
|
from vllm.distributed.parallel_state import (
|
||||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||||
@ -154,13 +155,38 @@ class MambaMixer(MambaBase, CustomOp):
|
|||||||
|
|
||||||
self.prefix = prefix
|
self.prefix = prefix
|
||||||
|
|
||||||
|
def _ssm_transform(
|
||||||
|
self, x: torch.Tensor
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
if self.is_lora_enabled:
|
||||||
|
# Lora kernel requires contiguous tensor.
|
||||||
|
ssm_params = self.x_proj(x.contiguous())[0]
|
||||||
|
else:
|
||||||
|
ssm_params = self.x_proj(x)[0]
|
||||||
|
time_step, B, C = torch.split(
|
||||||
|
ssm_params,
|
||||||
|
[self.time_step_rank, self.ssm_state_size, self.ssm_state_size],
|
||||||
|
dim=-1)
|
||||||
|
if self.use_rms_norm:
|
||||||
|
assert self.dt_layernorm is not None
|
||||||
|
assert self.b_layernorm is not None
|
||||||
|
assert self.c_layernorm is not None
|
||||||
|
time_step = self.dt_layernorm(time_step.contiguous())
|
||||||
|
B = self.b_layernorm(B.contiguous())
|
||||||
|
C = self.c_layernorm(C.contiguous())
|
||||||
|
discrete_time_step = self.dt_proj(time_step)[0].transpose(-2, -1)
|
||||||
|
return discrete_time_step, B, C
|
||||||
|
|
||||||
def forward(self,
|
def forward(self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
mamba_cache_params: Optional[MambaCacheParams] = None):
|
mamba_cache_params: Optional[MambaCacheParams] = None):
|
||||||
if not envs.VLLM_USE_V1:
|
if not envs.VLLM_USE_V1:
|
||||||
return CustomOp.forward(self, hidden_states, mamba_cache_params)
|
return CustomOp.forward(self, hidden_states, mamba_cache_params)
|
||||||
else:
|
else:
|
||||||
return self.forward_cuda(hidden_states, mamba_cache_params)
|
return self.forward_cuda(
|
||||||
|
hidden_states,
|
||||||
|
mamba_cache_params,
|
||||||
|
)
|
||||||
|
|
||||||
def forward_native(self,
|
def forward_native(self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
@ -170,6 +196,27 @@ class MambaMixer(MambaBase, CustomOp):
|
|||||||
def forward_cuda(self,
|
def forward_cuda(self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
mamba_cache_params: Optional[MambaCacheParams] = None):
|
mamba_cache_params: Optional[MambaCacheParams] = None):
|
||||||
|
"""
|
||||||
|
Run the Mamba-1 SSM pipeline.
|
||||||
|
|
||||||
|
Steps
|
||||||
|
-----
|
||||||
|
1. Apply the gated-MLP linear projection to the raw input.
|
||||||
|
2. Pass the projected sequence through the convolutional mixing layer.
|
||||||
|
3. Feed the result into the State-Space Model (SSM) blocks.
|
||||||
|
4. Perform the recurrence y ← SSM(A, B, C, Δ)(x)
|
||||||
|
to produce contextual representations.
|
||||||
|
5. Project the contextualised sequence back
|
||||||
|
to the output embedding dimension.
|
||||||
|
|
||||||
|
Batch handling
|
||||||
|
--------------
|
||||||
|
Prefill and decode tokens are processed by dedicated CUDA
|
||||||
|
kernels for both the convolutional (conv1d) and SSM stages.
|
||||||
|
In the case of a mixed batch (containing both prefill and
|
||||||
|
decode tokens), both sets of kernels are executed independently
|
||||||
|
and their outputs are concatenated before the final output projection.
|
||||||
|
"""
|
||||||
|
|
||||||
forward_context: ForwardContext = get_forward_context()
|
forward_context: ForwardContext = get_forward_context()
|
||||||
attn_metadata = forward_context.attn_metadata
|
attn_metadata = forward_context.attn_metadata
|
||||||
@ -185,126 +232,142 @@ class MambaMixer(MambaBase, CustomOp):
|
|||||||
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||||
conv_state = self_kv_cache[0].transpose(-1, -2)
|
conv_state = self_kv_cache[0].transpose(-1, -2)
|
||||||
ssm_state = self_kv_cache[1]
|
ssm_state = self_kv_cache[1]
|
||||||
has_initial_state = mamba1_metadata.has_initial_states
|
has_initial_states = mamba1_metadata.has_initial_states
|
||||||
context_lens_tensor = mamba1_metadata.context_lens_tensor
|
|
||||||
else:
|
else:
|
||||||
|
assert isinstance(attn_metadata, AttentionMetadata)
|
||||||
assert mamba_cache_params is not None
|
assert mamba_cache_params is not None
|
||||||
conv_state = mamba_cache_params.conv_state
|
conv_state = mamba_cache_params.conv_state
|
||||||
ssm_state = mamba_cache_params.ssm_state
|
ssm_state = mamba_cache_params.ssm_state
|
||||||
state_indices_tensor = mamba_cache_params.state_indices_tensor
|
state_indices_tensor = mamba_cache_params.state_indices_tensor
|
||||||
query_start_loc = attn_metadata.query_start_loc
|
query_start_loc = attn_metadata.query_start_loc
|
||||||
context_lens_tensor = attn_metadata.context_lens_tensor
|
context_lens_tensor = attn_metadata.context_lens_tensor
|
||||||
|
has_initial_states = None
|
||||||
if context_lens_tensor is not None:
|
if context_lens_tensor is not None:
|
||||||
has_initial_state = context_lens_tensor > 0
|
has_initial_states = context_lens_tensor > 0
|
||||||
|
|
||||||
# 1. Gated MLP's linear projection
|
# 1. Gated MLP's linear projection
|
||||||
projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1)
|
projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1)
|
||||||
hidden_states, gate = projected_states.chunk(2, dim=-2)
|
hidden_states_BC, gate = projected_states.chunk(2, dim=-2)
|
||||||
|
|
||||||
# 2. Convolution sequence transformation
|
|
||||||
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0),
|
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0),
|
||||||
self.conv1d.weight.size(2))
|
self.conv1d.weight.size(2))
|
||||||
|
|
||||||
if envs.VLLM_USE_V1 and attn_metadata is None:
|
if envs.VLLM_USE_V1 and attn_metadata is None:
|
||||||
# V1 profile run
|
# V1 profile run
|
||||||
hidden_states = hidden_states.contiguous()
|
hidden_states_BC = hidden_states_BC.contiguous()
|
||||||
return self.out_proj(hidden_states.transpose(-2, -1))[0]
|
return self.out_proj(hidden_states_BC.transpose(-2, -1))[0]
|
||||||
|
|
||||||
if query_start_loc is not None and context_lens_tensor is not None:
|
num_prefill_tokens = attn_metadata.num_prefill_tokens # token count
|
||||||
# |---------- N-1 iteration --------|
|
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||||
# |---------------- N iteration ---------------------|
|
num_prefills = attn_metadata.num_prefills # request count
|
||||||
# |- tokenA -|......................|-- newTokens ---|
|
num_decodes = attn_metadata.num_decode_tokens # token count (=request)
|
||||||
# |---------- context_len ----------|
|
has_prefill = num_prefill_tokens > 0
|
||||||
# |-------------------- seq_len ---------------------|
|
has_decode = num_decode_tokens > 0
|
||||||
# |-- query_len ---|
|
|
||||||
hidden_states = causal_conv1d_fn(
|
prefill_decode_split = split_batch_to_prefill_and_decode(
|
||||||
hidden_states,
|
hidden_states_BC,
|
||||||
|
gate,
|
||||||
|
state_indices_tensor,
|
||||||
|
query_start_loc,
|
||||||
|
has_initial_states,
|
||||||
|
num_prefill_tokens,
|
||||||
|
num_decode_tokens,
|
||||||
|
num_prefills,
|
||||||
|
num_decodes,
|
||||||
|
)
|
||||||
|
hidden_states_BC_p = prefill_decode_split.hidden_states_BC_p
|
||||||
|
hidden_states_BC_d = prefill_decode_split.hidden_states_BC_d
|
||||||
|
gate_p = prefill_decode_split.gate_p
|
||||||
|
gate_d = prefill_decode_split.gate_d
|
||||||
|
state_indices_tensor_p = prefill_decode_split.state_indices_tensor_p
|
||||||
|
state_indices_tensor_d = prefill_decode_split.state_indices_tensor_d
|
||||||
|
query_start_loc_p = prefill_decode_split.query_start_loc_p
|
||||||
|
has_initial_states_p = prefill_decode_split.has_initial_states_p
|
||||||
|
|
||||||
|
ssm_outputs = []
|
||||||
|
|
||||||
|
if has_prefill:
|
||||||
|
# 2. Convolution sequence transformation
|
||||||
|
conv_out_p = causal_conv1d_fn(
|
||||||
|
hidden_states_BC_p,
|
||||||
conv_weights,
|
conv_weights,
|
||||||
bias=self.conv1d.bias,
|
self.conv1d.bias,
|
||||||
activation=self.activation,
|
activation=self.activation,
|
||||||
conv_states=conv_state,
|
conv_states=conv_state,
|
||||||
has_initial_state=has_initial_state,
|
has_initial_state=has_initial_states_p,
|
||||||
cache_indices=state_indices_tensor,
|
cache_indices=state_indices_tensor_p,
|
||||||
query_start_loc=query_start_loc)
|
query_start_loc=query_start_loc_p)
|
||||||
else:
|
# 3. State Space Model sequence transformations.
|
||||||
hidden_states = causal_conv1d_update(
|
discrete_time_step_p, B_p, C_p = self._ssm_transform(
|
||||||
hidden_states.transpose(0, 1),
|
conv_out_p.transpose(-2, -1))
|
||||||
|
time_proj_bias = self._time_proj_bias()
|
||||||
|
|
||||||
|
# 4. Perform the recurrence y ← SSM(A, B, C, Δ)(x)
|
||||||
|
scan_out_p = selective_scan_fn(
|
||||||
|
conv_out_p,
|
||||||
|
ssm_state,
|
||||||
|
discrete_time_step_p,
|
||||||
|
self.A,
|
||||||
|
B_p.transpose(-2, -1),
|
||||||
|
C_p.transpose(-2, -1),
|
||||||
|
self.D.float(),
|
||||||
|
gate_p,
|
||||||
|
time_proj_bias,
|
||||||
|
delta_softplus=True,
|
||||||
|
cache_indices=state_indices_tensor_p,
|
||||||
|
has_initial_state=has_initial_states_p,
|
||||||
|
query_start_loc=query_start_loc_p)
|
||||||
|
ssm_outputs.append(scan_out_p)
|
||||||
|
|
||||||
|
if has_decode:
|
||||||
|
# 2. Convolution sequence transformation
|
||||||
|
conv_out_d = causal_conv1d_update(
|
||||||
|
hidden_states_BC_d.transpose(0, 1),
|
||||||
conv_state,
|
conv_state,
|
||||||
conv_weights,
|
conv_weights,
|
||||||
self.conv1d.bias,
|
self.conv1d.bias,
|
||||||
self.activation,
|
self.activation,
|
||||||
conv_state_indices=state_indices_tensor)
|
conv_state_indices=state_indices_tensor_d).transpose(0, 1)
|
||||||
hidden_states = hidden_states.transpose(0, 1)
|
|
||||||
|
|
||||||
# 3. State Space Model sequence transformation
|
# 3. State Space Model sequence transformation.
|
||||||
# 3.a. input varying initialization of time_step, B and C
|
discrete_time_step_d, B_d, C_d = self._ssm_transform(
|
||||||
|
conv_out_d.transpose(-2, -1))
|
||||||
|
time_proj_bias = self._time_proj_bias()
|
||||||
|
|
||||||
if self.is_lora_enabled:
|
# 4. Perform the recurrence y ← SSM(A, B, C, Δ)(x)
|
||||||
# lora kernel requires contiguous tensor
|
scan_outputs_d = torch.empty_like(
|
||||||
ssm_parameters = self.x_proj(
|
hidden_states_BC_d.transpose(0, 1))
|
||||||
hidden_states.transpose(-2, -1).contiguous())[0]
|
|
||||||
else:
|
|
||||||
ssm_parameters = self.x_proj(hidden_states.transpose(-2, -1))[0]
|
|
||||||
|
|
||||||
time_step, B, C = torch.split(
|
|
||||||
ssm_parameters,
|
|
||||||
[self.time_step_rank, self.ssm_state_size, self.ssm_state_size],
|
|
||||||
dim=-1,
|
|
||||||
)
|
|
||||||
if self.use_rms_norm:
|
|
||||||
assert self.dt_layernorm is not None
|
|
||||||
assert self.b_layernorm is not None
|
|
||||||
assert self.c_layernorm is not None
|
|
||||||
time_step = self.dt_layernorm(time_step.contiguous())
|
|
||||||
B = self.b_layernorm(B.contiguous())
|
|
||||||
C = self.c_layernorm(C.contiguous())
|
|
||||||
|
|
||||||
discrete_time_step = self.dt_proj(time_step)[0].transpose(-2, -1)
|
|
||||||
# 3.c perform the recurrence y ← SSM(A, B, C)(x)
|
|
||||||
time_proj_bias = (self.dt_proj.bias.float() if hasattr(
|
|
||||||
self.dt_proj, "bias") else None)
|
|
||||||
|
|
||||||
if query_start_loc is not None and context_lens_tensor is not None:
|
|
||||||
scan_outputs = selective_scan_fn(
|
|
||||||
hidden_states,
|
|
||||||
ssm_state,
|
|
||||||
discrete_time_step,
|
|
||||||
self.A,
|
|
||||||
B.transpose(-2, -1),
|
|
||||||
C.transpose(-2, -1),
|
|
||||||
self.D.float(),
|
|
||||||
gate,
|
|
||||||
time_proj_bias,
|
|
||||||
delta_softplus=True,
|
|
||||||
cache_indices=state_indices_tensor,
|
|
||||||
has_initial_state=has_initial_state,
|
|
||||||
query_start_loc=query_start_loc)
|
|
||||||
else:
|
|
||||||
scan_outputs = torch.empty_like(hidden_states.transpose(0, 1))
|
|
||||||
selective_state_update(ssm_state,
|
selective_state_update(ssm_state,
|
||||||
hidden_states.transpose(0, 1),
|
conv_out_d.transpose(0, 1),
|
||||||
discrete_time_step.transpose(0, 1),
|
discrete_time_step_d.transpose(0, 1),
|
||||||
self.A,
|
self.A,
|
||||||
B,
|
B_d,
|
||||||
C,
|
C_d,
|
||||||
self.D,
|
self.D,
|
||||||
gate.transpose(0, 1),
|
gate_d.transpose(0, 1),
|
||||||
time_proj_bias,
|
time_proj_bias,
|
||||||
dt_softplus=True,
|
dt_softplus=True,
|
||||||
state_batch_indices=state_indices_tensor,
|
state_batch_indices=state_indices_tensor_d,
|
||||||
out=scan_outputs)
|
out=scan_outputs_d)
|
||||||
scan_outputs = scan_outputs.transpose(0, 1)
|
scan_outputs_d = scan_outputs_d.transpose(0, 1)
|
||||||
|
|
||||||
# 4. Final linear projection
|
if envs.VLLM_USE_V1:
|
||||||
if self.is_lora_enabled:
|
ssm_outputs.insert(0, scan_outputs_d)
|
||||||
# lora kernel requires contiguous tensor
|
else:
|
||||||
contextualized_states = self.out_proj(
|
ssm_outputs.append(scan_outputs_d)
|
||||||
scan_outputs.transpose(-2, -1).contiguous())[0]
|
|
||||||
|
scan_outputs_combined = ssm_outputs[0] if len(
|
||||||
|
ssm_outputs) == 1 else torch.cat(ssm_outputs, dim=-1)
|
||||||
|
|
||||||
|
# 5. Final output projection
|
||||||
|
if self.is_lora_enabled: # Lora kernel requires contiguous tensor.
|
||||||
|
scan_outputs_combined = scan_outputs_combined.transpose(
|
||||||
|
-2, -1).contiguous()
|
||||||
|
out = self.out_proj(scan_outputs_combined)[0]
|
||||||
else:
|
else:
|
||||||
contextualized_states = self.out_proj(
|
out = self.out_proj(scan_outputs_combined.transpose(-2, -1))[0]
|
||||||
scan_outputs.transpose(-2, -1))[0]
|
|
||||||
return contextualized_states
|
return out
|
||||||
|
|
||||||
def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
|
def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
|
||||||
return MambaStateShapeCalculator.mamba1_state_shape(
|
return MambaStateShapeCalculator.mamba1_state_shape(
|
||||||
@ -317,3 +380,69 @@ class MambaMixer(MambaBase, CustomOp):
|
|||||||
@property
|
@property
|
||||||
def mamba_type(self) -> str:
|
def mamba_type(self) -> str:
|
||||||
return "mamba1"
|
return "mamba1"
|
||||||
|
|
||||||
|
def _time_proj_bias(self) -> Optional[torch.Tensor]:
|
||||||
|
if hasattr(self.dt_proj, "bias") and self.dt_proj.bias is not None:
|
||||||
|
return self.dt_proj.bias.float()
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class PrefillDecodeSplit(NamedTuple):
|
||||||
|
hidden_states_BC_p: torch.Tensor
|
||||||
|
hidden_states_BC_d: torch.Tensor
|
||||||
|
gate_p: torch.Tensor
|
||||||
|
gate_d: torch.Tensor
|
||||||
|
state_indices_tensor_p: torch.Tensor
|
||||||
|
state_indices_tensor_d: torch.Tensor
|
||||||
|
query_start_loc_p: Optional[torch.Tensor]
|
||||||
|
has_initial_states_p: Optional[torch.Tensor]
|
||||||
|
|
||||||
|
|
||||||
|
def split_batch_to_prefill_and_decode(
|
||||||
|
hidden_states_BC: torch.Tensor,
|
||||||
|
gate: torch.Tensor,
|
||||||
|
state_indices_tensor: torch.Tensor,
|
||||||
|
query_start_loc: torch.Tensor,
|
||||||
|
has_initial_states: Optional[torch.Tensor],
|
||||||
|
num_prefill_tokens: int,
|
||||||
|
num_decode_tokens: int,
|
||||||
|
num_prefills: int,
|
||||||
|
num_decodes: int,
|
||||||
|
) -> PrefillDecodeSplit:
|
||||||
|
if envs.VLLM_USE_V1:
|
||||||
|
# In v1, decode tokens come first, then prefill tokens.
|
||||||
|
hidden_states_BC_d, hidden_states_BC_p = torch.split(
|
||||||
|
hidden_states_BC, [num_decode_tokens, num_prefill_tokens], dim=-1)
|
||||||
|
gate_d, gate_p = torch.split(gate,
|
||||||
|
[num_decode_tokens, num_prefill_tokens],
|
||||||
|
dim=-1)
|
||||||
|
state_indices_tensor_d, state_indices_tensor_p = torch.split(
|
||||||
|
state_indices_tensor, [num_decodes, num_prefills], dim=0)
|
||||||
|
query_start_loc_p = (query_start_loc[-num_prefills - 1:] -
|
||||||
|
num_decodes if num_prefills > 0 else None)
|
||||||
|
has_initial_states_p = has_initial_states[-num_prefills:] if (
|
||||||
|
has_initial_states is not None and num_prefills > 0) else None
|
||||||
|
else:
|
||||||
|
# In v0, prefill tokens come first, then decode tokens.
|
||||||
|
hidden_states_BC_p, hidden_states_BC_d = torch.split(
|
||||||
|
hidden_states_BC, [num_prefill_tokens, num_decode_tokens], dim=-1)
|
||||||
|
gate_p, gate_d = torch.split(gate,
|
||||||
|
[num_prefill_tokens, num_decode_tokens],
|
||||||
|
dim=-1)
|
||||||
|
state_indices_tensor_p, state_indices_tensor_d = torch.split(
|
||||||
|
state_indices_tensor, [num_prefills, num_decodes], dim=0)
|
||||||
|
query_start_loc_p = (query_start_loc[:num_prefills +
|
||||||
|
1] if num_prefills > 0 else None)
|
||||||
|
has_initial_states_p = has_initial_states[:num_prefills] if (
|
||||||
|
has_initial_states is not None and num_prefills > 0) else None
|
||||||
|
|
||||||
|
return PrefillDecodeSplit(
|
||||||
|
hidden_states_BC_p=hidden_states_BC_p,
|
||||||
|
hidden_states_BC_d=hidden_states_BC_d,
|
||||||
|
gate_p=gate_p,
|
||||||
|
gate_d=gate_d,
|
||||||
|
state_indices_tensor_p=state_indices_tensor_p,
|
||||||
|
state_indices_tensor_d=state_indices_tensor_d,
|
||||||
|
query_start_loc_p=query_start_loc_p,
|
||||||
|
has_initial_states_p=has_initial_states_p,
|
||||||
|
)
|
||||||
|
|||||||
@ -2,14 +2,15 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import ClassVar
|
from typing import ClassVar, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.attention.backends.abstract import AttentionBackend
|
from vllm.attention.backends.abstract import AttentionBackend
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
|
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
|
||||||
CommonAttentionMetadata)
|
CommonAttentionMetadata,
|
||||||
|
split_decodes_and_prefills)
|
||||||
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
|
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
|
||||||
|
|
||||||
|
|
||||||
@ -25,12 +26,15 @@ class Mamba1AttentionMetadata:
|
|||||||
query_start_loc: torch.Tensor
|
query_start_loc: torch.Tensor
|
||||||
context_lens_tensor: torch.Tensor
|
context_lens_tensor: torch.Tensor
|
||||||
state_indices_tensor: torch.Tensor
|
state_indices_tensor: torch.Tensor
|
||||||
has_initial_states: torch.Tensor
|
has_initial_states: Optional[torch.Tensor]
|
||||||
|
num_prefills: int
|
||||||
|
num_prefill_tokens: int
|
||||||
|
num_decodes: int
|
||||||
|
num_decode_tokens: int
|
||||||
|
|
||||||
|
|
||||||
class Mamba1AttentionMetadataBuilder(
|
class Mamba1AttentionMetadataBuilder(
|
||||||
AttentionMetadataBuilder[Mamba1AttentionMetadata]):
|
AttentionMetadataBuilder[Mamba1AttentionMetadata]):
|
||||||
|
|
||||||
reorder_batch_threshold: ClassVar[int] = 1
|
reorder_batch_threshold: ClassVar[int] = 1
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -57,11 +61,23 @@ class Mamba1AttentionMetadataBuilder(
|
|||||||
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
|
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
|
||||||
context_lens_tensor = common_attn_metadata.num_computed_tokens_cpu.to(
|
context_lens_tensor = common_attn_metadata.num_computed_tokens_cpu.to(
|
||||||
query_start_loc.device)
|
query_start_loc.device)
|
||||||
has_initial_states = (context_lens_tensor > 0)
|
|
||||||
|
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||||
|
split_decodes_and_prefills(common_attn_metadata,
|
||||||
|
decode_threshold=1))
|
||||||
|
|
||||||
|
has_initial_states = None
|
||||||
|
|
||||||
|
if num_prefills > 0:
|
||||||
|
has_initial_states = context_lens_tensor > 0
|
||||||
|
|
||||||
return Mamba1AttentionMetadata(
|
return Mamba1AttentionMetadata(
|
||||||
query_start_loc=query_start_loc,
|
query_start_loc=query_start_loc,
|
||||||
context_lens_tensor=context_lens_tensor,
|
context_lens_tensor=context_lens_tensor,
|
||||||
has_initial_states=has_initial_states,
|
has_initial_states=has_initial_states,
|
||||||
state_indices_tensor=state_indices_tensor,
|
state_indices_tensor=state_indices_tensor,
|
||||||
|
num_prefills=num_prefills,
|
||||||
|
num_prefill_tokens=num_prefill_tokens,
|
||||||
|
num_decodes=num_decodes,
|
||||||
|
num_decode_tokens=num_decode_tokens,
|
||||||
)
|
)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user