mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-18 06:55:01 +08:00
[torch.compile] expanding support and fix allgather compilation (#9637)
Signed-off-by: youkaichao <youkaichao@gmail.com> Co-authored-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
parent
295a061fb3
commit
ad6f78053e
@ -392,8 +392,12 @@ class GroupCoordinator:
|
|||||||
# Convert negative dim to positive.
|
# Convert negative dim to positive.
|
||||||
dim += input_.dim()
|
dim += input_.dim()
|
||||||
input_size = input_.size()
|
input_size = input_.size()
|
||||||
|
# NOTE: we have to use concat-style all-gather here,
|
||||||
|
# stack-style all-gather has compatibility issues with
|
||||||
|
# torch.compile . see https://github.com/pytorch/pytorch/issues/138795
|
||||||
|
output_size = (input_size[0] * world_size, ) + input_size[1:]
|
||||||
# Allocate output tensor.
|
# Allocate output tensor.
|
||||||
output_tensor = torch.empty((world_size, ) + input_size,
|
output_tensor = torch.empty(output_size,
|
||||||
dtype=input_.dtype,
|
dtype=input_.dtype,
|
||||||
device=input_.device)
|
device=input_.device)
|
||||||
# All-gather.
|
# All-gather.
|
||||||
@ -401,6 +405,7 @@ class GroupCoordinator:
|
|||||||
input_,
|
input_,
|
||||||
group=self.device_group)
|
group=self.device_group)
|
||||||
# Reshape
|
# Reshape
|
||||||
|
output_tensor = output_tensor.reshape((world_size, ) + input_size)
|
||||||
output_tensor = output_tensor.movedim(0, dim)
|
output_tensor = output_tensor.movedim(0, dim)
|
||||||
output_tensor = output_tensor.reshape(input_size[:dim] +
|
output_tensor = output_tensor.reshape(input_size[:dim] +
|
||||||
(world_size *
|
(world_size *
|
||||||
|
|||||||
@ -25,6 +25,7 @@ from torch import nn
|
|||||||
from transformers import GPTBigCodeConfig
|
from transformers import GPTBigCodeConfig
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention, AttentionMetadata
|
||||||
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import CacheConfig, LoRAConfig
|
from vllm.config import CacheConfig, LoRAConfig
|
||||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||||
from vllm.model_executor.layers.activation import get_act_fn
|
from vllm.model_executor.layers.activation import get_act_fn
|
||||||
@ -187,6 +188,7 @@ class GPTBigCodeBlock(nn.Module):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
@support_torch_compile
|
||||||
class GPTBigCodeModel(nn.Module):
|
class GPTBigCodeModel(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|||||||
@ -23,6 +23,7 @@ from torch import nn
|
|||||||
from transformers import GPTJConfig
|
from transformers import GPTJConfig
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention, AttentionMetadata
|
||||||
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import CacheConfig
|
from vllm.config import CacheConfig
|
||||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||||
from vllm.model_executor.layers.activation import get_act_fn
|
from vllm.model_executor.layers.activation import get_act_fn
|
||||||
@ -174,6 +175,7 @@ class GPTJBlock(nn.Module):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
@support_torch_compile
|
||||||
class GPTJModel(nn.Module):
|
class GPTJModel(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|||||||
@ -23,6 +23,7 @@ from torch import nn
|
|||||||
from transformers import GPTNeoXConfig
|
from transformers import GPTNeoXConfig
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention, AttentionMetadata
|
||||||
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import CacheConfig
|
from vllm.config import CacheConfig
|
||||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||||
from vllm.model_executor.layers.activation import get_act_fn
|
from vllm.model_executor.layers.activation import get_act_fn
|
||||||
@ -187,6 +188,7 @@ class GPTNeoXLayer(nn.Module):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
@support_torch_compile
|
||||||
class GPTNeoXModel(nn.Module):
|
class GPTNeoXModel(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|||||||
@ -28,6 +28,7 @@ from torch import nn
|
|||||||
from transformers import GraniteConfig
|
from transformers import GraniteConfig
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention, AttentionMetadata
|
||||||
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import CacheConfig, LoRAConfig
|
from vllm.config import CacheConfig, LoRAConfig
|
||||||
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
||||||
get_tensor_model_parallel_world_size)
|
get_tensor_model_parallel_world_size)
|
||||||
@ -254,6 +255,7 @@ class GraniteDecoderLayer(nn.Module):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
@support_torch_compile
|
||||||
class GraniteModel(nn.Module):
|
class GraniteModel(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|||||||
@ -7,6 +7,7 @@ from torch import nn
|
|||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention, AttentionMetadata
|
||||||
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import CacheConfig
|
from vllm.config import CacheConfig
|
||||||
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
@ -230,6 +231,7 @@ class InternLMDecoderLayer(nn.Module):
|
|||||||
return hidden_states, residual
|
return hidden_states, residual
|
||||||
|
|
||||||
|
|
||||||
|
@support_torch_compile
|
||||||
class InternLM2Model(nn.Module):
|
class InternLM2Model(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user