mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-08 07:02:14 +08:00
[Hardware][TPU] Support parallel sampling & Swapping (#5855)
This commit is contained in:
parent
c54269d967
commit
cbc53b6b8d
@ -28,21 +28,35 @@ class PallasAttentionBackend(AttentionBackend):
|
|||||||
) -> Tuple[int, ...]:
|
) -> Tuple[int, ...]:
|
||||||
return (num_kv_heads, num_blocks, block_size, head_size)
|
return (num_kv_heads, num_blocks, block_size, head_size)
|
||||||
|
|
||||||
|
@torch.compile(backend="openxla")
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def swap_blocks(
|
def swap_blocks(
|
||||||
src_kv_cache: torch.Tensor,
|
src_kv_cache: Tuple[torch.Tensor, torch.Tensor],
|
||||||
dst_kv_cache: torch.Tensor,
|
dst_kv_cache: Tuple[torch.Tensor, torch.Tensor],
|
||||||
src_to_dst: Dict[int, int],
|
src_to_dst: Tuple[torch.Tensor, torch.Tensor],
|
||||||
) -> None:
|
) -> None:
|
||||||
raise NotImplementedError("swap_blocks is not implemented.")
|
src_k_cache, src_v_cache = src_kv_cache
|
||||||
|
dst_k_cache, dst_v_cache = dst_kv_cache
|
||||||
|
torch.ops.xla.dynamo_set_buffer_donor_(dst_k_cache, True)
|
||||||
|
torch.ops.xla.dynamo_set_buffer_donor_(dst_v_cache, True)
|
||||||
|
|
||||||
|
device = dst_k_cache.device
|
||||||
|
src_indices, dst_indices = src_to_dst
|
||||||
|
dst_k_cache[:, dst_indices] = src_k_cache[:, src_indices].to(device)
|
||||||
|
dst_v_cache[:, dst_indices] = src_v_cache[:, src_indices].to(device)
|
||||||
|
|
||||||
|
@torch.compile(backend="openxla")
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def copy_blocks(
|
def copy_blocks(
|
||||||
kv_caches: List[torch.Tensor],
|
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
src_to_dists: Dict[int, List[int]],
|
src_to_dists: Tuple[torch.Tensor, torch.Tensor],
|
||||||
) -> None:
|
) -> None:
|
||||||
# TODO(woosuk): Implement this.
|
src_indices, dst_indices = src_to_dists
|
||||||
raise NotImplementedError("copy_blocks is not implemented.")
|
for k_cache, v_cache in kv_caches:
|
||||||
|
torch.ops.xla.dynamo_set_buffer_donor_(k_cache, True)
|
||||||
|
k_cache[:, dst_indices] = k_cache[:, src_indices]
|
||||||
|
torch.ops.xla.dynamo_set_buffer_donor_(v_cache, True)
|
||||||
|
v_cache[:, dst_indices] = v_cache[:, src_indices]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
@ -22,6 +22,9 @@ logger = init_logger(__name__)
|
|||||||
_PAD_SLOT_ID = 0 # FIXME(woosuk)
|
_PAD_SLOT_ID = 0 # FIXME(woosuk)
|
||||||
# FIXME(woosuk): Temporarily disabled top-p sampling since it's too slow.
|
# FIXME(woosuk): Temporarily disabled top-p sampling since it's too slow.
|
||||||
_ENABLE_TOP_P = False
|
_ENABLE_TOP_P = False
|
||||||
|
# FIXME(woosuk): A temporary hack to support `n > 1`.
|
||||||
|
# This can significantly affect the performance if too large.
|
||||||
|
_MAX_NUM_SAMPLES = 128
|
||||||
|
|
||||||
|
|
||||||
class TPUModelRunner:
|
class TPUModelRunner:
|
||||||
@ -143,8 +146,9 @@ class TPUModelRunner:
|
|||||||
p = torch.ones((batch_size, ), dtype=torch.float32, device=self.device)
|
p = torch.ones((batch_size, ), dtype=torch.float32, device=self.device)
|
||||||
|
|
||||||
# Dummy run.
|
# Dummy run.
|
||||||
|
num_samples = _MAX_NUM_SAMPLES if is_prompt else 1
|
||||||
self.model(token_ids, position_ids, kv_caches, attn_metadata,
|
self.model(token_ids, position_ids, kv_caches, attn_metadata,
|
||||||
input_lens, t, p)
|
input_lens, t, p, num_samples)
|
||||||
|
|
||||||
def warmup_model(
|
def warmup_model(
|
||||||
self,
|
self,
|
||||||
@ -268,14 +272,11 @@ class TPUModelRunner:
|
|||||||
input_positions: List[List[int]] = []
|
input_positions: List[List[int]] = []
|
||||||
slot_mapping: List[List[int]] = []
|
slot_mapping: List[List[int]] = []
|
||||||
context_lens: List[int] = []
|
context_lens: List[int] = []
|
||||||
num_seq_groups = len(seq_group_metadata_list)
|
|
||||||
batch_size = _get_padded_batch_size(num_seq_groups)
|
|
||||||
|
|
||||||
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
|
batch_idx = 0
|
||||||
|
for seq_group_metadata in seq_group_metadata_list:
|
||||||
assert not seq_group_metadata.is_prompt
|
assert not seq_group_metadata.is_prompt
|
||||||
|
|
||||||
seq_ids = list(seq_group_metadata.seq_data.keys())
|
seq_ids = list(seq_group_metadata.seq_data.keys())
|
||||||
|
|
||||||
for seq_id in seq_ids:
|
for seq_id in seq_ids:
|
||||||
seq_data = seq_group_metadata.seq_data[seq_id]
|
seq_data = seq_group_metadata.seq_data[seq_id]
|
||||||
generation_token = seq_data.get_last_token_id()
|
generation_token = seq_data.get_last_token_id()
|
||||||
@ -288,14 +289,16 @@ class TPUModelRunner:
|
|||||||
|
|
||||||
assert seq_group_metadata.block_tables is not None
|
assert seq_group_metadata.block_tables is not None
|
||||||
block_table = seq_group_metadata.block_tables[seq_id]
|
block_table = seq_group_metadata.block_tables[seq_id]
|
||||||
self.block_tables[i, :len(block_table)] = block_table
|
self.block_tables[batch_idx, :len(block_table)] = block_table
|
||||||
|
batch_idx += 1
|
||||||
|
|
||||||
block_number = block_table[position // self.block_size]
|
block_number = block_table[position // self.block_size]
|
||||||
block_offset = position % self.block_size
|
block_offset = position % self.block_size
|
||||||
slot = block_number * self.block_size + block_offset
|
slot = block_number * self.block_size + block_offset
|
||||||
slot_mapping.append([slot])
|
slot_mapping.append([slot])
|
||||||
|
|
||||||
num_paddings = batch_size - num_seq_groups
|
batch_size = _get_padded_batch_size(batch_idx)
|
||||||
|
num_paddings = batch_size - batch_idx
|
||||||
input_tokens = input_tokens + [[0]] * num_paddings
|
input_tokens = input_tokens + [[0]] * num_paddings
|
||||||
input_positions = input_positions + [[0]] * num_paddings
|
input_positions = input_positions + [[0]] * num_paddings
|
||||||
slot_mapping = slot_mapping + [[_PAD_SLOT_ID]] * num_paddings
|
slot_mapping = slot_mapping + [[_PAD_SLOT_ID]] * num_paddings
|
||||||
@ -333,14 +336,13 @@ class TPUModelRunner:
|
|||||||
self,
|
self,
|
||||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
padded_batch_size: int,
|
padded_batch_size: int,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor, List[int]]:
|
||||||
assert len(seq_group_metadata_list) > 0
|
assert len(seq_group_metadata_list) > 0
|
||||||
t = []
|
t = []
|
||||||
p = []
|
p = []
|
||||||
|
best_of = []
|
||||||
for seq_group_metadata in seq_group_metadata_list:
|
for seq_group_metadata in seq_group_metadata_list:
|
||||||
assert seq_group_metadata.sampling_params is not None
|
|
||||||
sampling_params = seq_group_metadata.sampling_params
|
sampling_params = seq_group_metadata.sampling_params
|
||||||
|
|
||||||
# NOTE(woosuk): Here we mimic argmax sampling by applying a very
|
# NOTE(woosuk): Here we mimic argmax sampling by applying a very
|
||||||
# low temperature. This is not accurate.
|
# low temperature. This is not accurate.
|
||||||
t.append(sampling_params.temperature
|
t.append(sampling_params.temperature
|
||||||
@ -354,10 +356,11 @@ class TPUModelRunner:
|
|||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"Top-k sampling is currently disabled for the TPU backend "
|
"Top-k sampling is currently disabled for the TPU backend "
|
||||||
"due to performance issues.")
|
"due to performance issues.")
|
||||||
if sampling_params.best_of > 1:
|
if sampling_params.best_of > _MAX_NUM_SAMPLES:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"best_of > 1 is not currently supported by the TPU "
|
f"Best of > {_MAX_NUM_SAMPLES} is not supported by the TPU "
|
||||||
"backend.")
|
"backend.")
|
||||||
|
best_of.append(sampling_params.best_of)
|
||||||
if sampling_params.use_beam_search:
|
if sampling_params.use_beam_search:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"Beam search is not supported by the TPU backend.")
|
"Beam search is not supported by the TPU backend.")
|
||||||
@ -369,13 +372,19 @@ class TPUModelRunner:
|
|||||||
"prompt_logprobs is not currently supported by the TPU "
|
"prompt_logprobs is not currently supported by the TPU "
|
||||||
"backend.")
|
"backend.")
|
||||||
|
|
||||||
num_paddings = padded_batch_size - len(seq_group_metadata_list)
|
# Repeat the sampling params if the seq group has multiple seqs.
|
||||||
|
num_seqs = len(seq_group_metadata.seq_data)
|
||||||
|
t += [t[-1]] * (num_seqs - 1)
|
||||||
|
p += [p[-1]] * (num_seqs - 1)
|
||||||
|
best_of += [best_of[-1]] * (num_seqs - 1)
|
||||||
|
|
||||||
|
num_paddings = padded_batch_size - len(t)
|
||||||
t += [1.0] * num_paddings
|
t += [1.0] * num_paddings
|
||||||
p += [1.0] * num_paddings
|
p += [1.0] * num_paddings
|
||||||
|
|
||||||
t = torch.tensor(t, dtype=torch.float32, device=self.device)
|
t = torch.tensor(t, dtype=torch.float32, device=self.device)
|
||||||
p = torch.tensor(p, dtype=torch.float32, device=self.device)
|
p = torch.tensor(p, dtype=torch.float32, device=self.device)
|
||||||
return t, p
|
return t, p, best_of
|
||||||
|
|
||||||
def _execute_model(
|
def _execute_model(
|
||||||
self,
|
self,
|
||||||
@ -392,28 +401,41 @@ class TPUModelRunner:
|
|||||||
else:
|
else:
|
||||||
inputs = self._prepare_decode(seq_group_metadata_list)
|
inputs = self._prepare_decode(seq_group_metadata_list)
|
||||||
padded_batch_size = inputs[0].shape[0]
|
padded_batch_size = inputs[0].shape[0]
|
||||||
t, p = self._prepare_sample(seq_group_metadata_list, padded_batch_size)
|
t, p, best_of = self._prepare_sample(seq_group_metadata_list,
|
||||||
|
padded_batch_size)
|
||||||
|
num_samples = _MAX_NUM_SAMPLES if is_prompt else 1
|
||||||
|
|
||||||
# Execute the model.
|
# Execute the model.
|
||||||
next_token_ids = self.model(inputs[0], inputs[1], kv_caches,
|
next_token_ids = self.model(inputs[0], inputs[1], kv_caches,
|
||||||
*inputs[2:], t, p)
|
*inputs[2:], t, p, num_samples)
|
||||||
# Retrieve the outputs to CPU.
|
# Retrieve the outputs to CPU.
|
||||||
next_token_ids = next_token_ids.cpu().tolist()
|
next_token_ids = next_token_ids.cpu().tolist()
|
||||||
|
|
||||||
# NOTE(woosuk): Minimal code to construct the sampler outputs.
|
# NOTE(woosuk): Minimal code to construct the sampler outputs.
|
||||||
# The TPU backend does not reuse the sampler, since the TPU backend
|
# The TPU backend does not reuse the sampler, since the TPU backend
|
||||||
# does not support the advanced sampling parameters such as logprobs.
|
# does not support the advanced sampling parameters such as logprobs.
|
||||||
i = 0
|
zero_logprob = Logprob(0.0)
|
||||||
|
batch_idx = 0
|
||||||
sampler_outputs = []
|
sampler_outputs = []
|
||||||
for seq_group_metadata in seq_group_metadata_list:
|
for seq_group_metadata in seq_group_metadata_list:
|
||||||
seq_outputs = []
|
seq_outputs = []
|
||||||
seq_ids = list(seq_group_metadata.seq_data.keys())
|
seq_ids = list(seq_group_metadata.seq_data.keys())
|
||||||
for seq_id in seq_ids:
|
if is_prompt:
|
||||||
next_token_id = next_token_ids[i]
|
assert len(seq_ids) == 1
|
||||||
seq_outputs.append(
|
seq_id = seq_ids[0]
|
||||||
SequenceOutput(seq_id, next_token_id,
|
for i in range(best_of[batch_idx]):
|
||||||
{next_token_id: Logprob(0.0)}))
|
next_token_id = next_token_ids[batch_idx][i]
|
||||||
i += 1
|
seq_outputs.append(
|
||||||
|
SequenceOutput(seq_id, next_token_id,
|
||||||
|
{next_token_id: zero_logprob}))
|
||||||
|
batch_idx += 1
|
||||||
|
else:
|
||||||
|
for seq_id in seq_ids:
|
||||||
|
next_token_id = next_token_ids[batch_idx][0]
|
||||||
|
seq_outputs.append(
|
||||||
|
SequenceOutput(seq_id, next_token_id,
|
||||||
|
{next_token_id: zero_logprob}))
|
||||||
|
batch_idx += 1
|
||||||
sampler_outputs.append(
|
sampler_outputs.append(
|
||||||
CompletionSequenceGroupOutput(seq_outputs, None))
|
CompletionSequenceGroupOutput(seq_outputs, None))
|
||||||
return sampler_outputs
|
return sampler_outputs
|
||||||
@ -458,6 +480,7 @@ class ModelWrapper(nn.Module):
|
|||||||
input_lens: torch.Tensor,
|
input_lens: torch.Tensor,
|
||||||
t: torch.Tensor,
|
t: torch.Tensor,
|
||||||
p: torch.Tensor,
|
p: torch.Tensor,
|
||||||
|
num_samples: int,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Executes the forward pass of the model and samples the next token.
|
"""Executes the forward pass of the model and samples the next token.
|
||||||
|
|
||||||
@ -520,8 +543,9 @@ class ModelWrapper(nn.Module):
|
|||||||
if _ENABLE_TOP_P:
|
if _ENABLE_TOP_P:
|
||||||
logits = _apply_top_p(logits, p.unsqueeze(dim=1))
|
logits = _apply_top_p(logits, p.unsqueeze(dim=1))
|
||||||
probs = torch.softmax(logits, dim=-1, dtype=torch.float32)
|
probs = torch.softmax(logits, dim=-1, dtype=torch.float32)
|
||||||
# FIXME(woosuk): best_of > 1 is not supported.
|
next_token_ids = torch.multinomial(probs,
|
||||||
next_token_ids = torch.multinomial(probs, num_samples=1).squeeze(dim=1)
|
num_samples,
|
||||||
|
replacement=True)
|
||||||
return next_token_ids
|
return next_token_ids
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch_xla.core.xla_model as xm
|
import torch_xla.core.xla_model as xm
|
||||||
@ -117,19 +117,26 @@ class TPUWorker(LoraNotSupportedWorkerBase):
|
|||||||
# Synchronize before measuring the memory usage.
|
# Synchronize before measuring the memory usage.
|
||||||
xm.wait_device_ops()
|
xm.wait_device_ops()
|
||||||
|
|
||||||
|
dtype_btyes = get_dtype_size(self.cache_dtype)
|
||||||
|
block_size = self.cache_config.block_size
|
||||||
|
block_size_bytes = (dtype_btyes * block_size * num_layers * 2 *
|
||||||
|
head_size * num_kv_heads)
|
||||||
|
|
||||||
|
# Calculate the TPU KV cache size based on profiling.
|
||||||
m = xm.get_memory_info(self.device)
|
m = xm.get_memory_info(self.device)
|
||||||
total_memory_size = m["bytes_limit"]
|
total_memory_size = m["bytes_limit"]
|
||||||
usable_memory_size = int(total_memory_size *
|
usable_memory_size = int(total_memory_size *
|
||||||
self.cache_config.gpu_memory_utilization)
|
self.cache_config.gpu_memory_utilization)
|
||||||
profiled = m["bytes_used"] # Weights + intermediate activations.
|
profiled = m["bytes_used"] # Weights + intermediate activations.
|
||||||
kv_cache_bytes = max(usable_memory_size - profiled, 0)
|
tpu_kv_cache_bytes = max(usable_memory_size - profiled, 0)
|
||||||
dtype_btyes = get_dtype_size(self.cache_dtype)
|
num_tpu_blocks = tpu_kv_cache_bytes // block_size_bytes
|
||||||
block_size = self.cache_config.block_size
|
|
||||||
num_tpu_blocks = (kv_cache_bytes //
|
|
||||||
(dtype_btyes * block_size * num_layers * 2 *
|
|
||||||
head_size * num_kv_heads))
|
|
||||||
num_tpu_blocks = (num_tpu_blocks // 8) * 8 # Round down to 8.
|
num_tpu_blocks = (num_tpu_blocks // 8) * 8 # Round down to 8.
|
||||||
return num_tpu_blocks, 0
|
|
||||||
|
# Calculate the CPU KV cache size based on the config.
|
||||||
|
num_cpu_blocks = (self.cache_config.swap_space_bytes //
|
||||||
|
block_size_bytes)
|
||||||
|
num_cpu_blocks = (num_cpu_blocks // 8) * 8 # Round down to 8.
|
||||||
|
return num_tpu_blocks, num_cpu_blocks
|
||||||
|
|
||||||
def initialize_cache(
|
def initialize_cache(
|
||||||
self,
|
self,
|
||||||
@ -145,15 +152,19 @@ class TPUWorker(LoraNotSupportedWorkerBase):
|
|||||||
num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config)
|
num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config)
|
||||||
head_size = self.model_config.get_head_size()
|
head_size = self.model_config.get_head_size()
|
||||||
|
|
||||||
|
self.cpu_cache = []
|
||||||
self.tpu_cache = []
|
self.tpu_cache = []
|
||||||
tpu_cache_shape = self.model_runner.attn_backend.get_kv_cache_shape(
|
tpu_cache_shape = self.model_runner.attn_backend.get_kv_cache_shape(
|
||||||
num_gpu_blocks, self.block_size, num_kv_heads, head_size)
|
num_gpu_blocks, self.block_size, num_kv_heads, head_size)
|
||||||
for _ in range(num_layers):
|
for _ in range(num_layers):
|
||||||
key_cache = torch.zeros(tpu_cache_shape,
|
tpu_k_cache = torch.zeros(tpu_cache_shape,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=self.device)
|
device=self.device)
|
||||||
value_cache = torch.zeros_like(key_cache)
|
tpu_v_cache = torch.zeros_like(tpu_k_cache)
|
||||||
self.tpu_cache.append((key_cache, value_cache))
|
self.tpu_cache.append((tpu_k_cache, tpu_v_cache))
|
||||||
|
cpu_k_cache = torch.zeros_like(tpu_k_cache, device="cpu")
|
||||||
|
cpu_v_cache = torch.zeros_like(tpu_v_cache, device="cpu")
|
||||||
|
self.cpu_cache.append((cpu_k_cache, cpu_v_cache))
|
||||||
self._warmup_model()
|
self._warmup_model()
|
||||||
|
|
||||||
def _warmup_model(self) -> None:
|
def _warmup_model(self) -> None:
|
||||||
@ -187,22 +198,48 @@ class TPUWorker(LoraNotSupportedWorkerBase):
|
|||||||
if not self.is_driver_worker:
|
if not self.is_driver_worker:
|
||||||
self._execute_model_non_driver()
|
self._execute_model_non_driver()
|
||||||
return []
|
return []
|
||||||
|
|
||||||
assert execute_model_req is not None
|
assert execute_model_req is not None
|
||||||
# Currently, TPUWorker does not support swapping.
|
# Issue cache operations.
|
||||||
# TODO(woosuk): Support block copying.
|
self.cache_swap(
|
||||||
assert len(execute_model_req.blocks_to_swap_in) == 0, (
|
execute_model_req.blocks_to_swap_in,
|
||||||
"Swapping is not supported for the TPU backend.")
|
execute_model_req.blocks_to_swap_out,
|
||||||
assert len(execute_model_req.blocks_to_swap_out) == 0, (
|
execute_model_req.blocks_to_copy,
|
||||||
"Swapping is not supported for the TPU backend.")
|
)
|
||||||
assert len(execute_model_req.blocks_to_copy) == 0
|
# Run the model.
|
||||||
|
|
||||||
seq_group_metadata_list = execute_model_req.seq_group_metadata_list
|
seq_group_metadata_list = execute_model_req.seq_group_metadata_list
|
||||||
assert len(seq_group_metadata_list) > 0
|
assert len(seq_group_metadata_list) > 0
|
||||||
output = self.model_runner.execute_model(seq_group_metadata_list,
|
output = self.model_runner.execute_model(seq_group_metadata_list,
|
||||||
self.tpu_cache)
|
self.tpu_cache)
|
||||||
return [output]
|
return [output]
|
||||||
|
|
||||||
|
def cache_swap(
|
||||||
|
self,
|
||||||
|
blocks_to_swap_in: List[Tuple[int, int]],
|
||||||
|
blocks_to_swap_out: List[Tuple[int, int]],
|
||||||
|
blocks_to_copy: List[Tuple[int, int]],
|
||||||
|
) -> None:
|
||||||
|
attn_backend = self.model_runner.attn_backend
|
||||||
|
num_layers = self.model_config.get_num_layers(self.parallel_config)
|
||||||
|
|
||||||
|
if blocks_to_swap_in:
|
||||||
|
# Swap from CPU to TPU.
|
||||||
|
src_to_dst = _make_src_to_dst(blocks_to_swap_in, "cpu",
|
||||||
|
self.device)
|
||||||
|
for i in range(num_layers):
|
||||||
|
attn_backend.swap_blocks(self.cpu_cache[i], self.tpu_cache[i],
|
||||||
|
src_to_dst)
|
||||||
|
if blocks_to_swap_out:
|
||||||
|
# Swap from TPU to CPU.
|
||||||
|
src_to_dst = _make_src_to_dst(blocks_to_swap_out, self.device,
|
||||||
|
"cpu")
|
||||||
|
for i in range(num_layers):
|
||||||
|
attn_backend.swap_blocks(self.tpu_cache[i], self.cpu_cache[i],
|
||||||
|
src_to_dst)
|
||||||
|
if blocks_to_copy:
|
||||||
|
src_to_dst = _make_src_to_dst(blocks_to_copy, self.device,
|
||||||
|
self.device)
|
||||||
|
attn_backend.copy_blocks(self.tpu_cache, src_to_dst)
|
||||||
|
|
||||||
def start_worker_execution_loop(self) -> None:
|
def start_worker_execution_loop(self) -> None:
|
||||||
while self._execute_model_non_driver():
|
while self._execute_model_non_driver():
|
||||||
pass
|
pass
|
||||||
@ -210,3 +247,19 @@ class TPUWorker(LoraNotSupportedWorkerBase):
|
|||||||
def _execute_model_non_driver(self) -> bool:
|
def _execute_model_non_driver(self) -> bool:
|
||||||
self.model_runner.execute_model(None, self.tpu_cache)
|
self.model_runner.execute_model(None, self.tpu_cache)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def _make_src_to_dst(
|
||||||
|
mapping: List[Tuple[int, int]],
|
||||||
|
src_device: Union[torch.device, str],
|
||||||
|
dst_device: Union[torch.device, str],
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
src_indices = [i for i, _ in mapping]
|
||||||
|
dst_indices = [i for _, i in mapping]
|
||||||
|
src_indices = torch.tensor(src_indices,
|
||||||
|
device=src_device,
|
||||||
|
dtype=torch.int64)
|
||||||
|
dst_indices = torch.tensor(dst_indices,
|
||||||
|
device=dst_device,
|
||||||
|
dtype=torch.int64)
|
||||||
|
return src_indices, dst_indices
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user