diff --git a/vllm/v1/worker/gpu/cudagraph_utils.py b/vllm/v1/worker/gpu/cudagraph_utils.py index eb8e610ae4710..7f2994eeca008 100644 --- a/vllm/v1/worker/gpu/cudagraph_utils.py +++ b/vllm/v1/worker/gpu/cudagraph_utils.py @@ -78,7 +78,7 @@ class CudaGraphManager: kv_cache_config: KVCacheConfig, ) -> None: num_reqs = min(num_tokens, self.max_num_reqs) - input_ids = input_buffers.input_ids.gpu[:num_tokens] + input_ids = input_buffers.input_ids[:num_tokens] positions = input_buffers.positions[:num_tokens] attn_metadata = prepare_inputs_to_capture( num_reqs, diff --git a/vllm/v1/worker/gpu/input_batch.py b/vllm/v1/worker/gpu/input_batch.py index 43fd53d3acaae..3f8ef03f96445 100644 --- a/vllm/v1/worker/gpu/input_batch.py +++ b/vllm/v1/worker/gpu/input_batch.py @@ -3,7 +3,6 @@ from dataclasses import dataclass from typing import Any -import numba import numpy as np import torch @@ -30,15 +29,12 @@ class InputBuffers: self.pin_memory = pin_memory self.idx_mapping = self._make_buffer(max_num_reqs, dtype=torch.int32) - self.input_ids = self._make_buffer(max_num_tokens, dtype=torch.int32) + self.input_ids = torch.zeros(max_num_tokens, dtype=torch.int32, device=device) self.positions = torch.zeros(max_num_tokens, dtype=torch.int64, device=device) self.query_start_loc = self._make_buffer(max_num_reqs + 1, dtype=torch.int32) self.seq_lens = torch.zeros(max_num_reqs, dtype=torch.int32, device=device) self.cu_num_logits = self._make_buffer(max_num_reqs + 1, dtype=torch.int32) - # Spec decoding. - self.next_prefill_tokens = self._make_buffer(max_num_reqs, dtype=torch.int32) - # Structured outputs. self.bitmask_indices = self._make_buffer(max_num_reqs, dtype=torch.int32) self.grammar_bitmask = self._make_buffer( @@ -120,7 +116,7 @@ class InputBatch: input_buffers.seq_lens[num_reqs:] = 0 seq_lens = input_buffers.seq_lens[:num_reqs] - input_ids = input_buffers.input_ids.copy_to_gpu(num_tokens) + input_ids = input_buffers.input_ids[:num_tokens] positions = input_buffers.positions[:num_tokens] # attn_metadata = defaultdict(lambda: None) logits_indices = query_start_loc[1:] - 1 @@ -146,41 +142,63 @@ class InputBatch: ) -@numba.njit(cache=True) -def _prepare_prefill_inputs( - idx_mapping: np.ndarray, # [B] - query_lens: np.ndarray, # [B] - query_start_loc: np.ndarray, # [B + 1] - prefill_token_ids: np.ndarray, # [N, max_model_len] - num_computed_prefill_tokens: np.ndarray, # [N] - input_ids: np.ndarray, # [num_input_tokens] -) -> None: - num_reqs = idx_mapping.shape[0] - query_starts = query_start_loc[:num_reqs] - query_ends = query_start_loc[1 : num_reqs + 1] - starts = num_computed_prefill_tokens[idx_mapping] - ends = starts + query_lens - for i in range(num_reqs): - input_ids[query_starts[i] : query_ends[i]] = prefill_token_ids[ - idx_mapping[i], starts[i] : ends[i] - ] +@triton.jit +def _prepare_prefill_inputs_kernel( + input_ids_ptr, + next_prefill_tokens_ptr, + idx_mapping_ptr, + query_start_loc_ptr, + prefill_token_ids_ptr, + prefill_token_ids_stride, + prefill_lens_ptr, + num_computed_tokens_ptr, + BLOCK_SIZE: tl.constexpr, +): + batch_idx = tl.program_id(0) + req_state_idx = tl.load(idx_mapping_ptr + batch_idx) + prefill_len = tl.load(prefill_lens_ptr + req_state_idx) + num_computed = tl.load(num_computed_tokens_ptr + req_state_idx) + if num_computed >= prefill_len: + # Not prefill. + return + + query_start = tl.load(query_start_loc_ptr + batch_idx) + query_end = tl.load(query_start_loc_ptr + batch_idx + 1) + query_len = query_end - query_start + + prefill_ptr = prefill_token_ids_ptr + req_state_idx * prefill_token_ids_stride + for i in range(0, query_len, BLOCK_SIZE): + block = i + tl.arange(0, BLOCK_SIZE) + mask = block < query_len + tokens = tl.load(prefill_ptr + num_computed + block, mask=mask) + tl.store(input_ids_ptr + query_start + block, tokens, mask=mask) + + next_pos = num_computed + query_len + if next_pos < prefill_len: + next_token = tl.load(prefill_ptr + next_pos) + tl.store(next_prefill_tokens_ptr + req_state_idx, next_token) def prepare_prefill_inputs( - idx_mapping: np.ndarray, - num_scheduled_tokens: np.ndarray, - query_start_loc: np.ndarray, - prefill_token_ids: np.ndarray, - num_computed_prefill_tokens: np.ndarray, - input_ids: np.ndarray, + input_ids: torch.Tensor, + next_prefill_tokens: torch.Tensor, + idx_mapping: torch.Tensor, + query_start_loc: torch.Tensor, + prefill_token_ids: torch.Tensor, + prefill_len: torch.Tensor, + num_computed_tokens: torch.Tensor, ) -> None: - _prepare_prefill_inputs( + num_reqs = idx_mapping.shape[0] + _prepare_prefill_inputs_kernel[(num_reqs,)]( + input_ids, + next_prefill_tokens, idx_mapping, - num_scheduled_tokens, query_start_loc, prefill_token_ids, - num_computed_prefill_tokens, - input_ids, + prefill_token_ids.stride(0), + prefill_len, + num_computed_tokens, + BLOCK_SIZE=1024, ) diff --git a/vllm/v1/worker/gpu/model_runner.py b/vllm/v1/worker/gpu/model_runner.py index 9ba234544421d..1b512b2ee3e50 100644 --- a/vllm/v1/worker/gpu/model_runner.py +++ b/vllm/v1/worker/gpu/model_runner.py @@ -104,11 +104,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): if self.use_async_scheduling: self.input_prep_event = torch.cuda.Event() self.structured_outputs_event = torch.cuda.Event() - self.spec_decode_event = torch.cuda.Event() else: self.input_prep_event = None self.structured_outputs_event = None - self.spec_decode_event = None if self.speculative_config is not None: self.do_spec_decode = True @@ -412,9 +410,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): cu_num_new_blocks[i].append(x + len(block_ids)) new_block_ids[i].extend(block_ids) overwrite.append(True) - # Update the GPU tensors for request states. - if scheduler_output.scheduled_new_reqs: - self.req_states.prefill_len.copy_to_gpu() # Add new blocks for the existing requests. cached_reqs = scheduler_output.scheduled_cached_reqs @@ -507,16 +502,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): query_start_loc_cpu = self.input_buffers.query_start_loc.cpu[: num_reqs + 1] query_start_loc_np = self.input_buffers.query_start_loc.np[: num_reqs + 1] - # Copy prefill tokens from CPU to GPU. + # Get prefill tokens. prepare_prefill_inputs( - idx_mapping_np, - num_scheduled_tokens, - query_start_loc_np, - self.req_states.prefill_token_ids.np, - self.req_states.num_computed_prefill_tokens, - self.input_buffers.input_ids.np, + self.input_buffers.input_ids, + self.req_states.next_prefill_tokens, + idx_mapping, + query_start_loc_gpu, + self.req_states.prefill_token_ids.gpu, + self.req_states.prefill_len.gpu, + self.req_states.num_computed_tokens, ) - self.input_buffers.input_ids.copy_to_gpu(num_tokens) # Prepare positions and seq_lens. prepare_pos_seq_lens( @@ -531,7 +526,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Some input token ids are directly read from the last sampled tokens # and draft tokens. Also, get the logits indices to sample tokens from. logits_indices = combine_sampled_and_draft_tokens( - self.input_buffers.input_ids.gpu, + self.input_buffers.input_ids, idx_mapping, self.req_states.last_sampled_tokens, query_start_loc_gpu, @@ -572,7 +567,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): kv_cache_config=self.kv_cache_config, ) - input_ids = self.input_buffers.input_ids.gpu[:num_tokens_after_padding] + input_ids = self.input_buffers.input_ids[:num_tokens_after_padding] positions = self.input_buffers.positions[:num_tokens_after_padding] return InputBatch( req_ids=req_ids, @@ -782,20 +777,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): num_sampled: torch.Tensor, num_rejected: torch.Tensor, ) -> torch.Tensor: - num_reqs = input_batch.num_reqs - idx_mapping_np = input_batch.idx_mapping_np - with async_barrier(self.spec_decode_event): - self.input_buffers.next_prefill_tokens.np[:num_reqs] = ( - self.req_states.prefill_token_ids.np[ - idx_mapping_np, - self.req_states.num_computed_prefill_tokens[idx_mapping_np], - ] - ) - next_prefill_tokens = self.input_buffers.next_prefill_tokens.copy_to_gpu( - num_reqs - ) - assert self.speculator is not None + last_sampled_tokens = self.req_states.last_sampled_tokens[ + input_batch.idx_mapping + ] + next_prefill_tokens = self.req_states.next_prefill_tokens[ + input_batch.idx_mapping + ] draft_tokens = self.speculator.propose( input_batch, sampling_metadata, @@ -803,7 +791,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): aux_hidden_states, num_sampled, num_rejected, - self.req_states.last_sampled_tokens, + last_sampled_tokens, next_prefill_tokens, ) return draft_tokens diff --git a/vllm/v1/worker/gpu/spec_decode/eagle.py b/vllm/v1/worker/gpu/spec_decode/eagle.py index daf2775e8b92d..580d67246dfa1 100644 --- a/vllm/v1/worker/gpu/spec_decode/eagle.py +++ b/vllm/v1/worker/gpu/spec_decode/eagle.py @@ -121,7 +121,7 @@ class EagleSpeculator: num_tokens_across_dp=num_tokens_across_dp, ): ret_hidden_states = self.model( - input_ids=self.input_buffers.input_ids.gpu[:num_tokens], + input_ids=self.input_buffers.input_ids[:num_tokens], positions=self.input_buffers.positions[:num_tokens], hidden_states=self.hidden_states[:num_tokens], ) @@ -194,7 +194,7 @@ class EagleSpeculator: num_sampled: torch.Tensor, # [num_reqs] num_rejected: torch.Tensor, - # [max_num_reqs, 1] + # [num_reqs] last_sampled: torch.Tensor, # [num_reqs] next_prefill_tokens: torch.Tensor, @@ -316,7 +316,6 @@ def _prepare_eagle_inputs_kernel( eagle_positions_ptr, target_input_ids_ptr, target_positions_ptr, - idx_mapping_ptr, last_sampled_ptr, next_prefill_tokens_ptr, num_sampled_ptr, @@ -335,8 +334,7 @@ def _prepare_eagle_inputs_kernel( num_sampled = tl.load(num_sampled_ptr + batch_idx) if num_sampled > 0: - req_state_idx = tl.load(idx_mapping_ptr + batch_idx) - next_token = tl.load(last_sampled_ptr + req_state_idx).to(tl.int32) + next_token = tl.load(last_sampled_ptr + batch_idx).to(tl.int32) else: # Chunked prefilling. # Get the next prefill token. @@ -368,9 +366,9 @@ def prepare_eagle_inputs( num_sampled: torch.Tensor, # [num_reqs] num_rejected: torch.Tensor, - # [max_num_reqs, 1] + # [num_reqs] last_sampled: torch.Tensor, - # [max_num_reqs] + # [num_reqs] next_prefill_tokens: torch.Tensor, ) -> torch.Tensor: num_reqs = input_batch.num_reqs @@ -381,11 +379,10 @@ def prepare_eagle_inputs( ) _prepare_eagle_inputs_kernel[(num_reqs,)]( last_token_indices, - input_buffers.input_ids.gpu, + input_buffers.input_ids, input_buffers.positions, input_batch.input_ids, input_batch.positions, - input_batch.idx_mapping, last_sampled, next_prefill_tokens, num_sampled, @@ -485,7 +482,7 @@ def prepare_eagle_decode( last_token_indices, target_seq_lens, num_rejected, - input_buffers.input_ids.gpu, + input_buffers.input_ids, input_buffers.positions, input_hidden_states, input_hidden_states.stride(0), @@ -553,7 +550,7 @@ def update_eagle_inputs( ): num_reqs, hidden_size = output_hidden_states.shape _update_eagle_inputs_kernel[(num_reqs,)]( - input_buffers.input_ids.gpu, + input_buffers.input_ids, input_buffers.positions, hidden_states, hidden_states.stride(0), diff --git a/vllm/v1/worker/gpu/states.py b/vllm/v1/worker/gpu/states.py index 64874b72e60cf..4ddd2dfdd731f 100644 --- a/vllm/v1/worker/gpu/states.py +++ b/vllm/v1/worker/gpu/states.py @@ -117,8 +117,7 @@ class RequestState: self.prefill_token_ids = UvaBuffer( self.max_num_reqs, self.max_model_len, dtype=torch.int32 ) - self.prefill_len = self._make_buffer(self.max_num_reqs, dtype=torch.int32) - + self.prefill_len = UvaBuffer(self.max_num_reqs, dtype=torch.int32) # Number of computed tokens. self.num_computed_prefill_tokens = np.zeros(self.max_num_reqs, dtype=np.int32) self.num_computed_tokens = torch.zeros( @@ -140,6 +139,9 @@ class RequestState: dtype=torch.int64, device=device, ) + self.next_prefill_tokens = torch.zeros( + self.max_num_reqs, dtype=torch.int32, device=device + ) # LoRA. self.lora_ids = np.zeros(self.max_num_reqs, dtype=np.int32) @@ -380,13 +382,13 @@ def _expand_sampling_metadata_kernel( expanded_top_p_ptr, top_k_ptr, expanded_top_k_ptr, - seeds_ptr, rep_penalty_ptr, expanded_rep_penalty_ptr, freq_penalty_ptr, expanded_freq_penalty_ptr, pres_penalty_ptr, expanded_pres_penalty_ptr, + seeds_ptr, expanded_seeds_ptr, cu_num_logits_ptr, BLOCK_SIZE: tl.constexpr,