[Doc] Add docs for prompt replacement (#12318)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-01-22 22:56:29 +08:00 committed by GitHub
parent 16366ee8bb
commit 6609cdf019
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 79 additions and 15 deletions

View File

@ -218,7 +218,7 @@ class UltravoxMultiModalProcessor(
return [
PromptReplacement(
modality="audio",
target='<|audio|>',
target="<|audio|>",
replacement=get_replacement_ultravox,
)
]

View File

@ -29,41 +29,101 @@ if TYPE_CHECKING:
logger = init_logger(__name__)
_S = TypeVar("_S", str, list[int])
_PromptSeq = Union[str, list[int]]
PromptSeq = Union[str, list[int]]
"""A token sequence (list of token IDs) or text."""
@dataclass
class PromptReplacementDetails:
full: _PromptSeq
"""Details about the replacement token sequence or text."""
full: PromptSeq
"""The full replacement."""
features: _PromptSeq
features: PromptSeq
"""
The part of the replacement that corresponds to placeholder feature tokens.
The part of the replacement that corresponds to feature placeholders;
this will be replaced by the output of the vision encoder during model
inference.
"""
@staticmethod
def from_seq(seq: _PromptSeq) -> "PromptReplacementDetails":
def from_seq(seq: PromptSeq) -> "PromptReplacementDetails":
return PromptReplacementDetails(full=seq, features=seq)
_PromptRepl = Union[_PromptSeq, PromptReplacementDetails]
PromptRepl = Union[PromptSeq, PromptReplacementDetails]
"""
The replacement token sequence or text.
If only part of the replacement corresponds to feature placeholders, you can
use :class:`PromptReplacementDetails` to specify which part.
"""
@dataclass
class PromptReplacement:
"""
Defines how to replace portions of an input prompt with placeholder tokens.
Example:
For each image, replace one ``<image>`` input placeholder in the prompt
with a number of ``<image>`` feature placeholders
equal to the feature size of the vision encoder:
.. code-block:: python
PromptReplacement(
modality="image",
target="<image>",
replacement="<image>" * image_feature_size,
)
As above, but further pad the feature placeholders with ``<image_bos>``
and `<image_eos>``, which are not supposed to be passed to the vision
encoder:
.. code-block:: python
PromptReplacement(
modality="image",
target="<image>",
replacement=PromptReplacementDetails(
full="".join([
"<image_bos>",
"<image>" * image_feature_size,
"<image_eos>",
]),
features="<image>" * image_feature_size,
),
)
To avoid unnecessary tokenization during prompt replacement,
we recommended passing token sequences instead of text:
.. code-block:: python
PromptReplacement(
modality="image",
target=[image_token_id],
replacement=PromptReplacementDetails(
full=([image_bos_id] + [image_token_id] * image_feature_size
+ [image_eos_id]),
features=[image_token_id] * image_feature_size,
),
)
"""
modality: str
"""The modality for which the replacement is made."""
target: _PromptSeq
target: PromptSeq
"""The token sequence (or text) to find and replace."""
replacement: Union[Callable[[int], _PromptRepl],
_PromptRepl] = field(repr=False)
replacement: Union[Callable[[int], PromptRepl],
PromptRepl] = field(repr=False)
"""
Given the index of the processed item within :attr:`modality`,
output the replacement token sequence (or text).
@ -126,6 +186,10 @@ def full_groupby_modality(values: Iterable[_M]) -> ItemsView[str, list[_M]]:
@dataclass
class _BoundPromptSequence:
"""
A :data:`_PromptSeq` bound to a tokenizer to automatically
convert between token sequence and text representations.
"""
tokenizer: AnyTokenizer = field(repr=False)
_text: Optional[str]
@ -134,7 +198,7 @@ class _BoundPromptSequence:
@staticmethod
def from_seq(
tokenizer: AnyTokenizer,
seq: _PromptSeq,
seq: PromptSeq,
) -> "_BoundPromptSequence":
return _BoundPromptSequence(
tokenizer=tokenizer,
@ -180,9 +244,9 @@ class BoundPromptReplacement:
tokenizer: AnyTokenizer = field(repr=False)
modality: str
_target: _PromptSeq
_replacement: Union[Callable[[int], _PromptRepl],
_PromptRepl] = field(repr=False)
_target: PromptSeq
_replacement: Union[Callable[[int], PromptRepl],
PromptRepl] = field(repr=False)
def __post_init__(self) -> None:
self._replacement_cache = dict[int, _BoundPromptReplacementGroup]()
@ -350,7 +414,7 @@ def find_text_matches(
def _resolve_matches(
prompt: _PromptSeq,
prompt: PromptSeq,
mm_matches: Mapping[str, Sequence[_PromptReplacementMatch]],
) -> list[_PromptReplacementMatch]:
"""