[MM] Add text-only mode for Qwen3-VL (#26000)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Roger Wang 2025-09-30 21:13:42 -07:00 committed by yewentao256
parent cd0bbf5de2
commit 4c094b339e
2 changed files with 45 additions and 26 deletions

View File

@ -1125,14 +1125,17 @@ class Qwen3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
self.config = config
self.multimodal_config = multimodal_config
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
self.visual = Qwen3_VisionTransformer(
config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
quant_config=quant_config,
prefix=maybe_prefix(prefix, "visual"),
use_data_parallel=self.use_data_parallel,
)
if not multimodal_config.get_limit_per_prompt("image") and \
not multimodal_config.get_limit_per_prompt("video"):
self.visual = None
else:
self.visual = Qwen3_VisionTransformer(
config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
quant_config=quant_config,
prefix=maybe_prefix(prefix, "visual"),
use_data_parallel=self.use_data_parallel,
)
self.language_model = Qwen3LLMForCausalLM(vllm_config=vllm_config,
prefix=maybe_prefix(
@ -1148,11 +1151,15 @@ class Qwen3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
config.vision_config.deepstack_visual_indexes
) if self.use_deepstack else 0
# register buffer for deepstack
self.deepstack_input_embeds = [
torch.zeros(vllm_config.scheduler_config.max_num_batched_tokens,
config.text_config.hidden_size)
for _ in range(self.deepstack_num_level)
] if self.use_deepstack else None
if self.use_deepstack and self.visual is not None:
self.deepstack_input_embeds = [
torch.zeros(
vllm_config.scheduler_config.max_num_batched_tokens,
config.text_config.hidden_size)
for _ in range(self.deepstack_num_level)
]
else:
self.deepstack_input_embeds = None
self.visual_dim = config.vision_config.out_hidden_size
self.multiscale_dim = self.visual_dim * self.deepstack_num_level
@ -1526,7 +1533,11 @@ class Qwen3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
skip_prefixes = []
if self.visual is None:
skip_prefixes.extend(["visual."])
loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
def get_mm_mapping(self) -> MultiModelKeys:

View File

@ -319,13 +319,17 @@ class Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration):
self.multimodal_config = multimodal_config
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
self.visual = Qwen3_VisionTransformer(
config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
quant_config=quant_config,
prefix=maybe_prefix(prefix, "visual"),
use_data_parallel=self.use_data_parallel,
)
if not multimodal_config.get_limit_per_prompt("image") and \
not multimodal_config.get_limit_per_prompt("video"):
self.visual = None
else:
self.visual = Qwen3_VisionTransformer(
config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
quant_config=quant_config,
prefix=maybe_prefix(prefix, "visual"),
use_data_parallel=self.use_data_parallel,
)
self.language_model = Qwen3MoeLLMForCausalLM(vllm_config=vllm_config,
prefix=maybe_prefix(
@ -341,10 +345,14 @@ class Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration):
config.vision_config.deepstack_visual_indexes
) if self.use_deepstack else 0
# register buffer for deepstack
self.deepstack_input_embeds = [
torch.zeros(vllm_config.scheduler_config.max_num_batched_tokens,
config.text_config.hidden_size)
for _ in range(self.deepstack_num_level)
] if self.use_deepstack else None
if self.use_deepstack and self.visual is not None:
self.deepstack_input_embeds = [
torch.zeros(
vllm_config.scheduler_config.max_num_batched_tokens,
config.text_config.hidden_size)
for _ in range(self.deepstack_num_level)
]
else:
self.deepstack_input_embeds = None
self.visual_dim = config.vision_config.out_hidden_size
self.multiscale_dim = self.visual_dim * self.deepstack_num_level