mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-12 06:37:03 +08:00
fp8 support
Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
parent
269d901734
commit
e568e401da
@ -28,17 +28,27 @@ class BatchedMMTensors:
|
||||
|
||||
@staticmethod
|
||||
def make_tensors(config: BatchedMMConfig):
|
||||
if config.dtype == torch.torch.float8_e4m3fn:
|
||||
config_dtype = torch.bfloat16
|
||||
else:
|
||||
config_dtype = config.dtype
|
||||
|
||||
A = torch.randn(
|
||||
(config.num_experts, config.max_tokens_per_expert, config.K),
|
||||
device="cuda",
|
||||
dtype=config.dtype) / 10
|
||||
dtype=config_dtype) / 10
|
||||
B = torch.randn((config.num_experts, config.N, config.K),
|
||||
device="cuda",
|
||||
dtype=config.dtype)
|
||||
dtype=config_dtype)
|
||||
C = torch.zeros(
|
||||
(config.num_experts, config.max_tokens_per_expert, config.N),
|
||||
device="cuda",
|
||||
dtype=config.dtype)
|
||||
dtype=config_dtype)
|
||||
|
||||
A = A.to(config.dtype)
|
||||
B = B.to(config.dtype)
|
||||
C = C.to(config.dtype)
|
||||
|
||||
num_expert_tokens = torch.randint(low=0,
|
||||
high=config.max_tokens_per_expert,
|
||||
size=(config.num_experts, ),
|
||||
@ -66,8 +76,9 @@ def ref_impl(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
|
||||
[32, 64, 128, 192, 224, 256, 512])
|
||||
@pytest.mark.parametrize("K", [128, 256, 1024])
|
||||
@pytest.mark.parametrize("N", [128, 256, 512, 1024])
|
||||
@pytest.mark.parametrize("dtype",
|
||||
[torch.float32, torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize(
|
||||
"dtype",
|
||||
[torch.torch.float8_e4m3fn, torch.float32, torch.float16, torch.bfloat16])
|
||||
def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
|
||||
N: int, dtype: torch.dtype):
|
||||
|
||||
@ -78,6 +89,7 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
|
||||
ref_output = test_output.clone()
|
||||
|
||||
compute_tl_dtype = {
|
||||
torch.torch.float8_e4m3fn: tl.bfloat16,
|
||||
torch.float16: tl.float16,
|
||||
torch.bfloat16: tl.bfloat16,
|
||||
torch.float32: tl.float32
|
||||
@ -93,7 +105,7 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
|
||||
None,
|
||||
None,
|
||||
# Quantization schemes
|
||||
False,
|
||||
dtype == torch.torch.float8_e4m3fn,
|
||||
False,
|
||||
False,
|
||||
config={
|
||||
@ -106,6 +118,7 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
|
||||
tensors.num_expert_tokens)
|
||||
|
||||
rtol, atol = {
|
||||
torch.torch.float8_e4m3fn: (6e-2, 6e-2),
|
||||
torch.float16: (6e-2, 6e-2),
|
||||
torch.bfloat16: (6e-2, 6e-2),
|
||||
torch.float32: (1e-2, 1e-2),
|
||||
|
||||
@ -4,7 +4,8 @@ from contextlib import contextmanager
|
||||
from typing import Any, Optional
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.layer import (
|
||||
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
|
||||
MOE_DP_CHUNK_SIZE, FusedMoE, FusedMoEMethodBase,
|
||||
FusedMoeWeightScaleSupported)
|
||||
from vllm.triton_utils import HAS_TRITON
|
||||
|
||||
_config: Optional[dict[str, Any]] = None
|
||||
@ -29,6 +30,7 @@ __all__ = [
|
||||
"FusedMoeWeightScaleSupported",
|
||||
"override_config",
|
||||
"get_config",
|
||||
"MOE_DP_CHUNK_SIZE",
|
||||
]
|
||||
|
||||
if HAS_TRITON:
|
||||
|
||||
@ -9,7 +9,8 @@ import triton.language as tl
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
get_config_dtype_str, try_get_optimal_moe_config)
|
||||
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
|
||||
from vllm.model_executor.layers.fused_moe.utils import (_fp8_quantize,
|
||||
_resize_cache)
|
||||
|
||||
|
||||
@triton.jit
|
||||
@ -733,12 +734,27 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
#qintermediate_cache2 = intermediate_cache2
|
||||
a2q_scale = a2_scale
|
||||
# TODO (varun) : support w8a8
|
||||
assert not self.use_fp8_w8a8
|
||||
#if self.use_fp8_w8a8:
|
||||
# qintermediate_cache2, a2q_scale = _fp8_quantize(
|
||||
# intermediate_cache2, a2_scale, self.block_shape)
|
||||
#assert not self.use_fp8_w8a8
|
||||
if self.use_fp8_w8a8:
|
||||
per_act_token = False
|
||||
qintermediate_cache2 = torch.empty_like(intermediate_cache2,
|
||||
dtype=torch.float8_e4m3fn)
|
||||
if per_act_token:
|
||||
scale_shape = (E, num_tokens, 1)
|
||||
else:
|
||||
scale_shape = (E, 1)
|
||||
a2q_scale = torch.empty(scale_shape,
|
||||
dtype=torch.float32,
|
||||
device=hidden_states.device)
|
||||
for e in range(E):
|
||||
qintermediate_cache2[e], a2q_scale[e] = _fp8_quantize(
|
||||
intermediate_cache2[e, :expert_num_tokens[e]],
|
||||
a2_scale[e] if a2_scale is not None else None,
|
||||
per_act_token, self.block_shape)
|
||||
else:
|
||||
qintermediate_cache2 = intermediate_cache2
|
||||
|
||||
invoke_moe_batched_triton_kernel(A=intermediate_cache2,
|
||||
invoke_moe_batched_triton_kernel(A=qintermediate_cache2,
|
||||
B=w2,
|
||||
C=intermediate_cache3,
|
||||
expert_num_tokens=expert_num_tokens,
|
||||
|
||||
@ -56,7 +56,7 @@ logger = init_logger(__name__)
|
||||
|
||||
# Note: this limit is somewhat arbitrary and might be changed later.
|
||||
# The size of the activations will be E x MOE_DP_CHUNK_SIZE x hidden_dim.
|
||||
MOE_DP_CHUNK_SIZE = 256
|
||||
MOE_DP_CHUNK_SIZE = 128
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -72,7 +72,7 @@ class FusedMoEParallelConfig:
|
||||
|
||||
@property
|
||||
def use_pplx_kernels(self):
|
||||
return self.dp_size > 1 and self.use_ep and \
|
||||
return self.dp_size > 1 and self.use_ep and has_pplx and \
|
||||
envs.VLLM_ALL2ALL_BACKEND == "pplx"
|
||||
|
||||
@staticmethod
|
||||
@ -184,6 +184,7 @@ class FusedMoEParallelConfig:
|
||||
# Adapted from pplx-kernels tests/all_to_all_utils.py
|
||||
@dataclass
|
||||
class MoEConfig:
|
||||
max_num_tokens: int
|
||||
num_experts: int
|
||||
experts_per_token: int
|
||||
hidden_dim: int
|
||||
@ -471,6 +472,47 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input)
|
||||
|
||||
def set_prepare_finalize(
|
||||
self,
|
||||
dp_size: int,
|
||||
world_size: int,
|
||||
prepare_finalize: FusedMoEPrepareAndFinalize,
|
||||
) -> bool:
|
||||
assert self.fused_experts == fused_experts
|
||||
|
||||
experts: Optional[FusedMoEPermuteExpertsUnpermute] = None
|
||||
|
||||
if isinstance(prepare_finalize,
|
||||
(BatchedPrepareAndFinalize, PplxPrepareAndFinalize)):
|
||||
logger.debug("BatchedTritonExperts %s", self.moe)
|
||||
experts = BatchedTritonExperts(
|
||||
max_num_tokens=MOE_DP_CHUNK_SIZE,
|
||||
world_size=world_size,
|
||||
dp_size=dp_size,
|
||||
use_fp8_w8a8=False, #moe.in_dtype == torch.float8_e4m3fn,
|
||||
use_int8_w8a8=False,
|
||||
use_int8_w8a16=False,
|
||||
use_int4_w4a16=False,
|
||||
block_shape=None,
|
||||
)
|
||||
else:
|
||||
logger.debug("TritonExperts %s", self.moe)
|
||||
experts = TritonExperts(
|
||||
use_fp8_w8a8=False,
|
||||
use_int8_w8a8=False,
|
||||
use_int8_w8a16=False,
|
||||
use_int4_w4a16=False,
|
||||
block_shape=None,
|
||||
per_channel_quant=False,
|
||||
)
|
||||
|
||||
self.fused_experts = FusedMoEModularKernel(
|
||||
prepare_finalize,
|
||||
experts,
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
@ -785,14 +827,17 @@ class FusedMoE(torch.nn.Module):
|
||||
from vllm_hpu_extension.ops import DynamicFusedMOE
|
||||
self.hpu_fused_moe = DynamicFusedMOE(self.global_num_experts)
|
||||
|
||||
logger.debug(f"PARAM DTYPE = {params_dtype}")
|
||||
#assert params_dtype.itemsize == 1
|
||||
|
||||
moe = MoEConfig(
|
||||
max_num_tokens=MOE_DP_CHUNK_SIZE,
|
||||
num_experts=self.global_num_experts,
|
||||
experts_per_token=top_k,
|
||||
hidden_dim=hidden_size,
|
||||
num_local_experts=self.local_num_experts,
|
||||
moe_parallel_config=self.moe_parallel_config,
|
||||
# TODO (bnell): this needs to be fixed for quantized types.
|
||||
in_dtype=params_dtype,
|
||||
in_dtype=moe.in_dtype,
|
||||
max_num_tokens=MOE_DP_CHUNK_SIZE,
|
||||
)
|
||||
self.moe_config = moe
|
||||
@ -1195,6 +1240,8 @@ class FusedMoE(torch.nn.Module):
|
||||
if indices_type is not None:
|
||||
topk_ids = topk_ids.to(dtype=indices_type)
|
||||
|
||||
assert topk_ids.dtype == indices_type
|
||||
|
||||
return topk_weights, topk_ids
|
||||
|
||||
def must_reduce_shared_expert_outputs(self) -> bool:
|
||||
|
||||
@ -66,6 +66,10 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
per_act_token,
|
||||
self.block_shape)
|
||||
|
||||
if a1q_scale is not None and a1q_scale.dim() == 1:
|
||||
assert a1q_scale.numel() == 1
|
||||
a1q_scale = a1q_scale.view(1, 1)
|
||||
|
||||
# rem_experts need to be 0 for pplx to work properly.
|
||||
rem_experts = num_experts % self.world_size
|
||||
assert rem_experts == 0
|
||||
@ -104,6 +108,8 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
# There's not much point setting this unless it is != indices.size(0)
|
||||
bound_m: Optional[torch.Tensor] = None
|
||||
|
||||
#print(f"SCALE= {a1q_scale.shape}")
|
||||
|
||||
self.a2a.dispatch(
|
||||
out_expert_num_tokens=expert_num_tokens,
|
||||
out_expert_x=expert_x,
|
||||
|
||||
@ -13,7 +13,8 @@ import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
|
||||
from vllm.model_executor.layers.fused_moe import (MOE_DP_CHUNK_SIZE, FusedMoE,
|
||||
FusedMoEMethodBase,
|
||||
FusedMoeWeightScaleSupported)
|
||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
UnquantizedLinearMethod)
|
||||
@ -461,9 +462,13 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
self.fused_experts = functools.partial( # type: ignore
|
||||
fused_experts,
|
||||
use_fp8_w8a8=True,
|
||||
block_shape=self.quant_config.weight_block_size,
|
||||
allow_deep_gemm=self.allow_deep_gemm)
|
||||
|
||||
self.use_pplx_kernels = False
|
||||
self.rocm_aiter_moe_enabled = False
|
||||
|
||||
def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
|
||||
intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype, **extra_weight_attrs):
|
||||
@ -770,13 +775,26 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
assert not self.use_marlin and not self.rocm_aiter_moe_enabled, (
|
||||
"Marlin and ROCm AITER are not supported with all2all yet.")
|
||||
|
||||
experts = TritonOrDeepGemmExperts(
|
||||
use_fp8_w8a8=True,
|
||||
block_shape=self.quant_config.weight_block_size,
|
||||
allow_deep_gemm=self.allow_deep_gemm,
|
||||
)
|
||||
|
||||
return experts
|
||||
if isinstance(prepare_finalize,
|
||||
(BatchedPrepareAndFinalize, PplxPrepareAndFinalize)):
|
||||
logger.debug("BatchedTritonExperts(fp8)")
|
||||
return BatchedTritonExperts(
|
||||
max_num_tokens=MOE_DP_CHUNK_SIZE,
|
||||
world_size=world_size,
|
||||
dp_size=dp_size,
|
||||
use_fp8_w8a8=True,
|
||||
use_int8_w8a8=False,
|
||||
use_int8_w8a16=False,
|
||||
use_int4_w4a16=False,
|
||||
block_shape=self.quant_config.weight_block_size,
|
||||
)
|
||||
else:
|
||||
logger.debug("TritonOrDeepGemmExperts(fp8)")
|
||||
return TritonOrDeepGemmExperts(
|
||||
use_fp8_w8a8=True,
|
||||
block_shape=self.quant_config.weight_block_size,
|
||||
allow_deep_gemm=self.allow_deep_gemm,
|
||||
)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
@ -807,7 +825,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
)
|
||||
indices_type=torch.uint32 if self.use_pplx_kernels else None)
|
||||
|
||||
if self.rocm_aiter_moe_enabled:
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501
|
||||
@ -854,7 +872,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
activation=activation,
|
||||
use_fp8_w8a8=True,
|
||||
global_num_experts=global_num_experts,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
expert_map=expert_map,
|
||||
|
||||
@ -492,6 +492,11 @@ class MPClient(EngineCoreClient):
|
||||
(e for e in self.core_engines if e.identity == eng_identity),
|
||||
None)
|
||||
if engine is None:
|
||||
msg = msgspec.msgpack.decode(ready_msg_bytes)
|
||||
status, local = msg["status"], msg["local"]
|
||||
logger.debug(f"XXXXXX {status} message from "
|
||||
f"{'local' if local else 'remote'} "
|
||||
f"engine {eng_index}")
|
||||
raise RuntimeError(f"Message from engine with unexpected data "
|
||||
f"parallel rank: {eng_index}")
|
||||
msg = msgspec.msgpack.decode(ready_msg_bytes)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user