[Bugfix][torch2.10] Fix test_qwen2_5_vl_compilation with 2.10 RC (#30822)

Signed-off-by: Lucas Kabela <lucaskabela@meta.com>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
Lucas Kabela 2025-12-18 08:23:31 -08:00 committed by GitHub
parent 28d15ab56b
commit 0db5439ded
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 13 additions and 7 deletions

View File

@ -520,6 +520,7 @@ class VllmBackend:
self,
vllm_config: VllmConfig,
prefix: str = "",
is_encoder: bool = False,
):
# if the model is initialized with a non-empty prefix,
# then usually it's enough to use that prefix,
@ -530,7 +531,7 @@ class VllmBackend:
self.prefix = prefix or model_tag
# Mark compilation for encoder.
self.is_encoder = model_is_encoder
self.is_encoder = is_encoder or model_is_encoder
# Passes to run on the graph post-grad.
self.pass_manager = resolve_obj_by_qualname(
@ -797,7 +798,7 @@ class VllmBackend:
or not self.compilation_config.cudagraph_copy_inputs
):
return VllmSerializableFunction(
graph, example_inputs, self.prefix, self.split_gm
graph, example_inputs, self.prefix, self.split_gm, self.is_encoder
)
# index of tensors that have symbolic shapes (batch size)
@ -835,5 +836,5 @@ class VllmBackend:
return self.split_gm(*list_args)
return VllmSerializableFunction(
graph, example_inputs, self.prefix, copy_and_call
graph, example_inputs, self.prefix, copy_and_call, self.is_encoder
)

View File

@ -37,12 +37,15 @@ class VllmSerializableFunction(SerializableCallable):
serializing the Dynamo fx graph plus example inputs.
"""
def __init__(self, graph_module, example_inputs, prefix, optimized_call):
def __init__(
self, graph_module, example_inputs, prefix, optimized_call, is_encoder=False
):
assert isinstance(graph_module, torch.fx.GraphModule)
self.graph_module = graph_module
self.example_inputs = example_inputs
self.prefix = prefix
self.optimized_call = optimized_call
self.is_encoder = is_encoder
self.shape_env = None
sym_input = next(
(i for i in self.example_inputs if isinstance(i, torch.SymInt)), None
@ -106,7 +109,10 @@ class VllmSerializableFunction(SerializableCallable):
state["graph_module"] = GraphPickler.loads(state["graph_module"], fake_mode)
state["graph_module"].recompile()
state["example_inputs"] = GraphPickler.loads(state["example_inputs"], fake_mode)
vllm_backend = VllmBackend(get_current_vllm_config(), state["prefix"])
is_encoder = state.get("is_encoder", False)
vllm_backend = VllmBackend(
get_current_vllm_config(), state["prefix"], is_encoder
)
def optimized_call(*example_inputs):
"""

View File

@ -170,8 +170,7 @@ class PiecewiseBackend:
range_entry = self._find_range_for_shape(runtime_shape)
assert range_entry is not None, (
f"Shape out of considered range: {runtime_shape} "
"[1, max_num_batched_tokens]"
f"Shape: {runtime_shape} out of considered ranges: {self.compile_ranges}"
)
self._maybe_compile_for_range_entry(range_entry, args)