mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-26 02:44:27 +08:00
[Misc] rename torch_dtype to dtype (#26695)
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
parent
f93e348010
commit
8f4b313c37
@ -631,7 +631,7 @@ def main(args: argparse.Namespace):
|
|||||||
else:
|
else:
|
||||||
ensure_divisibility(intermediate_size, args.tp_size, "intermediate_size")
|
ensure_divisibility(intermediate_size, args.tp_size, "intermediate_size")
|
||||||
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
||||||
dtype = torch.float16 if current_platform.is_rocm() else config.torch_dtype
|
dtype = torch.float16 if current_platform.is_rocm() else config.dtype
|
||||||
use_fp8_w8a8 = args.dtype == "fp8_w8a8"
|
use_fp8_w8a8 = args.dtype == "fp8_w8a8"
|
||||||
use_int8_w8a16 = args.dtype == "int8_w8a16"
|
use_int8_w8a16 = args.dtype == "int8_w8a16"
|
||||||
block_quant_shape = get_weight_block_size_safety(config)
|
block_quant_shape = get_weight_block_size_safety(config)
|
||||||
|
|||||||
@ -344,7 +344,7 @@ def main(args: argparse.Namespace):
|
|||||||
topk = config.num_experts_per_tok
|
topk = config.num_experts_per_tok
|
||||||
|
|
||||||
hidden_size = config.hidden_size
|
hidden_size = config.hidden_size
|
||||||
dtype = torch.float16 if current_platform.is_rocm() else config.torch_dtype
|
dtype = torch.float16 if current_platform.is_rocm() else config.dtype
|
||||||
use_fp8_w8a8 = args.dtype == "fp8_w8a8"
|
use_fp8_w8a8 = args.dtype == "fp8_w8a8"
|
||||||
use_int8_w8a16 = args.dtype == "int8_w8a16"
|
use_int8_w8a16 = args.dtype == "int8_w8a16"
|
||||||
use_customized_permute = args.use_customized_permute
|
use_customized_permute = args.use_customized_permute
|
||||||
|
|||||||
@ -58,7 +58,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
|
|||||||
from auto_round import AutoRound
|
from auto_round import AutoRound
|
||||||
|
|
||||||
model_name = "Qwen/Qwen3-0.6B"
|
model_name = "Qwen/Qwen3-0.6B"
|
||||||
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto")
|
model = AutoModelForCausalLM.from_pretrained(model_name, dtype="auto")
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||||
|
|
||||||
bits, group_size, sym = 4, 128, True
|
bits, group_size, sym = 4, 128, True
|
||||||
|
|||||||
@ -43,7 +43,7 @@ MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct"
|
|||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
MODEL_ID,
|
MODEL_ID,
|
||||||
device_map="auto",
|
device_map="auto",
|
||||||
torch_dtype="auto",
|
dtype="auto",
|
||||||
)
|
)
|
||||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
|
||||||
```
|
```
|
||||||
|
|||||||
@ -41,7 +41,7 @@ MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct"
|
|||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
MODEL_ID,
|
MODEL_ID,
|
||||||
device_map="auto",
|
device_map="auto",
|
||||||
torch_dtype="auto",
|
dtype="auto",
|
||||||
)
|
)
|
||||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
|
||||||
```
|
```
|
||||||
|
|||||||
@ -46,7 +46,7 @@ MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct"
|
|||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
MODEL_ID,
|
MODEL_ID,
|
||||||
device_map="auto",
|
device_map="auto",
|
||||||
torch_dtype="auto",
|
dtype="auto",
|
||||||
)
|
)
|
||||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
|
||||||
```
|
```
|
||||||
|
|||||||
@ -82,7 +82,7 @@ Here's a complete example using `meta-llama/Llama-3.1-8B-Instruct` (most models
|
|||||||
|
|
||||||
# Select model and load it
|
# Select model and load it
|
||||||
MODEL_ID = "meta-llama/Llama-3.1-8B-Instruct"
|
MODEL_ID = "meta-llama/Llama-3.1-8B-Instruct"
|
||||||
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, device_map="auto", torch_dtype="auto")
|
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, device_map="auto", dtype="auto")
|
||||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
|
||||||
|
|
||||||
# Select calibration dataset
|
# Select calibration dataset
|
||||||
|
|||||||
@ -50,7 +50,7 @@ to fetch model and tokenizer.
|
|||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
MODEL_ID,
|
MODEL_ID,
|
||||||
device_map="auto",
|
device_map="auto",
|
||||||
torch_dtype="auto",
|
dtype="auto",
|
||||||
)
|
)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
|
|||||||
@ -27,7 +27,7 @@ You can quantize your own huggingface model with torchao, e.g. [transformers](ht
|
|||||||
quantization_config = TorchAoConfig(Int8WeightOnlyConfig())
|
quantization_config = TorchAoConfig(Int8WeightOnlyConfig())
|
||||||
quantized_model = AutoModelForCausalLM.from_pretrained(
|
quantized_model = AutoModelForCausalLM.from_pretrained(
|
||||||
model_name,
|
model_name,
|
||||||
torch_dtype="auto",
|
dtype="auto",
|
||||||
device_map="auto",
|
device_map="auto",
|
||||||
quantization_config=quantization_config
|
quantization_config=quantization_config
|
||||||
)
|
)
|
||||||
|
|||||||
@ -7,7 +7,7 @@ requests >= 2.26.0
|
|||||||
tqdm
|
tqdm
|
||||||
blake3
|
blake3
|
||||||
py-cpuinfo
|
py-cpuinfo
|
||||||
transformers >= 4.55.2
|
transformers >= 4.56.0
|
||||||
tokenizers >= 0.21.1 # Required for fast incremental detokenization.
|
tokenizers >= 0.21.1 # Required for fast incremental detokenization.
|
||||||
protobuf # Required by LlamaTokenizer.
|
protobuf # Required by LlamaTokenizer.
|
||||||
fastapi[standard] >= 0.115.0 # Required by FastAPI's form models in the OpenAI API server's audio transcriptions endpoint.
|
fastapi[standard] >= 0.115.0 # Required by FastAPI's form models in the OpenAI API server's audio transcriptions endpoint.
|
||||||
|
|||||||
@ -334,7 +334,7 @@ class HfRunner:
|
|||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
self.device = self.get_default_device()
|
self.device = self.get_default_device()
|
||||||
self.dtype = torch_dtype = _get_and_verify_dtype(
|
self.dtype = dtype = _get_and_verify_dtype(
|
||||||
self.model_name,
|
self.model_name,
|
||||||
self.config,
|
self.config,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
@ -342,7 +342,7 @@ class HfRunner:
|
|||||||
)
|
)
|
||||||
|
|
||||||
model_kwargs = model_kwargs if model_kwargs is not None else {}
|
model_kwargs = model_kwargs if model_kwargs is not None else {}
|
||||||
model_kwargs.setdefault("torch_dtype", torch_dtype)
|
model_kwargs.setdefault("dtype", dtype)
|
||||||
|
|
||||||
if is_sentence_transformer:
|
if is_sentence_transformer:
|
||||||
# Lazy init required for AMD CI
|
# Lazy init required for AMD CI
|
||||||
@ -388,7 +388,7 @@ class HfRunner:
|
|||||||
if not skip_tokenizer_init:
|
if not skip_tokenizer_init:
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||||
model_name,
|
model_name,
|
||||||
torch_dtype=torch_dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -398,7 +398,7 @@ class HfRunner:
|
|||||||
|
|
||||||
self.processor = AutoProcessor.from_pretrained(
|
self.processor = AutoProcessor.from_pretrained(
|
||||||
model_name,
|
model_name,
|
||||||
torch_dtype=torch_dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
if skip_tokenizer_init:
|
if skip_tokenizer_init:
|
||||||
|
|||||||
@ -38,7 +38,7 @@ def run_intern_vit_test(
|
|||||||
config.norm_type = "rms_norm"
|
config.norm_type = "rms_norm"
|
||||||
|
|
||||||
hf_model = AutoModel.from_pretrained(
|
hf_model = AutoModel.from_pretrained(
|
||||||
model, torch_dtype=torch_dtype, trust_remote_code=True
|
model, dtype=torch_dtype, trust_remote_code=True
|
||||||
).to("cuda")
|
).to("cuda")
|
||||||
hf_outputs_per_image = [
|
hf_outputs_per_image = [
|
||||||
hf_model(pixel_value.to("cuda")).last_hidden_state
|
hf_model(pixel_value.to("cuda")).last_hidden_state
|
||||||
|
|||||||
@ -45,7 +45,7 @@ def run_radio_test(
|
|||||||
hf_model = AutoModel.from_pretrained(
|
hf_model = AutoModel.from_pretrained(
|
||||||
model_id,
|
model_id,
|
||||||
config=config,
|
config=config,
|
||||||
torch_dtype=torch_dtype,
|
dtype=torch_dtype,
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
).to("cuda")
|
).to("cuda")
|
||||||
hf_model.eval()
|
hf_model.eval()
|
||||||
|
|||||||
@ -251,7 +251,7 @@ def run_hf(
|
|||||||
disable_detokenize: bool = False,
|
disable_detokenize: bool = False,
|
||||||
) -> float:
|
) -> float:
|
||||||
llm = AutoModelForCausalLM.from_pretrained(
|
llm = AutoModelForCausalLM.from_pretrained(
|
||||||
model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code
|
model, dtype=torch.float16, trust_remote_code=trust_remote_code
|
||||||
)
|
)
|
||||||
if llm.config.model_type == "llama":
|
if llm.config.model_type == "llama":
|
||||||
# To enable padding in the HF backend.
|
# To enable padding in the HF backend.
|
||||||
|
|||||||
@ -1837,18 +1837,18 @@ def _find_dtype(
|
|||||||
*,
|
*,
|
||||||
revision: str | None,
|
revision: str | None,
|
||||||
):
|
):
|
||||||
# NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
|
# NOTE: getattr(config, "dtype", torch.float32) is not correct
|
||||||
# because config.torch_dtype can be None.
|
# because config.dtype can be None.
|
||||||
config_dtype = getattr(config, "torch_dtype", None)
|
config_dtype = getattr(config, "dtype", None)
|
||||||
|
|
||||||
# Fallbacks for multi-modal models if the root config
|
# Fallbacks for multi-modal models if the root config
|
||||||
# does not define torch_dtype
|
# does not define dtype
|
||||||
if config_dtype is None:
|
if config_dtype is None:
|
||||||
config_dtype = getattr(config.get_text_config(), "torch_dtype", None)
|
config_dtype = getattr(config.get_text_config(), "dtype", None)
|
||||||
if config_dtype is None and hasattr(config, "vision_config"):
|
if config_dtype is None and hasattr(config, "vision_config"):
|
||||||
config_dtype = getattr(config.vision_config, "torch_dtype", None)
|
config_dtype = getattr(config.vision_config, "dtype", None)
|
||||||
if config_dtype is None and hasattr(config, "encoder_config"):
|
if config_dtype is None and hasattr(config, "encoder_config"):
|
||||||
config_dtype = getattr(config.encoder_config, "torch_dtype", None)
|
config_dtype = getattr(config.encoder_config, "dtype", None)
|
||||||
|
|
||||||
# Try to read the dtype of the weights if they are in safetensors format
|
# Try to read the dtype of the weights if they are in safetensors format
|
||||||
if config_dtype is None:
|
if config_dtype is None:
|
||||||
|
|||||||
@ -117,9 +117,8 @@ class LLM:
|
|||||||
execution with tensor parallelism.
|
execution with tensor parallelism.
|
||||||
dtype: The data type for the model weights and activations. Currently,
|
dtype: The data type for the model weights and activations. Currently,
|
||||||
we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
|
we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
|
||||||
the `torch_dtype` attribute specified in the model config file.
|
the `dtype` attribute of the Transformers model's config. However,
|
||||||
However, if the `torch_dtype` in the config is `float32`, we will
|
if the `dtype` in the config is `float32`, we will use `float16` instead.
|
||||||
use `float16` instead.
|
|
||||||
quantization: The method used to quantize the model weights. Currently,
|
quantization: The method used to quantize the model weights. Currently,
|
||||||
we support "awq", "gptq", and "fp8" (experimental).
|
we support "awq", "gptq", and "fp8" (experimental).
|
||||||
If None, we first check the `quantization_config` attribute in the
|
If None, we first check the `quantization_config` attribute in the
|
||||||
|
|||||||
@ -518,7 +518,7 @@ def init_tensorizer_model(
|
|||||||
) -> nn.Module:
|
) -> nn.Module:
|
||||||
assert tensorizer_config.hf_config is not None
|
assert tensorizer_config.hf_config is not None
|
||||||
model_args = tensorizer_config.hf_config
|
model_args = tensorizer_config.hf_config
|
||||||
model_args.torch_dtype = tensorizer_config.dtype
|
model_args.dtype = tensorizer_config.dtype
|
||||||
assert tensorizer_config.model_class is not None
|
assert tensorizer_config.model_class is not None
|
||||||
# TODO: Do we need to consider old-style model class?
|
# TODO: Do we need to consider old-style model class?
|
||||||
with meta_tensor_mode(), set_current_vllm_config(vllm_config, check_compile=True):
|
with meta_tensor_mode(), set_current_vllm_config(vllm_config, check_compile=True):
|
||||||
|
|||||||
@ -999,7 +999,7 @@ class ChameleonForConditionalGeneration(
|
|||||||
return []
|
return []
|
||||||
assert self.model.vqmodel is not None
|
assert self.model.vqmodel is not None
|
||||||
image_tokens = self.model.get_image_tokens(
|
image_tokens = self.model.get_image_tokens(
|
||||||
image_input["data"].to(self.config.torch_dtype)
|
image_input["data"].to(self.config.dtype)
|
||||||
)
|
)
|
||||||
vision_embeddings = self.model.get_input_embeddings(image_tokens)
|
vision_embeddings = self.model.get_input_embeddings(image_tokens)
|
||||||
return vision_embeddings
|
return vision_embeddings
|
||||||
|
|||||||
@ -1089,7 +1089,7 @@ class Ernie4_5VLMultiModalProcessor(BaseMultiModalProcessor[Ernie4_5_VLProcessin
|
|||||||
pixel_values = (
|
pixel_values = (
|
||||||
rescale_factor * pixel_values.to(torch.float32) - image_mean_tensor
|
rescale_factor * pixel_values.to(torch.float32) - image_mean_tensor
|
||||||
) / image_std_tensor
|
) / image_std_tensor
|
||||||
pixel_values = pixel_values.to(hf_config.torch_dtype)
|
pixel_values = pixel_values.to(hf_config.dtype)
|
||||||
return pixel_values
|
return pixel_values
|
||||||
|
|
||||||
def _call_hf_processor(
|
def _call_hf_processor(
|
||||||
|
|||||||
@ -615,7 +615,7 @@ class GLM4VForCausalLM(
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def _process_image_input(self, image_input: GLMVImagePixelInputs) -> torch.Tensor:
|
def _process_image_input(self, image_input: GLMVImagePixelInputs) -> torch.Tensor:
|
||||||
pixel_values = image_input["data"].to(dtype=self.config.torch_dtype)
|
pixel_values = image_input["data"].to(dtype=self.config.dtype)
|
||||||
|
|
||||||
return self.transformer.vision(pixel_values)
|
return self.transformer.vision(pixel_values)
|
||||||
|
|
||||||
|
|||||||
@ -114,7 +114,7 @@ class FlashConfig(PretrainedConfig):
|
|||||||
attention_dropout=0.0,
|
attention_dropout=0.0,
|
||||||
mla_scale_q_lora=False,
|
mla_scale_q_lora=False,
|
||||||
mla_scale_kv_lora=False,
|
mla_scale_kv_lora=False,
|
||||||
torch_dtype="bfloat16",
|
dtype="bfloat16",
|
||||||
params_dtype="bfloat16",
|
params_dtype="bfloat16",
|
||||||
router_dtype="float32",
|
router_dtype="float32",
|
||||||
router_bias=False,
|
router_bias=False,
|
||||||
@ -130,7 +130,7 @@ class FlashConfig(PretrainedConfig):
|
|||||||
bos_token_id=bos_token_id,
|
bos_token_id=bos_token_id,
|
||||||
eos_token_id=eos_token_id,
|
eos_token_id=eos_token_id,
|
||||||
tie_word_embeddings=tie_word_embeddings,
|
tie_word_embeddings=tie_word_embeddings,
|
||||||
torch_dtype=torch_dtype,
|
dtype=dtype,
|
||||||
params_dtype=params_dtype,
|
params_dtype=params_dtype,
|
||||||
router_dtype=router_dtype,
|
router_dtype=router_dtype,
|
||||||
topk_method=topk_method,
|
topk_method=topk_method,
|
||||||
|
|||||||
@ -987,7 +987,7 @@ class NemotronH_Nano_VL_V2(
|
|||||||
prefix=maybe_prefix(prefix, "language_model"),
|
prefix=maybe_prefix(prefix, "language_model"),
|
||||||
)
|
)
|
||||||
self.vision_model = self.get_vit_model_from_radio_config(config).to(
|
self.vision_model = self.get_vit_model_from_radio_config(config).to(
|
||||||
self.language_model.config.torch_dtype
|
self.language_model.config.dtype
|
||||||
)
|
)
|
||||||
|
|
||||||
# Construct the vision projection.
|
# Construct the vision projection.
|
||||||
@ -1008,7 +1008,7 @@ class NemotronH_Nano_VL_V2(
|
|||||||
ReLUSquaredActivation(),
|
ReLUSquaredActivation(),
|
||||||
nn.Linear(vision_projection_hidden_size, llm_hidden_size, bias=False),
|
nn.Linear(vision_projection_hidden_size, llm_hidden_size, bias=False),
|
||||||
)
|
)
|
||||||
self.mlp1 = self.mlp1.to(self.language_model.config.torch_dtype)
|
self.mlp1 = self.mlp1.to(self.language_model.config.dtype)
|
||||||
|
|
||||||
self.config = config
|
self.config = config
|
||||||
self.model_config = vllm_config.model_config
|
self.model_config = vllm_config.model_config
|
||||||
|
|||||||
@ -338,7 +338,7 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
|
|||||||
group_size=None,
|
group_size=None,
|
||||||
norm_before_gate=True,
|
norm_before_gate=True,
|
||||||
device=current_platform.current_device(),
|
device=current_platform.current_device(),
|
||||||
dtype=config.torch_dtype,
|
dtype=config.dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.out_proj = RowParallelLinear(
|
self.out_proj = RowParallelLinear(
|
||||||
@ -847,7 +847,7 @@ class Qwen3NextDecoderLayer(nn.Module):
|
|||||||
1,
|
1,
|
||||||
1,
|
1,
|
||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
dtype=config.torch_dtype,
|
dtype=config.dtype,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
self.ffn_layer_scale = torch.nn.Parameter(
|
self.ffn_layer_scale = torch.nn.Parameter(
|
||||||
@ -855,7 +855,7 @@ class Qwen3NextDecoderLayer(nn.Module):
|
|||||||
1,
|
1,
|
||||||
1,
|
1,
|
||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
dtype=config.torch_dtype,
|
dtype=config.dtype,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -530,7 +530,7 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
|
|||||||
with init_on_device_without_buffers("meta"):
|
with init_on_device_without_buffers("meta"):
|
||||||
self.model: PreTrainedModel = AutoModel.from_config(
|
self.model: PreTrainedModel = AutoModel.from_config(
|
||||||
self.config,
|
self.config,
|
||||||
torch_dtype=self.model_config.dtype,
|
dtype=self.model_config.dtype,
|
||||||
trust_remote_code=self.model_config.trust_remote_code,
|
trust_remote_code=self.model_config.trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -157,7 +157,7 @@ class TransformersForSequenceClassification(TransformersPoolingBase):
|
|||||||
with torch.device("meta"):
|
with torch.device("meta"):
|
||||||
seq_cls_model = AutoModelForSequenceClassification.from_config(
|
seq_cls_model = AutoModelForSequenceClassification.from_config(
|
||||||
self.config,
|
self.config,
|
||||||
torch_dtype=self.model_config.dtype,
|
dtype=self.model_config.dtype,
|
||||||
trust_remote_code=self.model_config.trust_remote_code,
|
trust_remote_code=self.model_config.trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -500,8 +500,8 @@ class CudaPlatformBase(Platform):
|
|||||||
return supported
|
return supported
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_if_supports_dtype(cls, torch_dtype: torch.dtype):
|
def check_if_supports_dtype(cls, dtype: torch.dtype):
|
||||||
if torch_dtype == torch.bfloat16: # noqa: SIM102
|
if dtype == torch.bfloat16: # noqa: SIM102
|
||||||
if not cls.has_device_capability(80):
|
if not cls.has_device_capability(80):
|
||||||
capability = cls.get_device_capability()
|
capability = cls.get_device_capability()
|
||||||
gpu_name = cls.get_device_name()
|
gpu_name = cls.get_device_name()
|
||||||
|
|||||||
@ -563,7 +563,7 @@ class Platform:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_if_supports_dtype(cls, torch_dtype: torch.dtype):
|
def check_if_supports_dtype(cls, dtype: torch.dtype):
|
||||||
"""
|
"""
|
||||||
Check if the dtype is supported by the current platform.
|
Check if the dtype is supported by the current platform.
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -484,8 +484,8 @@ class RocmPlatform(Platform):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_if_supports_dtype(cls, torch_dtype: torch.dtype):
|
def check_if_supports_dtype(cls, dtype: torch.dtype):
|
||||||
if torch_dtype == torch.bfloat16: # noqa: SIM102
|
if dtype == torch.bfloat16: # noqa: SIM102
|
||||||
if not cls.has_device_capability(80):
|
if not cls.has_device_capability(80):
|
||||||
capability = cls.get_device_capability()
|
capability = cls.get_device_capability()
|
||||||
gpu_name = cls.get_device_name()
|
gpu_name = cls.get_device_name()
|
||||||
|
|||||||
@ -236,8 +236,8 @@ class XPUPlatform(Platform):
|
|||||||
return torch.xpu.device_count()
|
return torch.xpu.device_count()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_if_supports_dtype(cls, torch_dtype: torch.dtype):
|
def check_if_supports_dtype(cls, dtype: torch.dtype):
|
||||||
if torch_dtype == torch.bfloat16: # noqa: SIM102
|
if dtype == torch.bfloat16: # noqa: SIM102
|
||||||
device_name = cls.get_device_name().lower()
|
device_name = cls.get_device_name().lower()
|
||||||
# client gpu a770
|
# client gpu a770
|
||||||
if device_name.count("a770") > 0:
|
if device_name.count("a770") > 0:
|
||||||
|
|||||||
@ -806,7 +806,7 @@ def create_kv_caches_with_random_flash(
|
|||||||
|
|
||||||
current_platform.seed_everything(seed)
|
current_platform.seed_everything(seed)
|
||||||
|
|
||||||
torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype)
|
dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype)
|
||||||
generic_kv_cache_shape = (num_blocks, 2, block_size, num_heads, head_size)
|
generic_kv_cache_shape = (num_blocks, 2, block_size, num_heads, head_size)
|
||||||
assert cache_layout in ("NHD", "HND")
|
assert cache_layout in ("NHD", "HND")
|
||||||
stride_order = (0, 1, 2, 3, 4) if cache_layout == "NHD" else (0, 1, 3, 2, 4)
|
stride_order = (0, 1, 2, 3, 4) if cache_layout == "NHD" else (0, 1, 3, 2, 4)
|
||||||
@ -819,7 +819,7 @@ def create_kv_caches_with_random_flash(
|
|||||||
|
|
||||||
for _ in range(num_layers):
|
for _ in range(num_layers):
|
||||||
key_value_cache = torch.empty(
|
key_value_cache = torch.empty(
|
||||||
size=kv_cache_allocation_shape, dtype=torch_dtype, device=device
|
size=kv_cache_allocation_shape, dtype=dtype, device=device
|
||||||
).permute(*stride_order)
|
).permute(*stride_order)
|
||||||
if cache_dtype in ["auto", "half", "bfloat16", "float"]:
|
if cache_dtype in ["auto", "half", "bfloat16", "float"]:
|
||||||
key_value_cache.uniform_(-scale, scale)
|
key_value_cache.uniform_(-scale, scale)
|
||||||
@ -851,14 +851,14 @@ def create_kv_caches_with_random(
|
|||||||
|
|
||||||
current_platform.seed_everything(seed)
|
current_platform.seed_everything(seed)
|
||||||
|
|
||||||
torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype)
|
dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype)
|
||||||
|
|
||||||
scale = head_size**-0.5
|
scale = head_size**-0.5
|
||||||
x = 16 // torch.tensor([], dtype=torch_dtype).element_size()
|
x = 16 // torch.tensor([], dtype=dtype).element_size()
|
||||||
key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
|
key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
|
||||||
key_caches: list[torch.Tensor] = []
|
key_caches: list[torch.Tensor] = []
|
||||||
for _ in range(num_layers):
|
for _ in range(num_layers):
|
||||||
key_cache = torch.empty(size=key_cache_shape, dtype=torch_dtype, device=device)
|
key_cache = torch.empty(size=key_cache_shape, dtype=dtype, device=device)
|
||||||
if cache_dtype in ["auto", "half", "bfloat16", "float"]:
|
if cache_dtype in ["auto", "half", "bfloat16", "float"]:
|
||||||
key_cache.uniform_(-scale, scale)
|
key_cache.uniform_(-scale, scale)
|
||||||
elif cache_dtype == "fp8":
|
elif cache_dtype == "fp8":
|
||||||
@ -870,9 +870,7 @@ def create_kv_caches_with_random(
|
|||||||
value_cache_shape = (num_blocks, num_heads, head_size, block_size)
|
value_cache_shape = (num_blocks, num_heads, head_size, block_size)
|
||||||
value_caches: list[torch.Tensor] = []
|
value_caches: list[torch.Tensor] = []
|
||||||
for _ in range(num_layers):
|
for _ in range(num_layers):
|
||||||
value_cache = torch.empty(
|
value_cache = torch.empty(size=value_cache_shape, dtype=dtype, device=device)
|
||||||
size=value_cache_shape, dtype=torch_dtype, device=device
|
|
||||||
)
|
|
||||||
if cache_dtype in ["auto", "half", "bfloat16", "float"]:
|
if cache_dtype in ["auto", "half", "bfloat16", "float"]:
|
||||||
value_cache.uniform_(-scale, scale)
|
value_cache.uniform_(-scale, scale)
|
||||||
elif cache_dtype == "fp8":
|
elif cache_dtype == "fp8":
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user