mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-02 22:04:29 +08:00
fp8 support
Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
parent
269d901734
commit
e568e401da
@ -28,17 +28,27 @@ class BatchedMMTensors:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def make_tensors(config: BatchedMMConfig):
|
def make_tensors(config: BatchedMMConfig):
|
||||||
|
if config.dtype == torch.torch.float8_e4m3fn:
|
||||||
|
config_dtype = torch.bfloat16
|
||||||
|
else:
|
||||||
|
config_dtype = config.dtype
|
||||||
|
|
||||||
A = torch.randn(
|
A = torch.randn(
|
||||||
(config.num_experts, config.max_tokens_per_expert, config.K),
|
(config.num_experts, config.max_tokens_per_expert, config.K),
|
||||||
device="cuda",
|
device="cuda",
|
||||||
dtype=config.dtype) / 10
|
dtype=config_dtype) / 10
|
||||||
B = torch.randn((config.num_experts, config.N, config.K),
|
B = torch.randn((config.num_experts, config.N, config.K),
|
||||||
device="cuda",
|
device="cuda",
|
||||||
dtype=config.dtype)
|
dtype=config_dtype)
|
||||||
C = torch.zeros(
|
C = torch.zeros(
|
||||||
(config.num_experts, config.max_tokens_per_expert, config.N),
|
(config.num_experts, config.max_tokens_per_expert, config.N),
|
||||||
device="cuda",
|
device="cuda",
|
||||||
dtype=config.dtype)
|
dtype=config_dtype)
|
||||||
|
|
||||||
|
A = A.to(config.dtype)
|
||||||
|
B = B.to(config.dtype)
|
||||||
|
C = C.to(config.dtype)
|
||||||
|
|
||||||
num_expert_tokens = torch.randint(low=0,
|
num_expert_tokens = torch.randint(low=0,
|
||||||
high=config.max_tokens_per_expert,
|
high=config.max_tokens_per_expert,
|
||||||
size=(config.num_experts, ),
|
size=(config.num_experts, ),
|
||||||
@ -66,8 +76,9 @@ def ref_impl(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
|
|||||||
[32, 64, 128, 192, 224, 256, 512])
|
[32, 64, 128, 192, 224, 256, 512])
|
||||||
@pytest.mark.parametrize("K", [128, 256, 1024])
|
@pytest.mark.parametrize("K", [128, 256, 1024])
|
||||||
@pytest.mark.parametrize("N", [128, 256, 512, 1024])
|
@pytest.mark.parametrize("N", [128, 256, 512, 1024])
|
||||||
@pytest.mark.parametrize("dtype",
|
@pytest.mark.parametrize(
|
||||||
[torch.float32, torch.float16, torch.bfloat16])
|
"dtype",
|
||||||
|
[torch.torch.float8_e4m3fn, torch.float32, torch.float16, torch.bfloat16])
|
||||||
def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
|
def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
|
||||||
N: int, dtype: torch.dtype):
|
N: int, dtype: torch.dtype):
|
||||||
|
|
||||||
@ -78,6 +89,7 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
|
|||||||
ref_output = test_output.clone()
|
ref_output = test_output.clone()
|
||||||
|
|
||||||
compute_tl_dtype = {
|
compute_tl_dtype = {
|
||||||
|
torch.torch.float8_e4m3fn: tl.bfloat16,
|
||||||
torch.float16: tl.float16,
|
torch.float16: tl.float16,
|
||||||
torch.bfloat16: tl.bfloat16,
|
torch.bfloat16: tl.bfloat16,
|
||||||
torch.float32: tl.float32
|
torch.float32: tl.float32
|
||||||
@ -93,7 +105,7 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
|
|||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
# Quantization schemes
|
# Quantization schemes
|
||||||
False,
|
dtype == torch.torch.float8_e4m3fn,
|
||||||
False,
|
False,
|
||||||
False,
|
False,
|
||||||
config={
|
config={
|
||||||
@ -106,6 +118,7 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
|
|||||||
tensors.num_expert_tokens)
|
tensors.num_expert_tokens)
|
||||||
|
|
||||||
rtol, atol = {
|
rtol, atol = {
|
||||||
|
torch.torch.float8_e4m3fn: (6e-2, 6e-2),
|
||||||
torch.float16: (6e-2, 6e-2),
|
torch.float16: (6e-2, 6e-2),
|
||||||
torch.bfloat16: (6e-2, 6e-2),
|
torch.bfloat16: (6e-2, 6e-2),
|
||||||
torch.float32: (1e-2, 1e-2),
|
torch.float32: (1e-2, 1e-2),
|
||||||
|
|||||||
@ -4,7 +4,8 @@ from contextlib import contextmanager
|
|||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from vllm.model_executor.layers.fused_moe.layer import (
|
from vllm.model_executor.layers.fused_moe.layer import (
|
||||||
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
|
MOE_DP_CHUNK_SIZE, FusedMoE, FusedMoEMethodBase,
|
||||||
|
FusedMoeWeightScaleSupported)
|
||||||
from vllm.triton_utils import HAS_TRITON
|
from vllm.triton_utils import HAS_TRITON
|
||||||
|
|
||||||
_config: Optional[dict[str, Any]] = None
|
_config: Optional[dict[str, Any]] = None
|
||||||
@ -29,6 +30,7 @@ __all__ = [
|
|||||||
"FusedMoeWeightScaleSupported",
|
"FusedMoeWeightScaleSupported",
|
||||||
"override_config",
|
"override_config",
|
||||||
"get_config",
|
"get_config",
|
||||||
|
"MOE_DP_CHUNK_SIZE",
|
||||||
]
|
]
|
||||||
|
|
||||||
if HAS_TRITON:
|
if HAS_TRITON:
|
||||||
|
|||||||
@ -9,7 +9,8 @@ import triton.language as tl
|
|||||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||||
get_config_dtype_str, try_get_optimal_moe_config)
|
get_config_dtype_str, try_get_optimal_moe_config)
|
||||||
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
|
from vllm.model_executor.layers.fused_moe.utils import (_fp8_quantize,
|
||||||
|
_resize_cache)
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
@ -733,12 +734,27 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
#qintermediate_cache2 = intermediate_cache2
|
#qintermediate_cache2 = intermediate_cache2
|
||||||
a2q_scale = a2_scale
|
a2q_scale = a2_scale
|
||||||
# TODO (varun) : support w8a8
|
# TODO (varun) : support w8a8
|
||||||
assert not self.use_fp8_w8a8
|
#assert not self.use_fp8_w8a8
|
||||||
#if self.use_fp8_w8a8:
|
if self.use_fp8_w8a8:
|
||||||
# qintermediate_cache2, a2q_scale = _fp8_quantize(
|
per_act_token = False
|
||||||
# intermediate_cache2, a2_scale, self.block_shape)
|
qintermediate_cache2 = torch.empty_like(intermediate_cache2,
|
||||||
|
dtype=torch.float8_e4m3fn)
|
||||||
|
if per_act_token:
|
||||||
|
scale_shape = (E, num_tokens, 1)
|
||||||
|
else:
|
||||||
|
scale_shape = (E, 1)
|
||||||
|
a2q_scale = torch.empty(scale_shape,
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=hidden_states.device)
|
||||||
|
for e in range(E):
|
||||||
|
qintermediate_cache2[e], a2q_scale[e] = _fp8_quantize(
|
||||||
|
intermediate_cache2[e, :expert_num_tokens[e]],
|
||||||
|
a2_scale[e] if a2_scale is not None else None,
|
||||||
|
per_act_token, self.block_shape)
|
||||||
|
else:
|
||||||
|
qintermediate_cache2 = intermediate_cache2
|
||||||
|
|
||||||
invoke_moe_batched_triton_kernel(A=intermediate_cache2,
|
invoke_moe_batched_triton_kernel(A=qintermediate_cache2,
|
||||||
B=w2,
|
B=w2,
|
||||||
C=intermediate_cache3,
|
C=intermediate_cache3,
|
||||||
expert_num_tokens=expert_num_tokens,
|
expert_num_tokens=expert_num_tokens,
|
||||||
|
|||||||
@ -56,7 +56,7 @@ logger = init_logger(__name__)
|
|||||||
|
|
||||||
# Note: this limit is somewhat arbitrary and might be changed later.
|
# Note: this limit is somewhat arbitrary and might be changed later.
|
||||||
# The size of the activations will be E x MOE_DP_CHUNK_SIZE x hidden_dim.
|
# The size of the activations will be E x MOE_DP_CHUNK_SIZE x hidden_dim.
|
||||||
MOE_DP_CHUNK_SIZE = 256
|
MOE_DP_CHUNK_SIZE = 128
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -72,7 +72,7 @@ class FusedMoEParallelConfig:
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def use_pplx_kernels(self):
|
def use_pplx_kernels(self):
|
||||||
return self.dp_size > 1 and self.use_ep and \
|
return self.dp_size > 1 and self.use_ep and has_pplx and \
|
||||||
envs.VLLM_ALL2ALL_BACKEND == "pplx"
|
envs.VLLM_ALL2ALL_BACKEND == "pplx"
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -184,6 +184,7 @@ class FusedMoEParallelConfig:
|
|||||||
# Adapted from pplx-kernels tests/all_to_all_utils.py
|
# Adapted from pplx-kernels tests/all_to_all_utils.py
|
||||||
@dataclass
|
@dataclass
|
||||||
class MoEConfig:
|
class MoEConfig:
|
||||||
|
max_num_tokens: int
|
||||||
num_experts: int
|
num_experts: int
|
||||||
experts_per_token: int
|
experts_per_token: int
|
||||||
hidden_dim: int
|
hidden_dim: int
|
||||||
@ -471,6 +472,47 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
activation=activation,
|
activation=activation,
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input)
|
apply_router_weight_on_input=apply_router_weight_on_input)
|
||||||
|
|
||||||
|
def set_prepare_finalize(
|
||||||
|
self,
|
||||||
|
dp_size: int,
|
||||||
|
world_size: int,
|
||||||
|
prepare_finalize: FusedMoEPrepareAndFinalize,
|
||||||
|
) -> bool:
|
||||||
|
assert self.fused_experts == fused_experts
|
||||||
|
|
||||||
|
experts: Optional[FusedMoEPermuteExpertsUnpermute] = None
|
||||||
|
|
||||||
|
if isinstance(prepare_finalize,
|
||||||
|
(BatchedPrepareAndFinalize, PplxPrepareAndFinalize)):
|
||||||
|
logger.debug("BatchedTritonExperts %s", self.moe)
|
||||||
|
experts = BatchedTritonExperts(
|
||||||
|
max_num_tokens=MOE_DP_CHUNK_SIZE,
|
||||||
|
world_size=world_size,
|
||||||
|
dp_size=dp_size,
|
||||||
|
use_fp8_w8a8=False, #moe.in_dtype == torch.float8_e4m3fn,
|
||||||
|
use_int8_w8a8=False,
|
||||||
|
use_int8_w8a16=False,
|
||||||
|
use_int4_w4a16=False,
|
||||||
|
block_shape=None,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.debug("TritonExperts %s", self.moe)
|
||||||
|
experts = TritonExperts(
|
||||||
|
use_fp8_w8a8=False,
|
||||||
|
use_int8_w8a8=False,
|
||||||
|
use_int8_w8a16=False,
|
||||||
|
use_int4_w4a16=False,
|
||||||
|
block_shape=None,
|
||||||
|
per_channel_quant=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.fused_experts = FusedMoEModularKernel(
|
||||||
|
prepare_finalize,
|
||||||
|
experts,
|
||||||
|
)
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
def forward_cuda(
|
def forward_cuda(
|
||||||
self,
|
self,
|
||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
@ -785,14 +827,17 @@ class FusedMoE(torch.nn.Module):
|
|||||||
from vllm_hpu_extension.ops import DynamicFusedMOE
|
from vllm_hpu_extension.ops import DynamicFusedMOE
|
||||||
self.hpu_fused_moe = DynamicFusedMOE(self.global_num_experts)
|
self.hpu_fused_moe = DynamicFusedMOE(self.global_num_experts)
|
||||||
|
|
||||||
|
logger.debug(f"PARAM DTYPE = {params_dtype}")
|
||||||
|
#assert params_dtype.itemsize == 1
|
||||||
|
|
||||||
moe = MoEConfig(
|
moe = MoEConfig(
|
||||||
|
max_num_tokens=MOE_DP_CHUNK_SIZE,
|
||||||
num_experts=self.global_num_experts,
|
num_experts=self.global_num_experts,
|
||||||
experts_per_token=top_k,
|
experts_per_token=top_k,
|
||||||
hidden_dim=hidden_size,
|
hidden_dim=hidden_size,
|
||||||
num_local_experts=self.local_num_experts,
|
num_local_experts=self.local_num_experts,
|
||||||
moe_parallel_config=self.moe_parallel_config,
|
moe_parallel_config=self.moe_parallel_config,
|
||||||
# TODO (bnell): this needs to be fixed for quantized types.
|
in_dtype=moe.in_dtype,
|
||||||
in_dtype=params_dtype,
|
|
||||||
max_num_tokens=MOE_DP_CHUNK_SIZE,
|
max_num_tokens=MOE_DP_CHUNK_SIZE,
|
||||||
)
|
)
|
||||||
self.moe_config = moe
|
self.moe_config = moe
|
||||||
@ -1195,6 +1240,8 @@ class FusedMoE(torch.nn.Module):
|
|||||||
if indices_type is not None:
|
if indices_type is not None:
|
||||||
topk_ids = topk_ids.to(dtype=indices_type)
|
topk_ids = topk_ids.to(dtype=indices_type)
|
||||||
|
|
||||||
|
assert topk_ids.dtype == indices_type
|
||||||
|
|
||||||
return topk_weights, topk_ids
|
return topk_weights, topk_ids
|
||||||
|
|
||||||
def must_reduce_shared_expert_outputs(self) -> bool:
|
def must_reduce_shared_expert_outputs(self) -> bool:
|
||||||
|
|||||||
@ -66,6 +66,10 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
per_act_token,
|
per_act_token,
|
||||||
self.block_shape)
|
self.block_shape)
|
||||||
|
|
||||||
|
if a1q_scale is not None and a1q_scale.dim() == 1:
|
||||||
|
assert a1q_scale.numel() == 1
|
||||||
|
a1q_scale = a1q_scale.view(1, 1)
|
||||||
|
|
||||||
# rem_experts need to be 0 for pplx to work properly.
|
# rem_experts need to be 0 for pplx to work properly.
|
||||||
rem_experts = num_experts % self.world_size
|
rem_experts = num_experts % self.world_size
|
||||||
assert rem_experts == 0
|
assert rem_experts == 0
|
||||||
@ -104,6 +108,8 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
# There's not much point setting this unless it is != indices.size(0)
|
# There's not much point setting this unless it is != indices.size(0)
|
||||||
bound_m: Optional[torch.Tensor] = None
|
bound_m: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
#print(f"SCALE= {a1q_scale.shape}")
|
||||||
|
|
||||||
self.a2a.dispatch(
|
self.a2a.dispatch(
|
||||||
out_expert_num_tokens=expert_num_tokens,
|
out_expert_num_tokens=expert_num_tokens,
|
||||||
out_expert_x=expert_x,
|
out_expert_x=expert_x,
|
||||||
|
|||||||
@ -13,7 +13,8 @@ import vllm.envs as envs
|
|||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
|
from vllm.model_executor.layers.fused_moe import (MOE_DP_CHUNK_SIZE, FusedMoE,
|
||||||
|
FusedMoEMethodBase,
|
||||||
FusedMoeWeightScaleSupported)
|
FusedMoeWeightScaleSupported)
|
||||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||||
UnquantizedLinearMethod)
|
UnquantizedLinearMethod)
|
||||||
@ -461,9 +462,13 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
|
|
||||||
self.fused_experts = functools.partial( # type: ignore
|
self.fused_experts = functools.partial( # type: ignore
|
||||||
fused_experts,
|
fused_experts,
|
||||||
|
use_fp8_w8a8=True,
|
||||||
block_shape=self.quant_config.weight_block_size,
|
block_shape=self.quant_config.weight_block_size,
|
||||||
allow_deep_gemm=self.allow_deep_gemm)
|
allow_deep_gemm=self.allow_deep_gemm)
|
||||||
|
|
||||||
|
self.use_pplx_kernels = False
|
||||||
|
self.rocm_aiter_moe_enabled = False
|
||||||
|
|
||||||
def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
|
def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
|
||||||
intermediate_size_per_partition: int,
|
intermediate_size_per_partition: int,
|
||||||
params_dtype: torch.dtype, **extra_weight_attrs):
|
params_dtype: torch.dtype, **extra_weight_attrs):
|
||||||
@ -770,13 +775,26 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
assert not self.use_marlin and not self.rocm_aiter_moe_enabled, (
|
assert not self.use_marlin and not self.rocm_aiter_moe_enabled, (
|
||||||
"Marlin and ROCm AITER are not supported with all2all yet.")
|
"Marlin and ROCm AITER are not supported with all2all yet.")
|
||||||
|
|
||||||
experts = TritonOrDeepGemmExperts(
|
if isinstance(prepare_finalize,
|
||||||
use_fp8_w8a8=True,
|
(BatchedPrepareAndFinalize, PplxPrepareAndFinalize)):
|
||||||
block_shape=self.quant_config.weight_block_size,
|
logger.debug("BatchedTritonExperts(fp8)")
|
||||||
allow_deep_gemm=self.allow_deep_gemm,
|
return BatchedTritonExperts(
|
||||||
)
|
max_num_tokens=MOE_DP_CHUNK_SIZE,
|
||||||
|
world_size=world_size,
|
||||||
return experts
|
dp_size=dp_size,
|
||||||
|
use_fp8_w8a8=True,
|
||||||
|
use_int8_w8a8=False,
|
||||||
|
use_int8_w8a16=False,
|
||||||
|
use_int4_w4a16=False,
|
||||||
|
block_shape=self.quant_config.weight_block_size,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.debug("TritonOrDeepGemmExperts(fp8)")
|
||||||
|
return TritonOrDeepGemmExperts(
|
||||||
|
use_fp8_w8a8=True,
|
||||||
|
block_shape=self.quant_config.weight_block_size,
|
||||||
|
allow_deep_gemm=self.allow_deep_gemm,
|
||||||
|
)
|
||||||
|
|
||||||
def apply(
|
def apply(
|
||||||
self,
|
self,
|
||||||
@ -807,7 +825,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
custom_routing_function=custom_routing_function,
|
custom_routing_function=custom_routing_function,
|
||||||
scoring_func=scoring_func,
|
scoring_func=scoring_func,
|
||||||
e_score_correction_bias=e_score_correction_bias,
|
e_score_correction_bias=e_score_correction_bias,
|
||||||
)
|
indices_type=torch.uint32 if self.use_pplx_kernels else None)
|
||||||
|
|
||||||
if self.rocm_aiter_moe_enabled:
|
if self.rocm_aiter_moe_enabled:
|
||||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501
|
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501
|
||||||
@ -854,7 +872,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
topk_ids=topk_ids,
|
topk_ids=topk_ids,
|
||||||
inplace=True,
|
inplace=True,
|
||||||
activation=activation,
|
activation=activation,
|
||||||
use_fp8_w8a8=True,
|
|
||||||
global_num_experts=global_num_experts,
|
global_num_experts=global_num_experts,
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
expert_map=expert_map,
|
expert_map=expert_map,
|
||||||
|
|||||||
@ -492,6 +492,11 @@ class MPClient(EngineCoreClient):
|
|||||||
(e for e in self.core_engines if e.identity == eng_identity),
|
(e for e in self.core_engines if e.identity == eng_identity),
|
||||||
None)
|
None)
|
||||||
if engine is None:
|
if engine is None:
|
||||||
|
msg = msgspec.msgpack.decode(ready_msg_bytes)
|
||||||
|
status, local = msg["status"], msg["local"]
|
||||||
|
logger.debug(f"XXXXXX {status} message from "
|
||||||
|
f"{'local' if local else 'remote'} "
|
||||||
|
f"engine {eng_index}")
|
||||||
raise RuntimeError(f"Message from engine with unexpected data "
|
raise RuntimeError(f"Message from engine with unexpected data "
|
||||||
f"parallel rank: {eng_index}")
|
f"parallel rank: {eng_index}")
|
||||||
msg = msgspec.msgpack.decode(ready_msg_bytes)
|
msg = msgspec.msgpack.decode(ready_msg_bytes)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user