mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-24 18:45:02 +08:00
[TPU] Remove multi-modal args in TPU backend (#6504)
This commit is contained in:
parent
5fa6e9876e
commit
e09ce759aa
@ -1,5 +1,5 @@
|
||||
import time
|
||||
from typing import List, Mapping, Optional, Tuple
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -12,8 +12,6 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig,
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensors,
|
||||
MultiModalInputs)
|
||||
from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
|
||||
SamplerOutput, SequenceGroupMetadata,
|
||||
SequenceOutput)
|
||||
@ -68,10 +66,6 @@ class TPUModelRunner:
|
||||
False,
|
||||
)
|
||||
|
||||
# Multi-modal data support
|
||||
self.multi_modal_input_mapper = MULTIMODAL_REGISTRY \
|
||||
.create_input_mapper(self.model_config)
|
||||
|
||||
def load_model(self) -> None:
|
||||
self.device = self.device_config.device
|
||||
|
||||
@ -154,7 +148,7 @@ class TPUModelRunner:
|
||||
# Dummy run.
|
||||
num_samples = _MAX_NUM_SAMPLES if is_prompt else 1
|
||||
self.model(token_ids, position_ids, kv_caches, attn_metadata,
|
||||
input_lens, None, t, p, num_samples)
|
||||
input_lens, t, p, num_samples)
|
||||
|
||||
def warmup_model(
|
||||
self,
|
||||
@ -199,14 +193,12 @@ class TPUModelRunner:
|
||||
def _prepare_prompt(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, torch.Tensor,
|
||||
Mapping[str, BatchedTensors]]:
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, torch.Tensor]:
|
||||
assert len(seq_group_metadata_list) > 0
|
||||
input_tokens: List[List[int]] = []
|
||||
input_positions: List[List[int]] = []
|
||||
prompt_lens: List[int] = []
|
||||
slot_mapping: List[List[int]] = []
|
||||
multi_modal_inputs_list: List[MultiModalInputs] = []
|
||||
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
assert seq_group_metadata.is_prompt
|
||||
@ -232,11 +224,6 @@ class TPUModelRunner:
|
||||
slot = block_number * self.block_size + block_offset
|
||||
slot_mapping[-1].append(slot)
|
||||
|
||||
mm_data = seq_group_metadata.multi_modal_data
|
||||
if mm_data:
|
||||
mm_kwargs = self.multi_modal_input_mapper(mm_data)
|
||||
multi_modal_inputs_list.append(mm_kwargs)
|
||||
|
||||
assert len(prompt_lens) > 0
|
||||
num_prefills = len(prompt_lens)
|
||||
num_prefill_tokens = sum(prompt_lens)
|
||||
@ -274,24 +261,17 @@ class TPUModelRunner:
|
||||
block_tables=None,
|
||||
context_lens=None,
|
||||
)
|
||||
|
||||
multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list,
|
||||
device=self.device)
|
||||
|
||||
return (input_tokens, input_positions, attn_metadata, prompt_lens,
|
||||
multi_modal_kwargs)
|
||||
return input_tokens, input_positions, attn_metadata, prompt_lens
|
||||
|
||||
def _prepare_decode(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, torch.Tensor,
|
||||
Mapping[str, BatchedTensors]]:
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, torch.Tensor]:
|
||||
assert len(seq_group_metadata_list) > 0
|
||||
input_tokens: List[List[int]] = []
|
||||
input_positions: List[List[int]] = []
|
||||
slot_mapping: List[List[int]] = []
|
||||
context_lens: List[int] = []
|
||||
multi_modal_inputs_list: List[MultiModalInputs] = []
|
||||
|
||||
batch_idx = 0
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
@ -317,11 +297,6 @@ class TPUModelRunner:
|
||||
slot = block_number * self.block_size + block_offset
|
||||
slot_mapping.append([slot])
|
||||
|
||||
mm_data = seq_group_metadata.multi_modal_data
|
||||
if mm_data:
|
||||
mm_kwargs = self.multi_modal_input_mapper(mm_data)
|
||||
multi_modal_inputs_list.append(mm_kwargs)
|
||||
|
||||
batch_size = _get_padded_batch_size(batch_idx)
|
||||
num_paddings = batch_size - batch_idx
|
||||
input_tokens = input_tokens + [[0]] * num_paddings
|
||||
@ -355,12 +330,7 @@ class TPUModelRunner:
|
||||
block_tables=block_tables,
|
||||
context_lens=context_lens,
|
||||
)
|
||||
|
||||
multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list,
|
||||
device=self.device)
|
||||
|
||||
return (input_tokens, input_positions, attn_metadata, input_lens,
|
||||
multi_modal_kwargs)
|
||||
return input_tokens, input_positions, attn_metadata, input_lens
|
||||
|
||||
def _prepare_sample(
|
||||
self,
|
||||
@ -513,7 +483,6 @@ class ModelWrapper(nn.Module):
|
||||
kv_caches: List[Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]],
|
||||
attn_metadata: AttentionMetadata,
|
||||
input_lens: torch.Tensor,
|
||||
multi_modal_kwargs: Optional[Mapping[str, BatchedTensors]],
|
||||
t: torch.Tensor,
|
||||
p: torch.Tensor,
|
||||
num_samples: int,
|
||||
@ -527,8 +496,6 @@ class ModelWrapper(nn.Module):
|
||||
memory profiling at initialization.
|
||||
attn_metadata: The Pallas attention metadata.
|
||||
input_lens: The actual input lengths of shape [batch_size].
|
||||
multi_modal_kwargs: Keyword arguments from multi-modal data to
|
||||
pass to the model.
|
||||
t: The sampling temperature of shape [batch_size].
|
||||
p: The top-p probability of shape [batch_size].
|
||||
"""
|
||||
@ -573,7 +540,6 @@ class ModelWrapper(nn.Module):
|
||||
position_ids,
|
||||
kv_caches,
|
||||
attn_metadata,
|
||||
**(multi_modal_kwargs or {}),
|
||||
)
|
||||
hidden_states = hidden_states.flatten(0, 1)
|
||||
logits = self.model.compute_logits(hidden_states, sampling_metadata)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user