mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 18:25:40 +08:00
[V0 Deprecation] Remove V0 Sequence class & Sampler (#25332)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Signed-off-by: Woosuk Kwon <woosuk@thinkingmachines.ai>
This commit is contained in:
parent
65a5910ce3
commit
26e673fe93
@ -48,10 +48,10 @@ from vllm.distributed import (cleanup_dist_env_and_memory,
|
|||||||
initialize_model_parallel)
|
initialize_model_parallel)
|
||||||
from vllm.inputs import TextPrompt
|
from vllm.inputs import TextPrompt
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.logprobs import Logprob
|
||||||
from vllm.multimodal.utils import fetch_image
|
from vllm.multimodal.utils import fetch_image
|
||||||
from vllm.outputs import RequestOutput
|
from vllm.outputs import RequestOutput
|
||||||
from vllm.sampling_params import BeamSearchParams
|
from vllm.sampling_params import BeamSearchParams
|
||||||
from vllm.sequence import Logprob
|
|
||||||
from vllm.transformers_utils.utils import maybe_model_redirect
|
from vllm.transformers_utils.utils import maybe_model_redirect
|
||||||
from vllm.utils import set_default_torch_num_threads
|
from vllm.utils import set_default_torch_num_threads
|
||||||
|
|
||||||
|
|||||||
@ -7,8 +7,8 @@ from typing import Optional
|
|||||||
import pytest
|
import pytest
|
||||||
from transformers import AutoModelForSpeechSeq2Seq
|
from transformers import AutoModelForSpeechSeq2Seq
|
||||||
|
|
||||||
|
from vllm.logprobs import SampleLogprobs
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.sequence import SampleLogprobs
|
|
||||||
|
|
||||||
from ....conftest import (AudioTestAssets, HfRunner, PromptAudioInput,
|
from ....conftest import (AudioTestAssets, HfRunner, PromptAudioInput,
|
||||||
VllmRunner)
|
VllmRunner)
|
||||||
|
|||||||
@ -12,10 +12,10 @@ from huggingface_hub import snapshot_download
|
|||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
from vllm.assets.image import ImageAsset
|
from vllm.assets.image import ImageAsset
|
||||||
|
from vllm.logprobs import SampleLogprobs
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.multimodal.image import convert_image_mode, rescale_image_size
|
from vllm.multimodal.image import convert_image_mode, rescale_image_size
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.sequence import SampleLogprobs
|
|
||||||
|
|
||||||
from ....conftest import (IMAGE_ASSETS, HfRunner, PromptAudioInput,
|
from ....conftest import (IMAGE_ASSETS, HfRunner, PromptAudioInput,
|
||||||
PromptImageInput, VllmRunner)
|
PromptImageInput, VllmRunner)
|
||||||
|
|||||||
@ -13,8 +13,8 @@ from mistral_common.tokens.tokenizers.multimodal import image_from_chunk
|
|||||||
from transformers import AutoProcessor
|
from transformers import AutoProcessor
|
||||||
|
|
||||||
from vllm import SamplingParams, TextPrompt, TokensPrompt
|
from vllm import SamplingParams, TextPrompt, TokensPrompt
|
||||||
|
from vllm.logprobs import Logprob, SampleLogprobs
|
||||||
from vllm.multimodal import MultiModalDataBuiltins
|
from vllm.multimodal import MultiModalDataBuiltins
|
||||||
from vllm.sequence import Logprob, SampleLogprobs
|
|
||||||
|
|
||||||
from ....utils import VLLM_PATH, large_gpu_test
|
from ....utils import VLLM_PATH, large_gpu_test
|
||||||
from ...utils import check_logprobs_close
|
from ...utils import check_logprobs_close
|
||||||
|
|||||||
@ -19,7 +19,7 @@ from transformers import (AutoConfig, AutoTokenizer, BatchFeature,
|
|||||||
GenerationConfig, GenerationMixin)
|
GenerationConfig, GenerationMixin)
|
||||||
from transformers.video_utils import VideoMetadata
|
from transformers.video_utils import VideoMetadata
|
||||||
|
|
||||||
from vllm.sequence import SampleLogprobs
|
from vllm.logprobs import SampleLogprobs
|
||||||
from vllm.utils import is_list_of
|
from vllm.utils import is_list_of
|
||||||
|
|
||||||
from .....conftest import HfRunner, ImageAsset, ImageTestAssets
|
from .....conftest import HfRunner, ImageAsset, ImageTestAssets
|
||||||
|
|||||||
@ -12,7 +12,7 @@ from transformers import AutoModelForCausalLM
|
|||||||
from transformers.models.auto.auto_factory import _BaseAutoModelClass
|
from transformers.models.auto.auto_factory import _BaseAutoModelClass
|
||||||
|
|
||||||
from vllm.config import RunnerOption
|
from vllm.config import RunnerOption
|
||||||
from vllm.sequence import SampleLogprobs
|
from vllm.logprobs import SampleLogprobs
|
||||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||||
|
|
||||||
from .....conftest import (AUDIO_ASSETS, IMAGE_ASSETS, HfRunner, ImageAsset,
|
from .....conftest import (AUDIO_ASSETS, IMAGE_ASSETS, HfRunner, ImageAsset,
|
||||||
|
|||||||
@ -12,7 +12,7 @@ from transformers import PretrainedConfig
|
|||||||
|
|
||||||
from vllm.config import ModelConfig, ModelDType, RunnerOption
|
from vllm.config import ModelConfig, ModelDType, RunnerOption
|
||||||
from vllm.inputs import InputContext
|
from vllm.inputs import InputContext
|
||||||
from vllm.sequence import Logprob, PromptLogprobs, SampleLogprobs
|
from vllm.logprobs import Logprob, PromptLogprobs, SampleLogprobs
|
||||||
|
|
||||||
from .registry import HF_EXAMPLE_MODELS
|
from .registry import HF_EXAMPLE_MODELS
|
||||||
|
|
||||||
|
|||||||
@ -8,10 +8,7 @@ import pytest
|
|||||||
from transformers import (AutoTokenizer, PreTrainedTokenizer,
|
from transformers import (AutoTokenizer, PreTrainedTokenizer,
|
||||||
PreTrainedTokenizerFast)
|
PreTrainedTokenizerFast)
|
||||||
|
|
||||||
from vllm.inputs import token_inputs
|
from vllm.sampling_params import SamplingParams
|
||||||
from vllm.sequence import Logprob, SamplingParams, Sequence, SequenceGroup
|
|
||||||
from vllm.transformers_utils.detokenizer import Detokenizer
|
|
||||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
|
||||||
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
|
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
|
||||||
from vllm.v1.engine import EngineCoreRequest
|
from vllm.v1.engine import EngineCoreRequest
|
||||||
from vllm.v1.engine.detokenizer import (FastIncrementalDetokenizer,
|
from vllm.v1.engine.detokenizer import (FastIncrementalDetokenizer,
|
||||||
@ -217,138 +214,3 @@ def test_oov_decode(tokenizer, fast):
|
|||||||
|
|
||||||
assert decoded_text == ''
|
assert decoded_text == ''
|
||||||
assert out_ids == [len(tokenizer)]
|
assert out_ids == [len(tokenizer)]
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def detokenizer(tokenizer_name: str) -> Detokenizer:
|
|
||||||
tokenizer = get_tokenizer(
|
|
||||||
tokenizer_name,
|
|
||||||
tokenizer_mode="mistral" if "mistral" in tokenizer_name else "auto",
|
|
||||||
trust_remote_code=False,
|
|
||||||
revision=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
return Detokenizer(tokenizer)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(name="complete_sequence_token_ids")
|
|
||||||
def create_complete_sequence_token_ids(complete_sequence: str,
|
|
||||||
tokenizer) -> list[int]:
|
|
||||||
return tokenizer(complete_sequence, add_special_tokens=False).input_ids
|
|
||||||
|
|
||||||
|
|
||||||
def create_sequence(prompt_token_ids=None):
|
|
||||||
prompt_token_ids = prompt_token_ids or []
|
|
||||||
return Sequence(
|
|
||||||
seq_id=0,
|
|
||||||
inputs=token_inputs(prompt_token_ids),
|
|
||||||
block_size=16,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def create_dummy_logprobs(
|
|
||||||
complete_sequence_token_ids: list[int]) -> list[dict[int, Logprob]]:
|
|
||||||
return [{
|
|
||||||
token_id: Logprob(logprob=0.0),
|
|
||||||
token_id + 1: Logprob(logprob=0.1)
|
|
||||||
} for token_id in complete_sequence_token_ids]
|
|
||||||
|
|
||||||
|
|
||||||
def create_dummy_prompt_logprobs(
|
|
||||||
complete_sequence_token_ids: list[int]
|
|
||||||
) -> list[Optional[dict[int, Any]]]:
|
|
||||||
# logprob for the first prompt token is None.
|
|
||||||
logprobs: list[Optional[dict[int, Any]]] = [None]
|
|
||||||
logprobs.extend(create_dummy_logprobs(complete_sequence_token_ids)[1:])
|
|
||||||
return logprobs
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("complete_sequence", TRUTH)
|
|
||||||
@pytest.mark.parametrize("tokenizer_name", TOKENIZERS)
|
|
||||||
@pytest.mark.parametrize("skip_special_tokens", [True, False], indirect=True)
|
|
||||||
def test_decode_sequence_logprobs(complete_sequence: str,
|
|
||||||
complete_sequence_token_ids: list[int],
|
|
||||||
detokenizer: Detokenizer,
|
|
||||||
skip_special_tokens: bool):
|
|
||||||
"""Verify Detokenizer decodes logprobs correctly."""
|
|
||||||
sampling_params = SamplingParams(skip_special_tokens=skip_special_tokens,
|
|
||||||
logprobs=2)
|
|
||||||
|
|
||||||
# Run sequentially.
|
|
||||||
seq = create_sequence()
|
|
||||||
dummy_logprobs = create_dummy_logprobs(complete_sequence_token_ids)
|
|
||||||
sequential_logprobs_text_chosen_token: list[str] = []
|
|
||||||
sequential_logprobs_text_other_token: list[str] = []
|
|
||||||
for new_token, logprobs in zip(complete_sequence_token_ids,
|
|
||||||
dummy_logprobs):
|
|
||||||
seq.append_token_id(new_token, logprobs)
|
|
||||||
detokenizer.decode_sequence_inplace(seq, sampling_params)
|
|
||||||
sequential_logprobs_text_chosen_token.append(
|
|
||||||
seq.output_logprobs[-1][new_token].decoded_token)
|
|
||||||
sequential_logprobs_text_other_token.append(
|
|
||||||
seq.output_logprobs[-1][new_token + 1].decoded_token)
|
|
||||||
sequential_result = seq.output_text
|
|
||||||
|
|
||||||
assert sequential_result == "".join(sequential_logprobs_text_chosen_token)
|
|
||||||
assert sequential_result != "".join(sequential_logprobs_text_other_token)
|
|
||||||
|
|
||||||
if not skip_special_tokens:
|
|
||||||
# Text for logprobs for the chosen token should be the same as the
|
|
||||||
# generated text. Note that this will only be true if we skip
|
|
||||||
# special tokens.
|
|
||||||
assert sequential_result == complete_sequence
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("complete_sequence", TRUTH)
|
|
||||||
@pytest.mark.parametrize("tokenizer_name", TOKENIZERS)
|
|
||||||
def test_decode_prompt_logprobs(complete_sequence: str,
|
|
||||||
complete_sequence_token_ids: list[int],
|
|
||||||
detokenizer: Detokenizer):
|
|
||||||
|
|
||||||
# We want to use skip_special_tokens=False here but Mistral tokenizers
|
|
||||||
# don't support that.
|
|
||||||
if complete_sequence not in SPECIAL_TOKS_TRUTH:
|
|
||||||
skip_special_tokens = True
|
|
||||||
elif not isinstance(detokenizer.tokenizer, MistralTokenizer):
|
|
||||||
skip_special_tokens = False
|
|
||||||
else:
|
|
||||||
pytest.skip("MistralTokenizers don't support "
|
|
||||||
"skip_special_tokens=False")
|
|
||||||
return
|
|
||||||
"""Verify Detokenizer decodes prompt logprobs correctly."""
|
|
||||||
sampling_params = SamplingParams(skip_special_tokens=skip_special_tokens,
|
|
||||||
prompt_logprobs=1)
|
|
||||||
|
|
||||||
# Run sequentially.
|
|
||||||
seq = create_sequence(complete_sequence_token_ids)
|
|
||||||
seq_group = SequenceGroup(request_id="1",
|
|
||||||
seqs=[seq],
|
|
||||||
sampling_params=sampling_params,
|
|
||||||
arrival_time=0.0)
|
|
||||||
dummy_logprobs = create_dummy_prompt_logprobs(complete_sequence_token_ids)
|
|
||||||
detokenizer.decode_prompt_logprobs_inplace(seq_group,
|
|
||||||
dummy_logprobs,
|
|
||||||
position_offset=0)
|
|
||||||
# First logprob is None.
|
|
||||||
decoded_prompt_logprobs: list[dict[int, Any]] = dummy_logprobs[
|
|
||||||
1:] # type: ignore
|
|
||||||
|
|
||||||
# decoded_prompt_logprobs doesn't contain the first token.
|
|
||||||
token_ids = complete_sequence_token_ids
|
|
||||||
tokenizer = detokenizer.tokenizer
|
|
||||||
text_full = tokenizer.decode(token_ids,
|
|
||||||
skip_special_tokens=skip_special_tokens)
|
|
||||||
text_first = tokenizer.decode(token_ids[0],
|
|
||||||
skip_special_tokens=skip_special_tokens)
|
|
||||||
text = text_full[len(text_first):]
|
|
||||||
|
|
||||||
# Text for logprobs for the chosen token should be the same as the
|
|
||||||
# prompt text. Note that the first logprob is None.
|
|
||||||
assert text == "".join([
|
|
||||||
logprobs[token_id].decoded_token
|
|
||||||
for token_id, logprobs in zip(token_ids[1:], decoded_prompt_logprobs)
|
|
||||||
])
|
|
||||||
assert text != "".join([
|
|
||||||
logprobs[token_id + 1].decoded_token
|
|
||||||
for token_id, logprobs in zip(token_ids[1:], decoded_prompt_logprobs)
|
|
||||||
])
|
|
||||||
|
|||||||
@ -12,7 +12,7 @@ from partial_json_parser.core.options import Allow
|
|||||||
from vllm.entrypoints.openai.protocol import (DeltaMessage, FunctionCall,
|
from vllm.entrypoints.openai.protocol import (DeltaMessage, FunctionCall,
|
||||||
ToolCall)
|
ToolCall)
|
||||||
from vllm.entrypoints.openai.tool_parsers import JambaToolParser
|
from vllm.entrypoints.openai.tool_parsers import JambaToolParser
|
||||||
from vllm.transformers_utils.detokenizer import detokenize_incrementally
|
from vllm.transformers_utils.detokenizer_utils import detokenize_incrementally
|
||||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer
|
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer
|
||||||
|
|
||||||
MODEL = "ai21labs/Jamba-tiny-dev"
|
MODEL = "ai21labs/Jamba-tiny-dev"
|
||||||
|
|||||||
@ -13,7 +13,7 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
|||||||
ToolCall)
|
ToolCall)
|
||||||
from vllm.entrypoints.openai.tool_parsers.qwen3coder_tool_parser import (
|
from vllm.entrypoints.openai.tool_parsers.qwen3coder_tool_parser import (
|
||||||
Qwen3CoderToolParser)
|
Qwen3CoderToolParser)
|
||||||
from vllm.transformers_utils.detokenizer import detokenize_incrementally
|
from vllm.transformers_utils.detokenizer_utils import detokenize_incrementally
|
||||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer
|
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer
|
||||||
|
|
||||||
MODEL = "Qwen/Qwen3-Coder-30B-A3B-Instruct-FP8"
|
MODEL = "Qwen/Qwen3-Coder-30B-A3B-Instruct-FP8"
|
||||||
|
|||||||
@ -13,7 +13,7 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
|||||||
DeltaMessage, FunctionCall,
|
DeltaMessage, FunctionCall,
|
||||||
ToolCall)
|
ToolCall)
|
||||||
from vllm.entrypoints.openai.tool_parsers import SeedOssToolParser
|
from vllm.entrypoints.openai.tool_parsers import SeedOssToolParser
|
||||||
from vllm.transformers_utils.detokenizer import detokenize_incrementally
|
from vllm.transformers_utils.detokenizer_utils import detokenize_incrementally
|
||||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer
|
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer
|
||||||
|
|
||||||
# Use a common model that is likely to be available
|
# Use a common model that is likely to be available
|
||||||
|
|||||||
@ -11,7 +11,7 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
|||||||
DeltaMessage, FunctionCall,
|
DeltaMessage, FunctionCall,
|
||||||
ToolCall)
|
ToolCall)
|
||||||
from vllm.entrypoints.openai.tool_parsers import xLAMToolParser
|
from vllm.entrypoints.openai.tool_parsers import xLAMToolParser
|
||||||
from vllm.transformers_utils.detokenizer import detokenize_incrementally
|
from vllm.transformers_utils.detokenizer_utils import detokenize_incrementally
|
||||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer
|
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer
|
||||||
|
|
||||||
# Use a common model that is likely to be available
|
# Use a common model that is likely to be available
|
||||||
|
|||||||
@ -12,9 +12,9 @@ from tests.v1.engine.utils import (NUM_PROMPT_LOGPROBS_UNDER_TEST,
|
|||||||
STOP_STRINGS,
|
STOP_STRINGS,
|
||||||
DummyOutputProcessorTestVectors,
|
DummyOutputProcessorTestVectors,
|
||||||
MockEngineCore)
|
MockEngineCore)
|
||||||
|
from vllm.logprobs import PromptLogprobs, SampleLogprobs
|
||||||
from vllm.outputs import CompletionOutput, RequestOutput
|
from vllm.outputs import CompletionOutput, RequestOutput
|
||||||
from vllm.sampling_params import RequestOutputKind, SamplingParams
|
from vllm.sampling_params import RequestOutputKind, SamplingParams
|
||||||
from vllm.sequence import PromptLogprobs, SampleLogprobs
|
|
||||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||||
from vllm.v1.engine import EngineCoreRequest
|
from vllm.v1.engine import EngineCoreRequest
|
||||||
from vllm.v1.engine.output_processor import (OutputProcessor,
|
from vllm.v1.engine.output_processor import (OutputProcessor,
|
||||||
|
|||||||
@ -15,10 +15,10 @@ from vllm.config import VllmConfig
|
|||||||
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
|
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
|
||||||
from vllm.sequence import ExecuteModelRequest, PoolerOutput
|
from vllm.sequence import ExecuteModelRequest, PoolerOutput
|
||||||
from vllm.tasks import SupportedTask
|
from vllm.tasks import SupportedTask
|
||||||
from vllm.utils import make_async
|
from vllm.utils import make_async
|
||||||
|
from vllm.v1.outputs import SamplerOutput
|
||||||
from vllm.worker.worker_base import WorkerBase
|
from vllm.worker.worker_base import WorkerBase
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|||||||
@ -17,12 +17,12 @@ from vllm.executor.msgspec_utils import encode_hook
|
|||||||
from vllm.executor.ray_utils import (RayWorkerWrapper, initialize_ray_cluster,
|
from vllm.executor.ray_utils import (RayWorkerWrapper, initialize_ray_cluster,
|
||||||
ray)
|
ray)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.ray.ray_env import get_env_vars_to_copy
|
from vllm.ray.ray_env import get_env_vars_to_copy
|
||||||
from vllm.sequence import ExecuteModelRequest
|
from vllm.sequence import ExecuteModelRequest
|
||||||
from vllm.utils import (_run_task_with_lock, get_distributed_init_method,
|
from vllm.utils import (_run_task_with_lock, get_distributed_init_method,
|
||||||
get_ip, get_open_port, make_async)
|
get_ip, get_open_port, make_async)
|
||||||
|
from vllm.v1.outputs import SamplerOutput
|
||||||
|
|
||||||
if ray is not None:
|
if ray is not None:
|
||||||
from ray.actor import ActorHandle
|
from ray.actor import ActorHandle
|
||||||
|
|||||||
@ -7,15 +7,7 @@ from .data import (DataPrompt, DecoderOnlyInputs, EmbedsInputs, EmbedsPrompt,
|
|||||||
SingletonPrompt, TextPrompt, TokenInputs, TokensPrompt,
|
SingletonPrompt, TextPrompt, TokenInputs, TokensPrompt,
|
||||||
build_explicit_enc_dec_prompt, embeds_inputs,
|
build_explicit_enc_dec_prompt, embeds_inputs,
|
||||||
to_enc_dec_tuple_list, token_inputs, zip_enc_dec_prompts)
|
to_enc_dec_tuple_list, token_inputs, zip_enc_dec_prompts)
|
||||||
from .registry import (DummyData, InputContext, InputProcessingContext,
|
from .registry import InputContext, InputProcessingContext
|
||||||
InputRegistry)
|
|
||||||
|
|
||||||
INPUT_REGISTRY = InputRegistry()
|
|
||||||
"""
|
|
||||||
The global [`InputRegistry`][vllm.inputs.registry.InputRegistry] which is used
|
|
||||||
by [`LLMEngine`][vllm.LLMEngine] to dispatch data processing according to the
|
|
||||||
target model.
|
|
||||||
"""
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"DataPrompt",
|
"DataPrompt",
|
||||||
@ -36,9 +28,6 @@ __all__ = [
|
|||||||
"build_explicit_enc_dec_prompt",
|
"build_explicit_enc_dec_prompt",
|
||||||
"to_enc_dec_tuple_list",
|
"to_enc_dec_tuple_list",
|
||||||
"zip_enc_dec_prompts",
|
"zip_enc_dec_prompts",
|
||||||
"INPUT_REGISTRY",
|
|
||||||
"DummyData",
|
|
||||||
"InputContext",
|
"InputContext",
|
||||||
"InputProcessingContext",
|
"InputProcessingContext",
|
||||||
"InputRegistry",
|
|
||||||
]
|
]
|
||||||
|
|||||||
@ -2,7 +2,7 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Union
|
from typing import TYPE_CHECKING, Any, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from transformers import BatchFeature, PretrainedConfig, ProcessorMixin
|
from transformers import BatchFeature, PretrainedConfig, ProcessorMixin
|
||||||
@ -15,16 +15,9 @@ from vllm.utils.jsontree import JSONTree, json_map_leaves
|
|||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
from vllm.multimodal import (MultiModalDataDict, MultiModalPlaceholderDict,
|
|
||||||
MultiModalRegistry)
|
|
||||||
from vllm.sequence import SequenceData
|
|
||||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||||
else:
|
else:
|
||||||
ModelConfig = Any
|
ModelConfig = Any
|
||||||
MultiModalDataDict = Any
|
|
||||||
MultiModalPlaceholderDict = Any
|
|
||||||
MultiModalRegistry = Any
|
|
||||||
SequenceData = Any
|
|
||||||
AnyTokenizer = Any
|
AnyTokenizer = Any
|
||||||
|
|
||||||
_T = TypeVar("_T")
|
_T = TypeVar("_T")
|
||||||
@ -191,61 +184,3 @@ class InputProcessingContext(InputContext):
|
|||||||
f"on data={data} with kwargs={allowed_kwargs}")
|
f"on data={data} with kwargs={allowed_kwargs}")
|
||||||
|
|
||||||
raise ValueError(msg) from exc
|
raise ValueError(msg) from exc
|
||||||
|
|
||||||
|
|
||||||
class DummyData(NamedTuple):
|
|
||||||
"""
|
|
||||||
Dummy data used for profiling.
|
|
||||||
|
|
||||||
Note: This is only used in V0.
|
|
||||||
"""
|
|
||||||
|
|
||||||
seq_data: SequenceData
|
|
||||||
multi_modal_data: Optional[MultiModalDataDict] = None
|
|
||||||
multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None
|
|
||||||
|
|
||||||
|
|
||||||
class InputRegistry:
|
|
||||||
"""
|
|
||||||
Note: This is only used in V0.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def dummy_data_for_profiling(
|
|
||||||
self,
|
|
||||||
model_config: ModelConfig,
|
|
||||||
seq_len: int,
|
|
||||||
mm_registry: MultiModalRegistry,
|
|
||||||
is_encoder_data: bool = False,
|
|
||||||
) -> DummyData:
|
|
||||||
"""
|
|
||||||
Create dummy data for profiling the memory usage of a model.
|
|
||||||
|
|
||||||
The model is identified by ``model_config``.
|
|
||||||
"""
|
|
||||||
# Avoid circular import
|
|
||||||
from vllm.multimodal.cache import processor_only_cache_from_config
|
|
||||||
from vllm.sequence import SequenceData
|
|
||||||
|
|
||||||
if not model_config.is_multimodal_model:
|
|
||||||
seq_data = SequenceData.from_prompt_token_counts((0, seq_len))
|
|
||||||
return DummyData(seq_data=seq_data)
|
|
||||||
|
|
||||||
cache = processor_only_cache_from_config(model_config, mm_registry)
|
|
||||||
|
|
||||||
# Encoder dummy data does not contain multi-modal data
|
|
||||||
if is_encoder_data:
|
|
||||||
enc_data = mm_registry.get_encoder_dummy_data(model_config,
|
|
||||||
seq_len,
|
|
||||||
cache=cache)
|
|
||||||
seq_data = SequenceData.from_seqs(enc_data.prompt_token_ids)
|
|
||||||
return DummyData(seq_data=seq_data)
|
|
||||||
|
|
||||||
dec_data = mm_registry.get_decoder_dummy_data(model_config,
|
|
||||||
seq_len,
|
|
||||||
cache=cache)
|
|
||||||
|
|
||||||
return DummyData(
|
|
||||||
seq_data=SequenceData.from_seqs(dec_data.prompt_token_ids),
|
|
||||||
multi_modal_data=dec_data.multi_modal_data.get_data(),
|
|
||||||
multi_modal_placeholders=dec_data.multi_modal_placeholders,
|
|
||||||
)
|
|
||||||
|
|||||||
@ -3,13 +3,11 @@
|
|||||||
|
|
||||||
from vllm.model_executor.parameter import (BasevLLMParameter,
|
from vllm.model_executor.parameter import (BasevLLMParameter,
|
||||||
PackedvLLMParameter)
|
PackedvLLMParameter)
|
||||||
from vllm.model_executor.sampling_metadata import (SamplingMetadata,
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
SamplingMetadataCache)
|
|
||||||
from vllm.model_executor.utils import set_random_seed
|
from vllm.model_executor.utils import set_random_seed
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"SamplingMetadata",
|
"SamplingMetadata",
|
||||||
"SamplingMetadataCache",
|
|
||||||
"set_random_seed",
|
"set_random_seed",
|
||||||
"BasevLLMParameter",
|
"BasevLLMParameter",
|
||||||
"PackedvLLMParameter",
|
"PackedvLLMParameter",
|
||||||
|
|||||||
@ -1,13 +1,10 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
"""A layer that compute logits from hidden_stats."""
|
"""A layer that compute logits from hidden_stats."""
|
||||||
import inspect
|
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import vllm.envs as envs
|
|
||||||
from vllm.distributed import (tensor_model_parallel_all_gather,
|
from vllm.distributed import (tensor_model_parallel_all_gather,
|
||||||
tensor_model_parallel_gather)
|
tensor_model_parallel_gather)
|
||||||
from vllm.model_executor.custom_op import CustomOp
|
from vllm.model_executor.custom_op import CustomOp
|
||||||
@ -16,11 +13,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
_logits_processor_threadpool: Optional[ThreadPoolExecutor] = None
|
|
||||||
if envs.VLLM_LOGITS_PROCESSOR_THREADS is not None:
|
|
||||||
_logits_processor_threadpool = ThreadPoolExecutor(
|
|
||||||
envs.VLLM_LOGITS_PROCESSOR_THREADS)
|
|
||||||
|
|
||||||
|
|
||||||
@CustomOp.register("logits_processor")
|
@CustomOp.register("logits_processor")
|
||||||
class LogitsProcessor(CustomOp):
|
class LogitsProcessor(CustomOp):
|
||||||
@ -60,15 +52,10 @@ class LogitsProcessor(CustomOp):
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
sampling_metadata: Optional[SamplingMetadata] = None,
|
sampling_metadata: Optional[SamplingMetadata] = None,
|
||||||
embedding_bias: Optional[torch.Tensor] = None,
|
embedding_bias: Optional[torch.Tensor] = None,
|
||||||
prune_hidden_states: bool = True,
|
|
||||||
) -> Optional[torch.Tensor]:
|
) -> Optional[torch.Tensor]:
|
||||||
if self.logits_as_input:
|
if self.logits_as_input:
|
||||||
logits = hidden_states
|
logits = hidden_states
|
||||||
else:
|
else:
|
||||||
if sampling_metadata is not None and prune_hidden_states:
|
|
||||||
hidden_states = _prune_hidden_states(hidden_states,
|
|
||||||
sampling_metadata)
|
|
||||||
|
|
||||||
# Get the logits for the next tokens.
|
# Get the logits for the next tokens.
|
||||||
logits = self._get_logits(hidden_states, lm_head, embedding_bias)
|
logits = self._get_logits(hidden_states, lm_head, embedding_bias)
|
||||||
if logits is not None:
|
if logits is not None:
|
||||||
@ -79,12 +66,6 @@ class LogitsProcessor(CustomOp):
|
|||||||
|
|
||||||
if self.scale != 1.0:
|
if self.scale != 1.0:
|
||||||
logits *= self.scale
|
logits *= self.scale
|
||||||
|
|
||||||
# Apply logits processors (if any).
|
|
||||||
if sampling_metadata is not None and \
|
|
||||||
sampling_metadata.seq_groups is not None:
|
|
||||||
logits = _apply_logits_processors(logits, sampling_metadata)
|
|
||||||
|
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
def _gather_logits(self, logits: torch.Tensor) -> torch.Tensor:
|
def _gather_logits(self, logits: torch.Tensor) -> torch.Tensor:
|
||||||
@ -125,75 +106,3 @@ class LogitsProcessor(CustomOp):
|
|||||||
s += f", org_vocab_size={self.org_vocab_size}"
|
s += f", org_vocab_size={self.org_vocab_size}"
|
||||||
s += f", scale={self.scale}, logits_as_input={self.logits_as_input}"
|
s += f", scale={self.scale}, logits_as_input={self.logits_as_input}"
|
||||||
return s
|
return s
|
||||||
|
|
||||||
|
|
||||||
def _prune_hidden_states(
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
sampling_metadata: SamplingMetadata,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
# NOTE(kzawora): The if guard is needed for Gaudi - in some scenarios
|
|
||||||
# (warmup, profile_run) we might not have selected_token_indices,
|
|
||||||
# so we skip pruning.
|
|
||||||
if sampling_metadata.selected_token_indices is not None:
|
|
||||||
return hidden_states.index_select(
|
|
||||||
0, sampling_metadata.selected_token_indices)
|
|
||||||
else:
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
def _apply_logits_processors(
|
|
||||||
logits: torch.Tensor,
|
|
||||||
sampling_metadata: SamplingMetadata,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
found_logits_processors = False
|
|
||||||
logits_processed = 0
|
|
||||||
logits_row_ids_and_logits_row_futures = []
|
|
||||||
for seq_group in sampling_metadata.seq_groups:
|
|
||||||
seq_ids = seq_group.seq_ids
|
|
||||||
sampling_params = seq_group.sampling_params
|
|
||||||
logits_processors = sampling_params.logits_processors
|
|
||||||
if logits_processors:
|
|
||||||
found_logits_processors = True
|
|
||||||
|
|
||||||
for seq_id, logits_row_idx in zip(seq_ids,
|
|
||||||
seq_group.sample_indices):
|
|
||||||
logits_row = logits[logits_row_idx]
|
|
||||||
past_tokens_ids = seq_group.seq_data[seq_id].output_token_ids
|
|
||||||
prompt_tokens_ids = seq_group.seq_data[seq_id].prompt_token_ids
|
|
||||||
|
|
||||||
if _logits_processor_threadpool is not None:
|
|
||||||
logits_row_ids_and_logits_row_futures.append(
|
|
||||||
(logits_row_idx,
|
|
||||||
_logits_processor_threadpool.submit(
|
|
||||||
_apply_logits_processors_single_seq, logits_row,
|
|
||||||
logits_processors, past_tokens_ids,
|
|
||||||
prompt_tokens_ids)))
|
|
||||||
else:
|
|
||||||
logits[logits_row_idx] = \
|
|
||||||
_apply_logits_processors_single_seq(
|
|
||||||
logits_row, logits_processors, past_tokens_ids,
|
|
||||||
prompt_tokens_ids)
|
|
||||||
|
|
||||||
logits_processed += len(seq_group.sample_indices) + len(
|
|
||||||
seq_group.prompt_logprob_indices)
|
|
||||||
|
|
||||||
for logits_row_idx, future in logits_row_ids_and_logits_row_futures:
|
|
||||||
logits[logits_row_idx] = future.result()
|
|
||||||
|
|
||||||
if found_logits_processors:
|
|
||||||
# verifies that no rows in logits were missed unexpectedly
|
|
||||||
assert logits_processed == logits.shape[0]
|
|
||||||
return logits
|
|
||||||
|
|
||||||
|
|
||||||
def _apply_logits_processors_single_seq(logits_row, logits_processors,
|
|
||||||
past_tokens_ids,
|
|
||||||
prompt_tokens_ids) -> torch.Tensor:
|
|
||||||
for logits_processor in logits_processors:
|
|
||||||
parameters = inspect.signature(logits_processor).parameters
|
|
||||||
if len(parameters) == 3:
|
|
||||||
logits_row = logits_processor(prompt_tokens_ids, past_tokens_ids,
|
|
||||||
logits_row)
|
|
||||||
else:
|
|
||||||
logits_row = logits_processor(past_tokens_ids, logits_row)
|
|
||||||
return logits_row
|
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@ -2,18 +2,15 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead)
|
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead)
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
|
||||||
|
|
||||||
from .utils import maybe_prefix
|
from .utils import maybe_prefix
|
||||||
|
|
||||||
@ -105,8 +102,10 @@ class Medusa(nn.Module):
|
|||||||
return [block(hidden_states) for block in self.blocks]
|
return [block(hidden_states) for block in self.blocks]
|
||||||
|
|
||||||
def compute_logits(
|
def compute_logits(
|
||||||
self, hidden_states: list[torch.Tensor],
|
self,
|
||||||
sampling_metadata: SamplingMetadata) -> list[torch.Tensor]:
|
hidden_states: list[torch.Tensor],
|
||||||
|
sampling_metadata,
|
||||||
|
) -> list[torch.Tensor]:
|
||||||
logits_lst: list[torch.Tensor] = []
|
logits_lst: list[torch.Tensor] = []
|
||||||
|
|
||||||
for hs, lm_head in zip(hidden_states, self.lm_heads):
|
for hs, lm_head in zip(hidden_states, self.lm_heads):
|
||||||
@ -130,57 +129,6 @@ class Medusa(nn.Module):
|
|||||||
|
|
||||||
return logits_lst
|
return logits_lst
|
||||||
|
|
||||||
def sample(
|
|
||||||
self,
|
|
||||||
logits: list[torch.Tensor],
|
|
||||||
sampling_metadata: SamplingMetadata,
|
|
||||||
) -> list[SamplerOutput]:
|
|
||||||
logits = torch.stack(logits, dim=0).float()
|
|
||||||
logprobs = torch.log_softmax(logits, dim=-1)
|
|
||||||
token_ids = logits.argmax(-1) # support only top-1 for now
|
|
||||||
probs = torch.softmax(logits, dim=-1)
|
|
||||||
|
|
||||||
token_id_list = []
|
|
||||||
token_prob_list = []
|
|
||||||
token_logprob_list = []
|
|
||||||
|
|
||||||
for idx, seq_group in enumerate(sampling_metadata.seq_groups):
|
|
||||||
token_id_list.append(token_ids[:, seq_group.sample_indices])
|
|
||||||
token_prob_list.append(probs[:, seq_group.sample_indices])
|
|
||||||
token_logprob_list.append(logprobs[:, seq_group.sample_indices])
|
|
||||||
|
|
||||||
outputs: list[Optional[SamplerOutput]] = []
|
|
||||||
for idx in range(len(sampling_metadata.seq_groups)):
|
|
||||||
outputs.append(
|
|
||||||
SamplerOutput(
|
|
||||||
outputs=None,
|
|
||||||
sampled_token_probs=token_prob_list[idx].squeeze(1),
|
|
||||||
logprobs=token_logprob_list[idx].squeeze(1),
|
|
||||||
sampled_token_ids=token_id_list[idx].squeeze(1),
|
|
||||||
))
|
|
||||||
|
|
||||||
return outputs
|
|
||||||
|
|
||||||
def generate_proposals(
|
|
||||||
self,
|
|
||||||
previous_hidden_states: torch.Tensor,
|
|
||||||
sampling_metadata: SamplingMetadata,
|
|
||||||
) -> Optional[list[SamplerOutput]]:
|
|
||||||
# During preemption, we may receive an empty tensor (batch_size=0)
|
|
||||||
if previous_hidden_states.size(0) == 0:
|
|
||||||
# Return None to signal the Top1Proposer that no proposals
|
|
||||||
# were generated for this batch, allowing it to handle this
|
|
||||||
# special case appropriately
|
|
||||||
return None
|
|
||||||
|
|
||||||
return self.sample(
|
|
||||||
logits=self.compute_logits(
|
|
||||||
hidden_states=self.forward(previous_hidden_states),
|
|
||||||
sampling_metadata=sampling_metadata,
|
|
||||||
),
|
|
||||||
sampling_metadata=sampling_metadata,
|
|
||||||
)
|
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[tuple[str,
|
def load_weights(self, weights: Iterable[tuple[str,
|
||||||
torch.Tensor]]) -> set[str]:
|
torch.Tensor]]) -> set[str]:
|
||||||
params_dict = dict(self.named_parameters())
|
params_dict = dict(self.named_parameters())
|
||||||
|
|||||||
@ -8,9 +8,7 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.model_executor import SamplingMetadata
|
|
||||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
ParallelLMHead, VocabParallelEmbedding)
|
ParallelLMHead, VocabParallelEmbedding)
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
@ -141,55 +139,57 @@ class MLPSpeculator(nn.Module):
|
|||||||
self.config = config
|
self.config = config
|
||||||
self.logits_processor = LogitsProcessor(config.vocab_size,
|
self.logits_processor = LogitsProcessor(config.vocab_size,
|
||||||
config.vocab_size, 1.0)
|
config.vocab_size, 1.0)
|
||||||
self.sampler = get_sampler()
|
|
||||||
|
|
||||||
def generate_proposals(
|
# NOTE(woosuk): This method is commented out because it is old code
|
||||||
self,
|
# using V0. We should either port it to V1 or remove it.
|
||||||
input_ids: torch.Tensor,
|
|
||||||
previous_hidden_states: torch.Tensor,
|
|
||||||
num_predict_tokens: int,
|
|
||||||
sampling_metadata: SamplingMetadata,
|
|
||||||
) -> list[SamplerOutput]:
|
|
||||||
if num_predict_tokens > self.max_speculative_tokens:
|
|
||||||
raise ValueError(f"Max speculative tokens for model is "
|
|
||||||
f"{self.max_speculative_tokens}, but "
|
|
||||||
f"{num_predict_tokens} were requested")
|
|
||||||
|
|
||||||
# b x 1 x d
|
# def generate_proposals(
|
||||||
previous_hidden_states = previous_hidden_states.unsqueeze(1)
|
# self,
|
||||||
|
# input_ids: torch.Tensor,
|
||||||
|
# previous_hidden_states: torch.Tensor,
|
||||||
|
# num_predict_tokens: int,
|
||||||
|
# sampling_metadata: SamplingMetadata,
|
||||||
|
# ) -> list[SamplerOutput]:
|
||||||
|
# if num_predict_tokens > self.max_speculative_tokens:
|
||||||
|
# raise ValueError(f"Max speculative tokens for model is "
|
||||||
|
# f"{self.max_speculative_tokens}, but "
|
||||||
|
# f"{num_predict_tokens} were requested")
|
||||||
|
|
||||||
if self.scale_input:
|
# # b x 1 x d
|
||||||
previous_hidden_states = self.ln0(previous_hidden_states) / SQRT2
|
# previous_hidden_states = previous_hidden_states.unsqueeze(1)
|
||||||
|
|
||||||
# b x 1
|
# if self.scale_input:
|
||||||
last_tokens = input_ids.unsqueeze(1)
|
# previous_hidden_states = self.ln0(previous_hidden_states) / SQRT2
|
||||||
|
|
||||||
next_tokens = []
|
# # b x 1
|
||||||
|
# last_tokens = input_ids.unsqueeze(1)
|
||||||
|
|
||||||
for head_index in range(num_predict_tokens):
|
# next_tokens = []
|
||||||
|
|
||||||
# Project and predict
|
# for head_index in range(num_predict_tokens):
|
||||||
z = self.emb[head_index](last_tokens) # b k d
|
|
||||||
states = self.proj[head_index](previous_hidden_states)
|
|
||||||
|
|
||||||
# Weighted add of state_weight*state and emb_weight*z
|
# # Project and predict
|
||||||
# Let subsequent LN take care of denominator
|
# z = self.emb[head_index](last_tokens) # b k d
|
||||||
# state_weight is close to 1, so shouldn't be any precision issues
|
# states = self.proj[head_index](previous_hidden_states)
|
||||||
states.add_(z, alpha=self.emb_weight / self.state_weight)
|
|
||||||
|
|
||||||
states = self.activation(self.ln[head_index](states)) # b k d
|
# # Weighted add of state_weight*state and emb_weight*z
|
||||||
previous_hidden_states = states
|
# # Let subsequent LN take care of denominator
|
||||||
# TODO: not yet supporting top_k_tokens_per_head
|
# # state_weight is close to 1, so shouldn't be any precision issues
|
||||||
states = states.flatten(0, 1)
|
# states.add_(z, alpha=self.emb_weight / self.state_weight)
|
||||||
|
|
||||||
logits = self.logits_processor(self.head[head_index], states,
|
# states = self.activation(self.ln[head_index](states)) # b k d
|
||||||
sampling_metadata)
|
# previous_hidden_states = states
|
||||||
|
# # TODO: not yet supporting top_k_tokens_per_head
|
||||||
|
# states = states.flatten(0, 1)
|
||||||
|
|
||||||
output = self.sampler(logits, sampling_metadata)
|
# logits = self.logits_processor(self.head[head_index], states,
|
||||||
last_tokens = output.sampled_token_ids
|
# sampling_metadata)
|
||||||
next_tokens.append(output)
|
|
||||||
|
|
||||||
return next_tokens
|
# output = self.sampler(logits, sampling_metadata)
|
||||||
|
# last_tokens = output.sampled_token_ids
|
||||||
|
# next_tokens.append(output)
|
||||||
|
|
||||||
|
# return next_tokens
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[tuple[str,
|
def load_weights(self, weights: Iterable[tuple[str,
|
||||||
torch.Tensor]]) -> set[str]:
|
torch.Tensor]]) -> set[str]:
|
||||||
|
|||||||
@ -697,16 +697,12 @@ class Phi4FlashForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsV0Only):
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
sampling_metadata: SamplingMetadata,
|
sampling_metadata: SamplingMetadata,
|
||||||
) -> Optional[torch.Tensor]:
|
) -> Optional[torch.Tensor]:
|
||||||
# If the shape is the same, it means that we have already
|
|
||||||
# prune hidden states manually.
|
|
||||||
prune_hidden_states = hidden_states.size(
|
|
||||||
0) != sampling_metadata.selected_token_indices.size(0)
|
|
||||||
processed_logits = self.logits_processor(
|
processed_logits = self.logits_processor(
|
||||||
self.lm_head,
|
self.lm_head,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
sampling_metadata,
|
sampling_metadata,
|
||||||
self.embedding_bias,
|
self.embedding_bias,
|
||||||
prune_hidden_states=prune_hidden_states)
|
)
|
||||||
return processed_logits
|
return processed_logits
|
||||||
|
|
||||||
def load_weights(
|
def load_weights(
|
||||||
|
|||||||
@ -1,597 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
from array import array
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from vllm.sampling_params import SamplingParams, SamplingType
|
|
||||||
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData,
|
|
||||||
SequenceGroupMetadata)
|
|
||||||
from vllm.utils import (PyObjectCache, async_tensor_h2d,
|
|
||||||
is_pin_memory_available, make_tensor_with_pad)
|
|
||||||
|
|
||||||
_SAMPLING_EPS = 1e-5
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class SequenceGroupToSample:
|
|
||||||
# |---------- N-1 iteration --------|
|
|
||||||
# |---------------- N iteration ---------------------|
|
|
||||||
# |- tokenA -|......................|-- newTokens ---|
|
|
||||||
# |---------- context_len ----------|
|
|
||||||
# |-------------------- seq_len ----------------------|
|
|
||||||
# |-- query_len ---|
|
|
||||||
|
|
||||||
# Sequence ids for the sequence group in a previous step.
|
|
||||||
seq_ids: list[int]
|
|
||||||
sampling_params: SamplingParams
|
|
||||||
# seq_id -> sequence data.
|
|
||||||
seq_data: dict[int, SequenceData]
|
|
||||||
# The length of the sequence (all tokens seen in the past + new token to
|
|
||||||
# compute attention) of the sequence group. None if it is in a decode
|
|
||||||
# stage.
|
|
||||||
seq_len: Optional[int]
|
|
||||||
# The length of new query tokens to compute in the current step. None if it
|
|
||||||
# is in a decode stage. The length of query_len <= seq_len if chunked
|
|
||||||
# prefill is enabled.
|
|
||||||
query_len: Optional[int]
|
|
||||||
# A random number generator for sampling.
|
|
||||||
generator: Optional[torch.Generator]
|
|
||||||
# True if the sequence group is in prefill stage. False if it is in a
|
|
||||||
# decode stage.
|
|
||||||
is_prompt: bool
|
|
||||||
# Query token indices from logits. to compute prompt logprob. Empty if
|
|
||||||
# prompt logprob is not required.
|
|
||||||
prompt_logprob_indices: list[int]
|
|
||||||
# Sample token indices from logits. Empty if sampling is not required.
|
|
||||||
sample_indices: list[int]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def do_sample(self):
|
|
||||||
return len(self.sample_indices) > 0
|
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
if len(self.prompt_logprob_indices) > 0:
|
|
||||||
assert self.sampling_params.prompt_logprobs is not None
|
|
||||||
if self.is_prompt:
|
|
||||||
assert self.seq_len is not None
|
|
||||||
assert self.query_len is not None
|
|
||||||
|
|
||||||
|
|
||||||
def gen_seq_group_to_sample_builder(num_seqs: int):
|
|
||||||
return lambda: SequenceGroupToSample(
|
|
||||||
seq_ids=[0] * num_seqs,
|
|
||||||
sampling_params=None,
|
|
||||||
seq_data=None, # type: ignore
|
|
||||||
seq_len=0,
|
|
||||||
query_len=0,
|
|
||||||
generator=None,
|
|
||||||
is_prompt=True,
|
|
||||||
prompt_logprob_indices=[],
|
|
||||||
sample_indices=[],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class SamplingMetadataCache:
|
|
||||||
"""Used to cache SamplingMetadata objects between scheduler iterations"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self._seq_group_to_sample_cache: dict[int, PyObjectCache] = {}
|
|
||||||
|
|
||||||
def get_cached_seq_group_to_sample(self, num_seqs):
|
|
||||||
if num_seqs not in self._seq_group_to_sample_cache:
|
|
||||||
self._seq_group_to_sample_cache[num_seqs] = PyObjectCache(
|
|
||||||
gen_seq_group_to_sample_builder(num_seqs))
|
|
||||||
|
|
||||||
obj = self._seq_group_to_sample_cache[num_seqs].get_object()
|
|
||||||
return obj
|
|
||||||
|
|
||||||
def reset(self):
|
|
||||||
for cache in self._seq_group_to_sample_cache.values():
|
|
||||||
cache.reset()
|
|
||||||
|
|
||||||
|
|
||||||
class SamplingMetadata:
|
class SamplingMetadata:
|
||||||
"""Metadata for input sequences. Used in sampler.
|
# Placeholder until it can be safely removed.
|
||||||
|
pass
|
||||||
The usage is as follows;
|
|
||||||
```
|
|
||||||
hidden_states = execute_model(...)
|
|
||||||
logits = hidden_states[sampling_metadata.selected_token_indices]
|
|
||||||
sample(logits)
|
|
||||||
|
|
||||||
def sample(logits):
|
|
||||||
# Use categorized_sample_indices for sampling....
|
|
||||||
```
|
|
||||||
|
|
||||||
Args:
|
|
||||||
seq_groups: List of batched sequence groups.
|
|
||||||
selected_token_indices: (num_query_tokens_to_logprob). Indices to find
|
|
||||||
logits from the initial model output hidden states.
|
|
||||||
categorized_sample_indices: SamplingType -> token indices to sample.
|
|
||||||
Each token indices is 2D tensor of (num_indices, num_indices) where
|
|
||||||
the first item means the sample index within the returned logit
|
|
||||||
(before pruning padding), and the second item means the sample
|
|
||||||
index after pruning using selected_token_indices.
|
|
||||||
For example, if the returned logit is [1, 2, 3], and we select
|
|
||||||
[1, 2] for sampling, the pruned logit will be [2, 3]. In this case,
|
|
||||||
The first tuple is [1, 2] (sampled index within original logit),
|
|
||||||
and the second tuple is [0, 1] (sampled index within pruned logit).
|
|
||||||
num_prompts: Number of prompt sequence groups in seq_groups.
|
|
||||||
skip_sampler_cpu_output: Indicates if we want to skip the GPU=>CPU
|
|
||||||
serialization of token outputs.
|
|
||||||
reuse_sampling_tensors: Indicates if we want to reuse sampling
|
|
||||||
tensors that are part of the sampler forward pass. Currently,
|
|
||||||
it is mainly used for multi-step decode.
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
seq_groups: list[SequenceGroupToSample],
|
|
||||||
selected_token_indices: torch.Tensor,
|
|
||||||
categorized_sample_indices: dict[SamplingType, torch.Tensor],
|
|
||||||
num_prompts: int,
|
|
||||||
skip_sampler_cpu_output: bool = False,
|
|
||||||
reuse_sampling_tensors: bool = False,
|
|
||||||
) -> None:
|
|
||||||
self.seq_groups = seq_groups
|
|
||||||
self.selected_token_indices = selected_token_indices
|
|
||||||
self.categorized_sample_indices = categorized_sample_indices
|
|
||||||
self.num_prompts = num_prompts
|
|
||||||
self.skip_sampler_cpu_output = skip_sampler_cpu_output
|
|
||||||
self.reuse_sampling_tensors = reuse_sampling_tensors
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def prepare(
|
|
||||||
seq_group_metadata_list: list[SequenceGroupMetadata],
|
|
||||||
seq_lens: list[int],
|
|
||||||
query_lens: list[int],
|
|
||||||
device: str,
|
|
||||||
pin_memory: bool,
|
|
||||||
generators: Optional[dict[str, torch.Generator]] = None,
|
|
||||||
cache: Optional[SamplingMetadataCache] = None,
|
|
||||||
) -> "SamplingMetadata":
|
|
||||||
(
|
|
||||||
seq_groups,
|
|
||||||
selected_token_indices,
|
|
||||||
categorized_sample_indices,
|
|
||||||
num_prompts,
|
|
||||||
) = _prepare_seq_groups(seq_group_metadata_list, seq_lens, query_lens,
|
|
||||||
device, generators, cache)
|
|
||||||
selected_token_indices = async_tensor_h2d(
|
|
||||||
selected_token_indices,
|
|
||||||
dtype=torch.long,
|
|
||||||
target_device=device,
|
|
||||||
pin_memory=pin_memory,
|
|
||||||
)
|
|
||||||
categorized_sample_indices = {
|
|
||||||
t:
|
|
||||||
async_tensor_h2d(
|
|
||||||
seq_ids,
|
|
||||||
dtype=torch.int,
|
|
||||||
target_device=device,
|
|
||||||
pin_memory=pin_memory,
|
|
||||||
)
|
|
||||||
for t, seq_ids in categorized_sample_indices.items()
|
|
||||||
}
|
|
||||||
|
|
||||||
sampling_metadata = SamplingMetadata(
|
|
||||||
seq_groups=seq_groups,
|
|
||||||
selected_token_indices=selected_token_indices,
|
|
||||||
categorized_sample_indices=categorized_sample_indices,
|
|
||||||
num_prompts=num_prompts,
|
|
||||||
)
|
|
||||||
return sampling_metadata
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
|
||||||
return (
|
|
||||||
"SamplingMetadata("
|
|
||||||
f"seq_groups={self.seq_groups}, "
|
|
||||||
f"selected_token_indices={self.selected_token_indices}, "
|
|
||||||
f"categorized_sample_indices={self.categorized_sample_indices})")
|
|
||||||
|
|
||||||
|
|
||||||
def _prepare_seq_groups(
|
|
||||||
seq_group_metadata_list: list[SequenceGroupMetadata],
|
|
||||||
seq_lens: list[int],
|
|
||||||
query_lens: list[int],
|
|
||||||
device: str,
|
|
||||||
generators: Optional[dict[str, torch.Generator]] = None,
|
|
||||||
cache: Optional[SamplingMetadataCache] = None,
|
|
||||||
) -> tuple[
|
|
||||||
list[SequenceGroupToSample],
|
|
||||||
list[int],
|
|
||||||
dict[SamplingType, list[int]],
|
|
||||||
int,
|
|
||||||
]:
|
|
||||||
"""Prepare sequence groups and indices for sampling.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
seq_group_metadata_list: A list of sequence group to batch.
|
|
||||||
seq_lens: A list of sequence lens per sequence group.
|
|
||||||
Index of prompt len should match with seq_group_metadata_list.
|
|
||||||
query_lens: A list of query lengths. Prompt lens include the length
|
|
||||||
of entire prompt tokens, and it could be shorter.
|
|
||||||
device: A device to use for random number generators,
|
|
||||||
`SequenceGroupToSample.generator`.
|
|
||||||
generators: A store of per-request random number generators used
|
|
||||||
for seeded requests.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
seq_groups: A list of sequence group to sample.
|
|
||||||
selected_token_indices: See the definition from `SamplingMetadata`.
|
|
||||||
categorized_sample_indices: See the definition from `SamplingMetadata`.
|
|
||||||
num_prompts: Total number of prompts from `seq_group_metadata_list`.
|
|
||||||
"""
|
|
||||||
# Batched sequence groups for the current model forward stsep.
|
|
||||||
seq_groups: list[SequenceGroupToSample] = []
|
|
||||||
# A list of token indices to sample/compute logprob. It is used to
|
|
||||||
# prune the outcome logits from the model for the performance.
|
|
||||||
selected_token_indices: list[int] = []
|
|
||||||
# Used for selected_token_indices.
|
|
||||||
model_output_idx = 0
|
|
||||||
|
|
||||||
# Sampling type -> (
|
|
||||||
# indices to sample/prompt logprob within pruned output logits,
|
|
||||||
# indices to sample within pruned logits)
|
|
||||||
categorized_sample_indices: dict[SamplingType, list[int]] = {
|
|
||||||
t: []
|
|
||||||
for t in SamplingType
|
|
||||||
}
|
|
||||||
# Index of logits to compute logprob. Logits include both prompt logprob
|
|
||||||
# and sample logprob indices.
|
|
||||||
logit_idx = 0
|
|
||||||
# Total number of prompts from given sequence groups.
|
|
||||||
num_prompts = 0
|
|
||||||
|
|
||||||
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
|
|
||||||
seq_ids = seq_group_metadata.seq_data.keys()
|
|
||||||
|
|
||||||
if cache is not None:
|
|
||||||
sample_obj = cache.get_cached_seq_group_to_sample(len(seq_ids))
|
|
||||||
|
|
||||||
for j, seq_id in enumerate(seq_ids):
|
|
||||||
sample_obj.seq_ids[j] = seq_id
|
|
||||||
|
|
||||||
sample_obj.prompt_logprob_indices.clear()
|
|
||||||
sample_obj.sample_indices.clear()
|
|
||||||
|
|
||||||
sampling_params = seq_group_metadata.sampling_params
|
|
||||||
is_prompt = seq_group_metadata.is_prompt
|
|
||||||
generator: Optional[torch.Generator] = None
|
|
||||||
# If the current seq group is in decode stage, it is None.
|
|
||||||
seq_len: Optional[int] = None
|
|
||||||
query_len: Optional[int] = None
|
|
||||||
prompt_logprob_indices: list[int] = (sample_obj.prompt_logprob_indices
|
|
||||||
if cache is not None else [])
|
|
||||||
sample_indices: list[int] = (sample_obj.sample_indices
|
|
||||||
if cache is not None else [])
|
|
||||||
do_sample = seq_group_metadata.do_sample
|
|
||||||
|
|
||||||
if seq_group_metadata.is_prompt:
|
|
||||||
if sampling_params.seed is not None:
|
|
||||||
generator = torch.Generator(device=device).manual_seed(
|
|
||||||
sampling_params.seed)
|
|
||||||
if generators is not None:
|
|
||||||
generators[seq_group_metadata.request_id] = generator
|
|
||||||
|
|
||||||
num_prompts += 1
|
|
||||||
num_prefill_sample = len(seq_ids)
|
|
||||||
assert num_prefill_sample == 1
|
|
||||||
assert query_lens is not None and seq_lens is not None
|
|
||||||
query_len, seq_len = query_lens[i], seq_lens[i]
|
|
||||||
# If we need sampling, exclude num_prefill_sample tokens from
|
|
||||||
# prompt logprob.
|
|
||||||
prompt_logprob_len = (query_len - num_prefill_sample
|
|
||||||
if do_sample else query_len)
|
|
||||||
sample_len = num_prefill_sample if do_sample else 0
|
|
||||||
else:
|
|
||||||
# Decode
|
|
||||||
prompt_logprob_len = 0
|
|
||||||
query_len = query_lens[i] if query_lens is not None and len(
|
|
||||||
query_lens) > 0 else 1
|
|
||||||
sample_len = len(seq_ids) * query_len if do_sample else 0
|
|
||||||
|
|
||||||
if sampling_params.seed is not None and generators is not None:
|
|
||||||
generator = generators.get(seq_group_metadata.request_id)
|
|
||||||
|
|
||||||
# Update indices to select from the model output.
|
|
||||||
"""
|
|
||||||
This blocks computes selected_token_indices which is used in the
|
|
||||||
following way.
|
|
||||||
|
|
||||||
hidden_states = model(...)
|
|
||||||
logits = hidden_states[selected_token_indices]
|
|
||||||
"""
|
|
||||||
|
|
||||||
if sampling_params.prompt_logprobs is not None:
|
|
||||||
selected_token_indices.extend(
|
|
||||||
range(model_output_idx, model_output_idx + prompt_logprob_len))
|
|
||||||
model_output_idx += prompt_logprob_len
|
|
||||||
if do_sample:
|
|
||||||
selected_token_indices.extend(
|
|
||||||
range(model_output_idx, model_output_idx + sample_len))
|
|
||||||
model_output_idx += sample_len
|
|
||||||
|
|
||||||
# We now find indices for logprob computation and sampling.
|
|
||||||
"""
|
|
||||||
This block computes categorized_sample_indices which is used in the
|
|
||||||
following way.
|
|
||||||
|
|
||||||
hidden_states = model(...)
|
|
||||||
logits = hidden_states[selected_token_indices]
|
|
||||||
def sample(logits):
|
|
||||||
# Use categorized_sample_indices for sampling.
|
|
||||||
# prompt_logprob_indices to find prompt logprob indices.
|
|
||||||
# sample_indices to find sample indices.
|
|
||||||
"""
|
|
||||||
|
|
||||||
if sampling_params.prompt_logprobs is not None:
|
|
||||||
prompt_logprob_indices.extend(
|
|
||||||
range(logit_idx, logit_idx + prompt_logprob_len))
|
|
||||||
logit_idx += prompt_logprob_len
|
|
||||||
if do_sample:
|
|
||||||
sample_indices.extend(range(logit_idx, logit_idx + sample_len))
|
|
||||||
categorized_sample_indices[sampling_params.sampling_type].extend(
|
|
||||||
list(range(logit_idx, logit_idx + sample_len)))
|
|
||||||
logit_idx += sample_len
|
|
||||||
|
|
||||||
if cache is not None:
|
|
||||||
sample_obj.sampling_params = sampling_params
|
|
||||||
sample_obj.seq_data = seq_group_metadata.seq_data
|
|
||||||
sample_obj.seq_len = seq_len
|
|
||||||
sample_obj.query_len = query_len
|
|
||||||
sample_obj.generator = generator
|
|
||||||
sample_obj.is_prompt = is_prompt
|
|
||||||
else:
|
|
||||||
sample_obj = SequenceGroupToSample(
|
|
||||||
seq_ids=list(seq_ids),
|
|
||||||
sampling_params=sampling_params,
|
|
||||||
seq_data=seq_group_metadata.seq_data,
|
|
||||||
seq_len=seq_len,
|
|
||||||
query_len=query_len,
|
|
||||||
generator=generator,
|
|
||||||
is_prompt=is_prompt,
|
|
||||||
prompt_logprob_indices=list(prompt_logprob_indices),
|
|
||||||
sample_indices=list(sample_indices),
|
|
||||||
)
|
|
||||||
|
|
||||||
seq_groups.append(sample_obj)
|
|
||||||
|
|
||||||
if cache is not None:
|
|
||||||
cache.reset()
|
|
||||||
|
|
||||||
return (seq_groups, selected_token_indices, categorized_sample_indices,
|
|
||||||
num_prompts)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class SamplingTensors:
|
|
||||||
"""Tensors for sampling."""
|
|
||||||
|
|
||||||
temperatures: torch.Tensor
|
|
||||||
top_ps: torch.Tensor
|
|
||||||
top_ks: torch.Tensor
|
|
||||||
min_ps: torch.Tensor
|
|
||||||
presence_penalties: torch.Tensor
|
|
||||||
frequency_penalties: torch.Tensor
|
|
||||||
repetition_penalties: torch.Tensor
|
|
||||||
prompt_tokens: torch.Tensor
|
|
||||||
output_tokens: torch.Tensor
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_sampling_metadata(
|
|
||||||
cls,
|
|
||||||
sampling_metadata: "SamplingMetadata",
|
|
||||||
vocab_size: int,
|
|
||||||
device: torch.device,
|
|
||||||
dtype: torch.dtype,
|
|
||||||
) -> tuple["SamplingTensors", bool, bool, bool]:
|
|
||||||
prompt_tokens: list[array] = []
|
|
||||||
output_tokens: list[array] = []
|
|
||||||
top_ks: list[int] = []
|
|
||||||
temperatures: list[float] = []
|
|
||||||
top_ps: list[float] = []
|
|
||||||
min_ps: list[float] = []
|
|
||||||
presence_penalties: list[float] = []
|
|
||||||
frequency_penalties: list[float] = []
|
|
||||||
repetition_penalties: list[float] = []
|
|
||||||
do_penalties = False
|
|
||||||
do_top_p_top_k = False
|
|
||||||
do_min_p = False
|
|
||||||
|
|
||||||
assert sampling_metadata.seq_groups is not None
|
|
||||||
for seq_group in sampling_metadata.seq_groups:
|
|
||||||
seq_ids = seq_group.seq_ids
|
|
||||||
sampling_params = seq_group.sampling_params
|
|
||||||
temperature = sampling_params.temperature
|
|
||||||
p = sampling_params.presence_penalty
|
|
||||||
f = sampling_params.frequency_penalty
|
|
||||||
r = sampling_params.repetition_penalty
|
|
||||||
top_p = sampling_params.top_p
|
|
||||||
min_p = sampling_params.min_p
|
|
||||||
|
|
||||||
# k should not be greater than the vocab size.
|
|
||||||
top_k = min(sampling_params.top_k, vocab_size)
|
|
||||||
top_k = vocab_size if top_k < 1 else top_k
|
|
||||||
if temperature < _SAMPLING_EPS:
|
|
||||||
# NOTE: Zero temperature means deterministic sampling
|
|
||||||
# (i.e., greedy sampling or beam search).
|
|
||||||
# Set the temperature to 1 to avoid division by zero.
|
|
||||||
temperature = 1.0
|
|
||||||
if not do_top_p_top_k and (top_p < 1.0 - _SAMPLING_EPS
|
|
||||||
or top_k != vocab_size):
|
|
||||||
do_top_p_top_k = True
|
|
||||||
if not do_min_p and min_p > _SAMPLING_EPS:
|
|
||||||
do_min_p = True
|
|
||||||
if not do_penalties and (abs(p) >= _SAMPLING_EPS
|
|
||||||
or abs(f) >= _SAMPLING_EPS
|
|
||||||
or abs(r - 1.0) >= _SAMPLING_EPS):
|
|
||||||
do_penalties = True
|
|
||||||
|
|
||||||
is_prompt = seq_group.is_prompt
|
|
||||||
if is_prompt and sampling_params.prompt_logprobs is not None:
|
|
||||||
# For tokens in the prompt that we only need to get
|
|
||||||
# their logprobs
|
|
||||||
query_len = seq_group.query_len
|
|
||||||
assert query_len is not None
|
|
||||||
prefill_len = len(seq_group.prompt_logprob_indices)
|
|
||||||
temperatures += [temperature] * prefill_len
|
|
||||||
top_ps += [top_p] * prefill_len
|
|
||||||
top_ks += [top_k] * prefill_len
|
|
||||||
min_ps += [min_p] * prefill_len
|
|
||||||
presence_penalties += [0] * prefill_len
|
|
||||||
frequency_penalties += [0] * prefill_len
|
|
||||||
repetition_penalties += [1] * prefill_len
|
|
||||||
|
|
||||||
if seq_group.do_sample:
|
|
||||||
sample_lens = len(seq_group.sample_indices)
|
|
||||||
assert sample_lens >= len(seq_ids)
|
|
||||||
temperatures += [temperature] * sample_lens
|
|
||||||
top_ps += [top_p] * sample_lens
|
|
||||||
top_ks += [top_k] * sample_lens
|
|
||||||
min_ps += [min_p] * sample_lens
|
|
||||||
presence_penalties += [p] * sample_lens
|
|
||||||
frequency_penalties += [f] * sample_lens
|
|
||||||
repetition_penalties += [r] * sample_lens
|
|
||||||
|
|
||||||
if do_penalties:
|
|
||||||
for seq_group in sampling_metadata.seq_groups:
|
|
||||||
seq_ids = seq_group.seq_ids
|
|
||||||
sampling_params = seq_group.sampling_params
|
|
||||||
if (seq_group.is_prompt
|
|
||||||
and sampling_params.prompt_logprobs is not None):
|
|
||||||
prefill_len = len(seq_group.prompt_logprob_indices)
|
|
||||||
prompt_tokens.extend(
|
|
||||||
array(VLLM_TOKEN_ID_ARRAY_TYPE)
|
|
||||||
for _ in range(prefill_len))
|
|
||||||
output_tokens.extend(
|
|
||||||
array(VLLM_TOKEN_ID_ARRAY_TYPE)
|
|
||||||
for _ in range(prefill_len))
|
|
||||||
if seq_group.do_sample:
|
|
||||||
for seq_id in seq_ids:
|
|
||||||
seq_data = seq_group.seq_data[seq_id]
|
|
||||||
prompt_tokens.append(seq_data.prompt_token_ids_array)
|
|
||||||
output_tokens.append(seq_data.output_token_ids_array)
|
|
||||||
|
|
||||||
sampling_tensors = SamplingTensors.from_lists(
|
|
||||||
temperatures,
|
|
||||||
top_ps,
|
|
||||||
top_ks,
|
|
||||||
min_ps,
|
|
||||||
presence_penalties,
|
|
||||||
frequency_penalties,
|
|
||||||
repetition_penalties,
|
|
||||||
prompt_tokens,
|
|
||||||
output_tokens,
|
|
||||||
vocab_size,
|
|
||||||
device,
|
|
||||||
dtype,
|
|
||||||
)
|
|
||||||
return (sampling_tensors, do_penalties, do_top_p_top_k, do_min_p)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_lists(
|
|
||||||
cls,
|
|
||||||
temperatures: list[float],
|
|
||||||
top_ps: list[float],
|
|
||||||
top_ks: list[int],
|
|
||||||
min_ps: list[float],
|
|
||||||
presence_penalties: list[float],
|
|
||||||
frequency_penalties: list[float],
|
|
||||||
repetition_penalties: list[float],
|
|
||||||
prompt_tokens: list[array],
|
|
||||||
output_tokens: list[array],
|
|
||||||
vocab_size: int,
|
|
||||||
device: torch.device,
|
|
||||||
dtype: torch.dtype,
|
|
||||||
) -> "SamplingTensors":
|
|
||||||
# Note that the performance will be very bad without
|
|
||||||
# pinned memory.
|
|
||||||
pin_memory = is_pin_memory_available()
|
|
||||||
|
|
||||||
do_penalties = prompt_tokens or output_tokens
|
|
||||||
|
|
||||||
if do_penalties:
|
|
||||||
prompt_t = make_tensor_with_pad(
|
|
||||||
prompt_tokens,
|
|
||||||
vocab_size,
|
|
||||||
device="cpu",
|
|
||||||
dtype=torch.int64,
|
|
||||||
pin_memory=pin_memory,
|
|
||||||
)
|
|
||||||
output_t = make_tensor_with_pad(
|
|
||||||
output_tokens,
|
|
||||||
vocab_size,
|
|
||||||
device="cpu",
|
|
||||||
dtype=torch.int64,
|
|
||||||
pin_memory=pin_memory,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
empty_tensor = torch.empty(0, device=device, dtype=torch.long)
|
|
||||||
prompt_t = empty_tensor
|
|
||||||
output_t = empty_tensor
|
|
||||||
|
|
||||||
temperatures_t = torch.tensor(
|
|
||||||
temperatures,
|
|
||||||
device="cpu",
|
|
||||||
dtype=dtype,
|
|
||||||
pin_memory=pin_memory,
|
|
||||||
)
|
|
||||||
top_ps_t = torch.tensor(
|
|
||||||
top_ps,
|
|
||||||
device="cpu",
|
|
||||||
dtype=dtype,
|
|
||||||
pin_memory=pin_memory,
|
|
||||||
)
|
|
||||||
min_ps_t = torch.tensor(
|
|
||||||
min_ps,
|
|
||||||
device="cpu",
|
|
||||||
dtype=dtype,
|
|
||||||
pin_memory=pin_memory,
|
|
||||||
)
|
|
||||||
presence_penalties_t = torch.tensor(
|
|
||||||
presence_penalties,
|
|
||||||
device="cpu",
|
|
||||||
dtype=dtype,
|
|
||||||
pin_memory=pin_memory,
|
|
||||||
)
|
|
||||||
frequency_penalties_t = torch.tensor(
|
|
||||||
frequency_penalties,
|
|
||||||
device="cpu",
|
|
||||||
dtype=dtype,
|
|
||||||
pin_memory=pin_memory,
|
|
||||||
)
|
|
||||||
repetition_penalties_t = torch.tensor(
|
|
||||||
repetition_penalties,
|
|
||||||
device="cpu",
|
|
||||||
dtype=dtype,
|
|
||||||
pin_memory=pin_memory,
|
|
||||||
)
|
|
||||||
top_ks_t = torch.tensor(
|
|
||||||
top_ks,
|
|
||||||
device="cpu",
|
|
||||||
dtype=torch.int,
|
|
||||||
pin_memory=pin_memory,
|
|
||||||
)
|
|
||||||
# Because the memory is pinned, we can do non-blocking
|
|
||||||
# transfer to device.
|
|
||||||
|
|
||||||
return cls(
|
|
||||||
temperatures=temperatures_t.to(device=device, non_blocking=True),
|
|
||||||
top_ps=top_ps_t.to(device=device, non_blocking=True),
|
|
||||||
top_ks=top_ks_t.to(device=device, non_blocking=True),
|
|
||||||
min_ps=min_ps_t.to(device=device, non_blocking=True),
|
|
||||||
presence_penalties=presence_penalties_t.to(device=device,
|
|
||||||
non_blocking=True),
|
|
||||||
frequency_penalties=frequency_penalties_t.to(device=device,
|
|
||||||
non_blocking=True),
|
|
||||||
repetition_penalties=repetition_penalties_t.to(device=device,
|
|
||||||
non_blocking=True),
|
|
||||||
prompt_tokens=prompt_t.to(device=device, non_blocking=True),
|
|
||||||
output_tokens=output_t.to(device=device, non_blocking=True),
|
|
||||||
)
|
|
||||||
|
|||||||
1322
vllm/sequence.py
1322
vllm/sequence.py
File diff suppressed because it is too large
Load Diff
@ -1,162 +0,0 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
||||||
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from vllm.logprobs import Logprob
|
|
||||||
from vllm.sequence import (VLLM_INVALID_TOKEN_ID, SamplingParams, Sequence,
|
|
||||||
SequenceGroup)
|
|
||||||
|
|
||||||
from .detokenizer_utils import (convert_prompt_ids_to_tokens,
|
|
||||||
detokenize_incrementally)
|
|
||||||
from .tokenizer import AnyTokenizer
|
|
||||||
|
|
||||||
|
|
||||||
class Detokenizer:
|
|
||||||
"""Provides methods to decode the output of a model into text."""
|
|
||||||
|
|
||||||
def __init__(self, tokenizer: AnyTokenizer):
|
|
||||||
self.tokenizer = tokenizer
|
|
||||||
|
|
||||||
def decode_prompt_logprobs_inplace(self, seq_group: SequenceGroup,
|
|
||||||
prompt_logprobs: list[Optional[dict[
|
|
||||||
int, Logprob]]],
|
|
||||||
position_offset: int) -> None:
|
|
||||||
"""Decodes the logprobs for the prompt of a sequence group.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
seq_group: The sequence group to decode.
|
|
||||||
prompt_logprobs: The logprobs to decode.
|
|
||||||
position_offset: Offset of the first index of the logprobs
|
|
||||||
relative to the start of the sequence (for chunked prefill).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The prompt logprobs with the decoded tokens.
|
|
||||||
"""
|
|
||||||
prms = seq_group.sampling_params
|
|
||||||
assert prms is not None
|
|
||||||
|
|
||||||
# We can pick any sequence for the prompt.
|
|
||||||
seq = seq_group.get_seqs()[0]
|
|
||||||
# Only prompt, without the generated token.
|
|
||||||
all_token_ids = seq.get_token_ids()
|
|
||||||
prompt_token_ids = all_token_ids[:-1]
|
|
||||||
prefix_offset = 0
|
|
||||||
read_offset = 0
|
|
||||||
next_iter_prefix_offset = 0
|
|
||||||
next_iter_read_offset = 0
|
|
||||||
next_iter_tokens: list[str] = []
|
|
||||||
prev_tokens = None
|
|
||||||
|
|
||||||
for token_position_in_logprob, prompt_logprobs_for_token in enumerate(
|
|
||||||
prompt_logprobs):
|
|
||||||
|
|
||||||
# Absolute token position equals the index in the logprobs
|
|
||||||
# list plus the offset of the entire logprobs list relative
|
|
||||||
# to the start of the sequence.
|
|
||||||
token_position = token_position_in_logprob + position_offset
|
|
||||||
if not prompt_logprobs_for_token:
|
|
||||||
continue
|
|
||||||
for token_id, sample_logprob in prompt_logprobs_for_token.items():
|
|
||||||
if (sample_logprob.decoded_token is None
|
|
||||||
and token_id != VLLM_INVALID_TOKEN_ID):
|
|
||||||
prompt_token_ids_with_token = (
|
|
||||||
prompt_token_ids[:token_position] + [token_id])
|
|
||||||
(new_tokens, new_text, new_prefix_offset,
|
|
||||||
new_read_offset) = detokenize_incrementally(
|
|
||||||
tokenizer=self.tokenizer,
|
|
||||||
all_input_ids=prompt_token_ids_with_token,
|
|
||||||
prev_tokens=prev_tokens,
|
|
||||||
prefix_offset=prefix_offset,
|
|
||||||
read_offset=read_offset,
|
|
||||||
skip_special_tokens=prms.skip_special_tokens,
|
|
||||||
spaces_between_special_tokens=prms.
|
|
||||||
spaces_between_special_tokens,
|
|
||||||
)
|
|
||||||
|
|
||||||
sample_logprob.decoded_token = new_text
|
|
||||||
|
|
||||||
# Use the offsets & prev tokens corresponding to
|
|
||||||
# real tokens to ensure detokenization is consistent
|
|
||||||
# actual with prompt.
|
|
||||||
if token_id == all_token_ids[token_position]:
|
|
||||||
next_iter_prefix_offset = new_prefix_offset
|
|
||||||
next_iter_read_offset = new_read_offset
|
|
||||||
next_iter_tokens = new_tokens
|
|
||||||
|
|
||||||
# Advance to the next token position.
|
|
||||||
prefix_offset = next_iter_prefix_offset
|
|
||||||
read_offset = next_iter_read_offset
|
|
||||||
if prev_tokens is None:
|
|
||||||
prev_tokens = next_iter_tokens.copy()
|
|
||||||
else:
|
|
||||||
prev_tokens.extend(next_iter_tokens)
|
|
||||||
|
|
||||||
def decode_sequence_inplace(self, seq: Sequence,
|
|
||||||
prms: SamplingParams) -> int:
|
|
||||||
"""Decodes the new token for a sequence. In-place operation.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
seq: The sequence to decode.
|
|
||||||
prms: The sampling parameters used to generate the sequence.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The number of characters added to the output text.
|
|
||||||
"""
|
|
||||||
all_input_ids = seq.get_token_ids()
|
|
||||||
token_id_generated_this_iteration = all_input_ids[-1]
|
|
||||||
|
|
||||||
# Convert prompt token IDs to tokens if necessary.
|
|
||||||
# Do it here so that we don't have to repeat this
|
|
||||||
# computation for each logprob.
|
|
||||||
if seq.tokens is None:
|
|
||||||
(seq.tokens, seq.prefix_offset,
|
|
||||||
seq.read_offset) = convert_prompt_ids_to_tokens(
|
|
||||||
tokenizer=self.tokenizer,
|
|
||||||
prompt_ids=all_input_ids[:-1],
|
|
||||||
skip_special_tokens=prms.skip_special_tokens,
|
|
||||||
)
|
|
||||||
|
|
||||||
(new_tokens, new_decoded_token_text, prefix_offset,
|
|
||||||
read_offset) = detokenize_incrementally(
|
|
||||||
tokenizer=self.tokenizer,
|
|
||||||
all_input_ids=all_input_ids,
|
|
||||||
prev_tokens=seq.tokens,
|
|
||||||
prefix_offset=seq.prefix_offset,
|
|
||||||
read_offset=seq.read_offset,
|
|
||||||
skip_special_tokens=prms.skip_special_tokens,
|
|
||||||
spaces_between_special_tokens=prms.spaces_between_special_tokens,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Decode logprobs
|
|
||||||
logprobs = seq.output_logprobs[-1]
|
|
||||||
if logprobs:
|
|
||||||
previous_tokens = all_input_ids[:-1]
|
|
||||||
for token_id, sample_logprob in logprobs.items():
|
|
||||||
# If the token was generated this iteration,
|
|
||||||
# use the provided text.
|
|
||||||
if token_id == token_id_generated_this_iteration:
|
|
||||||
sample_logprob.decoded_token = new_decoded_token_text
|
|
||||||
continue
|
|
||||||
|
|
||||||
if (sample_logprob.decoded_token is None
|
|
||||||
and token_id != VLLM_INVALID_TOKEN_ID):
|
|
||||||
all_input_ids_with_logprob = previous_tokens + [token_id]
|
|
||||||
(_, new_text, _, _) = detokenize_incrementally(
|
|
||||||
tokenizer=self.tokenizer,
|
|
||||||
all_input_ids=all_input_ids_with_logprob,
|
|
||||||
prev_tokens=seq.tokens,
|
|
||||||
prefix_offset=seq.prefix_offset,
|
|
||||||
read_offset=seq.read_offset,
|
|
||||||
skip_special_tokens=prms.skip_special_tokens,
|
|
||||||
spaces_between_special_tokens=prms.
|
|
||||||
spaces_between_special_tokens,
|
|
||||||
)
|
|
||||||
sample_logprob.decoded_token = new_text
|
|
||||||
|
|
||||||
seq.tokens.extend(new_tokens)
|
|
||||||
seq.prefix_offset = prefix_offset
|
|
||||||
seq.read_offset = read_offset
|
|
||||||
seq.output_text += new_decoded_token_text
|
|
||||||
|
|
||||||
return len(new_decoded_token_text)
|
|
||||||
@ -11,12 +11,12 @@ import torch.nn as nn
|
|||||||
from vllm.config import VllmConfig, set_current_vllm_config
|
from vllm.config import VllmConfig, set_current_vllm_config
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
|
||||||
from vllm.sequence import ExecuteModelRequest
|
from vllm.sequence import ExecuteModelRequest
|
||||||
from vllm.utils import (enable_trace_function_call_for_thread,
|
from vllm.utils import (enable_trace_function_call_for_thread,
|
||||||
resolve_obj_by_qualname, run_method,
|
resolve_obj_by_qualname, run_method,
|
||||||
update_environment_variables,
|
update_environment_variables,
|
||||||
warn_for_unimplemented_methods)
|
warn_for_unimplemented_methods)
|
||||||
|
from vllm.v1.outputs import SamplerOutput
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user