[BugFix] Fix GGUF tp>1 when vocab_size is not divisible by 64 (#12230)

Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
Nicolò Lucchesi 2025-01-21 05:23:14 +01:00 committed by GitHub
parent d4b62d4641
commit 5fe6bf29d6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 12 additions and 2 deletions

View File

@ -66,12 +66,20 @@ STARCODER_CONFIG = GGUFTestConfig(
gguf_filename="starcoder2-3b.Q6_K.gguf",
)
DOLPHIN_CONFIG = GGUFTestConfig(
# Test VocabParallelEmbedding sharding issue.
original_model="cognitivecomputations/TinyDolphin-2.8-1.1b",
gguf_repo="tsunemoto/TinyDolphin-2.8-1.1b-GGUF",
gguf_filename="tinydolphin-2.8-1.1b.Q6_K.gguf",
)
MODELS = [
LLAMA_CONFIG,
QWEN2_CONFIG,
PHI3_CONFIG,
GPT2_CONFIG,
STABLELM_CONFIG,
DOLPHIN_CONFIG
# STARCODER_CONFIG, # broken
]
@ -107,6 +115,7 @@ def test_models(
# Run unquantized model.
with vllm_runner(model_name=model.original_model,
enforce_eager=True, # faster tests
dtype=dtype,
max_model_len=MAX_MODEL_LEN,
tensor_parallel_size=tp_size) as original_model:
@ -115,6 +124,7 @@ def test_models(
# Run gguf model.
with vllm_runner(model_name=model.gguf_model,
enforce_eager=True,
tokenizer_name=model.original_model,
dtype=dtype,
max_model_len=MAX_MODEL_LEN,

View File

@ -355,7 +355,7 @@ class VocabParallelEmbedding(torch.nn.Module):
elif isinstance(param, UninitializedParameter):
shape = list(loaded_weight.shape)
if output_dim is not None:
shape[output_dim] = shape[output_dim] // self.tp_size
shape[output_dim] = self.num_embeddings_per_partition
param.materialize(tuple(shape), dtype=loaded_weight.dtype)
# If parameter does not have output dim, then it should
@ -381,7 +381,7 @@ class VocabParallelEmbedding(torch.nn.Module):
else:
assert loaded_weight.shape[output_dim] == self.org_vocab_size
# Copy the data.
# Copy the data. Select chunk corresponding to current shard.
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
if current_platform.is_hpu():