mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 04:34:57 +08:00
[Multimodal][torch.compile] Add compilation config field for turning off ViT/MM compile (#28242)
Signed-off-by: Lucas Kabela <lucaskabela@meta.com>
This commit is contained in:
parent
59b453eaa2
commit
4bf56c79cc
@ -3,10 +3,17 @@
|
||||
import pytest
|
||||
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.compilation import CompilationMode
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
def test_compile():
|
||||
vllm_config = VllmConfig()
|
||||
# Default configuration compiles mm encoder
|
||||
assert vllm_config.compilation_config.compile_mm_encoder
|
||||
|
||||
|
||||
# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
|
||||
@pytest.mark.forked
|
||||
@pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda")
|
||||
@ -31,8 +38,33 @@ def test_qwen2_5_vl_compilation(vllm_runner, monkeypatch):
|
||||
vllm_runner(
|
||||
"Qwen/Qwen2.5-VL-3B-Instruct",
|
||||
max_model_len=2048,
|
||||
gpu_memory_utilization=0.7,
|
||||
gpu_memory_utilization=0.8,
|
||||
compilation_config={"mode": CompilationMode.VLLM_COMPILE},
|
||||
) as _,
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
|
||||
@pytest.mark.forked
|
||||
@pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda")
|
||||
def test_qwen2_5_vl_no_vit_compilation(vllm_runner, monkeypatch):
|
||||
"""Test that Qwen2.5-VL vision submodules are not compiled when the
|
||||
config is passed off
|
||||
"""
|
||||
# Disable multiprocessing so that the counter is in the same process
|
||||
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
|
||||
|
||||
with (
|
||||
compilation_counter.expect(num_models_seen=1),
|
||||
vllm_runner(
|
||||
"Qwen/Qwen2.5-VL-3B-Instruct",
|
||||
max_model_len=2048,
|
||||
gpu_memory_utilization=0.8,
|
||||
compilation_config={
|
||||
"mode": CompilationMode.VLLM_COMPILE,
|
||||
"compile_mm_encoder": False,
|
||||
},
|
||||
) as _,
|
||||
):
|
||||
pass
|
||||
|
||||
@ -150,6 +150,7 @@ class CompilationConfig:
|
||||
- [`backend`][vllm.config.CompilationConfig.backend]
|
||||
- [`custom_ops`][vllm.config.CompilationConfig.custom_ops]
|
||||
- [`splitting_ops`][vllm.config.CompilationConfig.splitting_ops]
|
||||
- [`compile_mm_encoder`][vllm.config.CompilationConfig.compile_mm_encoder]
|
||||
- CudaGraph capture:
|
||||
- [`use_cudagraph`][vllm.config.CompilationConfig.use_cudagraph]
|
||||
- [`cudagraph_mode`][vllm.config.CompilationConfig.cudagraph_mode]
|
||||
@ -250,6 +251,13 @@ class CompilationConfig:
|
||||
disabled when running with Inductor: mode>=VLLM_COMPILE and use_inductor=True.
|
||||
Inductor generates (fused) Triton kernels for disabled custom ops."""
|
||||
splitting_ops: list[str] | None = None
|
||||
|
||||
"""
|
||||
Provide control over whether to compile the multimodal encoder
|
||||
such as Qwen2_5_vl
|
||||
"""
|
||||
compile_mm_encoder: bool = True
|
||||
|
||||
"""A list of ops to exclude from cudagraphs, used in piecewise compilation.
|
||||
|
||||
The behavior depends on use_inductor_graph_partition:
|
||||
|
||||
@ -67,6 +67,9 @@ from vllm.model_executor.layers.linear import (
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
from vllm.model_executor.models.transformers.utils import (
|
||||
should_torch_compile_mm_vit,
|
||||
)
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.evs import (
|
||||
compute_mrope_for_media,
|
||||
@ -464,6 +467,7 @@ class Qwen2_5_VisionAttention(nn.Module):
|
||||
"seqlens": 0,
|
||||
},
|
||||
mark_unbacked_dims={"seqlens": 0},
|
||||
enable_if=should_torch_compile_mm_vit,
|
||||
)
|
||||
class Qwen2_5_VisionBlock(nn.Module):
|
||||
def __init__(
|
||||
@ -529,7 +533,8 @@ class Qwen2_5_VisionBlock(nn.Module):
|
||||
@support_torch_compile(
|
||||
dynamic_arg_dims={
|
||||
"x": 0,
|
||||
}
|
||||
},
|
||||
enable_if=should_torch_compile_mm_vit,
|
||||
)
|
||||
class Qwen2_5_VisionPatchEmbed(nn.Module):
|
||||
def __init__(
|
||||
@ -560,7 +565,8 @@ class Qwen2_5_VisionPatchEmbed(nn.Module):
|
||||
@support_torch_compile(
|
||||
dynamic_arg_dims={
|
||||
"x": 0,
|
||||
}
|
||||
},
|
||||
enable_if=should_torch_compile_mm_vit,
|
||||
)
|
||||
class Qwen2_5_VisionPatchMerger(nn.Module):
|
||||
def __init__(
|
||||
|
||||
@ -205,3 +205,14 @@ def can_enable_torch_compile(vllm_config: "VllmConfig") -> bool:
|
||||
# Dynamic rope scaling is not compatible with torch.compile
|
||||
rope_scaling: dict = getattr(text_config, "rope_scaling", None) or {}
|
||||
return rope_scaling.get("rope_type") != "dynamic"
|
||||
|
||||
|
||||
def should_torch_compile_mm_vit(vllm_config: "VllmConfig") -> bool:
|
||||
"""
|
||||
Callable to be passed to `@support_torch_compile`'s `enable_if` argument.
|
||||
|
||||
Defaults to `True` but is disabled in the following situations:
|
||||
|
||||
- The model uses dynamic rope scaling.
|
||||
"""
|
||||
return vllm_config.compilation_config.compile_mm_encoder
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user