diff --git a/vllm/compilation/wrapper.py b/vllm/compilation/wrapper.py index c10241b483169..e3260a10c02ae 100644 --- a/vllm/compilation/wrapper.py +++ b/vllm/compilation/wrapper.py @@ -28,11 +28,12 @@ class TorchCompileWrapperWithCustomDispatcher: compiled_callable: Optional[Callable] = None, compilation_level: int = 0): + vllm_config = get_current_vllm_config() + self.vllm_config = vllm_config if compiled_callable is None: # default compilation settings # compiling the forward method - vllm_config = get_current_vllm_config() backend = vllm_config.compilation_config.init_backend(vllm_config) compiled_callable = torch.compile( @@ -82,6 +83,13 @@ class TorchCompileWrapperWithCustomDispatcher: self.compiled_codes.append(new_code) + if self.vllm_config.compilation_config.use_cudagraph and \ + "update" in new_code.co_names: + import depyf + src = depyf.decompile(new_code) + msg = "Assigning / modifying buffers of nn.Module during forward pass is not allowed when using cudagraph inside the compiler because it will cause silent errors. Please use eager mode or fix the code. The following code contains clues about which buffer is being modified (please search for the usage of the function `update`):\n" + src # noqa + raise RuntimeError(msg) + @contextmanager def dispatch_to_code(self, index: int): """Context manager to dispatch to the compiled code. diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index 117fe086e5e87..6695d44dfa32b 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -541,19 +541,12 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module): short_cache = self._compute_cos_sin_cache( original_max_position_embeddings, short_factor, short_mscale) short_cache = short_cache.to(dtype) - self.register_buffer("short_cos_sin_cache", - short_cache, - persistent=False) long_cache = self._compute_cos_sin_cache(max_position_embeddings, long_factor, long_mscale) long_cache = long_cache.to(dtype) - self.register_buffer("long_cos_sin_cache", - long_cache, - persistent=False) - long_short_cache = torch.cat( - [self.short_cos_sin_cache, self.long_cos_sin_cache], dim=0) + long_short_cache = torch.cat([short_cache, long_cache], dim=0) self.register_buffer("long_short_cos_sin_cache", long_short_cache, persistent=False) @@ -593,8 +586,6 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module): torch.full_like(positions, k)).long() idx = (torch.add(positions, long_prompt_offset) if long_prompt_offset is not None else positions) - self.long_short_cos_sin_cache: torch.Tensor = ( - self.long_short_cos_sin_cache.to(idx.device)) idx = torch.add(idx, offsets) if offsets is not None else idx cos_sin = torch.index_select(self.long_short_cos_sin_cache, 0, idx)