Added validation for max_size parameter in get_3d_rotary_pos_embed function when grid_type is set to 'slice'.

This commit is contained in:
sko00o 2025-07-31 15:01:48 +08:00
parent 389fb0323f
commit 881bbbf6c9
No known key found for this signature in database
GPG Key ID: 240DEE2151D0CF19

View File

@ -174,6 +174,8 @@ def get_3d_rotary_pos_embed(
grid_t = np.arange(temporal_size, dtype=np.float32)
grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32)
elif grid_type == "slice":
if max_size is None:
raise ValueError("`max_size` must be provided when `grid_type` is 'slice'")
max_h, max_w = max_size
grid_size_h, grid_size_w = grid_size
grid_h = np.arange(max_h, dtype=np.float32)