mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 03:35:17 +08:00
[torch.compile] Fine-grained CustomOp enabling mechanism (#9300)
This commit is contained in:
parent
7871659abb
commit
0f41fbe5a3
92
tests/model_executor/test_enabled_custom_ops.py
Normal file
92
tests/model_executor/test_enabled_custom_ops.py
Normal file
@ -0,0 +1,92 @@
|
|||||||
|
import os
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from vllm.model_executor.custom_op import CustomOp
|
||||||
|
from vllm.model_executor.layers.activation import (GeluAndMul,
|
||||||
|
ReLUSquaredActivation,
|
||||||
|
SiluAndMul)
|
||||||
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
|
|
||||||
|
|
||||||
|
# Registered subclass for test
|
||||||
|
@CustomOp.register("relu3")
|
||||||
|
class Relu3(ReLUSquaredActivation):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"env, torch_level, ops_enabled, default_on",
|
||||||
|
[
|
||||||
|
# Default values based on compile level
|
||||||
|
("", 0, [True] * 4, True),
|
||||||
|
("", 1, [True] * 4, True),
|
||||||
|
("", 2, [True] * 4, True), # All by default
|
||||||
|
("", 3, [False] * 4, False),
|
||||||
|
("", 4, [False] * 4, False), # None by default
|
||||||
|
# Explicitly enabling/disabling
|
||||||
|
#
|
||||||
|
# Default: all
|
||||||
|
#
|
||||||
|
# All but SiluAndMul
|
||||||
|
("+rms_norm,-silu_and_mul", 0, [1, 0, 1, 1], True),
|
||||||
|
# Only ReLU3
|
||||||
|
("none,-rms_norm,+relu3", 0, [0, 0, 0, 1], False),
|
||||||
|
# All but SiluAndMul
|
||||||
|
("all,-silu_and_mul", 1, [1, 0, 1, 1], True),
|
||||||
|
# All but ReLU3 (even if ReLU2 is on)
|
||||||
|
("-relu3,relu2", 1, [1, 1, 1, 0], True),
|
||||||
|
# GeluAndMul and SiluAndMul
|
||||||
|
("none,-relu3,+gelu_and_mul,+silu_and_mul", 2, [0, 1, 1, 0], False),
|
||||||
|
# All but RMSNorm
|
||||||
|
("-rms_norm", 2, [0, 1, 1, 1], True),
|
||||||
|
#
|
||||||
|
# Default: none
|
||||||
|
#
|
||||||
|
# Only ReLU3
|
||||||
|
("-silu_and_mul,+relu3", 3, [0, 0, 0, 1], False),
|
||||||
|
# All but RMSNorm
|
||||||
|
("all,-rms_norm", 4, [0, 1, 1, 1], True),
|
||||||
|
])
|
||||||
|
def test_enabled_ops(env: str, torch_level: int, ops_enabled: List[int],
|
||||||
|
default_on: bool):
|
||||||
|
os.environ["VLLM_CUSTOM_OPS"] = env
|
||||||
|
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(torch_level)
|
||||||
|
|
||||||
|
# Reset default_on (computed once):
|
||||||
|
CustomOp.default_on.cache_clear()
|
||||||
|
|
||||||
|
assert CustomOp.default_on() == default_on
|
||||||
|
|
||||||
|
ops_enabled = [bool(x) for x in ops_enabled]
|
||||||
|
|
||||||
|
assert RMSNorm(1024).enabled() == ops_enabled[0]
|
||||||
|
assert CustomOp.op_registry["rms_norm"].enabled() == ops_enabled[0]
|
||||||
|
|
||||||
|
assert SiluAndMul().enabled() == ops_enabled[1]
|
||||||
|
assert CustomOp.op_registry["silu_and_mul"].enabled() == ops_enabled[1]
|
||||||
|
|
||||||
|
assert GeluAndMul().enabled() == ops_enabled[2]
|
||||||
|
assert CustomOp.op_registry["gelu_and_mul"].enabled() == ops_enabled[2]
|
||||||
|
|
||||||
|
# If registered, subclasses should follow their own name
|
||||||
|
assert Relu3().enabled() == ops_enabled[3]
|
||||||
|
assert CustomOp.op_registry["relu3"].enabled() == ops_enabled[3]
|
||||||
|
|
||||||
|
# Unregistered subclass
|
||||||
|
class SiluAndMul2(SiluAndMul):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Subclasses should not require registration
|
||||||
|
assert SiluAndMul2().enabled() == SiluAndMul().enabled()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"env", ["all,none", "all,+rms_norm,all", "+rms_norm,-rms_norm"])
|
||||||
|
def test_enabled_ops_invalid(env: str):
|
||||||
|
os.environ["VLLM_CUSTOM_OPS"] = env
|
||||||
|
CustomOp.default_on.cache_clear()
|
||||||
|
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
RMSNorm(1024).enabled()
|
||||||
13
vllm/envs.py
13
vllm/envs.py
@ -65,6 +65,7 @@ if TYPE_CHECKING:
|
|||||||
VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False
|
VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False
|
||||||
VLLM_SKIP_P2P_CHECK: bool = False
|
VLLM_SKIP_P2P_CHECK: bool = False
|
||||||
VLLM_TORCH_COMPILE_LEVEL: int = 0
|
VLLM_TORCH_COMPILE_LEVEL: int = 0
|
||||||
|
VLLM_CUSTOM_OPS: List[str] = []
|
||||||
VLLM_DISABLED_KERNELS: List[str] = []
|
VLLM_DISABLED_KERNELS: List[str] = []
|
||||||
|
|
||||||
|
|
||||||
@ -205,7 +206,17 @@ environment_variables: Dict[str, Callable[[], Any]] = {
|
|||||||
os.environ.get("VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE", "1") != "0"),
|
os.environ.get("VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE", "1") != "0"),
|
||||||
"VLLM_TORCH_COMPILE_LEVEL":
|
"VLLM_TORCH_COMPILE_LEVEL":
|
||||||
lambda: int(os.environ.get("VLLM_TORCH_COMPILE_LEVEL", "0")),
|
lambda: int(os.environ.get("VLLM_TORCH_COMPILE_LEVEL", "0")),
|
||||||
|
# Fine-grained control over which custom ops to enable/disable.
|
||||||
|
# Use 'all' to enable all, 'none' to disable all.
|
||||||
|
# Also specify a list of custom op names to enable (prefixed with a '+'),
|
||||||
|
# or disable (prefixed with a '-').
|
||||||
|
# Examples:
|
||||||
|
# - 'all,-op1' to enable all except op1
|
||||||
|
# - 'none,+op1,+op2' to enable only op1 and op2
|
||||||
|
# By default, all custom ops are enabled when running without Inductor
|
||||||
|
# and disabled when running with Inductor (compile_level >= Inductor).
|
||||||
|
"VLLM_CUSTOM_OPS":
|
||||||
|
lambda: os.environ.get("VLLM_CUSTOM_OPS", "").replace(" ", "").split(","),
|
||||||
# local rank of the process in the distributed setting, used to determine
|
# local rank of the process in the distributed setting, used to determine
|
||||||
# the GPU device id
|
# the GPU device id
|
||||||
"LOCAL_RANK":
|
"LOCAL_RANK":
|
||||||
|
|||||||
@ -1,14 +1,24 @@
|
|||||||
|
from functools import lru_cache
|
||||||
|
from typing import Dict, Type
|
||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.compilation.levels import CompilationLevel
|
from vllm.compilation.levels import CompilationLevel
|
||||||
|
from vllm.logger import init_logger
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import is_cpu, is_hip, is_xpu
|
from vllm.utils import is_cpu, is_hip, is_xpu, print_warning_once
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class CustomOp(nn.Module):
|
class CustomOp(nn.Module):
|
||||||
|
"""
|
||||||
|
Base class for custom ops.
|
||||||
|
Dispatches the forward method to the appropriate backend.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._forward_method = self.dispatch_forward()
|
self._forward_method = self.dispatch_forward()
|
||||||
|
|
||||||
@ -17,7 +27,6 @@ class CustomOp(nn.Module):
|
|||||||
|
|
||||||
def forward_native(self, *args, **kwargs):
|
def forward_native(self, *args, **kwargs):
|
||||||
"""PyTorch-native implementation of the forward method.
|
"""PyTorch-native implementation of the forward method.
|
||||||
|
|
||||||
This method is optional. If implemented, it can be used with compilers
|
This method is optional. If implemented, it can be used with compilers
|
||||||
such as torch.compile or PyTorch XLA. Also, it can be used for testing
|
such as torch.compile or PyTorch XLA. Also, it can be used for testing
|
||||||
purposes.
|
purposes.
|
||||||
@ -56,7 +65,11 @@ class CustomOp(nn.Module):
|
|||||||
# NOTE(woosuk): Here we assume that vLLM was built for only one
|
# NOTE(woosuk): Here we assume that vLLM was built for only one
|
||||||
# specific backend. Currently, we do not support dynamic dispatching.
|
# specific backend. Currently, we do not support dynamic dispatching.
|
||||||
|
|
||||||
if envs.VLLM_TORCH_COMPILE_LEVEL >= CompilationLevel.INDUCTOR:
|
enabled = self.enabled()
|
||||||
|
logger.debug("custom op %s %s", self.__class__.name,
|
||||||
|
"enabled" if enabled else "disabled")
|
||||||
|
|
||||||
|
if not enabled:
|
||||||
return self.forward_native
|
return self.forward_native
|
||||||
|
|
||||||
if is_hip():
|
if is_hip():
|
||||||
@ -69,3 +82,50 @@ class CustomOp(nn.Module):
|
|||||||
return self.forward_xpu
|
return self.forward_xpu
|
||||||
else:
|
else:
|
||||||
return self.forward_cuda
|
return self.forward_cuda
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def enabled(cls) -> bool:
|
||||||
|
# if no name, then it was not registered
|
||||||
|
if not hasattr(cls, "name"):
|
||||||
|
print_warning_once(
|
||||||
|
f"Custom op {cls.__name__} was not registered, "
|
||||||
|
f"which means it won't appear in the op registry. "
|
||||||
|
f"It will be enabled/disabled based on the global settings.")
|
||||||
|
return CustomOp.default_on()
|
||||||
|
|
||||||
|
enabled = f"+{cls.name}" in envs.VLLM_CUSTOM_OPS
|
||||||
|
disabled = f"-{cls.name}" in envs.VLLM_CUSTOM_OPS
|
||||||
|
assert not (enabled
|
||||||
|
and disabled), f"Cannot enable and disable {cls.name}"
|
||||||
|
|
||||||
|
return (CustomOp.default_on() or enabled) and not disabled
|
||||||
|
|
||||||
|
# On by default if VLLM_TORCH_COMPILE_LEVEL < CompilationLevel.INDUCTOR
|
||||||
|
# Specifying 'all' or 'none' in VLLM_CUSTOM_OPS takes precedence.
|
||||||
|
@staticmethod
|
||||||
|
@lru_cache()
|
||||||
|
def default_on() -> bool:
|
||||||
|
count_none = envs.VLLM_CUSTOM_OPS.count("none")
|
||||||
|
count_all = envs.VLLM_CUSTOM_OPS.count("all")
|
||||||
|
assert count_none + count_all <= 1, "Can only specify 'none' or 'all'"
|
||||||
|
return envs.VLLM_TORCH_COMPILE_LEVEL < CompilationLevel.INDUCTOR and \
|
||||||
|
not count_none > 0 or count_all > 0
|
||||||
|
|
||||||
|
# Dictionary of all custom ops (classes, indexed by registered name).
|
||||||
|
# To check if an op with a name is enabled, call .enabled() on the class.
|
||||||
|
# Examples:
|
||||||
|
# - MyOp.enabled()
|
||||||
|
# - op_registry["my_op"].enabled()
|
||||||
|
op_registry: Dict[str, Type['CustomOp']] = {}
|
||||||
|
|
||||||
|
# Decorator to register custom ops.
|
||||||
|
@classmethod
|
||||||
|
def register(cls, name: str):
|
||||||
|
|
||||||
|
def decorator(op_cls):
|
||||||
|
assert name not in cls.op_registry, f"Duplicate op name: {name}"
|
||||||
|
op_cls.name = name
|
||||||
|
cls.op_registry[name] = op_cls
|
||||||
|
return op_cls
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|||||||
@ -11,11 +11,13 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank,
|
|||||||
from vllm.model_executor.custom_op import CustomOp
|
from vllm.model_executor.custom_op import CustomOp
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
from vllm.model_executor.utils import set_weight_attrs
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
|
from vllm.utils import LazyDict
|
||||||
|
|
||||||
|
|
||||||
|
@CustomOp.register("fatrelu_and_mul")
|
||||||
class FatreluAndMul(CustomOp):
|
class FatreluAndMul(CustomOp):
|
||||||
"""An activation function for FATReLU.
|
"""An activation function for FATReLU.
|
||||||
|
|
||||||
The function computes x -> FATReLU(x[:d]) * x[d:] where
|
The function computes x -> FATReLU(x[:d]) * x[d:] where
|
||||||
d = x.shape[-1] // 2.
|
d = x.shape[-1] // 2.
|
||||||
This is used in openbmb/MiniCPM-S-1B-sft.
|
This is used in openbmb/MiniCPM-S-1B-sft.
|
||||||
@ -40,6 +42,7 @@ class FatreluAndMul(CustomOp):
|
|||||||
return self.forward_native(x)
|
return self.forward_native(x)
|
||||||
|
|
||||||
|
|
||||||
|
@CustomOp.register("silu_and_mul")
|
||||||
class SiluAndMul(CustomOp):
|
class SiluAndMul(CustomOp):
|
||||||
"""An activation function for SwiGLU.
|
"""An activation function for SwiGLU.
|
||||||
|
|
||||||
@ -74,6 +77,7 @@ class SiluAndMul(CustomOp):
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
@CustomOp.register("gelu_and_mul")
|
||||||
class GeluAndMul(CustomOp):
|
class GeluAndMul(CustomOp):
|
||||||
"""An activation function for GeGLU.
|
"""An activation function for GeGLU.
|
||||||
|
|
||||||
@ -123,6 +127,7 @@ class GeluAndMul(CustomOp):
|
|||||||
return f'approximate={repr(self.approximate)}'
|
return f'approximate={repr(self.approximate)}'
|
||||||
|
|
||||||
|
|
||||||
|
@CustomOp.register("gelu_new")
|
||||||
class NewGELU(CustomOp):
|
class NewGELU(CustomOp):
|
||||||
|
|
||||||
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
@ -144,6 +149,7 @@ class NewGELU(CustomOp):
|
|||||||
return ops.gelu_new(x)
|
return ops.gelu_new(x)
|
||||||
|
|
||||||
|
|
||||||
|
@CustomOp.register("gelu_fast")
|
||||||
class FastGELU(CustomOp):
|
class FastGELU(CustomOp):
|
||||||
|
|
||||||
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
@ -164,8 +170,8 @@ class FastGELU(CustomOp):
|
|||||||
return ops.gelu_fast(x)
|
return ops.gelu_fast(x)
|
||||||
|
|
||||||
|
|
||||||
|
@CustomOp.register("quick_gelu")
|
||||||
class QuickGELU(CustomOp):
|
class QuickGELU(CustomOp):
|
||||||
|
|
||||||
# https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py#L90
|
# https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py#L90
|
||||||
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
"""PyTorch-native implementation equivalent to forward()."""
|
"""PyTorch-native implementation equivalent to forward()."""
|
||||||
@ -189,6 +195,7 @@ class QuickGELU(CustomOp):
|
|||||||
# def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
|
# def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
|
||||||
|
|
||||||
|
@CustomOp.register("relu2")
|
||||||
class ReLUSquaredActivation(CustomOp):
|
class ReLUSquaredActivation(CustomOp):
|
||||||
"""
|
"""
|
||||||
Applies the relu^2 activation introduced in https://arxiv.org/abs/2109.08668v2
|
Applies the relu^2 activation introduced in https://arxiv.org/abs/2109.08668v2
|
||||||
@ -244,15 +251,22 @@ class ScaledActivation(nn.Module):
|
|||||||
param_data.copy_(loaded_weight)
|
param_data.copy_(loaded_weight)
|
||||||
|
|
||||||
|
|
||||||
_ACTIVATION_REGISTRY = {
|
_ACTIVATION_REGISTRY = LazyDict({
|
||||||
"gelu": nn.GELU(),
|
"gelu":
|
||||||
"gelu_fast": FastGELU(),
|
lambda: nn.GELU(),
|
||||||
"gelu_new": NewGELU(),
|
"gelu_fast":
|
||||||
"gelu_pytorch_tanh": nn.GELU(approximate="tanh"),
|
lambda: FastGELU(),
|
||||||
"relu": nn.ReLU(),
|
"gelu_new":
|
||||||
"relu2": ReLUSquaredActivation(),
|
lambda: NewGELU(),
|
||||||
"quick_gelu": QuickGELU(),
|
"gelu_pytorch_tanh":
|
||||||
}
|
lambda: nn.GELU(approximate="tanh"),
|
||||||
|
"relu":
|
||||||
|
lambda: nn.ReLU(),
|
||||||
|
"relu2":
|
||||||
|
lambda: ReLUSquaredActivation(),
|
||||||
|
"quick_gelu":
|
||||||
|
lambda: QuickGELU(),
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
def get_act_fn(
|
def get_act_fn(
|
||||||
|
|||||||
@ -37,13 +37,13 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
@CustomOp.register("unquantized_fused_moe")
|
||||||
class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||||
"""MoE method without quantization."""
|
"""MoE method without quantization."""
|
||||||
|
|
||||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||||
hidden_size: int, intermediate_size: int,
|
hidden_size: int, intermediate_size: int,
|
||||||
params_dtype: torch.dtype, **extra_weight_attrs):
|
params_dtype: torch.dtype, **extra_weight_attrs):
|
||||||
|
|
||||||
# Fused gate_up_proj (column parallel)
|
# Fused gate_up_proj (column parallel)
|
||||||
w13_weight = torch.nn.Parameter(torch.empty(num_experts,
|
w13_weight = torch.nn.Parameter(torch.empty(num_experts,
|
||||||
2 * intermediate_size,
|
2 * intermediate_size,
|
||||||
@ -74,7 +74,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
num_expert_group: Optional[int] = None,
|
num_expert_group: Optional[int] = None,
|
||||||
custom_routing_function: Optional[Callable] = None
|
custom_routing_function: Optional[Callable] = None
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
return self.forward(x=x,
|
return self.forward(x=x,
|
||||||
layer=layer,
|
layer=layer,
|
||||||
router_logits=router_logits,
|
router_logits=router_logits,
|
||||||
@ -97,7 +96,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
num_expert_group: Optional[int] = None,
|
num_expert_group: Optional[int] = None,
|
||||||
custom_routing_function: Optional[Callable] = None
|
custom_routing_function: Optional[Callable] = None
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||||
fused_experts)
|
fused_experts)
|
||||||
|
|
||||||
@ -134,7 +132,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
num_expert_group: Optional[int] = None,
|
num_expert_group: Optional[int] = None,
|
||||||
custom_routing_function: Optional[Callable] = None
|
custom_routing_function: Optional[Callable] = None
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
from vllm.model_executor.layers.fused_moe.moe_pallas import fused_moe
|
from vllm.model_executor.layers.fused_moe.moe_pallas import fused_moe
|
||||||
assert not use_grouped_topk
|
assert not use_grouped_topk
|
||||||
assert num_expert_group is None
|
assert num_expert_group is None
|
||||||
|
|||||||
@ -7,6 +7,7 @@ import torch.nn as nn
|
|||||||
from vllm.model_executor.custom_op import CustomOp
|
from vllm.model_executor.custom_op import CustomOp
|
||||||
|
|
||||||
|
|
||||||
|
@CustomOp.register("rms_norm")
|
||||||
class RMSNorm(CustomOp):
|
class RMSNorm(CustomOp):
|
||||||
"""Root mean square normalization.
|
"""Root mean square normalization.
|
||||||
|
|
||||||
@ -122,6 +123,7 @@ class RMSNorm(CustomOp):
|
|||||||
return s
|
return s
|
||||||
|
|
||||||
|
|
||||||
|
@CustomOp.register("gemma_rms_norm")
|
||||||
class GemmaRMSNorm(CustomOp):
|
class GemmaRMSNorm(CustomOp):
|
||||||
"""RMS normalization for Gemma.
|
"""RMS normalization for Gemma.
|
||||||
|
|
||||||
|
|||||||
@ -72,6 +72,7 @@ def _apply_rotary_emb(
|
|||||||
return torch.stack((o1, o2), dim=-1).flatten(-2)
|
return torch.stack((o1, o2), dim=-1).flatten(-2)
|
||||||
|
|
||||||
|
|
||||||
|
@CustomOp.register("rotary_embedding")
|
||||||
class RotaryEmbedding(CustomOp):
|
class RotaryEmbedding(CustomOp):
|
||||||
"""Original rotary positional embedding."""
|
"""Original rotary positional embedding."""
|
||||||
|
|
||||||
@ -468,7 +469,7 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
|
|||||||
self.long_factor = long_factor
|
self.long_factor = long_factor
|
||||||
|
|
||||||
scale = self.max_position_embeddings / \
|
scale = self.max_position_embeddings / \
|
||||||
self.original_max_position_embeddings
|
self.original_max_position_embeddings
|
||||||
if scale <= 1.0:
|
if scale <= 1.0:
|
||||||
scaling_factor = 1.0
|
scaling_factor = 1.0
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -17,6 +17,7 @@ import uuid
|
|||||||
import warnings
|
import warnings
|
||||||
import weakref
|
import weakref
|
||||||
from asyncio import FIRST_COMPLETED, ensure_future
|
from asyncio import FIRST_COMPLETED, ensure_future
|
||||||
|
from collections.abc import Mapping
|
||||||
from functools import lru_cache, partial, wraps
|
from functools import lru_cache, partial, wraps
|
||||||
from platform import uname
|
from platform import uname
|
||||||
from typing import (Any, AsyncGenerator, Awaitable, Callable, Dict, Generic,
|
from typing import (Any, AsyncGenerator, Awaitable, Callable, Dict, Generic,
|
||||||
@ -1442,3 +1443,24 @@ class AtomicCounter:
|
|||||||
@property
|
@property
|
||||||
def value(self):
|
def value(self):
|
||||||
return self._value
|
return self._value
|
||||||
|
|
||||||
|
|
||||||
|
# Adapted from: https://stackoverflow.com/a/47212782/5082708
|
||||||
|
class LazyDict(Mapping, Generic[T]):
|
||||||
|
|
||||||
|
def __init__(self, factory: Dict[str, Callable[[], T]]):
|
||||||
|
self._factory = factory
|
||||||
|
self._dict: Dict[str, T] = {}
|
||||||
|
|
||||||
|
def __getitem__(self, key) -> T:
|
||||||
|
if key not in self._dict:
|
||||||
|
if key not in self._factory:
|
||||||
|
raise KeyError(key)
|
||||||
|
self._dict[key] = self._factory[key]()
|
||||||
|
return self._dict[key]
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return iter(self._factory)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self._factory)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user