[TPU][Bugfix] fix the missing apply_model in tpu worker (#25526)

Signed-off-by: Chengji Yao <chengjiyao@google.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Chengji Yao 2025-09-23 22:18:08 -07:00 committed by yewentao256
parent d7fb5a4ae8
commit 5b4ba2e1e1
2 changed files with 8 additions and 6 deletions

View File

@ -48,13 +48,9 @@ def test_model_tpu_int8(vllm_runner, model: str, dtype: str, max_tokens: int,
prompts = [
"A robot may not injure a human being",
"It is only with the heart that one can see rightly;",
"The greatest glory in living lies not in never falling,",
]
answers = [
"or, being injured, not kill, except in",
"without the heart, one can only see wrongly.",
"but in rising every time we fall. - Nelson"
"or kill a human being",
]
with vllm_runner(model, dtype=dtype, hf_overrides=hf_overrides) as vllm:

View File

@ -3,7 +3,7 @@
"""A TPU worker class."""
import os
from typing import Any, Optional
from typing import Any, Callable, Optional, TypeVar
import torch
import torch.distributed
@ -31,6 +31,8 @@ from vllm.v1.worker.utils import bind_kv_cache
logger = init_logger(__name__)
_R = TypeVar("_R")
if not USE_TPU_COMMONS:
logger.info("tpu_commons not found, using vLLM's TPUWorker.")
import torch_xla.core.xla_model as xm
@ -333,6 +335,10 @@ class TPUWorker:
def shutdown(self) -> None:
self.model_runner.ensure_kv_transfer_shutdown()
def apply_model(self, fn: Callable[[nn.Module], _R]) -> _R:
"""Apply a function on the model inside this worker."""
return fn(self.get_model())
if USE_TPU_COMMONS:
from tpu_commons.worker import TPUWorker as TPUCommonsWorker