[Hardware][CPU] Cross-attention and Encoder-Decoder models support on CPU backend (#9089)

This commit is contained in:
Isotr0py 2024-10-07 14:50:35 +08:00 committed by GitHub
parent 8c6de96ea1
commit 4f95ffee6f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 811 additions and 264 deletions

View File

@ -23,6 +23,7 @@ docker exec cpu-test-avx2 bash -c "python3 examples/offline_inference.py"
# Run basic model test # Run basic model test
docker exec cpu-test bash -c " docker exec cpu-test bash -c "
pip install pytest matplotlib einops transformers_stream_generator datamodel_code_generator pip install pytest matplotlib einops transformers_stream_generator datamodel_code_generator
pytest -v -s tests/models/encoder_decoder/language
pytest -v -s tests/models/decoder_only/language \ pytest -v -s tests/models/decoder_only/language \
--ignore=tests/models/test_fp8.py \ --ignore=tests/models/test_fp8.py \
--ignore=tests/models/decoder_only/language/test_jamba.py \ --ignore=tests/models/decoder_only/language/test_jamba.py \

View File

@ -4,220 +4,214 @@ Run `pytest tests/models/encoder_decoder/language/test_bart.py`.
""" """
from typing import List, Optional, Tuple, Type from typing import List, Optional, Tuple, Type
from vllm.utils import is_cpu import pytest
from transformers import AutoModelForSeq2SeqLM
if not is_cpu(): from vllm.sequence import SampleLogprobs
# CPU backend is not currently supported with encoder/decoder models
# skip test definitions entirely to avoid importing GPU kernel libs
# (xFormers, etc.)
import pytest from ....conftest import (DecoderPromptType, ExplicitEncoderDecoderPrompt,
from transformers import AutoModelForSeq2SeqLM HfRunner, VllmRunner)
from ....utils import multi_gpu_test
from ...utils import check_logprobs_close
from vllm.sequence import SampleLogprobs MODELS = ["facebook/bart-base", "facebook/bart-large-cnn"]
from ....conftest import (DecoderPromptType, ExplicitEncoderDecoderPrompt,
HfRunner, VllmRunner)
from ....utils import multi_gpu_test
from ...utils import check_logprobs_close
MODELS = ["facebook/bart-base", "facebook/bart-large-cnn"] def vllm_to_hf_output(
vllm_output: Tuple[List[int], str, Optional[SampleLogprobs]],
decoder_prompt_type: DecoderPromptType,
):
"""Sanitize vllm output to be comparable with hf output."""
output_ids, output_str, out_logprobs = vllm_output
def vllm_to_hf_output( hf_output_str = output_str + "</s>"
vllm_output: Tuple[List[int], str, Optional[SampleLogprobs]], if decoder_prompt_type == DecoderPromptType.NONE:
decoder_prompt_type: DecoderPromptType, hf_output_str = "<s>" + hf_output_str
):
"""Sanitize vllm output to be comparable with hf output."""
output_ids, output_str, out_logprobs = vllm_output
hf_output_str = output_str + "</s>" return output_ids, hf_output_str, out_logprobs
if decoder_prompt_type == DecoderPromptType.NONE:
hf_output_str = "<s>" + hf_output_str
return output_ids, hf_output_str, out_logprobs
def run_test( def run_test(
hf_runner: Type[HfRunner], hf_runner: Type[HfRunner],
vllm_runner: Type[VllmRunner], vllm_runner: Type[VllmRunner],
prompts: List[ExplicitEncoderDecoderPrompt[str, str]], prompts: List[ExplicitEncoderDecoderPrompt[str, str]],
decoder_prompt_type: DecoderPromptType, decoder_prompt_type: DecoderPromptType,
model: str, model: str,
*, *,
dtype: str, dtype: str,
max_tokens: int, max_tokens: int,
num_logprobs: int, num_logprobs: int,
tensor_parallel_size: int, tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None, distributed_executor_backend: Optional[str] = None,
) -> None: ) -> None:
''' '''
Test the vLLM BART model for a variety of encoder/decoder input prompts, Test the vLLM BART model for a variety of encoder/decoder input prompts,
by validating it against HuggingFace (HF) BART. by validating it against HuggingFace (HF) BART.
Arguments: Arguments:
* hf_runner: HuggingFace (HF) test model runner * hf_runner: HuggingFace (HF) test model runner
* vllm_runner: vLLM test model runner * vllm_runner: vLLM test model runner
* example_encoder_decoder_prompts: test fixture which provides a * example_encoder_decoder_prompts: test fixture which provides a
dictionary of dummy prompts dictionary of dummy prompts
* model: the HF ID of the specific BART variant under test * model: the HF ID of the specific BART variant under test
* dtype: the tensor datatype to employ * dtype: the tensor datatype to employ
* max_tokens * max_tokens
* num_logprobs * num_logprobs
* decoder_prompt_type: key into the example_encoder_decoder_prompts * decoder_prompt_type: key into the example_encoder_decoder_prompts
dictionary; selects specific encoder/decoder dictionary; selects specific encoder/decoder
prompt scenarios to test prompt scenarios to test
A note on using HF BART as a baseline for validating vLLM BART, A note on using HF BART as a baseline for validating vLLM BART,
specifically when the decoder prompt is None. specifically when the decoder prompt is None.
The HF GenerationMixin's default behavior is to force the first The HF GenerationMixin's default behavior is to force the first
decoded token to be <BOS> if the prompt does not already contain decoded token to be <BOS> if the prompt does not already contain
<BOS> (this is accomplished using a logit <BOS> (this is accomplished using a logit
processor setting.) processor setting.)
So when we use HF BART as our baseline for comparison, note that So when we use HF BART as our baseline for comparison, note that
when the user provides a request with a None decoder prompt when the user provides a request with a None decoder prompt
(i.e. a singleton encoder prompt, or else an explicit encoder/ (i.e. a singleton encoder prompt, or else an explicit encoder/
decoder prompt with the decoder sub-prompt set to None), HF and decoder prompt with the decoder sub-prompt set to None), HF and
vLLM handle this in different ways: vLLM handle this in different ways:
* HF will (1) tokenize the None prompt as an empty token-list, * HF will (1) tokenize the None prompt as an empty token-list,
(2) append <decoder-start-token> to the beginning, yielding (2) append <decoder-start-token> to the beginning, yielding
[<decoder-start-token>], (3) pass this token list to the model, and [<decoder-start-token>], (3) pass this token list to the model, and
then (4) after computing logits during prefill, override the model then (4) after computing logits during prefill, override the model
logits & force <BOS> to be the first generated token. logits & force <BOS> to be the first generated token.
* vLLM will (1) tokenize the None prompt as [<BOS>], (2) append decoder- * vLLM will (1) tokenize the None prompt as [<BOS>], (2) append decoder-
start-token to the beginning, yielding [<decoder-start-token><BOS>], start-token to the beginning, yielding [<decoder-start-token><BOS>],
(3) pass these tokens to the model & proceed with generation. (3) pass these tokens to the model & proceed with generation.
The net effect is that compared to vLLM, the list of HF *decoded* tokens The net effect is that compared to vLLM, the list of HF *decoded* tokens
will contain one more initial <BOS> than the vLLM generated tokens, will contain one more initial <BOS> than the vLLM generated tokens,
because vLLM's <BOS> token is injected into the prompt rather than into because vLLM's <BOS> token is injected into the prompt rather than into
the generated output. This is in spite of the fact that overall, the the generated output. This is in spite of the fact that overall, the
complete sequences (prompt + decoded tokens) produced by vLLM will match complete sequences (prompt + decoded tokens) produced by vLLM will match
HF. HF.
So when we use HF decoded token output to validate vLLM's decoded token So when we use HF decoded token output to validate vLLM's decoded token
output, the testing process must account for the difference in decoded output, the testing process must account for the difference in decoded
token sequences between vLLM and HF specifically in the token sequences between vLLM and HF specifically in the
decoder-prompt-is-None case. decoder-prompt-is-None case.
One option is to disable the logit processor feature that forces the One option is to disable the logit processor feature that forces the
<BOS> token to be decoded (forced_bos_token_id = None), eliminating <BOS> token to be decoded (forced_bos_token_id = None), eliminating
the problem entirely. However this is not "normal" BART usage. the problem entirely. However this is not "normal" BART usage.
The other option is - only in the decoder-prompt-is-None case - to The other option is - only in the decoder-prompt-is-None case - to
discard the first decoded token from the HF output before comparing it discard the first decoded token from the HF output before comparing it
to vLLM. to vLLM.
To that end, when testing the scenario where the decoder prompt is None To that end, when testing the scenario where the decoder prompt is None
(and only in that one scenario), this test skips the first HF decoded (and only in that one scenario), this test skips the first HF decoded
token during the process of validating the vLLM decoded output. token during the process of validating the vLLM decoded output.
''' '''
# NOTE: take care of the order. run vLLM first, and then run HF. # NOTE: take care of the order. run vLLM first, and then run HF.
# vLLM needs a fresh new process without cuda initialization. # vLLM needs a fresh new process without cuda initialization.
# if we run HF first, the cuda initialization will be done and it # if we run HF first, the cuda initialization will be done and it
# will hurt multiprocessing backend with fork method (the default). # will hurt multiprocessing backend with fork method (the default).
# Note: currently encoder/decoder models are only compatible with # Note: currently encoder/decoder models are only compatible with
# enforce_eager=True. Normally this is not a problem because # enforce_eager=True. Normally this is not a problem because
# for encoder/decoder models vLLM will # for encoder/decoder models vLLM will
# default to enforce_eager=True if enforce_eager # default to enforce_eager=True if enforce_eager
# is left unspecified. However, the # is left unspecified. However, the
# VllmRunner test fixture (which wraps around the LLM class) defaults to # VllmRunner test fixture (which wraps around the LLM class) defaults to
# enforce_eager=False (a behavior which a number of already-exisitng # enforce_eager=False (a behavior which a number of already-exisitng
# decoder-only unit tests expect), so when testing an encoder/decoder # decoder-only unit tests expect), so when testing an encoder/decoder
# model we must explicitly specify enforce_eager=True in the VllmRunner # model we must explicitly specify enforce_eager=True in the VllmRunner
# constructor. # constructor.
with vllm_runner( with vllm_runner(model,
model, dtype=dtype,
dtype=dtype, tensor_parallel_size=tensor_parallel_size,
tensor_parallel_size=tensor_parallel_size, distributed_executor_backend=distributed_executor_backend,
distributed_executor_backend=distributed_executor_backend, enforce_eager=True) as vllm_model:
enforce_eager=True) as vllm_model: vllm_outputs = vllm_model.generate_encoder_decoder_greedy_logprobs(
vllm_outputs = vllm_model.generate_encoder_decoder_greedy_logprobs( prompts, max_tokens, num_logprobs)
prompts, max_tokens, num_logprobs)
# Configuration settings for HF baseline # Configuration settings for HF baseline
hf_kwargs = { hf_kwargs = {
"top_k": None, "top_k": None,
"num_beams": 1, "num_beams": 1,
"repetition_penalty": 1.0, "repetition_penalty": 1.0,
"top_p": 1.0, "top_p": 1.0,
"length_penalty": 1.0, "length_penalty": 1.0,
"early_stopping": False, "early_stopping": False,
"no_repeat_ngram_size": None, "no_repeat_ngram_size": None,
"min_length": 0 "min_length": 0
} }
with hf_runner(model, dtype=dtype, with hf_runner(model, dtype=dtype,
auto_cls=AutoModelForSeq2SeqLM) as hf_model: auto_cls=AutoModelForSeq2SeqLM) as hf_model:
hf_outputs = ( hf_outputs = (hf_model.generate_encoder_decoder_greedy_logprobs_limit(
hf_model.generate_encoder_decoder_greedy_logprobs_limit( prompts,
prompts, max_tokens,
max_tokens, num_logprobs,
num_logprobs, **hf_kwargs,
**hf_kwargs, ))
))
hf_skip_tokens = (1 if decoder_prompt_type == DecoderPromptType.NONE hf_skip_tokens = (1
else 0) if decoder_prompt_type == DecoderPromptType.NONE else 0)
check_logprobs_close( check_logprobs_close(
outputs_0_lst=hf_outputs, outputs_0_lst=hf_outputs,
outputs_1_lst=[ outputs_1_lst=[
vllm_to_hf_output(vllm_output, decoder_prompt_type) vllm_to_hf_output(vllm_output, decoder_prompt_type)
for vllm_output in vllm_outputs for vllm_output in vllm_outputs
], ],
name_0="hf", name_0="hf",
name_1="vllm", name_1="vllm",
num_outputs_0_skip_tokens=hf_skip_tokens, num_outputs_0_skip_tokens=hf_skip_tokens,
) )
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float", "bfloat16"])
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("decoder_prompt_type", list(DecoderPromptType))
def test_models(hf_runner, vllm_runner, example_encoder_decoder_prompts,
model, dtype, max_tokens, num_logprobs,
decoder_prompt_type) -> None:
run_test( @pytest.mark.parametrize("model", MODELS)
hf_runner, @pytest.mark.parametrize("dtype", ["float", "bfloat16"])
vllm_runner, @pytest.mark.parametrize("max_tokens", [64])
example_encoder_decoder_prompts[decoder_prompt_type], @pytest.mark.parametrize("num_logprobs", [5])
decoder_prompt_type, @pytest.mark.parametrize("decoder_prompt_type", list(DecoderPromptType))
model, def test_models(hf_runner, vllm_runner, example_encoder_decoder_prompts, model,
dtype=dtype, dtype, max_tokens, num_logprobs, decoder_prompt_type) -> None:
max_tokens=max_tokens,
num_logprobs=num_logprobs,
tensor_parallel_size=1,
)
@multi_gpu_test(num_gpus=2) run_test(
@pytest.mark.parametrize("distributed_executor_backend", ["ray", "mp"]) hf_runner,
@pytest.mark.parametrize("model", ["facebook/bart-large-cnn"]) vllm_runner,
@pytest.mark.parametrize("dtype", ["float"]) example_encoder_decoder_prompts[decoder_prompt_type],
@pytest.mark.parametrize("max_tokens", [64]) decoder_prompt_type,
@pytest.mark.parametrize("num_logprobs", [5]) model,
@pytest.mark.parametrize("decoder_prompt_type", [DecoderPromptType.CUSTOM]) dtype=dtype,
def test_models_distributed(hf_runner, vllm_runner, max_tokens=max_tokens,
example_encoder_decoder_prompts, num_logprobs=num_logprobs,
distributed_executor_backend, model, dtype, tensor_parallel_size=1,
max_tokens, num_logprobs, )
decoder_prompt_type) -> None:
run_test(
hf_runner, @multi_gpu_test(num_gpus=2)
vllm_runner, @pytest.mark.parametrize("distributed_executor_backend", ["ray", "mp"])
example_encoder_decoder_prompts[decoder_prompt_type], @pytest.mark.parametrize("model", ["facebook/bart-large-cnn"])
decoder_prompt_type, @pytest.mark.parametrize("dtype", ["float"])
model, @pytest.mark.parametrize("max_tokens", [64])
dtype=dtype, @pytest.mark.parametrize("num_logprobs", [5])
max_tokens=max_tokens, @pytest.mark.parametrize("decoder_prompt_type", [DecoderPromptType.CUSTOM])
num_logprobs=num_logprobs, def test_models_distributed(hf_runner, vllm_runner,
tensor_parallel_size=2, example_encoder_decoder_prompts,
distributed_executor_backend=distributed_executor_backend, distributed_executor_backend, model, dtype,
) max_tokens, num_logprobs,
decoder_prompt_type) -> None:
run_test(
hf_runner,
vllm_runner,
example_encoder_decoder_prompts[decoder_prompt_type],
decoder_prompt_type,
model,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
tensor_parallel_size=2,
distributed_executor_backend=distributed_executor_backend,
)

View File

@ -75,6 +75,22 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
slot_mapping: torch.Tensor slot_mapping: torch.Tensor
seq_lens: Optional[List[int]] seq_lens: Optional[List[int]]
# Begin encoder attn & enc/dec cross-attn fields...
# Encoder sequence lengths representation
encoder_seq_lens: Optional[List[int]] = None
encoder_seq_lens_tensor: Optional[torch.Tensor] = None
# Maximum sequence length among encoder sequences
max_encoder_seq_len: Optional[int] = None
# Number of tokens input to encoder
num_encoder_tokens: Optional[int] = None
# Cross-attention memory-mapping data structures: slot mapping
# and block tables
cross_slot_mapping: Optional[torch.Tensor] = None
cross_block_tables: Optional[torch.Tensor] = None
def __post_init__(self): def __post_init__(self):
# Set during the execution of the first attention op. # Set during the execution of the first attention op.
# It is a list because it is needed to set per prompt # It is a list because it is needed to set per prompt
@ -82,6 +98,28 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
# from xformer API. # from xformer API.
# will not appear in the __repr__ and __init__ # will not appear in the __repr__ and __init__
self.attn_bias: Optional[List[torch.Tensor]] = None self.attn_bias: Optional[List[torch.Tensor]] = None
self.encoder_attn_bias: Optional[List[torch.Tensor]] = None
self.cross_attn_bias: Optional[List[torch.Tensor]] = None
@property
def is_all_encoder_attn_metadata_set(self):
'''
All attention metadata required for encoder attention is set.
'''
return ((self.encoder_seq_lens is not None)
and (self.encoder_seq_lens_tensor is not None)
and (self.max_encoder_seq_len is not None))
@property
def is_all_cross_attn_metadata_set(self):
'''
All attention metadata required for enc/dec cross-attention is set.
Superset of encoder attention required metadata.
'''
return (self.is_all_encoder_attn_metadata_set
and (self.cross_slot_mapping is not None)
and (self.cross_block_tables is not None))
@property @property
def prefill_metadata(self) -> Optional["TorchSDPAMetadata"]: def prefill_metadata(self) -> Optional["TorchSDPAMetadata"]:
@ -101,6 +139,136 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
return self return self
def get_seq_lens(
self,
attn_type: AttentionType,
):
'''
Extract appropriate sequence lengths from attention metadata
according to attention type.
Arguments:
* attn_metadata: Attention metadata structure associated with attention
* attn_type: encoder attention, decoder self-attention,
encoder/decoder cross-attention
Returns:
* Appropriate sequence lengths tensor for query
* Appropriate sequence lengths tensor for key & value
'''
if attn_type == AttentionType.DECODER:
seq_lens_q = self.seq_lens
seq_lens_kv = self.seq_lens
elif attn_type == AttentionType.ENCODER:
seq_lens_q = self.encoder_seq_lens
seq_lens_kv = self.encoder_seq_lens
elif attn_type == AttentionType.ENCODER_DECODER:
seq_lens_q = self.seq_lens
seq_lens_kv = self.encoder_seq_lens
else:
raise AttributeError(f"Invalid attention type {str(attn_type)}")
return seq_lens_q, seq_lens_kv
def get_attn_bias(
self,
attn_type: AttentionType,
) -> Optional[List[torch.Tensor]]:
'''
Extract appropriate attention bias from attention metadata
according to attention type.
Arguments:
* attn_metadata: Attention metadata structure associated with attention
* attn_type: encoder attention, decoder self-attention,
encoder/decoder cross-attention
Returns:
* Appropriate attention bias value given the attention type
'''
if attn_type == AttentionType.DECODER:
return self.attn_bias
elif attn_type == AttentionType.ENCODER:
return self.encoder_attn_bias
elif attn_type == AttentionType.ENCODER_DECODER:
return self.cross_attn_bias
else:
raise AttributeError(f"Invalid attention type {str(attn_type)}")
def set_attn_bias(
self,
attn_bias: List[torch.Tensor],
attn_type: AttentionType,
) -> None:
'''
Update appropriate attention bias field of attention metadata,
according to attention type.
Arguments:
* attn_metadata: Attention metadata structure associated with attention
* attn_bias: The desired attention bias value
* attn_type: encoder attention, decoder self-attention,
encoder/decoder cross-attention
'''
if attn_type == AttentionType.DECODER:
self.attn_bias = attn_bias
elif attn_type == AttentionType.ENCODER:
self.encoder_attn_bias = attn_bias
elif attn_type == AttentionType.ENCODER_DECODER:
self.cross_attn_bias = attn_bias
else:
raise AttributeError(f"Invalid attention type {str(attn_type)}")
def get_seq_len_block_table_args(
self,
attn_type: AttentionType,
) -> tuple:
'''
The particular choice of sequence-length- and block-table-related
attributes which should be extracted from attn_metadata is dependent
on the type of attention operation.
Decoder attn -> select entirely decoder self-attention-related fields
Encoder/decoder cross-attn -> select encoder sequence lengths &
cross-attn block-tables fields
Encoder attn -> select encoder sequence lengths fields & no block tables
Arguments:
* attn_metadata: Attention metadata structure associated with attention
* is_prompt: True if prefill, False otherwise
* attn_type: encoder attention, decoder self-attention,
encoder/decoder cross-attention
Returns:
* Appropriate sequence-lengths tensor
* Appropriate max sequence-length scalar
* Appropriate block tables (or None)
'''
if attn_type == AttentionType.DECODER:
# Decoder self-attention
# Choose max_seq_len based on whether we are in prompt_run
return (self.seq_lens_tensor, self.max_decode_seq_len,
self.block_tables)
elif attn_type == AttentionType.ENCODER_DECODER:
# Enc/dec cross-attention KVs match encoder sequence length;
# cross-attention utilizes special "cross" block tables
return (self.encoder_seq_lens_tensor, self.max_encoder_seq_len,
self.cross_block_tables)
elif attn_type == AttentionType.ENCODER:
# No block tables associated with encoder attention
return (self.encoder_seq_lens_tensor, self.max_encoder_seq_len,
None)
else:
raise AttributeError(f"Invalid attention type {str(attn_type)}")
class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
@ -171,84 +339,101 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
shape = [num_tokens, num_heads * head_size] shape = [num_tokens, num_heads * head_size]
""" """
assert k_scale == 1.0 and v_scale == 1.0 assert k_scale == 1.0 and v_scale == 1.0
if attn_type != AttentionType.DECODER: if (attn_type == AttentionType.ENCODER
raise NotImplementedError("Encoder self-attention and " and (not attn_metadata.is_all_encoder_attn_metadata_set)):
"encoder/decoder cross-attention " raise AttributeError("Encoder attention requires setting "
"are not implemented for " "encoder metadata attributes.")
"TorchSDPABackendImpl") elif (attn_type == AttentionType.ENCODER_DECODER
num_tokens, hidden_size = query.shape and (not attn_metadata.is_all_cross_attn_metadata_set)):
raise AttributeError("Encoder/decoder cross-attention "
"requires setting cross-attention "
"metadata attributes.")
# Reshape the query, key, and value tensors. # Reshape the query, key, and value tensors.
query = query.view(-1, self.num_heads, self.head_size) query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_kv_heads, self.head_size) if key is not None:
value = value.view(-1, self.num_kv_heads, self.head_size) assert value is not None
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)
else:
assert value is None
if kv_cache.numel() > 0: if (attn_type != AttentionType.ENCODER and kv_cache.numel() > 0):
# KV-cache during decoder-self- or
# encoder-decoder-cross-attention, but not
# during encoder attention.
#
# Even if there are no new key/value pairs to cache,
# we still need to break out key_cache and value_cache
# i.e. for later use by paged attention
key_cache, value_cache = PagedAttention.split_kv_cache( key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size) kv_cache, self.num_kv_heads, self.head_size)
PagedAttention.write_to_paged_cache(key, value, key_cache,
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype, k_scale,
v_scale)
if attn_metadata.is_prompt: if (key is not None) and (value is not None):
if attn_type == AttentionType.ENCODER_DECODER:
# Update cross-attention KV cache (prefill-only)
# During cross-attention decode, key & value will be None,
# preventing this IF-statement branch from running
updated_slot_mapping = attn_metadata.cross_slot_mapping
else:
# Update self-attention KV cache (prefill/decode)
updated_slot_mapping = attn_metadata.slot_mapping
PagedAttention.write_to_paged_cache(key, value, key_cache,
value_cache,
updated_slot_mapping,
self.kv_cache_dtype,
k_scale, v_scale)
if attn_type != AttentionType.ENCODER:
# Decoder self-attention supports chunked prefill.
# Encoder/decoder cross-attention requires no chunked
# prefill (100% prefill or 100% decode tokens, no mix)
num_prefill_tokens = attn_metadata.num_prefill_tokens
num_decode_tokens = attn_metadata.num_decode_tokens
else:
# Encoder attention - chunked prefill is not applicable;
# derive token-count from query shape & and treat them
# as 100% prefill tokens
assert attn_metadata.num_encoder_tokens is not None
num_prefill_tokens = attn_metadata.num_encoder_tokens
num_decode_tokens = 0
if attn_type == AttentionType.DECODER:
# Only enforce this shape-constraint for decoder
# self-attention
assert key.shape[0] == num_prefill_tokens + num_decode_tokens
assert value.shape[0] == num_prefill_tokens + num_decode_tokens
if prefill_meta := attn_metadata.prefill_metadata:
assert attn_metadata.seq_lens is not None assert attn_metadata.seq_lens is not None
if (kv_cache.numel() == 0 if (kv_cache.numel() == 0
or attn_metadata.block_tables.numel() == 0): or prefill_meta.block_tables.numel() == 0):
if self.num_kv_heads != self.num_heads: output = self._run_sdpa_forward(query,
key = key.repeat_interleave(self.num_queries_per_kv, dim=1) key,
value = value.repeat_interleave(self.num_queries_per_kv, value,
dim=1) prefill_meta,
attn_type=attn_type)
if attn_metadata.attn_bias is None:
if self.alibi_slopes is not None:
att_masks = _make_alibi_bias(
self.alibi_slopes, query.dtype,
attn_metadata.seq_lens) # type: ignore
elif self.sliding_window is not None:
att_masks = _make_sliding_window_bias(
attn_metadata.seq_lens, self.sliding_window,
query.dtype) # type: ignore
else:
att_masks = [None] * len(attn_metadata.seq_lens)
attn_metadata.attn_bias = att_masks
query = query.movedim(0, query.dim() - 2)
key = key.movedim(0, key.dim() - 2)
value = value.movedim(0, value.dim() - 2)
start = 0
output = torch.empty(
(num_tokens, self.num_heads, self.head_size),
dtype=query.dtype)
for seq_len, mask in zip(attn_metadata.seq_lens,
attn_metadata.attn_bias):
end = start + seq_len
sub_out = scaled_dot_product_attention(
query[None, :, start:end, :],
key[None, :, start:end, :],
value[None, :, start:end, :],
attn_mask=mask,
dropout_p=0.0,
is_causal=not self.need_mask,
scale=self.scale).squeeze(0).movedim(
query.dim() - 2, 0)
output[start:end, :, :] = sub_out
start = end
else: else:
# prefix-enabled attention # prefix-enabled attention
raise RuntimeError( raise RuntimeError(
"Torch SDPA backend doesn't support prefix decoding.") "Torch SDPA backend doesn't support prefix decoding.")
else: if decode_meta := attn_metadata.decode_metadata:
# Decoding run. # Decoding run.
(
seq_lens_arg,
max_seq_len_arg,
block_tables_arg,
) = decode_meta.get_seq_len_block_table_args(attn_type)
output = PagedAttention.forward_decode( output = PagedAttention.forward_decode(
query, query,
key_cache, key_cache,
value_cache, value_cache,
attn_metadata.block_tables, block_tables_arg,
attn_metadata.seq_lens_tensor, seq_lens_arg,
attn_metadata.max_decode_seq_len, max_seq_len_arg,
self.kv_cache_dtype, self.kv_cache_dtype,
self.num_kv_heads, self.num_kv_heads,
self.scale, self.scale,
@ -260,6 +445,59 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
# Reshape the output tensor. # Reshape the output tensor.
return output.view(-1, self.num_heads * self.head_size) return output.view(-1, self.num_heads * self.head_size)
def _run_sdpa_forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_metadata: TorchSDPAMetadata,
attn_type: AttentionType = AttentionType.DECODER,
):
if self.num_kv_heads != self.num_heads:
key = key.repeat_interleave(self.num_queries_per_kv, dim=1)
value = value.repeat_interleave(self.num_queries_per_kv, dim=1)
attn_masks = attn_metadata.get_attn_bias(attn_type)
if attn_masks is None:
if self.alibi_slopes is not None:
attn_masks = _make_alibi_bias(
self.alibi_slopes, query.dtype,
attn_metadata.seq_lens) # type: ignore
elif self.sliding_window is not None:
assert attn_metadata.seq_lens is not None
attn_masks = _make_sliding_window_bias(
attn_metadata.seq_lens, self.sliding_window,
query.dtype) # type: ignore
else:
seq_lens, _ = attn_metadata.get_seq_lens(attn_type)
attn_masks = [None] * len(seq_lens)
attn_metadata.set_attn_bias(attn_masks, attn_type)
output = torch.empty_like(query)
query = query.movedim(0, query.dim() - 2)
key = key.movedim(0, key.dim() - 2)
value = value.movedim(0, value.dim() - 2)
causal_attn = (attn_type == AttentionType.DECODER)
seq_lens_q, seq_lens_kv = attn_metadata.get_seq_lens(attn_type)
start_q, start_kv = 0, 0
for seq_len_q, seq_len_kv, mask in zip(seq_lens_q, seq_lens_kv,
attn_masks):
end_q = start_q + seq_len_q
end_kv = start_kv + seq_len_kv
sub_out = scaled_dot_product_attention(
query[None, :, start_q:end_q, :],
key[None, :, start_kv:end_kv, :],
value[None, :, start_kv:end_kv, :],
attn_mask=mask,
dropout_p=0.0,
is_causal=causal_attn and not self.need_mask,
scale=self.scale).squeeze(0).movedim(query.dim() - 2, 0)
output[start_q:end_q, :, :] = sub_out
start_q, start_kv = end_q, end_kv
return output
def _make_alibi_bias( def _make_alibi_bias(
alibi_slopes: torch.Tensor, alibi_slopes: torch.Tensor,

View File

@ -0,0 +1,311 @@
import dataclasses
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, cast
import torch
from vllm.attention import AttentionMetadata
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.multimodal import MultiModalInputs
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
from vllm.utils import make_tensor_with_pad
from vllm.worker.cpu_model_runner import (CPUModelRunner,
ModelInputForCPUBuilder,
ModelInputForCPUWithSamplingMetadata)
from vllm.worker.model_runner_base import (
_add_attn_metadata_broadcastable_dict,
_add_sampling_metadata_broadcastable_dict)
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
@dataclasses.dataclass(frozen=True)
class EncoderDecoderModelInputForCPU(ModelInputForCPUWithSamplingMetadata):
"""
Used by the EncoderDecoderModelRunner.
"""
encoder_input_tokens: Optional[torch.Tensor] = None
encoder_input_positions: Optional[torch.Tensor] = None
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
tensor_dict = {
"input_tokens": self.input_tokens,
"input_positions": self.input_positions,
"encoder_input_tokens": self.encoder_input_tokens,
"encoder_input_positions": self.encoder_input_positions,
}
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
_add_sampling_metadata_broadcastable_dict(tensor_dict,
self.sampling_metadata)
return tensor_dict
@classmethod
def from_broadcasted_tensor_dict(
cls,
tensor_dict: Dict[str, Any],
attn_backend: Optional["AttentionBackend"] = None,
) -> "EncoderDecoderModelInputForCPU":
return cast(
EncoderDecoderModelInputForCPU,
super().from_broadcasted_tensor_dict(tensor_dict, attn_backend))
class CPUEncoderDecoderModelRunner(CPUModelRunner):
_model_input_cls: Type[EncoderDecoderModelInputForCPU] = (
EncoderDecoderModelInputForCPU)
_builder_cls: Type[ModelInputForCPUBuilder] = ModelInputForCPUBuilder
def _list_to_int32_tensor(
self,
_list: List[int],
) -> torch.Tensor:
return torch.tensor(_list, dtype=torch.int32, device=self.device)
def _list_to_long_tensor(
self,
_list: List[int],
) -> torch.Tensor:
return torch.tensor(_list, dtype=torch.long, device=self.device)
def _empty_int32_tensor(self) -> torch.Tensor:
return self._list_to_int32_tensor([])
def _empty_long_tensor(self) -> torch.Tensor:
return self._list_to_long_tensor([])
def make_model_input_from_broadcasted_tensor_dict(
self, tensor_dict: Dict[str,
Any]) -> EncoderDecoderModelInputForCPU:
return EncoderDecoderModelInputForCPU.from_broadcasted_tensor_dict(
tensor_dict,
attn_backend=self.attn_backend,
)
def prepare_model_input(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
virtual_engine: int = 0,
finished_requests_ids: Optional[List[str]] = None
) -> EncoderDecoderModelInputForCPU:
model_input = super().prepare_model_input(seq_group_metadata_list,
virtual_engine,
finished_requests_ids)
model_input = cast(EncoderDecoderModelInputForCPU, model_input)
(
attn_metadata,
encoder_input_tokens_tensor,
encoder_input_positions_tensor,
) = self._prepare_encoder_model_input_tensors(seq_group_metadata_list,
model_input)
return dataclasses.replace(
model_input,
attn_metadata=attn_metadata,
encoder_input_tokens=encoder_input_tokens_tensor,
encoder_input_positions=encoder_input_positions_tensor,
)
def _prepare_encoder_model_input_tensors(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
model_input: EncoderDecoderModelInputForCPU,
) -> Tuple[AttentionMetadata, Optional[torch.Tensor],
Optional[torch.Tensor]]:
"""Helper method to prepare the encoder- and cross-attn-related
model inputs based on a given sequence group. These additional inputs
are used to augment an already-computed `EncoderDecoderModelInput`
data structure which already has decoder-related model inputs
populated.
Sets the following attn_metadata fields:
* `num_encoder_tokens`
* `encoder_seq_lens`
* `encoder_seq_lens_tensor`
* `max_encoder_seq_len`
* `cross_slot_mapping`
* `cross_block_tables`
Constructs a new model inputs data structure, based on
(1) the existing fields in the `model_inputs` argument,
and (2) the following additional fields which are
computed (or in the case of `attn_metadata`, updated)
by this function:
* attn_metadata
* encoder_input_tokens
* encoder_input_positions
Arguments:
* seq_group_metadata_list: list of sequence groups for which to
compute inputs
* model_inputs: model inputs data structure with decoder-oriented
fields already computed.
Return:
* Updated model inputs data structure
"""
if len(seq_group_metadata_list) == 0:
return (model_input.attn_metadata, None, None)
# Since we are not supporting chunked prefill either the entire
# batch is prefill or it is decode
is_prompt = seq_group_metadata_list[0].is_prompt
# Build encoder inputs
encoder_seq_lens: List[int] = []
if is_prompt:
# Prefill phase.
cross_block_tables = self._empty_int32_tensor().view(
len(seq_group_metadata_list), -1)
# Extract input tokens/positions, cross-attention slot-mapping,
# & seq len from each sequence group metadata
(
encoder_input_tokens,
encoder_input_positions,
cross_slot_mapping,
) = (
[],
[],
[],
)
for seq_group_metadata in seq_group_metadata_list:
# Build seq lens
seq_len = seq_group_metadata.encoder_seq_data.get_len()
token_ids = seq_group_metadata.encoder_seq_data.get_token_ids()
encoder_seq_lens.append(seq_len)
# Build slot mapping
for i in range(0, seq_len):
block_number = seq_group_metadata.cross_block_table[
i // self.block_size]
block_offset = i % self.block_size
slot = block_number * self.block_size + block_offset
cross_slot_mapping.append(slot)
# Build encoder input tokens
encoder_input_tokens.extend(token_ids)
encoder_input_positions.extend(list(range(0, seq_len)))
# Convert tokens/positions & cross-attention
# slot-mapping to encoder input tensors
encoder_input_tokens_tensor = self._list_to_long_tensor(
encoder_input_tokens)
encoder_input_positions_tensor = self._list_to_long_tensor(
encoder_input_positions)
cross_slot_mapping_tensor = self._list_to_long_tensor(
cross_slot_mapping)
else:
# Decode phase.
encoder_input_tokens_tensor = self._empty_long_tensor()
encoder_input_positions_tensor = self._empty_long_tensor()
cross_slot_mapping_tensor = self._empty_long_tensor()
# Extract cross-attention block tables &
# seq len from each sequence group metadata.
# Cross-attention block tables are empty
# during vLLM memory profiling.
cross_block_tables = []
for seq_group_metadata in seq_group_metadata_list:
for _ in range(len(seq_group_metadata.seq_data)):
encoder_seq_lens.append(
seq_group_metadata.encoder_seq_data.get_len())
cross_block_table = seq_group_metadata.cross_block_table
cross_block_tables.append([] if (
cross_block_table is None) else cross_block_table)
max_len_of_block_table = max(
len(block_table) for block_table in cross_block_tables)
cross_block_tables = make_tensor_with_pad(
cross_block_tables,
max_len=max_len_of_block_table,
pad=0,
dtype=torch.int32,
device=self.device,
)
# Compute encoder sequence lengths & encoder
# sequence starting offset tensors
max_encoder_seq_len = max(encoder_seq_lens, default=0)
encoder_seq_lens_tensor = self._list_to_int32_tensor(encoder_seq_lens)
encoder_seq_start_loc = torch.zeros(encoder_seq_lens_tensor.shape[0] +
1,
dtype=torch.int32,
device=self.device)
torch.cumsum(encoder_seq_lens_tensor,
dim=0,
dtype=encoder_seq_start_loc.dtype,
out=encoder_seq_start_loc[1:])
# Update attention metadata with encoder-oriented attributes
attn_metadata = model_input.attn_metadata
assert attn_metadata is not None
(
attn_metadata.num_encoder_tokens,
attn_metadata.encoder_seq_lens,
attn_metadata.encoder_seq_lens_tensor,
attn_metadata.max_encoder_seq_len,
attn_metadata.cross_slot_mapping,
attn_metadata.cross_block_tables,
) = (
sum(encoder_seq_lens),
encoder_seq_lens,
encoder_seq_lens_tensor,
max_encoder_seq_len,
cross_slot_mapping_tensor,
cross_block_tables,
)
return (attn_metadata, encoder_input_tokens_tensor,
encoder_input_positions_tensor)
@torch.no_grad()
def execute_model(
self,
model_input: EncoderDecoderModelInputForCPU,
kv_caches: List[torch.Tensor],
intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1,
) -> Optional[List[SamplerOutput]]:
if num_steps > 1:
raise ValueError(
"CPU worker does not support multi-step execution.")
model_executable = self.model
execute_model_kwargs = {
"input_ids":
model_input.input_tokens,
"positions":
model_input.input_positions,
"encoder_input_ids":
model_input.encoder_input_tokens,
"encoder_positions":
model_input.encoder_input_positions,
"kv_caches":
kv_caches,
"attn_metadata":
model_input.attn_metadata,
**MultiModalInputs.as_kwargs(model_input.multi_modal_kwargs or {},
device=self.device),
"intermediate_tensors":
intermediate_tensors,
}
hidden_states = model_executable(**execute_model_kwargs)
# Compute the logits.
logits = self.model.compute_logits(hidden_states,
model_input.sampling_metadata)
# Only perform sampling in the driver worker.
if not self.is_driver_worker:
return []
# Sample the next token.
output = self.model.sample(
logits=logits,
sampling_metadata=model_input.sampling_metadata,
)
return [output]

View File

@ -19,7 +19,7 @@ from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
MultiModalInputs) MultiModalInputs)
from vllm.sequence import (IntermediateTensors, SequenceData, from vllm.sequence import (IntermediateTensors, SequenceData,
SequenceGroupMetadata) SequenceGroupMetadata)
from vllm.utils import STR_NOT_IMPL_ENC_DEC_ERR_STRS, make_tensor_with_pad from vllm.utils import make_tensor_with_pad
from vllm.worker.model_runner_base import ( from vllm.worker.model_runner_base import (
ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase, ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase,
_add_attn_metadata_broadcastable_dict, _add_attn_metadata_broadcastable_dict,
@ -434,10 +434,6 @@ class CPUModelRunner(ModelRunnerBase[ModelInputForCPU]):
# Lazy initialization. # Lazy initialization.
self.model: nn.Module # Set after init_Model self.model: nn.Module # Set after init_Model
if self.model_config.is_encoder_decoder_model:
raise NotImplementedError(
STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_CPU'])
@property @property
def model_is_mrope(self) -> bool: def model_is_mrope(self) -> bool:
"""Detect if the model has "mrope" rope_scaling type. """Detect if the model has "mrope" rope_scaling type.
@ -459,8 +455,8 @@ class CPUModelRunner(ModelRunnerBase[ModelInputForCPU]):
def make_model_input_from_broadcasted_tensor_dict( def make_model_input_from_broadcasted_tensor_dict(
self, self,
tensor_dict: Dict[str, Any], tensor_dict: Dict[str, Any],
) -> ModelInputForCPU: ) -> ModelInputForCPUWithSamplingMetadata:
return ModelInputForCPU.from_broadcasted_tensor_dict( return ModelInputForCPUWithSamplingMetadata.from_broadcasted_tensor_dict( # noqa: E501
tensor_dict, tensor_dict,
attn_backend=self.attn_backend, attn_backend=self.attn_backend,
) )

View File

@ -1,5 +1,5 @@
"""A CPU worker class.""" """A CPU worker class."""
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple, Type
import torch import torch
import torch.distributed import torch.distributed
@ -15,6 +15,7 @@ from vllm.logger import init_logger
from vllm.model_executor import set_random_seed from vllm.model_executor import set_random_seed
from vllm.sequence import ExecuteModelRequest from vllm.sequence import ExecuteModelRequest
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.worker.cpu_enc_dec_model_runner import CPUEncoderDecoderModelRunner
from vllm.worker.cpu_model_runner import CPUModelRunner from vllm.worker.cpu_model_runner import CPUModelRunner
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, from vllm.worker.worker_base import (LocalOrDistributedWorkerBase,
LoraNotSupportedWorkerBase, WorkerInput) LoraNotSupportedWorkerBase, WorkerInput)
@ -163,7 +164,10 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
else: else:
self.local_omp_cpuid = omp_cpuids.split("|")[rank] self.local_omp_cpuid = omp_cpuids.split("|")[rank]
self.model_runner: CPUModelRunner = CPUModelRunner( ModelRunnerClass: Type[CPUModelRunner] = CPUModelRunner
if self._is_encoder_decoder_model():
ModelRunnerClass = CPUEncoderDecoderModelRunner
self.model_runner: CPUModelRunner = ModelRunnerClass(
model_config, model_config,
parallel_config, parallel_config,
scheduler_config, scheduler_config,
@ -205,6 +209,9 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
raise RuntimeError("Profiler is not enabled.") raise RuntimeError("Profiler is not enabled.")
self.profiler.stop() self.profiler.stop()
def _is_encoder_decoder_model(self):
return self.model_config.is_encoder_decoder_model
def init_device(self) -> None: def init_device(self) -> None:
if self.local_omp_cpuid != "all": if self.local_omp_cpuid != "all":
ret = torch.ops._C_utils.init_cpu_threads_env(self.local_omp_cpuid) ret = torch.ops._C_utils.init_cpu_threads_env(self.local_omp_cpuid)