mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-17 07:35:02 +08:00
[Bugfix] VLLM_V1 supports passing other compilation levels (#19340)
Signed-off-by: Richard Zou <zou3519@gmail.com>
This commit is contained in:
parent
ab714131e4
commit
04e38500ee
@ -26,6 +26,8 @@ def test_use_cudagraphs_dynamic(monkeypatch):
|
|||||||
assert not vllm_config.compilation_config.use_cudagraph
|
assert not vllm_config.compilation_config.use_cudagraph
|
||||||
|
|
||||||
|
|
||||||
|
# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
|
||||||
|
@pytest.mark.forked
|
||||||
# NB: We don't test VLLM_DISABLE_COMPILE_CACHE=0 because that depends
|
# NB: We don't test VLLM_DISABLE_COMPILE_CACHE=0 because that depends
|
||||||
# on the state of the cache directory on the current machine, which
|
# on the state of the cache directory on the current machine, which
|
||||||
# may be influenced by other tests.
|
# may be influenced by other tests.
|
||||||
@ -33,8 +35,8 @@ def test_use_cudagraphs_dynamic(monkeypatch):
|
|||||||
def test_VLLM_DISABLE_COMPILE_CACHE(vllm_runner, monkeypatch, val):
|
def test_VLLM_DISABLE_COMPILE_CACHE(vllm_runner, monkeypatch, val):
|
||||||
assert vllm.envs.VLLM_USE_V1
|
assert vllm.envs.VLLM_USE_V1
|
||||||
|
|
||||||
# spawn means that the counters are in the same process.
|
# Disable multiprocessing so that the counter is in the same process
|
||||||
monkeypatch.setenv('VLLM_WORKER_MULTIPROC_METHOD', "spawn")
|
monkeypatch.setenv('VLLM_ENABLE_V1_MULTIPROCESSING', '0')
|
||||||
monkeypatch.setenv('VLLM_DISABLE_COMPILE_CACHE', val)
|
monkeypatch.setenv('VLLM_DISABLE_COMPILE_CACHE', val)
|
||||||
|
|
||||||
compilation_config = {
|
compilation_config = {
|
||||||
@ -50,6 +52,8 @@ def test_VLLM_DISABLE_COMPILE_CACHE(vllm_runner, monkeypatch, val):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
|
||||||
|
@pytest.mark.forked
|
||||||
@pytest.mark.parametrize("enabled", [True, False])
|
@pytest.mark.parametrize("enabled", [True, False])
|
||||||
def test_use_cudagraphs(vllm_runner, monkeypatch, enabled):
|
def test_use_cudagraphs(vllm_runner, monkeypatch, enabled):
|
||||||
assert vllm.envs.VLLM_USE_V1
|
assert vllm.envs.VLLM_USE_V1
|
||||||
@ -72,3 +76,50 @@ def test_use_cudagraphs(vllm_runner, monkeypatch, enabled):
|
|||||||
compilation_config=compilation_config,
|
compilation_config=compilation_config,
|
||||||
gpu_memory_utilization=0.4) as _):
|
gpu_memory_utilization=0.4) as _):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
|
||||||
|
@pytest.mark.forked
|
||||||
|
def test_dynamo_as_is(vllm_runner, monkeypatch):
|
||||||
|
# Disable multiprocessing so that the counter is in the same process
|
||||||
|
monkeypatch.setenv('VLLM_ENABLE_V1_MULTIPROCESSING', '0')
|
||||||
|
|
||||||
|
with (
|
||||||
|
compilation_counter.expect(dynamo_as_is_count=1),
|
||||||
|
# loading the model causes compilation (if enabled) to happen
|
||||||
|
vllm_runner('facebook/opt-125m',
|
||||||
|
compilation_config={"level": 1},
|
||||||
|
gpu_memory_utilization=0.4) as _):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
|
||||||
|
@pytest.mark.forked
|
||||||
|
def test_no_compilation(vllm_runner, monkeypatch):
|
||||||
|
# Disable multiprocessing so that the counter is in the same process
|
||||||
|
monkeypatch.setenv('VLLM_ENABLE_V1_MULTIPROCESSING', '0')
|
||||||
|
|
||||||
|
with (
|
||||||
|
compilation_counter.expect(num_graphs_seen=0,
|
||||||
|
dynamo_as_is_count=0),
|
||||||
|
# loading the model causes compilation (if enabled) to happen
|
||||||
|
vllm_runner('facebook/opt-125m',
|
||||||
|
compilation_config={"level": 0},
|
||||||
|
gpu_memory_utilization=0.4) as _):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
|
||||||
|
@pytest.mark.forked
|
||||||
|
def test_enforce_eager(vllm_runner, monkeypatch):
|
||||||
|
# Disable multiprocessing so that the counter is in the same process
|
||||||
|
monkeypatch.setenv('VLLM_ENABLE_V1_MULTIPROCESSING', '0')
|
||||||
|
|
||||||
|
with (
|
||||||
|
compilation_counter.expect(num_graphs_seen=0,
|
||||||
|
dynamo_as_is_count=0),
|
||||||
|
# loading the model causes compilation (if enabled) to happen
|
||||||
|
vllm_runner('facebook/opt-125m',
|
||||||
|
enforce_eager=True,
|
||||||
|
gpu_memory_utilization=0.4) as _):
|
||||||
|
pass
|
||||||
|
|||||||
@ -27,6 +27,8 @@ class CompilationCounter:
|
|||||||
num_cache_entries_updated: int = 0
|
num_cache_entries_updated: int = 0
|
||||||
# The number of standalone_compile compiled artifacts saved
|
# The number of standalone_compile compiled artifacts saved
|
||||||
num_compiled_artifacts_saved: int = 0
|
num_compiled_artifacts_saved: int = 0
|
||||||
|
# Number of times a model was loaded with CompilationLevel.DYNAMO_AS_IS
|
||||||
|
dynamo_as_is_count: int = 0
|
||||||
|
|
||||||
def clone(self) -> "CompilationCounter":
|
def clone(self) -> "CompilationCounter":
|
||||||
return copy.deepcopy(self)
|
return copy.deepcopy(self)
|
||||||
|
|||||||
@ -4106,9 +4106,11 @@ class CompilationConfig:
|
|||||||
certain small batchsizes, where inductor is good at optimizing.
|
certain small batchsizes, where inductor is good at optimizing.
|
||||||
"""
|
"""
|
||||||
# Top-level Compilation control
|
# Top-level Compilation control
|
||||||
level: int = 0
|
level: Optional[int] = None
|
||||||
"""The level of compilation:
|
"""The level of compilation:
|
||||||
|
|
||||||
|
- None: If None, we will select the default compilation level.
|
||||||
|
For V1 engine this is 3, for V0 engine this is 0.
|
||||||
- 0: no compilation.
|
- 0: no compilation.
|
||||||
- 1: dynamo as is.
|
- 1: dynamo as is.
|
||||||
- 2: dynamo once.
|
- 2: dynamo once.
|
||||||
@ -4664,6 +4666,22 @@ class VllmConfig:
|
|||||||
"To workaround this limitation, vLLM will set 'ieee' input "
|
"To workaround this limitation, vLLM will set 'ieee' input "
|
||||||
"precision for chunked prefill triton kernels.")
|
"precision for chunked prefill triton kernels.")
|
||||||
|
|
||||||
|
# If the user does not explicitly set a compilation level, then
|
||||||
|
# we use the default level. The default level depends on other
|
||||||
|
# settings (see the below code).
|
||||||
|
if self.compilation_config.level is None:
|
||||||
|
if envs.VLLM_USE_V1:
|
||||||
|
if (self.model_config is not None
|
||||||
|
and not self.model_config.enforce_eager):
|
||||||
|
self.compilation_config.level = CompilationLevel.PIECEWISE
|
||||||
|
else:
|
||||||
|
self.compilation_config.level = \
|
||||||
|
CompilationLevel.NO_COMPILATION
|
||||||
|
else:
|
||||||
|
# NB: Passing both --enforce-eager and a compilation level
|
||||||
|
# in V0 means the compilation level wins out.
|
||||||
|
self.compilation_config.level = CompilationLevel.NO_COMPILATION
|
||||||
|
|
||||||
# async tp is built on top of sequence parallelism
|
# async tp is built on top of sequence parallelism
|
||||||
# and requires it to be enabled.
|
# and requires it to be enabled.
|
||||||
if self.compilation_config.pass_config.enable_async_tp:
|
if self.compilation_config.pass_config.enable_async_tp:
|
||||||
@ -4676,7 +4694,6 @@ class VllmConfig:
|
|||||||
# By default, V1 uses piecewise CUDA graphs. If full_cuda_graph
|
# By default, V1 uses piecewise CUDA graphs. If full_cuda_graph
|
||||||
# is set to True, full CUDA graphs will be used.
|
# is set to True, full CUDA graphs will be used.
|
||||||
self.compilation_config.cudagraph_num_of_warmups = 1
|
self.compilation_config.cudagraph_num_of_warmups = 1
|
||||||
self.compilation_config.level = CompilationLevel.PIECEWISE
|
|
||||||
self.compilation_config.set_splitting_ops_for_v1()
|
self.compilation_config.set_splitting_ops_for_v1()
|
||||||
|
|
||||||
self._set_cudagraph_sizes()
|
self._set_cudagraph_sizes()
|
||||||
|
|||||||
@ -43,7 +43,7 @@ from vllm.sequence import IntermediateTensors, PoolerOutput
|
|||||||
from vllm.tasks import GenerationTask, PoolingTask, SupportedTask
|
from vllm.tasks import GenerationTask, PoolingTask, SupportedTask
|
||||||
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
|
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
|
||||||
GiB_bytes, LazyLoader, check_use_alibi, get_dtype_size,
|
GiB_bytes, LazyLoader, check_use_alibi, get_dtype_size,
|
||||||
is_pin_memory_available, round_up)
|
is_pin_memory_available, round_up, supports_dynamo)
|
||||||
from vllm.v1.attention.backends.mamba_selectors import get_mamba_attn_backend
|
from vllm.v1.attention.backends.mamba_selectors import get_mamba_attn_backend
|
||||||
from vllm.v1.attention.backends.utils import (
|
from vllm.v1.attention.backends.utils import (
|
||||||
AttentionMetadataBuilder, CommonAttentionMetadata,
|
AttentionMetadataBuilder, CommonAttentionMetadata,
|
||||||
@ -1930,6 +1930,17 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
rank_mapping,
|
rank_mapping,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.vllm_config.compilation_config.level == \
|
||||||
|
CompilationLevel.DYNAMO_AS_IS and supports_dynamo()
|
||||||
|
):
|
||||||
|
backend = self.vllm_config.compilation_config.init_backend(
|
||||||
|
self.vllm_config)
|
||||||
|
compilation_counter.dynamo_as_is_count += 1
|
||||||
|
self.model.compile(
|
||||||
|
fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
|
||||||
|
backend=backend)
|
||||||
|
|
||||||
def reload_weights(self) -> None:
|
def reload_weights(self) -> None:
|
||||||
assert getattr(self, "model", None) is not None, \
|
assert getattr(self, "model", None) is not None, \
|
||||||
"Cannot reload weights before model is loaded."
|
"Cannot reload weights before model is loaded."
|
||||||
|
|||||||
@ -22,6 +22,7 @@ import vllm.envs as envs
|
|||||||
from vllm.attention import AttentionMetadata, get_attn_backend
|
from vllm.attention import AttentionMetadata, get_attn_backend
|
||||||
from vllm.attention.backends.abstract import AttentionState
|
from vllm.attention.backends.abstract import AttentionState
|
||||||
from vllm.attention.backends.utils import CommonAttentionState
|
from vllm.attention.backends.utils import CommonAttentionState
|
||||||
|
from vllm.compilation.counter import compilation_counter
|
||||||
from vllm.config import CompilationLevel, VllmConfig
|
from vllm.config import CompilationLevel, VllmConfig
|
||||||
from vllm.core.scheduler import SchedulerOutputs
|
from vllm.core.scheduler import SchedulerOutputs
|
||||||
from vllm.distributed import broadcast_tensor_dict, get_pp_group
|
from vllm.distributed import broadcast_tensor_dict, get_pp_group
|
||||||
@ -1121,6 +1122,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
|||||||
CompilationLevel.DYNAMO_AS_IS and supports_dynamo():
|
CompilationLevel.DYNAMO_AS_IS and supports_dynamo():
|
||||||
backend = self.vllm_config.compilation_config.init_backend(
|
backend = self.vllm_config.compilation_config.init_backend(
|
||||||
self.vllm_config)
|
self.vllm_config)
|
||||||
|
compilation_counter.dynamo_as_is_count += 1
|
||||||
self.model = torch.compile(
|
self.model = torch.compile(
|
||||||
self.model,
|
self.model,
|
||||||
fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
|
fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user