mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 01:35:01 +08:00
[Hardware][CPU] Cross-attention and Encoder-Decoder models support on CPU backend (#9089)
This commit is contained in:
parent
8c6de96ea1
commit
4f95ffee6f
@ -23,6 +23,7 @@ docker exec cpu-test-avx2 bash -c "python3 examples/offline_inference.py"
|
|||||||
# Run basic model test
|
# Run basic model test
|
||||||
docker exec cpu-test bash -c "
|
docker exec cpu-test bash -c "
|
||||||
pip install pytest matplotlib einops transformers_stream_generator datamodel_code_generator
|
pip install pytest matplotlib einops transformers_stream_generator datamodel_code_generator
|
||||||
|
pytest -v -s tests/models/encoder_decoder/language
|
||||||
pytest -v -s tests/models/decoder_only/language \
|
pytest -v -s tests/models/decoder_only/language \
|
||||||
--ignore=tests/models/test_fp8.py \
|
--ignore=tests/models/test_fp8.py \
|
||||||
--ignore=tests/models/decoder_only/language/test_jamba.py \
|
--ignore=tests/models/decoder_only/language/test_jamba.py \
|
||||||
|
|||||||
@ -4,220 +4,214 @@ Run `pytest tests/models/encoder_decoder/language/test_bart.py`.
|
|||||||
"""
|
"""
|
||||||
from typing import List, Optional, Tuple, Type
|
from typing import List, Optional, Tuple, Type
|
||||||
|
|
||||||
from vllm.utils import is_cpu
|
import pytest
|
||||||
|
from transformers import AutoModelForSeq2SeqLM
|
||||||
|
|
||||||
if not is_cpu():
|
from vllm.sequence import SampleLogprobs
|
||||||
# CPU backend is not currently supported with encoder/decoder models
|
|
||||||
# skip test definitions entirely to avoid importing GPU kernel libs
|
|
||||||
# (xFormers, etc.)
|
|
||||||
|
|
||||||
import pytest
|
from ....conftest import (DecoderPromptType, ExplicitEncoderDecoderPrompt,
|
||||||
from transformers import AutoModelForSeq2SeqLM
|
HfRunner, VllmRunner)
|
||||||
|
from ....utils import multi_gpu_test
|
||||||
|
from ...utils import check_logprobs_close
|
||||||
|
|
||||||
from vllm.sequence import SampleLogprobs
|
MODELS = ["facebook/bart-base", "facebook/bart-large-cnn"]
|
||||||
|
|
||||||
from ....conftest import (DecoderPromptType, ExplicitEncoderDecoderPrompt,
|
|
||||||
HfRunner, VllmRunner)
|
|
||||||
from ....utils import multi_gpu_test
|
|
||||||
from ...utils import check_logprobs_close
|
|
||||||
|
|
||||||
MODELS = ["facebook/bart-base", "facebook/bart-large-cnn"]
|
def vllm_to_hf_output(
|
||||||
|
vllm_output: Tuple[List[int], str, Optional[SampleLogprobs]],
|
||||||
|
decoder_prompt_type: DecoderPromptType,
|
||||||
|
):
|
||||||
|
"""Sanitize vllm output to be comparable with hf output."""
|
||||||
|
output_ids, output_str, out_logprobs = vllm_output
|
||||||
|
|
||||||
def vllm_to_hf_output(
|
hf_output_str = output_str + "</s>"
|
||||||
vllm_output: Tuple[List[int], str, Optional[SampleLogprobs]],
|
if decoder_prompt_type == DecoderPromptType.NONE:
|
||||||
decoder_prompt_type: DecoderPromptType,
|
hf_output_str = "<s>" + hf_output_str
|
||||||
):
|
|
||||||
"""Sanitize vllm output to be comparable with hf output."""
|
|
||||||
output_ids, output_str, out_logprobs = vllm_output
|
|
||||||
|
|
||||||
hf_output_str = output_str + "</s>"
|
return output_ids, hf_output_str, out_logprobs
|
||||||
if decoder_prompt_type == DecoderPromptType.NONE:
|
|
||||||
hf_output_str = "<s>" + hf_output_str
|
|
||||||
|
|
||||||
return output_ids, hf_output_str, out_logprobs
|
|
||||||
|
|
||||||
def run_test(
|
def run_test(
|
||||||
hf_runner: Type[HfRunner],
|
hf_runner: Type[HfRunner],
|
||||||
vllm_runner: Type[VllmRunner],
|
vllm_runner: Type[VllmRunner],
|
||||||
prompts: List[ExplicitEncoderDecoderPrompt[str, str]],
|
prompts: List[ExplicitEncoderDecoderPrompt[str, str]],
|
||||||
decoder_prompt_type: DecoderPromptType,
|
decoder_prompt_type: DecoderPromptType,
|
||||||
model: str,
|
model: str,
|
||||||
*,
|
*,
|
||||||
dtype: str,
|
dtype: str,
|
||||||
max_tokens: int,
|
max_tokens: int,
|
||||||
num_logprobs: int,
|
num_logprobs: int,
|
||||||
tensor_parallel_size: int,
|
tensor_parallel_size: int,
|
||||||
distributed_executor_backend: Optional[str] = None,
|
distributed_executor_backend: Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
'''
|
'''
|
||||||
Test the vLLM BART model for a variety of encoder/decoder input prompts,
|
Test the vLLM BART model for a variety of encoder/decoder input prompts,
|
||||||
by validating it against HuggingFace (HF) BART.
|
by validating it against HuggingFace (HF) BART.
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
|
|
||||||
* hf_runner: HuggingFace (HF) test model runner
|
* hf_runner: HuggingFace (HF) test model runner
|
||||||
* vllm_runner: vLLM test model runner
|
* vllm_runner: vLLM test model runner
|
||||||
* example_encoder_decoder_prompts: test fixture which provides a
|
* example_encoder_decoder_prompts: test fixture which provides a
|
||||||
dictionary of dummy prompts
|
dictionary of dummy prompts
|
||||||
* model: the HF ID of the specific BART variant under test
|
* model: the HF ID of the specific BART variant under test
|
||||||
* dtype: the tensor datatype to employ
|
* dtype: the tensor datatype to employ
|
||||||
* max_tokens
|
* max_tokens
|
||||||
* num_logprobs
|
* num_logprobs
|
||||||
* decoder_prompt_type: key into the example_encoder_decoder_prompts
|
* decoder_prompt_type: key into the example_encoder_decoder_prompts
|
||||||
dictionary; selects specific encoder/decoder
|
dictionary; selects specific encoder/decoder
|
||||||
prompt scenarios to test
|
prompt scenarios to test
|
||||||
|
|
||||||
A note on using HF BART as a baseline for validating vLLM BART,
|
A note on using HF BART as a baseline for validating vLLM BART,
|
||||||
specifically when the decoder prompt is None.
|
specifically when the decoder prompt is None.
|
||||||
|
|
||||||
The HF GenerationMixin's default behavior is to force the first
|
The HF GenerationMixin's default behavior is to force the first
|
||||||
decoded token to be <BOS> if the prompt does not already contain
|
decoded token to be <BOS> if the prompt does not already contain
|
||||||
<BOS> (this is accomplished using a logit
|
<BOS> (this is accomplished using a logit
|
||||||
processor setting.)
|
processor setting.)
|
||||||
|
|
||||||
So when we use HF BART as our baseline for comparison, note that
|
So when we use HF BART as our baseline for comparison, note that
|
||||||
when the user provides a request with a None decoder prompt
|
when the user provides a request with a None decoder prompt
|
||||||
(i.e. a singleton encoder prompt, or else an explicit encoder/
|
(i.e. a singleton encoder prompt, or else an explicit encoder/
|
||||||
decoder prompt with the decoder sub-prompt set to None), HF and
|
decoder prompt with the decoder sub-prompt set to None), HF and
|
||||||
vLLM handle this in different ways:
|
vLLM handle this in different ways:
|
||||||
|
|
||||||
* HF will (1) tokenize the None prompt as an empty token-list,
|
* HF will (1) tokenize the None prompt as an empty token-list,
|
||||||
(2) append <decoder-start-token> to the beginning, yielding
|
(2) append <decoder-start-token> to the beginning, yielding
|
||||||
[<decoder-start-token>], (3) pass this token list to the model, and
|
[<decoder-start-token>], (3) pass this token list to the model, and
|
||||||
then (4) after computing logits during prefill, override the model
|
then (4) after computing logits during prefill, override the model
|
||||||
logits & force <BOS> to be the first generated token.
|
logits & force <BOS> to be the first generated token.
|
||||||
|
|
||||||
* vLLM will (1) tokenize the None prompt as [<BOS>], (2) append decoder-
|
* vLLM will (1) tokenize the None prompt as [<BOS>], (2) append decoder-
|
||||||
start-token to the beginning, yielding [<decoder-start-token><BOS>],
|
start-token to the beginning, yielding [<decoder-start-token><BOS>],
|
||||||
(3) pass these tokens to the model & proceed with generation.
|
(3) pass these tokens to the model & proceed with generation.
|
||||||
|
|
||||||
The net effect is that compared to vLLM, the list of HF *decoded* tokens
|
The net effect is that compared to vLLM, the list of HF *decoded* tokens
|
||||||
will contain one more initial <BOS> than the vLLM generated tokens,
|
will contain one more initial <BOS> than the vLLM generated tokens,
|
||||||
because vLLM's <BOS> token is injected into the prompt rather than into
|
because vLLM's <BOS> token is injected into the prompt rather than into
|
||||||
the generated output. This is in spite of the fact that overall, the
|
the generated output. This is in spite of the fact that overall, the
|
||||||
complete sequences (prompt + decoded tokens) produced by vLLM will match
|
complete sequences (prompt + decoded tokens) produced by vLLM will match
|
||||||
HF.
|
HF.
|
||||||
|
|
||||||
So when we use HF decoded token output to validate vLLM's decoded token
|
So when we use HF decoded token output to validate vLLM's decoded token
|
||||||
output, the testing process must account for the difference in decoded
|
output, the testing process must account for the difference in decoded
|
||||||
token sequences between vLLM and HF specifically in the
|
token sequences between vLLM and HF specifically in the
|
||||||
decoder-prompt-is-None case.
|
decoder-prompt-is-None case.
|
||||||
|
|
||||||
One option is to disable the logit processor feature that forces the
|
One option is to disable the logit processor feature that forces the
|
||||||
<BOS> token to be decoded (forced_bos_token_id = None), eliminating
|
<BOS> token to be decoded (forced_bos_token_id = None), eliminating
|
||||||
the problem entirely. However this is not "normal" BART usage.
|
the problem entirely. However this is not "normal" BART usage.
|
||||||
|
|
||||||
The other option is - only in the decoder-prompt-is-None case - to
|
The other option is - only in the decoder-prompt-is-None case - to
|
||||||
discard the first decoded token from the HF output before comparing it
|
discard the first decoded token from the HF output before comparing it
|
||||||
to vLLM.
|
to vLLM.
|
||||||
|
|
||||||
To that end, when testing the scenario where the decoder prompt is None
|
To that end, when testing the scenario where the decoder prompt is None
|
||||||
(and only in that one scenario), this test skips the first HF decoded
|
(and only in that one scenario), this test skips the first HF decoded
|
||||||
token during the process of validating the vLLM decoded output.
|
token during the process of validating the vLLM decoded output.
|
||||||
'''
|
'''
|
||||||
|
|
||||||
# NOTE: take care of the order. run vLLM first, and then run HF.
|
# NOTE: take care of the order. run vLLM first, and then run HF.
|
||||||
# vLLM needs a fresh new process without cuda initialization.
|
# vLLM needs a fresh new process without cuda initialization.
|
||||||
# if we run HF first, the cuda initialization will be done and it
|
# if we run HF first, the cuda initialization will be done and it
|
||||||
# will hurt multiprocessing backend with fork method (the default).
|
# will hurt multiprocessing backend with fork method (the default).
|
||||||
|
|
||||||
# Note: currently encoder/decoder models are only compatible with
|
# Note: currently encoder/decoder models are only compatible with
|
||||||
# enforce_eager=True. Normally this is not a problem because
|
# enforce_eager=True. Normally this is not a problem because
|
||||||
# for encoder/decoder models vLLM will
|
# for encoder/decoder models vLLM will
|
||||||
# default to enforce_eager=True if enforce_eager
|
# default to enforce_eager=True if enforce_eager
|
||||||
# is left unspecified. However, the
|
# is left unspecified. However, the
|
||||||
# VllmRunner test fixture (which wraps around the LLM class) defaults to
|
# VllmRunner test fixture (which wraps around the LLM class) defaults to
|
||||||
# enforce_eager=False (a behavior which a number of already-exisitng
|
# enforce_eager=False (a behavior which a number of already-exisitng
|
||||||
# decoder-only unit tests expect), so when testing an encoder/decoder
|
# decoder-only unit tests expect), so when testing an encoder/decoder
|
||||||
# model we must explicitly specify enforce_eager=True in the VllmRunner
|
# model we must explicitly specify enforce_eager=True in the VllmRunner
|
||||||
# constructor.
|
# constructor.
|
||||||
with vllm_runner(
|
with vllm_runner(model,
|
||||||
model,
|
dtype=dtype,
|
||||||
dtype=dtype,
|
tensor_parallel_size=tensor_parallel_size,
|
||||||
tensor_parallel_size=tensor_parallel_size,
|
distributed_executor_backend=distributed_executor_backend,
|
||||||
distributed_executor_backend=distributed_executor_backend,
|
enforce_eager=True) as vllm_model:
|
||||||
enforce_eager=True) as vllm_model:
|
vllm_outputs = vllm_model.generate_encoder_decoder_greedy_logprobs(
|
||||||
vllm_outputs = vllm_model.generate_encoder_decoder_greedy_logprobs(
|
prompts, max_tokens, num_logprobs)
|
||||||
prompts, max_tokens, num_logprobs)
|
|
||||||
|
|
||||||
# Configuration settings for HF baseline
|
# Configuration settings for HF baseline
|
||||||
hf_kwargs = {
|
hf_kwargs = {
|
||||||
"top_k": None,
|
"top_k": None,
|
||||||
"num_beams": 1,
|
"num_beams": 1,
|
||||||
"repetition_penalty": 1.0,
|
"repetition_penalty": 1.0,
|
||||||
"top_p": 1.0,
|
"top_p": 1.0,
|
||||||
"length_penalty": 1.0,
|
"length_penalty": 1.0,
|
||||||
"early_stopping": False,
|
"early_stopping": False,
|
||||||
"no_repeat_ngram_size": None,
|
"no_repeat_ngram_size": None,
|
||||||
"min_length": 0
|
"min_length": 0
|
||||||
}
|
}
|
||||||
|
|
||||||
with hf_runner(model, dtype=dtype,
|
with hf_runner(model, dtype=dtype,
|
||||||
auto_cls=AutoModelForSeq2SeqLM) as hf_model:
|
auto_cls=AutoModelForSeq2SeqLM) as hf_model:
|
||||||
hf_outputs = (
|
hf_outputs = (hf_model.generate_encoder_decoder_greedy_logprobs_limit(
|
||||||
hf_model.generate_encoder_decoder_greedy_logprobs_limit(
|
prompts,
|
||||||
prompts,
|
max_tokens,
|
||||||
max_tokens,
|
num_logprobs,
|
||||||
num_logprobs,
|
**hf_kwargs,
|
||||||
**hf_kwargs,
|
))
|
||||||
))
|
|
||||||
|
|
||||||
hf_skip_tokens = (1 if decoder_prompt_type == DecoderPromptType.NONE
|
hf_skip_tokens = (1
|
||||||
else 0)
|
if decoder_prompt_type == DecoderPromptType.NONE else 0)
|
||||||
|
|
||||||
check_logprobs_close(
|
check_logprobs_close(
|
||||||
outputs_0_lst=hf_outputs,
|
outputs_0_lst=hf_outputs,
|
||||||
outputs_1_lst=[
|
outputs_1_lst=[
|
||||||
vllm_to_hf_output(vllm_output, decoder_prompt_type)
|
vllm_to_hf_output(vllm_output, decoder_prompt_type)
|
||||||
for vllm_output in vllm_outputs
|
for vllm_output in vllm_outputs
|
||||||
],
|
],
|
||||||
name_0="hf",
|
name_0="hf",
|
||||||
name_1="vllm",
|
name_1="vllm",
|
||||||
num_outputs_0_skip_tokens=hf_skip_tokens,
|
num_outputs_0_skip_tokens=hf_skip_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
@pytest.mark.parametrize("model", MODELS)
|
|
||||||
@pytest.mark.parametrize("dtype", ["float", "bfloat16"])
|
|
||||||
@pytest.mark.parametrize("max_tokens", [64])
|
|
||||||
@pytest.mark.parametrize("num_logprobs", [5])
|
|
||||||
@pytest.mark.parametrize("decoder_prompt_type", list(DecoderPromptType))
|
|
||||||
def test_models(hf_runner, vllm_runner, example_encoder_decoder_prompts,
|
|
||||||
model, dtype, max_tokens, num_logprobs,
|
|
||||||
decoder_prompt_type) -> None:
|
|
||||||
|
|
||||||
run_test(
|
@pytest.mark.parametrize("model", MODELS)
|
||||||
hf_runner,
|
@pytest.mark.parametrize("dtype", ["float", "bfloat16"])
|
||||||
vllm_runner,
|
@pytest.mark.parametrize("max_tokens", [64])
|
||||||
example_encoder_decoder_prompts[decoder_prompt_type],
|
@pytest.mark.parametrize("num_logprobs", [5])
|
||||||
decoder_prompt_type,
|
@pytest.mark.parametrize("decoder_prompt_type", list(DecoderPromptType))
|
||||||
model,
|
def test_models(hf_runner, vllm_runner, example_encoder_decoder_prompts, model,
|
||||||
dtype=dtype,
|
dtype, max_tokens, num_logprobs, decoder_prompt_type) -> None:
|
||||||
max_tokens=max_tokens,
|
|
||||||
num_logprobs=num_logprobs,
|
|
||||||
tensor_parallel_size=1,
|
|
||||||
)
|
|
||||||
|
|
||||||
@multi_gpu_test(num_gpus=2)
|
run_test(
|
||||||
@pytest.mark.parametrize("distributed_executor_backend", ["ray", "mp"])
|
hf_runner,
|
||||||
@pytest.mark.parametrize("model", ["facebook/bart-large-cnn"])
|
vllm_runner,
|
||||||
@pytest.mark.parametrize("dtype", ["float"])
|
example_encoder_decoder_prompts[decoder_prompt_type],
|
||||||
@pytest.mark.parametrize("max_tokens", [64])
|
decoder_prompt_type,
|
||||||
@pytest.mark.parametrize("num_logprobs", [5])
|
model,
|
||||||
@pytest.mark.parametrize("decoder_prompt_type", [DecoderPromptType.CUSTOM])
|
dtype=dtype,
|
||||||
def test_models_distributed(hf_runner, vllm_runner,
|
max_tokens=max_tokens,
|
||||||
example_encoder_decoder_prompts,
|
num_logprobs=num_logprobs,
|
||||||
distributed_executor_backend, model, dtype,
|
tensor_parallel_size=1,
|
||||||
max_tokens, num_logprobs,
|
)
|
||||||
decoder_prompt_type) -> None:
|
|
||||||
run_test(
|
|
||||||
hf_runner,
|
@multi_gpu_test(num_gpus=2)
|
||||||
vllm_runner,
|
@pytest.mark.parametrize("distributed_executor_backend", ["ray", "mp"])
|
||||||
example_encoder_decoder_prompts[decoder_prompt_type],
|
@pytest.mark.parametrize("model", ["facebook/bart-large-cnn"])
|
||||||
decoder_prompt_type,
|
@pytest.mark.parametrize("dtype", ["float"])
|
||||||
model,
|
@pytest.mark.parametrize("max_tokens", [64])
|
||||||
dtype=dtype,
|
@pytest.mark.parametrize("num_logprobs", [5])
|
||||||
max_tokens=max_tokens,
|
@pytest.mark.parametrize("decoder_prompt_type", [DecoderPromptType.CUSTOM])
|
||||||
num_logprobs=num_logprobs,
|
def test_models_distributed(hf_runner, vllm_runner,
|
||||||
tensor_parallel_size=2,
|
example_encoder_decoder_prompts,
|
||||||
distributed_executor_backend=distributed_executor_backend,
|
distributed_executor_backend, model, dtype,
|
||||||
)
|
max_tokens, num_logprobs,
|
||||||
|
decoder_prompt_type) -> None:
|
||||||
|
run_test(
|
||||||
|
hf_runner,
|
||||||
|
vllm_runner,
|
||||||
|
example_encoder_decoder_prompts[decoder_prompt_type],
|
||||||
|
decoder_prompt_type,
|
||||||
|
model,
|
||||||
|
dtype=dtype,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
num_logprobs=num_logprobs,
|
||||||
|
tensor_parallel_size=2,
|
||||||
|
distributed_executor_backend=distributed_executor_backend,
|
||||||
|
)
|
||||||
|
|||||||
@ -75,6 +75,22 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
|
|||||||
slot_mapping: torch.Tensor
|
slot_mapping: torch.Tensor
|
||||||
seq_lens: Optional[List[int]]
|
seq_lens: Optional[List[int]]
|
||||||
|
|
||||||
|
# Begin encoder attn & enc/dec cross-attn fields...
|
||||||
|
# Encoder sequence lengths representation
|
||||||
|
encoder_seq_lens: Optional[List[int]] = None
|
||||||
|
encoder_seq_lens_tensor: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
# Maximum sequence length among encoder sequences
|
||||||
|
max_encoder_seq_len: Optional[int] = None
|
||||||
|
|
||||||
|
# Number of tokens input to encoder
|
||||||
|
num_encoder_tokens: Optional[int] = None
|
||||||
|
|
||||||
|
# Cross-attention memory-mapping data structures: slot mapping
|
||||||
|
# and block tables
|
||||||
|
cross_slot_mapping: Optional[torch.Tensor] = None
|
||||||
|
cross_block_tables: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
# Set during the execution of the first attention op.
|
# Set during the execution of the first attention op.
|
||||||
# It is a list because it is needed to set per prompt
|
# It is a list because it is needed to set per prompt
|
||||||
@ -82,6 +98,28 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
|
|||||||
# from xformer API.
|
# from xformer API.
|
||||||
# will not appear in the __repr__ and __init__
|
# will not appear in the __repr__ and __init__
|
||||||
self.attn_bias: Optional[List[torch.Tensor]] = None
|
self.attn_bias: Optional[List[torch.Tensor]] = None
|
||||||
|
self.encoder_attn_bias: Optional[List[torch.Tensor]] = None
|
||||||
|
self.cross_attn_bias: Optional[List[torch.Tensor]] = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_all_encoder_attn_metadata_set(self):
|
||||||
|
'''
|
||||||
|
All attention metadata required for encoder attention is set.
|
||||||
|
'''
|
||||||
|
return ((self.encoder_seq_lens is not None)
|
||||||
|
and (self.encoder_seq_lens_tensor is not None)
|
||||||
|
and (self.max_encoder_seq_len is not None))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_all_cross_attn_metadata_set(self):
|
||||||
|
'''
|
||||||
|
All attention metadata required for enc/dec cross-attention is set.
|
||||||
|
|
||||||
|
Superset of encoder attention required metadata.
|
||||||
|
'''
|
||||||
|
return (self.is_all_encoder_attn_metadata_set
|
||||||
|
and (self.cross_slot_mapping is not None)
|
||||||
|
and (self.cross_block_tables is not None))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def prefill_metadata(self) -> Optional["TorchSDPAMetadata"]:
|
def prefill_metadata(self) -> Optional["TorchSDPAMetadata"]:
|
||||||
@ -101,6 +139,136 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
|
|||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def get_seq_lens(
|
||||||
|
self,
|
||||||
|
attn_type: AttentionType,
|
||||||
|
):
|
||||||
|
'''
|
||||||
|
Extract appropriate sequence lengths from attention metadata
|
||||||
|
according to attention type.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
|
||||||
|
* attn_metadata: Attention metadata structure associated with attention
|
||||||
|
* attn_type: encoder attention, decoder self-attention,
|
||||||
|
encoder/decoder cross-attention
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
* Appropriate sequence lengths tensor for query
|
||||||
|
* Appropriate sequence lengths tensor for key & value
|
||||||
|
'''
|
||||||
|
|
||||||
|
if attn_type == AttentionType.DECODER:
|
||||||
|
seq_lens_q = self.seq_lens
|
||||||
|
seq_lens_kv = self.seq_lens
|
||||||
|
elif attn_type == AttentionType.ENCODER:
|
||||||
|
seq_lens_q = self.encoder_seq_lens
|
||||||
|
seq_lens_kv = self.encoder_seq_lens
|
||||||
|
elif attn_type == AttentionType.ENCODER_DECODER:
|
||||||
|
seq_lens_q = self.seq_lens
|
||||||
|
seq_lens_kv = self.encoder_seq_lens
|
||||||
|
else:
|
||||||
|
raise AttributeError(f"Invalid attention type {str(attn_type)}")
|
||||||
|
return seq_lens_q, seq_lens_kv
|
||||||
|
|
||||||
|
def get_attn_bias(
|
||||||
|
self,
|
||||||
|
attn_type: AttentionType,
|
||||||
|
) -> Optional[List[torch.Tensor]]:
|
||||||
|
'''
|
||||||
|
Extract appropriate attention bias from attention metadata
|
||||||
|
according to attention type.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
|
||||||
|
* attn_metadata: Attention metadata structure associated with attention
|
||||||
|
* attn_type: encoder attention, decoder self-attention,
|
||||||
|
encoder/decoder cross-attention
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
* Appropriate attention bias value given the attention type
|
||||||
|
'''
|
||||||
|
|
||||||
|
if attn_type == AttentionType.DECODER:
|
||||||
|
return self.attn_bias
|
||||||
|
elif attn_type == AttentionType.ENCODER:
|
||||||
|
return self.encoder_attn_bias
|
||||||
|
elif attn_type == AttentionType.ENCODER_DECODER:
|
||||||
|
return self.cross_attn_bias
|
||||||
|
else:
|
||||||
|
raise AttributeError(f"Invalid attention type {str(attn_type)}")
|
||||||
|
|
||||||
|
def set_attn_bias(
|
||||||
|
self,
|
||||||
|
attn_bias: List[torch.Tensor],
|
||||||
|
attn_type: AttentionType,
|
||||||
|
) -> None:
|
||||||
|
'''
|
||||||
|
Update appropriate attention bias field of attention metadata,
|
||||||
|
according to attention type.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
|
||||||
|
* attn_metadata: Attention metadata structure associated with attention
|
||||||
|
* attn_bias: The desired attention bias value
|
||||||
|
* attn_type: encoder attention, decoder self-attention,
|
||||||
|
encoder/decoder cross-attention
|
||||||
|
'''
|
||||||
|
|
||||||
|
if attn_type == AttentionType.DECODER:
|
||||||
|
self.attn_bias = attn_bias
|
||||||
|
elif attn_type == AttentionType.ENCODER:
|
||||||
|
self.encoder_attn_bias = attn_bias
|
||||||
|
elif attn_type == AttentionType.ENCODER_DECODER:
|
||||||
|
self.cross_attn_bias = attn_bias
|
||||||
|
else:
|
||||||
|
raise AttributeError(f"Invalid attention type {str(attn_type)}")
|
||||||
|
|
||||||
|
def get_seq_len_block_table_args(
|
||||||
|
self,
|
||||||
|
attn_type: AttentionType,
|
||||||
|
) -> tuple:
|
||||||
|
'''
|
||||||
|
The particular choice of sequence-length- and block-table-related
|
||||||
|
attributes which should be extracted from attn_metadata is dependent
|
||||||
|
on the type of attention operation.
|
||||||
|
|
||||||
|
Decoder attn -> select entirely decoder self-attention-related fields
|
||||||
|
Encoder/decoder cross-attn -> select encoder sequence lengths &
|
||||||
|
cross-attn block-tables fields
|
||||||
|
Encoder attn -> select encoder sequence lengths fields & no block tables
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
|
||||||
|
* attn_metadata: Attention metadata structure associated with attention
|
||||||
|
* is_prompt: True if prefill, False otherwise
|
||||||
|
* attn_type: encoder attention, decoder self-attention,
|
||||||
|
encoder/decoder cross-attention
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
* Appropriate sequence-lengths tensor
|
||||||
|
* Appropriate max sequence-length scalar
|
||||||
|
* Appropriate block tables (or None)
|
||||||
|
'''
|
||||||
|
|
||||||
|
if attn_type == AttentionType.DECODER:
|
||||||
|
# Decoder self-attention
|
||||||
|
# Choose max_seq_len based on whether we are in prompt_run
|
||||||
|
return (self.seq_lens_tensor, self.max_decode_seq_len,
|
||||||
|
self.block_tables)
|
||||||
|
elif attn_type == AttentionType.ENCODER_DECODER:
|
||||||
|
# Enc/dec cross-attention KVs match encoder sequence length;
|
||||||
|
# cross-attention utilizes special "cross" block tables
|
||||||
|
return (self.encoder_seq_lens_tensor, self.max_encoder_seq_len,
|
||||||
|
self.cross_block_tables)
|
||||||
|
elif attn_type == AttentionType.ENCODER:
|
||||||
|
# No block tables associated with encoder attention
|
||||||
|
return (self.encoder_seq_lens_tensor, self.max_encoder_seq_len,
|
||||||
|
None)
|
||||||
|
else:
|
||||||
|
raise AttributeError(f"Invalid attention type {str(attn_type)}")
|
||||||
|
|
||||||
|
|
||||||
class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
|
class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
|
||||||
|
|
||||||
@ -171,84 +339,101 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
|
|||||||
shape = [num_tokens, num_heads * head_size]
|
shape = [num_tokens, num_heads * head_size]
|
||||||
"""
|
"""
|
||||||
assert k_scale == 1.0 and v_scale == 1.0
|
assert k_scale == 1.0 and v_scale == 1.0
|
||||||
if attn_type != AttentionType.DECODER:
|
if (attn_type == AttentionType.ENCODER
|
||||||
raise NotImplementedError("Encoder self-attention and "
|
and (not attn_metadata.is_all_encoder_attn_metadata_set)):
|
||||||
"encoder/decoder cross-attention "
|
raise AttributeError("Encoder attention requires setting "
|
||||||
"are not implemented for "
|
"encoder metadata attributes.")
|
||||||
"TorchSDPABackendImpl")
|
elif (attn_type == AttentionType.ENCODER_DECODER
|
||||||
num_tokens, hidden_size = query.shape
|
and (not attn_metadata.is_all_cross_attn_metadata_set)):
|
||||||
|
raise AttributeError("Encoder/decoder cross-attention "
|
||||||
|
"requires setting cross-attention "
|
||||||
|
"metadata attributes.")
|
||||||
|
|
||||||
# Reshape the query, key, and value tensors.
|
# Reshape the query, key, and value tensors.
|
||||||
query = query.view(-1, self.num_heads, self.head_size)
|
query = query.view(-1, self.num_heads, self.head_size)
|
||||||
key = key.view(-1, self.num_kv_heads, self.head_size)
|
if key is not None:
|
||||||
value = value.view(-1, self.num_kv_heads, self.head_size)
|
assert value is not None
|
||||||
|
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||||
|
value = value.view(-1, self.num_kv_heads, self.head_size)
|
||||||
|
else:
|
||||||
|
assert value is None
|
||||||
|
|
||||||
if kv_cache.numel() > 0:
|
if (attn_type != AttentionType.ENCODER and kv_cache.numel() > 0):
|
||||||
|
# KV-cache during decoder-self- or
|
||||||
|
# encoder-decoder-cross-attention, but not
|
||||||
|
# during encoder attention.
|
||||||
|
#
|
||||||
|
# Even if there are no new key/value pairs to cache,
|
||||||
|
# we still need to break out key_cache and value_cache
|
||||||
|
# i.e. for later use by paged attention
|
||||||
key_cache, value_cache = PagedAttention.split_kv_cache(
|
key_cache, value_cache = PagedAttention.split_kv_cache(
|
||||||
kv_cache, self.num_kv_heads, self.head_size)
|
kv_cache, self.num_kv_heads, self.head_size)
|
||||||
PagedAttention.write_to_paged_cache(key, value, key_cache,
|
|
||||||
value_cache,
|
|
||||||
attn_metadata.slot_mapping,
|
|
||||||
self.kv_cache_dtype, k_scale,
|
|
||||||
v_scale)
|
|
||||||
|
|
||||||
if attn_metadata.is_prompt:
|
if (key is not None) and (value is not None):
|
||||||
|
if attn_type == AttentionType.ENCODER_DECODER:
|
||||||
|
# Update cross-attention KV cache (prefill-only)
|
||||||
|
# During cross-attention decode, key & value will be None,
|
||||||
|
# preventing this IF-statement branch from running
|
||||||
|
updated_slot_mapping = attn_metadata.cross_slot_mapping
|
||||||
|
else:
|
||||||
|
# Update self-attention KV cache (prefill/decode)
|
||||||
|
updated_slot_mapping = attn_metadata.slot_mapping
|
||||||
|
|
||||||
|
PagedAttention.write_to_paged_cache(key, value, key_cache,
|
||||||
|
value_cache,
|
||||||
|
updated_slot_mapping,
|
||||||
|
self.kv_cache_dtype,
|
||||||
|
k_scale, v_scale)
|
||||||
|
|
||||||
|
if attn_type != AttentionType.ENCODER:
|
||||||
|
# Decoder self-attention supports chunked prefill.
|
||||||
|
# Encoder/decoder cross-attention requires no chunked
|
||||||
|
# prefill (100% prefill or 100% decode tokens, no mix)
|
||||||
|
num_prefill_tokens = attn_metadata.num_prefill_tokens
|
||||||
|
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||||
|
else:
|
||||||
|
# Encoder attention - chunked prefill is not applicable;
|
||||||
|
# derive token-count from query shape & and treat them
|
||||||
|
# as 100% prefill tokens
|
||||||
|
assert attn_metadata.num_encoder_tokens is not None
|
||||||
|
num_prefill_tokens = attn_metadata.num_encoder_tokens
|
||||||
|
num_decode_tokens = 0
|
||||||
|
|
||||||
|
if attn_type == AttentionType.DECODER:
|
||||||
|
# Only enforce this shape-constraint for decoder
|
||||||
|
# self-attention
|
||||||
|
assert key.shape[0] == num_prefill_tokens + num_decode_tokens
|
||||||
|
assert value.shape[0] == num_prefill_tokens + num_decode_tokens
|
||||||
|
|
||||||
|
if prefill_meta := attn_metadata.prefill_metadata:
|
||||||
assert attn_metadata.seq_lens is not None
|
assert attn_metadata.seq_lens is not None
|
||||||
if (kv_cache.numel() == 0
|
if (kv_cache.numel() == 0
|
||||||
or attn_metadata.block_tables.numel() == 0):
|
or prefill_meta.block_tables.numel() == 0):
|
||||||
if self.num_kv_heads != self.num_heads:
|
output = self._run_sdpa_forward(query,
|
||||||
key = key.repeat_interleave(self.num_queries_per_kv, dim=1)
|
key,
|
||||||
value = value.repeat_interleave(self.num_queries_per_kv,
|
value,
|
||||||
dim=1)
|
prefill_meta,
|
||||||
|
attn_type=attn_type)
|
||||||
if attn_metadata.attn_bias is None:
|
|
||||||
if self.alibi_slopes is not None:
|
|
||||||
att_masks = _make_alibi_bias(
|
|
||||||
self.alibi_slopes, query.dtype,
|
|
||||||
attn_metadata.seq_lens) # type: ignore
|
|
||||||
elif self.sliding_window is not None:
|
|
||||||
att_masks = _make_sliding_window_bias(
|
|
||||||
attn_metadata.seq_lens, self.sliding_window,
|
|
||||||
query.dtype) # type: ignore
|
|
||||||
else:
|
|
||||||
att_masks = [None] * len(attn_metadata.seq_lens)
|
|
||||||
attn_metadata.attn_bias = att_masks
|
|
||||||
|
|
||||||
query = query.movedim(0, query.dim() - 2)
|
|
||||||
key = key.movedim(0, key.dim() - 2)
|
|
||||||
value = value.movedim(0, value.dim() - 2)
|
|
||||||
|
|
||||||
start = 0
|
|
||||||
output = torch.empty(
|
|
||||||
(num_tokens, self.num_heads, self.head_size),
|
|
||||||
dtype=query.dtype)
|
|
||||||
for seq_len, mask in zip(attn_metadata.seq_lens,
|
|
||||||
attn_metadata.attn_bias):
|
|
||||||
end = start + seq_len
|
|
||||||
sub_out = scaled_dot_product_attention(
|
|
||||||
query[None, :, start:end, :],
|
|
||||||
key[None, :, start:end, :],
|
|
||||||
value[None, :, start:end, :],
|
|
||||||
attn_mask=mask,
|
|
||||||
dropout_p=0.0,
|
|
||||||
is_causal=not self.need_mask,
|
|
||||||
scale=self.scale).squeeze(0).movedim(
|
|
||||||
query.dim() - 2, 0)
|
|
||||||
output[start:end, :, :] = sub_out
|
|
||||||
start = end
|
|
||||||
else:
|
else:
|
||||||
# prefix-enabled attention
|
# prefix-enabled attention
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"Torch SDPA backend doesn't support prefix decoding.")
|
"Torch SDPA backend doesn't support prefix decoding.")
|
||||||
|
|
||||||
else:
|
if decode_meta := attn_metadata.decode_metadata:
|
||||||
# Decoding run.
|
# Decoding run.
|
||||||
|
(
|
||||||
|
seq_lens_arg,
|
||||||
|
max_seq_len_arg,
|
||||||
|
block_tables_arg,
|
||||||
|
) = decode_meta.get_seq_len_block_table_args(attn_type)
|
||||||
|
|
||||||
output = PagedAttention.forward_decode(
|
output = PagedAttention.forward_decode(
|
||||||
query,
|
query,
|
||||||
key_cache,
|
key_cache,
|
||||||
value_cache,
|
value_cache,
|
||||||
attn_metadata.block_tables,
|
block_tables_arg,
|
||||||
attn_metadata.seq_lens_tensor,
|
seq_lens_arg,
|
||||||
attn_metadata.max_decode_seq_len,
|
max_seq_len_arg,
|
||||||
self.kv_cache_dtype,
|
self.kv_cache_dtype,
|
||||||
self.num_kv_heads,
|
self.num_kv_heads,
|
||||||
self.scale,
|
self.scale,
|
||||||
@ -260,6 +445,59 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
|
|||||||
# Reshape the output tensor.
|
# Reshape the output tensor.
|
||||||
return output.view(-1, self.num_heads * self.head_size)
|
return output.view(-1, self.num_heads * self.head_size)
|
||||||
|
|
||||||
|
def _run_sdpa_forward(
|
||||||
|
self,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
attn_metadata: TorchSDPAMetadata,
|
||||||
|
attn_type: AttentionType = AttentionType.DECODER,
|
||||||
|
):
|
||||||
|
if self.num_kv_heads != self.num_heads:
|
||||||
|
key = key.repeat_interleave(self.num_queries_per_kv, dim=1)
|
||||||
|
value = value.repeat_interleave(self.num_queries_per_kv, dim=1)
|
||||||
|
|
||||||
|
attn_masks = attn_metadata.get_attn_bias(attn_type)
|
||||||
|
if attn_masks is None:
|
||||||
|
if self.alibi_slopes is not None:
|
||||||
|
attn_masks = _make_alibi_bias(
|
||||||
|
self.alibi_slopes, query.dtype,
|
||||||
|
attn_metadata.seq_lens) # type: ignore
|
||||||
|
elif self.sliding_window is not None:
|
||||||
|
assert attn_metadata.seq_lens is not None
|
||||||
|
attn_masks = _make_sliding_window_bias(
|
||||||
|
attn_metadata.seq_lens, self.sliding_window,
|
||||||
|
query.dtype) # type: ignore
|
||||||
|
else:
|
||||||
|
seq_lens, _ = attn_metadata.get_seq_lens(attn_type)
|
||||||
|
attn_masks = [None] * len(seq_lens)
|
||||||
|
attn_metadata.set_attn_bias(attn_masks, attn_type)
|
||||||
|
|
||||||
|
output = torch.empty_like(query)
|
||||||
|
query = query.movedim(0, query.dim() - 2)
|
||||||
|
key = key.movedim(0, key.dim() - 2)
|
||||||
|
value = value.movedim(0, value.dim() - 2)
|
||||||
|
|
||||||
|
causal_attn = (attn_type == AttentionType.DECODER)
|
||||||
|
|
||||||
|
seq_lens_q, seq_lens_kv = attn_metadata.get_seq_lens(attn_type)
|
||||||
|
start_q, start_kv = 0, 0
|
||||||
|
for seq_len_q, seq_len_kv, mask in zip(seq_lens_q, seq_lens_kv,
|
||||||
|
attn_masks):
|
||||||
|
end_q = start_q + seq_len_q
|
||||||
|
end_kv = start_kv + seq_len_kv
|
||||||
|
sub_out = scaled_dot_product_attention(
|
||||||
|
query[None, :, start_q:end_q, :],
|
||||||
|
key[None, :, start_kv:end_kv, :],
|
||||||
|
value[None, :, start_kv:end_kv, :],
|
||||||
|
attn_mask=mask,
|
||||||
|
dropout_p=0.0,
|
||||||
|
is_causal=causal_attn and not self.need_mask,
|
||||||
|
scale=self.scale).squeeze(0).movedim(query.dim() - 2, 0)
|
||||||
|
output[start_q:end_q, :, :] = sub_out
|
||||||
|
start_q, start_kv = end_q, end_kv
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
def _make_alibi_bias(
|
def _make_alibi_bias(
|
||||||
alibi_slopes: torch.Tensor,
|
alibi_slopes: torch.Tensor,
|
||||||
|
|||||||
311
vllm/worker/cpu_enc_dec_model_runner.py
Normal file
311
vllm/worker/cpu_enc_dec_model_runner.py
Normal file
@ -0,0 +1,311 @@
|
|||||||
|
import dataclasses
|
||||||
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, cast
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.attention import AttentionMetadata
|
||||||
|
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||||
|
from vllm.multimodal import MultiModalInputs
|
||||||
|
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
|
||||||
|
from vllm.utils import make_tensor_with_pad
|
||||||
|
from vllm.worker.cpu_model_runner import (CPUModelRunner,
|
||||||
|
ModelInputForCPUBuilder,
|
||||||
|
ModelInputForCPUWithSamplingMetadata)
|
||||||
|
from vllm.worker.model_runner_base import (
|
||||||
|
_add_attn_metadata_broadcastable_dict,
|
||||||
|
_add_sampling_metadata_broadcastable_dict)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from vllm.attention.backends.abstract import AttentionBackend
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass(frozen=True)
|
||||||
|
class EncoderDecoderModelInputForCPU(ModelInputForCPUWithSamplingMetadata):
|
||||||
|
"""
|
||||||
|
Used by the EncoderDecoderModelRunner.
|
||||||
|
"""
|
||||||
|
encoder_input_tokens: Optional[torch.Tensor] = None
|
||||||
|
encoder_input_positions: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
|
||||||
|
tensor_dict = {
|
||||||
|
"input_tokens": self.input_tokens,
|
||||||
|
"input_positions": self.input_positions,
|
||||||
|
"encoder_input_tokens": self.encoder_input_tokens,
|
||||||
|
"encoder_input_positions": self.encoder_input_positions,
|
||||||
|
}
|
||||||
|
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
|
||||||
|
_add_sampling_metadata_broadcastable_dict(tensor_dict,
|
||||||
|
self.sampling_metadata)
|
||||||
|
return tensor_dict
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_broadcasted_tensor_dict(
|
||||||
|
cls,
|
||||||
|
tensor_dict: Dict[str, Any],
|
||||||
|
attn_backend: Optional["AttentionBackend"] = None,
|
||||||
|
) -> "EncoderDecoderModelInputForCPU":
|
||||||
|
return cast(
|
||||||
|
EncoderDecoderModelInputForCPU,
|
||||||
|
super().from_broadcasted_tensor_dict(tensor_dict, attn_backend))
|
||||||
|
|
||||||
|
|
||||||
|
class CPUEncoderDecoderModelRunner(CPUModelRunner):
|
||||||
|
_model_input_cls: Type[EncoderDecoderModelInputForCPU] = (
|
||||||
|
EncoderDecoderModelInputForCPU)
|
||||||
|
_builder_cls: Type[ModelInputForCPUBuilder] = ModelInputForCPUBuilder
|
||||||
|
|
||||||
|
def _list_to_int32_tensor(
|
||||||
|
self,
|
||||||
|
_list: List[int],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return torch.tensor(_list, dtype=torch.int32, device=self.device)
|
||||||
|
|
||||||
|
def _list_to_long_tensor(
|
||||||
|
self,
|
||||||
|
_list: List[int],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return torch.tensor(_list, dtype=torch.long, device=self.device)
|
||||||
|
|
||||||
|
def _empty_int32_tensor(self) -> torch.Tensor:
|
||||||
|
return self._list_to_int32_tensor([])
|
||||||
|
|
||||||
|
def _empty_long_tensor(self) -> torch.Tensor:
|
||||||
|
return self._list_to_long_tensor([])
|
||||||
|
|
||||||
|
def make_model_input_from_broadcasted_tensor_dict(
|
||||||
|
self, tensor_dict: Dict[str,
|
||||||
|
Any]) -> EncoderDecoderModelInputForCPU:
|
||||||
|
return EncoderDecoderModelInputForCPU.from_broadcasted_tensor_dict(
|
||||||
|
tensor_dict,
|
||||||
|
attn_backend=self.attn_backend,
|
||||||
|
)
|
||||||
|
|
||||||
|
def prepare_model_input(
|
||||||
|
self,
|
||||||
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
|
virtual_engine: int = 0,
|
||||||
|
finished_requests_ids: Optional[List[str]] = None
|
||||||
|
) -> EncoderDecoderModelInputForCPU:
|
||||||
|
model_input = super().prepare_model_input(seq_group_metadata_list,
|
||||||
|
virtual_engine,
|
||||||
|
finished_requests_ids)
|
||||||
|
model_input = cast(EncoderDecoderModelInputForCPU, model_input)
|
||||||
|
(
|
||||||
|
attn_metadata,
|
||||||
|
encoder_input_tokens_tensor,
|
||||||
|
encoder_input_positions_tensor,
|
||||||
|
) = self._prepare_encoder_model_input_tensors(seq_group_metadata_list,
|
||||||
|
model_input)
|
||||||
|
return dataclasses.replace(
|
||||||
|
model_input,
|
||||||
|
attn_metadata=attn_metadata,
|
||||||
|
encoder_input_tokens=encoder_input_tokens_tensor,
|
||||||
|
encoder_input_positions=encoder_input_positions_tensor,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _prepare_encoder_model_input_tensors(
|
||||||
|
self,
|
||||||
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
|
model_input: EncoderDecoderModelInputForCPU,
|
||||||
|
) -> Tuple[AttentionMetadata, Optional[torch.Tensor],
|
||||||
|
Optional[torch.Tensor]]:
|
||||||
|
"""Helper method to prepare the encoder- and cross-attn-related
|
||||||
|
model inputs based on a given sequence group. These additional inputs
|
||||||
|
are used to augment an already-computed `EncoderDecoderModelInput`
|
||||||
|
data structure which already has decoder-related model inputs
|
||||||
|
populated.
|
||||||
|
|
||||||
|
Sets the following attn_metadata fields:
|
||||||
|
* `num_encoder_tokens`
|
||||||
|
* `encoder_seq_lens`
|
||||||
|
* `encoder_seq_lens_tensor`
|
||||||
|
* `max_encoder_seq_len`
|
||||||
|
* `cross_slot_mapping`
|
||||||
|
* `cross_block_tables`
|
||||||
|
|
||||||
|
Constructs a new model inputs data structure, based on
|
||||||
|
(1) the existing fields in the `model_inputs` argument,
|
||||||
|
and (2) the following additional fields which are
|
||||||
|
computed (or in the case of `attn_metadata`, updated)
|
||||||
|
by this function:
|
||||||
|
* attn_metadata
|
||||||
|
* encoder_input_tokens
|
||||||
|
* encoder_input_positions
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
|
||||||
|
* seq_group_metadata_list: list of sequence groups for which to
|
||||||
|
compute inputs
|
||||||
|
* model_inputs: model inputs data structure with decoder-oriented
|
||||||
|
fields already computed.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
|
||||||
|
* Updated model inputs data structure
|
||||||
|
"""
|
||||||
|
|
||||||
|
if len(seq_group_metadata_list) == 0:
|
||||||
|
return (model_input.attn_metadata, None, None)
|
||||||
|
|
||||||
|
# Since we are not supporting chunked prefill either the entire
|
||||||
|
# batch is prefill or it is decode
|
||||||
|
is_prompt = seq_group_metadata_list[0].is_prompt
|
||||||
|
|
||||||
|
# Build encoder inputs
|
||||||
|
encoder_seq_lens: List[int] = []
|
||||||
|
if is_prompt:
|
||||||
|
# Prefill phase.
|
||||||
|
cross_block_tables = self._empty_int32_tensor().view(
|
||||||
|
len(seq_group_metadata_list), -1)
|
||||||
|
|
||||||
|
# Extract input tokens/positions, cross-attention slot-mapping,
|
||||||
|
# & seq len from each sequence group metadata
|
||||||
|
(
|
||||||
|
encoder_input_tokens,
|
||||||
|
encoder_input_positions,
|
||||||
|
cross_slot_mapping,
|
||||||
|
) = (
|
||||||
|
[],
|
||||||
|
[],
|
||||||
|
[],
|
||||||
|
)
|
||||||
|
for seq_group_metadata in seq_group_metadata_list:
|
||||||
|
# Build seq lens
|
||||||
|
seq_len = seq_group_metadata.encoder_seq_data.get_len()
|
||||||
|
token_ids = seq_group_metadata.encoder_seq_data.get_token_ids()
|
||||||
|
encoder_seq_lens.append(seq_len)
|
||||||
|
|
||||||
|
# Build slot mapping
|
||||||
|
for i in range(0, seq_len):
|
||||||
|
block_number = seq_group_metadata.cross_block_table[
|
||||||
|
i // self.block_size]
|
||||||
|
block_offset = i % self.block_size
|
||||||
|
slot = block_number * self.block_size + block_offset
|
||||||
|
cross_slot_mapping.append(slot)
|
||||||
|
|
||||||
|
# Build encoder input tokens
|
||||||
|
encoder_input_tokens.extend(token_ids)
|
||||||
|
encoder_input_positions.extend(list(range(0, seq_len)))
|
||||||
|
|
||||||
|
# Convert tokens/positions & cross-attention
|
||||||
|
# slot-mapping to encoder input tensors
|
||||||
|
encoder_input_tokens_tensor = self._list_to_long_tensor(
|
||||||
|
encoder_input_tokens)
|
||||||
|
encoder_input_positions_tensor = self._list_to_long_tensor(
|
||||||
|
encoder_input_positions)
|
||||||
|
cross_slot_mapping_tensor = self._list_to_long_tensor(
|
||||||
|
cross_slot_mapping)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Decode phase.
|
||||||
|
encoder_input_tokens_tensor = self._empty_long_tensor()
|
||||||
|
encoder_input_positions_tensor = self._empty_long_tensor()
|
||||||
|
cross_slot_mapping_tensor = self._empty_long_tensor()
|
||||||
|
# Extract cross-attention block tables &
|
||||||
|
# seq len from each sequence group metadata.
|
||||||
|
# Cross-attention block tables are empty
|
||||||
|
# during vLLM memory profiling.
|
||||||
|
cross_block_tables = []
|
||||||
|
for seq_group_metadata in seq_group_metadata_list:
|
||||||
|
for _ in range(len(seq_group_metadata.seq_data)):
|
||||||
|
encoder_seq_lens.append(
|
||||||
|
seq_group_metadata.encoder_seq_data.get_len())
|
||||||
|
cross_block_table = seq_group_metadata.cross_block_table
|
||||||
|
cross_block_tables.append([] if (
|
||||||
|
cross_block_table is None) else cross_block_table)
|
||||||
|
|
||||||
|
max_len_of_block_table = max(
|
||||||
|
len(block_table) for block_table in cross_block_tables)
|
||||||
|
|
||||||
|
cross_block_tables = make_tensor_with_pad(
|
||||||
|
cross_block_tables,
|
||||||
|
max_len=max_len_of_block_table,
|
||||||
|
pad=0,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compute encoder sequence lengths & encoder
|
||||||
|
# sequence starting offset tensors
|
||||||
|
max_encoder_seq_len = max(encoder_seq_lens, default=0)
|
||||||
|
encoder_seq_lens_tensor = self._list_to_int32_tensor(encoder_seq_lens)
|
||||||
|
encoder_seq_start_loc = torch.zeros(encoder_seq_lens_tensor.shape[0] +
|
||||||
|
1,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=self.device)
|
||||||
|
torch.cumsum(encoder_seq_lens_tensor,
|
||||||
|
dim=0,
|
||||||
|
dtype=encoder_seq_start_loc.dtype,
|
||||||
|
out=encoder_seq_start_loc[1:])
|
||||||
|
|
||||||
|
# Update attention metadata with encoder-oriented attributes
|
||||||
|
attn_metadata = model_input.attn_metadata
|
||||||
|
assert attn_metadata is not None
|
||||||
|
(
|
||||||
|
attn_metadata.num_encoder_tokens,
|
||||||
|
attn_metadata.encoder_seq_lens,
|
||||||
|
attn_metadata.encoder_seq_lens_tensor,
|
||||||
|
attn_metadata.max_encoder_seq_len,
|
||||||
|
attn_metadata.cross_slot_mapping,
|
||||||
|
attn_metadata.cross_block_tables,
|
||||||
|
) = (
|
||||||
|
sum(encoder_seq_lens),
|
||||||
|
encoder_seq_lens,
|
||||||
|
encoder_seq_lens_tensor,
|
||||||
|
max_encoder_seq_len,
|
||||||
|
cross_slot_mapping_tensor,
|
||||||
|
cross_block_tables,
|
||||||
|
)
|
||||||
|
|
||||||
|
return (attn_metadata, encoder_input_tokens_tensor,
|
||||||
|
encoder_input_positions_tensor)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def execute_model(
|
||||||
|
self,
|
||||||
|
model_input: EncoderDecoderModelInputForCPU,
|
||||||
|
kv_caches: List[torch.Tensor],
|
||||||
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
|
num_steps: int = 1,
|
||||||
|
) -> Optional[List[SamplerOutput]]:
|
||||||
|
if num_steps > 1:
|
||||||
|
raise ValueError(
|
||||||
|
"CPU worker does not support multi-step execution.")
|
||||||
|
|
||||||
|
model_executable = self.model
|
||||||
|
execute_model_kwargs = {
|
||||||
|
"input_ids":
|
||||||
|
model_input.input_tokens,
|
||||||
|
"positions":
|
||||||
|
model_input.input_positions,
|
||||||
|
"encoder_input_ids":
|
||||||
|
model_input.encoder_input_tokens,
|
||||||
|
"encoder_positions":
|
||||||
|
model_input.encoder_input_positions,
|
||||||
|
"kv_caches":
|
||||||
|
kv_caches,
|
||||||
|
"attn_metadata":
|
||||||
|
model_input.attn_metadata,
|
||||||
|
**MultiModalInputs.as_kwargs(model_input.multi_modal_kwargs or {},
|
||||||
|
device=self.device),
|
||||||
|
"intermediate_tensors":
|
||||||
|
intermediate_tensors,
|
||||||
|
}
|
||||||
|
|
||||||
|
hidden_states = model_executable(**execute_model_kwargs)
|
||||||
|
|
||||||
|
# Compute the logits.
|
||||||
|
logits = self.model.compute_logits(hidden_states,
|
||||||
|
model_input.sampling_metadata)
|
||||||
|
|
||||||
|
# Only perform sampling in the driver worker.
|
||||||
|
if not self.is_driver_worker:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Sample the next token.
|
||||||
|
output = self.model.sample(
|
||||||
|
logits=logits,
|
||||||
|
sampling_metadata=model_input.sampling_metadata,
|
||||||
|
)
|
||||||
|
return [output]
|
||||||
@ -19,7 +19,7 @@ from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
|
|||||||
MultiModalInputs)
|
MultiModalInputs)
|
||||||
from vllm.sequence import (IntermediateTensors, SequenceData,
|
from vllm.sequence import (IntermediateTensors, SequenceData,
|
||||||
SequenceGroupMetadata)
|
SequenceGroupMetadata)
|
||||||
from vllm.utils import STR_NOT_IMPL_ENC_DEC_ERR_STRS, make_tensor_with_pad
|
from vllm.utils import make_tensor_with_pad
|
||||||
from vllm.worker.model_runner_base import (
|
from vllm.worker.model_runner_base import (
|
||||||
ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase,
|
ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase,
|
||||||
_add_attn_metadata_broadcastable_dict,
|
_add_attn_metadata_broadcastable_dict,
|
||||||
@ -434,10 +434,6 @@ class CPUModelRunner(ModelRunnerBase[ModelInputForCPU]):
|
|||||||
# Lazy initialization.
|
# Lazy initialization.
|
||||||
self.model: nn.Module # Set after init_Model
|
self.model: nn.Module # Set after init_Model
|
||||||
|
|
||||||
if self.model_config.is_encoder_decoder_model:
|
|
||||||
raise NotImplementedError(
|
|
||||||
STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_CPU'])
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def model_is_mrope(self) -> bool:
|
def model_is_mrope(self) -> bool:
|
||||||
"""Detect if the model has "mrope" rope_scaling type.
|
"""Detect if the model has "mrope" rope_scaling type.
|
||||||
@ -459,8 +455,8 @@ class CPUModelRunner(ModelRunnerBase[ModelInputForCPU]):
|
|||||||
def make_model_input_from_broadcasted_tensor_dict(
|
def make_model_input_from_broadcasted_tensor_dict(
|
||||||
self,
|
self,
|
||||||
tensor_dict: Dict[str, Any],
|
tensor_dict: Dict[str, Any],
|
||||||
) -> ModelInputForCPU:
|
) -> ModelInputForCPUWithSamplingMetadata:
|
||||||
return ModelInputForCPU.from_broadcasted_tensor_dict(
|
return ModelInputForCPUWithSamplingMetadata.from_broadcasted_tensor_dict( # noqa: E501
|
||||||
tensor_dict,
|
tensor_dict,
|
||||||
attn_backend=self.attn_backend,
|
attn_backend=self.attn_backend,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
"""A CPU worker class."""
|
"""A CPU worker class."""
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple, Type
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
@ -15,6 +15,7 @@ from vllm.logger import init_logger
|
|||||||
from vllm.model_executor import set_random_seed
|
from vllm.model_executor import set_random_seed
|
||||||
from vllm.sequence import ExecuteModelRequest
|
from vllm.sequence import ExecuteModelRequest
|
||||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
|
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||||
|
from vllm.worker.cpu_enc_dec_model_runner import CPUEncoderDecoderModelRunner
|
||||||
from vllm.worker.cpu_model_runner import CPUModelRunner
|
from vllm.worker.cpu_model_runner import CPUModelRunner
|
||||||
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase,
|
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase,
|
||||||
LoraNotSupportedWorkerBase, WorkerInput)
|
LoraNotSupportedWorkerBase, WorkerInput)
|
||||||
@ -163,7 +164,10 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
|
|||||||
else:
|
else:
|
||||||
self.local_omp_cpuid = omp_cpuids.split("|")[rank]
|
self.local_omp_cpuid = omp_cpuids.split("|")[rank]
|
||||||
|
|
||||||
self.model_runner: CPUModelRunner = CPUModelRunner(
|
ModelRunnerClass: Type[CPUModelRunner] = CPUModelRunner
|
||||||
|
if self._is_encoder_decoder_model():
|
||||||
|
ModelRunnerClass = CPUEncoderDecoderModelRunner
|
||||||
|
self.model_runner: CPUModelRunner = ModelRunnerClass(
|
||||||
model_config,
|
model_config,
|
||||||
parallel_config,
|
parallel_config,
|
||||||
scheduler_config,
|
scheduler_config,
|
||||||
@ -205,6 +209,9 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
|
|||||||
raise RuntimeError("Profiler is not enabled.")
|
raise RuntimeError("Profiler is not enabled.")
|
||||||
self.profiler.stop()
|
self.profiler.stop()
|
||||||
|
|
||||||
|
def _is_encoder_decoder_model(self):
|
||||||
|
return self.model_config.is_encoder_decoder_model
|
||||||
|
|
||||||
def init_device(self) -> None:
|
def init_device(self) -> None:
|
||||||
if self.local_omp_cpuid != "all":
|
if self.local_omp_cpuid != "all":
|
||||||
ret = torch.ops._C_utils.init_cpu_threads_env(self.local_omp_cpuid)
|
ret = torch.ops._C_utils.init_cpu_threads_env(self.local_omp_cpuid)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user