mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-06 10:57:06 +08:00
Refactor duplicate torch operation registrations to use shared module
Instead of changing library names (not scalable), create a shared test_operations.py module that: - Provides a single "silly" library for all compilation tests - Registers a unified attention operation that can handle both standard and counting modes - Eliminates duplicate registration errors when running all tests together - Maintains backward compatibility with existing test behavior Addresses feedback to make the solution more scalable and maintainable. Co-authored-by: ProExpertProg <11367180+ProExpertProg@users.noreply.github.com>
This commit is contained in:
parent
47dcf0940f
commit
2c81fbbb3c
@ -6,7 +6,6 @@ are compiled and graph captured separately.
|
||||
"""
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.library import Library
|
||||
|
||||
from vllm.compilation.backends import set_model_tag
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
@ -16,10 +15,10 @@ from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig,
|
||||
set_current_vllm_config)
|
||||
from vllm.envs import VLLM_USE_V1
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
# create a library to hold the custom op
|
||||
silly_lib = Library("silly_multiple", "FRAGMENT") # noqa
|
||||
# Import shared test operations
|
||||
# The standard attention operation is automatically registered when imported
|
||||
import tests.compile.test_operations
|
||||
|
||||
BATCH_SIZE = 32
|
||||
MLP_SIZE = 128
|
||||
@ -27,27 +26,6 @@ HIDDEN_SIZE = 1024
|
||||
RANDOM_SEED = 0
|
||||
|
||||
|
||||
def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||
out: torch.Tensor) -> None:
|
||||
out.copy_(q)
|
||||
out += k
|
||||
out += v
|
||||
|
||||
|
||||
def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||
out: torch.Tensor) -> None:
|
||||
return
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="attention",
|
||||
op_func=silly_attention,
|
||||
mutates_args=["out"],
|
||||
fake_impl=silly_attention_fake,
|
||||
target_lib=silly_lib,
|
||||
)
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class ParentModel(nn.Module):
|
||||
|
||||
@ -90,7 +68,7 @@ class Attention(nn.Module):
|
||||
x = self.pre_attn(x)
|
||||
x = self.rms_norm_ref(x)
|
||||
attn_output = torch.empty_like(x)
|
||||
torch.ops.silly_multiple.attention(x, x, x, attn_output)
|
||||
torch.ops.silly.attention(x, x, x, attn_output)
|
||||
x = attn_output
|
||||
x = self.rms_norm_ref(x)
|
||||
x = self.post_attn(x)
|
||||
@ -188,7 +166,7 @@ def test_ignore_torch_compile_decorator():
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = x + x
|
||||
attn_output = torch.empty_like(x)
|
||||
torch.ops.silly_multiple.attention(x, x, x, attn_output)
|
||||
torch.ops.silly.attention(x, x, x, attn_output)
|
||||
x = attn_output
|
||||
x = x * 3
|
||||
return x
|
||||
@ -347,4 +325,4 @@ def test_multi_graph_piecewise_compile_outputs_equal():
|
||||
assert torch.allclose(outputs[0], outputs[1])
|
||||
|
||||
# Expect bitwise equivalence using inductor w/ and w/o cudagraph
|
||||
assert torch.equal(outputs[0], outputs[2])
|
||||
assert torch.equal(outputs[0], outputs[2])
|
||||
@ -7,7 +7,6 @@ can exactly calculate the expected output and side effects.
|
||||
import pytest
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.library import Library
|
||||
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
@ -15,36 +14,15 @@ from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode,
|
||||
VllmConfig, set_current_vllm_config)
|
||||
from vllm.envs import VLLM_USE_V1
|
||||
from vllm.forward_context import BatchDescriptor, set_forward_context
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
global_counter = 0
|
||||
|
||||
# create a library to hold the custom op
|
||||
silly_lib = Library("silly_simple", "FRAGMENT") # noqa
|
||||
|
||||
|
||||
def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||
out: torch.Tensor) -> None:
|
||||
global global_counter
|
||||
global_counter += 1
|
||||
print(f"{global_counter=}")
|
||||
out.copy_(q)
|
||||
out[0] += 1
|
||||
|
||||
|
||||
def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||
out: torch.Tensor) -> None:
|
||||
return
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="attention",
|
||||
op_func=silly_attention,
|
||||
mutates_args=["out"],
|
||||
fake_impl=silly_attention_fake,
|
||||
target_lib=silly_lib,
|
||||
# Import shared test operations
|
||||
from tests.compile.test_operations import (
|
||||
get_global_counter, reset_global_counter, enable_counting_mode
|
||||
)
|
||||
|
||||
# Enable counting mode for this test's specific behavior
|
||||
enable_counting_mode()
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class SillyModel(nn.Module):
|
||||
@ -66,12 +44,12 @@ class SillyModel(nn.Module):
|
||||
x = x + 1
|
||||
x = x + 2
|
||||
out = torch.empty_like(x)
|
||||
torch.ops.silly_simple.attention(x, x, x, out)
|
||||
torch.ops.silly.attention(x, x, x, out)
|
||||
x = out
|
||||
x = x - 2
|
||||
x = x - 1
|
||||
out = torch.empty_like(x)
|
||||
torch.ops.silly_simple.attention(x, x, x, out)
|
||||
torch.ops.silly.attention(x, x, x, out)
|
||||
x = out
|
||||
x = x + 1
|
||||
return x
|
||||
@ -121,13 +99,12 @@ def test_simple_piecewise_compile(use_inductor):
|
||||
model(torch.randn(1).cuda())
|
||||
|
||||
input = torch.zeros(2).cuda()
|
||||
global global_counter
|
||||
global_counter = 0
|
||||
reset_global_counter()
|
||||
with set_forward_context(
|
||||
None,
|
||||
vllm_config=vllm_config,
|
||||
cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
|
||||
batch_descriptor=BatchDescriptor(num_tokens=2, )):
|
||||
output = model(input)
|
||||
assert global_counter == 2
|
||||
assert torch.allclose(output.cpu(), torch.tensor([3., 1.]))
|
||||
assert get_global_counter() == 2
|
||||
assert torch.allclose(output.cpu(), torch.tensor([3., 1.]))
|
||||
@ -14,38 +14,16 @@ from typing import Any, Optional
|
||||
import pytest
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.library import Library
|
||||
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode,
|
||||
VllmConfig, set_current_vllm_config)
|
||||
from vllm.forward_context import BatchDescriptor, set_forward_context
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
# create a library to hold the custom op
|
||||
silly_lib = Library("silly_toy_llama", "FRAGMENT") # noqa
|
||||
|
||||
|
||||
def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||
out: torch.Tensor) -> None:
|
||||
out.copy_(q)
|
||||
out += k
|
||||
out += v
|
||||
|
||||
|
||||
def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||
out: torch.Tensor) -> None:
|
||||
return
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="attention",
|
||||
op_func=silly_attention,
|
||||
mutates_args=["out"],
|
||||
fake_impl=silly_attention_fake,
|
||||
target_lib=silly_lib,
|
||||
)
|
||||
# Import shared test operations
|
||||
# The standard attention operation is automatically registered when imported
|
||||
import tests.compile.test_operations
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -160,7 +138,7 @@ class LlamaAttention(nn.Module):
|
||||
k = k + positions.unsqueeze(1)
|
||||
|
||||
attn_output = torch.empty_like(q)
|
||||
torch.ops.silly_toy_llama.attention(q, k, v, attn_output)
|
||||
torch.ops.silly.attention(q, k, v, attn_output)
|
||||
|
||||
output = self.output_projection(attn_output)
|
||||
return output
|
||||
@ -492,4 +470,4 @@ def benchmark():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
benchmark()
|
||||
benchmark()
|
||||
88
tests/compile/test_operations.py
Normal file
88
tests/compile/test_operations.py
Normal file
@ -0,0 +1,88 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Shared PyTorch custom operations for compilation tests.
|
||||
|
||||
This module provides a centralized place to define and register custom
|
||||
PyTorch operations used across multiple compilation tests. This avoids
|
||||
duplicate operation registrations that would cause RuntimeErrors when
|
||||
running tests together.
|
||||
|
||||
The main "attention" operation is automatically registered when this module
|
||||
is imported. Individual test files can access additional functionality
|
||||
through helper functions.
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch.library import Library
|
||||
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
# Shared library for all compilation test operations
|
||||
# Using "silly" namespace to match existing test expectations
|
||||
silly_lib = Library("silly", "FRAGMENT")
|
||||
|
||||
|
||||
# Global state for test_simple.py compatibility
|
||||
_global_counter = 0
|
||||
_use_counting_mode = False
|
||||
|
||||
|
||||
def get_global_counter():
|
||||
"""Get the current global counter value (for test_simple.py)"""
|
||||
return _global_counter
|
||||
|
||||
|
||||
def reset_global_counter():
|
||||
"""Reset the global counter to 0 (for test_simple.py)"""
|
||||
global _global_counter
|
||||
_global_counter = 0
|
||||
|
||||
|
||||
def enable_counting_mode():
|
||||
"""Enable counting mode for test_simple.py"""
|
||||
global _use_counting_mode
|
||||
_use_counting_mode = True
|
||||
reset_global_counter()
|
||||
|
||||
|
||||
def disable_counting_mode():
|
||||
"""Disable counting mode"""
|
||||
global _use_counting_mode
|
||||
_use_counting_mode = False
|
||||
|
||||
|
||||
def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||
out: torch.Tensor) -> None:
|
||||
"""
|
||||
Unified attention implementation that can handle both standard and counting modes.
|
||||
"""
|
||||
global _global_counter, _use_counting_mode
|
||||
|
||||
if _use_counting_mode:
|
||||
# Counting mode for test_simple.py
|
||||
_global_counter += 1
|
||||
print(f"global_counter={_global_counter}")
|
||||
out.copy_(q)
|
||||
out[0] += 1
|
||||
else:
|
||||
# Standard mode for test_multiple_graphs.py and test_toy_llama.py
|
||||
out.copy_(q)
|
||||
out += k
|
||||
out += v
|
||||
|
||||
|
||||
def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||
out: torch.Tensor) -> None:
|
||||
"""Fake implementation for testing"""
|
||||
return
|
||||
|
||||
|
||||
# Register the unified attention operation
|
||||
direct_register_custom_op(
|
||||
op_name="attention",
|
||||
op_func=silly_attention,
|
||||
mutates_args=["out"],
|
||||
fake_impl=silly_attention_fake,
|
||||
target_lib=silly_lib,
|
||||
)
|
||||
Loading…
x
Reference in New Issue
Block a user