# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import dataclasses import torch from vllm.attention import AttentionMetadata, AttentionMetadataBuilder from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.backends.utils import CommonAttentionState from vllm.model_executor import SamplingMetadata from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata class MockAttentionBackend(AttentionBackend): @staticmethod def get_name() -> str: raise NotImplementedError @staticmethod def get_impl_cls(): raise NotImplementedError @staticmethod def get_metadata_cls() -> type["AttentionMetadata"]: return AttentionMetadata @staticmethod def get_builder_cls() -> type["AttentionMetadataBuilder"]: return AttentionMetadataBuilder @staticmethod def get_state_cls() -> type["CommonAttentionState"]: return CommonAttentionState @staticmethod def get_kv_cache_shape( num_blocks: int, block_size: int, num_kv_heads: int, head_size: int, ) -> tuple[int, ...]: raise NotImplementedError @staticmethod def swap_blocks( src_kv_cache: torch.Tensor, dst_kv_cache: torch.Tensor, src_to_dst: torch.Tensor, ) -> None: pass @staticmethod def copy_blocks( kv_caches: list[torch.Tensor], src_to_dists: torch.Tensor, ) -> None: pass def test_model_runner_input(): sampling_metadata = SamplingMetadata( ["seq_group"], "selected_token_indices", "categorized_sample_indices", "num_prompts", ) attn_metadata = AttentionMetadata( num_prefills=1, num_prefill_tokens=2, num_decode_tokens=3, slot_mapping=torch.zeros(1), multi_modal_placeholder_index_maps=None, enable_kv_scales_calculation=True, ) model_input = ModelInputForGPUWithSamplingMetadata( input_tokens=torch.ones(10), input_positions=torch.ones(10), sampling_metadata=sampling_metadata, attn_metadata=attn_metadata) assert isinstance(model_input, ModelInputForGPUWithSamplingMetadata) # Test round trip serialization. tensor_dict = model_input.as_broadcastable_tensor_dict() attn_backend = MockAttentionBackend() received_model_input = ( ModelInputForGPUWithSamplingMetadata.from_broadcasted_tensor_dict( tensor_dict, attn_backend=attn_backend)) # Check that received copy has correct values. assert isinstance(received_model_input, ModelInputForGPUWithSamplingMetadata) assert received_model_input.input_tokens is not None assert ( received_model_input.input_tokens == model_input.input_tokens).all() assert received_model_input.input_positions is not None assert (received_model_input.input_positions == model_input.input_positions ).all() assert received_model_input.multi_modal_kwargs is None assert (received_model_input.multi_modal_kwargs == model_input.multi_modal_kwargs) assert received_model_input.lora_requests is None assert received_model_input.lora_requests == model_input.lora_requests assert received_model_input.lora_mapping is None assert received_model_input.lora_mapping == model_input.lora_mapping for field in dataclasses.fields(AttentionMetadata): assert getattr(received_model_input.attn_metadata, field.name, None) == getattr(attn_metadata, field.name, None) # For sampling metadata, only selected_token_indices is copied. assert (received_model_input.sampling_metadata.selected_token_indices == sampling_metadata.selected_token_indices) assert received_model_input.sampling_metadata.seq_groups is None