diff --git a/tests/compile/test_basic_correctness.py b/tests/compile/test_basic_correctness.py index 422cb94b036c..f6783704342f 100644 --- a/tests/compile/test_basic_correctness.py +++ b/tests/compile/test_basic_correctness.py @@ -62,8 +62,12 @@ class TestSetting: TestSetting( model="BAAI/bge-multilingual-gemma2", model_args=[ - "--runner", "pooling", "--dtype", "bfloat16", - "--max-model-len", "2048" + "--runner", + "pooling", + "--dtype", + "bfloat16", + "--max-model-len", + "2048", ], pp_size=1, tp_size=1, @@ -71,17 +75,15 @@ class TestSetting: method="encode", fullgraph=True, ), - # TODO: bert models are not supported in V1 yet - # # encoder-based embedding model (BERT) - # TestSetting( - # model="BAAI/bge-base-en-v1.5", - # model_args=["--runner", "pooling"], - # pp_size=1, - # tp_size=1, - # attn_backend="XFORMERS", - # method="encode", - # fullgraph=True, - # ), + TestSetting( + model="BAAI/bge-base-en-v1.5", + model_args=["--runner", "pooling"], + pp_size=1, + tp_size=1, + attn_backend="FLASH_ATTN", + method="encode", + fullgraph=True, + ), # vision language model TestSetting( model="microsoft/Phi-3.5-vision-instruct", @@ -92,7 +94,8 @@ class TestSetting: method="generate_with_image", fullgraph=False, ), - ]) + ], +) def test_compile_correctness( monkeypatch: pytest.MonkeyPatch, test_setting: TestSetting,