mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-10 04:27:02 +08:00
[Hardware][Gaudi][Feature] Enable Dynamic MoE for Mixtral (#12303)
Signed-off-by: zhenwei <zhenweiliu@habana.ai>
This commit is contained in:
parent
3aee6573dc
commit
5eeadc2642
@ -213,6 +213,34 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
e_score_correction_bias,
|
||||
)
|
||||
|
||||
def forward_hpu(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
use_grouped_topk: bool,
|
||||
top_k: int,
|
||||
router_logits: torch.Tensor,
|
||||
renormalize: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
assert not use_grouped_topk
|
||||
assert num_expert_group is None
|
||||
assert topk_group is None
|
||||
assert custom_routing_function is None
|
||||
assert layer is not None
|
||||
if scoring_func != "softmax":
|
||||
raise NotImplementedError(
|
||||
"Only softmax scoring function is supported for HPU.")
|
||||
if e_score_correction_bias is not None:
|
||||
raise NotImplementedError(
|
||||
"Expert score correction bias is not supported for HPU.")
|
||||
return layer.hpu_fused_moe(x, layer.w13_weight, layer.w2_weight,
|
||||
router_logits, top_k)
|
||||
|
||||
def forward_tpu(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
@ -411,6 +439,9 @@ class FusedMoE(torch.nn.Module):
|
||||
if self.scoring_func != "softmax" and not self.use_grouped_topk:
|
||||
raise ValueError("Only softmax scoring function is supported for "
|
||||
"non-grouped topk.")
|
||||
if current_platform.is_hpu():
|
||||
from vllm_hpu_extension.ops import DynamicFusedMOE
|
||||
self.hpu_fused_moe = DynamicFusedMOE(self.num_experts)
|
||||
|
||||
# Note: get_quant_method will look at the layer's local_num_experts
|
||||
# for heuristic purposes, so it must be initialized first.
|
||||
|
||||
@ -387,6 +387,16 @@ class DefaultModelLoader(BaseModelLoader):
|
||||
|
||||
weights_iterator = _xla_weights_iterator(weights_iterator)
|
||||
|
||||
elif current_platform.is_hpu():
|
||||
import habana_frameworks.torch.core as htcore
|
||||
|
||||
def _hpu_weights_iterator(iterator: Generator):
|
||||
for weights in iterator:
|
||||
yield weights
|
||||
htcore.mark_step()
|
||||
|
||||
weights_iterator = _hpu_weights_iterator(weights_iterator)
|
||||
|
||||
if self.counter_before_loading_weights == 0.0:
|
||||
self.counter_before_loading_weights = time.perf_counter()
|
||||
# Apply the prefix.
|
||||
|
||||
@ -376,8 +376,22 @@ class HpuModelAdapter:
|
||||
mask = mask >= metadata.block_usage.unsqueeze(-1)
|
||||
attn_bias = (torch.zeros_like(mask, dtype=dtype).masked_fill_(
|
||||
mask, -math.inf))
|
||||
block_mapping = torch.nn.functional.one_hot(metadata.block_groups,
|
||||
num_classes=batch_size)
|
||||
if os.environ.get('VLLM_USE_FAKE_HPU',
|
||||
'0') == '0' and htorch.utils.internal.is_lazy():
|
||||
block_mapping = torch.nn.functional.one_hot(metadata.block_groups,
|
||||
num_classes=batch_size)
|
||||
else:
|
||||
# Unfortunately one_hot on CPU/torch.compile mode/eager mode
|
||||
# doesn't handle out of bounds classes so we need to convert
|
||||
# all negative values to 0 (block_mapping) or bs (block_groups)
|
||||
block_groups = metadata.block_groups.to(torch.long)
|
||||
block_mapping = torch.nn.functional.relu(block_groups)
|
||||
block_mapping = torch.nn.functional.one_hot(block_mapping,
|
||||
num_classes=batch_size)
|
||||
oob_values = block_groups.lt(0)
|
||||
block_mapping.masked_fill_(oob_values.unsqueeze(-1), 0)
|
||||
block_groups.masked_fill_(oob_values, batch_size)
|
||||
metadata = metadata._replace(block_groups=block_groups)
|
||||
block_mapping = block_mapping.to(dtype)
|
||||
metadata = metadata._replace(block_mapping=block_mapping,
|
||||
attn_bias=attn_bias)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user