mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-23 13:24:28 +08:00
[Bugfix] Use runner_type instead of task in GritLM (#11144)
Signed-off-by: Pooya Davoodi <pooya.davoodi@parasail.io>
This commit is contained in:
parent
30870b4f66
commit
1efce68605
@ -35,7 +35,7 @@ def test_find_array(monkeypatch):
|
|||||||
from vllm.model_executor.models.gritlm import GritLMPooler
|
from vllm.model_executor.models.gritlm import GritLMPooler
|
||||||
|
|
||||||
# Create an LLM object to get the model config.
|
# Create an LLM object to get the model config.
|
||||||
llm = vllm.LLM(MODEL_NAME, task="embedding", max_model_len=MAX_MODEL_LEN)
|
llm = vllm.LLM(MODEL_NAME, task="embed", max_model_len=MAX_MODEL_LEN)
|
||||||
pooler = GritLMPooler(model_config=llm.llm_engine.model_config)
|
pooler = GritLMPooler(model_config=llm.llm_engine.model_config)
|
||||||
|
|
||||||
arr = _arr([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
|
arr = _arr([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
|
||||||
@ -55,7 +55,7 @@ def server_embedding():
|
|||||||
with pytest.MonkeyPatch.context() as mp:
|
with pytest.MonkeyPatch.context() as mp:
|
||||||
mp.setenv("VLLM_ATTENTION_BACKEND", "XFORMERS")
|
mp.setenv("VLLM_ATTENTION_BACKEND", "XFORMERS")
|
||||||
|
|
||||||
args = ["--task", "embedding", "--max_model_len", str(MAX_MODEL_LEN)]
|
args = ["--task", "embed", "--max_model_len", str(MAX_MODEL_LEN)]
|
||||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||||
yield remote_server
|
yield remote_server
|
||||||
|
|
||||||
@ -141,7 +141,7 @@ def test_gritlm_offline_embedding(monkeypatch):
|
|||||||
|
|
||||||
queries, q_instruction, documents, d_instruction = get_test_data()
|
queries, q_instruction, documents, d_instruction = get_test_data()
|
||||||
|
|
||||||
llm = vllm.LLM(MODEL_NAME, task="embedding", max_model_len=MAX_MODEL_LEN)
|
llm = vllm.LLM(MODEL_NAME, task="embed", max_model_len=MAX_MODEL_LEN)
|
||||||
|
|
||||||
d_rep = run_llm_encode(
|
d_rep = run_llm_encode(
|
||||||
llm,
|
llm,
|
||||||
|
|||||||
@ -203,12 +203,12 @@ class GritLM(LlamaForCausalLM):
|
|||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs)
|
super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs)
|
||||||
|
|
||||||
self.task = vllm_config.model_config.task
|
self.runner_type = vllm_config.model_config.runner_type
|
||||||
|
|
||||||
self._pooler = GritLMPooler(vllm_config.model_config)
|
self._pooler = GritLMPooler(vllm_config.model_config)
|
||||||
|
|
||||||
for layer in self.model.layers:
|
for layer in self.model.layers:
|
||||||
if self.task == "embedding" and hasattr(layer, "self_attn"):
|
if self.runner_type == "pooling" and hasattr(layer, "self_attn"):
|
||||||
assert isinstance(layer.self_attn.attn.impl, XFormersImpl), (
|
assert isinstance(layer.self_attn.attn.impl, XFormersImpl), (
|
||||||
"GritLM embedding is only supported by XFormers backend, "
|
"GritLM embedding is only supported by XFormers backend, "
|
||||||
"which can be forced by VLLM_ATTENTION_BACKEND=XFORMERS")
|
"which can be forced by VLLM_ATTENTION_BACKEND=XFORMERS")
|
||||||
@ -222,8 +222,8 @@ class GritLM(LlamaForCausalLM):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
|
|
||||||
# Change attention to non-causal for embedding task.
|
# Change attention to non-causal for pooling tasks.
|
||||||
if self.task == "embedding":
|
if self.runner_type == "pooling":
|
||||||
assert attn_metadata.prefill_metadata.attn_bias is None
|
assert attn_metadata.prefill_metadata.attn_bias is None
|
||||||
attn_metadata.prefill_metadata.attn_bias = [
|
attn_metadata.prefill_metadata.attn_bias = [
|
||||||
BlockDiagonalMask.from_seqlens(attn_metadata.seq_lens)
|
BlockDiagonalMask.from_seqlens(attn_metadata.seq_lens)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user