Signed-off-by: Woosuk Kwon <woosuk@thinkingmachines.ai>
This commit is contained in:
Woosuk Kwon 2025-09-19 06:50:56 +00:00
parent b405d78c07
commit 0d3de9e082
2 changed files with 16 additions and 9 deletions

View File

@ -5,6 +5,18 @@ import torch
from vllm.distributed import tensor_model_parallel_all_gather
def evenly_split(
n: int,
tp_size: int,
tp_rank: int,
) -> tuple[int, int]:
q = n // tp_size
r = n % tp_size
start = q * tp_rank + min(tp_rank, r)
end = start + q + (1 if tp_rank < r else 0)
return start, end
def pad_and_all_gather(
x: torch.Tensor,
padded_size: int,

View File

@ -23,7 +23,7 @@ from vllm.v1.sample.sampler import SamplerOutput
from vllm.v1.worker.gpu.attn_utils import (get_kv_cache_spec,
init_attn_backend, init_kv_cache)
from vllm.v1.worker.gpu.block_table import BlockTables
from vllm.v1.worker.gpu.dist_utils import pad_and_all_gather
from vllm.v1.worker.gpu.dist_utils import evenly_split, pad_and_all_gather
from vllm.v1.worker.gpu.input_batch import (InputBatch, InputBuffers,
prepare_inputs)
from vllm.v1.worker.gpu.sampler import Sampler
@ -366,11 +366,11 @@ class GPUModelRunner:
hidden_states: torch.Tensor,
input_batch: InputBatch,
) -> SamplerOutput:
num_reqs = logits.shape[0]
sample_hidden_states = hidden_states[input_batch.logits_indices]
logits = self.model.compute_logits(sample_hidden_states, None)
pos = input_batch.positions[input_batch.logits_indices]
idx_mapping_np = input_batch.idx_mapping_np
num_reqs = logits.shape[0]
# When the batch size is large enough, use DP sampler.
tp_group = get_tp_group()
@ -378,14 +378,9 @@ class GPUModelRunner:
n = (num_reqs + tp_size - 1) // tp_size
use_dp_sampler = tp_size > 1 and n > 32
if use_dp_sampler:
# Shard the inputs as evenly as possible.
# Make sure that no rank gets zero requests.
# NOTE(woosuk): Make sure that no rank gets zero requests.
tp_rank = tp_group.rank
q = num_reqs // tp_size
r = num_reqs % tp_size
start = q * tp_rank + min(tp_rank, r)
end = start + q + (1 if tp_rank < r else 0)
start, end = evenly_split(num_reqs, tp_size, tp_rank)
logits = logits[start:end]
pos = pos[start:end]
idx_mapping_np = idx_mapping_np[start:end]