mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-18 06:25:01 +08:00
[1/N] torch.compile user interface design (#10237)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
parent
9cdba9669c
commit
eea55cca5b
@ -12,10 +12,9 @@ from vllm.compilation.compile_context import set_compile_context
|
|||||||
from vllm.compilation.counter import compilation_counter
|
from vllm.compilation.counter import compilation_counter
|
||||||
from vllm.compilation.decorators import support_torch_compile
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.compilation.levels import CompilationLevel
|
from vllm.compilation.levels import CompilationLevel
|
||||||
|
from vllm.config import VllmConfig
|
||||||
from vllm.utils import direct_register_custom_op
|
from vllm.utils import direct_register_custom_op
|
||||||
|
|
||||||
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(CompilationLevel.PIECEWISE)
|
|
||||||
|
|
||||||
global_counter = 0
|
global_counter = 0
|
||||||
|
|
||||||
# create a library to hold the custom op
|
# create a library to hold the custom op
|
||||||
@ -48,7 +47,11 @@ direct_register_custom_op(
|
|||||||
@support_torch_compile
|
@support_torch_compile
|
||||||
class SillyModel(nn.Module):
|
class SillyModel(nn.Module):
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self,
|
||||||
|
*,
|
||||||
|
vllm_config: VllmConfig,
|
||||||
|
prefix: str = '',
|
||||||
|
**kwargs) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
@ -74,11 +77,12 @@ class SillyModel(nn.Module):
|
|||||||
|
|
||||||
def test_simple_piecewise_compile():
|
def test_simple_piecewise_compile():
|
||||||
|
|
||||||
model = SillyModel()
|
|
||||||
|
|
||||||
directory = os.path.dirname(__file__)
|
directory = os.path.dirname(__file__)
|
||||||
config = os.path.join(directory, "piecewise_compilation_config.json")
|
config = os.path.join(directory, "piecewise_compilation_config.json")
|
||||||
os.environ["VLLM_TORCH_COMPILE_CONFIG"] = config
|
os.environ["VLLM_TORCH_COMPILE_CONFIG"] = config
|
||||||
|
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(CompilationLevel.PIECEWISE)
|
||||||
|
|
||||||
|
model = SillyModel(vllm_config=VllmConfig(), prefix='')
|
||||||
|
|
||||||
inputs = torch.randn(100).cuda()
|
inputs = torch.randn(100).cuda()
|
||||||
|
|
||||||
|
|||||||
@ -19,6 +19,7 @@ from vllm.compilation.config import CompilationConfig
|
|||||||
from vllm.compilation.counter import compilation_counter
|
from vllm.compilation.counter import compilation_counter
|
||||||
from vllm.compilation.decorators import support_torch_compile
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.compilation.levels import CompilationLevel
|
from vllm.compilation.levels import CompilationLevel
|
||||||
|
from vllm.config import VllmConfig
|
||||||
from vllm.plugins import set_compilation_config
|
from vllm.plugins import set_compilation_config
|
||||||
from vllm.utils import direct_register_custom_op
|
from vllm.utils import direct_register_custom_op
|
||||||
|
|
||||||
@ -195,9 +196,15 @@ class LlamaDecoderLayer(nn.Module):
|
|||||||
return hidden_states, residual
|
return hidden_states, residual
|
||||||
|
|
||||||
|
|
||||||
|
@support_torch_compile
|
||||||
class LlamaModel(nn.Module):
|
class LlamaModel(nn.Module):
|
||||||
|
|
||||||
def __init__(self, config: LlamaConfig) -> None:
|
def __init__(self,
|
||||||
|
*,
|
||||||
|
vllm_config: VllmConfig,
|
||||||
|
config: LlamaConfig,
|
||||||
|
prefix: str = '',
|
||||||
|
**kwargs) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.embedding_tokens = nn.Embedding(
|
self.embedding_tokens = nn.Embedding(
|
||||||
num_embeddings=config.vocab_size,
|
num_embeddings=config.vocab_size,
|
||||||
@ -265,10 +272,9 @@ def run_model(llama_config,
|
|||||||
CompilationLevel.NO_COMPILATION)
|
CompilationLevel.NO_COMPILATION)
|
||||||
set_compilation_config(None)
|
set_compilation_config(None)
|
||||||
|
|
||||||
cls = LlamaModel
|
model = LlamaModel(config=llama_config,
|
||||||
if use_compile:
|
vllm_config=VllmConfig(),
|
||||||
cls = support_torch_compile(LlamaModel)
|
prefix="").eval().cuda()
|
||||||
model = cls(llama_config).eval().cuda()
|
|
||||||
|
|
||||||
B = 16 # max batch size
|
B = 16 # max batch size
|
||||||
input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda()
|
input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda()
|
||||||
@ -357,7 +363,6 @@ def test_toy_llama():
|
|||||||
def benchmark():
|
def benchmark():
|
||||||
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(CompilationLevel.PIECEWISE)
|
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(CompilationLevel.PIECEWISE)
|
||||||
from triton.testing import do_bench
|
from triton.testing import do_bench
|
||||||
cls = support_torch_compile(LlamaModel)
|
|
||||||
|
|
||||||
# similar to llama 3.1-8B
|
# similar to llama 3.1-8B
|
||||||
llama_config = LlamaConfig(hidden_size=4096,
|
llama_config = LlamaConfig(hidden_size=4096,
|
||||||
@ -390,7 +395,9 @@ def benchmark():
|
|||||||
else:
|
else:
|
||||||
set_compilation_config(None)
|
set_compilation_config(None)
|
||||||
|
|
||||||
model = cls(llama_config).eval().cuda().to(torch.bfloat16)
|
model = LlamaModel(config=llama_config,
|
||||||
|
vllm_config=VllmConfig(),
|
||||||
|
prefix="").eval().cuda().to(torch.bfloat16)
|
||||||
|
|
||||||
B = 256 # max batch size
|
B = 256 # max batch size
|
||||||
input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda()
|
input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda()
|
||||||
|
|||||||
@ -6,6 +6,7 @@ import torch
|
|||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.compilation.levels import CompilationLevel
|
from vllm.compilation.levels import CompilationLevel
|
||||||
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
|
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
|
||||||
|
from vllm.config import VllmConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.utils import supports_dynamo
|
from vllm.utils import supports_dynamo
|
||||||
@ -110,26 +111,26 @@ def _support_torch_compile(cls: type,
|
|||||||
"""
|
"""
|
||||||
A decorator to add support for compiling the forward method of a class.
|
A decorator to add support for compiling the forward method of a class.
|
||||||
"""
|
"""
|
||||||
|
if TorchCompileWrapperWithCustomDispatcher in cls.__bases__:
|
||||||
# for CompilationLevel.DYNAMO_AS_IS , the upper level model runner
|
# support decorating multiple times
|
||||||
# will handle the compilation, so we don't need to do anything here.
|
|
||||||
if envs.VLLM_TORCH_COMPILE_LEVEL in [
|
|
||||||
CompilationLevel.NO_COMPILATION, CompilationLevel.DYNAMO_AS_IS
|
|
||||||
] or not supports_dynamo():
|
|
||||||
return cls
|
return cls
|
||||||
|
|
||||||
# take care of method resolution order
|
# take care of method resolution order
|
||||||
# make sure super().__init__ is called on the base class
|
# make sure super().__init__ is called on the base class
|
||||||
# other than TorchCompileWrapperWithCustomDispatcher
|
# other than TorchCompileWrapperWithCustomDispatcher
|
||||||
if TorchCompileWrapperWithCustomDispatcher not in cls.__bases__:
|
cls.__bases__ = cls.__bases__ + (TorchCompileWrapperWithCustomDispatcher, )
|
||||||
# support decorating multiple times
|
|
||||||
cls.__bases__ = cls.__bases__ + (
|
|
||||||
TorchCompileWrapperWithCustomDispatcher, )
|
|
||||||
|
|
||||||
old_init = cls.__init__ # type: ignore
|
old_init = cls.__init__ # type: ignore
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = '', **kwargs):
|
||||||
old_init(self, *args, **kwargs)
|
old_init(self, vllm_config=vllm_config, prefix=prefix, **kwargs)
|
||||||
|
# for CompilationLevel.DYNAMO_AS_IS , the upper level model runner
|
||||||
|
# will handle the compilation, so we don't need to do anything here.
|
||||||
|
self.do_not_compile = envs.VLLM_TORCH_COMPILE_LEVEL in [
|
||||||
|
CompilationLevel.NO_COMPILATION, CompilationLevel.DYNAMO_AS_IS
|
||||||
|
] or not supports_dynamo()
|
||||||
|
if self.do_not_compile:
|
||||||
|
return
|
||||||
TorchCompileWrapperWithCustomDispatcher.__init__(self)
|
TorchCompileWrapperWithCustomDispatcher.__init__(self)
|
||||||
|
|
||||||
cls.__init__ = __init__ # type: ignore
|
cls.__init__ = __init__ # type: ignore
|
||||||
@ -138,7 +139,7 @@ def _support_torch_compile(cls: type,
|
|||||||
# torch.compiler.is_compiling() means we are inside the compilation
|
# torch.compiler.is_compiling() means we are inside the compilation
|
||||||
# e.g. TPU has the compilation logic in model runner, so we don't
|
# e.g. TPU has the compilation logic in model runner, so we don't
|
||||||
# need to compile the model inside.
|
# need to compile the model inside.
|
||||||
if torch.compiler.is_compiling():
|
if self.do_not_compile or torch.compiler.is_compiling():
|
||||||
return self.forward(*args, **kwargs)
|
return self.forward(*args, **kwargs)
|
||||||
|
|
||||||
# the first compilation needs to have dynamic shapes marked
|
# the first compilation needs to have dynamic shapes marked
|
||||||
|
|||||||
@ -2041,12 +2041,15 @@ class VllmConfig:
|
|||||||
simplifies passing around the distinct configurations in the codebase.
|
simplifies passing around the distinct configurations in the codebase.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model_config: ModelConfig
|
model_config: ModelConfig = field(default=None, init=True) # type: ignore
|
||||||
cache_config: CacheConfig
|
cache_config: CacheConfig = field(default=None, init=True) # type: ignore
|
||||||
parallel_config: ParallelConfig
|
parallel_config: ParallelConfig = field(default=None,
|
||||||
scheduler_config: SchedulerConfig
|
init=True) # type: ignore
|
||||||
device_config: DeviceConfig
|
scheduler_config: SchedulerConfig = field(default=None,
|
||||||
load_config: LoadConfig
|
init=True) # type: ignore
|
||||||
|
device_config: DeviceConfig = field(default=None,
|
||||||
|
init=True) # type: ignore
|
||||||
|
load_config: LoadConfig = field(default=None, init=True) # type: ignore
|
||||||
lora_config: Optional[LoRAConfig] = None
|
lora_config: Optional[LoRAConfig] = None
|
||||||
speculative_config: Optional[SpeculativeConfig] = None
|
speculative_config: Optional[SpeculativeConfig] = None
|
||||||
decoding_config: Optional[DecodingConfig] = None
|
decoding_config: Optional[DecodingConfig] = None
|
||||||
@ -2091,11 +2094,14 @@ class VllmConfig:
|
|||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
"""Verify configs are valid & consistent with each other.
|
"""Verify configs are valid & consistent with each other.
|
||||||
"""
|
"""
|
||||||
self.model_config.verify_async_output_proc(self.parallel_config,
|
if self.model_config is not None:
|
||||||
self.speculative_config,
|
self.model_config.verify_async_output_proc(self.parallel_config,
|
||||||
self.device_config)
|
self.speculative_config,
|
||||||
self.model_config.verify_with_parallel_config(self.parallel_config)
|
self.device_config)
|
||||||
self.cache_config.verify_with_parallel_config(self.parallel_config)
|
self.model_config.verify_with_parallel_config(self.parallel_config)
|
||||||
|
|
||||||
|
if self.cache_config is not None:
|
||||||
|
self.cache_config.verify_with_parallel_config(self.parallel_config)
|
||||||
|
|
||||||
if self.lora_config:
|
if self.lora_config:
|
||||||
self.lora_config.verify_with_model_config(self.model_config)
|
self.lora_config.verify_with_model_config(self.model_config)
|
||||||
@ -2149,4 +2155,4 @@ class VllmConfig:
|
|||||||
self.scheduler_config.num_scheduler_steps,
|
self.scheduler_config.num_scheduler_steps,
|
||||||
self.cache_config.enable_prefix_caching,
|
self.cache_config.enable_prefix_caching,
|
||||||
self.model_config.use_async_output_proc,
|
self.model_config.use_async_output_proc,
|
||||||
self.model_config.mm_processor_kwargs)
|
self.model_config.mm_processor_kwargs)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user