diff --git a/vllm/envs.py b/vllm/envs.py index 2c731eda7836a..bb419dacb1ee8 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -31,6 +31,7 @@ if TYPE_CHECKING: VLLM_LOGGING_LEVEL: str = "INFO" VLLM_LOGGING_PREFIX: str = "" VLLM_LOGGING_CONFIG_PATH: Optional[str] = None + VLLM_LOGITS_PROCESSOR_THREADS: Optional[int] = None VLLM_TRACE_FUNCTION: int = 0 VLLM_ATTENTION_BACKEND: Optional[str] = None VLLM_USE_FLASHINFER_SAMPLER: Optional[bool] = None @@ -282,6 +283,14 @@ environment_variables: Dict[str, Callable[[], Any]] = { "VLLM_LOGGING_PREFIX": lambda: os.getenv("VLLM_LOGGING_PREFIX", ""), + # if set, vllm will call logits processors in a thread pool with this many + # threads. This is useful when using custom logits processors that either + # (a) launch additional CUDA kernels or (b) do significant CPU-bound work + # while not holding the python GIL, or both. + "VLLM_LOGITS_PROCESSOR_THREADS": + lambda: int(os.getenv("VLLM_LOGITS_PROCESSOR_THREADS", "0")) + if "VLLM_LOGITS_PROCESSOR_THREADS" in os.environ else None, + # Trace function calls # If set to 1, vllm will trace function calls # Useful for debugging diff --git a/vllm/model_executor/layers/logits_processor.py b/vllm/model_executor/layers/logits_processor.py index ebf74c67d64cd..cdc67ca83d489 100644 --- a/vllm/model_executor/layers/logits_processor.py +++ b/vllm/model_executor/layers/logits_processor.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 """A layer that compute logits from hidden_stats.""" import inspect +from concurrent.futures import ThreadPoolExecutor from typing import Optional import torch @@ -15,6 +16,11 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.platforms import current_platform +_logits_processor_threadpool: Optional[ThreadPoolExecutor] = None +if envs.VLLM_LOGITS_PROCESSOR_THREADS is not None: + _logits_processor_threadpool = ThreadPoolExecutor( + envs.VLLM_LOGITS_PROCESSOR_THREADS) + class LogitsProcessor(nn.Module): """Process logits and apply logits processors from sampling metadata. @@ -135,6 +141,7 @@ def _apply_logits_processors( ) -> torch.Tensor: found_logits_processors = False logits_processed = 0 + logits_row_ids_and_logits_row_futures = [] for seq_group in sampling_metadata.seq_groups: seq_ids = seq_group.seq_ids sampling_params = seq_group.sampling_params @@ -148,22 +155,39 @@ def _apply_logits_processors( past_tokens_ids = seq_group.seq_data[seq_id].output_token_ids prompt_tokens_ids = seq_group.seq_data[seq_id].prompt_token_ids - for logits_processor in logits_processors: - parameters = inspect.signature(logits_processor).parameters - if len(parameters) == 3: - logits_row = logits_processor(prompt_tokens_ids, - past_tokens_ids, - logits_row) - else: - logits_row = logits_processor(past_tokens_ids, - logits_row) - - logits[logits_row_idx] = logits_row + if _logits_processor_threadpool is not None: + logits_row_ids_and_logits_row_futures.append( + (logits_row_idx, + _logits_processor_threadpool.submit( + _apply_logits_processors_single_seq, logits_row, + logits_processors, past_tokens_ids, + prompt_tokens_ids))) + else: + logits[logits_row_idx] = \ + _apply_logits_processors_single_seq( + logits_row, logits_processors, past_tokens_ids, + prompt_tokens_ids) logits_processed += len(seq_group.sample_indices) + len( seq_group.prompt_logprob_indices) + for logits_row_idx, future in logits_row_ids_and_logits_row_futures: + logits[logits_row_idx] = future.result() + if found_logits_processors: # verifies that no rows in logits were missed unexpectedly assert logits_processed == logits.shape[0] return logits + + +def _apply_logits_processors_single_seq(logits_row, logits_processors, + past_tokens_ids, + prompt_tokens_ids) -> torch.Tensor: + for logits_processor in logits_processors: + parameters = inspect.signature(logits_processor).parameters + if len(parameters) == 3: + logits_row = logits_processor(prompt_tokens_ids, past_tokens_ids, + logits_row) + else: + logits_row = logits_processor(past_tokens_ids, logits_row) + return logits_row