Signed-off-by: Woosuk Kwon <woosuk@thinkingmachines.ai>
This commit is contained in:
Woosuk Kwon 2025-09-15 21:19:18 +00:00
parent f1981db101
commit e107680d8a
2 changed files with 73 additions and 14 deletions

View File

@ -10,6 +10,7 @@ from vllm.config import VllmConfig
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, is_pin_memory_available
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.sample.sampler import SamplerOutput
@ -178,7 +179,6 @@ class GPUModelRunner:
positions = self.input_buffers.positions
query_start_loc = self.input_buffers.query_start_loc
seq_lens = self.input_buffers.seq_lens
prepare_inputs(
idx_mapping_np,
self.req_states.token_ids,
@ -196,30 +196,59 @@ class GPUModelRunner:
# tensors from CPU to GPU, because they may include paddings needed
# for full CUDA graph mode.
query_start_loc.copy_to_gpu()
query_start_loc = query_start_loc.gpu[:num_reqs + 1]
query_start_loc_cpu = query_start_loc.cpu[:num_reqs + 1]
query_start_loc_gpu = query_start_loc.gpu[:num_reqs + 1]
max_query_len = int(num_scheduled_tokens.max())
seq_lens.copy_to_gpu()
seq_lens_cpu = seq_lens.cpu[:num_reqs]
seq_lens_np = seq_lens.np[:num_reqs]
max_seq_len = int(seq_lens_np.max())
seq_lens = seq_lens.gpu[:num_reqs]
seq_lens_gpu = seq_lens.gpu[:num_reqs]
num_computed_tokens_np = self.req_states.num_computed_tokens[
idx_mapping_np]
num_computed_tokens_cpu = torch.from_numpy(num_computed_tokens_np)
num_tokens = self.req_states.num_tokens[idx_mapping_np]
is_chunked_prefilling = seq_lens_np < num_tokens
# Slot mappings: [num_kv_cache_groups, num_tokens]
slot_mappings = self.block_tables.compute_slot_mappings(
query_start_loc, positions.gpu[:num_tokens])
logits_indices = query_start_loc[1:] - 1
query_start_loc_gpu, positions.gpu[:num_tokens])
logits_indices = query_start_loc_gpu[1:] - 1
# Layer name -> attention metadata.
attn_metadata: dict[str, Any] = {}
for i, kv_cache_spec in enumerate(
self.kv_cache_config.kv_cache_groups):
block_table = block_tables[i]
slot_mapping = slot_mappings[i]
num_common_prefix_blocks = 0
common_attn_metadata = CommonAttentionMetadata(
query_start_loc=query_start_loc_gpu,
query_start_loc_cpu=query_start_loc_cpu,
seq_lens=seq_lens_gpu,
seq_lens_cpu=seq_lens_cpu,
num_computed_tokens_cpu=num_computed_tokens_cpu,
num_reqs=num_reqs,
num_actual_tokens=num_tokens,
max_query_len=max_query_len,
max_seq_len=max_seq_len,
block_table_tensor=block_table,
slot_mapping=slot_mapping,
logits_indices_padded=None,
num_logits_indices=logits_indices.size(0),
causal=True,
encoder_seq_lens=None,
)
attn_metadata_builder = self.attn_metadata_builders[i]
attn_metadata = attn_metadata_builder.build(
common_prefix_len=0,
common_attn_metadata=common_attn_metadata,
)
for layer_name in kv_cache_spec.layer_names:
attn_metadata[layer_name] = attn_metadata
return InputBatch(
req_ids=req_ids,
@ -237,8 +266,8 @@ class GPUModelRunner:
def sample(
self,
input_batch: InputBatch,
logits: torch.Tensor,
input_batch: InputBatch,
) -> SamplerOutput:
sampling_metadata = self.req_states.make_sampling_metadata(
input_batch.idx_mapping_np)
@ -246,13 +275,31 @@ class GPUModelRunner:
logits=logits,
sampling_metadata=sampling_metadata,
)
return sampler_output
def compute_prompt_logprobs(
self,
hidden_states: torch.Tensor,
input_batch: InputBatch,
):
idx_mapping_np = input_batch.idx_mapping_np
needs_prompt_logprobs = self.req_states.needs_prompt_logprobs[
idx_mapping_np]
if not np.any(needs_prompt_logprobs):
# Common case.
# No request in the batch needs prompt logprobs.
return None
num_prompt_tokens_scheduled = ...
if not np.any(num_prompt_tokens_scheduled > 0 & needs_prompt_logprobs):
# The request already computed prompt logprobs.
return None
return
def postprocess(
self,
input_batch: InputBatch,
sampler_output: SamplerOutput,
input_batch: InputBatch,
) -> np.ndarray:
# Get the number of sampled tokens.
# 0 if chunked-prefilling, 1 if not.
@ -282,9 +329,13 @@ class GPUModelRunner:
positions=input_batch.positions,
)
sampling_hidden_states = hidden_states[input_batch.logits_indices]
logits = self.model.compute_logits(sampling_hidden_states, None)
sampler_output = self.sample(input_batch, logits)
# Compute logits to sample next tokens.
sample_hidden_states = hidden_states[input_batch.logits_indices]
logits = self.model.compute_logits(sample_hidden_states, None)
num_sampled_tokens = self.postprocess(input_batch, sampler_output)
sampler_output = self.sample(logits, input_batch)
prompt_logprobs = self.compute_prompt_logprobs(hidden_states,
input_batch)
output = self.postprocess(sampler_output, input_batch)
return output

View File

@ -51,6 +51,7 @@ class RequestState:
)
self.num_tokens = np.zeros(self.max_num_reqs, dtype=np.int32)
self.num_computed_tokens = np.zeros(self.max_num_reqs, dtype=np.int32)
self.num_prompt_tokens = np.zeros(self.max_num_reqs, dtype=np.int32)
# Last sampled token ids.
self.last_token = torch.zeros(
@ -67,6 +68,8 @@ class RequestState:
# -1 means no logprobs are requested.
self.num_logprobs.fill(-1)
self.needs_prompt_logprobs = np.zeros(self.max_num_reqs, dtype=bool)
@property
def num_reqs(self) -> int:
return len(self.req_id_to_index)
@ -85,6 +88,7 @@ class RequestState:
prompt_len = len(prompt_token_ids)
self.num_tokens[req_idx] = prompt_len
self.num_prompt_tokens[req_idx] = prompt_len
self.token_ids[req_idx, :prompt_len] = prompt_token_ids
self.num_computed_tokens[req_idx] = num_computed_tokens
@ -102,6 +106,10 @@ class RequestState:
num_logprobs = -1
self.num_logprobs[req_idx] = num_logprobs
# For now, only support prompt logprobs for the prompt tokens.
needs_prompt_logprobs = sampling_params.prompt_logprobs is not None
self.needs_prompt_logprobs[req_idx] = needs_prompt_logprobs
def append_token_ids(
self,
req_idx: int,