diff --git a/vllm/v1/worker/gpu/dist_utils.py b/vllm/v1/worker/gpu/dist_utils.py index 3496a095c8d38..108a52edec186 100644 --- a/vllm/v1/worker/gpu/dist_utils.py +++ b/vllm/v1/worker/gpu/dist_utils.py @@ -5,6 +5,18 @@ import torch from vllm.distributed import tensor_model_parallel_all_gather +def evenly_split( + n: int, + tp_size: int, + tp_rank: int, +) -> tuple[int, int]: + q = n // tp_size + r = n % tp_size + start = q * tp_rank + min(tp_rank, r) + end = start + q + (1 if tp_rank < r else 0) + return start, end + + def pad_and_all_gather( x: torch.Tensor, padded_size: int, diff --git a/vllm/v1/worker/gpu/model_runner.py b/vllm/v1/worker/gpu/model_runner.py index 0e24569067e67..153f4af34eb8d 100644 --- a/vllm/v1/worker/gpu/model_runner.py +++ b/vllm/v1/worker/gpu/model_runner.py @@ -23,7 +23,7 @@ from vllm.v1.sample.sampler import SamplerOutput from vllm.v1.worker.gpu.attn_utils import (get_kv_cache_spec, init_attn_backend, init_kv_cache) from vllm.v1.worker.gpu.block_table import BlockTables -from vllm.v1.worker.gpu.dist_utils import pad_and_all_gather +from vllm.v1.worker.gpu.dist_utils import evenly_split, pad_and_all_gather from vllm.v1.worker.gpu.input_batch import (InputBatch, InputBuffers, prepare_inputs) from vllm.v1.worker.gpu.sampler import Sampler @@ -366,11 +366,11 @@ class GPUModelRunner: hidden_states: torch.Tensor, input_batch: InputBatch, ) -> SamplerOutput: - num_reqs = logits.shape[0] sample_hidden_states = hidden_states[input_batch.logits_indices] logits = self.model.compute_logits(sample_hidden_states, None) pos = input_batch.positions[input_batch.logits_indices] idx_mapping_np = input_batch.idx_mapping_np + num_reqs = logits.shape[0] # When the batch size is large enough, use DP sampler. tp_group = get_tp_group() @@ -378,14 +378,9 @@ class GPUModelRunner: n = (num_reqs + tp_size - 1) // tp_size use_dp_sampler = tp_size > 1 and n > 32 if use_dp_sampler: - # Shard the inputs as evenly as possible. - # Make sure that no rank gets zero requests. + # NOTE(woosuk): Make sure that no rank gets zero requests. tp_rank = tp_group.rank - q = num_reqs // tp_size - r = num_reqs % tp_size - start = q * tp_rank + min(tp_rank, r) - end = start + q + (1 if tp_rank < r else 0) - + start, end = evenly_split(num_reqs, tp_size, tp_rank) logits = logits[start:end] pos = pos[start:end] idx_mapping_np = idx_mapping_np[start:end]