diff --git a/tests/kernels/moe/test_batched_moe.py b/tests/kernels/moe/test_batched_moe.py index 7d369edfc86a4..3f38b9fbcb3ca 100644 --- a/tests/kernels/moe/test_batched_moe.py +++ b/tests/kernels/moe/test_batched_moe.py @@ -28,17 +28,27 @@ class BatchedMMTensors: @staticmethod def make_tensors(config: BatchedMMConfig): + if config.dtype == torch.torch.float8_e4m3fn: + config_dtype = torch.bfloat16 + else: + config_dtype = config.dtype + A = torch.randn( (config.num_experts, config.max_tokens_per_expert, config.K), device="cuda", - dtype=config.dtype) / 10 + dtype=config_dtype) / 10 B = torch.randn((config.num_experts, config.N, config.K), device="cuda", - dtype=config.dtype) + dtype=config_dtype) C = torch.zeros( (config.num_experts, config.max_tokens_per_expert, config.N), device="cuda", - dtype=config.dtype) + dtype=config_dtype) + + A = A.to(config.dtype) + B = B.to(config.dtype) + C = C.to(config.dtype) + num_expert_tokens = torch.randint(low=0, high=config.max_tokens_per_expert, size=(config.num_experts, ), @@ -66,8 +76,9 @@ def ref_impl(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, [32, 64, 128, 192, 224, 256, 512]) @pytest.mark.parametrize("K", [128, 256, 1024]) @pytest.mark.parametrize("N", [128, 256, 512, 1024]) -@pytest.mark.parametrize("dtype", - [torch.float32, torch.float16, torch.bfloat16]) +@pytest.mark.parametrize( + "dtype", + [torch.torch.float8_e4m3fn, torch.float32, torch.float16, torch.bfloat16]) def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int, N: int, dtype: torch.dtype): @@ -78,6 +89,7 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int, ref_output = test_output.clone() compute_tl_dtype = { + torch.torch.float8_e4m3fn: tl.bfloat16, torch.float16: tl.float16, torch.bfloat16: tl.bfloat16, torch.float32: tl.float32 @@ -93,7 +105,7 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int, None, None, # Quantization schemes - False, + dtype == torch.torch.float8_e4m3fn, False, False, config={ @@ -106,6 +118,7 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int, tensors.num_expert_tokens) rtol, atol = { + torch.torch.float8_e4m3fn: (6e-2, 6e-2), torch.float16: (6e-2, 6e-2), torch.bfloat16: (6e-2, 6e-2), torch.float32: (1e-2, 1e-2), diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py index 5c262287f7dd4..24ab494e9908c 100644 --- a/vllm/model_executor/layers/fused_moe/__init__.py +++ b/vllm/model_executor/layers/fused_moe/__init__.py @@ -4,7 +4,8 @@ from contextlib import contextmanager from typing import Any, Optional from vllm.model_executor.layers.fused_moe.layer import ( - FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) + MOE_DP_CHUNK_SIZE, FusedMoE, FusedMoEMethodBase, + FusedMoeWeightScaleSupported) from vllm.triton_utils import HAS_TRITON _config: Optional[dict[str, Any]] = None @@ -29,6 +30,7 @@ __all__ = [ "FusedMoeWeightScaleSupported", "override_config", "get_config", + "MOE_DP_CHUNK_SIZE", ] if HAS_TRITON: diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index c2db793659312..50c9e21af6a04 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -9,7 +9,8 @@ import triton.language as tl import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.model_executor.layers.fused_moe.fused_moe import ( get_config_dtype_str, try_get_optimal_moe_config) -from vllm.model_executor.layers.fused_moe.utils import _resize_cache +from vllm.model_executor.layers.fused_moe.utils import (_fp8_quantize, + _resize_cache) @triton.jit @@ -733,12 +734,27 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute): #qintermediate_cache2 = intermediate_cache2 a2q_scale = a2_scale # TODO (varun) : support w8a8 - assert not self.use_fp8_w8a8 - #if self.use_fp8_w8a8: - # qintermediate_cache2, a2q_scale = _fp8_quantize( - # intermediate_cache2, a2_scale, self.block_shape) + #assert not self.use_fp8_w8a8 + if self.use_fp8_w8a8: + per_act_token = False + qintermediate_cache2 = torch.empty_like(intermediate_cache2, + dtype=torch.float8_e4m3fn) + if per_act_token: + scale_shape = (E, num_tokens, 1) + else: + scale_shape = (E, 1) + a2q_scale = torch.empty(scale_shape, + dtype=torch.float32, + device=hidden_states.device) + for e in range(E): + qintermediate_cache2[e], a2q_scale[e] = _fp8_quantize( + intermediate_cache2[e, :expert_num_tokens[e]], + a2_scale[e] if a2_scale is not None else None, + per_act_token, self.block_shape) + else: + qintermediate_cache2 = intermediate_cache2 - invoke_moe_batched_triton_kernel(A=intermediate_cache2, + invoke_moe_batched_triton_kernel(A=qintermediate_cache2, B=w2, C=intermediate_cache3, expert_num_tokens=expert_num_tokens, diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 838a7c24b642f..0f9058a3feedd 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -56,7 +56,7 @@ logger = init_logger(__name__) # Note: this limit is somewhat arbitrary and might be changed later. # The size of the activations will be E x MOE_DP_CHUNK_SIZE x hidden_dim. -MOE_DP_CHUNK_SIZE = 256 +MOE_DP_CHUNK_SIZE = 128 @dataclass @@ -72,7 +72,7 @@ class FusedMoEParallelConfig: @property def use_pplx_kernels(self): - return self.dp_size > 1 and self.use_ep and \ + return self.dp_size > 1 and self.use_ep and has_pplx and \ envs.VLLM_ALL2ALL_BACKEND == "pplx" @staticmethod @@ -184,6 +184,7 @@ class FusedMoEParallelConfig: # Adapted from pplx-kernels tests/all_to_all_utils.py @dataclass class MoEConfig: + max_num_tokens: int num_experts: int experts_per_token: int hidden_dim: int @@ -471,6 +472,47 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): activation=activation, apply_router_weight_on_input=apply_router_weight_on_input) + def set_prepare_finalize( + self, + dp_size: int, + world_size: int, + prepare_finalize: FusedMoEPrepareAndFinalize, + ) -> bool: + assert self.fused_experts == fused_experts + + experts: Optional[FusedMoEPermuteExpertsUnpermute] = None + + if isinstance(prepare_finalize, + (BatchedPrepareAndFinalize, PplxPrepareAndFinalize)): + logger.debug("BatchedTritonExperts %s", self.moe) + experts = BatchedTritonExperts( + max_num_tokens=MOE_DP_CHUNK_SIZE, + world_size=world_size, + dp_size=dp_size, + use_fp8_w8a8=False, #moe.in_dtype == torch.float8_e4m3fn, + use_int8_w8a8=False, + use_int8_w8a16=False, + use_int4_w4a16=False, + block_shape=None, + ) + else: + logger.debug("TritonExperts %s", self.moe) + experts = TritonExperts( + use_fp8_w8a8=False, + use_int8_w8a8=False, + use_int8_w8a16=False, + use_int4_w4a16=False, + block_shape=None, + per_channel_quant=False, + ) + + self.fused_experts = FusedMoEModularKernel( + prepare_finalize, + experts, + ) + + return True + def forward_cuda( self, layer: torch.nn.Module, @@ -785,14 +827,17 @@ class FusedMoE(torch.nn.Module): from vllm_hpu_extension.ops import DynamicFusedMOE self.hpu_fused_moe = DynamicFusedMOE(self.global_num_experts) + logger.debug(f"PARAM DTYPE = {params_dtype}") + #assert params_dtype.itemsize == 1 + moe = MoEConfig( + max_num_tokens=MOE_DP_CHUNK_SIZE, num_experts=self.global_num_experts, experts_per_token=top_k, hidden_dim=hidden_size, num_local_experts=self.local_num_experts, moe_parallel_config=self.moe_parallel_config, - # TODO (bnell): this needs to be fixed for quantized types. - in_dtype=params_dtype, + in_dtype=moe.in_dtype, max_num_tokens=MOE_DP_CHUNK_SIZE, ) self.moe_config = moe @@ -1195,6 +1240,8 @@ class FusedMoE(torch.nn.Module): if indices_type is not None: topk_ids = topk_ids.to(dtype=indices_type) + assert topk_ids.dtype == indices_type + return topk_weights, topk_ids def must_reduce_shared_expert_outputs(self) -> bool: diff --git a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py index 783ebebbfec94..a449ac5ca596f 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py @@ -66,6 +66,10 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): per_act_token, self.block_shape) + if a1q_scale is not None and a1q_scale.dim() == 1: + assert a1q_scale.numel() == 1 + a1q_scale = a1q_scale.view(1, 1) + # rem_experts need to be 0 for pplx to work properly. rem_experts = num_experts % self.world_size assert rem_experts == 0 @@ -104,6 +108,8 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): # There's not much point setting this unless it is != indices.size(0) bound_m: Optional[torch.Tensor] = None + #print(f"SCALE= {a1q_scale.shape}") + self.a2a.dispatch( out_expert_num_tokens=expert_num_tokens, out_expert_x=expert_x, diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index ac9b74945e0ce..7effd0e1ad24b 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -13,7 +13,8 @@ import vllm.envs as envs from vllm import _custom_ops as ops from vllm.distributed import get_tensor_model_parallel_world_size from vllm.logger import init_logger -from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase, +from vllm.model_executor.layers.fused_moe import (MOE_DP_CHUNK_SIZE, FusedMoE, + FusedMoEMethodBase, FusedMoeWeightScaleSupported) from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, UnquantizedLinearMethod) @@ -461,9 +462,13 @@ class Fp8MoEMethod(FusedMoEMethodBase): self.fused_experts = functools.partial( # type: ignore fused_experts, + use_fp8_w8a8=True, block_shape=self.quant_config.weight_block_size, allow_deep_gemm=self.allow_deep_gemm) + self.use_pplx_kernels = False + self.rocm_aiter_moe_enabled = False + def create_weights(self, layer: Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, params_dtype: torch.dtype, **extra_weight_attrs): @@ -770,13 +775,26 @@ class Fp8MoEMethod(FusedMoEMethodBase): assert not self.use_marlin and not self.rocm_aiter_moe_enabled, ( "Marlin and ROCm AITER are not supported with all2all yet.") - experts = TritonOrDeepGemmExperts( - use_fp8_w8a8=True, - block_shape=self.quant_config.weight_block_size, - allow_deep_gemm=self.allow_deep_gemm, - ) - - return experts + if isinstance(prepare_finalize, + (BatchedPrepareAndFinalize, PplxPrepareAndFinalize)): + logger.debug("BatchedTritonExperts(fp8)") + return BatchedTritonExperts( + max_num_tokens=MOE_DP_CHUNK_SIZE, + world_size=world_size, + dp_size=dp_size, + use_fp8_w8a8=True, + use_int8_w8a8=False, + use_int8_w8a16=False, + use_int4_w4a16=False, + block_shape=self.quant_config.weight_block_size, + ) + else: + logger.debug("TritonOrDeepGemmExperts(fp8)") + return TritonOrDeepGemmExperts( + use_fp8_w8a8=True, + block_shape=self.quant_config.weight_block_size, + allow_deep_gemm=self.allow_deep_gemm, + ) def apply( self, @@ -807,7 +825,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): custom_routing_function=custom_routing_function, scoring_func=scoring_func, e_score_correction_bias=e_score_correction_bias, - ) + indices_type=torch.uint32 if self.use_pplx_kernels else None) if self.rocm_aiter_moe_enabled: from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501 @@ -854,7 +872,6 @@ class Fp8MoEMethod(FusedMoEMethodBase): topk_ids=topk_ids, inplace=True, activation=activation, - use_fp8_w8a8=True, global_num_experts=global_num_experts, apply_router_weight_on_input=apply_router_weight_on_input, expert_map=expert_map, diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index 0d52bc9a68148..fda0a5b554bcd 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -492,6 +492,11 @@ class MPClient(EngineCoreClient): (e for e in self.core_engines if e.identity == eng_identity), None) if engine is None: + msg = msgspec.msgpack.decode(ready_msg_bytes) + status, local = msg["status"], msg["local"] + logger.debug(f"XXXXXX {status} message from " + f"{'local' if local else 'remote'} " + f"engine {eng_index}") raise RuntimeError(f"Message from engine with unexpected data " f"parallel rank: {eng_index}") msg = msgspec.msgpack.decode(ready_msg_bytes)