diff --git a/vllm/v1/spec_decode/ngram_proposer.py b/vllm/v1/spec_decode/ngram_proposer.py index 75d0b3b835576..52c40ee3375fc 100644 --- a/vllm/v1/spec_decode/ngram_proposer.py +++ b/vllm/v1/spec_decode/ngram_proposer.py @@ -34,21 +34,17 @@ class NgramProposer: # Threshold of total number of tokens in the batch to enable # multi-threading in numba batch propose. self.num_tokens_threshold = 8192 - tp_size = vllm_config.parallel_config.tensor_parallel_size cpu_count = os.cpu_count() # Max number of threads for numba parallel processing. + # Since draft tokens are computed only on rank 0 and broadcast to other + # ranks (for TP consistency), rank 0 can use all available threads. if cpu_count: # Divide by 2 to use physical cores # and not logical cores (hyper-threading). # Cap the number of threads to 8 to avoid using too many threads # since other components like frontend (incl tokenization) # and Structured Outputs also use multiple threads. - # TODO(ekagra-ranjan): bump up the cap from 1 to 8 - # when TP parallelization for ngram is implemented. - self.num_numba_thread_available = min(1, (cpu_count // 2)) - # Divide by tp_size to ensure each tensor parallel rank - # has some threads since all ranks will run this. - self.num_numba_thread_available //= tp_size + self.num_numba_thread_available = min(8, cpu_count // 2) else: self.num_numba_thread_available = 1