[1/N] torch.compile user interface design (#10237)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao 2024-11-11 18:01:06 -08:00 committed by GitHub
parent 9cdba9669c
commit eea55cca5b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 55 additions and 37 deletions

View File

@ -12,10 +12,9 @@ from vllm.compilation.compile_context import set_compile_context
from vllm.compilation.counter import compilation_counter from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.compilation.levels import CompilationLevel from vllm.compilation.levels import CompilationLevel
from vllm.config import VllmConfig
from vllm.utils import direct_register_custom_op from vllm.utils import direct_register_custom_op
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(CompilationLevel.PIECEWISE)
global_counter = 0 global_counter = 0
# create a library to hold the custom op # create a library to hold the custom op
@ -48,7 +47,11 @@ direct_register_custom_op(
@support_torch_compile @support_torch_compile
class SillyModel(nn.Module): class SillyModel(nn.Module):
def __init__(self) -> None: def __init__(self,
*,
vllm_config: VllmConfig,
prefix: str = '',
**kwargs) -> None:
super().__init__() super().__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
@ -74,11 +77,12 @@ class SillyModel(nn.Module):
def test_simple_piecewise_compile(): def test_simple_piecewise_compile():
model = SillyModel()
directory = os.path.dirname(__file__) directory = os.path.dirname(__file__)
config = os.path.join(directory, "piecewise_compilation_config.json") config = os.path.join(directory, "piecewise_compilation_config.json")
os.environ["VLLM_TORCH_COMPILE_CONFIG"] = config os.environ["VLLM_TORCH_COMPILE_CONFIG"] = config
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(CompilationLevel.PIECEWISE)
model = SillyModel(vllm_config=VllmConfig(), prefix='')
inputs = torch.randn(100).cuda() inputs = torch.randn(100).cuda()

View File

@ -19,6 +19,7 @@ from vllm.compilation.config import CompilationConfig
from vllm.compilation.counter import compilation_counter from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.compilation.levels import CompilationLevel from vllm.compilation.levels import CompilationLevel
from vllm.config import VllmConfig
from vllm.plugins import set_compilation_config from vllm.plugins import set_compilation_config
from vllm.utils import direct_register_custom_op from vllm.utils import direct_register_custom_op
@ -195,9 +196,15 @@ class LlamaDecoderLayer(nn.Module):
return hidden_states, residual return hidden_states, residual
@support_torch_compile
class LlamaModel(nn.Module): class LlamaModel(nn.Module):
def __init__(self, config: LlamaConfig) -> None: def __init__(self,
*,
vllm_config: VllmConfig,
config: LlamaConfig,
prefix: str = '',
**kwargs) -> None:
super().__init__() super().__init__()
self.embedding_tokens = nn.Embedding( self.embedding_tokens = nn.Embedding(
num_embeddings=config.vocab_size, num_embeddings=config.vocab_size,
@ -265,10 +272,9 @@ def run_model(llama_config,
CompilationLevel.NO_COMPILATION) CompilationLevel.NO_COMPILATION)
set_compilation_config(None) set_compilation_config(None)
cls = LlamaModel model = LlamaModel(config=llama_config,
if use_compile: vllm_config=VllmConfig(),
cls = support_torch_compile(LlamaModel) prefix="").eval().cuda()
model = cls(llama_config).eval().cuda()
B = 16 # max batch size B = 16 # max batch size
input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda() input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda()
@ -357,7 +363,6 @@ def test_toy_llama():
def benchmark(): def benchmark():
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(CompilationLevel.PIECEWISE) os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(CompilationLevel.PIECEWISE)
from triton.testing import do_bench from triton.testing import do_bench
cls = support_torch_compile(LlamaModel)
# similar to llama 3.1-8B # similar to llama 3.1-8B
llama_config = LlamaConfig(hidden_size=4096, llama_config = LlamaConfig(hidden_size=4096,
@ -390,7 +395,9 @@ def benchmark():
else: else:
set_compilation_config(None) set_compilation_config(None)
model = cls(llama_config).eval().cuda().to(torch.bfloat16) model = LlamaModel(config=llama_config,
vllm_config=VllmConfig(),
prefix="").eval().cuda().to(torch.bfloat16)
B = 256 # max batch size B = 256 # max batch size
input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda() input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda()

View File

@ -6,6 +6,7 @@ import torch
import vllm.envs as envs import vllm.envs as envs
from vllm.compilation.levels import CompilationLevel from vllm.compilation.levels import CompilationLevel
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import supports_dynamo from vllm.utils import supports_dynamo
@ -110,26 +111,26 @@ def _support_torch_compile(cls: type,
""" """
A decorator to add support for compiling the forward method of a class. A decorator to add support for compiling the forward method of a class.
""" """
if TorchCompileWrapperWithCustomDispatcher in cls.__bases__:
# for CompilationLevel.DYNAMO_AS_IS , the upper level model runner # support decorating multiple times
# will handle the compilation, so we don't need to do anything here.
if envs.VLLM_TORCH_COMPILE_LEVEL in [
CompilationLevel.NO_COMPILATION, CompilationLevel.DYNAMO_AS_IS
] or not supports_dynamo():
return cls return cls
# take care of method resolution order # take care of method resolution order
# make sure super().__init__ is called on the base class # make sure super().__init__ is called on the base class
# other than TorchCompileWrapperWithCustomDispatcher # other than TorchCompileWrapperWithCustomDispatcher
if TorchCompileWrapperWithCustomDispatcher not in cls.__bases__: cls.__bases__ = cls.__bases__ + (TorchCompileWrapperWithCustomDispatcher, )
# support decorating multiple times
cls.__bases__ = cls.__bases__ + (
TorchCompileWrapperWithCustomDispatcher, )
old_init = cls.__init__ # type: ignore old_init = cls.__init__ # type: ignore
def __init__(self, *args, **kwargs): def __init__(self, *, vllm_config: VllmConfig, prefix: str = '', **kwargs):
old_init(self, *args, **kwargs) old_init(self, vllm_config=vllm_config, prefix=prefix, **kwargs)
# for CompilationLevel.DYNAMO_AS_IS , the upper level model runner
# will handle the compilation, so we don't need to do anything here.
self.do_not_compile = envs.VLLM_TORCH_COMPILE_LEVEL in [
CompilationLevel.NO_COMPILATION, CompilationLevel.DYNAMO_AS_IS
] or not supports_dynamo()
if self.do_not_compile:
return
TorchCompileWrapperWithCustomDispatcher.__init__(self) TorchCompileWrapperWithCustomDispatcher.__init__(self)
cls.__init__ = __init__ # type: ignore cls.__init__ = __init__ # type: ignore
@ -138,7 +139,7 @@ def _support_torch_compile(cls: type,
# torch.compiler.is_compiling() means we are inside the compilation # torch.compiler.is_compiling() means we are inside the compilation
# e.g. TPU has the compilation logic in model runner, so we don't # e.g. TPU has the compilation logic in model runner, so we don't
# need to compile the model inside. # need to compile the model inside.
if torch.compiler.is_compiling(): if self.do_not_compile or torch.compiler.is_compiling():
return self.forward(*args, **kwargs) return self.forward(*args, **kwargs)
# the first compilation needs to have dynamic shapes marked # the first compilation needs to have dynamic shapes marked

View File

@ -2041,12 +2041,15 @@ class VllmConfig:
simplifies passing around the distinct configurations in the codebase. simplifies passing around the distinct configurations in the codebase.
""" """
model_config: ModelConfig model_config: ModelConfig = field(default=None, init=True) # type: ignore
cache_config: CacheConfig cache_config: CacheConfig = field(default=None, init=True) # type: ignore
parallel_config: ParallelConfig parallel_config: ParallelConfig = field(default=None,
scheduler_config: SchedulerConfig init=True) # type: ignore
device_config: DeviceConfig scheduler_config: SchedulerConfig = field(default=None,
load_config: LoadConfig init=True) # type: ignore
device_config: DeviceConfig = field(default=None,
init=True) # type: ignore
load_config: LoadConfig = field(default=None, init=True) # type: ignore
lora_config: Optional[LoRAConfig] = None lora_config: Optional[LoRAConfig] = None
speculative_config: Optional[SpeculativeConfig] = None speculative_config: Optional[SpeculativeConfig] = None
decoding_config: Optional[DecodingConfig] = None decoding_config: Optional[DecodingConfig] = None
@ -2091,11 +2094,14 @@ class VllmConfig:
def __post_init__(self): def __post_init__(self):
"""Verify configs are valid & consistent with each other. """Verify configs are valid & consistent with each other.
""" """
self.model_config.verify_async_output_proc(self.parallel_config, if self.model_config is not None:
self.speculative_config, self.model_config.verify_async_output_proc(self.parallel_config,
self.device_config) self.speculative_config,
self.model_config.verify_with_parallel_config(self.parallel_config) self.device_config)
self.cache_config.verify_with_parallel_config(self.parallel_config) self.model_config.verify_with_parallel_config(self.parallel_config)
if self.cache_config is not None:
self.cache_config.verify_with_parallel_config(self.parallel_config)
if self.lora_config: if self.lora_config:
self.lora_config.verify_with_model_config(self.model_config) self.lora_config.verify_with_model_config(self.model_config)
@ -2149,4 +2155,4 @@ class VllmConfig:
self.scheduler_config.num_scheduler_steps, self.scheduler_config.num_scheduler_steps,
self.cache_config.enable_prefix_caching, self.cache_config.enable_prefix_caching,
self.model_config.use_async_output_proc, self.model_config.use_async_output_proc,
self.model_config.mm_processor_kwargs) self.model_config.mm_processor_kwargs)