diff --git a/vllm/v1/worker/gpu/dist_utils.py b/vllm/v1/worker/gpu/dist_utils.py new file mode 100644 index 0000000000000..3496a095c8d38 --- /dev/null +++ b/vllm/v1/worker/gpu/dist_utils.py @@ -0,0 +1,24 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import torch + +from vllm.distributed import tensor_model_parallel_all_gather + + +def pad_and_all_gather( + x: torch.Tensor, + padded_size: int, +) -> torch.Tensor: + n = x.shape[0] + if n != padded_size: + padded_x = torch.empty( + (padded_size, *x.shape[1:]), + dtype=x.dtype, + device=x.device, + ) + padded_x[:n] = x + else: + padded_x = x + + x = tensor_model_parallel_all_gather(padded_x) + return x diff --git a/vllm/v1/worker/gpu/model_runner.py b/vllm/v1/worker/gpu/model_runner.py index a4834ef5d975d..0e24569067e67 100644 --- a/vllm/v1/worker/gpu/model_runner.py +++ b/vllm/v1/worker/gpu/model_runner.py @@ -9,6 +9,7 @@ import numpy as np import torch from vllm.config import VllmConfig +from vllm.distributed import get_tp_group from vllm.forward_context import set_forward_context from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model_loader @@ -22,6 +23,7 @@ from vllm.v1.sample.sampler import SamplerOutput from vllm.v1.worker.gpu.attn_utils import (get_kv_cache_spec, init_attn_backend, init_kv_cache) from vllm.v1.worker.gpu.block_table import BlockTables +from vllm.v1.worker.gpu.dist_utils import pad_and_all_gather from vllm.v1.worker.gpu.input_batch import (InputBatch, InputBuffers, prepare_inputs) from vllm.v1.worker.gpu.sampler import Sampler @@ -364,17 +366,51 @@ class GPUModelRunner: hidden_states: torch.Tensor, input_batch: InputBatch, ) -> SamplerOutput: - # TODO(woosuk): Support DP sampler + CUDA graphs. + num_reqs = logits.shape[0] sample_hidden_states = hidden_states[input_batch.logits_indices] logits = self.model.compute_logits(sample_hidden_states, None) - pos = input_batch.positions[input_batch.logits_indices] + idx_mapping_np = input_batch.idx_mapping_np + + # When the batch size is large enough, use DP sampler. + tp_group = get_tp_group() + tp_size = tp_group.world_size + n = (num_reqs + tp_size - 1) // tp_size + use_dp_sampler = tp_size > 1 and n > 32 + if use_dp_sampler: + # Shard the inputs as evenly as possible. + # Make sure that no rank gets zero requests. + tp_rank = tp_group.rank + q = num_reqs // tp_size + r = num_reqs % tp_size + start = q * tp_rank + min(tp_rank, r) + end = start + q + (1 if tp_rank < r else 0) + + logits = logits[start:end] + pos = pos[start:end] + idx_mapping_np = idx_mapping_np[start:end] + sampling_metadata = self.req_states.make_sampling_metadata( - input_batch.idx_mapping_np, pos) + idx_mapping_np, pos) sampler_output = self.sampler( logits=logits, sampling_metadata=sampling_metadata, ) + + if use_dp_sampler: + # Gather the outputs. + # TODO(woosuk): Optimize. + sampler_output.sampled_token_ids = pad_and_all_gather( + sampler_output.sampled_token_ids, n)[:num_reqs] + + logprobs_tensors = sampler_output.logprobs_tensors + if logprobs_tensors is not None: + logprobs_tensors.logprob_token_ids = pad_and_all_gather( + logprobs_tensors.logprob_token_ids, n)[:num_reqs] + logprobs_tensors.logprobs = pad_and_all_gather( + logprobs_tensors.logprobs, n)[:num_reqs] + logprobs_tensors.selected_token_ranks = pad_and_all_gather( + logprobs_tensors.selected_token_ranks, n)[:num_reqs] return sampler_output def postprocess( @@ -395,6 +431,9 @@ class GPUModelRunner: sampled_token_ids_np, num_sampled_tokens, ) + + # self.req_states.last_token_ids[input_batch.idx_mapping] = ( + # sampler_output.sampled_token_ids) return sampled_token_ids_np, num_sampled_tokens def execute_model(