[1/N] torch.compile user interface design (#10237)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao 2024-11-11 18:01:06 -08:00 committed by GitHub
parent 9cdba9669c
commit eea55cca5b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 55 additions and 37 deletions

View File

@ -12,10 +12,9 @@ from vllm.compilation.compile_context import set_compile_context
from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import support_torch_compile
from vllm.compilation.levels import CompilationLevel
from vllm.config import VllmConfig
from vllm.utils import direct_register_custom_op
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(CompilationLevel.PIECEWISE)
global_counter = 0
# create a library to hold the custom op
@ -48,7 +47,11 @@ direct_register_custom_op(
@support_torch_compile
class SillyModel(nn.Module):
def __init__(self) -> None:
def __init__(self,
*,
vllm_config: VllmConfig,
prefix: str = '',
**kwargs) -> None:
super().__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor:
@ -74,11 +77,12 @@ class SillyModel(nn.Module):
def test_simple_piecewise_compile():
model = SillyModel()
directory = os.path.dirname(__file__)
config = os.path.join(directory, "piecewise_compilation_config.json")
os.environ["VLLM_TORCH_COMPILE_CONFIG"] = config
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(CompilationLevel.PIECEWISE)
model = SillyModel(vllm_config=VllmConfig(), prefix='')
inputs = torch.randn(100).cuda()

View File

@ -19,6 +19,7 @@ from vllm.compilation.config import CompilationConfig
from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import support_torch_compile
from vllm.compilation.levels import CompilationLevel
from vllm.config import VllmConfig
from vllm.plugins import set_compilation_config
from vllm.utils import direct_register_custom_op
@ -195,9 +196,15 @@ class LlamaDecoderLayer(nn.Module):
return hidden_states, residual
@support_torch_compile
class LlamaModel(nn.Module):
def __init__(self, config: LlamaConfig) -> None:
def __init__(self,
*,
vllm_config: VllmConfig,
config: LlamaConfig,
prefix: str = '',
**kwargs) -> None:
super().__init__()
self.embedding_tokens = nn.Embedding(
num_embeddings=config.vocab_size,
@ -265,10 +272,9 @@ def run_model(llama_config,
CompilationLevel.NO_COMPILATION)
set_compilation_config(None)
cls = LlamaModel
if use_compile:
cls = support_torch_compile(LlamaModel)
model = cls(llama_config).eval().cuda()
model = LlamaModel(config=llama_config,
vllm_config=VllmConfig(),
prefix="").eval().cuda()
B = 16 # max batch size
input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda()
@ -357,7 +363,6 @@ def test_toy_llama():
def benchmark():
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(CompilationLevel.PIECEWISE)
from triton.testing import do_bench
cls = support_torch_compile(LlamaModel)
# similar to llama 3.1-8B
llama_config = LlamaConfig(hidden_size=4096,
@ -390,7 +395,9 @@ def benchmark():
else:
set_compilation_config(None)
model = cls(llama_config).eval().cuda().to(torch.bfloat16)
model = LlamaModel(config=llama_config,
vllm_config=VllmConfig(),
prefix="").eval().cuda().to(torch.bfloat16)
B = 256 # max batch size
input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda()

View File

@ -6,6 +6,7 @@ import torch
import vllm.envs as envs
from vllm.compilation.levels import CompilationLevel
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.sequence import IntermediateTensors
from vllm.utils import supports_dynamo
@ -110,26 +111,26 @@ def _support_torch_compile(cls: type,
"""
A decorator to add support for compiling the forward method of a class.
"""
# for CompilationLevel.DYNAMO_AS_IS , the upper level model runner
# will handle the compilation, so we don't need to do anything here.
if envs.VLLM_TORCH_COMPILE_LEVEL in [
CompilationLevel.NO_COMPILATION, CompilationLevel.DYNAMO_AS_IS
] or not supports_dynamo():
if TorchCompileWrapperWithCustomDispatcher in cls.__bases__:
# support decorating multiple times
return cls
# take care of method resolution order
# make sure super().__init__ is called on the base class
# other than TorchCompileWrapperWithCustomDispatcher
if TorchCompileWrapperWithCustomDispatcher not in cls.__bases__:
# support decorating multiple times
cls.__bases__ = cls.__bases__ + (
TorchCompileWrapperWithCustomDispatcher, )
cls.__bases__ = cls.__bases__ + (TorchCompileWrapperWithCustomDispatcher, )
old_init = cls.__init__ # type: ignore
def __init__(self, *args, **kwargs):
old_init(self, *args, **kwargs)
def __init__(self, *, vllm_config: VllmConfig, prefix: str = '', **kwargs):
old_init(self, vllm_config=vllm_config, prefix=prefix, **kwargs)
# for CompilationLevel.DYNAMO_AS_IS , the upper level model runner
# will handle the compilation, so we don't need to do anything here.
self.do_not_compile = envs.VLLM_TORCH_COMPILE_LEVEL in [
CompilationLevel.NO_COMPILATION, CompilationLevel.DYNAMO_AS_IS
] or not supports_dynamo()
if self.do_not_compile:
return
TorchCompileWrapperWithCustomDispatcher.__init__(self)
cls.__init__ = __init__ # type: ignore
@ -138,7 +139,7 @@ def _support_torch_compile(cls: type,
# torch.compiler.is_compiling() means we are inside the compilation
# e.g. TPU has the compilation logic in model runner, so we don't
# need to compile the model inside.
if torch.compiler.is_compiling():
if self.do_not_compile or torch.compiler.is_compiling():
return self.forward(*args, **kwargs)
# the first compilation needs to have dynamic shapes marked

View File

@ -2041,12 +2041,15 @@ class VllmConfig:
simplifies passing around the distinct configurations in the codebase.
"""
model_config: ModelConfig
cache_config: CacheConfig
parallel_config: ParallelConfig
scheduler_config: SchedulerConfig
device_config: DeviceConfig
load_config: LoadConfig
model_config: ModelConfig = field(default=None, init=True) # type: ignore
cache_config: CacheConfig = field(default=None, init=True) # type: ignore
parallel_config: ParallelConfig = field(default=None,
init=True) # type: ignore
scheduler_config: SchedulerConfig = field(default=None,
init=True) # type: ignore
device_config: DeviceConfig = field(default=None,
init=True) # type: ignore
load_config: LoadConfig = field(default=None, init=True) # type: ignore
lora_config: Optional[LoRAConfig] = None
speculative_config: Optional[SpeculativeConfig] = None
decoding_config: Optional[DecodingConfig] = None
@ -2091,11 +2094,14 @@ class VllmConfig:
def __post_init__(self):
"""Verify configs are valid & consistent with each other.
"""
self.model_config.verify_async_output_proc(self.parallel_config,
self.speculative_config,
self.device_config)
self.model_config.verify_with_parallel_config(self.parallel_config)
self.cache_config.verify_with_parallel_config(self.parallel_config)
if self.model_config is not None:
self.model_config.verify_async_output_proc(self.parallel_config,
self.speculative_config,
self.device_config)
self.model_config.verify_with_parallel_config(self.parallel_config)
if self.cache_config is not None:
self.cache_config.verify_with_parallel_config(self.parallel_config)
if self.lora_config:
self.lora_config.verify_with_model_config(self.model_config)
@ -2149,4 +2155,4 @@ class VllmConfig:
self.scheduler_config.num_scheduler_steps,
self.cache_config.enable_prefix_caching,
self.model_config.use_async_output_proc,
self.model_config.mm_processor_kwargs)
self.model_config.mm_processor_kwargs)