mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-02 16:04:28 +08:00
[MM] Add text-only mode for Qwen3-VL (#26000)
This commit is contained in:
parent
99028fda44
commit
66bca9b8bd
@ -1125,14 +1125,17 @@ class Qwen3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
self.config = config
|
self.config = config
|
||||||
self.multimodal_config = multimodal_config
|
self.multimodal_config = multimodal_config
|
||||||
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
|
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
|
||||||
|
if not multimodal_config.get_limit_per_prompt("image") and \
|
||||||
self.visual = Qwen3_VisionTransformer(
|
not multimodal_config.get_limit_per_prompt("video"):
|
||||||
config.vision_config,
|
self.visual = None
|
||||||
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
|
else:
|
||||||
quant_config=quant_config,
|
self.visual = Qwen3_VisionTransformer(
|
||||||
prefix=maybe_prefix(prefix, "visual"),
|
config.vision_config,
|
||||||
use_data_parallel=self.use_data_parallel,
|
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
|
||||||
)
|
quant_config=quant_config,
|
||||||
|
prefix=maybe_prefix(prefix, "visual"),
|
||||||
|
use_data_parallel=self.use_data_parallel,
|
||||||
|
)
|
||||||
|
|
||||||
self.language_model = Qwen3LLMForCausalLM(vllm_config=vllm_config,
|
self.language_model = Qwen3LLMForCausalLM(vllm_config=vllm_config,
|
||||||
prefix=maybe_prefix(
|
prefix=maybe_prefix(
|
||||||
@ -1148,11 +1151,15 @@ class Qwen3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
config.vision_config.deepstack_visual_indexes
|
config.vision_config.deepstack_visual_indexes
|
||||||
) if self.use_deepstack else 0
|
) if self.use_deepstack else 0
|
||||||
# register buffer for deepstack
|
# register buffer for deepstack
|
||||||
self.deepstack_input_embeds = [
|
if self.use_deepstack and self.visual is not None:
|
||||||
torch.zeros(vllm_config.scheduler_config.max_num_batched_tokens,
|
self.deepstack_input_embeds = [
|
||||||
config.text_config.hidden_size)
|
torch.zeros(
|
||||||
for _ in range(self.deepstack_num_level)
|
vllm_config.scheduler_config.max_num_batched_tokens,
|
||||||
] if self.use_deepstack else None
|
config.text_config.hidden_size)
|
||||||
|
for _ in range(self.deepstack_num_level)
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
self.deepstack_input_embeds = None
|
||||||
self.visual_dim = config.vision_config.out_hidden_size
|
self.visual_dim = config.vision_config.out_hidden_size
|
||||||
self.multiscale_dim = self.visual_dim * self.deepstack_num_level
|
self.multiscale_dim = self.visual_dim * self.deepstack_num_level
|
||||||
|
|
||||||
@ -1526,7 +1533,11 @@ class Qwen3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
|
|
||||||
def load_weights(self, weights: Iterable[tuple[str,
|
def load_weights(self, weights: Iterable[tuple[str,
|
||||||
torch.Tensor]]) -> set[str]:
|
torch.Tensor]]) -> set[str]:
|
||||||
loader = AutoWeightsLoader(self)
|
|
||||||
|
skip_prefixes = []
|
||||||
|
if self.visual is None:
|
||||||
|
skip_prefixes.extend(["visual."])
|
||||||
|
loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
|
||||||
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
||||||
|
|
||||||
def get_mm_mapping(self) -> MultiModelKeys:
|
def get_mm_mapping(self) -> MultiModelKeys:
|
||||||
|
|||||||
@ -319,13 +319,17 @@ class Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration):
|
|||||||
self.multimodal_config = multimodal_config
|
self.multimodal_config = multimodal_config
|
||||||
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
|
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
|
||||||
|
|
||||||
self.visual = Qwen3_VisionTransformer(
|
if not multimodal_config.get_limit_per_prompt("image") and \
|
||||||
config.vision_config,
|
not multimodal_config.get_limit_per_prompt("video"):
|
||||||
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
|
self.visual = None
|
||||||
quant_config=quant_config,
|
else:
|
||||||
prefix=maybe_prefix(prefix, "visual"),
|
self.visual = Qwen3_VisionTransformer(
|
||||||
use_data_parallel=self.use_data_parallel,
|
config.vision_config,
|
||||||
)
|
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=maybe_prefix(prefix, "visual"),
|
||||||
|
use_data_parallel=self.use_data_parallel,
|
||||||
|
)
|
||||||
|
|
||||||
self.language_model = Qwen3MoeLLMForCausalLM(vllm_config=vllm_config,
|
self.language_model = Qwen3MoeLLMForCausalLM(vllm_config=vllm_config,
|
||||||
prefix=maybe_prefix(
|
prefix=maybe_prefix(
|
||||||
@ -341,10 +345,14 @@ class Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration):
|
|||||||
config.vision_config.deepstack_visual_indexes
|
config.vision_config.deepstack_visual_indexes
|
||||||
) if self.use_deepstack else 0
|
) if self.use_deepstack else 0
|
||||||
# register buffer for deepstack
|
# register buffer for deepstack
|
||||||
self.deepstack_input_embeds = [
|
if self.use_deepstack and self.visual is not None:
|
||||||
torch.zeros(vllm_config.scheduler_config.max_num_batched_tokens,
|
self.deepstack_input_embeds = [
|
||||||
config.text_config.hidden_size)
|
torch.zeros(
|
||||||
for _ in range(self.deepstack_num_level)
|
vllm_config.scheduler_config.max_num_batched_tokens,
|
||||||
] if self.use_deepstack else None
|
config.text_config.hidden_size)
|
||||||
|
for _ in range(self.deepstack_num_level)
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
self.deepstack_input_embeds = None
|
||||||
self.visual_dim = config.vision_config.out_hidden_size
|
self.visual_dim = config.vision_config.out_hidden_size
|
||||||
self.multiscale_dim = self.visual_dim * self.deepstack_num_level
|
self.multiscale_dim = self.visual_dim * self.deepstack_num_level
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user