[Bugfix] Fix typo in Pallas backend (#5558)

This commit is contained in:
Woosuk Kwon 2024-06-14 14:40:09 -07:00 committed by GitHub
parent e2afb03c92
commit 28c145eb57
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -110,7 +110,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
raise NotImplementedError("TPU version must be 4 or higher.")
self.megacore_mode = None
tpu_type = torch_xla.tpu.get_tp_groupu_env()["TYPE"].lower()
tpu_type = torch_xla.tpu.get_tpu_env()["TYPE"].lower()
if not tpu_type.endswith("lite"):
if self.num_kv_heads % 2 == 0:
self.megacore_mode = "kv_head"