[Model] Support Pipeline Parallelism for moonshotai/Kimi-VL-A3B-Thinking-2506 (#23114)

Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
Jiangyun Zhu 2025-08-19 14:24:31 +08:00 committed by GitHub
parent 90bbe0a5ad
commit fda9537c5e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 18 additions and 13 deletions

View File

@ -626,7 +626,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
| `InternS1ForConditionalGeneration` | Intern-S1 | T + I<sup>E+</sup> + V<sup>E+</sup> | `internlm/Intern-S1`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `InternVLChatModel` | InternVL 3.0, InternVideo 2.5, InternVL 2.5, Mono-InternVL, InternVL 2.0 | T + I<sup>E+</sup> + (V<sup>E+</sup>) | `OpenGVLab/InternVL3-9B`, `OpenGVLab/InternVideo2_5_Chat_8B`, `OpenGVLab/InternVL2_5-4B`, `OpenGVLab/Mono-InternVL-2B`, `OpenGVLab/InternVL2-4B`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `KeyeForConditionalGeneration` | Keye-VL-8B-Preview | T + I<sup>E+</sup> + V<sup>E+</sup> | `Kwai-Keye/Keye-VL-8B-Preview` | | | ✅︎ |
| `KimiVLForConditionalGeneration` | Kimi-VL-A3B-Instruct, Kimi-VL-A3B-Thinking | T + I<sup>+</sup> | `moonshotai/Kimi-VL-A3B-Instruct`, `moonshotai/Kimi-VL-A3B-Thinking` | | | ✅︎ |
| `KimiVLForConditionalGeneration` | Kimi-VL-A3B-Instruct, Kimi-VL-A3B-Thinking | T + I<sup>+</sup> | `moonshotai/Kimi-VL-A3B-Instruct`, `moonshotai/Kimi-VL-A3B-Thinking` | | ✅︎ | ✅︎ |
| `Llama4ForConditionalGeneration` | Llama 4 | T + I<sup>+</sup> | `meta-llama/Llama-4-Scout-17B-16E-Instruct`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct`, etc. | | ✅︎ | ✅︎ |
| `Llama_Nemotron_Nano_VL` | Llama Nemotron Nano VL | T + I<sup>E+</sup> | `nvidia/Llama-3.1-Nemotron-Nano-VL-8B-V1` | ✅︎ | ✅︎ | ✅︎ |
| `LlavaForConditionalGeneration` | LLaVA-1.5, Pixtral (HF Transformers) | T + I<sup>E+</sup> | `llava-hf/llava-1.5-7b-hf`, `TIGER-Lab/Mantis-8B-siglip-llama3` (see note), `mistral-community/pixtral-12b`, etc. | | ✅︎ | ✅︎ |

View File

@ -54,8 +54,7 @@ from transformers import BatchFeature
from transformers.activations import GELUActivation
from vllm.config import VllmConfig
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.distributed import get_pp_group
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.vocab_parallel_embedding import (
@ -63,7 +62,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.models.deepseek_v2 import DeepseekV2Model
from vllm.model_executor.models.interfaces import SupportsMultiModal
from vllm.model_executor.models.interfaces import (SupportsMultiModal,
SupportsPP)
from vllm.model_executor.models.moonvit import MoonVitPretrainedModel
from vllm.model_executor.models.utils import merge_multimodal_embeddings
from vllm.model_executor.sampling_metadata import SamplingMetadata
@ -81,7 +81,7 @@ from vllm.transformers_utils.configs import KimiVLConfig, MoonViTConfig
from vllm.transformers_utils.configs.deepseek_vl2 import DeepseekV2Config
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .utils import is_pp_missing_parameter, maybe_prefix
from .utils import PPMissingLayer, is_pp_missing_parameter, maybe_prefix
# For dummy input only
@ -270,7 +270,8 @@ class KimiVLMultiModalProcessor(BaseMultiModalProcessor[KimiVLProcessingInfo]):
@MULTIMODAL_REGISTRY.register_processor(KimiVLMultiModalProcessor,
info=KimiVLProcessingInfo,
dummy_inputs=KimiVLDummyInputsBuilder)
class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal):
class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP):
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
@ -304,17 +305,21 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal):
prefix=maybe_prefix(prefix, "language_model"),
)
self.unpadded_vocab_size = config.text_config.vocab_size
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.text_config.hidden_size,
org_num_embeddings=self.config.text_config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE)
if get_pp_group().is_last_rank:
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.text_config.hidden_size,
org_num_embeddings=self.config.text_config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
)
else:
self.lm_head = PPMissingLayer()
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)
logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size, logit_scale)
self.media_placeholder: int = self.config.media_placeholder_token_id
self.tp_rank = get_tensor_model_parallel_rank()
self.tp_world_size = get_tensor_model_parallel_world_size()
# ref: qwen2_vl.py
def _validate_and_reshape_mm_tensor(self, mm_input: object,