Merge a2b34802361df50da4ef3ac250707b71b2ce282a into 254f6b986720c92ddf97fbb1a6a6465da8e87e29

This commit is contained in:
ゆり 2025-12-25 00:07:15 +00:00 committed by GitHub
commit 1d3be5cbed
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -6,6 +6,7 @@ import numpy as np
from numba import get_num_threads, jit, njit, prange, set_num_threads from numba import get_num_threads, jit, njit, prange, set_num_threads
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed import get_tp_group
class NgramProposer: class NgramProposer:
@ -33,21 +34,17 @@ class NgramProposer:
# Threshold of total number of tokens in the batch to enable # Threshold of total number of tokens in the batch to enable
# multi-threading in numba batch propose. # multi-threading in numba batch propose.
self.num_tokens_threshold = 8192 self.num_tokens_threshold = 8192
tp_size = vllm_config.parallel_config.tensor_parallel_size
cpu_count = os.cpu_count() cpu_count = os.cpu_count()
# Max number of threads for numba parallel processing. # Max number of threads for numba parallel processing.
# Since draft tokens are computed only on rank 0 and broadcast to other
# ranks (for TP consistency), rank 0 can use all available threads.
if cpu_count: if cpu_count:
# Divide by 2 to use physical cores # Divide by 2 to use physical cores
# and not logical cores (hyper-threading). # and not logical cores (hyper-threading).
# Cap the number of threads to 8 to avoid using too many threads # Cap the number of threads to 8 to avoid using too many threads
# since other components like frontend (incl tokenization) # since other components like frontend (incl tokenization)
# and Structured Outputs also use multiple threads. # and Structured Outputs also use multiple threads.
# TODO(ekagra-ranjan): bump up the cap from 1 to 8 self.num_numba_thread_available = min(8, cpu_count // 2)
# when TP parallelization for ngram is implemented.
self.num_numba_thread_available = min(1, (cpu_count // 2))
# Divide by tp_size to ensure each tensor parallel rank
# has some threads since all ranks will run this.
self.num_numba_thread_available //= tp_size
else: else:
self.num_numba_thread_available = 1 self.num_numba_thread_available = 1
@ -137,33 +134,47 @@ class NgramProposer:
token_ids_cpu: np.ndarray, token_ids_cpu: np.ndarray,
spec_decode_unsupported_reqs: set, spec_decode_unsupported_reqs: set,
) -> list[list[int]]: ) -> list[list[int]]:
# find which requests need ngram proposals # Only compute draft tokens on TP rank 0 and broadcast to other ranks.
valid_ngram_requests = [] # This ensures all TP ranks have identical draft tokens, which is
for i, sampled_ids in enumerate(sampled_token_ids): # required because numba parallel execution can produce different
num_sampled_ids = len(sampled_ids) # results across ranks due to non-determinism.
if not num_sampled_ids: tp_group = get_tp_group()
# Skip speculative decoding. if tp_group.is_first_rank:
continue # find which requests need ngram proposals
valid_ngram_requests = []
for i, sampled_ids in enumerate(sampled_token_ids):
num_sampled_ids = len(sampled_ids)
if not num_sampled_ids:
# Skip speculative decoding.
continue
# Skip requests that require sampling parameters that are not # Skip requests that require sampling parameters that are not
# supported with speculative decoding. # supported with speculative decoding.
req_id = req_ids[i] req_id = req_ids[i]
if req_id in spec_decode_unsupported_reqs: if req_id in spec_decode_unsupported_reqs:
continue continue
num_tokens = num_tokens_no_spec[i] num_tokens = num_tokens_no_spec[i]
if num_tokens >= self.max_model_len: if num_tokens >= self.max_model_len:
# Skip requests that have already reached the max model length. # Skip requests that have already reached the max model length.
continue continue
valid_ngram_requests.append(i) valid_ngram_requests.append(i)
draft_token_ids = self.batch_propose( draft_token_ids = self.batch_propose(
len(sampled_token_ids), len(sampled_token_ids),
valid_ngram_requests, valid_ngram_requests,
num_tokens_no_spec, num_tokens_no_spec,
token_ids_cpu, token_ids_cpu,
) )
else:
draft_token_ids = None
# Broadcast draft tokens from rank 0 to all other ranks.
# Rank 0 always computes valid draft_token_ids, so broadcast
# will never return None.
draft_token_ids = tp_group.broadcast_object(draft_token_ids, src=0)
assert draft_token_ids is not None
return draft_token_ids return draft_token_ids