[Bugfix] Add fake mode around passes (#23349)

Signed-off-by: angelayi <yiangela7@gmail.com>
This commit is contained in:
Angela Yi 2025-08-28 08:25:56 -07:00 committed by GitHub
parent 95089607fa
commit db74d60490
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 64 additions and 39 deletions

View File

@ -10,6 +10,7 @@ from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from .inductor_pass import enable_fake_mode
from .vllm_inductor_pass import VllmInductorPass from .vllm_inductor_pass import VllmInductorPass
logger = init_logger(__name__) logger = init_logger(__name__)
@ -61,6 +62,7 @@ class ActivationQuantFusionPass(VllmInductorPass):
https://github.com/pytorch/pytorch/pull/139321#issuecomment-2452354980 https://github.com/pytorch/pytorch/pull/139321#issuecomment-2452354980
""" """
@enable_fake_mode
def __init__(self, config: VllmConfig): def __init__(self, config: VllmConfig):
super().__init__(config) super().__init__(config)

View File

@ -19,6 +19,7 @@ from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op from vllm.utils import direct_register_custom_op
from .inductor_pass import enable_fake_mode
from .vllm_inductor_pass import VllmInductorPass from .vllm_inductor_pass import VllmInductorPass
FP8_DTYPE = current_platform.fp8_dtype() FP8_DTYPE = current_platform.fp8_dtype()
@ -349,6 +350,7 @@ class AllGatherCutlassScaledMMPattern(BasePattern):
class AsyncTPPass(VllmInductorPass): class AsyncTPPass(VllmInductorPass):
@enable_fake_mode
def __init__(self, config: VllmConfig): def __init__(self, config: VllmConfig):
super().__init__(config) super().__init__(config)
@ -1121,6 +1123,10 @@ class AllReduceFusionPass(VllmInductorPass):
# in fallback path, when we don't use flashinfer # in fallback path, when we don't use flashinfer
fuse_rms_quant=config.compilation_config.pass_config.enable_fusion) fuse_rms_quant=config.compilation_config.pass_config.enable_fusion)
self.register_patterns()
@enable_fake_mode
def register_patterns(self):
for epsilon in [1e-5, 1e-6]: for epsilon in [1e-5, 1e-6]:
AllReduceFusedRMSNormStaticQuantFP8Pattern( AllReduceFusedRMSNormStaticQuantFP8Pattern(
epsilon, epsilon,

View File

@ -17,6 +17,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
from vllm.platforms import current_platform from vllm.platforms import current_platform
from .fx_utils import find_getitem_maybe from .fx_utils import find_getitem_maybe
from .inductor_pass import enable_fake_mode
from .multi_output_match import MultiOutputMatch from .multi_output_match import MultiOutputMatch
from .vllm_inductor_pass import VllmInductorPass from .vllm_inductor_pass import VllmInductorPass
@ -528,6 +529,7 @@ class FusionPass(VllmInductorPass):
cls._instance.pass_config = config.compilation_config.pass_config cls._instance.pass_config = config.compilation_config.pass_config
return cls._instance return cls._instance
@enable_fake_mode
def __init__(self, config: VllmConfig): def __init__(self, config: VllmConfig):
assert self.__class__._instance is None, \ assert self.__class__._instance is None, \
"FusionPass singleton instance already exists" "FusionPass singleton instance already exists"

View File

@ -7,8 +7,6 @@ import torch
import torch._inductor.pattern_matcher as pm import torch._inductor.pattern_matcher as pm
from torch._higher_order_ops.auto_functionalize import auto_functionalized from torch._higher_order_ops.auto_functionalize import auto_functionalized
from torch._inductor.pattern_matcher import PatternMatcherPass from torch._inductor.pattern_matcher import PatternMatcherPass
from torch._subclasses.fake_tensor import (FakeTensorMode,
unset_fake_temporarily)
from vllm.attention import Attention from vllm.attention import Attention
from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.config import VllmConfig, get_layers_from_vllm_config
@ -19,6 +17,7 @@ from vllm.platforms import current_platform
from vllm.utils import round_up from vllm.utils import round_up
from .fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32 from .fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32
from .inductor_pass import enable_fake_mode
from .vllm_inductor_pass import VllmInductorPass from .vllm_inductor_pass import VllmInductorPass
logger = init_logger(__name__) logger = init_logger(__name__)
@ -139,24 +138,21 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern):
output_block_scale=None) output_block_scale=None)
return RESHAPE_OP(at1[1], [-1, self.num_heads * self.head_size]) return RESHAPE_OP(at1[1], [-1, self.num_heads * self.head_size])
# Need custom fake mode, otherwise tracing happens with real tensors. inputs = [
# That would not work for the unified_attention custom op. empty_bf16(5, self.num_heads, self.head_size), # q
with unset_fake_temporarily(), FakeTensorMode(): empty_bf16(5, self.num_heads, self.head_size), # k
inputs = [ empty_bf16(5, self.num_heads, self.head_size), # v
empty_bf16(5, self.num_heads, self.head_size), # q empty_bf16(5, self.num_heads, self.head_size), # attn_output
empty_bf16(5, self.num_heads, self.head_size), # k self.empty_quant(5,
empty_bf16(5, self.num_heads, self.head_size), # v self.num_heads * self.head_size), # quant_output
empty_bf16(5, self.num_heads, self.head_size), # attn_output empty_fp32(1, 1) # scale
self.empty_quant(5, self.num_heads * ]
self.head_size), # quant_output
empty_fp32(1, 1) # scale
]
pm.register_replacement( pm.register_replacement(
pattern, replacement, inputs, pattern, replacement, inputs,
AttentionQuantPattern.wrap_trace_fn( AttentionQuantPattern.wrap_trace_fn(
AttentionQuantPattern.fx_view_to_reshape, pm.fwd_only), AttentionQuantPattern.fx_view_to_reshape, pm.fwd_only),
pm_pass) pm_pass)
class AttentionNvfp4QuantPattern(AttentionQuantPattern): class AttentionNvfp4QuantPattern(AttentionQuantPattern):
@ -219,27 +215,23 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern):
[-1, self.num_heads * self.head_size // 2]) [-1, self.num_heads * self.head_size // 2])
return output, at2[2] return output, at2[2]
# Need custom fake mode, otherwise tracing happens with real tensors. inputs = [
# That would not work for the unified_attention custom op. empty_bf16(5, self.num_heads, self.head_size), # q
with unset_fake_temporarily(), FakeTensorMode(): empty_bf16(5, self.num_heads, self.head_size), # k
inputs = [ empty_bf16(5, self.num_heads, self.head_size), # v
empty_bf16(5, self.num_heads, self.head_size), # q empty_bf16(5, self.num_heads, self.head_size), # output_attn
empty_bf16(5, self.num_heads, self.head_size), # k self.empty_quant(5, self.num_heads * self.head_size //
empty_bf16(5, self.num_heads, self.head_size), # v 2), # output_quant
empty_bf16(5, self.num_heads, self.head_size), # output_attn empty_i32(128, round_up(self.num_heads * self.head_size // 16,
self.empty_quant(5, self.num_heads * self.head_size // 4)), # output_scale
2), # output_quant empty_fp32(1, 1), # input_scale
empty_i32(128, ]
round_up(self.num_heads * self.head_size // 16,
4)), # output_scale
empty_fp32(1, 1), # input_scale
]
pm.register_replacement( pm.register_replacement(
pattern, replacement, inputs, pattern, replacement, inputs,
AttentionQuantPattern.wrap_trace_fn( AttentionQuantPattern.wrap_trace_fn(
AttentionQuantPattern.fx_view_to_reshape, pm.fwd_only), AttentionQuantPattern.fx_view_to_reshape, pm.fwd_only),
pm_pass) pm_pass)
class AttnFusionPass(VllmInductorPass): class AttnFusionPass(VllmInductorPass):
@ -255,6 +247,7 @@ class AttnFusionPass(VllmInductorPass):
support are attention kernels, which need to support fusing output quant. support are attention kernels, which need to support fusing output quant.
""" """
@enable_fake_mode
def __init__(self, config: VllmConfig): def __init__(self, config: VllmConfig):
super().__init__(config) super().__init__(config)

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
import hashlib import hashlib
import inspect import inspect
import json import json
@ -10,6 +11,8 @@ from typing import Any, Callable, Optional, Union
import torch import torch
from torch import fx from torch import fx
from torch._subclasses.fake_tensor import (FakeTensorMode,
unset_fake_temporarily)
from vllm.utils import is_torch_equal_or_newer from vllm.utils import is_torch_equal_or_newer
@ -114,3 +117,20 @@ class CallableInductorPass(InductorPass):
def uuid(self) -> Any: def uuid(self) -> Any:
return self._uuid return self._uuid
def enable_fake_mode(fn: Callable[..., Any]) -> Callable[..., Any]:
"""
Applies a FakeTensorMode context. This is useful when you don't want to
create or run things with real tensors.
"""
@functools.wraps(fn)
def fn_new(*args, **kwargs) -> Any:
with torch._guards.tracing(
None), unset_fake_temporarily(), FakeTensorMode():
result = fn(*args, **kwargs)
return result
return fn_new

View File

@ -14,6 +14,7 @@ from vllm.distributed.parallel_state import (
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from .inductor_pass import enable_fake_mode
from .vllm_inductor_pass import VllmInductorPass from .vllm_inductor_pass import VllmInductorPass
logger = init_logger(__name__) logger = init_logger(__name__)
@ -436,6 +437,7 @@ class SequenceParallelismPass(VllmInductorPass):
performance. performance.
""" """
@enable_fake_mode
def __init__(self, config: VllmConfig): def __init__(self, config: VllmConfig):
super().__init__(config) super().__init__(config)