mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-25 23:44:06 +08:00
[v1][bugfix] fix cudagraph with inplace buffer assignment (#11596)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
parent
32b4c63f02
commit
dba4d9dec6
@ -28,11 +28,12 @@ class TorchCompileWrapperWithCustomDispatcher:
|
||||
compiled_callable: Optional[Callable] = None,
|
||||
compilation_level: int = 0):
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
self.vllm_config = vllm_config
|
||||
if compiled_callable is None:
|
||||
# default compilation settings
|
||||
# compiling the forward method
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
backend = vllm_config.compilation_config.init_backend(vllm_config)
|
||||
|
||||
compiled_callable = torch.compile(
|
||||
@ -82,6 +83,13 @@ class TorchCompileWrapperWithCustomDispatcher:
|
||||
|
||||
self.compiled_codes.append(new_code)
|
||||
|
||||
if self.vllm_config.compilation_config.use_cudagraph and \
|
||||
"update" in new_code.co_names:
|
||||
import depyf
|
||||
src = depyf.decompile(new_code)
|
||||
msg = "Assigning / modifying buffers of nn.Module during forward pass is not allowed when using cudagraph inside the compiler because it will cause silent errors. Please use eager mode or fix the code. The following code contains clues about which buffer is being modified (please search for the usage of the function `update`):\n" + src # noqa
|
||||
raise RuntimeError(msg)
|
||||
|
||||
@contextmanager
|
||||
def dispatch_to_code(self, index: int):
|
||||
"""Context manager to dispatch to the compiled code.
|
||||
|
||||
@ -541,19 +541,12 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
|
||||
short_cache = self._compute_cos_sin_cache(
|
||||
original_max_position_embeddings, short_factor, short_mscale)
|
||||
short_cache = short_cache.to(dtype)
|
||||
self.register_buffer("short_cos_sin_cache",
|
||||
short_cache,
|
||||
persistent=False)
|
||||
|
||||
long_cache = self._compute_cos_sin_cache(max_position_embeddings,
|
||||
long_factor, long_mscale)
|
||||
long_cache = long_cache.to(dtype)
|
||||
self.register_buffer("long_cos_sin_cache",
|
||||
long_cache,
|
||||
persistent=False)
|
||||
|
||||
long_short_cache = torch.cat(
|
||||
[self.short_cos_sin_cache, self.long_cos_sin_cache], dim=0)
|
||||
long_short_cache = torch.cat([short_cache, long_cache], dim=0)
|
||||
self.register_buffer("long_short_cos_sin_cache",
|
||||
long_short_cache,
|
||||
persistent=False)
|
||||
@ -593,8 +586,6 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
|
||||
torch.full_like(positions, k)).long()
|
||||
idx = (torch.add(positions, long_prompt_offset)
|
||||
if long_prompt_offset is not None else positions)
|
||||
self.long_short_cos_sin_cache: torch.Tensor = (
|
||||
self.long_short_cos_sin_cache.to(idx.device))
|
||||
idx = torch.add(idx, offsets) if offsets is not None else idx
|
||||
cos_sin = torch.index_select(self.long_short_cos_sin_cache, 0, idx)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user