mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-12 03:40:15 +08:00
[BugFix] Fix de-functionalization pass for rotary_embedding (#23953)
Signed-off-by: angelayi <yiangela7@gmail.com>
This commit is contained in:
parent
b71fcd4905
commit
7cfa4b24bf
@ -397,6 +397,7 @@ steps:
|
|||||||
- pytest -v -s compile/test_pass_manager.py
|
- pytest -v -s compile/test_pass_manager.py
|
||||||
- pytest -v -s compile/test_fusion.py
|
- pytest -v -s compile/test_fusion.py
|
||||||
- pytest -v -s compile/test_fusion_attn.py
|
- pytest -v -s compile/test_fusion_attn.py
|
||||||
|
- pytest -v -s compile/test_functionalization.py
|
||||||
- pytest -v -s compile/test_silu_mul_quant_fusion.py
|
- pytest -v -s compile/test_silu_mul_quant_fusion.py
|
||||||
- pytest -v -s compile/test_sequence_parallelism.py
|
- pytest -v -s compile/test_sequence_parallelism.py
|
||||||
- pytest -v -s compile/test_async_tp.py
|
- pytest -v -s compile/test_async_tp.py
|
||||||
|
|||||||
@ -5,54 +5,237 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm import LLM, SamplingParams
|
|
||||||
from vllm.compilation.activation_quant_fusion import ActivationQuantFusionPass
|
from vllm.compilation.activation_quant_fusion import ActivationQuantFusionPass
|
||||||
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
|
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
|
||||||
from vllm.compilation.fusion import FUSED_OPS, RMSNormQuantFusionPass
|
from vllm.compilation.fusion import RMSNormQuantFusionPass
|
||||||
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func
|
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func
|
||||||
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
||||||
from vllm.compilation.post_cleanup import PostCleanupPass
|
from vllm.compilation.post_cleanup import PostCleanupPass
|
||||||
from vllm.config import CompilationConfig, PassConfig, VllmConfig
|
from vllm.config import CompilationConfig, PassConfig, VllmConfig
|
||||||
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
QuantKey, kFp8DynamicTokenSym, kFp8StaticTensorSym)
|
GroupShape)
|
||||||
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||||
|
Fp8LinearOp)
|
||||||
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
from .backend import TestBackend
|
from .backend import TestBackend
|
||||||
|
|
||||||
OPS_IN_MODEL = [
|
TEST_FP8 = current_platform.supports_fp8()
|
||||||
torch.ops._C.rotary_embedding.default,
|
FP8_DTYPE = current_platform.fp8_dtype()
|
||||||
torch.ops._C.fused_add_rms_norm.default,
|
|
||||||
]
|
|
||||||
|
|
||||||
RMS_OP = torch.ops._C.rms_norm.default
|
|
||||||
|
|
||||||
RMS_QUANT_OPS = {
|
class TestSiluMul(torch.nn.Module):
|
||||||
"static_fp8": [
|
|
||||||
torch.ops._C.rms_norm_static_fp8_quant.default,
|
|
||||||
torch.ops._C.fused_add_rms_norm_static_fp8_quant.default
|
|
||||||
],
|
|
||||||
}
|
|
||||||
|
|
||||||
SILU_MUL_OP = torch.ops._C.silu_and_mul.default
|
def __init__(self, hidden_size: int = 128):
|
||||||
|
super().__init__()
|
||||||
|
self.silu_and_mul = SiluAndMul()
|
||||||
|
self.wscale = torch.rand(1, dtype=torch.float32)
|
||||||
|
self.scale = torch.rand(1, dtype=torch.float32)
|
||||||
|
|
||||||
SILU_MUL_QUANT_OP = torch.ops._C.silu_and_mul_quant.default
|
if TEST_FP8:
|
||||||
prompts = [
|
self.w = torch.rand(hidden_size,
|
||||||
"Hello, my name is",
|
hidden_size).to(dtype=FP8_DTYPE).t()
|
||||||
"The president of the United States is",
|
self.fp8_linear = Fp8LinearOp(
|
||||||
"The capital of France is",
|
act_quant_static=True,
|
||||||
"The future of AI is",
|
act_quant_group_shape=GroupShape.PER_TENSOR,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
y = self.silu_and_mul(x)
|
||||||
|
if TEST_FP8:
|
||||||
|
x2 = self.fp8_linear.apply(y,
|
||||||
|
self.w,
|
||||||
|
self.wscale,
|
||||||
|
input_scale=self.wscale)
|
||||||
|
return x2
|
||||||
|
else:
|
||||||
|
return y
|
||||||
|
|
||||||
|
def example_inputs(self, num_tokens=32, hidden_size=128):
|
||||||
|
dtype = torch.float16 if TEST_FP8 else torch.float32
|
||||||
|
return (torch.rand(num_tokens, hidden_size * 2, dtype=dtype), )
|
||||||
|
|
||||||
|
def ops_in_model(self, do_fusion):
|
||||||
|
if TEST_FP8 and do_fusion:
|
||||||
|
return [torch.ops._C.silu_and_mul_quant.default]
|
||||||
|
else:
|
||||||
|
return [torch.ops._C.silu_and_mul.default]
|
||||||
|
|
||||||
|
def ops_not_in_model(self):
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
class TestFusedAddRMSNorm(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, hidden_size=16, intermediate_size=32):
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
|
||||||
|
dtype = torch.float16 if TEST_FP8 else torch.float32
|
||||||
|
|
||||||
|
self.gate_proj = torch.nn.Parameter(
|
||||||
|
torch.empty((intermediate_size, hidden_size), dtype=dtype))
|
||||||
|
self.norm = RMSNorm(intermediate_size, 1e-05)
|
||||||
|
self.norm.weight = torch.nn.Parameter(
|
||||||
|
torch.ones(intermediate_size, dtype=dtype))
|
||||||
|
|
||||||
|
torch.nn.init.normal_(self.gate_proj, std=0.02)
|
||||||
|
|
||||||
|
if TEST_FP8:
|
||||||
|
self.fp8_linear = Fp8LinearOp(act_quant_static=True)
|
||||||
|
|
||||||
|
self.scale = torch.rand(1, dtype=torch.float32)
|
||||||
|
self.w = torch.rand(hidden_size,
|
||||||
|
intermediate_size).to(dtype=FP8_DTYPE).t()
|
||||||
|
self.wscale = torch.rand(1, dtype=torch.float32)
|
||||||
|
|
||||||
|
def forward(self, hidden_states, residual):
|
||||||
|
# Reshape input
|
||||||
|
view = hidden_states.reshape(-1, self.hidden_size)
|
||||||
|
|
||||||
|
# matrix multiplication
|
||||||
|
permute = self.gate_proj.permute(1, 0)
|
||||||
|
mm = torch.mm(view, permute)
|
||||||
|
|
||||||
|
# layer normalization
|
||||||
|
norm_output, residual_output = self.norm(mm, residual)
|
||||||
|
|
||||||
|
if TEST_FP8:
|
||||||
|
# scaled_mm with static input quantization
|
||||||
|
fp8_linear_result = self.fp8_linear.apply(
|
||||||
|
norm_output,
|
||||||
|
self.w,
|
||||||
|
self.wscale,
|
||||||
|
input_scale=self.scale.to(norm_output.device),
|
||||||
|
)
|
||||||
|
|
||||||
|
return fp8_linear_result, residual_output
|
||||||
|
|
||||||
|
else:
|
||||||
|
return norm_output, residual_output
|
||||||
|
|
||||||
|
def example_inputs(self, batch_size=8, hidden_size=16, seq_len=16):
|
||||||
|
dtype = torch.float16 if TEST_FP8 else torch.float32
|
||||||
|
hidden_states = torch.randn((batch_size * seq_len, hidden_size),
|
||||||
|
dtype=dtype)
|
||||||
|
residual = torch.randn((batch_size * seq_len, hidden_size),
|
||||||
|
dtype=dtype)
|
||||||
|
return (hidden_states, residual)
|
||||||
|
|
||||||
|
def ops_in_model(self, do_fusion):
|
||||||
|
if TEST_FP8 and do_fusion:
|
||||||
|
return [torch.ops._C.fused_add_rms_norm_static_fp8_quant.default]
|
||||||
|
else:
|
||||||
|
return [torch.ops._C.fused_add_rms_norm.default]
|
||||||
|
|
||||||
|
def ops_not_in_model(self):
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
class TestRotaryEmbedding(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
head_dim=64,
|
||||||
|
rotary_dim=None,
|
||||||
|
max_position=2048,
|
||||||
|
base=10000):
|
||||||
|
super().__init__()
|
||||||
|
self.head_dim = head_dim
|
||||||
|
self.rotary_dim = rotary_dim or head_dim
|
||||||
|
|
||||||
|
self.rotary_emb = get_rope(
|
||||||
|
self.head_dim,
|
||||||
|
rotary_dim=self.rotary_dim,
|
||||||
|
max_position=max_position,
|
||||||
|
base=base,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, positions, q, k):
|
||||||
|
q_rotated, k_rotated = self.rotary_emb(positions, q, k)
|
||||||
|
return q_rotated, k_rotated
|
||||||
|
|
||||||
|
def example_inputs(self, num_tokens=32, head_dim=64):
|
||||||
|
dtype = torch.float16
|
||||||
|
positions = torch.arange(num_tokens, dtype=torch.long)
|
||||||
|
q = torch.randn(num_tokens, head_dim, dtype=dtype)
|
||||||
|
k = torch.randn(num_tokens, head_dim, dtype=dtype)
|
||||||
|
return (positions, q, k)
|
||||||
|
|
||||||
|
def ops_in_model(self, do_fusion):
|
||||||
|
return [torch.ops._C.rotary_embedding.default]
|
||||||
|
|
||||||
|
def ops_not_in_model(self):
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
class TestRotaryEmbeddingSliceScatter(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
head_dim=64,
|
||||||
|
num_heads=4,
|
||||||
|
max_position=2048,
|
||||||
|
base=10000):
|
||||||
|
super().__init__()
|
||||||
|
self.head_dim = head_dim
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.hidden_size = head_dim * num_heads
|
||||||
|
|
||||||
|
self.qkv_proj = torch.nn.Linear(self.hidden_size,
|
||||||
|
self.hidden_size * 3,
|
||||||
|
bias=False,
|
||||||
|
dtype=torch.float16)
|
||||||
|
|
||||||
|
self.rotary_emb = get_rope(
|
||||||
|
self.head_dim,
|
||||||
|
rotary_dim=self.head_dim,
|
||||||
|
max_position=max_position,
|
||||||
|
base=base,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, positions, hidden_states):
|
||||||
|
# Simulate the pattern: mm -> split_with_sizes -> rotary_embedding
|
||||||
|
# -> slice_scatter -> split_with_sizes
|
||||||
|
|
||||||
|
qkv = self.qkv_proj(hidden_states)
|
||||||
|
split_sizes = [self.hidden_size, self.hidden_size, self.hidden_size]
|
||||||
|
q, k, v = torch.split(qkv, split_sizes, dim=-1)
|
||||||
|
|
||||||
|
q_rotated, k_rotated = self.rotary_emb(positions, q, k)
|
||||||
|
|
||||||
|
qkv_updated = torch.cat([q_rotated, k_rotated, v], dim=-1)
|
||||||
|
return qkv_updated
|
||||||
|
|
||||||
|
def example_inputs(self, num_tokens=32, head_dim=64, num_heads=4):
|
||||||
|
dtype = torch.float16
|
||||||
|
hidden_size = head_dim * num_heads
|
||||||
|
positions = torch.arange(num_tokens, dtype=torch.long)
|
||||||
|
hidden_states = torch.randn(num_tokens, hidden_size, dtype=dtype)
|
||||||
|
return (positions, hidden_states)
|
||||||
|
|
||||||
|
def ops_in_model(self, do_fusion):
|
||||||
|
return [torch.ops._C.rotary_embedding.default]
|
||||||
|
|
||||||
|
def ops_not_in_model(self):
|
||||||
|
return [torch.ops.aten.slice_scatter.default]
|
||||||
|
|
||||||
|
|
||||||
|
MODELS = [
|
||||||
|
TestSiluMul,
|
||||||
|
TestFusedAddRMSNorm,
|
||||||
|
TestRotaryEmbedding,
|
||||||
|
TestRotaryEmbeddingSliceScatter,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize("model_class", MODELS)
|
||||||
"model, quant_key",
|
|
||||||
[("nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e", kFp8StaticTensorSym),
|
|
||||||
("nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8_DYNAMIC-e2e",
|
|
||||||
kFp8DynamicTokenSym)])
|
|
||||||
@pytest.mark.parametrize("do_fusion", [True, False])
|
@pytest.mark.parametrize("do_fusion", [True, False])
|
||||||
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE != "cuda",
|
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE != "cuda",
|
||||||
reason="Only test on CUDA")
|
reason="Only test on CUDA")
|
||||||
def test_fix_functionalization(model: str, quant_key: QuantKey,
|
def test_fix_functionalization(model_class: torch.nn.Module, do_fusion: bool):
|
||||||
do_fusion: bool):
|
|
||||||
torch.set_default_device("cuda")
|
torch.set_default_device("cuda")
|
||||||
|
|
||||||
vllm_config = VllmConfig()
|
vllm_config = VllmConfig()
|
||||||
@ -63,56 +246,31 @@ def test_fix_functionalization(model: str, quant_key: QuantKey,
|
|||||||
cleanup_pass = PostCleanupPass(vllm_config)
|
cleanup_pass = PostCleanupPass(vllm_config)
|
||||||
act_quant_fusion_pass = ActivationQuantFusionPass(vllm_config)
|
act_quant_fusion_pass = ActivationQuantFusionPass(vllm_config)
|
||||||
|
|
||||||
passes = [noop_pass, fusion_pass, act_quant_fusion_pass, cleanup_pass
|
passes = ([noop_pass, fusion_pass, act_quant_fusion_pass, cleanup_pass]
|
||||||
] if do_fusion else [noop_pass, cleanup_pass]
|
if do_fusion else [noop_pass, cleanup_pass])
|
||||||
func_pass = FixFunctionalizationPass(vllm_config)
|
func_pass = FixFunctionalizationPass(vllm_config)
|
||||||
|
|
||||||
backend_func = TestBackend(*passes, func_pass)
|
backend_func = TestBackend(*passes, func_pass)
|
||||||
backend_no_func = TestBackend(*passes)
|
backend_no_func = TestBackend(*passes)
|
||||||
|
|
||||||
# instantiate a full engine and manually compile the model 2x
|
model = model_class()
|
||||||
# (with and without FixFunctionalizationPass)
|
torch.compile(model, backend=backend_func)(*model.example_inputs())
|
||||||
llm = LLM(model=model, enforce_eager=True)
|
torch.compile(model, backend=backend_no_func)(*model.example_inputs())
|
||||||
model_runner = llm.llm_engine.model_executor.driver_worker.model_runner
|
|
||||||
orig_model = model_runner.model
|
|
||||||
# TODO mark inputs dynamic? (currently torch.compile is triggered 4x)
|
|
||||||
# Can only do that by using the decorator but then we'd have to instantiate
|
|
||||||
# 2 LLM instances.
|
|
||||||
|
|
||||||
sampling_params = SamplingParams(temperature=0.0, top_p=1.0)
|
# check if the functionalization pass is applied
|
||||||
model_runner.model = torch.compile(orig_model,
|
for op in model.ops_in_model(do_fusion):
|
||||||
fullgraph=True,
|
|
||||||
backend=backend_func)
|
|
||||||
gen_func = llm.generate(prompts, sampling_params)
|
|
||||||
|
|
||||||
model_runner.model = torch.compile(orig_model,
|
|
||||||
fullgraph=True,
|
|
||||||
backend=backend_no_func)
|
|
||||||
|
|
||||||
gen_no_func = llm.generate(prompts, sampling_params)
|
|
||||||
|
|
||||||
for output_func, output_no_func in zip(gen_func, gen_no_func):
|
|
||||||
assert output_func.outputs[0].text == output_no_func.outputs[0].text
|
|
||||||
|
|
||||||
# OPS_IN_MODEL always appear. RMS_OP is fused away if we run fusion,
|
|
||||||
# and replaced by fused quantized ops in RMS_QUANT_OPS.
|
|
||||||
rms_ops = [FUSED_OPS[(quant_key, True)], FUSED_OPS[(quant_key, False)]
|
|
||||||
] if do_fusion else [RMS_OP]
|
|
||||||
silu_mul_ops = [SILU_MUL_QUANT_OP] if do_fusion and \
|
|
||||||
quant_key == kFp8StaticTensorSym else [
|
|
||||||
SILU_MUL_OP
|
|
||||||
]
|
|
||||||
|
|
||||||
ops = OPS_IN_MODEL + rms_ops + silu_mul_ops
|
|
||||||
|
|
||||||
for op in ops:
|
|
||||||
find_auto_fn(backend_no_func.graph_post_pass.nodes, op)
|
find_auto_fn(backend_no_func.graph_post_pass.nodes, op)
|
||||||
assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes,
|
assert (find_auto_fn_maybe(backend_func.graph_post_pass.nodes, op)
|
||||||
op) is None # noqa: E501
|
is None) # noqa: E501
|
||||||
|
|
||||||
# make sure the ops were all de-functionalized
|
# make sure the ops were all de-functionalized
|
||||||
found = dict()
|
found = dict()
|
||||||
for node in backend_func.graph_post_pass.nodes:
|
for node in backend_func.graph_post_pass.nodes:
|
||||||
for op in ops:
|
for op in model.ops_in_model(do_fusion):
|
||||||
if is_func(node, op):
|
if is_func(node, op):
|
||||||
found[op] = True
|
found[op] = True
|
||||||
assert all(found[op] for op in ops)
|
for op in model.ops_not_in_model():
|
||||||
|
if is_func(node, op):
|
||||||
|
found[op] = True
|
||||||
|
assert all(found[op] for op in model.ops_in_model(do_fusion))
|
||||||
|
assert all(not found.get(op) for op in model.ops_not_in_model())
|
||||||
|
|||||||
@ -46,23 +46,43 @@ class FixFunctionalizationPass(VllmInductorPass):
|
|||||||
|
|
||||||
if at_target == torch.ops._C.rotary_embedding.default:
|
if at_target == torch.ops._C.rotary_embedding.default:
|
||||||
query = kwargs['query']
|
query = kwargs['query']
|
||||||
mm_node = query.args[0].args[0]
|
key = kwargs['key']
|
||||||
|
getitem_nodes = self.getitem_users(node)
|
||||||
|
|
||||||
# rotary_embedding is a special case: the two mutating inputs
|
if (is_func(query, operator.getitem)
|
||||||
# are query and key, which are slices of mm_node.
|
and is_func(key, operator.getitem)
|
||||||
# While functionalized, results at[1] and at[2] are scattered
|
and query.args[0] == key.args[0]
|
||||||
# back into mm_node. After de-functionalization, we can just
|
and is_func(query.args[0],
|
||||||
# use mm_node directly.
|
torch.ops.aten.split_with_sizes.default)
|
||||||
for idx, user in self.getitem_users(node).items():
|
and all(
|
||||||
for user_of_getitem in user.users:
|
is_func(user, torch.ops.aten.slice_scatter.default)
|
||||||
if is_func(user_of_getitem,
|
for getitem_node in getitem_nodes.values()
|
||||||
torch.ops.aten.slice_scatter.default):
|
for user in getitem_node.users)):
|
||||||
user_of_getitem.replace_all_uses_with(mm_node)
|
# Pattern where query and key are slices of an mm_node.
|
||||||
self._remove(user_of_getitem)
|
# While functionalized, results at [1] and [2] are scattered
|
||||||
self._remove(user)
|
# back into mm_node. So after de-functionalization, we can
|
||||||
|
# just use mm_node directly.
|
||||||
|
|
||||||
self.insert_defunctionalized(graph, node)
|
mm_node = query.args[0].args[0]
|
||||||
self._remove(node)
|
for user in getitem_nodes.values():
|
||||||
|
for user_of_getitem in user.users:
|
||||||
|
if is_func(user_of_getitem,
|
||||||
|
torch.ops.aten.slice_scatter.default):
|
||||||
|
user_of_getitem.replace_all_uses_with(mm_node)
|
||||||
|
self._remove(user_of_getitem)
|
||||||
|
self._remove(user)
|
||||||
|
|
||||||
|
self.insert_defunctionalized(graph, node)
|
||||||
|
self._remove(node)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Directly replace the auto_functionalize(rotary_embedding)
|
||||||
|
# with the inplace rotary_embedding. In theory, we shouldn't
|
||||||
|
# do this blindly, but in practice in vLLM it's ok. The best
|
||||||
|
# solution is to use auto_functionalization_v2 and then use
|
||||||
|
# inductor's builtin defunctionalization (reinplacing) pass.
|
||||||
|
mutated_args = {1: 'query', 2: 'key'}
|
||||||
|
self.defunctionalize(graph, node, mutated_args)
|
||||||
|
|
||||||
# rms_norm replacements avoid the most copies for LLaMa.
|
# rms_norm replacements avoid the most copies for LLaMa.
|
||||||
elif at_target == torch.ops._C.fused_add_rms_norm.default:
|
elif at_target == torch.ops._C.fused_add_rms_norm.default:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user