diff --git a/vllm/v1/worker/gpu/input_batch.py b/vllm/v1/worker/gpu/input_batch.py index 7675cb45170b..1177d25e300c 100644 --- a/vllm/v1/worker/gpu/input_batch.py +++ b/vllm/v1/worker/gpu/input_batch.py @@ -37,6 +37,9 @@ class InputBuffers: self.seq_lens = torch.zeros(max_num_reqs, dtype=torch.int32, device=device) self.cu_num_logits = self._make_buffer(max_num_reqs + 1, dtype=torch.int32) + # Spec decoding. + self.next_prefill_tokens = self._make_buffer(max_num_reqs, dtype=torch.int32) + # Structured outputs. self.bitmask_indices = self._make_buffer(max_num_reqs, dtype=torch.int32) self.grammar_bitmask = self._make_buffer( diff --git a/vllm/v1/worker/gpu/model_runner.py b/vllm/v1/worker/gpu/model_runner.py index 6e332ee4b75b..205298a415d4 100644 --- a/vllm/v1/worker/gpu/model_runner.py +++ b/vllm/v1/worker/gpu/model_runner.py @@ -45,6 +45,7 @@ from vllm.v1.worker.gpu.input_batch import ( prepare_prefill_inputs, ) from vllm.v1.worker.gpu.sampler import Sampler, compute_prompt_logprobs +from vllm.v1.worker.gpu.spec_decode import init_speculator from vllm.v1.worker.gpu.spec_decode.rejection_sample import rejection_sample from vllm.v1.worker.gpu.states import RequestState, SamplingMetadata from vllm.v1.worker.gpu.structured_outputs import apply_grammar_bitmask @@ -97,16 +98,20 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): if self.use_async_scheduling: self.input_prep_event = torch.cuda.Event() self.structured_outputs_event = torch.cuda.Event() + self.spec_decode_event = torch.cuda.Event() else: self.input_prep_event = None self.structured_outputs_event = None + self.spec_decode_event = None if self.speculative_config is not None: self.do_spec_decode = True self.num_speculative_steps = self.speculative_config.num_speculative_tokens + self.speculator = init_speculator(self.vllm_config, self.device) else: self.do_spec_decode = False self.num_speculative_steps = 0 + self.speculator = None self.req_states = RequestState( max_num_reqs=self.max_num_reqs, @@ -153,6 +158,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.vllm_config, self.device, ) + if self.do_spec_decode: + self.speculator.load_model(self.model) time_after_load = time.perf_counter() self.model_memory_usage = m.consumed_memory @@ -285,6 +292,33 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): logits = self.model.compute_logits(hidden_states) self.sampler(logits, sampling_metadata) + @torch.inference_mode() + def _dummy_speculator_run( + self, + hidden_states: torch.Tensor, + aux_hidden_states: list[torch.Tensor] | None, + ) -> None: + num_tokens = hidden_states.shape[0] + num_reqs = min(num_tokens, self.max_num_reqs) + input_batch = InputBatch.make_dummy( + num_reqs=num_reqs, + num_tokens=num_tokens, + input_buffers=self.input_buffers, + device=self.device, + ) + sampling_metadata = SamplingMetadata.make_dummy( + num_reqs=num_reqs, + device=self.device, + ) + num_sampled = torch.ones(num_reqs, dtype=torch.int32, device=self.device) + self.propose_draft( + input_batch=input_batch, + sampling_metadata=sampling_metadata, + last_hidden_states=hidden_states, + aux_hidden_states=aux_hidden_states, + num_sampled=num_sampled, + ) + @torch.inference_mode() def profile_run(self) -> None: hidden_states, sample_hidden_states = self._dummy_run( @@ -292,6 +326,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): skip_attn=True, ) self._dummy_sampler_run(sample_hidden_states) + if self.do_spec_decode: + self._dummy_speculator_run(hidden_states, None) torch.cuda.synchronize() del hidden_states, sample_hidden_states gc.collect() @@ -727,6 +763,41 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.req_states.prefill_len.np[idx_mapping_np], ) + @torch.inference_mode() + def propose_draft( + self, + input_batch: InputBatch, + sampling_metadata: SamplingMetadata, + last_hidden_states: torch.Tensor, + aux_hidden_states: list[torch.Tensor] | None, + num_sampled: torch.Tensor, + ) -> torch.Tensor: + num_reqs = input_batch.num_reqs + idx_mapping_np = input_batch.idx_mapping_np + with async_barrier(self.spec_decode_event): + self.input_buffers.next_prefill_tokens.np[:num_reqs] = ( + self.req_states.prefill_token_ids[ + idx_mapping_np, + self.req_states.num_computed_prefill_tokens[idx_mapping_np], + ] + ) + next_prefill_tokens = self.input_buffers.next_prefill_tokens.copy_to_gpu( + num_reqs + ) + + assert self.speculator is not None + draft_tokens = self.speculator.propose( + input_batch, + sampling_metadata, + last_hidden_states, + aux_hidden_states, + num_sampled, + self.req_states.last_sampled_tokens, + next_prefill_tokens, + ) + self.req_states.draft_tokens[input_batch.idx_mapping] = draft_tokens + return draft_tokens + def get_cudagraph_and_dp_padding( self, scheduler_output: SchedulerOutput, @@ -913,6 +984,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.postprocess( input_batch, sampler_output.sampled_token_ids, num_sampled_tokens ) + if self.do_spec_decode: + _ = self.propose_draft( + input_batch, + sampling_metadata, + hidden_states, + None, # aux_hidden_states + num_sampled_tokens, + ) if self.use_async_scheduling: return async_output diff --git a/vllm/v1/worker/gpu/sampler.py b/vllm/v1/worker/gpu/sampler.py index c48ed2d8ca16..d8676079ab95 100644 --- a/vllm/v1/worker/gpu/sampler.py +++ b/vllm/v1/worker/gpu/sampler.py @@ -100,8 +100,9 @@ def _gumbel_sample_kernel( mask=mask, other=float("-inf"), ) + logits = logits.to(tl.float32) - temp = tl.load(temp_ptr + req_idx) + temp = tl.load(temp_ptr + req_idx).to(tl.float32) if temp != 0.0: # Calculate the seed for gumbel noise. seed = tl.load(seeds_ptr + req_idx) @@ -116,7 +117,7 @@ def _gumbel_sample_kernel( # Apply temperature. if APPLY_TEMPERATURE: # NOTE(woosuk): Use div_rn to match the behavior of torch. - logits = tl.div_rn(logits, temp.to(tl.float32)) + logits = tl.div_rn(logits, temp) # Apply gumbel noise. logits = tl.where(mask, logits + gumbel_noise, float("-inf")) diff --git a/vllm/v1/worker/gpu/spec_decode/__init__.py b/vllm/v1/worker/gpu/spec_decode/__init__.py index e69de29bb2d1..15b85204e05c 100644 --- a/vllm/v1/worker/gpu/spec_decode/__init__.py +++ b/vllm/v1/worker/gpu/spec_decode/__init__.py @@ -0,0 +1,18 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import torch + +from vllm.config import VllmConfig + + +def init_speculator( + vllm_config: VllmConfig, + device: torch.device, +): + speculative_config = vllm_config.speculative_config + assert speculative_config is not None + if speculative_config.use_eagle(): + from vllm.v1.worker.gpu.spec_decode.eagle import EagleSpeculator + + return EagleSpeculator(vllm_config, device) + raise NotImplementedError(f"{speculative_config.method} is not supported yet.") diff --git a/vllm/v1/worker/gpu/spec_decode/eagle.py b/vllm/v1/worker/gpu/spec_decode/eagle.py new file mode 100644 index 000000000000..0f11903e1454 --- /dev/null +++ b/vllm/v1/worker/gpu/spec_decode/eagle.py @@ -0,0 +1,197 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import torch +import torch.nn as nn + +from vllm.config import VllmConfig +from vllm.config.compilation import CUDAGraphMode +from vllm.forward_context import set_forward_context +from vllm.model_executor.model_loader import get_model +from vllm.triton_utils import tl, triton +from vllm.v1.worker.gpu.input_batch import InputBatch +from vllm.v1.worker.gpu.sampler import gumbel_sample +from vllm.v1.worker.gpu.states import SamplingMetadata + + +class EagleSpeculator: + def __init__(self, vllm_config: VllmConfig, device: torch.device): + self.vllm_config = vllm_config + self.device = device + + self.speculative_config = vllm_config.speculative_config + assert self.speculative_config is not None + self.method = self.speculative_config.method + self.num_speculative_steps = self.speculative_config.num_speculative_tokens + self.draft_model_config = self.speculative_config.draft_model_config + + self.scheduler_config = vllm_config.scheduler_config + self.max_num_reqs = self.scheduler_config.max_num_seqs + self.max_num_tokens = self.scheduler_config.max_num_batched_tokens + + self.input_ids = torch.zeros( + self.max_num_tokens, dtype=torch.int32, device=device + ) + self.positions = torch.zeros( + self.max_num_tokens, dtype=torch.int64, device=device + ) + + def load_model(self, target_model: nn.Module) -> None: + from vllm.compilation.backends import set_model_tag + + with set_model_tag("eagle_head"): + self.model = get_model( + vllm_config=self.vllm_config, model_config=self.draft_model_config + ) + + share_lm_head = True + if share_lm_head and hasattr(target_model, "lm_head"): + if hasattr(self.model, "lm_head"): + del self.model.lm_head + self.model.lm_head = target_model.lm_head + + @torch.inference_mode() + def propose( + self, + input_batch: InputBatch, + sampling_metadata: SamplingMetadata, + # [num_tokens, hidden_size] + last_hidden_states: torch.Tensor, + # num_layers x [num_tokens, hidden_size] + aux_hidden_states: list[torch.Tensor] | None, + # [num_reqs] + num_sampled: torch.Tensor, + # [max_num_reqs, 1] + last_sampled: torch.Tensor, + # [num_reqs] + next_prefill_tokens: torch.Tensor, + ) -> torch.Tensor: + if aux_hidden_states: + assert self.method == "eagle3" + hidden_states = self.model.combine_hidden_states( + torch.cat(aux_hidden_states, dim=-1) + ) + else: + hidden_states = last_hidden_states + + # Get the input ids and last token indices for the speculator. + last_token_indices = prepare_eagle_inputs( + self.input_ids, + input_batch, + num_sampled, + last_sampled, + next_prefill_tokens, + ) + input_ids = self.input_ids[: input_batch.num_tokens_after_padding] + + # Prefill: Run the eagle speculator with eager mode. + with set_forward_context( + input_batch.attn_metadata, + self.vllm_config, + num_tokens=input_batch.num_tokens_after_padding, + cudagraph_runtime_mode=CUDAGraphMode.NONE, + ): + ret_hidden_states = self.model( + input_ids=input_ids, + positions=input_batch.positions, + hidden_states=hidden_states, + ) + if self.method == "mtp": + last_hidden_states = ret_hidden_states + hidden_states = ret_hidden_states + else: + last_hidden_states, hidden_states = ret_hidden_states + sample_hidden_states = last_hidden_states[last_token_indices] + logits = self.model.compute_logits(sample_hidden_states) + + num_reqs = input_batch.num_reqs + cu_num_logits = input_batch.cu_num_logits[:num_reqs] + temperature = sampling_metadata.temperature[cu_num_logits] + seed = sampling_metadata.seeds[cu_num_logits] + # NOTE(woosuk): We must add 1 to the positions to match the Gumbel noise + # used for draft and target sampling. + pos = input_batch.positions[last_token_indices] + 1 + draft_tokens = gumbel_sample( + logits, temperature, seed, pos, apply_temperature=True + ) + if self.num_speculative_steps == 1: + # Early exit. + return draft_tokens.view(-1, 1) + raise NotImplementedError("num_speculative_steps > 1 is not supported yet.") + + +@triton.jit +def _prepare_eagle_inputs_kernel( + last_token_indices_ptr, + eagle_input_ids_ptr, + target_input_ids_ptr, + idx_mapping_ptr, + last_sampled_ptr, + next_prefill_tokens_ptr, + num_sampled_ptr, + query_start_loc_ptr, + cu_num_logits_ptr, + BLOCK_SIZE: tl.constexpr, +): + batch_idx = tl.program_id(0) + query_start = tl.load(query_start_loc_ptr + batch_idx) + query_end = tl.load(query_start_loc_ptr + batch_idx + 1) + query_len = query_end - query_start + + # Get the true query length and next token after accounting for rejected tokens. + num_sampled = tl.load(num_sampled_ptr + batch_idx) + if num_sampled > 0: + req_state_idx = tl.load(idx_mapping_ptr + batch_idx) + next_token = tl.load(last_sampled_ptr + req_state_idx).to(tl.int32) + + logits_start = tl.load(cu_num_logits_ptr + batch_idx) + logits_end = tl.load(cu_num_logits_ptr + batch_idx + 1) + num_logits = logits_end - logits_start + + num_rejected = num_logits - num_sampled + query_len -= num_rejected + else: + # Chunked prefilling. + # Get the next prefill token. + next_token = tl.load(next_prefill_tokens_ptr + batch_idx) + + # Shift target_input_ids by one. + for i in range(1, query_len, BLOCK_SIZE): + block = i + tl.arange(0, BLOCK_SIZE) + mask = block < query_len + input_ids = tl.load(target_input_ids_ptr + query_start + block, mask=mask) + tl.store(eagle_input_ids_ptr + query_start + block - 1, input_ids, mask=mask) + + last_token_index = query_start + query_len - 1 + tl.store(last_token_indices_ptr + batch_idx, last_token_index) + tl.store(eagle_input_ids_ptr + last_token_index, next_token) + + +def prepare_eagle_inputs( + eagle_input_ids: torch.Tensor, + input_batch: InputBatch, + # [num_reqs] + num_sampled: torch.Tensor, + # [max_num_reqs, 1] + last_sampled: torch.Tensor, + # [max_num_reqs] + next_prefill_tokens: torch.Tensor, +) -> torch.Tensor: + num_reqs = input_batch.num_reqs + last_token_indices = torch.empty( + num_reqs, + dtype=torch.int64, + device=eagle_input_ids.device, + ) + _prepare_eagle_inputs_kernel[(num_reqs,)]( + last_token_indices, + eagle_input_ids, + input_batch.input_ids, + input_batch.idx_mapping, + last_sampled, + next_prefill_tokens, + num_sampled, + input_batch.query_start_loc, + input_batch.cu_num_logits, + BLOCK_SIZE=1024, + ) + return last_token_indices