mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 08:45:00 +08:00
[Bugfix] SharedStorage Connector for V1 PD multimodal (#21611)
Signed-off-by: fake0fan <645327136@qq.com> Signed-off-by: herotai214 <herotai214@gmail.com> Co-authored-by: herotai214 <herotai214@gmail.com>
This commit is contained in:
parent
004203e953
commit
4904e53c32
215
tests/v1/kv_connector/unit/test_shared_storage_connector.py
Normal file
215
tests/v1/kv_connector/unit/test_shared_storage_connector.py
Normal file
@ -0,0 +1,215 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
from dataclasses import asdict
|
||||||
|
from typing import NamedTuple
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from vllm import LLM, EngineArgs, SamplingParams
|
||||||
|
from vllm.assets.image import ImageAsset
|
||||||
|
from vllm.config import KVTransferConfig
|
||||||
|
from vllm.multimodal.utils import encode_image_base64
|
||||||
|
|
||||||
|
MODEL_NAME = "Qwen/Qwen2.5-VL-3B-Instruct"
|
||||||
|
|
||||||
|
SAMPLING_PARAMS = SamplingParams(temperature=0.0, top_k=1, max_tokens=128)
|
||||||
|
|
||||||
|
TEXT_PROMPTS = [
|
||||||
|
"What's in the image(s)? Around 30 words. What's special in 2nd image?",
|
||||||
|
"The future of AI is",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class InputCase(NamedTuple):
|
||||||
|
text: str
|
||||||
|
img: list[Image]
|
||||||
|
expected_len: int
|
||||||
|
info: str
|
||||||
|
|
||||||
|
|
||||||
|
def _check_path_len(path):
|
||||||
|
"""Return the latest length in path"""
|
||||||
|
return len(list(path.iterdir()))
|
||||||
|
|
||||||
|
|
||||||
|
def _list_path(path):
|
||||||
|
"""Return the list of foldername (hashes generatd) under the path"""
|
||||||
|
return list(path.iterdir())
|
||||||
|
|
||||||
|
|
||||||
|
def run_test(tmp_path, processor, llm: LLM, question: str,
|
||||||
|
image_urls: list[Image], expected_len: int, info: str):
|
||||||
|
"""
|
||||||
|
One individual test to process the prompt and output base on 1 set of input
|
||||||
|
Then check if the length in the strorage path matches the expected length
|
||||||
|
`info` introduces details or purpose of the individual test
|
||||||
|
"""
|
||||||
|
print(f"***info: {info}***")
|
||||||
|
print(
|
||||||
|
f"**Expected storage path length after llm generate: {expected_len}**")
|
||||||
|
process_prompt(processor, llm, question, image_urls)
|
||||||
|
|
||||||
|
print(f"Path matched expected length: {_check_path_len(tmp_path)}")
|
||||||
|
print(f"Hashes under the storage path: {_list_path(tmp_path)}")
|
||||||
|
|
||||||
|
assert _check_path_len(tmp_path) == expected_len, (
|
||||||
|
f"Expect storage path length {expected_len} ;",
|
||||||
|
f"but end up {_check_path_len(tmp_path)} instead. ", f"Info: {info}")
|
||||||
|
|
||||||
|
|
||||||
|
def process_prompt(processor, llm: LLM, question: str,
|
||||||
|
image_urls: list[Image]):
|
||||||
|
"""
|
||||||
|
Form the prompt based on the text and image input, then llm generate output
|
||||||
|
"""
|
||||||
|
placeholders = [{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": f"data:image;base64,{encode_image_base64(image_pil)}"
|
||||||
|
}
|
||||||
|
} for image_pil in image_urls]
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "You are a helpful assistant."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
*placeholders,
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": question
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
prompt = processor.apply_chat_template(messages,
|
||||||
|
tokenize=False,
|
||||||
|
add_generation_prompt=True)
|
||||||
|
|
||||||
|
outputs = llm.generate(
|
||||||
|
{
|
||||||
|
"prompt":
|
||||||
|
prompt,
|
||||||
|
**({
|
||||||
|
"multi_modal_data": {
|
||||||
|
"image": [*image_urls]
|
||||||
|
}
|
||||||
|
} if image_urls else {})
|
||||||
|
},
|
||||||
|
sampling_params=SAMPLING_PARAMS,
|
||||||
|
)
|
||||||
|
|
||||||
|
print("-" * 50)
|
||||||
|
print("Output:")
|
||||||
|
for o in outputs:
|
||||||
|
generated_text = o.outputs[0].text
|
||||||
|
print(generated_text)
|
||||||
|
print("-" * 50)
|
||||||
|
|
||||||
|
|
||||||
|
def test_shared_storage_connector_hashes(tmp_path):
|
||||||
|
"""
|
||||||
|
Tests that SharedStorageConnector saves KV to the storage locations
|
||||||
|
with proper hashes; that are unique for inputs with identical text but
|
||||||
|
differnt images (same size), or same multiple images but different orders.
|
||||||
|
"""
|
||||||
|
# Using tmp_path as the storage path to store KV
|
||||||
|
print(f"KV storage path at: {str(tmp_path)}")
|
||||||
|
|
||||||
|
# Configure the SharedStorageConnector
|
||||||
|
kv_transfer_config = KVTransferConfig(
|
||||||
|
kv_connector="SharedStorageConnector",
|
||||||
|
kv_role="kv_both",
|
||||||
|
kv_connector_extra_config={"shared_storage_path": str(tmp_path)})
|
||||||
|
|
||||||
|
engine_args = EngineArgs(
|
||||||
|
model=MODEL_NAME,
|
||||||
|
max_model_len=8192,
|
||||||
|
max_num_seqs=1,
|
||||||
|
kv_transfer_config=kv_transfer_config,
|
||||||
|
limit_mm_per_prompt={"image": 2},
|
||||||
|
)
|
||||||
|
|
||||||
|
# don't put this import at the top level
|
||||||
|
# it will call torch.cuda.device_count()
|
||||||
|
from transformers import AutoProcessor # noqa: F401
|
||||||
|
|
||||||
|
# Create processor to handle the chat prompt
|
||||||
|
processor = AutoProcessor.from_pretrained(MODEL_NAME)
|
||||||
|
|
||||||
|
# Prepare images for the tests
|
||||||
|
# Resize to the same size to check hashes correctness
|
||||||
|
image_1 = ImageAsset("stop_sign").pil_image.resize((1280, 720))
|
||||||
|
image_2 = ImageAsset("cherry_blossom").pil_image.resize((1280, 720))
|
||||||
|
|
||||||
|
# Make sure that they are not the same picture
|
||||||
|
assert image_1 != image_2, "The images should not be identical"
|
||||||
|
|
||||||
|
# Create the LLM instance
|
||||||
|
engine_args = asdict(engine_args)
|
||||||
|
llm = LLM(**engine_args)
|
||||||
|
|
||||||
|
# Prepare the input cases
|
||||||
|
input_cases = [
|
||||||
|
InputCase(text=TEXT_PROMPTS[0],
|
||||||
|
img=[image_1],
|
||||||
|
expected_len=1,
|
||||||
|
info="image_1 single input the first time."),
|
||||||
|
InputCase(text=TEXT_PROMPTS[0],
|
||||||
|
img=[image_2],
|
||||||
|
expected_len=2,
|
||||||
|
info=("image_2 single input the first time. "
|
||||||
|
"It is in same pixel size with image_1, yet it "
|
||||||
|
"should be able to form a new unique hash.")),
|
||||||
|
InputCase(text=TEXT_PROMPTS[0],
|
||||||
|
img=[image_1],
|
||||||
|
expected_len=2,
|
||||||
|
info=("image_1 single input the 2nd time. "
|
||||||
|
"It should not form aother new hash.")),
|
||||||
|
InputCase(text=TEXT_PROMPTS[0],
|
||||||
|
img=[image_2],
|
||||||
|
expected_len=2,
|
||||||
|
info=("image_2 single input the 2nd time. "
|
||||||
|
"It should not form aother new hash.")),
|
||||||
|
InputCase(text=TEXT_PROMPTS[0],
|
||||||
|
img=[image_1, image_2],
|
||||||
|
expected_len=3,
|
||||||
|
info="image_1 with image_2 input the first time."),
|
||||||
|
InputCase(text=TEXT_PROMPTS[0],
|
||||||
|
img=[image_2, image_1],
|
||||||
|
expected_len=4,
|
||||||
|
info="The image order is swapped. Should form new hash."),
|
||||||
|
InputCase(text=TEXT_PROMPTS[0],
|
||||||
|
img=[image_1, image_2],
|
||||||
|
expected_len=4,
|
||||||
|
info=("[image_1, image_2] input the 2nd time. "
|
||||||
|
"It should not form aother new hash.")),
|
||||||
|
InputCase(text=TEXT_PROMPTS[0],
|
||||||
|
img=[image_2, image_1],
|
||||||
|
expected_len=4,
|
||||||
|
info=("[image_2, image_1] input the 2nd time. "
|
||||||
|
"It should not form aother new hash.")),
|
||||||
|
InputCase(text=TEXT_PROMPTS[0],
|
||||||
|
img=[],
|
||||||
|
expected_len=5,
|
||||||
|
info="Pure text input test as a case-control"),
|
||||||
|
InputCase(text=TEXT_PROMPTS[0],
|
||||||
|
img=[],
|
||||||
|
expected_len=5,
|
||||||
|
info="Identical pure text input as a case-control"),
|
||||||
|
InputCase(text=TEXT_PROMPTS[1],
|
||||||
|
img=[],
|
||||||
|
expected_len=6,
|
||||||
|
info="Another pure text input as a case-control"),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Run tests
|
||||||
|
for case_id, (text, img, expected_len, info) in enumerate(input_cases):
|
||||||
|
print("\n", "=" * 25, f"Below running input case: {case_id}", "=" * 25)
|
||||||
|
run_test(tmp_path, processor, llm, text, img, expected_len, info)
|
||||||
|
|
||||||
|
print("All tests passed successfully!")
|
||||||
@ -32,10 +32,11 @@ class ReqMeta:
|
|||||||
slot_mapping: torch.Tensor
|
slot_mapping: torch.Tensor
|
||||||
# Is store or load
|
# Is store or load
|
||||||
is_store: bool
|
is_store: bool
|
||||||
|
mm_hashes: list[str]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def make_meta(token_ids: list[int], block_ids: list[int], block_size: int,
|
def make_meta(token_ids: list[int], block_ids: list[int], block_size: int,
|
||||||
is_store: bool) -> "ReqMeta":
|
is_store: bool, mm_hashes: list[str]) -> "ReqMeta":
|
||||||
valid_num_tokens = align_to_block_size(len(token_ids), block_size)
|
valid_num_tokens = align_to_block_size(len(token_ids), block_size)
|
||||||
token_ids_tensor = torch.tensor(token_ids)[:valid_num_tokens]
|
token_ids_tensor = torch.tensor(token_ids)[:valid_num_tokens]
|
||||||
block_ids_tensor = torch.tensor(block_ids)
|
block_ids_tensor = torch.tensor(block_ids)
|
||||||
@ -48,6 +49,7 @@ class ReqMeta:
|
|||||||
token_ids=token_ids_tensor,
|
token_ids=token_ids_tensor,
|
||||||
slot_mapping=slot_mapping,
|
slot_mapping=slot_mapping,
|
||||||
is_store=is_store,
|
is_store=is_store,
|
||||||
|
mm_hashes=mm_hashes,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -64,9 +66,11 @@ class SharedStorageConnectorMetadata(KVConnectorMetadata):
|
|||||||
block_ids: list[int],
|
block_ids: list[int],
|
||||||
block_size: int,
|
block_size: int,
|
||||||
is_store: bool,
|
is_store: bool,
|
||||||
|
mm_hashes: list[str],
|
||||||
) -> None:
|
) -> None:
|
||||||
self.requests.append(
|
self.requests.append(
|
||||||
ReqMeta.make_meta(token_ids, block_ids, block_size, is_store))
|
ReqMeta.make_meta(token_ids, block_ids, block_size, is_store,
|
||||||
|
mm_hashes))
|
||||||
|
|
||||||
|
|
||||||
class SharedStorageConnector(KVConnectorBase_V1):
|
class SharedStorageConnector(KVConnectorBase_V1):
|
||||||
@ -169,7 +173,7 @@ class SharedStorageConnector(KVConnectorBase_V1):
|
|||||||
forward_context.virtual_engine]
|
forward_context.virtual_engine]
|
||||||
|
|
||||||
filename = self._generate_filename_debug(
|
filename = self._generate_filename_debug(
|
||||||
layer_name, request.token_ids)
|
layer_name, request.token_ids, request.mm_hashes)
|
||||||
kv_cache = safetensors.torch.load_file(
|
kv_cache = safetensors.torch.load_file(
|
||||||
filename)["kv_cache"].cuda()
|
filename)["kv_cache"].cuda()
|
||||||
inject_kv_into_layer(kv_cache_layer, kv_cache,
|
inject_kv_into_layer(kv_cache_layer, kv_cache,
|
||||||
@ -221,7 +225,7 @@ class SharedStorageConnector(KVConnectorBase_V1):
|
|||||||
for request in connector_metadata.requests:
|
for request in connector_metadata.requests:
|
||||||
if request.is_store:
|
if request.is_store:
|
||||||
filename = self._generate_filename_debug(
|
filename = self._generate_filename_debug(
|
||||||
layer_name, request.token_ids)
|
layer_name, request.token_ids, request.mm_hashes)
|
||||||
kv_cache = extract_kv_from_layer(kv_layer,
|
kv_cache = extract_kv_from_layer(kv_layer,
|
||||||
request.slot_mapping)
|
request.slot_mapping)
|
||||||
tensors = {"kv_cache": kv_cache.detach().cpu()}
|
tensors = {"kv_cache": kv_cache.detach().cpu()}
|
||||||
@ -299,7 +303,8 @@ class SharedStorageConnector(KVConnectorBase_V1):
|
|||||||
meta.add_request(token_ids=new_req.prompt_token_ids,
|
meta.add_request(token_ids=new_req.prompt_token_ids,
|
||||||
block_ids=new_req.block_ids[0],
|
block_ids=new_req.block_ids[0],
|
||||||
block_size=self._block_size,
|
block_size=self._block_size,
|
||||||
is_store=False)
|
is_store=False,
|
||||||
|
mm_hashes=new_req.mm_hashes)
|
||||||
total_need_load += 1
|
total_need_load += 1
|
||||||
else:
|
else:
|
||||||
# NOTE: here, we set the store and load being exclusive,
|
# NOTE: here, we set the store and load being exclusive,
|
||||||
@ -310,7 +315,8 @@ class SharedStorageConnector(KVConnectorBase_V1):
|
|||||||
meta.add_request(token_ids=new_req.prompt_token_ids,
|
meta.add_request(token_ids=new_req.prompt_token_ids,
|
||||||
block_ids=new_req.block_ids[0],
|
block_ids=new_req.block_ids[0],
|
||||||
block_size=self._block_size,
|
block_size=self._block_size,
|
||||||
is_store=True)
|
is_store=True,
|
||||||
|
mm_hashes=new_req.mm_hashes)
|
||||||
|
|
||||||
cached_reqs = scheduler_output.scheduled_cached_reqs
|
cached_reqs = scheduler_output.scheduled_cached_reqs
|
||||||
for i, req_id in enumerate(cached_reqs.req_ids):
|
for i, req_id in enumerate(cached_reqs.req_ids):
|
||||||
@ -338,7 +344,8 @@ class SharedStorageConnector(KVConnectorBase_V1):
|
|||||||
meta.add_request(token_ids=token_ids,
|
meta.add_request(token_ids=token_ids,
|
||||||
block_ids=block_ids,
|
block_ids=block_ids,
|
||||||
block_size=self._block_size,
|
block_size=self._block_size,
|
||||||
is_store=False)
|
is_store=False,
|
||||||
|
mm_hashes=request.mm_hashes)
|
||||||
total_need_load += 1
|
total_need_load += 1
|
||||||
|
|
||||||
assert total_need_load == len(self._requests_need_load)
|
assert total_need_load == len(self._requests_need_load)
|
||||||
@ -359,20 +366,28 @@ class SharedStorageConnector(KVConnectorBase_V1):
|
|||||||
len(request.prompt_token_ids) - 1, self._block_size)
|
len(request.prompt_token_ids) - 1, self._block_size)
|
||||||
foldername = self._generate_foldername_debug(torch.tensor(
|
foldername = self._generate_foldername_debug(torch.tensor(
|
||||||
request.prompt_token_ids)[:num_tokens_to_check],
|
request.prompt_token_ids)[:num_tokens_to_check],
|
||||||
|
request.mm_hashes,
|
||||||
create_folder=False)
|
create_folder=False)
|
||||||
return os.path.exists(foldername)
|
return os.path.exists(foldername)
|
||||||
|
|
||||||
def _generate_foldername_debug(
|
def _generate_foldername_debug(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
token_ids: torch.Tensor,
|
||||||
|
mm_hashes: list[str],
|
||||||
create_folder=False,
|
create_folder=False,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Generate a folder name based on the hash of the bytes of the input
|
"""Generate a folder name based on the hash of the bytes of the input
|
||||||
ids.
|
ids.
|
||||||
"""
|
"""
|
||||||
input_ids_bytes = input_ids.numpy().tobytes()
|
token_bytes = token_ids.numpy().tobytes()
|
||||||
input_ids_hash = hashlib.md5(input_ids_bytes,
|
# Add mm_hashes to the bytes being hashed to avoid path traversal and
|
||||||
|
# to create a canonical key.
|
||||||
|
if mm_hashes:
|
||||||
|
mm_str = "-".join(mm_hashes)
|
||||||
|
token_bytes += mm_str.encode('utf-8')
|
||||||
|
input_ids_hash = hashlib.md5(token_bytes,
|
||||||
usedforsecurity=False).hexdigest()
|
usedforsecurity=False).hexdigest()
|
||||||
|
|
||||||
foldername = os.path.join(self._storage_path, input_ids_hash)
|
foldername = os.path.join(self._storage_path, input_ids_hash)
|
||||||
if create_folder:
|
if create_folder:
|
||||||
os.makedirs(foldername, exist_ok=True)
|
os.makedirs(foldername, exist_ok=True)
|
||||||
@ -381,12 +396,14 @@ class SharedStorageConnector(KVConnectorBase_V1):
|
|||||||
def _generate_filename_debug(
|
def _generate_filename_debug(
|
||||||
self,
|
self,
|
||||||
layer_name: str,
|
layer_name: str,
|
||||||
input_ids: torch.Tensor,
|
token_ids: torch.Tensor,
|
||||||
|
mm_hashes: list[str],
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Generate a file name based on the layer name and the hash
|
"""Generate a file name based on the layer name and the hash
|
||||||
of the bytes of the input ids.
|
of the bytes of the input ids.
|
||||||
"""
|
"""
|
||||||
foldername = self._generate_foldername_debug(input_ids,
|
foldername = self._generate_foldername_debug(token_ids,
|
||||||
|
mm_hashes=mm_hashes,
|
||||||
create_folder=True)
|
create_folder=True)
|
||||||
return os.path.join(foldername, f"{layer_name}.safetensors")
|
return os.path.join(foldername, f"{layer_name}.safetensors")
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user