mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 00:06:06 +08:00
[Model][Qwen3VL] Simplify get_mrope_input_positions using numpy (#28302)
Signed-off-by: Lukas Geiger <lukas.geiger94@gmail.com>
This commit is contained in:
parent
9f0247cfa4
commit
cbb799e314
@ -1432,13 +1432,11 @@ class Qwen3VLForConditionalGeneration(
|
||||
vision_start_token_id = hf_config.vision_start_token_id
|
||||
spatial_merge_size = hf_config.vision_config.spatial_merge_size
|
||||
|
||||
input_tokens_tensor = torch.tensor(input_tokens)
|
||||
vision_start_indices = torch.argwhere(
|
||||
input_tokens_tensor == vision_start_token_id
|
||||
).squeeze(1)
|
||||
vision_tokens = input_tokens_tensor[vision_start_indices + 1]
|
||||
image_nums = (vision_tokens == image_token_id).sum()
|
||||
video_nums = (vision_tokens == video_token_id).sum()
|
||||
input_tokens_array = np.array(input_tokens)
|
||||
vision_start_mask = input_tokens_array == vision_start_token_id
|
||||
vision_tokens = input_tokens_array[vision_start_mask.nonzero()[0] + 1]
|
||||
image_nums = np.count_nonzero(vision_tokens == image_token_id)
|
||||
video_nums = np.count_nonzero(vision_tokens == video_token_id)
|
||||
llm_pos_ids_list: list = []
|
||||
|
||||
st = 0
|
||||
@ -1474,43 +1472,23 @@ class Qwen3VLForConditionalGeneration(
|
||||
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
||||
llm_pos_ids_list.append(
|
||||
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
|
||||
np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx
|
||||
)
|
||||
|
||||
t_index = (
|
||||
torch.arange(llm_grid_t)
|
||||
.view(-1, 1)
|
||||
.expand(-1, llm_grid_h * llm_grid_w)
|
||||
.flatten()
|
||||
)
|
||||
h_index = (
|
||||
torch.arange(llm_grid_h)
|
||||
.view(1, -1, 1)
|
||||
.expand(llm_grid_t, -1, llm_grid_w)
|
||||
.flatten()
|
||||
)
|
||||
w_index = (
|
||||
torch.arange(llm_grid_w)
|
||||
.view(1, 1, -1)
|
||||
.expand(llm_grid_t, llm_grid_h, -1)
|
||||
.flatten()
|
||||
)
|
||||
llm_pos_ids_list.append(
|
||||
torch.stack([t_index, h_index, w_index]) + text_len + st_idx
|
||||
)
|
||||
grid_indices = np.indices((llm_grid_t, llm_grid_h, llm_grid_w))
|
||||
llm_pos_ids_list.append(grid_indices.reshape(3, -1) + text_len + st_idx)
|
||||
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
|
||||
|
||||
if st < len(input_tokens):
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
||||
text_len = len(input_tokens) - st
|
||||
llm_pos_ids_list.append(
|
||||
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
|
||||
np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx
|
||||
)
|
||||
|
||||
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
|
||||
llm_positions = np.concatenate(llm_pos_ids_list, axis=1).reshape(3, -1)
|
||||
mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
|
||||
|
||||
return llm_positions, mrope_position_delta
|
||||
return torch.from_numpy(llm_positions), mrope_position_delta
|
||||
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.language_model
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user