From 7b28ef2bc1c54c1dd24c77c2a5b70b4795d0d4a2 Mon Sep 17 00:00:00 2001 From: "Patrick C. Toulme" <135739773+patrick-toulme@users.noreply.github.com> Date: Sat, 27 Sep 2025 11:40:59 -0400 Subject: [PATCH] [Core] Refactor self.model() to call a helper for subclassing. (#25084) Signed-off-by: Patrick Toulme Signed-off-by: Patrick Toulme Signed-off-by: yewentao256 --- vllm/v1/worker/gpu_model_runner.py | 34 +++++++++++++++++++++++++++++- 1 file changed, 33 insertions(+), 1 deletion(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 1bae0d4ce4d1f..2354e8222e7af 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2268,6 +2268,38 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): finally: self.prepare_inputs_event.record() + def _model_forward( + self, + input_ids: Optional[torch.Tensor] = None, + positions: Optional[torch.Tensor] = None, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **model_kwargs: dict[str, Any], + ) -> Any: + """Helper method to call the model forward pass. + + This method can be overridden by subclasses for model execution. + Motivation: We can inspect only this method versus + the whole execute_model, which has additional logic. + + Args: + input_ids: Input token IDs + positions: Token positions + intermediate_tensors: Tensors from previous pipeline stages + inputs_embeds: Input embeddings (alternative to input_ids) + **model_kwargs: Additional model arguments + + Returns: + Model output tensor + """ + return self.model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + **model_kwargs, + ) + @torch.inference_mode() def execute_model( self, @@ -2337,7 +2369,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ), record_function_or_nullcontext("Forward"), self.maybe_get_kv_connector_output(scheduler_output) as kv_connector_output): - model_output = self.model( + model_output = self._model_forward( input_ids=input_ids, positions=positions, intermediate_tensors=intermediate_tensors,