[Hardware][Gaudi][Feature] Enable Dynamic MoE for Mixtral (#12303)

Signed-off-by: zhenwei <zhenweiliu@habana.ai>
This commit is contained in:
liuzhenwei 2025-03-25 00:48:40 +08:00 committed by GitHub
parent 3aee6573dc
commit 5eeadc2642
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 57 additions and 2 deletions

View File

@ -213,6 +213,34 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
e_score_correction_bias,
)
def forward_hpu(
self,
layer: torch.nn.Module,
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
router_logits: torch.Tensor,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None
) -> torch.Tensor:
assert not use_grouped_topk
assert num_expert_group is None
assert topk_group is None
assert custom_routing_function is None
assert layer is not None
if scoring_func != "softmax":
raise NotImplementedError(
"Only softmax scoring function is supported for HPU.")
if e_score_correction_bias is not None:
raise NotImplementedError(
"Expert score correction bias is not supported for HPU.")
return layer.hpu_fused_moe(x, layer.w13_weight, layer.w2_weight,
router_logits, top_k)
def forward_tpu(
self,
layer: torch.nn.Module,
@ -411,6 +439,9 @@ class FusedMoE(torch.nn.Module):
if self.scoring_func != "softmax" and not self.use_grouped_topk:
raise ValueError("Only softmax scoring function is supported for "
"non-grouped topk.")
if current_platform.is_hpu():
from vllm_hpu_extension.ops import DynamicFusedMOE
self.hpu_fused_moe = DynamicFusedMOE(self.num_experts)
# Note: get_quant_method will look at the layer's local_num_experts
# for heuristic purposes, so it must be initialized first.

View File

@ -387,6 +387,16 @@ class DefaultModelLoader(BaseModelLoader):
weights_iterator = _xla_weights_iterator(weights_iterator)
elif current_platform.is_hpu():
import habana_frameworks.torch.core as htcore
def _hpu_weights_iterator(iterator: Generator):
for weights in iterator:
yield weights
htcore.mark_step()
weights_iterator = _hpu_weights_iterator(weights_iterator)
if self.counter_before_loading_weights == 0.0:
self.counter_before_loading_weights = time.perf_counter()
# Apply the prefix.

View File

@ -376,8 +376,22 @@ class HpuModelAdapter:
mask = mask >= metadata.block_usage.unsqueeze(-1)
attn_bias = (torch.zeros_like(mask, dtype=dtype).masked_fill_(
mask, -math.inf))
block_mapping = torch.nn.functional.one_hot(metadata.block_groups,
num_classes=batch_size)
if os.environ.get('VLLM_USE_FAKE_HPU',
'0') == '0' and htorch.utils.internal.is_lazy():
block_mapping = torch.nn.functional.one_hot(metadata.block_groups,
num_classes=batch_size)
else:
# Unfortunately one_hot on CPU/torch.compile mode/eager mode
# doesn't handle out of bounds classes so we need to convert
# all negative values to 0 (block_mapping) or bs (block_groups)
block_groups = metadata.block_groups.to(torch.long)
block_mapping = torch.nn.functional.relu(block_groups)
block_mapping = torch.nn.functional.one_hot(block_mapping,
num_classes=batch_size)
oob_values = block_groups.lt(0)
block_mapping.masked_fill_(oob_values.unsqueeze(-1), 0)
block_groups.masked_fill_(oob_values, batch_size)
metadata = metadata._replace(block_groups=block_groups)
block_mapping = block_mapping.to(dtype)
metadata = metadata._replace(block_mapping=block_mapping,
attn_bias=attn_bias)