[torch.compile] expanding support and fix allgather compilation (#9637)

Signed-off-by: youkaichao <youkaichao@gmail.com>
Co-authored-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
Yongzao 2024-10-24 16:32:15 +08:00 committed by GitHub
parent 295a061fb3
commit ad6f78053e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 16 additions and 1 deletions

View File

@ -392,8 +392,12 @@ class GroupCoordinator:
# Convert negative dim to positive.
dim += input_.dim()
input_size = input_.size()
# NOTE: we have to use concat-style all-gather here,
# stack-style all-gather has compatibility issues with
# torch.compile . see https://github.com/pytorch/pytorch/issues/138795
output_size = (input_size[0] * world_size, ) + input_size[1:]
# Allocate output tensor.
output_tensor = torch.empty((world_size, ) + input_size,
output_tensor = torch.empty(output_size,
dtype=input_.dtype,
device=input_.device)
# All-gather.
@ -401,6 +405,7 @@ class GroupCoordinator:
input_,
group=self.device_group)
# Reshape
output_tensor = output_tensor.reshape((world_size, ) + input_size)
output_tensor = output_tensor.movedim(0, dim)
output_tensor = output_tensor.reshape(input_size[:dim] +
(world_size *

View File

@ -25,6 +25,7 @@ from torch import nn
from transformers import GPTBigCodeConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
@ -187,6 +188,7 @@ class GPTBigCodeBlock(nn.Module):
return hidden_states
@support_torch_compile
class GPTBigCodeModel(nn.Module):
def __init__(

View File

@ -23,6 +23,7 @@ from torch import nn
from transformers import GPTJConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
@ -174,6 +175,7 @@ class GPTJBlock(nn.Module):
return hidden_states
@support_torch_compile
class GPTJModel(nn.Module):
def __init__(

View File

@ -23,6 +23,7 @@ from torch import nn
from transformers import GPTNeoXConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
@ -187,6 +188,7 @@ class GPTNeoXLayer(nn.Module):
return hidden_states
@support_torch_compile
class GPTNeoXModel(nn.Module):
def __init__(

View File

@ -28,6 +28,7 @@ from torch import nn
from transformers import GraniteConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
@ -254,6 +255,7 @@ class GraniteDecoderLayer(nn.Module):
return hidden_states
@support_torch_compile
class GraniteModel(nn.Module):
def __init__(

View File

@ -7,6 +7,7 @@ from torch import nn
from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
@ -230,6 +231,7 @@ class InternLMDecoderLayer(nn.Module):
return hidden_states, residual
@support_torch_compile
class InternLM2Model(nn.Module):
def __init__(