mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-15 16:53:38 +08:00
[Bug] Fix compressed tensor not using deepgemm (#30820)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
d2dc5dfc6e
commit
97000a2be7
@ -1696,7 +1696,6 @@ def fused_experts(
|
|||||||
and (is_deep_gemm_e8m0_used() or _valid_deep_gemm(hidden_states, w1, w2))
|
and (is_deep_gemm_e8m0_used() or _valid_deep_gemm(hidden_states, w1, w2))
|
||||||
):
|
):
|
||||||
assert quant_config is not None
|
assert quant_config is not None
|
||||||
assert apply_router_weight_on_input is False
|
|
||||||
return deep_gemm_moe_fp8(
|
return deep_gemm_moe_fp8(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
w1=w1,
|
w1=w1,
|
||||||
|
|||||||
@ -96,6 +96,7 @@ from vllm.utils.deep_gemm import (
|
|||||||
get_col_major_tma_aligned_tensor,
|
get_col_major_tma_aligned_tensor,
|
||||||
get_mk_alignment_for_contiguous_layout,
|
get_mk_alignment_for_contiguous_layout,
|
||||||
is_deep_gemm_e8m0_used,
|
is_deep_gemm_e8m0_used,
|
||||||
|
is_deep_gemm_supported,
|
||||||
)
|
)
|
||||||
from vllm.utils.import_utils import has_deep_gemm
|
from vllm.utils.import_utils import has_deep_gemm
|
||||||
|
|
||||||
@ -716,6 +717,13 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
|||||||
get_marlin_input_dtype(layer_name) if self.use_marlin else None
|
get_marlin_input_dtype(layer_name) if self.use_marlin else None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.allow_deep_gemm = (
|
||||||
|
self.block_quant
|
||||||
|
and envs.VLLM_MOE_USE_DEEP_GEMM
|
||||||
|
and is_deep_gemm_supported()
|
||||||
|
and list(self.weight_block_size) == get_mk_alignment_for_contiguous_layout()
|
||||||
|
)
|
||||||
|
|
||||||
def create_weights(
|
def create_weights(
|
||||||
self,
|
self,
|
||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
@ -1231,6 +1239,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
|||||||
if self.disable_expert_map
|
if self.disable_expert_map
|
||||||
else layer.expert_map, # ???
|
else layer.expert_map, # ???
|
||||||
quant_config=self.moe_quant_config,
|
quant_config=self.moe_quant_config,
|
||||||
|
allow_deep_gemm=self.allow_deep_gemm,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
|
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
|
||||||
@ -1272,6 +1281,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
|||||||
global_num_experts=layer.global_num_experts,
|
global_num_experts=layer.global_num_experts,
|
||||||
expert_map=layer.expert_map,
|
expert_map=layer.expert_map,
|
||||||
quant_config=self.moe_quant_config,
|
quant_config=self.moe_quant_config,
|
||||||
|
allow_deep_gemm=self.allow_deep_gemm,
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user