mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-16 11:56:00 +08:00
[1/N] torch.compile user interface design (#10237)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
parent
9cdba9669c
commit
eea55cca5b
@ -12,10 +12,9 @@ from vllm.compilation.compile_context import set_compile_context
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.compilation.levels import CompilationLevel
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(CompilationLevel.PIECEWISE)
|
||||
|
||||
global_counter = 0
|
||||
|
||||
# create a library to hold the custom op
|
||||
@ -48,7 +47,11 @@ direct_register_custom_op(
|
||||
@support_torch_compile
|
||||
class SillyModel(nn.Module):
|
||||
|
||||
def __init__(self) -> None:
|
||||
def __init__(self,
|
||||
*,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = '',
|
||||
**kwargs) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
@ -74,11 +77,12 @@ class SillyModel(nn.Module):
|
||||
|
||||
def test_simple_piecewise_compile():
|
||||
|
||||
model = SillyModel()
|
||||
|
||||
directory = os.path.dirname(__file__)
|
||||
config = os.path.join(directory, "piecewise_compilation_config.json")
|
||||
os.environ["VLLM_TORCH_COMPILE_CONFIG"] = config
|
||||
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(CompilationLevel.PIECEWISE)
|
||||
|
||||
model = SillyModel(vllm_config=VllmConfig(), prefix='')
|
||||
|
||||
inputs = torch.randn(100).cuda()
|
||||
|
||||
|
||||
@ -19,6 +19,7 @@ from vllm.compilation.config import CompilationConfig
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.compilation.levels import CompilationLevel
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.plugins import set_compilation_config
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
@ -195,9 +196,15 @@ class LlamaDecoderLayer(nn.Module):
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class LlamaModel(nn.Module):
|
||||
|
||||
def __init__(self, config: LlamaConfig) -> None:
|
||||
def __init__(self,
|
||||
*,
|
||||
vllm_config: VllmConfig,
|
||||
config: LlamaConfig,
|
||||
prefix: str = '',
|
||||
**kwargs) -> None:
|
||||
super().__init__()
|
||||
self.embedding_tokens = nn.Embedding(
|
||||
num_embeddings=config.vocab_size,
|
||||
@ -265,10 +272,9 @@ def run_model(llama_config,
|
||||
CompilationLevel.NO_COMPILATION)
|
||||
set_compilation_config(None)
|
||||
|
||||
cls = LlamaModel
|
||||
if use_compile:
|
||||
cls = support_torch_compile(LlamaModel)
|
||||
model = cls(llama_config).eval().cuda()
|
||||
model = LlamaModel(config=llama_config,
|
||||
vllm_config=VllmConfig(),
|
||||
prefix="").eval().cuda()
|
||||
|
||||
B = 16 # max batch size
|
||||
input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda()
|
||||
@ -357,7 +363,6 @@ def test_toy_llama():
|
||||
def benchmark():
|
||||
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(CompilationLevel.PIECEWISE)
|
||||
from triton.testing import do_bench
|
||||
cls = support_torch_compile(LlamaModel)
|
||||
|
||||
# similar to llama 3.1-8B
|
||||
llama_config = LlamaConfig(hidden_size=4096,
|
||||
@ -390,7 +395,9 @@ def benchmark():
|
||||
else:
|
||||
set_compilation_config(None)
|
||||
|
||||
model = cls(llama_config).eval().cuda().to(torch.bfloat16)
|
||||
model = LlamaModel(config=llama_config,
|
||||
vllm_config=VllmConfig(),
|
||||
prefix="").eval().cuda().to(torch.bfloat16)
|
||||
|
||||
B = 256 # max batch size
|
||||
input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda()
|
||||
|
||||
@ -6,6 +6,7 @@ import torch
|
||||
import vllm.envs as envs
|
||||
from vllm.compilation.levels import CompilationLevel
|
||||
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import supports_dynamo
|
||||
@ -110,26 +111,26 @@ def _support_torch_compile(cls: type,
|
||||
"""
|
||||
A decorator to add support for compiling the forward method of a class.
|
||||
"""
|
||||
|
||||
# for CompilationLevel.DYNAMO_AS_IS , the upper level model runner
|
||||
# will handle the compilation, so we don't need to do anything here.
|
||||
if envs.VLLM_TORCH_COMPILE_LEVEL in [
|
||||
CompilationLevel.NO_COMPILATION, CompilationLevel.DYNAMO_AS_IS
|
||||
] or not supports_dynamo():
|
||||
if TorchCompileWrapperWithCustomDispatcher in cls.__bases__:
|
||||
# support decorating multiple times
|
||||
return cls
|
||||
|
||||
# take care of method resolution order
|
||||
# make sure super().__init__ is called on the base class
|
||||
# other than TorchCompileWrapperWithCustomDispatcher
|
||||
if TorchCompileWrapperWithCustomDispatcher not in cls.__bases__:
|
||||
# support decorating multiple times
|
||||
cls.__bases__ = cls.__bases__ + (
|
||||
TorchCompileWrapperWithCustomDispatcher, )
|
||||
cls.__bases__ = cls.__bases__ + (TorchCompileWrapperWithCustomDispatcher, )
|
||||
|
||||
old_init = cls.__init__ # type: ignore
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
old_init(self, *args, **kwargs)
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = '', **kwargs):
|
||||
old_init(self, vllm_config=vllm_config, prefix=prefix, **kwargs)
|
||||
# for CompilationLevel.DYNAMO_AS_IS , the upper level model runner
|
||||
# will handle the compilation, so we don't need to do anything here.
|
||||
self.do_not_compile = envs.VLLM_TORCH_COMPILE_LEVEL in [
|
||||
CompilationLevel.NO_COMPILATION, CompilationLevel.DYNAMO_AS_IS
|
||||
] or not supports_dynamo()
|
||||
if self.do_not_compile:
|
||||
return
|
||||
TorchCompileWrapperWithCustomDispatcher.__init__(self)
|
||||
|
||||
cls.__init__ = __init__ # type: ignore
|
||||
@ -138,7 +139,7 @@ def _support_torch_compile(cls: type,
|
||||
# torch.compiler.is_compiling() means we are inside the compilation
|
||||
# e.g. TPU has the compilation logic in model runner, so we don't
|
||||
# need to compile the model inside.
|
||||
if torch.compiler.is_compiling():
|
||||
if self.do_not_compile or torch.compiler.is_compiling():
|
||||
return self.forward(*args, **kwargs)
|
||||
|
||||
# the first compilation needs to have dynamic shapes marked
|
||||
|
||||
@ -2041,12 +2041,15 @@ class VllmConfig:
|
||||
simplifies passing around the distinct configurations in the codebase.
|
||||
"""
|
||||
|
||||
model_config: ModelConfig
|
||||
cache_config: CacheConfig
|
||||
parallel_config: ParallelConfig
|
||||
scheduler_config: SchedulerConfig
|
||||
device_config: DeviceConfig
|
||||
load_config: LoadConfig
|
||||
model_config: ModelConfig = field(default=None, init=True) # type: ignore
|
||||
cache_config: CacheConfig = field(default=None, init=True) # type: ignore
|
||||
parallel_config: ParallelConfig = field(default=None,
|
||||
init=True) # type: ignore
|
||||
scheduler_config: SchedulerConfig = field(default=None,
|
||||
init=True) # type: ignore
|
||||
device_config: DeviceConfig = field(default=None,
|
||||
init=True) # type: ignore
|
||||
load_config: LoadConfig = field(default=None, init=True) # type: ignore
|
||||
lora_config: Optional[LoRAConfig] = None
|
||||
speculative_config: Optional[SpeculativeConfig] = None
|
||||
decoding_config: Optional[DecodingConfig] = None
|
||||
@ -2091,11 +2094,14 @@ class VllmConfig:
|
||||
def __post_init__(self):
|
||||
"""Verify configs are valid & consistent with each other.
|
||||
"""
|
||||
self.model_config.verify_async_output_proc(self.parallel_config,
|
||||
self.speculative_config,
|
||||
self.device_config)
|
||||
self.model_config.verify_with_parallel_config(self.parallel_config)
|
||||
self.cache_config.verify_with_parallel_config(self.parallel_config)
|
||||
if self.model_config is not None:
|
||||
self.model_config.verify_async_output_proc(self.parallel_config,
|
||||
self.speculative_config,
|
||||
self.device_config)
|
||||
self.model_config.verify_with_parallel_config(self.parallel_config)
|
||||
|
||||
if self.cache_config is not None:
|
||||
self.cache_config.verify_with_parallel_config(self.parallel_config)
|
||||
|
||||
if self.lora_config:
|
||||
self.lora_config.verify_with_model_config(self.model_config)
|
||||
@ -2149,4 +2155,4 @@ class VllmConfig:
|
||||
self.scheduler_config.num_scheduler_steps,
|
||||
self.cache_config.enable_prefix_caching,
|
||||
self.model_config.use_async_output_proc,
|
||||
self.model_config.mm_processor_kwargs)
|
||||
self.model_config.mm_processor_kwargs)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user