Merge branch 'main' into woosuk/input-prep

This commit is contained in:
Woosuk Kwon 2025-09-14 00:44:56 +00:00
commit 9314a83b56

View File

@ -87,7 +87,7 @@ from vllm.v1.worker.gpu_worker_states import RequestState
from vllm.v1.worker.gpu_block_table import BlockTables
from vllm.v1.worker.gpu_input_batch import InputBatch, prepare_inputs
from vllm.v1.worker.kv_connector_model_runner_mixin import (
KVConnectorModelRunnerMixin, KVConnectorOutput)
KVConnectorModelRunnerMixin)
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
from .utils import (AttentionGroup, MultiModalBudget,
@ -197,6 +197,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.max_num_tokens = scheduler_config.max_num_batched_tokens
self.max_num_reqs = scheduler_config.max_num_seqs
# Broadcast PP output for external_launcher (torchrun)
# to make sure we are synced across pp ranks
# TODO: Support overlapping mirco-batches
# https://github.com/vllm-project/vllm/issues/18019
self.broadcast_pp_output = (
self.parallel_config.distributed_executor_backend
== "external_launcher" and len(get_pp_group().ranks) > 0)
# Model-related.
self.num_query_heads = model_config.get_num_attention_heads(
parallel_config)
@ -1363,7 +1371,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self,
hidden_states: torch.Tensor,
input_batch: InputBatch,
kv_connector_output: Optional[KVConnectorOutput],
) -> ModelRunnerOutput:
hidden_states = hidden_states[:num_scheduled_tokens]
pooling_metadata = self.req_states.get_pooling_metadata()
@ -1389,7 +1396,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
logprobs=None,
prompt_logprobs_dict={},
pooler_output=pooler_output,
kv_connector_output=kv_connector_output,
)
def _preprocess(
@ -1736,39 +1742,47 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
with record_function_or_nullcontext("Postprocess"):
if self.use_aux_hidden_state_outputs:
# True when EAGLE 3 is used.
hidden_states, aux_hidden_states = model_output
else:
# Common case.
hidden_states = model_output
aux_hidden_states = None
# Broadcast PP output for external_launcher (torchrun)
# to make sure we are synced across pp ranks
# TODO: Support overlapping mirco-batches
# https://github.com/vllm-project/vllm/issues/18019
broadcast_pp_output = \
self.parallel_config.distributed_executor_backend \
== "external_launcher" and len(get_pp_group().ranks) > 0
if not get_pp_group().is_last_rank:
# For mid-pipeline stages, return the hidden states.
assert isinstance(hidden_states, IntermediateTensors)
if not broadcast_pp_output:
if not self.broadcast_pp_output:
# Common case.
if not get_pp_group().is_last_rank:
# Return the intermediate tensors.
assert isinstance(hidden_states, IntermediateTensors)
hidden_states.kv_connector_output = kv_connector_output
return hidden_states
get_pp_group().send_tensor_dict(
hidden_states.tensors, all_gather_group=get_tp_group())
logits = None
else:
if self.is_pooling_model:
return self._pool(hidden_states, num_scheduled_tokens,
num_scheduled_tokens_np,
kv_connector_output)
# Return the pooling output.
output = self._pool(hidden_states, num_scheduled_tokens,
num_scheduled_tokens_np)
output.kv_connector_output = kv_connector_output
return output
sample_hidden_states = hidden_states[logits_indices]
logits = self.model.compute_logits(sample_hidden_states, None)
if broadcast_pp_output:
model_output_broadcast_data = {
"logits": logits.contiguous(),
} if logits is not None else {}
else:
# Rare case.
assert not self.is_pooling_model
if not get_pp_group().is_last_rank:
get_pp_group().send_tensor_dict(
hidden_states.tensors, all_gather_group=get_tp_group())
logits = None
else:
sample_hidden_states = hidden_states[logits_indices]
logits = self.model.compute_logits(sample_hidden_states,
None)
model_output_broadcast_data = {}
if logits is not None:
model_output_broadcast_data["logits"] = logits.contiguous()
model_output_broadcast_data = get_pp_group(
).broadcast_tensor_dict(model_output_broadcast_data,
src=len(get_pp_group().ranks) - 1)