mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-19 07:45:01 +08:00
[Core] Refactor self.model() to call a helper for subclassing. (#25084)
Signed-off-by: Patrick Toulme <ptoulme@meta.com> Signed-off-by: Patrick Toulme <pctoulme+1@gmail.com> Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
9b4c752106
commit
7b28ef2bc1
@ -2268,6 +2268,38 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
finally:
|
finally:
|
||||||
self.prepare_inputs_event.record()
|
self.prepare_inputs_event.record()
|
||||||
|
|
||||||
|
def _model_forward(
|
||||||
|
self,
|
||||||
|
input_ids: Optional[torch.Tensor] = None,
|
||||||
|
positions: Optional[torch.Tensor] = None,
|
||||||
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
|
**model_kwargs: dict[str, Any],
|
||||||
|
) -> Any:
|
||||||
|
"""Helper method to call the model forward pass.
|
||||||
|
|
||||||
|
This method can be overridden by subclasses for model execution.
|
||||||
|
Motivation: We can inspect only this method versus
|
||||||
|
the whole execute_model, which has additional logic.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_ids: Input token IDs
|
||||||
|
positions: Token positions
|
||||||
|
intermediate_tensors: Tensors from previous pipeline stages
|
||||||
|
inputs_embeds: Input embeddings (alternative to input_ids)
|
||||||
|
**model_kwargs: Additional model arguments
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Model output tensor
|
||||||
|
"""
|
||||||
|
return self.model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
positions=positions,
|
||||||
|
intermediate_tensors=intermediate_tensors,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
**model_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def execute_model(
|
def execute_model(
|
||||||
self,
|
self,
|
||||||
@ -2337,7 +2369,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
), record_function_or_nullcontext("Forward"),
|
), record_function_or_nullcontext("Forward"),
|
||||||
self.maybe_get_kv_connector_output(scheduler_output) as
|
self.maybe_get_kv_connector_output(scheduler_output) as
|
||||||
kv_connector_output):
|
kv_connector_output):
|
||||||
model_output = self.model(
|
model_output = self._model_forward(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
positions=positions,
|
positions=positions,
|
||||||
intermediate_tensors=intermediate_tensors,
|
intermediate_tensors=intermediate_tensors,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user