From dba4d9dec606da028fbb28240e99cabd5a761e6a Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 29 Dec 2024 17:03:49 +0800 Subject: [PATCH] [v1][bugfix] fix cudagraph with inplace buffer assignment (#11596) Signed-off-by: youkaichao --- vllm/compilation/wrapper.py | 10 +++++++++- vllm/model_executor/layers/rotary_embedding.py | 11 +---------- 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/vllm/compilation/wrapper.py b/vllm/compilation/wrapper.py index c10241b483169..e3260a10c02ae 100644 --- a/vllm/compilation/wrapper.py +++ b/vllm/compilation/wrapper.py @@ -28,11 +28,12 @@ class TorchCompileWrapperWithCustomDispatcher: compiled_callable: Optional[Callable] = None, compilation_level: int = 0): + vllm_config = get_current_vllm_config() + self.vllm_config = vllm_config if compiled_callable is None: # default compilation settings # compiling the forward method - vllm_config = get_current_vllm_config() backend = vllm_config.compilation_config.init_backend(vllm_config) compiled_callable = torch.compile( @@ -82,6 +83,13 @@ class TorchCompileWrapperWithCustomDispatcher: self.compiled_codes.append(new_code) + if self.vllm_config.compilation_config.use_cudagraph and \ + "update" in new_code.co_names: + import depyf + src = depyf.decompile(new_code) + msg = "Assigning / modifying buffers of nn.Module during forward pass is not allowed when using cudagraph inside the compiler because it will cause silent errors. Please use eager mode or fix the code. The following code contains clues about which buffer is being modified (please search for the usage of the function `update`):\n" + src # noqa + raise RuntimeError(msg) + @contextmanager def dispatch_to_code(self, index: int): """Context manager to dispatch to the compiled code. diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index 117fe086e5e87..6695d44dfa32b 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -541,19 +541,12 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module): short_cache = self._compute_cos_sin_cache( original_max_position_embeddings, short_factor, short_mscale) short_cache = short_cache.to(dtype) - self.register_buffer("short_cos_sin_cache", - short_cache, - persistent=False) long_cache = self._compute_cos_sin_cache(max_position_embeddings, long_factor, long_mscale) long_cache = long_cache.to(dtype) - self.register_buffer("long_cos_sin_cache", - long_cache, - persistent=False) - long_short_cache = torch.cat( - [self.short_cos_sin_cache, self.long_cos_sin_cache], dim=0) + long_short_cache = torch.cat([short_cache, long_cache], dim=0) self.register_buffer("long_short_cos_sin_cache", long_short_cache, persistent=False) @@ -593,8 +586,6 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module): torch.full_like(positions, k)).long() idx = (torch.add(positions, long_prompt_offset) if long_prompt_offset is not None else positions) - self.long_short_cos_sin_cache: torch.Tensor = ( - self.long_short_cos_sin_cache.to(idx.device)) idx = torch.add(idx, offsets) if offsets is not None else idx cos_sin = torch.index_select(self.long_short_cos_sin_cache, 0, idx)