[V1][Perf] Reduce scheduling overhead in model runner after cuda sync (#12094)

Signed-off-by: Keyun Tong <tongkeyun@gmail.com>
This commit is contained in:
Keyun Tong 2025-01-26 00:42:37 -08:00 committed by GitHub
parent 2a0309a646
commit fa63e710c7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 21 additions and 13 deletions

View File

@ -8,7 +8,7 @@ import torch
class SamplerOutput:
# [num_reqs]
sampled_token_ids: List[int]
sampled_token_ids: torch.Tensor
# [num_reqs, max_num_logprobs + 1]
logprob_token_ids: Optional[torch.Tensor]

View File

@ -50,9 +50,8 @@ class Sampler(nn.Module):
# Use int32 to reduce the tensor size.
sampled = sampled.to(torch.int32)
# NOTE: CPU-GPU synchronization happens here.
sampler_output = SamplerOutput(
sampled_token_ids=sampled.tolist(),
sampled_token_ids=sampled,
logprob_token_ids=topk_indices,
logprobs=topk_logprobs,
prompt_logprob_token_ids=None,

View File

@ -775,10 +775,10 @@ class GPUModelRunner:
sampling_metadata=sampling_metadata,
)
sampled_token_ids = sampler_output.sampled_token_ids
# TODO(woosuk): The following loop can be slow since it iterates over
# the requests one by one. Optimize.
num_reqs = self.input_batch.num_reqs
request_seq_lens: List[Tuple[int, CachedRequestState, int]] = []
for i, req_id in enumerate(self.input_batch.req_ids[:num_reqs]):
assert req_id is not None
req_state = self.requests[req_id]
@ -787,10 +787,10 @@ class GPUModelRunner:
assert seq_len <= req_state.num_tokens
if seq_len == req_state.num_tokens:
# Append the sampled token to the output token ids.
token_id = sampled_token_ids[i]
self.input_batch.token_ids_cpu[i, seq_len] = token_id
self.input_batch.num_tokens[i] += 1
req_state.output_token_ids.append(token_id)
# OPTIMIZATION: Priming the state updates for later updates.
req_state.output_token_ids.append(0)
request_seq_lens.append((i, req_state, seq_len))
else:
# Ignore the sampled token from the partial request.
# Rewind the generator state as if the token was not sampled.
@ -799,6 +799,21 @@ class GPUModelRunner:
# This relies on cuda-specific torch-internal impl details
generator.set_offset(generator.get_offset() - 4)
# num_reqs entries should be non-None
assert all(
req_id is not None for req_id in
self.input_batch.req_ids[:num_reqs]), "req_ids contains None"
req_ids = cast(List[str], self.input_batch.req_ids[:num_reqs])
# NOTE: GPU -> CPU Sync happens here.
# Move as many CPU operations as possible before this sync point.
sampled_token_ids = sampler_output.sampled_token_ids.tolist()
# Update with the actual token ids
for i, req_state, seq_len in request_seq_lens:
token_id = sampled_token_ids[i]
self.input_batch.token_ids_cpu[i, seq_len] = token_id
req_state.output_token_ids[-1] = token_id
if sampler_output.logprob_token_ids is None:
logprob_token_ids = None
else:
@ -808,12 +823,6 @@ class GPUModelRunner:
else:
logprobs = sampler_output.logprobs.cpu()
# num_reqs entries should be non-None
assert all(
req_id is not None for req_id in
self.input_batch.req_ids[:num_reqs]), "req_ids contains None"
req_ids = cast(List[str], self.input_batch.req_ids[:num_reqs])
model_runner_output = ModelRunnerOutput(
req_ids=req_ids,
req_id_to_index=self.input_batch.req_id_to_index,