Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 21 additions & 7 deletions test/prototype/mx_formats/test_mx_dtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
"""

import os
import traceback

import torch
from torch.distributed._tensor import DTensor, Shard, distribute_tensor
Expand All @@ -31,14 +32,20 @@

def setup_distributed():
world_size = int(os.environ.get("WORLD_SIZE", -1))
device_mesh = init_device_mesh("cuda", (world_size,))
device = torch.accelerator.current_accelerator()
device_mesh = init_device_mesh(device.type, (world_size,))
# seed must be the same in all processes
torch.manual_seed(1)
local_rank = torch.distributed.get_rank()
torch.cuda.set_device(local_rank)
torch.accelerator.set_device_index(local_rank)
return device_mesh


def print_once(msg):
if torch.distributed.get_rank() == 0:
print(msg)


def _test_dtensor_cast_to_mxfp8(mesh: DeviceMesh, size=1024):
device = mesh.device_type

Expand Down Expand Up @@ -105,20 +112,27 @@ def _test_mxfp8_mlp_tensor_parallelism_auto(mesh: DeviceMesh, size=64):
_test_dtensor_cast_to_mxfp8,
_test_mxfp8_mlp_tensor_parallelism_emulated,
]

from torchao.prototype.moe_training.kernels.mxfp8.quant import (
_mxfp8_cuda_kernels_available,
)

if _mxfp8_cuda_kernels_available:
if device_mesh.device_type == "cuda" and _mxfp8_cuda_kernels_available:
tests.append(_test_mxfp8_mlp_tensor_parallelism_auto)
else:
print("Skipping auto test: requires SM >= 100 and CUDA >= 12.8")
print_once("Skipping auto test: requires CUDA SM >= 100 and CUDA >= 12.8")

for test in tqdm(tests, desc="Running tests"):
failed_cnt = 0
for test in tqdm(tests, desc="Running tests", disable=torch.distributed.get_rank() != 0):
try:
test(device_mesh)
except Exception as e:
print(f"Test {test.__name__} failed with error: {e}")
raise e
print_once(f"\033[31m❌ FAILED {test.__name__}: {e}\033[0m")
print_once(traceback.format_exc())
failed_cnt += 1
else:
print_once(f"\033[32m✅ PASSED {test.__name__}\033[0m")

print_once(f"FAILED: {failed_cnt} PASSED: {len(tests) - failed_cnt}")

torch.distributed.destroy_process_group()
22 changes: 15 additions & 7 deletions test/prototype/mx_formats/test_mx_dtensor.sh
Original file line number Diff line number Diff line change
@@ -1,17 +1,25 @@
#!/bin/bash
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.
#!/bin/bash

# terminate script on first error
set -e
set -euo pipefail

if python -c 'import torch;print(torch.cuda.is_available())' | grep -q "False"; then
echo "Skipping test_dtensor.sh because no CUDA devices are available."
exit
fi
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"

# integration tests for TP/SP
NCCL_DEBUG=WARN torchrun --nproc_per_node 2 test/prototype/mx_formats/test_mx_dtensor.py
if python -c 'import torch; assert torch.cuda.is_available()' 2>/dev/null; then
echo "CUDA available, proceeding with test."
exec env NCCL_DEBUG=WARN torchrun --nproc_per_node 2 "${SCRIPT_DIR}/test_mx_dtensor.py"

elif python -c 'import torch; assert torch.xpu.is_available()' 2>/dev/null; then
echo "XPU available, proceeding with test."
exec torchrun --nproc_per_node 2 "${SCRIPT_DIR}/test_mx_dtensor.py"

else
echo "Skipping test_mx_dtensor.sh because no CUDA or XPU devices are available."
exit 0
fi
54 changes: 35 additions & 19 deletions test/prototype/mx_formats/test_mxfp8_allgather.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,38 @@
import traceback

import torch
import torch.distributed as dist

from torchao.prototype.mx_formats.mx_tensor import MXTensor
from torchao.utils import is_sm_at_least_90


def setup_distributed():
dist.init_process_group("nccl")
device = torch.accelerator.current_accelerator()
dist.init_process_group()
# seed must be the same in all processes
torch.manual_seed(42)
local_rank = torch.distributed.get_rank()
torch.cuda.set_device(local_rank)
return local_rank
torch.accelerator.set_device_index(local_rank)
return torch.device(device.type, local_rank)


def print_once(msg):
if torch.distributed.get_rank() == 0:
print(msg)


def _test_allgather(local_rank):
def _test_allgather(device):
golden_qdata = (
torch.randint(0, 256, (256, 512), dtype=torch.uint8)
.to(torch.float8_e5m2)
.to(local_rank)
.to(device)
)

# Random scale factors (typically float32 or uint8 for e8m0)
golden_scale = (
torch.randint(0, 256, (256, 16), dtype=torch.uint8)
.view(torch.float8_e8m0fnu)
.to(local_rank)
.to(device)
)

# Create golden MXTensor
Expand All @@ -37,7 +44,7 @@ def _test_allgather(local_rank):
orig_dtype=torch.float32,
kernel_preference=None,
act_quant_kwargs=None,
is_swizzled_scales=None,
is_swizzled_scales=False,
)

local_rank = torch.distributed.get_rank()
Expand All @@ -50,14 +57,14 @@ def _test_allgather(local_rank):

# Create local MXTensor from shard
local_mx = MXTensor(
golden_qdata[start_idx:end_idx].clone().to(local_rank),
golden_scale[start_idx:end_idx].clone().to(local_rank),
golden_qdata[start_idx:end_idx].clone().to(device),
golden_scale[start_idx:end_idx].clone().to(device),
elem_dtype=torch.float8_e5m2,
block_size=32,
orig_dtype=torch.float32,
kernel_preference=None,
act_quant_kwargs=None,
is_swizzled_scales=None,
is_swizzled_scales=False,
)

# Perform all_gather
Expand Down Expand Up @@ -93,13 +100,22 @@ def _test_allgather(local_rank):


if __name__ == "__main__":
local_rank = setup_distributed()

assert is_sm_at_least_90() == True, "SM must be > 9.0"

try:
_test_allgather(local_rank)
except Exception as e:
raise e
device = setup_distributed()
tests = [
_test_allgather,
]

failed_cnt = 0
for test in tests:
try:
test(device)
except Exception as e:
print_once(f"\033[31m\u274c FAILED {test.__name__}: {e}\033[0m")
print_once(traceback.format_exc())
failed_cnt += 1
else:
print_once(f"\033[32m\u2705 PASSED {test.__name__}\033[0m")

print_once(f"FAILED: {failed_cnt} PASSED: {len(tests) - failed_cnt}")

torch.distributed.destroy_process_group()
20 changes: 14 additions & 6 deletions test/prototype/mx_formats/test_mxfp8_allgather.sh
Original file line number Diff line number Diff line change
@@ -1,12 +1,20 @@
#!/bin/bash

# terminate script on first error
set -e
set -euo pipefail

if python -c 'import torch;print(torch.cuda.is_available())' | grep -q "False"; then
echo "Skipping test_dtensor.sh because no CUDA devices are available."
exit
fi
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"

# integration tests for TP/SP
NCCL_DEBUG=WARN torchrun --nproc_per_node 2 test/prototype/mx_formats/test_mxfp8_allgather.py
if python -c 'import torch; assert torch.cuda.is_available()' 2>/dev/null; then
echo "CUDA available, proceeding with test."
exec env NCCL_DEBUG=WARN torchrun --nproc_per_node 2 "${SCRIPT_DIR}/test_mxfp8_allgather.py"

elif python -c 'import torch; assert torch.xpu.is_available()' 2>/dev/null; then
echo "XPU available, proceeding with test."
exec torchrun --nproc_per_node 2 "${SCRIPT_DIR}/test_mxfp8_allgather.py"

else
echo "Skipping test_mxfp8_allgather.sh because no CUDA or XPU devices are available."
exit 0
fi
11 changes: 7 additions & 4 deletions torchao/prototype/mx_formats/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
is_mslk_version_at_least,
is_ROCM,
is_sm_at_least_100,
is_XPU,
)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -455,11 +456,13 @@ def triton_mxfp8_dequant_dim0(
raise AssertionError("needs triton")


_triton_kernels_available = (
has_triton()
and torch.cuda.is_available()
and (is_sm_at_least_100() and is_cuda_version_at_least(12, 8))
_triton_kernels_available = has_triton() and (
(
torch.cuda.is_available()
and (is_sm_at_least_100() and is_cuda_version_at_least(12, 8))
)
or (is_ROCM() and is_MI350())
or is_XPU()
)

if _triton_kernels_available:
Expand Down
4 changes: 4 additions & 0 deletions torchao/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1198,6 +1198,10 @@ def _cpu_is_vnni_supported() -> bool:
return False


def is_XPU():
return hasattr(torch, "xpu") and torch.xpu.is_available()


def should_reduce_range(device: torch.device) -> bool:
"""
Helper to determine if int8 tensor quantization range should be reduced
Expand Down
Loading