mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 15:25:01 +08:00
[Core][Bugfix] Use correct device to initialize GPU data during CUDA-graph-capture (#11233)
Signed-off-by: Yan Burman <yanburman@users.noreply.github.com> Signed-off-by: Ido Asraff <idoa@atero.ai>
This commit is contained in:
parent
d91457d529
commit
300acb8347
@ -50,7 +50,7 @@ def graph_allreduce(tp_size, pp_size, rank, distributed_init_port):
|
|||||||
|
|
||||||
for sz in test_sizes:
|
for sz in test_sizes:
|
||||||
for dtype in [torch.float32, torch.float16, torch.bfloat16]:
|
for dtype in [torch.float32, torch.float16, torch.bfloat16]:
|
||||||
with graph_capture() as graph_capture_context:
|
with graph_capture(device=device) as graph_capture_context:
|
||||||
# use integers so result matches NCCL exactly
|
# use integers so result matches NCCL exactly
|
||||||
inp1 = torch.randint(1,
|
inp1 = torch.randint(1,
|
||||||
16, (sz, ),
|
16, (sz, ),
|
||||||
|
|||||||
@ -107,7 +107,7 @@ def multiple_allreduce_with_vllm_worker_fn():
|
|||||||
device = torch.device(f"cuda:{torch.distributed.get_rank()}")
|
device = torch.device(f"cuda:{torch.distributed.get_rank()}")
|
||||||
ensure_model_parallel_initialized(2, 2)
|
ensure_model_parallel_initialized(2, 2)
|
||||||
tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device)
|
tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device)
|
||||||
with graph_capture():
|
with graph_capture(device=device):
|
||||||
# two tp groups can communicate independently
|
# two tp groups can communicate independently
|
||||||
if torch.distributed.get_rank() in [0, 1]:
|
if torch.distributed.get_rank() in [0, 1]:
|
||||||
tensor = tensor_model_parallel_all_reduce(tensor)
|
tensor = tensor_model_parallel_all_reduce(tensor)
|
||||||
|
|||||||
@ -920,7 +920,7 @@ def get_kv_transfer_group() -> kv_transfer.KVTransferAgent:
|
|||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def graph_capture():
|
def graph_capture(device: torch.device):
|
||||||
"""
|
"""
|
||||||
`graph_capture` is a context manager which should surround the code that
|
`graph_capture` is a context manager which should surround the code that
|
||||||
is capturing the CUDA graph. Its main purpose is to ensure that the
|
is capturing the CUDA graph. Its main purpose is to ensure that the
|
||||||
@ -934,8 +934,9 @@ def graph_capture():
|
|||||||
in order to explicitly distinguish the kernels to capture
|
in order to explicitly distinguish the kernels to capture
|
||||||
from other kernels possibly launched on background in the default stream.
|
from other kernels possibly launched on background in the default stream.
|
||||||
"""
|
"""
|
||||||
with get_tp_group().graph_capture() as context, get_pp_group(
|
context = GraphCaptureContext(torch.cuda.Stream(device=device))
|
||||||
).graph_capture(context):
|
with get_tp_group().graph_capture(context), get_pp_group().graph_capture(
|
||||||
|
context):
|
||||||
yield context
|
yield context
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -836,7 +836,7 @@ class GPUModelRunner:
|
|||||||
# Trigger CUDA graph capture for specific shapes.
|
# Trigger CUDA graph capture for specific shapes.
|
||||||
# Capture the large shapes first so that the smaller shapes
|
# Capture the large shapes first so that the smaller shapes
|
||||||
# can reuse the memory pool allocated for the large shapes.
|
# can reuse the memory pool allocated for the large shapes.
|
||||||
with graph_capture():
|
with graph_capture(device=self.device):
|
||||||
for num_tokens in reversed(self.cudagraph_batch_sizes):
|
for num_tokens in reversed(self.cudagraph_batch_sizes):
|
||||||
for _ in range(self.vllm_config.compilation_config.
|
for _ in range(self.vllm_config.compilation_config.
|
||||||
cudagraph_num_of_warmups):
|
cudagraph_num_of_warmups):
|
||||||
|
|||||||
@ -1426,10 +1426,15 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
|||||||
|
|
||||||
# Prepare dummy inputs. These will be reused for all batch sizes.
|
# Prepare dummy inputs. These will be reused for all batch sizes.
|
||||||
max_batch_size = self.max_batchsize_to_capture
|
max_batch_size = self.max_batchsize_to_capture
|
||||||
input_tokens = torch.zeros(max_batch_size, dtype=torch.long).cuda()
|
input_tokens = torch.zeros(max_batch_size,
|
||||||
input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda()
|
dtype=torch.long,
|
||||||
|
device=self.device)
|
||||||
|
input_positions = torch.zeros(max_batch_size,
|
||||||
|
dtype=torch.long,
|
||||||
|
device=self.device)
|
||||||
if self.model_config.uses_mrope:
|
if self.model_config.uses_mrope:
|
||||||
input_positions = torch.tile(input_positions, (3, 1))
|
input_positions = torch.tile(input_positions,
|
||||||
|
(3, 1)).cuda(device=self.device)
|
||||||
# Prepare dummy previous_hidden_states only if needed by the model.
|
# Prepare dummy previous_hidden_states only if needed by the model.
|
||||||
# This is used by draft models such as EAGLE.
|
# This is used by draft models such as EAGLE.
|
||||||
previous_hidden_states = None
|
previous_hidden_states = None
|
||||||
@ -1448,8 +1453,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
|||||||
dtype=self.model_config.dtype,
|
dtype=self.model_config.dtype,
|
||||||
device=self.device)
|
device=self.device)
|
||||||
|
|
||||||
with self.attn_state.graph_capture(
|
with self.attn_state.graph_capture(max_batch_size), graph_capture(
|
||||||
max_batch_size), graph_capture() as graph_capture_context:
|
self.device) as graph_capture_context:
|
||||||
# NOTE: Capturing the largest batch size first may help reduce the
|
# NOTE: Capturing the largest batch size first may help reduce the
|
||||||
# memory usage of CUDA graph.
|
# memory usage of CUDA graph.
|
||||||
for virtual_engine in range(
|
for virtual_engine in range(
|
||||||
@ -1549,10 +1554,12 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
|||||||
"""
|
"""
|
||||||
# During the decode phase encoder_input_ids and encoder_positions are
|
# During the decode phase encoder_input_ids and encoder_positions are
|
||||||
# unset. Do the same thing for graph capture.
|
# unset. Do the same thing for graph capture.
|
||||||
capture_inputs["encoder_input_ids"] = torch.tensor(
|
capture_inputs["encoder_input_ids"] = torch.tensor([],
|
||||||
[], dtype=torch.long).cuda()
|
dtype=torch.long,
|
||||||
capture_inputs["encoder_positions"] = torch.tensor(
|
device=self.device)
|
||||||
[], dtype=torch.long).cuda()
|
capture_inputs["encoder_positions"] = torch.tensor([],
|
||||||
|
dtype=torch.long,
|
||||||
|
device=self.device)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def vocab_size(self) -> int:
|
def vocab_size(self) -> int:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user