diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 5d4822a6279b2..1e1161727be1e 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -522,13 +522,9 @@ class EagleProposer: ) # Generate a mask for all valid tokens within those requests - max_gen_len = sampled_token_ids.shape[-1] - if max_gen_len == 1: - valid_mask = torch.ones_like(valid_sampled_token_ids_gpu, dtype=torch.bool) - else: - valid_mask = (valid_sampled_token_ids_gpu != -1) & ( - valid_sampled_token_ids_gpu < gpu_input_batch.vocab_size - ) + valid_mask = (valid_sampled_token_ids_gpu != -1) & ( + valid_sampled_token_ids_gpu < gpu_input_batch.vocab_size + ) # Count the number of valid tokens in each request valid_sampled_tokens_count = valid_mask.sum(dim=1)