mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 01:05:01 +08:00
[Model][Qwen3VL] Simplify get_mrope_input_positions using numpy (#28302)
Signed-off-by: Lukas Geiger <lukas.geiger94@gmail.com>
This commit is contained in:
parent
9f0247cfa4
commit
cbb799e314
@ -1432,13 +1432,11 @@ class Qwen3VLForConditionalGeneration(
|
|||||||
vision_start_token_id = hf_config.vision_start_token_id
|
vision_start_token_id = hf_config.vision_start_token_id
|
||||||
spatial_merge_size = hf_config.vision_config.spatial_merge_size
|
spatial_merge_size = hf_config.vision_config.spatial_merge_size
|
||||||
|
|
||||||
input_tokens_tensor = torch.tensor(input_tokens)
|
input_tokens_array = np.array(input_tokens)
|
||||||
vision_start_indices = torch.argwhere(
|
vision_start_mask = input_tokens_array == vision_start_token_id
|
||||||
input_tokens_tensor == vision_start_token_id
|
vision_tokens = input_tokens_array[vision_start_mask.nonzero()[0] + 1]
|
||||||
).squeeze(1)
|
image_nums = np.count_nonzero(vision_tokens == image_token_id)
|
||||||
vision_tokens = input_tokens_tensor[vision_start_indices + 1]
|
video_nums = np.count_nonzero(vision_tokens == video_token_id)
|
||||||
image_nums = (vision_tokens == image_token_id).sum()
|
|
||||||
video_nums = (vision_tokens == video_token_id).sum()
|
|
||||||
llm_pos_ids_list: list = []
|
llm_pos_ids_list: list = []
|
||||||
|
|
||||||
st = 0
|
st = 0
|
||||||
@ -1474,43 +1472,23 @@ class Qwen3VLForConditionalGeneration(
|
|||||||
|
|
||||||
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
||||||
llm_pos_ids_list.append(
|
llm_pos_ids_list.append(
|
||||||
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
|
np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx
|
||||||
)
|
)
|
||||||
|
|
||||||
t_index = (
|
grid_indices = np.indices((llm_grid_t, llm_grid_h, llm_grid_w))
|
||||||
torch.arange(llm_grid_t)
|
llm_pos_ids_list.append(grid_indices.reshape(3, -1) + text_len + st_idx)
|
||||||
.view(-1, 1)
|
|
||||||
.expand(-1, llm_grid_h * llm_grid_w)
|
|
||||||
.flatten()
|
|
||||||
)
|
|
||||||
h_index = (
|
|
||||||
torch.arange(llm_grid_h)
|
|
||||||
.view(1, -1, 1)
|
|
||||||
.expand(llm_grid_t, -1, llm_grid_w)
|
|
||||||
.flatten()
|
|
||||||
)
|
|
||||||
w_index = (
|
|
||||||
torch.arange(llm_grid_w)
|
|
||||||
.view(1, 1, -1)
|
|
||||||
.expand(llm_grid_t, llm_grid_h, -1)
|
|
||||||
.flatten()
|
|
||||||
)
|
|
||||||
llm_pos_ids_list.append(
|
|
||||||
torch.stack([t_index, h_index, w_index]) + text_len + st_idx
|
|
||||||
)
|
|
||||||
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
|
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
|
||||||
|
|
||||||
if st < len(input_tokens):
|
if st < len(input_tokens):
|
||||||
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
||||||
text_len = len(input_tokens) - st
|
text_len = len(input_tokens) - st
|
||||||
llm_pos_ids_list.append(
|
llm_pos_ids_list.append(
|
||||||
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
|
np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx
|
||||||
)
|
)
|
||||||
|
|
||||||
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
|
llm_positions = np.concatenate(llm_pos_ids_list, axis=1).reshape(3, -1)
|
||||||
mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
|
mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
|
||||||
|
return torch.from_numpy(llm_positions), mrope_position_delta
|
||||||
return llm_positions, mrope_position_delta
|
|
||||||
|
|
||||||
def get_language_model(self) -> torch.nn.Module:
|
def get_language_model(self) -> torch.nn.Module:
|
||||||
return self.language_model
|
return self.language_model
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user