diff --git a/tests/v1/sample/test_logprobs.py b/tests/v1/sample/test_logprobs.py index 42584938bc06..c0b0e1ea226e 100644 --- a/tests/v1/sample/test_logprobs.py +++ b/tests/v1/sample/test_logprobs.py @@ -521,8 +521,8 @@ def test_logprobs_mode(logprobs_mode: LogprobsMode): pytest.param( ( "eagle", - "meta-llama/Llama-3.1-8B-Instruct", - "yuhuili/EAGLE-LLaMA3.1-Instruct-8B", + "meta-llama/Llama-3.2-1B-Instruct", + "nm-testing/Llama3_2_1B_speculator.eagle3", ), marks=large_gpu_mark(min_gb=32), ), @@ -541,7 +541,7 @@ def test_spec_decode_logprobs( """ from vllm import LLM - prompt = "Hello world" + prompt = "Hello world " * 50 sampling_params = SamplingParams( temperature=0, logprobs=3, max_tokens=10, ignore_eos=False ) @@ -582,6 +582,9 @@ def test_spec_decode_logprobs( seed=42, logprobs_mode=logprobs_mode, gpu_memory_utilization=0.4, + # Force prefill chunking + enable_chunked_prefill=True, + max_num_batched_tokens=32, ) spec_results = spec_llm.generate([prompt], sampling_params) # Collect logprobs outputs from spec decode LLM. @@ -597,6 +600,8 @@ def test_spec_decode_logprobs( # Per-token logprobs are expected to be the same. assert len(ref_logprobs) == len(spec_logprobs) for ref_logprob, spec_logprob in zip(ref_logprobs, spec_logprobs): - assert math.isclose(ref_logprob.logprob, spec_logprob.logprob, abs_tol=1e-3) + assert math.isclose( + ref_logprob.logprob, spec_logprob.logprob, rel_tol=5e-2, abs_tol=1e-1 + ) assert ref_logprob.rank == spec_logprob.rank assert ref_logprob.decoded_token == spec_logprob.decoded_token diff --git a/vllm/v1/sample/sampler.py b/vllm/v1/sample/sampler.py index 39c63fe31ad2..c75b4f0543c0 100644 --- a/vllm/v1/sample/sampler.py +++ b/vllm/v1/sample/sampler.py @@ -81,7 +81,10 @@ class Sampler(nn.Module): if logprobs_mode == "raw_logprobs": raw_logprobs = self.compute_logprobs(logits) elif logprobs_mode == "raw_logits": - raw_logprobs = logits.clone() + if logits.dtype == torch.float32: + raw_logprobs = logits.clone() + else: + raw_logprobs = logits.to(torch.float32) # Use float32 for the logits. logits = logits.to(torch.float32) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 298bb1ef5f6f..979f97758703 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2466,7 +2466,9 @@ class GPUModelRunner( num_sampled_tokens = sampler_output.sampled_token_ids.shape[0] sampled_token_ids = sampler_output.sampled_token_ids + logprobs_tensors = sampler_output.logprobs_tensors invalid_req_indices = [] + cu_num_new_tokens: list[int] | None = None if not self.use_async_scheduling: # Get the valid generated tokens. max_gen_len = sampled_token_ids.shape[-1] @@ -2479,6 +2481,12 @@ class GPUModelRunner( sampled_token_ids, self.input_batch.vocab_size, ) + if logprobs_tensors: + # Needed for extracting logprobs when spec decoding. + # This must be done prior to discarding sampled tokens. + cu_num_new_tokens = [0] + for toks in valid_sampled_token_ids: + cu_num_new_tokens.append(cu_num_new_tokens[-1] + len(toks)) # Mask out the sampled tokens that should not be sampled. for i in discard_sampled_tokens_req_indices: valid_sampled_token_ids[int(i)].clear() @@ -2506,10 +2514,6 @@ class GPUModelRunner( # the sampled tokens back, because there's no direct communication # between the first-stage worker and the last-stage worker. req_ids = self.input_batch.req_ids - logprobs_tensors = sampler_output.logprobs_tensors - cu_num_accepted_tokens = ( - [0] if spec_decode_metadata and logprobs_tensors else None - ) for req_idx in range(num_sampled_tokens): if self.use_async_scheduling: sampled_ids = [-1] if req_idx not in invalid_req_indices_set else None @@ -2518,11 +2522,6 @@ class GPUModelRunner( num_sampled_ids: int = len(sampled_ids) if sampled_ids else 0 - if cu_num_accepted_tokens is not None: - cu_num_accepted_tokens.append( - cu_num_accepted_tokens[-1] + num_sampled_ids - ) - if not sampled_ids: continue @@ -2544,7 +2543,7 @@ class GPUModelRunner( req_state.output_token_ids.extend(sampled_ids) logprobs_lists = ( - logprobs_tensors.tolists(cu_num_accepted_tokens) + logprobs_tensors.tolists(cu_num_new_tokens) if not self.use_async_scheduling and logprobs_tensors is not None else None )