mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 16:45:52 +08:00
[Hardware][TPU] Initial support of model parallelism with single worker using SPMD (#18011)
Signed-off-by: Siyuan Liu <lsiyuan@google.com> Co-authored-by: Hossein Sarshar <hossein.sarshar@gmail.com> Co-authored-by: Chengji Yao <chengjiyao@google.com>
This commit is contained in:
parent
c57d577e8d
commit
9112b443a0
@ -155,6 +155,10 @@ run_and_track_test 12 "test_moe_pallas.py" \
|
||||
"python3 -m pytest -s -v /workspace/vllm/tests/tpu/test_moe_pallas.py"
|
||||
run_and_track_test 13 "test_lora.py" \
|
||||
"VLLM_XLA_CHECK_RECOMPILATION=0 python3 -m pytest -s -v /workspace/vllm/tests/tpu/lora/test_lora.py"
|
||||
run_and_track_test 14 "test_tpu_qkv_linear.py" \
|
||||
"python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_tpu_qkv_linear.py"
|
||||
run_and_track_test 15 "test_spmd_model_weight_loading.py" \
|
||||
"python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_spmd_model_weight_loading.py"
|
||||
|
||||
# After all tests have been attempted, exit with the overall status.
|
||||
if [ "$overall_script_exit_code" -ne 0 ]; then
|
||||
|
||||
@ -1,5 +1,8 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import argparse
|
||||
import os
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
prompts = [
|
||||
@ -18,14 +21,28 @@ sampling_params = SamplingParams(temperature=0, top_p=1.0, n=N, max_tokens=16)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="TPU offline inference example")
|
||||
parser.add_argument("--use-spmd", action="store_true", help="Enable SPMD mode")
|
||||
args = parser.parse_args()
|
||||
|
||||
llm_args = {
|
||||
"model": "Qwen/Qwen2-1.5B-Instruct",
|
||||
"max_num_batched_tokens": 64,
|
||||
"max_num_seqs": 4,
|
||||
"max_model_len": 128,
|
||||
}
|
||||
if args.use_spmd:
|
||||
os.environ["VLLM_XLA_USE_SPMD"] = "1"
|
||||
# Can only hardcode the number of chips for now.
|
||||
# calling xr.global_runtime_device_count() beforeing init SPMD env in
|
||||
# torch_xla will mess up the distributed env.
|
||||
llm_args["tensor_parallel_size"] = 8
|
||||
# Use Llama, for num_kv_heads = 8.
|
||||
llm_args["model"] = "meta-llama/Llama-3.1-8B-Instruct"
|
||||
|
||||
# Set `enforce_eager=True` to avoid ahead-of-time compilation.
|
||||
# In real workloads, `enforace_eager` should be `False`.
|
||||
llm = LLM(
|
||||
model="Qwen/Qwen2-1.5B-Instruct",
|
||||
max_num_batched_tokens=64,
|
||||
max_num_seqs=4,
|
||||
max_model_len=128,
|
||||
)
|
||||
llm = LLM(**llm_args)
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
print("-" * 50)
|
||||
for output, answer in zip(outputs, answers):
|
||||
|
||||
67
tests/v1/tpu/test_spmd_model_weight_loading.py
Normal file
67
tests/v1/tpu/test_spmd_model_weight_loading.py
Normal file
@ -0,0 +1,67 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import gc
|
||||
import tempfile
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch_xla.distributed.spmd as xs
|
||||
import torch_xla.runtime as xr
|
||||
|
||||
from vllm.config import set_current_vllm_config
|
||||
from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
|
||||
init_distributed_environment)
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.model_executor.model_loader.tpu import TPUModelLoader
|
||||
|
||||
|
||||
def _setup_environment(model):
|
||||
engine_args = EngineArgs(model=model, )
|
||||
vllm_config = engine_args.create_engine_config()
|
||||
with set_current_vllm_config(vllm_config):
|
||||
temp_file = tempfile.mkstemp()[1]
|
||||
init_distributed_environment(
|
||||
1,
|
||||
0,
|
||||
local_rank=0,
|
||||
distributed_init_method=f"file://{temp_file}",
|
||||
backend="gloo")
|
||||
# Under single worker mode, full model is init first and then
|
||||
# partitioned using GSPMD.
|
||||
ensure_model_parallel_initialized(1, 1)
|
||||
return vllm_config
|
||||
|
||||
|
||||
MESH = None
|
||||
|
||||
|
||||
def _get_spmd_mesh():
|
||||
global MESH
|
||||
if MESH is None:
|
||||
xr.use_spmd()
|
||||
num_devices = xr.global_runtime_device_count()
|
||||
mesh_shape = (num_devices, 1)
|
||||
device_ids = np.array(range(num_devices))
|
||||
MESH = xs.Mesh(device_ids, mesh_shape, ('x', 'y'))
|
||||
return MESH
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", [
|
||||
"Qwen/Qwen2-1.5B-Instruct",
|
||||
"meta-llama/Llama-3.1-8B-Instruct",
|
||||
"meta-llama/Llama-3.1-70B-Instruct",
|
||||
])
|
||||
def test_tpu_model_loader(model):
|
||||
# Skip the 70B test if there are less than 8 chips
|
||||
# TODO: Query using torch xla API, the query API is not working
|
||||
# with SPMD now. However, This test is running under SPMD mode.
|
||||
if '70B' in model and xr.global_runtime_device_count() < 8:
|
||||
pytest.skip(
|
||||
"Skipping 70B model if the TPU VM has less than 8 chips to \
|
||||
avoid OOM.")
|
||||
|
||||
vllm_config = _setup_environment(model)
|
||||
loader = TPUModelLoader(load_config=vllm_config.load_config)
|
||||
mesh = _get_spmd_mesh()
|
||||
model = loader.load_model(vllm_config, vllm_config.model_config, mesh)
|
||||
del model
|
||||
gc.collect()
|
||||
89
tests/v1/tpu/test_tpu_qkv_linear.py
Normal file
89
tests/v1/tpu/test_tpu_qkv_linear.py
Normal file
@ -0,0 +1,89 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import tempfile
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
import torch_xla.distributed.spmd as xs
|
||||
import torch_xla.runtime as xr
|
||||
|
||||
from vllm.config import set_current_vllm_config
|
||||
from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
|
||||
init_distributed_environment)
|
||||
from vllm.distributed.tpu_distributed_utils import XlaQKVParallelLinear
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.model_executor.layers.linear import QKVParallelLinear
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_environment():
|
||||
# This is a fake config used for init dist env.
|
||||
# QKVParallelLinear needs dist env to be initialized.
|
||||
engine_args = EngineArgs(
|
||||
model="Qwen/Qwen2-1.5B-Instruct",
|
||||
max_model_len=64,
|
||||
max_num_batched_tokens=64,
|
||||
max_num_seqs=4,
|
||||
)
|
||||
|
||||
vllm_config = engine_args.create_engine_config()
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
temp_file = tempfile.mkstemp()[1]
|
||||
init_distributed_environment(
|
||||
1,
|
||||
0,
|
||||
local_rank=0,
|
||||
distributed_init_method=f"file://{temp_file}",
|
||||
backend="gloo")
|
||||
ensure_model_parallel_initialized(1, 1)
|
||||
yield
|
||||
|
||||
|
||||
MESH = None
|
||||
|
||||
|
||||
def _get_spmd_mesh():
|
||||
global MESH
|
||||
if MESH is None:
|
||||
xr.use_spmd()
|
||||
num_devices = xr.global_runtime_device_count()
|
||||
mesh_shape = (num_devices, 1)
|
||||
device_ids = np.array(range(num_devices))
|
||||
MESH = xs.Mesh(device_ids, mesh_shape, ('x', 'y'))
|
||||
return MESH
|
||||
|
||||
|
||||
@pytest.mark.parametrize("bias", [False, True])
|
||||
# `xr.use_spmd()` will set a global state, and this state is not reversible.
|
||||
# Therefore, non-SPMD tests should be run before SPMD tests.
|
||||
@pytest.mark.parametrize("mesh", [None, _get_spmd_mesh()])
|
||||
@pytest.mark.parametrize("device", ['cpu', 'xla'])
|
||||
@torch.no_grad()
|
||||
def test_xla_qkv_linear(bias, mesh, device):
|
||||
torch.manual_seed(123)
|
||||
|
||||
qkv_linear = QKVParallelLinear(
|
||||
hidden_size=4096,
|
||||
head_size=128,
|
||||
total_num_heads=32,
|
||||
total_num_kv_heads=8,
|
||||
bias=bias,
|
||||
params_dtype=torch.bfloat16,
|
||||
return_bias=False,
|
||||
)
|
||||
|
||||
qkv_linear.weight.data = torch.rand_like(qkv_linear.weight.data) / 10
|
||||
if bias:
|
||||
qkv_linear.bias.data = torch.rand_like(qkv_linear.bias.data)
|
||||
|
||||
xla_qkv_linear = XlaQKVParallelLinear(qkv_linear, mesh=mesh)
|
||||
|
||||
qkv_linear = qkv_linear.to(device)
|
||||
xla_qkv_linear = xla_qkv_linear.to(device)
|
||||
input_tensor = torch.rand(10, 4096, dtype=torch.bfloat16) / 10
|
||||
input_tensor = input_tensor.to(device)
|
||||
|
||||
output = qkv_linear(input_tensor)
|
||||
xla_output = xla_qkv_linear(input_tensor)
|
||||
assert torch.allclose(output.cpu(), xla_output.cpu())
|
||||
@ -1901,6 +1901,8 @@ class ParallelConfig:
|
||||
if current_platform.is_neuron():
|
||||
# neuron uses single process to control multiple devices
|
||||
backend = "uni"
|
||||
elif current_platform.is_tpu() and envs.VLLM_XLA_USE_SPMD:
|
||||
backend = "uni"
|
||||
elif (current_platform.is_cuda()
|
||||
and cuda_device_count_stateless() < self.world_size):
|
||||
if not ray_found:
|
||||
|
||||
177
vllm/distributed/tpu_distributed_utils.py
Normal file
177
vllm/distributed/tpu_distributed_utils.py
Normal file
@ -0,0 +1,177 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from collections import OrderedDict
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch_xla.distributed.spmd as xs
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class XlaQKVParallelLinear(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
qkv_linear: nn.Module,
|
||||
mesh: Optional["xs.Mesh"] = None):
|
||||
super().__init__()
|
||||
assert isinstance(qkv_linear, QKVParallelLinear)
|
||||
self.skip_bias_add = qkv_linear.skip_bias_add
|
||||
self.return_bias = qkv_linear.return_bias
|
||||
assert qkv_linear.tp_size == 1, "TP > 1 is only supported under SPMD."
|
||||
|
||||
self.q_weight: Parameter
|
||||
self.k_weight: Parameter
|
||||
self.v_weight: Parameter
|
||||
self.q_bias: Optional[Parameter]
|
||||
self.k_bias: Optional[Parameter]
|
||||
self.v_bias: Optional[Parameter]
|
||||
self._load_weights_from_qkv_linear(qkv_linear)
|
||||
if mesh is not None:
|
||||
self._shard_weight(mesh)
|
||||
|
||||
def _shard_weight(self, mesh: "xs.Mesh"):
|
||||
self.q_weight = Parameter(self.q_weight.to('xla'), requires_grad=False)
|
||||
self.k_weight = Parameter(self.k_weight.to('xla'), requires_grad=False)
|
||||
self.v_weight = Parameter(self.v_weight.to('xla'), requires_grad=False)
|
||||
xs.mark_sharding(self.q_weight, mesh, ('x', None))
|
||||
xs.mark_sharding(self.k_weight, mesh, ('x', None))
|
||||
xs.mark_sharding(self.v_weight, mesh, ('x', None))
|
||||
if self.q_bias is not None:
|
||||
assert self.k_bias is not None and self.v_bias is not None, \
|
||||
"QKVParallelLinear should have q, k, and v biases together."
|
||||
self.q_bias = Parameter(self.q_bias.to('xla'), requires_grad=False)
|
||||
xs.mark_sharding(self.q_bias, mesh, ('x', ))
|
||||
self.k_bias = Parameter(self.k_bias.to('xla'), requires_grad=False)
|
||||
xs.mark_sharding(self.k_bias, mesh, ('x', ))
|
||||
self.v_bias = Parameter(self.v_bias.to('xla'), requires_grad=False)
|
||||
xs.mark_sharding(self.v_bias, mesh, ('x', ))
|
||||
|
||||
def _load_weights_from_qkv_linear(self, qkv_linear: nn.Module):
|
||||
q_proj_size, k_proj_size, _ = qkv_linear.output_sizes
|
||||
# The weight of qkv linear is a concatenation of q, k, and v weights
|
||||
# along the output dimension.
|
||||
qkv_weight = qkv_linear.weight.data.cpu()
|
||||
q_weight = Parameter(qkv_weight[:q_proj_size], requires_grad=False)
|
||||
k_weight = Parameter(qkv_weight[q_proj_size:q_proj_size + k_proj_size],
|
||||
requires_grad=False)
|
||||
v_weight = Parameter(qkv_weight[q_proj_size + k_proj_size:],
|
||||
requires_grad=False)
|
||||
self.register_parameter("q_weight", q_weight)
|
||||
self.register_parameter("k_weight", k_weight)
|
||||
self.register_parameter("v_weight", v_weight)
|
||||
|
||||
if qkv_linear.bias is not None:
|
||||
q_bias = Parameter(qkv_linear.bias[:q_proj_size],
|
||||
requires_grad=False)
|
||||
k_bias = Parameter(qkv_linear.bias[q_proj_size:q_proj_size +
|
||||
k_proj_size],
|
||||
requires_grad=False)
|
||||
v_bias = Parameter(qkv_linear.bias[q_proj_size + k_proj_size:],
|
||||
requires_grad=False)
|
||||
self.register_parameter("q_bias", q_bias)
|
||||
self.register_parameter("k_bias", k_bias)
|
||||
self.register_parameter("v_bias", v_bias)
|
||||
else:
|
||||
self.register_parameter("q_bias", None)
|
||||
self.register_parameter("k_bias", None)
|
||||
self.register_parameter("v_bias", None)
|
||||
|
||||
def forward(self, input):
|
||||
# Same forward functionality as QKVParallelLinear, but doing qkv porj
|
||||
# separately.
|
||||
q_bias = self.q_bias if not self.skip_bias_add else None
|
||||
k_bias = self.k_bias if not self.skip_bias_add else None
|
||||
v_bias = self.v_bias if not self.skip_bias_add else None
|
||||
q_proj = F.linear(input, self.q_weight, q_bias)
|
||||
k_proj = F.linear(input, self.k_weight, k_bias)
|
||||
v_proj = F.linear(input, self.v_weight, v_bias)
|
||||
# The q/k/v projections will be split outside of the QKVParallelLinear.
|
||||
# Because we are replacing XlaQKVParallelLinear with the
|
||||
# QKVParallelLinear, we need to concatenate q, k, and v projections to
|
||||
# match the output shape of the QKVParallelLinear implementation even if
|
||||
# it seems to be redundant.
|
||||
# The concat and the following split will be noop, and should be
|
||||
# optimized away by the compiler.
|
||||
qkv_proj = torch.cat([q_proj, k_proj, v_proj], dim=-1)
|
||||
output_bias = torch.cat([q_bias, k_bias, v_bias], dim=-1) if \
|
||||
self.skip_bias_add else None
|
||||
if not self.return_bias:
|
||||
return qkv_proj
|
||||
return qkv_proj, output_bias
|
||||
|
||||
|
||||
def partition_column_parallel_linear(layer: torch.nn.Module,
|
||||
mesh: xs.Mesh) -> torch.nn.Module:
|
||||
assert isinstance(layer, ColumnParallelLinear)
|
||||
xs.mark_sharding(layer.weight, mesh, ('x', None))
|
||||
logger.debug("Applied column-parallel sharding to %s", layer)
|
||||
return layer
|
||||
|
||||
|
||||
def partition_row_parallel_linear(layer: torch.nn.Module,
|
||||
mesh: xs.Mesh) -> torch.nn.Module:
|
||||
assert isinstance(layer, RowParallelLinear)
|
||||
xs.mark_sharding(layer.weight, mesh, (None, 'x'))
|
||||
logger.debug("Applied row-parallel sharding to %s", layer)
|
||||
return layer
|
||||
|
||||
|
||||
def partition_qkv_parallel_linear(layer: torch.nn.Module,
|
||||
mesh: xs.Mesh) -> torch.nn.Module:
|
||||
assert isinstance(layer, QKVParallelLinear)
|
||||
xla_layer = XlaQKVParallelLinear(layer, mesh)
|
||||
logger.debug("Applied qkv parallel sharding to %s", layer)
|
||||
return xla_layer
|
||||
|
||||
|
||||
MODULE_TYPE_TO_WRAPPING_FUNC = OrderedDict([
|
||||
("QKVParallelLinear", partition_qkv_parallel_linear),
|
||||
("ColumnParallelLinear", partition_column_parallel_linear),
|
||||
("RowParallelLinear", partition_row_parallel_linear),
|
||||
])
|
||||
|
||||
|
||||
def get_fqn(module):
|
||||
# Get the fully qualified name of the module
|
||||
return module.__class__.__qualname__
|
||||
|
||||
|
||||
def shard_model(model: torch.nn.Module, mesh: "xs.Mesh") -> None:
|
||||
"""
|
||||
Recursively check a PyTorch model and apply appropriate sharding based on
|
||||
the MODULE_TYPE_TO_WRAPPING_FUNC mapping.
|
||||
|
||||
Args:
|
||||
model: torch.nn.Module to process
|
||||
mesh: An XLA SPMD mesh object used for sharding
|
||||
"""
|
||||
|
||||
def _process_module(module, name=None, parent=None):
|
||||
for module_type, wrapping_func in MODULE_TYPE_TO_WRAPPING_FUNC.items():
|
||||
if get_fqn(module) == module_type:
|
||||
wrapped_module = wrapping_func(module, mesh)
|
||||
|
||||
assert parent is not None and name is not None, (
|
||||
"Top Level module is not expected to be wrapped.")
|
||||
if wrapped_module is not module:
|
||||
# Wrapped module and module are different py object.
|
||||
# The original module should be replaced by the
|
||||
# wrapped_module.
|
||||
logger.debug("replace %s with %s", module, wrapped_module)
|
||||
setattr(parent, name, wrapped_module)
|
||||
|
||||
module = wrapped_module
|
||||
break
|
||||
|
||||
for child_name, child_module in list(module.named_children()):
|
||||
_process_module(child_module, child_name, module)
|
||||
|
||||
_process_module(model)
|
||||
@ -51,6 +51,7 @@ if TYPE_CHECKING:
|
||||
VLLM_USE_RAY_COMPILED_DAG: bool = False
|
||||
VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE: str = "auto"
|
||||
VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM: bool = False
|
||||
VLLM_XLA_USE_SPMD: bool = False
|
||||
VLLM_WORKER_MULTIPROC_METHOD: str = "fork"
|
||||
VLLM_ASSETS_CACHE: str = os.path.join(VLLM_CACHE_ROOT, "assets")
|
||||
VLLM_IMAGE_FETCH_TIMEOUT: int = 5
|
||||
@ -513,6 +514,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
# If set, assert on XLA recompilation after each execution step.
|
||||
"VLLM_XLA_CHECK_RECOMPILATION":
|
||||
lambda: bool(int(os.getenv("VLLM_XLA_CHECK_RECOMPILATION", "0"))),
|
||||
|
||||
# Enable SPMD mode for TPU backend.
|
||||
"VLLM_XLA_USE_SPMD":
|
||||
lambda: bool(int(os.getenv("VLLM_XLA_USE_SPMD", "0"))),
|
||||
"VLLM_FUSED_MOE_CHUNK_SIZE":
|
||||
lambda: int(os.getenv("VLLM_FUSED_MOE_CHUNK_SIZE", "32768")),
|
||||
|
||||
|
||||
112
vllm/model_executor/model_loader/tpu.py
Normal file
112
vllm/model_executor/model_loader/tpu.py
Normal file
@ -0,0 +1,112 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch_xla.core.xla_model as xm
|
||||
import torch_xla.distributed.spmd as xs
|
||||
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
from vllm.distributed.tpu_distributed_utils import get_fqn, shard_model
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader.default_loader import DefaultModelLoader
|
||||
from vllm.model_executor.model_loader.utils import (
|
||||
initialize_model, process_weights_after_loading, set_default_torch_dtype)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class TPUModelLoader(DefaultModelLoader):
|
||||
"""
|
||||
A TPU model loader for model loading under SPMD mode.
|
||||
"""
|
||||
|
||||
def load_model(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
model_config: ModelConfig,
|
||||
mesh: Optional[xs.Mesh] = None,
|
||||
) -> nn.Module:
|
||||
# Initialize model and load weights on CPU. Then, during SPMD partition,
|
||||
# weights are sharded and transferred to TPUs.
|
||||
self.counter_before_loading_weights = time.perf_counter()
|
||||
model_config = vllm_config.model_config
|
||||
assert model_config.quantization is None, "Quantization not supported"
|
||||
target_device = torch.device('cpu')
|
||||
with set_default_torch_dtype(model_config.dtype):
|
||||
with target_device:
|
||||
model = initialize_model(vllm_config=vllm_config)
|
||||
|
||||
load_format = vllm_config.load_config.load_format
|
||||
if load_format != "dummy":
|
||||
weights_to_load = {
|
||||
name
|
||||
for name, _ in model.named_parameters()
|
||||
}
|
||||
all_weights = self.get_all_weights(model_config, model)
|
||||
loaded_weights = model.load_weights(all_weights)
|
||||
self.counter_after_loading_weights = time.perf_counter()
|
||||
logger.info(
|
||||
"Loading weights took %.2f seconds",
|
||||
self.counter_after_loading_weights -
|
||||
self.counter_before_loading_weights)
|
||||
# We only enable strict check for non-quantized models
|
||||
# that have loaded weights tracking currently.
|
||||
if model_config.quantization is None and \
|
||||
loaded_weights is not None:
|
||||
weights_not_loaded = weights_to_load - loaded_weights
|
||||
if weights_not_loaded:
|
||||
raise ValueError(
|
||||
"Following weights were not initialized from "
|
||||
f"checkpoint: {weights_not_loaded}")
|
||||
else:
|
||||
logger.info("Use dummy weight during weight loading.")
|
||||
|
||||
process_weights_after_loading(model, model_config, target_device)
|
||||
|
||||
counter_before_partition = time.perf_counter()
|
||||
model = model.eval()
|
||||
model = model.to('xla')
|
||||
shard_model(model, mesh)
|
||||
counter_after_partition = time.perf_counter()
|
||||
logger.info("Partition model took %.2f seconds",
|
||||
counter_after_partition - counter_before_partition)
|
||||
|
||||
# Ensure the model is properly loaded.
|
||||
self._check_model_is_loaded(mesh, model)
|
||||
|
||||
# Need to torch compile after model sharding are done. Because the
|
||||
# compiler hints ('xs.mark_sharding') are torch ops.
|
||||
if not model_config.is_multimodal_model:
|
||||
model.model = torch.compile(model.model, backend="openxla")
|
||||
else:
|
||||
model.language_model.model = \
|
||||
torch.compile(model.language_model.model, backend="openxla")
|
||||
return model
|
||||
|
||||
def _check_model_is_loaded(self, mesh: Optional[xs.Mesh],
|
||||
model: nn.Module) -> None:
|
||||
"""
|
||||
Ensure the model is properly loaded.
|
||||
1. All model parameters and buffers are on XLA device.
|
||||
2. Non-SPMD friendly layers are replaced as expected.
|
||||
"""
|
||||
device = xm.xla_device()
|
||||
device_type = str(device.type)
|
||||
|
||||
# Check parameters
|
||||
for name, param in model.named_parameters():
|
||||
assert param.device.type == device_type, f"Parameter {name} is on \
|
||||
{param.device.type} instead of {device_type}"
|
||||
|
||||
# Check buffers
|
||||
for name, buffer in model.named_buffers():
|
||||
assert buffer.device.type == device_type, \
|
||||
f"Buffer {name} is on {buffer.device.type} instead of \
|
||||
{device_type}"
|
||||
|
||||
for module in model.modules():
|
||||
if (mesh is not None) and (get_fqn(module) == 'QKVParallelLinear'):
|
||||
raise AssertionError("QKVParallelLinear should be replaced by \
|
||||
XlaQKVParallelLinear under SPMD mode.")
|
||||
@ -49,7 +49,9 @@ def _make_synced_weight_loader(original_weight_loader):
|
||||
|
||||
def _synced_weight_loader(param, *args, **kwargs):
|
||||
original_weight_loader(param, *args, **kwargs)
|
||||
torch._sync(param)
|
||||
# torch._sync doesn't support, is not needed for CPU tensors.
|
||||
if param.device != torch.device("cpu"):
|
||||
torch._sync(param)
|
||||
|
||||
return _synced_weight_loader
|
||||
|
||||
|
||||
@ -7,21 +7,22 @@ from unittest.mock import patch
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed
|
||||
import torch.nn as nn
|
||||
# TPU XLA related
|
||||
import torch_xla.core.xla_model as xm
|
||||
import torch_xla.distributed.spmd as xs
|
||||
import torch_xla.runtime as xr
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.attention.backends.abstract import AttentionType
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
|
||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.config import ParallelConfig, VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.layers import BaseLayerWithLoRA
|
||||
from vllm.model_executor.model_loader import get_model_loader
|
||||
from vllm.model_executor.model_loader.tpu import TPUModelLoader
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (BatchedTensorInputs, MultiModalKwargs,
|
||||
PlaceholderRange)
|
||||
@ -98,6 +99,7 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
original_parallel_config: Optional[ParallelConfig] = None,
|
||||
):
|
||||
self.vllm_config = vllm_config
|
||||
self.model_config = vllm_config.model_config
|
||||
@ -105,6 +107,7 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.lora_config = vllm_config.lora_config
|
||||
self.load_config = vllm_config.load_config
|
||||
self.parallel_config = vllm_config.parallel_config
|
||||
self.original_parallel_config = original_parallel_config
|
||||
self.scheduler_config = vllm_config.scheduler_config
|
||||
self.speculative_config = vllm_config.speculative_config
|
||||
self.prompt_adapter_config = vllm_config.prompt_adapter_config
|
||||
@ -118,6 +121,14 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.device = device
|
||||
self.check_recompilation = envs.VLLM_XLA_CHECK_RECOMPILATION
|
||||
|
||||
# SPMD Related
|
||||
self.use_spmd = envs.VLLM_XLA_USE_SPMD
|
||||
if self.use_spmd:
|
||||
num_devices = xr.global_runtime_device_count()
|
||||
mesh_shape = (num_devices, 1)
|
||||
device_ids = np.array(range(num_devices))
|
||||
self.mesh = xs.Mesh(device_ids, mesh_shape, ('x', 'y'))
|
||||
|
||||
self.enforce_eager = model_config.enforce_eager
|
||||
|
||||
self.num_xla_graphs = 0
|
||||
@ -271,6 +282,15 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
||||
max_num_mm_items_decoder_budget)
|
||||
self.max_num_mm_items_by_modality[modality] = max_num_mm_items
|
||||
|
||||
if not self.use_spmd:
|
||||
self.sample_from_logits_func = torch.compile(
|
||||
self.sample_from_logits,
|
||||
backend="openxla",
|
||||
fullgraph=True,
|
||||
dynamic=False)
|
||||
else:
|
||||
self.sample_from_logits_func = self.sample_from_logits
|
||||
|
||||
def _update_num_xla_graphs(self, case_str):
|
||||
check_comp = self.check_recompilation and not self.enforce_eager
|
||||
if not check_comp:
|
||||
@ -825,9 +845,8 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
||||
logits = self.structured_decode(require_struct_decoding,
|
||||
grammar_bitmask_padded, logits,
|
||||
arange)
|
||||
selected_token_ids = self.sample_from_logits(logits,
|
||||
tpu_sampling_metadata)
|
||||
|
||||
selected_token_ids = self.sample_from_logits_func(
|
||||
logits, tpu_sampling_metadata)
|
||||
# NOTE (NickLucche) Use the original logits (before any penalties or
|
||||
# temperature scaling) for the top-k logprobs. We can't enforce it due
|
||||
# to recompilations outside torch.compiled code, so just make sure
|
||||
@ -935,18 +954,26 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
||||
"vllm.model_executor.layers.vocab_parallel_embedding."
|
||||
"get_tensor_model_parallel_rank",
|
||||
return_value=xm_tp_rank):
|
||||
# model = get_model(vllm_config=self.vllm_config)
|
||||
model_loader = get_model_loader(self.load_config)
|
||||
if not hasattr(self, "model"):
|
||||
logger.info("Loading model from scratch...")
|
||||
model = model_loader.load_model(vllm_config=self.vllm_config,
|
||||
model_config=self.model_config)
|
||||
if self.use_spmd:
|
||||
tpu_loader = TPUModelLoader(
|
||||
load_config=self.vllm_config.load_config)
|
||||
model = tpu_loader.load_model(
|
||||
vllm_config=self.vllm_config,
|
||||
model_config=self.vllm_config.model_config,
|
||||
mesh=self.mesh)
|
||||
else:
|
||||
logger.info(
|
||||
"Model was already initialized. Loading weights inplace..."
|
||||
)
|
||||
model_loader.load_weights(self.model,
|
||||
model_config=self.model_config)
|
||||
# model = get_model(vllm_config=self.vllm_config)
|
||||
model_loader = get_model_loader(self.load_config)
|
||||
if not hasattr(self, "model"):
|
||||
logger.info("Loading model from scratch...")
|
||||
model = model_loader.load_model(
|
||||
vllm_config=self.vllm_config,
|
||||
model_config=self.model_config)
|
||||
else:
|
||||
logger.info("Model was already initialized. \
|
||||
Loading weights inplace...")
|
||||
model_loader.load_weights(self.model,
|
||||
model_config=self.model_config)
|
||||
if self.lora_config is not None:
|
||||
model = self.load_lora_model(model, self.model_config,
|
||||
self.scheduler_config,
|
||||
@ -970,31 +997,25 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
||||
device=self.device)
|
||||
else:
|
||||
input_ids = torch.zeros((num_tokens),
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
dtype=torch.int32).to(self.device)
|
||||
inputs_embeds = None
|
||||
actual_num_reqs = min(num_tokens, self.max_num_reqs)
|
||||
position_ids = torch.zeros(num_tokens,
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
dtype=torch.int32).to(self.device)
|
||||
slot_mapping = torch.zeros(num_tokens,
|
||||
dtype=torch.int64,
|
||||
device=self.device)
|
||||
dtype=torch.int64).to(self.device)
|
||||
block_tables = torch.zeros(
|
||||
(self.max_num_reqs, self.block_table_cpu.shape[1]),
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
dtype=torch.int32).to(self.device)
|
||||
query_lens = [1] * self.max_num_reqs
|
||||
query_start_loc = torch.cumsum(torch.tensor([0] + query_lens,
|
||||
dtype=torch.int32),
|
||||
dim=0,
|
||||
dtype=torch.int32).to(self.device)
|
||||
context_lens = torch.ones((self.max_num_reqs, ),
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
dtype=torch.int32).to(self.device)
|
||||
num_seqs = torch.tensor([actual_num_reqs],
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
dtype=torch.int32).to(self.device)
|
||||
attn_metadata = PallasMetadata(
|
||||
slot_mapping=slot_mapping,
|
||||
block_tables=block_tables,
|
||||
@ -1198,7 +1219,8 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
||||
with self.maybe_select_dummy_loras(
|
||||
self.lora_config, np.array([num_reqs],
|
||||
dtype=np.int32)):
|
||||
self.sample_from_logits(dummy_logits, sampling_metadata)
|
||||
self.sample_from_logits_func(dummy_logits,
|
||||
sampling_metadata)
|
||||
logger.info(" -- num_seqs: %d", num_reqs)
|
||||
xm.wait_device_ops()
|
||||
end = time.perf_counter()
|
||||
@ -1332,14 +1354,22 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
||||
assert tensor_config.size % kv_cache_spec.page_size_bytes == 0
|
||||
num_blocks = tensor_config.size // kv_cache_spec.page_size_bytes
|
||||
if isinstance(kv_cache_spec, AttentionSpec):
|
||||
if self.use_spmd:
|
||||
num_kv_heads = kv_cache_spec.num_kv_heads
|
||||
assert self.original_parallel_config is not None
|
||||
tp_size = \
|
||||
self.original_parallel_config.tensor_parallel_size
|
||||
# TODO: Handle kv cache duplication under SPMD mode.
|
||||
assert num_kv_heads % tp_size == 0, (
|
||||
f"num_kv_heads {num_kv_heads} must be divisible by "
|
||||
f"tp_size {tp_size} under SPMD mode")
|
||||
kv_cache_shape = PallasAttentionBackend.get_kv_cache_shape(
|
||||
num_blocks, kv_cache_spec.block_size,
|
||||
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
|
||||
dtype = kv_cache_spec.dtype
|
||||
|
||||
tpu_kv_cache = torch.zeros(kv_cache_shape,
|
||||
dtype=dtype,
|
||||
device=self.device)
|
||||
dtype=dtype).to(self.device)
|
||||
|
||||
kv_caches[layer_name] = tpu_kv_cache
|
||||
else:
|
||||
@ -1350,6 +1380,11 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.vllm_config.compilation_config.static_forward_context,
|
||||
self.kv_caches)
|
||||
|
||||
if self.use_spmd:
|
||||
# Shard KV Cache
|
||||
for cache in self.kv_caches:
|
||||
xs.mark_sharding(cache, self.mesh, (None, 'x', None, None))
|
||||
|
||||
def reset_dynamo_cache(self):
|
||||
if self.is_multimodal_model:
|
||||
compiled_model = self.model.get_language_model().model
|
||||
@ -1370,7 +1405,9 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
||||
sample_hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.compute_logits(sample_hidden_states, None)
|
||||
|
||||
@torch.compile(backend="openxla", fullgraph=True, dynamic=False)
|
||||
# TODO: Under SPMD mode, sample_from_logits has correctness issue.
|
||||
# Re-enable the torch.compile once the issue is fixed in torchxla.
|
||||
# @torch.compile(backend="openxla", fullgraph=True, dynamic=False)
|
||||
def sample_from_logits(
|
||||
self, logits: torch.Tensor,
|
||||
sampling_metadata: TPUSupportedSamplingMetadata) -> torch.Tensor:
|
||||
|
||||
@ -45,6 +45,15 @@ class TPUWorker:
|
||||
self.lora_config = vllm_config.lora_config
|
||||
self.load_config = vllm_config.load_config
|
||||
self.parallel_config = vllm_config.parallel_config
|
||||
self.use_spmd = envs.VLLM_XLA_USE_SPMD
|
||||
self.original_parallel_config = None
|
||||
if self.use_spmd:
|
||||
# Under SPMD mode, distributed env is initialized as if there is
|
||||
# only one worker/device.
|
||||
self.original_parallel_config = self.parallel_config
|
||||
self.parallel_config.tensor_parallel_size = 1
|
||||
self.parallel_config.pipeline_parallel_size = 1
|
||||
self.parallel_config.world_size = 1
|
||||
self.scheduler_config = vllm_config.scheduler_config
|
||||
self.device_config = vllm_config.device_config
|
||||
self.speculative_config = vllm_config.speculative_config
|
||||
@ -95,10 +104,9 @@ class TPUWorker:
|
||||
torch.set_default_dtype(self.model_config.dtype)
|
||||
|
||||
# Initialize the distributed environment.
|
||||
init_tpu_worker_distributed_environment(self.parallel_config,
|
||||
self.rank,
|
||||
self.distributed_init_method,
|
||||
self.local_rank)
|
||||
self._init_tpu_worker_distributed_environment(
|
||||
self.parallel_config, self.rank, self.distributed_init_method,
|
||||
self.local_rank)
|
||||
|
||||
# Device initialization should happen after initializing
|
||||
# the distributed runtime.
|
||||
@ -132,7 +140,9 @@ class TPUWorker:
|
||||
xr.initialize_cache(per_rank_path, readonly=False)
|
||||
|
||||
# Init ModelRunner here, so that we have access to self.device.
|
||||
self.model_runner = TPUModelRunner(self.vllm_config, self.device)
|
||||
self.model_runner = \
|
||||
TPUModelRunner(self.vllm_config, self.device,
|
||||
self.original_parallel_config)
|
||||
|
||||
if rank == 0:
|
||||
# If usage stat is enabled, collect relevant info.
|
||||
@ -147,9 +157,7 @@ class TPUWorker:
|
||||
|
||||
# Use an empty tensor instead of `None`` to force Dynamo to pass
|
||||
# it by reference, rather by specializing on the value ``None``.
|
||||
tpu_kv_cache = torch.tensor([],
|
||||
dtype=dtype,
|
||||
device=self.device)
|
||||
tpu_kv_cache = torch.tensor([], dtype=dtype).to(self.device)
|
||||
kv_caches[layer_name] = tpu_kv_cache
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
@ -178,9 +186,20 @@ class TPUWorker:
|
||||
|
||||
# Get the maximum amount of memory used by the model weights and
|
||||
# intermediate activations.
|
||||
m = xm.get_memory_info(self.device)
|
||||
total_memory_size = m["bytes_limit"]
|
||||
current_mem = m["bytes_used"]
|
||||
if self.use_spmd:
|
||||
# This is a workaround for the TPU SPMD mode. The get_memory_info
|
||||
# API doesn't work with SPMD mode in PyTorch/XLA.
|
||||
# TODO: use xm.get_memory_info for SPMD once it's supported in
|
||||
# PyTorch/XLA.
|
||||
import tpu_info
|
||||
chip_type, _ = tpu_info.device.get_local_chips()
|
||||
device_usage = tpu_info.metrics.get_chip_usage(chip_type)
|
||||
total_memory_size = device_usage[0].total_memory
|
||||
current_mem = device_usage[0].memory_usage
|
||||
else:
|
||||
m = xm.get_memory_info(self.device)
|
||||
total_memory_size = m["bytes_limit"]
|
||||
current_mem = m["bytes_used"]
|
||||
# Ideally we would use profiled = m["peak_bytes_used"] to
|
||||
# get weights + activations. But there is memory used during
|
||||
# compilation / weight loading that impacts the peak and
|
||||
@ -241,28 +260,30 @@ class TPUWorker:
|
||||
# worker will always be healthy as long as it's running.
|
||||
return
|
||||
|
||||
|
||||
def init_tpu_worker_distributed_environment(
|
||||
parallel_config: ParallelConfig,
|
||||
rank: int,
|
||||
distributed_init_method: Optional[str] = None,
|
||||
local_rank: int = -1,
|
||||
) -> None:
|
||||
"""Initialize the distributed environment."""
|
||||
|
||||
# NOTE(woosuk): This is just to initialize the TP group and broadcast
|
||||
# the input objects on CPU. The all-reduce and all-gather ops on TPU
|
||||
# are invoked by `xm.all_reduce` and `xm.all_gather` which use their
|
||||
# own context.
|
||||
init_distributed_environment(
|
||||
world_size=parallel_config.world_size,
|
||||
rank=rank,
|
||||
local_rank=local_rank,
|
||||
distributed_init_method=distributed_init_method,
|
||||
backend="gloo",
|
||||
)
|
||||
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
|
||||
parallel_config.pipeline_parallel_size)
|
||||
def _init_tpu_worker_distributed_environment(
|
||||
self,
|
||||
parallel_config: ParallelConfig,
|
||||
rank: int,
|
||||
distributed_init_method: Optional[str] = None,
|
||||
local_rank: int = -1,
|
||||
) -> None:
|
||||
"""Initialize the distributed environment."""
|
||||
if self.use_spmd:
|
||||
xr.use_spmd()
|
||||
# NOTE(woosuk): This is just to initialize the TP group and broadcast
|
||||
# the input objects on CPU. The all-reduce and all-gather ops on TPU
|
||||
# are invoked by `xm.all_reduce` and `xm.all_gather` which use their
|
||||
# own context.
|
||||
init_distributed_environment(
|
||||
world_size=parallel_config.world_size,
|
||||
rank=rank,
|
||||
local_rank=local_rank,
|
||||
distributed_init_method=distributed_init_method,
|
||||
backend="gloo",
|
||||
)
|
||||
ensure_model_parallel_initialized(
|
||||
parallel_config.tensor_parallel_size,
|
||||
parallel_config.pipeline_parallel_size)
|
||||
|
||||
|
||||
try:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user