mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 00:15:51 +08:00
[Bugfix] Fix image input for Pixtral-HF (#11741)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
5950f555a1
commit
91445c7bc8
@ -23,7 +23,7 @@ IMAGE_URLS = [
|
||||
class ModelRequestData(NamedTuple):
|
||||
llm: LLM
|
||||
prompt: str
|
||||
stop_token_ids: Optional[List[str]]
|
||||
stop_token_ids: Optional[List[int]]
|
||||
image_data: List[Image]
|
||||
chat_template: Optional[str]
|
||||
|
||||
@ -44,12 +44,14 @@ def load_aria(question, image_urls: List[str]) -> ModelRequestData:
|
||||
prompt = (f"<|im_start|>user\n{placeholders}{question}<|im_end|>\n"
|
||||
"<|im_start|>assistant\n")
|
||||
stop_token_ids = [93532, 93653, 944, 93421, 1019, 93653, 93519]
|
||||
|
||||
return ModelRequestData(
|
||||
llm=llm,
|
||||
prompt=prompt,
|
||||
stop_token_ids=stop_token_ids,
|
||||
image_data=[fetch_image(url) for url in image_urls],
|
||||
chat_template=None)
|
||||
chat_template=None,
|
||||
)
|
||||
|
||||
|
||||
def load_h2onvl(question: str, image_urls: List[str]) -> ModelRequestData:
|
||||
@ -166,7 +168,8 @@ def load_mllama(question, image_urls: List[str]) -> ModelRequestData:
|
||||
limit_mm_per_prompt={"image": len(image_urls)},
|
||||
)
|
||||
|
||||
prompt = f"<|image|><|image|><|begin_of_text|>{question}"
|
||||
placeholders = "<|image|>" * len(image_urls)
|
||||
prompt = f"{placeholders}<|begin_of_text|>{question}"
|
||||
return ModelRequestData(
|
||||
llm=llm,
|
||||
prompt=prompt,
|
||||
@ -209,6 +212,31 @@ def load_nvlm_d(question: str, image_urls: List[str]):
|
||||
)
|
||||
|
||||
|
||||
def load_pixtral_hf(question: str, image_urls: List[str]) -> ModelRequestData:
|
||||
model_name = "mistral-community/pixtral-12b"
|
||||
|
||||
# Adjust this as necessary to fit in GPU
|
||||
llm = LLM(
|
||||
model=model_name,
|
||||
max_model_len=8192,
|
||||
max_num_seqs=2,
|
||||
tensor_parallel_size=2,
|
||||
limit_mm_per_prompt={"image": len(image_urls)},
|
||||
)
|
||||
|
||||
placeholders = "[IMG]" * len(image_urls)
|
||||
prompt = f"<s>[INST]{question}\n{placeholders}[/INST]"
|
||||
stop_token_ids = None
|
||||
|
||||
return ModelRequestData(
|
||||
llm=llm,
|
||||
prompt=prompt,
|
||||
stop_token_ids=stop_token_ids,
|
||||
image_data=[fetch_image(url) for url in image_urls],
|
||||
chat_template=None,
|
||||
)
|
||||
|
||||
|
||||
def load_phi3v(question: str, image_urls: List[str]) -> ModelRequestData:
|
||||
# num_crops is an override kwarg to the multimodal image processor;
|
||||
# For some models, e.g., Phi-3.5-vision-instruct, it is recommended
|
||||
@ -244,7 +272,8 @@ def load_phi3v(question: str, image_urls: List[str]) -> ModelRequestData:
|
||||
)
|
||||
|
||||
|
||||
def load_qwenvl_chat(question: str, image_urls: List[str]) -> ModelRequestData:
|
||||
def load_qwen_vl_chat(question: str,
|
||||
image_urls: List[str]) -> ModelRequestData:
|
||||
model_name = "Qwen/Qwen-VL-Chat"
|
||||
llm = LLM(
|
||||
model=model_name,
|
||||
@ -274,6 +303,7 @@ def load_qwenvl_chat(question: str, image_urls: List[str]) -> ModelRequestData:
|
||||
|
||||
stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>"]
|
||||
stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
|
||||
|
||||
return ModelRequestData(
|
||||
llm=llm,
|
||||
prompt=prompt,
|
||||
@ -348,7 +378,8 @@ model_example_map = {
|
||||
"mllama": load_mllama,
|
||||
"NVLM_D": load_nvlm_d,
|
||||
"phi3_v": load_phi3v,
|
||||
"qwen_vl_chat": load_qwenvl_chat,
|
||||
"pixtral_hf": load_pixtral_hf,
|
||||
"qwen_vl_chat": load_qwen_vl_chat,
|
||||
"qwen2_vl": load_qwen2_vl,
|
||||
}
|
||||
|
||||
|
||||
@ -546,6 +546,12 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
raise ValueError("Incorrect type of pixel values. "
|
||||
f"Got type: {type(pixel_values)}")
|
||||
|
||||
if self.config.vision_config.model_type == "pixtral":
|
||||
return LlavaImagePixelInputs(
|
||||
type="pixel_values",
|
||||
data=flatten_bn(pixel_values),
|
||||
)
|
||||
|
||||
return LlavaImagePixelInputs(
|
||||
type="pixel_values",
|
||||
data=self._validate_pixel_values(
|
||||
|
||||
@ -774,7 +774,7 @@ class PixtralHFEncoderInfo(VisionEncoderInfo[PixtralVisionConfig]):
|
||||
) -> int:
|
||||
return get_pixtral_hf_image_feature_size(
|
||||
image_size=self.vision_config.image_size,
|
||||
patch_size=self.get_image_size(),
|
||||
patch_size=self.vision_config.patch_size,
|
||||
)
|
||||
|
||||
def get_max_image_tokens(self) -> int:
|
||||
|
||||
@ -281,6 +281,15 @@ def flatten_bn(
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def flatten_bn(
|
||||
x: Union[List[torch.Tensor], torch.Tensor],
|
||||
*,
|
||||
concat: bool = False,
|
||||
) -> Union[List[torch.Tensor], torch.Tensor]:
|
||||
...
|
||||
|
||||
|
||||
def flatten_bn(
|
||||
x: Union[List[torch.Tensor], torch.Tensor],
|
||||
*,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user