[Bugfix] Fix PP for ChatGLM and Molmo (#9422)

This commit is contained in:
Cyrus Leung 2024-10-24 14:12:05 +08:00 committed by GitHub
parent 056a68c7db
commit 836e8ef6ee
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 195 additions and 122 deletions

View File

@ -425,7 +425,7 @@ Text Generation
- -
* - :code:`MolmoForCausalLM` * - :code:`MolmoForCausalLM`
- Molmo - Molmo
- Image - T + I
- :code:`allenai/Molmo-7B-D-0924`, :code:`allenai/Molmo-72B-0924`, etc. - :code:`allenai/Molmo-7B-D-0924`, :code:`allenai/Molmo-72B-0924`, etc.
- -
- ✅︎ - ✅︎

View File

@ -118,11 +118,8 @@ class PPTestSettings:
# The values displayed here are only a rough indicator of the size of the model # The values displayed here are only a rough indicator of the size of the model
# yapf: disable # yapf: disable
GENERATION_MODEL_SETTINGS = { TEXT_GENERATION_MODELS = {
# [DETAILED TESTS] # [Decoder-only]
"meta-llama/Meta-Llama-3-8B": PPTestSettings.detailed(),
"microsoft/Phi-3-mini-4k-instruct": PPTestSettings.detailed(trust_remote_code=True, multi_node_only=True), # noqa: E501
# [FAST TESTS]
# Uses Llama # Uses Llama
# "BAAI/AquilaChat-7B": PPTestSettings.fast(), # "BAAI/AquilaChat-7B": PPTestSettings.fast(),
"Snowflake/snowflake-arctic-instruct": PPTestSettings.fast(tp_base=8, trust_remote_code=True), # noqa: E501 "Snowflake/snowflake-arctic-instruct": PPTestSettings.fast(tp_base=8, trust_remote_code=True), # noqa: E501
@ -151,6 +148,7 @@ GENERATION_MODEL_SETTINGS = {
"core42/jais-13b-chat": PPTestSettings.fast(), "core42/jais-13b-chat": PPTestSettings.fast(),
# TODO: Implement PP # TODO: Implement PP
# "ai21labs/AI21-Jamba-1.5-Mini": PPTestSettings.fast(), # "ai21labs/AI21-Jamba-1.5-Mini": PPTestSettings.fast(),
"meta-llama/Meta-Llama-3-8B": PPTestSettings.detailed(),
"openbmb/MiniCPM-2B-sft-bf16": PPTestSettings.fast(trust_remote_code=True), "openbmb/MiniCPM-2B-sft-bf16": PPTestSettings.fast(trust_remote_code=True),
"openbmb/MiniCPM3-4B": PPTestSettings.fast(trust_remote_code=True), "openbmb/MiniCPM3-4B": PPTestSettings.fast(trust_remote_code=True),
# Uses Llama # Uses Llama
@ -163,6 +161,7 @@ GENERATION_MODEL_SETTINGS = {
"facebook/opt-iml-max-1.3b": PPTestSettings.fast(), "facebook/opt-iml-max-1.3b": PPTestSettings.fast(),
"OrionStarAI/Orion-14B-Chat": PPTestSettings.fast(trust_remote_code=True), "OrionStarAI/Orion-14B-Chat": PPTestSettings.fast(trust_remote_code=True),
"microsoft/phi-2": PPTestSettings.fast(), "microsoft/phi-2": PPTestSettings.fast(),
"microsoft/Phi-3-mini-4k-instruct": PPTestSettings.detailed(trust_remote_code=True, multi_node_only=True), # noqa: E501
"microsoft/Phi-3-small-8k-instruct": PPTestSettings.fast(trust_remote_code=True), # noqa: E501 "microsoft/Phi-3-small-8k-instruct": PPTestSettings.fast(trust_remote_code=True), # noqa: E501
"microsoft/Phi-3.5-MoE-instruct": PPTestSettings.fast(trust_remote_code=True), # noqa: E501 "microsoft/Phi-3.5-MoE-instruct": PPTestSettings.fast(trust_remote_code=True), # noqa: E501
"adept/persimmon-8b-chat": PPTestSettings.fast(), "adept/persimmon-8b-chat": PPTestSettings.fast(),
@ -174,40 +173,40 @@ GENERATION_MODEL_SETTINGS = {
"upstage/solar-pro-preview-instruct": PPTestSettings.fast(tp_base=2), "upstage/solar-pro-preview-instruct": PPTestSettings.fast(tp_base=2),
# FIXME: Cannot load tokenizer in latest transformers version # FIXME: Cannot load tokenizer in latest transformers version
# "xverse/XVERSE-7B-Chat": PPTestSettings.fast(trust_remote_code=True), # "xverse/XVERSE-7B-Chat": PPTestSettings.fast(trust_remote_code=True),
# [Encoder-only]
# TODO: Implement PP
# "facebook/bart-base": PPTestSettings.fast(),
} }
EMBEDDING_MODEL_SETTINGS = { # type: ignore[var-annotated] EMBEDDING_MODELS = { # type: ignore[var-annotated]
# [FAST TESTS] # [Text-only]
"intfloat/e5-mistral-7b-instruct": PPTestSettings.fast(), "intfloat/e5-mistral-7b-instruct": PPTestSettings.fast(),
"BAAI/bge-multilingual-gemma2": PPTestSettings.fast(), "BAAI/bge-multilingual-gemma2": PPTestSettings.fast(),
"Qwen/Qwen2.5-Math-RM-72B": PPTestSettings.fast(tp_base=4, trust_remote_code=True), # noqa: E501 "Qwen/Qwen2.5-Math-RM-72B": PPTestSettings.fast(tp_base=4, trust_remote_code=True), # noqa: E501
} }
MULTIMODAL_MODEL_SETTINGS = { MULTIMODAL_MODELS = {
# [FAST TESTS] # [Decoder-only]
"Salesforce/blip2-opt-2.7b": PPTestSettings.fast(), "Salesforce/blip2-opt-2.7b": PPTestSettings.fast(),
"facebook/chameleon-7b": PPTestSettings.fast(), "facebook/chameleon-7b": PPTestSettings.fast(),
"adept/fuyu-8b": PPTestSettings.fast(), "adept/fuyu-8b": PPTestSettings.fast(),
"THUDM/glm-4v-9b": PPTestSettings.fast(trust_remote_code=True),
"OpenGVLab/InternVL2-1B": PPTestSettings.fast(trust_remote_code=True), "OpenGVLab/InternVL2-1B": PPTestSettings.fast(trust_remote_code=True),
"llava-hf/llava-1.5-7b-hf": PPTestSettings.fast(), "llava-hf/llava-1.5-7b-hf": PPTestSettings.fast(),
"llava-hf/llava-v1.6-mistral-7b-hf": PPTestSettings.fast(), "llava-hf/llava-v1.6-mistral-7b-hf": PPTestSettings.fast(),
"llava-hf/LLaVA-NeXT-Video-7B-hf": PPTestSettings.fast(), "llava-hf/LLaVA-NeXT-Video-7B-hf": PPTestSettings.fast(),
"llava-hf/llava-onevision-qwen2-0.5b-ov-hf": PPTestSettings.fast(), "llava-hf/llava-onevision-qwen2-0.5b-ov-hf": PPTestSettings.fast(),
"openbmb/MiniCPM-Llama3-V-2_5": PPTestSettings.fast(trust_remote_code=True), "openbmb/MiniCPM-Llama3-V-2_5": PPTestSettings.fast(trust_remote_code=True),
# TODO: Implement PP "allenai/Molmo-7B-D-0924": PPTestSettings.fast(trust_remote_code=True),
# "meta-llama/Llama-3.2-11B-Vision-Instruct": PPTestSettings.fast(),
"microsoft/Phi-3-vision-128k-instruct": PPTestSettings.fast(trust_remote_code=True), # noqa: E501 "microsoft/Phi-3-vision-128k-instruct": PPTestSettings.fast(trust_remote_code=True), # noqa: E501
"mistralai/Pixtral-12B-2409": PPTestSettings.fast(tp_base=2, tokenizer_mode="mistral"), # noqa: E501 "mistralai/Pixtral-12B-2409": PPTestSettings.fast(tp_base=2, tokenizer_mode="mistral"), # noqa: E501
"Qwen/Qwen-VL-Chat": PPTestSettings.fast(trust_remote_code=True), "Qwen/Qwen-VL-Chat": PPTestSettings.fast(trust_remote_code=True),
"Qwen/Qwen2-Audio-7B-Instruct": PPTestSettings.fast(), "Qwen/Qwen2-Audio-7B-Instruct": PPTestSettings.fast(),
"Qwen/Qwen2-VL-2B-Instruct": PPTestSettings.fast(), "Qwen/Qwen2-VL-2B-Instruct": PPTestSettings.fast(),
"fixie-ai/ultravox-v0_3": PPTestSettings.fast(), "fixie-ai/ultravox-v0_3": PPTestSettings.fast(),
} # [Encoder-decoder]
CONDITIONAL_GENERATION_MODEL_SETTINGS = { # type: ignore[var-annotated]
# [FAST TESTS]
# TODO: Implement PP # TODO: Implement PP
# "facebook/bart-base": PPTestSettings.fast(), # "meta-llama/Llama-3.2-11B-Vision-Instruct": PPTestSettings.fast(),
} }
# yapf: enable # yapf: enable
@ -323,7 +322,7 @@ def _compare_tp(
("model_name", "parallel_setup", "distributed_backend", "task", ("model_name", "parallel_setup", "distributed_backend", "task",
"test_options"), "test_options"),
[ [
params for model_name, settings in GENERATION_MODEL_SETTINGS.items() params for model_name, settings in TEXT_GENERATION_MODELS.items()
for params in settings.iter_params(model_name) for params in settings.iter_params(model_name)
if model_name in TEST_MODELS if model_name in TEST_MODELS
], ],
@ -350,7 +349,7 @@ def test_tp_language_generation(
("model_name", "parallel_setup", "distributed_backend", "task", ("model_name", "parallel_setup", "distributed_backend", "task",
"test_options"), "test_options"),
[ [
params for model_name, settings in EMBEDDING_MODEL_SETTINGS.items() params for model_name, settings in EMBEDDING_MODELS.items()
for params in settings.iter_params(model_name) for params in settings.iter_params(model_name)
if model_name in TEST_MODELS if model_name in TEST_MODELS
], ],
@ -377,7 +376,7 @@ def test_tp_language_embedding(
("model_name", "parallel_setup", "distributed_backend", "task", ("model_name", "parallel_setup", "distributed_backend", "task",
"test_options"), "test_options"),
[ [
params for model_name, settings in MULTIMODAL_MODEL_SETTINGS.items() params for model_name, settings in MULTIMODAL_MODELS.items()
for params in settings.iter_params(model_name) for params in settings.iter_params(model_name)
if model_name in TEST_MODELS if model_name in TEST_MODELS
], ],

View File

@ -13,8 +13,9 @@ from torch.nn import LayerNorm
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig, MultiModalConfig from vllm.config import CacheConfig, LoRAConfig, MultiModalConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.inputs import INPUT_REGISTRY, DecoderOnlyInputs, InputContext from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
token_inputs)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
@ -22,8 +23,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization import QuantizationConfig
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
@ -39,7 +39,9 @@ from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
SequenceData) SequenceData)
from vllm.transformers_utils.configs import ChatGLMConfig from vllm.transformers_utils.configs import ChatGLMConfig
from .interfaces import SupportsLoRA, SupportsMultiModal from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
logger = init_logger(__name__) logger = init_logger(__name__)
@ -150,6 +152,10 @@ def find_all_positions(input_ids: List[int], target: int) -> List[int]:
def input_processor_for_glmv(ctx: InputContext, inputs: DecoderOnlyInputs): def input_processor_for_glmv(ctx: InputContext, inputs: DecoderOnlyInputs):
multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data:
return inputs
hf_config = ctx.get_hf_config(ChatGLMConfig) hf_config = ctx.get_hf_config(ChatGLMConfig)
vision_config = getattr(hf_config, 'vision_config', None) vision_config = getattr(hf_config, 'vision_config', None)
@ -161,8 +167,8 @@ def input_processor_for_glmv(ctx: InputContext, inputs: DecoderOnlyInputs):
msg = f"Unsupported vision config: {type(vision_config)}" msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg) raise NotImplementedError(msg)
input_ids = inputs.get("prompt_token_ids") input_ids = inputs["prompt_token_ids"]
position_ids = inputs.get("position_ids")
tokenizer = cached_get_tokenizer( tokenizer = cached_get_tokenizer(
ctx.model_config.model, ctx.model_config.model,
trust_remote_code=ctx.model_config.trust_remote_code) trust_remote_code=ctx.model_config.trust_remote_code)
@ -171,20 +177,19 @@ def input_processor_for_glmv(ctx: InputContext, inputs: DecoderOnlyInputs):
raw_batch_data = tokenizer.apply_chat_template( raw_batch_data = tokenizer.apply_chat_template(
conversation=[{ conversation=[{
"role": "user", "role": "user",
"image": inputs['multi_modal_data']["image"], "image": multi_modal_data["image"],
"content": inputs['prompt'] "content": inputs['prompt'],
}], }],
add_generation_prompt=True, add_generation_prompt=True,
tokenize=True, tokenize=True,
return_tensors="pt", return_tensors="pt",
return_dict=True).data return_dict=True,
).data
except Exception: except Exception:
logger.error("Failed to process content (%s)", inputs['prompt']) logger.error("Failed to process content (%s)", inputs['prompt'])
raise raise
input_ids = raw_batch_data['input_ids'][0].tolist() input_ids = raw_batch_data['input_ids'][0].tolist()
if position_ids is None:
position_ids = list(range(len(input_ids)))
boi_token_id = hf_config.boi_token_id boi_token_id = hf_config.boi_token_id
eoi_token_id = hf_config.eoi_token_id eoi_token_id = hf_config.eoi_token_id
boi_positions = find_all_positions(input_ids, boi_token_id) boi_positions = find_all_positions(input_ids, boi_token_id)
@ -193,7 +198,6 @@ def input_processor_for_glmv(ctx: InputContext, inputs: DecoderOnlyInputs):
assert len(boi_positions) == len(eoi_positions) assert len(boi_positions) == len(eoi_positions)
new_input_ids = [] new_input_ids = []
new_position_ids = []
final_processed_position = 0 final_processed_position = 0
final_processed_position = 0 final_processed_position = 0
@ -201,29 +205,28 @@ def input_processor_for_glmv(ctx: InputContext, inputs: DecoderOnlyInputs):
assert boi_position < eoi_position assert boi_position < eoi_position
new_input_ids.extend(input_ids[final_processed_position:boi_position + new_input_ids.extend(input_ids[final_processed_position:boi_position +
1]) 1])
new_position_ids.extend(
list(range(final_processed_position, boi_position + 1)))
new_input_ids.extend([input_ids[boi_position + 1]] * new_input_ids.extend([input_ids[boi_position + 1]] *
image_placeholder_length) image_placeholder_length)
new_position_ids.extend([boi_position + 1] * image_placeholder_length)
final_processed_position = eoi_position final_processed_position = eoi_position
new_input_ids.extend(input_ids[final_processed_position:]) new_input_ids.extend(input_ids[final_processed_position:])
new_position_ids.extend(
list(range(final_processed_position, len(input_ids))))
assert len(new_input_ids) == len(new_position_ids) prompt = inputs.get("prompt")
if prompt is None:
prompt = tokenizer.decode(new_input_ids)
inputs["prompt_token_ids"] = new_input_ids return token_inputs(
inputs["position_ids"] = new_position_ids prompt_token_ids=new_input_ids,
return inputs prompt=prompt,
multi_modal_data=multi_modal_data,
)
class GLMAttention(nn.Module): class GLMAttention(nn.Module):
def __init__( def __init__(
self, self,
config, config: ChatGLMConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
@ -314,7 +317,7 @@ class GLMMLP(nn.Module):
def __init__( def __init__(
self, self,
config, config: ChatGLMConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
@ -357,7 +360,7 @@ class GLMBlock(nn.Module):
def __init__( def __init__(
self, self,
config, config: ChatGLMConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
@ -428,9 +431,10 @@ class GLMTransformer(nn.Module):
def __init__( def __init__(
self, self,
config, config: ChatGLMConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.post_layer_norm = config.post_layer_norm self.post_layer_norm = config.post_layer_norm
@ -439,10 +443,11 @@ class GLMTransformer(nn.Module):
self.num_layers = config.num_layers self.num_layers = config.num_layers
# Transformer layers. # Transformer layers.
self.layers = nn.ModuleList([ self.start_layer, self.end_layer, self.layers = make_layers(
GLMBlock(config, cache_config, quant_config) self.num_layers,
for i in range(self.num_layers) lambda prefix: GLMBlock(config, cache_config, quant_config),
]) prefix=f"{prefix}.layers",
)
if self.post_layer_norm: if self.post_layer_norm:
layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm
@ -450,6 +455,10 @@ class GLMTransformer(nn.Module):
self.final_layernorm = layer_norm_func( self.final_layernorm = layer_norm_func(
config.hidden_size, eps=config.layernorm_epsilon) config.hidden_size, eps=config.layernorm_epsilon)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(["hidden_states"],
config.hidden_size))
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
@ -457,16 +466,16 @@ class GLMTransformer(nn.Module):
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
for i in range(self.num_layers): for i in range(self.start_layer, self.end_layer):
layer = self.layers[i] layer = self.layers[i]
hidden_states = layer( hidden_states = layer(
hidden_states=hidden_states, hidden_states=hidden_states,
position_ids=position_ids, position_ids=position_ids,
kv_cache=kv_caches[i], kv_cache=kv_caches[i - self.start_layer],
attn_metadata=attn_metadata, attn_metadata=attn_metadata,
) )
# Final layer norm. # Final layer norm.
if self.post_layer_norm: if get_pp_group().is_last_rank and self.post_layer_norm:
hidden_states = self.final_layernorm(hidden_states) hidden_states = self.final_layernorm(hidden_states)
return hidden_states return hidden_states
@ -476,7 +485,7 @@ class ChatGLMModel(nn.Module):
def __init__( def __init__(
self, self,
config, config: ChatGLMConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
@ -504,6 +513,9 @@ class ChatGLMModel(nn.Module):
else: else:
self.vision = None self.vision = None
self.make_empty_intermediate_tensors = (
self.encoder.make_empty_intermediate_tensors)
def _parse_and_validate_image_input( def _parse_and_validate_image_input(
self, **kwargs: object) -> GLMImagePixelInputs: self, **kwargs: object) -> GLMImagePixelInputs:
@ -529,24 +541,26 @@ class ChatGLMModel(nn.Module):
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
**kwargs: object, **kwargs: object,
) -> torch.Tensor: ) -> torch.Tensor:
if intermediate_tensors is None:
inputs_embeds = self.embedding(input_ids)
image_input = self._parse_and_validate_image_input(**kwargs)
inputs_embeds = self.embedding(input_ids) if image_input["pixel_values"] is not None:
image_input = self._parse_and_validate_image_input(**kwargs) pixel_values = image_input["pixel_values"].to(
dtype=inputs_embeds.dtype)
image_embeds = self.vision(pixel_values)
if image_input["pixel_values"] is not None: boi_token_id = self.config.boi_token_id
pixel_values = image_input["pixel_values"].to( eoi_token_id = self.config.eoi_token_id
dtype=inputs_embeds.dtype)
image_embeds = self.vision(pixel_values)
boi_token_id = self.config.boi_token_id inputs_embeds = merge_glm_vision_embeddings(
eoi_token_id = self.config.eoi_token_id input_ids=input_ids,
inputs_embeds=inputs_embeds,
inputs_embeds = merge_glm_vision_embeddings( vision_embeddings=image_embeds,
input_ids=input_ids, boi_token_id=boi_token_id,
inputs_embeds=inputs_embeds, eoi_token_id=eoi_token_id)
vision_embeddings=image_embeds, else:
boi_token_id=boi_token_id, inputs_embeds = intermediate_tensors["hidden_states"]
eoi_token_id=eoi_token_id)
# Run encoder. # Run encoder.
hidden_states = self.encoder( hidden_states = self.encoder(
@ -555,6 +569,9 @@ class ChatGLMModel(nn.Module):
kv_caches=kv_caches, kv_caches=kv_caches,
attn_metadata=attn_metadata, attn_metadata=attn_metadata,
) )
if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states})
return hidden_states return hidden_states
@ -562,7 +579,8 @@ class ChatGLMModel(nn.Module):
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_glmv_image_tokens) @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_glmv_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_glmv) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_glmv)
@INPUT_REGISTRY.register_input_processor(input_processor_for_glmv) @INPUT_REGISTRY.register_input_processor(input_processor_for_glmv)
class ChatGLMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): class ChatGLMForCausalLM(nn.Module, SupportsLoRA, SupportsPP,
SupportsMultiModal):
packed_modules_mapping = { packed_modules_mapping = {
"query_key_value": ["query_key_value"], "query_key_value": ["query_key_value"],
"dense_h_to_4h": ["dense_h_to_4h"] "dense_h_to_4h": ["dense_h_to_4h"]
@ -610,7 +628,8 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
**kwargs) -> torch.Tensor: **kwargs) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, kv_caches, hidden_states = self.transformer(input_ids, positions, kv_caches,
attn_metadata, **kwargs) attn_metadata, intermediate_tensors,
**kwargs)
return hidden_states return hidden_states
def compute_logits( def compute_logits(
@ -656,6 +675,8 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)

View File

@ -30,21 +30,21 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization import QuantizationConfig
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import SupportsMultiModal
from vllm.model_executor.models.utils import make_layers
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalInputs from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalInputs
from vllm.multimodal.utils import cached_get_tokenizer
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
SequenceData) SequenceData)
from vllm.transformers_utils.processor import get_processor from vllm.transformers_utils.processor import get_processor
from .utils import get_vit_attn_backend from .interfaces import SupportsMultiModal, SupportsPP
from .utils import (get_vit_attn_backend,
make_empty_intermediate_tensors_factory, make_layers)
# TODO: hard-coded for now. Consider making it configurable. # TODO: hard-coded for now. Consider making it configurable.
VIT_LAYERS = [-2, -9] VIT_LAYERS = [-2, -9]
@ -744,6 +744,10 @@ class MolmoModel(nn.Module):
assert config.layer_norm_type == "rms" assert config.layer_norm_type == "rms"
self.norm = RMSNorm(config.hidden_size, config.layer_norm_eps) self.norm = RMSNorm(config.hidden_size, config.layer_norm_eps)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
@ -925,16 +929,19 @@ def pad_images(
def input_processor_for_molmo(ctx: InputContext, inputs: DecoderOnlyInputs): def input_processor_for_molmo(ctx: InputContext, inputs: DecoderOnlyInputs):
prompt = inputs.get("prompt", None) prompt = inputs.get("prompt")
multi_modal_data = inputs.get("multi_modal_data", None) multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is not None: image = None if multi_modal_data is None else multi_modal_data.get("image")
image = multi_modal_data.get("image", None)
else:
image = None
processor = cached_get_processor(ctx.model_config.model, processor = cached_get_processor(ctx.model_config.model,
trust_remote_code=True, trust_remote_code=True,
revision=ctx.model_config.code_revision) revision=ctx.model_config.code_revision)
model_config = ctx.model_config
tokenizer = cached_get_tokenizer(
model_config.tokenizer,
trust_remote_code=model_config.trust_remote_code)
# NOTE: message formatting for raw text prompt is only applied for # NOTE: message formatting for raw text prompt is only applied for
# offline inference; for online inference, the prompt is always in # offline inference; for online inference, the prompt is always in
# instruction format and tokenized. # instruction format and tokenized.
@ -997,9 +1004,13 @@ def input_processor_for_molmo(ctx: InputContext, inputs: DecoderOnlyInputs):
multi_modal_data = dict(image=image_data) multi_modal_data = dict(image=image_data)
prompt = inputs.get("prompt")
if prompt is None:
prompt = tokenizer.decode(out["input_ids"])
return token_inputs( return token_inputs(
prompt_token_ids=out["input_ids"], prompt_token_ids=out["input_ids"],
prompt=inputs["prompt"], prompt=prompt,
multi_modal_data=multi_modal_data, multi_modal_data=multi_modal_data,
) )
@ -1008,7 +1019,7 @@ def input_processor_for_molmo(ctx: InputContext, inputs: DecoderOnlyInputs):
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_molmo_image_tokens) @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_molmo_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_molmo) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_molmo)
@INPUT_REGISTRY.register_input_processor(input_processor_for_molmo) @INPUT_REGISTRY.register_input_processor(input_processor_for_molmo)
class MolmoForCausalLM(nn.Module, SupportsMultiModal): class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
def __init__( def __init__(
self, self,
@ -1040,6 +1051,9 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal):
or config.vocab_size) or config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def _parse_and_validate_image_input( def _parse_and_validate_image_input(
self, self,
**kwargs: object, **kwargs: object,
@ -1123,31 +1137,36 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal):
positions: torch.LongTensor, positions: torch.LongTensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
**kwargs: object, **kwargs: object,
) -> SamplerOutput: ) -> SamplerOutput:
if intermediate_tensors is not None:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is not None:
inputs_embeds = self.model.embed_tokens(input_ids)
image_features = self._process_image_input(image_input)
inputs_embeds = self._merge_multimodal_embeddings(
inputs_embeds,
image_features,
image_input["image_input_idx"],
image_input["seq_len"],
)
input_ids = None input_ids = None
else:
inputs_embeds = None inputs_embeds = None
else:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is not None:
inputs_embeds = self.model.embed_tokens(input_ids)
image_features = self._process_image_input(image_input)
inputs_embeds = self._merge_multimodal_embeddings(
inputs_embeds,
image_features,
image_input["image_input_idx"],
image_input["seq_len"],
)
input_ids = None
else:
inputs_embeds = None
hidden_states = self.model( hidden_states = self.model(
input_ids=input_ids, input_ids=input_ids,
positions=positions, positions=positions,
kv_caches=kv_caches, kv_caches=kv_caches,
attn_metadata=attn_metadata, attn_metadata=attn_metadata,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
) )

View File

@ -119,5 +119,6 @@ class Qwen2ForRewardModel(nn.Module, SupportsPP):
return self._pooler(hidden_states, pooling_metadata) return self._pooler(hidden_states, pooling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self,
ignore_unexpected_prefixes=["lm_head."])
loader.load_weights(weights) loader.load_weights(weights)

View File

@ -61,6 +61,7 @@ from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalDataDict,
MultiModalInputs) MultiModalInputs)
from vllm.multimodal.base import MultiModalData from vllm.multimodal.base import MultiModalData
from vllm.multimodal.image import cached_get_image_processor from vllm.multimodal.image import cached_get_image_processor
from vllm.multimodal.utils import cached_get_tokenizer
from vllm.sequence import IntermediateTensors, SequenceData from vllm.sequence import IntermediateTensors, SequenceData
from vllm.transformers_utils.config import uses_mrope from vllm.transformers_utils.config import uses_mrope
from vllm.transformers_utils.processor import cached_get_processor from vllm.transformers_utils.processor import cached_get_processor
@ -817,7 +818,7 @@ def input_processor_for_qwen2_vl(
min_pixels: Optional[int] = None, min_pixels: Optional[int] = None,
max_pixels: Optional[int] = None, max_pixels: Optional[int] = None,
) -> DecoderOnlyInputs: ) -> DecoderOnlyInputs:
multi_modal_data = inputs.get("multi_modal_data", None) multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None: if multi_modal_data is None:
return inputs return inputs
@ -830,6 +831,7 @@ def input_processor_for_qwen2_vl(
min_pixels = min_pixels if min_pixels else image_processor.min_pixels min_pixels = min_pixels if min_pixels else image_processor.min_pixels
max_pixels = max_pixels if max_pixels else image_processor.max_pixels max_pixels = max_pixels if max_pixels else image_processor.max_pixels
model_config = ctx.model_config
hf_config = ctx.get_hf_config(Qwen2VLConfig) hf_config = ctx.get_hf_config(Qwen2VLConfig)
# To avoid redundant processing of vision objects (resize, rescale, etc.), # To avoid redundant processing of vision objects (resize, rescale, etc.),
@ -845,14 +847,11 @@ def input_processor_for_qwen2_vl(
# return_tensors="pt") # return_tensors="pt")
# prompt_token_ids = inputs["input_ids"][0].tolist() # prompt_token_ids = inputs["input_ids"][0].tolist()
prompt_token_ids = inputs.get("prompt_token_ids", None) tokenizer = cached_get_tokenizer(
if prompt_token_ids is None: model_config.tokenizer,
prompt = inputs["prompt"] trust_remote_code=model_config.trust_remote_code)
prompt_token_ids = processor.tokenizer(
prompt, prompt_token_ids = inputs["prompt_token_ids"]
padding=True,
return_tensors=None,
)["input_ids"]
# Expand image pad tokens. # Expand image pad tokens.
@ -894,9 +893,13 @@ def input_processor_for_qwen2_vl(
min_pixels=min_pixels, min_pixels=min_pixels,
max_pixels=max_pixels) max_pixels=max_pixels)
prompt = inputs.get("prompt")
if prompt is None:
prompt = tokenizer.decode(prompt_token_ids)
return token_inputs( return token_inputs(
prompt_token_ids=prompt_token_ids, prompt_token_ids=prompt_token_ids,
prompt=inputs["prompt"], prompt=prompt,
multi_modal_data=multi_modal_data, multi_modal_data=multi_modal_data,
) )

View File

@ -79,6 +79,9 @@ class AutoWeightsLoader:
Similarly, the weight loading logic for individual parameters can be Similarly, the weight loading logic for individual parameters can be
overridden by defining a ``weight_loader`` method. overridden by defining a ``weight_loader`` method.
Detailed weight loading information can be viewed by setting the
environment variable ``VLLM_LOGGING_LEVEL=DEBUG``.
""" """
def __init__( def __init__(
@ -136,20 +139,27 @@ class AutoWeightsLoader:
weight_qualname = self._get_qualname(base_prefix, weight_name) weight_qualname = self._get_qualname(base_prefix, weight_name)
if self._can_skip(weight_qualname): if self._can_skip(weight_qualname):
logger.debug("Skipping weight %s", weight_qualname)
continue continue
if weight_name != "": if weight_name != "":
if not self._can_ignore_unexpected(weight_qualname): if self._can_ignore_unexpected(weight_qualname):
raise ValueError( logger.debug("Ignoring weight %s", weight_qualname)
f"Attempted to load nested weight '{weight_qualname}' "
f"into a single parameter '{base_prefix}'")
continue continue
raise ValueError(
f"Attempted to load nested weight '{weight_qualname}' "
f"into a single parameter '{base_prefix}'")
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, weight_data) weight_loader(param, weight_data)
logger.debug("Loaded weight %s with shape %s", weight_qualname,
param.shape)
yield weight_qualname yield weight_qualname
def _load_module( def _load_module(
@ -175,21 +185,41 @@ class AutoWeightsLoader:
for child_prefix, child_weights in self._groupby_prefix(weights): for child_prefix, child_weights in self._groupby_prefix(weights):
prefix = self._get_qualname(base_prefix, child_prefix) prefix = self._get_qualname(base_prefix, child_prefix)
if self._can_skip(prefix):
continue
if child_prefix in child_modules: if child_prefix in child_modules:
if self._can_skip(prefix + "."):
logger.debug("Skipping module %s", prefix)
continue
yield from self._load_module(prefix, yield from self._load_module(prefix,
child_modules[child_prefix], child_modules[child_prefix],
child_weights) child_weights)
elif child_prefix in child_params: elif child_prefix in child_params:
if self._can_skip(prefix):
logger.debug("Skipping param %s", prefix)
continue
yield from self._load_param(prefix, child_params[child_prefix], yield from self._load_param(prefix, child_params[child_prefix],
child_weights) child_weights)
else: else:
if not self._can_ignore_unexpected(prefix): can_skip_module = self._can_skip(prefix + ".")
msg = (f"There is no module or parameter named '{prefix}' " can_skip_param = self._can_skip(prefix)
f"in {type(self.module).__name__}") if can_skip_module or can_skip_param:
raise ValueError(msg) logger.debug("Skipping missing %s", prefix)
continue
can_ignore_module = self._can_ignore_unexpected(prefix + ".")
can_ignore_param = self._can_ignore_unexpected(prefix)
if can_ignore_module or can_ignore_param:
logger.debug("Ignoring missing %s", prefix)
continue
msg = (f"There is no module or parameter named '{prefix}' "
f"in {type(self.module).__name__}")
raise ValueError(msg)
def load_weights( def load_weights(
self, self,