mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-07 14:16:35 +08:00
[V1][Perf] Reduce scheduling overhead in model runner after cuda sync (#12094)
Signed-off-by: Keyun Tong <tongkeyun@gmail.com>
This commit is contained in:
parent
2a0309a646
commit
fa63e710c7
@ -8,7 +8,7 @@ import torch
|
||||
class SamplerOutput:
|
||||
|
||||
# [num_reqs]
|
||||
sampled_token_ids: List[int]
|
||||
sampled_token_ids: torch.Tensor
|
||||
|
||||
# [num_reqs, max_num_logprobs + 1]
|
||||
logprob_token_ids: Optional[torch.Tensor]
|
||||
|
||||
@ -50,9 +50,8 @@ class Sampler(nn.Module):
|
||||
# Use int32 to reduce the tensor size.
|
||||
sampled = sampled.to(torch.int32)
|
||||
|
||||
# NOTE: CPU-GPU synchronization happens here.
|
||||
sampler_output = SamplerOutput(
|
||||
sampled_token_ids=sampled.tolist(),
|
||||
sampled_token_ids=sampled,
|
||||
logprob_token_ids=topk_indices,
|
||||
logprobs=topk_logprobs,
|
||||
prompt_logprob_token_ids=None,
|
||||
|
||||
@ -775,10 +775,10 @@ class GPUModelRunner:
|
||||
sampling_metadata=sampling_metadata,
|
||||
)
|
||||
|
||||
sampled_token_ids = sampler_output.sampled_token_ids
|
||||
# TODO(woosuk): The following loop can be slow since it iterates over
|
||||
# the requests one by one. Optimize.
|
||||
num_reqs = self.input_batch.num_reqs
|
||||
request_seq_lens: List[Tuple[int, CachedRequestState, int]] = []
|
||||
for i, req_id in enumerate(self.input_batch.req_ids[:num_reqs]):
|
||||
assert req_id is not None
|
||||
req_state = self.requests[req_id]
|
||||
@ -787,10 +787,10 @@ class GPUModelRunner:
|
||||
assert seq_len <= req_state.num_tokens
|
||||
if seq_len == req_state.num_tokens:
|
||||
# Append the sampled token to the output token ids.
|
||||
token_id = sampled_token_ids[i]
|
||||
self.input_batch.token_ids_cpu[i, seq_len] = token_id
|
||||
self.input_batch.num_tokens[i] += 1
|
||||
req_state.output_token_ids.append(token_id)
|
||||
# OPTIMIZATION: Priming the state updates for later updates.
|
||||
req_state.output_token_ids.append(0)
|
||||
request_seq_lens.append((i, req_state, seq_len))
|
||||
else:
|
||||
# Ignore the sampled token from the partial request.
|
||||
# Rewind the generator state as if the token was not sampled.
|
||||
@ -799,6 +799,21 @@ class GPUModelRunner:
|
||||
# This relies on cuda-specific torch-internal impl details
|
||||
generator.set_offset(generator.get_offset() - 4)
|
||||
|
||||
# num_reqs entries should be non-None
|
||||
assert all(
|
||||
req_id is not None for req_id in
|
||||
self.input_batch.req_ids[:num_reqs]), "req_ids contains None"
|
||||
req_ids = cast(List[str], self.input_batch.req_ids[:num_reqs])
|
||||
|
||||
# NOTE: GPU -> CPU Sync happens here.
|
||||
# Move as many CPU operations as possible before this sync point.
|
||||
sampled_token_ids = sampler_output.sampled_token_ids.tolist()
|
||||
# Update with the actual token ids
|
||||
for i, req_state, seq_len in request_seq_lens:
|
||||
token_id = sampled_token_ids[i]
|
||||
self.input_batch.token_ids_cpu[i, seq_len] = token_id
|
||||
req_state.output_token_ids[-1] = token_id
|
||||
|
||||
if sampler_output.logprob_token_ids is None:
|
||||
logprob_token_ids = None
|
||||
else:
|
||||
@ -808,12 +823,6 @@ class GPUModelRunner:
|
||||
else:
|
||||
logprobs = sampler_output.logprobs.cpu()
|
||||
|
||||
# num_reqs entries should be non-None
|
||||
assert all(
|
||||
req_id is not None for req_id in
|
||||
self.input_batch.req_ids[:num_reqs]), "req_ids contains None"
|
||||
req_ids = cast(List[str], self.input_batch.req_ids[:num_reqs])
|
||||
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=req_ids,
|
||||
req_id_to_index=self.input_batch.req_id_to_index,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user