mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-26 12:09:38 +08:00
[TPU][Bugfix] fix the missing apply_model in tpu worker (#25526)
Signed-off-by: Chengji Yao <chengjiyao@google.com> Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
d7fb5a4ae8
commit
5b4ba2e1e1
@ -48,13 +48,9 @@ def test_model_tpu_int8(vllm_runner, model: str, dtype: str, max_tokens: int,
|
||||
|
||||
prompts = [
|
||||
"A robot may not injure a human being",
|
||||
"It is only with the heart that one can see rightly;",
|
||||
"The greatest glory in living lies not in never falling,",
|
||||
]
|
||||
answers = [
|
||||
"or, being injured, not kill, except in",
|
||||
"without the heart, one can only see wrongly.",
|
||||
"but in rising every time we fall. - Nelson"
|
||||
"or kill a human being",
|
||||
]
|
||||
|
||||
with vllm_runner(model, dtype=dtype, hf_overrides=hf_overrides) as vllm:
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
"""A TPU worker class."""
|
||||
|
||||
import os
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Callable, Optional, TypeVar
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
@ -31,6 +31,8 @@ from vllm.v1.worker.utils import bind_kv_cache
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_R = TypeVar("_R")
|
||||
|
||||
if not USE_TPU_COMMONS:
|
||||
logger.info("tpu_commons not found, using vLLM's TPUWorker.")
|
||||
import torch_xla.core.xla_model as xm
|
||||
@ -333,6 +335,10 @@ class TPUWorker:
|
||||
def shutdown(self) -> None:
|
||||
self.model_runner.ensure_kv_transfer_shutdown()
|
||||
|
||||
def apply_model(self, fn: Callable[[nn.Module], _R]) -> _R:
|
||||
"""Apply a function on the model inside this worker."""
|
||||
return fn(self.get_model())
|
||||
|
||||
|
||||
if USE_TPU_COMMONS:
|
||||
from tpu_commons.worker import TPUWorker as TPUCommonsWorker
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user