diff --git a/test/float8/test_fsdp_compile.py b/test/float8/test_fsdp_compile.py index eb32c40aa3..45f6a5d144 100644 --- a/test/float8/test_fsdp_compile.py +++ b/test/float8/test_fsdp_compile.py @@ -35,7 +35,9 @@ def setup(rank, world_size): os.environ["MASTER_PORT"] = "12355" # initialize the process group - dist.init_process_group("nccl", rank=rank, world_size=world_size) + device_type = torch.accelerator.current_accelerator().type + backend = dist.get_default_backend_for_device(device_type) + dist.init_process_group(backend, rank=rank, world_size=world_size) def cleanup(): @@ -72,7 +74,9 @@ def get_model(K, N, is_fp8, emulate, base_dtype=torch.float32): # and modified def fsdp_main(rank, world_size, args): setup(rank, world_size) - torch.cuda.set_device(rank) + device_type = torch.accelerator.current_accelerator().type + torch.accelerator.set_device_index(rank) + device = f"{device_type}:{rank}" (emulate,) = args @@ -81,20 +85,20 @@ def fsdp_main(rank, world_size, args): # regardless of float8. model = get_model(K, N, is_fp8=True, emulate=emulate, base_dtype=torch.bfloat16).to( - rank + device ) # To compile FSDP, we need use_orig_params to True model = FSDP(model, use_orig_params=True) optimizer = torch.optim.SGD(model.parameters(), lr=lr * world_size) - input_local = torch.randn(B, M, K, N, device="cuda") + input_local = torch.randn(B, M, K, N, device=device) model = torch.compile(model) for _iter in range(N_ITER): optimizer.zero_grad() - with torch.autocast("cuda"): + with torch.autocast(device_type): y_local = model(input_local) y_local.sum().backward() optimizer.step() @@ -105,17 +109,17 @@ def fsdp_main(rank, world_size, args): def run(): emulate = False - if not torch.cuda.is_available(): - warnings.warn("CUDA not available, running in emulation_mode", stacklevel=2) + if not torch.accelerator.is_available(): + warnings.warn("GPU not available, running in emulation_mode", stacklevel=2) emulate = True - elif torch.cuda.get_device_capability() < (9, 0): + elif torch.cuda.is_available() and torch.cuda.get_device_capability() < (9, 0): warnings.warn( f"CUDA capability {torch.cuda.get_device_capability()} < (9.0), running in emulation mode", stacklevel=2, ) emulate = True - WORLD_SIZE = torch.cuda.device_count() + WORLD_SIZE = torch.accelerator.device_count() args = (emulate,) mp.spawn(fsdp_main, args=(WORLD_SIZE, args), nprocs=WORLD_SIZE, join=True) diff --git a/test/float8/test_fsdp_compile.sh b/test/float8/test_fsdp_compile.sh index 192586cee6..2f01911413 100755 --- a/test/float8/test_fsdp_compile.sh +++ b/test/float8/test_fsdp_compile.sh @@ -7,10 +7,10 @@ # terminate script on first error set -e -if python -c 'import torch;print(torch.cuda.is_available())' | grep -q "False"; then - echo "Skipping test_fsdp_compile.sh because no CUDA devices are available." +if python -c 'import torch;print(torch.accelerator.is_available())' | grep -q "False"; then + echo "Skipping test_fsdp_compile.sh because no accelerator devices are available." exit fi # Code to be executed if CUDA devices are available -NCCL_DEBUG=WARN CUDA_VISIBLE_DEVICES=0,1 python test/float8/test_fsdp_compile.py +CCL_LOG_LEVEL=info NCCL_DEBUG=WARN CUDA_VISIBLE_DEVICES=0,1 python test/float8/test_fsdp_compile.py