mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-24 05:55:02 +08:00
[torch.compile] Add encoder tag for compilation (#30489)
Signed-off-by: ilmarkov <markovilya197@gmail.com>
This commit is contained in:
parent
3a20450d31
commit
3224ea9915
@ -463,21 +463,27 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
|
|||||||
# the tag for the part of model being compiled,
|
# the tag for the part of model being compiled,
|
||||||
# e.g. backbone/eagle_head
|
# e.g. backbone/eagle_head
|
||||||
model_tag: str = "backbone"
|
model_tag: str = "backbone"
|
||||||
|
model_is_encoder: bool = False
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def set_model_tag(tag: str):
|
def set_model_tag(tag: str, is_encoder: bool = False):
|
||||||
"""Context manager to set the model tag."""
|
"""Context manager to set the model tag."""
|
||||||
global model_tag
|
global model_tag
|
||||||
|
global model_is_encoder
|
||||||
assert tag != model_tag, (
|
assert tag != model_tag, (
|
||||||
f"Model tag {tag} is the same as the current tag {model_tag}."
|
f"Model tag {tag} is the same as the current tag {model_tag}."
|
||||||
)
|
)
|
||||||
old_tag = model_tag
|
old_tag = model_tag
|
||||||
|
old_is_encoder = model_is_encoder
|
||||||
|
|
||||||
model_tag = tag
|
model_tag = tag
|
||||||
|
model_is_encoder = is_encoder
|
||||||
try:
|
try:
|
||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
model_tag = old_tag
|
model_tag = old_tag
|
||||||
|
model_is_encoder = old_is_encoder
|
||||||
|
|
||||||
|
|
||||||
class VllmBackend:
|
class VllmBackend:
|
||||||
@ -523,6 +529,9 @@ class VllmBackend:
|
|||||||
# them, e.g. backbone (default), eagle_head, etc.
|
# them, e.g. backbone (default), eagle_head, etc.
|
||||||
self.prefix = prefix or model_tag
|
self.prefix = prefix or model_tag
|
||||||
|
|
||||||
|
# Mark compilation for encoder.
|
||||||
|
self.is_encoder = model_is_encoder
|
||||||
|
|
||||||
# Passes to run on the graph post-grad.
|
# Passes to run on the graph post-grad.
|
||||||
self.pass_manager = resolve_obj_by_qualname(
|
self.pass_manager = resolve_obj_by_qualname(
|
||||||
current_platform.get_pass_manager_cls()
|
current_platform.get_pass_manager_cls()
|
||||||
|
|||||||
@ -53,12 +53,7 @@ class PiecewiseBackend:
|
|||||||
self.is_last_graph = piecewise_compile_index == total_piecewise_compiles - 1
|
self.is_last_graph = piecewise_compile_index == total_piecewise_compiles - 1
|
||||||
|
|
||||||
self.is_full_graph = total_piecewise_compiles == 1
|
self.is_full_graph = total_piecewise_compiles == 1
|
||||||
# TODO: we need to generalize encoder compilation to other models
|
self.is_encoder_compilation = vllm_backend.is_encoder
|
||||||
self.is_encoder_compilation = vllm_backend.prefix in [
|
|
||||||
"Qwen2_5_VisionPatchEmbed",
|
|
||||||
"Qwen2_5_VisionPatchMerger",
|
|
||||||
"Qwen2_5_VisionBlock",
|
|
||||||
]
|
|
||||||
|
|
||||||
self.compile_ranges = self.compilation_config.get_compile_ranges()
|
self.compile_ranges = self.compilation_config.get_compile_ranges()
|
||||||
if self.is_encoder_compilation:
|
if self.is_encoder_compilation:
|
||||||
|
|||||||
@ -612,7 +612,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
|||||||
# DO NOT MOVE THIS IMPORT
|
# DO NOT MOVE THIS IMPORT
|
||||||
from vllm.compilation.backends import set_model_tag
|
from vllm.compilation.backends import set_model_tag
|
||||||
|
|
||||||
with set_model_tag("Qwen2_5_VisionPatchEmbed"):
|
with set_model_tag("Qwen2_5_VisionPatchEmbed", is_encoder=True):
|
||||||
self.patch_embed = Qwen2_5_VisionPatchEmbed(
|
self.patch_embed = Qwen2_5_VisionPatchEmbed(
|
||||||
patch_size=patch_size,
|
patch_size=patch_size,
|
||||||
temporal_patch_size=temporal_patch_size,
|
temporal_patch_size=temporal_patch_size,
|
||||||
@ -651,7 +651,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
|||||||
f"Qwen2.5-VL does not support {self.attn_backend} backend now."
|
f"Qwen2.5-VL does not support {self.attn_backend} backend now."
|
||||||
)
|
)
|
||||||
|
|
||||||
with set_model_tag("Qwen2_5_VisionBlock"):
|
with set_model_tag("Qwen2_5_VisionBlock", is_encoder=True):
|
||||||
self.blocks = nn.ModuleList(
|
self.blocks = nn.ModuleList(
|
||||||
[
|
[
|
||||||
Qwen2_5_VisionBlock(
|
Qwen2_5_VisionBlock(
|
||||||
@ -670,7 +670,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
with set_model_tag("Qwen2_5_VisionPatchMerger"):
|
with set_model_tag("Qwen2_5_VisionPatchMerger", is_encoder=True):
|
||||||
self.merger = Qwen2_5_VisionPatchMerger(
|
self.merger = Qwen2_5_VisionPatchMerger(
|
||||||
d_model=vision_config.out_hidden_size,
|
d_model=vision_config.out_hidden_size,
|
||||||
context_dim=self.hidden_size,
|
context_dim=self.hidden_size,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user