mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-03 06:44:29 +08:00
[TPU][Bugfix] fix the missing apply_model in tpu worker (#25526)
Signed-off-by: Chengji Yao <chengjiyao@google.com> Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
d7fb5a4ae8
commit
5b4ba2e1e1
@ -48,13 +48,9 @@ def test_model_tpu_int8(vllm_runner, model: str, dtype: str, max_tokens: int,
|
|||||||
|
|
||||||
prompts = [
|
prompts = [
|
||||||
"A robot may not injure a human being",
|
"A robot may not injure a human being",
|
||||||
"It is only with the heart that one can see rightly;",
|
|
||||||
"The greatest glory in living lies not in never falling,",
|
|
||||||
]
|
]
|
||||||
answers = [
|
answers = [
|
||||||
"or, being injured, not kill, except in",
|
"or kill a human being",
|
||||||
"without the heart, one can only see wrongly.",
|
|
||||||
"but in rising every time we fall. - Nelson"
|
|
||||||
]
|
]
|
||||||
|
|
||||||
with vllm_runner(model, dtype=dtype, hf_overrides=hf_overrides) as vllm:
|
with vllm_runner(model, dtype=dtype, hf_overrides=hf_overrides) as vllm:
|
||||||
|
|||||||
@ -3,7 +3,7 @@
|
|||||||
"""A TPU worker class."""
|
"""A TPU worker class."""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from typing import Any, Optional
|
from typing import Any, Callable, Optional, TypeVar
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
@ -31,6 +31,8 @@ from vllm.v1.worker.utils import bind_kv_cache
|
|||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
_R = TypeVar("_R")
|
||||||
|
|
||||||
if not USE_TPU_COMMONS:
|
if not USE_TPU_COMMONS:
|
||||||
logger.info("tpu_commons not found, using vLLM's TPUWorker.")
|
logger.info("tpu_commons not found, using vLLM's TPUWorker.")
|
||||||
import torch_xla.core.xla_model as xm
|
import torch_xla.core.xla_model as xm
|
||||||
@ -333,6 +335,10 @@ class TPUWorker:
|
|||||||
def shutdown(self) -> None:
|
def shutdown(self) -> None:
|
||||||
self.model_runner.ensure_kv_transfer_shutdown()
|
self.model_runner.ensure_kv_transfer_shutdown()
|
||||||
|
|
||||||
|
def apply_model(self, fn: Callable[[nn.Module], _R]) -> _R:
|
||||||
|
"""Apply a function on the model inside this worker."""
|
||||||
|
return fn(self.get_model())
|
||||||
|
|
||||||
|
|
||||||
if USE_TPU_COMMONS:
|
if USE_TPU_COMMONS:
|
||||||
from tpu_commons.worker import TPUWorker as TPUCommonsWorker
|
from tpu_commons.worker import TPUWorker as TPUCommonsWorker
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user