mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-21 07:47:00 +08:00
Refactor test_decorator.py to use shared testing_ops module
Co-authored-by: ProExpertProg <11367180+ProExpertProg@users.noreply.github.com>
This commit is contained in:
parent
2b81d5fd2f
commit
afcb616e89
@ -2,44 +2,20 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.library import Library
|
|
||||||
|
|
||||||
|
# This import automatically registers torch ops for testing (like silly.attention)
|
||||||
|
import tests.compile.testing_ops
|
||||||
from vllm.compilation.counter import compilation_counter
|
from vllm.compilation.counter import compilation_counter
|
||||||
from vllm.compilation.decorators import (ignore_torch_compile,
|
from vllm.compilation.decorators import (ignore_torch_compile,
|
||||||
support_torch_compile)
|
support_torch_compile)
|
||||||
from vllm.config import (CacheConfig, CompilationConfig, CompilationLevel,
|
from vllm.config import (CacheConfig, CompilationConfig, CompilationLevel,
|
||||||
CUDAGraphMode, VllmConfig, set_current_vllm_config)
|
CUDAGraphMode, VllmConfig, set_current_vllm_config)
|
||||||
from vllm.forward_context import BatchDescriptor, set_forward_context
|
from vllm.forward_context import BatchDescriptor, set_forward_context
|
||||||
from vllm.utils import direct_register_custom_op
|
|
||||||
|
|
||||||
# create a library to hold the custom op
|
|
||||||
silly_lib = Library("silly", "FRAGMENT") # noqa
|
|
||||||
|
|
||||||
BATCH_SIZE = 32
|
BATCH_SIZE = 32
|
||||||
MLP_SIZE = 128
|
MLP_SIZE = 128
|
||||||
|
|
||||||
|
|
||||||
def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
|
||||||
out: torch.Tensor) -> None:
|
|
||||||
out.copy_(q)
|
|
||||||
out += k
|
|
||||||
out += v
|
|
||||||
|
|
||||||
|
|
||||||
def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
|
||||||
out: torch.Tensor) -> None:
|
|
||||||
return
|
|
||||||
|
|
||||||
|
|
||||||
direct_register_custom_op(
|
|
||||||
op_name="attention",
|
|
||||||
op_func=silly_attention,
|
|
||||||
mutates_args=["out"],
|
|
||||||
fake_impl=silly_attention_fake,
|
|
||||||
target_lib=silly_lib,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@torch.inference_mode
|
@torch.inference_mode
|
||||||
def run_model(vllm_config: VllmConfig, model: nn.Module,
|
def run_model(vllm_config: VllmConfig, model: nn.Module,
|
||||||
cudagraph_runtime_mode: CUDAGraphMode):
|
cudagraph_runtime_mode: CUDAGraphMode):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user