diff --git a/vllm/model_executor/models/llama4.py b/vllm/model_executor/models/llama4.py index 8785e9dcff08a..51efbfe202f0b 100644 --- a/vllm/model_executor/models/llama4.py +++ b/vllm/model_executor/models/llama4.py @@ -37,7 +37,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.model_loader.weight_utils import default_weight_loader from .llama import LlamaForCausalLM, LlamaMLP, LlamaModel -from .utils import (AutoWeightsLoader, extract_layer_index, +from .utils import (AutoWeightsLoader, extract_layer_index, fast_topk, is_pp_missing_parameter) @@ -50,7 +50,7 @@ class Llama4MoE(nn.Module): topk: int, renormalize: bool, ) -> Tuple[torch.Tensor, torch.Tensor]: - router_scores, router_indices = torch.topk(gating_output, topk, dim=-1) + router_scores, router_indices = fast_topk(gating_output, topk, dim=-1) router_scores = torch.sigmoid(router_scores.float()).to( hidden_states.dtype) return (router_scores, router_indices.to(torch.int32)) diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index f197434f31432..7ed0560ee43fe 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -703,3 +703,12 @@ def cast_overflow_tensors( clamp_value = torch.finfo(tensors.dtype).max - offset tensors = torch.clamp(tensors, min=-clamp_value, max=clamp_value) return tensors + + +def fast_topk(values, topk, dim): + if topk == 1: + # Use max along the specified dimension to get both value and index + return torch.max(values, dim=dim, keepdim=True) + else: + # Use topk for efficiency with larger k values + return torch.topk(values, topk, dim=dim)