mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-01 22:20:56 +08:00
[EPLB][ROCm]: support EPBL for ROCm backend (#27731)
Signed-off-by: Perry Zhang <perzhang@amd.com> Co-authored-by: Perry Zhang <perzhang@amd.com>
This commit is contained in:
parent
bac904565f
commit
a1e7fa362a
@ -278,10 +278,10 @@ class ParallelConfig:
|
||||
)
|
||||
|
||||
if self.enable_eplb:
|
||||
if not current_platform.is_cuda():
|
||||
if not current_platform.is_cuda_alike():
|
||||
raise ValueError(
|
||||
"Expert parallelism load balancing is only supported on "
|
||||
"CUDA devices now."
|
||||
"CUDA devices or ROCm devices now."
|
||||
)
|
||||
if not self.enable_expert_parallel:
|
||||
raise ValueError("enable_expert_parallel must be True to use EPLB.")
|
||||
|
||||
@ -1218,7 +1218,11 @@ class FusedMoE(CustomOp):
|
||||
|
||||
def get_expert_weights(self) -> Iterable[torch.Tensor]:
|
||||
weights = list(self.named_parameters())
|
||||
assert all(weight.is_contiguous() for _, weight in weights)
|
||||
assert all(
|
||||
weight.is_contiguous()
|
||||
for name, weight in weights
|
||||
if not name.startswith("_shared_experts.")
|
||||
)
|
||||
|
||||
# Filter out the non-expert weights.
|
||||
# `e_score_correction_bias` is a bias for each logical expert,
|
||||
|
||||
@ -0,0 +1,3 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
@ -1019,9 +1019,10 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
logical_replica_count: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
if enable_eplb:
|
||||
raise NotImplementedError(
|
||||
"EPLB not supported for `CompressedTensorsW8A8Fp8MoEMethod` yet."
|
||||
)
|
||||
assert expert_load_view is not None
|
||||
assert logical_to_physical_map is not None
|
||||
assert logical_replica_count is not None
|
||||
assert isinstance(layer, FusedMoE)
|
||||
|
||||
topk_weights, topk_ids, _ = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
@ -1037,6 +1038,11 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
indices_type=self.topk_indices_dtype,
|
||||
num_fused_shared_experts=layer.num_fused_shared_experts,
|
||||
enable_eplb=enable_eplb,
|
||||
expert_map=expert_map,
|
||||
expert_load_view=expert_load_view,
|
||||
logical_to_physical_map=logical_to_physical_map,
|
||||
logical_replica_count=logical_replica_count,
|
||||
)
|
||||
|
||||
per_act_token = self.input_quant.strategy == QuantizationStrategy.TOKEN
|
||||
@ -1145,6 +1151,10 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
quant_config=self.moe_quant_config,
|
||||
)
|
||||
|
||||
@property
|
||||
def supports_eplb(self) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
|
||||
def __init__(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user