[BugFix] Fix tokenize asyncio task leak (#24677)

Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Nick Hill 2025-09-11 12:44:04 -07:00 committed by GitHub
parent c733bd5e87
commit b971f91504
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -168,8 +168,8 @@ class BaseRenderer(ABC):
if isinstance(prompt_embeds, list):
return [_load_and_validate_embed(embed) for embed in prompt_embeds]
else:
return [_load_and_validate_embed(prompt_embeds)]
return [_load_and_validate_embed(prompt_embeds)]
class CompletionRenderer(BaseRenderer):
@ -182,7 +182,7 @@ class CompletionRenderer(BaseRenderer):
AsyncMicrobatchTokenizer]] = None,
):
super().__init__(model_config, tokenizer)
self.async_tokenizer_pool = async_tokenizer_pool or {}
self.async_tokenizer_pool = async_tokenizer_pool
self.async_tokenizer: Optional[AsyncMicrobatchTokenizer] = None
async def render_prompt(
@ -208,23 +208,21 @@ class CompletionRenderer(BaseRenderer):
for prompt_input in batch_inputs:
if prompt_input["is_tokens"] is True:
# Token input
detokenize_task = asyncio.create_task(
# Note: detokenization is needed when echo is enabled,
# where the input token IDs are decoded back to text.
self._maybe_detokenize(prompt_input["content"],
config.max_length,
truncate_prompt_tokens,
config.cache_salt,
config.needs_detokenization))
tasks.append(detokenize_task)
# Note: detokenization is needed when echo is enabled,
# where the input token IDs are decoded back to text.
task = self._maybe_detokenize(prompt_input["content"],
config.max_length,
truncate_prompt_tokens,
config.cache_salt,
config.needs_detokenization)
else:
# Text input
tokenize_task = asyncio.create_task(
self._tokenize(prompt_input["content"], config.max_length,
truncate_prompt_tokens,
config.add_special_tokens,
config.cache_salt))
tasks.append(tokenize_task)
task = self._tokenize(prompt_input["content"],
config.max_length,
truncate_prompt_tokens,
config.add_special_tokens,
config.cache_salt)
tasks.append(task)
# Wait for all text tokenization to finish
if tasks:
@ -356,20 +354,24 @@ class CompletionRenderer(BaseRenderer):
def _get_async_tokenizer(self) -> AsyncMicrobatchTokenizer:
"""Get or create async tokenizer using shared pool."""
if self.async_tokenizer is not None:
return self.async_tokenizer
async_tokenizer = self.async_tokenizer
if async_tokenizer is not None:
return async_tokenizer
tokenizer = self.tokenizer
if self.tokenizer is None:
raise ValueError(
"No tokenizer available for text input processing")
# Check shared pool first
if self.tokenizer in self.async_tokenizer_pool:
return self.async_tokenizer_pool[self.tokenizer]
# Create new async tokenizer and add to pool
self.async_tokenizer = AsyncMicrobatchTokenizer(self.tokenizer)
self.async_tokenizer_pool[self.tokenizer] = self.async_tokenizer
return self.async_tokenizer
if self.async_tokenizer_pool is None:
async_tokenizer = AsyncMicrobatchTokenizer(tokenizer)
else:
async_tokenizer = self.async_tokenizer_pool.get(tokenizer)
if async_tokenizer is None:
async_tokenizer = AsyncMicrobatchTokenizer(tokenizer)
self.async_tokenizer_pool[tokenizer] = async_tokenizer
self.async_tokenizer = async_tokenizer
return async_tokenizer
def _create_tokens_prompt(
self,