mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-25 19:48:41 +08:00
[Model] Add LongCat-Flash (#23991)
Signed-off-by: yangxurui <yangxurui@meituan.com> Co-authored-by: yangxurui <yangxurui@meituan.com> Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
12c21d28c1
commit
c26e7b14d7
@ -44,6 +44,9 @@ __global__ void moe_align_block_size_kernel(
|
||||
|
||||
for (size_t i = tid; i < numel; i += stride) {
|
||||
int expert_id = topk_ids[i];
|
||||
if (expert_id >= num_experts) {
|
||||
continue;
|
||||
}
|
||||
int warp_idx = expert_id / experts_per_warp;
|
||||
int expert_offset = expert_id % experts_per_warp;
|
||||
atomicAdd(&shared_counts[warp_idx * experts_per_warp + expert_offset], 1);
|
||||
@ -95,12 +98,15 @@ template <typename scalar_t>
|
||||
__global__ void count_and_sort_expert_tokens_kernel(
|
||||
const scalar_t* __restrict__ topk_ids,
|
||||
int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ cumsum_buffer,
|
||||
size_t numel) {
|
||||
size_t numel, int32_t num_experts) {
|
||||
const size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const size_t stride = blockDim.x * gridDim.x;
|
||||
|
||||
for (size_t i = tid; i < numel; i += stride) {
|
||||
int32_t expert_id = topk_ids[i];
|
||||
if (expert_id >= num_experts) {
|
||||
continue;
|
||||
}
|
||||
int32_t rank_post_pad = atomicAdd(&cumsum_buffer[expert_id], 1);
|
||||
sorted_token_ids[rank_post_pad] = i;
|
||||
}
|
||||
@ -269,7 +275,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
|
||||
sort_kernel<<<actual_blocks, block_threads, 0, stream>>>(
|
||||
topk_ids.data_ptr<scalar_t>(),
|
||||
sorted_token_ids.data_ptr<int32_t>(),
|
||||
cumsum_buffer.data_ptr<int32_t>(), topk_ids.numel());
|
||||
cumsum_buffer.data_ptr<int32_t>(), topk_ids.numel(), num_experts);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@ -428,6 +428,7 @@ th {
|
||||
| `MiniMaxM1ForCausalLM` | MiniMax-Text | `MiniMaxAI/MiniMax-M1-40k`, `MiniMaxAI/MiniMax-M1-80k`, etc. | | | ✅︎ |
|
||||
| `MiniMaxText01ForCausalLM` | MiniMax-Text | `MiniMaxAI/MiniMax-Text-01`, etc. | | | ✅︎ |
|
||||
| `Zamba2ForCausalLM` | Zamba2 | `Zyphra/Zamba2-7B-instruct`, `Zyphra/Zamba2-2.7B-instruct`, `Zyphra/Zamba2-1.2B-instruct`, etc. | | | ✅︎ |
|
||||
| `LongcatFlashForCausalLM` | LongCat-Flash | `meituan-longcat/LongCat-Flash-Chat`, `meituan-longcat/LongCat-Flash-Chat-FP8` | ✅︎ |✅︎ | ✅︎ |
|
||||
|
||||
Some models are supported only via the [Transformers backend](#transformers). The purpose of the table below is to acknowledge models which we officially support in this way. The logs will say that the Transformers backend is being used, and you will see no warning that this is fallback behaviour. This means that, if you have issues with any of the models listed below, please [make an issue](https://github.com/vllm-project/vllm/issues/new/choose) and we'll do our best to fix it!
|
||||
|
||||
|
||||
@ -138,7 +138,7 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph(
|
||||
td = TestData.make_moe_tensors_8bit(m, k, n, e, reorder=True)
|
||||
|
||||
score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
topk_weights, topk_ids, _ = FusedMoE.select_experts(
|
||||
hidden_states=td.hidden_states,
|
||||
router_logits=score,
|
||||
use_grouped_topk=False,
|
||||
@ -206,7 +206,7 @@ def test_flashinfer_cutlass_moe_fp8_no_graph(
|
||||
td = TestData.make_moe_tensors_8bit(m, k, n, e, reorder=False)
|
||||
|
||||
score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
topk_weights, topk_ids, _ = FusedMoE.select_experts(
|
||||
hidden_states=td.hidden_states,
|
||||
router_logits=score,
|
||||
use_grouped_topk=False,
|
||||
|
||||
@ -273,6 +273,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
||||
is_available_online=False),
|
||||
"Llama4ForCausalLM": _HfExamplesInfo("meta-llama/Llama-4-Scout-17B-16E-Instruct", # noqa: E501
|
||||
is_available_online=False),
|
||||
"LongcatFlashForCausalLM": _HfExamplesInfo
|
||||
("meituan-longcat/LongCat-Flash-Chat", trust_remote_code=True),
|
||||
"MambaForCausalLM": _HfExamplesInfo("state-spaces/mamba-130m-hf"),
|
||||
"Mamba2ForCausalLM": _HfExamplesInfo("mistralai/Mamba-Codestral-7B-v0.1",
|
||||
min_transformers_version="4.55.3",
|
||||
@ -639,6 +641,10 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
|
||||
speculative_model="zai-org/GLM-4.5",
|
||||
min_transformers_version="4.54",
|
||||
is_available_online=False),
|
||||
"LongCatFlashMTPModel": _HfExamplesInfo(
|
||||
"meituan-longcat/LongCat-Flash-Chat",
|
||||
trust_remote_code=True,
|
||||
speculative_model="meituan-longcat/LongCat-Flash-Chat"),
|
||||
"MiMoMTPModel": _HfExamplesInfo("XiaomiMiMo/MiMo-7B-RL",
|
||||
trust_remote_code=True,
|
||||
speculative_model="XiaomiMiMo/MiMo-7B-RL"),
|
||||
|
||||
@ -428,9 +428,8 @@ def dummy_hf_overrides(
|
||||
num_hidden_layers = (3 if model_arch
|
||||
== "Gemma3nForConditionalGeneration" else 1)
|
||||
|
||||
text_config.update({
|
||||
update_dict = {
|
||||
"num_layers": num_layers,
|
||||
"num_hidden_layers": num_hidden_layers,
|
||||
"num_experts": num_experts,
|
||||
"num_experts_per_tok": 2,
|
||||
"num_local_experts": num_experts,
|
||||
@ -440,7 +439,14 @@ def dummy_hf_overrides(
|
||||
"n_routed_experts": num_experts,
|
||||
# For Gemma-3n
|
||||
"num_kv_shared_layers": 1,
|
||||
})
|
||||
}
|
||||
|
||||
# Update num_hidden_layers for non-Longcat architectures
|
||||
if model_arch != "LongcatFlashForCausalLM" \
|
||||
and model_arch != "LongCatFlashMTPModel":
|
||||
update_dict["num_hidden_layers"] = num_hidden_layers
|
||||
|
||||
text_config.update(update_dict)
|
||||
|
||||
if hasattr(hf_config, "vision_config"):
|
||||
hf_config.vision_config.update({
|
||||
|
||||
@ -96,7 +96,7 @@ def test_routing_strategy_integration(monkeypatch, device):
|
||||
envs.environment_variables[env_name] = lambda s=strategy: s
|
||||
|
||||
# Test the select_experts method
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
topk_weights, topk_ids, _ = FusedMoE.select_experts(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
top_k=top_k,
|
||||
|
||||
@ -1131,7 +1131,8 @@ class ModelConfig:
|
||||
if not hasattr(self.hf_text_config, "model_type"):
|
||||
return False
|
||||
elif self.hf_text_config.model_type in \
|
||||
('deepseek_v2', 'deepseek_v3', 'deepseek_mtp', 'kimi_k2'):
|
||||
('deepseek_v2', 'deepseek_v3', 'deepseek_mtp',
|
||||
'kimi_k2', 'longcat_flash'):
|
||||
return self.hf_text_config.kv_lora_rank is not None
|
||||
elif self.hf_text_config.model_type == 'eagle':
|
||||
# if the model is an EAGLE module, check for the
|
||||
@ -1257,6 +1258,9 @@ class ModelConfig:
|
||||
or self.hf_config.model_type == "qwen3_next_mtp"):
|
||||
total_num_hidden_layers = getattr(self.hf_text_config,
|
||||
"num_nextn_predict_layers", 0)
|
||||
elif (self.hf_config.model_type == "longcat_flash_mtp"):
|
||||
total_num_hidden_layers = getattr(self.hf_text_config,
|
||||
"num_nextn_predict_layers", 1)
|
||||
else:
|
||||
total_num_hidden_layers = getattr(self.hf_text_config,
|
||||
"num_hidden_layers", 0)
|
||||
|
||||
@ -31,7 +31,8 @@ logger = init_logger(__name__)
|
||||
|
||||
SpeculativeMethod = Literal["ngram", "eagle", "eagle3", "medusa",
|
||||
"mlp_speculator", "draft_model", "deepseek_mtp",
|
||||
"ernie_mtp", "qwen3_next_mtp", "mimo_mtp"]
|
||||
"ernie_mtp", "qwen3_next_mtp", "mimo_mtp",
|
||||
"longcat_flash_mtp"]
|
||||
|
||||
|
||||
@config
|
||||
@ -186,6 +187,13 @@ class SpeculativeConfig:
|
||||
"n_predict": n_predict,
|
||||
"architectures": ["Qwen3NextMTP"]
|
||||
})
|
||||
if hf_config.model_type == "longcat_flash":
|
||||
hf_config.model_type = "longcat_flash_mtp"
|
||||
n_predict = getattr(hf_config, "num_nextn_predict_layers", 1)
|
||||
hf_config.update({
|
||||
"n_predict": n_predict,
|
||||
"architectures": ["LongCatFlashMTPModel"]
|
||||
})
|
||||
|
||||
return hf_config
|
||||
|
||||
@ -332,6 +340,15 @@ class SpeculativeConfig:
|
||||
"one layer. Might need some code changes " \
|
||||
"to support multiple layers."
|
||||
)
|
||||
elif (self.draft_model_config.hf_config.model_type
|
||||
in ("longcat_flash_mtp")):
|
||||
self.method = "longcat_flash_mtp"
|
||||
if self.num_speculative_tokens > 1:
|
||||
logger.warning(
|
||||
"LongCat MTP models only have " \
|
||||
"one layer. Might need some code changes " \
|
||||
"to support multiple layers."
|
||||
)
|
||||
else:
|
||||
self.method = "draft_model"
|
||||
raise NotImplementedError(
|
||||
@ -548,7 +565,7 @@ class SpeculativeConfig:
|
||||
|
||||
def use_eagle(self) -> bool:
|
||||
return self.method in ("eagle", "eagle3", "deepseek_mtp", "ernie_mtp",
|
||||
"qwen3_next_mtp")
|
||||
"qwen3_next_mtp", "longcat_flash_mtp")
|
||||
|
||||
def __repr__(self) -> str:
|
||||
method = self.method
|
||||
|
||||
@ -664,6 +664,76 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def compute_identity_kernel(
|
||||
top_k: int,
|
||||
hidden_states_ptr: tl.tensor,
|
||||
expert_scales_ptr: tl.tensor,
|
||||
num_tokens: int,
|
||||
output_ptr: tl.tensor,
|
||||
hidden_dim: int,
|
||||
scales_stride: int,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
) -> None:
|
||||
pid = tl.program_id(0)
|
||||
|
||||
batch_id = pid // (hidden_dim // BLOCK_SIZE)
|
||||
dim_offset = pid % (hidden_dim // BLOCK_SIZE) * BLOCK_SIZE
|
||||
|
||||
if batch_id >= num_tokens or dim_offset >= hidden_dim:
|
||||
return
|
||||
|
||||
h = tl.load(hidden_states_ptr + batch_id * hidden_dim + dim_offset +
|
||||
tl.arange(0, BLOCK_SIZE),
|
||||
mask=(dim_offset + tl.arange(0, BLOCK_SIZE)) < hidden_dim)
|
||||
|
||||
result = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
|
||||
for i in range(top_k):
|
||||
scale = tl.load(expert_scales_ptr + batch_id * scales_stride + i)
|
||||
result += h * scale
|
||||
|
||||
tl.store(output_ptr + batch_id * hidden_dim + dim_offset +
|
||||
tl.arange(0, BLOCK_SIZE),
|
||||
result,
|
||||
mask=(dim_offset + tl.arange(0, BLOCK_SIZE)) < hidden_dim)
|
||||
|
||||
|
||||
def zero_experts_compute_triton(expert_indices: torch.Tensor,
|
||||
expert_scales: torch.Tensor, num_experts: int,
|
||||
zero_expert_type: str,
|
||||
hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
N = expert_indices.numel()
|
||||
top_k = expert_indices.size(-1)
|
||||
grid = lambda meta: (triton.cdiv(N, meta['BLOCK_SIZE']), )
|
||||
|
||||
if zero_expert_type == "identity":
|
||||
zero_expert_mask = expert_indices < num_experts
|
||||
zero_expert_scales = expert_scales.clone()
|
||||
zero_expert_scales[zero_expert_mask] = 0.0
|
||||
|
||||
normal_expert_mask = expert_indices >= num_experts
|
||||
expert_indices[normal_expert_mask] = 0
|
||||
expert_scales[normal_expert_mask] = 0.0
|
||||
|
||||
output = torch.zeros_like(hidden_states).to(hidden_states.device)
|
||||
hidden_dim = hidden_states.size(-1)
|
||||
num_tokens = hidden_states.size(0)
|
||||
|
||||
grid = lambda meta: (num_tokens * (hidden_dim // meta['BLOCK_SIZE']), )
|
||||
compute_identity_kernel[grid](
|
||||
top_k,
|
||||
hidden_states,
|
||||
zero_expert_scales,
|
||||
num_tokens,
|
||||
output,
|
||||
hidden_dim,
|
||||
zero_expert_scales.stride(0),
|
||||
BLOCK_SIZE=256,
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
# Adapted from: https://github.com/sgl-project/sglang/pull/2628
|
||||
def get_config_file_name(E: int,
|
||||
N: int,
|
||||
@ -940,6 +1010,25 @@ def fused_topk(
|
||||
return topk_weights, topk_ids, token_expert_indices
|
||||
|
||||
|
||||
def fused_topk_bias(
|
||||
hidden_states: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
e_score_correction_bias: torch.Tensor,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
):
|
||||
n_routed_experts = gating_output.shape[-1]
|
||||
scores = gating_output.softmax(dim=-1)
|
||||
scores_for_choice = scores.view(
|
||||
-1, n_routed_experts) + e_score_correction_bias.unsqueeze(0)
|
||||
topk_indices = torch.topk(scores_for_choice, k=topk, dim=-1,
|
||||
sorted=False)[1]
|
||||
topk_weights = scores.gather(1, topk_indices)
|
||||
if renormalize:
|
||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||
return topk_weights.to(torch.float32), topk_indices.to(torch.int32)
|
||||
|
||||
|
||||
# This is used by the Deepseek-V2 and Deepseek-V3 model
|
||||
@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
|
||||
def grouped_topk(
|
||||
|
||||
@ -24,6 +24,8 @@ from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FUSED_MOE_UNQUANTIZED_CONFIG, FusedMoEConfig, FusedMoEParallelConfig,
|
||||
FusedMoEQuantConfig, biased_moe_quant_config)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
zero_experts_compute_triton)
|
||||
# yapf: enable
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||
FusedMoEActivationFormat, FusedMoEModularKernel,
|
||||
@ -548,7 +550,10 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
zero_expert_num = getattr(layer, 'zero_expert_num', 0)
|
||||
zero_expert_type = getattr(layer, 'zero_expert_type', None)
|
||||
|
||||
topk_weights, topk_ids, zero_expert_result = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
@ -565,11 +570,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
expert_map=expert_map,
|
||||
expert_load_view=expert_load_view,
|
||||
logical_to_physical_map=logical_to_physical_map,
|
||||
logical_replica_count=logical_replica_count)
|
||||
logical_replica_count=logical_replica_count,
|
||||
global_num_experts=global_num_experts,
|
||||
zero_expert_num=zero_expert_num,
|
||||
zero_expert_type=zero_expert_type)
|
||||
|
||||
if self.rocm_aiter_moe_enabled:
|
||||
assert self.fused_experts is None
|
||||
return self.rocm_aiter_fused_experts(
|
||||
result = self.rocm_aiter_fused_experts(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
@ -591,7 +599,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
if self.moe.has_bias:
|
||||
raise ValueError(
|
||||
"FusedMoEModularKernel does not support bias.")
|
||||
return self.fused_experts(
|
||||
result = self.fused_experts(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
@ -605,7 +613,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
)
|
||||
else:
|
||||
assert fused_experts is not None
|
||||
return fused_experts(
|
||||
result = fused_experts(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
@ -619,6 +627,13 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
expert_map=expert_map,
|
||||
)
|
||||
|
||||
if zero_expert_num != 0 and zero_expert_type is not None:
|
||||
assert not isinstance(result, tuple), \
|
||||
"Shared + zero experts are mutually exclusive not yet supported"
|
||||
return result, zero_expert_result
|
||||
else:
|
||||
return result
|
||||
|
||||
def forward_cpu(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
@ -942,6 +957,8 @@ class FusedMoE(CustomOp):
|
||||
num_redundant_experts: int = 0,
|
||||
has_bias: bool = False,
|
||||
is_sequence_parallel=False,
|
||||
zero_expert_num: Optional[int] = 0,
|
||||
zero_expert_type: Optional[str] = None,
|
||||
):
|
||||
super().__init__()
|
||||
if params_dtype is None:
|
||||
@ -976,6 +993,8 @@ class FusedMoE(CustomOp):
|
||||
vllm_parallel_config=vllm_config.parallel_config))
|
||||
|
||||
self.global_num_experts = num_experts + num_redundant_experts
|
||||
self.zero_expert_num = zero_expert_num
|
||||
self.zero_expert_type = zero_expert_type
|
||||
|
||||
# Round up hidden size if needed.
|
||||
hidden_size = maybe_roundup_hidden_size(hidden_size, moe_in_dtype,
|
||||
@ -1656,25 +1675,30 @@ class FusedMoE(CustomOp):
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
global_num_experts: Optional[int] = None,
|
||||
zero_expert_num: Optional[int] = None,
|
||||
zero_expert_type: Optional[str] = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Route the input hidden states to the top-k experts based on the
|
||||
router logits.
|
||||
|
||||
Returns:
|
||||
(topk_weights, topk_ids) (tuple[torch.Tensor, torch.Tensor]):
|
||||
The weights and *global physical* expert ids of the top-k experts.
|
||||
(topk_weights, topk_ids, zero_expert_result)
|
||||
(tuple[torch.Tensor, torch.Tensor, torch.Tensor]):
|
||||
The weights, expert ids, and zero expert computation result.
|
||||
|
||||
**Compatibility**: When EPLB is not enabled, the returned ids are
|
||||
equivalent to global logical ids, so should be compatible with
|
||||
plain MoE implementations without redundant experts.
|
||||
"""
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
fused_topk, fused_topk_bias)
|
||||
|
||||
# Check if we should use a routing simulation strategy
|
||||
routing_strategy = envs.VLLM_MOE_ROUTING_SIMULATION_STRATEGY
|
||||
if routing_strategy != "":
|
||||
return RoutingSimulator.simulate_routing(
|
||||
topk_weights, topk_ids = RoutingSimulator.simulate_routing(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
strategy_name=routing_strategy,
|
||||
@ -1697,6 +1721,16 @@ class FusedMoE(CustomOp):
|
||||
e_score_correction_bias=e_score_correction_bias)
|
||||
if indices_type is not None:
|
||||
topk_ids = topk_ids.to(dtype=indices_type)
|
||||
elif e_score_correction_bias is not None:
|
||||
topk_weights, topk_ids = fused_topk_bias(
|
||||
hidden_states=hidden_states,
|
||||
gating_output=router_logits,
|
||||
e_score_correction_bias=e_score_correction_bias.data,
|
||||
topk=top_k,
|
||||
renormalize=renormalize,
|
||||
)
|
||||
if routed_scaling_factor is not None:
|
||||
topk_weights *= routed_scaling_factor
|
||||
elif custom_routing_function is None:
|
||||
topk_weights, topk_ids, token_expert_indices = fused_topk(
|
||||
hidden_states=hidden_states,
|
||||
@ -1729,7 +1763,20 @@ class FusedMoE(CustomOp):
|
||||
|
||||
assert topk_ids.dtype == indices_type or indices_type is None
|
||||
|
||||
return topk_weights, topk_ids
|
||||
# Compute zero expert result if needed
|
||||
if (zero_expert_num is not None and zero_expert_num > 0
|
||||
and zero_expert_type is not None
|
||||
and global_num_experts is not None):
|
||||
zero_expert_result = zero_experts_compute_triton(
|
||||
expert_indices=topk_ids,
|
||||
expert_scales=topk_weights,
|
||||
num_experts=global_num_experts,
|
||||
zero_expert_type=zero_expert_type,
|
||||
hidden_states=hidden_states,
|
||||
)
|
||||
else:
|
||||
zero_expert_result = None
|
||||
return topk_weights, topk_ids, zero_expert_result
|
||||
|
||||
def must_reduce_shared_expert_outputs(self) -> bool:
|
||||
"""
|
||||
@ -1878,6 +1925,11 @@ class FusedMoE(CustomOp):
|
||||
assert self.shared_experts is None or isinstance(
|
||||
final_hidden_states, tuple)
|
||||
|
||||
if isinstance(final_hidden_states, tuple):
|
||||
final_hidden_states, zero_expert_result = final_hidden_states
|
||||
if zero_expert_result is not None:
|
||||
final_hidden_states += zero_expert_result
|
||||
|
||||
if not skip_result_store:
|
||||
if self.shared_experts is None:
|
||||
full_fused_final_hidden_states[
|
||||
@ -1992,6 +2044,9 @@ class FusedMoE(CustomOp):
|
||||
shared_output,
|
||||
final_hidden_states,
|
||||
)
|
||||
elif self.zero_expert_num is not None and self.zero_expert_num > 0:
|
||||
assert isinstance(final_hidden_states, tuple)
|
||||
final_hidden_states, zero_expert_result = final_hidden_states
|
||||
|
||||
def reduce_output(states: torch.Tensor,
|
||||
do_combine: bool = True) -> torch.Tensor:
|
||||
@ -2003,14 +2058,16 @@ class FusedMoE(CustomOp):
|
||||
|
||||
return states
|
||||
|
||||
if self.shared_experts is None:
|
||||
assert not isinstance(final_hidden_states, tuple)
|
||||
return reduce_output(final_hidden_states)
|
||||
else:
|
||||
if self.shared_experts is not None:
|
||||
return (
|
||||
reduce_output(final_hidden_states[0], do_combine=False),
|
||||
reduce_output(final_hidden_states[1]),
|
||||
)
|
||||
elif self.zero_expert_num is not None and self.zero_expert_num > 0:
|
||||
assert isinstance(final_hidden_states, torch.Tensor)
|
||||
return reduce_output(final_hidden_states) + zero_expert_result
|
||||
else:
|
||||
return reduce_output(final_hidden_states)
|
||||
|
||||
@classmethod
|
||||
def make_expert_params_mapping(
|
||||
|
||||
@ -103,7 +103,6 @@ class MultiHeadLatentAttention(CustomOp):
|
||||
)
|
||||
|
||||
self.prefix = prefix
|
||||
self.debug_layer_idx = int(self.prefix.split(".")[-2])
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
|
||||
@ -520,7 +520,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
|
||||
|
||||
assert activation == "silu", "Only SiLU activation is supported."
|
||||
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
topk_weights, topk_ids, _ = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
|
||||
@ -486,7 +486,7 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase):
|
||||
if enable_eplb:
|
||||
raise NotImplementedError(
|
||||
"EPLB not supported for `BitsAndBytesMoEMethod` yet.")
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
topk_weights, topk_ids, _ = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
|
||||
@ -385,7 +385,7 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
|
||||
"`CompressedTensorsW4A4MoeMethod` yet.")
|
||||
assert activation == "silu", "Only SiLU activation is supported."
|
||||
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
topk_weights, topk_ids, _ = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
@ -934,7 +934,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
"EPLB not supported for "
|
||||
"`CompressedTensorsW8A8Fp8MoEMethod` yet.")
|
||||
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
topk_weights, topk_ids, _ = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
@ -1195,7 +1195,7 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
|
||||
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
topk_weights, topk_ids, _ = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
@ -1502,7 +1502,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
|
||||
assert activation == "silu", (
|
||||
f"{activation} not supported for Marlin MoE.")
|
||||
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
topk_weights, topk_ids, _ = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
@ -1747,7 +1747,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
||||
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
topk_weights, topk_ids, _ = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
|
||||
@ -146,7 +146,7 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
topk_weights, topk_ids, _ = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
|
||||
@ -18,6 +18,8 @@ from vllm.model_executor.layers.fused_moe import (
|
||||
FusedMoeWeightScaleSupported)
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEQuantConfig, fp8_w8a8_moe_quant_config)
|
||||
from vllm.model_executor.layers.fused_moe.layer import (
|
||||
UnquantizedFusedMoEMethod)
|
||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
@ -174,6 +176,10 @@ class Fp8Config(QuantizationConfig):
|
||||
return UnquantizedLinearMethod()
|
||||
return Fp8LinearMethod(self)
|
||||
elif isinstance(layer, FusedMoE):
|
||||
if is_layer_skipped(prefix=prefix,
|
||||
ignored_layers=self.ignored_layers,
|
||||
fused_mapping=self.packed_modules_mapping):
|
||||
return UnquantizedFusedMoEMethod(layer.moe_config)
|
||||
return Fp8MoEMethod(self, layer)
|
||||
elif isinstance(layer, Attention):
|
||||
return Fp8KVCacheMethod(self)
|
||||
@ -927,6 +933,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
|
||||
if enable_eplb:
|
||||
assert expert_load_view is not None
|
||||
assert logical_to_physical_map is not None
|
||||
@ -943,8 +950,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401
|
||||
assert (renormalize and use_grouped_topk
|
||||
and custom_routing_function is None)
|
||||
|
||||
return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8(
|
||||
result = torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8(
|
||||
routing_logits=router_logits.to(torch.float32),
|
||||
routing_bias=e_score_correction_bias,
|
||||
x=x,
|
||||
@ -965,7 +971,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
else:
|
||||
assert (not renormalize
|
||||
and custom_routing_function is not None)
|
||||
return apply_flashinfer_per_tensor_scale_fp8(
|
||||
result = apply_flashinfer_per_tensor_scale_fp8(
|
||||
layer=layer,
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
@ -976,7 +982,10 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
topk_group=topk_group,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input)
|
||||
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
zero_expert_num = getattr(layer, 'zero_expert_num', 0)
|
||||
zero_expert_type = getattr(layer, 'zero_expert_type', None)
|
||||
|
||||
select_result = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
@ -994,17 +1003,22 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
expert_load_view=expert_load_view,
|
||||
logical_to_physical_map=logical_to_physical_map,
|
||||
logical_replica_count=logical_replica_count,
|
||||
global_num_experts=global_num_experts,
|
||||
zero_expert_num=zero_expert_num,
|
||||
zero_expert_type=zero_expert_type,
|
||||
)
|
||||
|
||||
#
|
||||
# Note: the order of checks is important since self.fused_experts
|
||||
# can override fused_experts or cutlass but not rocm or marlin.
|
||||
#
|
||||
topk_weights, topk_ids, zero_expert_result = select_result
|
||||
|
||||
if self.rocm_aiter_moe_enabled:
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501
|
||||
rocm_aiter_fused_experts)
|
||||
assert self.fused_experts is None
|
||||
return rocm_aiter_fused_experts(
|
||||
result = rocm_aiter_fused_experts(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
@ -1018,7 +1032,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
assert activation == "silu", (
|
||||
f"{activation} not supported for Marlin MoE.")
|
||||
assert self.fused_experts is None
|
||||
return torch.ops.vllm.fused_marlin_moe(
|
||||
result = torch.ops.vllm.fused_marlin_moe(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
@ -1035,7 +1049,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
expert_map=expert_map,
|
||||
workspace=layer.workspace)
|
||||
elif self.fused_experts:
|
||||
return self.fused_experts(
|
||||
result = self.fused_experts(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
@ -1055,7 +1069,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
assert scoring_func == 'sigmoid', (
|
||||
f"Expected 'sigmoid' scoring func but got {scoring_func}")
|
||||
|
||||
return flashinfer_cutlass_moe_fp8(
|
||||
result = flashinfer_cutlass_moe_fp8(
|
||||
x,
|
||||
layer,
|
||||
topk_weights,
|
||||
@ -1068,7 +1082,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
)
|
||||
else:
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
return fused_experts(
|
||||
result = fused_experts(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
@ -1083,6 +1097,12 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
allow_deep_gemm=self.allow_deep_gemm,
|
||||
allow_cutlass_block_scaled_grouped_gemm=(
|
||||
self.allow_cutlass_block_scaled_grouped_gemm))
|
||||
if zero_expert_num != 0 and zero_expert_type is not None:
|
||||
assert not isinstance(result, tuple), \
|
||||
"Shared + zero experts are mutually exclusive not yet supported"
|
||||
return result, zero_expert_result
|
||||
else:
|
||||
return result
|
||||
|
||||
|
||||
class Fp8KVCacheMethod(BaseKVCacheMethod):
|
||||
|
||||
@ -555,7 +555,7 @@ class GGUFMoEMethod(FusedMoEMethodBase):
|
||||
"Apply router weight on input is not supported for"
|
||||
"fused GGUF MoE method.")
|
||||
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
topk_weights, topk_ids, _ = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
|
||||
@ -669,7 +669,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
|
||||
assert activation == "silu", "Only SiLU activation is supported."
|
||||
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
topk_weights, topk_ids, _ = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
|
||||
@ -543,7 +543,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
apply_router_weight_on_input=apply_router_weight_on_input)
|
||||
|
||||
# Expert selection
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
topk_weights, topk_ids, _ = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
@ -1491,7 +1491,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
)[0]
|
||||
return out
|
||||
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
topk_weights, topk_ids, _ = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
|
||||
@ -332,7 +332,7 @@ class MoeWNA16Method(FusedMoEMethodBase):
|
||||
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
assert activation == "silu", "Only SiLU activation is supported."
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
topk_weights, topk_ids, _ = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
|
||||
@ -718,7 +718,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
assert isinstance(self.fused_experts, mk.FusedMoEModularKernel)
|
||||
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
topk_weights, topk_ids, _ = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
@ -783,7 +783,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
raise NotImplementedError("EPLB is not supported for mxfp4")
|
||||
|
||||
if self.mxfp4_backend == Mxfp4Backend.MARLIN:
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
topk_weights, topk_ids, _ = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
@ -894,7 +894,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
or self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16):
|
||||
from vllm.utils.flashinfer import flashinfer_cutlass_fused_moe
|
||||
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
topk_weights, topk_ids, _ = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
|
||||
@ -329,7 +329,7 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
|
||||
raise NotImplementedError(
|
||||
"EPLB not supported for `QuarkW8A8Fp8MoEMethod` yet.")
|
||||
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
topk_weights, topk_ids, _ = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
@ -531,7 +531,7 @@ class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod):
|
||||
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
topk_weights, topk_ids, _ = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
|
||||
@ -318,7 +318,7 @@ class RTNMoEMethod(FusedMoEMethodBase):
|
||||
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
topk_weights, topk_ids, _ = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
|
||||
@ -292,6 +292,11 @@ def is_layer_skipped(
|
||||
f"Detected some but not all shards of {prefix} "
|
||||
"are quantized. All shards of fused layers "
|
||||
"to have the same precision.")
|
||||
elif "experts" in prefix:
|
||||
return any([
|
||||
prefix in layer_name for layer_name in ignored_layers
|
||||
if "experts" in layer_name
|
||||
])
|
||||
else:
|
||||
is_skipped = prefix in ignored_layers
|
||||
|
||||
|
||||
712
vllm/model_executor/models/longcat_flash.py
Normal file
712
vllm/model_executor/models/longcat_flash.py
Normal file
@ -0,0 +1,712 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# Apache License, Version 2.0:
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# MIT License:
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
"""Inference-only Flash model compatible with HuggingFace weights."""
|
||||
import typing
|
||||
from collections.abc import Callable, Iterable
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_pp_group
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.quantization.utils.int8_utils import (
|
||||
block_dequant)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.deepseek_v2 import DeepseekV2MLAAttention
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsLoRA, SupportsPP
|
||||
from .utils import (PPMissingLayer, is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class FlashConfig(PretrainedConfig):
|
||||
"""Flash model configuration."""
|
||||
model_type = "longcat_flash"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=131072,
|
||||
hidden_size=4096,
|
||||
intermediate_size=8192,
|
||||
num_layers=28,
|
||||
num_hidden_layers=None,
|
||||
num_attention_heads=96,
|
||||
num_key_value_heads=128,
|
||||
ep_size=1,
|
||||
kv_lora_rank=512,
|
||||
q_lora_rank=1536,
|
||||
qk_rope_head_dim=64,
|
||||
v_head_dim=128,
|
||||
qk_nope_head_dim=128,
|
||||
num_experts_per_tok=None,
|
||||
norm_topk_prob=False,
|
||||
max_position_embeddings=8192,
|
||||
initializer_range=0.02,
|
||||
rms_norm_eps=1e-05,
|
||||
use_cache=True,
|
||||
pad_token_id=None,
|
||||
bos_token_id=100000,
|
||||
eos_token_id=100001,
|
||||
pretraining_tp=1,
|
||||
tie_word_embeddings=False,
|
||||
rope_theta=1000000.0,
|
||||
rope_scaling=None,
|
||||
attention_bias=False,
|
||||
attention_dropout=0.0,
|
||||
mla_scale_q_lora=False,
|
||||
mla_scale_kv_lora=False,
|
||||
torch_dtype="bfloat16",
|
||||
params_dtype="bfloat16",
|
||||
router_dtype="float32",
|
||||
router_bias=False,
|
||||
topk_method=None,
|
||||
routed_scaling_factor=None,
|
||||
zero_expert_num=0,
|
||||
zero_expert_type=None,
|
||||
nextn_use_scmoe=False,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
torch_dtype=torch_dtype,
|
||||
params_dtype=params_dtype,
|
||||
router_dtype=router_dtype,
|
||||
topk_method=topk_method,
|
||||
router_bias=router_bias,
|
||||
nextn_use_scmoe=nextn_use_scmoe,
|
||||
**kwargs,
|
||||
)
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = (num_hidden_layers if num_hidden_layers
|
||||
is not None else num_layers)
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.ep_size = ep_size
|
||||
self.kv_lora_rank = kv_lora_rank
|
||||
self.q_lora_rank = q_lora_rank
|
||||
self.qk_rope_head_dim = qk_rope_head_dim
|
||||
self.v_head_dim = v_head_dim
|
||||
self.qk_nope_head_dim = qk_nope_head_dim
|
||||
self.num_experts_per_tok = num_experts_per_tok
|
||||
self.norm_topk_prob = norm_topk_prob
|
||||
# for backward compatibility
|
||||
if num_key_value_heads is None:
|
||||
num_key_value_heads = num_attention_heads
|
||||
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.initializer_range = initializer_range
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.pretraining_tp = pretraining_tp
|
||||
self.use_cache = use_cache
|
||||
self.rope_theta = rope_theta
|
||||
self.rope_scaling = rope_scaling
|
||||
self.attention_bias = attention_bias
|
||||
self.attention_dropout = attention_dropout
|
||||
self.mla_scale_q_lora = mla_scale_q_lora
|
||||
self.mla_scale_kv_lora = mla_scale_kv_lora
|
||||
self.zero_expert_num = zero_expert_num
|
||||
self.zero_expert_type = zero_expert_type
|
||||
self.routed_scaling_factor = routed_scaling_factor
|
||||
self.hidden_act = "silu"
|
||||
self.intermediate_size = self.ffn_hidden_size if hasattr(
|
||||
self, "ffn_hidden_size") else self.intermediate_size
|
||||
if hasattr(self, "moe_intermediate_size"):
|
||||
self.moe_intermediate_size = self.moe_intermediate_size
|
||||
elif hasattr(self, "expert_ffn_hidden_size"):
|
||||
self.moe_intermediate_size = self.expert_ffn_hidden_size
|
||||
else:
|
||||
self.moe_intermediate_size = self.intermediate_size
|
||||
|
||||
|
||||
class FlashMLP(nn.Module):
|
||||
"""Flash MLP layer."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
hidden_act: str,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
reduce_results: bool = True,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
hidden_size,
|
||||
[intermediate_size] * 2,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.gate_up_proj",
|
||||
)
|
||||
self.down_proj = RowParallelLinear(
|
||||
intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
reduce_results=reduce_results,
|
||||
prefix=f"{prefix}.down_proj",
|
||||
)
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
||||
"Only silu is supported for now.")
|
||||
self.act_fn = SiluAndMul()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if x.numel() == 0:
|
||||
return x
|
||||
|
||||
gate_up, _ = self.gate_up_proj(x)
|
||||
x = self.act_fn(gate_up)
|
||||
x, _ = self.down_proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class LongcatRouter(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
config,
|
||||
zero_expert_num=0,
|
||||
rounter_params_dtype=torch.bfloat16,
|
||||
prefix: str = ""):
|
||||
super().__init__()
|
||||
self.n_routed_experts = config.n_routed_experts if hasattr(
|
||||
config, "n_routed_experts") else config.num_experts[0]
|
||||
self.n_routed_experts = self.n_routed_experts + zero_expert_num
|
||||
self.classifier = ReplicatedLinear(
|
||||
config.hidden_size,
|
||||
self.n_routed_experts,
|
||||
bias=config.router_bias,
|
||||
params_dtype=rounter_params_dtype,
|
||||
quant_config=None,
|
||||
prefix=f"{prefix}.classifier",
|
||||
)
|
||||
self.e_score_correction_bias = nn.Parameter(
|
||||
torch.zeros((self.n_routed_experts), dtype=rounter_params_dtype))
|
||||
|
||||
def forward(self, hidden_states):
|
||||
logits, _ = self.classifier(hidden_states)
|
||||
return logits
|
||||
|
||||
|
||||
class LongcatMoe(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: FlashConfig,
|
||||
num_experts: int,
|
||||
top_k: int,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
enable_eplb: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.zero_expert_num = config.zero_expert_num
|
||||
self.zero_expert_type = config.zero_expert_type
|
||||
self.routed_scaling_factor = config.routed_scaling_factor
|
||||
self.enable_eplb = enable_eplb
|
||||
# Gate always runs at half / full precision for now.
|
||||
self.rounter_params_dtype = params_dtype
|
||||
if config.router_dtype == "float32":
|
||||
self.rounter_params_dtype = torch.float32
|
||||
|
||||
self.router = LongcatRouter(
|
||||
config=config,
|
||||
zero_expert_num=self.zero_expert_num,
|
||||
rounter_params_dtype=self.rounter_params_dtype,
|
||||
prefix=f"{prefix}.gate")
|
||||
|
||||
self.experts = FusedMoE(
|
||||
num_experts=num_experts,
|
||||
top_k=top_k,
|
||||
hidden_size=hidden_size,
|
||||
intermediate_size=intermediate_size,
|
||||
reduce_results=True,
|
||||
params_dtype=params_dtype,
|
||||
e_score_correction_bias=self.router.e_score_correction_bias,
|
||||
renormalize=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.experts",
|
||||
zero_expert_num=self.zero_expert_num,
|
||||
zero_expert_type=self.zero_expert_type,
|
||||
enable_eplb=self.enable_eplb,
|
||||
routed_scaling_factor=config.routed_scaling_factor,
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
num_tokens, hidden_dim = hidden_states.shape
|
||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||
|
||||
router_logits = self.router(hidden_states.to(
|
||||
self.rounter_params_dtype))
|
||||
final_hidden_states = self.experts(hidden_states=hidden_states,
|
||||
router_logits=router_logits)
|
||||
|
||||
return final_hidden_states.view(num_tokens, hidden_dim)
|
||||
|
||||
|
||||
class FlashDecoderLayer(nn.Module):
|
||||
"""Flash decoder layer with dual attention and MLP structure."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: FlashConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
enable_eplb: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.layer_idx = int(prefix.split(sep='.')[-1])
|
||||
self.hidden_size = config.hidden_size
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
rope_scaling = getattr(config, "rope_scaling", None)
|
||||
max_position_embeddings = getattr(config, "max_position_embeddings",
|
||||
8192)
|
||||
if rope_scaling is not None and getattr(
|
||||
config, "original_max_position_embeddings", None):
|
||||
rope_scaling["original_max_position_embeddings"] = (
|
||||
config.original_max_position_embeddings)
|
||||
|
||||
# Dual attention structure
|
||||
self.self_attn = nn.ModuleList([
|
||||
DeepseekV2MLAAttention(
|
||||
config=config,
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
qk_nope_head_dim=config.qk_nope_head_dim,
|
||||
qk_rope_head_dim=config.qk_rope_head_dim,
|
||||
v_head_dim=config.v_head_dim,
|
||||
q_lora_rank=(config.q_lora_rank if hasattr(
|
||||
config, "q_lora_rank") else None),
|
||||
kv_lora_rank=config.kv_lora_rank,
|
||||
rope_theta=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
cache_config=cache_config,
|
||||
quant_config=None if "self_attn" in getattr(
|
||||
config, "disable_quant_module", []) else quant_config,
|
||||
prefix=f"{prefix}.self_attn.{i}",
|
||||
) for i in range(2)
|
||||
])
|
||||
self.input_layernorm = nn.ModuleList([
|
||||
RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
for i in range(2)
|
||||
])
|
||||
self.post_attention_layernorm = nn.ModuleList([
|
||||
RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
for i in range(2)
|
||||
])
|
||||
|
||||
# Dual MLP structure
|
||||
self.mlps = nn.ModuleList([
|
||||
FlashMLP(
|
||||
hidden_size=self.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=None if "mlps" in getattr(
|
||||
config, "disable_quant_module", []) else quant_config,
|
||||
prefix=f"{prefix}.mlps.{i}",
|
||||
) for i in range(2)
|
||||
])
|
||||
|
||||
self.mlp = LongcatMoe(
|
||||
config=config,
|
||||
num_experts=config.n_routed_experts if hasattr(
|
||||
config, "n_routed_experts") else
|
||||
config.num_experts[self.layer_idx],
|
||||
top_k=config.moe_topk
|
||||
if hasattr(config, "moe_topk") else config.num_experts_per_tok,
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.moe_intermediate_size,
|
||||
quant_config=quant_config,
|
||||
prefix=(f"{prefix}.mlp"),
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm[0](hidden_states)
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm[0](hidden_states,
|
||||
residual)
|
||||
|
||||
hidden_states = self.self_attn[0](
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
)
|
||||
|
||||
hidden_states, residual = self.post_attention_layernorm[0](
|
||||
hidden_states, residual)
|
||||
|
||||
# moe
|
||||
hidden_states_copy = hidden_states.clone()
|
||||
moe_hidden_states = self.mlp(hidden_states_copy)
|
||||
|
||||
# first mlp
|
||||
hidden_states = self.mlps[0](hidden_states)
|
||||
|
||||
hidden_states, residual = self.input_layernorm[1](hidden_states,
|
||||
residual)
|
||||
|
||||
# second_attn
|
||||
hidden_states = self.self_attn[1](
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
)
|
||||
hidden_states, residual = self.post_attention_layernorm[1](
|
||||
hidden_states, residual)
|
||||
|
||||
# second_mlp
|
||||
hidden_states = self.mlps[1](hidden_states)
|
||||
|
||||
hidden_states = hidden_states + moe_hidden_states
|
||||
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class FlashModel(nn.Module):
|
||||
"""Flash model."""
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = FlashConfig(**vllm_config.model_config.hf_config.__dict__)
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
self.config = config
|
||||
|
||||
self.padding_idx = getattr(config, "pad_token_id", None)
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
if get_pp_group().is_first_rank:
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
prefix=maybe_prefix(prefix, "embed_tokens"),
|
||||
)
|
||||
else:
|
||||
self.embed_tokens = PPMissingLayer()
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: FlashDecoderLayer(
|
||||
config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
),
|
||||
prefix=f"{prefix}.layers")
|
||||
if get_pp_group().is_last_rank:
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
else:
|
||||
self.norm = PPMissingLayer()
|
||||
self.make_empty_intermediate_tensors = (
|
||||
make_empty_intermediate_tensors_factory(
|
||||
["hidden_states", "residual"], config.hidden_size))
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if get_pp_group().is_first_rank:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
residual = None
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
hidden_states, residual = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
residual,
|
||||
)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
"hidden_states": hidden_states,
|
||||
"residual": residual
|
||||
})
|
||||
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class LongcatFlashForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
"""Flash model for causal language modeling."""
|
||||
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
],
|
||||
"gate_up_proj": [
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
],
|
||||
}
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = FlashConfig(**vllm_config.model_config.hf_config.__dict__)
|
||||
quant_config = vllm_config.quant_config
|
||||
lora_config = vllm_config.lora_config
|
||||
|
||||
self.config = config
|
||||
config.intermediate_size = config.ffn_hidden_size if hasattr(
|
||||
config, "ffn_hidden_size") else config.intermediate_size
|
||||
self.lora_config = lora_config
|
||||
self.quant_config = quant_config
|
||||
|
||||
self.model = FlashModel(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
|
||||
if get_pp_group().is_last_rank:
|
||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(
|
||||
prefix, "lm_head"))
|
||||
else:
|
||||
self.lm_head = PPMissingLayer()
|
||||
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> Optional[torch.Tensor]:
|
||||
logits = self.logits_processor(self.lm_head, hidden_states)
|
||||
return logits
|
||||
|
||||
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
|
||||
# Params for weights, fp8 weight scales, fp8 activation scales
|
||||
# (param_name, weight_name, expert_id, shard_id)
|
||||
return FusedMoE.make_expert_params_mapping(
|
||||
ckpt_gate_proj_name="gate_proj",
|
||||
ckpt_down_proj_name="down_proj",
|
||||
ckpt_up_proj_name="up_proj",
|
||||
num_experts=self.config.n_routed_experts if hasattr(
|
||||
self.config, "n_routed_experts") else
|
||||
self.config.num_experts[0],
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
|
||||
stacked_params_mapping = [
|
||||
("fused_qkv_a_proj", "q_a_proj", 0),
|
||||
("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1),
|
||||
(".gate_up_proj", ".gate_proj", 0),
|
||||
(".gate_up_proj", ".up_proj", 1),
|
||||
]
|
||||
|
||||
expert_params_mapping = self.get_expert_mapping()
|
||||
loaded_params: set[str] = set()
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
if "mlp" in name and "mlps" not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if (name.endswith(".bias")
|
||||
or name.endswith("_bias")) and name not in params_dict:
|
||||
continue
|
||||
# Skip mtp
|
||||
if ".mtp." in name:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
is_expert_weight = False
|
||||
for mapping in expert_params_mapping:
|
||||
param_name, weight_name, expert_id, shard_id = mapping
|
||||
if weight_name not in name:
|
||||
continue
|
||||
is_expert_weight = True
|
||||
name_mapped = name.replace(weight_name, param_name)
|
||||
# Skip mtp
|
||||
if ".mtp." in name_mapped:
|
||||
continue
|
||||
if (name_mapped.endswith(".bias")
|
||||
or name_mapped.endswith("_bias")
|
||||
) and name not in params_dict:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name_mapped]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader = typing.cast(Callable[..., bool],
|
||||
param.weight_loader)
|
||||
success = weight_loader(param,
|
||||
loaded_weight,
|
||||
name_mapped,
|
||||
shard_id=shard_id,
|
||||
expert_id=expert_id,
|
||||
return_success=True)
|
||||
if success:
|
||||
name = name_mapped
|
||||
break
|
||||
else:
|
||||
if is_expert_weight:
|
||||
# We've checked that this is an expert weight
|
||||
# However it's not mapped locally to this rank
|
||||
# So we simply skip it
|
||||
continue
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
# Skip loading kv_scale from ckpts towards new design.
|
||||
if name.endswith(".kv_scale") and name not in params_dict:
|
||||
continue
|
||||
# Skip mtp
|
||||
if ".mtp." in name:
|
||||
continue
|
||||
if name is None:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
for layer_id in range(self.config.num_hidden_layers):
|
||||
for i in range(2):
|
||||
if isinstance(self.model.layers[layer_id], PPMissingLayer):
|
||||
continue
|
||||
self_attn = self.model.layers[layer_id].self_attn[i]
|
||||
if hasattr(self.quant_config, "weight_block_size"
|
||||
) and self_attn.kv_b_proj.weight.dtype in (
|
||||
torch.float8_e4m3fn,
|
||||
torch.float8_e4m3fnuz,
|
||||
):
|
||||
weight_block_size = self.quant_config.weight_block_size
|
||||
if weight_block_size is not None:
|
||||
assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
|
||||
dtype = torch.get_default_dtype()
|
||||
w = block_dequant(self_attn.kv_b_proj.weight,
|
||||
self_attn.kv_b_proj.weight_scale_inv,
|
||||
weight_block_size).to(dtype)
|
||||
else:
|
||||
w = self_attn.kv_b_proj.weight
|
||||
|
||||
w_kc, w_vc = w.unflatten(
|
||||
0,
|
||||
(-1,
|
||||
self_attn.qk_nope_head_dim + self_attn.v_head_dim)).split(
|
||||
[self_attn.qk_nope_head_dim, self_attn.v_head_dim],
|
||||
dim=1)
|
||||
self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(
|
||||
1, 2)
|
||||
self_attn.w_vc = w_vc.contiguous().transpose(1, 2)
|
||||
if self.config.mla_scale_q_lora:
|
||||
self_attn.q_a_layernorm.weight.data *= (
|
||||
self.config.hidden_size / self.config.q_lora_rank)**0.5
|
||||
if self.config.mla_scale_kv_lora:
|
||||
self_attn.kv_a_layernorm.weight.data *= (
|
||||
self.config.hidden_size /
|
||||
self.config.kv_lora_rank)**0.5
|
||||
return loaded_params
|
||||
352
vllm/model_executor/models/longcat_flash_mtp.py
Normal file
352
vllm/model_executor/models/longcat_flash_mtp.py
Normal file
@ -0,0 +1,352 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Adapted from
|
||||
# https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/model_executor/models/deepseek_mtp.py
|
||||
from collections.abc import Iterable
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import ReplicatedLinear
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.quantization.utils.int8_utils import (
|
||||
block_dequant)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.longcat_flash import FlashConfig
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .deepseek_v2 import DeepseekV2DecoderLayer
|
||||
from .interfaces import SupportsPP
|
||||
from .utils import maybe_prefix
|
||||
|
||||
|
||||
class LongCatMultiTokenPredictorLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
prefix: str,
|
||||
vllm_config: VllmConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.eh_proj = ReplicatedLinear(2 * config.hidden_size,
|
||||
config.hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix="eh_proj")
|
||||
self.mtp_block = DeepseekV2DecoderLayer(vllm_config, prefix)
|
||||
self.final_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
previous_hidden_states: torch.Tensor,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
spec_step_index: int = 0,
|
||||
) -> torch.Tensor:
|
||||
assert inputs_embeds is not None
|
||||
inputs_embeds = self.enorm(inputs_embeds)
|
||||
previous_hidden_states = self.hnorm(previous_hidden_states)
|
||||
|
||||
hidden_states, _ = self.eh_proj(
|
||||
torch.cat([inputs_embeds, previous_hidden_states], dim=-1))
|
||||
|
||||
hidden_states, residual = self.mtp_block(positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
residual=None)
|
||||
hidden_states, _ = self.final_layernorm(hidden_states, residual)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class LongCatMultiTokenPredictor(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
vllm_config: VllmConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = ""):
|
||||
super().__init__()
|
||||
config = FlashConfig(**vllm_config.model_config.hf_config.__dict__)
|
||||
vllm_config.model_config.hf_config.intermediate_size \
|
||||
= config.intermediate_size
|
||||
self.mtp_start_layer_idx = config.num_hidden_layers * 2
|
||||
self.num_mtp_layers = 1
|
||||
self.layers = torch.nn.ModuleDict({
|
||||
str(idx):
|
||||
LongCatMultiTokenPredictorLayer(
|
||||
config,
|
||||
prefix=f"{prefix}.layers.{idx}",
|
||||
vllm_config=vllm_config,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
for idx in range(self.mtp_start_layer_idx,
|
||||
self.mtp_start_layer_idx + self.num_mtp_layers)
|
||||
})
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
previous_hidden_states: torch.Tensor,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
spec_step_idx: int = 0,
|
||||
) -> torch.Tensor:
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
current_step_idx = (spec_step_idx % self.num_mtp_layers)
|
||||
return self.layers[str(self.mtp_start_layer_idx + current_step_idx)](
|
||||
input_ids,
|
||||
positions,
|
||||
previous_hidden_states,
|
||||
inputs_embeds,
|
||||
current_step_idx,
|
||||
)
|
||||
|
||||
|
||||
class LongCatFlashMTP(nn.Module, SupportsPP):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
# LongCat MTP without MoE layers
|
||||
vllm_config.model_config.hf_config.n_routed_experts = None
|
||||
self.config = FlashConfig(
|
||||
**vllm_config.model_config.hf_config.__dict__)
|
||||
self.quant_config = None if "mtp" in getattr(
|
||||
self.config, "disable_quant_module",
|
||||
[]) else vllm_config.quant_config
|
||||
|
||||
self.model = LongCatMultiTokenPredictor(vllm_config=vllm_config,
|
||||
quant_config=self.quant_config,
|
||||
prefix=maybe_prefix(
|
||||
prefix, "model"))
|
||||
self.lm_head = ParallelLMHead(
|
||||
self.config.vocab_size,
|
||||
self.config.hidden_size,
|
||||
quant_config=self.quant_config,
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(self.config.vocab_size)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
spec_step_idx: int = 0,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, hidden_states,
|
||||
inputs_embeds, spec_step_idx)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
spec_step_idx: int = 0,
|
||||
) -> Optional[torch.Tensor]:
|
||||
|
||||
logits = self.logits_processor(self.lm_head, hidden_states)
|
||||
return logits
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
stacked_params_mapping = [
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
("fused_qkv_a_proj", "q_a_proj", 0),
|
||||
("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1),
|
||||
]
|
||||
|
||||
new_to_old_names_mapping = {
|
||||
"model.mtp.embed_tokens.weight":
|
||||
"model.layers.0.embed_tokens.weight",
|
||||
"model.mtp.layers.0.eh_proj.weight": "eh_proj.weight",
|
||||
"model.mtp.layers.0.eh_proj.weight_scale_inv":
|
||||
"eh_proj.weight_scale_inv",
|
||||
"model.mtp.layers.0.enorm.m.weight": "enorm.weight",
|
||||
"model.mtp.layers.0.hnorm.m.weight": "hnorm.weight",
|
||||
"model.mtp.layers.0.input_layernorm.weight":
|
||||
"model.layers.0.input_layernorm.weight",
|
||||
"model.mtp.layers.0.post_attention_layernorm.weight":
|
||||
"model.layers.0.post_attention_layernorm.weight",
|
||||
"model.mtp.layers.0.self_attn.kv_a_layernorm.weight":
|
||||
"model.layers.0.self_attn.kv_a_layernorm.weight",
|
||||
"model.mtp.layers.0.self_attn.kv_a_proj_with_mqa.weight":
|
||||
"model.layers.0.self_attn.kv_a_proj_with_mqa.weight",
|
||||
"model.mtp.layers.0.self_attn.kv_a_proj_with_mqa.weight_scale_inv":
|
||||
"model.layers.0.self_attn.kv_a_proj_with_mqa.weight_scale_inv",
|
||||
"model.mtp.layers.0.self_attn.kv_b_proj.weight":
|
||||
"model.layers.0.self_attn.kv_b_proj.weight",
|
||||
"model.mtp.layers.0.self_attn.kv_b_proj.weight_scale_inv":
|
||||
"model.layers.0.self_attn.kv_b_proj.weight_scale_inv",
|
||||
"model.mtp.layers.0.self_attn.o_proj.weight":
|
||||
"model.layers.0.self_attn.o_proj.weight",
|
||||
"model.mtp.layers.0.self_attn.o_proj.weight_scale_inv":
|
||||
"model.layers.0.self_attn.o_proj.weight_scale_inv",
|
||||
"model.mtp.layers.0.self_attn.q_a_layernorm.weight":
|
||||
"model.layers.0.self_attn.q_a_layernorm.weight",
|
||||
"model.mtp.layers.0.self_attn.q_a_proj.weight":
|
||||
"model.layers.0.self_attn.q_a_proj.weight",
|
||||
"model.mtp.layers.0.self_attn.q_a_proj.weight_scale_inv":
|
||||
"model.layers.0.self_attn.q_a_proj.weight_scale_inv",
|
||||
"model.mtp.layers.0.self_attn.q_b_proj.weight":
|
||||
"model.layers.0.self_attn.q_b_proj.weight",
|
||||
"model.mtp.layers.0.self_attn.q_b_proj.weight_scale_inv":
|
||||
"model.layers.0.self_attn.q_b_proj.weight_scale_inv",
|
||||
"model.mtp.layers.0.transformer_layer.mlp.down_proj.weight":
|
||||
"model.layers.0.mlp.down_proj.weight",
|
||||
"model.mtp.layers.0.transformer_layer.mlp.down_proj.weight_scale_inv":
|
||||
"model.layers.0.mlp.down_proj.weight_scale_inv",
|
||||
"model.mtp.layers.0.transformer_layer.mlp.gate_proj.weight":
|
||||
"model.layers.0.mlp.gate_proj.weight",
|
||||
"model.mtp.layers.0.transformer_layer.mlp.gate_proj.weight_scale_inv":
|
||||
"model.layers.0.mlp.gate_proj.weight_scale_inv",
|
||||
"model.mtp.layers.0.transformer_layer.mlp.up_proj.weight":
|
||||
"model.layers.0.mlp.up_proj.weight",
|
||||
"model.mtp.layers.0.transformer_layer.mlp.up_proj.weight_scale_inv":
|
||||
"model.layers.0.mlp.up_proj.weight_scale_inv",
|
||||
"model.mtp.norm.weight": "final_layernorm.weight",
|
||||
}
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
spec_layer = self.get_spec_layer_idx_from_weight_name(
|
||||
self.config, name)
|
||||
if spec_layer is None:
|
||||
continue
|
||||
name = self._rewrite_spec_layer_name(spec_layer, name,
|
||||
new_to_old_names_mapping)
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
# Skip non-stacked layers and experts (experts handled below).
|
||||
if weight_name not in name:
|
||||
continue
|
||||
# We have mlp.experts[0].gate_proj in the checkpoint.
|
||||
# Since we handle the experts below in expert_params_mapping,
|
||||
# we need to skip here BEFORE we update the name, otherwise
|
||||
# name will be updated to mlp.experts[0].gate_up_proj, which
|
||||
# will then be updated below in expert_params_mapping
|
||||
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
|
||||
if (("mlp.experts." in name) and name not in params_dict):
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
|
||||
# QKV fusion is optional, fall back to normal
|
||||
# weight loading if it's not enabled
|
||||
if ((param_name == "fused_qkv_a_proj")
|
||||
and name not in params_dict):
|
||||
continue
|
||||
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
|
||||
# According to DeepSeek-V3 Technical Report, MTP modules
|
||||
# shares embedding layer. We only load the first weights.
|
||||
if (spec_layer != self.model.mtp_start_layer_idx
|
||||
and ".layers" not in name):
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
spec_layer_id = self.config.num_hidden_layers * 2
|
||||
self_attn = self.model.layers[str(spec_layer_id)].mtp_block.self_attn
|
||||
if hasattr(
|
||||
self.quant_config,
|
||||
"weight_block_size") and self_attn.kv_b_proj.weight.dtype in (
|
||||
torch.float8_e4m3fn,
|
||||
torch.float8_e4m3fnuz,
|
||||
):
|
||||
weight_block_size = self.quant_config.weight_block_size
|
||||
if weight_block_size is not None:
|
||||
dtype = torch.get_default_dtype()
|
||||
w = block_dequant(self_attn.kv_b_proj.weight,
|
||||
self_attn.kv_b_proj.weight_scale_inv,
|
||||
weight_block_size).to(dtype)
|
||||
else:
|
||||
w = self_attn.kv_b_proj.weight
|
||||
else:
|
||||
w = self_attn.kv_b_proj.weight
|
||||
w_kc, w_vc = w.unflatten(
|
||||
0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)).split(
|
||||
[self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
|
||||
self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2)
|
||||
self_attn.w_vc = w_vc.contiguous().transpose(1, 2)
|
||||
if self.config.mla_scale_q_lora:
|
||||
self_attn.q_a_layernorm.weight.data *= (
|
||||
self.config.hidden_size / self.config.q_lora_rank)**0.5
|
||||
if self.config.mla_scale_kv_lora:
|
||||
self_attn.kv_a_layernorm.weight.data *= (
|
||||
self.config.hidden_size / self.config.kv_lora_rank)**0.5
|
||||
return loaded_params
|
||||
|
||||
def _rewrite_spec_layer_name(self, spec_layer: int, name: str,
|
||||
new_to_old_names_mapping: dict) -> str:
|
||||
"""
|
||||
Rewrite the weight name to match the format of the original model.
|
||||
Add .mtp_block for modules in transformer layer block for spec layer
|
||||
and rename shared layer weights to be top level.
|
||||
"""
|
||||
if name in new_to_old_names_mapping:
|
||||
name = new_to_old_names_mapping[name]
|
||||
spec_layer_weight_names = [
|
||||
"embed_tokens", "enorm", "hnorm", "eh_proj", "shared_head"
|
||||
]
|
||||
if name.startswith("enorm") or name.startswith(
|
||||
"hnorm") or name.startswith("eh_proj") or name.startswith(
|
||||
"final_layernorm"):
|
||||
name = "model.layers." + str(spec_layer) + "." + name
|
||||
shared_weight_names = ["embed_tokens"]
|
||||
spec_layer_weight = False
|
||||
shared_weight = False
|
||||
for weight_name in spec_layer_weight_names:
|
||||
if weight_name in name:
|
||||
spec_layer_weight = True
|
||||
if weight_name in shared_weight_names:
|
||||
shared_weight = True
|
||||
break
|
||||
if not spec_layer_weight:
|
||||
# treat rest weights as weights for transformer layer block
|
||||
name = name.replace("model.layers.0.",
|
||||
f"model.layers.{spec_layer}.mtp_block.")
|
||||
elif shared_weight:
|
||||
# treat shared weights as top level weights
|
||||
name = name.replace("model.layers.0.", "model.")
|
||||
return name
|
||||
|
||||
def get_spec_layer_idx_from_weight_name(self, config: PretrainedConfig,
|
||||
weight_name: str) -> Optional[int]:
|
||||
if "model.mtp" in weight_name:
|
||||
return config.num_hidden_layers * 2
|
||||
return None
|
||||
@ -109,6 +109,7 @@ _TEXT_GENERATION_MODELS = {
|
||||
"Llama4ForCausalLM": ("llama4", "Llama4ForCausalLM"), # noqa: E501
|
||||
# For decapoda-research/llama-*
|
||||
"LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
|
||||
"LongcatFlashForCausalLM": ("longcat_flash", "LongcatFlashForCausalLM"),
|
||||
"MambaForCausalLM": ("mamba", "MambaForCausalLM"),
|
||||
"FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"),
|
||||
"FalconH1ForCausalLM":("falcon_h1", "FalconH1ForCausalLM"),
|
||||
@ -287,6 +288,7 @@ _SPECULATIVE_DECODING_MODELS = {
|
||||
"EagleDeepSeekMTPModel": ("deepseek_eagle", "EagleDeepseekV3ForCausalLM"),
|
||||
"DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
|
||||
"ErnieMTPModel": ("ernie_mtp", "ErnieMTP"),
|
||||
"LongCatFlashMTPModel": ("longcat_flash_mtp", "LongCatFlashMTP"),
|
||||
"Glm4MoeMTPModel": ("glm4_moe_mtp", "Glm4MoeMTP"),
|
||||
"MedusaModel": ("medusa", "Medusa"),
|
||||
"Qwen3NextMTP": ("qwen3_next_mtp", "Qwen3NextMTP"),
|
||||
|
||||
@ -691,14 +691,14 @@ def maybe_prefix(prefix: str, name: str) -> str:
|
||||
return name if not prefix else f"{prefix}.{name}"
|
||||
|
||||
|
||||
def extract_layer_index(layer_name: str) -> int:
|
||||
def extract_layer_index(layer_name: str, num_attn_module: int = 1) -> int:
|
||||
"""
|
||||
Extract the layer index from the module name.
|
||||
Examples:
|
||||
- "encoder.layers.0" -> 0
|
||||
- "encoder.layers.1.self_attn" -> 1
|
||||
- "2.self_attn" -> 2
|
||||
- "model.encoder.layers.0.sub.1" -> ValueError
|
||||
- "model.encoder.layers.0.sub.1" -> ValueError if num_attn_module == 1
|
||||
"""
|
||||
subnames = layer_name.split(".")
|
||||
int_vals: list[int] = []
|
||||
@ -707,9 +707,17 @@ def extract_layer_index(layer_name: str) -> int:
|
||||
int_vals.append(int(subname))
|
||||
except ValueError:
|
||||
continue
|
||||
assert len(int_vals) == 1, (f"layer name {layer_name} should"
|
||||
" only contain one integer")
|
||||
return int_vals[0]
|
||||
if num_attn_module == 1 or "attn" not in layer_name:
|
||||
assert len(int_vals) == 1, (f"layer name {layer_name} should"
|
||||
" only contain one integer")
|
||||
|
||||
return int_vals[0]
|
||||
else:
|
||||
assert len(int_vals) <= 2, (f"layer name {layer_name} should"
|
||||
" contain most two integers")
|
||||
layer_index = int_vals[0] * num_attn_module + int_vals[1] if len(
|
||||
int_vals) == 2 else int_vals[0]
|
||||
return layer_index
|
||||
|
||||
|
||||
def cast_overflow_tensors(
|
||||
|
||||
@ -169,7 +169,6 @@ class EagleProposer:
|
||||
target_hidden_states = self.model.combine_hidden_states(
|
||||
target_hidden_states)
|
||||
assert target_hidden_states.shape[-1] == self.hidden_size
|
||||
|
||||
# Shift the input ids by one token.
|
||||
# E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
|
||||
self.input_ids[:num_tokens - 1] = target_token_ids[1:]
|
||||
@ -223,7 +222,8 @@ class EagleProposer:
|
||||
hidden_states=self.hidden_states[:num_input_tokens],
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
if self.method in ("deepseek_mtp", "ernie_mtp", "qwen3_next_mtp"):
|
||||
if self.method in ("deepseek_mtp", "ernie_mtp", "qwen3_next_mtp",
|
||||
"longcat_flash_mtp"):
|
||||
last_hidden_states = ret_hidden_states
|
||||
hidden_states = last_hidden_states
|
||||
else:
|
||||
@ -237,7 +237,10 @@ class EagleProposer:
|
||||
return draft_token_ids.view(-1, 1)
|
||||
|
||||
positions = target_positions[last_token_indices]
|
||||
hidden_states = hidden_states[last_token_indices]
|
||||
if self.method in ("deepseek_mtp", "ernie_mtp", "longcat_flash_mtp"):
|
||||
hidden_states = self.hidden_states[last_token_indices]
|
||||
else:
|
||||
hidden_states = hidden_states[last_token_indices]
|
||||
|
||||
if isinstance(attn_metadata, TreeAttentionMetadata):
|
||||
# Draft using tree attention.
|
||||
@ -350,7 +353,7 @@ class EagleProposer:
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
if self.method in ("deepseek_mtp", "ernie_mtp",
|
||||
"qwen3_next_mtp"):
|
||||
"qwen3_next_mtp", "longcat_flash_mtp"):
|
||||
last_hidden_states = ret_hidden_states
|
||||
hidden_states = ret_hidden_states
|
||||
else:
|
||||
|
||||
@ -3840,9 +3840,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
target_layer_name)
|
||||
kv_caches[layer_name] = kv_caches[target_layer_name]
|
||||
|
||||
num_attn_module = 2 \
|
||||
if self.model_config.hf_config.model_type == "longcat_flash" else 1
|
||||
bind_kv_cache(kv_caches,
|
||||
self.compilation_config.static_forward_context,
|
||||
self.kv_caches)
|
||||
self.kv_caches, num_attn_module)
|
||||
return kv_caches
|
||||
|
||||
def maybe_add_kv_sharing_layers_to_kv_cache_groups(
|
||||
|
||||
@ -266,6 +266,7 @@ def bind_kv_cache(
|
||||
kv_caches: dict[str, torch.Tensor],
|
||||
forward_context: dict[str, "Attention"],
|
||||
runner_kv_caches: list[torch.Tensor],
|
||||
num_attn_module: Optional[int] = 1,
|
||||
) -> None:
|
||||
"""
|
||||
Bind the allocated KV cache to both ModelRunner and forward context so
|
||||
@ -289,7 +290,8 @@ def bind_kv_cache(
|
||||
# Convert kv_caches dict to a list of tensors in the order of layer_index.
|
||||
index2name = defaultdict(list)
|
||||
for layer_name in kv_caches:
|
||||
index2name[extract_layer_index(layer_name)].append(layer_name)
|
||||
index2name[extract_layer_index(layer_name,
|
||||
num_attn_module)].append(layer_name)
|
||||
|
||||
for layer_index in sorted(index2name.keys()):
|
||||
layer_names = index2name[layer_index]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user