[torch.compile] Fine-grained CustomOp enabling mechanism (#9300)

This commit is contained in:
Luka Govedič 2024-10-17 14:36:37 -04:00 committed by GitHub
parent 7871659abb
commit 0f41fbe5a3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 220 additions and 21 deletions

View File

@ -0,0 +1,92 @@
import os
from typing import List
import pytest
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.activation import (GeluAndMul,
ReLUSquaredActivation,
SiluAndMul)
from vllm.model_executor.layers.layernorm import RMSNorm
# Registered subclass for test
@CustomOp.register("relu3")
class Relu3(ReLUSquaredActivation):
pass
@pytest.mark.parametrize(
"env, torch_level, ops_enabled, default_on",
[
# Default values based on compile level
("", 0, [True] * 4, True),
("", 1, [True] * 4, True),
("", 2, [True] * 4, True), # All by default
("", 3, [False] * 4, False),
("", 4, [False] * 4, False), # None by default
# Explicitly enabling/disabling
#
# Default: all
#
# All but SiluAndMul
("+rms_norm,-silu_and_mul", 0, [1, 0, 1, 1], True),
# Only ReLU3
("none,-rms_norm,+relu3", 0, [0, 0, 0, 1], False),
# All but SiluAndMul
("all,-silu_and_mul", 1, [1, 0, 1, 1], True),
# All but ReLU3 (even if ReLU2 is on)
("-relu3,relu2", 1, [1, 1, 1, 0], True),
# GeluAndMul and SiluAndMul
("none,-relu3,+gelu_and_mul,+silu_and_mul", 2, [0, 1, 1, 0], False),
# All but RMSNorm
("-rms_norm", 2, [0, 1, 1, 1], True),
#
# Default: none
#
# Only ReLU3
("-silu_and_mul,+relu3", 3, [0, 0, 0, 1], False),
# All but RMSNorm
("all,-rms_norm", 4, [0, 1, 1, 1], True),
])
def test_enabled_ops(env: str, torch_level: int, ops_enabled: List[int],
default_on: bool):
os.environ["VLLM_CUSTOM_OPS"] = env
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(torch_level)
# Reset default_on (computed once):
CustomOp.default_on.cache_clear()
assert CustomOp.default_on() == default_on
ops_enabled = [bool(x) for x in ops_enabled]
assert RMSNorm(1024).enabled() == ops_enabled[0]
assert CustomOp.op_registry["rms_norm"].enabled() == ops_enabled[0]
assert SiluAndMul().enabled() == ops_enabled[1]
assert CustomOp.op_registry["silu_and_mul"].enabled() == ops_enabled[1]
assert GeluAndMul().enabled() == ops_enabled[2]
assert CustomOp.op_registry["gelu_and_mul"].enabled() == ops_enabled[2]
# If registered, subclasses should follow their own name
assert Relu3().enabled() == ops_enabled[3]
assert CustomOp.op_registry["relu3"].enabled() == ops_enabled[3]
# Unregistered subclass
class SiluAndMul2(SiluAndMul):
pass
# Subclasses should not require registration
assert SiluAndMul2().enabled() == SiluAndMul().enabled()
@pytest.mark.parametrize(
"env", ["all,none", "all,+rms_norm,all", "+rms_norm,-rms_norm"])
def test_enabled_ops_invalid(env: str):
os.environ["VLLM_CUSTOM_OPS"] = env
CustomOp.default_on.cache_clear()
with pytest.raises(AssertionError):
RMSNorm(1024).enabled()

View File

@ -65,6 +65,7 @@ if TYPE_CHECKING:
VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False
VLLM_SKIP_P2P_CHECK: bool = False VLLM_SKIP_P2P_CHECK: bool = False
VLLM_TORCH_COMPILE_LEVEL: int = 0 VLLM_TORCH_COMPILE_LEVEL: int = 0
VLLM_CUSTOM_OPS: List[str] = []
VLLM_DISABLED_KERNELS: List[str] = [] VLLM_DISABLED_KERNELS: List[str] = []
@ -205,7 +206,17 @@ environment_variables: Dict[str, Callable[[], Any]] = {
os.environ.get("VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE", "1") != "0"), os.environ.get("VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE", "1") != "0"),
"VLLM_TORCH_COMPILE_LEVEL": "VLLM_TORCH_COMPILE_LEVEL":
lambda: int(os.environ.get("VLLM_TORCH_COMPILE_LEVEL", "0")), lambda: int(os.environ.get("VLLM_TORCH_COMPILE_LEVEL", "0")),
# Fine-grained control over which custom ops to enable/disable.
# Use 'all' to enable all, 'none' to disable all.
# Also specify a list of custom op names to enable (prefixed with a '+'),
# or disable (prefixed with a '-').
# Examples:
# - 'all,-op1' to enable all except op1
# - 'none,+op1,+op2' to enable only op1 and op2
# By default, all custom ops are enabled when running without Inductor
# and disabled when running with Inductor (compile_level >= Inductor).
"VLLM_CUSTOM_OPS":
lambda: os.environ.get("VLLM_CUSTOM_OPS", "").replace(" ", "").split(","),
# local rank of the process in the distributed setting, used to determine # local rank of the process in the distributed setting, used to determine
# the GPU device id # the GPU device id
"LOCAL_RANK": "LOCAL_RANK":

View File

@ -1,14 +1,24 @@
from functools import lru_cache
from typing import Dict, Type
import torch.nn as nn import torch.nn as nn
import vllm.envs as envs import vllm.envs as envs
from vllm.compilation.levels import CompilationLevel from vllm.compilation.levels import CompilationLevel
from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import is_cpu, is_hip, is_xpu from vllm.utils import is_cpu, is_hip, is_xpu, print_warning_once
logger = init_logger(__name__)
class CustomOp(nn.Module): class CustomOp(nn.Module):
"""
Base class for custom ops.
Dispatches the forward method to the appropriate backend.
"""
def __init__(self, *args, **kwargs): def __init__(self):
super().__init__() super().__init__()
self._forward_method = self.dispatch_forward() self._forward_method = self.dispatch_forward()
@ -17,7 +27,6 @@ class CustomOp(nn.Module):
def forward_native(self, *args, **kwargs): def forward_native(self, *args, **kwargs):
"""PyTorch-native implementation of the forward method. """PyTorch-native implementation of the forward method.
This method is optional. If implemented, it can be used with compilers This method is optional. If implemented, it can be used with compilers
such as torch.compile or PyTorch XLA. Also, it can be used for testing such as torch.compile or PyTorch XLA. Also, it can be used for testing
purposes. purposes.
@ -56,7 +65,11 @@ class CustomOp(nn.Module):
# NOTE(woosuk): Here we assume that vLLM was built for only one # NOTE(woosuk): Here we assume that vLLM was built for only one
# specific backend. Currently, we do not support dynamic dispatching. # specific backend. Currently, we do not support dynamic dispatching.
if envs.VLLM_TORCH_COMPILE_LEVEL >= CompilationLevel.INDUCTOR: enabled = self.enabled()
logger.debug("custom op %s %s", self.__class__.name,
"enabled" if enabled else "disabled")
if not enabled:
return self.forward_native return self.forward_native
if is_hip(): if is_hip():
@ -69,3 +82,50 @@ class CustomOp(nn.Module):
return self.forward_xpu return self.forward_xpu
else: else:
return self.forward_cuda return self.forward_cuda
@classmethod
def enabled(cls) -> bool:
# if no name, then it was not registered
if not hasattr(cls, "name"):
print_warning_once(
f"Custom op {cls.__name__} was not registered, "
f"which means it won't appear in the op registry. "
f"It will be enabled/disabled based on the global settings.")
return CustomOp.default_on()
enabled = f"+{cls.name}" in envs.VLLM_CUSTOM_OPS
disabled = f"-{cls.name}" in envs.VLLM_CUSTOM_OPS
assert not (enabled
and disabled), f"Cannot enable and disable {cls.name}"
return (CustomOp.default_on() or enabled) and not disabled
# On by default if VLLM_TORCH_COMPILE_LEVEL < CompilationLevel.INDUCTOR
# Specifying 'all' or 'none' in VLLM_CUSTOM_OPS takes precedence.
@staticmethod
@lru_cache()
def default_on() -> bool:
count_none = envs.VLLM_CUSTOM_OPS.count("none")
count_all = envs.VLLM_CUSTOM_OPS.count("all")
assert count_none + count_all <= 1, "Can only specify 'none' or 'all'"
return envs.VLLM_TORCH_COMPILE_LEVEL < CompilationLevel.INDUCTOR and \
not count_none > 0 or count_all > 0
# Dictionary of all custom ops (classes, indexed by registered name).
# To check if an op with a name is enabled, call .enabled() on the class.
# Examples:
# - MyOp.enabled()
# - op_registry["my_op"].enabled()
op_registry: Dict[str, Type['CustomOp']] = {}
# Decorator to register custom ops.
@classmethod
def register(cls, name: str):
def decorator(op_cls):
assert name not in cls.op_registry, f"Duplicate op name: {name}"
op_cls.name = name
cls.op_registry[name] = op_cls
return op_cls
return decorator

View File

@ -11,11 +11,13 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank,
from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.utils import LazyDict
@CustomOp.register("fatrelu_and_mul")
class FatreluAndMul(CustomOp): class FatreluAndMul(CustomOp):
"""An activation function for FATReLU. """An activation function for FATReLU.
The function computes x -> FATReLU(x[:d]) * x[d:] where The function computes x -> FATReLU(x[:d]) * x[d:] where
d = x.shape[-1] // 2. d = x.shape[-1] // 2.
This is used in openbmb/MiniCPM-S-1B-sft. This is used in openbmb/MiniCPM-S-1B-sft.
@ -40,6 +42,7 @@ class FatreluAndMul(CustomOp):
return self.forward_native(x) return self.forward_native(x)
@CustomOp.register("silu_and_mul")
class SiluAndMul(CustomOp): class SiluAndMul(CustomOp):
"""An activation function for SwiGLU. """An activation function for SwiGLU.
@ -74,6 +77,7 @@ class SiluAndMul(CustomOp):
return out return out
@CustomOp.register("gelu_and_mul")
class GeluAndMul(CustomOp): class GeluAndMul(CustomOp):
"""An activation function for GeGLU. """An activation function for GeGLU.
@ -123,6 +127,7 @@ class GeluAndMul(CustomOp):
return f'approximate={repr(self.approximate)}' return f'approximate={repr(self.approximate)}'
@CustomOp.register("gelu_new")
class NewGELU(CustomOp): class NewGELU(CustomOp):
def forward_native(self, x: torch.Tensor) -> torch.Tensor: def forward_native(self, x: torch.Tensor) -> torch.Tensor:
@ -144,6 +149,7 @@ class NewGELU(CustomOp):
return ops.gelu_new(x) return ops.gelu_new(x)
@CustomOp.register("gelu_fast")
class FastGELU(CustomOp): class FastGELU(CustomOp):
def forward_native(self, x: torch.Tensor) -> torch.Tensor: def forward_native(self, x: torch.Tensor) -> torch.Tensor:
@ -164,8 +170,8 @@ class FastGELU(CustomOp):
return ops.gelu_fast(x) return ops.gelu_fast(x)
@CustomOp.register("quick_gelu")
class QuickGELU(CustomOp): class QuickGELU(CustomOp):
# https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py#L90 # https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py#L90
def forward_native(self, x: torch.Tensor) -> torch.Tensor: def forward_native(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward().""" """PyTorch-native implementation equivalent to forward()."""
@ -189,6 +195,7 @@ class QuickGELU(CustomOp):
# def forward_xpu(self, x: torch.Tensor) -> torch.Tensor: # def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
@CustomOp.register("relu2")
class ReLUSquaredActivation(CustomOp): class ReLUSquaredActivation(CustomOp):
""" """
Applies the relu^2 activation introduced in https://arxiv.org/abs/2109.08668v2 Applies the relu^2 activation introduced in https://arxiv.org/abs/2109.08668v2
@ -244,15 +251,22 @@ class ScaledActivation(nn.Module):
param_data.copy_(loaded_weight) param_data.copy_(loaded_weight)
_ACTIVATION_REGISTRY = { _ACTIVATION_REGISTRY = LazyDict({
"gelu": nn.GELU(), "gelu":
"gelu_fast": FastGELU(), lambda: nn.GELU(),
"gelu_new": NewGELU(), "gelu_fast":
"gelu_pytorch_tanh": nn.GELU(approximate="tanh"), lambda: FastGELU(),
"relu": nn.ReLU(), "gelu_new":
"relu2": ReLUSquaredActivation(), lambda: NewGELU(),
"quick_gelu": QuickGELU(), "gelu_pytorch_tanh":
} lambda: nn.GELU(approximate="tanh"),
"relu":
lambda: nn.ReLU(),
"relu2":
lambda: ReLUSquaredActivation(),
"quick_gelu":
lambda: QuickGELU(),
})
def get_act_fn( def get_act_fn(

View File

@ -37,13 +37,13 @@ class FusedMoEMethodBase(QuantizeMethodBase):
raise NotImplementedError raise NotImplementedError
@CustomOp.register("unquantized_fused_moe")
class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
"""MoE method without quantization.""" """MoE method without quantization."""
def create_weights(self, layer: torch.nn.Module, num_experts: int, def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size: int, hidden_size: int, intermediate_size: int,
params_dtype: torch.dtype, **extra_weight_attrs): params_dtype: torch.dtype, **extra_weight_attrs):
# Fused gate_up_proj (column parallel) # Fused gate_up_proj (column parallel)
w13_weight = torch.nn.Parameter(torch.empty(num_experts, w13_weight = torch.nn.Parameter(torch.empty(num_experts,
2 * intermediate_size, 2 * intermediate_size,
@ -74,7 +74,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
num_expert_group: Optional[int] = None, num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None custom_routing_function: Optional[Callable] = None
) -> torch.Tensor: ) -> torch.Tensor:
return self.forward(x=x, return self.forward(x=x,
layer=layer, layer=layer,
router_logits=router_logits, router_logits=router_logits,
@ -97,7 +96,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
num_expert_group: Optional[int] = None, num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None custom_routing_function: Optional[Callable] = None
) -> torch.Tensor: ) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe.fused_moe import ( from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_experts) fused_experts)
@ -134,7 +132,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
num_expert_group: Optional[int] = None, num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None custom_routing_function: Optional[Callable] = None
) -> torch.Tensor: ) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe.moe_pallas import fused_moe from vllm.model_executor.layers.fused_moe.moe_pallas import fused_moe
assert not use_grouped_topk assert not use_grouped_topk
assert num_expert_group is None assert num_expert_group is None

View File

@ -7,6 +7,7 @@ import torch.nn as nn
from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.custom_op import CustomOp
@CustomOp.register("rms_norm")
class RMSNorm(CustomOp): class RMSNorm(CustomOp):
"""Root mean square normalization. """Root mean square normalization.
@ -122,6 +123,7 @@ class RMSNorm(CustomOp):
return s return s
@CustomOp.register("gemma_rms_norm")
class GemmaRMSNorm(CustomOp): class GemmaRMSNorm(CustomOp):
"""RMS normalization for Gemma. """RMS normalization for Gemma.

View File

@ -72,6 +72,7 @@ def _apply_rotary_emb(
return torch.stack((o1, o2), dim=-1).flatten(-2) return torch.stack((o1, o2), dim=-1).flatten(-2)
@CustomOp.register("rotary_embedding")
class RotaryEmbedding(CustomOp): class RotaryEmbedding(CustomOp):
"""Original rotary positional embedding.""" """Original rotary positional embedding."""
@ -468,7 +469,7 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
self.long_factor = long_factor self.long_factor = long_factor
scale = self.max_position_embeddings / \ scale = self.max_position_embeddings / \
self.original_max_position_embeddings self.original_max_position_embeddings
if scale <= 1.0: if scale <= 1.0:
scaling_factor = 1.0 scaling_factor = 1.0
else: else:

View File

@ -17,6 +17,7 @@ import uuid
import warnings import warnings
import weakref import weakref
from asyncio import FIRST_COMPLETED, ensure_future from asyncio import FIRST_COMPLETED, ensure_future
from collections.abc import Mapping
from functools import lru_cache, partial, wraps from functools import lru_cache, partial, wraps
from platform import uname from platform import uname
from typing import (Any, AsyncGenerator, Awaitable, Callable, Dict, Generic, from typing import (Any, AsyncGenerator, Awaitable, Callable, Dict, Generic,
@ -1442,3 +1443,24 @@ class AtomicCounter:
@property @property
def value(self): def value(self):
return self._value return self._value
# Adapted from: https://stackoverflow.com/a/47212782/5082708
class LazyDict(Mapping, Generic[T]):
def __init__(self, factory: Dict[str, Callable[[], T]]):
self._factory = factory
self._dict: Dict[str, T] = {}
def __getitem__(self, key) -> T:
if key not in self._dict:
if key not in self._factory:
raise KeyError(key)
self._dict[key] = self._factory[key]()
return self._dict[key]
def __iter__(self):
return iter(self._factory)
def __len__(self):
return len(self._factory)