mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-21 05:45:27 +08:00
[platform] support custom torch.compile backend key (#11318)
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com> Signed-off-by: youkaichao <youkaichao@gmail.com> Co-authored-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
parent
12664ddda5
commit
20410b2fda
@ -9,6 +9,7 @@ import vllm.envs as envs
|
|||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.spec_decode_base_sampler import (
|
from vllm.model_executor.layers.spec_decode_base_sampler import (
|
||||||
SpecDecodeStochasticBaseSampler)
|
SpecDecodeStochasticBaseSampler)
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -368,7 +369,7 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
|
|||||||
# Note that we always sample with replacement.
|
# Note that we always sample with replacement.
|
||||||
# probs will be modified in place, but this is fine, as we pass
|
# probs will be modified in place, but this is fine, as we pass
|
||||||
# in a copy already.
|
# in a copy already.
|
||||||
@torch.compile(dynamic=True)
|
@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
|
||||||
def _multinomial(
|
def _multinomial(
|
||||||
probs: torch.Tensor,
|
probs: torch.Tensor,
|
||||||
num_samples: int,
|
num_samples: int,
|
||||||
|
|||||||
@ -133,7 +133,7 @@ class VocabParallelEmbeddingShardIndices:
|
|||||||
assert self.num_added_elements <= self.num_added_elements_padded
|
assert self.num_added_elements <= self.num_added_elements_padded
|
||||||
|
|
||||||
|
|
||||||
@torch.compile(dynamic=True)
|
@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
|
||||||
def get_masked_input_and_mask(
|
def get_masked_input_and_mask(
|
||||||
input_: torch.Tensor, org_vocab_start_index: int,
|
input_: torch.Tensor, org_vocab_start_index: int,
|
||||||
org_vocab_end_index: int, num_org_vocab_padding: int,
|
org_vocab_end_index: int, num_org_vocab_padding: int,
|
||||||
|
|||||||
@ -45,6 +45,7 @@ from vllm.model_executor.model_loader.weight_utils import (
|
|||||||
row_parallel_weight_loader)
|
row_parallel_weight_loader)
|
||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.model_executor.utils import set_weight_attrs
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
|
from vllm.platforms import current_platform
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
from .interfaces import SupportsLoRA, SupportsPP
|
from .interfaces import SupportsLoRA, SupportsPP
|
||||||
@ -53,7 +54,7 @@ from .utils import (extract_layer_index, is_pp_missing_parameter,
|
|||||||
maybe_prefix)
|
maybe_prefix)
|
||||||
|
|
||||||
|
|
||||||
@torch.compile
|
@torch.compile(backend=current_platform.simple_compile_backend)
|
||||||
def layer_norm_func(hidden_states, weight, variance_epsilon):
|
def layer_norm_func(hidden_states, weight, variance_epsilon):
|
||||||
input_dtype = hidden_states.dtype
|
input_dtype = hidden_states.dtype
|
||||||
hidden_states = hidden_states.to(torch.float32)
|
hidden_states = hidden_states.to(torch.float32)
|
||||||
|
|||||||
@ -20,6 +20,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|||||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
|
from vllm.platforms import current_platform
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
from .interfaces import SupportsPP
|
from .interfaces import SupportsPP
|
||||||
@ -54,12 +55,12 @@ class HeadMajorColumnParallelLinear(MergedColumnParallelLinear):
|
|||||||
return load_column_parallel_weight(param, loaded_weight)
|
return load_column_parallel_weight(param, loaded_weight)
|
||||||
|
|
||||||
|
|
||||||
@torch.compile(dynamic=True)
|
@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
|
||||||
def quick_gelu(x):
|
def quick_gelu(x):
|
||||||
return x * torch.sigmoid(1.702 * x)
|
return x * torch.sigmoid(1.702 * x)
|
||||||
|
|
||||||
|
|
||||||
@torch.compile(dynamic=True)
|
@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
|
||||||
def gegelu(input, limit: Optional[float] = None):
|
def gegelu(input, limit: Optional[float] = None):
|
||||||
a_gelu, a_linear = input[..., ::2], input[..., 1::2]
|
a_gelu, a_linear = input[..., ::2], input[..., 1::2]
|
||||||
if limit is not None:
|
if limit is not None:
|
||||||
|
|||||||
@ -82,6 +82,12 @@ class Platform:
|
|||||||
# check https://github.com/pytorch/pytorch/blob/313dac6c1ca0fa0cde32477509cce32089f8532a/torchgen/model.py#L134 # noqa
|
# check https://github.com/pytorch/pytorch/blob/313dac6c1ca0fa0cde32477509cce32089f8532a/torchgen/model.py#L134 # noqa
|
||||||
# use "CPU" as a fallback for platforms not registered in PyTorch
|
# use "CPU" as a fallback for platforms not registered in PyTorch
|
||||||
dispatch_key: str = "CPU"
|
dispatch_key: str = "CPU"
|
||||||
|
# The torch.compile backend for compiling simple and
|
||||||
|
# standalone functions. The default value is "inductor" to keep
|
||||||
|
# the same behavior as PyTorch.
|
||||||
|
# NOTE: for the forward part of the model, vLLM has another separate
|
||||||
|
# compilation strategy.
|
||||||
|
simple_compile_backend: str = "inductor"
|
||||||
supported_quantization: list[str] = []
|
supported_quantization: list[str] = []
|
||||||
|
|
||||||
def is_cuda(self) -> bool:
|
def is_cuda(self) -> bool:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user