[BugFix] Fix multi-modal async scheduling race condition (#28706)

Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Nick Hill 2025-11-14 01:11:13 -08:00 committed by GitHub
parent c36bcfe6b3
commit bc3e43069a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 43 additions and 31 deletions

View File

@ -342,8 +342,8 @@ class MsgpackSerde(ObjectSerde):
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
self.encoder = MsgpackEncoder()
self.tensor_decoder = MsgpackDecoder(torch.Tensor)
self.mm_decoder = MsgpackDecoder(MultiModalKwargsItem)
self.tensor_decoder = MsgpackDecoder(torch.Tensor, share_mem=False)
self.mm_decoder = MsgpackDecoder(MultiModalKwargsItem, share_mem=False)
self._mm_kwargs_item_cls = MultiModalKwargsItem
def serialize(self, value: Any) -> tuple[bytes | list[bytes], int, bytes, int]:
@ -368,7 +368,7 @@ class MsgpackSerde(ObjectSerde):
# pickle.loads do not read past the end of a pickled object
# within a large buffer, so we can skip storing the metadata size
type_name, nbytes, len_arr = pickle.loads(data_view)
serialized_data = bytearray(data_view[-nbytes:])
serialized_data = data_view[-nbytes:]
if type_name == torch.Tensor.__name__:
obj = []

View File

@ -31,6 +31,7 @@ from vllm.multimodal.inputs import (
MultiModalSharedField,
NestedTensors,
)
from vllm.utils.platform_utils import is_pin_memory_available
from vllm.v1.engine import UtilityResult
from vllm.v1.utils import tensor_data
@ -282,7 +283,9 @@ class MsgpackDecoder:
not thread-safe when encoding tensors / numpy arrays.
"""
def __init__(self, t: Any | None = None):
def __init__(self, t: Any | None = None, share_mem: bool = True):
self.share_mem = share_mem
self.pin_tensors = is_pin_memory_available()
args = () if t is None else (t,)
self.decoder = msgpack.Decoder(
*args, ext_hook=self.ext_hook, dec_hook=self.dec_hook
@ -347,21 +350,30 @@ class MsgpackDecoder:
# zero-copy decode. We assume the ndarray will not be kept around,
# as it now locks the whole received message buffer in memory.
buffer = self.aux_buffers[data] if isinstance(data, int) else data
return np.frombuffer(buffer, dtype=dtype).reshape(shape)
arr = np.frombuffer(buffer, dtype=dtype)
if not self.share_mem:
arr = arr.copy()
return arr.reshape(shape)
def _decode_tensor(self, arr: Any) -> torch.Tensor:
dtype, shape, data = arr
# Copy from inline representation, to decouple the memory storage
# of the message from the original buffer. And also make Torch
# not complain about a readonly memoryview.
buffer = self.aux_buffers[data] if isinstance(data, int) else bytearray(data)
is_aux = isinstance(data, int)
buffer = self.aux_buffers[data] if is_aux else data
buffer = buffer if isinstance(buffer, memoryview) else memoryview(buffer)
torch_dtype = getattr(torch, dtype)
assert isinstance(torch_dtype, torch.dtype)
if not buffer: # torch.frombuffer doesn't like empty buffers
if not buffer.nbytes: # torch.frombuffer doesn't like empty buffers
assert 0 in shape
return torch.empty(shape, dtype=torch_dtype)
# Create uint8 array
arr = torch.frombuffer(buffer, dtype=torch.uint8)
# Clone ensures tensor is backed by pytorch-owned memory for safe
# future async CPU->GPU transfer.
# Pin larger tensors for more efficient CPU->GPU transfer.
if not is_aux:
arr = arr.clone()
elif not self.share_mem:
arr = arr.pin_memory() if self.pin_tensors else arr.clone()
# Convert back to proper shape & type
return arr.view(torch_dtype).view(shape)

View File

@ -2590,28 +2590,28 @@ class GPUModelRunner(
)
)
dp_rank = self.parallel_config.data_parallel_rank
if ubatch_slices:
assert num_tokens_across_dp is not None
num_input_tokens = int(num_tokens_across_dp[dp_rank].item())
self.pad_out_ubatch_slice(ubatch_slices, num_input_tokens)
elif num_tokens_across_dp is not None:
num_input_tokens = int(num_tokens_across_dp[dp_rank].item())
else:
num_input_tokens = self._get_num_input_tokens(
scheduler_output.total_num_scheduled_tokens
)
dp_rank = self.parallel_config.data_parallel_rank
if ubatch_slices:
assert num_tokens_across_dp is not None
num_input_tokens = int(num_tokens_across_dp[dp_rank].item())
self.pad_out_ubatch_slice(ubatch_slices, num_input_tokens)
elif num_tokens_across_dp is not None:
num_input_tokens = int(num_tokens_across_dp[dp_rank].item())
else:
num_input_tokens = self._get_num_input_tokens(
scheduler_output.total_num_scheduled_tokens
)
(
input_ids,
inputs_embeds,
positions,
intermediate_tensors,
model_kwargs,
ec_connector_output,
) = self._preprocess(
scheduler_output, num_input_tokens, intermediate_tensors
)
(
input_ids,
inputs_embeds,
positions,
intermediate_tensors,
model_kwargs,
ec_connector_output,
) = self._preprocess(
scheduler_output, num_input_tokens, intermediate_tensors
)
uniform_decode = (
max_num_scheduled_tokens == self.uniform_decode_query_len