[Model] Add LongCat-Flash (#23991)

Signed-off-by: yangxurui <yangxurui@meituan.com>
Co-authored-by: yangxurui <yangxurui@meituan.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
XuruiYang 2025-09-25 12:53:40 +08:00 committed by yewentao256
parent 12c21d28c1
commit c26e7b14d7
31 changed files with 1357 additions and 66 deletions

View File

@ -44,6 +44,9 @@ __global__ void moe_align_block_size_kernel(
for (size_t i = tid; i < numel; i += stride) {
int expert_id = topk_ids[i];
if (expert_id >= num_experts) {
continue;
}
int warp_idx = expert_id / experts_per_warp;
int expert_offset = expert_id % experts_per_warp;
atomicAdd(&shared_counts[warp_idx * experts_per_warp + expert_offset], 1);
@ -95,12 +98,15 @@ template <typename scalar_t>
__global__ void count_and_sort_expert_tokens_kernel(
const scalar_t* __restrict__ topk_ids,
int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ cumsum_buffer,
size_t numel) {
size_t numel, int32_t num_experts) {
const size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
const size_t stride = blockDim.x * gridDim.x;
for (size_t i = tid; i < numel; i += stride) {
int32_t expert_id = topk_ids[i];
if (expert_id >= num_experts) {
continue;
}
int32_t rank_post_pad = atomicAdd(&cumsum_buffer[expert_id], 1);
sorted_token_ids[rank_post_pad] = i;
}
@ -269,7 +275,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
sort_kernel<<<actual_blocks, block_threads, 0, stream>>>(
topk_ids.data_ptr<scalar_t>(),
sorted_token_ids.data_ptr<int32_t>(),
cumsum_buffer.data_ptr<int32_t>(), topk_ids.numel());
cumsum_buffer.data_ptr<int32_t>(), topk_ids.numel(), num_experts);
}
});
}

View File

@ -428,6 +428,7 @@ th {
| `MiniMaxM1ForCausalLM` | MiniMax-Text | `MiniMaxAI/MiniMax-M1-40k`, `MiniMaxAI/MiniMax-M1-80k`, etc. | | | ✅︎ |
| `MiniMaxText01ForCausalLM` | MiniMax-Text | `MiniMaxAI/MiniMax-Text-01`, etc. | | | ✅︎ |
| `Zamba2ForCausalLM` | Zamba2 | `Zyphra/Zamba2-7B-instruct`, `Zyphra/Zamba2-2.7B-instruct`, `Zyphra/Zamba2-1.2B-instruct`, etc. | | | ✅︎ |
| `LongcatFlashForCausalLM` | LongCat-Flash | `meituan-longcat/LongCat-Flash-Chat`, `meituan-longcat/LongCat-Flash-Chat-FP8` | ✅︎ |✅︎ | ✅︎ |
Some models are supported only via the [Transformers backend](#transformers). The purpose of the table below is to acknowledge models which we officially support in this way. The logs will say that the Transformers backend is being used, and you will see no warning that this is fallback behaviour. This means that, if you have issues with any of the models listed below, please [make an issue](https://github.com/vllm-project/vllm/issues/new/choose) and we'll do our best to fix it!

View File

@ -138,7 +138,7 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph(
td = TestData.make_moe_tensors_8bit(m, k, n, e, reorder=True)
score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
topk_weights, topk_ids = FusedMoE.select_experts(
topk_weights, topk_ids, _ = FusedMoE.select_experts(
hidden_states=td.hidden_states,
router_logits=score,
use_grouped_topk=False,
@ -206,7 +206,7 @@ def test_flashinfer_cutlass_moe_fp8_no_graph(
td = TestData.make_moe_tensors_8bit(m, k, n, e, reorder=False)
score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
topk_weights, topk_ids = FusedMoE.select_experts(
topk_weights, topk_ids, _ = FusedMoE.select_experts(
hidden_states=td.hidden_states,
router_logits=score,
use_grouped_topk=False,

View File

@ -273,6 +273,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
is_available_online=False),
"Llama4ForCausalLM": _HfExamplesInfo("meta-llama/Llama-4-Scout-17B-16E-Instruct", # noqa: E501
is_available_online=False),
"LongcatFlashForCausalLM": _HfExamplesInfo
("meituan-longcat/LongCat-Flash-Chat", trust_remote_code=True),
"MambaForCausalLM": _HfExamplesInfo("state-spaces/mamba-130m-hf"),
"Mamba2ForCausalLM": _HfExamplesInfo("mistralai/Mamba-Codestral-7B-v0.1",
min_transformers_version="4.55.3",
@ -639,6 +641,10 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
speculative_model="zai-org/GLM-4.5",
min_transformers_version="4.54",
is_available_online=False),
"LongCatFlashMTPModel": _HfExamplesInfo(
"meituan-longcat/LongCat-Flash-Chat",
trust_remote_code=True,
speculative_model="meituan-longcat/LongCat-Flash-Chat"),
"MiMoMTPModel": _HfExamplesInfo("XiaomiMiMo/MiMo-7B-RL",
trust_remote_code=True,
speculative_model="XiaomiMiMo/MiMo-7B-RL"),

View File

@ -428,9 +428,8 @@ def dummy_hf_overrides(
num_hidden_layers = (3 if model_arch
== "Gemma3nForConditionalGeneration" else 1)
text_config.update({
update_dict = {
"num_layers": num_layers,
"num_hidden_layers": num_hidden_layers,
"num_experts": num_experts,
"num_experts_per_tok": 2,
"num_local_experts": num_experts,
@ -440,7 +439,14 @@ def dummy_hf_overrides(
"n_routed_experts": num_experts,
# For Gemma-3n
"num_kv_shared_layers": 1,
})
}
# Update num_hidden_layers for non-Longcat architectures
if model_arch != "LongcatFlashForCausalLM" \
and model_arch != "LongCatFlashMTPModel":
update_dict["num_hidden_layers"] = num_hidden_layers
text_config.update(update_dict)
if hasattr(hf_config, "vision_config"):
hf_config.vision_config.update({

View File

@ -96,7 +96,7 @@ def test_routing_strategy_integration(monkeypatch, device):
envs.environment_variables[env_name] = lambda s=strategy: s
# Test the select_experts method
topk_weights, topk_ids = FusedMoE.select_experts(
topk_weights, topk_ids, _ = FusedMoE.select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
top_k=top_k,

View File

@ -1131,7 +1131,8 @@ class ModelConfig:
if not hasattr(self.hf_text_config, "model_type"):
return False
elif self.hf_text_config.model_type in \
('deepseek_v2', 'deepseek_v3', 'deepseek_mtp', 'kimi_k2'):
('deepseek_v2', 'deepseek_v3', 'deepseek_mtp',
'kimi_k2', 'longcat_flash'):
return self.hf_text_config.kv_lora_rank is not None
elif self.hf_text_config.model_type == 'eagle':
# if the model is an EAGLE module, check for the
@ -1257,6 +1258,9 @@ class ModelConfig:
or self.hf_config.model_type == "qwen3_next_mtp"):
total_num_hidden_layers = getattr(self.hf_text_config,
"num_nextn_predict_layers", 0)
elif (self.hf_config.model_type == "longcat_flash_mtp"):
total_num_hidden_layers = getattr(self.hf_text_config,
"num_nextn_predict_layers", 1)
else:
total_num_hidden_layers = getattr(self.hf_text_config,
"num_hidden_layers", 0)

View File

@ -31,7 +31,8 @@ logger = init_logger(__name__)
SpeculativeMethod = Literal["ngram", "eagle", "eagle3", "medusa",
"mlp_speculator", "draft_model", "deepseek_mtp",
"ernie_mtp", "qwen3_next_mtp", "mimo_mtp"]
"ernie_mtp", "qwen3_next_mtp", "mimo_mtp",
"longcat_flash_mtp"]
@config
@ -186,6 +187,13 @@ class SpeculativeConfig:
"n_predict": n_predict,
"architectures": ["Qwen3NextMTP"]
})
if hf_config.model_type == "longcat_flash":
hf_config.model_type = "longcat_flash_mtp"
n_predict = getattr(hf_config, "num_nextn_predict_layers", 1)
hf_config.update({
"n_predict": n_predict,
"architectures": ["LongCatFlashMTPModel"]
})
return hf_config
@ -332,6 +340,15 @@ class SpeculativeConfig:
"one layer. Might need some code changes " \
"to support multiple layers."
)
elif (self.draft_model_config.hf_config.model_type
in ("longcat_flash_mtp")):
self.method = "longcat_flash_mtp"
if self.num_speculative_tokens > 1:
logger.warning(
"LongCat MTP models only have " \
"one layer. Might need some code changes " \
"to support multiple layers."
)
else:
self.method = "draft_model"
raise NotImplementedError(
@ -548,7 +565,7 @@ class SpeculativeConfig:
def use_eagle(self) -> bool:
return self.method in ("eagle", "eagle3", "deepseek_mtp", "ernie_mtp",
"qwen3_next_mtp")
"qwen3_next_mtp", "longcat_flash_mtp")
def __repr__(self) -> str:
method = self.method

View File

@ -664,6 +664,76 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
)
@triton.jit
def compute_identity_kernel(
top_k: int,
hidden_states_ptr: tl.tensor,
expert_scales_ptr: tl.tensor,
num_tokens: int,
output_ptr: tl.tensor,
hidden_dim: int,
scales_stride: int,
BLOCK_SIZE: tl.constexpr,
) -> None:
pid = tl.program_id(0)
batch_id = pid // (hidden_dim // BLOCK_SIZE)
dim_offset = pid % (hidden_dim // BLOCK_SIZE) * BLOCK_SIZE
if batch_id >= num_tokens or dim_offset >= hidden_dim:
return
h = tl.load(hidden_states_ptr + batch_id * hidden_dim + dim_offset +
tl.arange(0, BLOCK_SIZE),
mask=(dim_offset + tl.arange(0, BLOCK_SIZE)) < hidden_dim)
result = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
for i in range(top_k):
scale = tl.load(expert_scales_ptr + batch_id * scales_stride + i)
result += h * scale
tl.store(output_ptr + batch_id * hidden_dim + dim_offset +
tl.arange(0, BLOCK_SIZE),
result,
mask=(dim_offset + tl.arange(0, BLOCK_SIZE)) < hidden_dim)
def zero_experts_compute_triton(expert_indices: torch.Tensor,
expert_scales: torch.Tensor, num_experts: int,
zero_expert_type: str,
hidden_states: torch.Tensor) -> torch.Tensor:
N = expert_indices.numel()
top_k = expert_indices.size(-1)
grid = lambda meta: (triton.cdiv(N, meta['BLOCK_SIZE']), )
if zero_expert_type == "identity":
zero_expert_mask = expert_indices < num_experts
zero_expert_scales = expert_scales.clone()
zero_expert_scales[zero_expert_mask] = 0.0
normal_expert_mask = expert_indices >= num_experts
expert_indices[normal_expert_mask] = 0
expert_scales[normal_expert_mask] = 0.0
output = torch.zeros_like(hidden_states).to(hidden_states.device)
hidden_dim = hidden_states.size(-1)
num_tokens = hidden_states.size(0)
grid = lambda meta: (num_tokens * (hidden_dim // meta['BLOCK_SIZE']), )
compute_identity_kernel[grid](
top_k,
hidden_states,
zero_expert_scales,
num_tokens,
output,
hidden_dim,
zero_expert_scales.stride(0),
BLOCK_SIZE=256,
)
return output
# Adapted from: https://github.com/sgl-project/sglang/pull/2628
def get_config_file_name(E: int,
N: int,
@ -940,6 +1010,25 @@ def fused_topk(
return topk_weights, topk_ids, token_expert_indices
def fused_topk_bias(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
e_score_correction_bias: torch.Tensor,
topk: int,
renormalize: bool,
):
n_routed_experts = gating_output.shape[-1]
scores = gating_output.softmax(dim=-1)
scores_for_choice = scores.view(
-1, n_routed_experts) + e_score_correction_bias.unsqueeze(0)
topk_indices = torch.topk(scores_for_choice, k=topk, dim=-1,
sorted=False)[1]
topk_weights = scores.gather(1, topk_indices)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return topk_weights.to(torch.float32), topk_indices.to(torch.int32)
# This is used by the Deepseek-V2 and Deepseek-V3 model
@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
def grouped_topk(

View File

@ -24,6 +24,8 @@ from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.fused_moe.config import (
FUSED_MOE_UNQUANTIZED_CONFIG, FusedMoEConfig, FusedMoEParallelConfig,
FusedMoEQuantConfig, biased_moe_quant_config)
from vllm.model_executor.layers.fused_moe.fused_moe import (
zero_experts_compute_triton)
# yapf: enable
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEActivationFormat, FusedMoEModularKernel,
@ -548,7 +550,10 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
logical_replica_count: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
topk_weights, topk_ids = FusedMoE.select_experts(
zero_expert_num = getattr(layer, 'zero_expert_num', 0)
zero_expert_type = getattr(layer, 'zero_expert_type', None)
topk_weights, topk_ids, zero_expert_result = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
@ -565,11 +570,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
expert_map=expert_map,
expert_load_view=expert_load_view,
logical_to_physical_map=logical_to_physical_map,
logical_replica_count=logical_replica_count)
logical_replica_count=logical_replica_count,
global_num_experts=global_num_experts,
zero_expert_num=zero_expert_num,
zero_expert_type=zero_expert_type)
if self.rocm_aiter_moe_enabled:
assert self.fused_experts is None
return self.rocm_aiter_fused_experts(
result = self.rocm_aiter_fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
@ -591,7 +599,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
if self.moe.has_bias:
raise ValueError(
"FusedMoEModularKernel does not support bias.")
return self.fused_experts(
result = self.fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
@ -605,7 +613,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
)
else:
assert fused_experts is not None
return fused_experts(
result = fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
@ -619,6 +627,13 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
expert_map=expert_map,
)
if zero_expert_num != 0 and zero_expert_type is not None:
assert not isinstance(result, tuple), \
"Shared + zero experts are mutually exclusive not yet supported"
return result, zero_expert_result
else:
return result
def forward_cpu(
self,
layer: torch.nn.Module,
@ -942,6 +957,8 @@ class FusedMoE(CustomOp):
num_redundant_experts: int = 0,
has_bias: bool = False,
is_sequence_parallel=False,
zero_expert_num: Optional[int] = 0,
zero_expert_type: Optional[str] = None,
):
super().__init__()
if params_dtype is None:
@ -976,6 +993,8 @@ class FusedMoE(CustomOp):
vllm_parallel_config=vllm_config.parallel_config))
self.global_num_experts = num_experts + num_redundant_experts
self.zero_expert_num = zero_expert_num
self.zero_expert_type = zero_expert_type
# Round up hidden size if needed.
hidden_size = maybe_roundup_hidden_size(hidden_size, moe_in_dtype,
@ -1656,25 +1675,30 @@ class FusedMoE(CustomOp):
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
global_num_experts: Optional[int] = None,
zero_expert_num: Optional[int] = None,
zero_expert_type: Optional[str] = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Route the input hidden states to the top-k experts based on the
router logits.
Returns:
(topk_weights, topk_ids) (tuple[torch.Tensor, torch.Tensor]):
The weights and *global physical* expert ids of the top-k experts.
(topk_weights, topk_ids, zero_expert_result)
(tuple[torch.Tensor, torch.Tensor, torch.Tensor]):
The weights, expert ids, and zero expert computation result.
**Compatibility**: When EPLB is not enabled, the returned ids are
equivalent to global logical ids, so should be compatible with
plain MoE implementations without redundant experts.
"""
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_topk, fused_topk_bias)
# Check if we should use a routing simulation strategy
routing_strategy = envs.VLLM_MOE_ROUTING_SIMULATION_STRATEGY
if routing_strategy != "":
return RoutingSimulator.simulate_routing(
topk_weights, topk_ids = RoutingSimulator.simulate_routing(
hidden_states=hidden_states,
router_logits=router_logits,
strategy_name=routing_strategy,
@ -1697,6 +1721,16 @@ class FusedMoE(CustomOp):
e_score_correction_bias=e_score_correction_bias)
if indices_type is not None:
topk_ids = topk_ids.to(dtype=indices_type)
elif e_score_correction_bias is not None:
topk_weights, topk_ids = fused_topk_bias(
hidden_states=hidden_states,
gating_output=router_logits,
e_score_correction_bias=e_score_correction_bias.data,
topk=top_k,
renormalize=renormalize,
)
if routed_scaling_factor is not None:
topk_weights *= routed_scaling_factor
elif custom_routing_function is None:
topk_weights, topk_ids, token_expert_indices = fused_topk(
hidden_states=hidden_states,
@ -1729,7 +1763,20 @@ class FusedMoE(CustomOp):
assert topk_ids.dtype == indices_type or indices_type is None
return topk_weights, topk_ids
# Compute zero expert result if needed
if (zero_expert_num is not None and zero_expert_num > 0
and zero_expert_type is not None
and global_num_experts is not None):
zero_expert_result = zero_experts_compute_triton(
expert_indices=topk_ids,
expert_scales=topk_weights,
num_experts=global_num_experts,
zero_expert_type=zero_expert_type,
hidden_states=hidden_states,
)
else:
zero_expert_result = None
return topk_weights, topk_ids, zero_expert_result
def must_reduce_shared_expert_outputs(self) -> bool:
"""
@ -1878,6 +1925,11 @@ class FusedMoE(CustomOp):
assert self.shared_experts is None or isinstance(
final_hidden_states, tuple)
if isinstance(final_hidden_states, tuple):
final_hidden_states, zero_expert_result = final_hidden_states
if zero_expert_result is not None:
final_hidden_states += zero_expert_result
if not skip_result_store:
if self.shared_experts is None:
full_fused_final_hidden_states[
@ -1992,6 +2044,9 @@ class FusedMoE(CustomOp):
shared_output,
final_hidden_states,
)
elif self.zero_expert_num is not None and self.zero_expert_num > 0:
assert isinstance(final_hidden_states, tuple)
final_hidden_states, zero_expert_result = final_hidden_states
def reduce_output(states: torch.Tensor,
do_combine: bool = True) -> torch.Tensor:
@ -2003,14 +2058,16 @@ class FusedMoE(CustomOp):
return states
if self.shared_experts is None:
assert not isinstance(final_hidden_states, tuple)
return reduce_output(final_hidden_states)
else:
if self.shared_experts is not None:
return (
reduce_output(final_hidden_states[0], do_combine=False),
reduce_output(final_hidden_states[1]),
)
elif self.zero_expert_num is not None and self.zero_expert_num > 0:
assert isinstance(final_hidden_states, torch.Tensor)
return reduce_output(final_hidden_states) + zero_expert_result
else:
return reduce_output(final_hidden_states)
@classmethod
def make_expert_params_mapping(

View File

@ -103,7 +103,6 @@ class MultiHeadLatentAttention(CustomOp):
)
self.prefix = prefix
self.debug_layer_idx = int(self.prefix.split(".")[-2])
def forward_native(
self,

View File

@ -520,7 +520,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
assert activation == "silu", "Only SiLU activation is supported."
topk_weights, topk_ids = FusedMoE.select_experts(
topk_weights, topk_ids, _ = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,

View File

@ -486,7 +486,7 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase):
if enable_eplb:
raise NotImplementedError(
"EPLB not supported for `BitsAndBytesMoEMethod` yet.")
topk_weights, topk_ids = FusedMoE.select_experts(
topk_weights, topk_ids, _ = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,

View File

@ -385,7 +385,7 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
"`CompressedTensorsW4A4MoeMethod` yet.")
assert activation == "silu", "Only SiLU activation is supported."
topk_weights, topk_ids = FusedMoE.select_experts(
topk_weights, topk_ids, _ = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
@ -934,7 +934,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
"EPLB not supported for "
"`CompressedTensorsW8A8Fp8MoEMethod` yet.")
topk_weights, topk_ids = FusedMoE.select_experts(
topk_weights, topk_ids, _ = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
@ -1195,7 +1195,7 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
from vllm.model_executor.layers.fused_moe import fused_experts
topk_weights, topk_ids = FusedMoE.select_experts(
topk_weights, topk_ids, _ = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
@ -1502,7 +1502,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
assert activation == "silu", (
f"{activation} not supported for Marlin MoE.")
topk_weights, topk_ids = FusedMoE.select_experts(
topk_weights, topk_ids, _ = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
@ -1747,7 +1747,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
from vllm.model_executor.layers.fused_moe import fused_experts
topk_weights, topk_ids = FusedMoE.select_experts(
topk_weights, topk_ids, _ = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,

View File

@ -146,7 +146,7 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
from vllm.model_executor.layers.fused_moe import fused_experts
topk_weights, topk_ids = FusedMoE.select_experts(
topk_weights, topk_ids, _ = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,

View File

@ -18,6 +18,8 @@ from vllm.model_executor.layers.fused_moe import (
FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig, fp8_w8a8_moe_quant_config)
from vllm.model_executor.layers.fused_moe.layer import (
UnquantizedFusedMoEMethod)
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import QuantizationMethods
@ -174,6 +176,10 @@ class Fp8Config(QuantizationConfig):
return UnquantizedLinearMethod()
return Fp8LinearMethod(self)
elif isinstance(layer, FusedMoE):
if is_layer_skipped(prefix=prefix,
ignored_layers=self.ignored_layers,
fused_mapping=self.packed_modules_mapping):
return UnquantizedFusedMoEMethod(layer.moe_config)
return Fp8MoEMethod(self, layer)
elif isinstance(layer, Attention):
return Fp8KVCacheMethod(self)
@ -927,6 +933,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
if enable_eplb:
assert expert_load_view is not None
assert logical_to_physical_map is not None
@ -943,8 +950,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401
assert (renormalize and use_grouped_topk
and custom_routing_function is None)
return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8(
result = torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8(
routing_logits=router_logits.to(torch.float32),
routing_bias=e_score_correction_bias,
x=x,
@ -965,7 +971,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
else:
assert (not renormalize
and custom_routing_function is not None)
return apply_flashinfer_per_tensor_scale_fp8(
result = apply_flashinfer_per_tensor_scale_fp8(
layer=layer,
hidden_states=x,
router_logits=router_logits,
@ -976,7 +982,10 @@ class Fp8MoEMethod(FusedMoEMethodBase):
topk_group=topk_group,
apply_router_weight_on_input=apply_router_weight_on_input)
topk_weights, topk_ids = FusedMoE.select_experts(
zero_expert_num = getattr(layer, 'zero_expert_num', 0)
zero_expert_type = getattr(layer, 'zero_expert_type', None)
select_result = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
@ -994,17 +1003,22 @@ class Fp8MoEMethod(FusedMoEMethodBase):
expert_load_view=expert_load_view,
logical_to_physical_map=logical_to_physical_map,
logical_replica_count=logical_replica_count,
global_num_experts=global_num_experts,
zero_expert_num=zero_expert_num,
zero_expert_type=zero_expert_type,
)
#
# Note: the order of checks is important since self.fused_experts
# can override fused_experts or cutlass but not rocm or marlin.
#
topk_weights, topk_ids, zero_expert_result = select_result
if self.rocm_aiter_moe_enabled:
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501
rocm_aiter_fused_experts)
assert self.fused_experts is None
return rocm_aiter_fused_experts(
result = rocm_aiter_fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
@ -1018,7 +1032,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
assert activation == "silu", (
f"{activation} not supported for Marlin MoE.")
assert self.fused_experts is None
return torch.ops.vllm.fused_marlin_moe(
result = torch.ops.vllm.fused_marlin_moe(
x,
layer.w13_weight,
layer.w2_weight,
@ -1035,7 +1049,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
expert_map=expert_map,
workspace=layer.workspace)
elif self.fused_experts:
return self.fused_experts(
result = self.fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
@ -1055,7 +1069,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
assert scoring_func == 'sigmoid', (
f"Expected 'sigmoid' scoring func but got {scoring_func}")
return flashinfer_cutlass_moe_fp8(
result = flashinfer_cutlass_moe_fp8(
x,
layer,
topk_weights,
@ -1068,7 +1082,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
)
else:
from vllm.model_executor.layers.fused_moe import fused_experts
return fused_experts(
result = fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
@ -1083,6 +1097,12 @@ class Fp8MoEMethod(FusedMoEMethodBase):
allow_deep_gemm=self.allow_deep_gemm,
allow_cutlass_block_scaled_grouped_gemm=(
self.allow_cutlass_block_scaled_grouped_gemm))
if zero_expert_num != 0 and zero_expert_type is not None:
assert not isinstance(result, tuple), \
"Shared + zero experts are mutually exclusive not yet supported"
return result, zero_expert_result
else:
return result
class Fp8KVCacheMethod(BaseKVCacheMethod):

View File

@ -555,7 +555,7 @@ class GGUFMoEMethod(FusedMoEMethodBase):
"Apply router weight on input is not supported for"
"fused GGUF MoE method.")
topk_weights, topk_ids = FusedMoE.select_experts(
topk_weights, topk_ids, _ = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,

View File

@ -669,7 +669,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
assert activation == "silu", "Only SiLU activation is supported."
topk_weights, topk_ids = FusedMoE.select_experts(
topk_weights, topk_ids, _ = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,

View File

@ -543,7 +543,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
apply_router_weight_on_input=apply_router_weight_on_input)
# Expert selection
topk_weights, topk_ids = FusedMoE.select_experts(
topk_weights, topk_ids, _ = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
@ -1491,7 +1491,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
)[0]
return out
topk_weights, topk_ids = FusedMoE.select_experts(
topk_weights, topk_ids, _ = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,

View File

@ -332,7 +332,7 @@ class MoeWNA16Method(FusedMoEMethodBase):
from vllm.model_executor.layers.fused_moe import fused_experts
assert activation == "silu", "Only SiLU activation is supported."
topk_weights, topk_ids = FusedMoE.select_experts(
topk_weights, topk_ids, _ = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,

View File

@ -718,7 +718,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
assert isinstance(self.fused_experts, mk.FusedMoEModularKernel)
topk_weights, topk_ids = FusedMoE.select_experts(
topk_weights, topk_ids, _ = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
@ -783,7 +783,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
raise NotImplementedError("EPLB is not supported for mxfp4")
if self.mxfp4_backend == Mxfp4Backend.MARLIN:
topk_weights, topk_ids = FusedMoE.select_experts(
topk_weights, topk_ids, _ = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
@ -894,7 +894,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
or self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16):
from vllm.utils.flashinfer import flashinfer_cutlass_fused_moe
topk_weights, topk_ids = FusedMoE.select_experts(
topk_weights, topk_ids, _ = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,

View File

@ -329,7 +329,7 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
raise NotImplementedError(
"EPLB not supported for `QuarkW8A8Fp8MoEMethod` yet.")
topk_weights, topk_ids = FusedMoE.select_experts(
topk_weights, topk_ids, _ = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
@ -531,7 +531,7 @@ class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod):
from vllm.model_executor.layers.fused_moe import fused_experts
topk_weights, topk_ids = FusedMoE.select_experts(
topk_weights, topk_ids, _ = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,

View File

@ -318,7 +318,7 @@ class RTNMoEMethod(FusedMoEMethodBase):
from vllm.model_executor.layers.fused_moe import fused_experts
topk_weights, topk_ids = FusedMoE.select_experts(
topk_weights, topk_ids, _ = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,

View File

@ -292,6 +292,11 @@ def is_layer_skipped(
f"Detected some but not all shards of {prefix} "
"are quantized. All shards of fused layers "
"to have the same precision.")
elif "experts" in prefix:
return any([
prefix in layer_name for layer_name in ignored_layers
if "experts" in layer_name
])
else:
is_skipped = prefix in ignored_layers

View File

@ -0,0 +1,712 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Apache License, Version 2.0:
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# MIT License:
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
"""Inference-only Flash model compatible with HuggingFace weights."""
import typing
from collections.abc import Callable, Iterable
from typing import Optional, Union
import torch
from torch import nn
from transformers import PretrainedConfig
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.utils.int8_utils import (
block_dequant)
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.deepseek_v2 import DeepseekV2MLAAttention
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (PPMissingLayer, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
logger = init_logger(__name__)
class FlashConfig(PretrainedConfig):
"""Flash model configuration."""
model_type = "longcat_flash"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=131072,
hidden_size=4096,
intermediate_size=8192,
num_layers=28,
num_hidden_layers=None,
num_attention_heads=96,
num_key_value_heads=128,
ep_size=1,
kv_lora_rank=512,
q_lora_rank=1536,
qk_rope_head_dim=64,
v_head_dim=128,
qk_nope_head_dim=128,
num_experts_per_tok=None,
norm_topk_prob=False,
max_position_embeddings=8192,
initializer_range=0.02,
rms_norm_eps=1e-05,
use_cache=True,
pad_token_id=None,
bos_token_id=100000,
eos_token_id=100001,
pretraining_tp=1,
tie_word_embeddings=False,
rope_theta=1000000.0,
rope_scaling=None,
attention_bias=False,
attention_dropout=0.0,
mla_scale_q_lora=False,
mla_scale_kv_lora=False,
torch_dtype="bfloat16",
params_dtype="bfloat16",
router_dtype="float32",
router_bias=False,
topk_method=None,
routed_scaling_factor=None,
zero_expert_num=0,
zero_expert_type=None,
nextn_use_scmoe=False,
**kwargs,
):
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
torch_dtype=torch_dtype,
params_dtype=params_dtype,
router_dtype=router_dtype,
topk_method=topk_method,
router_bias=router_bias,
nextn_use_scmoe=nextn_use_scmoe,
**kwargs,
)
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.num_hidden_layers = (num_hidden_layers if num_hidden_layers
is not None else num_layers)
self.num_attention_heads = num_attention_heads
self.ep_size = ep_size
self.kv_lora_rank = kv_lora_rank
self.q_lora_rank = q_lora_rank
self.qk_rope_head_dim = qk_rope_head_dim
self.v_head_dim = v_head_dim
self.qk_nope_head_dim = qk_nope_head_dim
self.num_experts_per_tok = num_experts_per_tok
self.norm_topk_prob = norm_topk_prob
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.pretraining_tp = pretraining_tp
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.mla_scale_q_lora = mla_scale_q_lora
self.mla_scale_kv_lora = mla_scale_kv_lora
self.zero_expert_num = zero_expert_num
self.zero_expert_type = zero_expert_type
self.routed_scaling_factor = routed_scaling_factor
self.hidden_act = "silu"
self.intermediate_size = self.ffn_hidden_size if hasattr(
self, "ffn_hidden_size") else self.intermediate_size
if hasattr(self, "moe_intermediate_size"):
self.moe_intermediate_size = self.moe_intermediate_size
elif hasattr(self, "expert_ffn_hidden_size"):
self.moe_intermediate_size = self.expert_ffn_hidden_size
else:
self.moe_intermediate_size = self.intermediate_size
class FlashMLP(nn.Module):
"""Flash MLP layer."""
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
quant_config: Optional[QuantizationConfig] = None,
reduce_results: bool = True,
prefix: str = "",
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size,
[intermediate_size] * 2,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
)
self.down_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
reduce_results=reduce_results,
prefix=f"{prefix}.down_proj",
)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
self.act_fn = SiluAndMul()
def forward(self, x: torch.Tensor) -> torch.Tensor:
if x.numel() == 0:
return x
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
class LongcatRouter(nn.Module):
def __init__(self,
config,
zero_expert_num=0,
rounter_params_dtype=torch.bfloat16,
prefix: str = ""):
super().__init__()
self.n_routed_experts = config.n_routed_experts if hasattr(
config, "n_routed_experts") else config.num_experts[0]
self.n_routed_experts = self.n_routed_experts + zero_expert_num
self.classifier = ReplicatedLinear(
config.hidden_size,
self.n_routed_experts,
bias=config.router_bias,
params_dtype=rounter_params_dtype,
quant_config=None,
prefix=f"{prefix}.classifier",
)
self.e_score_correction_bias = nn.Parameter(
torch.zeros((self.n_routed_experts), dtype=rounter_params_dtype))
def forward(self, hidden_states):
logits, _ = self.classifier(hidden_states)
return logits
class LongcatMoe(nn.Module):
def __init__(
self,
config: FlashConfig,
num_experts: int,
top_k: int,
hidden_size: int,
intermediate_size: int,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
enable_eplb: bool = False,
):
super().__init__()
self.hidden_size = hidden_size
self.zero_expert_num = config.zero_expert_num
self.zero_expert_type = config.zero_expert_type
self.routed_scaling_factor = config.routed_scaling_factor
self.enable_eplb = enable_eplb
# Gate always runs at half / full precision for now.
self.rounter_params_dtype = params_dtype
if config.router_dtype == "float32":
self.rounter_params_dtype = torch.float32
self.router = LongcatRouter(
config=config,
zero_expert_num=self.zero_expert_num,
rounter_params_dtype=self.rounter_params_dtype,
prefix=f"{prefix}.gate")
self.experts = FusedMoE(
num_experts=num_experts,
top_k=top_k,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
reduce_results=True,
params_dtype=params_dtype,
e_score_correction_bias=self.router.e_score_correction_bias,
renormalize=False,
quant_config=quant_config,
prefix=f"{prefix}.experts",
zero_expert_num=self.zero_expert_num,
zero_expert_type=self.zero_expert_type,
enable_eplb=self.enable_eplb,
routed_scaling_factor=config.routed_scaling_factor,
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
router_logits = self.router(hidden_states.to(
self.rounter_params_dtype))
final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits)
return final_hidden_states.view(num_tokens, hidden_dim)
class FlashDecoderLayer(nn.Module):
"""Flash decoder layer with dual attention and MLP structure."""
def __init__(
self,
config: FlashConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
enable_eplb: bool = False,
) -> None:
super().__init__()
self.layer_idx = int(prefix.split(sep='.')[-1])
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings",
8192)
if rope_scaling is not None and getattr(
config, "original_max_position_embeddings", None):
rope_scaling["original_max_position_embeddings"] = (
config.original_max_position_embeddings)
# Dual attention structure
self.self_attn = nn.ModuleList([
DeepseekV2MLAAttention(
config=config,
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
qk_nope_head_dim=config.qk_nope_head_dim,
qk_rope_head_dim=config.qk_rope_head_dim,
v_head_dim=config.v_head_dim,
q_lora_rank=(config.q_lora_rank if hasattr(
config, "q_lora_rank") else None),
kv_lora_rank=config.kv_lora_rank,
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
cache_config=cache_config,
quant_config=None if "self_attn" in getattr(
config, "disable_quant_module", []) else quant_config,
prefix=f"{prefix}.self_attn.{i}",
) for i in range(2)
])
self.input_layernorm = nn.ModuleList([
RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
for i in range(2)
])
self.post_attention_layernorm = nn.ModuleList([
RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
for i in range(2)
])
# Dual MLP structure
self.mlps = nn.ModuleList([
FlashMLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=None if "mlps" in getattr(
config, "disable_quant_module", []) else quant_config,
prefix=f"{prefix}.mlps.{i}",
) for i in range(2)
])
self.mlp = LongcatMoe(
config=config,
num_experts=config.n_routed_experts if hasattr(
config, "n_routed_experts") else
config.num_experts[self.layer_idx],
top_k=config.moe_topk
if hasattr(config, "moe_topk") else config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
quant_config=quant_config,
prefix=(f"{prefix}.mlp"),
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
) -> tuple[torch.Tensor, torch.Tensor]:
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm[0](hidden_states)
else:
hidden_states, residual = self.input_layernorm[0](hidden_states,
residual)
hidden_states = self.self_attn[0](
positions=positions,
hidden_states=hidden_states,
)
hidden_states, residual = self.post_attention_layernorm[0](
hidden_states, residual)
# moe
hidden_states_copy = hidden_states.clone()
moe_hidden_states = self.mlp(hidden_states_copy)
# first mlp
hidden_states = self.mlps[0](hidden_states)
hidden_states, residual = self.input_layernorm[1](hidden_states,
residual)
# second_attn
hidden_states = self.self_attn[1](
positions=positions,
hidden_states=hidden_states,
)
hidden_states, residual = self.post_attention_layernorm[1](
hidden_states, residual)
# second_mlp
hidden_states = self.mlps[1](hidden_states)
hidden_states = hidden_states + moe_hidden_states
return hidden_states, residual
@support_torch_compile
class FlashModel(nn.Module):
"""Flash model."""
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = FlashConfig(**vllm_config.model_config.hf_config.__dict__)
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config
self.padding_idx = getattr(config, "pad_token_id", None)
self.vocab_size = config.vocab_size
if get_pp_group().is_first_rank:
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
prefix=maybe_prefix(prefix, "embed_tokens"),
)
else:
self.embed_tokens = PPMissingLayer()
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: FlashDecoderLayer(
config,
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix,
),
prefix=f"{prefix}.layers")
if get_pp_group().is_last_rank:
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
else:
self.norm = PPMissingLayer()
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states, residual = layer(
positions,
hidden_states,
residual,
)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
class LongcatFlashForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
"""Flash model for causal language modeling."""
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = FlashConfig(**vllm_config.model_config.hf_config.__dict__)
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
config.intermediate_size = config.ffn_hidden_size if hasattr(
config, "ffn_hidden_size") else config.intermediate_size
self.lora_config = lora_config
self.quant_config = quant_config
self.model = FlashModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
if get_pp_group().is_last_rank:
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=maybe_prefix(
prefix, "lm_head"))
else:
self.lm_head = PPMissingLayer()
self.logits_processor = LogitsProcessor(config.vocab_size)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states)
return logits
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
return FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.n_routed_experts if hasattr(
self.config, "n_routed_experts") else
self.config.num_experts[0],
)
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
("fused_qkv_a_proj", "q_a_proj", 0),
("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1),
(".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1),
]
expert_params_mapping = self.get_expert_mapping()
loaded_params: set[str] = set()
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
if "mlp" in name and "mlps" not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if (name.endswith(".bias")
or name.endswith("_bias")) and name not in params_dict:
continue
# Skip mtp
if ".mtp." in name:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
is_expert_weight = False
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
is_expert_weight = True
name_mapped = name.replace(weight_name, param_name)
# Skip mtp
if ".mtp." in name_mapped:
continue
if (name_mapped.endswith(".bias")
or name_mapped.endswith("_bias")
) and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name_mapped]
weight_loader = param.weight_loader
weight_loader = typing.cast(Callable[..., bool],
param.weight_loader)
success = weight_loader(param,
loaded_weight,
name_mapped,
shard_id=shard_id,
expert_id=expert_id,
return_success=True)
if success:
name = name_mapped
break
else:
if is_expert_weight:
# We've checked that this is an expert weight
# However it's not mapped locally to this rank
# So we simply skip it
continue
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Skip loading kv_scale from ckpts towards new design.
if name.endswith(".kv_scale") and name not in params_dict:
continue
# Skip mtp
if ".mtp." in name:
continue
if name is None:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
for layer_id in range(self.config.num_hidden_layers):
for i in range(2):
if isinstance(self.model.layers[layer_id], PPMissingLayer):
continue
self_attn = self.model.layers[layer_id].self_attn[i]
if hasattr(self.quant_config, "weight_block_size"
) and self_attn.kv_b_proj.weight.dtype in (
torch.float8_e4m3fn,
torch.float8_e4m3fnuz,
):
weight_block_size = self.quant_config.weight_block_size
if weight_block_size is not None:
assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
dtype = torch.get_default_dtype()
w = block_dequant(self_attn.kv_b_proj.weight,
self_attn.kv_b_proj.weight_scale_inv,
weight_block_size).to(dtype)
else:
w = self_attn.kv_b_proj.weight
w_kc, w_vc = w.unflatten(
0,
(-1,
self_attn.qk_nope_head_dim + self_attn.v_head_dim)).split(
[self_attn.qk_nope_head_dim, self_attn.v_head_dim],
dim=1)
self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(
1, 2)
self_attn.w_vc = w_vc.contiguous().transpose(1, 2)
if self.config.mla_scale_q_lora:
self_attn.q_a_layernorm.weight.data *= (
self.config.hidden_size / self.config.q_lora_rank)**0.5
if self.config.mla_scale_kv_lora:
self_attn.kv_a_layernorm.weight.data *= (
self.config.hidden_size /
self.config.kv_lora_rank)**0.5
return loaded_params

View File

@ -0,0 +1,352 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Adapted from
# https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/model_executor/models/deepseek_mtp.py
from collections.abc import Iterable
from typing import Optional
import torch
import torch.nn as nn
from transformers import PretrainedConfig
from vllm.config import VllmConfig
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.utils.int8_utils import (
block_dequant)
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.longcat_flash import FlashConfig
from vllm.sequence import IntermediateTensors
from .deepseek_v2 import DeepseekV2DecoderLayer
from .interfaces import SupportsPP
from .utils import maybe_prefix
class LongCatMultiTokenPredictorLayer(nn.Module):
def __init__(
self,
config: PretrainedConfig,
prefix: str,
vllm_config: VllmConfig,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.eh_proj = ReplicatedLinear(2 * config.hidden_size,
config.hidden_size,
bias=False,
quant_config=quant_config,
prefix="eh_proj")
self.mtp_block = DeepseekV2DecoderLayer(vllm_config, prefix)
self.final_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
previous_hidden_states: torch.Tensor,
inputs_embeds: Optional[torch.Tensor] = None,
spec_step_index: int = 0,
) -> torch.Tensor:
assert inputs_embeds is not None
inputs_embeds = self.enorm(inputs_embeds)
previous_hidden_states = self.hnorm(previous_hidden_states)
hidden_states, _ = self.eh_proj(
torch.cat([inputs_embeds, previous_hidden_states], dim=-1))
hidden_states, residual = self.mtp_block(positions=positions,
hidden_states=hidden_states,
residual=None)
hidden_states, _ = self.final_layernorm(hidden_states, residual)
return hidden_states
class LongCatMultiTokenPredictor(nn.Module):
def __init__(self,
*,
vllm_config: VllmConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__()
config = FlashConfig(**vllm_config.model_config.hf_config.__dict__)
vllm_config.model_config.hf_config.intermediate_size \
= config.intermediate_size
self.mtp_start_layer_idx = config.num_hidden_layers * 2
self.num_mtp_layers = 1
self.layers = torch.nn.ModuleDict({
str(idx):
LongCatMultiTokenPredictorLayer(
config,
prefix=f"{prefix}.layers.{idx}",
vllm_config=vllm_config,
quant_config=quant_config,
)
for idx in range(self.mtp_start_layer_idx,
self.mtp_start_layer_idx + self.num_mtp_layers)
})
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
previous_hidden_states: torch.Tensor,
inputs_embeds: Optional[torch.Tensor] = None,
spec_step_idx: int = 0,
) -> torch.Tensor:
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
current_step_idx = (spec_step_idx % self.num_mtp_layers)
return self.layers[str(self.mtp_start_layer_idx + current_step_idx)](
input_ids,
positions,
previous_hidden_states,
inputs_embeds,
current_step_idx,
)
class LongCatFlashMTP(nn.Module, SupportsPP):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
# LongCat MTP without MoE layers
vllm_config.model_config.hf_config.n_routed_experts = None
self.config = FlashConfig(
**vllm_config.model_config.hf_config.__dict__)
self.quant_config = None if "mtp" in getattr(
self.config, "disable_quant_module",
[]) else vllm_config.quant_config
self.model = LongCatMultiTokenPredictor(vllm_config=vllm_config,
quant_config=self.quant_config,
prefix=maybe_prefix(
prefix, "model"))
self.lm_head = ParallelLMHead(
self.config.vocab_size,
self.config.hidden_size,
quant_config=self.quant_config,
)
self.logits_processor = LogitsProcessor(self.config.vocab_size)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
hidden_states: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
spec_step_idx: int = 0,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, hidden_states,
inputs_embeds, spec_step_idx)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
spec_step_idx: int = 0,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states)
return logits
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
("fused_qkv_a_proj", "q_a_proj", 0),
("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1),
]
new_to_old_names_mapping = {
"model.mtp.embed_tokens.weight":
"model.layers.0.embed_tokens.weight",
"model.mtp.layers.0.eh_proj.weight": "eh_proj.weight",
"model.mtp.layers.0.eh_proj.weight_scale_inv":
"eh_proj.weight_scale_inv",
"model.mtp.layers.0.enorm.m.weight": "enorm.weight",
"model.mtp.layers.0.hnorm.m.weight": "hnorm.weight",
"model.mtp.layers.0.input_layernorm.weight":
"model.layers.0.input_layernorm.weight",
"model.mtp.layers.0.post_attention_layernorm.weight":
"model.layers.0.post_attention_layernorm.weight",
"model.mtp.layers.0.self_attn.kv_a_layernorm.weight":
"model.layers.0.self_attn.kv_a_layernorm.weight",
"model.mtp.layers.0.self_attn.kv_a_proj_with_mqa.weight":
"model.layers.0.self_attn.kv_a_proj_with_mqa.weight",
"model.mtp.layers.0.self_attn.kv_a_proj_with_mqa.weight_scale_inv":
"model.layers.0.self_attn.kv_a_proj_with_mqa.weight_scale_inv",
"model.mtp.layers.0.self_attn.kv_b_proj.weight":
"model.layers.0.self_attn.kv_b_proj.weight",
"model.mtp.layers.0.self_attn.kv_b_proj.weight_scale_inv":
"model.layers.0.self_attn.kv_b_proj.weight_scale_inv",
"model.mtp.layers.0.self_attn.o_proj.weight":
"model.layers.0.self_attn.o_proj.weight",
"model.mtp.layers.0.self_attn.o_proj.weight_scale_inv":
"model.layers.0.self_attn.o_proj.weight_scale_inv",
"model.mtp.layers.0.self_attn.q_a_layernorm.weight":
"model.layers.0.self_attn.q_a_layernorm.weight",
"model.mtp.layers.0.self_attn.q_a_proj.weight":
"model.layers.0.self_attn.q_a_proj.weight",
"model.mtp.layers.0.self_attn.q_a_proj.weight_scale_inv":
"model.layers.0.self_attn.q_a_proj.weight_scale_inv",
"model.mtp.layers.0.self_attn.q_b_proj.weight":
"model.layers.0.self_attn.q_b_proj.weight",
"model.mtp.layers.0.self_attn.q_b_proj.weight_scale_inv":
"model.layers.0.self_attn.q_b_proj.weight_scale_inv",
"model.mtp.layers.0.transformer_layer.mlp.down_proj.weight":
"model.layers.0.mlp.down_proj.weight",
"model.mtp.layers.0.transformer_layer.mlp.down_proj.weight_scale_inv":
"model.layers.0.mlp.down_proj.weight_scale_inv",
"model.mtp.layers.0.transformer_layer.mlp.gate_proj.weight":
"model.layers.0.mlp.gate_proj.weight",
"model.mtp.layers.0.transformer_layer.mlp.gate_proj.weight_scale_inv":
"model.layers.0.mlp.gate_proj.weight_scale_inv",
"model.mtp.layers.0.transformer_layer.mlp.up_proj.weight":
"model.layers.0.mlp.up_proj.weight",
"model.mtp.layers.0.transformer_layer.mlp.up_proj.weight_scale_inv":
"model.layers.0.mlp.up_proj.weight_scale_inv",
"model.mtp.norm.weight": "final_layernorm.weight",
}
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
spec_layer = self.get_spec_layer_idx_from_weight_name(
self.config, name)
if spec_layer is None:
continue
name = self._rewrite_spec_layer_name(spec_layer, name,
new_to_old_names_mapping)
for (param_name, weight_name, shard_id) in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below).
if weight_name not in name:
continue
# We have mlp.experts[0].gate_proj in the checkpoint.
# Since we handle the experts below in expert_params_mapping,
# we need to skip here BEFORE we update the name, otherwise
# name will be updated to mlp.experts[0].gate_up_proj, which
# will then be updated below in expert_params_mapping
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if (("mlp.experts." in name) and name not in params_dict):
continue
name = name.replace(weight_name, param_name)
# QKV fusion is optional, fall back to normal
# weight loading if it's not enabled
if ((param_name == "fused_qkv_a_proj")
and name not in params_dict):
continue
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# According to DeepSeek-V3 Technical Report, MTP modules
# shares embedding layer. We only load the first weights.
if (spec_layer != self.model.mtp_start_layer_idx
and ".layers" not in name):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
spec_layer_id = self.config.num_hidden_layers * 2
self_attn = self.model.layers[str(spec_layer_id)].mtp_block.self_attn
if hasattr(
self.quant_config,
"weight_block_size") and self_attn.kv_b_proj.weight.dtype in (
torch.float8_e4m3fn,
torch.float8_e4m3fnuz,
):
weight_block_size = self.quant_config.weight_block_size
if weight_block_size is not None:
dtype = torch.get_default_dtype()
w = block_dequant(self_attn.kv_b_proj.weight,
self_attn.kv_b_proj.weight_scale_inv,
weight_block_size).to(dtype)
else:
w = self_attn.kv_b_proj.weight
else:
w = self_attn.kv_b_proj.weight
w_kc, w_vc = w.unflatten(
0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)).split(
[self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2)
self_attn.w_vc = w_vc.contiguous().transpose(1, 2)
if self.config.mla_scale_q_lora:
self_attn.q_a_layernorm.weight.data *= (
self.config.hidden_size / self.config.q_lora_rank)**0.5
if self.config.mla_scale_kv_lora:
self_attn.kv_a_layernorm.weight.data *= (
self.config.hidden_size / self.config.kv_lora_rank)**0.5
return loaded_params
def _rewrite_spec_layer_name(self, spec_layer: int, name: str,
new_to_old_names_mapping: dict) -> str:
"""
Rewrite the weight name to match the format of the original model.
Add .mtp_block for modules in transformer layer block for spec layer
and rename shared layer weights to be top level.
"""
if name in new_to_old_names_mapping:
name = new_to_old_names_mapping[name]
spec_layer_weight_names = [
"embed_tokens", "enorm", "hnorm", "eh_proj", "shared_head"
]
if name.startswith("enorm") or name.startswith(
"hnorm") or name.startswith("eh_proj") or name.startswith(
"final_layernorm"):
name = "model.layers." + str(spec_layer) + "." + name
shared_weight_names = ["embed_tokens"]
spec_layer_weight = False
shared_weight = False
for weight_name in spec_layer_weight_names:
if weight_name in name:
spec_layer_weight = True
if weight_name in shared_weight_names:
shared_weight = True
break
if not spec_layer_weight:
# treat rest weights as weights for transformer layer block
name = name.replace("model.layers.0.",
f"model.layers.{spec_layer}.mtp_block.")
elif shared_weight:
# treat shared weights as top level weights
name = name.replace("model.layers.0.", "model.")
return name
def get_spec_layer_idx_from_weight_name(self, config: PretrainedConfig,
weight_name: str) -> Optional[int]:
if "model.mtp" in weight_name:
return config.num_hidden_layers * 2
return None

View File

@ -109,6 +109,7 @@ _TEXT_GENERATION_MODELS = {
"Llama4ForCausalLM": ("llama4", "Llama4ForCausalLM"), # noqa: E501
# For decapoda-research/llama-*
"LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
"LongcatFlashForCausalLM": ("longcat_flash", "LongcatFlashForCausalLM"),
"MambaForCausalLM": ("mamba", "MambaForCausalLM"),
"FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"),
"FalconH1ForCausalLM":("falcon_h1", "FalconH1ForCausalLM"),
@ -287,6 +288,7 @@ _SPECULATIVE_DECODING_MODELS = {
"EagleDeepSeekMTPModel": ("deepseek_eagle", "EagleDeepseekV3ForCausalLM"),
"DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
"ErnieMTPModel": ("ernie_mtp", "ErnieMTP"),
"LongCatFlashMTPModel": ("longcat_flash_mtp", "LongCatFlashMTP"),
"Glm4MoeMTPModel": ("glm4_moe_mtp", "Glm4MoeMTP"),
"MedusaModel": ("medusa", "Medusa"),
"Qwen3NextMTP": ("qwen3_next_mtp", "Qwen3NextMTP"),

View File

@ -691,14 +691,14 @@ def maybe_prefix(prefix: str, name: str) -> str:
return name if not prefix else f"{prefix}.{name}"
def extract_layer_index(layer_name: str) -> int:
def extract_layer_index(layer_name: str, num_attn_module: int = 1) -> int:
"""
Extract the layer index from the module name.
Examples:
- "encoder.layers.0" -> 0
- "encoder.layers.1.self_attn" -> 1
- "2.self_attn" -> 2
- "model.encoder.layers.0.sub.1" -> ValueError
- "model.encoder.layers.0.sub.1" -> ValueError if num_attn_module == 1
"""
subnames = layer_name.split(".")
int_vals: list[int] = []
@ -707,9 +707,17 @@ def extract_layer_index(layer_name: str) -> int:
int_vals.append(int(subname))
except ValueError:
continue
assert len(int_vals) == 1, (f"layer name {layer_name} should"
" only contain one integer")
return int_vals[0]
if num_attn_module == 1 or "attn" not in layer_name:
assert len(int_vals) == 1, (f"layer name {layer_name} should"
" only contain one integer")
return int_vals[0]
else:
assert len(int_vals) <= 2, (f"layer name {layer_name} should"
" contain most two integers")
layer_index = int_vals[0] * num_attn_module + int_vals[1] if len(
int_vals) == 2 else int_vals[0]
return layer_index
def cast_overflow_tensors(

View File

@ -169,7 +169,6 @@ class EagleProposer:
target_hidden_states = self.model.combine_hidden_states(
target_hidden_states)
assert target_hidden_states.shape[-1] == self.hidden_size
# Shift the input ids by one token.
# E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
self.input_ids[:num_tokens - 1] = target_token_ids[1:]
@ -223,7 +222,8 @@ class EagleProposer:
hidden_states=self.hidden_states[:num_input_tokens],
inputs_embeds=inputs_embeds,
)
if self.method in ("deepseek_mtp", "ernie_mtp", "qwen3_next_mtp"):
if self.method in ("deepseek_mtp", "ernie_mtp", "qwen3_next_mtp",
"longcat_flash_mtp"):
last_hidden_states = ret_hidden_states
hidden_states = last_hidden_states
else:
@ -237,7 +237,10 @@ class EagleProposer:
return draft_token_ids.view(-1, 1)
positions = target_positions[last_token_indices]
hidden_states = hidden_states[last_token_indices]
if self.method in ("deepseek_mtp", "ernie_mtp", "longcat_flash_mtp"):
hidden_states = self.hidden_states[last_token_indices]
else:
hidden_states = hidden_states[last_token_indices]
if isinstance(attn_metadata, TreeAttentionMetadata):
# Draft using tree attention.
@ -350,7 +353,7 @@ class EagleProposer:
inputs_embeds=inputs_embeds,
)
if self.method in ("deepseek_mtp", "ernie_mtp",
"qwen3_next_mtp"):
"qwen3_next_mtp", "longcat_flash_mtp"):
last_hidden_states = ret_hidden_states
hidden_states = ret_hidden_states
else:

View File

@ -3840,9 +3840,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
target_layer_name)
kv_caches[layer_name] = kv_caches[target_layer_name]
num_attn_module = 2 \
if self.model_config.hf_config.model_type == "longcat_flash" else 1
bind_kv_cache(kv_caches,
self.compilation_config.static_forward_context,
self.kv_caches)
self.kv_caches, num_attn_module)
return kv_caches
def maybe_add_kv_sharing_layers_to_kv_cache_groups(

View File

@ -266,6 +266,7 @@ def bind_kv_cache(
kv_caches: dict[str, torch.Tensor],
forward_context: dict[str, "Attention"],
runner_kv_caches: list[torch.Tensor],
num_attn_module: Optional[int] = 1,
) -> None:
"""
Bind the allocated KV cache to both ModelRunner and forward context so
@ -289,7 +290,8 @@ def bind_kv_cache(
# Convert kv_caches dict to a list of tensors in the order of layer_index.
index2name = defaultdict(list)
for layer_name in kv_caches:
index2name[extract_layer_index(layer_name)].append(layer_name)
index2name[extract_layer_index(layer_name,
num_attn_module)].append(layer_name)
for layer_index in sorted(index2name.keys()):
layer_names = index2name[layer_index]