[Hardware][CPU] Cross-attention and Encoder-Decoder models support on CPU backend (#9089)

This commit is contained in:
Isotr0py 2024-10-07 14:50:35 +08:00 committed by GitHub
parent 8c6de96ea1
commit 4f95ffee6f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 811 additions and 264 deletions

View File

@ -23,6 +23,7 @@ docker exec cpu-test-avx2 bash -c "python3 examples/offline_inference.py"
# Run basic model test
docker exec cpu-test bash -c "
pip install pytest matplotlib einops transformers_stream_generator datamodel_code_generator
pytest -v -s tests/models/encoder_decoder/language
pytest -v -s tests/models/decoder_only/language \
--ignore=tests/models/test_fp8.py \
--ignore=tests/models/decoder_only/language/test_jamba.py \

View File

@ -4,13 +4,6 @@ Run `pytest tests/models/encoder_decoder/language/test_bart.py`.
"""
from typing import List, Optional, Tuple, Type
from vllm.utils import is_cpu
if not is_cpu():
# CPU backend is not currently supported with encoder/decoder models
# skip test definitions entirely to avoid importing GPU kernel libs
# (xFormers, etc.)
import pytest
from transformers import AutoModelForSeq2SeqLM
@ -23,6 +16,7 @@ if not is_cpu():
MODELS = ["facebook/bart-base", "facebook/bart-large-cnn"]
def vllm_to_hf_output(
vllm_output: Tuple[List[int], str, Optional[SampleLogprobs]],
decoder_prompt_type: DecoderPromptType,
@ -36,6 +30,7 @@ if not is_cpu():
return output_ids, hf_output_str, out_logprobs
def run_test(
hf_runner: Type[HfRunner],
vllm_runner: Type[VllmRunner],
@ -131,8 +126,7 @@ if not is_cpu():
# decoder-only unit tests expect), so when testing an encoder/decoder
# model we must explicitly specify enforce_eager=True in the VllmRunner
# constructor.
with vllm_runner(
model,
with vllm_runner(model,
dtype=dtype,
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend,
@ -154,16 +148,15 @@ if not is_cpu():
with hf_runner(model, dtype=dtype,
auto_cls=AutoModelForSeq2SeqLM) as hf_model:
hf_outputs = (
hf_model.generate_encoder_decoder_greedy_logprobs_limit(
hf_outputs = (hf_model.generate_encoder_decoder_greedy_logprobs_limit(
prompts,
max_tokens,
num_logprobs,
**hf_kwargs,
))
hf_skip_tokens = (1 if decoder_prompt_type == DecoderPromptType.NONE
else 0)
hf_skip_tokens = (1
if decoder_prompt_type == DecoderPromptType.NONE else 0)
check_logprobs_close(
outputs_0_lst=hf_outputs,
@ -176,14 +169,14 @@ if not is_cpu():
num_outputs_0_skip_tokens=hf_skip_tokens,
)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float", "bfloat16"])
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("decoder_prompt_type", list(DecoderPromptType))
def test_models(hf_runner, vllm_runner, example_encoder_decoder_prompts,
model, dtype, max_tokens, num_logprobs,
decoder_prompt_type) -> None:
def test_models(hf_runner, vllm_runner, example_encoder_decoder_prompts, model,
dtype, max_tokens, num_logprobs, decoder_prompt_type) -> None:
run_test(
hf_runner,
@ -197,6 +190,7 @@ if not is_cpu():
tensor_parallel_size=1,
)
@multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize("distributed_executor_backend", ["ray", "mp"])
@pytest.mark.parametrize("model", ["facebook/bart-large-cnn"])

View File

@ -75,6 +75,22 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
slot_mapping: torch.Tensor
seq_lens: Optional[List[int]]
# Begin encoder attn & enc/dec cross-attn fields...
# Encoder sequence lengths representation
encoder_seq_lens: Optional[List[int]] = None
encoder_seq_lens_tensor: Optional[torch.Tensor] = None
# Maximum sequence length among encoder sequences
max_encoder_seq_len: Optional[int] = None
# Number of tokens input to encoder
num_encoder_tokens: Optional[int] = None
# Cross-attention memory-mapping data structures: slot mapping
# and block tables
cross_slot_mapping: Optional[torch.Tensor] = None
cross_block_tables: Optional[torch.Tensor] = None
def __post_init__(self):
# Set during the execution of the first attention op.
# It is a list because it is needed to set per prompt
@ -82,6 +98,28 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
# from xformer API.
# will not appear in the __repr__ and __init__
self.attn_bias: Optional[List[torch.Tensor]] = None
self.encoder_attn_bias: Optional[List[torch.Tensor]] = None
self.cross_attn_bias: Optional[List[torch.Tensor]] = None
@property
def is_all_encoder_attn_metadata_set(self):
'''
All attention metadata required for encoder attention is set.
'''
return ((self.encoder_seq_lens is not None)
and (self.encoder_seq_lens_tensor is not None)
and (self.max_encoder_seq_len is not None))
@property
def is_all_cross_attn_metadata_set(self):
'''
All attention metadata required for enc/dec cross-attention is set.
Superset of encoder attention required metadata.
'''
return (self.is_all_encoder_attn_metadata_set
and (self.cross_slot_mapping is not None)
and (self.cross_block_tables is not None))
@property
def prefill_metadata(self) -> Optional["TorchSDPAMetadata"]:
@ -101,6 +139,136 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
return self
def get_seq_lens(
self,
attn_type: AttentionType,
):
'''
Extract appropriate sequence lengths from attention metadata
according to attention type.
Arguments:
* attn_metadata: Attention metadata structure associated with attention
* attn_type: encoder attention, decoder self-attention,
encoder/decoder cross-attention
Returns:
* Appropriate sequence lengths tensor for query
* Appropriate sequence lengths tensor for key & value
'''
if attn_type == AttentionType.DECODER:
seq_lens_q = self.seq_lens
seq_lens_kv = self.seq_lens
elif attn_type == AttentionType.ENCODER:
seq_lens_q = self.encoder_seq_lens
seq_lens_kv = self.encoder_seq_lens
elif attn_type == AttentionType.ENCODER_DECODER:
seq_lens_q = self.seq_lens
seq_lens_kv = self.encoder_seq_lens
else:
raise AttributeError(f"Invalid attention type {str(attn_type)}")
return seq_lens_q, seq_lens_kv
def get_attn_bias(
self,
attn_type: AttentionType,
) -> Optional[List[torch.Tensor]]:
'''
Extract appropriate attention bias from attention metadata
according to attention type.
Arguments:
* attn_metadata: Attention metadata structure associated with attention
* attn_type: encoder attention, decoder self-attention,
encoder/decoder cross-attention
Returns:
* Appropriate attention bias value given the attention type
'''
if attn_type == AttentionType.DECODER:
return self.attn_bias
elif attn_type == AttentionType.ENCODER:
return self.encoder_attn_bias
elif attn_type == AttentionType.ENCODER_DECODER:
return self.cross_attn_bias
else:
raise AttributeError(f"Invalid attention type {str(attn_type)}")
def set_attn_bias(
self,
attn_bias: List[torch.Tensor],
attn_type: AttentionType,
) -> None:
'''
Update appropriate attention bias field of attention metadata,
according to attention type.
Arguments:
* attn_metadata: Attention metadata structure associated with attention
* attn_bias: The desired attention bias value
* attn_type: encoder attention, decoder self-attention,
encoder/decoder cross-attention
'''
if attn_type == AttentionType.DECODER:
self.attn_bias = attn_bias
elif attn_type == AttentionType.ENCODER:
self.encoder_attn_bias = attn_bias
elif attn_type == AttentionType.ENCODER_DECODER:
self.cross_attn_bias = attn_bias
else:
raise AttributeError(f"Invalid attention type {str(attn_type)}")
def get_seq_len_block_table_args(
self,
attn_type: AttentionType,
) -> tuple:
'''
The particular choice of sequence-length- and block-table-related
attributes which should be extracted from attn_metadata is dependent
on the type of attention operation.
Decoder attn -> select entirely decoder self-attention-related fields
Encoder/decoder cross-attn -> select encoder sequence lengths &
cross-attn block-tables fields
Encoder attn -> select encoder sequence lengths fields & no block tables
Arguments:
* attn_metadata: Attention metadata structure associated with attention
* is_prompt: True if prefill, False otherwise
* attn_type: encoder attention, decoder self-attention,
encoder/decoder cross-attention
Returns:
* Appropriate sequence-lengths tensor
* Appropriate max sequence-length scalar
* Appropriate block tables (or None)
'''
if attn_type == AttentionType.DECODER:
# Decoder self-attention
# Choose max_seq_len based on whether we are in prompt_run
return (self.seq_lens_tensor, self.max_decode_seq_len,
self.block_tables)
elif attn_type == AttentionType.ENCODER_DECODER:
# Enc/dec cross-attention KVs match encoder sequence length;
# cross-attention utilizes special "cross" block tables
return (self.encoder_seq_lens_tensor, self.max_encoder_seq_len,
self.cross_block_tables)
elif attn_type == AttentionType.ENCODER:
# No block tables associated with encoder attention
return (self.encoder_seq_lens_tensor, self.max_encoder_seq_len,
None)
else:
raise AttributeError(f"Invalid attention type {str(attn_type)}")
class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
@ -171,84 +339,101 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
shape = [num_tokens, num_heads * head_size]
"""
assert k_scale == 1.0 and v_scale == 1.0
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"TorchSDPABackendImpl")
num_tokens, hidden_size = query.shape
if (attn_type == AttentionType.ENCODER
and (not attn_metadata.is_all_encoder_attn_metadata_set)):
raise AttributeError("Encoder attention requires setting "
"encoder metadata attributes.")
elif (attn_type == AttentionType.ENCODER_DECODER
and (not attn_metadata.is_all_cross_attn_metadata_set)):
raise AttributeError("Encoder/decoder cross-attention "
"requires setting cross-attention "
"metadata attributes.")
# Reshape the query, key, and value tensors.
query = query.view(-1, self.num_heads, self.head_size)
if key is not None:
assert value is not None
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)
else:
assert value is None
if kv_cache.numel() > 0:
if (attn_type != AttentionType.ENCODER and kv_cache.numel() > 0):
# KV-cache during decoder-self- or
# encoder-decoder-cross-attention, but not
# during encoder attention.
#
# Even if there are no new key/value pairs to cache,
# we still need to break out key_cache and value_cache
# i.e. for later use by paged attention
key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size)
if (key is not None) and (value is not None):
if attn_type == AttentionType.ENCODER_DECODER:
# Update cross-attention KV cache (prefill-only)
# During cross-attention decode, key & value will be None,
# preventing this IF-statement branch from running
updated_slot_mapping = attn_metadata.cross_slot_mapping
else:
# Update self-attention KV cache (prefill/decode)
updated_slot_mapping = attn_metadata.slot_mapping
PagedAttention.write_to_paged_cache(key, value, key_cache,
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype, k_scale,
v_scale)
updated_slot_mapping,
self.kv_cache_dtype,
k_scale, v_scale)
if attn_metadata.is_prompt:
if attn_type != AttentionType.ENCODER:
# Decoder self-attention supports chunked prefill.
# Encoder/decoder cross-attention requires no chunked
# prefill (100% prefill or 100% decode tokens, no mix)
num_prefill_tokens = attn_metadata.num_prefill_tokens
num_decode_tokens = attn_metadata.num_decode_tokens
else:
# Encoder attention - chunked prefill is not applicable;
# derive token-count from query shape & and treat them
# as 100% prefill tokens
assert attn_metadata.num_encoder_tokens is not None
num_prefill_tokens = attn_metadata.num_encoder_tokens
num_decode_tokens = 0
if attn_type == AttentionType.DECODER:
# Only enforce this shape-constraint for decoder
# self-attention
assert key.shape[0] == num_prefill_tokens + num_decode_tokens
assert value.shape[0] == num_prefill_tokens + num_decode_tokens
if prefill_meta := attn_metadata.prefill_metadata:
assert attn_metadata.seq_lens is not None
if (kv_cache.numel() == 0
or attn_metadata.block_tables.numel() == 0):
if self.num_kv_heads != self.num_heads:
key = key.repeat_interleave(self.num_queries_per_kv, dim=1)
value = value.repeat_interleave(self.num_queries_per_kv,
dim=1)
if attn_metadata.attn_bias is None:
if self.alibi_slopes is not None:
att_masks = _make_alibi_bias(
self.alibi_slopes, query.dtype,
attn_metadata.seq_lens) # type: ignore
elif self.sliding_window is not None:
att_masks = _make_sliding_window_bias(
attn_metadata.seq_lens, self.sliding_window,
query.dtype) # type: ignore
else:
att_masks = [None] * len(attn_metadata.seq_lens)
attn_metadata.attn_bias = att_masks
query = query.movedim(0, query.dim() - 2)
key = key.movedim(0, key.dim() - 2)
value = value.movedim(0, value.dim() - 2)
start = 0
output = torch.empty(
(num_tokens, self.num_heads, self.head_size),
dtype=query.dtype)
for seq_len, mask in zip(attn_metadata.seq_lens,
attn_metadata.attn_bias):
end = start + seq_len
sub_out = scaled_dot_product_attention(
query[None, :, start:end, :],
key[None, :, start:end, :],
value[None, :, start:end, :],
attn_mask=mask,
dropout_p=0.0,
is_causal=not self.need_mask,
scale=self.scale).squeeze(0).movedim(
query.dim() - 2, 0)
output[start:end, :, :] = sub_out
start = end
or prefill_meta.block_tables.numel() == 0):
output = self._run_sdpa_forward(query,
key,
value,
prefill_meta,
attn_type=attn_type)
else:
# prefix-enabled attention
raise RuntimeError(
"Torch SDPA backend doesn't support prefix decoding.")
else:
if decode_meta := attn_metadata.decode_metadata:
# Decoding run.
(
seq_lens_arg,
max_seq_len_arg,
block_tables_arg,
) = decode_meta.get_seq_len_block_table_args(attn_type)
output = PagedAttention.forward_decode(
query,
key_cache,
value_cache,
attn_metadata.block_tables,
attn_metadata.seq_lens_tensor,
attn_metadata.max_decode_seq_len,
block_tables_arg,
seq_lens_arg,
max_seq_len_arg,
self.kv_cache_dtype,
self.num_kv_heads,
self.scale,
@ -260,6 +445,59 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
# Reshape the output tensor.
return output.view(-1, self.num_heads * self.head_size)
def _run_sdpa_forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_metadata: TorchSDPAMetadata,
attn_type: AttentionType = AttentionType.DECODER,
):
if self.num_kv_heads != self.num_heads:
key = key.repeat_interleave(self.num_queries_per_kv, dim=1)
value = value.repeat_interleave(self.num_queries_per_kv, dim=1)
attn_masks = attn_metadata.get_attn_bias(attn_type)
if attn_masks is None:
if self.alibi_slopes is not None:
attn_masks = _make_alibi_bias(
self.alibi_slopes, query.dtype,
attn_metadata.seq_lens) # type: ignore
elif self.sliding_window is not None:
assert attn_metadata.seq_lens is not None
attn_masks = _make_sliding_window_bias(
attn_metadata.seq_lens, self.sliding_window,
query.dtype) # type: ignore
else:
seq_lens, _ = attn_metadata.get_seq_lens(attn_type)
attn_masks = [None] * len(seq_lens)
attn_metadata.set_attn_bias(attn_masks, attn_type)
output = torch.empty_like(query)
query = query.movedim(0, query.dim() - 2)
key = key.movedim(0, key.dim() - 2)
value = value.movedim(0, value.dim() - 2)
causal_attn = (attn_type == AttentionType.DECODER)
seq_lens_q, seq_lens_kv = attn_metadata.get_seq_lens(attn_type)
start_q, start_kv = 0, 0
for seq_len_q, seq_len_kv, mask in zip(seq_lens_q, seq_lens_kv,
attn_masks):
end_q = start_q + seq_len_q
end_kv = start_kv + seq_len_kv
sub_out = scaled_dot_product_attention(
query[None, :, start_q:end_q, :],
key[None, :, start_kv:end_kv, :],
value[None, :, start_kv:end_kv, :],
attn_mask=mask,
dropout_p=0.0,
is_causal=causal_attn and not self.need_mask,
scale=self.scale).squeeze(0).movedim(query.dim() - 2, 0)
output[start_q:end_q, :, :] = sub_out
start_q, start_kv = end_q, end_kv
return output
def _make_alibi_bias(
alibi_slopes: torch.Tensor,

View File

@ -0,0 +1,311 @@
import dataclasses
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, cast
import torch
from vllm.attention import AttentionMetadata
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.multimodal import MultiModalInputs
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
from vllm.utils import make_tensor_with_pad
from vllm.worker.cpu_model_runner import (CPUModelRunner,
ModelInputForCPUBuilder,
ModelInputForCPUWithSamplingMetadata)
from vllm.worker.model_runner_base import (
_add_attn_metadata_broadcastable_dict,
_add_sampling_metadata_broadcastable_dict)
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
@dataclasses.dataclass(frozen=True)
class EncoderDecoderModelInputForCPU(ModelInputForCPUWithSamplingMetadata):
"""
Used by the EncoderDecoderModelRunner.
"""
encoder_input_tokens: Optional[torch.Tensor] = None
encoder_input_positions: Optional[torch.Tensor] = None
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
tensor_dict = {
"input_tokens": self.input_tokens,
"input_positions": self.input_positions,
"encoder_input_tokens": self.encoder_input_tokens,
"encoder_input_positions": self.encoder_input_positions,
}
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
_add_sampling_metadata_broadcastable_dict(tensor_dict,
self.sampling_metadata)
return tensor_dict
@classmethod
def from_broadcasted_tensor_dict(
cls,
tensor_dict: Dict[str, Any],
attn_backend: Optional["AttentionBackend"] = None,
) -> "EncoderDecoderModelInputForCPU":
return cast(
EncoderDecoderModelInputForCPU,
super().from_broadcasted_tensor_dict(tensor_dict, attn_backend))
class CPUEncoderDecoderModelRunner(CPUModelRunner):
_model_input_cls: Type[EncoderDecoderModelInputForCPU] = (
EncoderDecoderModelInputForCPU)
_builder_cls: Type[ModelInputForCPUBuilder] = ModelInputForCPUBuilder
def _list_to_int32_tensor(
self,
_list: List[int],
) -> torch.Tensor:
return torch.tensor(_list, dtype=torch.int32, device=self.device)
def _list_to_long_tensor(
self,
_list: List[int],
) -> torch.Tensor:
return torch.tensor(_list, dtype=torch.long, device=self.device)
def _empty_int32_tensor(self) -> torch.Tensor:
return self._list_to_int32_tensor([])
def _empty_long_tensor(self) -> torch.Tensor:
return self._list_to_long_tensor([])
def make_model_input_from_broadcasted_tensor_dict(
self, tensor_dict: Dict[str,
Any]) -> EncoderDecoderModelInputForCPU:
return EncoderDecoderModelInputForCPU.from_broadcasted_tensor_dict(
tensor_dict,
attn_backend=self.attn_backend,
)
def prepare_model_input(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
virtual_engine: int = 0,
finished_requests_ids: Optional[List[str]] = None
) -> EncoderDecoderModelInputForCPU:
model_input = super().prepare_model_input(seq_group_metadata_list,
virtual_engine,
finished_requests_ids)
model_input = cast(EncoderDecoderModelInputForCPU, model_input)
(
attn_metadata,
encoder_input_tokens_tensor,
encoder_input_positions_tensor,
) = self._prepare_encoder_model_input_tensors(seq_group_metadata_list,
model_input)
return dataclasses.replace(
model_input,
attn_metadata=attn_metadata,
encoder_input_tokens=encoder_input_tokens_tensor,
encoder_input_positions=encoder_input_positions_tensor,
)
def _prepare_encoder_model_input_tensors(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
model_input: EncoderDecoderModelInputForCPU,
) -> Tuple[AttentionMetadata, Optional[torch.Tensor],
Optional[torch.Tensor]]:
"""Helper method to prepare the encoder- and cross-attn-related
model inputs based on a given sequence group. These additional inputs
are used to augment an already-computed `EncoderDecoderModelInput`
data structure which already has decoder-related model inputs
populated.
Sets the following attn_metadata fields:
* `num_encoder_tokens`
* `encoder_seq_lens`
* `encoder_seq_lens_tensor`
* `max_encoder_seq_len`
* `cross_slot_mapping`
* `cross_block_tables`
Constructs a new model inputs data structure, based on
(1) the existing fields in the `model_inputs` argument,
and (2) the following additional fields which are
computed (or in the case of `attn_metadata`, updated)
by this function:
* attn_metadata
* encoder_input_tokens
* encoder_input_positions
Arguments:
* seq_group_metadata_list: list of sequence groups for which to
compute inputs
* model_inputs: model inputs data structure with decoder-oriented
fields already computed.
Return:
* Updated model inputs data structure
"""
if len(seq_group_metadata_list) == 0:
return (model_input.attn_metadata, None, None)
# Since we are not supporting chunked prefill either the entire
# batch is prefill or it is decode
is_prompt = seq_group_metadata_list[0].is_prompt
# Build encoder inputs
encoder_seq_lens: List[int] = []
if is_prompt:
# Prefill phase.
cross_block_tables = self._empty_int32_tensor().view(
len(seq_group_metadata_list), -1)
# Extract input tokens/positions, cross-attention slot-mapping,
# & seq len from each sequence group metadata
(
encoder_input_tokens,
encoder_input_positions,
cross_slot_mapping,
) = (
[],
[],
[],
)
for seq_group_metadata in seq_group_metadata_list:
# Build seq lens
seq_len = seq_group_metadata.encoder_seq_data.get_len()
token_ids = seq_group_metadata.encoder_seq_data.get_token_ids()
encoder_seq_lens.append(seq_len)
# Build slot mapping
for i in range(0, seq_len):
block_number = seq_group_metadata.cross_block_table[
i // self.block_size]
block_offset = i % self.block_size
slot = block_number * self.block_size + block_offset
cross_slot_mapping.append(slot)
# Build encoder input tokens
encoder_input_tokens.extend(token_ids)
encoder_input_positions.extend(list(range(0, seq_len)))
# Convert tokens/positions & cross-attention
# slot-mapping to encoder input tensors
encoder_input_tokens_tensor = self._list_to_long_tensor(
encoder_input_tokens)
encoder_input_positions_tensor = self._list_to_long_tensor(
encoder_input_positions)
cross_slot_mapping_tensor = self._list_to_long_tensor(
cross_slot_mapping)
else:
# Decode phase.
encoder_input_tokens_tensor = self._empty_long_tensor()
encoder_input_positions_tensor = self._empty_long_tensor()
cross_slot_mapping_tensor = self._empty_long_tensor()
# Extract cross-attention block tables &
# seq len from each sequence group metadata.
# Cross-attention block tables are empty
# during vLLM memory profiling.
cross_block_tables = []
for seq_group_metadata in seq_group_metadata_list:
for _ in range(len(seq_group_metadata.seq_data)):
encoder_seq_lens.append(
seq_group_metadata.encoder_seq_data.get_len())
cross_block_table = seq_group_metadata.cross_block_table
cross_block_tables.append([] if (
cross_block_table is None) else cross_block_table)
max_len_of_block_table = max(
len(block_table) for block_table in cross_block_tables)
cross_block_tables = make_tensor_with_pad(
cross_block_tables,
max_len=max_len_of_block_table,
pad=0,
dtype=torch.int32,
device=self.device,
)
# Compute encoder sequence lengths & encoder
# sequence starting offset tensors
max_encoder_seq_len = max(encoder_seq_lens, default=0)
encoder_seq_lens_tensor = self._list_to_int32_tensor(encoder_seq_lens)
encoder_seq_start_loc = torch.zeros(encoder_seq_lens_tensor.shape[0] +
1,
dtype=torch.int32,
device=self.device)
torch.cumsum(encoder_seq_lens_tensor,
dim=0,
dtype=encoder_seq_start_loc.dtype,
out=encoder_seq_start_loc[1:])
# Update attention metadata with encoder-oriented attributes
attn_metadata = model_input.attn_metadata
assert attn_metadata is not None
(
attn_metadata.num_encoder_tokens,
attn_metadata.encoder_seq_lens,
attn_metadata.encoder_seq_lens_tensor,
attn_metadata.max_encoder_seq_len,
attn_metadata.cross_slot_mapping,
attn_metadata.cross_block_tables,
) = (
sum(encoder_seq_lens),
encoder_seq_lens,
encoder_seq_lens_tensor,
max_encoder_seq_len,
cross_slot_mapping_tensor,
cross_block_tables,
)
return (attn_metadata, encoder_input_tokens_tensor,
encoder_input_positions_tensor)
@torch.no_grad()
def execute_model(
self,
model_input: EncoderDecoderModelInputForCPU,
kv_caches: List[torch.Tensor],
intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1,
) -> Optional[List[SamplerOutput]]:
if num_steps > 1:
raise ValueError(
"CPU worker does not support multi-step execution.")
model_executable = self.model
execute_model_kwargs = {
"input_ids":
model_input.input_tokens,
"positions":
model_input.input_positions,
"encoder_input_ids":
model_input.encoder_input_tokens,
"encoder_positions":
model_input.encoder_input_positions,
"kv_caches":
kv_caches,
"attn_metadata":
model_input.attn_metadata,
**MultiModalInputs.as_kwargs(model_input.multi_modal_kwargs or {},
device=self.device),
"intermediate_tensors":
intermediate_tensors,
}
hidden_states = model_executable(**execute_model_kwargs)
# Compute the logits.
logits = self.model.compute_logits(hidden_states,
model_input.sampling_metadata)
# Only perform sampling in the driver worker.
if not self.is_driver_worker:
return []
# Sample the next token.
output = self.model.sample(
logits=logits,
sampling_metadata=model_input.sampling_metadata,
)
return [output]

View File

@ -19,7 +19,7 @@ from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
MultiModalInputs)
from vllm.sequence import (IntermediateTensors, SequenceData,
SequenceGroupMetadata)
from vllm.utils import STR_NOT_IMPL_ENC_DEC_ERR_STRS, make_tensor_with_pad
from vllm.utils import make_tensor_with_pad
from vllm.worker.model_runner_base import (
ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase,
_add_attn_metadata_broadcastable_dict,
@ -434,10 +434,6 @@ class CPUModelRunner(ModelRunnerBase[ModelInputForCPU]):
# Lazy initialization.
self.model: nn.Module # Set after init_Model
if self.model_config.is_encoder_decoder_model:
raise NotImplementedError(
STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_CPU'])
@property
def model_is_mrope(self) -> bool:
"""Detect if the model has "mrope" rope_scaling type.
@ -459,8 +455,8 @@ class CPUModelRunner(ModelRunnerBase[ModelInputForCPU]):
def make_model_input_from_broadcasted_tensor_dict(
self,
tensor_dict: Dict[str, Any],
) -> ModelInputForCPU:
return ModelInputForCPU.from_broadcasted_tensor_dict(
) -> ModelInputForCPUWithSamplingMetadata:
return ModelInputForCPUWithSamplingMetadata.from_broadcasted_tensor_dict( # noqa: E501
tensor_dict,
attn_backend=self.attn_backend,
)

View File

@ -1,5 +1,5 @@
"""A CPU worker class."""
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Optional, Tuple, Type
import torch
import torch.distributed
@ -15,6 +15,7 @@ from vllm.logger import init_logger
from vllm.model_executor import set_random_seed
from vllm.sequence import ExecuteModelRequest
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.worker.cpu_enc_dec_model_runner import CPUEncoderDecoderModelRunner
from vllm.worker.cpu_model_runner import CPUModelRunner
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase,
LoraNotSupportedWorkerBase, WorkerInput)
@ -163,7 +164,10 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
else:
self.local_omp_cpuid = omp_cpuids.split("|")[rank]
self.model_runner: CPUModelRunner = CPUModelRunner(
ModelRunnerClass: Type[CPUModelRunner] = CPUModelRunner
if self._is_encoder_decoder_model():
ModelRunnerClass = CPUEncoderDecoderModelRunner
self.model_runner: CPUModelRunner = ModelRunnerClass(
model_config,
parallel_config,
scheduler_config,
@ -205,6 +209,9 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
raise RuntimeError("Profiler is not enabled.")
self.profiler.stop()
def _is_encoder_decoder_model(self):
return self.model_config.is_encoder_decoder_model
def init_device(self) -> None:
if self.local_omp_cpuid != "all":
ret = torch.ops._C_utils.init_cpu_threads_env(self.local_omp_cpuid)