[v1][bugfix] fix cudagraph with inplace buffer assignment (#11596)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao 2024-12-29 17:03:49 +08:00 committed by GitHub
parent 32b4c63f02
commit dba4d9dec6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 10 additions and 11 deletions

View File

@ -28,11 +28,12 @@ class TorchCompileWrapperWithCustomDispatcher:
compiled_callable: Optional[Callable] = None,
compilation_level: int = 0):
vllm_config = get_current_vllm_config()
self.vllm_config = vllm_config
if compiled_callable is None:
# default compilation settings
# compiling the forward method
vllm_config = get_current_vllm_config()
backend = vllm_config.compilation_config.init_backend(vllm_config)
compiled_callable = torch.compile(
@ -82,6 +83,13 @@ class TorchCompileWrapperWithCustomDispatcher:
self.compiled_codes.append(new_code)
if self.vllm_config.compilation_config.use_cudagraph and \
"update" in new_code.co_names:
import depyf
src = depyf.decompile(new_code)
msg = "Assigning / modifying buffers of nn.Module during forward pass is not allowed when using cudagraph inside the compiler because it will cause silent errors. Please use eager mode or fix the code. The following code contains clues about which buffer is being modified (please search for the usage of the function `update`):\n" + src # noqa
raise RuntimeError(msg)
@contextmanager
def dispatch_to_code(self, index: int):
"""Context manager to dispatch to the compiled code.

View File

@ -541,19 +541,12 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
short_cache = self._compute_cos_sin_cache(
original_max_position_embeddings, short_factor, short_mscale)
short_cache = short_cache.to(dtype)
self.register_buffer("short_cos_sin_cache",
short_cache,
persistent=False)
long_cache = self._compute_cos_sin_cache(max_position_embeddings,
long_factor, long_mscale)
long_cache = long_cache.to(dtype)
self.register_buffer("long_cos_sin_cache",
long_cache,
persistent=False)
long_short_cache = torch.cat(
[self.short_cos_sin_cache, self.long_cos_sin_cache], dim=0)
long_short_cache = torch.cat([short_cache, long_cache], dim=0)
self.register_buffer("long_short_cos_sin_cache",
long_short_cache,
persistent=False)
@ -593,8 +586,6 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
torch.full_like(positions, k)).long()
idx = (torch.add(positions, long_prompt_offset)
if long_prompt_offset is not None else positions)
self.long_short_cos_sin_cache: torch.Tensor = (
self.long_short_cos_sin_cache.to(idx.device))
idx = torch.add(idx, offsets) if offsets is not None else idx
cos_sin = torch.index_select(self.long_short_cos_sin_cache, 0, idx)