mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-21 11:37:00 +08:00
Merge a2b34802361df50da4ef3ac250707b71b2ce282a into 254f6b986720c92ddf97fbb1a6a6465da8e87e29
This commit is contained in:
commit
1d3be5cbed
@ -6,6 +6,7 @@ import numpy as np
|
|||||||
from numba import get_num_threads, jit, njit, prange, set_num_threads
|
from numba import get_num_threads, jit, njit, prange, set_num_threads
|
||||||
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
|
from vllm.distributed import get_tp_group
|
||||||
|
|
||||||
|
|
||||||
class NgramProposer:
|
class NgramProposer:
|
||||||
@ -33,21 +34,17 @@ class NgramProposer:
|
|||||||
# Threshold of total number of tokens in the batch to enable
|
# Threshold of total number of tokens in the batch to enable
|
||||||
# multi-threading in numba batch propose.
|
# multi-threading in numba batch propose.
|
||||||
self.num_tokens_threshold = 8192
|
self.num_tokens_threshold = 8192
|
||||||
tp_size = vllm_config.parallel_config.tensor_parallel_size
|
|
||||||
cpu_count = os.cpu_count()
|
cpu_count = os.cpu_count()
|
||||||
# Max number of threads for numba parallel processing.
|
# Max number of threads for numba parallel processing.
|
||||||
|
# Since draft tokens are computed only on rank 0 and broadcast to other
|
||||||
|
# ranks (for TP consistency), rank 0 can use all available threads.
|
||||||
if cpu_count:
|
if cpu_count:
|
||||||
# Divide by 2 to use physical cores
|
# Divide by 2 to use physical cores
|
||||||
# and not logical cores (hyper-threading).
|
# and not logical cores (hyper-threading).
|
||||||
# Cap the number of threads to 8 to avoid using too many threads
|
# Cap the number of threads to 8 to avoid using too many threads
|
||||||
# since other components like frontend (incl tokenization)
|
# since other components like frontend (incl tokenization)
|
||||||
# and Structured Outputs also use multiple threads.
|
# and Structured Outputs also use multiple threads.
|
||||||
# TODO(ekagra-ranjan): bump up the cap from 1 to 8
|
self.num_numba_thread_available = min(8, cpu_count // 2)
|
||||||
# when TP parallelization for ngram is implemented.
|
|
||||||
self.num_numba_thread_available = min(1, (cpu_count // 2))
|
|
||||||
# Divide by tp_size to ensure each tensor parallel rank
|
|
||||||
# has some threads since all ranks will run this.
|
|
||||||
self.num_numba_thread_available //= tp_size
|
|
||||||
else:
|
else:
|
||||||
self.num_numba_thread_available = 1
|
self.num_numba_thread_available = 1
|
||||||
|
|
||||||
@ -137,33 +134,47 @@ class NgramProposer:
|
|||||||
token_ids_cpu: np.ndarray,
|
token_ids_cpu: np.ndarray,
|
||||||
spec_decode_unsupported_reqs: set,
|
spec_decode_unsupported_reqs: set,
|
||||||
) -> list[list[int]]:
|
) -> list[list[int]]:
|
||||||
# find which requests need ngram proposals
|
# Only compute draft tokens on TP rank 0 and broadcast to other ranks.
|
||||||
valid_ngram_requests = []
|
# This ensures all TP ranks have identical draft tokens, which is
|
||||||
for i, sampled_ids in enumerate(sampled_token_ids):
|
# required because numba parallel execution can produce different
|
||||||
num_sampled_ids = len(sampled_ids)
|
# results across ranks due to non-determinism.
|
||||||
if not num_sampled_ids:
|
tp_group = get_tp_group()
|
||||||
# Skip speculative decoding.
|
if tp_group.is_first_rank:
|
||||||
continue
|
# find which requests need ngram proposals
|
||||||
|
valid_ngram_requests = []
|
||||||
|
for i, sampled_ids in enumerate(sampled_token_ids):
|
||||||
|
num_sampled_ids = len(sampled_ids)
|
||||||
|
if not num_sampled_ids:
|
||||||
|
# Skip speculative decoding.
|
||||||
|
continue
|
||||||
|
|
||||||
# Skip requests that require sampling parameters that are not
|
# Skip requests that require sampling parameters that are not
|
||||||
# supported with speculative decoding.
|
# supported with speculative decoding.
|
||||||
req_id = req_ids[i]
|
req_id = req_ids[i]
|
||||||
if req_id in spec_decode_unsupported_reqs:
|
if req_id in spec_decode_unsupported_reqs:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
num_tokens = num_tokens_no_spec[i]
|
num_tokens = num_tokens_no_spec[i]
|
||||||
if num_tokens >= self.max_model_len:
|
if num_tokens >= self.max_model_len:
|
||||||
# Skip requests that have already reached the max model length.
|
# Skip requests that have already reached the max model length.
|
||||||
continue
|
continue
|
||||||
|
|
||||||
valid_ngram_requests.append(i)
|
valid_ngram_requests.append(i)
|
||||||
|
|
||||||
draft_token_ids = self.batch_propose(
|
draft_token_ids = self.batch_propose(
|
||||||
len(sampled_token_ids),
|
len(sampled_token_ids),
|
||||||
valid_ngram_requests,
|
valid_ngram_requests,
|
||||||
num_tokens_no_spec,
|
num_tokens_no_spec,
|
||||||
token_ids_cpu,
|
token_ids_cpu,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
draft_token_ids = None
|
||||||
|
|
||||||
|
# Broadcast draft tokens from rank 0 to all other ranks.
|
||||||
|
# Rank 0 always computes valid draft_token_ids, so broadcast
|
||||||
|
# will never return None.
|
||||||
|
draft_token_ids = tp_group.broadcast_object(draft_token_ids, src=0)
|
||||||
|
assert draft_token_ids is not None
|
||||||
|
|
||||||
return draft_token_ids
|
return draft_token_ids
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user