mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-08 14:17:09 +08:00
refactor
Signed-off-by: Woosuk Kwon <woosuk@thinkingmachines.ai>
This commit is contained in:
parent
9c75d896a8
commit
d30c0d50a6
@ -3,6 +3,7 @@
|
||||
import torch
|
||||
|
||||
from vllm.distributed import tensor_model_parallel_all_gather
|
||||
from vllm.v1.outputs import SamplerOutput
|
||||
|
||||
|
||||
def evenly_split(
|
||||
@ -34,3 +35,24 @@ def pad_and_all_gather(
|
||||
|
||||
x = tensor_model_parallel_all_gather(padded_x)
|
||||
return x
|
||||
|
||||
|
||||
def all_gather_sampler_output(
|
||||
sampler_output: SamplerOutput,
|
||||
num_reqs: int,
|
||||
tp_size: int,
|
||||
) -> SamplerOutput:
|
||||
n = (num_reqs + tp_size - 1) // tp_size
|
||||
sampler_output.sampled_token_ids = pad_and_all_gather(
|
||||
sampler_output.sampled_token_ids, n)[:num_reqs]
|
||||
|
||||
# TODO(woosuk): 3 small all-gathers, could be merged into one.
|
||||
logprobs_tensors = sampler_output.logprobs_tensors
|
||||
if logprobs_tensors is not None:
|
||||
logprobs_tensors.logprob_token_ids = pad_and_all_gather(
|
||||
logprobs_tensors.logprob_token_ids, n)[:num_reqs]
|
||||
logprobs_tensors.logprobs = pad_and_all_gather(
|
||||
logprobs_tensors.logprobs, n)[:num_reqs]
|
||||
logprobs_tensors.selected_token_ranks = pad_and_all_gather(
|
||||
logprobs_tensors.selected_token_ranks, n)[:num_reqs]
|
||||
return sampler_output
|
||||
|
||||
@ -24,7 +24,8 @@ from vllm.v1.worker.gpu.async_utils import AsyncOutput
|
||||
from vllm.v1.worker.gpu.attn_utils import (get_kv_cache_spec,
|
||||
init_attn_backend, init_kv_cache)
|
||||
from vllm.v1.worker.gpu.block_table import BlockTables
|
||||
from vllm.v1.worker.gpu.dist_utils import evenly_split, pad_and_all_gather
|
||||
from vllm.v1.worker.gpu.dist_utils import (all_gather_sampler_output,
|
||||
evenly_split)
|
||||
from vllm.v1.worker.gpu.input_batch import (InputBatch, InputBuffers,
|
||||
prepare_inputs)
|
||||
from vllm.v1.worker.gpu.sampler import Sampler
|
||||
@ -398,19 +399,12 @@ class GPUModelRunner:
|
||||
)
|
||||
|
||||
if use_dp_sampler:
|
||||
# Gather the outputs.
|
||||
# TODO(woosuk): Optimize.
|
||||
sampler_output.sampled_token_ids = pad_and_all_gather(
|
||||
sampler_output.sampled_token_ids, n)[:num_reqs]
|
||||
|
||||
logprobs_tensors = sampler_output.logprobs_tensors
|
||||
if logprobs_tensors is not None:
|
||||
logprobs_tensors.logprob_token_ids = pad_and_all_gather(
|
||||
logprobs_tensors.logprob_token_ids, n)[:num_reqs]
|
||||
logprobs_tensors.logprobs = pad_and_all_gather(
|
||||
logprobs_tensors.logprobs, n)[:num_reqs]
|
||||
logprobs_tensors.selected_token_ranks = pad_and_all_gather(
|
||||
logprobs_tensors.selected_token_ranks, n)[:num_reqs]
|
||||
# All-gather the outputs.
|
||||
sampler_output = all_gather_sampler_output(
|
||||
sampler_output,
|
||||
num_reqs,
|
||||
tp_size,
|
||||
)
|
||||
return sampler_output
|
||||
|
||||
def postprocess(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user