diff --git a/utilities/cli/container.py b/utilities/cli/container.py index 1448fc12b6..ac72c33ed1 100644 --- a/utilities/cli/container.py +++ b/utilities/cli/container.py @@ -20,6 +20,7 @@ import re import shlex import shutil +import socket import stat import subprocess import sys @@ -725,6 +726,27 @@ def get_gpu_runtime_args(self) -> List[str]: args.extend(self.get_device_cgroup_args()) return args + def is_valid_endpoint(self) -> bool: + """Check if SCCACHE_MEMCACHED_ENDPOINT is valid""" + endpoint = os.environ.get("SCCACHE_MEMCACHED_ENDPOINT") + + if endpoint: + try: + host, port_str = endpoint.rsplit(":", 1) + port = int(port_str) + + with socket.create_connection((host, port), timeout=5): + info(f" > Using memcached endpoint {endpoint}") + return True + + except Exception: + warn( + f" > Memcached endpoint {endpoint} is not reachable, " + "falling back to local caching." + ) + + return False + def get_environment_args(self) -> List[str]: """Environment variable arguments""" # Default GPU visibility is controlled via NVIDIA_VISIBLE_DEVICES (from the image and/or @@ -762,7 +784,9 @@ def get_environment_args(self) -> List[str]: args.extend(["-e", f"SCCACHE_DIR={SCCACHE_CONTAINER_DIR}"]) # Forward other SCCACHE_* environment variables present on host for k in sccache_keys: - if k != "SCCACHE_DIR": + if (k != "SCCACHE_DIR") and ( + k != "SCCACHE_MEMCACHED_ENDPOINT" or self.is_valid_endpoint() + ): args.extend(["-e", k]) elif len(sccache_keys) > 0: warn(