mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 06:45:01 +08:00
[custom_op][vllm-plugin] update custom_op class to use op_registry (#19164)
Signed-off-by: Chendi.Xue <chendi.xue@intel.com>
This commit is contained in:
parent
f1e840e842
commit
7e8977fcd4
@ -10,5 +10,7 @@ setup(
|
|||||||
entry_points={
|
entry_points={
|
||||||
'vllm.platform_plugins': [
|
'vllm.platform_plugins': [
|
||||||
"dummy_platform_plugin = vllm_add_dummy_platform:dummy_platform_plugin" # noqa
|
"dummy_platform_plugin = vllm_add_dummy_platform:dummy_platform_plugin" # noqa
|
||||||
]
|
],
|
||||||
|
"vllm.general_plugins":
|
||||||
|
["dummy_custom_ops = vllm_add_dummy_platform:register_ops"],
|
||||||
})
|
})
|
||||||
|
|||||||
@ -6,3 +6,7 @@ from typing import Optional
|
|||||||
|
|
||||||
def dummy_platform_plugin() -> Optional[str]:
|
def dummy_platform_plugin() -> Optional[str]:
|
||||||
return "vllm_add_dummy_platform.dummy_platform.DummyPlatform"
|
return "vllm_add_dummy_platform.dummy_platform.DummyPlatform"
|
||||||
|
|
||||||
|
|
||||||
|
def register_ops():
|
||||||
|
import vllm_add_dummy_platform.dummy_custom_ops # noqa
|
||||||
|
|||||||
@ -1,10 +1,11 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
from vllm.attention.backends.flash_attn import FlashAttentionBackend
|
from vllm.attention.backends.placeholder_attn import (
|
||||||
|
PlaceholderAttentionBackend)
|
||||||
|
|
||||||
|
|
||||||
class DummyAttentionBackend(FlashAttentionBackend):
|
class DummyAttentionBackend(PlaceholderAttentionBackend):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_name() -> str:
|
def get_name() -> str:
|
||||||
|
|||||||
@ -0,0 +1,20 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
|
||||||
|
|
||||||
|
|
||||||
|
# Register CustomRotaryEmbedding to CustomOP.
|
||||||
|
@RotaryEmbedding.register_oot
|
||||||
|
class DummyRotaryEmbedding(RotaryEmbedding):
|
||||||
|
"""Original rotary positional embedding."""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.addition_config = True
|
||||||
|
|
||||||
|
def forward_oot(self, *args,
|
||||||
|
**kwargs) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
return super().forward_oot(*args, **kwargs)
|
||||||
@ -1,12 +1,29 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from vllm.platforms.cuda import CudaPlatform
|
from vllm.platforms.interface import Platform, PlatformEnum
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from vllm.config import VllmConfig
|
||||||
|
else:
|
||||||
|
VllmConfig = None
|
||||||
|
from vllm import envs
|
||||||
|
|
||||||
|
|
||||||
class DummyPlatform(CudaPlatform):
|
class DummyPlatform(Platform):
|
||||||
|
_enum = PlatformEnum.OOT
|
||||||
device_name = "DummyDevice"
|
device_name = "DummyDevice"
|
||||||
|
device_type: str = "privateuseone"
|
||||||
|
dispatch_key: str = "PrivateUse1"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
|
||||||
|
if envs.VLLM_USE_V1:
|
||||||
|
compilation_config = vllm_config.compilation_config
|
||||||
|
# Activate custom ops for v1.
|
||||||
|
compilation_config.custom_ops = ["all"]
|
||||||
|
|
||||||
def get_attn_backend_cls(self, backend_name, head_size, dtype,
|
def get_attn_backend_cls(self, backend_name, head_size, dtype,
|
||||||
kv_cache_dtype, block_size, use_v1, use_mla):
|
kv_cache_dtype, block_size, use_v1, use_mla):
|
||||||
return "vllm_add_dummy_platform.dummy_attention_backend.DummyAttentionBackend" # noqa E501
|
return "vllm_add_dummy_platform.dummy_attention_backend.DummyAttentionBackend" # noqa E501
|
||||||
@ -5,6 +5,7 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.attention.selector import get_attn_backend
|
from vllm.attention.selector import get_attn_backend
|
||||||
|
from vllm.plugins import load_general_plugins
|
||||||
from vllm.utils import STR_BACKEND_ENV_VAR, STR_INVALID_VAL
|
from vllm.utils import STR_BACKEND_ENV_VAR, STR_INVALID_VAL
|
||||||
|
|
||||||
|
|
||||||
@ -32,3 +33,16 @@ def test_oot_attention_backend(monkeypatch: pytest.MonkeyPatch):
|
|||||||
m.setenv(STR_BACKEND_ENV_VAR, STR_INVALID_VAL)
|
m.setenv(STR_BACKEND_ENV_VAR, STR_INVALID_VAL)
|
||||||
backend = get_attn_backend(16, torch.float16, "auto", 16, False)
|
backend = get_attn_backend(16, torch.float16, "auto", 16, False)
|
||||||
assert backend.get_name() == "Dummy_Backend"
|
assert backend.get_name() == "Dummy_Backend"
|
||||||
|
|
||||||
|
|
||||||
|
def test_oot_custom_op(monkeypatch: pytest.MonkeyPatch):
|
||||||
|
# simulate workload by running an example
|
||||||
|
load_general_plugins()
|
||||||
|
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
|
||||||
|
layer = RotaryEmbedding(16, 16, 16, 16, True, torch.float16)
|
||||||
|
assert layer.__class__.__name__ == "DummyRotaryEmbedding", (
|
||||||
|
f"Expected DummyRotaryEmbedding, got {layer.__class__.__name__}, "
|
||||||
|
"possibly because the custom op is not registered correctly.")
|
||||||
|
assert hasattr(layer, "addition_config"), (
|
||||||
|
"Expected DummyRotaryEmbedding to have an 'addition_config' attribute, "
|
||||||
|
"which is set by the custom op.")
|
||||||
|
|||||||
@ -1,6 +1,8 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from vllm.config import get_current_vllm_config
|
from vllm.config import get_current_vllm_config
|
||||||
@ -16,6 +18,24 @@ class CustomOp(nn.Module):
|
|||||||
Dispatches the forward method to the appropriate backend.
|
Dispatches the forward method to the appropriate backend.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def __new__(cls, *args, **kwargs):
|
||||||
|
try:
|
||||||
|
op_name = cls.__name__
|
||||||
|
except AttributeError:
|
||||||
|
raise TypeError(
|
||||||
|
f"Cannot instantiate '{cls.__name__}': its 'name' attribute "
|
||||||
|
f"was not set, possibly because it was not decorated with "
|
||||||
|
f"@CustomOp.register, or it's the CustomOp base class itself."
|
||||||
|
) from None
|
||||||
|
|
||||||
|
if op_name not in cls.op_registry_oot:
|
||||||
|
op_cls_to_instantiate = cls
|
||||||
|
else:
|
||||||
|
op_cls_to_instantiate = cls.op_registry_oot[op_name]
|
||||||
|
logger.debug("Instantiating custom op: %s using %s", op_name,
|
||||||
|
str(op_cls_to_instantiate))
|
||||||
|
return super().__new__(op_cls_to_instantiate)
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._forward_method = self.dispatch_forward()
|
self._forward_method = self.dispatch_forward()
|
||||||
@ -138,6 +158,7 @@ class CustomOp(nn.Module):
|
|||||||
# - MyOp.enabled()
|
# - MyOp.enabled()
|
||||||
# - op_registry["my_op"].enabled()
|
# - op_registry["my_op"].enabled()
|
||||||
op_registry: dict[str, type['CustomOp']] = {}
|
op_registry: dict[str, type['CustomOp']] = {}
|
||||||
|
op_registry_oot: dict[str, type['CustomOp']] = {}
|
||||||
|
|
||||||
# Decorator to register custom ops.
|
# Decorator to register custom ops.
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -150,3 +171,38 @@ class CustomOp(nn.Module):
|
|||||||
return op_cls
|
return op_cls
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
# Decorator to register out-of-tree(oot) custom ops.
|
||||||
|
# For OOT custom ops:
|
||||||
|
# if in-tree layer class is registered with an oot_custom_op layer,
|
||||||
|
# the oot_custom_op layer will be used instead.
|
||||||
|
# Example:
|
||||||
|
# - @UnquantizedFusedMoEMethod.register_oot
|
||||||
|
# class HPUUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod)
|
||||||
|
# or
|
||||||
|
# - @CustomOP.register_oot(name="UnquantizedFusedMoEMethod")
|
||||||
|
@classmethod
|
||||||
|
def register_oot(cls, _decorated_op_cls=None, name: Optional[str] = None):
|
||||||
|
|
||||||
|
def decorator(op_cls):
|
||||||
|
reg_name = name if name is not None else cls.__name__
|
||||||
|
assert reg_name not in cls.op_registry_oot, \
|
||||||
|
f"Duplicate op name: {reg_name}"
|
||||||
|
op_cls.name = reg_name
|
||||||
|
cls.op_registry_oot[reg_name] = op_cls
|
||||||
|
return op_cls
|
||||||
|
|
||||||
|
if _decorated_op_cls is None:
|
||||||
|
# Called with parentheses: @CustomOP.register_oot()
|
||||||
|
# or @CustomOP.register_oot(name="...")
|
||||||
|
# So, _decorated_op_cls is None.
|
||||||
|
# We return the actual decorator function.
|
||||||
|
return decorator
|
||||||
|
elif isinstance(_decorated_op_cls, type): # Check if it's a class
|
||||||
|
# Called without parentheses: @CustomOP.register_oot
|
||||||
|
# The first argument is the class itself.
|
||||||
|
# We call the 'decorator' function immediately with the class.
|
||||||
|
return decorator(_decorated_op_cls)
|
||||||
|
else:
|
||||||
|
# Handle other unexpected cases if necessary
|
||||||
|
raise TypeError("Decorator can only be applied to classes.")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user