mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-03 19:11:19 +08:00
JAX-based TPU worker
This commit is contained in:
parent
6d62e4c6aa
commit
91b47e3f2f
@ -1,18 +1,28 @@
|
|||||||
"""A TPU worker class."""
|
from typing import Dict, List, Optional, Tuple
|
||||||
from typing import Dict, List, Optional, Set, Tuple
|
|
||||||
|
|
||||||
|
import jax.numpy as jnp
|
||||||
import torch
|
import torch
|
||||||
import torch_xla.core.xla_model as xm
|
|
||||||
|
|
||||||
from vllm.attention import get_attn_backend
|
|
||||||
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
|
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
|
||||||
ParallelConfig, SchedulerConfig, VisionLanguageConfig)
|
ParallelConfig, SchedulerConfig, VisionLanguageConfig)
|
||||||
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor import set_random_seed
|
from vllm.model_executor import set_random_seed
|
||||||
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
||||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size
|
from vllm.worker.tpu_model_runner import TPUModelRunner
|
||||||
|
from vllm.worker.worker_base import LoraNotSupportedWorkerBase
|
||||||
|
from vllm.utils import get_dtype_size, STR_DTYPE_TO_TORCH_DTYPE
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class TPUWorker:
|
class TPUWorker(LoraNotSupportedWorkerBase):
|
||||||
|
"""A worker class that executes (a partition of) the model on a CPU socket.
|
||||||
|
|
||||||
|
Each worker is associated with a single CPU socket. The worker is
|
||||||
|
responsible for maintaining the KV cache and executing the model on the
|
||||||
|
CPU. In case of distributed inference, each worker is assigned a partition
|
||||||
|
of the model.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -20,60 +30,74 @@ class TPUWorker:
|
|||||||
parallel_config: ParallelConfig,
|
parallel_config: ParallelConfig,
|
||||||
scheduler_config: SchedulerConfig,
|
scheduler_config: SchedulerConfig,
|
||||||
device_config: DeviceConfig,
|
device_config: DeviceConfig,
|
||||||
local_rank: int,
|
cache_config: CacheConfig,
|
||||||
rank: int,
|
vision_language_config: Optional[VisionLanguageConfig],
|
||||||
distributed_init_method: str,
|
|
||||||
lora_config: Optional[LoRAConfig] = None,
|
|
||||||
vision_language_config: Optional[VisionLanguageConfig] = None,
|
|
||||||
kv_cache_dtype: Optional[str] = "auto",
|
|
||||||
is_driver_worker: bool = False,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
self.model_config = model_config
|
self.model_config = model_config
|
||||||
self.parallel_config = parallel_config
|
self.parallel_config = parallel_config
|
||||||
self.scheduler_config = scheduler_config
|
self.scheduler_config = scheduler_config
|
||||||
self.device_config = device_config
|
self.device_config = device_config
|
||||||
self.local_rank = local_rank
|
self.cache_config = cache_config
|
||||||
self.rank = rank
|
|
||||||
self.distributed_init_method = distributed_init_method
|
|
||||||
self.lora_config = lora_config
|
|
||||||
self.is_driver_worker = is_driver_worker
|
|
||||||
if self.is_driver_worker:
|
|
||||||
assert self.rank == 0, "The driver worker must have rank 0."
|
|
||||||
|
|
||||||
self.vision_language_config = vision_language_config
|
self.vision_language_config = vision_language_config
|
||||||
if self.vision_language_config:
|
|
||||||
assert not self.lora_config, (
|
|
||||||
"To be tested: vision language model with LoRA settings.")
|
|
||||||
|
|
||||||
assert self.device_config.device_type == "tpu"
|
assert self.device_config.device_type == "tpu"
|
||||||
self.device_config.device = xm.xla_device()
|
|
||||||
self.device = self.device_config.device
|
if self.cache_config.cache_dtype == "auto":
|
||||||
|
self.cache_dtype = self.model_config.dtype
|
||||||
|
else:
|
||||||
|
self.cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
|
||||||
|
self.cache_config.cache_dtype]
|
||||||
|
|
||||||
self.model_runner = TPUModelRunner(
|
self.model_runner = TPUModelRunner(
|
||||||
model_config,
|
model_config,
|
||||||
parallel_config,
|
parallel_config,
|
||||||
scheduler_config,
|
scheduler_config,
|
||||||
device_config,
|
device_config,
|
||||||
lora_config=self.lora_config,
|
|
||||||
kv_cache_dtype=kv_cache_dtype,
|
|
||||||
is_driver_worker=is_driver_worker,
|
|
||||||
vision_language_config=vision_language_config)
|
vision_language_config=vision_language_config)
|
||||||
self.cache_config = None
|
|
||||||
self.tpu_cache = None
|
self.tpu_cache = None
|
||||||
|
|
||||||
def init_device(self) -> None:
|
def init_device(self) -> None:
|
||||||
# Set random seed.
|
# Set random seed.
|
||||||
self._set_random_seed(self.model_config.seed)
|
set_random_seed(self.model_config.seed)
|
||||||
|
# TODO: JAX
|
||||||
|
|
||||||
def load_model(self):
|
def load_model(self):
|
||||||
self.model_runner.load_model()
|
self.model_runner.load_model()
|
||||||
|
|
||||||
def warm_up_model(self) -> None:
|
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
||||||
# Reset the seed to ensure that the random state is not affected by
|
num_tpu_blocks = 100
|
||||||
# the model initialization and profiling.
|
return num_tpu_blocks, 0
|
||||||
self._set_random_seed(self.model_config.seed)
|
|
||||||
|
def initialize_cache(
|
||||||
|
self,
|
||||||
|
num_gpu_blocks: int,
|
||||||
|
num_cpu_blocks: int,
|
||||||
|
) -> None:
|
||||||
|
self.cache_config.num_gpu_blocks = num_gpu_blocks
|
||||||
|
self.cache_config.num_cpu_blocks = num_cpu_blocks
|
||||||
|
self.block_size = self.cache_config.block_size
|
||||||
|
|
||||||
|
dtype = _torch_dtype_to_jax(self.cache_dtype)
|
||||||
|
num_layers = self.model_config.get_num_layers(self.parallel_config)
|
||||||
|
num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config)
|
||||||
|
head_size = self.model_config.get_head_size()
|
||||||
|
self.tpu_cache = [
|
||||||
|
jnp.zeros(
|
||||||
|
(2, num_kv_heads, num_gpu_blocks, self.block_size, head_size),
|
||||||
|
dtype=dtype) for _ in range(num_layers)
|
||||||
|
]
|
||||||
|
self.model_runner.block_size = self.block_size
|
||||||
|
|
||||||
|
def get_cache_block_size_bytes(self) -> int:
|
||||||
|
head_size = self.model_config.get_head_size()
|
||||||
|
num_heads = self.model_config.get_num_kv_heads(self.parallel_config)
|
||||||
|
num_layers = self.model_config.get_num_layers(self.parallel_config)
|
||||||
|
|
||||||
|
key_cache_block = self.cache_config.block_size * num_heads * head_size
|
||||||
|
value_cache_block = key_cache_block
|
||||||
|
total = num_layers * (key_cache_block + value_cache_block)
|
||||||
|
dtype_size = get_dtype_size(self.cache_dtype)
|
||||||
|
return dtype_size * total
|
||||||
|
|
||||||
@torch.inference_mode()
|
|
||||||
def execute_model(
|
def execute_model(
|
||||||
self,
|
self,
|
||||||
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None,
|
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None,
|
||||||
@ -101,84 +125,11 @@ class TPUWorker:
|
|||||||
self.tpu_cache)
|
self.tpu_cache)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def allocate_kv_cache(self, cache_config: CacheConfig) -> None:
|
|
||||||
self.cache_config = cache_config
|
|
||||||
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
|
|
||||||
cache_config.num_gpu_blocks, cache_config.block_size, self.num_heads, self.head_size)
|
|
||||||
kv_cache: List[torch.Tensor] = []
|
|
||||||
for _ in range(self.num_layers):
|
|
||||||
kv_cache.append(
|
|
||||||
torch.empty(kv_cache_shape,
|
|
||||||
dtype=self.dtype,
|
|
||||||
device=self.device))
|
|
||||||
self.tpu_cache = kv_cache
|
|
||||||
|
|
||||||
def _set_random_seed(self, seed: int) -> None:
|
def _torch_dtype_to_jax(dtype: torch.dtype) -> jnp.dtype:
|
||||||
xm.set_rng_state(seed, device=self.device)
|
mapping = {
|
||||||
set_random_seed(seed)
|
torch.float32: jnp.float32,
|
||||||
|
torch.float16: jnp.float16,
|
||||||
|
torch.bfloat16: jnp.bfloat16,
|
||||||
class CacheEngine:
|
}
|
||||||
|
return mapping[dtype]
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
cache_config: CacheConfig,
|
|
||||||
model_config: ModelConfig,
|
|
||||||
parallel_config: ParallelConfig,
|
|
||||||
device_config: DeviceConfig,
|
|
||||||
) -> None:
|
|
||||||
self.cache_config = cache_config
|
|
||||||
self.model_config = model_config
|
|
||||||
self.parallel_config = parallel_config
|
|
||||||
self.device_config = device_config
|
|
||||||
|
|
||||||
self.head_size = model_config.get_head_size()
|
|
||||||
self.num_layers = model_config.get_num_layers(parallel_config)
|
|
||||||
self.num_heads = model_config.get_num_kv_heads(parallel_config)
|
|
||||||
|
|
||||||
self.block_size = cache_config.block_size
|
|
||||||
self.num_tpu_blocks = cache_config.num_gpu_blocks
|
|
||||||
|
|
||||||
if cache_config.cache_dtype == "auto":
|
|
||||||
self.dtype = model_config.dtype
|
|
||||||
else:
|
|
||||||
self.dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
|
|
||||||
self.device = device_config.device
|
|
||||||
|
|
||||||
# Get attention backend.
|
|
||||||
self.attn_backend = get_attn_backend(self.dtype)
|
|
||||||
|
|
||||||
# Initialize the cache.
|
|
||||||
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
|
|
||||||
self.num_tpu_blocks, self.block_size, self.num_heads, self.head_size)
|
|
||||||
self.tpu_cache: List[torch.Tensor] = []
|
|
||||||
for _ in range(self.num_layers):
|
|
||||||
self.tpu_cache.append(
|
|
||||||
torch.empty(kv_cache_shape,
|
|
||||||
dtype=self.dtype,
|
|
||||||
device=self.device))
|
|
||||||
|
|
||||||
def copy(self, src_to_dsts: Dict[int, List[int]]) -> None:
|
|
||||||
raise NotImplementedError(
|
|
||||||
"Copying blocks is not supported on TPU backend.")
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_cache_block_size(
|
|
||||||
block_size: int,
|
|
||||||
cache_dtype: str,
|
|
||||||
model_config: ModelConfig,
|
|
||||||
parallel_config: ParallelConfig,
|
|
||||||
) -> int:
|
|
||||||
head_size = model_config.get_head_size()
|
|
||||||
num_heads = model_config.get_num_kv_heads(parallel_config)
|
|
||||||
num_layers = model_config.get_num_layers(parallel_config)
|
|
||||||
|
|
||||||
key_cache_block = block_size * num_heads * head_size
|
|
||||||
value_cache_block = key_cache_block
|
|
||||||
total = num_layers * (key_cache_block + value_cache_block)
|
|
||||||
if cache_dtype == "auto":
|
|
||||||
dtype = model_config.dtype
|
|
||||||
else:
|
|
||||||
dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype]
|
|
||||||
dtype_size = get_dtype_size(dtype)
|
|
||||||
return dtype_size * total
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user