mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-16 21:49:10 +08:00
DP sampler
Signed-off-by: Woosuk Kwon <woosuk@thinkingmachines.ai>
This commit is contained in:
parent
8af87986aa
commit
b405d78c07
24
vllm/v1/worker/gpu/dist_utils.py
Normal file
24
vllm/v1/worker/gpu/dist_utils.py
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.distributed import tensor_model_parallel_all_gather
|
||||||
|
|
||||||
|
|
||||||
|
def pad_and_all_gather(
|
||||||
|
x: torch.Tensor,
|
||||||
|
padded_size: int,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
n = x.shape[0]
|
||||||
|
if n != padded_size:
|
||||||
|
padded_x = torch.empty(
|
||||||
|
(padded_size, *x.shape[1:]),
|
||||||
|
dtype=x.dtype,
|
||||||
|
device=x.device,
|
||||||
|
)
|
||||||
|
padded_x[:n] = x
|
||||||
|
else:
|
||||||
|
padded_x = x
|
||||||
|
|
||||||
|
x = tensor_model_parallel_all_gather(padded_x)
|
||||||
|
return x
|
||||||
@ -9,6 +9,7 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
|
from vllm.distributed import get_tp_group
|
||||||
from vllm.forward_context import set_forward_context
|
from vllm.forward_context import set_forward_context
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.model_loader import get_model_loader
|
from vllm.model_executor.model_loader import get_model_loader
|
||||||
@ -22,6 +23,7 @@ from vllm.v1.sample.sampler import SamplerOutput
|
|||||||
from vllm.v1.worker.gpu.attn_utils import (get_kv_cache_spec,
|
from vllm.v1.worker.gpu.attn_utils import (get_kv_cache_spec,
|
||||||
init_attn_backend, init_kv_cache)
|
init_attn_backend, init_kv_cache)
|
||||||
from vllm.v1.worker.gpu.block_table import BlockTables
|
from vllm.v1.worker.gpu.block_table import BlockTables
|
||||||
|
from vllm.v1.worker.gpu.dist_utils import pad_and_all_gather
|
||||||
from vllm.v1.worker.gpu.input_batch import (InputBatch, InputBuffers,
|
from vllm.v1.worker.gpu.input_batch import (InputBatch, InputBuffers,
|
||||||
prepare_inputs)
|
prepare_inputs)
|
||||||
from vllm.v1.worker.gpu.sampler import Sampler
|
from vllm.v1.worker.gpu.sampler import Sampler
|
||||||
@ -364,17 +366,51 @@ class GPUModelRunner:
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
input_batch: InputBatch,
|
input_batch: InputBatch,
|
||||||
) -> SamplerOutput:
|
) -> SamplerOutput:
|
||||||
# TODO(woosuk): Support DP sampler + CUDA graphs.
|
num_reqs = logits.shape[0]
|
||||||
sample_hidden_states = hidden_states[input_batch.logits_indices]
|
sample_hidden_states = hidden_states[input_batch.logits_indices]
|
||||||
logits = self.model.compute_logits(sample_hidden_states, None)
|
logits = self.model.compute_logits(sample_hidden_states, None)
|
||||||
|
|
||||||
pos = input_batch.positions[input_batch.logits_indices]
|
pos = input_batch.positions[input_batch.logits_indices]
|
||||||
|
idx_mapping_np = input_batch.idx_mapping_np
|
||||||
|
|
||||||
|
# When the batch size is large enough, use DP sampler.
|
||||||
|
tp_group = get_tp_group()
|
||||||
|
tp_size = tp_group.world_size
|
||||||
|
n = (num_reqs + tp_size - 1) // tp_size
|
||||||
|
use_dp_sampler = tp_size > 1 and n > 32
|
||||||
|
if use_dp_sampler:
|
||||||
|
# Shard the inputs as evenly as possible.
|
||||||
|
# Make sure that no rank gets zero requests.
|
||||||
|
tp_rank = tp_group.rank
|
||||||
|
q = num_reqs // tp_size
|
||||||
|
r = num_reqs % tp_size
|
||||||
|
start = q * tp_rank + min(tp_rank, r)
|
||||||
|
end = start + q + (1 if tp_rank < r else 0)
|
||||||
|
|
||||||
|
logits = logits[start:end]
|
||||||
|
pos = pos[start:end]
|
||||||
|
idx_mapping_np = idx_mapping_np[start:end]
|
||||||
|
|
||||||
sampling_metadata = self.req_states.make_sampling_metadata(
|
sampling_metadata = self.req_states.make_sampling_metadata(
|
||||||
input_batch.idx_mapping_np, pos)
|
idx_mapping_np, pos)
|
||||||
sampler_output = self.sampler(
|
sampler_output = self.sampler(
|
||||||
logits=logits,
|
logits=logits,
|
||||||
sampling_metadata=sampling_metadata,
|
sampling_metadata=sampling_metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if use_dp_sampler:
|
||||||
|
# Gather the outputs.
|
||||||
|
# TODO(woosuk): Optimize.
|
||||||
|
sampler_output.sampled_token_ids = pad_and_all_gather(
|
||||||
|
sampler_output.sampled_token_ids, n)[:num_reqs]
|
||||||
|
|
||||||
|
logprobs_tensors = sampler_output.logprobs_tensors
|
||||||
|
if logprobs_tensors is not None:
|
||||||
|
logprobs_tensors.logprob_token_ids = pad_and_all_gather(
|
||||||
|
logprobs_tensors.logprob_token_ids, n)[:num_reqs]
|
||||||
|
logprobs_tensors.logprobs = pad_and_all_gather(
|
||||||
|
logprobs_tensors.logprobs, n)[:num_reqs]
|
||||||
|
logprobs_tensors.selected_token_ranks = pad_and_all_gather(
|
||||||
|
logprobs_tensors.selected_token_ranks, n)[:num_reqs]
|
||||||
return sampler_output
|
return sampler_output
|
||||||
|
|
||||||
def postprocess(
|
def postprocess(
|
||||||
@ -395,6 +431,9 @@ class GPUModelRunner:
|
|||||||
sampled_token_ids_np,
|
sampled_token_ids_np,
|
||||||
num_sampled_tokens,
|
num_sampled_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# self.req_states.last_token_ids[input_batch.idx_mapping] = (
|
||||||
|
# sampler_output.sampled_token_ids)
|
||||||
return sampled_token_ids_np, num_sampled_tokens
|
return sampled_token_ids_np, num_sampled_tokens
|
||||||
|
|
||||||
def execute_model(
|
def execute_model(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user