Fix env vars for running Ray distributed backend on GKE (#15166)

Signed-off-by: Richard Liu <ricliu@google.com>
This commit is contained in:
Richard Liu 2025-03-20 07:59:33 -07:00 committed by GitHub
parent 69ae2380c6
commit a8f12a63fd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 8 additions and 0 deletions

View File

@ -340,6 +340,8 @@ class RayDistributedExecutor(DistributedExecutorBase):
and v not in self.non_carry_over_env_vars
]
env_vars_to_copy.extend(current_platform.additional_env_vars)
# Copy existing env vars to each worker's args
for args in all_args_to_update_environment_variables:
# TODO: refactor platform-specific env vars

View File

@ -112,6 +112,8 @@ class Platform:
supported_quantization: list[str] = []
additional_env_vars: list[str] = []
def is_cuda(self) -> bool:
return self._enum == PlatformEnum.CUDA

View File

@ -29,6 +29,10 @@ class TpuPlatform(Platform):
"tpu_int8", "compressed-tensors", "compressed_tensors"
]
additional_env_vars: list[str] = [
"TPU_CHIPS_PER_HOST_BOUNDS", "TPU_HOST_BOUNDS"
]
@classmethod
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
dtype: torch.dtype, kv_cache_dtype: Optional[str],