[Model][Qwen3VL] Simplify get_mrope_input_positions using numpy (#28302)

Signed-off-by: Lukas Geiger <lukas.geiger94@gmail.com>
This commit is contained in:
Lukas Geiger 2025-11-12 02:55:10 +00:00 committed by GitHub
parent 9f0247cfa4
commit cbb799e314
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1432,13 +1432,11 @@ class Qwen3VLForConditionalGeneration(
vision_start_token_id = hf_config.vision_start_token_id vision_start_token_id = hf_config.vision_start_token_id
spatial_merge_size = hf_config.vision_config.spatial_merge_size spatial_merge_size = hf_config.vision_config.spatial_merge_size
input_tokens_tensor = torch.tensor(input_tokens) input_tokens_array = np.array(input_tokens)
vision_start_indices = torch.argwhere( vision_start_mask = input_tokens_array == vision_start_token_id
input_tokens_tensor == vision_start_token_id vision_tokens = input_tokens_array[vision_start_mask.nonzero()[0] + 1]
).squeeze(1) image_nums = np.count_nonzero(vision_tokens == image_token_id)
vision_tokens = input_tokens_tensor[vision_start_indices + 1] video_nums = np.count_nonzero(vision_tokens == video_token_id)
image_nums = (vision_tokens == image_token_id).sum()
video_nums = (vision_tokens == video_token_id).sum()
llm_pos_ids_list: list = [] llm_pos_ids_list: list = []
st = 0 st = 0
@ -1474,43 +1472,23 @@ class Qwen3VLForConditionalGeneration(
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
llm_pos_ids_list.append( llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx
) )
t_index = ( grid_indices = np.indices((llm_grid_t, llm_grid_h, llm_grid_w))
torch.arange(llm_grid_t) llm_pos_ids_list.append(grid_indices.reshape(3, -1) + text_len + st_idx)
.view(-1, 1)
.expand(-1, llm_grid_h * llm_grid_w)
.flatten()
)
h_index = (
torch.arange(llm_grid_h)
.view(1, -1, 1)
.expand(llm_grid_t, -1, llm_grid_w)
.flatten()
)
w_index = (
torch.arange(llm_grid_w)
.view(1, 1, -1)
.expand(llm_grid_t, llm_grid_h, -1)
.flatten()
)
llm_pos_ids_list.append(
torch.stack([t_index, h_index, w_index]) + text_len + st_idx
)
st = ed + llm_grid_t * llm_grid_h * llm_grid_w st = ed + llm_grid_t * llm_grid_h * llm_grid_w
if st < len(input_tokens): if st < len(input_tokens):
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
text_len = len(input_tokens) - st text_len = len(input_tokens) - st
llm_pos_ids_list.append( llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx
) )
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) llm_positions = np.concatenate(llm_pos_ids_list, axis=1).reshape(3, -1)
mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
return torch.from_numpy(llm_positions), mrope_position_delta
return llm_positions, mrope_position_delta
def get_language_model(self) -> torch.nn.Module: def get_language_model(self) -> torch.nn.Module:
return self.language_model return self.language_model