mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-12 06:37:03 +08:00
merge
Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
parent
468d16654a
commit
c169b05541
@ -184,7 +184,6 @@ class FusedMoEParallelConfig:
|
||||
# Adapted from pplx-kernels tests/all_to_all_utils.py
|
||||
@dataclass
|
||||
class MoEConfig:
|
||||
max_num_tokens: int
|
||||
num_experts: int
|
||||
experts_per_token: int
|
||||
hidden_dim: int
|
||||
@ -347,33 +346,18 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
all2all_manager = get_ep_group().device_communicator.all2all_manager
|
||||
assert all2all_manager is not None
|
||||
|
||||
experts: Optional[FusedMoEPermuteExpertsUnpermute] = None
|
||||
|
||||
if isinstance(prepare_finalize,
|
||||
(BatchedPrepareAndFinalize, PplxPrepareAndFinalize)):
|
||||
logger.debug("BatchedTritonExperts %s", self.moe)
|
||||
experts = BatchedTritonExperts(
|
||||
return BatchedTritonExperts(
|
||||
max_num_tokens=MOE_DP_CHUNK_SIZE,
|
||||
world_size=all2all_manager.world_size,
|
||||
# dp_size actually means tp_size, bug in pplx kernels
|
||||
dp_size=all2all_manager.tp_group.world_size,
|
||||
use_fp8_w8a8=False,
|
||||
use_int8_w8a8=False,
|
||||
use_int8_w8a16=False,
|
||||
use_int4_w4a16=False,
|
||||
block_shape=None,
|
||||
)
|
||||
else:
|
||||
logger.debug("TritonExperts %s", self.moe)
|
||||
experts = TritonExperts(
|
||||
use_fp8_w8a8=False,
|
||||
use_int8_w8a8=False,
|
||||
use_int8_w8a16=False,
|
||||
use_int4_w4a16=False,
|
||||
block_shape=None,
|
||||
per_channel_quant=False,
|
||||
)
|
||||
return experts
|
||||
return TritonExperts()
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||
hidden_size: int, intermediate_size_per_partition: int,
|
||||
@ -472,35 +456,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input)
|
||||
|
||||
def set_prepare_finalize(
|
||||
self,
|
||||
dp_size: int,
|
||||
world_size: int,
|
||||
prepare_finalize: FusedMoEPrepareAndFinalize,
|
||||
) -> bool:
|
||||
assert self.fused_experts == fused_experts
|
||||
|
||||
experts: Optional[FusedMoEPermuteExpertsUnpermute] = None
|
||||
|
||||
if isinstance(prepare_finalize,
|
||||
(BatchedPrepareAndFinalize, PplxPrepareAndFinalize)):
|
||||
logger.debug("BatchedTritonExperts %s", self.moe)
|
||||
experts = BatchedTritonExperts(
|
||||
max_num_tokens=MOE_DP_CHUNK_SIZE,
|
||||
world_size=world_size,
|
||||
dp_size=dp_size,
|
||||
)
|
||||
else:
|
||||
logger.debug("TritonExperts %s", self.moe)
|
||||
experts = TritonExperts()
|
||||
|
||||
self.fused_experts = FusedMoEModularKernel(
|
||||
prepare_finalize,
|
||||
experts,
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
@ -815,16 +770,14 @@ class FusedMoE(torch.nn.Module):
|
||||
from vllm_hpu_extension.ops import DynamicFusedMOE
|
||||
self.hpu_fused_moe = DynamicFusedMOE(self.global_num_experts)
|
||||
|
||||
logger.debug("Model dtype = %s", vllm_config.model_config.dtype)
|
||||
|
||||
moe = MoEConfig(
|
||||
max_num_tokens=MOE_DP_CHUNK_SIZE,
|
||||
num_experts=self.global_num_experts,
|
||||
experts_per_token=top_k,
|
||||
hidden_dim=hidden_size,
|
||||
num_local_experts=self.local_num_experts,
|
||||
moe_parallel_config=self.moe_parallel_config,
|
||||
in_dtype=moe.in_dtype,
|
||||
in_dtype=vllm_config.model_config.dtype,
|
||||
max_num_tokens=MOE_DP_CHUNK_SIZE,
|
||||
)
|
||||
self.moe_config = moe
|
||||
self.quant_config = quant_config
|
||||
@ -1281,7 +1234,7 @@ class FusedMoE(torch.nn.Module):
|
||||
|
||||
assert (self.batched_hidden_states.size(0) # type: ignore
|
||||
>= chunk_size)
|
||||
assert (self.batched_router_logits.size(0) # type: ignore
|
||||
assert (self.batched_router_logits.size(0) # type: ignore
|
||||
>= chunk_size)
|
||||
staged_hidden_states = self.batched_hidden_states[:
|
||||
chunk_size, :] # type: ignore
|
||||
|
||||
@ -108,8 +108,6 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
# There's not much point setting this unless it is != indices.size(0)
|
||||
bound_m: Optional[torch.Tensor] = None
|
||||
|
||||
#print(f"SCALE= {a1q_scale.shape}")
|
||||
|
||||
self.a2a.dispatch(
|
||||
out_expert_num_tokens=expert_num_tokens,
|
||||
out_expert_x=expert_x,
|
||||
|
||||
@ -11,7 +11,7 @@ from torch.nn.parameter import Parameter
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.distributed import get_ep_group, get_tensor_model_parallel_world_size
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe import (MOE_DP_CHUNK_SIZE, FusedMoE,
|
||||
FusedMoEMethodBase,
|
||||
@ -771,17 +771,26 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
def select_gemm_impl(self, prepare_finalize):
|
||||
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
|
||||
TritonOrDeepGemmExperts)
|
||||
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
||||
BatchedPrepareAndFinalize,
|
||||
BatchedTritonExperts)
|
||||
from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
|
||||
PplxPrepareAndFinalize)
|
||||
|
||||
assert not self.use_marlin and not self.rocm_aiter_moe_enabled, (
|
||||
"Marlin and ROCm AITER are not supported with all2all yet.")
|
||||
|
||||
all2all_manager = get_ep_group().device_communicator.all2all_manager
|
||||
assert all2all_manager is not None
|
||||
|
||||
if isinstance(prepare_finalize,
|
||||
(BatchedPrepareAndFinalize, PplxPrepareAndFinalize)):
|
||||
logger.debug("BatchedTritonExperts(fp8)")
|
||||
self.use_pplx_kernels = True
|
||||
return BatchedTritonExperts(
|
||||
max_num_tokens=MOE_DP_CHUNK_SIZE,
|
||||
world_size=world_size,
|
||||
dp_size=dp_size,
|
||||
world_size=all2all_manager.world_size,
|
||||
dp_size=all2all_manager.tp_group.world_size,
|
||||
qtype=torch.float8_e4m3fn,
|
||||
block_shape=self.quant_config.weight_block_size,
|
||||
per_act_token=False, #?
|
||||
|
||||
@ -492,11 +492,6 @@ class MPClient(EngineCoreClient):
|
||||
(e for e in self.core_engines if e.identity == eng_identity),
|
||||
None)
|
||||
if engine is None:
|
||||
msg = msgspec.msgpack.decode(ready_msg_bytes)
|
||||
status, local = msg["status"], msg["local"]
|
||||
logger.debug(f"XXXXXX {status} message from "
|
||||
f"{'local' if local else 'remote'} "
|
||||
f"engine {eng_index}")
|
||||
raise RuntimeError(f"Message from engine with unexpected data "
|
||||
f"parallel rank: {eng_index}")
|
||||
msg = msgspec.msgpack.decode(ready_msg_bytes)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user