mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-29 18:27:12 +08:00
Refactor duplicate torch operation registrations to use shared module
Instead of changing library names (not scalable), create a shared test_operations.py module that: - Provides a single "silly" library for all compilation tests - Registers a unified attention operation that can handle both standard and counting modes - Eliminates duplicate registration errors when running all tests together - Maintains backward compatibility with existing test behavior Addresses feedback to make the solution more scalable and maintainable. Co-authored-by: ProExpertProg <11367180+ProExpertProg@users.noreply.github.com>
This commit is contained in:
parent
47dcf0940f
commit
2c81fbbb3c
@ -6,7 +6,6 @@ are compiled and graph captured separately.
|
|||||||
"""
|
"""
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.library import Library
|
|
||||||
|
|
||||||
from vllm.compilation.backends import set_model_tag
|
from vllm.compilation.backends import set_model_tag
|
||||||
from vllm.compilation.counter import compilation_counter
|
from vllm.compilation.counter import compilation_counter
|
||||||
@ -16,10 +15,10 @@ from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig,
|
|||||||
set_current_vllm_config)
|
set_current_vllm_config)
|
||||||
from vllm.envs import VLLM_USE_V1
|
from vllm.envs import VLLM_USE_V1
|
||||||
from vllm.forward_context import set_forward_context
|
from vllm.forward_context import set_forward_context
|
||||||
from vllm.utils import direct_register_custom_op
|
|
||||||
|
|
||||||
# create a library to hold the custom op
|
# Import shared test operations
|
||||||
silly_lib = Library("silly_multiple", "FRAGMENT") # noqa
|
# The standard attention operation is automatically registered when imported
|
||||||
|
import tests.compile.test_operations
|
||||||
|
|
||||||
BATCH_SIZE = 32
|
BATCH_SIZE = 32
|
||||||
MLP_SIZE = 128
|
MLP_SIZE = 128
|
||||||
@ -27,27 +26,6 @@ HIDDEN_SIZE = 1024
|
|||||||
RANDOM_SEED = 0
|
RANDOM_SEED = 0
|
||||||
|
|
||||||
|
|
||||||
def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
|
||||||
out: torch.Tensor) -> None:
|
|
||||||
out.copy_(q)
|
|
||||||
out += k
|
|
||||||
out += v
|
|
||||||
|
|
||||||
|
|
||||||
def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
|
||||||
out: torch.Tensor) -> None:
|
|
||||||
return
|
|
||||||
|
|
||||||
|
|
||||||
direct_register_custom_op(
|
|
||||||
op_name="attention",
|
|
||||||
op_func=silly_attention,
|
|
||||||
mutates_args=["out"],
|
|
||||||
fake_impl=silly_attention_fake,
|
|
||||||
target_lib=silly_lib,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@support_torch_compile
|
@support_torch_compile
|
||||||
class ParentModel(nn.Module):
|
class ParentModel(nn.Module):
|
||||||
|
|
||||||
@ -90,7 +68,7 @@ class Attention(nn.Module):
|
|||||||
x = self.pre_attn(x)
|
x = self.pre_attn(x)
|
||||||
x = self.rms_norm_ref(x)
|
x = self.rms_norm_ref(x)
|
||||||
attn_output = torch.empty_like(x)
|
attn_output = torch.empty_like(x)
|
||||||
torch.ops.silly_multiple.attention(x, x, x, attn_output)
|
torch.ops.silly.attention(x, x, x, attn_output)
|
||||||
x = attn_output
|
x = attn_output
|
||||||
x = self.rms_norm_ref(x)
|
x = self.rms_norm_ref(x)
|
||||||
x = self.post_attn(x)
|
x = self.post_attn(x)
|
||||||
@ -188,7 +166,7 @@ def test_ignore_torch_compile_decorator():
|
|||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
x = x + x
|
x = x + x
|
||||||
attn_output = torch.empty_like(x)
|
attn_output = torch.empty_like(x)
|
||||||
torch.ops.silly_multiple.attention(x, x, x, attn_output)
|
torch.ops.silly.attention(x, x, x, attn_output)
|
||||||
x = attn_output
|
x = attn_output
|
||||||
x = x * 3
|
x = x * 3
|
||||||
return x
|
return x
|
||||||
|
|||||||
@ -7,7 +7,6 @@ can exactly calculate the expected output and side effects.
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.library import Library
|
|
||||||
|
|
||||||
from vllm.compilation.counter import compilation_counter
|
from vllm.compilation.counter import compilation_counter
|
||||||
from vllm.compilation.decorators import support_torch_compile
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
@ -15,36 +14,15 @@ from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode,
|
|||||||
VllmConfig, set_current_vllm_config)
|
VllmConfig, set_current_vllm_config)
|
||||||
from vllm.envs import VLLM_USE_V1
|
from vllm.envs import VLLM_USE_V1
|
||||||
from vllm.forward_context import BatchDescriptor, set_forward_context
|
from vllm.forward_context import BatchDescriptor, set_forward_context
|
||||||
from vllm.utils import direct_register_custom_op
|
|
||||||
|
|
||||||
global_counter = 0
|
# Import shared test operations
|
||||||
|
from tests.compile.test_operations import (
|
||||||
# create a library to hold the custom op
|
get_global_counter, reset_global_counter, enable_counting_mode
|
||||||
silly_lib = Library("silly_simple", "FRAGMENT") # noqa
|
|
||||||
|
|
||||||
|
|
||||||
def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
|
||||||
out: torch.Tensor) -> None:
|
|
||||||
global global_counter
|
|
||||||
global_counter += 1
|
|
||||||
print(f"{global_counter=}")
|
|
||||||
out.copy_(q)
|
|
||||||
out[0] += 1
|
|
||||||
|
|
||||||
|
|
||||||
def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
|
||||||
out: torch.Tensor) -> None:
|
|
||||||
return
|
|
||||||
|
|
||||||
|
|
||||||
direct_register_custom_op(
|
|
||||||
op_name="attention",
|
|
||||||
op_func=silly_attention,
|
|
||||||
mutates_args=["out"],
|
|
||||||
fake_impl=silly_attention_fake,
|
|
||||||
target_lib=silly_lib,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Enable counting mode for this test's specific behavior
|
||||||
|
enable_counting_mode()
|
||||||
|
|
||||||
|
|
||||||
@support_torch_compile
|
@support_torch_compile
|
||||||
class SillyModel(nn.Module):
|
class SillyModel(nn.Module):
|
||||||
@ -66,12 +44,12 @@ class SillyModel(nn.Module):
|
|||||||
x = x + 1
|
x = x + 1
|
||||||
x = x + 2
|
x = x + 2
|
||||||
out = torch.empty_like(x)
|
out = torch.empty_like(x)
|
||||||
torch.ops.silly_simple.attention(x, x, x, out)
|
torch.ops.silly.attention(x, x, x, out)
|
||||||
x = out
|
x = out
|
||||||
x = x - 2
|
x = x - 2
|
||||||
x = x - 1
|
x = x - 1
|
||||||
out = torch.empty_like(x)
|
out = torch.empty_like(x)
|
||||||
torch.ops.silly_simple.attention(x, x, x, out)
|
torch.ops.silly.attention(x, x, x, out)
|
||||||
x = out
|
x = out
|
||||||
x = x + 1
|
x = x + 1
|
||||||
return x
|
return x
|
||||||
@ -121,13 +99,12 @@ def test_simple_piecewise_compile(use_inductor):
|
|||||||
model(torch.randn(1).cuda())
|
model(torch.randn(1).cuda())
|
||||||
|
|
||||||
input = torch.zeros(2).cuda()
|
input = torch.zeros(2).cuda()
|
||||||
global global_counter
|
reset_global_counter()
|
||||||
global_counter = 0
|
|
||||||
with set_forward_context(
|
with set_forward_context(
|
||||||
None,
|
None,
|
||||||
vllm_config=vllm_config,
|
vllm_config=vllm_config,
|
||||||
cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
|
cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
|
||||||
batch_descriptor=BatchDescriptor(num_tokens=2, )):
|
batch_descriptor=BatchDescriptor(num_tokens=2, )):
|
||||||
output = model(input)
|
output = model(input)
|
||||||
assert global_counter == 2
|
assert get_global_counter() == 2
|
||||||
assert torch.allclose(output.cpu(), torch.tensor([3., 1.]))
|
assert torch.allclose(output.cpu(), torch.tensor([3., 1.]))
|
||||||
@ -14,38 +14,16 @@ from typing import Any, Optional
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.library import Library
|
|
||||||
|
|
||||||
from vllm.compilation.counter import compilation_counter
|
from vllm.compilation.counter import compilation_counter
|
||||||
from vllm.compilation.decorators import support_torch_compile
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode,
|
from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode,
|
||||||
VllmConfig, set_current_vllm_config)
|
VllmConfig, set_current_vllm_config)
|
||||||
from vllm.forward_context import BatchDescriptor, set_forward_context
|
from vllm.forward_context import BatchDescriptor, set_forward_context
|
||||||
from vllm.utils import direct_register_custom_op
|
|
||||||
|
|
||||||
# create a library to hold the custom op
|
# Import shared test operations
|
||||||
silly_lib = Library("silly_toy_llama", "FRAGMENT") # noqa
|
# The standard attention operation is automatically registered when imported
|
||||||
|
import tests.compile.test_operations
|
||||||
|
|
||||||
def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
|
||||||
out: torch.Tensor) -> None:
|
|
||||||
out.copy_(q)
|
|
||||||
out += k
|
|
||||||
out += v
|
|
||||||
|
|
||||||
|
|
||||||
def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
|
||||||
out: torch.Tensor) -> None:
|
|
||||||
return
|
|
||||||
|
|
||||||
|
|
||||||
direct_register_custom_op(
|
|
||||||
op_name="attention",
|
|
||||||
op_func=silly_attention,
|
|
||||||
mutates_args=["out"],
|
|
||||||
fake_impl=silly_attention_fake,
|
|
||||||
target_lib=silly_lib,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -160,7 +138,7 @@ class LlamaAttention(nn.Module):
|
|||||||
k = k + positions.unsqueeze(1)
|
k = k + positions.unsqueeze(1)
|
||||||
|
|
||||||
attn_output = torch.empty_like(q)
|
attn_output = torch.empty_like(q)
|
||||||
torch.ops.silly_toy_llama.attention(q, k, v, attn_output)
|
torch.ops.silly.attention(q, k, v, attn_output)
|
||||||
|
|
||||||
output = self.output_projection(attn_output)
|
output = self.output_projection(attn_output)
|
||||||
return output
|
return output
|
||||||
|
|||||||
88
tests/compile/test_operations.py
Normal file
88
tests/compile/test_operations.py
Normal file
@ -0,0 +1,88 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
"""
|
||||||
|
Shared PyTorch custom operations for compilation tests.
|
||||||
|
|
||||||
|
This module provides a centralized place to define and register custom
|
||||||
|
PyTorch operations used across multiple compilation tests. This avoids
|
||||||
|
duplicate operation registrations that would cause RuntimeErrors when
|
||||||
|
running tests together.
|
||||||
|
|
||||||
|
The main "attention" operation is automatically registered when this module
|
||||||
|
is imported. Individual test files can access additional functionality
|
||||||
|
through helper functions.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.library import Library
|
||||||
|
|
||||||
|
from vllm.utils import direct_register_custom_op
|
||||||
|
|
||||||
|
# Shared library for all compilation test operations
|
||||||
|
# Using "silly" namespace to match existing test expectations
|
||||||
|
silly_lib = Library("silly", "FRAGMENT")
|
||||||
|
|
||||||
|
|
||||||
|
# Global state for test_simple.py compatibility
|
||||||
|
_global_counter = 0
|
||||||
|
_use_counting_mode = False
|
||||||
|
|
||||||
|
|
||||||
|
def get_global_counter():
|
||||||
|
"""Get the current global counter value (for test_simple.py)"""
|
||||||
|
return _global_counter
|
||||||
|
|
||||||
|
|
||||||
|
def reset_global_counter():
|
||||||
|
"""Reset the global counter to 0 (for test_simple.py)"""
|
||||||
|
global _global_counter
|
||||||
|
_global_counter = 0
|
||||||
|
|
||||||
|
|
||||||
|
def enable_counting_mode():
|
||||||
|
"""Enable counting mode for test_simple.py"""
|
||||||
|
global _use_counting_mode
|
||||||
|
_use_counting_mode = True
|
||||||
|
reset_global_counter()
|
||||||
|
|
||||||
|
|
||||||
|
def disable_counting_mode():
|
||||||
|
"""Disable counting mode"""
|
||||||
|
global _use_counting_mode
|
||||||
|
_use_counting_mode = False
|
||||||
|
|
||||||
|
|
||||||
|
def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||||
|
out: torch.Tensor) -> None:
|
||||||
|
"""
|
||||||
|
Unified attention implementation that can handle both standard and counting modes.
|
||||||
|
"""
|
||||||
|
global _global_counter, _use_counting_mode
|
||||||
|
|
||||||
|
if _use_counting_mode:
|
||||||
|
# Counting mode for test_simple.py
|
||||||
|
_global_counter += 1
|
||||||
|
print(f"global_counter={_global_counter}")
|
||||||
|
out.copy_(q)
|
||||||
|
out[0] += 1
|
||||||
|
else:
|
||||||
|
# Standard mode for test_multiple_graphs.py and test_toy_llama.py
|
||||||
|
out.copy_(q)
|
||||||
|
out += k
|
||||||
|
out += v
|
||||||
|
|
||||||
|
|
||||||
|
def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||||
|
out: torch.Tensor) -> None:
|
||||||
|
"""Fake implementation for testing"""
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
# Register the unified attention operation
|
||||||
|
direct_register_custom_op(
|
||||||
|
op_name="attention",
|
||||||
|
op_func=silly_attention,
|
||||||
|
mutates_args=["out"],
|
||||||
|
fake_impl=silly_attention_fake,
|
||||||
|
target_lib=silly_lib,
|
||||||
|
)
|
||||||
Loading…
x
Reference in New Issue
Block a user