From d19fc0d3e13004dda41a02a0db619c07882c393e Mon Sep 17 00:00:00 2001 From: Ula Golowicz Date: Wed, 24 Jun 2026 12:46:08 +0000 Subject: [PATCH] [xpu][mx][test] Enable multicard tests for xpu Signed-off-by: Ula Golowicz --- test/prototype/mx_formats/test_mx_dtensor.py | 28 +++++++--- test/prototype/mx_formats/test_mx_dtensor.sh | 22 +++++--- .../mx_formats/test_mxfp8_allgather.py | 54 ++++++++++++------- .../mx_formats/test_mxfp8_allgather.sh | 20 ++++--- torchao/prototype/mx_formats/kernels.py | 11 ++-- torchao/utils.py | 4 ++ 6 files changed, 96 insertions(+), 43 deletions(-) diff --git a/test/prototype/mx_formats/test_mx_dtensor.py b/test/prototype/mx_formats/test_mx_dtensor.py index 8081baf40d..f892c427b9 100644 --- a/test/prototype/mx_formats/test_mx_dtensor.py +++ b/test/prototype/mx_formats/test_mx_dtensor.py @@ -11,6 +11,7 @@ """ import os +import traceback import torch from torch.distributed._tensor import DTensor, Shard, distribute_tensor @@ -31,14 +32,20 @@ def setup_distributed(): world_size = int(os.environ.get("WORLD_SIZE", -1)) - device_mesh = init_device_mesh("cuda", (world_size,)) + device = torch.accelerator.current_accelerator() + device_mesh = init_device_mesh(device.type, (world_size,)) # seed must be the same in all processes torch.manual_seed(1) local_rank = torch.distributed.get_rank() - torch.cuda.set_device(local_rank) + torch.accelerator.set_device_index(local_rank) return device_mesh +def print_once(msg): + if torch.distributed.get_rank() == 0: + print(msg) + + def _test_dtensor_cast_to_mxfp8(mesh: DeviceMesh, size=1024): device = mesh.device_type @@ -105,20 +112,27 @@ def _test_mxfp8_mlp_tensor_parallelism_auto(mesh: DeviceMesh, size=64): _test_dtensor_cast_to_mxfp8, _test_mxfp8_mlp_tensor_parallelism_emulated, ] + from torchao.prototype.moe_training.kernels.mxfp8.quant import ( _mxfp8_cuda_kernels_available, ) - if _mxfp8_cuda_kernels_available: + if device_mesh.device_type == "cuda" and _mxfp8_cuda_kernels_available: tests.append(_test_mxfp8_mlp_tensor_parallelism_auto) else: - print("Skipping auto test: requires SM >= 100 and CUDA >= 12.8") + print_once("Skipping auto test: requires CUDA SM >= 100 and CUDA >= 12.8") - for test in tqdm(tests, desc="Running tests"): + failed_cnt = 0 + for test in tqdm(tests, desc="Running tests", disable=torch.distributed.get_rank() != 0): try: test(device_mesh) except Exception as e: - print(f"Test {test.__name__} failed with error: {e}") - raise e + print_once(f"\033[31m❌ FAILED {test.__name__}: {e}\033[0m") + print_once(traceback.format_exc()) + failed_cnt += 1 + else: + print_once(f"\033[32m✅ PASSED {test.__name__}\033[0m") + + print_once(f"FAILED: {failed_cnt} PASSED: {len(tests) - failed_cnt}") torch.distributed.destroy_process_group() diff --git a/test/prototype/mx_formats/test_mx_dtensor.sh b/test/prototype/mx_formats/test_mx_dtensor.sh index abf9424e3c..596d0f79e1 100755 --- a/test/prototype/mx_formats/test_mx_dtensor.sh +++ b/test/prototype/mx_formats/test_mx_dtensor.sh @@ -1,17 +1,25 @@ +#!/bin/bash # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. -#!/bin/bash # terminate script on first error -set -e +set -euo pipefail -if python -c 'import torch;print(torch.cuda.is_available())' | grep -q "False"; then - echo "Skipping test_dtensor.sh because no CUDA devices are available." - exit -fi +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" # integration tests for TP/SP -NCCL_DEBUG=WARN torchrun --nproc_per_node 2 test/prototype/mx_formats/test_mx_dtensor.py +if python -c 'import torch; assert torch.cuda.is_available()' 2>/dev/null; then + echo "CUDA available, proceeding with test." + exec env NCCL_DEBUG=WARN torchrun --nproc_per_node 2 "${SCRIPT_DIR}/test_mx_dtensor.py" + +elif python -c 'import torch; assert torch.xpu.is_available()' 2>/dev/null; then + echo "XPU available, proceeding with test." + exec torchrun --nproc_per_node 2 "${SCRIPT_DIR}/test_mx_dtensor.py" + +else + echo "Skipping test_mx_dtensor.sh because no CUDA or XPU devices are available." + exit 0 +fi diff --git a/test/prototype/mx_formats/test_mxfp8_allgather.py b/test/prototype/mx_formats/test_mxfp8_allgather.py index 698bbf6340..66338a1d63 100644 --- a/test/prototype/mx_formats/test_mxfp8_allgather.py +++ b/test/prototype/mx_formats/test_mxfp8_allgather.py @@ -1,31 +1,38 @@ +import traceback + import torch import torch.distributed as dist from torchao.prototype.mx_formats.mx_tensor import MXTensor -from torchao.utils import is_sm_at_least_90 def setup_distributed(): - dist.init_process_group("nccl") + device = torch.accelerator.current_accelerator() + dist.init_process_group() # seed must be the same in all processes torch.manual_seed(42) local_rank = torch.distributed.get_rank() - torch.cuda.set_device(local_rank) - return local_rank + torch.accelerator.set_device_index(local_rank) + return torch.device(device.type, local_rank) + + +def print_once(msg): + if torch.distributed.get_rank() == 0: + print(msg) -def _test_allgather(local_rank): +def _test_allgather(device): golden_qdata = ( torch.randint(0, 256, (256, 512), dtype=torch.uint8) .to(torch.float8_e5m2) - .to(local_rank) + .to(device) ) # Random scale factors (typically float32 or uint8 for e8m0) golden_scale = ( torch.randint(0, 256, (256, 16), dtype=torch.uint8) .view(torch.float8_e8m0fnu) - .to(local_rank) + .to(device) ) # Create golden MXTensor @@ -37,7 +44,7 @@ def _test_allgather(local_rank): orig_dtype=torch.float32, kernel_preference=None, act_quant_kwargs=None, - is_swizzled_scales=None, + is_swizzled_scales=False, ) local_rank = torch.distributed.get_rank() @@ -50,14 +57,14 @@ def _test_allgather(local_rank): # Create local MXTensor from shard local_mx = MXTensor( - golden_qdata[start_idx:end_idx].clone().to(local_rank), - golden_scale[start_idx:end_idx].clone().to(local_rank), + golden_qdata[start_idx:end_idx].clone().to(device), + golden_scale[start_idx:end_idx].clone().to(device), elem_dtype=torch.float8_e5m2, block_size=32, orig_dtype=torch.float32, kernel_preference=None, act_quant_kwargs=None, - is_swizzled_scales=None, + is_swizzled_scales=False, ) # Perform all_gather @@ -93,13 +100,22 @@ def _test_allgather(local_rank): if __name__ == "__main__": - local_rank = setup_distributed() - - assert is_sm_at_least_90() == True, "SM must be > 9.0" - - try: - _test_allgather(local_rank) - except Exception as e: - raise e + device = setup_distributed() + tests = [ + _test_allgather, + ] + + failed_cnt = 0 + for test in tests: + try: + test(device) + except Exception as e: + print_once(f"\033[31m\u274c FAILED {test.__name__}: {e}\033[0m") + print_once(traceback.format_exc()) + failed_cnt += 1 + else: + print_once(f"\033[32m\u2705 PASSED {test.__name__}\033[0m") + + print_once(f"FAILED: {failed_cnt} PASSED: {len(tests) - failed_cnt}") torch.distributed.destroy_process_group() diff --git a/test/prototype/mx_formats/test_mxfp8_allgather.sh b/test/prototype/mx_formats/test_mxfp8_allgather.sh index 180375af40..a5c6cc4986 100755 --- a/test/prototype/mx_formats/test_mxfp8_allgather.sh +++ b/test/prototype/mx_formats/test_mxfp8_allgather.sh @@ -1,12 +1,20 @@ #!/bin/bash # terminate script on first error -set -e +set -euo pipefail -if python -c 'import torch;print(torch.cuda.is_available())' | grep -q "False"; then - echo "Skipping test_dtensor.sh because no CUDA devices are available." - exit -fi +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" # integration tests for TP/SP -NCCL_DEBUG=WARN torchrun --nproc_per_node 2 test/prototype/mx_formats/test_mxfp8_allgather.py \ No newline at end of file +if python -c 'import torch; assert torch.cuda.is_available()' 2>/dev/null; then + echo "CUDA available, proceeding with test." + exec env NCCL_DEBUG=WARN torchrun --nproc_per_node 2 "${SCRIPT_DIR}/test_mxfp8_allgather.py" + +elif python -c 'import torch; assert torch.xpu.is_available()' 2>/dev/null; then + echo "XPU available, proceeding with test." + exec torchrun --nproc_per_node 2 "${SCRIPT_DIR}/test_mxfp8_allgather.py" + +else + echo "Skipping test_mxfp8_allgather.sh because no CUDA or XPU devices are available." + exit 0 +fi diff --git a/torchao/prototype/mx_formats/kernels.py b/torchao/prototype/mx_formats/kernels.py index 5b1930c22c..ab56de44eb 100644 --- a/torchao/prototype/mx_formats/kernels.py +++ b/torchao/prototype/mx_formats/kernels.py @@ -25,6 +25,7 @@ is_mslk_version_at_least, is_ROCM, is_sm_at_least_100, + is_XPU, ) logger = logging.getLogger(__name__) @@ -455,11 +456,13 @@ def triton_mxfp8_dequant_dim0( raise AssertionError("needs triton") -_triton_kernels_available = ( - has_triton() - and torch.cuda.is_available() - and (is_sm_at_least_100() and is_cuda_version_at_least(12, 8)) +_triton_kernels_available = has_triton() and ( + ( + torch.cuda.is_available() + and (is_sm_at_least_100() and is_cuda_version_at_least(12, 8)) + ) or (is_ROCM() and is_MI350()) + or is_XPU() ) if _triton_kernels_available: diff --git a/torchao/utils.py b/torchao/utils.py index b9cfb78b06..78fd3be779 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -1198,6 +1198,10 @@ def _cpu_is_vnni_supported() -> bool: return False +def is_XPU(): + return hasattr(torch, "xpu") and torch.xpu.is_available() + + def should_reduce_range(device: torch.device) -> bool: """ Helper to determine if int8 tensor quantization range should be reduced