disable mm cache when enable_tower_connector_lora

Signed-off-by: bk-201 <joy25810@foxmail.com>
This commit is contained in:
bk-201 2025-12-23 01:47:18 +00:00
parent f3a55ff958
commit f114b4e143
3 changed files with 106 additions and 74 deletions

View File

@ -15,10 +15,11 @@ class TestConfig:
max_num_seqs: int = 2
max_loras: int = 2
max_lora_rank: int = 32
enable_tower_connector_lora: bool = True
enable_tower_connector_lora: bool = False
max_model_len: int = 8192
gpu_memory_utilization: float = 0.85
mm_processor_kwargs: dict[str, int] | None = None
mm_processor_cache_gb: float = 4
def __post_init__(self):
if self.mm_processor_kwargs is None:
@ -54,6 +55,7 @@ class Qwen2VLTester:
trust_remote_code=True,
gpu_memory_utilization=self.config.gpu_memory_utilization,
mm_processor_kwargs=self.config.mm_processor_kwargs,
mm_processor_cache_gb=self.config.mm_processor_cache_gb,
max_model_len=self.config.max_model_len,
)
@ -62,6 +64,7 @@ class Qwen2VLTester:
images: list[ImageAsset],
expected_outputs: list[str],
lora_id: int | None = None,
lora_name: str | None = None,
temperature: float = 0,
max_tokens: int = 5,
):
@ -77,7 +80,9 @@ class Qwen2VLTester:
for asset in images
]
lora_request = LoRARequest(str(lora_id), lora_id, self.config.lora_path)
lora_request = LoRARequest(
lora_name if lora_name else str(lora_id), lora_id, self.config.lora_path
)
outputs = self.llm.generate(inputs, sampling_params, lora_request=lora_request)
generated_texts = [output.outputs[0].text.strip() for output in outputs]
# Validate outputs
@ -207,59 +212,15 @@ def test_qwen25vl_lora(qwen25vl_lora_files):
tester.run_test(TEST_IMAGES, expected_outputs=EXPECTED_OUTPUTS, lora_id=lora_id)
def test_qwen2vl_language_lora(qwen2vl_language_lora_files):
"""
Test language-only LoRA adapter.
"""
config = TestConfig(
model_path=QWEN2VL_MODEL_PATH, lora_path=qwen2vl_language_lora_files
)
tester = Qwen2VLTester(config)
for lora_id in [1, 2]:
tester.run_test(
TEST_IMAGES, expected_outputs=EXPECTED_OUTPUTS_LANGUAGE, lora_id=lora_id
)
def test_qwen2vl_vision_lora(qwen2vl_vision_tower_connector_lora_files):
"""
Test vision tower + connector LoRA adapter.
"""
config = TestConfig(
model_path=QWEN2VL_MODEL_PATH,
lora_path=qwen2vl_vision_tower_connector_lora_files,
)
tester = Qwen2VLTester(config)
for lora_id in [1, 2]:
tester.run_test(
TEST_IMAGES, expected_outputs=EXPECTED_OUTPUTS_VISION, lora_id=lora_id
)
def test_qwen2vl_vision_no_connector_lora(
qwen2vl_vision_tower_lora_files,
):
"""
Test vision tower only LoRA adapter.
"""
config = TestConfig(
model_path=QWEN2VL_MODEL_PATH,
lora_path=qwen2vl_vision_tower_lora_files,
)
tester = Qwen2VLTester(config)
for lora_id in [1, 2]:
tester.run_test(
TEST_IMAGES,
expected_outputs=EXPECTED_OUTPUTS_VISION_NO_CONNECTOR,
lora_id=lora_id,
)
def test_qwen25vl_vision_lora(qwen25vl_vision_lora_files):
config = TestConfig(
model_path=QWEN25VL_MODEL_PATH,
lora_path=qwen25vl_vision_lora_files,
# Currently, tower_connector_lora is incompatible with
# the multi-modal processor cache.
# TODO: Remove this restriction
mm_processor_cache_gb=0,
enable_tower_connector_lora=True,
)
tester = Qwen2VLTester(config)
for lora_id in [1, 2]:
@ -274,6 +235,11 @@ def test_qwen3vl_vision_lora(qwen3vl_vision_lora_files):
config = TestConfig(
model_path=QWEN3VL_MODEL_PATH,
lora_path=qwen3vl_vision_lora_files,
# Currently, tower_connector_lora is incompatible with
# the multi-modal processor cache.
# TODO: Remove this restriction
mm_processor_cache_gb=0,
enable_tower_connector_lora=True,
)
tester = Qwen2VLTester(config)
for lora_id in [1, 2]:
@ -282,3 +248,61 @@ def test_qwen3vl_vision_lora(qwen3vl_vision_lora_files):
expected_outputs=EXPECTED_OUTPUTS_VISION_QWEN3_VL,
lora_id=lora_id,
)
def test_qwen2vl_multiple_lora_types(
qwen2vl_language_lora_files,
qwen2vl_vision_tower_connector_lora_files,
qwen2vl_vision_tower_lora_files,
):
"""
Test multiple LoRA adapter types (language, vision tower + connector,
vision tower only) using the same LLM instance to verify mm_encoder_cache
behavior with different LoRA requests.
By reusing the same LLM instance across different LoRA requests, we ensure that
the multimodal encoder cache correctly manages state transitions between
language-only and vision-enabled LoRA adapters.
"""
config = TestConfig(
model_path=QWEN2VL_MODEL_PATH,
# We'll override the lora_path for each specific test, but need to provide
# an initial path for initialization
lora_path=qwen2vl_language_lora_files,
# Currently, tower_connector_lora is incompatible with
# the multi-modal processor cache.
# TODO: Remove this restriction
mm_processor_cache_gb=0,
enable_tower_connector_lora=True,
)
tester = Qwen2VLTester(config)
# Test 1: Language-only LoRA adapter
tester.config.lora_path = qwen2vl_language_lora_files
for lora_id in [1, 2]:
tester.run_test(
TEST_IMAGES,
expected_outputs=EXPECTED_OUTPUTS_LANGUAGE,
lora_id=lora_id,
lora_name="language_only",
)
# Test 2: Vision tower + connector LoRA adapter
tester.config.lora_path = qwen2vl_vision_tower_connector_lora_files
for lora_id in [3, 4]:
tester.run_test(
TEST_IMAGES,
expected_outputs=EXPECTED_OUTPUTS_VISION,
lora_id=lora_id,
lora_name="vision_tower_connector",
)
# Test 3: Vision tower only LoRA adapter (no connector)
tester.config.lora_path = qwen2vl_vision_tower_lora_files
for lora_id in [5, 6]:
tester.run_test(
TEST_IMAGES,
expected_outputs=EXPECTED_OUTPUTS_VISION_NO_CONNECTOR,
lora_id=lora_id,
lora_name="vision_tower",
)

View File

@ -1647,6 +1647,19 @@ class EngineArgs:
else None
)
if (
lora_config is not None
and lora_config.enable_tower_connector_lora
and self.mm_processor_cache_gb != 0
):
raise ValueError(
"Currently, enable_tower_connector_lora is "
"incompatible with the multi-modal processor cache. "
"When enable_tower_connector_lora is set, "
"mm_processor_cache_gb must be 0, got %s",
self.mm_processor_cache_gb,
)
if (
lora_config is not None
and speculative_config is not None

View File

@ -406,6 +406,20 @@ class InputProcessor:
mm_uuids[modality] = [f"{request_id}-{modality}-{i}" for i in range(n)]
return mm_uuids
def _get_mm_identifier(
self,
mm_hash: str,
lora_request: LoRARequest | None,
) -> str:
"""
When enable_tower_connector_lora is True, multi-modal embeddings
vary depending on the LoRA request. Therefore, the mm_hash must be
generated based on the LoRA request to prevent incorrect cache hits.
"""
if lora_request is None or not self.lora_config.enable_tower_connector_lora:
return mm_hash
return f"{lora_request.lora_name}:{mm_hash}"
def process_inputs(
self,
request_id: str,
@ -458,28 +472,6 @@ class InputProcessor:
else:
mm_uuids = None
# When enable_tower_connector_lora is True, multi-modal embeddings
# vary depending on the LoRA request. Therefore, the mm_hash must be
# generated based on the LoRA request to prevent incorrect cache hits.
lora_config = self.lora_config
if (
mm_uuids
and lora_request
and lora_config
and lora_config.enable_tower_connector_lora
):
def add_mm_lora_prefix(val):
if isinstance(val, list):
return [
f"{lora_request.lora_name}:{v}" if v is not None else None
for v in val
]
else:
return f"{lora_request.lora_name}:{val}"
mm_uuids = {k: add_mm_lora_prefix(v) for k, v in mm_uuids.items()}
# Process inputs, which includes:
# 1. Tokenize text prompt, with LoRA request if one exists.
# 2. For multimodal models with a merged preprocessor, preprocess
@ -548,7 +540,10 @@ class InputProcessor:
MultiModalFeatureSpec(
data=decoder_mm_inputs[modality][idx],
modality=modality,
identifier=decoder_mm_hashes[modality][idx],
identifier=self._get_mm_identifier(
decoder_mm_hashes[modality][idx],
lora_request,
),
mm_position=decoder_mm_positions[modality][idx],
)
)