mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-03 06:15:16 +08:00
[Bugfix] Fix fusion for VL models (#30244)
Signed-off-by: ElizaWszola <ewszola@redhat.com>
This commit is contained in:
parent
48b8456ff9
commit
994acec0cc
@ -27,6 +27,7 @@ is_blackwell = lambda: current_platform.is_device_capability_family(100)
|
||||
class Matches(NamedTuple):
|
||||
attention_fusion: int = 0
|
||||
allreduce_fusion: int = 0
|
||||
rms_quant_norm_fusion: int = 0
|
||||
sequence_parallel: int = 0
|
||||
async_tp: int = 0
|
||||
|
||||
@ -40,6 +41,7 @@ class ModelBackendTestCase(NamedTuple):
|
||||
|
||||
MODELS_FP8: list[ModelBackendTestCase] = []
|
||||
MODELS_FP4: list[ModelBackendTestCase] = []
|
||||
MODELS_GROUP_FP8: list[ModelBackendTestCase] = []
|
||||
MODELS: list[ModelBackendTestCase] = [] # tp-only
|
||||
|
||||
if current_platform.is_cuda():
|
||||
@ -498,3 +500,79 @@ def run_model(compile_config: int | CompilationConfig, model: str, **model_kwarg
|
||||
compilation_config.compile_ranges_split_points = (
|
||||
llm.llm_engine.vllm_config.compilation_config.compile_ranges_split_points
|
||||
)
|
||||
|
||||
|
||||
if current_platform.is_cuda():
|
||||
MODELS_GROUP_FP8 = [
|
||||
ModelBackendTestCase(
|
||||
model_name="Qwen/Qwen3-30B-A3B-FP8",
|
||||
model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"),
|
||||
backend=AttentionBackendEnum.TRITON_ATTN,
|
||||
matches=Matches(
|
||||
rms_quant_norm_fusion=48,
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
CUSTOM_OPS_QUANT_RMS_NORM = ["+quant_fp8,+rms_norm"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_name, model_kwargs, backend, matches, custom_ops",
|
||||
# Test rms norm+group quant_fp8 fusion
|
||||
list[tuple[Any, ...]](flat_product(MODELS_GROUP_FP8, CUSTOM_OPS_QUANT_RMS_NORM)),
|
||||
)
|
||||
@pytest.mark.parametrize("inductor_graph_partition", [True, False])
|
||||
def test_rms_group_quant(
|
||||
model_name: str,
|
||||
model_kwargs: dict[str, Any],
|
||||
backend: AttentionBackendEnum,
|
||||
matches: Matches,
|
||||
custom_ops: str,
|
||||
inductor_graph_partition: bool,
|
||||
caplog_mp_spawn,
|
||||
monkeypatch,
|
||||
):
|
||||
if inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
|
||||
pytest.skip("Inductor graph partition requires torch>=2.9")
|
||||
|
||||
custom_ops_list = custom_ops.split(",") if custom_ops else []
|
||||
|
||||
if inductor_graph_partition:
|
||||
mode = CUDAGraphMode.FULL_AND_PIECEWISE
|
||||
splitting_ops: list[str] | None = None
|
||||
else:
|
||||
mode = CUDAGraphMode.FULL_DECODE_ONLY
|
||||
splitting_ops = []
|
||||
|
||||
# Disable, compile cache to make sure custom passes run.
|
||||
# Otherwise, we can't verify fusion happened through the logs.
|
||||
monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")
|
||||
|
||||
# To capture subprocess logs, we need to know whether spawn or fork is used.
|
||||
# Force spawn as it is more general.
|
||||
monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
|
||||
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend.name)
|
||||
|
||||
compilation_config = CompilationConfig(
|
||||
# Testing properties
|
||||
custom_ops=custom_ops_list,
|
||||
use_inductor_graph_partition=inductor_graph_partition,
|
||||
cudagraph_mode=mode,
|
||||
splitting_ops=splitting_ops,
|
||||
# Common
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
pass_config=PassConfig(eliminate_noops=True, enable_fusion=True),
|
||||
# Inductor caches custom passes by default as well via uuid
|
||||
inductor_compile_config={"force_disable_caches": True},
|
||||
)
|
||||
|
||||
with caplog_mp_spawn(logging.DEBUG) as log_holder:
|
||||
run_model(compilation_config, model_name, **model_kwargs)
|
||||
|
||||
log_matches = re.findall(
|
||||
r"\[fusion.py:\d+] Replaced (\d+) patterns",
|
||||
log_holder.text,
|
||||
)
|
||||
assert len(log_matches) == 1, log_holder.text
|
||||
assert int(log_matches[0]) == matches.rms_quant_norm_fusion
|
||||
|
||||
@ -23,17 +23,14 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
kNvfp4Quant,
|
||||
kStaticTensorScale,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
cutlass_block_fp8_supported,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.deep_gemm import (
|
||||
is_deep_gemm_e8m0_used,
|
||||
should_use_deepgemm_for_fp8_linear_for_nk,
|
||||
)
|
||||
|
||||
from .inductor_pass import enable_fake_mode
|
||||
from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8, MatcherRMSNorm
|
||||
from .matcher_utils import (
|
||||
MatcherFusedAddRMSNorm,
|
||||
MatcherQuantFP8,
|
||||
MatcherRMSNorm,
|
||||
)
|
||||
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -118,21 +115,18 @@ FUSED_OPS: dict[FusedRMSQuantKey, OpOverload] = {
|
||||
|
||||
|
||||
class RMSNormQuantPattern:
|
||||
def __init__(self, epsilon: float, key: FusedRMSQuantKey):
|
||||
def __init__(
|
||||
self,
|
||||
epsilon: float,
|
||||
key: FusedRMSQuantKey,
|
||||
has_col_major_scales: bool = False,
|
||||
is_e8m0: bool = False,
|
||||
):
|
||||
self.epsilon = epsilon
|
||||
self.quant_dtype = key.quant.dtype
|
||||
config = get_current_vllm_config()
|
||||
self.model_dtype = config.model_config.dtype if config.model_config else None
|
||||
|
||||
# groupwise FP8 linear uses col major scales if deepgemm and cutlass
|
||||
using_deepgemm = should_use_deepgemm_for_fp8_linear_for_nk(
|
||||
self.model_dtype,
|
||||
config.model_config.hf_config.intermediate_size,
|
||||
config.model_config.hf_config.hidden_size,
|
||||
)
|
||||
use_col_major_scales = using_deepgemm or cutlass_block_fp8_supported()
|
||||
use_e8m0 = is_deep_gemm_e8m0_used() if using_deepgemm else False
|
||||
|
||||
assert key in FUSED_OPS, f"unsupported fused rmsnorm+quant op for {key}"
|
||||
self.FUSED_OP = FUSED_OPS[key]
|
||||
|
||||
@ -142,7 +136,7 @@ class RMSNormQuantPattern:
|
||||
else MatcherFusedAddRMSNorm(epsilon)
|
||||
)
|
||||
self.quant_matcher = MatcherQuantFP8(
|
||||
key.quant, use_col_major_scales=use_col_major_scales, use_e8m0=use_e8m0
|
||||
key.quant, has_col_major_scales=has_col_major_scales, is_e8m0=is_e8m0
|
||||
)
|
||||
|
||||
|
||||
@ -260,6 +254,8 @@ class FusedAddRMSNormGroupQuantPattern(RMSNormQuantPattern):
|
||||
quant_dtype: torch.dtype,
|
||||
group_shape: GroupShape,
|
||||
symmetric=True,
|
||||
has_col_major_scales: bool = False,
|
||||
is_e8m0: bool = False,
|
||||
):
|
||||
scale = ScaleDesc(torch.float32, False, group_shape)
|
||||
key = FusedRMSQuantKey(
|
||||
@ -267,7 +263,11 @@ class FusedAddRMSNormGroupQuantPattern(RMSNormQuantPattern):
|
||||
quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
|
||||
)
|
||||
self.group_shape = group_shape
|
||||
super().__init__(epsilon, key)
|
||||
self.has_col_major_scales = has_col_major_scales
|
||||
self.is_e8m0 = is_e8m0
|
||||
super().__init__(
|
||||
epsilon, key, has_col_major_scales=has_col_major_scales, is_e8m0=is_e8m0
|
||||
)
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
def pattern(input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor):
|
||||
@ -283,9 +283,7 @@ class FusedAddRMSNormGroupQuantPattern(RMSNormQuantPattern):
|
||||
input = input.to(dtype=self.model_dtype)
|
||||
|
||||
result = torch.empty_like(input, dtype=self.quant_dtype)
|
||||
scale = self.quant_matcher.make_scale(
|
||||
input, transposed=self.quant_matcher.use_col_major_scales
|
||||
)
|
||||
scale = self.quant_matcher.make_scale(input, self.has_col_major_scales)
|
||||
at = auto_functionalized(
|
||||
self.FUSED_OP,
|
||||
result=result,
|
||||
@ -296,7 +294,7 @@ class FusedAddRMSNormGroupQuantPattern(RMSNormQuantPattern):
|
||||
scale_ub=None,
|
||||
residual=residual,
|
||||
group_size=self.group_shape[1],
|
||||
is_scale_transposed=self.quant_matcher.use_col_major_scales,
|
||||
is_scale_transposed=self.has_col_major_scales,
|
||||
)
|
||||
|
||||
# result, residual, scale
|
||||
@ -318,6 +316,8 @@ class RMSNormGroupQuantPattern(RMSNormQuantPattern):
|
||||
quant_dtype: torch.dtype,
|
||||
group_shape: GroupShape,
|
||||
symmetric=True,
|
||||
has_col_major_scales: bool = False,
|
||||
is_e8m0: bool = False,
|
||||
):
|
||||
scale = ScaleDesc(torch.float32, False, group_shape)
|
||||
key = FusedRMSQuantKey(
|
||||
@ -325,7 +325,9 @@ class RMSNormGroupQuantPattern(RMSNormQuantPattern):
|
||||
quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
|
||||
)
|
||||
self.group_shape = group_shape
|
||||
super().__init__(epsilon, key)
|
||||
super().__init__(
|
||||
epsilon, key, has_col_major_scales=has_col_major_scales, is_e8m0=is_e8m0
|
||||
)
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
def pattern(input: torch.Tensor, weight: torch.Tensor):
|
||||
@ -340,7 +342,7 @@ class RMSNormGroupQuantPattern(RMSNormQuantPattern):
|
||||
|
||||
result = torch.empty_like(input, dtype=self.quant_dtype)
|
||||
scale = self.quant_matcher.make_scale(
|
||||
input, transposed=self.quant_matcher.use_col_major_scales
|
||||
input, transposed=self.quant_matcher.has_col_major_scales
|
||||
)
|
||||
at = auto_functionalized(
|
||||
self.FUSED_OP,
|
||||
@ -352,7 +354,7 @@ class RMSNormGroupQuantPattern(RMSNormQuantPattern):
|
||||
scale_ub=None,
|
||||
residual=None,
|
||||
group_size=self.group_shape[1],
|
||||
is_scale_transposed=self.quant_matcher.use_col_major_scales,
|
||||
is_scale_transposed=self.quant_matcher.has_col_major_scales,
|
||||
)
|
||||
|
||||
# result, scale
|
||||
@ -489,27 +491,6 @@ class RMSNormQuantFusionPass(VllmPatternMatcherPass):
|
||||
# Make sure fused add patterns are before simple rms norm,
|
||||
# as the latter is a subset of the former in torch ops
|
||||
for epsilon in [1e-5, 1e-6]:
|
||||
# Fuse fused_add_rms_norm + fp8 group quant
|
||||
# Only register group quant patterns on CUDA where the C++ op exists
|
||||
if current_platform.is_cuda():
|
||||
FusedAddRMSNormGroupQuantPattern(
|
||||
epsilon, FP8_DTYPE, group_shape=GroupShape(1, 128)
|
||||
).register(self.patterns)
|
||||
|
||||
# Fuse rms_norm + fp8 group quant
|
||||
RMSNormGroupQuantPattern(
|
||||
epsilon, FP8_DTYPE, group_shape=GroupShape(1, 128)
|
||||
).register(self.patterns)
|
||||
|
||||
FusedAddRMSNormGroupQuantPattern(
|
||||
epsilon, FP8_DTYPE, group_shape=GroupShape(1, 64)
|
||||
).register(self.patterns)
|
||||
|
||||
# Fuse rms_norm + fp8 group quant
|
||||
RMSNormGroupQuantPattern(
|
||||
epsilon, FP8_DTYPE, group_shape=GroupShape(1, 64)
|
||||
).register(self.patterns)
|
||||
|
||||
# Fuse fused_add_rms_norm + static fp8 quant
|
||||
FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(
|
||||
self.patterns
|
||||
@ -526,6 +507,29 @@ class RMSNormQuantFusionPass(VllmPatternMatcherPass):
|
||||
# Fuse rms_norm + dynamic per-token fp8 quant
|
||||
RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(self.patterns)
|
||||
|
||||
# Only register group quant patterns on CUDA where the C++ op exists
|
||||
if current_platform.is_cuda():
|
||||
for group_shape in [GroupShape(1, 128), GroupShape(1, 64)]:
|
||||
for has_col_major_scales in [True, False]:
|
||||
for is_e8m0 in [True, False]:
|
||||
# Fuse fused_add_rms_norm + fp8 group quant
|
||||
FusedAddRMSNormGroupQuantPattern(
|
||||
epsilon,
|
||||
FP8_DTYPE,
|
||||
group_shape=group_shape,
|
||||
has_col_major_scales=has_col_major_scales,
|
||||
is_e8m0=is_e8m0,
|
||||
).register(self.patterns)
|
||||
|
||||
# Fuse rms_norm + fp8 group quant
|
||||
RMSNormGroupQuantPattern(
|
||||
epsilon,
|
||||
FP8_DTYPE,
|
||||
group_shape=group_shape,
|
||||
has_col_major_scales=has_col_major_scales,
|
||||
is_e8m0=is_e8m0,
|
||||
).register(self.patterns)
|
||||
|
||||
self.dump_patterns(config, self.patterns)
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
|
||||
@ -234,24 +234,30 @@ class MatcherQuantFP8(MatcherCustomOp):
|
||||
self,
|
||||
quant_key: QuantKey,
|
||||
enabled: bool | None = None,
|
||||
use_col_major_scales: bool = False,
|
||||
use_e8m0: bool = False,
|
||||
has_col_major_scales: bool = False,
|
||||
is_e8m0: bool = False,
|
||||
):
|
||||
if enabled is None:
|
||||
enabled = QuantFP8.enabled()
|
||||
|
||||
super().__init__(enabled)
|
||||
self.quant_key = quant_key
|
||||
self.use_col_major_scales = use_col_major_scales
|
||||
self.use_e8m0 = use_e8m0
|
||||
assert quant_key in QUANT_OPS, f"unsupported quantization scheme {quant_key}"
|
||||
self.QUANT_OP = QUANT_OPS[quant_key]
|
||||
|
||||
self.has_col_major_scales = has_col_major_scales
|
||||
self.is_e8m0 = is_e8m0
|
||||
|
||||
assert quant_key.dtype == current_platform.fp8_dtype(), (
|
||||
"Only QuantFP8 supported by"
|
||||
)
|
||||
assert quant_key.scale2 is None
|
||||
self.quant_fp8 = QuantFP8(quant_key.scale.static, quant_key.scale.group_shape)
|
||||
self.quant_fp8 = QuantFP8(
|
||||
quant_key.scale.static,
|
||||
quant_key.scale.group_shape,
|
||||
column_major_scales=has_col_major_scales,
|
||||
use_ue8m0=is_e8m0,
|
||||
)
|
||||
|
||||
def forward_custom(
|
||||
self,
|
||||
@ -264,7 +270,7 @@ class MatcherQuantFP8(MatcherCustomOp):
|
||||
|
||||
if self.quant_key.scale.group_shape.is_per_group():
|
||||
assert scale is None
|
||||
scale = self.make_scale(input, transposed=self.use_col_major_scales)
|
||||
scale = self.make_scale(input, transposed=self.has_col_major_scales)
|
||||
|
||||
finfo = torch.finfo(self.quant_key.dtype)
|
||||
fp8_min = finfo.min
|
||||
@ -279,7 +285,7 @@ class MatcherQuantFP8(MatcherCustomOp):
|
||||
eps=1e-10,
|
||||
fp8_min=fp8_min,
|
||||
fp8_max=fp8_max,
|
||||
scale_ue8m0=self.use_e8m0,
|
||||
scale_ue8m0=self.is_e8m0,
|
||||
)
|
||||
return result, scale
|
||||
|
||||
|
||||
@ -381,22 +381,6 @@ def should_use_deepgemm_for_fp8_linear(
|
||||
)
|
||||
|
||||
|
||||
def should_use_deepgemm_for_fp8_linear_for_nk(
|
||||
output_dtype: torch.dtype,
|
||||
shape0: int,
|
||||
shape1: int,
|
||||
supports_deep_gemm: bool | None = None,
|
||||
):
|
||||
if supports_deep_gemm is None:
|
||||
supports_deep_gemm = is_deep_gemm_supported()
|
||||
return (
|
||||
supports_deep_gemm
|
||||
and output_dtype == torch.bfloat16
|
||||
and shape0 % 128 == 0
|
||||
and shape1 % 128 == 0
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"calc_diff",
|
||||
"DeepGemmQuantScaleFMT",
|
||||
@ -411,7 +395,6 @@ __all__ = [
|
||||
"is_deep_gemm_supported",
|
||||
"get_num_sms",
|
||||
"should_use_deepgemm_for_fp8_linear",
|
||||
"should_use_deepgemm_for_fp8_linear_for_nk",
|
||||
"get_col_major_tma_aligned_tensor",
|
||||
"get_mk_alignment_for_contiguous_layout",
|
||||
]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user