diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index 64c2dac524f2..d4b8cf25fec5 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -923,14 +923,19 @@ class MRotaryEmbedding(RotaryEmbedding): def get_input_positions( input_tokens: List[int], hf_config: PretrainedConfig, - image_grid_thw: Union[List[List[int]], torch.Tensor], - video_grid_thw: Union[List[List[int]], torch.Tensor], - second_per_grid_ts: Optional[List[float]] = None, + image_grid_thw: Optional[Union[List[List[int]], torch.Tensor]], + video_grid_thw: Optional[Union[List[List[int]], torch.Tensor]], + second_per_grid_ts: Optional[List[float]], context_len: int = 0, seq_len: Optional[int] = None, ) -> Tuple[List[List[int]], int]: """Get mrope input positions and delta value.""" + image_grid_thw = [] if image_grid_thw is None else image_grid_thw + video_grid_thw = [] if video_grid_thw is None else video_grid_thw + second_per_grid_ts = [] if second_per_grid_ts is None else \ + second_per_grid_ts + llm_positions, mrope_position_delta = \ MRotaryEmbedding.get_input_positions_tensor( input_tokens=input_tokens, @@ -950,7 +955,7 @@ class MRotaryEmbedding(RotaryEmbedding): hf_config: PretrainedConfig, image_grid_thw: Union[List[List[int]], torch.Tensor], video_grid_thw: Union[List[List[int]], torch.Tensor], - second_per_grid_ts: Optional[List[float]] = None, + second_per_grid_ts: List[float], context_len: int = 0, seq_len: Optional[int] = None, ) -> Tuple[torch.Tensor, int]: @@ -1006,7 +1011,7 @@ class MRotaryEmbedding(RotaryEmbedding): video_grid_thw[video_index][2], ) video_second_per_grid_t = 1.0 - if second_per_grid_ts is not None: + if second_per_grid_ts: video_second_per_grid_t = second_per_grid_ts[video_index] video_index += 1 remain_videos -= 1