mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 13:44:58 +08:00
[Hardware][Intel] Isolate CPUModelRunner and ModelRunner for better maintenance (#3824)
This commit is contained in:
parent
08ccee1e83
commit
8afca50889
@ -50,20 +50,15 @@ class TorchSDPABackend(AttentionBackend):
|
||||
|
||||
|
||||
@dataclass
|
||||
class TorchSDPAMetadata(AttentionMetadataPerStage, PagedAttentionMetadata):
|
||||
class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata,
|
||||
AttentionMetadataPerStage):
|
||||
"""Metadata for TorchSDPABackend.
|
||||
"""
|
||||
# Currently, input sequences can only contain all prompts
|
||||
# or all decoding. True if all sequences are prompts.
|
||||
is_prompt: bool
|
||||
slot_mapping: torch.Tensor
|
||||
prompt_lens: Optional[List[int]]
|
||||
prompt_lens_tensor: Optional[torch.Tensor]
|
||||
|
||||
max_subquery_len: Optional[int] = None
|
||||
max_prompt_len: Optional[int] = None
|
||||
subquery_start_loc: Optional[torch.Tensor] = None
|
||||
seq_start_loc: Optional[torch.Tensor] = None
|
||||
use_cuda_graph: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
# Set during the execution of the first attention op.
|
||||
@ -111,7 +106,7 @@ class TorchSDPABackendImpl(AttentionImpl):
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: Optional[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata[TorchSDPAMetadata],
|
||||
attn_metadata: TorchSDPAMetadata,
|
||||
kv_scale: float,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with torch SDPA and PagedAttention.
|
||||
@ -140,51 +135,36 @@ class TorchSDPABackendImpl(AttentionImpl):
|
||||
attn_metadata.kv_cache_dtype,
|
||||
kv_scale)
|
||||
|
||||
num_prefill_tokens = attn_metadata.num_prefill_tokens
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
assert key.shape[0] == num_prefill_tokens + num_decode_tokens
|
||||
assert value.shape[0] == num_prefill_tokens + num_decode_tokens
|
||||
|
||||
output = torch.empty_like(query)
|
||||
# Query for decode. KV is not needed because it is already cached.
|
||||
decode_query = query[num_prefill_tokens:]
|
||||
# QKV for prefill.
|
||||
query = query[:num_prefill_tokens]
|
||||
key = key[:num_prefill_tokens]
|
||||
value = value[:num_prefill_tokens]
|
||||
|
||||
assert query.shape[0] == num_prefill_tokens
|
||||
assert decode_query.shape[0] == num_decode_tokens
|
||||
|
||||
if prefill_meta := attn_metadata.prefill_metadata:
|
||||
if (kv_cache is None or prefill_meta.block_tables.numel() == 0):
|
||||
if attn_metadata.is_prompt:
|
||||
if (kv_cache is None or attn_metadata.block_tables.numel() == 0):
|
||||
if self.num_kv_heads != self.num_heads:
|
||||
key = key.repeat_interleave(self.num_queries_per_kv, dim=1)
|
||||
value = value.repeat_interleave(self.num_queries_per_kv,
|
||||
dim=1)
|
||||
|
||||
if prefill_meta.attn_bias is None:
|
||||
if attn_metadata.attn_bias is None:
|
||||
if self.alibi_slopes is not None:
|
||||
att_masks = _make_alibi_bias(
|
||||
self.alibi_slopes, query.dtype,
|
||||
prefill_meta.prompt_lens) # type: ignore
|
||||
attn_metadata.prompt_lens) # type: ignore
|
||||
elif self.sliding_window is not None:
|
||||
att_masks = _make_sliding_window_bias(
|
||||
prefill_meta.prompt_lens, self.sliding_window,
|
||||
attn_metadata.prompt_lens, self.sliding_window,
|
||||
query.dtype) # type: ignore
|
||||
else:
|
||||
att_masks = [None] * len(prefill_meta.prompt_lens)
|
||||
prefill_meta.attn_bias = att_masks
|
||||
att_masks = [None] * len(attn_metadata.prompt_lens)
|
||||
attn_metadata.attn_bias = att_masks
|
||||
|
||||
query = query.movedim(0, query.dim() - 2)
|
||||
key = key.movedim(0, key.dim() - 2)
|
||||
value = value.movedim(0, value.dim() - 2)
|
||||
|
||||
start = 0
|
||||
out = torch.empty((num_tokens, self.num_heads, self.head_size),
|
||||
output = torch.empty(
|
||||
(num_tokens, self.num_heads, self.head_size),
|
||||
dtype=query.dtype)
|
||||
for prompt_len, mask in zip(prefill_meta.prompt_lens,
|
||||
prefill_meta.attn_bias):
|
||||
for prompt_len, mask in zip(attn_metadata.prompt_lens,
|
||||
attn_metadata.attn_bias):
|
||||
end = start + prompt_len
|
||||
sub_out = scaled_dot_product_attention(
|
||||
query[:, start:end, :],
|
||||
@ -194,32 +174,28 @@ class TorchSDPABackendImpl(AttentionImpl):
|
||||
dropout_p=0.0,
|
||||
is_causal=not self.need_mask,
|
||||
scale=self.scale).movedim(query.dim() - 2, 0)
|
||||
out[start:end, :, :] = sub_out
|
||||
output[start:end, :, :] = sub_out
|
||||
start = end
|
||||
assert out.shape == output[:num_prefill_tokens].shape
|
||||
output[:num_prefill_tokens] = out
|
||||
else:
|
||||
# prefix-enabled attention
|
||||
raise RuntimeError(
|
||||
"Torch SDPA backend doesn't support prefix decoding.")
|
||||
|
||||
if decode_meta := attn_metadata.decode_metadata:
|
||||
else:
|
||||
# Decoding run.
|
||||
out = PagedAttention.forward_decode(
|
||||
decode_query,
|
||||
output = PagedAttention.forward_decode(
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
decode_meta.block_tables,
|
||||
decode_meta.context_lens,
|
||||
decode_meta.max_context_len,
|
||||
attn_metadata.block_tables,
|
||||
attn_metadata.context_lens,
|
||||
attn_metadata.max_context_len,
|
||||
attn_metadata.kv_cache_dtype,
|
||||
self.num_kv_heads,
|
||||
self.scale,
|
||||
self.alibi_slopes,
|
||||
kv_scale,
|
||||
)
|
||||
assert out.shape == output[num_prefill_tokens:].shape
|
||||
output[num_prefill_tokens:]
|
||||
|
||||
# Reshape the output tensor.
|
||||
return output.view(-1, self.num_heads * self.head_size)
|
||||
@ -241,7 +217,7 @@ def _make_alibi_bias(
|
||||
bias = bias[None, :] - bias[:, None]
|
||||
|
||||
num_heads = alibi_slopes.shape[0]
|
||||
bias = bias[None, :].expand(num_heads, prompt_len, prompt_len)
|
||||
bias = bias[None, :].repeat((num_heads, 1, 1))
|
||||
bias.mul_(alibi_slopes[:, None, None])
|
||||
inf_mask = torch.empty(
|
||||
(1, prompt_len, prompt_len),
|
||||
|
||||
@ -25,6 +25,7 @@ class CPUExecutor(ExecutorBase):
|
||||
assert lora_config is None, "cpu backend doesn't support LoRA"
|
||||
model_config = _verify_and_get_model_config(model_config)
|
||||
cache_config = _verify_and_get_cache_config(cache_config)
|
||||
scheduler_config = _verify_and_get_scheduler_config(scheduler_config)
|
||||
|
||||
self.model_config = model_config
|
||||
self.cache_config = cache_config
|
||||
@ -116,6 +117,15 @@ def _verify_and_get_model_config(config: ModelConfig) -> ModelConfig:
|
||||
return config
|
||||
|
||||
|
||||
def _verify_and_get_scheduler_config(
|
||||
config: SchedulerConfig) -> SchedulerConfig:
|
||||
if config.chunked_prefill_enabled:
|
||||
logger.warning("Chunked prefill is not supported on CPU, disable it.")
|
||||
config.chunked_prefill_enabled = False
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def _verify_and_get_cache_config(config: CacheConfig) -> CacheConfig:
|
||||
_GB = 1 << 30
|
||||
if config.enable_prefix_caching:
|
||||
|
||||
@ -372,7 +372,6 @@ def is_pin_memory_available() -> bool:
|
||||
print_warning_once("Pin memory is not supported on Neuron.")
|
||||
return False
|
||||
elif is_cpu():
|
||||
print_warning_once("Pin memory is not supported on CPU.")
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
408
vllm/worker/cpu_model_runner.py
Normal file
408
vllm/worker/cpu_model_runner.py
Normal file
@ -0,0 +1,408 @@
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention import AttentionMetadata, get_attn_backend
|
||||
from vllm.config import (DeviceConfig, LoRAConfig, ModelConfig, ParallelConfig,
|
||||
SchedulerConfig)
|
||||
from vllm.distributed import broadcast_tensor_dict
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor import SamplingMetadata
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.sampling_params import SamplingParams, SamplingType
|
||||
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
|
||||
from vllm.utils import make_tensor_with_pad, maybe_expand_dim
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_PAD_SLOT_ID = -1
|
||||
|
||||
|
||||
class CPUModelRunner:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
device_config: DeviceConfig,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
kv_cache_dtype: Optional[str] = "auto",
|
||||
is_driver_worker: bool = False,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
self.model_config = model_config
|
||||
self.parallel_config = parallel_config
|
||||
self.scheduler_config = scheduler_config
|
||||
self.lora_config = lora_config
|
||||
self.is_driver_worker = is_driver_worker
|
||||
|
||||
# model_config can be None in tests/samplers/test_sampler.py.
|
||||
# FIXME(woosuk): This is a hack to make the tests work. Refactor this.
|
||||
self.sliding_window = (model_config.get_sliding_window()
|
||||
if model_config is not None else None)
|
||||
self.device_config = (device_config
|
||||
if device_config is not None else DeviceConfig())
|
||||
self.device = self.device_config.device
|
||||
|
||||
self.model = None
|
||||
self.block_size = None # Set after initial profiling.
|
||||
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
|
||||
self.attn_backend = get_attn_backend(
|
||||
self.model_config.dtype if model_config is not None else None)
|
||||
|
||||
def load_model(self) -> None:
|
||||
self.model = get_model(self.model_config,
|
||||
self.device_config,
|
||||
lora_config=self.lora_config,
|
||||
parallel_config=self.parallel_config,
|
||||
scheduler_config=self.scheduler_config)
|
||||
|
||||
def _prepare_prompt(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int]]:
|
||||
assert len(seq_group_metadata_list) > 0
|
||||
input_tokens: List[int] = []
|
||||
input_positions: List[int] = []
|
||||
slot_mapping: List[int] = []
|
||||
prompt_lens: List[int] = []
|
||||
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
assert seq_group_metadata.is_prompt
|
||||
seq_ids = list(seq_group_metadata.seq_data.keys())
|
||||
assert len(seq_ids) == 1
|
||||
seq_id = seq_ids[0]
|
||||
|
||||
seq_data = seq_group_metadata.seq_data[seq_id]
|
||||
prompt_tokens = seq_data.get_token_ids()
|
||||
computed_len = seq_data.get_num_computed_tokens()
|
||||
prompt_len = len(prompt_tokens)
|
||||
|
||||
prompt_lens.append(prompt_len) # Prompt token num
|
||||
input_tokens.extend(prompt_tokens) # Token ids
|
||||
|
||||
# Token position ids
|
||||
# NOTE(woosuk): Here we assume that the first token in the prompt
|
||||
# is always the first token in the sequence.
|
||||
input_positions.extend(list(range(computed_len, prompt_len)))
|
||||
|
||||
# Compute the slot mapping.
|
||||
block_table = seq_group_metadata.block_tables[seq_id]
|
||||
# Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
|
||||
# where start_idx is max(0, prompt_len - sliding_window).
|
||||
# For example, if the prompt len is 10, sliding window is 8, and
|
||||
# block size is 4, the first two tokens are masked and the slot
|
||||
# mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
|
||||
start_idx = 0
|
||||
if self.sliding_window is not None:
|
||||
start_idx = max(0, prompt_len - self.sliding_window)
|
||||
|
||||
for i in range(computed_len, prompt_len):
|
||||
if i < start_idx:
|
||||
slot_mapping.append(_PAD_SLOT_ID)
|
||||
continue
|
||||
|
||||
block_number = block_table[i //
|
||||
self.block_size] # type: ignore
|
||||
block_offset = i % self.block_size # type: ignore
|
||||
slot = block_number * self.block_size + block_offset
|
||||
slot_mapping.append(slot)
|
||||
|
||||
num_prompt_tokens = len(input_tokens)
|
||||
|
||||
input_tokens = torch.tensor(input_tokens,
|
||||
dtype=torch.long,
|
||||
device=self.device) # type: ignore
|
||||
input_positions = torch.tensor(input_positions,
|
||||
dtype=torch.long,
|
||||
device=self.device) # type: ignore
|
||||
slot_mapping = torch.tensor(slot_mapping,
|
||||
dtype=torch.long,
|
||||
device=self.device) # type: ignore
|
||||
|
||||
attn_metadata = self.attn_backend.make_metadata(
|
||||
is_prompt=True,
|
||||
prompt_lens=prompt_lens,
|
||||
num_prefills=len(prompt_lens),
|
||||
num_prefill_tokens=num_prompt_tokens,
|
||||
num_decode_tokens=0,
|
||||
prefill_metadata=None,
|
||||
decode_metadata=None,
|
||||
max_context_len=None,
|
||||
context_lens=None,
|
||||
block_tables=torch.tensor([]),
|
||||
slot_mapping=slot_mapping,
|
||||
kv_cache_dtype=self.kv_cache_dtype,
|
||||
)
|
||||
return (
|
||||
input_tokens,
|
||||
input_positions,
|
||||
attn_metadata,
|
||||
prompt_lens,
|
||||
)
|
||||
|
||||
def _prepare_decode(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata]:
|
||||
assert len(seq_group_metadata_list) > 0
|
||||
input_tokens: List[int] = []
|
||||
input_positions: List[int] = []
|
||||
slot_mapping: List[int] = []
|
||||
context_lens: List[int] = []
|
||||
block_tables: List[List[int]] = []
|
||||
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
assert not seq_group_metadata.is_prompt
|
||||
assert seq_group_metadata.token_chunk_size == 1
|
||||
|
||||
seq_ids = list(seq_group_metadata.seq_data.keys())
|
||||
|
||||
for seq_id in seq_ids:
|
||||
seq_data = seq_group_metadata.seq_data[seq_id]
|
||||
generation_token = seq_data.get_last_token_id()
|
||||
input_tokens.append(generation_token)
|
||||
|
||||
seq_len = seq_data.get_len()
|
||||
position = seq_len - 1
|
||||
input_positions.append(position)
|
||||
|
||||
context_len = seq_len if self.sliding_window is None else min(
|
||||
seq_len, self.sliding_window)
|
||||
context_lens.append(context_len)
|
||||
|
||||
block_table = seq_group_metadata.block_tables[seq_id]
|
||||
block_number = block_table[position // self.block_size]
|
||||
block_offset = position % self.block_size
|
||||
slot = block_number * self.block_size + block_offset
|
||||
slot_mapping.append(slot)
|
||||
|
||||
if self.sliding_window is not None:
|
||||
sliding_window_blocks = (self.sliding_window //
|
||||
self.block_size)
|
||||
block_table = block_table[-sliding_window_blocks:]
|
||||
block_tables.append(block_table)
|
||||
|
||||
max_context_len = max(context_lens)
|
||||
|
||||
input_tokens = torch.tensor(input_tokens,
|
||||
dtype=torch.long,
|
||||
device=self.device)
|
||||
input_positions = torch.tensor(input_positions,
|
||||
dtype=torch.long,
|
||||
device=self.device)
|
||||
slot_mapping = torch.tensor(slot_mapping,
|
||||
dtype=torch.long,
|
||||
device=self.device)
|
||||
context_lens = torch.tensor(context_lens,
|
||||
dtype=torch.int,
|
||||
device=self.device)
|
||||
|
||||
max_block_table_len = max(
|
||||
len(block_table) for block_table in block_tables)
|
||||
block_tables = make_tensor_with_pad(
|
||||
block_tables,
|
||||
max_len=max_block_table_len,
|
||||
pad=0,
|
||||
dtype=torch.int,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
attn_metadata = self.attn_backend.make_metadata(
|
||||
is_prompt=False,
|
||||
slot_mapping=slot_mapping,
|
||||
prompt_lens=None,
|
||||
num_prefill_tokens=0,
|
||||
num_decode_tokens=len(input_tokens),
|
||||
max_context_len=max_context_len,
|
||||
num_prefills=0,
|
||||
prefill_metadata=None,
|
||||
decode_metadata=None,
|
||||
context_lens=context_lens,
|
||||
block_tables=block_tables,
|
||||
kv_cache_dtype=self.kv_cache_dtype,
|
||||
)
|
||||
return (
|
||||
input_tokens,
|
||||
input_positions,
|
||||
attn_metadata,
|
||||
)
|
||||
|
||||
def _prepare_sample(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
prompt_lens: List[int],
|
||||
) -> SamplingMetadata:
|
||||
seq_groups: List[Tuple[List[int], SamplingParams]] = []
|
||||
selected_token_indices: List[int] = []
|
||||
generators: List[torch.Generator] = []
|
||||
selected_token_start_idx = 0
|
||||
categorized_sample_indices = {t: [] for t in SamplingType}
|
||||
categorized_sample_indices_start_idx = 0
|
||||
categorized_sampled_token_indices_start_idx = 0
|
||||
|
||||
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
|
||||
seq_ids = list(seq_group_metadata.seq_data.keys())
|
||||
sampling_params = seq_group_metadata.sampling_params
|
||||
seq_groups.append((seq_ids, sampling_params))
|
||||
|
||||
if seq_group_metadata.is_prompt:
|
||||
assert len(seq_ids) == 1
|
||||
subquery_len = prompt_lens[i]
|
||||
if sampling_params.prompt_logprobs is not None:
|
||||
# NOTE: prompt token positions do not need sample, skip
|
||||
categorized_sample_indices_start_idx += subquery_len - 1
|
||||
|
||||
categorized_sample_indices[
|
||||
sampling_params.sampling_type].append([
|
||||
categorized_sample_indices_start_idx,
|
||||
categorized_sampled_token_indices_start_idx
|
||||
])
|
||||
categorized_sample_indices_start_idx += 1
|
||||
categorized_sampled_token_indices_start_idx += 1
|
||||
|
||||
if sampling_params.prompt_logprobs is not None:
|
||||
selected_token_indices.extend(
|
||||
range(selected_token_start_idx,
|
||||
selected_token_start_idx + subquery_len - 1))
|
||||
selected_token_indices.append(selected_token_start_idx +
|
||||
subquery_len - 1)
|
||||
selected_token_start_idx += subquery_len
|
||||
|
||||
if sampling_params.seed is not None:
|
||||
seq_group_metadata.state.generator = torch.Generator(
|
||||
device=self.device).manual_seed(sampling_params.seed)
|
||||
else:
|
||||
num_seqs = len(seq_ids)
|
||||
selected_token_indices.extend(
|
||||
range(selected_token_start_idx,
|
||||
selected_token_start_idx + num_seqs))
|
||||
selected_token_start_idx += num_seqs
|
||||
|
||||
categorized_sample_indices[
|
||||
sampling_params.sampling_type].extend(
|
||||
zip(
|
||||
range(
|
||||
categorized_sample_indices_start_idx,
|
||||
categorized_sample_indices_start_idx +
|
||||
num_seqs),
|
||||
range(
|
||||
categorized_sampled_token_indices_start_idx,
|
||||
categorized_sampled_token_indices_start_idx +
|
||||
num_seqs)))
|
||||
categorized_sample_indices_start_idx += num_seqs
|
||||
categorized_sampled_token_indices_start_idx += num_seqs
|
||||
|
||||
if sampling_params.seed is not None:
|
||||
generators.append(seq_group_metadata.state.generator)
|
||||
|
||||
selected_token_indices = torch.tensor(selected_token_indices,
|
||||
dtype=torch.long)
|
||||
|
||||
categorized_sample_indices = {
|
||||
t: maybe_expand_dim(torch.tensor(seq_ids, dtype=torch.int), 2, 2)
|
||||
for t, seq_ids in categorized_sample_indices.items()
|
||||
}
|
||||
|
||||
seq_data: Dict[int, SequenceData] = {}
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
seq_data.update(seq_group_metadata.seq_data)
|
||||
|
||||
sampling_metadata = SamplingMetadata(
|
||||
seq_groups=seq_groups,
|
||||
seq_data=seq_data,
|
||||
prompt_lens=prompt_lens,
|
||||
selected_token_indices=selected_token_indices,
|
||||
categorized_sample_indices=categorized_sample_indices,
|
||||
generators=generators,
|
||||
)
|
||||
return sampling_metadata
|
||||
|
||||
def prepare_input_tensors(
|
||||
self,
|
||||
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata,
|
||||
SamplingMetadata]:
|
||||
if self.is_driver_worker:
|
||||
# NOTE: We assume that all sequences in the group are all prompts or
|
||||
# all decodes.
|
||||
is_prompt = seq_group_metadata_list[0].is_prompt
|
||||
# Prepare input tensors.
|
||||
if is_prompt:
|
||||
(input_tokens, input_positions, attn_metadata,
|
||||
prompt_lens) = self._prepare_prompt(seq_group_metadata_list)
|
||||
else:
|
||||
(input_tokens, input_positions,
|
||||
attn_metadata) = self._prepare_decode(seq_group_metadata_list)
|
||||
prompt_lens = []
|
||||
sampling_metadata = self._prepare_sample(seq_group_metadata_list,
|
||||
prompt_lens)
|
||||
# Broadcast the metadata.
|
||||
metadata_dict = {
|
||||
"input_tokens": input_tokens,
|
||||
"input_positions": input_positions,
|
||||
"selected_token_indices":
|
||||
sampling_metadata.selected_token_indices,
|
||||
}
|
||||
metadata_dict.update(attn_metadata.asdict_zerocopy())
|
||||
broadcast_tensor_dict(metadata_dict, src=0)
|
||||
else:
|
||||
metadata_dict = broadcast_tensor_dict(src=0)
|
||||
input_tokens = metadata_dict.pop("input_tokens")
|
||||
input_positions = metadata_dict.pop("input_positions")
|
||||
selected_token_indices = metadata_dict.pop(
|
||||
"selected_token_indices")
|
||||
attn_metadata = self.attn_backend.make_metadata(**metadata_dict)
|
||||
sampling_metadata = SamplingMetadata(
|
||||
seq_groups=None,
|
||||
seq_data=None,
|
||||
prompt_lens=None,
|
||||
selected_token_indices=selected_token_indices,
|
||||
categorized_sample_indices=None,
|
||||
generators=None,
|
||||
perform_sampling=False,
|
||||
)
|
||||
|
||||
return (
|
||||
input_tokens,
|
||||
input_positions,
|
||||
attn_metadata,
|
||||
sampling_metadata,
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
self,
|
||||
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
|
||||
kv_caches: List[torch.Tensor],
|
||||
) -> Optional[SamplerOutput]:
|
||||
(input_tokens, input_positions, attn_metadata, sampling_metadata
|
||||
) = self.prepare_input_tensors(seq_group_metadata_list)
|
||||
|
||||
model_executable = self.model
|
||||
execute_model_kwargs = {
|
||||
"input_ids": input_tokens,
|
||||
"positions": input_positions,
|
||||
"kv_caches": kv_caches,
|
||||
"attn_metadata": attn_metadata,
|
||||
}
|
||||
|
||||
hidden_states = model_executable(**execute_model_kwargs)
|
||||
|
||||
# Compute the logits.
|
||||
logits = self.model.compute_logits(hidden_states, sampling_metadata)
|
||||
|
||||
# Only perform sampling in the driver worker.
|
||||
if not sampling_metadata.perform_sampling:
|
||||
return None
|
||||
|
||||
# Sample the next token.
|
||||
output = self.model.sample(
|
||||
logits=logits,
|
||||
sampling_metadata=sampling_metadata,
|
||||
)
|
||||
return output
|
||||
@ -12,25 +12,14 @@ from vllm.distributed import (broadcast_tensor_dict,
|
||||
init_distributed_environment)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor import set_random_seed
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||
from vllm.worker.model_runner import ModelRunner
|
||||
from vllm.worker.cpu_model_runner import CPUModelRunner
|
||||
from vllm.worker.worker_base import LoraNotSupportedWorkerBase
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class CPUModelRunner(ModelRunner):
|
||||
|
||||
def load_model(self) -> None:
|
||||
self.model = get_model(self.model_config,
|
||||
self.device_config,
|
||||
lora_config=self.lora_config,
|
||||
parallel_config=self.parallel_config,
|
||||
scheduler_config=self.scheduler_config)
|
||||
|
||||
|
||||
class CPUCacheEngine:
|
||||
"""Manages the KV cache for CPU backend.
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user