mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-05 03:29:08 +08:00
[BugFix] Fix multi-modal async scheduling race condition (#28706)
Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
parent
c36bcfe6b3
commit
bc3e43069a
@ -342,8 +342,8 @@ class MsgpackSerde(ObjectSerde):
|
|||||||
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
|
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
|
||||||
|
|
||||||
self.encoder = MsgpackEncoder()
|
self.encoder = MsgpackEncoder()
|
||||||
self.tensor_decoder = MsgpackDecoder(torch.Tensor)
|
self.tensor_decoder = MsgpackDecoder(torch.Tensor, share_mem=False)
|
||||||
self.mm_decoder = MsgpackDecoder(MultiModalKwargsItem)
|
self.mm_decoder = MsgpackDecoder(MultiModalKwargsItem, share_mem=False)
|
||||||
self._mm_kwargs_item_cls = MultiModalKwargsItem
|
self._mm_kwargs_item_cls = MultiModalKwargsItem
|
||||||
|
|
||||||
def serialize(self, value: Any) -> tuple[bytes | list[bytes], int, bytes, int]:
|
def serialize(self, value: Any) -> tuple[bytes | list[bytes], int, bytes, int]:
|
||||||
@ -368,7 +368,7 @@ class MsgpackSerde(ObjectSerde):
|
|||||||
# pickle.loads do not read past the end of a pickled object
|
# pickle.loads do not read past the end of a pickled object
|
||||||
# within a large buffer, so we can skip storing the metadata size
|
# within a large buffer, so we can skip storing the metadata size
|
||||||
type_name, nbytes, len_arr = pickle.loads(data_view)
|
type_name, nbytes, len_arr = pickle.loads(data_view)
|
||||||
serialized_data = bytearray(data_view[-nbytes:])
|
serialized_data = data_view[-nbytes:]
|
||||||
|
|
||||||
if type_name == torch.Tensor.__name__:
|
if type_name == torch.Tensor.__name__:
|
||||||
obj = []
|
obj = []
|
||||||
|
|||||||
@ -31,6 +31,7 @@ from vllm.multimodal.inputs import (
|
|||||||
MultiModalSharedField,
|
MultiModalSharedField,
|
||||||
NestedTensors,
|
NestedTensors,
|
||||||
)
|
)
|
||||||
|
from vllm.utils.platform_utils import is_pin_memory_available
|
||||||
from vllm.v1.engine import UtilityResult
|
from vllm.v1.engine import UtilityResult
|
||||||
from vllm.v1.utils import tensor_data
|
from vllm.v1.utils import tensor_data
|
||||||
|
|
||||||
@ -282,7 +283,9 @@ class MsgpackDecoder:
|
|||||||
not thread-safe when encoding tensors / numpy arrays.
|
not thread-safe when encoding tensors / numpy arrays.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, t: Any | None = None):
|
def __init__(self, t: Any | None = None, share_mem: bool = True):
|
||||||
|
self.share_mem = share_mem
|
||||||
|
self.pin_tensors = is_pin_memory_available()
|
||||||
args = () if t is None else (t,)
|
args = () if t is None else (t,)
|
||||||
self.decoder = msgpack.Decoder(
|
self.decoder = msgpack.Decoder(
|
||||||
*args, ext_hook=self.ext_hook, dec_hook=self.dec_hook
|
*args, ext_hook=self.ext_hook, dec_hook=self.dec_hook
|
||||||
@ -347,21 +350,30 @@ class MsgpackDecoder:
|
|||||||
# zero-copy decode. We assume the ndarray will not be kept around,
|
# zero-copy decode. We assume the ndarray will not be kept around,
|
||||||
# as it now locks the whole received message buffer in memory.
|
# as it now locks the whole received message buffer in memory.
|
||||||
buffer = self.aux_buffers[data] if isinstance(data, int) else data
|
buffer = self.aux_buffers[data] if isinstance(data, int) else data
|
||||||
return np.frombuffer(buffer, dtype=dtype).reshape(shape)
|
arr = np.frombuffer(buffer, dtype=dtype)
|
||||||
|
if not self.share_mem:
|
||||||
|
arr = arr.copy()
|
||||||
|
return arr.reshape(shape)
|
||||||
|
|
||||||
def _decode_tensor(self, arr: Any) -> torch.Tensor:
|
def _decode_tensor(self, arr: Any) -> torch.Tensor:
|
||||||
dtype, shape, data = arr
|
dtype, shape, data = arr
|
||||||
# Copy from inline representation, to decouple the memory storage
|
is_aux = isinstance(data, int)
|
||||||
# of the message from the original buffer. And also make Torch
|
buffer = self.aux_buffers[data] if is_aux else data
|
||||||
# not complain about a readonly memoryview.
|
buffer = buffer if isinstance(buffer, memoryview) else memoryview(buffer)
|
||||||
buffer = self.aux_buffers[data] if isinstance(data, int) else bytearray(data)
|
|
||||||
torch_dtype = getattr(torch, dtype)
|
torch_dtype = getattr(torch, dtype)
|
||||||
assert isinstance(torch_dtype, torch.dtype)
|
assert isinstance(torch_dtype, torch.dtype)
|
||||||
if not buffer: # torch.frombuffer doesn't like empty buffers
|
if not buffer.nbytes: # torch.frombuffer doesn't like empty buffers
|
||||||
assert 0 in shape
|
assert 0 in shape
|
||||||
return torch.empty(shape, dtype=torch_dtype)
|
return torch.empty(shape, dtype=torch_dtype)
|
||||||
# Create uint8 array
|
# Create uint8 array
|
||||||
arr = torch.frombuffer(buffer, dtype=torch.uint8)
|
arr = torch.frombuffer(buffer, dtype=torch.uint8)
|
||||||
|
# Clone ensures tensor is backed by pytorch-owned memory for safe
|
||||||
|
# future async CPU->GPU transfer.
|
||||||
|
# Pin larger tensors for more efficient CPU->GPU transfer.
|
||||||
|
if not is_aux:
|
||||||
|
arr = arr.clone()
|
||||||
|
elif not self.share_mem:
|
||||||
|
arr = arr.pin_memory() if self.pin_tensors else arr.clone()
|
||||||
# Convert back to proper shape & type
|
# Convert back to proper shape & type
|
||||||
return arr.view(torch_dtype).view(shape)
|
return arr.view(torch_dtype).view(shape)
|
||||||
|
|
||||||
|
|||||||
@ -2590,28 +2590,28 @@ class GPUModelRunner(
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
dp_rank = self.parallel_config.data_parallel_rank
|
dp_rank = self.parallel_config.data_parallel_rank
|
||||||
if ubatch_slices:
|
if ubatch_slices:
|
||||||
assert num_tokens_across_dp is not None
|
assert num_tokens_across_dp is not None
|
||||||
num_input_tokens = int(num_tokens_across_dp[dp_rank].item())
|
num_input_tokens = int(num_tokens_across_dp[dp_rank].item())
|
||||||
self.pad_out_ubatch_slice(ubatch_slices, num_input_tokens)
|
self.pad_out_ubatch_slice(ubatch_slices, num_input_tokens)
|
||||||
elif num_tokens_across_dp is not None:
|
elif num_tokens_across_dp is not None:
|
||||||
num_input_tokens = int(num_tokens_across_dp[dp_rank].item())
|
num_input_tokens = int(num_tokens_across_dp[dp_rank].item())
|
||||||
else:
|
else:
|
||||||
num_input_tokens = self._get_num_input_tokens(
|
num_input_tokens = self._get_num_input_tokens(
|
||||||
scheduler_output.total_num_scheduled_tokens
|
scheduler_output.total_num_scheduled_tokens
|
||||||
)
|
)
|
||||||
|
|
||||||
(
|
(
|
||||||
input_ids,
|
input_ids,
|
||||||
inputs_embeds,
|
inputs_embeds,
|
||||||
positions,
|
positions,
|
||||||
intermediate_tensors,
|
intermediate_tensors,
|
||||||
model_kwargs,
|
model_kwargs,
|
||||||
ec_connector_output,
|
ec_connector_output,
|
||||||
) = self._preprocess(
|
) = self._preprocess(
|
||||||
scheduler_output, num_input_tokens, intermediate_tensors
|
scheduler_output, num_input_tokens, intermediate_tensors
|
||||||
)
|
)
|
||||||
|
|
||||||
uniform_decode = (
|
uniform_decode = (
|
||||||
max_num_scheduled_tokens == self.uniform_decode_query_len
|
max_num_scheduled_tokens == self.uniform_decode_query_len
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user