mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 05:15:01 +08:00
[Hardware][TPU] Implement tensor parallelism with Ray (#5871)
This commit is contained in:
parent
14dbd5a767
commit
52f07e3dec
@ -4,4 +4,5 @@
|
|||||||
# Dependencies for TPU
|
# Dependencies for TPU
|
||||||
# Currently, the TPU backend uses a nightly version of PyTorch XLA.
|
# Currently, the TPU backend uses a nightly version of PyTorch XLA.
|
||||||
# You can install the dependencies in Dockerfile.tpu.
|
# You can install the dependencies in Dockerfile.tpu.
|
||||||
|
ray
|
||||||
triton # To avoid import errors
|
triton # To avoid import errors
|
||||||
|
|||||||
@ -55,8 +55,8 @@ class PallasMetadata(AttentionMetadata):
|
|||||||
|
|
||||||
# Currently, input sequences can only contain all prefills
|
# Currently, input sequences can only contain all prefills
|
||||||
# or all decoding.
|
# or all decoding.
|
||||||
block_tables: Optional[torch.Tensor]
|
block_tables: Optional[torch.Tensor] = None
|
||||||
context_lens: Optional[torch.Tensor]
|
context_lens: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def prefill_metadata(self) -> Optional["PallasMetadata"]:
|
def prefill_metadata(self) -> Optional["PallasMetadata"]:
|
||||||
|
|||||||
@ -394,8 +394,14 @@ class LLMEngine:
|
|||||||
from vllm.executor.neuron_executor import NeuronExecutor
|
from vllm.executor.neuron_executor import NeuronExecutor
|
||||||
executor_class = NeuronExecutor
|
executor_class = NeuronExecutor
|
||||||
elif engine_config.device_config.device_type == "tpu":
|
elif engine_config.device_config.device_type == "tpu":
|
||||||
from vllm.executor.tpu_executor import TPUExecutor
|
if distributed_executor_backend == "ray":
|
||||||
executor_class = TPUExecutor
|
initialize_ray_cluster(engine_config.parallel_config)
|
||||||
|
from vllm.executor.ray_tpu_executor import RayTPUExecutor
|
||||||
|
executor_class = RayTPUExecutor
|
||||||
|
else:
|
||||||
|
assert distributed_executor_backend is None
|
||||||
|
from vllm.executor.tpu_executor import TPUExecutor
|
||||||
|
executor_class = TPUExecutor
|
||||||
elif engine_config.device_config.device_type == "cpu":
|
elif engine_config.device_config.device_type == "cpu":
|
||||||
from vllm.executor.cpu_executor import CPUExecutor
|
from vllm.executor.cpu_executor import CPUExecutor
|
||||||
executor_class = CPUExecutor
|
executor_class = CPUExecutor
|
||||||
|
|||||||
313
vllm/executor/ray_tpu_executor.py
Normal file
313
vllm/executor/ray_tpu_executor.py
Normal file
@ -0,0 +1,313 @@
|
|||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
from collections import defaultdict
|
||||||
|
from itertools import islice, repeat
|
||||||
|
from typing import (TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, Tuple,
|
||||||
|
Union)
|
||||||
|
|
||||||
|
import vllm.envs as envs
|
||||||
|
from vllm.executor.executor_base import ExecutorAsyncBase
|
||||||
|
from vllm.executor.ray_utils import RayWorkerWrapper, ray
|
||||||
|
from vllm.executor.tpu_executor import TPUExecutor
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.sequence import ExecuteModelRequest, SamplerOutput
|
||||||
|
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
|
||||||
|
get_vllm_instance_id, make_async)
|
||||||
|
|
||||||
|
if ray is not None:
|
||||||
|
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ray.util.placement_group import PlacementGroup
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class RayTPUExecutor(TPUExecutor):
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
# This is non-None when the execute model loop is running
|
||||||
|
# in the parallel workers. It's a coroutine in the AsyncLLMEngine case.
|
||||||
|
self.parallel_worker_tasks: Optional[Union[Any, Awaitable[Any]]] = None
|
||||||
|
# Updated by implementations that require additional args to be passed
|
||||||
|
# to the _run_workers execute_model call
|
||||||
|
self.extra_execute_model_run_workers_kwargs: Dict[str, Any] = {}
|
||||||
|
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
def _init_executor(self) -> None:
|
||||||
|
assert self.parallel_config.distributed_executor_backend == "ray"
|
||||||
|
placement_group = self.parallel_config.placement_group
|
||||||
|
|
||||||
|
# Disable Ray usage stats collection.
|
||||||
|
ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0")
|
||||||
|
if ray_usage != "1":
|
||||||
|
os.environ["RAY_USAGE_STATS_ENABLED"] = "0"
|
||||||
|
|
||||||
|
# Create the parallel TPU workers.
|
||||||
|
self._init_workers_ray(placement_group)
|
||||||
|
|
||||||
|
def _init_workers_ray(self, placement_group: "PlacementGroup",
|
||||||
|
**ray_remote_kwargs):
|
||||||
|
# The driver dummy worker does not actually use any resources.
|
||||||
|
# It holds the resource for the driver worker.
|
||||||
|
self.driver_dummy_worker: Optional[RayWorkerWrapper] = None
|
||||||
|
# The remaining workers are the actual ray actors.
|
||||||
|
self.workers: List[RayWorkerWrapper] = []
|
||||||
|
|
||||||
|
# Create the workers.
|
||||||
|
driver_ip = get_ip()
|
||||||
|
for bundle_id, bundle in enumerate(placement_group.bundle_specs):
|
||||||
|
if not bundle.get("TPU", 0):
|
||||||
|
continue
|
||||||
|
scheduling_strategy = PlacementGroupSchedulingStrategy(
|
||||||
|
placement_group=placement_group,
|
||||||
|
placement_group_capture_child_tasks=True,
|
||||||
|
placement_group_bundle_index=bundle_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert self.speculative_config is None
|
||||||
|
worker_module_name = "vllm.worker.tpu_worker"
|
||||||
|
worker_class_name = "TPUWorker"
|
||||||
|
|
||||||
|
worker = ray.remote(
|
||||||
|
num_cpus=0,
|
||||||
|
resources={"TPU": 1},
|
||||||
|
scheduling_strategy=scheduling_strategy,
|
||||||
|
**ray_remote_kwargs,
|
||||||
|
)(RayWorkerWrapper).remote(
|
||||||
|
worker_module_name=worker_module_name,
|
||||||
|
worker_class_name=worker_class_name,
|
||||||
|
trust_remote_code=self.model_config.trust_remote_code,
|
||||||
|
)
|
||||||
|
|
||||||
|
worker_ip = ray.get(worker.get_node_ip.remote())
|
||||||
|
if worker_ip == driver_ip and self.driver_dummy_worker is None:
|
||||||
|
# If the worker is on the same node as the driver, we use it
|
||||||
|
# as the resource holder for the driver process.
|
||||||
|
self.driver_dummy_worker = worker
|
||||||
|
self.driver_worker = RayWorkerWrapper(
|
||||||
|
worker_module_name=worker_module_name,
|
||||||
|
worker_class_name=worker_class_name,
|
||||||
|
trust_remote_code=self.model_config.trust_remote_code,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Else, added to the list of workers.
|
||||||
|
self.workers.append(worker)
|
||||||
|
|
||||||
|
if self.driver_dummy_worker is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Ray does not allocate any TPUs on the driver node. Consider "
|
||||||
|
"adjusting the Ray placement group or running the driver on a "
|
||||||
|
"TPU node.")
|
||||||
|
|
||||||
|
# Get the set of TPU IDs used on each node.
|
||||||
|
worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids",
|
||||||
|
use_dummy_driver=True)
|
||||||
|
|
||||||
|
node_workers = defaultdict(list)
|
||||||
|
for i, (node_id, _) in enumerate(worker_node_and_gpu_ids):
|
||||||
|
node_workers[node_id].append(i)
|
||||||
|
|
||||||
|
VLLM_INSTANCE_ID = get_vllm_instance_id()
|
||||||
|
|
||||||
|
# Set environment variables for the driver and workers.
|
||||||
|
all_args_to_update_environment_variables = [({
|
||||||
|
"VLLM_INSTANCE_ID":
|
||||||
|
VLLM_INSTANCE_ID,
|
||||||
|
"VLLM_TRACE_FUNCTION":
|
||||||
|
str(envs.VLLM_TRACE_FUNCTION),
|
||||||
|
}, ) for _ in worker_node_and_gpu_ids]
|
||||||
|
self._run_workers("update_environment_variables",
|
||||||
|
all_args=all_args_to_update_environment_variables)
|
||||||
|
|
||||||
|
if len(node_workers) == 1:
|
||||||
|
# in single node case, we don't need to get the IP address.
|
||||||
|
# the loopback address is sufficient
|
||||||
|
# NOTE: a node may have several IP addresses, one for each
|
||||||
|
# network interface. `get_ip()` might return any of them,
|
||||||
|
# while they might not work for communication inside the node
|
||||||
|
# if the network setup is complicated. Using the loopback address
|
||||||
|
# solves this issue, as it always works for communication inside
|
||||||
|
# the node.
|
||||||
|
driver_ip = "127.0.0.1"
|
||||||
|
distributed_init_method = get_distributed_init_method(
|
||||||
|
driver_ip, get_open_port())
|
||||||
|
|
||||||
|
# Initialize the actual workers inside worker wrapper.
|
||||||
|
init_worker_all_kwargs = [
|
||||||
|
self._get_worker_kwargs(
|
||||||
|
local_rank=node_workers[node_id].index(rank),
|
||||||
|
rank=rank,
|
||||||
|
distributed_init_method=distributed_init_method,
|
||||||
|
) for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids)
|
||||||
|
]
|
||||||
|
self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs)
|
||||||
|
|
||||||
|
self._run_workers("init_device")
|
||||||
|
self._run_workers("load_model",
|
||||||
|
max_concurrent_workers=self.parallel_config.
|
||||||
|
max_parallel_loading_workers)
|
||||||
|
|
||||||
|
def _driver_execute_model(
|
||||||
|
self,
|
||||||
|
execute_model_req: Optional[ExecuteModelRequest] = None
|
||||||
|
) -> List[SamplerOutput]:
|
||||||
|
"""Run execute_model in the driver worker.
|
||||||
|
|
||||||
|
Passing None will cause the driver to stop the model execution
|
||||||
|
loop running in each of the remote workers.
|
||||||
|
"""
|
||||||
|
return self.driver_worker.execute_method("execute_model",
|
||||||
|
execute_model_req)
|
||||||
|
|
||||||
|
def _run_workers(
|
||||||
|
self,
|
||||||
|
method: str,
|
||||||
|
*args,
|
||||||
|
async_run_remote_workers_only: bool = False,
|
||||||
|
all_args: Optional[List[Tuple[Any, ...]]] = None,
|
||||||
|
all_kwargs: Optional[List[Dict[str, Any]]] = None,
|
||||||
|
use_dummy_driver: bool = False,
|
||||||
|
max_concurrent_workers: Optional[int] = None,
|
||||||
|
use_ray_compiled_dag: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
) -> Any:
|
||||||
|
"""Runs the given method on all workers. Can be used in the following
|
||||||
|
ways:
|
||||||
|
|
||||||
|
- async_run_remote_workers_only: If True the method will be run only
|
||||||
|
in the remote workers, not the driver worker. It will also be
|
||||||
|
run asynchronously and return a list of futures rather than blocking
|
||||||
|
on the results.
|
||||||
|
- args/kwargs: All workers share the same args/kwargs
|
||||||
|
- all_args/all_kwargs: args/kwargs for each worker are specified
|
||||||
|
individually
|
||||||
|
"""
|
||||||
|
|
||||||
|
if max_concurrent_workers:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"max_concurrent_workers is not supported yet.")
|
||||||
|
|
||||||
|
count = len(self.workers)
|
||||||
|
all_worker_args = repeat(args, count) if all_args is None \
|
||||||
|
else islice(all_args, 1, None)
|
||||||
|
all_worker_kwargs = repeat(kwargs, count) if all_kwargs is None \
|
||||||
|
else islice(all_kwargs, 1, None)
|
||||||
|
|
||||||
|
# Start the ray workers first.
|
||||||
|
ray_worker_outputs = [
|
||||||
|
worker.execute_method.remote(method, *worker_args, **worker_kwargs)
|
||||||
|
for (worker, worker_args, worker_kwargs
|
||||||
|
) in zip(self.workers, all_worker_args, all_worker_kwargs)
|
||||||
|
]
|
||||||
|
|
||||||
|
if async_run_remote_workers_only:
|
||||||
|
# Just return futures
|
||||||
|
return ray_worker_outputs
|
||||||
|
|
||||||
|
driver_args = args if all_args is None else all_args[0]
|
||||||
|
driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0]
|
||||||
|
|
||||||
|
# Start the driver worker after all the ray workers.
|
||||||
|
if not use_dummy_driver:
|
||||||
|
driver_worker_output = self.driver_worker.execute_method(
|
||||||
|
method, *driver_args, **driver_kwargs)
|
||||||
|
else:
|
||||||
|
assert self.driver_dummy_worker is not None
|
||||||
|
driver_worker_output = ray.get(
|
||||||
|
self.driver_dummy_worker.execute_method.remote(
|
||||||
|
method, *driver_args, **driver_kwargs))
|
||||||
|
# Get the results of the ray workers.
|
||||||
|
if self.workers:
|
||||||
|
ray_worker_outputs = ray.get(ray_worker_outputs)
|
||||||
|
|
||||||
|
return [driver_worker_output] + ray_worker_outputs
|
||||||
|
|
||||||
|
def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
|
||||||
|
"""Wait for futures returned from _run_workers() with
|
||||||
|
async_run_remote_workers_only to complete."""
|
||||||
|
ray.get(parallel_worker_tasks)
|
||||||
|
|
||||||
|
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
||||||
|
num_blocks = self._run_workers("determine_num_available_blocks", )
|
||||||
|
num_tpu_blocks = min(b[0] for b in num_blocks)
|
||||||
|
num_cpu_blocks = min(b[1] for b in num_blocks)
|
||||||
|
return num_tpu_blocks, num_cpu_blocks
|
||||||
|
|
||||||
|
def initialize_cache(self, num_gpu_blocks: int,
|
||||||
|
num_cpu_blocks: int) -> None:
|
||||||
|
logger.info("# TPU blocks: %d, # CPU blocks: %d", num_gpu_blocks,
|
||||||
|
num_cpu_blocks)
|
||||||
|
self.cache_config.num_gpu_blocks = num_gpu_blocks
|
||||||
|
self.cache_config.num_cpu_blocks = num_cpu_blocks
|
||||||
|
self._run_workers("initialize_cache",
|
||||||
|
num_gpu_blocks=num_gpu_blocks,
|
||||||
|
num_cpu_blocks=num_cpu_blocks)
|
||||||
|
|
||||||
|
def execute_model(
|
||||||
|
self,
|
||||||
|
execute_model_req: ExecuteModelRequest,
|
||||||
|
) -> List[SamplerOutput]:
|
||||||
|
if self.parallel_worker_tasks is None:
|
||||||
|
self.parallel_worker_tasks = self._run_workers(
|
||||||
|
"start_worker_execution_loop",
|
||||||
|
async_run_remote_workers_only=True,
|
||||||
|
**self.extra_execute_model_run_workers_kwargs)
|
||||||
|
|
||||||
|
# Only the driver worker returns the sampling results.
|
||||||
|
return self._driver_execute_model(execute_model_req)
|
||||||
|
|
||||||
|
def stop_remote_worker_execution_loop(self) -> None:
|
||||||
|
if self.parallel_worker_tasks is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
self._driver_execute_model()
|
||||||
|
parallel_worker_tasks = self.parallel_worker_tasks
|
||||||
|
self.parallel_worker_tasks = None
|
||||||
|
# Ensure that workers exit model loop cleanly
|
||||||
|
# (this will raise otherwise)
|
||||||
|
self._wait_for_tasks_completion(parallel_worker_tasks)
|
||||||
|
|
||||||
|
|
||||||
|
class RayTPUExecutorAsync(RayTPUExecutor, ExecutorAsyncBase):
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.driver_exec_method = make_async(self.driver_worker.execute_method)
|
||||||
|
|
||||||
|
async def execute_model_async(
|
||||||
|
self,
|
||||||
|
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
|
||||||
|
if self.parallel_worker_tasks is None:
|
||||||
|
# Start model execution loop running in the parallel workers
|
||||||
|
self.parallel_worker_tasks = asyncio.create_task(
|
||||||
|
self._start_worker_execution_loop())
|
||||||
|
|
||||||
|
# Only the driver worker returns the sampling results.
|
||||||
|
return await self._driver_execute_model_async(execute_model_req)
|
||||||
|
|
||||||
|
async def stop_remote_worker_execution_loop_async(self) -> None:
|
||||||
|
if self.parallel_worker_tasks is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
await self._driver_execute_model_async()
|
||||||
|
parallel_worker_tasks = self.parallel_worker_tasks
|
||||||
|
self.parallel_worker_tasks = None
|
||||||
|
# Ensure that workers exit model loop cleanly
|
||||||
|
# (this will raise otherwise)
|
||||||
|
await parallel_worker_tasks
|
||||||
|
|
||||||
|
async def _driver_execute_model_async(
|
||||||
|
self,
|
||||||
|
execute_model_req: Optional[ExecuteModelRequest] = None
|
||||||
|
) -> List[SamplerOutput]:
|
||||||
|
return await self.driver_exec_method("execute_model",
|
||||||
|
execute_model_req)
|
||||||
|
|
||||||
|
async def _start_worker_execution_loop(self):
|
||||||
|
coros = [
|
||||||
|
worker.execute_method.remote("start_worker_execution_loop")
|
||||||
|
for worker in self.workers
|
||||||
|
]
|
||||||
|
return await asyncio.gather(*coros)
|
||||||
@ -1,6 +1,7 @@
|
|||||||
import time
|
import time
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -45,6 +46,7 @@ class ModelInputForTPU(ModelRunnerInputBase):
|
|||||||
num_samples: int
|
num_samples: int
|
||||||
best_of: List[int]
|
best_of: List[int]
|
||||||
seq_groups: List[List[int]]
|
seq_groups: List[List[int]]
|
||||||
|
virtual_engine: int = 0
|
||||||
|
|
||||||
def as_broadcastable_tensor_dict(
|
def as_broadcastable_tensor_dict(
|
||||||
self) -> Dict[str, Union[int, torch.Tensor]]:
|
self) -> Dict[str, Union[int, torch.Tensor]]:
|
||||||
@ -55,6 +57,9 @@ class ModelInputForTPU(ModelRunnerInputBase):
|
|||||||
"t": self.t,
|
"t": self.t,
|
||||||
"p": self.p,
|
"p": self.p,
|
||||||
"num_samples": self.num_samples,
|
"num_samples": self.num_samples,
|
||||||
|
"best_of": self.best_of,
|
||||||
|
"seq_groups": self.seq_groups,
|
||||||
|
"virtual_engine": self.virtual_engine,
|
||||||
}
|
}
|
||||||
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
|
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
|
||||||
return tensor_dict
|
return tensor_dict
|
||||||
@ -113,16 +118,30 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
|
|||||||
def load_model(self) -> None:
|
def load_model(self) -> None:
|
||||||
self.device = self.device_config.device
|
self.device = self.device_config.device
|
||||||
|
|
||||||
model = get_model(
|
# NOTE(woosuk): While the executor assigns the TP ranks to the worker
|
||||||
model_config=self.model_config,
|
# process, the ranks can be different from the ranks internally assigned
|
||||||
load_config=self.load_config,
|
# by the xm runtime. Therefore, there is a mismatch in the rank
|
||||||
device_config=self.device_config,
|
# assignment between the gloo (cpu) runtime and the xm (tpu) runtime.
|
||||||
parallel_config=self.parallel_config,
|
# This is not a problem in linear layers because all-reduce is
|
||||||
cache_config=self.cache_config,
|
# rank-agnostic. However, it matters for all-gather as the ranks
|
||||||
scheduler_config=self.scheduler_config,
|
# determine the order of concatenating the output tensors.
|
||||||
multimodal_config=self.multimodal_config,
|
# As a workaround, we use the xm's rank assignment only when loading
|
||||||
lora_config=None,
|
# the embedding weights.
|
||||||
)
|
xm_tp_rank = xm.get_ordinal()
|
||||||
|
with patch(
|
||||||
|
"vllm.model_executor.layers.vocab_parallel_embedding."
|
||||||
|
"get_tensor_model_parallel_rank",
|
||||||
|
return_value=xm_tp_rank):
|
||||||
|
model = get_model(
|
||||||
|
model_config=self.model_config,
|
||||||
|
load_config=self.load_config,
|
||||||
|
device_config=self.device_config,
|
||||||
|
parallel_config=self.parallel_config,
|
||||||
|
cache_config=self.cache_config,
|
||||||
|
scheduler_config=self.scheduler_config,
|
||||||
|
multimodal_config=self.multimodal_config,
|
||||||
|
lora_config=None,
|
||||||
|
)
|
||||||
model = model.eval()
|
model = model.eval()
|
||||||
xm.wait_device_ops()
|
xm.wait_device_ops()
|
||||||
|
|
||||||
@ -463,10 +482,11 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
|
|||||||
tensor_dict, attn_backend=self.attn_backend)
|
tensor_dict, attn_backend=self.attn_backend)
|
||||||
return model_input
|
return model_input
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
def execute_model(
|
def execute_model(
|
||||||
self,
|
self,
|
||||||
model_input: ModelInputForTPU,
|
model_input: ModelInputForTPU,
|
||||||
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
|
kv_caches: Optional[List[Any]],
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
num_steps: int = 1,
|
num_steps: int = 1,
|
||||||
) -> List[SamplerOutput]:
|
) -> List[SamplerOutput]:
|
||||||
|
|||||||
@ -70,13 +70,13 @@ class TPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
|
|||||||
|
|
||||||
def init_device(self) -> None:
|
def init_device(self) -> None:
|
||||||
os.environ["PJRT_DEVICE"] = "TPU"
|
os.environ["PJRT_DEVICE"] = "TPU"
|
||||||
self.device = xm.xla_device()
|
|
||||||
self.device_config.device = self.device
|
|
||||||
torch.set_grad_enabled(False)
|
torch.set_grad_enabled(False)
|
||||||
torch.set_default_dtype(self.model_config.dtype)
|
torch.set_default_dtype(self.model_config.dtype)
|
||||||
|
|
||||||
# NOTE(woosuk): This is just a hack to initialize the TP group.
|
# NOTE(woosuk): This is just to initialize the TP group and broadcast
|
||||||
# This cannot perform the actual communication ops.
|
# the input objects on CPU. The all-reduce and all-gather ops on TPU
|
||||||
|
# are invoked by `xm.all_reduce` and `xm.all_gather` which use their
|
||||||
|
# own context.
|
||||||
init_distributed_environment(
|
init_distributed_environment(
|
||||||
world_size=self.parallel_config.world_size,
|
world_size=self.parallel_config.world_size,
|
||||||
rank=self.rank,
|
rank=self.rank,
|
||||||
@ -88,6 +88,11 @@ class TPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
|
|||||||
self.parallel_config.tensor_parallel_size,
|
self.parallel_config.tensor_parallel_size,
|
||||||
self.parallel_config.pipeline_parallel_size)
|
self.parallel_config.pipeline_parallel_size)
|
||||||
|
|
||||||
|
# Device initialization should happen after initializing the distributed
|
||||||
|
# runtime.
|
||||||
|
self.device = xm.xla_device()
|
||||||
|
self.device_config.device = self.device
|
||||||
|
|
||||||
# Set random seed.
|
# Set random seed.
|
||||||
set_random_seed(self.model_config.seed)
|
set_random_seed(self.model_config.seed)
|
||||||
xm.set_rng_state(self.model_config.seed, self.device)
|
xm.set_rng_state(self.model_config.seed, self.device)
|
||||||
@ -200,8 +205,7 @@ class TPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def do_metadata_broadcast(self) -> bool:
|
def do_metadata_broadcast(self) -> bool:
|
||||||
# TODO(woosuk): Support TP.
|
return self.parallel_config.tensor_parallel_size > 1
|
||||||
return False
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
|
def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user