From 5eeadc264246d8d8b95012350bde14b1cc431147 Mon Sep 17 00:00:00 2001 From: liuzhenwei Date: Tue, 25 Mar 2025 00:48:40 +0800 Subject: [PATCH] [Hardware][Gaudi][Feature] Enable Dynamic MoE for Mixtral (#12303) Signed-off-by: zhenwei --- vllm/model_executor/layers/fused_moe/layer.py | 31 +++++++++++++++++++ vllm/model_executor/model_loader/loader.py | 10 ++++++ vllm/worker/hpu_model_runner.py | 18 +++++++++-- 3 files changed, 57 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 917643134645f..739d216e6e80c 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -213,6 +213,34 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): e_score_correction_bias, ) + def forward_hpu( + self, + layer: torch.nn.Module, + x: torch.Tensor, + use_grouped_topk: bool, + top_k: int, + router_logits: torch.Tensor, + renormalize: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None + ) -> torch.Tensor: + assert not use_grouped_topk + assert num_expert_group is None + assert topk_group is None + assert custom_routing_function is None + assert layer is not None + if scoring_func != "softmax": + raise NotImplementedError( + "Only softmax scoring function is supported for HPU.") + if e_score_correction_bias is not None: + raise NotImplementedError( + "Expert score correction bias is not supported for HPU.") + return layer.hpu_fused_moe(x, layer.w13_weight, layer.w2_weight, + router_logits, top_k) + def forward_tpu( self, layer: torch.nn.Module, @@ -411,6 +439,9 @@ class FusedMoE(torch.nn.Module): if self.scoring_func != "softmax" and not self.use_grouped_topk: raise ValueError("Only softmax scoring function is supported for " "non-grouped topk.") + if current_platform.is_hpu(): + from vllm_hpu_extension.ops import DynamicFusedMOE + self.hpu_fused_moe = DynamicFusedMOE(self.num_experts) # Note: get_quant_method will look at the layer's local_num_experts # for heuristic purposes, so it must be initialized first. diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index de04c6f89c2f1..c969f18b822c4 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -387,6 +387,16 @@ class DefaultModelLoader(BaseModelLoader): weights_iterator = _xla_weights_iterator(weights_iterator) + elif current_platform.is_hpu(): + import habana_frameworks.torch.core as htcore + + def _hpu_weights_iterator(iterator: Generator): + for weights in iterator: + yield weights + htcore.mark_step() + + weights_iterator = _hpu_weights_iterator(weights_iterator) + if self.counter_before_loading_weights == 0.0: self.counter_before_loading_weights = time.perf_counter() # Apply the prefix. diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index 4ac547ae326da..6b1593eb8235c 100644 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -376,8 +376,22 @@ class HpuModelAdapter: mask = mask >= metadata.block_usage.unsqueeze(-1) attn_bias = (torch.zeros_like(mask, dtype=dtype).masked_fill_( mask, -math.inf)) - block_mapping = torch.nn.functional.one_hot(metadata.block_groups, - num_classes=batch_size) + if os.environ.get('VLLM_USE_FAKE_HPU', + '0') == '0' and htorch.utils.internal.is_lazy(): + block_mapping = torch.nn.functional.one_hot(metadata.block_groups, + num_classes=batch_size) + else: + # Unfortunately one_hot on CPU/torch.compile mode/eager mode + # doesn't handle out of bounds classes so we need to convert + # all negative values to 0 (block_mapping) or bs (block_groups) + block_groups = metadata.block_groups.to(torch.long) + block_mapping = torch.nn.functional.relu(block_groups) + block_mapping = torch.nn.functional.one_hot(block_mapping, + num_classes=batch_size) + oob_values = block_groups.lt(0) + block_mapping.masked_fill_(oob_values.unsqueeze(-1), 0) + block_groups.masked_fill_(oob_values, batch_size) + metadata = metadata._replace(block_groups=block_groups) block_mapping = block_mapping.to(dtype) metadata = metadata._replace(block_mapping=block_mapping, attn_bias=attn_bias)