mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 16:45:52 +08:00
[torch.compile] expanding support and fix allgather compilation (#9637)
Signed-off-by: youkaichao <youkaichao@gmail.com> Co-authored-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
parent
295a061fb3
commit
ad6f78053e
@ -392,8 +392,12 @@ class GroupCoordinator:
|
||||
# Convert negative dim to positive.
|
||||
dim += input_.dim()
|
||||
input_size = input_.size()
|
||||
# NOTE: we have to use concat-style all-gather here,
|
||||
# stack-style all-gather has compatibility issues with
|
||||
# torch.compile . see https://github.com/pytorch/pytorch/issues/138795
|
||||
output_size = (input_size[0] * world_size, ) + input_size[1:]
|
||||
# Allocate output tensor.
|
||||
output_tensor = torch.empty((world_size, ) + input_size,
|
||||
output_tensor = torch.empty(output_size,
|
||||
dtype=input_.dtype,
|
||||
device=input_.device)
|
||||
# All-gather.
|
||||
@ -401,6 +405,7 @@ class GroupCoordinator:
|
||||
input_,
|
||||
group=self.device_group)
|
||||
# Reshape
|
||||
output_tensor = output_tensor.reshape((world_size, ) + input_size)
|
||||
output_tensor = output_tensor.movedim(0, dim)
|
||||
output_tensor = output_tensor.reshape(input_size[:dim] +
|
||||
(world_size *
|
||||
|
||||
@ -25,6 +25,7 @@ from torch import nn
|
||||
from transformers import GPTBigCodeConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, LoRAConfig
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
@ -187,6 +188,7 @@ class GPTBigCodeBlock(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class GPTBigCodeModel(nn.Module):
|
||||
|
||||
def __init__(
|
||||
|
||||
@ -23,6 +23,7 @@ from torch import nn
|
||||
from transformers import GPTJConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
@ -174,6 +175,7 @@ class GPTJBlock(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class GPTJModel(nn.Module):
|
||||
|
||||
def __init__(
|
||||
|
||||
@ -23,6 +23,7 @@ from torch import nn
|
||||
from transformers import GPTNeoXConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
@ -187,6 +188,7 @@ class GPTNeoXLayer(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class GPTNeoXModel(nn.Module):
|
||||
|
||||
def __init__(
|
||||
|
||||
@ -28,6 +28,7 @@ from torch import nn
|
||||
from transformers import GraniteConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, LoRAConfig
|
||||
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size)
|
||||
@ -254,6 +255,7 @@ class GraniteDecoderLayer(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class GraniteModel(nn.Module):
|
||||
|
||||
def __init__(
|
||||
|
||||
@ -7,6 +7,7 @@ from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
@ -230,6 +231,7 @@ class InternLMDecoderLayer(nn.Module):
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class InternLM2Model(nn.Module):
|
||||
|
||||
def __init__(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user