Merge fd593c596de73fcb98ee8636150a385ea896b43d into 254f6b986720c92ddf97fbb1a6a6465da8e87e29

This commit is contained in:
Amir Samani 2025-12-25 00:06:49 +00:00 committed by GitHub
commit 3bb2e7950a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 16 additions and 8 deletions

View File

@ -293,7 +293,7 @@ class CommunicatorBenchmark:
graph = torch.cuda.CUDAGraph()
graph_pool = torch.cuda.graph_pool_handle()
set_graph_pool_id(graph_pool)
with torch.cuda.graph(graph, pool=graph_pool):
with torch.cuda.graph(graph, pool=graph_pool, stream=stream):
for _ in range(CUDA_GRAPH_CAPTURE_CYCLES):
allreduce_fn(graph_input)

View File

@ -189,13 +189,16 @@ class TestCUDAGraphWrapper:
self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL
)
batch_descriptor = BatchDescriptor(num_tokens=10)
stream = torch.cuda.Stream()
# 0. global warmup
with set_forward_context(
attn_metadata=None,
vllm_config=self.vllm_config,
cudagraph_runtime_mode=CUDAGraphMode.NONE,
batch_descriptor=None,
with (
set_forward_context(
attn_metadata=None,
vllm_config=self.vllm_config,
cudagraph_runtime_mode=CUDAGraphMode.NONE,
batch_descriptor=None,
),
torch.cuda.stream(stream),
):
wrapper(self.input_tensor)
@ -208,6 +211,7 @@ class TestCUDAGraphWrapper:
batch_descriptor=batch_descriptor,
),
patch("torch.cuda.graph", wraps=torch.cuda.graph) as mock_cuda_graph,
torch.cuda.stream(stream),
):
output1 = wrapper(self.input_tensor)
# capturing phase should generate a zero output

View File

@ -263,7 +263,11 @@ class CUDAGraphWrapper:
else:
set_graph_pool_id(current_platform.graph_pool_handle())
# mind-exploding: carefully manage the reference and memory.
with torch.cuda.graph(cudagraph, pool=self.graph_pool):
with torch.cuda.graph(
cudagraph,
pool=self.graph_pool,
stream=torch.cuda.current_stream(),
):
# `output` is managed by pytorch's cudagraph pool
output = self.runnable(*args, **kwargs)
if self.cudagraph_options.weak_ref_output: