mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-25 17:07:32 +08:00
[model][refactor] remove cuda hard code in models and layers (#13658)
This commit is contained in:
parent
437b76ff59
commit
23eca9cf68
@ -7,6 +7,7 @@ import torch
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
fused_topk, moe_align_block_size, try_get_optimal_moe_config)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
@ -238,7 +239,7 @@ def fused_marlin_moe(
|
||||
max_workspace_size = (max(2 * N, K) // 64) * 16
|
||||
workspace = torch.zeros(max_workspace_size,
|
||||
dtype=torch.int,
|
||||
device="cuda",
|
||||
device=current_platform.device_type,
|
||||
requires_grad=False)
|
||||
|
||||
if has_no_zp:
|
||||
|
||||
@ -30,6 +30,7 @@ import torch.nn as nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
|
||||
@ -650,9 +651,13 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
|
||||
is_neox_style, dtype)
|
||||
|
||||
def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
|
||||
pos_freqs = self.base**(torch.arange(
|
||||
0, self.rotary_dim, 2, dtype=torch.float, device="cuda") /
|
||||
self.rotary_dim)
|
||||
pos_freqs = self.base**(
|
||||
torch.arange(0,
|
||||
self.rotary_dim,
|
||||
2,
|
||||
dtype=torch.float,
|
||||
device=current_platform.device_type) /
|
||||
self.rotary_dim)
|
||||
inv_freq_extrapolation = 1.0 / pos_freqs
|
||||
inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs)
|
||||
|
||||
@ -670,7 +675,7 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
|
||||
def _compute_cos_sin_cache(self) -> torch.Tensor:
|
||||
inv_freq = self._compute_inv_freq(self.scaling_factor)
|
||||
t = torch.arange(self.max_position_embeddings * self.scaling_factor,
|
||||
device="cuda",
|
||||
device=current_platform.device_type,
|
||||
dtype=torch.float32)
|
||||
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
||||
cos = (freqs.cos() * self.mscale)
|
||||
|
||||
@ -7,6 +7,8 @@ import torch
|
||||
import torch.jit
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
class SpecDecodeBaseSampler(nn.Module):
|
||||
"""Base class for samplers used for Speculative Decoding verification
|
||||
@ -35,7 +37,7 @@ class SpecDecodeBaseSampler(nn.Module):
|
||||
def init_gpu_tensors(self, device: Union[int, str]) -> None:
|
||||
assert self.num_accepted_tokens is None
|
||||
if isinstance(device, int):
|
||||
device = f"cuda:{device}"
|
||||
device = f"{current_platform.device_type}:{device}"
|
||||
elif not isinstance(device, str):
|
||||
raise ValueError(f"Device must be int or str, get {type(device)}")
|
||||
self.num_accepted_tokens = torch.tensor(0,
|
||||
|
||||
@ -914,7 +914,8 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
if param_name + "." in k:
|
||||
quant_state[k] = temp_state_dict[k]
|
||||
|
||||
return QuantState.from_dict(quant_state, device="cuda")
|
||||
return QuantState.from_dict(quant_state,
|
||||
device=current_platform.device_type)
|
||||
|
||||
# Second iterate over all prequant and normal weights
|
||||
# pre quantized weights would have a quant_state
|
||||
|
||||
@ -30,6 +30,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.configs.arctic import ArcticConfig
|
||||
|
||||
@ -138,13 +139,13 @@ class ArcticMoE(nn.Module):
|
||||
torch.empty(self.num_experts,
|
||||
2 * self.intermediate_size,
|
||||
self.hidden_size,
|
||||
device="cuda",
|
||||
device=current_platform.device_type,
|
||||
dtype=self.params_dtype))
|
||||
self.w2s = nn.Parameter(
|
||||
torch.empty(self.num_experts,
|
||||
self.hidden_size,
|
||||
self.intermediate_size,
|
||||
device="cuda",
|
||||
device=current_platform.device_type,
|
||||
dtype=self.params_dtype))
|
||||
set_weight_attrs(self.ws, {
|
||||
"weight_loader": self.weight_loader,
|
||||
|
||||
@ -51,6 +51,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsLoRA, SupportsPP
|
||||
@ -98,13 +99,13 @@ class MiniCPMMoE(nn.Module):
|
||||
torch.empty(self.num_total_experts,
|
||||
2 * self.intermediate_size,
|
||||
self.hidden_size,
|
||||
device="cuda",
|
||||
device=current_platform.device_type,
|
||||
dtype=self.params_dtype))
|
||||
self.w2s = nn.Parameter(
|
||||
torch.empty(self.num_total_experts,
|
||||
self.hidden_size,
|
||||
self.intermediate_size,
|
||||
device="cuda",
|
||||
device=current_platform.device_type,
|
||||
dtype=self.params_dtype))
|
||||
|
||||
set_weight_attrs(self.ws, {
|
||||
|
||||
@ -59,6 +59,7 @@ from vllm.multimodal.parse import (DictEmbeddingItems, ImageItem, ImageSize,
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo, PromptReplacement)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .idefics2_vision_model import Idefics2VisionTransformer
|
||||
@ -1184,7 +1185,8 @@ class MiniCPMV2_0(MiniCPMVBaseModel):
|
||||
quant_config=quant_config,
|
||||
prefix=prefix)
|
||||
|
||||
return resampler.to(device="cuda", dtype=torch.get_default_dtype())
|
||||
return resampler.to(device=current_platform.device_type,
|
||||
dtype=torch.get_default_dtype())
|
||||
|
||||
def get_vision_embedding(
|
||||
self,
|
||||
@ -1266,7 +1268,8 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA):
|
||||
quant_config=quant_config,
|
||||
prefix=prefix)
|
||||
|
||||
return resampler.to(device="cuda", dtype=torch.get_default_dtype())
|
||||
return resampler.to(device=current_platform.device_type,
|
||||
dtype=torch.get_default_dtype())
|
||||
|
||||
def get_vision_embedding(
|
||||
self,
|
||||
@ -1360,7 +1363,8 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA):
|
||||
quant_config=quant_config,
|
||||
prefix=prefix)
|
||||
|
||||
return resampler.to(device="cuda", dtype=torch.get_default_dtype())
|
||||
return resampler.to(device=current_platform.device_type,
|
||||
dtype=torch.get_default_dtype())
|
||||
|
||||
def get_vision_embedding(
|
||||
self,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user