Signed-off-by: Woosuk Kwon <woosuk@thinkingmachines.ai>
This commit is contained in:
Woosuk Kwon 2025-09-19 07:17:53 +00:00
parent 9c75d896a8
commit d30c0d50a6
2 changed files with 30 additions and 14 deletions

View File

@ -3,6 +3,7 @@
import torch
from vllm.distributed import tensor_model_parallel_all_gather
from vllm.v1.outputs import SamplerOutput
def evenly_split(
@ -34,3 +35,24 @@ def pad_and_all_gather(
x = tensor_model_parallel_all_gather(padded_x)
return x
def all_gather_sampler_output(
sampler_output: SamplerOutput,
num_reqs: int,
tp_size: int,
) -> SamplerOutput:
n = (num_reqs + tp_size - 1) // tp_size
sampler_output.sampled_token_ids = pad_and_all_gather(
sampler_output.sampled_token_ids, n)[:num_reqs]
# TODO(woosuk): 3 small all-gathers, could be merged into one.
logprobs_tensors = sampler_output.logprobs_tensors
if logprobs_tensors is not None:
logprobs_tensors.logprob_token_ids = pad_and_all_gather(
logprobs_tensors.logprob_token_ids, n)[:num_reqs]
logprobs_tensors.logprobs = pad_and_all_gather(
logprobs_tensors.logprobs, n)[:num_reqs]
logprobs_tensors.selected_token_ranks = pad_and_all_gather(
logprobs_tensors.selected_token_ranks, n)[:num_reqs]
return sampler_output

View File

@ -24,7 +24,8 @@ from vllm.v1.worker.gpu.async_utils import AsyncOutput
from vllm.v1.worker.gpu.attn_utils import (get_kv_cache_spec,
init_attn_backend, init_kv_cache)
from vllm.v1.worker.gpu.block_table import BlockTables
from vllm.v1.worker.gpu.dist_utils import evenly_split, pad_and_all_gather
from vllm.v1.worker.gpu.dist_utils import (all_gather_sampler_output,
evenly_split)
from vllm.v1.worker.gpu.input_batch import (InputBatch, InputBuffers,
prepare_inputs)
from vllm.v1.worker.gpu.sampler import Sampler
@ -398,19 +399,12 @@ class GPUModelRunner:
)
if use_dp_sampler:
# Gather the outputs.
# TODO(woosuk): Optimize.
sampler_output.sampled_token_ids = pad_and_all_gather(
sampler_output.sampled_token_ids, n)[:num_reqs]
logprobs_tensors = sampler_output.logprobs_tensors
if logprobs_tensors is not None:
logprobs_tensors.logprob_token_ids = pad_and_all_gather(
logprobs_tensors.logprob_token_ids, n)[:num_reqs]
logprobs_tensors.logprobs = pad_and_all_gather(
logprobs_tensors.logprobs, n)[:num_reqs]
logprobs_tensors.selected_token_ranks = pad_and_all_gather(
logprobs_tensors.selected_token_ranks, n)[:num_reqs]
# All-gather the outputs.
sampler_output = all_gather_sampler_output(
sampler_output,
num_reqs,
tp_size,
)
return sampler_output
def postprocess(