diff --git a/vllm/v1/worker/gpu/dist_utils.py b/vllm/v1/worker/gpu/dist_utils.py index 108a52edec186..c4c56de47e1c6 100644 --- a/vllm/v1/worker/gpu/dist_utils.py +++ b/vllm/v1/worker/gpu/dist_utils.py @@ -3,6 +3,7 @@ import torch from vllm.distributed import tensor_model_parallel_all_gather +from vllm.v1.outputs import SamplerOutput def evenly_split( @@ -34,3 +35,24 @@ def pad_and_all_gather( x = tensor_model_parallel_all_gather(padded_x) return x + + +def all_gather_sampler_output( + sampler_output: SamplerOutput, + num_reqs: int, + tp_size: int, +) -> SamplerOutput: + n = (num_reqs + tp_size - 1) // tp_size + sampler_output.sampled_token_ids = pad_and_all_gather( + sampler_output.sampled_token_ids, n)[:num_reqs] + + # TODO(woosuk): 3 small all-gathers, could be merged into one. + logprobs_tensors = sampler_output.logprobs_tensors + if logprobs_tensors is not None: + logprobs_tensors.logprob_token_ids = pad_and_all_gather( + logprobs_tensors.logprob_token_ids, n)[:num_reqs] + logprobs_tensors.logprobs = pad_and_all_gather( + logprobs_tensors.logprobs, n)[:num_reqs] + logprobs_tensors.selected_token_ranks = pad_and_all_gather( + logprobs_tensors.selected_token_ranks, n)[:num_reqs] + return sampler_output diff --git a/vllm/v1/worker/gpu/model_runner.py b/vllm/v1/worker/gpu/model_runner.py index 178ccaef1d869..4a7ed7a6af40b 100644 --- a/vllm/v1/worker/gpu/model_runner.py +++ b/vllm/v1/worker/gpu/model_runner.py @@ -24,7 +24,8 @@ from vllm.v1.worker.gpu.async_utils import AsyncOutput from vllm.v1.worker.gpu.attn_utils import (get_kv_cache_spec, init_attn_backend, init_kv_cache) from vllm.v1.worker.gpu.block_table import BlockTables -from vllm.v1.worker.gpu.dist_utils import evenly_split, pad_and_all_gather +from vllm.v1.worker.gpu.dist_utils import (all_gather_sampler_output, + evenly_split) from vllm.v1.worker.gpu.input_batch import (InputBatch, InputBuffers, prepare_inputs) from vllm.v1.worker.gpu.sampler import Sampler @@ -398,19 +399,12 @@ class GPUModelRunner: ) if use_dp_sampler: - # Gather the outputs. - # TODO(woosuk): Optimize. - sampler_output.sampled_token_ids = pad_and_all_gather( - sampler_output.sampled_token_ids, n)[:num_reqs] - - logprobs_tensors = sampler_output.logprobs_tensors - if logprobs_tensors is not None: - logprobs_tensors.logprob_token_ids = pad_and_all_gather( - logprobs_tensors.logprob_token_ids, n)[:num_reqs] - logprobs_tensors.logprobs = pad_and_all_gather( - logprobs_tensors.logprobs, n)[:num_reqs] - logprobs_tensors.selected_token_ranks = pad_and_all_gather( - logprobs_tensors.selected_token_ranks, n)[:num_reqs] + # All-gather the outputs. + sampler_output = all_gather_sampler_output( + sampler_output, + num_reqs, + tp_size, + ) return sampler_output def postprocess(