[platform] support custom torch.compile backend key (#11318)

Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Co-authored-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
wangxiyuan 2025-01-10 23:46:51 +08:00 committed by GitHub
parent 12664ddda5
commit 20410b2fda
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 14 additions and 5 deletions

View File

@ -9,6 +9,7 @@ import vllm.envs as envs
from vllm.logger import init_logger
from vllm.model_executor.layers.spec_decode_base_sampler import (
SpecDecodeStochasticBaseSampler)
from vllm.platforms import current_platform
logger = init_logger(__name__)
@ -368,7 +369,7 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
# Note that we always sample with replacement.
# probs will be modified in place, but this is fine, as we pass
# in a copy already.
@torch.compile(dynamic=True)
@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
def _multinomial(
probs: torch.Tensor,
num_samples: int,

View File

@ -133,7 +133,7 @@ class VocabParallelEmbeddingShardIndices:
assert self.num_added_elements <= self.num_added_elements_padded
@torch.compile(dynamic=True)
@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
def get_masked_input_and_mask(
input_: torch.Tensor, org_vocab_start_index: int,
org_vocab_end_index: int, num_org_vocab_padding: int,

View File

@ -45,6 +45,7 @@ from vllm.model_executor.model_loader.weight_utils import (
row_parallel_weight_loader)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP
@ -53,7 +54,7 @@ from .utils import (extract_layer_index, is_pp_missing_parameter,
maybe_prefix)
@torch.compile
@torch.compile(backend=current_platform.simple_compile_backend)
def layer_norm_func(hidden_states, weight, variance_epsilon):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)

View File

@ -20,6 +20,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsPP
@ -54,12 +55,12 @@ class HeadMajorColumnParallelLinear(MergedColumnParallelLinear):
return load_column_parallel_weight(param, loaded_weight)
@torch.compile(dynamic=True)
@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
def quick_gelu(x):
return x * torch.sigmoid(1.702 * x)
@torch.compile(dynamic=True)
@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
def gegelu(input, limit: Optional[float] = None):
a_gelu, a_linear = input[..., ::2], input[..., 1::2]
if limit is not None:

View File

@ -82,6 +82,12 @@ class Platform:
# check https://github.com/pytorch/pytorch/blob/313dac6c1ca0fa0cde32477509cce32089f8532a/torchgen/model.py#L134 # noqa
# use "CPU" as a fallback for platforms not registered in PyTorch
dispatch_key: str = "CPU"
# The torch.compile backend for compiling simple and
# standalone functions. The default value is "inductor" to keep
# the same behavior as PyTorch.
# NOTE: for the forward part of the model, vLLM has another separate
# compilation strategy.
simple_compile_backend: str = "inductor"
supported_quantization: list[str] = []
def is_cuda(self) -> bool: