mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 09:16:06 +08:00
[Model] Support pp for qwen2-vl (#8696)
This commit is contained in:
parent
3e83c12b5c
commit
a79e522984
@ -8,6 +8,8 @@ WARNING: This test runs in both single-node (4 GPUs) and multi-node
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from packaging import version
|
||||
from transformers import __version__ as transformers_version
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
@ -37,6 +39,7 @@ VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1"
|
||||
(1, 2, 1, 1, 1, "OpenGVLab/InternVL2-1B", "mp"),
|
||||
(1, 2, 1, 1, 1, "OpenGVLab/InternVL2-2B", "mp"),
|
||||
(1, 2, 1, 0, 1, "OpenGVLab/InternVL2-4B", "mp"),
|
||||
(1, 2, 0, 1, 0, "Qwen/Qwen2-VL-2B-Instruct", "mp")
|
||||
],
|
||||
)
|
||||
@fork_new_process_for_each_test
|
||||
@ -46,6 +49,11 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL,
|
||||
pytest.skip("Skipping multi-node pipeline parallel test for "
|
||||
"multiprocessing distributed backend")
|
||||
|
||||
# Skip tests that require transformers>=4.45.0
|
||||
if "Qwen2-VL" in MODEL_NAME and version.parse(
|
||||
transformers_version) < version.parse("4.45.0.dev0"):
|
||||
pytest.skip("This test requires transformers>=4.45.0")
|
||||
|
||||
pp_args = [
|
||||
# use half precision for speed and memory savings in CI environment
|
||||
"--dtype",
|
||||
|
||||
@ -51,6 +51,7 @@ _PP_SUPPORTED_MODELS = [
|
||||
"Qwen2ForCausalLM",
|
||||
"Qwen2MoeForCausalLM",
|
||||
"QWenLMHeadModel",
|
||||
"Qwen2VLForConditionalGeneration",
|
||||
]
|
||||
|
||||
|
||||
|
||||
@ -49,7 +49,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsLoRA
|
||||
from .utils import is_pp_missing_parameter, make_layers
|
||||
from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers
|
||||
|
||||
|
||||
class Qwen2MLP(nn.Module):
|
||||
@ -235,11 +235,16 @@ class Qwen2Model(nn.Module):
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
if get_pp_group().is_first_rank or (config.tie_word_embeddings
|
||||
and get_pp_group().is_last_rank):
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
else:
|
||||
self.embed_tokens = PPMissingLayer()
|
||||
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: Qwen2DecoderLayer(config=config,
|
||||
@ -248,7 +253,10 @@ class Qwen2Model(nn.Module):
|
||||
prefix=f"{prefix}.layers",
|
||||
)
|
||||
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
if get_pp_group().is_last_rank:
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
else:
|
||||
self.norm = PPMissingLayer()
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids)
|
||||
|
||||
@ -45,7 +45,7 @@ from vllm.attention import AttentionMetadata
|
||||
from vllm.attention.selector import (_Backend, backend_name_to_enum,
|
||||
get_global_forced_attn_backend)
|
||||
from vllm.config import CacheConfig, MultiModalConfig
|
||||
from vllm.distributed import parallel_state
|
||||
from vllm.distributed import get_pp_group, parallel_state
|
||||
from vllm.distributed import utils as dist_utils
|
||||
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
|
||||
from vllm.logger import init_logger
|
||||
@ -68,6 +68,9 @@ from vllm.platforms import current_platform
|
||||
from vllm.sequence import IntermediateTensors, SequenceData
|
||||
from vllm.transformers_utils.processor import get_processor
|
||||
|
||||
from .utils import (PPMissingLayer, is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# === Vision Inputs === #
|
||||
@ -856,15 +859,21 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
|
||||
self.model = Qwen2Model(config, cache_config, quant_config)
|
||||
|
||||
if config.tie_word_embeddings:
|
||||
self.lm_head = self.model.embed_tokens
|
||||
if get_pp_group().is_last_rank:
|
||||
if config.tie_word_embeddings:
|
||||
self.lm_head = self.model.embed_tokens
|
||||
else:
|
||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config)
|
||||
else:
|
||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config)
|
||||
self.lm_head = PPMissingLayer()
|
||||
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
self.make_empty_intermediate_tensors = (
|
||||
make_empty_intermediate_tensors_factory(
|
||||
["hidden_states", "residual"], config.hidden_size))
|
||||
|
||||
def _validate_and_reshape_mm_tensor(self,
|
||||
mm_input: Union[torch.Tensor,
|
||||
@ -979,7 +988,8 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
video_input = self._parse_and_validate_video_input(**kwargs)
|
||||
|
||||
if image_input is None and video_input is None:
|
||||
if (image_input is None
|
||||
and video_input is None) or not get_pp_group().is_first_rank:
|
||||
inputs_embeds = None
|
||||
else:
|
||||
if getattr(self.config, "rope_scaling", {}).get("type",
|
||||
@ -1015,6 +1025,7 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
positions=positions,
|
||||
kv_caches=kv_caches,
|
||||
attn_metadata=attn_metadata,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
return hidden_states
|
||||
@ -1055,6 +1066,8 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
@ -1081,6 +1094,8 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
except KeyError:
|
||||
print(params_dict.keys())
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user