diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index a1eec7d74483f..2fb6265560b19 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -520,6 +520,7 @@ class VllmBackend: self, vllm_config: VllmConfig, prefix: str = "", + is_encoder: bool = False, ): # if the model is initialized with a non-empty prefix, # then usually it's enough to use that prefix, @@ -530,7 +531,7 @@ class VllmBackend: self.prefix = prefix or model_tag # Mark compilation for encoder. - self.is_encoder = model_is_encoder + self.is_encoder = is_encoder or model_is_encoder # Passes to run on the graph post-grad. self.pass_manager = resolve_obj_by_qualname( @@ -797,7 +798,7 @@ class VllmBackend: or not self.compilation_config.cudagraph_copy_inputs ): return VllmSerializableFunction( - graph, example_inputs, self.prefix, self.split_gm + graph, example_inputs, self.prefix, self.split_gm, self.is_encoder ) # index of tensors that have symbolic shapes (batch size) @@ -835,5 +836,5 @@ class VllmBackend: return self.split_gm(*list_args) return VllmSerializableFunction( - graph, example_inputs, self.prefix, copy_and_call + graph, example_inputs, self.prefix, copy_and_call, self.is_encoder ) diff --git a/vllm/compilation/caching.py b/vllm/compilation/caching.py index fc02a08f74265..8c9ec87bcad56 100644 --- a/vllm/compilation/caching.py +++ b/vllm/compilation/caching.py @@ -37,12 +37,15 @@ class VllmSerializableFunction(SerializableCallable): serializing the Dynamo fx graph plus example inputs. """ - def __init__(self, graph_module, example_inputs, prefix, optimized_call): + def __init__( + self, graph_module, example_inputs, prefix, optimized_call, is_encoder=False + ): assert isinstance(graph_module, torch.fx.GraphModule) self.graph_module = graph_module self.example_inputs = example_inputs self.prefix = prefix self.optimized_call = optimized_call + self.is_encoder = is_encoder self.shape_env = None sym_input = next( (i for i in self.example_inputs if isinstance(i, torch.SymInt)), None @@ -106,7 +109,10 @@ class VllmSerializableFunction(SerializableCallable): state["graph_module"] = GraphPickler.loads(state["graph_module"], fake_mode) state["graph_module"].recompile() state["example_inputs"] = GraphPickler.loads(state["example_inputs"], fake_mode) - vllm_backend = VllmBackend(get_current_vllm_config(), state["prefix"]) + is_encoder = state.get("is_encoder", False) + vllm_backend = VllmBackend( + get_current_vllm_config(), state["prefix"], is_encoder + ) def optimized_call(*example_inputs): """ diff --git a/vllm/compilation/piecewise_backend.py b/vllm/compilation/piecewise_backend.py index 58d3e2a14b22a..12cc49971e08b 100644 --- a/vllm/compilation/piecewise_backend.py +++ b/vllm/compilation/piecewise_backend.py @@ -170,8 +170,7 @@ class PiecewiseBackend: range_entry = self._find_range_for_shape(runtime_shape) assert range_entry is not None, ( - f"Shape out of considered range: {runtime_shape} " - "[1, max_num_batched_tokens]" + f"Shape: {runtime_shape} out of considered ranges: {self.compile_ranges}" ) self._maybe_compile_for_range_entry(range_entry, args)