fp8 support

Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
Bill Nell 2025-05-21 03:21:18 +00:00
parent 269d901734
commit e568e401da
7 changed files with 133 additions and 27 deletions

View File

@ -28,17 +28,27 @@ class BatchedMMTensors:
@staticmethod
def make_tensors(config: BatchedMMConfig):
if config.dtype == torch.torch.float8_e4m3fn:
config_dtype = torch.bfloat16
else:
config_dtype = config.dtype
A = torch.randn(
(config.num_experts, config.max_tokens_per_expert, config.K),
device="cuda",
dtype=config.dtype) / 10
dtype=config_dtype) / 10
B = torch.randn((config.num_experts, config.N, config.K),
device="cuda",
dtype=config.dtype)
dtype=config_dtype)
C = torch.zeros(
(config.num_experts, config.max_tokens_per_expert, config.N),
device="cuda",
dtype=config.dtype)
dtype=config_dtype)
A = A.to(config.dtype)
B = B.to(config.dtype)
C = C.to(config.dtype)
num_expert_tokens = torch.randint(low=0,
high=config.max_tokens_per_expert,
size=(config.num_experts, ),
@ -66,8 +76,9 @@ def ref_impl(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
[32, 64, 128, 192, 224, 256, 512])
@pytest.mark.parametrize("K", [128, 256, 1024])
@pytest.mark.parametrize("N", [128, 256, 512, 1024])
@pytest.mark.parametrize("dtype",
[torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize(
"dtype",
[torch.torch.float8_e4m3fn, torch.float32, torch.float16, torch.bfloat16])
def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
N: int, dtype: torch.dtype):
@ -78,6 +89,7 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
ref_output = test_output.clone()
compute_tl_dtype = {
torch.torch.float8_e4m3fn: tl.bfloat16,
torch.float16: tl.float16,
torch.bfloat16: tl.bfloat16,
torch.float32: tl.float32
@ -93,7 +105,7 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
None,
None,
# Quantization schemes
False,
dtype == torch.torch.float8_e4m3fn,
False,
False,
config={
@ -106,6 +118,7 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
tensors.num_expert_tokens)
rtol, atol = {
torch.torch.float8_e4m3fn: (6e-2, 6e-2),
torch.float16: (6e-2, 6e-2),
torch.bfloat16: (6e-2, 6e-2),
torch.float32: (1e-2, 1e-2),

View File

@ -4,7 +4,8 @@ from contextlib import contextmanager
from typing import Any, Optional
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
MOE_DP_CHUNK_SIZE, FusedMoE, FusedMoEMethodBase,
FusedMoeWeightScaleSupported)
from vllm.triton_utils import HAS_TRITON
_config: Optional[dict[str, Any]] = None
@ -29,6 +30,7 @@ __all__ = [
"FusedMoeWeightScaleSupported",
"override_config",
"get_config",
"MOE_DP_CHUNK_SIZE",
]
if HAS_TRITON:

View File

@ -9,7 +9,8 @@ import triton.language as tl
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.fused_moe import (
get_config_dtype_str, try_get_optimal_moe_config)
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
from vllm.model_executor.layers.fused_moe.utils import (_fp8_quantize,
_resize_cache)
@triton.jit
@ -733,12 +734,27 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
#qintermediate_cache2 = intermediate_cache2
a2q_scale = a2_scale
# TODO (varun) : support w8a8
assert not self.use_fp8_w8a8
#if self.use_fp8_w8a8:
# qintermediate_cache2, a2q_scale = _fp8_quantize(
# intermediate_cache2, a2_scale, self.block_shape)
#assert not self.use_fp8_w8a8
if self.use_fp8_w8a8:
per_act_token = False
qintermediate_cache2 = torch.empty_like(intermediate_cache2,
dtype=torch.float8_e4m3fn)
if per_act_token:
scale_shape = (E, num_tokens, 1)
else:
scale_shape = (E, 1)
a2q_scale = torch.empty(scale_shape,
dtype=torch.float32,
device=hidden_states.device)
for e in range(E):
qintermediate_cache2[e], a2q_scale[e] = _fp8_quantize(
intermediate_cache2[e, :expert_num_tokens[e]],
a2_scale[e] if a2_scale is not None else None,
per_act_token, self.block_shape)
else:
qintermediate_cache2 = intermediate_cache2
invoke_moe_batched_triton_kernel(A=intermediate_cache2,
invoke_moe_batched_triton_kernel(A=qintermediate_cache2,
B=w2,
C=intermediate_cache3,
expert_num_tokens=expert_num_tokens,

View File

@ -56,7 +56,7 @@ logger = init_logger(__name__)
# Note: this limit is somewhat arbitrary and might be changed later.
# The size of the activations will be E x MOE_DP_CHUNK_SIZE x hidden_dim.
MOE_DP_CHUNK_SIZE = 256
MOE_DP_CHUNK_SIZE = 128
@dataclass
@ -72,7 +72,7 @@ class FusedMoEParallelConfig:
@property
def use_pplx_kernels(self):
return self.dp_size > 1 and self.use_ep and \
return self.dp_size > 1 and self.use_ep and has_pplx and \
envs.VLLM_ALL2ALL_BACKEND == "pplx"
@staticmethod
@ -184,6 +184,7 @@ class FusedMoEParallelConfig:
# Adapted from pplx-kernels tests/all_to_all_utils.py
@dataclass
class MoEConfig:
max_num_tokens: int
num_experts: int
experts_per_token: int
hidden_dim: int
@ -471,6 +472,47 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input)
def set_prepare_finalize(
self,
dp_size: int,
world_size: int,
prepare_finalize: FusedMoEPrepareAndFinalize,
) -> bool:
assert self.fused_experts == fused_experts
experts: Optional[FusedMoEPermuteExpertsUnpermute] = None
if isinstance(prepare_finalize,
(BatchedPrepareAndFinalize, PplxPrepareAndFinalize)):
logger.debug("BatchedTritonExperts %s", self.moe)
experts = BatchedTritonExperts(
max_num_tokens=MOE_DP_CHUNK_SIZE,
world_size=world_size,
dp_size=dp_size,
use_fp8_w8a8=False, #moe.in_dtype == torch.float8_e4m3fn,
use_int8_w8a8=False,
use_int8_w8a16=False,
use_int4_w4a16=False,
block_shape=None,
)
else:
logger.debug("TritonExperts %s", self.moe)
experts = TritonExperts(
use_fp8_w8a8=False,
use_int8_w8a8=False,
use_int8_w8a16=False,
use_int4_w4a16=False,
block_shape=None,
per_channel_quant=False,
)
self.fused_experts = FusedMoEModularKernel(
prepare_finalize,
experts,
)
return True
def forward_cuda(
self,
layer: torch.nn.Module,
@ -785,14 +827,17 @@ class FusedMoE(torch.nn.Module):
from vllm_hpu_extension.ops import DynamicFusedMOE
self.hpu_fused_moe = DynamicFusedMOE(self.global_num_experts)
logger.debug(f"PARAM DTYPE = {params_dtype}")
#assert params_dtype.itemsize == 1
moe = MoEConfig(
max_num_tokens=MOE_DP_CHUNK_SIZE,
num_experts=self.global_num_experts,
experts_per_token=top_k,
hidden_dim=hidden_size,
num_local_experts=self.local_num_experts,
moe_parallel_config=self.moe_parallel_config,
# TODO (bnell): this needs to be fixed for quantized types.
in_dtype=params_dtype,
in_dtype=moe.in_dtype,
max_num_tokens=MOE_DP_CHUNK_SIZE,
)
self.moe_config = moe
@ -1195,6 +1240,8 @@ class FusedMoE(torch.nn.Module):
if indices_type is not None:
topk_ids = topk_ids.to(dtype=indices_type)
assert topk_ids.dtype == indices_type
return topk_weights, topk_ids
def must_reduce_shared_expert_outputs(self) -> bool:

View File

@ -66,6 +66,10 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
per_act_token,
self.block_shape)
if a1q_scale is not None and a1q_scale.dim() == 1:
assert a1q_scale.numel() == 1
a1q_scale = a1q_scale.view(1, 1)
# rem_experts need to be 0 for pplx to work properly.
rem_experts = num_experts % self.world_size
assert rem_experts == 0
@ -104,6 +108,8 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
# There's not much point setting this unless it is != indices.size(0)
bound_m: Optional[torch.Tensor] = None
#print(f"SCALE= {a1q_scale.shape}")
self.a2a.dispatch(
out_expert_num_tokens=expert_num_tokens,
out_expert_x=expert_x,

View File

@ -13,7 +13,8 @@ import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
from vllm.model_executor.layers.fused_moe import (MOE_DP_CHUNK_SIZE, FusedMoE,
FusedMoEMethodBase,
FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod)
@ -461,9 +462,13 @@ class Fp8MoEMethod(FusedMoEMethodBase):
self.fused_experts = functools.partial( # type: ignore
fused_experts,
use_fp8_w8a8=True,
block_shape=self.quant_config.weight_block_size,
allow_deep_gemm=self.allow_deep_gemm)
self.use_pplx_kernels = False
self.rocm_aiter_moe_enabled = False
def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs):
@ -770,13 +775,26 @@ class Fp8MoEMethod(FusedMoEMethodBase):
assert not self.use_marlin and not self.rocm_aiter_moe_enabled, (
"Marlin and ROCm AITER are not supported with all2all yet.")
experts = TritonOrDeepGemmExperts(
use_fp8_w8a8=True,
block_shape=self.quant_config.weight_block_size,
allow_deep_gemm=self.allow_deep_gemm,
)
return experts
if isinstance(prepare_finalize,
(BatchedPrepareAndFinalize, PplxPrepareAndFinalize)):
logger.debug("BatchedTritonExperts(fp8)")
return BatchedTritonExperts(
max_num_tokens=MOE_DP_CHUNK_SIZE,
world_size=world_size,
dp_size=dp_size,
use_fp8_w8a8=True,
use_int8_w8a8=False,
use_int8_w8a16=False,
use_int4_w4a16=False,
block_shape=self.quant_config.weight_block_size,
)
else:
logger.debug("TritonOrDeepGemmExperts(fp8)")
return TritonOrDeepGemmExperts(
use_fp8_w8a8=True,
block_shape=self.quant_config.weight_block_size,
allow_deep_gemm=self.allow_deep_gemm,
)
def apply(
self,
@ -807,7 +825,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
)
indices_type=torch.uint32 if self.use_pplx_kernels else None)
if self.rocm_aiter_moe_enabled:
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501
@ -854,7 +872,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
topk_ids=topk_ids,
inplace=True,
activation=activation,
use_fp8_w8a8=True,
global_num_experts=global_num_experts,
apply_router_weight_on_input=apply_router_weight_on_input,
expert_map=expert_map,

View File

@ -492,6 +492,11 @@ class MPClient(EngineCoreClient):
(e for e in self.core_engines if e.identity == eng_identity),
None)
if engine is None:
msg = msgspec.msgpack.decode(ready_msg_bytes)
status, local = msg["status"], msg["local"]
logger.debug(f"XXXXXX {status} message from "
f"{'local' if local else 'remote'} "
f"engine {eng_index}")
raise RuntimeError(f"Message from engine with unexpected data "
f"parallel rank: {eng_index}")
msg = msgspec.msgpack.decode(ready_msg_bytes)