mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-27 22:57:29 +08:00
[Bugfix] Add fake mode around passes (#23349)
Signed-off-by: angelayi <yiangela7@gmail.com>
This commit is contained in:
parent
95089607fa
commit
db74d60490
@ -10,6 +10,7 @@ from vllm.config import VllmConfig
|
|||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
from .inductor_pass import enable_fake_mode
|
||||||
from .vllm_inductor_pass import VllmInductorPass
|
from .vllm_inductor_pass import VllmInductorPass
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -61,6 +62,7 @@ class ActivationQuantFusionPass(VllmInductorPass):
|
|||||||
https://github.com/pytorch/pytorch/pull/139321#issuecomment-2452354980
|
https://github.com/pytorch/pytorch/pull/139321#issuecomment-2452354980
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@enable_fake_mode
|
||||||
def __init__(self, config: VllmConfig):
|
def __init__(self, config: VllmConfig):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
|
|||||||
@ -19,6 +19,7 @@ from vllm.logger import init_logger
|
|||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import direct_register_custom_op
|
from vllm.utils import direct_register_custom_op
|
||||||
|
|
||||||
|
from .inductor_pass import enable_fake_mode
|
||||||
from .vllm_inductor_pass import VllmInductorPass
|
from .vllm_inductor_pass import VllmInductorPass
|
||||||
|
|
||||||
FP8_DTYPE = current_platform.fp8_dtype()
|
FP8_DTYPE = current_platform.fp8_dtype()
|
||||||
@ -349,6 +350,7 @@ class AllGatherCutlassScaledMMPattern(BasePattern):
|
|||||||
|
|
||||||
class AsyncTPPass(VllmInductorPass):
|
class AsyncTPPass(VllmInductorPass):
|
||||||
|
|
||||||
|
@enable_fake_mode
|
||||||
def __init__(self, config: VllmConfig):
|
def __init__(self, config: VllmConfig):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
@ -1121,6 +1123,10 @@ class AllReduceFusionPass(VllmInductorPass):
|
|||||||
# in fallback path, when we don't use flashinfer
|
# in fallback path, when we don't use flashinfer
|
||||||
fuse_rms_quant=config.compilation_config.pass_config.enable_fusion)
|
fuse_rms_quant=config.compilation_config.pass_config.enable_fusion)
|
||||||
|
|
||||||
|
self.register_patterns()
|
||||||
|
|
||||||
|
@enable_fake_mode
|
||||||
|
def register_patterns(self):
|
||||||
for epsilon in [1e-5, 1e-6]:
|
for epsilon in [1e-5, 1e-6]:
|
||||||
AllReduceFusedRMSNormStaticQuantFP8Pattern(
|
AllReduceFusedRMSNormStaticQuantFP8Pattern(
|
||||||
epsilon,
|
epsilon,
|
||||||
|
|||||||
@ -17,6 +17,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
|||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
from .fx_utils import find_getitem_maybe
|
from .fx_utils import find_getitem_maybe
|
||||||
|
from .inductor_pass import enable_fake_mode
|
||||||
from .multi_output_match import MultiOutputMatch
|
from .multi_output_match import MultiOutputMatch
|
||||||
from .vllm_inductor_pass import VllmInductorPass
|
from .vllm_inductor_pass import VllmInductorPass
|
||||||
|
|
||||||
@ -528,6 +529,7 @@ class FusionPass(VllmInductorPass):
|
|||||||
cls._instance.pass_config = config.compilation_config.pass_config
|
cls._instance.pass_config = config.compilation_config.pass_config
|
||||||
return cls._instance
|
return cls._instance
|
||||||
|
|
||||||
|
@enable_fake_mode
|
||||||
def __init__(self, config: VllmConfig):
|
def __init__(self, config: VllmConfig):
|
||||||
assert self.__class__._instance is None, \
|
assert self.__class__._instance is None, \
|
||||||
"FusionPass singleton instance already exists"
|
"FusionPass singleton instance already exists"
|
||||||
|
|||||||
@ -7,8 +7,6 @@ import torch
|
|||||||
import torch._inductor.pattern_matcher as pm
|
import torch._inductor.pattern_matcher as pm
|
||||||
from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
||||||
from torch._inductor.pattern_matcher import PatternMatcherPass
|
from torch._inductor.pattern_matcher import PatternMatcherPass
|
||||||
from torch._subclasses.fake_tensor import (FakeTensorMode,
|
|
||||||
unset_fake_temporarily)
|
|
||||||
|
|
||||||
from vllm.attention import Attention
|
from vllm.attention import Attention
|
||||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||||
@ -19,6 +17,7 @@ from vllm.platforms import current_platform
|
|||||||
from vllm.utils import round_up
|
from vllm.utils import round_up
|
||||||
|
|
||||||
from .fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32
|
from .fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32
|
||||||
|
from .inductor_pass import enable_fake_mode
|
||||||
from .vllm_inductor_pass import VllmInductorPass
|
from .vllm_inductor_pass import VllmInductorPass
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -139,24 +138,21 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern):
|
|||||||
output_block_scale=None)
|
output_block_scale=None)
|
||||||
return RESHAPE_OP(at1[1], [-1, self.num_heads * self.head_size])
|
return RESHAPE_OP(at1[1], [-1, self.num_heads * self.head_size])
|
||||||
|
|
||||||
# Need custom fake mode, otherwise tracing happens with real tensors.
|
inputs = [
|
||||||
# That would not work for the unified_attention custom op.
|
empty_bf16(5, self.num_heads, self.head_size), # q
|
||||||
with unset_fake_temporarily(), FakeTensorMode():
|
empty_bf16(5, self.num_heads, self.head_size), # k
|
||||||
inputs = [
|
empty_bf16(5, self.num_heads, self.head_size), # v
|
||||||
empty_bf16(5, self.num_heads, self.head_size), # q
|
empty_bf16(5, self.num_heads, self.head_size), # attn_output
|
||||||
empty_bf16(5, self.num_heads, self.head_size), # k
|
self.empty_quant(5,
|
||||||
empty_bf16(5, self.num_heads, self.head_size), # v
|
self.num_heads * self.head_size), # quant_output
|
||||||
empty_bf16(5, self.num_heads, self.head_size), # attn_output
|
empty_fp32(1, 1) # scale
|
||||||
self.empty_quant(5, self.num_heads *
|
]
|
||||||
self.head_size), # quant_output
|
|
||||||
empty_fp32(1, 1) # scale
|
|
||||||
]
|
|
||||||
|
|
||||||
pm.register_replacement(
|
pm.register_replacement(
|
||||||
pattern, replacement, inputs,
|
pattern, replacement, inputs,
|
||||||
AttentionQuantPattern.wrap_trace_fn(
|
AttentionQuantPattern.wrap_trace_fn(
|
||||||
AttentionQuantPattern.fx_view_to_reshape, pm.fwd_only),
|
AttentionQuantPattern.fx_view_to_reshape, pm.fwd_only),
|
||||||
pm_pass)
|
pm_pass)
|
||||||
|
|
||||||
|
|
||||||
class AttentionNvfp4QuantPattern(AttentionQuantPattern):
|
class AttentionNvfp4QuantPattern(AttentionQuantPattern):
|
||||||
@ -219,27 +215,23 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern):
|
|||||||
[-1, self.num_heads * self.head_size // 2])
|
[-1, self.num_heads * self.head_size // 2])
|
||||||
return output, at2[2]
|
return output, at2[2]
|
||||||
|
|
||||||
# Need custom fake mode, otherwise tracing happens with real tensors.
|
inputs = [
|
||||||
# That would not work for the unified_attention custom op.
|
empty_bf16(5, self.num_heads, self.head_size), # q
|
||||||
with unset_fake_temporarily(), FakeTensorMode():
|
empty_bf16(5, self.num_heads, self.head_size), # k
|
||||||
inputs = [
|
empty_bf16(5, self.num_heads, self.head_size), # v
|
||||||
empty_bf16(5, self.num_heads, self.head_size), # q
|
empty_bf16(5, self.num_heads, self.head_size), # output_attn
|
||||||
empty_bf16(5, self.num_heads, self.head_size), # k
|
self.empty_quant(5, self.num_heads * self.head_size //
|
||||||
empty_bf16(5, self.num_heads, self.head_size), # v
|
2), # output_quant
|
||||||
empty_bf16(5, self.num_heads, self.head_size), # output_attn
|
empty_i32(128, round_up(self.num_heads * self.head_size // 16,
|
||||||
self.empty_quant(5, self.num_heads * self.head_size //
|
4)), # output_scale
|
||||||
2), # output_quant
|
empty_fp32(1, 1), # input_scale
|
||||||
empty_i32(128,
|
]
|
||||||
round_up(self.num_heads * self.head_size // 16,
|
|
||||||
4)), # output_scale
|
|
||||||
empty_fp32(1, 1), # input_scale
|
|
||||||
]
|
|
||||||
|
|
||||||
pm.register_replacement(
|
pm.register_replacement(
|
||||||
pattern, replacement, inputs,
|
pattern, replacement, inputs,
|
||||||
AttentionQuantPattern.wrap_trace_fn(
|
AttentionQuantPattern.wrap_trace_fn(
|
||||||
AttentionQuantPattern.fx_view_to_reshape, pm.fwd_only),
|
AttentionQuantPattern.fx_view_to_reshape, pm.fwd_only),
|
||||||
pm_pass)
|
pm_pass)
|
||||||
|
|
||||||
|
|
||||||
class AttnFusionPass(VllmInductorPass):
|
class AttnFusionPass(VllmInductorPass):
|
||||||
@ -255,6 +247,7 @@ class AttnFusionPass(VllmInductorPass):
|
|||||||
support are attention kernels, which need to support fusing output quant.
|
support are attention kernels, which need to support fusing output quant.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@enable_fake_mode
|
||||||
def __init__(self, config: VllmConfig):
|
def __init__(self, config: VllmConfig):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import functools
|
||||||
import hashlib
|
import hashlib
|
||||||
import inspect
|
import inspect
|
||||||
import json
|
import json
|
||||||
@ -10,6 +11,8 @@ from typing import Any, Callable, Optional, Union
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import fx
|
from torch import fx
|
||||||
|
from torch._subclasses.fake_tensor import (FakeTensorMode,
|
||||||
|
unset_fake_temporarily)
|
||||||
|
|
||||||
from vllm.utils import is_torch_equal_or_newer
|
from vllm.utils import is_torch_equal_or_newer
|
||||||
|
|
||||||
@ -114,3 +117,20 @@ class CallableInductorPass(InductorPass):
|
|||||||
|
|
||||||
def uuid(self) -> Any:
|
def uuid(self) -> Any:
|
||||||
return self._uuid
|
return self._uuid
|
||||||
|
|
||||||
|
|
||||||
|
def enable_fake_mode(fn: Callable[..., Any]) -> Callable[..., Any]:
|
||||||
|
"""
|
||||||
|
Applies a FakeTensorMode context. This is useful when you don't want to
|
||||||
|
create or run things with real tensors.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@functools.wraps(fn)
|
||||||
|
def fn_new(*args, **kwargs) -> Any:
|
||||||
|
with torch._guards.tracing(
|
||||||
|
None), unset_fake_temporarily(), FakeTensorMode():
|
||||||
|
result = fn(*args, **kwargs)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
return fn_new
|
||||||
|
|||||||
@ -14,6 +14,7 @@ from vllm.distributed.parallel_state import (
|
|||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
from .inductor_pass import enable_fake_mode
|
||||||
from .vllm_inductor_pass import VllmInductorPass
|
from .vllm_inductor_pass import VllmInductorPass
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -436,6 +437,7 @@ class SequenceParallelismPass(VllmInductorPass):
|
|||||||
performance.
|
performance.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@enable_fake_mode
|
||||||
def __init__(self, config: VllmConfig):
|
def __init__(self, config: VllmConfig):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user