mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-16 20:55:49 +08:00
[Bugfix] Disable shared expert overlap if Marlin MoE is used (#28410)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
parent
28534b92b9
commit
e5f599d4d1
@ -678,6 +678,10 @@ class FusedMoE(CustomOp):
|
|||||||
and self.moe_config.use_flashinfer_cutlass_kernels
|
and self.moe_config.use_flashinfer_cutlass_kernels
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def use_marlin_kernels(self):
|
||||||
|
return getattr(self.quant_method, "use_marlin", False)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def use_dp_chunking(self) -> bool:
|
def use_dp_chunking(self) -> bool:
|
||||||
return (
|
return (
|
||||||
|
|||||||
@ -28,17 +28,17 @@ class SharedFusedMoE(FusedMoE):
|
|||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self._shared_experts = shared_experts
|
self._shared_experts = shared_experts
|
||||||
|
|
||||||
# Disable shared expert overlap if we are using eplb, because of
|
# Disable shared expert overlap if:
|
||||||
# correctness issues, or if using flashinfer with DP, since there
|
# - we are using eplb, because of correctness issues
|
||||||
# is nothing to be gained in this case. Disabling the overlap
|
# - we are using flashinfer with DP, since there nothint to gain
|
||||||
# optimization also prevents the shared experts from being hidden
|
# - we are using marlin kjernels
|
||||||
# from torch.compile.
|
|
||||||
self.use_overlapped = (
|
self.use_overlapped = (
|
||||||
use_overlapped
|
use_overlapped
|
||||||
and not (
|
and not (
|
||||||
# TODO(wentao): find the root cause and remove this condition
|
# TODO(wentao): find the root cause and remove this condition
|
||||||
self.enable_eplb
|
self.enable_eplb
|
||||||
or (self.moe_config.use_flashinfer_cutlass_kernels and self.dp_size > 1)
|
or (self.moe_config.use_flashinfer_cutlass_kernels and self.dp_size > 1)
|
||||||
|
or self.use_marlin_kernels
|
||||||
)
|
)
|
||||||
and self._shared_experts is not None
|
and self._shared_experts is not None
|
||||||
)
|
)
|
||||||
|
|||||||
@ -424,6 +424,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
|
|||||||
if self.quant_config.weight_bits != 4:
|
if self.quant_config.weight_bits != 4:
|
||||||
raise ValueError("AWQMoEMethod only supports 4bit now.")
|
raise ValueError("AWQMoEMethod only supports 4bit now.")
|
||||||
self.quant_type = scalar_types.uint4
|
self.quant_type = scalar_types.uint4
|
||||||
|
self.use_marlin = True
|
||||||
|
|
||||||
def create_weights(
|
def create_weights(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -1342,6 +1342,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
|
|||||||
f"{WNA16_SUPPORTED_BITS}",
|
f"{WNA16_SUPPORTED_BITS}",
|
||||||
)
|
)
|
||||||
self.quant_type = WNA16_SUPPORTED_TYPES_MAP[self.num_bits]
|
self.quant_type = WNA16_SUPPORTED_TYPES_MAP[self.num_bits]
|
||||||
|
self.use_marlin = True
|
||||||
|
|
||||||
def create_weights(
|
def create_weights(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -482,6 +482,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
|||||||
self.quant_type = scalar_types.uint8b128
|
self.quant_type = scalar_types.uint8b128
|
||||||
else:
|
else:
|
||||||
raise ValueError("GPTQMarlinMoEMethod only supports int4 and int8 now.")
|
raise ValueError("GPTQMarlinMoEMethod only supports int4 and int8 now.")
|
||||||
|
self.use_marlin = True
|
||||||
|
|
||||||
def create_weights(
|
def create_weights(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -216,6 +216,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|||||||
def __init__(self, moe: FusedMoEConfig):
|
def __init__(self, moe: FusedMoEConfig):
|
||||||
super().__init__(moe)
|
super().__init__(moe)
|
||||||
self.mxfp4_backend = get_mxfp4_backend(moe.is_lora_enabled)
|
self.mxfp4_backend = get_mxfp4_backend(moe.is_lora_enabled)
|
||||||
|
self.use_marlin = self.mxfp4_backend == Mxfp4Backend.MARLIN
|
||||||
self.max_capture_size = (
|
self.max_capture_size = (
|
||||||
get_current_vllm_config().compilation_config.max_cudagraph_capture_size
|
get_current_vllm_config().compilation_config.max_cudagraph_capture_size
|
||||||
)
|
)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user